1
- """SSO login base dependency
2
- """
3
- # pylint: disable=too-few-public-methods
1
+ """SSO login base dependency"""
4
2
5
3
import json
6
4
import sys
7
5
import warnings
8
6
from typing import Any , Dict , List , Optional
9
7
10
8
import httpx
11
- import pydantic
12
9
from oauthlib .oauth2 import WebApplicationClient
13
10
from starlette .exceptions import HTTPException
14
11
from starlette .requests import Request
@@ -34,19 +31,6 @@ class SSOLoginError(HTTPException):
34
31
"""
35
32
36
33
37
- class OpenID (pydantic .BaseModel ): # pylint: disable=no-member
38
- """Class (schema) to represent information got from sso provider in a common form."""
39
-
40
- id : Optional [str ] = None
41
- email : Optional [str ] = None
42
- first_name : Optional [str ] = None
43
- last_name : Optional [str ] = None
44
- display_name : Optional [str ] = None
45
- picture : Optional [str ] = None
46
- provider : Optional [str ] = None
47
-
48
-
49
- # pylint: disable=too-many-instance-attributes
50
34
class SSOBase :
51
35
"""Base class (mixin) for all SSO providers"""
52
36
@@ -59,15 +43,14 @@ class SSOBase:
59
43
additional_headers : Optional [Dict [str , Any ]] = None
60
44
61
45
def __init__ (
62
- self ,
63
- client_id : str ,
64
- client_secret : str ,
65
- redirect_uri : Optional [str ] = None ,
66
- allow_insecure_http : bool = False ,
67
- use_state : bool = False ,
68
- scope : Optional [List [str ]] = None ,
46
+ self ,
47
+ client_id : str ,
48
+ client_secret : str ,
49
+ redirect_uri : Optional [str ] = None ,
50
+ allow_insecure_http : bool = False ,
51
+ use_state : bool = False ,
52
+ scope : Optional [List [str ]] = None ,
69
53
):
70
- # pylint: disable=too-many-arguments
71
54
self .client_id = client_id
72
55
self .client_secret = client_secret
73
56
self .redirect_uri = redirect_uri
@@ -116,8 +99,8 @@ def refresh_token(self) -> Optional[str]:
116
99
return self ._refresh_token or self .oauth_client .refresh_token
117
100
118
101
@classmethod
119
- async def openid_from_response (cls , response : dict ) -> OpenID :
120
- """Return {OpenID } object from provider's user info endpoint response"""
102
+ async def openid_from_response (cls , response : dict ) -> dict :
103
+ """Return {dict } object from provider's user info endpoint response"""
121
104
raise NotImplementedError (f"Provider { cls .provider } not supported" )
122
105
123
106
async def get_discovery_document (self ) -> DiscoveryDocument :
@@ -143,11 +126,11 @@ async def userinfo_endpoint(self) -> Optional[str]:
143
126
return discovery .get ("userinfo_endpoint" )
144
127
145
128
async def get_login_url (
146
- self ,
147
- * ,
148
- redirect_uri : Optional [str ] = None ,
149
- params : Optional [Dict [str , Any ]] = None ,
150
- state : Optional [str ] = None ,
129
+ self ,
130
+ * ,
131
+ redirect_uri : Optional [str ] = None ,
132
+ params : Optional [Dict [str , Any ]] = None ,
133
+ state : Optional [str ] = None ,
151
134
) -> str :
152
135
"""Return prepared login url. This is low-level, see {get_login_redirect} instead."""
153
136
params = params or {}
@@ -160,11 +143,11 @@ async def get_login_url(
160
143
return request_uri
161
144
162
145
async def get_login_redirect (
163
- self ,
164
- * ,
165
- redirect_uri : Optional [str ] = None ,
166
- params : Optional [Dict [str , Any ]] = None ,
167
- state : Optional [str ] = None ,
146
+ self ,
147
+ * ,
148
+ redirect_uri : Optional [str ] = None ,
149
+ params : Optional [Dict [str , Any ]] = None ,
150
+ state : Optional [str ] = None ,
168
151
) -> RedirectResponse :
169
152
"""Return redirect response by Stalette to login page of Oauth SSO provider
170
153
@@ -182,13 +165,13 @@ async def get_login_redirect(
182
165
return response
183
166
184
167
async def verify_and_process (
185
- self ,
186
- request : Request ,
187
- * ,
188
- params : Optional [Dict [str , Any ]] = None ,
189
- headers : Optional [Dict [str , Any ]] = None ,
190
- redirect_uri : Optional [str ] = None ,
191
- ) -> Optional [OpenID ]:
168
+ self ,
169
+ request : Request ,
170
+ * ,
171
+ params : Optional [Dict [str , Any ]] = None ,
172
+ headers : Optional [Dict [str , Any ]] = None ,
173
+ redirect_uri : Optional [str ] = None ,
174
+ ) -> Optional [dict ]:
192
175
"""Get FastAPI (Starlette) Request object and process login.
193
176
This handler should be used for your /callback path.
194
177
@@ -197,7 +180,7 @@ async def verify_and_process(
197
180
params {Optional[Dict[str, Any]]} -- Optional additional query parameters to pass to the provider
198
181
199
182
Returns:
200
- Optional[OpenID ] -- OpenID if the login was successfull
183
+ Optional[dict ] -- dict if the login was successfully
201
184
"""
202
185
headers = headers or {}
203
186
code = request .query_params .get ("code" )
@@ -209,22 +192,21 @@ async def verify_and_process(
209
192
)
210
193
211
194
async def process_login (
212
- self ,
213
- code : str ,
214
- request : Request ,
215
- * ,
216
- params : Optional [Dict [str , Any ]] = None ,
217
- additional_headers : Optional [Dict [str , Any ]] = None ,
218
- redirect_uri : Optional [str ] = None ,
219
- ) -> Optional [OpenID ]:
195
+ self ,
196
+ code : str ,
197
+ request : Request ,
198
+ * ,
199
+ params : Optional [Dict [str , Any ]] = None ,
200
+ additional_headers : Optional [Dict [str , Any ]] = None ,
201
+ redirect_uri : Optional [str ] = None ,
202
+ ) -> Optional [dict ]:
220
203
"""This method should be called from callback endpoint to verify the user and request user info endpoint.
221
204
This is low level, you should use {verify_and_process} instead.
222
205
223
206
Arguments:
224
207
params {Optional[Dict[str, Any]]} -- Optional additional query parameters to pass to the provider
225
208
additional_headers {Optional[Dict[str, Any]]} -- Optional additional headers to be added to all requests
226
209
"""
227
- # pylint: disable=too-many-locals
228
210
params = params or {}
229
211
additional_headers = additional_headers or {}
230
212
additional_headers .update (self .additional_headers or {})
0 commit comments