1 #pragma once 2 3 #include <ATen/ExpandUtils.h> 4 #include <ATen/NestedTensorImpl.h> 5 #include <ATen/core/Tensor.h> 6 #include <c10/core/Device.h> 7 #include <c10/core/DeviceType.h> 8 #include <c10/core/Stream.h> 9 #include <c10/core/SymIntArrayRef.h> 10 #include <c10/core/TensorImpl.h> 11 #include <c10/core/impl/DeviceGuardImplInterface.h> 12 #include <c10/util/DimVector.h> 13 #include <c10/util/Exception.h> 14 #include <c10/util/SmallVector.h> 15 16 #ifndef AT_PER_OPERATOR_HEADERS 17 #include <ATen/Functions.h> 18 #else 19 #include <ATen/ops/zeros.h> 20 #endif 21 22 namespace torch::autograd { 23 24 using SymIntSmallVec = c10::SmallVector<c10::SymInt, c10::kDimVectorStaticSize>; 25 using MetadataShape = std::variant<SymIntSmallVec, at::Tensor>; 26 27 /** 28 * Records TensorOptions, shape of the tensor, whether or not the Python 29 * dispatch key is set (tensor subclass), and, where applicable, the stream the 30 * corresponding operation took place on. 31 * 32 * If is_valid() is false, then the corresponding input is not used and may be 33 * an undefined tensor. 34 */ 35 struct TORCH_API InputMetadata { 36 InputMetadata() = default; 37 InputMetadata( 38 const at::TensorOptions& options, 39 MetadataShape input_shape, 40 bool is_tensor_subclass, 41 bool is_nested); 42 InputMetadata(const at::Tensor& t); 43 optionsInputMetadata44 const at::TensorOptions& options() const { 45 return options_; 46 } 47 dtypeInputMetadata48 caffe2::TypeMeta dtype() const { 49 return options_.dtype(); 50 } 51 deviceInputMetadata52 at::Device device() const { 53 return options_.device(); 54 } 55 layoutInputMetadata56 at::Layout layout() const { 57 return options_.layout(); 58 } 59 streamInputMetadata60 c10::Stream stream() const { 61 return stream_; 62 } 63 is_tensor_subclassInputMetadata64 bool is_tensor_subclass() const { 65 return is_tensor_subclass_; 66 } 67 68 at::Tensor zeros_like() const; 69 70 bool is_same_shape(const at::Tensor& grad) const; 71 72 bool is_expandable_to_shape(const at::Tensor& grad) const; 73 74 at::Tensor reduce_grad(at::Tensor& grad) const; 75 76 at::Tensor maybe_reduce( 77 const size_t index, 78 at::Tensor grad, 79 const std::function<std::string(const std::string&)>& format_error) const; 80 81 std::stringstream incompatible_shape_error_message( 82 const size_t index, 83 const at::Tensor& grad) const; 84 was_default_constructedInputMetadata85 bool was_default_constructed() const { 86 return was_default_constructed_; 87 } 88 89 bool is_cpp_nested_tensor() const; 90 is_nested_tensorInputMetadata91 bool is_nested_tensor() const { 92 return is_nested_; 93 } 94 95 c10::SymIntArrayRef shape_as_dim_vector() const; 96 97 // Danger: not thread safe, caller must protect with lock 98 SymIntSmallVec& mutable_shape_as_dim_vector(); 99 100 private: 101 at::Tensor shape_as_tensor() const; 102 bool is_nestedness_same(const at::Tensor& grad) const; 103 bool maybe_expandable_to(const at::Tensor& grad) const; 104 105 // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) 106 const at::TensorOptions options_; 107 MetadataShape shape_; 108 c10::Stream stream_ = c10::Stream(c10::Stream::Default::DEFAULT, device()); 109 bool is_tensor_subclass_ = false; 110 bool is_nested_ = false; 111 bool was_default_constructed_ = true; 112 }; 113 } // namespace torch::autograd 114