1# Owner(s): ["module: cuda"] 2 3import torch 4from torch.cuda.jiterator import _create_jit_fn as create_jit_fn 5from torch.cuda.jiterator import _create_multi_output_jit_fn as create_multi_output_jit_fn 6import sys 7from itertools import product 8from torch.testing._internal.common_utils import TestCase, parametrize, run_tests, TEST_CUDA, NoTest 9from torch.testing._internal.common_dtype import all_types_and_complex_and 10from torch.testing._internal.common_device_type import ( 11 skipCUDAIfVersionLessThan, instantiate_device_type_tests, dtypes, toleranceOverride, tol) 12 13if not TEST_CUDA: 14 print('CUDA not available, skipping tests', file=sys.stderr) 15 TestCase = NoTest # noqa: F811 16 17 18code_string = "template <typename T> T my_fused_kernel(T x, T y, T alpha, T beta) { return alpha * x + beta * y; }" 19jitted_fn = create_jit_fn(code_string, alpha=1, beta=1) 20 21def ref_fn(x, y, alpha=1, beta=1): 22 return alpha * x + beta * y 23 24class TestPythonJiterator(TestCase): 25 @parametrize("shape_strides", [ 26 (([3, 3], [3, 1]), ([3, 3], [3, 1])), # contiguous 27 ]) 28 @dtypes(*product(all_types_and_complex_and(torch.half, torch.bfloat16), 29 all_types_and_complex_and(torch.half, torch.bfloat16))) 30 def test_all_dtype_contiguous(self, device, dtypes, shape_strides): 31 a_buffer = torch.rand(9, device=device).mul(10).type(dtypes[0]) 32 b_buffer = torch.rand(9, device=device).mul(10).type(dtypes[1]) 33 34 a = a_buffer.as_strided(*shape_strides[0]) 35 b = b_buffer.as_strided(*shape_strides[1]) 36 37 expected = ref_fn(a, b) 38 result = jitted_fn(a, b) 39 40 self.assertEqual(expected, result) 41 42 # See https://github.com/pytorch/pytorch/pull/76394#issuecomment-1118018287 for details 43 # On cuda 11.3, nvrtcCompileProgram is taking too long to 44 # compile jiterator generated kernels for non-contiguous input that requires dynamic-casting. 45 @skipCUDAIfVersionLessThan((11, 6)) 46 @parametrize("shape_strides", [ 47 (([3, 3], [1, 3]), ([3, 1], [1, 3])), # non-contiguous 48 ]) 49 @dtypes(*product(all_types_and_complex_and(torch.half, torch.bfloat16), 50 all_types_and_complex_and(torch.half, torch.bfloat16))) 51 def test_all_dtype_noncontiguous(self, device, dtypes, shape_strides): 52 a_buffer = torch.rand(9, device=device).mul(10).type(dtypes[0]) 53 b_buffer = torch.rand(9, device=device).mul(10).type(dtypes[1]) 54 55 a = a_buffer.as_strided(*shape_strides[0]) 56 b = b_buffer.as_strided(*shape_strides[1]) 57 58 expected = ref_fn(a, b) 59 result = jitted_fn(a, b) 60 61 self.assertEqual(expected, result) 62 63 @dtypes(torch.float, torch.double, torch.float16, torch.bfloat16) 64 @parametrize("alpha", [-1, 2.0, None]) 65 @parametrize("beta", [3, -4.2, None]) 66 @toleranceOverride({torch.float16 : tol(atol=1e-2, rtol=1e-3)}) 67 def test_extra_args(self, device, dtype, alpha, beta): 68 a = torch.rand(3, device=device).mul(10).type(dtype) 69 b = torch.rand(3, device=device).mul(10).type(dtype) 70 71 extra_args = {} 72 if alpha is not None: 73 extra_args["alpha"] = alpha 74 if beta is not None: 75 extra_args["beta"] = beta 76 77 expected = ref_fn(a, b, **extra_args) 78 result = jitted_fn(a, b, **extra_args) 79 80 self.assertEqual(expected, result) 81 82 @parametrize("is_train", [True, False]) 83 def test_bool_extra_args(self, device, is_train): 84 code_string = "template <typename T> T conditional(T x, T mask, bool is_train) { return is_train ? x * mask : x; }" 85 jitted_fn = create_jit_fn(code_string, is_train=False) 86 87 def ref_fn(x, mask, is_train): 88 return x * mask if is_train else x 89 90 a = torch.rand(3, device=device) 91 b = torch.rand(3, device=device) 92 93 expected = ref_fn(a, b, is_train=is_train) 94 result = jitted_fn(a, b, is_train=is_train) 95 self.assertEqual(expected, result) 96 97 def test_multiple_functors(self, device): 98 code_string = ''' 99 template <typename T> T fn(T x, T mask) { return x * mask; } 100 template <typename T> T main_fn(T x, T mask, T y) { return fn(x, mask) + y; } 101 ''' 102 jitted_fn = create_jit_fn(code_string) 103 104 def ref_fn(x, mask, y): 105 return x * mask + y 106 107 a = torch.rand(3, device=device) 108 b = torch.rand(3, device=device) 109 c = torch.rand(3, device=device) 110 111 expected = ref_fn(a, b, c) 112 result = jitted_fn(a, b, c) 113 self.assertEqual(expected, result) 114 115 @parametrize("num_inputs", [1, 5, 8]) 116 def test_various_num_inputs(self, num_inputs): 117 inputs = [] 118 for i in range(num_inputs): 119 inputs.append(torch.rand(3, device='cuda').mul(10)) 120 121 input_string = ",".join([f"T i{i}" for i in range(num_inputs)]) 122 function_body = "+".join([f"i{i}" for i in range(num_inputs)]) 123 code_string = f"template <typename T> T my_kernel({input_string}) {{ return {function_body}; }}" 124 jitted_fn = create_jit_fn(code_string) 125 126 def ref_fn(*inputs): 127 return torch.sum(torch.stack(inputs), dim=0) 128 129 expected = ref_fn(*inputs) 130 result = jitted_fn(*inputs) 131 132 self.assertEqual(expected, result) 133 134 @parametrize("num_outputs", [1, 4, 8]) 135 def test_various_num_outputs(self, num_outputs): 136 input = torch.rand(3, device='cuda') 137 138 output_string = ", ".join([f"T& out{i}" for i in range(num_outputs)]) 139 function_body = "" 140 for i in range(num_outputs): 141 function_body += f"out{i} = input + {i};\n" 142 # NB: return type must be void, otherwise ROCm silently fails 143 code_string = f"template <typename T> void my_kernel(T input, {output_string}) {{ {function_body} }}" 144 145 jitted_fn = create_multi_output_jit_fn(code_string, num_outputs) 146 147 def ref_fn(input): 148 outputs = [] 149 for i in range(num_outputs): 150 outputs.append(input + i) 151 152 if num_outputs == 1: 153 return outputs[0] 154 return tuple(outputs) 155 156 expected = ref_fn(input) 157 result = jitted_fn(input) 158 159 for i in range(num_outputs): 160 self.assertEqual(expected[i], result[i]) 161 162 @parametrize("code_string", [ 163 "template <typename T> T my _kernel(T x) { return x; }", 164 "template <typename T> Tmy_kernel(T x) { return x; }", 165 ]) 166 def test_invalid_function_name(self, code_string): 167 with self.assertRaises(Exception): 168 jitted_fn = create_jit_fn(code_string) 169 170 171instantiate_device_type_tests(TestPythonJiterator, globals(), only_for="cuda") 172 173if __name__ == '__main__': 174 run_tests() 175