From af2afd7c224389856e118d3f544c9621a016599f Mon Sep 17 00:00:00 2001 From: Victorien <65306057+Viicos@users.noreply.github.com> Date: Fri, 28 Mar 2025 05:56:09 +0100 Subject: gh-119180: Use equality when comparing against `annotationlib.Format` (#131755) --- Lib/test/test_annotationlib.py | 4 ++-- Lib/test/test_typing.py | 2 ++ Lib/typing.py | 6 +++--- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/Lib/test/test_annotationlib.py b/Lib/test/test_annotationlib.py index 20f74b4..495606b 100644 --- a/Lib/test/test_annotationlib.py +++ b/Lib/test/test_annotationlib.py @@ -517,7 +517,7 @@ class TestGetAnnotations(unittest.TestCase): foo.__annotations__ = {"a": "foo", "b": "str"} for format in Format: - if format is Format.VALUE_WITH_FAKE_GLOBALS: + if format == Format.VALUE_WITH_FAKE_GLOBALS: continue with self.subTest(format=format): self.assertEqual( @@ -816,7 +816,7 @@ class TestGetAnnotations(unittest.TestCase): wa = WeirdAnnotations() for format in Format: - if format is Format.VALUE_WITH_FAKE_GLOBALS: + if format == Format.VALUE_WITH_FAKE_GLOBALS: continue with ( self.subTest(format=format), diff --git a/Lib/test/test_typing.py b/Lib/test/test_typing.py index 4023534..2c02973 100644 --- a/Lib/test/test_typing.py +++ b/Lib/test/test_typing.py @@ -7158,6 +7158,8 @@ class GetTypeHintTests(BaseTestCase): self.assertEqual(get_type_hints(C, format=annotationlib.Format.STRING), {'x': 'undefined'}) + # Make sure using an int as format also works: + self.assertEqual(get_type_hints(C, format=4), {'x': 'undefined'}) def test_get_type_hints_format_function(self): def func(x: undefined) -> undefined: ... diff --git a/Lib/typing.py b/Lib/typing.py index 9621155..e36da7e 100644 --- a/Lib/typing.py +++ b/Lib/typing.py @@ -2315,7 +2315,7 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False, hints = {} for base in reversed(obj.__mro__): ann = annotationlib.get_annotations(base, format=format) - if format is annotationlib.Format.STRING: + if format == annotationlib.Format.STRING: hints.update(ann) continue if globalns is None: @@ -2339,7 +2339,7 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False, value = _eval_type(value, base_globals, base_locals, base.__type_params__, format=format, owner=obj) hints[name] = value - if include_extras or format is annotationlib.Format.STRING: + if include_extras or format == annotationlib.Format.STRING: return hints else: return {k: _strip_annotations(t) for k, t in hints.items()} @@ -2353,7 +2353,7 @@ def get_type_hints(obj, globalns=None, localns=None, include_extras=False, and not hasattr(obj, '__annotate__') ): raise TypeError(f"{obj!r} is not a module, class, or callable.") - if format is annotationlib.Format.STRING: + if format == annotationlib.Format.STRING: return hints if globalns is None: -- cgit v0.12