Skip to content

Commit 8d85731

Browse files
committed
ensure we only update keys once
1 parent 87a9908 commit 8d85731

File tree

2 files changed

+44
-49
lines changed

2 files changed

+44
-49
lines changed

src/cryptojwt/key_bundle.py

Lines changed: 29 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,7 @@ def __init__(
263263
if self.local:
264264
self._keys = self._do_local(kid)
265265

266+
266267
def _set_source(self, source, fileformat):
267268
if source.startswith("file://"):
268269
self.source = source[7:]
@@ -284,10 +285,10 @@ def _set_source(self, source, fileformat):
284285

285286
def _do_local(self, kid):
286287
if self.fileformat in ["jwks", "jwk"]:
287-
updated, res = self._do_local_jwk(self.source)
288+
updated, keys = self._do_local_jwk(self.source)
288289
elif self.fileformat == "der":
289-
updated, res = self._do_local_der(self.source, self.keytype, self.keyusage, kid)
290-
return res
290+
updated, keys = self._do_local_der(self.source, self.keytype, self.keyusage, kid)
291+
return keys
291292

292293
def _local_update_required(self) -> bool:
293294
stat = os.stat(self.source)
@@ -311,14 +312,9 @@ def add_jwk_dicts(self, keys):
311312
:param keys: List of JWK dictionaries
312313
:return:
313314
"""
314-
self._add_jwk_dicts(keys)
315+
self._keys.extend(self.jwk_dicts_as_keys(keys))
315316
self.last_updated = time.time()
316317

317-
def _add_jwk_dicts(self, keys):
318-
_new_keys = self.jwk_dicts_as_keys(keys)
319-
if _new_keys:
320-
self._keys.extend(_new_keys)
321-
322318
def jwk_dicts_as_keys(self, keys):
323319
"""
324320
Return JWK dictionaries as list of JWK objects
@@ -392,13 +388,13 @@ def _do_local_jwk(self, filename):
392388
with open(filename) as input_file:
393389
_info = json.load(input_file)
394390
if "keys" in _info:
395-
res = self.jwk_dicts_as_keys(_info["keys"])
391+
new_keys = self.jwk_dicts_as_keys(_info["keys"])
396392
else:
397-
res = self.jwk_dicts_as_keys([_info])
393+
new_keys = self.jwk_dicts_as_keys([_info])
398394

399395
self.last_local = time.time()
400396
self.time_out = self.last_local + self.cache_time
401-
return True, res
397+
return True, new_keys
402398

403399
def _do_local_der(self, filename, keytype, keyusage=None, kid=""):
404400
"""
@@ -431,12 +427,12 @@ def _do_local_der(self, filename, keytype, keyusage=None, kid=""):
431427
if kid:
432428
key_args["kid"] = kid
433429

434-
res = self.jwk_dicts_as_keys([key_args])
430+
new_keys = self.jwk_dicts_as_keys([key_args])
435431
self.last_local = time.time()
436432
self.time_out = self.last_local + self.cache_time
437-
return True, res
433+
return True, new_keys
438434

439-
def _do_remote(self):
435+
def _do_remote(self, set_keys=True):
440436
"""
441437
Load a JWKS from a webpage.
442438
@@ -451,7 +447,7 @@ def _do_remote(self):
451447
self.source,
452448
datetime.fromtimestamp(self.ignore_errors_until),
453449
)
454-
return False
450+
return False, None
455451

456452
LOGGER.info("Reading remote JWKS from %s", self.source)
457453
try:
@@ -500,11 +496,12 @@ def _do_remote(self):
500496
self.ignore_errors_until = time.time() + self.ignore_errors_period
501497
raise UpdateFailed(REMOTE_FAILED.format(self.source, _http_resp.status_code))
502498

503-
if new_keys is not None:
499+
if set_keys and new_keys:
504500
self._keys = new_keys
501+
505502
self.last_updated = time.time()
506503
self.ignore_errors_until = None
507-
return load_successful
504+
return load_successful, new_keys
508505

509506
def _parse_remote_response(self, response):
510507
"""
@@ -545,38 +542,31 @@ def update(self):
545542
:return: True if update was ok or False if we encountered an error during update.
546543
"""
547544
if self.source:
548-
_old_keys = self._keys # just in case
549-
550-
# reread everything
551-
self._keys = []
545+
new_keys = []
552546
updated = None
553547

554548
try:
555549
if self.local:
556550
if self.fileformat in ["jwks", "jwk"]:
557551
updated, k = self._do_local_jwk(self.source)
558-
if k:
559-
self._keys.extend(k)
560552
elif self.fileformat == "der":
561553
updated, k = self._do_local_der(self.source, self.keytype, self.keyusage)
562-
if k:
563-
self._keys.extend(k)
564554
elif self.remote:
565-
updated = self._do_remote()
555+
updated, k = self._do_remote(set_keys=False)
556+
if k:
557+
new_keys.extend(k)
566558
except Exception as err:
567559
LOGGER.error("Key bundle update failed: %s", err)
568-
self._keys = _old_keys # restore
569560
return False
570561

571562
if updated:
572563
now = time.time()
573-
for _key in _old_keys:
574-
if _key not in self._keys:
564+
for _key in self._keys:
565+
if _key not in new_keys:
575566
if not _key.inactive_since: # If already marked don't mess
576567
_key.inactive_since = now
577-
self._keys.append(_key)
578-
else:
579-
self._keys = _old_keys
568+
new_keys.append(_key)
569+
self._keys = new_keys
580570

581571
return True
582572

@@ -592,9 +582,9 @@ def get(self, typ="", only_active=True):
592582

593583
if typ:
594584
_typs = [typ.lower(), typ.upper()]
595-
_keys = [k for k in self._keys[:] if k.kty in _typs]
585+
_keys = [k for k in self._keys if k.kty in _typs]
596586
else:
597-
_keys = self._keys[:]
587+
_keys = self._keys
598588

599589
if only_active:
600590
return [k for k in _keys if not k.inactive_since]
@@ -609,7 +599,7 @@ def keys(self, update: bool = True):
609599
"""
610600
if update:
611601
self._uptodate()
612-
return self._keys[:]
602+
return self._keys
613603

614604
def active_keys(self):
615605
"""Return the set of active keys."""
@@ -836,9 +826,11 @@ def load(self, spec):
836826
:param spec: Dictionary with attributes and value to populate the instance with
837827
:return: The instance itself
838828
"""
829+
839830
_keys = spec.get("keys", [])
840831
if _keys:
841-
self._add_jwk_dicts(_keys)
832+
self._keys.extend(self.jwk_dicts_as_keys(_keys))
833+
self.last_updated = time.time()
842834

843835
for attr, default in self.params.items():
844836
val = spec.get(attr)

tests/test_03_key_bundle.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,8 @@ def test_httpc_params_1():
480480
rsps.add(method=responses.GET, url=source, json=JWKS_DICT, status=200)
481481
httpc_params = {"timeout": (2, 2)} # connect, read timeouts in seconds
482482
kb = KeyBundle(source=source, httpc=requests.request, httpc_params=httpc_params)
483-
assert kb._do_remote()
483+
updated, _ = kb._do_remote()
484+
assert updated == True
484485

485486

486487
@pytest.mark.network
@@ -920,7 +921,7 @@ def test_export_inactive():
920921

921922

922923
def test_remote():
923-
source = "https://example.com/keys.json"
924+
source = "https://example.com/test_remote/keys.json"
924925
# Mock response
925926
with responses.RequestsMock() as rsps:
926927
rsps.add(method="GET", url=source, json=JWKS_DICT, status=200)
@@ -941,7 +942,7 @@ def test_remote():
941942

942943

943944
def test_remote_not_modified():
944-
source = "https://example.com/keys.json"
945+
source = "https://example.com/test_remote_not_modified/keys.json"
945946
headers = {
946947
"Date": "Fri, 15 Mar 2019 10:14:25 GMT",
947948
"Last-Modified": "Fri, 1 Jan 1970 00:00:00 GMT",
@@ -954,13 +955,15 @@ def test_remote_not_modified():
954955

955956
with responses.RequestsMock() as rsps:
956957
rsps.add(method="GET", url=source, json=JWKS_DICT, status=200, headers=headers)
957-
assert kb._do_remote()
958+
updated, _ = kb._do_remote()
959+
assert updated == True
958960
assert kb.last_remote == headers.get("Last-Modified")
959961
timeout1 = kb.time_out
960962

961963
with responses.RequestsMock() as rsps:
962964
rsps.add(method="GET", url=source, status=304, headers=headers)
963-
assert not kb._do_remote()
965+
updated, _ = kb._do_remote()
966+
assert not updated
964967
assert kb.last_remote == headers.get("Last-Modified")
965968
timeout2 = kb.time_out
966969

@@ -980,8 +983,8 @@ def test_remote_not_modified():
980983

981984

982985
def test_ignore_errors_period():
983-
source_good = "https://example.com/keys.json"
984-
source_bad = "https://example.com/keys-bad.json"
986+
source_good = "https://example.com/test_ignore_errors_period/keys.json"
987+
source_bad = "https://example.com/test_ignore_errors_period/keys-bad.json"
985988
ignore_errors_period = 1
986989
# Mock response
987990
with responses.RequestsMock() as rsps:
@@ -994,19 +997,19 @@ def test_ignore_errors_period():
994997
httpc_params=httpc_params,
995998
ignore_errors_period=ignore_errors_period,
996999
)
997-
res = kb._do_remote()
1000+
res, _ = kb._do_remote()
9981001
assert res == True
9991002
assert kb.ignore_errors_until is None
10001003

10011004
# refetch, but fail by using a bad source
10021005
kb.source = source_bad
10031006
try:
1004-
res = kb._do_remote()
1007+
res, _ = kb._do_remote()
10051008
except UpdateFailed:
10061009
pass
10071010

10081011
# retry should fail silently as we're in holddown
1009-
res = kb._do_remote()
1012+
res, _ = kb._do_remote()
10101013
assert kb.ignore_errors_until is not None
10111014
assert res == False
10121015

@@ -1015,7 +1018,7 @@ def test_ignore_errors_period():
10151018

10161019
# try again
10171020
kb.source = source_good
1018-
res = kb._do_remote()
1021+
res, _ = kb._do_remote()
10191022
assert res == True
10201023

10211024

@@ -1031,7 +1034,7 @@ def test_ignore_invalid_keys():
10311034

10321035

10331036
def test_exclude_attributes():
1034-
source = "https://example.com/keys.json"
1037+
source = "https://example.com/test_exclude_attributes/keys.json"
10351038
# Mock response
10361039
with responses.RequestsMock() as rsps:
10371040
rsps.add(method="GET", url=source, json=JWKS_DICT, status=200)

0 commit comments

Comments
 (0)