summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Lib/annotationlib.py81
-rw-r--r--Lib/test/test_annotationlib.py88
2 files changed, 143 insertions, 26 deletions
diff --git a/Lib/annotationlib.py b/Lib/annotationlib.py
index be3bc27..20c9542 100644
--- a/Lib/annotationlib.py
+++ b/Lib/annotationlib.py
@@ -664,28 +664,38 @@ def get_annotations(
if eval_str and format != Format.VALUE:
raise ValueError("eval_str=True is only supported with format=Format.VALUE")
- # For VALUE format, we look at __annotations__ directly.
- if format != Format.VALUE:
- annotate = get_annotate_function(obj)
- if annotate is not None:
- ann = call_annotate_function(annotate, format, owner=obj)
- if not isinstance(ann, dict):
- raise ValueError(f"{obj!r}.__annotate__ returned a non-dict")
- return dict(ann)
-
- if isinstance(obj, type):
- try:
- ann = _BASE_GET_ANNOTATIONS(obj)
- except AttributeError:
- # For static types, the descriptor raises AttributeError.
- return {}
- else:
- ann = getattr(obj, "__annotations__", None)
- if ann is None:
- return {}
-
- if not isinstance(ann, dict):
- raise ValueError(f"{obj!r}.__annotations__ is neither a dict nor None")
+ match format:
+ case Format.VALUE:
+ # For VALUE, we only look at __annotations__
+ ann = _get_dunder_annotations(obj)
+ case Format.FORWARDREF:
+ # For FORWARDREF, we use __annotations__ if it exists
+ try:
+ ann = _get_dunder_annotations(obj)
+ except NameError:
+ pass
+ else:
+ return dict(ann)
+
+ # But if __annotations__ threw a NameError, we try calling __annotate__
+ ann = _get_and_call_annotate(obj, format)
+ if ann is not None:
+ return ann
+
+ # If that didn't work either, we have a very weird object: evaluating
+ # __annotations__ threw NameError and there is no __annotate__. In that case,
+ # we fall back to trying __annotations__ again.
+ return dict(_get_dunder_annotations(obj))
+ case Format.SOURCE:
+ # For SOURCE, we try to call __annotate__
+ ann = _get_and_call_annotate(obj, format)
+ if ann is not None:
+ return ann
+ # But if we didn't get it, we use __annotations__ instead.
+ ann = _get_dunder_annotations(obj)
+ return ann
+ case _:
+ raise ValueError(f"Unsupported format {format!r}")
if not ann:
return {}
@@ -750,3 +760,30 @@ def get_annotations(
for key, value in ann.items()
}
return return_value
+
+
+def _get_and_call_annotate(obj, format):
+ annotate = get_annotate_function(obj)
+ if annotate is not None:
+ ann = call_annotate_function(annotate, format, owner=obj)
+ if not isinstance(ann, dict):
+ raise ValueError(f"{obj!r}.__annotate__ returned a non-dict")
+ return dict(ann)
+ return None
+
+
+def _get_dunder_annotations(obj):
+ if isinstance(obj, type):
+ try:
+ ann = _BASE_GET_ANNOTATIONS(obj)
+ except AttributeError:
+ # For static types, the descriptor raises AttributeError.
+ return {}
+ else:
+ ann = getattr(obj, "__annotations__", None)
+ if ann is None:
+ return {}
+
+ if not isinstance(ann, dict):
+ raise ValueError(f"{obj!r}.__annotations__ is neither a dict nor None")
+ return dict(ann)
diff --git a/Lib/test/test_annotationlib.py b/Lib/test/test_annotationlib.py
index cc051ef..5b052da 100644
--- a/Lib/test/test_annotationlib.py
+++ b/Lib/test/test_annotationlib.py
@@ -740,17 +740,97 @@ class TestGetAnnotations(unittest.TestCase):
self.assertEqual(annotationlib.get_annotations(f), {"x": int})
self.assertEqual(
- annotationlib.get_annotations(f, format=annotationlib.Format.FORWARDREF),
+ annotationlib.get_annotations(f, format=Format.FORWARDREF),
{"x": int},
)
f.__annotations__["x"] = str
# The modification is reflected in VALUE (the default)
self.assertEqual(annotationlib.get_annotations(f), {"x": str})
- # ... but not in FORWARDREF, which uses __annotate__
+ # ... and also in FORWARDREF, which tries __annotations__ if available
self.assertEqual(
- annotationlib.get_annotations(f, format=annotationlib.Format.FORWARDREF),
- {"x": int},
+ annotationlib.get_annotations(f, format=Format.FORWARDREF),
+ {"x": str},
+ )
+ # ... but not in SOURCE which always uses __annotate__
+ self.assertEqual(
+ annotationlib.get_annotations(f, format=Format.SOURCE),
+ {"x": "int"},
+ )
+
+ def test_non_dict_annotations(self):
+ class WeirdAnnotations:
+ @property
+ def __annotations__(self):
+ return "not a dict"
+
+ wa = WeirdAnnotations()
+ for format in Format:
+ with (
+ self.subTest(format=format),
+ self.assertRaisesRegex(
+ ValueError, r".*__annotations__ is neither a dict nor None"
+ ),
+ ):
+ annotationlib.get_annotations(wa, format=format)
+
+ def test_annotations_on_custom_object(self):
+ class HasAnnotations:
+ @property
+ def __annotations__(self):
+ return {"x": int}
+
+ ha = HasAnnotations()
+ self.assertEqual(
+ annotationlib.get_annotations(ha, format=Format.VALUE), {"x": int}
+ )
+ self.assertEqual(
+ annotationlib.get_annotations(ha, format=Format.FORWARDREF), {"x": int}
+ )
+
+ # TODO(gh-124412): This should return {'x': 'int'} instead.
+ self.assertEqual(
+ annotationlib.get_annotations(ha, format=Format.SOURCE), {"x": int}
+ )
+
+ def test_raising_annotations_on_custom_object(self):
+ class HasRaisingAnnotations:
+ @property
+ def __annotations__(self):
+ return {"x": undefined}
+
+ hra = HasRaisingAnnotations()
+
+ with self.assertRaises(NameError):
+ annotationlib.get_annotations(hra, format=Format.VALUE)
+
+ with self.assertRaises(NameError):
+ annotationlib.get_annotations(hra, format=Format.FORWARDREF)
+
+ undefined = float
+ self.assertEqual(
+ annotationlib.get_annotations(hra, format=Format.VALUE), {"x": float}
+ )
+
+ def test_forwardref_prefers_annotations(self):
+ class HasBoth:
+ @property
+ def __annotations__(self):
+ return {"x": int}
+
+ @property
+ def __annotate__(self):
+ return lambda format: {"x": str}
+
+ hb = HasBoth()
+ self.assertEqual(
+ annotationlib.get_annotations(hb, format=Format.VALUE), {"x": int}
+ )
+ self.assertEqual(
+ annotationlib.get_annotations(hb, format=Format.FORWARDREF), {"x": int}
+ )
+ self.assertEqual(
+ annotationlib.get_annotations(hb, format=Format.SOURCE), {"x": str}
)
def test_pep695_generic_class_with_future_annotations(self):