Skip to content

feat(event_handler): add ability to expose a Swagger UI #3254

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
237 changes: 214 additions & 23 deletions aws_lambda_powertools/event_handler/api_gateway.py

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions aws_lambda_powertools/event_handler/lambda_function_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,7 @@ def __init__(
strip_prefixes,
enable_validation,
)

def _get_base_path(self) -> str:
stage = self.current_event.request_context.stage
return f"/{stage}" if stage and stage != "$default" else "/"
Original file line number Diff line number Diff line change
Expand Up @@ -94,20 +94,31 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) ->
else:
# Re-write the route_args with the validated values, and call the next middleware
app.context["_route_args"] = values
response = next_middleware(app)

# Process the response body if it exists
raw_response = jsonable_encoder(response.body)
# Call the handler by calling the next middleware
response = next_middleware(app)

# Validate and serialize the response
return self._serialize_response(field=route.dependant.return_param, response_content=raw_response)
# Process the response
return self._handle_response(route=route, response=response)
except RequestValidationError as e:
return Response(
status_code=422,
content_type="application/json",
body=json.dumps({"detail": e.errors()}),
)

def _handle_response(self, *, route: Route, response: Response):
# Process the response body if it exists
if response.body:
# Validate and serialize the response, if it's JSON
if response.is_json():
response.body = json.dumps(
self._serialize_response(field=route.dependant.return_param, response_content=response.body),
sort_keys=True,
)

return response

def _serialize_response(
self,
*,
Expand Down
32 changes: 16 additions & 16 deletions aws_lambda_powertools/event_handler/openapi/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,9 +363,24 @@ class Config:
extra = "allow"


# https://swagger.io/specification/#tag-object
class Tag(BaseModel):
name: str
description: Optional[str] = None
externalDocs: Optional[ExternalDocumentation] = None

if PYDANTIC_V2:
model_config = {"extra": "allow"}

else:

class Config:
extra = "allow"


# https://swagger.io/specification/#operation-object
class Operation(BaseModel):
tags: Optional[List[str]] = None
tags: Optional[List[Tag]] = None
summary: Optional[str] = None
description: Optional[str] = None
externalDocs: Optional[ExternalDocumentation] = None
Expand Down Expand Up @@ -540,21 +555,6 @@ class Config:
extra = "allow"


# https://swagger.io/specification/#tag-object
class Tag(BaseModel):
name: str
description: Optional[str] = None
externalDocs: Optional[ExternalDocumentation] = None

if PYDANTIC_V2:
model_config = {"extra": "allow"}

else:

class Config:
extra = "allow"


# https://swagger.io/specification/#openapi-object
class OpenAPI(BaseModel):
openapi: str
Expand Down
Empty file.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions aws_lambda_powertools/event_handler/vpc_lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ def __init__(
"""Amazon VPC Lattice resolver"""
super().__init__(ProxyEventType.VPCLatticeEvent, cors, debug, serializer, strip_prefixes, enable_validation)

def _get_base_path(self) -> str:
return "/"


class VPCLatticeV2Resolver(ApiGatewayResolver):
"""VPC Lattice resolver
Expand Down Expand Up @@ -98,3 +101,6 @@ def __init__(
):
"""Amazon VPC Lattice resolver"""
super().__init__(ProxyEventType.VPCLatticeEventV2, cors, debug, serializer, strip_prefixes, enable_validation)

def _get_base_path(self) -> str:
return "/"
23 changes: 23 additions & 0 deletions tests/functional/event_handler/test_api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,29 @@ def handler(event, context):
assert headers["Content-Encoding"] == ["gzip"]


def test_response_is_json_without_content_type():
response = Response(200, None, "")

assert response.is_json() is False


def test_response_is_json_with_json_content_type():
response = Response(200, content_types.APPLICATION_JSON, "")
assert response.is_json() is True


def test_response_is_json_with_multiple_json_content_types():
response = Response(
200,
None,
"",
{
"Content-Type": [content_types.APPLICATION_JSON, content_types.APPLICATION_JSON],
},
)
assert response.is_json() is True


def test_compress():
# GIVEN a function that has compress=True
# AND an event with a "Accept-Encoding" that include gzip
Expand Down
103 changes: 103 additions & 0 deletions tests/functional/event_handler/test_base_path.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from aws_lambda_powertools.event_handler import (
ALBResolver,
APIGatewayHttpResolver,
APIGatewayRestResolver,
LambdaFunctionUrlResolver,
VPCLatticeResolver,
VPCLatticeV2Resolver,
)
from tests.functional.utils import load_event


def test_base_path_api_gateway_rest():
app = APIGatewayRestResolver(enable_validation=True)

@app.get("/")
def handle():
return app._get_base_path()

event = load_event("apiGatewayProxyEvent.json")
event["path"] = "/"

result = app(event, {})
assert result["statusCode"] == 200
assert result["body"] == '"/"'


def test_base_path_api_gateway_http():
app = APIGatewayHttpResolver(enable_validation=True)

@app.get("/")
def handle():
return app._get_base_path()

event = load_event("apiGatewayProxyV2Event.json")
event["rawPath"] = "/"
event["requestContext"]["http"]["path"] = "/"
event["requestContext"]["http"]["method"] = "GET"

result = app(event, {})
assert result["statusCode"] == 200
assert result["body"] == '"/"'


def test_base_path_alb():
app = ALBResolver(enable_validation=True)

@app.get("/")
def handle():
return app._get_base_path()

event = load_event("albEvent.json")
event["path"] = "/"

result = app(event, {})
assert result["statusCode"] == 200
assert result["body"] == '"/"'


def test_base_path_lambda_function_url():
app = LambdaFunctionUrlResolver(enable_validation=True)

@app.get("/")
def handle():
return app._get_base_path()

event = load_event("lambdaFunctionUrlIAMEvent.json")
event["rawPath"] = "/"
event["requestContext"]["http"]["path"] = "/"
event["requestContext"]["http"]["method"] = "GET"

result = app(event, {})
assert result["statusCode"] == 200
assert result["body"] == '"/"'


def test_vpc_lattice():
app = VPCLatticeResolver(enable_validation=True)

@app.get("/")
def handle():
return app._get_base_path()

event = load_event("vpcLatticeEvent.json")
event["raw_path"] = "/"

result = app(event, {})
assert result["statusCode"] == 200
assert result["body"] == '"/"'


def test_vpc_latticev2():
app = VPCLatticeV2Resolver(enable_validation=True)

@app.get("/")
def handle():
return app._get_base_path()

event = load_event("vpcLatticeV2Event.json")
event["path"] = "/"

result = app(event, {})
assert result["statusCode"] == 200
assert result["body"] == '"/"'
93 changes: 92 additions & 1 deletion tests/functional/event_handler/test_openapi_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Parameter,
ParameterInType,
Schema,
Tag,
)
from aws_lambda_powertools.event_handler.openapi.params import (
Body,
Expand Down Expand Up @@ -118,7 +119,7 @@ def handler(
assert get.summary == "Get Users"
assert get.operationId == "GetUsers"
assert get.description == "Get paginated users"
assert get.tags == ["Users"]
assert get.tags == [Tag(name="Users")]

parameter = get.parameters[0]
assert parameter.required is False
Expand Down Expand Up @@ -152,6 +153,54 @@ def handler() -> str:
assert response.schema_.type == "string"


def test_openapi_with_omitted_param():
app = APIGatewayRestResolver()

@app.get("/")
def handler(page: Annotated[str, Query(include_in_schema=False)]):
return page

schema = app.get_openapi_schema()
assert len(schema.paths.keys()) == 1

get = schema.paths["/"].get
assert get.parameters is None


def test_openapi_with_description():
app = APIGatewayRestResolver()

@app.get("/")
def handler(page: Annotated[str, Query(description="This is a description")]):
return page

schema = app.get_openapi_schema()
assert len(schema.paths.keys()) == 1

get = schema.paths["/"].get
assert len(get.parameters) == 1

parameter = get.parameters[0]
assert parameter.description == "This is a description"


def test_openapi_with_deprecated():
app = APIGatewayRestResolver()

@app.get("/")
def handler(page: Annotated[str, Query(deprecated=True)]):
return page

schema = app.get_openapi_schema()
assert len(schema.paths.keys()) == 1

get = schema.paths["/"].get
assert len(get.parameters) == 1

parameter = get.parameters[0]
assert parameter.deprecated is True


def test_openapi_with_pydantic_returns():
app = APIGatewayRestResolver()

Expand Down Expand Up @@ -283,6 +332,48 @@ def handler(user: Annotated[User, Body(embed=True)]):
assert body_post_handler_schema.properties["user"].ref == "#/components/schemas/User"


def test_openapi_with_tags():
app = APIGatewayRestResolver()

@app.get("/users")
def handler():
raise NotImplementedError()

schema = app.get_openapi_schema(tags=["Orders"])
assert len(schema.tags) == 1

tag = schema.tags[0]
assert tag.name == "Orders"


def test_openapi_operation_with_tags():
app = APIGatewayRestResolver()

@app.get("/users", tags=["Users"])
def handler():
raise NotImplementedError()

schema = app.get_openapi_schema()
assert len(schema.paths.keys()) == 1

get = schema.paths["/users"].get
assert len(get.tags) == 1

tag = get.tags[0]
assert tag.name == "Users"


def test_openapi_with_excluded_operations():
app = APIGatewayRestResolver()

@app.get("/secret", include_in_schema=False)
def secret():
return "password"

schema = app.get_openapi_schema()
assert len(schema.paths.keys()) == 0


def test_create_header():
header = Header(convert_underscores=True)
assert header.convert_underscores is True
Expand Down
Loading