summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Doc/c-api/dict.rst20
-rw-r--r--Doc/whatsnew/3.13.rst6
-rw-r--r--Include/cpython/dictobject.h10
-rw-r--r--Lib/test/test_capi/test_dict.py22
-rw-r--r--Misc/NEWS.d/next/C API/2023-11-15-13-47-48.gh-issue-112066.22WsqR.rst5
-rw-r--r--Modules/_testcapi/dict.c26
-rw-r--r--Objects/dictobject.c91
7 files changed, 160 insertions, 20 deletions
diff --git a/Doc/c-api/dict.rst b/Doc/c-api/dict.rst
index 8471c98..03f3d28 100644
--- a/Doc/c-api/dict.rst
+++ b/Doc/c-api/dict.rst
@@ -174,6 +174,26 @@ Dictionary Objects
.. versionadded:: 3.4
+.. c:function:: int PyDict_SetDefaultRef(PyObject *p, PyObject *key, PyObject *default_value, PyObject **result)
+
+ Inserts *default_value* into the dictionary *p* with a key of *key* if the
+ key is not already present in the dictionary. If *result* is not ``NULL``,
+ then *\*result* is set to a :term:`strong reference` to either
+ *default_value*, if the key was not present, or the existing value, if *key*
+ was already present in the dictionary.
+ Returns ``1`` if the key was present and *default_value* was not inserted,
+ or ``0`` if the key was not present and *default_value* was inserted.
+ On failure, returns ``-1``, sets an exception, and sets ``*result``
+ to ``NULL``.
+
+ For clarity: if you have a strong reference to *default_value* before
+ calling this function, then after it returns, you hold a strong reference
+ to both *default_value* and *\*result* (if it's not ``NULL``).
+ These may refer to the same object: in that case you hold two separate
+ references to it.
+ .. versionadded:: 3.13
+
+
.. c:function:: int PyDict_Pop(PyObject *p, PyObject *key, PyObject **result)
Remove *key* from dictionary *p* and optionally return the removed value.
diff --git a/Doc/whatsnew/3.13.rst b/Doc/whatsnew/3.13.rst
index 3727577..e034d34 100644
--- a/Doc/whatsnew/3.13.rst
+++ b/Doc/whatsnew/3.13.rst
@@ -1440,6 +1440,12 @@ New Features
not needed.
(Contributed by Victor Stinner in :gh:`106004`.)
+* Added :c:func:`PyDict_SetDefaultRef`, which is similar to
+ :c:func:`PyDict_SetDefault` but returns a :term:`strong reference` instead of
+ a :term:`borrowed reference`. This function returns ``-1`` on error, ``0`` on
+ insertion, and ``1`` if the key was already present in the dictionary.
+ (Contributed by Sam Gross in :gh:`112066`.)
+
* Add :c:func:`PyDict_ContainsString` function: same as
:c:func:`PyDict_Contains`, but *key* is specified as a :c:expr:`const char*`
UTF-8 encoded bytes string, rather than a :c:expr:`PyObject*`.
diff --git a/Include/cpython/dictobject.h b/Include/cpython/dictobject.h
index 1720fe6..35b6a82 100644
--- a/Include/cpython/dictobject.h
+++ b/Include/cpython/dictobject.h
@@ -41,6 +41,16 @@ PyAPI_FUNC(PyObject *) _PyDict_GetItemStringWithError(PyObject *, const char *);
PyAPI_FUNC(PyObject *) PyDict_SetDefault(
PyObject *mp, PyObject *key, PyObject *defaultobj);
+// Inserts `key` with a value `default_value`, if `key` is not already present
+// in the dictionary. If `result` is not NULL, then the value associated
+// with `key` is returned in `*result` (either the existing value, or the now
+// inserted `default_value`).
+// Returns:
+// -1 on error
+// 0 if `key` was not present and `default_value` was inserted
+// 1 if `key` was present and `default_value` was not inserted
+PyAPI_FUNC(int) PyDict_SetDefaultRef(PyObject *mp, PyObject *key, PyObject *default_value, PyObject **result);
+
/* Get the number of items of a dictionary. */
static inline Py_ssize_t PyDict_GET_SIZE(PyObject *op) {
PyDictObject *mp;
diff --git a/Lib/test/test_capi/test_dict.py b/Lib/test/test_capi/test_dict.py
index 57a7238..cca6145 100644
--- a/Lib/test/test_capi/test_dict.py
+++ b/Lib/test/test_capi/test_dict.py
@@ -339,6 +339,28 @@ class CAPITest(unittest.TestCase):
# CRASHES setdefault({}, 'a', NULL)
# CRASHES setdefault(NULL, 'a', 5)
+ def test_dict_setdefaultref(self):
+ setdefault = _testcapi.dict_setdefaultref
+ dct = {}
+ self.assertEqual(setdefault(dct, 'a', 5), 5)
+ self.assertEqual(dct, {'a': 5})
+ self.assertEqual(setdefault(dct, 'a', 8), 5)
+ self.assertEqual(dct, {'a': 5})
+
+ dct2 = DictSubclass()
+ self.assertEqual(setdefault(dct2, 'a', 5), 5)
+ self.assertEqual(dct2, {'a': 5})
+ self.assertEqual(setdefault(dct2, 'a', 8), 5)
+ self.assertEqual(dct2, {'a': 5})
+
+ self.assertRaises(TypeError, setdefault, {}, [], 5) # unhashable
+ self.assertRaises(SystemError, setdefault, UserDict(), 'a', 5)
+ self.assertRaises(SystemError, setdefault, [1], 0, 5)
+ self.assertRaises(SystemError, setdefault, 42, 'a', 5)
+ # CRASHES setdefault({}, NULL, 5)
+ # CRASHES setdefault({}, 'a', NULL)
+ # CRASHES setdefault(NULL, 'a', 5)
+
def test_mapping_keys_valuesitems(self):
class BadMapping(dict):
def keys(self):
diff --git a/Misc/NEWS.d/next/C API/2023-11-15-13-47-48.gh-issue-112066.22WsqR.rst b/Misc/NEWS.d/next/C API/2023-11-15-13-47-48.gh-issue-112066.22WsqR.rst
new file mode 100644
index 0000000..ae2b8b2
--- /dev/null
+++ b/Misc/NEWS.d/next/C API/2023-11-15-13-47-48.gh-issue-112066.22WsqR.rst
@@ -0,0 +1,5 @@
+Add :c:func:`PyDict_SetDefaultRef`: insert a key and value into a dictionary
+if the key is not already present. This is similar to
+:meth:`dict.setdefault`, but returns an integer value indicating if the key
+was already present. It is also similar to :c:func:`PyDict_SetDefault`, but
+returns a strong reference instead of a borrowed reference.
diff --git a/Modules/_testcapi/dict.c b/Modules/_testcapi/dict.c
index 42e056b..fe03c24 100644
--- a/Modules/_testcapi/dict.c
+++ b/Modules/_testcapi/dict.c
@@ -226,6 +226,31 @@ dict_setdefault(PyObject *self, PyObject *args)
}
static PyObject *
+dict_setdefaultref(PyObject *self, PyObject *args)
+{
+ PyObject *obj, *key, *default_value, *result = UNINITIALIZED_PTR;
+ if (!PyArg_ParseTuple(args, "OOO", &obj, &key, &default_value)) {
+ return NULL;
+ }
+ NULLABLE(obj);
+ NULLABLE(key);
+ NULLABLE(default_value);
+ switch (PyDict_SetDefaultRef(obj, key, default_value, &result)) {
+ case -1:
+ assert(result == NULL);
+ return NULL;
+ case 0:
+ assert(result == default_value);
+ return result;
+ case 1:
+ return result;
+ default:
+ Py_FatalError("PyDict_SetDefaultRef() returned invalid code");
+ Py_UNREACHABLE();
+ }
+}
+
+static PyObject *
dict_delitem(PyObject *self, PyObject *args)
{
PyObject *mapping, *key;
@@ -433,6 +458,7 @@ static PyMethodDef test_methods[] = {
{"dict_delitem", dict_delitem, METH_VARARGS},
{"dict_delitemstring", dict_delitemstring, METH_VARARGS},
{"dict_setdefault", dict_setdefault, METH_VARARGS},
+ {"dict_setdefaultref", dict_setdefaultref, METH_VARARGS},
{"dict_keys", dict_keys, METH_O},
{"dict_values", dict_values, METH_O},
{"dict_items", dict_items, METH_O},
diff --git a/Objects/dictobject.c b/Objects/dictobject.c
index 4bb818b..11b388d 100644
--- a/Objects/dictobject.c
+++ b/Objects/dictobject.c
@@ -3355,8 +3355,9 @@ dict_get_impl(PyDictObject *self, PyObject *key, PyObject *default_value)
return Py_NewRef(val);
}
-PyObject *
-PyDict_SetDefault(PyObject *d, PyObject *key, PyObject *defaultobj)
+static int
+dict_setdefault_ref(PyObject *d, PyObject *key, PyObject *default_value,
+ PyObject **result, int incref_result)
{
PyDictObject *mp = (PyDictObject *)d;
PyObject *value;
@@ -3365,41 +3366,64 @@ PyDict_SetDefault(PyObject *d, PyObject *key, PyObject *defaultobj)
if (!PyDict_Check(d)) {
PyErr_BadInternalCall();
- return NULL;
+ if (result) {
+ *result = NULL;
+ }
+ return -1;
}
if (!PyUnicode_CheckExact(key) || (hash = unicode_get_hash(key)) == -1) {
hash = PyObject_Hash(key);
- if (hash == -1)
- return NULL;
+ if (hash == -1) {
+ if (result) {
+ *result = NULL;
+ }
+ return -1;
+ }
}
if (mp->ma_keys == Py_EMPTY_KEYS) {
if (insert_to_emptydict(interp, mp, Py_NewRef(key), hash,
- Py_NewRef(defaultobj)) < 0) {
- return NULL;
+ Py_NewRef(default_value)) < 0) {
+ if (result) {
+ *result = NULL;
+ }
+ return -1;
+ }
+ if (result) {
+ *result = incref_result ? Py_NewRef(default_value) : default_value;
}
- return defaultobj;
+ return 0;
}
if (!PyUnicode_CheckExact(key) && DK_IS_UNICODE(mp->ma_keys)) {
if (insertion_resize(interp, mp, 0) < 0) {
- return NULL;
+ if (result) {
+ *result = NULL;
+ }
+ return -1;
}
}
Py_ssize_t ix = _Py_dict_lookup(mp, key, hash, &value);
- if (ix == DKIX_ERROR)
- return NULL;
+ if (ix == DKIX_ERROR) {
+ if (result) {
+ *result = NULL;
+ }
+ return -1;
+ }
if (ix == DKIX_EMPTY) {
uint64_t new_version = _PyDict_NotifyEvent(
- interp, PyDict_EVENT_ADDED, mp, key, defaultobj);
+ interp, PyDict_EVENT_ADDED, mp, key, default_value);
mp->ma_keys->dk_version = 0;
- value = defaultobj;
+ value = default_value;
if (mp->ma_keys->dk_usable <= 0) {
if (insertion_resize(interp, mp, 1) < 0) {
- return NULL;
+ if (result) {
+ *result = NULL;
+ }
+ return -1;
}
}
Py_ssize_t hashpos = find_empty_slot(mp->ma_keys, hash);
@@ -3431,11 +3455,16 @@ PyDict_SetDefault(PyObject *d, PyObject *key, PyObject *defaultobj)
mp->ma_keys->dk_usable--;
mp->ma_keys->dk_nentries++;
assert(mp->ma_keys->dk_usable >= 0);
+ ASSERT_CONSISTENT(mp);
+ if (result) {
+ *result = incref_result ? Py_NewRef(value) : value;
+ }
+ return 0;
}
else if (value == NULL) {
uint64_t new_version = _PyDict_NotifyEvent(
- interp, PyDict_EVENT_ADDED, mp, key, defaultobj);
- value = defaultobj;
+ interp, PyDict_EVENT_ADDED, mp, key, default_value);
+ value = default_value;
assert(_PyDict_HasSplitTable(mp));
assert(mp->ma_values->values[ix] == NULL);
MAINTAIN_TRACKING(mp, key, value);
@@ -3443,10 +3472,33 @@ PyDict_SetDefault(PyObject *d, PyObject *key, PyObject *defaultobj)
_PyDictValues_AddToInsertionOrder(mp->ma_values, ix);
mp->ma_used++;
mp->ma_version_tag = new_version;
+ ASSERT_CONSISTENT(mp);
+ if (result) {
+ *result = incref_result ? Py_NewRef(value) : value;
+ }
+ return 0;
}
ASSERT_CONSISTENT(mp);
- return value;
+ if (result) {
+ *result = incref_result ? Py_NewRef(value) : value;
+ }
+ return 1;
+}
+
+int
+PyDict_SetDefaultRef(PyObject *d, PyObject *key, PyObject *default_value,
+ PyObject **result)
+{
+ return dict_setdefault_ref(d, key, default_value, result, 1);
+}
+
+PyObject *
+PyDict_SetDefault(PyObject *d, PyObject *key, PyObject *defaultobj)
+{
+ PyObject *result;
+ dict_setdefault_ref(d, key, defaultobj, &result, 0);
+ return result;
}
/*[clinic input]
@@ -3467,9 +3519,8 @@ dict_setdefault_impl(PyDictObject *self, PyObject *key,
/*[clinic end generated code: output=f8c1101ebf69e220 input=0f063756e815fd9d]*/
{
PyObject *val;
-
- val = PyDict_SetDefault((PyObject *)self, key, default_value);
- return Py_XNewRef(val);
+ PyDict_SetDefaultRef((PyObject *)self, key, default_value, &val);
+ return val;
}