# Owner(s): ["module: cuda"] import torch from torch.cuda.jiterator import _create_jit_fn as create_jit_fn from torch.cuda.jiterator import _create_multi_output_jit_fn as create_multi_output_jit_fn import sys from itertools import product from torch.testing._internal.common_utils import TestCase, parametrize, run_tests, TEST_CUDA, NoTest from torch.testing._internal.common_dtype import all_types_and_complex_and from torch.testing._internal.common_device_type import ( skipCUDAIfVersionLessThan, instantiate_device_type_tests, dtypes, toleranceOverride, tol) if not TEST_CUDA: print('CUDA not available, skipping tests', file=sys.stderr) TestCase = NoTest # noqa: F811 code_string = "template T my_fused_kernel(T x, T y, T alpha, T beta) { return alpha * x + beta * y; }" jitted_fn = create_jit_fn(code_string, alpha=1, beta=1) def ref_fn(x, y, alpha=1, beta=1): return alpha * x + beta * y class TestPythonJiterator(TestCase): @parametrize("shape_strides", [ (([3, 3], [3, 1]), ([3, 3], [3, 1])), # contiguous ]) @dtypes(*product(all_types_and_complex_and(torch.half, torch.bfloat16), all_types_and_complex_and(torch.half, torch.bfloat16))) def test_all_dtype_contiguous(self, device, dtypes, shape_strides): a_buffer = torch.rand(9, device=device).mul(10).type(dtypes[0]) b_buffer = torch.rand(9, device=device).mul(10).type(dtypes[1]) a = a_buffer.as_strided(*shape_strides[0]) b = b_buffer.as_strided(*shape_strides[1]) expected = ref_fn(a, b) result = jitted_fn(a, b) self.assertEqual(expected, result) # See https://github.com/pytorch/pytorch/pull/76394#issuecomment-1118018287 for details # On cuda 11.3, nvrtcCompileProgram is taking too long to # compile jiterator generated kernels for non-contiguous input that requires dynamic-casting. @skipCUDAIfVersionLessThan((11, 6)) @parametrize("shape_strides", [ (([3, 3], [1, 3]), ([3, 1], [1, 3])), # non-contiguous ]) @dtypes(*product(all_types_and_complex_and(torch.half, torch.bfloat16), all_types_and_complex_and(torch.half, torch.bfloat16))) def test_all_dtype_noncontiguous(self, device, dtypes, shape_strides): a_buffer = torch.rand(9, device=device).mul(10).type(dtypes[0]) b_buffer = torch.rand(9, device=device).mul(10).type(dtypes[1]) a = a_buffer.as_strided(*shape_strides[0]) b = b_buffer.as_strided(*shape_strides[1]) expected = ref_fn(a, b) result = jitted_fn(a, b) self.assertEqual(expected, result) @dtypes(torch.float, torch.double, torch.float16, torch.bfloat16) @parametrize("alpha", [-1, 2.0, None]) @parametrize("beta", [3, -4.2, None]) @toleranceOverride({torch.float16 : tol(atol=1e-2, rtol=1e-3)}) def test_extra_args(self, device, dtype, alpha, beta): a = torch.rand(3, device=device).mul(10).type(dtype) b = torch.rand(3, device=device).mul(10).type(dtype) extra_args = {} if alpha is not None: extra_args["alpha"] = alpha if beta is not None: extra_args["beta"] = beta expected = ref_fn(a, b, **extra_args) result = jitted_fn(a, b, **extra_args) self.assertEqual(expected, result) @parametrize("is_train", [True, False]) def test_bool_extra_args(self, device, is_train): code_string = "template T conditional(T x, T mask, bool is_train) { return is_train ? x * mask : x; }" jitted_fn = create_jit_fn(code_string, is_train=False) def ref_fn(x, mask, is_train): return x * mask if is_train else x a = torch.rand(3, device=device) b = torch.rand(3, device=device) expected = ref_fn(a, b, is_train=is_train) result = jitted_fn(a, b, is_train=is_train) self.assertEqual(expected, result) def test_multiple_functors(self, device): code_string = ''' template T fn(T x, T mask) { return x * mask; } template T main_fn(T x, T mask, T y) { return fn(x, mask) + y; } ''' jitted_fn = create_jit_fn(code_string) def ref_fn(x, mask, y): return x * mask + y a = torch.rand(3, device=device) b = torch.rand(3, device=device) c = torch.rand(3, device=device) expected = ref_fn(a, b, c) result = jitted_fn(a, b, c) self.assertEqual(expected, result) @parametrize("num_inputs", [1, 5, 8]) def test_various_num_inputs(self, num_inputs): inputs = [] for i in range(num_inputs): inputs.append(torch.rand(3, device='cuda').mul(10)) input_string = ",".join([f"T i{i}" for i in range(num_inputs)]) function_body = "+".join([f"i{i}" for i in range(num_inputs)]) code_string = f"template T my_kernel({input_string}) {{ return {function_body}; }}" jitted_fn = create_jit_fn(code_string) def ref_fn(*inputs): return torch.sum(torch.stack(inputs), dim=0) expected = ref_fn(*inputs) result = jitted_fn(*inputs) self.assertEqual(expected, result) @parametrize("num_outputs", [1, 4, 8]) def test_various_num_outputs(self, num_outputs): input = torch.rand(3, device='cuda') output_string = ", ".join([f"T& out{i}" for i in range(num_outputs)]) function_body = "" for i in range(num_outputs): function_body += f"out{i} = input + {i};\n" # NB: return type must be void, otherwise ROCm silently fails code_string = f"template void my_kernel(T input, {output_string}) {{ {function_body} }}" jitted_fn = create_multi_output_jit_fn(code_string, num_outputs) def ref_fn(input): outputs = [] for i in range(num_outputs): outputs.append(input + i) if num_outputs == 1: return outputs[0] return tuple(outputs) expected = ref_fn(input) result = jitted_fn(input) for i in range(num_outputs): self.assertEqual(expected[i], result[i]) @parametrize("code_string", [ "template T my _kernel(T x) { return x; }", "template Tmy_kernel(T x) { return x; }", ]) def test_invalid_function_name(self, code_string): with self.assertRaises(Exception): jitted_fn = create_jit_fn(code_string) instantiate_device_type_tests(TestPythonJiterator, globals(), only_for="cuda") if __name__ == '__main__': run_tests()