1 #include <ATen/ATen.h>
2 #include <ATen/Config.h>
3 #include <ATen/TensorUtils.h>
4 #include <c10/util/accumulate.h>
5 #include <c10/util/irange.h>
6
7 #include <ostream>
8 #include <sstream>
9
10 namespace at {
11
operator <<(std::ostream & out,const TensorGeometryArg & t)12 std::ostream& operator<<(std::ostream & out, const TensorGeometryArg& t) {
13 if (t.pos == 0) {
14 // 0 is distinguished; it usually indicates 'self' or the return
15 // tensor
16 out << "'" << t.name << "'";
17 } else {
18 out << "argument #" << t.pos << " '" << t.name << "'";
19 }
20 return out;
21 }
22
checkDim(CheckedFrom c,const Tensor & tensor,const char * name,int pos,int64_t dim)23 void checkDim(
24 CheckedFrom c,
25 const Tensor& tensor,
26 const char* name,
27 int pos, // 1-indexed
28 int64_t dim) {
29 TORCH_CHECK(
30 tensor.dim() == dim,
31 "Expected ",
32 dim,
33 "-dimensional tensor, but got ",
34 tensor.dim(),
35 "-dimensional tensor for ",
36 TensorGeometryArg(TensorArg({tensor, name, pos})),
37 " (while checking arguments for ",
38 c,
39 ")");
40 }
41
checkDim(CheckedFrom c,const TensorGeometryArg & t,int64_t dim)42 void checkDim(CheckedFrom c, const TensorGeometryArg& t, int64_t dim) {
43 TORCH_CHECK(t->dim() == dim,
44 "Expected ", dim, "-dimensional tensor, but got ", t->dim(),
45 "-dimensional tensor for ", t," (while checking arguments for ", c, ")");
46 }
47
checkDimRange(CheckedFrom c,const TensorGeometryArg & t,int64_t dim_start,int64_t dim_end)48 void checkDimRange(CheckedFrom c, const TensorGeometryArg& t, int64_t dim_start, int64_t dim_end) {
49 TORCH_CHECK(
50 t->dim() >= dim_start && t->dim() < dim_end,
51 "Expected ", dim_start, " to ", (dim_end - 1), " dimensions, but got ",
52 t->dim(), "-dimensional tensor for ", t, " (while checking arguments for ",
53 c, ")");
54 }
55
checkContiguous(CheckedFrom c,const TensorGeometryArg & t)56 void checkContiguous(CheckedFrom c, const TensorGeometryArg& t) {
57 TORCH_CHECK(
58 t->is_contiguous(),
59 "Expected contiguous tensor, but got non-contiguous tensor for ", t,
60 " (while checking arguments for ", c, ")");
61 }
62
checkAllContiguous(CheckedFrom c,at::ArrayRef<TensorArg> ts)63 void checkAllContiguous(CheckedFrom c, at::ArrayRef<TensorArg> ts) {
64 for (auto& t : ts) {
65 if (!t->defined()) continue;
66 checkContiguous(c, t);
67 }
68 }
69
checkSize(CheckedFrom c,const TensorGeometryArg & t,IntArrayRef sizes)70 void checkSize(CheckedFrom c, const TensorGeometryArg& t, IntArrayRef sizes) {
71 checkDim(c, t, static_cast<int64_t>(sizes.size()));
72 TORCH_CHECK(
73 t->sizes().equals(sizes),
74 "Expected tensor of size ", sizes, ", but got tensor of size ", t->sizes(),
75 " for ", t, " (while checking arguments for ", c, ")");
76 }
77
checkSize_symint(CheckedFrom c,const TensorGeometryArg & t,c10::SymIntArrayRef sizes)78 void checkSize_symint(CheckedFrom c, const TensorGeometryArg& t, c10::SymIntArrayRef sizes) {
79 checkDim(c, t, static_cast<int64_t>(sizes.size()));
80 TORCH_CHECK(
81 t->sym_sizes().equals(sizes),
82 "Expected tensor of size ", sizes, ", but got tensor of size ", t->sizes(),
83 " for ", t, " (while checking arguments for ", c, ")");
84 }
85
checkSize(CheckedFrom c,const TensorGeometryArg & t,int64_t dim,int64_t size)86 void checkSize(CheckedFrom c, const TensorGeometryArg& t, int64_t dim, int64_t size) {
87 TORCH_CHECK(
88 t->size(dim) == size,
89 "Expected tensor to have size ", size, " at dimension ", dim,
90 ", but got size ", t->size(dim), " for ", t,
91 " (while checking arguments for ", c, ")");
92 }
93
checkSize_symint(CheckedFrom c,const TensorGeometryArg & t,int64_t dim,const c10::SymInt & size)94 void checkSize_symint(CheckedFrom c, const TensorGeometryArg& t, int64_t dim, const c10::SymInt& size) {
95 TORCH_CHECK(
96 t->sym_size(dim) == size,
97 "Expected tensor to have size ", size, " at dimension ", dim,
98 ", but got size ", t->size(dim), " for ", t,
99 " (while checking arguments for ", c, ")");
100 }
101
checkAllSame(CheckedFrom c,ArrayRef<TensorArg> tensors,void (* fn)(CheckedFrom,const TensorArg &,const TensorArg &))102 static void checkAllSame(CheckedFrom c, ArrayRef<TensorArg> tensors, void(*fn)(CheckedFrom, const TensorArg&, const TensorArg&)) {
103 const TensorArg* t0 = nullptr;
104 for (auto& t : tensors) {
105 if (!t->defined()) continue;
106 if (t0 != nullptr) {
107 fn(c, *t0, t);
108 } else {
109 t0 = &t;
110 }
111 }
112 }
113
checkSameSize(CheckedFrom c,const TensorArg & t1,const TensorArg & t2)114 void checkSameSize(CheckedFrom c, const TensorArg& t1, const TensorArg& t2) {
115 TORCH_CHECK(
116 t1->sizes().equals(t2->sizes()),
117 "Expected tensor for ", t1, " to have same size as tensor for ", t2,
118 "; but ", t1->sizes(), " does not equal ", t2->sizes(),
119 " (while checking arguments for ", c, ")");
120 }
121
checkAllSameSize(CheckedFrom c,ArrayRef<TensorArg> tensors)122 void checkAllSameSize(CheckedFrom c, ArrayRef<TensorArg> tensors) {
123 checkAllSame(c, tensors, checkSameSize);
124 }
125
checkNumel(CheckedFrom c,const TensorGeometryArg & t,int64_t numel)126 void checkNumel(CheckedFrom c, const TensorGeometryArg& t, int64_t numel) {
127 TORCH_CHECK(
128 t->numel() == numel,
129 "Expected tensor for ", t, " to have ", numel,
130 " elements; but it actually has ", t->numel(), " elements",
131 " (while checking arguments for ", c, ")");
132 }
133
checkSameNumel(CheckedFrom c,const TensorArg & t1,const TensorArg & t2)134 void checkSameNumel(CheckedFrom c, const TensorArg& t1, const TensorArg& t2) {
135 TORCH_CHECK(
136 t1->numel() == t2->numel(),
137 "Expected tensor for ", t1,
138 " to have same number of elements as tensor for ", t2, "; but ",
139 t1->numel(), " does not equal ", t2->numel(),
140 " (while checking arguments for ", c, ")");
141 }
142
checkAllSameNumel(CheckedFrom c,ArrayRef<TensorArg> tensors)143 void checkAllSameNumel(CheckedFrom c, ArrayRef<TensorArg> tensors) {
144 checkAllSame(c, tensors, checkSameNumel);
145 }
146
checkSameGPU(CheckedFrom c,const TensorArg & t1,const TensorArg & t2)147 void checkSameGPU(CheckedFrom c, const TensorArg& t1, const TensorArg& t2) {
148 if (t1->is_cpu() || t2->is_cpu()) {
149 std::ostringstream oss;
150 if (t1->is_cpu()) {
151 oss << "Tensor for " << t1 << " is on CPU, ";
152 }
153 if (t2->is_cpu()) {
154 oss << "Tensor for " << t2 << " is on CPU, ";
155 }
156 oss << "but expected " << ((!t1->is_cpu() && !t2->is_cpu()) ? "them" : "it")
157 << " to be on GPU (while checking arguments for " << c << ")";
158 AT_ERROR(oss.str());
159 }
160 TORCH_CHECK(
161 t1->get_device() == t2->get_device(),
162 "Expected tensor for ", t1, " to have the same device as tensor for ", t2,
163 "; but device ", t1->get_device(), " does not equal ", t2->get_device(),
164 " (while checking arguments for ", c, ")");
165 }
166
checkAllSameGPU(CheckedFrom c,ArrayRef<TensorArg> tensors)167 void checkAllSameGPU(CheckedFrom c, ArrayRef<TensorArg> tensors) {
168 checkAllSame(c, tensors, checkSameGPU);
169 }
170
checkSameType(CheckedFrom c,const TensorArg & t1,const TensorArg & t2)171 void checkSameType(CheckedFrom c, const TensorArg& t1, const TensorArg& t2) {
172 TORCH_CHECK(
173 t1->options().type_equal(t2->options()),
174 "Expected tensor for ", t1, " to have the same type as tensor for ", t2,
175 "; but type ", t1->toString(), " does not equal ", t2->toString(),
176 " (while checking arguments for ", c, ")");
177 }
178
checkScalarType(CheckedFrom c,const TensorArg & t,ScalarType ty)179 void checkScalarType(CheckedFrom c, const TensorArg& t, ScalarType ty) {
180 TORCH_CHECK(
181 t->scalar_type() == ty,
182 "Expected tensor for ", t, " to have scalar type ", toString(ty),
183 "; but got ", t->toString(), " instead (while checking arguments for ", c,
184 ")");
185 }
186
checkScalarTypes(CheckedFrom c,const TensorArg & t,at::ArrayRef<ScalarType> l)187 void checkScalarTypes(CheckedFrom c, const TensorArg& t,
188 at::ArrayRef<ScalarType> l) {
189 if (std::find(l.begin(), l.end(), t->scalar_type()) == l.end()) {
190 std::ostringstream oss;
191 oss << "Expected tensor for " << t << " to have one of the following "
192 << "scalar types: ";
193 size_t i = 0;
194 for (auto ty : l) {
195 if (i != 0) {
196 oss << ", ";
197 }
198 oss << toString(ty);
199 i++;
200 }
201 oss << "; but got " << t->toString()
202 << " instead (while checking arguments for " << c << ")";
203 AT_ERROR(oss.str());
204 }
205 }
206
checkAllSameType(CheckedFrom c,ArrayRef<TensorArg> tensors)207 void checkAllSameType(CheckedFrom c, ArrayRef<TensorArg> tensors) {
208 checkAllSame(c, tensors, checkSameType);
209 }
210
checkSameDim(CheckedFrom c,const TensorGeometryArg & t1,const TensorGeometryArg & t2)211 void checkSameDim(CheckedFrom c, const TensorGeometryArg& t1, const TensorGeometryArg& t2) {
212 TORCH_CHECK(
213 t1->dim() == t2->dim(),
214 "Expected tensor for ", t1, " to have the same dimension as tensor for ",
215 t2, "; but ", t1->dim(), " does not equal ", t2->dim(),
216 " (while checking arguments for ", c, ")");
217 }
218
checkDefined(CheckedFrom c,const TensorArg & t)219 void checkDefined(CheckedFrom c, const TensorArg& t) {
220 TORCH_CHECK(
221 t->defined(),
222 "Expected tensor for ", t, " to be non-null, but it was undefined ",
223 " (while checking arguments for ", c, ")");
224 }
225
checkAllDefined(CheckedFrom c,ArrayRef<TensorArg> ts)226 void checkAllDefined(CheckedFrom c, ArrayRef<TensorArg> ts) {
227 // NB: don't filter defined here
228 for (auto t : ts) {
229 checkDefined(c, t);
230 }
231 }
232
checkBackend(CheckedFrom c,const Tensor & t,Backend backend)233 static void checkBackend(CheckedFrom c, const Tensor& t, Backend backend) {
234 TORCH_CHECK(
235 !t.defined() || t.options().backend() == backend,
236 "Expected tensor to have ", toString(backend),
237 " Backend, but got tensor with ", toString(t.options().backend()), " Backend ",
238 "(while checking arguments for ", c, ")");
239 }
240
checkBackend(CheckedFrom c,at::ArrayRef<Tensor> tensors,at::Backend backend)241 void checkBackend(CheckedFrom c, at::ArrayRef<Tensor> tensors, at::Backend backend) {
242 for (auto &t : tensors) {
243 checkBackend(c, t, backend);
244 }
245 }
246
checkDeviceType(CheckedFrom c,const Tensor & t,DeviceType device_type)247 static void checkDeviceType(CheckedFrom c, const Tensor& t, DeviceType device_type) {
248 TORCH_CHECK(
249 !t.defined() || t.device().type() == device_type,
250 "Expected tensor to have ", device_type,
251 " DeviceType, but got tensor with ", t.device().type(), " DeviceType ",
252 "(while checking arguments for ", c, ")");
253 }
254
checkDeviceType(CheckedFrom c,at::ArrayRef<Tensor> tensors,at::DeviceType device_type)255 void checkDeviceType(CheckedFrom c, at::ArrayRef<Tensor> tensors, at::DeviceType device_type) {
256 for (auto &t : tensors) {
257 checkDeviceType(c, t, device_type);
258 }
259 }
260
checkLayout(CheckedFrom c,const Tensor & t,Layout layout)261 void checkLayout(CheckedFrom c, const Tensor& t, Layout layout) {
262 TORCH_CHECK(
263 !t.defined() || t.layout() == layout,
264 "Expected tensor to have ", layout,
265 " Layout, but got tensor with ", t.layout(), " Layout ",
266 "(while checking arguments for ", c, ")");
267 }
268
checkLayout(CheckedFrom c,at::ArrayRef<Tensor> tensors,at::Layout layout)269 void checkLayout(CheckedFrom c, at::ArrayRef<Tensor> tensors, at::Layout layout) {
270 for (auto &t : tensors) {
271 checkLayout(c, t, layout);
272 }
273 }
274
maybe_data_ptr(const Tensor & tensor)275 void * maybe_data_ptr(const Tensor& tensor) {
276 return tensor.defined() ? (void *)tensor.data_ptr() : nullptr;
277 }
278
maybe_data_ptr(const TensorArg & tensor)279 void * maybe_data_ptr(const TensorArg& tensor) {
280 return tensor->defined() ? (void *)tensor->data_ptr() : nullptr;
281 }
282
check_dim_size(const Tensor & tensor,int64_t dim,int64_t dim_size,int64_t size)283 void check_dim_size(
284 const Tensor& tensor,
285 int64_t dim,
286 int64_t dim_size,
287 int64_t size) {
288 /* Check dimension size of a tensor */
289 TORCH_CHECK(
290 tensor.dim() == dim && tensor.size(dim_size) == size,
291 "Expected a tensor of dimension ",
292 dim,
293 " and tensor.size[",
294 dim_size,
295 "] == ",
296 size,
297 " but got: dimension ",
298 tensor.dim(),
299 " and tensor.size[",
300 dim_size,
301 "] = ",
302 tensor.size(dim_size));
303 }
304
305 namespace detail {
306
defaultStrides(IntArrayRef sizes)307 std::vector<int64_t> defaultStrides(IntArrayRef sizes) {
308 std::vector<int64_t> strides(sizes.size());
309 int64_t stride = 1;
310 for(size_t i = sizes.size(); i > 0; --i) {
311 strides[i-1] = stride;
312 stride *= sizes[i-1];
313 }
314 return strides;
315 }
316
317 // On a high level,
318 // 1. separate `oldshape` into chunks of dimensions, where the dimensions are
319 // ``contiguous'' in each chunk, i.e., oldstride[i] = oldshape[i+1] *
320 // oldstride[i+1]
321 // 2. `newshape` must be able to be separated into same number of chunks as
322 // `oldshape` was separated into, where each chunk of newshape has matching
323 // ``numel'', i.e., number of subspaces, as the corresponding chunk of
324 // `oldshape`.
325 //
326 // templatized for DimVector and IntArrayRef use cases,
327 // see overloads of computeStride() below.
328 //
329 template <typename ResultVec, typename NewShapeVec, typename Numel>
computeStride_impl(const NewShapeVec & oldshape,const NewShapeVec & oldstride,const NewShapeVec & newshape,ResultVec toResult (const NewShapeVec &))330 inline std::optional<ResultVec> computeStride_impl(
331 const NewShapeVec& oldshape,
332 const NewShapeVec& oldstride,
333 const NewShapeVec& newshape,
334 ResultVec toResult(const NewShapeVec&)
335 ) {
336 if (oldshape.empty()) {
337 return ResultVec(newshape.size(), 1);
338 }
339
340 // NOTE: stride is arbitrary in the numel() == 0 case;
341 // to match NumPy behavior we copy the strides if the size matches, otherwise
342 // we use the stride as if it were computed via resize.
343 // This could perhaps be combined with the below code, but the complexity
344 // didn't seem worth it.
345 const Numel numel = c10::multiply_integers(oldshape);
346 bool zero_numel = TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(numel, 0));
347 if (zero_numel && oldshape.equals(newshape)) {
348 return toResult(oldstride);
349 }
350
351 ResultVec newstride(newshape.size());
352 if (zero_numel) {
353 for (int64_t view_d = newshape.size() - 1; view_d >= 0; view_d--) {
354 if (view_d == (int64_t)(newshape.size() - 1)) {
355 newstride[view_d] = 1;
356 } else {
357 newstride[view_d] =
358 std::max<Numel>(newshape[view_d+1], Numel(1)) * newstride[view_d+1];
359 }
360 }
361 return newstride;
362 }
363
364 int64_t view_d = (int64_t)newshape.size() - 1;
365 // stride for each subspace in the chunk
366 Numel chunk_base_stride = oldstride.back();
367 // numel in current chunk
368 Numel tensor_numel = 1;
369 Numel view_numel = 1;
370 for (int64_t tensor_d = oldshape.size() - 1; tensor_d >= 0; tensor_d--) {
371 tensor_numel *= oldshape[tensor_d];
372 // if end of tensor size chunk, check view
373 if ((tensor_d == 0) ||
374 (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(oldshape[tensor_d - 1], 1)) &&
375 oldstride[tensor_d - 1] != tensor_numel * chunk_base_stride)) {
376 while (view_d >= 0 &&
377 (TORCH_GUARD_SIZE_OBLIVIOUS(sym_lt(view_numel, tensor_numel)) || TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(newshape[view_d], 1)))) {
378 newstride[view_d] = view_numel * chunk_base_stride;
379 view_numel *= newshape[view_d];
380 view_d--;
381 }
382 if (view_numel != tensor_numel) {
383 return std::nullopt;
384 }
385 if (tensor_d > 0) {
386 chunk_base_stride = oldstride[tensor_d - 1];
387 tensor_numel = 1;
388 view_numel = 1;
389 }
390 }
391 }
392 if (view_d != -1) {
393 return std::nullopt;
394 }
395 return newstride;
396 }
397
computeStride(IntArrayRef oldshape,IntArrayRef oldstride,IntArrayRef newshape)398 std::optional<std::vector<int64_t>> computeStride(
399 IntArrayRef oldshape,
400 IntArrayRef oldstride,
401 IntArrayRef newshape) {
402 auto toResult = [](const IntArrayRef& a) { return a.vec(); };
403 return computeStride_impl<std::vector<int64_t>, IntArrayRef, int64_t>(oldshape, oldstride, newshape, toResult);
404 }
405
computeStride(c10::SymIntArrayRef oldshape,c10::SymIntArrayRef oldstride,c10::SymIntArrayRef newshape)406 std::optional<SymDimVector> computeStride(
407 c10::SymIntArrayRef oldshape,
408 c10::SymIntArrayRef oldstride,
409 c10::SymIntArrayRef newshape) {
410 auto toResult = [](const SymIntArrayRef& a) { return SymDimVector(a); };
411 return computeStride_impl<SymDimVector, c10::SymIntArrayRef, c10::SymInt>(oldshape, oldstride, newshape, toResult);
412 }
413
computeStride(IntArrayRef oldshape,IntArrayRef oldstride,const DimVector & newshape)414 std::optional<DimVector> computeStride(
415 IntArrayRef oldshape,
416 IntArrayRef oldstride,
417 const DimVector& newshape) {
418 auto toResult = [](const IntArrayRef& a) { return DimVector(a); };
419 return computeStride_impl<DimVector, IntArrayRef, int64_t>(oldshape, oldstride, newshape, toResult);
420 }
421
422 } // namespace detail
423 } // namespace at
424