summaryrefslogtreecommitdiffstats
path: root/Lib/functools.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/functools.py')
-rw-r--r--Lib/functools.py35
1 files changed, 28 insertions, 7 deletions
diff --git a/Lib/functools.py b/Lib/functools.py
index 77ec852..ccac6f8 100644
--- a/Lib/functools.py
+++ b/Lib/functools.py
@@ -837,6 +837,14 @@ def singledispatch(func):
dispatch_cache[cls] = impl
return impl
+ def _is_union_type(cls):
+ from typing import get_origin, Union
+ return get_origin(cls) in {Union, types.UnionType}
+
+ def _is_valid_union_type(cls):
+ from typing import get_args
+ return _is_union_type(cls) and all(isinstance(arg, type) for arg in get_args(cls))
+
def register(cls, func=None):
"""generic_func.register(cls, func) -> func
@@ -845,7 +853,7 @@ def singledispatch(func):
"""
nonlocal cache_token
if func is None:
- if isinstance(cls, type):
+ if isinstance(cls, type) or _is_valid_union_type(cls):
return lambda f: register(cls, f)
ann = getattr(cls, '__annotations__', {})
if not ann:
@@ -859,12 +867,25 @@ def singledispatch(func):
# only import typing if annotation parsing is necessary
from typing import get_type_hints
argname, cls = next(iter(get_type_hints(func).items()))
- if not isinstance(cls, type):
- raise TypeError(
- f"Invalid annotation for {argname!r}. "
- f"{cls!r} is not a class."
- )
- registry[cls] = func
+ if not isinstance(cls, type) and not _is_valid_union_type(cls):
+ if _is_union_type(cls):
+ raise TypeError(
+ f"Invalid annotation for {argname!r}. "
+ f"{cls!r} not all arguments are classes."
+ )
+ else:
+ raise TypeError(
+ f"Invalid annotation for {argname!r}. "
+ f"{cls!r} is not a class."
+ )
+
+ if _is_union_type(cls):
+ from typing import get_args
+
+ for arg in get_args(cls):
+ registry[arg] = func
+ else:
+ registry[cls] = func
if cache_token is None and hasattr(cls, '__abstractmethods__'):
cache_token = get_cache_token()
dispatch_cache.clear()