xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/input_metadata.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/ExpandUtils.h>
4 #include <ATen/NestedTensorImpl.h>
5 #include <ATen/core/Tensor.h>
6 #include <c10/core/Device.h>
7 #include <c10/core/DeviceType.h>
8 #include <c10/core/Stream.h>
9 #include <c10/core/SymIntArrayRef.h>
10 #include <c10/core/TensorImpl.h>
11 #include <c10/core/impl/DeviceGuardImplInterface.h>
12 #include <c10/util/DimVector.h>
13 #include <c10/util/Exception.h>
14 #include <c10/util/SmallVector.h>
15 
16 #ifndef AT_PER_OPERATOR_HEADERS
17 #include <ATen/Functions.h>
18 #else
19 #include <ATen/ops/zeros.h>
20 #endif
21 
22 namespace torch::autograd {
23 
24 using SymIntSmallVec = c10::SmallVector<c10::SymInt, c10::kDimVectorStaticSize>;
25 using MetadataShape = std::variant<SymIntSmallVec, at::Tensor>;
26 
27 /**
28  * Records TensorOptions, shape of the tensor, whether or not the Python
29  * dispatch key is set (tensor subclass), and, where applicable, the stream the
30  * corresponding operation took place on.
31  *
32  * If is_valid() is false, then the corresponding input is not used and may be
33  * an undefined tensor.
34  */
35 struct TORCH_API InputMetadata {
36   InputMetadata() = default;
37   InputMetadata(
38       const at::TensorOptions& options,
39       MetadataShape input_shape,
40       bool is_tensor_subclass,
41       bool is_nested);
42   InputMetadata(const at::Tensor& t);
43 
optionsInputMetadata44   const at::TensorOptions& options() const {
45     return options_;
46   }
47 
dtypeInputMetadata48   caffe2::TypeMeta dtype() const {
49     return options_.dtype();
50   }
51 
deviceInputMetadata52   at::Device device() const {
53     return options_.device();
54   }
55 
layoutInputMetadata56   at::Layout layout() const {
57     return options_.layout();
58   }
59 
streamInputMetadata60   c10::Stream stream() const {
61     return stream_;
62   }
63 
is_tensor_subclassInputMetadata64   bool is_tensor_subclass() const {
65     return is_tensor_subclass_;
66   }
67 
68   at::Tensor zeros_like() const;
69 
70   bool is_same_shape(const at::Tensor& grad) const;
71 
72   bool is_expandable_to_shape(const at::Tensor& grad) const;
73 
74   at::Tensor reduce_grad(at::Tensor& grad) const;
75 
76   at::Tensor maybe_reduce(
77       const size_t index,
78       at::Tensor grad,
79       const std::function<std::string(const std::string&)>& format_error) const;
80 
81   std::stringstream incompatible_shape_error_message(
82       const size_t index,
83       const at::Tensor& grad) const;
84 
was_default_constructedInputMetadata85   bool was_default_constructed() const {
86     return was_default_constructed_;
87   }
88 
89   bool is_cpp_nested_tensor() const;
90 
is_nested_tensorInputMetadata91   bool is_nested_tensor() const {
92     return is_nested_;
93   }
94 
95   c10::SymIntArrayRef shape_as_dim_vector() const;
96 
97   // Danger: not thread safe, caller must protect with lock
98   SymIntSmallVec& mutable_shape_as_dim_vector();
99 
100  private:
101   at::Tensor shape_as_tensor() const;
102   bool is_nestedness_same(const at::Tensor& grad) const;
103   bool maybe_expandable_to(const at::Tensor& grad) const;
104 
105   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
106   const at::TensorOptions options_;
107   MetadataShape shape_;
108   c10::Stream stream_ = c10::Stream(c10::Stream::Default::DEFAULT, device());
109   bool is_tensor_subclass_ = false;
110   bool is_nested_ = false;
111   bool was_default_constructed_ = true;
112 };
113 } // namespace torch::autograd
114