diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 817762a..de2453a 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -32,13 +32,13 @@ jobs: env: py311-fastapi84 - python: "3.7" - env: py37-fastapi99 + env: py37-fastapi100 - python: "3.9" - env: py39-fastapi99 + env: py39-fastapi100 - python: "3.10" - env: py310-fastapi99 + env: py310-fastapi100 - python: "3.11" - env: py311-fastapi99 + env: py311-fastapi100 steps: - uses: actions/checkout@v2 diff --git a/README.md b/README.md index ac8f71f..490cbe2 100644 --- a/README.md +++ b/README.md @@ -13,9 +13,7 @@ the [social-core](https://github.com/python-social-auth/social-core) authenticat - Use multiple OAuth2 providers at the same time * There need to be provided a way to configure the OAuth2 for multiple providers -- Token -> user data, user data -> token easy conversion - Customizable OAuth2 routes -- Registration support ## Installation @@ -43,12 +41,14 @@ middleware configuration is declared with the `OAuth2Config` and `OAuth2Client` - `client_secret` - The OAuth2 client secret for the particular provider. - `redirect_uri` - The OAuth2 redirect URI to redirect to after success. Defaults to the base URL. - `scope` - The OAuth2 scope for the particular provider. Defaults to `[]`. +- `claims` - Claims mapping for the certain provider. It is also important to mention that for the configured clients of the auth providers, the authorization URLs are accessible by the `/oauth2/{provider}/auth` path where the `provider` variable represents the exact value of the auth provider backend `name` attribute. ```python +from fastapi_oauth2.claims import Claims from fastapi_oauth2.client import OAuth2Client from fastapi_oauth2.config import OAuth2Config from social_core.backends.github import GithubOAuth2 @@ -65,6 +65,10 @@ oauth2_config = OAuth2Config( client_secret=os.getenv("OAUTH2_CLIENT_SECRET"), redirect_uri="https://pysnippet.org/", scope=["user:email"], + claims=Claims( + picture="avatar_url", + identity=lambda user: "%s:%s" % (user.get("provider"), user.get("id")), + ), ), ] ) diff --git a/examples/demonstration/config.py b/examples/demonstration/config.py index c63b136..935c2b1 100644 --- a/examples/demonstration/config.py +++ b/examples/demonstration/config.py @@ -3,6 +3,7 @@ from dotenv import load_dotenv from social_core.backends.github import GithubOAuth2 +from fastapi_oauth2.claims import Claims from fastapi_oauth2.client import OAuth2Client from fastapi_oauth2.config import OAuth2Config @@ -20,6 +21,10 @@ client_secret=os.getenv("OAUTH2_CLIENT_SECRET"), # redirect_uri="http://127.0.0.1:8000/", scope=["user:email"], + claims=Claims( + picture="avatar_url", + identity=lambda user: "%s:%s" % (user.get("provider"), user.get("id")), + ), ), ] ) diff --git a/examples/demonstration/database.py b/examples/demonstration/database.py new file mode 100644 index 0000000..a22915f --- /dev/null +++ b/examples/demonstration/database.py @@ -0,0 +1,21 @@ +from sqlalchemy import create_engine +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker + +engine = create_engine( + "sqlite:///./database.sqlite", + connect_args={ + "check_same_thread": False, + }, +) + +Base = declarative_base() +SessionLocal = sessionmaker(bind=engine, autoflush=False) + + +def get_db(): + db = SessionLocal() + try: + yield db + finally: + db.close() diff --git a/examples/demonstration/main.py b/examples/demonstration/main.py index 1fd1291..e657bf1 100644 --- a/examples/demonstration/main.py +++ b/examples/demonstration/main.py @@ -1,14 +1,39 @@ from fastapi import APIRouter from fastapi import FastAPI +from sqlalchemy.orm import Session from config import oauth2_config +from database import Base +from database import engine +from database import get_db +from fastapi_oauth2.middleware import Auth from fastapi_oauth2.middleware import OAuth2Middleware +from fastapi_oauth2.middleware import User from fastapi_oauth2.router import router as oauth2_router +from models import User as UserModel from router import router as app_router +Base.metadata.create_all(bind=engine) + router = APIRouter() + +async def on_auth(auth: Auth, user: User): + # perform a check for user existence in + # the database and create if not exists + db: Session = next(get_db()) + query = db.query(UserModel) + if user.identity and not query.filter_by(identity=user.identity).first(): + UserModel(**{ + "identity": user.get("identity"), + "username": user.get("username"), + "image": user.get("image"), + "email": user.get("email"), + "name": user.get("name"), + }).save(db) + + app = FastAPI() app.include_router(app_router) app.include_router(oauth2_router) -app.add_middleware(OAuth2Middleware, config=oauth2_config) +app.add_middleware(OAuth2Middleware, config=oauth2_config, callback=on_auth) diff --git a/examples/demonstration/models.py b/examples/demonstration/models.py new file mode 100644 index 0000000..ed86f45 --- /dev/null +++ b/examples/demonstration/models.py @@ -0,0 +1,27 @@ +from sqlalchemy import Column +from sqlalchemy import Integer +from sqlalchemy import String +from sqlalchemy.orm import Session + +from database import Base + + +class BaseModel(Base): + __abstract__ = True + + def save(self, db: Session): + db.add(self) + db.commit() + db.refresh(self) + return self + + +class User(BaseModel): + __tablename__ = "users" + + id = Column(Integer, primary_key=True, index=True) + username = Column(String) + email = Column(String) + name = Column(String) + image = Column(String) + identity = Column(String, unique=True) # provider_name:user_id diff --git a/examples/demonstration/router.py b/examples/demonstration/router.py index fb8cf24..8656b1a 100644 --- a/examples/demonstration/router.py +++ b/examples/demonstration/router.py @@ -1,12 +1,16 @@ import json +from fastapi import APIRouter from fastapi import Depends from fastapi import Request -from fastapi import APIRouter from fastapi.responses import HTMLResponse from fastapi.templating import Jinja2Templates +from sqlalchemy.orm import Session +from starlette.responses import RedirectResponse +from database import get_db from fastapi_oauth2.security import OAuth2 +from models import User oauth2 = OAuth2() router = APIRouter() @@ -18,6 +22,39 @@ async def root(request: Request): return templates.TemplateResponse("index.html", {"request": request, "user": request.user, "json": json}) +@router.get("/auth") +def sim_auth(request: Request): + access_token = request.auth.jwt_create({ + "id": 1, + "identity": "demo:1", + "image": None, + "display_name": "John Doe", + "email": "john.doe@auth.sim", + "username": "JohnDoe", + "exp": 3689609839, + }) + response = RedirectResponse("/") + response.set_cookie( + "Authorization", + value=f"Bearer {access_token}", + max_age=request.auth.expires, + expires=request.auth.expires, + httponly=request.auth.http, + ) + return response + + @router.get("/user") -def user(request: Request, _: str = Depends(oauth2)): +def user_get(request: Request, _: str = Depends(oauth2)): return request.user + + +@router.get("/users") +def users_get(request: Request, db: Session = Depends(get_db), _: str = Depends(oauth2)): + return db.query(User).all() + + +@router.post("/users") +async def users_post(request: Request, db: Session = Depends(get_db), _: str = Depends(oauth2)): + data = await request.json() + return User(**data).save(db) diff --git a/examples/demonstration/templates/index.html b/examples/demonstration/templates/index.html index c42adf1..9a8b81d 100644 --- a/examples/demonstration/templates/index.html +++ b/examples/demonstration/templates/index.html @@ -12,8 +12,15 @@
{% if request.user.is_authenticated %} Sign out - Pic + {% if request.user.picture %} + Pic + {% else %} + Pic + {% endif %} {% else %} + + Simulate Login + @@ -25,7 +32,7 @@
{% if request.user.is_authenticated %} -

Hi, {{ request.user.name }}

+

Hi, {{ request.user.display_name }}

This is what your JWT contains currently

{{ json.dumps(request.user, indent=4) }}
{% else %} diff --git a/src/fastapi_oauth2/__init__.py b/src/fastapi_oauth2/__init__.py index 5f1e750..a390618 100644 --- a/src/fastapi_oauth2/__init__.py +++ b/src/fastapi_oauth2/__init__.py @@ -1 +1 @@ -__version__ = "1.0.0-alpha" +__version__ = "1.0.0-alpha.1" diff --git a/src/fastapi_oauth2/claims.py b/src/fastapi_oauth2/claims.py new file mode 100644 index 0000000..2f5ef68 --- /dev/null +++ b/src/fastapi_oauth2/claims.py @@ -0,0 +1,19 @@ +from typing import Any +from typing import Callable +from typing import Union + + +class Claims(dict): + """Claims configuration for a single provider.""" + + display_name: Union[str, Callable[[dict], Any]] + identity: Union[str, Callable[[dict], Any]] + picture: Union[str, Callable[[dict], Any]] + email: Union[str, Callable[[dict], Any]] + + def __init__(self, seq=None, **kwargs) -> None: + super().__init__(seq or {}, **kwargs) + self["display_name"] = kwargs.get("display_name", self.get("display_name", "name")) + self["identity"] = kwargs.get("identity", self.get("identity", "sub")) + self["picture"] = kwargs.get("picture", self.get("picture", "picture")) + self["email"] = kwargs.get("email", self.get("email", "email")) diff --git a/src/fastapi_oauth2/client.py b/src/fastapi_oauth2/client.py index b88e245..9db113c 100644 --- a/src/fastapi_oauth2/client.py +++ b/src/fastapi_oauth2/client.py @@ -1,16 +1,22 @@ from typing import Optional from typing import Sequence from typing import Type +from typing import Union from social_core.backends.oauth import BaseOAuth2 +from .claims import Claims + class OAuth2Client: + """OAuth2 client configuration for a single provider.""" + backend: Type[BaseOAuth2] client_id: str client_secret: str redirect_uri: Optional[str] scope: Optional[Sequence[str]] + claims: Optional[Union[Claims, dict]] def __init__( self, @@ -20,9 +26,11 @@ def __init__( client_secret: str, redirect_uri: Optional[str] = None, scope: Optional[Sequence[str]] = None, - ): + claims: Optional[Union[Claims, dict]] = None, + ) -> None: self.backend = backend self.client_id = client_id self.client_secret = client_secret self.redirect_uri = redirect_uri self.scope = scope or [] + self.claims = Claims(claims) diff --git a/src/fastapi_oauth2/config.py b/src/fastapi_oauth2/config.py index 707ad66..8eb4b85 100644 --- a/src/fastapi_oauth2/config.py +++ b/src/fastapi_oauth2/config.py @@ -6,6 +6,8 @@ class OAuth2Config: + """Configuration class of the authentication middleware.""" + allow_http: bool jwt_secret: str jwt_expires: int @@ -20,7 +22,7 @@ def __init__( jwt_expires: Union[int, str] = 900, jwt_algorithm: str = "HS256", clients: List[OAuth2Client] = None, - ): + ) -> None: if allow_http: os.environ["OAUTHLIB_INSECURE_TRANSPORT"] = "1" self.allow_http = allow_http diff --git a/src/fastapi_oauth2/core.py b/src/fastapi_oauth2/core.py index e8676de..a9e7291 100644 --- a/src/fastapi_oauth2/core.py +++ b/src/fastapi_oauth2/core.py @@ -16,37 +16,39 @@ from starlette.requests import Request from starlette.responses import RedirectResponse +from .claims import Claims from .client import OAuth2Client class OAuth2LoginError(HTTPException): - """Raised when any login-related error occurs - (such as when user is not verified or if there was an attempt for fake login) - """ + """Raised when any login-related error occurs.""" class OAuth2Strategy(BaseStrategy): - def request_data(self, merge=True): + """Dummy strategy for using the `BaseOAuth2.user_data` method.""" + + def request_data(self, merge=True) -> Dict[str, Any]: return {} - def absolute_uri(self, path=None): + def absolute_uri(self, path=None) -> str: return path - def get_setting(self, name): - return None + def get_setting(self, name) -> Any: + """Mocked setting method.""" @staticmethod - def get_json(url, method='GET', *args, **kwargs): + def get_json(url, method='GET', *args, **kwargs) -> httpx.Response: return httpx.request(method, url, *args, **kwargs) class OAuth2Core: - """Base class (mixin) for all SSO providers""" + """OAuth2 flow handler of a certain provider.""" client_id: str = None client_secret: str = None callback_url: Optional[str] = None scope: Optional[List[str]] = None + claims: Optional[Claims] = None backend: BaseOAuth2 = None _oauth_client: Optional[WebApplicationClient] = None @@ -56,8 +58,10 @@ class OAuth2Core: def __init__(self, client: OAuth2Client) -> None: self.client_id = client.client_id self.client_secret = client.client_secret - self.scope = client.scope or self.scope + self.scope = client.scope + self.claims = client.claims self.provider = client.backend.name + self.redirect_uri = client.redirect_uri self.backend = client.backend(OAuth2Strategy()) self.authorization_endpoint = client.backend.AUTHORIZATION_URL self.token_endpoint = client.backend.ACCESS_TOKEN_URL @@ -71,17 +75,14 @@ def oauth_client(self) -> WebApplicationClient: def get_redirect_uri(self, request: Request) -> str: return urljoin(str(request.base_url), "/oauth2/%s/token" % self.provider) - async def get_login_url(self, request: Request) -> Any: + async def login_redirect(self, request: Request) -> RedirectResponse: redirect_uri = self.get_redirect_uri(request) state = "".join([random.choice(string.ascii_letters) for _ in range(32)]) - return self.oauth_client.prepare_request_uri( + return RedirectResponse(str(self.oauth_client.prepare_request_uri( self.authorization_endpoint, redirect_uri=redirect_uri, state=state, scope=self.scope - ) - - async def login_redirect(self, request: Request) -> RedirectResponse: - return RedirectResponse(await self.get_login_url(request), 303) + )), 303) - async def get_token_data(self, request: Request) -> Optional[Dict[str, Any]]: + async def token_redirect(self, request: Request) -> RedirectResponse: if not request.query_params.get("code"): raise OAuth2LoginError(400, "'code' parameter was not found in callback request") if not request.query_params.get("state"): @@ -108,14 +109,10 @@ async def get_token_data(self, request: Request) -> Optional[Dict[str, Any]]: async with httpx.AsyncClient() as session: response = await session.post(token_url, headers=headers, content=content, auth=auth) token = self.oauth_client.parse_request_body_response(json.dumps(response.json())) - data = self.backend.user_data(token.get("access_token")) - - return {**data, "scope": self.scope} + token_data = self.standardize(self.backend.user_data(token.get("access_token"))) + access_token = request.auth.jwt_create(token_data) - async def token_redirect(self, request: Request) -> RedirectResponse: - token_data = await self.get_token_data(request) - access_token = request.auth.jwt_create(token_data) - response = RedirectResponse(request.base_url) + response = RedirectResponse(self.redirect_uri or request.base_url) response.set_cookie( "Authorization", value=f"Bearer {access_token}", @@ -124,3 +121,8 @@ async def token_redirect(self, request: Request) -> RedirectResponse: httponly=request.auth.http, ) return response + + def standardize(self, data: Dict[str, Any]) -> Dict[str, Any]: + data["provider"] = self.provider + data["scope"] = self.scope + return data diff --git a/src/fastapi_oauth2/middleware.py b/src/fastapi_oauth2/middleware.py index dc47017..c921f7b 100644 --- a/src/fastapi_oauth2/middleware.py +++ b/src/fastapi_oauth2/middleware.py @@ -1,15 +1,21 @@ from datetime import datetime from datetime import timedelta +from typing import Any +from typing import Awaitable +from typing import Callable from typing import Dict from typing import List from typing import Optional +from typing import Sequence from typing import Tuple from typing import Union from fastapi.security.utils import get_authorization_scheme_param from jose.jwt import decode as jwt_decode from jose.jwt import encode as jwt_encode +from starlette.authentication import AuthCredentials from starlette.authentication import AuthenticationBackend +from starlette.authentication import BaseUser from starlette.middleware.authentication import AuthenticationMiddleware from starlette.requests import Request from starlette.types import ASGIApp @@ -17,12 +23,15 @@ from starlette.types import Scope from starlette.types import Send +from .claims import Claims from .client import OAuth2Client from .config import OAuth2Config from .core import OAuth2Core -class Auth: +class Auth(AuthCredentials): + """Extended auth credentials schema based on Starlette AuthCredentials.""" + http: bool secret: str expires: int @@ -30,8 +39,16 @@ class Auth: scopes: List[str] clients: Dict[str, OAuth2Core] = {} - def __init__(self, scopes: Optional[List[str]] = None) -> None: - self.scopes = scopes or [] + provider: str + default_provider: str = "local" + + def __init__( + self, + scopes: Optional[Sequence[str]] = None, + provider: str = default_provider, + ) -> None: + super().__init__(scopes) + self.provider = provider @classmethod def set_http(cls, http: bool) -> None: @@ -67,24 +84,57 @@ def jwt_create(cls, token_data: dict) -> str: return cls.jwt_encode({**token_data, "exp": expire}) -class User(dict): - is_authenticated: bool +class User(BaseUser, dict): + """Extended user schema based on Starlette BaseUser.""" + + @property + def is_authenticated(self) -> bool: + return bool(self) + + @property + def display_name(self) -> str: + return self.__getprop__("display_name") + + @property + def identity(self) -> str: + return self.__getprop__("identity") + + @property + def picture(self) -> str: + return self.__getprop__("picture") - def __init__(self, seq: Optional[dict] = None, **kwargs) -> None: - self.is_authenticated = seq is not None - super().__init__(seq or {}, **kwargs) + @property + def email(self) -> str: + return self.__getprop__("email") + + def use_claims(self, claims: Claims) -> "User": + for attr, item in claims.items(): + self[attr] = self.__getprop__(item) + return self + + def __getprop__(self, item, default="") -> Any: + if callable(item): + return item(self) + return self.get(item, default) class OAuth2Backend(AuthenticationBackend): - def __init__(self, config: OAuth2Config) -> None: + """Authentication backend for AuthenticationMiddleware.""" + + def __init__( + self, + config: OAuth2Config, + callback: Callable[[Auth, User], Union[Awaitable[None], None]] = None, + ) -> None: Auth.set_http(config.allow_http) Auth.set_secret(config.jwt_secret) Auth.set_expires(config.jwt_expires) Auth.set_algorithm(config.jwt_algorithm) for client in config.clients: Auth.register_client(client) + self.callback = callback - async def authenticate(self, request: Request) -> Optional[Tuple["Auth", "User"]]: + async def authenticate(self, request: Request) -> Optional[Tuple[Auth, User]]: authorization = request.headers.get( "Authorization", request.cookies.get("Authorization"), @@ -94,19 +144,44 @@ async def authenticate(self, request: Request) -> Optional[Tuple["Auth", "User"] if not scheme or not param: return Auth(), User() - user = Auth.jwt_decode(param) - return Auth(user.pop("scope", [])), User(user) + user = User(Auth.jwt_decode(param)) + user.update(provider=user.get("provider", Auth.default_provider)) + auth = Auth(user.pop("scope", []), user.get("provider")) + client = Auth.clients.get(auth.provider) + claims = client.claims if client else Claims() + user = user.use_claims(claims) + + # Call the callback function on authentication + if callable(self.callback): + coroutine = self.callback(auth, user) + if issubclass(type(coroutine), Awaitable): + await coroutine + return auth, user class OAuth2Middleware: + """Wrapper for the Starlette AuthenticationMiddleware.""" + auth_middleware: AuthenticationMiddleware = None - def __init__(self, app: ASGIApp, config: Union[OAuth2Config, dict]) -> None: + def __init__( + self, + app: ASGIApp, + config: Union[OAuth2Config, dict], + callback: Callable[[Auth, User], Union[Awaitable[None], None]] = None, + **kwargs, # AuthenticationMiddleware kwargs + ) -> None: + """Initiates the middleware with the given configuration. + + :param app: FastAPI application instance + :param config: middleware configuration + :param callback: callback function to be called after authentication + """ if isinstance(config, dict): config = OAuth2Config(**config) elif not isinstance(config, OAuth2Config): raise TypeError("config is not a valid type") - self.auth_middleware = AuthenticationMiddleware(app, OAuth2Backend(config)) + self.auth_middleware = AuthenticationMiddleware(app, backend=OAuth2Backend(config, callback), **kwargs) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await self.auth_middleware(scope, receive, send) diff --git a/src/fastapi_oauth2/security.py b/src/fastapi_oauth2/security.py index 4c6aeed..0f5d3b3 100644 --- a/src/fastapi_oauth2/security.py +++ b/src/fastapi_oauth2/security.py @@ -1,3 +1,10 @@ +from typing import Any +from typing import Callable +from typing import Dict +from typing import Optional +from typing import Tuple +from typing import Type + from fastapi.security import OAuth2 as FastAPIOAuth2 from fastapi.security import OAuth2AuthorizationCodeBearer as FastAPICodeBearer from fastapi.security import OAuth2PasswordBearer as FastAPIPasswordBearer @@ -5,9 +12,11 @@ from starlette.requests import Request -def use_cookie(cls: FastAPIOAuth2): - def _use_cookie(*args, **kwargs): - async def __call__(self, request: Request): +def use_cookies(cls: Type[FastAPIOAuth2]) -> Callable[[Tuple[Any], Dict[str, Any]], FastAPIOAuth2]: + """OAuth2 classes wrapped with this decorator will use cookies for the Authorization header.""" + + def _use_cookies(*args, **kwargs) -> FastAPIOAuth2: + async def __call__(self: FastAPIOAuth2, request: Request) -> Optional[str]: authorization = request.headers.get("Authorization", request.cookies.get("Authorization")) if authorization: request._headers = Headers({**request.headers, "Authorization": authorization}) @@ -16,19 +25,19 @@ async def __call__(self, request: Request): cls.__call__ = __call__ return cls(*args, **kwargs) - return _use_cookie + return _use_cookies -@use_cookie +@use_cookies class OAuth2(FastAPIOAuth2): - ... + """Wrapper class of the `fastapi.security.OAuth2` class.""" -@use_cookie +@use_cookies class OAuth2PasswordBearer(FastAPIPasswordBearer): - ... + """Wrapper class of the `fastapi.security.OAuth2PasswordBearer` class.""" -@use_cookie +@use_cookies class OAuth2AuthorizationCodeBearer(FastAPICodeBearer): - ... + """Wrapper class of the `fastapi.security.OAuth2AuthorizationCodeBearer` class.""" diff --git a/tox.ini b/tox.ini index 9ec924f..7edde4a 100644 --- a/tox.ini +++ b/tox.ini @@ -1,11 +1,11 @@ [tox] envlist = py{36,38,310,311}-fastapi68 - py{37,39,310,311}-fastapi{84,99} + py{37,39,310,311}-fastapi{84,100} [testenv] deps = - fastapi99: fastapi>=0.99.0 + fastapi100: fastapi>=0.100.0 fastapi84: fastapi<=0.84.0 fastapi68: fastapi<=0.68.1 -r{toxinidir}/tests/requirements.txt