1
1
import json
2
+ import random
2
3
import re
4
+ import string
3
5
from typing import Any
4
6
from typing import Dict
5
7
from typing import List
6
8
from typing import Optional
7
9
from urllib .parse import urljoin
8
10
9
11
import httpx
12
+ import requests
10
13
from oauthlib .oauth2 import WebApplicationClient
14
+ from social_core .backends .oauth import BaseOAuth2
15
+ from social_core .strategy import BaseStrategy
11
16
from starlette .exceptions import HTTPException
12
17
from starlette .requests import Request
13
18
from starlette .responses import RedirectResponse
@@ -21,31 +26,42 @@ class OAuth2LoginError(HTTPException):
21
26
"""
22
27
23
28
29
+ class OAuth2Strategy (BaseStrategy ):
30
+ def request_data (self , merge = True ):
31
+ return {}
32
+
33
+ def absolute_uri (self , path = None ):
34
+ return path
35
+
36
+ def get_setting (self , name ):
37
+ return None
38
+
39
+ @staticmethod
40
+ def get_json (url , method = 'GET' , * args , ** kwargs ):
41
+ return requests .request (method , url , * args , ** kwargs )
42
+
43
+
24
44
class OAuth2Core :
25
45
"""Base class (mixin) for all SSO providers"""
26
46
27
47
client_id : str = None
28
48
client_secret : str = None
29
49
callback_url : Optional [str ] = None
30
- allow_http : bool = False
31
50
scope : Optional [List [str ]] = None
32
- state : Optional [ str ] = None
51
+ backend : BaseOAuth2 = None
33
52
_oauth_client : Optional [WebApplicationClient ] = None
34
- additional_headers : Optional [Dict [str , Any ]] = None
35
53
36
54
authorization_endpoint : str = None
37
55
token_endpoint : str = None
38
- userinfo_endpoint : str = None
39
56
40
57
def __init__ (self , client : OAuth2Client ) -> None :
41
58
self .client_id = client .client_id
42
59
self .client_secret = client .client_secret
43
60
self .scope = client .scope or self .scope
44
61
self .provider = client .backend .name
62
+ self .backend = client .backend (OAuth2Strategy ())
45
63
self .authorization_endpoint = client .backend .AUTHORIZATION_URL
46
64
self .token_endpoint = client .backend .ACCESS_TOKEN_URL
47
- self .userinfo_endpoint = "https://api.github.com/user"
48
- self .additional_headers = {"Content-Type" : "application/x-www-form-urlencoded" , "Accept" : "application/json" }
49
65
50
66
@property
51
67
def oauth_client (self ) -> WebApplicationClient :
@@ -56,47 +72,24 @@ def oauth_client(self) -> WebApplicationClient:
56
72
def get_redirect_uri (self , request : Request ) -> str :
57
73
return urljoin (str (request .base_url ), "/oauth2/%s/token" % self .provider )
58
74
59
- async def get_login_url (
60
- self ,
61
- request : Request ,
62
- * ,
63
- params : Optional [Dict [str , Any ]] = None ,
64
- state : Optional [str ] = None ,
65
- ) -> Any :
66
- self .state = state
67
- params = params or {}
75
+ async def get_login_url (self , request : Request ) -> Any :
68
76
redirect_uri = self .get_redirect_uri (request )
77
+ state = "" .join ([random .choice (string .ascii_letters ) for _ in range (32 )])
69
78
return self .oauth_client .prepare_request_uri (
70
- self .authorization_endpoint , redirect_uri = redirect_uri , state = state , scope = self .scope , ** params
79
+ self .authorization_endpoint , redirect_uri = redirect_uri , state = state , scope = self .scope
71
80
)
72
81
73
- async def login_redirect (
74
- self ,
75
- request : Request ,
76
- * ,
77
- params : Optional [Dict [str , Any ]] = None ,
78
- state : Optional [str ] = None ,
79
- ) -> RedirectResponse :
80
- login_uri = await self .get_login_url (request , params = params , state = state )
81
- return RedirectResponse (login_uri , 303 )
82
-
83
- async def get_token_data (
84
- self ,
85
- request : Request ,
86
- * ,
87
- params : Optional [Dict [str , Any ]] = None ,
88
- headers : Optional [Dict [str , Any ]] = None ,
89
- ) -> Optional [Dict [str , Any ]]:
90
- params = params or {}
91
- additional_headers = headers or {}
92
- additional_headers .update (self .additional_headers or {})
82
+ async def login_redirect (self , request : Request ) -> RedirectResponse :
83
+ return RedirectResponse (await self .get_login_url (request ), 303 )
84
+
85
+ async def get_token_data (self , request : Request ) -> Optional [Dict [str , Any ]]:
93
86
if not request .query_params .get ("code" ):
94
87
raise OAuth2LoginError (400 , "'code' parameter was not found in callback request" )
95
- if self . state != request .query_params .get ("state" ):
96
- raise OAuth2LoginError (400 , "'state' parameter does not match " )
88
+ if not request .query_params .get ("state" ):
89
+ raise OAuth2LoginError (400 , "'state' parameter was not found in callback request " )
97
90
98
91
url = request .url
99
- scheme = "http" if self . allow_http else "https"
92
+ scheme = "http" if request . auth . http else "https"
100
93
current_url = re .sub (r"^https?" , scheme , str (url ))
101
94
redirect_uri = self .get_redirect_uri (request )
102
95
@@ -105,36 +98,30 @@ async def get_token_data(
105
98
redirect_url = redirect_uri ,
106
99
authorization_response = current_url ,
107
100
code = request .query_params .get ("code" ),
108
- ** params ,
101
+ state = request . query_params . get ( "state" ) ,
109
102
)
110
103
111
- headers .update (additional_headers )
104
+ headers .update ({
105
+ "Accept" : "application/json" ,
106
+ "Content-Type" : "application/x-www-form-urlencoded" ,
107
+ })
112
108
auth = httpx .BasicAuth (self .client_id , self .client_secret )
113
109
async with httpx .AsyncClient () as session :
114
110
response = await session .post (token_url , headers = headers , content = content , auth = auth )
115
- self .oauth_client .parse_request_body_response (json .dumps (response .json ()))
116
-
117
- url , headers , _ = self .oauth_client .add_token (self .userinfo_endpoint )
118
- response = await session .get (url , headers = headers )
119
- content = response .json ()
120
-
121
- return {** content , "scope" : self .scope }
122
-
123
- async def token_redirect (
124
- self ,
125
- request : Request ,
126
- * ,
127
- params : Optional [Dict [str , Any ]] = None ,
128
- headers : Optional [Dict [str , Any ]] = None ,
129
- ) -> RedirectResponse :
130
- token_data = await self .get_token_data (request , params = params , headers = headers )
111
+ token = self .oauth_client .parse_request_body_response (json .dumps (response .json ()))
112
+ data = self .backend .user_data (token .get ("access_token" ))
113
+
114
+ return {** data , "scope" : self .scope }
115
+
116
+ async def token_redirect (self , request : Request ) -> RedirectResponse :
117
+ token_data = await self .get_token_data (request )
131
118
access_token = request .auth .jwt_create (token_data )
132
119
response = RedirectResponse (request .base_url )
133
120
response .set_cookie (
134
121
"Authorization" ,
135
122
value = f"Bearer { access_token } " ,
136
- httponly = self .allow_http ,
137
123
max_age = request .auth .expires ,
138
124
expires = request .auth .expires ,
125
+ httponly = request .auth .http ,
139
126
)
140
127
return response
0 commit comments