diff options
author | Serhiy Storchaka <storchaka@gmail.com> | 2023-09-06 20:55:42 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-06 20:55:42 (GMT) |
commit | 6f3c138dfa868b32d3288898923bbfa388f2fa5d (patch) | |
tree | 98098db4d0d2f256f2a17d515c1750a28424166c | |
parent | 9f0c0a46f00d687e921990ee83894b2f4ce8a6e7 (diff) | |
download | cpython-6f3c138dfa868b32d3288898923bbfa388f2fa5d.zip cpython-6f3c138dfa868b32d3288898923bbfa388f2fa5d.tar.gz cpython-6f3c138dfa868b32d3288898923bbfa388f2fa5d.tar.bz2 |
gh-108751: Add copy.replace() function (GH-108752)
It creates a modified copy of an object by calling the object's
__replace__() method.
It is a generalization of dataclasses.replace(), named tuple's _replace()
method and replace() methods in various classes, and supports all these
stdlib classes.
-rw-r--r-- | Doc/library/collections.rst | 2 | ||||
-rw-r--r-- | Doc/library/copy.rst | 30 | ||||
-rw-r--r-- | Doc/library/dataclasses.rst | 2 | ||||
-rw-r--r-- | Doc/library/datetime.rst | 9 | ||||
-rw-r--r-- | Doc/library/inspect.rst | 11 | ||||
-rw-r--r-- | Doc/library/types.rst | 2 | ||||
-rw-r--r-- | Doc/whatsnew/3.13.rst | 12 | ||||
-rw-r--r-- | Lib/_pydatetime.py | 6 | ||||
-rw-r--r-- | Lib/collections/__init__.py | 1 | ||||
-rw-r--r-- | Lib/copy.py | 13 | ||||
-rw-r--r-- | Lib/dataclasses.py | 9 | ||||
-rw-r--r-- | Lib/inspect.py | 4 | ||||
-rw-r--r-- | Lib/test/datetimetester.py | 137 | ||||
-rw-r--r-- | Lib/test/test_code.py | 7 | ||||
-rw-r--r-- | Lib/test/test_copy.py | 66 | ||||
-rw-r--r-- | Lib/test/test_inspect.py | 65 | ||||
-rw-r--r-- | Misc/NEWS.d/next/Library/2023-09-01-13-14-08.gh-issue-108751.2itqwe.rst | 2 | ||||
-rw-r--r-- | Modules/_datetimemodule.c | 6 | ||||
-rw-r--r-- | Objects/codeobject.c | 1 |
19 files changed, 314 insertions, 71 deletions
diff --git a/Doc/library/collections.rst b/Doc/library/collections.rst index b8b231b..03cb1dc 100644 --- a/Doc/library/collections.rst +++ b/Doc/library/collections.rst @@ -979,6 +979,8 @@ field names, the method and attribute names start with an underscore. >>> for partnum, record in inventory.items(): ... inventory[partnum] = record._replace(price=newprices[partnum], timestamp=time.now()) + Named tuples are also supported by generic function :func:`copy.replace`. + .. attribute:: somenamedtuple._fields Tuple of strings listing the field names. Useful for introspection diff --git a/Doc/library/copy.rst b/Doc/library/copy.rst index 8f32477..cc4ca03 100644 --- a/Doc/library/copy.rst +++ b/Doc/library/copy.rst @@ -17,14 +17,22 @@ operations (explained below). Interface summary: -.. function:: copy(x) +.. function:: copy(obj) - Return a shallow copy of *x*. + Return a shallow copy of *obj*. -.. function:: deepcopy(x[, memo]) +.. function:: deepcopy(obj[, memo]) - Return a deep copy of *x*. + Return a deep copy of *obj*. + + +.. function:: replace(obj, /, **changes) + + Creates a new object of the same type as *obj*, replacing fields with values + from *changes*. + + .. versionadded:: 3.13 .. exception:: Error @@ -89,6 +97,20 @@ with the component as first argument and the memo dictionary as second argument. The memo dictionary should be treated as an opaque object. +.. index:: + single: __replace__() (replace protocol) + +Function :func:`replace` is more limited than :func:`copy` and :func:`deepcopy`, +and only supports named tuples created by :func:`~collections.namedtuple`, +:mod:`dataclasses`, and other classes which define method :meth:`!__replace__`. + + .. method:: __replace__(self, /, **changes) + :noindex: + +:meth:`!__replace__` should create a new object of the same type, +replacing fields with values from *changes*. + + .. seealso:: Module :mod:`pickle` diff --git a/Doc/library/dataclasses.rst b/Doc/library/dataclasses.rst index d687487..d78a607 100644 --- a/Doc/library/dataclasses.rst +++ b/Doc/library/dataclasses.rst @@ -456,6 +456,8 @@ Module contents ``replace()`` (or similarly named) method which handles instance copying. + Dataclass instances are also supported by generic function :func:`copy.replace`. + .. function:: is_dataclass(obj) Return ``True`` if its parameter is a dataclass or an instance of one, diff --git a/Doc/library/datetime.rst b/Doc/library/datetime.rst index 04cc755..0b9d42f 100644 --- a/Doc/library/datetime.rst +++ b/Doc/library/datetime.rst @@ -652,6 +652,9 @@ Instance methods: >>> d.replace(day=26) datetime.date(2002, 12, 26) + :class:`date` objects are also supported by generic function + :func:`copy.replace`. + .. method:: date.timetuple() @@ -1251,6 +1254,9 @@ Instance methods: ``tzinfo=None`` can be specified to create a naive datetime from an aware datetime with no conversion of date and time data. + :class:`datetime` objects are also supported by generic function + :func:`copy.replace`. + .. versionadded:: 3.6 Added the ``fold`` argument. @@ -1827,6 +1833,9 @@ Instance methods: ``tzinfo=None`` can be specified to create a naive :class:`.time` from an aware :class:`.time`, without conversion of the time data. + :class:`time` objects are also supported by generic function + :func:`copy.replace`. + .. versionadded:: 3.6 Added the ``fold`` argument. diff --git a/Doc/library/inspect.rst b/Doc/library/inspect.rst index 603ac32..fe0ed13 100644 --- a/Doc/library/inspect.rst +++ b/Doc/library/inspect.rst @@ -689,8 +689,8 @@ function. The optional *return_annotation* argument, can be an arbitrary Python object, is the "return" annotation of the callable. - Signature objects are *immutable*. Use :meth:`Signature.replace` to make a - modified copy. + Signature objects are *immutable*. Use :meth:`Signature.replace` or + :func:`copy.replace` to make a modified copy. .. versionchanged:: 3.5 Signature objects are picklable and :term:`hashable`. @@ -746,6 +746,9 @@ function. >>> str(new_sig) "(a, b) -> 'new return anno'" + Signature objects are also supported by generic function + :func:`copy.replace`. + .. classmethod:: Signature.from_callable(obj, *, follow_wrapped=True, globalns=None, localns=None) Return a :class:`Signature` (or its subclass) object for a given callable @@ -769,7 +772,7 @@ function. .. class:: Parameter(name, kind, *, default=Parameter.empty, annotation=Parameter.empty) Parameter objects are *immutable*. Instead of modifying a Parameter object, - you can use :meth:`Parameter.replace` to create a modified copy. + you can use :meth:`Parameter.replace` or :func:`copy.replace` to create a modified copy. .. versionchanged:: 3.5 Parameter objects are picklable and :term:`hashable`. @@ -892,6 +895,8 @@ function. >>> str(param.replace(default=Parameter.empty, annotation='spam')) "foo:'spam'" + Parameter objects are also supported by generic function :func:`copy.replace`. + .. versionchanged:: 3.4 In Python 3.3 Parameter objects were allowed to have ``name`` set to ``None`` if their ``kind`` was set to ``POSITIONAL_ONLY``. diff --git a/Doc/library/types.rst b/Doc/library/types.rst index 8cbe17d..82300af 100644 --- a/Doc/library/types.rst +++ b/Doc/library/types.rst @@ -200,6 +200,8 @@ Standard names are defined for the following types: Return a copy of the code object with new values for the specified fields. + Code objects are also supported by generic function :func:`copy.replace`. + .. versionadded:: 3.8 .. data:: CellType diff --git a/Doc/whatsnew/3.13.rst b/Doc/whatsnew/3.13.rst index de23172..8c64675 100644 --- a/Doc/whatsnew/3.13.rst +++ b/Doc/whatsnew/3.13.rst @@ -115,6 +115,18 @@ array It can be used instead of ``'u'`` type code, which is deprecated. (Contributed by Inada Naoki in :gh:`80480`.) +copy +---- + +* Add :func:`copy.replace` function which allows to create a modified copy of + an object, which is especially usefule for immutable objects. + It supports named tuples created with the factory function + :func:`collections.namedtuple`, :class:`~dataclasses.dataclass` instances, + various :mod:`datetime` objects, :class:`~inspect.Signature` objects, + :class:`~inspect.Parameter` objects, :ref:`code object <code-objects>`, and + any user classes which define the :meth:`!__replace__` method. + (Contributed by Serhiy Storchaka in :gh:`108751`.) + dbm --- diff --git a/Lib/_pydatetime.py b/Lib/_pydatetime.py index 549fcda..df616bb 100644 --- a/Lib/_pydatetime.py +++ b/Lib/_pydatetime.py @@ -1112,6 +1112,8 @@ class date: day = self._day return type(self)(year, month, day) + __replace__ = replace + # Comparisons of date objects with other. def __eq__(self, other): @@ -1637,6 +1639,8 @@ class time: fold = self._fold return type(self)(hour, minute, second, microsecond, tzinfo, fold=fold) + __replace__ = replace + # Pickle support. def _getstate(self, protocol=3): @@ -1983,6 +1987,8 @@ class datetime(date): return type(self)(year, month, day, hour, minute, second, microsecond, tzinfo, fold=fold) + __replace__ = replace + def _local_timezone(self): if self.tzinfo is None: ts = self._mktime() diff --git a/Lib/collections/__init__.py b/Lib/collections/__init__.py index 8652dc8..a461550 100644 --- a/Lib/collections/__init__.py +++ b/Lib/collections/__init__.py @@ -495,6 +495,7 @@ def namedtuple(typename, field_names, *, rename=False, defaults=None, module=Non '_field_defaults': field_defaults, '__new__': __new__, '_make': _make, + '__replace__': _replace, '_replace': _replace, '__repr__': __repr__, '_asdict': _asdict, diff --git a/Lib/copy.py b/Lib/copy.py index da2908e..6d7bb9a 100644 --- a/Lib/copy.py +++ b/Lib/copy.py @@ -290,3 +290,16 @@ def _reconstruct(x, memo, func, args, return y del types, weakref + + +def replace(obj, /, **changes): + """Return a new object replacing specified fields with new values. + + This is especially useful for immutable objects, like named tuples or + frozen dataclasses. + """ + cls = obj.__class__ + func = getattr(cls, '__replace__', None) + if func is None: + raise TypeError(f"replace() does not support {cls.__name__} objects") + return func(obj, **changes) diff --git a/Lib/dataclasses.py b/Lib/dataclasses.py index 21f3fa5..84f8d68 100644 --- a/Lib/dataclasses.py +++ b/Lib/dataclasses.py @@ -1073,6 +1073,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen, globals, slots, )) + _set_new_attribute(cls, '__replace__', _replace) # Get the fields as a list, and include only real fields. This is # used in all of the following methods. @@ -1546,13 +1547,15 @@ def replace(obj, /, **changes): c1 = replace(c, x=3) assert c1.x == 3 and c1.y == 2 """ + if not _is_dataclass_instance(obj): + raise TypeError("replace() should be called on dataclass instances") + return _replace(obj, **changes) + +def _replace(obj, /, **changes): # We're going to mutate 'changes', but that's okay because it's a # new dict, even if called with 'replace(obj, **my_changes)'. - if not _is_dataclass_instance(obj): - raise TypeError("replace() should be called on dataclass instances") - # It's an error to have init=False fields in 'changes'. # If a field is not in 'changes', read its value from the provided obj. diff --git a/Lib/inspect.py b/Lib/inspect.py index c821183..aaa22be 100644 --- a/Lib/inspect.py +++ b/Lib/inspect.py @@ -2870,6 +2870,8 @@ class Parameter: return formatted + __replace__ = replace + def __repr__(self): return '<{} "{}">'.format(self.__class__.__name__, self) @@ -3130,6 +3132,8 @@ class Signature: return type(self)(parameters, return_annotation=return_annotation) + __replace__ = replace + def _hash_basis(self): params = tuple(param for param in self.parameters.values() if param.kind != _KEYWORD_ONLY) diff --git a/Lib/test/datetimetester.py b/Lib/test/datetimetester.py index 55e0619..8bda173 100644 --- a/Lib/test/datetimetester.py +++ b/Lib/test/datetimetester.py @@ -1699,22 +1699,23 @@ class TestDate(HarmlessMixedComparison, unittest.TestCase): cls = self.theclass args = [1, 2, 3] base = cls(*args) - self.assertEqual(base, base.replace()) + self.assertEqual(base.replace(), base) + self.assertEqual(copy.replace(base), base) - i = 0 - for name, newval in (("year", 2), - ("month", 3), - ("day", 4)): + changes = (("year", 2), + ("month", 3), + ("day", 4)) + for i, (name, newval) in enumerate(changes): newargs = args[:] newargs[i] = newval expected = cls(*newargs) - got = base.replace(**{name: newval}) - self.assertEqual(expected, got) - i += 1 + self.assertEqual(base.replace(**{name: newval}), expected) + self.assertEqual(copy.replace(base, **{name: newval}), expected) # Out of bounds. base = cls(2000, 2, 29) self.assertRaises(ValueError, base.replace, year=2001) + self.assertRaises(ValueError, copy.replace, base, year=2001) def test_subclass_replace(self): class DateSubclass(self.theclass): @@ -1722,6 +1723,7 @@ class TestDate(HarmlessMixedComparison, unittest.TestCase): dt = DateSubclass(2012, 1, 1) self.assertIs(type(dt.replace(year=2013)), DateSubclass) + self.assertIs(type(copy.replace(dt, year=2013)), DateSubclass) def test_subclass_date(self): @@ -2856,26 +2858,27 @@ class TestDateTime(TestDate): cls = self.theclass args = [1, 2, 3, 4, 5, 6, 7] base = cls(*args) - self.assertEqual(base, base.replace()) - - i = 0 - for name, newval in (("year", 2), - ("month", 3), - ("day", 4), - ("hour", 5), - ("minute", 6), - ("second", 7), - ("microsecond", 8)): + self.assertEqual(base.replace(), base) + self.assertEqual(copy.replace(base), base) + + changes = (("year", 2), + ("month", 3), + ("day", 4), + ("hour", 5), + ("minute", 6), + ("second", 7), + ("microsecond", 8)) + for i, (name, newval) in enumerate(changes): newargs = args[:] newargs[i] = newval expected = cls(*newargs) - got = base.replace(**{name: newval}) - self.assertEqual(expected, got) - i += 1 + self.assertEqual(base.replace(**{name: newval}), expected) + self.assertEqual(copy.replace(base, **{name: newval}), expected) # Out of bounds. base = cls(2000, 2, 29) self.assertRaises(ValueError, base.replace, year=2001) + self.assertRaises(ValueError, copy.replace, base, year=2001) @support.run_with_tz('EDT4') def test_astimezone(self): @@ -3671,19 +3674,19 @@ class TestTime(HarmlessMixedComparison, unittest.TestCase): cls = self.theclass args = [1, 2, 3, 4] base = cls(*args) - self.assertEqual(base, base.replace()) - - i = 0 - for name, newval in (("hour", 5), - ("minute", 6), - ("second", 7), - ("microsecond", 8)): + self.assertEqual(base.replace(), base) + self.assertEqual(copy.replace(base), base) + + changes = (("hour", 5), + ("minute", 6), + ("second", 7), + ("microsecond", 8)) + for i, (name, newval) in enumerate(changes): newargs = args[:] newargs[i] = newval expected = cls(*newargs) - got = base.replace(**{name: newval}) - self.assertEqual(expected, got) - i += 1 + self.assertEqual(base.replace(**{name: newval}), expected) + self.assertEqual(copy.replace(base, **{name: newval}), expected) # Out of bounds. base = cls(1) @@ -3691,6 +3694,10 @@ class TestTime(HarmlessMixedComparison, unittest.TestCase): self.assertRaises(ValueError, base.replace, minute=-1) self.assertRaises(ValueError, base.replace, second=100) self.assertRaises(ValueError, base.replace, microsecond=1000000) + self.assertRaises(ValueError, copy.replace, base, hour=24) + self.assertRaises(ValueError, copy.replace, base, minute=-1) + self.assertRaises(ValueError, copy.replace, base, second=100) + self.assertRaises(ValueError, copy.replace, base, microsecond=1000000) def test_subclass_replace(self): class TimeSubclass(self.theclass): @@ -3698,6 +3705,7 @@ class TestTime(HarmlessMixedComparison, unittest.TestCase): ctime = TimeSubclass(12, 30) self.assertIs(type(ctime.replace(hour=10)), TimeSubclass) + self.assertIs(type(copy.replace(ctime, hour=10)), TimeSubclass) def test_subclass_time(self): @@ -4085,31 +4093,37 @@ class TestTimeTZ(TestTime, TZInfoBase, unittest.TestCase): zm200 = FixedOffset(timedelta(minutes=-200), "-200") args = [1, 2, 3, 4, z100] base = cls(*args) - self.assertEqual(base, base.replace()) - - i = 0 - for name, newval in (("hour", 5), - ("minute", 6), - ("second", 7), - ("microsecond", 8), - ("tzinfo", zm200)): + self.assertEqual(base.replace(), base) + self.assertEqual(copy.replace(base), base) + + changes = (("hour", 5), + ("minute", 6), + ("second", 7), + ("microsecond", 8), + ("tzinfo", zm200)) + for i, (name, newval) in enumerate(changes): newargs = args[:] newargs[i] = newval expected = cls(*newargs) - got = base.replace(**{name: newval}) - self.assertEqual(expected, got) - i += 1 + self.assertEqual(base.replace(**{name: newval}), expected) + self.assertEqual(copy.replace(base, **{name: newval}), expected) # Ensure we can get rid of a tzinfo. self.assertEqual(base.tzname(), "+100") base2 = base.replace(tzinfo=None) self.assertIsNone(base2.tzinfo) self.assertIsNone(base2.tzname()) + base22 = copy.replace(base, tzinfo=None) + self.assertIsNone(base22.tzinfo) + self.assertIsNone(base22.tzname()) # Ensure we can add one. base3 = base2.replace(tzinfo=z100) self.assertEqual(base, base3) self.assertIs(base.tzinfo, base3.tzinfo) + base32 = copy.replace(base22, tzinfo=z100) + self.assertEqual(base, base32) + self.assertIs(base.tzinfo, base32.tzinfo) # Out of bounds. base = cls(1) @@ -4117,6 +4131,10 @@ class TestTimeTZ(TestTime, TZInfoBase, unittest.TestCase): self.assertRaises(ValueError, base.replace, minute=-1) self.assertRaises(ValueError, base.replace, second=100) self.assertRaises(ValueError, base.replace, microsecond=1000000) + self.assertRaises(ValueError, copy.replace, base, hour=24) + self.assertRaises(ValueError, copy.replace, base, minute=-1) + self.assertRaises(ValueError, copy.replace, base, second=100) + self.assertRaises(ValueError, copy.replace, base, microsecond=1000000) def test_mixed_compare(self): t1 = self.theclass(1, 2, 3) @@ -4885,38 +4903,45 @@ class TestDateTimeTZ(TestDateTime, TZInfoBase, unittest.TestCase): zm200 = FixedOffset(timedelta(minutes=-200), "-200") args = [1, 2, 3, 4, 5, 6, 7, z100] base = cls(*args) - self.assertEqual(base, base.replace()) - - i = 0 - for name, newval in (("year", 2), - ("month", 3), - ("day", 4), - ("hour", 5), - ("minute", 6), - ("second", 7), - ("microsecond", 8), - ("tzinfo", zm200)): + self.assertEqual(base.replace(), base) + self.assertEqual(copy.replace(base), base) + + changes = (("year", 2), + ("month", 3), + ("day", 4), + ("hour", 5), + ("minute", 6), + ("second", 7), + ("microsecond", 8), + ("tzinfo", zm200)) + for i, (name, newval) in enumerate(changes): newargs = args[:] newargs[i] = newval expected = cls(*newargs) - got = base.replace(**{name: newval}) - self.assertEqual(expected, got) - i += 1 + self.assertEqual(base.replace(**{name: newval}), expected) + self.assertEqual(copy.replace(base, **{name: newval}), expected) # Ensure we can get rid of a tzinfo. self.assertEqual(base.tzname(), "+100") base2 = base.replace(tzinfo=None) self.assertIsNone(base2.tzinfo) self.assertIsNone(base2.tzname()) + base22 = copy.replace(base, tzinfo=None) + self.assertIsNone(base22.tzinfo) + self.assertIsNone(base22.tzname()) # Ensure we can add one. base3 = base2.replace(tzinfo=z100) self.assertEqual(base, base3) self.assertIs(base.tzinfo, base3.tzinfo) + base32 = copy.replace(base22, tzinfo=z100) + self.assertEqual(base, base32) + self.assertIs(base.tzinfo, base32.tzinfo) # Out of bounds. base = cls(2000, 2, 29) self.assertRaises(ValueError, base.replace, year=2001) + self.assertRaises(ValueError, copy.replace, base, year=2001) def test_more_astimezone(self): # The inherited test_astimezone covered some trivial and error cases. diff --git a/Lib/test/test_code.py b/Lib/test/test_code.py index e056c16..812c068 100644 --- a/Lib/test/test_code.py +++ b/Lib/test/test_code.py @@ -125,6 +125,7 @@ consts: ('None',) """ +import copy import inspect import sys import threading @@ -280,11 +281,17 @@ class CodeTest(unittest.TestCase): with self.subTest(attr=attr, value=value): new_code = code.replace(**{attr: value}) self.assertEqual(getattr(new_code, attr), value) + new_code = copy.replace(code, **{attr: value}) + self.assertEqual(getattr(new_code, attr), value) new_code = code.replace(co_varnames=code2.co_varnames, co_nlocals=code2.co_nlocals) self.assertEqual(new_code.co_varnames, code2.co_varnames) self.assertEqual(new_code.co_nlocals, code2.co_nlocals) + new_code = copy.replace(code, co_varnames=code2.co_varnames, + co_nlocals=code2.co_nlocals) + self.assertEqual(new_code.co_varnames, code2.co_varnames) + self.assertEqual(new_code.co_nlocals, code2.co_nlocals) def test_nlocals_mismatch(self): def func(): diff --git a/Lib/test/test_copy.py b/Lib/test/test_copy.py index 826e468..c66c6ee 100644 --- a/Lib/test/test_copy.py +++ b/Lib/test/test_copy.py @@ -4,7 +4,7 @@ import copy import copyreg import weakref import abc -from operator import le, lt, ge, gt, eq, ne +from operator import le, lt, ge, gt, eq, ne, attrgetter import unittest from test import support @@ -899,7 +899,71 @@ class TestCopy(unittest.TestCase): g.b() +class TestReplace(unittest.TestCase): + + def test_unsupported(self): + self.assertRaises(TypeError, copy.replace, 1) + self.assertRaises(TypeError, copy.replace, []) + self.assertRaises(TypeError, copy.replace, {}) + def f(): pass + self.assertRaises(TypeError, copy.replace, f) + class A: pass + self.assertRaises(TypeError, copy.replace, A) + self.assertRaises(TypeError, copy.replace, A()) + + def test_replace_method(self): + class A: + def __new__(cls, x, y=0): + self = object.__new__(cls) + self.x = x + self.y = y + return self + + def __init__(self, *args, **kwargs): + self.z = self.x + self.y + + def __replace__(self, **changes): + x = changes.get('x', self.x) + y = changes.get('y', self.y) + return type(self)(x, y) + + attrs = attrgetter('x', 'y', 'z') + a = A(11, 22) + self.assertEqual(attrs(copy.replace(a)), (11, 22, 33)) + self.assertEqual(attrs(copy.replace(a, x=1)), (1, 22, 23)) + self.assertEqual(attrs(copy.replace(a, y=2)), (11, 2, 13)) + self.assertEqual(attrs(copy.replace(a, x=1, y=2)), (1, 2, 3)) + + def test_namedtuple(self): + from collections import namedtuple + Point = namedtuple('Point', 'x y', defaults=(0,)) + p = Point(11, 22) + self.assertEqual(copy.replace(p), (11, 22)) + self.assertEqual(copy.replace(p, x=1), (1, 22)) + self.assertEqual(copy.replace(p, y=2), (11, 2)) + self.assertEqual(copy.replace(p, x=1, y=2), (1, 2)) + with self.assertRaisesRegex(ValueError, 'unexpected field name'): + copy.replace(p, x=1, error=2) + + def test_dataclass(self): + from dataclasses import dataclass + @dataclass + class C: + x: int + y: int = 0 + + attrs = attrgetter('x', 'y') + c = C(11, 22) + self.assertEqual(attrs(copy.replace(c)), (11, 22)) + self.assertEqual(attrs(copy.replace(c, x=1)), (1, 22)) + self.assertEqual(attrs(copy.replace(c, y=2)), (11, 2)) + self.assertEqual(attrs(copy.replace(c, x=1, y=2)), (1, 2)) + with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'): + copy.replace(c, x=1, error=2) + + def global_foo(x, y): return x+y + if __name__ == "__main__": unittest.main() diff --git a/Lib/test/test_inspect.py b/Lib/test/test_inspect.py index 78ef817..2fb356a 100644 --- a/Lib/test/test_inspect.py +++ b/Lib/test/test_inspect.py @@ -1,6 +1,7 @@ import asyncio import builtins import collections +import copy import datetime import functools import importlib @@ -3830,6 +3831,28 @@ class TestSignatureObject(unittest.TestCase): P('bar', P.VAR_POSITIONAL)])), '(foo, /, *bar)') + def test_signature_replace_parameters(self): + def test(a, b) -> 42: + pass + + sig = inspect.signature(test) + parameters = sig.parameters + sig = sig.replace(parameters=list(parameters.values())[1:]) + self.assertEqual(list(sig.parameters), ['b']) + self.assertEqual(sig.parameters['b'], parameters['b']) + self.assertEqual(sig.return_annotation, 42) + sig = sig.replace(parameters=()) + self.assertEqual(dict(sig.parameters), {}) + + sig = inspect.signature(test) + parameters = sig.parameters + sig = copy.replace(sig, parameters=list(parameters.values())[1:]) + self.assertEqual(list(sig.parameters), ['b']) + self.assertEqual(sig.parameters['b'], parameters['b']) + self.assertEqual(sig.return_annotation, 42) + sig = copy.replace(sig, parameters=()) + self.assertEqual(dict(sig.parameters), {}) + def test_signature_replace_anno(self): def test() -> 42: pass @@ -3843,6 +3866,15 @@ class TestSignatureObject(unittest.TestCase): self.assertEqual(sig.return_annotation, 42) self.assertEqual(sig, inspect.signature(test)) + sig = inspect.signature(test) + sig = copy.replace(sig, return_annotation=None) + self.assertIs(sig.return_annotation, None) + sig = copy.replace(sig, return_annotation=sig.empty) + self.assertIs(sig.return_annotation, sig.empty) + sig = copy.replace(sig, return_annotation=42) + self.assertEqual(sig.return_annotation, 42) + self.assertEqual(sig, inspect.signature(test)) + def test_signature_replaced(self): def test(): pass @@ -4187,41 +4219,66 @@ class TestParameterObject(unittest.TestCase): p = inspect.Parameter('foo', default=42, kind=inspect.Parameter.KEYWORD_ONLY) - self.assertIsNot(p, p.replace()) - self.assertEqual(p, p.replace()) + self.assertIsNot(p.replace(), p) + self.assertEqual(p.replace(), p) + self.assertIsNot(copy.replace(p), p) + self.assertEqual(copy.replace(p), p) p2 = p.replace(annotation=1) self.assertEqual(p2.annotation, 1) p2 = p2.replace(annotation=p2.empty) - self.assertEqual(p, p2) + self.assertEqual(p2, p) + p3 = copy.replace(p, annotation=1) + self.assertEqual(p3.annotation, 1) + p3 = copy.replace(p3, annotation=p3.empty) + self.assertEqual(p3, p) p2 = p2.replace(name='bar') self.assertEqual(p2.name, 'bar') self.assertNotEqual(p2, p) + p3 = copy.replace(p3, name='bar') + self.assertEqual(p3.name, 'bar') + self.assertNotEqual(p3, p) with self.assertRaisesRegex(ValueError, 'name is a required attribute'): p2 = p2.replace(name=p2.empty) + with self.assertRaisesRegex(ValueError, + 'name is a required attribute'): + p3 = copy.replace(p3, name=p3.empty) p2 = p2.replace(name='foo', default=None) self.assertIs(p2.default, None) self.assertNotEqual(p2, p) + p3 = copy.replace(p3, name='foo', default=None) + self.assertIs(p3.default, None) + self.assertNotEqual(p3, p) p2 = p2.replace(name='foo', default=p2.empty) self.assertIs(p2.default, p2.empty) - + p3 = copy.replace(p3, name='foo', default=p3.empty) + self.assertIs(p3.default, p3.empty) p2 = p2.replace(default=42, kind=p2.POSITIONAL_OR_KEYWORD) self.assertEqual(p2.kind, p2.POSITIONAL_OR_KEYWORD) self.assertNotEqual(p2, p) + p3 = copy.replace(p3, default=42, kind=p3.POSITIONAL_OR_KEYWORD) + self.assertEqual(p3.kind, p3.POSITIONAL_OR_KEYWORD) + self.assertNotEqual(p3, p) with self.assertRaisesRegex(ValueError, "value <class 'inspect._empty'> " "is not a valid Parameter.kind"): p2 = p2.replace(kind=p2.empty) + with self.assertRaisesRegex(ValueError, + "value <class 'inspect._empty'> " + "is not a valid Parameter.kind"): + p3 = copy.replace(p3, kind=p3.empty) p2 = p2.replace(kind=p2.KEYWORD_ONLY) self.assertEqual(p2, p) + p3 = copy.replace(p3, kind=p3.KEYWORD_ONLY) + self.assertEqual(p3, p) def test_signature_parameter_positional_only(self): with self.assertRaisesRegex(TypeError, 'name must be a str'): diff --git a/Misc/NEWS.d/next/Library/2023-09-01-13-14-08.gh-issue-108751.2itqwe.rst b/Misc/NEWS.d/next/Library/2023-09-01-13-14-08.gh-issue-108751.2itqwe.rst new file mode 100644 index 0000000..7bc21fe --- /dev/null +++ b/Misc/NEWS.d/next/Library/2023-09-01-13-14-08.gh-issue-108751.2itqwe.rst @@ -0,0 +1,2 @@ +Add :func:`copy.replace` function which allows to create a modified copy of +an object. It supports named tuples, dataclasses, and many other objects. diff --git a/Modules/_datetimemodule.c b/Modules/_datetimemodule.c index 191db3f..0d35677 100644 --- a/Modules/_datetimemodule.c +++ b/Modules/_datetimemodule.c @@ -3590,6 +3590,8 @@ static PyMethodDef date_methods[] = { {"replace", _PyCFunction_CAST(date_replace), METH_VARARGS | METH_KEYWORDS, PyDoc_STR("Return date with new specified fields.")}, + {"__replace__", _PyCFunction_CAST(date_replace), METH_VARARGS | METH_KEYWORDS}, + {"__reduce__", (PyCFunction)date_reduce, METH_NOARGS, PyDoc_STR("__reduce__() -> (cls, state)")}, @@ -4719,6 +4721,8 @@ static PyMethodDef time_methods[] = { {"replace", _PyCFunction_CAST(time_replace), METH_VARARGS | METH_KEYWORDS, PyDoc_STR("Return time with new specified fields.")}, + {"__replace__", _PyCFunction_CAST(time_replace), METH_VARARGS | METH_KEYWORDS}, + {"fromisoformat", (PyCFunction)time_fromisoformat, METH_O | METH_CLASS, PyDoc_STR("string -> time from a string in ISO 8601 format")}, @@ -6579,6 +6583,8 @@ static PyMethodDef datetime_methods[] = { {"replace", _PyCFunction_CAST(datetime_replace), METH_VARARGS | METH_KEYWORDS, PyDoc_STR("Return datetime with new specified fields.")}, + {"__replace__", _PyCFunction_CAST(datetime_replace), METH_VARARGS | METH_KEYWORDS}, + {"astimezone", _PyCFunction_CAST(datetime_astimezone), METH_VARARGS | METH_KEYWORDS, PyDoc_STR("tz -> convert to local time in new timezone tz\n")}, diff --git a/Objects/codeobject.c b/Objects/codeobject.c index 70a0c2e..5830607 100644 --- a/Objects/codeobject.c +++ b/Objects/codeobject.c @@ -2145,6 +2145,7 @@ static struct PyMethodDef code_methods[] = { {"co_positions", (PyCFunction)code_positionsiterator, METH_NOARGS}, CODE_REPLACE_METHODDEF CODE__VARNAME_FROM_OPARG_METHODDEF + {"__replace__", _PyCFunction_CAST(code_replace), METH_FASTCALL|METH_KEYWORDS}, {NULL, NULL} /* sentinel */ }; |