1 #pragma once
2
3 #ifndef AT_PER_OPERATOR_HEADERS
4 #include <ATen/Functions.h>
5 #else
6 #include <ATen/ops/view.h>
7 #include <ATen/ops/view_copy.h>
8 #endif
9
10 #include <ATen/Tensor.h>
11 #include <ATen/core/DimVector.h>
12 #include <c10/util/Exception.h>
13 #include <c10/util/MaybeOwned.h>
14 #include <c10/util/irange.h>
15
16 #include <functional>
17 #include <tuple>
18 #include <utility>
19
20 namespace at {
21
22 TORCH_API std::vector<int64_t> infer_size(IntArrayRef a, IntArrayRef b);
23 TORCH_API std::vector<SymInt> infer_size_symint(
24 SymIntArrayRef a,
25 SymIntArrayRef b);
26 TORCH_API DimVector infer_size_dimvector(IntArrayRef a, IntArrayRef b);
27 TORCH_API SymDimVector
28 infer_size_symdimvector(SymIntArrayRef a, SymIntArrayRef b);
29
30 // Named type instead of a pair/tuple so that we can be sure to
31 // construct the vectors in place and get NRVO.
32 template <typename Container>
33 struct InferExpandGeometryResult {
34 Container sizes;
35 Container strides;
InferExpandGeometryResultInferExpandGeometryResult36 explicit InferExpandGeometryResult(size_t ndim)
37 : sizes(ndim), strides(ndim) {}
InferExpandGeometryResultInferExpandGeometryResult38 explicit InferExpandGeometryResult(IntArrayRef sizes_, size_t ndim)
39 : sizes(sizes_.begin(), sizes_.end()), strides(ndim) {}
40 };
41
42 TORCH_API std::tuple<std::vector<int64_t>, std::vector<int64_t>>
43 inferExpandGeometry(
44 IntArrayRef tensor_sizes,
45 IntArrayRef tensor_strides,
46 IntArrayRef sizes);
47
48 TORCH_API InferExpandGeometryResult<DimVector> inferExpandGeometry_dimvector(
49 IntArrayRef tensor_sizes,
50 IntArrayRef tensor_strides,
51 IntArrayRef sizes);
52
53 TORCH_API std::vector<int64_t> infer_dense_strides(
54 IntArrayRef tensor_sizes,
55 IntArrayRef tensor_strides);
56
57 // True if input shapes are expandable
58 // NOTE: infer_size did a similar check, please keep them sync if change is
59 // needed
are_expandable(IntArrayRef shape1,IntArrayRef shape2)60 inline bool are_expandable(IntArrayRef shape1, IntArrayRef shape2) {
61 size_t ndim1 = shape1.size();
62 size_t ndim2 = shape2.size();
63 size_t ndim = ndim1 < ndim2 ? ndim1 : ndim2;
64
65 for (int64_t i = static_cast<int64_t>(ndim) - 1; i >= 0; --i) {
66 if (shape1[--ndim1] == shape2[--ndim2] || shape1[ndim1] == 1 ||
67 shape2[ndim2] == 1) {
68 continue;
69 }
70 return false;
71 }
72 return true;
73 }
74
75 // avoid copy-construction of Tensor by using a reference_wrapper.
check_defined(std::initializer_list<std::reference_wrapper<const Tensor>> tensors,const char * api_name)76 inline void check_defined(
77 std::initializer_list<std::reference_wrapper<const Tensor>> tensors,
78 const char* api_name) {
79 for (auto& t : tensors) {
80 if (!t.get().defined()) {
81 AT_ERROR(api_name, "(...) called with an undefined Tensor");
82 }
83 }
84 }
85
86 // NOTE [ ExpandUtils Borrowing ]
87 //
88 // Functions in ExpandUtils return `c10::MaybeOwned<Tensor>` because
89 // expansion may not actually be needed, in which case we can improve
90 // efficiency by returning
91 // `c10::MaybeOwned<Tensor>::borrowed(to_expand)`. However, this means
92 // that you need to be careful: the returned `c10::MaybeOwned<Tensor>`
93 // must not outlive the original `Tensor` object that `to_expand`
94 // referred to! The deleted rvalue reference overloads of these
95 // functions help with this by preventing trivial use of a temporary
96 // resulting from a function call, but it is still possible to make a
97 // mistake.
98
expand_inplace(const Tensor & tensor,const Tensor & to_expand)99 inline c10::MaybeOwned<Tensor> expand_inplace(
100 const Tensor& tensor,
101 const Tensor& to_expand) {
102 if (tensor.sym_sizes().equals(to_expand.sym_sizes())) {
103 return c10::MaybeOwned<Tensor>::borrowed(to_expand);
104 }
105 return c10::MaybeOwned<Tensor>::owned(
106 to_expand.expand_symint(tensor.sym_sizes()));
107 }
108
109 inline c10::MaybeOwned<Tensor> expand_inplace(
110 const Tensor& tensor,
111 Tensor&& to_expand) = delete;
112
expand_inplace(const Tensor & tensor,const Tensor & to_expand,const char * api_name)113 inline c10::MaybeOwned<Tensor> expand_inplace(
114 const Tensor& tensor,
115 const Tensor& to_expand,
116 const char* api_name) {
117 check_defined({tensor, to_expand}, api_name);
118 return expand_inplace(tensor, to_expand);
119 }
120
121 inline c10::MaybeOwned<Tensor> expand_inplace(
122 const Tensor& tensor,
123 Tensor&& to_expand,
124 const char* api_name) = delete;
125
126 inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
expand_inplace(const Tensor & tensor,const Tensor & to_expand1,const Tensor & to_expand2)127 expand_inplace(
128 const Tensor& tensor,
129 const Tensor& to_expand1,
130 const Tensor& to_expand2) {
131 if (tensor.sizes().equals(to_expand1.sizes()) &&
132 tensor.sizes().equals((to_expand2.sizes()))) {
133 return std::make_tuple(
134 c10::MaybeOwned<Tensor>::borrowed(to_expand1),
135 c10::MaybeOwned<Tensor>::borrowed(to_expand2));
136 }
137
138 return std::make_tuple(
139 c10::MaybeOwned<Tensor>::owned(to_expand1.expand(tensor.sizes())),
140 c10::MaybeOwned<Tensor>::owned(to_expand2.expand(tensor.sizes())));
141 }
142
143 inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
144 expand_inplace(
145 const Tensor& tensor,
146 Tensor&& to_expand1,
147 const Tensor& to_expand2) = delete;
148 inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
149 expand_inplace(
150 const Tensor& tensor,
151 const Tensor& to_expand1,
152 Tensor&& to_expand2) = delete;
153 inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
154 expand_inplace(const Tensor& tensor, Tensor&& to_expand1, Tensor&& to_expand2) =
155 delete;
156
157 inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
expand_inplace(const Tensor & tensor,const Tensor & to_expand1,const Tensor & to_expand2,const char * api_name)158 expand_inplace(
159 const Tensor& tensor,
160 const Tensor& to_expand1,
161 const Tensor& to_expand2,
162 const char* api_name) {
163 check_defined({tensor, to_expand1, to_expand2}, api_name);
164 return expand_inplace(tensor, to_expand1, to_expand2);
165 }
166
167 inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
168 expand_inplace(
169 const Tensor& tensor,
170 Tensor&& to_expand1,
171 const Tensor& to_expand2,
172 const char* api_name) = delete;
173 inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
174 expand_inplace(
175 const Tensor& tensor,
176 const Tensor& to_expand1,
177 Tensor&& to_expand2,
178 const char* api_name) = delete;
179 inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
180 expand_inplace(
181 const Tensor& tensor,
182 Tensor&& to_expand1,
183 Tensor&& to_expand2,
184 const char* api_name) = delete;
185
186 // See NOTE [ ExpandUtils Borrowing ] above for `MaybeOwned` explanation.
187 inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
expand_outplace(const Tensor & to_expand1,const Tensor & to_expand2)188 expand_outplace(const Tensor& to_expand1, const Tensor& to_expand2) {
189 auto s1 = to_expand1.sym_sizes();
190 auto s2 = to_expand2.sym_sizes();
191 if (s1.equals(s2)) {
192 return std::make_tuple(
193 c10::MaybeOwned<Tensor>::borrowed(to_expand1),
194 c10::MaybeOwned<Tensor>::borrowed(to_expand2));
195 }
196
197 auto expanded_size = infer_size_symdimvector(s1, s2);
198 return std::make_tuple(
199 c10::MaybeOwned<Tensor>::owned(to_expand1.expand_symint(expanded_size)),
200 c10::MaybeOwned<Tensor>::owned(to_expand2.expand_symint(expanded_size)));
201 }
202
203 inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
204 expand_outplace(Tensor&& to_expand1, const Tensor& to_expand2) = delete;
205 inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
206 expand_outplace(const Tensor& to_expand1, Tensor&& to_expand2) = delete;
207 inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
208 expand_outplace(Tensor&& to_expand1, Tensor&& to_expand2) = delete;
209
210 inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
expand_outplace(const Tensor & to_expand1,const Tensor & to_expand2,const char * api_name)211 expand_outplace(
212 const Tensor& to_expand1,
213 const Tensor& to_expand2,
214 const char* api_name) {
215 check_defined({to_expand1, to_expand2}, api_name);
216 return expand_outplace(to_expand1, to_expand2);
217 }
218
219 inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
220 expand_outplace(
221 Tensor&& to_expand1,
222 const Tensor& to_expand2,
223 const char* api_name) = delete;
224 inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
225 expand_outplace(
226 const Tensor& to_expand1,
227 Tensor&& to_expand2,
228 const char* api_name) = delete;
229 inline std::tuple<c10::MaybeOwned<Tensor>, c10::MaybeOwned<Tensor>>
230 expand_outplace(
231 Tensor&& to_expand1,
232 Tensor&& to_expand2,
233 const char* api_name) = delete;
234
235 inline std::tuple<
236 c10::MaybeOwned<Tensor>,
237 c10::MaybeOwned<Tensor>,
238 c10::MaybeOwned<Tensor>>
expand_outplace(const Tensor & to_expand1,const Tensor & to_expand2,const Tensor & to_expand3)239 expand_outplace(
240 const Tensor& to_expand1,
241 const Tensor& to_expand2,
242 const Tensor& to_expand3) {
243 if (to_expand1.sizes().equals(to_expand2.sizes()) &&
244 to_expand1.sizes().equals(to_expand3.sizes())) {
245 return std::make_tuple(
246 c10::MaybeOwned<Tensor>::borrowed(to_expand1),
247 c10::MaybeOwned<Tensor>::borrowed(to_expand2),
248 c10::MaybeOwned<Tensor>::borrowed(to_expand3));
249 }
250
251 auto expanded_size12 =
252 infer_size_dimvector(to_expand1.sizes(), to_expand2.sizes());
253 auto expanded_size =
254 infer_size_dimvector(expanded_size12, to_expand3.sizes());
255 return std::make_tuple(
256 c10::MaybeOwned<Tensor>::owned(to_expand1.expand(expanded_size)),
257 c10::MaybeOwned<Tensor>::owned(to_expand2.expand(expanded_size)),
258 c10::MaybeOwned<Tensor>::owned(to_expand3.expand(expanded_size)));
259 }
260
261 inline std::tuple<
262 c10::MaybeOwned<Tensor>,
263 c10::MaybeOwned<Tensor>,
264 c10::MaybeOwned<Tensor>>
265 expand_outplace(
266 Tensor&& to_expand1,
267 const Tensor& to_expand2,
268 const Tensor& to_expand3) = delete;
269 inline std::tuple<
270 c10::MaybeOwned<Tensor>,
271 c10::MaybeOwned<Tensor>,
272 c10::MaybeOwned<Tensor>>
273 expand_outplace(
274 const Tensor& to_expand1,
275 Tensor&& to_expand2,
276 const Tensor& to_expand3) = delete;
277 inline std::tuple<
278 c10::MaybeOwned<Tensor>,
279 c10::MaybeOwned<Tensor>,
280 c10::MaybeOwned<Tensor>>
281 expand_outplace(
282 Tensor&& to_expand1,
283 Tensor&& to_expand2,
284 const Tensor& to_expand3) = delete;
285 inline std::tuple<
286 c10::MaybeOwned<Tensor>,
287 c10::MaybeOwned<Tensor>,
288 c10::MaybeOwned<Tensor>>
289 expand_outplace(
290 const Tensor& to_expand1,
291 const Tensor& to_expand2,
292 Tensor&& to_expand3) = delete;
293 inline std::tuple<
294 c10::MaybeOwned<Tensor>,
295 c10::MaybeOwned<Tensor>,
296 c10::MaybeOwned<Tensor>>
297 expand_outplace(
298 Tensor&& to_expand1,
299 const Tensor& to_expand2,
300 Tensor&& to_expand3) = delete;
301 inline std::tuple<
302 c10::MaybeOwned<Tensor>,
303 c10::MaybeOwned<Tensor>,
304 c10::MaybeOwned<Tensor>>
305 expand_outplace(
306 const Tensor& to_expand1,
307 Tensor&& to_expand2,
308 Tensor&& to_expand3) = delete;
309 inline std::tuple<
310 c10::MaybeOwned<Tensor>,
311 c10::MaybeOwned<Tensor>,
312 c10::MaybeOwned<Tensor>>
313 expand_outplace(Tensor&& to_expand1, Tensor&& to_expand2, Tensor&& to_expand3) =
314 delete;
315
316 inline std::tuple<
317 c10::MaybeOwned<Tensor>,
318 c10::MaybeOwned<Tensor>,
319 c10::MaybeOwned<Tensor>>
expand_outplace(const Tensor & to_expand1,const Tensor & to_expand2,const Tensor & to_expand3,const char * api_name)320 expand_outplace(
321 const Tensor& to_expand1,
322 const Tensor& to_expand2,
323 const Tensor& to_expand3,
324 const char* api_name) {
325 check_defined({to_expand1, to_expand2, to_expand3}, api_name);
326 return expand_outplace(to_expand1, to_expand2, to_expand3);
327 }
328
329 inline std::tuple<
330 c10::MaybeOwned<Tensor>,
331 c10::MaybeOwned<Tensor>,
332 c10::MaybeOwned<Tensor>>
333 expand_outplace(
334 Tensor&& to_expand1,
335 const Tensor& to_expand2,
336 const Tensor& to_expand3,
337 const char* api_name) = delete;
338 inline std::tuple<
339 c10::MaybeOwned<Tensor>,
340 c10::MaybeOwned<Tensor>,
341 c10::MaybeOwned<Tensor>>
342 expand_outplace(
343 const Tensor& to_expand1,
344 Tensor&& to_expand2,
345 const Tensor& to_expand3,
346 const char* api_name) = delete;
347 inline std::tuple<
348 c10::MaybeOwned<Tensor>,
349 c10::MaybeOwned<Tensor>,
350 c10::MaybeOwned<Tensor>>
351 expand_outplace(
352 Tensor&& to_expand1,
353 Tensor&& to_expand2,
354 const Tensor& to_expand3,
355 const char* api_name) = delete;
356 inline std::tuple<
357 c10::MaybeOwned<Tensor>,
358 c10::MaybeOwned<Tensor>,
359 c10::MaybeOwned<Tensor>>
360 expand_outplace(
361 const Tensor& to_expand1,
362 const Tensor& to_expand2,
363 Tensor&& to_expand3,
364 const char* api_name) = delete;
365 inline std::tuple<
366 c10::MaybeOwned<Tensor>,
367 c10::MaybeOwned<Tensor>,
368 c10::MaybeOwned<Tensor>>
369 expand_outplace(
370 Tensor&& to_expand1,
371 const Tensor& to_expand2,
372 Tensor&& to_expand3,
373 const char* api_name) = delete;
374 inline std::tuple<
375 c10::MaybeOwned<Tensor>,
376 c10::MaybeOwned<Tensor>,
377 c10::MaybeOwned<Tensor>>
378 expand_outplace(
379 const Tensor& to_expand1,
380 Tensor&& to_expand2,
381 Tensor&& to_expand3,
382 const char* api_name) = delete;
383 inline std::tuple<
384 c10::MaybeOwned<Tensor>,
385 c10::MaybeOwned<Tensor>,
386 c10::MaybeOwned<Tensor>>
387 expand_outplace(
388 Tensor&& to_expand1,
389 Tensor&& to_expand2,
390 Tensor&& to_expand3,
391 const char* api_name) = delete;
392
expand_size(const Tensor & to_expand,IntArrayRef sizes)393 inline c10::MaybeOwned<Tensor> expand_size(
394 const Tensor& to_expand,
395 IntArrayRef sizes) {
396 if (to_expand.sizes().equals(sizes)) {
397 return c10::MaybeOwned<Tensor>::borrowed(to_expand);
398 }
399
400 return c10::MaybeOwned<Tensor>::owned(to_expand.expand(sizes));
401 }
402
403 inline c10::MaybeOwned<Tensor> expand_size(
404 Tensor&& to_expand,
405 IntArrayRef sizes) = delete;
406
expand_size(const Tensor & to_expand,IntArrayRef sizes,const char * api_name)407 inline c10::MaybeOwned<Tensor> expand_size(
408 const Tensor& to_expand,
409 IntArrayRef sizes,
410 const char* api_name) {
411 check_defined({to_expand}, api_name);
412 return expand_size(to_expand, sizes);
413 }
414
415 inline c10::MaybeOwned<Tensor> expand_size(
416 Tensor&& to_expand,
417 IntArrayRef sizes,
418 const char* api_name) = delete;
419
expand_outplace(TensorList to_expand)420 inline std::vector<Tensor> expand_outplace(TensorList to_expand) {
421 // expands a list of Tensors; ignores undefined (null) tensors
422 bool first = true;
423 DimVector sizes;
424 for (const auto i : c10::irange(to_expand.size())) {
425 if (!to_expand[i].defined()) {
426 continue;
427 } else if (first) {
428 sizes = to_expand[i].sizes();
429 first = false;
430 } else {
431 sizes = infer_size_dimvector(sizes, to_expand[i].sizes());
432 }
433 }
434
435 std::vector<Tensor> result(to_expand.size());
436 for (const auto i : c10::irange(to_expand.size())) {
437 if (!to_expand[i].defined()) {
438 continue;
439 } else if (to_expand[i].sizes().equals(sizes)) {
440 result[i] = to_expand[i];
441 } else {
442 result[i] = to_expand[i].expand(sizes);
443 }
444 }
445 return result;
446 }
447
448 template <typename T>
449 inline Tensor _sum_to(
450 Tensor tensor,
451 const c10::ArrayRef<T> shape,
452 bool always_return_non_view = false) {
453 if (shape.size() == 0) {
454 return tensor.sum();
455 }
456
457 auto sizes = at::symint::sizes<T>(tensor);
458 c10::SmallVector<int64_t, 8> reduce_dims;
459 const int64_t leading_dims = sizes.size() - shape.size();
460 for (const auto i : c10::irange(leading_dims)) {
461 reduce_dims.push_back(i);
462 }
463 for (int64_t i = leading_dims; i < static_cast<int64_t>(sizes.size()); ++i) {
464 if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(shape[i - leading_dims], 1)) &&
465 TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(sizes[i], 1))) {
466 reduce_dims.push_back(i);
467 }
468 }
469
470 if (!reduce_dims.empty()) {
471 tensor = tensor.sum(reduce_dims, /*keepdim=*/true);
472 }
473
474 if (always_return_non_view) {
475 // This is only actually used by the functionalization pass.
476 // We want to be able to guarantee that this function doesn't return a view
477 // of the input.
478 return leading_dims > 0 ? at::symint::view_copy<T>(tensor, shape)
479 : tensor.clone();
480 } else {
481 return leading_dims > 0 ? at::symint::view<T>(tensor, shape) : tensor;
482 }
483 }
484
485 inline Tensor sum_to(
486 Tensor tensor,
487 const c10::SymIntArrayRef shape,
488 bool always_return_non_view = false) {
489 return _sum_to(std::move(tensor), shape, always_return_non_view);
490 }
491
492 // Sums `tensor` repeatedly to produce a tensor of shape `shape`.
493 // Precondition: is_expandable_to(shape, tensor.sizes()) must be true
494 inline Tensor sum_to(
495 Tensor tensor,
496 const IntArrayRef shape,
497 bool always_return_non_view = false) {
498 return _sum_to(std::move(tensor), shape, always_return_non_view);
499 }
500
is_expandable_to(SymIntArrayRef shape,c10::SymIntArrayRef desired)501 inline bool is_expandable_to(
502 SymIntArrayRef shape,
503 c10::SymIntArrayRef desired) {
504 size_t ndim = shape.size();
505 size_t target_dim = desired.size();
506 if (ndim > target_dim) {
507 return false;
508 }
509 for (const auto i : c10::irange(ndim)) {
510 const auto& size = shape[ndim - i - 1];
511 const auto& target = desired[target_dim - i - 1];
512 if (size != target && size != 1) {
513 return false;
514 }
515 }
516 return true;
517 }
518
is_expandable_to(IntArrayRef shape,IntArrayRef desired)519 inline bool is_expandable_to(IntArrayRef shape, IntArrayRef desired) {
520 auto sym_shape = c10::SymIntArrayRef(
521 reinterpret_cast<const c10::SymInt*>(shape.data()), shape.size());
522 auto sym_desired = c10::SymIntArrayRef(
523 reinterpret_cast<const c10::SymInt*>(desired.data()), desired.size());
524 return is_expandable_to(sym_shape, sym_desired);
525 }
526
527 } // namespace at
528