From 2194071540313e2bbdc7214d77453b9ce3034a5c Mon Sep 17 00:00:00 2001 From: Ethan Furman Date: Thu, 13 Apr 2023 08:24:33 -0700 Subject: [PATCH] gh-103365: [Enum] STRICT boundary corrections (GH-103494) STRICT boundary: - fix bitwise operations - make default for Flag --- Doc/library/enum.rst | 5 +- Lib/enum.py | 67 +++++++++++-------- Lib/test/test_enum.py | 47 ++++++++++--- ...-04-12-17-59-55.gh-issue-103365.UBEE0U.rst | 1 + 4 files changed, 82 insertions(+), 38 deletions(-) create mode 100644 Misc/NEWS.d/next/Library/2023-04-12-17-59-55.gh-issue-103365.UBEE0U.rst diff --git a/Doc/library/enum.rst b/Doc/library/enum.rst index c690c837309ea5..07acf9da33e275 100644 --- a/Doc/library/enum.rst +++ b/Doc/library/enum.rst @@ -696,7 +696,8 @@ Data Types .. attribute:: STRICT - Out-of-range values cause a :exc:`ValueError` to be raised:: + Out-of-range values cause a :exc:`ValueError` to be raised. This is the + default for :class:`Flag`:: >>> from enum import Flag, STRICT, auto >>> class StrictFlag(Flag, boundary=STRICT): @@ -714,7 +715,7 @@ Data Types .. attribute:: CONFORM Out-of-range values have invalid values removed, leaving a valid *Flag* - value. This is the default for :class:`Flag`:: + value:: >>> from enum import Flag, CONFORM, auto >>> class ConformFlag(Flag, boundary=CONFORM): diff --git a/Lib/enum.py b/Lib/enum.py index 10902c4b202a2d..432d7456b4b9f1 100644 --- a/Lib/enum.py +++ b/Lib/enum.py @@ -275,6 +275,13 @@ def __set_name__(self, enum_class, member_name): enum_member.__objclass__ = enum_class enum_member.__init__(*args) enum_member._sort_order_ = len(enum_class._member_names_) + + if Flag is not None and issubclass(enum_class, Flag): + enum_class._flag_mask_ |= value + if _is_single_bit(value): + enum_class._singles_mask_ |= value + enum_class._all_bits_ = 2 ** ((enum_class._flag_mask_).bit_length()) - 1 + # If another member with the same value was already defined, the # new member becomes an alias to the existing one. try: @@ -532,12 +539,8 @@ def __new__(metacls, cls, bases, classdict, *, boundary=None, _simple=False, **k classdict['_use_args_'] = use_args # # convert future enum members into temporary _proto_members - # and record integer values in case this will be a Flag - flag_mask = 0 for name in member_names: value = classdict[name] - if isinstance(value, int): - flag_mask |= value classdict[name] = _proto_member(value) # # house-keeping structures @@ -554,8 +557,9 @@ def __new__(metacls, cls, bases, classdict, *, boundary=None, _simple=False, **k boundary or getattr(first_enum, '_boundary_', None) ) - classdict['_flag_mask_'] = flag_mask - classdict['_all_bits_'] = 2 ** ((flag_mask).bit_length()) - 1 + classdict['_flag_mask_'] = 0 + classdict['_singles_mask_'] = 0 + classdict['_all_bits_'] = 0 classdict['_inverted_'] = None try: exc = None @@ -644,21 +648,10 @@ def __new__(metacls, cls, bases, classdict, *, boundary=None, _simple=False, **k ): delattr(enum_class, '_boundary_') delattr(enum_class, '_flag_mask_') + delattr(enum_class, '_singles_mask_') delattr(enum_class, '_all_bits_') delattr(enum_class, '_inverted_') elif Flag is not None and issubclass(enum_class, Flag): - # ensure _all_bits_ is correct and there are no missing flags - single_bit_total = 0 - multi_bit_total = 0 - for flag in enum_class._member_map_.values(): - flag_value = flag._value_ - if _is_single_bit(flag_value): - single_bit_total |= flag_value - else: - # multi-bit flags are considered aliases - multi_bit_total |= flag_value - enum_class._flag_mask_ = single_bit_total - # # set correct __iter__ member_list = [m._value_ for m in enum_class] if member_list != sorted(member_list): @@ -1303,8 +1296,8 @@ def _reduce_ex_by_global_name(self, proto): class FlagBoundary(StrEnum): """ control how out of range values are handled - "strict" -> error is raised - "conform" -> extra bits are discarded [default for Flag] + "strict" -> error is raised [default for Flag] + "conform" -> extra bits are discarded "eject" -> lose flag status "keep" -> keep flag status and all bits [default for IntFlag] """ @@ -1315,7 +1308,7 @@ class FlagBoundary(StrEnum): STRICT, CONFORM, EJECT, KEEP = FlagBoundary -class Flag(Enum, boundary=CONFORM): +class Flag(Enum, boundary=STRICT): """ Support for flags """ @@ -1394,6 +1387,7 @@ def _missing_(cls, value): # - value must not include any skipped flags (e.g. if bit 2 is not # defined, then 0d10 is invalid) flag_mask = cls._flag_mask_ + singles_mask = cls._singles_mask_ all_bits = cls._all_bits_ neg_value = None if ( @@ -1425,7 +1419,8 @@ def _missing_(cls, value): value = all_bits + 1 + value # get members and unknown unknown = value & ~flag_mask - member_value = value & flag_mask + aliases = value & ~singles_mask + member_value = value & singles_mask if unknown and cls._boundary_ is not KEEP: raise ValueError( '%s(%r) --> unknown values %r [%s]' @@ -1439,11 +1434,25 @@ def _missing_(cls, value): pseudo_member = cls._member_type_.__new__(cls, value) if not hasattr(pseudo_member, '_value_'): pseudo_member._value_ = value - if member_value: - pseudo_member._name_ = '|'.join([ - m._name_ for m in cls._iter_member_(member_value) - ]) - if unknown: + if member_value or aliases: + members = [] + combined_value = 0 + for m in cls._iter_member_(member_value): + members.append(m) + combined_value |= m._value_ + if aliases: + value = member_value | aliases + for n, pm in cls._member_map_.items(): + if pm not in members and pm._value_ and pm._value_ & value == pm._value_: + members.append(pm) + combined_value |= pm._value_ + unknown = value ^ combined_value + pseudo_member._name_ = '|'.join([m._name_ for m in members]) + if not combined_value: + pseudo_member._name_ = None + elif unknown and cls._boundary_ is STRICT: + raise ValueError('%r: no members with value %r' % (cls, unknown)) + elif unknown: pseudo_member._name_ += '|%s' % cls._numeric_repr_(unknown) else: pseudo_member._name_ = None @@ -1675,6 +1684,7 @@ def convert_class(cls): body['_boundary_'] = boundary or etype._boundary_ body['_flag_mask_'] = None body['_all_bits_'] = None + body['_singles_mask_'] = None body['_inverted_'] = None body['__or__'] = Flag.__or__ body['__xor__'] = Flag.__xor__ @@ -1750,7 +1760,8 @@ def convert_class(cls): else: multi_bits |= value gnv_last_values.append(value) - enum_class._flag_mask_ = single_bits + enum_class._flag_mask_ = single_bits | multi_bits + enum_class._singles_mask_ = single_bits enum_class._all_bits_ = 2 ** ((single_bits|multi_bits).bit_length()) - 1 # set correct __iter__ member_list = [m._value_ for m in enum_class] diff --git a/Lib/test/test_enum.py b/Lib/test/test_enum.py index e4151bf9e6d9b3..89294e95df2a83 100644 --- a/Lib/test/test_enum.py +++ b/Lib/test/test_enum.py @@ -2873,6 +2873,8 @@ def __new__(cls, c): # a = ord('a') # + self.assertEqual(FlagFromChar._all_bits_, 316912650057057350374175801343) + self.assertEqual(FlagFromChar._flag_mask_, 158456325028528675187087900672) self.assertEqual(FlagFromChar.a, 158456325028528675187087900672) self.assertEqual(FlagFromChar.a|1, 158456325028528675187087900673) # @@ -2887,6 +2889,8 @@ def __new__(cls, c): a = ord('a') z = 1 # + self.assertEqual(FlagFromChar._all_bits_, 316912650057057350374175801343) + self.assertEqual(FlagFromChar._flag_mask_, 158456325028528675187087900674) self.assertEqual(FlagFromChar.a.value, 158456325028528675187087900672) self.assertEqual((FlagFromChar.a|FlagFromChar.z).value, 158456325028528675187087900674) # @@ -2900,6 +2904,8 @@ def __new__(cls, c): # a = ord('a') # + self.assertEqual(FlagFromChar._all_bits_, 316912650057057350374175801343) + self.assertEqual(FlagFromChar._flag_mask_, 158456325028528675187087900672) self.assertEqual(FlagFromChar.a, 158456325028528675187087900672) self.assertEqual(FlagFromChar.a|1, 158456325028528675187087900673) @@ -3077,18 +3083,18 @@ def test_bool(self): self.assertEqual(bool(f.value), bool(f)) def test_boundary(self): - self.assertIs(enum.Flag._boundary_, CONFORM) - class Iron(Flag, boundary=STRICT): + self.assertIs(enum.Flag._boundary_, STRICT) + class Iron(Flag, boundary=CONFORM): ONE = 1 TWO = 2 EIGHT = 8 - self.assertIs(Iron._boundary_, STRICT) + self.assertIs(Iron._boundary_, CONFORM) # - class Water(Flag, boundary=CONFORM): + class Water(Flag, boundary=STRICT): ONE = 1 TWO = 2 EIGHT = 8 - self.assertIs(Water._boundary_, CONFORM) + self.assertIs(Water._boundary_, STRICT) # class Space(Flag, boundary=EJECT): ONE = 1 @@ -3101,10 +3107,10 @@ class Bizarre(Flag, boundary=KEEP): c = 4 d = 6 # - self.assertRaisesRegex(ValueError, 'invalid value 7', Iron, 7) + self.assertRaisesRegex(ValueError, 'invalid value 7', Water, 7) # - self.assertIs(Water(7), Water.ONE|Water.TWO) - self.assertIs(Water(~9), Water.TWO) + self.assertIs(Iron(7), Iron.ONE|Iron.TWO) + self.assertIs(Iron(~9), Iron.TWO) # self.assertEqual(Space(7), 7) self.assertTrue(type(Space(7)) is int) @@ -3112,6 +3118,31 @@ class Bizarre(Flag, boundary=KEEP): self.assertEqual(list(Bizarre), [Bizarre.c]) self.assertIs(Bizarre(3), Bizarre.b) self.assertIs(Bizarre(6), Bizarre.d) + # + class SkipFlag(enum.Flag): + A = 1 + B = 2 + C = 4 | B + # + self.assertTrue(SkipFlag.C in (SkipFlag.A|SkipFlag.C)) + self.assertRaisesRegex(ValueError, 'SkipFlag.. invalid value 42', SkipFlag, 42) + # + class SkipIntFlag(enum.IntFlag): + A = 1 + B = 2 + C = 4 | B + # + self.assertTrue(SkipIntFlag.C in (SkipIntFlag.A|SkipIntFlag.C)) + self.assertEqual(SkipIntFlag(42).value, 42) + # + class MethodHint(Flag): + HiddenText = 0x10 + DigitsOnly = 0x01 + LettersOnly = 0x02 + OnlyMask = 0x0f + # + self.assertEqual(str(MethodHint.HiddenText|MethodHint.OnlyMask), 'MethodHint.HiddenText|DigitsOnly|LettersOnly|OnlyMask') + def test_iter(self): Color = self.Color diff --git a/Misc/NEWS.d/next/Library/2023-04-12-17-59-55.gh-issue-103365.UBEE0U.rst b/Misc/NEWS.d/next/Library/2023-04-12-17-59-55.gh-issue-103365.UBEE0U.rst new file mode 100644 index 00000000000000..4d69f6f6fff713 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2023-04-12-17-59-55.gh-issue-103365.UBEE0U.rst @@ -0,0 +1 @@ +Set default Flag boundary to ``STRICT`` and fix bitwise operations.