diff --git a/pyproject.toml b/pyproject.toml index 58f1220..31d488f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "grpc-interceptor" -version = "0.15.2" +version = "0.15.3" description = "Simplifies gRPC interceptors" license = "MIT" readme = "README.md" diff --git a/src/grpc_interceptor/exception_to_status.py b/src/grpc_interceptor/exception_to_status.py index b335614..ef6ad35 100644 --- a/src/grpc_interceptor/exception_to_status.py +++ b/src/grpc_interceptor/exception_to_status.py @@ -33,9 +33,9 @@ class ExceptionToStatusInterceptor(ServerInterceptor): status_on_unknown_exception: Specify what to do if an exception which is not a subclass of GrpcException is raised. If None, do nothing (by default, grpc will set the status to UNKNOWN). If not None, then the - status code will be set to this value. It must not be OK. The details - will be set to the value of repr(e), where e is the exception. In any - case, the exception will be propagated. + status code will be set to this value if `context.abort` hasn't been called + earlier. It must not be OK. The details will be set to the value of repr(e), + where e is the exception. In any case, the exception will be propagated. Raises: ValueError: If status_code is OK. @@ -93,7 +93,7 @@ def handle_exception( """ if isinstance(ex, GrpcException): context.abort(ex.status_code, ex.details) - else: + elif not context.code(): if self._status_on_unknown_exception is not None: context.abort(self._status_on_unknown_exception, repr(ex)) raise ex diff --git a/src/grpc_interceptor/testing/dummy_client.py b/src/grpc_interceptor/testing/dummy_client.py index bc7da41..a611ff9 100644 --- a/src/grpc_interceptor/testing/dummy_client.py +++ b/src/grpc_interceptor/testing/dummy_client.py @@ -3,6 +3,7 @@ import asyncio from concurrent import futures from contextlib import contextmanager +from inspect import iscoroutine from threading import Event, Thread from typing import ( Any, @@ -37,6 +38,21 @@ def _get_output(self, request: DummyRequest, context: grpc.ServicerContext) -> s return output + async def _get_output_async( + self, + request: DummyRequest, + context: grpc_aio.ServicerContext + ) -> str: + input = request.input + + output = input + if input in self._special_cases: + output = self._special_cases[input](input, context) + if iscoroutine(output): + output = await output + + return output + class DummyService(dummy_pb2_grpc.DummyServiceServicer, _SpecialCaseMixin): """A gRPC service used for testing. @@ -99,7 +115,7 @@ async def Execute( self, request: DummyRequest, context: grpc_aio.ServicerContext ) -> DummyResponse: """Echo the input, or take on of the special cases actions.""" - return DummyResponse(output=self._get_output(request, context)) + return DummyResponse(output=await self._get_output_async(request, context)) async def ExecuteClientStream( self, @@ -107,16 +123,17 @@ async def ExecuteClientStream( context: grpc_aio.ServicerContext, ) -> DummyResponse: """Iterate over the input and concatenates the strings into the output.""" - output = "".join( - [self._get_output(request, context) async for request in request_iter] - ) + output = "".join([ + await self._get_output_async(request, context) + async for request in request_iter + ]) # noqa: E501 return DummyResponse(output=output) async def ExecuteServerStream( self, request: DummyRequest, context: grpc_aio.ServicerContext ) -> AsyncGenerator[DummyResponse, None]: """Stream one character at a time from the input.""" - for c in self._get_output(request, context): + for c in await self._get_output_async(request, context): yield DummyResponse(output=c) async def ExecuteClientServerStream( @@ -126,7 +143,7 @@ async def ExecuteClientServerStream( ) -> AsyncGenerator[DummyResponse, None]: """Stream input to output.""" async for request in request_iter: - yield DummyResponse(output=self._get_output(request, context)) + yield DummyResponse(output=await self._get_output_async(request, context)) class AsyncReadWriteDummyService( @@ -147,7 +164,7 @@ async def Execute( self, request: DummyRequest, context: grpc_aio.ServicerContext ) -> DummyResponse: """Echo the input, or take on of the special cases actions.""" - return DummyResponse(output=self._get_output(request, context)) + return DummyResponse(output=await self._get_output_async(request, context)) async def ExecuteClientStream( self, @@ -160,7 +177,7 @@ async def ExecuteClientStream( request = await context.read() if request == grpc_aio.EOF: break - output.append(self._get_output(request, context)) + output.append(await self._get_output_async(request, context)) return DummyResponse(output="".join(output)) @@ -168,7 +185,7 @@ async def ExecuteServerStream( self, request: DummyRequest, context: grpc_aio.ServicerContext ) -> None: """Stream one character at a time from the input.""" - for c in self._get_output(request, context): + for c in await self._get_output_async(request, context): await context.write(DummyResponse(output=c)) async def ExecuteClientServerStream( @@ -182,7 +199,7 @@ async def ExecuteClientServerStream( if request == grpc_aio.EOF: break await context.write( - DummyResponse(output=self._get_output(request, context)) + DummyResponse(output=await self._get_output_async(request, context)) ) diff --git a/tests/test_exception_to_status.py b/tests/test_exception_to_status.py index 80e7ad4..319e812 100644 --- a/tests/test_exception_to_status.py +++ b/tests/test_exception_to_status.py @@ -1,8 +1,9 @@ """Test cases for ExceptionToStatusInterceptor.""" import re -from typing import List, Optional, Union +from typing import Any, List, Optional, Union import grpc +from grpc import aio as grpc_aio import pytest from grpc_interceptor import exceptions as gx @@ -155,6 +156,28 @@ def test_non_grpc_exception_with_override(aio): assert re.fullmatch(r"ValueError\('oops',?\)", e.value.details()) +@pytest.mark.parametrize("aio", [False, True]) +def test_aborted_context(aio): + """If the context is aborted, the exception is propagated.""" + def error(request: Any, context: grpc.ServicerContext) -> None: + context.abort(grpc.StatusCode.RESOURCE_EXHAUSTED, 'resource exhausted') + + async def async_error(request: Any, context: grpc_aio.ServicerContext) -> None: + await context.abort(grpc.StatusCode.RESOURCE_EXHAUSTED, 'resource exhausted') + + interceptors = _get_interceptors(aio, grpc.StatusCode.INTERNAL) + special_cases = { + "error": async_error if aio else error + } + + with dummy_client( + special_cases=special_cases, interceptors=interceptors, aio_server=aio + ) as client: + with pytest.raises(grpc.RpcError) as e: + client.Execute(DummyRequest(input="error")) + assert e.value.code() == grpc.StatusCode.RESOURCE_EXHAUSTED + + def test_override_with_ok(): """We cannot set the default status code to OK.""" with pytest.raises(ValueError):