Skip to content

Commit

Permalink
Merge pull request #80 from IdentityPython/develop
Browse files Browse the repository at this point in the history
Changes as an effect of changing persistent storage model.
  • Loading branch information
rohe authored Mar 21, 2021
2 parents a54f653 + 99f3780 commit e1eabfd
Show file tree
Hide file tree
Showing 8 changed files with 286 additions and 107 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ exclude_lines = [

[tool.poetry]
name = "cryptojwt"
version = "1.4.1"
version = "1.5.0"
description = "Python implementation of JWT, JWE, JWS and JWK"
authors = ["Roland Hedberg <[email protected]>"]
license = "Apache-2.0"
Expand Down
148 changes: 103 additions & 45 deletions src/cryptojwt/key_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import time
from datetime import datetime
from functools import cmp_to_key
from typing import List
from typing import Optional

import requests

Expand All @@ -24,7 +26,6 @@
from .jwk.jwk import dump_jwk
from .jwk.jwk import import_jwk
from .jwk.rsa import RSAKey
from .jwk.rsa import import_private_rsa_key_from_file
from .jwk.rsa import new_rsa_key
from .utils import as_unicode

Expand Down Expand Up @@ -152,6 +153,26 @@ def ec_init(spec):
class KeyBundle:
"""The Key Bundle"""

params = {
"cache_time": 0,
"etag": "",
"fileformat": "jwks",
"httpc_params": {},
"ignore_errors_period": 0,
"ignore_errors_until": None,
"ignore_invalid_keys": True,
"imp_jwks": None,
"keytype": "RSA",
"keyusage": None,
"last_local": None,
"last_remote": None,
"last_updated": 0,
"local": False,
"remote": False,
"source": None,
"time_out": 0,
}

def __init__(
self,
keys=None,
Expand Down Expand Up @@ -189,22 +210,22 @@ def __init__(
"""

self._keys = []
self.remote = False
self.local = False
self.cache_time = cache_time
self.ignore_errors_period = ignore_errors_period
self.ignore_errors_until = None # UNIX timestamp of last error
self.time_out = 0
self.etag = ""
self.source = None
self.fileformat = fileformat.lower()
self.ignore_errors_period = ignore_errors_period
self.ignore_errors_until = None # UNIX timestamp of last error
self.ignore_invalid_keys = ignore_invalid_keys
self.imp_jwks = None
self.keytype = keytype
self.keyusage = keyusage
self.imp_jwks = None
self.last_updated = 0
self.last_remote = None # HTTP Date of last remote update
self.last_local = None # UNIX timestamp of last local update
self.ignore_invalid_keys = ignore_invalid_keys
self.last_remote = None # HTTP Date of last remote update
self.last_updated = 0
self.local = False
self.remote = False
self.source = None
self.time_out = 0

if httpc:
self.httpc = httpc
Expand Down Expand Up @@ -490,6 +511,7 @@ def update(self):

# reread everything
self._keys = []
updated = None

try:
if self.local:
Expand Down Expand Up @@ -751,48 +773,68 @@ def difference(self, bundle):

return [k for k in self._keys if k not in bundle]

def dump(self):
_keys = []
for _k in self._keys:
_ser = _k.to_dict()
if _k.inactive_since:
_ser["inactive_since"] = _k.inactive_since
_keys.append(_ser)

res = {
"keys": _keys,
"fileformat": self.fileformat,
"last_updated": self.last_updated,
"last_remote": self.last_remote,
"last_local": self.last_local,
"httpc_params": self.httpc_params,
"remote": self.remote,
"local": self.local,
"imp_jwks": self.imp_jwks,
"time_out": self.time_out,
"cache_time": self.cache_time,
}
def dump(self, exclude_attributes: Optional[List[str]] = None):
if exclude_attributes is None:
exclude_attributes = []

if self.source:
res["source"] = self.source
res = {}

if "keys" not in exclude_attributes:
_keys = []
for _k in self._keys:
_ser = _k.to_dict()
if _k.inactive_since:
_ser["inactive_since"] = _k.inactive_since
_keys.append(_ser)
res["keys"] = _keys

for attr, default in self.params.items():
if attr in exclude_attributes:
continue
val = getattr(self, attr)
res[attr] = val

return res

def load(self, spec):
"""
Sets attributes according to a specification.
Does not overwrite an existing attributes value with a default value.
:param spec: Dictionary with attributes and value to populate the instance with
:return: The instance itself
"""
_keys = spec.get("keys", [])
if _keys:
self.do_keys(_keys)
self.source = spec.get("source", None)
self.fileformat = spec.get("fileformat", "jwks")
self.last_updated = spec.get("last_updated", 0)
self.last_remote = spec.get("last_remote", None)
self.last_local = spec.get("last_local", None)
self.remote = spec.get("remote", False)
self.local = spec.get("local", False)
self.imp_jwks = spec.get("imp_jwks", None)
self.time_out = spec.get("time_out", 0)
self.cache_time = spec.get("cache_time", 0)
self.httpc_params = spec.get("httpc_params", {})

for attr, default in self.params.items():
val = spec.get(attr)
if val:
setattr(self, attr, val)

return self

def flush(self):
self._keys = []
self.cache_time = (300,)
self.etag = ""
self.fileformat = "jwks"
# self.httpc=None,
self.httpc_params = (None,)
self.ignore_errors_period = 0
self.ignore_errors_until = None
self.ignore_invalid_keys = True
self.imp_jwks = None
self.keytype = ("RSA",)
self.keyusage = (None,)
self.last_local = None # UNIX timestamp of last local update
self.last_remote = None # HTTP Date of last remote update
self.last_updated = 0
self.local = False
self.remote = False
self.source = None
self.time_out = 0
return self


Expand Down Expand Up @@ -1246,3 +1288,19 @@ def init_key(filename, type, kid="", **kwargs):
_new_key = key_gen(type, kid=kid, **kwargs)
dump_jwk(filename, _new_key)
return _new_key


def key_by_alg(alg: str):
if alg.startswith("RS"):
return key_gen("RSA", alg="RS256")
elif alg.startswith("ES"):
if alg == "ES256":
return key_gen("EC", crv="P-256")
elif alg == "ES384":
return key_gen("EC", crv="P-384")
elif alg == "ES512":
return key_gen("EC", crv="P-521")
elif alg.startswith("HS"):
return key_gen("sym")

raise ValueError("Don't know who to create a key to use with '{}'".format(alg))
77 changes: 50 additions & 27 deletions src/cryptojwt/key_issuer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import json
import logging
import os
from typing import List
from typing import Optional

from requests import request

Expand All @@ -15,13 +17,21 @@

__author__ = "Roland Hedberg"


logger = logging.getLogger(__name__)


class KeyIssuer(object):
""" A key issuer instance contains a number of KeyBundles. """

params = {
"ca_certs": None,
"httpc_params": None,
"keybundle_cls": KeyBundle,
"name": "",
"remove_after": 3600,
"spec2key": None,
}

def __init__(
self,
ca_certs=None,
Expand All @@ -45,14 +55,13 @@ def __init__(

self._bundles = []

self.keybundle_cls = keybundle_cls
self.name = name

self.spec2key = {}
self.ca_certs = ca_certs
self.remove_after = remove_after
self.httpc = httpc or request
self.httpc_params = httpc_params or {}
self.keybundle_cls = keybundle_cls
self.name = name
self.remove_after = remove_after
self.spec2key = {}

def __repr__(self) -> str:
return '<KeyIssuer "{}" {}>'.format(self.name, self.key_summary())
Expand Down Expand Up @@ -350,43 +359,57 @@ def __len__(self):
nr += len(kb)
return nr

def dump(self, exclude=None):
def dump(self, exclude_attributes: Optional[List[str]] = None) -> dict:
"""
Returns the content as a dictionary.
:param exclude_attributes: List of attribute names for objects that should be ignored.
:return: A dictionary
"""

_bundles = []
for kb in self._bundles:
_bundles.append(kb.dump())

info = {
"name": self.name,
"bundles": _bundles,
"keybundle_cls": qualified_name(self.keybundle_cls),
"spec2key": self.spec2key,
"ca_certs": self.ca_certs,
"remove_after": self.remove_after,
"httpc_params": self.httpc_params,
}
if exclude_attributes is None:
exclude_attributes = []

info = {}
for attr, default in self.params.items():
if attr in exclude_attributes:
continue
val = getattr(self, attr)
if attr == "keybundle_cls":
val = qualified_name(val)
info[attr] = val

if "bundles" not in exclude_attributes:
_bundles = []
for kb in self._bundles:
_bundles.append(kb.dump(exclude_attributes=exclude_attributes))
info["bundles"] = _bundles

return info

def load(self, info):
"""
:param items: A list with the information
:param items: A dictionary with the information to load
:return:
"""
self.name = info["name"]
self.keybundle_cls = importer(info["keybundle_cls"])
self.spec2key = info["spec2key"]
self.ca_certs = info["ca_certs"]
self.remove_after = info["remove_after"]
self.httpc_params = info["httpc_params"]
for attr, default in self.params.items():
val = info.get(attr)
if val:
if attr == "keybundle_cls":
val = importer(val)
setattr(self, attr, val)

self._bundles = [KeyBundle().load(val) for val in info["bundles"]]
return self

def flush(self):
for attr, default in self.params.items():
setattr(self, attr, default)

self._bundles = []
return self

def update(self):
for kb in self._bundles:
kb.update()
Expand Down
Loading

0 comments on commit e1eabfd

Please sign in to comment.