xref: /aosp_15_r20/external/pytorch/aten/src/ATen/ExpandUtils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #ifndef AT_PER_OPERATOR_HEADERS
4 #include <ATen/Functions.h>
5 #else
6 #include <ATen/ops/view.h>
7 #include <ATen/ops/view_copy.h>
8 #endif
9 
10 #include <ATen/Tensor.h>
11 #include <ATen/core/DimVector.h>
12 #include <c10/util/Exception.h>
13 #include <c10/util/MaybeOwned.h>
14 #include <c10/util/irange.h>
15 
16 #include <functional>
17 #include <tuple>
18 #include <utility>
19 
20 namespace at {
21 
22 TORCH_API std::vector<int64_t> infer_size(IntArrayRef a, IntArrayRef b);
23 TORCH_API std::vector<SymInt> infer_size_symint(
24     SymIntArrayRef a,
25     SymIntArrayRef b);
26 TORCH_API DimVector infer_size_dimvector(IntArrayRef a, IntArrayRef b);
27 TORCH_API SymDimVector
28 infer_size_symdimvector(SymIntArrayRef a, SymIntArrayRef b);
29 
30 // Named type instead of a pair/tuple so that we can be sure to
31 // construct the vectors in place and get NRVO.
32 template <typename Container>
33 struct InferExpandGeometryResult {
34   Container sizes;
35   Container strides;
InferExpandGeometryResultInferExpandGeometryResult36   explicit InferExpandGeometryResult(size_t ndim)
37       : sizes(ndim), strides(ndim) {}
InferExpandGeometryResultInferExpandGeometryResult38   explicit InferExpandGeometryResult(IntArrayRef sizes_, size_t ndim)
39       : sizes(sizes_.begin(), sizes_.end()), strides(ndim) {}
40 };
41 
42 TORCH_API std::tuple<std::vector<int64_t>, std::vector<int64_t>>
43 inferExpandGeometry(
44     IntArrayRef tensor_sizes,
45     IntArrayRef tensor_strides,
46     IntArrayRef sizes);
47 
48 TORCH_API InferExpandGeometryResult<DimVector> inferExpandGeometry_dimvector(
49     IntArrayRef tensor_sizes,
50     IntArrayRef tensor_strides,
51     IntArrayRef sizes);
52 
53 TORCH_API std::vector<int64_t> infer_dense_strides(
54     IntArrayRef tensor_sizes,
55     IntArrayRef tensor_strides);
56 
57 // True if input shapes are expandable
58 // NOTE: infer_size did a similar check, please keep them sync if change is
59 // needed
are_expandable(IntArrayRef shape1,IntArrayRef shape2)60 inline bool are_expandable(IntArrayRef shape1, IntArrayRef shape2) {
61   size_t ndim1 = shape1.size();
62   size_t ndim2 = shape2.size();
63   size_t ndim = ndim1 < ndim2 ? ndim1 : ndim2;
64 
65   for (int64_t i = static_cast<int64_t>(ndim) - 1; i >= 0; --i) {
66     if (shape1[--ndim1] == shape2[--ndim2] || shape1[ndim1] == 1 ||
67         shape2[ndim2] == 1) {
68       continue;
69     }
70     return false;
71   }
72   return true;
73 }
74 
75 // avoid copy-construction of Tensor by using a reference_wrapper.
check_defined(std::initializer_list<std::reference_wrapper<const Tensor>> tensors,const char * api_name)76 inline void check_defined(
77     std::initializer_list<std::reference_wrapper<const Tensor>> tensors,
78     const char* api_name) {
79   for (auto& t : tensors) {
80     if (!t.get().defined()) {
81       AT_ERROR(api_name, "(...) called with an undefined Tensor");
82     }
83   }
84 }
85 
86 // NOTE [ ExpandUtils Borrowing ]
87 //
88 // Functions in ExpandUtils return `c10::MaybeOwned<Tensor>` because
89 // expansion may not actually be needed, in which case we can improve
90 // efficiency by returning
91 // `c10::MaybeOwned<Tensor>::borrowed(to_expand)`. However, this means
92 // that you need to be careful: the returned `c10::MaybeOwned<Tensor>`
93 // must not outlive the original `Tensor` object that `to_expand`
94 // referred to! The deleted rvalue reference overloads of these
95 // functions help with this by preventing trivial use of a temporary
96 // resulting from a function call, but it is still possible to make a
97 // mistake.
98 
expand_inplace(const Tensor & tensor,const Tensor & to_expand)99 inline c10::MaybeOwned<Tensor> expand_inplace(
100     const Tensor& tensor,
101     const Tensor& to_expand) {
102   if (tensor.sym_sizes().equals(to_expand.sym_sizes())) {
103     return c10::MaybeOwned<Tensor>::borrowed(to_expand);
104   }
105   return c10::MaybeOwned<Tensor>::owned(
106       to_expand.expand_symint(tensor.sym_sizes()));
107 }
108 
109 inline c10::MaybeOwned<Tensor> expand_inplace(
110     const Tensor& tensor,
111     Tensor&& to_expand) = delete;
112 
expand_inplace(const Tensor & tensor,const Tensor & to_expand,const char * api_name)113 inline c10::MaybeOwned<Tensor> expand_inplace(
114     const Tensor& tensor,
115     const Tensor& to_expand,
116     const char* api_name) {
117   check_defined({tensor, to_expand}, api_name);
118   return expand_inplace(tensor, to_expand);
119 }
120 
121 inline c10::MaybeOwned<Tensor> expand_inplace(
122     const Tensor& tensor,
123     Tensor&& to_expand,
124     const char* api_name) = delete;
125 
126 inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
expand_inplace(const Tensor & tensor,const Tensor & to_expand1,const Tensor & to_expand2)127 expand_inplace(
128     const Tensor& tensor,
129     const Tensor& to_expand1,
130     const Tensor& to_expand2) {
131   if (tensor.sizes().equals(to_expand1.sizes()) &&
132       tensor.sizes().equals((to_expand2.sizes()))) {
133     return std::make_tuple(
134         c10::MaybeOwned<Tensor>::borrowed(to_expand1),
135         c10::MaybeOwned<Tensor>::borrowed(to_expand2));
136   }
137 
138   return std::make_tuple(
139       c10::MaybeOwned<Tensor>::owned(to_expand1.expand(tensor.sizes())),
140       c10::MaybeOwned<Tensor>::owned(to_expand2.expand(tensor.sizes())));
141 }
142 
143 inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
144 expand_inplace(
145     const Tensor& tensor,
146     Tensor&& to_expand1,
147     const Tensor& to_expand2) = delete;
148 inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
149 expand_inplace(
150     const Tensor& tensor,
151     const Tensor& to_expand1,
152     Tensor&& to_expand2) = delete;
153 inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
154 expand_inplace(const Tensor& tensor, Tensor&& to_expand1, Tensor&& to_expand2) =
155     delete;
156 
157 inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
expand_inplace(const Tensor & tensor,const Tensor & to_expand1,const Tensor & to_expand2,const char * api_name)158 expand_inplace(
159     const Tensor& tensor,
160     const Tensor& to_expand1,
161     const Tensor& to_expand2,
162     const char* api_name) {
163   check_defined({tensor, to_expand1, to_expand2}, api_name);
164   return expand_inplace(tensor, to_expand1, to_expand2);
165 }
166 
167 inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
168 expand_inplace(
169     const Tensor& tensor,
170     Tensor&& to_expand1,
171     const Tensor& to_expand2,
172     const char* api_name) = delete;
173 inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
174 expand_inplace(
175     const Tensor& tensor,
176     const Tensor& to_expand1,
177     Tensor&& to_expand2,
178     const char* api_name) = delete;
179 inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
180 expand_inplace(
181     const Tensor& tensor,
182     Tensor&& to_expand1,
183     Tensor&& to_expand2,
184     const char* api_name) = delete;
185 
186 // See NOTE [ ExpandUtils Borrowing ] above for `MaybeOwned` explanation.
187 inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
expand_outplace(const Tensor & to_expand1,const Tensor & to_expand2)188 expand_outplace(const Tensor& to_expand1, const Tensor& to_expand2) {
189   auto s1 = to_expand1.sym_sizes();
190   auto s2 = to_expand2.sym_sizes();
191   if (s1.equals(s2)) {
192     return std::make_tuple(
193         c10::MaybeOwned<Tensor>::borrowed(to_expand1),
194         c10::MaybeOwned<Tensor>::borrowed(to_expand2));
195   }
196 
197   auto expanded_size = infer_size_symdimvector(s1, s2);
198   return std::make_tuple(
199       c10::MaybeOwned<Tensor>::owned(to_expand1.expand_symint(expanded_size)),
200       c10::MaybeOwned<Tensor>::owned(to_expand2.expand_symint(expanded_size)));
201 }
202 
203 inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
204 expand_outplace(Tensor&& to_expand1, const Tensor& to_expand2) = delete;
205 inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
206 expand_outplace(const Tensor& to_expand1, Tensor&& to_expand2) = delete;
207 inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
208 expand_outplace(Tensor&& to_expand1, Tensor&& to_expand2) = delete;
209 
210 inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
expand_outplace(const Tensor & to_expand1,const Tensor & to_expand2,const char * api_name)211 expand_outplace(
212     const Tensor& to_expand1,
213     const Tensor& to_expand2,
214     const char* api_name) {
215   check_defined({to_expand1, to_expand2}, api_name);
216   return expand_outplace(to_expand1, to_expand2);
217 }
218 
219 inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
220 expand_outplace(
221     Tensor&& to_expand1,
222     const Tensor& to_expand2,
223     const char* api_name) = delete;
224 inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
225 expand_outplace(
226     const Tensor& to_expand1,
227     Tensor&& to_expand2,
228     const char* api_name) = delete;
229 inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
230 expand_outplace(
231     Tensor&& to_expand1,
232     Tensor&& to_expand2,
233     const char* api_name) = delete;
234 
235 inline std::tuple<
236     c10::MaybeOwned<Tensor>,
237     c10::MaybeOwned<Tensor>,
238     c10::MaybeOwned<Tensor>>
expand_outplace(const Tensor & to_expand1,const Tensor & to_expand2,const Tensor & to_expand3)239 expand_outplace(
240     const Tensor& to_expand1,
241     const Tensor& to_expand2,
242     const Tensor& to_expand3) {
243   if (to_expand1.sizes().equals(to_expand2.sizes()) &&
244       to_expand1.sizes().equals(to_expand3.sizes())) {
245     return std::make_tuple(
246         c10::MaybeOwned<Tensor>::borrowed(to_expand1),
247         c10::MaybeOwned<Tensor>::borrowed(to_expand2),
248         c10::MaybeOwned<Tensor>::borrowed(to_expand3));
249   }
250 
251   auto expanded_size12 =
252       infer_size_dimvector(to_expand1.sizes(), to_expand2.sizes());
253   auto expanded_size =
254       infer_size_dimvector(expanded_size12, to_expand3.sizes());
255   return std::make_tuple(
256       c10::MaybeOwned<Tensor>::owned(to_expand1.expand(expanded_size)),
257       c10::MaybeOwned<Tensor>::owned(to_expand2.expand(expanded_size)),
258       c10::MaybeOwned<Tensor>::owned(to_expand3.expand(expanded_size)));
259 }
260 
261 inline std::tuple<
262     c10::MaybeOwned<Tensor>,
263     c10::MaybeOwned<Tensor>,
264     c10::MaybeOwned<Tensor>>
265 expand_outplace(
266     Tensor&& to_expand1,
267     const Tensor& to_expand2,
268     const Tensor& to_expand3) = delete;
269 inline std::tuple<
270     c10::MaybeOwned<Tensor>,
271     c10::MaybeOwned<Tensor>,
272     c10::MaybeOwned<Tensor>>
273 expand_outplace(
274     const Tensor& to_expand1,
275     Tensor&& to_expand2,
276     const Tensor& to_expand3) = delete;
277 inline std::tuple<
278     c10::MaybeOwned<Tensor>,
279     c10::MaybeOwned<Tensor>,
280     c10::MaybeOwned<Tensor>>
281 expand_outplace(
282     Tensor&& to_expand1,
283     Tensor&& to_expand2,
284     const Tensor& to_expand3) = delete;
285 inline std::tuple<
286     c10::MaybeOwned<Tensor>,
287     c10::MaybeOwned<Tensor>,
288     c10::MaybeOwned<Tensor>>
289 expand_outplace(
290     const Tensor& to_expand1,
291     const Tensor& to_expand2,
292     Tensor&& to_expand3) = delete;
293 inline std::tuple<
294     c10::MaybeOwned<Tensor>,
295     c10::MaybeOwned<Tensor>,
296     c10::MaybeOwned<Tensor>>
297 expand_outplace(
298     Tensor&& to_expand1,
299     const Tensor& to_expand2,
300     Tensor&& to_expand3) = delete;
301 inline std::tuple<
302     c10::MaybeOwned<Tensor>,
303     c10::MaybeOwned<Tensor>,
304     c10::MaybeOwned<Tensor>>
305 expand_outplace(
306     const Tensor& to_expand1,
307     Tensor&& to_expand2,
308     Tensor&& to_expand3) = delete;
309 inline std::tuple<
310     c10::MaybeOwned<Tensor>,
311     c10::MaybeOwned<Tensor>,
312     c10::MaybeOwned<Tensor>>
313 expand_outplace(Tensor&& to_expand1, Tensor&& to_expand2, Tensor&& to_expand3) =
314     delete;
315 
316 inline std::tuple<
317     c10::MaybeOwned<Tensor>,
318     c10::MaybeOwned<Tensor>,
319     c10::MaybeOwned<Tensor>>
expand_outplace(const Tensor & to_expand1,const Tensor & to_expand2,const Tensor & to_expand3,const char * api_name)320 expand_outplace(
321     const Tensor& to_expand1,
322     const Tensor& to_expand2,
323     const Tensor& to_expand3,
324     const char* api_name) {
325   check_defined({to_expand1, to_expand2, to_expand3}, api_name);
326   return expand_outplace(to_expand1, to_expand2, to_expand3);
327 }
328 
329 inline std::tuple<
330     c10::MaybeOwned<Tensor>,
331     c10::MaybeOwned<Tensor>,
332     c10::MaybeOwned<Tensor>>
333 expand_outplace(
334     Tensor&& to_expand1,
335     const Tensor& to_expand2,
336     const Tensor& to_expand3,
337     const char* api_name) = delete;
338 inline std::tuple<
339     c10::MaybeOwned<Tensor>,
340     c10::MaybeOwned<Tensor>,
341     c10::MaybeOwned<Tensor>>
342 expand_outplace(
343     const Tensor& to_expand1,
344     Tensor&& to_expand2,
345     const Tensor& to_expand3,
346     const char* api_name) = delete;
347 inline std::tuple<
348     c10::MaybeOwned<Tensor>,
349     c10::MaybeOwned<Tensor>,
350     c10::MaybeOwned<Tensor>>
351 expand_outplace(
352     Tensor&& to_expand1,
353     Tensor&& to_expand2,
354     const Tensor& to_expand3,
355     const char* api_name) = delete;
356 inline std::tuple<
357     c10::MaybeOwned<Tensor>,
358     c10::MaybeOwned<Tensor>,
359     c10::MaybeOwned<Tensor>>
360 expand_outplace(
361     const Tensor& to_expand1,
362     const Tensor& to_expand2,
363     Tensor&& to_expand3,
364     const char* api_name) = delete;
365 inline std::tuple<
366     c10::MaybeOwned<Tensor>,
367     c10::MaybeOwned<Tensor>,
368     c10::MaybeOwned<Tensor>>
369 expand_outplace(
370     Tensor&& to_expand1,
371     const Tensor& to_expand2,
372     Tensor&& to_expand3,
373     const char* api_name) = delete;
374 inline std::tuple<
375     c10::MaybeOwned<Tensor>,
376     c10::MaybeOwned<Tensor>,
377     c10::MaybeOwned<Tensor>>
378 expand_outplace(
379     const Tensor& to_expand1,
380     Tensor&& to_expand2,
381     Tensor&& to_expand3,
382     const char* api_name) = delete;
383 inline std::tuple<
384     c10::MaybeOwned<Tensor>,
385     c10::MaybeOwned<Tensor>,
386     c10::MaybeOwned<Tensor>>
387 expand_outplace(
388     Tensor&& to_expand1,
389     Tensor&& to_expand2,
390     Tensor&& to_expand3,
391     const char* api_name) = delete;
392 
expand_size(const Tensor & to_expand,IntArrayRef sizes)393 inline c10::MaybeOwned<Tensor> expand_size(
394     const Tensor& to_expand,
395     IntArrayRef sizes) {
396   if (to_expand.sizes().equals(sizes)) {
397     return c10::MaybeOwned<Tensor>::borrowed(to_expand);
398   }
399 
400   return c10::MaybeOwned<Tensor>::owned(to_expand.expand(sizes));
401 }
402 
403 inline c10::MaybeOwned<Tensor> expand_size(
404     Tensor&& to_expand,
405     IntArrayRef sizes) = delete;
406 
expand_size(const Tensor & to_expand,IntArrayRef sizes,const char * api_name)407 inline c10::MaybeOwned<Tensor> expand_size(
408     const Tensor& to_expand,
409     IntArrayRef sizes,
410     const char* api_name) {
411   check_defined({to_expand}, api_name);
412   return expand_size(to_expand, sizes);
413 }
414 
415 inline c10::MaybeOwned<Tensor> expand_size(
416     Tensor&& to_expand,
417     IntArrayRef sizes,
418     const char* api_name) = delete;
419 
expand_outplace(TensorList to_expand)420 inline std::vector<Tensor> expand_outplace(TensorList to_expand) {
421   // expands a list of Tensors; ignores undefined (null) tensors
422   bool first = true;
423   DimVector sizes;
424   for (const auto i : c10::irange(to_expand.size())) {
425     if (!to_expand[i].defined()) {
426       continue;
427     } else if (first) {
428       sizes = to_expand[i].sizes();
429       first = false;
430     } else {
431       sizes = infer_size_dimvector(sizes, to_expand[i].sizes());
432     }
433   }
434 
435   std::vector<Tensor> result(to_expand.size());
436   for (const auto i : c10::irange(to_expand.size())) {
437     if (!to_expand[i].defined()) {
438       continue;
439     } else if (to_expand[i].sizes().equals(sizes)) {
440       result[i] = to_expand[i];
441     } else {
442       result[i] = to_expand[i].expand(sizes);
443     }
444   }
445   return result;
446 }
447 
448 template <typename T>
449 inline Tensor _sum_to(
450     Tensor tensor,
451     const c10::ArrayRef<T> shape,
452     bool always_return_non_view = false) {
453   if (shape.size() == 0) {
454     return tensor.sum();
455   }
456 
457   auto sizes = at::symint::sizes<T>(tensor);
458   c10::SmallVector<int64_t, 8> reduce_dims;
459   const int64_t leading_dims = sizes.size() - shape.size();
460   for (const auto i : c10::irange(leading_dims)) {
461     reduce_dims.push_back(i);
462   }
463   for (int64_t i = leading_dims; i < static_cast<int64_t>(sizes.size()); ++i) {
464     if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(shape[i - leading_dims], 1)) &&
465         TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(sizes[i], 1))) {
466       reduce_dims.push_back(i);
467     }
468   }
469 
470   if (!reduce_dims.empty()) {
471     tensor = tensor.sum(reduce_dims, /*keepdim=*/true);
472   }
473 
474   if (always_return_non_view) {
475     // This is only actually used by the functionalization pass.
476     // We want to be able to guarantee that this function doesn't return a view
477     // of the input.
478     return leading_dims > 0 ? at::symint::view_copy<T>(tensor, shape)
479                             : tensor.clone();
480   } else {
481     return leading_dims > 0 ? at::symint::view<T>(tensor, shape) : tensor;
482   }
483 }
484 
485 inline Tensor sum_to(
486     Tensor tensor,
487     const c10::SymIntArrayRef shape,
488     bool always_return_non_view = false) {
489   return _sum_to(std::move(tensor), shape, always_return_non_view);
490 }
491 
492 // Sums `tensor` repeatedly to produce a tensor of shape `shape`.
493 // Precondition: is_expandable_to(shape, tensor.sizes()) must be true
494 inline Tensor sum_to(
495     Tensor tensor,
496     const IntArrayRef shape,
497     bool always_return_non_view = false) {
498   return _sum_to(std::move(tensor), shape, always_return_non_view);
499 }
500 
is_expandable_to(SymIntArrayRef shape,c10::SymIntArrayRef desired)501 inline bool is_expandable_to(
502     SymIntArrayRef shape,
503     c10::SymIntArrayRef desired) {
504   size_t ndim = shape.size();
505   size_t target_dim = desired.size();
506   if (ndim > target_dim) {
507     return false;
508   }
509   for (const auto i : c10::irange(ndim)) {
510     const auto& size = shape[ndim - i - 1];
511     const auto& target = desired[target_dim - i - 1];
512     if (size != target && size != 1) {
513       return false;
514     }
515   }
516   return true;
517 }
518 
is_expandable_to(IntArrayRef shape,IntArrayRef desired)519 inline bool is_expandable_to(IntArrayRef shape, IntArrayRef desired) {
520   auto sym_shape = c10::SymIntArrayRef(
521       reinterpret_cast<const c10::SymInt*>(shape.data()), shape.size());
522   auto sym_desired = c10::SymIntArrayRef(
523       reinterpret_cast<const c10::SymInt*>(desired.data()), desired.size());
524   return is_expandable_to(sym_shape, sym_desired);
525 }
526 
527 } // namespace at
528