@@ -252,15 +252,16 @@ def __init__(
252
252
self .source = None
253
253
if isinstance (keys , dict ):
254
254
if "keys" in keys :
255
- self . _add_jwk_dicts ( keys ["keys" ])
255
+ initial_keys = keys ["keys" ]
256
256
else :
257
- self . _add_jwk_dicts ( [keys ])
257
+ initial_keys = [keys ]
258
258
else :
259
- self ._add_jwk_dicts (keys )
259
+ initial_keys = keys
260
+ self ._keys = self .jwk_dicts_as_keys (initial_keys )
260
261
else :
261
262
self ._set_source (source , fileformat )
262
263
if self .local :
263
- self ._do_local (kid )
264
+ self ._keys = self . _do_local (kid )
264
265
265
266
def _set_source (self , source , fileformat ):
266
267
if source .startswith ("file://" ):
@@ -283,9 +284,10 @@ def _set_source(self, source, fileformat):
283
284
284
285
def _do_local (self , kid ):
285
286
if self .fileformat in ["jwks" , "jwk" ]:
286
- self ._do_local_jwk (self .source )
287
+ updated , keys = self ._do_local_jwk (self .source )
287
288
elif self .fileformat == "der" :
288
- self ._do_local_der (self .source , self .keytype , self .keyusage , kid )
289
+ updated , keys = self ._do_local_der (self .source , self .keytype , self .keyusage , kid )
290
+ return keys
289
291
290
292
def _local_update_required (self ) -> bool :
291
293
stat = os .stat (self .source )
@@ -309,13 +311,8 @@ def add_jwk_dicts(self, keys):
309
311
:param keys: List of JWK dictionaries
310
312
:return:
311
313
"""
312
- self ._add_jwk_dicts (keys )
313
-
314
- def _add_jwk_dicts (self , keys ):
315
- _new_keys = self .jwk_dicts_as_keys (keys )
316
- if _new_keys :
317
- self ._keys .extend (_new_keys )
318
- self .last_updated = time .time ()
314
+ self ._keys .extend (self .jwk_dicts_as_keys (keys ))
315
+ self .last_updated = time .time ()
319
316
320
317
def jwk_dicts_as_keys (self , keys ):
321
318
"""
@@ -384,18 +381,19 @@ def _do_local_jwk(self, filename):
384
381
:return: True if load was successful or False if file hasn't been modified
385
382
"""
386
383
if not self ._local_update_required ():
387
- return False
384
+ return False , None
388
385
389
386
LOGGER .info ("Reading local JWKS from %s" , filename )
390
387
with open (filename ) as input_file :
391
388
_info = json .load (input_file )
392
389
if "keys" in _info :
393
- self ._add_jwk_dicts (_info ["keys" ])
390
+ new_keys = self .jwk_dicts_as_keys (_info ["keys" ])
394
391
else :
395
- self ._add_jwk_dicts ([_info ])
392
+ new_keys = self .jwk_dicts_as_keys ([_info ])
393
+
396
394
self .last_local = time .time ()
397
395
self .time_out = self .last_local + self .cache_time
398
- return True
396
+ return True , new_keys
399
397
400
398
def _do_local_der (self , filename , keytype , keyusage = None , kid = "" ):
401
399
"""
@@ -407,7 +405,7 @@ def _do_local_der(self, filename, keytype, keyusage=None, kid=""):
407
405
:return: True if load was successful or False if file hasn't been modified
408
406
"""
409
407
if not self ._local_update_required ():
410
- return False
408
+ return False , None
411
409
412
410
LOGGER .info ("Reading local DER from %s" , filename )
413
411
key_args = {}
@@ -428,12 +426,12 @@ def _do_local_der(self, filename, keytype, keyusage=None, kid=""):
428
426
if kid :
429
427
key_args ["kid" ] = kid
430
428
431
- self ._add_jwk_dicts ([key_args ])
429
+ new_keys = self .jwk_dicts_as_keys ([key_args ])
432
430
self .last_local = time .time ()
433
431
self .time_out = self .last_local + self .cache_time
434
- return True
432
+ return True , new_keys
435
433
436
- def _do_remote (self ):
434
+ def _do_remote (self , set_keys = True ):
437
435
"""
438
436
Load a JWKS from a webpage.
439
437
@@ -448,7 +446,7 @@ def _do_remote(self):
448
446
self .source ,
449
447
datetime .fromtimestamp (self .ignore_errors_until ),
450
448
)
451
- return False
449
+ return False , None
452
450
453
451
LOGGER .info ("Reading remote JWKS from %s" , self .source )
454
452
try :
@@ -497,11 +495,12 @@ def _do_remote(self):
497
495
self .ignore_errors_until = time .time () + self .ignore_errors_period
498
496
raise UpdateFailed (REMOTE_FAILED .format (self .source , _http_resp .status_code ))
499
497
500
- if new_keys is not None :
498
+ if set_keys and new_keys :
501
499
self ._keys = new_keys
500
+
502
501
self .last_updated = time .time ()
503
502
self .ignore_errors_until = None
504
- return load_successful
503
+ return load_successful , new_keys
505
504
506
505
def _parse_remote_response (self , response ):
507
506
"""
@@ -542,34 +541,31 @@ def update(self):
542
541
:return: True if update was ok or False if we encountered an error during update.
543
542
"""
544
543
if self .source :
545
- _old_keys = self ._keys # just in case
546
-
547
- # reread everything
548
- self ._keys = []
544
+ new_keys = []
549
545
updated = None
550
546
551
547
try :
552
548
if self .local :
553
549
if self .fileformat in ["jwks" , "jwk" ]:
554
- updated = self ._do_local_jwk (self .source )
550
+ updated , k = self ._do_local_jwk (self .source )
555
551
elif self .fileformat == "der" :
556
- updated = self ._do_local_der (self .source , self .keytype , self .keyusage )
552
+ updated , k = self ._do_local_der (self .source , self .keytype , self .keyusage )
557
553
elif self .remote :
558
- updated = self ._do_remote ()
554
+ updated , k = self ._do_remote (set_keys = False )
555
+ if k :
556
+ new_keys .extend (k )
559
557
except Exception as err :
560
558
LOGGER .error ("Key bundle update failed: %s" , err )
561
- self ._keys = _old_keys # restore
562
559
return False
563
560
564
561
if updated :
565
562
now = time .time ()
566
- for _key in _old_keys :
567
- if _key not in self . _keys :
563
+ for _key in self . _keys :
564
+ if _key not in new_keys :
568
565
if not _key .inactive_since : # If already marked don't mess
569
566
_key .inactive_since = now
570
- self ._keys .append (_key )
571
- else :
572
- self ._keys = _old_keys
567
+ new_keys .append (_key )
568
+ self ._keys = new_keys
573
569
574
570
return True
575
571
@@ -585,9 +581,9 @@ def get(self, typ="", only_active=True):
585
581
586
582
if typ :
587
583
_typs = [typ .lower (), typ .upper ()]
588
- _keys = [k for k in self ._keys [:] if k .kty in _typs ]
584
+ _keys = [k for k in self ._keys if k .kty in _typs ]
589
585
else :
590
- _keys = self ._keys [:]
586
+ _keys = self ._keys
591
587
592
588
if only_active :
593
589
return [k for k in _keys if not k .inactive_since ]
@@ -602,7 +598,7 @@ def keys(self, update: bool = True):
602
598
"""
603
599
if update :
604
600
self ._uptodate ()
605
- return self ._keys [:]
601
+ return self ._keys
606
602
607
603
def active_keys (self ):
608
604
"""Return the set of active keys."""
@@ -829,9 +825,11 @@ def load(self, spec):
829
825
:param spec: Dictionary with attributes and value to populate the instance with
830
826
:return: The instance itself
831
827
"""
828
+
832
829
_keys = spec .get ("keys" , [])
833
830
if _keys :
834
- self ._add_jwk_dicts (_keys )
831
+ self ._keys .extend (self .jwk_dicts_as_keys (_keys ))
832
+ self .last_updated = time .time ()
835
833
836
834
for attr , default in self .params .items ():
837
835
val = spec .get (attr )
0 commit comments