Skip to content

Commit 1bf46d6

Browse files
committed
Fixed some dump/load problems.
1 parent cd543a9 commit 1bf46d6

File tree

4 files changed

+129
-90
lines changed

4 files changed

+129
-90
lines changed

src/cryptojwt/key_bundle.py

Lines changed: 75 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,26 @@ def ec_init(spec):
153153
class KeyBundle:
154154
"""The Key Bundle"""
155155

156+
params = {
157+
"cache_time": 0,
158+
"etag": "",
159+
"fileformat": "jwks",
160+
"httpc_params": {},
161+
"ignore_errors_period": 0,
162+
"ignore_errors_until": None,
163+
"ignore_invalid_keys": True,
164+
"imp_jwks": None,
165+
"keytype": "RSA",
166+
"keyusage": None,
167+
"last_local": None,
168+
"last_remote": None,
169+
"last_updated": 0,
170+
"local": False,
171+
"remote": False,
172+
"source": None,
173+
"time_out": 0,
174+
}
175+
156176
def __init__(
157177
self,
158178
keys=None,
@@ -491,6 +511,7 @@ def update(self):
491511

492512
# reread everything
493513
self._keys = []
514+
updated = None
494515

495516
try:
496517
if self.local:
@@ -753,66 +774,67 @@ def difference(self, bundle):
753774
return [k for k in self._keys if k not in bundle]
754775

755776
def dump(self, exclude_attributes: Optional[List[str]] = None):
756-
_keys = []
757-
for _k in self._keys:
758-
_ser = _k.to_dict()
759-
if _k.inactive_since:
760-
_ser["inactive_since"] = _k.inactive_since
761-
_keys.append(_ser)
762-
763-
res = {
764-
"keys": _keys,
765-
"cache_time": self.cache_time,
766-
"etag": self.etag,
767-
"fileformat": self.fileformat,
768-
"httpc_params": self.httpc_params,
769-
"ignore_errors_period": self.ignore_errors_period,
770-
"ignore_errors_until": self.ignore_errors_until,
771-
"ignore_invalid_keys": self.ignore_invalid_keys,
772-
"imp_jwks": self.imp_jwks,
773-
"keytype": self.keytype,
774-
"keyusage": self.keyusage,
775-
"last_local": self.last_local,
776-
"last_remote": self.last_remote,
777-
"last_updated": self.last_updated,
778-
"local": self.local,
779-
"remote": self.remote,
780-
"time_out": self.time_out,
781-
}
782-
783-
if self.source:
784-
res["source"] = self.source
785-
786-
if exclude_attributes:
787-
for attr in exclude_attributes:
788-
try:
789-
del res[attr]
790-
except KeyError:
791-
pass
777+
if exclude_attributes is None:
778+
exclude_attributes = []
779+
780+
res = {}
781+
782+
if "keys" not in exclude_attributes:
783+
_keys = []
784+
for _k in self._keys:
785+
_ser = _k.to_dict()
786+
if _k.inactive_since:
787+
_ser["inactive_since"] = _k.inactive_since
788+
_keys.append(_ser)
789+
res["keys"] = _keys
790+
791+
for attr, default in self.params.items():
792+
if attr in exclude_attributes:
793+
continue
794+
val = getattr(self, attr)
795+
res[attr] = val
796+
797+
# res = {
798+
# "cache_time": self.cache_time,
799+
# "etag": self.etag,
800+
# "fileformat": self.fileformat,
801+
# "httpc_params": self.httpc_params,
802+
# "ignore_errors_period": self.ignore_errors_period,
803+
# "ignore_errors_until": self.ignore_errors_until,
804+
# "ignore_invalid_keys": self.ignore_invalid_keys,
805+
# "imp_jwks": self.imp_jwks,
806+
# "keytype": self.keytype,
807+
# "keyusage": self.keyusage,
808+
# "last_local": self.last_local,
809+
# "last_remote": self.last_remote,
810+
# "last_updated": self.last_updated,
811+
# "local": self.local,
812+
# "remote": self.remote,
813+
# "time_out": self.time_out,
814+
# }
815+
816+
# if self.source:
817+
# res["source"] = self.source
792818

793819
return res
794820

795821
def load(self, spec):
822+
"""
823+
Sets attributes according to a specification.
824+
Does not overwrite an existing attributes value with a default value.
825+
826+
:param spec: Dictionary with attributes and value to populate the instance with
827+
:return: The instance itself
828+
"""
796829
_keys = spec.get("keys", [])
797830
if _keys:
798831
self.do_keys(_keys)
799-
self.cache_time = spec.get("cache_time", 0)
800-
self.etag = spec.get("etag", "")
801-
self.fileformat = spec.get("fileformat", "jwks")
802-
self.httpc_params = spec.get("httpc_params", {})
803-
self.ignore_errors_period = spec.get("ignore_errors_period", 0)
804-
self.ignore_errors_until = spec.get("ignore_errors_until", None)
805-
self.ignore_invalid_keys = spec.get("ignore_invalid_keys", True)
806-
self.imp_jwks = spec.get("imp_jwks", None)
807-
self.keytype = (spec.get("keytype", "RSA"),)
808-
self.keyusage = (spec.get("keyusage", None),)
809-
self.last_local = spec.get("last_local", None)
810-
self.last_remote = spec.get("last_remote", None)
811-
self.last_updated = spec.get("last_updated", 0)
812-
self.local = spec.get("local", False)
813-
self.remote = spec.get("remote", False)
814-
self.source = spec.get("source", None)
815-
self.time_out = spec.get("time_out", 0)
832+
833+
for attr, default in self.params.items():
834+
val = spec.get(attr)
835+
if val:
836+
setattr(self, attr, val)
837+
816838
return self
817839

818840
def flush(self):

src/cryptojwt/key_issuer.py

Lines changed: 37 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,22 @@
1717

1818
__author__ = "Roland Hedberg"
1919

20-
2120
logger = logging.getLogger(__name__)
2221

2322

2423
class KeyIssuer(object):
2524
""" A key issuer instance contains a number of KeyBundles. """
2625

26+
params = {
27+
"ca_certs": None,
28+
"keybundle_cls": KeyBundle,
29+
"remove_after": 3600,
30+
"httpc": None,
31+
"httpc_params": None,
32+
"name": "",
33+
"spec2key": None,
34+
}
35+
2736
def __init__(
2837
self,
2938
ca_certs=None,
@@ -360,27 +369,23 @@ def dump(self, exclude_attributes: Optional[List[str]] = None) -> dict:
360369
:return: A dictionary
361370
"""
362371

363-
_bundles = []
364-
for kb in self._bundles:
365-
_bundles.append(kb.dump(exclude_attributes=exclude_attributes))
366-
367-
info = {
368-
"name": self.name,
369-
"bundles": _bundles,
370-
"keybundle_cls": qualified_name(self.keybundle_cls),
371-
"spec2key": self.spec2key,
372-
"ca_certs": self.ca_certs,
373-
"remove_after": self.remove_after,
374-
"httpc_params": self.httpc_params,
375-
}
376-
377-
# remove after the fact
378-
if exclude_attributes:
379-
for attr in exclude_attributes:
380-
try:
381-
del info[attr]
382-
except KeyError:
383-
pass
372+
if exclude_attributes is None:
373+
exclude_attributes = []
374+
375+
info = {}
376+
for attr, default in self.params.items():
377+
if attr in exclude_attributes:
378+
continue
379+
val = getattr(self, attr)
380+
if attr == "keybundle_cls":
381+
val = qualified_name(val)
382+
info[attr] = val
383+
384+
if "bundles" not in exclude_attributes:
385+
_bundles = []
386+
for kb in self._bundles:
387+
_bundles.append(kb.dump(exclude_attributes=exclude_attributes))
388+
info["bundles"] = _bundles
384389

385390
return info
386391

@@ -390,24 +395,20 @@ def load(self, info):
390395
:param items: A dictionary with the information to load
391396
:return:
392397
"""
393-
self.name = info["name"]
394-
self.keybundle_cls = importer(info["keybundle_cls"])
395-
self.spec2key = info["spec2key"]
396-
self.ca_certs = info["ca_certs"]
397-
self.remove_after = info["remove_after"]
398-
self.httpc_params = info["httpc_params"]
398+
for attr, default in self.params.items():
399+
val = info.get(attr)
400+
if val:
401+
if attr == "keybundle_cls":
402+
val = importer(val)
403+
setattr(self, attr, val)
404+
399405
self._bundles = [KeyBundle().load(val) for val in info["bundles"]]
400406
return self
401407

402408
def flush(self):
403-
self.ca_certs = (None,)
404-
self.keybundle_cls = (KeyBundle,)
405-
self.remove_after = (3600,)
406-
self.httpc = (None,)
407-
self.httpc_params = (None,)
408-
self.name = ""
409-
self.spec2key = None
410-
self.remove_after = 0
409+
for attr, default in self.params.items():
410+
setattr(self, attr, default)
411+
411412
self._bundles = []
412413
return self
413414

src/cryptojwt/serialize/item.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ class KeyIssuer:
77
@staticmethod
88
def serialize(item: key_issuer.KeyIssuer) -> str:
99
""" Convert from KeyIssuer to JSON """
10-
return json.dumps(item.dump())
10+
return json.dumps(item.dump(exclude_attributes=["keybundle_cls", "httpc"]))
1111

1212
def deserialize(self, spec: str) -> key_issuer.KeyIssuer:
1313
""" Convert from JSON to KeyIssuer """

tests/test_03_key_bundle.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -968,6 +968,7 @@ def test_export_inactive():
968968
"last_local",
969969
"remote",
970970
"local",
971+
"source",
971972
"time_out",
972973
}
973974

@@ -1085,3 +1086,18 @@ def test_ignore_invalid_keys():
10851086

10861087
with pytest.raises(UnknownKeyType):
10871088
KeyBundle(keys={"keys": [rsa_key_dict]}, ignore_invalid_keys=False)
1089+
1090+
1091+
def test_exclude_attributes():
1092+
source = "https://example.com/keys.json"
1093+
# Mock response
1094+
with responses.RequestsMock() as rsps:
1095+
rsps.add(method="GET", url=source, json=JWKS_DICT, status=200)
1096+
httpc_params = {"timeout": (2, 2)} # connect, read timeouts in seconds
1097+
kb = KeyBundle(source=source, httpc=requests.request, httpc_params=httpc_params)
1098+
kb.do_remote()
1099+
1100+
exp = kb.dump(exclude_attributes=["cache_time", "ignore_invalid_keys"])
1101+
kb2 = KeyBundle(cache_time=600, ignore_invalid_keys=False).load(exp)
1102+
assert kb2.cache_time == 600
1103+
assert kb2.ignore_invalid_keys is False

0 commit comments

Comments
 (0)