Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
davfsa committed Dec 25, 2022
1 parent 11e7b50 commit ef3bd28
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 4 deletions.
7 changes: 3 additions & 4 deletions hikari/impl/interaction_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,8 @@ async def _consume_generator_listener(generator: typing.AsyncGenerator[typing.An
try:
await generator.__anext__()

# We expect only one!
exc = RuntimeError("Generator listener yielded more than once, expected only one yield")
await generator.athrow(exc)
# We expect only one yield!
await generator.athrow(RuntimeError("Generator listener yielded more than once, expected only one yield"))

except StopAsyncIteration:
pass
Expand Down Expand Up @@ -477,8 +476,8 @@ async def on_interaction(self, body: bytes, signature: bytes, timestamp: bytes)
if inspect.isasyncgen(call):
result = await call.__anext__()
task = asyncio.create_task(_consume_generator_listener(call))
self._running_generator_listeners.append(task)
task.add_done_callback(self._running_generator_listeners.remove)
self._running_generator_listeners.append(task)

else:
result = await call
Expand Down
167 changes: 167 additions & 0 deletions tests/hikari/impl/test_interaction_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,75 @@ async def write_headers(self, status_line: str, headers: "multidict.CIMultiDict[
pass


@pytest.mark.asyncio()
class TestConsumeGeneratorListener:
async def test_normal_behaviour(self):
async def mock_generator_listener():
nonlocal g_continued

yield

g_continued = True

g_continued = False
generator = mock_generator_listener()
# The function expects the generator to have already yielded once
await generator.__anext__()

await interaction_server_impl._consume_generator_listener(generator)

assert g_continued is True

async def test_when_more_than_one_yield(self):
async def mock_generator_listener():
nonlocal g_continued

yield

g_continued = True

yield

g_continued = False
generator = mock_generator_listener()
# The function expects the generator to have already yielded once
await generator.__anext__()

loop = mock.Mock()
with mock.patch.object(asyncio, "get_running_loop", return_value=loop):
await interaction_server_impl._consume_generator_listener(generator)

assert g_continued is True
args, _ = loop.call_exception_handler.call_args_list[0]
exception = args[0]["exception"]
assert isinstance(exception, RuntimeError)
assert exception.args == ("Generator listener yielded more than once, expected only one yield",)

async def test_when_exception(self):
async def mock_generator_listener():
nonlocal g_continued, exception

yield

g_continued = True

raise exception

g_continued = False
exception = ValueError("Some random exception")
generator = mock_generator_listener()
# The function expects the generator to have already yielded once
await generator.__anext__()

loop = mock.Mock()
with mock.patch.object(asyncio, "get_running_loop", return_value=loop):
await interaction_server_impl._consume_generator_listener(generator)

assert g_continued is True
args, _ = loop.call_exception_handler.call_args_list[0]
assert args[0]["exception"] is exception


@pytest.fixture()
def valid_edd25519():
body = (
Expand Down Expand Up @@ -516,18 +585,63 @@ async def test_aiohttp_hook_when_no_body(self, mock_interaction_server: interact

@pytest.mark.asyncio()
async def test_close(self, mock_interaction_server: interaction_server_impl.InteractionServer):
class TaskMock:
def __init__(self, done, cancelled):
self._awaited_count = 0
self._done = done
self._cancelled = cancelled

self.cancel = mock.Mock()

def __await__(self):
if False:
yield # Turns it into a generator

self._awaited_count += 1

raise asyncio.CancelledError

def assert_properly_cancelled(self):
self.cancel.assert_called_once_with()
assert self._awaited_count == 1

def assert_not_cancelled(self):
self.cancel.assert_not_called()
assert self._awaited_count == 0

def done(self):
return self._done

def cancelled(self):
return self._cancelled

mock_runner = mock.AsyncMock()
mock_event = mock.Mock()
mock_interaction_server._is_closing = False
mock_interaction_server._server = mock_runner
mock_interaction_server._close_event = mock_event
generator_listener_1 = TaskMock(False, False)
generator_listener_2 = TaskMock(False, True)
generator_listener_3 = TaskMock(True, False)
generator_listener_4 = TaskMock(True, True)
mock_interaction_server._running_generator_listeners = [
generator_listener_1,
generator_listener_2,
generator_listener_3,
generator_listener_4,
]

await mock_interaction_server.close()

mock_runner.shutdown.assert_awaited_once()
mock_runner.cleanup.assert_awaited_once()
mock_event.set.assert_called_once()
assert mock_interaction_server._is_closing is False
assert mock_interaction_server._running_generator_listeners == []
generator_listener_1.assert_properly_cancelled()
generator_listener_2.assert_not_cancelled()
generator_listener_3.assert_not_cancelled()
generator_listener_4.assert_not_cancelled()

@pytest.mark.asyncio()
async def test_close_when_closing(self, mock_interaction_server: interaction_server_impl.InteractionServer):
Expand All @@ -537,13 +651,16 @@ async def test_close_when_closing(self, mock_interaction_server: interaction_ser
mock_interaction_server._close_event = mock_event
mock_interaction_server._is_closing = True
mock_interaction_server.join = mock.AsyncMock()
mock_listener = object()
mock_interaction_server._running_generator_listeners = [mock_listener]

await mock_interaction_server.close()

mock_runner.shutdown.assert_not_called()
mock_runner.cleanup.assert_not_called()
mock_event.set.assert_not_called()
mock_interaction_server.join.assert_awaited_once()
assert mock_interaction_server._running_generator_listeners == [mock_listener]

@pytest.mark.asyncio()
async def test_close_when_not_running(self, mock_interaction_server: interaction_server_impl.InteractionServer):
Expand Down Expand Up @@ -596,6 +713,56 @@ async def test_on_interaction(
assert result.payload == b'{"ok": "No boomer"}'
assert result.status_code == 200

@pytest.mark.asyncio()
async def test_on_interaction_with_generator_listener(
self,
mock_interaction_server: interaction_server_impl.InteractionServer,
mock_entity_factory: entity_factory_impl.EntityFactoryImpl,
public_key: bytes,
valid_edd25519: bytes,
valid_payload: bytes,
):
async def mock_generator_listener(event):
nonlocal g_called, g_complete

g_called = True
assert event is mock_entity_factory.deserialize_interaction.return_value

yield mock_builder

g_complete = True

mock_interaction_server._public_key = nacl.signing.VerifyKey(public_key)
mock_file_1 = mock.Mock()
mock_file_2 = mock.Mock()
mock_entity_factory.deserialize_interaction.return_value = base_interactions.PartialInteraction(
app=None, id=123, application_id=541324, type=2, token="ok", version=1
)
mock_builder = mock.Mock(build=mock.Mock(return_value=({"ok": "No boomer"}, [mock_file_1, mock_file_2])))
g_called = False
g_complete = False
mock_interaction_server.set_listener(base_interactions.PartialInteraction, mock_generator_listener)

result = await mock_interaction_server.on_interaction(*valid_edd25519)

mock_builder.build.assert_called_once_with(mock_entity_factory)
mock_entity_factory.deserialize_interaction.assert_called_once_with(valid_payload)
assert result.content_type == "application/json"
assert result.charset == "UTF-8"
assert result.files == [mock_file_1, mock_file_2]
assert result.headers is None
assert result.payload == b'{"ok": "No boomer"}'
assert result.status_code == 200

assert g_called is True
assert g_complete is False
assert len(mock_interaction_server._running_generator_listeners) != 0
# Give some time for the task to complete
await asyncio.sleep(hikari_test_helpers.REASONABLE_QUICK_RESPONSE_TIME)

assert g_complete is True
assert len(mock_interaction_server._running_generator_listeners) == 0

@pytest.mark.asyncio()
async def test_on_interaction_calls__fetch_public_key(
self, mock_interaction_server: interaction_server_impl.InteractionServer
Expand Down

0 comments on commit ef3bd28

Please sign in to comment.