diff --git a/aws_lambda_powertools/middleware_factory/factory.py b/aws_lambda_powertools/middleware_factory/factory.py index a66fed3014d..53d05d0383b 100644 --- a/aws_lambda_powertools/middleware_factory/factory.py +++ b/aws_lambda_powertools/middleware_factory/factory.py @@ -2,7 +2,7 @@ import inspect import logging import os -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, TypeVar, cast from ..shared import constants from ..shared.functions import resolve_truthy_env_var_choice @@ -11,9 +11,14 @@ logger = logging.getLogger(__name__) +FuncType = TypeVar("FuncType", bound=Callable[..., Any]) + # Maintenance: we can't yet provide an accurate return type without ParamSpec etc. see #1066 -def lambda_handler_decorator(decorator: Optional[Callable] = None, trace_execution: Optional[bool] = None) -> Callable: +def lambda_handler_decorator( + decorator: Optional[Callable[..., Any]] = None, + trace_execution: Optional[bool] = None, +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: """Decorator factory for decorating Lambda handlers. You can use lambda_handler_decorator to create your own middlewares, @@ -34,7 +39,7 @@ def lambda_handler_decorator(decorator: Optional[Callable] = None, trace_executi Parameters ---------- - decorator: Callable + decorator: Callable[..., Any] Middleware to be wrapped by this factory trace_execution: bool Flag to explicitly enable trace execution for middlewares.\n @@ -104,7 +109,10 @@ def lambda_handler(event, context): """ if decorator is None: - return functools.partial(lambda_handler_decorator, trace_execution=trace_execution) + return cast( + Callable[[Callable[..., Any]], Callable[..., Any]], + functools.partial(lambda_handler_decorator, trace_execution=trace_execution), + ) trace_execution = resolve_truthy_env_var_choice( env=os.getenv(constants.MIDDLEWARE_FACTORY_TRACE_ENV, "false"), @@ -112,10 +120,10 @@ def lambda_handler(event, context): ) @functools.wraps(decorator) - def final_decorator(func: Optional[Callable] = None, **kwargs: Any): + def final_decorator(*args: Any, func: Optional[FuncType] = None, **kwargs: Any) -> FuncType: # If called with kwargs return new func with kwargs if func is None: - return functools.partial(final_decorator, **kwargs) + return cast(FuncType, functools.partial(final_decorator, *args, **kwargs)) if not inspect.isfunction(func): # @custom_middleware(True) vs @custom_middleware(log_event=True) @@ -124,9 +132,11 @@ def final_decorator(func: Optional[Callable] = None, **kwargs: Any): ) @functools.wraps(func) - def wrapper(event, context, **handler_kwargs): + def wrapper(*args: Any, **kwargs: Any) -> Any: try: - middleware = functools.partial(decorator, func, event, context, **kwargs, **handler_kwargs) + if decorator is None: + raise ValueError("Decorator cannot be None") + middleware = functools.partial(decorator, func, *args, **kwargs) if trace_execution: tracer = Tracer(auto_patch=False) with tracer.provider.in_subsegment(name=f"## {decorator.__qualname__}"): @@ -135,9 +145,11 @@ def wrapper(event, context, **handler_kwargs): response = middleware() return response except Exception: - logger.exception(f"Caught exception in {decorator.__qualname__}") + logger.exception( + f"Caught exception in {decorator.__qualname__ if decorator is not None else 'UnknownDecorator'}", + ) raise - return wrapper + return cast(FuncType, wrapper) return final_decorator diff --git a/aws_lambda_powertools/utilities/data_classes/event_source.py b/aws_lambda_powertools/utilities/data_classes/event_source.py index 3968f923573..e7c3630a310 100644 --- a/aws_lambda_powertools/utilities/data_classes/event_source.py +++ b/aws_lambda_powertools/utilities/data_classes/event_source.py @@ -11,7 +11,7 @@ def event_source( event: Dict[str, Any], context: LambdaContext, data_class: Type[DictWrapper], -): +) -> Any: """Middleware to create an instance of the passed in event source data class Parameters