xref: /aosp_15_r20/external/pytorch/aten/src/ATen/TensorUtils.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/ATen.h>
2 #include <ATen/Config.h>
3 #include <ATen/TensorUtils.h>
4 #include <c10/util/accumulate.h>
5 #include <c10/util/irange.h>
6 
7 #include <ostream>
8 #include <sstream>
9 
10 namespace at {
11 
operator <<(std::ostream & out,const TensorGeometryArg & t)12 std::ostream& operator<<(std::ostream & out, const TensorGeometryArg& t) {
13   if (t.pos == 0) {
14     // 0 is distinguished; it usually indicates 'self' or the return
15     // tensor
16     out << "'" << t.name << "'";
17   } else {
18     out << "argument #" << t.pos << " '" << t.name << "'";
19   }
20   return out;
21 }
22 
checkDim(CheckedFrom c,const Tensor & tensor,const char * name,int pos,int64_t dim)23 void checkDim(
24     CheckedFrom c,
25     const Tensor& tensor,
26     const char* name,
27     int pos, // 1-indexed
28     int64_t dim) {
29   TORCH_CHECK(
30       tensor.dim() == dim,
31       "Expected ",
32       dim,
33       "-dimensional tensor, but got ",
34       tensor.dim(),
35       "-dimensional tensor for ",
36       TensorGeometryArg(TensorArg({tensor, name, pos})),
37       " (while checking arguments for ",
38       c,
39       ")");
40 }
41 
checkDim(CheckedFrom c,const TensorGeometryArg & t,int64_t dim)42 void checkDim(CheckedFrom c, const TensorGeometryArg& t, int64_t dim) {
43   TORCH_CHECK(t->dim() == dim,
44     "Expected ", dim, "-dimensional tensor, but got ", t->dim(),
45     "-dimensional tensor for ", t," (while checking arguments for ", c, ")");
46 }
47 
checkDimRange(CheckedFrom c,const TensorGeometryArg & t,int64_t dim_start,int64_t dim_end)48 void checkDimRange(CheckedFrom c, const TensorGeometryArg& t, int64_t dim_start, int64_t dim_end) {
49   TORCH_CHECK(
50     t->dim() >= dim_start && t->dim() < dim_end,
51     "Expected ", dim_start, " to ", (dim_end - 1), " dimensions, but got ",
52     t->dim(), "-dimensional tensor for ", t, " (while checking arguments for ",
53     c, ")");
54 }
55 
checkContiguous(CheckedFrom c,const TensorGeometryArg & t)56 void checkContiguous(CheckedFrom c, const TensorGeometryArg& t) {
57   TORCH_CHECK(
58     t->is_contiguous(),
59     "Expected contiguous tensor, but got non-contiguous tensor for ", t,
60      " (while checking arguments for ", c, ")");
61 }
62 
checkAllContiguous(CheckedFrom c,at::ArrayRef<TensorArg> ts)63 void checkAllContiguous(CheckedFrom c, at::ArrayRef<TensorArg> ts) {
64   for (auto& t : ts) {
65     if (!t->defined()) continue;
66     checkContiguous(c, t);
67   }
68 }
69 
checkSize(CheckedFrom c,const TensorGeometryArg & t,IntArrayRef sizes)70 void checkSize(CheckedFrom c, const TensorGeometryArg& t, IntArrayRef sizes) {
71   checkDim(c, t, static_cast<int64_t>(sizes.size()));
72   TORCH_CHECK(
73     t->sizes().equals(sizes),
74     "Expected tensor of size ", sizes, ", but got tensor of size ", t->sizes(),
75     " for ", t, " (while checking arguments for ", c, ")");
76 }
77 
checkSize_symint(CheckedFrom c,const TensorGeometryArg & t,c10::SymIntArrayRef sizes)78 void checkSize_symint(CheckedFrom c, const TensorGeometryArg& t, c10::SymIntArrayRef sizes) {
79   checkDim(c, t, static_cast<int64_t>(sizes.size()));
80   TORCH_CHECK(
81     t->sym_sizes().equals(sizes),
82     "Expected tensor of size ", sizes, ", but got tensor of size ", t->sizes(),
83     " for ", t, " (while checking arguments for ", c, ")");
84 }
85 
checkSize(CheckedFrom c,const TensorGeometryArg & t,int64_t dim,int64_t size)86 void checkSize(CheckedFrom c, const TensorGeometryArg& t, int64_t dim, int64_t size) {
87   TORCH_CHECK(
88     t->size(dim) == size,
89     "Expected tensor to have size ", size, " at dimension ", dim,
90     ", but got size ", t->size(dim), " for ", t,
91     " (while checking arguments for ", c, ")");
92 }
93 
checkSize_symint(CheckedFrom c,const TensorGeometryArg & t,int64_t dim,const c10::SymInt & size)94 void checkSize_symint(CheckedFrom c, const TensorGeometryArg& t, int64_t dim, const c10::SymInt& size) {
95   TORCH_CHECK(
96     t->sym_size(dim) == size,
97     "Expected tensor to have size ", size, " at dimension ", dim,
98     ", but got size ", t->size(dim), " for ", t,
99     " (while checking arguments for ", c, ")");
100 }
101 
checkAllSame(CheckedFrom c,ArrayRef<TensorArg> tensors,void (* fn)(CheckedFrom,const TensorArg &,const TensorArg &))102 static void checkAllSame(CheckedFrom c, ArrayRef<TensorArg> tensors, void(*fn)(CheckedFrom, const TensorArg&, const TensorArg&)) {
103   const TensorArg* t0 = nullptr;
104   for (auto& t : tensors) {
105     if (!t->defined()) continue;
106     if (t0 != nullptr) {
107       fn(c, *t0, t);
108     } else {
109       t0 = &t;
110     }
111   }
112 }
113 
checkSameSize(CheckedFrom c,const TensorArg & t1,const TensorArg & t2)114 void checkSameSize(CheckedFrom c, const TensorArg& t1, const TensorArg& t2) {
115   TORCH_CHECK(
116     t1->sizes().equals(t2->sizes()),
117     "Expected tensor for ", t1, " to have same size as tensor for ", t2,
118     "; but ", t1->sizes(), " does not equal ", t2->sizes(),
119     " (while checking arguments for ", c, ")");
120 }
121 
checkAllSameSize(CheckedFrom c,ArrayRef<TensorArg> tensors)122 void checkAllSameSize(CheckedFrom c, ArrayRef<TensorArg> tensors) {
123   checkAllSame(c, tensors, checkSameSize);
124 }
125 
checkNumel(CheckedFrom c,const TensorGeometryArg & t,int64_t numel)126 void checkNumel(CheckedFrom c, const TensorGeometryArg& t, int64_t numel) {
127   TORCH_CHECK(
128     t->numel() == numel,
129     "Expected tensor for ", t, " to have ", numel,
130     " elements; but it actually has ", t->numel(), " elements",
131     " (while checking arguments for ", c, ")");
132 }
133 
checkSameNumel(CheckedFrom c,const TensorArg & t1,const TensorArg & t2)134 void checkSameNumel(CheckedFrom c, const TensorArg& t1, const TensorArg& t2) {
135   TORCH_CHECK(
136     t1->numel() == t2->numel(),
137     "Expected tensor for ", t1,
138     " to have same number of elements as tensor for ", t2, "; but ",
139     t1->numel(), " does not equal ", t2->numel(),
140     " (while checking arguments for ", c, ")");
141 }
142 
checkAllSameNumel(CheckedFrom c,ArrayRef<TensorArg> tensors)143 void checkAllSameNumel(CheckedFrom c, ArrayRef<TensorArg> tensors) {
144   checkAllSame(c, tensors, checkSameNumel);
145 }
146 
checkSameGPU(CheckedFrom c,const TensorArg & t1,const TensorArg & t2)147 void checkSameGPU(CheckedFrom c, const TensorArg& t1, const TensorArg& t2) {
148   if (t1->is_cpu() || t2->is_cpu()) {
149     std::ostringstream oss;
150     if (t1->is_cpu()) {
151       oss << "Tensor for " << t1 << " is on CPU, ";
152     }
153     if (t2->is_cpu()) {
154       oss << "Tensor for " << t2 << " is on CPU, ";
155     }
156     oss << "but expected " << ((!t1->is_cpu() && !t2->is_cpu()) ? "them" : "it")
157         << " to be on GPU (while checking arguments for " << c << ")";
158     AT_ERROR(oss.str());
159   }
160   TORCH_CHECK(
161     t1->get_device() == t2->get_device(),
162     "Expected tensor for ", t1, " to have the same device as tensor for ", t2,
163     "; but device ", t1->get_device(), " does not equal ", t2->get_device(),
164     " (while checking arguments for ", c, ")");
165 }
166 
checkAllSameGPU(CheckedFrom c,ArrayRef<TensorArg> tensors)167 void checkAllSameGPU(CheckedFrom c, ArrayRef<TensorArg> tensors) {
168   checkAllSame(c, tensors, checkSameGPU);
169 }
170 
checkSameType(CheckedFrom c,const TensorArg & t1,const TensorArg & t2)171 void checkSameType(CheckedFrom c, const TensorArg& t1, const TensorArg& t2) {
172   TORCH_CHECK(
173     t1->options().type_equal(t2->options()),
174     "Expected tensor for ", t1, " to have the same type as tensor for ", t2,
175     "; but type ", t1->toString(), " does not equal ", t2->toString(),
176     " (while checking arguments for ", c, ")");
177 }
178 
checkScalarType(CheckedFrom c,const TensorArg & t,ScalarType ty)179 void checkScalarType(CheckedFrom c, const TensorArg& t, ScalarType ty) {
180   TORCH_CHECK(
181     t->scalar_type() == ty,
182     "Expected tensor for ", t, " to have scalar type ", toString(ty),
183     "; but got ", t->toString(), " instead (while checking arguments for ", c,
184     ")");
185 }
186 
checkScalarTypes(CheckedFrom c,const TensorArg & t,at::ArrayRef<ScalarType> l)187 void checkScalarTypes(CheckedFrom c, const TensorArg& t,
188                       at::ArrayRef<ScalarType> l) {
189     if (std::find(l.begin(), l.end(), t->scalar_type()) == l.end()) {
190       std::ostringstream oss;
191       oss << "Expected tensor for " << t << " to have one of the following "
192           << "scalar types: ";
193       size_t i = 0;
194       for (auto ty : l) {
195         if (i != 0) {
196           oss << ", ";
197         }
198         oss << toString(ty);
199         i++;
200       }
201       oss << "; but got " << t->toString()
202           << " instead (while checking arguments for " << c << ")";
203       AT_ERROR(oss.str());
204     }
205 }
206 
checkAllSameType(CheckedFrom c,ArrayRef<TensorArg> tensors)207 void checkAllSameType(CheckedFrom c, ArrayRef<TensorArg> tensors) {
208   checkAllSame(c, tensors, checkSameType);
209 }
210 
checkSameDim(CheckedFrom c,const TensorGeometryArg & t1,const TensorGeometryArg & t2)211 void checkSameDim(CheckedFrom c, const TensorGeometryArg& t1, const TensorGeometryArg& t2) {
212   TORCH_CHECK(
213     t1->dim() == t2->dim(),
214     "Expected tensor for ", t1, " to have the same dimension as tensor for ",
215     t2, "; but ", t1->dim(), " does not equal ", t2->dim(),
216     " (while checking arguments for ", c, ")");
217 }
218 
checkDefined(CheckedFrom c,const TensorArg & t)219 void checkDefined(CheckedFrom c, const TensorArg& t) {
220   TORCH_CHECK(
221     t->defined(),
222     "Expected tensor for ", t, " to be non-null, but it was undefined ",
223     " (while checking arguments for ", c, ")");
224 }
225 
checkAllDefined(CheckedFrom c,ArrayRef<TensorArg> ts)226 void checkAllDefined(CheckedFrom c, ArrayRef<TensorArg> ts) {
227   // NB: don't filter defined here
228   for (auto t : ts) {
229     checkDefined(c, t);
230   }
231 }
232 
checkBackend(CheckedFrom c,const Tensor & t,Backend backend)233 static void checkBackend(CheckedFrom c, const Tensor& t, Backend backend) {
234   TORCH_CHECK(
235     !t.defined() || t.options().backend() == backend,
236     "Expected tensor to have ", toString(backend),
237     " Backend, but got tensor with ", toString(t.options().backend()), " Backend ",
238     "(while checking arguments for ", c, ")");
239 }
240 
checkBackend(CheckedFrom c,at::ArrayRef<Tensor> tensors,at::Backend backend)241 void checkBackend(CheckedFrom c, at::ArrayRef<Tensor> tensors, at::Backend backend) {
242   for (auto &t : tensors) {
243     checkBackend(c, t, backend);
244   }
245 }
246 
checkDeviceType(CheckedFrom c,const Tensor & t,DeviceType device_type)247 static void checkDeviceType(CheckedFrom c, const Tensor& t, DeviceType device_type) {
248   TORCH_CHECK(
249       !t.defined() || t.device().type() == device_type,
250       "Expected tensor to have ", device_type,
251       " DeviceType, but got tensor with ", t.device().type(), " DeviceType ",
252       "(while checking arguments for ", c, ")");
253 }
254 
checkDeviceType(CheckedFrom c,at::ArrayRef<Tensor> tensors,at::DeviceType device_type)255 void checkDeviceType(CheckedFrom c, at::ArrayRef<Tensor> tensors, at::DeviceType device_type) {
256   for (auto &t : tensors) {
257     checkDeviceType(c, t, device_type);
258   }
259 }
260 
checkLayout(CheckedFrom c,const Tensor & t,Layout layout)261 void checkLayout(CheckedFrom c, const Tensor& t, Layout layout) {
262   TORCH_CHECK(
263     !t.defined() || t.layout() == layout,
264     "Expected tensor to have ", layout,
265     " Layout, but got tensor with ", t.layout(), " Layout ",
266     "(while checking arguments for ", c, ")");
267 }
268 
checkLayout(CheckedFrom c,at::ArrayRef<Tensor> tensors,at::Layout layout)269 void checkLayout(CheckedFrom c, at::ArrayRef<Tensor> tensors, at::Layout layout) {
270   for (auto &t : tensors) {
271     checkLayout(c, t, layout);
272   }
273 }
274 
maybe_data_ptr(const Tensor & tensor)275 void * maybe_data_ptr(const Tensor& tensor) {
276   return tensor.defined() ? (void *)tensor.data_ptr() : nullptr;
277 }
278 
maybe_data_ptr(const TensorArg & tensor)279 void * maybe_data_ptr(const TensorArg& tensor) {
280   return tensor->defined() ? (void *)tensor->data_ptr() : nullptr;
281 }
282 
check_dim_size(const Tensor & tensor,int64_t dim,int64_t dim_size,int64_t size)283 void check_dim_size(
284     const Tensor& tensor,
285     int64_t dim,
286     int64_t dim_size,
287     int64_t size) {
288   /* Check dimension size of a tensor */
289   TORCH_CHECK(
290       tensor.dim() == dim && tensor.size(dim_size) == size,
291       "Expected a tensor of dimension ",
292       dim,
293       " and tensor.size[",
294       dim_size,
295       "] == ",
296       size,
297       " but got: dimension ",
298       tensor.dim(),
299       " and tensor.size[",
300       dim_size,
301       "] = ",
302       tensor.size(dim_size));
303 }
304 
305 namespace detail {
306 
defaultStrides(IntArrayRef sizes)307 std::vector<int64_t> defaultStrides(IntArrayRef sizes) {
308   std::vector<int64_t> strides(sizes.size());
309   int64_t stride = 1;
310   for(size_t i = sizes.size(); i > 0; --i) {
311     strides[i-1] = stride;
312     stride *= sizes[i-1];
313   }
314   return strides;
315 }
316 
317 // On a high level,
318 // 1. separate `oldshape` into chunks of dimensions, where the dimensions are
319 //    ``contiguous'' in each chunk, i.e., oldstride[i] = oldshape[i+1] *
320 //     oldstride[i+1]
321 // 2. `newshape` must be able to be separated into same number of chunks as
322 //    `oldshape` was separated into, where each chunk of newshape has matching
323 //    ``numel'', i.e., number of subspaces, as the corresponding chunk of
324 //    `oldshape`.
325 //
326 // templatized for DimVector and IntArrayRef use cases,
327 // see overloads of computeStride() below.
328 //
329 template <typename ResultVec, typename NewShapeVec, typename Numel>
computeStride_impl(const NewShapeVec & oldshape,const NewShapeVec & oldstride,const NewShapeVec & newshape,ResultVec toResult (const NewShapeVec &))330 inline std::optional<ResultVec> computeStride_impl(
331     const NewShapeVec& oldshape,
332     const NewShapeVec& oldstride,
333     const NewShapeVec& newshape,
334     ResultVec toResult(const NewShapeVec&)
335 ) {
336   if (oldshape.empty()) {
337     return ResultVec(newshape.size(), 1);
338   }
339 
340   // NOTE: stride is arbitrary in the numel() == 0 case;
341   // to match NumPy behavior we copy the strides if the size matches, otherwise
342   // we use the stride as if it were computed via resize.
343   // This could perhaps be combined with the below code, but the complexity
344   // didn't seem worth it.
345   const Numel numel = c10::multiply_integers(oldshape);
346   bool zero_numel = TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(numel, 0));
347   if (zero_numel && oldshape.equals(newshape)) {
348     return toResult(oldstride);
349   }
350 
351   ResultVec newstride(newshape.size());
352   if (zero_numel) {
353     for (int64_t view_d = newshape.size() - 1; view_d >= 0; view_d--) {
354       if (view_d == (int64_t)(newshape.size() - 1)) {
355         newstride[view_d] = 1;
356       } else {
357         newstride[view_d] =
358           std::max<Numel>(newshape[view_d+1], Numel(1)) * newstride[view_d+1];
359       }
360     }
361     return newstride;
362   }
363 
364   int64_t view_d = (int64_t)newshape.size() - 1;
365   // stride for each subspace in the chunk
366   Numel chunk_base_stride = oldstride.back();
367   // numel in current chunk
368   Numel tensor_numel = 1;
369   Numel view_numel = 1;
370   for (int64_t tensor_d = oldshape.size() - 1; tensor_d >= 0; tensor_d--) {
371     tensor_numel *= oldshape[tensor_d];
372     // if end of tensor size chunk, check view
373     if ((tensor_d == 0) ||
374         (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(oldshape[tensor_d - 1], 1)) &&
375          oldstride[tensor_d - 1] != tensor_numel * chunk_base_stride)) {
376       while (view_d >= 0 &&
377             (TORCH_GUARD_SIZE_OBLIVIOUS(sym_lt(view_numel, tensor_numel)) || TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(newshape[view_d], 1)))) {
378         newstride[view_d] = view_numel * chunk_base_stride;
379         view_numel *= newshape[view_d];
380         view_d--;
381       }
382       if (view_numel != tensor_numel) {
383         return std::nullopt;
384       }
385       if (tensor_d > 0) {
386         chunk_base_stride = oldstride[tensor_d - 1];
387         tensor_numel = 1;
388         view_numel = 1;
389       }
390     }
391   }
392   if (view_d != -1) {
393     return std::nullopt;
394   }
395   return newstride;
396 }
397 
computeStride(IntArrayRef oldshape,IntArrayRef oldstride,IntArrayRef newshape)398 std::optional<std::vector<int64_t>> computeStride(
399     IntArrayRef oldshape,
400     IntArrayRef oldstride,
401     IntArrayRef newshape) {
402   auto toResult = [](const IntArrayRef& a) { return a.vec(); };
403   return computeStride_impl<std::vector<int64_t>, IntArrayRef, int64_t>(oldshape, oldstride, newshape, toResult);
404 }
405 
computeStride(c10::SymIntArrayRef oldshape,c10::SymIntArrayRef oldstride,c10::SymIntArrayRef newshape)406 std::optional<SymDimVector> computeStride(
407     c10::SymIntArrayRef oldshape,
408     c10::SymIntArrayRef oldstride,
409     c10::SymIntArrayRef newshape) {
410   auto toResult = [](const SymIntArrayRef& a) { return SymDimVector(a); };
411   return computeStride_impl<SymDimVector, c10::SymIntArrayRef, c10::SymInt>(oldshape, oldstride, newshape, toResult);
412 }
413 
computeStride(IntArrayRef oldshape,IntArrayRef oldstride,const DimVector & newshape)414 std::optional<DimVector> computeStride(
415     IntArrayRef oldshape,
416     IntArrayRef oldstride,
417     const DimVector& newshape) {
418   auto toResult = [](const IntArrayRef& a) { return DimVector(a); };
419   return computeStride_impl<DimVector, IntArrayRef, int64_t>(oldshape, oldstride, newshape, toResult);
420 }
421 
422 }  // namespace detail
423 }  // namespace at
424