Skip to content

Commit 0bb1c54

Browse files
authored
Merge pull request #115 from jschlyter/atomic_keys_update
2 parents 633118d + baacdd1 commit 0bb1c54

File tree

2 files changed

+54
-53
lines changed

2 files changed

+54
-53
lines changed

src/cryptojwt/key_bundle.py

Lines changed: 39 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -252,15 +252,16 @@ def __init__(
252252
self.source = None
253253
if isinstance(keys, dict):
254254
if "keys" in keys:
255-
self._add_jwk_dicts(keys["keys"])
255+
initial_keys = keys["keys"]
256256
else:
257-
self._add_jwk_dicts([keys])
257+
initial_keys = [keys]
258258
else:
259-
self._add_jwk_dicts(keys)
259+
initial_keys = keys
260+
self._keys = self.jwk_dicts_as_keys(initial_keys)
260261
else:
261262
self._set_source(source, fileformat)
262263
if self.local:
263-
self._do_local(kid)
264+
self._keys = self._do_local(kid)
264265

265266
def _set_source(self, source, fileformat):
266267
if source.startswith("file://"):
@@ -283,9 +284,10 @@ def _set_source(self, source, fileformat):
283284

284285
def _do_local(self, kid):
285286
if self.fileformat in ["jwks", "jwk"]:
286-
self._do_local_jwk(self.source)
287+
updated, keys = self._do_local_jwk(self.source)
287288
elif self.fileformat == "der":
288-
self._do_local_der(self.source, self.keytype, self.keyusage, kid)
289+
updated, keys = self._do_local_der(self.source, self.keytype, self.keyusage, kid)
290+
return keys
289291

290292
def _local_update_required(self) -> bool:
291293
stat = os.stat(self.source)
@@ -309,13 +311,8 @@ def add_jwk_dicts(self, keys):
309311
:param keys: List of JWK dictionaries
310312
:return:
311313
"""
312-
self._add_jwk_dicts(keys)
313-
314-
def _add_jwk_dicts(self, keys):
315-
_new_keys = self.jwk_dicts_as_keys(keys)
316-
if _new_keys:
317-
self._keys.extend(_new_keys)
318-
self.last_updated = time.time()
314+
self._keys.extend(self.jwk_dicts_as_keys(keys))
315+
self.last_updated = time.time()
319316

320317
def jwk_dicts_as_keys(self, keys):
321318
"""
@@ -384,18 +381,19 @@ def _do_local_jwk(self, filename):
384381
:return: True if load was successful or False if file hasn't been modified
385382
"""
386383
if not self._local_update_required():
387-
return False
384+
return False, None
388385

389386
LOGGER.info("Reading local JWKS from %s", filename)
390387
with open(filename) as input_file:
391388
_info = json.load(input_file)
392389
if "keys" in _info:
393-
self._add_jwk_dicts(_info["keys"])
390+
new_keys = self.jwk_dicts_as_keys(_info["keys"])
394391
else:
395-
self._add_jwk_dicts([_info])
392+
new_keys = self.jwk_dicts_as_keys([_info])
393+
396394
self.last_local = time.time()
397395
self.time_out = self.last_local + self.cache_time
398-
return True
396+
return True, new_keys
399397

400398
def _do_local_der(self, filename, keytype, keyusage=None, kid=""):
401399
"""
@@ -407,7 +405,7 @@ def _do_local_der(self, filename, keytype, keyusage=None, kid=""):
407405
:return: True if load was successful or False if file hasn't been modified
408406
"""
409407
if not self._local_update_required():
410-
return False
408+
return False, None
411409

412410
LOGGER.info("Reading local DER from %s", filename)
413411
key_args = {}
@@ -428,12 +426,12 @@ def _do_local_der(self, filename, keytype, keyusage=None, kid=""):
428426
if kid:
429427
key_args["kid"] = kid
430428

431-
self._add_jwk_dicts([key_args])
429+
new_keys = self.jwk_dicts_as_keys([key_args])
432430
self.last_local = time.time()
433431
self.time_out = self.last_local + self.cache_time
434-
return True
432+
return True, new_keys
435433

436-
def _do_remote(self):
434+
def _do_remote(self, set_keys=True):
437435
"""
438436
Load a JWKS from a webpage.
439437
@@ -448,7 +446,7 @@ def _do_remote(self):
448446
self.source,
449447
datetime.fromtimestamp(self.ignore_errors_until),
450448
)
451-
return False
449+
return False, None
452450

453451
LOGGER.info("Reading remote JWKS from %s", self.source)
454452
try:
@@ -497,11 +495,12 @@ def _do_remote(self):
497495
self.ignore_errors_until = time.time() + self.ignore_errors_period
498496
raise UpdateFailed(REMOTE_FAILED.format(self.source, _http_resp.status_code))
499497

500-
if new_keys is not None:
498+
if set_keys and new_keys:
501499
self._keys = new_keys
500+
502501
self.last_updated = time.time()
503502
self.ignore_errors_until = None
504-
return load_successful
503+
return load_successful, new_keys
505504

506505
def _parse_remote_response(self, response):
507506
"""
@@ -542,34 +541,31 @@ def update(self):
542541
:return: True if update was ok or False if we encountered an error during update.
543542
"""
544543
if self.source:
545-
_old_keys = self._keys # just in case
546-
547-
# reread everything
548-
self._keys = []
544+
new_keys = []
549545
updated = None
550546

551547
try:
552548
if self.local:
553549
if self.fileformat in ["jwks", "jwk"]:
554-
updated = self._do_local_jwk(self.source)
550+
updated, k = self._do_local_jwk(self.source)
555551
elif self.fileformat == "der":
556-
updated = self._do_local_der(self.source, self.keytype, self.keyusage)
552+
updated, k = self._do_local_der(self.source, self.keytype, self.keyusage)
557553
elif self.remote:
558-
updated = self._do_remote()
554+
updated, k = self._do_remote(set_keys=False)
555+
if k:
556+
new_keys.extend(k)
559557
except Exception as err:
560558
LOGGER.error("Key bundle update failed: %s", err)
561-
self._keys = _old_keys # restore
562559
return False
563560

564561
if updated:
565562
now = time.time()
566-
for _key in _old_keys:
567-
if _key not in self._keys:
563+
for _key in self._keys:
564+
if _key not in new_keys:
568565
if not _key.inactive_since: # If already marked don't mess
569566
_key.inactive_since = now
570-
self._keys.append(_key)
571-
else:
572-
self._keys = _old_keys
567+
new_keys.append(_key)
568+
self._keys = new_keys
573569

574570
return True
575571

@@ -585,9 +581,9 @@ def get(self, typ="", only_active=True):
585581

586582
if typ:
587583
_typs = [typ.lower(), typ.upper()]
588-
_keys = [k for k in self._keys[:] if k.kty in _typs]
584+
_keys = [k for k in self._keys if k.kty in _typs]
589585
else:
590-
_keys = self._keys[:]
586+
_keys = self._keys
591587

592588
if only_active:
593589
return [k for k in _keys if not k.inactive_since]
@@ -602,7 +598,7 @@ def keys(self, update: bool = True):
602598
"""
603599
if update:
604600
self._uptodate()
605-
return self._keys[:]
601+
return self._keys
606602

607603
def active_keys(self):
608604
"""Return the set of active keys."""
@@ -829,9 +825,11 @@ def load(self, spec):
829825
:param spec: Dictionary with attributes and value to populate the instance with
830826
:return: The instance itself
831827
"""
828+
832829
_keys = spec.get("keys", [])
833830
if _keys:
834-
self._add_jwk_dicts(_keys)
831+
self._keys.extend(self.jwk_dicts_as_keys(_keys))
832+
self.last_updated = time.time()
835833

836834
for attr, default in self.params.items():
837835
val = spec.get(attr)

tests/test_03_key_bundle.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,8 @@ def test_httpc_params_1():
480480
rsps.add(method=responses.GET, url=source, json=JWKS_DICT, status=200)
481481
httpc_params = {"timeout": (2, 2)} # connect, read timeouts in seconds
482482
kb = KeyBundle(source=source, httpc=requests.request, httpc_params=httpc_params)
483-
assert kb._do_remote()
483+
updated, _ = kb._do_remote()
484+
assert updated == True
484485

485486

486487
@pytest.mark.network
@@ -920,7 +921,7 @@ def test_export_inactive():
920921

921922

922923
def test_remote():
923-
source = "https://example.com/keys.json"
924+
source = "https://example.com/test_remote/keys.json"
924925
# Mock response
925926
with responses.RequestsMock() as rsps:
926927
rsps.add(method="GET", url=source, json=JWKS_DICT, status=200)
@@ -941,7 +942,7 @@ def test_remote():
941942

942943

943944
def test_remote_not_modified():
944-
source = "https://example.com/keys.json"
945+
source = "https://example.com/test_remote_not_modified/keys.json"
945946
headers = {
946947
"Date": "Fri, 15 Mar 2019 10:14:25 GMT",
947948
"Last-Modified": "Fri, 1 Jan 1970 00:00:00 GMT",
@@ -954,13 +955,15 @@ def test_remote_not_modified():
954955

955956
with responses.RequestsMock() as rsps:
956957
rsps.add(method="GET", url=source, json=JWKS_DICT, status=200, headers=headers)
957-
assert kb._do_remote()
958+
updated, _ = kb._do_remote()
959+
assert updated == True
958960
assert kb.last_remote == headers.get("Last-Modified")
959961
timeout1 = kb.time_out
960962

961963
with responses.RequestsMock() as rsps:
962964
rsps.add(method="GET", url=source, status=304, headers=headers)
963-
assert not kb._do_remote()
965+
updated, _ = kb._do_remote()
966+
assert not updated
964967
assert kb.last_remote == headers.get("Last-Modified")
965968
timeout2 = kb.time_out
966969

@@ -980,8 +983,8 @@ def test_remote_not_modified():
980983

981984

982985
def test_ignore_errors_period():
983-
source_good = "https://example.com/keys.json"
984-
source_bad = "https://example.com/keys-bad.json"
986+
source_good = "https://example.com/test_ignore_errors_period/keys.json"
987+
source_bad = "https://example.com/test_ignore_errors_period/keys-bad.json"
985988
ignore_errors_period = 1
986989
# Mock response
987990
with responses.RequestsMock() as rsps:
@@ -994,19 +997,19 @@ def test_ignore_errors_period():
994997
httpc_params=httpc_params,
995998
ignore_errors_period=ignore_errors_period,
996999
)
997-
res = kb._do_remote()
1000+
res, _ = kb._do_remote()
9981001
assert res == True
9991002
assert kb.ignore_errors_until is None
10001003

10011004
# refetch, but fail by using a bad source
10021005
kb.source = source_bad
10031006
try:
1004-
res = kb._do_remote()
1007+
res, _ = kb._do_remote()
10051008
except UpdateFailed:
10061009
pass
10071010

10081011
# retry should fail silently as we're in holddown
1009-
res = kb._do_remote()
1012+
res, _ = kb._do_remote()
10101013
assert kb.ignore_errors_until is not None
10111014
assert res == False
10121015

@@ -1015,7 +1018,7 @@ def test_ignore_errors_period():
10151018

10161019
# try again
10171020
kb.source = source_good
1018-
res = kb._do_remote()
1021+
res, _ = kb._do_remote()
10191022
assert res == True
10201023

10211024

@@ -1031,7 +1034,7 @@ def test_ignore_invalid_keys():
10311034

10321035

10331036
def test_exclude_attributes():
1034-
source = "https://example.com/keys.json"
1037+
source = "https://example.com/test_exclude_attributes/keys.json"
10351038
# Mock response
10361039
with responses.RequestsMock() as rsps:
10371040
rsps.add(method="GET", url=source, json=JWKS_DICT, status=200)

0 commit comments

Comments
 (0)