1 #include <ATen/cuda/detail/IndexUtils.cuh> 2 #include <vector> 3 4 namespace at { 5 namespace cuda { 6 namespace detail { 7 8 struct SizeAndStride { 9 int64_t size; 10 int64_t stride; 11 }; 12 13 /* 14 A comparator that will sort SizeAndStride structs by stride, 15 in ascending order. 16 */ compareSizeAndStride(const void * a,const void * b)17 int compareSizeAndStride(const void* a, const void* b) { 18 const SizeAndStride* aS = (const SizeAndStride*) a; 19 const SizeAndStride* bS = (const SizeAndStride*) b; 20 21 if (aS->stride < bS->stride) return -1; 22 if (aS->stride == bS->stride) return 0; 23 return 1; 24 } 25 26 /* 27 Returns false if there is no possibility that the tensor 28 has "overlapping" indices and true otherwise. 29 "Overlapping" indices are two+ valid indices that specify 30 the same offset within the tensor. 31 The function does this by checking for a sufficient but not 32 necessary condition of no overlap. In particular, that 33 that there exists an ordering of the tensor's dimensions 34 that is nicely "nested," with each dimension contained 35 within the next one. 36 */ maybeOverlappingIndices(const TensorBase & t)37bool maybeOverlappingIndices(const TensorBase& t) { 38 /* Extract size/stride arrays; only consider size >1 dims. */ 39 std::vector<SizeAndStride> info(t.dim()); 40 int dims = t.dim(); 41 int nonSize1Dims = 0; 42 for (int i = 0; i < dims; ++i) { 43 int64_t size = t.size(i); 44 if (size > 1) { 45 info[nonSize1Dims].size = size; 46 info[nonSize1Dims].stride = t.stride(i); 47 48 if (info[nonSize1Dims].stride < 1) { 49 return true; 50 } 51 52 ++nonSize1Dims; 53 } 54 } 55 56 // Short-circuits if tensor is a single element. 57 if (nonSize1Dims == 0) { 58 return false; 59 } 60 61 /* Ascending order (innermost dimension in sorted view is at [0]) */ 62 qsort(info.data(), nonSize1Dims, sizeof(SizeAndStride), compareSizeAndStride); 63 64 for (int i = 0; i < (nonSize1Dims - 1); ++i) { 65 if (((info[i].size - 1) * info[i].stride) >= info[i + 1].stride) { 66 return true; 67 } 68 } 69 70 return false; 71 } 72 73 } // detail 74 } // cuda 75 } // at 76