xref: /aosp_15_r20/external/pytorch/c10/cuda/CUDAAlgorithm.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #ifdef THRUST_DEVICE_LOWER_BOUND_WORKS
2*da0073e9SAndroid Build Coastguard Worker #include <thrust/binary_search.h>
3*da0073e9SAndroid Build Coastguard Worker #include <thrust/device_vector.h>
4*da0073e9SAndroid Build Coastguard Worker #include <thrust/execution_policy.h>
5*da0073e9SAndroid Build Coastguard Worker #include <thrust/functional.h>
6*da0073e9SAndroid Build Coastguard Worker #endif
7*da0073e9SAndroid Build Coastguard Worker namespace c10::cuda {
8*da0073e9SAndroid Build Coastguard Worker #ifdef THRUST_DEVICE_LOWER_BOUND_WORKS
9*da0073e9SAndroid Build Coastguard Worker template <typename Iter, typename Scalar>
10*da0073e9SAndroid Build Coastguard Worker __forceinline__ __device__ Iter
lower_bound(Iter start,Iter end,Scalar value)11*da0073e9SAndroid Build Coastguard Worker lower_bound(Iter start, Iter end, Scalar value) {
12*da0073e9SAndroid Build Coastguard Worker   return thrust::lower_bound(thrust::device, start, end, value);
13*da0073e9SAndroid Build Coastguard Worker }
14*da0073e9SAndroid Build Coastguard Worker #else
15*da0073e9SAndroid Build Coastguard Worker // thrust::lower_bound is broken on device, see
16*da0073e9SAndroid Build Coastguard Worker // https://github.com/NVIDIA/thrust/issues/1734 Implementation inspired by
17*da0073e9SAndroid Build Coastguard Worker // https://github.com/pytorch/pytorch/blob/805120ab572efef66425c9f595d9c6c464383336/aten/src/ATen/native/cuda/Bucketization.cu#L28
18*da0073e9SAndroid Build Coastguard Worker template <typename Iter, typename Scalar>
19*da0073e9SAndroid Build Coastguard Worker __device__ Iter lower_bound(Iter start, Iter end, Scalar value) {
20*da0073e9SAndroid Build Coastguard Worker   while (start < end) {
21*da0073e9SAndroid Build Coastguard Worker     auto mid = start + ((end - start) >> 1);
22*da0073e9SAndroid Build Coastguard Worker     if (*mid < value) {
23*da0073e9SAndroid Build Coastguard Worker       start = mid + 1;
24*da0073e9SAndroid Build Coastguard Worker     } else {
25*da0073e9SAndroid Build Coastguard Worker       end = mid;
26*da0073e9SAndroid Build Coastguard Worker     }
27*da0073e9SAndroid Build Coastguard Worker   }
28*da0073e9SAndroid Build Coastguard Worker   return end;
29*da0073e9SAndroid Build Coastguard Worker }
30*da0073e9SAndroid Build Coastguard Worker #endif // THRUST_DEVICE_LOWER_BOUND_WORKS
31*da0073e9SAndroid Build Coastguard Worker } // namespace c10::cuda
32