summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Lib/test/test_type_cache.py158
-rw-r--r--Misc/NEWS.d/next/Core and Builtins/2024-01-03-12-19-37.gh-issue-89811.cZOj6d.rst2
-rw-r--r--Modules/_testcapimodule.c29
-rw-r--r--Python/specialize.c28
4 files changed, 216 insertions, 1 deletions
diff --git a/Lib/test/test_type_cache.py b/Lib/test/test_type_cache.py
index 9dc91dc..11158f9 100644
--- a/Lib/test/test_type_cache.py
+++ b/Lib/test/test_type_cache.py
@@ -1,5 +1,6 @@
""" Tests for the internal type cache in CPython. """
import unittest
+import dis
from test import support
from test.support import import_helper
try:
@@ -8,7 +9,17 @@ except ImportError:
_clear_type_cache = None
# Skip this test if the _testcapi module isn't available.
-type_get_version = import_helper.import_module('_testcapi').type_get_version
+_testcapi = import_helper.import_module("_testcapi")
+type_get_version = _testcapi.type_get_version
+type_assign_specific_version_unsafe = _testcapi.type_assign_specific_version_unsafe
+type_modified = _testcapi.type_modified
+
+
+def type_assign_version(type_):
+ try:
+ type_.x
+ except AttributeError:
+ pass
@support.cpython_only
@@ -42,6 +53,151 @@ class TypeCacheTests(unittest.TestCase):
self.assertEqual(len(set(all_version_tags)), 30,
msg=f"{all_version_tags} contains non-unique versions")
+ def test_type_assign_specific_version(self):
+ """meta-test for type_assign_specific_version_unsafe"""
+ class C:
+ pass
+
+ type_assign_version(C)
+ orig_version = type_get_version(C)
+ if orig_version == 0:
+ self.skipTest("Could not assign a valid type version")
+
+ type_modified(C)
+ type_assign_specific_version_unsafe(C, orig_version + 5)
+ type_assign_version(C) # this should do nothing
+
+ new_version = type_get_version(C)
+ self.assertEqual(new_version, orig_version + 5)
+
+ _clear_type_cache()
+
+
+@support.cpython_only
+class TypeCacheWithSpecializationTests(unittest.TestCase):
+ def tearDown(self):
+ _clear_type_cache()
+
+ def _assign_valid_version_or_skip(self, type_):
+ type_modified(type_)
+ type_assign_version(type_)
+ if type_get_version(type_) == 0:
+ self.skipTest("Could not assign valid type version")
+
+ def _assign_and_check_version_0(self, user_type):
+ type_modified(user_type)
+ type_assign_specific_version_unsafe(user_type, 0)
+ self.assertEqual(type_get_version(user_type), 0)
+
+ def _all_opnames(self, func):
+ return set(instr.opname for instr in dis.Bytecode(func, adaptive=True))
+
+ def _check_specialization(self, func, arg, opname, *, should_specialize):
+ for _ in range(100):
+ func(arg)
+
+ if should_specialize:
+ self.assertNotIn(opname, self._all_opnames(func))
+ else:
+ self.assertIn(opname, self._all_opnames(func))
+
+ def test_load_method_specialization_user_type(self):
+ class A:
+ def foo(self):
+ pass
+
+ self._assign_valid_version_or_skip(A)
+
+ def load_foo_1(instance):
+ instance.foo()
+
+ self._check_specialization(
+ load_foo_1, A(), "LOAD_METHOD_ADAPTIVE", should_specialize=True
+ )
+ del load_foo_1
+
+ self._assign_and_check_version_0(A)
+
+ def load_foo_2(instance):
+ instance.foo()
+
+ self._check_specialization(
+ load_foo_2, A(), "LOAD_METHOD_ADAPTIVE", should_specialize=False
+ )
+
+ def test_store_attr_specialization_user_type(self):
+ class B:
+ __slots__ = ("bar",)
+
+ self._assign_valid_version_or_skip(B)
+
+ def store_bar_1(instance):
+ instance.bar = 10
+
+ self._check_specialization(
+ store_bar_1, B(), "STORE_ATTR_ADAPTIVE", should_specialize=True
+ )
+ del store_bar_1
+
+ self._assign_and_check_version_0(B)
+
+ def store_bar_2(instance):
+ instance.bar = 10
+
+ self._check_specialization(
+ store_bar_2, B(), "STORE_ATTR_ADAPTIVE", should_specialize=False
+ )
+
+ def test_load_attr_specialization_user_type(self):
+ class C:
+ __slots__ = ("biz",)
+ def __init__(self):
+ self.biz = 8
+
+ self._assign_valid_version_or_skip(C)
+
+ def load_biz_1(type_):
+ type_.biz
+
+ self._check_specialization(
+ load_biz_1, C(), "LOAD_ATTR_ADAPTIVE", should_specialize=True
+ )
+ del load_biz_1
+
+ self._assign_and_check_version_0(C)
+
+ def load_biz_2(type_):
+ type_.biz
+
+ self._check_specialization(
+ load_biz_2, C(), "LOAD_ATTR_ADAPTIVE", should_specialize=False
+ )
+
+ def test_binary_subscript_specialization_user_type(self):
+ class D:
+ def __getitem__(self, _):
+ return 1
+
+ self._assign_valid_version_or_skip(D)
+
+ def subscript_1(instance):
+ instance[6]
+
+ self._check_specialization(
+ subscript_1, D(), "BINARY_SUBSCR_ADAPTIVE", should_specialize=True
+ )
+ del subscript_1
+
+ self._assign_and_check_version_0(D)
+
+ def subscript_2(instance):
+ instance[6]
+
+ self._check_specialization(
+ subscript_2, D(), "BINARY_SUBSCR_ADAPTIVE", should_specialize=False
+ )
+
+
if __name__ == "__main__":
unittest.main()
diff --git a/Misc/NEWS.d/next/Core and Builtins/2024-01-03-12-19-37.gh-issue-89811.cZOj6d.rst b/Misc/NEWS.d/next/Core and Builtins/2024-01-03-12-19-37.gh-issue-89811.cZOj6d.rst
new file mode 100644
index 0000000..90bd981
--- /dev/null
+++ b/Misc/NEWS.d/next/Core and Builtins/2024-01-03-12-19-37.gh-issue-89811.cZOj6d.rst
@@ -0,0 +1,2 @@
+Check for a valid ``tp_version_tag`` before performing bytecode specializations that
+rely on this value being usable.
diff --git a/Modules/_testcapimodule.c b/Modules/_testcapimodule.c
index 2f1801f..a4411d1 100644
--- a/Modules/_testcapimodule.c
+++ b/Modules/_testcapimodule.c
@@ -5899,6 +5899,32 @@ type_get_version(PyObject *self, PyObject *type)
return res;
}
+static PyObject *
+type_modified(PyObject *self, PyObject *type)
+{
+ if (!PyType_Check(type)) {
+ PyErr_SetString(PyExc_TypeError, "argument must be a type");
+ return NULL;
+ }
+ PyType_Modified((PyTypeObject *)type);
+ Py_RETURN_NONE;
+}
+
+// Circumvents standard version assignment machinery - use with caution and only on
+// short-lived heap types
+static PyObject *
+type_assign_specific_version_unsafe(PyObject *self, PyObject *args)
+{
+ PyTypeObject *type;
+ unsigned int version;
+ if (!PyArg_ParseTuple(args, "Oi:type_assign_specific_version_unsafe", &type, &version)) {
+ return NULL;
+ }
+ assert(!PyType_HasFeature(type, Py_TPFLAGS_IMMUTABLETYPE));
+ type->tp_version_tag = version;
+ type->tp_flags |= Py_TPFLAGS_VALID_VERSION_TAG;
+ Py_RETURN_NONE;
+}
// Test PyThreadState C API
static PyObject *
@@ -6782,6 +6808,9 @@ static PyMethodDef TestMethods[] = {
{"fatal_error", test_fatal_error, METH_VARARGS,
PyDoc_STR("fatal_error(message, release_gil=False): call Py_FatalError(message)")},
{"type_get_version", type_get_version, METH_O, PyDoc_STR("type->tp_version_tag")},
+ {"type_modified", type_modified, METH_O, PyDoc_STR("PyType_Modified")},
+ {"type_assign_specific_version_unsafe", type_assign_specific_version_unsafe, METH_VARARGS,
+ PyDoc_STR("forcefully assign type->tp_version_tag")},
{"test_tstate_capi", test_tstate_capi, METH_NOARGS, NULL},
{"float_pack", test_float_pack, METH_VARARGS, NULL},
{"float_unpack", test_float_unpack, METH_VARARGS, NULL},
diff --git a/Python/specialize.c b/Python/specialize.c
index 4a5213c..3441e84 100644
--- a/Python/specialize.c
+++ b/Python/specialize.c
@@ -481,6 +481,7 @@ miss_counter_start(void) {
#define SPEC_FAIL_UNPACK_SEQUENCE_ITERATOR 8
#define SPEC_FAIL_UNPACK_SEQUENCE_SEQUENCE 9
+static uint32_t type_get_version(PyTypeObject *t, int opcode);
static int
specialize_module_load_attr(PyObject *owner, _Py_CODEUNIT *instr,
@@ -673,6 +674,9 @@ _Py_Specialize_LoadAttr(PyObject *owner, _Py_CODEUNIT *instr, PyObject *name)
}
PyObject *descr;
DescriptorClassification kind = analyze_descriptor(type, name, &descr, 0);
+ if (type_get_version(type, LOAD_ATTR) == 0) {
+ goto fail;
+ }
switch(kind) {
case OVERRIDING:
SPECIALIZATION_FAIL(LOAD_ATTR, SPEC_FAIL_ATTR_OVERRIDING_DESCRIPTOR);
@@ -766,6 +770,9 @@ _Py_Specialize_StoreAttr(PyObject *owner, _Py_CODEUNIT *instr, PyObject *name)
}
PyObject *descr;
DescriptorClassification kind = analyze_descriptor(type, name, &descr, 1);
+ if (type_get_version(type, STORE_ATTR) == 0) {
+ goto fail;
+ }
switch(kind) {
case OVERRIDING:
SPECIALIZATION_FAIL(STORE_ATTR, SPEC_FAIL_ATTR_OVERRIDING_DESCRIPTOR);
@@ -889,6 +896,9 @@ specialize_class_load_method(PyObject *owner, _Py_CODEUNIT *instr,
PyObject *descr = NULL;
DescriptorClassification kind = 0;
kind = analyze_descriptor((PyTypeObject *)owner, name, &descr, 0);
+ if (type_get_version((PyTypeObject *)owner, LOAD_METHOD) == 0) {
+ return -1;
+ }
switch (kind) {
case METHOD:
case NON_DESCRIPTOR:
@@ -950,6 +960,9 @@ _Py_Specialize_LoadMethod(PyObject *owner, _Py_CODEUNIT *instr, PyObject *name)
PyObject *descr = NULL;
DescriptorClassification kind = 0;
kind = analyze_descriptor(owner_cls, name, &descr, 0);
+ if (type_get_version(owner_cls, LOAD_METHOD) == 0) {
+ goto fail;
+ }
assert(descr != NULL || kind == ABSENT || kind == GETSET_OVERRIDDEN);
if (kind != METHOD) {
SPECIALIZATION_FAIL(LOAD_METHOD, load_method_fail_kind(kind));
@@ -1183,6 +1196,18 @@ function_kind(PyCodeObject *code) {
return SIMPLE_FUNCTION;
}
+/* Returning 0 indicates a failure. */
+static uint32_t
+type_get_version(PyTypeObject *t, int opcode)
+{
+ uint32_t version = t->tp_version_tag;
+ if (version == 0) {
+ SPECIALIZATION_FAIL(opcode, SPEC_FAIL_OUT_OF_VERSIONS);
+ return 0;
+ }
+ return version;
+}
+
int
_Py_Specialize_BinarySubscr(
PyObject *container, PyObject *sub, _Py_CODEUNIT *instr)
@@ -1231,6 +1256,9 @@ _Py_Specialize_BinarySubscr(
SPECIALIZATION_FAIL(BINARY_SUBSCR, SPEC_FAIL_WRONG_NUMBER_ARGUMENTS);
goto fail;
}
+ if (type_get_version(cls, BINARY_SUBSCR) == 0) {
+ goto fail;
+ }
assert(cls->tp_version_tag != 0);
write_u32(cache->type_version, cls->tp_version_tag);
int version = _PyFunction_GetVersionForCurrentState(func);