Skip to content

Commit 5a40010

Browse files
committed
Include all incoming response attributes into token
1 parent 8f0d17f commit 5a40010

File tree

6 files changed

+59
-83
lines changed

6 files changed

+59
-83
lines changed

examples/airnominal/.env

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
AIRNOMINAL_GITHUB_CLIENT_ID=eccd08d6736b7999a32a
2-
AIRNOMINAL_GITHUB_CLIENT_SECRET=642999c1c5f2b3df8b877afdc78252ef5b594d31
3-
AIRNOMINAL_GITHUB_REDIRECT_URL=http://127.0.0.1:8000/auth/callback
4-
AIRNMONIAL_MAIN_PAGE_REDIRECT_URL=http://127.0.0.1:8000/
1+
GITHUB_CLIENT_ID=eccd08d6736b7999a32a
2+
GITHUB_CLIENT_SECRET=642999c1c5f2b3df8b877afdc78252ef5b594d31
3+
GITHUB_REDIRECT_URL=http://127.0.0.1:8000/auth/callback
4+
MAIN_PAGE_REDIRECT_URL=http://127.0.0.1:8000/
55

6-
AIRNOMINAL_JWT_SECRET_KEY=secret
7-
AIRNOMINAL_JWT_ALGORITHM=HS256
8-
AIRNOMINAL_JWT_TOKEN_EXPIRES=300
6+
JWT_SECRET_KEY=secret
7+
JWT_ALGORITHM=HS256
8+
JWT_TOKEN_EXPIRES=300

examples/airnominal/auth.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
router = APIRouter()
2222

23-
# config for github SSO
23+
# config for GitHub SSO
2424
os.environ["OAUTHLIB_INSECURE_TRANSPORT"] = "1"
2525

2626
sso = GithubSSO(

examples/airnominal/config.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44

55
load_dotenv()
66

7-
# config for github SSO
8-
CLIENT_ID = os.getenv("AIRNOMINAL_GITHUB_CLIENT_ID")
9-
CLIENT_SECRET = os.getenv("AIRNOMINAL_GITHUB_CLIENT_SECRET")
10-
redirect_url = os.getenv("AIRNOMINAL_GITHUB_REDIRECT_URL")
11-
redirect_url_main_page = os.getenv("AIRNMONIAL_MAIN_PAGE_REDIRECT_URL")
7+
# config for GitHub SSO
8+
CLIENT_ID = os.getenv("GITHUB_CLIENT_ID")
9+
CLIENT_SECRET = os.getenv("GITHUB_CLIENT_SECRET")
10+
redirect_url = os.getenv("GITHUB_REDIRECT_URL")
11+
redirect_url_main_page = os.getenv("MAIN_PAGE_REDIRECT_URL")
1212

1313
# config for jwt generation
14-
SECRET_KEY = os.getenv("AIRNOMINAL_JWT_SECRET_KEY")
15-
ALGORITHM = os.getenv("AIRNOMINAL_JWT_ALGORITHM")
16-
ACCESS_TOKEN_EXPIRE_MINUTES = int(os.getenv("AIRNOMINAL_JWT_TOKEN_EXPIRES"))
14+
SECRET_KEY = os.getenv("JWT_SECRET_KEY")
15+
ALGORITHM = os.getenv("JWT_ALGORITHM")
16+
ACCESS_TOKEN_EXPIRE_MINUTES = int(os.getenv("JWT_TOKEN_EXPIRES"))

examples/airnominal/fastapi_sso/base.py

Lines changed: 36 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
1-
"""SSO login base dependency
2-
"""
3-
# pylint: disable=too-few-public-methods
1+
"""SSO login base dependency"""
42

53
import json
64
import sys
75
import warnings
86
from typing import Any, Dict, List, Optional
97

108
import httpx
11-
import pydantic
129
from oauthlib.oauth2 import WebApplicationClient
1310
from starlette.exceptions import HTTPException
1411
from starlette.requests import Request
@@ -34,19 +31,6 @@ class SSOLoginError(HTTPException):
3431
"""
3532

3633

37-
class OpenID(pydantic.BaseModel): # pylint: disable=no-member
38-
"""Class (schema) to represent information got from sso provider in a common form."""
39-
40-
id: Optional[str] = None
41-
email: Optional[str] = None
42-
first_name: Optional[str] = None
43-
last_name: Optional[str] = None
44-
display_name: Optional[str] = None
45-
picture: Optional[str] = None
46-
provider: Optional[str] = None
47-
48-
49-
# pylint: disable=too-many-instance-attributes
5034
class SSOBase:
5135
"""Base class (mixin) for all SSO providers"""
5236

@@ -59,15 +43,14 @@ class SSOBase:
5943
additional_headers: Optional[Dict[str, Any]] = None
6044

6145
def __init__(
62-
self,
63-
client_id: str,
64-
client_secret: str,
65-
redirect_uri: Optional[str] = None,
66-
allow_insecure_http: bool = False,
67-
use_state: bool = False,
68-
scope: Optional[List[str]] = None,
46+
self,
47+
client_id: str,
48+
client_secret: str,
49+
redirect_uri: Optional[str] = None,
50+
allow_insecure_http: bool = False,
51+
use_state: bool = False,
52+
scope: Optional[List[str]] = None,
6953
):
70-
# pylint: disable=too-many-arguments
7154
self.client_id = client_id
7255
self.client_secret = client_secret
7356
self.redirect_uri = redirect_uri
@@ -116,8 +99,8 @@ def refresh_token(self) -> Optional[str]:
11699
return self._refresh_token or self.oauth_client.refresh_token
117100

118101
@classmethod
119-
async def openid_from_response(cls, response: dict) -> OpenID:
120-
"""Return {OpenID} object from provider's user info endpoint response"""
102+
async def openid_from_response(cls, response: dict) -> dict:
103+
"""Return {dict} object from provider's user info endpoint response"""
121104
raise NotImplementedError(f"Provider {cls.provider} not supported")
122105

123106
async def get_discovery_document(self) -> DiscoveryDocument:
@@ -143,11 +126,11 @@ async def userinfo_endpoint(self) -> Optional[str]:
143126
return discovery.get("userinfo_endpoint")
144127

145128
async def get_login_url(
146-
self,
147-
*,
148-
redirect_uri: Optional[str] = None,
149-
params: Optional[Dict[str, Any]] = None,
150-
state: Optional[str] = None,
129+
self,
130+
*,
131+
redirect_uri: Optional[str] = None,
132+
params: Optional[Dict[str, Any]] = None,
133+
state: Optional[str] = None,
151134
) -> str:
152135
"""Return prepared login url. This is low-level, see {get_login_redirect} instead."""
153136
params = params or {}
@@ -160,11 +143,11 @@ async def get_login_url(
160143
return request_uri
161144

162145
async def get_login_redirect(
163-
self,
164-
*,
165-
redirect_uri: Optional[str] = None,
166-
params: Optional[Dict[str, Any]] = None,
167-
state: Optional[str] = None,
146+
self,
147+
*,
148+
redirect_uri: Optional[str] = None,
149+
params: Optional[Dict[str, Any]] = None,
150+
state: Optional[str] = None,
168151
) -> RedirectResponse:
169152
"""Return redirect response by Stalette to login page of Oauth SSO provider
170153
@@ -182,13 +165,13 @@ async def get_login_redirect(
182165
return response
183166

184167
async def verify_and_process(
185-
self,
186-
request: Request,
187-
*,
188-
params: Optional[Dict[str, Any]] = None,
189-
headers: Optional[Dict[str, Any]] = None,
190-
redirect_uri: Optional[str] = None,
191-
) -> Optional[OpenID]:
168+
self,
169+
request: Request,
170+
*,
171+
params: Optional[Dict[str, Any]] = None,
172+
headers: Optional[Dict[str, Any]] = None,
173+
redirect_uri: Optional[str] = None,
174+
) -> Optional[dict]:
192175
"""Get FastAPI (Starlette) Request object and process login.
193176
This handler should be used for your /callback path.
194177
@@ -197,7 +180,7 @@ async def verify_and_process(
197180
params {Optional[Dict[str, Any]]} -- Optional additional query parameters to pass to the provider
198181
199182
Returns:
200-
Optional[OpenID] -- OpenID if the login was successfull
183+
Optional[dict] -- dict if the login was successfully
201184
"""
202185
headers = headers or {}
203186
code = request.query_params.get("code")
@@ -209,22 +192,21 @@ async def verify_and_process(
209192
)
210193

211194
async def process_login(
212-
self,
213-
code: str,
214-
request: Request,
215-
*,
216-
params: Optional[Dict[str, Any]] = None,
217-
additional_headers: Optional[Dict[str, Any]] = None,
218-
redirect_uri: Optional[str] = None,
219-
) -> Optional[OpenID]:
195+
self,
196+
code: str,
197+
request: Request,
198+
*,
199+
params: Optional[Dict[str, Any]] = None,
200+
additional_headers: Optional[Dict[str, Any]] = None,
201+
redirect_uri: Optional[str] = None,
202+
) -> Optional[dict]:
220203
"""This method should be called from callback endpoint to verify the user and request user info endpoint.
221204
This is low level, you should use {verify_and_process} instead.
222205
223206
Arguments:
224207
params {Optional[Dict[str, Any]]} -- Optional additional query parameters to pass to the provider
225208
additional_headers {Optional[Dict[str, Any]]} -- Optional additional headers to be added to all requests
226209
"""
227-
# pylint: disable=too-many-locals
228210
params = params or {}
229211
additional_headers = additional_headers or {}
230212
additional_headers.update(self.additional_headers or {})
Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
"""Github SSO Oauth Helper class"""
1+
"""GitHub SSO Oauth Helper class"""
22

3-
from .base import DiscoveryDocument, OpenID, SSOBase
3+
from .base import DiscoveryDocument, SSOBase
44

55

66
class GithubSSO(SSOBase):
7-
"""Class providing login via Github SSO"""
7+
"""Class providing login via GitHub SSO"""
88

99
provider = "github"
1010
scope = ["user:email"]
@@ -18,11 +18,5 @@ async def get_discovery_document(self) -> DiscoveryDocument:
1818
}
1919

2020
@classmethod
21-
async def openid_from_response(cls, response: dict) -> OpenID:
22-
return OpenID(
23-
email=response["email"],
24-
provider=cls.provider,
25-
id=response["id"],
26-
display_name=response["login"],
27-
picture=response["avatar_url"],
28-
)
21+
async def openid_from_response(cls, response: dict) -> dict:
22+
return {**response, "provider": cls.provider}

examples/airnominal/templates/index.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
</head>
1010
<body>
1111
{% if request.user %}
12-
<h1>Hello, {{ request.user.display_name }}</h1>
13-
<img src="{{ request.user.picture }}" alt="Pic">
12+
<h1>Hello, {{ request.user.name }}</h1>
13+
<img src="{{ request.user.avatar_url }}" alt="Pic">
1414
<a href="/auth/logout">Logout</a>
1515
{% else %}
1616
<a href="/auth/login">Sign in</a>

0 commit comments

Comments
 (0)