xref: /aosp_15_r20/external/pytorch/aten/src/ATen/InferSize.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/DimVector.h>
4 #include <c10/core/ScalarType.h>
5 #include <c10/core/SymIntArrayRef.h>
6 #include <c10/util/DimVector.h>
7 #include <optional>
8 #include <sstream>
9 #include <vector>
10 
11 namespace at {
12 
13 // Infers the size of a dim with size -1, if it exists. Also checks that new
14 // shape is compatible with the number of elements.
15 //
16 // templated to handle std::vector<int64_t> and DimVector use cases, see
17 // below
18 //
19 template <typename InputArrayRef, typename NumelType, typename ResultVec>
infer_size_impl(InputArrayRef shape,NumelType numel,ResultVec & res)20 inline void infer_size_impl(
21     InputArrayRef shape,
22     NumelType numel,
23     ResultVec& res) {
24   NumelType newsize = 1;
25   // N.B. this is an index, not a sym dim!
26   std::optional<int64_t> infer_dim;
27   for (int64_t dim = 0, ndim = shape.size(); dim != ndim; dim++) {
28     if (shape[dim] == -1) {
29       if (infer_dim) {
30         throw std::runtime_error("only one dimension can be inferred");
31       }
32       infer_dim = dim;
33     } else if (shape[dim] >= 0) {
34       newsize *= shape[dim];
35     } else {
36       AT_ERROR("invalid shape dimension ", shape[dim]);
37     }
38   }
39 
40   if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(numel, newsize)) ||
41       (infer_dim && newsize > 0 && numel % newsize == 0)) {
42     if (infer_dim) {
43       // We have a degree of freedom here to select the dimension size; follow
44       // NumPy semantics and just bail.  However, a nice error message is needed
45       // because users often use `view` as a way to flatten & unflatten
46       // dimensions and will otherwise be confused why
47       //   empty_tensor.view( 0, 0)
48       // works yet
49       //   empty_tensor.view(-1, 0)
50       // doesn't.
51       TORCH_CHECK(
52           newsize != 0,
53           "cannot reshape tensor of 0 elements into shape ",
54           shape,
55           " because the unspecified dimension size -1 can be any "
56           "value and is ambiguous");
57       res[*infer_dim] = numel / newsize;
58     }
59     return;
60   }
61 
62   std::ostringstream ss;
63   ss << "shape '" << shape << "' is invalid for input of size " << numel;
64   throw std::runtime_error(ss.str());
65 }
66 
infer_size(IntArrayRef shape,int64_t numel)67 inline std::vector<int64_t> infer_size(IntArrayRef shape, int64_t numel) {
68   auto res = shape.vec();
69   infer_size_impl(shape, numel, res);
70   return res;
71 }
72 
infer_size_dv(IntArrayRef shape,int64_t numel)73 inline at::DimVector infer_size_dv(IntArrayRef shape, int64_t numel) {
74   auto res = at::DimVector(shape);
75   infer_size_impl(shape, numel, res);
76   return res;
77 }
78 
infer_size_dv(c10::SymIntArrayRef shape,c10::SymInt numel)79 inline at::SymDimVector infer_size_dv(
80     c10::SymIntArrayRef shape,
81     c10::SymInt numel) {
82   auto res = at::SymDimVector(shape);
83   infer_size_impl<c10::SymIntArrayRef, c10::SymInt, at::SymDimVector>(
84       shape, std::move(numel), res);
85   return res;
86 }
87 
88 } // namespace at
89