xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/input_metadata.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/autograd/input_metadata.h>
2 
3 // TODO: we may be able to move some imports from input_metadata.h to here, but
4 // it seems that function.h transitively depends on some of them.
5 
6 namespace torch::autograd {
7 
8 namespace {
9 
compute_variant_shape(const at::Tensor & input)10 MetadataShape compute_variant_shape(const at::Tensor& input) {
11   if (input.is_nested() && !input.unsafeGetTensorImpl()->is_python_dispatch()) {
12     auto nested_size = input._nested_tensor_size();
13     return MetadataShape{std::in_place_type<at::Tensor>, nested_size};
14   }
15   return MetadataShape{std::in_place_type<SymIntSmallVec>, input.sym_sizes()};
16 }
17 
is_python_dispatch(const at::Tensor & tensor)18 bool is_python_dispatch(const at::Tensor& tensor) {
19   return tensor.unsafeGetTensorImpl()->is_python_dispatch();
20 }
21 
is_cpp_nested_tensor(const at::Tensor & tensor)22 bool is_cpp_nested_tensor(const at::Tensor& tensor) {
23   return tensor.is_nested() && !is_python_dispatch(tensor);
24 }
25 
26 } // namespace
27 
InputMetadata(const at::TensorOptions & options,MetadataShape input_shape,bool is_tensor_subclass,bool is_nested)28 InputMetadata::InputMetadata(
29     const at::TensorOptions& options,
30     MetadataShape input_shape,
31     bool is_tensor_subclass,
32     bool is_nested)
33     : options_{options},
34       shape_{std::move(input_shape)},
35       is_tensor_subclass_{is_tensor_subclass},
36       is_nested_{is_nested},
37       was_default_constructed_{false} {
38   auto device_ = options.device();
39   stream_ = c10::impl::getDeviceGuardImpl(device_.type())->getStream(device_);
40 }
41 
InputMetadata(const at::Tensor & t)42 InputMetadata::InputMetadata(const at::Tensor& t)
43     : InputMetadata(
44           t.options(),
45           compute_variant_shape(t),
46           is_python_dispatch(t),
47           t.is_nested()) {}
48 
zeros_like() const49 at::Tensor InputMetadata::zeros_like() const {
50   TORCH_CHECK(
51       !is_nested_, "Zeros is not currently supported for nested tensors.")
52   return at::zeros_symint(shape_as_dim_vector(), options_);
53 }
54 
maybe_reduce(const size_t i,at::Tensor grad,const std::function<std::string (const std::string &)> & format_error) const55 at::Tensor InputMetadata::maybe_reduce(
56     const size_t i,
57     at::Tensor grad,
58     const std::function<std::string(const std::string&)>& format_error) const {
59   auto fail = [&]() {
60     const auto message = incompatible_shape_error_message(i, grad);
61     TORCH_CHECK(false, format_error(message.str()));
62   };
63 
64   // Nested tensor makes my brain explode, so I've just hard-coded the logic
65   // for this case, at risk of code duplication.  This logic does NOT do the
66   // careful oblivious logic as seen below
67   if (is_nested_ || is_cpp_nested_tensor() || grad.is_nested() ||
68       ::torch::autograd::is_cpp_nested_tensor(grad)) {
69     if (!is_same_shape(grad)) {
70       if (is_expandable_to_shape(grad)) {
71         return reduce_grad(grad);
72       } else {
73         fail();
74       }
75     } else {
76       return grad;
77     }
78   }
79 
80   auto shape = shape_as_dim_vector();
81   auto desired = grad.sym_sizes();
82 
83   size_t ndim = shape.size();
84   size_t target_dim = desired.size();
85   if (ndim > target_dim) {
86     fail();
87   }
88   bool needs_reduce = false;
89   for (const auto i : c10::irange(ndim)) {
90     const auto& size = shape[ndim - i - 1];
91     const auto& target = desired[target_dim - i - 1];
92     // The conditions here are written carefully so that we are able to
93     // infer deferred runtime asserts
94     if (TORCH_GUARD_SIZE_OBLIVIOUS(size.sym_eq(1))) {
95       // NB: we could short circuit this once needs_reduce is true but there's
96       // no point since the reduction function will guard on this anyway
97       if (!c10::definitely_true(size.sym_eq(target), __FILE__, __LINE__)) {
98         needs_reduce = true;
99       }
100     } else {
101       if (!size.sym_eq(target).expect_true(__FILE__, __LINE__)) {
102         fail();
103       }
104     }
105   }
106   if (ndim != target_dim) {
107     needs_reduce = true;
108   }
109 
110   if (needs_reduce) {
111     return reduce_grad(grad);
112   } else {
113     return grad;
114   }
115 }
116 
is_same_shape(const at::Tensor & grad) const117 bool InputMetadata::is_same_shape(const at::Tensor& grad) const {
118   if (!is_nestedness_same(grad)) {
119     return false;
120   }
121   if (is_cpp_nested_tensor()) {
122     return grad._nested_tensor_size().is_same_size(shape_as_tensor());
123   }
124   return grad.sym_sizes().equals(shape_as_dim_vector());
125 }
126 
is_expandable_to_shape(const at::Tensor & grad) const127 bool InputMetadata::is_expandable_to_shape(const at::Tensor& grad) const {
128   if (!maybe_expandable_to(grad)) {
129     return false;
130   }
131   return at::is_expandable_to(shape_as_dim_vector(), grad.sym_sizes());
132 }
133 
reduce_grad(at::Tensor & grad) const134 at::Tensor InputMetadata::reduce_grad(at::Tensor& grad) const {
135   // reduce_grad should only be called if is_expandable_to_shape returns true.
136   TORCH_INTERNAL_ASSERT(maybe_expandable_to(grad));
137   return at::sum_to(std::move(grad), shape_as_dim_vector());
138 }
139 
incompatible_shape_error_message(const size_t index,const at::Tensor & grad) const140 std::stringstream InputMetadata::incompatible_shape_error_message(
141     const size_t index,
142     const at::Tensor& grad) const {
143   std::stringstream ss{};
144   ss << "invalid gradient at index " << index << " - got ";
145   if (::torch::autograd::is_cpp_nested_tensor(grad)) {
146     ss << grad._nested_tensor_size();
147   } else {
148     ss << grad.sym_sizes();
149   }
150   ss << " but expected shape compatible with ";
151   if (is_cpp_nested_tensor()) {
152     ss << shape_as_tensor();
153   } else {
154     ss << shape_as_dim_vector();
155   }
156   return ss;
157 }
158 
is_cpp_nested_tensor() const159 bool InputMetadata::is_cpp_nested_tensor() const {
160   bool ret = std::holds_alternative<at::Tensor>(shape_);
161   TORCH_INTERNAL_ASSERT(ret == (is_nested_ && !is_tensor_subclass_))
162   return ret;
163 }
164 
shape_as_dim_vector() const165 c10::SymIntArrayRef InputMetadata::shape_as_dim_vector() const {
166   const auto& dim_shape = std::get<SymIntSmallVec>(shape_);
167   return c10::SymIntArrayRef(dim_shape.data(), dim_shape.size());
168 }
169 
170 // Danger: not thread safe, caller must protect with lock
mutable_shape_as_dim_vector()171 SymIntSmallVec& InputMetadata::mutable_shape_as_dim_vector() {
172   return std::get<SymIntSmallVec>(shape_);
173 }
174 
is_nestedness_same(const at::Tensor & grad) const175 bool InputMetadata::is_nestedness_same(const at::Tensor& grad) const {
176   return (
177       grad.is_nested() == is_nested_ &&
178       ::torch::autograd::is_cpp_nested_tensor(grad) == is_cpp_nested_tensor());
179 }
180 
shape_as_tensor() const181 at::Tensor InputMetadata::shape_as_tensor() const {
182   return std::get<at::Tensor>(shape_);
183 }
184 
maybe_expandable_to(const at::Tensor & grad) const185 bool InputMetadata::maybe_expandable_to(const at::Tensor& grad) const {
186   // This is the initial step to determine whether or not the tensor represented
187   // by input_metadata is expandable to grad based on is-nestedness information
188   // alone. If this function returns true, then is_expandable_to_shape will be
189   // called. We support the following 3 types of expansion:
190   bool grad_is_nested = grad.is_nested();
191   if (!is_nested_ && !grad_is_nested) {
192     // Normal case (no NestedTensors are involved)
193     // (1) plain Tensor -> plain Tensor
194     return true;
195   } else {
196     // (2) python NT -> python NT
197     // (3) plain Tensor -> python NT
198     return (
199         grad_is_nested && is_python_dispatch(grad) &&
200         (!is_nested_ || is_tensor_subclass_));
201   }
202 }
203 
204 } // namespace torch::autograd
205