Skip to content

Commit

Permalink
Merge pull request #68 from IdentityPython/develop
Browse files Browse the repository at this point in the history
Prepare release 1.3.0
  • Loading branch information
jschlyter authored Sep 11, 2020
2 parents 3f24ac8 + dcad904 commit 46df6c3
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 30 deletions.
2 changes: 1 addition & 1 deletion src/cryptojwt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
except ImportError:
pass

__version__ = "1.2.0"
__version__ = "1.3.0"

logger = logging.getLogger(__name__)

Expand Down
4 changes: 0 additions & 4 deletions src/cryptojwt/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,6 @@ class UpdateFailed(KeyIOError):
pass


class UnknownKeytype(Invalid):
"""An unknown key type"""


class JWKException(JWKESTException):
pass

Expand Down
73 changes: 50 additions & 23 deletions src/cryptojwt/key_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import os
import time
from datetime import datetime
from functools import cmp_to_key

import requests
Expand Down Expand Up @@ -156,6 +157,7 @@ def __init__(
keys=None,
source="",
cache_time=300,
ignore_errors_period=0,
fileformat="jwks",
keytype="RSA",
keyusage=None,
Expand Down Expand Up @@ -188,6 +190,8 @@ def __init__(
self.remote = False
self.local = False
self.cache_time = cache_time
self.ignore_errors_period = ignore_errors_period
self.ignore_errors_until = None # UNIX timestamp of last error
self.time_out = 0
self.etag = ""
self.source = None
Expand Down Expand Up @@ -314,7 +318,11 @@ def do_local_jwk(self, filename):
Load a JWKS from a local file
:param filename: Name of the file from which the JWKS should be loaded
:return: True if load was successful or False if file hasn't been modified
"""
if not self._local_update_required():
return False

LOGGER.info("Reading local JWKS from %s", filename)
with open(filename) as input_file:
_info = json.load(input_file)
Expand All @@ -324,6 +332,7 @@ def do_local_jwk(self, filename):
self.do_keys([_info])
self.last_local = time.time()
self.time_out = self.last_local + self.cache_time
return True

def do_local_der(self, filename, keytype, keyusage=None, kid=""):
"""
Expand All @@ -332,7 +341,11 @@ def do_local_der(self, filename, keytype, keyusage=None, kid=""):
:param filename: Name of the file
:param keytype: Presently 'rsa' and 'ec' supported
:param keyusage: encryption ('enc') or signing ('sig') or both
:return: True if load was successful or False if file hasn't been modified
"""
if not self._local_update_required():
return False

LOGGER.info("Reading local DER from %s", filename)
key_args = {}
_kty = keytype.lower()
Expand All @@ -355,16 +368,25 @@ def do_local_der(self, filename, keytype, keyusage=None, kid=""):
self.do_keys([key_args])
self.last_local = time.time()
self.time_out = self.last_local + self.cache_time
return True

def do_remote(self):
"""
Load a JWKS from a webpage.
:return: True or False if load was successful
:return: True if load was successful or False if remote hasn't been modified
"""
# if self.verify_ssl is not None:
# self.httpc_params["verify"] = self.verify_ssl

if self.ignore_errors_until and time.time() < self.ignore_errors_until:
LOGGER.warning(
"Not reading remote JWKS from %s (in error holddown until %s)",
self.source,
datetime.fromtimestamp(self.ignore_errors_until),
)
return False

LOGGER.info("Reading remote JWKS from %s", self.source)
try:
LOGGER.debug("KeyBundle fetch keys from: %s", self.source)
Expand All @@ -378,7 +400,10 @@ def do_remote(self):
LOGGER.error(err)
raise UpdateFailed(REMOTE_FAILED.format(self.source, str(err)))

if _http_resp.status_code == 200: # New content
load_successful = _http_resp.status_code == 200
not_modified = _http_resp.status_code == 304

if load_successful:
self.time_out = time.time() + self.cache_time

self.imp_jwks = self._parse_remote_response(_http_resp)
Expand All @@ -390,25 +415,27 @@ def do_remote(self):
self.do_keys(self.imp_jwks["keys"])
except KeyError:
LOGGER.error("No 'keys' keyword in JWKS")
self.ignore_errors_until = time.time() + self.ignore_errors_period
raise UpdateFailed(MALFORMED.format(self.source))

if hasattr(_http_resp, "headers"):
headers = getattr(_http_resp, "headers")
self.last_remote = headers.get("last-modified") or headers.get("date")

elif _http_resp.status_code == 304: # Not modified
elif not_modified:
LOGGER.debug("%s not modified since %s", self.source, self.last_remote)
self.time_out = time.time() + self.cache_time

else:
LOGGER.warning(
"HTTP status %d reading remote JWKS from %s",
_http_resp.status_code,
self.source,
)
self.ignore_errors_until = time.time() + self.ignore_errors_period
raise UpdateFailed(REMOTE_FAILED.format(self.source, _http_resp.status_code))

self.last_updated = time.time()
return True
self.ignore_errors_until = None
return load_successful

def _parse_remote_response(self, response):
"""
Expand All @@ -433,23 +460,20 @@ def _parse_remote_response(self, response):
return None

def _uptodate(self):
res = False
if self.remote or self.local:
if time.time() > self.time_out:
if self.local and not self._local_update_required():
res = True
elif self.update():
res = True
return res
return self.update()
return False

def update(self):
"""
Reload the keys if necessary.
This is a forced update, will happen even if cache time has not elapsed.
Replaced keys will be marked as inactive and not removed.
:return: True if update was ok or False if we encountered an error during update.
"""
res = True # An update was successful
if self.source:
_old_keys = self._keys # just in case

Expand All @@ -459,24 +483,27 @@ def update(self):
try:
if self.local:
if self.fileformat in ["jwks", "jwk"]:
self.do_local_jwk(self.source)
updated = self.do_local_jwk(self.source)
elif self.fileformat == "der":
self.do_local_der(self.source, self.keytype, self.keyusage)
updated = self.do_local_der(self.source, self.keytype, self.keyusage)
elif self.remote:
res = self.do_remote()
updated = self.do_remote()
except Exception as err:
LOGGER.error("Key bundle update failed: %s", err)
self._keys = _old_keys # restore
return False

now = time.time()
for _key in _old_keys:
if _key not in self._keys:
if not _key.inactive_since: # If already marked don't mess
_key.inactive_since = now
self._keys.append(_key)
if updated:
now = time.time()
for _key in _old_keys:
if _key not in self._keys:
if not _key.inactive_since: # If already marked don't mess
_key.inactive_since = now
self._keys.append(_key)
else:
self._keys = _old_keys

return res
return True

def get(self, typ="", only_active=True):
"""
Expand Down
45 changes: 44 additions & 1 deletion tests/test_03_key_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from cryptojwt.jwk.rsa import import_rsa_key_from_cert_file
from cryptojwt.jwk.rsa import new_rsa_key
from cryptojwt.key_bundle import KeyBundle
from cryptojwt.key_bundle import UpdateFailed
from cryptojwt.key_bundle import build_key_bundle
from cryptojwt.key_bundle import dump_jwks
from cryptojwt.key_bundle import init_key
Expand Down Expand Up @@ -566,6 +567,7 @@ def test_update_2():
ec_key = new_ec_key(crv="P-256", key_ops=["sign"])
_jwks = {"keys": [rsa_key.serialize(), ec_key.serialize()]}

time.sleep(0.5)
with open(fname, "w") as fp:
fp.write(json.dumps(_jwks))

Expand Down Expand Up @@ -1008,7 +1010,7 @@ def test_remote_not_modified():

with responses.RequestsMock() as rsps:
rsps.add(method="GET", url=source, status=304, headers=headers)
assert kb.do_remote()
assert not kb.do_remote()
assert kb.last_remote == headers.get("Last-Modified")
timeout2 = kb.time_out

Expand All @@ -1018,9 +1020,50 @@ def test_remote_not_modified():
kb2 = KeyBundle().load(exp)
assert kb2.source == source
assert len(kb2.keys()) == 3
assert len(kb2.active_keys()) == 3
assert len(kb2.get("rsa")) == 1
assert len(kb2.get("oct")) == 1
assert len(kb2.get("ec")) == 1
assert kb2.httpc_params == {"timeout": (2, 2)}
assert kb2.imp_jwks
assert kb2.last_updated


def test_ignore_errors_period():
source_good = "https://example.com/keys.json"
source_bad = "https://example.com/keys-bad.json"
ignore_errors_period = 1
# Mock response
with responses.RequestsMock() as rsps:
rsps.add(method="GET", url=source_good, json=JWKS_DICT, status=200)
rsps.add(method="GET", url=source_bad, json=JWKS_DICT, status=500)
httpc_params = {"timeout": (2, 2)} # connect, read timeouts in seconds
kb = KeyBundle(
source=source_good,
httpc=requests.request,
httpc_params=httpc_params,
ignore_errors_period=ignore_errors_period,
)
res = kb.do_remote()
assert res == True
assert kb.ignore_errors_until is None

# refetch, but fail by using a bad source
kb.source = source_bad
try:
res = kb.do_remote()
except UpdateFailed:
pass

# retry should fail silently as we're in holddown
res = kb.do_remote()
assert kb.ignore_errors_until is not None
assert res == False

# wait until holddown
time.sleep(ignore_errors_period + 1)

# try again
kb.source = source_good
res = kb.do_remote()
assert res == True
6 changes: 6 additions & 0 deletions tests/test_04_key_jar.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,12 @@ def test_aud(self):
keys = self.bob_keyjar.get_jwt_verify_keys(_jwt.jwt, no_kid_issuer=no_kid_issuer)
assert len(keys) == 1

def test_inactive_verify_key(self):
_jwt = factory(self.sjwt_b)
self.alice_keyjar.return_issuer("Bob")[0].mark_all_as_inactive()
keys = self.alice_keyjar.get_jwt_verify_keys(_jwt.jwt)
assert len(keys) == 0


def test_copy():
kj = KeyJar()
Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ envlist = py{36,37,38},quality
[testenv]
passenv = CI TRAVIS TRAVIS_*
commands =
py.test --cov=cryptojwt --isort --black {posargs}
pytest -vvv -ra --cov=cryptojwt --isort --black {posargs}
codecov
extras = testing
deps =
Expand Down

0 comments on commit 46df6c3

Please sign in to comment.