xref: /aosp_15_r20/external/pytorch/test/test_jiterator.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: cuda"]
2
3import torch
4from torch.cuda.jiterator import _create_jit_fn as create_jit_fn
5from torch.cuda.jiterator import _create_multi_output_jit_fn as create_multi_output_jit_fn
6import sys
7from itertools import product
8from torch.testing._internal.common_utils import TestCase, parametrize, run_tests, TEST_CUDA, NoTest
9from torch.testing._internal.common_dtype import all_types_and_complex_and
10from torch.testing._internal.common_device_type import (
11    skipCUDAIfVersionLessThan, instantiate_device_type_tests, dtypes, toleranceOverride, tol)
12
13if not TEST_CUDA:
14    print('CUDA not available, skipping tests', file=sys.stderr)
15    TestCase = NoTest  # noqa: F811
16
17
18code_string = "template <typename T> T my_fused_kernel(T x, T y, T alpha, T beta) { return alpha * x + beta * y; }"
19jitted_fn = create_jit_fn(code_string, alpha=1, beta=1)
20
21def ref_fn(x, y, alpha=1, beta=1):
22    return alpha * x + beta * y
23
24class TestPythonJiterator(TestCase):
25    @parametrize("shape_strides", [
26        (([3, 3], [3, 1]), ([3, 3], [3, 1])),  # contiguous
27    ])
28    @dtypes(*product(all_types_and_complex_and(torch.half, torch.bfloat16),
29                     all_types_and_complex_and(torch.half, torch.bfloat16)))
30    def test_all_dtype_contiguous(self, device, dtypes, shape_strides):
31        a_buffer = torch.rand(9, device=device).mul(10).type(dtypes[0])
32        b_buffer = torch.rand(9, device=device).mul(10).type(dtypes[1])
33
34        a = a_buffer.as_strided(*shape_strides[0])
35        b = b_buffer.as_strided(*shape_strides[1])
36
37        expected = ref_fn(a, b)
38        result = jitted_fn(a, b)
39
40        self.assertEqual(expected, result)
41
42    # See https://github.com/pytorch/pytorch/pull/76394#issuecomment-1118018287 for details
43    # On cuda 11.3, nvrtcCompileProgram is taking too long to
44    # compile jiterator generated kernels for non-contiguous input that requires dynamic-casting.
45    @skipCUDAIfVersionLessThan((11, 6))
46    @parametrize("shape_strides", [
47        (([3, 3], [1, 3]), ([3, 1], [1, 3])),  # non-contiguous
48    ])
49    @dtypes(*product(all_types_and_complex_and(torch.half, torch.bfloat16),
50                     all_types_and_complex_and(torch.half, torch.bfloat16)))
51    def test_all_dtype_noncontiguous(self, device, dtypes, shape_strides):
52        a_buffer = torch.rand(9, device=device).mul(10).type(dtypes[0])
53        b_buffer = torch.rand(9, device=device).mul(10).type(dtypes[1])
54
55        a = a_buffer.as_strided(*shape_strides[0])
56        b = b_buffer.as_strided(*shape_strides[1])
57
58        expected = ref_fn(a, b)
59        result = jitted_fn(a, b)
60
61        self.assertEqual(expected, result)
62
63    @dtypes(torch.float, torch.double, torch.float16, torch.bfloat16)
64    @parametrize("alpha", [-1, 2.0, None])
65    @parametrize("beta", [3, -4.2, None])
66    @toleranceOverride({torch.float16 : tol(atol=1e-2, rtol=1e-3)})
67    def test_extra_args(self, device, dtype, alpha, beta):
68        a = torch.rand(3, device=device).mul(10).type(dtype)
69        b = torch.rand(3, device=device).mul(10).type(dtype)
70
71        extra_args = {}
72        if alpha is not None:
73            extra_args["alpha"] = alpha
74        if beta is not None:
75            extra_args["beta"] = beta
76
77        expected = ref_fn(a, b, **extra_args)
78        result = jitted_fn(a, b, **extra_args)
79
80        self.assertEqual(expected, result)
81
82    @parametrize("is_train", [True, False])
83    def test_bool_extra_args(self, device, is_train):
84        code_string = "template <typename T> T conditional(T x, T mask, bool is_train) { return is_train ? x * mask : x; }"
85        jitted_fn = create_jit_fn(code_string, is_train=False)
86
87        def ref_fn(x, mask, is_train):
88            return x * mask if is_train else x
89
90        a = torch.rand(3, device=device)
91        b = torch.rand(3, device=device)
92
93        expected = ref_fn(a, b, is_train=is_train)
94        result = jitted_fn(a, b, is_train=is_train)
95        self.assertEqual(expected, result)
96
97    def test_multiple_functors(self, device):
98        code_string = '''
99        template <typename T> T fn(T x, T mask) { return x * mask; }
100        template <typename T> T main_fn(T x, T mask, T y) { return fn(x, mask) + y; }
101        '''
102        jitted_fn = create_jit_fn(code_string)
103
104        def ref_fn(x, mask, y):
105            return x * mask + y
106
107        a = torch.rand(3, device=device)
108        b = torch.rand(3, device=device)
109        c = torch.rand(3, device=device)
110
111        expected = ref_fn(a, b, c)
112        result = jitted_fn(a, b, c)
113        self.assertEqual(expected, result)
114
115    @parametrize("num_inputs", [1, 5, 8])
116    def test_various_num_inputs(self, num_inputs):
117        inputs = []
118        for i in range(num_inputs):
119            inputs.append(torch.rand(3, device='cuda').mul(10))
120
121        input_string = ",".join([f"T i{i}" for i in range(num_inputs)])
122        function_body = "+".join([f"i{i}" for i in range(num_inputs)])
123        code_string = f"template <typename T> T my_kernel({input_string}) {{ return {function_body}; }}"
124        jitted_fn = create_jit_fn(code_string)
125
126        def ref_fn(*inputs):
127            return torch.sum(torch.stack(inputs), dim=0)
128
129        expected = ref_fn(*inputs)
130        result = jitted_fn(*inputs)
131
132        self.assertEqual(expected, result)
133
134    @parametrize("num_outputs", [1, 4, 8])
135    def test_various_num_outputs(self, num_outputs):
136        input = torch.rand(3, device='cuda')
137
138        output_string = ", ".join([f"T& out{i}" for i in range(num_outputs)])
139        function_body = ""
140        for i in range(num_outputs):
141            function_body += f"out{i} = input + {i};\n"
142        # NB: return type must be void, otherwise ROCm silently fails
143        code_string = f"template <typename T> void my_kernel(T input, {output_string}) {{ {function_body} }}"
144
145        jitted_fn = create_multi_output_jit_fn(code_string, num_outputs)
146
147        def ref_fn(input):
148            outputs = []
149            for i in range(num_outputs):
150                outputs.append(input + i)
151
152            if num_outputs == 1:
153                return outputs[0]
154            return tuple(outputs)
155
156        expected = ref_fn(input)
157        result = jitted_fn(input)
158
159        for i in range(num_outputs):
160            self.assertEqual(expected[i], result[i])
161
162    @parametrize("code_string", [
163        "template <typename T> T my _kernel(T x) { return x; }",
164        "template <typename T> Tmy_kernel(T x) { return x; }",
165    ])
166    def test_invalid_function_name(self, code_string):
167        with self.assertRaises(Exception):
168            jitted_fn = create_jit_fn(code_string)
169
170
171instantiate_device_type_tests(TestPythonJiterator, globals(), only_for="cuda")
172
173if __name__ == '__main__':
174    run_tests()
175