summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
Diffstat (limited to 'Lib')
-rw-r--r--Lib/operator.py46
-rw-r--r--Lib/test/test_operator.py109
2 files changed, 153 insertions, 2 deletions
diff --git a/Lib/operator.py b/Lib/operator.py
index 856036d..0db51c1 100644
--- a/Lib/operator.py
+++ b/Lib/operator.py
@@ -231,10 +231,13 @@ class attrgetter:
After h = attrgetter('name.first', 'name.last'), the call h(r) returns
(r.name.first, r.name.last).
"""
+ __slots__ = ('_attrs', '_call')
+
def __init__(self, attr, *attrs):
if not attrs:
if not isinstance(attr, str):
raise TypeError('attribute name must be a string')
+ self._attrs = (attr,)
names = attr.split('.')
def func(obj):
for name in names:
@@ -242,7 +245,8 @@ class attrgetter:
return obj
self._call = func
else:
- getters = tuple(map(attrgetter, (attr,) + attrs))
+ self._attrs = (attr,) + attrs
+ getters = tuple(map(attrgetter, self._attrs))
def func(obj):
return tuple(getter(obj) for getter in getters)
self._call = func
@@ -250,19 +254,30 @@ class attrgetter:
def __call__(self, obj):
return self._call(obj)
+ def __repr__(self):
+ return '%s.%s(%s)' % (self.__class__.__module__,
+ self.__class__.__qualname__,
+ ', '.join(map(repr, self._attrs)))
+
+ def __reduce__(self):
+ return self.__class__, self._attrs
+
class itemgetter:
"""
Return a callable object that fetches the given item(s) from its operand.
After f = itemgetter(2), the call f(r) returns r[2].
After g = itemgetter(2, 5, 3), the call g(r) returns (r[2], r[5], r[3])
"""
+ __slots__ = ('_items', '_call')
+
def __init__(self, item, *items):
if not items:
+ self._items = (item,)
def func(obj):
return obj[item]
self._call = func
else:
- items = (item,) + items
+ self._items = items = (item,) + items
def func(obj):
return tuple(obj[i] for i in items)
self._call = func
@@ -270,6 +285,14 @@ class itemgetter:
def __call__(self, obj):
return self._call(obj)
+ def __repr__(self):
+ return '%s.%s(%s)' % (self.__class__.__module__,
+ self.__class__.__name__,
+ ', '.join(map(repr, self._items)))
+
+ def __reduce__(self):
+ return self.__class__, self._items
+
class methodcaller:
"""
Return a callable object that calls the given method on its operand.
@@ -277,6 +300,7 @@ class methodcaller:
After g = methodcaller('name', 'date', foo=1), the call g(r) returns
r.name('date', foo=1).
"""
+ __slots__ = ('_name', '_args', '_kwargs')
def __init__(*args, **kwargs):
if len(args) < 2:
@@ -284,12 +308,30 @@ class methodcaller:
raise TypeError(msg)
self = args[0]
self._name = args[1]
+ if not isinstance(self._name, str):
+ raise TypeError('method name must be a string')
self._args = args[2:]
self._kwargs = kwargs
def __call__(self, obj):
return getattr(obj, self._name)(*self._args, **self._kwargs)
+ def __repr__(self):
+ args = [repr(self._name)]
+ args.extend(map(repr, self._args))
+ args.extend('%s=%r' % (k, v) for k, v in self._kwargs.items())
+ return '%s.%s(%s)' % (self.__class__.__module__,
+ self.__class__.__name__,
+ ', '.join(args))
+
+ def __reduce__(self):
+ if not self._kwargs:
+ return self.__class__, (self._name,) + self._args
+ else:
+ from functools import partial
+ return partial(self.__class__, self._name, **self._kwargs), self._args
+
+
# In-place Operations *********************************************************#
def iadd(a, b):
diff --git a/Lib/test/test_operator.py b/Lib/test/test_operator.py
index 1bd0391..ef9cf3e 100644
--- a/Lib/test/test_operator.py
+++ b/Lib/test/test_operator.py
@@ -1,4 +1,6 @@
import unittest
+import pickle
+import sys
from test import support
@@ -35,6 +37,9 @@ class Seq2(object):
class OperatorTestCase:
+ def setUp(self):
+ sys.modules['operator'] = self.module
+
def test_lt(self):
operator = self.module
self.assertRaises(TypeError, operator.lt)
@@ -396,6 +401,7 @@ class OperatorTestCase:
def test_methodcaller(self):
operator = self.module
self.assertRaises(TypeError, operator.methodcaller)
+ self.assertRaises(TypeError, operator.methodcaller, 12)
class A:
def foo(self, *args, **kwds):
return args[0] + args[1]
@@ -491,5 +497,108 @@ class PyOperatorTestCase(OperatorTestCase, unittest.TestCase):
class COperatorTestCase(OperatorTestCase, unittest.TestCase):
module = c_operator
+
+class OperatorPickleTestCase:
+ def copy(self, obj, proto):
+ with support.swap_item(sys.modules, 'operator', self.module):
+ pickled = pickle.dumps(obj, proto)
+ with support.swap_item(sys.modules, 'operator', self.module2):
+ return pickle.loads(pickled)
+
+ def test_attrgetter(self):
+ attrgetter = self.module.attrgetter
+ attrgetter = self.module.attrgetter
+ class A:
+ pass
+ a = A()
+ a.x = 'X'
+ a.y = 'Y'
+ a.z = 'Z'
+ a.t = A()
+ a.t.u = A()
+ a.t.u.v = 'V'
+ for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+ with self.subTest(proto=proto):
+ f = attrgetter('x')
+ f2 = self.copy(f, proto)
+ self.assertEqual(repr(f2), repr(f))
+ self.assertEqual(f2(a), f(a))
+ # multiple gets
+ f = attrgetter('x', 'y', 'z')
+ f2 = self.copy(f, proto)
+ self.assertEqual(repr(f2), repr(f))
+ self.assertEqual(f2(a), f(a))
+ # recursive gets
+ f = attrgetter('t.u.v')
+ f2 = self.copy(f, proto)
+ self.assertEqual(repr(f2), repr(f))
+ self.assertEqual(f2(a), f(a))
+
+ def test_itemgetter(self):
+ itemgetter = self.module.itemgetter
+ a = 'ABCDE'
+ for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+ with self.subTest(proto=proto):
+ f = itemgetter(2)
+ f2 = self.copy(f, proto)
+ self.assertEqual(repr(f2), repr(f))
+ self.assertEqual(f2(a), f(a))
+ # multiple gets
+ f = itemgetter(2, 0, 4)
+ f2 = self.copy(f, proto)
+ self.assertEqual(repr(f2), repr(f))
+ self.assertEqual(f2(a), f(a))
+
+ def test_methodcaller(self):
+ methodcaller = self.module.methodcaller
+ class A:
+ def foo(self, *args, **kwds):
+ return args[0] + args[1]
+ def bar(self, f=42):
+ return f
+ def baz(*args, **kwds):
+ return kwds['name'], kwds['self']
+ a = A()
+ for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+ with self.subTest(proto=proto):
+ f = methodcaller('bar')
+ f2 = self.copy(f, proto)
+ self.assertEqual(repr(f2), repr(f))
+ self.assertEqual(f2(a), f(a))
+ # positional args
+ f = methodcaller('foo', 1, 2)
+ f2 = self.copy(f, proto)
+ self.assertEqual(repr(f2), repr(f))
+ self.assertEqual(f2(a), f(a))
+ # keyword args
+ f = methodcaller('bar', f=5)
+ f2 = self.copy(f, proto)
+ self.assertEqual(repr(f2), repr(f))
+ self.assertEqual(f2(a), f(a))
+ f = methodcaller('baz', self='eggs', name='spam')
+ f2 = self.copy(f, proto)
+ # Can't test repr consistently with multiple keyword args
+ self.assertEqual(f2(a), f(a))
+
+class PyPyOperatorPickleTestCase(OperatorPickleTestCase, unittest.TestCase):
+ module = py_operator
+ module2 = py_operator
+
+@unittest.skipUnless(c_operator, 'requires _operator')
+class PyCOperatorPickleTestCase(OperatorPickleTestCase, unittest.TestCase):
+ module = py_operator
+ module2 = c_operator
+
+@unittest.skipUnless(c_operator, 'requires _operator')
+class CPyOperatorPickleTestCase(OperatorPickleTestCase, unittest.TestCase):
+ module = c_operator
+ module2 = py_operator
+
+@unittest.skipUnless(c_operator, 'requires _operator')
+class CCOperatorPickleTestCase(OperatorPickleTestCase, unittest.TestCase):
+ module = c_operator
+ module2 = c_operator
+
+
if __name__ == "__main__":
unittest.main()