1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/Dispatch.h>
3 #include <ATen/native/TensorIterator.h>
4 #include <ATen/native/LinearAlgebra.h>
5 #include <ATen/native/BatchLinearAlgebra.h>
6 #include <ATen/native/DispatchStub.h>
7 #include <ATen/native/cuda/Loops.cuh>
8 #include <ATen/native/SharedReduceOps.h>
9 #include <ATen/native/ReduceOps.h>
10 #include <c10/core/Scalar.h>
11
12 #include <thrust/swap.h>
13
14 namespace at::native {
15
16 namespace {
17
addr_kernel_cuda(TensorIterator & iter,const Scalar & beta,const Scalar & alpha)18 void addr_kernel_cuda(TensorIterator &iter, const Scalar& beta, const Scalar& alpha) {
19 if (iter.dtype() == ScalarType::Bool) {
20 using scalar_t = bool;
21 auto beta_val = beta.to<scalar_t>();
22 auto alpha_val = alpha.to<scalar_t>();
23
24 // when beta is false, values in self should be ignored,
25 // nans and infs in self should not propagate.
26 if (beta_val == false) {
27 gpu_kernel(
28 iter,
29 [=] GPU_LAMBDA (scalar_t self_val,
30 scalar_t vec1_val, scalar_t vec2_val) -> scalar_t {
31 return alpha_val && vec1_val && vec2_val;
32 }
33 );
34 } else {
35 gpu_kernel(
36 iter,
37 [=] GPU_LAMBDA (scalar_t self_val,
38 scalar_t vec1_val, scalar_t vec2_val) -> scalar_t {
39 return (beta_val && self_val) || (alpha_val && vec1_val && vec2_val);
40 }
41 );
42 }
43 return;
44 }
45
46 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf,
47 iter.dtype(), "addr_cuda", [&] {
48 auto beta_val = beta.to<scalar_t>();
49 auto alpha_val = alpha.to<scalar_t>();
50
51 scalar_t zero_val(0);
52 // when beta==0, values in self should be ignored,
53 // nans and infs in self should not propagate.
54 if (beta_val == zero_val) {
55 gpu_kernel(
56 iter,
57 [=] GPU_LAMBDA (scalar_t self_val,
58 scalar_t vec1_val, scalar_t vec2_val) -> scalar_t {
59 return alpha_val * vec1_val * vec2_val;
60 }
61 );
62 } else {
63 gpu_kernel(
64 iter,
65 [=] GPU_LAMBDA (scalar_t self_val,
66 scalar_t vec1_val, scalar_t vec2_val) -> scalar_t {
67 return beta_val * self_val + alpha_val * vec1_val * vec2_val;
68 }
69 );
70 }
71 });
72 }
73
74
75 template <int n_threads, int n_elems_per_thread, typename func_t>
C10_LAUNCH_BOUNDS_2(n_threads,n_elems_per_thread)76 C10_LAUNCH_BOUNDS_2(n_threads, n_elems_per_thread)
77 __global__ void _elementwise_kernel(int total_n_elems, func_t f) {
78 constexpr int total_work_block = n_threads * n_elems_per_thread;
79 int idx = total_work_block * blockIdx.x + threadIdx.x;
80
81 #pragma unroll
82 for (int i = 0; i < n_elems_per_thread; ++i) {
83 if (idx < total_n_elems) {
84 f(idx);
85 idx += n_threads;
86 }
87 }
88 }
89
90 template <int n_threads, int n_elems_per_thread, typename func_t>
_launch_kernel(int total_n_elems,func_t f)91 static void _launch_kernel(int total_n_elems, func_t f) {
92 TORCH_INTERNAL_ASSERT(
93 total_n_elems >= 0 && total_n_elems <= std::numeric_limits<int32_t>::max()
94 );
95
96 dim3 block(n_threads);
97 constexpr int total_work_block = n_threads * n_elems_per_thread;
98 dim3 grid((total_n_elems + total_work_block - 1) / total_work_block);
99
100 auto stream = at::cuda::getCurrentCUDAStream();
101 _elementwise_kernel<n_threads, n_elems_per_thread, func_t>
102 <<<grid, block, 0, stream>>>(total_n_elems, f);
103 C10_CUDA_KERNEL_LAUNCH_CHECK();
104 }
105
unpack_pivots_cuda_kernel(TensorIterator & iter,const int64_t dim_size,const int64_t max_pivot)106 void unpack_pivots_cuda_kernel(TensorIterator& iter, const int64_t dim_size, const int64_t max_pivot) {
107 if (iter.numel() == 0) {
108 return;
109 }
110
111 if (!iter.can_use_32bit_indexing()) {
112 for (auto& sub_iter : iter.with_32bit_indexing()) {
113 unpack_pivots_cuda_kernel(sub_iter, dim_size, max_pivot);
114 }
115 return;
116 }
117
118 const auto offset_calculator = make_offset_calculator<2>(iter);
119
120 const auto perm_ptr = reinterpret_cast<char*>(iter.data_ptr(0));
121 const auto pivots_ptr = reinterpret_cast<const char*>(iter.data_ptr(1));
122
123 auto loop = [=]C10_DEVICE(const int idx) {
124 const auto offsets = offset_calculator.get(idx);
125
126 int64_t* const __restrict__ perm_data = reinterpret_cast<int64_t*>(perm_ptr + offsets[0]);
127 const int32_t* const __restrict__ pivots_data = reinterpret_cast<const int32_t*>(pivots_ptr + offsets[1]);
128
129 // QUESTION: can we mix 64bit offsets with 32bit Iterator indexing?
130 for (int64_t i = 0; i < dim_size; ++i) {
131 thrust::swap(
132 perm_data[i],
133 perm_data[pivots_data[i] - 1]
134 );
135 }
136 };
137
138 _launch_kernel<num_threads(), thread_work_size()>(iter.numel(), loop);
139 }
140 } // anonymous namespace
141
142 REGISTER_DISPATCH(unpack_pivots_stub, &unpack_pivots_cuda_kernel);
143 REGISTER_DISPATCH(addr_stub, &addr_kernel_cuda);
144 } // namespace at::native
145