1 #pragma once 2 #include <c10/core/GradMode.h> 3 #include <torch/csrc/python_headers.h> 4 #include <torch/csrc/utils/pybind.h> 5 6 namespace torch::dynamo { 7 8 PyObject* torch_c_dynamo_guards_init(); 9 10 // interfaces for extra_state and eval_frame.c because RootGuardManager class is 11 // not visible there. 12 void* convert_to_root_guard_manager(py::object root); 13 bool run_root_guard_manager(void* root, PyObject* f_locals); 14 15 struct LocalState { 16 // TLS state that changes operators 17 c10::impl::LocalDispatchKeySet dispatch_modifier; 18 c10::DispatchKeySet override_dispatch_key_set; 19 bool grad_mode_enabled; 20 applyLocalState21 at::DispatchKeySet apply(at::DispatchKeySet ks) const { 22 if (override_dispatch_key_set.empty()) { 23 return (ks | dispatch_modifier.included_) - dispatch_modifier.excluded_; 24 } else { 25 return override_dispatch_key_set; 26 } 27 } 28 LocalStateLocalState29 LocalState() 30 : dispatch_modifier(c10::impl::tls_local_dispatch_key_set()), 31 override_dispatch_key_set(c10::BackendComponent::InvalidBit), 32 grad_mode_enabled(at::GradMode::is_enabled()) {} 33 overrideDispatchKeySetLocalState34 void overrideDispatchKeySet(c10::DispatchKeySet ks) { 35 override_dispatch_key_set = ks; 36 } 37 }; 38 39 class TensorCheck { 40 public: 41 TensorCheck( 42 const LocalState& state, 43 PyTypeObject* pt, 44 const at::Tensor& v, 45 std::vector<std::optional<c10::SymInt>> dynamic_dims_sizes, 46 std::vector<std::optional<c10::SymInt>> dynamic_dims_strides); 47 48 TensorCheck( 49 const LocalState& state, 50 PyTypeObject* pt, 51 c10::DispatchKeySet dispatch_key_set, 52 at::ScalarType dtype, 53 at::DeviceIndex device_index, 54 bool requires_grad, 55 std::vector<std::optional<c10::SymInt>> dynamic_dims_sizes, 56 std::vector<std::optional<c10::SymInt>> dynamic_dims_strides); 57 58 bool check(const LocalState& state, const at::Tensor& v); 59 bool check( 60 const LocalState& state, 61 const c10::DispatchKeySet& dispatch_key_set, 62 const at::ScalarType& dtype, 63 const c10::Device& device, 64 const c10::SymIntArrayRef& dynamic_dims_sizes, 65 const c10::SymIntArrayRef& dynamic_dims_strides, 66 const bool& requires_grad); 67 std::string check_verbose( 68 const LocalState& state, 69 const at::Tensor& v, 70 const std::string& tensor_name); 71 72 PyTypeObject* pytype; 73 74 private: 75 uint64_t dispatch_key_; // DispatchKeySet includes device/layout 76 at::ScalarType dtype_; 77 // Note(voz): While dispatch_key_ is sufficiently representative of a device 78 // In that keys are more granular AND device specific - they do not 79 // necessarily capture device indices correctly. 80 at::DeviceIndex device_index_; 81 bool requires_grad_; 82 // NB: These are unset if dynamic shapes is enabled. 83 std::vector<std::optional<c10::SymInt>> sizes_; 84 std::vector<std::optional<c10::SymInt>> strides_; 85 // Not strictly required for dense tensors, but nested tensors need it. 86 int64_t dim_; 87 }; 88 89 } // namespace torch::dynamo 90