diff --git a/hikari/api/rest.py b/hikari/api/rest.py index a20a030a8e..ff8c4079b5 100644 --- a/hikari/api/rest.py +++ b/hikari/api/rest.py @@ -308,7 +308,7 @@ async def edit_permission_overwrites( channel: snowflakes.SnowflakeishOr[channels.GuildChannel], target: snowflakes.Snowflakeish, *, - target_type: typing.Union[channels.PermissionOverwriteType, str], + target_type: channels.PermissionOverwriteType, allow: undefined.UndefinedOr[permissions_.Permissions] = undefined.UNDEFINED, deny: undefined.UndefinedOr[permissions_.Permissions] = undefined.UNDEFINED, reason: undefined.UndefinedOr[str] = undefined.UNDEFINED, @@ -323,7 +323,7 @@ async def edit_permission_overwrites( snowflakes.Snowflakeish, users.PartialUser, guilds.PartialRole, channels.PermissionOverwrite ], *, - target_type: undefined.UndefinedOr[typing.Union[channels.PermissionOverwriteType, str]] = undefined.UNDEFINED, + target_type: undefined.UndefinedOr[channels.PermissionOverwriteType] = undefined.UNDEFINED, allow: undefined.UndefinedOr[permissions_.Permissions] = undefined.UNDEFINED, deny: undefined.UndefinedOr[permissions_.Permissions] = undefined.UNDEFINED, reason: undefined.UndefinedOr[str] = undefined.UNDEFINED, @@ -338,7 +338,7 @@ async def edit_permission_overwrites( target : typing.Union[hikari.users.PartialUser, hikari.guilds.PartialRole, hikari.channels.PermissionOverwrite, hikari.snowflakes.Snowflakeish] The channel overwrite to edit. This may be the object or the ID of an existing overwrite. - target_type : hikari.undefined.UndefinedOr[typing.Union[hikari.channels.PermissionOverwriteType, builtins.str]] + target_type : hikari.undefined.UndefinedOr[hikari.channels.PermissionOverwriteType] If provided, the type of the target to update. If unset, will attempt to get the type from `target`. allow : hikari.undefined.UndefinedOr[hikari.permissions.Permissions] @@ -481,7 +481,7 @@ async def create_invite( target_user : hikari.undefined.UndefinedOr[hikari.snowflakes.SnowflakeishOr[hikari.users.PartialUser]] If provided, the target user id for this invite. This may be the object or the ID of an existing user. - target_user_type : hikari.undefined.UndefinedOr[typing.Union[hikari.invites.TargetUserType, builtins.int]] + target_user_type : hikari.undefined.UndefinedOr[hikari.invites.TargetUserType] If provided, the type of target user for this invite. reason : hikari.undefined.UndefinedOr[builtins.str] If provided, the reason that will be recorded in the audit logs. diff --git a/hikari/api/shard.py b/hikari/api/shard.py index 99d232c0d9..514415eb4f 100644 --- a/hikari/api/shard.py +++ b/hikari/api/shard.py @@ -25,10 +25,10 @@ __all__: typing.List[str] = ["GatewayDataFormat", "GatewayCompression", "GatewayShard"] import abc -import enum import typing from hikari import undefined +from hikari.utilities import enums if typing.TYPE_CHECKING: import datetime @@ -41,9 +41,8 @@ from hikari import users as users_ -@enum.unique @typing.final -class GatewayDataFormat(str, enum.Enum): +class GatewayDataFormat(str, enums.Enum): """Format of inbound gateway payloads.""" JSON = "json" @@ -52,9 +51,8 @@ class GatewayDataFormat(str, enum.Enum): """Erlang transmission format.""" -@enum.unique @typing.final -class GatewayCompression(str, enum.Enum): +class GatewayCompression(str, enums.Enum): """Types of gateway compression that may be supported.""" TRANSPORT_ZLIB_STREAM = "transport_zlib_stream" diff --git a/hikari/api/special_endpoints.py b/hikari/api/special_endpoints.py index 80faa374ce..5da9efff96 100644 --- a/hikari/api/special_endpoints.py +++ b/hikari/api/special_endpoints.py @@ -179,7 +179,9 @@ class GuildBuilder(abc.ABC): If not overridden, the guild will use the default voice region for Discord. """ - verification_level: undefined.UndefinedOr[guilds.GuildVerificationLevel] = attr.ib(default=undefined.UNDEFINED) + verification_level: undefined.UndefinedOr[typing.Union[guilds.GuildVerificationLevel, int]] = attr.ib( + default=undefined.UNDEFINED + ) """Verification level required to join the guild that can be overwritten. If not overridden, the guild will use the default verification level for diff --git a/hikari/applications.py b/hikari/applications.py index 03ec51c016..52bafce749 100644 --- a/hikari/applications.py +++ b/hikari/applications.py @@ -34,7 +34,6 @@ "TeamMembershipState", ] -import enum import typing import attr @@ -44,6 +43,7 @@ from hikari import snowflakes from hikari import urls from hikari.utilities import attr_extensions +from hikari.utilities import enums from hikari.utilities import routes if typing.TYPE_CHECKING: @@ -52,9 +52,8 @@ from hikari import users -@enum.unique @typing.final -class OAuth2Scope(str, enum.Enum): +class OAuth2Scope(str, enums.Enum): """OAuth2 Scopes that Discord allows. These are categories of permissions for applications using the OAuth2 API @@ -184,9 +183,8 @@ def __str__(self) -> str: return self.name -@enum.unique @typing.final -class ConnectionVisibility(enum.IntEnum): +class ConnectionVisibility(int, enums.Enum): """Describes who can see a connection with a third party account.""" NONE = 0 @@ -235,7 +233,7 @@ class OwnConnection: is_activity_visible: bool = attr.ib(eq=False, hash=False, repr=False) """`builtins.True` if this connection's activities are shown in the user's presence.""" - visibility: ConnectionVisibility = attr.ib(eq=False, hash=False, repr=True) + visibility: typing.Union[ConnectionVisibility, int] = attr.ib(eq=False, hash=False, repr=True) """The visibility of the connection.""" @@ -250,8 +248,8 @@ class OwnGuild(guilds.PartialGuild): """The guild-level permissions that apply to the current user or bot.""" -@enum.unique -class TeamMembershipState(enum.IntEnum): +@typing.final +class TeamMembershipState(int, enums.Enum): """Represents the state of a user's team membership.""" INVITED = 1 @@ -272,7 +270,7 @@ class TeamMember: app: traits.RESTAware = attr.ib(repr=False, metadata={attr_extensions.SKIP_DEEP_COPY: True}) """The client application that models may use for procedures.""" - membership_state: TeamMembershipState = attr.ib(repr=False) + membership_state: typing.Union[TeamMembershipState, int] = attr.ib(repr=False) """The state of this user's membership.""" permissions: typing.Sequence[str] = attr.ib(repr=False) diff --git a/hikari/audit_logs.py b/hikari/audit_logs.py index a473b37a07..c5c37d0480 100644 --- a/hikari/audit_logs.py +++ b/hikari/audit_logs.py @@ -42,13 +42,13 @@ import abc import datetime -import enum import typing import attr from hikari import snowflakes from hikari.utilities import attr_extensions +from hikari.utilities import enums from hikari.utilities import mapping if typing.TYPE_CHECKING: @@ -60,7 +60,7 @@ @typing.final -class AuditLogChangeKey(str, enum.Enum): +class AuditLogChangeKey(str, enums.Enum): """Commonly known and documented keys for audit log change objects. Others may exist. These should be expected to default to the raw string @@ -142,9 +142,8 @@ class AuditLogChange: """The name of the audit log change's key.""" -@enum.unique @typing.final -class AuditLogEventType(enum.IntEnum): +class AuditLogEventType(int, enums.Enum): """The type of event that occurred.""" GUILD_UPDATE = 1 @@ -204,7 +203,7 @@ class ChannelOverwriteEntryInfo(BaseAuditLogEntryInfo, snowflakes.Unique): id: snowflakes.Snowflake = attr.ib(eq=True, hash=True, repr=True) """The ID of this entity.""" - type: channels.PermissionOverwriteType = attr.ib(repr=True) + type: typing.Union[channels.PermissionOverwriteType, str] = attr.ib(repr=True) """The type of entity this overwrite targets.""" role_name: typing.Optional[str] = attr.ib(repr=True) diff --git a/hikari/channels.py b/hikari/channels.py index f6089a1d41..04b03c87a8 100644 --- a/hikari/channels.py +++ b/hikari/channels.py @@ -42,7 +42,6 @@ ] import abc -import enum import typing import attr @@ -54,6 +53,7 @@ from hikari import urls from hikari import users from hikari.utilities import attr_extensions +from hikari.utilities import enums from hikari.utilities import routes if typing.TYPE_CHECKING: @@ -67,9 +67,8 @@ from hikari import webhooks -@enum.unique @typing.final -class ChannelType(enum.IntEnum): +class ChannelType(int, enums.Enum): """The known channel types that are exposed to us by the API.""" GUILD_TEXT = 0 @@ -201,9 +200,8 @@ def channel(self) -> typing.Union[GuildNewsChannel, GuildTextChannel]: return channel -@enum.unique @typing.final -class PermissionOverwriteType(str, enum.Enum): +class PermissionOverwriteType(str, enums.Enum): """The type of entity a Permission Overwrite targets.""" ROLE = "role" @@ -252,7 +250,9 @@ class PermissionOverwrite(snowflakes.Unique): ) """The ID of this entity.""" - type: PermissionOverwriteType = attr.ib(converter=PermissionOverwriteType, eq=True, hash=True, repr=True) + type: typing.Union[PermissionOverwriteType, str] = attr.ib( + converter=PermissionOverwriteType, eq=True, hash=True, repr=True + ) """The type of entity this overwrite targets.""" # Flags are lazily loaded, due to the IntFlag mechanism being overly slow @@ -306,7 +306,7 @@ class PartialChannel(snowflakes.Unique): name: typing.Optional[str] = attr.ib(eq=False, hash=False, repr=True) """The channel's name. This will be missing for DM channels.""" - type: ChannelType = attr.ib(eq=False, hash=False, repr=True) + type: typing.Union[ChannelType, int] = attr.ib(eq=False, hash=False, repr=True) """The channel's type.""" def __str__(self) -> str: diff --git a/hikari/errors.py b/hikari/errors.py index d6e1d4d197..9372ee13ae 100644 --- a/hikari/errors.py +++ b/hikari/errors.py @@ -47,13 +47,13 @@ "VoiceError", ] -import enum import http import typing import attr from hikari.utilities import attr_extensions +from hikari.utilities import enums if typing.TYPE_CHECKING: from hikari import intents as intents_ @@ -109,9 +109,8 @@ def __str__(self) -> str: return self.reason -@enum.unique @typing.final -class ShardCloseCode(enum.IntEnum): +class ShardCloseCode(int, enums.Enum): """Reasons for a shard connection closure.""" NORMAL_CLOSURE = 1000 diff --git a/hikari/guilds.py b/hikari/guilds.py index 1ef0a48b0d..d5fbe9cfba 100644 --- a/hikari/guilds.py +++ b/hikari/guilds.py @@ -59,6 +59,7 @@ from hikari import urls from hikari import users from hikari.utilities import attr_extensions +from hikari.utilities import enums from hikari.utilities import flag from hikari.utilities import routes @@ -76,9 +77,8 @@ from hikari import voices as voices_ -@enum.unique @typing.final -class GuildExplicitContentFilterLevel(enum.IntEnum): +class GuildExplicitContentFilterLevel(int, enums.Enum): """Represents the explicit content filter setting for a guild.""" DISABLED = 0 @@ -94,9 +94,8 @@ def __str__(self) -> str: return self.name -@enum.unique @typing.final -class GuildFeature(str, enum.Enum): +class GuildFeature(str, enums.Enum): """Features that a guild can provide.""" ANIMATED_ICON = "ANIMATED_ICON" @@ -173,9 +172,8 @@ def __str__(self) -> str: """ -@enum.unique @typing.final -class GuildMessageNotificationsLevel(enum.IntEnum): +class GuildMessageNotificationsLevel(int, enums.Enum): """Represents the default notification level for new messages in a guild.""" ALL_MESSAGES = 0 @@ -188,9 +186,8 @@ def __str__(self) -> str: return self.name -@enum.unique @typing.final -class GuildMFALevel(enum.IntEnum): +class GuildMFALevel(int, enums.Enum): """Represents the multi-factor authorization requirement for a guild.""" NONE = 0 @@ -203,9 +200,8 @@ def __str__(self) -> str: return self.name -@enum.unique @typing.final -class GuildPremiumTier(enum.IntEnum): +class GuildPremiumTier(int, enums.Enum): """Tier for Discord Nitro boosting in a guild.""" NONE = 0 @@ -239,9 +235,8 @@ class GuildSystemChannelFlag(flag.Flag): """Display a message when the guild is Nitro boosted.""" -@enum.unique @typing.final -class GuildVerificationLevel(enum.IntEnum): +class GuildVerificationLevel(int, enums.Enum): """Represents the level of verification of a guild.""" NONE = 0 @@ -517,9 +512,8 @@ def colour(self) -> colours.Colour: return self.color -@enum.unique @typing.final -class IntegrationExpireBehaviour(enum.IntEnum): +class IntegrationExpireBehaviour(int, enums.Enum): """Behavior for expiring integration subscribers.""" REMOVE_ROLE = 0 @@ -581,7 +575,7 @@ class Integration(PartialIntegration): is_emojis_enabled: typing.Optional[bool] = attr.ib(eq=False, hash=False, repr=False) """Whether users under this integration are allowed to use it's custom emojis.""" - expire_behavior: IntegrationExpireBehaviour = attr.ib(eq=False, hash=False, repr=False) + expire_behavior: typing.Union[IntegrationExpireBehaviour, int] = attr.ib(eq=False, hash=False, repr=False) """How members should be treated after their connected subscription expires. This will not be enacted until after `GuildIntegration.expire_grace_period` @@ -827,13 +821,17 @@ class Guild(PartialGuild, abc.ABC): AFK and are moved to the AFK channel (`Guild.afk_channel_id`). """ - default_message_notifications: GuildMessageNotificationsLevel = attr.ib(eq=False, hash=False, repr=False) + default_message_notifications: typing.Union[GuildMessageNotificationsLevel, int] = attr.ib( + eq=False, hash=False, repr=False + ) """The default setting for message notifications in this guild.""" - explicit_content_filter: GuildExplicitContentFilterLevel = attr.ib(eq=False, hash=False, repr=False) + explicit_content_filter: typing.Union[GuildExplicitContentFilterLevel, int] = attr.ib( + eq=False, hash=False, repr=False + ) """The setting for the explicit content filter in this guild.""" - mfa_level: GuildMFALevel = attr.ib(eq=False, hash=False, repr=False) + mfa_level: typing.Union[GuildMFALevel, int] = attr.ib(eq=False, hash=False, repr=False) """The required MFA level for users wishing to participate in this guild.""" application_id: typing.Optional[snowflakes.Snowflake] = attr.ib(eq=False, hash=False, repr=False) @@ -895,7 +893,7 @@ class Guild(PartialGuild, abc.ABC): `Guild.features` for this guild. For all other purposes, it is `builtins.None`. """ - premium_tier: GuildPremiumTier = attr.ib(eq=False, hash=False, repr=False) + premium_tier: typing.Union[GuildPremiumTier, int] = attr.ib(eq=False, hash=False, repr=False) """The premium tier for this guild.""" premium_subscription_count: typing.Optional[int] = attr.ib(eq=False, hash=False, repr=False) @@ -919,25 +917,13 @@ class Guild(PartialGuild, abc.ABC): this guild. For all other purposes, it should be considered to be `builtins.None`. """ + verification_level: typing.Union[GuildVerificationLevel, int] = attr.ib(eq=False, hash=False, repr=False) + """The verification level needed for a user to participate in this guild.""" + # Flags are lazily loaded, due to the IntFlag mechanism being overly slow # to execute. - _verification_level: int = attr.ib(eq=False, hash=False, repr=False) _system_channel_flags: int = attr.ib(eq=False, hash=False, repr=False) - @property - def verification_level(self) -> GuildVerificationLevel: - """Return the verification level required for this guild. - - This defines the verification level needed for a user to participate in - this guild. - - Returns - ------- - GuildVerificationLevel - The verification level required for this guild. - """ - return GuildVerificationLevel(self._verification_level) - @property def system_channel_flags(self) -> GuildSystemChannelFlag: """Return flags for the guild system channel. diff --git a/hikari/impl/entity_factory.py b/hikari/impl/entity_factory.py index d63eb55c44..0915329a53 100644 --- a/hikari/impl/entity_factory.py +++ b/hikari/impl/entity_factory.py @@ -79,7 +79,7 @@ class _PartialGuildFields: id: snowflakes.Snowflake = attr.ib() name: str = attr.ib() icon_hash: str = attr.ib() - features: typing.Sequence[typing.Union[guild_models.GuildFeature, str]] = attr.ib() + features: typing.Sequence[guild_models.GuildFeatureish] = attr.ib() @attr_extensions.with_copy @@ -87,7 +87,7 @@ class _PartialGuildFields: class _GuildChannelFields: id: snowflakes.Snowflake = attr.ib() name: typing.Optional[str] = attr.ib() - type: channel_models.ChannelType = attr.ib() + type: typing.Union[channel_models.ChannelType, int] = attr.ib() guild_id: snowflakes.Snowflake = attr.ib() position: int = attr.ib() permission_overwrites: typing.Mapping[snowflakes.Snowflake, channel_models.PermissionOverwrite] = attr.ib() @@ -113,10 +113,10 @@ class _GuildFields(_PartialGuildFields): region: str = attr.ib() afk_channel_id: typing.Optional[snowflakes.Snowflake] = attr.ib() afk_timeout: datetime.timedelta = attr.ib() - verification_level: guild_models.GuildVerificationLevel = attr.ib() - default_message_notifications: guild_models.GuildMessageNotificationsLevel = attr.ib() - explicit_content_filter: guild_models.GuildExplicitContentFilterLevel = attr.ib() - mfa_level: guild_models.GuildMFALevel = attr.ib() + verification_level: typing.Union[guild_models.GuildVerificationLevel, int] = attr.ib() + default_message_notifications: typing.Union[guild_models.GuildMessageNotificationsLevel, int] = attr.ib() + explicit_content_filter: typing.Union[guild_models.GuildVerificationLevel, int] = attr.ib() + mfa_level: typing.Union[guild_models.GuildMFALevel, int] = attr.ib() application_id: typing.Optional[snowflakes.Snowflake] = attr.ib() widget_channel_id: typing.Optional[snowflakes.Snowflake] = attr.ib() system_channel_id: typing.Optional[snowflakes.Snowflake] = attr.ib() @@ -127,7 +127,7 @@ class _GuildFields(_PartialGuildFields): vanity_url_code: typing.Optional[str] = attr.ib() description: typing.Optional[str] = attr.ib() banner_hash: typing.Optional[str] = attr.ib() - premium_tier: guild_models.GuildPremiumTier = attr.ib() + premium_tier: typing.Union[guild_models.GuildPremiumTier, int] = attr.ib() premium_subscription_count: typing.Optional[int] = attr.ib() preferred_locale: str = attr.ib() public_updates_channel_id: typing.Optional[snowflakes.Snowflake] = attr.ib() @@ -143,7 +143,7 @@ class _InviteFields: channel_id: snowflakes.Snowflake = attr.ib() inviter: typing.Optional[user_models.User] = attr.ib() target_user: typing.Optional[user_models.User] = attr.ib() - target_user_type: typing.Optional[invite_models.TargetUserType] = attr.ib() + target_user_type: typing.Union[invite_models.TargetUserType, int, None] = attr.ib() approximate_active_member_count: typing.Optional[int] = attr.ib() approximate_member_count: typing.Optional[int] = attr.ib() @@ -402,10 +402,7 @@ def deserialize_audit_log(self, payload: data_binding.JSONObject) -> audit_log_m if (change_payloads := entry_payload.get("changes")) is not None: for change_payload in change_payloads: key: typing.Union[audit_log_models.AuditLogChangeKey, str] - try: - key = audit_log_models.AuditLogChangeKey(change_payload["key"]) - except ValueError: - key = change_payload["key"] + key = audit_log_models.AuditLogChangeKey(change_payload["key"]) new_value: typing.Any = change_payload.get("new_value") old_value: typing.Any = change_payload.get("old_value") @@ -424,10 +421,7 @@ def deserialize_audit_log(self, payload: data_binding.JSONObject) -> audit_log_m user_id = snowflakes.Snowflake(raw_user_id) action_type: typing.Union[audit_log_models.AuditLogEventType, int] - try: - action_type = audit_log_models.AuditLogEventType(entry_payload["action_type"]) - except ValueError: - action_type = entry_payload["action_type"] + action_type = audit_log_models.AuditLogEventType(entry_payload["action_type"]) options: typing.Optional[audit_log_models.BaseAuditLogEntryInfo] = None if (raw_option := entry_payload.get("options")) is not None: @@ -1084,10 +1078,7 @@ def deserialize_guild_member_ban(self, payload: data_binding.JSONObject) -> guil def _set_partial_guild_attributes(payload: data_binding.JSONObject) -> _PartialGuildFields: features = [] for feature in payload["features"]: - try: - features.append(guild_models.GuildFeature(feature)) - except ValueError: - features.append(feature) + features.append(guild_models.GuildFeature(feature)) return _PartialGuildFields( id=snowflakes.Snowflake(payload["id"]), diff --git a/hikari/invites.py b/hikari/invites.py index f1770f74f0..9371d0b185 100644 --- a/hikari/invites.py +++ b/hikari/invites.py @@ -33,7 +33,6 @@ ] import abc -import enum import typing import attr @@ -43,6 +42,7 @@ from hikari import snowflakes from hikari import urls from hikari.utilities import attr_extensions +from hikari.utilities import enums from hikari.utilities import routes if typing.TYPE_CHECKING: @@ -53,9 +53,8 @@ from hikari import users -@enum.unique @typing.final -class TargetUserType(enum.IntEnum): +class TargetUserType(int, enums.Enum): """The reason a invite targets a user.""" STREAM = 1 @@ -121,7 +120,7 @@ class InviteGuild(guilds.PartialGuild): Otherwise, this will always be `builtins.None`. For all other purposes, it is `builtins.None`. """ - verification_level: guilds.GuildVerificationLevel = attr.ib(eq=False, hash=False, repr=False) + verification_level: typing.Union[guilds.GuildVerificationLevel, int] = attr.ib(eq=False, hash=False, repr=False) """The verification level required for a user to participate in this guild.""" vanity_url_code: typing.Optional[str] = attr.ib(eq=False, hash=False, repr=True) @@ -248,7 +247,7 @@ class Invite(InviteCode): target_user: typing.Optional[users.User] = attr.ib(eq=False, hash=False, repr=False) """The object of the user who this invite targets, if set.""" - target_user_type: typing.Optional[TargetUserType] = attr.ib(eq=False, hash=False, repr=False) + target_user_type: typing.Union[TargetUserType, int, None] = attr.ib(eq=False, hash=False, repr=False) """The type of user target this invite is, if applicable.""" approximate_active_member_count: typing.Optional[int] = attr.ib(eq=False, hash=False, repr=False) diff --git a/hikari/messages.py b/hikari/messages.py index 3a240e77d9..e7bfe497e4 100644 --- a/hikari/messages.py +++ b/hikari/messages.py @@ -45,6 +45,7 @@ from hikari import undefined from hikari import urls from hikari.utilities import attr_extensions +from hikari.utilities import enums from hikari.utilities import flag if typing.TYPE_CHECKING: @@ -59,9 +60,8 @@ from hikari import users -@enum.unique @typing.final -class MessageType(enum.IntEnum): +class MessageType(int, enums.Enum): """The type of a message.""" DEFAULT = 0 @@ -131,9 +131,8 @@ class MessageFlag(flag.Flag): """This message came from the urgent message system.""" -@enum.unique @typing.final -class MessageActivityType(enum.IntEnum): +class MessageActivityType(int, enums.Enum): """The type of a rich presence message activity.""" NONE = 0 @@ -212,7 +211,7 @@ def __str__(self) -> str: class MessageActivity: """Represents the activity of a rich presence-enabled message.""" - type: MessageActivityType = attr.ib(repr=True) + type: typing.Union[MessageActivityType, int] = attr.ib(repr=True) """The type of message activity.""" party_id: typing.Optional[str] = attr.ib(repr=True) @@ -331,7 +330,7 @@ class PartialMessage(snowflakes.Unique): webhook_id: undefined.UndefinedNoneOr[snowflakes.Snowflake] = attr.ib(repr=False) """If the message was generated by a webhook, the webhook's ID.""" - type: undefined.UndefinedOr[MessageType] = attr.ib(repr=False) + type: undefined.UndefinedOr[typing.Union[MessageType, int]] = attr.ib(repr=False) """The message's type.""" activity: undefined.UndefinedNoneOr[MessageActivity] = attr.ib(repr=False) @@ -883,7 +882,7 @@ class Message(PartialMessage): webhook_id: typing.Optional[snowflakes.Snowflake] """If the message was generated by a webhook, the webhook's id.""" - type: MessageType + type: typing.Union[MessageType, int] """The message type.""" activity: typing.Optional[MessageActivity] diff --git a/hikari/permissions.py b/hikari/permissions.py index 8d1946e54b..cd2430b4e8 100644 --- a/hikari/permissions.py +++ b/hikari/permissions.py @@ -36,7 +36,7 @@ class Permissions(flag.Flag): """Represents the permissions available in a given channel or guild. - This is an int-flag enum. This means that you can **combine multiple + This enum is an `enum.IntFlag`. This means that you can **combine multiple permissions together** into one value using the bitwise-OR operator (`|`). my_perms = Permissions.MANAGE_CHANNELS | Permissions.MANAGE_GUILD diff --git a/hikari/presences.py b/hikari/presences.py index 926640e6d4..a81cc1ea87 100644 --- a/hikari/presences.py +++ b/hikari/presences.py @@ -44,6 +44,7 @@ from hikari import snowflakes from hikari.utilities import attr_extensions +from hikari.utilities import enums from hikari.utilities import flag if typing.TYPE_CHECKING: @@ -53,9 +54,8 @@ from hikari import traits -@enum.unique @typing.final -class ActivityType(enum.IntEnum): +class ActivityType(int, enums.Enum): """The activity type.""" PLAYING = 0 @@ -191,7 +191,7 @@ class Activity: url: typing.Optional[str] = attr.ib(default=None, repr=False) """The activity URL. Only valid for `STREAMING` activities.""" - type: ActivityType = attr.ib(converter=ActivityType, default=ActivityType.PLAYING) + type: typing.Union[ActivityType, int] = attr.ib(converter=ActivityType, default=ActivityType.PLAYING) """The activity type.""" def __str__(self) -> str: @@ -251,7 +251,7 @@ def flags(self) -> typing.Optional[ActivityFlag]: @typing.final -class Status(str, enum.Enum): +class Status(str, enums.Enum): """The status of a member.""" ONLINE = "online" @@ -275,13 +275,13 @@ def __str__(self) -> str: class ClientStatus: """The client statuses for this member.""" - desktop: Status = attr.ib(repr=True) + desktop: typing.Union[Status, str] = attr.ib(repr=True) """The status of the target user's desktop session.""" - mobile: Status = attr.ib(repr=True) + mobile: typing.Union[Status, str] = attr.ib(repr=True) """The status of the target user's mobile session.""" - web: Status = attr.ib(repr=True) + web: typing.Union[Status, str] = attr.ib(repr=True) """The status of the target user's web session.""" @@ -306,7 +306,7 @@ class MemberPresence: guild_id: snowflakes.Snowflake = attr.ib(eq=True, hash=True, repr=True) """The ID of the guild this presence belongs to.""" - visible_status: Status = attr.ib(eq=False, hash=False, repr=True) + visible_status: typing.Union[Status, str] = attr.ib(eq=False, hash=False, repr=True) """This user's current status being displayed by the client.""" activities: typing.Sequence[RichActivity] = attr.ib(eq=False, hash=False, repr=False) diff --git a/hikari/users.py b/hikari/users.py index 4f045ff9e6..22a9e2981e 100644 --- a/hikari/users.py +++ b/hikari/users.py @@ -36,6 +36,7 @@ from hikari import undefined from hikari import urls from hikari.utilities import attr_extensions +from hikari.utilities import enums from hikari.utilities import flag from hikari.utilities import routes @@ -94,9 +95,8 @@ class UserFlag(flag.Flag): """ -@enum.unique @typing.final -class PremiumType(enum.IntEnum): +class PremiumType(int, enums.Enum): """The types of Nitro.""" NONE = 0 @@ -431,7 +431,7 @@ class OwnUser(UserImpl): scope. Will always be `builtins.None` for bot users. """ - premium_type: typing.Optional[PremiumType] = attr.ib(eq=False, hash=False, repr=False) + premium_type: typing.Union[PremiumType, int, None] = attr.ib(eq=False, hash=False, repr=False) """The type of Nitro Subscription this user account had. This will always be `builtins.None` for bots. diff --git a/hikari/utilities/cache.py b/hikari/utilities/cache.py index e6b9b87b0e..6753f63dbc 100644 --- a/hikari/utilities/cache.py +++ b/hikari/utilities/cache.py @@ -422,7 +422,7 @@ class InviteData(BaseData[invites.InviteWithMetadata]): channel_id: snowflakes.Snowflake = attr.ib() inviter_id: typing.Optional[snowflakes.Snowflake] = attr.ib() target_user_id: typing.Optional[snowflakes.Snowflake] = attr.ib() - target_user_type: typing.Optional[invites.TargetUserType] = attr.ib() + target_user_type: typing.Union[invites.TargetUserType, int, None] = attr.ib() uses: int = attr.ib() max_uses: typing.Optional[int] = attr.ib() max_age: typing.Optional[datetime.timedelta] = attr.ib() @@ -567,7 +567,7 @@ class RichActivityData(BaseData[presences.RichActivity]): name: str = attr.ib() url: typing.Optional[str] = attr.ib() - type: presences.ActivityType = attr.ib() + type: typing.Union[presences.ActivityType, int] = attr.ib() created_at: datetime.datetime = attr.ib() timestamps: typing.Optional[presences.ActivityTimestamps] = attr.ib() application_id: typing.Optional[snowflakes.Snowflake] = attr.ib() @@ -638,7 +638,7 @@ class MemberPresenceData(BaseData[presences.MemberPresence]): user_id: snowflakes.Snowflake = attr.ib() role_ids: typing.Optional[typing.Tuple[snowflakes.Snowflake, ...]] = attr.ib() guild_id: snowflakes.Snowflake = attr.ib() - visible_status: presences.Status = attr.ib() + visible_status: typing.Union[presences.Status, str] = attr.ib() activities: typing.Tuple[RichActivityData, ...] = attr.ib() client_status: presences.ClientStatus = attr.ib() premium_since: typing.Optional[datetime.datetime] = attr.ib() diff --git a/hikari/utilities/enums.py b/hikari/utilities/enums.py new file mode 100644 index 0000000000..314f52d42b --- /dev/null +++ b/hikari/utilities/enums.py @@ -0,0 +1,236 @@ +# -*- coding: utf-8 -*- +# cython: language_level=3 +# Copyright (c) 2020 Nekokatt +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +"""Implementation of parts of Python's `enum` protocol to be faster.""" +from __future__ import annotations + +__all__: typing.List[str] = ["Enum"] + +import os +import sys +import types +import typing + +_T = typing.TypeVar("_T") + + +class _EnumNamespace(typing.Dict[str, typing.Any]): + __slots__: typing.Sequence[str] = ("base", "names_to_values", "values_to_names") + + def __init__(self, base: typing.Type[typing.Any]) -> None: + super().__init__() + self.base = base + self.names_to_values: typing.Dict[str, typing.Any] = {} + self.values_to_names: typing.Dict[str, typing.Any] = {} + self["__doc__"] = "An enumeration." + + def __contains__(self, item: typing.Any) -> bool: + try: + _ = self[item] + return True + except KeyError: + return False + + def __getitem__(self, name: str) -> typing.Any: + try: + return super().__getitem__(name) + except KeyError: + try: + return self.names_to_values[name] + except KeyError: + raise KeyError(name) from None + + def __iter__(self) -> typing.Iterator[str]: + yield from super().__iter__() + yield self.names_to_values + + def __setitem__(self, name: str, value: typing.Any) -> None: + if name == "" or name == "mro": + raise TypeError(f"Invalid enum member name: {name!r}") + + if name.startswith("_"): + # Dunder/sunder, so skip. + super().__setitem__(name, value) + return + + if hasattr(value, "__get__") or hasattr(value, "__set__") or hasattr(value, "__del__"): + super().__setitem__(name, value) + return + + if not isinstance(value, self.base): + raise TypeError(f"Expected member {name} to be of type {self.base.__name__} but was {type(value).__name__}") + + name = sys.intern(name) + + if issubclass(self.base, str): + value = sys.intern(value) + else: + try: + # This will fail if unhashable. + hash(value) + except TypeError: + raise TypeError(f"Cannot have unhashable values in this enum type ({name}: {value!r})") from None + + if name in self.names_to_values: + raise TypeError("Cannot define same name twice") + if value in self.values_to_names: + # We must have defined some alias, so just register the name + self.names_to_values[name] = value + return + if not isinstance(value, self.base): + raise TypeError("Enum values must be an instance of the base type of the enum") + + self.names_to_values[name] = value + self.values_to_names[value] = name + + +# We refer to these from the metaclasses, but obviously this won't work +# until these classes are created, and since they use the metaclasses as +# a base metaclass, we have to give these values for _EnumMeta to not +# flake out when initializing them. +_Enum = NotImplemented + + +def _attr_mutator(self, *_: typing.Any) -> typing.NoReturn: + raise TypeError("Cannot mutate enum members") + + +class _EnumMeta(type): + def __call__(cls, value: typing.Any) -> typing.Any: + try: + return cls._value2member_map_[value] + except KeyError: + # If we cant find the value, just return what got casted in + return value + + def __dir__(cls) -> typing.List[str]: + members = ["__class__", "__doc__", "__members__", "__module__"] + try: + members += list(cls._name2member_map_) + finally: + return members + + def __getattr__(cls, name: str) -> typing.Any: + if name.startswith("_") and name.endswith("_"): + # Stop recursion errors by trying to look up _name2member_map_ + # recursively. + raise AttributeError(name) + try: + return cls._name2member_map_[name] + except KeyError: + try: + return super().__getattribute__(name) + except AttributeError: + raise AttributeError(name) from None + + def __getitem__(cls, name: str) -> typing.Any: + return cls._name2member_map_[name] + + def __iter__(cls) -> typing.Iterator[str]: + yield cls._name2member_map_ + + @staticmethod + def __new__( + mcs: typing.Type[_T], + name: str, + bases: typing.Tuple[typing.Type[typing.Any], ...], + namespace: _EnumNamespace, + ) -> _T: + global _Enum + + if name == "Enum" and _Enum is NotImplemented: + # noinspection PyRedundantParentheses + return (_Enum := super().__new__(mcs, name, bases, namespace)) + + try: + base, enum_type = bases + except ValueError: + raise TypeError("Expected two base classes for an enum") from None + + if not issubclass(enum_type, _Enum): + raise TypeError("second base type for enum must be derived from Enum") + + new_namespace = { + "__objtype__": base, + "__enumtype__": enum_type, + "_name2member_map_": (name2member := {}), + "_value2member_map_": (value2member := {}), + # Required to be immutable by enum API itself. + "__members__": types.MappingProxyType(namespace.names_to_values), + **namespace, + } + + cls = super().__new__(mcs, name, bases, new_namespace) + + for name, value in namespace.names_to_values.items(): + # Patching the member init call is around 100ns faster per call than + # using the default type.__call__ which would make us do the lookup + # in cls.__new__. Reason for this is that python will also always + # invoke cls.__init__ if we do this, so we end up with two function + # calls. + member = cls.__new__(cls, value) + member.name = name + member.value = value + name2member[name] = member + value2member[value] = member + + cls.__setattr__ = _attr_mutator + cls.__delattr__ = _attr_mutator + + return cls + + @classmethod + def __prepare__(mcs, name: str, bases: typing.Tuple[typing.Type[typing.Any], ...] = ()) -> _EnumNamespace: + try: + # Fails if Enum is not defined. We check this in `__new__` properly. + base, enum_type = bases + + if isinstance(base, _EnumMeta): + raise TypeError("First base to an enum must be the type to combine with, not _EnumMeta") + if not isinstance(enum_type, _EnumMeta): + raise TypeError("Second base to an enum must be the enum type (derived from _EnumMeta) to be used") + + return _EnumNamespace(base) + except ValueError: + return _EnumNamespace(object) + + def __repr__(cls) -> str: + return f"" + + __str__ = __repr__ + + +class Enum(metaclass=_EnumMeta): + """Re-implementation of parts of Python's `enum` to be faster.""" + + def __getattr__(self, name: str) -> typing.Any: + return getattr(self.value, name) + + def __repr__(self) -> str: + return f"<{type(self).__name__}.{self.name}: {self.value!r}>" + + def __str__(self) -> str: + return f"{type(self).__name__}.{self.name}" + + +# We have to use this fallback, or Pdoc will fail to document some stuff correctly... +if os.getenv("PDOC3_GENERATING") == "1": # pragma: no cover + from enum import Enum # noqa: F811 - Redefinition intended diff --git a/hikari/utilities/enums.pyi b/hikari/utilities/enums.pyi new file mode 100644 index 0000000000..9abb63a6cb --- /dev/null +++ b/hikari/utilities/enums.pyi @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- +# cython: language_level=3 +# Copyright (c) 2020 Nekokatt +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +"""Typehints for `hikari.utilities.enums`.""" + +# Enums use a lot of internal voodoo that will not type check nicely, so we +# skip that module with MyPy and just accept that "here be dragons". +# +# The caveat to implementing this is that MyPy has to have a special module to +# understand how to use Python's enum types. I really don't want to have to +# ship my own MyPy plugin for this, so just make MyPy think that the types +# we are using are just aliases from the enum types in the standard library. + +import enum as __enum + +Enum = __enum.Enum + +__all__ = ["Enum"] diff --git a/hikari/webhooks.py b/hikari/webhooks.py index 91ebc62cbe..ba9071de3a 100644 --- a/hikari/webhooks.py +++ b/hikari/webhooks.py @@ -25,7 +25,6 @@ __all__: typing.List[str] = ["WebhookType", "Webhook"] -import enum import typing import attr @@ -35,6 +34,7 @@ from hikari import undefined from hikari import urls from hikari.utilities import attr_extensions +from hikari.utilities import enums from hikari.utilities import routes if typing.TYPE_CHECKING: @@ -46,9 +46,8 @@ from hikari import users as users_ -@enum.unique @typing.final -class WebhookType(enum.IntEnum): +class WebhookType(int, enums.Enum): """Types of webhook.""" INCOMING = 1 @@ -77,7 +76,7 @@ class Webhook(snowflakes.Unique): id: snowflakes.Snowflake = attr.ib(eq=True, hash=True, repr=True) """The ID of this entity.""" - type: WebhookType = attr.ib(eq=False, hash=False, repr=True) + type: typing.Union[WebhookType, int] = attr.ib(eq=False, hash=False, repr=True) """The type of the webhook.""" guild_id: typing.Optional[snowflakes.Snowflake] = attr.ib(eq=False, hash=False, repr=True) diff --git a/scripts/enum_benchmark.py b/scripts/enum_benchmark.py new file mode 100644 index 0000000000..ab9cfa2f1b --- /dev/null +++ b/scripts/enum_benchmark.py @@ -0,0 +1,119 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2020 Nekokatt +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import cProfile +import enum as py_enum +import timeit + +from hikari.utilities import enums as hikari_enum + + +class BasicPyEnum(str, py_enum.Enum): + a = "0" + b = "1" + c = "2" + d = "3" + e = "4" + f = "5" + g = "6" + h = "7" + i = "8" + j = "9" + k = "10" + l = "11" + m = "12" + n = "13" + o = "14" + p = "15" + q = "16" + r = "17" + s = "18" + t = "19" + u = "20" + v = "21" + w = "22" + x = "23" + y = "24" + z = "25" + + +class BasicHikariEnum(str, hikari_enum.Enum): + a = "0" + b = "1" + c = "2" + d = "3" + e = "4" + f = "5" + g = "6" + h = "7" + i = "8" + j = "9" + k = "10" + l = "11" + m = "12" + n = "13" + o = "14" + p = "15" + q = "16" + r = "17" + s = "18" + t = "19" + u = "20" + v = "21" + w = "22" + x = "23" + y = "24" + z = "25" + + +# Dummy work to churn the CPU up. +for i in range(100_000): + assert sum(i for i in range(10)) > 0 + +py_enum_call_time = timeit.timeit("BasicPyEnum('25')", number=1_000_000, globals=globals()) +hikari_enum_call_time = timeit.timeit("BasicHikariEnum('25')", number=1_000_000, globals=globals()) +py_enum_delegate_to_map_time = timeit.timeit( + "BasicPyEnum._value2member_map_['25']", number=1_000_000, globals=globals() +) +hikari_enum_delegate_to_map_time = timeit.timeit( + "BasicHikariEnum._value2member_map_['25']", number=1_000_000, globals=globals() +) +py_enum_getitem_time = timeit.timeit("BasicPyEnum['z']", number=1_000_000, globals=globals()) +hikari_enum_getitem_time = timeit.timeit("BasicHikariEnum['z']", number=1_000_000, globals=globals()) + +print("BasicPyEnum.__call__('25')", py_enum_call_time, "µs") +print("BasicHikariEnum.__call__('25')", hikari_enum_call_time, "µs") +print("BasicPyEnum._value2member_map_['25']", py_enum_delegate_to_map_time, "µs") +print("BasicHikariEnum._value2member_map['25']", hikari_enum_delegate_to_map_time, "µs") +print("BasicPyEnum.__getitem__['z']", py_enum_getitem_time, "µs") +print("BasicHikariEnum.__getitem__['z']", hikari_enum_getitem_time, "µs") + +print("BasicPyEnum.__call__ profile") +cProfile.runctx("for i in range(1_000_000): BasicPyEnum('25')", globals=globals(), locals=locals()) + +print("BasicHikariEnum.__call__ profile") +cProfile.runctx("for i in range(1_000_000): BasicHikariEnum('25')", globals=globals(), locals=locals()) + +print("BasicPyEnum.__getitem__ profile") +cProfile.runctx("for i in range(1_000_000): BasicPyEnum['z']", globals=globals(), locals=locals()) + +print("BasicHikariEnum.__getitem__ profile") +cProfile.runctx("for i in range(1_000_000): BasicHikariEnum['z']", globals=globals(), locals=locals()) diff --git a/tests/hikari/utilities/test_enums.py b/tests/hikari/utilities/test_enums.py new file mode 100644 index 0000000000..97828afaa0 --- /dev/null +++ b/tests/hikari/utilities/test_enums.py @@ -0,0 +1,191 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2020 Nekokatt +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +import mock +import pytest + +from hikari.utilities import enums + + +class TestEnum: + @mock.patch.object(enums, "_Enum", new=NotImplemented) + def test_init_first_enum_type_populates_Enum(self): + class Enum(metaclass=enums._EnumMeta): + pass + + assert enums._Enum is Enum + + @mock.patch.object(enums, "_Enum", new=NotImplemented) + def test_init_first_enum_type_with_wrong_name_and_no_bases_raises_TypeError(self): + with pytest.raises(TypeError): + + class Potato(metaclass=enums._EnumMeta): + pass + + assert enums._Enum is NotImplemented + + def test_init_second_enum_type_with_no_bases_does_not_change_Enum_attribute_and_raises_TypeError(self): + expect = enums._Enum + + with pytest.raises(TypeError): + + class Enum(metaclass=enums._EnumMeta): + pass + + assert enums._Enum is expect + + @pytest.mark.parametrize( + ("args", "kwargs"), + [([str], {"metaclass": enums._EnumMeta}), ([enums.Enum], {"metaclass": enums._EnumMeta}), ([enums.Enum], {})], + ) + def test_init_enum_type_with_one_base_is_TypeError(self, args, kwargs): + with pytest.raises(TypeError): + + class Enum(*args, **kwargs): + pass + + @pytest.mark.parametrize( + ("args", "kwargs"), + [ + ([enums.Enum, str], {"metaclass": enums._EnumMeta}), + ([enums.Enum, str], {}), + ], + ) + def test_init_enum_type_with_bases_in_wrong_order_is_TypeError(self, args, kwargs): + with pytest.raises(TypeError): + + class Enum(*args, **kwargs): + pass + + def test_init_enum_type_default_docstring_set(self): + class Enum(str, enums.Enum): + pass + + assert Enum.__doc__ == "An enumeration." + + def test_init_enum_type_disallows_objects_that_are_not_instances_of_the_first_base(self): + with pytest.raises(TypeError): + + class Enum(str, enums.Enum): + foo = 1 + + def test_init_enum_type_allows_any_object_if_it_has_a_dunder_name(self): + class Enum(str, enums.Enum): + __foo__ = 1 + __bar = 2 + + assert Enum is not None + + def test_init_enum_type_allows_any_object_if_it_has_a_sunder_name(self): + class Enum(str, enums.Enum): + _foo_ = 1 + _bar = 2 + + assert Enum is not None + + def test_init_enum_type_allows_methods(self): + class Enum(int, enums.Enum): + def foo(self): + return "foo" + + assert Enum.foo(12) == "foo" + + def test_init_enum_type_allows_classmethods(self): + class Enum(int, enums.Enum): + @classmethod + def foo(cls): + assert cls is Enum + return "foo" + + assert Enum.foo() == "foo" + + def test_init_enum_type_allows_staticmethods(self): + class Enum(int, enums.Enum): + @staticmethod + def foo(): + return "foo" + + assert Enum.foo() == "foo" + + def test_init_enum_type_allows_descriptors(self): + class Enum(int, enums.Enum): + @property + def foo(self): + return "foo" + + assert isinstance(Enum.foo, property) + + def test_init_enum_type_maps_names_in___members__(self): + class Enum(int, enums.Enum): + foo = 9 + bar = 18 + baz = 27 + + @staticmethod + def sm(): + pass + + @classmethod + def cm(cls): + pass + + def m(self): + pass + + @property + def p(self): + pass + + __dunder__ = "aaa" + _sunder_ = "bbb" + __priv = "ccc" + _prot = "ddd" + + assert Enum.__members__ == {"foo": 9, "bar": 18, "baz": 27} + + def test___call___when_member(self): + class Enum(int, enums.Enum): + foo = 9 + bar = 18 + baz = 27 + + returned = Enum(9) + assert returned == Enum.foo + assert type(returned) == Enum + + def test___call___when_not_member(self): + class Enum(int, enums.Enum): + foo = 9 + bar = 18 + baz = 27 + + returned = Enum(69) + assert returned == 69 + assert type(returned) != Enum + + def test___getitem__(self): + class Enum(int, enums.Enum): + foo = 9 + bar = 18 + baz = 27 + + returned = Enum["foo"] + assert returned == Enum.foo + assert type(returned) == Enum