xref: /aosp_15_r20/external/pytorch/aten/src/ATen/TensorGeometry.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/TensorGeometry.h>
2 
3 namespace at {
4 
5 // See TensorGeometry.h on why this is useful now that we cache is_contiguous.
6 template <typename T>
_geometry_is_contiguous(ArrayRef<T> sizes,ArrayRef<T> strides)7 bool _geometry_is_contiguous(ArrayRef<T> sizes, ArrayRef<T> strides) {
8   assert(!overflows<std::int64_t>(sizes.size()));
9   auto dim = static_cast<std::int64_t>(sizes.size());
10   T expected_stride = 1;
11   bool contig_if_nonempty = true;
12   for (int64_t i = dim - 1; i >= 0; i--) {
13     if (sizes[i] == 0) {
14       return true;
15     }
16     if (contig_if_nonempty) {
17       if (sizes[i] != 1 && strides[i] != expected_stride) {
18         contig_if_nonempty = false;
19       }
20       expected_stride *= sizes[i];
21     }
22   }
23   return contig_if_nonempty;
24 }
25 
geometry_is_contiguous(IntArrayRef sizes,IntArrayRef strides)26 bool geometry_is_contiguous(IntArrayRef sizes, IntArrayRef strides) {
27   return _geometry_is_contiguous(sizes, strides);
28 }
29 
is_contiguous() const30 bool TensorGeometry::is_contiguous() const {
31   if (numel_ == 0) {
32     return true;
33   }
34   return at::_geometry_is_contiguous<c10::SymInt>(sizes_, strides_);
35 }
36 
37 } // namespace at
38