xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/DistanceKernel.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/cuda/Exceptions.h>
5 #include <ATen/cuda/DeviceUtils.cuh>
6 #include <ATen/cuda/CUDAContext.h>
7 #include <math.h>
8 
9 #include <ATen/native/cuda/block_reduce.cuh>
10 #include <ATen/native/cuda/DeviceSqrt.cuh>
11 #include <ATen/native/Distance.h>
12 
13 #ifndef AT_PER_OPERATOR_HEADERS
14 #include <ATen/Functions.h>
15 #else
16 #include <ATen/ops/empty.h>
17 #include <ATen/ops/sum.h>
18 #endif
19 
20 #include <c10/macros/Macros.h>
21 
22 namespace at::native {
23 
24 namespace {
25 
26 constexpr int kCUDANumThreads = 256;
27 
28 template <typename scalar_t>
29 struct dists {
30 
signat::native::__anon65d58f980111::dists31   static __forceinline__ __device__ scalar_t sign(scalar_t val) {
32     return (0 < val) - (val < 0);
33   }
34 
35   // Zero norm
36   struct zero {
incat::native::__anon65d58f980111::dists::zero37     static __forceinline__ __device__ void inc(scalar_t& agg, const scalar_t diff, const scalar_t /*p*/) { agg += diff != 0.0; }
finishat::native::__anon65d58f980111::dists::zero38     static __forceinline__ __device__ scalar_t finish(const scalar_t agg, const scalar_t /*p*/) { return agg; }
aggat::native::__anon65d58f980111::dists::zero39     static __forceinline__ __device__ void agg(scalar_t& update, const scalar_t other) { update += other; }
40   };
41 
42   // One norm
43   struct one {
incat::native::__anon65d58f980111::dists::one44     static __forceinline__ __device__ void inc(scalar_t& agg, const scalar_t diff, const scalar_t /*p*/) { agg += diff; }
finishat::native::__anon65d58f980111::dists::one45     static __forceinline__ __device__ scalar_t finish(const scalar_t agg, const scalar_t /*p*/) { return agg; }
aggat::native::__anon65d58f980111::dists::one46     static __forceinline__ __device__ void agg(scalar_t& update, const scalar_t other) { update += other; }
backwardat::native::__anon65d58f980111::dists::one47     static __forceinline__ __device__ scalar_t backward(const scalar_t diff, const scalar_t grad, const scalar_t /*dist*/, const scalar_t /*p*/) { return grad * sign(diff); }
48   };
49 
50   // Special case backward when p is less than two
51   struct lt_two {
backwardat::native::__anon65d58f980111::dists::lt_two52     static __forceinline__ __device__ scalar_t backward(const scalar_t diff, const scalar_t grad, const scalar_t dist, const scalar_t p) {
53       return (dist == 0.0 || (diff == 0.0 && p < 1)) ? 0 : (sign(diff) * std::pow(std::abs(diff), p - 1) * grad / std::pow(dist, p - 1));
54     }
55   };
56 
57   // Two norm
58   struct two {
incat::native::__anon65d58f980111::dists::two59     static __forceinline__ __device__ void inc(scalar_t& agg, const scalar_t diff, const scalar_t /*p*/) { agg += diff * diff; }
finishat::native::__anon65d58f980111::dists::two60     static __forceinline__ __device__ scalar_t finish(const scalar_t agg, const scalar_t /*p*/) { return device_sqrt<scalar_t>(agg); }
aggat::native::__anon65d58f980111::dists::two61     static __forceinline__ __device__ void agg(scalar_t& update, const scalar_t other) { update += other; }
backwardat::native::__anon65d58f980111::dists::two62     static __forceinline__ __device__ scalar_t backward(const scalar_t diff, const scalar_t grad, const scalar_t dist, const scalar_t /*p*/) { return dist == 0.0 ? 0 : grad * diff / dist; }
63   };
64 
65   // General p norm
66   struct p {
incat::native::__anon65d58f980111::dists::p67     static __forceinline__ __device__ void inc(scalar_t& agg, const scalar_t diff, const scalar_t p) { agg += std::pow(diff, p); }
finishat::native::__anon65d58f980111::dists::p68     static __forceinline__ __device__ scalar_t finish(const scalar_t agg, const scalar_t p) { return std::pow(agg, static_cast<scalar_t>(1) / p); }
aggat::native::__anon65d58f980111::dists::p69     static __forceinline__ __device__ void agg(scalar_t& update, const scalar_t other) { update += other; }
backwardat::native::__anon65d58f980111::dists::p70     static __forceinline__ __device__ scalar_t backward(const scalar_t diff, const scalar_t grad, const scalar_t dist, const scalar_t p) { return dist == 0.0 ? 0 : diff * std::pow(std::abs(diff), p - 2) * grad / std::pow(dist, p - 1); }
71   };
72 
73   // Inf norm
74   struct inf {
incat::native::__anon65d58f980111::dists::inf75     static __forceinline__ __device__ void inc(scalar_t& agg, const scalar_t diff, const scalar_t /*p*/) { if (diff > agg) { agg = diff; } }
finishat::native::__anon65d58f980111::dists::inf76     static __forceinline__ __device__ scalar_t finish(const scalar_t agg, const scalar_t /*p*/) { return agg; }
aggat::native::__anon65d58f980111::dists::inf77     static __forceinline__ __device__ void agg(scalar_t& update, const scalar_t other) { if (other > update) { update = other; } }
backwardat::native::__anon65d58f980111::dists::inf78     static __forceinline__ __device__ scalar_t backward(const scalar_t diff, const scalar_t grad, const scalar_t dist, const scalar_t /*p*/) { return grad * sign(diff) * (std::abs(diff) == dist); }
79   };
80 
81 };
82 
83 template <typename scalar_t, typename F>
84 struct DistReduceOp {
combineat::native::__anon65d58f980111::DistReduceOp85     __forceinline__ __device__ scalar_t combine(scalar_t a, scalar_t b) const {
86         F::agg(a, b);
87         return a;
88     }
89 
warp_shfl_downat::native::__anon65d58f980111::DistReduceOp90     __forceinline__ __device__ scalar_t warp_shfl_down(scalar_t data, int offset) const {
91         return WARP_SHFL_DOWN(data, offset);
92     }
93 };
94 
95 template <typename scalar_t, typename F>
pdist_kernel_cuda_impl(scalar_t * result,const scalar_t * self,const int64_t n,const int64_t m,const scalar_t p,const double n2,const double n2_squared_minus_1)96 __global__ static void pdist_kernel_cuda_impl(scalar_t * result, const scalar_t * self, const int64_t n, const int64_t m, const scalar_t p,
97                                               const double n2, const double n2_squared_minus_1) {
98   const int64_t k = blockIdx.x;
99   const int stride = blockDim.x;
100 
101   // The -1 accounts for floating point truncation issues
102   int64_t i = static_cast<int64_t>((n2 - device_sqrt<double>(n2_squared_minus_1 - 2 * k)));
103   int64_t j = k - n * i + i * (i + 1) / 2 + i + 1;
104 
105   const scalar_t * const start = self + i * m;
106   const scalar_t * const end = start + m;
107   const scalar_t * a = start + threadIdx.x;
108   const scalar_t * b = self + j * m + threadIdx.x;
109   scalar_t agg = 0.0;
110   for (; a < end; a += stride, b += stride) {
111     F::inc(agg, std::abs(*a - *b), p);
112   }
113 
114   __shared__ scalar_t agg_smem[kCUDANumThreads];
115   scalar_t agg_init{0.0};
116   agg = cuda_utils::BlockReduce(agg, DistReduceOp<scalar_t, F>{}, agg_init, agg_smem);
117   if (threadIdx.x == 0) {
118     result[k] = F::finish(agg, p);
119   }
120 }
121 
122 template <typename scalar_t, typename F>
cdist_backward_kernel_cuda_impl(scalar_t * buffer,const scalar_t * grad,const scalar_t * x1,const scalar_t * x2,const scalar_t * dist,const scalar_t p,const int64_t r1,const int64_t r2,const int64_t m,const int64_t count,const int64_t r_size,const int64_t l1_size,const int64_t l2_size)123 __global__ static void cdist_backward_kernel_cuda_impl(scalar_t * buffer, const scalar_t * grad, const scalar_t * x1, const scalar_t * x2, const scalar_t * dist,
124                                                        const scalar_t p, const int64_t r1, const int64_t r2, const int64_t m, const int64_t count, const int64_t r_size, const int64_t l1_size, const int64_t l2_size) {
125   const int y = (blockIdx.y * gridDim.z + blockIdx.z) * blockDim.y + threadIdx.y;
126   const int init = blockIdx.x * blockDim.x + threadIdx.x;
127   if (y >= count || init >= m) {
128     return;
129   }
130   const int l = y / r_size;
131   const int k = y % r_size;
132   const int stride = blockDim.x * gridDim.x;
133   const int l_size = r_size * m;
134 
135   int64_t i = k / r2;
136   int64_t j = k % r2;
137 
138   const scalar_t grad_k = grad[y];
139   const scalar_t dist_k = dist[y];
140 
141   const scalar_t * const start = x1 + l * l1_size + i * m;
142   const scalar_t * const end = start + m;
143   const scalar_t * self_i = start + init;
144   const scalar_t * self_j = x2 + l * l2_size + j * m + init;
145 
146   scalar_t * buff_i = buffer + l * l_size + (r1 * j + i) * m + init;
147 
148   for (; self_i < end; self_i += stride, self_j += stride, buff_i += stride) {
149     const scalar_t res = F::backward(*self_i - *self_j, grad_k, dist_k, p);
150     *buff_i = res;
151   }
152 }
153 
154 template <typename scalar_t, typename F>
pdist_backward_kernel_cuda_impl(scalar_t * buffer,const scalar_t * grad,const scalar_t * self,const scalar_t * dist,int64_t gs,const int64_t n,const int64_t m,const int64_t combs,const scalar_t p,const double n2,const double n2_squared_minus_1)155 __global__ static void pdist_backward_kernel_cuda_impl(scalar_t * buffer, const scalar_t * grad, const scalar_t * self, const scalar_t * dist, int64_t gs, const int64_t n, const int64_t m, const int64_t combs, const scalar_t p,
156                                                        const double n2, const double n2_squared_minus_1) {
157   const int64_t k = blockIdx.x * blockDim.x + threadIdx.x;
158   const int init = blockIdx.y * blockDim.y + threadIdx.y;
159   const int stride = blockDim.y * gridDim.y;
160 
161   if (k >= combs) {
162     return;
163   }
164 
165   // The -1 accounts for floating point truncation issues
166   int64_t i = static_cast<int64_t>((n2 - device_sqrt<double>(n2_squared_minus_1 - 2 * k)));
167   int64_t j = k - n * i + i * (i + 1) / 2 + i + 1;
168   int64_t ib = j - i - 1;
169   int64_t jb = n - 2 - i;
170 
171   const scalar_t grad_k = grad[k * gs];
172   const scalar_t dist_k = dist[k];
173 
174   const scalar_t * const start = self + i * m;
175   const scalar_t * const end = start + m;
176   const scalar_t * self_i = start + init;
177   const scalar_t * self_j = self + j * m + init;
178   scalar_t * buff_i = buffer + (ib * n + i) * m + init;
179   scalar_t * buff_j = buffer + (jb * n + j) * m + init;
180   for (; self_i < end; self_i += stride, self_j += stride, buff_i += stride, buff_j += stride) {
181     const scalar_t res = F::backward(*self_i - *self_j, grad_k, dist_k, p);
182     *buff_i = res;
183     *buff_j = -res;
184   }
185 }
186 
187 template <typename scalar_t, typename F>
cdist_kernel_cuda_impl(scalar_t * result,const scalar_t * x1,const scalar_t * x2,const scalar_t p,const int64_t r2,const int64_t m,const int64_t r_size,const int64_t l1_size,const int64_t l2_size)188 __global__ static void cdist_kernel_cuda_impl(scalar_t * result, const scalar_t * x1, const scalar_t * x2,
189     const scalar_t p, const int64_t r2, const int64_t m, const int64_t r_size, const int64_t l1_size, const int64_t l2_size) {
190   const int64_t l = blockIdx.x / r_size;
191   const int64_t k = blockIdx.x % r_size;
192   const int64_t i = k / r2;
193   const int64_t j = k % r2;
194   const int stride = blockDim.x;
195 
196   const scalar_t * const start = x1 + l * l1_size + i * m;
197   const scalar_t * const end = start + m;
198   const scalar_t * a = start + threadIdx.x;
199   const scalar_t * b = x2 + l * l2_size + j * m + threadIdx.x;
200 
201   scalar_t agg = 0.0;
202   for (; a < end; a += stride, b += stride) {
203     F::inc(agg, std::abs(*a - *b), p);
204   }
205   __shared__ scalar_t agg_smem[kCUDANumThreads];
206   scalar_t agg_init{0.0};
207   agg = cuda_utils::BlockReduce(agg, DistReduceOp<scalar_t, F>{}, agg_init, agg_smem);
208   if (threadIdx.x == 0) {
209     result[blockIdx.x] = F::finish(agg, p);
210   }
211 }
212 
cdist_kernel_impl(Tensor & result,const Tensor & x1,const Tensor & x2,double p)213 void cdist_kernel_impl(Tensor& result, const Tensor& x1, const Tensor& x2, double p) {
214   const int64_t r1 = x1.size(-2);
215   const int64_t r2 = x2.size(-2);
216   const int64_t m = x1.size(-1);
217   const int64_t r_size = r1 * r2;
218   const int64_t l1_size = r1 * m;
219   const int64_t l2_size = r2 * m;
220   const dim3 grid(result.numel());
221   const dim3 block(kCUDANumThreads);
222 
223   AT_DISPATCH_FLOATING_TYPES(x1.scalar_type(), "cdist_cuda", [&] {
224     auto impl_fptr = cdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::p>;
225     if (p == 0.0) {
226       impl_fptr = cdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::zero>;
227     } else if (p == 1.0) {
228       impl_fptr = cdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::one>;
229     } else if (p == 2.0) {
230       impl_fptr = cdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::two>;
231     } else if (std::isinf(p)) {
232       impl_fptr = cdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::inf>;
233     }
234     impl_fptr<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(result.mutable_data_ptr<scalar_t>(), x1.const_data_ptr<scalar_t>(), x2.const_data_ptr<scalar_t>(), p, r2, m, r_size, l1_size, l2_size);
235     C10_CUDA_KERNEL_LAUNCH_CHECK();
236   });
237 }
238 
pdist_forward_kernel_impl(Tensor & result,const Tensor & self,double p)239 void pdist_forward_kernel_impl(Tensor& result, const Tensor& self, double p) {
240   const dim3 grid(result.numel());
241   const dim3 block(kCUDANumThreads);
242   int64_t n = self.size(0);
243   int64_t m = self.size(1);
244   // https://github.com/pytorch/pytorch/issues/15511 demonstrated we need to do
245   // some math in fp64 -- this is just minimizing the amount of fp64 math we do on the device.
246   const double n2 = n - .5;
247   const double n2_squared_minus_1 = n2 * n2 - 1;
248 
249   AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "pdist_cuda", [&] {
250     auto impl_fptr = pdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::p>;
251     if (p == 0.0) {
252       impl_fptr = pdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::zero>;
253     } else if (p == 1.0) {
254       impl_fptr = pdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::one>;
255     } else if (p == 2.0) {
256       impl_fptr = pdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::two>;
257     } else if (std::isinf(p)) {
258       impl_fptr = pdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::inf>;
259     }
260     impl_fptr<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(result.mutable_data_ptr<scalar_t>(), self.const_data_ptr<scalar_t>(), n, m, p, n2, n2_squared_minus_1);
261     C10_CUDA_KERNEL_LAUNCH_CHECK();
262   });
263 }
264 
pdist_backward_kernel_impl(Tensor & result,const Tensor & grad,const Tensor & self,const double p,const Tensor & dist)265 void pdist_backward_kernel_impl(Tensor& result, const Tensor& grad, const Tensor& self, const double p, const Tensor& dist) {
266   if (p == 0.0 || grad.numel() == 0 || self.numel() == 0) {
267     result.fill_(0);
268     return;
269   }
270 
271   const int64_t n = result.size(0);
272   int64_t m = self.size(1);
273   const int block_x = 16;
274   // NB: be careful with changing block_y; as it's currently written, grid_y is limited to be 2^16.
275   // block_y of 64 gives us max pdist dim1 of 2**24
276   const int block_y = 64;
277   const int grid_x = (dist.numel() + block_x - 1) / block_x;
278   const int grid_y = (m + block_y * 8 - 1) / (block_y * 8);
279   const dim3 grid(grid_x, grid_y);
280   const dim3 block(block_x, block_y);
281   // https://github.com/pytorch/pytorch/issues/15511 demonstrated we need to do
282   // some math in fp64 -- this is just minimizing the amount of fp64 math we do on the device.
283   const double n2 = n - .5;
284   const double n2_squared_minus_1 = n2 * n2 - 1;
285 
286   Tensor buffer = at::empty({n - 1, result.size(0), result.size(1)}, result.options());
287   AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "pdist_cuda_backward", [&] {
288     auto impl_fptr = pdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::p>;
289     if (p == 1.0) {
290       impl_fptr = pdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::one>;
291     } else if (p < 2.0) {
292       impl_fptr = pdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::lt_two>;
293     } else if (p == 2.0) {
294       impl_fptr = pdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::two>;
295     } else if (std::isinf(p)) {
296       impl_fptr = pdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::inf>;
297     }
298     impl_fptr<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(buffer.mutable_data_ptr<scalar_t>(), grad.const_data_ptr<scalar_t>(), self.const_data_ptr<scalar_t>(), dist.const_data_ptr<scalar_t>(), grad.stride(0), n, m, dist.numel(), p, n2, n2_squared_minus_1);
299     C10_CUDA_KERNEL_LAUNCH_CHECK();
300   });
301 
302   at::sum_out(result, buffer, 0);
303 }
304 
cdist_backward_kernel_impl(Tensor & result,const Tensor & grad,const Tensor & x1,const Tensor & x2,const double p,const Tensor & dist)305 void cdist_backward_kernel_impl(Tensor& result, const Tensor& grad, const Tensor& x1, const Tensor& x2, const double p, const Tensor& dist) {
306   if (p == 0.0 || grad.numel() == 0 || x1.numel() == 0 || x2.numel() == 0) {
307     result.fill_(0);
308     return;
309   }
310 
311   const int64_t r1 = x1.size(-2);
312   const int64_t r2 = x2.size(-2);
313   const int64_t m = x1.size(-1);
314   // Just like we do in the CPU code, assume that result is always batched
315   int64_t batch = result.size(0);
316   const int block_x = 64;
317   const int block_y = 16;
318   const int grid_x = (m + block_x * 8 - 1) / (block_x * 8);
319 
320   const int64_t count = dist.numel();
321   const int64_t grid_temp = (count + block_y - 1) / block_y;
322 
323   const int grid_y = (grid_temp - 1) / 65535 + 1;
324   const int grid_z = (grid_temp - 1) / grid_y + 1;
325 
326   const dim3 grid(grid_x, grid_y, grid_z);
327   const dim3 block(block_x, block_y);
328 
329   const int64_t r_size = r1 * r2;
330   const int64_t l1_size = r1 * m;
331   const int64_t l2_size = r2 * m;
332   //current implementation supports only gradient that can be collapsed to 1D. However, to avoid checking this assumption,
333   //we call grad.contiguous() before backward, so stride is guaranteed to be 1
334 
335   Tensor buffer = at::empty({batch, r2, r1, m}, result.options());
336   AT_DISPATCH_FLOATING_TYPES(result.scalar_type(), "cdist_cuda_backward", [&] {
337     auto impl_fptr = cdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::p>;
338     if (p == 1.0) {
339       impl_fptr = cdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::one>;
340     } else if (p < 2.0) {
341        impl_fptr = cdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::lt_two>;
342     } else if (p == 2.0) {
343        impl_fptr = cdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::two>;
344     } else if (std::isinf(p)) {
345        impl_fptr = cdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::inf>;
346     }
347     impl_fptr<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(buffer.mutable_data_ptr<scalar_t>(),
348       grad.const_data_ptr<scalar_t>(), x1.const_data_ptr<scalar_t>(), x2.const_data_ptr<scalar_t>(), dist.const_data_ptr<scalar_t>(),
349       p, r1, r2, m, count, r_size, l1_size, l2_size);
350     C10_CUDA_KERNEL_LAUNCH_CHECK();
351   });
352 
353   at::sum_out(result, buffer, 1);
354 
355 }
356 
357 
358 } // anonymous namespace
359 
360 REGISTER_DISPATCH(pdist_forward_stub, &pdist_forward_kernel_impl);
361 REGISTER_DISPATCH(pdist_backward_stub, &pdist_backward_kernel_impl);
362 REGISTER_DISPATCH(cdist_stub, &cdist_kernel_impl);
363 REGISTER_DISPATCH(cdist_backward_stub, &cdist_backward_kernel_impl);
364 
365 } // at::native
366