xref: /aosp_15_r20/external/pytorch/aten/src/ATen/ExpandUtils.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/ExpandUtils.h>
3 #include <ATen/ExpandBase.h>
4 
5 #include <c10/util/irange.h>
6 
7 namespace at {
8 namespace internal {
expand_slow_path(const TensorBase & self,IntArrayRef size)9 TensorBase expand_slow_path(const TensorBase &self, IntArrayRef size) {
10   return OptionalTensorRef(self)->expand(size);
11 }
12 } // namespace internal
13 
14 namespace {
15 // NOTE: are_expandable did a similar check, please keep them sync if change is needed
16 template <typename Container, typename ArrayType>
infer_size_impl(ArrayType a,ArrayType b)17 Container infer_size_impl(ArrayType a, ArrayType b) {
18   // Use ptrdiff_t to ensure signed comparison.
19   auto dimsA = static_cast<ptrdiff_t>(a.size());
20   auto dimsB = static_cast<ptrdiff_t>(b.size());
21   auto ndim = dimsA > dimsB ? dimsA : dimsB;
22   Container expandedSizes(ndim);
23 
24   for (ptrdiff_t i = ndim - 1; i >= 0; --i) {
25     ptrdiff_t offset = ndim - 1 - i;
26     ptrdiff_t dimA = dimsA - 1 - offset;
27     ptrdiff_t dimB = dimsB - 1 - offset;
28     auto sizeA = (dimA >= 0) ? a[dimA] : 1;
29     auto sizeB = (dimB >= 0) ? b[dimB] : 1;
30 
31     TORCH_CHECK(
32         sizeA == sizeB || sizeA == 1 || sizeB == 1,
33         "The size of tensor a (", sizeA,
34         ") must match the size of tensor b (", sizeB,
35         ") at non-singleton dimension ", i);
36 
37       // 1s map to the other size (even 0).
38       expandedSizes[i] = sizeA == 1 ? std::move(sizeB) : std::move(sizeA);
39   }
40 
41   return expandedSizes;
42 }
43 }
44 
infer_size(IntArrayRef a,IntArrayRef b)45 std::vector<int64_t> infer_size(IntArrayRef a, IntArrayRef b) {
46   return infer_size_impl<std::vector<int64_t>>(a, b);
47 }
48 
infer_size_symint(SymIntArrayRef a,SymIntArrayRef b)49 std::vector<SymInt> infer_size_symint(SymIntArrayRef a, SymIntArrayRef b) {
50   return infer_size_impl<std::vector<SymInt>>(a, b);
51 }
52 
infer_size_dimvector(IntArrayRef a,IntArrayRef b)53 DimVector infer_size_dimvector(IntArrayRef a, IntArrayRef b) {
54   return infer_size_impl<DimVector, IntArrayRef>(a, b);
55 }
56 
infer_size_symdimvector(SymIntArrayRef a,SymIntArrayRef b)57 SymDimVector infer_size_symdimvector(SymIntArrayRef a, SymIntArrayRef b) {
58   return infer_size_impl<SymDimVector, SymIntArrayRef>(a, b);
59 }
60 
61 template<typename Container>
inferExpandGeometryImpl(IntArrayRef tensor_sizes,IntArrayRef tensor_strides,IntArrayRef sizes)62 C10_ALWAYS_INLINE InferExpandGeometryResult<Container> inferExpandGeometryImpl(
63     IntArrayRef tensor_sizes,
64     IntArrayRef tensor_strides,
65     IntArrayRef sizes) {
66   int64_t ndim = static_cast<int64_t>(sizes.size());
67   int64_t tensor_dim = static_cast<int64_t>(tensor_sizes.size());
68 
69   if (tensor_dim == 0) {
70     return InferExpandGeometryResult<Container>(sizes, ndim);
71   }
72 
73   InferExpandGeometryResult<Container> result(ndim);
74   auto& expandedSizes = result.sizes;
75   auto& expandedStrides = result.strides;
76 
77   // create a new geometry for the tensors
78   for (int64_t i = ndim - 1; i >= 0; --i) {
79     int64_t offset = ndim - 1 - i;
80     int64_t dim = tensor_dim - 1 - offset;
81     int64_t size = (dim >= 0) ? tensor_sizes[dim] : 1;
82     int64_t stride = (dim >= 0) ? tensor_strides[dim]
83                                 : expandedSizes[i + 1] * expandedStrides[i + 1];
84     int64_t targetSize = sizes[i];
85     if (targetSize == -1) {
86       TORCH_CHECK(
87           dim >= 0,
88           "The expanded size of the tensor (",
89           targetSize,
90           ") isn't allowed in a leading, non-existing dimension ",
91           i);
92       targetSize = size;
93     }
94     if (size != targetSize) {
95       TORCH_CHECK(
96           size == 1,
97           "The expanded size of the tensor (",
98           targetSize,
99           ") must match the existing size (",
100           size,
101           ") at non-singleton dimension ",
102           i,
103           ".  Target sizes: ",
104           sizes,
105           ".  Tensor sizes: ",
106           tensor_sizes);
107       size = targetSize;
108       stride = 0;
109     }
110     expandedSizes[i] = size;
111     expandedStrides[i] = stride;
112   }
113   return result;
114 }
115 
inferExpandGeometry(IntArrayRef tensor_sizes,IntArrayRef tensor_strides,IntArrayRef sizes)116 std::tuple<std::vector<int64_t>, std::vector<int64_t>> inferExpandGeometry(
117     IntArrayRef tensor_sizes,
118     IntArrayRef tensor_strides,
119     IntArrayRef sizes) {
120   auto result = inferExpandGeometryImpl<std::vector<int64_t>>(
121       tensor_sizes, tensor_strides, sizes);
122   return std::make_tuple(std::move(result.sizes), std::move(result.strides));
123 }
124 
inferExpandGeometry_dimvector(IntArrayRef tensor_sizes,IntArrayRef tensor_strides,IntArrayRef sizes)125 InferExpandGeometryResult<DimVector> inferExpandGeometry_dimvector(
126     IntArrayRef tensor_sizes,
127     IntArrayRef tensor_strides,
128     IntArrayRef sizes) {
129   return inferExpandGeometryImpl<DimVector>(
130       tensor_sizes, tensor_strides, sizes);
131 }
132 
133 
134 // This function returns a dense and non-overlapping strides, which keeps the same layout permutation
135 // as the input `tensor_strides`, computed based on the input `tensor_sizes`.
136 // Note:
137 // 1. This function expects the inputs `tensor_strides` and `tensor_sizes` are non-dense or overlapping,
138 //    If the inputs are densed and non-overlapping, the output strides will be the same as `tensor_strides`.
139 //    However, this function won't check whether inputs are dense or overlapping, so the whole function will
140 //    still be executed even the inputs are already dense and non-overlapping, this will cause slowness.
141 //
142 //    Please verify whether the inputs are non-dense or overlapping before calling this function if possible,
143 //    if the inputs come from a tensor, you can check this through `is_non_overlapping_and_dense()`
144 //
145 // 2. The strides propagation rule that is used in this function is exactily the same as what is being used in
146 //    TensorIterator. Please refer to https://github.com/pytorch/pytorch/pull/42922 for more details
147 
infer_dense_strides(IntArrayRef tensor_sizes,IntArrayRef tensor_strides)148 std::vector<int64_t> infer_dense_strides(IntArrayRef tensor_sizes, IntArrayRef tensor_strides) {
149 
150   TORCH_CHECK(tensor_sizes.size() == tensor_strides.size(),
151     "Input sizes and strides should have same size but got ", tensor_sizes.size(), " and ", tensor_strides.size());
152 
153   size_t ndim = tensor_sizes.size();
154   if (ndim == 0) {
155     return {};
156   }
157   if (ndim == 1) {
158     return {1};
159   }
160 
161   std::vector<int64_t> perm(ndim);
162   // initialize perm with n-1, n-2, ..., 1, 0
163   std::iota(perm.rbegin(), perm.rend(), 0);
164 
165   // The following sorting algorithm has exactly the same behavior as TensorIterator
166   // This is to make sure we have the same stride propagation everywhere.
167 
168   // return -1 if dim0 should come before dim1
169   // return  1 if dim0 should come after dim1
170   // return  0 if comparison is ambiguous
171   auto should_swap = [&](size_t dim0, size_t dim1) {
172     int64_t stride0 = tensor_strides[dim0];
173     int64_t stride1 = tensor_strides[dim1];
174 
175     // if any stride is 0, treat it as ambiguous comparison to
176     // keep the same behavior as TensorIterator
177     if (stride0 == 0 || stride1 == 0) {
178       return 0;
179     }
180     if (stride0 < stride1) {
181       return -1;
182     }
183     if (stride0 > stride1) {
184       return 1;
185     }
186     // for equal strides, the dimension with smaller size goes front
187     if (tensor_sizes[dim0] > tensor_sizes[dim1]) {
188       return 1;
189     }
190     return 0;
191   };
192 
193   // Insertion sort (stable) indices in `perm` based on input tensor's stride and shape,
194   // all dimensions with 0 stride won't move. This is the same behavior as TensorIterator.
195   // eg. Given tensor with size/stride (6, 5, 4, 3, 2)/(6, 0, 120, 0, 1), the initial `perm`
196   //     is (4, 3, 2, 1, 0) and the sorted `perm` will be (4, 3, 0, 1, 2)
197   for (const auto i : c10::irange(1, ndim)) {
198     auto dim1 = i;
199     for (const auto j : c10::irange(1, i + 1)) {
200       auto dim0 = i - j;
201       int comparison = should_swap(perm[dim0], perm[dim1]);
202       if (comparison > 0) {
203         std::swap(perm[dim0], perm[dim1]);
204         dim1 = dim0;
205       }
206       else if (comparison < 0) {
207         break;
208       }
209     }
210   }
211 
212   // compute output strides which preserves the input tensor's memory layout
213   std::vector<int64_t> out_strides(ndim);
214   int64_t curr_stride = 1;
215   for (const auto i : c10::irange(ndim)) {
216     int64_t idx = perm[i];
217     out_strides[idx] = curr_stride;
218     // Note: for size 0, we simply treated it as 1, it really doesn't matter here
219     // since the total number of element is 0.
220     if (tensor_sizes[idx] > 1) {
221       curr_stride *= tensor_sizes[idx];
222     }
223   }
224   return out_strides;
225 }
226 
227 } // namespace at
228