xref: /aosp_15_r20/external/pytorch/test/test_kernel_launch_checks.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: unknown"]
2
3from torch.testing._internal.common_utils import TestCase, run_tests
4from torch.testing._internal.check_kernel_launches import (
5    check_cuda_kernel_launches, check_code_for_cuda_kernel_launches
6)
7
8
9class AlwaysCheckCudaLaunchTest(TestCase):
10    def test_check_code(self):
11        """Verifies that the regex works for a few different situations"""
12
13        # Try some different spacings
14        self.assertEqual(2, check_code_for_cuda_kernel_launches("""
15some_function_call<TemplateArg><<<1,2,0,stream>>>(arg1,arg2,arg3);
16C10_CUDA_KERNEL_LAUNCH_CHECK();
17some_function_call<TemplateArg><<<1,2,0,stream>>>(arg1,arg2,arg3);
18
19some_function_call<TemplateArg><<<1,2,0,stream>>>(arg1,arg2,arg3);
20C10_CUDA_KERNEL_LAUNCH_CHECK();
21some_function_call<TemplateArg><<<1,2,0,stream>>>(arg1,arg2,arg3);
22some_other_stuff;
23some_function_call<TemplateArg><<<1,2,0,stream>>>(arg1,arg2,arg3);
24C10_CUDA_KERNEL_LAUNCH_CHECK();
25some_function_call<TemplateArg><<<1,2,0,stream>>> (arg1,arg2,arg3);
26C10_CUDA_KERNEL_LAUNCH_CHECK();
27some_function_call<TemplateArg><<<1,2,0,stream>>> ( arg1 , arg2 , arg3 ) ;
28
29    C10_CUDA_KERNEL_LAUNCH_CHECK();
30        """))
31
32        # Does it work for macros?
33        self.assertEqual(0, check_code_for_cuda_kernel_launches(r"""
34#define SOME_MACRO(x) some_function_call<<<1,2>>> ( x ) ;  \
35    C10_CUDA_KERNEL_LAUNCH_CHECK();
36
37#define SMALL_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE, SELF_DIM, SOURCE_DIM, IDX_DIM)  \
38  indexAddSmallIndex<TENSOR_TYPE, INDICES_TYPE, TYPE, SELF_DIM, SOURCE_DIM, IDX_DIM> \
39    <<<smallIndexGrid, smallIndexBlock, 0, stream>>>(                                \
40      selfInfo, sourceInfo, indexInfo,                                               \
41      selfAddDim, sourceAddDim, sliceSize, selfAddDimSize);                          \
42  C10_CUDA_KERNEL_LAUNCH_CHECK();
43        """))
44
45        # Does it work for lambdas?
46        self.assertEqual(1, check_code_for_cuda_kernel_launches(r"""
47            rrelu_with_noise_cuda_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
48                    numel,
49                    rng_engine_inputs,
50                    output_data,
51                    input_data,
52                    noise_data,
53                    lower,
54                    upper,
55                    [] __device__ (curandStatePhilox4_32_10_t* state) {
56                    return curand_uniform2_double(state);
57                    });
58                    C10_CUDA_KERNEL_LAUNCH_CHECK();
59
60            rrelu_with_noise_cuda_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
61                    numel,
62                    rng_engine_inputs,
63                    output_data,
64                    input_data,
65                    noise_data,
66                    lower,
67                    upper,
68                    [] __device__ (curandStatePhilox4_32_10_t* state) {
69                    return curand_uniform2_double(state);
70                    });
71                    uh oh;
72                    C10_CUDA_KERNEL_LAUNCH_CHECK();
73        """))
74
75    def test_check_cuda_launches(self):
76        unsafeLaunchesCount = check_cuda_kernel_launches()
77        self.assertTrue(unsafeLaunchesCount == 0)
78
79
80if __name__ == '__main__':
81    run_tests()
82