xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/detail/IndexUtils.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/cuda/detail/IndexUtils.cuh>
2 #include <vector>
3 
4 namespace at {
5 namespace cuda {
6 namespace detail {
7 
8 struct SizeAndStride {
9   int64_t size;
10   int64_t stride;
11 };
12 
13 /*
14  A comparator that will sort SizeAndStride structs by stride,
15  in ascending order.
16  */
compareSizeAndStride(const void * a,const void * b)17  int compareSizeAndStride(const void* a, const void* b) {
18   const SizeAndStride* aS = (const SizeAndStride*) a;
19   const SizeAndStride* bS = (const SizeAndStride*) b;
20 
21   if (aS->stride < bS->stride) return -1;
22   if (aS->stride == bS->stride) return 0;
23   return 1;
24 }
25 
26 /*
27 Returns false if there is no possibility that the tensor
28 has "overlapping" indices and true otherwise.
29 "Overlapping" indices are two+ valid indices that specify
30 the same offset within the tensor.
31 The function does this by checking for a sufficient but not
32 necessary condition of no overlap. In particular, that
33 that there exists an ordering of the tensor's dimensions
34 that is nicely "nested," with each dimension contained
35 within the next one.
36 */
maybeOverlappingIndices(const TensorBase & t)37 bool maybeOverlappingIndices(const TensorBase& t) {
38   /* Extract size/stride arrays; only consider size >1 dims. */
39   std::vector<SizeAndStride> info(t.dim());
40   int dims = t.dim();
41   int nonSize1Dims = 0;
42   for (int i = 0; i < dims; ++i) {
43     int64_t size = t.size(i);
44     if (size > 1) {
45       info[nonSize1Dims].size = size;
46       info[nonSize1Dims].stride = t.stride(i);
47 
48       if (info[nonSize1Dims].stride < 1) {
49         return true;
50       }
51 
52       ++nonSize1Dims;
53     }
54   }
55 
56   // Short-circuits if tensor is a single element.
57   if (nonSize1Dims == 0) {
58     return false;
59   }
60 
61   /* Ascending order (innermost dimension in sorted view is at [0]) */
62   qsort(info.data(), nonSize1Dims, sizeof(SizeAndStride), compareSizeAndStride);
63 
64   for (int i = 0; i < (nonSize1Dims - 1); ++i) {
65     if (((info[i].size - 1) * info[i].stride) >= info[i + 1].stride) {
66       return true;
67     }
68   }
69 
70   return false;
71 }
72 
73 } // detail
74 } // cuda
75 } // at
76