1 #define PY_SSIZE_T_CLEAN
2 #include <torch/csrc/dynamo/cache_entry.h>
3 #include <torch/csrc/dynamo/cpp_shim.h>
4 #include <torch/csrc/dynamo/cpython_defs.h>
5 #include <torch/csrc/dynamo/cpython_includes.h>
6 #include <torch/csrc/dynamo/debug_macros.h>
7 #include <torch/csrc/dynamo/extra_state.h>
8 #include <torch/csrc/dynamo/framelocals_mapping.h>
9 #include <torch/csrc/utils/python_compat.h>
10 #include <opcode.h>
11 #include <stdbool.h>
12
13 #define MAX_COMPILE_CONTEXT_SIZE 100
14
15 PyObject* guard_error_hook = NULL;
16 const char* cache_lookup_profiler_str = "TorchDynamo Cache Lookup";
17 static char compile_context[MAX_COMPILE_CONTEXT_SIZE];
18 static int active_dynamo_threads = 0;
19
20 static Py_tss_t eval_frame_callback_key = Py_tss_NEEDS_INIT;
21
eval_frame_callback_get(void)22 inline static PyObject* eval_frame_callback_get(void) {
23 void* result = PyThread_tss_get(&eval_frame_callback_key);
24 if (unlikely(result == NULL)) {
25 return (PyObject*)Py_None;
26 } else {
27 return (PyObject*)result;
28 }
29 }
30
eval_frame_callback_set(PyObject * obj)31 inline static void eval_frame_callback_set(PyObject* obj) {
32 PyThread_tss_set(&eval_frame_callback_key, obj);
33 }
34
35 // 3.14 Not supported at all. See cpython_defs.c for hints
36 #if !(IS_PYTHON_3_14_PLUS)
37
38 // All the eval APIs change in 3.11 so we need to decide which one to use on the fly
39 // https://docs.python.org/3/c-api/init.html#c._PyFrameEvalFunction
40 #if IS_PYTHON_3_11_PLUS
41 #define THP_EVAL_API_FRAME_OBJECT _PyInterpreterFrame
42
43 // We need to be able to return the _PyInterpreterFrame to python so create
44 // a python binding for it
45
46 typedef struct THPPyInterpreterFrame {
47 PyObject_HEAD
48 _PyInterpreterFrame* frame; // Borrowed reference
49 PyObject* locals;
50 } THPPyInterpreterFrame;
51
52 THPPyInterpreterFrame* THPPyInterpreterFrame_New(_PyInterpreterFrame* frame);
53
54 #define DECLARE_PYOBJ_ATTR(name) \
55 static PyObject* THPPyInterpreterFrame_##name(THPPyInterpreterFrame* self, PyObject* _noargs) { \
56 PyObject* res = (PyObject*)self->frame->name; \
57 Py_XINCREF(res); \
58 return res; \
59 }
60
61 #if IS_PYTHON_3_12_PLUS
62 DECLARE_PYOBJ_ATTR(f_funcobj)
63 #else
64 DECLARE_PYOBJ_ATTR(f_func)
65 #endif
66
DECLARE_PYOBJ_ATTR(f_globals)67 DECLARE_PYOBJ_ATTR(f_globals)
68 DECLARE_PYOBJ_ATTR(f_builtins)
69
70 static PyObject* THPPyInterpreterFrame_f_locals(THPPyInterpreterFrame* self, PyObject* _noargs) {
71 DEBUG_NULL_CHECK(self->locals);
72 Py_XINCREF(self->locals);
73 return self->locals;
74 }
75
76 #if IS_PYTHON_3_13_PLUS
77 DECLARE_PYOBJ_ATTR(f_executable)
78 #else
79 DECLARE_PYOBJ_ATTR(f_code)
80 #endif
81
DECLARE_PYOBJ_ATTR(frame_obj)82 DECLARE_PYOBJ_ATTR(frame_obj)
83
84 #undef DECLARE_PYOBJ_ATTR
85
86 static THPPyInterpreterFrame* THPPyInterpreterFrame_previous(THPPyInterpreterFrame* self, PyObject* _noargs) {
87 THPPyInterpreterFrame* res = THPPyInterpreterFrame_New(self->frame->previous);
88 return res;
89 }
90
91 // This is not a true attribute of the class but we do access it in python and it is hard to implement
92 // on the python side, so do it here:
THPPyInterpreterFrame_f_lasti(THPPyInterpreterFrame * self,PyObject * _noargs)93 static PyObject* THPPyInterpreterFrame_f_lasti(THPPyInterpreterFrame* self, PyObject* _noargs) {
94 return PyLong_FromLong(_PyInterpreterFrame_LASTI(self->frame));
95 }
96
THPPyInterpreterFrame_f_lineno(THPPyInterpreterFrame * self,PyObject * _noargs)97 static PyObject* THPPyInterpreterFrame_f_lineno(THPPyInterpreterFrame* self, PyObject* _noargs) {
98 if (!self->frame->frame_obj) {
99 return PyLong_FromLong(F_CODE(self->frame)->co_firstlineno);
100 }
101 int lineno = PyFrame_GetLineNumber(self->frame->frame_obj);
102 if (lineno < 0) {
103 Py_RETURN_NONE;
104 }
105 return PyLong_FromLong(lineno);
106 }
107
THPPyInterpreterFrame_f_back(THPPyInterpreterFrame * self,PyObject * _noargs)108 static PyObject* THPPyInterpreterFrame_f_back(THPPyInterpreterFrame* self, PyObject* _noargs) {
109 if (!self->frame->frame_obj) {
110 Py_RETURN_NONE;
111 }
112 return (PyObject*)PyFrame_GetBack(self->frame->frame_obj);
113 }
114
115 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
116 static struct PyGetSetDef THPPyInterpreterFrame_properties[] = {
117 #if IS_PYTHON_3_12_PLUS
118 {"f_func", (getter)THPPyInterpreterFrame_f_funcobj, NULL, NULL, NULL},
119 #else
120 {"f_func", (getter)THPPyInterpreterFrame_f_func, NULL, NULL, NULL},
121 #endif
122 {"f_globals", (getter)THPPyInterpreterFrame_f_globals, NULL, NULL, NULL},
123 {"f_builtins", (getter)THPPyInterpreterFrame_f_builtins, NULL, NULL, NULL},
124 {"f_locals", (getter)THPPyInterpreterFrame_f_locals, NULL, NULL, NULL},
125 #if IS_PYTHON_3_13_PLUS
126 {"f_code", (getter)THPPyInterpreterFrame_f_executable, NULL, NULL, NULL},
127 #else
128 {"f_code", (getter)THPPyInterpreterFrame_f_code, NULL, NULL, NULL},
129 #endif
130 {"frame_obj", (getter)THPPyInterpreterFrame_frame_obj, NULL, NULL, NULL},
131 {"previous", (getter)THPPyInterpreterFrame_previous, NULL, NULL, NULL},
132 {"f_lasti", (getter)THPPyInterpreterFrame_f_lasti, NULL, NULL, NULL},
133 {"f_lineno", (getter)THPPyInterpreterFrame_f_lineno, NULL, NULL, NULL},
134 {"f_back", (getter)THPPyInterpreterFrame_f_back, NULL, NULL, NULL},
135 {NULL}};
136
137 static PyTypeObject THPPyInterpreterFrameType = {
138 PyVarObject_HEAD_INIT(NULL, 0)
139 .tp_name = "torch._C._dynamo.eval_frame._PyInterpreterFrame",
140 .tp_basicsize = sizeof(THPPyInterpreterFrame),
141 .tp_flags = Py_TPFLAGS_DEFAULT,
142 .tp_getset = THPPyInterpreterFrame_properties,
143 };
144
145
THPPyInterpreterFrame_New(_PyInterpreterFrame * frame)146 THPPyInterpreterFrame* THPPyInterpreterFrame_New(_PyInterpreterFrame* frame) {
147 PyTypeObject* type = (PyTypeObject*)&THPPyInterpreterFrameType;
148 THPPyInterpreterFrame* self = (THPPyInterpreterFrame*)type->tp_alloc(type, 0);
149 if (!self)
150 return NULL;
151 self->frame = frame;
152 self->locals = NULL;
153 return self;
154 }
155
156
157 #else
158 #define THP_EVAL_API_FRAME_OBJECT PyFrameObject
159
160 static int
THP_PyFrame_FastToLocalsWithError(THP_EVAL_API_FRAME_OBJECT * frame,int * free_vars_copied)161 THP_PyFrame_FastToLocalsWithError(THP_EVAL_API_FRAME_OBJECT *frame, int *free_vars_copied) {
162 return PyFrame_FastToLocalsWithError(frame);
163 }
164 #endif
165
166 static PyObject* _custom_eval_frame_shim(
167 PyThreadState* tstate,
168 THP_EVAL_API_FRAME_OBJECT* frame,
169 int throw_flag);
170 static PyObject* _custom_eval_frame(
171 PyThreadState* tstate,
172 THP_EVAL_API_FRAME_OBJECT* frame,
173 int throw_flag,
174 PyObject* callback,
175 int* should_clear_frame);
176 static PyObject *(*previous_eval_frame)(PyThreadState *tstate,
177 THP_EVAL_API_FRAME_OBJECT* frame, int throw_flag) = NULL;
178
179 #if PY_VERSION_HEX >= 0x03090000
custom_eval_frame_shim(PyThreadState * tstate,THP_EVAL_API_FRAME_OBJECT * frame,int throw_flag)180 static PyObject* custom_eval_frame_shim(
181 PyThreadState* tstate,
182 THP_EVAL_API_FRAME_OBJECT* frame,
183 int throw_flag) {
184 return _custom_eval_frame_shim(tstate, frame, throw_flag);
185 }
186 #else
custom_eval_frame_shim(THP_EVAL_API_FRAME_OBJECT * frame,int throw_flag)187 static PyObject* custom_eval_frame_shim(THP_EVAL_API_FRAME_OBJECT* frame, int throw_flag) {
188 PyThreadState* tstate = PyThreadState_GET();
189 return _custom_eval_frame_shim(tstate, frame, throw_flag);
190 }
191 #endif
192
eval_frame_default(PyThreadState * tstate,THP_EVAL_API_FRAME_OBJECT * frame,int throw_flag)193 inline static PyObject* eval_frame_default(
194 PyThreadState* tstate,
195 THP_EVAL_API_FRAME_OBJECT* frame,
196 int throw_flag) {
197 #if PY_VERSION_HEX >= 0x03090000
198 if (tstate == NULL) {
199 tstate = PyThreadState_GET();
200 }
201 if (previous_eval_frame) {
202 return previous_eval_frame(tstate, frame, throw_flag);
203 }
204 else {
205 return _PyEval_EvalFrameDefault(tstate, frame, throw_flag);
206 }
207 #else
208 return _PyEval_EvalFrameDefault(frame, throw_flag);
209 #endif
210 }
211
enable_eval_frame_shim(PyThreadState * tstate)212 inline static void enable_eval_frame_shim(PyThreadState* tstate) {
213 #if PY_VERSION_HEX >= 0x03090000
214 if (_PyInterpreterState_GetEvalFrameFunc(tstate->interp) !=
215 &custom_eval_frame_shim) {
216 DEBUG_CHECK(previous_eval_frame == NULL);
217 previous_eval_frame = _PyInterpreterState_GetEvalFrameFunc(tstate->interp);
218 _PyInterpreterState_SetEvalFrameFunc(tstate->interp,
219 &custom_eval_frame_shim);
220 }
221 #else
222 if (tstate->interp->eval_frame != &custom_eval_frame_shim) {
223 // First call
224 tstate->interp->eval_frame = &custom_eval_frame_shim;
225 }
226 #endif
227 }
228
enable_eval_frame_default(PyThreadState * tstate)229 inline static void enable_eval_frame_default(PyThreadState* tstate) {
230 #if PY_VERSION_HEX >= 0x03090000
231 if (_PyInterpreterState_GetEvalFrameFunc(tstate->interp) !=
232 previous_eval_frame) {
233 DEBUG_CHECK(previous_eval_frame != NULL);
234 _PyInterpreterState_SetEvalFrameFunc(tstate->interp,
235 previous_eval_frame);
236 previous_eval_frame = NULL;
237 }
238 #else
239 if (tstate->interp->eval_frame != &_PyEval_EvalFrameDefault) {
240 // First call
241 tstate->interp->eval_frame = &_PyEval_EvalFrameDefault;
242 }
243 #endif
244 }
245
246
get_frame_name(THP_EVAL_API_FRAME_OBJECT * frame)247 inline static const char* get_frame_name(THP_EVAL_API_FRAME_OBJECT* frame) {
248 // Returns the C string name of the current frame.
249 DEBUG_CHECK(PyUnicode_Check(F_CODE(frame)->co_name));
250 return PyUnicode_AsUTF8(F_CODE(frame)->co_name);
251 }
252
call_callback(PyObject * callable,THP_EVAL_API_FRAME_OBJECT * _frame,PyObject * locals,CacheEntry * cache_entry,FrameState * frame_state)253 static inline PyObject* call_callback(
254 PyObject* callable,
255 THP_EVAL_API_FRAME_OBJECT* _frame,
256 PyObject* locals,
257 CacheEntry* cache_entry,
258 FrameState* frame_state) {
259
260 // remember to update the type signature for DynamoCallbackFn.__call__ in torch/_dynamo/types.py
261 // if this function changes
262 #if IS_PYTHON_3_11_PLUS
263 THPPyInterpreterFrame* frame = THPPyInterpreterFrame_New(_frame);
264 if (frame == NULL) {
265 return NULL;
266 }
267 frame->locals = locals;
268 #else
269 PyObject* frame = Py_NewRef(_frame);
270 #endif
271
272 PyObject* cache_entry_pyobj = CacheEntry_to_obj(cache_entry);
273 PyObject* res = PyObject_CallFunction(
274 callable,
275 "OOO",
276 frame,
277 cache_entry_pyobj,
278 frame_state);
279 Py_DECREF(frame);
280 Py_DECREF(cache_entry_pyobj);
281 return res;
282 }
283
clear_old_frame_if_python_312_plus(PyThreadState * tstate,THP_EVAL_API_FRAME_OBJECT * frame)284 static inline void clear_old_frame_if_python_312_plus(
285 PyThreadState* tstate,
286 THP_EVAL_API_FRAME_OBJECT* frame) {
287 #if IS_PYTHON_3_12_PLUS
288
289 THP_PyFrame_Clear(frame);
290 THP_PyThreadState_PopFrame(tstate, frame);
291
292 #endif
293 }
294
eval_custom_code_impl(PyThreadState * tstate,THP_EVAL_API_FRAME_OBJECT * frame,PyCodeObject * code,int throw_flag,int free_vars_copied)295 inline static PyObject* eval_custom_code_impl(
296 PyThreadState* tstate,
297 THP_EVAL_API_FRAME_OBJECT* frame,
298 PyCodeObject* code,
299 int throw_flag,
300 int free_vars_copied) {
301
302 DEBUG_NULL_CHECK(tstate);
303 DEBUG_NULL_CHECK(frame);
304 DEBUG_NULL_CHECK(code);
305
306 #if IS_PYTHON_3_11_PLUS
307
308 // Generate Python function object and _PyInterpreterFrame in a way similar to
309 // https://github.com/python/cpython/blob/e715da6db1d1d70cd779dc48e1ba8110c51cc1bf/Python/ceval.c#L1130
310 #if IS_PYTHON_3_12_PLUS
311 PyFunctionObject* old_func = (PyFunctionObject*) frame->f_funcobj;
312 size_t size = code->co_framesize;
313 #else
314 PyFunctionObject* old_func = frame->f_func;
315 size_t size = code->co_nlocalsplus + code->co_stacksize + FRAME_SPECIALS_SIZE;
316 #endif
317
318 PyFunctionObject* func = _PyFunction_CopyWithNewCode(old_func, code);
319 if (func == NULL) {
320 return NULL;
321 }
322
323 THP_EVAL_API_FRAME_OBJECT* shadow = THP_PyThreadState_BumpFramePointerSlow(tstate, size);
324 if (shadow == NULL) {
325 Py_DECREF(func);
326 return NULL;
327 }
328
329 Py_INCREF(func);
330 // consumes reference to func
331 #if IS_PYTHON_3_12_PLUS
332 _PyFrame_Initialize(shadow, func, NULL, code, 0);
333 #else
334 _PyFrame_InitializeSpecials(shadow, func, NULL, code->co_nlocalsplus);
335 #endif
336
337 PyObject** fastlocals_old = frame->localsplus;
338 PyObject** fastlocals_new = shadow->localsplus;
339 Py_ssize_t n_old = F_CODE(frame)->co_nlocalsplus;
340 Py_ssize_t n_new = code->co_nlocalsplus;
341
342 // localsplus are XINCREF'd by default eval frame, so all values must be valid.
343 #if !(IS_PYTHON_3_12_PLUS)
344 // _PyFrame_Initialize in 3.12 already does this
345 for (int i = 0; i < code->co_nlocalsplus; i++) {
346 fastlocals_new[i] = NULL;
347 }
348 #endif
349
350 // for 3.11+, if free_vars_copied is true, we do not need to
351 // run the first COPY_FREE_VARS since THP_PyFrame_FastToLocalsWithError
352 // already did the equivalent action.
353 if (free_vars_copied && _Py_OPCODE(_PyCode_CODE(F_CODE(shadow))[0]) == COPY_FREE_VARS) {
354 PREV_INSTR(shadow) = _PyCode_CODE(F_CODE(shadow));
355 }
356
357 #else
358
359 THP_EVAL_API_FRAME_OBJECT* shadow = PyFrame_New(tstate, code, frame->f_globals, NULL);
360 if (shadow == NULL) {
361 return NULL;
362 }
363
364 PyObject** fastlocals_old = frame->f_localsplus;
365 PyObject** fastlocals_new = shadow->f_localsplus;
366 Py_ssize_t n_old = F_CODE(frame)->co_nlocals + PyCode_GetNFreevars(F_CODE(frame)) + PyCode_GetNCellvars(F_CODE(frame));
367 Py_ssize_t n_new = code->co_nlocals + PyCode_GetNFreevars(code) + PyCode_GetNCellvars(code);
368
369 #endif
370
371 // ============== Initialize new frame from old frame ============
372 // Python internal for executing a function:
373 // 1. CPython interpreter first creates an empty frame according to the code object
374 // 2. CPython interpreter initializes the frame by filling arguments/free variables into frame and initializing cell variables
375 // 3. CPython interpreter executes the code object
376 //
377 // Dynamo hooks the 3th step: before executing the code object, Dynamo transforms the code object into a new code object. Then, the old frame is not suitable for executing the new code. Therefore, Dynamo needs to manually create and initialize a new frame to execute the new code.
378 // The main task is to copy data in old frame to new frame, concerning a storage space named `localsplus`.
379 //
380 // localsplus storage is an array with the following layout:
381 // | args | new_locals | cell_variables | free_variables |
382 // | <--- from left to right, index from 0 to n - 1 ---> |
383 // code.co_varnames == args + new_locals, code.co_nlocals == len(code.co_varnames)
384 // code.co_freevars == free_variables
385 // In Python 3.10 and lower, `n == code.co_nlocals + len(code.co_cellvars) + len(code.co_freevars)` (Python expression)
386 // In Python 3.11 and higher, `n <= code.co_nlocals + len(code.co_cellvars) + len(code.co_freevars)` (Python expression). There is an extra field in Python C-API: `n == code->co_nlocalsplus` (C expression) to retrieve the length of array.
387 // The complexity happens if an argument becomes a cell variable:
388 // In Python 3.10 and lower, `code.co_cellvars == cell_variables`, and the corresponding slot in args becomes `NULL`.
389 // In Python 3.11 and higher, `code.co_cellvars > cell_variables`, that cell variable is still stored in args, with a flag set in corresponding item's `co_localspluskinds` .
390 //
391 // ideally, we need to look up new localsplus from old localsplus by name:
392 // for i, name, value in enumerate(localsplusnames_old):
393 // if value != NULL: (NULL happens for new local variables and arguments that becomes cell variables)
394 // name_to_idx[name] = i
395 // for i, name in enumerate(localsplusnames_new):
396 // if name in name_to_idx:
397 // fastlocals_new[i] = fastlocals_old[name_to_idx[name]]
398 //
399 // The above process of building a `name_to_idx` mapping is expensive.
400 // Dynamo makes the following assumptions:
401 // 1. new code has the same arguments as the old code (both the number and the order)
402 // 2. new code has the same cell variables as the old code (both the number and the order)
403 // 3. new code has the same free variables as the old code (both the number and the order)
404 // The only flexibility lies in new local variables: new code can introduce their own variables.
405 // With these assumptions, Dynamo can copy data directly by index. Dynamo just needs to take care of copying cell variables correctly.
406 // To avoid runtime cost, the assumptions are checked when we first generate the code object in pytorch/torch/_dynamo/convert_frame.py .
407
408
409 // copy args
410 // according to https://docs.python.org/3/library/inspect.html , `co_argcount` is the number of arguments (not including keyword only arguments, * or ** args). so we need to add `co_kwonlyargcount` and `co_flags` to get the total number of arguments.
411 // !!(F_CODE(frame)->co_flags & CO_VARARGS) is 1 if the function has *args, 0 otherwise
412 // !!(F_CODE(frame)->co_flags & CO_VARKEYWORDS) is 1 if the function has **kwargs, 0 otherwise
413 // they convert bit flags to 0 or 1, and avoid branching.
414 // This is performance critical code, so we really care about performance.
415 Py_ssize_t total_argcount_old = F_CODE(frame)->co_argcount + F_CODE(frame)->co_kwonlyargcount + !!(F_CODE(frame)->co_flags & CO_VARARGS) + !!(F_CODE(frame)->co_flags & CO_VARKEYWORDS);
416
417 for (Py_ssize_t i = 0; i < total_argcount_old; i++) {
418 Py_XINCREF(fastlocals_old[i]);
419 fastlocals_new[i] = fastlocals_old[i];
420 }
421
422 // copy free vars
423 Py_ssize_t nfrees_old = PyCode_GetNFreevars(F_CODE(frame));
424
425 for (Py_ssize_t i = 0; i < nfrees_old; i++) {
426 Py_XINCREF(fastlocals_old[n_old - 1 - i]);
427 fastlocals_new[n_new - 1 - i] = fastlocals_old[n_old - 1 - i];
428 }
429
430 // copy cell vars, from high index to low index, until it meets a variable that is not cell variable.
431 for (Py_ssize_t i = n_old - nfrees_old - 1, j = n_new - nfrees_old - 1; i >= total_argcount_old; i--, j--) {
432
433 // conditional test to tell if a variable is not a cell variable
434 // this is straightforward in Python 3.11 and higher, as there are bit flags in `co_localspluskinds` to tell if a variable is a cell variable.
435 // in Python 3.10 and lower, essentially we are checking if a variable is a new local variable (because of the layout mentioned above, the first variable that is not cell variable is the first new local variable). the corresponding slot in `flocalsplus` is NULL for new local variables.
436 #if IS_PYTHON_3_11_PLUS
437 if(!(_PyLocals_GetKind(F_CODE(frame)->co_localspluskinds, i) & CO_FAST_CELL))
438 {
439 break;
440 }
441 #else
442 if(fastlocals_old[i] == NULL)
443 {
444 break;
445 }
446 #endif
447
448 Py_XINCREF(fastlocals_old[i]);
449 fastlocals_new[j] = fastlocals_old[i];
450 }
451
452 // NOTE: if you want to evaluate frame instead of shadow in 3.12+,
453 // you need to clear_old_frame_if_python_312_plus the shadow frame BEFORE
454 // calling eval_frame_default (i.e. here) and comment out the
455 // clear_old_frame_if_python_312_plus call on the original frame.
456
457 PyObject* result = eval_frame_default(tstate, shadow, throw_flag);
458
459 #if IS_PYTHON_3_12_PLUS
460
461 // frame is cleared by caller
462 Py_DECREF(func);
463
464 #elif IS_PYTHON_3_11_PLUS
465
466 // In 3.11, shadow has is_entry set to true, so _PyEvalFrameClearAndPop is not called,
467 // so we manually clear and pop the shadow frame.
468 THP_PyFrame_Clear(shadow);
469 THP_PyThreadState_PopFrame(tstate, shadow);
470 Py_DECREF(func);
471
472 #else
473
474 Py_DECREF(shadow);
475
476 #endif
477
478 return result;
479 }
480
481 // This wrapper function adds a profiler event
eval_custom_code(PyThreadState * tstate,THP_EVAL_API_FRAME_OBJECT * frame,PyCodeObject * code,int throw_flag,int free_vars_copied)482 inline static PyObject* eval_custom_code(
483 PyThreadState* tstate,
484 THP_EVAL_API_FRAME_OBJECT* frame,
485 PyCodeObject* code,
486 int throw_flag,
487 int free_vars_copied) {
488 const char* trace_id = compile_context;
489 _PytorchRecordFunctionState* rf = _pytorch_record_function_enter_with_context("Torch-Compiled Region", trace_id);
490 PyObject* result = eval_custom_code_impl(
491 tstate,
492 frame,
493 code,
494 throw_flag,
495 free_vars_copied
496 );
497 _pytorch_record_function_exit(rf);
498 return result;
499 }
500
_custom_eval_frame_shim(PyThreadState * tstate,THP_EVAL_API_FRAME_OBJECT * frame,int throw_flag)501 static PyObject* _custom_eval_frame_shim(
502 PyThreadState* tstate,
503 THP_EVAL_API_FRAME_OBJECT* frame,
504 int throw_flag) {
505 // Shims logic into one of three states. Can probably be refactored into a
506 // single func, later:
507 // - None: disables TorchDynamo
508 // - False: run-only mode (reuse existing compiles)
509 // - Python callable(): enables TorchDynamo
510 PyObject* callback = eval_frame_callback_get();
511
512 if (callback == Py_None) {
513 return eval_frame_default(tstate, frame, throw_flag);
514 }
515
516 int should_clear_frame = 0;
517 PyObject* result = _custom_eval_frame(tstate, frame, throw_flag, callback, &should_clear_frame);
518 if (should_clear_frame) {
519 clear_old_frame_if_python_312_plus(tstate, frame);
520 }
521 return result;
522 }
523
524 static PyObject* skip_code_recursive_flag;
525
526 // NOTE: In 3.12+, the frame evaluation function (callee) is responsible for clearing/popping
527 // the frame, meaning that unless we default evaluate the original frame,
528 // we are responsible for clearing it - via clear_old_frame_if_python_312_plus.
529 // The should_clear_frame flag is used to indicate whether the frame should be
530 // cleared by _custom_eval_frame's caller.
531 // Generally should_clear_frame should be set if and only we don't eval_frame_default.
_custom_eval_frame(PyThreadState * tstate,THP_EVAL_API_FRAME_OBJECT * frame,int throw_flag,PyObject * callback,int * should_clear_frame)532 static PyObject* _custom_eval_frame(
533 PyThreadState* tstate,
534 THP_EVAL_API_FRAME_OBJECT* frame,
535 int throw_flag,
536 PyObject* callback,
537 int* should_clear_frame) {
538 #if IS_PYTHON_3_11_PLUS
539 DEBUG_TRACE(
540 "begin %s %s %i %i",
541 get_frame_name(frame),
542 PyUnicode_AsUTF8(F_CODE(frame)->co_filename),
543 F_CODE(frame)->co_firstlineno,
544 _PyInterpreterFrame_LASTI(frame));
545 #else
546 DEBUG_TRACE(
547 "begin %s %s %i %i %i",
548 get_frame_name(frame),
549 PyUnicode_AsUTF8(F_CODE(frame)->co_filename),
550 frame->f_lineno,
551 frame->f_lasti,
552 frame->f_iblock);
553 #endif
554
555 if (throw_flag) {
556 // When unwinding generators, eval frame is called with throw_flag ==
557 // true. Frame evaluation is supposed to continue unwinding by propagating
558 // the exception. Dynamo doesn't really know how to do this, nor does it
559 // really want to do this, because there's unlikely any code to capture
560 // (you're going to immediately quit out of the frame, perhaps running
561 // some unwinding logic along the way). So we just run the default
562 // handler in this case.
563 //
564 // NB: A previous version of this patch returned NULL. This is wrong,
565 // because returning NULL is *different* from unwinding an exception.
566 // In particular, you will not execute things like context manager
567 // __exit__ if you just return NULL.
568 //
569 // NB: It's /conceivable/ that you might want to actually still call the
570 // Dynamo callback when throw_flag == TRUE, to give Dynamo a chance to
571 // do any stack unwinding code. But this is not really useful because
572 // (1) Dynamo doesn't actually know how to do stack unwinding, so it would
573 // immediately skip the frame, and (2) even if it did, this would only
574 // be profitable if there was tensor code in the unwinding code. Seems
575 // unlikely.
576 DEBUG_TRACE("throw %s", get_frame_name(frame));
577 return eval_frame_default(tstate, frame, throw_flag);
578 }
579
580 ExtraState* extra = get_extra_state(F_CODE(frame));
581 if (extra == SKIP_CODE || (callback == Py_False && extra == NULL)) {
582 DEBUG_TRACE("skip %s", get_frame_name(frame));
583 return eval_frame_default(tstate, frame, throw_flag);
584 }
585 if (extra == SKIP_CODE_RECURSIVE) {
586 DEBUG_TRACE("skip recursive %s", get_frame_name(frame));
587 eval_frame_callback_set(Py_None);
588 PyObject* result = eval_frame_default(tstate, frame, throw_flag);
589 eval_frame_callback_set(callback);
590 return result;
591 }
592
593 if (extra == NULL) {
594 extra = init_and_set_extra_state(F_CODE(frame));
595 }
596
597
598 int free_vars_copied = 0;
599 #if IS_PYTHON_3_12_PLUS
600 PyObject *locals = get_framelocals_mapping(frame);
601 #else
602 if (THP_PyFrame_FastToLocalsWithError(frame, &free_vars_copied) < 0) {
603 DEBUG_TRACE("error %s", get_frame_name(frame));
604 *should_clear_frame = 1;
605 return NULL;
606 }
607 PyObject *locals = frame->f_locals;
608 Py_INCREF(locals);
609 #endif
610
611 PyObject* backend = get_backend(callback);
612
613 // A callback of Py_False indicates "run only" mode, the cache is checked, but
614 // we never compile.
615 if (callback == Py_False) {
616 DEBUG_TRACE("In run only mode %s", get_frame_name(frame));
617 _PytorchRecordFunctionState* rf = _pytorch_record_function_enter(cache_lookup_profiler_str);
618 PyObject* maybe_cached_code = lookup(extra, locals, backend);
619 _pytorch_record_function_exit(rf);
620
621 Py_DECREF(locals);
622
623 if (maybe_cached_code == NULL) {
624 // guard eval failed, keep propagating
625 *should_clear_frame = 1;
626 return NULL;
627 } else if (maybe_cached_code == Py_None) {
628 DEBUG_TRACE("cache miss %s", get_frame_name(frame));
629 return eval_frame_default(tstate, frame, throw_flag);
630 }
631 PyCodeObject* cached_code = (PyCodeObject*)maybe_cached_code;
632 // used cached version
633 DEBUG_TRACE("cache hit %s", get_frame_name(frame));
634 *should_clear_frame = 1;
635 return eval_custom_code(tstate, frame, cached_code, throw_flag, 0);
636 }
637 DEBUG_CHECK(PyDict_CheckExact(locals));
638 DEBUG_CHECK(PyDict_CheckExact(frame->f_globals));
639 DEBUG_CHECK(PyDict_CheckExact(frame->f_builtins));
640
641 // We don't run the current custom_eval_frame behavior for guards.
642 // So we temporarily set the callback to Py_None to drive the correct behavior
643 // in the shim.
644 eval_frame_callback_set(Py_None);
645
646 _PytorchRecordFunctionState* rf = _pytorch_record_function_enter(cache_lookup_profiler_str);
647 PyObject* maybe_cached_code = lookup(extra, locals, backend);
648 _pytorch_record_function_exit(rf);
649 if (maybe_cached_code == NULL) {
650 // Python error
651 *should_clear_frame = 1;
652 Py_DECREF(locals);
653 return NULL;
654 } else if (maybe_cached_code != Py_None) {
655 PyCodeObject* cached_code = (PyCodeObject*)maybe_cached_code;
656 // used cached version
657 DEBUG_TRACE("cache hit %s", get_frame_name(frame));
658 // Re-enable custom behavior
659 eval_frame_callback_set(callback);
660 *should_clear_frame = 1;
661 Py_DECREF(locals);
662 return eval_custom_code(tstate, frame, cached_code, throw_flag, free_vars_copied);
663 }
664 // cache miss
665 CacheEntry* cache_entry = extract_cache_entry(extra);
666 FrameState* frame_state = extract_frame_state(extra);
667 PyObject* result =
668 call_callback(callback, frame, locals, cache_entry, frame_state);
669 Py_DECREF(locals);
670 if (result == NULL) {
671 // internal exception, returning here will leak the exception into user code
672 // this is useful for debugging -- but we dont want it to happen outside of
673 // testing
674 // NB: we intentionally DO NOT re-enable custom behavior to prevent
675 // cascading failure from internal exceptions. The upshot is if
676 // Dynamo barfs, that's it for Dynamo, even if you catch the exception
677 // inside the torch.compile block we won't try to Dynamo anything else.
678 *should_clear_frame = 1;
679 return NULL;
680 } else if (result == skip_code_recursive_flag) {
681 // Dynamo returned skip_code_recursive_flag, so we should recursively skip code.
682 DEBUG_TRACE("create skip recursive %s", get_frame_name(frame));
683 set_extra_state(F_CODE(frame), SKIP_CODE_RECURSIVE);
684 PyObject* r = eval_frame_default(tstate, frame, throw_flag);
685 // Re-enable custom behavior
686 eval_frame_callback_set(callback);
687 return r;
688 } else if (result != Py_None) {
689 DEBUG_TRACE("create cache %s", get_frame_name(frame));
690
691 // NB: We could use extract_cache_entry to get the cache_entry, but
692 // extract_cache_entry returns a borrowed reference. Modifying a borrowed
693 // reference seems wrong. Therefore, we directly access the
694 // extra->cache_entry. extra wont be NULL here.
695 CacheEntry* new_cache_entry = create_cache_entry(extra, result, backend);
696 Py_DECREF(result);
697
698 // Update the existing cache_entry on the extra object. This extra object is
699 // sitting on the extra scratch space, we are just changing the cache_entry
700 // ptr. As a result, extra now becomes the owner of CacheEntry object. This
701 // will be cleaned up when set_extra_state is called.
702 // Re-enable custom behavior
703 eval_frame_callback_set(callback);
704 *should_clear_frame = 1;
705 return eval_custom_code(tstate, frame, CacheEntry_get_code(new_cache_entry), throw_flag, free_vars_copied);
706 } else {
707 DEBUG_TRACE("create skip %s", get_frame_name(frame));
708 Py_DECREF(result);
709 set_extra_state(F_CODE(frame), SKIP_CODE);
710 // Re-enable custom behavior
711 eval_frame_callback_set(callback);
712 return eval_frame_default(tstate, frame, throw_flag);
713 }
714 }
715
716 #else // IS_PYTHON_3_14_PLUS
717
718 // Fake definitions for everything we removed
719
720 typedef struct THPPyInterpreterFrame {
721 PyObject_HEAD
722 _PyInterpreterFrame* frame; // Borrowed reference
723 } THPPyInterpreterFrame;
724
enable_eval_frame_shim(PyThreadState * tstate)725 inline static void enable_eval_frame_shim(PyThreadState* tstate) {}
enable_eval_frame_default(PyThreadState * tstate)726 inline static void enable_eval_frame_default(PyThreadState* tstate) {}
727
728 static struct PyGetSetDef THPPyInterpreterFrame_properties[] = {NULL};
729
730 static PyTypeObject THPPyInterpreterFrameType = {
731 PyVarObject_HEAD_INIT(NULL, 0)
732 .tp_name = "torch._C._dynamo.eval_frame._PyInterpreterFrame",
733 .tp_basicsize = sizeof(THPPyInterpreterFrame),
734 .tp_flags = Py_TPFLAGS_DEFAULT,
735 .tp_getset = THPPyInterpreterFrame_properties,
736 };
737
738 #endif // CPython 3.14
739
increment_working_threads(PyThreadState * tstate)740 static PyObject* increment_working_threads(PyThreadState* tstate) {
741 active_dynamo_threads = active_dynamo_threads + 1;
742 if (active_dynamo_threads > 0) {
743 enable_eval_frame_shim(tstate);
744 }
745 Py_RETURN_NONE;
746 }
747
decrement_working_threads(PyThreadState * tstate)748 static PyObject* decrement_working_threads(PyThreadState* tstate) {
749 if (active_dynamo_threads > 0) {
750 active_dynamo_threads = active_dynamo_threads - 1;
751 if (active_dynamo_threads == 0) {
752 enable_eval_frame_default(tstate);
753 }
754 }
755 Py_RETURN_NONE;
756 }
757
set_eval_frame(PyObject * new_callback,PyThreadState * tstate)758 static PyObject* set_eval_frame(PyObject* new_callback, PyThreadState* tstate) {
759 // Change the eval frame callback and return the old one
760 // - None: disables TorchDynamo
761 // - False: run-only mode (reuse existing compiles)
762 // - Python callable(): enables TorchDynamo
763 PyObject* old_callback = eval_frame_callback_get();
764
765 // owned by caller
766 Py_INCREF(old_callback);
767
768 if (old_callback != Py_None && new_callback == Py_None) {
769 decrement_working_threads(tstate);
770 } else if (old_callback == Py_None && new_callback != Py_None) {
771 increment_working_threads(tstate);
772 }
773
774 Py_INCREF(new_callback);
775 Py_DECREF(old_callback);
776
777 // Set thread local callback. This will drive behavior of our shim, if/when it
778 // is installed.
779 eval_frame_callback_set(new_callback);
780
781 return old_callback;
782 }
783
set_eval_frame_py(PyObject * dummy,PyObject * callback)784 static PyObject* set_eval_frame_py(PyObject* dummy, PyObject* callback) {
785 if (callback != Py_None && callback != Py_False &&
786 !PyCallable_Check(callback)) {
787 DEBUG_TRACE0("arg error");
788 PyErr_SetString(PyExc_TypeError, "expected a callable");
789 return NULL;
790 }
791 DEBUG_TRACE(
792 "python enabled=%d and is run_only=%d",
793 callback != Py_None,
794 callback == Py_False);
795 return set_eval_frame(callback, PyThreadState_GET());
796 }
797
reset_code(PyObject * dummy,PyObject * code)798 static PyObject* reset_code(PyObject* dummy, PyObject* code) {
799 if (!PyCode_Check(code)) {
800 DEBUG_TRACE0("arg error");
801 PyErr_SetString(PyExc_TypeError, "expected a code object");
802 return NULL;
803 }
804
805 // set_extra_state destroys the existing object on extra scratch space.
806 set_extra_state((PyCodeObject*)code, NULL);
807 Py_RETURN_NONE;
808 }
809
unsupported(PyObject * dummy,PyObject * args)810 static PyObject* unsupported(PyObject* dummy, PyObject* args) {
811 // a dummy C function used in testing
812 PyObject* obj1 = NULL;
813 PyObject* obj2 = NULL;
814 if (!PyArg_ParseTuple(args, "OO", &obj1, &obj2)) {
815 return NULL;
816 }
817 Py_INCREF(obj2);
818 return obj2;
819 }
820
skip_code(PyObject * dummy,PyObject * obj)821 static PyObject* skip_code(PyObject* dummy, PyObject* obj) {
822 if (!PyCode_Check(obj)) {
823 PyErr_SetString(PyExc_TypeError, "expected a code object");
824 return NULL;
825 }
826
827 // set_extra_state destroys the existing object on extra scratch space.
828 set_extra_state((PyCodeObject*)obj, SKIP_CODE);
829 Py_RETURN_NONE;
830 }
831
set_guard_error_hook(PyObject * dummy,PyObject * obj)832 static PyObject* set_guard_error_hook(PyObject* dummy, PyObject* obj) {
833 if (obj == Py_None) {
834 obj = NULL;
835 }
836 Py_XSETREF(guard_error_hook, Py_XNewRef(obj));
837 Py_RETURN_NONE;
838 }
839
set_context_frame(PyObject * dummy,PyObject * obj)840 static PyObject* set_context_frame(PyObject* dummy, PyObject* obj) {
841 int frame_id, frame_compile_id, attempt;
842 if (!PyArg_ParseTuple(obj, "iii", &frame_id, &frame_compile_id, &attempt)) {
843 PyErr_SetString(PyExc_TypeError, "Expected three integers");
844 return NULL;
845 }
846 if (attempt == 0) {
847 sprintf(compile_context, "%d/%d", frame_id, frame_compile_id);
848 } else {
849 sprintf(compile_context, "%d/%d_%d", frame_id, frame_compile_id, attempt);
850 }
851 Py_RETURN_NONE;
852 }
853
854 static PyMethodDef _methods[] = {
855 {"set_eval_frame", set_eval_frame_py, METH_O, NULL},
856 {"reset_code", reset_code, METH_O, NULL},
857 {"unsupported", unsupported, METH_VARARGS, NULL},
858 {"skip_code", skip_code, METH_O, NULL},
859 {"set_guard_error_hook", set_guard_error_hook, METH_O, NULL},
860 {"set_context_frame", set_context_frame, METH_O, NULL},
861 {NULL, NULL, 0, NULL}};
862
863 static struct PyModuleDef _module = {
864 PyModuleDef_HEAD_INIT,
865 "torch._C._dynamo.eval_frame",
866 "Module containing hooks to override eval_frame",
867 -1,
868 _methods};
869
870 #if IS_PYTHON_3_12_PLUS
871 #define _PyEval_RequestCodeExtraIndex PyUnstable_Eval_RequestCodeExtraIndex
872 #endif
873
torch_c_dynamo_eval_frame_init(void)874 PyObject* torch_c_dynamo_eval_frame_init(void) {
875 extra_index = _PyEval_RequestCodeExtraIndex(destroy_extra_state);
876 if (extra_index < 0) {
877 PyErr_SetString(PyExc_RuntimeError,
878 "dynamo: unable to register extra index");
879 return NULL;
880 }
881
882 int result = PyThread_tss_create(&eval_frame_callback_key);
883 CHECK(result == 0);
884
885 Py_INCREF(Py_None);
886 eval_frame_callback_set(Py_None);
887
888 PyObject* module = PyModule_Create(&_module);
889 if (module == NULL) {
890 return NULL;
891 }
892
893 #if IS_PYTHON_3_11_PLUS
894 if (PyType_Ready(&THPPyInterpreterFrameType) < 0) {
895 return NULL;
896 }
897 Py_INCREF(&THPPyInterpreterFrameType);
898 if (PyModule_AddObject(module, "_PyInterpreterFrame", (PyObject*)&THPPyInterpreterFrameType) != 0) {
899 return NULL;
900 }
901 #endif
902
903 skip_code_recursive_flag = PyObject_New(PyObject, &PyBaseObject_Type);
904 if (skip_code_recursive_flag == NULL) {
905 return NULL;
906 }
907 if (PyModule_AddObject(module, "skip_code_recursive_flag", skip_code_recursive_flag) != 0) {
908 return NULL;
909 }
910
911 return module;
912 }
913