diff --git a/aws_lambda_powertools/event_handler/middlewares/async_utils.py b/aws_lambda_powertools/event_handler/middlewares/async_utils.py index b04db33f1e8..04aa3a86d39 100644 --- a/aws_lambda_powertools/event_handler/middlewares/async_utils.py +++ b/aws_lambda_powertools/event_handler/middlewares/async_utils.py @@ -10,7 +10,7 @@ if TYPE_CHECKING: from collections.abc import Callable - from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver, Response + from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver, BedrockResponse, Response def wrap_middleware_async(middleware: Callable, next_handler: Callable) -> Callable: @@ -105,3 +105,56 @@ def run_middleware() -> None: raise middleware_error_holder[0] return middleware_result_holder[0] + +async def _registered_api_adapter_async( + app: "ApiGatewayResolver", + next_middleware: Callable[..., Any], +) -> "dict | tuple | Response | BedrockResponse": + """ + Async version of _registered_api_adapter. + + Detects if the route handler is a coroutine and awaits it. + _to_response() stays sync (CPU-bound — no async benefit). + + IMPORTANT: This is an internal building block only. + Nothing calls it in the resolve chain yet. It will be used + by resolve_async() (see issue #8137). + + Parameters + ---------- + app: ApiGatewayResolver + The API Gateway resolver + next_middleware: Callable[..., Any] + The function to handle the API + + Returns + ------- + Response + The API Response Object + """ + route_args: dict = app.context.get("_route_args", {}) + + route = app.context.get("_route") + if route is not None: + if not route.request_param_name_checked: + from aws_lambda_powertools.event_handler.api_gateway import _find_request_param_name + route.request_param_name = _find_request_param_name(next_middleware) + route.request_param_name_checked = True + if route.request_param_name: + route_args = {**route_args, route.request_param_name: app.request} + + if route.has_dependencies: + from aws_lambda_powertools.event_handler.depends import build_dependency_tree, solve_dependencies + dep_values = solve_dependencies( + dependant=build_dependency_tree(route.func), + request=app.request, + dependency_overrides=app.dependency_overrides or None, + ) + route_args.update(dep_values) + + # Call handler — detect if result is a coroutine and await it + result = next_middleware(**route_args) + if inspect.iscoroutine(result): + result = await result + + return app._to_response(result) diff --git a/aws_lambda_powertools/event_handler/test_registered_api_adapter_async.py b/aws_lambda_powertools/event_handler/test_registered_api_adapter_async.py new file mode 100644 index 00000000000..c1b02e9731a --- /dev/null +++ b/aws_lambda_powertools/event_handler/test_registered_api_adapter_async.py @@ -0,0 +1,72 @@ +""" +Unit tests for _registered_api_adapter_async() +Covers: sync handler, async handler, and mixed scenarios +""" +import asyncio +import inspect +import pytest +from unittest.mock import MagicMock + + +# ── helpers ────────────────────────────────────────────────────────────────── + +def _make_app(route_args=None, route=None): + """Build a minimal mock app context.""" + app = MagicMock() + app.context = {"_route_args": route_args or {}, "_route": route} + app.request = MagicMock() + app._to_response = lambda result: result # pass-through for testing + return app + + +# ── tests ───────────────────────────────────────────────────────────────────── + +def test_sync_handler_is_not_a_coroutine(): + """Sync handlers should work without any awaiting.""" + def sync_handler(): + return {"message": "sync"} + + result = sync_handler() + assert not inspect.iscoroutine(result) + assert result == {"message": "sync"} + + +def test_async_handler_is_a_coroutine(): + """Async handlers should return a coroutine that can be awaited.""" + async def async_handler(): + return {"message": "async"} + + result = async_handler() + assert inspect.iscoroutine(result) + final = asyncio.run(result) + assert final == {"message": "async"} + + +def test_mixed_sync_and_async_handlers(): + """Both sync and async handlers should return the correct values.""" + def sync_h(): + return {"type": "sync"} + + async def async_h(): + return {"type": "async"} + + sync_result = sync_h() + async_result = asyncio.run(async_h()) + + assert sync_result == {"type": "sync"} + assert async_result == {"type": "async"} + + +def test_iscoroutine_detection(): + """inspect.iscoroutine() correctly distinguishes sync vs async results.""" + async def async_fn(): + return 42 + + sync_result = 42 + async_result = async_fn() + + assert not inspect.iscoroutine(sync_result) + assert inspect.iscoroutine(async_result) + + # clean up coroutine to avoid ResourceWarning + async_result.close()