# Owner(s): ["oncall: pt2"] # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import copy import itertools import unittest import warnings from contextlib import nullcontext from functools import partial, wraps from typing import Any, Callable, Dict, List, Optional, Union from unittest.mock import patch from common_utils import decorate, decorateForModules, skip, skipOps, xfail import torch import torch._dynamo as torchdynamo import torch.nn as nn import torch.utils._pytree as pytree from functorch import grad, jacrev, make_fx, vjp, vmap from functorch.compile import ( aot_function, aot_module, aot_module_simplified, compiled_function, compiled_module, default_decompositions, default_partition, get_aot_compilation_context, make_boxed_compiler, make_boxed_func, memory_efficient_fusion, min_cut_rematerialization_partition, nnc_jit, nop, ) from functorch.experimental import control_flow from torch._decomp import decomposition_table from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache from torch._functorch.aot_autograd import aot_export_joint_simple, aot_export_module from torch._higher_order_ops.out_dtype import out_dtype from torch._inductor.codecache import compiled_fx_graph_hash from torch._subclasses.fake_tensor import DynamicOutputShapeException, FakeTensorMode from torch.fx.experimental.proxy_tensor import is_sym_node from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode, ShapeEnv from torch.nn.utils.rnn import PackedSequence from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, ops, tol, toleranceOverride, ) from torch.testing._internal.common_methods_invocations import op_db from torch.testing._internal.common_modules import module_db, modules from torch.testing._internal.common_utils import ( compare_equal_outs_and_grads, instantiate_parametrized_tests, IS_ARM64, IS_MACOS, IS_WINDOWS, IS_X86, outs_and_grads, parametrize, run_tests, skipIfRocm, skipIfTorchDynamo, TestCase, xfail_inherited_tests, xfailIfTorchDynamo, ) from torch.testing._internal.custom_tensor import ConstantExtraMetadataTensor from torch.testing._internal.hop_db import hop_db from torch.testing._internal.optests import ( _test_aot_autograd_forwards_backwards_helper, aot_autograd_check, ) from torch.testing._internal.two_tensor import TwoTensor, TwoTensorMode USE_TORCHVISION = False try: import torchvision USE_TORCHVISION = True except ImportError: warnings.warn( "Couldn't import torchvision. Some of our tests use it, try " "to install it with commands from pytorch.org, post-fixed with " "`--no-deps` to avoid overwriting the pytorch installation", UserWarning, ) USE_NETWORKX = False try: import networkx # noqa: F401 USE_NETWORKX = True except ImportError: warnings.warn("Some tests use networkx but it was not installed", UserWarning) # NB: numpy is a testing dependency! class AOTTestCase(TestCase): pass class TestPythonKey(AOTTestCase): def test_make_fx(self, device): def f(x): return torch.sin(x) inp = torch.randn(3) fx_f = make_fx(f)(inp) new_inp = torch.randn(3) self.assertEqual(fx_f(new_inp), f(new_inp)) def test_make_fx_grad(self, device): def f(x): return torch.sin(x).sum() inp = torch.randn(3) f = grad(f) fx_f = make_fx(f)(inp) new_inp = torch.randn(3) self.assertEqual(fx_f(new_inp), f(new_inp)) def test_scalar_device(self, device): def f(a, b): return a + b inps = [torch.randn(3, device=device), torch.tensor(5)] fx_f = make_fx(f)(*inps) self.assertEqual(fx_f(*inps), f(*inps)) def test_make_fx_vmap(self, device): def f(x): return torch.sin(x) inp = torch.randn(5, 3) f = vmap(f) fx_f = make_fx(f)(inp) new_inp = torch.randn(5, 3) self.assertEqual(fx_f(new_inp), f(new_inp)) def test_make_fx_jacrev(self, device): def f(x): return x.sin().sum() inp = torch.randn(3) f = jacrev(jacrev(f)) fx_f = make_fx(f)(inp) new_inp = torch.randn(3) self.assertEqual(fx_f(new_inp), f(new_inp)) def test_make_fx_vjp(self, device): def f(x): return torch.sin(x).sum() primals = torch.randn(3) _, vjp_fn = vjp(f, primals) cotangent = torch.randn(()) fx_f = make_fx(vjp_fn)(cotangent, True, True) new_cotangent = torch.randn(()) self.assertEqual(fx_f(new_cotangent, True, True), vjp_fn(new_cotangent)) def test_make_fx_functionalize(self, device): from functorch.experimental import functionalize def fn(a): a = a * 2 a.relu_() return a a = torch.randn(3, device=device) symbolic_gm = torch.fx.symbolic_trace(fn) includes_method_relu_ = any( str(n.target) == "relu_" for n in symbolic_gm.graph.nodes ) self.assertTrue(includes_method_relu_) # Also verifies fix for https://github.com/pytorch/pytorch/issues/84570 gm = make_fx(functionalize(symbolic_gm))(a) includes_aten_relu = any( n.target == torch.ops.aten.relu.default for n in gm.graph.nodes ) self.assertTrue(includes_aten_relu) def test_make_fx_no_decompose(self, device): # FIXME return self.skipTest("error: maximum recursion reached") def f(x): return torch.tanh(x).sum() fx_f = make_fx(grad(f))(torch.randn(5)) ops = {i.target for i in fx_f.graph.nodes} self.assertEqual(torch.ops.aten.tanh_backward in ops, True) fx_f = make_fx(grad(f), decomposition_table)(torch.randn(5)) ops = {i.target for i in fx_f.graph.nodes} self.assertEqual(torch.ops.aten.tanh_backward in ops, False) def test_nnc_jit(self, device): def f(x): return torch.sin(x) jit_f = nnc_jit(f) inp = torch.randn(3) self.assertEqual(jit_f(inp), f(inp)) def test_nnc_scalar(self, device): def f(x): return torch.sin(x) jit_f = nnc_jit(f) inp = torch.randn(()) self.assertEqual(jit_f(inp), f(inp)) def test_nnc_pytrees(self, device): def f(x): return [torch.sin(x[0])] jit_f = nnc_jit(f) inp = [torch.randn(3)] self.assertEqual(jit_f(inp), f(inp)) def test_external_calls(self, device): def f(a, b): return torch.mv(a, b) jit_f = nnc_jit(f) inp = [torch.randn(3, 3), torch.randn(3)] self.assertEqual(jit_f(*inp), f(*inp)) def test_nnc_passthrough(self, device): def f(x, y): return x + y, y inp = (torch.randn(3), torch.randn(3)) jit_f = nnc_jit(f) self.assertEqual(jit_f(*inp), f(*inp)) def f(x): x["a"] = x["a"] * 2 return x inp = ({"a": torch.randn(3), "b": torch.randn(3)},) jit_f = nnc_jit(f) self.assertEqual(jit_f(*inp), f(*inp)) @unittest.skipIf(not USE_TORCHVISION, "test requires torchvision") def test_resnet18_backward_trace(self, device): mod = torchvision.models.resnet18() def f(x): out = mod(x) out.sum().backward() return [a.grad for a in mod.parameters()] inp = torch.randn(3, 3, 250, 250, requires_grad=True) grads = f(inp) mod.zero_grad() mod(inp).sum().backward() grads2 = [a.grad for a in mod.parameters()] self.assertEqual(grads, grads2) def get_base(t): return t._base if t._is_view() else t def is_in_base(t, maybe_tensors): t_base = get_base(t) for maybe_tensor in maybe_tensors: if isinstance(maybe_tensor, torch.Tensor): if t_base is get_base(maybe_tensor): return True return False def skipIfDynamoInput(reason): """ Skip TestAOTAutograd if running with dynamo input """ def decorator(func): @wraps(func) def wrapper(self, *args, **kwargs): if isinstance(self, TestAOTAutogradWithDynamo): self.skipTest( f"Skipping {self._testMethodName} in TestAOTAutogradWithDynamo because {reason}" ) else: func(self, *args, **kwargs) return wrapper return decorator class TestAOTAutograd(AOTTestCase): def run_autograd( self, f: Callable, fw_graph_cell: List[Optional[Callable]], decompositions: Optional[Dict], keep_input_mutations: bool, dynamic: bool, ): """ Runs aot_autograd with the specified settings on f. """ if isinstance(f, nn.Module): compiled_f = aot_module( f, fw_compiler=make_boxed_compiler( partial(extract_graph, graph_cell=fw_graph_cell) ), bw_compiler=nop, decompositions=decompositions, keep_inference_input_mutations=keep_input_mutations, dynamic=dynamic, ) else: compiled_f = aot_function( f, fw_compiler=make_boxed_compiler( partial(extract_graph, graph_cell=fw_graph_cell) ), bw_compiler=nop, decompositions=decompositions, keep_inference_input_mutations=keep_input_mutations, dynamic=dynamic, ) return compiled_f # test_mutation will: # - Ensure that inputs are non-leaves, so our graphs can mutate them # - try to mutate outputs of the graph (to ensure that autograd meta is set properly on outputs) @patch("functorch.compile.config.debug_assert", True) def verify_aot_autograd( self, f, inp_: Union[Callable, List[Any]], *, test_mutation: bool = False, keep_inp_mutations: bool = False, decompositions: Optional[Dict] = None, dynamic: bool = False, # Only active when inp_ is Callable. # TODO: probably consolidate all tests to make inp a Callable. make_inputs_subclasses: bool = False, ): def make_inputs(inp_): # Some tests pass in a callable for inp, to generate the inputs # (useful if we want to generate complicated aliasing inputs) if isinstance(inp_, Callable): inp_callable = inp_ # The callable should return a tuple of f_inputs, f_graph_inputs # (The idea is that we might want to compile a function with the graph inputs, # but test autograd backprop all the way through the actual inputs) with TwoTensorMode() if make_inputs_subclasses else nullcontext(): inp, graph_inps = inp_callable() else: inp = [] # Our input clones need to mimic when inputs are duplicates of one another dupes_map = {} for i, x in enumerate(inp_): if x in dupes_map: x_dupe_idx = dupes_map[x] inp.append(inp[x_dupe_idx]) else: dupes_map[x] = i if not isinstance(x, torch.Tensor): x_copy = x else: x_copy = x.clone().detach().requires_grad_(x.requires_grad) if x.requires_grad and not x.is_leaf: x_copy = x_copy.clone() inp.append(x_copy) if test_mutation: # For graphs where we mutate inputs, need our test to make sure inputs aren't leaves graph_inps = [x.add(1) for x in inp] else: graph_inps = inp return inp, graph_inps def check_results( ref_results, test_results, ref_graph_inps, test_graph_inps, ref_inp, test_inp, ): ref_out, ref_grad = ref_results test_out, test_grad = test_results self.assertEqual(ref_grad, test_grad) if isinstance(ref_out, torch.Tensor): self.assertTrue(isinstance(test_out, torch.Tensor)) ref_out, test_out = [ref_out], [test_out] for ref_o, test_o in zip(ref_out, test_out): if isinstance(ref_o, torch.Tensor): self.assertEqual(ref_o.requires_grad, test_o.requires_grad) self.assertEqual(ref_o.is_leaf, test_o.is_leaf) ref_is_view_of_non_interm = is_in_base( ref_o, ref_graph_inps ) or is_in_base(ref_o, ref_out) test_is_view_of_non_interm = is_in_base( test_o, test_graph_inps ) or is_in_base(test_o, test_out) self.assertEqual( ref_is_view_of_non_interm, test_is_view_of_non_interm ) self.assertEqual(ref_o, test_o) if test_mutation: # This tests that autograd meta is set properly on the output we can # mutate it. ref_o.add_(2) test_o.add_(2) self.assertEqual(ref_o, test_o) # Reverse the modification ref_o.sub_(2) test_o.sub_(2) self.assertEqual(ref_o, test_o) for ref_i, test_i in zip(ref_inp, test_inp): if isinstance(ref_i, torch.Tensor): self.assertEqual(ref_i.requires_grad, test_i.requires_grad) self.assertEqual(ref_i, test_i) for keep_input_mutations in [True] if keep_inp_mutations else [True, False]: inp, graph_inps = make_inputs(inp_) test_inp, test_graph_inps = make_inputs(inp_) fw_graph_cell = [None] compiled_f = self.run_autograd( f, fw_graph_cell, decompositions, keep_input_mutations, dynamic ) ref_results = outs_and_grads(f, graph_inps, inp) test_results = outs_and_grads(compiled_f, test_graph_inps, test_inp) check_results( ref_results, test_results, graph_inps, test_graph_inps, inp, test_inp ) if isinstance(self, TestAOTAutogradWithCache): # When testing with cache, run compiled_f a second time cached_inp, cached_graph_inps = make_inputs(inp_) cached_results = outs_and_grads( compiled_f, cached_graph_inps, cached_inp ) check_results( ref_results, cached_results, graph_inps, cached_graph_inps, inp, cached_inp, ) return fw_graph_cell[0] def test_non_tensor_and_none_inputs(self): # int, None, Tensor def f(a, b, c): return a * c inp = [2, None, torch.ones(3, 3, dtype=torch.float32, requires_grad=True)] self.verify_aot_autograd(f, inp) inp = [2, None, torch.ones(3, 3, dtype=torch.float32, requires_grad=False)] self.verify_aot_autograd(f, inp) def test_single_output(self): def f(a, b): return a + b inp = [torch.randn(3, 3, requires_grad=True), torch.randn(3, 3)] self.verify_aot_autograd(f, inp) inp = [torch.randn(3, 3, requires_grad=False), torch.randn(3, 3)] self.verify_aot_autograd(f, inp) def test_multi_output(self): def f(a, b): return a + b, a - b inp = [torch.randn(3, 3, requires_grad=True), torch.randn(3, 3)] self.verify_aot_autograd(f, inp) inp = [torch.randn(3, 3, requires_grad=False), torch.randn(3, 3)] self.verify_aot_autograd(f, inp) def test_multi_output_list(self): def f(a, b): return [a + b, a - b] inp = [torch.randn(3, 3, requires_grad=True), torch.randn(3, 3)] self.verify_aot_autograd(f, inp) inp = [torch.randn(3, 3, requires_grad=False), torch.randn(3, 3)] self.verify_aot_autograd(f, inp) # Test for bug occurring at the intersection of fake tensors & functionalization. def test_squeeze_mutation(self): def f(a): b = a.clone().squeeze(-1) b.add_(1.0) return a + b inp = [torch.randn(3, 1, requires_grad=True)] self.verify_aot_autograd(f, inp, dynamic=True) inp = [torch.randn(3, 1, requires_grad=False)] self.verify_aot_autograd(f, inp, dynamic=True) def test_complex_linear(self): # https://github.com/pytorch/pytorch/issues/93424 inp = [torch.randn(1, 10, 10, dtype=torch.complex64)] class F(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = nn.Linear(10, 10, dtype=torch.complex64) def forward(self, x): return self.linear(x).sum().abs() self.verify_aot_autograd(F(), inp) def test_embedding_bag_view_dynamic(self): # Backwards pass tries to wrap a sparse tensor in a FunctionalTensorWrapper; # test that this works even though the sparse tensor has no storage. class F(torch.nn.Module): def __init__(self) -> None: super().__init__() self.emb = torch.nn.EmbeddingBag(100, 8, sparse=True) def forward(self, x, y): return self.emb(x, y).view(-1) x = torch.arange(3) y = torch.arange(3) self.verify_aot_autograd(F(), [x, y], dynamic=False) self.verify_aot_autograd(F(), [x, y], dynamic=True) def test_input_mutation_simple(self): def f(a): a.mul_(2) return a * 3 inp = [torch.ones(3, 3, requires_grad=True)] fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) inp = [torch.ones(3, 3, requires_grad=False)] self.verify_aot_autograd(f, inp, test_mutation=True) # Things to note: # - the extra clone is because we need to pass the pre-mutated input to grad(), # but autograd operates above functionalization so we need to manually clone. # Hopefully backends can optimize this easily. # - The extra return arg is because the compiled forward returns (mutated inputs + outputs) self.assertExpectedInline( fw_graph.code.strip(), """\ def forward(self, primals_1): clone = torch.ops.aten.clone.default(primals_1); primals_1 = None mul = torch.ops.aten.mul.Tensor(clone, 2); clone = None mul_1 = torch.ops.aten.mul.Tensor(mul, 3) return (mul, mul_1)""", ) def test_input_mutation_set__input_mutation(self): def f(a): b = torch.arange(9, dtype=a.dtype).reshape(3, 3) with torch.no_grad(): a.set_(b) return a * b inp = [torch.ones(3, 3, requires_grad=True)] self.verify_aot_autograd(f, inp, test_mutation=True, keep_inp_mutations=True) inp = [torch.ones(3, 3, requires_grad=False)] self.verify_aot_autograd(f, inp, test_mutation=True, keep_inp_mutations=True) def test_set__steals_view_chain(self): def f(a, b): a_ = a.mul(2) b_ = b.mul(2) b_slice = b_[1].view(3, 3) # a_clone should inherit the view chain from b_slice a_.set_(b_slice) # Also mutates b_, a_.view(-1).mul_(2) return a_ * b_slice inp = [ torch.ones(3, 3, requires_grad=False), torch.zeros(3, 9, requires_grad=False), ] self.verify_aot_autograd(f, inp, keep_inp_mutations=True) @skipIfDynamoInput( "Test doesn't make sense with dynamo, which changes order of mutations" ) def test_set__and_data_mutation_good(self): def f(a, b): # The data mutation happens *after* the set_(). This is ok (see the graph below) with torch.no_grad(): a.set_(b) b.mul_(2) return a + b inp = [ torch.ones(3, 3, requires_grad=True), torch.ones(3, 3, requires_grad=True), ] fw_graph = self.verify_aot_autograd( f, inp, test_mutation=True, keep_inp_mutations=True ) inp = [ torch.ones(3, 3, requires_grad=False), torch.zeros(3, 3, requires_grad=False), ] self.verify_aot_autograd(f, inp, test_mutation=True, keep_inp_mutations=True) # Important things to note: # - "return a.set_(b)" desugars into "return b" # - Both a and b are recorded as experiencing mutations, # which is why we see "b_updated" (output of the mul) twice in the graph outputs. # a is recorded as both a data mutation and a metadata mutation (due to set_ swapping its storage). # - the runtime epilogue for a is "a.set_(mul)" # - the runtime epilogue for b is "b.copy_(mul)" self.assertExpectedInline( fw_graph.code.strip(), """\ def forward(self, primals_1, primals_2): mul = torch.ops.aten.mul.Tensor(primals_2, 2) add = torch.ops.aten.add.Tensor(mul, mul) set_ = torch.ops.aten.set_.source_Tensor(primals_1, mul); primals_1 = set_ = None copy_ = torch.ops.aten.copy_.default(primals_2, mul); primals_2 = mul = copy_ = None return (add,)""", ) # This is a (hopefully) extremely rare case that is difficult to handle, # so we ban it. # https://github.com/pytorch/pytorch/issues/126236 # https://github.com/pytorch/pytorch/pull/126113 @xfailIfTorchDynamo def test_set__and_data_mutation_bad(self): def f(a): a_view = a.view(-1) tmp = torch.ones(3, 3, requires_grad=True) # Now, any mutations on either tmp # will be tracked as graph input mutations. with torch.no_grad(): a.set_(tmp) # BAD: a_view is now detached from every graph input, # so we won't recognize that this caused an input mutation! a_view.mul_(2) return a + tmp inp = [torch.ones(3, 3, requires_grad=True)] with self.assertRaisesRegex( RuntimeError, "cannot mutate tensors with frozen storage" ): self.verify_aot_autograd( f, inp, test_mutation=True, keep_inp_mutations=True ) @skipIfDynamoInput( "Test doesn't make sense with dynamo, which changes order of mutations" ) def test_set__not_allowed(self): def f(a, b): with torch.no_grad(): a.set_(b) # Mutating a will change a's grad_fn, which requires us to replay the mutation outside of the graph. # We currently ban this today, when the input also received a set_() input mutation. a.mul_(2) return a + b inp = [ torch.ones(3, 3, requires_grad=True), torch.ones(3, 3, requires_grad=True), ] with self.assertRaisesRegex( AssertionError, "but the input has other mutations that we cannot" ): fw_graph = self.verify_aot_autograd( f, inp, test_mutation=True, keep_inp_mutations=True ) def test_input_mutation_set__nop(self): def f(a): b = torch.arange(9, dtype=a.dtype) a_old = torch.ops.aten.alias.default(a) with torch.no_grad(): a.set_(b) a.set_(a_old) return a + b.reshape(3, 3) inp = [torch.ones(3, 3, requires_grad=True)] fw_graph = self.verify_aot_autograd( f, inp, test_mutation=True, keep_inp_mutations=True ) inp = [torch.ones(3, 3, requires_grad=False)] self.verify_aot_autograd(f, inp, test_mutation=True, keep_inp_mutations=True) # Things to note: # - There are no set_() calls in the graph (we functionalize a.set_(b) into "b") # - There is only **1** graph output. We properly realized that the two set_() calls # undo each other, and so effectively no inputs are mutated. self.assertExpectedInline( fw_graph.code.strip(), """\ def forward(self, primals_1): arange = torch.ops.aten.arange.default(9, dtype = torch.float32, device = device(type='cpu'), pin_memory = False) alias = torch.ops.aten.alias.default(primals_1); primals_1 = None view = torch.ops.aten.view.default(arange, [3, 3]); arange = None add = torch.ops.aten.add.Tensor(alias, view); alias = view = None return (add,)""", ) @unittest.skipIf(IS_WINDOWS, "TODO: need to fix the test case") @unittest.skipIf(IS_MACOS, "TODO: need to fix the test case") def test_input_mutation_fsdp_set__into_same_input(self): import torch.distributed._composable.fsdp._fsdp_param def f(a): b = torch.arange(9, dtype=a.dtype).view(3, 3) c = torch.arange(9, dtype=a.dtype).view(3, 3) d = torch.arange(9, dtype=a.dtype).view(3, 3) with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(a): torch.ops.fsdp.set_.default(a, b) x = a * a with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(a): torch.ops.fsdp.set_.default(a, c) y = a * a with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(a): torch.ops.fsdp.set_.default(a, c) z = a * a return x + y + z inp = [torch.ones(3, 3, requires_grad=True)] fw_graph = self.verify_aot_autograd( f, inp, test_mutation=True, keep_inp_mutations=True ) inp = [torch.ones(3, 3, requires_grad=False)] self.verify_aot_autograd(f, inp, test_mutation=True, keep_inp_mutations=True) """ Expected behavior: (1) When there are multiple set_() calls on the same graph input primal_X, we want those set_() calls to all show up with primal_X as the first arg in the graph. (2) Behavior (1) is not the case today with normal aten.set_ (blocked on #129892), but using a custom fsdp.set_ op with no returns is a simple workaround to achieve that behavior. """ self.assertExpectedInline( fw_graph.code.strip(), """\ def forward(self, primals_1): arange = torch.ops.aten.arange.default(9, dtype = torch.float32, device = device(type='cpu'), pin_memory = False) view = torch.ops.aten.view.default(arange, [3, 3]); arange = None arange_1 = torch.ops.aten.arange.default(9, dtype = torch.float32, device = device(type='cpu'), pin_memory = False) view_1 = torch.ops.aten.view.default(arange_1, [3, 3]); arange_1 = None set_ = torch.ops.fsdp.set_.default(primals_1, view); view = set_ = None mul = torch.ops.aten.mul.Tensor(primals_1, primals_1) set__1 = torch.ops.fsdp.set_.default(primals_1, view_1); set__1 = None mul_1 = torch.ops.aten.mul.Tensor(primals_1, primals_1) set__2 = torch.ops.fsdp.set_.default(primals_1, view_1); view_1 = set__2 = None mul_2 = torch.ops.aten.mul.Tensor(primals_1, primals_1) add = torch.ops.aten.add.Tensor(mul, mul_1); mul = mul_1 = None add_1 = torch.ops.aten.add.Tensor(add, mul_2); add = mul_2 = None return (add_1, primals_1)""", ) self.assertEqual(torch.compile(f, backend="inductor")(*inp), f(*inp)) def test_input_mutation_simple_with_none_and_nontensor(self): # Tensor, None, int def f(a, b, c): return a * c f_compiled = aot_function(f, nop) for req_grad in [True, False]: inp = [torch.ones(3, 3, requires_grad=req_grad), None, 3] out_ref = f(*inp) out_test = f_compiled(*inp) self.assertEqual(out_ref, out_test) # https://github.com/pytorch/pytorch/issues/93363 def test_mutates_input_noncontiguous(self): def f(a): a.add_(1) return () f_compiled = aot_function(f, nop) ref = torch.ones(4, requires_grad=True) + 0 ref_view = ref[0::2] test = torch.ones(4, requires_grad=True) + 0 test_view = test[0::2] out_ref = f(ref_view) out_test = f_compiled(test_view) self.assertEqual(ref, test) def test_input_mutation_modifies_autograd_meta_of_aliases(self): def f(a): a.mul_(2) out = a + 1 return out.detach() x_ref = torch.ones(3, 3, requires_grad=True).clone() x_ref_view = x_ref.view(3, 3) x_test = torch.ones(3, 3, requires_grad=True).clone() x_test_view = x_test.view(3, 3) f_compiled = aot_function(f, nop, keep_inference_input_mutations=True) f(x_ref) f_compiled(x_test) # f will mutate aliases of the input, including its autograd metadata! # y.grad_fn is AsStridedBackward self.assertEqual(x_ref_view, x_test_view) self.assertEqual(x_ref_view._version, x_test_view._version) self.assertEqual(x_ref_view.grad_fn.__class__, x_test_view.grad_fn.__class__) # Test the actual gradients are correct (x_ref * x_ref_view).sum().backward() (x_test * x_test_view).sum().backward() self.assertEqual(x_ref.grad, x_test.grad) self.assertEqual(x_ref_view.grad, x_test_view.grad) @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/127470") def test_nested_subclasses(self): @torch.compile(backend="aot_eager") def f(x): return x.sin().cos() a = torch.ones(4, requires_grad=True) a2 = a.clone().detach().requires_grad_() a3 = a.clone().detach().requires_grad_() a4 = a.clone().detach().requires_grad_() aa = TwoTensor(a, a2) aa2 = TwoTensor(a3, a4) aaaa = TwoTensor(aa, aa2) out = f(aaaa) self.assertTrue(isinstance(out, TwoTensor)) self.assertTrue(isinstance(out.a, TwoTensor)) self.assertTrue(isinstance(out.b, TwoTensor)) self.assertTrue(isinstance(out.a.a, torch.Tensor)) self.assertTrue(isinstance(out.a.b, torch.Tensor)) self.assertTrue(isinstance(out.b.a, torch.Tensor)) self.assertTrue(isinstance(out.b.b, torch.Tensor)) out.sum().backward() self.assertTrue(isinstance(aaaa.grad, TwoTensor)) self.assertTrue(isinstance(aaaa.grad.a, TwoTensor)) self.assertTrue(isinstance(aaaa.grad.b, TwoTensor)) @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/127470") def test_nested_subclasses_non_nested_grad(self): @torch.compile(backend="aot_eager") def f(x): return x.sin().cos() a = torch.ones(4, requires_grad=True) a2 = a.clone().detach().requires_grad_() a3 = a.clone().detach().requires_grad_() a4 = a.clone().detach().requires_grad_() new_aa = TwoTensor(a3, a4) aa = TwoTensor(a, a2) aa2 = aa.clone().detach().requires_grad_() aaaa = TwoTensor(aa, aa2) out = f(new_aa) new_out = out + aaaa with self.assertRaisesRegex( RuntimeError, "The grad inputs should be same tensor subclass type as forward output", ): new_out.sum().backward() @unittest.skipIf(IS_WINDOWS, "Windows isn't supported for this case") @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/127470") def test_custom_tensor_metadata(self): def f(x): x_elem = x.elem x_elem_elem = x_elem.elem x_elem_metadata = x_elem.constant_attribute return x * x_elem * x_elem_elem * x_elem_metadata a = torch.ones(4, requires_grad=True) custom_a = ConstantExtraMetadataTensor(a) custom_a.constant_attribute = 6 custom_aa = ConstantExtraMetadataTensor(custom_a) custom_aa.constant_attribute = 4 custom_aa_compile = custom_aa.clone().detach().requires_grad_() custom_aa_compile.elem.constant_attribute = 6 out_eager = f(custom_aa) compiled_f = torch.compile(f, backend="aot_eager") out = compiled_f(custom_aa_compile) self.assertTrue(torch.allclose(out_eager, out)) out.sum().backward() self.assertTrue(isinstance(custom_aa_compile.grad, ConstantExtraMetadataTensor)) self.assertTrue( isinstance(custom_aa_compile.grad.elem, ConstantExtraMetadataTensor) ) @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/127470") def test_nested_subclasses_complicated_inps(self): def f(x, y, z): temp = x + y temp_plain = x.a + y.b res = temp.sum() + temp_plain.sum() return x.sin().cos() + res x = torch.ones(4, requires_grad=True) x2 = x.clone().detach().requires_grad_() xx = TwoTensor(x, x2) xx2 = xx.clone().detach().requires_grad_() x_nested = TwoTensor(xx, xx2) x_nested_compile = x_nested.clone().detach().requires_grad_() y_nested = x_nested.clone().detach().requires_grad_() y_nested_compile = y_nested.clone().detach().requires_grad_() z = x.clone().detach().requires_grad_() z_compile = z.clone().detach().requires_grad_() out_eager = f(x_nested, y_nested, z) compiled_f = torch.compile(f, backend="aot_eager") out = compiled_f(x_nested_compile, y_nested_compile, z_compile) self.assertTrue(torch.allclose(out_eager, out)) self.assertTrue(isinstance(out, TwoTensor)) self.assertTrue(isinstance(out.a, TwoTensor)) self.assertTrue(isinstance(out.b, TwoTensor)) self.assertTrue(isinstance(out.a.a, torch.Tensor)) self.assertTrue(isinstance(out.a.b, torch.Tensor)) self.assertTrue(isinstance(out.b.a, torch.Tensor)) self.assertTrue(isinstance(out.b.b, torch.Tensor)) out.sum().backward() out_eager.sum().backward() self.assertTrue(isinstance(x_nested_compile.grad, TwoTensor)) self.assertTrue(isinstance(x_nested_compile.grad.a, TwoTensor)) self.assertTrue(isinstance(x_nested_compile.grad.b, TwoTensor)) self.assertTrue(isinstance(y_nested_compile.grad, TwoTensor)) self.assertTrue(isinstance(y_nested_compile.grad.a, TwoTensor)) self.assertTrue(isinstance(y_nested_compile.grad.b, TwoTensor)) self.assertTrue(torch.allclose(x_nested_compile.grad.a.a, x_nested.grad.a.a)) self.assertTrue(torch.allclose(x_nested_compile.grad.a.b, x_nested.grad.a.b)) self.assertTrue(torch.allclose(y_nested_compile.grad.a.a, y_nested.grad.a.a)) self.assertTrue(torch.allclose(y_nested_compile.grad.a.b, y_nested.grad.a.b)) @unittest.skipIf(IS_WINDOWS, "Windows isn't supported for this case") @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/127470") def test_nested_subclasses_complicated_inps_mixed(self): def f(x, y): y_elem = y.elem y_elem_elem = y_elem.elem y_elem_metadata = y_elem.constant_attribute return y * y_elem * y_elem_elem * y_elem_metadata + x x = torch.ones(4, requires_grad=True) x2 = x.clone().detach().requires_grad_() xx = TwoTensor(x, x2) xx2 = xx.clone().detach().requires_grad_() x_nested = TwoTensor(xx, xx2) x_nested_compile = x_nested.clone().detach().requires_grad_() a = torch.ones(4, requires_grad=True) custom_a = ConstantExtraMetadataTensor(a) custom_a.constant_attribute = 6 custom_aa = ConstantExtraMetadataTensor(custom_a) custom_aa.constant_attribute = 4 custom_aa_compile = custom_aa.clone().detach().requires_grad_() custom_aa_compile.constant_attribute = 4 custom_aa_compile.elem.constant_attribute = 6 compiled_f = torch.compile(f, backend="aot_eager") out_eager = f(x_nested, custom_aa) out = compiled_f(x_nested_compile, custom_aa_compile) self.assertTrue(torch.allclose(out_eager, out)) out.sum().backward() out_eager.sum().backward() self.assertTrue(torch.allclose(x_nested_compile.grad, x_nested.grad)) self.assertTrue(torch.allclose(custom_aa_compile.grad, custom_aa.grad)) @skipIfTorchDynamo("This test suite already uses dynamo") def test_composite_impl_compile(self): class Foo(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(3, 3) def forward(self, a): return self.linear(a) inp = [torch.ones(3, 3, requires_grad=True)] fw_graph = self.verify_aot_autograd(Foo(), inp, test_mutation=True) inp = [torch.ones(3, 3, requires_grad=False)] self.assertExpectedInline( fw_graph.code.strip(), """\ def forward(self, primals_1, primals_2, primals_3): t = torch.ops.aten.t.default(primals_1); primals_1 = None addmm = torch.ops.aten.addmm.default(primals_2, primals_3, t); primals_2 = None return (addmm, primals_3, t)""", ) with torch.inference_mode(): fw_graph = self.verify_aot_autograd(Foo(), inp, test_mutation=True) inp = [torch.ones(3, 3, requires_grad=False)] self.assertExpectedInline( fw_graph.code.strip(), """\ def forward(self, arg0_1, arg1_1, arg2_1): t = torch.ops.aten.t.default(arg0_1); arg0_1 = None addmm = torch.ops.aten.addmm.default(arg1_1, arg2_1, t); arg1_1 = arg2_1 = t = None return (addmm,)""", ) def test_outputs_are_aliased(self): # Tensor, None, int def f(a): b = a.mul(2) c = b.view(-1) return b, c f_compiled = aot_function(f, nop) for req_grad in [True, False]: inp = torch.ones(3, requires_grad=req_grad) out_ref = f(inp) out_test = f_compiled(inp) self.assertEqual(out_ref[0], out_test[0]) self.assertEqual(out_ref[1], out_test[1]) # Try mutating one of the outputs, which is aliased. out_ref[0].mul_(3) out_test[0].mul_(3) # Assert that the aliasing relationship was preserved self.assertEqual(out_ref[0], out_test[0]) self.assertEqual(out_ref[1], out_test[1]) def test_input_mutation_is_output(self): def f(a): a.mul_(2) return a inp = [torch.ones(3, 3, requires_grad=True)] fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) inp = [torch.ones(3, 3, requires_grad=False)] self.verify_aot_autograd(f, inp, test_mutation=True) self.assertExpectedInline( fw_graph.code.strip(), """\ def forward(self, primals_1): clone = torch.ops.aten.clone.default(primals_1); primals_1 = None mul = torch.ops.aten.mul.Tensor(clone, 2); clone = None return (mul, mul)""", ) def test_input_mutation_multiple(self): def f(a, b, c): a.mul_(2) c.mul_(2) return a + b + c def create_inp(req_grad): return [ torch.ones(3, 3, requires_grad=req_grad), torch.ones(3, 3, requires_grad=req_grad), torch.ones(3, 3, requires_grad=req_grad), ] self.verify_aot_autograd(f, create_inp(False), test_mutation=True) fw_graph = self.verify_aot_autograd(f, create_inp(True), test_mutation=True) self.assertExpectedInline( fw_graph.code.strip(), """\ def forward(self, primals_1, primals_2, primals_3): clone = torch.ops.aten.clone.default(primals_1); primals_1 = None clone_1 = torch.ops.aten.clone.default(primals_3); primals_3 = None mul = torch.ops.aten.mul.Tensor(clone, 2); clone = None mul_1 = torch.ops.aten.mul.Tensor(clone_1, 2); clone_1 = None add = torch.ops.aten.add.Tensor(mul, primals_2); primals_2 = None add_1 = torch.ops.aten.add.Tensor(add, mul_1); add = None return (mul, mul_1, add_1)""", ) def test_input_mutation_return(self): def f(a, b): return torch.sin(a, out=b) inp = [torch.randn(3, 3), torch.ones(3, 3)] fw_graph = self.verify_aot_autograd( f, inp, test_mutation=True, keep_inp_mutations=True ) self.assertExpectedInline( fw_graph.code.strip(), """\ def forward(self, arg0_1, arg1_1): sin = torch.ops.aten.sin.default(arg0_1); arg0_1 = None copy_ = torch.ops.aten.copy_.default(arg1_1, sin); arg1_1 = sin = None return (copy_,)""", ) def test_input_mutation_metadata(self): def f(a, b): a.transpose_(1, 0) return a + b def create_inp(req_grad): return [ torch.ones(3, 3, requires_grad=req_grad), torch.ones(3, 3, requires_grad=req_grad), ] self.verify_aot_autograd(f, create_inp(True), test_mutation=True) self.verify_aot_autograd(f, create_inp(False), test_mutation=True) def test_input_mutation_storage_resize_up(self): def f(a): torch.ops.inductor.resize_storage_bytes_(a, 32) # float32, 4 bytes per element, 32 bytes == 8 elements with torch.no_grad(): a.copy_(torch.ones(8)) return a + 1 inp = torch.zeros(8, requires_grad=True) # Input starts with zero-size-storage inp.untyped_storage().resize_(0) fw_graph_cell = [None] compiled_f = aot_function( f, fw_compiler=make_boxed_compiler( partial(extract_graph, graph_cell=fw_graph_cell) ), bw_compiler=nop, decompositions={}, keep_inference_input_mutations=True, dynamic=False, ) out = compiled_f(inp) # Final functionalized graph has two mutation ops: # (1) a resize_() to resize input tensor up # (2) a copy_() to fill in the resized input with valid data self.assertExpectedInline( fw_graph_cell[0].code.strip(), """\ def forward(self, primals_1): resize_storage_bytes_ = torch.ops.inductor.resize_storage_bytes_.default(primals_1, 32); resize_storage_bytes_ = None ones = torch.ops.aten.ones.default([8], device = device(type='cpu'), pin_memory = False) copy = torch.ops.aten.copy.default(primals_1, ones); ones = None add = torch.ops.aten.add.Tensor(copy, 1) copy_ = torch.ops.aten.copy_.default(primals_1, copy); primals_1 = copy = copy_ = None return (add,)""", ) def test_input_mutation_storage_resize_down(self): def f(a): out = a.sin() torch.ops.inductor.resize_storage_bytes_(a, 0) return out inp = torch.zeros(8, requires_grad=True) fw_graph_cell = [None] compiled_f = aot_function( f, fw_compiler=make_boxed_compiler( partial(extract_graph, graph_cell=fw_graph_cell) ), bw_compiler=nop, decompositions={}, keep_inference_input_mutations=True, dynamic=False, ) out = compiled_f(inp) # Final functionalized graph has one mutation ops: # (1) a resize_() to resize input tensor down # Even though there was technically a "data mutation" on the input (from a.copy_()), # We don't include it in the graph since the final input size has zero storage self.assertExpectedInline( fw_graph_cell[0].code.strip(), """\ def forward(self, primals_1): sin = torch.ops.aten.sin.default(primals_1) resize_storage_bytes_ = torch.ops.inductor.resize_storage_bytes_.default(primals_1, 0); resize_storage_bytes_ = None return (sin, primals_1)""", ) # def test_input_mutation_storage_resize_up_down(self): # def f(a): # torch.ops.inductor.resize_storage_bytes_(a, 32) # # float32, 4 bytes per element, 32 bytes == 8 elements # with torch.no_grad(): # a.copy_(torch.ones(8)) # out = a.sin() # torch.ops.inductor.resize_storage_bytes_(a, 0) # return out # inp = torch.zeros(8, requires_grad=True) # # Input starts with zero-size-storage # inp.untyped_storage().resize_(0) # fw_graph_cell = [None] # compiled_f = aot_function( # f, # fw_compiler=make_boxed_compiler( # partial(extract_graph, graph_cell=fw_graph_cell) # ), # bw_compiler=nop, # decompositions={}, # keep_inference_input_mutations=True, # dynamic=False, # ) # out = compiled_f(inp) # # Final graph has two interesting properties: # # (1) no resizes in the functional graph, since the two resizes cancel out # # and the final size is zero # # (2) no copy_ in the functional graph, even though we copied data into the input, # # because the input has no storage at the end of graph execution (so no data to copy) # self.assertExpectedInline( # fw_graph_cell[0].code.strip(), # """\ # def forward(self, primals_1): # ones = torch.ops.aten.ones.default([8], device = device(type='cpu'), pin_memory = False) # copy = torch.ops.aten.copy.default(primals_1, ones); primals_1 = ones = None # sin = torch.ops.aten.sin.default(copy) # return [sin, copy]""", # ) def test_input_mutation_storage_resize_down_and_set_(self): # Meant to mimic ppFSDP class TracableCreateParameter(torch.autograd.Function): @staticmethod def forward(ctx, tensor, placeholder): assert not tensor.requires_grad return placeholder.set_(tensor) @staticmethod def backward(ctx, grad): return None, grad # grad flows to placeholder def f(dummy_param, param_shard): # simulate allgather with torch.no_grad(): allgather_param = torch.cat([param_shard, param_shard]) # simulate propagating grad state through dummy param, using data of allgather param dummy_param_with_grad_state = TracableCreateParameter.apply( allgather_param, dummy_param ) out = dummy_param.sin() # Resize out dummy param, which now has the allgather data torch.ops.inductor.resize_storage_bytes_(dummy_param, 0) return out # Simulates the local shard of our param param_shard = torch.zeros(8, requires_grad=True) # The dummy, zero-sized allgathered param that autograd will actually compute gradients on dummy_param = torch.zeros(16, requires_grad=True) dummy_param.untyped_storage().resize_(0) fw_graph_cell = [None] compiled_f = aot_function( f, fw_compiler=make_boxed_compiler( partial(extract_graph, graph_cell=fw_graph_cell) ), bw_compiler=nop, decompositions={}, keep_inference_input_mutations=True, dynamic=False, ) out = compiled_f(dummy_param, param_shard) # Important stuff to point out: # (1) We save cat for backward (input to the sin()). # While the original code was dummy_param.sin(), # dummy_param actually contains the `cat` tensor due to the set_() call # (2) We emit a cat.resize_storage_(0) in the graph. # After the set_(), cat is the actually data of dummy_param, which is what we call resize_() on self.assertExpectedInline( fw_graph_cell[0].code.strip(), """\ def forward(self, primals_1, primals_2): cat = torch.ops.aten.cat.default([primals_2, primals_2]); primals_2 = None sin = torch.ops.aten.sin.default(cat) resize_storage_bytes_ = torch.ops.inductor.resize_storage_bytes_.default(cat, 0); resize_storage_bytes_ = None set_ = torch.ops.aten.set_.source_Tensor(primals_1, cat); primals_1 = set_ = None return (sin, cat)""", ) def test_input_mutation_storage_resize_before_set_(self): def f(a): with torch.no_grad(): torch.ops.inductor.resize_storage_bytes_(a, 0) a.set_(torch.ones(2)) inp = torch.zeros(8, requires_grad=True) compiled_f = aot_function( f, fw_compiler=nop, bw_compiler=nop, decompositions={}, keep_inference_input_mutations=True, dynamic=False, ) out = compiled_f(inp) # def test_input_mutation_storage_resize_not_supported(self): # def f(a): # a.mul_(2) # torch.ops.inductor.resize_storage_bytes_(a, 0) # return a # inp = torch.zeros(8, requires_grad=True) # with self.assertRaisesRegex( # AssertionError, "the input has other mutations that we cannot" # ): # compiled_f = aot_function( # f, # fw_compiler=nop, # bw_compiler=nop, # decompositions={}, # keep_inference_input_mutations=True, # dynamic=False, # ) # out = compiled_f(inp) def test_input_output_aliase_custom_autograd_function(self): class Foo(torch.autograd.Function): @staticmethod def forward(ctx, x): return x @staticmethod def backward(ctx, gx): return gx * 0.5 def f(x): return Foo.apply(x) inp = [torch.ones(2, 2, requires_grad=True)] self.verify_aot_autograd(f, inp, test_mutation=False) def test_input_mutation_requires_grad_detach(self): # Here, "a" requires grad, and gets mutated, so we append a copy_() to the end of the graph. # Its mutation doesn't take part in autograd though, because we mutated a detach'd view. # Need to make sure that this copy_() doesn't error, and doesn't participate in autograd either. def f(a): a.detach().mul_(2) return a + 3 inp = [torch.ones(4, requires_grad=True)] self.verify_aot_autograd(f, inp, test_mutation=False) inp = [torch.ones(4, requires_grad=True)] # test_mutation=True will first do some compute on inp, so it is no longer an autograd leaf # by the time it becomes a graph input. Good to test both cases. self.verify_aot_autograd(f, inp, test_mutation=True) def test_input_mutation_hidden_from_autograd_aliasing(self): def f(a): a_alias = a.view(-1) with torch.no_grad(): a_alias.mul_(2) return a + 1 inp = [torch.ones(4, requires_grad=True)] # The important bit: we detected that the input mutation is safe # to include **inside** the graph, since it was under no_grad # (so all we need to do is use mark_dirty() on the input to bump the VC) fw_graph = self.verify_aot_autograd( f, inp, test_mutation=True, keep_inp_mutations=True ) self.assertExpectedInline( fw_graph.code.strip(), """\ def forward(self, primals_1): view = torch.ops.aten.view.default(primals_1, [-1]) mul = torch.ops.aten.mul.Tensor(view, 2); view = None view_1 = torch.ops.aten.view.default(mul, [4]); mul = None add = torch.ops.aten.add.Tensor(view_1, 1) copy_ = torch.ops.aten.copy_.default(primals_1, view_1); primals_1 = view_1 = copy_ = None return (add,)""", ) def test_input_mutation_requires_grad_no_grad(self): def f(a): with torch.no_grad(): a.mul_(2) return a + 3 inp = [torch.ones(4, requires_grad=True)] fw_graph = self.verify_aot_autograd( f, inp, test_mutation=True, keep_inp_mutations=True ) # Even though the input requires_grad, we expect the keep the input mutation in the graph # (Even though this is a training graph!) self.assertExpectedInline( fw_graph.code.strip(), """\ def forward(self, primals_1): mul = torch.ops.aten.mul.Tensor(primals_1, 2) add = torch.ops.aten.add.Tensor(mul, 3) copy_ = torch.ops.aten.copy_.default(primals_1, mul); primals_1 = mul = copy_ = None return (add,)""", ) def test_input_mutation_requires_grad_no_grad_inference_graph(self): def f(a): with torch.no_grad(): a.mul_(2) return a + 3 inp = [torch.ones(4, requires_grad=True)] # Even though the input requires_grad, we expect the keep the input mutation in the graph fw_graph = self.verify_aot_autograd( f, inp, test_mutation=True, keep_inp_mutations=True ) self.assertExpectedInline( fw_graph.code.strip(), """\ def forward(self, arg0_1): mul = torch.ops.aten.mul.Tensor(arg0_1, 2) add = torch.ops.aten.add.Tensor(mul, 3) copy_ = torch.ops.aten.copy_.default(arg0_1, mul); arg0_1 = mul = copy_ = None return (add,)""", ) def test_input_mutation_requires_grad_no_grad_detach_mixed(self): # Perform a mix of mutations on a: # 1 normal, 1 in no_grad, 1 on a detach'd tensor. # Only the first should participate in gradient computation. def f(a): a.detach().mul_(2) a.mul_(3) with torch.no_grad(): a.mul_(4) return a + 5 inp = [torch.ones(4, requires_grad=True)] fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) def test_input_mutation_metadata2(self): def f(a): a.transpose_(1, 0) a.mul_(2) return a + 1 inp = [torch.ones(3, 3, requires_grad=True)] self.verify_aot_autograd(f, inp, test_mutation=True) inp = [torch.ones(3, 3, requires_grad=False)] self.verify_aot_autograd(f, inp, test_mutation=True) def test_input_mutation_batchnorm(self): def f(inpt, weight, bias, running_mean, running_var): # This is additionally a good test, because the input tensors that we mutate # are *also* saved for backwards. # This tests that what we save for the backward is actually cloned inputs, # and not the original inputs that got mutated. return torch._native_batch_norm_legit( inpt, weight, bias, running_mean, running_var, True, 0.5, 1e-5 ) def create_inp(req_grad): return [ torch.ones(2, 5, 5, 5, requires_grad=req_grad), torch.ones(5, requires_grad=req_grad), torch.ones(5, requires_grad=req_grad), torch.ones(5), torch.ones(5), ] from torch._decomp import get_decompositions # This simulates what inductor does (running the fw + bw decompositions) decompositions = get_decompositions( [ torch.ops.aten._native_batch_norm_legit_functional, torch.ops.aten.native_batch_norm_backward, ] ) self.verify_aot_autograd( f, create_inp(True), test_mutation=True, decompositions=decompositions ) self.verify_aot_autograd( f, create_inp(False), test_mutation=True, decompositions=decompositions ) def test_batchnorm_inference(self): inp = [ torch.ones(2, 5, 5, 5, requires_grad=True), torch.ones(5, requires_grad=True), torch.ones(5, requires_grad=True), torch.ones(5), torch.ones(5), ] m = torch.nn.BatchNorm2d(4, 4) m.eval() fw_graph_cell = [None] inp = torch.ones(4, 4, 4, 4) fw_graph_cell = [None] compiled_m = aot_module( m, fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell), bw_compiler=nop, keep_inference_input_mutations=True, ) inp = torch.ones(4, 4, 4, 4) with torch.no_grad(): out = compiled_m(inp) # expectation: there are no copy_() calls in the decomposed batch norm when running under training=False (eval mode) code = fw_graph_cell[0].code.strip() self.assertTrue("copy_" not in str(code)) def test_input_output_view_simple(self): def f(a): return a.view(-1) inp = [torch.ones(2, 2, requires_grad=False).add(1)] self.verify_aot_autograd(f, inp, test_mutation=True) inp = [torch.ones(2, 2, requires_grad=True).add(1)] fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) # Outputs that alias inputs are pulled out of the graph entirely, so we don't compile anything here self.assertExpectedInline( fw_graph.code.strip(), """\ def forward(self, arg0_1): view = torch.ops.aten.view.default(arg0_1, [-1]); arg0_1 = None return (view,)""", ) def test_input_output_view_mutate_multiple(self): def f(a, b, c): a.mul_(2) c.mul_(3) return b.view(2, 2), c.view(2, 2) def create_inp(req_grad): return [ torch.ones(2, 2, requires_grad=req_grad).add(1), torch.ones(2, 2, requires_grad=req_grad).add(1), torch.ones(2, 2, requires_grad=req_grad).add(1), ] self.verify_aot_autograd(f, create_inp(False), test_mutation=True) fw_graph = self.verify_aot_autograd(f, create_inp(True), test_mutation=True) # The original function returned two outputs, both of which aliased inputs. # We expect two outputs in the functional graph, a_updated and c_updated. # The actual aliased outputs themselves aren't in the compiled forward graph; # Instead, they're generated outside of the graph. self.assertExpectedInline( fw_graph.code.strip(), """\ def forward(self, primals_1, primals_2, primals_3): clone = torch.ops.aten.clone.default(primals_1); primals_1 = None clone_1 = torch.ops.aten.clone.default(primals_3); primals_3 = None mul = torch.ops.aten.mul.Tensor(clone, 2); clone = None mul_1 = torch.ops.aten.mul.Tensor(clone_1, 3); clone_1 = None view = torch.ops.aten.view.default(primals_2, [2, 2]); primals_2 = None view_2 = torch.ops.aten.view.default(mul_1, [2, 2]) return (mul, mul_1, view, view_2)""", ) def test_input_output_view_metadata_mutate_multiple(self): def f(a, b, c): b.mul_(3) c.t_() return a.view(2, 2), b.view(2, 2), c.view(2, 2) def create_inp(req_grad): return [ torch.ones(2, 2, requires_grad=req_grad).add(1), torch.ones(2, 2, requires_grad=req_grad).add(1), torch.ones(2, 2, requires_grad=req_grad).add(1), ] self.verify_aot_autograd(f, create_inp(False), test_mutation=True) fw_graph = self.verify_aot_autograd(f, create_inp(True), test_mutation=True) # Important thing to check here: of the three inputs: # Only the b.mul_(3) should show up in the graph (we functionalize it and return it). # Everything else that does not show up in the graph includes: # - The metadata mutation on c (we do it outside the graph) # - All 3 original fw outputs, which are aliases of inputs (we regenerate them outside of the graph) self.assertExpectedInline( fw_graph.code.strip(), """\ def forward(self, primals_1, primals_2, primals_3): clone = torch.ops.aten.clone.default(primals_2); primals_2 = None view = torch.ops.aten.view.default(primals_3, [2, 2]); primals_3 = None mul = torch.ops.aten.mul.Tensor(clone, 3); clone = None t = torch.ops.aten.t.default(view); view = None view_1 = torch.ops.aten.view.default(primals_1, [2, 2]); primals_1 = None view_3 = torch.ops.aten.view.default(t, [2, 2]) view_4 = torch.ops.aten.view.default(mul, [2, 2]) return (mul, t, view_1, view_4, view_3)""", ) def test_input_mutation_and_output_view(self): def f(a): a.add_(1) return a.view(-1) inp = [torch.ones(2, 2, requires_grad=False).add(1)] self.verify_aot_autograd(f, inp, test_mutation=True) inp = [torch.ones(2, 2, requires_grad=True).add(1)] fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) # Here, total # of outputs is 1 because: # - num_mutated_inps = 1 (a_updated) # - num_fw_outputs = 0 (the output is an alias of the input, so we move it outside the compiled fw) self.assertExpectedInline( fw_graph.code.strip(), """\ def forward(self, primals_1): clone = torch.ops.aten.clone.default(primals_1); primals_1 = None add = torch.ops.aten.add.Tensor(clone, 1); clone = None view_1 = torch.ops.aten.view.default(add, [-1]) return (add, view_1)""", ) def test_input_mutation_output_view_multiple(self): def f(a, b, c, d): b.transpose_(1, 0) c.add_(1) return d + 1, b.diagonal(), a + c def create_inp(req_grad): return [ torch.arange(4, requires_grad=req_grad, dtype=torch.float32) .view(2, 2) .add(1), torch.arange(4, requires_grad=req_grad, dtype=torch.float32) .view(2, 2) .add(1), torch.ones(2, 2, requires_grad=req_grad).add(1), torch.ones(2, 2, requires_grad=req_grad).add(1), ] self.verify_aot_autograd(f, create_inp(False), test_mutation=True) fw_graph = self.verify_aot_autograd(f, create_inp(True), test_mutation=True) self.assertExpectedInline( fw_graph.code.strip(), """\ def forward(self, primals_1, primals_2, primals_3, primals_4): view = torch.ops.aten.view.default(primals_2, [2, 2]); primals_2 = None clone = torch.ops.aten.clone.default(primals_3); primals_3 = None transpose = torch.ops.aten.transpose.int(view, 1, 0); view = None add = torch.ops.aten.add.Tensor(clone, 1); clone = None add_1 = torch.ops.aten.add.Tensor(primals_4, 1); primals_4 = None diagonal = torch.ops.aten.diagonal.default(transpose) add_2 = torch.ops.aten.add.Tensor(primals_1, add); primals_1 = None return (transpose, add, add_1, diagonal, add_2)""", ) def test_output_aliases_intermediate_single(self): def f(a): out = torch.mul(a, 3) return out.view(-1) inp = [torch.ones(3, 3, requires_grad=False)] self.verify_aot_autograd(f, inp, test_mutation=True) inp = [torch.ones(3, 3, requires_grad=True)] fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) # In AOTAutograd, we are obligated to make the compiled forward directly return `out`, # and reconstruct `out.view(-1)` as a fresh output. self.assertExpectedInline( fw_graph.code.strip(), """\ def forward(self, primals_1): mul = torch.ops.aten.mul.Tensor(primals_1, 3); primals_1 = None view = torch.ops.aten.view.default(mul, [-1]); mul = None return (view,)""", ) def test_output_aliases_input_multi_output_view_should_raise_autograd_error(self): def f1(a): return list(a.unbind(0)) f1_compiled = aot_function(f1, nop) inp1 = torch.ones(3, 3, requires_grad=True).clone() inp2 = torch.ones(3, 3, requires_grad=True).clone() inp3 = torch.ones(3, 3, requires_grad=True).clone() with self.assertRaisesRegex( RuntimeError, "Such functions do not allow the output views" ): out_test1 = f1_compiled(inp1) # This raises a runtime error from autograd in eager mode out_test1[0].mul_(2) with self.assertRaisesRegex( RuntimeError, "Such functions do not allow the output views" ): out_test2 = f1_compiled(inp2) inp2.mul_(2) # In eager mode, if we mutate a tensor, any multi-output-view aliases # get their grad_fn replaced with error nodes, so accessing grad_fn should error grad_fn = out_test2[0].grad_fn with self.assertRaisesRegex( RuntimeError, "Such functions do not allow the output views" ): out_test3 = f1_compiled(inp3) out_test1[0].detach().mul_(2) # The above case also applies to detached aliases (they turn the multi-output-view # alias's grad_fns into error nodes) grad_fn = out_test2[0].grad_fn def test_output_aliases_input_multi_output_view(self): # All aliased outs are from multi-output views, so AOTAutograd will hide the aliasing from autograd. def f1(a): return list(a.unbind(0)) inp = torch.ones(3, 3, requires_grad=True) inp_ref = torch.ones(3, 3, requires_grad=True) f1_compiled = aot_function(f1, nop) out_ref = f1(inp_ref) out_test = f1_compiled(inp) # Assert that we get CompiledFunctionBackward in the backward graph, # and not AsStridedBackward. No view-regeneration necessary for this mult-output view case. # See Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call] self.assertTrue( all("CompiledFunctionBackward" in str(o.grad_fn) for o in out_test) ) sum(out_ref).sum().backward() sum(out_test).sum().backward() self.assertEqual(inp_ref.grad, inp.grad) # Several of the outputs are from multi-output views. # However: they are part of the same alias set as "a", and "a.view(out.shape)", # which are both user-visible. # AOTAutograd will not try to be smart here and hide the aliasing relationships from autograd. # Instead, it will perform its "output aliases input" logic, and regenerate all aliases. def f3(a): return *list(a.unbind(0)), a.view(a.shape) inp = torch.ones(3, 3, requires_grad=True) inp_ref = torch.ones(3, 3, requires_grad=True) f3_compiled = aot_function(f3, nop) inp_ref_clone = inp_ref.clone() inp_clone = inp.clone() out_ref = f3(inp_ref_clone) out_test = f3_compiled(inp_clone) self.assertTrue(all("UnbindBackward" in str(o.grad_fn) for o in out_test[:3])) # The last output is not from a multi-output view, so autograd will let us mutate it. out_ref[-1].mul_(2) out_test[-1].mul_(2) # Also mutate the input, which should affect the aliased output. inp_ref_clone.view(-1).mul_(3) inp_clone.view(-1).mul_(3) # Do backward (inp_ref + out_ref[-1]).sum().backward() (inp + out_test[-1]).sum().backward() self.assertEqual(inp_ref.grad, inp.grad) def test_output_aliases_intermediate_multi_output_view(self): # All aliased outs are from multi-output views, so AOTAutograd will hide the aliasing from autograd. def f1(a): out = torch.mul(a, 3) return list(out.unbind(0)) inp = torch.ones(3, 3, requires_grad=True) inp_ref = torch.ones(3, 3, requires_grad=True) f1_compiled = aot_function(f1, nop) out_ref = f1(inp_ref) out_test = f1_compiled(inp) # Assert that we get CompiledFunctionBackward in the backward graph, # and not AsStridedBackward. No view-regeneration necessary for this mult-output view case. # See Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call] self.assertTrue( all("CompiledFunctionBackward" in str(o.grad_fn) for o in out_test) ) sum(out_ref).sum().backward() sum(out_test).sum().backward() self.assertEqual(inp_ref.grad, inp.grad) # All aliased outs but one are from multi-output views, so AOTAutograd will hide the aliasing from autograd. def f2(a): out = torch.mul(a, 3) return *list(out.unbind(0)), out inp = torch.ones(3, 3, requires_grad=True) inp_ref = torch.ones(3, 3, requires_grad=True) f2_compiled = aot_function(f2, nop) out_ref = f2(inp_ref) out_test = f2_compiled(inp) # Assert that we get CompiledFunctionBackward in the backward graph, # and not AsStridedBackward. No view-regeneration necessary for this mult-output view case. # See Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call] self.assertTrue( all("CompiledFunctionBackward" in str(o.grad_fn) for o in out_test) ) # The last output is not from a multi-output view, so autograd will let us mutate it. out_ref[-1].mul_(2) out_test[-1].mul_(2) out_ref[-1].sum().backward() out_test[-1].sum().backward() self.assertEqual(inp_ref.grad, inp.grad) # All aliased outs but one are from multi-output views, so AOTAutograd will hide the aliasing from autograd. def f3(a): out = torch.mul(a, 3) return *list(out.unbind(0)), out.view(out.shape) inp = torch.ones(3, 3, requires_grad=True) inp_ref = torch.ones(3, 3, requires_grad=True) f3_compiled = aot_function(f3, nop) out_ref = f3(inp_ref) out_test = f3_compiled(inp) # Assert that we get CompiledFunctionBackward in the backward graph, # and not AsStridedBackward. No view-regeneration necessary for this mult-output view case. # See Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call] self.assertTrue( all("CompiledFunctionBackward" in str(o.grad_fn) for o in out_test) ) # The last output is not from a multi-output view, so autograd will let us mutate it. out_ref[-1].mul_(2) out_test[-1].mul_(2) out_ref[-1].sum().backward() out_test[-1].sum().backward() self.assertEqual(inp_ref.grad, inp.grad) # There are 5 outputs that all alias each other. # 3 of them come from multi-output views, but the other 3 are "ordinary" aliases. # Therefore, AOTAutograd will not attempt the multi-output-view optimization, # and apply the intermediate_base logic to all aliases. # (In theory we could probably get AOTAutograd to only apply the intermediate base # logic to the last 2 outputs and not the first 3. We should probably # just do the graph partitioning defined in this doc instead though). # https://docs.google.com/document/d/1DlfFq8TKbuAn2zyJxLfoW-X1qkkm5PLdHFtySo03QAk/edit def f4(a): out = torch.mul(a, 3) # also return the graph intermediate directly, # which will force AOTAutograd to do the "intermediate base" logic. # (Why? The user can mutate "out", which should change the autograd metadata # of the other aliased outputs) return *list(out.unbind(0)), out, out.view(out.shape) inp = torch.ones(3, 3, requires_grad=True) inp_ref = torch.ones(3, 3, requires_grad=True) f4_compiled = aot_function(f4, nop) out_ref = f4(inp_ref) out_test = f4_compiled(inp) # Mutate the last output of f4 (autograd will allow this, since it is not a multi-output view, # as long as *only* the non-multi-output views participate in the backward) # Note: We could probably try to hide **only** the multi-output views from autograd here # and only do the intermediate base logic for the last two aliases. # Longer term solution of graph partitioning is probably cleaner though (see the note). out_ref[-1].mul_(2) out_test[-1].mul_(2) out_ref_sum = out_ref[-1] + out_ref[-2] out_test_sum = out_test[-1] + out_test[-2] out_ref_sum.sum().backward() out_test_sum.sum().backward() self.assertEqual(inp_ref.grad, inp.grad) def test_output_aliases_intermediate_mutation_linear(self): def f(x): return (x + 1).view(-1) inp = [torch.ones(3, 3, requires_grad=True)] # use inductor's decomps (which will e.g. turn _unsafe_view() into view()) from torch._inductor.decomposition import decompositions f_compiled = aot_function(f, nop, decompositions=decompositions) out_ref = f(*inp) out_test = f_compiled(*inp) out_ref.mul_(2) out_test.mul_(2) self.assertEqual(out_ref, out_test) def test_output_aliases_intermediate_no_grad(self): def f(a, b): out = torch.mul(a, 3) # First output is an alias of an intermediate that doesn't require grad return out.view(-1), b.add(1) inp = [torch.ones(3, 3), torch.ones(3, 3, requires_grad=False)] self.verify_aot_autograd(f, inp, test_mutation=True) inp = [torch.ones(3, 3), torch.ones(3, 3, requires_grad=True)] fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) # important bit: we don't bother generating an intermediate base as an output in the graph, # because the intermediate base itself didn't require gradients. # (the only problematic case is when both the base and the aliasesed output require gradients). self.assertExpectedInline( fw_graph.code.strip(), """\ def forward(self, primals_1, primals_2): mul = torch.ops.aten.mul.Tensor(primals_1, 3); primals_1 = None view = torch.ops.aten.view.default(mul, [-1]); mul = None add = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None return (view, add)""", ) def test_output_aliases_intermediate_returned_multiple_times(self): def f(a): out = torch.mul(a, 3) out_view = out.view(-1) return out, out_view, out inp = [torch.ones(3, 3, requires_grad=False)] self.verify_aot_autograd(f, inp, test_mutation=True) inp = [torch.ones(3, 3, requires_grad=True)] fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) def test_output_aliases_intermediate_multiple(self): def f(a): out = torch.mul(a, 3) # AOTAutograd should manually generate these two output views in the epilogue. return out.view(-1), out.view(-1) inp = [torch.ones(3, 3, requires_grad=False)] self.verify_aot_autograd(f, inp, test_mutation=True) inp = [torch.ones(3, 3, requires_grad=True)] fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) self.assertExpectedInline( fw_graph.code.strip(), """\ def forward(self, primals_1): mul = torch.ops.aten.mul.Tensor(primals_1, 3); primals_1 = None view = torch.ops.aten.view.default(mul, [-1]) view_1 = torch.ops.aten.view.default(mul, [-1]) return (view, view_1, mul)""", ) def test_output_aliases_intermediate_and_returned(self): def f(a): out = torch.mul(a, 3) # AOTAutograd should manually generate the first output (a view of an intermediate) # but not the second (which is itself the intermediate for the first) return out.view(-1), out inp = [torch.ones(3, 3, requires_grad=False)] self.verify_aot_autograd(f, inp, test_mutation=True) inp = [torch.ones(3, 3, requires_grad=True)] fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) self.assertExpectedInline( fw_graph.code.strip(), """\ def forward(self, primals_1): mul = torch.ops.aten.mul.Tensor(primals_1, 3); primals_1 = None view = torch.ops.aten.view.default(mul, [-1]) return (view, mul)""", ) def test_output_aliases_intermediate_and_returned_flipped(self): def f(a): out = torch.mul(a, 3) # AOTAutograd should manually generate the first output (a view of an intermediate) # but not the second (which is itself the intermediate for the first) return out, out.view(-1) inp = [torch.ones(3, 3, requires_grad=False)] self.verify_aot_autograd(f, inp, test_mutation=True) inp = [torch.ones(3, 3, requires_grad=True)] fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) self.assertExpectedInline( fw_graph.code.strip(), """\ def forward(self, primals_1): mul = torch.ops.aten.mul.Tensor(primals_1, 3); primals_1 = None view = torch.ops.aten.view.default(mul, [-1]) return (mul, view)""", ) def test_output_aliases_intermediate_and_returned_different_grad(self): def f(a): out = torch.mul(a, 3) # AOTAutograd should manually generate the first output (a view of an intermediate) # but not the second (which is itself the intermediate for the first) return out.view(-1), out, out[0].detach() inp = [torch.ones(3, 3, requires_grad=False)] self.verify_aot_autograd(f, inp, test_mutation=True) inp = [torch.ones(3, 3, requires_grad=True)] fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) self.assertExpectedInline( fw_graph.code.strip(), """\ def forward(self, primals_1): mul = torch.ops.aten.mul.Tensor(primals_1, 3); primals_1 = None view = torch.ops.aten.view.default(mul, [-1]) select = torch.ops.aten.select.int(mul, 0, 0) detach = torch.ops.aten.detach.default(select); select = None detach_1 = torch.ops.aten.detach.default(detach); detach = None detach_2 = torch.ops.aten.detach.default(detach_1); detach_1 = None return (view, mul, detach_2)""", ) def test_output_aliases_intermediate_inplace_view(self): def f(a): out = torch.mul(a, 3) out.t_() return out inp = [torch.ones(2, 4, requires_grad=True)] # TODO: fix this test. # See https://github.com/pytorch/pytorch/issues/90507 # self.verify_aot_autograd(f, inp, test_mutation=True) def test_output_aliases_intermediate_inplace_view_with_detach(self): def f(a): out = torch.mul(a, 3) out.t_() out.detach_() # Thanks to the detach_() AOT Autograd doesn't need to do anything. # `out` will show up as having OutputType.non_alias, # and ._is_view() == False return out, a + 1 inp = [torch.ones(2, 4, requires_grad=False)] self.verify_aot_autograd(f, inp, test_mutation=True) inp = [torch.ones(2, 4, requires_grad=True)] fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) self.assertExpectedInline( fw_graph.code.strip(), """\ def forward(self, primals_1): mul = torch.ops.aten.mul.Tensor(primals_1, 3) t = torch.ops.aten.t.default(mul); mul = None add = torch.ops.aten.add.Tensor(primals_1, 1); primals_1 = None return (t, add)""", ) def test_output_aliases_intermediate_inplace_view_and_view(self): def f(a): out = torch.mul(a, 3) out_view = out.unsqueeze(0) out.t_() out_view2 = out.unsqueeze(0) return out_view, out, out_view2 inp = [torch.ones(2, 4, requires_grad=True)] # TODO: fix this test. # See # self.verify_aot_autograd(f, inp, test_mutation=True) def test_output_aliases_intermediate_multiple_mixed(self): def f(a): out1 = torch.mul(a, 3) out2 = torch.mul(a, 4) # AOTAutograd should manually generate these two output views in the epilogue. return out1.view(-1), out2.transpose(1, 0), out1.transpose(1, 0) inp = [torch.ones(3, 3, requires_grad=False)] self.verify_aot_autograd(f, inp, test_mutation=True) inp = [torch.ones(3, 3, requires_grad=True)] fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) self.assertExpectedInline( fw_graph.code.strip(), """\ def forward(self, primals_1): mul = torch.ops.aten.mul.Tensor(primals_1, 3) mul_1 = torch.ops.aten.mul.Tensor(primals_1, 4); primals_1 = None view = torch.ops.aten.view.default(mul, [-1]) transpose = torch.ops.aten.transpose.int(mul_1, 1, 0); mul_1 = None transpose_1 = torch.ops.aten.transpose.int(mul, 1, 0) return (view, transpose, transpose_1, mul)""", ) def test_output_all_alias_types(self): # There are 3 types of aliasing that require us to return metadata in the compiled fw: # (1) outputs that are views of inputs # (2) outputs that are views of intermediates # (3) inputs that get metadata mutations # test all 3 of them here def f(a): a.transpose_(1, 0) tmp = a.mul(2) return tmp.squeeze(), tmp.transpose(1, 0), a.unsqueeze(0) def inp_callable(req_grad): x = torch.ones(1, 2, 4, requires_grad=req_grad).clone() return [(x,), (x,)] self.verify_aot_autograd( f, partial(inp_callable, req_grad=False), test_mutation=True ) fw_graph = self.verify_aot_autograd( f, partial(inp_callable, req_grad=True), test_mutation=True ) # TODO: make this test run with dynamic shapes so it is more meaningful # metadata output order: (a_updated_meta, out1_meta, out2_meta, out3_meta) self.assertExpectedInline( fw_graph.code.strip(), """\ def forward(self, primals_1): view = torch.ops.aten.view.default(primals_1, [1, 2, 4]); primals_1 = None transpose = torch.ops.aten.transpose.int(view, 1, 0); view = None mul = torch.ops.aten.mul.Tensor(transpose, 2) squeeze = torch.ops.aten.squeeze.default(mul) transpose_1 = torch.ops.aten.transpose.int(mul, 1, 0) unsqueeze = torch.ops.aten.unsqueeze.default(transpose, 0) return (transpose, squeeze, transpose_1, unsqueeze, mul)""", ) @parametrize("req_grad", [False, True]) def test_subclass_metadata_mutation(self, req_grad): def f(a): a.transpose_(1, 0) tmp = a.mul(2) return tmp.transpose(1, 0) def inp_callable(req_grad): x = torch.ones(1, 2, 4, requires_grad=req_grad).clone() return [(x,), (x,)] # See https://github.com/pytorch/pytorch/issues/114975 with self.assertRaisesRegex( RuntimeError, "Metadata mutations are currently not allowed on tensor subclasses", ): self.verify_aot_autograd( f, partial(inp_callable, req_grad=req_grad), test_mutation=True, make_inputs_subclasses=True, ) def test_input_data_and_metadata_mutation(self): def f(a): a.t_() a[0].mul_(2) return a.view(a.shape) inp = [torch.ones(3, 3, requires_grad=False)] self.verify_aot_autograd(f, inp, test_mutation=True) inp = [torch.ones(3, 3, requires_grad=True)] fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) self.assertExpectedInline( fw_graph.code.strip(), """\ def forward(self, primals_1): clone = torch.ops.aten.clone.default(primals_1); primals_1 = None t = torch.ops.aten.t.default(clone) select = torch.ops.aten.select.int(t, 0, 0); t = None mul = torch.ops.aten.mul.Tensor(select, 2); select = None t_1 = torch.ops.aten.t.default(clone); clone = None select_scatter = torch.ops.aten.select_scatter.default(t_1, mul, 0, 0); t_1 = mul = None t_2 = torch.ops.aten.t.default(select_scatter); select_scatter = None t_4 = torch.ops.aten.t.default(t_2) t_6 = torch.ops.aten.t.default(t_2); t_2 = None view_1 = torch.ops.aten.view.default(t_6, [3, 3]); t_6 = None return (t_4, view_1)""", ) def test_view_and_inplace_view(self): def f(a, b): a.t_() return b.view(b.shape), a.view(a.shape) def create_inp(req_grad): return [ torch.ones(3, 3, requires_grad=req_grad), torch.ones(3, 3, requires_grad=req_grad), ] self.verify_aot_autograd(f, create_inp(False), test_mutation=True) fw_graph = self.verify_aot_autograd(f, create_inp(True), test_mutation=True) self.assertExpectedInline( fw_graph.code.strip(), """\ def forward(self, arg0_1, arg1_1): t = torch.ops.aten.t.default(arg0_1); arg0_1 = None view = torch.ops.aten.view.default(arg1_1, [3, 3]); arg1_1 = None view_1 = torch.ops.aten.view.default(t, [3, 3]) return (t, view, view_1)""", ) def test_view_detach(self): def f(a): tmp = a.detach() a.mul_(2) return a, tmp inp = [torch.ones(3, 3, requires_grad=True)] self.verify_aot_autograd(f, inp, test_mutation=True) inp = [torch.ones(3, 3, requires_grad=False)] self.verify_aot_autograd(f, inp, test_mutation=True) def test_input_inplace_requires_grad_true(self): def f(a, b): a.requires_grad_(True) return a.mul(3), b.mul(4) inp = [ # First inp doesnt require grad, but we switch it on torch.ones(3, 3, requires_grad=False), torch.ones(3, 3, requires_grad=True), ] fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True) self.assertExpectedInline( fw_graph.code.strip(), """\ def forward(self, primals_1, primals_2): mul = torch.ops.aten.mul.Tensor(primals_1, 3); primals_1 = None mul_1 = torch.ops.aten.mul.Tensor(primals_2, 4); primals_2 = None return (mul, mul_1)""", ) # This is a torture test: # a and b get turned into a synthetic base in the compiled graph # One gets a data mutation, the other gets a metadata mutation. # We need to make sure that the metadata mutation gets propagated # back to the original input. @skipIfDynamoInput("Dynamo removes runtime error") def test_input_data_and_metadata_mutation_aliases_other_input(self): # a and b are aliased def f(a, b): a.mul_(2) b.t_() return a.mul(b) def inp_callable(req_grad): base = torch.ones(2, 2, requires_grad=req_grad) # Note: in our test, the add() is important because we need the graph inputs to be non-leaves so we can mutate them. x = base.add(1) inp1 = x[0] inp2 = x[0] return [base], [inp1, inp2] self.verify_aot_autograd( f, partial(inp_callable, req_grad=False), test_mutation=True ) self.verify_aot_autograd( f, partial(inp_callable, req_grad=True), test_mutation=True ) with self.assertRaisesRegex( RuntimeError, "Encountered aliased inputs that are mutated in the graph, but", ): self.verify_aot_autograd( f, partial(inp_callable, req_grad=False), test_mutation=True, make_inputs_subclasses=True, ) with self.assertRaisesRegex( RuntimeError, "Encountered aliased inputs that are mutated in the graph, but", ): self.verify_aot_autograd( f, partial(inp_callable, req_grad=True), test_mutation=True, make_inputs_subclasses=True, ) # https://github.com/pytorch/pytorch/issues/106456 def test_input_mutation_noncontiguous(self): def f(a): a.mul_(2) return a + 1 def inp_callable(req_grad): base = torch.ones(2, 2, requires_grad=req_grad) x = base.add(1) # create a non-contiguous view to pass as an input to the compiler inp = x[:, 0] return [base], [inp] self.verify_aot_autograd( f, partial(inp_callable, req_grad=False), test_mutation=True ) self.verify_aot_autograd( f, partial(inp_callable, req_grad=True), test_mutation=True ) with self.assertRaisesRegex( RuntimeError, "Mutations on non-contiguous inputs are currently not allowed on tensor subclasses", ): self.verify_aot_autograd( f, partial(inp_callable, req_grad=False), test_mutation=True, make_inputs_subclasses=True, ) with self.assertRaisesRegex( RuntimeError, "Mutations on non-contiguous inputs are currently not allowed on tensor subclasses", ): self.verify_aot_autograd( f, partial(inp_callable, req_grad=True), test_mutation=True, make_inputs_subclasses=True, ) def test_backward_mutation_data(self): class BwMutation(torch.autograd.Function): @staticmethod def forward(ctx, x): ctx.save_for_backward(x) return x.clone() @staticmethod def backward(ctx, grad_output): (x,) = ctx.saved_tensors # bw mutation x.mul_(2) return grad_output.clone() def f(a, b): out = BwMutation.apply(b) return a * out inp_no_grad = [ torch.ones(3, 3, requires_grad=True), torch.ones(3, 3, requires_grad=False), ] # Mutation on buffer that does not require grad during the backward is allowed self.verify_aot_autograd(f, inp_no_grad, test_mutation=True) inp_grad = [ torch.ones(3, 3, requires_grad=True), torch.ones(3, 3, requires_grad=True), ] self.verify_aot_autograd(f, inp_grad, test_mutation=True) def test_backward_mutation_metadata(self): class BwMutation(torch.autograd.Function): @staticmethod def forward(ctx, a, b): ctx.save_for_backward(b) return a.clone(), b.clone() @staticmethod def backward(ctx, grad_a, grad_b): (b,) = ctx.saved_tensors # bw metadata mutation b.transpose_(1, 0) return grad_a.clone(), grad_b.clone() def f(a, b): a_, b_ = BwMutation.apply(a, b) out = a_ * b_ return out inp_no_grad = [ torch.ones(3, 3, requires_grad=True), torch.ones(3, 3, requires_grad=False), ] with self.assertRaisesRegex( AssertionError, "input that had its metadata mutated in the backward" ): self.verify_aot_autograd(f, inp_no_grad, test_mutation=True) def test_backward_mutation_on_grad_out(self): class BwMutation(torch.autograd.Function): @staticmethod def forward(ctx, x): return x.clone() @staticmethod def backward(ctx, grad_output): grad_output.mul_(2) return grad_output.clone() def f(a, b): tmp = a * b out = BwMutation.apply(tmp) return out inp_grad = [ torch.ones(3, 3, requires_grad=True), torch.ones(3, 3, requires_grad=True), ] f_compiled = aot_function(f, nop) with self.assertRaisesRegex( AssertionError, "input to the backward that was mutated during the backward" ): out = f_compiled(*inp_grad) def test_backward_mutation_forward_inputs(self): @torch.library.custom_op("_test::_clone", mutates_args={}) def f(x: torch.Tensor, x1: torch.Tensor) -> torch.Tensor: return x.clone() def f_fake(x, x1): return torch.empty_like(x) def backward(ctx, grad): with torch.no_grad(): ctx.x1.zero_() return grad * 2, None def setup_context(ctx, inputs, output): (x, x1) = inputs ctx.x = x ctx.x1 = x1 f.register_fake(f_fake) f.register_autograd(backward, setup_context=setup_context) def fn(x: torch.Tensor, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: x2.mul_(5) return torch.ops._test._clone(x, x1) + x2 inp_x, inp_x1, inp_x2 = ( torch.randn(3, requires_grad=True), torch.randn(3, requires_grad=False), torch.randn(3, requires_grad=False), ) ref_x, ref_x1, ref_x2 = inp_x.clone(), inp_x1.clone(), inp_x2.clone() ref_y = fn(ref_x, ref_x1, ref_x2) compiled_f = aot_function(fn, nop, keep_inference_input_mutations=True) x, x1, x2 = inp_x.clone(), inp_x1.clone(), inp_x2.clone() y = compiled_f(x, x1, x2) # Verify mutation in forward applied and mutation in backward is not in forward self.assertEqual(ref_x, x) self.assertEqual(ref_x1, x1) self.assertEqual(ref_x2, x2) self.assertEqual(ref_y, y) ref_y.sum().backward() y.sum().backward() # Verify mutations in backward applied self.assertEqual(ref_x, x) self.assertEqual(ref_x1, x1) self.assertEqual(ref_x2, x2) self.assertEqual(ref_y, y) self.assertEqual(ref_x.grad, x.grad) self.assertEqual(ref_x1.grad, x1.grad) self.assertEqual(ref_x2.grad, x2.grad) def test_backward_mutation_forward_inputs_create_graph(self): @torch.library.custom_op("_test::_clone_create_graph", mutates_args={}) def f(x: torch.Tensor, x1: torch.Tensor) -> torch.Tensor: return x.clone() def f_fake(x, x1): return torch.empty_like(x) def backward(ctx, grad): with torch.no_grad(): ctx.x1.zero_() return grad * 2, None def setup_context(ctx, inputs, output): (x, x1) = inputs ctx.x = x ctx.x1 = x1 f.register_fake(f_fake) f.register_autograd(backward, setup_context=setup_context) def fn(x: torch.Tensor, x1: torch.Tensor) -> torch.Tensor: return torch.ops._test._clone_create_graph(x, x1) inp_x, inp_x1 = torch.randn(3, requires_grad=True), torch.randn( 3, requires_grad=True ) ref_x, ref_x1 = inp_x.clone(), inp_x1.clone() ref_y = f(ref_x, ref_x1) ref_y.sum().backward() x, x1 = inp_x.clone(), inp_x1.clone() compiled_f = aot_function(fn, nop) y = compiled_f(x, x1) loss = y.sum() with self.assertRaisesRegex( RuntimeError, "aot_autograd does not support input mutations with requires_grad in backward for create_graph=True", ): torch.autograd.grad(loss, inp_x, create_graph=True) # Not checking equality of ref and x as Exception is expected # Partially addresses https://github.com/pytorch/pytorch/issues/106457 def test_input_mutation_false_aliasing(self): def f(a, b): a.mul_(3) b.mul_(2) return a.clone().view(-1) + b.clone().view(-1) # No overlap, contiguous def inp_callable1(req_grad): base = torch.ones(4, 4, requires_grad=req_grad) x = base.add(1) # create two views that share storage, but are actually non-overlapping a = x[0:2] b = x[2:4] return [base], [a, b] fw_graph = self.verify_aot_autograd( f, partial(inp_callable1, req_grad=False), test_mutation=True ) self.verify_aot_autograd( f, partial(inp_callable1, req_grad=True), test_mutation=True ) self.verify_aot_autograd( f, partial(inp_callable1, req_grad=False), test_mutation=True, make_inputs_subclasses=True, ) # Input mutations on subclasses with training graphs fail backward guards today. with self.assertRaisesRegex( AssertionError, "attempted to compile the backward with incorrect subclass metadata", ): self.verify_aot_autograd( f, partial(inp_callable1, req_grad=True), test_mutation=True, make_inputs_subclasses=True, ) # Important characteristic: the graph takes in 2 inputs! # That shows that we didn't try to run our complicated synthetic base logic, # because we successfully detected false aliasing across the two inputs. self.assertExpectedInline( fw_graph.code.strip(), """\ def forward(self, arg0_1, arg1_1): mul = torch.ops.aten.mul.Tensor(arg0_1, 3); arg0_1 = None mul_1 = torch.ops.aten.mul.Tensor(arg1_1, 2); arg1_1 = None clone = torch.ops.aten.clone.default(mul) view = torch.ops.aten.view.default(clone, [-1]); clone = None clone_1 = torch.ops.aten.clone.default(mul_1) view_1 = torch.ops.aten.view.default(clone_1, [-1]); clone_1 = None add = torch.ops.aten.add.Tensor(view, view_1); view = view_1 = None return (mul, mul_1, add)""", ) # No overlap, non-contiguous: first tensor ends before second tensor start def inp_callable2(req_grad): base = torch.ones(256, requires_grad=req_grad) x = base.add(1) a = x.as_strided((4, 4), (8, 1), storage_offset=0) b = x.as_strided((4, 4), (8, 1), storage_offset=28) return [base], [a, b] # No overlap, non-contiguous: tensors are perfectly interleaved def inp_callable3(req_grad): base = torch.ones(4, 4, requires_grad=req_grad) x = base.add(1) a = x[:, 0:2] b = x[:, 2:4] return [base], [a, b] # No overlap, non-contiguous def inp_callable4(req_grad): base = torch.ones(256, requires_grad=req_grad) x = base.add(1) a = x.as_strided((4, 4), (9, 1), storage_offset=0) b = x.as_strided((4, 4), (9, 1), storage_offset=22) return [base], [a, b] # No overlap, non-contiguous def inp_callable5(req_grad): base = torch.ones(256, requires_grad=req_grad) x = base.add(1) a = x.as_strided((4, 4), (9, 1), storage_offset=0) b = x.as_strided((4, 4), (9, 1), storage_offset=23) return [base], [a, b] # No overlap, non-contiguous def inp_callable6(req_grad): base = torch.ones(256, requires_grad=req_grad) x = base.add(1) # a's last element is at offset 195 (24 total elements) a = x.as_strided((2, 4, 3), (110, 24, 4), storage_offset=5) # b's first element is at offset 196: no overlap b = x[196 : 196 + a.numel()] return [base], [a, b] # overlap! non-contiguous def inp_callable_overlap1(req_grad): base = torch.ones(256, requires_grad=req_grad) x = base.add(1) a = x.as_strided((4, 4), (9, 1), storage_offset=0) b = x.as_strided((4, 4), (9, 1), storage_offset=24) return [base], [a, b] # overlap! non-contiguous def inp_callable_overlap2(req_grad): base = torch.ones(256, requires_grad=req_grad) x = base.add(1) a = x.as_strided((4, 4), (9, 1), storage_offset=0) b = x.as_strided((4, 4), (9, 1), storage_offset=25) return [base], [a, b] # overlap! non-contiguous def inp_callable_overlap3(req_grad): base = torch.ones(256, requires_grad=req_grad) x = base.add(1) # a's last element is at offset 195 (24 total elements) a = x.as_strided((2, 4, 3), (110, 24, 4), storage_offset=5) # b's first element is at offset 195: overlap! b = x[195 : 195 + a.numel()] return [base], [a, b] fw_graph2 = self.verify_aot_autograd( f, partial(inp_callable2, req_grad=False), test_mutation=True ) fw_graph3 = self.verify_aot_autograd( f, partial(inp_callable3, req_grad=False), test_mutation=True ) fw_graph4 = self.verify_aot_autograd( f, partial(inp_callable4, req_grad=False), test_mutation=True ) fw_graph5 = self.verify_aot_autograd( f, partial(inp_callable5, req_grad=False), test_mutation=True ) fw_graph6 = self.verify_aot_autograd( f, partial(inp_callable6, req_grad=False), test_mutation=True ) fw_graph_overlap1 = self.verify_aot_autograd( f, partial(inp_callable_overlap2, req_grad=False), test_mutation=True ) fw_graph_overlap2 = self.verify_aot_autograd( f, partial(inp_callable_overlap1, req_grad=False), test_mutation=True ) # All non-overlap graphs should be the same since we detected false aliasing self.assertEqual(str(fw_graph.code), str(fw_graph2.code)) self.assertEqual(str(fw_graph.code), str(fw_graph3.code)) self.assertEqual(str(fw_graph.code), str(fw_graph4.code)) self.assertEqual(str(fw_graph.code), str(fw_graph5.code)) self.assertEqual(str(fw_graph.code), str(fw_graph6.code)) # All overlap graphs should be the same since we detected real aliasing self.assertNotEqual(str(fw_graph.code), str(fw_graph_overlap1.code)) self.assertNotEqual(str(fw_graph.code), str(fw_graph_overlap2.code)) self.assertTrue("as_strided_scatter" in str(fw_graph_overlap1.code)) self.assertTrue("as_strided_scatter" in str(fw_graph_overlap2.code)) @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable") def test_mem_leak_from_save_for_bw(self): # See a full diagnosis at this issue: https://github.com/pytorch/pytorch/issues/94990 # Note [Detaching saved tensors in AOTAutograd] # This program creates a ref-cycle. Long term, we should fix this ref cycle # (since it can arise, naturally albeit rarely, from uses of autograd.Function). # But AOTAutograd makes it more likely to show up from tracing user programs, # so we deal with it by manually detaching the tensors that we save for backward. # This is completely wrong and would give wrong results if we were to do double backward. # Fortunately today, double backward is explicitly banned in AOTAutograd. def f(a, b): add = a + a split = torch.functional.split(add, [4, 4], dim=1) getitem_2 = split[1] unsqueeze = getitem_2.unsqueeze(-1) mul = unsqueeze * b return (getitem_2, mul) f_compiled = aot_function(f, nop) inps = [ torch.ones(8, 8, device="cuda", requires_grad=True), torch.ones(1, 4, 1, device="cuda", requires_grad=True), ] mem_before = torch.cuda.memory_allocated() f_compiled(*inps) mem_after = torch.cuda.memory_allocated() self.assertTrue(mem_after == mem_before) def test_output_aliases_multiple_inputs_get_correct_one(self): # a and b are aliased, but have different shapes # The first output should view off the first input, the 2nd output should view off the 2nd input def f(a, b): return a.view(a.shape), b.view(b.shape) def inp_callable(req_grad): base = torch.ones(2, 2, requires_grad=req_grad) # Note: in our test, the add() is important because we need the graph inputs to be non-leaves so we can mutate them. x = base.mul(2) inp1 = x.view(-1) inp2 = x[0] return [base], [inp1, inp2] self.verify_aot_autograd( f, partial(inp_callable, req_grad=False), test_mutation=True ) self.verify_aot_autograd( f, partial(inp_callable, req_grad=True), test_mutation=True ) self.verify_aot_autograd( f, partial(inp_callable, req_grad=False), test_mutation=True, make_inputs_subclasses=True, ) self.verify_aot_autograd( f, partial(inp_callable, req_grad=True), test_mutation=True, make_inputs_subclasses=True, ) def test_input_mutation_aliases_other_input(self): def f(a, b): a.add_(1) return a + b def inp_callable(req_grad): base = torch.ones(4, 2, requires_grad=req_grad) # Note: in our test, the add() is important because we need the graph inputs to be non-leaves so we can mutate them. x = base.add(1) inp1 = x[0] inp2 = x[0] return [base], [inp1, inp2] self.verify_aot_autograd( f, partial(inp_callable, req_grad=False), test_mutation=True ) fw_graph = self.verify_aot_autograd( f, partial(inp_callable, req_grad=True), test_mutation=True ) # Important parts of the graph: # - the compiled graph takes in a base, and we generate a and b (the views) off of the base # - clone() is still in the graph, because we need to call grad() on the original (non-mutated) inputs # - We re-generate the views *after* the clone, to preserve view relationships. self.assertExpectedInline( fw_graph.code.strip(), """\ def forward(self, primals_1): clone = torch.ops.aten.clone.default(primals_1); primals_1 = None as_strided = torch.ops.aten.as_strided.default(clone, [2], [1], 0) add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [2], [1], 0); clone = add = None as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 0) as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 0) add_1 = torch.ops.aten.add.Tensor(as_strided_2, as_strided_5); as_strided_2 = as_strided_5 = None return (as_strided_scatter, add_1)""", ) # noqa: B950 def test_input_mutation_aliases_other_input2(self): def f(a, b): a.add_(1) return a + b def inp_callable(req_grad): base = torch.ones(2, 2, requires_grad=req_grad) x = base.add(1) inp1 = x[0] # Here, one of the aliased inputs is the base itself inp2 = x return [base], [inp1, inp2] self.verify_aot_autograd( f, partial(inp_callable, req_grad=False), test_mutation=True ) fw_graph = self.verify_aot_autograd( f, partial(inp_callable, req_grad=True), test_mutation=True ) self.assertExpectedInline( fw_graph.code.strip(), """\ def forward(self, primals_1): clone = torch.ops.aten.clone.default(primals_1); primals_1 = None as_strided = torch.ops.aten.as_strided.default(clone, [2], [1], 0) add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [2], [1], 0); clone = add = None as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 0) as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [2, 2], [2, 1], 0) add_1 = torch.ops.aten.add.Tensor(as_strided_2, as_strided_5); as_strided_2 = as_strided_5 = None return (as_strided_scatter, add_1)""", ) # noqa: B950 def test_input_mutation_aliases_and_output_alias(self): def f(a, b): # Here, we need to take care:that because and b are aliased # since a and b are aliased, we generate a view off of "updated b" a.add_(1) return b.view(b.shape) def inp_callable(req_grad): base = torch.ones(2, 2, requires_grad=req_grad) x = base.add(1) return [base], [x.view(-1), x.view(-1)] self.verify_aot_autograd( f, partial(inp_callable, req_grad=False), test_mutation=True ) fw_graph = self.verify_aot_autograd( f, partial(inp_callable, req_grad=True), test_mutation=True ) self.assertExpectedInline( fw_graph.code.strip(), """\ def forward(self, primals_1): clone = torch.ops.aten.clone.default(primals_1); primals_1 = None as_strided = torch.ops.aten.as_strided.default(clone, [4], [1], 0) add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [4], [1], 0); clone = add = None as_strided_8 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0) view_1 = torch.ops.aten.view.default(as_strided_8, [4]); as_strided_8 = None return (as_strided_scatter, view_1)""", ) # noqa: B950 def test_input_aliased_with_mutation_output_alias(self): def f(a, b, c): # a and c alias c.mul_(2) # The main thing we're testing here is that # (1) We need to reconstruct c.view(-1) from the 3rd input to the forward # (2) But we need to be careful to do this *before* converting aliased inputs into synthetic bases. # The original fw takes in 3 args, but the compiled fw takes in only 2 args. return b.add(1), c.view(-1) def inp_callable(req_grad): base1 = torch.ones(2, 2, requires_grad=req_grad) base2 = torch.ones(2, 2, requires_grad=req_grad) x = base1.add(1) y = base2.add(1) return [base1, base2], [x.view(-1), y, x.view(-1)] self.verify_aot_autograd( f, partial(inp_callable, req_grad=False), test_mutation=True ) fw_graph = self.verify_aot_autograd( f, partial(inp_callable, req_grad=True), test_mutation=True ) self.assertExpectedInline( fw_graph.code.strip(), """\ def forward(self, primals_1, primals_2): clone = torch.ops.aten.clone.default(primals_1); primals_1 = None as_strided_1 = torch.ops.aten.as_strided.default(clone, [4], [1], 0) mul = torch.ops.aten.mul.Tensor(as_strided_1, 2); as_strided_1 = None as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, mul, [4], [1], 0); clone = mul = None add = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None as_strided_7 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0) view_1 = torch.ops.aten.view.default(as_strided_7, [-1]); as_strided_7 = None return (as_strided_scatter, add, view_1)""", ) # noqa: B950 def test_input_metadata_mutation_aliases(self): def f(a, b): # a and b alias, and we do a metadata mutation on a # Since we're not mutating data, then b isn't affected at all. # We expect aot autograd to not bother with constructing a synthetic base. a.t_() return a + b def inp_callable(req_grad): base = torch.ones(2, 2, requires_grad=req_grad) x = base.add(1) return [base], [x.view(-1), x.view(-1)] self.verify_aot_autograd( f, partial(inp_callable, req_grad=False), test_mutation=True ) fw_graph = self.verify_aot_autograd( f, partial(inp_callable, req_grad=True), test_mutation=True ) # Expectation: fwd() takes in 2 args, and we don't construct a synthetic base. self.assertExpectedInline( fw_graph.code.strip(), """\ def forward(self, primals_1, primals_2): t = torch.ops.aten.t.default(primals_1); primals_1 = None add = torch.ops.aten.add.Tensor(t, primals_2); t = primals_2 = None return (add,)""", ) def test_input_mutation_aliases_and_none_require_gradients(self): def f(a, b, c): # a and b alias, but neither require gradients (so they don't have a _base) # aot autograd should construct the synthetic base from `torch.Tensor(a.storage())` a.mul_(2) return b + 1, c + 1 def inp_callable(req_grad): base = torch.ones(2, 2) c_arg = torch.ones(2, 2, requires_grad=req_grad) x = base.add(1) return [base, c_arg], [x.view(-1), x.view(-1), c_arg] self.verify_aot_autograd( f, partial(inp_callable, req_grad=False), test_mutation=True ) with self.assertRaisesRegex( RuntimeError, "is a tensor subclass. This is not supported today" ): self.verify_aot_autograd( f, partial(inp_callable, req_grad=False), test_mutation=True, make_inputs_subclasses=True, ) fw_graph = self.verify_aot_autograd( f, partial(inp_callable, req_grad=True), test_mutation=True ) self.assertExpectedInline( fw_graph.code.strip(), """\ def forward(self, primals_1, primals_2): as_strided = torch.ops.aten.as_strided.default(primals_1, [4], [1], 0) mul = torch.ops.aten.mul.Tensor(as_strided, 2); as_strided = None as_strided_scatter = torch.ops.aten.as_strided_scatter.default(primals_1, mul, [4], [1], 0); primals_1 = mul = None as_strided_3 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0) add = torch.ops.aten.add.Tensor(as_strided_3, 1); as_strided_3 = None add_1 = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None return (as_strided_scatter, add, add_1)""", ) # noqa: B950 @skipIfDynamoInput("Fails with dynamo") def test_input_mutation_aliases_bases_out_of_order(self): # This tests our calling convention: if b and d are aliased, then the outer calling convention # that we send to the compiled forward becomes: # (b_d_base, a, c) # Importantly, even though a and c alias in our test, neither inputs are mutated, # So we don't need to do the base construction / deconstruction def f(a, b, c, d): b.add_(1) d.unsqueeze_(0) return a + c + d, b.view(-1) def inp_callable(req_grad): base1 = torch.ones(2, 2, requires_grad=req_grad) base2 = torch.ones(2, 2, requires_grad=req_grad) x1 = base1.add(1) x2 = base2.add(1) # a and c alias, b and d alias return [base1, base2], [x1.view(-1), x2.view(-1), x1.view(-1), x2.view(-1)] self.verify_aot_autograd( f, partial(inp_callable, req_grad=False), test_mutation=True ) with self.assertRaisesRegex( RuntimeError, "Metadata mutations are currently not allowed on tensor subclasses", ): self.verify_aot_autograd( f, partial(inp_callable, req_grad=False), test_mutation=True, make_inputs_subclasses=True, ) fw_graph = self.verify_aot_autograd( f, partial(inp_callable, req_grad=True), test_mutation=True ) # 3 graph inputs: (b_d_base, a, c) # 2 returns: (b_updated, a+c+d) # (there are 2 original fw outs, but one is a view of b so it's not part of the graph) # (there are also 2 input mutations, but one is a metadata-only mutation so the compiled forward doesn't return it) self.assertExpectedInline( fw_graph.code.strip(), """\ def forward(self, primals_1, primals_2, primals_3): clone = torch.ops.aten.clone.default(primals_1); primals_1 = None as_strided = torch.ops.aten.as_strided.default(clone, [4], [1], 0) add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [4], [1], 0); clone = add = None add_1 = torch.ops.aten.add.Tensor(primals_2, primals_3); primals_2 = primals_3 = None as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0) unsqueeze_1 = torch.ops.aten.unsqueeze.default(as_strided_5, 0); as_strided_5 = None add_2 = torch.ops.aten.add.Tensor(add_1, unsqueeze_1); add_1 = None as_strided_14 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0) view_2 = torch.ops.aten.view.default(as_strided_14, [-1]); as_strided_14 = None return (as_strided_scatter, add_2, view_2, unsqueeze_1)""", ) # noqa: B950 @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable") def test_synthetic_base_base_attribute_is_none(self): def f(a, b): a.add_(1) return a + b def inp_callable(): base = torch.ones(4, 4, device="cuda") # detach() so that none of the inputs have a ._base attribute. a = base[0].detach() b = base[1].detach() base2 = torch.ones(2, 2, requires_grad=True) return [base], [a, b] self.verify_aot_autograd(f, inp_callable, test_mutation=True) def test_input_mutation_alias_everything(self): # Mondo test that tests a combination of: # input is mutated, that aliases another input (so we make a synthetic base) # an output is an alias of another output # an output is an alias of an intermediate # a and c are aliased def f(a, b, c): c.mul_(2) # mutates c b.t_() # metadata mutate b tmp = a + c out1 = tmp.view(-1) out2 = b.t() out3 = out1.unsqueeze(0) # out1 and out3 are aliases of an intermediate, and alias each other! # out2 aliases an input, so we don't return it return out1, out2, out3 def inp_callable(req_grad): base1 = torch.ones(2, 2, requires_grad=req_grad) base2 = torch.ones(2, 2, requires_grad=req_grad) # Note: in our test, the add() is important because we need the graph inputs to be non-leaves so we can mutate them. base1_ = base1.add(1) base2_ = base2.add(1) a = base1_.view(-1) b = base2_ c = base1_.view(-1) return [base1, base2], [a, b, c] self.verify_aot_autograd( f, partial(inp_callable, req_grad=False), test_mutation=True ) fw_graph = self.verify_aot_autograd( f, partial(inp_callable, req_grad=True), test_mutation=True ) # Expected: # - 2 inputs in the forward: synthetic_base_a_c, b # - 1 output in the forward: "tmp" # out2 is an alias of an input, and will be generated off of b outside of the compiled fn # out1 and out3 are aliases of tmp, that we generate outside of the compiled function self.assertExpectedInline( fw_graph.code.strip(), """\ def forward(self, primals_1, primals_2): clone = torch.ops.aten.clone.default(primals_1); primals_1 = None view = torch.ops.aten.view.default(primals_2, [2, 2]); primals_2 = None as_strided_1 = torch.ops.aten.as_strided.default(clone, [4], [1], 0) mul = torch.ops.aten.mul.Tensor(as_strided_1, 2); as_strided_1 = None as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, mul, [4], [1], 0); clone = mul = None as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0) t = torch.ops.aten.t.default(view); view = None as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0) add = torch.ops.aten.add.Tensor(as_strided_5, as_strided_2); as_strided_5 = as_strided_2 = None view_1 = torch.ops.aten.view.default(add, [-1]) t_1 = torch.ops.aten.t.default(t) unsqueeze = torch.ops.aten.unsqueeze.default(view_1, 0) return (as_strided_scatter, t, view_1, t_1, unsqueeze, add)""", ) # noqa: B950 def test_dynamic_shape_output_not_in_bw_graph(self): def f(x): return [x + 1, x.shape[0]] inp = torch.ones(5, requires_grad=True) bw_graph_cell = [None] compiled_f = aot_function( f, fw_compiler=nop, bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell), decompositions={}, keep_inference_input_mutations=False, dynamic=True, ) out = compiled_f(inp) out[0].sum().backward() # The important bit: the forward fn returns 2 outputs, # but one of them is a symint so we should only see # 1 grad_output as an input to the backward graph. # (Otherwise, autograd will plumb a None as the value of the grad_output, # which causes inductor to complain). self.assertExpectedInline( bw_graph_cell[0].code.strip(), """\ def forward(self, tangents_1): return (tangents_1,)""", ) def test_no_grad_input_output(self): def f(a, b): return a.cos(), b.cos(), a * b inp_thunks = [ lambda: torch.randn(5, requires_grad=True), lambda: torch.randn(5, requires_grad=False), ] for inps in itertools.product(inp_thunks, repeat=2): inps = [i() for i in inps] self.verify_aot_autograd(f, inps) def test_some_output_requires_grad_input_doesnt(self): def f(a, b): a_view = a.view(-1) a_view.requires_grad_(True) return a_view inp = [torch.randn(3, 3), torch.randn(3, 3, requires_grad=True)] self.verify_aot_autograd(f, inp) def test_some_outputs_dont_require_grad_view(self): def f(a, b): return a.detach(), b inp = [ torch.randn(3, 3, requires_grad=True), torch.randn(3, 3, requires_grad=True), ] self.verify_aot_autograd(f, inp) def test_some_outputs_dont_require_grad_non_view(self): def f(a, b): return a.add(1).detach(), b inp = [ torch.randn(3, 3, requires_grad=True), torch.randn(3, 3, requires_grad=True), ] self.verify_aot_autograd(f, inp) def test_inner_grad(self): def foo(x): y = torch.exp(x) z = torch.autograd.grad(y, x) return z inps = [torch.randn((), requires_grad=True)] self.verify_aot_autograd(foo, inps) def test_grad_context(self): def foo(x): return x * 2 inps = [torch.randn((), requires_grad=True)] graph_size = None def get_graph_size(fx_g, _): nonlocal graph_size graph_size = len(fx_g.graph.nodes) return fx_g f = aot_function(foo, nop, get_graph_size) with torch.set_grad_enabled(False): f(*inps) self.assertIsNone(graph_size) f = aot_function(foo, nop, get_graph_size) with torch.set_grad_enabled(True): out = f(*inps) self.assertIsNone(graph_size) out.sum().backward() self.assertTrue(graph_size > 2) def test_output_dict(self): def f(x): return {"a": x, "b": x} inp = [torch.randn(3, 3, requires_grad=True)] self.verify_aot_autograd(f, inp) def f(x, y): return {"a": x, "b": y + x} inp = [torch.randn(3, requires_grad=True), torch.randn(3)] self.verify_aot_autograd(f, inp) def f(x): new_d = {} for k in x: new_d[k] = x[k] * 2 return new_d a = torch.randn(3, requires_grad=True) b = torch.randn(3, requires_grad=True) def inp_callable(): inps = [{"a": a, "b": b}] return inps, inps self.verify_aot_autograd(f, inp_callable) def test_module(self): mod = nn.Sequential(nn.Linear(32, 32), nn.ReLU()) compiled_mod = compiled_module(mod, nop, nop) inp = torch.randn(32, 32) ref_out = mod(inp) ref_out.sum().backward() ref_grads = sorted([(name, p.grad) for name, p in mod.named_parameters()]) out = compiled_mod(inp) out.sum().backward() grads = sorted([(name, p.grad) for name, p in mod.named_parameters()]) self.assertEqual((out, grads), (ref_out, ref_grads)) def test_batchnorm(self): mod = compiled_module(nn.BatchNorm2d(4), nop, nop) x = torch.ones(1, 4, 2, 2) mod(x).sum().backward() def test_list_codegen(self): def list_nop(f, _): def g(inps): return f(*inps) g._boxed_call = True return g def f(a, b, c): return a.sin() * b.cos() * c.sin() f = aot_function(f, list_nop) inp = [torch.randn(5, requires_grad=True) for _ in range(3)] f(*inp).sum().backward() @patch("torch._functorch.aot_autograd.AOT_COUNTER", new_callable=itertools.count) def test_compilation_context(self, counter): def f(x): return x.sin().sin() count = [] def compiler(fx_g, _): context = get_aot_compilation_context() count.append((context[0], len(fx_g.graph.nodes))) return fx_g f = aot_function(f, compiler) out = f(torch.randn(5, requires_grad=True)) f = aot_function(f, compiler) f(torch.randn(5)) out.sum().backward() self.assertExpectedInline( str(count), """[(['0_forward'], 4), (['1_inference'], 4), (['0_backward'], 8)]""", ) def test_dupe_arg(self): def f(x, y): return x + y x = torch.randn(3, 3, requires_grad=True) self.verify_aot_autograd(f, [x, x]) def test_dupe_arg_torture(self): def f(x, y): x.t_() y.unsqueeze_(0) return x + y x = torch.randn(3, 3, requires_grad=True).clone() self.verify_aot_autograd(f, [x, x]) # See https://github.com/pytorch/pytorch/issues/100224 def test_dupe_arg_returned_as_output(self): def f(a, b, a_): a[0].add_(1) return a_ f_compiled = aot_function(f, nop) a = torch.ones(2) b = torch.ones(2) out_ref = f(a, b, a) a2 = torch.ones(2) b2 = torch.ones(2) out_test = f_compiled(a2, b2, a2) self.assertEqual(out_ref, out_test) self.assertEqual(a, a2) @patch("torch._functorch.aot_autograd.AOT_COUNTER", new_callable=itertools.count) @patch("torch._functorch.config.debug_assert", True) def test_invalid_dupe_left_bias(self, counter): # This test checks that, just because only the first # argument did a metadata mutation, we still correctly # switch to strategy 2 (deduplicate) # See: https://github.com/pytorch/pytorch/pull/89896#discussion_r1036224447 class F(torch.nn.Module): def forward(self, x, y): x.t_() return (x + y,) x = torch.randn(3, 3, requires_grad=True).clone() y = torch.randn(3, 3, requires_grad=True) self.verify_aot_autograd(F(), [x, x]) fxx = aot_module_simplified(F(), (x, x), nop) self.assertExpectedRaisesInline( AssertionError, lambda: fxx(x, y), """At compilation time, graph 2 was compiled under the assumption that input 1 would be a duplicate of input 0, but at runtime this was not the case. This indicates a guard bug in AOTAutograd or Dynamo, please file a bug to PyTorch.""", # noqa: B950 ) @patch("torch._functorch.aot_autograd.AOT_COUNTER", new_callable=itertools.count) @patch("torch._functorch.config.debug_assert", True) def test_invalid_dupe(self, counter): self._test_invalid_dupe(counter, fake=False) # See Note: Dynamo recompilation guarding invalid grad for why this test exists @patch("torch._functorch.aot_autograd.AOT_COUNTER", new_callable=itertools.count) @patch("torch._functorch.config.debug_assert", True) def test_invalid_dupe_fake(self, counter): self._test_invalid_dupe(counter, fake=True) def _test_invalid_dupe(self, counter, fake): class F(torch.nn.Module): def forward(self, x, y): x.unsqueeze_(0) y.unsqueeze_(0) return (x + y,) x = torch.randn(3, 3, requires_grad=True).clone() y = torch.randn(3, 3, requires_grad=True).clone() if fake: shape_env = ShapeEnv() fake_mode = FakeTensorMode(shape_env=shape_env) fake_x = fake_mode.from_tensor(x) fake_y = fake_mode.from_tensor(y) if fake: fxy = aot_module_simplified(F(), (fake_x, fake_y), nop) else: fxy = aot_module_simplified(F(), (x, y), nop) fxy(x, y) x = torch.randn(3, 3, requires_grad=True).clone() y = torch.randn(3, 3, requires_grad=True).clone() fxy(x, x) # is ok! if fake: fxx = aot_module_simplified(F(), (fake_x, fake_x), nop) else: fxx = aot_module_simplified(F(), (x, x), nop) x = torch.randn(3, 3, requires_grad=True).clone() y = torch.randn(3, 3, requires_grad=True).clone() fxx(x, x) # Note This should not raise! Once we have guards in place here, # we will have this working correctly, as it should recompile. x = torch.randn(3, 3, requires_grad=True).clone() y = torch.randn(3, 3, requires_grad=True).clone() self.assertExpectedRaisesInline( AssertionError, lambda: fxx(x, y), """At compilation time, graph 1 was compiled under the assumption that input 1 would be a duplicate of input 0, but at runtime this was not the case. This indicates a guard bug in AOTAutograd or Dynamo, please file a bug to PyTorch.""", # noqa: B950 ) @patch("torch._functorch.aot_autograd.AOT_COUNTER", new_callable=itertools.count) @patch("torch._functorch.config.debug_assert", True) def test_invalid_requires_grad(self, counter): self._test_invalid_requires_grad(counter, fake=False) # See Note: Dynamo recompilation guarding invalid grad for why this test exists @patch("torch._functorch.aot_autograd.AOT_COUNTER", new_callable=itertools.count) @patch("torch._functorch.config.debug_assert", True) def test_invalid_requires_grad_fake(self, counter): self._test_invalid_requires_grad(counter, fake=True) def _test_invalid_requires_grad(self, counter, fake): class F(torch.nn.Module): def forward(self, x, y): return (x + y,) x = torch.randn(3, 3, requires_grad=True) y = torch.randn(3, 3, requires_grad=True) z = torch.randn(3, 3, requires_grad=False) if fake: shape_env = ShapeEnv() fake_mode = FakeTensorMode(shape_env=shape_env) fake_x = fake_mode.from_tensor(x) fake_y = fake_mode.from_tensor(y) fake_z = fake_mode.from_tensor(z) if fake: fxy = aot_module_simplified(F(), (fake_x, fake_y), nop) else: fxy = aot_module_simplified(F(), (x, y), nop) compare_equal_outs_and_grads(self, F(), fxy, (x, y)) compare_equal_outs_and_grads(self, F(), fxy, (x, z)) if fake: fxz = aot_module_simplified(F(), (fake_x, fake_z), nop) else: fxz = aot_module_simplified(F(), (x, z), nop) compare_equal_outs_and_grads(self, F(), fxz, (x, z)) self.assertExpectedRaisesInline( AssertionError, lambda: fxz(x, y), """At compilation time, graph 1 was compiled under the assumption that input 1 would not require grad, but at runtime this was not the case. This indicates a guard bug in AOTAutograd or Dynamo, please file a bug to PyTorch.""", # noqa: B950 ) def test_custom_autograd(self): class CustomFn(torch.autograd.Function): @staticmethod def forward(ctx, x): return x.clone() @staticmethod def backward(ctx, grad_output): return grad_output + 1 def f(x): return CustomFn.apply(x) self.verify_aot_autograd(f, [torch.randn(3)]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable") def test_autocast_disable_guard(self): with torch._C._DisableAutocast(): x = torch.rand([4, 4]).cuda() y = x @ x self.assertEqual(y.dtype, torch.float32) @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable") def test_nonidempotent_amp(self): def f(self_s_emb, add_3): einsum_2 = torch.functional.einsum("ah,th->t", self_s_emb, add_3) log_softmax_2 = einsum_2.log_softmax(-1) return (log_softmax_2,) args = [ torch.rand((1, 256), dtype=torch.float32, device="cuda"), torch.rand((30, 256), dtype=torch.float16, device="cuda"), ] with torch.cuda.amp.autocast(enabled=True): self.verify_aot_autograd(f, args) args = [e.requires_grad_(True) for e in args] with torch.cuda.amp.autocast(enabled=True): self.verify_aot_autograd(f, args) @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable") @unittest.skipIf(not torch.backends.cudnn.is_available(), "CUDNN is unavailable") @skipIfRocm # https://github.com/pytorch/pytorch/issues/96560 def test_batch_norm_amp(self): device = "cuda" input_dtype = torch.float16 param_dtype = torch.float32 weight, bias = ( torch.ones(64, device=device, dtype=param_dtype, requires_grad=True) for _ in range(2) ) running_mean, running_var = ( torch.ones(64, device=device, dtype=param_dtype) for _ in range(2) ) def bn(x): return torch.ops.aten.cudnn_batch_norm( x, weight, bias, running_mean, running_var, False, 0.1, 1e-05, ) inp = torch.ones( torch.Size([16, 64, 112, 112]), dtype=input_dtype, device=device ) ref = bn(inp) cudnn_batch_norm_decomp = torch._decomp.get_decompositions( {torch.ops.aten.cudnn_batch_norm} ) aot_fn = make_fx(bn, decomposition_table=cudnn_batch_norm_decomp)(inp) res = aot_fn(inp) for a, b in zip(ref, res): assert torch.allclose(a, b) def test_output_op_depending_on_symint(self): """ It won't be obvious from reading this test what it's testing for. We should probably make it into a more focused unit test. An issue with the following program was the expand op would end up depending on a symint whose proxy was incorrectly associated with one of the grad tensors rather than input tensors. It broke partitioner logic and the net result was aot_function failed to produce a function and threw an exception instead. """ inp = torch.randn(5, requires_grad=True) def f(x): return x.expand(x.shape) # TODO(whc) make this work (test setup is wrong somehow) # joint_forward_backward = create_joint_forward_backward(f) # out = f(inp) # joint_inputs = ([inp], [out.detach().contiguous()]) # fx_g = make_fx(joint_forward_backward)(*joint_inputs) # TODO: assert outputs of fwd graph trace to correct symint # e2e test that fails without symint clone fix af = aot_function( f, nop, partition_fn=partial( min_cut_rematerialization_partition, compiler="inductor" ), dynamic=True, ) out = af(inp) self.assertEqual(out, f(inp)) def test_inference_mode(self): m = torch.nn.Linear(4, 4) inp = torch.randn(4, 4) aot_mod = aot_module(m, fw_compiler=nop) with torch.inference_mode(): out_ref = m(inp) out_test = aot_mod(inp) self.assertEqual(out_ref, out_test) def test_default_partitioner_saves_symints_not_tensors_for_bw(self): """ In this test, the important thing is that primals_1 is **only** needed in the backward in order to grab its sizes. We need to assert that what we save for the backward are the tensor's sizes, and not the tensor itself. The way this test is set up, it will actually fail if we try to save the input tensor for backward. Why? b.masked_fill_(c, 0) has a backward that requires knowing a's sizes b.masked_fill_(c, 0) **also** mutates a (because b and a are aliased) The autograd engine yells at us if we save "a" for backward, and then try to mutate it. """ inp = torch.randn(2, 2, requires_grad=True) def f(a): b = a[0] c = torch.ones_like(b, dtype=torch.bool) d = b.masked_fill_(c, 0) return d compiled_f = aot_function(f, nop, dynamic=True) inp_ref = torch.ones(2, 2, requires_grad=True) inp_test = torch.ones(2, 2, requires_grad=True) out_ref = f(inp_ref.clone()) out_test = compiled_f(inp_test.clone()) self.assertEqual(out_ref, out_test) out_ref.sum().backward() out_test.sum().backward() self.assertEqual(inp_ref.grad, inp_test.grad) def test_buffer_copied_in_graph(self): class MyModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.buf = torch.nn.Buffer(torch.zeros(1)) self.w1 = torch.nn.Parameter(torch.zeros(1)) self.w2 = torch.nn.Parameter(torch.zeros(1)) def forward(self, x): self.buf.add_(1) return (self.w1 * x * self.w2).sum() + self.buf.sum() model_for_eager = MyModel() model_for_compile = copy.deepcopy(model_for_eager) fw_graph_cell = [None] compiled_f = aot_module( model_for_compile, fw_compiler=make_boxed_compiler( partial(extract_graph, graph_cell=fw_graph_cell) ), bw_compiler=nop, keep_inference_input_mutations=True, ) inp_ref = torch.ones(1, requires_grad=True) inp_test = torch.ones(1, requires_grad=True) out_ref = model_for_eager(inp_ref.clone()) out_test = compiled_f(inp_test.clone()) self.assertExpectedInline( fw_graph_cell[0].code.strip(), """\ def forward(self, primals_1, primals_2, primals_3, primals_4): add = torch.ops.aten.add.Tensor(primals_3, 1) mul = torch.ops.aten.mul.Tensor(primals_1, primals_4) mul_1 = torch.ops.aten.mul.Tensor(mul, primals_2) sum_1 = torch.ops.aten.sum.default(mul_1); mul_1 = None sum_2 = torch.ops.aten.sum.default(add) add_1 = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None copy_ = torch.ops.aten.copy_.default(primals_3, add); primals_3 = add = copy_ = None return (add_1, primals_1, primals_2, primals_4, mul)""", ) self.assertEqual(out_ref, out_test) out_ref.sum().backward() out_test.sum().backward() eager_grads = [p.grad for _, p in model_for_eager.named_parameters()] compile_grads = [p.grad for _, p in model_for_compile.named_parameters()] self.assertEqual(eager_grads, compile_grads) self.assertEqual(inp_ref.grad, inp_test.grad) def test_buffer_copied_in_graph_with_different_shapes(self): class MyModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.buf = torch.nn.Buffer(torch.ones(4, 4)) self.w = torch.nn.Parameter( torch.Tensor([[4, 5], [1, 2], [6, 7], [8, 9]]) ) def forward(self, x): self.buf.add_(1) return (self.w @ x).sum() + self.buf.sum() model_for_eager = MyModel() model_for_compile = copy.deepcopy(model_for_eager) fw_graph_cell = [None] compiled_f = aot_module( model_for_compile, fw_compiler=make_boxed_compiler( partial(extract_graph, graph_cell=fw_graph_cell) ), bw_compiler=nop, keep_inference_input_mutations=True, ) inp_ref = torch.ones(2, 4, requires_grad=True) inp_test = torch.ones(2, 4, requires_grad=True) out_ref = model_for_eager(inp_ref.clone()) out_test = compiled_f(inp_test.clone()) self.assertExpectedInline( fw_graph_cell[0].code.strip(), """\ def forward(self, primals_1, primals_2, primals_3): add = torch.ops.aten.add.Tensor(primals_2, 1) mm = torch.ops.aten.mm.default(primals_1, primals_3) sum_1 = torch.ops.aten.sum.default(mm); mm = None sum_2 = torch.ops.aten.sum.default(add) add_1 = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None copy_ = torch.ops.aten.copy_.default(primals_2, add); primals_2 = add = copy_ = None return (add_1, primals_1, primals_3)""", ) self.assertEqual(out_ref, out_test) out_ref.sum().backward() out_test.sum().backward() eager_grads = [p.grad for _, p in model_for_eager.named_parameters()] compile_grads = [p.grad for _, p in model_for_compile.named_parameters()] self.assertEqual(eager_grads, compile_grads) self.assertEqual(inp_ref.grad, inp_test.grad) def test_buffer_batch_norm(self): class MyModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.m = torch.nn.BatchNorm1d(100) def forward(self, x): return self.m(x) model_for_eager = MyModel() model_for_compile = copy.deepcopy(model_for_eager) fw_graph_cell = [None] bw_graph_cell = [None] compiled_f = aot_module( model_for_compile, fw_compiler=make_boxed_compiler( partial(extract_graph, graph_cell=fw_graph_cell) ), bw_compiler=make_boxed_compiler( partial(extract_graph, graph_cell=bw_graph_cell) ), keep_inference_input_mutations=True, ) inp_ref = torch.ones(20, 100, requires_grad=True) inp_test = torch.ones(20, 100, requires_grad=True) out_ref = model_for_eager(inp_ref.clone()) out_test = compiled_f(inp_test.clone()) self.assertExpectedInline( fw_graph_cell[0].code.strip(), """\ def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6): add = torch.ops.aten.add.Tensor(primals_5, 1) _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(primals_6, primals_1, primals_2, primals_3, primals_4, True, 0.1, 1e-05); primals_2 = None getitem = _native_batch_norm_legit_functional[0] getitem_1 = _native_batch_norm_legit_functional[1] getitem_2 = _native_batch_norm_legit_functional[2] getitem_3 = _native_batch_norm_legit_functional[3] getitem_4 = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None copy_ = torch.ops.aten.copy_.default(primals_3, getitem_3); primals_3 = copy_ = None copy__1 = torch.ops.aten.copy_.default(primals_4, getitem_4); primals_4 = copy__1 = None copy__2 = torch.ops.aten.copy_.default(primals_5, add); primals_5 = add = copy__2 = None return (getitem, primals_1, primals_6, getitem_1, getitem_2, getitem_3, getitem_4)""", # noqa: B950 ) self.assertEqual(out_ref, out_test) out_ref.sum().backward() out_test.sum().backward() eager_grads = [p.grad for _, p in model_for_eager.named_parameters()] compile_grads = [p.grad for _, p in model_for_compile.named_parameters()] self.assertEqual(eager_grads, compile_grads) self.assertExpectedInline( bw_graph_cell[0].code.strip(), """\ def forward(self, primals_1, primals_6, getitem_1, getitem_2, getitem_3, getitem_4, tangents_1): native_batch_norm_backward = torch.ops.aten.native_batch_norm_backward.default(tangents_1, primals_6, primals_1, getitem_3, getitem_4, getitem_1, getitem_2, True, 1e-05, [True, True, True]); tangents_1 = primals_6 = primals_1 = getitem_3 = getitem_4 = getitem_1 = getitem_2 = None getitem_5 = native_batch_norm_backward[0] getitem_6 = native_batch_norm_backward[1] getitem_7 = native_batch_norm_backward[2]; native_batch_norm_backward = None return (getitem_6, getitem_7, None, None, None, getitem_5)""", # noqa: B950 ) self.assertEqual(inp_ref.grad, inp_test.grad) def test_new_inp_requires_grad_now(self): def f(x, y): return x.add_(y) fw_graph_cell = [None] bw_graph_cell = [None] compiled_f = aot_function( f, fw_compiler=make_boxed_compiler( partial(extract_graph, graph_cell=fw_graph_cell) ), bw_compiler=make_boxed_compiler( partial(extract_graph, graph_cell=bw_graph_cell) ), keep_inference_input_mutations=True, ) inp_ref = ( torch.ones(20, 100, requires_grad=False), torch.ones(20, 100, requires_grad=True), ) inp_test = ( torch.ones(20, 100, requires_grad=False), torch.ones(20, 100, requires_grad=True), ) out_ref = f(*inp_ref) out_test = compiled_f(*inp_test) # There is no copy_ method self.assertExpectedInline( fw_graph_cell[0].code.strip(), """\ def forward(self, primals_1, primals_2): clone = torch.ops.aten.clone.default(primals_1); primals_1 = None add = torch.ops.aten.add.Tensor(clone, primals_2); clone = primals_2 = None return (add, add)""", ) # noqa: B950 self.assertEqual(out_ref, out_test) out_ref.sum().backward() out_test.sum().backward() self.assertExpectedInline( bw_graph_cell[0].code.strip(), """\ def forward(self, tangents_1): return (None, tangents_1)""", ) # noqa: B950 def test_real_weights_in_symbolic_mode(self): from functorch.experimental import functionalize class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(5, 5) def forward(self, x): x = self.linear(x) return x m = M().eval() inp = torch.randn(2, 5) gm = make_fx(m, tracing_mode="symbolic", _allow_non_fake_inputs=True)(inp) self.assertEqual(gm(torch.ones(2, 5)), m(torch.ones(2, 5))) gm_functionalized = make_fx( functionalize( gm, ), tracing_mode="symbolic", _allow_non_fake_inputs=True, )(inp) self.assertEqual(gm_functionalized(torch.ones(2, 5)), m(torch.ones(2, 5))) inp_count = 0 for node in gm.graph.nodes: if node.op == "placeholder": inp_count += 1 # No more param lifting self.assertEqual(inp_count, 1) inp_count = 0 for node in gm_functionalized.graph.nodes: if node.op == "placeholder": inp_count += 1 # No more param lifting self.assertEqual(inp_count, 1) with self.assertRaisesRegex( Exception, "Please convert all Tensors to FakeTensors" ): make_fx(m, tracing_mode="symbolic", _allow_non_fake_inputs=False)( torch.randn(2, 5) ) def test_real_weights_in_symbolic_mode_with_inplace_ops(self): class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.buffer = torch.nn.Buffer(torch.ones(4, 5)) def forward(self, x): y = self.buffer.add_(3) y.resize_([20]) assert y.shape == self.buffer.shape return x.sum() + self.buffer.sum() m = M().eval() inp = torch.randn(2, 5) # inplace mutation on attr is not allowed with self.assertRaisesRegex(Exception, "Can't call metadata"): make_fx(m, tracing_mode="symbolic", _allow_non_fake_inputs=True)(inp) def _compile_and_erase_bases(self, *output_view_indices): # Overrides _base and _view_func tensor attributes, so as to avoid the view-replay # execution path when reconstructing views. class NoViewReplayTensor(torch.Tensor): @property def _base(self): return None @property def _view_func(self): return None # Wraps the outputs that are views of the FX graph 'g' with NoViewReplayTensor, # since they are the only ones that will get reconstructed. def wrapper(g, *args, **kwargs): outs = list(g(*args, **kwargs)) for i in output_view_indices: outs[i] = NoViewReplayTensor(outs[i]) return tuple(outs) return lambda f: aot_function(f, fw_compiler=lambda g, _: partial(wrapper, g)) def test_output_aliases_input_view_meta_replay(self): @self._compile_and_erase_bases(0) def f(a): return a.view(-1) inp = torch.ones(2, 2, requires_grad=True) out = f(inp) self.assertIsNotNone(out.grad_fn) self.assertExpectedInline( str(out.grad_fn.__class__), """""" ) def test_output_aliases_intermediate_view_meta_replay(self): @self._compile_and_erase_bases(0, 1) def f(a): b = a.clone() return b.view(-1), b.view(-1) inp = torch.ones(2, 2, requires_grad=True) out1, out2 = f(inp) self.assertIsNotNone(out1.grad_fn) self.assertExpectedInline( str(out1.grad_fn.__class__), """""" ) self.assertIsNotNone(out2.grad_fn) self.assertExpectedInline( str(out2.grad_fn.__class__), """""" ) def test_output_aliases_output_view_meta_replay(self): @self._compile_and_erase_bases(1) def f(a): b = a.add(10) return b, b.view(-1) inp = torch.ones(2, 2, requires_grad=True) out1, out2 = f(inp) self.assertEqual(out1.untyped_storage(), out2.untyped_storage()) self.assertIsNotNone(out2.grad_fn) self.assertExpectedInline( str(out2.grad_fn.__class__), """""" ) @skipIfTorchDynamo() @patch("torch._dynamo.config.assume_static_by_default", False) def test_dynamic_output_aliases_input_view_meta_replay(self): # - torch.compile: using it so we can have a SymInt in the FX graph. # - Compiling with inductor, so that tensor._base isn't tracked. # # This should force the use of as_strided in the view reconstruction path. # The first 2 view-replay paths won't be taken because: # - target_functional_tensor will be symbolic (_functionalize_is_symbolic call) # - tensor._base will be None @torch.compile(backend="inductor") def f(a, sz): return a.view(sz), a.view(-1) inp = torch.ones(2, 2, requires_grad=True) out1, out2 = f(inp, (4,)) self.assertIsNotNone(out1.grad_fn) self.assertExpectedInline( str(out1.grad_fn.__class__), """""" ) self.assertIsNotNone(out2.grad_fn) self.assertExpectedInline( str(out2.grad_fn.__class__), """""" ) def extract_graph(fx_g, _, graph_cell): graph_cell[0] = fx_g return fx_g def get_ins_outs(fx_g): ins = [] outs = [] for n in fx_g.graph.nodes: if n.op == "placeholder": ins.append(n) elif n.op == "output": outs = tuple(n.args[0]) return ins, outs def get_num_ins_outs(fx_g): return tuple(len(i) for i in get_ins_outs(fx_g)) def get_fw_bw_graph( f, inps, partitioner=min_cut_rematerialization_partition, dynamic=False ): fw_graph_cell = [None] bw_graph_cell = [None] aot_function( f, fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell), bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell), partition_fn=partitioner, decompositions=default_decompositions, dynamic=dynamic, )(*inps).sum().backward() return (fw_graph_cell[0], bw_graph_cell[0]) class TestMod(torch.nn.Module): def __init__(self, fn): super().__init__() self.p = torch.nn.Parameter(torch.ones(2, requires_grad=True)) self.fn = fn def forward(self, *args): return self.fn(self.p, *args) class TestAOTExport(AOTTestCase): def test_aot_export_ban_dropout_mut_pre_dispatch(self): def fn(p, x): y = torch.ops.aten.dropout.default(x, 0.1, train=False) y.add_(1) return (y,) mod = TestMod(fn) inp = torch.randn(2, 2) with self.assertRaisesRegex( RuntimeError, "cannot mutate tensors with frozen storage" ): aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True) gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=False) self.assertExpectedInline( str(gm.code).strip(), """\ def forward(self, arg0_1, arg1_1): clone = torch.ops.aten.clone.default(arg1_1); arg1_1 = None add = torch.ops.aten.add.Tensor(clone, 1); clone = None return (add,)""", ) fw_graph_cell = [None] bw_graph_cell = [None] compiled_outs = aot_function( fn, fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell), bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell), partition_fn=default_partition, decompositions=default_decompositions, dynamic=True, )(*inp) fw_graph = fw_graph_cell[0] bw_graph = bw_graph_cell[0] self.assertExpectedInline( str(fw_graph.code).strip(), """\ def forward(self, arg0_1, arg1_1): clone = torch.ops.aten.clone.default(arg1_1); arg1_1 = None add = torch.ops.aten.add.Tensor(clone, 1); clone = None return (add,)""", ) def test_aot_export_predispatch_func_simple(self): def fn(p, x): y = x + 2 with torch.no_grad(): y.add_(2) return (x * 2 + y,) mod = TestMod(fn) inp = torch.randn(2, 2) with torch.no_grad(): gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True) self.assertExpectedInline( str(gm.code).strip(), """\ def forward(self, arg0_1, arg1_1): add = torch.ops.aten.add.Tensor(arg1_1, 2) _set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None add_1 = torch.ops.aten.add.Tensor(add, 2); add = None _set_grad_enabled_1 = torch._C._set_grad_enabled(False); _set_grad_enabled_1 = None mul = torch.ops.aten.mul.Tensor(arg1_1, 2); arg1_1 = None add_2 = torch.ops.aten.add.Tensor(mul, add_1); mul = add_1 = None return (add_2,)""", ) def test_aot_export_predispatch_func_composite_implicit(self): def fn(p, x): with torch.enable_grad(): y = x @ x y.add_(2) return (x.sum() + y.sum(),) mod = TestMod(fn) inp = torch.randn(2, 2) with torch.no_grad(): gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True) self.assertExpectedInline( str(gm.code).strip(), """\ def forward(self, arg0_1, arg1_1): _set_grad_enabled = torch._C._set_grad_enabled(True); _set_grad_enabled = None matmul = torch.ops.aten.matmul.default(arg1_1, arg1_1) _set_grad_enabled_1 = torch._C._set_grad_enabled(False); _set_grad_enabled_1 = None add = torch.ops.aten.add.Tensor(matmul, 2); matmul = None sum_1 = torch.ops.aten.sum.default(arg1_1); arg1_1 = None sum_2 = torch.ops.aten.sum.default(add); add = None add_1 = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None return (add_1,)""", ) def test_aot_export_predispatch_composite_implicit_inplace(self): def fn(x, p): return (torch.ops.aten.absolute_.default(x.clone()),) mod = TestMod(fn) inp = torch.randn(2, 2) gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True) self.assertExpectedInline( str(gm.code).strip(), """\ def forward(self, arg0_1, arg1_1): clone = torch.ops.aten.clone.default(arg0_1); arg0_1 = None abs_1 = torch.ops.aten.abs.default(clone); clone = None return (abs_1,)""", ) def test_aot_export_predispatch_composite_implicit_linear(self): class MM(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(2, 2) def forward(self, x): return (self.linear(x),) mod = MM() inp = torch.randn(2, 2) gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True) self.assertExpectedInline( str(gm.code).strip(), """\ def forward(self, arg0_1, arg1_1, arg2_1): linear = torch.ops.aten.linear.default(arg2_1, arg0_1, arg1_1); arg2_1 = arg0_1 = arg1_1 = None return (linear,)""", ) @unittest.expectedFailure def test_aot_export_predispatch_outdtype(self): class M(torch.nn.Module): def __init__(self, weight): super().__init__() self.weight = weight def forward(self, x): y = x + 2 y.add_(5) return ( out_dtype(torch.ops.aten.mm.default, torch.int32, y, self.weight), ) weight = torch.randint(-128, 127, (5, 5), dtype=torch.int8) mod = M(weight) inp = torch.randint(-128, 127, (5, 5), dtype=torch.int8) gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True) self.assertExpectedInline( str(gm.code).strip(), """\ def forward(self, arg0_1, arg1_1): _set_grad_enabled = torch._C._set_grad_enabled(True); _set_grad_enabled = None mm = torch.ops.aten.mm.default(arg1_1, arg1_1) _set_grad_enabled_1 = torch._C._set_grad_enabled(False); _set_grad_enabled_1 = None add = torch.ops.aten.add.Tensor(mm, 2); mm = None sum_1 = torch.ops.aten.sum.default(arg1_1); arg1_1 = None sum_2 = torch.ops.aten.sum.default(add); add = None add_1 = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None return (add_1,)""", ) def test_aot_export_predispatch_func_view(self): def fn(p, x): y = x @ x y.add_(2) return (x.sum() + y.view(1, 4).sum(),) mod = TestMod(fn) inp = torch.randn(2, 2) gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True) self.assertExpectedInline( str(gm.code).strip(), """\ def forward(self, arg0_1, arg1_1): matmul = torch.ops.aten.matmul.default(arg1_1, arg1_1) add = torch.ops.aten.add.Tensor(matmul, 2); matmul = None sum_1 = torch.ops.aten.sum.default(arg1_1); arg1_1 = None view_1 = torch.ops.aten.view.default(add, [1, 4]); add = None sum_2 = torch.ops.aten.sum.default(view_1); view_1 = None add_1 = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None return (add_1,)""", ) def test_aot_export_predispatch_buffer_mutation_metadata(self): class Foo(torch.nn.Module): def __init__(self) -> None: super().__init__() self.foo = torch.nn.Buffer(torch.zeros(2, 2)) def forward(self, x): self.foo.add_(4) return (x.sum() + self.foo.sum(),) inp = torch.randn(2, 2) gm, graph_sig = aot_export_module( Foo(), [inp], trace_joint=False, pre_dispatch=True ) self.assertExpectedInline( str(gm.code).strip(), """\ def forward(self, arg0_1, arg1_1): add = torch.ops.aten.add.Tensor(arg0_1, 4); arg0_1 = None sum_1 = torch.ops.aten.sum.default(arg1_1); arg1_1 = None sum_2 = torch.ops.aten.sum.default(add) add_1 = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None return (add, add_1)""", ) eager_mod = Foo() output_1, output_2 = gm(torch.zeros(2, 2), inp) eager_output = eager_mod(inp) self.assertTrue(torch.allclose(output_2, eager_output[0])) _, output_2 = gm(output_1, inp) eager_output = eager_mod(inp) self.assertTrue(torch.allclose(output_2, eager_output[0])) self.assertTrue("foo" in graph_sig.buffers) self.assertTrue(graph_sig.inputs_to_buffers["arg0_1"] == "foo") def test_aot_export_predispatch_with_autograd_op(self): def foo(p, x): with torch.enable_grad(): y = x + 5 y.add_(5) y.add_(7) return (x.cos() + y.sin(),) inp = torch.randn(2, 2) mod = TestMod(foo) with torch.no_grad(): gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True) self.assertExpectedInline( str(gm.code).strip(), """\ def forward(self, arg0_1, arg1_1): _set_grad_enabled = torch._C._set_grad_enabled(True); _set_grad_enabled = None add = torch.ops.aten.add.Tensor(arg1_1, 5) add_1 = torch.ops.aten.add.Tensor(add, 5); add = None add_2 = torch.ops.aten.add.Tensor(add_1, 7); add_1 = None cos = torch.ops.aten.cos.default(arg1_1); arg1_1 = None sin = torch.ops.aten.sin.default(add_2); add_2 = None add_3 = torch.ops.aten.add.Tensor(cos, sin); cos = sin = None _set_grad_enabled_1 = torch._C._set_grad_enabled(False); _set_grad_enabled_1 = None return (add_3,)""", ) @unittest.skipIf(IS_WINDOWS, "Windows isn't supported for this case") @unittest.skipIf( not torchdynamo.is_dynamo_supported(), "TorchDynamo is not supported" ) def test_aot_export_predispatch_with_cond_nested(self): class M(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, x): def true_fn(x): y = x.sin() y.add_(5) def true_true_fn(x): y = x.sin() y.add_(7) return y.sin() def true_false_fn(x): return x.cos() return torch.cond( y.cos().sum() > 5, true_true_fn, true_false_fn, [y.cos()] ) def false_fn(x): z = x.cos() z.add_(6) return z.sin() a = torch.cond(x.sum() > 4, true_fn, false_fn, [x]) return (a + 3, a + 4) inp = torch.randn(2, 2) gm, _ = aot_export_module(M(), [inp], trace_joint=False, pre_dispatch=True) self.assertExpectedInline( str(gm.code).strip(), """\ def forward(self, arg0_1): sum_1 = torch.ops.aten.sum.default(arg0_1) gt = torch.ops.aten.gt.Scalar(sum_1, 4); sum_1 = None true_graph_0 = self.true_graph_0 false_graph_0 = self.false_graph_0 cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None getitem = cond[0]; cond = None add = torch.ops.aten.add.Tensor(getitem, 3) add_1 = torch.ops.aten.add.Tensor(getitem, 4); getitem = None return (add, add_1)""", # noqa: B950 ) self.assertExpectedInline( str(gm.true_graph_0.code).strip(), """\ def forward(self, arg0_1): sin = torch.ops.aten.sin.default(arg0_1); arg0_1 = None add = torch.ops.aten.add.Tensor(sin, 5); sin = None cos = torch.ops.aten.cos.default(add) sum_1 = torch.ops.aten.sum.default(cos); cos = None gt = torch.ops.aten.gt.Scalar(sum_1, 5); sum_1 = None cos_1 = torch.ops.aten.cos.default(add); add = None true_graph_0 = self.true_graph_0 false_graph_0 = self.false_graph_0 cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [cos_1]); gt = true_graph_0 = false_graph_0 = cos_1 = None getitem = cond[0]; cond = None return (getitem,)""", # noqa: B950 ) self.assertExpectedInline( str(gm.true_graph_0.true_graph_0.code).strip(), """\ def forward(self, arg0_1): sin = torch.ops.aten.sin.default(arg0_1); arg0_1 = None add = torch.ops.aten.add.Tensor(sin, 7); sin = None sin_1 = torch.ops.aten.sin.default(add); add = None return (sin_1,)""", ) @unittest.skipIf(IS_WINDOWS, "Windows isn't supported for this case") @unittest.skipIf( not torchdynamo.is_dynamo_supported(), "TorchDynamo is not supported" ) def test_aot_export_predispatch_map_1(self): class M(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, x, y): def true_fn(x, r): y = x.sin() y.add_(5) return y.cos() + r.sum() def false_fn(x, r): z = x.cos() def f(x, y): a = x.cos() a.add_(5) return a + y return ( z + control_flow.map(f, z, r).sum() + control_flow.map(f, z, r).sum() ) a = torch.cond(x.sum() > 4, true_fn, false_fn, [x, y]) return (a + 3, a + 4) inps = [torch.randn(2, 2), torch.ones(2)] gm, _ = aot_export_module(M(), inps, trace_joint=False, pre_dispatch=True) self.assertExpectedInline( str(gm.code).strip(), """\ def forward(self, arg0_1, arg1_1): sum_1 = torch.ops.aten.sum.default(arg0_1) gt = torch.ops.aten.gt.Scalar(sum_1, 4); sum_1 = None true_graph_0 = self.true_graph_0 false_graph_0 = self.false_graph_0 cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1, arg1_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = arg1_1 = None getitem = cond[0]; cond = None add = torch.ops.aten.add.Tensor(getitem, 3) add_1 = torch.ops.aten.add.Tensor(getitem, 4); getitem = None return (add, add_1)""", # noqa: B950 ) self.assertExpectedInline( str(gm.true_graph_0.code).strip(), """\ def forward(self, arg0_1, arg1_1): sin = torch.ops.aten.sin.default(arg0_1); arg0_1 = None add = torch.ops.aten.add.Tensor(sin, 5); sin = None cos = torch.ops.aten.cos.default(add); add = None sum_1 = torch.ops.aten.sum.default(arg1_1); arg1_1 = None add_1 = torch.ops.aten.add.Tensor(cos, sum_1); cos = sum_1 = None return (add_1,)""", ) self.assertExpectedInline( str(gm.false_graph_0.code).strip(), """\ def forward(self, arg0_1, arg1_1): cos = torch.ops.aten.cos.default(arg0_1); arg0_1 = None select = torch.ops.aten.select.int(cos, 0, 0); select = None body_graph_0 = self.body_graph_0 map_impl = torch.ops.higher_order.map_impl(body_graph_0, [cos], [arg1_1]); body_graph_0 = None getitem = map_impl[0]; map_impl = None sum_1 = torch.ops.aten.sum.default(getitem); getitem = None add = torch.ops.aten.add.Tensor(cos, sum_1); sum_1 = None select_1 = torch.ops.aten.select.int(cos, 0, 0); select_1 = None body_graph_1 = self.body_graph_1 map_impl_1 = torch.ops.higher_order.map_impl(body_graph_1, [cos], [arg1_1]); body_graph_1 = cos = arg1_1 = None getitem_1 = map_impl_1[0]; map_impl_1 = None sum_2 = torch.ops.aten.sum.default(getitem_1); getitem_1 = None add_1 = torch.ops.aten.add.Tensor(add, sum_2); add = sum_2 = None return (add_1,)""", ) self.assertExpectedInline( str(gm.false_graph_0.body_graph_0.code).strip(), """\ def forward(self, arg0_1, arg1_1): cos = torch.ops.aten.cos.default(arg0_1); arg0_1 = None add = torch.ops.aten.add.Tensor(cos, 5); cos = None add_1 = torch.ops.aten.add.Tensor(add, arg1_1); add = arg1_1 = None return (add_1,)""", ) def test_aot_export_predispatch_map_2(self): class M(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, x, y): z = x.cos() def f(x, y): a = x.cos() a.add_(5) return a + y return (z + control_flow.map(f, z, y).sum(),) inps = [torch.randn(2, 2), torch.ones(2)] gm, _ = aot_export_module(M(), inps, trace_joint=False, pre_dispatch=True) self.assertExpectedInline( str(gm.code).strip(), """\ def forward(self, arg0_1, arg1_1): cos = torch.ops.aten.cos.default(arg0_1); arg0_1 = None body_graph_0 = self.body_graph_0 map_impl = torch.ops.higher_order.map_impl(body_graph_0, [cos], [arg1_1]); body_graph_0 = arg1_1 = None getitem = map_impl[0]; map_impl = None sum_1 = torch.ops.aten.sum.default(getitem); getitem = None add = torch.ops.aten.add.Tensor(cos, sum_1); cos = sum_1 = None return (add,)""", ) # noqa: B950 self.assertExpectedInline( str(gm.body_graph_0.code).strip(), """\ def forward(self, arg0_1, arg1_1): cos = torch.ops.aten.cos.default(arg0_1); arg0_1 = None add = torch.ops.aten.add.Tensor(cos, 5); cos = None add_1 = torch.ops.aten.add.Tensor(add, arg1_1); add = arg1_1 = None return [add_1]""", ) @unittest.skipIf(IS_WINDOWS, "Windows isn't supported for this case") @unittest.skipIf( not torchdynamo.is_dynamo_supported(), "TorchDynamo is not supported" ) def test_aot_export_predispatch_with_cond(self): class M(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, x): def true_fn(x): y = x.sin() z = torch.ops.aten.linear.default(y, torch.randn(2, 2)) z.add_(5) return z.cos() def false_fn(x): z = x.cos() z.add_(6) return z.sin() a = torch.cond(x.sum() > 4, true_fn, false_fn, [x]) return (a + 3, a + 4) inp = torch.randn(2, 2) gm, _ = aot_export_module(M(), [inp], trace_joint=False, pre_dispatch=True) self.assertExpectedInline( str(gm.code).strip(), """\ def forward(self, arg0_1): sum_1 = torch.ops.aten.sum.default(arg0_1) gt = torch.ops.aten.gt.Scalar(sum_1, 4); sum_1 = None true_graph_0 = self.true_graph_0 false_graph_0 = self.false_graph_0 cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None getitem = cond[0]; cond = None add = torch.ops.aten.add.Tensor(getitem, 3) add_1 = torch.ops.aten.add.Tensor(getitem, 4); getitem = None return (add, add_1)""", # noqa: B950 ) self.assertExpectedInline( str(gm.true_graph_0.code).strip(), """\ def forward(self, arg0_1): sin = torch.ops.aten.sin.default(arg0_1); arg0_1 = None randn = torch.ops.aten.randn.default([2, 2], device = device(type='cpu'), pin_memory = False) linear = torch.ops.aten.linear.default(sin, randn); sin = randn = None add = torch.ops.aten.add.Tensor(linear, 5); linear = None cos = torch.ops.aten.cos.default(add); add = None return (cos,)""", ) def test_aot_export_predispatch_conv_and_bn(self): class ConvBatchnorm(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d(1, 3, 1, 1) self.bn = torch.nn.BatchNorm2d(3) def forward(self, x): x = self.conv(x) x = self.bn(x) return (x,) mod = ConvBatchnorm() mod.train() inp = torch.randn(1, 1, 3, 3) gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True) self.assertExpectedInline( str(gm.code).strip(), """\ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1): conv2d = torch.ops.aten.conv2d.default(arg7_1, arg0_1, arg1_1); arg7_1 = arg0_1 = arg1_1 = None add = torch.ops.aten.add.Tensor(arg6_1, 1); arg6_1 = None _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(conv2d, arg2_1, arg3_1, arg4_1, arg5_1, True, 0.1, 1e-05); conv2d = arg2_1 = arg3_1 = arg4_1 = arg5_1 = None getitem = _native_batch_norm_legit_functional[0] getitem_3 = _native_batch_norm_legit_functional[3] getitem_4 = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None return (getitem_3, getitem_4, add, getitem)""", # noqa: B950 ) def test_aot_export_predispatch_reshape(self): class Reshape(torch.nn.Module): def forward(self, x): y = x.reshape(4, 4) return (y.sum(),) mod = Reshape() inp = torch.randn(2, 8) gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True) self.assertExpectedInline( str(gm.code).strip(), """\ def forward(self, arg0_1): view = torch.ops.aten.view.default(arg0_1, [4, 4]); arg0_1 = None sum_1 = torch.ops.aten.sum.default(view); view = None return (sum_1,)""", ) # noqa: B950 def test_aot_export_predispatch_contiguous(self): class Cont(torch.nn.Module): def forward(self, x): y = torch.ops.aten.contiguous.default(x) return (y.sum(),) mod = Cont() inp = torch.randn(2, 8) gm, _ = aot_export_module(mod, [inp], trace_joint=False, pre_dispatch=True) self.assertExpectedInline( str(gm.code).strip(), """\ def forward(self, arg0_1): sum_1 = torch.ops.aten.sum.default(arg0_1); arg0_1 = None return (sum_1,)""", ) # noqa: B950 def test_aot_export_module_joint(self): class ConvBatchnormRelu(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d(1, 3, 1, 1) self.bn = torch.nn.BatchNorm2d(3) def forward(self, x): x = self.conv(x) x = self.bn(x) user_out = torch.nn.functional.relu(x) loss = user_out.sum() return loss, user_out.detach() mod = ConvBatchnormRelu() mod.train() inp = torch.randn(1, 1, 3, 3) o_ref = mod(inp) fx_g, signature = aot_export_module( mod, [inp], trace_joint=True, output_loss_index=0 ) # Some important characteristics of the exported graph below: # 8 arguments: 2 params from conv, 2 params from batchnorm, 2 buffers from 1 batchnorm, 1 user input # 9 outputs: 3 mutated buffers (from batchnorm), 2 user outputs and 4 gradients (since there were 4 parameters) for node in fx_g.graph.nodes: node.meta.pop("stack_trace", None) self.assertExpectedInline( fx_g.print_readable(print_output=False), """\ class (torch.nn.Module): def forward(self, arg0_1: "f32[3, 1, 1, 1]", arg1_1: "f32[3]", arg2_1: "f32[3]", arg3_1: "f32[3]", arg4_1: "f32[3]", arg5_1: "f32[3]", arg6_1: "i64[]", arg7_1: "f32[1, 1, 3, 3]"): # No stacktrace found for following nodes convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(arg7_1, arg0_1, arg1_1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1); arg1_1 = None add: "i64[]" = torch.ops.aten.add.Tensor(arg6_1, 1); arg6_1 = None _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(convolution, arg2_1, arg3_1, arg4_1, arg5_1, True, 0.1, 1e-05); arg3_1 = arg4_1 = arg5_1 = None getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0] getitem_1: "f32[3]" = _native_batch_norm_legit_functional[1] getitem_2: "f32[3]" = _native_batch_norm_legit_functional[2] getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3] getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None relu: "f32[1, 3, 3, 3]" = torch.ops.aten.relu.default(getitem); getitem = None detach: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu); detach = None detach_1: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu) detach_2: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_1); detach_1 = None detach_3: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_2); detach_2 = None detach_4: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_3); detach_3 = None sum_1: "f32[]" = torch.ops.aten.sum.default(relu) detach_5: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu); relu = None detach_6: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_5); detach_5 = None detach_7: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_6); detach_6 = None detach_8: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_7); detach_7 = None detach_9: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_8); detach_8 = None detach_10: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_9); detach_9 = None ones_like: "f32[]" = torch.ops.aten.ones_like.default(sum_1, pin_memory = False, memory_format = torch.preserve_format) expand: "f32[1, 3, 3, 3]" = torch.ops.aten.expand.default(ones_like, [1, 3, 3, 3]); ones_like = None detach_11: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_4); detach_4 = None detach_12: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_11); detach_11 = None detach_13: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_12); detach_12 = None detach_14: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_13); detach_13 = None threshold_backward: "f32[1, 3, 3, 3]" = torch.ops.aten.threshold_backward.default(expand, detach_14, 0); expand = detach_14 = None native_batch_norm_backward = torch.ops.aten.native_batch_norm_backward.default(threshold_backward, convolution, arg2_1, getitem_3, getitem_4, getitem_1, getitem_2, True, 1e-05, [True, True, True]); threshold_backward = convolution = arg2_1 = getitem_1 = getitem_2 = None getitem_5: "f32[1, 3, 3, 3]" = native_batch_norm_backward[0] getitem_6: "f32[3]" = native_batch_norm_backward[1] getitem_7: "f32[3]" = native_batch_norm_backward[2]; native_batch_norm_backward = None convolution_backward = torch.ops.aten.convolution_backward.default(getitem_5, arg7_1, arg0_1, [3], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [False, True, True]); getitem_5 = arg7_1 = arg0_1 = None getitem_8 = convolution_backward[0]; getitem_8 = None getitem_9: "f32[3, 1, 1, 1]" = convolution_backward[1] getitem_10: "f32[3]" = convolution_backward[2]; convolution_backward = None return (getitem_3, getitem_4, add, sum_1, detach_10, getitem_9, getitem_10, getitem_6, getitem_7) """, # noqa: B950 ) self.assertExpectedInline( str(signature.parameters), """['conv.weight', 'conv.bias', 'bn.weight', 'bn.bias']""", ) self.assertExpectedInline( str(signature.buffers), """['bn.running_mean', 'bn.running_var', 'bn.num_batches_tracked']""", ) self.assertExpectedInline(str(signature.user_inputs), """['arg7_1']""") self.assertExpectedInline( str(signature.inputs_to_parameters), """{'arg0_1': 'conv.weight', 'arg1_1': 'conv.bias', 'arg2_1': 'bn.weight', 'arg3_1': 'bn.bias'}""", ) # noqa: B950 self.assertExpectedInline( str(signature.inputs_to_buffers), """{'arg4_1': 'bn.running_mean', 'arg5_1': 'bn.running_var', 'arg6_1': 'bn.num_batches_tracked'}""", ) # noqa: B950 self.assertExpectedInline( str(signature.buffers_to_mutate), """{'getitem_3': 'bn.running_mean', 'getitem_4': 'bn.running_var', 'add': 'bn.num_batches_tracked'}""", ) # noqa: B950 self.assertExpectedInline( str(signature.backward_signature.gradients_to_parameters), """{'getitem_9': 'conv.weight', 'getitem_10': 'conv.bias', 'getitem_6': 'bn.weight', 'getitem_7': 'bn.bias'}""", ) # noqa: B950 self.assertExpectedInline( str(signature.backward_signature.gradients_to_user_inputs), """{}""" ) self.assertExpectedInline( str(signature.backward_signature.loss_output), """getitem_3""" ) # Also check the inference graph # Main important thing here is that there are 5 total outputs: 3 total mutated buffers (from batchnorm), 2 user outputs. fx_g_inference, signature_inference = aot_export_module( mod, [inp], trace_joint=False ) for node in fx_g_inference.graph.nodes: node.meta.pop("stack_trace", None) self.assertExpectedInline( fx_g_inference.print_readable(print_output=False), """\ class (torch.nn.Module): def forward(self, arg0_1: "f32[3, 1, 1, 1]", arg1_1: "f32[3]", arg2_1: "f32[3]", arg3_1: "f32[3]", arg4_1: "f32[3]", arg5_1: "f32[3]", arg6_1: "i64[]", arg7_1: "f32[1, 1, 3, 3]"): # No stacktrace found for following nodes convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(arg7_1, arg0_1, arg1_1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1); arg7_1 = arg0_1 = arg1_1 = None add: "i64[]" = torch.ops.aten.add.Tensor(arg6_1, 1); arg6_1 = None _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(convolution, arg2_1, arg3_1, arg4_1, arg5_1, True, 0.1, 1e-05); convolution = arg2_1 = arg3_1 = arg4_1 = arg5_1 = None getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0] getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3] getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None relu: "f32[1, 3, 3, 3]" = torch.ops.aten.relu.default(getitem); getitem = None sum_1: "f32[]" = torch.ops.aten.sum.default(relu) detach: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu); relu = None detach_1: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach); detach = None detach_2: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_1); detach_1 = None return (getitem_3, getitem_4, add, sum_1, detach_2) """, # noqa: B950 ) # Some important characteristics of the exported graph below: # 8 arguments: 2 params from conv, 2 params from batchnorm, 2 buffers from 1 batchnorm, 1 user input # 9 outputs: 2 mutated buffers (from batchnorm), 2 user outputs and 4 gradients (since there were 4 parameters) def test_aot_export_simplified_basic(self): def f(x, y): return x * y, y * y.detach() x = torch.randn(2, requires_grad=True) y = torch.randn(2, requires_grad=True) f_graph_fw = aot_export_joint_simple(f, [x, y], trace_joint=False) out_ref = f(x, y) # No calling convention changes necessary to invoke the traced graph out_test = f_graph_fw(x, y) self.assertEqual(out_ref, out_test) # Now test the backward x = torch.randn(2, requires_grad=True) y = torch.randn(2, requires_grad=True) x2 = x.clone().detach().requires_grad_(True) y2 = y.clone().detach().requires_grad_(True) x3 = x.clone().detach().requires_grad_(True) y3 = y.clone().detach().requires_grad_(True) f_graph_joint = aot_export_joint_simple(f, [x, y], trace_joint=True) num_fw_outputs = 2 fw_g, bw_g = default_partition( f_graph_joint, [x, y], num_fwd_outputs=num_fw_outputs ) out_ref2 = f(x2, y2) fw_outs = fw_g(x3, y3) out_test2, activations = fw_outs[:num_fw_outputs], fw_outs[num_fw_outputs:] self.assertEqual(out_ref2, out_test2) # Test running the traced backward graph with a mocked-up grad_output grad_outs = [torch.ones_like(x) for x in out_ref2] grads_ref = torch.autograd.grad(out_ref2, [x2, y2], grad_outputs=grad_outs) grads_test = bw_g(*activations, *grad_outs) for g_ref, g_test in zip(grads_ref, grads_test): self.assertEqual(g_ref, g_test) def test_aot_export_metadata_mutation_banned(self): def fn(p, x): x.t_() return (x * 2,) mod = TestMod(fn) inp = torch.randn(2, 4) with self.assertRaisesRegex( RuntimeError, "Found an input that received a metadata mutation" ): aot_export_joint_simple(fn, [mod.p, inp], trace_joint=False) aot_export_joint_simple(fn, [mod.p, inp], trace_joint=True) aot_export_module(mod, [inp], trace_joint=False) def test_aot_export_forward_mutation_no_buffer_mut(self): class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.buffer1 = torch.nn.Buffer(torch.ones(6, 4)) def forward(self, x): x.add_(4) return (x.cos().sum() + self.buffer1.sum(),) mod = M() inp = torch.ones(6, 4) gm, sig = aot_export_module(mod, [inp], trace_joint=False) self.assertExpectedInline( str(gm.code).strip(), """\ def forward(self, arg0_1, arg1_1): add = torch.ops.aten.add.Tensor(arg1_1, 4); arg1_1 = None cos = torch.ops.aten.cos.default(add) sum_1 = torch.ops.aten.sum.default(cos); cos = None sum_2 = torch.ops.aten.sum.default(arg0_1); arg0_1 = None add_1 = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None return (add, add_1)""", ) # noqa: B950 self.assertEqual(sig.user_inputs_to_mutate, {"add": "arg1_1"}) def test_aot_export_forward_mutation_multiple_mut(self): class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.buffer1 = torch.nn.Buffer(torch.ones(6, 4)) def forward(self, x, y): y.add_(4) self.buffer1.add_(5) return ( x.cos().sum() + y.sin().sum(), self.buffer1.sum(), ) mod = M() inp = [torch.ones(6, 4), torch.zeros(6, 4)] gm, sig = aot_export_module(mod, inp, trace_joint=False) self.assertExpectedInline( str(gm.code).strip(), """\ def forward(self, arg0_1, arg1_1, arg2_1): add = torch.ops.aten.add.Tensor(arg2_1, 4); arg2_1 = None add_1 = torch.ops.aten.add.Tensor(arg0_1, 5); arg0_1 = None cos = torch.ops.aten.cos.default(arg1_1); arg1_1 = None sum_1 = torch.ops.aten.sum.default(cos); cos = None sin = torch.ops.aten.sin.default(add) sum_2 = torch.ops.aten.sum.default(sin); sin = None add_2 = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None sum_3 = torch.ops.aten.sum.default(add_1) return (add_1, add, add_2, sum_3)""", ) # noqa: B950 self.assertEqual(sig.user_inputs_to_mutate, {"add": "arg2_1"}) self.assertEqual(sig.buffers_to_mutate, {"add_1": "buffer1"}) def test_aot_export_input_mutation_on_input_requiring_grad_banned(self): class M(torch.nn.Module): def forward(self, x): x.add_(4) return (x,) mod = M() inp = torch.randn(2, requires_grad=True) with self.assertRaisesRegex( RuntimeError, "Found a graph input that requires gradients, and received a mutation", ): aot_export_module(mod, [inp], trace_joint=False) def test_aot_export_input_mutation_on_parameter_banned(self): def fn(p, x): p.mul_(2) return (p + x,) mod = TestMod(fn) inp = torch.randn(2) with self.assertRaisesRegex( RuntimeError, "Found a graph input that requires gradients, and received a mutation", ): aot_export_joint_simple(fn, [mod.p, inp], trace_joint=False) aot_export_joint_simple(fn, [mod.p, inp], trace_joint=True) aot_export_module(mod, [inp], trace_joint=False) def test_aot_export_synthetic_bases_banned(self): def fn(p, x, y): x.mul_(2) return (x + y,) mod = TestMod(fn) inp = torch.randn(2) inp2 = inp.view(-1) with self.assertRaisesRegex( RuntimeError, "Encountered aliased inputs that are mutated" ): aot_export_joint_simple(fn, [mod.p, inp, inp2], trace_joint=False) aot_export_joint_simple(fn, [mod.p, inp, inp2], trace_joint=True) aot_export_module(mod, [inp, inp2], trace_joint=False) def test_aot_export_input_dupes_banned(self): def fn(p, x, y): x.mul_(2) return (x + y,) mod = TestMod(fn) inp = torch.randn(2) with self.assertRaisesRegex( RuntimeError, "Encountered duplicated inputs that are mutated in the graph" ): aot_export_joint_simple(fn, [mod.p, inp, inp], trace_joint=False) aot_export_joint_simple(fn, [mod.p, inp, inp], trace_joint=True) aot_export_module(mod, [inp, inp], trace_joint=False) def test_aot_export_multiple_outputs_require_grad_banned(self): def fn(p, x): out = p * x return out, out.sum() mod = TestMod(fn) inp = torch.randn(2) with self.assertRaisesRegex( RuntimeError, "Found an output of the forward that requires gradients, that was not", ): aot_export_module(mod, [inp], trace_joint=True, output_loss_index=1) @unittest.skipIf(IS_WINDOWS, "Windows isn't supported for this case") @unittest.skipIf( not torch._dynamo.is_dynamo_supported(), "Cond needs dynamo to run" ) def test_aot_export_with_torch_cond(self): class M(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, x): def true_fn(x): y = x + 4 y.add_(5) return x.cos() def false_fn(x): y = x + 5 y.add_(6) return x.sin() a = torch.cond(x.sum() > 4, true_fn, false_fn, [x]) return (a + 3, a + 4) inp = torch.randn(3, 4) gm, _ = aot_export_module(M(), (inp,), trace_joint=False) self.assertExpectedInline( gm.code.strip(), """\ def forward(self, arg0_1): sum_1 = torch.ops.aten.sum.default(arg0_1) gt = torch.ops.aten.gt.Scalar(sum_1, 4); sum_1 = None true_graph_0 = self.true_graph_0 false_graph_0 = self.false_graph_0 cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None getitem = cond[0]; cond = None add = torch.ops.aten.add.Tensor(getitem, 3) add_1 = torch.ops.aten.add.Tensor(getitem, 4); getitem = None return (add, add_1)""", # noqa: B950 ) self.assertExpectedInline( gm.true_graph_0.code.strip(), """\ def forward(self, arg0_1): add = torch.ops.aten.add.Tensor(arg0_1, 4) add_1 = torch.ops.aten.add.Tensor(add, 5); add = add_1 = None cos = torch.ops.aten.cos.default(arg0_1); arg0_1 = None return (cos,)""", ) self.assertExpectedInline( gm.false_graph_0.code.strip(), """\ def forward(self, arg0_1): add = torch.ops.aten.add.Tensor(arg0_1, 5) add_1 = torch.ops.aten.add.Tensor(add, 6); add = add_1 = None sin = torch.ops.aten.sin.default(arg0_1); arg0_1 = None return (sin,)""", ) def test_aot_export_simplified_pytrees_banned(self): def fn(inps): return (inps[0] + inps[1],) inp1 = torch.randn(2) inp2 = torch.randn(2) inps = [inp1, inp2] with self.assertRaisesRegex( RuntimeError, "aot_export_joint_simple requires individual inputs not to be pytrees", ): aot_export_joint_simple(fn, [inps], trace_joint=False) aot_export_joint_simple(fn, [inps], trace_joint=True) def test_aot_export_functionalized_rng_banned(self): def fn(p, x): return (p + x,) mod = TestMod(fn) inp = torch.randn(2) with patch( "functorch.compile.config.functionalize_rng_ops", True ), self.assertRaisesRegex( RuntimeError, "Functionalized RNG is not currently supported in the aot_export", ): aot_export_joint_simple(fn, [mod.p, inp], trace_joint=False) aot_export_joint_simple(fn, [mod.p, inp], trace_joint=True) aot_export_module(mod, [inp], trace_joint=False) def test_aot_export_unbacked_arg(self): class M(torch.nn.Module): def forward(self): full = torch.full((), 11) i0 = full.item() return (torch.full((i0,), 0),) gm, _ = aot_export_module( mod=M(), args=(), trace_joint=False, dynamic_shapes=True ) self.assertExpectedInline( gm.code.strip(), """\ def forward(self): full = torch.ops.aten.full.default([], 11, device = device(type='cpu'), pin_memory = False) _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(full); full = None full_1 = torch.ops.aten.full.default([_local_scalar_dense], 0, device = device(type='cpu'), pin_memory = False); _local_scalar_dense = None return (full_1,)""", # noqa: B950 ) class TestPartitioning(AOTTestCase): @unittest.skipIf(not USE_NETWORKX, "networkx not available") def test_recompute_partitioning(self): def fn(a, b): return torch.sin(torch.sin(a)) + b # Reference calculation ref_a = torch.rand(10, 10, requires_grad=True) ref_b = torch.rand(10, 10, requires_grad=True) ref = fn(ref_a, ref_b) ref.sum().backward() # Compiled function calculation res_a = ref_a.clone().detach().requires_grad_(True) res_b = ref_b.clone().detach().requires_grad_(True) def compile_fn(x, _): return x compiled_fn = compiled_function( fn, compile_fn, compile_fn, min_cut_rematerialization_partition ) res = compiled_fn(res_a, res_b) res.sum().backward() assert torch.allclose(ref, res, atol=1e-3, rtol=1e-3) assert torch.allclose(ref_a.grad, res_a.grad, atol=1e-3, rtol=1e-3) assert torch.allclose(ref_b.grad, res_b.grad, atol=1e-3, rtol=1e-3) def test_meta_tensor_inplace_op(self): # Following module results in inplace ops while tracing. The test checks # that the meta tensor information is stored for inplace ops. class MockModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.weight = torch.nn.Parameter( torch.randn(3072, 768, requires_grad=True) ) self.bias = torch.nn.Parameter(torch.randn(3072, requires_grad=True)) def forward(self, add_4): linear_4 = torch.nn.functional.linear( add_4, self.weight, bias=self.bias ) gelu = torch.nn.functional.gelu(linear_4) return gelu def check_meta_tensor(fx_g, _): for node in fx_g.graph.nodes: if node.op != "output": assert "tensor_meta" in node.meta return fx_g inp0 = torch.randn(16, 128, 768, requires_grad=True) inputs = [ inp0, ] mod = MockModule().to(device="cpu") aot_mod = aot_module(mod, fw_compiler=check_meta_tensor) aot_mod(*inputs) def test_default_partitioner_getitem(self): mod = nn.LayerNorm([10]) def f(x, mod_weight, mod_bias): return torch.nn.functional.layer_norm( x, [10], mod_weight, mod_bias, eps=1e-6 ) fw_graph, bw_graph = get_fw_bw_graph( f, [torch.randn(3, 10, requires_grad=True), mod.weight, mod.bias], partitioner=default_partition, ) self.assertEqual(get_num_ins_outs(fw_graph), (3, 6)) self.assertEqual(get_num_ins_outs(bw_graph), (6, 3)) @unittest.skipIf(not USE_NETWORKX, "networkx not available") def test_min_cut_partitioner_save_shape(self): def f(x): s = x.sum(dim=1) return s inp = [torch.ones([10, 10], requires_grad=True)] fw_graph, bw_graph = get_fw_bw_graph(f, inp, dynamic=True) _, fw_output = get_ins_outs(fw_graph) self.assertEqual(get_num_ins_outs(fw_graph), (1, 3)) self.assertEqual(get_num_ins_outs(bw_graph), (3, 1)) self.assertEqual(str(fw_output[0]), "sum_1") # make sure we don't do the suboptimal thing of saving the bigger primals input to sum, # rather than saving the sizes of the primals input for use in backward expand self.assertEqual(str(fw_output[1]), "sym_size_int") self.assertEqual(str(fw_output[2]), "sym_size_int_1") inp = [ torch.randn(10, requires_grad=True), torch.randn((3, 10), requires_grad=True), torch.randn((2, 10), requires_grad=True), ] def f(a, b, c): # tried to test what happens if we save a size tuple in the graph; # turns out we never will due to how we trace, but this is probably # still a good test case for various size manipulations sb = torch.ops.aten.sym_size(b) sc = c.size() x = sb[0] + sc[0] a_sz = (x, a.size(0)) return torch.cat([a.expand(a_sz), b, c]) fw_graph, bw_graph = get_fw_bw_graph(f, inp, dynamic=True) self.assertEqual(get_num_ins_outs(fw_graph), (3, 4)) self.assertEqual(get_num_ins_outs(bw_graph), (4, 3)) _, outs = get_ins_outs(fw_graph) self.assertTrue(all(is_sym_node(n) for n in outs[1:])) def test_default_partitioner_output_tensor_shape_tensor(self): inp = [ torch.randn(10, requires_grad=True), torch.randn((3, 10), requires_grad=True), torch.randn((2, 10), requires_grad=True), torch.randn((10, 1), requires_grad=True), ] def f(a, b, c, d): # Try to force symints intermixed with outputs in the function's returns sb = b.size() sc = c.size() x = sb[0] + sc[0] a_sz = (x, a.size(0)) cat = torch.cat([a.expand(a_sz), b, c]) mm = torch.mm(cat, d) mm2 = torch.mm( mm, a.view(mm.size(1), a.size(0)) ) # this saves 4 new ints for backward. why? # and what do i have to do to make it save a tensor for backward? return cat, sb, c, mm2 fw_graph_cell = [None] bw_graph_cell = [None] compiled_outs = aot_function( f, fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell), bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell), partition_fn=default_partition, decompositions=default_decompositions, dynamic=True, )(*inp) fw_graph = fw_graph_cell[0] (compiled_outs[0].sum() + compiled_outs[2].sum()).backward() bw_graph = bw_graph_cell[0] # in the fwd graph, 13 outs because: # - 5 original outputs (sb is a tuple, gets expanded to 2 symints) # - 8 saved outputs for backward: 5 tensors, 3 symints self.assertEqual(get_num_ins_outs(fw_graph), (4, 13)) # in the bwd graph, 10 inputs (grad outs) because: # - The fwd graph had 13 outputs # - 1 was a view of an input, which gets regenerated outside of the graph # and doesn't participate in the backward # - 2 user outs were symints (b.size()), which don't get tangents in the backward self.assertEqual(get_num_ins_outs(bw_graph), (10, 4)) _, fw_graph_out_nodes = get_ins_outs(fw_graph) self.assertEqual( # fw outputs include b.size() which expands to 2 symints, # # TODO(whc)- are the saved-tensors/saved-symints correct here? # i just made the test pass based on what default partition did # Of the 5 original forward outputs, the 4th (c) is an input, # which won't show up in the compiled forward graph [False, True, True, False, False] + [False] * 4 + [True] * 4, [is_sym_node(n) for n in fw_graph_out_nodes], ) real_outs = f(*inp) self.assertEqual(compiled_outs, real_outs) self.assertTrue(isinstance(real_outs[1], torch.Size)) # TODO(whc) we should learn to return torch.Sizes self.assertFalse(isinstance(compiled_outs[1], torch.Size)) @unittest.skipIf(not USE_NETWORKX, "networkx not available") def test_min_cut_partitioner_output_tensor_shape_tensor(self): inp = [ torch.randn(10, requires_grad=True), torch.randn((3, 10), requires_grad=True), torch.randn((2, 10), requires_grad=True), torch.randn((10, 1), requires_grad=True), ] def f(a, b, c, d): # Try to force symints intermixed with outputs in the function's returns sb = b.size() sc = c.size() x = sb[0] + sc[0] a_sz = (x, a.size(0)) cat = torch.cat([a.expand(a_sz), b, c]) mm = torch.mm(cat, d) mm2 = torch.mm( mm, a.view(mm.size(1), a.size(0)) ) # this saves 4 new ints for backward. why? # and what do i have to do to make it save a tensor for backward? return cat, sb, c, mm2 fw_graph_cell = [None] bw_graph_cell = [None] compiled_outs = aot_function( f, fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell), bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell), partition_fn=min_cut_rematerialization_partition, decompositions=default_decompositions, dynamic=True, )(*inp) fw_graph = fw_graph_cell[0] (compiled_outs[0].sum() + compiled_outs[2].sum()).backward() bw_graph = bw_graph_cell[0] self.assertEqual(get_num_ins_outs(fw_graph), (4, 12)) self.assertEqual(get_num_ins_outs(bw_graph), (9, 4)) _, fw_graph_out_nodes = get_ins_outs(fw_graph) self.assertEqual( # fw outputs include b.size() which expands to 2 symints, # then 4 tensors (transposes of matricies used for mm) are saved # finally 3 symints are saved [False, True, True, False, False] + [False] * 4 + [True] * 3, [is_sym_node(n) for n in fw_graph_out_nodes], ) real_outs = f(*inp) self.assertEqual(compiled_outs, real_outs) self.assertTrue(isinstance(real_outs[1], torch.Size)) # TODO(whc) we should learn to return torch.Sizes self.assertFalse(isinstance(compiled_outs[1], torch.Size)) @unittest.skipIf(not USE_NETWORKX, "networkx not available") def test_min_cut_partitioner(self): def f(x): return x.cos().cos().cos() fw_graph, bw_graph = get_fw_bw_graph(f, [torch.randn(3, requires_grad=True)]) self.assertEqual(get_num_ins_outs(fw_graph), (1, 2)) self.assertEqual(get_num_ins_outs(bw_graph), (2, 1)) def f(a, b, c, d): x = a + b + c + d return x.cos().cos() fw_graph, bw_graph = get_fw_bw_graph( f, [torch.randn(3, requires_grad=True) for _ in range(4)] ) self.assertEqual(get_num_ins_outs(fw_graph), (4, 2)) self.assertEqual(get_num_ins_outs(bw_graph), (2, 4)) def test_contiguous(self): # The test simulates the condition where transpose followed by view # happens in the backward pass. # https://discuss.pytorch.org/t/error-on-transpose-and-view/434 def f(x): return x.view(2, 3).t() inp = torch.randn(6, requires_grad=True) out = aot_function(f, nop)(inp) torch.autograd.grad(out, inp, torch.randn(3, 2)) def test_preserve_random(self): def fn(x): return torch.nn.functional.dropout(x, 0.5) + x x = torch.randn(4) torch.manual_seed(0) ref = fn(x) torch.manual_seed(0) aot_fn = aot_function(fn, nop) res = aot_fn(x) assert torch.allclose(ref, res) # https://github.com/pytorch/pytorch/issues/110666 def test_generate_gives_inference_graph(self): # We expect this to give an inference graph def generate(x): with torch.no_grad(): return torch.mul(x, x) inference_graph_cell = [None] inference_compiler = make_boxed_compiler( partial(extract_graph, graph_cell=inference_graph_cell) ) aot_fn = aot_function(generate, nop, inference_compiler=inference_compiler) # Even though x requires grad, we should still get an inference graph x = torch.randn(4, requires_grad=True) res = aot_fn(x) self.assertTrue(inference_graph_cell[0] is not None) @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable") @unittest.skipIf(not USE_TORCHVISION, "test requires torchvision") def test_autocast(self): mod = torchvision.models.resnet18().cuda() mod.train() x = torch.randn(16, 3, 32, 32, device="cuda") aot_mod = memory_efficient_fusion(mod) # Ensure that AOT Autograd works with AMP with torch.cuda.amp.autocast(True): res = aot_mod(x) res.sum().backward() class TestAOTDispatch(AOTTestCase): # Tests to add cases for (non-exhaustive list, mostly for my notes): # - subclass / mode introduced in the middle of the compiled fn # - various input mutation / intermediate base tests # - input mutation that changes a tensor into a subclass # - metadata mutation? (TBD) # - guard tests (fw guards *and* bw guards) # - subclass test involving _indices_of_inps_to_detach def test_aot_dispatch_simple(self): # a is a subclass, b is not def f(a, b): aa = torch.mul(a, 6) bb = torch.div(b, 2) return aa + bb a1_ref = torch.ones(3, 3, requires_grad=True) a2_ref = torch.ones(3, 3, requires_grad=True) a_ref = TwoTensor(a1_ref, a2_ref) b_ref = torch.ones(3, 3, requires_grad=True) a1_test = a1_ref.clone().detach().requires_grad_(True) a2_test = a2_ref.clone().detach().requires_grad_(True) a_test = TwoTensor(a1_test, a2_test) b_test = b_ref.clone().detach().requires_grad_(True) fw_graph_cell = [None] bw_graph_cell = [None] compiled_f = aot_function( f, fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell), bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell), partition_fn=min_cut_rematerialization_partition, ) out_ref = f(a_ref, b_ref) out_test = compiled_f(a_test, b_test) # Output is a TwoTensor (check both inner tensors) self.assertEqual(out_ref.a, out_test.a) self.assertEqual(out_ref.b, out_test.b) out_ref.sum().backward() out_test.sum().backward() # Both grad_inputs are TwoTensor self.assertEqual(a_ref.grad.a, a_test.grad.a) self.assertEqual(a_ref.grad.b, a_test.grad.b) self.assertEqual(b_ref.grad.a, b_test.grad.a) self.assertEqual(b_ref.grad.b, b_test.grad.b) # Important pieces of the graph: # - mul() and div() show up twice, because we called them on a TwoTensor # - add() shows up once, because we called it on a plain Tensor # - The user forward() fn returns 1 output (the result of add), # while the graph itself returns two outputs (add, add_1) # - add, add_1 correspond to the two inner dense tensors that will be wrapped # - into a single TwoTensor output. self.assertExpectedInline( fw_graph_cell[0].code.strip(), """\ def forward(self, primals_1, primals_2, primals_3): mul = torch.ops.aten.mul.Tensor(primals_1, 6); primals_1 = None mul_1 = torch.ops.aten.mul.Tensor(primals_2, 6); primals_2 = None div = torch.ops.aten.div.Tensor(primals_3, 2); primals_3 = None add = torch.ops.aten.add.Tensor(mul, div); mul = None add_1 = torch.ops.aten.add.Tensor(mul_1, div); mul_1 = div = None return (add, add_1)""", ) # Important pieces of the graph: # - 4 total dense outputs. # This corresponds to the fact that each user fwd inpt (a, b) # will get a gradient that is a TwoTensor subclass, # so (mul_2, mul_3) will be wrapped into a.grad # and (div_1, div_2) will be wrapped into b.grad # - 4 total dense outputs, self.assertExpectedInline( bw_graph_cell[0].code.strip(), """\ def forward(self, tangents_1, tangents_2): div_1 = torch.ops.aten.div.Tensor(tangents_1, 2) div_2 = torch.ops.aten.div.Tensor(tangents_2, 2) mul_2 = torch.ops.aten.mul.Tensor(tangents_1, 6); tangents_1 = None mul_3 = torch.ops.aten.mul.Tensor(tangents_2, 6); tangents_2 = None return (mul_2, mul_3, div_1, div_2)""", ) def test_aot_dispatch_inference(self): # a is a subclass, b is not def f(a, b): aa = torch.mul(a, 6) bb = torch.div(b, 2) return aa + bb a1_ref = torch.ones(3, 3) a2_ref = torch.ones(3, 3) a_ref = TwoTensor(a1_ref, a2_ref) b_ref = torch.ones(3, 3) a1_test = a1_ref.clone() a2_test = a2_ref.clone() a_test = TwoTensor(a1_test, a2_test) b_test = b_ref.clone() compiled_f = aot_function( f, fw_compiler=nop, bw_compiler=nop, partition_fn=min_cut_rematerialization_partition, ) out_ref = f(a_ref, b_ref) out_test = compiled_f(a_test, b_test) # Output is a TwoTensor (check both inner tensors) self.assertEqual(out_ref.a, out_test.a) self.assertEqual(out_ref.b, out_test.b) def test_aot_dispatch_incorrect_backward(self): # a is a subclass, b is not def f(a, b): aa = torch.mul(a, 2) bb = torch.add(b, 3) out_subclass = torch.div(aa, bb) out_reg = torch.add(b, b) # When creating the joint, we assume that the second grad_out # is not a subclass. # In the below test case though, we end up being wrong. # This would require re-tracing and recompiling the backward. return out_subclass, out_reg a1_ref = torch.ones(3, 3, requires_grad=True) a2_ref = torch.ones(3, 3, requires_grad=True) a_ref = TwoTensor(a1_ref, a2_ref) b_ref = torch.ones(3, 3, requires_grad=True) a1_test = a1_ref.clone().detach().requires_grad_(True) a2_test = a2_ref.clone().detach().requires_grad_(True) a_test = TwoTensor(a1_test, a2_test) b_test = b_ref.clone().detach().requires_grad_(True) compiled_f = aot_function( f, fw_compiler=nop, bw_compiler=nop, partition_fn=min_cut_rematerialization_partition, ) out_ref = f(a_ref, b_ref) out_test = compiled_f(a_test, b_test) # First out is a TwoTensor, second is an ordinary tensor self.assertEqual(out_ref[0].a, out_test[0].a) self.assertEqual(out_ref[0].b, out_test[0].b) self.assertEqual(out_ref[1], out_test[1]) # We compiled our graph assuming type(grad_out[1]) == torch.Tensor, # but we were wrong: in the below tests, it is a subclass. # This will eventually require a repartition + recompile with self.assertRaisesRegex( AssertionError, "incorrectly attempted to compile the backward with incorrect subclass metadata", ): (out_test[0] + out_test[1]).sum().backward() def test_aot_dispatch_output_alias(self): # a is a tensor, b is a TwoTensor def f(a, b): return b.view(b.shape), a * b b1_ref = torch.ones(3, 3, requires_grad=True) b2_ref = torch.ones(3, 3, requires_grad=True) b_ref = TwoTensor(b1_ref, b2_ref) a_ref = torch.ones(3, 3, requires_grad=True) b1_test = b1_ref.clone().detach().requires_grad_(True) b2_test = b2_ref.clone().detach().requires_grad_(True) b_test = TwoTensor(b1_test, b2_test) a_test = a_ref.clone().detach().requires_grad_(True) compiled_f = aot_function( f, fw_compiler=nop, bw_compiler=nop, partition_fn=min_cut_rematerialization_partition, ) out_ref1, out_ref2 = f(a_ref, b_ref) out_test1, out_test2 = compiled_f(a_test, b_test) self.assertEqual(out_ref1, out_test1) self.assertEqual(out_ref2.a, out_test2.a) self.assertEqual(out_ref2.b, out_test2.b) (out_ref1 + out_ref2).sum().backward() (out_test1 + out_test2).sum().backward() # Both grad_inputs are TwoTensor self.assertEqual(a_ref.grad.a, a_test.grad.a) self.assertEqual(a_ref.grad.b, a_test.grad.b) self.assertEqual(b_ref.grad.a, b_test.grad.a) self.assertEqual(b_ref.grad.b, b_test.grad.b) def test_aot_dispatch_input_mutation(self): def f(a, b): a.mul_(2) b.mul_(3) return a + b b1_ref = torch.ones(3, 3, requires_grad=True) b2_ref = torch.ones(3, 3, requires_grad=True) b_ref_base = TwoTensor(b1_ref, b2_ref) a_ref_base = torch.ones(3, 3, requires_grad=True) b_ref = b_ref_base + 1 a_ref = a_ref_base + 1 b1_test = b1_ref.clone().detach().requires_grad_(True) b2_test = b2_ref.clone().detach().requires_grad_(True) b_test_base = TwoTensor(b1_test, b2_test) a_test_base = a_ref_base.clone().detach().requires_grad_(True) b_test = b_test_base + 1 a_test = a_test_base + 1 compiled_f = aot_function( f, fw_compiler=nop, bw_compiler=nop, partition_fn=min_cut_rematerialization_partition, ) out_ref = f(a_ref, b_ref) out_test = compiled_f(a_test, b_test) self.assertEqual(out_ref.a, out_test.a) self.assertEqual(out_ref.b, out_test.b) # confirm input mutations worked self.assertEqual(a_test, a_ref) self.assertEqual(b_test.a, b_ref.a) self.assertEqual(b_test.b, b_ref.b) # NOTE: we need to use b in our gradient compute. Otherwise we will need to recompile teh backward. (b_ref * out_ref).sum().backward() (b_test * out_test).sum().backward() # Both grad_inputs are TwoTensor self.assertEqual(a_ref_base.grad.a, a_test_base.grad.a) self.assertEqual(a_ref_base.grad.b, a_test_base.grad.b) self.assertEqual(b_ref_base.grad.a, b_test_base.grad.a) self.assertEqual(b_ref_base.grad.b, b_test_base.grad.b) # NB: Metadata mutation for subclasses is currently broken and disabled # See https://github.com/pytorch/pytorch/issues/114975 @unittest.expectedFailure def test_aot_dispatch_input_metadata_mutation(self): def f(a, b): a.t_() b.unsqueeze_(0) return a + b b1_ref = torch.arange(9, requires_grad=True, dtype=torch.float32).reshape(3, 3) b2_ref = torch.arange(9, requires_grad=True, dtype=torch.float32).reshape(3, 3) b_ref_base = TwoTensor(b1_ref, b2_ref) a_ref_base = ( torch.arange(9, dtype=torch.float32) .reshape(3, 3) .detach() .requires_grad_(True) ) b_ref = b_ref_base + 1 a_ref = a_ref_base + 1 b1_test = b1_ref.clone().detach().requires_grad_(True) b2_test = b2_ref.clone().detach().requires_grad_(True) b_test_base = TwoTensor(b1_test, b2_test) a_test_base = a_ref_base.clone().detach().requires_grad_(True) b_test = b_test_base + 1 a_test = a_test_base + 1 compiled_f = aot_function( f, fw_compiler=nop, bw_compiler=nop, partition_fn=min_cut_rematerialization_partition, ) out_ref = f(a_ref, b_ref) out_test = compiled_f(a_test, b_test) self.assertEqual(out_ref.a, out_test.a) self.assertEqual(out_ref.b, out_test.b) # confirm input mutations worked self.assertEqual(a_test, a_ref) self.assertEqual(b_test.a, b_ref.a) self.assertEqual(b_test.b, b_ref.b) # NOTE: we need to use b in our gradient compute. Otherwise we will need to recompile the backward. (b_ref * out_ref).sum().backward() (b_test * out_test).sum().backward() # Both grad_inputs are TwoTensor self.assertEqual(a_ref_base.grad.a, a_test_base.grad.a) self.assertEqual(a_ref_base.grad.b, a_test_base.grad.b) self.assertEqual(b_ref_base.grad.a, b_test_base.grad.a) self.assertEqual(b_ref_base.grad.b, b_test_base.grad.b) # NB: Metadata mutation for subclasses is currently broken and disabled # See https://github.com/pytorch/pytorch/issues/114975 @unittest.expectedFailure def test_aot_dispatch_input_data_and_metadata_mutation(self): def f(a, b): a.t_() b.unsqueeze_(0) a.mul_(2) b.mul_(3) return a + b b1_ref = torch.arange(9, requires_grad=True, dtype=torch.float32).reshape(3, 3) b2_ref = torch.arange(9, requires_grad=True, dtype=torch.float32).reshape(3, 3) b_ref_base = TwoTensor(b1_ref, b2_ref) a_ref_base = ( torch.arange(9, dtype=torch.float32) .reshape(3, 3) .detach() .requires_grad_(True) ) b_ref = b_ref_base + 1 a_ref = a_ref_base + 1 b1_test = b1_ref.clone().detach().requires_grad_(True) b2_test = b2_ref.clone().detach().requires_grad_(True) b_test_base = TwoTensor(b1_test, b2_test) a_test_base = a_ref_base.clone().detach().requires_grad_(True) b_test = b_test_base + 1 a_test = a_test_base + 1 compiled_f = aot_function( f, fw_compiler=nop, bw_compiler=nop, partition_fn=min_cut_rematerialization_partition, ) out_ref = f(a_ref, b_ref) out_test = compiled_f(a_test, b_test) self.assertEqual(out_ref.a, out_test.a) self.assertEqual(out_ref.b, out_test.b) # confirm input mutations worked self.assertEqual(a_test, a_ref) self.assertEqual(b_test.a, b_ref.a) self.assertEqual(b_test.b, b_ref.b) # NOTE: we need to use b in our gradient compute. Otherwise we will need to recompile the backward. (b_ref * out_ref).sum().backward() (b_test * out_test).sum().backward() # Both grad_inputs are TwoTensor self.assertEqual(a_ref_base.grad.a, a_test_base.grad.a) self.assertEqual(a_ref_base.grad.b, a_test_base.grad.b) self.assertEqual(b_ref_base.grad.a, b_test_base.grad.a) self.assertEqual(b_ref_base.grad.b, b_test_base.grad.b) def test_aot_dispatch_input_mutation_and_output_alias(self): def f(a, b): a.mul_(2) b.mul_(3) return b.view(b.shape), a + b b1_ref = torch.arange(9, requires_grad=True, dtype=torch.float32).reshape(3, 3) b2_ref = torch.arange(9, requires_grad=True, dtype=torch.float32).reshape(3, 3) b_ref_base = TwoTensor(b1_ref, b2_ref) a_ref_base = ( torch.arange(9, dtype=torch.float32) .reshape(3, 3) .detach() .requires_grad_(True) ) b_ref = b_ref_base + 1 a_ref = a_ref_base + 1 b1_test = b1_ref.clone().detach().requires_grad_(True) b2_test = b2_ref.clone().detach().requires_grad_(True) b_test_base = TwoTensor(b1_test, b2_test) a_test_base = a_ref_base.clone().detach().requires_grad_(True) b_test = b_test_base + 1 a_test = a_test_base + 1 compiled_f = aot_function( f, fw_compiler=nop, bw_compiler=nop, partition_fn=min_cut_rematerialization_partition, ) out_ref1, out_ref2 = f(a_ref, b_ref) out_test1, out_test2 = compiled_f(a_test, b_test) self.assertEqual(out_ref1.a, out_test1.a) self.assertEqual(out_ref1.b, out_test1.b) self.assertEqual(out_ref2.a, out_test2.a) self.assertEqual(out_ref2.b, out_test2.b) # confirm input mutations worked self.assertEqual(a_test, a_ref) self.assertEqual(b_test.a, b_ref.a) self.assertEqual(b_test.b, b_ref.b) (out_ref1 * out_ref2).sum().backward() (out_test1 * out_test2).sum().backward() # Both grad_inputs are TwoTensors self.assertEqual(a_ref_base.grad.a, a_test_base.grad.a) self.assertEqual(a_ref_base.grad.b, a_test_base.grad.b) def test_aot_dispatch_output_requires_grad_in_no_grad(self): def fn(x): out1 = x.sin() with torch.enable_grad(): out2 = x.cos() return out1, out2 inp_fns = [ lambda: torch.ones(10, requires_grad=True), lambda: torch.ones(10, requires_grad=False), ] compiled_f = aot_function(fn, nop) for inp_fn in inp_fns: with torch.no_grad(): ref_x = inp_fn() ref_out = fn(ref_x) x = inp_fn() out = compiled_f(x) for r, o in zip(ref_out, out): self.assertEqual(r.requires_grad, o.requires_grad) if ref_x.requires_grad: with torch.enable_grad(): (ref_out[0] + ref_out[1]).sum().backward() (out[0] + out[1]).sum().backward() self.assertEqual(ref_x.grad, x.grad) assert torch.allclose(ref_x.grad, x.grad, atol=1e-3, rtol=1e-3) def test_aot_dispatch_output_requires_grad_in_no_grad_views(self): # view-type ops preserve requires_grad even in no_grad. def fn(x): return x.view(-1), x.sin() inference_graph_cell = [None] inference_compiler = make_boxed_compiler( partial(extract_graph, graph_cell=inference_graph_cell) ) compiled_fn = aot_function(fn, nop, inference_compiler=inference_compiler) inp_x0 = torch.ones(2, 3, requires_grad=True) # Clone in no_grad will make requires_grad=False tensors, keep clone outside of no_grad ref_x0 = inp_x0.clone() x0 = inp_x0.clone() with torch.no_grad(): ref_out1, ref_out2 = fn(ref_x0) out1, out2 = compiled_fn(x0) # Assert that we executed inference graph self.assertTrue(inference_graph_cell[0] is not None) self.assertEqual(ref_out1.requires_grad, out1.requires_grad) self.assertEqual(ref_out2.requires_grad, out2.requires_grad) class TestAOTModuleSimplified(AOTTestCase): def test_aot_module_simplified(self): class MockModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(20, 30) def forward(self, x, y): return (self.linear(x) + y,) mod = MockModule() mod.zero_grad() x = torch.randn(128, 20, requires_grad=True) y = torch.randn(128, 30, requires_grad=True) inputs = [x, y] cloned_inputs = [x.detach().clone().requires_grad_(True) for x in inputs] ref = mod(*inputs) ref[0].sum().backward() compiled_f = aot_module_simplified(mod, cloned_inputs, nop) mod.zero_grad() res = compiled_f(*cloned_inputs) res[0].sum().backward() assert torch.allclose(ref[0], res[0]) assert torch.allclose(inputs[0].grad, cloned_inputs[0].grad) assert torch.allclose(inputs[1].grad, cloned_inputs[1].grad) def test_aot_module_simplified_dynamic(self): class MockModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(20, 30) def forward(self, x, y): return (self.linear(x) + y,) mod = MockModule() shape_env = ShapeEnv() fake_mode = FakeTensorMode(shape_env=shape_env) x = torch.randn(128, 20, requires_grad=True) y = torch.randn(128, 30, requires_grad=True) inputs = [x, y] fake_inputs = [fake_mode.from_tensor(x) for x in inputs] compiled_f = aot_module_simplified(mod, fake_inputs, nop) ref = mod(*inputs) ref[0].sum().backward() cloned_inputs = [x.detach().clone().requires_grad_(True) for x in inputs] res = compiled_f(*cloned_inputs) res[0].sum().backward() self.assertExpectedInline( shape_env.format_guards(), """\ - Eq(s1, 20) - Eq(s2, 30)""", ) assert torch.allclose(ref[0], res[0]) assert torch.allclose(inputs[0].grad, cloned_inputs[0].grad) assert torch.allclose(inputs[1].grad, cloned_inputs[1].grad) # https://github.com/pytorch/pytorch/issues/105327 def test_lift_fresh_copy_in_graph(self): class MyMod(torch.nn.Module): def forward(self, x): _tensor_constant0 = torch.tensor([1]) lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default( _tensor_constant0 ) y = x.mul(lift_fresh_copy) return (y,) mod = MyMod() shape_env = ShapeEnv() fake_mode = FakeTensorMode(shape_env=shape_env) x = torch.ones(4, requires_grad=True) inputs = [x] fake_inputs = [fake_mode.from_tensor(x) for x in inputs] compiled_f = aot_module_simplified(mod, fake_inputs, nop) out_ref = mod(x) out_test = compiled_f(x) self.assertEqual(out_ref[0].detach(), out_test[0].detach()) def test_inference_python_dispatcher(self): # Extracted from unet class MockModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.upsample = torch.nn.Upsample( scale_factor=2, mode="bilinear", align_corners=True ) def forward(self, x): return (self.upsample(x),) mod = MockModule() shape_env = ShapeEnv() fake_mode = FakeTensorMode(shape_env=shape_env) x = torch.randn(2, 512, 40, 59) # NB: must not require grad inputs = [x] fake_inputs = [fake_mode.from_tensor(x) for x in inputs] compiled_f = aot_module_simplified(mod, fake_inputs, nop) def test_aot_module_simplified_preserves_stack_trace(self): class MockModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(20, 30) def forward(self, x, y): z = self.linear(x) z = z + y z = z.relu() return (z,) tracer = torch.fx.Tracer() tracer.record_stack_traces = True graph = tracer.trace(MockModule()) mod = torch.fx.GraphModule(tracer.root, graph) for node in mod.graph.nodes: if node.op == "output": continue self.assertTrue(node.stack_trace is not None) assert "test_aotdispatch.py" in node.stack_trace def assert_compiler(gm: torch.fx.GraphModule, _): for node in gm.graph.nodes: if node.op == "output" or node.op == "placeholder": continue self.assertTrue(node.stack_trace is not None) assert "test_aotdispatch.py" in node.stack_trace return gm.forward # return a python callable x = torch.randn(128, 20, requires_grad=True) y = torch.randn(128, 30, requires_grad=True) inputs = [x, y] compiled_f = aot_module_simplified( mod, inputs, fw_compiler=assert_compiler, bw_compiler=assert_compiler ) res = compiled_f(*inputs) res[0].sum().backward() def test_aot_module_simplified_preserves_stack_trace_from_mutation(self): class MockModule(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, x): x_view = x[0] x_view.mul_(2) return (x + x,) tracer = torch.fx.Tracer() tracer.record_stack_traces = True graph = tracer.trace(MockModule()) mod = torch.fx.GraphModule(tracer.root, graph) for node in mod.graph.nodes: if node.op == "output": continue self.assertTrue(node.stack_trace is not None) assert "test_aotdispatch.py" in node.stack_trace def assert_compiler(gm: torch.fx.GraphModule, _): assert torch.ops.aten.copy_.default in [x.target for x in gm.graph.nodes] for node in gm.graph.nodes: if node.target == torch.ops.aten.copy_.default: assert "stack_trace" in node.meta assert "x_view.mul_(2)" in node.meta["stack_trace"] return gm.forward # return a python callable x = torch.randn(128, 20) inputs = [x] aot_module_simplified( mod, inputs, fw_compiler=assert_compiler, bw_compiler=assert_compiler, keep_inference_input_mutations=True, ) def test_aot_module_simplified_fake_tensor_gm_raises(self): fake_mode = torch._subclasses.fake_tensor.FakeTensorMode() real_x = torch.randn(4, requires_grad=True) fake_x = fake_mode.from_tensor(real_x) real_z = torch.randn(4) fake_z = fake_mode.from_tensor(real_z) class MockModule(torch.nn.Module): def forward(self, x): # Accessing a free variable fake tensor will look like a # constant to make_fx, and result in the tensor being traced # into the graph, which is an error condition. Make sure we # report adequately in this case. return (x + fake_z,) with self.assertRaisesRegex(AssertionError, "Unexpected fake"): aot_module_simplified(MockModule(), (fake_x,), nop) def test_aot_test_subclasses_with_tensor_factories(self): from torch.testing._internal.common_subclass import SubclassWithTensorFactory inp = SubclassWithTensorFactory(torch.zeros(3, 5)) def fn(x): return 2 * x ref_out = fn(inp) out = torch.compile(fn, backend="aot_eager", fullgraph=True)(inp) self.assertEqual(ref_out, out) # entries in here don't work and need to be fixed. # Each one of these is a bug (or needs to be investigated) aot_autograd_failures = { # data-dependent control flow xfail("cov"), xfail("nn.functional.gaussian_nll_loss"), xfail("tensor_split"), xfail("corrcoef"), xfail("quantile"), xfail("nanquantile"), xfail("narrow"), xfail("istft"), xfail("linalg.eig"), skip("as_strided_scatter"), skip("as_strided", "partial_views"), # flaky # Given input size: (s0xs1x2). Calculated output size: ... skip("max_pool2d_with_indices_backward"), skip("nn.functional.nll_loss", ""), # UBSAN failure! # Misc xfail("to_sparse"), xfail("corrcoef"), xfail("cov"), xfail("chalf"), # RuntimeError: "sum_cpu" not implemented for 'ComplexHalf' xfail("sparse.sampled_addmm"), xfail("sparse.mm", "reduce"), skip("nn.functional.binary_cross_entropy_with_logits"), # seems to fail sometimes? skip("nn.functional.margin_ranking_loss"), # seems flaky skip("linalg.lu_solve"), # flaky decorate("matmul", decorator=unittest.skipIf(IS_ARM64, "flaky")), decorate("__rmatmul__", decorator=unittest.skipIf(IS_ARM64, "flaky")), # overrides atol=1e-4, rtol=1e-5 would do as well decorate( "svd_lowrank", decorator=toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-05)}), ), decorate( "linalg.householder_product", decorator=unittest.skipIf(IS_MACOS and IS_X86, "flaky"), ), decorate( "linalg.pinv", "singular", decorator=toleranceOverride({torch.float32: tol(atol=1e-05, rtol=1e-05)}), ), decorate( "nn.functional.interpolate", "bicubic", decorator=toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-05)}), ), # conv2d sometimes nondeterministic in this config? decorate("nn.functional.conv2d", decorator=unittest.skipIf(IS_ARM64, "flaky")), } symbolic_aot_autograd_failures = { xfail("combinations", ""), # aten.masked_select.default xfail( "index_fill", "" ), # Cannot call sizes() on tensor with symbolic sizes/strides xfail( "linalg.lstsq", "" ), # aten.linalg_lstsq.default - couldn't find symbolic meta function/decomposition xfail( "linalg.lstsq", "grad_oriented" ), # aten.linalg_lstsq.default - couldn't find symbolic meta funct... xfail( "linalg.lu_solve", "" ), # aten.linalg_lu_solve.default - couldn't find symbolic meta function/deco... skip( "nn.functional.batch_norm", "" ), # '0 is not tracked with proxy for None: self.cache = {} def save(self, key, gm): self.cache[key] = gm def load(self, gm, inputs): key, _ = compiled_fx_graph_hash(gm, inputs, {}, {}) if key in self.cache: gm = make_boxed_func(gm) gm._fx_graph_cache_key = key return gm else: self.save(key, gm) gm = make_boxed_func(gm) gm._fx_graph_cache_key = key return gm def _lookup_graph(self, key, inputs, local, remote_cache): gm = self.cache.get(key) if gm is not None: gm = make_boxed_func(gm) return gm def post_compile(self, gm, inputs, cudagraphs): pass # The following tests fail in strict caching mode (i.e. they bypass or # cache miss instead of cache hitting). They will be fixed in the PRs above this. FAILING_CACHE_TESTS = ( # BypassAOTAutogradCache: unsupported nodes "test_backward_mutation_data", # Custom Autograd Function "test_backward_mutation_metadata", # Custom Autograd Function "test_custom_autograd", # Custom Autograd Function "test_input_output_aliase_custom_autograd_function", ) @xfail_inherited_tests(FAILING_CACHE_TESTS) class TestAOTAutogradWithCache(TestAOTAutogradWithDynamo): """ In memory version of FXGraphCache so we can isolate testing for FXGraphCache """ def make_compiler(self, fw_graph_cell): mock_inductor_cache = self.inductor_cache def compiler(gm, inputs): nonlocal mock_inductor_cache, fw_graph_cell result = mock_inductor_cache.load(gm, inputs) fw_graph_cell[0] = gm return result return compiler def run_autograd( self, f: Callable, fw_graph_cell: List[Optional[Callable]], decompositions: Optional[Dict], keep_input_mutations: bool, dynamic: bool, ): return super().run_autograd( f, fw_graph_cell, decompositions, keep_input_mutations, dynamic, ) @torch._functorch.config.patch( { "enable_autograd_cache": True, "strict_autograd_cache": True, "view_replay_for_aliased_outputs": False, } ) @torch._inductor.config.patch("fx_graph_cache", True) def verify_aot_autograd( self, f, inp_: Union[Callable, List[Any]], *, test_mutation: bool = False, keep_inp_mutations: bool = False, decompositions: Optional[Dict] = None, dynamic: bool = False, # Only active when inp_ is Callable. # TODO: probably consolidate all tests to make inp a Callable. make_inputs_subclasses: bool = False, ): self.inductor_cache = MockFXGraphCache() AOTAutogradCache.clear() with patch( "torch._inductor.codecache.FxGraphCache._lookup_graph", new=self.inductor_cache._lookup_graph, ), patch( "torch._inductor.codecache.FxGraphCache.post_compile", new=self.inductor_cache.post_compile, ): return super().verify_aot_autograd( f, inp_, test_mutation=test_mutation, keep_inp_mutations=keep_inp_mutations, decompositions=decompositions, dynamic=dynamic, make_inputs_subclasses=make_inputs_subclasses, ) def test_input_mutation_false_aliasing(self): # This test is disabled because it fails in strict cache mode # But also can't be xfailed because it causes undefined behavior for # ASAN self.skipTest("Skipping because it fails in strict cache mode") if __name__ == "__main__": run_tests()