diff options
-rw-r--r-- | Lib/functools.py | 92 | ||||
-rw-r--r-- | Lib/test/test_functools.py | 175 | ||||
-rw-r--r-- | Misc/NEWS | 5 | ||||
-rw-r--r-- | Modules/_functoolsmodule.c | 2 |
4 files changed, 182 insertions, 92 deletions
diff --git a/Lib/functools.py b/Lib/functools.py index 214523c..9845df2 100644 --- a/Lib/functools.py +++ b/Lib/functools.py @@ -21,6 +21,7 @@ from abc import get_cache_token from collections import namedtuple from types import MappingProxyType from weakref import WeakKeyDictionary +from reprlib import recursive_repr try: from _thread import RLock except ImportError: @@ -237,26 +238,83 @@ except ImportError: ################################################################################ # Purely functional, no descriptor behaviour -def partial(func, *args, **keywords): +class partial: """New function with partial application of the given arguments and keywords. """ - if hasattr(func, 'func'): - args = func.args + args - tmpkw = func.keywords.copy() - tmpkw.update(keywords) - keywords = tmpkw - del tmpkw - func = func.func - - def newfunc(*fargs, **fkeywords): - newkeywords = keywords.copy() - newkeywords.update(fkeywords) - return func(*(args + fargs), **newkeywords) - newfunc.func = func - newfunc.args = args - newfunc.keywords = keywords - return newfunc + + __slots__ = "func", "args", "keywords", "__dict__", "__weakref__" + + def __new__(*args, **keywords): + if not args: + raise TypeError("descriptor '__new__' of partial needs an argument") + if len(args) < 2: + raise TypeError("type 'partial' takes at least one argument") + cls, func, *args = args + if not callable(func): + raise TypeError("the first argument must be callable") + args = tuple(args) + + if hasattr(func, "func"): + args = func.args + args + tmpkw = func.keywords.copy() + tmpkw.update(keywords) + keywords = tmpkw + del tmpkw + func = func.func + + self = super(partial, cls).__new__(cls) + + self.func = func + self.args = args + self.keywords = keywords + return self + + def __call__(*args, **keywords): + if not args: + raise TypeError("descriptor '__call__' of partial needs an argument") + self, *args = args + newkeywords = self.keywords.copy() + newkeywords.update(keywords) + return self.func(*self.args, *args, **newkeywords) + + @recursive_repr() + def __repr__(self): + qualname = type(self).__qualname__ + args = [repr(self.func)] + args.extend(repr(x) for x in self.args) + args.extend(f"{k}={v!r}" for (k, v) in self.keywords.items()) + if type(self).__module__ == "functools": + return f"functools.{qualname}({', '.join(args)})" + return f"{qualname}({', '.join(args)})" + + def __reduce__(self): + return type(self), (self.func,), (self.func, self.args, + self.keywords or None, self.__dict__ or None) + + def __setstate__(self, state): + if not isinstance(state, tuple): + raise TypeError("argument to __setstate__ must be a tuple") + if len(state) != 4: + raise TypeError(f"expected 4 items in state, got {len(state)}") + func, args, kwds, namespace = state + if (not callable(func) or not isinstance(args, tuple) or + (kwds is not None and not isinstance(kwds, dict)) or + (namespace is not None and not isinstance(namespace, dict))): + raise TypeError("invalid partial state") + + args = tuple(args) # just in case it's a subclass + if kwds is None: + kwds = {} + elif type(kwds) is not dict: # XXX does it need to be *exactly* dict? + kwds = dict(kwds) + if namespace is None: + namespace = {} + + self.__dict__ = namespace + self.func = func + self.args = args + self.keywords = kwds try: from _functools import partial diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py index 40f2234..fa66510 100644 --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -8,6 +8,7 @@ import sys from test import support import unittest from weakref import proxy +import contextlib try: import threading except ImportError: @@ -20,6 +21,14 @@ c_functools = support.import_fresh_module('functools', fresh=['_functools']) decimal = support.import_fresh_module('decimal', fresh=['_decimal']) +@contextlib.contextmanager +def replaced_module(name, replacement): + original_module = sys.modules[name] + sys.modules[name] = replacement + try: + yield + finally: + sys.modules[name] = original_module def capture(*args, **kw): """capture all positional and keyword arguments""" @@ -167,58 +176,35 @@ class TestPartial: p2.new_attr = 'spam' self.assertEqual(p2.new_attr, 'spam') - -@unittest.skipUnless(c_functools, 'requires the C _functools module') -class TestPartialC(TestPartial, unittest.TestCase): - if c_functools: - partial = c_functools.partial - - def test_attributes_unwritable(self): - # attributes should not be writable - p = self.partial(capture, 1, 2, a=10, b=20) - self.assertRaises(AttributeError, setattr, p, 'func', map) - self.assertRaises(AttributeError, setattr, p, 'args', (1, 2)) - self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2)) - - p = self.partial(hex) - try: - del p.__dict__ - except TypeError: - pass - else: - self.fail('partial object allowed __dict__ to be deleted') - def test_repr(self): args = (object(), object()) args_repr = ', '.join(repr(a) for a in args) kwargs = {'a': object(), 'b': object()} kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs), 'b={b!r}, a={a!r}'.format_map(kwargs)] - if self.partial is c_functools.partial: + if self.partial in (c_functools.partial, py_functools.partial): name = 'functools.partial' else: name = self.partial.__name__ f = self.partial(capture) - self.assertEqual('{}({!r})'.format(name, capture), - repr(f)) + self.assertEqual(f'{name}({capture!r})', repr(f)) f = self.partial(capture, *args) - self.assertEqual('{}({!r}, {})'.format(name, capture, args_repr), - repr(f)) + self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f)) f = self.partial(capture, **kwargs) self.assertIn(repr(f), - ['{}({!r}, {})'.format(name, capture, kwargs_repr) + [f'{name}({capture!r}, {kwargs_repr})' for kwargs_repr in kwargs_reprs]) f = self.partial(capture, *args, **kwargs) self.assertIn(repr(f), - ['{}({!r}, {}, {})'.format(name, capture, args_repr, kwargs_repr) + [f'{name}({capture!r}, {args_repr}, {kwargs_repr})' for kwargs_repr in kwargs_reprs]) def test_recursive_repr(self): - if self.partial is c_functools.partial: + if self.partial in (c_functools.partial, py_functools.partial): name = 'functools.partial' else: name = self.partial.__name__ @@ -226,30 +212,31 @@ class TestPartialC(TestPartial, unittest.TestCase): f = self.partial(capture) f.__setstate__((f, (), {}, {})) try: - self.assertEqual(repr(f), '%s(%s(...))' % (name, name)) + self.assertEqual(repr(f), '%s(...)' % (name,)) finally: f.__setstate__((capture, (), {}, {})) f = self.partial(capture) f.__setstate__((capture, (f,), {}, {})) try: - self.assertEqual(repr(f), '%s(%r, %s(...))' % (name, capture, name)) + self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,)) finally: f.__setstate__((capture, (), {}, {})) f = self.partial(capture) f.__setstate__((capture, (), {'a': f}, {})) try: - self.assertEqual(repr(f), '%s(%r, a=%s(...))' % (name, capture, name)) + self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,)) finally: f.__setstate__((capture, (), {}, {})) def test_pickle(self): - f = self.partial(signature, ['asdf'], bar=[True]) - f.attr = [] - for proto in range(pickle.HIGHEST_PROTOCOL + 1): - f_copy = pickle.loads(pickle.dumps(f, proto)) - self.assertEqual(signature(f_copy), signature(f)) + with self.AllowPickle(): + f = self.partial(signature, ['asdf'], bar=[True]) + f.attr = [] + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + f_copy = pickle.loads(pickle.dumps(f, proto)) + self.assertEqual(signature(f_copy), signature(f)) def test_copy(self): f = self.partial(signature, ['asdf'], bar=[True]) @@ -274,11 +261,13 @@ class TestPartialC(TestPartial, unittest.TestCase): def test_setstate(self): f = self.partial(signature) f.__setstate__((capture, (1,), dict(a=10), dict(attr=[]))) + self.assertEqual(signature(f), (capture, (1,), dict(a=10), dict(attr=[]))) self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20})) f.__setstate__((capture, (1,), dict(a=10), None)) + self.assertEqual(signature(f), (capture, (1,), dict(a=10), {})) self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20})) @@ -325,38 +314,39 @@ class TestPartialC(TestPartial, unittest.TestCase): self.assertIs(type(r[0]), tuple) def test_recursive_pickle(self): - f = self.partial(capture) - f.__setstate__((f, (), {}, {})) - try: - for proto in range(pickle.HIGHEST_PROTOCOL + 1): - with self.assertRaises(RecursionError): - pickle.dumps(f, proto) - finally: - f.__setstate__((capture, (), {}, {})) - - f = self.partial(capture) - f.__setstate__((capture, (f,), {}, {})) - try: - for proto in range(pickle.HIGHEST_PROTOCOL + 1): - f_copy = pickle.loads(pickle.dumps(f, proto)) - try: - self.assertIs(f_copy.args[0], f_copy) - finally: - f_copy.__setstate__((capture, (), {}, {})) - finally: - f.__setstate__((capture, (), {}, {})) - - f = self.partial(capture) - f.__setstate__((capture, (), {'a': f}, {})) - try: - for proto in range(pickle.HIGHEST_PROTOCOL + 1): - f_copy = pickle.loads(pickle.dumps(f, proto)) - try: - self.assertIs(f_copy.keywords['a'], f_copy) - finally: - f_copy.__setstate__((capture, (), {}, {})) - finally: - f.__setstate__((capture, (), {}, {})) + with self.AllowPickle(): + f = self.partial(capture) + f.__setstate__((f, (), {}, {})) + try: + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + with self.assertRaises(RecursionError): + pickle.dumps(f, proto) + finally: + f.__setstate__((capture, (), {}, {})) + + f = self.partial(capture) + f.__setstate__((capture, (f,), {}, {})) + try: + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + f_copy = pickle.loads(pickle.dumps(f, proto)) + try: + self.assertIs(f_copy.args[0], f_copy) + finally: + f_copy.__setstate__((capture, (), {}, {})) + finally: + f.__setstate__((capture, (), {}, {})) + + f = self.partial(capture) + f.__setstate__((capture, (), {'a': f}, {})) + try: + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + f_copy = pickle.loads(pickle.dumps(f, proto)) + try: + self.assertIs(f_copy.keywords['a'], f_copy) + finally: + f_copy.__setstate__((capture, (), {}, {})) + finally: + f.__setstate__((capture, (), {}, {})) # Issue 6083: Reference counting bug def test_setstate_refcount(self): @@ -375,24 +365,60 @@ class TestPartialC(TestPartial, unittest.TestCase): f = self.partial(object) self.assertRaises(TypeError, f.__setstate__, BadSequence()) +@unittest.skipUnless(c_functools, 'requires the C _functools module') +class TestPartialC(TestPartial, unittest.TestCase): + if c_functools: + partial = c_functools.partial + + class AllowPickle: + def __enter__(self): + return self + def __exit__(self, type, value, tb): + return False + + def test_attributes_unwritable(self): + # attributes should not be writable + p = self.partial(capture, 1, 2, a=10, b=20) + self.assertRaises(AttributeError, setattr, p, 'func', map) + self.assertRaises(AttributeError, setattr, p, 'args', (1, 2)) + self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2)) + + p = self.partial(hex) + try: + del p.__dict__ + except TypeError: + pass + else: + self.fail('partial object allowed __dict__ to be deleted') class TestPartialPy(TestPartial, unittest.TestCase): - partial = staticmethod(py_functools.partial) + partial = py_functools.partial + class AllowPickle: + def __init__(self): + self._cm = replaced_module("functools", py_functools) + def __enter__(self): + return self._cm.__enter__() + def __exit__(self, type, value, tb): + return self._cm.__exit__(type, value, tb) if c_functools: - class PartialSubclass(c_functools.partial): + class CPartialSubclass(c_functools.partial): pass +class PyPartialSubclass(py_functools.partial): + pass @unittest.skipUnless(c_functools, 'requires the C _functools module') class TestPartialCSubclass(TestPartialC): if c_functools: - partial = PartialSubclass + partial = CPartialSubclass # partial subclasses are not optimized for nested calls test_nested_optimization = None +class TestPartialPySubclass(TestPartialPy): + partial = PyPartialSubclass class TestPartialMethod(unittest.TestCase): @@ -683,9 +709,10 @@ class TestWraps(TestUpdateWrapper): self.assertEqual(wrapper.attr, 'This is a different test') self.assertEqual(wrapper.dict_attr, f.dict_attr) - +@unittest.skipUnless(c_functools, 'requires the C _functools module') class TestReduce(unittest.TestCase): - func = functools.reduce + if c_functools: + func = c_functools.reduce def test_reduce(self): class Squares: @@ -135,6 +135,11 @@ Core and Builtins Library ------- +- Issue #27137: the pure Python fallback implementation of ``functools.partial`` + now matches the behaviour of its accelerated C counterpart for subclassing, + pickling and text representation purposes. Patch by Emanuel Barry and + Serhiy Storchaka. + - Issue #28019: itertools.count() no longer rounds non-integer step in range between 1.0 and 2.0 to 1. diff --git a/Modules/_functoolsmodule.c b/Modules/_functoolsmodule.c index 848a03c..fa5fad3 100644 --- a/Modules/_functoolsmodule.c +++ b/Modules/_functoolsmodule.c @@ -229,7 +229,7 @@ partial_repr(partialobject *pto) if (status != 0) { if (status < 0) return NULL; - return PyUnicode_FromFormat("%s(...)", Py_TYPE(pto)->tp_name); + return PyUnicode_FromString("..."); } arglist = PyUnicode_FromString(""); |