1 #ifdef USE_C10D_NCCL
2
3 #include <ATen/Dispatch.h>
4 #include <ATen/cuda/CUDAContext.h>
5 #include <c10/cuda/CUDAGuard.h>
6 #include <torch/torch.h>
7 #include <algorithm>
8 #include <torch/csrc/distributed/c10d/NanCheck.hpp>
9
10 namespace c10d {
11
12 // CUDA kernel to check if data has NAN, device side assert
13 // is raised if NAN is found
14 template <typename T>
checkForNaN(T * data,size_t size)15 __global__ void checkForNaN(T* data, size_t size) {
16 size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
17 size_t stride = blockDim.x * gridDim.x;
18
19 for (size_t i = tid; i < size; i += stride) {
20 CUDA_KERNEL_ASSERT(!isnan(data[i]));
21 }
22 }
23
24 // CHECK if a Tensor contains NAN in any of its element
checkForNan(const at::Tensor & tensor,at::cuda::CUDAStream & stream)25 void checkForNan(const at::Tensor& tensor, at::cuda::CUDAStream& stream) {
26 // skip check for non float types
27 if (!torch::is_floating_point(tensor)) {
28 return;
29 }
30 const size_t maxNumThreadsPerBlock = 256;
31 const size_t maxNumBlocks = 24;
32 const size_t numThreadsPerBlock =
33 std::min<size_t>(maxNumThreadsPerBlock, tensor.numel());
34
35 const size_t numBlocks = std::min<size_t>(
36 maxNumBlocks,
37 (tensor.numel() + numThreadsPerBlock - 1) / numThreadsPerBlock);
38
39 AT_DISPATCH_FLOATING_TYPES_AND2(
40 at::ScalarType::Half,
41 at::ScalarType::BFloat16,
42 tensor.scalar_type(),
43 "checkForNaN",
44 [&] {
45 checkForNaN<scalar_t><<<numBlocks, numThreadsPerBlock, 0, stream>>>(
46 tensor.data_ptr<scalar_t>(), tensor.numel());
47 C10_CUDA_KERNEL_LAUNCH_CHECK();
48 });
49 }
50
51 } // namespace c10d
52
53 #endif // USE_C10D_NCCL
54