Skip to content

Commit

Permalink
Improve filtering by using per event bitmasks
Browse files Browse the repository at this point in the history
  • Loading branch information
davfsa committed Mar 1, 2022
1 parent fa1b5b0 commit 0233162
Show file tree
Hide file tree
Showing 9 changed files with 673 additions and 916 deletions.
15 changes: 5 additions & 10 deletions hikari/api/entity_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,7 @@


class GatewayGuildDefinition(abc.ABC):
"""Structure for handling entities within guild create and update events.
!!! warning
The methods on this class may raise `builtins.LookupError` if called
when the relevant resource isn't available in the inner payload.
"""
"""Structure for handling entities within guild create and update events."""

__slots__: typing.Sequence[str] = ()

Expand All @@ -71,7 +66,7 @@ def id(self) -> snowflakes.Snowflake:
"""ID of the guild the definition is for."""

@abc.abstractmethod
def channels(self) -> typing.Mapping[snowflakes.Snowflake, channel_models.GuildChannel]:
def channels(self) -> typing.Optional[typing.Mapping[snowflakes.Snowflake, channel_models.GuildChannel]]:
"""Get a mapping of channel IDs to the channels that belong to the guild."""

@abc.abstractmethod
Expand All @@ -83,15 +78,15 @@ def guild(self) -> guild_models.GatewayGuild:
"""Get the object of the guild this definition is for."""

@abc.abstractmethod
def members(self) -> typing.Mapping[snowflakes.Snowflake, guild_models.Member]:
def members(self) -> typing.Optional[typing.Mapping[snowflakes.Snowflake, guild_models.Member]]:
"""Get a mapping of user IDs to the members that belong to the guild.
!!! note
This may be a partial mapping of members in the guild.
"""

@abc.abstractmethod
def presences(self) -> typing.Mapping[snowflakes.Snowflake, presence_models.MemberPresence]:
def presences(self) -> typing.Optional[typing.Mapping[snowflakes.Snowflake, presence_models.MemberPresence]]:
"""Get a mapping of user IDs to the presences that are active in the guild.
!!! note
Expand All @@ -103,7 +98,7 @@ def roles(self) -> typing.Mapping[snowflakes.Snowflake, guild_models.Role]:
"""Get a mapping of role IDs to the roles that belong to the guild."""

@abc.abstractmethod
def voice_states(self) -> typing.Mapping[snowflakes.Snowflake, voice_models.VoiceState]:
def voice_states(self) -> typing.Optional[typing.Mapping[snowflakes.Snowflake, voice_models.VoiceState]]:
"""Get a mapping of user IDs to the voice states that are active in the guild."""


Expand Down
5 changes: 3 additions & 2 deletions hikari/api/event_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ async def on_message(event):
Wait_for: `hikari.api.event_manager.EventManager.wait_for`
"""

# FIXME: Breaking change here, make sure to document
@abc.abstractmethod
def get_listeners(
self,
Expand All @@ -372,8 +373,8 @@ def get_listeners(
The event type to look for.
`T` must be a subclass of `hikari.events.base_events.Event`.
polymorphic : builtins.bool
If `builtins.True`, this will also return the listeners of the
subclasses of the given event type. If `builtins.False`, then
If `builtins.True`, this will also return the listeners for all the
event types `event_type` will dispatch. If `builtins.False`, then
only listeners for this class specifically are returned. The
default is `builtins.True`.
Expand Down
15 changes: 14 additions & 1 deletion hikari/events/base_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,16 @@
REQUIRED_INTENTS_ATTR: typing.Final[str] = "___requiresintents___"
NO_RECURSIVE_THROW_ATTR: typing.Final[str] = "___norecursivethrow___"

_id_counter = 1 # We start at 1 since Event is 0


class Event(abc.ABC):
"""Base event type that all Hikari events should subclass."""

__slots__: typing.Sequence[str] = ()

__dispatches: typing.ClassVar[typing.Tuple[typing.Type[Event], ...]]
__bitmask: typing.ClassVar[int]

def __init_subclass__(cls) -> None:
super().__init_subclass__()
Expand All @@ -68,11 +71,16 @@ def __init_subclass__(cls) -> None:
Event.__dispatches
except AttributeError:
Event.__dispatches = (Event,)
Event.__bitmask = 1 << 0

global _id_counter

mro = cls.mro()
# We don't have to explicitly include Event here as issubclass(Event, Event) returns True.
# Non-event classes should be ignored.
cls.__dispatches = tuple(cls for cls in mro if issubclass(cls, Event))
cls.__dispatches = tuple(sub_cls for sub_cls in mro if issubclass(sub_cls, Event))
cls.__bitmask = 1 << _id_counter
_id_counter += 1

@property
@abc.abstractmethod
Expand All @@ -90,6 +98,11 @@ def dispatches(cls) -> typing.Sequence[typing.Type[Event]]:
"""Sequence of the event classes this event is dispatched as."""
return cls.__dispatches

@classmethod
def bitmask(cls) -> int:
"""Bitmask for this event."""
return cls.__bitmask


def get_required_intents_for(event_type: typing.Type[Event]) -> typing.Collection[intents.Intents]:
"""Retrieve the intents that are required to listen to an event type.
Expand Down
17 changes: 5 additions & 12 deletions hikari/impl/entity_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,6 @@ def members(self) -> typing.Optional[typing.Mapping[snowflakes.Snowflake, guild_
for m in self._payload["members"]
}

for member_payload in self._payload["members"]:
member = self._entity_factory.deserialize_member(member_payload, guild_id=self.id)
self._members[member.user.id] = member
else:
self._members = None

Expand All @@ -309,21 +306,17 @@ def presences(self) -> typing.Optional[typing.Mapping[snowflakes.Snowflake, pres
for p in self._payload["presences"]
}

for presence_payload in self._payload["presences"]:
presence = self._entity_factory.deserialize_member_presence(presence_payload, guild_id=self.id)
self._presences[presence.user_id] = presence
else:
self._presences = None

return self._presences

def roles(self) -> typing.Mapping[snowflakes.Snowflake, guild_models.Role]:
if self._roles is None:
if self._roles is undefined.UNDEFINED:
self._roles = {
snowflakes.Snowflake(r["id"]): self._entity_factory.deserialize_role(r, guild_id=self.id)
for r in self._payload["roles"]
}
if self._roles is undefined.UNDEFINED:
self._roles = {
snowflakes.Snowflake(r["id"]): self._entity_factory.deserialize_role(r, guild_id=self.id)
for r in self._payload["roles"]
}

return self._roles

Expand Down
91 changes: 52 additions & 39 deletions hikari/impl/event_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,13 +162,33 @@ async def on_channel_pins_update(self, shard: gateway_shard.GatewayShard, payloa
# TODO: we need a method for this specifically
await self.dispatch(self._event_factory.deserialize_channel_pins_update_event(shard, payload))

# Internal granularity is preferred for GUILD_CREATE over decorator based filtering due to its large cache scope.
async def on_guild_create(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None:
# Internal granularity is preferred for GUILD_CREATE over decorator based filtering due to its large scope.
async def on_guild_create( # noqa: C901 - Function too complex
self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject
) -> None:
"""See https://discord.com/developers/docs/topics/gateway#guild-create for more info."""
enabled_for_event = self._enabled_for_event(guild_events.GuildAvailableEvent)
if not enabled_for_event and self._cache:
event: typing.Union[guild_events.GuildAvailableEvent, guild_events.GuildJoinEvent, None]

if "unavailable" in payload and self._enabled_for_event(guild_events.GuildAvailableEvent):
event = self._event_factory.deserialize_guild_available_event(shard, payload)
elif "unavailable" not in payload and self._enabled_for_event(guild_events.GuildJoinEvent):
event = self._event_factory.deserialize_guild_join_event(shard, payload)
else:
event = None

if event:
# We also filter here to prevent iterating over them and calling a function that won't do anything
channels = event.channels if self._cache_enabled_for(config.CacheComponents.GUILD_CHANNELS) else None
emojis = event.emojis if self._cache_enabled_for(config.CacheComponents.EMOJIS) else None
guild = event.guild if self._cache_enabled_for(config.CacheComponents.GUILDS) else None
guild_id = event.guild.id
members = event.members if self._cache_enabled_for(config.CacheComponents.MEMBERS) else None
presences = event.presences if self._cache_enabled_for(config.CacheComponents.PRESENCES) else None
roles = event.roles if self._cache_enabled_for(config.CacheComponents.ROLES) else None
voice_states = event.voice_states if self._cache_enabled_for(config.CacheComponents.VOICE_STATES) else None

elif self._cache:
_LOGGER.log(ux.TRACE, "Skipping on_guild_create dispatch due to lack of any registered listeners")
event: typing.Union[guild_events.GuildAvailableEvent, guild_events.GuildJoinEvent, None] = None
gd = self._entity_factory.deserialize_gateway_guild(payload)

channels = gd.channels() if self._cache_enabled_for(config.CacheComponents.GUILD_CHANNELS) else None
Expand All @@ -180,23 +200,11 @@ async def on_guild_create(self, shard: gateway_shard.GatewayShard, payload: data
roles = gd.roles() if self._cache_enabled_for(config.CacheComponents.ROLES) else None
voice_states = gd.voice_states() if self._cache_enabled_for(config.CacheComponents.VOICE_STATES) else None

elif enabled_for_event:
if "unavailable" in payload:
event = self._event_factory.deserialize_guild_available_event(shard, payload)
else:
event = self._event_factory.deserialize_guild_join_event(shard, payload)

channels = event.channels
emojis = event.emojis
guild = event.guild
guild_id = guild.id
members = event.members
presences = event.presences
roles = event.roles
voice_states = event.voice_states

else:
event = None
_LOGGER.log(
ux.TRACE, "Skipping on_guild_create raw dispatch due to lack of any registered listeners or cache need"
)

channels = None
emojis = None
guild = None
Expand Down Expand Up @@ -241,16 +249,19 @@ async def on_guild_create(self, shard: gateway_shard.GatewayShard, payload: data
for voice_state in voice_states.values():
self._cache.set_voice_state(voice_state)

recv_chunks = self._enabled_for_event(shard_events.MemberChunkEvent) or self._cache_enabled_for(
config.CacheComponents.MEMBERS
)
members_declared = self._intents & intents_.Intents.GUILD_MEMBERS
presences_declared = self._intents & intents_.Intents.GUILD_PRESENCES

# When intents are enabled discord will only send other member objects on the guild create
# When intents are enabled Discord will only send other member objects on the guild create
# payload if presence intents are also declared, so if this isn't the case then we also want
# to chunk small guilds.
if recv_chunks and members_declared and (payload.get("large") or not presences_declared):
if (
self._intents & intents_.Intents.GUILD_MEMBERS
and (payload.get("large") or not presences_declared)
and (
self._cache_enabled_for(config.CacheComponents.MEMBERS)
or self._enabled_for_event(shard_events.MemberChunkEvent)
)
):
# We create a task here instead of awaiting the result to avoid any rate-limits from delaying dispatch.
nonce = f"{shard.id}.{_fixed_size_nonce()}"

Expand All @@ -263,28 +274,30 @@ async def on_guild_create(self, shard: gateway_shard.GatewayShard, payload: data
if event:
await self.dispatch(event)

# Internal granularity is preferred for GUILD_UPDATE over decorator based filtering due to its large cache scope.
# Internal granularity is preferred for GUILD_UPDATE over decorator based filtering due to its large scope.
async def on_guild_update(self, shard: gateway_shard.GatewayShard, payload: data_binding.JSONObject) -> None:
"""See https://discord.com/developers/docs/topics/gateway#guild-update for more info."""
enabled_for_event = self._enabled_for_event(guild_events.GuildUpdateEvent)
event: typing.Optional[guild_events.GuildUpdateEvent]
if self._enabled_for_event(guild_events.GuildUpdateEvent):
guild_id = snowflakes.Snowflake(payload["id"])
old = self._cache.get_guild(guild_id) if self._cache else None
event = self._event_factory.deserialize_guild_update_event(shard, payload, old_guild=old)

# We also filter here to prevent iterating over them and calling a function that won't do anything
emojis = event.emojis if self._cache_enabled_for(config.CacheComponents.EMOJIS) else None
guild = event.guild if self._cache_enabled_for(config.CacheComponents.GUILDS) else None
roles = event.roles if self._cache_enabled_for(config.CacheComponents.ROLES) else None

if not enabled_for_event and self._cache:
elif self._cache:
_LOGGER.log(ux.TRACE, "Skipping on_guild_update raw dispatch due to lack of any registered listeners")
event: typing.Optional[guild_events.GuildUpdateEvent] = None
event = None

gd = self._entity_factory.deserialize_gateway_guild(payload)
emojis = gd.emojis() if self._cache_enabled_for(config.CacheComponents.EMOJIS) else None
guild = gd.guild() if self._cache_enabled_for(config.CacheComponents.GUILDS) else None
guild_id = gd.id
roles = gd.roles() if self._cache_enabled_for(config.CacheComponents.ROLES) else None

elif enabled_for_event:
guild_id = snowflakes.Snowflake(payload["id"])
old = self._cache.get_guild(guild_id) if self._cache else None
event = self._event_factory.deserialize_guild_update_event(shard, payload, old_guild=old)
emojis = event.emojis
guild = event.guild
roles = event.roles

else:
_LOGGER.log(
ux.TRACE, "Skipping on_guild_update raw dispatch due to lack of any registered listeners or cache need"
Expand Down
Loading

0 comments on commit 0233162

Please sign in to comment.