xref: /aosp_15_r20/external/pytorch/torch/csrc/dynamo/guards.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <c10/core/GradMode.h>
3 #include <torch/csrc/python_headers.h>
4 #include <torch/csrc/utils/pybind.h>
5 
6 namespace torch::dynamo {
7 
8 PyObject* torch_c_dynamo_guards_init();
9 
10 // interfaces for extra_state and eval_frame.c because RootGuardManager class is
11 // not visible there.
12 void* convert_to_root_guard_manager(py::object root);
13 bool run_root_guard_manager(void* root, PyObject* f_locals);
14 
15 struct LocalState {
16   // TLS state that changes operators
17   c10::impl::LocalDispatchKeySet dispatch_modifier;
18   c10::DispatchKeySet override_dispatch_key_set;
19   bool grad_mode_enabled;
20 
applyLocalState21   at::DispatchKeySet apply(at::DispatchKeySet ks) const {
22     if (override_dispatch_key_set.empty()) {
23       return (ks | dispatch_modifier.included_) - dispatch_modifier.excluded_;
24     } else {
25       return override_dispatch_key_set;
26     }
27   }
28 
LocalStateLocalState29   LocalState()
30       : dispatch_modifier(c10::impl::tls_local_dispatch_key_set()),
31         override_dispatch_key_set(c10::BackendComponent::InvalidBit),
32         grad_mode_enabled(at::GradMode::is_enabled()) {}
33 
overrideDispatchKeySetLocalState34   void overrideDispatchKeySet(c10::DispatchKeySet ks) {
35     override_dispatch_key_set = ks;
36   }
37 };
38 
39 class TensorCheck {
40  public:
41   TensorCheck(
42       const LocalState& state,
43       PyTypeObject* pt,
44       const at::Tensor& v,
45       std::vector<std::optional<c10::SymInt>> dynamic_dims_sizes,
46       std::vector<std::optional<c10::SymInt>> dynamic_dims_strides);
47 
48   TensorCheck(
49       const LocalState& state,
50       PyTypeObject* pt,
51       c10::DispatchKeySet dispatch_key_set,
52       at::ScalarType dtype,
53       at::DeviceIndex device_index,
54       bool requires_grad,
55       std::vector<std::optional<c10::SymInt>> dynamic_dims_sizes,
56       std::vector<std::optional<c10::SymInt>> dynamic_dims_strides);
57 
58   bool check(const LocalState& state, const at::Tensor& v);
59   bool check(
60       const LocalState& state,
61       const c10::DispatchKeySet& dispatch_key_set,
62       const at::ScalarType& dtype,
63       const c10::Device& device,
64       const c10::SymIntArrayRef& dynamic_dims_sizes,
65       const c10::SymIntArrayRef& dynamic_dims_strides,
66       const bool& requires_grad);
67   std::string check_verbose(
68       const LocalState& state,
69       const at::Tensor& v,
70       const std::string& tensor_name);
71 
72   PyTypeObject* pytype;
73 
74  private:
75   uint64_t dispatch_key_; // DispatchKeySet includes device/layout
76   at::ScalarType dtype_;
77   // Note(voz): While dispatch_key_ is sufficiently representative of a device
78   // In that keys are more granular AND device specific - they do not
79   // necessarily capture device indices correctly.
80   at::DeviceIndex device_index_;
81   bool requires_grad_;
82   // NB: These are unset if dynamic shapes is enabled.
83   std::vector<std::optional<c10::SymInt>> sizes_;
84   std::vector<std::optional<c10::SymInt>> strides_;
85   // Not strictly required for dense tensors, but nested tensors need it.
86   int64_t dim_;
87 };
88 
89 } // namespace torch::dynamo
90