Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
fc910a3
Hide that C recursion protection is implemented with a counter. There…
markshannon Feb 10, 2025
afeb866
Implement C recursion protection with limit pointers
markshannon Feb 11, 2025
22ca169
Use uintptr_t instead of char * to avoid warnings and UB
markshannon Feb 11, 2025
e8d8c4b
Merge branch 'main' into c-recursion-limit
markshannon Feb 11, 2025
b3638a5
Fix typo and update stable ABI
markshannon Feb 11, 2025
774efb5
Tweak AST test numbers
markshannon Feb 11, 2025
151c88f
Improve logic handling trial C stack overflow
markshannon Feb 11, 2025
350f8ec
Remove calls to PyOS_CheckStack
markshannon Feb 11, 2025
428c46a
Up the limits for recursion tests
markshannon Feb 11, 2025
a9be141
Use deeper stack for test
markshannon Feb 12, 2025
03fc52e
Remove exceeds_recursion_limit and get_c_recursion_limit. Use platfor…
markshannon Feb 12, 2025
afac1e6
Do fewer probes when growing stack limit
markshannon Feb 12, 2025
9da904d
Tweak depths
markshannon Feb 12, 2025
a802ff6
Merge branch 'main' into c-recursion-limit
markshannon Feb 12, 2025
dbcf6f0
Perform lazy initialization of c recursion check
markshannon Feb 12, 2025
2cc3287
Post merge fixup
markshannon Feb 12, 2025
e697926
Up depth again
markshannon Feb 12, 2025
f8a9143
Drop 'failing' depth
markshannon Feb 12, 2025
7d6d77f
Add news
markshannon Feb 12, 2025
31a83dc
Increase headroom
markshannon Feb 12, 2025
47c50aa
Update test
markshannon Feb 12, 2025
495c4ea
Tweak some more thresholds and tests
markshannon Feb 12, 2025
9e0cc67
Add stack protection to parser
markshannon Feb 13, 2025
857a7bb
Make tests more robust to low stacks
markshannon Feb 13, 2025
9c9326a
Improve error messages for stack overflow
markshannon Feb 13, 2025
75d3219
Merge branch 'main' into c-recursion-limit
markshannon Feb 13, 2025
3e41b46
Fix formatting
markshannon Feb 13, 2025
158401a
Halve size of WASI stack
markshannon Feb 13, 2025
c1eb229
Reduce webassembly 'stack size' by 10
markshannon Feb 13, 2025
7407d2b
Halve size of WASI stack, again
markshannon Feb 13, 2025
978b5e7
Change WASI stack back to 100k
markshannon Feb 13, 2025
7b36f59
Add many skip tests for WASI due to stack issues
markshannon Feb 13, 2025
5e5db03
Probe all pages when extending stack limits
markshannon Feb 14, 2025
82173ed
Fix compiler warnings
markshannon Feb 14, 2025
64cfd86
Use GetCurrentThreadStackLimits instead of probing with alloca
markshannon Feb 14, 2025
e52137f
Refactor a bit
markshannon Feb 14, 2025
704c336
Merge branch 'main' into c-recursion-limit
markshannon Feb 14, 2025
21366c3
Fix logic error in test
markshannon Feb 17, 2025
7761d31
Make ABI function needed for Py_TRASHCAN_BEGIN private
markshannon Feb 17, 2025
b067e3e
Move new fields to _PyThreadStateImpl to avoid any potential API brea…
markshannon Feb 17, 2025
b0e695f
Fix missing cast
markshannon Feb 17, 2025
c4cd68f
yet another missing _
markshannon Feb 17, 2025
9654790
Tidy up a bit
markshannon Feb 18, 2025
2cef96f
Restore use of exceeds_recursion_limit
markshannon Feb 18, 2025
c5d8a40
Address remaining review comments
markshannon Feb 18, 2025
33e11c8
Merge branch 'main' into c-recursion-limit
markshannon Feb 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Hide that C recursion protection is implemented with a counter. There…
… is an imbalance in the AST somewhere.
  • Loading branch information
markshannon committed Feb 11, 2025
commit fc910a355198973988cb73fb2bae94604c84800d
4 changes: 4 additions & 0 deletions Include/ceval.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ PyAPI_FUNC(int) Py_GetRecursionLimit(void);
PyAPI_FUNC(int) Py_EnterRecursiveCall(const char *where);
PyAPI_FUNC(void) Py_LeaveRecursiveCall(void);

PyAPI_FUNC(int) Py_ReachedRecursionLimit(PyThreadState *tstate, int margin_count);
PyAPI_FUNC(void) _Py_EnterRecursiveCallUnchecked(PyThreadState *tstate);
PyAPI_FUNC(void) Py_LeaveRecursiveCallTstate(PyThreadState *tstate);

PyAPI_FUNC(const char *) PyEval_GetFuncName(PyObject *);
PyAPI_FUNC(const char *) PyEval_GetFuncDesc(PyObject *);

Expand Down
8 changes: 4 additions & 4 deletions Include/cpython/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -490,15 +490,15 @@ PyAPI_FUNC(void) _PyTrash_thread_destroy_chain(PyThreadState *tstate);
#define Py_TRASHCAN_BEGIN(op, dealloc) \
do { \
PyThreadState *tstate = PyThreadState_Get(); \
if (tstate->c_recursion_remaining <= Py_TRASHCAN_HEADROOM && Py_TYPE(op)->tp_dealloc == (destructor)dealloc) { \
if (Py_ReachedRecursionLimit(tstate, 1) && Py_TYPE(op)->tp_dealloc == (destructor)dealloc) { \
_PyTrash_thread_deposit_object(tstate, (PyObject *)op); \
break; \
} \
tstate->c_recursion_remaining--;
_Py_EnterRecursiveCallUnchecked(tstate);
/* The body of the deallocator is here. */
#define Py_TRASHCAN_END \
tstate->c_recursion_remaining++; \
if (tstate->delete_later && tstate->c_recursion_remaining > (Py_TRASHCAN_HEADROOM*2)) { \
Py_LeaveRecursiveCallTstate(tstate); \
if (tstate->delete_later && !Py_ReachedRecursionLimit(tstate, 2)) { \
_PyTrash_thread_destroy_chain(tstate); \
} \
} while (0);
Expand Down
8 changes: 7 additions & 1 deletion Include/internal/pycore_ceval.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ static inline int _Py_EnterRecursiveCallTstate(PyThreadState *tstate,
}

static inline void _Py_EnterRecursiveCallTstateUnchecked(PyThreadState *tstate) {
assert(tstate->c_recursion_remaining > 0);
assert(tstate->c_recursion_remaining >= -2); // Allow a bit of wiggle room
tstate->c_recursion_remaining--;
}

Expand All @@ -234,6 +234,12 @@ static inline void _Py_LeaveRecursiveCallTstate(PyThreadState *tstate) {
tstate->c_recursion_remaining++;
}

#define Py_RECURSION_LIMIT_MARGIN_MULTIPLIER 50

static inline int _Py_ReachedRecursionLimit(PyThreadState *tstate, int margin_count) {
return tstate->c_recursion_remaining <= margin_count * Py_RECURSION_LIMIT_MARGIN_MULTIPLIER;
}

static inline void _Py_LeaveRecursiveCall(void) {
PyThreadState *tstate = _PyThreadState_GET();
_Py_LeaveRecursiveCallTstate(tstate);
Expand Down
2 changes: 0 additions & 2 deletions Include/internal/pycore_symtable.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,6 @@ struct symtable {
PyObject *st_private; /* name of current class or NULL */
_PyFutureFeatures *st_future; /* module's future features that affect
the symbol table */
int recursion_depth; /* current recursion depth */
int recursion_limit; /* recursion limit */
};

typedef struct _symtable_entry {
Expand Down
2 changes: 1 addition & 1 deletion Lib/test/test_ast/test_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ def next(self):
def test_ast_recursion_limit(self):
fail_depth = support.exceeds_recursion_limit()
crash_depth = 100_000
success_depth = int(support.get_c_recursion_limit() * 0.8)
success_depth = int(support.get_c_recursion_limit() * 0.6)
if _testinternalcapi is not None:
remaining = _testinternalcapi.get_c_recursion_remaining()
success_depth = min(success_depth, remaining)
Expand Down
2 changes: 1 addition & 1 deletion Lib/test/test_capi/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ def test_trashcan_subclass(self):
# activated when its tp_dealloc is being called by a subclass
from _testcapi import MyList
L = None
for i in range(1000):
for i in range(support.get_c_recursion_limit()):
L = MyList((L,))

@support.requires_resource('cpu')
Expand Down
5 changes: 2 additions & 3 deletions Objects/object.c
Original file line number Diff line number Diff line change
Expand Up @@ -2910,8 +2910,7 @@ _PyTrash_thread_destroy_chain(PyThreadState *tstate)
tups = [(tup,) for tup in tups]
del tups
*/
assert(tstate->c_recursion_remaining > Py_TRASHCAN_HEADROOM);
tstate->c_recursion_remaining--;
_Py_EnterRecursiveCallTstateUnchecked(tstate);
while (tstate->delete_later) {
PyObject *op = tstate->delete_later;
destructor dealloc = Py_TYPE(op)->tp_dealloc;
Expand All @@ -2933,7 +2932,7 @@ _PyTrash_thread_destroy_chain(PyThreadState *tstate)
_PyObject_ASSERT(op, Py_REFCNT(op) == 0);
(*dealloc)(op);
}
tstate->c_recursion_remaining++;
_Py_LeaveRecursiveCallTstate(tstate);
}

void _Py_NO_RETURN
Expand Down
63 changes: 18 additions & 45 deletions Parser/asdl_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,7 +738,7 @@ def emit_sequence_constructor(self, name, type):
class PyTypesDeclareVisitor(PickleVisitor):

def visitProduct(self, prod, name):
self.emit("static PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, void*);" % name, 0)
self.emit("static PyObject* ast2obj_%s(struct ast_state *state, void*);" % name, 0)
if prod.attributes:
self.emit("static const char * const %s_attributes[] = {" % name, 0)
for a in prod.attributes:
Expand All @@ -759,7 +759,7 @@ def visitSum(self, sum, name):
ptype = "void*"
if is_simple(sum):
ptype = get_c_type(name)
self.emit("static PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, %s);" % (name, ptype), 0)
self.emit("static PyObject* ast2obj_%s(struct ast_state *state, %s);" % (name, ptype), 0)
for t in sum.types:
self.visitConstructor(t, name)

Expand Down Expand Up @@ -1734,16 +1734,16 @@ def visitModule(self, mod):

/* Conversion AST -> Python */

static PyObject* ast2obj_list(struct ast_state *state, struct validator *vstate, asdl_seq *seq,
PyObject* (*func)(struct ast_state *state, struct validator *vstate, void*))
static PyObject* ast2obj_list(struct ast_state *state, asdl_seq *seq,
PyObject* (*func)(struct ast_state *state, void*))
{
Py_ssize_t i, n = asdl_seq_LEN(seq);
PyObject *result = PyList_New(n);
PyObject *value;
if (!result)
return NULL;
for (i = 0; i < n; i++) {
value = func(state, vstate, asdl_seq_GET_UNTYPED(seq, i));
value = func(state, asdl_seq_GET_UNTYPED(seq, i));
if (!value) {
Py_DECREF(result);
return NULL;
Expand All @@ -1753,7 +1753,7 @@ def visitModule(self, mod):
return result;
}

static PyObject* ast2obj_object(struct ast_state *Py_UNUSED(state), struct validator *Py_UNUSED(vstate), void *o)
static PyObject* ast2obj_object(struct ast_state *Py_UNUSED(state), void *o)
{
PyObject *op = (PyObject*)o;
if (!op) {
Expand All @@ -1765,7 +1765,7 @@ def visitModule(self, mod):
#define ast2obj_identifier ast2obj_object
#define ast2obj_string ast2obj_object

static PyObject* ast2obj_int(struct ast_state *Py_UNUSED(state), struct validator *Py_UNUSED(vstate), long b)
static PyObject* ast2obj_int(struct ast_state *Py_UNUSED(state), long b)
{
return PyLong_FromLong(b);
}
Expand Down Expand Up @@ -2014,25 +2014,23 @@ class ObjVisitor(PickleVisitor):
def func_begin(self, name):
ctype = get_c_type(name)
self.emit("PyObject*", 0)
self.emit("ast2obj_%s(struct ast_state *state, struct validator *vstate, void* _o)" % (name), 0)
self.emit("ast2obj_%s(struct ast_state *state, void* _o)" % (name), 0)
self.emit("{", 0)
self.emit("%s o = (%s)_o;" % (ctype, ctype), 1)
self.emit("PyObject *result = NULL, *value = NULL;", 1)
self.emit("PyTypeObject *tp;", 1)
self.emit('if (!o) {', 1)
self.emit("Py_RETURN_NONE;", 2)
self.emit("}", 1)
self.emit("if (++vstate->recursion_depth > vstate->recursion_limit) {", 1)
self.emit("PyErr_SetString(PyExc_RecursionError,", 2)
self.emit('"maximum recursion depth exceeded during ast construction");', 3)
self.emit('if (Py_EnterRecursiveCall("during ast construction")) {', 1)
self.emit("return NULL;", 2)
self.emit("}", 1)

def func_end(self):
self.emit("vstate->recursion_depth--;", 1)
self.emit("Py_LeaveRecursiveCall();", 1)
self.emit("return result;", 1)
self.emit("failed:", 0)
self.emit("vstate->recursion_depth--;", 1)
self.emit("Py_LeaveRecursiveCall();", 1)
self.emit("Py_XDECREF(value);", 1)
self.emit("Py_XDECREF(result);", 1)
self.emit("return NULL;", 1)
Expand All @@ -2050,15 +2048,15 @@ def visitSum(self, sum, name):
self.visitConstructor(t, i + 1, name)
self.emit("}", 1)
for a in sum.attributes:
self.emit("value = ast2obj_%s(state, vstate, o->%s);" % (a.type, a.name), 1)
self.emit("value = ast2obj_%s(state, o->%s);" % (a.type, a.name), 1)
self.emit("if (!value) goto failed;", 1)
self.emit('if (PyObject_SetAttr(result, state->%s, value) < 0)' % a.name, 1)
self.emit('goto failed;', 2)
self.emit('Py_DECREF(value);', 1)
self.func_end()

def simpleSum(self, sum, name):
self.emit("PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, %s_ty o)" % (name, name), 0)
self.emit("PyObject* ast2obj_%s(struct ast_state *state, %s_ty o)" % (name, name), 0)
self.emit("{", 0)
self.emit("switch(o) {", 1)
for t in sum.types:
Expand All @@ -2076,7 +2074,7 @@ def visitProduct(self, prod, name):
for field in prod.fields:
self.visitField(field, name, 1, True)
for a in prod.attributes:
self.emit("value = ast2obj_%s(state, vstate, o->%s);" % (a.type, a.name), 1)
self.emit("value = ast2obj_%s(state, o->%s);" % (a.type, a.name), 1)
self.emit("if (!value) goto failed;", 1)
self.emit("if (PyObject_SetAttr(result, state->%s, value) < 0)" % a.name, 1)
self.emit('goto failed;', 2)
Expand Down Expand Up @@ -2117,7 +2115,7 @@ def set(self, field, value, depth):
self.emit("for(i = 0; i < n; i++)", depth+1)
# This cannot fail, so no need for error handling
self.emit(
"PyList_SET_ITEM(value, i, ast2obj_{0}(state, vstate, ({0}_ty)asdl_seq_GET({1}, i)));".format(
"PyList_SET_ITEM(value, i, ast2obj_{0}(state, ({0}_ty)asdl_seq_GET({1}, i)));".format(
field.type,
value
),
Expand All @@ -2126,9 +2124,9 @@ def set(self, field, value, depth):
)
self.emit("}", depth)
else:
self.emit("value = ast2obj_list(state, vstate, (asdl_seq*)%s, ast2obj_%s);" % (value, field.type), depth)
self.emit("value = ast2obj_list(state, (asdl_seq*)%s, ast2obj_%s);" % (value, field.type), depth)
else:
self.emit("value = ast2obj_%s(state, vstate, %s);" % (field.type, value), depth, reflow=False)
self.emit("value = ast2obj_%s(state, %s);" % (field.type, value), depth, reflow=False)


class PartingShots(StaticVisitor):
Expand All @@ -2140,28 +2138,8 @@ class PartingShots(StaticVisitor):
if (state == NULL) {
return NULL;
}
PyObject *result = ast2obj_mod(state, t);

int starting_recursion_depth;
/* Be careful here to prevent overflow. */
PyThreadState *tstate = _PyThreadState_GET();
if (!tstate) {
return NULL;
}
struct validator vstate;
vstate.recursion_limit = Py_C_RECURSION_LIMIT;
int recursion_depth = Py_C_RECURSION_LIMIT - tstate->c_recursion_remaining;
starting_recursion_depth = recursion_depth;
vstate.recursion_depth = starting_recursion_depth;

PyObject *result = ast2obj_mod(state, &vstate, t);

/* Check that the recursion depth counting balanced correctly */
if (result && vstate.recursion_depth != starting_recursion_depth) {
PyErr_Format(PyExc_SystemError,
"AST constructor recursion depth mismatch (before=%d, after=%d)",
starting_recursion_depth, vstate.recursion_depth);
return NULL;
}
return result;
}

Expand Down Expand Up @@ -2293,11 +2271,6 @@ def generate_module_def(mod, metadata, f, internal_h):
#include "structmember.h"
#include <stddef.h>

struct validator {
int recursion_depth; /* current recursion depth */
int recursion_limit; /* recursion limit */
};

// Forward declaration
static int init_types(void *arg);

Expand Down
Loading