1 #include <torch/csrc/autograd/input_metadata.h>
2
3 // TODO: we may be able to move some imports from input_metadata.h to here, but
4 // it seems that function.h transitively depends on some of them.
5
6 namespace torch::autograd {
7
8 namespace {
9
compute_variant_shape(const at::Tensor & input)10 MetadataShape compute_variant_shape(const at::Tensor& input) {
11 if (input.is_nested() && !input.unsafeGetTensorImpl()->is_python_dispatch()) {
12 auto nested_size = input._nested_tensor_size();
13 return MetadataShape{std::in_place_type<at::Tensor>, nested_size};
14 }
15 return MetadataShape{std::in_place_type<SymIntSmallVec>, input.sym_sizes()};
16 }
17
is_python_dispatch(const at::Tensor & tensor)18 bool is_python_dispatch(const at::Tensor& tensor) {
19 return tensor.unsafeGetTensorImpl()->is_python_dispatch();
20 }
21
is_cpp_nested_tensor(const at::Tensor & tensor)22 bool is_cpp_nested_tensor(const at::Tensor& tensor) {
23 return tensor.is_nested() && !is_python_dispatch(tensor);
24 }
25
26 } // namespace
27
InputMetadata(const at::TensorOptions & options,MetadataShape input_shape,bool is_tensor_subclass,bool is_nested)28 InputMetadata::InputMetadata(
29 const at::TensorOptions& options,
30 MetadataShape input_shape,
31 bool is_tensor_subclass,
32 bool is_nested)
33 : options_{options},
34 shape_{std::move(input_shape)},
35 is_tensor_subclass_{is_tensor_subclass},
36 is_nested_{is_nested},
37 was_default_constructed_{false} {
38 auto device_ = options.device();
39 stream_ = c10::impl::getDeviceGuardImpl(device_.type())->getStream(device_);
40 }
41
InputMetadata(const at::Tensor & t)42 InputMetadata::InputMetadata(const at::Tensor& t)
43 : InputMetadata(
44 t.options(),
45 compute_variant_shape(t),
46 is_python_dispatch(t),
47 t.is_nested()) {}
48
zeros_like() const49 at::Tensor InputMetadata::zeros_like() const {
50 TORCH_CHECK(
51 !is_nested_, "Zeros is not currently supported for nested tensors.")
52 return at::zeros_symint(shape_as_dim_vector(), options_);
53 }
54
maybe_reduce(const size_t i,at::Tensor grad,const std::function<std::string (const std::string &)> & format_error) const55 at::Tensor InputMetadata::maybe_reduce(
56 const size_t i,
57 at::Tensor grad,
58 const std::function<std::string(const std::string&)>& format_error) const {
59 auto fail = [&]() {
60 const auto message = incompatible_shape_error_message(i, grad);
61 TORCH_CHECK(false, format_error(message.str()));
62 };
63
64 // Nested tensor makes my brain explode, so I've just hard-coded the logic
65 // for this case, at risk of code duplication. This logic does NOT do the
66 // careful oblivious logic as seen below
67 if (is_nested_ || is_cpp_nested_tensor() || grad.is_nested() ||
68 ::torch::autograd::is_cpp_nested_tensor(grad)) {
69 if (!is_same_shape(grad)) {
70 if (is_expandable_to_shape(grad)) {
71 return reduce_grad(grad);
72 } else {
73 fail();
74 }
75 } else {
76 return grad;
77 }
78 }
79
80 auto shape = shape_as_dim_vector();
81 auto desired = grad.sym_sizes();
82
83 size_t ndim = shape.size();
84 size_t target_dim = desired.size();
85 if (ndim > target_dim) {
86 fail();
87 }
88 bool needs_reduce = false;
89 for (const auto i : c10::irange(ndim)) {
90 const auto& size = shape[ndim - i - 1];
91 const auto& target = desired[target_dim - i - 1];
92 // The conditions here are written carefully so that we are able to
93 // infer deferred runtime asserts
94 if (TORCH_GUARD_SIZE_OBLIVIOUS(size.sym_eq(1))) {
95 // NB: we could short circuit this once needs_reduce is true but there's
96 // no point since the reduction function will guard on this anyway
97 if (!c10::definitely_true(size.sym_eq(target), __FILE__, __LINE__)) {
98 needs_reduce = true;
99 }
100 } else {
101 if (!size.sym_eq(target).expect_true(__FILE__, __LINE__)) {
102 fail();
103 }
104 }
105 }
106 if (ndim != target_dim) {
107 needs_reduce = true;
108 }
109
110 if (needs_reduce) {
111 return reduce_grad(grad);
112 } else {
113 return grad;
114 }
115 }
116
is_same_shape(const at::Tensor & grad) const117 bool InputMetadata::is_same_shape(const at::Tensor& grad) const {
118 if (!is_nestedness_same(grad)) {
119 return false;
120 }
121 if (is_cpp_nested_tensor()) {
122 return grad._nested_tensor_size().is_same_size(shape_as_tensor());
123 }
124 return grad.sym_sizes().equals(shape_as_dim_vector());
125 }
126
is_expandable_to_shape(const at::Tensor & grad) const127 bool InputMetadata::is_expandable_to_shape(const at::Tensor& grad) const {
128 if (!maybe_expandable_to(grad)) {
129 return false;
130 }
131 return at::is_expandable_to(shape_as_dim_vector(), grad.sym_sizes());
132 }
133
reduce_grad(at::Tensor & grad) const134 at::Tensor InputMetadata::reduce_grad(at::Tensor& grad) const {
135 // reduce_grad should only be called if is_expandable_to_shape returns true.
136 TORCH_INTERNAL_ASSERT(maybe_expandable_to(grad));
137 return at::sum_to(std::move(grad), shape_as_dim_vector());
138 }
139
incompatible_shape_error_message(const size_t index,const at::Tensor & grad) const140 std::stringstream InputMetadata::incompatible_shape_error_message(
141 const size_t index,
142 const at::Tensor& grad) const {
143 std::stringstream ss{};
144 ss << "invalid gradient at index " << index << " - got ";
145 if (::torch::autograd::is_cpp_nested_tensor(grad)) {
146 ss << grad._nested_tensor_size();
147 } else {
148 ss << grad.sym_sizes();
149 }
150 ss << " but expected shape compatible with ";
151 if (is_cpp_nested_tensor()) {
152 ss << shape_as_tensor();
153 } else {
154 ss << shape_as_dim_vector();
155 }
156 return ss;
157 }
158
is_cpp_nested_tensor() const159 bool InputMetadata::is_cpp_nested_tensor() const {
160 bool ret = std::holds_alternative<at::Tensor>(shape_);
161 TORCH_INTERNAL_ASSERT(ret == (is_nested_ && !is_tensor_subclass_))
162 return ret;
163 }
164
shape_as_dim_vector() const165 c10::SymIntArrayRef InputMetadata::shape_as_dim_vector() const {
166 const auto& dim_shape = std::get<SymIntSmallVec>(shape_);
167 return c10::SymIntArrayRef(dim_shape.data(), dim_shape.size());
168 }
169
170 // Danger: not thread safe, caller must protect with lock
mutable_shape_as_dim_vector()171 SymIntSmallVec& InputMetadata::mutable_shape_as_dim_vector() {
172 return std::get<SymIntSmallVec>(shape_);
173 }
174
is_nestedness_same(const at::Tensor & grad) const175 bool InputMetadata::is_nestedness_same(const at::Tensor& grad) const {
176 return (
177 grad.is_nested() == is_nested_ &&
178 ::torch::autograd::is_cpp_nested_tensor(grad) == is_cpp_nested_tensor());
179 }
180
shape_as_tensor() const181 at::Tensor InputMetadata::shape_as_tensor() const {
182 return std::get<at::Tensor>(shape_);
183 }
184
maybe_expandable_to(const at::Tensor & grad) const185 bool InputMetadata::maybe_expandable_to(const at::Tensor& grad) const {
186 // This is the initial step to determine whether or not the tensor represented
187 // by input_metadata is expandable to grad based on is-nestedness information
188 // alone. If this function returns true, then is_expandable_to_shape will be
189 // called. We support the following 3 types of expansion:
190 bool grad_is_nested = grad.is_nested();
191 if (!is_nested_ && !grad_is_nested) {
192 // Normal case (no NestedTensors are involved)
193 // (1) plain Tensor -> plain Tensor
194 return true;
195 } else {
196 // (2) python NT -> python NT
197 // (3) plain Tensor -> python NT
198 return (
199 grad_is_nested && is_python_dispatch(grad) &&
200 (!is_nested_ || is_tensor_subclass_));
201 }
202 }
203
204 } // namespace torch::autograd
205