# Owner(s): ["module: dynamo"] """ PYTEST_DONT_REWRITE (prevents pytest from rewriting assertions, which interferes with test_export_persist_assert) """ import copy import functools import inspect import io import operator import unittest from enum import Enum from typing import Dict, List, Sequence from unittest.mock import patch import torch import torch._dynamo import torch._dynamo.test_case import torch._dynamo.testing from functorch.experimental.control_flow import cond from torch._dynamo import config from torch._dynamo.exc import UserError from torch._dynamo.testing import normalize_gm from torch._higher_order_ops.out_dtype import out_dtype from torch._subclasses import fake_tensor from torch.fx.experimental.proxy_tensor import make_fx from torch.fx.experimental.symbolic_shapes import ( ConstraintViolationError, DimDynamic, ShapeEnv, StatelessSymbolicContext, ) from torch.testing._internal import common_utils from torch.testing._internal.common_cuda import TEST_CUDA class ExportTests(torch._dynamo.test_case.TestCase): # TODO(voz): Refactor to a shared test function. # The tests in this file are a little redundant, # They all take a func, run it with eager, then export it, then compare def test_export(self): def pre_attention_state_ops(input, mems, state): lc_key = state[0] lc_val = state[1] bar = [] for i in range(0, 4): bar2 = [] for j in range(0, 3): bar2.append( lc_key + lc_val + torch.tensor([0.1, 0.25, 0.4, 0.5, 0.1]) ) bar.append(bar2) return bar def func(): mems = torch.tensor([[[1.8364, 0.2724, -1.4917, -0.4367, 0.8640]]]) state = [ torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]), torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]), ] i = torch.tensor( [ [0.0313, -0.1487, -0.3846, -0.5321], [-1.7073, 1.3331, -0.0890, -1.4935], [-0.8314, -0.1862, -0.5935, 1.5232], ] ) return pre_attention_state_ops(i, mems, state) opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) real_result = opt_func() torch._dynamo.reset() exported = torch._dynamo.export(func)() out_graph = exported[0] dynamo_result = out_graph() self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) def test_no_tensor_computation_fail(self): with self.assertRaisesRegex( AssertionError, "Failed to produce a graph", ): inp = [torch.randn(3)] inp2 = 2 inps = [inp, inp2] def func(x, y): return x exported = torch._dynamo.export(func, same_signature=False)(*inps) def test_no_tensor_computation(self): inp = [torch.randn(3)] inp2 = 2 inps = [inp, inp2] def func(x, y): return x opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) real_result = opt_func(*inps) torch._dynamo.reset() exported = torch._dynamo.export(func)(*inps) out_graph = exported[0] dynamo_result = out_graph(*inps) self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) self.assertExpectedInline( out_graph.code.strip(), """\ def forward(self, x, y): arg0, arg1, = fx_pytree.tree_flatten_spec(([x, y], {}), self._in_spec) x = arg0 return pytree.tree_unflatten([x], self._out_spec)""", ) def test_no_tensor_computation_2(self): inp = torch.randn(3) inp2 = 2 inps = [inp, inp2] def func(x, y): return y opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) real_result = opt_func(*inps) torch._dynamo.reset() exported = torch._dynamo.export(func)(*inps) out_graph = exported[0] dynamo_result = out_graph(*inps) self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) self.assertExpectedInline( out_graph.code.strip(), """\ def forward(self, x, y): arg0, arg1, = fx_pytree.tree_flatten_spec(([x, y], {}), self._in_spec) x = arg0 return pytree.tree_unflatten([2], self._out_spec)""", ) def test_export_mismatched_out(self): def func(x): y = x + 1 return ([x, x], (y, y)) opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) real_result = opt_func(torch.tensor([[[1.3737, 0.1]]])) torch._dynamo.reset() exported = torch._dynamo.export(func)(torch.tensor([[[1.3737, 0.1]]])) out_graph = exported[0] dynamo_result = out_graph(torch.tensor([[[1.3737, 0.1]]])) self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) def test_export_shape_control_flow_1(self): def func(x): if x.shape[0] > 10: return x.cos() return x.sin() opt_func = torch._dynamo.optimize("eager")(func) real_result = opt_func(torch.ones(6, 4)) torch._dynamo.reset() exported = torch._dynamo.export(func)(torch.ones(6, 4)) out_graph, out_guards = exported dynamo_result = out_graph(torch.ones(6, 4)) from torch._guards import GuardSource self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) hit = False for guard in out_guards: if guard.source == GuardSource.SHAPE_ENV: hit = True self.assertExpectedInline( guard.code_list, """["L['x'].stride()[0] == L['x'].size()[1]", "L['x'].stride()[1] == 1", "L['x'].storage_offset() == 0", "2 <= L['x'].size()[0] <= 10", "2 <= L['x'].size()[1]"]""", # noqa: B950 ) break self.assertTrue(hit) def test_export_control_flow_with_getattr(self): class Animal(Enum): COW = "moo" class MyModule(torch.nn.Module): def __init__(self, a): super().__init__() self.a = a def forward(self, x): if self.a == Animal.COW.value: return x * x else: raise ValueError("bad") module = MyModule("moo") input = (torch.ones(4, 3),) resA = module(*input) graph, _ = torch._dynamo.export(module)(*input) resB = graph(*input) self.assertTrue(torch._dynamo.utils.same(resA, resB)) def test_export_graph_bypass(self): inp = [ torch.tensor([0.1, 0.1]), torch.tensor([0.2, 0.2]), torch.tensor([0.3, 0.3]), ] def func(x): first = x[2] second = x[2] return first * second opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) real_result = opt_func(inp) torch._dynamo.reset() exported = torch._dynamo.export(func)(inp) out_graph = exported[0] dynamo_result = out_graph(inp) self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) def test_list_unpack(self): inp = [ torch.tensor([0.1, 0.1]), torch.tensor([0.2, 0.2]), torch.tensor([0.3, 0.3]), ] def func(x): first = x[2] second = x[2] return x[0], first * second, x[1], x[2] opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) real_result = opt_func(inp) torch._dynamo.reset() exported = torch._dynamo.export(func)(inp) out_graph = exported[0] dynamo_result = out_graph(inp) self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) def test_export_with_shallow_list_copy_wo_side_effects(self): def f(x): y = x.copy() return y[0] + y[1] inp = [torch.tensor([1.3, 3.77, 0.1]), torch.tensor([8.7, 6.23, 9.9])] gm = torch._dynamo.export(f, aten_graph=True, tracing_mode="symbolic")( inp ).graph_module self.assertTrue(torch._dynamo.utils.same(gm(inp), f(inp))) def test_export_with_shallow_list_copy_with_side_effects(self): def f(x): y = x.copy() x[0] = x[1] y.append(torch.tensor([[100]])) return x[0] + x[1], y[0] + y[1], y[2] inp = [torch.tensor([1.3, 3.77, 0.1]), torch.tensor([8.7, 6.23, 9.9])] gm = torch._dynamo.export(f, aten_graph=True, tracing_mode="symbolic")( inp ).graph_module res = gm(inp) ref = f(inp) self.assertTrue(torch._dynamo.utils.same(res, ref)) self.assertEqual(res[0], res[1]) def test_export_mismatched_out_2(self): def func(x): y = x + 1 return ([x, x], (y, y)) opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) real_result = opt_func(torch.tensor([[[1.3737, 0.1]]])) torch._dynamo.reset() exported = torch._dynamo.export(func)(torch.tensor([[[1.3737, 0.1]]])) out_graph = exported[0] dynamo_result = out_graph(torch.tensor([[[1.3737, 0.1]]])) self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) def test_export_graph_with_list(self): inp = [ torch.tensor([0.1, 0.1]), torch.tensor([0.2, 0.2]), torch.tensor([0.3, 0.3]), torch.tensor([0.4, 0.4]), ] def func(x): first = x[2] second = x[2] return first * second, x opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) real_result = opt_func(inp) torch._dynamo.reset() exported = torch._dynamo.export(func)(inp) out_graph = exported[0] dynamo_result = out_graph(inp) self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) def test_export_graph_with_complex_reorder(self): inp = [ torch.tensor([0.1, 0.1]), torch.tensor([0.2, 0.2]), torch.tensor([0.3, 0.3]), torch.tensor([0.4, 0.4]), ] def func(x): first = x[0] second = x[1] third = x[2] return third, first, second, first * second, first * third opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) real_result = opt_func(inp) torch._dynamo.reset() exported = torch._dynamo.export(func)(inp) out_graph = exported[0] dynamo_result = out_graph(inp) self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) def test_dupes(self): inp = torch.tensor([0.1, 0.1]) def func(x): y = x + 1 return y, y opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) real_result = opt_func(inp) torch._dynamo.reset() exported = torch._dynamo.export(func)(inp) out_graph = exported[0] dynamo_result = out_graph(inp) self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) def test_dupes_2(self): inp = torch.tensor([0.1, 0.1]) def func(x): y = x + 1 return y, y opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) real_result = opt_func(inp) torch._dynamo.reset() exported = torch._dynamo.export(func)(inp) out_graph = exported[0] dynamo_result = out_graph(inp) self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) def test_dupes_and_bypass(self): inp = torch.tensor([0.1, 0.1]) inp2 = torch.tensor([0.4, 0.4]) inps = [inp, inp2] def func(x, z): y = x + 1 return y, y, z opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) real_result = opt_func(*inps) torch._dynamo.reset() exported = torch._dynamo.export(func)(*inps) out_graph = exported[0] dynamo_result = out_graph(*inps) self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) def test_dupes_and_bypass_with_non_tensor_arg(self): inp = torch.tensor([0.1, 0.1]) inp2 = torch.tensor([0.1, 0.1]) inp3 = 4 inps = [inp, inp2, inp3] def func(x, z, k): y = x + k return y, y, z opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) real_result = opt_func(*inps) torch._dynamo.reset() exported = torch._dynamo.export(func)(*inps) out_graph = exported[0] dynamo_result = out_graph(*inps) self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) def test_dupes_and_bypass_reorder_with_non_tensor_arg(self): inp = torch.tensor([0.1, 0.1]) inp2 = torch.tensor([0.1, 0.1]) inp3 = 4 inps = [inp, inp2, inp3] def func(x, z, k): y = x + k return z, y, y opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) real_result = opt_func(*inps) torch._dynamo.reset() exported = torch._dynamo.export(func)(*inps) out_graph = exported[0] dynamo_result = out_graph(*inps) self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) @config.patch(capture_scalar_outputs=True) def test_dupes_and_bypass_with_non_tensor_output(self): inp = torch.tensor([0.1, 0.1]) inp2 = torch.tensor([0.1, 0.1]) inp3 = 4 inps = [inp, inp2, inp3] def func(x, z, k): y = x + k return y[0].item(), y, z opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) real_result = opt_func(*inps) torch._dynamo.reset() exported = torch._dynamo.export(func)(*inps) out_graph = exported[0] dynamo_result = out_graph(*inps) self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) def test_zeroes_in_and_out_different_shape_on_test(self): inp = torch.zeros(10) inp2 = torch.zeros(10) inp3 = torch.zeros(10) inps = [inp, inp2, inp3] inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)] def func(a, b, c): return [[a], [b, c], [a + b], [[c + c]]] opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) real_result = opt_func(*inps_rand) torch._dynamo.reset() exported = torch._dynamo.export(func)(*inps) out_graph = exported[0] dynamo_result = out_graph(*inps_rand) self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) @config.patch(capture_scalar_outputs=True) def test_zeroes_in_new_shape_scalar_out(self): inp = torch.zeros(10) inp2 = torch.zeros(10) inp3 = torch.zeros(10) inps = [inp, inp2, inp3] inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)] def func(a, b, c): return a[0].item() + b[0].item() + c[0].item() opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) real_result = opt_func(*inps_rand) torch._dynamo.reset() exported = torch._dynamo.export(func)(*inps) out_graph = exported[0] dynamo_result = out_graph(*inps_rand) self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) @config.patch(capture_scalar_outputs=True) def test_zeroes_in_new_shape_scalar_out_permute(self): inp = torch.zeros(10) inp2 = torch.zeros(10) inp3 = torch.zeros(10) inps = [inp, inp2, inp3] inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)] def func(a, b, c): return b[0].item() + c[0].item() + a[0].item() + a[0].item() opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) real_result = opt_func(*inps_rand) torch._dynamo.reset() exported = torch._dynamo.export(func)(*inps) out_graph = exported[0] dynamo_result = out_graph(*inps_rand) self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) @config.patch(capture_scalar_outputs=True) def test_zeroes_in_new_shape_scalar_out_permute_dupe_and_bypass(self): inp = torch.zeros(10) inp2 = torch.zeros(10) inp3 = torch.zeros(10) inps = [inp, inp2, inp3] inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)] def func(a, b, c): return a, b[0].item() + c[0].item() + a[0].item() + a[0].item(), a opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) real_result = opt_func(*inps_rand) torch._dynamo.reset() exported = torch._dynamo.export(func)(*inps) out_graph = exported[0] dynamo_result = out_graph(*inps_rand) self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) def test_func_return(self): inp = torch.zeros(10) inp2 = torch.zeros(10) inp3 = torch.zeros(10) inps = [inp, inp2, inp3] inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)] def func(a, b, c): x = a + b + c def func2(y): return x * y return func2(x) opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) real_result = opt_func(*inps_rand) torch._dynamo.reset() exported = torch._dynamo.export(func)(*inps) out_graph = exported[0] dynamo_result = out_graph(*inps_rand) self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) def test_dict_return(self): inp = torch.zeros(10) inp2 = torch.zeros(10) inp3 = torch.zeros(10) inps = [inp, inp2, inp3] inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)] def func(a, b, c): x = a + b + c return {"a": x} opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) real_result = opt_func(*inps_rand) torch._dynamo.reset() exported = torch._dynamo.export(func)(*inps) out_graph = exported[0] dynamo_result = out_graph(*inps_rand) self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) def test_export_with_aten_graph(self): def pre_attention_state_ops(input, mems, state): lc_key = state[0] lc_val = state[1] bar = [] for i in range(0, 4): bar2 = [] for j in range(0, 3): bar2.append( lc_key + lc_val + torch.tensor([0.1, 0.25, 0.4, 0.5, 0.1]) ) bar.append(bar2) return bar def func(): mems = torch.tensor([[[1.8364, 0.2724, -1.4917, -0.4367, 0.8640]]]) state = [ torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]), torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]), ] i = torch.tensor( [ [0.0313, -0.1487, -0.3846, -0.5321], [-1.7073, 1.3331, -0.0890, -1.4935], [-0.8314, -0.1862, -0.5935, 1.5232], ] ) return pre_attention_state_ops(i, mems, state) opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) real_result = opt_func() torch._dynamo.reset() exported = torch._dynamo.export(func, aten_graph=True)() out_graph = exported[0] dynamo_result = out_graph() self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) def test_export_no_tensor_computation_with_aten_graph(self): inp = [torch.randn(3)] inp2 = 2 inps = [inp, inp2] def func(x, y): return x opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) real_result = opt_func(*inps) torch._dynamo.reset() exported = torch._dynamo.export(func, aten_graph=True)(*inps) out_graph = exported[0] dynamo_result = out_graph(*inps) self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) self.assertExpectedInline( out_graph.code.strip(), """\ def forward(self, x, y): arg0, arg1, = fx_pytree.tree_flatten_spec(([x, y], {}), self._in_spec) arg0_1 = arg0 return pytree.tree_unflatten([arg0_1], self._out_spec)""", ) def test_no_tensor_computation_2_with_aten_graph(self): inp = torch.randn(3) inp2 = 2 inps = [inp, inp2] def func(x, y): return y opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) real_result = opt_func(*inps) torch._dynamo.reset() exported = torch._dynamo.export(func, aten_graph=True)(*inps) out_graph = exported[0] dynamo_result = out_graph(*inps) self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) self.assertExpectedInline( out_graph.code.strip(), """\ def forward(self, x, y): arg0, arg1, = fx_pytree.tree_flatten_spec(([x, y], {}), self._in_spec) arg0_1 = arg0 return pytree.tree_unflatten([2], self._out_spec)""", ) def test_export_mismatched_out_with_aten_graph(self): def func(x): y = x + 1 return ([x, x], (y, y)) opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) real_result = opt_func(torch.tensor([[[1.3737, 0.1]]])) torch._dynamo.reset() exported = torch._dynamo.export(func, aten_graph=True)( torch.tensor([[[1.3737, 0.1]]]) ) out_graph = exported[0] dynamo_result = out_graph(torch.tensor([[[1.3737, 0.1]]])) self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) def test_export_graph_bypass_with_aten_graph(self): inp = [ torch.tensor([0.1, 0.1]), torch.tensor([0.2, 0.2]), torch.tensor([0.3, 0.3]), ] def func(x): first = x[2] second = x[2] return first * second opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) real_result = opt_func(inp) torch._dynamo.reset() exported = torch._dynamo.export(func, aten_graph=True)(inp) out_graph = exported[0] dynamo_result = out_graph(inp) self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) def test_list_unpack_with_aten_graph(self): inp = [ torch.tensor([0.1, 0.1]), torch.tensor([0.2, 0.2]), torch.tensor([0.3, 0.3]), ] def func(x): first = x[2] second = x[2] return x[0], first * second, x[1], x[2] opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) real_result = opt_func(inp) torch._dynamo.reset() exported = torch._dynamo.export(func, aten_graph=True)(inp) out_graph = exported[0] dynamo_result = out_graph(inp) self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) def test_export_mismatched_out_2_with_aten_graph(self): def func(x): y = x + 1 return ([x, x], (y, y)) opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) real_result = opt_func(torch.tensor([[[1.3737, 0.1]]])) torch._dynamo.reset() exported = torch._dynamo.export(func, aten_graph=True)( torch.tensor([[[1.3737, 0.1]]]) ) out_graph = exported[0] dynamo_result = out_graph(torch.tensor([[[1.3737, 0.1]]])) self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) def test_export_graph_with_list_with_aten_graph(self): inp = [ torch.tensor([0.1, 0.1]), torch.tensor([0.2, 0.2]), torch.tensor([0.3, 0.3]), torch.tensor([0.4, 0.4]), ] def func(x): first = x[2] second = x[2] return first * second, x opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) real_result = opt_func(inp) torch._dynamo.reset() exported = torch._dynamo.export(func, aten_graph=True)(inp) out_graph = exported[0] dynamo_result = out_graph(inp) self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) def test_export_graph_with_complex_reorder_with_aten_graph(self): inp = [ torch.tensor([0.1, 0.1]), torch.tensor([0.2, 0.2]), torch.tensor([0.3, 0.3]), torch.tensor([0.4, 0.4]), ] def func(x): first = x[0] second = x[1] third = x[2] return third, first, second, first * second, first * third opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) real_result = opt_func(inp) torch._dynamo.reset() exported = torch._dynamo.export(func, aten_graph=True)(inp) out_graph = exported[0] dynamo_result = out_graph(inp) self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) def test_dupes_with_aten_graph(self): inp = torch.tensor([0.1, 0.1]) def func(x): y = x + 1 return y, y opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) real_result = opt_func(inp) torch._dynamo.reset() exported = torch._dynamo.export(func, aten_graph=True)(inp) out_graph = exported[0] dynamo_result = out_graph(inp) self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) def test_dupes_2_with_aten_graph(self): inp = torch.tensor([0.1, 0.1]) def func(x): y = x + 1 return y, y opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) real_result = opt_func(inp) torch._dynamo.reset() exported = torch._dynamo.export(func, aten_graph=True)(inp) out_graph = exported[0] dynamo_result = out_graph(inp) self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) def test_dupes_and_bypass_with_aten_graph(self): inp = torch.tensor([0.1, 0.1]) inp2 = torch.tensor([0.4, 0.4]) inps = [inp, inp2] def func(x, z): y = x + 1 return y, y, z opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) real_result = opt_func(*inps) torch._dynamo.reset() exported = torch._dynamo.export(func, aten_graph=True)(*inps) out_graph = exported[0] dynamo_result = out_graph(*inps) self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) def test_dupes_and_bypass_with_non_tensor_arg_with_aten_graph(self): inp = torch.tensor([0.1, 0.1]) inp2 = torch.tensor([0.1, 0.1]) inp3 = 4 inps = [inp, inp2, inp3] def func(x, z, k): y = x + k return y, y, z opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) real_result = opt_func(*inps) torch._dynamo.reset() exported = torch._dynamo.export(func, aten_graph=True)(*inps) out_graph = exported[0] dynamo_result = out_graph(*inps) self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) def test_dupes_and_bypass_reorder_with_non_tensor_arg_with_aten_graph(self): inp = torch.tensor([0.1, 0.1]) inp2 = torch.tensor([0.1, 0.1]) inp3 = 4 inps = [inp, inp2, inp3] def func(x, z, k): y = x + k return z, y, y opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) real_result = opt_func(*inps) torch._dynamo.reset() exported = torch._dynamo.export(func, aten_graph=True)(*inps) out_graph = exported[0] dynamo_result = out_graph(*inps) self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) @config.patch(capture_scalar_outputs=True) def test_dupes_and_bypass_with_non_tensor_output_with_aten_graph(self): inp = torch.tensor([0.1, 0.1]) inp2 = torch.tensor([0.1, 0.1]) inp3 = 4 inps = [inp, inp2, inp3] def func(x, z, k): y = x + k return y[0].item(), y, z opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) real_result = opt_func(*inps) torch._dynamo.reset() exported = torch._dynamo.export(func)(*inps) out_graph = exported[0] dynamo_result = out_graph(*inps) self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) def test_zeroes_in_and_out_different_shape_on_test_with_aten_graph(self): inp = torch.zeros(10) inp2 = torch.zeros(10) inp3 = torch.zeros(10) inps = [inp, inp2, inp3] inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)] def func(a, b, c): return [[a], [b, c], [a + b], [[c + c]]] opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) real_result = opt_func(*inps_rand) torch._dynamo.reset() exported = torch._dynamo.export(func, aten_graph=True)(*inps) out_graph = exported[0] dynamo_result = out_graph(*inps_rand) self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) def test_func_return_with_aten_graph(self): inp = torch.zeros(10) inp2 = torch.zeros(10) inp3 = torch.zeros(10) inps = [inp, inp2, inp3] inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)] def func(a, b, c): x = a + b + c def func2(y): return x * y return func2(x) opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) real_result = opt_func(*inps_rand) torch._dynamo.reset() exported = torch._dynamo.export(func, aten_graph=True)(*inps) out_graph = exported[0] dynamo_result = out_graph(*inps_rand) self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) def test_dict_return_with_aten_graph(self): inp = torch.zeros(10) inp2 = torch.zeros(10) inp3 = torch.zeros(10) inps = [inp, inp2, inp3] inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)] def func(a, b, c): x = a + b + c return {"a": x} opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) real_result = opt_func(*inps_rand) torch._dynamo.reset() exported = torch._dynamo.export(func, aten_graph=True)(*inps) out_graph = exported[0] dynamo_result = out_graph(*inps_rand) self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) def test_export_with_stack_trace(self): inp = torch.randn(4, 4) class MyBlock(torch.nn.Module): def forward(self, x): x = torch.nn.functional.linear(x, torch.randn(4, 4)) return torch.cos(x).relu() + 1 class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.block = MyBlock() def forward(self, x): out = self.block(x) return out exported = torch._dynamo.export(MyModule(), aten_graph=False)(inp) out_graph = exported[0] for node in out_graph.graph.nodes: if node.op not in {"placeholder", "output"}: self.assertTrue(node.stack_trace is not None) self.assertTrue(node.meta["nn_module_stack"] is not None) self.assertTrue(node.meta["source_fn_stack"] is not None) torch._dynamo.reset() exported = torch._dynamo.export(MyModule(), aten_graph=True)(inp) out_graph = exported[0] for node in out_graph.graph.nodes: if node.op == "call_function": self.assertTrue(node.stack_trace is not None) self.assertTrue(node.meta["nn_module_stack"] is not None) self.assertTrue(node.meta["source_fn_stack"] is not None) self.assertTrue(node.meta["val"] is not None) self.assertTrue(node.meta["original_aten"] is not None) def test_export_preserves_nn_module_stack_for_get_attr(self): inp = torch.randn(4, 4) class MyBlock(torch.nn.Module): def __init__(self) -> None: super().__init__() self.weight = torch.nn.Parameter(torch.ones(1, 1)) self.buffer = torch.nn.Buffer(torch.ones(1, 1)) def forward(self, x): x = torch.nn.functional.linear(x, torch.randn(4, 4)) return torch.cos(x).relu() + self.weight + self.buffer class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.block = MyBlock() def forward(self, x): out = self.block(x) return out m = MyModule() exported = torch._dynamo.export(m, aten_graph=False)(inp) out_graph = exported[0] attr_access_count = 0 for node in out_graph.graph.nodes: if node.op == "get_attr": attr_access_count += 1 self.assertTrue(node.meta["nn_module_stack"] is not None) self.assertEqual(attr_access_count, 2) torch._dynamo.reset() exported = torch._dynamo.export(m, aten_graph=True)(inp) out_graph = exported[0] attr_access_count = 0 for node in out_graph.graph.nodes: if node.op == "get_attr": attr_access_count += 1 self.assertTrue(node.meta["nn_module_stack"] is not None) self.assertEqual(attr_access_count, 2) def test_export_compare_optimize_with_make_fx(self): inp = torch.tensor([0.1, 0.1]) linear = torch.nn.Linear(2, 2) def func(x): x = x + 1 y = x.t() y = y.relu() y = linear(y) return y exported = torch._dynamo.export(func, aten_graph=True)(inp) out_graph = exported[0] export_result = out_graph(inp) torch._dynamo.reset() def compiler(gm, sample_inputs): def fw(*args): aten_gm = make_fx(gm)(*args) return aten_gm(*args) return fw opt_func = torch._dynamo.optimize(compiler, nopython=True, dynamic=True)(func) make_fx_result_through_backend = opt_func(inp) fx_g = make_fx(func)(inp) make_fx_result_through_direct = fx_g(inp) self.assertTrue( torch._dynamo.utils.same(make_fx_result_through_backend, export_result) ) self.assertTrue( torch._dynamo.utils.same(make_fx_result_through_direct, export_result) ) def test_export_with_constant_method_on_module(self): class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.param = torch.nn.Parameter(torch.rand(4, 2)) self.linear = torch.nn.Linear(2, 2) @torch._dynamo.assume_constant_result def helper_fn(self, x): return torch.nonzero(x) def forward(self, x): y = torch.sin(x) x = self.linear(x) y = self.helper_fn(x) return y module = MyModule() real_result = module(torch.tensor([[1.0, 0], [0, 0]])) module = MyModule() graph, _ = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]])) result = graph(torch.tensor([[1.0, 0.0], [0, 0]])) self.assertTrue(torch._dynamo.utils.same(result, real_result)) result = graph(torch.tensor([[1, 0], [0.25, 0.25]])) self.assertTrue(torch._dynamo.utils.same(result, real_result)) def test_export_with_constant_method_on_module_invoke_twice(self): class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.param = torch.nn.Parameter(torch.rand(4, 2)) self.linear = torch.nn.Linear(2, 2) @torch._dynamo.assume_constant_result def helper_fn(self, x): return torch.nonzero(x) def forward(self, x): y = torch.sin(x) x = self.linear(x) y = self.helper_fn(x) + self.helper_fn(x) return y module = MyModule() real_result = module(torch.tensor([[1.0, 0], [0, 0]])) module = MyModule() graph, _ = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]])) result = graph(torch.tensor([[1.0, 0.0], [0, 0]])) self.assertTrue(torch._dynamo.utils.same(result, real_result)) result = graph(torch.tensor([[1, 0], [0.25, 0.25]])) self.assertTrue(torch._dynamo.utils.same(result, real_result)) def test_export_with_constant_free_function(self): @torch._dynamo.assume_constant_result def helper_fn(x): return torch.nonzero(x) class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.param = torch.nn.Parameter(torch.rand(4, 2)) self.linear = torch.nn.Linear(2, 2) @torch._dynamo.assume_constant_result def helper_fn(self, x): return torch.nonzero(x) def forward(self, x): y = torch.sin(x) x = self.linear(x) y = helper_fn(x) + self.helper_fn(x) return y module = MyModule() real_result = module(torch.tensor([[1.0, 0], [0, 0]])) module = MyModule() graph, _ = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]])) result = graph(torch.tensor([[1.0, 0.0], [0, 0]])) self.assertTrue(torch._dynamo.utils.same(result, real_result)) result = graph(torch.tensor([[1, 0], [0.25, 0.25]])) self.assertTrue(torch._dynamo.utils.same(result, real_result)) def test_export_with_constant_free_function_and_class_method(self): @torch._dynamo.assume_constant_result def helper_fn(x): return torch.nonzero(x) class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.param = torch.nn.Parameter(torch.rand(4, 2)) self.linear = torch.nn.Linear(2, 2) def forward(self, x): y = torch.sin(x) x = self.linear(x) y = helper_fn(x) return y module = MyModule() real_result = module(torch.tensor([[1.0, 0], [0, 0]])) module = MyModule() graph, _ = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]])) result = graph(torch.tensor([[1.0, 0.0], [0, 0]])) self.assertTrue(torch._dynamo.utils.same(result, real_result)) result = graph(torch.tensor([[1, 0], [0.25, 0.25]])) self.assertTrue(torch._dynamo.utils.same(result, real_result)) def test_export_with_constant_free_function_and_class_method_multiarg(self): @torch._dynamo.assume_constant_result def helper_fn(x): return torch.nonzero(x) class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.param = torch.nn.Parameter(torch.rand(4, 2)) self.linear = torch.nn.Linear(2, 2) def forward(self, x, z): y = torch.sin(x) x = self.linear(x) y = helper_fn(x) + helper_fn(z) return y module = MyModule() real_result = module( torch.tensor([[1.0, 0], [0, 0]]), torch.tensor([[1.0, 0], [0, 0]]) ) module = MyModule() graph, _ = torch._dynamo.export(module)( torch.tensor([[0.0, 0], [0, 0]]), torch.tensor([[1.0, 0], [0, 0]]) ) result = graph( torch.tensor([[1.0, 0.0], [0, 0]]), torch.tensor([[1.0, 0.0], [0, 0]]) ) self.assertTrue(torch._dynamo.utils.same(result, real_result)) result = graph( torch.tensor([[1, 0], [0.25, 0.25]]), torch.tensor([[1, 0], [0.25, 0.25]]) ) self.assertTrue(torch._dynamo.utils.same(result, real_result)) def test_export_with_constant_free_function_and_class_method_multiarg_diff(self): @torch._dynamo.assume_constant_result def helper_fn(x): return torch.nonzero(x) class MyModule(torch.nn.Module): def forward(self, x, z): y = helper_fn(x) + helper_fn(z) return y module = MyModule() real_result = module( torch.tensor([[1.0, 0], [0, 0]]), torch.tensor([[1.0, 0], [0, 0]]) ) module = MyModule() graph, _ = torch._dynamo.export(module)( torch.tensor([[0.0, 0], [0, 0]]), torch.tensor([[0.0, 0], [0.5, 0]]) ) result = graph( torch.tensor([[1.0, 0.0], [0, 0]]), torch.tensor([[0.0, 1.0], [0, 0]]) ) self.assertTrue(torch._dynamo.utils.same(result, real_result)) result = graph( torch.tensor([[1, 0], [0.25, 0.25]]), torch.tensor([[0.33, 0.33], [0.25, 0.25]]), ) self.assertTrue(torch._dynamo.utils.same(result, real_result)) def test_export_with_constant_tuple_nonzero(self): class MyModule(torch.nn.Module): @torch._dynamo.assume_constant_result def helper_fn(self, x): return (torch.nonzero(x), torch.nonzero(x)) def forward(self, x): y = torch.tensor([0.5]) elements = self.helper_fn(x) all_y = [] for element in elements: for item in element: all_y.append(y * item) return all_y module = MyModule() real_result = module(torch.tensor([1.0, 1.0])) graph, guards = torch._dynamo.export(module)(torch.tensor([1.0, 1.0])) # Tensor input can be almost anything here, and the result will capture what we # made constant at compile time. result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]])) self.assertTrue(torch._dynamo.utils.same(result, real_result)) def test_export_with_constant_list_nonzero(self): class MyModule(torch.nn.Module): @torch._dynamo.assume_constant_result def helper_fn(self, x): return [torch.nonzero(x), torch.nonzero(x)] def forward(self, x): y = torch.tensor([0.5]) elements = self.helper_fn(x) all_y = [] for element in elements: for item in element: all_y.append(y * item) return all_y module = MyModule() real_result = module(torch.tensor([1.0, 1.0])) graph, guards = torch._dynamo.export(module)(torch.tensor([1.0, 1.0])) # Tensor input can be almost anything here, and the result will capture what we # made constant at compile time. result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]])) self.assertTrue(torch._dynamo.utils.same(result, real_result)) def test_export_with_constant_list_nonzero_free_function(self): @torch._dynamo.assume_constant_result def helper_fn(x): return [torch.nonzero(x), torch.nonzero(x)] class MyModule(torch.nn.Module): def forward(self, x): y = torch.tensor([0.5]) elements = helper_fn(x) all_y = [] for element in elements: for item in element: all_y.append(y * item) return all_y module = MyModule() real_result = module(torch.tensor([1.0, 1.0])) graph, guards = torch._dynamo.export(module)(torch.tensor([1.0, 1.0])) # Tensor input can be almost anything here, and the result will capture what we # made constant at compile time. result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]])) self.assertTrue(torch._dynamo.utils.same(result, real_result)) def test_export_with_constant_dict_values(self): class MyModule(torch.nn.Module): @torch._dynamo.assume_constant_result def helper_fn(self, x): return {"x": x, "x^2": x * x} def forward(self, x): y = torch.tensor([0.5]) elements = self.helper_fn(x) y = y * elements["x"] y = y * elements["x^2"] return y module = MyModule() real_result = module(torch.tensor([2.0, 2.0])) graph, guards = torch._dynamo.export(module)(torch.tensor([2.0, 2.0])) # Tensor input can be almost anything here, and the result will capture what we # made constant at compile time. result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]])) self.assertTrue(torch._dynamo.utils.same(result, real_result)) def test_export_with_constant_none_control_flow(self): class MyModule(torch.nn.Module): @torch._dynamo.assume_constant_result def helper_fn(self, x): if x.item() < 0: return None else: return x def forward(self, x): y = torch.tensor([0.5]) x = self.helper_fn(x) if x is None: return y return y * x module = MyModule() real_result = module(torch.tensor([-1])) # X is negative, so .item() < 0, which means we return y self.assertEqual(real_result, torch.tensor([0.5])) graph, guards = torch._dynamo.export(module)(torch.tensor([-1])) result = graph(torch.tensor([2])) # X is positive, but we compiled helper_fn to return None, so it will still return y self.assertTrue(torch._dynamo.utils.same(result, real_result)) def test_export_with_constant_not_none_control_flow(self): class MyModule(torch.nn.Module): @torch._dynamo.assume_constant_result def helper_fn(self, x): if x.item() < 0: return None else: return x def forward(self, x): y = torch.tensor([0.5]) x = self.helper_fn(x) if x is None: return y return y * x module = MyModule() real_result = module(torch.tensor([2])) # X is positive, so .item() > 0, which means we return y * x self.assertEqual(real_result, torch.tensor([1.0])) graph, guards = torch._dynamo.export(module)(torch.tensor([2])) result = graph(torch.tensor([-0.5])) # X is negative, but we compiled helper_fn to return x, so it will still return y * x self.assertTrue(torch._dynamo.utils.same(result, real_result)) def test_export_with_constant_none_control_flow_free_func(self): @torch._dynamo.assume_constant_result def helper_fn(x): if x.item() < 0: return None else: return x class MyModule(torch.nn.Module): def forward(self, x): y = torch.tensor([0.5]) x = helper_fn(x) if x is None: return y return y * x module = MyModule() real_result = module(torch.tensor([-1])) # X is negative, so .item() < 0, which means we return y self.assertEqual(real_result, torch.tensor([0.5])) graph, guards = torch._dynamo.export(module)(torch.tensor([-1])) result = graph(torch.tensor([2])) # X is positive, but we compiled helper_fn to return None, so it will still return y self.assertTrue(torch._dynamo.utils.same(result, real_result)) def test_export_with_constant_not_none_control_flow_pos(self): class MyModule(torch.nn.Module): @torch._dynamo.assume_constant_result def helper_fn(self, x): if x.item() < 0: return None else: return x def forward(self, x): y = torch.tensor([0.5]) x = self.helper_fn(x) if x is None: return y return y * x module = MyModule() real_result = module(torch.tensor([2])) # X is positive, so .item() > 0, which means we return y * x self.assertEqual(real_result, torch.tensor([1.0])) graph, guards = torch._dynamo.export(module)(torch.tensor([2])) result = graph(torch.tensor([-0.5])) # X is negative, but we compiled helper_fn to return x, so it will still return y * x self.assertTrue(torch._dynamo.utils.same(result, real_result)) def test_export_with_constant_not_none_control_flow_free_func(self): @torch._dynamo.assume_constant_result def helper_fn(x): if x.item() < 0: return None else: return x class MyModule(torch.nn.Module): def forward(self, x): y = torch.tensor([0.5]) x = helper_fn(x) if x is None: return y return y * x module = MyModule() real_result = module(torch.tensor([2])) # X is positive, so .item() > 0, which means we return y * x self.assertEqual(real_result, torch.tensor([1.0])) graph, guards = torch._dynamo.export(module)(torch.tensor([2])) result = graph(torch.tensor([-0.5])) # X is negative, but we compiled helper_fn to return x, so it will still return y * x self.assertTrue(torch._dynamo.utils.same(result, real_result)) def test_export_with_constant_not_return_const(self): class MyModule(torch.nn.Module): @torch._dynamo.assume_constant_result def helper_fn(self, x): return self.val def forward(self, x): y = torch.tensor([0.5]) x = self.helper_fn(x) if x == "A": return y return -1 module = MyModule() module.val = "A" resA = module(torch.tensor([2])) graph, guards = torch._dynamo.export(module)(torch.tensor([2])) module.val = "B" resB = graph(torch.tensor([2])) self.assertTrue(torch._dynamo.utils.same(resA, resB)) def test_export_with_builtin_op_on_assume_constant(self): @torch._dynamo.assume_constant_result def get_y(y) -> torch.Tensor: return y class Bob(torch.nn.Module): def __init__(self, p, val) -> None: super().__init__() self.p = p self.y = torch.nn.Parameter(torch.tensor(val)) def forward(self, x: torch.Tensor) -> torch.Tensor: # This only looks dynamic but it's actually a constant value if get_y(self.y) < self.p: return torch.cat([x, x]) else: return x model = Bob(0.5, 0.3) inp = torch.ones(3, 4) graph, guards = torch._dynamo.export(model)(inp) self.assertEqual(model(inp), graph(inp)) def test_export_with_constant_in_unspecialized_nn_module(self): class Module(torch.nn.Module): def __init__(self, y): super().__init__() self.y = y @torch._dynamo.assume_constant_result def check(self): return self.y[0].item() == 1 def forward(self, x): # This line leads to module obj being tracked as UnspecializedNNModuleVariable in dynamo self.device = x.device if self.check(): return x + 1 else: return x + 2 model = Module(torch.tensor([1])) inp = torch.ones(3, 4) graph, _ = torch._dynamo.export(model)(inp) self.assertEqual(model(inp), graph(inp)) def test_export_decomp(self): def f(x): return x.t() + x.t() def nop(x): return x.cos() graph, _ = torch._dynamo.export( f, aten_graph=True, decomposition_table={torch.ops.aten.t.default: nop}, )(torch.randn(5)) self.assertEqual( len([n for n in graph.graph.nodes if n.target == torch.ops.aten.t.default]), 0, ) graph, _ = torch._dynamo.export(f, aten_graph=True, decomposition_table=None)( torch.randn(5) ) self.assertEqual( len([n for n in graph.graph.nodes if n.target == torch.ops.aten.t.default]), 2, ) def test_export_decomp_asserts_bad_args(self): def f(x): return x.t() + x.t() def nop(x): return x.cos() with self.assertRaises(AssertionError): graph, _ = torch._dynamo.export( f, (torch.randn(5)), aten_graph=False, decomposition_table={torch.ops.aten.t.default: nop}, ) @config.patch(capture_scalar_outputs=True) def test_export_with_module_layer(self): from functorch.experimental.control_flow import cond class Module(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(3, 3) def forward(self, pred, x): def true_fn(val): return self.linear(val) * torch.tensor(2) def false_fn(val): return self.linear(val) * torch.tensor(-1) return cond(pred, true_fn, false_fn, [x]) mod = Module() x = torch.randn([3, 3]) pred = torch.tensor(x[0][0].item() < 0) real_result = mod.forward(pred, x) torch._dynamo.reset() exported = torch._dynamo.export(mod.forward)(pred, x) out_graph = exported[0] dynamo_result = out_graph(pred, x) self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) # New X, just to show we did not specialize x = x * -1 pred = torch.tensor(x[0][0].item() < 0) real_result_2 = mod.forward(pred, x) dynamo_result_2 = out_graph(pred, x) self.assertTrue(torch._dynamo.utils.same(real_result_2, dynamo_result_2)) @config.patch(capture_scalar_outputs=True) def test_export_with_cond_branches_calling_methods(self): from functorch.experimental.control_flow import cond class Module(torch.nn.Module): # ok def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(3, 3) def t(self, val): return val + 1 def f(self, val): return val - 1 def true_fn(self, val): return self.linear(val) + self.t(val) def false_fn(self, val): return self.linear(val) - self.f(val) def forward(self, pred, x): return cond(pred, self.true_fn, self.false_fn, [x]) mod = Module() x = torch.randn([3, 3]) pred = torch.tensor(x[0][0].item() < 0) real_result = mod.forward(pred, x) out_graph, _ = torch._dynamo.export(mod.forward)(pred, x) dynamo_result = out_graph(pred, x) self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) @config.patch(capture_scalar_outputs=True) def test_export_with_cond_closure(self): from functorch.experimental.control_flow import cond class Foo(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, pred, x): def true_fn(x): return x * 2 def false_fn(x): return x - 2 return cond(pred, true_fn, false_fn, [x]) class Bar(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, pred, x): def true_fn(x): return x * 2 def false_fn(x): return x - 2 return cond(pred, true_fn, false_fn, [x + 1]) class FooBar(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(3, 3) def forward(self, pred, x): y = x + x def true_fn(x, y): return self.linear(x) * (x + y) def false_fn(x, y): return x * (y - x) return cond(pred, true_fn, false_fn, [x, y]) for Module in [Foo, Bar, FooBar]: mod = Module() x = torch.randn([3, 3], requires_grad=True) pred = torch.tensor(x[0][0].item() < 0) real_result = mod.forward(pred, x) out_graph, _ = torch._dynamo.export(mod.forward)(pred, x) dynamo_result = out_graph(pred, x) self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) def test_export_with_cond_with_closed_function(self): def hello(x): return x + 1 def hi(x): return x + 2 def foo(pred, x): def true_fn(x): return hello(x) def false_fn(x): return hi(x) return cond(pred, true_fn, false_fn, [x]) x = torch.randn(5) pred = x[0] > 0 real_result = foo(pred, x) out_graph, _ = torch._dynamo.export(foo)(pred, x) dynamo_result = out_graph(pred, x) self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) def test_export_with_cond_dynamic_shape_pred(self): from functorch.experimental.control_flow import cond class Module(torch.nn.Module): def forward(self, x): def true_fn(x): return x + x def false_fn(x): return x[:2] return cond(x.shape[0] <= 2, true_fn, false_fn, [x]) class Module2(torch.nn.Module): def forward(self, x): def true_fn(x): return x + x def false_fn(x): return x[:2] return cond(x.shape[0] <= 2, true_fn, false_fn, (x,)) mods = [Module(), Module2()] for mod in mods: x = torch.randn(2, 2) out_graph, guards = torch._dynamo.export(mod)(x) self.assertExpectedInline( out_graph.code.strip(), """\ def forward(self, x): arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) l_x_ = arg0 size = l_x_.size() getitem = size[0]; size = None le = getitem <= 2; getitem = None cond_true_0 = self.cond_true_0 cond_false_0 = self.cond_false_0 cond = torch.ops.higher_order.cond(le, cond_true_0, cond_false_0, [l_x_]); le = cond_true_0 = cond_false_0 = l_x_ = None getitem_2 = cond[0]; cond = None return pytree.tree_unflatten([getitem_2], self._out_spec)""", ) self.assertExpectedInline( out_graph.cond_true_0.code.strip(), """\ def forward(self, l_x_): l_x__1 = l_x_ add = l_x__1 + l_x__1; l_x__1 = None return (add,)""", ) self.assertExpectedInline( out_graph.cond_false_0.code.strip(), """\ def forward(self, l_x_): l_x__1 = l_x_ getitem = l_x__1[slice(None, 2, None)]; l_x__1 = None return (getitem,)""", ) with self.assertRaisesRegex( torch._dynamo.exc.UncapturedHigherOrderOpError, "Cond doesn't work unless it is captured completely with torch.compile", ): # True branch and false branch return tensors of different shape torch._dynamo.export(mod)(torch.randn(3, 2)) # We specialize into one of the branches since predicate is a python boolean. test_x = torch.randn(3, 2) mod(test_x) def test_export_with_map_cond(self): from functorch.experimental.control_flow import cond, map class Module(torch.nn.Module): def inner(self, x, pred): def true_fn(x): return x + x def false_fn(x): return x * x return cond(pred, true_fn, false_fn, [x]) def forward(self, pred, xs): def body(x, pred): return self.inner(x, pred) return map(body, xs, pred) mod = Module() x = torch.randn(3, 2, 1) pred_x = torch.tensor(True) y = torch.randn(4, 3, 2) pred_y = torch.tensor(False) real_result = mod(pred_y, y) out_graph, _ = torch._dynamo.export(mod)(pred_x, x) self.assertEqual(real_result, out_graph(pred_y, y)) def test_export_with_map_zero_sized_tensor(self): from functorch.experimental.control_flow import map class Module(torch.nn.Module): def forward(self, xs): def body(x): return x + 1 return map(body, xs) mod = Module() xs = torch.randn(0, 2) with self.assertRaisesRegex( torch._dynamo.exc.Unsupported, "zero-sized tensor", ): out_graph, _ = torch._dynamo.export(mod)(xs) def test_export_meta_val(self): def f(x, y, z): return x * y + z gm, _ = torch._dynamo.export( f, aten_graph=True, )( torch.ones(3, 2), torch.zeros(3, 2), torch.ones(3, 2), ) for node in gm.graph.nodes: if node.op == "placeholder": self.assertIn("val", node.meta) def test_input_container_type(self): def f(x: torch.Tensor, y: List[torch.Tensor]) -> Dict[str, torch.Tensor]: return {"a": x.sum() + sum(y).sum()} inp = (torch.randn(6, 5), [torch.randn(6, 5), torch.randn(6, 5)]) gm, _ = torch._dynamo.export(f, aten_graph=True)(*inp) self.assertEqual(gm(*inp), f(*inp)) @config.patch(assume_static_by_default=False) def test_export_symbolic_shape(self): def f(x: torch.Tensor) -> torch.Tensor: return torch.empty(x.shape[0] * 2) inp = (torch.randn(6, 5),) gm, _ = torch._dynamo.export(f, aten_graph=True)(*inp) has_sym_size = False for node in gm.graph.nodes: if node.target is torch.ops.aten.sym_size.int: has_sym_size = True self.assertTrue(has_sym_size) @config.patch(assume_static_by_default=False) def test_dynamic_slicing(self): def f(x): return x[: x.shape[0] - 2, x.shape[1] - 1 :: 2] gm_aten_mode, _ = torch._dynamo.export(f, aten_graph=True)(torch.randn(4, 5)) inp = torch.randn(6, 7) self.assertEqual(gm_aten_mode(inp).shape, f(inp).shape) count = 0 # aten graph should flatten getitem calls to actual # slice kernel call. for node in gm_aten_mode.graph.nodes: if ( node.op == "call_function" and node.target == torch.ops.aten.slice.Tensor ): count += 1 self.assertEqual(count, 2) gm_torch_mode, _ = torch._dynamo.export(f, aten_graph=False)(torch.randn(4, 5)) # In torch mode, the graph should contain 3 getitem methods # one for x.shape[0]-2 and one for x.shape[1]-1 and one for slice # this is because Tensor class has its' own getitem method # which gets translated to aten.Slice later. count = 0 for node in gm_torch_mode.graph.nodes: if node.op == "call_function" and node.target == operator.getitem: count += 1 self.assertEqual(count, 1) self.assertEqual(gm_torch_mode(inp).shape, f(inp).shape) def test_dynamic_slicing_invalid(self): def g(x, y): return x[y : x.shape[0]] with self.assertRaisesRegex( torch._dynamo.exc.Unsupported, "Dynamic slicing on data-dependent value is not supported", ): torch._dynamo.export( g, aten_graph=True, )( torch.randn(4, 5), torch.tensor(2), ) @config.patch(capture_scalar_outputs=True) def test_dynamic_slicing_simple(self): def f(x): return x[slice(None, None, None)] gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.randn(4, 5)) inp = torch.randn(6, 7) self.assertEqual(gm(inp), f(inp)) def test_pre_dispatch_simple(self): def f(x): y = torch.ones_like(x) return torch.matmul(x, y) gm, _ = torch._dynamo.export( f, aten_graph=True, pre_dispatch=True, tracing_mode="fake", )( torch.randn(5, 5), ) inp = torch.randn(6, 6) self.assertEqual(gm(inp), f(inp)) self.assertExpectedInline( gm.code.strip(), """\ def forward(self, x): arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) arg0_1 = arg0 ones_like = torch.ops.aten.ones_like.default(arg0_1, pin_memory = False) matmul = torch.ops.aten.matmul.default(arg0_1, ones_like); arg0_1 = ones_like = None return pytree.tree_unflatten([matmul], self._out_spec)""", ) @patch.object(torch._dynamo.config, "capture_scalar_outputs", True) def test_export_cond_in_aten_symbolic(self): class ConditionOp(torch.nn.Module): def true_fn(self, x, y): return x * y def false_fn(self, x, y): return x + y def forward(self, pred, x, y): return cond(pred, self.true_fn, self.false_fn, [x, y]) model = ConditionOp() inp = ( torch.tensor(False), torch.randn(4, 4), torch.randn(4, 4), ) gm, _ = torch._dynamo.export(model, aten_graph=True)(*inp) gm.print_readable() self.assertEqual(gm(*inp), model(*inp)) def test_export_with_kwargs(self): def fn_with_kwargs(pos0, tuple0, *myargs, mykw0=None, **mykwargs): out = pos0 for arg in tuple0: out *= arg for arg in myargs: out *= arg out *= mykw0 out *= mykwargs["input0"] * mykwargs["input1"] return out mykwargs = {"input0": torch.randn(4), "input1": torch.randn(4)} tuple0 = (torch.randn(4), torch.randn(4)) mykw0 = torch.randn(4) pos0 = torch.randn(4) myargs = [torch.randn(4), torch.randn(4)] expected_argument_names = [ "pos0", "tuple0", "myargs_0", "myargs_1", "mykw0", "input0", "input1", ] self._test_export_preserving_original_signature( fn_with_kwargs, expected_argument_names, pos0, tuple0, *myargs, mykw0=mykw0, **mykwargs, ) def test_export_with_kwargs_and_empty_args(self): def fn_with_kwargs(mykw0=None, **mykwargs): out = mykw0 out *= mykwargs["input0"] * mykwargs["input1"] return out mykwargs = {"input0": torch.randn(4), "input1": torch.randn(4)} mykw0 = torch.randn(4) expected_argument_names = ["mykw0"] + list(mykwargs.keys()) self._test_export_preserving_original_signature( fn_with_kwargs, expected_argument_names, mykw0, **mykwargs ) def test_export_with_args_and_empty_kwargs(self): def fn_with_kwargs(pos0, tuple0, *myargs): out = pos0 for arg in tuple0: out *= arg for arg in myargs: out *= arg return out tuple0 = (torch.randn(4), torch.randn(4)) pos0 = torch.randn(4) myargs = [torch.randn(4), torch.randn(4)] expected_argument_names = ["pos0", "tuple0", "myargs_0", "myargs_1"] self._test_export_preserving_original_signature( fn_with_kwargs, expected_argument_names, pos0, tuple0, *myargs ) @common_utils.parametrize( "default_value", [ common_utils.subtest(None, name="None"), common_utils.subtest(42.0, name="float"), common_utils.subtest( # FIXME: AssertionError: Dynamo input and output is a strict subset of traced input/output torch.randn(4), name="tensor", decorators=[unittest.expectedFailure], ), common_utils.subtest( # FIXME: AssertionError: Dynamo input and output is a strict subset of traced input/output (torch.randn(4),), name="tuple", decorators=[unittest.expectedFailure], ), ], ) def test_export_with_args_with_default(self, default_value): def fn(pos0, pos1_default=default_value): out = pos0 if pos1_default is None: pos1_default = torch.randn(4) if isinstance(pos1_default, tuple): pos1_default = pos1_default[0] out *= pos1_default return out pos0 = torch.randn(4) expected_argument_names = ["pos0"] self._test_export_preserving_original_signature( fn, expected_argument_names, pos0 ) @common_utils.parametrize( "default_value", [ common_utils.subtest(None, name="None"), common_utils.subtest(42.0, name="float"), common_utils.subtest( # FIXME: AssertionError: Dynamo input and output is a strict subset of traced input/output torch.randn(4), name="tensor", decorators=[unittest.expectedFailure], ), common_utils.subtest( # FIXME: AssertionError: Dynamo input and output is a strict subset of traced input/output (torch.randn(4),), name="tuple", decorators=[unittest.expectedFailure], ), ], ) def test_export_with_kwargs_with_default(self, default_value): def fn(pos0, *, kw0, kw1_default=default_value, **kwargs): out = pos0 out += kw0 if kw1_default is None: kw1_default = torch.randn(4) elif isinstance(kw1_default, tuple): kw1_default = kw1_default[0] out += kw1_default out += kwargs["kw2"] return out pos0 = torch.randn(4) kw0 = torch.randn(4) kw2 = torch.randn(4) args = (pos0,) kwargs = {"kw0": kw0, "kw2": kw2} expected_argument_names = ["pos0", "kw0", "kw2"] self._test_export_preserving_original_signature( fn, expected_argument_names, *args, **kwargs ) def test_export_with_wrapped_fn(self): # To ensure dynamo.export is robust to wrapped functions # when it cannot use `inspect` to retrieve original signature # info. def _fn(pos0, pos1=1.0, *args, kw0, kw1=2.0, **kwargs): out = pos0 out += pos1 out += kw0 out += kw1 for arg in args: out += arg for kwarg in kwargs.values(): out += kwarg return out def wrapped_fn(*args, **kwargs): return _fn(*args, **kwargs) pos0 = torch.randn(4) kw0 = torch.randn(4) args = (pos0, torch.randn(4), torch.randn(4)) kwargs = {"kw0": kw0, "kw2": torch.randn(4)} expected_argument_names = [f"args_{i}" for i in range(len(args))] + list( kwargs.keys() ) self._test_export_preserving_original_signature( wrapped_fn, expected_argument_names, *args, **kwargs ) def test_export_with_functools_wrapped_method(self): def test_decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): return func(*args, **kwargs) return wrapper class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, x): return x @test_decorator def method_to_test(self, pos0, pos1=1.0, *args, kw0, kw1=2.0, **kwargs): out = pos0 out += pos1 out += kw0 out += kw1 for arg in args: out += arg for kwarg in kwargs.values(): out += kwarg return out pos0 = torch.randn(4) pos1 = torch.randn(4) unnamed_pos = torch.randn(4) kw0 = torch.randn(4) args = (pos0, pos1, unnamed_pos) kwargs = {"kw0": kw0, "kw2": torch.randn(4), "unnamed_kw": torch.randn(4)} expected_argument_names = [ "pos0", "pos1", "args_0", # 3rd unnamed positional argument ] + list(kwargs.keys()) m = MyModule() self._test_export_preserving_original_signature( m.method_to_test, expected_argument_names, *args, **kwargs ) def test_export_with_functools_wrapped_fn(self): def test_decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): return func(*args, **kwargs) return wrapper @test_decorator def _fn(pos0, pos1=1.0, *args, kw0, kw1=2.0, **kwargs): out = pos0 out += pos1 out += kw0 out += kw1 for arg in args: out += arg for kwarg in kwargs.values(): out += kwarg return out def wrapped_fn(*args, **kwargs): return _fn(*args, **kwargs) pos0 = torch.randn(4) kw0 = torch.randn(4) args = (pos0, torch.randn(4), torch.randn(4)) kwargs = {"kw0": kw0, "kw2": torch.randn(4)} expected_argument_names = [f"args_{i}" for i in range(len(args))] + list( kwargs.keys() ) self._test_export_preserving_original_signature( wrapped_fn, expected_argument_names, *args, **kwargs ) def _test_export_preserving_original_signature( self, fn, expected_argument_names: Sequence[str], *args, **kwargs ): torch._dynamo.reset() exported = torch._dynamo.export( fn, *args, **kwargs, aten_graph=False, ) out_graph = exported[0] dynamo_result = out_graph(*args, **kwargs) real_result = fn(*args, **kwargs) self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) # Check that the exported graph preserves same argument names. self.assertEqual( inspect.getfullargspec(out_graph.forward).args[1:], expected_argument_names ) def test_dataclass_input_output(self): from dataclasses import dataclass @dataclass class Tensors: x: torch.Tensor y: torch.Tensor def f(t): return t.x + t.y with self.assertRaisesRegex( UserError, "It looks like one of the inputs with type .*Tensors.* " "is not supported or pytree-flattenable", ): torch._dynamo.export(f, aten_graph=False)( Tensors(x=torch.randn(10), y=torch.randn(10)) ) def f(x, y): return Tensors(x=x.sin(), y=y.cos()) with self.assertRaisesRegex( UserError, "It looks like one of the outputs with type .*Tensors.* " "is not supported or pytree-flattenable", ): torch._dynamo.export(f, aten_graph=False)(torch.randn(10), torch.randn(10)) def test_empty(self): def f(x): return x exported = torch._dynamo.export(f)(torch.randn(3, 3)) out_graph = exported[0] inp = torch.randn(3, 3) self.assertTrue(torch._dynamo.utils.same(inp, out_graph(inp))) class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.a = torch.ones(3, 3) def forward(self): return self.a exported = torch._dynamo.export(M())() out_graph = exported[0] self.assertTrue(torch._dynamo.utils.same(torch.ones(3, 3), out_graph())) @unittest.skipIf(not TEST_CUDA, "No CUDA available.") def test_export_with_parameters(self): class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.features = torch.nn.Sequential( torch.nn.Conv2d( 3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1) ), torch.nn.ReLU(inplace=True), ) def forward(self, x): return self.features(x) model = MyModule().eval().cuda() random_inputs = (torch.rand([32, 3, 32, 32]).to("cuda"),) dim_x = torch.export.Dim("dim_x", min=1, max=32) exp_program = torch.export.export( model, random_inputs, dynamic_shapes={"x": {0: dim_x}} ) output_buffer = io.BytesIO() # Tests if we can restore saved nn.Parameters when we load them again torch.export.save(exp_program, output_buffer) loaded_model = torch.export.load(output_buffer) self.assertTrue( isinstance( loaded_model.module().get_parameter("features.0.weight"), torch.nn.Parameter, ) ) def test_export_fast_binary_broadcast_check(self): # This test looks at the case where we erroneously create a guard # when checking the equality of the operands' shape and the output # shape during FakeTensor's binary op fast path. class MyModel(torch.nn.Module): def forward(self, a, b): # final shape is (dim0, 4, 8) # order matters since a & the output have the same shape return b + a a = torch.randn(100, 4, 8) b = torch.randn(4, 8) model = MyModel().eval().cuda() batchsize = torch.export.Dim("dim0", min=3, max=1024) dynamic_shape_spec = {"a": [batchsize, None, None], "b": [None, None]} torch.export.export(model, (a, b), dynamic_shapes=dynamic_shape_spec) def test_export_fast_binary_broadcast_check_unbacked(self): class MyModel(torch.nn.Module): def forward(self, numel, scalar): u0 = numel.item() torch._check_is_size(u0) x = torch.ones(u0 + 1) return scalar - x model = MyModel().eval().cuda() numel = torch.tensor(10) scalar = torch.randn(1) torch.export.export(model, (numel, scalar)) def test_export_meta(self): class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.p = torch.nn.Parameter(torch.ones(2, 3)) def forward(self, x): return self.p + x with torch.device("meta"): m = MyModule() inp = torch.ones(2, 3, device="meta") exported = torch._dynamo.export(m)(inp) out_graph = exported[0] dynamo_result = out_graph(inp) self.assertEqual(dynamo_result, m(inp)) def test_constraint_violation_error_messages(self): class Foo(torch.nn.Module): def forward(self, x): if x.shape[0] == x.shape[1] * 2: return x + 1 else: return x + 2 foo = Foo() t = torch.zeros([8, 4]) dim0 = torch.export.Dim("dim0", min=3, max=10) dim1 = torch.export.Dim("dim1") dynamic_shapes = {"x": (dim0, dim1)} with self.assertRaisesRegex( torch._dynamo.exc.UserError, "Constraints violated .*!(.*\n)*.*" "by dim0 = 2\\*dim1(.*\n)*.*" "Not all values of dim1 .* satisfy the generated guard 2 <= .* and .* <= 5(.*\n)*.*", ): torch.export.export(foo, (t,), dynamic_shapes=dynamic_shapes) class Bar(torch.nn.Module): def forward(self, x): if x.shape[0] == 5: return x + 1 else: return x + 2 bar = Bar() t = torch.zeros([5]) dim0 = torch.export.Dim("dim0", min=3, max=8) dynamic_shapes = {"x": (dim0,)} with self.assertRaisesRegex( torch._dynamo.exc.UserError, "Not all values.*valid.*inferred to be a constant", ): torch.export.export(bar, (t,), dynamic_shapes=dynamic_shapes) class Qux(torch.nn.Module): def forward(self, x): if x.shape[0] > 5 and x.shape[0] < 10: return x + 1 else: return x + 2 qux = Qux() t = torch.zeros([7]) dim0 = torch.export.Dim("dim0", min=3, max=8) dynamic_shapes = {"x": (dim0,)} with self.assertRaisesRegex( torch._dynamo.exc.UserError, "Not all values.*satisfy the generated guard", ): torch.export.export(qux, (t,), dynamic_shapes=dynamic_shapes) def test_untracked_inputs_in_constraints(self): from copy import copy class Foo(torch.nn.Module): def forward(self, x, y): return y + 1 foo = Foo() x = torch.randn(2) y = torch.randn(5, 4) dim0_x, dim0_y = torch.export.dims("dim0_x", "dim0_y") dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_y}} example_inputs = (copy(x), y) ep = torch.export.export(foo, example_inputs, dynamic_shapes=dynamic_shapes) ep.module()(torch.randn(3), y) # no specialization error def test_export_raise_guard_full_constraint(self): y = torch.randn([3, 3, 3]) def my_dyn_fn(x): if x.shape[0] == 3: return x.sin() return x.cos() torch._dynamo.export(my_dyn_fn)(y) with self.assertRaises(ConstraintViolationError): torch._dynamo.export( my_dyn_fn, dynamic_shapes=({0: torch.export.Dim("dimx")},) )(y) def test_export_module_specify_constraints_signature(self): y = torch.randn([3, 3, 3]) class Mod(torch.nn.Module): def forward(self, x): if x.shape[0] == 3: return x.sin() return x.cos() mod = Mod() torch._dynamo.export(mod)(y) with self.assertRaisesRegex(ConstraintViolationError, "dimx = 3"): torch._dynamo.export(mod, dynamic_shapes=({0: torch.export.Dim("dimx")},))( y ) def test_export_raise_guard_partial_constraint(self): y = torch.randn([3, 3, 3]) def my_dyn_fn(x): if x.shape[0] > 3: return x.sin() return x.cos() torch._dynamo.export(my_dyn_fn)(y) with self.assertRaises(ConstraintViolationError): torch._dynamo.export( my_dyn_fn, dynamic_shapes=({0: torch.export.Dim("dimx")},) )(y) def test_export_raise_on_relationship(self): y = torch.randn([3, 3, 3]) def my_dyn_fn(a, b, c): if a.shape[0] == b.shape[1] == c.shape[2]: return a.sin() return a.cos() torch._dynamo.export(my_dyn_fn)(y, y, y) dim = torch.export.Dim("dim") dynamic_shapes = ({0: dim}, {0: dim}, {0: dim}) with self.assertRaises(ConstraintViolationError): torch._dynamo.export(my_dyn_fn, dynamic_shapes=dynamic_shapes)(y, y, y) dynamic_shapes = ({0: dim}, {1: dim}, {2: dim}) torch._dynamo.export(my_dyn_fn, dynamic_shapes=dynamic_shapes)(y, y, y) def test_export_no_raise(self): y = torch.randn([3, 3, 3]) def my_dyn_fn(a, b, c): if a.shape[1] == 3: return a.cos() return a * b * c torch._dynamo.export(my_dyn_fn)(y, y, y) dim = torch.export.Dim("dim") dynamic_shapes = ({0: dim}, {0: dim}, {0: dim}) torch._dynamo.export(my_dyn_fn, dynamic_shapes=dynamic_shapes)(y, y, y) def test_export_multi_dynamic_dim_unsafe_relationship(self): x = torch.randn([3, 3, 3]) y = torch.randn([2, 2, 2]) z = torch.randn([3, 3, 3]) def my_dyn_fn(a, b, c): if a.shape[0] == c.shape[0]: return a.cos() return a * c, b torch._dynamo.export(my_dyn_fn)(x, y, z) dimx, dimy, dimz = torch.export.dims("dimx", "dimy", "dimz") dynamic_shapes = ({0: dimx}, {0: dimy}, {0: dimz}) with self.assertRaises(ConstraintViolationError): torch._dynamo.export(my_dyn_fn, dynamic_shapes=dynamic_shapes)(x, y, z) dimz = dimx dynamic_shapes = ({0: dimx}, {0: dimy}, {0: dimz}) torch._dynamo.export(my_dyn_fn, dynamic_shapes=dynamic_shapes)(x, y, z) def test_remove_redundant_dynamic_dim_in_error_message(self): class Foo(torch.nn.Module): def forward(self, x, y): if x.shape[0] == y["k"].shape[0]: return x + 1 else: return x - 1 foo = Foo() a = torch.randn(3) b = torch.randn(3) dim0_a, dim0_b = torch.export.dims("dim0_a", "dim0_b") with self.assertRaisesRegex(torch._dynamo.exc.UserError, "dim0_b = dim0_a"): torch.export.export( foo, (a, {"k": b}), dynamic_shapes={"x": {0: dim0_a}, "y": {"k": {0: dim0_b}}}, ) def test_enforce_equalities(self): class Bar(torch.nn.Module): def forward(self, x, y): return torch.matmul(x, y) bar = Bar() batch, size = torch.export.dims("batch", "size") dynamic_shapes = {"x": (batch, size, size), "y": (batch, size, size)} x = torch.randn(10, 3, 3) y = torch.randn(10, 3, 4) with self.assertRaisesRegex( torch._dynamo.exc.UserError, ".*y.*size.*2.* = 4 is not equal to .*x.*size.*1.* = 3", ): torch.export.export( bar, (x, y), dynamic_shapes=dynamic_shapes, ) y = torch.randn(10, 3, 3) ebar = torch.export.export( bar, (x, y), dynamic_shapes=dynamic_shapes, ) self.assertEqual( [ str(node.meta["val"].shape) for node in ebar.graph_module.graph.nodes if node.op == "placeholder" ], ["torch.Size([s0, s1, s1])", "torch.Size([s0, s1, s1])"], ) @torch._dynamo.config.patch( capture_dynamic_output_shape_ops=True, specialize_int=True, capture_scalar_outputs=True, ) def test_export_preserve_constraints_as_metadata_tensor(self): def f(x): b = x.nonzero() torch._check(b.shape[0] >= 2) torch._check(b.shape[0] <= 5) return b y = torch.tensor([8, 8, 6]) gm, _ = torch._dynamo.export( f, aten_graph=True, tracing_mode="symbolic", )(y) @config.patch( capture_dynamic_output_shape_ops=True, specialize_int=True, capture_scalar_outputs=True, ) def test_exported_graph_serialization(self): def f(x, y): b = x.item() torch._check_is_size(b) return torch.empty((b, y.shape[0])) x = torch.tensor([3]) y = torch.randn([8, 8, 6]) example_inputs = [x, y] dynamic_shapes = (None, {0: torch.export.Dim("dimy", min=6, max=10)}) gm, _ = torch._dynamo.export( f, dynamic_shapes=dynamic_shapes, aten_graph=True, tracing_mode="symbolic", )(*example_inputs) # Ensure the exported graph module with metadata is serializable, # metadata won't be saved in the serialized module buffer = io.BytesIO() torch.save(gm, buffer) def test_export_dynamic_dim_not_1(self): x = torch.randn([1, 1, 1]) def my_dyn_fn(a): if a.shape[0] != 1: return a.cos() return a * a torch._dynamo.export(my_dyn_fn)(x) with self.assertRaises(ConstraintViolationError): torch._dynamo.export( my_dyn_fn, dynamic_shapes=({0: torch.export.Dim("dimx")},) )(x) def test_symbool(self): def f(x): a = torch.scalar_tensor(x.shape[0] > 4) return x.sin().sum() + a.sum() gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.ones(6, 4)) self.assertEqual(gm(torch.ones(3, 4)), f(torch.ones(3, 4))) def test_export_multi_dynamic_dim_constraint(self): x = torch.randn([3, 3, 3]) y = torch.randn([2, 2, 2]) z = torch.randn([3, 3, 3]) def my_dyn_fn(a, b, c): if a.shape[0] == c.shape[0]: return a.cos() return a * c, b torch._dynamo.export(my_dyn_fn)(x, y, z) dimx_0, dimx_1, dimx_2 = torch.export.dims("dimx_0", "dimx_1", "dimx_2") dynamic_shapes = ({0: dimx_0, 1: dimx_1, 2: dimx_2}, None, None) with self.assertRaises(ConstraintViolationError): torch._dynamo.export(my_dyn_fn, dynamic_shapes=dynamic_shapes)(x, y, z) dynamic_shapes = ({0: dimx_0, 1: dimx_1, 2: dimx_2}, None, {0: dimx_0}) torch._dynamo.export(my_dyn_fn, dynamic_shapes=dynamic_shapes)(x, y, z) def test_export_dynamic_dim_range_constraint(self): x = torch.ones(6, 4, 4) dynamic_shapes = ({0: torch.export.Dim("dimx", min=5, max=6)},) def foo(x): if x.shape[0] > 3: # ok return x.sin() return x.cos() torch._dynamo.export( foo, dynamic_shapes=dynamic_shapes, aten_graph=True, )(x) def bar(x): if x.shape[0] > 5: # error return x.sin() return x.cos() with self.assertRaises(ConstraintViolationError): torch._dynamo.export( bar, dynamic_shapes=dynamic_shapes, aten_graph=True, )(x) def test_trivial_constraint(self): class Foo(torch.nn.Module): def forward(self, x): # complex divisibility condition if (2 * x.shape[0] + 3) % (x.shape[0] - 3) == 0: return x + 1 else: return x - 1 foo = Foo() class Bar(torch.nn.Module): def forward(self, x): # trivially true if (2 * x.shape[0] + 2) % (x.shape[0] + 1) == 0: return x + 1 else: return x - 1 bar = Bar() class Qux(torch.nn.Module): def forward(self, x): # simple divisibility condition (not trivially true) if (3 * x.shape[0]) % 2 == 0: return x + 1 else: return x - 1 qux = Qux() x = torch.randn(12) dim0 = torch.export.Dim("dim0", max=100) dynamic_shapes = {"x": (dim0,)} with self.assertRaisesRegex( torch._dynamo.exc.UserError, r"Constraints violated \(dim0\)", ): torch.export.export(foo, (x,), dynamic_shapes=dynamic_shapes) torch.export.export(bar, (x,), dynamic_shapes=dynamic_shapes) with self.assertRaisesRegex( torch._dynamo.exc.UserError, r"Constraints violated \(dim0\)", ): torch.export.export(qux, (x,), dynamic_shapes=dynamic_shapes) def test_list_contains(self): def func(x): assert x.size(-1) in [4, 5, 6], "bad" return x + x inps = (torch.randn(1, 5),) opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) real_result = opt_func(*inps) torch._dynamo.reset() exported = torch._dynamo.export(func, aten_graph=True)(*inps) out_graph = exported[0] dynamo_result = out_graph(*inps) self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) def test_list_not_contains(self): def func(x): assert x.size(0) not in [4, 5, 6], "bad1" assert "monkey" not in ["cow", "pig"], "bad2" return x + x inps = (torch.randn(1, 5),) opt_func = torch._dynamo.optimize("eager", nopython=True, dynamic=True)(func) real_result = opt_func(*inps) torch._dynamo.reset() exported = torch._dynamo.export(func, aten_graph=True)(*inps) out_graph = exported[0] dynamo_result = out_graph(*inps) self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) def test_export_identity(self): inp = torch.tensor([0.1, 0.1]) def func(x): return x torch._dynamo.reset() exported, _ = torch._dynamo.export(func)(inp) dynamo_result = exported(inp) self.assertTrue(torch._dynamo.utils.same(inp, dynamo_result)) def test_export_specialized_int(self): class Foo(torch.nn.Module): def __init__( self, input_dim, ): super().__init__() self.torch_module = torch.nn.LayerNorm( input_dim, eps=1e-5, elementwise_affine=True ) self.int_val = 100 def forward(self, input): return input.cos() * self.int_val * self.torch_module.eps mod = Foo(128) inp = torch.randn(3, 128) # In export, int & float in forward should always be specialized gm, _ = torch._dynamo.export(mod, aten_graph=True)(inp) count = 0 for node in gm.graph.nodes: if node.op == "placeholder": count += 1 self.assertEqual(count, 1) def test_export_with_nonzero_static(self): class BasicModule(torch.nn.Module): def __init__(self, static_size): super().__init__() self.static_size = static_size def forward(self, x): return torch.nonzero_static(x, size=self.static_size) input_tensors = torch.tensor([6, 8]), torch.zeros(2, 3) static_sizes = 3, 4 for input_tensor, static_size in zip(input_tensors, static_sizes): m = BasicModule(static_size) gm, _ = torch._dynamo.export(m, aten_graph=True)(input_tensor) res = gm(input_tensor) self.assertEqual(res.size(0), static_size) self.assertTrue( torch._dynamo.utils.same( res, torch.nonzero_static(input_tensor, size=static_size) ) ) def test_export_pass_arg_by_name(self): class BasicModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.my_lin = torch.nn.Linear(3, 4, bias=True) def forward(self, x): return self.my_lin(x) mod, input_tensor = BasicModule(), torch.randn(2, 3) gm, guard = torch._dynamo.export(mod, aten_graph=True)(input_tensor) ref = mod(x=input_tensor) res = gm(x=input_tensor) self.assertTrue(torch._dynamo.utils.same(ref, res)) def test_export_pass_arg_by_name_star_args(self): class BasicModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.my_lin = torch.nn.Linear(3, 4, bias=True) def forward(self, *args): return self.my_lin(args[0]) * self.my_lin(args[1]) mod, input_tensor, input_tensor2 = ( BasicModule(), torch.randn(2, 3), torch.randn(2, 3), ) gm, guard = torch._dynamo.export(mod, aten_graph=True)( input_tensor, input_tensor2 ) ref = mod(input_tensor, input_tensor2) res = gm(input_tensor, input_tensor2) self.assertTrue(torch._dynamo.utils.same(ref, res)) def test_export_mark_dynamic_conflict_dynamic_dim(self): y = torch.randn([3, 3, 3]) def my_dyn_fn(x): if x.shape[0] > 3: return x.sin() return x.cos() torch._dynamo.mark_dynamic(y, 0) with self.assertRaisesRegex( RuntimeError, "Constraints violated", ): torch._dynamo.export( my_dyn_fn, dynamic_shapes=({0: torch.export.Dim("dim")},) )(y) def test_export_dynamic_dim_cleanup(self): y = torch.randn([3, 3, 3]) def my_dyn_fn(x): return x.cos() torch._dynamo.export(my_dyn_fn, dynamic_shapes=({0: torch.export.Dim("dim")},))( y ) @config.patch(capture_dynamic_output_shape_ops=True) def test_export_dynamic_control_flow_error(self): def f(x): if x.nonzero() > 3: return x.cos() return x.sin() with self.assertRaisesRegex( torch._dynamo.exc.UserError, "Dynamic control flow is not supported at the moment", ): gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.randn(5, 6)) @config.patch(assume_static_by_default=False) def test_export_persist_assert(self): def f(x): assert x[0].sum() > 4, "Shape must be more than 4" return x.cos() + x.sin() gm, guard = torch._dynamo.export(f, aten_graph=True, tracing_mode="symbolic")( torch.ones(5, 4, 6) ) def has_aten_op(gm, op): for node in gm.graph.nodes: if node.target == op: return True return False self.assertTrue(has_aten_op(gm, torch.ops.aten._assert_async.msg)) gm.graph.eliminate_dead_code() gm.recompile() self.assertTrue(has_aten_op(gm, torch.ops.aten._assert_async.msg)) with self.assertRaisesRegex(RuntimeError, "Shape must be more than 4"): gm(torch.zeros(3, 4, 5)) @common_utils.parametrize( "type_fn", [ common_utils.subtest(type, name="builtin"), common_utils.subtest(lambda obj: obj.__class__, name="attr"), ], ) def test_access_class_method_from_user_class(self, type_fn): class A: @classmethod def func(cls): return torch.Tensor([4, 5]) def f(x): a = A() return x.sum() + type_fn(a).func().sum() gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.ones(6, 4)) self.assertEqual(f(torch.ones(6, 4)), gm(torch.ones(6, 4))) def test_not_functionalize(self): class Foo(torch.nn.Module): def __init__(self) -> None: super().__init__() self.buffer1 = torch.nn.Buffer(torch.ones(6, 2)) def forward(self, x): x.add_(2) return x.sum() + self.buffer1.sum() example_inputs = (torch.ones(1, 2, 3),) gm, _ = torch._dynamo.export( Foo(), aten_graph=True, tracing_mode="symbolic", )(*example_inputs) count = 0 for node in gm.graph.nodes: if node.target == torch.ops.aten.add_.Tensor: count += 1 self.assertEqual(count, 1) test_inp = (torch.ones(1, 2, 3),) test_inp_v2 = (torch.ones(1, 2, 3),) self.assertEqual(gm(*test_inp), Foo()(*test_inp_v2)) def test_round_dynamic_shapes(self): def f(x): return x[: round(x.shape[0] / 2)] gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.ones(6, 4)) self.assertEqual(f(torch.ones(6, 4)), gm(torch.ones(6, 4))) def test_cond_supported_pred_types(self): def true_fn(x): return x.cos() def false_fn(x): return x.sin() def f_pred_traced_as_symnode_var(x): return cond(x.shape[0] > 2, true_fn, false_fn, [x]) def f_pred_traced_as_tensor_var(x): return cond(x.all(), true_fn, false_fn, [x]) def f_pred_complex_expression_traced_as_symnode_var(x): return cond( x.dim() > 1 and x.shape[1] > 5 and x.shape[1] <= 10, true_fn, false_fn, [x], ) example_inputs = (torch.rand(5, 8),) for f in [ f_pred_traced_as_symnode_var, f_pred_traced_as_tensor_var, f_pred_complex_expression_traced_as_symnode_var, ]: gm, _ = torch._dynamo.export(f, aten_graph=True)(*example_inputs) self.assertEqual(gm(*example_inputs), f(*example_inputs)) @unittest.expectedFailure # TODO: Not sure why dynamo creates a new inputs for self.a def test_sum_param(self): # Setting a new attribute inside forward() class Foo(torch.nn.Module): def __init__(self) -> None: super().__init__() self.a = torch.randn(3, 2) def forward(self, x): self.b = 2 return x.sum() + self.a.sum() + self.b torch._dynamo.export(Foo())(torch.randn(3, 2)) def test_mixed_real_and_fake_inputs(self): class _TestPattern(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d(1, 1, 1) self.bn = torch.nn.BatchNorm2d(1) def forward(self, input): running_std = torch.sqrt(self.bn.running_var + self.bn.eps) scale_factor = self.bn.weight / running_std weight_shape = [1] * len(self.conv.weight.shape) weight_shape[0] = -1 bias_shape = [1] * len(self.conv.weight.shape) bias_shape[1] = -1 scaled_weight = self.conv.weight * scale_factor.reshape(weight_shape) zero_bias = torch.zeros_like(self.conv.bias, dtype=input.dtype) conv = self.conv._conv_forward(input, scaled_weight, zero_bias) conv_orig = conv / scale_factor.reshape(bias_shape) conv_orig = conv_orig + self.conv.bias.reshape(bias_shape) conv = self.bn(conv_orig) return conv example_inputs = (torch.randn(1, 1, 3, 3),) torch._dynamo.export( _TestPattern(), aten_graph=True, )(*example_inputs) @config.patch( capture_dynamic_output_shape_ops=True, capture_scalar_outputs=True, assume_static_by_default=False, ) def test_sym_contains(self): def f(x, y): return x.size(0) in y gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.ones(2), torch.ones(3)) true_inp = (torch.Tensor([6, 4, 5]), torch.ones(6, 4).add_(5)) false_inp = (torch.Tensor([6, 4, 5]), torch.ones(6, 4).add_(2)) self.assertEqual(gm(*true_inp), f(*true_inp)) self.assertEqual(gm(*false_inp), f(*false_inp)) def test_cond_raise_user_error_on_missing_args(self): def true_fn(x): return x.cos() def false_fn(x): return x.sin() def f(x): return cond(x.shape[0] > 10, true_fn, false_fn) example_inputs = (torch.rand(5),) with self.assertRaisesRegex( TypeError, r"cond\(\) missing 1 required positional argument: 'operands'", ): f(*example_inputs) def test_cond_raise_user_error_on_unsupported_pred(self): def f_unsupported_pred(x): pred = torch.nn.Module() return cond(pred, lambda x: x.sin(), lambda x: x.cos(), [x]) example_inputs = (torch.rand(5),) with self.assertRaisesRegex( RuntimeError, "Expected pred to be bool or tensor, but got Module()", ): f_unsupported_pred(*example_inputs) def test_cond_raise_user_error_on_non_list_operands(self): def f_non_list_operands(x): return cond(torch.tensor(True), lambda x: x.sin(), lambda x: x.cos(), x) example_inputs = (torch.rand(5),) with self.assertRaisesRegex( RuntimeError, r"Expect operands to be a tuple of possibly nested dict/list/tuple", ): f_non_list_operands(*example_inputs) def test_cond_raise_user_error_on_non_tensor_operands(self): def f_non_tensor_operands(x): a: float = 3.14 return cond( torch.tensor(1234), lambda x, a: x.sin(), lambda x, a: x.cos(), [x, a] ) example_inputs = (torch.rand(5),) with self.assertRaisesRegex( RuntimeError, r"Expect operands to be a tuple of possibly nested dict/list/tuple", ): f_non_tensor_operands(*example_inputs) def test_cond_raise_user_error_on_branch_args_mismatch(self): def true_fn(x, y): return x.sin() def false_fn(x): return x.cos() def f_branch_args_mismatch(x, y): return cond(torch.tensor([[[[True]]]]), true_fn, false_fn, [x, y]) example_inputs = (torch.rand(5), torch.rand(2)) with self.assertRaisesRegex( torch._dynamo.exc.UncapturedHigherOrderOpError, "Cond doesn't work unless it is captured completely with torch.compil", ): torch._dynamo.export( f_branch_args_mismatch, aten_graph=True, )( *example_inputs, ) @config.patch(suppress_errors=True) def test_uncaptured_higher_order_op_error_not_suppresed(self): def true_fn(x, y): return x.sin() def false_fn(x): return x.cos() def f_branch_args_mismatch(x, y): return cond(torch.tensor([[[[100]]]]), true_fn, false_fn, [x, y]) example_inputs = (torch.rand(5), torch.rand(2)) with self.assertRaisesRegex( torch._dynamo.exc.UncapturedHigherOrderOpError, "Cond doesn't work unless it is captured completely with torch.compile", ): torch._dynamo.export( f_branch_args_mismatch, aten_graph=True, )( *example_inputs, ) def test_cond_raise_user_error_on_branch_return_non_tensor(self): def f_branch_return_non_tensor(x): return cond(x.shape[0] <= 5, lambda x: 3.14, lambda x: 3.14, [x]) example_inputs = (torch.rand(5),) with self.assertRaisesRegex( torch._dynamo.exc.UncapturedHigherOrderOpError, "Cond doesn't work unless it is captured completely with torch.compile", ): torch._dynamo.export( f_branch_return_non_tensor, aten_graph=True, )(*example_inputs) def test_cond_raise_user_error_on_branch_return_multiple_tensors(self): def f_branch_return_multiple_tensors(pred, x, y): return cond(pred, lambda x: (x, x), lambda x: (x, x), [y]) example_inputs = (torch.tensor(True), torch.randn(4), torch.randn(2)) gm, _ = torch._dynamo.export( f_branch_return_multiple_tensors, aten_graph=True, )(*example_inputs) self.assertEqual( gm(*example_inputs), f_branch_return_multiple_tensors(*example_inputs) ) def test_multiple_outputs_op_with_evaluator(self): class TopKModel(torch.nn.Module): def forward(self, x): values, _ = torch.topk(x, 3) return torch.sum(values) x = torch.arange(1.0, 6.0, requires_grad=True) torch._dynamo.export(TopKModel())(x) def test_cond_raise_user_error_on_mismatch_return_length(self): def true_fn(x): return x def false_fn(x): return (x, x) def f_mismatch_return_length(x): return cond(torch.tensor(100), true_fn, false_fn, [x]) example_inputs = (torch.rand(5),) with self.assertRaisesRegex( RuntimeError, "Unmatched number of outputs from cond" ): torch._dynamo.export( f_mismatch_return_length, aten_graph=True, )(*example_inputs) def test_cond_raise_user_error_on_mismatch_return_tensor_meta(self): def true_fn(x): return torch.tensor([[3], [2]]) def false_fn(x): return torch.tensor([3.14]) def f_return_tensor_mismatch(x): return cond(x.shape[0] < 3, true_fn, false_fn, [x]) example_inputs = (torch.rand(5),) with self.assertRaisesRegex( torch._dynamo.exc.UncapturedHigherOrderOpError, "Cond doesn't work unless it is captured completely with torch.compile", ): torch._dynamo.export(f_return_tensor_mismatch, aten_graph=True)( *example_inputs, ) def test_byte_tensor_does_not_crash(self): # See https://github.com/pytorch/pytorch/issues/100455 def func(text): tensor = torch.ByteTensor(list(bytes(text, "utf8"))) return tensor + tensor text = "".join(chr(a % 90 + 40) for a in range(111)) opt_func = torch._dynamo.optimize("eager", dynamic=True)(func) for i in [99, 100]: input = text[:i] opt_func(input) def test_export_defaults_ok(self): class DynamicSliceExportMod(torch.nn.Module): def forward(self, x): results = [] for i in range(4): results.append(x[: x.size(0) - i, i : x.size(2), i:3]) return tuple(results) gm, _ = torch._dynamo.export(DynamicSliceExportMod(), aten_graph=True)( torch.randn(5, 5, 5), ) self.assertExpectedInline( gm.code.strip(), """\ def forward(self, x): arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) arg0_1 = arg0 sym_size_int = torch.ops.aten.sym_size.int(arg0_1, 0) slice_1 = torch.ops.aten.slice.Tensor(arg0_1, 2, 0, 3) sub = sym_size_int - 1 slice_2 = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, sub); sub = None slice_3 = torch.ops.aten.slice.Tensor(slice_2, 1, 1, sym_size_int); slice_2 = None slice_4 = torch.ops.aten.slice.Tensor(slice_3, 2, 1, 3); slice_3 = None sub_1 = sym_size_int - 2 slice_5 = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, sub_1); sub_1 = None slice_6 = torch.ops.aten.slice.Tensor(slice_5, 1, 2, sym_size_int); slice_5 = None slice_7 = torch.ops.aten.slice.Tensor(slice_6, 2, 2, 3); slice_6 = None sub_2 = sym_size_int - 3 slice_8 = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, sub_2); arg0_1 = sub_2 = None slice_9 = torch.ops.aten.slice.Tensor(slice_8, 1, 3, sym_size_int); slice_8 = sym_size_int = None slice_10 = torch.ops.aten.slice.Tensor(slice_9, 2, 3, 3); slice_9 = None return pytree.tree_unflatten([slice_1, slice_4, slice_7, slice_10], self._out_spec)""", ) def test_capture_symbolic_tracing_simple_within_fake_mode(self): from torch._dynamo.output_graph import config def f(x): y = torch.randn(3) return x + x * y with fake_tensor.FakeTensorMode( shape_env=ShapeEnv( allow_scalar_outputs=config.capture_scalar_outputs, allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops, ), ): x = torch.randn(3) for aten_graph in [True, False]: gm, _ = torch._dynamo.export(f, aten_graph=aten_graph)(x) self.assertTrue( isinstance(gm, torch.fx.GraphModule), msg="test_capture_symbolic_tracing_simple_within_fake_mode_aten_graph_" + str(aten_graph), ) def test_export_with_symbool_inputs(self): def f(pred: bool, x: torch.Tensor): if pred: return x.sin() else: return x.cos() x = torch.randn([3, 4]) def test_symbool_guards( f, size_tests, exp_graph, exp_guard_code, exp_shape_env_guards ): shape_env = ShapeEnv() with fake_tensor.FakeTensorMode( shape_env=shape_env, ) as fake_mode: fake_x = fake_mode.from_tensor( x, symbolic_context=StatelessSymbolicContext( dynamic_sizes=[DimDynamic.DYNAMIC for _ in range(x.dim())], ), ) for i, size in enumerate(size_tests): pred = fake_x.size(0) == size gm, guards = torch._dynamo.export(f)(pred, x) actual = normalize_gm(gm.print_readable(print_output=False)) # TODO: This is naughty, EXPECTTEST_ACCEPT=1 doesn't work self.assertExpectedInline(actual, exp_graph[i]) dynamo_shape_env_guards = [ guard for guard in guards if guard.guard_types is not None and "SHAPE_ENV" in guard.guard_types ] self.assertEqual(len(dynamo_shape_env_guards), 1) guard_code_on_predicate = [ code for code in dynamo_shape_env_guards[0].code_list if "L['pred']" in code ] self.assertEqual(guard_code_on_predicate, exp_guard_code[i]) outter_shape_env_guards = [ str(guard.expr) for guard in shape_env.guards ] self.assertEqual(outter_shape_env_guards, exp_shape_env_guards[i]) true_graph = """\ class GraphModule(torch.nn.Module): def forward(self, pred, x): arg1: "f32[s1, s2]"; arg0, arg1, = fx_pytree.tree_flatten_spec(([pred, x], {}), self._in_spec) l_x_ = arg1 sin: "f32[s1, s2]" = l_x_.sin(); l_x_ = None return pytree.tree_unflatten([sin], self._out_spec) """ false_graph = """\ class GraphModule(torch.nn.Module): def forward(self, pred, x): arg1: "f32[s1, s2]"; arg0, arg1, = fx_pytree.tree_flatten_spec(([pred, x], {}), self._in_spec) l_x_ = arg1 cos: "f32[s1, s2]" = l_x_.cos(); l_x_ = None return pytree.tree_unflatten([cos], self._out_spec) """ true_guard_code = [ "cast_symbool_to_symint_guardless(L['pred']) == 1", ] false_guard_code = [ "Ne(cast_symbool_to_symint_guardless(L['pred']), 1)", ] test_symbool_guards( f, [3, 3, 4, 5], [true_graph, true_graph, false_graph, false_graph], [true_guard_code, true_guard_code, false_guard_code, false_guard_code], # Outter shape env should have no guards in it because we never specialize on the outter symbool. [[], [], [], []], ) def test_invalid_input_global(self) -> None: global bulbous_bouffant bulbous_bouffant = torch.randn(3) def f(y): return bulbous_bouffant + y self.assertExpectedInlineMunged( UserError, lambda: torch._dynamo.export(f)(torch.randn(3)), """\ G['bulbous_bouffant'], accessed at: File "test_export.py", line N, in f return bulbous_bouffant + y """, ) def test_invalid_input_global_multiple_access(self) -> None: global macademia macademia = torch.randn(3) def g(y): global macademia y = macademia + y return y def f(y): global macademia y = g(y) return macademia + y # NB: This doesn't actually work (it only reports the first usage), # but I'm leaving the test here in case we fix it later self.assertExpectedInlineMunged( UserError, lambda: torch._dynamo.export(f)(torch.randn(3)), """\ G['macademia'], accessed at: File "test_export.py", line N, in f y = g(y) File "test_export.py", line N, in g y = macademia + y """, ) def test_invalid_input_nonlocal(self) -> None: arglebargle = torch.randn(3) def f(y): return arglebargle + y self.assertExpectedInlineMunged( UserError, lambda: torch._dynamo.export(f)(torch.randn(3)), """L['arglebargle'], a closed over free variable""", ) def test_invalid_input_unused_nonlocal_ok(self) -> None: arglebargle = torch.randn(3) def f(y): x = arglebargle return y torch._dynamo.export(f)(torch.randn(3)) def test_symbolic_tracing_within_fake_mode_with_constraints(self): from torch._subclasses import fake_tensor fake_mode = fake_tensor.FakeTensorMode() class DynamicShapeSimpleModel(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, a, b, c) -> torch.Tensor: d = (torch.matmul(a, b) + c) / 2 d_s0 = d.shape[0] d_s1 = d.shape[1] d_s3 = d_s0 * d_s1 e = d.view(d_s3) return torch.cat([e, e]) with fake_mode: model = DynamicShapeSimpleModel() inputs = (torch.randn(2, 4), torch.randn(4, 7), torch.randn(2, 7)) dim = torch.export.Dim("dim") dynamic_shapes = ({0: dim}, None, {0: dim}) for aten_graph in [True, False]: gm = torch._dynamo.export( model, dynamic_shapes=dynamic_shapes, aten_graph=aten_graph, )(*inputs).graph_module # Since there are no parameters we can do this inputs = (torch.randn(2, 4), torch.randn(4, 7), torch.randn(2, 7)) self.assertEqual(model(*inputs), gm(*inputs)) def test_symbolic_tracing_within_fake_mode_with_constraints_with_parameters(self): from torch._subclasses import fake_tensor fake_mode = fake_tensor.FakeTensorMode() # TODO: Seems to choke if you don't make a fresh model and # just try to export Linear directly... class Model(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(2, 2) def forward(self, x): out = self.linear(x) return out with fake_mode: model = Model() inputs = (torch.randn(10, 2, 2),) dynamic_shapes = ({0: torch.export.Dim("dim")},) for aten_graph in [True, False]: gm = torch._dynamo.export( model, dynamic_shapes=dynamic_shapes, aten_graph=aten_graph, )(*inputs).graph_module def test_capture_symbolic_tracing_within_fake_mode(self): from torch._dynamo.output_graph import config from torch._subclasses import fake_tensor from torch.fx.experimental.symbolic_shapes import ShapeEnv class Model(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(2, 2) self.linear2 = torch.nn.Linear(2, 2) def forward(self, x): out = self.linear(x) out = self.linear2(out) return out # User-instantiated FakeTensorMode fake_mode = fake_tensor.FakeTensorMode( allow_non_fake_inputs=False, allow_fallback_kernels=True, shape_env=ShapeEnv( allow_scalar_outputs=config.capture_scalar_outputs, allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops, ), ) # Fakefy input+model before exporting it with fake_mode: x = torch.rand(5, 2, 2) model = Model() # Export the model with fake inputs and parameters for aten_graph in [True, False]: graph_module, _ = torch._dynamo.export(model, aten_graph=aten_graph)(x) self.assertTrue( isinstance(graph_module, torch.fx.GraphModule), msg="test_capture_symbolic_tracing_within_fake_mode_aten_graph_" + str(aten_graph), ) def test_cond_op_param_buffer_lifted(self): class A(torch.nn.Module): def __init__(self) -> None: super().__init__() self.buffer1 = torch.nn.Buffer(torch.zeros(6, 4)) def forward(self): return self.buffer1.sum() class B(torch.nn.Module): def __init__(self) -> None: super().__init__() self.buffer2 = torch.nn.Buffer(torch.ones(6, 4)) def forward(self): return self.buffer2.sum() class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.a = A() self.b = B() def forward(self, x): def true_fn(x): return x.cos() + self.a() def false_fn(x): return x.sin() + self.b() return (cond(x.shape[0] > 4, true_fn, false_fn, [x]),) gm, _ = torch._dynamo.export(M(), aten_graph=False)(torch.ones(6, 4)) self.assertEqual(gm(torch.ones(6, 4)), M()(torch.ones(6, 4))) self.assertEqual(gm(torch.ones(3, 4)), M()(torch.ones(3, 4))) def test_nested_cond_op_param_buffer_lifted(self): class A(torch.nn.Module): def __init__(self) -> None: super().__init__() self.buffer1 = torch.nn.Buffer(torch.zeros(6, 4)) def forward(self): return self.buffer1.sum() class B(torch.nn.Module): def __init__(self) -> None: super().__init__() self.buffer2 = torch.nn.Buffer(torch.ones(6, 4)) def forward(self): return self.buffer2.sum() class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.a = A() self.b = B() def forward(self, x): def true_true_fn(x): return x.cos() + self.a() def true_false_fn(x): return x.cos() + self.a() + 1 def true_fn(x): return cond(x.shape[0] > 5, true_true_fn, true_false_fn, [x]) def false_fn(x): return x.sin() + self.b() return (cond(x.shape[0] > 4, true_fn, false_fn, [x]),) gm, _ = torch._dynamo.export(M(), aten_graph=False)(torch.ones(6, 4)) self.assertEqual(gm(torch.ones(6, 4)), M()(torch.ones(6, 4))) self.assertEqual(gm(torch.ones(5, 4)), M()(torch.ones(5, 4))) self.assertEqual(gm(torch.ones(3, 4)), M()(torch.ones(3, 4))) def test_map_cond_param_buffer_lifted(self): from functorch.experimental.control_flow import cond, map class A(torch.nn.Module): def __init__(self) -> None: super().__init__() self.buffer1 = torch.nn.Buffer(torch.zeros(6, 4)) def forward(self): return self.buffer1.sum() class B(torch.nn.Module): def __init__(self) -> None: super().__init__() self.buffer2 = torch.nn.Buffer(torch.ones(6, 4)) def forward(self): return self.buffer2.sum() class Module(torch.nn.Module): def __init__(self) -> None: super().__init__() self.a = A() self.b = B() def inner(self, x, pred): def true_fn(x): return x + x + self.a() def false_fn(x): return x * x + self.b() return cond(pred, true_fn, false_fn, [x]) def forward(self, pred, xs): def body(x, pred): return self.inner(x, pred) + self.b() return map(body, xs, pred) mod = Module() x = torch.randn(3, 2, 1) pred_x = torch.tensor(True) y = torch.randn(4, 3, 2) pred_y = torch.tensor(False) real_result = mod(pred_y, y) out_graph, _ = torch._dynamo.export(mod)(pred_x, x) self.assertEqual(real_result, out_graph(pred_y, y)) def test_cond_free_variables_overlapping(self): from functorch.experimental.control_flow import cond class Module(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, pred, x): a = torch.ones(6, 4) b = torch.ones(6, 4) c = torch.ones(6, 4) d = torch.ones(6, 4) def true_fn(x): return x + x + a.cos() + b.cos() + d.cos() def false_fn(x): return x * x + a.sin() + b.sin() + c.sin() return cond(pred, true_fn, false_fn, [x]) mod = Module() x = torch.ones(6, 4) pred_x = torch.tensor(True) out_graph, _ = torch._dynamo.export(mod)(pred_x, x) self.assertExpectedInline( out_graph.code.strip(), """\ def forward(self, pred, x): arg0, arg1, = fx_pytree.tree_flatten_spec(([pred, x], {}), self._in_spec) l_pred_ = arg0 l_x_ = arg1 a = torch.ones(6, 4) b = torch.ones(6, 4) c = torch.ones(6, 4) d = torch.ones(6, 4) cond_true_0 = self.cond_true_0 cond_false_0 = self.cond_false_0 cond = torch.ops.higher_order.cond(l_pred_, cond_true_0, cond_false_0, [a, b, l_x_, d, c]); l_pred_ = cond_true_0 = cond_false_0 = a = b = l_x_ = d = c = None getitem = cond[0]; cond = None return pytree.tree_unflatten([getitem], self._out_spec)""", # noqa: B950,E122 ) self.assertExpectedInline( out_graph.cond_true_0.code.strip(), """\ def forward(self, a, b, l_x_, d_true_branch, c_false_branch): a_1 = a b_1 = b l_x__1 = l_x_ add = l_x__1 + l_x__1; l_x__1 = None cos = a_1.cos(); a_1 = None add_1 = add + cos; add = cos = None cos_1 = b_1.cos(); b_1 = None add_2 = add_1 + cos_1; add_1 = cos_1 = None cos_2 = d_true_branch.cos(); d_true_branch = None add_3 = add_2 + cos_2; add_2 = cos_2 = None return (add_3,)""", ) self.assertExpectedInline( out_graph.cond_false_0.code.strip(), """\ def forward(self, a, b, l_x_, d_true_branch, c_false_branch): a_1 = a b_1 = b l_x__1 = l_x_ mul = l_x__1 * l_x__1; l_x__1 = None sin = a_1.sin(); a_1 = None add = mul + sin; mul = sin = None sin_1 = b_1.sin(); b_1 = None add_1 = add + sin_1; add = sin_1 = None sin_2 = c_false_branch.sin(); c_false_branch = None add_2 = add_1 + sin_2; add_1 = sin_2 = None return (add_2,)""", ) @unittest.skipIf( common_utils.TEST_WITH_ASAN, "Times out with ASAN, see https://github.com/pytorch/pytorch/issues/110416", ) def test_retracibility(self): class MyLinear(torch.nn.Module): def __init__(self) -> None: super().__init__() self.weight = torch.randn(20, 98) self.bias = torch.randn(20) def forward(self, x): return torch.nn.functional.linear(x, self.weight, self.bias) class Foo(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d(16, 33, 3) self.linear = MyLinear() def forward(self, x): a, b = x a_conv = self.conv(a) a_linear = self.linear(a_conv) b_conv = self.conv(b) b_linear = self.linear(b_conv) return ( a_linear.cos() + b_linear.sin(), a_linear.sin() + b_linear.cos(), ) inp_container = (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100)) gm, _ = torch._dynamo.export(Foo(), inp_container, aten_graph=True) gm2, _ = torch._dynamo.export(gm, inp_container, aten_graph=True) inp_test = (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100)) self.assertTrue(torch.allclose(gm(inp_test)[0], gm2(inp_test)[0])) self.assertTrue(torch.allclose(gm(inp_test)[1], gm2(inp_test)[1])) def test_retracibility_dict_container_inp_out(self): class MyLinear(torch.nn.Module): def __init__(self) -> None: super().__init__() self.weight = torch.randn(20, 98) self.bias = torch.randn(20) def forward(self, x): return torch.nn.functional.linear(x, self.weight, self.bias) class Foo(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d(16, 33, 3) self.linear = MyLinear() def forward(self, x): a1, a2 = x["a"] b = x["b"] a1_conv = self.conv(a1) a1_linear = self.linear(a1_conv) a2_conv = self.conv(a2) a2_linear = self.linear(a2_conv) b_conv = self.conv(b) b_linear = self.linear(b_conv) return { "a": [ a1_linear.cos() + b_linear.sin(), a1_linear.cos() + b_linear.sin(), ], "b": a2_linear.sin() + b_linear.cos(), } inp_container = { "a": (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100)), "b": torch.randn(20, 16, 50, 100), } gm, _ = torch._dynamo.export(Foo(), inp_container, aten_graph=True) gm2, _ = torch._dynamo.export(gm, inp_container, aten_graph=True) inp_test = { "a": (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100)), "b": torch.randn(20, 16, 50, 100), } self.assertTrue(torch.allclose(gm(inp_test)["a"][0], gm2(inp_test)["a"][0])) self.assertTrue(torch.allclose(gm(inp_test)["a"][1], gm2(inp_test)["a"][1])) self.assertTrue(torch.allclose(gm(inp_test)["b"], gm2(inp_test)["b"])) def test_retracibility_nested_list_out(self): class MyLinear(torch.nn.Module): def __init__(self) -> None: super().__init__() self.weight = torch.randn(20, 98) self.bias = torch.randn(20) def forward(self, x): return torch.nn.functional.linear(x, self.weight, self.bias) class Foo(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d(16, 33, 3) self.linear = MyLinear() def forward(self, x): a1, a2 = x["a"] b = x["b"] a1_conv = self.conv(a1) a1_linear = self.linear(a1_conv) a2_conv = self.conv(a2) a2_linear = self.linear(a2_conv) b_conv = self.conv(b) b_linear = self.linear(b_conv) return [ [ a1_linear.cos() + b_linear.sin(), a1_linear.cos() + b_linear.sin(), ], [ a2_linear.sin() + b_linear.cos(), a2_linear.sin() + b_linear.cos(), ], ] inp_container = { "a": (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100)), "b": torch.randn(20, 16, 50, 100), } gm, _ = torch._dynamo.export(Foo(), inp_container, aten_graph=True) gm2, _ = torch._dynamo.export(gm, inp_container, aten_graph=True) inp_test = { "a": (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100)), "b": torch.randn(20, 16, 50, 100), } self.assertTrue(torch.allclose(gm(inp_test)[0][0], gm2(inp_test)[0][0])) self.assertTrue(torch.allclose(gm(inp_test)[0][1], gm2(inp_test)[0][1])) self.assertTrue(torch.allclose(gm(inp_test)[1][0], gm2(inp_test)[1][0])) self.assertTrue(torch.allclose(gm(inp_test)[1][1], gm2(inp_test)[1][1])) def test_fx_pytree(self): def foo(args): flat_args, spec = torch.utils._pytree.tree_flatten(args) flat_args_fx = torch.fx._pytree.tree_flatten_spec(args, spec) return flat_args_fx[0] + flat_args[0] inp_container = (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100)) gm, _ = torch._dynamo.export(foo, inp_container, aten_graph=True) self.assertTrue(torch.allclose(foo(inp_container), gm(inp_container))) @config.patch(suppress_errors=True) @config.patch(verbose=True) def test_export_with_map_zero_sized_tensor_suppress_errors(self): from functorch.experimental.control_flow import map class Module(torch.nn.Module): def forward(self, xs): def body(x): return x + 1 return map(body, xs) mod = Module() xs = torch.randn(0, 2) with self.assertRaises( torch._dynamo.exc.Unsupported, ): out_graph, _ = torch._dynamo.export(mod, xs) def test_param_buffer_safe_from_mutation_simple(self): class Module(torch.nn.Module): def __init__(self) -> None: super().__init__() self.buffer1 = torch.nn.Buffer(torch.zeros(5, 5)) def forward(self, x): self.buffer1.add_(1) return x + self.buffer1 gm, _ = torch._dynamo.export(Module(), torch.ones(5, 5), aten_graph=False) buffers = list(gm.named_buffers()) self.assertEqual(len(buffers), 1) name, buffer = buffers[0] self.assertEqual(name, "L__self___buffer1") self.assertTrue(torch.allclose(buffer, torch.zeros(5))) def test_param_buffer_safe_from_mutation_recurse(self): class Child(torch.nn.Module): def __init__(self) -> None: super().__init__() self.buffer2 = torch.nn.Buffer(torch.zeros(5)) def forward(self, x): return x.sum() + self.buffer2.sum() class Module(torch.nn.Module): def __init__(self) -> None: super().__init__() self.buffer1 = torch.nn.Buffer(torch.zeros(5)) self.child = Child() def forward(self, x): self.buffer1.add_(1) self.child.buffer2.add_(2) return x.sum() + self.buffer1.sum() + self.child(x) gm, _ = torch._dynamo.export(Module(), torch.ones(5), aten_graph=False) for name, buffer in gm.named_buffers(): self.assertTrue(torch.allclose(buffer, torch.zeros(5))) def test_predispatch_with_higher_order(self): def f(x): return cond(x.shape[0] > 4, lambda x: x + 5, lambda x: x - 3, [x]) gm, _ = torch._dynamo.export(f, aten_graph=True, pre_dispatch=True)( torch.randn(4, 4) ) inp1 = torch.randn(4, 4) inp2 = torch.randn(6, 4) self.assertTrue(torch.allclose(f(inp1), gm(inp1))) self.assertTrue(torch.allclose(f(inp2), gm(inp2))) def test_predispatch_with_higher_order_nested(self): def f(x): def true_fn(x): return cond(x.shape[0] > 6, lambda x: x + 10, lambda x: x - 10, [x]) return cond(x.shape[0] > 4, true_fn, lambda x: x - 3, [x]) gm, _ = torch._dynamo.export(f, aten_graph=True, pre_dispatch=True)( torch.randn(4, 4) ) inp1 = torch.randn(4, 4) inp2 = torch.randn(6, 4) inp3 = torch.randn(8, 4) self.assertTrue(torch.allclose(f(inp1), gm(inp1))) self.assertTrue(torch.allclose(f(inp2), gm(inp2))) self.assertTrue(torch.allclose(f(inp3), gm(inp3))) def test_predispatch_with_for_out_dtype(self): class M(torch.nn.Module): def __init__(self, weight): super().__init__() self.weight = weight def forward(self, x): return out_dtype(torch.ops.aten.mm.default, torch.int32, x, self.weight) weight = torch.randint(-128, 127, (5, 5), dtype=torch.int8) m = M(weight) x = torch.randint(-128, 127, (5, 5), dtype=torch.int8) gm, _ = torch._dynamo.export(m, x, aten_graph=True, pre_dispatch=True) self.assertTrue(torch.allclose(m(x), gm(x))) def test_predispatch_with_for_out_dtype_nested(self): class M(torch.nn.Module): def __init__(self, weight): super().__init__() self.weight = weight def true_fn(self, x): return out_dtype( torch.ops.aten.mm.default, torch.int32, x, self.weight ).sum() def false_fn(self, x): return out_dtype( torch.ops.aten.mul.Tensor, torch.int32, x, self.weight ).sum() def forward(self, x): return cond(x.sum() != 0, self.true_fn, self.false_fn, [x]) weight = torch.randint(-128, 127, (5, 5), dtype=torch.int8) m = M(weight) x = torch.ones((5, 5), dtype=torch.int8) gm, _ = torch._dynamo.export(m, x, aten_graph=True, pre_dispatch=True) self.assertTrue(torch.allclose(m(x), gm(x))) y = torch.zeros((5, 5), dtype=torch.int8) self.assertTrue(torch.allclose(m(y), gm(y))) self.assertExpectedInline( gm.true_graph_0.code.strip(), """\ def forward(self, arg0_1, arg1_1): out_dtype = torch.ops.higher_order.out_dtype(torch.ops.aten.mm.default, torch.int32, arg1_1, arg0_1); arg1_1 = arg0_1 = None sum_1 = torch.ops.aten.sum.default(out_dtype); out_dtype = None return (sum_1,)""", ) self.assertExpectedInline( gm.false_graph_0.code.strip(), """\ def forward(self, arg0_1, arg1_1): out_dtype = torch.ops.higher_order.out_dtype(torch.ops.aten.mul.Tensor, torch.int32, arg1_1, arg0_1); arg1_1 = arg0_1 = None sum_1 = torch.ops.aten.sum.default(out_dtype); out_dtype = None return (sum_1,)""", ) def test_export_nn_module_stack_patched_module(self): def forward(self, x, y): return x * y class Toplevel(torch.nn.Module): def __init__(self, m): super().__init__() self.m = m def forward(self, x, y): return self.m(x, y) class M(torch.nn.Module): def forward(self, x, y): return x + y t = Toplevel(M()) t.m.forward = forward.__get__(t.m, M) x, y = torch.rand(3), torch.rand(3) gm, _ = torch._dynamo.export(t, x, y) self.assertTrue(torch.allclose(forward(None, x, y), gm(x, y))) for node in gm.graph.nodes: if node.op == "call_function": self.assertIn("nn_module_stack", node.meta) def test_preserve_fx_node_metadata(self): class Module1(torch.nn.Module): def forward(self, x): return torch.sin(x) class Module2(torch.nn.Module): def __init__(self) -> None: super().__init__() self.mod1 = Module1() def forward(self, x): x = torch.cos(x) x = self.mod1(x) x = torch.relu(x) return x def fn(x): return torch.abs(x) mod = Module2() inp = torch.randn(3, 3) gm, _ = torch._dynamo.export(mod)(inp) # replace relu with fn gm_edit = copy.deepcopy(gm) for nd in gm_edit.graph.nodes: if nd.target == torch.relu: nd.target = fn nd.meta.clear() break gm_edit.recompile() gm2, _ = torch._dynamo.export(gm_edit)(inp) self.assertExpectedInline( gm.code.strip(), """\ def forward(self, x): arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) l_x_ = arg0 x = torch.cos(l_x_); l_x_ = None x_1 = torch.sin(x); x = None x_2 = torch.relu(x_1); x_1 = None return pytree.tree_unflatten([x_2], self._out_spec)""", ) def _constais_op(gm, target): for nd in gm.graph.nodes: if nd.target == target: return True return False self.assertTrue(_constais_op(gm_edit, torch.cos)) self.assertTrue(_constais_op(gm_edit, torch.sin)) self.assertTrue(not _constais_op(gm_edit, torch.relu)) self.assertExpectedInline( gm2.code.strip(), """\ def forward(self, x): arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) l_x_ = arg0 x = torch.cos(l_x_); l_x_ = None x_1 = torch.sin(x); x = None x_2 = torch.abs(x_1); x_1 = None return pytree.tree_unflatten([x_2], self._out_spec)""", ) # check for other metadata for op in (torch.sin, torch.cos): nd1 = next(filter(lambda nd: nd.target == op, gm.graph.nodes)) nd2 = next(filter(lambda nd: nd.target == op, gm2.graph.nodes)) self.assertTrue( ("nn_module_stack" in nd1.meta) == ("nn_module_stack" in nd2.meta) ) if "nn_module_stack" in nd1.meta: self.assertEqual( nd1.meta["nn_module_stack"], nd2.meta["nn_module_stack"] ) self.assertEqual(nd1.meta["stack_trace"], nd2.meta["stack_trace"]) def test_preserve_fx_node_metadata_recompile(self): def fn(x): return torch.sin(x) gm, _ = torch._dynamo.export(fn)(torch.randn(3, 3)) do_export = torch._dynamo.export(gm) torch._dynamo.optimize("eager")(fn)(torch.randn(3, 3)) gm1, _ = do_export(torch.randn(3, 3)) gm2, _ = do_export(torch.randn(5, 3)) self.assertExpectedInline( gm1.code.strip(), """\ def forward(self, x): arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) l_x_ = arg0 sin = torch.sin(l_x_); l_x_ = None return pytree.tree_unflatten([sin], self._out_spec)""", ) self.assertExpectedInline( gm2.code.strip(), """\ def forward(self, x): arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) l_x_ = arg0 sin = torch.sin(l_x_); l_x_ = None return pytree.tree_unflatten([sin], self._out_spec)""", ) def test_preserve_fx_node_metadata_inline(self): def f1(x): return torch.sin(x) gm, _ = torch._dynamo.export(f1)(torch.randn(3, 3)) def f2(x): x = torch.cos(x) return gm(x) gm2, _ = torch._dynamo.export(f2)(torch.randn(3, 3)) self.assertExpectedInline( gm2.code.strip(), """\ def forward(self, x): arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) l_x_ = arg0 x = torch.cos(l_x_); l_x_ = None sin = torch.sin(x); x = None return pytree.tree_unflatten([sin], self._out_spec)""", ) def test_preserve_fx_node_metadata_graph_break(self): def fn(x): x = torch.sin(x) x = torch.abs(x) return torch.cos(x) def bad_fn(x): torch._dynamo.graph_break() return x gm, _ = torch._dynamo.export(fn)(torch.randn(3, 3)) # replace abs with graph break gm_edit = copy.deepcopy(gm) for nd in gm_edit.graph.nodes: if nd.target == torch.abs: nd.target = bad_fn nd.meta.clear() break gm_edit.recompile() expected = [ """x = torch.sin(l_x_)""", """cos = torch.cos(l_stack0_)""", ] def test_backend(gm: torch.fx.GraphModule, example_inputs): self.assertTrue(expected) # Normalize output for dynamic and not for nd in gm.graph.nodes: if "example_value" in nd.meta: del nd.meta["example_value"] self.assertIn(expected[0], gm.print_readable(print_output=False)) expected.pop(0) return gm.forward torch._dynamo.reset() opt_gm_edit = torch.compile(gm_edit, backend=test_backend) opt_gm_edit(torch.randn(3, 3)) def test_torch_inference_mode_ctx(self): @torch.inference_mode() def fn(x): return x + 1 gm, _ = torch._dynamo.export(fn, torch.rand(2, 2)) inp = torch.randn(2, 2) out = gm(inp) self.assertExpectedInline( gm.code.strip(), """\ def forward(self, x): arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) l_args_0_ = arg0 _enter_inference_mode = torch.autograd.grad_mode._enter_inference_mode(True) add = l_args_0_ + 1; l_args_0_ = None _exit_inference_mode = torch.autograd.grad_mode._exit_inference_mode(_enter_inference_mode); _enter_inference_mode = _exit_inference_mode = None return pytree.tree_unflatten([add], self._out_spec)""", # NOQA: B950 ) self.assertEqual(out.requires_grad, False) with self.assertRaisesRegex( RuntimeError, "Setting requires_grad=True on inference tensor outside InferenceMode is not allowed.", ): out.requires_grad = True @torch.inference_mode(False) def fn_no_inference(x): return x + 1 gm_no_inference, _ = torch._dynamo.export(fn_no_inference, torch.rand(2, 2)) self.assertExpectedInline( gm_no_inference.code.strip(), """\ def forward(self, x): arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) l_args_0_ = arg0 _enter_inference_mode = torch.autograd.grad_mode._enter_inference_mode(False) add = l_args_0_ + 1; l_args_0_ = None _exit_inference_mode = torch.autograd.grad_mode._exit_inference_mode(_enter_inference_mode); _enter_inference_mode = _exit_inference_mode = None return pytree.tree_unflatten([add], self._out_spec)""", # NOQA: B950 ) inp = torch.randn(2, 2) out = gm_no_inference(inp) self.assertEqual(out.requires_grad, False) out.requires_grad = True def fn(x): with torch.inference_mode(): return x + 1 gm, _ = torch._dynamo.export(fn)(torch.rand(2, 2)) self.assertExpectedInline( gm.code.strip(), """\ def forward(self, x): arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) l_x_ = arg0 _enter_inference_mode = torch.autograd.grad_mode._enter_inference_mode(True) add = l_x_ + 1; l_x_ = None _exit_inference_mode = torch.autograd.grad_mode._exit_inference_mode(_enter_inference_mode); _enter_inference_mode = _exit_inference_mode = None return pytree.tree_unflatten([add], self._out_spec)""", # NOQA: B950 ) inp = torch.randn(2, 2, requires_grad=True) out = gm(inp) self.assertEqual(out.requires_grad, False) def test_export_masking_with_no_grad(self): def fn(x, b, y): x = x.clone() x[b] = y return x def fn_no_grad(x, b, y): with torch.no_grad(): return fn(x, b, y) def fn_inference_mode(x, b, y): with torch.inference_mode(): return fn(x, b, y) x = torch.randn(4, requires_grad=True) b = torch.tensor([True, False, True, False]) y = torch.randn(2, requires_grad=True) gm, _ = torch._dynamo.export(fn_no_grad)(x, b, y) self.assertExpectedInline( gm.code.strip(), """\ def forward(self, x, b, y): arg0, arg1, arg2, = fx_pytree.tree_flatten_spec(([x, b, y], {}), self._in_spec) l_x_ = arg0 l_b_ = arg1 l_y_ = arg2 _set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None x = l_x_.clone(); l_x_ = None x[l_b_] = l_y_; setitem = x; l_b_ = l_y_ = setitem = None _set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None return pytree.tree_unflatten([x], self._out_spec)""", ) gm, _ = torch._dynamo.export(fn_inference_mode)(x, b, y) self.assertExpectedInline( gm.code.strip(), """\ def forward(self, x, b, y): arg0, arg1, arg2, = fx_pytree.tree_flatten_spec(([x, b, y], {}), self._in_spec) l_x_ = arg0 l_b_ = arg1 l_y_ = arg2 _enter_inference_mode = torch.autograd.grad_mode._enter_inference_mode(True) x = l_x_.clone(); l_x_ = None x[l_b_] = l_y_; setitem = x; l_b_ = l_y_ = setitem = None _exit_inference_mode = torch.autograd.grad_mode._exit_inference_mode(_enter_inference_mode); _enter_inference_mode = _exit_inference_mode = None return pytree.tree_unflatten([x], self._out_spec)""", # NOQA: B950 ) with self.assertRaisesRegex( torch._dynamo.exc.Unsupported, "boolean masking setitem backwards" ): gm, _ = torch._dynamo.export(fn)(x, b, y) def test_dynamo_list_index(self): def fn(x, in_list): return x + in_list.index(2) inputs = (torch.ones(2, 2), [1, 2]) graph, _ = torch._dynamo.export(fn)(*inputs) out = graph(*inputs) self.assertEqual(out, torch.ones(2, 2) + 1) common_utils.instantiate_parametrized_tests(ExportTests) if __name__ == "__main__": from torch._dynamo.test_case import run_tests run_tests()