diff --git a/CHANGELOG.md b/CHANGELOG.md index d0cb400..d73ed79 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ - Removed "convert camel case" option. - Removed the custom exceptions. From methods, return an ErrorResponse instead of raising an exception. +- Methods should now return a Response object, and not raise exceptions. Refactoring/internal changes: diff --git a/doc/api.md b/doc/api.md index c837de6..213f024 100644 --- a/doc/api.md +++ b/doc/api.md @@ -133,34 +133,38 @@ trim_log_values = yes ## Errors -The library handles most errors related to the JSON-RPC standard. +The library handles some errors related to the JSON-RPC standard, such as +invalid json or invalid json-rpc requests. + +To return a custom error response: ```python -from jsonrpcserver.exceptions import InvalidParamsError +from jsonrpcserver.response import Context, InvalidParamsResponse, SuccessResponse @method -def fruits(color): +def fruits(context: Context, color: str) -> Union[SuccessResponse, InvalidParamsResponse]: if color not in ("red", "orange", "yellow"): - raise InvalidParamsError("No fruits of that colour") + return InvalidParamsResponse("No fruits of that colour", id=context.request.id) + return SuccessResponse("blue", id=context.request.id) ``` The dispatcher will give the appropriate response: ```python ->>> str(dispatch('{"jsonrpc": "2.0", "method": "fruits", "params": {"color": "blue"}, "id": 1}')) -'{"jsonrpc": "2.0", "error": {"code": -32602, "message": "Invalid parameters"}, "id": 1}' +>>> dispatch('{"jsonrpc": "2.0", "method": "fruits", "params": {"color": "blue"}, "id": 1}') +InvalidParamsResponse(code=-32602, message='Invalid parameters', id=1) ``` -To send some other application-defined error response, raise an `ApiError` in a -similar way. +To send some other application-defined error response, return an +`ApiErrorResponse` in a similar way. ```python -from jsonrpcserver.exceptions import ApiError +from jsonrpcserver.response import ApiErrorResponse @method def my_method(): if some_condition: - raise ApiError("Can't fulfill the request") + return ApiErrorResponse("Can't fulfill the request") ``` ## Async diff --git a/jsonrpcserver/async_dispatcher.py b/jsonrpcserver/async_dispatcher.py index 4d9f54f..d394477 100644 --- a/jsonrpcserver/async_dispatcher.py +++ b/jsonrpcserver/async_dispatcher.py @@ -1,6 +1,8 @@ """Asynchronous dispatch""" + import asyncio import collections.abc +import logging from json import JSONDecodeError from json import dumps as default_serialize, loads as default_deserialize from typing import Any, Iterable, Optional, Union, Callable @@ -13,53 +15,73 @@ add_handlers, config, create_requests, - handle_exceptions, log_request, log_response, remove_handlers, schema, validate, ) -from .methods import Method, Methods, global_methods, validate_args, lookup -from .request import Request +from .methods import Methods, global_methods, validate_args +from .request import Request, is_notification from .response import ( BatchResponse, + ExceptionResponse, InvalidJSONResponse, InvalidJSONRPCResponse, + InvalidParamsResponse, + NotificationResponse, Response, SuccessResponse, ) -async def call(method: Method, *args: Any, **kwargs: Any) -> Any: - return await validate_args(method, *args, **kwargs)(*args, **kwargs) +async def call(request, method, *args, **kwargs) -> Response: + errors = validate_args(method, *args, **kwargs) + return ( + await method(*args, **kwargs) + if not errors + else InvalidParamsResponse(data=errors, id=request.id) + ) async def safe_call( request: Request, methods: Methods, *, extra: Any, serialize: Callable ) -> Response: - with handle_exceptions(request) as handler: - if isinstance(request.params, list): - result = await call( - lookup(methods, request.method), + try: + result = ( + await call( + methods.items[request.method], *([Context(request=request, extra=extra)] + request.params), ) - else: - result = await call( - lookup(methods, request.method), + if isinstance(request.params, list) + else await call( + methods.items[request.method], Context(request=request, extra=extra), **request.params, ) + ) # Ensure value returned from the method is JSON-serializable. If not, # handle_exception will set handler.response to an ExceptionResponse serialize(result) - handler.response = SuccessResponse( - result=result, id=request.id, serialize_func=serialize + except asyncio.CancelledError: + # Allow CancelledError from asyncio task cancellation to bubble up. Without + # this, CancelledError is caught and handled, resulting in a "Server error" + # response object from the dispatcher, but because the CancelledError doesn't + # bubble up the rpc_server task doesn't exit. See PR + # https://github.com/bcb/jsonrpcserver/pull/132 + raise + except Exception as exc: # Other error inside method - server error + logging.exception(exc) + return ExceptionResponse(exc, id=request.id) + else: + return ( + NotificationResponse() + if is_notification(request) + else SuccessResponse(result=result, id=request.id, serialize_func=serialize) ) - return handler.response -async def call_requests( +async def dispatch_requests( requests: Union[Request, Iterable[Request]], methods: Methods, extra: Any, @@ -87,7 +109,7 @@ async def dispatch_pure( return InvalidJSONResponse(data=str(exc)) except ValidationError as exc: return InvalidJSONRPCResponse(data=None) - return await call_requests( + return await dispatch_requests( create_requests(deserialized), methods, extra=extra, diff --git a/jsonrpcserver/dispatcher.py b/jsonrpcserver/dispatcher.py index ab413e1..2c623b7 100644 --- a/jsonrpcserver/dispatcher.py +++ b/jsonrpcserver/dispatcher.py @@ -4,20 +4,16 @@ The dispatch() function takes a JSON-RPC request, logs it, calls the appropriate method, then logs and returns the response. """ -import asyncio import logging import os from collections.abc import Iterable from configparser import ConfigParser -from contextlib import contextmanager from json import JSONDecodeError from json import dumps as default_serialize, loads as default_deserialize -from types import SimpleNamespace from typing import ( Any, Callable, Dict, - Generator, Iterable, List, NamedTuple, @@ -33,10 +29,9 @@ from pkg_resources import resource_string from .log import log_ -from .methods import Method, Methods, global_methods, validate_args, lookup +from .methods import Methods, global_methods, validate_args from .request import Request, is_notification, NOID from .response import ( - ApiErrorResponse, BatchResponse, ExceptionResponse, InvalidJSONResponse, @@ -45,30 +40,29 @@ MethodNotFoundResponse, NotificationResponse, Response, - SuccessResponse, ) -from .exceptions import MethodNotFoundError, InvalidParamsError, ApiError + +Context = NamedTuple( + "Context", + [("request", Request), ("extra", Any)], +) request_logger = logging.getLogger(__name__ + ".request") response_logger = logging.getLogger(__name__ + ".response") +DEFAULT_REQUEST_LOG_FORMAT = "--> %(message)s" +DEFAULT_RESPONSE_LOG_FORMAT = "<-- %(message)s" + # Prepare the jsonschema validator schema = default_deserialize(resource_string(__name__, "request-schema.json")) klass = validator_for(schema) klass.check_schema(schema) validator = klass(schema) -DEFAULT_REQUEST_LOG_FORMAT = "--> %(message)s" -DEFAULT_RESPONSE_LOG_FORMAT = "<-- %(message)s" - +# Read configuration file config = ConfigParser(default_section="dispatch") config.read([".jsonrpcserverrc", os.path.expanduser("~/.jsonrpcserverrc")]) -Context = NamedTuple( - "Context", - [("request", Request), ("extra", Any)], -) - def add_handlers() -> Tuple[logging.Handler, logging.Handler]: # Request handler @@ -120,49 +114,30 @@ def validate(request: Union[Dict, List], schema: dict) -> Union[Dict, List]: return request -def call(method: Method, *args: Any, **kwargs: Any) -> Any: - """ - Validates arguments and then calls the method. - - Args: - method: The method to call. - *args, **kwargs: Arguments to the method. - - Returns: - The "result" part of the JSON-RPC response (the return value from the method). - """ - return validate_args(method, *args, **kwargs)(*args, **kwargs) +def c(request, method, *args, **kwargs) -> Response: + errors = validate_args(method, *args, **kwargs) + return ( + method(*args, **kwargs) + if not errors + else InvalidParamsResponse(data=errors, id=request.id) + ) -@contextmanager -def handle_exceptions(request: Request) -> Generator: - handler = SimpleNamespace(response=None) - try: - yield handler - except MethodNotFoundError: - handler.response = MethodNotFoundResponse(id=request.id, data=request.method) - except (InvalidParamsError, AssertionError) as exc: - # InvalidParamsError is raised by validate_args. AssertionError is raised inside - # the methods, however it's better to raise InvalidParamsError inside methods. - # AssertionError will be removed in the next major release. - handler.response = InvalidParamsResponse(id=request.id, data=str(exc)) - except ApiError as exc: # Method signals custom error - handler.response = ApiErrorResponse( - str(exc), code=exc.code, data=exc.data, id=request.id +def call(request: Request, method: Callable, *, extra: Any) -> Response: + return ( + c( + request, + method, + *([Context(request=request, extra=extra)] + request.params), + ) + if isinstance(request.params, list) + else c( + request, + method, + Context(request=request, extra=extra), + **request.params, ) - except asyncio.CancelledError: - # Allow CancelledError from asyncio task cancellation to bubble up. Without - # this, CancelledError is caught and handled, resulting in a "Server error" - # response object from the dispatcher, but because the CancelledError doesn't - # bubble up the rpc_server task doesn't exit. See PR - # https://github.com/bcb/jsonrpcserver/pull/132 - raise - except Exception as exc: # Other error inside method - server error - logging.exception(exc) - handler.response = ExceptionResponse(exc, id=request.id) - finally: - if is_notification(request): - handler.response = NotificationResponse() + ) def safe_call( @@ -179,28 +154,19 @@ def safe_call( Returns: A Response object. """ - with handle_exceptions(request) as handler: - if isinstance(request.params, list): - result = call( - lookup(methods, request.method), - *([Context(request=request, extra=extra)] + request.params), - ) + if request.method in methods.items: + try: + response = call(request, methods.items[request.method], extra=extra) + except Exception as exc: # Other error inside method - server error + logging.exception(exc) + return ExceptionResponse(exc, id=request.id) else: - result = call( - lookup(methods, request.method), - Context(request=request, extra=extra), - **request.params, - ) - # Ensure value returned from the method is JSON-serializable. If not, - # handle_exception will set handler.response to an ExceptionResponse - serialize(result) - handler.response = SuccessResponse( - result=result, id=request.id, serialize_func=serialize - ) - return handler.response + return NotificationResponse() if is_notification(request) else response + else: + return MethodNotFoundResponse(data=request.method, id=request.id) -def call_requests( +def dispatch_requests_pure( requests: Union[Request, Iterable[Request]], methods: Methods, *, @@ -233,6 +199,19 @@ def call_requests( ) +def dispatch_requests( + requests: Union[Request, Iterable[Request]], + methods: Methods, + *, + extra: Optional[Any] = None, + serialize: Callable = default_serialize, +) -> Response: + """ + Impure (public) version of dispatch_requests_pure - has default values. + """ + return dispatch_requests_pure(requests, methods, extra=extra, serialize=serialize) + + def create_requests( requests: Union[Dict, List[Dict]], ) -> Union[Request, List[Request]]: @@ -293,12 +272,13 @@ def dispatch_pure( return InvalidJSONResponse(data=str(exc)) except ValidationError as exc: return InvalidJSONRPCResponse(data=None) - return call_requests( - create_requests(deserialized), - methods=methods, - extra=extra, - serialize=serialize, - ) + else: + return dispatch_requests_pure( + create_requests(deserialized), + methods=methods, + extra=extra, + serialize=serialize, + ) @apply_config(config) diff --git a/jsonrpcserver/exceptions.py b/jsonrpcserver/exceptions.py deleted file mode 100644 index 56811ff..0000000 --- a/jsonrpcserver/exceptions.py +++ /dev/null @@ -1,34 +0,0 @@ -"""Exceptions""" -from typing import Any - -from .response import UNSPECIFIED - - -class MethodNotFoundError(Exception): - """ Method lookup failed """ - - pass - - -class InvalidParamsError(Exception): - """ Method arguments invalid """ - - pass - - -class ApiError(Exception): - """ A method responds with a custom error """ - - def __init__(self, message: str, code: int = 1, data: Any = UNSPECIFIED): - """ - Args: - message: A string providing a short description of the error, eg. "Invalid - params". - code: A Number that indicates the error type that occurred. This MUST be an - integer. - data: A Primitive or Structured value that contains additional information - about the error. This may be omitted. - """ - super().__init__(message) - self.code = code - self.data = data diff --git a/jsonrpcserver/methods.py b/jsonrpcserver/methods.py index c88bdad..4b93964 100644 --- a/jsonrpcserver/methods.py +++ b/jsonrpcserver/methods.py @@ -10,33 +10,27 @@ from inspect import signature -from .exceptions import MethodNotFoundError, InvalidParamsError - Method = Callable[..., Any] -def validate_args(func: Method, *args: Any, **kwargs: Any) -> Method: +def validate_args(func: Method, *args: Any, **kwargs: Any) -> str: """ Check if the request's arguments match a function's signature. - Raises InvalidParamsError if arguments cannot be passed to a function. - Args: func: The function to check. args: Positional arguments. kwargs: Keyword arguments. Returns: - The same function passed in. - - Raises: - InvalidParamsError: If the arguments cannot be passed to the function. + An empty string if arguments can be passed to a function, an error + message otherwise. """ try: signature(func).bind(*args, **kwargs) except TypeError as exc: - raise InvalidParamsError(exc) from exc - return func + return str(exc) + return "" class Methods: @@ -106,26 +100,6 @@ def _batch_add(self, *args: Any, **kwargs: Any) -> Optional[Callable]: return None -def lookup(methods: Methods, method_name: str) -> Method: - """ - Lookup a method. - - Args: - methods: Methods object - method_name: Method name to look up - - Returns: - The callable method. - - Raises: - MethodNotFoundError if method_name is not found. - """ - try: - return methods.items[method_name] - except KeyError as exc: - raise MethodNotFoundError(method_name) from exc - - # A default Methods object which can be used, or user can create their own. global_methods = Methods() diff --git a/tests/test_dispatcher.py b/tests/test_dispatcher.py index 9d7ad36..9c8636a 100644 --- a/tests/test_dispatcher.py +++ b/tests/test_dispatcher.py @@ -1,3 +1,4 @@ +"""TODO: Add tests for dispatch_requests (non-pure version)""" import logging from json import dumps as serialize from unittest.mock import sentinel @@ -5,7 +6,7 @@ from jsonrpcserver.dispatcher import ( Context, add_handlers, - call_requests, + dispatch_requests_pure, create_requests, dispatch, dispatch_pure, @@ -19,6 +20,7 @@ from jsonrpcserver.methods import Methods, global_methods from jsonrpcserver.request import Request, NOID from jsonrpcserver.response import ( + ApiErrorResponse, BatchResponse, ErrorResponse, InvalidJSONResponse, @@ -28,11 +30,10 @@ NotificationResponse, SuccessResponse, ) -from jsonrpcserver.exceptions import ApiError def ping(context: Context): - return "pong" + return SuccessResponse("pong", id=context.request.id) def test_add_handlers(): @@ -78,10 +79,10 @@ def test_safe_call_notification(): def test_safe_call_notification_failure(): def fail(): - raise ValueError() + 1 / 0 response = safe_call( - Request(method="foo", params=[], id=NOID), + Request(method="fail", params=[], id=NOID), Methods(fail), extra=None, serialize=default_serialize, @@ -111,7 +112,9 @@ def test_safe_call_invalid_args(): def test_safe_call_api_error(): def error(context: Context): - raise ApiError("Client Error", code=123, data={"data": 42}) + return ApiErrorResponse( + "Client Error", code=123, data={"data": 42}, id=context.request.id + ) response = safe_call( Request(method="error", params=[], id=1), @@ -128,7 +131,7 @@ def error(context: Context): def test_safe_call_api_error_minimal(): def error(context: Context): - raise ApiError("Client Error") + return ApiErrorResponse("Client Error", code=123, id=context.request.id) response = safe_call( Request(method="error", params=[], id=1), @@ -140,38 +143,18 @@ def error(context: Context): response_dict = response.deserialized() error_dict = response_dict["error"] assert error_dict["message"] == "Client Error" - assert error_dict["code"] == 1 + assert error_dict["code"] == 123 assert "data" not in error_dict -def test_non_json_encodable_resonse(): - def method(context: Context): - return b"Hello, World" - - response = safe_call( - Request(method="method", params=[], id=1), - Methods(method), - extra=None, - serialize=default_serialize, - ) - # response must be serializable here - str(response) - assert isinstance(response, ErrorResponse) - response_dict = response.deserialized() - error_dict = response_dict["error"] - assert error_dict["message"] == "Server error" - assert error_dict["code"] == -32000 - assert "data" in error_dict - - -# call_requests +# dispatch_requests_pure -def test_call_requests_with_extra(): +def test_dispatch_requests_pure_with_extra(): def ping_with_extra(context: Context): assert context.extra is sentinel.extra - call_requests( + dispatch_requests_pure( Request(method="ping_with_extra", params=[], id=1), Methods(ping_with_extra), extra=sentinel.extra, @@ -180,20 +163,6 @@ def ping_with_extra(context: Context): # Assert is in the method -def test_call_requests_batch_all_notifications(): - """Should return a BatchResponse response, an empty list""" - response = call_requests( - [ - Request(method="notify_sum", params=[1, 2, 4], id=NOID), - Request(method="notify_hello", params=[7], id=NOID), - ], - Methods(ping), - extra=None, - serialize=default_serialize, - ) - assert str(response) == "" - - # create_requests @@ -272,8 +241,9 @@ def test_dispatch_pure_invalid_jsonrpc(): def test_dispatch_pure_invalid_params(): - def foo(context: Context, colour): - assert colour in ("orange", "red", "yellow"), "Invalid colour" + def foo(context: Context, colour: str): + if colour not in ("orange", "red", "yellow"): + return InvalidParamsResponse(id=context.request.id) response = dispatch_pure( '{"jsonrpc": "2.0", "method": "foo", "params": ["blue"], "id": 1}', @@ -335,7 +305,7 @@ def test_dispatch_basic_logging(): def test_examples_positionals(): def subtract(context: Context, minuend, subtrahend): - return minuend - subtrahend + return SuccessResponse(minuend - subtrahend, id=context.request.id) response = dispatch_pure( '{"jsonrpc": "2.0", "method": "subtract", "params": [42, 23], "id": 1}', @@ -361,7 +331,9 @@ def subtract(context: Context, minuend, subtrahend): def test_examples_nameds(): def subtract(context: Context, **kwargs): - return kwargs["minuend"] - kwargs["subtrahend"] + return SuccessResponse( + kwargs["minuend"] - kwargs["subtrahend"], id=context.request.id + ) response = dispatch_pure( '{"jsonrpc": "2.0", "method": "subtract", "params": {"subtrahend": 23, "minuend": 42}, "id": 3}', @@ -386,10 +358,9 @@ def subtract(context: Context, **kwargs): def test_examples_notification(): - methods = {"update": lambda: None, "foobar": lambda: None} response = dispatch_pure( '{"jsonrpc": "2.0", "method": "update", "params": [1, 2, 3, 4, 5]}', - methods, + Methods(update=lambda: None, foobar=lambda: None), extra=None, serialize=default_serialize, deserialize=default_deserialize, @@ -399,7 +370,7 @@ def test_examples_notification(): # Second example response = dispatch_pure( '{"jsonrpc": "2.0", "method": "foobar"}', - methods, + Methods(update=lambda: None, foobar=lambda: None), extra=None, serialize=default_serialize, deserialize=default_deserialize, @@ -478,20 +449,21 @@ def test_examples_multiple_invalid_jsonrpc(): def test_examples_mixed_requests_and_notifications(): """ - We break the spec here. The examples put an invalid jsonrpc request in the mix here. - but it's removed to test the rest, because we're not validating each request - individually. Any invalid jsonrpc will respond with a single error message. + We break the spec here. The examples put an invalid jsonrpc request in the + mix here. but it's removed to test the rest, because we're not validating + each request individually. Any invalid jsonrpc will respond with a single + error message. The spec example includes this which invalidates the entire request: {"foo": "boo"}, """ methods = Methods( - **{ - "sum": lambda _, *args: sum(args), - "notify_hello": lambda _, *args: 19, - "subtract": lambda _, *args: args[0] - sum(args[1:]), - "get_data": lambda _: ["hello", 5], - } + sum=lambda ctx, *args: SuccessResponse(sum(args), id=ctx.request.id), + notify_hello=lambda ctx, *args: SuccessResponse(19, id=ctx.request.id), + subtract=lambda ctx, *args: SuccessResponse( + args[0] - sum(args[1:]), id=ctx.request.id + ), + get_data=lambda ctx: SuccessResponse(["hello", 5], id=ctx.request.id), ) requests = serialize( [ @@ -525,5 +497,6 @@ def test_examples_mixed_requests_and_notifications(): {"jsonrpc": "2.0", "result": ["hello", 5], "id": "9"}, ] assert isinstance(response, BatchResponse) + print(response.deserialized()) for r in response.deserialized(): assert r in expected diff --git a/tests/test_methods.py b/tests/test_methods.py index 6d3dc62..d34c774 100644 --- a/tests/test_methods.py +++ b/tests/test_methods.py @@ -2,30 +2,29 @@ import pytest -from jsonrpcserver.methods import Methods, add, validate_args, lookup -from jsonrpcserver.exceptions import MethodNotFoundError, InvalidParamsError +from jsonrpcserver.methods import Methods, add, validate_args def test_validate_no_arguments(): - validate_args(lambda: None) + assert validate_args(lambda: None) == "" def test_validate_no_arguments_too_many_positionals(): - with pytest.raises(InvalidParamsError): - validate_args(lambda: None, "foo") + assert validate_args(lambda: None, "foo") == "too many positional arguments" def test_validate_positionals(): - validate_args(lambda x: None, 1) + assert validate_args(lambda x: None, 1) == "" def test_validate_positionals_not_passed(): - with pytest.raises(InvalidParamsError): - validate_args(lambda x: None, foo="bar") + assert ( + validate_args(lambda x: None, foo="bar") == "missing a required argument: 'x'" + ) def test_validate_keywords(): - validate_args(lambda **kwargs: None, foo="bar") + assert validate_args(lambda **kwargs: None, foo="bar") == "" def test_validate_object_method(): @@ -40,8 +39,7 @@ def test_add_function(): def foo(): pass - methods = Methods(foo) - assert methods.items["foo"] is foo + assert Methods(foo).items["foo"] is foo def test_add_no_name(): @@ -151,7 +149,7 @@ def foo(): def test_add_function_custom_name_via_decorator(): methods = Methods() - @methods.add(name='bar') + @methods.add(name="bar") def foo(): pass @@ -185,20 +183,3 @@ def dog(): methods = Methods(cat, dog) assert methods.items["cat"] == cat assert methods.items["dog"] == dog - - -def test_lookup(): - def foo(): - pass - - methods = Methods() - methods.items["foo"] = foo - - assert lookup(methods, "foo") is foo - - -def test_lookup_failure(): - methods = Methods() - - with pytest.raises(MethodNotFoundError): - lookup(methods, "bar") diff --git a/tests/test_response.py b/tests/test_response.py index 5aac0d9..39ce13c 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -18,6 +18,43 @@ ) +# Moved from test_dispatcher, need to test to_json with a non-json-serializable +# value. +# def test_non_json_encodable_resonse(): +# def method(context: Context): +# return SuccessResponse(b"Hello, World", id=context.request.id) +# +# response = safe_call( +# Request(method="method", params=[], id=1), +# Methods(method), +# extra=None, +# serialize=default_serialize, +# ) +# # response must be serializable here +# str(response) +# assert isinstance(response, ErrorResponse) +# response_dict = response.deserialized() +# error_dict = response_dict["error"] +# assert error_dict["message"] == "Server error" +# assert error_dict["code"] == -32000 +# assert "data" in error_dict + + +# Moved from test_dispatcher, need to test to_json with batch responses +# def test_dispatch_requests_pure_batch_all_notifications(): +# """Should return a BatchResponse response, an empty list""" +# response = dispatch_requests_pure( +# [ +# Request(method="notify_sum", params=[1, 2, 4], id=NOID), +# Request(method="notify_hello", params=[7], id=NOID), +# ], +# Methods(ping), +# extra=None, +# serialize=default_serialize, +# ) +# assert str(response) == "" + + def test_response(): with pytest.raises(TypeError): Response() # Abstract