diff --git a/Lib/test/test_context.py b/Lib/test/test_context.py index a08038b5dbd407..ef20495dcc01ea 100644 --- a/Lib/test/test_context.py +++ b/Lib/test/test_context.py @@ -556,6 +556,36 @@ def fun(): ctx.run(fun) + def test_context_eq_reentrant_contextvar_set(self): + var = contextvars.ContextVar("v") + ctx1 = contextvars.Context() + ctx2 = contextvars.Context() + + class ReentrantEq: + def __eq__(self, other): + ctx1.run(lambda: var.set(object())) + return True + + ctx1.run(var.set, ReentrantEq()) + ctx2.run(var.set, object()) + ctx1 == ctx2 + + def test_context_eq_reentrant_contextvar_set_in_hash(self): + var = contextvars.ContextVar("v") + ctx1 = contextvars.Context() + ctx2 = contextvars.Context() + + class ReentrantHash: + def __hash__(self): + ctx1.run(lambda: var.set(object())) + return 0 + def __eq__(self, other): + return isinstance(other, ReentrantHash) + + ctx1.run(var.set, ReentrantHash()) + ctx2.run(var.set, ReentrantHash()) + ctx1 == ctx2 + # HAMT Tests diff --git a/Misc/NEWS.d/next/Core_and_Builtins/2025-12-17-19-45-10.gh-issue-142829.ICtLXy.rst b/Misc/NEWS.d/next/Core_and_Builtins/2025-12-17-19-45-10.gh-issue-142829.ICtLXy.rst new file mode 100644 index 00000000000000..b85003071ac188 --- /dev/null +++ b/Misc/NEWS.d/next/Core_and_Builtins/2025-12-17-19-45-10.gh-issue-142829.ICtLXy.rst @@ -0,0 +1,3 @@ +Fix a use-after-free crash in :class:`contextvars.Context` comparison when a +custom ``__eq__`` method modifies the context via +:meth:`~contextvars.ContextVar.set`. diff --git a/Python/hamt.c b/Python/hamt.c index e372b1a1b4c18b..881290a0e60db8 100644 --- a/Python/hamt.c +++ b/Python/hamt.c @@ -2328,6 +2328,10 @@ _PyHamt_Eq(PyHamtObject *v, PyHamtObject *w) return 0; } + Py_INCREF(v); + Py_INCREF(w); + + int res = 1; PyHamtIteratorState iter; hamt_iter_t iter_res; hamt_find_t find_res; @@ -2343,25 +2347,38 @@ _PyHamt_Eq(PyHamtObject *v, PyHamtObject *w) find_res = hamt_find(w, v_key, &w_val); switch (find_res) { case F_ERROR: - return -1; + res = -1; + goto done; case F_NOT_FOUND: - return 0; + res = 0; + goto done; case F_FOUND: { + Py_INCREF(v_key); + Py_INCREF(v_val); + Py_INCREF(w_val); int cmp = PyObject_RichCompareBool(v_val, w_val, Py_EQ); + Py_DECREF(v_key); + Py_DECREF(v_val); + Py_DECREF(w_val); if (cmp < 0) { - return -1; + res = -1; + goto done; } if (cmp == 0) { - return 0; + res = 0; + goto done; } } } } } while (iter_res != I_END); - return 1; +done: + Py_DECREF(v); + Py_DECREF(w); + return res; } Py_ssize_t