xref: /aosp_15_r20/external/pytorch/torch/csrc/dynamo/guards.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/PythonTorchFunctionTLS.h>
2 #include <c10/core/SafePyObject.h>
3 #include <c10/core/impl/PyInterpreter.h>
4 #define PY_SSIZE_T_CLEAN
5 #include <ATen/EmptyTensor.h>
6 #include <ATen/SparseCsrTensorUtils.h>
7 #include <c10/util/flat_hash_map.h>
8 #include <torch/csrc/autograd/grad_mode.h>
9 #include <torch/csrc/autograd/utils/wrap_outputs.h>
10 #include <torch/csrc/dynamo/guards.h>
11 #include <torch/csrc/inductor/inductor_ops.h>
12 #include <torch/csrc/utils/disable_torch_function.h>
13 #include <torch/csrc/utils/python_arg_parser.h>
14 #include <torch/csrc/utils/python_compat.h>
15 #include <torch/csrc/utils/python_numbers.h>
16 #include <torch/csrc/utils/python_symnode.h>
17 #include <torch/csrc/utils/pythoncapi_compat.h>
18 #include <torch/extension.h>
19 
20 #ifdef USE_CUDA
21 #include <ATen/cuda/EmptyTensor.h>
22 #endif
23 
24 #ifdef USE_XPU
25 #include <ATen/xpu/EmptyTensor.h>
26 #endif
27 
28 #include <sstream>
29 #include <utility>
30 
31 // For TupleIteratorGetItemAccessor, we need a fast way to retrieve the
32 // underlying tuple and access the item. Before Python 3.12 version, the
33 // datastructure is in tupleobject.c file -
34 // https://github.com/python/cpython/blob/9afc6d102d16080535325f645849cd84eb04d57d/Objects/tupleobject.c#L1058-L1062
35 // To handle this, we manually copy the struct here and manually cast it to this
36 // new struct. From 3.12, the struct is included in the header file.
37 #if IS_PYTHON_3_12_PLUS
38 
39 #define Py_BUILD_CORE
40 // Bring _PyTupleIterObject from the header file
41 #include <internal/pycore_tuple.h>
42 #undef Py_BUILD_CORE
43 
44 #else
45 
46 // Manually create _PyTupleIterObject struct
47 typedef struct {
48   PyObject_HEAD Py_ssize_t it_index;
49   PyTupleObject* it_seq; /* Set to NULL when iterator is exhausted */
50 } _PyTupleIterObject;
51 
52 #endif // IS_PYTHON_3_12_PLUS
53 
54 namespace torch::dynamo {
55 
56 // Macro to skip addition of duplicate guards like EQUALS_MATCH
57 #define SKIP_IF_GUARD_ALREADY_PRESENT(name) \
58   if (self.is_leaf_guard_present(name)) {   \
59     return;                                 \
60   }                                         \
61   self.insert_leaf_guard(name);
62 
TensorCheck(const LocalState & state,PyTypeObject * pt,const at::Tensor & v,std::vector<std::optional<c10::SymInt>> dynamic_dims_sizes,std::vector<std::optional<c10::SymInt>> dynamic_dims_strides)63 TensorCheck::TensorCheck(
64     const LocalState& state,
65     PyTypeObject* pt,
66     const at::Tensor& v,
67     std::vector<std::optional<c10::SymInt>> dynamic_dims_sizes,
68     std::vector<std::optional<c10::SymInt>> dynamic_dims_strides)
69     : pytype(pt),
70       dispatch_key_(state.apply(v.key_set()).raw_repr()),
71       dtype_(v.dtype().toScalarType()),
72       device_index_(v.device().index()),
73       requires_grad_(v.requires_grad()),
74       sizes_(std::move(dynamic_dims_sizes)),
75       strides_(std::move(dynamic_dims_strides)),
76       dim_(static_cast<int64_t>(sizes_.size())) {
77   // TODO(voz): In cases where sizes_ and strides_ are fully dynamic, should
78   // we just treat this as optional?
79 }
80 
TensorCheck(const LocalState & state,PyTypeObject * pt,c10::DispatchKeySet dispatch_key_set,at::ScalarType dtype,at::DeviceIndex device_index,bool requires_grad,std::vector<std::optional<c10::SymInt>> dynamic_dims_sizes,std::vector<std::optional<c10::SymInt>> dynamic_dims_strides)81 TensorCheck::TensorCheck(
82     const LocalState& state,
83     PyTypeObject* pt,
84     c10::DispatchKeySet dispatch_key_set,
85     at::ScalarType dtype,
86     at::DeviceIndex device_index,
87     bool requires_grad,
88     std::vector<std::optional<c10::SymInt>> dynamic_dims_sizes,
89     std::vector<std::optional<c10::SymInt>> dynamic_dims_strides)
90     : pytype(pt),
91       dispatch_key_(state.apply(dispatch_key_set).raw_repr()),
92       dtype_(dtype),
93       device_index_(device_index),
94       requires_grad_(requires_grad),
95       sizes_(std::move(dynamic_dims_sizes)),
96       strides_(std::move(dynamic_dims_strides)),
97       dim_(static_cast<int64_t>(sizes_.size())) {}
98 
99 // See note in guards.py [Note - On Export Tensor Guards]
100 // Logic parallel to here must be maintained in python
check(const LocalState & state,const at::Tensor & v)101 bool TensorCheck::check(const LocalState& state, const at::Tensor& v) {
102   // In terms of a sparse_csr tensor, it does not support strides informatio
103   c10::SymIntArrayRef sym_strides(std::vector<SymInt>(v.ndimension(), -1));
104   bool does_not_support_stride = v.layout() == c10::kSparseCsr ||
105       v.layout() == c10::kSparseCsc || v.layout() == c10::kSparseBsc ||
106       v.layout() == c10::kSparseBsr;
107   if (!does_not_support_stride) {
108     sym_strides = v.sym_strides();
109   }
110 
111   return check(
112       state,
113       v.key_set(),
114       v.dtype().toScalarType(),
115       v.device(),
116       v.sym_sizes(),
117       sym_strides,
118       v.requires_grad());
119 }
120 
check(const LocalState & state,const c10::DispatchKeySet & dispatch_key_set,const at::ScalarType & dtype,const c10::Device & device,const c10::SymIntArrayRef & sym_sizes,const c10::SymIntArrayRef & sym_strides,const bool & requires_grad)121 bool TensorCheck::check(
122     const LocalState& state,
123     const c10::DispatchKeySet& dispatch_key_set,
124     const at::ScalarType& dtype,
125     const c10::Device& device,
126     const c10::SymIntArrayRef& sym_sizes,
127     const c10::SymIntArrayRef& sym_strides,
128     const bool& requires_grad) {
129   if (dispatch_key_ != state.apply(dispatch_key_set).raw_repr() ||
130       dtype_ != dtype || device_index_ != device.index() ||
131       requires_grad_ != requires_grad) {
132     return false;
133   }
134 
135   auto ndim = sym_sizes.size();
136   if (ndim != static_cast<size_t>(dim_)) {
137     return false;
138   }
139 
140   const auto& sizes = sym_sizes;
141   const auto& strides = sym_strides;
142   for (auto i : c10::irange(ndim)) {
143     auto known_size = sizes_[i];
144     auto known_stride = strides_[i];
145     if (known_size.has_value()) {
146       if (known_size.value() != sizes[i]) {
147         return false;
148       }
149     }
150     if (known_stride.has_value()) {
151       if (known_stride.value() != strides[i]) {
152         return false;
153       }
154     }
155   }
156   return true;
157 }
158 
check_verbose(const LocalState & state,const at::Tensor & v,const std::string & tensor_name)159 std::string TensorCheck::check_verbose(
160     const LocalState& state,
161     const at::Tensor& v,
162     const std::string& tensor_name) {
163   std::stringstream fail_reason;
164   fail_reason << "tensor '" << tensor_name << "' ";
165   if (dispatch_key_ != state.apply(v.key_set()).raw_repr()) {
166     // return fmt::format("tensor dispatch key mismatch. expected {}, actual
167     // {}", dispatch_key_, state.apply(v.key_set()).raw_repr());
168     fail_reason << "dispatch key set mismatch. expected "
169                 << c10::DispatchKeySet(c10::DispatchKeySet::RAW, dispatch_key_)
170                 << ", actual " << state.apply(v.key_set());
171     return fail_reason.str();
172   } else if (dtype_ != v.dtype().toScalarType()) {
173     // return fmt::format("tensor dtype mismatch. expected {}, actual {}",
174     // dtype_, v.dtype().toScalarType());
175     fail_reason << "dtype mismatch. expected " << dtype_ << ", actual "
176                 << v.dtype().toScalarType();
177     return fail_reason.str();
178   } else if (device_index_ != v.device().index()) {
179     fail_reason << "Tensor device index mismatch. Expected device index to be "
180                 << device_index_ << ", actual " << v.device().index();
181     return fail_reason.str();
182   } else if (requires_grad_ != v.requires_grad()) {
183     // return fmt::format("tensor requires_grad mismatch. expected {}",
184     // requires_grad_);
185     fail_reason << "requires_grad mismatch. expected requires_grad="
186                 << requires_grad_;
187     return fail_reason.str();
188   }
189   auto ndim = v.ndimension();
190   if (ndim != dim_) {
191     // return fmt::format("tensor rank mismatch. expected {}, actual {}",
192     // sizes_.size(), ndim);
193     fail_reason << "rank mismatch. expected " << sizes_.size() << ", actual "
194                 << ndim;
195     return fail_reason.str();
196   }
197   const auto& sizes = v.sym_sizes();
198   for (auto i : c10::irange(ndim)) {
199     auto known_size = sizes_[i];
200     if (known_size.has_value() && (known_size.value() != sizes[i])) {
201       fail_reason << "size mismatch at index " << i << ". expected "
202                   << known_size.value() << ", actual " << sizes[i];
203       return fail_reason.str();
204     }
205   }
206   const bool supports_stride =
207       !v.is_sparse() && !at::sparse_csr::is_sparse_compressed(v);
208   if (supports_stride) {
209     const auto& strides = v.sym_strides();
210     for (auto i : c10::irange(ndim)) {
211       auto known_stride = strides_[i];
212       if (known_stride.has_value() && known_stride.value() != strides[i]) {
213         fail_reason << "stride mismatch at index " << i << ". expected "
214                     << known_stride.value() << ", actual " << strides[i];
215         return fail_reason.str();
216       }
217     }
218   }
219   return "";
220 }
221 
222 namespace {
223 
224 typedef std::vector<TensorCheck> ChecksList;
225 
226 typedef struct {
227   PyObject_HEAD;
228   ChecksList* checks;
229 } TensorGuards;
230 
TensorGuards_dealloc(TensorGuards * self)231 static void TensorGuards_dealloc(TensorGuards* self) {
232   if (self->checks != nullptr) {
233     delete self->checks;
234     self->checks = nullptr;
235   }
236   Py_TYPE(self)->tp_free((PyObject*)self);
237 }
238 
TensorGuards_new(PyTypeObject * type,PyObject * args,PyObject * kwds)239 static PyObject* TensorGuards_new(
240     PyTypeObject* type,
241     PyObject* args,
242     PyObject* kwds) {
243   TensorGuards* self = (TensorGuards*)type->tp_alloc(type, 0);
244   if (self != nullptr) {
245     self->checks = new ChecksList();
246   }
247   return (PyObject*)self;
248 }
249 
wrapIntegersInOptional(const c10::SymIntArrayRef & intArray)250 static std::vector<std::optional<c10::SymInt>> wrapIntegersInOptional(
251     const c10::SymIntArrayRef& intArray) {
252   std::vector<std::optional<c10::SymInt>> optVec(intArray.size());
253   std::transform(
254       intArray.begin(),
255       intArray.end(),
256       optVec.begin(),
257       [](const c10::SymInt& value) { return std::make_optional(value); });
258   return optVec;
259 }
260 
pyListToVecOptInt(PyObject * pyList)261 static std::vector<std::optional<c10::SymInt>> pyListToVecOptInt(
262     PyObject* pyList) {
263   std::vector<std::optional<c10::SymInt>> vec;
264   Py_ssize_t size = PyList_Size(pyList);
265   for (Py_ssize_t i = 0; i < size; i++) {
266     PyObject* item = PyList_GetItem(pyList, i);
267     auto handle = py::handle(item);
268     if (item == Py_None) {
269       vec.emplace_back(std::nullopt);
270     } else if (torch::is_symint(handle)) {
271       vec.emplace_back(py::cast<c10::SymInt>(handle));
272     } else {
273       int64_t value = PyLong_AsLongLong(item);
274       if (value == -1 && PyErr_Occurred()) {
275         PyErr_SetString(
276             PyExc_TypeError,
277             "Size or stride list item is not a valid integer.");
278         TORCH_CHECK(false, "Size or stride list item is not a valid integer.");
279       }
280       vec.emplace_back(c10::SymInt(value));
281     }
282   }
283   return vec;
284 }
285 
get_dynamic_dims(PyObject * dynamic_dims_py)286 static std::vector<std::vector<std::optional<c10::SymInt>>> get_dynamic_dims(
287     PyObject* dynamic_dims_py) {
288   std::vector<std::vector<std::optional<c10::SymInt>>> per_tensor_dynamic_dims;
289   if (dynamic_dims_py != Py_None) {
290     Py_ssize_t size = PyList_Size(dynamic_dims_py);
291     for (Py_ssize_t i = 0; i < size; i++) {
292       PyObject* py_list = PyList_GetItem(dynamic_dims_py, i);
293       std::vector<std::optional<c10::SymInt>> vec = pyListToVecOptInt(py_list);
294       per_tensor_dynamic_dims.push_back(std::move(vec));
295     }
296   }
297   return per_tensor_dynamic_dims;
298 }
299 
TensorGuards_init(TensorGuards * self,PyObject * args,PyObject * kwds)300 static int TensorGuards_init(
301     TensorGuards* self,
302     PyObject* args,
303     PyObject* kwds) {
304   if (!PyTuple_CheckExact(args)) {
305     PyErr_SetString(PyExc_TypeError, "expected tuple()");
306     return -1;
307   }
308   // Top level structure is List[List[Union[int, None]]]
309   PyObject* dynamic_dims_sizes_py =
310       PyDict_GetItemString(kwds, "dynamic_dims_sizes");
311   if (dynamic_dims_sizes_py == nullptr) {
312     PyErr_SetString(PyExc_TypeError, "missing dynamic_dims_sizes=...");
313     return -1;
314   }
315   PyObject* dynamic_dims_strides_py =
316       PyDict_GetItemString(kwds, "dynamic_dims_strides");
317   if (dynamic_dims_strides_py == nullptr) {
318     PyErr_SetString(PyExc_TypeError, "missing dynamic_dims_strides=...");
319     return -1;
320   }
321 
322   // dynamic_dims_strides/sizes_py is None when dynamic_shapes=False - this is
323   // an optimization to avoid invoking .size()/.stride() in python needlessly
324   std::vector<std::vector<std::optional<c10::SymInt>>>
325       per_tensor_dynamic_dims_sizes = get_dynamic_dims(dynamic_dims_sizes_py);
326   std::vector<std::vector<std::optional<c10::SymInt>>>
327       per_tensor_dynamic_dims_strides =
328           get_dynamic_dims(dynamic_dims_strides_py);
329 
330   auto& checks = *self->checks;
331   auto len = PyTuple_GET_SIZE(args);
332   checks.reserve(len);
333   LocalState state;
334 
335   for (auto i : c10::irange(len)) {
336     PyObject* item = PyTuple_GET_ITEM(args, i);
337     if (!THPVariable_CheckExact(item) && !THPVariable_Check(item)) {
338       PyErr_SetString(PyExc_TypeError, "expected Tensor()");
339       return -1;
340     }
341     auto tensor = THPVariable_Unpack(item);
342     std::vector<std::optional<c10::SymInt>> tensor_dims_size =
343         per_tensor_dynamic_dims_sizes.empty()
344         ? wrapIntegersInOptional(tensor.sym_sizes())
345         : per_tensor_dynamic_dims_sizes[i];
346     std::vector<std::optional<c10::SymInt>> tensor_dims_stride =
347         per_tensor_dynamic_dims_strides.empty()
348         ? wrapIntegersInOptional(tensor.sym_strides())
349         : per_tensor_dynamic_dims_strides[i];
350 
351     checks.emplace_back(
352         state,
353         Py_TYPE(item),
354         std::move(tensor),
355         std::move(tensor_dims_size),
356         std::move(tensor_dims_stride));
357   }
358   return 0;
359 }
360 
TensorGuards_check(TensorGuards * self,PyObject * args,PyObject * kwargs)361 PyObject* TensorGuards_check(
362     TensorGuards* self,
363     PyObject* args,
364     PyObject* kwargs) {
365   if (!PyTuple_CheckExact(args)) {
366     PyErr_SetString(PyExc_TypeError, "expected tuple()");
367     return nullptr;
368   }
369   auto& checks = *self->checks;
370   auto len = PyTuple_GET_SIZE(args);
371 
372   // kwargs is just ignored here
373 
374   if (static_cast<decltype(len)>(checks.size()) != len) {
375     PyErr_SetString(PyExc_TypeError, "wrong length");
376     return nullptr;
377   }
378 
379   LocalState state;
380   // Note - all the tensors that make it to guards must be unique. Dynamo
381   // builder handles guarding for positive aliases (X is Y). However, we do not
382   // create guards for negative alias (X is not Y) as that is an N^2
383   // relationship. Instead, we rely on the uniqueness upstream to verify, at
384   // check_fn time (this function).
385   ska::flat_hash_map<PyObject*, std::nullptr_t> unique_tensors;
386   for (auto i : c10::irange(len)) {
387     PyObject* item = PyTuple_GET_ITEM(args, i);
388 
389     if (Py_TYPE(item) != checks[i].pytype) {
390       Py_RETURN_FALSE;
391     }
392     auto insertion = unique_tensors.insert({item, nullptr});
393     if (!insertion.second) {
394       // Violates uniqueness
395       Py_RETURN_FALSE;
396     }
397     if (!checks[i].check(state, THPVariable_Unpack(item))) {
398       Py_RETURN_FALSE;
399     }
400   }
401 
402   Py_RETURN_TRUE;
403 }
404 
TensorGuards_check_verbose(TensorGuards * self,PyObject * args,PyObject * kwargs)405 PyObject* TensorGuards_check_verbose(
406     TensorGuards* self,
407     PyObject* args,
408     PyObject* kwargs) {
409   if (!PyTuple_CheckExact(args)) {
410     PyErr_SetString(PyExc_TypeError, "expected tuple()");
411     return nullptr;
412   }
413   auto& checks = *self->checks;
414   auto len = PyTuple_GET_SIZE(args);
415 
416   if (static_cast<decltype(len)>(checks.size()) != len) {
417     PyErr_SetString(PyExc_TypeError, "wrong length");
418     return nullptr;
419   }
420 
421   PyObject* tensor_check_names_py =
422       PyDict_GetItemString(kwargs, "tensor_check_names");
423   if (tensor_check_names_py == nullptr) {
424     PyErr_SetString(PyExc_TypeError, "missing tensor_check_names kwarg");
425     return nullptr;
426   }
427 
428   if (!PyList_Check(tensor_check_names_py)) {
429     PyErr_SetString(PyExc_TypeError, "tensor_check_names kwarg must be a list");
430     return nullptr;
431   }
432 
433   auto names_size = PyList_Size(tensor_check_names_py);
434   if (names_size != static_cast<decltype(names_size)>(checks.size())) {
435     PyErr_SetString(
436         PyExc_TypeError,
437         "tensor_check_names should be the same size as # tensors");
438     return nullptr;
439   }
440 
441   std::vector<std::string> tensor_check_names;
442   tensor_check_names.reserve(names_size);
443   for (auto i : c10::irange(names_size)) {
444     PyObject* value = PyList_GetItem(tensor_check_names_py, i);
445     if (!PyUnicode_Check(value)) {
446       PyErr_SetString(
447           PyExc_TypeError, "tensor_check_names must only contain strings");
448       return nullptr;
449     }
450     tensor_check_names.emplace_back(PyUnicode_AsUTF8(value));
451   }
452 
453   LocalState state;
454   ska::flat_hash_map<PyObject*, std::nullptr_t> unique_tensors;
455   for (auto i : c10::irange(len)) {
456     PyObject* item = PyTuple_GET_ITEM(args, i);
457     if (Py_TYPE(item) != checks[i].pytype) {
458       std::stringstream fail_reason;
459       PyObject* type_str = PyObject_Str(PyObject_Type(item));
460       fail_reason << "expected type of '" << tensor_check_names[i]
461                   << "' to be a tensor type, ";
462       if (!type_str) {
463         fail_reason << "but found a different type";
464       } else {
465         fail_reason << "' but found " << PyUnicode_AsUTF8(type_str);
466       }
467       return Py_BuildValue("s", fail_reason.str().c_str());
468     }
469 
470     auto insertion = unique_tensors.insert({item, nullptr});
471     if (!insertion.second) {
472       std::stringstream fail_reason;
473       fail_reason << "Duplicate tensor found where not expected! ";
474       fail_reason << tensor_check_names[i]
475                   << "should not alias to anything, but is aliased";
476       return Py_BuildValue("s", fail_reason.str().c_str());
477     }
478     std::string fail_reason = checks[i].check_verbose(
479         state, THPVariable_Unpack(item), tensor_check_names[i]);
480     if (fail_reason.length() > 0) {
481       return Py_BuildValue("s", fail_reason.c_str());
482     }
483   }
484 
485   Py_RETURN_TRUE;
486 }
487 
488 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
489 static PyMethodDef TensorGuards_methods[] = {
490     {"check",
491      (PyCFunction)(void*)TensorGuards_check,
492      METH_VARARGS | METH_KEYWORDS,
493      ""},
494     {"check_verbose",
495      (PyCFunction)(void*)TensorGuards_check_verbose,
496      METH_VARARGS | METH_KEYWORDS,
497      "verbose fail reasons for failed checks"},
498     {nullptr} /* Sentinel */
499 };
500 
501 static PyTypeObject TensorGuardsType = {PyVarObject_HEAD_INIT(nullptr, 0)};
502 
503 // TODO (janimesh) - Remove the PyObject_HEAD part when C++ guard manager is
504 // merged.
505 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
506 struct GlobalStateGuard {
507   PyObject_HEAD;
508 
inittorch::dynamo::__anon296b09360211::GlobalStateGuard509   inline void init() {
510     auto& ctx = at::globalContext();
511     _grad_mode = at::GradMode::is_enabled();
512     // The below two flags disambiguate
513     // if torch function disabled state is
514     // 1) enabled, 2) all disabled, 3) subclasses disabled
515     // we guard on the stack separately
516     _torch_function = torch::torch_function_enabled();
517     _torch_function_all_disabled = at::impl::torch_function_all_disabled();
518     _deterministic_algorithms = ctx.deterministicAlgorithms();
519     _deterministic_algorithms_warn_only = ctx.deterministicAlgorithmsWarnOnly();
520     _allow_tf32 = ctx.allowTF32CuBLAS();
521     _allow_fp16_reduce = ctx.allowFP16ReductionCuBLAS();
522     _allow_bf16_reduce = ctx.allowBF16ReductionCuBLAS();
523     _num_threads = at::get_num_threads();
524     _default_dtype = at::get_default_dtype();
525   }
526 
checktorch::dynamo::__anon296b09360211::GlobalStateGuard527   inline bool check() const {
528     auto& ctx = at::globalContext();
529     return (_grad_mode == at::GradMode::is_enabled() &&
530             _torch_function == torch::torch_function_enabled() &&
531             _torch_function_all_disabled ==
532                 at::impl::torch_function_all_disabled() &&
533             _deterministic_algorithms == ctx.deterministicAlgorithms() &&
534             _deterministic_algorithms_warn_only ==
535                 ctx.deterministicAlgorithmsWarnOnly() &&
536             _allow_tf32 == ctx.allowTF32CuBLAS() &&
537             _allow_fp16_reduce == ctx.allowFP16ReductionCuBLAS() &&
538             _allow_bf16_reduce == ctx.allowBF16ReductionCuBLAS() &&
539             _num_threads == at::get_num_threads()) &&
540         _default_dtype == at::get_default_dtype();
541   }
542 
reasontorch::dynamo::__anon296b09360211::GlobalStateGuard543   inline std::string reason() const {
544     std::ostringstream os;
545     auto& ctx = at::globalContext();
546     if (_grad_mode != at::GradMode::is_enabled())
547       os << "grad_mode ";
548     if (_torch_function != torch::torch_function_enabled())
549       os << "torch_function ";
550     if (_deterministic_algorithms != ctx.deterministicAlgorithms())
551       os << "deterministic_algorithms ";
552     if (_deterministic_algorithms_warn_only !=
553         ctx.deterministicAlgorithmsWarnOnly())
554       os << "deterministic_algorithms_warn_only ";
555     if (_allow_tf32 != ctx.allowTF32CuBLAS())
556       os << "allow_tf32 ";
557     if (_allow_fp16_reduce != ctx.allowFP16ReductionCuBLAS())
558       os << "allow_fp16_reduce ";
559     if (_allow_bf16_reduce != ctx.allowBF16ReductionCuBLAS())
560       os << "allow_bf16_reduce ";
561     if (_num_threads != at::get_num_threads())
562       os << "num_threads ";
563     if (_default_dtype != at::get_default_dtype())
564       os << "default_dtype ";
565     return os.str();
566   }
567 
568   bool _grad_mode;
569   bool _torch_function;
570   bool _torch_function_all_disabled;
571   bool _deterministic_algorithms;
572   bool _deterministic_algorithms_warn_only;
573   bool _allow_tf32;
574   bool _allow_fp16_reduce;
575   bool _allow_bf16_reduce;
576   int _num_threads;
577   caffe2::TypeMeta _default_dtype;
578   // TODO(jansel): we should guard on more state as inductor starts using it
579 };
580 
GlobalStateGuard_init(GlobalStateGuard * self,PyObject * args,PyObject * kwargs)581 int GlobalStateGuard_init(
582     GlobalStateGuard* self,
583     PyObject* args,
584     PyObject* kwargs) {
585   self->init();
586   return 0;
587 }
588 
GlobalStateGuard_check(GlobalStateGuard * self,PyObject * args,PyObject * kwargs)589 PyObject* GlobalStateGuard_check(
590     GlobalStateGuard* self,
591     PyObject* args,
592     PyObject* kwargs) {
593   if (self->check()) {
594     Py_RETURN_TRUE;
595   } else {
596     Py_RETURN_FALSE;
597   }
598 }
599 
GlobalStateGuard_reason(GlobalStateGuard * self,PyObject * args,PyObject * kwargs)600 PyObject* GlobalStateGuard_reason(
601     GlobalStateGuard* self,
602     PyObject* args,
603     PyObject* kwargs) {
604   return PyUnicode_FromString(self->reason().c_str());
605 }
606 
607 // NOLINTNEXTLINE(*array*)
608 static PyMethodDef GlobalStateGuard_methods[] = {
609     {"check",
610      (PyCFunction)(void*)GlobalStateGuard_check,
611      METH_NOARGS,
612      "Return true if global state was the same as at creation time"},
613     {"reason",
614      (PyCFunction)(void*)GlobalStateGuard_reason,
615      METH_NOARGS,
616      "Return string reason for guard check failing"},
617     {nullptr}};
618 static PyTypeObject GlobalStateGuardType = {PyVarObject_HEAD_INIT(nullptr, 0)};
619 
check_type_id(PyObject * dummy,PyObject * args)620 static PyObject* check_type_id(PyObject* dummy, PyObject* args) {
621   // faster `lambda obj, expected: id(type(obj)) == expected`
622   PyObject* obj = nullptr;
623   unsigned long long expected = 0;
624   if (!PyArg_ParseTuple(args, "OK", &obj, &expected)) {
625     return nullptr;
626   }
627   // NOLINTNEXTLINE(performance-no-int-to-ptr)
628   if (Py_TYPE(obj) == (void*)expected) {
629     Py_RETURN_TRUE;
630   } else {
631     Py_RETURN_FALSE;
632   }
633 }
634 
check_obj_id(PyObject * dummy,PyObject * args)635 static PyObject* check_obj_id(PyObject* dummy, PyObject* args) {
636   // faster `lambda obj, expected: id(obj) == expected`
637   PyObject* obj = nullptr;
638   unsigned long long expected = 0;
639   if (!PyArg_ParseTuple(args, "OK", &obj, &expected)) {
640     return nullptr;
641   }
642   // NOLINTNEXTLINE(performance-no-int-to-ptr)
643   if (obj == (void*)expected) {
644     Py_RETURN_TRUE;
645   } else {
646     Py_RETURN_FALSE;
647   }
648 }
649 
650 #if IS_PYTHON_3_12_PLUS
651 
652 static std::unordered_map<PyObject*, uint64_t> dict_version_map;
653 static int dict_version_watcher_id;
654 static uint64_t global_dict_version_id = 0;
dict_version_watch_callback(PyDict_WatchEvent event,PyObject * dict,PyObject * key,PyObject * new_value)655 static int dict_version_watch_callback(
656     PyDict_WatchEvent event,
657     PyObject* dict,
658     PyObject* key,
659     PyObject* new_value) noexcept {
660   if (event == PyDict_EVENT_DEALLOCATED) {
661     dict_version_map.erase(dict);
662   } else if (event != PyDict_EVENT_CLONED) {
663     dict_version_map[dict] = global_dict_version_id++;
664   }
665   return 0;
666 }
667 
668 #endif
669 
get_dict_version_unchecked(PyObject * dict)670 static uint64_t get_dict_version_unchecked(PyObject* dict) {
671 #if IS_PYTHON_3_12_PLUS
672 
673   if (PyDict_Watch(dict_version_watcher_id, dict)) {
674     throw std::runtime_error("failed to add version watcher to dict!");
675   }
676   if (!dict_version_map.count(dict)) {
677     dict_version_map[dict] = global_dict_version_id++;
678   }
679   return dict_version_map[dict];
680 
681 #else
682 
683   return ((PyDictObject*)dict)->ma_version_tag;
684 
685 #endif
686 }
687 
dict_version(PyObject * dummy,PyObject * args)688 static PyObject* dict_version(PyObject* dummy, PyObject* args) {
689   // Retrieves the version of a dictionary.
690   PyObject* obj = nullptr;
691   if (!PyArg_ParseTuple(args, "O", &obj)) {
692     return nullptr;
693   }
694   if (!PyDict_Check(obj)) {
695     return nullptr;
696   }
697   return THPUtils_packUInt64(get_dict_version_unchecked(obj));
698 }
699 
assert_size_stride(PyObject * dummy,PyObject * args)700 static PyObject* assert_size_stride(PyObject* dummy, PyObject* args) {
701   /*
702    Assert that a given tensor has a given size/stride, but ignore strides
703    of size==1 dimensions.  Implemented in C++ as this is on the hot path.
704   */
705   PyObject* item = nullptr;
706   PyObject* size = nullptr;
707   PyObject* stride = nullptr;
708   if (!PyArg_ParseTuple(args, "OOO", &item, &size, &stride)) {
709     return nullptr;
710   }
711   if (!THPVariable_CheckExact(item) && !THPVariable_Check(item)) {
712     PyErr_SetString(PyExc_TypeError, "expected Tensor()");
713     return nullptr;
714   }
715   if (!PyTuple_CheckExact(size) || !PyTuple_CheckExact(stride)) {
716     PyErr_SetString(PyExc_TypeError, "expected tuple()");
717     return nullptr;
718   }
719   at::Tensor tensor = THPVariable_Unpack(item);
720   int64_t ndim = tensor.ndimension();
721   if (PyTuple_GET_SIZE(size) != ndim || PyTuple_GET_SIZE(stride) != ndim) {
722     PyErr_SetString(PyExc_AssertionError, "wrong number of dimensions");
723     return nullptr;
724   }
725   std::stringstream msg;
726   int num_errors = 0;
727   for (auto i : c10::irange(ndim)) {
728     int64_t want_size = THPUtils_unpackLong(PyTuple_GET_ITEM(size, i));
729     int64_t want_stride = THPUtils_unpackLong(PyTuple_GET_ITEM(stride, i));
730     int64_t actual_size = tensor.size(i);
731     int64_t actual_stride = tensor.stride(i);
732     if (want_size != actual_size ||
733         // ignore stride differences when size is 1
734         (want_stride != actual_stride && actual_size > 1)) {
735       if (num_errors > 0)
736         msg << "; ";
737       msg << "expected size " << actual_size << "==" << want_size << ", stride "
738           << actual_stride << "==" << want_stride << " at dim=" << i;
739       num_errors++;
740     }
741   }
742 
743   if (num_errors) {
744     PyErr_SetString(PyExc_AssertionError, msg.str().c_str());
745     return nullptr;
746   }
747 
748   Py_RETURN_TRUE;
749 }
750 
751 template <typename T>
unwrap_size_tuple(PyObject * obj,T & output)752 inline static void unwrap_size_tuple(PyObject* obj, T& output) {
753   TORCH_CHECK(PyTuple_CheckExact(obj));
754   size_t len = PyTuple_GET_SIZE(obj);
755   output.reserve(len);
756   for (size_t i = 0; i < len; ++i) {
757     auto result = PyLong_AsSsize_t(PyTuple_GET_ITEM(obj, i));
758     TORCH_CHECK(result >= 0);
759     output.emplace_back(result);
760   }
761 }
762 
763 template <typename T>
_parse_empty_strided_args(PyObject * args,T & sizes,T & strides,at::ScalarType & dtype)764 inline static void _parse_empty_strided_args(
765     PyObject* args,
766     T& sizes,
767     T& strides,
768     at::ScalarType& dtype) {
769   TORCH_CHECK(PyTuple_CheckExact(args));
770   TORCH_CHECK(PyTuple_GET_SIZE(args) == 3);
771   // note PyTuple_GET_ITEM returns a borrowed ref, so no need for refcounts
772   unwrap_size_tuple(PyTuple_GET_ITEM(args, 0), sizes);
773   unwrap_size_tuple(PyTuple_GET_ITEM(args, 1), strides);
774   PyObject* py_dtype = PyTuple_GET_ITEM(args, 2);
775   TORCH_CHECK(THPDtype_Check(py_dtype));
776   dtype = reinterpret_cast<THPDtype*>(py_dtype)->scalar_type;
777 }
778 
_empty_strided_device(PyObject * dummy,PyObject * args,c10::DeviceType device_type)779 inline static PyObject* _empty_strided_device(
780     PyObject* dummy,
781     PyObject* args,
782     c10::DeviceType device_type) {
783   HANDLE_TH_ERRORS;
784   at::SmallVector<int64_t, 8> sizes;
785   at::SmallVector<int64_t, 8> strides;
786   at::ScalarType dtype{at::ScalarType::Undefined};
787   _parse_empty_strided_args(args, sizes, strides, dtype);
788   if (device_type == c10::DeviceType::CPU) {
789     return THPVariable_Wrap(
790         at::detail::empty_strided_cpu(sizes, strides, dtype));
791   }
792 #ifdef USE_CUDA
793   else if (device_type == c10::DeviceType::CUDA) {
794     return THPVariable_Wrap(at::detail::empty_strided_cuda(
795         sizes, strides, dtype, c10::DeviceType::CUDA));
796   }
797 #endif
798 #ifdef USE_XPU
799   else if (device_type == c10::DeviceType::XPU) {
800     return THPVariable_Wrap(at::detail::empty_strided_xpu(
801         sizes, strides, dtype, c10::DeviceType::XPU));
802   }
803 #endif
804   else {
805     TORCH_CHECK(
806         false, "PyTorch compiled without support for the specified device.");
807   }
808 
809   END_HANDLE_TH_ERRORS;
810 }
811 
_empty_strided_cpu(PyObject * dummy,PyObject * args)812 static PyObject* _empty_strided_cpu(PyObject* dummy, PyObject* args) {
813   // at::empty_strided is surprising slow.  This is a lower-overhead
814   // version that saves ~2us on every allocation.
815   return _empty_strided_device(dummy, args, c10::DeviceType::CPU);
816 }
817 
_empty_strided_cuda(PyObject * dummy,PyObject * args)818 static PyObject* _empty_strided_cuda(PyObject* dummy, PyObject* args) {
819   // at::empty_strided is surprising slow.  This is lower-overhead.
820   return _empty_strided_device(dummy, args, c10::DeviceType::CUDA);
821 }
822 
_empty_strided_xpu(PyObject * dummy,PyObject * args)823 static PyObject* _empty_strided_xpu(PyObject* dummy, PyObject* args) {
824   // at::empty_strided is surprising slow.  This is lower-overhead.
825   return _empty_strided_device(dummy, args, c10::DeviceType::XPU);
826 }
827 
_reinterpret_tensor(PyObject * dummy,PyObject * args)828 static PyObject* _reinterpret_tensor(PyObject* dummy, PyObject* args) {
829   HANDLE_TH_ERRORS;
830   static PythonArgParser parser(
831       {"_reinterpret_tensor(Tensor base, IntArrayRef sizes, IntArrayRef strides, int64_t offset_increment=0)"},
832       /*traceable=*/true);
833 
834   ParsedArgs<4> parsed_args;
835   auto r = parser.parse(args, /*kwargs=*/nullptr, parsed_args);
836 
837   Tensor self = r.tensor(0);
838   auto sizes = r.intlist(1);
839   auto strides = r.intlist(2);
840   auto offset_increment = r.toInt64(3);
841 
842   auto res = torch::inductor::_reinterpret_tensor(
843       self, sizes, strides, offset_increment);
844   return torch::autograd::utils::wrap(res);
845 
846   END_HANDLE_TH_ERRORS;
847 }
848 
849 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
850 static PyMethodDef _methods[] = {
851     {"check_type_id", check_type_id, METH_VARARGS, nullptr},
852     {"check_obj_id", check_obj_id, METH_VARARGS, nullptr},
853     {"assert_size_stride", assert_size_stride, METH_VARARGS, nullptr},
854     {"dict_version", dict_version, METH_VARARGS, nullptr},
855     {"_empty_strided_cpu", _empty_strided_cpu, METH_VARARGS, nullptr},
856     {"_empty_strided_cuda", _empty_strided_cuda, METH_VARARGS, nullptr},
857     {"_empty_strided_xpu", _empty_strided_xpu, METH_VARARGS, nullptr},
858     {"_reinterpret_tensor", _reinterpret_tensor, METH_VARARGS, nullptr},
859     {nullptr, nullptr, 0, nullptr}};
860 
861 static struct PyModuleDef _module = {
862     PyModuleDef_HEAD_INIT,
863     "torch._C._dynamo.guards",
864     "Module containing checks on tensors",
865     -1,
866     _methods};
867 
get_exception_message()868 std::string get_exception_message() {
869   PyObject *ptype = nullptr, *pvalue = nullptr, *ptraceback = nullptr;
870   PyErr_Fetch(&ptype, &pvalue, &ptraceback);
871 
872   PyObject* exc_message_pyobj = PyObject_Str(pvalue);
873   const char* exc_message = PyUnicode_AsUTF8(exc_message_pyobj);
874 
875   Py_DECREF(exc_message_pyobj);
876   Py_XDECREF(ptype);
877   Py_XDECREF(pvalue);
878   Py_XDECREF(ptraceback);
879   return std::string(exc_message);
880 }
881 
is_immutable_object(py::handle example_value)882 bool is_immutable_object(py::handle example_value) {
883   if (PyTuple_Check(example_value.ptr())) {
884     // Check that each element is immutable
885     for (Py_ssize_t i = 0; i < PyTuple_Size(example_value.ptr()); ++i) {
886       if (!is_immutable_object(
887               py::handle(PyTuple_GetItem(example_value.ptr(), i)))) {
888         return false;
889       }
890     }
891     return true;
892   }
893   return PyLong_Check(example_value.ptr()) ||
894       PyFloat_Check(example_value.ptr()) || PyBool_Check(example_value.ptr()) ||
895       PyUnicode_Check(example_value.ptr()) ||
896       THPVariable_Check(example_value.ptr());
897 }
898 
is_parameter(py::handle tensor)899 bool is_parameter(py::handle tensor) {
900   py::object parameter = py::module::import("torch.nn").attr("Parameter");
901   return py::isinstance(tensor, parameter);
902 }
903 
904 /**
905  * Stores relevant guard debug information, e.g., failure str for a LeafGuard
906  * failure. The data structure is also accessible in Python.
907  */
908 
909 class GuardDebugInfo {
910  public:
GuardDebugInfo(bool result,py::list verbose_code_parts,int num_guards_executed)911   GuardDebugInfo(
912       bool result,
913       py::list verbose_code_parts,
914       int num_guards_executed)
915       : result(result),
916         verbose_code_parts(std::move(verbose_code_parts)),
917         num_guards_executed(num_guards_executed) {}
918 
919   // This constructor is used when guard succeeds.
GuardDebugInfo(bool result,int num_guards_executed)920   GuardDebugInfo(bool result, int num_guards_executed)
921       : result(result), num_guards_executed(num_guards_executed) {}
922 
GuardDebugInfo(bool result,const std::string & failed_reason,int num_guards_executed)923   GuardDebugInfo(
924       bool result,
925       const std::string& failed_reason,
926       int num_guards_executed)
927       : GuardDebugInfo(result, num_guards_executed) {
928     verbose_code_parts.append(failed_reason);
929   }
930 
to_string()931   std::string to_string() {
932     std::stringstream ss;
933     ss << "GuardDebugInfo(\n"
934        << "result=" << result << ",\n"
935        << "verbose_code_parts=" << verbose_code_parts << ",\n"
936        << "num_guards_executed=" << num_guards_executed << ")\n";
937     return ss.str();
938   }
939 
940   // Whether the guard passed or failed.
941   bool result;
942 
943   // This is a list of verbose_code_parts for the failed guard. When there are
944   // more than one verbose_code_parts, then recompilation reasoning infra on the
945   // Python side can iterate over this list and eval each string to pinpoint the
946   // exact code part that failed.
947   py::list verbose_code_parts;
948 
949   // Total number of executed guards so far. This is helpful in debugging if
950   // shuffling is working.
951   int num_guards_executed;
952 };
953 
954 class GuardManager;
955 class RootGuardManager;
956 class DictGuardManager;
957 
958 /**
959  * Base class for the leaf guard in the GuardManager hierarchy.
960  */
961 class LeafGuard {
962  public:
963   // Most guards do not need root guard manager.
LeafGuard(py::object verbose_code_parts)964   LeafGuard(py::object verbose_code_parts)
965       : _verbose_code_parts(std::move(verbose_code_parts)) {}
966 
967   // Guards like TENSOR_MATCH require root_guard_manager to access local_state
968   // shared across all leaf guards.
LeafGuard(RootGuardManager * root_guard_manager,py::object verbose_code_parts)969   LeafGuard(RootGuardManager* root_guard_manager, py::object verbose_code_parts)
970       : _root_guard_manager(root_guard_manager),
971         _verbose_code_parts(std::move(verbose_code_parts)) {}
972 
973   // check function could be called from python. This is useful for debugging
974   // purpose.
check(py::handle value)975   bool check(py::handle value) {
976     return check_nopybind(value.ptr());
977   }
978 
check_verbose(py::handle value)979   GuardDebugInfo check_verbose(py::handle value) {
980     return check_verbose_nopybind(value.ptr());
981   }
982 
check_verbose_nopybind(PyObject * value)983   virtual GuardDebugInfo check_verbose_nopybind(
984       PyObject* value) { // borrowed ref
985     bool result = check_nopybind(value);
986     if (!result) {
987       return GuardDebugInfo(result, _verbose_code_parts, 0);
988     }
989     return GuardDebugInfo(true, 0);
990   }
991 
verbose_code_parts()992   py::list verbose_code_parts() {
993     return _verbose_code_parts;
994   }
995 
996   // This is on the hot path and avoids any refcounting code from pybind. This
997   // is not exposed to Python and can only be called from C++.
998   virtual bool check_nopybind(PyObject* value) = 0;
999   virtual ~LeafGuard() = default;
1000 
1001  protected:
1002   // RootGuardManager has state that is common across all guards like
1003   // LocalState.
1004   RootGuardManager* _root_guard_manager{nullptr};
1005 
1006  private:
1007   // This is set while constructing the leaf guard. This is used for identifying
1008   // the cause of recompilation.
1009   py::list _verbose_code_parts;
1010 };
1011 
1012 /**
1013  * Represents a leaf guard that accepts the python guard check function. We
1014  * would like to have most of the guards in C++ (to avoid a Python function
1015  * call).  But, it will take some time to reach that goal. Also, there might be
1016  * cases where its too tedious to write an equivalent C++ guard.
1017  *
1018  * LAMBDA_GUARD allows us to gradually move to C++. We can start from all
1019  * guards of type PythonLambaGuard and incrementally move expensive guards to
1020  * C++.
1021  */
1022 class LAMBDA_GUARD : public LeafGuard {
1023  public:
LAMBDA_GUARD(py::object guard_check_fn,py::object verbose_code_parts)1024   LAMBDA_GUARD(py::object guard_check_fn, py::object verbose_code_parts)
1025       : LeafGuard(std::move(verbose_code_parts)) {
1026     if (py::isinstance<py::function>(guard_check_fn)) {
1027       _guard_check_fn = py::cast<py::function>(std::move(guard_check_fn));
1028     } else {
1029       throw py::type_error("LAMBDA_GUARD expects (callable, str)");
1030     }
1031   }
1032 
1033   // Runs the lambda function with the current f_locals value.
check_nopybind(PyObject * value)1034   bool check_nopybind(PyObject* value) override { // borrowed ref
1035     PyObject* x = PyObject_CallOneArg(_guard_check_fn.ptr(), value); // new ref
1036     if (x == nullptr) {
1037       // An exception is caught in the lambda function.
1038       PyErr_Clear();
1039       return false;
1040     }
1041     bool result = PyObject_IsTrue(x);
1042     Py_DECREF(x);
1043     return result;
1044   }
1045 
check_verbose_nopybind(PyObject * value)1046   GuardDebugInfo check_verbose_nopybind(PyObject* value) override {
1047     PyObject* x = PyObject_CallOneArg(_guard_check_fn.ptr(), value); // new ref
1048     if (x == nullptr) {
1049       // An exception is caught in the lambda function.
1050       std::string exc_message = get_exception_message();
1051       PyErr_Clear();
1052       return GuardDebugInfo(false, exc_message, 0);
1053     }
1054     bool result = PyObject_IsTrue(x);
1055     Py_DECREF(x);
1056     if (result) {
1057       return GuardDebugInfo(true, 0);
1058     }
1059     return GuardDebugInfo(false, verbose_code_parts(), 0);
1060   }
1061 
1062  private:
1063   // The user provided lambda function for check_fn.
1064   py::function _guard_check_fn;
1065 };
1066 
1067 class TYPE_MATCH : public LeafGuard {
1068  public:
1069   // type_id = id(type(obj))
TYPE_MATCH(py::object type_id,py::object verbose_code_parts)1070   TYPE_MATCH(py::object type_id, py::object verbose_code_parts)
1071       : LeafGuard(std::move(verbose_code_parts)),
1072         _expected(py::cast<intptr_t>(std::move(type_id))) {}
1073 
check_nopybind(PyObject * value)1074   bool check_nopybind(PyObject* value) override { // borrowed ref
1075     // NOLINTNEXTLINE(performance-no-int-to-ptr)
1076     return Py_TYPE(value) == (void*)_expected;
1077   }
1078 
1079  private:
1080   // id of the type of the original object.
1081   intptr_t _expected;
1082 };
1083 
1084 class ID_MATCH : public LeafGuard {
1085  public:
1086   // obj_id = id(obj)
ID_MATCH(py::object obj_id,py::object verbose_code_parts)1087   ID_MATCH(py::object obj_id, py::object verbose_code_parts)
1088       : LeafGuard(std::move(verbose_code_parts)),
1089         _expected(py::cast<intptr_t>(std::move(obj_id))) {}
1090 
check_nopybind(PyObject * value)1091   bool check_nopybind(PyObject* value) override { // borrowed ref
1092     // NOLINTNEXTLINE(performance-no-int-to-ptr)
1093     return value == (void*)_expected;
1094   }
1095 
1096  private:
1097   // id of the original object.
1098   intptr_t _expected;
1099 };
1100 
1101 class EQUALS_MATCH : public LeafGuard {
1102  public:
EQUALS_MATCH(py::object value,py::object verbose_code_parts)1103   EQUALS_MATCH(py::object value, py::object verbose_code_parts)
1104       : LeafGuard(std::move(verbose_code_parts)),
1105         _value(value),
1106         _value_type(Py_TYPE(value.ptr())) {}
1107 
check_nopybind(PyObject * value)1108   bool check_nopybind(PyObject* value) override { // borrowed ref
1109     // Fast path - pointer equality check. Pointer equality checks are ok
1110     // because objects guarded with EQUALS_MATCH are immutable.
1111     if (value != _value.ptr()) {
1112       // Check type
1113       if (Py_TYPE(value) != _value_type) {
1114         return false;
1115       }
1116       int result = PyObject_RichCompareBool(value, _value.ptr(), Py_EQ);
1117       // Check for exception
1118       if (result == -1) {
1119         PyErr_Clear();
1120         return false;
1121       }
1122       return result;
1123     }
1124     return true;
1125   }
1126 
1127  private:
1128   // value to compare against. This is py::object so that we hold on to the
1129   // original value and prevent garbage collection. We run EQUALS_MATCH only on
1130   // selected objects which do not have high memory footprint, so holding on to
1131   // these objects is ok.
1132   py::object _value;
1133 
1134   // Type of the value
1135   PyTypeObject* _value_type;
1136 };
1137 
1138 class TUPLE_ITERATOR_LEN : public LeafGuard {
1139  public:
TUPLE_ITERATOR_LEN(py::object length,py::object type_id,py::object verbose_code_parts)1140   TUPLE_ITERATOR_LEN(
1141       py::object length,
1142       py::object type_id,
1143       py::object verbose_code_parts)
1144       : LeafGuard(std::move(verbose_code_parts)),
1145         _length(py::cast<Py_ssize_t>(std::move(length))),
1146         _type_id(py::cast<intptr_t>(std::move(type_id))) {}
1147 
check_nopybind(PyObject * value)1148   bool check_nopybind(PyObject* value) override { // borrowed ref
1149     // Do a type match first.
1150     // NOLINTNEXTLINE(performance-no-int-to-ptr)
1151     if (Py_TYPE(value) != (void*)_type_id) {
1152       return false;
1153     }
1154     _PyTupleIterObject* it = (_PyTupleIterObject*)value;
1155     Py_ssize_t length = 0;
1156     if (it->it_seq)
1157       length = PyTuple_GET_SIZE(it->it_seq) - it->it_index;
1158     return length == _length;
1159   }
1160 
1161  private:
1162   // Length of the guarded list
1163   Py_ssize_t _length;
1164   intptr_t _type_id;
1165 };
1166 
1167 class LENGTH_CHECK : public LeafGuard {
1168  public:
LENGTH_CHECK(py::object value,py::object verbose_code_parts)1169   LENGTH_CHECK(py::object value, py::object verbose_code_parts)
1170       : LeafGuard(std::move(verbose_code_parts)),
1171         _length(py::cast<Py_ssize_t>(std::move(value))) {}
1172 
check_nopybind(PyObject * value)1173   bool check_nopybind(PyObject* value) override { // borrowed ref
1174     // PySequence_Length returns -1 if the object is not a sequence. So, we
1175     // don't have to test for PySequence_Check.
1176     return PySequence_Length(value) == _length;
1177   }
1178 
1179  private:
1180   // Length of the guarded list
1181   Py_ssize_t _length;
1182 };
1183 
1184 class DICT_LENGTH : public LeafGuard {
1185  public:
DICT_LENGTH(py::object value,py::object verbose_code_parts)1186   DICT_LENGTH(py::object value, py::object verbose_code_parts)
1187       : LeafGuard(std::move(verbose_code_parts)),
1188         _length(py::cast<Py_ssize_t>(std::move(value))) {}
1189 
check_nopybind(PyObject * value)1190   bool check_nopybind(PyObject* value) override { // borrowed ref
1191     return PyDict_Check(value) && PyDict_Size(value) == _length;
1192   }
1193 
1194  private:
1195   // Length of the guarded dict
1196   Py_ssize_t _length;
1197 };
1198 
1199 class NOT_NONE : public LeafGuard {
1200  public:
NOT_NONE(py::object verbose_code_parts)1201   NOT_NONE(py::object verbose_code_parts)
1202       : LeafGuard(std::move(verbose_code_parts)) {}
1203 
check_nopybind(PyObject * value)1204   bool check_nopybind(PyObject* value) override { // borrowed ref
1205     return value != Py_None;
1206   }
1207 };
1208 
1209 class DEFAULT_DEVICE : public LeafGuard {
1210  public:
DEFAULT_DEVICE(py::object verbose_code_parts)1211   DEFAULT_DEVICE(py::object verbose_code_parts)
1212       : LeafGuard(std::move(verbose_code_parts)) {
1213     py::handle device_module = py::module::import("torch.utils._device");
1214     // Save the dict using py::object
1215     _utils_device_dict = device_module.attr("__dict__");
1216     _device = _utils_device_dict["CURRENT_DEVICE"];
1217   }
1218 
check_nopybind(PyObject * value)1219   bool check_nopybind(PyObject* value) override { // borrowed ref
1220     // Create a static interned string. Interned string is faster than creating
1221     // a new string every time. Even though its a new reference, we don't dec
1222     // ref it. Interned strings are used for things like variable names and are
1223     // leaked by design.
1224     static PyObject* current_device_str =
1225         PyUnicode_InternFromString("CURRENT_DEVICE");
1226     PyObject* device = PyDict_GetItem(
1227         _utils_device_dict.ptr(), current_device_str); // borrowed ref
1228     if (device != _device.ptr()) {
1229       int result = PyObject_RichCompareBool(device, _device.ptr(), Py_EQ);
1230       if (result == -1) {
1231         PyErr_Clear();
1232         return false;
1233       }
1234       return result;
1235     }
1236     return true;
1237   }
1238 
1239  private:
1240   // Save the current device and the module dict during the guard construction.
1241   py::object _utils_device_dict;
1242   py::object _device;
1243 };
1244 
1245 class GLOBAL_STATE : public LeafGuard {
1246  public:
GLOBAL_STATE(py::object verbose_code_parts)1247   GLOBAL_STATE(py::object verbose_code_parts)
1248       : LeafGuard(std::move(verbose_code_parts)) {
1249     _guard = std::make_unique<GlobalStateGuard>();
1250     _guard->init();
1251   }
1252 
check_nopybind(PyObject * value)1253   bool check_nopybind(PyObject* value) override { // borrowed ref
1254     // Ignore value arg, this is just to satisfy the interface.
1255     return _guard->check();
1256   }
1257 
check_verbose_nopybind(PyObject * value)1258   GuardDebugInfo check_verbose_nopybind(PyObject* value) override {
1259     if (!_guard->check()) {
1260       return GuardDebugInfo(
1261           false, "GLOBAL_STATE changed: " + _guard->reason(), 0);
1262     }
1263     return GuardDebugInfo(true, 1);
1264   }
1265 
1266  private:
1267   std::unique_ptr<GlobalStateGuard> _guard;
1268 };
1269 
1270 class DATA_PTR_MATCH : public LeafGuard {
1271  public:
DATA_PTR_MATCH(py::object tensor,py::object verbose_code_parts)1272   DATA_PTR_MATCH(py::object tensor, py::object verbose_code_parts)
1273       : LeafGuard(std::move(verbose_code_parts)) {
1274     PyObject* value = tensor.ptr();
1275     if (!THPVariable_CheckExact(value) && !THPVariable_Check(value)) {
1276       throw std::runtime_error("DATA_PTR_MATCH guard requires a tensor");
1277     }
1278     _data_ptr = THPVariable_Unpack(value).data_ptr();
1279   }
1280 
check_nopybind(PyObject * value)1281   bool check_nopybind(PyObject* value) override { // borrowed ref
1282     if (!THPVariable_CheckExact(value) && !THPVariable_Check(value)) {
1283       return false;
1284     }
1285     void* data_ptr = THPVariable_Unpack(value).data_ptr();
1286     return data_ptr == _data_ptr;
1287   }
1288 
1289  private:
1290   // Original tensor data pointer.
1291   void* _data_ptr;
1292 };
1293 
1294 // Checks that an attr is absent in the object. We don't need the opposite
1295 // HASATTR guard because we can just rely on GetAttrGuardAccessor to act as
1296 // HASATTR guard.
1297 class NO_HASATTR : public LeafGuard {
1298  public:
NO_HASATTR(py::object attr_name,py::object verbose_code_parts)1299   NO_HASATTR(py::object attr_name, py::object verbose_code_parts)
1300       : LeafGuard(std::move(verbose_code_parts)),
1301         _attr_name(std::move(attr_name)) {}
1302 
check_nopybind(PyObject * value)1303   bool check_nopybind(PyObject* value) override { // borrowed ref
1304     return PyObject_HasAttr(value, _attr_name.ptr()) == 0;
1305   }
1306 
1307  private:
1308   py::object _attr_name;
1309 };
1310 
1311 // Checks that dict contains or does not contain a key. This happens for
1312 // PythonSysModulesVariable tracker.
1313 // TODO(janimesh) - Check if we can use DictGuardManager. The downside could be
1314 // large number of keys for sys module, so DICT_CONTAINS might still end up
1315 // being faster.
1316 class DICT_CONTAINS : public LeafGuard {
1317  public:
DICT_CONTAINS(bool contains,py::object key,py::object verbose_code_parts)1318   DICT_CONTAINS(bool contains, py::object key, py::object verbose_code_parts)
1319       : LeafGuard(std::move(verbose_code_parts)),
1320         _contains(contains ? 1 : 0),
1321         _key(std::move(key)) {}
1322 
check_nopybind(PyObject * value)1323   bool check_nopybind(PyObject* value) override { // borrowed ref
1324     int result = PyDict_Contains(value, _key.ptr());
1325     if (result == -1) {
1326       PyErr_Clear();
1327       return false;
1328     }
1329     return result == _contains;
1330   }
1331 
1332  private:
1333   int _contains;
1334   py::object _key;
1335 };
1336 
1337 /**
1338  * Relational guards compare more than one value. We implement Relational
1339  * guards by capturing some state in the guard object. For example for tensor
1340  * aliasing guards - tensor X is not tensor Y - we construct one leaf guard
1341  * and and install it at as a leaf of two guard managers (one for X and
1342  * another for Y). Therefore, this guard is run twice. In the first
1343  * invocation, it saves the first value (state) and returns True. In the
1344  * second invocation, it compares the saved value with the new value and
1345  * returns True if they do not alias.
1346  *
1347  * We have to be careful about resetting in case the other guards fail and we
1348  * have some state in the relational guard. This is done by virtual method
1349  * reset_state(). This is called by the RootGuardManager before it exits.
1350  *
1351  */
1352 class RelationalGuard : public LeafGuard {
1353  public:
RelationalGuard(py::object verbose_code_parts)1354   RelationalGuard(py::object verbose_code_parts)
1355       : LeafGuard(std::move(verbose_code_parts)) {}
1356 
1357   // reset the relational guard state on guard failure. This is called by the
1358   // guard manager.
1359   virtual void reset_state() = 0;
1360 };
1361 
1362 /**
1363  * Checks that object x is object y.
1364  */
1365 class OBJECT_ALIASING : public RelationalGuard {
1366  public:
OBJECT_ALIASING(py::object verbose_code_parts)1367   OBJECT_ALIASING(py::object verbose_code_parts)
1368       : RelationalGuard(std::move(verbose_code_parts)) {}
1369 
check_nopybind(PyObject * value)1370   bool check_nopybind(PyObject* value) override { // borrowed ref
1371     if (_is_first_call) {
1372       _first_tensor = value;
1373       _is_first_call = false;
1374       return true;
1375     }
1376     return _first_tensor == value;
1377   }
1378 
reset_state()1379   void reset_state() final {
1380     _is_first_call = true;
1381   }
1382 
1383  private:
1384   bool _is_first_call{true};
1385   PyObject* _first_tensor{nullptr};
1386 };
1387 
1388 /**
1389  * Checks that none of the tensors alias.
1390  */
1391 class NO_TENSOR_ALIASING : public RelationalGuard {
1392  public:
NO_TENSOR_ALIASING(const py::list & tensor_names,py::object verbose_code_parts)1393   NO_TENSOR_ALIASING(
1394       const py::list& tensor_names,
1395       py::object verbose_code_parts)
1396       : RelationalGuard(std::move(verbose_code_parts)),
1397         _tensor_names(tensor_names) {
1398     _unique_tensors.reserve(tensor_names.size());
1399   }
1400 
check_nopybind(PyObject * value)1401   bool check_nopybind(PyObject* value) override { // borrowed ref
1402     // Typically we don't have to increment the ref count here because the
1403     // tensors are held in f_locals. But there is a special case for
1404     // `from_numpy` source. `from_numpy` converts integers and such into tensors
1405     // and these tensors are ephemeral. If we don't incref, those tensors can be
1406     // garbage collected, and the next time from_numpy can reuse the memory
1407     // address. Therefore, we incref here. They are decref'd in reset_state.
1408     Py_INCREF(value);
1409     auto insertion = _unique_tensors.insert({value, nullptr});
1410     if (!insertion.second) {
1411       // No need to clear _unique_tensors, reset_state will do
1412       // it.
1413       return false;
1414     }
1415     return true;
1416   }
1417 
check_verbose_nopybind(PyObject * value)1418   GuardDebugInfo check_verbose_nopybind(PyObject* value) override {
1419     bool result = check_nopybind(value);
1420 
1421     if (!result) {
1422       return GuardDebugInfo(
1423           false, "Duplicate tensor found where not expected!", 0);
1424     }
1425     return GuardDebugInfo(true, 1);
1426   }
1427 
reset_state()1428   void reset_state() final {
1429     for (auto item : _unique_tensors) {
1430       Py_DECREF(item.first);
1431     }
1432     _unique_tensors.clear();
1433   }
1434 
1435  private:
1436   py::list _tensor_names;
1437   ska::flat_hash_map<PyObject*, std::nullptr_t> _unique_tensors;
1438 };
1439 
1440 class DYNAMIC_INDICES : public LeafGuard {
1441   // C++ equivalent of
1442   //  code.append(
1443   //      f"(({tensor_name}._dynamo_dynamic_indices.issubset({value._dynamo_dynamic_indices}))
1444   //      if hasattr({tensor_name}, '_dynamo_dynamic_indices') else True)"  #
1445   //      noqa: B950
1446   //  )
1447  public:
DYNAMIC_INDICES(py::set dynamic_indices,py::object verbose_code_parts)1448   DYNAMIC_INDICES(py::set dynamic_indices, py::object verbose_code_parts)
1449       : LeafGuard(std::move(verbose_code_parts)),
1450         _dynamic_indices(std::move(dynamic_indices)) {}
1451 
check_nopybind(PyObject * value)1452   bool check_nopybind(PyObject* value) override { // borrowed ref
1453     // Make an interned string
1454     static PyObject* dynamic_indices_str =
1455         PyUnicode_InternFromString("_dynamo_dynamic_indices");
1456     PyObject* indices = PyObject_GetAttr(value, dynamic_indices_str); // new ref
1457     if (indices == nullptr) {
1458       // Attr absent. Clear exception.
1459       PyErr_Clear();
1460       // This is true deliberately. If hasattr fails, we return true.
1461       return true;
1462     }
1463 
1464     static PyObject* issubset_str = PyUnicode_InternFromString("issubset");
1465     PyObject* call_result = PyObject_CallMethodOneArg(
1466         indices, issubset_str, _dynamic_indices.ptr()); // new ref
1467     bool result = PyObject_IsTrue(call_result);
1468     Py_DECREF(call_result);
1469     Py_DECREF(indices);
1470     return result;
1471   }
1472 
1473  private:
1474   py::set _dynamic_indices;
1475 };
1476 
1477 class DICT_VERSION : public LeafGuard {
1478  public:
DICT_VERSION(py::object value,py::object verbose_code_parts)1479   DICT_VERSION(py::object value, py::object verbose_code_parts)
1480       : LeafGuard(std::move(verbose_code_parts)) {
1481     if (!PyDict_Check(value.ptr())) {
1482       throw py::type_error("DICT_VERSION expects a dict");
1483     }
1484     _tag = get_dict_version_unchecked(value.ptr());
1485   }
check_nopybind(PyObject * value)1486   bool check_nopybind(PyObject* value) override { // borrowed ref
1487     return PyDict_Check(value) && get_dict_version_unchecked(value) == _tag;
1488   }
1489 
1490   // Saved dict version.
1491   uint64_t _tag;
1492 };
1493 
1494 // GuardManager can be a pointer to DictGuardManager, but at this point the
1495 // compiler does not know that DictGuardManager is a derived class of
1496 // GuardManager (no way to define inheritance relationships in forward
1497 // declarations), so we forward declare a factory function and define it when
1498 // both DictGuardManager and GuardManager are fully defined.
1499 std::unique_ptr<GuardManager> make_guard_manager(
1500     RootGuardManager* root,
1501     std::string source,
1502     py::handle example_value,
1503     py::handle guard_manager_enum);
1504 
1505 /**
1506  * Base class representing a pair of accessor and the associated guard
1507  * manager. The accessor defines how to access the child value from the
1508  * py::object given to the parent check function.
1509  *
1510  * GuardAccessors can be considered equivalent to name() method of Source
1511  * objects in guards.py. In python, name() method returns a str which we can
1512  * then eval in f_locals and f_globals to retrieve the actual py object.
1513  * GuardAccessor serves the same purpose. The minor difference is that
1514  * GuardManager is a tree structure, so a GuardAccessor just has to retrieve
1515  * the value in the next level in this tree and pass it to the child
1516  * GuardAccessor.
1517  *
1518  * GuardAccessor also owns the GuardManager associated with the retrieved
1519  * value from the GuardAccessor.
1520  */
1521 class GuardAccessor {
1522  public:
GuardAccessor(RootGuardManager * root,py::object accessor_key,std::string source,py::handle example_value,py::handle guard_manager_enum)1523   GuardAccessor(
1524       RootGuardManager* root,
1525       py::object accessor_key,
1526       std::string source,
1527       py::handle example_value,
1528       py::handle guard_manager_enum)
1529       : _guard_manager(make_guard_manager(
1530             root,
1531             source,
1532             example_value,
1533             guard_manager_enum)),
1534         _accessor_key(std::move(accessor_key)),
1535         _source(std::move(source)) {}
1536 
1537   // Return by reference as GuardAccessor owns the GuardManager.
get_guard_manager()1538   std::unique_ptr<GuardManager>& get_guard_manager() {
1539     return _guard_manager;
1540   }
1541 
matches_key(const py::handle & key) const1542   bool matches_key(const py::handle& key) const {
1543     return _accessor_key.equal(key);
1544   }
1545 
get_source()1546   std::string get_source() {
1547     return _source;
1548   }
1549 
1550   // matches_dict_tag is used by the DictGetItemGuardAccessor to skip the guard
1551   // subtree on immutable dict getitems.
1552   virtual bool check_nopybind(PyObject* obj, bool matches_dict_tag = false) = 0;
1553   virtual GuardDebugInfo check_verbose_nopybind(PyObject* obj) = 0;
1554   virtual std::string repr() const = 0;
1555 
1556   virtual ~GuardAccessor() = default;
1557 
1558  protected:
1559   // Guard manager corresponding to the retrieved value from the
1560   // GuardAccessor.
1561   std::unique_ptr<GuardManager> _guard_manager;
1562   // accessor key could be py::str for getattr, getitem or py::function for
1563   // lambda accessor. It is a py::object because we need to keep these accessor
1564   // keys alive.
1565   py::object _accessor_key;
1566 
1567   // A string that can be eval'd on f_locals or f_globals to access the variable
1568   // value. Only used for debugging.
1569   std::string _source;
1570 };
1571 
1572 /**
1573  * GuardManager encapsulates all the guards related to a particular
1574  * py::object. It is a tree structure and consists of 1) Leaf guards - Guards
1575  * that are run on the user given object 2) Accessors - Guard accessors (like
1576  * getattr, getitem) to access the next value in the tree hierarchy. Accessor
1577  * object also holds the child GuardManager.
1578  *
1579  * Lets look at an example to understand how it works.
1580  * class Pair:
1581  *     int x = 1;
1582  *     int y = 2;
1583  *
1584  * At compile time
1585  * >> guard_mananger = GuardManager()
1586  * >> guard_mananger.x.add_lambda_guard(
1587  *        lambda x: isinstance(x, Pair),
1588  *        lambda x: f"expected Pair, found {type(x)}"
1589  *    )
1590  * >> guard_mananger.x.add_lambda_guard(lambda x: x == 1, lambda x: f"found
1591  * {x}, expected 1")
1592  * >> guard_mananger.y.add_lambda_guard(lambda x: x == 2, lambda x: f"found
1593  * {x}, expected 2")
1594  *
1595  * At runtime
1596  * >> guard_mananger.check(Pair())
1597  *
1598  * At compile time we build the tree structure. When we do `guard_manager.x`,
1599  * it creates an AttrGuardAccessorNode, initializes a child guard manager with
1600  * this accessor node, and adds it as a child. When we do
1601  * `guard_manager.x.add_lambda_guard`, we call add_lambda_guard on the newly
1602  * created guard manager and register a new leaf guard on it.
1603  *
1604  * At runtime, the accessor node has an important function of providing a way
1605  * to access the value for the child guard. In the above example,
1606  * guard_manager.x adds an AttrGuardAccessorNode with attr_name x. When check
1607  * function is called, parent GuardManager calls getattr(value, "x") on its
1608  * value passed to the check function to call the check function of the child
1609  * guard manager.
1610  *
1611  * Performace optimization for fail fast - An optimization for runtime here is
1612  * to sort the execution of child guards depending on the failure count.  This
1613  * ensures that we run the guards that are more prone to fail statistically
1614  * first. This can improve the cache lookup time when we have multiple cache
1615  * entries.
1616  */
1617 
1618 class GuardManager {
1619  public:
1620   GuardManager() = delete;
GuardManager(RootGuardManager * root,std::string source)1621   GuardManager(RootGuardManager* root, std::string source)
1622       : _root(root), _source(std::move(source)), _is_dict(false) {}
1623 
GuardManager(RootGuardManager * root,std::string source,py::handle example_value)1624   GuardManager(
1625       RootGuardManager* root,
1626       std::string source,
1627       py::handle example_value)
1628       : _root(root),
1629         _source(std::move(source)),
1630         _is_dict(py::isinstance<py::dict>(example_value)) {
1631     if (_is_dict) {
1632       _dict_tag = get_dict_version_unchecked(example_value.ptr());
1633     }
1634   }
1635 
1636   GuardManager(const GuardManager& m) = delete;
1637   GuardManager& operator=(const GuardManager&) = delete;
1638   virtual ~GuardManager() = default;
1639 
get_root()1640   RootGuardManager* get_root() {
1641     return _root;
1642   }
1643 
get_source()1644   std::string get_source() {
1645     return _source;
1646   }
1647 
add_leaf_guard(std::shared_ptr<LeafGuard> leaf_guard)1648   virtual void add_leaf_guard(std::shared_ptr<LeafGuard> leaf_guard) {
1649     _leaf_guards.emplace_back(std::move(leaf_guard));
1650   }
1651 
1652   /**
1653    * Adds a new guard manager with appropriate Accessor. If the accessor is
1654    * already present, we just return the guard manager.
1655    */
1656   template <typename GuardAccessorT>
get_child_manager(py::object accessor_key,std::string source,py::handle example_value,py::handle guard_manager_enum)1657   GuardManager* get_child_manager(
1658       py::object accessor_key,
1659       std::string source,
1660       py::handle example_value,
1661       py::handle guard_manager_enum) {
1662     // accessor_key type depends on the GuardAccessorT
1663     // for example for GetAttrGuardAccessor - py::str name
1664 
1665     // Return the manager if the guard accessor exists
1666     for (const auto& accessor : _accessors) {
1667       if (accessor->matches_key(accessor_key)) {
1668         return accessor->get_guard_manager().get();
1669       }
1670     }
1671 
1672     // Construct a new guard accessor
1673     _accessors.emplace_back(std::make_unique<GuardAccessorT>(
1674         _root,
1675         std::move(accessor_key),
1676         source,
1677         example_value,
1678         guard_manager_enum));
1679     return _accessors.back()->get_guard_manager().get();
1680   }
1681 
1682   // Runs the leaf guards check and then child managers check function.
1683   //
1684   // NB: There is some code DUPLICATION between this and check_verbose
1685   // function. This is intentional. check function is in the hot path and is
1686   // kept very simple. The purpose of check_verbose function is to get guard
1687   // failure reasoning to understand recompilations. check_verbose function
1688   // does not change the state of the guard, e.g., it does not shuffle the
1689   // guards and does not change the fail count. For simplicity, we duplicate
1690   // the code here.
check_nopybind(PyObject * value)1691   virtual bool check_nopybind(PyObject* value) { // borrowed ref
1692     // Iterate over leaf guards
1693     for (const auto& guard : _leaf_guards) {
1694       if (!guard->check_nopybind(value)) { // early exit
1695         _fail_count += 1;
1696         // no need of sorting, just return.
1697         return false;
1698       }
1699     }
1700 
1701     bool matches_dict_tag = false;
1702     uint64_t new_tag = 0;
1703     if (_is_dict) {
1704       // Check if the dict tag matches. If it does, propagate to the child
1705       // accessors. This will pass to the child manager via
1706       // DictGetItemGuardManager.
1707       new_tag = get_dict_version_unchecked(value);
1708       matches_dict_tag = new_tag == _dict_tag;
1709     }
1710 
1711     // Iterate over accessors.
1712     bool result = true;
1713     bool failed_on_first = true;
1714     for (const auto& accessor : _accessors) {
1715       if (!accessor->check_nopybind(value, matches_dict_tag)) { // early exit
1716         _fail_count += 1;
1717         result = false;
1718         // need to sort, so break the loop.
1719         break;
1720       }
1721       failed_on_first = false;
1722     }
1723 
1724     // failed_on_first is just an optimization to avoid sorting if we are
1725     // failing on the first accessor itself. This is helpful when we have
1726     // already sorted the guards once, and dont need to sort again.
1727     if (!result && !failed_on_first) {
1728       // Inplace sort the child guards by fail count. This moves the guard
1729       // with higher fail count earlier in the queue, and enables fail fast
1730       // for the next check_verbose.
1731 
1732       // An alternate implementation was to use priority queue directly on
1733       // _accessors, but it was rejected because of the complexity of
1734       // popping and creating a new pq on each run_guards. Moreover, this sort
1735       // is happening on the unhappy path when check_verbose guard
1736       // fails. So, its probably ok.
1737       std::sort(
1738           _accessors.begin(),
1739           _accessors.end(),
1740           [](const std::unique_ptr<GuardAccessor>& a,
1741              const std::unique_ptr<GuardAccessor>& b) {
1742             return a->get_guard_manager()->fail_count() >
1743                 b->get_guard_manager()->fail_count();
1744           });
1745     }
1746 
1747     if (_is_dict && result) {
1748       // If result is true, reset the _dict_tag. This is useful if there is a
1749       // mutation on the dict but it does not change the attr values (like
1750       // swapping).
1751       _dict_tag = new_tag;
1752     }
1753     return result;
1754   }
1755 
1756   // This function has some code duplication with function check. This is
1757   // deliberate to keep check function simple and fast.
check_verbose_nopybind(PyObject * value)1758   virtual GuardDebugInfo check_verbose_nopybind(
1759       PyObject* value) { // borrowed ref
1760     int num_guards_executed = 0;
1761     // Iterate over leaf guards
1762     for (const auto& guard : _leaf_guards) {
1763       const GuardDebugInfo& debug_info = guard->check_verbose_nopybind(value);
1764       num_guards_executed++;
1765       if (!debug_info.result) {
1766         return GuardDebugInfo(
1767             false, debug_info.verbose_code_parts, num_guards_executed);
1768       }
1769     }
1770 
1771     // Iterate over accessors
1772     for (const auto& accessor : _accessors) {
1773       const GuardDebugInfo& debug_info =
1774           accessor->check_verbose_nopybind(value);
1775       num_guards_executed += debug_info.num_guards_executed;
1776       if (!debug_info.result) {
1777         return GuardDebugInfo(
1778             false, debug_info.verbose_code_parts, num_guards_executed);
1779       }
1780     }
1781 
1782     return GuardDebugInfo(true, num_guards_executed);
1783   }
1784 
fail_count() const1785   int64_t fail_count() const {
1786     return _fail_count;
1787   }
1788 
1789   // DEBUG function - Returning raw pointers because we can't return unique_ptr
1790   // and pybind does not accept a unique_ptr reference return type.
get_accessors() const1791   virtual std::vector<GuardAccessor*> get_accessors() const {
1792     std::vector<GuardAccessor*> ret;
1793     ret.reserve(_accessors.size());
1794     for (const auto& accessor : _accessors) {
1795       ret.emplace_back(accessor.get());
1796     }
1797     return ret;
1798   }
1799 
1800   // DEBUG function - Returning raw pointers because we can't return unique_ptr
1801   // and pybind does not accept a unique_ptr reference return type.
get_child_managers()1802   virtual std::vector<GuardManager*> get_child_managers() {
1803     std::vector<GuardManager*> ret;
1804     ret.reserve(_accessors.size());
1805     for (const auto& accessor : _accessors) {
1806       ret.emplace_back(accessor->get_guard_manager().get());
1807     }
1808     return ret;
1809   }
1810 
1811   // DEBUG function - Returning raw pointers because we can't return unique_ptr
1812   // and pybind does not accept a unique_ptr reference return type.
get_leaf_guards() const1813   std::vector<LeafGuard*> get_leaf_guards() const {
1814     std::vector<LeafGuard*> ret;
1815     ret.reserve(_leaf_guards.size());
1816     for (const auto& guard : _leaf_guards) {
1817       ret.push_back(guard.get());
1818     }
1819     return ret;
1820   }
1821 
is_leaf_guard_present(const std::string & guard_name)1822   bool is_leaf_guard_present(const std::string& guard_name) {
1823     return _inserted_leaf_guards.find(guard_name) !=
1824         _inserted_leaf_guards.end();
1825   }
1826 
insert_leaf_guard(const std::string & guard_name)1827   void insert_leaf_guard(const std::string& guard_name) {
1828     _inserted_leaf_guards.insert(guard_name);
1829   }
1830 
add_permitted_leaf_guard(std::shared_ptr<LeafGuard> leaf_guard)1831   void add_permitted_leaf_guard(std::shared_ptr<LeafGuard> leaf_guard) {
1832     // Selectively called for permitted guards. This is used by DictGuardManager
1833     // which overrides the add_leaf_guard manager to throw runtime error.
1834     GuardManager::add_leaf_guard(std::move(leaf_guard));
1835   }
1836 
1837  protected:
1838   // Keeps a count of how many times this guard manager check function returns
1839   // False. This is used for sorting optimization.
1840   int64_t _fail_count{0};
1841 
1842  private:
1843   // Root of the guard manager, this is the used to install the relational
1844   // guard resetters.
1845   RootGuardManager* _root;
1846 
1847   // A string that can be used to eval on f_locals or f_globals to get the
1848   // value. This is used only to pass on debugging information.
1849   std::string _source;
1850 
1851   // A map of which leaf guards are inserted. This is to prevent duplicate
1852   // guards like TYPE_MATCH.
1853   std::unordered_set<std::string> _inserted_leaf_guards;
1854 
1855   // Leaf guards are the terminal guards on this object, e.g, type check on a
1856   // list. These guards have to be run before any children are run.
1857   //
1858   // These leaf guards are not shufflable. In almost all cases, these guards
1859   // will have an order, e,g., type(x) is int guard and x == 5 guard. We also
1860   // expect very few leaf guards per GuardManager node.
1861   //
1862   // NB: Why are leaf guards shared ptr? This is primarily to enable relational
1863   // guards like `tensor X is not tensor Y`. These guards require multiple
1864   // values. We handle it by creating one guard object that holds state and this
1865   // guard is installed in many guard managers, hence a shared ptr.
1866   std::vector<std::shared_ptr<LeafGuard>> _leaf_guards;
1867 
1868   // GuardAccessors nodes to access the child guards. These guards are
1869   // shufflable. On a guard failure, they are sorted based on their fail count
1870   // to enable fail fast for the next check.
1871   std::vector<std::unique_ptr<GuardAccessor>> _accessors;
1872 
1873   bool _is_dict;
1874   uint64_t _dict_tag{0};
1875 };
1876 
1877 /**
1878  * RootGuardManager is the root of the guard tree. This is primarily
1879  * constructed to hold the relational guard pointers so that we can reset the
1880  * state of those guards on guard failure. All the other important
1881  * implementation is in GuardManager class.
1882  */
1883 
1884 class RootGuardManager : public GuardManager {
1885  public:
1886   // This is the root node, set its _root member to nullptr
RootGuardManager()1887   RootGuardManager() : GuardManager(this, "L") {}
1888 
1889   // Adds the relational guard resetter
add_relational_guard_resetter(std::shared_ptr<RelationalGuard> relational_guard)1890   void add_relational_guard_resetter(
1891       std::shared_ptr<RelationalGuard> relational_guard) {
1892     _relational_guard_resetters.emplace_back(std::move(relational_guard));
1893   }
1894 
1895   // Python visible API to check guard function.
check(py::handle value)1896   bool check(py::handle value) {
1897     return check_nopybind(value.ptr());
1898   }
1899 
1900   // Python visible API to check_verbose guard function.
check_verbose(py::handle value)1901   GuardDebugInfo check_verbose(py::handle value) {
1902     return check_verbose_nopybind(value.ptr());
1903   }
1904 
1905   // Fast check function.
check_nopybind(PyObject * value)1906   bool check_nopybind(PyObject* value) override { // borrowed ref
1907     // Check [Note on GIL interaction with mutex lock] for details on why we
1908     // need mutex and its interactions wth GIL.
1909     PyThreadState* _save = nullptr;
1910     Py_UNBLOCK_THREADS; // ; is added to avoid clang-formatting
1911     std::lock_guard<std::mutex> lock_guard(_lock);
1912     Py_BLOCK_THREADS; // ; is added to avoid clang-formatting
1913 
1914     // Get the local state. This will be used for TENSOR_MATCH guards.
1915     if (_init_local_state) {
1916       LocalState state;
1917       _local_state = state;
1918     }
1919 
1920     if (!GuardManager::check_nopybind(value)) {
1921       _reset_relational_guard_state();
1922       return false;
1923     }
1924 
1925     // Iterate over epilogue leaf guards.
1926     for (const auto& guard : _epilogue_lambda_guards) {
1927       if (!guard->check_nopybind(value)) { // early exit
1928         _reset_relational_guard_state();
1929         return false;
1930       }
1931     }
1932     _reset_relational_guard_state();
1933     return true;
1934   }
1935 
1936   // Fast check_verbose function.
check_verbose_nopybind(PyObject * value)1937   GuardDebugInfo check_verbose_nopybind(
1938       PyObject* value) override { // borrowed ref
1939     // Check [Note on GIL interaction with mutex lock] for details on why we
1940     // need mutex and its interactions wth GIL.
1941     PyThreadState* _save = nullptr;
1942     Py_UNBLOCK_THREADS; // ; is added to avoid clang-formatting
1943     std::lock_guard<std::mutex> lock_guard(_lock);
1944     Py_BLOCK_THREADS; // ; is added to avoid clang-formatting
1945 
1946     // Get the local state. This will be used for TENSOR_MATCH guards.
1947     if (_init_local_state) {
1948       LocalState state;
1949       _local_state = state;
1950     }
1951 
1952     GuardDebugInfo debug_info = GuardManager::check_verbose_nopybind(value);
1953     if (!debug_info.result) {
1954       _reset_relational_guard_state();
1955       return debug_info;
1956     }
1957 
1958     int num_guards_executed = debug_info.num_guards_executed;
1959 
1960     // Iterate over epilogue leaf guards
1961     for (const auto& guard : _epilogue_lambda_guards) {
1962       const GuardDebugInfo& tmp_debug_info =
1963           guard->check_verbose_nopybind(value);
1964       num_guards_executed++;
1965       if (!tmp_debug_info.result) {
1966         _reset_relational_guard_state();
1967         return GuardDebugInfo(
1968             false, tmp_debug_info.verbose_code_parts, num_guards_executed);
1969       }
1970     }
1971     _reset_relational_guard_state();
1972     return GuardDebugInfo(true, num_guards_executed);
1973   }
1974 
add_epilogue_lambda_guard(std::unique_ptr<LeafGuard> leaf_guard)1975   void add_epilogue_lambda_guard(std::unique_ptr<LeafGuard> leaf_guard) {
1976     _epilogue_lambda_guards.emplace_back(std::move(leaf_guard));
1977   }
1978 
set_init_local_state_flag()1979   void set_init_local_state_flag() {
1980     _init_local_state = true;
1981   }
1982 
1983   // DEBUG function - Returning raw pointers because we can't return unique_ptr
1984   // and pybind does not accept a unique_ptr reference return type.
get_epilogue_lambda_guards() const1985   std::vector<LeafGuard*> get_epilogue_lambda_guards() const {
1986     std::vector<LeafGuard*> ret;
1987     ret.reserve(_epilogue_lambda_guards.size());
1988     for (const auto& guard : _epilogue_lambda_guards) {
1989       ret.push_back(guard.get());
1990     }
1991     return ret;
1992   }
1993 
1994  private:
1995   // Reset the state of all the relational guards on failure.
_reset_relational_guard_state()1996   void _reset_relational_guard_state() {
1997     for (auto& guard : _relational_guard_resetters) {
1998       guard->reset_state();
1999     }
2000   }
2001 
2002  public:
2003   // Local state for TENSOR_MATCH guards.
2004   LocalState _local_state;
2005 
2006  private:
2007   // All the relational guards under this guard mananger. We only use these
2008   // when the guard evaluates to False. This ensures that guard state is reset
2009   // on guard failure so that next invocation is clean.
2010   std::vector<std::shared_ptr<RelationalGuard>> _relational_guard_resetters;
2011 
2012   // These guards are lambda guards, i.e., the guards that lack C++
2013   // implementation. For simplicity, we add these guards at the root. They
2014   // MUST be run after all other guard managers have finished to ensure that
2015   // the epilogue guards do not step on some nonexistent getattr or getitem.
2016   std::vector<std::unique_ptr<LeafGuard>> _epilogue_lambda_guards;
2017 
2018   // [Note on GIL interaction with mutex lock]
2019   // We use std::mutex to prevent multiple threads from running
2020   // check/check_verbose simultaneously. This is to prevent race condition due
2021   // to state changes in RelationalGuard.
2022   //
2023   // However, we also need to be careful about GIL interaction with mutex. There
2024   // is a chance of deadlock
2025   //
2026   //    Thread 1: has GIL, waiting for lock
2027   //    Thread 2: has lock, waiting for GIL
2028   //
2029   // This can happen when Thread 2 earlier acquired the mutex lock, starting
2030   // running the critical section of check function and then called some python
2031   // function (like LAMBDA_GUARD) and reached Cpython codebase that checks if it
2032   // should release the GIL (typically happens after every few bytecode
2033   // instructions). Thread 2 here can decide to release the GIL. Thread 1 can
2034   // acquire GIL and reach the mutex, where it will wait forever.
2035   //
2036   // To avoid this, each thread releases the GIL before acquiring the mutex and
2037   // then acquires the GIL again after acquiring the mutex lock by using
2038   // Py_BLOCK_THREADS and Py_UNBLOCK_THREADS. This avoids the deadlock.
2039   std::mutex _lock;
2040 
2041   // We init LocalState only when this flag it set. This flag is set during
2042   // TENSOR_MATCH guard init.
2043   bool _init_local_state = false;
2044 };
2045 
2046 /*
2047  * Dicts are common in python code. Therefore, we handle guards for dicts
2048  * differently and use PyDict_* APIs which are faster than PyObject_* APIs
2049  * because of no ref count increments/decrements.
2050  *
2051  * DictGuardManager relies on the order of dict.keys(). It keeps track of the
2052  * indices of dict.keys() to access the key, value pair.
2053  */
2054 typedef std::pair<std::unique_ptr<GuardManager>, std::unique_ptr<GuardManager>>
2055     KeyValueManager;
2056 class DictGuardManager : public GuardManager {
2057  public:
DictGuardManager(RootGuardManager * root,std::string source,py::handle example_value)2058   DictGuardManager(
2059       RootGuardManager* root,
2060       std::string source,
2061       py::handle example_value)
2062       : GuardManager(root, std::move(source)),
2063         _size(PyDict_Size(example_value.ptr())),
2064         _expected_type(Py_TYPE(example_value.ptr())),
2065         _is_exact_dict_type(PyDict_CheckExact(example_value.ptr())) {}
2066 
get_key_manager(py::object key_index,std::string source,py::handle example_value,py::handle guard_manager_enum)2067   GuardManager* get_key_manager(
2068       py::object key_index,
2069       std::string source,
2070       py::handle example_value,
2071       py::handle guard_manager_enum) {
2072     KeyValueManager& key_value_manager =
2073         _get_index_manager(std::move(key_index));
2074     if (!key_value_manager.first) {
2075       key_value_manager.first = make_guard_manager(
2076           this->get_root(),
2077           std::move(source),
2078           example_value,
2079           guard_manager_enum);
2080     };
2081     return key_value_manager.first.get();
2082   }
2083 
get_value_manager(py::object key_index,std::string source,py::handle example_value,py::handle guard_manager_enum)2084   GuardManager* get_value_manager(
2085       py::object key_index,
2086       std::string source,
2087       py::handle example_value,
2088       py::handle guard_manager_enum) {
2089     KeyValueManager& key_value_manager =
2090         _get_index_manager(std::move(key_index));
2091     if (!key_value_manager.second) {
2092       key_value_manager.second = make_guard_manager(
2093           this->get_root(),
2094           std::move(source),
2095           example_value,
2096           guard_manager_enum);
2097     };
2098     return key_value_manager.second.get();
2099   }
2100 
check_nopybind(PyObject * obj)2101   bool check_nopybind(PyObject* obj) override { // borrowed ref
2102     // TODO(janimesh) - Implement a fast-path using dict versions.
2103 
2104     if (Py_TYPE(obj) != _expected_type) {
2105       _fail_count += 1;
2106       return false;
2107     }
2108 
2109     if (PyDict_Size(obj) != _size) {
2110       _fail_count += 1;
2111       return false;
2112     }
2113 
2114     // Early return
2115     if (_size == 0) {
2116       return true;
2117     }
2118 
2119     // Invokes the base class's check_nopybind method. We permit a limited set
2120     // of leaf guards and accessors within the DictGuardManager framework.
2121     // Integrating certain guards or accessors directly within the
2122     // DictGuardManager can be challenging. For instance, `type(dict_object)` as
2123     // an accessor is permissible, which otherwise would be hard to integrate
2124     // directly into DictGuardManager.  Similarly, incorporating guards such as
2125     // DICT_CONTAINS and DICT_VERSION as leaf guards offers a simpler solution
2126     // than embedding these functionalities within the DictGuardManager itself.
2127     if (!GuardManager::check_nopybind(obj)) {
2128       _fail_count += 1;
2129       // No need to shuffle the child guards, just return.
2130       return false;
2131     }
2132 
2133     PyObject *key = nullptr, *value = nullptr;
2134     Py_ssize_t pos = 0;
2135 
2136     // Points to an element in the _indices vector.
2137     size_t index_pointer = 0;
2138     // Points to the key index in the dict
2139     Py_ssize_t dict_pointer = 0;
2140 
2141     while (index_pointer < _indices.size() &&
2142            PyDict_Next(obj, &pos, &key, &value)) {
2143       // Skip if dict_pointer is not a saved index.
2144       if (dict_pointer == _indices[index_pointer]) {
2145         index_pointer += 1;
2146         KeyValueManager& key_value_manager = _key_value_managers[dict_pointer];
2147         std::unique_ptr<GuardManager>& key_manager = key_value_manager.first;
2148         if (key_manager && !key_manager->check_nopybind(key)) {
2149           return false;
2150         }
2151         std::unique_ptr<GuardManager>& value_manager = key_value_manager.second;
2152         if (value_manager && !value_manager->check_nopybind(value)) {
2153           return false;
2154         }
2155       }
2156       dict_pointer += 1;
2157     }
2158     return true;
2159   }
2160 
check_verbose_nopybind(PyObject * obj)2161   GuardDebugInfo check_verbose_nopybind(
2162       PyObject* obj) override { // borrowed ref
2163     if (Py_TYPE(obj) != _expected_type) {
2164       return GuardDebugInfo(false, "TYPE_MISMATCH(" + get_source() + ")", 0);
2165     }
2166 
2167     if (PyDict_Size(obj) != _size) {
2168       return GuardDebugInfo(
2169           false, "len(" + get_source() + ") != " + std::to_string(_size), 0);
2170     }
2171 
2172     // Early return
2173     if (_size == 0) {
2174       return GuardDebugInfo(true, 0);
2175     }
2176 
2177     // Invokes the base class's check_nopybind method. We permit a limited set
2178     // of leaf guards and accessors within the DictGuardManager framework.
2179     // Integrating certain guards or accessors directly within the
2180     // DictGuardManager can be challenging. For instance, `type(dict_object)` as
2181     // an accessor is permissible, which otherwise would be hard to integrate
2182     // directly into DictGuardManager.  Similarly, incorporating guards such as
2183     // DICT_CONTAINS and DICT_VERSION as leaf guards offers a simpler solution
2184     // than embedding these functionalities within the DictGuardManager itself.
2185     GuardDebugInfo debug_info = GuardManager::check_verbose_nopybind(obj);
2186     if (!debug_info.result) {
2187       return debug_info;
2188     }
2189 
2190     PyObject *key = nullptr, *value = nullptr;
2191     Py_ssize_t pos = 0;
2192 
2193     // Points to an element in the _indices vector.
2194     size_t index_pointer = 0;
2195     Py_ssize_t dict_pointer = 0;
2196 
2197     int num_guards_executed = 0;
2198     while (index_pointer < _indices.size() &&
2199            PyDict_Next(obj, &pos, &key, &value)) {
2200       // Skip if pos is not a saved index.
2201       if (dict_pointer == _indices[index_pointer]) {
2202         index_pointer += 1;
2203         KeyValueManager& key_value_manager = _key_value_managers[dict_pointer];
2204         std::unique_ptr<GuardManager>& key_manager = key_value_manager.first;
2205         if (key_manager) {
2206           GuardDebugInfo debug_info = key_manager->check_verbose_nopybind(key);
2207           num_guards_executed += debug_info.num_guards_executed;
2208           if (!debug_info.result) {
2209             return GuardDebugInfo(
2210                 false, debug_info.verbose_code_parts, num_guards_executed);
2211           }
2212         }
2213         std::unique_ptr<GuardManager>& value_manager = key_value_manager.second;
2214         if (value_manager) {
2215           GuardDebugInfo debug_info =
2216               value_manager->check_verbose_nopybind(value);
2217           num_guards_executed += debug_info.num_guards_executed;
2218           if (!debug_info.result) {
2219             return GuardDebugInfo(
2220                 false, debug_info.verbose_code_parts, num_guards_executed);
2221           }
2222         }
2223       }
2224       dict_pointer += 1;
2225     }
2226     return GuardDebugInfo(true, num_guards_executed);
2227   }
2228 
skip_adding_guard(const py::object & a,const py::object & b)2229   void skip_adding_guard(const py::object& a, const py::object& b) {
2230     // The `add_leaf_guard` method in `DictGuardManager` is overridden to block
2231     // the addition of leaf guards. However, this is too strict. Python side of
2232     // guard management frequently adds TYPE_MATCH and DICT_LENGTH on
2233     // DictGuardManager. We could refactor Python side to never call these
2234     // guards on dict objects, but that results in messy code. Instead, we just
2235     // override these two guards to not go through add_leaf_guard code path and
2236     // skip adding guards. This makes the python side easy.
2237   }
2238 
fail_on_get_child_manager(const py::object & a,const std::string & source,const py::object & b)2239   void fail_on_get_child_manager(
2240       const py::object& a,
2241       const std::string& source,
2242       const py::object& b) {
2243     throw std::runtime_error("Can not add an accessor to DictGuardManager");
2244   }
2245 
add_leaf_guard(std::shared_ptr<LeafGuard> leaf_guard)2246   void add_leaf_guard(std::shared_ptr<LeafGuard> leaf_guard) override {
2247     // If you are calling this, you probably want to go through a key, value
2248     // child manager and then add a leaf guard on them. DictGuardManager already
2249     // has TYPE_MATCH and LENGTH_CHECK built in.
2250     throw std::runtime_error("DictGuardManager does not support a leaf_guard");
2251   }
2252 
2253   // Debug helper - Returning raw pointers because we can't return unique_ptr
2254   // and pybind does not accept a unique_ptr reference return type.
2255   std::unordered_map<Py_ssize_t, std::pair<GuardManager*, GuardManager*>>
get_key_value_managers()2256   get_key_value_managers() {
2257     std::unordered_map<Py_ssize_t, std::pair<GuardManager*, GuardManager*>> ret;
2258     for (auto index : _indices) {
2259       ret[index] = std::make_pair(
2260           _key_value_managers[index].first.get(),
2261           _key_value_managers[index].second.get());
2262     }
2263     return ret;
2264   }
2265 
is_exact_dict_type()2266   bool is_exact_dict_type() {
2267     return _is_exact_dict_type;
2268   }
2269 
2270  private:
2271   /**
2272    * Adds a new KeyDictGuardAccessor. If the accessor is already present, we
2273    * just return the guard manager.
2274    */
_get_index_manager(py::object key_index)2275   KeyValueManager& _get_index_manager(py::object key_index) {
2276     // Check if the accessor is already present.
2277     Py_ssize_t index = py::cast<Py_ssize_t>(std::move(key_index));
2278     auto it = _key_value_managers.find(index);
2279     if (it != _key_value_managers.end()) {
2280       return it->second;
2281     }
2282     _indices.push_back(index);
2283     // Always keep the _indices array sorted
2284     std::sort(_indices.begin(), _indices.end());
2285     _key_value_managers[index] = std::make_pair(nullptr, nullptr);
2286     return _key_value_managers[index];
2287   }
2288 
2289  protected: // also used by DictSubclassGuardManager
2290   Py_ssize_t _size;
2291   // DictGuardManager supports both exact dict type and non-exact dict type.
2292   // Therefore, we have to compare the type to early exit.
2293   PyTypeObject* _expected_type;
2294   bool _is_exact_dict_type; // Useful to check getattr_manager validity.
2295   std::vector<Py_ssize_t> _indices;
2296   std::unordered_map<Py_ssize_t, KeyValueManager> _key_value_managers;
2297 };
2298 
2299 /**
2300  * The DictSubclassGuardManager is designed to work with dict subclasses,
2301  * specifically focusing on OrderedDicts. Standard dictionaries leverage the
2302  * PyDict_Next function to iterate over keys, values, and items. OrderedDicts,
2303  * on the other hand, rely on an additional linked list structure to maintain
2304  * keys order. Although PyDict_Next and OrderedDict generally yield the same
2305  * order, discrepancies arise when using OrderedDict's move_to_end method (used
2306  * in Pytorch hooks). `move_to_end` method only updates the linked list, leaving
2307  * PyDict_Next unaffected. Therefore, to accurately capture key ordering in such
2308  * cases, DictSubclassGuardManager directly invoke the .keys() method.
2309  */
2310 
2311 class DictSubclassGuardManager : public DictGuardManager {
2312  public:
DictSubclassGuardManager(RootGuardManager * root,std::string source,py::handle example_value)2313   DictSubclassGuardManager(
2314       RootGuardManager* root,
2315       std::string source,
2316       py::handle example_value)
2317       : DictGuardManager(root, std::move(source), example_value) {}
2318 
2319  public:
check_nopybind(PyObject * obj)2320   bool check_nopybind(PyObject* obj) override { // borrowed ref
2321     // TODO(janimesh) - Implement a fast-path using dict versions.
2322 
2323     if (Py_TYPE(obj) != _expected_type) {
2324       _fail_count += 1;
2325       return false;
2326     }
2327 
2328     if (PyDict_Size(obj) != _size) {
2329       _fail_count += 1;
2330       return false;
2331     }
2332 
2333     // Early return
2334     if (_size == 0) {
2335       return true;
2336     }
2337 
2338     if (!GuardManager::check_nopybind(obj)) { // NOLINT
2339       _fail_count += 1;
2340       // No need to shuffle the child guards, just return.
2341       return false;
2342     }
2343 
2344     // Points to an element in the _indices vector.
2345     size_t index_pointer = 0;
2346     // Points to the key index in the dict
2347     Py_ssize_t dict_pointer = 0;
2348 
2349     // Use iter(dict.keys()) to iterate over the keys
2350     py::object keys =
2351         py::handle(obj).attr("keys")(); // py::object handles the references
2352     PyObject* iterator = PyObject_GetIter(keys.ptr()); // new reference
2353     PyObject* key = nullptr;
2354 
2355     while (index_pointer < _indices.size() &&
2356            (key = PyIter_Next(iterator))) { // new reference
2357       if (dict_pointer == _indices[index_pointer]) {
2358         KeyValueManager& key_value_manager = _key_value_managers[dict_pointer];
2359         std::unique_ptr<GuardManager>& key_manager = key_value_manager.first;
2360         if (key_manager && !key_manager->check_nopybind(key)) {
2361           Py_DECREF(key);
2362           Py_DECREF(iterator);
2363           return false;
2364         }
2365 
2366         PyObject* value = PyDict_GetItem(obj, key); // borrowed ref
2367         std::unique_ptr<GuardManager>& value_manager = key_value_manager.second;
2368         if (value_manager && !value_manager->check_nopybind(value)) {
2369           Py_DECREF(key);
2370           Py_DECREF(iterator);
2371           return false;
2372         }
2373 
2374         index_pointer++;
2375       }
2376       dict_pointer++;
2377       Py_DECREF(key);
2378     }
2379 
2380     Py_DECREF(iterator);
2381     return true;
2382   }
2383 
check_verbose_nopybind(PyObject * obj)2384   GuardDebugInfo check_verbose_nopybind(
2385       PyObject* obj) override { // borrowed ref
2386     if (Py_TYPE(obj) != _expected_type) {
2387       return GuardDebugInfo(false, "TYPE_MISMATCH(" + get_source() + ")", 0);
2388     }
2389 
2390     if (PyDict_Size(obj) != _size) {
2391       return GuardDebugInfo(
2392           false, "len(" + get_source() + ") != " + std::to_string(_size), 0);
2393     }
2394 
2395     // Early return
2396     if (_size == 0) {
2397       return GuardDebugInfo(true, 0);
2398     }
2399 
2400     GuardDebugInfo debug_info =
2401         GuardManager::check_verbose_nopybind(obj); // NOLINT
2402     if (!debug_info.result) {
2403       return debug_info;
2404     }
2405 
2406     // Points to an element in the _indices vector.
2407     size_t index_pointer = 0;
2408     // Points to the key index in the dict
2409     Py_ssize_t dict_pointer = 0;
2410 
2411     int num_guards_executed = 0;
2412 
2413     // Use iter(dict.keys()) to iterate over the keys
2414     py::object keys =
2415         py::handle(obj).attr("keys")(); // py::object handles the references
2416     PyObject* iterator = PyObject_GetIter(keys.ptr()); // new reference
2417     PyObject* key = nullptr;
2418 
2419     while (index_pointer < _indices.size() &&
2420            (key = PyIter_Next(iterator))) { // new reference
2421       if (dict_pointer == _indices[index_pointer]) {
2422         KeyValueManager& key_value_manager = _key_value_managers[dict_pointer];
2423         std::unique_ptr<GuardManager>& key_manager = key_value_manager.first;
2424         if (key_manager) {
2425           GuardDebugInfo debug_info = key_manager->check_verbose_nopybind(key);
2426           num_guards_executed += debug_info.num_guards_executed;
2427           if (!debug_info.result) {
2428             Py_DECREF(key);
2429             Py_DECREF(iterator);
2430             return GuardDebugInfo(
2431                 false, debug_info.verbose_code_parts, num_guards_executed);
2432           }
2433         }
2434 
2435         PyObject* value = PyDict_GetItem(obj, key); // borrowed ref
2436         std::unique_ptr<GuardManager>& value_manager = key_value_manager.second;
2437         if (value_manager) {
2438           GuardDebugInfo debug_info =
2439               value_manager->check_verbose_nopybind(value);
2440           num_guards_executed += debug_info.num_guards_executed;
2441           if (!debug_info.result) {
2442             Py_DECREF(key);
2443             Py_DECREF(iterator);
2444             return GuardDebugInfo(
2445                 false, debug_info.verbose_code_parts, num_guards_executed);
2446           }
2447         }
2448         index_pointer++;
2449       }
2450       Py_DECREF(key);
2451       dict_pointer++;
2452     }
2453 
2454     Py_DECREF(iterator);
2455     return GuardDebugInfo(true, num_guards_executed);
2456   }
2457 };
2458 
make_guard_manager(RootGuardManager * root,std::string source,py::handle example_value,py::handle guard_manager_enum)2459 std::unique_ptr<GuardManager> make_guard_manager(
2460     RootGuardManager* root,
2461     std::string source,
2462     py::handle example_value,
2463     py::handle guard_manager_enum) {
2464   static py::object guard_manager_enum_class =
2465       py::module_::import("torch._dynamo.guards").attr("GuardManagerType");
2466   static py::object base_guard_manager_enum =
2467       guard_manager_enum_class.attr("GUARD_MANAGER");
2468   static py::object dict_guard_manager_enum =
2469       guard_manager_enum_class.attr("DICT_GUARD_MANAGER");
2470   static py::object dict_subclass_guard_manager_enum =
2471       guard_manager_enum_class.attr("DICT_SUBCLASS_GUARD_MANAGER");
2472   if (py::isinstance<py::dict>(example_value)) {
2473     // The purpose of having both DictGuardManager and DictSubclassGuardManager
2474     // is to handle the variability in how dictionaries and their subclasses
2475     // manage key ordering.
2476 
2477     // While inserting dictionary guards (check guards.py), we rely on the
2478     // list(d.keys()) ordering. Therefore, the cpp guard equivalent must have
2479     // the same keys ordering. For standard dictionaries, .keys() API internally
2480     // uses PyDict_Next. So, DictGuardManager directly uses PyDict_Next to
2481     // speedup the key fetches.
2482 
2483     // But PyDict_Next might not give correct ordering for subclasses of dict.
2484     // For example, OrderedDict override the .keys() API without changing the
2485     // underlying datastructure. This leads to different keys ordering than the
2486     // one given by PyDict_Next. We use DictSubclassGuardManager to account for
2487     // this discrepancy. DictSubclassGuardManager directly calls the .keys() API
2488     // to accurately capture key ordering. This approach is less efficient than
2489     // using PyDict_Next (handled by DictGuardManager), but it ensures
2490     // correctness.
2491 
2492     // Since regular dicts are more common than subclasses of dicts with
2493     // overridden keys method, we still optimize for the common case with
2494     // DictGuardManager by relying on PyDict_Next.
2495 
2496     if (guard_manager_enum.is(base_guard_manager_enum)) {
2497       // For dicts that don't need to guard on keys, we can just rely on the
2498       // base GuardManager.
2499       return std::make_unique<GuardManager>(
2500           root, std::move(source), example_value);
2501     } else if (guard_manager_enum.is(dict_guard_manager_enum)) {
2502       return std::make_unique<DictGuardManager>(
2503           root, std::move(source), example_value);
2504     } else if (guard_manager_enum.is(dict_subclass_guard_manager_enum))
2505       return std::make_unique<DictSubclassGuardManager>(
2506           root, std::move(source), example_value);
2507     else {
2508       throw py::type_error("Invalid guard manager enum");
2509     }
2510   }
2511   return std::make_unique<GuardManager>(root, std::move(source));
2512 }
2513 
2514 class TORCH_FUNCTION_MODE_STACK : public LeafGuard {
2515  public:
TORCH_FUNCTION_MODE_STACK(const py::list & initial_stack,const py::list & ignored_types,py::object verbose_code_parts)2516   TORCH_FUNCTION_MODE_STACK(
2517       const py::list& initial_stack,
2518       const py::list& ignored_types,
2519       py::object verbose_code_parts)
2520       : LeafGuard(std::move(verbose_code_parts)),
2521         _ref_stack(),
2522         _ignored_types() {
2523     Py_ssize_t len = PyList_Size(initial_stack.ptr());
2524     for (Py_ssize_t idx = 0; idx < len; idx++) {
2525       PyObject* mode = PyList_GetItem(initial_stack.ptr(), idx); // borrowed ref
2526       this->_ref_stack.push_back(Py_TYPE(mode));
2527     }
2528 
2529     len = PyList_Size(ignored_types.ptr());
2530     for (Py_ssize_t idx = 0; idx < len; idx++) {
2531       PyObject* type_obj =
2532           PyList_GetItem(ignored_types.ptr(), idx); // borrowed ref
2533       if (PyType_Check(type_obj) == 0) {
2534         PyErr_SetString(
2535             PyExc_TypeError, "ignored_types should contain a list of types");
2536         return;
2537       }
2538       PyTypeObject* type = (PyTypeObject*)type_obj;
2539       this->_ignored_types.insert(type);
2540     }
2541   }
2542 
check_nopybind(PyObject * value)2543   bool check_nopybind(PyObject* value) override {
2544     // Ignore value arg, only used to satisfy the interface
2545     size_t ref_ind = 0;
2546     int64_t len = at::impl::PythonTorchFunctionTLS::stack_len();
2547     const size_t ref_stack_size = this->_ref_stack.size();
2548 
2549     for (int64_t idx = 0; idx < len; idx++) {
2550       std::shared_ptr<c10::SafePyObject> mode =
2551           at::impl::PythonTorchFunctionTLS::get_stack_at(idx);
2552 
2553       PyTypeObject* mode_type = Py_TYPE(mode->ptr(getPyInterpreter()));
2554       // skip ignored types
2555       if (this->_ignored_types.count(mode_type) > 0) {
2556         continue;
2557       }
2558       // if we already have more non-ignored modes than the ref stack
2559       // or if the mode doesn't match at the current index, return false
2560       else if (
2561           (ref_stack_size == 0) || (ref_ind > ref_stack_size - 1) ||
2562           mode_type != _ref_stack[ref_ind]) {
2563         return false;
2564       }
2565       ref_ind++;
2566     }
2567 
2568     return ref_ind == this->_ref_stack.size();
2569   }
2570 
2571  private:
2572   std::vector<PyTypeObject*> _ref_stack;
2573   std::set<PyTypeObject*> _ignored_types;
2574 };
2575 
2576 class TENSOR_MATCH : public LeafGuard {
2577  public:
TENSOR_MATCH(RootGuardManager * root_guard_manager,py::object value,py::object dynamic_dims_sizes_py,py::object dynamic_dims_strides_py,py::object tensor_name,py::object verbose_code_parts)2578   TENSOR_MATCH(
2579       RootGuardManager* root_guard_manager,
2580       py::object value,
2581       py::object dynamic_dims_sizes_py,
2582       py::object dynamic_dims_strides_py,
2583       py::object tensor_name,
2584       py::object verbose_code_parts)
2585       : LeafGuard(root_guard_manager, std::move(verbose_code_parts)),
2586         _tensor_name(py::cast<py::str>(std::move(tensor_name))) {
2587     root_guard_manager->set_init_local_state_flag();
2588     PyObject* item = value.ptr();
2589     if (!THPVariable_CheckExact(item) && !THPVariable_Check(item)) {
2590       PyErr_SetString(PyExc_TypeError, "expected Tensor()");
2591       return;
2592     }
2593     auto tensor = THPVariable_Unpack(item);
2594 
2595     std::vector<std::optional<c10::SymInt>> tensor_dims_size =
2596         pyListToVecOptInt(dynamic_dims_sizes_py.ptr());
2597     std::vector<std::optional<c10::SymInt>> tensor_dims_stride =
2598         pyListToVecOptInt(dynamic_dims_strides_py.ptr());
2599 
2600     tensor_dims_size = tensor_dims_size.empty()
2601         ? wrapIntegersInOptional(tensor.sym_sizes())
2602         : tensor_dims_size;
2603     tensor_dims_stride = tensor_dims_stride.empty()
2604         ? wrapIntegersInOptional(tensor.sym_strides())
2605         : tensor_dims_stride;
2606     LocalState state;
2607     _tensor_check = std::make_unique<TensorCheck>(
2608         state,
2609         Py_TYPE(item),
2610         std::move(tensor),
2611         std::move(tensor_dims_size),
2612         std::move(tensor_dims_stride));
2613   }
2614 
check_nopybind(PyObject * value)2615   bool check_nopybind(PyObject* value) override { // borrowed ref
2616     if (Py_TYPE(value) != _tensor_check->pytype) {
2617       return false;
2618     }
2619     return _tensor_check->check(
2620         _root_guard_manager->_local_state, THPVariable_Unpack(value));
2621   }
2622 
check_verbose_nopybind(PyObject * value)2623   GuardDebugInfo check_verbose_nopybind(
2624       PyObject* value) override { // borrowed ref
2625 
2626     if (Py_TYPE(value) != _tensor_check->pytype) {
2627       std::stringstream fail_reason;
2628       PyObject* type_str = PyObject_Str(PyObject_Type(value));
2629       fail_reason << "expected type of '" << _tensor_name
2630                   << "' to be a tensor type, ";
2631       if (!type_str) {
2632         fail_reason << "but found a different type";
2633       } else {
2634         fail_reason << "' but found " << PyUnicode_AsUTF8(type_str);
2635       }
2636       return GuardDebugInfo(false, fail_reason.str(), 0);
2637     }
2638 
2639     std::string fail_reason = _tensor_check->check_verbose(
2640         _root_guard_manager->_local_state,
2641         THPVariable_Unpack(value),
2642         _tensor_name);
2643 
2644     if (!fail_reason.empty()) {
2645       if (is_parameter(py::handle(value))) {
2646         fail_reason += ". Guard failed on a parameter, consider using ";
2647         fail_reason +=
2648             "torch._dynamo.config.force_parameter_static_shapes = False ";
2649         fail_reason += "to allow dynamism on parameters.";
2650       }
2651       return GuardDebugInfo(false, fail_reason, 0);
2652     }
2653     return GuardDebugInfo(true, 1);
2654   }
2655 
2656  private:
2657   std::string _tensor_name;
2658   std::unique_ptr<TensorCheck> _tensor_check;
2659 };
2660 
2661 /**
2662  * Represents __getattr__ acccessor.
2663  */
2664 class GetAttrGuardAccessor : public GuardAccessor {
2665  public:
GetAttrGuardAccessor(RootGuardManager * root,py::str name,std::string source,py::handle example_value,py::handle guard_manager_enum)2666   GetAttrGuardAccessor(
2667       RootGuardManager* root,
2668       py::str name,
2669       std::string source,
2670       py::handle example_value,
2671       py::handle guard_manager_enum)
2672       : GuardAccessor(
2673             root,
2674             name,
2675             std::move(source),
2676             example_value,
2677             guard_manager_enum),
2678         _attr_name(name.ptr()) {}
2679 
2680   // NB: Intentional duplication between check_nopybind and
2681   // check_verbose_nopybind.
check_nopybind(PyObject * obj,bool matches_dict_tag=false)2682   bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
2683       override { // borrowed ref
2684     PyObject* x = PyObject_GetAttr(obj, _attr_name); // new ref
2685     if (x == nullptr) {
2686       // Attribute absent, clear the exception and return false.
2687       PyErr_Clear();
2688       return false;
2689     }
2690     bool result = _guard_manager->check_nopybind(x);
2691     Py_DECREF(x);
2692     return result;
2693   }
2694 
check_verbose_nopybind(PyObject * obj)2695   GuardDebugInfo check_verbose_nopybind(
2696       PyObject* obj) override { // borrowed ref
2697     PyObject* x = PyObject_GetAttr(obj, _attr_name); // new ref
2698     if (x == nullptr) {
2699       // Attribute absent, clear the exception and return false.
2700       PyErr_Clear();
2701       return GuardDebugInfo(
2702           false, "getattr failed on source " + get_source(), 0);
2703     }
2704     GuardDebugInfo result = _guard_manager->check_verbose_nopybind(x);
2705     Py_DECREF(x);
2706     return result;
2707   }
2708 
repr() const2709   std::string repr() const override {
2710     // Helpful when priting GuardManager tree structure.
2711     return "GetAttrGuardAccessor(" + py::str(_attr_name).cast<std::string>() +
2712         ")";
2713   }
2714 
2715  private:
2716   // no need of py::object here because the attr_name is already passed on to
2717   // the base class as accessor_key which is a py::object.
2718   PyObject* _attr_name;
2719 };
2720 
2721 /**
2722  * Represents x.__dict__ acccessor.
2723  */
2724 class GetGenericDictGuardAccessor : public GuardAccessor {
2725  public:
GetGenericDictGuardAccessor(RootGuardManager * root,py::str name,std::string source,py::handle example_value,py::handle guard_manager_enum)2726   GetGenericDictGuardAccessor(
2727       RootGuardManager* root,
2728       py::str name,
2729       std::string source,
2730       py::handle example_value,
2731       py::handle guard_manager_enum)
2732       : GuardAccessor(
2733             root,
2734             std::move(name),
2735             std::move(source),
2736             example_value,
2737             guard_manager_enum) {}
2738 
2739   // NB: Intentional duplication between check_nopybind and
2740   // check_verbose_nopybind.
check_nopybind(PyObject * obj,bool matches_dict_tag=false)2741   bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
2742       override { // borrowed ref
2743     PyObject* x = PyObject_GenericGetDict(obj, nullptr); // new ref
2744     if (x == nullptr) {
2745       // Attribute absent, clear the exception and return false.
2746       PyErr_Clear();
2747       return false;
2748     }
2749     bool result = _guard_manager->check_nopybind(x);
2750     Py_DECREF(x);
2751     return result;
2752   }
2753 
check_verbose_nopybind(PyObject * obj)2754   GuardDebugInfo check_verbose_nopybind(
2755       PyObject* obj) override { // borrowed ref
2756     PyObject* x = PyObject_GenericGetDict(obj, nullptr); // new ref
2757     if (x == nullptr) {
2758       // Attribute absent, clear the exception and return false.
2759       PyErr_Clear();
2760       return GuardDebugInfo(
2761           false, "getattr failed on source " + get_source(), 0);
2762     }
2763     GuardDebugInfo result = _guard_manager->check_verbose_nopybind(x);
2764     Py_DECREF(x);
2765     return result;
2766   }
2767 
repr() const2768   std::string repr() const override {
2769     // Helpful when priting GuardManager tree structure.
2770     return "GetGenericDictGuardAccessor";
2771   }
2772 };
2773 
2774 /**
2775  * Represents __getitem__ acccessor.
2776  */
2777 class GetItemGuardAccessor : public GuardAccessor {
2778  public:
GetItemGuardAccessor(RootGuardManager * root,py::object name,std::string source,py::handle example_value,py::handle guard_manager_enum)2779   GetItemGuardAccessor(
2780       RootGuardManager* root,
2781       py::object name,
2782       std::string source,
2783       py::handle example_value,
2784       py::handle guard_manager_enum)
2785       : GuardAccessor(
2786             root,
2787             name,
2788             std::move(source),
2789             example_value,
2790             guard_manager_enum),
2791         _attr_name(name.ptr()) {}
2792 
2793   // NB: Intentional duplication between check_nopybind and
2794   // check_verbose_nopybind.
check_nopybind(PyObject * obj,bool matches_dict_tag=false)2795   bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
2796       override { // borrowed ref
2797     PyObject* x = PyObject_GetItem(obj, _attr_name); // new ref
2798     if (x == nullptr) {
2799       PyErr_Clear();
2800       return false;
2801     }
2802     bool result = _guard_manager->check_nopybind(x);
2803     Py_DECREF(x);
2804     return result;
2805   }
2806 
check_verbose_nopybind(PyObject * obj)2807   GuardDebugInfo check_verbose_nopybind(
2808       PyObject* obj) override { // borrowed ref
2809     PyObject* x = PyObject_GetItem(obj, _attr_name); // new ref
2810     if (x == nullptr) {
2811       PyErr_Clear();
2812       return GuardDebugInfo(
2813           false, std::string("KeyError on ") + get_source(), 0);
2814     }
2815     GuardDebugInfo result = _guard_manager->check_verbose_nopybind(x);
2816     Py_DECREF(x);
2817     return result;
2818   }
2819 
repr() const2820   std::string repr() const override {
2821     return "GetItemGuardAccessor(" + py::str(_attr_name).cast<std::string>() +
2822         ")";
2823   }
2824 
2825  private:
2826   // no need of py::object here because the attr_name is already passed on to
2827   // the base class as accessor_key which is a py::object.
2828   PyObject* _attr_name;
2829 };
2830 
2831 /**
2832  * Represents dict[name] acccessor. This is ONLY used for f_locals because its a
2833  * dict, and DictGuardManager does not support sorting. We differentiate it from
2834  * GetItemGuardAccessor because PyDict_GetItem should be fasten the
2835  * PyObject_GetItem.
2836  */
2837 class DictGetItemGuardAccessor : public GuardAccessor {
2838  public:
DictGetItemGuardAccessor(RootGuardManager * root,py::object key,std::string source,py::handle example_value,py::handle guard_manager_enum)2839   DictGetItemGuardAccessor(
2840       RootGuardManager* root,
2841       py::object key,
2842       std::string source,
2843       py::handle example_value,
2844       py::handle guard_manager_enum)
2845       : GuardAccessor(
2846             root,
2847             key,
2848             std::move(source),
2849             example_value,
2850             guard_manager_enum),
2851         _key(key.ptr()),
2852         _is_immutable_object(is_immutable_object(example_value)) {}
2853 
2854   // NB: Intentional duplication between check_nopybind and
2855   // check_verbose_nopybind.
check_nopybind(PyObject * obj,bool matches_dict_tag=false)2856   bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
2857       override { // borrowed ref
2858     if (matches_dict_tag && _is_immutable_object) {
2859       // immutable object and dict tag matches, we can skip the guard subtree.
2860       return true;
2861     }
2862     PyObject* x = PyDict_GetItem(obj, _key); // borrowed ref
2863     if (x == nullptr) {
2864       PyErr_Clear();
2865       return false;
2866     }
2867     bool result = _guard_manager->check_nopybind(x);
2868     return result;
2869   }
2870 
check_verbose_nopybind(PyObject * obj)2871   GuardDebugInfo check_verbose_nopybind(
2872       PyObject* obj) override { // borrowed ref
2873     PyObject* x = PyDict_GetItem(obj, _key); // borrowed ref
2874     if (x == nullptr) {
2875       PyErr_Clear();
2876       return GuardDebugInfo(
2877           false, std::string("KeyError on ") + get_source(), 0);
2878     }
2879     GuardDebugInfo result = _guard_manager->check_verbose_nopybind(x);
2880     return result;
2881   }
2882 
repr() const2883   std::string repr() const override {
2884     return "DictGetItemGuardAccessor(" + py::str(_key).cast<std::string>() +
2885         ")";
2886   }
2887 
2888  private:
2889   PyObject* _key;
2890 
2891   // If immutable object and dict tag matches, we can skip the guard subtree and
2892   // return true.
2893   bool _is_immutable_object;
2894 };
2895 
2896 /**
2897  * Represents list[index] accessor. It is faster than generic
2898  * GetItemGuardAccessor.
2899  */
2900 class ListGetItemGuardAccessor : public GuardAccessor {
2901  public:
ListGetItemGuardAccessor(RootGuardManager * root,const py::object & index,std::string source,py::handle example_value,py::handle guard_manager_enum)2902   ListGetItemGuardAccessor(
2903       RootGuardManager* root,
2904       const py::object& index,
2905       std::string source,
2906       py::handle example_value,
2907       py::handle guard_manager_enum)
2908       : GuardAccessor(
2909             root,
2910             index,
2911             std::move(source),
2912             example_value,
2913             guard_manager_enum),
2914         _index(py::cast<Py_ssize_t>(index)) {}
2915 
2916   // NB: Intentional duplication between check_nopybind and
2917   // check_verbose_nopybind.
check_nopybind(PyObject * obj,bool matches_dict_tag=false)2918   bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
2919       override { // borrowed ref
2920     PyObject* x = PyList_GetItem(obj, _index); // borrowed ref
2921     if (x == nullptr) {
2922       PyErr_Clear();
2923       return false;
2924     }
2925     bool result = _guard_manager->check_nopybind(x);
2926     return result;
2927   }
2928 
check_verbose_nopybind(PyObject * obj)2929   GuardDebugInfo check_verbose_nopybind(
2930       PyObject* obj) override { // borrowed ref
2931     PyObject* x = PyList_GetItem(obj, _index); // borrowed ref
2932     if (x == nullptr) {
2933       PyErr_Clear();
2934       return GuardDebugInfo(
2935           false, std::string("IndexError on ") + get_source(), 0);
2936     }
2937     GuardDebugInfo result = _guard_manager->check_verbose_nopybind(x);
2938     return result;
2939   }
2940 
repr() const2941   std::string repr() const override {
2942     return "ListGetItemGuardAccessor(" + std::to_string(_index) + ")";
2943   }
2944 
2945  private:
2946   Py_ssize_t _index;
2947 };
2948 
2949 /**
2950  * Represents tuple[index] accessor. It is faster than generic
2951  * GetItemGuardAccessor.
2952  */
2953 class TupleGetItemGuardAccessor : public GuardAccessor {
2954  public:
TupleGetItemGuardAccessor(RootGuardManager * root,const py::object & index,std::string source,py::handle example_value,py::handle guard_manager_enum)2955   TupleGetItemGuardAccessor(
2956       RootGuardManager* root,
2957       const py::object& index,
2958       std::string source,
2959       py::handle example_value,
2960       py::handle guard_manager_enum)
2961       : GuardAccessor(
2962             root,
2963             index,
2964             std::move(source),
2965             example_value,
2966             guard_manager_enum),
2967         _index(py::cast<Py_ssize_t>(index)) {}
2968 
2969   // NB: Intentional duplication between check_nopybind and
2970   // check_verbose_nopybind.
check_nopybind(PyObject * obj,bool matches_dict_tag=false)2971   bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
2972       override { // borrowed ref
2973     PyObject* x = PyTuple_GetItem(obj, _index); // borrowed ref
2974     if (x == nullptr) {
2975       PyErr_Clear();
2976       return false;
2977     }
2978     bool result = _guard_manager->check_nopybind(x);
2979     return result;
2980   }
2981 
check_verbose_nopybind(PyObject * obj)2982   GuardDebugInfo check_verbose_nopybind(
2983       PyObject* obj) override { // borrowed ref
2984     PyObject* x = PyTuple_GetItem(obj, _index); // borrowed ref
2985     if (x == nullptr) {
2986       PyErr_Clear();
2987       return GuardDebugInfo(
2988           false, std::string("IndexError on ") + get_source(), 0);
2989     }
2990     GuardDebugInfo result = _guard_manager->check_verbose_nopybind(x);
2991     return result;
2992   }
2993 
repr() const2994   std::string repr() const override {
2995     return "TupleGetItemGuardAccessor(" + std::to_string(_index) + ")";
2996   }
2997 
2998  private:
2999   Py_ssize_t _index;
3000 };
3001 
3002 /**
3003  * Represents tensor.grad acccessor.
3004  */
3005 class GradGuardAccessor : public GuardAccessor {
3006  public:
GradGuardAccessor(RootGuardManager * root,py::str name,std::string source,py::handle example_value,py::handle guard_manager_enum)3007   GradGuardAccessor(
3008       RootGuardManager* root,
3009       py::str name,
3010       std::string source,
3011       py::handle example_value,
3012       py::handle guard_manager_enum)
3013       : GuardAccessor(
3014             root,
3015             std::move(name),
3016             std::move(source),
3017             example_value,
3018             guard_manager_enum) {}
3019 
3020   // NB: Intentional duplication between check_nopybind and
3021   // check_verbose_nopybind.
check_nopybind(PyObject * obj,bool matches_dict_tag=false)3022   bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
3023       override { // borrowed ref
3024     // check that its a tensor
3025     if (!THPVariable_CheckExact(obj) && !THPVariable_Check(obj)) {
3026       return false;
3027     }
3028     PyObject* grad =
3029         THPVariable_Wrap(THPVariable_Unpack(obj).grad()); // New reference
3030     bool result = _guard_manager->check_nopybind(grad);
3031     // For undefined tensor, THPVariable_Wrap returns Py_RETURN_NONE. So, no
3032     // need of Py_XDECREF.
3033     Py_DECREF(grad);
3034     return result;
3035   }
3036 
check_verbose_nopybind(PyObject * obj)3037   GuardDebugInfo check_verbose_nopybind(
3038       PyObject* obj) override { // borrowed ref
3039     // check that its a tensor
3040     if (!THPVariable_CheckExact(obj) && !THPVariable_Check(obj)) {
3041       return GuardDebugInfo(
3042           false, "not a tensor - grad field is accessed " + get_source(), 0);
3043     }
3044     PyObject* grad =
3045         THPVariable_Wrap(THPVariable_Unpack(obj).grad()); // New reference
3046     GuardDebugInfo result = _guard_manager->check_verbose_nopybind(grad);
3047     // For undefined tensor, THPVariable_Wrap returns Py_RETURN_NONE. So, no
3048     // need of Py_XDECREF.
3049     Py_DECREF(grad);
3050     return result;
3051   }
3052 
repr() const3053   std::string repr() const override {
3054     // Helpful when priting GuardManager tree structure.
3055     return "GradGuardAccessor(grad)";
3056   }
3057 };
3058 
3059 /**
3060  * Represents func.__defaults__ accessor.
3061  */
3062 class FuncDefaultsGuardAccessor : public GuardAccessor {
3063  public:
FuncDefaultsGuardAccessor(RootGuardManager * root,py::object name,std::string source,py::handle example_value,py::handle guard_manager_enum)3064   FuncDefaultsGuardAccessor(
3065       RootGuardManager* root,
3066       py::object name,
3067       std::string source,
3068       py::handle example_value,
3069       py::handle guard_manager_enum)
3070       : GuardAccessor(
3071             root,
3072             std::move(name),
3073             std::move(source),
3074             example_value,
3075             guard_manager_enum) {}
3076 
3077   // NB: Intentional duplication between check_nopybind and
3078   // check_verbose_nopybind.
check_nopybind(PyObject * obj,bool matches_dict_tag=false)3079   bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
3080       override { // borrowed ref
3081     PyObject* func = obj;
3082     if (PyMethod_Check(obj)) {
3083       func = PyMethod_GET_FUNCTION(obj); // borrowed ref
3084     } else if (PyInstanceMethod_Check(obj)) {
3085       func = PyInstanceMethod_GET_FUNCTION(obj); // borrowed ref
3086     }
3087     PyObject* x = PyFunction_GetDefaults(func); // borrowed ref
3088     if (x == nullptr) {
3089       PyErr_Clear();
3090       return false;
3091     }
3092     return _guard_manager->check_nopybind(x);
3093   }
3094 
check_verbose_nopybind(PyObject * obj)3095   GuardDebugInfo check_verbose_nopybind(
3096       PyObject* obj) override { // borrowed ref
3097     PyObject* func = obj;
3098     if (PyMethod_Check(obj)) {
3099       func = PyMethod_GET_FUNCTION(obj); // borrowed ref
3100     } else if (PyInstanceMethod_Check(obj)) {
3101       func = PyInstanceMethod_GET_FUNCTION(obj); // borrowed ref
3102     }
3103     PyObject* x = PyFunction_GetDefaults(func);
3104     if (x == nullptr) {
3105       PyErr_Clear();
3106       return GuardDebugInfo(
3107           false,
3108           std::string(repr() + ": Not a function on ") + get_source(),
3109           0);
3110     }
3111 
3112     return _guard_manager->check_verbose_nopybind(x);
3113   }
3114 
repr() const3115   std::string repr() const override {
3116     return "FuncDefaultsGuardAccessor";
3117   }
3118 };
3119 
3120 /**
3121  * Represents func.__kwdefaults__ accessor.
3122  */
3123 class FuncKwDefaultsGuardAccessor : public GuardAccessor {
3124  public:
FuncKwDefaultsGuardAccessor(RootGuardManager * root,py::object name,std::string source,py::handle example_value,py::handle guard_manager_enum)3125   FuncKwDefaultsGuardAccessor(
3126       RootGuardManager* root,
3127       py::object name,
3128       std::string source,
3129       py::handle example_value,
3130       py::handle guard_manager_enum)
3131       : GuardAccessor(
3132             root,
3133             std::move(name),
3134             std::move(source),
3135             example_value,
3136             guard_manager_enum) {}
3137 
3138   // NB: Intentional duplication between check_nopybind and
3139   // check_verbose_nopybind.
check_nopybind(PyObject * obj,bool matches_dict_tag=false)3140   bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
3141       override { // borrowed ref
3142     PyObject* func = obj;
3143     if (PyMethod_Check(obj)) {
3144       func = PyMethod_GET_FUNCTION(obj); // borrowed ref
3145     } else if (PyInstanceMethod_Check(obj)) {
3146       func = PyInstanceMethod_GET_FUNCTION(obj); // borrowed ref
3147     }
3148     PyObject* x = PyFunction_GetKwDefaults(func); // borrowed ref
3149     if (x == nullptr) {
3150       PyErr_Clear();
3151       return false;
3152     }
3153     return _guard_manager->check_nopybind(x);
3154   }
3155 
check_verbose_nopybind(PyObject * obj)3156   GuardDebugInfo check_verbose_nopybind(
3157       PyObject* obj) override { // borrowed ref
3158     PyObject* func = obj;
3159     if (PyMethod_Check(obj)) {
3160       func = PyMethod_GET_FUNCTION(obj); // borrowed ref
3161     } else if (PyInstanceMethod_Check(obj)) {
3162       func = PyInstanceMethod_GET_FUNCTION(obj); // borrowed ref
3163     }
3164     PyObject* x = PyFunction_GetKwDefaults(func);
3165     if (x == nullptr) {
3166       PyErr_Clear();
3167       return GuardDebugInfo(
3168           false,
3169           std::string(repr() + ": Not a function on ") + get_source(),
3170           0);
3171     }
3172 
3173     return _guard_manager->check_verbose_nopybind(x);
3174   }
3175 
repr() const3176   std::string repr() const override {
3177     return "FuncKwDefaultsGuardAccessor";
3178   }
3179 };
3180 
3181 /**
3182  * Represents f_globals acccessor. This sits as a child accessor of the
3183  * RootGuardManager.
3184  */
3185 class GlobalsGuardAccessor : public GuardAccessor {
3186  public:
GlobalsGuardAccessor(RootGuardManager * root,py::dict globals_dict,std::string source,py::handle example_value,py::handle guard_manager_enum)3187   GlobalsGuardAccessor(
3188       RootGuardManager* root,
3189       py::dict globals_dict,
3190       std::string source,
3191       py::handle example_value,
3192       py::handle guard_manager_enum)
3193       : GuardAccessor(
3194             root,
3195             globals_dict,
3196             std::move(source),
3197             example_value,
3198             guard_manager_enum),
3199         _globals_dict(globals_dict.ptr()) {}
3200 
3201   // NB: Intentional duplication between check_nopybind and
3202   // check_verbose_nopybind.
check_nopybind(PyObject * obj,bool matches_dict_tag=false)3203   bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
3204       override { // borrowed ref
3205     // Ignore the obj arg. This is required to satisfy the function signature.
3206     // Just pass on the globals dict to the child manager.
3207     return _guard_manager->check_nopybind(_globals_dict);
3208   }
3209 
check_verbose_nopybind(PyObject * obj)3210   GuardDebugInfo check_verbose_nopybind(
3211       PyObject* obj) override { // borrowed ref
3212     // Ignore the obj arg. This is required to satisfy the function signature.
3213     // Just pass on the globals dict to the child manager.
3214     return _guard_manager->check_verbose_nopybind(_globals_dict);
3215   }
3216 
repr() const3217   std::string repr() const override {
3218     return "GlobalsGuardAccessor";
3219   }
3220 
3221  private:
3222   // no need of py::object here because the globals_dict is already passed on to
3223   // the base class as accessor_key which is a py::object.
3224   PyObject* _globals_dict;
3225 };
3226 
3227 /**
3228  * Represent type(...) accessor.
3229  */
3230 class TypeGuardAccessor : public GuardAccessor {
3231  public:
3232   // name = __type_accessor__, a unique string used as attribute name.
TypeGuardAccessor(RootGuardManager * root,py::str name,std::string source,py::handle example_value,py::handle guard_manager_enum)3233   TypeGuardAccessor(
3234       RootGuardManager* root,
3235       py::str name,
3236       std::string source,
3237       py::handle example_value,
3238       py::handle guard_manager_enum)
3239       : GuardAccessor(
3240             root,
3241             std::move(name),
3242             std::move(source),
3243             example_value,
3244             guard_manager_enum) {}
3245 
3246   // NB: Intentional duplication between check_nopybind and
3247   // check_verbose_nopybind.
check_nopybind(PyObject * obj,bool matches_dict_tag=false)3248   bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
3249       override { // borrowed ref
3250     PyObject* x = (PyObject*)Py_TYPE(obj); // borrowed ref
3251     return _guard_manager->check_nopybind(x);
3252   }
3253 
check_verbose_nopybind(PyObject * obj)3254   GuardDebugInfo check_verbose_nopybind(
3255       PyObject* obj) override { // borrowed ref
3256     PyObject* x = (PyObject*)Py_TYPE(obj); // borrowed ref
3257     return _guard_manager->check_verbose_nopybind(x);
3258   }
3259 
repr() const3260   std::string repr() const override {
3261     return "TypeGuardAccessor";
3262   }
3263 };
3264 
3265 /**
3266  * Getitem tuple_iterator accessor.
3267  */
3268 class TupleIteratorGetItemAccessor : public GuardAccessor {
3269  public:
TupleIteratorGetItemAccessor(RootGuardManager * root,py::object index,std::string source,py::handle example_value,py::handle guard_manager_enum)3270   TupleIteratorGetItemAccessor(
3271       RootGuardManager* root,
3272       py::object index,
3273       std::string source,
3274       py::handle example_value,
3275       py::handle guard_manager_enum)
3276       : GuardAccessor(
3277             root,
3278             index,
3279             std::move(source),
3280             example_value,
3281             guard_manager_enum),
3282         _index(py::cast<Py_ssize_t>(std::move(index))) {}
3283 
3284   // NB: Intentional duplication between check_nopybind and
3285   // check_verbose_nopybind.
check_nopybind(PyObject * obj,bool matches_dict_tag=false)3286   bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
3287       override { // borrowed ref
3288     _PyTupleIterObject* it = (_PyTupleIterObject*)obj;
3289     PyObject* x =
3290         PyTuple_GET_ITEM(it->it_seq, it->it_index + _index); // borrowed ref
3291     if (x == nullptr) {
3292       // Out of range.
3293       PyErr_Clear();
3294       return false;
3295     }
3296     bool result = _guard_manager->check_nopybind(x);
3297     return result;
3298   }
3299 
check_verbose_nopybind(PyObject * obj)3300   GuardDebugInfo check_verbose_nopybind(
3301       PyObject* obj) override { // borrowed ref
3302     _PyTupleIterObject* it = (_PyTupleIterObject*)obj;
3303     PyObject* x =
3304         PyTuple_GET_ITEM(it->it_seq, it->it_index + _index); // borrowed ref
3305     if (x == nullptr) {
3306       // Out of range.
3307       PyErr_Clear();
3308       return GuardDebugInfo(false, std::string("IndexError ") + repr(), 0);
3309     }
3310     GuardDebugInfo result = _guard_manager->check_verbose_nopybind(x);
3311     return result;
3312   }
3313 
repr() const3314   std::string repr() const override {
3315     return "TupleIteratorGetItemAccessor(" + std::to_string(_index) + ")";
3316   }
3317 
3318  private:
3319   Py_ssize_t _index;
3320 };
3321 
3322 /**
3323  * GlobalWeakRef accessor. Dynamo can insert a weakref object into the frame
3324  * globals. This accessor reads the globals and then calls the weakref object
3325  * to get the underlying object. This is a child of GlobalsGuardAccessor.
3326  * Therefore, we will get the globals dict while caling check_nopybind.
3327  */
3328 class GlobalWeakRefGuardAccessor : public GuardAccessor {
3329  public:
GlobalWeakRefGuardAccessor(RootGuardManager * root,py::object global_name,std::string source,py::handle example_value,py::handle guard_manager_enum)3330   GlobalWeakRefGuardAccessor(
3331       RootGuardManager* root,
3332       py::object global_name,
3333       std::string source,
3334       py::handle example_value,
3335       py::handle guard_manager_enum)
3336       : GuardAccessor(
3337             root,
3338             global_name,
3339             std::move(source),
3340             example_value,
3341             guard_manager_enum),
3342         _global_name(global_name.ptr()) {}
3343 
3344   // NB: Intentional duplication between check_nopybind and
3345   // check_verbose_nopybind.
check_nopybind(PyObject * obj,bool matches_dict_tag=false)3346   bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
3347       override { // borrowed ref
3348     // obj is globals dict because GlobalWeakRefGuardAccessor has to be a
3349     // child of GlobalsGuardAccessor.
3350     PyObject* weakref = PyDict_GetItem(obj, _global_name); // borrowed ref
3351     if (weakref == nullptr) {
3352       // The weakref is not in the globals dict.
3353       PyErr_Clear();
3354       return false;
3355     }
3356 
3357     if (!PyWeakref_Check(weakref)) {
3358       return false;
3359     }
3360 
3361     PyObject* x = PyWeakref_GetObject(weakref); // borrowed ref
3362     return _guard_manager->check_nopybind(x);
3363   }
3364 
check_verbose_nopybind(PyObject * obj)3365   GuardDebugInfo check_verbose_nopybind(
3366       PyObject* obj) override { // borrowed ref
3367     // obj is globals dict because GlobalWeakRefGuardAccessor has to be a
3368     // child of GlobalsGuardAccessor.
3369     PyObject* weakref = PyDict_GetItem(obj, _global_name); // borrowed ref
3370     if (weakref == nullptr) {
3371       // The weakref is not in the globals dict.
3372       PyErr_Clear();
3373       return GuardDebugInfo(
3374           false, std::string("KeyError on ") + get_source(), 0);
3375     }
3376 
3377     if (!PyWeakref_Check(weakref)) {
3378       return GuardDebugInfo(
3379           false, std::string("Not a weakref ") + get_source(), 0);
3380     }
3381 
3382     PyObject* x = PyWeakref_GetObject(weakref); // borrowed ref
3383     return _guard_manager->check_verbose_nopybind(x);
3384   }
3385 
repr() const3386   std::string repr() const override {
3387     return "GlobalWeakRefGuardAccessor(" +
3388         py::str(_global_name).cast<std::string>() + ")";
3389   }
3390 
3391  private:
3392   PyObject* _global_name;
3393 };
3394 
3395 /**
3396  * Implements weakref call - x_weak()
3397  */
3398 class WeakRefCallGuardAccessor : public GuardAccessor {
3399  public:
WeakRefCallGuardAccessor(RootGuardManager * root,py::str name,std::string source,py::handle example_value,py::handle guard_manager_enum)3400   WeakRefCallGuardAccessor(
3401       RootGuardManager* root,
3402       py::str name,
3403       std::string source,
3404       py::handle example_value,
3405       py::handle guard_manager_enum)
3406       : GuardAccessor(
3407             root,
3408             std::move(name),
3409             std::move(source),
3410             example_value,
3411             guard_manager_enum) {}
3412 
3413   // NB: Intentional duplication between check_nopybind and
3414   // check_verbose_nopybind.
check_nopybind(PyObject * obj,bool matches_dict_tag=false)3415   bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
3416       override { // borrowed ref
3417     if (!PyWeakref_Check(obj)) {
3418       return false;
3419     }
3420 
3421     PyObject* x = PyWeakref_GetObject(obj); // borrowed ref
3422     return _guard_manager->check_nopybind(x);
3423   }
3424 
check_verbose_nopybind(PyObject * obj)3425   GuardDebugInfo check_verbose_nopybind(
3426       PyObject* obj) override { // borrowed ref
3427     if (!PyWeakref_Check(obj)) {
3428       return GuardDebugInfo(
3429           false, std::string("Not a weakref obj ") + get_source(), 0);
3430     }
3431 
3432     PyObject* x = PyWeakref_GetObject(obj); // borrowed ref
3433     return _guard_manager->check_verbose_nopybind(x);
3434   }
3435 
repr() const3436   std::string repr() const override {
3437     return "WeakRefCallGuardAccessor()";
3438   }
3439 };
3440 
3441 /**
3442  * Similar to PythonLambdaLeafGuard, this class is a way to allow developers to
3443  * supply accessor as a python function. This is useful for from_numpy source.
3444  */
3445 class PythonLambdaGuardAccessor : public GuardAccessor {
3446  public:
PythonLambdaGuardAccessor(RootGuardManager * root,py::function accessor_fn,std::string source,py::handle example_value,py::handle guard_manager_enum)3447   PythonLambdaGuardAccessor(
3448       RootGuardManager* root,
3449       py::function accessor_fn,
3450       std::string source,
3451       py::handle example_value,
3452       py::handle guard_manager_enum)
3453       : GuardAccessor(
3454             root,
3455             accessor_fn,
3456             std::move(source),
3457             example_value,
3458             guard_manager_enum),
3459         _accessor_fn(std::move(accessor_fn)) {}
3460 
3461   // NB: Intentional duplication between check_nopybind and
3462   // check_verbose_nopybind.
check_nopybind(PyObject * obj,bool matches_dict_tag=false)3463   bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
3464       override { // borrowed ref
3465     PyObject* x = PyObject_CallOneArg(_accessor_fn.ptr(), obj); // new ref
3466     if (x == nullptr) {
3467       // The accessor function failed.
3468       PyErr_Clear();
3469       return false;
3470     }
3471     bool result = _guard_manager->check_nopybind(x);
3472     Py_DECREF(x);
3473     return result;
3474   }
3475 
check_verbose_nopybind(PyObject * obj)3476   GuardDebugInfo check_verbose_nopybind(
3477       PyObject* obj) override { // borrowed ref
3478     PyObject* x = PyObject_CallOneArg(_accessor_fn.ptr(), obj); // new ref
3479     if (x == nullptr) {
3480       // The accessor function failed.
3481       std::string exc_message = get_exception_message();
3482       PyErr_Clear();
3483       return GuardDebugInfo(false, exc_message, 0);
3484     }
3485     GuardDebugInfo result = _guard_manager->check_verbose_nopybind(x);
3486     Py_DECREF(x);
3487     return result;
3488   }
3489 
repr() const3490   std::string repr() const override {
3491     return "PythonLambdaGuardAccessor";
3492   }
3493 
3494  private:
3495   py::object _accessor_fn;
3496 };
3497 
install_object_aliasing_guard(GuardManager * x,GuardManager * y,py::object verbose_code_parts)3498 void install_object_aliasing_guard(
3499     GuardManager* x,
3500     GuardManager* y,
3501     py::object verbose_code_parts) {
3502   // Adds tensor X is tensor Y guard. This is a an example of relational guard.
3503   // There is one guard object that is shared between two guard managers.
3504   std::shared_ptr<RelationalGuard> guard =
3505       std::make_shared<OBJECT_ALIASING>(std::move(verbose_code_parts));
3506 
3507   // Register the resetter on the toor guard mananger, so that it can reset
3508   // the newly added relational guard when the guard eval fails.
3509   x->get_root()->add_relational_guard_resetter(guard);
3510 
3511   // In case the guard is a DictGuardManager, OBJECT_ALIASING guard is a
3512   // permitted guard.
3513   x->add_permitted_leaf_guard(guard);
3514   y->add_permitted_leaf_guard(guard);
3515 }
3516 
install_no_tensor_aliasing_guard(const py::list & guard_managers,const py::list & tensor_names,py::object verbose_code_parts)3517 void install_no_tensor_aliasing_guard(
3518     const py::list& guard_managers,
3519     const py::list& tensor_names,
3520     py::object verbose_code_parts) {
3521   // Adds a guard that checks none of tensors alias. This is a an example of
3522   // relational guard. There is one guard object that is shared between multiple
3523   // guard managers.
3524   std::shared_ptr<RelationalGuard> guard = std::make_shared<NO_TENSOR_ALIASING>(
3525       tensor_names, std::move(verbose_code_parts));
3526 
3527   // Register the resetter on the toor guard mananger, so that it can reset
3528   // the newly added relational guard when the guard eval fails.
3529   py::cast<GuardManager*>(guard_managers[0])
3530       ->get_root()
3531       ->add_relational_guard_resetter(guard);
3532   for (const auto& guard_manager : guard_managers) {
3533     py::cast<GuardManager*>(guard_manager)->add_leaf_guard(guard);
3534   }
3535 }
3536 
3537 } // namespace
3538 
_torchinductor_pyobject_tensor_data_ptr(PyObject * obj)3539 static void* _torchinductor_pyobject_tensor_data_ptr(PyObject* obj) {
3540   if (C10_UNLIKELY(
3541           obj == nullptr ||
3542           (!THPVariable_CheckExact(obj) && !THPVariable_Check(obj)))) {
3543     throw std::runtime_error(
3544         "_torchinductor_pyobject_tensor_data_ptr: non-tensor input");
3545   }
3546   return THPVariable_Unpack(obj).data_ptr();
3547 }
3548 
convert_to_root_guard_manager(py::object root)3549 void* convert_to_root_guard_manager(py::object root) {
3550   RootGuardManager* root_mgr = std::move(root).cast<RootGuardManager*>();
3551   return (void*)root_mgr;
3552 }
3553 
run_root_guard_manager(void * root,PyObject * f_locals)3554 bool run_root_guard_manager(void* root, PyObject* f_locals) {
3555   return ((RootGuardManager*)root)->check_nopybind(f_locals);
3556 }
3557 
torch_c_dynamo_guards_init()3558 PyObject* torch_c_dynamo_guards_init() {
3559   // initialize TensorGuardsType
3560   TensorGuardsType.tp_name = "torch._C._dynamo.guards.TensorGuards";
3561   TensorGuardsType.tp_basicsize = sizeof(TensorGuards);
3562   TensorGuardsType.tp_itemsize = 0;
3563   TensorGuardsType.tp_dealloc = (destructor)TensorGuards_dealloc;
3564   TensorGuardsType.tp_flags = Py_TPFLAGS_DEFAULT;
3565   TensorGuardsType.tp_doc = "Check properties of a torch.Tensor";
3566   TensorGuardsType.tp_methods = TensorGuards_methods;
3567   TensorGuardsType.tp_init = (initproc)TensorGuards_init;
3568   TensorGuardsType.tp_new = TensorGuards_new;
3569 
3570   if (PyType_Ready(&TensorGuardsType) < 0)
3571     return nullptr;
3572 
3573   GlobalStateGuardType.tp_name = "torch._C._dynamo.guards.GlobalStateGuard";
3574   GlobalStateGuardType.tp_basicsize = sizeof(GlobalStateGuard);
3575   GlobalStateGuardType.tp_itemsize = 0;
3576   GlobalStateGuardType.tp_flags = Py_TPFLAGS_DEFAULT;
3577   GlobalStateGuardType.tp_doc = "Guard on PyTorch global flags such as no_grad";
3578   GlobalStateGuardType.tp_methods = GlobalStateGuard_methods;
3579   GlobalStateGuardType.tp_init = (initproc)GlobalStateGuard_init;
3580   GlobalStateGuardType.tp_new = PyType_GenericNew;
3581 
3582   if (PyType_Ready(&GlobalStateGuardType) < 0)
3583     return nullptr;
3584 
3585   auto m = PyModule_Create(&_module);
3586   if (m == nullptr)
3587     return nullptr;
3588 
3589   Py_INCREF(&TensorGuardsType);
3590   if (PyModule_AddObject(m, "TensorGuards", (PyObject*)&TensorGuardsType) < 0) {
3591     Py_DECREF(&TensorGuardsType);
3592     Py_DECREF(m);
3593     return nullptr;
3594   }
3595 
3596   Py_INCREF(&GlobalStateGuardType);
3597   if (PyModule_AddObject(
3598           m, "GlobalStateGuard", (PyObject*)&GlobalStateGuardType) < 0) {
3599     Py_DECREF(&GlobalStateGuardType);
3600     Py_DECREF(m);
3601     return nullptr;
3602   }
3603 
3604   // We expose the address of _torchinductor_pyobject_tensor_data_ptr in order
3605   // to allow manual linking in our generated TorchInductor Python bindings.
3606   // While regular linking works in most cases, it does not work properly in
3607   // fbcode due to janky build setup there.
3608   if (PyModule_AddObject(
3609           m,
3610           "_torchinductor_pyobject_tensor_data_ptr",
3611           PyLong_FromVoidPtr(reinterpret_cast<void*>(
3612               &_torchinductor_pyobject_tensor_data_ptr))) < 0) {
3613     return nullptr;
3614   }
3615 
3616   auto py_m = py::handle(m).cast<py::module>();
3617   py::class_<GuardDebugInfo, std::unique_ptr<GuardDebugInfo>>(
3618       py_m, "GuardDebugInfo")
3619       .def(py::init<bool, py::list, int>())
3620       .def("__str__", &GuardDebugInfo::to_string)
3621       .def_readonly("result", &GuardDebugInfo::result)
3622       .def_readonly("verbose_code_parts", &GuardDebugInfo::verbose_code_parts)
3623       .def_readonly(
3624           "num_guards_executed", &GuardDebugInfo::num_guards_executed);
3625 
3626   // Leaf Guards
3627   py::class_<LeafGuard, std::shared_ptr<LeafGuard>>(py_m, "LeafGuard")
3628       .def("verbose_code_parts", &LeafGuard::verbose_code_parts);
3629   py::class_<LAMBDA_GUARD, LeafGuard, std::shared_ptr<LAMBDA_GUARD>>(
3630       py_m, "LAMBDA_GUARD")
3631       .def(py::init<py::function, py::list>())
3632       .def("__call__", &LAMBDA_GUARD::check);
3633   py::class_<TYPE_MATCH, LeafGuard, std::shared_ptr<TYPE_MATCH>>(
3634       py_m, "TYPE_MATCH")
3635       .def(py::init<py::object, py::list>())
3636       .def("__call__", &TYPE_MATCH::check);
3637   py::class_<ID_MATCH, LeafGuard, std::shared_ptr<ID_MATCH>>(py_m, "ID_MATCH")
3638       .def(py::init<py::object, py::list>())
3639       .def("__call__", &ID_MATCH::check);
3640   py::class_<EQUALS_MATCH, LeafGuard, std::shared_ptr<EQUALS_MATCH>>(
3641       py_m, "EQUALS_MATCH")
3642       .def(py::init<py::object, py::list>())
3643       .def("__call__", &EQUALS_MATCH::check);
3644   py::class_<LENGTH_CHECK, LeafGuard, std::shared_ptr<LENGTH_CHECK>>(
3645       py_m, "LENGTH_CHECK")
3646       .def(py::init<py::object, py::list>())
3647       .def("__call__", &LENGTH_CHECK::check);
3648   py::class_<DICT_LENGTH, LeafGuard, std::shared_ptr<DICT_LENGTH>>(
3649       py_m, "DICT_LENGTH")
3650       .def(py::init<py::object, py::list>())
3651       .def("__call__", &DICT_LENGTH::check);
3652   py::class_<DEFAULT_DEVICE, LeafGuard, std::shared_ptr<DEFAULT_DEVICE>>(
3653       py_m, "DEFAULT_DEVICE")
3654       .def(py::init<py::list>())
3655       .def("__call__", &DEFAULT_DEVICE::check);
3656   py::class_<NOT_NONE, LeafGuard, std::shared_ptr<NOT_NONE>>(py_m, "NOT_NONE")
3657       .def(py::init<py::list>())
3658       .def("__call__", &NOT_NONE::check);
3659   py::class_<
3660       TUPLE_ITERATOR_LEN,
3661       LeafGuard,
3662       std::shared_ptr<TUPLE_ITERATOR_LEN>>(py_m, "TUPLE_ITERATOR_LEN")
3663       .def(py::init<py::object, py::object, py::list>())
3664       .def("__call__", &TUPLE_ITERATOR_LEN::check);
3665   py::class_<GLOBAL_STATE, LeafGuard, std::shared_ptr<GLOBAL_STATE>>(
3666       py_m, "GLOBAL_STATE")
3667       .def(py::init<py::list>())
3668       .def("check_verbose", &GLOBAL_STATE::check_verbose)
3669       .def("__call__", &GLOBAL_STATE::check);
3670   py::class_<
3671       TORCH_FUNCTION_MODE_STACK,
3672       LeafGuard,
3673       std::shared_ptr<TORCH_FUNCTION_MODE_STACK>>(
3674       py_m, "TORCH_FUNCTION_MODE_STACK")
3675       .def(py::init<py::list, py::list, py::list>())
3676       .def("__call__", &TORCH_FUNCTION_MODE_STACK::check);
3677   py::class_<DATA_PTR_MATCH, LeafGuard, std::shared_ptr<DATA_PTR_MATCH>>(
3678       py_m, "DATA_PTR_MATCH")
3679       .def(py::init<py::object, py::list>())
3680       .def("__call__", &DATA_PTR_MATCH::check);
3681   py::class_<NO_HASATTR, LeafGuard, std::shared_ptr<NO_HASATTR>>(
3682       py_m, "NO_HASATTR")
3683       .def(py::init<py::object, py::list>())
3684       .def("__call__", &NO_HASATTR::check);
3685   py::class_<DICT_CONTAINS, LeafGuard, std::shared_ptr<DICT_CONTAINS>>(
3686       py_m, "DICT_CONTAINS")
3687       .def(py::init<bool, py::object, py::list>())
3688       .def("__call__", &DICT_CONTAINS::check);
3689   py::class_<DYNAMIC_INDICES, LeafGuard, std::shared_ptr<DYNAMIC_INDICES>>(
3690       py_m, "DYNAMIC_INDICES")
3691       .def(py::init<py::set, py::list>())
3692       .def("__call__", &DYNAMIC_INDICES::check);
3693   py::class_<DICT_VERSION, LeafGuard, std::shared_ptr<DICT_VERSION>>(
3694       py_m, "DICT_VERSION")
3695       .def(py::init<py::object, py::list>())
3696       .def("__call__", &DICT_VERSION::check);
3697   py::class_<TENSOR_MATCH, LeafGuard, std::shared_ptr<TENSOR_MATCH>>(
3698       py_m, "TENSOR_MATCH")
3699       .def(py::init<
3700            RootGuardManager*,
3701            py::object,
3702            py::object,
3703            py::object,
3704            py::str,
3705            py::list>())
3706       .def("__call__", &TENSOR_MATCH::check);
3707   // NOLINTNEXTLINE(bugprone-unused-raii)
3708   py::class_<OBJECT_ALIASING, LeafGuard, std::shared_ptr<OBJECT_ALIASING>>(
3709       py_m, "OBJECT_ALIASING");
3710   // NOLINTNEXTLINE(bugprone-unused-raii)
3711   py::class_<
3712       NO_TENSOR_ALIASING,
3713       LeafGuard,
3714       std::shared_ptr<NO_TENSOR_ALIASING>>(py_m, "NO_TENSOR_ALIASING");
3715 
3716   // Guard Accessors - These are present so that we can iterate over the
3717   // GuardManager hierarchy. We intentionally do not provide even an init
3718   // function on these, because these should be constructed from within C++.
3719   py::class_<GuardAccessor, std::unique_ptr<GuardAccessor>>(
3720       py_m, "GuardAccessor")
3721       .def("repr", &GuardAccessor::repr);
3722   // NOLINTNEXTLINE(bugprone-unused-raii)
3723   py::class_<
3724       GetAttrGuardAccessor,
3725       GuardAccessor,
3726       std::unique_ptr<GetAttrGuardAccessor>>(py_m, "GetAttrGuardAccessor");
3727   // NOLINTNEXTLINE(bugprone-unused-raii)
3728   py::class_<
3729       GetGenericDictGuardAccessor,
3730       GuardAccessor,
3731       std::unique_ptr<GetGenericDictGuardAccessor>>(
3732       py_m, "GetGenericDictGuardAccessor");
3733   // NOLINTNEXTLINE(bugprone-unused-raii)
3734   py::class_<
3735       GetItemGuardAccessor,
3736       GuardAccessor,
3737       std::unique_ptr<GetItemGuardAccessor>>(py_m, "GetItemGuardAccessor");
3738   // NOLINTNEXTLINE(bugprone-unused-raii)
3739   py::class_<
3740       DictGetItemGuardAccessor,
3741       GuardAccessor,
3742       std::unique_ptr<DictGetItemGuardAccessor>>(
3743       py_m, "DictGetItemGuardAccessor");
3744   // NOLINTNEXTLINE(bugprone-unused-raii)
3745   py::class_<
3746       ListGetItemGuardAccessor,
3747       GuardAccessor,
3748       std::unique_ptr<ListGetItemGuardAccessor>>(
3749       py_m, "ListGetItemGuardAccessor");
3750   // NOLINTNEXTLINE(bugprone-unused-raii)
3751   py::class_<
3752       TupleGetItemGuardAccessor,
3753       GuardAccessor,
3754       std::unique_ptr<TupleGetItemGuardAccessor>>(
3755       py_m, "TupleGetItemGuardAccessor");
3756   // NOLINTNEXTLINE(bugprone-unused-raii)
3757   py::class_<
3758       FuncDefaultsGuardAccessor,
3759       GuardAccessor,
3760       std::unique_ptr<FuncDefaultsGuardAccessor>>(
3761       py_m, "FuncDefaultsGuardAccessor");
3762   // NOLINTNEXTLINE(bugprone-unused-raii)
3763   py::class_<
3764       FuncKwDefaultsGuardAccessor,
3765       GuardAccessor,
3766       std::unique_ptr<FuncKwDefaultsGuardAccessor>>(
3767       py_m, "FuncKwDefaultsGuardAccessor");
3768   // NOLINTNEXTLINE(bugprone-unused-raii)
3769   py::class_<
3770       GlobalsGuardAccessor,
3771       GuardAccessor,
3772       std::unique_ptr<GlobalsGuardAccessor>>(py_m, "GlobalsGuardAccessor");
3773   // NOLINTNEXTLINE(bugprone-unused-raii)
3774   py::class_<
3775       TypeGuardAccessor,
3776       GuardAccessor,
3777       std::unique_ptr<TypeGuardAccessor>>(py_m, "TypeGuardAccessor");
3778   // NOLINTNEXTLINE(bugprone-unused-raii)
3779   py::class_<
3780       WeakRefCallGuardAccessor,
3781       GuardAccessor,
3782       std::unique_ptr<WeakRefCallGuardAccessor>>(
3783       py_m, "WeakRefCallGuardAccessor");
3784   // NOLINTNEXTLINE(bugprone-unused-raii)
3785   py::class_<
3786       TupleIteratorGetItemAccessor,
3787       GuardAccessor,
3788       std::unique_ptr<TupleIteratorGetItemAccessor>>(
3789       py_m, "TupleIteratorGetItemAccessor");
3790   // NOLINTNEXTLINE(bugprone-unused-raii)
3791   py::class_<
3792       GlobalWeakRefGuardAccessor,
3793       GuardAccessor,
3794       std::unique_ptr<GlobalWeakRefGuardAccessor>>(
3795       py_m, "GlobalWeakRefGuardAccessor");
3796 
3797   // Guard Manager - No constructor in python, python should use
3798   // RootGuardManager.
3799   py::class_<GuardManager, std::unique_ptr<GuardManager>>(py_m, "GuardManager")
3800       // return by reference because GuardManager has the ownership of accessors
3801       .def("get_source", &GuardManager::get_source)
3802       .def(
3803           "get_accessors",
3804           &GuardManager::get_accessors,
3805           py::return_value_policy::reference)
3806       // return by reference because GuardManager has the ownership of child
3807       // managers
3808       .def(
3809           "get_child_managers",
3810           &GuardManager::get_child_managers,
3811           py::return_value_policy::reference)
3812       // return by reference because GuardManager has the ownership of leaf
3813       // guards
3814       .def(
3815           "get_leaf_guards",
3816           &GuardManager::get_leaf_guards,
3817           py::return_value_policy::reference)
3818       .def(
3819           "add_lambda_guard",
3820           [](GuardManager& self,
3821              py::object lambda,
3822              py::object verbose_code_parts) -> void {
3823             self.add_leaf_guard(std::make_shared<LAMBDA_GUARD>(
3824                 std::move(lambda), std::move(verbose_code_parts)));
3825           })
3826       .def(
3827           "add_type_match_guard",
3828           [](GuardManager& self,
3829              py::object value,
3830              py::object verbose_code_parts) -> void {
3831             SKIP_IF_GUARD_ALREADY_PRESENT("TYPE_MATCH");
3832             self.add_leaf_guard(std::make_shared<TYPE_MATCH>(
3833                 std::move(value), std::move(verbose_code_parts)));
3834           })
3835       .def(
3836           "add_id_match_guard",
3837           [](GuardManager& self,
3838              py::object value,
3839              py::object verbose_code_parts) -> void {
3840             SKIP_IF_GUARD_ALREADY_PRESENT("ID_MATCH");
3841             self.add_leaf_guard(std::make_shared<ID_MATCH>(
3842                 std::move(value), std::move(verbose_code_parts)));
3843           })
3844       .def(
3845           "add_equals_match_guard",
3846           [](GuardManager& self,
3847              py::object value,
3848              py::object verbose_code_parts) -> void {
3849             SKIP_IF_GUARD_ALREADY_PRESENT("EQUALS_MATCH");
3850             self.add_leaf_guard(std::make_shared<EQUALS_MATCH>(
3851                 std::move(value), std::move(verbose_code_parts)));
3852           })
3853       .def(
3854           "add_length_check_guard",
3855           [](GuardManager& self,
3856              py::object value,
3857              py::object verbose_code_parts) -> void {
3858             SKIP_IF_GUARD_ALREADY_PRESENT("LENGTH_CHECK");
3859             self.add_leaf_guard(std::make_shared<LENGTH_CHECK>(
3860                 std::move(value), std::move(verbose_code_parts)));
3861           })
3862       .def(
3863           "add_dict_length_check_guard",
3864           [](GuardManager& self,
3865              py::object value,
3866              py::object verbose_code_parts) -> void {
3867             SKIP_IF_GUARD_ALREADY_PRESENT("DICT_LENGTH");
3868             self.add_leaf_guard(std::make_shared<DICT_LENGTH>(
3869                 std::move(value), std::move(verbose_code_parts)));
3870           })
3871       .def(
3872           "add_tuple_iterator_length_guard",
3873           [](GuardManager& self,
3874              py::object length,
3875              py::object type_id,
3876              py::object verbose_code_parts) -> void {
3877             SKIP_IF_GUARD_ALREADY_PRESENT("TUPLE_ITERATOR_LEN");
3878             self.add_leaf_guard(std::make_shared<TUPLE_ITERATOR_LEN>(
3879                 std::move(length),
3880                 std::move(type_id),
3881                 std::move(verbose_code_parts)));
3882           })
3883       .def(
3884           "add_default_device_guard",
3885           [](GuardManager& self, py::object verbose_code_parts) -> void {
3886             self.add_leaf_guard(std::make_shared<DEFAULT_DEVICE>(
3887                 std::move(verbose_code_parts)));
3888           })
3889       .def(
3890           "add_not_none_guard",
3891           [](GuardManager& self, py::object verbose_code_parts) -> void {
3892             SKIP_IF_GUARD_ALREADY_PRESENT("NOT_NONE");
3893             self.add_leaf_guard(
3894                 std::make_shared<NOT_NONE>(std::move(verbose_code_parts)));
3895           })
3896       .def(
3897           "add_global_state_guard",
3898           [](GuardManager& self, py::object verbose_code_parts) -> void {
3899             self.add_leaf_guard(
3900                 std::make_shared<GLOBAL_STATE>(std::move(verbose_code_parts)));
3901           })
3902       .def(
3903           "add_torch_function_mode_stack_guard",
3904           [](GuardManager& self,
3905              const py::list& initial_stack,
3906              const py::list& ignored_types,
3907              py::object verbose_code_parts) -> void {
3908             self.add_leaf_guard(std::make_shared<TORCH_FUNCTION_MODE_STACK>(
3909                 initial_stack, ignored_types, std::move(verbose_code_parts)));
3910           })
3911       .def(
3912           "add_data_ptr_guard",
3913           [](GuardManager& self,
3914              py::object data_ptr,
3915              py::object verbose_code_parts) -> void {
3916             SKIP_IF_GUARD_ALREADY_PRESENT("DATA_PTR_MATCH");
3917             self.add_leaf_guard(std::make_shared<DATA_PTR_MATCH>(
3918                 std::move(data_ptr), std::move(verbose_code_parts)));
3919           })
3920       .def(
3921           "add_no_hasattr_guard",
3922           [](GuardManager& self,
3923              py::object attr_name,
3924              py::object verbose_code_parts) -> void {
3925             self.add_leaf_guard(std::make_shared<NO_HASATTR>(
3926                 std::move(attr_name), std::move(verbose_code_parts)));
3927           })
3928       .def(
3929           "add_dict_contains_guard",
3930           [](GuardManager& self,
3931              bool contains,
3932              py::object key,
3933              py::object verbose_code_parts) -> void {
3934             self.add_leaf_guard(std::make_shared<DICT_CONTAINS>(
3935                 contains, std::move(key), std::move(verbose_code_parts)));
3936           })
3937       .def(
3938           "add_dynamic_indices_guard",
3939           [](GuardManager& self,
3940              py::set value,
3941              py::object verbose_code_parts) -> void {
3942             self.add_leaf_guard(std::make_shared<DYNAMIC_INDICES>(
3943                 std::move(value), std::move(verbose_code_parts)));
3944           })
3945       .def(
3946           "add_dict_version_guard",
3947           [](GuardManager& self,
3948              py::object value,
3949              py::object verbose_code_parts) -> void {
3950             self.add_leaf_guard(std::make_shared<DICT_VERSION>(
3951                 std::move(value), std::move(verbose_code_parts)));
3952           })
3953       .def(
3954           "add_tensor_match_guard",
3955           [](GuardManager& self,
3956              py::object value,
3957              py::object sizes,
3958              py::object strides,
3959              py::object tensor_name,
3960              py::object verbose_code_parts) -> void {
3961             SKIP_IF_GUARD_ALREADY_PRESENT("TENSOR_MATCH");
3962             self.add_leaf_guard(std::make_shared<TENSOR_MATCH>(
3963                 self.get_root(),
3964                 std::move(value),
3965                 std::move(sizes),
3966                 std::move(strides),
3967                 std::move(tensor_name),
3968                 std::move(verbose_code_parts)));
3969           })
3970 
3971       // return by reference because GuardManager has the ownership of accessors
3972       // and guard managers
3973       .def(
3974           "getitem_manager",
3975           &GuardManager::get_child_manager<GetItemGuardAccessor>,
3976           py::arg("key"),
3977           py::arg("source"),
3978           py::arg("example_value"),
3979           py::arg("guard_manager_enum"),
3980           py::return_value_policy::reference)
3981       // return by reference because GuardManager has the ownership of accessors
3982       // and guard managers
3983       .def(
3984           "dict_getitem_manager",
3985           &GuardManager::get_child_manager<DictGetItemGuardAccessor>,
3986           py::arg("key"),
3987           py::arg("source"),
3988           py::arg("example_value"),
3989           py::arg("guard_manager_enum"),
3990           py::return_value_policy::reference)
3991       // return by reference because GuardManager has the ownership of accessors
3992       // and guard managers
3993       .def(
3994           "list_getitem_manager",
3995           &GuardManager::get_child_manager<ListGetItemGuardAccessor>,
3996           py::arg("key"),
3997           py::arg("source"),
3998           py::arg("example_value"),
3999           py::arg("guard_manager_enum"),
4000           py::return_value_policy::reference)
4001       // return by reference because GuardManager has the ownership of accessors
4002       // and guard managers
4003       .def(
4004           "tuple_getitem_manager",
4005           &GuardManager::get_child_manager<TupleGetItemGuardAccessor>,
4006           py::arg("key"),
4007           py::arg("source"),
4008           py::arg("example_value"),
4009           py::arg("guard_manager_enum"),
4010           py::return_value_policy::reference)
4011       // return by reference because GuardManager has the ownership of accessors
4012       // and guard managers
4013       .def(
4014           "func_defaults_manager",
4015           [](GuardManager& self,
4016              std::string source,
4017              py::object example_value,
4018              py::handle guard_manager_enum) -> GuardManager* {
4019             // A unique key is used to save as the accessor key.
4020             py::str unique_key("__defaults_accessor__");
4021             return self.get_child_manager<FuncDefaultsGuardAccessor>(
4022                 std::move(unique_key),
4023                 std::move(source),
4024                 std::move(example_value),
4025                 guard_manager_enum);
4026           },
4027           py::arg("source"),
4028           py::arg("example_value"),
4029           py::arg("guard_manager_enum"),
4030           py::return_value_policy::reference)
4031 
4032       // return by reference because GuardManager has the ownership of accessors
4033       // and guard managers
4034       .def(
4035           "func_kwdefaults_manager",
4036           [](GuardManager& self,
4037              std::string source,
4038              py::object example_value,
4039              py::handle guard_manager_enum) -> GuardManager* {
4040             // A unique key is used to save as the accessor key.
4041             py::str unique_key("__kwdefaults_accessor__");
4042             return self.get_child_manager<FuncKwDefaultsGuardAccessor>(
4043                 std::move(unique_key),
4044                 std::move(source),
4045                 std::move(example_value),
4046                 guard_manager_enum);
4047           },
4048           py::arg("source"),
4049           py::arg("example_value"),
4050           py::arg("guard_manager_enum"),
4051           py::return_value_policy::reference)
4052       // return by reference because GuardManager has the ownership of accessors
4053       // and guard managers
4054       .def(
4055           "globals_dict_manager",
4056           &GuardManager::get_child_manager<GlobalsGuardAccessor>,
4057           py::arg("f_globals"),
4058           py::arg("source"),
4059           py::arg("example_value"),
4060           py::arg("guard_manager_enum"),
4061           py::return_value_policy::reference)
4062       // return by reference because GuardManager has the ownership of accessors
4063       // and guard managers
4064       .def(
4065           "type_manager",
4066           [](GuardManager& self,
4067              std::string source,
4068              py::handle example_value,
4069              py::handle guard_manager_enum) -> GuardManager* {
4070             // A unique key is used to save as the accessor key.
4071             py::str unique_key("__type_accessor__");
4072             return self.get_child_manager<TypeGuardAccessor>(
4073                 std::move(unique_key),
4074                 std::move(source),
4075                 example_value,
4076                 guard_manager_enum);
4077           },
4078           py::arg("source"),
4079           py::arg("example_value"),
4080           py::arg("guard_manager_enum"),
4081           py::return_value_policy::reference)
4082       // return by reference because GuardManager has the ownership of accessors
4083       // and guard managers
4084       .def(
4085           "weakref_call_manager",
4086           [](GuardManager& self,
4087              std::string source,
4088              py::handle example_value,
4089              py::handle guard_manager_enum) -> GuardManager* {
4090             // A unique key is used to save as the accessor key.
4091             py::str unique_key("__weakref_call_accessor__");
4092             return self.get_child_manager<WeakRefCallGuardAccessor>(
4093                 std::move(unique_key),
4094                 std::move(source),
4095                 example_value,
4096                 guard_manager_enum);
4097           },
4098           py::arg("source"),
4099           py::arg("example_value"),
4100           py::arg("guard_manager_enum"),
4101           py::return_value_policy::reference)
4102       // return by reference because GuardManager has the ownership of accessors
4103       // and guard managers
4104       .def(
4105           "tuple_iterator_getitem_manager",
4106           &GuardManager::get_child_manager<TupleIteratorGetItemAccessor>,
4107           py::arg("index"),
4108           py::arg("source"),
4109           py::arg("example_value"),
4110           py::arg("guard_manager_enum"),
4111           py::return_value_policy::reference)
4112       // return by reference because GuardManager has the ownership of accessors
4113       // and guard managers
4114       .def(
4115           "global_weakref_manager",
4116           &GuardManager::get_child_manager<GlobalWeakRefGuardAccessor>,
4117           py::arg("global_name"),
4118           py::arg("source"),
4119           py::arg("example_value"),
4120           py::arg("guard_manager_enum"),
4121           py::return_value_policy::reference)
4122       // return by reference because GuardManager has the ownership of accessors
4123       // and guard managers
4124       .def(
4125           "lambda_manager",
4126           &GuardManager::get_child_manager<PythonLambdaGuardAccessor>,
4127           py::arg("python_lambda"),
4128           py::arg("source"),
4129           py::arg("example_value"),
4130           py::arg("guard_manager_enum"),
4131           py::return_value_policy::reference)
4132       // return by reference because GuardManager has the ownership of accessors
4133       // and guard managers
4134       .def(
4135           "grad_manager",
4136           [](GuardManager& self,
4137              std::string source,
4138              py::handle example_value,
4139              py::handle guard_manager_enum) -> GuardManager* {
4140             // A unique key is used to save as the accessor key.
4141             py::str unique_key("__grad_accessor__");
4142             return self.get_child_manager<GradGuardAccessor>(
4143                 std::move(unique_key),
4144                 std::move(source),
4145                 example_value,
4146                 guard_manager_enum);
4147           },
4148           py::arg("source"),
4149           py::arg("example_value"),
4150           py::arg("guard_manager_enum"),
4151           py::return_value_policy::reference)
4152       // return by reference because GuardManager has the ownership of accessors
4153       // and guard managers
4154       .def(
4155           "get_generic_dict_manager",
4156           [](GuardManager& self,
4157              std::string source,
4158              py::handle example_value,
4159              py::handle guard_manager_enum) -> GuardManager* {
4160             // A unique key is used to save as the accessor key.
4161             py::str unique_key("__generic_dict_accessor__");
4162             return self.get_child_manager<GetGenericDictGuardAccessor>(
4163                 std::move(unique_key),
4164                 std::move(source),
4165                 example_value,
4166                 guard_manager_enum);
4167           },
4168           py::arg("source"),
4169           py::arg("example_value"),
4170           py::arg("guard_manager_enum"),
4171           py::return_value_policy::reference)
4172       // return by reference because C++ GuardManager has the ownership of
4173       // accessors and guard managers
4174       .def(
4175           "getattr_manager",
4176           &GuardManager::get_child_manager<GetAttrGuardAccessor>,
4177           py::arg("attr"),
4178           py::arg("source"),
4179           py::arg("example_value"),
4180           py::arg("guard_manager_enum"),
4181           py::return_value_policy::reference);
4182 
4183   // Root Guard Manager
4184   py::class_<RootGuardManager, GuardManager, std::unique_ptr<RootGuardManager>>(
4185       py_m, "RootGuardManager")
4186       .def(py::init<>())
4187       .def("check", &RootGuardManager::check)
4188       .def("check_verbose", &RootGuardManager::check_verbose)
4189       // return by reference because GuardManager has the ownership of leaf
4190       // guards
4191       .def(
4192           "get_epilogue_lambda_guards",
4193           &RootGuardManager::get_epilogue_lambda_guards,
4194           py::return_value_policy::reference)
4195       .def(
4196           "add_epilogue_lambda_guard",
4197           [](RootGuardManager& self,
4198              py::object lambda,
4199              py::object verbose_code_parts) -> void {
4200             self.add_epilogue_lambda_guard(std::make_unique<LAMBDA_GUARD>(
4201                 std::move(lambda), std::move(verbose_code_parts)));
4202           });
4203 
4204   // Dict Guard Manager
4205   py::class_<DictGuardManager, GuardManager, std::unique_ptr<DictGuardManager>>(
4206       py_m, "DictGuardManager")
4207       // return by reference because GuardManager has the ownership of accessors
4208       // and guard managers
4209       .def(
4210           "get_key_manager",
4211           [](DictGuardManager& self,
4212              py::object index,
4213              std::string source,
4214              py::handle example_value,
4215              py::handle guard_manager_enum) -> GuardManager* {
4216             return self.get_key_manager(
4217                 std::move(index),
4218                 std::move(source),
4219                 example_value,
4220                 guard_manager_enum);
4221           },
4222           py::arg("index"),
4223           py::arg("source"),
4224           py::arg("example_value"),
4225           py::arg("guard_manager_enum"),
4226           py::return_value_policy::reference)
4227       // return by reference because GuardManager has the ownership of accessors
4228       // and guard managers
4229       .def(
4230           "get_value_manager",
4231           [](DictGuardManager& self,
4232              py::object index,
4233              std::string source,
4234              py::handle example_value,
4235              py::handle guard_manager_enum) -> GuardManager* {
4236             return self.get_value_manager(
4237                 std::move(index),
4238                 std::move(source),
4239                 example_value,
4240                 guard_manager_enum);
4241           },
4242           py::arg("index"),
4243           py::arg("source"),
4244           py::arg("example_value"),
4245           py::arg("guard_manager_enum"),
4246           py::return_value_policy::reference)
4247       // return by reference because GuardManager has the ownership of leaf
4248       // guards
4249       .def(
4250           "get_key_value_managers",
4251           &DictGuardManager::get_key_value_managers,
4252           py::return_value_policy::reference)
4253       // Skipped leaf guards
4254       .def("add_type_match_guard", &DictGuardManager::skip_adding_guard)
4255       .def("add_dict_length_check_guard", &DictGuardManager::skip_adding_guard)
4256       // Permitted leaf guards
4257       .def(
4258           "add_dict_contains_guard",
4259           [](DictGuardManager& self,
4260              bool contains,
4261              py::object key,
4262              py::object verbose_code_parts) -> void {
4263             self.add_permitted_leaf_guard(std::make_shared<DICT_CONTAINS>(
4264                 contains, std::move(key), std::move(verbose_code_parts)));
4265           })
4266       .def(
4267           "add_dict_version_guard",
4268           [](DictGuardManager& self,
4269              py::object value,
4270              py::object verbose_code_parts) -> void {
4271             // DICT_VERSION is used in a very narrow context today to guard on
4272             // pytree SUPPPORTED_NODES. We can remove this once we have tags in
4273             // DictGuardManager.
4274             self.add_permitted_leaf_guard(std::make_shared<DICT_VERSION>(
4275                 std::move(value), std::move(verbose_code_parts)));
4276           })
4277       // Not permitted accesssors
4278       .def("lambda_manager", &DictGuardManager::fail_on_get_child_manager)
4279       .def("getitem_manager", &DictGuardManager::fail_on_get_child_manager)
4280       .def("dict_getitem_manager", &DictGuardManager::fail_on_get_child_manager)
4281       .def("globals_dict_manager", &DictGuardManager::fail_on_get_child_manager)
4282       .def(
4283           "tuple_iterator_getitem_manager",
4284           &DictGuardManager::fail_on_get_child_manager)
4285       .def(
4286           "global_weakref_manager",
4287           &DictGuardManager::fail_on_get_child_manager)
4288       .def("lambda_manager", &DictGuardManager::fail_on_get_child_manager)
4289       // Permitted accessors (and also type_manager)
4290       // return by reference because GuardManager has the ownership of accessors
4291       // and guard managers
4292       .def(
4293           "getattr_manager",
4294           [](DictGuardManager& self,
4295              py::object attr_name,
4296              std::string source,
4297              py::handle example_value,
4298              py::handle guard_manager_enum) -> GuardManager* {
4299             if (self.is_exact_dict_type()) {
4300               throw std::runtime_error(
4301                   "getattr_manager on a DictGuardManager is supported only for dict subclasses");
4302             }
4303             return self.get_child_manager<GetAttrGuardAccessor>(
4304                 std::move(attr_name),
4305                 std::move(source),
4306                 example_value,
4307                 guard_manager_enum);
4308           },
4309           py::arg("attr"),
4310           py::arg("source"),
4311           py::arg("example_value"),
4312           py::arg("guard_manager_enum"),
4313           py::return_value_policy::reference);
4314 
4315   // Dict Guard Manager
4316   py::class_< // NOLINT
4317       DictSubclassGuardManager,
4318       DictGuardManager,
4319       std::unique_ptr<DictSubclassGuardManager>>(
4320       py_m, "DictSubclassGuardManager") // NOLINT
4321       .def(
4322           "add_no_hasattr_guard",
4323           [](DictSubclassGuardManager& self,
4324              py::object attr_name,
4325              py::object verbose_code_parts) -> void {
4326             self.add_permitted_leaf_guard(std::make_shared<NO_HASATTR>(
4327                 std::move(attr_name), std::move(verbose_code_parts)));
4328           });
4329 
4330   py_m.def("install_object_aliasing_guard", install_object_aliasing_guard);
4331   py_m.def(
4332       "install_no_tensor_aliasing_guard", install_no_tensor_aliasing_guard);
4333 
4334 // initialize dict_version_map watcher for 3.12
4335 #if IS_PYTHON_3_12_PLUS
4336 
4337   dict_version_watcher_id = PyDict_AddWatcher(dict_version_watch_callback);
4338   if (dict_version_watcher_id == -1) {
4339     throw std::runtime_error("Failed to install dict_version_watch_callback");
4340   }
4341 
4342 #endif
4343 
4344   return m;
4345 }
4346 
4347 } // namespace torch::dynamo
4348