Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support yielding in interaction handler #1383

Merged
merged 13 commits into from
Jan 1, 2023
1 change: 1 addition & 0 deletions changes/1383.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Support yielding in interaction listeners.
13 changes: 10 additions & 3 deletions hikari/api/interaction_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@
]


ListenerT = typing.Callable[["_InteractionT_co"], typing.Awaitable["_ResponseT_co"]]
ListenerT = typing.Union[
typing.Callable[["_InteractionT_co"], typing.Awaitable["_ResponseT_co"]],
typing.Callable[["_InteractionT_co"], typing.AsyncGenerator["_ResponseT_co", None]],
]
"""Type hint of a Interaction server's listener callback.

This should be an async callback which takes in one positional argument which
Expand Down Expand Up @@ -255,8 +258,12 @@ def set_listener(
interaction_type : typing.Type[hikari.interactions.base_interactions.PartialInteraction]
The type of interaction this listener should be registered for.
listener : typing.Optional[ListenerT[hikari.interactions.base_interactions.PartialInteraction, hikari.api.special_endpoints.InteractionResponseBuilder]]
The asynchronous listener callback to set or `None` to
unset the previous listener.
The asynchronous listener callback to set or `None` to unset the previous listener.

An asynchronous listener can be either a normal coroutine or an
async generator which should yield exactly once. This allows
sending an initial response to the request, while still
later executing further logic.

Other Parameters
----------------
Expand Down
36 changes: 35 additions & 1 deletion hikari/impl/interaction_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
__all__: typing.Sequence[str] = ("InteractionServer",)

import asyncio
import inspect
import logging
import typing

Expand Down Expand Up @@ -166,6 +167,22 @@ async def write(self, writer: aiohttp.abc.AbstractStreamWriter) -> None:
await writer.write(chunk)


async def _consume_generator_listener(generator: typing.AsyncGenerator[typing.Any, None]) -> None:
try:
await generator.__anext__()

# We expect only one yield!
await generator.athrow(RuntimeError("Generator listener yielded more than once, expected only one yield"))

except StopAsyncIteration:
pass

except Exception as exc:
asyncio.get_running_loop().call_exception_handler(
{"message": "Exception occurred during interaction post dispatch", "exception": exc}
)


class InteractionServer(interaction_server.InteractionServer):
"""Standard implementation of `hikari.api.interaction_server.InteractionServer`.

Expand Down Expand Up @@ -201,6 +218,7 @@ class InteractionServer(interaction_server.InteractionServer):
"_public_key",
"_rest_client",
"_server",
"_running_generator_listeners",
)

def __init__(
Expand Down Expand Up @@ -237,6 +255,7 @@ def __init__(
self._rest_client = rest_client
self._server: typing.Optional[aiohttp.web_runner.AppRunner] = None
self._public_key = nacl.signing.VerifyKey(public_key) if public_key is not None else None
self._running_generator_listeners: typing.List[asyncio.Task[None]] = []

@property
def is_alive(self) -> bool:
Expand Down Expand Up @@ -365,6 +384,11 @@ async def close(self) -> None:
await self._server.cleanup()
self._server = None
self._application_fetch_lock = None

# Wait for handlers to complete
await asyncio.gather(*self._running_generator_listeners)
self._running_generator_listeners = []

self._close_event.set()
self._close_event = None
self._is_closing = False
Expand Down Expand Up @@ -440,7 +464,17 @@ async def on_interaction(self, body: bytes, signature: bytes, timestamp: bytes)
if listener := self._listeners.get(type(interaction)):
_LOGGER.debug("Dispatching interaction %s", interaction.id)
try:
result = await listener(interaction)
call = listener(interaction)

if inspect.isasyncgen(call):
result = await call.__anext__()
task = asyncio.create_task(_consume_generator_listener(call))
task.add_done_callback(self._running_generator_listeners.remove)
self._running_generator_listeners.append(task)

else:
result = await call

raw_payload, files = result.build(self._entity_factory)
payload = self._dumps(raw_payload)

Expand Down
9 changes: 9 additions & 0 deletions hikari/interactions/command_interactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,15 @@ def build_deferred_response(self) -> special_endpoints.InteractionDeferredBuilde
the result of this call can be returned as is without any modifications
being made to it.

Examples
--------
.. code-block:: python

async def handle_command_interaction(interaction: CommandInteraction) -> InteractionMessageBuilder:
yield interaction.build_deferred_response()

await interaction.edit_initial_response("Pong!")

Returns
-------
hikari.api.special_endpoints.InteractionMessageBuilder
Expand Down
144 changes: 142 additions & 2 deletions tests/hikari/impl/test_interaction_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,75 @@ async def write_headers(self, status_line: str, headers: "multidict.CIMultiDict[
pass


@pytest.mark.asyncio()
class TestConsumeGeneratorListener:
async def test_normal_behaviour(self):
async def mock_generator_listener():
nonlocal g_continued

yield

g_continued = True

g_continued = False
generator = mock_generator_listener()
# The function expects the generator to have already yielded once
await generator.__anext__()

await interaction_server_impl._consume_generator_listener(generator)

assert g_continued is True

async def test_when_more_than_one_yield(self):
async def mock_generator_listener():
nonlocal g_continued

yield

g_continued = True

yield

g_continued = False
generator = mock_generator_listener()
# The function expects the generator to have already yielded once
await generator.__anext__()

loop = mock.Mock()
with mock.patch.object(asyncio, "get_running_loop", return_value=loop):
await interaction_server_impl._consume_generator_listener(generator)

assert g_continued is True
args, _ = loop.call_exception_handler.call_args_list[0]
exception = args[0]["exception"]
assert isinstance(exception, RuntimeError)
assert exception.args == ("Generator listener yielded more than once, expected only one yield",)

async def test_when_exception(self):
async def mock_generator_listener():
nonlocal g_continued, exception

yield

g_continued = True

raise exception

g_continued = False
exception = ValueError("Some random exception")
generator = mock_generator_listener()
# The function expects the generator to have already yielded once
await generator.__anext__()

loop = mock.Mock()
with mock.patch.object(asyncio, "get_running_loop", return_value=loop):
await interaction_server_impl._consume_generator_listener(generator)

assert g_continued is True
args, _ = loop.call_exception_handler.call_args_list[0]
assert args[0]["exception"] is exception


@pytest.fixture()
def valid_edd25519():
body = (
Expand Down Expand Up @@ -521,13 +590,31 @@ async def test_close(self, mock_interaction_server: interaction_server_impl.Inte
mock_interaction_server._is_closing = False
mock_interaction_server._server = mock_runner
mock_interaction_server._close_event = mock_event

await mock_interaction_server.close()
generator_listener_1 = mock.Mock()
generator_listener_2 = mock.Mock()
generator_listener_3 = mock.Mock()
generator_listener_4 = mock.Mock()
mock_interaction_server._running_generator_listeners = [
generator_listener_1,
generator_listener_2,
generator_listener_3,
generator_listener_4,
]

with mock.patch.object(asyncio, "gather", new=mock.AsyncMock()) as gather:
await mock_interaction_server.close()

mock_runner.shutdown.assert_awaited_once()
mock_runner.cleanup.assert_awaited_once()
mock_event.set.assert_called_once()
assert mock_interaction_server._is_closing is False
assert mock_interaction_server._running_generator_listeners == []
gather.assert_awaited_once_with(
generator_listener_1,
generator_listener_2,
generator_listener_3,
generator_listener_4,
)

@pytest.mark.asyncio()
async def test_close_when_closing(self, mock_interaction_server: interaction_server_impl.InteractionServer):
Expand All @@ -537,13 +624,16 @@ async def test_close_when_closing(self, mock_interaction_server: interaction_ser
mock_interaction_server._close_event = mock_event
mock_interaction_server._is_closing = True
mock_interaction_server.join = mock.AsyncMock()
mock_listener = object()
mock_interaction_server._running_generator_listeners = [mock_listener]

await mock_interaction_server.close()

mock_runner.shutdown.assert_not_called()
mock_runner.cleanup.assert_not_called()
mock_event.set.assert_not_called()
mock_interaction_server.join.assert_awaited_once()
assert mock_interaction_server._running_generator_listeners == [mock_listener]

@pytest.mark.asyncio()
async def test_close_when_not_running(self, mock_interaction_server: interaction_server_impl.InteractionServer):
Expand Down Expand Up @@ -596,6 +686,56 @@ async def test_on_interaction(
assert result.payload == b'{"ok": "No boomer"}'
assert result.status_code == 200

@pytest.mark.asyncio()
async def test_on_interaction_with_generator_listener(
self,
mock_interaction_server: interaction_server_impl.InteractionServer,
mock_entity_factory: entity_factory_impl.EntityFactoryImpl,
public_key: bytes,
valid_edd25519: bytes,
valid_payload: bytes,
):
async def mock_generator_listener(event):
nonlocal g_called, g_complete

g_called = True
assert event is mock_entity_factory.deserialize_interaction.return_value

yield mock_builder

g_complete = True

mock_interaction_server._public_key = nacl.signing.VerifyKey(public_key)
mock_file_1 = mock.Mock()
mock_file_2 = mock.Mock()
mock_entity_factory.deserialize_interaction.return_value = base_interactions.PartialInteraction(
app=None, id=123, application_id=541324, type=2, token="ok", version=1
)
mock_builder = mock.Mock(build=mock.Mock(return_value=({"ok": "No boomer"}, [mock_file_1, mock_file_2])))
g_called = False
g_complete = False
mock_interaction_server.set_listener(base_interactions.PartialInteraction, mock_generator_listener)

result = await mock_interaction_server.on_interaction(*valid_edd25519)

mock_builder.build.assert_called_once_with(mock_entity_factory)
mock_entity_factory.deserialize_interaction.assert_called_once_with(valid_payload)
assert result.content_type == "application/json"
assert result.charset == "UTF-8"
assert result.files == [mock_file_1, mock_file_2]
assert result.headers is None
assert result.payload == b'{"ok": "No boomer"}'
assert result.status_code == 200

assert g_called is True
assert g_complete is False
assert len(mock_interaction_server._running_generator_listeners) != 0
# Give some time for the task to complete
await asyncio.sleep(hikari_test_helpers.REASONABLE_QUICK_RESPONSE_TIME)

assert g_complete is True
assert len(mock_interaction_server._running_generator_listeners) == 0

@pytest.mark.asyncio()
async def test_on_interaction_calls__fetch_public_key(
self, mock_interaction_server: interaction_server_impl.InteractionServer
Expand Down