Skip to content

Commit

Permalink
Fix wrong status code if context.abort has been called (#43)
Browse files Browse the repository at this point in the history
* Use correct status code if context.abort has been called earlier, not status_on_unknown_exception

* use async context

* Add linting support for Python 3.11
  • Loading branch information
thangtp authored Aug 15, 2023
1 parent 64d1206 commit 25edf03
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 16 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "grpc-interceptor"
version = "0.15.2"
version = "0.15.3"
description = "Simplifies gRPC interceptors"
license = "MIT"
readme = "README.md"
Expand Down
8 changes: 4 additions & 4 deletions src/grpc_interceptor/exception_to_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ class ExceptionToStatusInterceptor(ServerInterceptor):
status_on_unknown_exception: Specify what to do if an exception which is
not a subclass of GrpcException is raised. If None, do nothing (by
default, grpc will set the status to UNKNOWN). If not None, then the
status code will be set to this value. It must not be OK. The details
will be set to the value of repr(e), where e is the exception. In any
case, the exception will be propagated.
status code will be set to this value if `context.abort` hasn't been called
earlier. It must not be OK. The details will be set to the value of repr(e),
where e is the exception. In any case, the exception will be propagated.
Raises:
ValueError: If status_code is OK.
Expand Down Expand Up @@ -93,7 +93,7 @@ def handle_exception(
"""
if isinstance(ex, GrpcException):
context.abort(ex.status_code, ex.details)
else:
elif not context.code():
if self._status_on_unknown_exception is not None:
context.abort(self._status_on_unknown_exception, repr(ex))
raise ex
Expand Down
37 changes: 27 additions & 10 deletions src/grpc_interceptor/testing/dummy_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
from concurrent import futures
from contextlib import contextmanager
from inspect import iscoroutine
from threading import Event, Thread
from typing import (
Any,
Expand Down Expand Up @@ -37,6 +38,21 @@ def _get_output(self, request: DummyRequest, context: grpc.ServicerContext) -> s

return output

async def _get_output_async(
self,
request: DummyRequest,
context: grpc_aio.ServicerContext
) -> str:
input = request.input

output = input
if input in self._special_cases:
output = self._special_cases[input](input, context)
if iscoroutine(output):
output = await output

return output


class DummyService(dummy_pb2_grpc.DummyServiceServicer, _SpecialCaseMixin):
"""A gRPC service used for testing.
Expand Down Expand Up @@ -99,24 +115,25 @@ async def Execute(
self, request: DummyRequest, context: grpc_aio.ServicerContext
) -> DummyResponse:
"""Echo the input, or take on of the special cases actions."""
return DummyResponse(output=self._get_output(request, context))
return DummyResponse(output=await self._get_output_async(request, context))

async def ExecuteClientStream(
self,
request_iter: AsyncIterable[DummyRequest],
context: grpc_aio.ServicerContext,
) -> DummyResponse:
"""Iterate over the input and concatenates the strings into the output."""
output = "".join(
[self._get_output(request, context) async for request in request_iter]
)
output = "".join([
await self._get_output_async(request, context)
async for request in request_iter
]) # noqa: E501
return DummyResponse(output=output)

async def ExecuteServerStream(
self, request: DummyRequest, context: grpc_aio.ServicerContext
) -> AsyncGenerator[DummyResponse, None]:
"""Stream one character at a time from the input."""
for c in self._get_output(request, context):
for c in await self._get_output_async(request, context):
yield DummyResponse(output=c)

async def ExecuteClientServerStream(
Expand All @@ -126,7 +143,7 @@ async def ExecuteClientServerStream(
) -> AsyncGenerator[DummyResponse, None]:
"""Stream input to output."""
async for request in request_iter:
yield DummyResponse(output=self._get_output(request, context))
yield DummyResponse(output=await self._get_output_async(request, context))


class AsyncReadWriteDummyService(
Expand All @@ -147,7 +164,7 @@ async def Execute(
self, request: DummyRequest, context: grpc_aio.ServicerContext
) -> DummyResponse:
"""Echo the input, or take on of the special cases actions."""
return DummyResponse(output=self._get_output(request, context))
return DummyResponse(output=await self._get_output_async(request, context))

async def ExecuteClientStream(
self,
Expand All @@ -160,15 +177,15 @@ async def ExecuteClientStream(
request = await context.read()
if request == grpc_aio.EOF:
break
output.append(self._get_output(request, context))
output.append(await self._get_output_async(request, context))

return DummyResponse(output="".join(output))

async def ExecuteServerStream(
self, request: DummyRequest, context: grpc_aio.ServicerContext
) -> None:
"""Stream one character at a time from the input."""
for c in self._get_output(request, context):
for c in await self._get_output_async(request, context):
await context.write(DummyResponse(output=c))

async def ExecuteClientServerStream(
Expand All @@ -182,7 +199,7 @@ async def ExecuteClientServerStream(
if request == grpc_aio.EOF:
break
await context.write(
DummyResponse(output=self._get_output(request, context))
DummyResponse(output=await self._get_output_async(request, context))
)


Expand Down
25 changes: 24 additions & 1 deletion tests/test_exception_to_status.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Test cases for ExceptionToStatusInterceptor."""
import re
from typing import List, Optional, Union
from typing import Any, List, Optional, Union

import grpc
from grpc import aio as grpc_aio
import pytest

from grpc_interceptor import exceptions as gx
Expand Down Expand Up @@ -155,6 +156,28 @@ def test_non_grpc_exception_with_override(aio):
assert re.fullmatch(r"ValueError\('oops',?\)", e.value.details())


@pytest.mark.parametrize("aio", [False, True])
def test_aborted_context(aio):
"""If the context is aborted, the exception is propagated."""
def error(request: Any, context: grpc.ServicerContext) -> None:
context.abort(grpc.StatusCode.RESOURCE_EXHAUSTED, 'resource exhausted')

async def async_error(request: Any, context: grpc_aio.ServicerContext) -> None:
await context.abort(grpc.StatusCode.RESOURCE_EXHAUSTED, 'resource exhausted')

interceptors = _get_interceptors(aio, grpc.StatusCode.INTERNAL)
special_cases = {
"error": async_error if aio else error
}

with dummy_client(
special_cases=special_cases, interceptors=interceptors, aio_server=aio
) as client:
with pytest.raises(grpc.RpcError) as e:
client.Execute(DummyRequest(input="error"))
assert e.value.code() == grpc.StatusCode.RESOURCE_EXHAUSTED


def test_override_with_ok():
"""We cannot set the default status code to OK."""
with pytest.raises(ValueError):
Expand Down

0 comments on commit 25edf03

Please sign in to comment.