summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSerhiy Storchaka <storchaka@gmail.com>2023-09-06 20:55:42 (GMT)
committerGitHub <noreply@github.com>2023-09-06 20:55:42 (GMT)
commit6f3c138dfa868b32d3288898923bbfa388f2fa5d (patch)
tree98098db4d0d2f256f2a17d515c1750a28424166c
parent9f0c0a46f00d687e921990ee83894b2f4ce8a6e7 (diff)
downloadcpython-6f3c138dfa868b32d3288898923bbfa388f2fa5d.zip
cpython-6f3c138dfa868b32d3288898923bbfa388f2fa5d.tar.gz
cpython-6f3c138dfa868b32d3288898923bbfa388f2fa5d.tar.bz2
gh-108751: Add copy.replace() function (GH-108752)
It creates a modified copy of an object by calling the object's __replace__() method. It is a generalization of dataclasses.replace(), named tuple's _replace() method and replace() methods in various classes, and supports all these stdlib classes.
-rw-r--r--Doc/library/collections.rst2
-rw-r--r--Doc/library/copy.rst30
-rw-r--r--Doc/library/dataclasses.rst2
-rw-r--r--Doc/library/datetime.rst9
-rw-r--r--Doc/library/inspect.rst11
-rw-r--r--Doc/library/types.rst2
-rw-r--r--Doc/whatsnew/3.13.rst12
-rw-r--r--Lib/_pydatetime.py6
-rw-r--r--Lib/collections/__init__.py1
-rw-r--r--Lib/copy.py13
-rw-r--r--Lib/dataclasses.py9
-rw-r--r--Lib/inspect.py4
-rw-r--r--Lib/test/datetimetester.py137
-rw-r--r--Lib/test/test_code.py7
-rw-r--r--Lib/test/test_copy.py66
-rw-r--r--Lib/test/test_inspect.py65
-rw-r--r--Misc/NEWS.d/next/Library/2023-09-01-13-14-08.gh-issue-108751.2itqwe.rst2
-rw-r--r--Modules/_datetimemodule.c6
-rw-r--r--Objects/codeobject.c1
19 files changed, 314 insertions, 71 deletions
diff --git a/Doc/library/collections.rst b/Doc/library/collections.rst
index b8b231b..03cb1dc 100644
--- a/Doc/library/collections.rst
+++ b/Doc/library/collections.rst
@@ -979,6 +979,8 @@ field names, the method and attribute names start with an underscore.
>>> for partnum, record in inventory.items():
... inventory[partnum] = record._replace(price=newprices[partnum], timestamp=time.now())
+ Named tuples are also supported by generic function :func:`copy.replace`.
+
.. attribute:: somenamedtuple._fields
Tuple of strings listing the field names. Useful for introspection
diff --git a/Doc/library/copy.rst b/Doc/library/copy.rst
index 8f32477..cc4ca03 100644
--- a/Doc/library/copy.rst
+++ b/Doc/library/copy.rst
@@ -17,14 +17,22 @@ operations (explained below).
Interface summary:
-.. function:: copy(x)
+.. function:: copy(obj)
- Return a shallow copy of *x*.
+ Return a shallow copy of *obj*.
-.. function:: deepcopy(x[, memo])
+.. function:: deepcopy(obj[, memo])
- Return a deep copy of *x*.
+ Return a deep copy of *obj*.
+
+
+.. function:: replace(obj, /, **changes)
+
+ Creates a new object of the same type as *obj*, replacing fields with values
+ from *changes*.
+
+ .. versionadded:: 3.13
.. exception:: Error
@@ -89,6 +97,20 @@ with the component as first argument and the memo dictionary as second argument.
The memo dictionary should be treated as an opaque object.
+.. index::
+ single: __replace__() (replace protocol)
+
+Function :func:`replace` is more limited than :func:`copy` and :func:`deepcopy`,
+and only supports named tuples created by :func:`~collections.namedtuple`,
+:mod:`dataclasses`, and other classes which define method :meth:`!__replace__`.
+
+ .. method:: __replace__(self, /, **changes)
+ :noindex:
+
+:meth:`!__replace__` should create a new object of the same type,
+replacing fields with values from *changes*.
+
+
.. seealso::
Module :mod:`pickle`
diff --git a/Doc/library/dataclasses.rst b/Doc/library/dataclasses.rst
index d687487..d78a607 100644
--- a/Doc/library/dataclasses.rst
+++ b/Doc/library/dataclasses.rst
@@ -456,6 +456,8 @@ Module contents
``replace()`` (or similarly named) method which handles instance
copying.
+ Dataclass instances are also supported by generic function :func:`copy.replace`.
+
.. function:: is_dataclass(obj)
Return ``True`` if its parameter is a dataclass or an instance of one,
diff --git a/Doc/library/datetime.rst b/Doc/library/datetime.rst
index 04cc755..0b9d42f 100644
--- a/Doc/library/datetime.rst
+++ b/Doc/library/datetime.rst
@@ -652,6 +652,9 @@ Instance methods:
>>> d.replace(day=26)
datetime.date(2002, 12, 26)
+ :class:`date` objects are also supported by generic function
+ :func:`copy.replace`.
+
.. method:: date.timetuple()
@@ -1251,6 +1254,9 @@ Instance methods:
``tzinfo=None`` can be specified to create a naive datetime from an aware
datetime with no conversion of date and time data.
+ :class:`datetime` objects are also supported by generic function
+ :func:`copy.replace`.
+
.. versionadded:: 3.6
Added the ``fold`` argument.
@@ -1827,6 +1833,9 @@ Instance methods:
``tzinfo=None`` can be specified to create a naive :class:`.time` from an
aware :class:`.time`, without conversion of the time data.
+ :class:`time` objects are also supported by generic function
+ :func:`copy.replace`.
+
.. versionadded:: 3.6
Added the ``fold`` argument.
diff --git a/Doc/library/inspect.rst b/Doc/library/inspect.rst
index 603ac32..fe0ed13 100644
--- a/Doc/library/inspect.rst
+++ b/Doc/library/inspect.rst
@@ -689,8 +689,8 @@ function.
The optional *return_annotation* argument, can be an arbitrary Python object,
is the "return" annotation of the callable.
- Signature objects are *immutable*. Use :meth:`Signature.replace` to make a
- modified copy.
+ Signature objects are *immutable*. Use :meth:`Signature.replace` or
+ :func:`copy.replace` to make a modified copy.
.. versionchanged:: 3.5
Signature objects are picklable and :term:`hashable`.
@@ -746,6 +746,9 @@ function.
>>> str(new_sig)
"(a, b) -> 'new return anno'"
+ Signature objects are also supported by generic function
+ :func:`copy.replace`.
+
.. classmethod:: Signature.from_callable(obj, *, follow_wrapped=True, globalns=None, localns=None)
Return a :class:`Signature` (or its subclass) object for a given callable
@@ -769,7 +772,7 @@ function.
.. class:: Parameter(name, kind, *, default=Parameter.empty, annotation=Parameter.empty)
Parameter objects are *immutable*. Instead of modifying a Parameter object,
- you can use :meth:`Parameter.replace` to create a modified copy.
+ you can use :meth:`Parameter.replace` or :func:`copy.replace` to create a modified copy.
.. versionchanged:: 3.5
Parameter objects are picklable and :term:`hashable`.
@@ -892,6 +895,8 @@ function.
>>> str(param.replace(default=Parameter.empty, annotation='spam'))
"foo:'spam'"
+ Parameter objects are also supported by generic function :func:`copy.replace`.
+
.. versionchanged:: 3.4
In Python 3.3 Parameter objects were allowed to have ``name`` set
to ``None`` if their ``kind`` was set to ``POSITIONAL_ONLY``.
diff --git a/Doc/library/types.rst b/Doc/library/types.rst
index 8cbe17d..82300af 100644
--- a/Doc/library/types.rst
+++ b/Doc/library/types.rst
@@ -200,6 +200,8 @@ Standard names are defined for the following types:
Return a copy of the code object with new values for the specified fields.
+ Code objects are also supported by generic function :func:`copy.replace`.
+
.. versionadded:: 3.8
.. data:: CellType
diff --git a/Doc/whatsnew/3.13.rst b/Doc/whatsnew/3.13.rst
index de23172..8c64675 100644
--- a/Doc/whatsnew/3.13.rst
+++ b/Doc/whatsnew/3.13.rst
@@ -115,6 +115,18 @@ array
It can be used instead of ``'u'`` type code, which is deprecated.
(Contributed by Inada Naoki in :gh:`80480`.)
+copy
+----
+
+* Add :func:`copy.replace` function which allows to create a modified copy of
+ an object, which is especially usefule for immutable objects.
+ It supports named tuples created with the factory function
+ :func:`collections.namedtuple`, :class:`~dataclasses.dataclass` instances,
+ various :mod:`datetime` objects, :class:`~inspect.Signature` objects,
+ :class:`~inspect.Parameter` objects, :ref:`code object <code-objects>`, and
+ any user classes which define the :meth:`!__replace__` method.
+ (Contributed by Serhiy Storchaka in :gh:`108751`.)
+
dbm
---
diff --git a/Lib/_pydatetime.py b/Lib/_pydatetime.py
index 549fcda..df616bb 100644
--- a/Lib/_pydatetime.py
+++ b/Lib/_pydatetime.py
@@ -1112,6 +1112,8 @@ class date:
day = self._day
return type(self)(year, month, day)
+ __replace__ = replace
+
# Comparisons of date objects with other.
def __eq__(self, other):
@@ -1637,6 +1639,8 @@ class time:
fold = self._fold
return type(self)(hour, minute, second, microsecond, tzinfo, fold=fold)
+ __replace__ = replace
+
# Pickle support.
def _getstate(self, protocol=3):
@@ -1983,6 +1987,8 @@ class datetime(date):
return type(self)(year, month, day, hour, minute, second,
microsecond, tzinfo, fold=fold)
+ __replace__ = replace
+
def _local_timezone(self):
if self.tzinfo is None:
ts = self._mktime()
diff --git a/Lib/collections/__init__.py b/Lib/collections/__init__.py
index 8652dc8..a461550 100644
--- a/Lib/collections/__init__.py
+++ b/Lib/collections/__init__.py
@@ -495,6 +495,7 @@ def namedtuple(typename, field_names, *, rename=False, defaults=None, module=Non
'_field_defaults': field_defaults,
'__new__': __new__,
'_make': _make,
+ '__replace__': _replace,
'_replace': _replace,
'__repr__': __repr__,
'_asdict': _asdict,
diff --git a/Lib/copy.py b/Lib/copy.py
index da2908e..6d7bb9a 100644
--- a/Lib/copy.py
+++ b/Lib/copy.py
@@ -290,3 +290,16 @@ def _reconstruct(x, memo, func, args,
return y
del types, weakref
+
+
+def replace(obj, /, **changes):
+ """Return a new object replacing specified fields with new values.
+
+ This is especially useful for immutable objects, like named tuples or
+ frozen dataclasses.
+ """
+ cls = obj.__class__
+ func = getattr(cls, '__replace__', None)
+ if func is None:
+ raise TypeError(f"replace() does not support {cls.__name__} objects")
+ return func(obj, **changes)
diff --git a/Lib/dataclasses.py b/Lib/dataclasses.py
index 21f3fa5..84f8d68 100644
--- a/Lib/dataclasses.py
+++ b/Lib/dataclasses.py
@@ -1073,6 +1073,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
globals,
slots,
))
+ _set_new_attribute(cls, '__replace__', _replace)
# Get the fields as a list, and include only real fields. This is
# used in all of the following methods.
@@ -1546,13 +1547,15 @@ def replace(obj, /, **changes):
c1 = replace(c, x=3)
assert c1.x == 3 and c1.y == 2
"""
+ if not _is_dataclass_instance(obj):
+ raise TypeError("replace() should be called on dataclass instances")
+ return _replace(obj, **changes)
+
+def _replace(obj, /, **changes):
# We're going to mutate 'changes', but that's okay because it's a
# new dict, even if called with 'replace(obj, **my_changes)'.
- if not _is_dataclass_instance(obj):
- raise TypeError("replace() should be called on dataclass instances")
-
# It's an error to have init=False fields in 'changes'.
# If a field is not in 'changes', read its value from the provided obj.
diff --git a/Lib/inspect.py b/Lib/inspect.py
index c821183..aaa22be 100644
--- a/Lib/inspect.py
+++ b/Lib/inspect.py
@@ -2870,6 +2870,8 @@ class Parameter:
return formatted
+ __replace__ = replace
+
def __repr__(self):
return '<{} "{}">'.format(self.__class__.__name__, self)
@@ -3130,6 +3132,8 @@ class Signature:
return type(self)(parameters,
return_annotation=return_annotation)
+ __replace__ = replace
+
def _hash_basis(self):
params = tuple(param for param in self.parameters.values()
if param.kind != _KEYWORD_ONLY)
diff --git a/Lib/test/datetimetester.py b/Lib/test/datetimetester.py
index 55e0619..8bda173 100644
--- a/Lib/test/datetimetester.py
+++ b/Lib/test/datetimetester.py
@@ -1699,22 +1699,23 @@ class TestDate(HarmlessMixedComparison, unittest.TestCase):
cls = self.theclass
args = [1, 2, 3]
base = cls(*args)
- self.assertEqual(base, base.replace())
+ self.assertEqual(base.replace(), base)
+ self.assertEqual(copy.replace(base), base)
- i = 0
- for name, newval in (("year", 2),
- ("month", 3),
- ("day", 4)):
+ changes = (("year", 2),
+ ("month", 3),
+ ("day", 4))
+ for i, (name, newval) in enumerate(changes):
newargs = args[:]
newargs[i] = newval
expected = cls(*newargs)
- got = base.replace(**{name: newval})
- self.assertEqual(expected, got)
- i += 1
+ self.assertEqual(base.replace(**{name: newval}), expected)
+ self.assertEqual(copy.replace(base, **{name: newval}), expected)
# Out of bounds.
base = cls(2000, 2, 29)
self.assertRaises(ValueError, base.replace, year=2001)
+ self.assertRaises(ValueError, copy.replace, base, year=2001)
def test_subclass_replace(self):
class DateSubclass(self.theclass):
@@ -1722,6 +1723,7 @@ class TestDate(HarmlessMixedComparison, unittest.TestCase):
dt = DateSubclass(2012, 1, 1)
self.assertIs(type(dt.replace(year=2013)), DateSubclass)
+ self.assertIs(type(copy.replace(dt, year=2013)), DateSubclass)
def test_subclass_date(self):
@@ -2856,26 +2858,27 @@ class TestDateTime(TestDate):
cls = self.theclass
args = [1, 2, 3, 4, 5, 6, 7]
base = cls(*args)
- self.assertEqual(base, base.replace())
-
- i = 0
- for name, newval in (("year", 2),
- ("month", 3),
- ("day", 4),
- ("hour", 5),
- ("minute", 6),
- ("second", 7),
- ("microsecond", 8)):
+ self.assertEqual(base.replace(), base)
+ self.assertEqual(copy.replace(base), base)
+
+ changes = (("year", 2),
+ ("month", 3),
+ ("day", 4),
+ ("hour", 5),
+ ("minute", 6),
+ ("second", 7),
+ ("microsecond", 8))
+ for i, (name, newval) in enumerate(changes):
newargs = args[:]
newargs[i] = newval
expected = cls(*newargs)
- got = base.replace(**{name: newval})
- self.assertEqual(expected, got)
- i += 1
+ self.assertEqual(base.replace(**{name: newval}), expected)
+ self.assertEqual(copy.replace(base, **{name: newval}), expected)
# Out of bounds.
base = cls(2000, 2, 29)
self.assertRaises(ValueError, base.replace, year=2001)
+ self.assertRaises(ValueError, copy.replace, base, year=2001)
@support.run_with_tz('EDT4')
def test_astimezone(self):
@@ -3671,19 +3674,19 @@ class TestTime(HarmlessMixedComparison, unittest.TestCase):
cls = self.theclass
args = [1, 2, 3, 4]
base = cls(*args)
- self.assertEqual(base, base.replace())
-
- i = 0
- for name, newval in (("hour", 5),
- ("minute", 6),
- ("second", 7),
- ("microsecond", 8)):
+ self.assertEqual(base.replace(), base)
+ self.assertEqual(copy.replace(base), base)
+
+ changes = (("hour", 5),
+ ("minute", 6),
+ ("second", 7),
+ ("microsecond", 8))
+ for i, (name, newval) in enumerate(changes):
newargs = args[:]
newargs[i] = newval
expected = cls(*newargs)
- got = base.replace(**{name: newval})
- self.assertEqual(expected, got)
- i += 1
+ self.assertEqual(base.replace(**{name: newval}), expected)
+ self.assertEqual(copy.replace(base, **{name: newval}), expected)
# Out of bounds.
base = cls(1)
@@ -3691,6 +3694,10 @@ class TestTime(HarmlessMixedComparison, unittest.TestCase):
self.assertRaises(ValueError, base.replace, minute=-1)
self.assertRaises(ValueError, base.replace, second=100)
self.assertRaises(ValueError, base.replace, microsecond=1000000)
+ self.assertRaises(ValueError, copy.replace, base, hour=24)
+ self.assertRaises(ValueError, copy.replace, base, minute=-1)
+ self.assertRaises(ValueError, copy.replace, base, second=100)
+ self.assertRaises(ValueError, copy.replace, base, microsecond=1000000)
def test_subclass_replace(self):
class TimeSubclass(self.theclass):
@@ -3698,6 +3705,7 @@ class TestTime(HarmlessMixedComparison, unittest.TestCase):
ctime = TimeSubclass(12, 30)
self.assertIs(type(ctime.replace(hour=10)), TimeSubclass)
+ self.assertIs(type(copy.replace(ctime, hour=10)), TimeSubclass)
def test_subclass_time(self):
@@ -4085,31 +4093,37 @@ class TestTimeTZ(TestTime, TZInfoBase, unittest.TestCase):
zm200 = FixedOffset(timedelta(minutes=-200), "-200")
args = [1, 2, 3, 4, z100]
base = cls(*args)
- self.assertEqual(base, base.replace())
-
- i = 0
- for name, newval in (("hour", 5),
- ("minute", 6),
- ("second", 7),
- ("microsecond", 8),
- ("tzinfo", zm200)):
+ self.assertEqual(base.replace(), base)
+ self.assertEqual(copy.replace(base), base)
+
+ changes = (("hour", 5),
+ ("minute", 6),
+ ("second", 7),
+ ("microsecond", 8),
+ ("tzinfo", zm200))
+ for i, (name, newval) in enumerate(changes):
newargs = args[:]
newargs[i] = newval
expected = cls(*newargs)
- got = base.replace(**{name: newval})
- self.assertEqual(expected, got)
- i += 1
+ self.assertEqual(base.replace(**{name: newval}), expected)
+ self.assertEqual(copy.replace(base, **{name: newval}), expected)
# Ensure we can get rid of a tzinfo.
self.assertEqual(base.tzname(), "+100")
base2 = base.replace(tzinfo=None)
self.assertIsNone(base2.tzinfo)
self.assertIsNone(base2.tzname())
+ base22 = copy.replace(base, tzinfo=None)
+ self.assertIsNone(base22.tzinfo)
+ self.assertIsNone(base22.tzname())
# Ensure we can add one.
base3 = base2.replace(tzinfo=z100)
self.assertEqual(base, base3)
self.assertIs(base.tzinfo, base3.tzinfo)
+ base32 = copy.replace(base22, tzinfo=z100)
+ self.assertEqual(base, base32)
+ self.assertIs(base.tzinfo, base32.tzinfo)
# Out of bounds.
base = cls(1)
@@ -4117,6 +4131,10 @@ class TestTimeTZ(TestTime, TZInfoBase, unittest.TestCase):
self.assertRaises(ValueError, base.replace, minute=-1)
self.assertRaises(ValueError, base.replace, second=100)
self.assertRaises(ValueError, base.replace, microsecond=1000000)
+ self.assertRaises(ValueError, copy.replace, base, hour=24)
+ self.assertRaises(ValueError, copy.replace, base, minute=-1)
+ self.assertRaises(ValueError, copy.replace, base, second=100)
+ self.assertRaises(ValueError, copy.replace, base, microsecond=1000000)
def test_mixed_compare(self):
t1 = self.theclass(1, 2, 3)
@@ -4885,38 +4903,45 @@ class TestDateTimeTZ(TestDateTime, TZInfoBase, unittest.TestCase):
zm200 = FixedOffset(timedelta(minutes=-200), "-200")
args = [1, 2, 3, 4, 5, 6, 7, z100]
base = cls(*args)
- self.assertEqual(base, base.replace())
-
- i = 0
- for name, newval in (("year", 2),
- ("month", 3),
- ("day", 4),
- ("hour", 5),
- ("minute", 6),
- ("second", 7),
- ("microsecond", 8),
- ("tzinfo", zm200)):
+ self.assertEqual(base.replace(), base)
+ self.assertEqual(copy.replace(base), base)
+
+ changes = (("year", 2),
+ ("month", 3),
+ ("day", 4),
+ ("hour", 5),
+ ("minute", 6),
+ ("second", 7),
+ ("microsecond", 8),
+ ("tzinfo", zm200))
+ for i, (name, newval) in enumerate(changes):
newargs = args[:]
newargs[i] = newval
expected = cls(*newargs)
- got = base.replace(**{name: newval})
- self.assertEqual(expected, got)
- i += 1
+ self.assertEqual(base.replace(**{name: newval}), expected)
+ self.assertEqual(copy.replace(base, **{name: newval}), expected)
# Ensure we can get rid of a tzinfo.
self.assertEqual(base.tzname(), "+100")
base2 = base.replace(tzinfo=None)
self.assertIsNone(base2.tzinfo)
self.assertIsNone(base2.tzname())
+ base22 = copy.replace(base, tzinfo=None)
+ self.assertIsNone(base22.tzinfo)
+ self.assertIsNone(base22.tzname())
# Ensure we can add one.
base3 = base2.replace(tzinfo=z100)
self.assertEqual(base, base3)
self.assertIs(base.tzinfo, base3.tzinfo)
+ base32 = copy.replace(base22, tzinfo=z100)
+ self.assertEqual(base, base32)
+ self.assertIs(base.tzinfo, base32.tzinfo)
# Out of bounds.
base = cls(2000, 2, 29)
self.assertRaises(ValueError, base.replace, year=2001)
+ self.assertRaises(ValueError, copy.replace, base, year=2001)
def test_more_astimezone(self):
# The inherited test_astimezone covered some trivial and error cases.
diff --git a/Lib/test/test_code.py b/Lib/test/test_code.py
index e056c16..812c068 100644
--- a/Lib/test/test_code.py
+++ b/Lib/test/test_code.py
@@ -125,6 +125,7 @@ consts: ('None',)
"""
+import copy
import inspect
import sys
import threading
@@ -280,11 +281,17 @@ class CodeTest(unittest.TestCase):
with self.subTest(attr=attr, value=value):
new_code = code.replace(**{attr: value})
self.assertEqual(getattr(new_code, attr), value)
+ new_code = copy.replace(code, **{attr: value})
+ self.assertEqual(getattr(new_code, attr), value)
new_code = code.replace(co_varnames=code2.co_varnames,
co_nlocals=code2.co_nlocals)
self.assertEqual(new_code.co_varnames, code2.co_varnames)
self.assertEqual(new_code.co_nlocals, code2.co_nlocals)
+ new_code = copy.replace(code, co_varnames=code2.co_varnames,
+ co_nlocals=code2.co_nlocals)
+ self.assertEqual(new_code.co_varnames, code2.co_varnames)
+ self.assertEqual(new_code.co_nlocals, code2.co_nlocals)
def test_nlocals_mismatch(self):
def func():
diff --git a/Lib/test/test_copy.py b/Lib/test/test_copy.py
index 826e468..c66c6ee 100644
--- a/Lib/test/test_copy.py
+++ b/Lib/test/test_copy.py
@@ -4,7 +4,7 @@ import copy
import copyreg
import weakref
import abc
-from operator import le, lt, ge, gt, eq, ne
+from operator import le, lt, ge, gt, eq, ne, attrgetter
import unittest
from test import support
@@ -899,7 +899,71 @@ class TestCopy(unittest.TestCase):
g.b()
+class TestReplace(unittest.TestCase):
+
+ def test_unsupported(self):
+ self.assertRaises(TypeError, copy.replace, 1)
+ self.assertRaises(TypeError, copy.replace, [])
+ self.assertRaises(TypeError, copy.replace, {})
+ def f(): pass
+ self.assertRaises(TypeError, copy.replace, f)
+ class A: pass
+ self.assertRaises(TypeError, copy.replace, A)
+ self.assertRaises(TypeError, copy.replace, A())
+
+ def test_replace_method(self):
+ class A:
+ def __new__(cls, x, y=0):
+ self = object.__new__(cls)
+ self.x = x
+ self.y = y
+ return self
+
+ def __init__(self, *args, **kwargs):
+ self.z = self.x + self.y
+
+ def __replace__(self, **changes):
+ x = changes.get('x', self.x)
+ y = changes.get('y', self.y)
+ return type(self)(x, y)
+
+ attrs = attrgetter('x', 'y', 'z')
+ a = A(11, 22)
+ self.assertEqual(attrs(copy.replace(a)), (11, 22, 33))
+ self.assertEqual(attrs(copy.replace(a, x=1)), (1, 22, 23))
+ self.assertEqual(attrs(copy.replace(a, y=2)), (11, 2, 13))
+ self.assertEqual(attrs(copy.replace(a, x=1, y=2)), (1, 2, 3))
+
+ def test_namedtuple(self):
+ from collections import namedtuple
+ Point = namedtuple('Point', 'x y', defaults=(0,))
+ p = Point(11, 22)
+ self.assertEqual(copy.replace(p), (11, 22))
+ self.assertEqual(copy.replace(p, x=1), (1, 22))
+ self.assertEqual(copy.replace(p, y=2), (11, 2))
+ self.assertEqual(copy.replace(p, x=1, y=2), (1, 2))
+ with self.assertRaisesRegex(ValueError, 'unexpected field name'):
+ copy.replace(p, x=1, error=2)
+
+ def test_dataclass(self):
+ from dataclasses import dataclass
+ @dataclass
+ class C:
+ x: int
+ y: int = 0
+
+ attrs = attrgetter('x', 'y')
+ c = C(11, 22)
+ self.assertEqual(attrs(copy.replace(c)), (11, 22))
+ self.assertEqual(attrs(copy.replace(c, x=1)), (1, 22))
+ self.assertEqual(attrs(copy.replace(c, y=2)), (11, 2))
+ self.assertEqual(attrs(copy.replace(c, x=1, y=2)), (1, 2))
+ with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'):
+ copy.replace(c, x=1, error=2)
+
+
def global_foo(x, y): return x+y
+
if __name__ == "__main__":
unittest.main()
diff --git a/Lib/test/test_inspect.py b/Lib/test/test_inspect.py
index 78ef817..2fb356a 100644
--- a/Lib/test/test_inspect.py
+++ b/Lib/test/test_inspect.py
@@ -1,6 +1,7 @@
import asyncio
import builtins
import collections
+import copy
import datetime
import functools
import importlib
@@ -3830,6 +3831,28 @@ class TestSignatureObject(unittest.TestCase):
P('bar', P.VAR_POSITIONAL)])),
'(foo, /, *bar)')
+ def test_signature_replace_parameters(self):
+ def test(a, b) -> 42:
+ pass
+
+ sig = inspect.signature(test)
+ parameters = sig.parameters
+ sig = sig.replace(parameters=list(parameters.values())[1:])
+ self.assertEqual(list(sig.parameters), ['b'])
+ self.assertEqual(sig.parameters['b'], parameters['b'])
+ self.assertEqual(sig.return_annotation, 42)
+ sig = sig.replace(parameters=())
+ self.assertEqual(dict(sig.parameters), {})
+
+ sig = inspect.signature(test)
+ parameters = sig.parameters
+ sig = copy.replace(sig, parameters=list(parameters.values())[1:])
+ self.assertEqual(list(sig.parameters), ['b'])
+ self.assertEqual(sig.parameters['b'], parameters['b'])
+ self.assertEqual(sig.return_annotation, 42)
+ sig = copy.replace(sig, parameters=())
+ self.assertEqual(dict(sig.parameters), {})
+
def test_signature_replace_anno(self):
def test() -> 42:
pass
@@ -3843,6 +3866,15 @@ class TestSignatureObject(unittest.TestCase):
self.assertEqual(sig.return_annotation, 42)
self.assertEqual(sig, inspect.signature(test))
+ sig = inspect.signature(test)
+ sig = copy.replace(sig, return_annotation=None)
+ self.assertIs(sig.return_annotation, None)
+ sig = copy.replace(sig, return_annotation=sig.empty)
+ self.assertIs(sig.return_annotation, sig.empty)
+ sig = copy.replace(sig, return_annotation=42)
+ self.assertEqual(sig.return_annotation, 42)
+ self.assertEqual(sig, inspect.signature(test))
+
def test_signature_replaced(self):
def test():
pass
@@ -4187,41 +4219,66 @@ class TestParameterObject(unittest.TestCase):
p = inspect.Parameter('foo', default=42,
kind=inspect.Parameter.KEYWORD_ONLY)
- self.assertIsNot(p, p.replace())
- self.assertEqual(p, p.replace())
+ self.assertIsNot(p.replace(), p)
+ self.assertEqual(p.replace(), p)
+ self.assertIsNot(copy.replace(p), p)
+ self.assertEqual(copy.replace(p), p)
p2 = p.replace(annotation=1)
self.assertEqual(p2.annotation, 1)
p2 = p2.replace(annotation=p2.empty)
- self.assertEqual(p, p2)
+ self.assertEqual(p2, p)
+ p3 = copy.replace(p, annotation=1)
+ self.assertEqual(p3.annotation, 1)
+ p3 = copy.replace(p3, annotation=p3.empty)
+ self.assertEqual(p3, p)
p2 = p2.replace(name='bar')
self.assertEqual(p2.name, 'bar')
self.assertNotEqual(p2, p)
+ p3 = copy.replace(p3, name='bar')
+ self.assertEqual(p3.name, 'bar')
+ self.assertNotEqual(p3, p)
with self.assertRaisesRegex(ValueError,
'name is a required attribute'):
p2 = p2.replace(name=p2.empty)
+ with self.assertRaisesRegex(ValueError,
+ 'name is a required attribute'):
+ p3 = copy.replace(p3, name=p3.empty)
p2 = p2.replace(name='foo', default=None)
self.assertIs(p2.default, None)
self.assertNotEqual(p2, p)
+ p3 = copy.replace(p3, name='foo', default=None)
+ self.assertIs(p3.default, None)
+ self.assertNotEqual(p3, p)
p2 = p2.replace(name='foo', default=p2.empty)
self.assertIs(p2.default, p2.empty)
-
+ p3 = copy.replace(p3, name='foo', default=p3.empty)
+ self.assertIs(p3.default, p3.empty)
p2 = p2.replace(default=42, kind=p2.POSITIONAL_OR_KEYWORD)
self.assertEqual(p2.kind, p2.POSITIONAL_OR_KEYWORD)
self.assertNotEqual(p2, p)
+ p3 = copy.replace(p3, default=42, kind=p3.POSITIONAL_OR_KEYWORD)
+ self.assertEqual(p3.kind, p3.POSITIONAL_OR_KEYWORD)
+ self.assertNotEqual(p3, p)
with self.assertRaisesRegex(ValueError,
"value <class 'inspect._empty'> "
"is not a valid Parameter.kind"):
p2 = p2.replace(kind=p2.empty)
+ with self.assertRaisesRegex(ValueError,
+ "value <class 'inspect._empty'> "
+ "is not a valid Parameter.kind"):
+ p3 = copy.replace(p3, kind=p3.empty)
p2 = p2.replace(kind=p2.KEYWORD_ONLY)
self.assertEqual(p2, p)
+ p3 = copy.replace(p3, kind=p3.KEYWORD_ONLY)
+ self.assertEqual(p3, p)
def test_signature_parameter_positional_only(self):
with self.assertRaisesRegex(TypeError, 'name must be a str'):
diff --git a/Misc/NEWS.d/next/Library/2023-09-01-13-14-08.gh-issue-108751.2itqwe.rst b/Misc/NEWS.d/next/Library/2023-09-01-13-14-08.gh-issue-108751.2itqwe.rst
new file mode 100644
index 0000000..7bc21fe
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2023-09-01-13-14-08.gh-issue-108751.2itqwe.rst
@@ -0,0 +1,2 @@
+Add :func:`copy.replace` function which allows to create a modified copy of
+an object. It supports named tuples, dataclasses, and many other objects.
diff --git a/Modules/_datetimemodule.c b/Modules/_datetimemodule.c
index 191db3f..0d35677 100644
--- a/Modules/_datetimemodule.c
+++ b/Modules/_datetimemodule.c
@@ -3590,6 +3590,8 @@ static PyMethodDef date_methods[] = {
{"replace", _PyCFunction_CAST(date_replace), METH_VARARGS | METH_KEYWORDS,
PyDoc_STR("Return date with new specified fields.")},
+ {"__replace__", _PyCFunction_CAST(date_replace), METH_VARARGS | METH_KEYWORDS},
+
{"__reduce__", (PyCFunction)date_reduce, METH_NOARGS,
PyDoc_STR("__reduce__() -> (cls, state)")},
@@ -4719,6 +4721,8 @@ static PyMethodDef time_methods[] = {
{"replace", _PyCFunction_CAST(time_replace), METH_VARARGS | METH_KEYWORDS,
PyDoc_STR("Return time with new specified fields.")},
+ {"__replace__", _PyCFunction_CAST(time_replace), METH_VARARGS | METH_KEYWORDS},
+
{"fromisoformat", (PyCFunction)time_fromisoformat, METH_O | METH_CLASS,
PyDoc_STR("string -> time from a string in ISO 8601 format")},
@@ -6579,6 +6583,8 @@ static PyMethodDef datetime_methods[] = {
{"replace", _PyCFunction_CAST(datetime_replace), METH_VARARGS | METH_KEYWORDS,
PyDoc_STR("Return datetime with new specified fields.")},
+ {"__replace__", _PyCFunction_CAST(datetime_replace), METH_VARARGS | METH_KEYWORDS},
+
{"astimezone", _PyCFunction_CAST(datetime_astimezone), METH_VARARGS | METH_KEYWORDS,
PyDoc_STR("tz -> convert to local time in new timezone tz\n")},
diff --git a/Objects/codeobject.c b/Objects/codeobject.c
index 70a0c2e..5830607 100644
--- a/Objects/codeobject.c
+++ b/Objects/codeobject.c
@@ -2145,6 +2145,7 @@ static struct PyMethodDef code_methods[] = {
{"co_positions", (PyCFunction)code_positionsiterator, METH_NOARGS},
CODE_REPLACE_METHODDEF
CODE__VARNAME_FROM_OPARG_METHODDEF
+ {"__replace__", _PyCFunction_CAST(code_replace), METH_FASTCALL|METH_KEYWORDS},
{NULL, NULL} /* sentinel */
};