1
2 #include <ATen/FunctionalInverses.h>
3
4 #include <ATen/ATen.h>
5 #include <ATen/ExpandUtils.h>
6 #include <ATen/WrapDimUtilsMulti.h>
7
8 #include <utility>
9 namespace at::functionalization {
10
11 // This logic is similar to autograd code for view backwards calls.
12 // We can't easily share it though, because (eventually) these functions
13 // will all call `permute/unsqueeze_copy()` instead of `permute/unsqueeze`.
14
permute_inverse(const Tensor & self,IntArrayRef dims,InverseReturnMode inverse_return_mode)15 static Tensor permute_inverse(const Tensor& self, IntArrayRef dims, InverseReturnMode inverse_return_mode) {
16 // invert the permutation
17 auto ndims = static_cast<int64_t>(dims.size());
18 std::vector<int64_t> dims_(ndims);
19 for(const auto i : c10::irange(ndims)) {
20 dims_[at::maybe_wrap_dim(dims[i], ndims)] = i;
21 }
22 if (inverse_return_mode != InverseReturnMode::NeverView) {
23 return at::permute(self, dims_);
24 } else {
25 return at::permute_copy(self, dims_);
26 }
27 }
28
unsqueeze_copy_to(const Tensor & self,c10::SymIntArrayRef sizes,InverseReturnMode inverse_return_mode)29 static Tensor unsqueeze_copy_to(const Tensor & self, c10::SymIntArrayRef sizes, InverseReturnMode inverse_return_mode) {
30 auto result = self;
31 bool need_alias = (inverse_return_mode == InverseReturnMode::AlwaysView);
32 int64_t nDims = static_cast<int64_t>(sizes.size());
33 for(const auto dim : c10::irange(nDims)) {
34 if (sizes[dim] == 1) {
35 need_alias = false;
36 if (inverse_return_mode != InverseReturnMode::NeverView) {
37 result = at::unsqueeze(result, dim);
38 } else {
39 result = at::unsqueeze_copy(result, dim);
40 }
41 }
42 }
43
44 // return an alias to ensure the output is a view when necessary
45 return need_alias ? at::alias(result) : result;
46 }
47
unsqueeze_copy_to(const Tensor & self,IntArrayRef dim,c10::SymIntArrayRef sizes,InverseReturnMode inverse_return_mode)48 static Tensor unsqueeze_copy_to(const Tensor & self, IntArrayRef dim, c10::SymIntArrayRef sizes, InverseReturnMode inverse_return_mode) {
49 const auto ndim = static_cast<int64_t>(sizes.size());
50 const auto mask = at::dim_list_to_bitset(dim, ndim);
51 Tensor result = self;
52 bool need_alias = (inverse_return_mode == InverseReturnMode::AlwaysView);
53 // in NumPy it's not an error to unsqueeze a scalar, but we still need to avoided
54 // unsqueezing in the backward.
55 if (ndim == 0) {
56 // return an alias to ensure the output is a view when necessary
57 return need_alias ? at::alias(result) : result;
58 }
59
60 for (const auto d : c10::irange(ndim)) {
61 if (mask.test(d) && sizes[d] == 1) {
62 need_alias = false;
63 if (inverse_return_mode != InverseReturnMode::NeverView) {
64 result = at::unsqueeze(result, d);
65 } else {
66 result = at::unsqueeze_copy(result, d);
67 }
68 }
69 }
70
71 // return an alias to ensure the output is a view when necessary
72 return need_alias ? at::alias(result) : result;
73 }
74
75 // Note [Functionalization Pass: View Inverses].
76 // This file contains the implementation of each "view inverse".
77 // These aren't really true inverses in the mathematically sense: each view inverse describes how to undo
78 // the original view (although it takes in different arguments).
79 //
80 // E.g. Below is an example of a program that has alias operations removed, and the role that view inverses play:
81 //
82 // normal program with views and mutations:
83 // view1 = input1.view_op(args...)
84 // view1.add_(1) (perform a mutation on the view, which should also modify input)
85
86 // version of the program with no aliasing, that instead uses view_inverse functions:
87 // view_copy1 = input1.view_copy_op(args...)
88 // view_copy1.add_(1) (perform a mutation on view_copy1. At this point, input1 is NOT modified)
89 // x = view_op_inverse(input1, view_copy1, args...)
90 //
91 // at this point, input1 and x should be equal
92 //
93 // Note that input1 is also passed as an argument to view_op_inverse in the above example.
94 // This isn't actually required for most view operators: it's only required for view ops
95 // where you can't figure out what the size of the base tensor is given just the view tensor and arguments.
96 // Examples are slice/select/scatter/squeeze/as_strided.
97 // We happen to be passing in the base tensor in all cases, mostly to make the codegen simpler.
98 // But you'll see below that the "base" argument is ignored by most view_inverse implementations.
99
100 // ----------------------------------------------------------
101 // Implementations of each view_inverse() function are below.
102 // One of these needs to be implemented for every existing non-composite view operator.
103 // The codegen automatically generates the corresponding function declaration.
104 // ----------------------------------------------------------
105
_fw_primal_inverse(const at::Tensor & base,const at::Tensor & mutated_view,InverseReturnMode inverse_return_mode,int64_t level)106 Tensor FunctionalInverses::_fw_primal_inverse(const at::Tensor& base, const at::Tensor& mutated_view, InverseReturnMode inverse_return_mode, int64_t level) {
107 TORCH_INTERNAL_ASSERT(false, "Attempted to call _fw_primal() during the functionalization pass. For now, this is not supported.");
108 return Tensor();
109 }
110
_make_dual_inverse(const at::Tensor & base,const at::Tensor & mutated_view,InverseReturnMode inverse_return_mode,const at::Tensor & tangent,int64_t level)111 Tensor FunctionalInverses::_make_dual_inverse(const at::Tensor& base, const at::Tensor& mutated_view, InverseReturnMode inverse_return_mode, const at::Tensor& tangent, int64_t level) {
112 TORCH_INTERNAL_ASSERT(false, "Attempted to call _make_dual() during the functionalization pass. For now, this is not supported.");
113 return Tensor();
114 }
115
view_as_real_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode)116 Tensor FunctionalInverses::view_as_real_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) {
117 if (inverse_return_mode != InverseReturnMode::NeverView) {
118 return at::view_as_complex(mutated_view);
119 } else {
120 return at::view_as_complex_copy(mutated_view);
121 }
122 }
123
view_as_complex_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode)124 Tensor FunctionalInverses::view_as_complex_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) {
125 if (inverse_return_mode != InverseReturnMode::NeverView) {
126 return at::view_as_real(mutated_view.resolve_conj());
127 } else {
128 return at::view_as_real_copy(mutated_view.resolve_conj());
129 }
130 }
131
_conj_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode)132 Tensor FunctionalInverses::_conj_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) {
133 if (inverse_return_mode != InverseReturnMode::NeverView) {
134 return at::_conj(mutated_view);
135 } else {
136 return at::_conj_copy(mutated_view);
137 }
138 }
139
_neg_view_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode)140 Tensor FunctionalInverses::_neg_view_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) {
141 if (inverse_return_mode != InverseReturnMode::NeverView) {
142 return at::_neg_view(mutated_view);
143 } else {
144 return at::_neg_view_copy(mutated_view);
145 }
146 }
147
as_strided_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode,at::SymIntArrayRef size,at::SymIntArrayRef stride,std::optional<c10::SymInt> storage_offset)148 Tensor FunctionalInverses::as_strided_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, at::SymIntArrayRef size, at::SymIntArrayRef stride, std::optional<c10::SymInt> storage_offset) {
149 if (inverse_return_mode == InverseReturnMode::AlwaysView) {
150 // NB: assumes mutated_view is a narrowed view of base.
151 // We should NOT do this for functionalization
152 return mutated_view.as_strided_symint(
153 base.sym_sizes(), base.sym_strides(), base.sym_storage_offset());
154 } else {
155 return base.as_strided_scatter_symint(mutated_view, size, stride, std::move(storage_offset));
156 }
157 }
158
diagonal_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode,int64_t offset,int64_t dim1,int64_t dim2)159 Tensor FunctionalInverses::diagonal_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, int64_t offset, int64_t dim1, int64_t dim2) {
160 if (inverse_return_mode == InverseReturnMode::AlwaysView) {
161 // NB: assumes mutated_view is a narrowed view of base.
162 // We should NOT do this for functionalization
163 return mutated_view.as_strided_symint(
164 base.sym_sizes(), base.sym_strides(), base.sym_storage_offset());
165 } else {
166 return base.diagonal_scatter(mutated_view, offset, dim1, dim2);
167 }
168 }
169
expand_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode,at::SymIntArrayRef size,bool implicit)170 Tensor FunctionalInverses::expand_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, at::SymIntArrayRef size, bool implicit) {
171 if (inverse_return_mode == InverseReturnMode::AlwaysView) {
172 // NB: assumes mutated_view is an expanded view of base.
173 // We should NOT do this for functionalization
174 return mutated_view.as_strided_symint(
175 base.sym_sizes(), base.sym_strides(), base.sym_storage_offset());
176 } else {
177 return base + at::sum_to(
178 mutated_view - base,
179 base.sym_sizes(),
180 /*always_return_non_view=*/inverse_return_mode == InverseReturnMode::NeverView
181 );
182 }
183 }
184
permute_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode,at::IntArrayRef dims)185 Tensor FunctionalInverses::permute_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, at::IntArrayRef dims) {
186 return at::functionalization::permute_inverse(mutated_view, dims, inverse_return_mode);
187 }
188
_reshape_alias_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode,at::SymIntArrayRef size,at::SymIntArrayRef stride)189 Tensor FunctionalInverses::_reshape_alias_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, at::SymIntArrayRef size, at::SymIntArrayRef stride) {
190 // Note that I'm directly calling reshape(), and ignoring the strides.
191 // _reshape_alias() isn't available from user code, and is an implementation detail of reshape().
192 // Specifically, passing in the strides directly can get us into trouble in cases like:
193 // b = a[0]; c = b.reshape(...); c.add_(1); print(a)
194 // When we eventually run the _reshape_alias_inverse() call here, if we were to pass in both sizes and strides,
195 // The call would fail because `mutated_view` doesn't have enough bytes of storage.
196 if (inverse_return_mode != InverseReturnMode::NeverView) {
197 return at::_reshape_alias_symint(mutated_view, base.sym_sizes(), base.sym_strides());
198 } else {
199 return at::_reshape_alias_copy_symint(mutated_view, base.sym_sizes(), base.sym_strides());
200 }
201 }
202
select_int_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode,int64_t dim,c10::SymInt index)203 Tensor FunctionalInverses::select_int_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, int64_t dim, c10::SymInt index) {
204 if (inverse_return_mode == InverseReturnMode::AlwaysView) {
205 // NB: assumes mutated_view is a narrowed view of base.
206 // We should NOT do this for functionalization
207 return mutated_view.as_strided_symint(
208 base.sym_sizes(), base.sym_strides(), base.sym_storage_offset());
209 } else {
210 return base.select_scatter_symint(mutated_view, dim, std::move(index));
211 }
212 }
213
detach_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode)214 Tensor FunctionalInverses::detach_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) {
215 // the functionalization pass doesn't care about autograd metadata - as a view, I think detach() is just an identity function
216 return mutated_view;
217 }
218
lift_fresh_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode)219 Tensor FunctionalInverses::lift_fresh_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) {
220 return mutated_view;
221 }
222
slice_Tensor_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode,int64_t dim,std::optional<c10::SymInt> start,std::optional<c10::SymInt> end,c10::SymInt step)223 Tensor FunctionalInverses::slice_Tensor_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, int64_t dim, std::optional<c10::SymInt> start, std::optional<c10::SymInt> end, c10::SymInt step) {
224 if (inverse_return_mode == InverseReturnMode::AlwaysView) {
225 // NB: assumes mutated_view is a narrowed view of base.
226 // We should NOT do this for functionalization
227 return mutated_view.slice_inverse_symint(
228 base, dim, std::move(start), std::move(end), std::move(step));
229 } else {
230 return base.slice_scatter_symint(mutated_view, dim, std::move(start), std::move(end), std::move(step));
231 }
232 }
233
split_Tensor_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode,int64_t mutated_view_idx,c10::SymInt split_size,int64_t dim)234 Tensor FunctionalInverses::split_Tensor_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, int64_t mutated_view_idx, c10::SymInt split_size, int64_t dim) {
235 // It would be nice if this logic could be re-used from autograd's split_backward(), but I don't think it can.
236 // For functionalization, we have only have one of the tensors from the TensorList outputed by split(), and we want to layer i
237 // on top of the base tensor.
238 // For autograd, we have all of the tensors outputted by split() and we just want to stack them.
239 dim = at::maybe_wrap_dim(dim, base.dim());
240 auto dim_size = base.sym_size(dim);
241 auto start = split_size * mutated_view_idx;
242 auto end = split_size + start;
243 if (end > dim_size) end = dim_size;
244
245 if (inverse_return_mode == InverseReturnMode::AlwaysView) {
246 // NB: assumes mutated_view is a narrowed view of base.
247 // We should NOT do this for functionalization
248 return mutated_view.slice_inverse_symint(base, dim, start, end, 1);
249 } else {
250 return base.slice_scatter_symint(mutated_view, dim, start, end, 1);
251 }
252 }
253
split_with_sizes_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode,int64_t mutated_view_idx,c10::SymIntArrayRef split_sizes,int64_t dim)254 Tensor FunctionalInverses::split_with_sizes_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, int64_t mutated_view_idx, c10::SymIntArrayRef split_sizes, int64_t dim) {
255 dim = at::maybe_wrap_dim(dim, base.dim());
256 auto dim_size = base.sym_size(dim);
257 c10::SymInt start = 0;
258 for (auto i = 0; i < mutated_view_idx; ++i) {
259 start += split_sizes[i];
260 }
261 auto end = start + split_sizes[mutated_view_idx];
262 if (end > dim_size) end = dim_size;
263
264 if (inverse_return_mode == InverseReturnMode::AlwaysView) {
265 // NB: assumes mutated_view is a narrowed view of base.
266 // We should NOT do this for functionalization
267 return mutated_view.slice_inverse_symint(base, dim, start, end, 1);
268 } else {
269 return base.slice_scatter_symint(mutated_view, dim, start, end, 1);
270 }
271 }
272
squeeze_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode)273 Tensor FunctionalInverses::squeeze_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) {
274 return unsqueeze_copy_to(mutated_view, base.sym_sizes(), inverse_return_mode);
275 }
276
squeeze_dim_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode,int64_t dim)277 Tensor FunctionalInverses::squeeze_dim_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, int64_t dim) {
278 return unsqueeze_copy_to(mutated_view, dim, base.sym_sizes(), inverse_return_mode);
279 }
280
squeeze_dims_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode,IntArrayRef dim)281 Tensor FunctionalInverses::squeeze_dims_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, IntArrayRef dim) {
282 return unsqueeze_copy_to(mutated_view, dim, base.sym_sizes(), inverse_return_mode);
283 }
284
t_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode)285 Tensor FunctionalInverses::t_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) {
286 if (inverse_return_mode != InverseReturnMode::NeverView) {
287 return at::t(mutated_view);
288 } else {
289 return at::t_copy(mutated_view);
290 }
291 }
292
transpose_int_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode,int64_t dim0,int64_t dim1)293 Tensor FunctionalInverses::transpose_int_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, int64_t dim0, int64_t dim1) {
294 if (inverse_return_mode != InverseReturnMode::NeverView) {
295 return transpose(mutated_view, dim0, dim1);
296 } else {
297 return transpose_copy(mutated_view, dim0, dim1);
298 }
299 }
300
_nested_view_from_buffer_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode,const Tensor & nested_sizes,const Tensor & nested_strides,const Tensor & storage_offsets)301 Tensor FunctionalInverses::_nested_view_from_buffer_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, const Tensor& nested_sizes, const Tensor& nested_strides, const Tensor& storage_offsets) {
302 TORCH_INTERNAL_ASSERT(false, "Attempted to call _nested_view_from_buffer() during the functionalization pass. For now, nested tensors aren't supported during functionalization");
303 return Tensor();
304 }
305
_nested_view_from_jagged_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode,const Tensor & offsets,const Tensor & dummy,const std::optional<Tensor> & lengths,int64_t ragged_idx,const std::optional<Tensor> & min_seqlen,const std::optional<Tensor> & max_seqlen)306 Tensor FunctionalInverses::_nested_view_from_jagged_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, const Tensor& offsets, const Tensor& dummy, const std::optional<Tensor>& lengths, int64_t ragged_idx, const std::optional<Tensor>& min_seqlen, const std::optional<Tensor>& max_seqlen) {
307 auto values = at::_nested_get_values(mutated_view);
308 if (inverse_return_mode != InverseReturnMode::NeverView) {
309 return values;
310 } else {
311 return values.clone(/*memory_format=*/at::MemoryFormat::Contiguous);
312 }
313 }
314
_nested_get_values_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode)315 Tensor FunctionalInverses::_nested_get_values_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) {
316 auto offsets = at::_nested_get_offsets(base);
317 auto lengths = at::_nested_get_lengths(base);
318 auto ragged_idx = at::_nested_get_ragged_idx(base);
319 auto dummy = at::_nested_get_jagged_dummy(base);
320 auto min_seqlen = at::_nested_get_min_seqlen(base);
321 auto max_seqlen = at::_nested_get_max_seqlen(base);
322 auto nt = at::_nested_view_from_jagged(
323 mutated_view, offsets, dummy, lengths, ragged_idx,
324 (min_seqlen.defined() ? std::optional<Tensor>(min_seqlen) : std::nullopt),
325 (max_seqlen.defined() ? std::optional<Tensor>(max_seqlen) : std::nullopt));
326
327 if (inverse_return_mode != InverseReturnMode::NeverView) {
328 return nt;
329 } else {
330 return nt.clone(/*memory_format=*/at::MemoryFormat::Contiguous);
331 }
332 }
333
unsqueeze_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode,int64_t dim)334 Tensor FunctionalInverses::unsqueeze_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, int64_t dim) {
335 if (inverse_return_mode != InverseReturnMode::NeverView) {
336 return at::squeeze(mutated_view, dim);
337 } else {
338 return at::squeeze_copy(mutated_view, dim);
339 }
340 }
341
_indices_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode)342 Tensor FunctionalInverses::_indices_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) {
343 TORCH_INTERNAL_ASSERT(false, "Attempted to call _indices() during the functionalization pass. For now, sparse tensors aren't supported during functionalization");
344 return Tensor();
345 }
346
_values_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode)347 Tensor FunctionalInverses::_values_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) {
348 TORCH_INTERNAL_ASSERT(false, "Attempted to call _values() during the functionalization pass. For now, sparse tensors aren't supported during functionalization");
349 return Tensor();
350 }
351
indices_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode)352 Tensor FunctionalInverses::indices_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) {
353 TORCH_INTERNAL_ASSERT(false, "Attempted to call indices() during the functionalization pass. For now, sparse tensors aren't supported during functionalization");
354 return Tensor();
355 }
356
values_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode)357 Tensor FunctionalInverses::values_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) {
358 TORCH_INTERNAL_ASSERT(false, "Attempted to call values() during the functionalization pass. For now, sparse tensors aren't supported during functionalization");
359 return Tensor();
360 }
361
_sparse_broadcast_to_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode,at::IntArrayRef size)362 Tensor FunctionalInverses::_sparse_broadcast_to_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, at::IntArrayRef size) {
363 TORCH_INTERNAL_ASSERT(false, "Attempted to call _sparse_broadcast_to() during the functionalization pass. For now, sparse tensors aren't supported during functionalization");
364 return Tensor();
365 }
366
crow_indices_inverse(const at::Tensor & base,const at::Tensor & mutated_view,InverseReturnMode inverse_return_mode)367 Tensor FunctionalInverses::crow_indices_inverse(const at::Tensor& base, const at::Tensor& mutated_view, InverseReturnMode inverse_return_mode) {
368 TORCH_INTERNAL_ASSERT(false, "Attempted to call crow_indices() during the functionalization pass. For now, sparse tensors aren't supported during functionalization");
369 return Tensor();
370 }
371
col_indices_inverse(const at::Tensor & base,const at::Tensor & mutated_view,InverseReturnMode inverse_return_mode)372 Tensor FunctionalInverses::col_indices_inverse(const at::Tensor& base, const at::Tensor& mutated_view, InverseReturnMode inverse_return_mode) {
373 TORCH_INTERNAL_ASSERT(false, "Attempted to call col_indices() during the functionalization pass. For now, sparse tensors aren't supported during functionalization");
374 return Tensor();
375 }
376
ccol_indices_inverse(const at::Tensor & base,const at::Tensor & mutated_view,InverseReturnMode inverse_return_mode)377 Tensor FunctionalInverses::ccol_indices_inverse(const at::Tensor& base, const at::Tensor& mutated_view, InverseReturnMode inverse_return_mode) {
378 TORCH_INTERNAL_ASSERT(false, "Attempted to call ccol_indices() during the functionalization pass. For now, sparse tensors aren't supported during functionalization");
379 return Tensor();
380 }
381
row_indices_inverse(const at::Tensor & base,const at::Tensor & mutated_view,InverseReturnMode inverse_return_mode)382 Tensor FunctionalInverses::row_indices_inverse(const at::Tensor& base, const at::Tensor& mutated_view, InverseReturnMode inverse_return_mode) {
383 TORCH_INTERNAL_ASSERT(false, "Attempted to call row_indices() during the functionalization pass. For now, sparse tensors aren't supported during functionalization");
384 return Tensor();
385 }
386
unbind_int_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode,int64_t mutated_view_idx,int64_t dim)387 Tensor FunctionalInverses::unbind_int_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, int64_t mutated_view_idx, int64_t dim) {
388 if (inverse_return_mode == InverseReturnMode::AlwaysView) {
389 // NB: assumes mutated_view is a narrowed view of base.
390 // We should NOT do this for functionalization
391 return mutated_view.as_strided_symint(
392 base.sym_sizes(), base.sym_strides(), base.sym_storage_offset());
393 } else {
394 dim = at::maybe_wrap_dim(dim, static_cast<int64_t>(base.sizes().size()));
395 return base.select_scatter(mutated_view, dim, mutated_view_idx);
396 }
397 }
398
view_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode,at::SymIntArrayRef size)399 Tensor FunctionalInverses::view_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, at::SymIntArrayRef size) {
400 if (inverse_return_mode != InverseReturnMode::NeverView) {
401 return mutated_view.view_symint(base.sym_sizes());
402 } else {
403 return at::view_copy_symint(mutated_view, base.sym_sizes());
404 }
405 }
406
407
view_dtype_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode,at::ScalarType dtype)408 Tensor FunctionalInverses::view_dtype_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, at::ScalarType dtype) {
409 if (inverse_return_mode != InverseReturnMode::NeverView) {
410 return mutated_view.view(base.scalar_type());
411 } else {
412 return at::view_copy(mutated_view, base.scalar_type());
413 }
414 }
415
unfold_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode,int64_t dimension,int64_t size,int64_t step)416 Tensor FunctionalInverses::unfold_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, int64_t dimension, int64_t size, int64_t step) {
417 if (inverse_return_mode == InverseReturnMode::AlwaysView) {
418 // NB: assumes mutated_view is a narrowed view of base.
419 // We should NOT do this for functionalization
420 return mutated_view.as_strided_symint(
421 base.sym_sizes(), base.sym_strides(), base.sym_storage_offset());
422 } else {
423 // I think autograd and the functionalization pass want the exact same thing here, but need to test to confirm.
424 // unfold_backward() is safe to use here because it is NOT a view op.
425 // (note: technically, we'll have an extra memory copy.
426 // We'd need to add an aliasing version of unfold_backward to fix that though).
427 TORCH_CHECK(
428 !(inverse_return_mode == InverseReturnMode::ViewOrScatterInverse && size > step),
429 "While executing unfold, functionalization encountered a tensor being mutated that has internal overlap. \
430 When using torch.compile (or running functionalization directly), this is banned \
431 as the behavior is not well defined. Consider cloning the tensor before mutating it, \
432 or removing the mutation from your model."
433 );
434 return unfold_backward(mutated_view, base.sizes(), dimension, size, step);
435 }
436 }
437
alias_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode)438 Tensor FunctionalInverses::alias_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) {
439 if (inverse_return_mode != InverseReturnMode::NeverView) {
440 return at::alias(mutated_view);
441 } else {
442 return at::alias_copy(mutated_view);
443 }
444 }
445
chunk_inverse(const at::Tensor & base,const at::Tensor & mutated_view,InverseReturnMode inverse_return_mode,int64_t mutated_view_idx,int chunks,int dim)446 Tensor FunctionalInverses::chunk_inverse(const at::Tensor & base, const at::Tensor & mutated_view, InverseReturnMode inverse_return_mode, int64_t mutated_view_idx, int chunks, int dim) {
447 // TODO: Can the logic from TensorShape.cpp be reused here somehow?
448 const auto dim_size = base.sym_size(dim);
449 auto split_size = (dim_size + chunks - 1) / chunks;
450 std::vector<c10::SymInt> split_sizes(chunks, split_size);
451 split_sizes[chunks - 1] = split_size - (split_size * chunks - dim_size);
452 return split_with_sizes_inverse(base, mutated_view, inverse_return_mode, mutated_view_idx, split_sizes, dim);
453 }
454
narrow_inverse(const at::Tensor & base,const at::Tensor & mutated_view,InverseReturnMode inverse_return_mode,int dim,c10::SymInt start,c10::SymInt length)455 Tensor FunctionalInverses::narrow_inverse(const at::Tensor & base, const at::Tensor & mutated_view, InverseReturnMode inverse_return_mode, int dim, c10::SymInt start, c10::SymInt length) {
456 if (inverse_return_mode == InverseReturnMode::AlwaysView) {
457 // NB: assumes mutated_view is a narrowed view of base.
458 // We should NOT do this for functionalization
459 return mutated_view.slice_inverse_symint(base, dim, start, start + length, 1);
460 } else {
461 return base.slice_scatter_symint(
462 mutated_view, dim, start, start + length, 1);
463 }
464 }
465
slice_inverse_inverse(const at::Tensor & base,const at::Tensor & mutated_view,InverseReturnMode inverse_return_mode,const at::Tensor & src,int64_t dim,std::optional<c10::SymInt> start,std::optional<c10::SymInt> end,c10::SymInt step)466 Tensor FunctionalInverses::slice_inverse_inverse(const at::Tensor & base, const at::Tensor & mutated_view, InverseReturnMode inverse_return_mode, const at::Tensor & src, int64_t dim, std::optional<c10::SymInt> start, std::optional<c10::SymInt> end, c10::SymInt step) {
467 // slice_inverse() inverse is just slice()
468 if (inverse_return_mode == InverseReturnMode::NeverView) {
469 return at::slice_copy_symint(
470 mutated_view, dim, std::move(start), std::move(end), std::move(step));
471 } else {
472 return mutated_view.slice_symint(
473 dim, std::move(start), std::move(end), std::move(step));
474 }
475 }
476
477 } // namespace at::functionalization
478