Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for endpoint_url for local dynamodb table #300

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 52 additions & 39 deletions credstash.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,14 +232,14 @@ def paddedInt(i):
return (pad * "0") + i_str


def getHighestVersion(name, region=None, table="credential-store",
def getHighestVersion(name, region=None, endpoint_url=None, table="credential-store",
**kwargs):
'''
Return the highest version of `name` in the table
'''
session = get_session(**kwargs)

dynamodb = session.resource('dynamodb', region_name=region)
dynamodb = session.resource('dynamodb', region_name=region, endpoint_url=endpoint_url)
secrets = dynamodb.Table(table)

response = secrets.query(Limit=1,
Expand Down Expand Up @@ -285,15 +285,14 @@ def clean_error(*args, **kwargs):
return clean_error

@clean_fail
def listSecrets(region=None, table="credential-store", session=None, **kwargs):
def listSecrets(region=None, table="credential-store", endpoint_url=None, session=None, **kwargs):
'''
do a full-table scan of the credential-store,
and return the names and versions of every credential
'''
if session is None:
session = get_session(**kwargs)

dynamodb = session.resource('dynamodb', region_name=region)
dynamodb = session.resource('dynamodb', region_name=region, endpoint_url=endpoint_url)
secrets = dynamodb.Table(table)

items = []
Expand All @@ -315,7 +314,7 @@ def listSecrets(region=None, table="credential-store", session=None, **kwargs):

@clean_fail
def putSecret(name, secret, version="", kms_key="alias/credstash",
region=None, table="credential-store", context=None,
region=None, endpoint_url=None, table="credential-store", context=None,
digest=DEFAULT_DIGEST, comment="", kms=None, dynamodb=None,
kms_region=None, **kwargs):
'''
Expand All @@ -328,7 +327,7 @@ def putSecret(name, secret, version="", kms_key="alias/credstash",
if dynamodb is None or kms is None:
session = get_session(**kwargs)
if dynamodb is None:
dynamodb = session.resource('dynamodb', region_name=region)
dynamodb = session.resource('dynamodb', region_name=region, endpoint_url=endpoint_url)
if kms is None:
kms = session.client('kms', region_name=kms_region or region)

Expand All @@ -338,7 +337,6 @@ def putSecret(name, secret, version="", kms_key="alias/credstash",
secret,
digest_method=digest,
)

secrets = dynamodb.Table(table)

data = {
Expand All @@ -348,12 +346,11 @@ def putSecret(name, secret, version="", kms_key="alias/credstash",
if comment:
data['comment'] = comment
data.update(sealed)

return secrets.put_item(Item=data, ConditionExpression=Attr('name').not_exists())


def putSecretAutoversion(name, secret, kms_key="alias/credstash",
region=None, table="credential-store", context=None,
region=None, endpoint_url=None, table="credential-store", context=None,
digest=DEFAULT_DIGEST, comment="", kms_region=None, **kwargs):
"""
This function put secrets to credstash using autoversioning
Expand All @@ -364,24 +361,24 @@ def putSecretAutoversion(name, secret, kms_key="alias/credstash",
incremented_version = paddedInt(int(latest_version) + 1)
try:
putSecret(name=name, secret=secret, version=incremented_version,
kms_key=kms_key, region=region, kms_region=kms_region,
kms_key=kms_key, region=region, endpoint_url=endpoint_url, kms_region=kms_region,
table=table, context=context, digest=digest, comment=comment, **kwargs)
print("Secret '{0}' has been stored in table {1}".format(name, table))
except KmsError as e:
fatal(e)


def getAllSecrets(version="", region=None, table="credential-store",
def getAllSecrets(version="", region=None, endpoint_url=None, table="credential-store",
context=None, credential=None, session=None,
kms_region=None, **kwargs):
'''
fetch and decrypt all secrets
'''
if session is None:
session = get_session(**kwargs)
dynamodb = session.resource('dynamodb', region_name=region)
dynamodb = session.resource('dynamodb', region_name=region, endpoint_url=endpoint_url)
kms = session.client('kms', region_name=kms_region or region)
secrets = listSecrets(region, table, session, **kwargs)
secrets = listSecrets(region, table, endpoint_url, session, **kwargs)

# Only return the secrets that match the pattern in `credential`
# This already works out of the box with the CLI get action,
Expand Down Expand Up @@ -415,9 +412,10 @@ def getAllSecrets(version="", region=None, table="credential-store",


@clean_fail
def getAllAction(args, region, kms_region, **session_params):
def getAllAction(args, region, endpoint_url, kms_region, **session_params):
secrets = getAllSecrets(args.version,
region=region,
endpoint_url=endpoint_url,
kms_region=kms_region,
table=args.table,
context=args.context,
Expand All @@ -440,10 +438,11 @@ def getAllAction(args, region, kms_region, **session_params):


@clean_fail
def putSecretAction(args, region, kms_region, **session_params):
def putSecretAction(args, region, endpoint_url, kms_region, **session_params):
if args.autoversion:
latestVersion = getHighestVersion(args.credential,
region,
endpoint_url,
args.table,
**session_params)
try:
Expand All @@ -458,15 +457,15 @@ def putSecretAction(args, region, kms_region, **session_params):
if(args.prompt):
value = getpass("{}: ".format(args.credential))
if putSecret(args.credential, value, version=version,
kms_key=args.key, region=region, kms_region=kms_region,
kms_key=args.key, region=region, endpoint_url=endpoint_url, kms_region=kms_region,
table=args.table, context=args.context, digest=args.digest,
comment=args.comment, **session_params):
print("{0} has been stored".format(args.credential))
except KmsError as e:
fatal(e)
except botocore.exceptions.ClientError as e:
if e.response["Error"]["Code"] == "ConditionalCheckFailedException":
latestVersion = getHighestVersion(args.credential, region,
latestVersion = getHighestVersion(args.credential, region, endpoint_url,
args.table,
**session_params)
fatal("%s version %s is already in the credential store. "
Expand All @@ -477,7 +476,7 @@ def putSecretAction(args, region, kms_region, **session_params):


@clean_fail
def putAllSecretsAction(args, region, kms_region, **session_params):
def putAllSecretsAction(args, region, endpoint_url, kms_region, **session_params):
credentials = json.loads(args.credentials)

for credential, value in credentials.items():
Expand All @@ -486,26 +485,28 @@ def putAllSecretsAction(args, region, kms_region, **session_params):
args.value = value
args.comment = None
args.prompt = None
putSecretAction(args, region, kms_region, **session_params)
putSecretAction(args, region, endpoint_url, kms_region, **session_params)
except SystemExit as e:
pass


@clean_fail
def getSecretAction(args, region, kms_region, **session_params):
def getSecretAction(args, region, endpoint_url, kms_region, **session_params):
try:
if WILDCARD_CHAR in args.credential:
names = expand_wildcard(args.credential,
[x["name"]
for x
in listSecrets(region=region,
endpoint_url=endpoint_url,
table=args.table,
**session_params)])
secrets = {
name:getSecret(
name,
version=args.version,
region=region,
endpoint_url=endpoint_url,
kms_region=kms_region,
table=args.table,
context=args.context,
Expand Down Expand Up @@ -533,7 +534,8 @@ def getSecretAction(args, region, kms_region, **session_params):
sys.stdout.write(getSecret(
args.credential,
version=args.version,
region=region,
region=region,
endpoint_url=endpoint_url,
kms_region=kms_region,
table=args.table,
context=args.context,
Expand All @@ -549,7 +551,7 @@ def getSecretAction(args, region, kms_region, **session_params):
fatal(e)

@clean_fail
def getSecret(name, version="", region=None, table="credential-store", context=None,
def getSecret(name, version="", region=None, endpoint_url=None, table="credential-store", context=None,
dynamodb=None, kms=None, kms_region=None, **kwargs):
'''
fetch and decrypt the secret called `name`
Expand All @@ -561,7 +563,7 @@ def getSecret(name, version="", region=None, table="credential-store", context=N
if dynamodb is None or kms is None:
session = get_session(**kwargs)
if dynamodb is None:
dynamodb = session.resource('dynamodb', region_name=region)
dynamodb = session.resource('dynamodb', region_name=region, endpoint_url=endpoint_url)
if kms is None:
kms = session.client('kms', region_name=kms_region or region)

Expand Down Expand Up @@ -591,10 +593,10 @@ def getSecret(name, version="", region=None, table="credential-store", context=N


@clean_fail
def deleteSecrets(name, region=None, table="credential-store",
def deleteSecrets(name, endpoint_url=None, region=None, table="credential-store",
**kwargs):
session = get_session(**kwargs)
dynamodb = session.resource('dynamodb', region_name=region)
dynamodb = session.resource('dynamodb', region_name=region, endpoint_url=endpoint_url)
secrets = dynamodb.Table(table)

response = {'LastEvaluatedKey': None}
Expand Down Expand Up @@ -652,12 +654,12 @@ def writeConfig(options):


@clean_fail
def createDdbTable(region=None, table="credential-store", tags=None, **kwargs):
def createDdbTable(region=None, endpoint_url=None, table="credential-store", tags=None, **kwargs):
'''
create the secret store table in DDB in the specified region
'''
session = get_session(**kwargs)
dynamodb = session.resource("dynamodb", region_name=region)
dynamodb = session.resource("dynamodb", region_name=region, endpoint_url=endpoint_url)
if table in (t.name for t in dynamodb.tables.all()):
print("Credential Store table already exists")
return
Expand Down Expand Up @@ -692,7 +694,7 @@ def createDdbTable(region=None, table="credential-store", tags=None, **kwargs):
)

print("Waiting for table to be created...")
client = session.client("dynamodb", region_name=region)
client = session.client("dynamodb", region_name=region, endpoint_url=endpoint_url)

response = client.describe_table(TableName=table)

Expand Down Expand Up @@ -842,9 +844,10 @@ def get_digest(digest):


@clean_fail
def list_credentials(region, args, **session_params):
def list_credentials(region, args, endpoint_url, **session_params):
credential_list = listSecrets(region=region,
table=args.table,
endpoint_url=endpoint_url,
**session_params)
if credential_list:
# print list of credential names and versions,
Expand All @@ -859,9 +862,10 @@ def list_credentials(region, args, **session_params):


@clean_fail
def list_credential_keys(region, args, **session_params):
def list_credential_keys(region, endpoint_url, args, **session_params):
credential_list = listSecrets(region=region,
table=args.table,
endpoint_url=endpoint_url,
**session_params)
if credential_list:
creds = sorted(set(cred["name"] for cred in credential_list))
Expand Down Expand Up @@ -907,6 +911,14 @@ def get_parser():
"CREDSTASH_DEFAULT_TABLE env variable, "
"or if that is not set, the value "
"`credential-store` will be used")
parsers['super'].add_argument("--endpoint_url", default=os.environ.get("DYNAMODB_ENDPOINT_URL", None),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super-small issue - could you please change this to --endpoint-url to be consistent with the other args?

Suggested change
parsers['super'].add_argument("--endpoint_url", default=os.environ.get("DYNAMODB_ENDPOINT_URL", None),
parsers['super'].add_argument("--endpoint-url", default=os.environ.get("DYNAMODB_ENDPOINT_URL", None),

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done 👍

help="DynamoDB endpoint to use for credential storage. "
"If not specified, credstash "
"will use the value of the "
"DYNAMODB_ENDPOINT_URL env variable, "
"or if that is not set, the value "
"`None` will be used, "
"which will auto-generate the dynamodb url.")
parsers['super'].add_argument("--log-level",
help="Set the log level, default WARNING",
default='WARNING'
Expand Down Expand Up @@ -1100,8 +1112,9 @@ def main():
# test for region
try:
region = args.region
endpoint_url = args.endpoint_url
session = get_session(**session_params)
session.resource('dynamodb', region_name=region)
session.resource('dynamodb', region_name=region, endpoint_url=endpoint_url)
except botocore.exceptions.NoRegionError:
if 'AWS_DEFAULT_REGION' not in os.environ:
region = DEFAULT_REGION
Expand All @@ -1117,27 +1130,27 @@ def main():
**session_params)
return
if args.action == "list":
list_credentials(region, args, **session_params)
list_credentials(region, args, endpoint_url, **session_params)
return
if args.action == "keys":
list_credential_keys(region, args, **session_params)
list_credential_keys(region, endpoint_url, args, **session_params)
return
if args.action == "put":
putSecretAction(args, region, kms_region, **session_params)
putSecretAction(args, region, endpoint_url, kms_region, **session_params)
return
if args.action == "putall":
putAllSecretsAction(args, region, kms_region, **session_params)
putAllSecretsAction(args, region, endpoint_url, kms_region, **session_params)
return
if args.action == "get":
getSecretAction(args, region, kms_region, **session_params)
getSecretAction(args, region, endpoint_url, kms_region, **session_params)
return
if args.action == "getall":
getAllAction(args, region, kms_region, **session_params)
getAllAction(args, region, endpoint_url, kms_region, **session_params)
return
if args.action == "setup":
if args.save_kms_region:
setKmsRegion(args)
createDdbTable(region=region, table=args.table,
createDdbTable(region=region, table=args.table, endpoint_url=endpoint_url,
tags=args.tags, **session_params)
return
else:
Expand Down