Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[3.11] gh-103365: [Enum] STRICT boundary corrections (GH-103494) #103513

Merged
merged 1 commit into from
Apr 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions Doc/library/enum.rst
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,8 @@ Data Types

.. attribute:: STRICT

Out-of-range values cause a :exc:`ValueError` to be raised::
Out-of-range values cause a :exc:`ValueError` to be raised. This is the
default for :class:`Flag`::

>>> from enum import Flag, STRICT, auto
>>> class StrictFlag(Flag, boundary=STRICT):
Expand All @@ -709,7 +710,7 @@ Data Types
.. attribute:: CONFORM

Out-of-range values have invalid values removed, leaving a valid *Flag*
value. This is the default for :class:`Flag`::
value::

>>> from enum import Flag, CONFORM, auto
>>> class ConformFlag(Flag, boundary=CONFORM):
Expand Down
67 changes: 39 additions & 28 deletions Lib/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,13 @@ def __set_name__(self, enum_class, member_name):
enum_member.__objclass__ = enum_class
enum_member.__init__(*args)
enum_member._sort_order_ = len(enum_class._member_names_)

if Flag is not None and issubclass(enum_class, Flag):
enum_class._flag_mask_ |= value
if _is_single_bit(value):
enum_class._singles_mask_ |= value
enum_class._all_bits_ = 2 ** ((enum_class._flag_mask_).bit_length()) - 1

# If another member with the same value was already defined, the
# new member becomes an alias to the existing one.
try:
Expand Down Expand Up @@ -525,12 +532,8 @@ def __new__(metacls, cls, bases, classdict, *, boundary=None, _simple=False, **k
classdict['_use_args_'] = use_args
#
# convert future enum members into temporary _proto_members
# and record integer values in case this will be a Flag
flag_mask = 0
for name in member_names:
value = classdict[name]
if isinstance(value, int):
flag_mask |= value
classdict[name] = _proto_member(value)
#
# house-keeping structures
Expand All @@ -547,8 +550,9 @@ def __new__(metacls, cls, bases, classdict, *, boundary=None, _simple=False, **k
boundary
or getattr(first_enum, '_boundary_', None)
)
classdict['_flag_mask_'] = flag_mask
classdict['_all_bits_'] = 2 ** ((flag_mask).bit_length()) - 1
classdict['_flag_mask_'] = 0
classdict['_singles_mask_'] = 0
classdict['_all_bits_'] = 0
classdict['_inverted_'] = None
try:
exc = None
Expand Down Expand Up @@ -637,21 +641,10 @@ def __new__(metacls, cls, bases, classdict, *, boundary=None, _simple=False, **k
):
delattr(enum_class, '_boundary_')
delattr(enum_class, '_flag_mask_')
delattr(enum_class, '_singles_mask_')
delattr(enum_class, '_all_bits_')
delattr(enum_class, '_inverted_')
elif Flag is not None and issubclass(enum_class, Flag):
# ensure _all_bits_ is correct and there are no missing flags
single_bit_total = 0
multi_bit_total = 0
for flag in enum_class._member_map_.values():
flag_value = flag._value_
if _is_single_bit(flag_value):
single_bit_total |= flag_value
else:
# multi-bit flags are considered aliases
multi_bit_total |= flag_value
enum_class._flag_mask_ = single_bit_total
#
# set correct __iter__
member_list = [m._value_ for m in enum_class]
if member_list != sorted(member_list):
Expand Down Expand Up @@ -1303,8 +1296,8 @@ def _reduce_ex_by_global_name(self, proto):
class FlagBoundary(StrEnum):
"""
control how out of range values are handled
"strict" -> error is raised
"conform" -> extra bits are discarded [default for Flag]
"strict" -> error is raised [default for Flag]
"conform" -> extra bits are discarded
"eject" -> lose flag status
"keep" -> keep flag status and all bits [default for IntFlag]
"""
Expand All @@ -1315,7 +1308,7 @@ class FlagBoundary(StrEnum):
STRICT, CONFORM, EJECT, KEEP = FlagBoundary


class Flag(Enum, boundary=CONFORM):
class Flag(Enum, boundary=STRICT):
"""
Support for flags
"""
Expand Down Expand Up @@ -1393,6 +1386,7 @@ def _missing_(cls, value):
# - value must not include any skipped flags (e.g. if bit 2 is not
# defined, then 0d10 is invalid)
flag_mask = cls._flag_mask_
singles_mask = cls._singles_mask_
all_bits = cls._all_bits_
neg_value = None
if (
Expand Down Expand Up @@ -1424,7 +1418,8 @@ def _missing_(cls, value):
value = all_bits + 1 + value
# get members and unknown
unknown = value & ~flag_mask
member_value = value & flag_mask
aliases = value & ~singles_mask
member_value = value & singles_mask
if unknown and cls._boundary_ is not KEEP:
raise ValueError(
'%s(%r) --> unknown values %r [%s]'
Expand All @@ -1438,11 +1433,25 @@ def _missing_(cls, value):
pseudo_member = cls._member_type_.__new__(cls, value)
if not hasattr(pseudo_member, '_value_'):
pseudo_member._value_ = value
if member_value:
pseudo_member._name_ = '|'.join([
m._name_ for m in cls._iter_member_(member_value)
])
if unknown:
if member_value or aliases:
members = []
combined_value = 0
for m in cls._iter_member_(member_value):
members.append(m)
combined_value |= m._value_
if aliases:
value = member_value | aliases
for n, pm in cls._member_map_.items():
if pm not in members and pm._value_ and pm._value_ & value == pm._value_:
members.append(pm)
combined_value |= pm._value_
unknown = value ^ combined_value
pseudo_member._name_ = '|'.join([m._name_ for m in members])
if not combined_value:
pseudo_member._name_ = None
elif unknown and cls._boundary_ is STRICT:
raise ValueError('%r: no members with value %r' % (cls, unknown))
elif unknown:
pseudo_member._name_ += '|%s' % cls._numeric_repr_(unknown)
else:
pseudo_member._name_ = None
Expand Down Expand Up @@ -1671,6 +1680,7 @@ def convert_class(cls):
body['_boundary_'] = boundary or etype._boundary_
body['_flag_mask_'] = None
body['_all_bits_'] = None
body['_singles_mask_'] = None
body['_inverted_'] = None
body['__or__'] = Flag.__or__
body['__xor__'] = Flag.__xor__
Expand Down Expand Up @@ -1743,7 +1753,8 @@ def convert_class(cls):
else:
multi_bits |= value
gnv_last_values.append(value)
enum_class._flag_mask_ = single_bits
enum_class._flag_mask_ = single_bits | multi_bits
enum_class._singles_mask_ = single_bits
enum_class._all_bits_ = 2 ** ((single_bits|multi_bits).bit_length()) - 1
# set correct __iter__
member_list = [m._value_ for m in enum_class]
Expand Down
47 changes: 39 additions & 8 deletions Lib/test/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -2758,6 +2758,8 @@ def __new__(cls, c):
#
a = ord('a')
#
self.assertEqual(FlagFromChar._all_bits_, 316912650057057350374175801343)
self.assertEqual(FlagFromChar._flag_mask_, 158456325028528675187087900672)
self.assertEqual(FlagFromChar.a, 158456325028528675187087900672)
self.assertEqual(FlagFromChar.a|1, 158456325028528675187087900673)
#
Expand All @@ -2772,6 +2774,8 @@ def __new__(cls, c):
a = ord('a')
z = 1
#
self.assertEqual(FlagFromChar._all_bits_, 316912650057057350374175801343)
self.assertEqual(FlagFromChar._flag_mask_, 158456325028528675187087900674)
self.assertEqual(FlagFromChar.a.value, 158456325028528675187087900672)
self.assertEqual((FlagFromChar.a|FlagFromChar.z).value, 158456325028528675187087900674)
#
Expand All @@ -2785,6 +2789,8 @@ def __new__(cls, c):
#
a = ord('a')
#
self.assertEqual(FlagFromChar._all_bits_, 316912650057057350374175801343)
self.assertEqual(FlagFromChar._flag_mask_, 158456325028528675187087900672)
self.assertEqual(FlagFromChar.a, 158456325028528675187087900672)
self.assertEqual(FlagFromChar.a|1, 158456325028528675187087900673)

Expand Down Expand Up @@ -2962,18 +2968,18 @@ def test_bool(self):
self.assertEqual(bool(f.value), bool(f))

def test_boundary(self):
self.assertIs(enum.Flag._boundary_, CONFORM)
class Iron(Flag, boundary=STRICT):
self.assertIs(enum.Flag._boundary_, STRICT)
class Iron(Flag, boundary=CONFORM):
ONE = 1
TWO = 2
EIGHT = 8
self.assertIs(Iron._boundary_, STRICT)
self.assertIs(Iron._boundary_, CONFORM)
#
class Water(Flag, boundary=CONFORM):
class Water(Flag, boundary=STRICT):
ONE = 1
TWO = 2
EIGHT = 8
self.assertIs(Water._boundary_, CONFORM)
self.assertIs(Water._boundary_, STRICT)
#
class Space(Flag, boundary=EJECT):
ONE = 1
Expand All @@ -2986,17 +2992,42 @@ class Bizarre(Flag, boundary=KEEP):
c = 4
d = 6
#
self.assertRaisesRegex(ValueError, 'invalid value 7', Iron, 7)
self.assertRaisesRegex(ValueError, 'invalid value 7', Water, 7)
#
self.assertIs(Water(7), Water.ONE|Water.TWO)
self.assertIs(Water(~9), Water.TWO)
self.assertIs(Iron(7), Iron.ONE|Iron.TWO)
self.assertIs(Iron(~9), Iron.TWO)
#
self.assertEqual(Space(7), 7)
self.assertTrue(type(Space(7)) is int)
#
self.assertEqual(list(Bizarre), [Bizarre.c])
self.assertIs(Bizarre(3), Bizarre.b)
self.assertIs(Bizarre(6), Bizarre.d)
#
class SkipFlag(enum.Flag):
A = 1
B = 2
C = 4 | B
#
self.assertTrue(SkipFlag.C in (SkipFlag.A|SkipFlag.C))
self.assertRaisesRegex(ValueError, 'SkipFlag.. invalid value 42', SkipFlag, 42)
#
class SkipIntFlag(enum.IntFlag):
A = 1
B = 2
C = 4 | B
#
self.assertTrue(SkipIntFlag.C in (SkipIntFlag.A|SkipIntFlag.C))
self.assertEqual(SkipIntFlag(42).value, 42)
#
class MethodHint(Flag):
HiddenText = 0x10
DigitsOnly = 0x01
LettersOnly = 0x02
OnlyMask = 0x0f
#
self.assertEqual(str(MethodHint.HiddenText|MethodHint.OnlyMask), 'MethodHint.HiddenText|DigitsOnly|LettersOnly|OnlyMask')


def test_iter(self):
Color = self.Color
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Set default Flag boundary to ``STRICT`` and fix bitwise operations.