1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/ExpandUtils.h>
3 #include <ATen/ExpandBase.h>
4
5 #include <c10/util/irange.h>
6
7 namespace at {
8 namespace internal {
expand_slow_path(const TensorBase & self,IntArrayRef size)9 TensorBase expand_slow_path(const TensorBase &self, IntArrayRef size) {
10 return OptionalTensorRef(self)->expand(size);
11 }
12 } // namespace internal
13
14 namespace {
15 // NOTE: are_expandable did a similar check, please keep them sync if change is needed
16 template <typename Container, typename ArrayType>
infer_size_impl(ArrayType a,ArrayType b)17 Container infer_size_impl(ArrayType a, ArrayType b) {
18 // Use ptrdiff_t to ensure signed comparison.
19 auto dimsA = static_cast<ptrdiff_t>(a.size());
20 auto dimsB = static_cast<ptrdiff_t>(b.size());
21 auto ndim = dimsA > dimsB ? dimsA : dimsB;
22 Container expandedSizes(ndim);
23
24 for (ptrdiff_t i = ndim - 1; i >= 0; --i) {
25 ptrdiff_t offset = ndim - 1 - i;
26 ptrdiff_t dimA = dimsA - 1 - offset;
27 ptrdiff_t dimB = dimsB - 1 - offset;
28 auto sizeA = (dimA >= 0) ? a[dimA] : 1;
29 auto sizeB = (dimB >= 0) ? b[dimB] : 1;
30
31 TORCH_CHECK(
32 sizeA == sizeB || sizeA == 1 || sizeB == 1,
33 "The size of tensor a (", sizeA,
34 ") must match the size of tensor b (", sizeB,
35 ") at non-singleton dimension ", i);
36
37 // 1s map to the other size (even 0).
38 expandedSizes[i] = sizeA == 1 ? std::move(sizeB) : std::move(sizeA);
39 }
40
41 return expandedSizes;
42 }
43 }
44
infer_size(IntArrayRef a,IntArrayRef b)45 std::vector<int64_t> infer_size(IntArrayRef a, IntArrayRef b) {
46 return infer_size_impl<std::vector<int64_t>>(a, b);
47 }
48
infer_size_symint(SymIntArrayRef a,SymIntArrayRef b)49 std::vector<SymInt> infer_size_symint(SymIntArrayRef a, SymIntArrayRef b) {
50 return infer_size_impl<std::vector<SymInt>>(a, b);
51 }
52
infer_size_dimvector(IntArrayRef a,IntArrayRef b)53 DimVector infer_size_dimvector(IntArrayRef a, IntArrayRef b) {
54 return infer_size_impl<DimVector, IntArrayRef>(a, b);
55 }
56
infer_size_symdimvector(SymIntArrayRef a,SymIntArrayRef b)57 SymDimVector infer_size_symdimvector(SymIntArrayRef a, SymIntArrayRef b) {
58 return infer_size_impl<SymDimVector, SymIntArrayRef>(a, b);
59 }
60
61 template<typename Container>
inferExpandGeometryImpl(IntArrayRef tensor_sizes,IntArrayRef tensor_strides,IntArrayRef sizes)62 C10_ALWAYS_INLINE InferExpandGeometryResult<Container> inferExpandGeometryImpl(
63 IntArrayRef tensor_sizes,
64 IntArrayRef tensor_strides,
65 IntArrayRef sizes) {
66 int64_t ndim = static_cast<int64_t>(sizes.size());
67 int64_t tensor_dim = static_cast<int64_t>(tensor_sizes.size());
68
69 if (tensor_dim == 0) {
70 return InferExpandGeometryResult<Container>(sizes, ndim);
71 }
72
73 InferExpandGeometryResult<Container> result(ndim);
74 auto& expandedSizes = result.sizes;
75 auto& expandedStrides = result.strides;
76
77 // create a new geometry for the tensors
78 for (int64_t i = ndim - 1; i >= 0; --i) {
79 int64_t offset = ndim - 1 - i;
80 int64_t dim = tensor_dim - 1 - offset;
81 int64_t size = (dim >= 0) ? tensor_sizes[dim] : 1;
82 int64_t stride = (dim >= 0) ? tensor_strides[dim]
83 : expandedSizes[i + 1] * expandedStrides[i + 1];
84 int64_t targetSize = sizes[i];
85 if (targetSize == -1) {
86 TORCH_CHECK(
87 dim >= 0,
88 "The expanded size of the tensor (",
89 targetSize,
90 ") isn't allowed in a leading, non-existing dimension ",
91 i);
92 targetSize = size;
93 }
94 if (size != targetSize) {
95 TORCH_CHECK(
96 size == 1,
97 "The expanded size of the tensor (",
98 targetSize,
99 ") must match the existing size (",
100 size,
101 ") at non-singleton dimension ",
102 i,
103 ". Target sizes: ",
104 sizes,
105 ". Tensor sizes: ",
106 tensor_sizes);
107 size = targetSize;
108 stride = 0;
109 }
110 expandedSizes[i] = size;
111 expandedStrides[i] = stride;
112 }
113 return result;
114 }
115
inferExpandGeometry(IntArrayRef tensor_sizes,IntArrayRef tensor_strides,IntArrayRef sizes)116 std::tuple<std::vector<int64_t>, std::vector<int64_t>> inferExpandGeometry(
117 IntArrayRef tensor_sizes,
118 IntArrayRef tensor_strides,
119 IntArrayRef sizes) {
120 auto result = inferExpandGeometryImpl<std::vector<int64_t>>(
121 tensor_sizes, tensor_strides, sizes);
122 return std::make_tuple(std::move(result.sizes), std::move(result.strides));
123 }
124
inferExpandGeometry_dimvector(IntArrayRef tensor_sizes,IntArrayRef tensor_strides,IntArrayRef sizes)125 InferExpandGeometryResult<DimVector> inferExpandGeometry_dimvector(
126 IntArrayRef tensor_sizes,
127 IntArrayRef tensor_strides,
128 IntArrayRef sizes) {
129 return inferExpandGeometryImpl<DimVector>(
130 tensor_sizes, tensor_strides, sizes);
131 }
132
133
134 // This function returns a dense and non-overlapping strides, which keeps the same layout permutation
135 // as the input `tensor_strides`, computed based on the input `tensor_sizes`.
136 // Note:
137 // 1. This function expects the inputs `tensor_strides` and `tensor_sizes` are non-dense or overlapping,
138 // If the inputs are densed and non-overlapping, the output strides will be the same as `tensor_strides`.
139 // However, this function won't check whether inputs are dense or overlapping, so the whole function will
140 // still be executed even the inputs are already dense and non-overlapping, this will cause slowness.
141 //
142 // Please verify whether the inputs are non-dense or overlapping before calling this function if possible,
143 // if the inputs come from a tensor, you can check this through `is_non_overlapping_and_dense()`
144 //
145 // 2. The strides propagation rule that is used in this function is exactily the same as what is being used in
146 // TensorIterator. Please refer to https://github.com/pytorch/pytorch/pull/42922 for more details
147
infer_dense_strides(IntArrayRef tensor_sizes,IntArrayRef tensor_strides)148 std::vector<int64_t> infer_dense_strides(IntArrayRef tensor_sizes, IntArrayRef tensor_strides) {
149
150 TORCH_CHECK(tensor_sizes.size() == tensor_strides.size(),
151 "Input sizes and strides should have same size but got ", tensor_sizes.size(), " and ", tensor_strides.size());
152
153 size_t ndim = tensor_sizes.size();
154 if (ndim == 0) {
155 return {};
156 }
157 if (ndim == 1) {
158 return {1};
159 }
160
161 std::vector<int64_t> perm(ndim);
162 // initialize perm with n-1, n-2, ..., 1, 0
163 std::iota(perm.rbegin(), perm.rend(), 0);
164
165 // The following sorting algorithm has exactly the same behavior as TensorIterator
166 // This is to make sure we have the same stride propagation everywhere.
167
168 // return -1 if dim0 should come before dim1
169 // return 1 if dim0 should come after dim1
170 // return 0 if comparison is ambiguous
171 auto should_swap = [&](size_t dim0, size_t dim1) {
172 int64_t stride0 = tensor_strides[dim0];
173 int64_t stride1 = tensor_strides[dim1];
174
175 // if any stride is 0, treat it as ambiguous comparison to
176 // keep the same behavior as TensorIterator
177 if (stride0 == 0 || stride1 == 0) {
178 return 0;
179 }
180 if (stride0 < stride1) {
181 return -1;
182 }
183 if (stride0 > stride1) {
184 return 1;
185 }
186 // for equal strides, the dimension with smaller size goes front
187 if (tensor_sizes[dim0] > tensor_sizes[dim1]) {
188 return 1;
189 }
190 return 0;
191 };
192
193 // Insertion sort (stable) indices in `perm` based on input tensor's stride and shape,
194 // all dimensions with 0 stride won't move. This is the same behavior as TensorIterator.
195 // eg. Given tensor with size/stride (6, 5, 4, 3, 2)/(6, 0, 120, 0, 1), the initial `perm`
196 // is (4, 3, 2, 1, 0) and the sorted `perm` will be (4, 3, 0, 1, 2)
197 for (const auto i : c10::irange(1, ndim)) {
198 auto dim1 = i;
199 for (const auto j : c10::irange(1, i + 1)) {
200 auto dim0 = i - j;
201 int comparison = should_swap(perm[dim0], perm[dim1]);
202 if (comparison > 0) {
203 std::swap(perm[dim0], perm[dim1]);
204 dim1 = dim0;
205 }
206 else if (comparison < 0) {
207 break;
208 }
209 }
210 }
211
212 // compute output strides which preserves the input tensor's memory layout
213 std::vector<int64_t> out_strides(ndim);
214 int64_t curr_stride = 1;
215 for (const auto i : c10::irange(ndim)) {
216 int64_t idx = perm[i];
217 out_strides[idx] = curr_stride;
218 // Note: for size 0, we simply treated it as 1, it really doesn't matter here
219 // since the total number of element is 0.
220 if (tensor_sizes[idx] > 1) {
221 curr_stride *= tensor_sizes[idx];
222 }
223 }
224 return out_strides;
225 }
226
227 } // namespace at
228