xref: /aosp_15_r20/external/pytorch/aten/src/ATen/FunctionalInverses.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker 
2*da0073e9SAndroid Build Coastguard Worker #include <ATen/FunctionalInverses.h>
3*da0073e9SAndroid Build Coastguard Worker 
4*da0073e9SAndroid Build Coastguard Worker #include <ATen/ATen.h>
5*da0073e9SAndroid Build Coastguard Worker #include <ATen/ExpandUtils.h>
6*da0073e9SAndroid Build Coastguard Worker #include <ATen/WrapDimUtilsMulti.h>
7*da0073e9SAndroid Build Coastguard Worker 
8*da0073e9SAndroid Build Coastguard Worker #include <utility>
9*da0073e9SAndroid Build Coastguard Worker namespace at::functionalization {
10*da0073e9SAndroid Build Coastguard Worker 
11*da0073e9SAndroid Build Coastguard Worker // This logic is similar to autograd code for view backwards calls.
12*da0073e9SAndroid Build Coastguard Worker // We can't easily share it though, because (eventually) these functions
13*da0073e9SAndroid Build Coastguard Worker // will all call `permute/unsqueeze_copy()` instead of `permute/unsqueeze`.
14*da0073e9SAndroid Build Coastguard Worker 
permute_inverse(const Tensor & self,IntArrayRef dims,InverseReturnMode inverse_return_mode)15*da0073e9SAndroid Build Coastguard Worker static Tensor permute_inverse(const Tensor& self, IntArrayRef dims, InverseReturnMode inverse_return_mode) {
16*da0073e9SAndroid Build Coastguard Worker   // invert the permutation
17*da0073e9SAndroid Build Coastguard Worker   auto ndims = static_cast<int64_t>(dims.size());
18*da0073e9SAndroid Build Coastguard Worker   std::vector<int64_t> dims_(ndims);
19*da0073e9SAndroid Build Coastguard Worker   for(const auto i : c10::irange(ndims)) {
20*da0073e9SAndroid Build Coastguard Worker     dims_[at::maybe_wrap_dim(dims[i], ndims)] = i;
21*da0073e9SAndroid Build Coastguard Worker   }
22*da0073e9SAndroid Build Coastguard Worker   if (inverse_return_mode != InverseReturnMode::NeverView) {
23*da0073e9SAndroid Build Coastguard Worker     return at::permute(self, dims_);
24*da0073e9SAndroid Build Coastguard Worker   } else {
25*da0073e9SAndroid Build Coastguard Worker     return at::permute_copy(self, dims_);
26*da0073e9SAndroid Build Coastguard Worker   }
27*da0073e9SAndroid Build Coastguard Worker }
28*da0073e9SAndroid Build Coastguard Worker 
unsqueeze_copy_to(const Tensor & self,c10::SymIntArrayRef sizes,InverseReturnMode inverse_return_mode)29*da0073e9SAndroid Build Coastguard Worker static Tensor unsqueeze_copy_to(const Tensor & self, c10::SymIntArrayRef sizes, InverseReturnMode inverse_return_mode) {
30*da0073e9SAndroid Build Coastguard Worker   auto result = self;
31*da0073e9SAndroid Build Coastguard Worker   bool need_alias = (inverse_return_mode == InverseReturnMode::AlwaysView);
32*da0073e9SAndroid Build Coastguard Worker   int64_t nDims = static_cast<int64_t>(sizes.size());
33*da0073e9SAndroid Build Coastguard Worker   for(const auto dim : c10::irange(nDims)) {
34*da0073e9SAndroid Build Coastguard Worker     if (sizes[dim] == 1) {
35*da0073e9SAndroid Build Coastguard Worker       need_alias = false;
36*da0073e9SAndroid Build Coastguard Worker       if (inverse_return_mode != InverseReturnMode::NeverView) {
37*da0073e9SAndroid Build Coastguard Worker         result = at::unsqueeze(result, dim);
38*da0073e9SAndroid Build Coastguard Worker       } else {
39*da0073e9SAndroid Build Coastguard Worker         result = at::unsqueeze_copy(result, dim);
40*da0073e9SAndroid Build Coastguard Worker       }
41*da0073e9SAndroid Build Coastguard Worker     }
42*da0073e9SAndroid Build Coastguard Worker   }
43*da0073e9SAndroid Build Coastguard Worker 
44*da0073e9SAndroid Build Coastguard Worker   // return an alias to ensure the output is a view when necessary
45*da0073e9SAndroid Build Coastguard Worker   return need_alias ? at::alias(result) : result;
46*da0073e9SAndroid Build Coastguard Worker }
47*da0073e9SAndroid Build Coastguard Worker 
unsqueeze_copy_to(const Tensor & self,IntArrayRef dim,c10::SymIntArrayRef sizes,InverseReturnMode inverse_return_mode)48*da0073e9SAndroid Build Coastguard Worker static Tensor unsqueeze_copy_to(const Tensor & self, IntArrayRef dim, c10::SymIntArrayRef sizes, InverseReturnMode inverse_return_mode) {
49*da0073e9SAndroid Build Coastguard Worker   const auto ndim = static_cast<int64_t>(sizes.size());
50*da0073e9SAndroid Build Coastguard Worker   const auto mask = at::dim_list_to_bitset(dim, ndim);
51*da0073e9SAndroid Build Coastguard Worker   Tensor result = self;
52*da0073e9SAndroid Build Coastguard Worker   bool need_alias = (inverse_return_mode == InverseReturnMode::AlwaysView);
53*da0073e9SAndroid Build Coastguard Worker   // in NumPy it's not an error to unsqueeze a scalar, but we still need to avoided
54*da0073e9SAndroid Build Coastguard Worker   // unsqueezing in the backward.
55*da0073e9SAndroid Build Coastguard Worker   if (ndim == 0) {
56*da0073e9SAndroid Build Coastguard Worker     // return an alias to ensure the output is a view when necessary
57*da0073e9SAndroid Build Coastguard Worker     return need_alias ? at::alias(result) : result;
58*da0073e9SAndroid Build Coastguard Worker   }
59*da0073e9SAndroid Build Coastguard Worker 
60*da0073e9SAndroid Build Coastguard Worker   for (const auto d : c10::irange(ndim)) {
61*da0073e9SAndroid Build Coastguard Worker     if (mask.test(d) && sizes[d] == 1) {
62*da0073e9SAndroid Build Coastguard Worker       need_alias = false;
63*da0073e9SAndroid Build Coastguard Worker       if (inverse_return_mode != InverseReturnMode::NeverView) {
64*da0073e9SAndroid Build Coastguard Worker         result = at::unsqueeze(result, d);
65*da0073e9SAndroid Build Coastguard Worker       } else {
66*da0073e9SAndroid Build Coastguard Worker         result = at::unsqueeze_copy(result, d);
67*da0073e9SAndroid Build Coastguard Worker       }
68*da0073e9SAndroid Build Coastguard Worker     }
69*da0073e9SAndroid Build Coastguard Worker   }
70*da0073e9SAndroid Build Coastguard Worker 
71*da0073e9SAndroid Build Coastguard Worker   // return an alias to ensure the output is a view when necessary
72*da0073e9SAndroid Build Coastguard Worker   return need_alias ? at::alias(result) : result;
73*da0073e9SAndroid Build Coastguard Worker }
74*da0073e9SAndroid Build Coastguard Worker 
75*da0073e9SAndroid Build Coastguard Worker // Note [Functionalization Pass: View Inverses].
76*da0073e9SAndroid Build Coastguard Worker // This file contains the implementation of each "view inverse".
77*da0073e9SAndroid Build Coastguard Worker // These aren't really true inverses in the mathematically sense: each view inverse describes how to undo
78*da0073e9SAndroid Build Coastguard Worker // the original view (although it takes in different arguments).
79*da0073e9SAndroid Build Coastguard Worker //
80*da0073e9SAndroid Build Coastguard Worker // E.g. Below is an example of a program that has alias operations removed, and the role that view inverses play:
81*da0073e9SAndroid Build Coastguard Worker //
82*da0073e9SAndroid Build Coastguard Worker // normal program with views and mutations:
83*da0073e9SAndroid Build Coastguard Worker // view1 = input1.view_op(args...)
84*da0073e9SAndroid Build Coastguard Worker // view1.add_(1) (perform a mutation on the view, which should also modify input)
85*da0073e9SAndroid Build Coastguard Worker 
86*da0073e9SAndroid Build Coastguard Worker // version of the program with no aliasing, that instead uses view_inverse functions:
87*da0073e9SAndroid Build Coastguard Worker // view_copy1 = input1.view_copy_op(args...)
88*da0073e9SAndroid Build Coastguard Worker // view_copy1.add_(1) (perform a mutation on view_copy1. At this point, input1 is NOT modified)
89*da0073e9SAndroid Build Coastguard Worker // x = view_op_inverse(input1, view_copy1, args...)
90*da0073e9SAndroid Build Coastguard Worker //
91*da0073e9SAndroid Build Coastguard Worker // at this point, input1 and x should be equal
92*da0073e9SAndroid Build Coastguard Worker //
93*da0073e9SAndroid Build Coastguard Worker // Note that input1 is also passed as an argument to view_op_inverse in the above example.
94*da0073e9SAndroid Build Coastguard Worker // This isn't actually required for most view operators: it's only required for view ops
95*da0073e9SAndroid Build Coastguard Worker // where you can't figure out what the size of the base tensor is given just the view tensor and arguments.
96*da0073e9SAndroid Build Coastguard Worker // Examples are slice/select/scatter/squeeze/as_strided.
97*da0073e9SAndroid Build Coastguard Worker // We happen to be passing in the base tensor in all cases, mostly to make the codegen simpler.
98*da0073e9SAndroid Build Coastguard Worker // But you'll see below that the "base" argument is ignored by most view_inverse implementations.
99*da0073e9SAndroid Build Coastguard Worker 
100*da0073e9SAndroid Build Coastguard Worker // ----------------------------------------------------------
101*da0073e9SAndroid Build Coastguard Worker // Implementations of each view_inverse() function are below.
102*da0073e9SAndroid Build Coastguard Worker // One of these needs to be implemented for every existing non-composite view operator.
103*da0073e9SAndroid Build Coastguard Worker // The codegen automatically generates the corresponding function declaration.
104*da0073e9SAndroid Build Coastguard Worker // ----------------------------------------------------------
105*da0073e9SAndroid Build Coastguard Worker 
_fw_primal_inverse(const at::Tensor & base,const at::Tensor & mutated_view,InverseReturnMode inverse_return_mode,int64_t level)106*da0073e9SAndroid Build Coastguard Worker Tensor FunctionalInverses::_fw_primal_inverse(const at::Tensor& base, const at::Tensor& mutated_view, InverseReturnMode inverse_return_mode, int64_t level) {
107*da0073e9SAndroid Build Coastguard Worker     TORCH_INTERNAL_ASSERT(false, "Attempted to call _fw_primal() during the functionalization pass. For now, this is not supported.");
108*da0073e9SAndroid Build Coastguard Worker     return Tensor();
109*da0073e9SAndroid Build Coastguard Worker }
110*da0073e9SAndroid Build Coastguard Worker 
_make_dual_inverse(const at::Tensor & base,const at::Tensor & mutated_view,InverseReturnMode inverse_return_mode,const at::Tensor & tangent,int64_t level)111*da0073e9SAndroid Build Coastguard Worker Tensor FunctionalInverses::_make_dual_inverse(const at::Tensor& base, const at::Tensor& mutated_view, InverseReturnMode inverse_return_mode, const at::Tensor& tangent, int64_t level) {
112*da0073e9SAndroid Build Coastguard Worker     TORCH_INTERNAL_ASSERT(false, "Attempted to call _make_dual() during the functionalization pass. For now, this is not supported.");
113*da0073e9SAndroid Build Coastguard Worker     return Tensor();
114*da0073e9SAndroid Build Coastguard Worker }
115*da0073e9SAndroid Build Coastguard Worker 
view_as_real_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode)116*da0073e9SAndroid Build Coastguard Worker Tensor FunctionalInverses::view_as_real_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) {
117*da0073e9SAndroid Build Coastguard Worker     if (inverse_return_mode != InverseReturnMode::NeverView) {
118*da0073e9SAndroid Build Coastguard Worker       return at::view_as_complex(mutated_view);
119*da0073e9SAndroid Build Coastguard Worker     } else {
120*da0073e9SAndroid Build Coastguard Worker       return at::view_as_complex_copy(mutated_view);
121*da0073e9SAndroid Build Coastguard Worker     }
122*da0073e9SAndroid Build Coastguard Worker }
123*da0073e9SAndroid Build Coastguard Worker 
view_as_complex_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode)124*da0073e9SAndroid Build Coastguard Worker Tensor FunctionalInverses::view_as_complex_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) {
125*da0073e9SAndroid Build Coastguard Worker     if (inverse_return_mode != InverseReturnMode::NeverView) {
126*da0073e9SAndroid Build Coastguard Worker       return at::view_as_real(mutated_view.resolve_conj());
127*da0073e9SAndroid Build Coastguard Worker     } else {
128*da0073e9SAndroid Build Coastguard Worker       return at::view_as_real_copy(mutated_view.resolve_conj());
129*da0073e9SAndroid Build Coastguard Worker     }
130*da0073e9SAndroid Build Coastguard Worker }
131*da0073e9SAndroid Build Coastguard Worker 
_conj_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode)132*da0073e9SAndroid Build Coastguard Worker Tensor FunctionalInverses::_conj_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) {
133*da0073e9SAndroid Build Coastguard Worker     if (inverse_return_mode != InverseReturnMode::NeverView) {
134*da0073e9SAndroid Build Coastguard Worker       return at::_conj(mutated_view);
135*da0073e9SAndroid Build Coastguard Worker     } else {
136*da0073e9SAndroid Build Coastguard Worker       return at::_conj_copy(mutated_view);
137*da0073e9SAndroid Build Coastguard Worker     }
138*da0073e9SAndroid Build Coastguard Worker }
139*da0073e9SAndroid Build Coastguard Worker 
_neg_view_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode)140*da0073e9SAndroid Build Coastguard Worker Tensor FunctionalInverses::_neg_view_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) {
141*da0073e9SAndroid Build Coastguard Worker     if (inverse_return_mode != InverseReturnMode::NeverView) {
142*da0073e9SAndroid Build Coastguard Worker       return at::_neg_view(mutated_view);
143*da0073e9SAndroid Build Coastguard Worker     } else {
144*da0073e9SAndroid Build Coastguard Worker       return at::_neg_view_copy(mutated_view);
145*da0073e9SAndroid Build Coastguard Worker     }
146*da0073e9SAndroid Build Coastguard Worker }
147*da0073e9SAndroid Build Coastguard Worker 
as_strided_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode,at::SymIntArrayRef size,at::SymIntArrayRef stride,std::optional<c10::SymInt> storage_offset)148*da0073e9SAndroid Build Coastguard Worker Tensor FunctionalInverses::as_strided_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, at::SymIntArrayRef size, at::SymIntArrayRef stride, std::optional<c10::SymInt> storage_offset) {
149*da0073e9SAndroid Build Coastguard Worker     if (inverse_return_mode == InverseReturnMode::AlwaysView) {
150*da0073e9SAndroid Build Coastguard Worker       // NB: assumes mutated_view is a narrowed view of base.
151*da0073e9SAndroid Build Coastguard Worker       // We should NOT do this for functionalization
152*da0073e9SAndroid Build Coastguard Worker       return mutated_view.as_strided_symint(
153*da0073e9SAndroid Build Coastguard Worker           base.sym_sizes(), base.sym_strides(), base.sym_storage_offset());
154*da0073e9SAndroid Build Coastguard Worker     } else {
155*da0073e9SAndroid Build Coastguard Worker       return base.as_strided_scatter_symint(mutated_view, size, stride, std::move(storage_offset));
156*da0073e9SAndroid Build Coastguard Worker     }
157*da0073e9SAndroid Build Coastguard Worker }
158*da0073e9SAndroid Build Coastguard Worker 
diagonal_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode,int64_t offset,int64_t dim1,int64_t dim2)159*da0073e9SAndroid Build Coastguard Worker Tensor FunctionalInverses::diagonal_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, int64_t offset, int64_t dim1, int64_t dim2) {
160*da0073e9SAndroid Build Coastguard Worker     if (inverse_return_mode == InverseReturnMode::AlwaysView) {
161*da0073e9SAndroid Build Coastguard Worker       // NB: assumes mutated_view is a narrowed view of base.
162*da0073e9SAndroid Build Coastguard Worker       // We should NOT do this for functionalization
163*da0073e9SAndroid Build Coastguard Worker       return mutated_view.as_strided_symint(
164*da0073e9SAndroid Build Coastguard Worker           base.sym_sizes(), base.sym_strides(), base.sym_storage_offset());
165*da0073e9SAndroid Build Coastguard Worker     } else {
166*da0073e9SAndroid Build Coastguard Worker       return base.diagonal_scatter(mutated_view, offset, dim1, dim2);
167*da0073e9SAndroid Build Coastguard Worker     }
168*da0073e9SAndroid Build Coastguard Worker }
169*da0073e9SAndroid Build Coastguard Worker 
expand_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode,at::SymIntArrayRef size,bool implicit)170*da0073e9SAndroid Build Coastguard Worker Tensor FunctionalInverses::expand_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, at::SymIntArrayRef size, bool implicit) {
171*da0073e9SAndroid Build Coastguard Worker     if (inverse_return_mode == InverseReturnMode::AlwaysView) {
172*da0073e9SAndroid Build Coastguard Worker       // NB: assumes mutated_view is an expanded view of base.
173*da0073e9SAndroid Build Coastguard Worker       // We should NOT do this for functionalization
174*da0073e9SAndroid Build Coastguard Worker       return mutated_view.as_strided_symint(
175*da0073e9SAndroid Build Coastguard Worker           base.sym_sizes(), base.sym_strides(), base.sym_storage_offset());
176*da0073e9SAndroid Build Coastguard Worker     } else {
177*da0073e9SAndroid Build Coastguard Worker       return base + at::sum_to(
178*da0073e9SAndroid Build Coastguard Worker           mutated_view - base,
179*da0073e9SAndroid Build Coastguard Worker           base.sym_sizes(),
180*da0073e9SAndroid Build Coastguard Worker           /*always_return_non_view=*/inverse_return_mode == InverseReturnMode::NeverView
181*da0073e9SAndroid Build Coastguard Worker       );
182*da0073e9SAndroid Build Coastguard Worker     }
183*da0073e9SAndroid Build Coastguard Worker }
184*da0073e9SAndroid Build Coastguard Worker 
permute_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode,at::IntArrayRef dims)185*da0073e9SAndroid Build Coastguard Worker Tensor FunctionalInverses::permute_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, at::IntArrayRef dims) {
186*da0073e9SAndroid Build Coastguard Worker     return at::functionalization::permute_inverse(mutated_view, dims, inverse_return_mode);
187*da0073e9SAndroid Build Coastguard Worker }
188*da0073e9SAndroid Build Coastguard Worker 
_reshape_alias_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode,at::SymIntArrayRef size,at::SymIntArrayRef stride)189*da0073e9SAndroid Build Coastguard Worker Tensor FunctionalInverses::_reshape_alias_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, at::SymIntArrayRef size, at::SymIntArrayRef stride) {
190*da0073e9SAndroid Build Coastguard Worker     // Note that I'm directly calling reshape(), and ignoring the strides.
191*da0073e9SAndroid Build Coastguard Worker     // _reshape_alias() isn't available from user code, and is an implementation detail of reshape().
192*da0073e9SAndroid Build Coastguard Worker     // Specifically, passing in the strides directly can get us into trouble in cases like:
193*da0073e9SAndroid Build Coastguard Worker     // b = a[0]; c = b.reshape(...); c.add_(1); print(a)
194*da0073e9SAndroid Build Coastguard Worker     // When we eventually run the _reshape_alias_inverse() call here, if we were to pass in both sizes and strides,
195*da0073e9SAndroid Build Coastguard Worker     // The call would fail because `mutated_view` doesn't have enough bytes of storage.
196*da0073e9SAndroid Build Coastguard Worker     if (inverse_return_mode != InverseReturnMode::NeverView) {
197*da0073e9SAndroid Build Coastguard Worker       return at::_reshape_alias_symint(mutated_view, base.sym_sizes(), base.sym_strides());
198*da0073e9SAndroid Build Coastguard Worker     } else {
199*da0073e9SAndroid Build Coastguard Worker       return at::_reshape_alias_copy_symint(mutated_view, base.sym_sizes(), base.sym_strides());
200*da0073e9SAndroid Build Coastguard Worker     }
201*da0073e9SAndroid Build Coastguard Worker }
202*da0073e9SAndroid Build Coastguard Worker 
select_int_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode,int64_t dim,c10::SymInt index)203*da0073e9SAndroid Build Coastguard Worker Tensor FunctionalInverses::select_int_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, int64_t dim, c10::SymInt index) {
204*da0073e9SAndroid Build Coastguard Worker     if (inverse_return_mode == InverseReturnMode::AlwaysView) {
205*da0073e9SAndroid Build Coastguard Worker       // NB: assumes mutated_view is a narrowed view of base.
206*da0073e9SAndroid Build Coastguard Worker       // We should NOT do this for functionalization
207*da0073e9SAndroid Build Coastguard Worker       return mutated_view.as_strided_symint(
208*da0073e9SAndroid Build Coastguard Worker           base.sym_sizes(), base.sym_strides(), base.sym_storage_offset());
209*da0073e9SAndroid Build Coastguard Worker     } else {
210*da0073e9SAndroid Build Coastguard Worker       return base.select_scatter_symint(mutated_view, dim, std::move(index));
211*da0073e9SAndroid Build Coastguard Worker     }
212*da0073e9SAndroid Build Coastguard Worker }
213*da0073e9SAndroid Build Coastguard Worker 
detach_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode)214*da0073e9SAndroid Build Coastguard Worker Tensor FunctionalInverses::detach_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) {
215*da0073e9SAndroid Build Coastguard Worker     // the functionalization pass doesn't care about autograd metadata - as a view, I think detach() is just an identity function
216*da0073e9SAndroid Build Coastguard Worker     return mutated_view;
217*da0073e9SAndroid Build Coastguard Worker }
218*da0073e9SAndroid Build Coastguard Worker 
lift_fresh_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode)219*da0073e9SAndroid Build Coastguard Worker Tensor FunctionalInverses::lift_fresh_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) {
220*da0073e9SAndroid Build Coastguard Worker     return mutated_view;
221*da0073e9SAndroid Build Coastguard Worker }
222*da0073e9SAndroid Build Coastguard Worker 
slice_Tensor_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode,int64_t dim,std::optional<c10::SymInt> start,std::optional<c10::SymInt> end,c10::SymInt step)223*da0073e9SAndroid Build Coastguard Worker Tensor FunctionalInverses::slice_Tensor_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, int64_t dim, std::optional<c10::SymInt> start, std::optional<c10::SymInt> end, c10::SymInt step) {
224*da0073e9SAndroid Build Coastguard Worker     if (inverse_return_mode == InverseReturnMode::AlwaysView) {
225*da0073e9SAndroid Build Coastguard Worker       // NB: assumes mutated_view is a narrowed view of base.
226*da0073e9SAndroid Build Coastguard Worker       // We should NOT do this for functionalization
227*da0073e9SAndroid Build Coastguard Worker       return mutated_view.slice_inverse_symint(
228*da0073e9SAndroid Build Coastguard Worker           base, dim, std::move(start), std::move(end), std::move(step));
229*da0073e9SAndroid Build Coastguard Worker     } else {
230*da0073e9SAndroid Build Coastguard Worker       return base.slice_scatter_symint(mutated_view, dim, std::move(start), std::move(end), std::move(step));
231*da0073e9SAndroid Build Coastguard Worker     }
232*da0073e9SAndroid Build Coastguard Worker }
233*da0073e9SAndroid Build Coastguard Worker 
split_Tensor_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode,int64_t mutated_view_idx,c10::SymInt split_size,int64_t dim)234*da0073e9SAndroid Build Coastguard Worker Tensor FunctionalInverses::split_Tensor_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, int64_t mutated_view_idx, c10::SymInt split_size, int64_t dim) {
235*da0073e9SAndroid Build Coastguard Worker     // It would be nice if this logic could be re-used from autograd's split_backward(), but I don't think it can.
236*da0073e9SAndroid Build Coastguard Worker     // For functionalization, we have only have one of the tensors from the TensorList outputed by split(), and we want to layer i
237*da0073e9SAndroid Build Coastguard Worker     // on top of the base tensor.
238*da0073e9SAndroid Build Coastguard Worker     // For autograd, we have all of the tensors outputted by split() and we just want to stack them.
239*da0073e9SAndroid Build Coastguard Worker     dim = at::maybe_wrap_dim(dim, base.dim());
240*da0073e9SAndroid Build Coastguard Worker     auto dim_size = base.sym_size(dim);
241*da0073e9SAndroid Build Coastguard Worker     auto start = split_size * mutated_view_idx;
242*da0073e9SAndroid Build Coastguard Worker     auto end = split_size + start;
243*da0073e9SAndroid Build Coastguard Worker     if (end > dim_size) end = dim_size;
244*da0073e9SAndroid Build Coastguard Worker 
245*da0073e9SAndroid Build Coastguard Worker     if (inverse_return_mode == InverseReturnMode::AlwaysView) {
246*da0073e9SAndroid Build Coastguard Worker       // NB: assumes mutated_view is a narrowed view of base.
247*da0073e9SAndroid Build Coastguard Worker       // We should NOT do this for functionalization
248*da0073e9SAndroid Build Coastguard Worker       return mutated_view.slice_inverse_symint(base, dim, start, end, 1);
249*da0073e9SAndroid Build Coastguard Worker     } else {
250*da0073e9SAndroid Build Coastguard Worker       return base.slice_scatter_symint(mutated_view, dim, start, end, 1);
251*da0073e9SAndroid Build Coastguard Worker     }
252*da0073e9SAndroid Build Coastguard Worker }
253*da0073e9SAndroid Build Coastguard Worker 
split_with_sizes_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode,int64_t mutated_view_idx,c10::SymIntArrayRef split_sizes,int64_t dim)254*da0073e9SAndroid Build Coastguard Worker Tensor FunctionalInverses::split_with_sizes_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, int64_t mutated_view_idx, c10::SymIntArrayRef split_sizes, int64_t dim) {
255*da0073e9SAndroid Build Coastguard Worker     dim = at::maybe_wrap_dim(dim, base.dim());
256*da0073e9SAndroid Build Coastguard Worker     auto dim_size = base.sym_size(dim);
257*da0073e9SAndroid Build Coastguard Worker     c10::SymInt start = 0;
258*da0073e9SAndroid Build Coastguard Worker     for (auto i = 0; i < mutated_view_idx; ++i) {
259*da0073e9SAndroid Build Coastguard Worker         start += split_sizes[i];
260*da0073e9SAndroid Build Coastguard Worker     }
261*da0073e9SAndroid Build Coastguard Worker     auto end = start + split_sizes[mutated_view_idx];
262*da0073e9SAndroid Build Coastguard Worker     if (end > dim_size) end = dim_size;
263*da0073e9SAndroid Build Coastguard Worker 
264*da0073e9SAndroid Build Coastguard Worker     if (inverse_return_mode == InverseReturnMode::AlwaysView) {
265*da0073e9SAndroid Build Coastguard Worker       // NB: assumes mutated_view is a narrowed view of base.
266*da0073e9SAndroid Build Coastguard Worker       // We should NOT do this for functionalization
267*da0073e9SAndroid Build Coastguard Worker       return mutated_view.slice_inverse_symint(base, dim, start, end, 1);
268*da0073e9SAndroid Build Coastguard Worker     } else {
269*da0073e9SAndroid Build Coastguard Worker       return base.slice_scatter_symint(mutated_view, dim, start, end, 1);
270*da0073e9SAndroid Build Coastguard Worker     }
271*da0073e9SAndroid Build Coastguard Worker }
272*da0073e9SAndroid Build Coastguard Worker 
squeeze_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode)273*da0073e9SAndroid Build Coastguard Worker Tensor FunctionalInverses::squeeze_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) {
274*da0073e9SAndroid Build Coastguard Worker     return unsqueeze_copy_to(mutated_view, base.sym_sizes(), inverse_return_mode);
275*da0073e9SAndroid Build Coastguard Worker }
276*da0073e9SAndroid Build Coastguard Worker 
squeeze_dim_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode,int64_t dim)277*da0073e9SAndroid Build Coastguard Worker Tensor FunctionalInverses::squeeze_dim_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, int64_t dim) {
278*da0073e9SAndroid Build Coastguard Worker     return unsqueeze_copy_to(mutated_view, dim, base.sym_sizes(), inverse_return_mode);
279*da0073e9SAndroid Build Coastguard Worker }
280*da0073e9SAndroid Build Coastguard Worker 
squeeze_dims_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode,IntArrayRef dim)281*da0073e9SAndroid Build Coastguard Worker Tensor FunctionalInverses::squeeze_dims_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, IntArrayRef dim) {
282*da0073e9SAndroid Build Coastguard Worker     return unsqueeze_copy_to(mutated_view, dim, base.sym_sizes(), inverse_return_mode);
283*da0073e9SAndroid Build Coastguard Worker }
284*da0073e9SAndroid Build Coastguard Worker 
t_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode)285*da0073e9SAndroid Build Coastguard Worker Tensor FunctionalInverses::t_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) {
286*da0073e9SAndroid Build Coastguard Worker     if (inverse_return_mode != InverseReturnMode::NeverView) {
287*da0073e9SAndroid Build Coastguard Worker       return at::t(mutated_view);
288*da0073e9SAndroid Build Coastguard Worker     } else {
289*da0073e9SAndroid Build Coastguard Worker       return at::t_copy(mutated_view);
290*da0073e9SAndroid Build Coastguard Worker     }
291*da0073e9SAndroid Build Coastguard Worker }
292*da0073e9SAndroid Build Coastguard Worker 
transpose_int_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode,int64_t dim0,int64_t dim1)293*da0073e9SAndroid Build Coastguard Worker Tensor FunctionalInverses::transpose_int_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, int64_t dim0, int64_t dim1) {
294*da0073e9SAndroid Build Coastguard Worker     if (inverse_return_mode != InverseReturnMode::NeverView) {
295*da0073e9SAndroid Build Coastguard Worker       return transpose(mutated_view, dim0, dim1);
296*da0073e9SAndroid Build Coastguard Worker     } else {
297*da0073e9SAndroid Build Coastguard Worker       return transpose_copy(mutated_view, dim0, dim1);
298*da0073e9SAndroid Build Coastguard Worker     }
299*da0073e9SAndroid Build Coastguard Worker }
300*da0073e9SAndroid Build Coastguard Worker 
_nested_view_from_buffer_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode,const Tensor & nested_sizes,const Tensor & nested_strides,const Tensor & storage_offsets)301*da0073e9SAndroid Build Coastguard Worker Tensor FunctionalInverses::_nested_view_from_buffer_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, const Tensor& nested_sizes, const Tensor& nested_strides, const Tensor& storage_offsets) {
302*da0073e9SAndroid Build Coastguard Worker     TORCH_INTERNAL_ASSERT(false, "Attempted to call _nested_view_from_buffer() during the functionalization pass. For now, nested tensors aren't supported during functionalization");
303*da0073e9SAndroid Build Coastguard Worker     return Tensor();
304*da0073e9SAndroid Build Coastguard Worker }
305*da0073e9SAndroid Build Coastguard Worker 
_nested_view_from_jagged_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode,const Tensor & offsets,const Tensor & dummy,const std::optional<Tensor> & lengths,int64_t ragged_idx,const std::optional<Tensor> & min_seqlen,const std::optional<Tensor> & max_seqlen)306*da0073e9SAndroid Build Coastguard Worker Tensor FunctionalInverses::_nested_view_from_jagged_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, const Tensor& offsets, const Tensor& dummy, const std::optional<Tensor>& lengths, int64_t ragged_idx, const std::optional<Tensor>& min_seqlen, const std::optional<Tensor>& max_seqlen) {
307*da0073e9SAndroid Build Coastguard Worker   auto values = at::_nested_get_values(mutated_view);
308*da0073e9SAndroid Build Coastguard Worker   if (inverse_return_mode != InverseReturnMode::NeverView) {
309*da0073e9SAndroid Build Coastguard Worker     return values;
310*da0073e9SAndroid Build Coastguard Worker   } else {
311*da0073e9SAndroid Build Coastguard Worker     return values.clone(/*memory_format=*/at::MemoryFormat::Contiguous);
312*da0073e9SAndroid Build Coastguard Worker   }
313*da0073e9SAndroid Build Coastguard Worker }
314*da0073e9SAndroid Build Coastguard Worker 
_nested_get_values_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode)315*da0073e9SAndroid Build Coastguard Worker Tensor FunctionalInverses::_nested_get_values_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) {
316*da0073e9SAndroid Build Coastguard Worker   auto offsets = at::_nested_get_offsets(base);
317*da0073e9SAndroid Build Coastguard Worker   auto lengths = at::_nested_get_lengths(base);
318*da0073e9SAndroid Build Coastguard Worker   auto ragged_idx = at::_nested_get_ragged_idx(base);
319*da0073e9SAndroid Build Coastguard Worker   auto dummy = at::_nested_get_jagged_dummy(base);
320*da0073e9SAndroid Build Coastguard Worker   auto min_seqlen = at::_nested_get_min_seqlen(base);
321*da0073e9SAndroid Build Coastguard Worker   auto max_seqlen = at::_nested_get_max_seqlen(base);
322*da0073e9SAndroid Build Coastguard Worker   auto nt = at::_nested_view_from_jagged(
323*da0073e9SAndroid Build Coastguard Worker       mutated_view, offsets, dummy, lengths, ragged_idx,
324*da0073e9SAndroid Build Coastguard Worker       (min_seqlen.defined() ? std::optional<Tensor>(min_seqlen) : std::nullopt),
325*da0073e9SAndroid Build Coastguard Worker       (max_seqlen.defined() ? std::optional<Tensor>(max_seqlen) : std::nullopt));
326*da0073e9SAndroid Build Coastguard Worker 
327*da0073e9SAndroid Build Coastguard Worker   if (inverse_return_mode != InverseReturnMode::NeverView) {
328*da0073e9SAndroid Build Coastguard Worker     return nt;
329*da0073e9SAndroid Build Coastguard Worker   } else {
330*da0073e9SAndroid Build Coastguard Worker     return nt.clone(/*memory_format=*/at::MemoryFormat::Contiguous);
331*da0073e9SAndroid Build Coastguard Worker   }
332*da0073e9SAndroid Build Coastguard Worker }
333*da0073e9SAndroid Build Coastguard Worker 
unsqueeze_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode,int64_t dim)334*da0073e9SAndroid Build Coastguard Worker Tensor FunctionalInverses::unsqueeze_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, int64_t dim) {
335*da0073e9SAndroid Build Coastguard Worker     if (inverse_return_mode != InverseReturnMode::NeverView) {
336*da0073e9SAndroid Build Coastguard Worker       return at::squeeze(mutated_view, dim);
337*da0073e9SAndroid Build Coastguard Worker     } else {
338*da0073e9SAndroid Build Coastguard Worker       return at::squeeze_copy(mutated_view, dim);
339*da0073e9SAndroid Build Coastguard Worker     }
340*da0073e9SAndroid Build Coastguard Worker }
341*da0073e9SAndroid Build Coastguard Worker 
_indices_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode)342*da0073e9SAndroid Build Coastguard Worker Tensor FunctionalInverses::_indices_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) {
343*da0073e9SAndroid Build Coastguard Worker     TORCH_INTERNAL_ASSERT(false, "Attempted to call _indices() during the functionalization pass. For now, sparse tensors aren't supported during functionalization");
344*da0073e9SAndroid Build Coastguard Worker     return Tensor();
345*da0073e9SAndroid Build Coastguard Worker }
346*da0073e9SAndroid Build Coastguard Worker 
_values_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode)347*da0073e9SAndroid Build Coastguard Worker Tensor FunctionalInverses::_values_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) {
348*da0073e9SAndroid Build Coastguard Worker     TORCH_INTERNAL_ASSERT(false, "Attempted to call _values() during the functionalization pass. For now, sparse tensors aren't supported during functionalization");
349*da0073e9SAndroid Build Coastguard Worker     return Tensor();
350*da0073e9SAndroid Build Coastguard Worker }
351*da0073e9SAndroid Build Coastguard Worker 
indices_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode)352*da0073e9SAndroid Build Coastguard Worker Tensor FunctionalInverses::indices_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) {
353*da0073e9SAndroid Build Coastguard Worker     TORCH_INTERNAL_ASSERT(false, "Attempted to call indices() during the functionalization pass. For now, sparse tensors aren't supported during functionalization");
354*da0073e9SAndroid Build Coastguard Worker     return Tensor();
355*da0073e9SAndroid Build Coastguard Worker }
356*da0073e9SAndroid Build Coastguard Worker 
values_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode)357*da0073e9SAndroid Build Coastguard Worker Tensor FunctionalInverses::values_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) {
358*da0073e9SAndroid Build Coastguard Worker     TORCH_INTERNAL_ASSERT(false, "Attempted to call values() during the functionalization pass. For now, sparse tensors aren't supported during functionalization");
359*da0073e9SAndroid Build Coastguard Worker     return Tensor();
360*da0073e9SAndroid Build Coastguard Worker }
361*da0073e9SAndroid Build Coastguard Worker 
_sparse_broadcast_to_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode,at::IntArrayRef size)362*da0073e9SAndroid Build Coastguard Worker Tensor FunctionalInverses::_sparse_broadcast_to_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, at::IntArrayRef size) {
363*da0073e9SAndroid Build Coastguard Worker     TORCH_INTERNAL_ASSERT(false, "Attempted to call _sparse_broadcast_to() during the functionalization pass. For now, sparse tensors aren't supported during functionalization");
364*da0073e9SAndroid Build Coastguard Worker     return Tensor();
365*da0073e9SAndroid Build Coastguard Worker }
366*da0073e9SAndroid Build Coastguard Worker 
crow_indices_inverse(const at::Tensor & base,const at::Tensor & mutated_view,InverseReturnMode inverse_return_mode)367*da0073e9SAndroid Build Coastguard Worker Tensor FunctionalInverses::crow_indices_inverse(const at::Tensor& base, const at::Tensor& mutated_view, InverseReturnMode inverse_return_mode) {
368*da0073e9SAndroid Build Coastguard Worker     TORCH_INTERNAL_ASSERT(false, "Attempted to call crow_indices() during the functionalization pass. For now, sparse tensors aren't supported during functionalization");
369*da0073e9SAndroid Build Coastguard Worker     return Tensor();
370*da0073e9SAndroid Build Coastguard Worker }
371*da0073e9SAndroid Build Coastguard Worker 
col_indices_inverse(const at::Tensor & base,const at::Tensor & mutated_view,InverseReturnMode inverse_return_mode)372*da0073e9SAndroid Build Coastguard Worker Tensor FunctionalInverses::col_indices_inverse(const at::Tensor& base, const at::Tensor& mutated_view, InverseReturnMode inverse_return_mode) {
373*da0073e9SAndroid Build Coastguard Worker     TORCH_INTERNAL_ASSERT(false, "Attempted to call col_indices() during the functionalization pass. For now, sparse tensors aren't supported during functionalization");
374*da0073e9SAndroid Build Coastguard Worker     return Tensor();
375*da0073e9SAndroid Build Coastguard Worker }
376*da0073e9SAndroid Build Coastguard Worker 
ccol_indices_inverse(const at::Tensor & base,const at::Tensor & mutated_view,InverseReturnMode inverse_return_mode)377*da0073e9SAndroid Build Coastguard Worker Tensor FunctionalInverses::ccol_indices_inverse(const at::Tensor& base, const at::Tensor& mutated_view, InverseReturnMode inverse_return_mode) {
378*da0073e9SAndroid Build Coastguard Worker     TORCH_INTERNAL_ASSERT(false, "Attempted to call ccol_indices() during the functionalization pass. For now, sparse tensors aren't supported during functionalization");
379*da0073e9SAndroid Build Coastguard Worker     return Tensor();
380*da0073e9SAndroid Build Coastguard Worker }
381*da0073e9SAndroid Build Coastguard Worker 
row_indices_inverse(const at::Tensor & base,const at::Tensor & mutated_view,InverseReturnMode inverse_return_mode)382*da0073e9SAndroid Build Coastguard Worker Tensor FunctionalInverses::row_indices_inverse(const at::Tensor& base, const at::Tensor& mutated_view, InverseReturnMode inverse_return_mode) {
383*da0073e9SAndroid Build Coastguard Worker     TORCH_INTERNAL_ASSERT(false, "Attempted to call row_indices() during the functionalization pass. For now, sparse tensors aren't supported during functionalization");
384*da0073e9SAndroid Build Coastguard Worker     return Tensor();
385*da0073e9SAndroid Build Coastguard Worker }
386*da0073e9SAndroid Build Coastguard Worker 
unbind_int_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode,int64_t mutated_view_idx,int64_t dim)387*da0073e9SAndroid Build Coastguard Worker Tensor FunctionalInverses::unbind_int_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, int64_t mutated_view_idx, int64_t dim) {
388*da0073e9SAndroid Build Coastguard Worker     if (inverse_return_mode == InverseReturnMode::AlwaysView) {
389*da0073e9SAndroid Build Coastguard Worker       // NB: assumes mutated_view is a narrowed view of base.
390*da0073e9SAndroid Build Coastguard Worker       // We should NOT do this for functionalization
391*da0073e9SAndroid Build Coastguard Worker       return mutated_view.as_strided_symint(
392*da0073e9SAndroid Build Coastguard Worker           base.sym_sizes(), base.sym_strides(), base.sym_storage_offset());
393*da0073e9SAndroid Build Coastguard Worker     } else {
394*da0073e9SAndroid Build Coastguard Worker       dim = at::maybe_wrap_dim(dim, static_cast<int64_t>(base.sizes().size()));
395*da0073e9SAndroid Build Coastguard Worker       return base.select_scatter(mutated_view, dim, mutated_view_idx);
396*da0073e9SAndroid Build Coastguard Worker     }
397*da0073e9SAndroid Build Coastguard Worker }
398*da0073e9SAndroid Build Coastguard Worker 
view_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode,at::SymIntArrayRef size)399*da0073e9SAndroid Build Coastguard Worker Tensor FunctionalInverses::view_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, at::SymIntArrayRef size) {
400*da0073e9SAndroid Build Coastguard Worker     if (inverse_return_mode != InverseReturnMode::NeverView) {
401*da0073e9SAndroid Build Coastguard Worker       return mutated_view.view_symint(base.sym_sizes());
402*da0073e9SAndroid Build Coastguard Worker     } else {
403*da0073e9SAndroid Build Coastguard Worker       return at::view_copy_symint(mutated_view, base.sym_sizes());
404*da0073e9SAndroid Build Coastguard Worker     }
405*da0073e9SAndroid Build Coastguard Worker }
406*da0073e9SAndroid Build Coastguard Worker 
407*da0073e9SAndroid Build Coastguard Worker 
view_dtype_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode,at::ScalarType dtype)408*da0073e9SAndroid Build Coastguard Worker Tensor FunctionalInverses::view_dtype_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, at::ScalarType dtype) {
409*da0073e9SAndroid Build Coastguard Worker     if (inverse_return_mode != InverseReturnMode::NeverView) {
410*da0073e9SAndroid Build Coastguard Worker       return mutated_view.view(base.scalar_type());
411*da0073e9SAndroid Build Coastguard Worker     } else {
412*da0073e9SAndroid Build Coastguard Worker       return at::view_copy(mutated_view, base.scalar_type());
413*da0073e9SAndroid Build Coastguard Worker     }
414*da0073e9SAndroid Build Coastguard Worker }
415*da0073e9SAndroid Build Coastguard Worker 
unfold_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode,int64_t dimension,int64_t size,int64_t step)416*da0073e9SAndroid Build Coastguard Worker Tensor FunctionalInverses::unfold_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, int64_t dimension, int64_t size, int64_t step) {
417*da0073e9SAndroid Build Coastguard Worker     if (inverse_return_mode == InverseReturnMode::AlwaysView) {
418*da0073e9SAndroid Build Coastguard Worker       // NB: assumes mutated_view is a narrowed view of base.
419*da0073e9SAndroid Build Coastguard Worker       // We should NOT do this for functionalization
420*da0073e9SAndroid Build Coastguard Worker       return mutated_view.as_strided_symint(
421*da0073e9SAndroid Build Coastguard Worker           base.sym_sizes(), base.sym_strides(), base.sym_storage_offset());
422*da0073e9SAndroid Build Coastguard Worker     } else {
423*da0073e9SAndroid Build Coastguard Worker       // I think autograd and the functionalization pass want the exact same thing here, but need to test to confirm.
424*da0073e9SAndroid Build Coastguard Worker       // unfold_backward() is safe to use here because it is NOT a view op.
425*da0073e9SAndroid Build Coastguard Worker       // (note: technically, we'll have an extra memory copy.
426*da0073e9SAndroid Build Coastguard Worker       // We'd need to add an aliasing version of unfold_backward to fix that though).
427*da0073e9SAndroid Build Coastguard Worker       TORCH_CHECK(
428*da0073e9SAndroid Build Coastguard Worker         !(inverse_return_mode == InverseReturnMode::ViewOrScatterInverse && size > step),
429*da0073e9SAndroid Build Coastguard Worker         "While executing unfold, functionalization encountered a tensor being mutated that has internal overlap. \
430*da0073e9SAndroid Build Coastguard Worker When using torch.compile (or running functionalization directly), this is banned \
431*da0073e9SAndroid Build Coastguard Worker as the behavior is not well defined. Consider cloning the tensor before mutating it, \
432*da0073e9SAndroid Build Coastguard Worker or removing the mutation from your model."
433*da0073e9SAndroid Build Coastguard Worker           );
434*da0073e9SAndroid Build Coastguard Worker       return unfold_backward(mutated_view, base.sizes(), dimension, size, step);
435*da0073e9SAndroid Build Coastguard Worker     }
436*da0073e9SAndroid Build Coastguard Worker }
437*da0073e9SAndroid Build Coastguard Worker 
alias_inverse(const Tensor & base,const Tensor & mutated_view,InverseReturnMode inverse_return_mode)438*da0073e9SAndroid Build Coastguard Worker Tensor FunctionalInverses::alias_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) {
439*da0073e9SAndroid Build Coastguard Worker     if (inverse_return_mode != InverseReturnMode::NeverView) {
440*da0073e9SAndroid Build Coastguard Worker       return at::alias(mutated_view);
441*da0073e9SAndroid Build Coastguard Worker     } else {
442*da0073e9SAndroid Build Coastguard Worker       return at::alias_copy(mutated_view);
443*da0073e9SAndroid Build Coastguard Worker     }
444*da0073e9SAndroid Build Coastguard Worker }
445*da0073e9SAndroid Build Coastguard Worker 
chunk_inverse(const at::Tensor & base,const at::Tensor & mutated_view,InverseReturnMode inverse_return_mode,int64_t mutated_view_idx,int chunks,int dim)446*da0073e9SAndroid Build Coastguard Worker Tensor FunctionalInverses::chunk_inverse(const at::Tensor & base, const at::Tensor & mutated_view, InverseReturnMode inverse_return_mode, int64_t mutated_view_idx, int chunks, int dim) {
447*da0073e9SAndroid Build Coastguard Worker     // TODO: Can the logic from TensorShape.cpp be reused here somehow?
448*da0073e9SAndroid Build Coastguard Worker     const auto dim_size = base.sym_size(dim);
449*da0073e9SAndroid Build Coastguard Worker     auto split_size = (dim_size + chunks - 1) / chunks;
450*da0073e9SAndroid Build Coastguard Worker     std::vector<c10::SymInt> split_sizes(chunks, split_size);
451*da0073e9SAndroid Build Coastguard Worker     split_sizes[chunks - 1] = split_size - (split_size * chunks - dim_size);
452*da0073e9SAndroid Build Coastguard Worker     return split_with_sizes_inverse(base, mutated_view, inverse_return_mode, mutated_view_idx, split_sizes, dim);
453*da0073e9SAndroid Build Coastguard Worker }
454*da0073e9SAndroid Build Coastguard Worker 
narrow_inverse(const at::Tensor & base,const at::Tensor & mutated_view,InverseReturnMode inverse_return_mode,int dim,c10::SymInt start,c10::SymInt length)455*da0073e9SAndroid Build Coastguard Worker Tensor FunctionalInverses::narrow_inverse(const at::Tensor & base, const at::Tensor & mutated_view, InverseReturnMode inverse_return_mode, int dim, c10::SymInt start, c10::SymInt length) {
456*da0073e9SAndroid Build Coastguard Worker     if (inverse_return_mode == InverseReturnMode::AlwaysView) {
457*da0073e9SAndroid Build Coastguard Worker       // NB: assumes mutated_view is a narrowed view of base.
458*da0073e9SAndroid Build Coastguard Worker       // We should NOT do this for functionalization
459*da0073e9SAndroid Build Coastguard Worker       return mutated_view.slice_inverse_symint(base, dim, start, start + length, 1);
460*da0073e9SAndroid Build Coastguard Worker     } else {
461*da0073e9SAndroid Build Coastguard Worker       return base.slice_scatter_symint(
462*da0073e9SAndroid Build Coastguard Worker           mutated_view, dim, start, start + length, 1);
463*da0073e9SAndroid Build Coastguard Worker     }
464*da0073e9SAndroid Build Coastguard Worker }
465*da0073e9SAndroid Build Coastguard Worker 
slice_inverse_inverse(const at::Tensor & base,const at::Tensor & mutated_view,InverseReturnMode inverse_return_mode,const at::Tensor & src,int64_t dim,std::optional<c10::SymInt> start,std::optional<c10::SymInt> end,c10::SymInt step)466*da0073e9SAndroid Build Coastguard Worker Tensor FunctionalInverses::slice_inverse_inverse(const at::Tensor & base, const at::Tensor & mutated_view, InverseReturnMode inverse_return_mode, const at::Tensor & src, int64_t dim, std::optional<c10::SymInt> start, std::optional<c10::SymInt> end, c10::SymInt step) {
467*da0073e9SAndroid Build Coastguard Worker     // slice_inverse() inverse is just slice()
468*da0073e9SAndroid Build Coastguard Worker     if (inverse_return_mode == InverseReturnMode::NeverView) {
469*da0073e9SAndroid Build Coastguard Worker       return at::slice_copy_symint(
470*da0073e9SAndroid Build Coastguard Worker           mutated_view, dim, std::move(start), std::move(end), std::move(step));
471*da0073e9SAndroid Build Coastguard Worker     } else {
472*da0073e9SAndroid Build Coastguard Worker       return mutated_view.slice_symint(
473*da0073e9SAndroid Build Coastguard Worker           dim, std::move(start), std::move(end), std::move(step));
474*da0073e9SAndroid Build Coastguard Worker     }
475*da0073e9SAndroid Build Coastguard Worker }
476*da0073e9SAndroid Build Coastguard Worker 
477*da0073e9SAndroid Build Coastguard Worker } // namespace at::functionalization
478