1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Worker #ifndef AT_PER_OPERATOR_HEADERS
4*da0073e9SAndroid Build Coastguard Worker #include <ATen/Functions.h>
5*da0073e9SAndroid Build Coastguard Worker #else
6*da0073e9SAndroid Build Coastguard Worker #include <ATen/ops/view.h>
7*da0073e9SAndroid Build Coastguard Worker #include <ATen/ops/view_copy.h>
8*da0073e9SAndroid Build Coastguard Worker #endif
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Worker #include <ATen/Tensor.h>
11*da0073e9SAndroid Build Coastguard Worker #include <ATen/core/DimVector.h>
12*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Exception.h>
13*da0073e9SAndroid Build Coastguard Worker #include <c10/util/MaybeOwned.h>
14*da0073e9SAndroid Build Coastguard Worker #include <c10/util/irange.h>
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Worker #include <functional>
17*da0073e9SAndroid Build Coastguard Worker #include <tuple>
18*da0073e9SAndroid Build Coastguard Worker #include <utility>
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Worker namespace at {
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Worker TORCH_API std::vector<int64_t> infer_size(IntArrayRef a, IntArrayRef b);
23*da0073e9SAndroid Build Coastguard Worker TORCH_API std::vector<SymInt> infer_size_symint(
24*da0073e9SAndroid Build Coastguard Worker SymIntArrayRef a,
25*da0073e9SAndroid Build Coastguard Worker SymIntArrayRef b);
26*da0073e9SAndroid Build Coastguard Worker TORCH_API DimVector infer_size_dimvector(IntArrayRef a, IntArrayRef b);
27*da0073e9SAndroid Build Coastguard Worker TORCH_API SymDimVector
28*da0073e9SAndroid Build Coastguard Worker infer_size_symdimvector(SymIntArrayRef a, SymIntArrayRef b);
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Worker // Named type instead of a pair/tuple so that we can be sure to
31*da0073e9SAndroid Build Coastguard Worker // construct the vectors in place and get NRVO.
32*da0073e9SAndroid Build Coastguard Worker template <typename Container>
33*da0073e9SAndroid Build Coastguard Worker struct InferExpandGeometryResult {
34*da0073e9SAndroid Build Coastguard Worker Container sizes;
35*da0073e9SAndroid Build Coastguard Worker Container strides;
InferExpandGeometryResultInferExpandGeometryResult36*da0073e9SAndroid Build Coastguard Worker explicit InferExpandGeometryResult(size_t ndim)
37*da0073e9SAndroid Build Coastguard Worker : sizes(ndim), strides(ndim) {}
InferExpandGeometryResultInferExpandGeometryResult38*da0073e9SAndroid Build Coastguard Worker explicit InferExpandGeometryResult(IntArrayRef sizes_, size_t ndim)
39*da0073e9SAndroid Build Coastguard Worker : sizes(sizes_.begin(), sizes_.end()), strides(ndim) {}
40*da0073e9SAndroid Build Coastguard Worker };
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Worker TORCH_API std::tuple<std::vector<int64_t>, std::vector<int64_t>>
43*da0073e9SAndroid Build Coastguard Worker inferExpandGeometry(
44*da0073e9SAndroid Build Coastguard Worker IntArrayRef tensor_sizes,
45*da0073e9SAndroid Build Coastguard Worker IntArrayRef tensor_strides,
46*da0073e9SAndroid Build Coastguard Worker IntArrayRef sizes);
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard Worker TORCH_API InferExpandGeometryResult<DimVector> inferExpandGeometry_dimvector(
49*da0073e9SAndroid Build Coastguard Worker IntArrayRef tensor_sizes,
50*da0073e9SAndroid Build Coastguard Worker IntArrayRef tensor_strides,
51*da0073e9SAndroid Build Coastguard Worker IntArrayRef sizes);
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker TORCH_API std::vector<int64_t> infer_dense_strides(
54*da0073e9SAndroid Build Coastguard Worker IntArrayRef tensor_sizes,
55*da0073e9SAndroid Build Coastguard Worker IntArrayRef tensor_strides);
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker // True if input shapes are expandable
58*da0073e9SAndroid Build Coastguard Worker // NOTE: infer_size did a similar check, please keep them sync if change is
59*da0073e9SAndroid Build Coastguard Worker // needed
are_expandable(IntArrayRef shape1,IntArrayRef shape2)60*da0073e9SAndroid Build Coastguard Worker inline bool are_expandable(IntArrayRef shape1, IntArrayRef shape2) {
61*da0073e9SAndroid Build Coastguard Worker size_t ndim1 = shape1.size();
62*da0073e9SAndroid Build Coastguard Worker size_t ndim2 = shape2.size();
63*da0073e9SAndroid Build Coastguard Worker size_t ndim = ndim1 < ndim2 ? ndim1 : ndim2;
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker for (int64_t i = static_cast<int64_t>(ndim) - 1; i >= 0; --i) {
66*da0073e9SAndroid Build Coastguard Worker if (shape1[--ndim1] == shape2[--ndim2] || shape1[ndim1] == 1 ||
67*da0073e9SAndroid Build Coastguard Worker shape2[ndim2] == 1) {
68*da0073e9SAndroid Build Coastguard Worker continue;
69*da0073e9SAndroid Build Coastguard Worker }
70*da0073e9SAndroid Build Coastguard Worker return false;
71*da0073e9SAndroid Build Coastguard Worker }
72*da0073e9SAndroid Build Coastguard Worker return true;
73*da0073e9SAndroid Build Coastguard Worker }
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker // avoid copy-construction of Tensor by using a reference_wrapper.
check_defined(std::initializer_list<std::reference_wrapper<const Tensor>> tensors,const char * api_name)76*da0073e9SAndroid Build Coastguard Worker inline void check_defined(
77*da0073e9SAndroid Build Coastguard Worker std::initializer_list<std::reference_wrapper<const Tensor>> tensors,
78*da0073e9SAndroid Build Coastguard Worker const char* api_name) {
79*da0073e9SAndroid Build Coastguard Worker for (auto& t : tensors) {
80*da0073e9SAndroid Build Coastguard Worker if (!t.get().defined()) {
81*da0073e9SAndroid Build Coastguard Worker AT_ERROR(api_name, "(...) called with an undefined Tensor");
82*da0073e9SAndroid Build Coastguard Worker }
83*da0073e9SAndroid Build Coastguard Worker }
84*da0073e9SAndroid Build Coastguard Worker }
85*da0073e9SAndroid Build Coastguard Worker
86*da0073e9SAndroid Build Coastguard Worker // NOTE [ ExpandUtils Borrowing ]
87*da0073e9SAndroid Build Coastguard Worker //
88*da0073e9SAndroid Build Coastguard Worker // Functions in ExpandUtils return `c10::MaybeOwned<Tensor>` because
89*da0073e9SAndroid Build Coastguard Worker // expansion may not actually be needed, in which case we can improve
90*da0073e9SAndroid Build Coastguard Worker // efficiency by returning
91*da0073e9SAndroid Build Coastguard Worker // `c10::MaybeOwned<Tensor>::borrowed(to_expand)`. However, this means
92*da0073e9SAndroid Build Coastguard Worker // that you need to be careful: the returned `c10::MaybeOwned<Tensor>`
93*da0073e9SAndroid Build Coastguard Worker // must not outlive the original `Tensor` object that `to_expand`
94*da0073e9SAndroid Build Coastguard Worker // referred to! The deleted rvalue reference overloads of these
95*da0073e9SAndroid Build Coastguard Worker // functions help with this by preventing trivial use of a temporary
96*da0073e9SAndroid Build Coastguard Worker // resulting from a function call, but it is still possible to make a
97*da0073e9SAndroid Build Coastguard Worker // mistake.
98*da0073e9SAndroid Build Coastguard Worker
expand_inplace(const Tensor & tensor,const Tensor & to_expand)99*da0073e9SAndroid Build Coastguard Worker inline c10::MaybeOwned<Tensor> expand_inplace(
100*da0073e9SAndroid Build Coastguard Worker const Tensor& tensor,
101*da0073e9SAndroid Build Coastguard Worker const Tensor& to_expand) {
102*da0073e9SAndroid Build Coastguard Worker if (tensor.sym_sizes().equals(to_expand.sym_sizes())) {
103*da0073e9SAndroid Build Coastguard Worker return c10::MaybeOwned<Tensor>::borrowed(to_expand);
104*da0073e9SAndroid Build Coastguard Worker }
105*da0073e9SAndroid Build Coastguard Worker return c10::MaybeOwned<Tensor>::owned(
106*da0073e9SAndroid Build Coastguard Worker to_expand.expand_symint(tensor.sym_sizes()));
107*da0073e9SAndroid Build Coastguard Worker }
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Worker inline c10::MaybeOwned<Tensor> expand_inplace(
110*da0073e9SAndroid Build Coastguard Worker const Tensor& tensor,
111*da0073e9SAndroid Build Coastguard Worker Tensor&& to_expand) = delete;
112*da0073e9SAndroid Build Coastguard Worker
expand_inplace(const Tensor & tensor,const Tensor & to_expand,const char * api_name)113*da0073e9SAndroid Build Coastguard Worker inline c10::MaybeOwned<Tensor> expand_inplace(
114*da0073e9SAndroid Build Coastguard Worker const Tensor& tensor,
115*da0073e9SAndroid Build Coastguard Worker const Tensor& to_expand,
116*da0073e9SAndroid Build Coastguard Worker const char* api_name) {
117*da0073e9SAndroid Build Coastguard Worker check_defined({tensor, to_expand}, api_name);
118*da0073e9SAndroid Build Coastguard Worker return expand_inplace(tensor, to_expand);
119*da0073e9SAndroid Build Coastguard Worker }
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker inline c10::MaybeOwned<Tensor> expand_inplace(
122*da0073e9SAndroid Build Coastguard Worker const Tensor& tensor,
123*da0073e9SAndroid Build Coastguard Worker Tensor&& to_expand,
124*da0073e9SAndroid Build Coastguard Worker const char* api_name) = delete;
125*da0073e9SAndroid Build Coastguard Worker
126*da0073e9SAndroid Build Coastguard Worker inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
expand_inplace(const Tensor & tensor,const Tensor & to_expand1,const Tensor & to_expand2)127*da0073e9SAndroid Build Coastguard Worker expand_inplace(
128*da0073e9SAndroid Build Coastguard Worker const Tensor& tensor,
129*da0073e9SAndroid Build Coastguard Worker const Tensor& to_expand1,
130*da0073e9SAndroid Build Coastguard Worker const Tensor& to_expand2) {
131*da0073e9SAndroid Build Coastguard Worker if (tensor.sizes().equals(to_expand1.sizes()) &&
132*da0073e9SAndroid Build Coastguard Worker tensor.sizes().equals((to_expand2.sizes()))) {
133*da0073e9SAndroid Build Coastguard Worker return std::make_tuple(
134*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>::borrowed(to_expand1),
135*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>::borrowed(to_expand2));
136*da0073e9SAndroid Build Coastguard Worker }
137*da0073e9SAndroid Build Coastguard Worker
138*da0073e9SAndroid Build Coastguard Worker return std::make_tuple(
139*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>::owned(to_expand1.expand(tensor.sizes())),
140*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>::owned(to_expand2.expand(tensor.sizes())));
141*da0073e9SAndroid Build Coastguard Worker }
142*da0073e9SAndroid Build Coastguard Worker
143*da0073e9SAndroid Build Coastguard Worker inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
144*da0073e9SAndroid Build Coastguard Worker expand_inplace(
145*da0073e9SAndroid Build Coastguard Worker const Tensor& tensor,
146*da0073e9SAndroid Build Coastguard Worker Tensor&& to_expand1,
147*da0073e9SAndroid Build Coastguard Worker const Tensor& to_expand2) = delete;
148*da0073e9SAndroid Build Coastguard Worker inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
149*da0073e9SAndroid Build Coastguard Worker expand_inplace(
150*da0073e9SAndroid Build Coastguard Worker const Tensor& tensor,
151*da0073e9SAndroid Build Coastguard Worker const Tensor& to_expand1,
152*da0073e9SAndroid Build Coastguard Worker Tensor&& to_expand2) = delete;
153*da0073e9SAndroid Build Coastguard Worker inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
154*da0073e9SAndroid Build Coastguard Worker expand_inplace(const Tensor& tensor, Tensor&& to_expand1, Tensor&& to_expand2) =
155*da0073e9SAndroid Build Coastguard Worker delete;
156*da0073e9SAndroid Build Coastguard Worker
157*da0073e9SAndroid Build Coastguard Worker inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
expand_inplace(const Tensor & tensor,const Tensor & to_expand1,const Tensor & to_expand2,const char * api_name)158*da0073e9SAndroid Build Coastguard Worker expand_inplace(
159*da0073e9SAndroid Build Coastguard Worker const Tensor& tensor,
160*da0073e9SAndroid Build Coastguard Worker const Tensor& to_expand1,
161*da0073e9SAndroid Build Coastguard Worker const Tensor& to_expand2,
162*da0073e9SAndroid Build Coastguard Worker const char* api_name) {
163*da0073e9SAndroid Build Coastguard Worker check_defined({tensor, to_expand1, to_expand2}, api_name);
164*da0073e9SAndroid Build Coastguard Worker return expand_inplace(tensor, to_expand1, to_expand2);
165*da0073e9SAndroid Build Coastguard Worker }
166*da0073e9SAndroid Build Coastguard Worker
167*da0073e9SAndroid Build Coastguard Worker inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
168*da0073e9SAndroid Build Coastguard Worker expand_inplace(
169*da0073e9SAndroid Build Coastguard Worker const Tensor& tensor,
170*da0073e9SAndroid Build Coastguard Worker Tensor&& to_expand1,
171*da0073e9SAndroid Build Coastguard Worker const Tensor& to_expand2,
172*da0073e9SAndroid Build Coastguard Worker const char* api_name) = delete;
173*da0073e9SAndroid Build Coastguard Worker inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
174*da0073e9SAndroid Build Coastguard Worker expand_inplace(
175*da0073e9SAndroid Build Coastguard Worker const Tensor& tensor,
176*da0073e9SAndroid Build Coastguard Worker const Tensor& to_expand1,
177*da0073e9SAndroid Build Coastguard Worker Tensor&& to_expand2,
178*da0073e9SAndroid Build Coastguard Worker const char* api_name) = delete;
179*da0073e9SAndroid Build Coastguard Worker inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
180*da0073e9SAndroid Build Coastguard Worker expand_inplace(
181*da0073e9SAndroid Build Coastguard Worker const Tensor& tensor,
182*da0073e9SAndroid Build Coastguard Worker Tensor&& to_expand1,
183*da0073e9SAndroid Build Coastguard Worker Tensor&& to_expand2,
184*da0073e9SAndroid Build Coastguard Worker const char* api_name) = delete;
185*da0073e9SAndroid Build Coastguard Worker
186*da0073e9SAndroid Build Coastguard Worker // See NOTE [ ExpandUtils Borrowing ] above for `MaybeOwned` explanation.
187*da0073e9SAndroid Build Coastguard Worker inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
expand_outplace(const Tensor & to_expand1,const Tensor & to_expand2)188*da0073e9SAndroid Build Coastguard Worker expand_outplace(const Tensor& to_expand1, const Tensor& to_expand2) {
189*da0073e9SAndroid Build Coastguard Worker auto s1 = to_expand1.sym_sizes();
190*da0073e9SAndroid Build Coastguard Worker auto s2 = to_expand2.sym_sizes();
191*da0073e9SAndroid Build Coastguard Worker if (s1.equals(s2)) {
192*da0073e9SAndroid Build Coastguard Worker return std::make_tuple(
193*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>::borrowed(to_expand1),
194*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>::borrowed(to_expand2));
195*da0073e9SAndroid Build Coastguard Worker }
196*da0073e9SAndroid Build Coastguard Worker
197*da0073e9SAndroid Build Coastguard Worker auto expanded_size = infer_size_symdimvector(s1, s2);
198*da0073e9SAndroid Build Coastguard Worker return std::make_tuple(
199*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>::owned(to_expand1.expand_symint(expanded_size)),
200*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>::owned(to_expand2.expand_symint(expanded_size)));
201*da0073e9SAndroid Build Coastguard Worker }
202*da0073e9SAndroid Build Coastguard Worker
203*da0073e9SAndroid Build Coastguard Worker inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
204*da0073e9SAndroid Build Coastguard Worker expand_outplace(Tensor&& to_expand1, const Tensor& to_expand2) = delete;
205*da0073e9SAndroid Build Coastguard Worker inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
206*da0073e9SAndroid Build Coastguard Worker expand_outplace(const Tensor& to_expand1, Tensor&& to_expand2) = delete;
207*da0073e9SAndroid Build Coastguard Worker inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
208*da0073e9SAndroid Build Coastguard Worker expand_outplace(Tensor&& to_expand1, Tensor&& to_expand2) = delete;
209*da0073e9SAndroid Build Coastguard Worker
210*da0073e9SAndroid Build Coastguard Worker inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
expand_outplace(const Tensor & to_expand1,const Tensor & to_expand2,const char * api_name)211*da0073e9SAndroid Build Coastguard Worker expand_outplace(
212*da0073e9SAndroid Build Coastguard Worker const Tensor& to_expand1,
213*da0073e9SAndroid Build Coastguard Worker const Tensor& to_expand2,
214*da0073e9SAndroid Build Coastguard Worker const char* api_name) {
215*da0073e9SAndroid Build Coastguard Worker check_defined({to_expand1, to_expand2}, api_name);
216*da0073e9SAndroid Build Coastguard Worker return expand_outplace(to_expand1, to_expand2);
217*da0073e9SAndroid Build Coastguard Worker }
218*da0073e9SAndroid Build Coastguard Worker
219*da0073e9SAndroid Build Coastguard Worker inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
220*da0073e9SAndroid Build Coastguard Worker expand_outplace(
221*da0073e9SAndroid Build Coastguard Worker Tensor&& to_expand1,
222*da0073e9SAndroid Build Coastguard Worker const Tensor& to_expand2,
223*da0073e9SAndroid Build Coastguard Worker const char* api_name) = delete;
224*da0073e9SAndroid Build Coastguard Worker inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
225*da0073e9SAndroid Build Coastguard Worker expand_outplace(
226*da0073e9SAndroid Build Coastguard Worker const Tensor& to_expand1,
227*da0073e9SAndroid Build Coastguard Worker Tensor&& to_expand2,
228*da0073e9SAndroid Build Coastguard Worker const char* api_name) = delete;
229*da0073e9SAndroid Build Coastguard Worker inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
230*da0073e9SAndroid Build Coastguard Worker expand_outplace(
231*da0073e9SAndroid Build Coastguard Worker Tensor&& to_expand1,
232*da0073e9SAndroid Build Coastguard Worker Tensor&& to_expand2,
233*da0073e9SAndroid Build Coastguard Worker const char* api_name) = delete;
234*da0073e9SAndroid Build Coastguard Worker
235*da0073e9SAndroid Build Coastguard Worker inline std::tuple<
236*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>,
237*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>,
238*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>>
expand_outplace(const Tensor & to_expand1,const Tensor & to_expand2,const Tensor & to_expand3)239*da0073e9SAndroid Build Coastguard Worker expand_outplace(
240*da0073e9SAndroid Build Coastguard Worker const Tensor& to_expand1,
241*da0073e9SAndroid Build Coastguard Worker const Tensor& to_expand2,
242*da0073e9SAndroid Build Coastguard Worker const Tensor& to_expand3) {
243*da0073e9SAndroid Build Coastguard Worker if (to_expand1.sizes().equals(to_expand2.sizes()) &&
244*da0073e9SAndroid Build Coastguard Worker to_expand1.sizes().equals(to_expand3.sizes())) {
245*da0073e9SAndroid Build Coastguard Worker return std::make_tuple(
246*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>::borrowed(to_expand1),
247*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>::borrowed(to_expand2),
248*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>::borrowed(to_expand3));
249*da0073e9SAndroid Build Coastguard Worker }
250*da0073e9SAndroid Build Coastguard Worker
251*da0073e9SAndroid Build Coastguard Worker auto expanded_size12 =
252*da0073e9SAndroid Build Coastguard Worker infer_size_dimvector(to_expand1.sizes(), to_expand2.sizes());
253*da0073e9SAndroid Build Coastguard Worker auto expanded_size =
254*da0073e9SAndroid Build Coastguard Worker infer_size_dimvector(expanded_size12, to_expand3.sizes());
255*da0073e9SAndroid Build Coastguard Worker return std::make_tuple(
256*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>::owned(to_expand1.expand(expanded_size)),
257*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>::owned(to_expand2.expand(expanded_size)),
258*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>::owned(to_expand3.expand(expanded_size)));
259*da0073e9SAndroid Build Coastguard Worker }
260*da0073e9SAndroid Build Coastguard Worker
261*da0073e9SAndroid Build Coastguard Worker inline std::tuple<
262*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>,
263*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>,
264*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>>
265*da0073e9SAndroid Build Coastguard Worker expand_outplace(
266*da0073e9SAndroid Build Coastguard Worker Tensor&& to_expand1,
267*da0073e9SAndroid Build Coastguard Worker const Tensor& to_expand2,
268*da0073e9SAndroid Build Coastguard Worker const Tensor& to_expand3) = delete;
269*da0073e9SAndroid Build Coastguard Worker inline std::tuple<
270*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>,
271*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>,
272*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>>
273*da0073e9SAndroid Build Coastguard Worker expand_outplace(
274*da0073e9SAndroid Build Coastguard Worker const Tensor& to_expand1,
275*da0073e9SAndroid Build Coastguard Worker Tensor&& to_expand2,
276*da0073e9SAndroid Build Coastguard Worker const Tensor& to_expand3) = delete;
277*da0073e9SAndroid Build Coastguard Worker inline std::tuple<
278*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>,
279*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>,
280*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>>
281*da0073e9SAndroid Build Coastguard Worker expand_outplace(
282*da0073e9SAndroid Build Coastguard Worker Tensor&& to_expand1,
283*da0073e9SAndroid Build Coastguard Worker Tensor&& to_expand2,
284*da0073e9SAndroid Build Coastguard Worker const Tensor& to_expand3) = delete;
285*da0073e9SAndroid Build Coastguard Worker inline std::tuple<
286*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>,
287*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>,
288*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>>
289*da0073e9SAndroid Build Coastguard Worker expand_outplace(
290*da0073e9SAndroid Build Coastguard Worker const Tensor& to_expand1,
291*da0073e9SAndroid Build Coastguard Worker const Tensor& to_expand2,
292*da0073e9SAndroid Build Coastguard Worker Tensor&& to_expand3) = delete;
293*da0073e9SAndroid Build Coastguard Worker inline std::tuple<
294*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>,
295*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>,
296*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>>
297*da0073e9SAndroid Build Coastguard Worker expand_outplace(
298*da0073e9SAndroid Build Coastguard Worker Tensor&& to_expand1,
299*da0073e9SAndroid Build Coastguard Worker const Tensor& to_expand2,
300*da0073e9SAndroid Build Coastguard Worker Tensor&& to_expand3) = delete;
301*da0073e9SAndroid Build Coastguard Worker inline std::tuple<
302*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>,
303*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>,
304*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>>
305*da0073e9SAndroid Build Coastguard Worker expand_outplace(
306*da0073e9SAndroid Build Coastguard Worker const Tensor& to_expand1,
307*da0073e9SAndroid Build Coastguard Worker Tensor&& to_expand2,
308*da0073e9SAndroid Build Coastguard Worker Tensor&& to_expand3) = delete;
309*da0073e9SAndroid Build Coastguard Worker inline std::tuple<
310*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>,
311*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>,
312*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>>
313*da0073e9SAndroid Build Coastguard Worker expand_outplace(Tensor&& to_expand1, Tensor&& to_expand2, Tensor&& to_expand3) =
314*da0073e9SAndroid Build Coastguard Worker delete;
315*da0073e9SAndroid Build Coastguard Worker
316*da0073e9SAndroid Build Coastguard Worker inline std::tuple<
317*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>,
318*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>,
319*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>>
expand_outplace(const Tensor & to_expand1,const Tensor & to_expand2,const Tensor & to_expand3,const char * api_name)320*da0073e9SAndroid Build Coastguard Worker expand_outplace(
321*da0073e9SAndroid Build Coastguard Worker const Tensor& to_expand1,
322*da0073e9SAndroid Build Coastguard Worker const Tensor& to_expand2,
323*da0073e9SAndroid Build Coastguard Worker const Tensor& to_expand3,
324*da0073e9SAndroid Build Coastguard Worker const char* api_name) {
325*da0073e9SAndroid Build Coastguard Worker check_defined({to_expand1, to_expand2, to_expand3}, api_name);
326*da0073e9SAndroid Build Coastguard Worker return expand_outplace(to_expand1, to_expand2, to_expand3);
327*da0073e9SAndroid Build Coastguard Worker }
328*da0073e9SAndroid Build Coastguard Worker
329*da0073e9SAndroid Build Coastguard Worker inline std::tuple<
330*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>,
331*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>,
332*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>>
333*da0073e9SAndroid Build Coastguard Worker expand_outplace(
334*da0073e9SAndroid Build Coastguard Worker Tensor&& to_expand1,
335*da0073e9SAndroid Build Coastguard Worker const Tensor& to_expand2,
336*da0073e9SAndroid Build Coastguard Worker const Tensor& to_expand3,
337*da0073e9SAndroid Build Coastguard Worker const char* api_name) = delete;
338*da0073e9SAndroid Build Coastguard Worker inline std::tuple<
339*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>,
340*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>,
341*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>>
342*da0073e9SAndroid Build Coastguard Worker expand_outplace(
343*da0073e9SAndroid Build Coastguard Worker const Tensor& to_expand1,
344*da0073e9SAndroid Build Coastguard Worker Tensor&& to_expand2,
345*da0073e9SAndroid Build Coastguard Worker const Tensor& to_expand3,
346*da0073e9SAndroid Build Coastguard Worker const char* api_name) = delete;
347*da0073e9SAndroid Build Coastguard Worker inline std::tuple<
348*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>,
349*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>,
350*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>>
351*da0073e9SAndroid Build Coastguard Worker expand_outplace(
352*da0073e9SAndroid Build Coastguard Worker Tensor&& to_expand1,
353*da0073e9SAndroid Build Coastguard Worker Tensor&& to_expand2,
354*da0073e9SAndroid Build Coastguard Worker const Tensor& to_expand3,
355*da0073e9SAndroid Build Coastguard Worker const char* api_name) = delete;
356*da0073e9SAndroid Build Coastguard Worker inline std::tuple<
357*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>,
358*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>,
359*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>>
360*da0073e9SAndroid Build Coastguard Worker expand_outplace(
361*da0073e9SAndroid Build Coastguard Worker const Tensor& to_expand1,
362*da0073e9SAndroid Build Coastguard Worker const Tensor& to_expand2,
363*da0073e9SAndroid Build Coastguard Worker Tensor&& to_expand3,
364*da0073e9SAndroid Build Coastguard Worker const char* api_name) = delete;
365*da0073e9SAndroid Build Coastguard Worker inline std::tuple<
366*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>,
367*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>,
368*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>>
369*da0073e9SAndroid Build Coastguard Worker expand_outplace(
370*da0073e9SAndroid Build Coastguard Worker Tensor&& to_expand1,
371*da0073e9SAndroid Build Coastguard Worker const Tensor& to_expand2,
372*da0073e9SAndroid Build Coastguard Worker Tensor&& to_expand3,
373*da0073e9SAndroid Build Coastguard Worker const char* api_name) = delete;
374*da0073e9SAndroid Build Coastguard Worker inline std::tuple<
375*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>,
376*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>,
377*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>>
378*da0073e9SAndroid Build Coastguard Worker expand_outplace(
379*da0073e9SAndroid Build Coastguard Worker const Tensor& to_expand1,
380*da0073e9SAndroid Build Coastguard Worker Tensor&& to_expand2,
381*da0073e9SAndroid Build Coastguard Worker Tensor&& to_expand3,
382*da0073e9SAndroid Build Coastguard Worker const char* api_name) = delete;
383*da0073e9SAndroid Build Coastguard Worker inline std::tuple<
384*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>,
385*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>,
386*da0073e9SAndroid Build Coastguard Worker c10::MaybeOwned<Tensor>>
387*da0073e9SAndroid Build Coastguard Worker expand_outplace(
388*da0073e9SAndroid Build Coastguard Worker Tensor&& to_expand1,
389*da0073e9SAndroid Build Coastguard Worker Tensor&& to_expand2,
390*da0073e9SAndroid Build Coastguard Worker Tensor&& to_expand3,
391*da0073e9SAndroid Build Coastguard Worker const char* api_name) = delete;
392*da0073e9SAndroid Build Coastguard Worker
expand_size(const Tensor & to_expand,IntArrayRef sizes)393*da0073e9SAndroid Build Coastguard Worker inline c10::MaybeOwned<Tensor> expand_size(
394*da0073e9SAndroid Build Coastguard Worker const Tensor& to_expand,
395*da0073e9SAndroid Build Coastguard Worker IntArrayRef sizes) {
396*da0073e9SAndroid Build Coastguard Worker if (to_expand.sizes().equals(sizes)) {
397*da0073e9SAndroid Build Coastguard Worker return c10::MaybeOwned<Tensor>::borrowed(to_expand);
398*da0073e9SAndroid Build Coastguard Worker }
399*da0073e9SAndroid Build Coastguard Worker
400*da0073e9SAndroid Build Coastguard Worker return c10::MaybeOwned<Tensor>::owned(to_expand.expand(sizes));
401*da0073e9SAndroid Build Coastguard Worker }
402*da0073e9SAndroid Build Coastguard Worker
403*da0073e9SAndroid Build Coastguard Worker inline c10::MaybeOwned<Tensor> expand_size(
404*da0073e9SAndroid Build Coastguard Worker Tensor&& to_expand,
405*da0073e9SAndroid Build Coastguard Worker IntArrayRef sizes) = delete;
406*da0073e9SAndroid Build Coastguard Worker
expand_size(const Tensor & to_expand,IntArrayRef sizes,const char * api_name)407*da0073e9SAndroid Build Coastguard Worker inline c10::MaybeOwned<Tensor> expand_size(
408*da0073e9SAndroid Build Coastguard Worker const Tensor& to_expand,
409*da0073e9SAndroid Build Coastguard Worker IntArrayRef sizes,
410*da0073e9SAndroid Build Coastguard Worker const char* api_name) {
411*da0073e9SAndroid Build Coastguard Worker check_defined({to_expand}, api_name);
412*da0073e9SAndroid Build Coastguard Worker return expand_size(to_expand, sizes);
413*da0073e9SAndroid Build Coastguard Worker }
414*da0073e9SAndroid Build Coastguard Worker
415*da0073e9SAndroid Build Coastguard Worker inline c10::MaybeOwned<Tensor> expand_size(
416*da0073e9SAndroid Build Coastguard Worker Tensor&& to_expand,
417*da0073e9SAndroid Build Coastguard Worker IntArrayRef sizes,
418*da0073e9SAndroid Build Coastguard Worker const char* api_name) = delete;
419*da0073e9SAndroid Build Coastguard Worker
expand_outplace(TensorList to_expand)420*da0073e9SAndroid Build Coastguard Worker inline std::vector<Tensor> expand_outplace(TensorList to_expand) {
421*da0073e9SAndroid Build Coastguard Worker // expands a list of Tensors; ignores undefined (null) tensors
422*da0073e9SAndroid Build Coastguard Worker bool first = true;
423*da0073e9SAndroid Build Coastguard Worker DimVector sizes;
424*da0073e9SAndroid Build Coastguard Worker for (const auto i : c10::irange(to_expand.size())) {
425*da0073e9SAndroid Build Coastguard Worker if (!to_expand[i].defined()) {
426*da0073e9SAndroid Build Coastguard Worker continue;
427*da0073e9SAndroid Build Coastguard Worker } else if (first) {
428*da0073e9SAndroid Build Coastguard Worker sizes = to_expand[i].sizes();
429*da0073e9SAndroid Build Coastguard Worker first = false;
430*da0073e9SAndroid Build Coastguard Worker } else {
431*da0073e9SAndroid Build Coastguard Worker sizes = infer_size_dimvector(sizes, to_expand[i].sizes());
432*da0073e9SAndroid Build Coastguard Worker }
433*da0073e9SAndroid Build Coastguard Worker }
434*da0073e9SAndroid Build Coastguard Worker
435*da0073e9SAndroid Build Coastguard Worker std::vector<Tensor> result(to_expand.size());
436*da0073e9SAndroid Build Coastguard Worker for (const auto i : c10::irange(to_expand.size())) {
437*da0073e9SAndroid Build Coastguard Worker if (!to_expand[i].defined()) {
438*da0073e9SAndroid Build Coastguard Worker continue;
439*da0073e9SAndroid Build Coastguard Worker } else if (to_expand[i].sizes().equals(sizes)) {
440*da0073e9SAndroid Build Coastguard Worker result[i] = to_expand[i];
441*da0073e9SAndroid Build Coastguard Worker } else {
442*da0073e9SAndroid Build Coastguard Worker result[i] = to_expand[i].expand(sizes);
443*da0073e9SAndroid Build Coastguard Worker }
444*da0073e9SAndroid Build Coastguard Worker }
445*da0073e9SAndroid Build Coastguard Worker return result;
446*da0073e9SAndroid Build Coastguard Worker }
447*da0073e9SAndroid Build Coastguard Worker
448*da0073e9SAndroid Build Coastguard Worker template <typename T>
449*da0073e9SAndroid Build Coastguard Worker inline Tensor _sum_to(
450*da0073e9SAndroid Build Coastguard Worker Tensor tensor,
451*da0073e9SAndroid Build Coastguard Worker const c10::ArrayRef<T> shape,
452*da0073e9SAndroid Build Coastguard Worker bool always_return_non_view = false) {
453*da0073e9SAndroid Build Coastguard Worker if (shape.size() == 0) {
454*da0073e9SAndroid Build Coastguard Worker return tensor.sum();
455*da0073e9SAndroid Build Coastguard Worker }
456*da0073e9SAndroid Build Coastguard Worker
457*da0073e9SAndroid Build Coastguard Worker auto sizes = at::symint::sizes<T>(tensor);
458*da0073e9SAndroid Build Coastguard Worker c10::SmallVector<int64_t, 8> reduce_dims;
459*da0073e9SAndroid Build Coastguard Worker const int64_t leading_dims = sizes.size() - shape.size();
460*da0073e9SAndroid Build Coastguard Worker for (const auto i : c10::irange(leading_dims)) {
461*da0073e9SAndroid Build Coastguard Worker reduce_dims.push_back(i);
462*da0073e9SAndroid Build Coastguard Worker }
463*da0073e9SAndroid Build Coastguard Worker for (int64_t i = leading_dims; i < static_cast<int64_t>(sizes.size()); ++i) {
464*da0073e9SAndroid Build Coastguard Worker if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(shape[i - leading_dims], 1)) &&
465*da0073e9SAndroid Build Coastguard Worker TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(sizes[i], 1))) {
466*da0073e9SAndroid Build Coastguard Worker reduce_dims.push_back(i);
467*da0073e9SAndroid Build Coastguard Worker }
468*da0073e9SAndroid Build Coastguard Worker }
469*da0073e9SAndroid Build Coastguard Worker
470*da0073e9SAndroid Build Coastguard Worker if (!reduce_dims.empty()) {
471*da0073e9SAndroid Build Coastguard Worker tensor = tensor.sum(reduce_dims, /*keepdim=*/true);
472*da0073e9SAndroid Build Coastguard Worker }
473*da0073e9SAndroid Build Coastguard Worker
474*da0073e9SAndroid Build Coastguard Worker if (always_return_non_view) {
475*da0073e9SAndroid Build Coastguard Worker // This is only actually used by the functionalization pass.
476*da0073e9SAndroid Build Coastguard Worker // We want to be able to guarantee that this function doesn't return a view
477*da0073e9SAndroid Build Coastguard Worker // of the input.
478*da0073e9SAndroid Build Coastguard Worker return leading_dims > 0 ? at::symint::view_copy<T>(tensor, shape)
479*da0073e9SAndroid Build Coastguard Worker : tensor.clone();
480*da0073e9SAndroid Build Coastguard Worker } else {
481*da0073e9SAndroid Build Coastguard Worker return leading_dims > 0 ? at::symint::view<T>(tensor, shape) : tensor;
482*da0073e9SAndroid Build Coastguard Worker }
483*da0073e9SAndroid Build Coastguard Worker }
484*da0073e9SAndroid Build Coastguard Worker
485*da0073e9SAndroid Build Coastguard Worker inline Tensor sum_to(
486*da0073e9SAndroid Build Coastguard Worker Tensor tensor,
487*da0073e9SAndroid Build Coastguard Worker const c10::SymIntArrayRef shape,
488*da0073e9SAndroid Build Coastguard Worker bool always_return_non_view = false) {
489*da0073e9SAndroid Build Coastguard Worker return _sum_to(std::move(tensor), shape, always_return_non_view);
490*da0073e9SAndroid Build Coastguard Worker }
491*da0073e9SAndroid Build Coastguard Worker
492*da0073e9SAndroid Build Coastguard Worker // Sums `tensor` repeatedly to produce a tensor of shape `shape`.
493*da0073e9SAndroid Build Coastguard Worker // Precondition: is_expandable_to(shape, tensor.sizes()) must be true
494*da0073e9SAndroid Build Coastguard Worker inline Tensor sum_to(
495*da0073e9SAndroid Build Coastguard Worker Tensor tensor,
496*da0073e9SAndroid Build Coastguard Worker const IntArrayRef shape,
497*da0073e9SAndroid Build Coastguard Worker bool always_return_non_view = false) {
498*da0073e9SAndroid Build Coastguard Worker return _sum_to(std::move(tensor), shape, always_return_non_view);
499*da0073e9SAndroid Build Coastguard Worker }
500*da0073e9SAndroid Build Coastguard Worker
is_expandable_to(SymIntArrayRef shape,c10::SymIntArrayRef desired)501*da0073e9SAndroid Build Coastguard Worker inline bool is_expandable_to(
502*da0073e9SAndroid Build Coastguard Worker SymIntArrayRef shape,
503*da0073e9SAndroid Build Coastguard Worker c10::SymIntArrayRef desired) {
504*da0073e9SAndroid Build Coastguard Worker size_t ndim = shape.size();
505*da0073e9SAndroid Build Coastguard Worker size_t target_dim = desired.size();
506*da0073e9SAndroid Build Coastguard Worker if (ndim > target_dim) {
507*da0073e9SAndroid Build Coastguard Worker return false;
508*da0073e9SAndroid Build Coastguard Worker }
509*da0073e9SAndroid Build Coastguard Worker for (const auto i : c10::irange(ndim)) {
510*da0073e9SAndroid Build Coastguard Worker const auto& size = shape[ndim - i - 1];
511*da0073e9SAndroid Build Coastguard Worker const auto& target = desired[target_dim - i - 1];
512*da0073e9SAndroid Build Coastguard Worker if (size != target && size != 1) {
513*da0073e9SAndroid Build Coastguard Worker return false;
514*da0073e9SAndroid Build Coastguard Worker }
515*da0073e9SAndroid Build Coastguard Worker }
516*da0073e9SAndroid Build Coastguard Worker return true;
517*da0073e9SAndroid Build Coastguard Worker }
518*da0073e9SAndroid Build Coastguard Worker
is_expandable_to(IntArrayRef shape,IntArrayRef desired)519*da0073e9SAndroid Build Coastguard Worker inline bool is_expandable_to(IntArrayRef shape, IntArrayRef desired) {
520*da0073e9SAndroid Build Coastguard Worker auto sym_shape = c10::SymIntArrayRef(
521*da0073e9SAndroid Build Coastguard Worker reinterpret_cast<const c10::SymInt*>(shape.data()), shape.size());
522*da0073e9SAndroid Build Coastguard Worker auto sym_desired = c10::SymIntArrayRef(
523*da0073e9SAndroid Build Coastguard Worker reinterpret_cast<const c10::SymInt*>(desired.data()), desired.size());
524*da0073e9SAndroid Build Coastguard Worker return is_expandable_to(sym_shape, sym_desired);
525*da0073e9SAndroid Build Coastguard Worker }
526*da0073e9SAndroid Build Coastguard Worker
527*da0073e9SAndroid Build Coastguard Worker } // namespace at
528