xref: /aosp_15_r20/external/pytorch/torch/csrc/dynamo/eval_frame.c (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define PY_SSIZE_T_CLEAN
2 #include <torch/csrc/dynamo/cache_entry.h>
3 #include <torch/csrc/dynamo/cpp_shim.h>
4 #include <torch/csrc/dynamo/cpython_defs.h>
5 #include <torch/csrc/dynamo/cpython_includes.h>
6 #include <torch/csrc/dynamo/debug_macros.h>
7 #include <torch/csrc/dynamo/extra_state.h>
8 #include <torch/csrc/dynamo/framelocals_mapping.h>
9 #include <torch/csrc/utils/python_compat.h>
10 #include <opcode.h>
11 #include <stdbool.h>
12 
13 #define MAX_COMPILE_CONTEXT_SIZE 100
14 
15 PyObject* guard_error_hook = NULL;
16 const char* cache_lookup_profiler_str = "TorchDynamo Cache Lookup";
17 static char compile_context[MAX_COMPILE_CONTEXT_SIZE];
18 static int active_dynamo_threads = 0;
19 
20 static Py_tss_t eval_frame_callback_key = Py_tss_NEEDS_INIT;
21 
eval_frame_callback_get(void)22 inline static PyObject* eval_frame_callback_get(void) {
23   void* result = PyThread_tss_get(&eval_frame_callback_key);
24   if (unlikely(result == NULL)) {
25     return (PyObject*)Py_None;
26   } else {
27     return (PyObject*)result;
28   }
29 }
30 
eval_frame_callback_set(PyObject * obj)31 inline static void eval_frame_callback_set(PyObject* obj) {
32   PyThread_tss_set(&eval_frame_callback_key, obj);
33 }
34 
35 // 3.14 Not supported at all. See cpython_defs.c for hints
36 #if !(IS_PYTHON_3_14_PLUS)
37 
38 // All the eval APIs change in 3.11 so we need to decide which one to use on the fly
39 // https://docs.python.org/3/c-api/init.html#c._PyFrameEvalFunction
40 #if IS_PYTHON_3_11_PLUS
41 #define THP_EVAL_API_FRAME_OBJECT _PyInterpreterFrame
42 
43 // We need to be able to return the _PyInterpreterFrame to python so create
44 // a python binding for it
45 
46 typedef struct THPPyInterpreterFrame {
47   PyObject_HEAD
48   _PyInterpreterFrame* frame; // Borrowed reference
49   PyObject* locals;
50 } THPPyInterpreterFrame;
51 
52 THPPyInterpreterFrame* THPPyInterpreterFrame_New(_PyInterpreterFrame* frame);
53 
54 #define DECLARE_PYOBJ_ATTR(name) \
55 static PyObject* THPPyInterpreterFrame_##name(THPPyInterpreterFrame* self, PyObject* _noargs) { \
56   PyObject* res = (PyObject*)self->frame->name; \
57   Py_XINCREF(res); \
58   return res; \
59 }
60 
61 #if IS_PYTHON_3_12_PLUS
62 DECLARE_PYOBJ_ATTR(f_funcobj)
63 #else
64 DECLARE_PYOBJ_ATTR(f_func)
65 #endif
66 
DECLARE_PYOBJ_ATTR(f_globals)67 DECLARE_PYOBJ_ATTR(f_globals)
68 DECLARE_PYOBJ_ATTR(f_builtins)
69 
70 static PyObject* THPPyInterpreterFrame_f_locals(THPPyInterpreterFrame* self, PyObject* _noargs) {
71   DEBUG_NULL_CHECK(self->locals);
72   Py_XINCREF(self->locals);
73   return self->locals;
74 }
75 
76 #if IS_PYTHON_3_13_PLUS
77 DECLARE_PYOBJ_ATTR(f_executable)
78 #else
79 DECLARE_PYOBJ_ATTR(f_code)
80 #endif
81 
DECLARE_PYOBJ_ATTR(frame_obj)82 DECLARE_PYOBJ_ATTR(frame_obj)
83 
84 #undef DECLARE_PYOBJ_ATTR
85 
86 static THPPyInterpreterFrame* THPPyInterpreterFrame_previous(THPPyInterpreterFrame* self, PyObject* _noargs) {
87   THPPyInterpreterFrame* res = THPPyInterpreterFrame_New(self->frame->previous);
88   return res;
89 }
90 
91 // This is not a true attribute of the class but we do access it in python and it is hard to implement
92 // on the python side, so do it here:
THPPyInterpreterFrame_f_lasti(THPPyInterpreterFrame * self,PyObject * _noargs)93 static PyObject* THPPyInterpreterFrame_f_lasti(THPPyInterpreterFrame* self, PyObject* _noargs) {
94   return PyLong_FromLong(_PyInterpreterFrame_LASTI(self->frame));
95 }
96 
THPPyInterpreterFrame_f_lineno(THPPyInterpreterFrame * self,PyObject * _noargs)97 static PyObject* THPPyInterpreterFrame_f_lineno(THPPyInterpreterFrame* self, PyObject* _noargs) {
98   if (!self->frame->frame_obj) {
99     return PyLong_FromLong(F_CODE(self->frame)->co_firstlineno);
100   }
101   int lineno = PyFrame_GetLineNumber(self->frame->frame_obj);
102   if (lineno < 0) {
103     Py_RETURN_NONE;
104   }
105   return PyLong_FromLong(lineno);
106 }
107 
THPPyInterpreterFrame_f_back(THPPyInterpreterFrame * self,PyObject * _noargs)108 static PyObject* THPPyInterpreterFrame_f_back(THPPyInterpreterFrame* self, PyObject* _noargs) {
109   if (!self->frame->frame_obj) {
110     Py_RETURN_NONE;
111   }
112   return (PyObject*)PyFrame_GetBack(self->frame->frame_obj);
113 }
114 
115 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
116 static struct PyGetSetDef THPPyInterpreterFrame_properties[] = {
117 #if IS_PYTHON_3_12_PLUS
118     {"f_func", (getter)THPPyInterpreterFrame_f_funcobj, NULL, NULL, NULL},
119 #else
120     {"f_func", (getter)THPPyInterpreterFrame_f_func, NULL, NULL, NULL},
121 #endif
122     {"f_globals", (getter)THPPyInterpreterFrame_f_globals, NULL, NULL, NULL},
123     {"f_builtins", (getter)THPPyInterpreterFrame_f_builtins, NULL, NULL, NULL},
124     {"f_locals", (getter)THPPyInterpreterFrame_f_locals, NULL, NULL, NULL},
125 #if IS_PYTHON_3_13_PLUS
126     {"f_code", (getter)THPPyInterpreterFrame_f_executable, NULL, NULL, NULL},
127 #else
128     {"f_code", (getter)THPPyInterpreterFrame_f_code, NULL, NULL, NULL},
129 #endif
130     {"frame_obj", (getter)THPPyInterpreterFrame_frame_obj, NULL, NULL, NULL},
131     {"previous", (getter)THPPyInterpreterFrame_previous, NULL, NULL, NULL},
132     {"f_lasti", (getter)THPPyInterpreterFrame_f_lasti, NULL, NULL, NULL},
133     {"f_lineno", (getter)THPPyInterpreterFrame_f_lineno, NULL, NULL, NULL},
134     {"f_back", (getter)THPPyInterpreterFrame_f_back, NULL, NULL, NULL},
135     {NULL}};
136 
137 static PyTypeObject THPPyInterpreterFrameType = {
138     PyVarObject_HEAD_INIT(NULL, 0)
139     .tp_name = "torch._C._dynamo.eval_frame._PyInterpreterFrame",
140     .tp_basicsize = sizeof(THPPyInterpreterFrame),
141     .tp_flags = Py_TPFLAGS_DEFAULT,
142     .tp_getset = THPPyInterpreterFrame_properties,
143 };
144 
145 
THPPyInterpreterFrame_New(_PyInterpreterFrame * frame)146 THPPyInterpreterFrame* THPPyInterpreterFrame_New(_PyInterpreterFrame* frame) {
147   PyTypeObject* type = (PyTypeObject*)&THPPyInterpreterFrameType;
148   THPPyInterpreterFrame* self = (THPPyInterpreterFrame*)type->tp_alloc(type, 0);
149   if (!self)
150     return NULL;
151   self->frame = frame;
152   self->locals = NULL;
153   return self;
154 }
155 
156 
157 #else
158 #define THP_EVAL_API_FRAME_OBJECT PyFrameObject
159 
160 static int
THP_PyFrame_FastToLocalsWithError(THP_EVAL_API_FRAME_OBJECT * frame,int * free_vars_copied)161 THP_PyFrame_FastToLocalsWithError(THP_EVAL_API_FRAME_OBJECT *frame, int *free_vars_copied) {
162   return PyFrame_FastToLocalsWithError(frame);
163 }
164 #endif
165 
166 static PyObject* _custom_eval_frame_shim(
167     PyThreadState* tstate,
168     THP_EVAL_API_FRAME_OBJECT* frame,
169     int throw_flag);
170 static PyObject* _custom_eval_frame(
171     PyThreadState* tstate,
172     THP_EVAL_API_FRAME_OBJECT* frame,
173     int throw_flag,
174     PyObject* callback,
175     int* should_clear_frame);
176 static PyObject *(*previous_eval_frame)(PyThreadState *tstate,
177                                         THP_EVAL_API_FRAME_OBJECT* frame, int throw_flag) = NULL;
178 
179 #if PY_VERSION_HEX >= 0x03090000
custom_eval_frame_shim(PyThreadState * tstate,THP_EVAL_API_FRAME_OBJECT * frame,int throw_flag)180 static PyObject* custom_eval_frame_shim(
181     PyThreadState* tstate,
182     THP_EVAL_API_FRAME_OBJECT* frame,
183     int throw_flag) {
184   return _custom_eval_frame_shim(tstate, frame, throw_flag);
185 }
186 #else
custom_eval_frame_shim(THP_EVAL_API_FRAME_OBJECT * frame,int throw_flag)187 static PyObject* custom_eval_frame_shim(THP_EVAL_API_FRAME_OBJECT* frame, int throw_flag) {
188   PyThreadState* tstate = PyThreadState_GET();
189   return _custom_eval_frame_shim(tstate, frame, throw_flag);
190 }
191 #endif
192 
eval_frame_default(PyThreadState * tstate,THP_EVAL_API_FRAME_OBJECT * frame,int throw_flag)193 inline static PyObject* eval_frame_default(
194     PyThreadState* tstate,
195     THP_EVAL_API_FRAME_OBJECT* frame,
196     int throw_flag) {
197 #if PY_VERSION_HEX >= 0x03090000
198   if (tstate == NULL) {
199     tstate = PyThreadState_GET();
200   }
201   if (previous_eval_frame) {
202     return previous_eval_frame(tstate, frame, throw_flag);
203   }
204   else {
205     return _PyEval_EvalFrameDefault(tstate, frame, throw_flag);
206   }
207 #else
208   return _PyEval_EvalFrameDefault(frame, throw_flag);
209 #endif
210 }
211 
enable_eval_frame_shim(PyThreadState * tstate)212 inline static void enable_eval_frame_shim(PyThreadState* tstate) {
213 #if PY_VERSION_HEX >= 0x03090000
214   if (_PyInterpreterState_GetEvalFrameFunc(tstate->interp) !=
215       &custom_eval_frame_shim) {
216     DEBUG_CHECK(previous_eval_frame == NULL);
217     previous_eval_frame = _PyInterpreterState_GetEvalFrameFunc(tstate->interp);
218     _PyInterpreterState_SetEvalFrameFunc(tstate->interp,
219                                          &custom_eval_frame_shim);
220   }
221 #else
222   if (tstate->interp->eval_frame != &custom_eval_frame_shim) {
223     // First call
224     tstate->interp->eval_frame = &custom_eval_frame_shim;
225   }
226 #endif
227 }
228 
enable_eval_frame_default(PyThreadState * tstate)229 inline static void enable_eval_frame_default(PyThreadState* tstate) {
230 #if PY_VERSION_HEX >= 0x03090000
231   if (_PyInterpreterState_GetEvalFrameFunc(tstate->interp) !=
232       previous_eval_frame) {
233     DEBUG_CHECK(previous_eval_frame != NULL);
234     _PyInterpreterState_SetEvalFrameFunc(tstate->interp,
235                                          previous_eval_frame);
236     previous_eval_frame = NULL;
237   }
238 #else
239   if (tstate->interp->eval_frame != &_PyEval_EvalFrameDefault) {
240     // First call
241     tstate->interp->eval_frame = &_PyEval_EvalFrameDefault;
242   }
243 #endif
244 }
245 
246 
get_frame_name(THP_EVAL_API_FRAME_OBJECT * frame)247 inline static const char* get_frame_name(THP_EVAL_API_FRAME_OBJECT* frame) {
248   // Returns the C string name of the current frame.
249   DEBUG_CHECK(PyUnicode_Check(F_CODE(frame)->co_name));
250   return PyUnicode_AsUTF8(F_CODE(frame)->co_name);
251 }
252 
call_callback(PyObject * callable,THP_EVAL_API_FRAME_OBJECT * _frame,PyObject * locals,CacheEntry * cache_entry,FrameState * frame_state)253 static inline PyObject* call_callback(
254     PyObject* callable,
255     THP_EVAL_API_FRAME_OBJECT* _frame,
256     PyObject* locals,
257     CacheEntry* cache_entry,
258     FrameState* frame_state) {
259 
260 // remember to update the type signature for DynamoCallbackFn.__call__ in torch/_dynamo/types.py
261 // if this function changes
262 #if IS_PYTHON_3_11_PLUS
263   THPPyInterpreterFrame* frame = THPPyInterpreterFrame_New(_frame);
264   if (frame == NULL) {
265     return NULL;
266   }
267   frame->locals = locals;
268 #else
269   PyObject* frame = Py_NewRef(_frame);
270 #endif
271 
272   PyObject* cache_entry_pyobj = CacheEntry_to_obj(cache_entry);
273   PyObject* res = PyObject_CallFunction(
274     callable,
275     "OOO",
276     frame,
277     cache_entry_pyobj,
278     frame_state);
279   Py_DECREF(frame);
280   Py_DECREF(cache_entry_pyobj);
281   return res;
282 }
283 
clear_old_frame_if_python_312_plus(PyThreadState * tstate,THP_EVAL_API_FRAME_OBJECT * frame)284 static inline void clear_old_frame_if_python_312_plus(
285   PyThreadState* tstate,
286   THP_EVAL_API_FRAME_OBJECT* frame) {
287 #if IS_PYTHON_3_12_PLUS
288 
289   THP_PyFrame_Clear(frame);
290   THP_PyThreadState_PopFrame(tstate, frame);
291 
292 #endif
293 }
294 
eval_custom_code_impl(PyThreadState * tstate,THP_EVAL_API_FRAME_OBJECT * frame,PyCodeObject * code,int throw_flag,int free_vars_copied)295 inline static PyObject* eval_custom_code_impl(
296     PyThreadState* tstate,
297     THP_EVAL_API_FRAME_OBJECT* frame,
298     PyCodeObject* code,
299     int throw_flag,
300     int free_vars_copied) {
301 
302   DEBUG_NULL_CHECK(tstate);
303   DEBUG_NULL_CHECK(frame);
304   DEBUG_NULL_CHECK(code);
305 
306 #if IS_PYTHON_3_11_PLUS
307 
308   // Generate Python function object and _PyInterpreterFrame in a way similar to
309   // https://github.com/python/cpython/blob/e715da6db1d1d70cd779dc48e1ba8110c51cc1bf/Python/ceval.c#L1130
310 #if IS_PYTHON_3_12_PLUS
311   PyFunctionObject* old_func = (PyFunctionObject*) frame->f_funcobj;
312   size_t size = code->co_framesize;
313 #else
314   PyFunctionObject* old_func = frame->f_func;
315   size_t size = code->co_nlocalsplus + code->co_stacksize + FRAME_SPECIALS_SIZE;
316 #endif
317 
318   PyFunctionObject* func = _PyFunction_CopyWithNewCode(old_func, code);
319   if (func == NULL) {
320     return NULL;
321   }
322 
323   THP_EVAL_API_FRAME_OBJECT* shadow = THP_PyThreadState_BumpFramePointerSlow(tstate, size);
324   if (shadow == NULL) {
325     Py_DECREF(func);
326     return NULL;
327   }
328 
329   Py_INCREF(func);
330   // consumes reference to func
331 #if IS_PYTHON_3_12_PLUS
332   _PyFrame_Initialize(shadow, func, NULL, code, 0);
333 #else
334   _PyFrame_InitializeSpecials(shadow, func, NULL, code->co_nlocalsplus);
335 #endif
336 
337   PyObject** fastlocals_old = frame->localsplus;
338   PyObject** fastlocals_new = shadow->localsplus;
339   Py_ssize_t n_old = F_CODE(frame)->co_nlocalsplus;
340   Py_ssize_t n_new = code->co_nlocalsplus;
341 
342   // localsplus are XINCREF'd by default eval frame, so all values must be valid.
343 #if !(IS_PYTHON_3_12_PLUS)
344   // _PyFrame_Initialize in 3.12 already does this
345   for (int i = 0; i < code->co_nlocalsplus; i++) {
346     fastlocals_new[i] = NULL;
347   }
348 #endif
349 
350   // for 3.11+, if free_vars_copied is true, we do not need to
351   // run the first COPY_FREE_VARS since THP_PyFrame_FastToLocalsWithError
352   // already did the equivalent action.
353   if (free_vars_copied && _Py_OPCODE(_PyCode_CODE(F_CODE(shadow))[0]) == COPY_FREE_VARS) {
354     PREV_INSTR(shadow) = _PyCode_CODE(F_CODE(shadow));
355   }
356 
357 #else
358 
359   THP_EVAL_API_FRAME_OBJECT* shadow = PyFrame_New(tstate, code, frame->f_globals, NULL);
360   if (shadow == NULL) {
361     return NULL;
362   }
363 
364   PyObject** fastlocals_old = frame->f_localsplus;
365   PyObject** fastlocals_new = shadow->f_localsplus;
366   Py_ssize_t n_old = F_CODE(frame)->co_nlocals + PyCode_GetNFreevars(F_CODE(frame)) + PyCode_GetNCellvars(F_CODE(frame));
367   Py_ssize_t n_new = code->co_nlocals + PyCode_GetNFreevars(code) + PyCode_GetNCellvars(code);
368 
369 #endif
370 
371   // ============== Initialize new frame from old frame ============
372   // Python internal for executing a function:
373   //  1. CPython interpreter first creates an empty frame according to the code object
374   //  2. CPython interpreter initializes the frame by filling arguments/free variables into frame and initializing cell variables
375   //  3. CPython interpreter executes the code object
376   //
377   // Dynamo hooks the 3th step: before executing the code object, Dynamo transforms the code object into a new code object. Then, the old frame is not suitable for executing the new code. Therefore, Dynamo needs to manually create and initialize a new frame to execute the new code.
378   // The main task is to copy data in old frame to new frame, concerning a storage space named `localsplus`.
379   //
380   // localsplus storage is an array with the following layout:
381   // |   args   |   new_locals    |    cell_variables |   free_variables    |
382   // | <--- from left to right, index from 0 to n - 1 ---> |
383   // code.co_varnames == args + new_locals, code.co_nlocals == len(code.co_varnames)
384   // code.co_freevars == free_variables
385   // In Python 3.10 and lower, `n == code.co_nlocals + len(code.co_cellvars) + len(code.co_freevars)` (Python expression)
386   // In Python 3.11 and higher, `n <= code.co_nlocals + len(code.co_cellvars) + len(code.co_freevars)` (Python expression). There is an extra field in Python C-API: `n == code->co_nlocalsplus` (C expression) to retrieve the length of array.
387   // The complexity happens if an argument becomes a cell variable:
388   //  In Python 3.10 and lower, `code.co_cellvars == cell_variables`, and the corresponding slot in args becomes `NULL`.
389   //  In Python 3.11 and higher, `code.co_cellvars > cell_variables`, that cell variable is still stored in args, with a flag set in corresponding item's `co_localspluskinds` .
390   //
391   // ideally, we need to look up new localsplus from old localsplus by name:
392   // for i, name, value in enumerate(localsplusnames_old):
393   //   if value != NULL: (NULL happens for new local variables and arguments that becomes cell variables)
394   //     name_to_idx[name] = i
395   // for i, name in enumerate(localsplusnames_new):
396   //  if name in name_to_idx:
397   //    fastlocals_new[i] = fastlocals_old[name_to_idx[name]]
398   //
399   // The above process of building a `name_to_idx` mapping is expensive.
400   // Dynamo makes the following assumptions:
401   //  1. new code has the same arguments as the old code (both the number and the order)
402   //  2. new code has the same cell variables as the old code (both the number and the order)
403   //  3. new code has the same free variables as the old code (both the number and the order)
404   //  The only flexibility lies in new local variables: new code can introduce their own variables.
405   // With these assumptions, Dynamo can copy data directly by index. Dynamo just needs to take care of copying cell variables correctly.
406   // To avoid runtime cost, the assumptions are checked when we first generate the code object in pytorch/torch/_dynamo/convert_frame.py .
407 
408 
409   // copy args
410   // according to https://docs.python.org/3/library/inspect.html , `co_argcount` is the number of arguments (not including keyword only arguments, * or ** args). so we need to add `co_kwonlyargcount` and `co_flags` to get the total number of arguments.
411   // !!(F_CODE(frame)->co_flags & CO_VARARGS) is 1 if the function has *args, 0 otherwise
412   // !!(F_CODE(frame)->co_flags & CO_VARKEYWORDS) is 1 if the function has **kwargs, 0 otherwise
413   // they convert bit flags to 0 or 1, and avoid branching.
414   // This is performance critical code, so we really care about performance.
415   Py_ssize_t total_argcount_old = F_CODE(frame)->co_argcount + F_CODE(frame)->co_kwonlyargcount + !!(F_CODE(frame)->co_flags & CO_VARARGS) + !!(F_CODE(frame)->co_flags & CO_VARKEYWORDS);
416 
417   for (Py_ssize_t i = 0; i < total_argcount_old; i++) {
418     Py_XINCREF(fastlocals_old[i]);
419     fastlocals_new[i] = fastlocals_old[i];
420   }
421 
422   // copy free vars
423   Py_ssize_t nfrees_old = PyCode_GetNFreevars(F_CODE(frame));
424 
425   for (Py_ssize_t i = 0; i < nfrees_old; i++) {
426     Py_XINCREF(fastlocals_old[n_old - 1 - i]);
427     fastlocals_new[n_new - 1 - i] = fastlocals_old[n_old - 1 - i];
428   }
429 
430   // copy cell vars, from high index to low index, until it meets a variable that is not cell variable.
431   for (Py_ssize_t i = n_old - nfrees_old - 1, j = n_new - nfrees_old - 1; i >= total_argcount_old; i--, j--) {
432 
433   // conditional test to tell if a variable is not a cell variable
434   // this is straightforward in Python 3.11 and higher, as there are bit flags in `co_localspluskinds` to tell if a variable is a cell variable.
435   // in Python 3.10 and lower, essentially we are checking if a variable is a new local variable (because of the layout mentioned above, the first variable that is not cell variable is the first new local variable). the corresponding slot in `flocalsplus` is NULL for new local variables.
436 #if IS_PYTHON_3_11_PLUS
437     if(!(_PyLocals_GetKind(F_CODE(frame)->co_localspluskinds, i) & CO_FAST_CELL))
438     {
439       break;
440     }
441 #else
442     if(fastlocals_old[i] == NULL)
443     {
444       break;
445     }
446 #endif
447 
448     Py_XINCREF(fastlocals_old[i]);
449     fastlocals_new[j] = fastlocals_old[i];
450   }
451 
452   // NOTE: if you want to evaluate frame instead of shadow in 3.12+,
453   // you need to clear_old_frame_if_python_312_plus the shadow frame BEFORE
454   // calling eval_frame_default (i.e. here) and comment out the
455   // clear_old_frame_if_python_312_plus call on the original frame.
456 
457   PyObject* result = eval_frame_default(tstate, shadow, throw_flag);
458 
459 #if IS_PYTHON_3_12_PLUS
460 
461   // frame is cleared by caller
462   Py_DECREF(func);
463 
464 #elif IS_PYTHON_3_11_PLUS
465 
466   // In 3.11, shadow has is_entry set to true, so _PyEvalFrameClearAndPop is not called,
467   // so we manually clear and pop the shadow frame.
468   THP_PyFrame_Clear(shadow);
469   THP_PyThreadState_PopFrame(tstate, shadow);
470   Py_DECREF(func);
471 
472 #else
473 
474   Py_DECREF(shadow);
475 
476 #endif
477 
478   return result;
479 }
480 
481 // This wrapper function adds a profiler event
eval_custom_code(PyThreadState * tstate,THP_EVAL_API_FRAME_OBJECT * frame,PyCodeObject * code,int throw_flag,int free_vars_copied)482 inline static PyObject* eval_custom_code(
483     PyThreadState* tstate,
484     THP_EVAL_API_FRAME_OBJECT* frame,
485     PyCodeObject* code,
486     int throw_flag,
487     int free_vars_copied) {
488   const char* trace_id = compile_context;
489   _PytorchRecordFunctionState* rf = _pytorch_record_function_enter_with_context("Torch-Compiled Region", trace_id);
490   PyObject* result = eval_custom_code_impl(
491     tstate,
492     frame,
493     code,
494     throw_flag,
495     free_vars_copied
496   );
497   _pytorch_record_function_exit(rf);
498   return result;
499 }
500 
_custom_eval_frame_shim(PyThreadState * tstate,THP_EVAL_API_FRAME_OBJECT * frame,int throw_flag)501 static PyObject* _custom_eval_frame_shim(
502     PyThreadState* tstate,
503     THP_EVAL_API_FRAME_OBJECT* frame,
504     int throw_flag) {
505   // Shims logic into one of three states. Can probably be refactored into a
506   // single func, later:
507   //  - None: disables TorchDynamo
508   //  - False: run-only mode (reuse existing compiles)
509   //  - Python callable(): enables TorchDynamo
510   PyObject* callback = eval_frame_callback_get();
511 
512   if (callback == Py_None) {
513     return eval_frame_default(tstate, frame, throw_flag);
514   }
515 
516   int should_clear_frame = 0;
517   PyObject* result = _custom_eval_frame(tstate, frame, throw_flag, callback, &should_clear_frame);
518   if (should_clear_frame) {
519     clear_old_frame_if_python_312_plus(tstate, frame);
520   }
521   return result;
522 }
523 
524 static PyObject* skip_code_recursive_flag;
525 
526 // NOTE: In 3.12+, the frame evaluation function (callee) is responsible for clearing/popping
527 // the frame, meaning that unless we default evaluate the original frame,
528 // we are responsible for clearing it - via clear_old_frame_if_python_312_plus.
529 // The should_clear_frame flag is used to indicate whether the frame should be
530 // cleared by _custom_eval_frame's caller.
531 // Generally should_clear_frame should be set if and only we don't eval_frame_default.
_custom_eval_frame(PyThreadState * tstate,THP_EVAL_API_FRAME_OBJECT * frame,int throw_flag,PyObject * callback,int * should_clear_frame)532 static PyObject* _custom_eval_frame(
533     PyThreadState* tstate,
534     THP_EVAL_API_FRAME_OBJECT* frame,
535     int throw_flag,
536     PyObject* callback,
537     int* should_clear_frame) {
538 #if IS_PYTHON_3_11_PLUS
539   DEBUG_TRACE(
540       "begin %s %s %i %i",
541       get_frame_name(frame),
542       PyUnicode_AsUTF8(F_CODE(frame)->co_filename),
543       F_CODE(frame)->co_firstlineno,
544       _PyInterpreterFrame_LASTI(frame));
545 #else
546   DEBUG_TRACE(
547       "begin %s %s %i %i %i",
548       get_frame_name(frame),
549       PyUnicode_AsUTF8(F_CODE(frame)->co_filename),
550       frame->f_lineno,
551       frame->f_lasti,
552       frame->f_iblock);
553 #endif
554 
555   if (throw_flag) {
556     // When unwinding generators, eval frame is called with throw_flag ==
557     // true.  Frame evaluation is supposed to continue unwinding by propagating
558     // the exception.  Dynamo doesn't really know how to do this, nor does it
559     // really want to do this, because there's unlikely any code to capture
560     // (you're going to immediately quit out of the frame, perhaps running
561     // some unwinding logic along the way).  So we just run the default
562     // handler in this case.
563     //
564     // NB: A previous version of this patch returned NULL.  This is wrong,
565     // because returning NULL is *different* from unwinding an exception.
566     // In particular, you will not execute things like context manager
567     // __exit__ if you just return NULL.
568     //
569     // NB: It's /conceivable/ that you might want to actually still call the
570     // Dynamo callback when throw_flag == TRUE, to give Dynamo a chance to
571     // do any stack unwinding code.  But this is not really useful because
572     // (1) Dynamo doesn't actually know how to do stack unwinding, so it would
573     // immediately skip the frame, and (2) even if it did, this would only
574     // be profitable if there was tensor code in the unwinding code.  Seems
575     // unlikely.
576     DEBUG_TRACE("throw %s", get_frame_name(frame));
577     return eval_frame_default(tstate, frame, throw_flag);
578   }
579 
580   ExtraState* extra = get_extra_state(F_CODE(frame));
581   if (extra == SKIP_CODE || (callback == Py_False && extra == NULL)) {
582     DEBUG_TRACE("skip %s", get_frame_name(frame));
583     return eval_frame_default(tstate, frame, throw_flag);
584   }
585   if (extra == SKIP_CODE_RECURSIVE) {
586     DEBUG_TRACE("skip recursive %s", get_frame_name(frame));
587     eval_frame_callback_set(Py_None);
588     PyObject* result = eval_frame_default(tstate, frame, throw_flag);
589     eval_frame_callback_set(callback);
590     return result;
591   }
592 
593   if (extra == NULL) {
594     extra = init_and_set_extra_state(F_CODE(frame));
595   }
596 
597 
598   int free_vars_copied = 0;
599   #if IS_PYTHON_3_12_PLUS
600   PyObject *locals = get_framelocals_mapping(frame);
601   #else
602   if (THP_PyFrame_FastToLocalsWithError(frame, &free_vars_copied) < 0) {
603     DEBUG_TRACE("error %s", get_frame_name(frame));
604     *should_clear_frame = 1;
605     return NULL;
606   }
607   PyObject *locals = frame->f_locals;
608   Py_INCREF(locals);
609   #endif
610 
611   PyObject* backend = get_backend(callback);
612 
613   // A callback of Py_False indicates "run only" mode, the cache is checked, but
614   // we never compile.
615   if (callback == Py_False) {
616     DEBUG_TRACE("In run only mode %s", get_frame_name(frame));
617     _PytorchRecordFunctionState* rf = _pytorch_record_function_enter(cache_lookup_profiler_str);
618     PyObject* maybe_cached_code = lookup(extra, locals, backend);
619     _pytorch_record_function_exit(rf);
620 
621     Py_DECREF(locals);
622 
623     if (maybe_cached_code == NULL) {
624       // guard eval failed, keep propagating
625       *should_clear_frame = 1;
626       return NULL;
627     } else if (maybe_cached_code == Py_None) {
628       DEBUG_TRACE("cache miss %s", get_frame_name(frame));
629       return eval_frame_default(tstate, frame, throw_flag);
630     }
631     PyCodeObject* cached_code = (PyCodeObject*)maybe_cached_code;
632     // used cached version
633     DEBUG_TRACE("cache hit %s", get_frame_name(frame));
634     *should_clear_frame = 1;
635     return eval_custom_code(tstate, frame, cached_code, throw_flag, 0);
636   }
637   DEBUG_CHECK(PyDict_CheckExact(locals));
638   DEBUG_CHECK(PyDict_CheckExact(frame->f_globals));
639   DEBUG_CHECK(PyDict_CheckExact(frame->f_builtins));
640 
641   // We don't run the current custom_eval_frame behavior for guards.
642   // So we temporarily set the callback to Py_None to drive the correct behavior
643   // in the shim.
644   eval_frame_callback_set(Py_None);
645 
646   _PytorchRecordFunctionState* rf = _pytorch_record_function_enter(cache_lookup_profiler_str);
647   PyObject* maybe_cached_code = lookup(extra, locals, backend);
648   _pytorch_record_function_exit(rf);
649   if (maybe_cached_code == NULL) {
650     // Python error
651     *should_clear_frame = 1;
652     Py_DECREF(locals);
653     return NULL;
654   } else if (maybe_cached_code != Py_None) {
655     PyCodeObject* cached_code = (PyCodeObject*)maybe_cached_code;
656     // used cached version
657     DEBUG_TRACE("cache hit %s", get_frame_name(frame));
658     // Re-enable custom behavior
659     eval_frame_callback_set(callback);
660     *should_clear_frame = 1;
661     Py_DECREF(locals);
662     return eval_custom_code(tstate, frame, cached_code, throw_flag, free_vars_copied);
663   }
664   // cache miss
665   CacheEntry* cache_entry = extract_cache_entry(extra);
666   FrameState* frame_state = extract_frame_state(extra);
667   PyObject* result =
668       call_callback(callback, frame, locals, cache_entry, frame_state);
669   Py_DECREF(locals);
670   if (result == NULL) {
671     // internal exception, returning here will leak the exception into user code
672     // this is useful for debugging -- but we dont want it to happen outside of
673     // testing
674     // NB: we intentionally DO NOT re-enable custom behavior to prevent
675     // cascading failure from internal exceptions.  The upshot is if
676     // Dynamo barfs, that's it for Dynamo, even if you catch the exception
677     // inside the torch.compile block we won't try to Dynamo anything else.
678     *should_clear_frame = 1;
679     return NULL;
680   } else if (result == skip_code_recursive_flag) {
681     // Dynamo returned skip_code_recursive_flag, so we should recursively skip code.
682     DEBUG_TRACE("create skip recursive %s", get_frame_name(frame));
683     set_extra_state(F_CODE(frame), SKIP_CODE_RECURSIVE);
684     PyObject* r = eval_frame_default(tstate, frame, throw_flag);
685     // Re-enable custom behavior
686     eval_frame_callback_set(callback);
687     return r;
688   } else if (result != Py_None) {
689     DEBUG_TRACE("create cache %s", get_frame_name(frame));
690 
691     // NB: We could use extract_cache_entry to get the cache_entry, but
692     // extract_cache_entry returns a borrowed reference. Modifying a borrowed
693     // reference seems wrong. Therefore, we directly access the
694     // extra->cache_entry. extra wont be NULL here.
695     CacheEntry* new_cache_entry = create_cache_entry(extra, result, backend);
696     Py_DECREF(result);
697 
698     // Update the existing cache_entry on the extra object. This extra object is
699     // sitting on the extra scratch space, we are just changing the cache_entry
700     // ptr. As a result, extra now becomes the owner of CacheEntry object. This
701     // will be cleaned up when set_extra_state is called.
702     // Re-enable custom behavior
703     eval_frame_callback_set(callback);
704     *should_clear_frame = 1;
705     return eval_custom_code(tstate, frame, CacheEntry_get_code(new_cache_entry), throw_flag, free_vars_copied);
706   } else {
707     DEBUG_TRACE("create skip %s", get_frame_name(frame));
708     Py_DECREF(result);
709     set_extra_state(F_CODE(frame), SKIP_CODE);
710     // Re-enable custom behavior
711     eval_frame_callback_set(callback);
712     return eval_frame_default(tstate, frame, throw_flag);
713   }
714 }
715 
716 #else // IS_PYTHON_3_14_PLUS
717 
718 // Fake definitions for everything we removed
719 
720 typedef struct THPPyInterpreterFrame {
721   PyObject_HEAD
722   _PyInterpreterFrame* frame; // Borrowed reference
723 } THPPyInterpreterFrame;
724 
enable_eval_frame_shim(PyThreadState * tstate)725 inline static void enable_eval_frame_shim(PyThreadState* tstate) {}
enable_eval_frame_default(PyThreadState * tstate)726 inline static void enable_eval_frame_default(PyThreadState* tstate) {}
727 
728 static struct PyGetSetDef THPPyInterpreterFrame_properties[] = {NULL};
729 
730 static PyTypeObject THPPyInterpreterFrameType = {
731     PyVarObject_HEAD_INIT(NULL, 0)
732     .tp_name = "torch._C._dynamo.eval_frame._PyInterpreterFrame",
733     .tp_basicsize = sizeof(THPPyInterpreterFrame),
734     .tp_flags = Py_TPFLAGS_DEFAULT,
735     .tp_getset = THPPyInterpreterFrame_properties,
736 };
737 
738 #endif // CPython 3.14
739 
increment_working_threads(PyThreadState * tstate)740 static PyObject* increment_working_threads(PyThreadState* tstate) {
741   active_dynamo_threads = active_dynamo_threads + 1;
742   if (active_dynamo_threads > 0) {
743     enable_eval_frame_shim(tstate);
744   }
745   Py_RETURN_NONE;
746 }
747 
decrement_working_threads(PyThreadState * tstate)748 static PyObject* decrement_working_threads(PyThreadState* tstate) {
749   if (active_dynamo_threads > 0) {
750     active_dynamo_threads = active_dynamo_threads - 1;
751     if (active_dynamo_threads == 0) {
752       enable_eval_frame_default(tstate);
753     }
754   }
755   Py_RETURN_NONE;
756 }
757 
set_eval_frame(PyObject * new_callback,PyThreadState * tstate)758 static PyObject* set_eval_frame(PyObject* new_callback, PyThreadState* tstate) {
759   // Change the eval frame callback and return the old one
760   //  - None: disables TorchDynamo
761   //  - False: run-only mode (reuse existing compiles)
762   //  - Python callable(): enables TorchDynamo
763   PyObject* old_callback = eval_frame_callback_get();
764 
765   // owned by caller
766   Py_INCREF(old_callback);
767 
768   if (old_callback != Py_None && new_callback == Py_None) {
769     decrement_working_threads(tstate);
770   } else if (old_callback == Py_None && new_callback != Py_None) {
771     increment_working_threads(tstate);
772   }
773 
774   Py_INCREF(new_callback);
775   Py_DECREF(old_callback);
776 
777   // Set thread local callback. This will drive behavior of our shim, if/when it
778   // is installed.
779   eval_frame_callback_set(new_callback);
780 
781   return old_callback;
782 }
783 
set_eval_frame_py(PyObject * dummy,PyObject * callback)784 static PyObject* set_eval_frame_py(PyObject* dummy, PyObject* callback) {
785   if (callback != Py_None && callback != Py_False &&
786       !PyCallable_Check(callback)) {
787     DEBUG_TRACE0("arg error");
788     PyErr_SetString(PyExc_TypeError, "expected a callable");
789     return NULL;
790   }
791   DEBUG_TRACE(
792       "python enabled=%d and is run_only=%d",
793       callback != Py_None,
794       callback == Py_False);
795   return set_eval_frame(callback, PyThreadState_GET());
796 }
797 
reset_code(PyObject * dummy,PyObject * code)798 static PyObject* reset_code(PyObject* dummy, PyObject* code) {
799   if (!PyCode_Check(code)) {
800     DEBUG_TRACE0("arg error");
801     PyErr_SetString(PyExc_TypeError, "expected a code object");
802     return NULL;
803   }
804 
805   // set_extra_state destroys the existing object on extra scratch space.
806   set_extra_state((PyCodeObject*)code, NULL);
807   Py_RETURN_NONE;
808 }
809 
unsupported(PyObject * dummy,PyObject * args)810 static PyObject* unsupported(PyObject* dummy, PyObject* args) {
811   // a dummy C function used in testing
812   PyObject* obj1 = NULL;
813   PyObject* obj2 = NULL;
814   if (!PyArg_ParseTuple(args, "OO", &obj1, &obj2)) {
815     return NULL;
816   }
817   Py_INCREF(obj2);
818   return obj2;
819 }
820 
skip_code(PyObject * dummy,PyObject * obj)821 static PyObject* skip_code(PyObject* dummy, PyObject* obj) {
822   if (!PyCode_Check(obj)) {
823     PyErr_SetString(PyExc_TypeError, "expected a code object");
824     return NULL;
825   }
826 
827   // set_extra_state destroys the existing object on extra scratch space.
828   set_extra_state((PyCodeObject*)obj, SKIP_CODE);
829   Py_RETURN_NONE;
830 }
831 
set_guard_error_hook(PyObject * dummy,PyObject * obj)832 static PyObject* set_guard_error_hook(PyObject* dummy, PyObject* obj) {
833   if (obj == Py_None) {
834     obj = NULL;
835   }
836   Py_XSETREF(guard_error_hook, Py_XNewRef(obj));
837   Py_RETURN_NONE;
838 }
839 
set_context_frame(PyObject * dummy,PyObject * obj)840 static PyObject* set_context_frame(PyObject* dummy, PyObject* obj) {
841     int frame_id, frame_compile_id, attempt;
842     if (!PyArg_ParseTuple(obj, "iii", &frame_id, &frame_compile_id, &attempt)) {
843         PyErr_SetString(PyExc_TypeError, "Expected three integers");
844         return NULL;
845     }
846     if (attempt == 0) {
847       sprintf(compile_context, "%d/%d", frame_id, frame_compile_id);
848     } else {
849       sprintf(compile_context, "%d/%d_%d", frame_id, frame_compile_id, attempt);
850     }
851     Py_RETURN_NONE;
852 }
853 
854 static PyMethodDef _methods[] = {
855     {"set_eval_frame", set_eval_frame_py, METH_O, NULL},
856     {"reset_code", reset_code, METH_O, NULL},
857     {"unsupported", unsupported, METH_VARARGS, NULL},
858     {"skip_code", skip_code, METH_O, NULL},
859     {"set_guard_error_hook", set_guard_error_hook, METH_O, NULL},
860     {"set_context_frame", set_context_frame, METH_O, NULL},
861     {NULL, NULL, 0, NULL}};
862 
863 static struct PyModuleDef _module = {
864     PyModuleDef_HEAD_INIT,
865     "torch._C._dynamo.eval_frame",
866     "Module containing hooks to override eval_frame",
867     -1,
868     _methods};
869 
870 #if IS_PYTHON_3_12_PLUS
871 #define _PyEval_RequestCodeExtraIndex PyUnstable_Eval_RequestCodeExtraIndex
872 #endif
873 
torch_c_dynamo_eval_frame_init(void)874 PyObject* torch_c_dynamo_eval_frame_init(void) {
875   extra_index = _PyEval_RequestCodeExtraIndex(destroy_extra_state);
876   if (extra_index < 0) {
877     PyErr_SetString(PyExc_RuntimeError,
878                     "dynamo: unable to register extra index");
879     return NULL;
880   }
881 
882   int result = PyThread_tss_create(&eval_frame_callback_key);
883   CHECK(result == 0);
884 
885   Py_INCREF(Py_None);
886   eval_frame_callback_set(Py_None);
887 
888   PyObject* module = PyModule_Create(&_module);
889   if (module == NULL) {
890     return NULL;
891   }
892 
893 #if IS_PYTHON_3_11_PLUS
894   if (PyType_Ready(&THPPyInterpreterFrameType) < 0) {
895     return NULL;
896   }
897   Py_INCREF(&THPPyInterpreterFrameType);
898   if (PyModule_AddObject(module, "_PyInterpreterFrame", (PyObject*)&THPPyInterpreterFrameType) != 0) {
899     return NULL;
900   }
901 #endif
902 
903   skip_code_recursive_flag = PyObject_New(PyObject, &PyBaseObject_Type);
904   if (skip_code_recursive_flag == NULL) {
905     return NULL;
906   }
907   if (PyModule_AddObject(module, "skip_code_recursive_flag", skip_code_recursive_flag) != 0) {
908     return NULL;
909   }
910 
911   return module;
912 }
913