diff options
Diffstat (limited to 'Lib/functools.py')
-rw-r--r-- | Lib/functools.py | 35 |
1 files changed, 28 insertions, 7 deletions
diff --git a/Lib/functools.py b/Lib/functools.py index 77ec852..ccac6f8 100644 --- a/Lib/functools.py +++ b/Lib/functools.py @@ -837,6 +837,14 @@ def singledispatch(func): dispatch_cache[cls] = impl return impl + def _is_union_type(cls): + from typing import get_origin, Union + return get_origin(cls) in {Union, types.UnionType} + + def _is_valid_union_type(cls): + from typing import get_args + return _is_union_type(cls) and all(isinstance(arg, type) for arg in get_args(cls)) + def register(cls, func=None): """generic_func.register(cls, func) -> func @@ -845,7 +853,7 @@ def singledispatch(func): """ nonlocal cache_token if func is None: - if isinstance(cls, type): + if isinstance(cls, type) or _is_valid_union_type(cls): return lambda f: register(cls, f) ann = getattr(cls, '__annotations__', {}) if not ann: @@ -859,12 +867,25 @@ def singledispatch(func): # only import typing if annotation parsing is necessary from typing import get_type_hints argname, cls = next(iter(get_type_hints(func).items())) - if not isinstance(cls, type): - raise TypeError( - f"Invalid annotation for {argname!r}. " - f"{cls!r} is not a class." - ) - registry[cls] = func + if not isinstance(cls, type) and not _is_valid_union_type(cls): + if _is_union_type(cls): + raise TypeError( + f"Invalid annotation for {argname!r}. " + f"{cls!r} not all arguments are classes." + ) + else: + raise TypeError( + f"Invalid annotation for {argname!r}. " + f"{cls!r} is not a class." + ) + + if _is_union_type(cls): + from typing import get_args + + for arg in get_args(cls): + registry[arg] = func + else: + registry[cls] = func if cache_token is None and hasattr(cls, '__abstractmethods__'): cache_token = get_cache_token() dispatch_cache.clear() |