diff --git a/src/fastapi_oauth2/middleware.py b/src/fastapi_oauth2/middleware.py index 8b91e32..8481947 100644 --- a/src/fastapi_oauth2/middleware.py +++ b/src/fastapi_oauth2/middleware.py @@ -1,5 +1,6 @@ from datetime import datetime from datetime import timedelta +from datetime import timezone from typing import Any from typing import Awaitable from typing import Callable @@ -27,6 +28,7 @@ from .claims import Claims from .config import OAuth2Config from .core import OAuth2Core +from .exceptions import OAuth2AuthenticationError class Auth(AuthCredentials): @@ -51,7 +53,7 @@ def jwt_decode(cls, token: str) -> dict: @classmethod def jwt_create(cls, token_data: dict) -> str: - expire = datetime.utcnow() + timedelta(seconds=cls.expires) + expire = datetime.now(timezone.utc) + timedelta(seconds=cls.expires) return cls.jwt_encode({**token_data, "exp": expire}) @@ -106,7 +108,11 @@ async def authenticate(self, request: Request) -> Optional[Tuple[Auth, User]]: if not scheme or not param: return Auth(), User() - user = User(Auth.jwt_decode(param)) + token_data = Auth.jwt_decode(param) + if token_data["exp"] and token_data["exp"] < int(datetime.now(timezone.utc).timestamp()): + raise OAuth2AuthenticationError(401, "Token expired") + + user = User(token_data) auth = Auth(user.pop("scope", [])) auth.provider = auth.clients.get(user.get("provider")) claims = auth.provider.claims if auth.provider else {}