1# Owner(s): ["module: inductor"] 2import copy 3import functools 4import os 5import unittest 6from typing import Tuple 7 8import torch 9from torch import nn, Tensor 10from torch._dynamo.convert_frame import maybe_cprofile 11from torch._dynamo.test_case import run_tests, TestCase 12from torch._dynamo.testing import rand_strided, reduce_to_scalar_loss 13from torch._inductor import config, ir, metrics 14from torch._inductor.fx_passes import pad_mm as pad_mm_pass 15from torch._inductor.runtime.benchmarking import benchmarker 16from torch._inductor.utils import ceildiv, run_and_get_code 17from torch.testing._internal.common_utils import ( 18 instantiate_parametrized_tests, 19 parametrize, 20 requires_cuda, 21 serialTest, 22) 23from torch.testing._internal.inductor_utils import HAS_CUDA 24 25 26DO_PERF_TEST = os.environ.get("DO_PERF_TEST") == "1" 27DO_ACC_TEST = os.environ.get("DO_ACC_TEST", "1") == "1" 28WITH_STACK = os.environ.get("WITH_STACK") == "1" 29USE_CUDA_GRAPHS = os.environ.get("USE_CUDA_GRAPHS", "1") == "1" 30 31try: 32 import transformers # noqa: F401 33 34 HAS_TRANSFORMER = True 35except ImportError: 36 HAS_TRANSFORMER = False 37 38 39def get_optim(m): 40 return torch.optim.Adam(m.parameters(), lr=0.01, capturable=True, foreach=True) 41 42 43def gen_transformer_inputs(vocab_size, bs, seq_length): 44 def geninp(): 45 return torch.randint( 46 0, vocab_size, (bs, seq_length), dtype=torch.int64, requires_grad=False 47 ) 48 49 input_dict = {"input_ids": geninp(), "labels": geninp()} 50 return input_dict 51 52 53class LinearAndSoftmax(nn.Module): 54 """ 55 It's very common that a transformer model will do a matmul and then 56 softmax/log_softmax in the end. 57 58 Creating this toy model to capture the pattern and make sure we do 59 proper padding. 60 """ 61 62 def __init__(self, vocab_size=30523, bias=True): 63 """ 64 The default vocab size for BertForMaskedLM is 30522. 65 We run a few test cases with good or bad vocab_size around Bert's 66 default value. 67 """ 68 super().__init__() 69 self.vocab_size = vocab_size 70 self.linear = nn.Linear(768, vocab_size, bias=bias) 71 self.ce = nn.CrossEntropyLoss() 72 73 def forward(self, x, label): 74 x = self.linear(x) 75 return self.ce(x.view(-1, self.vocab_size), label.view(-1)) 76 77 def get_example_inputs(self, batch_size=16): 78 return torch.randn(batch_size, 512, 768), torch.randint( 79 0, self.vocab_size, (batch_size, 512) 80 ) 81 82 83def forward_and_backward_pass(m, inputs): 84 m(*inputs).sum().backward() 85 86 87@config.patch( 88 { 89 "benchmark_kernel": True, 90 "triton.unique_kernel_names": True, 91 "triton.cudagraphs": USE_CUDA_GRAPHS, 92 } 93) 94@requires_cuda 95class TestCaseBase(TestCase): 96 @classmethod 97 def setUpClass(cls): 98 if HAS_CUDA: 99 cls.prior_float32_matmul_precision = torch.get_float32_matmul_precision() 100 cls.prior_default_device = torch.get_default_device() 101 torch.set_float32_matmul_precision("high") 102 torch.set_default_device("cuda") 103 104 @classmethod 105 def tearDownClass(cls): 106 if HAS_CUDA: 107 torch.set_float32_matmul_precision(cls.prior_float32_matmul_precision) 108 torch.set_default_device(cls.prior_default_device) 109 110 cls.prior_float32_matmul_precision = None 111 cls.prior_default_device = None 112 113 def check_close(self, ref, act, tol=1e-3): 114 if type(ref).__name__ == "LongformerMaskedLMOutput": 115 ref = ref.loss 116 act = act.loss 117 if type(ref).__name__ == "SequenceClassifierOutput": 118 ref = ref.logits 119 act = act.logits 120 if isinstance(ref, dict) and "loss" in ref: 121 ref = ref["loss"] 122 act = act["loss"] 123 self.assertTrue( 124 torch.allclose(ref, act, atol=tol, rtol=tol), f"ref:\n{ref}\nact:\n{act}" 125 ) 126 127 def common_numeric_check(self, f, *args, tol=1e-3, **kwargs): 128 ref = f(*args, **kwargs) 129 opt_f = torch.compile(f) 130 act = opt_f(*args, **kwargs) 131 self.check_close(ref, act, tol) 132 133 def do_profiling( 134 self, 135 f_lhs, 136 f_rhs, 137 tag_lhs="With padding", 138 tag_rhs="Without padding", 139 args=(), 140 kwargs=None, 141 ): 142 if kwargs is None: 143 kwargs = {} 144 torch.cuda.synchronize() 145 with torch.profiler.profile(with_stack=WITH_STACK) as p: 146 niter = 3 147 for _ in range(niter): 148 with torch.profiler.record_function(tag_lhs): 149 f_lhs(*args, **kwargs) 150 151 with torch.profiler.record_function(tag_rhs): 152 f_rhs(*args, **kwargs) 153 torch.cuda.synchronize() 154 155 profile_path = "/tmp/chrome.json" 156 p.export_chrome_trace(profile_path) 157 print(f"Chrome trace is written to {profile_path}") 158 159 160class PerfTestBetweenGoodAndBadShape(TestCaseBase): 161 @unittest.skipIf(not DO_PERF_TEST, "Perf test not enabled") 162 def test_nobias_LinearAndSoftmax_both_shapes(self): 163 self.test_LinearAndSoftmax_both_shapes(bias=False) 164 165 @unittest.skipIf(not DO_PERF_TEST, "Perf test not enabled") 166 def test_LinearAndSoftmax_both_shapes(self, bias=True): 167 """ 168 Compare the perf with good and bad shape. 169 """ 170 m_bad_shape = LinearAndSoftmax(vocab_size=30523, bias=bias) 171 inptus_bad_shape = m_bad_shape.get_example_inputs() 172 m_good_shape = LinearAndSoftmax(vocab_size=30528, bias=bias) 173 inputs_good_shape = m_good_shape.get_example_inputs() 174 175 m_bad_shape_opt = torch.compile(m_bad_shape) 176 m_good_shape_opt = torch.compile(m_good_shape) 177 178 latency_good_shape = benchmarker.benchmark_gpu( 179 lambda: forward_and_backward_pass(m_good_shape_opt, inputs_good_shape) 180 ) 181 latency_bad_shape = benchmarker.benchmark_gpu( 182 lambda: forward_and_backward_pass(m_bad_shape_opt, inptus_bad_shape) 183 ) 184 print( 185 f"Latency for good shape v.s. bad shape: {latency_good_shape:.3f}ms v.s. {latency_bad_shape:.3f}ms" 186 ) 187 188 @unittest.skipIf(not DO_PERF_TEST or not HAS_TRANSFORMER, "Perf test not enabled") 189 def test_BertForMaskedLM(self, num_layers=1): 190 """ 191 Compare the perf between doing padding and good shape. 192 """ 193 from transformers import BertForMaskedLM 194 195 config_cls = BertForMaskedLM.config_class 196 bs = 16 197 seq_length = 512 198 199 def create_model(vocab_size): 200 config = config_cls() 201 config.num_hidden_layers = num_layers 202 config.vocab_size = vocab_size 203 inputs = gen_transformer_inputs(config.vocab_size, bs, seq_length) 204 model = BertForMaskedLM(config) 205 206 optim = get_optim(model) 207 208 def f(**inputs): 209 optim.zero_grad(True) 210 with torch.cuda.amp.autocast(): 211 pred = model(**inputs) 212 loss = pred[0] 213 loss.backward() 214 optim.step() 215 216 return torch.compile(f), inputs 217 218 f_good_shape, inputs_good_shape = create_model(30528) 219 f_bad_shape, inputs_bad_shape = create_model(30522) 220 221 print("benchmark for good shape") 222 latency_good_shape = benchmarker.benchmark_gpu( 223 lambda: f_good_shape(**inputs_good_shape) 224 ) 225 print("benchmark for bad shape") 226 latency_bad_shape = benchmarker.benchmark_gpu( 227 lambda: f_bad_shape(**inputs_bad_shape) 228 ) 229 print( 230 f"Latency with good and bad shape: {latency_good_shape:.3f} v.s. {latency_bad_shape:.3f}" 231 ) 232 233 self.do_profiling( 234 lambda: f_good_shape(**inputs_good_shape), 235 lambda: f_bad_shape(**inputs_bad_shape), 236 tag_lhs="With good shape", 237 tag_rhs="With bad shape", 238 ) 239 240 241class PerfTestWithAndWithoutPadding(TestCaseBase): 242 @maybe_cprofile 243 def run_acc_and_perf_test(self, model, inputs, perf_inputs=None, tol=1e-3): 244 """ 245 Run accuracy test. 246 247 Also compare the perf with and without the comprehensive padding if 248 DO_PERF_TEST is true. 249 """ 250 if perf_inputs is None: 251 perf_inputs = inputs 252 253 def _process_inputs(x): 254 """ 255 return args and kwargs 256 """ 257 if isinstance(x, dict): 258 return [], x 259 260 if not isinstance(inputs, (tuple, list)): 261 x = [x] 262 263 return x, {} 264 265 args, kwargs = _process_inputs(inputs) 266 perf_args, perf_kwargs = _process_inputs(perf_inputs) 267 268 if DO_ACC_TEST: 269 model.eval() 270 self.common_numeric_check(model, *args, **kwargs, tol=tol) 271 else: 272 print("Accuracy test skipped") 273 274 model.train() 275 276 if DO_PERF_TEST: 277 print("Do performance test") 278 279 def get_f(m, optim): 280 def f(*args, **kwargs): 281 optim.zero_grad(True) 282 with torch.cuda.amp.autocast(): 283 pred = m(*args, **kwargs) 284 loss = reduce_to_scalar_loss(pred) 285 loss.backward() 286 optim.step() 287 288 return f 289 290 latency_with_padding = None 291 print("Benchmark with padding") 292 with config.patch(comprehensive_padding=True): 293 m_copy_with_padding = copy.deepcopy(model) 294 optim_with_padding = get_optim(m_copy_with_padding) 295 opt_f_with_padding = torch.compile( 296 get_f(m_copy_with_padding, optim_with_padding) 297 ) 298 latency_with_padding = benchmarker.benchmark_gpu( 299 lambda: opt_f_with_padding(*perf_args, **perf_kwargs) 300 ) 301 latency_without_padding = None 302 print("bencmark without padding") 303 with config.patch(comprehensive_padding=False): 304 m_copy_without_padding = copy.deepcopy(model) 305 optim_without_padding = get_optim(m_copy_without_padding) 306 opt_f_without_padding = torch.compile( 307 get_f(m_copy_without_padding, optim_without_padding) 308 ) 309 latency_without_padding = benchmarker.benchmark_gpu( 310 lambda: opt_f_without_padding(*perf_args, **perf_kwargs) 311 ) 312 print( 313 f"Latency with and without padding: {latency_with_padding:.3f} v.s. {latency_without_padding:.3f}" 314 ) 315 316 # profiling 317 self.do_profiling( 318 opt_f_with_padding, 319 opt_f_without_padding, 320 args=perf_args, 321 kwargs=perf_kwargs, 322 ) 323 324 def test_nvidia_deeprecommender(self): 325 """ 326 Compared the perf with and without comprehensive padding. 327 """ 328 layer_sizes = [197951, 512, 512, 1024, 512, 512, 197951] 329 x = torch.randn(4, layer_sizes[0]) 330 331 class Model(nn.Module): 332 def __init__(self) -> None: 333 super().__init__() 334 mod_list = [] 335 for i in range(len(layer_sizes) - 1): 336 mod_list.append(nn.Linear(layer_sizes[i], layer_sizes[i + 1])) 337 mod_list.append(nn.SELU()) 338 339 if i == 2: 340 mod_list.append(nn.Dropout(0.8)) 341 self.seq = nn.Sequential(*mod_list) 342 343 def forward(self, x): 344 return self.seq(x) 345 346 m = Model() 347 perf_inputs = torch.randn(256, layer_sizes[0]) 348 self.run_acc_and_perf_test(m, x, perf_inputs) 349 350 @unittest.skipIf(not DO_PERF_TEST or not HAS_TRANSFORMER, "Perf test not enabled") 351 def test_longformer(self, bs=4): 352 from transformers import AutoConfig, AutoModelForMaskedLM 353 354 config = AutoConfig.from_pretrained("allenai/longformer-base-4096") 355 model = AutoModelForMaskedLM.from_config(config) 356 357 vocab_size = model.config.vocab_size 358 seq_length = 1024 359 input_dict = gen_transformer_inputs(vocab_size, bs, seq_length) 360 361 self.run_acc_and_perf_test(model, input_dict) 362 363 @unittest.skipIf(not DO_PERF_TEST or not HAS_TRANSFORMER, "Perf test not enabled") 364 def test_longformer_small_bs(self): 365 """ 366 The model exists in both HF and TB. In TB it uses a samller batch size. 367 """ 368 self.test_longformer(bs=2) 369 370 371@instantiate_parametrized_tests 372class PaddingTest(TestCaseBase): 373 @unittest.skipIf(not DO_PERF_TEST, "Perf test not enabled") 374 def test_mm_padding_perf(self): 375 def naive_mm(a, b): 376 return a @ b 377 378 def _compute_padding(s, align): 379 return (s + align - 1) // align * align - s 380 381 @torch.compile 382 def pad_mm(a, b, align=16): 383 """ 384 NOTE: this function only pad a single dimension which is good 385 enough for testing. 386 """ 387 m_padding = _compute_padding(a.size(0), align) 388 k_padding = _compute_padding(a.size(1), align) 389 n_padding = _compute_padding(b.size(1), align) 390 return pad_mm_pass.pad_mm(a, b, m_padding, k_padding, n_padding) 391 392 for M, K, N, f in ( 393 (8192, 768, 30523, naive_mm), 394 (8192, 768, 30523, pad_mm), 395 (8192, 768, 30528, naive_mm), 396 (30523, 8192, 768, naive_mm), 397 (30528, 8192, 768, naive_mm), 398 ): 399 a = torch.randn(M, K) 400 b = torch.randn(K, N) 401 ms = benchmarker.benchmark_gpu(lambda: f(a, b)) 402 print(f"MxKxN {M}x{K}x{N} {f.__name__}: {ms:.3f}ms") 403 404 @unittest.skipIf(not DO_PERF_TEST, "Perf test not enabled") 405 def test_padmm(self): 406 """ 407 Latency between origional matmul and padded matmul: 2.717 v.s. 2.356 408 """ 409 mat1_pad = torch.randn(8192, 30522, dtype=torch.float16) 410 mat2_pad = torch.randn(30522, 768, dtype=torch.float16) 411 412 def f(): 413 return mat1_pad @ mat2_pad 414 415 def pad_dim(x: Tensor, padded_length: int, dim: int) -> Tensor: 416 pad = x.new_zeros(*x.shape[:dim], padded_length, *x.shape[dim + 1 :]) 417 return torch.cat([x, pad], dim=dim) 418 419 @torch.compile(fullgraph=True, options={"triton.cudagraphs": False}) 420 def g(): 421 mat1 = mat1_pad 422 mat2 = mat2_pad 423 mat1 = pad_dim(mat1, 6, 1) 424 mat2 = pad_dim(mat2, 6, 0) 425 return torch.ops.aten.mm(mat1, mat2) 426 427 ori_time = benchmarker.benchmark_gpu(f) 428 pad_time = benchmarker.benchmark_gpu(g) 429 430 print( 431 f"Latency between origional matmul and padded matmul: {ori_time:.3f} v.s. {pad_time:.3f}" 432 ) 433 self.do_profiling(f, g, "No MM Padding", "With mm padding") 434 435 @unittest.skipIf(not DO_PERF_TEST, "Perf test not enabled") 436 def test_matmul(self): 437 """ 438 Latency with good and bad shapes: 1.705 v.s. 2.625 439 """ 440 x_good_shape = torch.randn(8192, 30528, dtype=torch.float16) 441 weight_good_shape = torch.randn(30528, 768, dtype=torch.float16) 442 out_good_shape = torch.randn(8192, 768, dtype=torch.float16) 443 444 # Using stride (30522, 1) does not make a difference here. 445 x_bad_shape = rand_strided( 446 (8192, 30522), (30528, 1), device="cuda", dtype=torch.float16 447 ) 448 weight_bad_shape = torch.randn(30522, 768, dtype=torch.float16) 449 out_bad_shape = torch.randn(8192, 768, dtype=torch.float16) 450 451 def f(x, weight, out): 452 torch.mm(x, weight, out=out) 453 return out 454 455 f1 = torch.compile( 456 functools.partial(f, x_good_shape, weight_good_shape, out_good_shape) 457 ) 458 f2 = torch.compile( 459 functools.partial(f, x_bad_shape, weight_bad_shape, out_bad_shape) 460 ) 461 latency_good_shape = benchmarker.benchmark_gpu(f1) 462 latency_bad_shape = benchmarker.benchmark_gpu(f2) 463 print( 464 f"Latency with good and bad shapes: {latency_good_shape:.3f} v.s. {latency_bad_shape:.3f}" 465 ) 466 self.do_profiling(f1, f2) 467 468 @serialTest() 469 def test_nobias_LinearAndSoftmax_codegen(self): 470 self.test_LinearAndSoftmax_codegen(bias=False) 471 472 def test_LinearAndSoftmax_codegen(self, bias=True): 473 m_bad_shape = LinearAndSoftmax(vocab_size=30523, bias=bias) 474 inputs_bad_shape = m_bad_shape.get_example_inputs() 475 m_bad_shape_opt = torch.compile(copy.deepcopy(m_bad_shape)) 476 477 _, wrapper_codes = run_and_get_code( 478 forward_and_backward_pass, m_bad_shape_opt, inputs_bad_shape 479 ) 480 forward_and_backward_pass(m_bad_shape, inputs_bad_shape) 481 self.assertEqual( 482 m_bad_shape.linear.weight.grad, m_bad_shape_opt.linear.weight.grad 483 ) 484 self.assertTrue(len(wrapper_codes) == 2) # one for forward and oen for backward 485 forward_wrapper = wrapper_codes[0] 486 487 # make sure the load for softmax is aligned 488 self.assertTrue( 489 "tl.load(in_ptr0 + (r1 + (30528*x0))" in forward_wrapper, 490 f"forward_wrapper: {forward_wrapper}", 491 ) 492 493 if DO_PERF_TEST: 494 latency = benchmarker.benchmark_gpu( 495 lambda: forward_and_backward_pass(m_bad_shape_opt, inputs_bad_shape) 496 ) 497 print(f"latency: {latency:.3f}ms") 498 499 @config.patch(pattern_matcher=False) 500 def test_attention(self): 501 batch_size, seq_len, num_heads, hidden_size = 1, 4, 1, 16 502 inv_scale = (num_heads / hidden_size) ** 0.5 503 504 class Attention(nn.Module): 505 def __init__(self) -> None: 506 super().__init__() 507 self.query = nn.Linear(hidden_size, hidden_size) 508 self.key = nn.Linear(hidden_size, hidden_size) 509 self.value = nn.Linear(hidden_size, hidden_size) 510 511 @staticmethod 512 def reshape(x): 513 return x.view(batch_size, seq_len, num_heads, -1).permute(0, 2, 1, 3) 514 515 @staticmethod 516 def cancel_reshape(x): 517 return x.permute(0, 2, 1, 3).view(batch_size, seq_len, hidden_size) 518 519 def forward(self, x): 520 query, key, value = self.query(x), self.key(x), self.value(x) 521 weights = ( 522 torch.matmul( 523 self.reshape(query), self.reshape(key).permute(0, 1, 3, 2) 524 ) 525 * inv_scale 526 ).softmax(dim=-1) 527 return self.cancel_reshape(torch.matmul(weights, self.reshape(value))) 528 529 attn = Attention() 530 x = torch.randn(batch_size, seq_len, hidden_size) 531 532 self.common_numeric_check(attn, x) 533 534 def test_view(self): 535 def f(x): 536 return x.view(3, 3, 3) 537 538 x = torch.randn(3, 9) 539 self.common_numeric_check(f, x) 540 541 def test_pad_strides(self): 542 """ 543 Note that dim0's stride is also padded even though its previous value 544 is already multiple of 16. The reason is we padded dim1's stride. 545 We have to correspondingly increase the stride for dim0. 546 """ 547 sizes = [2, 16, 2047] 548 in_strides = [2047 * 16, 2047, 1] 549 out_strides = list(ir.Layout._pad_strides(in_strides, sizes, torch.float32)) 550 expected_strides = [2048 * 16, 2048, 1] 551 self.assertEqual( 552 expected_strides, out_strides, f"{expected_strides} v.s. {out_strides}" 553 ) 554 555 def test_pad_strides_skip(self): 556 """ 557 The padding is skipped to avoid too much memory overhead. 558 """ 559 sizes = [2, 32, 127] 560 in_strides = [4064, 127, 1] 561 out_strides = list(ir.Layout._pad_strides(in_strides, sizes, torch.float32)) 562 expected_strides = [4064, 127, 1] 563 self.assertEqual( 564 expected_strides, out_strides, f"{expected_strides} v.s. {out_strides}" 565 ) 566 567 def test_pad_3d_tensor(self): 568 """ 569 Constructing this test case guided by the fact that we don't pad 570 placeholder or user visible output's strides. 571 572 Add a matmul in the beginning and end so we can pad strides for 573 intermediate tensors. 574 """ 575 576 def f(x, y): 577 x = torch.matmul(x, y) 578 x = x + 1 579 return torch.matmul(x, y) 580 581 x = torch.randn(2, 16, 2047) 582 y = torch.randn(2047, 2047) 583 self.common_numeric_check(f, x, y, tol=1e-2) 584 self.assertTrue(metrics.num_comprehensive_padding > 0) 585 586 def test_conv(self): 587 """ 588 Padding the input for convolution may cause extra copy kernel being called. 589 Check this example trace: https://gist.github.com/shunting314/ce45398f7d51a63ce05fc8d411faddb3 590 """ 591 x_shape = (1, 128, 640, 959) 592 x1 = torch.randn(*x_shape) 593 594 padded_stride = ir.Layout._pad_strides(x1.stride(), x1.shape, torch.float32) 595 x2 = rand_strided(x_shape, padded_stride, device="cuda") 596 x2.copy_(x1) 597 598 weight = torch.randn(64, 128, 3, 3) 599 600 def fun(x, weight): 601 return torch.convolution( 602 x, 603 weight, 604 stride=(1, 1), 605 padding=(1, 1), 606 dilation=(1, 1), 607 transposed=False, 608 output_padding=(0, 0), 609 groups=1, 610 bias=None, 611 ) 612 613 ref = fun(x1, weight) 614 act = fun(x2, weight) 615 self.check_close(ref, act) 616 if DO_PERF_TEST: 617 latency_with_padding = benchmarker.benchmark_gpu(lambda: fun(x2, weight)) 618 latency_without_padding = benchmarker.benchmark_gpu(lambda: fun(x1, weight)) 619 print( 620 f"Latency with and without padding: {latency_with_padding:.3f} v.s. {latency_without_padding:.3f}" 621 ) 622 623 self.do_profiling(lambda: fun(x2, weight), lambda: fun(x1, weight)) 624 625 @unittest.skipIf(not DO_PERF_TEST, "Perf test not enabled") 626 def test_cat(self): 627 """ 628 Compare the perf between aten cat and compiled cat. 629 630 Latency between eager and compiled: 1.596 v.s. 0.601 631 632 Eager cat can be 2.66x slower than inductor kernel. 633 """ 634 x = torch.randn(8192, 30522, dtype=torch.float16) 635 636 def f(x): 637 pad = x.new_zeros(x.size(0), 6) 638 return torch.cat([x, pad], dim=1) 639 640 # disable cudagraphs since cudagraphs need copy the input which 641 # distort the latency a lot! (double the latency here for compiled 642 # version) 643 with config.patch("triton.cudagraphs", False): 644 opt_f = torch.compile(f) 645 opt_f(x) 646 eager_time = benchmarker.benchmark_gpu(lambda: f(x)) 647 opt_time = benchmarker.benchmark_gpu(lambda: opt_f(x)) 648 print( 649 f"Latency between eager and compiled: {eager_time:.3f} v.s. {opt_time:.3f}" 650 ) 651 self.do_profiling(lambda: f(x), lambda: opt_f(x), "Eager Cat", "Compiled Cat") 652 653 def test_pad_channels_last(self): 654 t = torch.randn(2, 3, 5, 1025) 655 in_strides = t.stride() 656 out_strides = ir.Layout._pad_strides(in_strides, t.shape, torch.float32) 657 self.assertTrue(in_strides != out_strides) 658 659 t = t.to(memory_format=torch.channels_last) 660 in_strides = t.stride() 661 out_strides = ir.Layout._pad_strides(in_strides, t.shape, torch.float32) 662 self.assertTrue(in_strides == out_strides) 663 664 @parametrize("alignment_bytes", (32, 128)) 665 @parametrize("shape", [(21, 19), (3, 5, 71)]) 666 @parametrize("dtype", (torch.float16, torch.float32)) 667 def test_pad_outputs( 668 self, dtype: torch.dtype, shape: Tuple[int], alignment_bytes: int 669 ): 670 """ 671 Tests padding output tensors to a specific alignment. 672 This is enabled by a config flag. 673 """ 674 func = torch.add 675 inputs = tuple(torch.randn(*shape, dtype=dtype) for input_idx in range(2)) 676 677 # Compile and run 678 with config.patch( 679 { 680 "comprehensive_padding": True, 681 "padding_alignment_bytes": alignment_bytes, 682 "padding_stride_threshold": 0, 683 "pad_outputs": True, 684 } 685 ): 686 compiled_func = torch.compile(func) 687 compiled_out = compiled_func(*inputs) 688 689 # Check numerics 690 eager_out = func(*inputs) 691 self.check_close(eager_out, compiled_out) 692 693 # Compute the expected padding 694 element_size = torch.tensor([], dtype=dtype).element_size() 695 self.assertGreater(alignment_bytes, element_size) 696 self.assertEqual(alignment_bytes % element_size, 0) 697 alignment_elements = alignment_bytes // element_size 698 contiguous_stride = inputs[0].stride() 699 expected_stride = [1] 700 for dim in reversed(shape[1:]): 701 slice_size = dim * expected_stride[0] 702 new_stride = alignment_elements * ceildiv(slice_size, alignment_elements) 703 expected_stride.insert(0, new_stride) 704 expected_stride = tuple(expected_stride) 705 self.assertNotEqual(expected_stride, contiguous_stride) 706 707 # Check strides 708 self.assertFalse(compiled_out.is_contiguous()) 709 self.assertEqual(compiled_out.stride(), expected_stride) 710 711 712if __name__ == "__main__": 713 if HAS_CUDA: 714 run_tests() 715