Skip to content

Commit

Permalink
[3.12] Check for valid tp_version_tag in specializer (pythongh-89811) (
Browse files Browse the repository at this point in the history
  • Loading branch information
lazorchakp authored Jan 19, 2024
1 parent ffac6ac commit ae2a25b
Show file tree
Hide file tree
Showing 4 changed files with 197 additions and 2 deletions.
146 changes: 144 additions & 2 deletions Lib/test/test_type_cache.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
""" Tests for the internal type cache in CPython. """
import unittest
import dis
from test import support
from test.support import import_helper
try:
Expand All @@ -8,8 +9,11 @@
_clear_type_cache = None

# Skip this test if the _testcapi module isn't available.
type_get_version = import_helper.import_module('_testcapi').type_get_version
type_assign_version = import_helper.import_module('_testcapi').type_assign_version
_testcapi = import_helper.import_module("_testcapi")
type_get_version = _testcapi.type_get_version
type_assign_specific_version_unsafe = _testcapi.type_assign_specific_version_unsafe
type_assign_version = _testcapi.type_assign_version
type_modified = _testcapi.type_modified


@support.cpython_only
Expand Down Expand Up @@ -56,6 +60,144 @@ class C:
self.assertNotEqual(type_get_version(C), 0)
self.assertNotEqual(type_get_version(C), c_ver)

def test_type_assign_specific_version(self):
"""meta-test for type_assign_specific_version_unsafe"""
class C:
pass

type_assign_version(C)
orig_version = type_get_version(C)
if orig_version == 0:
self.skipTest("Could not assign a valid type version")

type_modified(C)
type_assign_specific_version_unsafe(C, orig_version + 5)
type_assign_version(C) # this should do nothing

new_version = type_get_version(C)
self.assertEqual(new_version, orig_version + 5)

_clear_type_cache()


@support.cpython_only
class TypeCacheWithSpecializationTests(unittest.TestCase):
def tearDown(self):
_clear_type_cache()

def _assign_valid_version_or_skip(self, type_):
type_modified(type_)
type_assign_version(type_)
if type_get_version(type_) == 0:
self.skipTest("Could not assign valid type version")

def _assign_and_check_version_0(self, user_type):
type_modified(user_type)
type_assign_specific_version_unsafe(user_type, 0)
self.assertEqual(type_get_version(user_type), 0)

def _all_opnames(self, func):
return set(instr.opname for instr in dis.Bytecode(func, adaptive=True))

def _check_specialization(self, func, arg, opname, *, should_specialize):
for _ in range(100):
func(arg)

if should_specialize:
self.assertNotIn(opname, self._all_opnames(func))
else:
self.assertIn(opname, self._all_opnames(func))

def test_class_load_attr_specialization_user_type(self):
class A:
def foo(self):
pass

self._assign_valid_version_or_skip(A)

def load_foo_1(type_):
type_.foo

self._check_specialization(load_foo_1, A, "LOAD_ATTR", should_specialize=True)
del load_foo_1

self._assign_and_check_version_0(A)

def load_foo_2(type_):
return type_.foo

self._check_specialization(load_foo_2, A, "LOAD_ATTR", should_specialize=False)

def test_class_load_attr_specialization_static_type(self):
self._assign_valid_version_or_skip(str)
self._assign_valid_version_or_skip(bytes)

def get_capitalize_1(type_):
return type_.capitalize

self._check_specialization(get_capitalize_1, str, "LOAD_ATTR", should_specialize=True)
self.assertEqual(get_capitalize_1(str)('hello'), 'Hello')
self.assertEqual(get_capitalize_1(bytes)(b'hello'), b'Hello')
del get_capitalize_1

# Permanently overflow the static type version counter, and force str and bytes
# to have tp_version_tag == 0
for _ in range(2**16):
type_modified(str)
type_assign_version(str)
type_modified(bytes)
type_assign_version(bytes)

self.assertEqual(type_get_version(str), 0)
self.assertEqual(type_get_version(bytes), 0)

def get_capitalize_2(type_):
return type_.capitalize

self._check_specialization(get_capitalize_2, str, "LOAD_ATTR", should_specialize=False)
self.assertEqual(get_capitalize_2(str)('hello'), 'Hello')
self.assertEqual(get_capitalize_2(bytes)(b'hello'), b'Hello')

def test_property_load_attr_specialization_user_type(self):
class G:
@property
def x(self):
return 9

self._assign_valid_version_or_skip(G)

def load_x_1(instance):
instance.x

self._check_specialization(load_x_1, G(), "LOAD_ATTR", should_specialize=True)
del load_x_1

self._assign_and_check_version_0(G)

def load_x_2(instance):
instance.x

self._check_specialization(load_x_2, G(), "LOAD_ATTR", should_specialize=False)

def test_store_attr_specialization_user_type(self):
class B:
__slots__ = ("bar",)

self._assign_valid_version_or_skip(B)

def store_bar_1(type_):
type_.bar = 10

self._check_specialization(store_bar_1, B(), "STORE_ATTR", should_specialize=True)
del store_bar_1

self._assign_and_check_version_0(B)

def store_bar_2(type_):
type_.bar = 10

self._check_specialization(store_bar_2, B(), "STORE_ATTR", should_specialize=False)


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Check for a valid ``tp_version_tag`` before performing bytecode specializations that
rely on this value being usable.
29 changes: 29 additions & 0 deletions Modules/_testcapimodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -2530,6 +2530,32 @@ type_get_version(PyObject *self, PyObject *type)
return res;
}

static PyObject *
type_modified(PyObject *self, PyObject *type)
{
if (!PyType_Check(type)) {
PyErr_SetString(PyExc_TypeError, "argument must be a type");
return NULL;
}
PyType_Modified((PyTypeObject *)type);
Py_RETURN_NONE;
}

// Circumvents standard version assignment machinery - use with caution and only on
// short-lived heap types
static PyObject *
type_assign_specific_version_unsafe(PyObject *self, PyObject *args)
{
PyTypeObject *type;
unsigned int version;
if (!PyArg_ParseTuple(args, "Oi:type_assign_specific_version_unsafe", &type, &version)) {
return NULL;
}
assert(!PyType_HasFeature(type, Py_TPFLAGS_IMMUTABLETYPE));
type->tp_version_tag = version;
type->tp_flags |= Py_TPFLAGS_VALID_VERSION_TAG;
Py_RETURN_NONE;
}

static PyObject *
type_assign_version(PyObject *self, PyObject *type)
Expand Down Expand Up @@ -3357,6 +3383,9 @@ static PyMethodDef TestMethods[] = {
{"test_py_is_macros", test_py_is_macros, METH_NOARGS},
{"test_py_is_funcs", test_py_is_funcs, METH_NOARGS},
{"type_get_version", type_get_version, METH_O, PyDoc_STR("type->tp_version_tag")},
{"type_modified", type_modified, METH_O, PyDoc_STR("PyType_Modified")},
{"type_assign_specific_version_unsafe", type_assign_specific_version_unsafe, METH_VARARGS,
PyDoc_STR("forcefully assign type->tp_version_tag")},
{"type_assign_version", type_assign_version, METH_O, PyDoc_STR("PyUnstable_Type_AssignVersionTag")},
{"type_get_tp_bases", type_get_tp_bases, METH_O},
{"type_get_tp_mro", type_get_tp_mro, METH_O},
Expand Down
22 changes: 22 additions & 0 deletions Python/specialize.c
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,7 @@ _PyCode_Quicken(PyCodeObject *code)
static int function_kind(PyCodeObject *code);
static bool function_check_args(PyObject *o, int expected_argcount, int opcode);
static uint32_t function_get_version(PyObject *o, int opcode);
static uint32_t type_get_version(PyTypeObject *t, int opcode);

static int
specialize_module_load_attr(
Expand Down Expand Up @@ -746,6 +747,9 @@ _Py_Specialize_LoadAttr(PyObject *owner, _Py_CODEUNIT *instr, PyObject *name)
PyObject *descr = NULL;
DescriptorClassification kind = analyze_descriptor(type, name, &descr, 0);
assert(descr != NULL || kind == ABSENT || kind == GETSET_OVERRIDDEN);
if (type_get_version(type, LOAD_ATTR) == 0) {
goto fail;
}
switch(kind) {
case OVERRIDING:
SPECIALIZATION_FAIL(LOAD_ATTR, SPEC_FAIL_ATTR_OVERRIDING_DESCRIPTOR);
Expand Down Expand Up @@ -917,6 +921,9 @@ _Py_Specialize_StoreAttr(PyObject *owner, _Py_CODEUNIT *instr, PyObject *name)
}
PyObject *descr;
DescriptorClassification kind = analyze_descriptor(type, name, &descr, 1);
if (type_get_version(type, STORE_ATTR) == 0) {
goto fail;
}
switch(kind) {
case OVERRIDING:
SPECIALIZATION_FAIL(STORE_ATTR, SPEC_FAIL_ATTR_OVERRIDING_DESCRIPTOR);
Expand Down Expand Up @@ -1043,6 +1050,9 @@ specialize_class_load_attr(PyObject *owner, _Py_CODEUNIT *instr,
PyObject *descr = NULL;
DescriptorClassification kind = 0;
kind = analyze_descriptor((PyTypeObject *)owner, name, &descr, 0);
if (type_get_version((PyTypeObject *)owner, LOAD_ATTR) == 0) {
return -1;
}
switch (kind) {
case METHOD:
case NON_DESCRIPTOR:
Expand Down Expand Up @@ -1317,6 +1327,18 @@ function_get_version(PyObject *o, int opcode)
return version;
}

/* Returning 0 indicates a failure. */
static uint32_t
type_get_version(PyTypeObject *t, int opcode)
{
uint32_t version = t->tp_version_tag;
if (version == 0) {
SPECIALIZATION_FAIL(opcode, SPEC_FAIL_OUT_OF_VERSIONS);
return 0;
}
return version;
}

void
_Py_Specialize_BinarySubscr(
PyObject *container, PyObject *sub, _Py_CODEUNIT *instr)
Expand Down

0 comments on commit ae2a25b

Please sign in to comment.