summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Lib/functools.py92
-rw-r--r--Lib/test/test_functools.py175
-rw-r--r--Misc/NEWS5
-rw-r--r--Modules/_functoolsmodule.c2
4 files changed, 182 insertions, 92 deletions
diff --git a/Lib/functools.py b/Lib/functools.py
index 214523c..9845df2 100644
--- a/Lib/functools.py
+++ b/Lib/functools.py
@@ -21,6 +21,7 @@ from abc import get_cache_token
from collections import namedtuple
from types import MappingProxyType
from weakref import WeakKeyDictionary
+from reprlib import recursive_repr
try:
from _thread import RLock
except ImportError:
@@ -237,26 +238,83 @@ except ImportError:
################################################################################
# Purely functional, no descriptor behaviour
-def partial(func, *args, **keywords):
+class partial:
"""New function with partial application of the given arguments
and keywords.
"""
- if hasattr(func, 'func'):
- args = func.args + args
- tmpkw = func.keywords.copy()
- tmpkw.update(keywords)
- keywords = tmpkw
- del tmpkw
- func = func.func
-
- def newfunc(*fargs, **fkeywords):
- newkeywords = keywords.copy()
- newkeywords.update(fkeywords)
- return func(*(args + fargs), **newkeywords)
- newfunc.func = func
- newfunc.args = args
- newfunc.keywords = keywords
- return newfunc
+
+ __slots__ = "func", "args", "keywords", "__dict__", "__weakref__"
+
+ def __new__(*args, **keywords):
+ if not args:
+ raise TypeError("descriptor '__new__' of partial needs an argument")
+ if len(args) < 2:
+ raise TypeError("type 'partial' takes at least one argument")
+ cls, func, *args = args
+ if not callable(func):
+ raise TypeError("the first argument must be callable")
+ args = tuple(args)
+
+ if hasattr(func, "func"):
+ args = func.args + args
+ tmpkw = func.keywords.copy()
+ tmpkw.update(keywords)
+ keywords = tmpkw
+ del tmpkw
+ func = func.func
+
+ self = super(partial, cls).__new__(cls)
+
+ self.func = func
+ self.args = args
+ self.keywords = keywords
+ return self
+
+ def __call__(*args, **keywords):
+ if not args:
+ raise TypeError("descriptor '__call__' of partial needs an argument")
+ self, *args = args
+ newkeywords = self.keywords.copy()
+ newkeywords.update(keywords)
+ return self.func(*self.args, *args, **newkeywords)
+
+ @recursive_repr()
+ def __repr__(self):
+ qualname = type(self).__qualname__
+ args = [repr(self.func)]
+ args.extend(repr(x) for x in self.args)
+ args.extend(f"{k}={v!r}" for (k, v) in self.keywords.items())
+ if type(self).__module__ == "functools":
+ return f"functools.{qualname}({', '.join(args)})"
+ return f"{qualname}({', '.join(args)})"
+
+ def __reduce__(self):
+ return type(self), (self.func,), (self.func, self.args,
+ self.keywords or None, self.__dict__ or None)
+
+ def __setstate__(self, state):
+ if not isinstance(state, tuple):
+ raise TypeError("argument to __setstate__ must be a tuple")
+ if len(state) != 4:
+ raise TypeError(f"expected 4 items in state, got {len(state)}")
+ func, args, kwds, namespace = state
+ if (not callable(func) or not isinstance(args, tuple) or
+ (kwds is not None and not isinstance(kwds, dict)) or
+ (namespace is not None and not isinstance(namespace, dict))):
+ raise TypeError("invalid partial state")
+
+ args = tuple(args) # just in case it's a subclass
+ if kwds is None:
+ kwds = {}
+ elif type(kwds) is not dict: # XXX does it need to be *exactly* dict?
+ kwds = dict(kwds)
+ if namespace is None:
+ namespace = {}
+
+ self.__dict__ = namespace
+ self.func = func
+ self.args = args
+ self.keywords = kwds
try:
from _functools import partial
diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py
index 40f2234..fa66510 100644
--- a/Lib/test/test_functools.py
+++ b/Lib/test/test_functools.py
@@ -8,6 +8,7 @@ import sys
from test import support
import unittest
from weakref import proxy
+import contextlib
try:
import threading
except ImportError:
@@ -20,6 +21,14 @@ c_functools = support.import_fresh_module('functools', fresh=['_functools'])
decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
+@contextlib.contextmanager
+def replaced_module(name, replacement):
+ original_module = sys.modules[name]
+ sys.modules[name] = replacement
+ try:
+ yield
+ finally:
+ sys.modules[name] = original_module
def capture(*args, **kw):
"""capture all positional and keyword arguments"""
@@ -167,58 +176,35 @@ class TestPartial:
p2.new_attr = 'spam'
self.assertEqual(p2.new_attr, 'spam')
-
-@unittest.skipUnless(c_functools, 'requires the C _functools module')
-class TestPartialC(TestPartial, unittest.TestCase):
- if c_functools:
- partial = c_functools.partial
-
- def test_attributes_unwritable(self):
- # attributes should not be writable
- p = self.partial(capture, 1, 2, a=10, b=20)
- self.assertRaises(AttributeError, setattr, p, 'func', map)
- self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
- self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
-
- p = self.partial(hex)
- try:
- del p.__dict__
- except TypeError:
- pass
- else:
- self.fail('partial object allowed __dict__ to be deleted')
-
def test_repr(self):
args = (object(), object())
args_repr = ', '.join(repr(a) for a in args)
kwargs = {'a': object(), 'b': object()}
kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
'b={b!r}, a={a!r}'.format_map(kwargs)]
- if self.partial is c_functools.partial:
+ if self.partial in (c_functools.partial, py_functools.partial):
name = 'functools.partial'
else:
name = self.partial.__name__
f = self.partial(capture)
- self.assertEqual('{}({!r})'.format(name, capture),
- repr(f))
+ self.assertEqual(f'{name}({capture!r})', repr(f))
f = self.partial(capture, *args)
- self.assertEqual('{}({!r}, {})'.format(name, capture, args_repr),
- repr(f))
+ self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f))
f = self.partial(capture, **kwargs)
self.assertIn(repr(f),
- ['{}({!r}, {})'.format(name, capture, kwargs_repr)
+ [f'{name}({capture!r}, {kwargs_repr})'
for kwargs_repr in kwargs_reprs])
f = self.partial(capture, *args, **kwargs)
self.assertIn(repr(f),
- ['{}({!r}, {}, {})'.format(name, capture, args_repr, kwargs_repr)
+ [f'{name}({capture!r}, {args_repr}, {kwargs_repr})'
for kwargs_repr in kwargs_reprs])
def test_recursive_repr(self):
- if self.partial is c_functools.partial:
+ if self.partial in (c_functools.partial, py_functools.partial):
name = 'functools.partial'
else:
name = self.partial.__name__
@@ -226,30 +212,31 @@ class TestPartialC(TestPartial, unittest.TestCase):
f = self.partial(capture)
f.__setstate__((f, (), {}, {}))
try:
- self.assertEqual(repr(f), '%s(%s(...))' % (name, name))
+ self.assertEqual(repr(f), '%s(...)' % (name,))
finally:
f.__setstate__((capture, (), {}, {}))
f = self.partial(capture)
f.__setstate__((capture, (f,), {}, {}))
try:
- self.assertEqual(repr(f), '%s(%r, %s(...))' % (name, capture, name))
+ self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,))
finally:
f.__setstate__((capture, (), {}, {}))
f = self.partial(capture)
f.__setstate__((capture, (), {'a': f}, {}))
try:
- self.assertEqual(repr(f), '%s(%r, a=%s(...))' % (name, capture, name))
+ self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,))
finally:
f.__setstate__((capture, (), {}, {}))
def test_pickle(self):
- f = self.partial(signature, ['asdf'], bar=[True])
- f.attr = []
- for proto in range(pickle.HIGHEST_PROTOCOL + 1):
- f_copy = pickle.loads(pickle.dumps(f, proto))
- self.assertEqual(signature(f_copy), signature(f))
+ with self.AllowPickle():
+ f = self.partial(signature, ['asdf'], bar=[True])
+ f.attr = []
+ for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+ f_copy = pickle.loads(pickle.dumps(f, proto))
+ self.assertEqual(signature(f_copy), signature(f))
def test_copy(self):
f = self.partial(signature, ['asdf'], bar=[True])
@@ -274,11 +261,13 @@ class TestPartialC(TestPartial, unittest.TestCase):
def test_setstate(self):
f = self.partial(signature)
f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
+
self.assertEqual(signature(f),
(capture, (1,), dict(a=10), dict(attr=[])))
self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
f.__setstate__((capture, (1,), dict(a=10), None))
+
self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
@@ -325,38 +314,39 @@ class TestPartialC(TestPartial, unittest.TestCase):
self.assertIs(type(r[0]), tuple)
def test_recursive_pickle(self):
- f = self.partial(capture)
- f.__setstate__((f, (), {}, {}))
- try:
- for proto in range(pickle.HIGHEST_PROTOCOL + 1):
- with self.assertRaises(RecursionError):
- pickle.dumps(f, proto)
- finally:
- f.__setstate__((capture, (), {}, {}))
-
- f = self.partial(capture)
- f.__setstate__((capture, (f,), {}, {}))
- try:
- for proto in range(pickle.HIGHEST_PROTOCOL + 1):
- f_copy = pickle.loads(pickle.dumps(f, proto))
- try:
- self.assertIs(f_copy.args[0], f_copy)
- finally:
- f_copy.__setstate__((capture, (), {}, {}))
- finally:
- f.__setstate__((capture, (), {}, {}))
-
- f = self.partial(capture)
- f.__setstate__((capture, (), {'a': f}, {}))
- try:
- for proto in range(pickle.HIGHEST_PROTOCOL + 1):
- f_copy = pickle.loads(pickle.dumps(f, proto))
- try:
- self.assertIs(f_copy.keywords['a'], f_copy)
- finally:
- f_copy.__setstate__((capture, (), {}, {}))
- finally:
- f.__setstate__((capture, (), {}, {}))
+ with self.AllowPickle():
+ f = self.partial(capture)
+ f.__setstate__((f, (), {}, {}))
+ try:
+ for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+ with self.assertRaises(RecursionError):
+ pickle.dumps(f, proto)
+ finally:
+ f.__setstate__((capture, (), {}, {}))
+
+ f = self.partial(capture)
+ f.__setstate__((capture, (f,), {}, {}))
+ try:
+ for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+ f_copy = pickle.loads(pickle.dumps(f, proto))
+ try:
+ self.assertIs(f_copy.args[0], f_copy)
+ finally:
+ f_copy.__setstate__((capture, (), {}, {}))
+ finally:
+ f.__setstate__((capture, (), {}, {}))
+
+ f = self.partial(capture)
+ f.__setstate__((capture, (), {'a': f}, {}))
+ try:
+ for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+ f_copy = pickle.loads(pickle.dumps(f, proto))
+ try:
+ self.assertIs(f_copy.keywords['a'], f_copy)
+ finally:
+ f_copy.__setstate__((capture, (), {}, {}))
+ finally:
+ f.__setstate__((capture, (), {}, {}))
# Issue 6083: Reference counting bug
def test_setstate_refcount(self):
@@ -375,24 +365,60 @@ class TestPartialC(TestPartial, unittest.TestCase):
f = self.partial(object)
self.assertRaises(TypeError, f.__setstate__, BadSequence())
+@unittest.skipUnless(c_functools, 'requires the C _functools module')
+class TestPartialC(TestPartial, unittest.TestCase):
+ if c_functools:
+ partial = c_functools.partial
+
+ class AllowPickle:
+ def __enter__(self):
+ return self
+ def __exit__(self, type, value, tb):
+ return False
+
+ def test_attributes_unwritable(self):
+ # attributes should not be writable
+ p = self.partial(capture, 1, 2, a=10, b=20)
+ self.assertRaises(AttributeError, setattr, p, 'func', map)
+ self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
+ self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
+
+ p = self.partial(hex)
+ try:
+ del p.__dict__
+ except TypeError:
+ pass
+ else:
+ self.fail('partial object allowed __dict__ to be deleted')
class TestPartialPy(TestPartial, unittest.TestCase):
- partial = staticmethod(py_functools.partial)
+ partial = py_functools.partial
+ class AllowPickle:
+ def __init__(self):
+ self._cm = replaced_module("functools", py_functools)
+ def __enter__(self):
+ return self._cm.__enter__()
+ def __exit__(self, type, value, tb):
+ return self._cm.__exit__(type, value, tb)
if c_functools:
- class PartialSubclass(c_functools.partial):
+ class CPartialSubclass(c_functools.partial):
pass
+class PyPartialSubclass(py_functools.partial):
+ pass
@unittest.skipUnless(c_functools, 'requires the C _functools module')
class TestPartialCSubclass(TestPartialC):
if c_functools:
- partial = PartialSubclass
+ partial = CPartialSubclass
# partial subclasses are not optimized for nested calls
test_nested_optimization = None
+class TestPartialPySubclass(TestPartialPy):
+ partial = PyPartialSubclass
class TestPartialMethod(unittest.TestCase):
@@ -683,9 +709,10 @@ class TestWraps(TestUpdateWrapper):
self.assertEqual(wrapper.attr, 'This is a different test')
self.assertEqual(wrapper.dict_attr, f.dict_attr)
-
+@unittest.skipUnless(c_functools, 'requires the C _functools module')
class TestReduce(unittest.TestCase):
- func = functools.reduce
+ if c_functools:
+ func = c_functools.reduce
def test_reduce(self):
class Squares:
diff --git a/Misc/NEWS b/Misc/NEWS
index fd69eb7..a7a9104 100644
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -135,6 +135,11 @@ Core and Builtins
Library
-------
+- Issue #27137: the pure Python fallback implementation of ``functools.partial``
+ now matches the behaviour of its accelerated C counterpart for subclassing,
+ pickling and text representation purposes. Patch by Emanuel Barry and
+ Serhiy Storchaka.
+
- Issue #28019: itertools.count() no longer rounds non-integer step in range
between 1.0 and 2.0 to 1.
diff --git a/Modules/_functoolsmodule.c b/Modules/_functoolsmodule.c
index 848a03c..fa5fad3 100644
--- a/Modules/_functoolsmodule.c
+++ b/Modules/_functoolsmodule.c
@@ -229,7 +229,7 @@ partial_repr(partialobject *pto)
if (status != 0) {
if (status < 0)
return NULL;
- return PyUnicode_FromFormat("%s(...)", Py_TYPE(pto)->tp_name);
+ return PyUnicode_FromString("...");
}
arglist = PyUnicode_FromString("");