1# Owner(s): ["module: inductor"] 2import contextlib 3import copy 4import dataclasses 5import functools 6import gc 7import importlib 8import itertools 9import math 10import operator 11import os 12import random 13import re 14import subprocess 15import sys 16import threading 17import time 18import typing 19import unittest 20import unittest.mock 21import weakref 22from pathlib import Path 23from typing import Tuple 24from unittest.mock import patch 25 26import numpy as np 27 28import torch 29 30import torch._dynamo.config as dynamo_config 31import torch.nn as nn 32from torch._dispatch.python import enable_python_dispatcher 33from torch._dynamo.debug_utils import aot_graph_input_parser 34from torch._dynamo.testing import ( 35 CompileCounterWithBackend, 36 expectedFailureCodegenDynamic, 37 rand_strided, 38 same, 39 skipIfPy312, 40) 41from torch._dynamo.utils import ifdynstaticdefault 42from torch._inductor.codegen.common import DataTypePropagation, OptimizationContext 43from torch._inductor.fx_passes import pad_mm 44from torch._inductor.test_case import TestCase as InductorTestCase 45from torch._inductor.utils import ( 46 add_scheduler_init_hook, 47 aoti_compile_with_persistent_cache, 48 aoti_eager_cache_dir, 49 load_aoti_eager_cache, 50 run_and_get_code, 51 run_and_get_cpp_code, 52 run_and_get_triton_code, 53) 54from torch._inductor.virtualized import V 55from torch._prims_common import is_integer_dtype 56from torch.fx.experimental.proxy_tensor import make_fx 57from torch.library import _scoped_library 58from torch.nn import functional as F 59from torch.testing import FileCheck, make_tensor 60from torch.testing._internal.common_cuda import ( 61 PLATFORM_SUPPORTS_FLASH_ATTENTION, 62 PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, 63 SM80OrLater, 64 TEST_CUDNN, 65 with_tf32_off, 66) 67 68from torch.testing._internal.common_device_type import ( 69 _has_sufficient_memory, 70 expectedFailureXPU, 71) 72from torch.testing._internal.common_dtype import all_types, get_all_dtypes 73from torch.testing._internal.common_utils import ( 74 DeterministicGuard, 75 instantiate_parametrized_tests, 76 IS_CI, 77 IS_FBCODE, 78 IS_MACOS, 79 IS_WINDOWS, 80 IS_X86, 81 parametrize, 82 serialTest, 83 skipIfNNModuleInlined, 84 skipIfRocm, 85 skipIfXpu, 86 subtest, 87 TEST_WITH_ASAN, 88 TEST_WITH_ROCM, 89) 90from torch.utils import _pytree as pytree 91from torch.utils._python_dispatch import TorchDispatchMode 92from torch.utils._pytree import tree_flatten, tree_unflatten 93from torch.utils.weak import WeakTensorKeyDictionary 94 95DO_PERF_TEST = os.environ.get("DO_PERF_TEST") == "1" 96 97if IS_WINDOWS and IS_CI: 98 sys.stderr.write( 99 "Windows CI does not have necessary dependencies for test_torchinductor yet\n" 100 ) 101 if __name__ == "__main__": 102 sys.exit(0) 103 raise unittest.SkipTest("requires sympy/functorch/filelock") 104 105importlib.import_module("functorch") 106importlib.import_module("filelock") 107 108from torch._inductor import config, test_operators 109 110from torch._inductor.compile_fx import ( 111 compile_fx, 112 compile_fx_inner, 113 complex_memory_overlap, 114) 115from torch._inductor.utils import has_torchvision_roi_align 116 117from torch.testing._internal.common_utils import slowTest 118from torch.testing._internal.inductor_utils import ( 119 GPU_TYPE, 120 HAS_CPU, 121 HAS_GPU, 122 HAS_MULTIGPU, 123 skipCPUIf, 124 skipCUDAIf, 125) 126 127HAS_AVX2 = "fbgemm" in torch.backends.quantized.supported_engines 128 129aten = torch.ops.aten 130requires_gpu = functools.partial(unittest.skipIf, not HAS_GPU, "requires gpu") 131 132requires_multigpu = functools.partial( 133 unittest.skipIf, not HAS_MULTIGPU, f"requires multiple {GPU_TYPE} devices" 134) 135skip_if_x86_mac = functools.partial( 136 unittest.skipIf, IS_MACOS and IS_X86, "Does not work on x86 Mac" 137) 138vec_dtypes = [torch.float, torch.bfloat16, torch.float16] 139 140libtest = torch.library.Library("test", "FRAGMENT") # noqa: TOR901 141ids = set() 142 143f32 = torch.float32 144i64 = torch.int64 145i32 = torch.int32 146 147 148def _large_cumprod_input(shape, dim, dtype, device): 149 # Construct a cumprod input which guaruntees not to overflow or underflow 150 if is_integer_dtype(dtype): 151 # Large products don't fit in integers, the best we can do 152 # is random +/-1 values to test the sign of the result 153 x = torch.randint(0, 1, shape, dtype=dtype, device=device) 154 return x * 2 - 1 155 156 comp_dtype = torch._prims_common.get_computation_dtype(dtype) 157 batch_size = 256 158 if comp_dtype != dtype: 159 batch_size = math.floor(math.log2(torch.finfo(dtype).max) / 3) 160 161 # Create random values with a uniform magnitude and uniform exponent 162 num_batches = (shape[dim] + 2 * batch_size - 1) // (2 * batch_size) 163 batch_shape = ( 164 shape[:dim] 165 + ( 166 num_batches, 167 batch_size, 168 ) 169 + shape[dim + 1 :] 170 ) 171 magnitude = 1 + torch.rand(batch_shape, dtype=comp_dtype, device=device) 172 exponent = torch.randint(-1, 1, batch_shape, device=device).to(comp_dtype) 173 batch = magnitude * exponent.exp2() 174 175 # Alternate each batch of values with their reciprocals so the product 176 # never gets too far away from 1 177 t = torch.cat((batch, batch.reciprocal()), dim=dim + 1) 178 t = t.flatten(dim, dim + 1) 179 t = aten.slice(t, dim=dim, start=0, end=shape[dim]) 180 181 # Randomize sign 182 sign = torch.randint(0, 1, shape, device=device) * 2 - 1 183 return (t * sign).to(dtype) 184 185 186def define_custom_op_for_test(id_, fn_cpu, fn_cuda, fn_xpu, fn_meta, tags=()): 187 global libtest 188 global ids 189 if id_ not in ids: 190 libtest.define(f"{id_}(Tensor self) -> Tensor", tags=tags) 191 libtest.impl(id_, fn_cpu, "CPU") 192 libtest.impl(id_, fn_cuda, "CUDA") 193 libtest.impl(id_, fn_xpu, "XPU") 194 libtest.impl(id_, fn_meta, "Meta") 195 ids.add(id_) 196 197 198def define_custom_op_2_for_test(id_, fn_cpu, fn_cuda, fn_xpu, fn_meta, tags=()): 199 global libtest 200 global ids 201 if id_ not in ids: 202 libtest.define( 203 f"{id_}(Tensor self, float scale) -> (Tensor, Tensor)", tags=tags 204 ) 205 libtest.impl(id_, fn_cpu, "CPU") 206 libtest.impl(id_, fn_cuda, "CUDA") 207 libtest.impl(id_, fn_xpu, "XPU") 208 libtest.impl(id_, fn_meta, "Meta") 209 ids.add(id_) 210 211 212def define_custom_op_3_for_test(id_, fn_cpu, fn_cuda, fn_xpu, fn_meta, tags=()): 213 global libtest 214 global ids 215 if id_ not in ids: 216 libtest.define(f"{id_}(Tensor[] x) -> Tensor", tags=tags) 217 libtest.impl(id_, fn_cpu, "CPU") 218 libtest.impl(id_, fn_cuda, "CUDA") 219 libtest.impl(id_, fn_xpu, "XPU") 220 libtest.impl(id_, fn_meta, "Meta") 221 ids.add(id_) 222 223 224f32 = torch.float32 225 226 227def run_fw_bw_and_get_code(fn): 228 def run_with_backward(): 229 result = fn() 230 result.sum().backward() 231 return result 232 233 return run_and_get_code(run_with_backward) 234 235 236def register_ops_with_aoti_compile(ns, op_set, dispatch_key, torch_compile_op_lib_impl): 237 for _op_name in op_set: 238 qualified_op_name = f"{ns}::{_op_name}" 239 _, overload_names = torch._C._jit_get_operation(qualified_op_name) 240 for overload_name in overload_names: 241 try: 242 reg_op_name = qualified_op_name 243 schema = torch._C._get_schema(qualified_op_name, overload_name) 244 if schema.overload_name: 245 reg_op_name = f"{qualified_op_name}.{schema.overload_name}" 246 torch_compile_op_lib_impl._impl_with_aoti_compile( # noqa: F821 247 reg_op_name, dispatch_key 248 ) 249 except Exception as e: 250 continue 251 252 253class TestCase(InductorTestCase): 254 @classmethod 255 def setUpClass(cls): 256 super().setUpClass() 257 cls._stack = contextlib.ExitStack() 258 cls._stack.enter_context( 259 config.patch( 260 { 261 "debug": True, 262 "debug_index_asserts": True, 263 "cpp.min_chunk_size": 1, 264 "triton.autotune_pointwise": False, # too slow 265 "implicit_fallbacks": False, 266 "generate_intermediate_hooks": True, 267 } 268 ) 269 ) 270 271 @classmethod 272 def tearDownClass(cls): 273 cls._stack.close() 274 super().tearDownClass() 275 276 def setUp(self): 277 torch._dynamo.reset() 278 torch._inductor.metrics.reset() 279 super().setUp() 280 self._start = time.perf_counter() 281 282 def tearDown(self): 283 super().tearDown() 284 torch._dynamo.reset() 285 if os.environ.get("ERROR_ON_SLOW") == "1": 286 elapsed = time.perf_counter() - self._start 287 assert elapsed < 120 288 289 290class ToTuple(torch.nn.Module): 291 def forward(self, x): 292 return (x,) 293 294 295@dataclasses.dataclass 296class InputGen: 297 n: int 298 device: str 299 300 def dense(self): 301 return torch.randn((self.n, self.n), device=self.device) 302 303 def transposed(self): 304 return self.dense().transpose(0, 1) 305 306 def strided(self): 307 return torch.randn((self.n * 2, self.n * 3), device=self.device)[ 308 self.n :, self.n :: 2 309 ] 310 311 def broadcast1(self): 312 return torch.randn((self.n,), device=self.device) 313 314 def broadcast2(self): 315 return torch.randn((1, self.n, 1), device=self.device) 316 317 def broadcast3(self): 318 return torch.randn((1,), device=self.device) 319 320 def double(self): 321 return torch.randn((self.n, self.n), device=self.device, dtype=torch.double) 322 323 def int(self): 324 return torch.arange(self.n, device=self.device, dtype=torch.int32) 325 326 327def compute_grads(args, kwrags, results, grads): 328 def gather_leaf_tensors(args, kwargs): 329 args = pytree.arg_tree_leaves(*args, **kwargs) 330 leaf_tensors = [ 331 arg for arg in args if isinstance(arg, torch.Tensor) and arg.requires_grad 332 ] 333 return leaf_tensors 334 335 flat_results = pytree.tree_leaves(results) 336 flat_diff_results = [ 337 r for r in flat_results if isinstance(r, torch.Tensor) and r.requires_grad 338 ] 339 assert len(flat_diff_results) > 0 340 341 leaf_tensors = gather_leaf_tensors(args, kwrags) 342 assert len(leaf_tensors) > 0 343 return torch.autograd.grad( 344 flat_diff_results, 345 leaf_tensors, 346 grads, 347 allow_unused=True, 348 retain_graph=True, 349 ) 350 351 352def clone_preserve_strides(x, device=None): 353 if not isinstance(x, torch.Tensor): 354 return x 355 buffer = torch.as_strided( 356 x, (x.untyped_storage().size() // x.element_size(),), (1,), 0 357 ) 358 if not device: 359 buffer = buffer.clone() 360 else: 361 buffer = buffer.to(device, copy=True) 362 out = torch.as_strided(buffer, x.size(), x.stride(), x.storage_offset()) 363 return out 364 365 366def check_model( 367 self: TestCase, 368 model, 369 example_inputs, 370 kwargs=None, 371 *, 372 atol=None, 373 rtol=None, 374 grad_atol=None, 375 grad_rtol=None, 376 check_lowp=True, 377 exact_dtype=True, 378 nopython=True, 379 copy_to_gpu=True, 380 reference_in_float=True, 381 assert_equal=True, 382 check_gradient=False, 383 check_has_compiled=True, 384 output_process_fn_grad=lambda x: x, 385): 386 kwargs = kwargs or {} 387 torch._dynamo.reset() 388 389 ref_inputs = [clone_preserve_strides(x) for x in example_inputs] 390 ref_kwargs = kwargs 391 has_lowp_args = False 392 393 if reference_in_float and exact_dtype: 394 # Store expected dtypes so we can check actual result gives the correct types 395 torch.manual_seed(0) 396 try: 397 eager_result = model(*ref_inputs, **ref_kwargs) 398 except RuntimeError: 399 # Eager model may fail if the dtype is not supported 400 eager_result = None 401 402 ref_inputs = [clone_preserve_strides(x) for x in example_inputs] 403 expect_dtypes = [ 404 x.dtype if isinstance(x, torch.Tensor) else None 405 for x in pytree.tree_leaves(eager_result) 406 ] 407 del eager_result 408 409 ref_model = model 410 if reference_in_float: 411 # check_lowp is ignored here, it's kept just to be able to call `common` with extra arg 412 def upcast_fn(x): 413 nonlocal has_lowp_args 414 if isinstance(x, torch.Tensor) and ( 415 x.dtype == torch.float16 or x.dtype == torch.bfloat16 416 ): 417 has_lowp_args = True 418 return x.float() 419 else: 420 return x 421 422 ref_inputs = list(map(upcast_fn, example_inputs)) 423 ref_kwargs = {k: upcast_fn(v) for k, v in kwargs.items()} 424 if has_lowp_args and hasattr(model, "to"): 425 ref_model = copy.deepcopy(model).to(torch.float) 426 427 torch.manual_seed(0) 428 429 correct = ref_model(*ref_inputs, **ref_kwargs) 430 431 torch._inductor.metrics.reset() 432 433 called = False 434 435 def compile_fx_wrapper(model_, example_inputs_): 436 nonlocal called 437 called = True 438 return compile_fx(model_, example_inputs_) 439 440 def run(*ex, **kwargs): 441 return model(*ex, **kwargs) 442 443 run = torch._dynamo.optimize(compile_fx_wrapper, nopython=nopython)(run) 444 445 torch.manual_seed(0) 446 actual = run(*example_inputs, **kwargs) 447 # if not called: 448 # exp = torch._dynamo.explain(run)(*example_inputs) 449 # print("Explain:", exp[0]) 450 # for graph in exp[2]: 451 # print("Graph", graph) 452 if check_has_compiled: 453 assert called, "Ran graph without calling compile_fx" 454 assert type(actual) == type(correct) 455 if isinstance(actual, (tuple, list)): 456 assert len(actual) == len(correct) 457 assert all( 458 type(actual_item) == type(correct_item) 459 for actual_item, correct_item in zip(actual, correct) 460 ) 461 462 correct_flat, correct_spec = tree_flatten(correct) 463 actual_flat = pytree.tree_leaves(actual) 464 465 def reference_to_expect(actual_flat, correct_flat): 466 return tuple( 467 ( 468 y.to(x.dtype) 469 if isinstance(y, torch.Tensor) and y.dtype.is_floating_point 470 else y 471 ) 472 for x, y in zip(actual_flat, correct_flat) 473 ) 474 475 if reference_in_float and exact_dtype: 476 for expect_dtype, actual_result in zip(expect_dtypes, actual_flat): 477 if expect_dtype is not None: 478 assert ( 479 actual_result.dtype == expect_dtype 480 ), f"dtype mismatch, expected {expect_dtype} but got {actual_result.dtype}" 481 482 if reference_in_float: 483 correct_flat = reference_to_expect(actual_flat, correct_flat) 484 correct = tree_unflatten(correct_flat, correct_spec) 485 486 if assert_equal: 487 self.assertEqual( 488 actual, 489 correct, 490 atol=atol, 491 rtol=rtol, 492 equal_nan=True, 493 exact_dtype=exact_dtype, 494 ) 495 # In case of input mutations, check that inputs are the same 496 self.assertEqual( 497 ref_inputs, 498 example_inputs, 499 atol=atol, 500 rtol=rtol, 501 equal_nan=True, 502 # our testing sometimes uses higher precision inputs for the reference 503 exact_dtype=False, 504 ) 505 else: 506 for correct_val, actual_val in zip(correct_flat, actual_flat): 507 if isinstance(correct_val, torch.Tensor): 508 assert correct_val.device == actual_val.device 509 assert correct_val.size() == actual_val.size() 510 strides_equal, _ = torch._prims_common.check_significant_strides( 511 correct_val, actual_val 512 ) 513 assert strides_equal 514 assert correct_val.layout == actual_val.layout 515 if exact_dtype: 516 assert correct_val.dtype == actual_val.dtype 517 518 if check_gradient: 519 actual = output_process_fn_grad(actual) 520 correct = output_process_fn_grad(correct) 521 actual_flat = pytree.tree_leaves(actual) 522 correct_flat = pytree.tree_leaves(correct) 523 524 # generate random unit norm gradients 525 grads = [ 526 torch.rand(r.shape, device=r.device, dtype=r.dtype) 527 for r in correct_flat 528 if isinstance(r, torch.Tensor) and r.requires_grad 529 ] 530 for g in grads: 531 g /= g.norm() 532 533 correct_grad = compute_grads(ref_inputs, ref_kwargs, correct, grads) 534 all_none_grads = all(x is None for x in correct_grad) 535 if all_none_grads: 536 # See Note [Detaching inputs that never need gradients] 537 # There are a handful of ops that can return None gradients, into of zero gradients. 538 # If all inputs to an AOTAutograd graph are supposed to get None gradients, 539 # AOTAutograd will end up forcing all of the outputs of the forward to not require grad. 540 # There's no easy fix to this (see the note above), although one option is to 541 # force any derivative formulas in core to return tensors of zeros instead of None. 542 flat_results = pytree.tree_leaves(actual) 543 results_that_require_grad = [ 544 x 545 for x in flat_results 546 if isinstance(x, torch.Tensor) and x.requires_grad 547 ] 548 self.assertEqual(len(results_that_require_grad), 0) 549 else: 550 actual_grad = compute_grads(example_inputs, kwargs, actual, grads) 551 552 if reference_in_float: 553 expect_grad = reference_to_expect(actual_grad, correct_grad) 554 else: 555 expect_grad = correct_grad 556 557 self.assertEqual( 558 actual_grad, 559 expect_grad, 560 atol=grad_atol or atol, 561 rtol=grad_rtol or rtol, 562 equal_nan=True, 563 exact_dtype=exact_dtype, 564 ) 565 566 torch._dynamo.reset() 567 568 569@torch._inductor.config.patch("triton.cudagraphs", False) 570def check_model_gpu( 571 self: TestCase, 572 model, 573 example_inputs, 574 kwargs=None, 575 *, 576 atol=None, 577 rtol=None, 578 grad_atol=None, 579 grad_rtol=None, 580 check_lowp=True, 581 exact_dtype=True, 582 nopython=True, 583 copy_to_gpu=True, 584 reference_in_float=True, 585 assert_equal=True, 586 check_gradient=False, 587 check_has_compiled=True, 588 output_process_fn_grad=lambda x: x, 589): 590 kwargs = kwargs or {} 591 if hasattr(model, "to"): 592 model = model.to(device=GPU_TYPE) 593 594 if copy_to_gpu: 595 example_inputs = tuple( 596 clone_preserve_strides(x, device=GPU_TYPE) for x in example_inputs 597 ) 598 599 check_model( 600 self, 601 model, 602 example_inputs, 603 kwargs, 604 atol=atol, 605 rtol=rtol, 606 grad_atol=grad_atol, 607 grad_rtol=grad_rtol, 608 exact_dtype=exact_dtype, 609 nopython=nopython, 610 reference_in_float=reference_in_float, 611 assert_equal=assert_equal, 612 check_gradient=check_gradient, 613 check_has_compiled=check_has_compiled, 614 output_process_fn_grad=output_process_fn_grad, 615 ) 616 617 if check_lowp: 618 619 def downcast_fn(x): 620 if not isinstance(x, torch.Tensor) or not x.dtype == torch.float: 621 return x 622 return torch.empty_strided( 623 x.size(), x.stride(), device=GPU_TYPE, dtype=torch.half 624 ).copy_(x) 625 626 example_inputs = list(map(downcast_fn, example_inputs)) 627 if hasattr(model, "to"): 628 model = model.to(torch.half) 629 if rtol is not None: 630 rtol = max(2e-3, rtol) 631 check_model( 632 self, 633 model, 634 example_inputs, 635 kwargs, 636 atol=atol, 637 rtol=rtol, 638 grad_atol=grad_atol, 639 grad_rtol=grad_rtol, 640 exact_dtype=exact_dtype, 641 nopython=nopython, 642 reference_in_float=reference_in_float, 643 assert_equal=assert_equal, 644 check_gradient=check_gradient, 645 check_has_compiled=check_has_compiled, 646 output_process_fn_grad=output_process_fn_grad, 647 ) 648 649 650check_model_cuda = check_model_gpu 651 652 653def _run_and_assert_no_indirect_indexing( 654 test_case, func, *args, has_wrapping=None, has_assert=False, **kwargs 655): 656 result, source_codes = run_and_get_code(func, *args, **kwargs) 657 658 for code in source_codes: 659 for line in code.split("\n"): 660 stmt = None 661 # Find indexing expressions 662 if ".load(" in line: 663 stmt = line.split(".load")[-1] 664 elif "tl.store" in line: 665 stmt = line.split(".store")[-1] 666 stmt = ",".join(stmt.split(",")[:-2]) # Remove store value and mask 667 elif ".store" in line: 668 stmt = line.split(".store")[-1] 669 elif "[" in line: 670 stmt = line.split("[")[-1].split("]")[0] 671 if "tl.make_block_ptr(" in line: 672 continue 673 674 if stmt is None: 675 continue 676 677 # indirect indexing involves a `tmp` variable 678 test_case.assertTrue( 679 "tmp" not in stmt, 680 msg=f"Found indirect indexing in statement '{stmt}' from code:\n{code}", 681 ) 682 if has_wrapping is not None: 683 test_case.assertTrue( 684 ("where" in code or "?" in code) is has_wrapping, 685 msg=f"Wanted {has_wrapping=} but got\n{code}", 686 ) 687 test_case.assertTrue( 688 any( 689 ("device_assert" in code or "TORCH_CHECK" in code) is has_assert 690 for code in source_codes 691 ) 692 ) 693 return result 694 695 696def assertGeneratedKernelCountEqual(self: TestCase, expected: int): 697 if config.triton.multi_kernel: 698 # when multi_kernel is enabled, we generated both persistent reduction 699 # and non-persistent reduction kernels for the same node schedule. 700 # That will mess up with the kernel count. Just don't check it. 701 return 702 if config.cpp_wrapper: 703 expected *= 2 704 self.assertEqual(torch._inductor.metrics.generated_kernel_count, expected) 705 706 707class SweepInputs2: 708 input_gen_types1 = [ 709 "dense", 710 "transposed", 711 "strided", 712 "broadcast1", 713 "broadcast2", 714 "broadcast3", 715 "double", 716 "int", 717 ] 718 input_gen_types2 = input_gen_types1 719 gen = None 720 721 @staticmethod 722 def kernel(a, b): 723 return (a + b,) 724 725 @classmethod 726 def gen_template(cls, name1, name2): 727 def test(self): 728 check_model( 729 self, 730 cls.kernel, 731 ( 732 getattr(cls.gen, name1)(), 733 getattr(cls.gen, name2)(), 734 ), 735 ) 736 737 test.__name__ = f"test_{cls.gen.device}_{name1}_{name2}" 738 setattr(cls, test.__name__, test) 739 740 @classmethod 741 def populate(cls): 742 for name1 in cls.input_gen_types1: 743 for name2 in cls.input_gen_types2: 744 cls.gen_template(name1, name2) 745 746 747@instantiate_parametrized_tests 748class CommonTemplate: 749 def test_bool(self): 750 def fn(a, b): 751 return ( 752 a + b, 753 a * b, 754 a & b, 755 a | b, 756 a ^ b, 757 torch.logical_and(a, b), 758 torch.logical_or(a, b), 759 torch.logical_not(a), 760 torch.sign(b), 761 ) 762 763 self.common( 764 fn, 765 ( 766 torch.tensor([True, False, True, False]), 767 torch.tensor([False, False, True, True]), 768 ), 769 ) 770 771 @skipCUDAIf(not SM80OrLater, "Requires sm80") 772 def test_eager_aoti_support_out(self): 773 ns = "aten" 774 op_name = "clamp" 775 dispatch_key = "CPU" 776 device = "cpu" 777 if self.device.lower() == "cuda": 778 dispatch_key = "CUDA" 779 device = "cuda" 780 781 inp_tensor = torch.randn(128, dtype=torch.float, device=device).fill_(1.0) 782 min_tensor = inp_tensor - 0.05 783 max_tensor = inp_tensor + 0.05 784 with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl: 785 ref_out_tensor = torch.randn(128, dtype=torch.float, device=device).fill_( 786 -1 787 ) 788 ref_tensor = torch.clamp( 789 max=max_tensor, min=min_tensor, input=inp_tensor, out=ref_out_tensor 790 ) 791 792 ref_out_tensor1 = torch.randn(128, dtype=torch.float, device=device).fill_( 793 -1 794 ) 795 ref_tensor1 = torch.clamp( 796 max=max_tensor, out=ref_out_tensor1, min=min_tensor, input=inp_tensor 797 ) 798 799 register_ops_with_aoti_compile( 800 ns, [op_name], dispatch_key, torch_compile_op_lib_impl 801 ) 802 803 res_out_tensor = torch.randn(128, dtype=torch.float, device=device).fill_( 804 -1 805 ) 806 res_tensor = torch.clamp( 807 max=max_tensor, min=min_tensor, input=inp_tensor, out=res_out_tensor 808 ) 809 810 self.assertEqual(ref_tensor, res_tensor) 811 self.assertEqual(ref_out_tensor, res_out_tensor) 812 813 res_out_tensor1 = torch.randn(128, dtype=torch.float, device=device).fill_( 814 -1 815 ) 816 res_tensor1 = torch.clamp( 817 max=max_tensor, out=res_out_tensor1, min=min_tensor, input=inp_tensor 818 ) 819 820 self.assertEqual(ref_tensor1, res_tensor1) 821 self.assertEqual(ref_out_tensor1, res_out_tensor1) 822 823 @skipCUDAIf(not SM80OrLater, "Requires sm80") 824 def test_eager_aoti_cache_hit(self): 825 ns = "aten" 826 op_name = "abs" 827 dispatch_key = "CPU" 828 device = "cpu" 829 if self.device.lower() == "cuda": 830 dispatch_key = "CUDA" 831 device = "cuda" 832 833 input_tensor = torch.randn(128, dtype=torch.float, device=device) 834 kernel_lib_path = aoti_compile_with_persistent_cache( 835 ns, 836 op_name, 837 device, 838 False, 839 getattr(torch.ops.aten, op_name), 840 (input_tensor,), 841 {}, 842 ) 843 self.assertTrue(Path(kernel_lib_path).exists()) 844 845 from unittest import mock 846 847 # Patch the aoti_compile_with_persistent_cache as None to ensure no new kernel is generated 848 with mock.patch( 849 "torch._inductor.utils.aoti_compile_with_persistent_cache", None 850 ): 851 with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl: 852 # Get ref result from eager 853 ref_value = getattr(torch.ops.aten, op_name)(input_tensor) 854 855 register_ops_with_aoti_compile( 856 ns, [op_name], dispatch_key, torch_compile_op_lib_impl 857 ) 858 859 # Invoke the pre-compiled kernel and get result. 860 res_value = getattr(torch.ops.aten, op_name)(input_tensor) 861 862 self.assertEqual(ref_value, res_value) 863 864 @skipCUDAIf(not SM80OrLater, "Requires sm80") 865 def test_eager_aoti_with_persistent_cache(self): 866 def fn(a): 867 return torch.abs(a) 868 869 ns = "aten" 870 op_name = "abs" 871 872 device = "cpu" 873 if self.device.lower() == "cuda": 874 device = "cuda" 875 876 input_tensor = torch.randn(128, dtype=torch.float, device=device) 877 kernel_lib_path = aoti_compile_with_persistent_cache( 878 ns, 879 op_name, 880 input_tensor.device.type, 881 False, 882 fn, 883 args=(input_tensor,), 884 kwargs={}, 885 ) 886 self.assertTrue(len(kernel_lib_path) > 0) 887 888 device_kernel_cache = aoti_eager_cache_dir(ns, device) 889 kernel_conf = device_kernel_cache / f"{op_name}.json" 890 self.assertTrue(kernel_conf.exists()) 891 892 json_data = load_aoti_eager_cache("aten", "abs", input_tensor.device.type) 893 self.assertTrue(json_data is not None) 894 self.assertTrue(isinstance(json_data, list)) 895 self.assertTrue(len(json_data) > 0) 896 897 op_info = json_data[0] 898 self.assertTrue(isinstance(op_info, dict)) 899 self.assertTrue("meta_info" in op_info) 900 self.assertTrue("kernel_path" in op_info) 901 kernel_libs_abs_path = [] 902 for item in json_data: 903 kernel_path = device_kernel_cache / item["kernel_path"] 904 kernel_libs_abs_path.append(kernel_path.as_posix()) 905 906 self.assertTrue(kernel_lib_path in kernel_libs_abs_path) 907 908 @skipCUDAIf(not SM80OrLater, "Requires sm80") 909 def test_eager_aoti_with_scalar(self): 910 namespace_name = "aten" 911 op_name = "add" 912 op_overload_name = "Tensor" 913 op_name_with_overload = f"{op_name}.{op_overload_name}" 914 915 dispatch_key = "CPU" 916 device = torch.device("cpu") 917 if self.device.lower() == "cuda": 918 dispatch_key = "CUDA" 919 device = torch.device("cuda") 920 921 # Test the difference between scalar tensor and scalar 922 a = torch.scalar_tensor(1.0, device=device) 923 b = torch.scalar_tensor(2.0, device=device) 924 925 kernel_lib_path = aoti_compile_with_persistent_cache( 926 namespace_name, 927 op_name_with_overload, 928 a.device.type, 929 False, 930 torch.ops.aten.add, 931 args=(a, b), 932 kwargs={"alpha": 3.0}, 933 ) 934 self.assertTrue(Path(kernel_lib_path).exists()) 935 device_kernel_cache = aoti_eager_cache_dir(namespace_name, device.type) 936 kernel_conf = device_kernel_cache / f"{op_name_with_overload}.json" 937 self.assertTrue(kernel_conf.exists()) 938 json_data = load_aoti_eager_cache( 939 namespace_name, op_name_with_overload, a.device.type 940 ) 941 op_info = json_data[0] 942 self.assertTrue(isinstance(op_info, dict)) 943 self.assertTrue("meta_info" in op_info) 944 self.assertTrue(len(op_info["meta_info"]) == 3) 945 self.assertTrue(op_info["meta_info"][0]["sizes"] == []) 946 self.assertTrue(op_info["meta_info"][0]["strides"] == []) 947 # Scalar Tensor 948 self.assertTrue("scalar_value" not in op_info["meta_info"][0]) 949 self.assertTrue(op_info["meta_info"][1]["sizes"] == []) 950 self.assertTrue(op_info["meta_info"][1]["strides"] == []) 951 # Scalar Tensor 952 self.assertTrue("scalar_value" not in op_info["meta_info"][1]) 953 self.assertTrue(op_info["meta_info"][2]["sizes"] == []) 954 self.assertTrue(op_info["meta_info"][2]["strides"] == []) 955 # Scalar 956 self.assertTrue("scalar_value" in op_info["meta_info"][2]) 957 958 with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl: 959 a = torch.randn(128, device=device) 960 b = torch.randn(128, device=device) 961 962 scalar_values = [1.0, 2.0, 3.0] 963 ref_values = [] 964 for scalar_value in scalar_values: 965 ref_values.append(torch.add(a, b, alpha=scalar_value)) 966 967 register_ops_with_aoti_compile( 968 namespace_name, [op_name], dispatch_key, torch_compile_op_lib_impl 969 ) 970 971 res_values = [] 972 for scalar_value in scalar_values: 973 res_values.append(torch.add(a, b, alpha=scalar_value)) 974 975 self.assertEqual(len(ref_values), len(res_values)) 976 self.assertEqual(ref_values, res_values) 977 978 @skipCUDAIf(not SM80OrLater, "Requires sm80") 979 def test_eager_aoti_override_registration(self): 980 namespace_name = "aten" 981 dispatch_key = "CPU" 982 device = torch.device("cpu") 983 if self.device.lower() == "cuda": 984 dispatch_key = "CUDA" 985 device = torch.device("cuda") 986 987 unary_op_set = ["abs", "acos"] 988 989 def fn(x, op_name=""): 990 return getattr(torch, op_name)(x) 991 992 # Invoke torch.compile directly to get referent results 993 x = torch.randn(3, 4, device=device) 994 995 ref_array = [] 996 for unary_op_name in unary_op_set: 997 opt_fn = torch.compile(functools.partial(fn, op_name=unary_op_name)) 998 ref = opt_fn(x) 999 ref_array.append(ref) 1000 1001 with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl: 1002 register_ops_with_aoti_compile( 1003 namespace_name, unary_op_set, dispatch_key, torch_compile_op_lib_impl 1004 ) 1005 1006 res_array = [] 1007 for unary_op_name in unary_op_set: 1008 res_array.append(getattr(torch, unary_op_name)(x)) 1009 1010 for ref, res in zip(ref_array, res_array): 1011 self.assertEqual(ref, res) 1012 1013 a = torch.randn(128, device=device) 1014 min_tensor = torch.randn(128, device=device) 1015 max_tensor = min_tensor + 0.5 1016 1017 ref_with_min = torch.ops.aten.clamp(a, min_tensor) 1018 ref_with_min_max = torch.ops.aten.clamp(a, min_tensor, max_tensor) 1019 1020 with _scoped_library("aten", "IMPL") as torch_compile_op_lib_impl: 1021 register_ops_with_aoti_compile( 1022 namespace_name, ["clamp"], dispatch_key, torch_compile_op_lib_impl 1023 ) 1024 res_with_min = torch.ops.aten.clamp(a, min_tensor) 1025 res_with_min_max = torch.ops.aten.clamp(a, min_tensor, max_tensor) 1026 self.assertEqual(ref_with_min, res_with_min) 1027 self.assertEqual(ref_with_min_max, res_with_min_max) 1028 1029 def test_add_const_int(self): 1030 def fn(a): 1031 return (a + 1, torch.add(a, 1, alpha=2)) 1032 1033 for dtype in [torch.float32, torch.int32, torch.int64]: 1034 self.common(fn, (torch.arange(32, dtype=dtype),)) 1035 1036 def test_add_const_float(self): 1037 def fn(a): 1038 return (a + 1.5,) 1039 1040 self.common(fn, (torch.randn(32),)) 1041 1042 def test_add_inplace_permuted(self): 1043 def fn(x, y): 1044 return x.add_(y) 1045 1046 x = torch.ones([2, 12, 13, 17]).transpose(1, 2) 1047 y = torch.randn([2, 13, 1, 17]) 1048 1049 self.common(fn, (x, y)) 1050 1051 def test_add_complex(self): 1052 def fn(a, b, alpha): 1053 return torch.add(a, b, alpha=alpha) 1054 1055 x = torch.tensor([1 + 1j, -1 + 1j, -2 + 2j, 3 - 3j, 0, 1j, 1, -1]) 1056 y = torch.tensor([1 + 1j, -1 + 1j, -2 + 2j, 3 - 3j, 0, 1j, 1, -1]) 1057 1058 self.common(fn, (x, y, 2)) 1059 1060 def test_add_complex3(self): 1061 # fix https://github.com/pytorch/pytorch/issues/115071 1062 @torch.compile 1063 def fn(*args): 1064 a = torch.neg(args[0]) 1065 b = torch.add(args[0], args[0]) 1066 return (a, b) 1067 1068 x = torch.randn(41, dtype=torch.complex64) 1069 y = x.clone() 1070 # should not inplace write to the input 1071 fn(x) 1072 self.assertEqual(x, y) 1073 1074 def test_add_complex4(self): 1075 @torch.compile 1076 def fn(a, b): 1077 c = a + b 1078 d = a + b 1079 return c + d 1080 1081 for dtype in [torch.complex32, torch.complex64, torch.complex128]: 1082 x = torch.tensor( 1083 [1 + 1j, -1 + 1j, -2 + 2j, 3 - 3j, 0, 1j, 1, -1], 1084 dtype=dtype, 1085 device=self.device, 1086 ) 1087 y = torch.tensor( 1088 [1 + 1j, -1 + 1j, -2 + 2j, 3 - 3j, 0, 1j, 1, -1], 1089 dtype=dtype, 1090 device=self.device, 1091 ) 1092 _, code = run_and_get_code(fn, x, y) 1093 self.assertEqual( 1094 " ".join(code).count( 1095 "view_dtype" if config.cpp_wrapper else "aten.view" 1096 ), 1097 3, 1098 ) 1099 1100 def test_concat_add_inplace(self): 1101 def fn(x, y, z): 1102 return torch.cat([x, y], dim=1).add_(z) 1103 1104 x = torch.randn([2, 12, 14, 14]) 1105 y = torch.randn([2, 12, 14, 14]) 1106 z = torch.randn([2, 24, 14, 14]) 1107 1108 self.common(fn, (x, y, z)) 1109 1110 def test_abs(self): 1111 def fn(a): 1112 return (a / (torch.abs(a) + 1),) 1113 1114 self.common(fn, (torch.randn(17),)) 1115 1116 def test_angle(self): 1117 def fn(a, b, c): 1118 return torch.angle(a), torch.angle(b), torch.angle(c) 1119 1120 complex_input = torch.tensor( 1121 [1 + 1j, -1 + 1j, -2 + 2j, 3 - 3j, 0, 1j, 1, -1, float("nan")] 1122 ) 1123 real_input = torch.tensor([-1.0, 0.0, 1.0, float("nan")]) 1124 interger_real_input = torch.tensor([-1, 0, 1]) 1125 self.common(fn, (complex_input, real_input, interger_real_input)) 1126 1127 def test_sgn(self): 1128 def fn(a): 1129 return torch.sgn(a), torch.sgn(a + 1) - 1 1130 1131 self.common(fn, [torch.linspace(-10, 10, 41)]) 1132 1133 @skipCUDAIf(not SM80OrLater, "uses bfloat16 which requires SM >= 80") 1134 def test_scatter_bf16(self): 1135 def fn(inp, src, index): 1136 return inp.scatter_add(0, index, src) 1137 1138 for dtype in [torch.int64, torch.bool, torch.bfloat16]: 1139 self.common( 1140 fn, 1141 [ 1142 torch.zeros(3, 5, dtype=dtype), 1143 torch.ones((2, 5), dtype=dtype), 1144 torch.tensor([[0, 1, 2, 0, 0]]), 1145 ], 1146 ) 1147 1148 def test_randn_generator(self): 1149 def fn(a, generator): 1150 return torch.randn([20, 20], generator=generator, device=a.device) 1151 1152 self.common(fn, (torch.linspace(-10, 10, 41), None), assert_equal=False) 1153 1154 # generator not yet supported in dynamo 1155 with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "Generator"): 1156 self.common(fn, (torch.linspace(-10, 10, 41), torch.Generator(self.device))) 1157 1158 def test_sgn_extremal(self): 1159 def fn(a): 1160 return (torch.sgn(a),) 1161 1162 self.common(fn, [torch.tensor([np.nan, np.inf, -np.inf, 0])]) 1163 1164 def test_max_min(self): 1165 def fn(a, b): 1166 return (torch.maximum(a, b), torch.minimum(a, b)) 1167 1168 self.common(fn, (torch.randn(8), torch.randn(8))) 1169 t1 = torch.randn(8) 1170 t1[0] = float("nan") 1171 t2 = torch.randn(8) 1172 t2[1] = float("nan") 1173 self.common(fn, (t1, t2)) 1174 1175 def test_neg_max_uint8(self): 1176 # https://github.com/pytorch/pytorch/issues/93380 1177 def fn(a, b): 1178 c = torch.neg(a) 1179 return torch.maximum(b, c) 1180 1181 a = torch.randint(256, (1,), dtype=torch.uint8) 1182 b = torch.randint(256, (8390,), dtype=torch.uint8) 1183 self.common(fn, (a, b)) 1184 1185 def test_compar(self): 1186 def fn(x): 1187 return x.gt(3.5), x.ge(3.5), x.eq(3.5), x.le(2.5), x.lt(3.5), x.ne(3.5) 1188 1189 a = torch.tensor([3]) 1190 self.common(fn, (a,)) 1191 1192 def test_horizonal_fusion1(self): 1193 def fn(a, b, c): 1194 return (a + b, a - c, b * c) 1195 1196 self.common( 1197 fn, (torch.randn(8, 16, 16), torch.randn(8, 16, 16), torch.randn(1, 16, 1)) 1198 ) 1199 1200 def test_horizonal_fusion2(self): 1201 def fn(a, b, c): 1202 return a + 1, b + 2, c + 3 1203 1204 self.common(fn, (torch.randn(8, 16, 8), torch.randn(8, 16), torch.randn(16, 8))) 1205 1206 def test_vertical_fusion1(self): 1207 def fn(sa, ct, p): 1208 # From torchbench.pyhpc_equation_of_state 1209 v17 = -3.087032500374211e-7 1210 v18 = -1.988366587925593e-8 1211 v19 = -1.061519070296458e-11 1212 v20 = 1.550932729220080e-10 1213 t15 = v19 * ct 1214 t19 = v17 + ct * (v18 + t15) + v20 * sa 1215 t20 = 1.0 / t19 1216 t128 = t19 * p 1217 return t20 + t128 1218 1219 self.common( 1220 fn, 1221 ( 1222 torch.randn(204, 204, 26), 1223 torch.randn(204, 204, 26), 1224 torch.randn(26), 1225 ), 1226 ) 1227 assertGeneratedKernelCountEqual(self, 1) 1228 1229 @config.patch({"fx_graph_cache": False}) 1230 def test_forced_buffer_realize(self): 1231 # Test torch._test_inductor_realize forces a buffer to be realized 1232 def fn(a): 1233 b = test_operators.realize(a * 2) 1234 return (b * 2,) 1235 1236 self.common(fn, (torch.randn(10),)) 1237 self.assertEqual(torch._inductor.metrics.ir_nodes_pre_fusion, 2) 1238 1239 @config.patch({"fx_graph_cache": False}) 1240 def test_scheduler_vertical_fusion1(self): 1241 realize = test_operators.realize 1242 1243 def fn(sa, ct, p): 1244 # From torchbench.pyhpc_equation_of_state 1245 v17 = -3.087032500374211e-7 1246 v18 = -1.988366587925593e-8 1247 v19 = -1.061519070296458e-11 1248 v20 = 1.550932729220080e-10 1249 t15 = realize(v19 * ct) 1250 t19 = realize(v17 + ct * (v18 + t15) + v20 * sa) 1251 t20 = realize(1.0 / t19) 1252 t128 = realize(t19 * p) 1253 return t20 + t128 1254 1255 self.common( 1256 fn, 1257 ( 1258 torch.randn(204, 204, 26), 1259 torch.randn(204, 204, 26), 1260 torch.randn(26), 1261 ), 1262 ) 1263 self.assertEqual(torch._inductor.metrics.ir_nodes_pre_fusion, 5) 1264 assertGeneratedKernelCountEqual(self, 1 if self.device == GPU_TYPE else 2) 1265 1266 def test_index_propagation(self): 1267 def copy(x): 1268 i = torch.arange(x.size(0), device=x.device) 1269 return x[i] 1270 1271 x = torch.randn(8, device=self.device) 1272 copy_opt = torch._dynamo.optimize("inductor")(copy) 1273 1274 expect = copy(x) 1275 actual = _run_and_assert_no_indirect_indexing(self, copy_opt, x) 1276 self.assertEqual(expect, actual) 1277 1278 def test_index_propagation_flip(self): 1279 def flip(x): 1280 i = torch.arange(x.size(0) - 1, -1, -1, device=x.device) 1281 return x[i] 1282 1283 x = torch.randn(8, device=self.device) 1284 flip_opt = torch._dynamo.optimize("inductor")(flip) 1285 1286 expect = flip(x) 1287 actual = _run_and_assert_no_indirect_indexing(self, flip_opt, x) 1288 self.assertEqual(expect, actual) 1289 1290 def test_index_propagation_floordiv(self): 1291 def repeat_interleave(x, n): 1292 # e.g. x=[1, 2, 3], n=2 => returns [1, 1, 2, 2, 3, 3] 1293 i = torch.arange(x.shape[0] * n, device=x.device) 1294 return x[i // n] 1295 1296 x = torch.randn(8, 16, device=self.device) 1297 repeat_interleave_opt = torch._dynamo.optimize("inductor")(repeat_interleave) 1298 # With static shapes we can prove the bound, our dynamic shapes reasoning is not good enough 1299 has_assert = ifdynstaticdefault(False, True) 1300 # this should be collapsed to direct indexing 1301 actual = _run_and_assert_no_indirect_indexing( 1302 self, repeat_interleave_opt, x, 3, has_assert=has_assert 1303 ) 1304 expect = torch.repeat_interleave(x, 3, dim=0) 1305 self.assertEqual(expect, actual) 1306 self.assertEqual(actual, repeat_interleave(x, 3)) 1307 1308 def test_index_propagation_remainder(self): 1309 def repeat(x, n): 1310 # e.g. x=[1, 2, 3], n=2 => returns [1, 2, 3, 1, 2, 3] 1311 i = torch.arange(x.shape[0] * n, device=x.device) 1312 return x[i % x.shape[0]] 1313 1314 x = torch.randn(8, 16, device=self.device) 1315 repeat_opt = torch._dynamo.optimize("inductor")(repeat) 1316 1317 # With static shapes we can prove the bound, our dynamic shapes reasoning is not good enough 1318 has_assert = ifdynstaticdefault(False, True) 1319 # this should be collapsed to direct indexing 1320 actual = _run_and_assert_no_indirect_indexing( 1321 self, repeat_opt, x, 3, has_wrapping=False, has_assert=has_assert 1322 ) 1323 expect = x.repeat(3, 1) 1324 self.assertEqual(expect, actual) 1325 self.assertEqual(actual, repeat(x, 3)) 1326 1327 def test_index_propagation_abs(self): 1328 def reflection_pad_left(x, n): 1329 # e.g. x=[1, 2, 3], n=2 => returns [3, 2, 1, 2, 3] 1330 i = torch.arange(x.shape[0] + n, device=x.device) 1331 return x[(i - n).abs()] 1332 1333 x = torch.randn(8, device=self.device) 1334 opt_fn = torch._dynamo.optimize("inductor")(reflection_pad_left) 1335 1336 # With static shapes we can prove the bound, our dynamic shapes reasoning is not good enough 1337 has_assert = ifdynstaticdefault(False, True) 1338 # this should be collapsed to direct indexing 1339 actual = _run_and_assert_no_indirect_indexing( 1340 self, opt_fn, x, 3, has_wrapping=False, has_assert=has_assert 1341 ) 1342 expect = reflection_pad_left(x, 3) 1343 self.assertEqual(expect, actual) 1344 1345 def test_index_propagation_device_assert_masked(self): 1346 def fn(a): 1347 idx = torch.arange(a.size(0), device=a.device) 1348 padded_idx = torch.constant_pad_nd(idx, (1050, 0)) 1349 padded_idx = torch.where(padded_idx >= 0, padded_idx, padded_idx) 1350 return a[padded_idx] 1351 1352 self.common(fn, (torch.randn(1024),)) 1353 1354 @skipIfRocm 1355 @config.patch(debug_index_asserts=False) 1356 def test_neg_index(self): 1357 def test( 1358 fn, inps, has_assert: bool, has_wrapping: bool, vectorize: bool = True 1359 ): 1360 fn_opt = torch.compile(fn) 1361 if self.device == "cpu": 1362 _, code = run_and_get_cpp_code(fn_opt, *inps) 1363 self.assertTrue(("?" in code or "blendv" in code) is has_wrapping) 1364 self.assertTrue(("TORCH_CHECK" in code) is has_assert) 1365 # Assert that we always vectorize the kernel regardless of wrapping / checks 1366 self.assertTrue(("loadu" in code) is vectorize) 1367 else: 1368 code = run_and_get_triton_code(fn_opt, *inps) 1369 self.assertTrue(("tl.where" in code) is has_wrapping) 1370 self.assertTrue(("device_assert" in code) is has_assert) 1371 1372 def indirect(a, b): 1373 return a[b - 1] 1374 1375 a = torch.rand(1024, device=self.device) 1376 b = torch.zeros(256, dtype=torch.long, device=self.device) 1377 test(indirect, (a, b), has_assert=True, has_wrapping=True) 1378 1379 def direct(x): 1380 return x[:, -1] 1381 1382 a = torch.rand(1, 64, 32, device=self.device) 1383 # Does not even generate a kernel as it's a view 1384 test(direct, (a,), has_assert=False, has_wrapping=False, vectorize=False) 1385 1386 def flip(a, b): 1387 return a[b] 1388 1389 a = torch.rand(1024, device=self.device) 1390 b = torch.arange(start=-1, end=-a.numel() - 1, step=-1, device=self.device) 1391 test(flip, (a, b), has_assert=True, has_wrapping=True) 1392 1393 # Constant propagate a constant that's negative 1394 def flip_with_index_constant(a): 1395 b = torch.arange(start=-1, end=-a.numel() - 1, step=-1, device=a.device) 1396 return a[b] 1397 1398 # Wrapping is constant-folded 1399 test(flip_with_index_constant, (a,), has_assert=False, has_wrapping=False) 1400 1401 # Operation where we can't prove that the index is always positive or negative 1402 def pos_and_neg(a): 1403 b = torch.arange(start=1, end=-a.numel() - 1, step=-1, device=a.device) 1404 return a[b] 1405 1406 # It has wrapping but no assert 1407 test(pos_and_neg, (a,), has_assert=False, has_wrapping=True) 1408 1409 # We currently don't do constant propagation with float constants 1410 # We cannot prove this kind of asserts just with bounds. We would need 1411 # to lift IndexPropagation.shape_env to be accessible in all of Inductor 1412 def flip_with_index(a): 1413 b = 1.0 * torch.arange( 1414 start=-1, end=-a.numel() - 1, step=-1, device=a.device 1415 ) 1416 b = b.int() 1417 return a[b] 1418 1419 test( 1420 flip_with_index, 1421 (a,), 1422 has_assert=ifdynstaticdefault(False, True), 1423 has_wrapping=False, 1424 vectorize=False, # Constant propagation off -> indirect indexing -> no vec 1425 ) 1426 1427 def unsafe_index(a, b): 1428 return aten._unsafe_index(a, (b,)) 1429 1430 test(unsafe_index, (a, b), has_assert=False, has_wrapping=True) 1431 1432 def constant_propagation(a): 1433 b = torch.tensor([2], device=a.device) 1434 return a[b] 1435 1436 test( 1437 constant_propagation, 1438 (a,), 1439 has_assert=ifdynstaticdefault(False, True), 1440 has_wrapping=False, 1441 vectorize=False, # There's no loop to vectorize! 1442 ) 1443 1444 def constant_propagation_neg(a): 1445 b = torch.tensor([-2], device=a.device) 1446 return a[b] 1447 1448 # In symbolic shapes, we know that we can access -2, so no assert is necessary! 1449 test( 1450 constant_propagation_neg, 1451 (a,), 1452 has_assert=False, 1453 has_wrapping=False, 1454 vectorize=False, # There's no loop to vectorize! 1455 ) 1456 1457 def test_computed_buffer_inlining(self): 1458 def flip(x): 1459 idx = torch.arange(x.size(0) - 1, -1, -1, device=x.device) 1460 return x[idx], idx 1461 1462 flip_opt = torch._dynamo.optimize("inductor")(flip) 1463 x = torch.randn(8, device=self.device) 1464 1465 expect = flip(x) 1466 actual = _run_and_assert_no_indirect_indexing(self, flip_opt, x) 1467 self.assertEqual(expect, actual) 1468 1469 def test_sum1(self): 1470 def fn(a, b): 1471 return ((a + b).sum(-1),) 1472 1473 self.common(fn, (torch.randn(8, 8), torch.randn(8, 8))) 1474 1475 def test_sum2(self): 1476 def fn(a, b): 1477 return ((a + b).sum([1, 2]), (a + b).sum(-1)) 1478 1479 self.common(fn, (torch.randn(8, 9, 3, 21), torch.randn(8, 9, 3, 21))) 1480 1481 def test_sum3(self): 1482 def fn(a, b): 1483 r1 = a + b 1484 r2 = r1.sum(-1) 1485 r3 = torch.squeeze(b) + 10 1486 return (r1, r2, r3) 1487 1488 # Mismatched elements: 2 / 10 (20.0%) 1489 # Greatest absolute difference: 0.0029296875 at index (8,) (up to 1e-05 allowed) 1490 # Greatest relative difference: 0.0017482517482517483 at index (6,) (up to 0.001 allowed) 1491 self.common(fn, (torch.randn(10, 10), torch.randn(1, 10)), atol=1e-5, rtol=2e-3) 1492 1493 def test_sum4(self): 1494 def fn(a): 1495 b = a + 1 1496 c = b.sum(-1) 1497 d = c + 3 1498 e = d.sum(-1) 1499 f = e + 5 1500 return (f, e, d, c, b) 1501 1502 self.common(fn, (torch.randn(1, 16, 8, 8),)) 1503 1504 def test_sum5(self): 1505 def fn(a): 1506 b = a + 1 1507 c = b.sum(-1) 1508 d = c + 3 1509 e = d.sum(-1) 1510 f = e + 5 1511 return (f,) 1512 1513 self.common(fn, (torch.randn(1, 17, 8, 9),)) 1514 1515 def test_reduction1(self): 1516 def fn(a): 1517 return (a.sum(), a.max(), a.min(), a.argmax(), a.argmin()) 1518 1519 self.common(fn, (torch.tensor([float("-inf"), 0.0, float("inf")]),)) 1520 1521 @skip_if_x86_mac() 1522 def test_reduction2(self): 1523 def fn(a): 1524 # FIXME: a.argmax 1525 return (a.sum(), a.max(), a.min(), a.argmin()) 1526 1527 self.common(fn, (torch.full((4,), float("inf")),)) 1528 1529 @skip_if_x86_mac() 1530 def test_reduction3(self): 1531 def fn(a): 1532 # FIXME: a.argmin 1533 return (a.sum(), a.max(), a.min(), a.argmax()) 1534 1535 self.common(fn, (torch.full((4,), float("-inf")),)) 1536 1537 def test_reduction4(self): 1538 if self.device == "cpu": 1539 raise unittest.SkipTest("Non-deterministic CPU results") 1540 1541 def fn(a): 1542 return (a.argmax(-1), a.argmin(-1)) 1543 1544 inputs = (torch.ones(128), torch.ones(4, 4, 1)) 1545 for i in inputs: 1546 self.common(fn, (i,)) 1547 1548 @config.patch(unroll_reductions_threshold=1) 1549 def test_reduction5(self): 1550 if self.device == "cpu": 1551 raise unittest.SkipTest("Non-deterministic CPU results") 1552 1553 def fn(a): 1554 return (a.sum(), a.max(), a.min(), a.argmax()) 1555 1556 self.common(fn, (torch.full((4,), float("-inf")),)) 1557 1558 def test_prod(self): 1559 def fn(a): 1560 return a.prod(0), a.prod(1), a.prod() 1561 1562 self.common(fn, (torch.rand((10, 10)),)) 1563 self.common(fn, (torch.rand((1, 2050)),)) 1564 1565 def test_unroll_small_reduction(self): 1566 def fn(x): 1567 val1, index1 = x.min(-1) 1568 val2, index2 = x.max(-1) 1569 return ( 1570 val1, 1571 index1, 1572 val2, 1573 index2, 1574 x.sum(-1), 1575 (x > 1).any(-1), 1576 (x > 0).all(-1), 1577 x.argmin(-1), 1578 x.argmax(-1), 1579 x.amin(-1), 1580 x.amax(-1), 1581 x.aminmax(), 1582 ) 1583 1584 with config.patch(unroll_reductions_threshold=8): 1585 # small sized reductions will get unrolled 1586 self.common(fn, (torch.randn(8, 3),)) 1587 torch._dynamo.reset() 1588 with config.patch(unroll_reductions_threshold=1): 1589 # make sure things also work if they aren't unrolled 1590 self.common(fn, (torch.randn(8, 3),)) 1591 1592 def test_multilayer_sum_low_prec(self): 1593 # fp16 nyi for cpu 1594 if self.device == "cpu": 1595 raise unittest.SkipTest(f"requires {GPU_TYPE}") 1596 1597 def fn(a): 1598 return torch.mean(a) 1599 1600 self.common(fn, ((torch.rand((10, 3, 352, 352), dtype=torch.float16),))) 1601 1602 def test_multilayer_prime_size(self): 1603 def fn(a): 1604 return torch.max(a), torch.sum(a) 1605 1606 # Requires masked loading for the intermediate reduction 1607 sample = torch.full((3999971,), 0, dtype=torch.int64) 1608 sample[-1] = 1 1609 self.common(fn, (sample,)) 1610 1611 @skipCPUIf(IS_MACOS, "fails on macos") 1612 def test_multilayer_var(self): 1613 def fn(a): 1614 return torch.var(a) 1615 1616 self.common(fn, ((torch.rand((10, 3, 352, 352), dtype=torch.float32),))) 1617 self.common(fn, ((torch.rand((14923), dtype=torch.float32),))) 1618 1619 @skipCPUIf(IS_MACOS, "fails on macos") 1620 def test_multilayer_var_lowp(self): 1621 def fn(a): 1622 return torch.var(a) 1623 1624 self.common(fn, (torch.rand((16, 16, 352, 352), dtype=torch.float16),)) 1625 self.common(fn, (torch.rand((14923), dtype=torch.float16),)) 1626 1627 def test_split_cumsum(self): 1628 def fn(a): 1629 return torch.cumsum(a, -1) 1630 1631 for dtype in get_all_dtypes( 1632 include_bfloat16=False, 1633 include_bool=True, 1634 include_complex=False, 1635 include_half=False, 1636 ): 1637 # Use low=0 since when the mean value is 0, cumsum at all points 1638 # tends towards zero which makes the relative error term blow up 1639 inp = make_tensor(10, 3, 352, 352, low=0, dtype=dtype, device=self.device) 1640 self.common(fn, (inp.view(-1),), rtol=1e-5, atol=1e-5, check_lowp=False) 1641 self.common(fn, (inp.view(10, -1),), rtol=1e-5, atol=1e-5, check_lowp=False) 1642 1643 @skipCUDAIf(not SM80OrLater, "Requires sm80") 1644 @skipCUDAIf(TEST_WITH_ROCM, "Computation not done in float on ROCm") 1645 def test_split_cumsum_low_prec(self): 1646 if self.device == "cpu": 1647 raise unittest.SkipTest("ir.Scan nyi on CPU") 1648 1649 def fn(a): 1650 return torch.cumsum(a.view(-1), 0) 1651 1652 self.common( 1653 fn, 1654 (torch.rand((10, 3, 352, 352), dtype=torch.float16),), 1655 reference_in_float=True, 1656 check_lowp=False, 1657 ) 1658 1659 def test_consecutive_split_cumsum(self): 1660 def fn(a, b): 1661 a = a.view(-1) 1662 b = b.view(-1) 1663 return torch.cumsum(a, 0) + torch.cumsum(b, 0) 1664 1665 a = make_tensor(10, 3, 352, 352, low=0, dtype=torch.float32, device=self.device) 1666 b = make_tensor(10, 3, 352, 352, low=0, dtype=torch.float64, device=self.device) 1667 self.common(fn, (a, b), rtol=1e-5, atol=1e-5, check_lowp=False) 1668 1669 def test_split_cumprod(self): 1670 def fn(a): 1671 return torch.cumprod(a, -1) 1672 1673 for dtype in [torch.float32, torch.float64, torch.int32, torch.int64]: 1674 inp = _large_cumprod_input( 1675 (10, 10000), dim=1, dtype=dtype, device=self.device 1676 ) 1677 self.common(fn, (inp,), atol=1e-5, rtol=1e-4, check_lowp=False) 1678 1679 @skipCUDAIf(not SM80OrLater, "Requires sm80") 1680 @skipCUDAIf(TEST_WITH_ROCM, "Computation not done in float on ROCm") 1681 def test_split_cumprod_low_prec(self): 1682 if self.device == "cpu": 1683 raise unittest.SkipTest("ir.Scan nyi on CPU") 1684 1685 def fn(a): 1686 return torch.cumprod(a.view(-1), 0) 1687 1688 for dtype in [torch.float16, torch.bfloat16]: 1689 inp = _large_cumprod_input( 1690 (10, 10000), dim=1, dtype=dtype, device=self.device 1691 ) 1692 self.common( 1693 fn, 1694 (inp,), 1695 reference_in_float=True, 1696 check_lowp=False, 1697 ) 1698 1699 def test_consecutive_split_cumprod(self): 1700 def fn(a, b): 1701 return torch.cumprod(a, 0) + torch.cumprod(b, 0) 1702 1703 a = _large_cumprod_input( 1704 (10000,), dim=0, dtype=torch.float32, device=self.device 1705 ) 1706 b = _large_cumprod_input( 1707 (10000,), dim=0, dtype=torch.float64, device=self.device 1708 ) 1709 self.common(fn, (a, b), atol=1e-5, rtol=1e-5, check_lowp=False) 1710 1711 @skipCUDAIf(TEST_WITH_ROCM, "associative_scan is not supported on ROCm") 1712 def test_custom_scan_op(self): 1713 if self.device != "cuda": 1714 raise unittest.SkipTest("associative_scan only supported on GPU") 1715 1716 def sum_combine(a, b): 1717 return a + b 1718 1719 from torch._higher_order_ops.associative_scan import associative_scan 1720 1721 a = torch.randn(100, 100, device=self.device) 1722 expect = torch.cumsum(a, 0) 1723 actual = associative_scan(sum_combine, a, 0) 1724 self.assertEqual(expect, actual) 1725 1726 def logcumsum_combine(a, b): 1727 min_v = torch.minimum(a, b) 1728 max_v = torch.maximum(a, b) 1729 mask = (min_v != max_v) | ~min_v.isinf() 1730 return torch.where(mask, max_v + (min_v - max_v).exp().log1p(), a) 1731 1732 expect = torch.logcumsumexp(a, 0) 1733 actual = associative_scan(logcumsum_combine, a, 0) 1734 self.assertEqual(expect, actual) 1735 1736 def test_custom_scan_op_compiled(self): 1737 if self.device != "cuda": 1738 raise unittest.SkipTest("associative_scan only supported on GPU") 1739 1740 from torch._higher_order_ops.associative_scan import associative_scan 1741 1742 def sum_combine(a, b): 1743 return a + b 1744 1745 def fn(a, b, dim): 1746 diff = (a - b).abs() 1747 sad = associative_scan(sum_combine, diff, dim) 1748 return sad.sum(dim) 1749 1750 a = torch.randn(100, 100, device=self.device) 1751 b = torch.randn(100, 100, device=self.device) 1752 self.common(fn, (a, b, 0)) 1753 cfn = torch.compile(fn) 1754 _, code = run_and_get_code(cfn, a, b, 0) 1755 1756 # Check everything is fused into a single kernel 1757 FileCheck().check_not("run(").check_regex( 1758 r"triton_.*\.run\(arg[01]_1, arg[12]_1, buf1," 1759 ).check_not("run(").run(code[0]) 1760 1761 @skipCUDAIf(TEST_WITH_ROCM, "associative_scan is not supported on ROCm") 1762 def test_custom_scan_op_multi_input(self): 1763 if self.device != "cuda": 1764 raise unittest.SkipTest("associative_scan only supported on GPU") 1765 1766 def argmax_combine(a, b): 1767 a_value, a_index = a 1768 b_value, b_index = b 1769 mask = (a_value > b_value) | ((a_value == b_value) & (a_index > b_index)) 1770 return ( 1771 torch.where(mask, a_value, b_value), 1772 torch.where(mask, a_index, b_index), 1773 ) 1774 1775 from torch._higher_order_ops.associative_scan import associative_scan 1776 1777 a = torch.randn(100, 100, device=self.device) 1778 expect = torch.cummax(a, 0) 1779 1780 idx = torch.arange(100, device=self.device).view(100, 1).expand(100, 100) 1781 actual = associative_scan(argmax_combine, (a, idx), 0) 1782 self.assertEqual(expect, actual) 1783 1784 def test_embedding_bag_byte_unpack(self): 1785 if self.device != "cpu": 1786 raise unittest.SkipTest(f"No {GPU_TYPE} implementation (it returns empty)") 1787 1788 def fn(a): 1789 return torch.ops.quantized.embedding_bag_byte_unpack(a) 1790 1791 M, N = 32, 64 1792 scales = torch.randn(M, 1).view(torch.uint8) 1793 offsets = torch.randn(M, 1).view(torch.uint8) 1794 data = torch.randint(0, 255, (M, N), dtype=torch.uint8) 1795 packed = torch.cat([data, scales, offsets], dim=-1) 1796 self.common(fn, [packed]) 1797 1798 def test_expanded_reduction(self): 1799 def fn(x, y): 1800 z = x * y 1801 return z.sum((0, 1)) 1802 1803 atol = None 1804 rtol = None 1805 1806 # By default, inductor generate non-persistent reduction kernels in this 1807 # case. But when multi-kernel is enabled, inductor will pick the faster 1808 # of persistent reduction and non-persistent-reduction kernel. 1809 # In this case, inductor picked the persistent-reduction kernel. 1810 # The persistent reduction kernel happens to need looser tolerance. 1811 if config.triton.multi_kernel: 1812 atol = 1e-5 1813 rtol = 1e-5 1814 self.common( 1815 fn, (torch.randn(2, 197, 256), torch.randn(2, 1, 256)), atol=atol, rtol=rtol 1816 ) 1817 1818 def test_min_max_reduction(self): 1819 def fn(a, b): 1820 return ( 1821 (a + b).max(), 1822 (a + b).min(), 1823 torch.amax(a + 1, keepdim=True), 1824 torch.amin(b + 1, keepdim=True), 1825 ) 1826 1827 dtypes = [torch.float, torch.float16] 1828 if not (self.device == "cuda" and not SM80OrLater): 1829 dtypes += [torch.bfloat16] 1830 for dtype in dtypes: 1831 self.common(fn, (torch.randn(8, 8).to(dtype), torch.randn(8, 8).to(dtype))) 1832 1833 def test_min_max_reduction_nan(self): 1834 def fn(a): 1835 return (torch.max(a), torch.min(a)) 1836 1837 t1 = torch.randn(32) 1838 t1[16] = float("nan") 1839 self.common(fn, (t1,)) 1840 1841 def test_fmin_fmax(self): 1842 def fn(a, b): 1843 return ( 1844 torch.fmin(a, b), 1845 torch.fmax(a, b), 1846 torch.fmax(a + 1, torch.tensor(0.0)), 1847 ) 1848 1849 self.common( 1850 fn, 1851 ( 1852 torch.tensor( 1853 [-10.0, 10.0, float("nan"), float("nan"), float("nan"), 3, 4] 1854 ), 1855 torch.tensor( 1856 [float("nan"), float("nan"), -10.0, 10.0, float("nan"), 4, 3] 1857 ), 1858 ), 1859 ) 1860 1861 def test_sum_int(self): 1862 def fn(x): 1863 return 2 * x.sum(-1) + x.sum() 1864 1865 dtypes = torch.bool, torch.uint8, torch.int 1866 inps = [torch.randint(2, (64,), dtype=dtype) for dtype in dtypes] 1867 for i in inps: 1868 self.common(fn, (i,), check_lowp=False) 1869 1870 def test_sum_dtype(self): 1871 def fn(x): 1872 return x * x.sum(-1, dtype=torch.double) + x.sum(dtype=torch.double) 1873 1874 self.common(fn, (torch.ones(32, 32) * 70,)) 1875 1876 def test_cumsum(self): 1877 def fn(x): 1878 return x.cumsum(0), x.cumsum(1) 1879 1880 # Persistent reductions 1881 self.common(fn, (torch.rand(16, 32),), check_lowp=True) 1882 self.common(fn, (torch.rand(20, 30),), check_lowp=True) 1883 1884 # Non-persistent reduction 1885 self.common(fn, (torch.rand(100, 4000),), check_lowp=True) 1886 1887 def test_cumsum_zero_dim(self): 1888 def fn(x): 1889 return x.cumsum(0), x.cumsum(-1) 1890 1891 a = torch.rand(()) 1892 self.common(fn, (a,)) 1893 1894 def test_cumsum_no_mask(self): 1895 def fn(x): 1896 return x.cumsum(-1) 1897 1898 # Persistent reduction 1899 a = torch.rand((1, 1024)) 1900 self.common(fn, (a,), check_lowp=not TEST_WITH_ROCM) 1901 1902 # Non-persistent reduction 1903 b = torch.rand((1, 8192)) 1904 self.common(fn, (b,), check_lowp=not TEST_WITH_ROCM) 1905 1906 def test_cumprod_zero_dim(self): 1907 def fn(x): 1908 return x.cumprod(0), x.cumprod(-1) 1909 1910 a = torch.rand(()) 1911 self.common(fn, (a,)) 1912 1913 def test_logcumsumexp(self): 1914 def fn(x): 1915 return x.logcumsumexp(0), x.logcumsumexp(1) 1916 1917 # Persistent reductions 1918 self.common(fn, (torch.rand(16, 32),), check_lowp=not TEST_WITH_ROCM) 1919 self.common(fn, (torch.rand(20, 30),), check_lowp=not TEST_WITH_ROCM) 1920 1921 # Non-persistent reduction 1922 self.common(fn, (torch.rand(100, 4000),), check_lowp=not TEST_WITH_ROCM) 1923 1924 def test_logcumsumexp_zero_dim(self): 1925 def fn(x): 1926 return x.logcumsumexp(0), x.logcumsumexp(-1) 1927 1928 a = torch.rand(()) 1929 self.common(fn, (a,)) 1930 1931 def test_clamp(self): 1932 def fn(a, b): 1933 return (a.clamp(-0.1, 0.1), b.clamp(0), torch.clamp(a + b, max=0)) 1934 1935 self.common(fn, (torch.randn(8, 8), torch.randn(8, 8))) 1936 1937 def test_clamp_type_promotion(self): 1938 def fn(a): 1939 b = torch.tensor(1.0, dtype=torch.double, device=self.device) 1940 c = torch.full((4,), 2, device=self.device) 1941 return a.clamp(min=b, max=c) 1942 1943 self.common(fn, (torch.randint(4, (4,)),)) 1944 1945 def test_dist(self): 1946 def fn(a, b): 1947 return ( 1948 torch.dist(a, b), 1949 torch.dist(a, b, p=1.2), 1950 ) 1951 1952 self.common(fn, (torch.randn(4, 4), torch.randn(4, 4))) 1953 1954 @skipCUDAIf(not SM80OrLater, "Requires sm80") 1955 def test_dist_bf16(self): 1956 def fn(a, b): 1957 return torch.dist(a.to(torch.bfloat16), b.to(torch.bfloat16)) 1958 1959 self.common(fn, (torch.randn(4, 4), torch.randn(4, 4))) 1960 1961 def test_arange1(self): 1962 def fn(x): 1963 rng1 = torch.arange(8 * 8, dtype=torch.float32, device=x.device).view(8, 8) 1964 rng2 = torch.arange(10, 18, device=x.device) 1965 tmp = x * rng1 1966 return tmp, tmp + rng2 1967 1968 self.common(fn, (torch.randn(8, 8),)) 1969 1970 def test_arange2(self): 1971 def fn(x): 1972 rng1 = torch.arange(8, device=x.device) 1973 return (x + rng1,) 1974 1975 self.common(fn, (torch.randint(4, (8, 8)),), check_lowp=False) 1976 1977 def test_arange3(self): 1978 def fn(x): 1979 return x + torch.ops.aten.arange.start_step( 1980 0, 53, 4, dtype=torch.int64, device=x.device 1981 ) 1982 1983 self.common(fn, (torch.randn(14),)) 1984 1985 def test_arange4(self): 1986 def fn(x): 1987 return x - torch.arange(512, -512, -1.0, device=x.device) 1988 1989 self.common(fn, (torch.randn(1024),)) 1990 1991 def test_arange5(self): 1992 def fn(step, device): 1993 return torch.arange(512, -512, step, device=device) 1994 1995 compiled_fn = torch._dynamo.optimize()(fn) 1996 1997 # NOTE: use assertEqual to check dtypes which self.common doesn't do 1998 for step in (-1, -1.0): 1999 expect = fn(step, self.device) 2000 actual = compiled_fn(step, self.device) 2001 self.assertEqual(expect, actual) 2002 self.assertEqual(expect, actual) 2003 2004 def test_arange6(self): 2005 def fn(x): 2006 return torch.arange(0.1, 8.0001, 1, dtype=x.dtype, device=x.device) 2007 2008 # Test that float arguments are truncated to int when dtype is set explicitly 2009 make_arg = functools.partial( 2010 make_tensor, device=self.device, requires_grad=False 2011 ) 2012 self.common(fn, (make_arg(1, dtype=torch.float32),)) 2013 self.common(fn, (make_arg(1, dtype=torch.int64),)) 2014 2015 def test_linspace1(self): 2016 def fn(x): 2017 return torch.linspace(0.125, 0.875, 7, device=x.device) + x 2018 2019 self.common(fn, (torch.randn(1, 7),)) 2020 2021 def test_linspace2(self): 2022 def fn(x): 2023 return torch.linspace(0, 2, 1, device=x.device) + x 2024 2025 self.common(fn, (torch.randn(1, 1),)) 2026 2027 def test_linspace3(self): 2028 def fn(x): 2029 return torch.linspace(0, 2, 0, device=x.device) 2030 2031 self.common(fn, (torch.Tensor([]),)) 2032 2033 def test_tensor1(self): 2034 def fn(x): 2035 return torch.tensor([1], device=x.device) + x, torch.tensor( 2036 5, device=x.device 2037 ) 2038 2039 self.common(fn, (torch.randn(10),)) 2040 2041 def test_tensor2(self): 2042 def fn(x): 2043 return torch.tensor(list(range(2, 40, 2)), device=x.device) + x 2044 2045 self.common(fn, (torch.randn(1),)) 2046 2047 def test_tensor3(self): 2048 def fn(x): 2049 return ( 2050 torch.tensor([], device=x.device), 2051 torch.tensor([1, 2], device=x.device) + 1, 2052 torch.tensor([1, 2, 3], device=x.device) + 2, 2053 torch.tensor([1, 2, 3, 4], device=x.device) + x, 2054 ) 2055 2056 self.common(fn, [torch.randn(4)]) 2057 2058 def test_views1(self): 2059 def fn1(x, y): 2060 return (x.view(size2) + y,) 2061 2062 def fn2(x, y): 2063 return ((x + 1).view(size2) + y,) 2064 2065 views = [ 2066 ([5 * 7], [5, 7]), 2067 ([2 * 3 * 4 * 5 * 6 * 7], [2, 3, 4, 5, 6, 7]), 2068 ([2 * 3, 4, 5, 6 * 7], [2, 3, 4, 5, 6, 7]), 2069 ([10 * 5, 20], [10, 5, 20]), 2070 ([1, 10, 1], [10]), 2071 ([10, 1, 10, 1, 10], [10, 100]), 2072 ([2, 2, 2, 2], [4, 4]), 2073 ] 2074 for size1, size2 in views: 2075 self.common(fn1, (torch.randn(size1), torch.randn(size2))) 2076 self.common(fn2, (torch.randn(size1), torch.randn(size2))) 2077 2078 for size2, size1 in views: 2079 self.common(fn1, (torch.randn(size1), torch.randn(size2))) 2080 self.common(fn2, (torch.randn(size1), torch.randn(size2))) 2081 2082 def test_views2(self): 2083 def fn1(x): 2084 return (x.view(size2) + 1,) 2085 2086 def fn2(x): 2087 return ((x * 2).view(size2) + 1,) 2088 2089 for size1, size2 in [ 2090 ([2, 2, 2, 2], [4, -1]), 2091 ([10, 1, 10, 1, 10], [-1, 100]), 2092 ([10 * 5, 20], [10, -1, 20]), 2093 ]: 2094 self.common(fn1, (torch.randn(size1),)) 2095 self.common(fn2, (torch.randn(size1),)) 2096 2097 def test_views3(self): 2098 # example taken from hf_BigBird 2099 def forward(arg1, arg2): 2100 index = torch.ops.aten.index(arg1, [arg2]) 2101 view_1 = torch.ops.aten.view(index, [1, 2232, 64]) 2102 view_2 = torch.ops.aten.view(view_1, [1, 12, 62, 192]) 2103 return view_2 2104 2105 self.common( 2106 forward, 2107 ( 2108 rand_strided((64, 64), (64, 1), torch.float32), 2109 rand_strided((2232,), (1,), torch.int64), 2110 ), 2111 ) 2112 2113 def test_views4(self): 2114 # example taken from hf_BigBird 2115 def forward(arg1, arg2): 2116 arg1 = arg1.index_select(0, arg2) 2117 arg1 = torch.ops.aten.view(arg1, [2, 3, 4, 5, 5]) 2118 arg1 = torch.ops.aten.view(arg1, [2, 3, 2, 10, -1]) 2119 return arg1 2120 2121 self.common( 2122 forward, 2123 ( 2124 torch.randn(12, 5, 5), 2125 torch.randint(0, 11, (24,)), 2126 ), 2127 ) 2128 2129 def test_views5(self): 2130 # tensor with shape 0 in any dimension 2131 def forward(x): 2132 y = x[:, 4:] 2133 return y.view(len(y), -1, 4) 2134 2135 self.common( 2136 forward, 2137 (torch.randn(4, 4, 4, 4),), 2138 ) 2139 2140 def test_views6(self): 2141 def forward(x): 2142 x = torch.ops.aten.relu(x) 2143 s = torch.ops.aten.slice(x, 0, 0, 9223372036854775807) 2144 s = torch.ops.aten.slice(s, 1, 0, 9223372036854775807) 2145 s = torch.ops.aten.slice(s, 3, 0, 0) 2146 y = torch.ops.aten.view(s, [4, 2, -1]) 2147 return y 2148 2149 self.common( 2150 forward, 2151 (torch.randn(4, 2, 4, 4),), 2152 ) 2153 2154 def test_views7(self): 2155 # x.view(dtype) 2156 def forward(x, y): 2157 x = (x + 1).to(torch.float32) 2158 y = (y + 1).to(torch.int32) 2159 return x.view(torch.int32), y.view(torch.float32) 2160 2161 self.common( 2162 forward, 2163 ( 2164 torch.rand(2, 3, dtype=torch.float32), 2165 torch.randint(10, (2, 3), dtype=torch.int32), 2166 ), 2167 ) 2168 2169 def test_relu(self): 2170 def fn(a, b): 2171 return (torch.relu(a), torch.relu(a + b) / 10) 2172 2173 self.common(fn, (torch.randn(8, 8), torch.randn(8, 8))) 2174 2175 def test_exp(self): 2176 def fn(a, b): 2177 return (torch.exp(a), torch.exp(a + b)) 2178 2179 self.common(fn, (torch.randn(8, 8), torch.randn(8, 8))) 2180 2181 def test_exp2(self): 2182 def fn(a, b): 2183 return (torch.exp2(a), torch.exp2(a + b), torch.pow(2, -torch.abs(a - b))) 2184 2185 self.common(fn, (torch.randn(8, 8), torch.randn(8, 8))) 2186 2187 def test_sigmoid(self): 2188 def fn(a, b): 2189 return (torch.sigmoid(a), torch.sigmoid(a + b)) 2190 2191 self.common(fn, (torch.randn(8, 8), torch.randn(8, 8))) 2192 2193 def test_round(self): 2194 def fn(a, b): 2195 return torch.round(a), torch.round(b + 1), torch.round(a, decimals=2) 2196 2197 # without manual_seed, there is some chance this test fails due to: 2198 # https://github.com/openai/triton/issues/530 2199 torch.manual_seed(0) 2200 2201 # with *100 we are always getting a number exactly at .5 which we don't do right in half 2202 self.common(fn, (torch.randn(8, 8) * 100, torch.randn(8, 8) * 10)) 2203 2204 def test_round_correctness(self): 2205 if self.device == "cuda": 2206 raise unittest.SkipTest("need to debug tl.libdevice on A100/V100") 2207 2208 def fn(a): 2209 return torch.round(a) 2210 2211 self.common( 2212 fn, 2213 [torch.arange(-10, 10, 0.1, dtype=torch.float64)], 2214 check_lowp=False, 2215 ) 2216 2217 def test_builtins_round(self): 2218 def fn(x, i): 2219 return x[: round(i / 2 + 1)] + round(i / 2) 2220 2221 cfn = torch.compile(fullgraph=True, dynamic=True)(fn) 2222 2223 x = torch.zeros(5, dtype=torch.int, device=self.device) 2224 with torch.no_grad(): 2225 for i in range(1, 6): 2226 self.assertEqual(cfn(x, i), fn(x, i)) 2227 2228 def test_builtins_round_float_ndigits_pos(self): 2229 def fn(x, i): 2230 return x + round(i / 2 * 123.4567, 1) 2231 2232 cfn = torch.compile(fullgraph=True, dynamic=True)(fn) 2233 2234 x = torch.zeros(2, device=self.device) 2235 i = 2 2236 2237 with torch.no_grad(): 2238 self.assertEqual(cfn(x, i), fn(x, i)) 2239 2240 def test_builtins_round_float_ndigits_zero(self): 2241 def fn(x, i): 2242 return x + round(i / 2 * 123.4567, 0) 2243 2244 cfn = torch.compile(fullgraph=True, dynamic=True)(fn) 2245 2246 x = torch.zeros(2, device=self.device) 2247 i = 2 2248 2249 with torch.no_grad(): 2250 self.assertEqual(cfn(x, i), fn(x, i)) 2251 2252 def test_builtins_round_float_ndigits_neg(self): 2253 def fn(x, i): 2254 return x + round(i / 2 * 123.4567, -1) 2255 2256 cfn = torch.compile(fullgraph=True, dynamic=True)(fn) 2257 2258 x = torch.zeros(2, device=self.device) 2259 i = 2 2260 2261 with torch.no_grad(): 2262 self.assertEqual(cfn(x, i), fn(x, i)) 2263 2264 def test_builtins_round_int_ndigits_pos(self): 2265 def fn(x, i): 2266 return x + round(i, 1) 2267 2268 cfn = torch.compile(fullgraph=True, dynamic=True)(fn) 2269 2270 x = torch.zeros(2, device=self.device) 2271 i = 123 2272 2273 with torch.no_grad(): 2274 self.assertEqual(cfn(x, i), fn(x, i)) 2275 2276 def test_builtins_round_int_ndigits_zero(self): 2277 def fn(x, i): 2278 return x + round(i, 0) 2279 2280 cfn = torch.compile(fullgraph=True, dynamic=True)(fn) 2281 2282 x = torch.zeros(2, device=self.device) 2283 i = 123 2284 2285 with torch.no_grad(): 2286 self.assertEqual(cfn(x, i), fn(x, i)) 2287 2288 def test_silu(self): 2289 def fn(a): 2290 return (torch.nn.functional.silu(a),) 2291 2292 self.common(fn, (torch.randn(8, 8),)) 2293 2294 def test_nan_to_num(self): 2295 def fn(a): 2296 return ( 2297 torch.nan_to_num(a), 2298 torch.nan_to_num(a, nan=3.0), 2299 torch.nan_to_num(a, nan=None), 2300 torch.nan_to_num(a, posinf=4.0), 2301 torch.nan_to_num(a, neginf=5.0), 2302 torch.nan_to_num(a, nan=3.0, posinf=4.0, neginf=5.0), 2303 ) 2304 2305 self.common( 2306 fn, 2307 (torch.tensor((float("nan"), float("inf"), float("-inf"), 1.0)),), 2308 check_lowp=False, # a much more elaborate test is required to match finfo max's for float and half 2309 ) 2310 2311 def test_one_hot(self): 2312 def fn(a): 2313 return torch.nn.functional.one_hot(a, 8) + 1 2314 2315 self.common( 2316 fn, 2317 (torch.arange(100).view(4, 5, 5) % 8,), 2318 check_lowp=False, 2319 ) 2320 2321 def test_div1(self): 2322 def fn(a, b): 2323 return ( 2324 aten.div(a, b, rounding_mode=None), 2325 aten.div(a, b, rounding_mode="floor"), 2326 aten.div(a, b, rounding_mode="trunc"), 2327 a / b, 2328 a // b, 2329 ) 2330 2331 self.common(fn, (torch.randn(8, 8) * 100, torch.randn(8, 8) * 100)) 2332 2333 def test_div2(self): 2334 def fn(a, b): 2335 return ( 2336 aten.div(a, b, rounding_mode=None), 2337 aten.div(a, b, rounding_mode="floor"), 2338 aten.div(a, b, rounding_mode="trunc"), 2339 a / b, 2340 a // b, 2341 ) 2342 2343 self.common(fn, (torch.randint(-100, 100, [8, 8]), 100 * torch.randn(8, 8))) 2344 2345 def test_div3(self): 2346 def fn(a, b): 2347 return ( 2348 aten.div(a, b, rounding_mode=None), 2349 aten.div(a, b, rounding_mode="floor"), 2350 aten.div(a, b, rounding_mode="trunc"), 2351 a / b, 2352 a // b, 2353 ) 2354 2355 a = torch.randint(1, 100, [8, 8]) 2356 self.common(fn, (a * 2, a)) 2357 2358 def test_div4(self): 2359 def fn(a, b): 2360 return ( 2361 aten.div(a, b, rounding_mode=None), 2362 aten.div(a, b, rounding_mode="floor"), 2363 aten.div(a, b, rounding_mode="trunc"), 2364 a / b, 2365 a // b, 2366 ) 2367 2368 self.common( 2369 fn, 2370 (torch.randint(-100, 0, [8, 8]), torch.randint(1, 10, [8, 8])), 2371 ) 2372 2373 def test_div5(self): 2374 def fn(a, b): 2375 return ( 2376 aten.div(a, b, rounding_mode=None), 2377 aten.div(a, b, rounding_mode="floor"), 2378 aten.div(a, b, rounding_mode="trunc"), 2379 a / b, 2380 a // b, 2381 ) 2382 2383 # divide a scalar 2384 self.common(fn, (torch.randint(-100, 0, [8, 8]), 16)) 2385 2386 def test_div6(self): 2387 def fn(a, b): 2388 return ( 2389 aten.div(a, b, rounding_mode=None), 2390 aten.div(a, b, rounding_mode="floor"), 2391 aten.div(a, b, rounding_mode="trunc"), 2392 a / b, 2393 a // b, 2394 ) 2395 2396 # treat boolean as integer 2397 self.common( 2398 fn, 2399 (torch.ones([8, 8], dtype=torch.bool), torch.randint(-100, -1, [8, 8])), 2400 ) 2401 2402 def test_div7(self): 2403 def fn(a, b): 2404 return ( 2405 aten.div(a, b, rounding_mode=None), 2406 aten.div(a, b, rounding_mode="floor"), 2407 aten.div(a, b, rounding_mode="trunc"), 2408 a / b, 2409 a // b, 2410 ) 2411 2412 self.common( 2413 fn, 2414 ( 2415 torch.randint(2**32, 2**40, [100, 100]), 2416 torch.randint(-10, -1, [100, 100]), 2417 ), 2418 ) 2419 2420 def test_div8(self): 2421 def fn(a, b): 2422 return ( 2423 aten.div(a, b, rounding_mode=None), 2424 aten.div(a * 0.5, b, rounding_mode=None), 2425 aten.div(a, b * 1.0, rounding_mode=None), 2426 aten.div(a, b, rounding_mode="floor"), 2427 aten.div(a, b, rounding_mode="trunc"), 2428 a / b, 2429 a // b, 2430 ) 2431 2432 self.common(fn, (1024, 100)) 2433 2434 def test_div9(self): 2435 def fn(x): 2436 return (torch.div(42, x), aten.true_divide(42, x), aten.div.Tensor(42, x)) 2437 2438 self.common(fn, (torch.randn(8),)) 2439 2440 def test_div_zero_dim(self): 2441 def fn(a, b): 2442 return ( 2443 aten.div(a, b, rounding_mode=None), 2444 aten.div(a, b, rounding_mode="floor"), 2445 aten.div(a, b, rounding_mode="trunc"), 2446 a / b, 2447 a // b, 2448 ) 2449 2450 for dtype in (torch.float32, torch.int64): 2451 self.common( 2452 fn, 2453 ( 2454 make_tensor(10, device=self.device, dtype=dtype), 2455 make_tensor((), device=self.device, dtype=dtype, exclude_zero=True), 2456 ), 2457 ) 2458 self.common( 2459 fn, 2460 ( 2461 make_tensor((), device=self.device, dtype=dtype), 2462 make_tensor(10, device=self.device, dtype=dtype, exclude_zero=True), 2463 ), 2464 ) 2465 2466 def test_div_prim(self): 2467 def fn(a, b): 2468 return (torch.ops.prims.div(a, b),) 2469 2470 for dtype in (torch.float32, torch.int64): 2471 self.common( 2472 fn, 2473 ( 2474 make_tensor(100, device=self.device, dtype=dtype), 2475 make_tensor( 2476 100, device=self.device, dtype=dtype, exclude_zero=True 2477 ), 2478 ), 2479 ) 2480 2481 def test_floordiv(self): 2482 def fn_floor_input(a, i): 2483 n = (i * 1.234) // 8.234 2484 return a + n 2485 2486 self.common( 2487 fn_floor_input, 2488 (make_tensor(10, device=self.device, dtype=torch.float32), 33), 2489 ) 2490 2491 def fn_int_input(a, i): 2492 n = i // 8 2493 return a + n 2494 2495 self.common( 2496 fn_int_input, (make_tensor(10, device=self.device, dtype=torch.float32), 33) 2497 ) 2498 2499 def test_div_precision(self): 2500 # Reproducer for https://github.com/pytorch/pytorch/issues/101039 2501 2502 def forward(x, y): 2503 z = x.div(y) 2504 return F.softmax(z, dim=-1) 2505 2506 query = torch.randn(1, 10, 40) 2507 key = torch.randn(1, 2, 40) 2508 x = torch.matmul(query, key.transpose(-2, -1)) 2509 self.common(forward, (x, 1e-6)) 2510 2511 x = torch.tensor( 2512 [ 2513 [ 2514 [ 2515 [-16.1649, 5.6846, -5.1022, -9.1134], 2516 [-11.5552, -2.2615, -12.8913, 10.6538], 2517 [-7.1666, -5.3333, 2.0776, -9.7984], 2518 [7.4469, -2.3948, 2.7371, 0.9201], 2519 ], 2520 [ 2521 [-8.0361, -16.3771, 22.7741, 4.4685], 2522 [20.8047, -0.7771, -2.4355, -2.2299], 2523 [3.8343, -2.0914, -2.4077, 2.2740], 2524 [-15.8663, -2.7015, -12.5241, -3.0040], 2525 ], 2526 [ 2527 [-2.5139, 14.4393, -3.7186, 1.2255], 2528 [5.6742, 14.1842, -8.5976, 16.8366], 2529 [-9.7358, -3.0279, 11.8164, -4.0787], 2530 [-9.0621, 8.2580, 29.9486, -2.4107], 2531 ], 2532 [ 2533 [7.3622, 12.5640, -20.5592, 13.6237], 2534 [-11.5640, 0.8832, 16.7275, -2.5009], 2535 [-2.0953, -12.2276, -26.2633, 4.5268], 2536 [15.3329, -11.7492, 6.5650, -9.2483], 2537 ], 2538 ], 2539 [ 2540 [ 2541 [7.9980, -4.9369, 3.1508, 5.2994], 2542 [3.8052, 3.9514, 8.4987, -10.5045], 2543 [-2.6827, -4.0010, -4.0611, 6.4091], 2544 [-19.0318, 6.4073, 2.8923, 8.0250], 2545 ], 2546 [ 2547 [7.1650, -3.4585, 5.7720, -5.0305], 2548 [-0.9765, -3.0086, 11.7114, 8.0555], 2549 [-3.1027, -3.5514, 9.6182, -8.8526], 2550 [-9.2348, -6.0239, 6.2528, -6.7221], 2551 ], 2552 [ 2553 [11.5936, 22.4139, -0.4089, -4.9889], 2554 [14.8217, -2.3426, -17.6189, 3.7427], 2555 [1.9546, -13.0902, 8.6293, -7.2457], 2556 [-7.6900, -4.5796, 9.6332, -10.2631], 2557 ], 2558 [ 2559 [0.8027, -1.0955, 14.8404, -0.2673], 2560 [3.2143, -1.8640, -2.9678, 6.5165], 2561 [-3.9865, 6.5230, 6.3019, -0.4247], 2562 [8.3185, -13.5076, 27.0986, -1.6792], 2563 ], 2564 ], 2565 ] 2566 ) 2567 x = torch.matmul(x, x) 2568 y = torch.tensor([[[0.6331]], [[1.6358]], [[-0.3459]], [[1.0196]]]) 2569 self.common(forward, (x, y)) 2570 2571 def test_div_by_zero(self): 2572 def fn(x, runtime_zero, runtime_neg_zero): 2573 zero = torch.zeros_like(x) 2574 return ( 2575 x / 0.0, 2576 x / -0.0, 2577 zero / 0.0, 2578 x / zero, 2579 x / -zero, 2580 zero / zero, 2581 x / runtime_zero, 2582 # NOTE: -runtime_zero doesn't work as -(0.0) is broken in triton 2583 x / runtime_neg_zero, 2584 runtime_zero / runtime_neg_zero, 2585 ) 2586 2587 a = torch.randn(10) 2588 zero = torch.zeros(10) 2589 neg_zero = -zero 2590 self.common(fn, (a, zero, neg_zero)) 2591 2592 def test_both_scalars(self): 2593 def fn(a, b): 2594 return ( 2595 aten.add(a, b), 2596 aten.add(b, a), 2597 aten.sub(a, b), 2598 aten.sub(b, a), 2599 aten.mul(a, b), 2600 aten.mul(b, a), 2601 ) 2602 2603 self.common(fn, (4, 3.3), reference_in_float=False) 2604 2605 def test_sum_keepdims(self): 2606 def fn(a, b): 2607 return (torch.sum(a + b, -1, keepdim=True),) 2608 2609 self.common(fn, (torch.randn(8, 8), torch.randn(8, 8))) 2610 2611 def test_large_tensor_reduction(self): 2612 if not _has_sufficient_memory(self.device, 4.5 * 1024**3): # 4.5 GiB 2613 raise unittest.SkipTest("insufficient memory") 2614 2615 if self.device == "cpu": 2616 raise unittest.SkipTest("Fails on CPU") 2617 2618 # Test 64-bit indexing works correctly 2619 def fn(a): 2620 return torch.max(a) 2621 2622 t = torch.ones(2**32, dtype=torch.int8, device=self.device) 2623 t[-1] = 2 2624 2625 # self.common OOMs here because it copies inputs to check for mutations 2626 compiled_fn = torch._dynamo.optimize()(fn) 2627 actual = compiled_fn(t) 2628 expect = torch.tensor(2, dtype=torch.int8, device=self.device) 2629 self.assertEqual(actual, expect) 2630 2631 def test_large_broadcast_reduction(self): 2632 if self.device == "cpu": 2633 raise unittest.SkipTest("Fails on CPU") 2634 2635 # Test 64-bit indexing works correctly when inputs are less than 32-bit 2636 # but intermediate tensors require 64-bit indexing 2637 def fn(a, b): 2638 return torch.max(a + b) 2639 2640 t1 = torch.ones(1, 2**16, dtype=torch.int8, device=self.device) 2641 t2 = torch.ones(2**16, 1, dtype=torch.int8, device=self.device) 2642 2643 t1[-1, -1] = 2 2644 t2[-1, -1] = 2 2645 2646 # self.common OOMs here because it copies inputs to check for mutations 2647 compiled_fn = torch._dynamo.optimize()(fn) 2648 actual = compiled_fn(t1, t2) 2649 expect = torch.tensor(4, dtype=torch.int8, device=self.device) 2650 self.assertEqual(actual, expect) 2651 2652 def test_large_pointwise(self): 2653 if not _has_sufficient_memory(self.device, 2 * (2**31 + 1)): 2654 raise unittest.SkipTest("insufficient memory") 2655 2656 def fn(a): 2657 return a + 1 2658 2659 t = torch.ones(2**31 + 1, dtype=torch.int8, device=self.device) 2660 compiled_fn = torch._dynamo.optimize()(fn) 2661 actual = compiled_fn(t) 2662 2663 # Can't use assertEqual as it expands broadcasted inputs 2664 del t 2665 if torch.device(self.device).type == GPU_TYPE: 2666 getattr(torch, GPU_TYPE).empty_cache() 2667 2668 self.assertTrue((actual == 2).all()) 2669 2670 def test_large_offset_pointwise(self): 2671 # Test 64-bit indexing is used when input views a tensor that can be 2672 # indexed with 32-bit strides but the storage offset pushes it over 2673 # INT_MAX 2674 if not _has_sufficient_memory(self.device, (2**31 + 1) + (2**30 + 1)): 2675 raise unittest.SkipTest("insufficient memory") 2676 2677 def fn(a): 2678 return a + 4 2679 2680 t = torch.ones(2**31 + 1, dtype=torch.int8, device=self.device) 2681 t[2**30 :] = 0 2682 compiled_fn = torch._dynamo.optimize()(fn) 2683 actual = compiled_fn(t[2**30 :]) 2684 self.assertTrue((actual == 4).all()) 2685 2686 def test_large_strided_reduction(self): 2687 # Test 64-bit indexing is used when input numel is less than INT_MAX 2688 # but stride calculations go above INT_MAX 2689 if not _has_sufficient_memory(self.device, 2**31 + 2): 2690 raise unittest.SkipTest("insufficient memory") 2691 2692 def fn(a): 2693 return torch.max(a) 2694 2695 storage = torch.ones(2**31 + 1, dtype=torch.int8, device=self.device) 2696 view = storage[::32] 2697 view[-1] = 2 2698 2699 compiled_fn = torch._dynamo.optimize()(fn) 2700 actual = compiled_fn(view) 2701 expect = torch.tensor(2, dtype=torch.int8, device=self.device) 2702 self.assertEqual(actual, expect) 2703 2704 def test_softmax(self): 2705 def fn(a, b): 2706 return (torch.softmax(a + b, -1), torch.softmax(a, 0), torch.softmax(b, 1)) 2707 2708 self.common(fn, (torch.randn(8, 8), torch.randn(8, 8))) 2709 2710 def test_log_softmax(self): 2711 def fn(a, b): 2712 return (F.log_softmax(a + b, -1), F.log_softmax(a, 0), F.log_softmax(b, 1)) 2713 2714 self.common(fn, (torch.randn(8, 8), torch.randn(8, 8))) 2715 2716 def test_transpose(self): 2717 def fn(a, b): 2718 return ( 2719 torch.t(a) + b, 2720 torch.transpose(b * 2, 0, 1) + 10, 2721 ) 2722 2723 self.common(fn, (torch.randn(8, 8), torch.randn(8, 8))) 2724 2725 def test_permute1(self): 2726 def fn(a): 2727 return ( 2728 torch.permute(a + 1, [2, 1, 4, 0, 3]) + 2, 2729 torch.permute(a, [2, 1, 4, 0, 3]) + 2, 2730 ) 2731 2732 self.common(fn, (torch.randn(2, 2, 2, 2, 2),)) 2733 2734 def test_permute2(self): 2735 def fn(a): 2736 a = a.unfold(0, 2, 1) 2737 a = torch.unsqueeze(a, 1) 2738 a = torch.permute(a, [0, 2, 3, -3]) 2739 return (a,) 2740 2741 self.common(fn, (torch.randn(4, 4),)) 2742 2743 def test_expand(self): 2744 def fn(a): 2745 return ( 2746 (a + 1).expand(3, 4, 2, 3, 2) + 2, 2747 a.expand(2, 1, 2, 3, 2) + 2, 2748 ), a.expand(2, -1, 5, -1) 2749 2750 self.common(fn, (torch.randn(2, 1, 2),)) 2751 2752 def test_squeeze1(self): 2753 def fn(a): 2754 return ((a + 1).squeeze() + 2, a.squeeze() + 2) 2755 2756 self.common(fn, (torch.randn(1, 2, 1, 2, 2, 1, 1),)) 2757 2758 def test_squeeze2(self): 2759 def fn(a): 2760 return ((a + 1).squeeze(-1).squeeze(2) + 2, a.squeeze(0) + 2) 2761 2762 self.common(fn, (torch.randn(1, 2, 1, 2, 2, 2, 1),)) 2763 2764 def test_squeeze_varargs(self): 2765 def fn(x): 2766 return x.squeeze(1, 2).clone() 2767 2768 a = torch.randn(1024, 1, 1) 2769 self.common(fn, (a,)) 2770 2771 def test_simplify_loops(self): 2772 def fn(a, b): 2773 return a + b 2774 2775 self.common( 2776 fn, 2777 ( 2778 torch.randn(2, 3, 4, 5, 6), 2779 torch.randn(4, 2, 3, 5, 6).permute(1, 2, 0, 3, 4), 2780 ), 2781 ) 2782 2783 def test_unsqueeze(self): 2784 def fn(a): 2785 return ( 2786 torch.unsqueeze(a + 1, -1) + 2, 2787 torch.unsqueeze(a, 2) + 2, 2788 torch.unsqueeze(a + 1, 0) + 2, 2789 torch.unsqueeze(a, -2) + 2, 2790 ) 2791 2792 self.common( 2793 fn, 2794 ( 2795 torch.randn( 2796 2, 2797 2, 2798 2, 2799 2, 2800 ), 2801 ), 2802 ) 2803 2804 def test_unsqueeze_inplace(self): 2805 def fn(a): 2806 tmp1 = a + 1 2807 aten.unsqueeze_(tmp1, 2) 2808 tmp2 = aten.unsqueeze_(a + 1, 0) + 2 2809 return (tmp1, tmp2) 2810 2811 self.common( 2812 fn, 2813 ( 2814 torch.randn( 2815 2, 2816 2, 2817 2, 2818 2, 2819 ), 2820 ), 2821 ) 2822 2823 def test_addmm(self): 2824 def fn(a, b, c): 2825 return (torch.addmm(a + 1, b + 2, c + 3) + 4,) 2826 2827 self.common( 2828 fn, 2829 ( 2830 torch.randn(8, 8), 2831 torch.randn(8, 8), 2832 torch.randn(8, 8), 2833 ), 2834 ) 2835 2836 # https://github.com/pytorch/pytorch/issues/98979 2837 @skipCUDAIf(True, "cuda failed for float64 linear") 2838 @skipIfXpu(msg="Double and complex datatype matmul is not supported in oneDNN") 2839 def test_linear_float64(self): 2840 mod = torch.nn.Sequential(torch.nn.Linear(8, 16).to(torch.float64)).eval() 2841 with torch.no_grad(): 2842 self.common(mod, (torch.randn(2, 8).to(torch.float64),)) 2843 2844 def test_linear1(self): 2845 mod = torch.nn.Sequential( 2846 torch.nn.Linear(8, 16), 2847 torch.nn.Sigmoid(), 2848 ToTuple(), 2849 ) 2850 self.common(mod, (torch.randn(2, 8),)) 2851 2852 def test_linear2(self): 2853 mod = torch.nn.Sequential( 2854 torch.nn.Linear(8, 8), 2855 torch.nn.ReLU(), 2856 torch.nn.Linear(8, 8), 2857 torch.nn.ReLU(), 2858 torch.nn.Linear(8, 8), 2859 torch.nn.ReLU(), 2860 torch.nn.Linear(8, 8), 2861 torch.nn.ReLU(), 2862 ) 2863 self.common( 2864 mod, 2865 (torch.randn(2, 8),), 2866 atol=1e-3, 2867 rtol=0.01, 2868 ) 2869 2870 def test_bmm1(self): 2871 def fn(a, b): 2872 return ( 2873 torch.bmm(a, b), 2874 torch.bmm(a + 1, b + 2) + 3, 2875 ) 2876 2877 self.common( 2878 fn, 2879 ( 2880 torch.randn(2, 8, 8), 2881 torch.randn(2, 8, 8), 2882 ), 2883 check_lowp=False, 2884 ) 2885 self.common( 2886 fn, 2887 ( 2888 torch.randn(1, 16, 8), 2889 torch.randn(1, 8, 10), 2890 ), 2891 check_lowp=False, 2892 ) 2893 2894 def test_bmm2(self): 2895 def fn(a, b): 2896 return torch.bmm(a.permute(0, 2, 1), b) 2897 2898 self.common( 2899 fn, 2900 ( 2901 torch.randn(1, 8, 8), 2902 torch.randn(1, 8, 8), 2903 ), 2904 check_lowp=False, 2905 ) 2906 2907 @skipIfPy312 # segfaults 2908 @config.patch(force_mixed_mm=True) 2909 def test_mixed_mm(self): 2910 def fn(a, b): 2911 return torch.mm(a, b.to(a.dtype)) 2912 2913 self.common( 2914 fn, 2915 ( 2916 torch.randn(8, 8), 2917 torch.randint(-128, 127, (8, 8), dtype=torch.int8), 2918 ), 2919 check_lowp=True, 2920 ) 2921 2922 @skipIfPy312 # segfaults 2923 @config.patch(force_mixed_mm=True) 2924 def test_mixed_mm2(self): 2925 def fn(a, b, scale, bias): 2926 return torch.mm(a, b.to(a.dtype)) * scale + bias 2927 2928 self.common( 2929 fn, 2930 ( 2931 torch.randn(8, 8), 2932 torch.randint(-128, 127, (8, 8), dtype=torch.int8), 2933 torch.randn(8), 2934 torch.randn(8), 2935 ), 2936 check_lowp=True, 2937 ) 2938 2939 @skipIfPy312 # segfaults 2940 @config.patch(force_mixed_mm=True) 2941 def test_mixed_mm3(self): 2942 def fn(a, b): 2943 return torch.mm(a, b.to(a.dtype)) 2944 2945 # (256, 256) @ (256, 256) so different block sizes are tried out during autotuning 2946 self.common( 2947 fn, 2948 ( 2949 torch.randn(256, 256), 2950 torch.randint(-128, 127, (256, 256), dtype=torch.int8), 2951 ), 2952 check_lowp=True, 2953 rtol=0.01, 2954 atol=0.1, 2955 ) 2956 2957 @with_tf32_off 2958 @config.patch(use_mixed_mm=True) 2959 def test_uint4x2_mixed_mm(self): 2960 def fn(a, b): 2961 return torch.mm( 2962 a, 2963 torch.cat((b & 0xF, b >> 4), 1) 2964 .reshape(-1, b.shape[1]) 2965 .to(a.dtype) 2966 .sub(8), 2967 ) 2968 2969 self.common( 2970 fn, 2971 ( 2972 torch.randn(8, 8), 2973 torch.randint(0, 255, (4, 8), dtype=torch.uint8), 2974 ), 2975 check_lowp=True, 2976 ) 2977 2978 @expectedFailureXPU 2979 def test_mm_mixed_dtype(self): 2980 def fn(a, b): 2981 return torch.mm(a, b) 2982 2983 t1 = torch.arange(6, dtype=torch.float, device=self.device).view(2, 3) 2984 t2 = torch.arange(9, dtype=torch.int64, device=self.device).view(3, 3) 2985 2986 msg = "expected .* and .* to have the same dtype, but got: .* != .*" 2987 with self.assertRaisesRegex(RuntimeError, msg): 2988 torch.compile(fn)(t1, t2) 2989 with self.assertRaisesRegex(RuntimeError, msg): 2990 fn(t1, t2) 2991 2992 @expectedFailureXPU 2993 def test_linear_mixed_dtype(self): 2994 class Net(nn.Module): 2995 def __init__(self): 2996 super(Net, self).__init__() # noqa: UP008 2997 self.fc1 = nn.Linear(3, 3) 2998 2999 def forward(self, x): 3000 x = self.fc1(x.permute(1, 2, 0)) 3001 return x 3002 3003 fn = Net().to(self.device) 3004 t = torch.arange(27, device=self.device).view(3, 3, 3) 3005 3006 msg = "expected .* and .* to have the same dtype, but got: .* != .*" 3007 with self.assertRaisesRegex(RuntimeError, msg): 3008 fn(t) 3009 with self.assertRaisesRegex(RuntimeError, msg): 3010 with torch.no_grad(): 3011 torch.compile(fn)(t) 3012 # TODO: Autograd internal assertion 3013 msg = r".*isDifferentiableType\(variable.scalar_type\(\)\) INTERNAL ASSERT FAILED.*" 3014 with self.assertRaisesRegex(RuntimeError, msg): 3015 torch.compile(fn)(t) 3016 3017 def test_scalar_input(self): 3018 def fn(x, y): 3019 a = torch.div(x, y, rounding_mode="floor") 3020 return a 3021 3022 self.common(fn, [torch.randint(5, (1, 8)), 5400]) 3023 3024 @torch._dynamo.config.patch(dynamic_shapes=True) 3025 @torch._dynamo.config.patch(assume_static_by_default=False) 3026 def test_scalar_output(self): 3027 def fn(arg0_1, arg2_1): 3028 arg1_1 = arg2_1.size(1) 3029 view = torch.ops.aten.view.default(arg2_1, [-1, arg1_1]) 3030 embedding = torch.ops.aten.embedding.default(arg0_1, view) 3031 full = torch.ops.aten.full.default([1, arg1_1], 1, dtype=torch.float32) 3032 return (full, arg1_1, embedding) 3033 3034 arg0_1 = rand_strided((32128, 768), (768, 1), device="cpu", dtype=torch.float32) 3035 arg2_1 = rand_strided((1, 22), (22, 1), device="cpu", dtype=torch.int64) 3036 self.common(fn, [arg0_1, arg2_1]) 3037 3038 def test_shape_prop_torch_ones(self): 3039 class Model(torch.nn.Module): 3040 def forward(self, attention_scores): 3041 extended_attention_mask = torch.ones( 3042 8, 1, 1, 512, device=attention_scores.device 3043 ) 3044 attention_scores = attention_scores + extended_attention_mask 3045 3046 return attention_scores 3047 3048 mod = Model().eval() 3049 with torch.no_grad(): 3050 self.common( 3051 mod, 3052 (torch.randn(8, 12, 512, 512),), 3053 ) 3054 3055 @slowTest 3056 @expectedFailureCodegenDynamic 3057 @config.patch({"freezing": True}) 3058 def test_conv_bn_fuse(self): 3059 # For gpu path, there is an accuracy issue 3060 if self.device == GPU_TYPE: 3061 raise unittest.SkipTest("only support cpu conv bn test") 3062 3063 # fails dynamic check which bn is fused, and there will not have loops vars. 3064 input_shapes = {1: (112,), 2: (112, 112), 3: (55, 55, 55)} 3065 conv_modules = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d} 3066 bn_modules = { 3067 1: torch.nn.BatchNorm1d, 3068 2: torch.nn.BatchNorm2d, 3069 3: torch.nn.BatchNorm3d, 3070 } 3071 options = itertools.product( 3072 [1, 2, 3], 3073 [True, False], 3074 [1, 3], 3075 [1, 2], 3076 [1, 4], 3077 ) 3078 3079 for ( 3080 dim, 3081 bias, 3082 kernel_size, 3083 dilation, 3084 groups, 3085 ) in options: 3086 oC = 32 * groups 3087 iC = 3 * groups 3088 x_shape = (1, iC) + input_shapes[dim] 3089 mod = torch.nn.Sequential( 3090 conv_modules[dim]( 3091 iC, 3092 oC, 3093 kernel_size=kernel_size, 3094 dilation=dilation, 3095 groups=groups, 3096 bias=bias, 3097 ), 3098 bn_modules[dim](oC), 3099 ).eval() 3100 test_memory_format = [torch.contiguous_format] 3101 # TODO: GPU path doesn't support channels_last now. 3102 if not HAS_GPU and dim > 1: 3103 channels_last = ( 3104 torch.channels_last if dim == 2 else torch.channels_last_3d 3105 ) 3106 test_memory_format.append(channels_last) 3107 for memory_format in test_memory_format: 3108 v = torch.randn(x_shape, dtype=torch.float32).to( 3109 memory_format=memory_format 3110 ) 3111 with torch.no_grad(): 3112 self.common( 3113 mod, 3114 (v,), 3115 ) 3116 3117 def test_conv_functional_bn_fuse(self): 3118 # For gpu path, there is an accuracy issue 3119 if self.device == GPU_TYPE: 3120 raise unittest.SkipTest("only support cpu conv bn test") 3121 3122 # Define a BatchNorm using functional BN. 3123 class BatchNorm(torch.nn.BatchNorm2d): 3124 def __init__( 3125 self, 3126 num_features, 3127 eps=1e-5, 3128 momentum=0.1, 3129 affine=True, 3130 track_running_stats=True, 3131 device=None, 3132 dtype=None, 3133 ): 3134 factory_kwargs = {"device": device, "dtype": dtype} 3135 super().__init__( 3136 num_features, 3137 eps=eps, 3138 momentum=momentum, 3139 affine=affine, 3140 track_running_stats=track_running_stats, 3141 **factory_kwargs, 3142 ) 3143 3144 def forward(self, x): 3145 if self.momentum is None: 3146 exponential_average_factor = 0.0 3147 else: 3148 exponential_average_factor = self.momentum 3149 3150 if self.training and self.track_running_stats: 3151 # TODO: if statement only here to tell the jit to skip emitting this when it is None 3152 if self.num_batches_tracked is not None: # type: ignore[has-type] 3153 self.num_batches_tracked = self.num_batches_tracked + 1 # type: ignore[has-type] 3154 if self.momentum is None: # use cumulative moving average 3155 exponential_average_factor = 1.0 / float( 3156 self.num_batches_tracked 3157 ) 3158 else: # use exponential moving average 3159 exponential_average_factor = self.momentum 3160 if self.training: 3161 bn_training = True 3162 else: 3163 bn_training = (self.running_mean is None) and ( 3164 self.running_var is None 3165 ) 3166 x = F.batch_norm( 3167 x, 3168 # If buffers are not to be tracked, ensure that they won't be updated 3169 ( 3170 self.running_mean 3171 if not self.training or self.track_running_stats 3172 else None 3173 ), 3174 ( 3175 self.running_var 3176 if not self.training or self.track_running_stats 3177 else None 3178 ), 3179 self.weight, 3180 self.bias, 3181 bn_training, 3182 exponential_average_factor, 3183 self.eps, 3184 ) 3185 return x 3186 3187 v = torch.randn(1, 3, 556, 56, dtype=torch.float32) 3188 mod = torch.nn.Sequential( 3189 torch.nn.Conv2d( 3190 3, 3191 64, 3192 kernel_size=3, 3193 dilation=1, 3194 groups=1, 3195 bias=True, 3196 ), 3197 BatchNorm(64), 3198 ).eval() 3199 with torch.no_grad(): 3200 self.common( 3201 mod, 3202 (v,), 3203 ) 3204 3205 @skipIfRocm 3206 def test_conv_inference_heuristics(self): 3207 if self.device != GPU_TYPE: 3208 raise unittest.SkipTest(f"{GPU_TYPE} only test") 3209 3210 in_channels = 6 3211 out_channels = 6 3212 kernel_size = 3 3213 groups = 3 3214 3215 grouped_conv = nn.Conv2d( 3216 in_channels, out_channels, kernel_size, groups=groups 3217 ).to(self.device) 3218 3219 input_tensor = torch.randn(1, in_channels, 10, 10).to(self.device) 3220 3221 # Perform the forward pass 3222 @torch.compile() 3223 def foo(m, inp): 3224 return m(inp) 3225 3226 with torch.no_grad(): 3227 _, code = run_and_get_code(foo, grouped_conv, input_tensor) 3228 # no to channels last permuting before kernel 3229 FileCheck().check_not(".run(").check(".convolution(").run(code[0]) 3230 3231 # in out should do channels last in inference 3232 in_channels = 8 3233 out_channels = 4 3234 kernel_size = 3 3235 3236 # Create the convolution layer 3237 conv_layer = nn.Conv2d(in_channels, out_channels, kernel_size).to(self.device) 3238 3239 input_tensor = torch.randn(1, in_channels, 10, 10).to(self.device) 3240 3241 with torch.no_grad(): 3242 _, code = run_and_get_code(foo, conv_layer, input_tensor) 3243 # should be channels last permuting before kernel 3244 FileCheck().check(".run(").check(".convolution(").run(code[0]) 3245 3246 def test_upsample_cat_conv(self): 3247 if self.device == GPU_TYPE: 3248 raise unittest.SkipTest("only support cpu upsample_cat_conv test") 3249 3250 class M(torch.nn.Module): 3251 def __init__( 3252 self, 3253 **kwargs, 3254 ): 3255 super().__init__() 3256 self.upsample = torch.nn.UpsamplingNearest2d(scale_factor=2) 3257 self.conv = torch.nn.Conv2d( 3258 8, 3259 5, 3260 kernel_size=1, 3261 padding=0, 3262 stride=1, 3263 dilation=1, 3264 **kwargs, 3265 ) 3266 3267 def forward(self, x, y): 3268 x = self.upsample(x) 3269 z = torch.cat([x, y], dim=1) 3270 z = self.conv(z) 3271 return z 3272 3273 v1 = torch.randn([8, 2, 12, 26]) 3274 v2 = torch.randn([8, 6, 24, 52]) 3275 3276 with torch.no_grad(): 3277 self.common( 3278 M().eval(), 3279 (v1, v2), 3280 ) 3281 3282 def test_aliased_buffer_reuse(self): 3283 def fn(x, y): 3284 x = 2 * x 3285 y = 2 * y 3286 c = torch.cat([x, y], dim=-1) 3287 d = 1 + c 3288 m = torch.mm(d, d) 3289 return m[:, :2] + x 3290 3291 self.common(fn, (torch.randn(4, 2), torch.randn(4, 2)), check_lowp=False) 3292 3293 def test_slice_view_with_graph_break(self): 3294 def fn(): 3295 a = torch.tensor([1], device=self.device) 3296 a = a[0:1] 3297 b = a.squeeze() 3298 a[0] = 0 3299 if a[0] < 1e5: 3300 pass 3301 a[0] = 2 3302 return b 3303 3304 expect = fn() 3305 opt_fn = torch.compile(fn) 3306 actual = opt_fn() 3307 self.assertEqual(expect, actual) 3308 3309 def test_view_detach(self): 3310 def fn(a): 3311 return a[0].detach() 3312 3313 self.common( 3314 fn, 3315 (torch.randn([4, 4], requires_grad=True),), 3316 ) 3317 3318 def test_gather1(self): 3319 def fn(a, b): 3320 return ( 3321 torch.gather(a.expand([4, 5, 10, 6]), 3, b + 1), 3322 torch.gather(a.expand([4, 5, 10, 6]), -1, b + 1), 3323 ) 3324 3325 self.common( 3326 fn, 3327 ( 3328 torch.randn([1, 1, 10, 6]), 3329 torch.randint(5, [4, 5, 10, 1], dtype=torch.int64), 3330 ), 3331 ) 3332 3333 def test_gather2(self): 3334 # 0d tensor 3335 def fn(a, b): 3336 return torch.gather(a, 0, b) + torch.gather(a, -1, b) 3337 3338 x = torch.tensor(123) 3339 y = torch.tensor(0) 3340 self.assertEqual(fn(x, y), x + x) 3341 3342 def test_gather3(self): 3343 def fn(a, b): 3344 return torch.gather(a, 1, b, sparse_grad=True) 3345 3346 self.common( 3347 fn, 3348 ( 3349 torch.randn([4, 5, 10, 6], requires_grad=True), 3350 torch.randint(5, [4, 5, 10, 1], dtype=torch.int64), 3351 ), 3352 ) 3353 3354 def test_slice1(self): 3355 def fn(a): 3356 return ( 3357 a[:, :10, 0] + a[:, 10:, 0], 3358 (a + 1)[:, :10, 0] + (a + 1)[:, 10:, 0], 3359 a[:, -30:, 0], # negative index out of range 3360 a[:, :-30, 0], # negative index out of range 3361 ) 3362 3363 self.common( 3364 fn, 3365 (torch.randn([2, 20, 2]),), 3366 ) 3367 3368 def test_slice2(self): 3369 def fn(a): 3370 return ( 3371 a[:-1, ::2, -1] + a[-1:, 1::2, -2], 3372 (a + 1)[:-1, ::2, -1] + (a + 2)[-1:, 1::2, -2], 3373 ) 3374 3375 self.common( 3376 fn, 3377 (torch.randn([2, 20, 2]),), 3378 ) 3379 3380 # It's a view so it doens't generate a kernel 3381 @expectedFailureCodegenDynamic 3382 def test_slice3(self): 3383 def fn(a, b): 3384 return torch.ops.aten.slice.Tensor(a, 0, 0, -b) 3385 3386 x = torch.rand(48, 3, 512, 512) 3387 self.common(fn, (x, 2)) 3388 3389 @expectedFailureCodegenDynamic 3390 def test_slice4(self): 3391 # empty slices that require clamping the start or end 3392 def fn(a): 3393 return ( 3394 aten.slice.Tensor(a, 0, 2, 0, 1), 3395 aten.slice.Tensor(a, 0, a.shape[0], a.shape[0] + 10, 1), 3396 aten.slice.Tensor(a, 0, -20, 0, 1), 3397 aten.slice.Tensor(a, 0, -20, -16, 1), 3398 ) 3399 3400 x = torch.rand(10) 3401 self.common(fn, (x,)) 3402 3403 def test_split_with_list(self): 3404 def fn(a, sizes): 3405 return [t + 1.0 for t in torch.split(a * 2.0, sizes, -1)] 3406 3407 self.common(fn, (torch.randn(2, 2, 10), [3, 3, 4])) 3408 self.common(fn, (torch.randn(2, 2, 10), [4, 3, 3])) 3409 self.common(fn, (torch.randn(2, 2, 10), [1, 2, 3, 4])) 3410 3411 def test_split_with_integer(self): 3412 # argument `split_size_or_sections` is integer 3413 @torch.compile(dynamic=True) 3414 def f(x, sizes): 3415 return torch.split(x, sizes, -1) 3416 3417 # split into equally sized chunks, 10 = 5 + 5 3418 r1, r2 = f(torch.randn(2, 10), 5) 3419 self.assertTrue(r1.size() == (2, 5)) 3420 self.assertTrue(r2.size() == (2, 5)) 3421 3422 # split into equally sized chunks, 12 = 4 + 4 + 4 3423 r1, r2, r3 = f(torch.randn(2, 12), 4) 3424 self.assertTrue(r1.size() == (2, 4)) 3425 self.assertTrue(r2.size() == (2, 4)) 3426 self.assertTrue(r3.size() == (2, 4)) 3427 3428 # split unevenly, 10 = 3 + 3 + 3 + 1 3429 r1, r2, r3, r4 = f(torch.randn(2, 10), 3) 3430 self.assertTrue(r1.size() == (2, 3)) 3431 self.assertTrue(r2.size() == (2, 3)) 3432 self.assertTrue(r3.size() == (2, 3)) 3433 self.assertTrue(r4.size() == (2, 1)) 3434 3435 def test_split_failed(self): 3436 @torch._dynamo.optimize("inductor") 3437 def fn(a): 3438 return torch.split(a, [2, 1, 1], dim=1) 3439 3440 with self.assertRaisesRegex(RuntimeError, ""): 3441 fn(torch.randn(1, 5)) 3442 3443 def test_inductor_assert(self): 3444 @torch._dynamo.optimize("inductor", dynamic=True) 3445 def fn(a): 3446 assert a.shape[0] >= 2 and a.shape[1] >= 4 3447 return a.cos() 3448 3449 inp = torch.randn(2, 4, 6) 3450 torch._dynamo.mark_dynamic(inp, 0) 3451 torch._dynamo.mark_dynamic(inp, 1) 3452 self.assertEqual(fn(inp), inp.cos()) 3453 3454 def test_split(self): 3455 def fn(a): 3456 t = torch.split(a, 3, -1) 3457 return (t[0], t[1], t[2], t[3]) 3458 3459 def fn2(a): 3460 return fn(a + 1) 3461 3462 self.common( 3463 fn, 3464 (torch.randn([2, 2, 10]),), 3465 ) 3466 3467 self.common( 3468 fn2, 3469 (torch.randn([2, 2, 10]),), 3470 ) 3471 3472 def test_to_dtype(self): 3473 def fn(a, b): 3474 return ( 3475 aten._to_copy(a, dtype=6), 3476 aten._to_copy(b + 1, dtype=6), 3477 aten.to(b, torch.float64), 3478 aten.to(b, torch.bool), 3479 ) 3480 3481 self.common( 3482 fn, 3483 ( 3484 torch.randn([2, 2, 10]), 3485 torch.randn([2, 2, 10], dtype=torch.float64), 3486 ), 3487 ) 3488 3489 @requires_gpu() 3490 def test_to_device(self): 3491 def fn(a): 3492 if a.device.type == "cpu": 3493 return aten._to_copy( 3494 a, device=torch.device(GPU_TYPE), dtype=6, layout=0 3495 ) 3496 else: 3497 return aten._to_copy(a, device=torch.device("cpu"), dtype=6, layout=0) 3498 3499 self.common( 3500 fn, 3501 (torch.randn([2, 2, 10]),), 3502 ) 3503 3504 def test_to_memory_format(self): 3505 def fn(a, memory_format): 3506 return a.to(memory_format=memory_format) 3507 3508 self.common( 3509 fn, 3510 (torch.randn([2, 2, 10, 10]), torch.channels_last), 3511 ) 3512 self.common( 3513 fn, 3514 ( 3515 torch.randn([2, 2, 10, 10]).to(memory_format=torch.channels_last), 3516 torch.contiguous_format, 3517 ), 3518 ) 3519 3520 @requires_gpu() 3521 def test_to_device_constant(self): 3522 def fn(a): 3523 d1 = a.device.type 3524 if d1 == "cpu": 3525 d2 = GPU_TYPE 3526 else: 3527 d2 = "cpu" 3528 3529 const1 = torch.as_tensor(list(range(64)), device=d2) 3530 return ( 3531 torch.arange(10, device=d2).to(d1) + a, 3532 const1.to(d1), 3533 (const1 + 1).to(d1), 3534 ) 3535 3536 self.common( 3537 fn, 3538 (torch.randn([10]),), 3539 ) 3540 3541 @requires_gpu() 3542 def test_multi_device(self): 3543 def fn(x): 3544 x = x + 1 3545 x = x + 2 3546 x = x.to(device=GPU_TYPE) 3547 x = x + 3 3548 x = x + 4 3549 x = x.cpu() 3550 x = x + 5 3551 x = x + 6 3552 x = x.to(device=GPU_TYPE) 3553 x = x + 7 3554 x = x + 8 3555 x = x.cpu() 3556 x = x + 9 3557 x = x + 10 3558 return x 3559 3560 self.common( 3561 fn, 3562 (torch.randn([2, 2, 10]),), 3563 check_lowp=False, # cpu doesn't understand fp16, and there are explicit .cpu() calls 3564 ) 3565 3566 @skipIfRocm 3567 @requires_multigpu() 3568 def test_multi_gpu_device(self): 3569 # TODO: https://github.com/pytorch/pytorch/issues/92627 3570 x = torch.rand([4], device=GPU_TYPE) 3571 3572 def fn(x, y): 3573 r = torch.ops.aten.div(x, y) 3574 r = r.to(f"{GPU_TYPE}:1") 3575 return 2 * r 3576 3577 self.common(fn, (torch.randn(4), torch.randn(4)), check_lowp=False) 3578 3579 @requires_multigpu() 3580 def test_multi_gpu_recompile_on_index(self): 3581 torch.set_float32_matmul_precision("high") 3582 3583 def gemm(x, y): 3584 return x @ y 3585 3586 failed_guard = None 3587 3588 def fail(guard): 3589 nonlocal failed_guard 3590 failed_guard = guard 3591 3592 gemm_opt = torch._dynamo.optimize("inductor", guard_fail_fn=fail)(gemm) 3593 3594 x0 = torch.randn(1024, 1024, device=f"{GPU_TYPE}:0") 3595 y0 = torch.randn(1024, 1024, device=f"{GPU_TYPE}:0") 3596 3597 gemm_opt(x0, y0) 3598 3599 x1 = torch.randn(1024, 1024, device=f"{GPU_TYPE}:1") 3600 y1 = torch.randn(1024, 1024, device=f"{GPU_TYPE}:1") 3601 3602 gemm_opt(x1, y1) 3603 self.assertTrue(failed_guard is not None) 3604 self.assertTrue( 3605 "tensor 'L['x']' Tensor device index mismatch. Expected device index to be" 3606 in failed_guard.reason 3607 ) 3608 3609 def test_unbind(self): 3610 def fn(a): 3611 return torch.unbind(a), torch.unbind(a, -1) 3612 3613 self.common( 3614 fn, 3615 (torch.randn([4, 4, 4]),), 3616 ) 3617 3618 @skipIfRocm 3619 def test_convolution1(self): 3620 m = torch.nn.Sequential( 3621 torch.nn.Conv2d(5, 6, [3, 3]), 3622 torch.nn.ReLU(), 3623 ToTuple(), 3624 ) 3625 3626 self.common( 3627 m, 3628 (torch.randn([2, 5, 16, 16]),), 3629 # Mismatched elements: 10 / 2352 (0.4%) 3630 # Greatest absolute difference: 5.7220458984375e-05 at index (0, 3, 12, 12) (up to 1e-05 allowed) 3631 # Greatest relative difference: 0.06512477175897748 at index (0, 4, 11, 9) (up to 0.001 allowed) 3632 atol=6e-5, 3633 rtol=0.001, 3634 ) 3635 3636 def test_convolution2(self): 3637 def fn(x, w, b): 3638 # transposed conv 3639 return (aten.convolution(x, w, b, [4], [0], [1], True, [0], 1),) 3640 3641 self.common( 3642 fn, 3643 ( 3644 torch.randn([2, 32, 90]), 3645 torch.randn([32, 16, 8]), 3646 torch.randn([16]), 3647 ), 3648 check_lowp=False, 3649 ) 3650 3651 @skipIfRocm 3652 def test_convolution3(self): 3653 # Test stride or padding or dilation is 1 element list. 3654 m = torch.nn.Sequential( 3655 torch.nn.Conv2d(5, 6, [3, 3], stride=[1], padding=[0], dilation=[1]), 3656 torch.nn.ReLU(), 3657 ToTuple(), 3658 ) 3659 3660 self.common( 3661 m, 3662 (torch.randn([2, 5, 16, 16]),), 3663 atol=6e-5, 3664 rtol=0.001, 3665 ) 3666 3667 @skipIfRocm 3668 def test_convolution4(self): 3669 def fn(x, w): 3670 x = F.conv2d(x, w, groups=w.shape[0]) 3671 return x.sum() 3672 3673 self.common( 3674 fn, 3675 ( 3676 torch.randn([2, 3, 16, 20]), 3677 torch.randn([3, 1, 5, 5]), 3678 ), 3679 ) 3680 3681 def test_conv2d_channels_last(self): 3682 if self.device == GPU_TYPE: 3683 raise unittest.SkipTest("only support cpu conv2d channels_last") 3684 3685 m = torch.nn.Sequential( 3686 torch.nn.Conv2d(3, 3, 1, 1), 3687 ToTuple(), 3688 ) 3689 # only weight is channels_last 3690 self.common( 3691 m.to(memory_format=torch.channels_last), 3692 (torch.randn([2, 3, 16, 16]),), 3693 check_lowp=False, 3694 ) 3695 # only activation is channels_last 3696 self.common( 3697 m, 3698 (torch.randn([2, 3, 16, 16]).to(memory_format=torch.channels_last),), 3699 check_lowp=False, 3700 ) 3701 # activation and weight are all channels_last 3702 self.common( 3703 m.to(memory_format=torch.channels_last), 3704 (torch.randn([2, 3, 16, 16]).to(memory_format=torch.channels_last),), 3705 check_lowp=False, 3706 ) 3707 3708 def test_conv2d_backward_channels_last(self): 3709 def fn(grad_output, inp, weight): 3710 convolution_backward_8 = torch.ops.aten.convolution_backward.default( 3711 grad_output, 3712 inp, 3713 weight, 3714 [320], 3715 [1, 1], 3716 [0, 0], 3717 [1, 1], 3718 False, 3719 [0, 0], 3720 1, 3721 [True, True, True], 3722 ) 3723 return convolution_backward_8 3724 3725 # only weight is channels_last 3726 self.common( 3727 fn, 3728 ( 3729 torch.randn([2, 320, 8, 8]), 3730 torch.randn([2, 2048, 8, 8]), 3731 torch.randn([320, 2048, 1, 1]).to(memory_format=torch.channels_last), 3732 ), 3733 check_lowp=False, 3734 ) 3735 3736 def test_conv3d_channels_last(self): 3737 if self.device == GPU_TYPE: 3738 raise unittest.SkipTest("only support cpu conv3d channels_last") 3739 3740 m = torch.nn.Sequential( 3741 torch.nn.Conv3d(3, 3, 1, 1), 3742 ToTuple(), 3743 ) 3744 # only weight is channels_last 3745 self.common( 3746 m.to(memory_format=torch.channels_last_3d), 3747 (torch.randn([2, 3, 16, 16, 16]),), 3748 ) 3749 # only activation is channels_last 3750 self.common( 3751 m, 3752 (torch.randn([2, 3, 16, 16, 16]).to(memory_format=torch.channels_last_3d),), 3753 ) 3754 # activation and weight are all channels_last 3755 self.common( 3756 m.to(memory_format=torch.channels_last_3d), 3757 (torch.randn([2, 3, 16, 16, 16]).to(memory_format=torch.channels_last_3d),), 3758 ) 3759 3760 def test_adaptive_avg_pool2d1(self): 3761 def fn(x): 3762 return aten._adaptive_avg_pool2d(x, (6, 6)), aten._adaptive_avg_pool2d( 3763 x + 1, (2, 5) 3764 ) 3765 3766 self.common( 3767 fn, 3768 (torch.randn(2, 4, 16, 16),), 3769 check_lowp=False, 3770 ) 3771 3772 # lowering to avg_pool2d case 3773 self.common( 3774 fn, 3775 (torch.randn(2, 4, 3, 3),), 3776 ) 3777 3778 # no-op case 3779 self.common( 3780 fn, 3781 (torch.randn(2, 4, 6, 6),), 3782 ) 3783 3784 def test_adaptive_avg_pool2d2(self): 3785 # Big kernel size, use fallback 3786 def fn(x): 3787 return aten._adaptive_avg_pool2d(x, (4, 4)) 3788 3789 torch._inductor.metrics.generated_kernel_count = 0 3790 self.common( 3791 fn, 3792 (torch.randn(2, 4, 21, 21),), 3793 check_lowp=False, 3794 ) 3795 assertGeneratedKernelCountEqual(self, 0) 3796 3797 def test_adaptive_max_pool2d1(self): 3798 def fn(x): 3799 return aten.adaptive_max_pool2d(x, (6, 6)) 3800 3801 self.common( 3802 fn, 3803 (torch.randn(2, 4, 16, 16),), 3804 check_lowp=False, 3805 ) 3806 3807 # lowering to max_pool2d case 3808 self.common( 3809 fn, 3810 (torch.randn(2, 4, 3, 3),), 3811 ) 3812 3813 # no-op case 3814 self.common( 3815 fn, 3816 (torch.randn(2, 4, 6, 6),), 3817 ) 3818 3819 def test_adaptive_max_pool2d2(self): 3820 # Big kernel size, use fallback 3821 def fn(x): 3822 return aten.adaptive_max_pool2d(x, (4, 4)) 3823 3824 torch._inductor.metrics.generated_kernel_count = 0 3825 self.common( 3826 fn, 3827 (torch.randn(2, 4, 21, 21),), 3828 check_lowp=False, 3829 ) 3830 assertGeneratedKernelCountEqual(self, 0) 3831 3832 def test_fractional_max_pool2d1(self): 3833 def fn(x, samples): 3834 return aten.fractional_max_pool2d(x, (3, 3), (2, 2), samples) 3835 3836 self.common( 3837 fn, (torch.randn(1, 4, 16, 16), torch.rand(1, 4, 2)), check_lowp=False 3838 ) 3839 3840 def test_fractional_max_pool2d2(self): 3841 # fallback for larger kernel size 3842 3843 def fn(x, samples): 3844 return aten.fractional_max_pool2d(x, (6, 5), (3, 3), samples) 3845 3846 torch._inductor.metrics.generated_kernel_count = 0 3847 self.common( 3848 fn, 3849 (torch.randn(2, 4, 36, 36), torch.rand(2, 4, 2)), 3850 check_lowp=False, 3851 ) 3852 assertGeneratedKernelCountEqual(self, 0) 3853 3854 def test_fractional_max_pool2d3(self): 3855 def fn(x, samples): 3856 return aten.fractional_max_pool2d(x, (1, 1), (16, 16), samples) 3857 3858 self.common( 3859 fn, (torch.randn(2, 4, 16, 16), torch.rand(2, 4, 2)), check_lowp=False 3860 ) 3861 3862 @config.patch(fallback_random=True) 3863 def test_fractional_max_pool2d4(self): 3864 random.seed(1234) 3865 torch.manual_seed(1234) 3866 3867 # check rectangular kernel/output size 3868 3869 def fn(x): 3870 return torch.nn.functional.fractional_max_pool2d_with_indices( 3871 x, (4, 3), (3, 2) 3872 ) 3873 3874 self.common(fn, (torch.randn(1, 4, 16, 16),), check_lowp=False) 3875 3876 def test_multi_threading(self): 3877 model = torch.nn.Linear(2, 3).eval() 3878 inp = torch.randn(4, 2) 3879 3880 num_run = 3 3881 3882 def run_weights_sharing_model(m, inp): 3883 with torch.no_grad(): 3884 for i in range(num_run): 3885 y = m(inp) 3886 3887 numb_instance = 2 3888 threads = [] 3889 compiled_m = torch.compile(model) 3890 for i in range(1, numb_instance + 1): 3891 thread = threading.Thread( 3892 target=run_weights_sharing_model, args=(compiled_m, inp) 3893 ) 3894 threads.append(thread) 3895 thread.start() 3896 for thread in threads: 3897 thread.join() 3898 3899 @unittest.skipIf(config.is_fbcode(), "fbcode triton error, needs debugging") 3900 def test_adaptive_avg_pool2d_low_prec(self): 3901 class Model(torch.nn.Module): 3902 def __init__(self): 3903 super().__init__() 3904 self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1)) 3905 3906 def forward(self, x): 3907 x = self.avgpool(x) 3908 return x 3909 3910 mod = Model().to(self.device) 3911 for dtype in [torch.half, torch.bfloat16]: 3912 x = torch.randn(4, 3, 7, 7, device=self.device).to(dtype=dtype) 3913 opt_mod = torch.compile(mod) 3914 res = opt_mod(x) 3915 expected = mod(x) 3916 self.assertTrue(torch.allclose(res, expected)) 3917 3918 def test_buffer_copied_in_graph(self): 3919 class MyModel(torch.nn.Module): 3920 def __init__(self): 3921 super().__init__() 3922 self.register_buffer("buf", torch.zeros(1)) 3923 self.w1 = torch.nn.Parameter(torch.zeros(1)) 3924 self.w2 = torch.nn.Parameter(torch.zeros(1)) 3925 3926 def forward(self, x): 3927 self.buf.add_(1) 3928 return (self.w1 * x * self.w2).sum() + self.buf.sum() 3929 3930 model_for_eager = MyModel().to(self.device) 3931 model_for_compile = copy.deepcopy(model_for_eager) 3932 3933 eager_version_counters = [ 3934 buffer._version for _, buffer in model_for_eager.named_buffers() 3935 ] 3936 compile_version_counters = [ 3937 buffer._version for _, buffer in model_for_compile.named_buffers() 3938 ] 3939 3940 compiled_f = torch.compile(model_for_compile, backend="inductor") 3941 3942 inp_ref = torch.ones(1, requires_grad=True, device=self.device) 3943 inp_test = torch.ones(1, requires_grad=True, device=self.device) 3944 3945 out_ref = model_for_eager(inp_ref.clone()) 3946 out_test = compiled_f(inp_test.clone()) 3947 3948 eager_version_counters_after = [ 3949 buffer._version for _, buffer in model_for_eager.named_buffers() 3950 ] 3951 compile_version_counters_after = [ 3952 buffer._version for _, buffer in model_for_compile.named_buffers() 3953 ] 3954 3955 eager_delta = list( 3956 map(operator.sub, eager_version_counters_after, eager_version_counters) 3957 ) 3958 compile_delta = list( 3959 map(operator.sub, compile_version_counters_after, compile_version_counters) 3960 ) 3961 3962 self.assertEqual(eager_delta, compile_delta) 3963 3964 def test_buffer_copied_in_graph_with_different_shapes(self): 3965 class MyModel(torch.nn.Module): 3966 def __init__(self): 3967 super().__init__() 3968 self.register_buffer("buf", torch.ones(4, 4)) 3969 self.w = torch.nn.Parameter( 3970 torch.Tensor([[4, 5], [1, 2], [6, 7], [8, 9]]) 3971 ) 3972 3973 def forward(self, x): 3974 self.buf.add_(1) 3975 return (self.w @ x).sum() + self.buf.sum() 3976 3977 model_for_eager = MyModel().to(self.device) 3978 model_for_compile = copy.deepcopy(model_for_eager) 3979 3980 eager_version_counters = [ 3981 buffer._version for _, buffer in model_for_eager.named_buffers() 3982 ] 3983 compile_version_counters = [ 3984 buffer._version for _, buffer in model_for_compile.named_buffers() 3985 ] 3986 3987 compiled_f = torch.compile(model_for_compile, backend="inductor") 3988 3989 inp_ref = torch.ones(2, 4, requires_grad=True, device=self.device) 3990 inp_test = torch.ones(2, 4, requires_grad=True, device=self.device) 3991 3992 out_ref = model_for_eager(inp_ref.clone()) 3993 out_test = compiled_f(inp_test.clone()) 3994 3995 eager_version_counters_after = [ 3996 buffer._version for _, buffer in model_for_eager.named_buffers() 3997 ] 3998 compile_version_counters_after = [ 3999 buffer._version for _, buffer in model_for_compile.named_buffers() 4000 ] 4001 4002 eager_delta = list( 4003 map(operator.sub, eager_version_counters_after, eager_version_counters) 4004 ) 4005 compile_delta = list( 4006 map(operator.sub, compile_version_counters_after, compile_version_counters) 4007 ) 4008 4009 self.assertEqual(eager_delta, compile_delta) 4010 4011 @skipIfNNModuleInlined("https://github.com/pytorch/pytorch/issues/128198") 4012 def test_buffer_batch_norm(self): 4013 class MyModel(torch.nn.Module): 4014 def __init__(self): 4015 super().__init__() 4016 self.m = torch.nn.BatchNorm1d(100) 4017 4018 def forward(self, x): 4019 return self.m(x) 4020 4021 model_for_eager = MyModel().to(self.device) 4022 model_for_compile = copy.deepcopy(model_for_eager) 4023 4024 eager_version_counters = [ 4025 buffer._version for _, buffer in model_for_eager.named_buffers() 4026 ] 4027 compile_version_counters = [ 4028 buffer._version for _, buffer in model_for_compile.named_buffers() 4029 ] 4030 4031 compiled_f = torch.compile(model_for_compile, backend="inductor") 4032 4033 inp_ref = torch.ones(20, 100, requires_grad=True, device=self.device) 4034 inp_test = torch.ones(20, 100, requires_grad=True, device=self.device) 4035 4036 out_ref = model_for_eager(inp_ref.clone()) 4037 out_test = compiled_f(inp_test.clone()) 4038 4039 eager_version_counters_after = [ 4040 buffer._version for _, buffer in model_for_eager.named_buffers() 4041 ] 4042 compile_version_counters_after = [ 4043 buffer._version for _, buffer in model_for_compile.named_buffers() 4044 ] 4045 4046 eager_delta = list( 4047 map(operator.sub, eager_version_counters_after, eager_version_counters) 4048 ) 4049 compile_delta = list( 4050 map(operator.sub, compile_version_counters_after, compile_version_counters) 4051 ) 4052 4053 self.assertEqual(eager_delta, compile_delta) 4054 4055 def test_adaptive_avg_pool_with_output_size_0(self): 4056 m1 = nn.AdaptiveAvgPool1d(0) 4057 self.common(m1, (torch.randn(1, 2),)) 4058 m2 = nn.AdaptiveAvgPool2d(0) 4059 self.common(m2, (torch.randn(1, 2, 3),)) 4060 4061 def test_max_pool2d1(self): 4062 def fn(x): 4063 return aten.max_pool2d_with_indices(x, [3, 3], [2, 2]) 4064 4065 self.common( 4066 fn, 4067 (torch.randn(2, 4, 16, 16),), 4068 ) 4069 4070 def test_max_pool2d2(self): 4071 def fn(x): 4072 return aten.max_pool2d_with_indices(x, [3, 3], [2, 2]) 4073 4074 self.common( 4075 fn, 4076 (torch.randn([16, 64, 55, 55]),), 4077 ) 4078 4079 def test_max_pool2d3(self): 4080 def fn(x): 4081 # with padding 4082 return ( 4083 aten.max_pool2d_with_indices(x, [3, 3], [2, 2], [1, 1]), 4084 aten.max_pool2d_with_indices( 4085 x, 4086 [ 4087 3, 4088 ], 4089 [ 4090 2, 4091 ], 4092 [ 4093 1, 4094 ], 4095 ), 4096 ) 4097 4098 self.common( 4099 fn, 4100 (-torch.arange(1 * 8 * 8, dtype=torch.float32).view(1, 1, 8, 8),), 4101 ) 4102 4103 def test_max_pool2d4(self): 4104 def fn(x): 4105 # with padding 4106 return aten.max_pool2d_with_indices(x, [3, 3], [2, 2], [0, 0], [1, 1], True) 4107 4108 self.common( 4109 fn, 4110 (torch.randn([2, 8, 111, 111]),), 4111 ) 4112 4113 def test_max_pool2d5(self): 4114 def fn(x): 4115 return aten.max_pool2d_with_indices(x, [3, 3], []) 4116 4117 self.common( 4118 fn, 4119 (torch.randn([16, 64, 55, 55]),), 4120 ) 4121 4122 def test_max_pool2d6(self): 4123 # Too big kernel size, use fallback 4124 def fn(x): 4125 return aten.max_pool2d_with_indices(x, [13, 13], []) 4126 4127 torch._inductor.metrics.generated_kernel_count = 0 4128 self.common( 4129 fn, 4130 (torch.randn([16, 64, 55, 55]),), 4131 ) 4132 assertGeneratedKernelCountEqual(self, 0) 4133 4134 # From https://github.com/pytorch/pytorch/issues/94775 4135 def test_max_pool2d7(self): 4136 # ceil mode turns on 4137 def fn(x): 4138 return torch.nn.functional.max_pool2d( 4139 x, 1, stride=(2, 2), padding=0, ceil_mode=True 4140 ) 4141 4142 self.common( 4143 fn, 4144 (torch.randn([1, 1, 6, 7]),), 4145 ) 4146 4147 # From https://github.com/pytorch/pytorch/issues/93384 4148 def test_max_pool2d8(self): 4149 # dialtion is not 1, use fallback 4150 def fn(x): 4151 return aten.max_pool2d_with_indices(x, [3, 2], [2, 1], [1, 1], [1, 2]) 4152 4153 torch._inductor.metrics.generated_kernel_count = 0 4154 self.common( 4155 fn, 4156 (torch.randn([2, 2, 3, 6]),), 4157 ) 4158 assertGeneratedKernelCountEqual(self, 0) 4159 4160 def test_avg_pool2d1(self): 4161 def fn(x): 4162 return aten.avg_pool2d(x, [3, 3], [2, 2]) 4163 4164 self.common( 4165 fn, 4166 (torch.randn(2, 4, 16, 16),), 4167 ) 4168 4169 def test_avg_pool2d2(self): 4170 def fn(x): 4171 return aten.avg_pool2d(x, [3, 3], [2, 2]) 4172 4173 self.common( 4174 fn, 4175 (torch.randn([16, 64, 55, 55]),), 4176 ) 4177 4178 def test_avg_pool2d3(self): 4179 def fn(x): 4180 return ( 4181 aten.avg_pool2d(x, [3, 3], [2, 2], [1, 1]), 4182 aten.avg_pool2d( 4183 x, 4184 [ 4185 3, 4186 ], 4187 [ 4188 2, 4189 ], 4190 [ 4191 1, 4192 ], 4193 ), 4194 ) 4195 4196 self.common( 4197 fn, 4198 (-torch.arange(1 * 8 * 8, dtype=torch.float32).view(1, 1, 8, 8),), 4199 ) 4200 4201 def test_avg_pool2d4(self): 4202 def fn(x): 4203 return aten.avg_pool2d(x, [3, 3], [2, 2], [0, 0], True) 4204 4205 self.common( 4206 fn, 4207 (torch.randn([2, 8, 111, 111]),), 4208 ) 4209 4210 def test_avg_pool2d5(self): 4211 def fn(x): 4212 return aten.avg_pool2d(x, [3, 3], [2, 2], [1, 1], count_include_pad=False) 4213 4214 self.common( 4215 fn, 4216 (-torch.arange(1 * 8 * 8, dtype=torch.float32).view(1, 1, 8, 8),), 4217 ) 4218 4219 def test_avg_pool2d6(self): 4220 def fn(x): 4221 return aten.avg_pool2d(x, [3, 3], [2, 2], [1, 1], divisor_override=3) 4222 4223 self.common( 4224 fn, 4225 (-torch.arange(1 * 8 * 8, dtype=torch.float32).view(1, 1, 8, 8),), 4226 ) 4227 4228 def test_avg_pool2d7(self): 4229 # Large kernel size, use fallback 4230 def fn(x): 4231 return aten.avg_pool2d(x, [13, 13], [1, 1], [0, 0]) 4232 4233 torch._inductor.metrics.generated_kernel_count = 0 4234 self.common( 4235 fn, 4236 (-torch.arange(1 * 24 * 24, dtype=torch.float32).view(1, 1, 24, 24),), 4237 ) 4238 assertGeneratedKernelCountEqual(self, 0) 4239 4240 def test_avg_pool2d8(self): 4241 # https://github.com/pytorch/pytorch/issues/100987 4242 def fn(x): 4243 return aten.avg_pool2d( 4244 x, kernel_size=3, stride=2, padding=1, ceil_mode=True 4245 ) 4246 4247 self.common( 4248 fn, 4249 (torch.randn(1, 3, 6, 6),), 4250 ) 4251 4252 def test_alexnet_prefix(self): 4253 def forward(arg6, arg7, arg16): 4254 convolution = torch.ops.aten.convolution( 4255 arg16, arg7, arg6, [4, 4], [2, 2], [1, 1], False, [0, 0], 1 4256 ) 4257 relu = torch.ops.aten.relu(convolution) 4258 max_pool2d_with_indices = torch.ops.aten.max_pool2d_with_indices( 4259 relu, [3, 3], [2, 2] 4260 ) 4261 getitem = max_pool2d_with_indices[0] 4262 return (getitem,) 4263 4264 self.common( 4265 forward, 4266 ( 4267 rand_strided((64,), (1,), torch.float32, "cpu"), 4268 rand_strided((64, 3, 11, 11), (363, 121, 11, 1), torch.float32, "cpu"), 4269 rand_strided( 4270 (16, 3, 224, 224), (150528, 50176, 224, 1), torch.float32, "cpu" 4271 ), 4272 ), 4273 # Mismatched elements: 127 / 746496 (0.0%) 4274 # Greatest absolute difference: 0.0009765625 at index (1, 62, 7, 16) (up to 1e-05 allowed) 4275 # Greatest relative difference: 0.05187467899332306 at index (14, 18, 11, 0) (up to 0.001 allowed) 4276 atol=3e-3, 4277 rtol=2, 4278 ) 4279 4280 def test_elu(self): 4281 def fn(x): 4282 return aten.elu(x, 1.6732632423543772, 1.0507009873554805) + 2, aten.elu( 4283 x + 1, 2, 3, 4 4284 ) 4285 4286 self.common( 4287 fn, 4288 (torch.randn([16, 16]),), 4289 ) 4290 4291 def test_tan(self): 4292 def fn(x): 4293 return aten.tan(x) + 2, aten.tan(x + 1) 4294 4295 self.common( 4296 fn, 4297 (torch.randn([16, 16]),), 4298 ) 4299 4300 def test_tanh(self): 4301 def fn(x): 4302 return aten.tanh(x) + 2, aten.tanh(x + 1) 4303 4304 self.common( 4305 fn, 4306 (torch.randn([16, 16]),), 4307 ) 4308 4309 def test_lgamma(self): 4310 def fn(x): 4311 return aten.lgamma(x) + 2, aten.cos(x + 1) 4312 4313 self.common( 4314 fn, 4315 (torch.randn([16, 16]),), 4316 ) 4317 4318 def test_cos(self): 4319 def fn(x): 4320 return aten.cos(x) + 2, aten.cos(x + 1) 4321 4322 self.common( 4323 fn, 4324 (torch.randn([16, 16]),), 4325 ) 4326 4327 def test_sin(self): 4328 def fn(x): 4329 return aten.sin(x) + 2, aten.sin(x + 1) 4330 4331 self.common( 4332 fn, 4333 (torch.randn([16, 16]),), 4334 ) 4335 4336 def test_repeat(self): 4337 def fn(x): 4338 return ( 4339 x.repeat(0, 1, 1, 1), 4340 x.repeat(2, 2, 3, 1), 4341 x.repeat(8, 1, 1, 1), 4342 x.repeat(2, 1, 1, 1, 1, 1), 4343 ) 4344 4345 self.common( 4346 fn, 4347 (torch.randn([1, 2, 4, 8]),), 4348 ) 4349 4350 def test_repeat_as_strided(self): 4351 # Reproducer for #127474 4352 4353 def fn(x): 4354 view_size = (3, 2) 4355 full = x.repeat((3, 2)) 4356 view = torch.as_strided(full, view_size, full.stride()) 4357 result = view + view 4358 4359 return result 4360 4361 self.common(fn, (torch.randn(1, 1),)) 4362 4363 def test_repeat_interleave(self): 4364 def fn(x): 4365 return ( 4366 x.repeat_interleave(2), 4367 x.repeat_interleave(3, dim=0), 4368 x.repeat_interleave(x.size(1), dim=1), 4369 ) 4370 4371 self.common( 4372 fn, 4373 (torch.randn([1, 2, 4, 8]),), 4374 ) 4375 4376 @config.patch(implicit_fallbacks=True) 4377 def test_repeat_interleave_2(self): 4378 def fn(x): 4379 return torch.ops.aten.repeat_interleave.Tensor(x, output_size=12) 4380 4381 self.common( 4382 fn, 4383 (torch.tensor([2, 4, 6]),), 4384 ) 4385 4386 @config.patch(fallback_random=True) 4387 def test_randn_with_dtype_and_device(self): 4388 if self.device == GPU_TYPE: 4389 raise unittest.SkipTest("only support cpu randn_with_dtype_and_device test") 4390 4391 def fn(vectors): 4392 rotations_shape = (12, vectors.shape[-1], 1, 64) 4393 random_rotations = torch.randn( 4394 rotations_shape, device=vectors.device, dtype=vectors.dtype 4395 ) 4396 random_rotations += 1 4397 return random_rotations 4398 4399 self.common( 4400 fn, 4401 (torch.randn([4, 12, 2, 64]),), 4402 ) 4403 4404 def test_embedding(self): 4405 m = torch.nn.Sequential( 4406 torch.nn.Embedding(10, 4, padding_idx=0), 4407 torch.nn.ReLU(), 4408 ToTuple(), 4409 ) 4410 4411 self.common( 4412 m, 4413 (torch.randint(10, [2, 8]),), 4414 ) 4415 4416 def test_mean(self): 4417 def fn(x): 4418 return ( 4419 x.mean(), 4420 x.mean(-1), 4421 torch.mean(x, -2, keepdim=True), 4422 x.mean([0, 1]), 4423 ) 4424 4425 self.common( 4426 fn, 4427 (torch.randn([1, 2, 4, 8]),), 4428 ) 4429 4430 def test_var_mean(self): 4431 def fn(x): 4432 return ( 4433 *torch.var_mean(x, -1), 4434 *torch.var_mean(x, [1, 3]), 4435 ) 4436 4437 self.common( 4438 fn, 4439 (torch.randn([1, 2, 4, 8]),), 4440 ) 4441 4442 def test_var_correction(self): 4443 def fn(x): 4444 dim = -1 4445 return ( 4446 torch.var(x, dim=dim, correction=1.3), 4447 torch.var(x, dim=dim, correction=3), 4448 torch.var(x, dim=dim, correction=10), 4449 ) 4450 4451 self.common(fn, (torch.randn([2, 8]),)) 4452 # Unrolled reduction 4453 self.common(fn, (torch.randn([2, 4]),)) 4454 4455 @config.patch(pick_loop_orders=True) 4456 def test_transposed_propagates(self): 4457 @torch._dynamo.optimize("inductor", nopython=True) 4458 def fn(x, y): 4459 return x + y 4460 4461 a = torch.randn(1, 4, 4, 4, device=self.device).permute(0, 2, 3, 1) 4462 b = torch.randn(4, 4, 4, device=self.device).permute(1, 2, 0) 4463 c = fn(a, b) 4464 self.assertEqual(a.stride(), c.stride()) 4465 self.assertEqual(c.stride()[2], 1) 4466 4467 def test_std(self): 4468 def fn(x): 4469 return ( 4470 torch.var(x, True), 4471 torch.var(x, False), 4472 torch.var(x, -1, True), 4473 torch.var(x, -1, False), 4474 torch.std(x, False), 4475 torch.std(x, [0, 1], True), 4476 torch.std(x, [0, 1], False), 4477 torch.std(x, -2, True, keepdim=True), 4478 ) 4479 4480 self.common( 4481 fn, 4482 (torch.randn([2, 4, 4, 8]),), 4483 ) 4484 4485 def test_embedding_bag(self): 4486 def fn(w, i, o): 4487 return aten._embedding_bag(w, i, o, False, 0, False, None) 4488 4489 self.common( 4490 fn, 4491 (torch.randn([10, 4]), torch.randint(10, [8]), torch.tensor([0, 2, 6])), 4492 ) 4493 4494 def test_batch_norm_2d(self): 4495 m = torch.nn.Sequential( 4496 torch.nn.BatchNorm2d(10), 4497 torch.nn.ReLU(), 4498 ) 4499 m.eval() 4500 self.common(m, (torch.randn([2, 10, 8, 8]),), check_lowp=False) 4501 self.common( 4502 m, 4503 (torch.randn([3, 10, 16, 16]),), 4504 check_lowp=False, # too painful to match types of bn model 4505 ) 4506 4507 # From yolov3 4508 @with_tf32_off 4509 def test_batch_norm_2d_2(self): 4510 if self.device == "cpu": 4511 raise unittest.SkipTest(f"requires {GPU_TYPE}") 4512 4513 class Repro(torch.nn.Module): 4514 def __init__(self): 4515 super().__init__() 4516 self.self_0 = torch.nn.Conv2d( 4517 64, 4518 128, 4519 kernel_size=(3, 3), 4520 stride=(2, 2), 4521 padding=(1, 1), 4522 bias=False, 4523 ) 4524 self.self_1 = torch.nn.BatchNorm2d( 4525 128, 4526 eps=0.0001, 4527 momentum=0.03, 4528 affine=True, 4529 track_running_stats=True, 4530 ) 4531 self.self_2 = torch.nn.LeakyReLU(negative_slope=0.1, inplace=True) 4532 4533 def forward(self, l_input_: torch.Tensor): 4534 self_0 = self.self_0(l_input_) 4535 self_1 = self.self_1(self_0) 4536 self_2 = self.self_2(self_1) 4537 return (self_2,) 4538 4539 inp = torch.randn((4, 64, 192, 256), dtype=torch.float32, device=GPU_TYPE) 4540 mod = Repro().to(device=GPU_TYPE) 4541 o1 = mod(inp) 4542 o2 = torch.compile(mod)(inp) 4543 self.assertEqual(o1, o2) 4544 4545 @patch.object(config.trace, "enabled", True) 4546 def test_layer_norm(self): 4547 m = torch.nn.Sequential( 4548 torch.nn.LayerNorm(32), 4549 torch.nn.ReLU(), 4550 ) 4551 m.eval() 4552 with torch.no_grad(): 4553 self.common(m, (torch.randn([16, 32]),), check_lowp=False) 4554 if self.device != "cpu": 4555 assertGeneratedKernelCountEqual(self, 1) 4556 4557 def test_transpose_add(self): 4558 def fn(a, b): 4559 return a.t() + b 4560 4561 self.common( 4562 fn, (torch.randn([16, 32]), torch.randn([32, 16])), check_lowp=False 4563 ) 4564 if self.device != "cpu": 4565 assertGeneratedKernelCountEqual(self, 1) 4566 4567 @patch.object(config.triton, "persistent_reductions", True) 4568 def test_softmax_one_kernel_persist(self): 4569 def fn(x): 4570 dim = 1 4571 x_max = torch.amax(x, dim, keepdim=True) 4572 unnormalized = torch.exp(x - x_max) 4573 result = unnormalized / torch.sum(unnormalized, dim, keepdim=True) 4574 return result 4575 4576 self.common(fn, (torch.randn([16, 32]),), check_lowp=False) 4577 if self.device != "cpu": 4578 assertGeneratedKernelCountEqual(self, 1) 4579 4580 @patch.object(config.triton, "persistent_reductions", False) 4581 def test_softmax_one_kernel_loop(self): 4582 def fn(x): 4583 x_max = torch.amax(x, 1, keepdim=True) 4584 unnormalized = torch.exp(x - x_max) 4585 result = unnormalized / torch.sum(unnormalized, 1, keepdim=True) 4586 return result 4587 4588 self.common(fn, (torch.randn([16, 32]),), check_lowp=False) 4589 if self.device != "cpu": 4590 assertGeneratedKernelCountEqual(self, 1) 4591 4592 def test_complex_fallback(self): 4593 def fn(x): 4594 return x * x + 10 4595 4596 self.common( 4597 fn, 4598 (torch.randn([1, 2, 4, 8]).to(dtype=torch.complex64),), 4599 ) 4600 assertGeneratedKernelCountEqual(self, 0) 4601 4602 class ToComplex(nn.Module): 4603 def forward(self, x): 4604 return (x + x + 12).to(torch.complex64) 4605 4606 self.common(ToComplex(), (torch.rand([1, 2, 4, 8]),), check_lowp=False) 4607 4608 if self.device != "cpu": 4609 assertGeneratedKernelCountEqual(self, 1) 4610 4611 def test_view_as_complex(self): 4612 class Repro(torch.nn.Module): 4613 def __init__(self): 4614 super().__init__() 4615 4616 def forward(self, view_2): 4617 clone = torch.ops.aten.clone.default( 4618 view_2, memory_format=torch.contiguous_format 4619 ) 4620 view_2 = None 4621 view_as_complex = torch.ops.aten.view_as_complex.default(clone) 4622 clone = None 4623 return (view_as_complex,) 4624 4625 inp = torch.empty_strided((128, 64, 12, 32, 2), (1, 98304, 8192, 256, 128)).to( 4626 self.device 4627 ) 4628 mod = Repro() 4629 4630 o1 = mod(inp) 4631 o2 = torch.compile(mod)(inp) 4632 4633 self.assertEqual(o1, o2) 4634 4635 def test_view_as_real(self): 4636 def fn(x): 4637 y = torch.view_as_real(x) 4638 return y + 1 4639 4640 x = torch.randn(4, dtype=torch.complex64) 4641 4642 self.common(fn, (x,)) 4643 4644 def test_cauchy(self): 4645 def fn(x, y): 4646 return torch.sum(1 / (torch.unsqueeze(x, -1) - y)) 4647 4648 self.common( 4649 fn, 4650 ( 4651 torch.randn(32), 4652 torch.randn(32), 4653 ), 4654 # Absolute difference: 0.0003662109375 (up to 0.0001 allowed) 4655 # Relative difference: 1.8804297408767818e-05 (up to 1e-05 allowed) 4656 atol=5 * 1e-4, 4657 rtol=5 * 1e-5, 4658 check_lowp=False, 4659 ) 4660 if self.device != "cpu": 4661 assertGeneratedKernelCountEqual(self, 1) 4662 4663 def test_fusing_write_into_disjoint_read(self): 4664 def test_flip(a): 4665 return a.copy_(torch.flip(a, (0,))) 4666 4667 self.common(test_flip, (torch.rand([20]),)) 4668 4669 assertGeneratedKernelCountEqual(self, 2) 4670 4671 # issue only manifests on cuda with large tensors 4672 if self.device != "cpu": 4673 4674 def f(a): 4675 a[:, 20:40] = a[:, 20:40] + 1 4676 a[:, 2:900025] = a[:, 1:900024] + 2 4677 4678 a = torch.rand((1, 1000000), device=GPU_TYPE) 4679 self.common(f, (a,)) 4680 4681 def test_gather_scatter(self): 4682 def fn(node_feat, edge_index): 4683 src_node_feat = node_feat[edge_index[0]] 4684 dst_node_feat = node_feat[edge_index[1]] 4685 edge_feat = src_node_feat - dst_node_feat + 1 4686 new_node_feat = torch.zeros_like(node_feat) 4687 new_node_feat.scatter_add_( 4688 0, edge_index[1].unsqueeze(-1).expand_as(edge_feat), edge_feat 4689 ) 4690 return new_node_feat 4691 4692 num_nodes = 16 4693 num_features = 32 4694 node_feat = torch.randn(num_nodes, num_features) 4695 edge_index = torch.randint(0, num_nodes, size=(2, num_nodes * 5)) 4696 self.common( 4697 fn, 4698 ( 4699 node_feat, 4700 edge_index, 4701 ), 4702 check_lowp=False, 4703 ) 4704 if self.device != "cpu": 4705 assertGeneratedKernelCountEqual(self, 2) 4706 4707 @config.patch(max_fusion_size=1) 4708 def test_no_mega_fusion_during_lowering(self): 4709 n = 50 4710 4711 def fn(*args): 4712 x = args[0] 4713 for i in range(n): 4714 x = torch.add(x, args[i]) 4715 return x 4716 4717 self.common( 4718 fn, 4719 [torch.randn(64) for _ in range(n)], 4720 check_lowp=False, 4721 ) 4722 print("-->", torch._inductor.metrics.generated_kernel_count) 4723 if self.device != "cpu": 4724 self.assertTrue(torch._inductor.metrics.generated_kernel_count > 1) 4725 4726 def test_move_arange(self): 4727 def fn(x): 4728 return torch.arange(len(x), device="cpu").to(x.device) + x 4729 4730 self.common(fn, (torch.randn([32]),), check_lowp=False) 4731 # if we have a copy there will be more than 1 kernel 4732 assertGeneratedKernelCountEqual(self, 1) 4733 4734 def test_leaky_relu(self): 4735 def fn(x): 4736 return aten.leaky_relu(x, 0.2) + 2, aten.leaky_relu(x + 1) 4737 4738 self.common( 4739 fn, 4740 (torch.randn([16, 16]),), 4741 ) 4742 4743 def test_gelu(self): 4744 def fn(x): 4745 return aten.gelu(x) + 2, aten.gelu(x + 1) 4746 4747 self.common( 4748 fn, 4749 (torch.randn([16, 16]),), 4750 ) 4751 4752 def test_clone(self): 4753 def fn(x): 4754 return aten.clone(x) + 2, aten.clone(x + 1) 4755 4756 self.common( 4757 fn, 4758 (torch.randn([16, 16]),), 4759 ) 4760 4761 def test_masked_fill(self): 4762 def fn(mask, value): 4763 return aten.masked_fill(value, mask, -10000.0) + 2, aten.masked_fill( 4764 value / 2.0, torch.logical_not(mask), 667 4765 ) 4766 4767 self.common( 4768 fn, 4769 ( 4770 torch.randint(0, 1, [1, 16], dtype=torch.bool), 4771 torch.randn([16, 16]), 4772 ), 4773 ) 4774 4775 def test_masked_fill_promotion(self): 4776 def fn(mask, value): 4777 return aten.masked_fill(value, mask, torch.tensor(3.5)) 4778 4779 opt_fn = torch._dynamo.optimize("inductor")(fn) 4780 for inp in ( 4781 torch.randn( 4782 [16, 16], 4783 dtype=torch.float16 if self.device == GPU_TYPE else torch.float32, 4784 device=self.device, 4785 ), 4786 torch.randint(16, (16, 16), device=self.device), 4787 ): 4788 inputs = ( 4789 torch.randint(0, 1, [1, 16], dtype=torch.bool, device=self.device), 4790 inp, 4791 ) 4792 self.assertEqual(fn(*inputs), opt_fn(*inputs)) 4793 4794 def test_masked_scatter(self): 4795 def fn(value, mask, source): 4796 return torch.masked_scatter(value, mask, source) 4797 4798 value = make_tensor(10, 10, dtype=torch.float32, device=self.device) 4799 mask = make_tensor(10, 10, dtype=torch.bool, device=self.device) 4800 source = make_tensor( 4801 mask.count_nonzero(), dtype=torch.float32, device=self.device 4802 ) 4803 4804 self.common(fn, (value, mask, source)) 4805 4806 def test_fill1(self): 4807 def fn(x): 4808 tmp = torch.ones_like(x) 4809 return tmp, aten.fill.Scalar(tmp, 2) 4810 4811 self.common( 4812 fn, 4813 (torch.randn([16, 16]),), 4814 ) 4815 4816 def test_fill2(self): 4817 def fn(x): 4818 tmp = torch.ones_like(x) 4819 return tmp, aten.fill.Tensor(tmp, torch.tensor(3.0)) 4820 4821 self.common( 4822 fn, 4823 (torch.randn([16, 16]),), 4824 ) 4825 4826 def test_pow1(self): 4827 def fn(x): 4828 return [aten.pow(x, e) for e in range(-8, 9)] 4829 4830 self.common( 4831 fn, 4832 (torch.randn([16, 16]),), 4833 ) 4834 4835 def test_pow2(self): 4836 def fn(x): 4837 return aten.pow(1000, x), aten.pow(x, 1000) 4838 4839 self.common( 4840 fn, 4841 ( 4842 torch.randn( 4843 [16, 16], 4844 dtype=torch.float32, 4845 ), 4846 ), 4847 # Mismatched elements: 9 / 256 (3.5%) 4848 # Greatest absolute difference: 2.491354329061828e+28 at index (6, 6) (up to 1e-05 allowed) 4849 # Greatest relative difference: 2.9793410720160818e-05 at index (4, 5) (up to 1.3e-06 allowed) 4850 atol=1e-5, 4851 rtol=3e-05, 4852 ) 4853 4854 def test_pow3(self): 4855 # power of 0.5 is special-cased, arbitrary power would still produce triton codegen error 4856 def fn(x): 4857 z = torch.tensor(0.123, device=self.device) 4858 w = z + x 4859 return torch.pow(w, 0.5) 4860 4861 opt = torch._dynamo.optimize("inductor")(fn) 4862 input = torch.rand(()) 4863 self.assertTrue(same(opt(input), fn(input))) 4864 4865 def test_pow_int(self): 4866 def fn(x, y): 4867 return torch.pow(x, 0x57), torch.pow(x, y) 4868 4869 for dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64): 4870 intmax = torch.iinfo(dtype).max 4871 make_arg = functools.partial( 4872 make_tensor, dtype=dtype, device=self.device, requires_grad=False 4873 ) 4874 self.common( 4875 fn, 4876 ( 4877 make_arg(16, 16), 4878 make_arg(16, 16, high=intmax), 4879 ), 4880 ) 4881 4882 def test_glu(self): 4883 def fn(x): 4884 return aten.glu(x, -1), aten.glu(x, 1), aten.glu(x, 2) 4885 4886 self.common( 4887 fn, 4888 (torch.randn([8, 16, 8, 8]),), 4889 ) 4890 4891 @torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True) 4892 def test_nonzero_unbacked_refinement(self): 4893 def fn(x): 4894 z = x.nonzero() 4895 torch._check(z.size(0) == 4) 4896 return z + 3 4897 4898 self.common( 4899 fn, 4900 (torch.tensor([0, 1, 3, 4, 2, 0, 0]),), 4901 ) 4902 4903 with self.assertRaises(RuntimeError): 4904 torch.compile(fn)(torch.tensor([0, 0, 0, 0])) 4905 4906 @torch._dynamo.config.patch(capture_scalar_outputs=True) 4907 def test_unbacked_floordiv_simplify(self): 4908 def fn(x, y): 4909 z = y.item() 4910 torch._check(z // 2 == 3) 4911 return x + x.new_zeros(z) 4912 4913 self.common( 4914 fn, 4915 ( 4916 torch.randn(6), 4917 torch.tensor([6]), 4918 ), 4919 ) 4920 4921 self.common( 4922 fn, 4923 ( 4924 torch.randn(7), 4925 torch.tensor([7]), 4926 ), 4927 ) 4928 4929 @torch._dynamo.config.patch(capture_scalar_outputs=True) 4930 def test_unbacked_floordiv_simplify_errors(self): 4931 def fn(x, y): 4932 z = y.item() 4933 torch._check(z // 2 == 3) 4934 return x + x.new_zeros(z) 4935 4936 # This is a little suboptimal: we actually fail /in the compiler/ but 4937 # not in a way that causes Dynamo to graph break 4938 with self.assertRaises(RuntimeError): 4939 torch.compile(fn)(torch.randn(8), torch.tensor(8)) 4940 4941 def test_cat(self): 4942 def fn(a): 4943 tmp = a * 2 4944 return ( 4945 torch.cat((a, a[:, :4] + 1, a + 2), -1), 4946 torch.cat((tmp, tmp), 0), 4947 torch.cat((tmp, tmp.double()), 0), 4948 ) 4949 4950 self.common( 4951 fn, 4952 (torch.randn([8, 16]),), 4953 ) 4954 self.common( 4955 fn, 4956 (torch.randn([1, 3, 3, 16]).to(memory_format=torch.channels_last),), 4957 ) 4958 4959 def test_cat_uint8(self): 4960 def fn(x): 4961 batch_shape = x.shape[:1] 4962 out = torch.cat([x.new_zeros(1).expand(batch_shape + (1,)), x], dim=-1) 4963 return out 4964 4965 self.common( 4966 fn, 4967 (torch.randint(0, 256, size=(3, 255), dtype=torch.uint8),), 4968 ) 4969 4970 def test_cat_empty(self): 4971 def fn_2(*tensors): 4972 return torch.cat(tensors) 4973 4974 self.common( 4975 fn_2, 4976 ( 4977 torch.randn([1, 3, 3, 16]), 4978 torch.ones([0]), 4979 ), 4980 ) 4981 self.common( 4982 fn_2, 4983 ( 4984 torch.randn([1, 3, 3, 16]), 4985 torch.ones([0]), 4986 torch.randn([1, 3, 3, 16]), 4987 ), 4988 ) 4989 self.common( 4990 fn_2, 4991 ( 4992 torch.ones([0]), 4993 torch.randn([1, 3, 3, 16]), 4994 ), 4995 ) 4996 4997 @torch._dynamo.config.patch(capture_scalar_outputs=True) 4998 def test_cat_unbacked_legacy_empty(self): 4999 def fn(x, y): 5000 z = y.item() 5001 return torch.cat([x, x.new_ones(z)]) 5002 5003 self.common( 5004 fn, 5005 ( 5006 torch.randn([2, 3]), 5007 torch.tensor([0]), 5008 ), 5009 ) 5010 5011 @torch._dynamo.config.patch(capture_scalar_outputs=True) 5012 def test_cat_unbacked_empty_1d(self): 5013 def fn(x, y): 5014 z = y.item() 5015 return torch.cat([x, x.new_ones(z)]) 5016 5017 self.common( 5018 fn, 5019 ( 5020 torch.randn([2]), 5021 torch.tensor([0]), 5022 ), 5023 ) 5024 5025 self.common( 5026 fn, 5027 ( 5028 torch.randn([2]), 5029 torch.tensor([3]), 5030 ), 5031 ) 5032 5033 @torch._dynamo.config.patch(capture_scalar_outputs=True) 5034 def test_cat_unbacked_2d(self): 5035 def fn(x, y): 5036 z = y.item() 5037 return torch.cat([x, x.new_ones(z, x.shape[1])]) 5038 5039 self.common( 5040 fn, 5041 ( 5042 torch.randn([2, 3]), 5043 torch.tensor([0]), 5044 ), 5045 ) 5046 5047 self.common( 5048 fn, 5049 ( 5050 torch.randn([2, 3]), 5051 torch.tensor([4]), 5052 ), 5053 ) 5054 5055 def test_cat_negative_dim(self): 5056 def fn(*tensors): 5057 return torch.cat(tensors, dim=-1) 5058 5059 self.common( 5060 fn, 5061 ( 5062 torch.randn([2, 3]), 5063 torch.randn([2, 4]), 5064 ), 5065 ) 5066 5067 self.common( 5068 fn, 5069 ( 5070 torch.randn([2, 3]), 5071 torch.randn([0]), 5072 torch.randn([2, 4]), 5073 ), 5074 ) 5075 5076 self.common( 5077 fn, 5078 ( 5079 torch.randn([0]), 5080 torch.randn([2, 3]), 5081 torch.randn([2, 4]), 5082 ), 5083 ) 5084 5085 @expectedFailureCodegenDynamic 5086 def test_cat_single_empty(self): 5087 # fails dynamic check for 'has a dynamic dimension' 5088 def fn_2(*tensors): 5089 return torch.cat(tensors) 5090 5091 self.common( 5092 fn_2, 5093 (torch.ones([0]),), 5094 ) 5095 5096 def test_cat_upcasting(self): 5097 def fn(arg4_1, slice_7): 5098 cat_1 = aten.cat.default([arg4_1, slice_7], 1) 5099 return (cat_1,) 5100 5101 self.common( 5102 fn, 5103 ( 5104 torch.randn([8, 16], dtype=torch.float32), 5105 torch.randn([8, 20], dtype=torch.float16), 5106 ), 5107 ) 5108 5109 def test_cat_extern_kernel(self): 5110 def fn(x1, x2, x3, x4): 5111 x = torch.mm(x2, x3) 5112 s = torch.narrow(x, 1, 0, 100) 5113 x = torch.mm(s, x4) 5114 c = torch.cat((x, x1), 1) 5115 return (c,) 5116 5117 if self.device == "xpu": 5118 atol = 3e-4 5119 rtol = 1e-4 5120 else: 5121 # use default 5122 atol = None 5123 rtol = None 5124 self.common( 5125 fn, 5126 ( 5127 torch.randn(256, 256), 5128 torch.randn(256, 1024), 5129 torch.randn(1024, 1600), 5130 torch.randn(100, 256), 5131 ), 5132 atol=atol, 5133 rtol=rtol, 5134 check_lowp=False, # accuracy issues with relatively large matmuls 5135 ) 5136 5137 @skipCUDAIf(not SM80OrLater, "uses bfloat16 which requires SM >= 80") 5138 # Constant folding was explicitly turned off due to issue #108388 5139 # Turn it back on for test 5140 @torch._inductor.config.patch(joint_graph_constant_folding=True) 5141 def test_remove_no_ops(self): 5142 def matmul_with_op(x, y, fn): 5143 return fn(x @ y) 5144 5145 foo_opt = torch.compile(matmul_with_op) 5146 5147 # test no-op 5148 fns = ( 5149 lambda x: x 5150 + torch.zeros( 5151 [256, 256], dtype=torch.float32, device=x.device 5152 ), # noqa: E731 5153 lambda x: x 5154 - torch.zeros( 5155 [256, 256], dtype=torch.float32, device=x.device 5156 ), # noqa: E731 5157 lambda x: x 5158 * torch.ones( 5159 [256, 256], dtype=torch.float32, device=x.device 5160 ), # noqa: E731 5161 lambda x: x 5162 / torch.ones( 5163 [256, 256], dtype=torch.float32, device=x.device 5164 ), # noqa: E731 5165 ) 5166 5167 inps = [torch.rand([256, 256], device=self.device) for _ in range(2)] 5168 5169 for fn in fns: 5170 out, source_codes = run_and_get_code(foo_opt, inps[0], inps[1], fn) 5171 self.assertEqual(out, matmul_with_op(inps[0], inps[1], fn)) 5172 5173 if self.device == "cpu": 5174 FileCheck().check_not("cpp_fused").run(source_codes[0]) 5175 else: 5176 FileCheck().check_not("triton.jit").run(source_codes[0]) 5177 5178 # test dtype conversion 5179 inps = [ 5180 torch.rand([256, 256], device=self.device, dtype=torch.bfloat16) 5181 for _ in range(2) 5182 ] 5183 for fn in fns: 5184 out, source_codes = run_and_get_code(foo_opt, inps[0], inps[1], fn) 5185 self.assertEqual(out, matmul_with_op(inps[0], inps[1], fn)) 5186 5187 # test broadcasted shape bail 5188 fn = lambda x: x + torch.zeros( # noqa: E731 5189 [256, 256, 256], dtype=torch.bfloat16, device=self.device 5190 ) 5191 out, source_codes = run_and_get_code(foo_opt, inps[0], inps[1], fn) 5192 self.assertEqual(out, matmul_with_op(inps[0], inps[1], fn)) 5193 5194 def test_remove_noop_copy(self): 5195 def fn(x, y): 5196 x = x.cos() 5197 a = x.copy_(y) 5198 return a.sin() 5199 5200 self.common(fn, (torch.randn(8, 8), torch.randn(8))) 5201 5202 def fn2(a, b): 5203 abs_max = torch.abs(a).max() 5204 b[0] = abs_max.to(a.dtype) 5205 return b 5206 5207 self.common( 5208 fn2, 5209 ( 5210 torch.randn(8, 8, dtype=torch.float16), 5211 torch.randn(8, dtype=torch.float32), 5212 ), 5213 ) 5214 5215 def test_remove_noop_clone(self): 5216 def fn(x): 5217 y = x.clone().reshape(-1, 4) 5218 y[:, [2, 0]] = y[:, [0, 2]] 5219 return y + x 5220 5221 self.common(fn, (torch.randn(2, 4),)) 5222 5223 def test_cat_of_loops_and_extern_kernel(self): 5224 class M(torch.nn.Module): 5225 def __init__( 5226 self, 5227 **kwargs, 5228 ): 5229 super().__init__() 5230 self.conv = torch.nn.Conv2d( 5231 64, 5232 5, 5233 1, 5234 **kwargs, 5235 ) 5236 self.max_pool2d = torch.nn.MaxPool2d(2) 5237 5238 def forward(self, x, y): 5239 x1 = self.conv(x) 5240 y1 = self.max_pool2d(y) 5241 return torch.cat([x1, y1], 1) 5242 5243 mod = M() 5244 opt_mod = torch._dynamo.optimize("inductor")(mod) 5245 memory_format = torch.channels_last 5246 inputs = ( 5247 torch.randn([1, 64, 16, 16]).to(memory_format=memory_format), 5248 torch.randn([1, 64, 32, 32]).to(memory_format=memory_format), 5249 ) 5250 y = mod(*inputs) 5251 opt_y = opt_mod(*inputs) 5252 self.assertEqual(y, opt_y) 5253 self.assertEqual(y.stride(), opt_y.stride()) 5254 5255 def test_cat_inplace(self): 5256 def fn(x): 5257 rt = torch.cat([x]) 5258 v = x.sin_() 5259 return rt 5260 5261 # can't use self.common because input is modified inplace 5262 inp = torch.ones(2) 5263 opt_fn = torch.compile(fn) 5264 res = opt_fn(inp.clone()) 5265 expected = fn(inp.clone()) 5266 self.assertEqual(res, expected) 5267 5268 def test_stack(self): 5269 def fn(a, b): 5270 return torch.stack( 5271 [ 5272 a.expand(12, 16), 5273 b.expand(12, 16), 5274 ], 5275 2, 5276 ) 5277 5278 self.common(fn, (torch.randn([1, 16]), torch.randn([12, 1]))) 5279 5280 def test_hardtanh(self): 5281 def fn(x): 5282 return F.hardtanh(x), F.hardtanh(x + 1), F.hardtanh(x - 1) 5283 5284 self.common( 5285 fn, 5286 (torch.randn([64]),), 5287 ) 5288 5289 def test_hardsigmoid(self): 5290 def fn(x): 5291 return F.hardsigmoid(x), F.hardsigmoid(x + 3), F.hardsigmoid(x - 3) 5292 5293 self.common( 5294 fn, 5295 (torch.randn([64]),), 5296 ) 5297 5298 def test_hardswish(self): 5299 def fn(x): 5300 return F.hardswish(x), F.hardswish(x + 3), F.hardswish(x - 3) 5301 5302 self.common( 5303 fn, 5304 (torch.randn([64]),), 5305 ) 5306 5307 def test_rsqrt(self): 5308 def fn(x): 5309 return torch.rsqrt(x), torch.rsqrt(x + 1) - 2 5310 5311 self.common( 5312 fn, 5313 (torch.randn([64]),), 5314 ) 5315 5316 def test_expm1(self): 5317 def fn(x): 5318 return torch.expm1(x), torch.expm1(x) * 2 5319 5320 for dtype in (torch.float16, torch.float, torch.double, torch.int, torch.int64): 5321 self.common( 5322 fn, 5323 (torch.randn([64]).to(dtype=dtype),), 5324 ) 5325 self.common( 5326 fn, 5327 (torch.arange(-1e-5, 1e-5, 1e-7).to(dtype=dtype),), 5328 ) 5329 5330 def test_log1p(self): 5331 def fn(x): 5332 return torch.log1p(x), torch.log1p(x) * 2 5333 5334 for dtype in (torch.float16, torch.float, torch.double, torch.int, torch.int64): 5335 self.common( 5336 fn, 5337 (torch.randn([64]).to(dtype=dtype),), 5338 ) 5339 self.common( 5340 fn, 5341 (torch.arange(-1e-5, 1e-5, 1e-7).to(dtype=dtype),), 5342 ) 5343 5344 def test_flip(self): 5345 def fn(x): 5346 return torch.flip(x, (-1,)), torch.flip(x, (0, 2)) - 2 5347 5348 self.common( 5349 fn, 5350 (torch.randn([1, 2, 6, 6]),), 5351 ) 5352 5353 def test_signbit(self): 5354 def fn(x): 5355 return torch.signbit(x), ~torch.signbit(-x) & 1 5356 5357 self.common( 5358 fn, 5359 (torch.randn([1, 2, 6, 6]),), 5360 ) 5361 5362 def test_sign_dtype(self): 5363 def fn(x): 5364 y = torch.sign(x) 5365 return torch.tanh(y) 5366 5367 self.common(fn, (torch.randn([1, 2, 6, 6]),)) 5368 5369 def test_fmod(self): 5370 def fn(a, b): 5371 return torch.fmod(a, b), torch.fmod(3.0 * a, b) - 2.0 5372 5373 shape = [1, 2, 6, 6] 5374 self.common(fn, (torch.randn(shape), torch.randn(shape))) 5375 5376 def test_fmod_zero_dim(self): 5377 def fn(a, b): 5378 return (torch.fmod(a, b),) 5379 5380 self.common( 5381 fn, 5382 ( 5383 make_tensor(10, device=self.device, dtype=torch.float32), 5384 make_tensor((), device=self.device, dtype=torch.float32), 5385 ), 5386 ) 5387 self.common( 5388 fn, 5389 ( 5390 make_tensor((), device=self.device, dtype=torch.float32), 5391 make_tensor(10, device=self.device, dtype=torch.float32), 5392 ), 5393 ) 5394 5395 def test_log2(self): 5396 def fn(x): 5397 return torch.log2(x), torch.log2(x + 1) - 2 5398 5399 self.common( 5400 fn, 5401 (torch.randn([64]) + 10,), 5402 ) 5403 5404 def test_logsumexp(self): 5405 def fn(x): 5406 return torch.logsumexp(x, -1), torch.logsumexp(x, 0) - 2 5407 5408 self.common( 5409 fn, 5410 (torch.randn([8, 8]) + 10,), 5411 ) 5412 5413 def test_log_fp64(self): 5414 def fn(x): 5415 return torch.log(x), torch.log2(x) 5416 5417 self.common( 5418 fn, 5419 (torch.randn([1024], dtype=torch.float64) + 10,), 5420 ) 5421 5422 def test_bitwise(self): 5423 def fn(x, y): 5424 return ( 5425 torch.bitwise_not(x), 5426 torch.bitwise_or(x, y), 5427 torch.bitwise_xor(x, y), 5428 torch.bitwise_and(x, y), 5429 ) 5430 5431 self.common( 5432 fn, 5433 ( 5434 torch.randint(0, 2**30, [64], dtype=torch.int32), 5435 torch.randint(0, 2**30, [64], dtype=torch.int32), 5436 ), 5437 ) 5438 5439 def test_bitwise2(self): 5440 # again with bool types 5441 def fn(x, y): 5442 return ( 5443 torch.bitwise_not(x), 5444 torch.bitwise_or(x, y), 5445 torch.bitwise_xor(x, y), 5446 torch.bitwise_and(x, y), 5447 ) 5448 5449 self.common( 5450 fn, 5451 ( 5452 torch.randint(0, 2, (2, 20), dtype=torch.bool), 5453 torch.randint(0, 2, (2, 20), dtype=torch.bool), 5454 ), 5455 ) 5456 5457 def test_bitwise3(self): 5458 # Repro for https://github.com/pytorch/pytorch/issues/97968 5459 def fn(x, y): 5460 return ( 5461 torch.max(torch.bitwise_and(x, y), y), 5462 torch.clamp_max(torch.bitwise_or(x, y), y), 5463 torch.clamp_min(torch.bitwise_xor(x, y), y), 5464 ) 5465 5466 self.common( 5467 fn, 5468 ( 5469 torch.rand([5, 10, 1]).to(torch.int8), 5470 torch.rand([10, 1]).to(torch.int8), 5471 ), 5472 ) 5473 5474 def test_inf(self): 5475 def fn(a): 5476 return a + float("inf"), a + float("-inf"), a * -float("inf") 5477 5478 self.common(fn, (torch.randn(8),)) 5479 5480 def test_remainder(self): 5481 def fn(a, b): 5482 return ( 5483 torch.remainder(a, b), 5484 torch.remainder(a + 1, b - 1), 5485 torch.remainder(a - 1, b + 1), 5486 ) 5487 5488 self.common(fn, (torch.randn(64), torch.randn(64))) 5489 5490 def test_zeros(self): 5491 def fn(a): 5492 return ( 5493 a + 1, 5494 torch.zeros( 5495 (1, 8, 64, 64), 5496 dtype=torch.float32, 5497 device=a.device, 5498 ), 5499 torch.zeros( 5500 1, 5501 8, 5502 64, 5503 64, 5504 dtype=torch.float32, 5505 device=a.device, 5506 ), 5507 torch.zeros(2, 3), 5508 a + torch.ones(8, device=a.device), 5509 torch.full((2, 3), 3.1416, device=a.device), 5510 ) 5511 5512 self.common(fn, (torch.randn(8),)) 5513 5514 def test_new_ones(self): 5515 def fn(a): 5516 return ( 5517 aten.new_ones( 5518 a, [], device=a.device, dtype=6, layout=0, pin_memory=False 5519 ), 5520 aten.new_zeros( 5521 a, [], device=a.device, dtype=6, layout=0, pin_memory=False 5522 ), 5523 ) 5524 5525 self.common(fn, (torch.randn(8),)) 5526 5527 def test_full_like(self): 5528 def fn(a): 5529 return torch.full_like(a, 7.777) - 1 5530 5531 self.common(fn, (torch.randn(8),)) 5532 5533 def test_full_truncation(self): 5534 def fn(a): 5535 return a + torch.full_like(a, 7.777) 5536 5537 for dtype in all_types(): 5538 self.common(fn, (make_tensor(8, dtype=dtype, device=self.device),)) 5539 5540 def test_full_boolean(self): 5541 def fn(n): 5542 x = torch.full((1,), n >= 1024, device=self.device) 5543 return x, x + 1 5544 5545 self.common(fn, (1024,)) 5546 self.common(fn, (1023,)) 5547 5548 def test_index1(self): 5549 def fn(a, b, c): 5550 return aten.index(a, [b, c]) 5551 5552 self.common( 5553 fn, 5554 ( 5555 torch.randn(8, 8, 12), 5556 torch.tensor([0, 0, 2, 2], dtype=torch.int64), 5557 torch.tensor([3, 4, 4, 3], dtype=torch.int64), 5558 ), 5559 ) 5560 self.common( 5561 fn, 5562 ( 5563 torch.randn(8, 8, 12), 5564 torch.tensor([[0, 0, 2, 2]], dtype=torch.int64), 5565 torch.tensor([[3], [4], [4], [3]], dtype=torch.int64), 5566 ), 5567 ) 5568 5569 def test_index2(self): 5570 def fn(a, b): 5571 return ( 5572 aten.index(a, [b]), 5573 aten.index(a, [None, b]), 5574 ) 5575 5576 self.common( 5577 fn, 5578 ( 5579 torch.randn(8, 8, 8), 5580 torch.tensor([[0, 0, 2, 2]], dtype=torch.int64), 5581 ), 5582 ) 5583 5584 def test_index3(self): 5585 def fn(x, ia, ib): 5586 return (x[:, ia, None, ib, 0],) 5587 5588 self.common( 5589 fn, 5590 ( 5591 torch.randn(3, 4, 4, 4, 3), 5592 torch.tensor([0, 2, 1], dtype=torch.int64), 5593 torch.tensor([0, 2, 1], dtype=torch.int64), 5594 ), 5595 ) 5596 5597 def test_output_strides(self): 5598 def fn(x): 5599 y = x.permute(0, 2, 3, 1).contiguous() 5600 torch._dynamo.graph_break() 5601 return y.view(-1, 4) 5602 5603 inp = torch.rand([4, 4, 4, 4], device=self.device) 5604 fn_opt = torch._dynamo.optimize("inductor")(fn) 5605 5606 self.assertEqual(fn(inp), fn_opt(inp)) 5607 self.assertEqual(fn(inp).stride(), fn_opt(inp).stride()) 5608 5609 # no redundant copy 5610 def foo(x): 5611 return x[0:2:2].T[3:].squeeze(0) 5612 5613 foo_opt = torch._dynamo.optimize("inductor")(foo) 5614 out = foo_opt(inp) 5615 self.assertEqual(inp.storage(), out.storage()) 5616 5617 def test_index_select(self): 5618 def fn(a, b): 5619 return ( 5620 torch.index_select(a, 0, b), 5621 torch.index_select(a, 1, b), 5622 torch.index_select(torch.index_select(a, 2, b), 1, b), 5623 ) 5624 5625 for ind_dtype in (torch.int32, torch.int64): 5626 self.common( 5627 fn, 5628 ( 5629 torch.randn(8, 8, 8), 5630 torch.tensor([0, 0, 2, 1], dtype=ind_dtype), 5631 ), 5632 ) 5633 5634 @skipCUDAIf(not TEST_CUDNN, "CUDNN not available") 5635 @skipIfXpu 5636 @skipIfRocm 5637 def test_cudnn_rnn(self): 5638 if self.device == "cpu": 5639 raise unittest.SkipTest(f"requires {GPU_TYPE}") 5640 5641 def fn( 5642 a0, 5643 b0, 5644 b1, 5645 b2, 5646 b3, 5647 b4, 5648 b5, 5649 b6, 5650 b7, 5651 b8, 5652 b9, 5653 b10, 5654 b11, 5655 b12, 5656 b13, 5657 b14, 5658 b15, 5659 a3, 5660 a4, 5661 a5, 5662 ): 5663 a1 = [ 5664 b0, 5665 b1, 5666 b2, 5667 b3, 5668 b4, 5669 b5, 5670 b6, 5671 b7, 5672 b8, 5673 b9, 5674 b10, 5675 b11, 5676 b12, 5677 b13, 5678 b14, 5679 b15, 5680 ] 5681 return aten._cudnn_rnn( 5682 a0, 5683 a1, 5684 4, 5685 a3, 5686 a4, 5687 a5, 5688 2, 5689 2048, 5690 0, 5691 2, 5692 False, 5693 0.0, 5694 False, 5695 True, 5696 [], 5697 None, 5698 ) 5699 5700 self.common( 5701 fn, 5702 ( 5703 torch.randn([92, 8, 2048]), 5704 torch.randn([8192, 2048]), 5705 torch.randn([8192, 2048]), 5706 torch.randn([8192]), 5707 torch.randn([8192]), 5708 torch.randn([8192, 2048]), 5709 torch.randn([8192, 2048]), 5710 torch.randn([8192]), 5711 torch.randn([8192]), 5712 torch.randn([8192, 4096]), 5713 torch.randn([8192, 2048]), 5714 torch.randn([8192]), 5715 torch.randn([8192]), 5716 torch.randn([8192, 4096]), 5717 torch.randn([8192, 2048]), 5718 torch.randn([8192]), 5719 torch.randn([8192]), 5720 torch.randn([167837696]), 5721 torch.randn([4, 8, 2048]), 5722 torch.randn([4, 8, 2048]), 5723 ), 5724 check_lowp=False, # difference in rnn is too large between half and float inputs 5725 ) 5726 5727 def test_upsample_nearest1d(self): 5728 def fn(a): 5729 return ( 5730 aten.upsample_nearest1d(a, [74], None), 5731 aten.upsample_nearest1d(a, [70], None), 5732 aten.upsample_nearest1d(a, [45], None), 5733 aten.upsample_nearest1d(a, [36], None), 5734 aten.upsample_nearest1d(a, None, [2.0]), 5735 ) 5736 5737 self.common(fn, (torch.randn([2, 4, 37]),)) 5738 5739 def test_upsample_nearest2d(self): 5740 def fn(a): 5741 return ( 5742 aten.upsample_nearest2d(a, [74, 76]), 5743 aten.upsample_nearest2d(a, [70, 75]), 5744 aten.upsample_nearest2d(a, [45, 74]), 5745 aten.upsample_nearest2d(a, [36, 39]), 5746 aten.upsample_nearest2d(a, None, [2.0, 2.0]), 5747 ) 5748 5749 self.common(fn, (torch.randn([2, 4, 37, 38]),)) 5750 5751 def test_upsample_nearest3d(self): 5752 def fn(a): 5753 return ( 5754 aten.upsample_nearest3d(a, [74, 76, 78], None), 5755 aten.upsample_nearest3d(a, [70, 75, 80], None), 5756 aten.upsample_nearest3d(a, [45, 74, 103], None), 5757 aten.upsample_nearest3d(a, [36, 39, 40], None), 5758 aten.upsample_nearest3d(a, None, [2.0, 2.0, 2.0]), 5759 ) 5760 5761 self.common(fn, (torch.randn([2, 4, 37, 38, 39]),)) 5762 5763 def test_upsample_nearest2d_backward(self): 5764 func = torch.ops.aten.upsample_nearest2d_backward 5765 5766 def fn(a): 5767 return ( 5768 func(a, output_size=[6, 12], input_size=[3, 3, 3, 6]), 5769 func(a, output_size=[6, 12], input_size=[3, 3, 4, 5]), 5770 func(a, output_size=[6, 12], input_size=[3, 3, 2, 8]), 5771 func(a, output_size=[6, 12], input_size=[3, 3, 2, 8]), 5772 func(a, output_size=[6, 12], input_size=[3, 3, 4, 7]), 5773 ) 5774 5775 self.common(fn, (torch.randn([3, 3, 6, 12]),)) 5776 5777 @skip_if_x86_mac() 5778 def test_upsample_bilinear2d_a(self): 5779 def fn(a): 5780 return ( 5781 aten.upsample_bilinear2d(a, [45, 45], False, None), 5782 aten.upsample_bilinear2d(a, None, True, [2.0, 2.0]), 5783 ) 5784 5785 self.common(fn, (torch.randn([2, 4, 37, 38]),), atol=2.5e-5, rtol=1.3e-6) 5786 5787 def test_upsample_bilinear2d_b(self): 5788 def fn(a): 5789 return aten.upsample_bilinear2d(a, None, True, [2.0, 2.0]) 5790 5791 self.common( 5792 fn, 5793 [ 5794 torch.randn([1, 2, 40, 59]), 5795 ], 5796 atol=2.5e-5, 5797 rtol=1.3e-6, 5798 ) 5799 5800 def test_reflection_pad2d(self): 5801 def fn(a, pad): 5802 return ( 5803 aten.reflection_pad2d(a, [1, 1, 1, 1]), 5804 aten.reflection_pad2d(a, pad), 5805 ) 5806 5807 self.common( 5808 fn, 5809 ( 5810 torch.randint(0, 999, size=[1, 1, 8, 8], dtype=torch.float32), 5811 [5, 2, 3, 4], 5812 ), 5813 ) 5814 5815 def test_reflection_pad2d_backward(self): 5816 def template(size, padding): 5817 def fn(grad_output, x): 5818 return aten.reflection_pad2d_backward(grad_output, x, padding) 5819 5820 x = torch.randint(0, 999, size=size, dtype=torch.float32) 5821 result = aten.reflection_pad2d(x, padding) 5822 grad_output = torch.randn_like(result) 5823 5824 self.common(fn, (grad_output, x)) 5825 5826 template([1, 1, 8, 8], [0, 0, 0, 0]) 5827 template([1, 1, 8, 8], [1, 1, 1, 1]) 5828 template([1, 1, 8, 8], [1, 2, 3, 4]) 5829 template([1, 1, 8, 8], [0, -1, 2, 2]) 5830 template([1, 1, 8, 8], [-1, 0, 2, 2]) 5831 template([1, 1, 8, 8], [2, 2, 0, -1]) 5832 template([1, 1, 8, 8], [2, 2, -1, 0]) 5833 5834 def test_grid_sampler_2d(self): 5835 def fn(a, b): 5836 return ( 5837 aten.grid_sampler_2d(a, b, 0, 0, True), 5838 aten.grid_sampler_2d(a, b, 0, 1, False), 5839 ) 5840 5841 self.common( 5842 fn, 5843 ( 5844 torch.randn([4, 3, 352, 352], dtype=torch.float32), 5845 torch.rand([4, 352, 352, 2], dtype=torch.float32) * 2 - 1, 5846 ), 5847 check_lowp=False, 5848 # Mismatched elements: 154697 / 1486848 (10.4%) 5849 # Greatest absolute difference: 0.0001976490020751953 at index (0, 0, 101, 243) (up to 1e-05 allowed) 5850 # Greatest relative difference: 7.332530120481928 at index (1, 1, 258, 301) (up to 1.3e-06 allowed) 5851 atol=0.0002, 5852 rtol=1.3e-06, 5853 ) 5854 5855 def test_upsample_bicubic2d(self): 5856 def fn(a): 5857 return ( 5858 aten.upsample_bicubic2d(a, (128, 128), True), 5859 aten.upsample_bicubic2d(a, (128, 256), False), 5860 ) 5861 5862 # Mismatched elements: 10 / 196608 (0.0%) 5863 # Greatest absolute difference: 1.3869255781173706e-05 at index (2, 1, 88, 65) (up to 1e-05 allowed) 5864 # Greatest relative difference: 0.0033082996811011046 at index (3, 1, 88, 91) (up to 1.3e-06 allowed) 5865 self.common( 5866 fn, 5867 (torch.randn([4, 3, 64, 32], dtype=torch.float32),), 5868 atol=2e-5, 5869 rtol=1e-3, 5870 ) 5871 5872 def test_float_index_expression(self): 5873 # Test that index propagation doesn't generate bad index_expr calls like 5874 # ops.index_expr(0.5*x, dtype) where the expression is not integral 5875 def fn(x): 5876 return aten.upsample_bicubic2d(x, (256, 256), False) 5877 5878 x = torch.randn(1, 1, 128, 128, dtype=torch.float32, device=self.device) 5879 _, source_codes = run_and_get_code(fn, x) 5880 5881 pattern = r"0\.50*\*[ix][\d]" 5882 for code in source_codes: 5883 self.assertIsNone( 5884 re.search(pattern, code), msg="Found bad index_expr in code:\n" + code 5885 ) 5886 5887 def test_float_index_expression_type_promotion(self): 5888 # Test that float indexing expressions participate in type promotion 5889 def fn(x): 5890 return x + 1.0 / x.size(0) 5891 5892 x = torch.arange(10) 5893 self.common(fn, (x,)) 5894 5895 def test_sort(self): 5896 def fn(a): 5897 return torch.sort(a) 5898 5899 self.common( 5900 fn, (torch.randint(0, 999, size=[1, 1, 8, 8], dtype=torch.float32),) 5901 ) 5902 5903 def test_topk(self): 5904 def fn(a): 5905 return torch.topk(a, 2, -1) 5906 5907 self.common( 5908 fn, (torch.randint(0, 999, size=[1, 1, 8, 8], dtype=torch.float32),) 5909 ) 5910 5911 def test_long_tensor(self): 5912 def fn(a): 5913 return ( 5914 torch.LongTensor([294]).to(a.device) - a, 5915 torch.as_tensor([295]).to(a.device) + a, 5916 ) 5917 5918 self.common(fn, (torch.randint(0, 999, size=[8, 8]),)) 5919 5920 def test_constant_pad_1d(self): 5921 def fn(a): 5922 return ( 5923 aten.constant_pad_nd(a, [0, 1], 6.0), 5924 aten.constant_pad_nd(a, [2, 3], 99.0), 5925 ) 5926 5927 self.common(fn, (torch.randint(0, 999, size=[2, 16, 31], dtype=torch.float32),)) 5928 5929 def test_constant_pad_fill_dtype(self): 5930 def fn(a, b): 5931 return ( 5932 aten.constant_pad_nd(a, (1, 1), 1.0) & b, 5933 aten.constant_pad_nd(a, (1, 1), 0.0) & b, 5934 ) 5935 5936 self.common( 5937 fn, 5938 (torch.randint(2, (4,), dtype=torch.bool), torch.ones(6, dtype=torch.bool)), 5939 ) 5940 5941 def test_constant_pad_2d(self): 5942 def fn(a): 5943 return ( 5944 aten.constant_pad_nd(a, [1, 1, 1, 1], 6.0), 5945 aten.constant_pad_nd(a, [1, 2, 3, 4], 99.0), 5946 ) 5947 5948 self.common( 5949 fn, (torch.randint(0, 999, size=[1, 1, 8, 8], dtype=torch.float32),) 5950 ) 5951 5952 def test_constant_pad_3d(self): 5953 def fn(a): 5954 return ( 5955 aten.constant_pad_nd(a, [1, 2, 3, 4, 5, 6], 6.0), 5956 aten.constant_pad_nd(a, [0, 0, 3, 4, 0, 0], 6.0), 5957 ) 5958 5959 self.common( 5960 fn, (torch.randint(0, 999, size=[2, 4, 4, 4], dtype=torch.float32),) 5961 ) 5962 5963 def test_constant_pad_float64(self): 5964 # Repro for https://github.com/pytorch/pytorch/issues/93351 5965 def fn(input): 5966 v1 = torch.nn.functional.pad(input, pad=(1, 0)) 5967 return torch.gt(v1, input) 5968 5969 x = torch.rand([1, 2, 2, 1], dtype=torch.float64) 5970 self.common(fn, (x,)) 5971 5972 def test_constant_pad_nd_inplace(self): 5973 def fn(a): 5974 return aten.constant_pad_nd(a, [0, 0]) 5975 5976 x = torch.randn([2], device=self.device) 5977 fn_compiled = torch.compile(fn) 5978 y = fn_compiled(x) 5979 self.assertTrue(y is not x) 5980 5981 def test_l1_loss(self): 5982 def fn(a, b): 5983 return torch.nn.functional.l1_loss(a, b), torch.nn.functional.mse_loss(a, b) 5984 5985 self.common( 5986 fn, 5987 ( 5988 torch.randn([2, 3, 16, 16]), 5989 torch.randn([2, 3, 16, 16]), 5990 ), 5991 check_lowp=False, 5992 ) 5993 5994 def test_triu(self): 5995 def fn(a): 5996 return aten.triu(a, 1), aten.triu(a, 0), aten.triu(a, 2) 5997 5998 self.common(fn, (torch.randn([2, 10, 10]),)) 5999 6000 def test_no_op_reduction(self): 6001 def fn(a): 6002 return a.sum(-1), torch.amax(a + 1, 1, keepdim=True) 6003 6004 self.common(fn, (torch.randn([8, 1, 1]),)) 6005 6006 def test_inplace_add(self): 6007 @torch._dynamo.optimize("inductor") 6008 def fn(x, y): 6009 return x.add_(y) 6010 6011 inputs = ( 6012 rand_strided((4, 4), (4, 1), device=self.device), 6013 rand_strided((4, 4), (4, 1), device=self.device), 6014 ) 6015 inp_clone = inputs[0].clone() 6016 out = fn(*inputs) 6017 self.assertTrue(same(out, inp_clone + inputs[1])) 6018 self.assertTrue(out is inputs[0]) 6019 6020 # The following 2 tests are meant to check the logic that drops 6021 # xmask from triton load/store if xnumel = 1 6022 @requires_gpu() 6023 def test_single_elem(self): 6024 def fn(a): 6025 b = a + 1 6026 return (b,) 6027 6028 self.common(fn, (torch.randn(1),)) 6029 6030 @requires_gpu() 6031 def test_single_elem_indirect(self): 6032 def fn(a, b): 6033 c = a[b] + 1 6034 return (c,) 6035 6036 a = torch.randn(1) 6037 b = (torch.tensor([0], dtype=torch.int64),) 6038 6039 self.common(fn, (a, b)) 6040 6041 # This test is meant to check for issues from the logic 6042 # that drops xmask from trito load/store if XBLOCK divides xnumel 6043 6044 @requires_gpu() 6045 def test_xblock_divides_xnumel(self): 6046 def fn(a): 6047 b = a + 1 6048 return (b,) 6049 6050 # assumption is that XBLOCK is always a divisor of 1024 6051 # so xmask will be dropped iff xnumel is multiple of 1024 6052 self.common(fn, (torch.randn(1024),)) 6053 self.common(fn, (torch.randn(1025),)) 6054 6055 def test_inplace_mixed_dtype_ops(self): 6056 @torch._dynamo.optimize("inductor") 6057 def fn(x, y): 6058 z = x + y.float() 6059 w = z.add_(y) 6060 return w.mul_(y) 6061 6062 inputs = ( 6063 rand_strided((4, 4), (4, 1), device=self.device, dtype=torch.float), 6064 rand_strided((4, 4), (4, 1), device=self.device, dtype=torch.double), 6065 ) 6066 out = fn(*inputs) 6067 out_eager = (inputs[0] + inputs[1].float()).add_(inputs[1]).mul_(inputs[1]) 6068 self.assertTrue(same(out, out_eager)) 6069 6070 @config.patch( 6071 {"triton.unique_kernel_names": True, "triton.descriptive_names": False} 6072 ) 6073 def test_kernel_names(self): 6074 @torch._dynamo.optimize("inductor") 6075 def fn(x): 6076 return 2 * x 6077 6078 inputs = (rand_strided((8,), (1,), device=self.device),) 6079 self.assertTrue(same(fn(*inputs), 2 * inputs[0])) 6080 6081 @config.patch({"triton.cudagraphs": True}) 6082 @dynamo_config.patch(automatic_dynamic_shapes=True) 6083 def test_strided_inputs(self): 6084 @torch._dynamo.optimize("inductor") 6085 def fn(x, y): 6086 return x + y 6087 6088 inputs = ( 6089 rand_strided((8, 16), (32, 2), device=self.device), 6090 rand_strided((8, 16), (16, 1), device=self.device), 6091 ) 6092 self.assertTrue(same(fn(*inputs), inputs[0] + inputs[1])) 6093 6094 @config.patch({"triton.cudagraphs": True}) 6095 @dynamo_config.patch(automatic_dynamic_shapes=True) 6096 def test_input_mutation1(self): 6097 def fn(a): 6098 b = a + 1 6099 a.copy_(b) 6100 c = a + 2 6101 return a * b / c 6102 6103 arg1 = torch.randn(64, device=self.device) 6104 arg2 = arg1.clone() 6105 arg3 = torch.randn(64, device=self.device) 6106 arg4 = arg3.clone() 6107 correct1 = fn(arg1) 6108 correct2 = fn(arg3) 6109 opt_fn = torch._dynamo.optimize_assert(compile_fx)(fn) 6110 actual1 = opt_fn(arg2) 6111 actual2 = opt_fn(arg4) 6112 6113 self.assertTrue(same(actual1, correct1)) 6114 self.assertTrue(same(actual2, correct2)) 6115 self.assertTrue(same(arg1, arg2)) 6116 self.assertTrue(same(arg3, arg4)) 6117 6118 def test_input_mutation2(self): 6119 def fn(a): 6120 b = a + 1 6121 a.view(64).copy_(torch.tensor([66.0], device=a.device)) 6122 c = a + 2 6123 return b, c 6124 6125 # NOTE: this test fails when none of the inputs require grad. 6126 # That seems like an inductor bug. 6127 arg1 = torch.randn([1, 64], device=self.device).requires_grad_(True).add(1) 6128 arg2 = arg1.clone() 6129 correct1 = fn(arg1) 6130 opt_fn = torch._dynamo.optimize_assert(compile_fx)(fn) 6131 actual1 = opt_fn(arg2) 6132 6133 self.assertTrue(same(actual1, correct1)) 6134 self.assertTrue(same(arg1, arg2)) 6135 6136 def test_input_mutation3(self): 6137 def fn(a): 6138 a += 1 6139 a *= 2 6140 aten.sigmoid_(a) 6141 a = a.view(64) 6142 a += 3 6143 a *= 4 6144 aten.relu_(a) 6145 return a 6146 6147 arg1 = torch.randn([1, 64], device=self.device) 6148 arg2 = arg1.clone() 6149 correct1 = fn(arg1) 6150 opt_fn = torch._dynamo.optimize_assert(compile_fx)(fn) 6151 actual1 = opt_fn(arg2) 6152 6153 self.assertTrue(same(actual1, correct1)) 6154 self.assertTrue(same(arg1, arg2)) 6155 6156 def test_input_mutation4(self): 6157 def fn(a): 6158 torch.relu_(a) 6159 return a 6160 6161 arg1 = torch.randn([1, 64], device=self.device) 6162 arg2 = arg1.clone() 6163 correct1 = fn(arg1) 6164 opt_fn = torch._dynamo.optimize_assert(compile_fx)(fn) 6165 actual1 = opt_fn(arg2) 6166 6167 self.assertTrue(same(actual1, correct1)) 6168 self.assertTrue(same(arg1, arg2)) 6169 6170 def test_input_mutation5(self): 6171 def fn(x): 6172 tmp = x.ceil() 6173 x.add_(10) 6174 return tmp 6175 6176 opt_fn = torch._dynamo.optimize()(fn) 6177 6178 a = torch.zeros((), dtype=torch.int64, device=self.device) 6179 a_expect = a.clone() 6180 expect = fn(a_expect) 6181 6182 a_actual = a.clone() 6183 actual = opt_fn(a_actual) 6184 6185 self.assertEqual(a_expect, a_actual) 6186 self.assertEqual(expect, actual) 6187 6188 def test_slice_mutation1(self): 6189 def fn(a): 6190 x = torch.zeros_like(a) 6191 b = x + 1 6192 x[:, 3] = 3.0 6193 c = torch.clone(x) 6194 x[4, :] = 4.0 6195 d = x + 1 6196 return x, b, c, d 6197 6198 self.common(fn, (torch.randn([8, 8]),)) 6199 6200 def test_slice_mutation2(self): 6201 def fn(a): 6202 a[:, 20:40] = a[:, 20:40] + 1 6203 a[:, 2:11] = a[:, 1:10] + 2 6204 6205 arg1 = torch.randn([1, 64], device=self.device) 6206 arg2 = arg1.clone() 6207 fn(arg1) 6208 opt_fn = torch._dynamo.optimize_assert(compile_fx)(fn) 6209 opt_fn(arg2) 6210 self.assertTrue(same(arg1, arg2)) 6211 6212 def test_slice_mutation3(self): 6213 def fn(a): 6214 a[:2, :2].fill_(10) 6215 6216 opt_fn = torch._dynamo.optimize_assert(compile_fx)(fn) 6217 6218 x1 = torch.randn(8, 8, device=self.device) 6219 x2 = x1.clone() 6220 fn(x1) 6221 opt_fn(x2) 6222 self.assertEqual(x1, x2) 6223 6224 def test_tensor_index_slice(self): 6225 def fn(a): 6226 x = torch.tensor([1, 2], device=self.device) 6227 y = torch.tensor([2, 3], device=self.device) 6228 xx = torch.tensor([1, 2], device=self.device).view(1, 2) 6229 yy = torch.tensor([1, 2, 3], device=self.device).view(3, 1) 6230 return [ 6231 a[x, y], 6232 a[:, x, y], 6233 a[:, x, y, :], 6234 a[x, :, y], 6235 a[:, x, :, y, :], 6236 a[xx, yy], 6237 a[:, xx, yy], 6238 a[xx, :, yy], 6239 a[xx, yy, :], 6240 a[:, xx, :, yy], 6241 ] 6242 6243 a = torch.arange(3 * 4 * 5 * 6 * 7, device=self.device).view(3, 4, 5, 6, 7) 6244 refs = fn(a) 6245 tests = torch.compile(fn)(a) 6246 for ref, test in zip(refs, tests): 6247 torch.testing.assert_close(ref, test) 6248 6249 @torch._dynamo.config.patch(cache_size_limit=10) 6250 def test_tensor_index_put_slice(self): 6251 def fn(a, version): 6252 x = torch.tensor([1, 2], device=self.device, dtype=torch.int32) 6253 y = torch.tensor([2, 3], device=self.device, dtype=torch.int32) 6254 6255 xx = torch.tensor([1, 2], device=self.device).view(1, 2) 6256 yy = torch.tensor([1, 2, 3], device=self.device).view(3, 1) 6257 6258 if version == 0: 6259 a[x, y] = torch.zeros_like(a[x, y]) 6260 elif version == 1: 6261 a[:, x, y] = torch.zeros_like(a[:, x, y]) 6262 elif version == 2: 6263 a[:, x, y, :] = torch.zeros_like(a[:, x, y, :]) 6264 elif version == 3: 6265 a[x, :, y] = torch.zeros_like(a[x, :, y]) 6266 elif version == 4: 6267 a[:, x, :, y, :] = torch.zeros_like(a[:, x, :, y, :]) 6268 elif version == 5: 6269 a[xx, yy] = torch.zeros_like(a[xx, yy]) 6270 elif version == 6: 6271 a[:, xx, yy] = torch.zeros_like(a[:, xx, yy]) 6272 elif version == 7: 6273 a[xx, :, yy] = torch.zeros_like(a[xx, :, yy]) 6274 elif version == 8: 6275 a[xx, yy, :] = torch.zeros_like(a[xx, yy, :]) 6276 elif version == 9: 6277 a[:, xx, :, yy] = torch.zeros_like(a[:, xx, :, yy]) 6278 6279 return a 6280 6281 a = torch.arange(3 * 4 * 5 * 6 * 7, device=self.device, dtype=torch.int32).view( 6282 3, 4, 5, 6, 7 6283 ) 6284 for i in range(10): 6285 ref = fn(torch.clone(a), i) 6286 test = torch.compile(fn)(torch.clone(a), i) 6287 torch.testing.assert_close(ref, test) 6288 6289 def test_indirect_load_broadcast(self): 6290 def fn(in_ptr0, in_ptr1, in_ptr2): 6291 return torch.gather(in_ptr1, 0, in_ptr2) + in_ptr0 6292 6293 arg190 = rand_strided((32, 21), (1, 32), device=self.device, dtype=torch.int64) 6294 arg190.fill_(0) 6295 arg111 = rand_strided( 6296 (9521, 512), (512, 1), device=self.device, dtype=torch.float32 6297 ) 6298 self.common( 6299 fn, 6300 ( 6301 torch.randn(32, 1), 6302 arg111, 6303 arg190, 6304 ), 6305 ) 6306 6307 def test_roi_align(self): 6308 if not has_torchvision_roi_align(): 6309 raise unittest.SkipTest("requires torchvision") 6310 6311 def fn(a, b): 6312 return torch.ops.torchvision.roi_align(a, b, 0.25, 7, 7, 2, False) 6313 6314 self.common(fn, (torch.zeros([4, 256, 296, 304]), torch.zeros([2292, 5]))) 6315 6316 def test_nll_loss_forward(self): 6317 def fn(a, b): 6318 return aten.nll_loss_forward(a, b, None, 1, -100) 6319 6320 labels = ( 6321 torch.zeros([5], dtype=torch.int64), 6322 torch.tensor([-100, -100, 3, -100, -100], dtype=torch.int64), 6323 ) 6324 inps = (torch.randn(5, 5), torch.randn(5, 5)) 6325 for a, b in zip(inps, labels): 6326 self.common( 6327 fn, 6328 (a, b), 6329 ) 6330 6331 @skipIfXpu 6332 def test_nll_loss_backward(self): 6333 def fn(a, b, c): 6334 return aten.nll_loss_backward( 6335 a, b, c, None, 1, -100, torch.tensor(1.0, device=self.device) 6336 ) 6337 6338 labels = ( 6339 torch.zeros([5], dtype=torch.int64), 6340 torch.tensor([-100, -100, 3, -100, -100], dtype=torch.int64), 6341 ) 6342 inps = (torch.randn(5, 5), torch.randn(5, 5)) 6343 grad_outs = (torch.randn(()), torch.randn(())) 6344 for a, b, c in zip(grad_outs, inps, labels): 6345 self.common( 6346 fn, 6347 (a, b, c), 6348 ) 6349 6350 def test_isinf(self): 6351 def fn(x): 6352 return x.isinf(), x.isnan() 6353 6354 self.common( 6355 fn, [torch.tensor([1, float("inf"), 2, float("-inf"), float("nan")])] 6356 ) 6357 self.common( 6358 fn, 6359 [ 6360 torch.tensor( 6361 [1, float("inf"), 2, float("-inf"), float("nan")], 6362 dtype=torch.float64, 6363 ) 6364 ], 6365 ) 6366 6367 def test_isinf2(self): 6368 def fn(x): 6369 y = torch.tensor( 6370 [1, float("inf"), 2, float("-inf"), float("nan")], device=self.device 6371 ) 6372 return x == y 6373 6374 self.common( 6375 fn, (torch.tensor([1, float("inf"), 2, float("-inf"), float("nan")]),) 6376 ) 6377 6378 def test_any(self): 6379 def fn(x): 6380 return ( 6381 x.any(-1), 6382 x.isinf().any(), 6383 torch.all(x.isinf(), dim=0), 6384 torch.all(torch.logical_not(x.isinf())), 6385 ) 6386 6387 self.common(fn, [-torch.rand(64)]) 6388 tmp = torch.randn(16, 8) 6389 tmp[1, 1] = float("inf") 6390 self.common(fn, [tmp]) 6391 6392 def test_multilayer_any(self): 6393 def fn(x): 6394 return (x.isinf().any(), x.isfinite().all()) 6395 6396 sample = torch.rand(9, 3, 353, 353) 6397 self.common(fn, [sample]) 6398 6399 sample.view(-1)[-1] = float("inf") 6400 self.common(fn, [sample]) 6401 6402 def test_inplace_activations(self): 6403 def fn(x): 6404 a = aten.hardswish_(x + 1) 6405 b = aten.hardtanh_(x + 1) 6406 c = aten.leaky_relu_(x + 1) 6407 d = aten.silu_(x + 1) 6408 e = aten.log1p(x + 1) 6409 f = aten.masked_fill_(x + 1, torch.zeros_like(x, dtype=torch.bool), 99.0) 6410 h = aten.masked_fill_(x + 1, torch.ones_like(x, dtype=torch.bool), 99.0) 6411 return (a, b, c, d, e, f, h) 6412 6413 self.common(fn, [torch.randn(64) * 10]) 6414 6415 def test_baddbmm(self): 6416 def fn(a, b, c, beta): 6417 return aten.baddbmm(a, b, c, beta=beta) 6418 6419 b = torch.randn(6, 128, 64) 6420 c = torch.randn(6, 64, 100) 6421 options = itertools.product( 6422 [torch.randn(6, 1, 100), torch.randn(6, 1, 100).fill_(torch.nan)], 6423 [0.0, 1.0], 6424 ) 6425 for a, beta in options: 6426 self.common( 6427 fn, 6428 [a, b, c, beta], 6429 # Mismatched elements: 1212 / 76800 (1.6%) 6430 # Greatest absolute difference: 0.001953125 at index (0, 0, 93) (up to 1e-05 allowed) 6431 # Greatest relative difference: 1.0 at index (3, 19, 4) (up to 0.001 allowed) 6432 atol=0.002, 6433 rtol=0.001, 6434 ) 6435 6436 @config.patch({"triton.max_tiles": 2}) 6437 def test_fuse_tiled(self): 6438 def fn(a, b, c): 6439 return a + b, c + 1 6440 6441 self.common( 6442 fn, [torch.randn(128, 1), torch.randn(1, 128), torch.randn(128, 128)] 6443 ) 6444 6445 def test_expand_as(self): 6446 def fn(a, b): 6447 return aten.expand_as(a, b), aten.expand_as(a + 1, b + 1) + 1 6448 6449 self.common( 6450 fn, 6451 [ 6452 torch.randn(6, 1, 100), 6453 torch.randn(6, 128, 100), 6454 ], 6455 ) 6456 6457 def test_index_put1(self): 6458 def fn(a, b, c): 6459 return ( 6460 torch.index_put(a, [b], c), 6461 torch.index_put_(a + 1, [b + 1], c + 1) + 1, 6462 ) 6463 6464 self.common( 6465 fn, 6466 [ 6467 torch.randn([800, 256, 7, 7]), 6468 torch.randperm(601), 6469 torch.randn([601, 256, 7, 7]), 6470 ], 6471 ) 6472 self.common( 6473 fn, [torch.randn(1024, 4, 2), torch.arange(4), torch.randn(4, 1, 1)] 6474 ) 6475 6476 def test_index_put2(self): 6477 def fn(a, b, c): 6478 return torch.index_put(a, [b], c, True) 6479 6480 self.common( 6481 fn, 6482 [ 6483 torch.randn([100, 256, 7, 7]), 6484 torch.randint(0, 100, size=[600], dtype=torch.int64), 6485 torch.randn([600, 256, 7, 7]), 6486 ], 6487 # workaround for https://github.com/openai/triton/issues/558 6488 check_lowp=False, 6489 ) 6490 6491 def test_index_put3(self): 6492 def fn(a, b, c): 6493 torch.ops.aten.index_put_(a, (None, b, None), c) 6494 a1 = a + 1 6495 torch.ops.aten.index_put_(a1, (None, b + 1, None), c + 1) 6496 return (a, a1) 6497 6498 self.common( 6499 fn, 6500 [ 6501 torch.randn([1024, 4, 2]), 6502 torch.arange(3), 6503 torch.randn([1024, 1, 2]), 6504 ], 6505 ) 6506 6507 def test_index_put4(self): 6508 # a, b[0] are not broadcastable 6509 # https://github.com/pytorch/pytorch/issues/97104 6510 def fn(a, b, c): 6511 return torch.index_put(a, [b], c) 6512 6513 self.common( 6514 fn, 6515 [ 6516 torch.rand([8, 2]), 6517 torch.rand([8]) > 0.5, 6518 torch.rand([]), 6519 ], 6520 ) 6521 6522 def test_index_put_as_masked_fill(self): 6523 def fn(a, b, c, d): 6524 a = a.clone() 6525 torch.ops.aten.index_put_(a, [b], c, d) 6526 return a 6527 6528 self.common( 6529 fn, 6530 ( 6531 torch.randn([1024, 4, 2]), 6532 torch.randn([1024, 4, 2]) > 0, 6533 torch.randn([]), 6534 False, 6535 ), 6536 ) 6537 6538 self.common( 6539 fn, 6540 ( 6541 torch.randn([1024, 4, 2]), 6542 torch.randn([1024, 4, 2]) > 0, 6543 torch.randn([]), 6544 True, 6545 ), 6546 ) 6547 6548 def test_index_put_fallback1(self): 6549 def fn(a, b, c, d): 6550 a = a.clone() 6551 torch.ops.aten.index_put_(a, [b], c, d) 6552 return a 6553 6554 self.common( 6555 fn, 6556 ( 6557 torch.randn([3]), 6558 torch.as_tensor([True, True, False]), 6559 torch.randn([2]), 6560 False, 6561 ), 6562 ) 6563 6564 self.common( 6565 fn, 6566 ( 6567 torch.randn([3]), 6568 torch.as_tensor([True, True, False]), 6569 torch.randn([2]), 6570 True, 6571 ), 6572 ) 6573 6574 def test_index_put_fallback2(self): 6575 def fn(a, b, c, d, e): 6576 a = a.clone() 6577 torch.ops.aten.index_put_(a, [None, b, c], d, e) 6578 return a 6579 6580 self.common( 6581 fn, 6582 ( 6583 torch.randn([1, 2, 3]), 6584 torch.as_tensor([0, 1]), 6585 torch.as_tensor([True, True, False]), 6586 torch.randn([]), 6587 False, 6588 ), 6589 ) 6590 self.common( 6591 fn, 6592 ( 6593 torch.randn([1, 2, 3]), 6594 torch.as_tensor([0, 1]), 6595 torch.as_tensor([True, True, False]), 6596 torch.randn([]), 6597 True, 6598 ), 6599 ) 6600 6601 def test_index_put_deterministic_fallback(self): 6602 with DeterministicGuard(True): 6603 6604 def fn(a, b, c): 6605 return torch.index_put(a, [b], c, True) 6606 6607 self.common( 6608 fn, 6609 [ 6610 torch.randn([100, 32]), 6611 torch.randint(0, 100, size=[600], dtype=torch.int64), 6612 torch.randn([600, 32]), 6613 ], 6614 check_lowp=False, 6615 ) 6616 6617 def test_index_put_index(self): 6618 def fn(ind, x, src): 6619 y = torch.ops.aten.index_put.default(x, [ind], src) 6620 return torch.ops.aten.index.Tensor(y, [ind]) 6621 6622 args = [torch.tensor([1], dtype=torch.int64), torch.randn(8, 4), torch.randn(4)] 6623 self.common(fn, args) 6624 6625 def test_index_put_reinplace(self): 6626 def fn(x, idx): 6627 src = torch.ones(idx.size(0), device=x.device) 6628 x.index_put_((idx,), src) 6629 return x.expand((2, x.shape[0])) 6630 6631 a = torch.randn(1024) 6632 idx = torch.arange(10) 6633 torch._inductor.metrics.generated_kernel_count = 0 6634 self.common(fn, (a, idx)) 6635 assertGeneratedKernelCountEqual(self, 1) 6636 6637 def test_index_put_failed_reinplace(self): 6638 def fn(x, idx): 6639 src = torch.ones(idx.size(0), device=x.device) 6640 y = x.index_put((idx,), src) 6641 return x, y 6642 6643 a = torch.randn(1024) 6644 idx = torch.arange(10) 6645 torch._inductor.metrics.generated_kernel_count = 0 6646 self.common(fn, (a, idx)) 6647 assertGeneratedKernelCountEqual(self, 2) 6648 6649 def test_adding_tensor_offsets(self): 6650 @torch.compile(fullgraph=True) 6651 def fn(x): 6652 return x[16:32] 6653 6654 with torch.no_grad(): 6655 x = torch.randn(1024, device=self.device) 6656 self.assertEqual(fn(x[0:]), x[16:][:16]) 6657 self.assertEqual(fn(x[128:]), x[128 + 16 :][:16]) 6658 6659 # from GPT2ForSequenceClassification 6660 def test_index_tensor(self): 6661 def fn(x, y): 6662 ne = torch.ops.aten.ne.Scalar(x, 0) 6663 sum = torch.ops.aten.sum.dim_IntList(ne, [-1]) 6664 sub = torch.ops.aten.sub.Tensor(sum, 1) 6665 iota = torch.ops.prims.iota.default( 6666 1, 6667 start=0, 6668 step=1, 6669 dtype=torch.int64, 6670 device=x.device, 6671 requires_grad=False, 6672 ) 6673 return torch.ops.aten.index.Tensor(y, [iota, sub]) 6674 6675 self.common(fn, [torch.randn(1, 1024), torch.randn(1, 1024, 2)]) 6676 6677 @config.patch(fallback_random=True) 6678 def test_bernoulli1(self): 6679 def fn(a): 6680 b = torch.empty_like(a) 6681 return aten.bernoulli_(b), b 6682 6683 self.common( 6684 fn, 6685 [ 6686 torch.randn([100]), 6687 ], 6688 ) 6689 6690 def test_bernoulli2(self): 6691 def fn(a): 6692 return aten.bernoulli(a) 6693 6694 self.common( 6695 fn, 6696 [torch.tensor([1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0])], 6697 ) 6698 6699 def test_narrow(self): 6700 def fn(x): 6701 return ( 6702 aten.narrow(x, 1, 10, 16), 6703 aten.narrow(x + 2, 0, 10, 16) + 1, 6704 aten.narrow_copy(x, 1, 10, 16), 6705 ) 6706 6707 self.common(fn, [torch.randn(64, 64)]) 6708 6709 def test_new_cpp_build_logical(self): 6710 from torch._inductor.codecache import validate_new_cpp_commands 6711 6712 validate_new_cpp_commands() 6713 6714 def test_as_strided(self): 6715 def fn(x): 6716 return ( 6717 aten.as_strided(x, (8, 8, 64), (8 * 64, 64, 1), 0), 6718 aten.as_strided(x + 1, (8, 8, 64), (8 * 64, 64, 1), 0) + 2, 6719 ) 6720 6721 def fn_channels_last(x): 6722 return ( 6723 aten.as_strided( 6724 x, (8, 384, 2, 20, 12), (153600, 1, 61440, 384, 7680), 0 6725 ), 6726 aten.as_strided( 6727 x + 1, (8, 384, 2, 20, 12), (153600, 1, 61440, 384, 7680), 0 6728 ) 6729 + 2, 6730 ) 6731 6732 self.common(fn, [torch.randn(64, 64)]) 6733 self.common( 6734 fn_channels_last, 6735 [torch.randn(8, 384, 20, 20).to(memory_format=torch.channels_last)], 6736 ) 6737 6738 def test_like_channels_last(self): 6739 def foo(): 6740 randn = torch.randn((4, 3, 8, 8), device=self.device, dtype=torch.float32) 6741 xc = randn.contiguous(memory_format=torch.channels_last) 6742 clone = torch.zeros_like(xc, memory_format=torch.preserve_format) 6743 rand_like = torch.rand_like(randn) 6744 return (xc, clone, rand_like) 6745 6746 out = foo() 6747 out_comp = torch.compile()(foo)() 6748 6749 for t, t_comp in zip(out, out_comp): 6750 self.assertEqual(t.stride(), t_comp.stride()) 6751 6752 def test_as_strided_scatter(self): 6753 def fn(a, b): 6754 return aten.as_strided_scatter( 6755 a * 8 + 10, 6756 b * 2 - 4, 6757 size=(a.shape[0], a.shape[1] // 2), 6758 stride=(a.shape[1], 2), 6759 storage_offset=0, 6760 ) 6761 6762 self.common(fn, [torch.randn(10, 1024), torch.randn(10, 512)]) 6763 6764 def test_select_scatter(self): 6765 def fn(x, a, b): 6766 return ( 6767 aten.select_scatter(x, a, 1, 0), 6768 aten.select_scatter(x, b, 0, 1), 6769 ) 6770 6771 self.common( 6772 fn, 6773 [ 6774 torch.randn(8, 197, 38), 6775 torch.randn(8, 38), 6776 torch.randn(197, 38), 6777 ], 6778 ) 6779 6780 def test_slice_scatter(self): 6781 def fn(x, a): 6782 return ( 6783 aten.slice_scatter(x, a, 2, 10, -10), 6784 aten.slice_scatter(x, a[:, :, :40], 2, 10, -10, 2), 6785 ) 6786 6787 self.common( 6788 fn, 6789 [ 6790 torch.randn(4, 8, 100), 6791 torch.randn(4, 8, 80), 6792 ], 6793 ) 6794 6795 def test_slice_scatter2(self): 6796 def fn(a, b): 6797 return aten.slice_scatter(a, b, 0, 0, 9223372036854775807) 6798 6799 self.common( 6800 fn, 6801 [ 6802 torch.randn([8, 197, 384]), 6803 torch.randn([8, 197, 384]), 6804 ], 6805 ) 6806 6807 def test_slice_scatter3(self): 6808 def fn(a, b): 6809 return aten.slice_scatter.default(a, b, 1, 1, 9223372036854775807, 2) 6810 6811 self.common( 6812 fn, 6813 [ 6814 torch.randn([1, 4]), 6815 torch.randn([1, 2]), 6816 ], 6817 ) 6818 6819 def test_slice_scatter4(self): 6820 def fn(a, b): 6821 return aten.slice_scatter.default(a, b, 1, 2, 9223372036854775807, 3) 6822 6823 self.common( 6824 fn, 6825 [ 6826 torch.randn([1, 9]), 6827 torch.randn([1, 3]), 6828 ], 6829 ) 6830 6831 def test_slice_scatter5(self): 6832 # empty slices that require clamping the start or end 6833 def fn(a, b): 6834 return ( 6835 aten.slice_scatter.default(a, b, 0, 2, 0, 1), 6836 aten.slice_scatter.default(a, b, 0, a.shape[0], a.shape[0] + 10, 1), 6837 aten.slice_scatter.default(a, b, 0, -20, 0, 1), 6838 aten.slice_scatter.default(a, b, 0, -20, -16, 1), 6839 ) 6840 6841 a = torch.arange(10, dtype=torch.float) 6842 b = torch.empty(0) 6843 self.common(fn, [a, b]) 6844 6845 def test_slice_scatter_reinplace(self): 6846 class M(nn.Module): 6847 def __init__(self, device): 6848 super().__init__() 6849 self.linear1 = nn.Linear(64, 64, bias=False) 6850 self.cache_k = torch.zeros((56, 384, 8, 64), device=device) 6851 6852 def forward(self, x, start_pos): 6853 bsz, seqlen, _, _ = x.shape 6854 xk = self.linear1(x) 6855 with torch.no_grad(): 6856 self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk 6857 keys = self.cache_k[:bsz, : start_pos + seqlen] 6858 scores = torch.matmul( 6859 xk.transpose(1, 2), keys.transpose(1, 2).transpose(2, 3) 6860 ) 6861 return scores 6862 6863 kv_cache_module = M(self.device) 6864 inp = torch.randn(1, 32, 8, 64) 6865 6866 # Test that the cache update is reinplaced such that the cache is updated inplace 6867 # rather than copy-scatter-copy-back. 6868 6869 torch._inductor.metrics.generated_kernel_count = 0 6870 with torch.no_grad(): 6871 self.common(kv_cache_module, (inp, 1), check_lowp=False) 6872 assertGeneratedKernelCountEqual(self, 1) 6873 6874 def test_scatter1(self): 6875 def fn(a, dim, index, b): 6876 return aten.scatter(a, dim, index, b) 6877 6878 self.common( 6879 fn, 6880 [ 6881 torch.zeros(2, 3), 6882 -1, 6883 torch.tensor([[0]]), 6884 torch.ones(2, 3), 6885 ], 6886 ) 6887 6888 def test_scatter2(self): 6889 if self.device == "cuda": 6890 raise unittest.SkipTest("unstable on sm86") 6891 6892 check_lowp = True 6893 if self.device == "xpu": 6894 check_lowp = False 6895 6896 def fn(a, dim, index, b): 6897 return aten.scatter.reduce(a, dim, index, b, reduce="add") 6898 6899 self.common( 6900 fn, 6901 [ 6902 torch.zeros(64, 512), 6903 0, 6904 torch.zeros((64, 512), dtype=torch.int64), 6905 torch.ones(64, 512), 6906 ], 6907 check_lowp=check_lowp, 6908 ) 6909 6910 def test_scatter3(self): 6911 def fn(a, dim, index, b): 6912 return aten.scatter(a, dim, index, b, reduce="add") 6913 6914 check_lowp = True 6915 if self.device == "xpu": 6916 check_lowp = False 6917 6918 self.common( 6919 fn, 6920 [ 6921 torch.randn(5, 29, 13), 6922 2, 6923 torch.tensor([[[3, 5, 7, 9]]]), 6924 0.8, # src can be a scalar 6925 ], 6926 # Mismatched elements: 1 / 1885 (0.1%) 6927 # Greatest absolute difference: 0.00018310546875 at index (0, 0, 3) (up to 1e-05 allowed) 6928 # Greatest relative difference: 0.0022371364653243847 at index (0, 0, 3) (up to 0.001 allowed) 6929 atol=2e-4, 6930 rtol=1e-3, 6931 check_lowp=check_lowp, 6932 ) 6933 6934 def test_scatter4(self): 6935 def fn(x, ind, src): 6936 return torch.scatter(x, 0, ind, src) 6937 6938 check_lowp = True 6939 if self.device == "xpu": 6940 check_lowp = False 6941 6942 for deterministic in [False, True]: 6943 with DeterministicGuard(deterministic): 6944 self.common( 6945 fn, 6946 [ 6947 torch.randn(196, 992), 6948 torch.randint(196, (1, 992)), 6949 torch.randn(1, 992), 6950 ], 6951 check_lowp=check_lowp, 6952 ) 6953 6954 def test_scatter5(self): 6955 def fn(a, dim, index, b, reduce): 6956 a = a.clone() 6957 a.scatter_(dim, index, b, reduce=reduce) 6958 a1 = a + 1.0 6959 a1.scatter_(dim, index, b, reduce=reduce) 6960 return (a, a1) 6961 6962 check_lowp = True 6963 if self.device == "xpu": 6964 check_lowp = False 6965 6966 for reduce in ["add", "multiply"]: 6967 self.common( 6968 fn, 6969 [ 6970 torch.ones((4, 5)), 6971 0, 6972 torch.tensor([[1], [2], [3]], dtype=torch.int64), 6973 torch.randn(4, 5), 6974 reduce, 6975 ], 6976 check_lowp=check_lowp, 6977 ) 6978 6979 def test_scatter6(self): 6980 def fn(a, dim, index, b): 6981 return aten.scatter(a, dim, index, b) 6982 6983 check_lowp = True 6984 if self.device == "xpu": 6985 check_lowp = False 6986 6987 for deterministic in [False, True]: 6988 with DeterministicGuard(deterministic): 6989 self.common( 6990 fn, 6991 [ 6992 torch.randn(5, 8, 13), 6993 2, 6994 torch.tensor([[[3, 5, 7, 9]]]), 6995 0.8, # src can be a scalar 6996 ], 6997 check_lowp=check_lowp, 6998 ) 6999 7000 @unittest.skip("Flaky test, needs debugging") 7001 def test_scatter_add1(self): 7002 def fn(a, dim, index, b): 7003 return aten.scatter_add(a, dim, index, b) 7004 7005 check_lowp = True 7006 if self.device == "xpu": 7007 check_lowp = False 7008 7009 self.common( 7010 fn, 7011 [ 7012 torch.randn(2, 3), 7013 0, 7014 torch.tensor([[0]]), 7015 torch.randn(2, 3), 7016 ], 7017 check_lowp=check_lowp, 7018 ) 7019 7020 def test_scatter_add2(self): 7021 def fn(a, dim, index, b): 7022 return aten.scatter_add(a, dim, index, b) 7023 7024 check_lowp = True 7025 if self.device == "xpu": 7026 check_lowp = False 7027 7028 self.common( 7029 fn, 7030 [ 7031 torch.randn(2, 3), 7032 0, 7033 torch.tensor([[0, 0, 0], [1, 1, 1]]), 7034 torch.randn(2, 3), 7035 ], 7036 check_lowp=check_lowp, 7037 ) 7038 7039 def test_scatter_add3(self): 7040 def fn(a, dim, index, b): 7041 return aten.scatter_add(a, dim, index, b) 7042 7043 check_lowp = True 7044 if self.device == "xpu": 7045 check_lowp = False 7046 7047 for deterministic in [False, True]: 7048 with DeterministicGuard(deterministic): 7049 self.common( 7050 fn, 7051 [ 7052 torch.randn(5, 29, 13), 7053 2, 7054 torch.tensor([[[3, 5, 7, 9]]]), 7055 torch.randn(1, 1, 10), 7056 ], 7057 check_lowp=check_lowp, 7058 ) 7059 7060 def test_scatter_reduce1(self): 7061 def fn(a, dim, index, b): 7062 return aten.scatter_reduce(a, dim, index, b, "sum") 7063 7064 check_lowp = True 7065 if self.device == "xpu": 7066 check_lowp = False 7067 7068 self.common( 7069 fn, 7070 [ 7071 torch.randn(5, 29, 13), 7072 2, 7073 torch.tensor([[[3, 5, 7, 9]]]), 7074 torch.randn(1, 1, 10), 7075 ], 7076 check_lowp=check_lowp, 7077 ) 7078 7079 def test_scatter_reduce2(self): 7080 def fn(a, dim, index, b, reduce): 7081 return aten.scatter_reduce(a, dim, index, b, reduce, include_self=False) 7082 7083 check_lowp = True 7084 if self.device == "xpu": 7085 check_lowp = False 7086 7087 for reduce in ["sum", "amax"]: 7088 self.common( 7089 fn, 7090 [ 7091 torch.randn(2, 3), 7092 0, 7093 torch.zeros((2, 3), dtype=torch.int64), 7094 torch.randn(2, 3), 7095 reduce, 7096 ], 7097 check_lowp=check_lowp, 7098 ) 7099 7100 def test_scatter_reduce3(self): 7101 def fn(a, dim, index, b, reduce): 7102 a = a.clone() 7103 a.scatter_reduce_(dim, index, b, reduce=reduce) 7104 a1 = a + 1.0 7105 a1.scatter_reduce_(dim, index, b, reduce=reduce) 7106 return (a, a1) 7107 7108 check_lowp = True 7109 if self.device == "xpu": 7110 check_lowp = False 7111 7112 for reduce in ["sum", "prod"]: 7113 self.common( 7114 fn, 7115 [ 7116 torch.ones((4, 5)), 7117 0, 7118 torch.tensor([[1], [2], [3]], dtype=torch.int64), 7119 torch.randn(4, 5), 7120 reduce, 7121 ], 7122 check_lowp=check_lowp, 7123 ) 7124 7125 def test_dense_mask_index(self): 7126 r""" 7127 There will be a little difference for reduce order between aten and inductor 7128 https://github.com/pytorch/pytorch/pull/122289 7129 Absolute difference: 0.00067138671875 (up to 1e-05 allowed) 7130 Relative difference: 3.1747371732500974e-06 (up to 1.3e-06 allowed) 7131 """ 7132 kwargs = {} 7133 if self.device == "cpu": 7134 kwargs["atol"] = 1e-4 7135 kwargs["rtol"] = 1.3e-5 7136 7137 def fn(x, y): 7138 y = torch.ops.aten.select.int(y, 0, 2) 7139 z = x * y 7140 return z.sum() 7141 7142 self.common(fn, [torch.randn(102400), torch.randn(3)], **kwargs) 7143 7144 def test_empty1(self): 7145 def fn(): 7146 return torch.empty((1, 128, 128)) 7147 7148 self.common(fn, [], assert_equal=False) 7149 7150 def test_empty2(self): 7151 def fn(): 7152 return aten.empty((1, 128, 128)) 7153 7154 self.common(fn, [], assert_equal=False) 7155 7156 def test_new_empty(self): 7157 def fn(a): 7158 return aten.new_empty(a, [1, 128, 128]) 7159 7160 self.common(fn, [torch.randn(55)], assert_equal=False) 7161 7162 def test_empty_strided(self): 7163 def fn(): 7164 return aten.empty_strided([1, 128, 128], [16384, 128, 1]) 7165 7166 self.common(fn, [], assert_equal=False) 7167 7168 def test_new_empty_strided(self): 7169 def fn(a): 7170 return aten.new_empty_strided(a, [1, 128, 128], [16384, 128, 1]) 7171 7172 self.common(fn, [torch.randn(55)], assert_equal=False) 7173 7174 def test_dropout_trivial_0(self): 7175 def fn1(a): 7176 return torch.nn.functional.dropout(a, 0.0, True) + a 7177 7178 self.common(fn1, [torch.randn(55)]) 7179 7180 def test_dropout_trivial_1(self): 7181 def fn2(a): 7182 return torch.nn.functional.dropout(a, 1.0, True) + a 7183 7184 self.common(fn2, [torch.randn(55)]) 7185 7186 @config.patch({"triton.cudagraphs": True}) 7187 @dynamo_config.patch(automatic_dynamic_shapes=True) 7188 def test_dropout(self): 7189 random.seed(1234) 7190 torch.manual_seed(1234) 7191 7192 @torch._dynamo.optimize("inductor") 7193 def fn1(a): 7194 return torch.nn.functional.dropout(a) 7195 7196 x = torch.ones(1000, device=self.device, dtype=torch.float32) 7197 result1 = fn1(x) 7198 self.assertTrue(400 < result1.nonzero().shape[0] < 600) 7199 self.assertTrue(0.9 < result1.mean().item() < 1.1) 7200 7201 random.seed(1234) 7202 torch.manual_seed(1234) 7203 7204 @torch._dynamo.optimize("inductor") 7205 def fn2(a): 7206 return torch.nn.functional.dropout(a, 0.5, True) 7207 7208 result2 = fn2(x) 7209 self.assertTrue(400 < result2.nonzero().shape[0] < 600) 7210 self.assertTrue(0.9 < result2.mean().item() < 1.1) 7211 7212 @dynamo_config.patch(automatic_dynamic_shapes=True) 7213 def test_dropout_deterministic(self): 7214 @torch._dynamo.optimize("inductor") 7215 def fn(a): 7216 return torch.nn.functional.dropout(a, 0.55, True) 7217 7218 for cg in [False, True]: 7219 with patch.object(config.triton, "cudagraphs", cg): 7220 torch._dynamo.reset() 7221 7222 x = torch.ones(1024, device=self.device, dtype=torch.float32) 7223 7224 torch.manual_seed(1234) 7225 a0 = fn(x).clone() 7226 a1 = fn(x).clone() 7227 a2 = fn(x).clone() 7228 7229 torch.manual_seed(1234) 7230 b0 = fn(x).clone() 7231 b1 = fn(x).clone() 7232 b2 = fn(x).clone() 7233 7234 # same seed, same values 7235 self.assertTrue(torch.allclose(a0, b0)) 7236 self.assertTrue(torch.allclose(a1, b1)) 7237 self.assertTrue(torch.allclose(a2, b2)) 7238 7239 # different calls, different values 7240 self.assertFalse(torch.allclose(a0, a1)) 7241 self.assertFalse(torch.allclose(a1, a2)) 7242 7243 def test_rand_like_deterministic(self): 7244 @torch._dynamo.optimize("inductor") 7245 def fn(a): 7246 return torch.rand_like(a), torch.rand_like(a) 7247 7248 x = torch.ones(1024, device=self.device, dtype=torch.float32) 7249 7250 torch.manual_seed(1234) 7251 a0 = fn(x)[0].clone() 7252 a1 = fn(x)[0].clone() 7253 a2 = fn(x)[0].clone() 7254 7255 torch.manual_seed(1234) 7256 b0 = fn(x)[0].clone() 7257 b1 = fn(x)[0].clone() 7258 b2 = fn(x)[0].clone() 7259 7260 # same seed, same values 7261 self.assertTrue(torch.allclose(a0, b0)) 7262 self.assertTrue(torch.allclose(a1, b1)) 7263 self.assertTrue(torch.allclose(a2, b2)) 7264 7265 # different calls, different values 7266 self.assertFalse(torch.allclose(a0, a1)) 7267 self.assertFalse(torch.allclose(a1, a2)) 7268 7269 c, d = fn(x) 7270 self.assertFalse(torch.allclose(c, d)) 7271 self.assertTrue((c >= 0).all()) 7272 self.assertTrue((c < 1).all()) 7273 self.assertTrue((d >= 0).all()) 7274 self.assertTrue((d < 1).all()) 7275 7276 @config.patch(implicit_fallbacks=True) 7277 def test_fallback_mutable_op_basic(self): 7278 with torch.library._scoped_library("mylib", "FRAGMENT") as m: 7279 7280 def impl(a, b, c, d, e=2): 7281 a.add_(b[0] * c * e), 7282 if d is not None: 7283 d.add_(b[1]) 7284 7285 m.define( 7286 "inplace_(Tensor(a!) a, Tensor[] b, SymInt c, *, Tensor(b!)? d, SymInt e=2) -> ()" 7287 ) 7288 m.impl("inplace_", impl, "CompositeExplicitAutograd") 7289 7290 # We do some clones and copy_ to test that Inductor doesn't reorder 7291 # the copy_ w.r.t. inplace_. 7292 def f(a, b1, b2, c, d): 7293 a_ = a.clone() 7294 d_ = d if d is None else d.clone() 7295 torch.ops.mylib.inplace_(a_, (b1, b2), c, d=d_) 7296 a.copy_(a_) 7297 if d is not None: 7298 d.copy_(d_) 7299 return () 7300 7301 a = torch.tensor([0.0, 1.0, 2]) 7302 b = [torch.tensor([2.0, 3.0, 5.0]), torch.tensor([1.0, 4.0, 6.0])] 7303 c = 4 7304 d = torch.tensor([2.0, 1, 0]) 7305 args = (a, b[0], b[1], c, d) 7306 cloned_args = pytree.tree_map_only(torch.Tensor, torch.clone, args) 7307 mod = make_fx(f)(*cloned_args) 7308 cloned_args = pytree.tree_map_only(torch.Tensor, torch.clone, args) 7309 compiled_f = compile_fx_inner(mod, cloned_args) 7310 7311 cloned_args = pytree.tree_map_only(torch.Tensor, torch.clone, args) 7312 compiled_f(list(cloned_args)) 7313 f(*args) 7314 self.assertEqual(cloned_args, args) 7315 7316 @config.patch(implicit_fallbacks=True) 7317 def test_fallback_mutable_op_with_return(self): 7318 with torch.library._scoped_library("mylib", "FRAGMENT") as m: 7319 7320 def impl(a, b, c, d, e=2): 7321 a.add_(b[0] * c * e), 7322 if d is not None: 7323 d.add_(b[1]) 7324 return b[0] + b[1] 7325 7326 m.define( 7327 "inplace_(Tensor(a!) a, Tensor[] b, SymInt c, *, Tensor(b!)? d, SymInt e=2) -> Tensor" 7328 ) 7329 m.impl("inplace_", impl, "CompositeExplicitAutograd") 7330 7331 # We do some clones and copy_ to test that Inductor doesn't reorder 7332 # the copy_ w.r.t. inplace_. 7333 def f(a, b0, b1, c, d): 7334 a_ = a.clone() 7335 d_ = d if d is None else d.clone() 7336 res = torch.ops.mylib.inplace_(a_, (b0, b1), c, d=d_) 7337 a.copy_(a_) 7338 if d is not None: 7339 d.copy_(d_) 7340 return (res,) 7341 7342 a = torch.tensor([0.0, 1.0, 2]) 7343 b = [torch.tensor([2.0, 3.0, 5.0]), torch.tensor([1.0, 4.0, 6.0])] 7344 c = 4 7345 d = torch.tensor([2.0, 1, 0]) 7346 args = (a, b[0], b[1], c, d) 7347 7348 cloned_args = pytree.tree_map_only(torch.Tensor, torch.clone, args) 7349 mod = make_fx(f)(*cloned_args) 7350 cloned_args = pytree.tree_map_only(torch.Tensor, torch.clone, args) 7351 compiled_f = compile_fx_inner(mod, cloned_args) 7352 7353 cloned_args = pytree.tree_map_only(torch.Tensor, torch.clone, args) 7354 compiled_out = compiled_f(list(cloned_args)) 7355 out = f(*args) 7356 self.assertEqual(cloned_args, args) 7357 self.assertEqual(compiled_out, out) 7358 7359 @config.patch(implicit_fallbacks=True) 7360 def test_fallback_mutable_op_no_mutated_tensors(self): 7361 with torch.library._scoped_library("mylib", "FRAGMENT") as m: 7362 7363 def impl(a, b): 7364 if b is not None: 7365 b.add_(1) 7366 7367 m.define("inplace_(Tensor a, Tensor(b!)? b) -> ()") 7368 m.impl("inplace_", impl, "CompositeExplicitAutograd") 7369 7370 def f(a): 7371 torch.ops.mylib.inplace_(a, None) 7372 return () 7373 7374 a = torch.tensor([0.0, 1.0, 2]) 7375 args = (a,) 7376 cloned_args = pytree.tree_map_only(torch.Tensor, torch.clone, args) 7377 mod = make_fx(f)(*cloned_args) 7378 cloned_args = pytree.tree_map_only(torch.Tensor, torch.clone, args) 7379 compiled_f = compile_fx_inner(mod, cloned_args) 7380 7381 cloned_args = pytree.tree_map_only(torch.Tensor, torch.clone, args) 7382 compiled_f(list(cloned_args)) 7383 f(*args) 7384 self.assertEqual(cloned_args, args) 7385 7386 @config.patch(implicit_fallbacks=True) 7387 def test_fallback_mutable_op_list(self): 7388 with torch.library._scoped_library("mylib", "FRAGMENT") as m: 7389 7390 def impl(a, b): 7391 for bi in b: 7392 bi.add_(a) 7393 7394 m.define("inplace_(Tensor a, Tensor(a!)[] b) -> ()") 7395 m.impl("inplace_", impl, "CompositeExplicitAutograd") 7396 7397 def f(a, b): 7398 torch.ops.mylib.inplace_(a, b) 7399 return () 7400 7401 a = torch.tensor([0.0, 1.0, 2]) 7402 b = [torch.tensor([2.0, 3.0, 5.0]), torch.tensor([1.0, 4.0, 6.0])] 7403 args = (a, b) 7404 cloned_args = pytree.tree_map_only(torch.Tensor, torch.clone, args) 7405 mod = make_fx(f)(*cloned_args) 7406 cloned_args = pytree.tree_map_only(torch.Tensor, torch.clone, args) 7407 7408 with self.assertRaisesRegex( 7409 torch._inductor.exc.LoweringException, 7410 "NYI: Can't generate FallbackKernel", 7411 ): 7412 compiled_f = compile_fx_inner(mod, cloned_args) 7413 7414 @expectedFailureXPU 7415 def test_functionalize_rng_wrappers(self): 7416 # Ideally, we would like to use torch.compile for these operators. But 7417 # currently the plan is to introduce these operators at the partitioner 7418 # level, obviating the need to support them fully through the 7419 # torch.compile stack. To ensure that we have good enough debugging with 7420 # minifiers, we have ensure that they work with make_fx. This test uses 7421 # make_fx to do the testing. In future, we can move on torch.compile. 7422 def fn(): 7423 rng_state1, a1 = torch._prims.rng_prims.run_and_save_rng_state( 7424 torch.ops.aten.rand.default, 7425 [4, 4], 7426 dtype=torch.float32, 7427 device=self.device, 7428 ) 7429 rng_state2, a2 = torch._prims.rng_prims.run_and_save_rng_state( 7430 torch.ops.aten.rand.default, 7431 [4, 4], 7432 dtype=torch.float32, 7433 device=self.device, 7434 ) 7435 7436 b1 = torch._prims.rng_prims.run_with_rng_state( 7437 rng_state1, 7438 torch.ops.aten.rand.default, 7439 [4, 4], 7440 dtype=torch.float32, 7441 device=self.device, 7442 ) 7443 b2 = torch._prims.rng_prims.run_with_rng_state( 7444 rng_state2, 7445 torch.ops.aten.rand.default, 7446 [4, 4], 7447 dtype=torch.float32, 7448 device=self.device, 7449 ) 7450 7451 return (a1, a2, b1, b2) 7452 7453 mod = make_fx(fn)() 7454 compiled_f = compile_fx_inner(mod, ()) 7455 a1, a2, b1, b2 = compiled_f(()) 7456 self.assertEqual(a1, b1) 7457 self.assertEqual(a2, b2) 7458 7459 @patch.object(torch._functorch.config, "functionalize_rng_ops", True) 7460 @expectedFailureXPU 7461 def test_philox_rand(self): 7462 if self.device == "cpu": 7463 raise unittest.SkipTest( 7464 f"functionalization of rng ops supported only on {GPU_TYPE}" 7465 ) 7466 7467 @torch._dynamo.optimize("inductor") 7468 def fn(x): 7469 a = torch.rand_like(x) * x 7470 a = torch.rand_like(x) * a 7471 return a 7472 7473 def check(x): 7474 torch.manual_seed(123) 7475 a = fn(x) 7476 7477 torch.manual_seed(1234) 7478 b = fn(x) 7479 7480 torch.manual_seed(123) 7481 c = fn(x) 7482 7483 # same seed, same values 7484 self.assertTrue(torch.allclose(a, c)) 7485 7486 # different calls, different values 7487 self.assertFalse(torch.allclose(a, b)) 7488 7489 check(torch.ones(1024, device=self.device, dtype=torch.float32)) 7490 # Need comment: should we add "_get_rng_state_offset" to common device interface? 7491 self.assertEqual(getattr(torch, self.device)._get_rng_state_offset(), 2048) 7492 # Check non-multiple of 4 numel 7493 check(torch.ones(3, device=self.device, dtype=torch.float32)) 7494 self.assertEqual(getattr(torch, self.device)._get_rng_state_offset(), 8) 7495 7496 # Already on by default, just want to make sure 7497 @patch.object(torch._inductor.config, "allow_buffer_reuse", True) 7498 def test_reuse_buffers_with_aliasing(self): 7499 def f(x): 7500 z = x + 1 7501 z = torch.view_as_complex(z) 7502 a = torch.view_as_real(z) 7503 out = a + 1 7504 return out, torch.view_as_real(z + 1) 7505 7506 self.common(f, (torch.zeros((4, 2)),)) 7507 7508 code = run_and_get_triton_code(torch.compile(f), torch.zeros((4, 2))) 7509 # Make sure that we haven't added complex support and made this test 7510 # invalid. If we've added complex support please update the test to use 7511 # a different set of view ops we don't lower 7512 self.assertTrue("aten.view_as_real" in code) 7513 7514 def f2(x): 7515 z = x + 1 7516 z = torch.view_as_complex(z) 7517 z = torch.view_as_real(z) 7518 z = torch.view_as_complex(z) 7519 a = torch.view_as_real(z) 7520 out = a + 1 7521 return out, torch.view_as_real(z + 1) 7522 7523 self.common(f, (torch.zeros((4, 2)),)) 7524 7525 def test_randn_like_empty(self): 7526 class Model(torch.nn.Module): 7527 def __init__( 7528 self, 7529 ): 7530 super().__init__() 7531 7532 def forward(self, v1: torch.Tensor): 7533 vx = v1.min(dim=1).values 7534 v2 = torch.randn_like(vx) 7535 return v2 7536 7537 model = Model() 7538 x = torch.rand(10, 3, 0) 7539 7540 self.common(model, (x,)) 7541 7542 def test_randint(self): 7543 @torch.compile(fullgraph=True) 7544 def fn(x): 7545 return ( 7546 torch.randint(10, [1024], device=x.device), 7547 torch.randint(-4, 7, [1024], dtype=torch.int32, device=x.device), 7548 torch.randint_like(x, 2**50), 7549 ) 7550 7551 torch.manual_seed(12345) 7552 a0, b0, c0 = fn(torch.zeros([40, 40], device=self.device)) 7553 self.assertEqual(a0.shape, [1024]) 7554 self.assertEqual(b0.shape, [1024]) 7555 self.assertEqual(c0.shape, [40, 40]) 7556 torch.manual_seed(12345) 7557 a1, b1, c1 = fn(torch.zeros([40, 40], device=self.device)) 7558 self.assertEqual(a0, a1) 7559 self.assertEqual(b0, b1) 7560 self.assertEqual(c0, c1) 7561 7562 self.assertEqual(a0.min(), 0) 7563 self.assertEqual(a0.max(), 9) 7564 7565 self.assertEqual(b0.min(), -4) 7566 self.assertEqual(b0.max(), 6) 7567 7568 self.assertGreaterEqual(c0.min(), 0) 7569 self.assertGreater(c0.max(), 2**40) 7570 self.assertLess(c0.max(), 2**50) 7571 7572 @config.patch(fallback_random=True) 7573 def test_like_rands(self): 7574 def fn(x): 7575 return torch.rand_like(x), torch.randn_like(x) 7576 7577 self.common(fn, [torch.zeros([20, 20])]) 7578 7579 def test_like_rands2(self): 7580 # rand_like with kwargs `device` of str type 7581 d = self.device 7582 assert isinstance(d, str) 7583 7584 @torch.compile 7585 def fn(x): 7586 return torch.rand_like(x, device=d) 7587 7588 x = torch.ones(10, device=self.device, dtype=torch.float32) 7589 a0 = fn(x).clone() 7590 a1 = fn(x).clone() 7591 self.assertFalse(torch.allclose(a0, a1)) 7592 7593 @requires_gpu() 7594 def test_like_rands3(self): 7595 # rand_like with `device` which is different from `x.device` 7596 def test_like_rands_on_different_device(device1, device2): 7597 @torch.compile 7598 def fn(x, device): 7599 return torch.rand_like(x, device=device) 7600 7601 x = torch.ones(10, device=device1, dtype=torch.float32) 7602 return fn(x, device2).clone() 7603 7604 a0 = test_like_rands_on_different_device("cpu", GPU_TYPE) 7605 a1 = test_like_rands_on_different_device(GPU_TYPE, "cpu") 7606 self.assertTrue(a0.device.type == GPU_TYPE) 7607 self.assertTrue(a1.device.type == "cpu") 7608 7609 def test_max_pool2d_with_indices_backward(self): 7610 def fn(a, b, c): 7611 return aten.max_pool2d_with_indices_backward( 7612 a, b, [2, 2], [2, 2], [0, 0], [1, 1], False, c 7613 ) 7614 7615 x = torch.randn([2, 4, 18, 14]) 7616 result, indices = aten.max_pool2d_with_indices( 7617 x, 7618 [2, 2], 7619 [2, 2], 7620 [0, 0], 7621 [1, 1], 7622 False, 7623 ) 7624 7625 self.common( 7626 fn, 7627 [ 7628 torch.randn_like(result), 7629 x, 7630 indices, 7631 ], 7632 ) 7633 7634 def test_max_pool2d_with_indices_backward2(self): 7635 def fn(a, b, c): 7636 return aten.max_pool2d_with_indices_backward( 7637 a, b, [3, 3], [2, 2], [1, 1], [1, 1], True, c 7638 ) 7639 7640 x = torch.randn([2, 4, 40, 56]) 7641 result, indices = aten.max_pool2d_with_indices( 7642 x, 7643 [3, 3], 7644 [2, 2], 7645 [1, 1], 7646 [1, 1], 7647 True, 7648 ) 7649 7650 self.common( 7651 fn, 7652 [ 7653 torch.randn_like(result), 7654 x, 7655 indices, 7656 ], 7657 ) 7658 7659 # From https://github.com/pytorch/torchdynamo/issues/1200 7660 def test_max_pool2d_with_indices_backward3(self): 7661 def fn(a, b, c): 7662 return aten.max_pool2d_with_indices_backward( 7663 a, b, [1, 1], [2, 2], [0, 0], [1, 1], False, c 7664 ) 7665 7666 x = torch.randn([32, 256, 37, 38]) 7667 result, indices = aten.max_pool2d_with_indices( 7668 x, 7669 [1, 1], 7670 [2, 2], 7671 0, 7672 1, 7673 False, 7674 ) 7675 self.common( 7676 fn, 7677 [ 7678 torch.randn_like(result), 7679 x, 7680 indices, 7681 ], 7682 ) 7683 7684 # From https://github.com/pytorch/torchdynamo/issues/1352 7685 def test_max_pool2d_with_indices_backward4(self): 7686 def fn(a, b, c): 7687 return aten.max_pool2d_with_indices_backward( 7688 a, b, [5, 5], [1, 1], [2, 2], [1, 1], False, c 7689 ) 7690 7691 torch._inductor.metrics.generated_kernel_count = 0 7692 x = torch.randn([2, 64, 3, 4]) 7693 result, indices = aten.max_pool2d_with_indices( 7694 x, 7695 [5, 5], 7696 [1, 1], 7697 2, 7698 1, 7699 False, 7700 ) 7701 self.common( 7702 fn, 7703 [ 7704 torch.randn_like(result), 7705 x, 7706 indices, 7707 ], 7708 ) 7709 assertGeneratedKernelCountEqual(self, 1) 7710 7711 @expectedFailureXPU 7712 def test_max_pool2d_with_indices_backward5(self): 7713 # Window size is too big. Should fallback 7714 def fn(a, b, c): 7715 return aten.max_pool2d_with_indices_backward( 7716 a, b, [13, 13], [1, 1], [2, 2], [1, 1], False, c 7717 ) 7718 7719 torch._inductor.metrics.generated_kernel_count = 0 7720 x = torch.randn([2, 64, 20, 20]) 7721 result, indices = aten.max_pool2d_with_indices( 7722 x, 7723 [13, 13], 7724 [1, 1], 7725 2, 7726 1, 7727 False, 7728 ) 7729 self.common( 7730 fn, 7731 [ 7732 torch.randn_like(result), 7733 x, 7734 indices, 7735 ], 7736 ) 7737 assertGeneratedKernelCountEqual(self, 0) 7738 7739 # From https://github.com/pytorch/pytorch/issues/93384 7740 def test_max_pool2d_with_indices_backward6(self): 7741 # dilation is not 1. Should fallback 7742 def fn(a, b, c): 7743 return aten.max_pool2d_with_indices_backward( 7744 a, b, [3, 2], [2, 1], [1, 1], [1, 2], False, c 7745 ) 7746 7747 torch._inductor.metrics.generated_kernel_count = 0 7748 x = torch.randn([2, 2, 3, 6]) 7749 result, indices = aten.max_pool2d_with_indices( 7750 x, 7751 [3, 2], 7752 [2, 1], 7753 [1, 1], 7754 [1, 2], 7755 False, 7756 ) 7757 self.common( 7758 fn, 7759 [ 7760 torch.randn_like(result), 7761 x, 7762 indices, 7763 ], 7764 ) 7765 assertGeneratedKernelCountEqual(self, 0) 7766 7767 def test_issue102546(self): 7768 def fn(x): 7769 return x.mean(0) 7770 7771 self.common(fn, [torch.rand(())]) 7772 7773 def test_avg_pool2d_backward(self): 7774 def fn(a, b): 7775 return aten.avg_pool2d_backward( 7776 a, 7777 b, 7778 [2, 2], 7779 [2, 2], 7780 [0, 0], 7781 True, 7782 False, 7783 None, 7784 ) 7785 7786 self.common( 7787 fn, 7788 [ 7789 torch.randn([2, 4, 7, 7]), 7790 torch.randn([2, 4, 14, 14]), 7791 ], 7792 ) 7793 7794 def test_avg_pool2d_backward2(self): 7795 def fn(a, b): 7796 return aten.avg_pool2d_backward( 7797 a, 7798 b, 7799 [3, 3], 7800 [1, 1], 7801 [1, 1], 7802 True, 7803 False, 7804 None, 7805 ) 7806 7807 self.common( 7808 fn, 7809 [ 7810 torch.randn([1, 1, 20, 15]), 7811 torch.randn([1, 1, 20, 15]), 7812 ], 7813 ) 7814 7815 def test_avg_pool2d_backward3(self): 7816 def fn(a, b): 7817 return aten.avg_pool2d_backward( 7818 a, 7819 b, 7820 [1, 1], 7821 [2, 2], 7822 [0, 0], 7823 False, 7824 False, 7825 None, 7826 ) 7827 7828 torch._inductor.metrics.generated_kernel_count = 0 7829 self.common( 7830 fn, 7831 [ 7832 torch.randn([1, 2016, 11, 11]), 7833 torch.randn([1, 2016, 21, 21]), 7834 ], 7835 ) 7836 assertGeneratedKernelCountEqual(self, 1) 7837 7838 def test_avg_pool2d_backward4(self): 7839 def fn(a, b): 7840 return aten.avg_pool2d_backward( 7841 a, 7842 b, 7843 [13, 13], 7844 [1, 1], 7845 [0, 0], 7846 True, 7847 False, 7848 None, 7849 ) 7850 7851 torch._inductor.metrics.generated_kernel_count = 0 7852 self.common( 7853 fn, 7854 [ 7855 torch.randn([1, 16, 12, 12]), 7856 torch.randn([1, 16, 24, 24]), 7857 ], 7858 check_lowp=False, 7859 ) 7860 assertGeneratedKernelCountEqual(self, 0) 7861 7862 def test_avg_pool3d_backward(self): 7863 def fn(a, b): 7864 return aten.avg_pool3d_backward( 7865 a, 7866 b, 7867 [2, 2, 2], 7868 [2, 2, 2], 7869 [0, 0, 0], 7870 True, 7871 False, 7872 None, 7873 ) 7874 7875 self.common( 7876 fn, 7877 [ 7878 torch.randn([2, 4, 7, 7, 7]), 7879 torch.randn([2, 4, 14, 14, 14]), 7880 ], 7881 ) 7882 7883 def test_avg_pool3d_backward2(self): 7884 def fn(a, b): 7885 return aten.avg_pool3d_backward( 7886 a, 7887 b, 7888 [3, 3, 3], 7889 [1, 1, 1], 7890 [1, 1, 1], 7891 True, 7892 False, 7893 None, 7894 ) 7895 7896 self.common( 7897 fn, 7898 [ 7899 torch.randn([1, 1, 20, 20, 15]), 7900 torch.randn([1, 1, 20, 20, 15]), 7901 ], 7902 ) 7903 7904 def test_avg_pool3d_backward3(self): 7905 def fn(a, b): 7906 return aten.avg_pool3d_backward( 7907 a, 7908 b, 7909 [1, 1, 1], 7910 [2, 2, 2], 7911 [0, 0, 0], 7912 False, 7913 False, 7914 None, 7915 ) 7916 7917 torch._inductor.metrics.generated_kernel_count = 0 7918 self.common( 7919 fn, 7920 [ 7921 torch.randn([1, 2016, 11, 11, 11]), 7922 torch.randn([1, 2016, 21, 21, 21]), 7923 ], 7924 ) 7925 assertGeneratedKernelCountEqual(self, 1) 7926 7927 def test_avg_pool3d_backward4(self): 7928 def fn(a, b): 7929 return aten.avg_pool3d_backward( 7930 a, 7931 b, 7932 [13, 13, 13], 7933 [1, 1, 1], 7934 [0, 0, 0], 7935 True, 7936 False, 7937 None, 7938 ) 7939 7940 torch._inductor.metrics.generated_kernel_count = 0 7941 self.common( 7942 fn, 7943 [ 7944 torch.randn([1, 16, 12, 12, 12]), 7945 torch.randn([1, 16, 24, 24, 24]), 7946 ], 7947 check_lowp=False, 7948 ) 7949 assertGeneratedKernelCountEqual(self, 0) 7950 7951 @config.patch(search_autotune_cache=False) 7952 def test_mm_views(self): 7953 def fn(a, b): 7954 return torch.mm(a.view(32, 32), b.view(32, 32)) 7955 7956 self.common( 7957 fn, 7958 ( 7959 torch.randn([32, 32]).transpose(0, 1), 7960 torch.randn([1, 32, 32]).transpose(0, 1), 7961 ), 7962 check_lowp=False, 7963 ) 7964 expected_kernel = 0 7965 # codegen mm kernel from template 7966 self.assertEqual( 7967 torch._inductor.metrics.generated_kernel_count, expected_kernel 7968 ) 7969 7970 @torch._dynamo.config.patch(assume_static_by_default=False) 7971 def test_dtype_sympy_expr(self): 7972 @torch._dynamo.optimize_assert("inductor") 7973 def fn(a): 7974 y = a[..., :-1, :].contiguous() 7975 return y 7976 7977 result = fn(torch.randn([1, 2, 16, 4]).requires_grad_()) 7978 result.sum().backward() 7979 7980 def test_dropout2(self): 7981 n = 100000 7982 weight = torch.ones( 7983 n, device=self.device, dtype=torch.float32, requires_grad=True 7984 ) 7985 ones = torch.ones(n, device=self.device, dtype=torch.float32) 7986 7987 @torch._dynamo.optimize_assert("inductor") 7988 def run(x, train=True): 7989 return F.dropout(x * weight, 0.33, train) 7990 7991 def check(r, g): 7992 rmean = r.mean().item() 7993 gmean = g.mean().item() 7994 rcount = len(r.nonzero()) 7995 gcount = len(g.nonzero()) 7996 7997 # dropped elements should match 7998 self.assertTrue(same(r.nonzero(), g.nonzero())) 7999 self.assertEqual(rcount, gcount) 8000 8001 # dropped should be close to 0.33 8002 self.assertGreater(rcount, 0.64 * n) 8003 self.assertGreater(0.68 * n, rcount) 8004 8005 self.assertAlmostEqual(rmean, gmean) 8006 self.assertAlmostEqual(rmean, 1.0, places=2) 8007 8008 r1 = run(ones, train=False) 8009 r1.sum().backward() 8010 g1 = weight.grad.clone() 8011 # eval mode should be all ones 8012 self.assertTrue(same(r1, torch.ones_like(r1))) 8013 self.assertTrue(same(g1, torch.ones_like(g1))) 8014 8015 torch.manual_seed(1234) 8016 weight.grad.zero_() 8017 r2, (fw_code, bw_code) = run_fw_bw_and_get_code(lambda: run(ones)) 8018 if self.device == GPU_TYPE: 8019 self.assertEqual(fw_code.count("tl.rand"), 1) 8020 self.assertEqual(bw_code.count("tl.rand"), 0) 8021 g2 = weight.grad.clone() 8022 check(r2, g2) 8023 8024 torch.manual_seed(1234) 8025 weight.grad.zero_() 8026 r3 = run(ones) 8027 r3.sum().backward() 8028 g3 = weight.grad.clone() 8029 check(r3, g3) 8030 8031 # second run is same result as first 8032 self.assertTrue(same(r2, r3)) 8033 self.assertTrue(same(g2, g3)) 8034 8035 @config.patch(search_autotune_cache=False) 8036 def test_dropout3(self): 8037 m = torch.nn.Sequential( 8038 torch.nn.Linear(32, 32, bias=False), 8039 torch.nn.Dropout(), 8040 torch.nn.Linear(32, 32, bias=False), 8041 torch.nn.Dropout(), 8042 ).to(self.device) 8043 8044 @torch._dynamo.optimize_assert("inductor") 8045 def run(x): 8046 return m(x) 8047 8048 torch._inductor.metrics.generated_kernel_count = 0 8049 8050 result, (fw_code, bw_code) = run_fw_bw_and_get_code( 8051 lambda: run(torch.randn([8, 32], device=self.device)) 8052 ) 8053 8054 if self.device == GPU_TYPE: 8055 self.assertEqual(fw_code.count("tl.rand"), 2) 8056 self.assertEqual(bw_code.count("tl.rand"), 0) 8057 expected_kernel = 4 8058 8059 self.assertEqual( 8060 torch._inductor.metrics.generated_kernel_count, expected_kernel 8061 ) 8062 8063 def test_randint_kernel_count(self): 8064 @torch._dynamo.optimize_assert("inductor") 8065 def fn1(): 8066 random_tensor1 = torch.randint(10, [32], device=self.device) 8067 random_tensor2 = torch.randint(10, [32], device=self.device) 8068 random_tensor3 = torch.randint(10, [32], device=self.device) 8069 return random_tensor1, random_tensor2, random_tensor3 8070 8071 _, source_codes = run_and_get_code(fn1) 8072 if self.device == GPU_TYPE: 8073 self.assertEqual(len(source_codes), 1) 8074 self.assertEqual(source_codes[0].count("async_compile.triton"), 2) 8075 8076 def test_roll(self): 8077 def fn(a): 8078 return ( 8079 aten.roll(a, [-3, 10], [1, 2]), 8080 aten.roll(a, [5]), 8081 ) 8082 8083 self.common( 8084 fn, 8085 [ 8086 torch.randn([2, 56, 56, 16]), 8087 ], 8088 ) 8089 8090 def test_argmax_min_int32(self): 8091 # https://github.com/pytorch/pytorch/issues/94055 8092 def fn(a, b): 8093 c = a.argmax(3) 8094 return torch.min(b, c) 8095 8096 a = torch.rand(3, 4, 2, 1).int() 8097 b = torch.rand(2, 2, 1, 4, 1).int() 8098 self.common(fn, (a, b)) 8099 8100 def test_argmax_argmin1(self): 8101 def fn(x): 8102 return (aten.argmax(x), aten.argmin(x)) 8103 8104 self.common( 8105 fn, 8106 [ 8107 torch.randn([8, 256, 256]), 8108 ], 8109 ) 8110 8111 def test_argmax_argmin2(self): 8112 def fn(x): 8113 return ( 8114 aten.argmax(x, 0), 8115 aten.argmin(x, 0), 8116 aten.argmax(x, 1), 8117 aten.argmin(x, 1), 8118 ) 8119 8120 self.common(fn, (torch.randn([144, 144]),)) 8121 8122 def test_argmax_argmin_with_duplicates(self): 8123 def fn(x): 8124 return ( 8125 aten.argmax(x, 0), 8126 aten.argmin(x, 0), 8127 aten.argmax(x, 1), 8128 aten.argmin(x, 1), 8129 ) 8130 8131 # Unrolled reduction 8132 t1 = torch.randint(2, size=(6, 6)) 8133 self.common(fn, (t1,)) 8134 8135 # Persistent reduction 8136 t1 = torch.randint(8, size=(32, 32)) 8137 self.common(fn, (t1,)) 8138 8139 # Non-persistent reduction 8140 t1 = torch.randint(8, size=(1028, 1028)) 8141 self.common(fn, (t1,)) 8142 8143 def test_argmax_argmin_with_nan(self): 8144 def fn(x): 8145 return ( 8146 aten.argmax(x, 0), 8147 aten.argmin(x, 0), 8148 aten.argmax(x, 1), 8149 aten.argmin(x, 1), 8150 ) 8151 8152 if self.device == "cpu": 8153 raise unittest.SkipTest("broken on CPU") 8154 8155 # Unrolled reduction 8156 t1 = torch.randn((6, 6)) 8157 t1[:, 1] = float("nan") 8158 t1[:, 3] = float("nan") 8159 self.common(fn, (t1,)) 8160 8161 # Persistent reduction 8162 t1 = torch.randn((32, 32)) 8163 t1[:, 4] = float("nan") 8164 t1[:, 8] = float("nan") 8165 self.common(fn, (t1,)) 8166 8167 # Non-persistent reduction 8168 t1 = torch.randn((1028, 1028)) 8169 t1[:, 40] = float("nan") 8170 t1[:, 100] = float("nan") 8171 self.common(fn, (t1,)) 8172 8173 def test_conv_backward(self): 8174 def fn(rank4_inps, rank3_inps, rank5_inps): 8175 out1 = aten.convolution_backward( 8176 *rank4_inps, 8177 [C], 8178 [1, 1], 8179 [0, 0], 8180 [1, 1], 8181 False, 8182 [0, 0], 8183 1, 8184 [True, True, True], 8185 ) 8186 out2 = aten.convolution_backward( 8187 *rank4_inps, 8188 [C], 8189 [1, 1], 8190 [0, 0], 8191 [1, 1], 8192 False, 8193 [0, 0], 8194 1, 8195 [True, False, False], 8196 ) 8197 out3 = aten.convolution_backward( 8198 *rank3_inps, 8199 [C], 8200 [1], 8201 [0], 8202 [1], 8203 False, 8204 [0], 8205 1, 8206 [True, True, True], 8207 ) 8208 out4 = aten.convolution_backward( 8209 *rank5_inps, 8210 [C], 8211 [1, 1, 1], 8212 [0, 0, 0], 8213 [1, 1, 1], 8214 False, 8215 [0, 0, 0], 8216 1, 8217 [True, True, True], 8218 ) 8219 return (out1, out2, out3, out4) 8220 8221 B = 3 8222 C = 4 8223 H = 5 8224 grad_out = torch.randn(B, C, H - 2, H - 2, H - 2) 8225 inp = torch.randn(B, C, H, H, H) 8226 weight = torch.randn(C, C, 3, 3, 3) 8227 8228 def shrink_rank(x, rank): 8229 res = x 8230 while res.dim() > rank: 8231 res = torch.select(res, -1, 0) 8232 return res.contiguous() 8233 8234 rank4_inps = [shrink_rank(x, 4) for x in [grad_out, inp, weight]] 8235 rank3_inps = [shrink_rank(x, 4) for x in [grad_out, inp, weight]] 8236 rank5_inps = [shrink_rank(x, 5) for x in [grad_out, inp, weight]] 8237 8238 with torch.backends.cudnn.flags(enabled=True, allow_tf32=False): 8239 self.common( 8240 fn, 8241 [rank4_inps, rank3_inps, rank5_inps], 8242 ) 8243 8244 @unittest.skip( 8245 """ 8246 FIXME: In the case of having equally max/min elements, our implementation returns 8247 the last index instead of the first one 8248 """ 8249 ) 8250 def test_argmax_argmin3(self): 8251 def fn(x): 8252 return ( 8253 aten.argmax(x, 0), 8254 aten.argmin(x, 0), 8255 aten.argmax(x, -1), 8256 aten.argmin(x, -1), 8257 ) 8258 8259 self.common( 8260 fn, 8261 [torch.randint(0, 5, [10, 10])], 8262 ) 8263 8264 def test_vdd_clamp(self): 8265 def fn(x): 8266 return torch.clamp_min(x, 3) 8267 8268 self.common( 8269 fn, 8270 [ 8271 torch.randn([16], requires_grad=True) * 10, 8272 ], 8273 ) 8274 8275 def test_tmp_not_defined_issue1(self): 8276 def forward( 8277 primals_3, 8278 primals_4, 8279 add_tensor, 8280 convert_element_type_default, 8281 div_default, 8282 reciprocal_default, 8283 ): 8284 var_default = torch.ops.aten.var( 8285 convert_element_type_default, [2], correction=0 8286 ) 8287 sub_tensor = torch.ops.aten.sub.Tensor(add_tensor, div_default) 8288 mul_tensor_1 = torch.ops.aten.mul.Tensor(sub_tensor, reciprocal_default) 8289 mul_tensor_2 = torch.ops.aten.mul.Tensor(mul_tensor_1, primals_3) 8290 add_tensor_2 = torch.ops.aten.add.Tensor(mul_tensor_2, primals_4) 8291 convert_element_type_default_1 = add_tensor_2.to(dtype=torch.float32) 8292 convert_element_type_default_2 = convert_element_type_default_1.to( 8293 dtype=torch.float32 8294 ) 8295 var_default_1 = torch.ops.aten.var( 8296 convert_element_type_default_2, [2], correction=0 8297 ) 8298 broadcast_in_dim_default_2 = var_default_1.reshape(1, 512, 1) 8299 sum_default_1 = convert_element_type_default_2.sum(2) 8300 add_tensor_3 = torch.ops.aten.add.Tensor(broadcast_in_dim_default_2, 1e-05) 8301 return (var_default, sum_default_1, add_tensor_3) 8302 8303 inps = [ 8304 (torch.Size([1024]), torch.float32), 8305 (torch.Size([1024]), torch.float32), 8306 (torch.Size([1, 512, 1024]), torch.float32), 8307 (torch.Size([1, 512, 1024]), torch.float32), 8308 (torch.Size([1, 512, 1]), torch.float32), 8309 (torch.Size([1, 512, 1]), torch.float32), 8310 ] 8311 inps = [torch.randn(shape, dtype=dtype) for (shape, dtype) in inps] 8312 self.common(forward, inps, atol=1e-05, rtol=2e-05) 8313 8314 @unittest.skipIf( 8315 os.environ.get("BUILD_ENVIRONMENT", "").startswith("parallelnative"), 8316 "TODO: debug this with asan", 8317 ) 8318 def test_tmp_not_defined_issue2(self): 8319 def forward(arg38_1, arg81_1, getitem_17, new_zeros_default_4): 8320 div_tensor_7 = torch.ops.aten.div.Tensor(getitem_17, arg81_1) 8321 mul_tensor_24 = torch.ops.aten.mul.Tensor(div_tensor_7, arg38_1) 8322 sum_default_7 = torch.ops.aten.sum.default(mul_tensor_24) 8323 return (new_zeros_default_4, sum_default_7) 8324 8325 dtype = torch.float32 8326 args = [ 8327 ((1, 88, 40, 40), (140800, 1600, 40, 1), dtype), 8328 ((), (), dtype), 8329 ((1, 88, 40, 40), (140800, 1600, 40, 1), dtype), 8330 ((3,), (1,), dtype), 8331 ] 8332 args = [ 8333 rand_strided(shape, stride, dtype).requires_grad_(True).add(1) 8334 for shape, stride, dtype in args 8335 ] 8336 self.common(forward, args) 8337 8338 @requires_gpu() 8339 def test_tmp_not_defined_issue3(self): 8340 from torch import device 8341 8342 def forward( 8343 self, 8344 primals_1: "f32[1001, 6]", 8345 primals_2: "f32[1001]", 8346 primals_3: "f32[1001, 64]", 8347 primals_4: "f32[4190]", 8348 primals_5: "f32[4190]", 8349 primals_6: "f32[1739, 4190]", 8350 primals_48: "f32[6144, 4191]", 8351 ): 8352 _tensor_constant0: "i64[4190]" = self._tensor_constant0 8353 lift_fresh_copy: "i64[4190]" = torch.ops.aten.lift_fresh_copy.default( 8354 _tensor_constant0 8355 ) 8356 8357 index: "f32[6144, 4190]" = torch.ops.aten.index.Tensor( 8358 primals_48, [None, lift_fresh_copy] 8359 ) 8360 8361 _tensor_constant1: "i64[6]" = self._tensor_constant1 8362 lift_fresh_copy_1: "i64[6]" = torch.ops.aten.lift_fresh_copy.default( 8363 _tensor_constant1 8364 ) 8365 index_1: "f32[6144, 6]" = torch.ops.aten.index.Tensor( 8366 primals_48, [None, lift_fresh_copy_1] 8367 ) 8368 primals_48 = lift_fresh_copy_1 = None 8369 permute: "f32[6, 1001]" = torch.ops.aten.permute.default(primals_1, [1, 0]) 8370 addmm: "f32[6144, 1001]" = torch.ops.aten.addmm.default( 8371 primals_2, index_1, permute 8372 ) 8373 amax: "f32[6144, 1]" = torch.ops.aten.amax.default(addmm, [-1], True) 8374 sub: "f32[6144, 1001]" = torch.ops.aten.sub.Tensor(addmm, amax) 8375 exp: "f32[6144, 1001]" = torch.ops.aten.exp.default(sub) 8376 sum_1: "f32[6144, 1]" = torch.ops.aten.sum.dim_IntList(exp, [-1], True) 8377 div: "f32[6144, 1001]" = torch.ops.aten.div.Tensor(exp, sum_1) 8378 8379 full_default: "i32[6144, 1001]" = torch.ops.aten.full.default( 8380 [6144, 1001], 8381 1, 8382 dtype=torch.int32, 8383 layout=torch.strided, 8384 device=device(type=GPU_TYPE, index=0), 8385 pin_memory=False, 8386 ) 8387 8388 iota: "i32[1001]" = torch.ops.prims.iota.default( 8389 1001, 8390 start=0, 8391 step=1, 8392 dtype=torch.int32, 8393 device=device(type=GPU_TYPE), 8394 requires_grad=False, 8395 ) 8396 8397 mul: "i32[6144, 1001]" = torch.ops.aten.mul.Tensor(full_default, iota) 8398 iota_1: "i32[6144]" = torch.ops.prims.iota.default( 8399 6144, 8400 start=0, 8401 step=1001, 8402 dtype=torch.int32, 8403 device=device(type=GPU_TYPE, index=0), 8404 requires_grad=False, 8405 ) 8406 view: "i32[6150144]" = torch.ops.aten.reshape.default(mul, [-1]) 8407 view_1: "f32[6150144]" = torch.ops.aten.reshape.default(div, [-1]) 8408 _embedding_bag = torch.ops.aten._embedding_bag.default( 8409 primals_3, view, iota_1, False, 0, False, view_1 8410 ) 8411 getitem: "f32[6144, 64]" = _embedding_bag[0] 8412 getitem_1: "i32[6150144]" = _embedding_bag[1] 8413 getitem_2: "i32[6144]" = _embedding_bag[2] 8414 getitem_3: "i32[0]" = _embedding_bag[3] 8415 unsqueeze: "f32[6144, 1, 64]" = torch.ops.aten.unsqueeze.default(getitem, 1) 8416 var_mean = torch.ops.aten.var_mean.correction( 8417 index, [1], correction=0, keepdim=True 8418 ) 8419 getitem_4: "f32[6144, 1]" = var_mean[0] 8420 getitem_5: "f32[6144, 1]" = var_mean[1] 8421 add: "f32[6144, 1]" = torch.ops.aten.add.Tensor(getitem_4, 1e-05) 8422 rsqrt: "f32[6144, 1]" = torch.ops.aten.rsqrt.default(add) 8423 sub_1: "f32[6144, 4190]" = torch.ops.aten.sub.Tensor(index, getitem_5) 8424 mul_1: "f32[6144, 4190]" = torch.ops.aten.mul.Tensor(sub_1, rsqrt) 8425 mul_2: "f32[6144, 4190]" = torch.ops.aten.mul.Tensor(mul_1, primals_4) 8426 add_1: "f32[6144, 4190]" = torch.ops.aten.add.Tensor(mul_2, primals_5) 8427 permute_1: "f32[4190, 1739]" = torch.ops.aten.permute.default( 8428 primals_6, [1, 0] 8429 ) 8430 8431 return [ 8432 index, 8433 index_1, 8434 addmm, 8435 amax, 8436 sum_1, 8437 iota_1, 8438 view, 8439 view_1, 8440 getitem_1, 8441 getitem_2, 8442 getitem_3, 8443 unsqueeze, 8444 getitem_5, 8445 rsqrt, 8446 add_1, 8447 permute_1, 8448 ] 8449 8450 kwargs = aot_graph_input_parser(forward, device=GPU_TYPE) 8451 self.common(forward, [], kwargs=kwargs) 8452 8453 def test_misaligned_address_issue1(self): 8454 def forward(sub_tensor_1, unsqueeze_default): 8455 gather_default = torch.ops.aten.gather.default( 8456 sub_tensor_1, 1, unsqueeze_default 8457 ) 8458 return gather_default 8459 8460 args = [ 8461 ((1, 1000), (1000, 1), torch.float32), 8462 ((1, 1), (1, 1), torch.int64), 8463 ] 8464 args = [rand_strided(shape, stride, dtype) for shape, stride, dtype in args] 8465 self.common(forward, args) 8466 8467 def test_invalid_operand_issue1(self): 8468 def forward(arg0_1, arg1_1, arg3_1, squeeze, view_1, slice_1): 8469 slice_scatter = torch.ops.aten.slice_scatter.default( 8470 slice_1, arg3_1, 1, 1, 9223372036854775807 8471 ) 8472 slice_scatter_1 = torch.ops.aten.slice_scatter.default( 8473 arg1_1, slice_scatter, 0, 0, 9223372036854775807 8474 ) 8475 slice_2 = torch.ops.aten.slice.Tensor( 8476 slice_scatter_1, 0, 0, 9223372036854775807 8477 ) 8478 select_scatter = torch.ops.aten.select_scatter.default( 8479 slice_2, squeeze, 1, 0 8480 ) 8481 slice_scatter_2 = torch.ops.aten.slice_scatter.default( 8482 slice_scatter_1, select_scatter, 0, 0, 9223372036854775807 8483 ) 8484 view = torch.ops.aten.view.default(slice_scatter_2, [-1, 128]) 8485 embedding = torch.ops.aten.embedding.default(arg0_1, view, 1) 8486 return [embedding, view_1] 8487 8488 args = [ 8489 ((50005, 768), (768, 1), torch.float32), 8490 ((8, 128), (128, 1), torch.int64), 8491 ((8, 127), (127, 1), torch.int64), 8492 ((8,), (1,), torch.int64), 8493 ((1024,), (1,), torch.int64), 8494 ((8, 128), (128, 1), torch.int64), 8495 ] 8496 args = [rand_strided(shape, stride, dtype) for shape, stride, dtype in args] 8497 self.common(forward, args) 8498 8499 def test_sizehint_issue1(self): 8500 def forward(x): 8501 return torch.nn.functional.unfold( 8502 x, kernel_size=[4, 4], dilation=1, padding=0, stride=[4, 4] 8503 ) 8504 8505 args = [((2, 24, 56, 56), (75264, 3136, 56, 1), torch.float32, False)] 8506 args = [ 8507 rand_strided(sh, st, dt).requires_grad_(rg) for (sh, st, dt, rg) in args 8508 ] 8509 self.common(forward, args) 8510 8511 def test_zero_dim_reductions(self): 8512 for kd in [True, False]: 8513 inps0 = (torch.zeros(2, 0, device=self.device, dtype=torch.float16), 1, kd) 8514 failed_ops = [aten.argmin, aten.argmax, aten.max, aten.min] 8515 for fo in failed_ops: 8516 with self.assertRaisesRegex( 8517 IndexError, "Expected reduction dim 1 to have non-zero size" 8518 ): 8519 mod = make_fx(fo)(*inps0) 8520 _ = compile_fx_inner(mod, inps0) 8521 8522 pass_ops = [ 8523 lambda *x: fn(*x) for fn in [aten.sum, aten.prod, aten.any, aten.all] 8524 ] 8525 for po in pass_ops: 8526 compiled = torch._dynamo.optimize("inductor")(po) 8527 expected = po(*inps0) 8528 actual = compiled(*inps0) 8529 8530 self.assertTrue(torch.allclose(actual, expected, atol=1e-3, rtol=1e-3)) 8531 8532 def test_unfold_zero_dimension_tensor(self): 8533 def forward(x): 8534 return torch.unfold_copy(dimension=1, input=x, size=0, step=7) 8535 8536 x = torch.rand([1, 0], dtype=torch.float32) 8537 8538 y = forward(x) 8539 compiled_y = torch.compile(forward, fullgraph=True)(x) 8540 8541 self.assertEqual(y, compiled_y) 8542 8543 def test_zero_element_mutation(self): 8544 class CustomModel(nn.Module): 8545 def __init__(self): 8546 super().__init__() 8547 self.layer1 = nn.LeakyReLU(negative_slope=5.2955089, inplace=True) 8548 8549 def forward(self, inputs): 8550 return self.layer1(inputs) 8551 8552 ip_size = [0] 8553 input_tensor = torch.randn(ip_size) 8554 8555 mymodel = CustomModel() 8556 self.common(mymodel, (input_tensor,)) 8557 8558 def test_lerp(self): 8559 # non-contiguous inputs for lerp 8560 def fn0(i0, i1): 8561 x1 = i0.transpose(-2, -3) 8562 return torch.lerp(i1, x1, 70000) 8563 8564 # contiguous inputs for lerp 8565 def fn1(i0, i1): 8566 return torch.lerp(i1, i0, 70000) 8567 8568 self.common(fn0, [torch.rand(10, 3, 10), torch.rand(3, 10, 10)]) 8569 self.common(fn1, [torch.rand(3, 10, 10), torch.rand(3, 10, 10)]) 8570 8571 def test_unspec_inputs(self): 8572 if self.device == "cpu": 8573 raise unittest.SkipTest("Testing mixed devices") 8574 8575 def fn(x, y): 8576 return x + y, x * y, x / y 8577 8578 opt = torch._dynamo.optimize("inductor")(fn) 8579 dtypes = [ 8580 torch.float16, 8581 torch.bfloat16, 8582 torch.float32, 8583 torch.float64, 8584 torch.int32, 8585 torch.int64, 8586 ] 8587 8588 for d in dtypes: 8589 inputs = ( 8590 rand_strided((2, 3), (3, 1), dtype=torch.float32, device=GPU_TYPE), 8591 rand_strided((), (), dtype=d, device="cpu"), 8592 ) 8593 self.assertTrue(same(opt(*inputs), fn(*inputs))) 8594 inputs = (inputs[1], inputs[0]) 8595 self.assertTrue(same(opt(*inputs), fn(*inputs))) 8596 8597 @dynamo_config.patch(automatic_dynamic_shapes=True) 8598 def test_list_clearing(self): 8599 if self.device == "cpu": 8600 contexts = [contextlib.nullcontext] 8601 else: 8602 contexts = [ 8603 contextlib.nullcontext, 8604 lambda: config.patch({"triton.cudagraphs": True}), 8605 ] 8606 8607 for context in contexts: 8608 with context(): 8609 inps = [ 8610 torch.rand([5, 5]).to(self.device), 8611 torch.rand([5, 5]).to(self.device), 8612 ] 8613 inp_refs = [weakref.ref(inp) for inp in inps] 8614 8615 def fn(x, y): 8616 a = x + y 8617 return (a @ a,) 8618 8619 fn_fx = make_fx(fn)(inps[0], inps[1]) 8620 fn_compiled = compile_fx_inner(fn_fx, inps) 8621 8622 test_self = self 8623 matmul_seen = False 8624 8625 class TestRefMode(TorchDispatchMode): 8626 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 8627 kwargs = kwargs if kwargs else {} 8628 8629 nonlocal inps 8630 nonlocal inp_refs 8631 nonlocal test_self 8632 nonlocal matmul_seen 8633 8634 # by matmul, inputs should be deallocated 8635 # TODO: should not be necessary, ref-cycle ? 8636 gc.collect() 8637 if func is aten.mm.out: 8638 matmul_seen = True 8639 test_self.assertEqual(len(inps), 0) 8640 test_self.assertIsNone(inp_refs[0]()) 8641 test_self.assertIsNone(inp_refs[1]()) 8642 8643 return func(*args, **kwargs) 8644 8645 with TestRefMode(): 8646 fn_compiled(inps) 8647 8648 # do an extra run to make sure we are deallocating on warmup and record 8649 if self.device == GPU_TYPE: 8650 inps.extend( 8651 [ 8652 torch.rand([5, 5]).to(self.device), 8653 torch.rand([5, 5]).to(self.device), 8654 ] 8655 ) 8656 inp_refs.extend([weakref.ref(inp) for inp in inps]) 8657 matmul_seen = False 8658 8659 with TestRefMode(): 8660 fn_compiled(inps) 8661 8662 # for some reason, TorchDispatch doesnt capture the 8663 # cuda mm call (even without cudagraphs) 8664 if self.device == "cpu": 8665 self.assertTrue(matmul_seen) 8666 else: 8667 self.assertEqual(len(inps), 0) 8668 8669 def test_dtype_mismatch_issue(self): 8670 def fn(x): 8671 attn = torch.nn.functional.pad(x, [0, 1]) 8672 return attn.softmax(dim=-1) 8673 8674 x = torch.rand(128, 32, 63) 8675 self.common(fn, (x,)) 8676 8677 def test_diagonal_copy(self): 8678 def fn(x): 8679 return torch.diagonal_copy(x) 8680 8681 for x in (torch.randn(2, 3), torch.randn(2, 2), torch.randn(3, 2)): 8682 self.common(fn, (x,)) 8683 8684 def test_kwargs(self): 8685 if self.device == GPU_TYPE: 8686 raise unittest.SkipTest("histogramdd only supports cpu") 8687 8688 def fn(x, y): 8689 return torch.histogramdd( 8690 x, 8691 bins=[3, 3], 8692 weight=y, 8693 ) 8694 8695 self.common( 8696 fn, 8697 [torch.randn((4, 2)), torch.randn(4)], 8698 ) 8699 8700 # Shape padding causes the inputs to all get specialized, so the codegen 8701 # test fails 8702 @expectedFailureCodegenDynamic 8703 @requires_gpu() 8704 @torch._inductor.config.patch("shape_padding", True) 8705 def test_shape_padding(self): 8706 dtypes = [ 8707 torch.float16, 8708 torch.float32, 8709 ] 8710 8711 b, m, n, k = 7, 11, 13, 15 8712 8713 def gen(*shape, dtype=torch.float32): 8714 return torch.randn(*shape, device=GPU_TYPE, dtype=dtype) / k + 1.0 8715 8716 for dtype in dtypes: 8717 x = gen(m, k, dtype=dtype) 8718 y = gen(k, n, dtype=dtype) 8719 z = gen(n, dtype=dtype) 8720 self.common(lambda x, y: torch.mm(x, y), (x, y)) 8721 self.common(lambda x, y: torch.matmul(x, y), (x, y)) 8722 self.common(lambda x, y, z: torch.addmm(z, x, y), (x, y, z)) 8723 8724 for dtype in dtypes: 8725 x = gen(b, m, k, dtype=dtype) 8726 y = gen(b, k, n, dtype=dtype) 8727 z = gen(n, dtype=dtype) 8728 self.common(lambda x, y: torch.bmm(x, y), (x, y)) 8729 self.common(lambda x, y: torch.matmul(x, y), (x, y)) 8730 self.common(lambda x, y, z: torch.baddbmm(z, x, y), (x, y, z)) 8731 8732 @requires_gpu() 8733 @torch._inductor.config.patch("layout_optimization", True) 8734 def test_inductor_layout_optimization_input_mutations(self): 8735 # channel dim must be > 64 for inductor to do layout optimization and use NHWC 8736 mod = nn.Conv2d(3, 128, 1, stride=1, bias=False).to(GPU_TYPE) 8737 8738 def f(x): 8739 x.mul_(2) 8740 out = mod(x) 8741 return out 8742 8743 f_compiled = torch.compile(f) 8744 x_ref = torch.rand(2, 3, 128, 128, device=GPU_TYPE) 8745 x_test = x_ref.clone().detach() 8746 with torch.no_grad(): 8747 out_ref = f(x_ref) 8748 out_test = f_compiled(x_test) 8749 self.assertEqual(out_ref, out_test) 8750 self.assertEqual(out_ref.shape, out_test.shape) 8751 # Importantly, since inductor._config.keep_output_stride is True, 8752 # the outputs should have matching strides here. 8753 self.assertEqual(out_ref.stride(), out_test.stride()) 8754 self.assertEqual(x_ref, x_test) 8755 8756 def test_int_input_dynamic_shapes(self): 8757 @torch.compile(dynamic=True) 8758 def fn(x, i): 8759 y = x * i 8760 return y 8761 8762 # Constant must not get matched as constant 8763 self.common(fn, [torch.randn(3, 1, 1, 1, 1), 9132]) 8764 8765 def test_sqrt_dynamic_shapes(self): 8766 # TIMM convit_base model: https://github.com/pytorch/pytorch/issues/97877. 8767 # TODO: support cuda path. 8768 if self.device == GPU_TYPE: 8769 raise unittest.SkipTest("sqrt dynamic shapes only supports cpu") 8770 8771 class Model(torch.nn.Module): 8772 def __init__(self): 8773 super().__init__() 8774 8775 def forward(self, x): 8776 B, N, C = x.shape 8777 return self.get_rel_indices(N) 8778 8779 def get_rel_indices(self, num_patches: int) -> torch.Tensor: 8780 img_size = int(num_patches**0.5) 8781 ind = torch.arange(img_size) 8782 return ind 8783 8784 self.common( 8785 Model(), 8786 [ 8787 torch.randn(8, 4, 4), 8788 ], 8789 ) 8790 8791 def test_rsqrt_dynamic_shapes(self): 8792 # From HF hf_BigBird model. 8793 @torch.compile(dynamic=True) 8794 def fn(a, b): 8795 r = 1 / math.sqrt(a.size(1)) 8796 return torch.bmm(a, b) / r 8797 8798 self.common( 8799 fn, 8800 [ 8801 torch.randn(2, 4, 4), 8802 torch.randn(2, 4, 4), 8803 ], 8804 ) 8805 8806 def test_index_dynamic_shapes(self): 8807 # Repro from vision_maskrcnn 8808 def fn(arg0_1): 8809 unsqueeze = arg0_1.unsqueeze(0) 8810 sym_size = arg0_1.size(1) 8811 ceil = math.ceil(sym_size * 1.8735363483428955) 8812 iota = torch.ops.prims.iota.default( 8813 ceil, 8814 start=0, 8815 step=1, 8816 dtype=torch.int64, 8817 device=arg0_1.device, 8818 requires_grad=False, 8819 ) 8820 convert_element_type_1 = iota.to(torch.float32) 8821 sym_size_1 = arg0_1.size(2) 8822 floor_1 = math.floor(sym_size_1 * 1.8735363483428955) 8823 ceil_1 = math.ceil(floor_1) 8824 iota_1 = torch.ops.prims.iota.default( 8825 ceil_1, 8826 start=0, 8827 step=1, 8828 dtype=torch.int64, 8829 device=arg0_1.device, 8830 requires_grad=False, 8831 ) 8832 convert_element_type_3 = iota_1.to(torch.float32) 8833 sub_2 = (convert_element_type_1 + 0.5) * (sym_size / ceil) - 0.5 8834 clamp_min = sub_2.clamp_min(0.0) 8835 sub_3 = (convert_element_type_3 + 0.5) * (sym_size_1 / floor_1) - 0.5 8836 clamp_min_1 = sub_3.clamp_min(0.0) 8837 convert_element_type_4 = clamp_min.to(torch.int64) 8838 sub_4 = sym_size - 1 8839 clamp_max = clamp_min.ceil().clamp_max(sub_4) 8840 convert_element_type_5 = clamp_max.to(torch.int64) 8841 convert_element_type_6 = clamp_min_1.to(torch.int64) 8842 unsqueeze_2 = convert_element_type_4.unsqueeze(1) 8843 index = torch.ops.aten.index.Tensor( 8844 unsqueeze, [None, None, unsqueeze_2, convert_element_type_6] 8845 ) 8846 index_1 = torch.ops.aten.index.Tensor( 8847 unsqueeze, 8848 [ 8849 None, 8850 None, 8851 convert_element_type_5.unsqueeze(1), 8852 convert_element_type_6, 8853 ], 8854 ) 8855 sub_6 = clamp_min.unsqueeze(1) - unsqueeze_2 8856 mul_10 = (index * (1.0 - sub_6) + index_1 * (sub_6)) * ( 8857 1.0 - (clamp_min_1 - convert_element_type_6) 8858 ) 8859 select = torch.ops.aten.select.int(mul_10, 0, 0) 8860 return (select,) 8861 8862 x = torch.randn(15, 20, 3) 8863 self.common( 8864 fn, 8865 [x], 8866 ) 8867 8868 def test_setitem_with_int_parameter(self): 8869 x = torch.zeros(7, device=self.device) 8870 8871 def fn(n, a): 8872 a[n] = -1 8873 return a 8874 8875 cnts = CompileCounterWithBackend("inductor") 8876 opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 8877 8878 for n in range(2, x.shape[0]): 8879 opt_fn(n, x) 8880 self.assertEqual(x[n], -1) 8881 8882 # If assume_static_by_default is set, the calls above will trigger 8883 # 3 function compilation: 8884 # 1. assuming 'n' is static (equals 2) 8885 # 2. making 'n' dynamic, but with the guard 'end <= x.shape[0]' 8886 # (from: torch._inductor.ir.SliceView.create) 8887 frame_count = 2 if torch._dynamo.config.assume_static_by_default else 1 8888 self.assertEqual(cnts.frame_count, frame_count) 8889 8890 # Negative index triggers new compilation. 8891 opt_fn(-x.shape[0], x) 8892 self.assertEqual(x[0], -1) 8893 self.assertEqual(cnts.frame_count, frame_count + 1) 8894 8895 @config.patch(profiler_mark_wrapper_call=True) 8896 def test_profiler_mark_wrapper_call(self): 8897 from torch.profiler import profile 8898 8899 @torch._dynamo.optimize("inductor", nopython=True) 8900 def fn(a, b): 8901 return a + b 8902 8903 a = torch.rand((100,)) 8904 b = torch.rand((100,)) 8905 with profile() as prof: 8906 fn(a, b) 8907 assert any( 8908 "inductor_wrapper_call" in e.name for e in prof.profiler.function_events 8909 ) 8910 8911 def test_insignificant_strides(self): 8912 def f(x): 8913 tmp = x + 1 8914 return tmp.view(-1, 1, 2) 8915 8916 x = torch.arange(8, device=self.device, dtype=torch.float32) 8917 out = f(x) 8918 compiled_out = torch.compile(f)(x) 8919 8920 self.assertEqual(out.stride(), compiled_out.stride()) 8921 self.assertEqual(out, compiled_out) 8922 8923 @unittest.skipIf(IS_X86 and not HAS_AVX2, "Requires AVX2") 8924 def test_pixel_shuffle_channels_last(self): 8925 def fn(x): 8926 x = torch.nn.functional.pixel_shuffle(x, 2) 8927 x = torch.nn.functional.relu(x) 8928 return x 8929 8930 self.common( 8931 fn, 8932 (torch.randn(1, 16, 64, 72).to(memory_format=torch.channels_last),), 8933 ) 8934 8935 def test_where_broadcast(self): 8936 # https://github.com/pytorch/pytorch/issues/93374 8937 def fn(x, p1, p0): 8938 o = torch.where(x, p1, p0) 8939 return o 8940 8941 # https://github.com/pytorch/pytorch/issues/94725 8942 class Repro(torch.nn.Module): 8943 def __init__(self): 8944 super().__init__() 8945 self.register_buffer( 8946 "_tensor_constant0", torch.randn([], dtype=torch.float32) 8947 ) 8948 8949 def forward(self, arg0_1, arg1_1): 8950 convert_element_type = torch.ops.prims.convert_element_type.default( 8951 arg1_1, torch.bool 8952 ) 8953 bitwise_not = torch.ops.aten.bitwise_not.default(convert_element_type) 8954 _tensor_constant0 = self._tensor_constant0 8955 lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default( 8956 _tensor_constant0 8957 ) 8958 where = torch.ops.aten.where.self(bitwise_not, lift_fresh_copy, arg0_1) 8959 return (where, bitwise_not) 8960 8961 self.common( 8962 fn, 8963 (torch.tensor([[True]]), torch.rand(13, 7, 3), torch.rand(1, 1)), 8964 ) 8965 8966 args = [ 8967 torch.randn(1, 4, 64, 64), 8968 torch.zeros(1, 1, 64, 64, dtype=torch.uint8), 8969 ] 8970 args[1][:, :, :32, :32] = 1 8971 eager_args = [x.clone() for x in args] 8972 eager_mod = Repro() 8973 mod = make_fx(eager_mod, tracing_mode="real")(*args) 8974 compiled = compile_fx_inner(mod, args) 8975 inductor_out = compiled(args) 8976 eager_out = eager_mod(*eager_args) 8977 self.assertEqual(inductor_out, eager_out) 8978 8979 @skipIfRocm 8980 def test_require_stride_expanded(self): 8981 def forward(arg6, arg7, arg16): 8982 convolution = torch.ops.aten.convolution( 8983 arg16.unsqueeze(0), arg7, arg6, [4, 4], [2, 2], [1, 1], False, [0, 0], 1 8984 ) 8985 return (convolution,) 8986 8987 self.common( 8988 forward, 8989 ( 8990 None, 8991 rand_strided( 8992 (64, 3, 11, 11), 8993 (363, 121, 11, 1), 8994 torch.float32, 8995 device=self.device, 8996 ).to(memory_format=torch.channels_last), 8997 rand_strided( 8998 (1, 3, 224, 224), 8999 (150528, 50176, 224, 1), 9000 torch.float32, 9001 device=self.device, 9002 ) 9003 .to(memory_format=torch.channels_last) 9004 .squeeze(0), 9005 ), 9006 atol=1e-3, 9007 rtol=0.001, 9008 ) 9009 9010 # expanded dim should not cause copy in require_stride_order 9011 assertGeneratedKernelCountEqual(self, 0) 9012 9013 @requires_gpu() 9014 @unittest.skipIf( 9015 not PLATFORM_SUPPORTS_FLASH_ATTENTION, 9016 "Does not support SDPA or pre-SM80 hardware", 9017 ) 9018 @skipIfRocm 9019 def test_sdpa(self): 9020 def foo(arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): 9021 view = torch.ops.aten.view.default(arg3_1, [23760, 128]) 9022 arg3_1 = None 9023 mm = torch.ops.aten.mm.default(view, arg4_1) 9024 view = arg4_1 = None 9025 view_1 = torch.ops.aten.view.default(mm, [3, 99, 80, 8]) 9026 mm = None 9027 view_2 = torch.ops.aten.view.default(view_1, [3, 99, 80, 8]) 9028 view_1 = None 9029 permute = torch.ops.aten.permute.default(view_2, [0, 3, 1, 2]) 9030 view_2 = None 9031 view_3 = torch.ops.aten.view.default(permute, [3, 8, 99, 80]) 9032 permute = None 9033 9034 clone = torch.ops.aten.clone.default( 9035 view_3, memory_format=torch.contiguous_format 9036 ) 9037 view_3 = None 9038 9039 expand = torch.ops.aten.expand.default(clone, [3, 8, 99, 80]) 9040 clone = None 9041 _scaled_dot_product_efficient_attention = ( 9042 torch.ops.aten._scaled_dot_product_efficient_attention.default( 9043 arg0_1, arg1_1, arg2_1, expand, False 9044 ) 9045 ) 9046 arg0_1 = arg1_1 = arg2_1 = expand = None 9047 getitem = _scaled_dot_product_efficient_attention[0] 9048 _scaled_dot_product_efficient_attention = None 9049 return (getitem,) 9050 9051 DEVICE = torch.device(f"{GPU_TYPE}:0") 9052 DTYPE = torch.float16 9053 B = 3 9054 H = 8 9055 Q = 99 9056 K = 80 9057 D = 32 9058 C_bias = 128 9059 9060 # inputs 9061 query = torch.randn((B, H, Q, D), device=DEVICE, dtype=DTYPE) 9062 key = torch.randn((B, H, K, D), device=DEVICE, dtype=DTYPE) 9063 value = torch.randn((B, H, K, D), device=DEVICE, dtype=DTYPE) 9064 bias = torch.randn((B, Q, K, C_bias), device=DEVICE, dtype=DTYPE) 9065 weights = torch.randn((C_bias, H), device=DEVICE, dtype=DTYPE) 9066 9067 self.common( 9068 foo, 9069 (query, key, value, bias, weights), 9070 atol=0.02, 9071 rtol=1e4, 9072 ) 9073 9074 @requires_gpu() 9075 @unittest.skipIf( 9076 not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, 9077 "Does not support mem_eff_attention", 9078 ) 9079 @skipIfRocm 9080 def test_sdpa_unaligned_mask(self): 9081 def foo( 9082 arg0_1: "f32[8, 8, 16, 16]", 9083 arg1_1: "f32[8, 8, 15, 16]", 9084 arg2_1: "f32[8, 8, 15, 16]", 9085 arg3_1: "f32[1, 1, 16, 15]", 9086 ): 9087 constant_pad_nd: "f32[1, 1, 16, 16]" = ( 9088 torch.ops.aten.constant_pad_nd.default(arg3_1, [0, 1], 0.0) 9089 ) 9090 arg3_1 = None 9091 slice_1: "f32[1, 1, 16, 15]" = torch.ops.aten.slice.Tensor( 9092 constant_pad_nd, -1, 0, 15 9093 ) 9094 constant_pad_nd = None 9095 expand: "f32[8, 8, 16, 15]" = torch.ops.aten.expand.default( 9096 slice_1, [8, 8, 16, 15] 9097 ) 9098 slice_1 = None 9099 _scaled_dot_product_efficient_attention = ( 9100 torch.ops.aten._scaled_dot_product_efficient_attention.default( 9101 arg0_1, arg1_1, arg2_1, expand, False 9102 ) 9103 ) 9104 arg0_1 = arg1_1 = arg2_1 = expand = None 9105 getitem: "f32[8, 8, 16, 16]" = _scaled_dot_product_efficient_attention[0] 9106 _scaled_dot_product_efficient_attention = None 9107 return (getitem,) 9108 9109 query = torch.rand(8, 8, 16, 16, device=GPU_TYPE) 9110 key = torch.rand(8, 8, 15, 16, device=GPU_TYPE) 9111 value = torch.rand(8, 8, 15, 16, device=GPU_TYPE) 9112 bias = torch.rand(1, 1, 16, 15, device=GPU_TYPE) 9113 self.common( 9114 foo, 9115 (query, key, value, bias), 9116 atol=0.02, 9117 rtol=1e4, 9118 ) 9119 9120 @requires_gpu() 9121 @unittest.skipIf( 9122 not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, 9123 "Does not support mem_eff_attention", 9124 ) 9125 @skipIfRocm 9126 @config.patch(freezing=True) 9127 def test_sdpa_unaligned_mask_freezing(self): 9128 class Mod(torch.nn.Module): 9129 def __init__(self): 9130 super().__init__() 9131 self.arg3_1 = torch.rand(1, 1, 16, 15, device=GPU_TYPE) 9132 9133 def forward( 9134 self, 9135 arg0_1: "f32[8, 8, 16, 16]", 9136 arg1_1: "f32[8, 8, 15, 16]", 9137 arg2_1: "f32[8, 8, 15, 16]", 9138 ): 9139 arg3_1 = self.arg3_1 9140 constant_pad_nd: "f32[1, 1, 16, 16]" = ( 9141 torch.ops.aten.constant_pad_nd.default(arg3_1, [0, 1], 0.0) 9142 ) 9143 arg3_1 = None 9144 slice_1: "f32[1, 1, 16, 15]" = torch.ops.aten.slice.Tensor( 9145 constant_pad_nd, -1, 0, 15 9146 ) 9147 constant_pad_nd = None 9148 expand: "f32[8, 8, 16, 15]" = torch.ops.aten.expand.default( 9149 slice_1, [8, 8, 16, 15] 9150 ) 9151 slice_1 = None 9152 _scaled_dot_product_efficient_attention = ( 9153 torch.ops.aten._scaled_dot_product_efficient_attention.default( 9154 arg0_1, arg1_1, arg2_1, expand, False 9155 ) 9156 ) 9157 arg0_1 = arg1_1 = arg2_1 = expand = None 9158 getitem: "f32[8, 8, 16, 16]" = _scaled_dot_product_efficient_attention[ 9159 0 9160 ] 9161 _scaled_dot_product_efficient_attention = None 9162 return (getitem,) 9163 9164 query = torch.rand(8, 8, 16, 16, device=GPU_TYPE) 9165 key = torch.rand(8, 8, 15, 16, device=GPU_TYPE) 9166 value = torch.rand(8, 8, 15, 16, device=GPU_TYPE) 9167 9168 mod = Mod() 9169 out_eager = mod(query, key, value) 9170 9171 with torch.no_grad(): 9172 out_compiled = torch.compile(mod)(query, key, value) 9173 self.assertEqual(out_eager, out_compiled, atol=0.02, rtol=1e4) 9174 9175 def test_where_with_logical_op(self): 9176 def fn_and(x, y): 9177 return torch.where(torch.logical_and(x, y), 1.0, 0.0) 9178 9179 def fn_or(x, y): 9180 return torch.where(torch.logical_or(x, y), 1.0, 0.0) 9181 9182 self.common( 9183 fn_and, 9184 (torch.randn(32), torch.randn(32)), 9185 ) 9186 self.common( 9187 fn_or, 9188 (torch.randn(32), torch.randn(32)), 9189 ) 9190 9191 @skipIfRocm 9192 def test_conv_with_as_strided(self): 9193 class Model(nn.Module): 9194 def __init__(self): 9195 super().__init__() 9196 self.kv = torch.nn.Conv2d( 9197 256, 384, kernel_size=(1, 1), stride=(1, 1), bias=False 9198 ) 9199 9200 def forward(self, x): 9201 convolution = self.kv(x) 9202 constant_pad_nd = torch.ops.aten.constant_pad_nd.default( 9203 convolution, [2, 2, 2, 2], 0.0 9204 ) 9205 # as_strided inputs are depend on input's size and stide. 9206 as_strided = torch.ops.aten.as_strided.default( 9207 constant_pad_nd, [8, 384, 2, 20, 12], [153600, 400, 160, 1, 20] 9208 ) 9209 as_strided_1 = torch.ops.aten.as_strided.default( 9210 as_strided, [8, 384, 2, 2, 12, 12], [153600, 400, 160, 8, 20, 1] 9211 ) 9212 clone = torch.ops.aten.clone.default( 9213 as_strided_1, memory_format=torch.contiguous_format 9214 ) 9215 return clone 9216 9217 self.common( 9218 Model(), 9219 (torch.randn(8, 256, 16, 16),), 9220 ) 9221 9222 def test_inplace_where_pointwise(self): 9223 # https://github.com/pytorch/pytorch/issues/96446 9224 def fn(a, b): 9225 a[0] = 2 9226 return a * b 9227 9228 self.common(fn, (torch.rand(1), torch.rand(2))) 9229 9230 def test_view_on_aliased(self): 9231 # https://github.com/pytorch/pytorch/issues/96728 9232 def fn1(a, b): 9233 a = a.max(0).values 9234 c = torch.cat((a, b)) 9235 c = c.round() 9236 b >= a[0] # noqa: B015 9237 return c 9238 9239 some_const = torch.tensor(6324) 9240 9241 def fn2(): 9242 a = torch.tensor([[0.6324]]) 9243 ret = torch.cat((a, a), dim=0) 9244 some_const >= a[0] # noqa: B015 9245 return ret 9246 9247 self.common(fn1, (torch.tensor([[4.0]]), torch.tensor([5.0]))) 9248 self.common(fn2, ()) 9249 9250 def test_argmax_to_float(self): 9251 # https://github.com/pytorch/pytorch/issues/97127 9252 def fn(): 9253 a = torch.zeros([2, 2]) 9254 b = a.argmax(0) 9255 return b.float().mean() 9256 9257 self.common(fn, ()) 9258 9259 def test_const_int32_to_float(self): 9260 # https://github.com/pytorch/pytorch/issues/97124 9261 def fn(): 9262 a = torch.zeros([1, 2], dtype=torch.int32) 9263 a = a + a 9264 b = a.to(dtype=torch.float32) 9265 return b * 0.8 9266 9267 self.common(fn, ()) 9268 9269 def test_getitem(self): 9270 out_features = ["p3", "p4", "p5", "p6", "p7"] 9271 in_feature = "p5" 9272 9273 def fn(a): 9274 return a[out_features.index(in_feature)] 9275 9276 x = [ 9277 torch.rand([1, 256, 100, 152], device=self.device), 9278 torch.rand([1, 256, 50, 76], device=self.device), 9279 torch.rand([1, 256, 25, 38], device=self.device), 9280 ] 9281 opt_fn = torch._dynamo.optimize("inductor")(fn) 9282 same(fn(x), opt_fn(x)) 9283 9284 def test_pad_view(self): 9285 def fn(a): 9286 y = torch.nn.functional.pad(a, (0, 0, 0, 1)) 9287 y = y.view(*y.size()[:-2], y.size(-1), y.size(-2)) 9288 return y 9289 9290 x = torch.rand(48, 3, 512, 512) 9291 self.common(fn, (x,)) 9292 9293 def test_pad_cast(self): 9294 def fn(x): 9295 return torch.nn.functional.pad(x.to(torch.float32), (0, 3, 0, 0)) 9296 9297 for dtype in [torch.int32, torch.int64]: 9298 self.common(fn, (torch.ones(1, 1, 13, dtype=dtype),)) 9299 9300 @unittest.skipIf(not HAS_CPU, "requires C++ compiler") 9301 def test_data_type_propogation(self): 9302 from torch._dynamo.utils import detect_fake_mode 9303 from torch._inductor.codegen.common import boolean_ops 9304 from torch._inductor.compile_fx import _shape_env_from_inputs 9305 from torch._inductor.debug import DebugContext 9306 from torch._inductor.decomposition import decompositions 9307 from torch._inductor.graph import GraphLowering 9308 from torch._inductor.virtualized import V 9309 from torch.fx.passes.fake_tensor_prop import FakeTensorProp 9310 9311 def get_data_type(node: torch.fx.Node): 9312 if OptimizationContext.key in node.meta: 9313 return node.meta[OptimizationContext.key].dtype 9314 else: 9315 return None 9316 9317 def func(arg0_1): 9318 max_pool2d_with_indices = torch.ops.aten.max_pool2d_with_indices.default( 9319 arg0_1, [3, 3], [2, 2], [1, 1] 9320 ) 9321 arg0_1 = None 9322 getitem = max_pool2d_with_indices[0] 9323 max_pool2d_with_indices = None 9324 return (getitem,) 9325 9326 example_inputs = [ 9327 torch.randn(10, 32, 20, 20, dtype=torch.bfloat16).to( 9328 memory_format=torch.channels_last 9329 ) 9330 ] 9331 9332 gm = make_fx(func, decomposition_table=decompositions, tracing_mode="fake")( 9333 *example_inputs 9334 ) 9335 9336 shape_env = _shape_env_from_inputs(example_inputs) 9337 9338 fake_mode = detect_fake_mode(example_inputs) 9339 if not fake_mode: 9340 fake_mode = torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True) 9341 FakeTensorProp(gm, mode=fake_mode).propagate(*example_inputs) 9342 else: 9343 FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs( 9344 *example_inputs 9345 ) 9346 with V.set_fake_mode(fake_mode): 9347 graph = GraphLowering( 9348 gm, 9349 shape_env=shape_env, 9350 ) 9351 with V.set_graph_handler(graph), V.set_debug_handler(DebugContext()): 9352 graph.run(*example_inputs) 9353 graph.compile_to_module() 9354 scheduler_node = graph.scheduler.nodes[0] 9355 DataTypePropagation.propagate_scheduler_node(scheduler_node) 9356 root_graph = scheduler_node._body.root_block.graph 9357 for node in root_graph.nodes: 9358 if node.op == "placeholder": 9359 self.assertEqual(get_data_type(node), None) 9360 elif node.target in boolean_ops(): 9361 self.assertEqual(get_data_type(node), torch.bool) 9362 elif node.target in ( 9363 "constant", 9364 "to_dtype", 9365 "index_expr", 9366 ): 9367 self.assertEqual(get_data_type(node), node.args[-1]) 9368 elif node.target in ( 9369 "get_index", 9370 "index_expr", 9371 ): 9372 self.assertEqual(get_data_type(node), torch.int64) 9373 elif node.target in ( 9374 "load", 9375 "store", 9376 ): 9377 self.assertEqual( 9378 get_data_type(node), V.graph.get_dtype(node.args[1]) 9379 ) 9380 elif node.target == "reduction": 9381 _, _, dtype, _, _, _, _ = node.args 9382 self.assertEqual(get_data_type(node), dtype) 9383 elif node.target.startswith("masked_subblock"): 9384 """ 9385 masked_subblocks: 9386 opcode name target args kwargs 9387 ----------- --------- --------- -------------------------- -------- 9388 placeholder ops ops () {} 9389 call_module get_index get_index ('index2',) {} 9390 call_method load load (ops, 'arg0_1', get_index) {} 9391 call_method to_dtype to_dtype (ops, load, torch.float32) {} 9392 output output output (to_dtype,) {} 9393 """ 9394 self.assertEqual(get_data_type(node), torch.float) 9395 elif node.target == "and_": 9396 """ 9397 and_'s input is boolean_ops: 9398 ----------- --------- --------- -------------------------- -------- 9399 call_method and__22 and_ (ops, ge_15, lt_15) 9400 ----------- --------- --------- -------------------------- -------- 9401 """ 9402 self.assertEqual(get_data_type(node), torch.bool) 9403 elif node.target == "maximum": 9404 """ 9405 maximum's input is maximum or masked_subblock: 9406 ----------- --------- --------- -------------------------- -------- 9407 call_method maximum_6 maximum (ops, masked_subblock8, maximum_5) 9408 ----------- --------- --------- -------------------------- -------- 9409 """ 9410 self.assertEqual(get_data_type(node), torch.float) 9411 elif node.target == "output": 9412 self.assertEqual(get_data_type(node), torch.bfloat16) 9413 9414 # Calling div only torch.SymInt arguments is not yet supported. 9415 # To support this behavior, we need to allow const-propping tensors that store symint data. 9416 # For now, dynamo will explicitly graph break when it encounters user code with this behavior. 9417 @expectedFailureCodegenDynamic 9418 def test_AllenaiLongformerBase_repro(self): 9419 def fn(query, scores, window_overlap): 9420 batch_size, seq_len, num_heads, _ = query.size() 9421 chunks_count = torch.div(seq_len, window_overlap, rounding_mode="trunc") - 1 9422 diagonal_attention_scores = scores.new_zeros( 9423 ( 9424 batch_size * num_heads, 9425 chunks_count + 1, 9426 window_overlap, 9427 window_overlap * 2 + 1, 9428 ) 9429 ) 9430 diagonal_attention_scores[:, :-1, :, window_overlap:] = scores[ 9431 :, :, :window_overlap, : window_overlap + 1 9432 ] 9433 input_tensor = diagonal_attention_scores.view( 9434 batch_size, num_heads, seq_len, 2 * window_overlap + 1 9435 ).transpose(2, 1) 9436 beginning_input = input_tensor[:, :window_overlap, :, : window_overlap + 1] 9437 input_tensor[:, :window_overlap, :, : window_overlap + 1] = torch.full_like( 9438 beginning_input, -float("inf") 9439 ) 9440 return input_tensor 9441 9442 args = [ 9443 ((4, 1024, 12, 64), (768, 3072, 64, 1)), 9444 ((48, 3, 512, 513), (787968, 262656, 513, 1)), 9445 ] 9446 args = [rand_strided(sh, st) for (sh, st) in args] 9447 args.append(256) 9448 9449 if self.device == "cpu": 9450 opt_fn = torch._dynamo.optimize("inductor")(fn) 9451 _, code = run_and_get_cpp_code(opt_fn, *args) 9452 print(code) 9453 FileCheck().check_count( 9454 "static_cast<int32_t>(256)", 9455 1, 9456 exactly=True, 9457 ).run(code) 9458 9459 self.common(fn, args) 9460 9461 def test_cumsum_pattern_matcher_issue(self): 9462 def fn(input_ids) -> torch.Tensor: 9463 input_shape = input_ids.size() 9464 input_ids = input_ids.view(-1, input_shape[-1]) 9465 batch_size, seq_length = input_shape 9466 past_key_values_length = 0 9467 mask_seq_length = past_key_values_length + seq_length 9468 attention_mask = torch.ones( 9469 batch_size, mask_seq_length, device=input_ids.device 9470 ) 9471 attention_mask = attention_mask.long() 9472 return torch.cumsum(attention_mask, dim=1) 9473 9474 x = torch.randn(2, 2) 9475 self.common(fn, (x,), atol=0, rtol=0) 9476 9477 @staticmethod 9478 def _check_resize_common( 9479 self, fn, x, size_or_y, memory_format, inplace, deterministic 9480 ): 9481 x_ref_arg = x.clone() 9482 x_opt_arg = x.clone() 9483 x_numel = x.numel() 9484 torch._dynamo.reset_code_caches() 9485 opt_fn = torch._dynamo.optimize_assert(compile_fx)(fn) 9486 correct = fn(x_ref_arg, size_or_y, memory_format) 9487 actual = opt_fn(x_opt_arg, size_or_y, memory_format) 9488 9489 def get_numel(size_or_y): 9490 if isinstance(size_or_y, torch.Tensor): 9491 return size_or_y.numel() 9492 else: 9493 # assume shape 9494 return functools.reduce(lambda x, y: x * y, size_or_y, 1) 9495 9496 if deterministic: 9497 nele_check = correct.numel() 9498 else: 9499 nele_check = min(x_numel, get_numel(size_or_y)) 9500 9501 correct_values = correct.as_strided((nele_check,), (1,)) 9502 actual_values = actual.as_strided((nele_check,), (1,)) 9503 self.assertTrue(same(correct_values, actual_values, equal_nan=deterministic)) 9504 correct_strides = correct.stride() 9505 actual_strides = actual.stride() 9506 self.assertEqual(correct_strides, actual_strides) 9507 9508 @staticmethod 9509 def _cases_resize_common(): 9510 sizes = [ 9511 ((2,), (1, 3, 2, 3)), 9512 ((100,), (1, 3, 2, 3)), 9513 ((1, 3, 2, 3), (1, 3, 2, 3)), 9514 ((2,), (1, 3, 2, 3, 1)), 9515 ((100,), (1, 3, 2, 3, 1)), 9516 ((1, 3, 2, 3, 1), (1, 3, 2, 3, 1)), 9517 ((2, 0, 1), (2, 2)), 9518 ] 9519 for x_size, y_size in sizes: 9520 memory_formats = [torch.contiguous_format] 9521 if len(y_size) == 4: 9522 memory_formats.append(torch.channels_last) 9523 if len(y_size) == 5: 9524 memory_formats.append(torch.channels_last_3d) 9525 for memory_format in memory_formats: 9526 x = torch.randn(*x_size) 9527 yield x, y_size, memory_format 9528 # check some non-contiguous tensors 9529 if x.numel() == 100: 9530 x_strided = x[::2].reshape(25, 2).transpose(0, 1) 9531 yield x_strided, y_size, memory_format 9532 9533 def test_resize(self): 9534 def fn(x, size, memory_format): 9535 # NOTE: Tensor.resize() =/= aten::resize() 9536 return torch.ops.aten.resize(x, size, memory_format=memory_format) 9537 9538 for deterministic in [True, False]: 9539 with DeterministicGuard( 9540 deterministic, fill_uninitialized_memory=deterministic 9541 ): 9542 for x, y_size, memory_format in CommonTemplate._cases_resize_common(): 9543 CommonTemplate._check_resize_common( 9544 self, 9545 fn, 9546 x, 9547 y_size, 9548 memory_format, 9549 inplace=False, 9550 deterministic=deterministic, 9551 ) 9552 9553 @staticmethod 9554 def _cases_resize_as_common(): 9555 for x, y_size, memory_format in CommonTemplate._cases_resize_common(): 9556 # each sizes /memory_format combintation tested in 2 ways: 9557 # 1. y is contiguous fn gets memory_format kwargs 9558 # 2. y has memory_format contiguity and fn gets preserve kwarg 9559 # 3. y has some other strides (not contiguous or channels last) and fn gets preserve 9560 yield x, torch.randn(*y_size), memory_format 9561 yield x, torch.randn(*y_size).contiguous( 9562 memory_format=memory_format 9563 ), torch.preserve_format 9564 yield x, torch.randn(*y_size).permute( 9565 tuple(reversed(range(len(y_size)))) 9566 ), torch.preserve_format 9567 9568 def test_resize_as(self): 9569 def fn(x, y, memory_format): 9570 return torch.ops.aten.resize_as(x, y, memory_format=memory_format) 9571 9572 for deterministic in [True, False]: 9573 with DeterministicGuard( 9574 deterministic, fill_uninitialized_memory=deterministic 9575 ): 9576 for x, y, memory_format in CommonTemplate._cases_resize_as_common(): 9577 CommonTemplate._check_resize_common( 9578 self, 9579 fn, 9580 x, 9581 y, 9582 memory_format, 9583 inplace=False, 9584 deterministic=deterministic, 9585 ) 9586 9587 def test_inplace_resize_as(self): 9588 def fn(x, y): 9589 x.resize_as_(y) 9590 return x 9591 9592 x = torch.randn(2, 3) 9593 y = torch.randn(200, 300) 9594 x_clone = x.clone() 9595 opt_fn = torch._dynamo.optimize("inductor")(fn) 9596 same(fn(x, y), opt_fn(x_clone, y)) 9597 9598 def test_erfc(self): 9599 def fn(x): 9600 return torch.erfc(x) 9601 9602 self.common(fn, (torch.randn(8, 8),)) 9603 9604 def test_erfinv(self): 9605 def fn(x): 9606 return torch.erfinv(x) 9607 9608 # domain for erfinv is (-1, 1) 9609 x = torch.empty(8, 8).uniform_(-1, 1) 9610 self.common(fn, (x,)) 9611 9612 def test_uint(self): 9613 def fn(z): 9614 x = torch.tensor(5, device=z.device, dtype=torch.uint8) 9615 y = torch.neg(x) 9616 return x < y 9617 9618 self.common(fn, (torch.randn(26),)) 9619 9620 def test_scaled_dot_product_attention(self): 9621 if self.device == "cuda" and not PLATFORM_SUPPORTS_FLASH_ATTENTION: 9622 raise unittest.SkipTest("Can't run flash attention on this platform") 9623 if self.device == "cuda" and TEST_WITH_ROCM: 9624 raise unittest.SkipTest( 9625 "Flash attention support is incomplete on this platform" 9626 ) 9627 9628 def fn(q, k, v): 9629 return torch.nn.functional.scaled_dot_product_attention( 9630 q.transpose(1, 2).contiguous(), 9631 k.transpose(1, 2), 9632 v.transpose(1, 2), 9633 scale=0.125, 9634 )[:2] 9635 9636 self.common( 9637 fn, 9638 ( 9639 torch.randn(4, 2, 4, 2), 9640 torch.randn(4, 2, 4, 2), 9641 torch.randn(4, 2, 4, 2), 9642 ), 9643 atol=2e-4, # to pass lowp check on GPU 9644 rtol=1e-2, # to pass lowp check on GPU 9645 ) 9646 9647 @skipIfRocm 9648 @expectedFailureXPU 9649 def test_scaled_dot_product_efficient_attention(self): 9650 if self.device == "cpu": 9651 raise unittest.SkipTest(f"requires {GPU_TYPE}") 9652 9653 # The first two values should be the same, attention output 9654 # and logsumexp since dropout is not being set 9655 def fn(q, k, v, attn_bias, compute_log_sumexp): 9656 return aten._scaled_dot_product_efficient_attention( 9657 q, k, v, attn_bias, compute_log_sumexp 9658 )[:2] 9659 9660 self.common( 9661 fn, 9662 ( 9663 torch.randn(4, 4, 36, 36), 9664 torch.randn(4, 4, 36, 36), 9665 torch.randn(4, 4, 36, 36), 9666 torch.randn(4, 4, 36, 36), 9667 False, 9668 ), 9669 check_lowp=False, 9670 ) 9671 9672 def test_fft_real_input(self): 9673 def fn(x): 9674 return torch.fft.fftn(x) 9675 9676 self.common(fn, (torch.randn((16, 16, 16)),), check_lowp=False) 9677 9678 def test_fft_real_input_real_output(self): 9679 def fn(x): 9680 return torch.fft.fftn(x).real 9681 9682 self.common(fn, (torch.randn((16, 16, 16)),), check_lowp=False) 9683 9684 def test_bucketize(self): 9685 def fn(input, boundaries, out_int32, right): 9686 return torch.bucketize(input, boundaries, out_int32=out_int32, right=right) 9687 9688 input = torch.rand((64, 64)) * 2 - 1 9689 boundaries = torch.tensor([-0.9, -0.8, 0.1, 0.2, 0.5, 0.9]) 9690 9691 for out_int32 in [True, False]: 9692 for right in [True, False]: 9693 out_int32 = True 9694 right = False 9695 self.common(fn, (input, boundaries, out_int32, right), check_lowp=False) 9696 9697 def test_bucketize_default_kwargs(self): 9698 def fn(input, offsets): 9699 return torch.bucketize(input, offsets) 9700 9701 input = torch.tensor( 9702 [-1.0, -0.9, -0.8, -0.5, 0.0, 0.1, 0.2, 0.4, 0.5, 0.6, 0.9, 0.91] 9703 ) 9704 offsets = torch.tensor([-0.9, -0.8, 0.1, 0.2, 0.5, 0.9]) 9705 9706 self.common(fn, (input, offsets), check_lowp=False) 9707 9708 def test_bucketize_int(self): 9709 def fn(input, offsets, out_int32, right): 9710 return torch.bucketize(input, offsets, out_int32=out_int32, right=right) 9711 9712 input = torch.randint(0, 102, (64, 64)) 9713 offsets = torch.arange(10, dtype=torch.int32) ** 2 + 1 9714 9715 for out_int32 in [True, False]: 9716 for right in [True, False]: 9717 self.common(fn, (input, offsets, out_int32, right), check_lowp=False) 9718 9719 @patch.object(config.triton, "autotune_pointwise", True) 9720 def test_bucketize_add_autotune(self): 9721 # Causes a @pointwise(size_hints) where size_hints is 2D 9722 9723 def fn(input, offsets, add_value): 9724 return torch.bucketize(input, offsets) + add_value 9725 9726 input = torch.rand((16, 16, 64, 64)) 9727 boundaries = torch.tensor([-0.9, -0.8, 0.1, 0.2, 0.5, 0.9]) 9728 add_value = torch.randint(0, 1024, (16, 16, 64, 64)).to( 9729 memory_format=torch.channels_last 9730 ) 9731 9732 self.common(fn, (input, boundaries, add_value), check_lowp=False) 9733 9734 assertGeneratedKernelCountEqual(self, 1) 9735 9736 def test_bucketize_computed_offsets(self): 9737 def fn(inp, offsets): 9738 return torch.bucketize(inp, offsets + 0.01) 9739 9740 inp = torch.tensor( 9741 [-1.0, -0.9, -0.8, -0.5, 0.0, 0.1, 0.2, 0.4, 0.5, 0.6, 0.9, 0.91] 9742 ) 9743 offsets = torch.tensor([-0.9, -0.8, 0.1, 0.2, 0.5, 0.9]) - 0.01 9744 9745 self.common(fn, (inp, offsets), check_lowp=False) 9746 9747 @requires_gpu() 9748 @config.patch(assume_aligned_inputs=False) 9749 def test_config_option_dont_assume_alignment(self): 9750 def fn(x: torch.Tensor) -> torch.Tensor: 9751 return x.sin() + x.cos() 9752 9753 # Inductor specializes on the (unguarded) alignment of the initial input. 9754 # Make sure that for different configurations, nothing breaks. 9755 for offset in (0, 1, 2, 3, 4): 9756 base = torch.randn(64 * 64 + 64, dtype=torch.float32, device=GPU_TYPE) 9757 inp = torch.as_strided(base, (64, 64), (64, 1), offset) 9758 torch._dynamo.reset() 9759 fn_c = torch.compile(fn) 9760 9761 ref = fn(inp) 9762 res = fn_c(inp) 9763 self.assertEqual(ref, res) 9764 9765 for offset2 in (0, 1, 2, 3, 4): 9766 base2 = torch.randn(64 * 64 + 64, dtype=torch.float32, device=GPU_TYPE) 9767 inp2 = torch.as_strided(base, (64, 64), (64, 1), offset2) 9768 ref2 = fn(inp2) 9769 res2 = fn_c(inp2) 9770 self.assertEqual(ref2, res2) 9771 9772 @requires_gpu() 9773 @config.patch(assume_aligned_inputs=False) 9774 def test_config_option_dont_assume_alignment_recompiles(self): 9775 # Inputs: 9776 # 1. (32, 32) shape 9777 # 2. (64, 64) shape -> causes a recompile 9778 # 3. (64, 64) shape with different storage offset -> should NOT cause a recompile 9779 failed_guards = [] 9780 9781 def fail(guard): 9782 nonlocal failed_guards 9783 failed_guards.append(guard) 9784 9785 def fn(x: torch.Tensor) -> torch.Tensor: 9786 return x.sin() + x.cos() 9787 9788 base = torch.randn(64 * 64 + 64, dtype=torch.float32, device=GPU_TYPE) 9789 9790 inp1 = torch.as_strided(base, (32, 32), (32, 1), 4) 9791 inp2 = torch.as_strided(base, (64, 64), (64, 1), 4) 9792 inp3 = torch.as_strided(base, (64, 64), (64, 1), 5) 9793 9794 torch._dynamo.reset() 9795 9796 fn_c = torch._dynamo.optimize("inductor", guard_fail_fn=fail)(fn) 9797 9798 ref1 = fn(inp1) 9799 res1 = fn_c(inp1) 9800 self.assertEqual(ref1, res1) 9801 self.assertEqual(0, len(failed_guards)) 9802 9803 ref2 = fn(inp2) 9804 res2 = fn_c(inp2) 9805 self.assertEqual(ref2, res2) 9806 # if dynamic shapes isn't already turned on, we might have a guard failure as we turn 9807 # on dynamic shapes 9808 self.assertLessEqual(len(failed_guards), 1) 9809 failed_guard_count_iteration_2 = len(failed_guards) 9810 9811 failed_guards = [] 9812 ref3 = fn(inp3) 9813 res3 = fn_c(inp3) 9814 self.assertEqual(ref3, res3) 9815 # we might still have the dynamics shapes failure, but offset change shouldn't be guarded on 9816 # see Note: [Input Alignment handling in Inductor] 9817 self.assertLessEqual(len(failed_guards), failed_guard_count_iteration_2) 9818 9819 @requires_gpu() 9820 @config.patch(assume_aligned_inputs=False) 9821 def test_config_option_dont_assume_alignment_cudagraphs(self): 9822 def fn(x): 9823 return x.cos() * x.sin() 9824 9825 fn_c = torch.compile(fn, mode="reduce-overhead", dynamic=True) 9826 9827 for size, stride, offset in ( 9828 ((32, 32), (32, 1), 4), 9829 ((48, 48), (48, 1), 4), 9830 ((64, 64), (64, 1), 5), 9831 ): 9832 torch.manual_seed(42) 9833 base = torch.randn(64 * 64 + 64, dtype=torch.float32, device=GPU_TYPE) 9834 torch.manual_seed(42) 9835 base_ref = torch.randn(64 * 64 + 64, dtype=torch.float32, device=GPU_TYPE) 9836 9837 inp = torch.as_strided(base, size, stride, offset) 9838 inp_ref = torch.as_strided(base_ref, size, stride, offset) 9839 9840 inp.requires_grad_(True) 9841 inp_ref.requires_grad_(True) 9842 9843 res = fn_c(inp) 9844 ref = fn(inp_ref) 9845 self.assertEqual(ref, res) 9846 9847 res.sum().backward() 9848 ref.sum().backward() 9849 self.assertEqual(base.grad, base_ref.grad) 9850 9851 @config.patch(implicit_fallbacks=True) 9852 def test_custom_op_1(self): 9853 import torch.library 9854 9855 def foo_cpu(x): 9856 return 3 * x 9857 9858 def foo_cuda(x): 9859 return 3 * x 9860 9861 def foo_xpu(x): 9862 return 3 * x 9863 9864 def foo_meta(x): 9865 return torch.empty_like(x) 9866 9867 define_custom_op_for_test("foo", foo_cpu, foo_cuda, foo_xpu, foo_meta) 9868 9869 def fn(x): 9870 a = torch.nn.functional.relu(x) 9871 b = torch.ops.test.foo(a) 9872 c = torch.cos(b) 9873 return c 9874 9875 self.common(fn, (torch.randn((16, 32)),), check_lowp=False) 9876 9877 @config.patch(implicit_fallbacks=True) 9878 def test_custom_op_2(self): 9879 import torch.library 9880 9881 def foo_cpu(x, scale: float): 9882 return scale * x, torch.cos(x) 9883 9884 def foo_cuda(x, scale: float): 9885 return scale * x, torch.cos(x) 9886 9887 def foo_xpu(x, scale: float): 9888 return scale * x, torch.cos(x) 9889 9890 def foo_meta(x, scale: float): 9891 return torch.empty_like(x), torch.empty_like(x) 9892 9893 define_custom_op_2_for_test("foo2", foo_cpu, foo_cuda, foo_xpu, foo_meta) 9894 9895 def fn(x, scale: float): 9896 a = torch.nn.functional.relu(x) 9897 return torch.ops.test.foo2(a, scale) 9898 9899 self.common(fn, (torch.randn((16, 32)), 2.0), check_lowp=False) 9900 9901 @config.patch(implicit_fallbacks=True) 9902 def test_custom_op_3(self): 9903 import torch.library 9904 9905 def foo_cpu(x): 9906 result = torch.zeros_like(x[0]) 9907 for t in x: 9908 result += t 9909 return result 9910 9911 def foo_cuda(x): 9912 result = torch.zeros_like(x[0]) 9913 for t in x: 9914 result += t 9915 return result 9916 9917 def foo_xpu(x): 9918 result = torch.zeros_like(x[0]) 9919 for t in x: 9920 result += t 9921 return result 9922 9923 def foo_meta(x): 9924 return torch.empty_like(x[0]) 9925 9926 define_custom_op_3_for_test("foo3", foo_cpu, foo_cuda, foo_xpu, foo_meta) 9927 9928 def fn(x): 9929 return torch.ops.test.foo3(x) 9930 9931 self.common( 9932 fn, 9933 ([torch.randn((16, 32)), torch.randn((16, 32)), torch.randn((16, 32))],), 9934 check_lowp=False, 9935 ) 9936 9937 @requires_gpu() 9938 @torch._inductor.config.patch("layout_optimization", True) 9939 @torch._inductor.config.patch("keep_output_stride", False) 9940 @config.patch(implicit_fallbacks=True) 9941 def test_custom_op_fixed_layout_sequential(self): 9942 import torch.library 9943 9944 mod = nn.Conv2d(3, 128, 1, stride=1, bias=False).to(device=GPU_TYPE) 9945 inp = torch.rand(2, 3, 128, 128, device=GPU_TYPE) 9946 expected_stride = mod(inp).stride() 9947 9948 def bar_cpu(x): 9949 self.assertEqual(x.stride(), expected_stride) 9950 return x.clone() 9951 9952 def bar_cuda(x): 9953 self.assertEqual(x.stride(), expected_stride) 9954 return x.clone() 9955 9956 def bar_xpu(x): 9957 self.assertEqual(x.stride(), expected_stride) 9958 return x.clone() 9959 9960 def bar_meta(x): 9961 return torch.empty_like(x) 9962 9963 define_custom_op_for_test( 9964 "bar", 9965 bar_cpu, 9966 bar_cuda, 9967 bar_xpu, 9968 bar_meta, 9969 tags=[torch._C.Tag.needs_fixed_stride_order], 9970 ) 9971 9972 def fn(x): 9973 z = mod(x) 9974 output = torch.ops.test.bar(z) 9975 return output 9976 9977 with torch.no_grad(): 9978 # With keep_output_stride False, inductor would normally have different layout from eager execution 9979 # But because our custom op needs fixed layout, the assertions in the custom op will pass 9980 self.common(fn, (inp,), check_lowp=False) 9981 9982 @config.patch(implicit_fallbacks=True) 9983 def test_mutable_custom_op_fixed_layout(self): 9984 with torch.library._scoped_library("mylib", "DEF") as lib: 9985 lib.define( 9986 "copy_(Tensor(a!) dst, Tensor src) -> ()", 9987 tags=torch.Tag.needs_fixed_stride_order, 9988 ) 9989 9990 @torch.library.impl(lib, "copy_", "Meta") 9991 def _(dst, src): 9992 return None 9993 9994 @torch.library.impl(lib, "copy_", "CompositeExplicitAutograd") 9995 def _(dst, src): 9996 dst.copy_(src) 9997 9998 def f(x): 9999 full_default_3 = torch.full([3], 7.0, device="cpu") 10000 chunk_cat_default_1 = torch.ops.mylib.copy_.default(full_default_3, x) 10001 mul_out = torch.mul(full_default_3, full_default_3) 10002 return mul_out 10003 10004 x = torch.arange(3, dtype=torch.float, device="cpu") 10005 eager_out = f(x) 10006 10007 compiled_inductor_f = torch.compile(f, backend="inductor", fullgraph=True) 10008 compiled_inductor_out = compiled_inductor_f(x) 10009 self.assertEqual(compiled_inductor_out, eager_out) 10010 10011 @requires_gpu() 10012 @config.patch(implicit_fallbacks=True) 10013 def test_custom_op_fixed_layout_channels_last(self): 10014 class Block(nn.Module): 10015 def __init__( 10016 self, 10017 ): 10018 super().__init__() 10019 10020 self.in_layers = nn.Sequential( 10021 nn.Dropout(p=0.1), 10022 ) 10023 10024 def helper(self, x): 10025 out = F.gelu(x) 10026 out = self.in_layers(out) 10027 return out 10028 10029 def forward(self, x): 10030 out = self.helper(x) 10031 out = torch.ops.test.baz(out) 10032 return out 10033 10034 model = Block() 10035 model = model.to(GPU_TYPE).to(memory_format=torch.channels_last) 10036 input_t = torch.randn([1, 320, 128, 128], dtype=torch.float32, device=GPU_TYPE) 10037 input_t = input_t.to(memory_format=torch.channels_last) 10038 expected_strides = model.helper(input_t).stride() 10039 10040 def baz_cpu(x): 10041 self.assertEqual(expected_strides, x.stride()) 10042 return x.clone() 10043 10044 def baz_cuda(x): 10045 self.assertEqual(expected_strides, x.stride()) 10046 return x.clone() 10047 10048 def baz_xpu(x): 10049 self.assertEqual(expected_strides, x.stride()) 10050 return x.clone() 10051 10052 def baz_meta(x): 10053 return torch.empty_like(x) 10054 10055 define_custom_op_for_test( 10056 "baz", 10057 baz_cpu, 10058 baz_cuda, 10059 baz_xpu, 10060 baz_meta, 10061 tags=[torch._C.Tag.needs_fixed_stride_order], 10062 ) 10063 10064 with torch.no_grad(): 10065 net = torch.compile(model) 10066 out = net(input_t) 10067 10068 def test_buffer_use_after_remove(self): 10069 # https://github.com/pytorch/pytorch/issues/102857 10070 10071 def rotvec_to_rotmat(rotvec) -> torch.Tensor: 10072 """Simplified rotvec to rotmat code from RoMa 10073 (https://github.com/naver/roma/blob/06e4b0cdc1c802a60a012bb19c581d6600c63358/roma/mappings.py#L371) 10074 """ 10075 theta = torch.norm(rotvec, dim=-1) 10076 axis = rotvec / theta[..., None] 10077 kx, ky, kz = axis[:, 0], axis[:, 1], axis[:, 2] 10078 sin_theta = torch.sin(theta) 10079 cos_theta = torch.cos(theta) 10080 one_minus_cos_theta = 1 - cos_theta 10081 xs = kx * sin_theta 10082 ys = ky * sin_theta 10083 zs = kz * sin_theta 10084 xyc = kx * ky * one_minus_cos_theta 10085 xzc = kx * kz * one_minus_cos_theta 10086 yzc = ky * kz * one_minus_cos_theta 10087 xxc = kx**2 * one_minus_cos_theta 10088 yyc = ky**2 * one_minus_cos_theta 10089 zzc = kz**2 * one_minus_cos_theta 10090 R_rodrigues = torch.stack( 10091 [ 10092 1 - yyc - zzc, 10093 xyc - zs, 10094 xzc + ys, 10095 xyc + zs, 10096 1 - xxc - zzc, 10097 -xs + yzc, 10098 xzc - ys, 10099 xs + yzc, 10100 1 - xxc - yyc, 10101 ], 10102 dim=-1, 10103 ).reshape(-1, 3, 3) 10104 R = R_rodrigues 10105 return R 10106 10107 def f(coord, rot, trans): 10108 rot_mat = rotvec_to_rotmat(rot) 10109 coord = torch.einsum("...ij,...bj->...bi", rot_mat, coord) + trans 10110 return coord.sum() 10111 10112 foo_c = torch.compile(f, dynamic=True) 10113 10114 def run(fn): 10115 coord = torch.ones((2, 3), device=self.device) 10116 rot = nn.Parameter(torch.ones((2, 3), device=self.device)) 10117 trans = nn.Parameter(torch.ones((2, 3), device=self.device)) 10118 10119 U = fn(coord, rot, trans) 10120 U.backward() 10121 10122 return U, rot, trans 10123 10124 U_e, rot_e, trans_e = run(f) 10125 U, rot, trans = run(foo_c) 10126 10127 self.assertEqual(U, U_e) 10128 self.assertEqual(rot.grad, rot_e.grad) 10129 self.assertEqual(trans.grad, trans_e.grad) 10130 10131 @config.patch({"fx_graph_cache": False}) 10132 def test_inner_fn_str_and_stride(self): 10133 def f(x): 10134 x = x + 1 10135 x = test_operators.realize(x) 10136 x = x * 2 10137 x = test_operators.realize(x) 10138 return x 10139 10140 x = torch.rand(3, 2, device=self.device).t() 10141 ref = f(x) 10142 called = False 10143 10144 def hook_fn(scheduler, nodes): 10145 nonlocal called 10146 called = True 10147 10148 if self.device != "cpu": 10149 self.assertEqual(len(nodes), 3) 10150 _, mul_buf, _ = nodes 10151 self.assertTrue( 10152 all( 10153 V.graph.sizevars.size_hints(buf.get_stride()) == (1, 2) 10154 for buf in nodes 10155 ) 10156 ) 10157 # before the fix, the wrong index expression 10158 # 'i1 + 3 * i0' is cached. 10159 self.assertTrue( 10160 "i0 + 2 * i1" in mul_buf.data.inner_fn_str() 10161 or "i0 + i1 * s1" in mul_buf.data.inner_fn_str() 10162 ) 10163 10164 with add_scheduler_init_hook(hook_fn): 10165 actual = torch.compile(f, fullgraph=True)(x) 10166 self.assertEqual(ref, actual) 10167 self.assertTrue(called) 10168 10169 def test_mutations_loop_fusion(self): 10170 def fn(tensor, index, source): 10171 out = tensor.index_add(0, index, source, alpha=2.0) / 2 10172 return out 10173 10174 device = "cpu" 10175 tensor = torch.rand((1,), dtype=torch.double, device=device) 10176 index = torch.tensor([0], dtype=torch.long, device=device) 10177 source = torch.rand((1,), dtype=torch.double, device=device) 10178 self.common( 10179 fn, 10180 ( 10181 tensor, 10182 index, 10183 source, 10184 ), 10185 ) 10186 10187 @config.patch( 10188 "triton.autotune_pointwise", True 10189 ) # needed to introduce config that exceed max shared memory usage 10190 @serialTest() 10191 def test_large_block_sizes(self): 10192 """ 10193 Inductor will try triton configs like x = 64 and y = 1024 which will 10194 result in out of shared memory if dtype is fp32. 10195 10196 Currently inductor will skip such bad configs and pick the best one 10197 from the remaining configs. 10198 """ 10199 if not _has_sufficient_memory(self.device, 3 * 2**24 * 65 * 4): 10200 raise unittest.SkipTest("insufficient memory") 10201 10202 @torch.compile 10203 def fn(x, y): 10204 return x.t() + y 10205 10206 # Use shape (2**24, 65) rather than (2**24, 128) potentially avoid OOM in 10207 # CI while still keep the same up-rounded size-hints. 10208 a = torch.randn(2**24, 65, device=self.device) 10209 b = torch.randn(65, 2**24, device=self.device) 10210 fn(a, b) 10211 10212 # Skipped on ROCm until https://github.com/ROCm/triton/issues/443 resolved 10213 @skipIfRocm 10214 def test_fuse_large_params(self): 10215 def pt2_optimizer_step(optimizer): 10216 @torch.compile() 10217 def f(): 10218 optimizer.step() 10219 10220 f() 10221 10222 params = [ 10223 torch.rand(10, 10, dtype=torch.float32, device=self.device) 10224 for _ in range(194) 10225 ] 10226 for p in params: 10227 p.grad = torch.rand_like(p) 10228 10229 o = torch.optim.AdamW(params) 10230 pt2_optimizer_step(o) 10231 10232 def test_adaptive_avg_pool1d_argmax(self): 10233 # https://github.com/pytorch/pytorch/issues/113013 10234 def fn(x): 10235 x = torch.adaptive_avg_pool1d(input=x, output_size=2) 10236 x = torch.argmax(input=x) 10237 return x 10238 10239 x = torch.rand([4, 4, 3], dtype=torch.float64) 10240 self.common(fn, (x,)) 10241 10242 def test_float16_to_int16(self): 10243 def fn(x): 10244 x_view = x.view(dtype=torch.int16) 10245 return x_view.mul(2) 10246 10247 x = torch.ones(4, dtype=torch.float16, device=self.device) 10248 ref = fn(x) 10249 actual = torch.compile(fn)(x) 10250 self.assertEqual(ref, actual) 10251 10252 @skipCUDAIf(not SM80OrLater, "uses bfloat16 which requires SM >= 80") 10253 def test_bfloat16_to_int16(self): 10254 def fn(a, b): 10255 x = a + b 10256 x_view = x.view(dtype=torch.int16) 10257 return x_view.mul(2) 10258 10259 a = torch.ones(4, dtype=torch.bfloat16, device=self.device) 10260 b = torch.ones(4, dtype=torch.bfloat16, device=self.device) 10261 ref = fn(a, b) 10262 actual = torch.compile(fn)(a, b) 10263 self.assertEqual(ref, actual) 10264 10265 def test_float32_to_int32(self): 10266 def fn(a, b): 10267 x = a + b 10268 x_view = x.view(dtype=torch.int32) 10269 return x_view.mul(2) 10270 10271 a = torch.ones(4, dtype=torch.float32, device=self.device) 10272 b = torch.ones(4, dtype=torch.float32, device=self.device) 10273 ref = fn(a, b) 10274 actual = torch.compile(fn)(a, b) 10275 self.assertEqual(ref, actual) 10276 10277 def test_randint_int64_mod(self): 10278 # This used to not compile due to a wrong return type of randint64_cpu 10279 # See https://github.com/pytorch/pytorch/issues/117435 10280 def fn(n): 10281 return ( 10282 torch.randint( 10283 low=-5, high=5, size=(n,), dtype=torch.int64, device=self.device 10284 ) 10285 % 10 10286 ) 10287 10288 res = torch.compile(fn)(20) 10289 self.assertTrue(torch.all((0 <= res) & (res < 10)).item()) 10290 10291 @torch._inductor.config.patch(force_shape_pad=True) 10292 def test_should_pad_bench_for_bmm(self): 10293 B = 2 10294 M = 1024 10295 N = 1024 10296 K = 1024 + 1 # a size that requires padding 10297 10298 mat1 = torch.rand(B, M, K, device=self.device) 10299 mat2 = torch.rand(B, K, N, device=self.device) 10300 10301 should_pad = pad_mm.should_pad_bench(None, mat1, mat2, torch.ops.aten.bmm) 10302 10303 self.assertTrue(should_pad) 10304 10305 @parametrize( 10306 "name, op", 10307 [ 10308 subtest((name, getattr(torch.special, name)), name=name) 10309 for name in torch.special.__all__ 10310 if name not in {"softmax", "log_softmax", "logsumexp"} 10311 ], 10312 ) 10313 def test_pointwise(self, name, op): 10314 dtype = torch.float32 10315 check_lowp = True 10316 if self.device == GPU_TYPE and name in { 10317 "airy_ai", 10318 "bessel_i0", 10319 "bessel_i1", 10320 "bessel_j0", 10321 "bessel_j1", 10322 "bessel_y0", 10323 "bessel_y1", 10324 "erfcx", 10325 "gammainc", 10326 "gammaincc", 10327 "i1", 10328 "i1e", 10329 "modified_bessel_i0", 10330 "modified_bessel_i1", 10331 "modified_bessel_k0", 10332 "modified_bessel_k1", 10333 "ndtri", 10334 "scaled_modified_bessel_k0", 10335 "scaled_modified_bessel_k1", 10336 "spherical_bessel_j0", 10337 "zeta", 10338 "chebyshev_polynomial_t", 10339 "chebyshev_polynomial_v", 10340 "chebyshev_polynomial_u", 10341 "chebyshev_polynomial_w", 10342 "legendre_polynomial_p", 10343 "shifted_chebyshev_polynomial_t", 10344 "shifted_chebyshev_polynomial_u", 10345 "shifted_chebyshev_polynomial_v", 10346 "shifted_chebyshev_polynomial_w", 10347 "hermite_polynomial_h", 10348 "hermite_polynomial_he", 10349 "laguerre_polynomial_l", 10350 }: 10351 # <func>_cuda not implemented for Half 10352 check_lowp = False 10353 10354 if name in {"gammainc", "gammaincc"}: 10355 args = ( 10356 torch.randn(8, 8, dtype=dtype, device=self.device), 10357 torch.empty(8, 8, dtype=dtype, device=self.device).uniform_(1, 2), 10358 ) 10359 10360 def fn(x, y): 10361 return op(x, y) 10362 10363 elif name in {"xlog1py", "xlogy", "zeta"}: 10364 args = ( 10365 torch.randn(8, 8, dtype=dtype, device=self.device), 10366 torch.empty(8, 8, dtype=dtype, device=self.device).uniform_(1, 2), 10367 ) 10368 10369 def fn(x, y): 10370 return op(x, y) 10371 10372 elif name == "multigammaln": 10373 args = ( 10374 torch.empty(8, 8, dtype=dtype, device=self.device).uniform_(1, 2), 10375 2, 10376 ) 10377 10378 def fn(x, p): 10379 return op(x, p) 10380 10381 elif name == "polygamma": 10382 args = ( 10383 1, 10384 torch.empty(8, 8, dtype=dtype, device=self.device).uniform_(1, 10), 10385 ) 10386 10387 def fn(n, x): 10388 return op(n, x) 10389 10390 elif "_polynomial_" in name: 10391 args = ( 10392 torch.randn(8, 8, dtype=dtype, device=self.device), 10393 2, 10394 ) 10395 10396 def fn(x, n): 10397 return op(x, n) 10398 10399 else: 10400 args = (torch.randn(8, 8, dtype=dtype, device=self.device),) 10401 10402 def fn(x): 10403 return op(x) 10404 10405 self.common(fn, args, check_lowp=check_lowp) 10406 10407 # codegen test fails with no dynamic for loop in dynamic shape tests 10408 @expectedFailureCodegenDynamic 10409 def test_view_uint8_through_differing_bitwidths(self): 10410 # https://github.com/pytorch/pytorch/issues/120998 10411 def fn(x, view_dtype): 10412 return x.view(view_dtype).view(torch.uint8) 10413 10414 view_dtypes = [torch.int16, torch.int32, torch.int64] 10415 for dtype in view_dtypes: 10416 x = torch.randint(0, 2**4, [4096, 4096], dtype=torch.uint8) 10417 self.common( 10418 fn, 10419 ( 10420 x, 10421 dtype, 10422 ), 10423 ) 10424 10425 @torch._dynamo.config.patch(capture_scalar_outputs=True) 10426 def test_split_with_sizes_with_unbacked_symints(self): 10427 @torch.compile() 10428 def f(sz, x): 10429 s0, s1 = sz.tolist() 10430 r0, r1 = torch.ops.aten.split_with_sizes.default(x, [s0, s1]) 10431 return torch.ops.aten.sort.default(r1) 10432 10433 N = 7312 10434 S0 = 420 10435 S1 = N - S0 10436 10437 result = f(torch.tensor([S0, S1]), torch.randn(N)) 10438 self.assertTrue(len(result) == 2) 10439 10440 @torch.compile() 10441 def f2(x): 10442 y = torch.arange(x.item()) 10443 return torch.ops.aten.split_with_sizes.default(y, [5, 5, 10]) 10444 10445 result = f2(torch.tensor([20])) 10446 self.assertTrue(len(result) == 3) 10447 10448 @torch._dynamo.config.patch(capture_scalar_outputs=True) 10449 def test_split_with_unbacked_symints(self): 10450 # https://github.com/pytorch/pytorch/issues/122937 10451 @torch.compile() 10452 def f(x): 10453 y = torch.arange(x.item()) 10454 return torch.split(y, [5, 5, 10]) 10455 10456 result = f(torch.tensor([20])) 10457 self.assertTrue(len(result) == 3) 10458 10459 def test_complex_memory_overlap(self): 10460 t = rand_strided((8, 1500, 1), (1504, 1, 1), device=self.device) 10461 self.assertFalse(complex_memory_overlap(t)) 10462 10463 def test_generate_rand_fp8(self): 10464 """ 10465 PyTorch can not generate fp8 tensors with a normal distribution because of 10466 missing needed kernels. 10467 10468 We work around that in rand_strided by generating an fp16 tensor first and 10469 then do casting. 10470 """ 10471 t = rand_strided((2, 3), (3, 1), device=self.device, dtype=torch.float8_e4m3fn) 10472 self.assertTrue(t.dtype is torch.float8_e4m3fn) 10473 10474 def test_large_grid(self): 10475 # https://github.com/pytorch/pytorch/issues/123210 10476 def fn(primals_5): 10477 view = torch.ops.aten.reshape.default(primals_5, [-1, 2, 4]) 10478 primals_5 = None 10479 permute = torch.ops.aten.permute.default(view, [0, 2, 1]) 10480 clone = torch.ops.aten.clone.default( 10481 permute, memory_format=torch.contiguous_format 10482 ) 10483 return clone 10484 10485 s0 = 16777472 10486 s1 = 8 10487 compiled_fn = torch._dynamo.optimize()(fn) 10488 actual = compiled_fn(torch.ones(s0, s1)) 10489 self.assertTrue((actual == 1).all()) 10490 10491 10492@dataclasses.dataclass 10493class TestFailure: 10494 suffixes: Tuple[str] 10495 is_skip: bool = False 10496 __test__: bool = False 10497 10498 10499def copy_tests( 10500 my_cls, other_cls, suffix, test_failures=None, xfail_prop=None 10501): # noqa: B902 10502 for name, value in my_cls.__dict__.items(): 10503 if name.startswith("test_"): 10504 # You cannot copy functions in Python, so we use closures here to 10505 # create objects with different ids. Otherwise, unittest.skip 10506 # would modify all methods sharing the same object id. Also, by 10507 # using a default argument, we create a copy instead of a 10508 # reference. Otherwise, we would lose access to the value. 10509 10510 @functools.wraps(value) 10511 def new_test(self, value=value): 10512 return value(self) 10513 10514 # Copy __dict__ which may contain test metadata 10515 new_test.__dict__ = copy.deepcopy(value.__dict__) 10516 10517 if xfail_prop is not None and hasattr(value, xfail_prop): 10518 new_test = unittest.expectedFailure(new_test) 10519 10520 tf = test_failures and test_failures.get(name) 10521 if tf is not None and suffix in tf.suffixes: 10522 skip_func = ( 10523 unittest.skip("Skipped!") 10524 if tf.is_skip 10525 else unittest.expectedFailure 10526 ) 10527 new_test = skip_func(new_test) 10528 10529 setattr(other_cls, f"{name}_{suffix}", new_test) 10530 10531 10532if HAS_CPU: 10533 10534 class SweepInputsCpuTest(SweepInputs2, TestCase): 10535 gen = InputGen(10, "cpu") 10536 10537 SweepInputsCpuTest.populate() 10538 10539 class CpuTests(TestCase): 10540 common = check_model 10541 device = "cpu" 10542 10543 copy_tests(CommonTemplate, CpuTests, "cpu") 10544 10545if HAS_GPU and not TEST_WITH_ASAN: 10546 10547 class SweepInputsGPUTest(SweepInputs2, TestCase): 10548 gen = InputGen(10, GPU_TYPE) 10549 10550 SweepInputsGPUTest.populate() 10551 10552 class GPUTests(TestCase): 10553 common = check_model_gpu 10554 device = GPU_TYPE 10555 10556 copy_tests(CommonTemplate, GPUTests, GPU_TYPE) 10557 10558 class TritonCodeGenTests(TestCase): 10559 from torch._inductor.runtime.triton_heuristics import CachingAutotuner 10560 10561 device_type = GPU_TYPE 10562 10563 class NoOpCompilerBackend: 10564 def __init__(self): 10565 self.example_args = None 10566 self.model = None 10567 10568 def noop_backend( 10569 self, 10570 model_: torch.fx.GraphModule, 10571 example_inputs_: typing.List[torch.Tensor], 10572 ): 10573 """ 10574 The Noop backend does not compile the fx graph it is given. 10575 Instead, it transforms the fx graph so that its functions are 10576 aten operations. It then saves this graph. 10577 """ 10578 from torch._inductor.decomposition import select_decomp_table 10579 from torch._subclasses import FakeTensorMode 10580 from torch.fx import Interpreter 10581 10582 fake_mode = FakeTensorMode() 10583 10584 def interpret(*args, **kwargs): 10585 return Interpreter(model_).run(*args[0:], **kwargs) 10586 10587 fake_flat_tensor_args = [ 10588 fake_mode.from_tensor(x) for x in example_inputs_ 10589 ] 10590 fw_module = make_fx(interpret, select_decomp_table())( 10591 *fake_flat_tensor_args 10592 ) 10593 self.model = fw_module 10594 self.example_args = fake_flat_tensor_args 10595 return lambda x: example_inputs_ 10596 10597 def get_kernels(self, fn, args) -> typing.List[CachingAutotuner]: 10598 from torch._inductor.debug import DebugContext 10599 from torch._inductor.graph import GraphLowering 10600 from torch._inductor.virtualized import V 10601 10602 cxt = TritonCodeGenTests.NoOpCompilerBackend() 10603 torch._dynamo.optimize(backend=cxt.noop_backend)(fn)(*args) 10604 graph = GraphLowering(cxt.model) 10605 kernels = [] 10606 with V.set_graph_handler(graph), V.set_debug_handler(DebugContext()): 10607 graph.run(*(cxt.example_args)) 10608 mod = graph.compile_to_module() 10609 10610 for val in mod.__dict__.values(): 10611 if isinstance( 10612 val, torch._inductor.runtime.triton_heuristics.CachingAutotuner 10613 ): 10614 kernels.append(val) 10615 10616 return kernels 10617 10618 def test_divisible_by_16_covers_numel_args(self): 10619 torch._dynamo.reset() 10620 10621 def fn(a: torch.Tensor) -> torch.Tensor: 10622 return torch.sum(a) 10623 10624 kernels = self.get_kernels(fn, [torch.randn([256, 256], device=GPU_TYPE)]) 10625 if config.triton.multi_kernel: 10626 self.assertTrue( 10627 len(kernels) == 4, 10628 "SUM should result in four kernels when multi-kernel is enabled", 10629 ) 10630 else: 10631 self.assertTrue(len(kernels) == 2, "SUM should result in two kernels") 10632 10633 # kernel0 reduces from 256 to (xnumel=8, rnumel=8192), which means it reduces 256 by 256 into an array of 10634 # size 8 by accumulating 8192 elements at once note that rnumel is equal to 512 * 16, so rnumel which is 10635 # at slot 3 should be in the divisible by 16 descriptor 10636 arguments_that_are_divisible_by_16_in_kernel0 = ( 10637 kernels[0].triton_meta["configs"][0].divisible_by_16 10638 ) 10639 self.assertEqual(arguments_that_are_divisible_by_16_in_kernel0, (0, 1, 3)) 10640 10641 # kernel1 reduces from 8 elements to a single scalar. 10642 # Since multi-kernel generate 2 variants for each kernel. The second 10643 # persistent-reduction has index 2. 10644 kernel1_index = 2 if config.triton.multi_kernel else 1 10645 arguments_that_are_divisible_by_16_in_kernel1 = ( 10646 kernels[kernel1_index].triton_meta["configs"][0].divisible_by_16 10647 ) 10648 self.assertEqual(arguments_that_are_divisible_by_16_in_kernel1, (0, 1)) 10649 torch._dynamo.reset() 10650 10651 @config.patch(assume_aligned_inputs=False) 10652 def test_codegen_config_option_dont_assume_alignment(self): 10653 def fn(x: torch.Tensor) -> torch.Tensor: 10654 return x.sin() + x.cos() 10655 10656 # We want code that assumes alignment if the initial input is 16-byte aligned 10657 for offset in (0, 1, 2, 3, 4): 10658 base = torch.randn(64 * 64 + 64, dtype=torch.float32, device=GPU_TYPE) 10659 inps = torch.as_strided(base, (64, 64), (64, 1), offset) 10660 torch._dynamo.reset() 10661 kernels = self.get_kernels(fn, [inps]) 10662 arguments_that_are_divisible_by_16 = ( 10663 kernels[0].triton_meta["configs"][0].divisible_by_16 10664 ) 10665 10666 # NO_ALIGN ALIGN ALIGN 10667 # def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr) 10668 10669 if offset % 4 == 0: 10670 expected_aligned = (0, 1, 2) 10671 else: 10672 expected_aligned = (1, 2) 10673 self.assertEqual(arguments_that_are_divisible_by_16, expected_aligned) 10674 10675 # If input isn't a view, storage offset != , inductor will assume alignment. 10676 torch._dynamo.reset() 10677 inp = torch.randn((64, 64), device=GPU_TYPE) 10678 kernels = self.get_kernels(fn, [inp]) 10679 arguments_that_are_divisible_by_16 = ( 10680 kernels[0].triton_meta["configs"][0].divisible_by_16 10681 ) 10682 self.assertEqual(arguments_that_are_divisible_by_16, (0, 1, 2)) 10683 10684 def test_optimize_indexing_dtype(self): 10685 def fn(x: torch.Tensor) -> torch.Tensor: 10686 return aten.upsample_bilinear2d.vec(x, None, True, [2.0, 2.0]) 10687 10688 fn_opt = torch._dynamo.optimize("inductor")(fn) 10689 inps = [torch.randn(2, 4, 16, 16, device=GPU_TYPE)] 10690 code = run_and_get_triton_code(fn_opt, *inps) 10691 self.assertTrue("to(tl.int32)" in code) 10692 self.assertFalse("to(tl.int64)" in code) 10693 10694 self.assertEqual(fn_opt(*inps), fn(*inps)) 10695 10696 def test_optimize_indexing_dtype_with_constraint(self): 10697 def fn1(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: 10698 x = torch.arange(0, b.shape[0], device=GPU_TYPE) 10699 y = ((x + x) / 3).int() 10700 return a[y.to(torch.int64)] 10701 10702 def fn2(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: 10703 torch._check_is_size(b.shape[0]) 10704 torch._check(b.shape[0] >= 2) 10705 torch._check(b.shape[0] <= 100) 10706 return fn1(a, b) 10707 10708 fn1_opt = torch._dynamo.optimize("inductor")(fn1) 10709 fn2_opt = torch._dynamo.optimize("inductor")(fn2) 10710 10711 a = torch.rand([100, 100], device=GPU_TYPE) 10712 b = torch.rand([100], device=GPU_TYPE) 10713 torch._dynamo.mark_dynamic(b, 0) 10714 inps = [a, b] 10715 10716 code1 = run_and_get_triton_code(fn1_opt, *inps) 10717 code2 = run_and_get_triton_code(fn2_opt, *inps) 10718 10719 # The function with the constrained tensor should be optimized, but 10720 # the other should not: 10721 self.assertTrue("to(tl.int64)" in code1) 10722 self.assertTrue("to(tl.int32)" in code2) 10723 self.assertFalse("to(tl.int64)" in code2) 10724 10725 self.assertEqual(fn1_opt(*inps), fn1(*inps)) 10726 self.assertEqual(fn2_opt(*inps), fn1(*inps)) 10727 10728 def test_constant_folding_deallocation(self): 10729 import torch._inductor 10730 10731 def fn(): 10732 li = [] 10733 for i in range(10): 10734 x = torch.full([100], i) 10735 x = x + 1 10736 li.append(x) 10737 10738 return li 10739 10740 mod = make_fx(fn)() 10741 10742 live_tensors = WeakTensorKeyDictionary() 10743 max_live_tensors = 0 10744 10745 class LiveTensors(TorchDispatchMode): 10746 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 10747 nonlocal live_tensors 10748 nonlocal max_live_tensors 10749 10750 kwargs = kwargs if kwargs else {} 10751 for arg in pytree.arg_tree_leaves(*args, **kwargs): 10752 if isinstance(arg, torch.Tensor): 10753 live_tensors[arg] = True 10754 10755 out = func(*args, **kwargs) 10756 if not isinstance(out, torch.Tensor): 10757 return out 10758 10759 live_tensors[out] = True 10760 max_live_tensors = max(max_live_tensors, len(live_tensors)) 10761 return out 10762 10763 mode = LiveTensors() 10764 from torch._inductor.fx_passes.joint_graph import UniformValueConstantFolder 10765 10766 with mode: 10767 UniformValueConstantFolder(mod).run() 10768 10769 # there are a couple extra tensors created in `insertable_tensor_check` 10770 self.assertTrue(max_live_tensors == 4) 10771 10772 # See https://github.com/pytorch/pytorch/issues/100348 10773 def test_inductor_detach_view(self): 10774 def fn(x: torch.Tensor) -> torch.Tensor: 10775 a = x * 2 10776 return a, a.detach() 10777 10778 fn_opt = torch._dynamo.optimize("inductor")(fn) 10779 inp = torch.ones(2, 2, requires_grad=True, device=GPU_TYPE) 10780 inp_ref = inp.clone().detach().requires_grad_(True) 10781 out_ref = fn(inp_ref) 10782 out = fn_opt(inp) 10783 out_ref[0].sum().backward() 10784 out[0].sum().backward() 10785 self.assertEqual(inp.grad, inp_ref.grad) 10786 10787 @skipIfRocm # asserts not implemented in Rocm yet 10788 def test_optimize_indexing_assert(self): 10789 def has_indirect(code, tl_fn: str): 10790 self.assertTrue( 10791 tl_fn in code, 10792 msg=f"{tl_fn} not present:\n{code}", 10793 ) 10794 for line in code.split("\n"): 10795 if tl_fn in line: 10796 stmt = line.split(tl_fn)[-1] 10797 # indirect indexing involves a `tmp` variable 10798 self.assertTrue( 10799 "tmp" in stmt, 10800 msg=f"Indirect indexing not present in code:\n{line}", 10801 ) 10802 10803 def has_assert(code, lower: bool, upper: bool): 10804 self.assertIn( 10805 "device_assert", code, msg=f"No device asert found:\n{code}" 10806 ) 10807 for line in code.split("\n"): 10808 if "device_assert" in line: 10809 self.assertTrue( 10810 ("0 <= " in line) is lower, 10811 msg=f"Lower bound {'' if lower else 'not '}elided:{line}", 10812 ) 10813 self.assertTrue( 10814 (" < " in line) is upper, 10815 msg=f"Upper bound {'' if upper else 'not '}elided:{line}", 10816 ) 10817 10818 def fn(x: torch.Tensor) -> torch.Tensor: 10819 s = 1.0 * torch.arange(x.shape[0], device=x.device) 10820 return x[s.long()] 10821 10822 # aten.index 10823 for dynamic in (False, True): 10824 fn_opt = torch.compile(fn, dynamic=dynamic) 10825 10826 x = torch.randn(8, device=GPU_TYPE) 10827 code = run_and_get_triton_code(fn_opt, x) 10828 self.assertEqual(fn_opt(x), fn(x), msg=f"{dynamic=}") 10829 10830 # Check that there's indirect indexing... 10831 has_indirect(code, tl_fn="tl.load") 10832 if not dynamic: 10833 # We elide the assert for static shapes 10834 self.assertNotIn("device_assert", code) 10835 else: 10836 # ...but we generate an upper bound for dynamic shapes 10837 has_assert(code, lower=False, upper=True) 10838 10839 def fn(a, z, b, idx0, idx1): 10840 idx2 = torch.arange(a.shape[-1], device=a.device) 10841 a.index_put_((z, idx0, idx1, idx2), b, accumulate=True) 10842 return a 10843 10844 # aten.index_put 10845 for dynamic in (False, True): 10846 fn_opt = torch.compile(fn, dynamic=dynamic) 10847 a = torch.randn(1, 32, 32, 4, device=GPU_TYPE) 10848 z = torch.zeros((), dtype=torch.int64, device=GPU_TYPE) 10849 b = torch.randn(33, 1, device=GPU_TYPE) 10850 idx0 = torch.randint(32, (33,), device=GPU_TYPE).view(33, 1, 1) 10851 idx1 = torch.randint(32, (33,), device=GPU_TYPE).view(33, 1) 10852 inps = (a.clone(), z, b, idx0, idx1) 10853 code = run_and_get_triton_code(fn_opt, *inps) 10854 10855 # Correctness 10856 out_opt = fn_opt(a.clone(), z, b, idx0, idx1) 10857 out = fn(a.clone(), z, b, idx0, idx1) 10858 self.assertEqual(out_opt, out, msg=f"{dynamic=}") 10859 10860 # We have an indirect store via atomic_add 10861 has_indirect(code, tl_fn="tl.atomic_add") 10862 # We cannot elide he assert in this case 10863 has_assert(code, lower=True, upper=True) 10864 10865 def test_not_materialize_pointwise_reduction(self): 10866 def fn(a, b): 10867 return (a - b).sum(dim=-1).amax(dim=-1) 10868 10869 N = 16 10870 K = 7 10871 fn_opt = torch._dynamo.optimize("inductor")(fn) 10872 inps = [ 10873 torch.randn(N, 1, K, device=GPU_TYPE), 10874 torch.randn(1, N, K, device=GPU_TYPE), 10875 ] 10876 code = run_and_get_triton_code(fn_opt, *inps) 10877 self.assertEqual( 10878 code.count("tl.store"), 2 if config.triton.multi_kernel else 1 10879 ) 10880 self.assertTrue("out_ptr1" in code) 10881 self.assertFalse("out_ptr0" in code) 10882 self.assertEqual(fn_opt(*inps), fn(*inps)) 10883 10884 def test_numpy_on_gpu(self): 10885 x = np.arange(10, dtype=np.float32) 10886 10887 @torch.compile 10888 def fn(x): 10889 return np.sin(x) 10890 10891 def fn_gpu(x): 10892 with torch.device(GPU_TYPE): 10893 return fn(x) 10894 10895 r = fn_gpu(x) 10896 code = run_and_get_triton_code(fn_gpu, x) 10897 self.assertIn("tl_math.sin", code) 10898 self.assertEqual(type(r), np.ndarray) 10899 self.assertEqual(r, np.sin(x)) 10900 10901 def test_numpy_autograd(self): 10902 def my_torch(x): 10903 y = torch.cat([torch.sin(x) ** 2, torch.max(x)[None]]) 10904 return y.sum() 10905 10906 def my_np(x): 10907 y = np.concatenate([np.sin(x) ** 2, np.max(x)[None]]) 10908 return np.sum(y) 10909 10910 @torch.compile 10911 def wrapper(x): 10912 return torch.compiler.wrap_numpy(my_np)(x) 10913 10914 @torch.compile 10915 def wrapper2(x): 10916 x = x.numpy() 10917 y = my_np(x) 10918 return torch.from_numpy(y) 10919 10920 x_np = torch.arange(8, dtype=torch.float32, requires_grad=True) 10921 x = torch.arange(8, dtype=torch.float32, requires_grad=True) 10922 out_np = wrapper(x_np) 10923 out = my_torch(x) 10924 self.assertEqual(out, out_np) 10925 10926 x2_np = torch.arange(8, dtype=torch.float32, requires_grad=True) 10927 out2_np = wrapper2(x2_np) 10928 self.assertEqual(out, out2_np) 10929 10930 out_np.backward() 10931 out.backward() 10932 self.assertEqual(x.grad, x_np.grad) 10933 10934 out2_np.backward() 10935 self.assertEqual(x.grad, x2_np.grad) 10936 10937 # Disable constant propagation, so we isolate value range analysis 10938 @patch.object(config, "constant_and_index_propagation", False) 10939 @patch.object(config, "joint_graph_constant_folding", False) 10940 def test_cant_optimize_compute(self): 10941 def ones(): 10942 return torch.ones([4], device=GPU_TYPE) 10943 10944 def suffix(inp): 10945 return (inp.to(torch.int64) + 1).to(torch.float64) 10946 10947 ten = torch.rand([4], device=GPU_TYPE) 10948 10949 for foo in ( 10950 lambda x: x + 2147483657, 10951 lambda x: torch.where(x < 0, ones(), ones() - 2) * (-(2 ** (40))), 10952 lambda x: x + ten, 10953 lambda x: x + ten.sum(), 10954 ): 10955 10956 def fn(): 10957 return suffix(foo(ones())) 10958 10959 fn_opt = torch._dynamo.optimize("inductor")(fn) 10960 code = run_and_get_triton_code(fn_opt) 10961 10962 # this cannot be optimized away, value too large 10963 self.assertTrue("to(tl.int64)" in code) 10964 self.assertEqual(fn_opt(), fn()) 10965 10966 # Disable constant propagation, so we isolate value range analysis 10967 @patch.object(config, "constant_and_index_propagation", False) 10968 @patch.object(config, "joint_graph_constant_folding", False) 10969 def test_optimize_compute(self): 10970 def ones(): 10971 return torch.ones([4], device=GPU_TYPE) 10972 10973 def suffix(inp): 10974 return (inp.to(torch.int64) + 1).to(torch.float64) 10975 10976 for foo in ( 10977 lambda x: x + 500, 10978 lambda x: torch.where(x < 0, ones(), ones() - 2) * (-(2 ** (20))), 10979 lambda x: x / 30, 10980 ): 10981 10982 def fn(): 10983 return suffix(foo(ones())) 10984 10985 fn_opt = torch._dynamo.optimize("inductor")(fn) 10986 code = run_and_get_triton_code(fn_opt) 10987 10988 # this can be optimized away, value too large 10989 self.assertTrue("to(tl.int64)" not in code) 10990 self.assertTrue("to(tl.int32)" in code) 10991 10992 self.assertEqual(fn_opt(), fn()) 10993 10994 @config.patch("triton.use_block_ptr", False) 10995 def test_evict_last_non_coalesced_loads(self): 10996 @torch.compile 10997 def f(a, b): 10998 return (a * b).sum(dim=-1) 10999 11000 N = 512 11001 inps = ( 11002 torch.randn(N, N, N, device=GPU_TYPE).permute(2, 1, 0), 11003 torch.randn(N, N, N, device=GPU_TYPE).permute(1, 2, 0), 11004 ) 11005 code = run_and_get_triton_code(f, *inps) 11006 lines = [line for line in code.split("\n") if "tl.load" in line] 11007 if config.triton.multi_kernel: 11008 # the first 2 lines are generated for the persistent reduction 11009 # variant. 11010 self.assertExpectedInline( 11011 "\n".join(lines), 11012 """\ 11013 tmp0 = tl.load(in_ptr0 + (x1 + (512*x0) + (262144*r2)), rmask, eviction_policy='evict_last', other=0.0) 11014 tmp1 = tl.load(in_ptr1 + (x3 + (262144*r2)), rmask, other=0.0) 11015 tmp0 = tl.load(in_ptr0 + (x1 + (512*x0) + (262144*r2)), rmask, eviction_policy='evict_last', other=0.0) 11016 tmp1 = tl.load(in_ptr1 + (x3 + (262144*r2)), rmask, eviction_policy='evict_first', other=0.0)""", 11017 ) 11018 else: 11019 self.assertExpectedInline( 11020 "\n".join(lines), 11021 """\ 11022 tmp0 = tl.load(in_ptr0 + (x1 + (512*x0) + (262144*r2)), rmask, eviction_policy='evict_last', other=0.0) 11023 tmp1 = tl.load(in_ptr1 + (x3 + (262144*r2)), rmask, eviction_policy='evict_first', other=0.0)""", 11024 ) 11025 11026 @skipIfRocm 11027 @config.patch("triton.use_block_ptr", True) 11028 def test_evict_last_non_coalesced_loads_block_ptr(self): 11029 @torch.compile 11030 def f(a, b): 11031 return (a * b).sum(dim=-1) 11032 11033 N = 512 11034 inps = ( 11035 torch.randn(N, N, N, device=GPU_TYPE).permute(2, 1, 0), 11036 torch.randn(N, N, N, device=GPU_TYPE).permute(1, 2, 0), 11037 ) 11038 code = run_and_get_triton_code(f, *inps) 11039 lines = [line for line in code.split("\n") if "tl.load" in line] 11040 11041 if config.triton.multi_kernel: 11042 # the first 2 lines are generated for the persistent reduction 11043 # variant. 11044 self.assertExpectedInline( 11045 "\n".join(lines), 11046 """\ 11047 tmp0 = tl.load(in_ptr0 + (x1 + (512*x0) + (262144*r2)), rmask, eviction_policy='evict_last', other=0.0) 11048 tmp1 = tl.load(tl.make_block_ptr(in_ptr1, shape=[262144, 512], strides=[1, 262144], block_shape=[XBLOCK, RBLOCK], order=[0, 1], offsets=[xoffset, roffset]), boundary_check=[1], padding_option='zero') 11049 tmp0 = tl.load(in_ptr0 + (x1 + (512*x0) + (262144*r2)), rmask, eviction_policy='evict_last', other=0.0) 11050 tmp1 = tl.load(block_ptr0, boundary_check=[1], padding_option='zero', eviction_policy='evict_first')""", # noqa: B950 line too long 11051 ) 11052 else: 11053 self.assertExpectedInline( 11054 "\n".join(lines), 11055 """\ 11056 tmp0 = tl.load(in_ptr0 + (x1 + (512*x0) + (262144*r2)), rmask, eviction_policy='evict_last', other=0.0) 11057 tmp1 = tl.load(block_ptr0, boundary_check=[1], padding_option='zero', eviction_policy='evict_first')""", 11058 ) 11059 11060 # Disable index propagation, so the indirect indexing isn't optimized away 11061 @patch.object(config, "constant_and_index_propagation", False) 11062 def test_computed_indirect_mask(self): 11063 def fn(x, n): 11064 tmp = torch.arange(n, device=x.device) 11065 return x[tmp] + 1 11066 11067 x = torch.randn(8, device=GPU_TYPE) 11068 fn_opt = torch.compile(fn) 11069 code = run_and_get_triton_code(fn_opt, x, 8) 11070 # load should be masked 11071 self.assertTrue("tl.load(in_ptr0 + (tmp0), xmask" in code) 11072 self.assertEqual(fn(x, 8), fn_opt(x, 8)) 11073 11074 def test_kernel_names_descriptive(self): 11075 @torch._dynamo.optimize("inductor") 11076 def fn1(x): 11077 return x.cos().sin() 11078 11079 @torch._dynamo.optimize("inductor") 11080 def fn2(x): 11081 x = torch.mm(x, x) 11082 x = torch.softmax(x, dim=1) 11083 return x 11084 11085 mod = nn.Sequential( 11086 nn.Linear(4, 4), 11087 nn.LayerNorm(4), 11088 nn.ReLU(), 11089 ).to(device=GPU_TYPE) 11090 11091 @torch._dynamo.optimize("inductor") 11092 def fn3(x): 11093 return mod(x) 11094 11095 func_and_kernel_aten = [ 11096 (fn1, "triton_poi_fused_cos_sin", (torch.randn(8, device=GPU_TYPE),)), 11097 ( 11098 fn2, 11099 "triton_poi_fused__softmax", 11100 (torch.randn(4, 4, device=GPU_TYPE),), 11101 ), 11102 ( 11103 fn3, 11104 "triton_poi_fused_native_layer_norm_relu", 11105 (torch.randn(4, 4, device=GPU_TYPE),), 11106 ), 11107 ] 11108 func_and_kernel_torch = [ 11109 (fn1, "triton_poi_fused_cos_sin", (torch.randn(8, device=GPU_TYPE),)), 11110 ( 11111 fn2, 11112 "triton_poi_fused_softmax", 11113 (torch.randn(4, 4, device=GPU_TYPE),), 11114 ), 11115 ( 11116 fn3, 11117 "triton_poi_fused_layer_norm_relu" 11118 if torch._dynamo.config.inline_inbuilt_nn_modules 11119 else "triton_poi_fused_LayerNorm_ReLU", 11120 (torch.randn(4, 4, device=GPU_TYPE),), 11121 ), 11122 ] 11123 11124 def test_funcs(func_and_kernel): 11125 with torch.no_grad(): 11126 for fn, kernel_name, inps in func_and_kernel: 11127 code = run_and_get_triton_code(fn, *inps) 11128 if kernel_name not in code: 11129 print(code) 11130 self.assertTrue(kernel_name in code) 11131 11132 test_funcs(func_and_kernel_aten) 11133 patch.object(config.triton, "descriptive_names", "torch")(test_funcs)( 11134 func_and_kernel_torch 11135 ) 11136 11137 @patch.object(config, "profile_bandwidth", True) 11138 def test_bandwidth_profiler(self): 11139 @torch._dynamo.optimize("inductor") 11140 def fn(x): 11141 x = x.cos() 11142 x = x.cos() 11143 x = torch.mm(x, x) 11144 x = x.sin() 11145 x = x.relu() 11146 return x 11147 11148 inp = torch.randn(4, 4, device=GPU_TYPE) 11149 code = run_and_get_triton_code(fn, inp) 11150 fn(inp) 11151 self.assertTrue("start_graph" in code) 11152 self.assertTrue("end_graph" in code) 11153 11154 def test_split_op_with_sym(self): 11155 def fn(x: torch.Tensor) -> torch.Tensor: 11156 # split(tensor, sympy.Integer), split(tensor, sympy.Expr) 11157 return torch.split(x, x.shape[0]), torch.split(x, x.shape[0] // 2) 11158 11159 for dynamic_shapes in [True, False]: 11160 with torch._dynamo.config.patch(dynamic_shapes=dynamic_shapes): 11161 torch._dynamo.reset() 11162 fn_opt = torch._dynamo.optimize("inductor", dynamic=dynamic_shapes)( 11163 fn 11164 ) 11165 inps = torch.randn([5, 5]) 11166 fn_opt(inps) 11167 11168 @skipIfRocm 11169 @unittest.skipIf(IS_FBCODE, "fbcode system python does not provide torch") 11170 def test_indirect_device_assert(self): 11171 dir_path = os.path.dirname(os.path.realpath(__file__)) 11172 test_path = os.path.join(dir_path, "indirect_assert_helper.py") 11173 fns = ("first_arg", "store", "second_arg", "same_pm_one", "same_pp_one") 11174 11175 def test(fn, ndims, dyn_shape, one_size=False): 11176 proc = subprocess.Popen( 11177 [ 11178 sys.executable, 11179 test_path, 11180 fn, 11181 str(ndims), 11182 str(dyn_shape), 11183 str(one_size), 11184 ], 11185 stdout=subprocess.PIPE, 11186 stderr=subprocess.PIPE, 11187 env={**os.environ, "MKL_THREADING_LAYER": "GNU"}, 11188 ) 11189 stderr = proc.communicate()[1] 11190 self.assertTrue( 11191 any( 11192 "out of bounds" in err.decode("utf-8") 11193 for err in stderr.splitlines() 11194 ), 11195 f"{fn}, {ndims}, {dyn_shape}, {one_size}", 11196 ) 11197 11198 for fn, ndims, dyn_shape in itertools.product(fns, (2, 3), (True, False)): 11199 test(fn, ndims, dyn_shape) 11200 11201 test("first_arg", 2, False, True) 11202 11203 for fn, dyn_shape in itertools.product( 11204 ("upper1", "upper2", "lower1", "lower2"), (True, False) 11205 ): 11206 test(fn, 2, dyn_shape) 11207 11208 @patch("torch._inductor.config.comment_origin", True) 11209 @patch("torch._functorch.config.max_dist_from_bw", 0) 11210 def test_inductor_sequence_nr(self): 11211 class Model(torch.nn.Module): 11212 def __init__(self): 11213 super().__init__() 11214 self.conv1 = torch.nn.Conv2d( 11215 in_channels=16, 11216 out_channels=16, 11217 kernel_size=(1, 1), 11218 stride=1, 11219 padding="same", 11220 bias=True, 11221 ) 11222 self.bn1 = torch.nn.BatchNorm2d(num_features=16) 11223 self.relu1 = torch.nn.ReLU() 11224 self.loss_fn = torch.nn.L1Loss() 11225 11226 def forward(self, x, target): 11227 y = x 11228 x = self.conv1(x) 11229 x = self.bn1(x) 11230 x = self.relu1(x) 11231 x = x + y 11232 x = torch.flatten(x) 11233 output = self.loss_fn(x, target) 11234 return (output,) 11235 11236 def get_triton_codegen(optimized_module, args): 11237 def run_with_backward(): 11238 result = optimized_module(*args) 11239 result[0].backward() 11240 return result 11241 11242 res, (fwd_code, bwd_code) = run_and_get_code(run_with_backward) 11243 return fwd_code, bwd_code 11244 11245 x = torch.rand(100, 16, 32, 32, requires_grad=True, device=GPU_TYPE) 11246 target = torch.rand(1, device=GPU_TYPE) 11247 args = [x, target] 11248 model = Model().to(device=GPU_TYPE) 11249 opt_model = torch.compile(model) 11250 fwd_code, bwd_code = get_triton_codegen(opt_model, args) 11251 11252 bwd_seq_nr_set = set() 11253 fwd_seq_nr_set = set() 11254 for idx, code in enumerate([fwd_code, bwd_code]): 11255 seq_nr_set = bwd_seq_nr_set if idx > 0 else fwd_seq_nr_set 11256 prefix = "BWD" if idx > 0 else "FWD" 11257 for line in code.split("\n"): 11258 if "seq_nr" in line: 11259 res = re.search(r"seq_nr:(\d+)", line) 11260 if res: 11261 seq_nr_set.add(int(res.group(1))) 11262 self.assertTrue(bwd_seq_nr_set.issubset(fwd_seq_nr_set)) 11263 11264 @config.patch( 11265 { 11266 "coordinate_descent_tuning": True, 11267 "triton.unique_kernel_names": True, 11268 "benchmark_kernel": True, 11269 } 11270 ) 11271 @skipIfRocm 11272 @expectedFailureXPU 11273 @unittest.skipIf( 11274 torch.cuda.is_available() and torch.cuda.get_device_capability() < (9, 0), 11275 "Triton does not support fp8 on A100", 11276 ) 11277 def test_red_followed_by_transposed_pointwise(self): 11278 bs = 26624 11279 dim = 1024 11280 11281 @torch.compile(dynamic=False) 11282 def f(in1, in2, a, b): 11283 out = torch.nn.functional.silu(in1) * in2 11284 out_row = (out / out.amax(dim=1, keepdim=True)).to(torch.float8_e4m3fn) 11285 out_col = (out / out.amax(dim=0, keepdim=True)).to(torch.float8_e4m3fn) 11286 11287 # setup strides for _scaled_mm 11288 out_row = out_row.contiguous() 11289 out_col = out_col.t().contiguous().t() 11290 11291 return ( 11292 torch._scaled_mm(out_row, a, out_dtype=torch.bfloat16)[0], 11293 torch._scaled_mm(b, out_col, out_dtype=torch.bfloat16)[0], 11294 ) 11295 11296 in1 = torch.randn((bs, dim), dtype=torch.bfloat16, device=GPU_TYPE) 11297 in2 = torch.randn((bs, dim), dtype=torch.bfloat16, device=GPU_TYPE) 11298 a = ( 11299 torch.randn((dim, dim), dtype=torch.bfloat16, device=GPU_TYPE) 11300 .t() 11301 .to(torch.float8_e4m3fn) 11302 ) 11303 b = torch.randn((dim, bs), dtype=torch.bfloat16, device=GPU_TYPE).to( 11304 torch.float8_e4m3fn 11305 ) 11306 11307 # warmup 11308 _, (wrapper,) = run_and_get_code(f, in1, in2, a, b) 11309 11310 # Previously indcutor decide reduction hint for a reduction kernel without considering 11311 # the pointwise nodes. That will cause the third reduction kernel in this wrapper to be a 11312 # persistent inner reduction and cause bad perf. 11313 # 11314 # We fix that by making the third reduction a non-persistent reduction 11315 # and improve the perf by 4.14x (451us -> 109us) 11316 self.assertEqual(3, wrapper.count("def triton_red_")) 11317 self.assertEqual(0, wrapper.count("def triton_per_")) 11318 11319 if DO_PERF_TEST: 11320 with torch.profiler.profile( 11321 activities=[torch.profiler.ProfilerActivity.CUDA] 11322 ) as p: 11323 for _ in range(1000): 11324 f(in1, in2, a, b) 11325 11326 print(p.key_averages().table(max_name_column_width=200)) 11327 11328 class RNNTest(TestCase): 11329 device_type = GPU_TYPE 11330 11331 class Model(torch.nn.Module): 11332 def __init__(self): 11333 super().__init__() 11334 self.gru = torch.nn.GRU(16, 16, batch_first=True) 11335 11336 def forward(self, x): 11337 return self.gru(x) 11338 11339 @expectedFailureXPU 11340 def test_rnn_compile_safe(self): 11341 device = torch.device(GPU_TYPE) 11342 model = RNNTest.Model().to(device) 11343 model = torch._dynamo.optimize("inductor")(model) 11344 x = torch.rand(1024, 20, 16).to(device) 11345 model(x) 11346 11347 class NanCheckerTest(TestCase): 11348 @config.patch("nan_asserts", True) 11349 def test_nan_checker_pass(self): 11350 def f(x): 11351 return torch.softmax(x, dim=-1) 11352 11353 x = torch.randn(2, 1024, device=GPU_TYPE) 11354 ref = f(x) 11355 actual, (code,) = run_and_get_code(torch.compile(f), x) 11356 self.assertTrue(torch.allclose(ref, actual)) 11357 self.assertTrue("# make sure graph inputs are not nan/inf" in code) 11358 self.assertTrue( 11359 re.search(r"assert not .*\.isnan\(\)\.any\(\).item\(\)", code) 11360 is not None 11361 ) 11362 self.assertTrue( 11363 re.search(r"assert not .*\.isinf\(\)\.any\(\).item\(\)", code) 11364 is not None 11365 ) 11366 11367 @config.patch("nan_asserts", True) 11368 def test_nan_checker_fail(self): 11369 def f(x): 11370 return torch.softmax(x, dim=-1) 11371 11372 x = torch.randn(2, 1024, device=GPU_TYPE) 11373 x[0, 0] = float("nan") 11374 with self.assertRaises(AssertionError): 11375 torch.compile(f)(x) 11376 11377 11378if HAS_CPU: 11379 11380 class TestFull(TestCase): 11381 def test_full_dtype(self): 11382 pytypes = ( 11383 bool, 11384 int, 11385 float, 11386 # TODO: Triton's JITFunction._type_of has no support for complex 11387 # complex, 11388 ) 11389 11390 dtypes = ( 11391 torch.bool, 11392 torch.int32, 11393 torch.int64, 11394 torch.float32, 11395 torch.float64, 11396 None, 11397 # torch.complex64, 11398 # torch.complex128, 11399 ) 11400 11401 def fn(pytype, dtype): 11402 if pytype is bool: 11403 fill_value = True 11404 elif pytype is int: 11405 fill_value = 42 11406 elif pytype is float: 11407 fill_value = 42.0 11408 else: 11409 raise AssertionError(f"Unexpected Python type: {pytype}") 11410 11411 return torch.full( 11412 (4, 6), fill_value, dtype=dtype, device=torch.device("cpu") 11413 ) 11414 11415 fn_opt = torch._dynamo.optimize("inductor")(fn) 11416 11417 for pytype, dtype in itertools.product(pytypes, dtypes): 11418 with enable_python_dispatcher(): 11419 with torch.no_grad(): 11420 ret_opt = fn_opt(pytype, dtype) 11421 11422 self.assertEqual(ret_opt, fn(pytype, dtype)) 11423 11424 11425if __name__ == "__main__": 11426 from torch._inductor.test_case import run_tests 11427 11428 if HAS_CPU or HAS_GPU: 11429 run_tests(needs="filelock") 11430