From ef3bd28aba889db12c37a56072157d9e149ffef0 Mon Sep 17 00:00:00 2001 From: davfsa Date: Sun, 25 Dec 2022 12:56:37 +0100 Subject: [PATCH] Add tests --- hikari/impl/interaction_server.py | 7 +- tests/hikari/impl/test_interaction_server.py | 167 +++++++++++++++++++ 2 files changed, 170 insertions(+), 4 deletions(-) diff --git a/hikari/impl/interaction_server.py b/hikari/impl/interaction_server.py index 02d50dc839..eb8ad0b0fd 100644 --- a/hikari/impl/interaction_server.py +++ b/hikari/impl/interaction_server.py @@ -171,9 +171,8 @@ async def _consume_generator_listener(generator: typing.AsyncGenerator[typing.An try: await generator.__anext__() - # We expect only one! - exc = RuntimeError("Generator listener yielded more than once, expected only one yield") - await generator.athrow(exc) + # We expect only one yield! + await generator.athrow(RuntimeError("Generator listener yielded more than once, expected only one yield")) except StopAsyncIteration: pass @@ -477,8 +476,8 @@ async def on_interaction(self, body: bytes, signature: bytes, timestamp: bytes) if inspect.isasyncgen(call): result = await call.__anext__() task = asyncio.create_task(_consume_generator_listener(call)) - self._running_generator_listeners.append(task) task.add_done_callback(self._running_generator_listeners.remove) + self._running_generator_listeners.append(task) else: result = await call diff --git a/tests/hikari/impl/test_interaction_server.py b/tests/hikari/impl/test_interaction_server.py index 8f995fbf91..ac7da471d1 100644 --- a/tests/hikari/impl/test_interaction_server.py +++ b/tests/hikari/impl/test_interaction_server.py @@ -74,6 +74,75 @@ async def write_headers(self, status_line: str, headers: "multidict.CIMultiDict[ pass +@pytest.mark.asyncio() +class TestConsumeGeneratorListener: + async def test_normal_behaviour(self): + async def mock_generator_listener(): + nonlocal g_continued + + yield + + g_continued = True + + g_continued = False + generator = mock_generator_listener() + # The function expects the generator to have already yielded once + await generator.__anext__() + + await interaction_server_impl._consume_generator_listener(generator) + + assert g_continued is True + + async def test_when_more_than_one_yield(self): + async def mock_generator_listener(): + nonlocal g_continued + + yield + + g_continued = True + + yield + + g_continued = False + generator = mock_generator_listener() + # The function expects the generator to have already yielded once + await generator.__anext__() + + loop = mock.Mock() + with mock.patch.object(asyncio, "get_running_loop", return_value=loop): + await interaction_server_impl._consume_generator_listener(generator) + + assert g_continued is True + args, _ = loop.call_exception_handler.call_args_list[0] + exception = args[0]["exception"] + assert isinstance(exception, RuntimeError) + assert exception.args == ("Generator listener yielded more than once, expected only one yield",) + + async def test_when_exception(self): + async def mock_generator_listener(): + nonlocal g_continued, exception + + yield + + g_continued = True + + raise exception + + g_continued = False + exception = ValueError("Some random exception") + generator = mock_generator_listener() + # The function expects the generator to have already yielded once + await generator.__anext__() + + loop = mock.Mock() + with mock.patch.object(asyncio, "get_running_loop", return_value=loop): + await interaction_server_impl._consume_generator_listener(generator) + + assert g_continued is True + args, _ = loop.call_exception_handler.call_args_list[0] + assert args[0]["exception"] is exception + + @pytest.fixture() def valid_edd25519(): body = ( @@ -516,11 +585,51 @@ async def test_aiohttp_hook_when_no_body(self, mock_interaction_server: interact @pytest.mark.asyncio() async def test_close(self, mock_interaction_server: interaction_server_impl.InteractionServer): + class TaskMock: + def __init__(self, done, cancelled): + self._awaited_count = 0 + self._done = done + self._cancelled = cancelled + + self.cancel = mock.Mock() + + def __await__(self): + if False: + yield # Turns it into a generator + + self._awaited_count += 1 + + raise asyncio.CancelledError + + def assert_properly_cancelled(self): + self.cancel.assert_called_once_with() + assert self._awaited_count == 1 + + def assert_not_cancelled(self): + self.cancel.assert_not_called() + assert self._awaited_count == 0 + + def done(self): + return self._done + + def cancelled(self): + return self._cancelled + mock_runner = mock.AsyncMock() mock_event = mock.Mock() mock_interaction_server._is_closing = False mock_interaction_server._server = mock_runner mock_interaction_server._close_event = mock_event + generator_listener_1 = TaskMock(False, False) + generator_listener_2 = TaskMock(False, True) + generator_listener_3 = TaskMock(True, False) + generator_listener_4 = TaskMock(True, True) + mock_interaction_server._running_generator_listeners = [ + generator_listener_1, + generator_listener_2, + generator_listener_3, + generator_listener_4, + ] await mock_interaction_server.close() @@ -528,6 +637,11 @@ async def test_close(self, mock_interaction_server: interaction_server_impl.Inte mock_runner.cleanup.assert_awaited_once() mock_event.set.assert_called_once() assert mock_interaction_server._is_closing is False + assert mock_interaction_server._running_generator_listeners == [] + generator_listener_1.assert_properly_cancelled() + generator_listener_2.assert_not_cancelled() + generator_listener_3.assert_not_cancelled() + generator_listener_4.assert_not_cancelled() @pytest.mark.asyncio() async def test_close_when_closing(self, mock_interaction_server: interaction_server_impl.InteractionServer): @@ -537,6 +651,8 @@ async def test_close_when_closing(self, mock_interaction_server: interaction_ser mock_interaction_server._close_event = mock_event mock_interaction_server._is_closing = True mock_interaction_server.join = mock.AsyncMock() + mock_listener = object() + mock_interaction_server._running_generator_listeners = [mock_listener] await mock_interaction_server.close() @@ -544,6 +660,7 @@ async def test_close_when_closing(self, mock_interaction_server: interaction_ser mock_runner.cleanup.assert_not_called() mock_event.set.assert_not_called() mock_interaction_server.join.assert_awaited_once() + assert mock_interaction_server._running_generator_listeners == [mock_listener] @pytest.mark.asyncio() async def test_close_when_not_running(self, mock_interaction_server: interaction_server_impl.InteractionServer): @@ -596,6 +713,56 @@ async def test_on_interaction( assert result.payload == b'{"ok": "No boomer"}' assert result.status_code == 200 + @pytest.mark.asyncio() + async def test_on_interaction_with_generator_listener( + self, + mock_interaction_server: interaction_server_impl.InteractionServer, + mock_entity_factory: entity_factory_impl.EntityFactoryImpl, + public_key: bytes, + valid_edd25519: bytes, + valid_payload: bytes, + ): + async def mock_generator_listener(event): + nonlocal g_called, g_complete + + g_called = True + assert event is mock_entity_factory.deserialize_interaction.return_value + + yield mock_builder + + g_complete = True + + mock_interaction_server._public_key = nacl.signing.VerifyKey(public_key) + mock_file_1 = mock.Mock() + mock_file_2 = mock.Mock() + mock_entity_factory.deserialize_interaction.return_value = base_interactions.PartialInteraction( + app=None, id=123, application_id=541324, type=2, token="ok", version=1 + ) + mock_builder = mock.Mock(build=mock.Mock(return_value=({"ok": "No boomer"}, [mock_file_1, mock_file_2]))) + g_called = False + g_complete = False + mock_interaction_server.set_listener(base_interactions.PartialInteraction, mock_generator_listener) + + result = await mock_interaction_server.on_interaction(*valid_edd25519) + + mock_builder.build.assert_called_once_with(mock_entity_factory) + mock_entity_factory.deserialize_interaction.assert_called_once_with(valid_payload) + assert result.content_type == "application/json" + assert result.charset == "UTF-8" + assert result.files == [mock_file_1, mock_file_2] + assert result.headers is None + assert result.payload == b'{"ok": "No boomer"}' + assert result.status_code == 200 + + assert g_called is True + assert g_complete is False + assert len(mock_interaction_server._running_generator_listeners) != 0 + # Give some time for the task to complete + await asyncio.sleep(hikari_test_helpers.REASONABLE_QUICK_RESPONSE_TIME) + + assert g_complete is True + assert len(mock_interaction_server._running_generator_listeners) == 0 + @pytest.mark.asyncio() async def test_on_interaction_calls__fetch_public_key( self, mock_interaction_server: interaction_server_impl.InteractionServer