diff --git a/src/josepy/jwk.py b/src/josepy/jwk.py index 7a4abf86..8e507e3d 100644 --- a/src/josepy/jwk.py +++ b/src/josepy/jwk.py @@ -1,6 +1,5 @@ """JSON Web Key.""" import abc -import base64 import json import logging import math @@ -407,38 +406,45 @@ class JWKOKP(JWK): ) required = ('crv', JWK.type_field_name, 'x') - def __init__(self, *args, **kwargs) -> None: + def __init__(self, *args, **kwargs): if 'key' in kwargs and not isinstance(kwargs['key'], util.ComparableOKPKey): kwargs['key'] = util.ComparableOKPKey(kwargs['key']) super().__init__(*args, **kwargs) - def public_key(self) -> Union[ - ed25519.Ed25519PublicKey, ed448.Ed448PublicKey, - x25519.X25519PublicKey, x448.X448PublicKey, - ]: - return self._wrapped.__class__.public_key() + def public_key(self): + return self.key._wrapped.__class__.public_key() + + def _key_to_crv(self): + if isinstance(self.key._wrapped, (ed25519.Ed25519PrivateKey, ed25519.Ed25519PrivateKey)): + return "Ed25519" + elif isinstance(self.key._wrapped, (ed448.Ed448PrivateKey, ed448.Ed448PrivateKey)): + return "Ed448" + elif isinstance(self.key._wrapped, (x25519.X25519PrivateKey, x25519.X25519PrivateKey)): + return "X25519" + elif isinstance(self.key._wrapped, (x448.X448PrivateKey, x448.X448PrivateKey)): + return "X448" + return NotImplemented def fields_to_partial_json(self) -> Dict: - params = {} # type: Dict + params = {} + print(dir(self)) if self.key.is_private(): - params['d'] = base64.b64encode(self.key.private_bytes( + params['d'] = json_util.encode_b64jose(self.key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.PKCS8, encryption_algorithm=serialization.NoEncryption() )) params['x'] = self.key.public_key().public_bytes( encoding=serialization.Encoding.PEM, - format=serialization.PublicFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption() + format=serialization.PublicFormat.SubjectPublicKeyInfo, ) else: - params['x'] = base64.b64decode(self.key.public_bytes( + params['x'] = json_util.encode_b64jose(self.key.public_bytes( serialization.Encoding.Raw, serialization.PublicFormat.Raw, serialization.NoEncryption(), )) - # TODO find a better way to get the curve name - params['crv'] = 'ed25519' + params['crv'] = self._key_to_crv() return params @classmethod @@ -463,12 +469,12 @@ def fields_from_json(cls, jobj) -> ComparableOKPKey: if "x" not in obj: raise errors.DeserializationError('OKP should have "x" parameter') - x = base64.b64decode(jobj.get("x")) + x = json_util.decode_b64jose(jobj.get("x")) try: if "d" not in obj: return jobj["key"]._wrapped.__class__.from_public_bytes(x) # noqa - d = base64.b64decode(obj.get("d")) + d = json_util.decode_b64jose(obj.get("d")) return jobj["key"]._wrapped.__class__.from_private_bytes(d) # noqa except ValueError as err: raise errors.DeserializationError("Invalid key parameter") from err diff --git a/src/josepy/jwk_test.py b/src/josepy/jwk_test.py index fe37b71e..8ba31747 100644 --- a/src/josepy/jwk_test.py +++ b/src/josepy/jwk_test.py @@ -342,7 +342,6 @@ def setUp(self): self.private = self.x448_key self.jwk = self.private # Test vectors taken from - # self.jwked25519json = { 'kty': 'OKP', 'crv': 'Ed25519', @@ -368,36 +367,31 @@ def setUp(self): 'x': 'jjQtV-fA7J_tK8dPzYq7jRPNjF8r5p6LW2R25S2Gw5U', } -# def test_encode_ed448(self): -# from josepy.jwk import JWKOKP -# import josepy -# data = """-----BEGIN PRIVATE KEY----- -# MEcCAQAwBQYDK2VxBDsEOfqsAFWdop10FFPW7Ha2tx2AZh0Ii+jfL2wFXU/dY/fe -# iU7/vrGmQ+ux26NkgzfploOHZjEmltLJ9w== -# -----END PRIVATE KEY-----""" -# key = JWKOKP.load(data) -# data = key.to_partial_json() -# # key = JWKOKP.load(data) -# x = josepy.json_util.decode_b64jose(data['x']) -# self.assertEqual(len(x), 64) + def test_encode_ed448(self): + from josepy.jwk import JWKOKP + import josepy + data = b"""-----BEGIN PRIVATE KEY----- +MEcCAQAwBQYDK2VxBDsEOfqsAFWdop10FFPW7Ha2tx2AZh0Ii+jfL2wFXU/dY/fe +iU7/vrGmQ+ux26NkgzfploOHZjEmltLJ9w== +-----END PRIVATE KEY-----""" + key = JWKOKP.load(data) + data = key.to_partial_json() + x = josepy.json_util.encode_b64jose(data['x']) + self.assertEqual(len(x), 195) def test_encode_ed25519(self): import josepy from josepy.jwk import JWKOKP data = b"""-----BEGIN PRIVATE KEY----- - MC4CAQAwBQYDK2VwBCIEIPIAha9VqyHHpY1GtEW8JXWqLU5mrPRhXPwJqCtL3bWZ - -----END PRIVATE KEY-----""" +MC4CAQAwBQYDK2VwBCIEIPIAha9VqyHHpY1GtEW8JXWqLU5mrPRhXPwJqCtL3bWZ +-----END PRIVATE KEY-----""" key = JWKOKP.load(data) data = key.to_partial_json() - print(data) - # key = jwk.JWKOKP.load(data) - # y = josepy.json_util.decode_b64jose(data['x']) - # self.assertEqual(len(y), 64) - - # def test_init_auto_comparable(self): - # self.assertIsInstance( - # self.jwk256_not_comparable.key, util.ComparableECKey) - # self.assertEqual(self.jwk256, self.jwk256_not_comparable) + x = josepy.json_util.encode_b64jose(data['x']) + self.assertEqual(len(x), 151) + + def test_init_auto_comparable(self): + self.assertIsInstance(self.x448_key.key, util.ComparableOKPKey) def test_unknown_crv_name(self): from josepy.jwk import JWK