Skip to content

Commit 5b8cae6

Browse files
committed
Implement a custom strategy for using the user_data method
1 parent a4dedfa commit 5b8cae6

File tree

2 files changed

+51
-59
lines changed

2 files changed

+51
-59
lines changed

src/fastapi_oauth2/core.py

Lines changed: 45 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
import json
2+
import random
23
import re
4+
import string
35
from typing import Any
46
from typing import Dict
57
from typing import List
68
from typing import Optional
79
from urllib.parse import urljoin
810

911
import httpx
12+
import requests
1013
from oauthlib.oauth2 import WebApplicationClient
14+
from social_core.backends.oauth import BaseOAuth2
15+
from social_core.strategy import BaseStrategy
1116
from starlette.exceptions import HTTPException
1217
from starlette.requests import Request
1318
from starlette.responses import RedirectResponse
@@ -21,31 +26,42 @@ class OAuth2LoginError(HTTPException):
2126
"""
2227

2328

29+
class OAuth2Strategy(BaseStrategy):
30+
def request_data(self, merge=True):
31+
return {}
32+
33+
def absolute_uri(self, path=None):
34+
return path
35+
36+
def get_setting(self, name):
37+
return None
38+
39+
@staticmethod
40+
def get_json(url, method='GET', *args, **kwargs):
41+
return requests.request(method, url, *args, **kwargs)
42+
43+
2444
class OAuth2Core:
2545
"""Base class (mixin) for all SSO providers"""
2646

2747
client_id: str = None
2848
client_secret: str = None
2949
callback_url: Optional[str] = None
30-
allow_http: bool = False
3150
scope: Optional[List[str]] = None
32-
state: Optional[str] = None
51+
backend: BaseOAuth2 = None
3352
_oauth_client: Optional[WebApplicationClient] = None
34-
additional_headers: Optional[Dict[str, Any]] = None
3553

3654
authorization_endpoint: str = None
3755
token_endpoint: str = None
38-
userinfo_endpoint: str = None
3956

4057
def __init__(self, client: OAuth2Client) -> None:
4158
self.client_id = client.client_id
4259
self.client_secret = client.client_secret
4360
self.scope = client.scope or self.scope
4461
self.provider = client.backend.name
62+
self.backend = client.backend(OAuth2Strategy())
4563
self.authorization_endpoint = client.backend.AUTHORIZATION_URL
4664
self.token_endpoint = client.backend.ACCESS_TOKEN_URL
47-
self.userinfo_endpoint = "https://api.github.com/user"
48-
self.additional_headers = {"Content-Type": "application/x-www-form-urlencoded", "Accept": "application/json"}
4965

5066
@property
5167
def oauth_client(self) -> WebApplicationClient:
@@ -56,47 +72,24 @@ def oauth_client(self) -> WebApplicationClient:
5672
def get_redirect_uri(self, request: Request) -> str:
5773
return urljoin(str(request.base_url), "/oauth2/%s/token" % self.provider)
5874

59-
async def get_login_url(
60-
self,
61-
request: Request,
62-
*,
63-
params: Optional[Dict[str, Any]] = None,
64-
state: Optional[str] = None,
65-
) -> Any:
66-
self.state = state
67-
params = params or {}
75+
async def get_login_url(self, request: Request) -> Any:
6876
redirect_uri = self.get_redirect_uri(request)
77+
state = "".join([random.choice(string.ascii_letters) for _ in range(32)])
6978
return self.oauth_client.prepare_request_uri(
70-
self.authorization_endpoint, redirect_uri=redirect_uri, state=state, scope=self.scope, **params
79+
self.authorization_endpoint, redirect_uri=redirect_uri, state=state, scope=self.scope
7180
)
7281

73-
async def login_redirect(
74-
self,
75-
request: Request,
76-
*,
77-
params: Optional[Dict[str, Any]] = None,
78-
state: Optional[str] = None,
79-
) -> RedirectResponse:
80-
login_uri = await self.get_login_url(request, params=params, state=state)
81-
return RedirectResponse(login_uri, 303)
82-
83-
async def get_token_data(
84-
self,
85-
request: Request,
86-
*,
87-
params: Optional[Dict[str, Any]] = None,
88-
headers: Optional[Dict[str, Any]] = None,
89-
) -> Optional[Dict[str, Any]]:
90-
params = params or {}
91-
additional_headers = headers or {}
92-
additional_headers.update(self.additional_headers or {})
82+
async def login_redirect(self, request: Request) -> RedirectResponse:
83+
return RedirectResponse(await self.get_login_url(request), 303)
84+
85+
async def get_token_data(self, request: Request) -> Optional[Dict[str, Any]]:
9386
if not request.query_params.get("code"):
9487
raise OAuth2LoginError(400, "'code' parameter was not found in callback request")
95-
if self.state != request.query_params.get("state"):
96-
raise OAuth2LoginError(400, "'state' parameter does not match")
88+
if not request.query_params.get("state"):
89+
raise OAuth2LoginError(400, "'state' parameter was not found in callback request")
9790

9891
url = request.url
99-
scheme = "http" if self.allow_http else "https"
92+
scheme = "http" if request.auth.http else "https"
10093
current_url = re.sub(r"^https?", scheme, str(url))
10194
redirect_uri = self.get_redirect_uri(request)
10295

@@ -105,36 +98,30 @@ async def get_token_data(
10598
redirect_url=redirect_uri,
10699
authorization_response=current_url,
107100
code=request.query_params.get("code"),
108-
**params,
101+
state=request.query_params.get("state"),
109102
)
110103

111-
headers.update(additional_headers)
104+
headers.update({
105+
"Accept": "application/json",
106+
"Content-Type": "application/x-www-form-urlencoded",
107+
})
112108
auth = httpx.BasicAuth(self.client_id, self.client_secret)
113109
async with httpx.AsyncClient() as session:
114110
response = await session.post(token_url, headers=headers, content=content, auth=auth)
115-
self.oauth_client.parse_request_body_response(json.dumps(response.json()))
116-
117-
url, headers, _ = self.oauth_client.add_token(self.userinfo_endpoint)
118-
response = await session.get(url, headers=headers)
119-
content = response.json()
120-
121-
return {**content, "scope": self.scope}
122-
123-
async def token_redirect(
124-
self,
125-
request: Request,
126-
*,
127-
params: Optional[Dict[str, Any]] = None,
128-
headers: Optional[Dict[str, Any]] = None,
129-
) -> RedirectResponse:
130-
token_data = await self.get_token_data(request, params=params, headers=headers)
111+
token = self.oauth_client.parse_request_body_response(json.dumps(response.json()))
112+
data = self.backend.user_data(token.get("access_token"))
113+
114+
return {**data, "scope": self.scope}
115+
116+
async def token_redirect(self, request: Request) -> RedirectResponse:
117+
token_data = await self.get_token_data(request)
131118
access_token = request.auth.jwt_create(token_data)
132119
response = RedirectResponse(request.base_url)
133120
response.set_cookie(
134121
"Authorization",
135122
value=f"Bearer {access_token}",
136-
httponly=self.allow_http,
137123
max_age=request.auth.expires,
138124
expires=request.auth.expires,
125+
httponly=request.auth.http,
139126
)
140127
return response

src/fastapi_oauth2/middleware.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424

2525
class Auth:
26+
http: bool
2627
secret: str
2728
expires: int
2829
algorithm: str
@@ -32,6 +33,10 @@ class Auth:
3233
def __init__(self, scopes: Optional[List[str]] = None) -> None:
3334
self.scopes = scopes or []
3435

36+
@classmethod
37+
def set_http(cls, http: bool) -> None:
38+
cls.http = http
39+
3540
@classmethod
3641
def set_secret(cls, secret: str) -> None:
3742
cls.secret = secret
@@ -72,10 +77,10 @@ def __init__(self, seq: Optional[dict] = None, **kwargs) -> None:
7277

7378
class OAuth2Backend(AuthenticationBackend):
7479
def __init__(self, config: OAuth2Config) -> None:
80+
Auth.set_http(config.allow_http)
7581
Auth.set_secret(config.jwt_secret)
7682
Auth.set_expires(config.jwt_expires)
7783
Auth.set_algorithm(config.jwt_algorithm)
78-
OAuth2Core.allow_http = config.allow_http
7984
for client in config.clients:
8085
Auth.register_client(client)
8186

0 commit comments

Comments
 (0)