1# Owner(s): ["module: inductor"] 2 3import os 4import re 5import unittest 6 7import torch 8from torch import nn 9from torch._dynamo.testing import reset_rng_state 10from torch._inductor import config, test_operators 11from torch._inductor.codegen.multi_kernel import MultiKernelCall 12from torch._inductor.test_case import TestCase 13from torch._inductor.utils import run_and_get_code 14from torch.nn import functional as F 15from torch.testing import make_tensor 16from torch.testing._internal.common_utils import ( 17 instantiate_parametrized_tests, 18 parametrize, 19 skipIfXpu, 20) 21from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU 22 23 24class TransformerSnippet(nn.Module): 25 def __init__(self) -> None: 26 super().__init__() 27 self.ln1 = nn.LayerNorm(64) 28 self.ln2 = nn.LayerNorm(64) 29 30 def forward(self, x1, x2): 31 x1 = F.dropout(x1, 0.1) 32 x2 = F.dropout(self.ln1(x2), 0.1) 33 34 return self.ln2(x1 + x2) 35 36 def example_inputs(self): 37 return (torch.randn(2, 64).to(GPU_TYPE), torch.randn(2, 64).to(GPU_TYPE)) 38 39 40def _contains_multi_kernel_code(wrapper_code: str): 41 return ( 42 re.search(r"multi_kernel_[^ ]* = async_compile.multi_kernel[(]", wrapper_code) 43 is not None 44 ) 45 46 47def make_cpp_wrapper_test(orig_test, **extra_args): 48 """ 49 Wrap an existing test into a new test with cpp-wrapper enabled. 50 51 Make this as a free function rather than staticmethod in MultiKernelTest. 52 Otherwise we get 'TypeError: 'staticmethod' object is not callable' 53 error in py3.8. (py3.10 works) 54 """ 55 56 @config.patch("cpp_wrapper", True) 57 @skipIfXpu(msg="cpp wrapper doesn't currently work on the XPU stack") 58 def fn(self): 59 # The same kernel may have been compiled by previous tests with 60 # cpp_wrapper disabled. Clear the cache so we go ahead to re-compile 61 # the kernel with cpp_wrapper enabled. 62 from torch._inductor import codecache 63 64 codecache.PyCodeCache.cache_clear() 65 return orig_test(self, **extra_args) 66 67 return fn 68 69 70@config.patch( 71 { 72 "triton.multi_kernel": int(os.environ.get("TORCHINDUCTOR_MULTI_KERNEL", "1")), 73 "benchmark_kernel": True, 74 } 75) 76@instantiate_parametrized_tests 77class MultiKernelTest(TestCase): 78 def test_softmax(self, expect_multi_kernel=True): 79 x = torch.rand(2, 1024).to(GPU_TYPE) 80 ref = torch.softmax(x, -1) 81 compiled_fn = torch.compile(torch.softmax) 82 act, wrapper_code = run_and_get_code(compiled_fn, x, -1) 83 84 # wrapper_code will contains 2 entries if cpp_wrapper=True. 85 # One for the first pass and one for the second pass. 86 # We mainly care about the wrapper for the final pass here. 87 wrapper_code = wrapper_code[-1] 88 self.assertEqual(ref, act) 89 if expect_multi_kernel: 90 self.assertTrue(_contains_multi_kernel_code(wrapper_code)) 91 else: 92 # Skip verifying the wrapper_code in fbcode since we may fail 93 # compiling the cpp wrapper cuda code due to lacking proper setup of 94 # cuda compiler in fbcode environment. In that case, the last 95 # collected wrapper_code will corresponds to the first pass 96 # cpp-wrapper codegen which contains the multi-kernel. 97 if not config.is_fbcode(): 98 self.assertFalse(_contains_multi_kernel_code(wrapper_code)) 99 100 @parametrize("force_kernel", (0, 1)) 101 @unittest.mock.patch.dict( 102 os.environ, {"TORCHINDUCTOR_DISABLE_MULTI_KERNEL_CACHE": "1"} 103 ) 104 def test_softmax_force_non_persistent_reduction(self, force_kernel): 105 """ 106 Force a specific sub-kernel being picked by mocking the benchmark result. 107 """ 108 x = torch.rand(2, 1024).to(GPU_TYPE) 109 mock_latency = [0.2, 0.2] 110 mock_latency[force_kernel] = 0.1 # this make sure force_kernel will be picked 111 112 def f(x): 113 return torch.softmax(x, -1) + force_kernel 114 115 orig_run = MultiKernelCall.run 116 picked_kernel = None 117 118 def mock_run(self, *args, **kwargs): 119 out = orig_run(self, *args, **kwargs) 120 nonlocal picked_kernel 121 picked_kernel = self.picked_kernel 122 return out 123 124 with unittest.mock.patch.object( 125 MultiKernelCall, "run", mock_run 126 ), unittest.mock.patch.object( 127 MultiKernelCall, 128 "benchmark_sub_kernels", 129 lambda *args, **kwargs: mock_latency, 130 ): 131 torch.compile(f)(x) 132 self.assertEqual(picked_kernel, force_kernel) 133 134 @config.patch("warn_mix_layout", True) 135 def test_softmax_warn_mixed_layout(self): 136 self.test_softmax() 137 138 test_softmax_cpp_wrapper = make_cpp_wrapper_test( 139 test_softmax, expect_multi_kernel=False 140 ) 141 142 def test_layernorm(self): 143 ln = nn.LayerNorm(1024).to(GPU_TYPE) 144 x = torch.rand(2, 1024).to(GPU_TYPE) 145 ref = ln(x) 146 act = torch.compile(ln)(x) 147 self.assertEqual(ref, act, atol=1e-4, rtol=1e-4) 148 149 def test_inplace_update(self): 150 """ 151 Inductor generate inplace kernel for mul. 152 """ 153 154 def f(x, y): 155 return x.sum(dim=-1, keepdims=True) * (y @ y) 156 157 x = torch.rand(1024, 1024).to(GPU_TYPE) 158 y = torch.rand(1024, 1024).to(GPU_TYPE) 159 ref = f(x, y) 160 act = torch.compile(f)(x, y) 161 self.assertEqual(ref, act) 162 163 def test_transformer_snippet(self): 164 model = TransformerSnippet().to(GPU_TYPE) 165 x = model.example_inputs() 166 167 def f(*x): 168 y = model(*x) 169 return y 170 171 reset_rng_state() 172 ref = f(*x) 173 174 opt_f = torch.compile(f) 175 reset_rng_state() 176 act = opt_f(*x) 177 178 # don't compare tensor if using inductor random number generator. 179 # inductor random number implementation is different to eager. 180 # We should fallback to eager if we want to test accuracy. 181 if config.fallback_random: 182 self.assertEqual(ref, act, atol=1e-4, rtol=1e-4) 183 184 def test_transformer_snippet_with_fallback_random(self): 185 """ 186 Same as test_transformer_snippet but fallback the random number 187 generator to eager so we can check accuracy. 188 """ 189 with config.patch("fallback_random", True): 190 self.test_transformer_snippet() 191 192 def test_batchnorm_training(self): 193 """ 194 For training, batchnorm will tracking running mean/variance during forward pass. 195 The kernel generated by inductor currently will pass in those tensors twice as arguments: 196 once for input and once for output. They are ruled out as in-out argument because 197 they are considered as graph inputs. 198 199 Multi-kernel previously assumes that we never pass the same argument mutli times 200 for a kernel. No mater if we change inductor behavior to assure that, it's better 201 to make multi-kernel being able to handle those cases. 202 """ 203 bn = nn.BatchNorm2d(3).to(GPU_TYPE) 204 205 @torch.compile 206 def f(x): 207 bn(x).sum().backward() 208 209 _, (wrapper_code, _) = run_and_get_code( 210 f, torch.randn(2, 3, 8, 8, device=GPU_TYPE) 211 ) 212 self.assertTrue(_contains_multi_kernel_code(wrapper_code)) 213 214 def test_pass_same_arg_multi_times(self): 215 """ 216 A super simple example that simulate how BatchNorm update the running 217 stats. 218 219 Inductor currently pass the same tensor multiple times for the generated 220 kernel: once for input and once for output. 221 222 Here is a paster for the generated kernel (without multi-kernel enabled): 223 https://gist.github.com/shunting314/f0b446b4b9a28f4940e31dcd3e809cf9 224 """ 225 226 def f(x, y): 227 x = x.sum(dim=1, keepdim=False) 228 y.copy_(y * 0.9 + x * 0.1) 229 230 x = torch.randn(8, 16, device=GPU_TYPE) 231 y = torch.randn(8, device=GPU_TYPE) 232 y_ref = y.clone() 233 234 ref = f(x, y_ref) 235 act = torch.compile(f)(x, y) 236 self.assertEqual(y_ref, y) 237 238 def test_reduction_scratch_buffer(self, force_multi_kernel=1): 239 """ 240 The explicited realized buffer in the test function will be passed in 241 as a scratch buffer for the non-persistent reduction kernel but 242 can be skipped for the persistent reduction kernel. 243 244 This causes different argument lists for non-persistent reduction kernel and 245 persistent reduction kernel. 246 247 Check documentation around torch._inductor.config.triton.multi_kernel about 248 how to interpret the force_multi_kernel argument. 249 """ 250 251 def f(x): 252 x = x.sum(dim=-1, keepdim=True) + x 253 x = test_operators.realize(x) 254 x = x.sum(dim=-1, keepdim=True) + x 255 return x 256 257 x = torch.rand(16, 16, device=GPU_TYPE) 258 ref = f(x) 259 with config.patch("triton.multi_kernel", force_multi_kernel): 260 act = torch.compile(f)(x) 261 self.assertEqual(ref, act) 262 263 def test_split_scan(self, force_multi_kernel=1): 264 def f(x): 265 x = x.view(-1) 266 return torch.cumsum(x, 0) 267 268 x = make_tensor(10, 3, 352, 352, low=0, dtype=torch.float32, device=GPU_TYPE) 269 expect = f(x) 270 with config.patch("triton.multi_kernel", force_multi_kernel): 271 actual = torch.compile(f)(x) 272 self.assertEqual(expect, actual) 273 274 def test_sort_disables_multi_kernel(self, force_multi_kernel=1): 275 """ 276 Sort currently requires a persistent kernel, so multi-kernel is not 277 possible. Make sure this falls back gracefully. 278 """ 279 280 def f(x): 281 return x.sort(-1).values 282 283 x = torch.rand(32, 32, device=GPU_TYPE) 284 expect = f(x) 285 with config.patch("triton.multi_kernel", force_multi_kernel): 286 actual = torch.compile(f)(x) 287 self.assertEqual(expect, actual) 288 289 # Use benchmarking to pick the faster kernel 290 test_reduction_scratch_buffer_cpp_wrapper = make_cpp_wrapper_test( 291 test_reduction_scratch_buffer, force_multi_kernel=1 292 ) 293 # force pick persistent reduction. This can be a good test since this persistent 294 # reduction uses less call arguments than the corresponding non-persistent 295 # reduction. 296 test_reduction_scratch_buffer_cpp_wrapper_persistent_reduction = ( 297 make_cpp_wrapper_test(test_reduction_scratch_buffer, force_multi_kernel=2) 298 ) 299 # force pick non-persistent reduction 300 test_reduction_scratch_buffer_cpp_wrapper_non_persistent_reduction = ( 301 make_cpp_wrapper_test(test_reduction_scratch_buffer, force_multi_kernel=3) 302 ) 303 304 305if __name__ == "__main__": 306 from torch._inductor.test_case import run_tests 307 308 if HAS_GPU: 309 run_tests() 310