diff --git a/src/cryptojwt/jwt.py b/src/cryptojwt/jwt.py index f6c738e8..e212c772 100755 --- a/src/cryptojwt/jwt.py +++ b/src/cryptojwt/jwt.py @@ -1,9 +1,8 @@ """Basic JSON Web Token implementation.""" import json import logging +import time import uuid -from datetime import datetime -from datetime import timezone from json import JSONDecodeError from .exception import HeaderError @@ -28,9 +27,7 @@ def utc_time_sans_frac(): :return: A number of seconds """ - - now_timestampt = int(datetime.now(timezone.utc).timestamp()) - return now_timestampt + return int(time.time()) def pick_key(keys, use, alg="", key_type="", kid=""): @@ -95,6 +92,7 @@ def __init__( allowed_sign_algs=None, allowed_enc_algs=None, allowed_enc_encs=None, + allowed_max_lifetime=None, zip="", ): self.key_jar = key_jar # KeyJar instance @@ -115,6 +113,7 @@ def __init__( self.allowed_sign_algs = allowed_sign_algs self.allowed_enc_algs = allowed_enc_algs self.allowed_enc_encs = allowed_enc_encs + self.allowed_max_lifetime = allowed_max_lifetime self.zip = zip def receiver_keys(self, recv, use): @@ -176,13 +175,13 @@ def put_together_aud(recv, aud=None): return _aud - def pack_init(self, recv, aud): + def pack_init(self, recv, aud, iat=None): """ Gather initial information for the payload. :return: A dictionary with claims and values """ - argv = {"iss": self.iss, "iat": utc_time_sans_frac()} + argv = {"iss": self.iss, "iat": iat or utc_time_sans_frac()} if self.lifetime: argv["exp"] = argv["iat"] + self.lifetime @@ -207,7 +206,7 @@ def pack_key(self, issuer_id="", kid=""): return keys[0] # Might be more then one if kid == '' - def pack(self, payload=None, kid="", issuer_id="", recv="", aud=None, **kwargs): + def pack(self, payload=None, kid="", issuer_id="", recv="", aud=None, iat=None, **kwargs): """ :param payload: Information to be carried as payload in the JWT @@ -216,13 +215,14 @@ def pack(self, payload=None, kid="", issuer_id="", recv="", aud=None, **kwargs): :param recv: The intended immediate receiver :param aud: Intended audience for this JWS/JWE, not expected to contain the recipient. + :param iat: Override issued at (default current timestamp) :param kwargs: Extra keyword arguments :return: A signed or signed and encrypted Json Web Token """ _args = {} if payload is not None: _args.update(payload) - _args.update(self.pack_init(recv, aud)) + _args.update(self.pack_init(recv, aud, iat)) try: _encrypt = kwargs["encrypt"] @@ -304,11 +304,12 @@ def verify_profile(msg_cls, info, **kwargs): raise VerificationError() return _msg - def unpack(self, token): + def unpack(self, token, timestamp=None): """ Unpack a received signed or signed and encrypted Json Web Token :param token: The Json Web Token + :param timestamp: Time for evaluation (default now) :return: If decryption and signature verification work the payload will be returned as a Message instance if possible. """ @@ -378,6 +379,26 @@ def unpack(self, token): except KeyError: _msg_cls = None + timestamp = timestamp or utc_time_sans_frac() + + if "nbf" in _info: + nbf = int(_info["nbf"]) + if timestamp < nbf - self.skew: + raise VerificationError("Token not yet valid") + + if "exp" in _info: + exp = int(_info["exp"]) + if timestamp >= exp + self.skew: + raise VerificationError("Token expired") + else: + exp = None + + if "iat" in _info: + iat = int(_info["iat"]) + if self.allowed_max_lifetime and exp: + if abs(exp - iat) > self.allowed_max_lifetime: + raise VerificationError("Token lifetime exceeded") + if _msg_cls: vp_args = {"skew": self.skew} if self.iss: diff --git a/tests/test_09_jwt.py b/tests/test_09_jwt.py index 2f645fe7..0bb912fd 100755 --- a/tests/test_09_jwt.py +++ b/tests/test_09_jwt.py @@ -5,7 +5,9 @@ from cryptojwt.exception import IssuerNotFound from cryptojwt.jws.exception import NoSuitableSigningKeys from cryptojwt.jwt import JWT +from cryptojwt.jwt import VerificationError from cryptojwt.jwt import pick_key +from cryptojwt.jwt import utc_time_sans_frac from cryptojwt.key_bundle import KeyBundle from cryptojwt.key_jar import KeyJar from cryptojwt.key_jar import init_key_jar @@ -81,15 +83,82 @@ def test_jwt_pack_and_unpack(): assert set(info.keys()) == {"iat", "iss", "sub"} -def test_jwt_pack_and_unpack_unknown_issuer(): +def test_jwt_pack_and_unpack_valid(): alice = JWT(key_jar=ALICE_KEY_JAR, iss=ALICE, sign_alg="RS256") + t = utc_time_sans_frac() + payload = {"sub": "sub", "nbf": t, "exp": t + 3600} + _jwt = alice.pack(payload=payload) + + bob = JWT(key_jar=BOB_KEY_JAR, iss=BOB, allowed_sign_algs=["RS256"]) + info = bob.unpack(_jwt) + + assert set(info.keys()) == {"iat", "iss", "sub", "nbf", "exp"} + + +def test_jwt_pack_and_unpack_not_yet_valid(): + lifetime = 3600 + skew = 15 + alice = JWT(key_jar=ALICE_KEY_JAR, iss=ALICE, sign_alg="RS256", lifetime=lifetime) + timestamp = utc_time_sans_frac() + payload = {"sub": "sub", "nbf": timestamp} + _jwt = alice.pack(payload=payload) + + bob = JWT(key_jar=BOB_KEY_JAR, iss=BOB, allowed_sign_algs=["RS256"], skew=skew) + _ = bob.unpack(_jwt, timestamp=timestamp - skew) + with pytest.raises(VerificationError): + _ = bob.unpack(_jwt, timestamp=timestamp - skew - 1) + + +def test_jwt_pack_and_unpack_expired(): + lifetime = 3600 + skew = 15 + alice = JWT(key_jar=ALICE_KEY_JAR, iss=ALICE, sign_alg="RS256", lifetime=lifetime) payload = {"sub": "sub"} _jwt = alice.pack(payload=payload) - kj = KeyJar() - bob = JWT(key_jar=kj, iss=BOB, allowed_sign_algs=["RS256"]) - with pytest.raises(IssuerNotFound): - info = bob.unpack(_jwt) + bob = JWT(key_jar=BOB_KEY_JAR, iss=BOB, allowed_sign_algs=["RS256"], skew=skew) + iat = bob.unpack(_jwt)["iat"] + _ = bob.unpack(_jwt, timestamp=iat + lifetime + skew - 1) + with pytest.raises(VerificationError): + _ = bob.unpack(_jwt, timestamp=iat + lifetime + skew) + + +def test_jwt_pack_and_unpack_max_lifetime_exceeded(): + lifetime = 3600 + alice = JWT(key_jar=ALICE_KEY_JAR, iss=ALICE, sign_alg="RS256", lifetime=lifetime) + payload = {"sub": "sub"} + _jwt = alice.pack(payload=payload) + + bob = JWT( + key_jar=BOB_KEY_JAR, iss=BOB, allowed_sign_algs=["RS256"], allowed_max_lifetime=lifetime - 1 + ) + with pytest.raises(VerificationError): + _ = bob.unpack(_jwt) + + +def test_jwt_pack_and_unpack_max_lifetime_exceeded(): + lifetime = 3600 + alice = JWT(key_jar=ALICE_KEY_JAR, iss=ALICE, sign_alg="RS256", lifetime=lifetime) + payload = {"sub": "sub"} + _jwt = alice.pack(payload=payload) + + bob = JWT( + key_jar=BOB_KEY_JAR, iss=BOB, allowed_sign_algs=["RS256"], allowed_max_lifetime=lifetime - 1 + ) + with pytest.raises(VerificationError): + _ = bob.unpack(_jwt) + + +def test_jwt_pack_and_unpack_timestamp(): + lifetime = 3600 + alice = JWT(key_jar=ALICE_KEY_JAR, iss=ALICE, sign_alg="RS256", lifetime=lifetime) + payload = {"sub": "sub"} + _jwt = alice.pack(payload=payload, iat=42) + + bob = JWT(key_jar=BOB_KEY_JAR, iss=BOB, allowed_sign_algs=["RS256"]) + _ = bob.unpack(_jwt, timestamp=42) + with pytest.raises(VerificationError): + _ = bob.unpack(_jwt) def test_jwt_pack_and_unpack_unknown_key(): @@ -261,4 +330,4 @@ def test_eddsa_jwt(): kj = KeyJar() kj.add_kb(ISSUER, KeyBundle(JWKS_DICT)) jwt = JWT(key_jar=kj) - _ = jwt.unpack(JWT_TEST) + _ = jwt.unpack(JWT_TEST, timestamp=1655278809)