#pragma once #include #include #include namespace torch::dynamo { PyObject* torch_c_dynamo_guards_init(); // interfaces for extra_state and eval_frame.c because RootGuardManager class is // not visible there. void* convert_to_root_guard_manager(py::object root); bool run_root_guard_manager(void* root, PyObject* f_locals); struct LocalState { // TLS state that changes operators c10::impl::LocalDispatchKeySet dispatch_modifier; c10::DispatchKeySet override_dispatch_key_set; bool grad_mode_enabled; at::DispatchKeySet apply(at::DispatchKeySet ks) const { if (override_dispatch_key_set.empty()) { return (ks | dispatch_modifier.included_) - dispatch_modifier.excluded_; } else { return override_dispatch_key_set; } } LocalState() : dispatch_modifier(c10::impl::tls_local_dispatch_key_set()), override_dispatch_key_set(c10::BackendComponent::InvalidBit), grad_mode_enabled(at::GradMode::is_enabled()) {} void overrideDispatchKeySet(c10::DispatchKeySet ks) { override_dispatch_key_set = ks; } }; class TensorCheck { public: TensorCheck( const LocalState& state, PyTypeObject* pt, const at::Tensor& v, std::vector> dynamic_dims_sizes, std::vector> dynamic_dims_strides); TensorCheck( const LocalState& state, PyTypeObject* pt, c10::DispatchKeySet dispatch_key_set, at::ScalarType dtype, at::DeviceIndex device_index, bool requires_grad, std::vector> dynamic_dims_sizes, std::vector> dynamic_dims_strides); bool check(const LocalState& state, const at::Tensor& v); bool check( const LocalState& state, const c10::DispatchKeySet& dispatch_key_set, const at::ScalarType& dtype, const c10::Device& device, const c10::SymIntArrayRef& dynamic_dims_sizes, const c10::SymIntArrayRef& dynamic_dims_strides, const bool& requires_grad); std::string check_verbose( const LocalState& state, const at::Tensor& v, const std::string& tensor_name); PyTypeObject* pytype; private: uint64_t dispatch_key_; // DispatchKeySet includes device/layout at::ScalarType dtype_; // Note(voz): While dispatch_key_ is sufficiently representative of a device // In that keys are more granular AND device specific - they do not // necessarily capture device indices correctly. at::DeviceIndex device_index_; bool requires_grad_; // NB: These are unset if dynamic shapes is enabled. std::vector> sizes_; std::vector> strides_; // Not strictly required for dense tensors, but nested tensors need it. int64_t dim_; }; } // namespace torch::dynamo