@@ -263,6 +263,7 @@ def __init__(
263
263
if self .local :
264
264
self ._keys = self ._do_local (kid )
265
265
266
+
266
267
def _set_source (self , source , fileformat ):
267
268
if source .startswith ("file://" ):
268
269
self .source = source [7 :]
@@ -284,10 +285,10 @@ def _set_source(self, source, fileformat):
284
285
285
286
def _do_local (self , kid ):
286
287
if self .fileformat in ["jwks" , "jwk" ]:
287
- updated , res = self ._do_local_jwk (self .source )
288
+ updated , keys = self ._do_local_jwk (self .source )
288
289
elif self .fileformat == "der" :
289
- updated , res = self ._do_local_der (self .source , self .keytype , self .keyusage , kid )
290
- return res
290
+ updated , keys = self ._do_local_der (self .source , self .keytype , self .keyusage , kid )
291
+ return keys
291
292
292
293
def _local_update_required (self ) -> bool :
293
294
stat = os .stat (self .source )
@@ -311,14 +312,9 @@ def add_jwk_dicts(self, keys):
311
312
:param keys: List of JWK dictionaries
312
313
:return:
313
314
"""
314
- self ._add_jwk_dicts ( keys )
315
+ self ._keys . extend ( self . jwk_dicts_as_keys ( keys ) )
315
316
self .last_updated = time .time ()
316
317
317
- def _add_jwk_dicts (self , keys ):
318
- _new_keys = self .jwk_dicts_as_keys (keys )
319
- if _new_keys :
320
- self ._keys .extend (_new_keys )
321
-
322
318
def jwk_dicts_as_keys (self , keys ):
323
319
"""
324
320
Return JWK dictionaries as list of JWK objects
@@ -392,13 +388,13 @@ def _do_local_jwk(self, filename):
392
388
with open (filename ) as input_file :
393
389
_info = json .load (input_file )
394
390
if "keys" in _info :
395
- res = self .jwk_dicts_as_keys (_info ["keys" ])
391
+ new_keys = self .jwk_dicts_as_keys (_info ["keys" ])
396
392
else :
397
- res = self .jwk_dicts_as_keys ([_info ])
393
+ new_keys = self .jwk_dicts_as_keys ([_info ])
398
394
399
395
self .last_local = time .time ()
400
396
self .time_out = self .last_local + self .cache_time
401
- return True , res
397
+ return True , new_keys
402
398
403
399
def _do_local_der (self , filename , keytype , keyusage = None , kid = "" ):
404
400
"""
@@ -431,12 +427,12 @@ def _do_local_der(self, filename, keytype, keyusage=None, kid=""):
431
427
if kid :
432
428
key_args ["kid" ] = kid
433
429
434
- res = self .jwk_dicts_as_keys ([key_args ])
430
+ new_keys = self .jwk_dicts_as_keys ([key_args ])
435
431
self .last_local = time .time ()
436
432
self .time_out = self .last_local + self .cache_time
437
- return True , res
433
+ return True , new_keys
438
434
439
- def _do_remote (self ):
435
+ def _do_remote (self , set_keys = True ):
440
436
"""
441
437
Load a JWKS from a webpage.
442
438
@@ -451,7 +447,7 @@ def _do_remote(self):
451
447
self .source ,
452
448
datetime .fromtimestamp (self .ignore_errors_until ),
453
449
)
454
- return False
450
+ return False , None
455
451
456
452
LOGGER .info ("Reading remote JWKS from %s" , self .source )
457
453
try :
@@ -500,11 +496,12 @@ def _do_remote(self):
500
496
self .ignore_errors_until = time .time () + self .ignore_errors_period
501
497
raise UpdateFailed (REMOTE_FAILED .format (self .source , _http_resp .status_code ))
502
498
503
- if new_keys is not None :
499
+ if set_keys and new_keys :
504
500
self ._keys = new_keys
501
+
505
502
self .last_updated = time .time ()
506
503
self .ignore_errors_until = None
507
- return load_successful
504
+ return load_successful , new_keys
508
505
509
506
def _parse_remote_response (self , response ):
510
507
"""
@@ -545,38 +542,31 @@ def update(self):
545
542
:return: True if update was ok or False if we encountered an error during update.
546
543
"""
547
544
if self .source :
548
- _old_keys = self ._keys # just in case
549
-
550
- # reread everything
551
- self ._keys = []
545
+ new_keys = []
552
546
updated = None
553
547
554
548
try :
555
549
if self .local :
556
550
if self .fileformat in ["jwks" , "jwk" ]:
557
551
updated , k = self ._do_local_jwk (self .source )
558
- if k :
559
- self ._keys .extend (k )
560
552
elif self .fileformat == "der" :
561
553
updated , k = self ._do_local_der (self .source , self .keytype , self .keyusage )
562
- if k :
563
- self ._keys .extend (k )
564
554
elif self .remote :
565
- updated = self ._do_remote ()
555
+ updated , k = self ._do_remote (set_keys = False )
556
+ if k :
557
+ new_keys .extend (k )
566
558
except Exception as err :
567
559
LOGGER .error ("Key bundle update failed: %s" , err )
568
- self ._keys = _old_keys # restore
569
560
return False
570
561
571
562
if updated :
572
563
now = time .time ()
573
- for _key in _old_keys :
574
- if _key not in self . _keys :
564
+ for _key in self . _keys :
565
+ if _key not in new_keys :
575
566
if not _key .inactive_since : # If already marked don't mess
576
567
_key .inactive_since = now
577
- self ._keys .append (_key )
578
- else :
579
- self ._keys = _old_keys
568
+ new_keys .append (_key )
569
+ self ._keys = new_keys
580
570
581
571
return True
582
572
@@ -592,9 +582,9 @@ def get(self, typ="", only_active=True):
592
582
593
583
if typ :
594
584
_typs = [typ .lower (), typ .upper ()]
595
- _keys = [k for k in self ._keys [:] if k .kty in _typs ]
585
+ _keys = [k for k in self ._keys if k .kty in _typs ]
596
586
else :
597
- _keys = self ._keys [:]
587
+ _keys = self ._keys
598
588
599
589
if only_active :
600
590
return [k for k in _keys if not k .inactive_since ]
@@ -609,7 +599,7 @@ def keys(self, update: bool = True):
609
599
"""
610
600
if update :
611
601
self ._uptodate ()
612
- return self ._keys [:]
602
+ return self ._keys
613
603
614
604
def active_keys (self ):
615
605
"""Return the set of active keys."""
@@ -836,9 +826,11 @@ def load(self, spec):
836
826
:param spec: Dictionary with attributes and value to populate the instance with
837
827
:return: The instance itself
838
828
"""
829
+
839
830
_keys = spec .get ("keys" , [])
840
831
if _keys :
841
- self ._add_jwk_dicts (_keys )
832
+ self ._keys .extend (self .jwk_dicts_as_keys (_keys ))
833
+ self .last_updated = time .time ()
842
834
843
835
for attr , default in self .params .items ():
844
836
val = spec .get (attr )
0 commit comments