# Owner(s): ["oncall: export"] # flake8: noqa import copy import dataclasses import unittest from contextlib import contextmanager from dataclasses import dataclass from re import escape from typing import Any, List import torch import torch._dynamo as torchdynamo from functorch.experimental.control_flow import cond, map from torch import Tensor from torch._export.utils import ( get_buffer, get_param, is_buffer, is_param, register_dataclass_as_pytree_node, ) from torch._higher_order_ops.torchbind import enable_torchbind_tracing from torch.export import Constraint, Dim, export, FlatArgsAdapter, unflatten from torch.export._trace import DEFAULT_EXPORT_DYNAMO_CONFIG from torch.export.unflatten import _disable_interpreter from torch.fx.experimental.proxy_tensor import make_fx from torch.testing import FileCheck from torch.testing._internal.common_utils import ( find_library_location, IS_FBCODE, IS_MACOS, IS_SANDCASTLE, IS_WINDOWS, run_tests, skipIfTorchDynamo, TestCase, ) from torch.testing._internal.torchbind_impls import init_torchbind_implementations from torch.utils._pytree import ( LeafSpec, tree_flatten, tree_unflatten, TreeSpec, treespec_dumps, treespec_loads, ) @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support") class TestUnflatten(TestCase): def compare_outputs(self, eager, unflattened, args): orig_output = eager(*args) unflattened_output = unflattened(*args) self.assertTrue(torch.allclose(orig_output, unflattened_output)) def test_unflatten_nested(self): class NestedChild(torch.nn.Module): def forward(self, x): return x / x class Child1(torch.nn.Module): def __init__(self) -> None: super().__init__() self.nested = NestedChild() self.register_parameter( "child1param", torch.nn.Parameter(torch.ones(2, 3)) ) def forward(self, x): x = self.nested(x) return x + self.child1param class Child2(torch.nn.Module): def __init__(self) -> None: super().__init__() self.child2buffer = torch.nn.Buffer(torch.ones(2, 3)) def forward(self, x): return x - self.child2buffer class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.foo = Child1() self.bar = Child2() self.register_parameter( "rootparam", torch.nn.Parameter(torch.ones(2, 3)) ) def forward(self, x): x = x * self.rootparam x = self.foo(x) x = self.bar(x) return x orig_eager = MyModule() export_module = export(orig_eager, (torch.rand(2, 3),), {}) unflattened = unflatten(export_module) inputs = (torch.rand(2, 3),) # Compare the root modules and all submodules self.compare_outputs(orig_eager, unflattened, inputs) self.compare_outputs(orig_eager.foo, unflattened.foo, inputs) self.compare_outputs(orig_eager.bar, unflattened.bar, inputs) self.compare_outputs(orig_eager.foo.nested, unflattened.foo.nested, inputs) # Check state dicts are equal orig_state_dict = orig_eager.state_dict() exported_state_dict = unflattened.state_dict() for name, value in orig_state_dict.items(): self.assertTrue(torch.allclose(value, exported_state_dict[name])) def test_unflatten_buffer_mutation(self): class Child(torch.nn.Module): def __init__(self) -> None: super().__init__() self.child2buffer = torch.nn.Buffer(torch.ones(2, 3)) def forward(self, x): self.child2buffer.add_(x) return x - self.child2buffer class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.foo = Child() self.register_parameter( "rootparam", torch.nn.Parameter(torch.ones(2, 3)) ) def forward(self, x): x = self.foo(x) return x * self.rootparam eager_module = MyModule() export_module = export(eager_module, (torch.rand(2, 3),), {}) unflattened_module = unflatten(export_module) # Buffer should look the same before and after one run eager_buffer = eager_module.foo.child2buffer unflattened_buffer = unflattened_module.foo.child2buffer self.assertTrue(torch.allclose(eager_buffer, unflattened_buffer)) inputs = (torch.rand(2, 3),) eager_module(*inputs) unflattened_module(*inputs) self.assertTrue(torch.allclose(eager_buffer, unflattened_buffer)) def test_unflatten_nested_access(self): class Child(torch.nn.Module): def __init__(self) -> None: super().__init__() self.child2buffer = torch.nn.Buffer(torch.ones(2, 3)) def forward(self, x): return x - self.child2buffer class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.foo = Child() self.register_parameter( "rootparam", torch.nn.Parameter(torch.ones(2, 3)) ) def forward(self, x): x = x + self.foo.child2buffer x = self.foo(x) return x eager_module = MyModule() export_module = export(eager_module, (torch.rand(2, 3),), {}) unflattened_module = unflatten(export_module) inputs = (torch.rand(2, 3),) self.compare_outputs(eager_module, unflattened_module, inputs) def test_unflatten_shared_submodule(self): class Shared(torch.nn.Module): def __init__(self) -> None: super().__init__() layernorm = torch.nn.LayerNorm(10) self.sub_net = torch.nn.Sequential( layernorm, torch.nn.ReLU(), layernorm, torch.nn.ReLU(), ) def forward(self, x): return self.sub_net(x) eager_module = Shared() inps = (torch.rand(10),) export_module = export(eager_module, inps, {}) unflattened_module = unflatten(export_module) self.compare_outputs(eager_module, unflattened_module, inps) self.assertTrue(hasattr(unflattened_module, "sub_net")) for i in range(len(eager_module.sub_net)): self.assertTrue(hasattr(unflattened_module.sub_net, str(i))) self.assertEqual( id(getattr(unflattened_module.sub_net, "0")), id(getattr(unflattened_module.sub_net, "2")), ) @unittest.skipIf(IS_WINDOWS, "Windows not supported for this test") @skipIfTorchDynamo("Non strict mode is not meant to run with dynamo") def test_unflatten_preserve_signature(self): class NestedChild(torch.nn.Module): def forward(self, zx, y): return {"x": y["key"] + zx[1], "w": y["key"] * zx[1]} class Child1(torch.nn.Module): def __init__(self) -> None: super().__init__() self.nested = NestedChild() def forward(self, x, y): z = torch.ones_like(x) xw = self.nested((z, x), y={"key": y}) return xw["w"] + z - xw["x"] class Child2(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, x): return x - 1 class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.foo = Child1() self.bar = Child2() def forward(self, x, y): x = self.foo(x, y) x = self.bar(x) return x orig_eager = MyModule() inps = torch.rand(2, 3), torch.rand(2, 3) for strict in [True, False]: export_module = export( orig_eager, inps, {}, preserve_module_call_signature=("foo.nested",), strict=strict, ) unflattened = unflatten(export_module) self.compare_outputs(export_module.module(), unflattened, inps) unflattened.foo.nested = NestedChild() self.compare_outputs(export_module.module(), unflattened, inps) # Test tree spec mismatched input orig_outs = export_module.module()(*inps) new_inps = *inps, torch.rand(2, 3) with self.assertRaisesRegex( TypeError, "There is no flat args adapter sepcified. Are you sure you are calling this with the right arguments?", ): unflattened(new_inps) # With flat args adapter class KeepTwoFlatArgsAdapter(FlatArgsAdapter): def adapt( self, target_spec: TreeSpec, input_spec: TreeSpec, input_args: List[Any], ) -> List[Any]: while len(input_args) > 2: input_args.pop(-1) return input_args unflattened = unflatten(export_module, KeepTwoFlatArgsAdapter()) new_outs = unflattened(*new_inps) self.assertTrue(torch.allclose(orig_outs, new_outs)) def test_unflatten_param_list_dict(self): class Mod(torch.nn.Module): def __init__(self) -> None: super().__init__() self.param_list = torch.nn.ParameterList() self.param_dict = torch.nn.ParameterDict() for i in range(2): self.param_list.append(torch.nn.Parameter(torch.randn((2, 3)))) self.param_dict[f"key_{i}"] = torch.nn.Parameter( torch.randn((2, 3)) ) def forward(self, x): for i in range(2): x = x + self.param_list[i] x = x + self.param_dict[f"key_{i}"] return x export_module = torch.export.export(Mod(), (torch.randn((2, 3)),)) unflattened = unflatten(export_module) self.compare_outputs( export_module.module(), unflattened, (torch.randn((2, 3)),) ) @unittest.skipIf(IS_WINDOWS, "Windows not supported for this test") def test_unflatten_preserve_with_unused_input(self): class M1(torch.nn.Module): def forward(self, x, a, b): return x + a, b class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.m1 = M1() def forward(self, x, y): a, b = torch.topk(y, 2) return self.m1(x, a, b)[0] ep = torch.export.export( M(), (torch.randn(2), torch.randn(5)), preserve_module_call_signature=("m1",), strict=False, ) ep.graph.eliminate_dead_code() unflattened = unflatten(ep) self.compare_outputs(ep.module(), unflattened, (torch.randn(2), torch.randn(5))) def test_unflatten_wrong_input(self): class Mod(torch.nn.Module): def __init__(self) -> None: super().__init__() self.param_list = torch.nn.ParameterList() self.param_dict = torch.nn.ParameterDict() for i in range(2): self.param_list.append(torch.nn.Parameter(torch.randn((2, 3)))) self.param_dict[f"key_{i}"] = torch.nn.Parameter( torch.randn((2, 3)) ) def forward(self, x): a = x.sum() for i in range(2): a = a + self.param_list[i].sum() a = a + self.param_dict[f"key_{i}"].sum() return a export_module = torch.export.export(Mod(), (torch.randn((2, 3)),)) with self.assertRaisesRegex( RuntimeError, escape("Expected input at *args[0].shape[0] to be equal to 2, but got 6"), ): export_module.module()(torch.randn(6, 6)) unflattened = unflatten(export_module) with self.assertRaisesRegex( RuntimeError, escape("Expected input at *args[0].shape[0] to be equal to 2, but got 6"), ): unflattened(torch.randn(6, 6)) @unittest.skipIf(IS_WINDOWS, "Windows not supported for this test") def test_unflatten_with_inplace_compile(self): class NestedChild(torch.nn.Module): def forward(self, x): return x / x class Child1(torch.nn.Module): def __init__(self) -> None: super().__init__() self.nested = NestedChild() self.register_parameter( "child1param", torch.nn.Parameter(torch.ones(2, 3)) ) def forward(self, x): x = self.nested(x) return x + self.child1param class Child2(torch.nn.Module): def __init__(self) -> None: super().__init__() self.child2buffer = torch.nn.Buffer(torch.ones(2, 3)) def forward(self, x): return x - self.child2buffer class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.foo = Child1() self.bar = Child2() self.register_parameter( "rootparam", torch.nn.Parameter(torch.ones(2, 3)) ) def forward(self, x): x = x * self.rootparam x = self.foo(x) x = self.bar(x) return x orig_eager = MyModule() export_module = torch.export.export(orig_eager, (torch.rand(2, 3),), {}) unflattened = unflatten(export_module) # in-place compilation should work. Pass fullgraph to ensure no graph breaks. from torch._dynamo.backends.debugging import ExplainWithBackend eb = ExplainWithBackend("inductor") unflattened.foo.compile(backend=eb, fullgraph=True) inputs = (torch.randn(2, 3),) self.compare_outputs(orig_eager, unflattened, inputs) self.assertEqual(len(eb.graphs), 1) def test_fx_trace(self): class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, x, y): x = x[0] + x[1] x = x + y["foo"] return x orig_eager = MyModule() inputs = ((torch.rand(2, 3), torch.rand(2, 3)), {"foo": torch.rand(2, 3)}) export_module = export(orig_eager, inputs, {}) unflattened = unflatten(export_module) torch.fx.symbolic_trace( unflattened, concrete_args=(torch.fx.PH, torch.fx.PH, torch.fx.PH) ) def test_double_nested_submodule(self): class SubSubMod(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, x): return x * x class SubMod(torch.nn.Module): def __init__(self) -> None: super().__init__() self.subsubmod = SubSubMod() def forward(self, x): return x - x class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.submod = SubMod() def forward(self, x): return x + self.submod.subsubmod(x) orig_eager = MyModule() export_module = torch.export.export(orig_eager, (torch.rand(2, 3),), {}) unflattened = unflatten(export_module) inputs = (torch.rand(2, 3),) self.compare_outputs(orig_eager, unflattened, inputs) def test_unflatten_container_type(self): class Leaf(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(4, 4) def forward(self, x): return self.linear(x) class Bar(torch.nn.Module): def __init__(self) -> None: super().__init__() self.leaf = Leaf() self.buffer = torch.nn.Buffer(torch.randn(4, 4)) def forward(self, x, z): return self.buffer.sum() + self.leaf(x).sum() + z[0].sum() + z[1].sum() class Foo(torch.nn.Module): def __init__(self) -> None: super().__init__() self.bar = Bar() def forward(self, x, z): y = self.bar.buffer + x + z[0] + z[1] return self.bar(x, z) + y.sum() inp = (torch.randn(4, 4), [torch.randn(4, 4), torch.randn(4, 4)]) mod = Foo() ep_strict = torch.export.export(mod, inp) ep_non_strict = torch.export.export(mod, inp, strict=False) gm_unflat_non_strict = unflatten(ep_non_strict) ep = torch.export.export(gm_unflat_non_strict, inp, strict=False) self.assertTrue(torch.allclose(ep.module()(*inp), mod(*inp))) def test_unflattened_module_nodes_has_meta_val(self): class SubMod(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, x): return x + x, x * x class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.submod = SubMod() def forward(self, x): return x + sum(self.submod(x)) orig_eager = MyModule() export_module = torch.export.export(orig_eager, (torch.rand(2, 3),), {}) unflattened = unflatten(export_module) inputs = (torch.rand(2, 3),) self.compare_outputs(orig_eager, unflattened, inputs) def check_meta(gm): for n in gm.graph.nodes: if n.op == "output": continue self.assertTrue(n.meta.get("val") is not None) for m in unflattened.modules(): check_meta(m) def test_unflatten_requires_grad_param(self): class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.p = torch.nn.Parameter(torch.ones(3, 3), requires_grad=False) def forward(self, x): return self.p + x with torch.device("meta"): mod = M() inputs = (torch.randn(3, 3, device="meta"),) ep = export(mod, inputs) unflattened = unflatten(ep) self.assertTrue(unflattened.state_dict()["p"].requires_grad is False) self.assertTrue(unflattened.p.requires_grad is False) def test_placeholder_and_get_attr_ordering_after_unflattened(self): class TransposeModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d(3, 1, 3, stride=2) def forward(self, x): x = self.conv(x) return x.transpose(0, 1) x = torch.randn(32, 3, 64, 64) exported_program = export(TransposeModule(), args=(x,)) unflattened_module = unflatten(exported_program) # Check the inputs of the created call_module node are in order call_module_input_order = [] for node in unflattened_module.graph.nodes: if node.op == "call_module": transpose_module = unflattened_module.get_submodule(node.target) for sub_node in transpose_module.graph.nodes: if sub_node.op == "placeholder" or sub_node.op == "get_attr": call_module_input_order.append(sub_node.op) self.assertEqual( call_module_input_order, ["placeholder", "get_attr", "get_attr"] ) def test_unflatten_constant_tensor(self): class SubMod(torch.nn.Module): def __init__(self) -> None: super().__init__() self.initializer = 0.1 def forward(self, x): return x + torch.tensor(self.initializer) class Mod(torch.nn.Module): def __init__(self) -> None: super().__init__() self.submod = SubMod() def forward(self, x): return x + self.submod(x) export_module = torch.export.export(Mod(), (torch.randn((2, 3)),)) unflattened = unflatten(export_module) self.compare_outputs( export_module.module(), unflattened, (torch.randn((2, 3)),) ) @skipIfTorchDynamo("custom objects not supported in dynamo yet") def test_unflatten_constant_obj(self): init_torchbind_implementations() @torch._library.register_fake_class("_TorchScriptTesting::_Foo") class FakeFoo: def __init__(self, x: int, y: int): self.x = x self.y = y @classmethod def __obj_unflatten__(cls, flat_ctx): return cls(**dict(flat_ctx)) def add_tensor(self, z): return (self.x + self.y) * z class SubMod(torch.nn.Module): def __init__(self) -> None: super().__init__() self.attr = torch.classes._TorchScriptTesting._Foo(10, 20) def forward(self, x): return x + self.attr.add_tensor(x) class Mod(torch.nn.Module): def __init__(self) -> None: super().__init__() self.submod = SubMod() def forward(self, x): return x + self.submod(x) with enable_torchbind_tracing(): export_module = torch.export.export( Mod(), (torch.randn((2, 3)),), strict=False ) unflattened = unflatten(export_module) self.compare_outputs( export_module.module(), unflattened, (torch.randn((2, 3)),) ) # skip connection is not supported yet @unittest.expectedFailure def test_unflatten_skipped_call_module(self): class C(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): return a.d(x.cos()) class B(torch.nn.Module): def __init__(self): super().__init__() self.c = C() def forward(self, x): return self.c(x) + x class D(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): return x.sin() class A(torch.nn.Module): def __init__(self): super().__init__() self.b = B() self.d = D() def forward(self, x): return self.b(x) a = A() # The call chain looks like this: # A -> B -> C -> A.d ep = torch.export.export(a, (torch.randn(3),), strict=False) unflattened = unflatten(ep) def test_nested_leaf_non_strict(self): class Leaf(torch.nn.Module): def forward(self, x): return x + 1 class Nested(torch.nn.Module): def __init__(self) -> None: super().__init__() self.leaf = Leaf() def forward(self, x): return self.leaf(x) + 2 class TopLevel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.nested = Nested() def forward(self, x): return self.nested(x) + 3 ep = torch.export.export( TopLevel(), (torch.randn(3),), strict=False, preserve_module_call_signature=("nested",), ) torch.export.unflatten(ep) def test_unflatten_submodule_ordering(self): class Module2(torch.nn.Module): def __init__(self) -> None: super().__init__() self.buffer = torch.nn.Buffer(torch.rand(3, 4)) self.register_parameter("param", torch.nn.Parameter(torch.rand(3, 4))) def forward(self, x): return x + self.buffer + self.param class Module1(torch.nn.Module): def __init__(self) -> None: super().__init__() self.buffer = torch.nn.Buffer(torch.rand(3, 4)) self.register_parameter("param", torch.nn.Parameter(torch.rand(3, 4))) def forward(self, x): return x + self.buffer + self.param class Module(torch.nn.Module): def __init__(self) -> None: super().__init__() self.mod2 = Module2() self.mod3 = self.mod2 self.mod1 = Module1() def forward(self, x): return self.mod3(self.mod2(self.mod1(x))) mod = Module() ep = torch.export.export(mod, (torch.randn(3, 4),)) unflattened = torch.export.unflatten(ep) fqn_list = [x for x, _ in unflattened.named_modules(remove_duplicate=False)] self.assertEqual(len(fqn_list), 4) self.assertEqual( [x for x, _ in mod.named_modules(remove_duplicate=False)], fqn_list, ) def test_duplicate_placeholder(self): N, C, H, W = 1, 2, 2, 3 class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() layer = torch.nn.LayerNorm([C, H, W]) self.norms = torch.nn.ModuleList( [ layer, # reuse layer norm layer, layer, ] ) def forward(self, input_): for i in range(len(self.norms)): output = self.norms[i](input_) input_ = output return output mod = MyModule() input_ = torch.randn(N, C, H, W) ep_strict = export(copy.deepcopy(mod), (input_,), strict=True) umod = unflatten(ep_strict) self.assertTrue(torch.allclose(umod(input_), mod(input_))) ep_non_strict = export(copy.deepcopy(mod), (input_,), strict=False) umod = unflatten(ep_non_strict) self.assertTrue(torch.allclose(umod(input_), mod(input_))) def test_simple_alias(self): # handle weight sharing, check tensor ids after unflattening class Foo(torch.nn.Module): def __init__(self) -> None: super().__init__() # alias param self.bias = torch.nn.Parameter(torch.randn(4)) self.m = torch.nn.Linear(4, 4) self.m.bias = self.bias def forward(self, x): return self.m(x) + self.bias m = Foo() inps = (torch.randn(4, 4),) ep = export(m, inps) unep = unflatten(ep) self.assertTrue(id(unep.m.bias) == id(unep.bias)) # handle aliasing where one alias is unused class Foo(torch.nn.Module): def __init__(self) -> None: super().__init__() self.bias = torch.nn.Parameter(torch.randn(4)) self.m = torch.nn.Linear(4, 4) self.m.bias = ( self.bias ) # self.bias is unused, aliasing should be handled def forward(self, x): return self.m(x) m = Foo() inps = (torch.randn(4, 4),) ep = export(m, inps) unep = unflatten(ep) self.assertTrue(torch.allclose(unep(*inps), m(*inps))) def test_attr_as_submod_input(self): class layer(torch.nn.Module): def forward(self, x, const) -> torch.Tensor: return x + const class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.const = torch.nn.Buffer(torch.ones(4, 8)) self.layers = torch.nn.ModuleList([layer() for _ in range(2)]) def forward(self, x: torch.Tensor) -> torch.Tensor: for layer in self.layers: x = layer(x, self.const) return x mod = M() x = torch.randn(4, 8) ep = export(mod, (x,)) unflattened = unflatten(ep) torch.testing.assert_close(unflattened(x), mod(x)) def test_dedup_sym_size(self): # Here, sym_size & floor div are used in 3 subgraphs (top-level, m1, m2), # but only one copy of sym_size is created in the initial export graph. # For m1, sym_size & floordiv should be copied as recompute since we preserve the call signature, # but for m2 floordiv should be passed in as a placeholder. # Test that this is preserved, and the unflattened module runs correctly. class M1(torch.nn.Module): def forward(self, x, y): d = x.size(0) // 2 return y[:d] class M2(torch.nn.Module): def forward(self, x, y): d = x.size(0) // 2 return y[:d] class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.m1 = M1() self.m2 = M2() def forward(self, x, y): d = x.size(0) // 2 m1_res = self.m1(x, y) m2_res = self.m2(x, y) return y[d:] + m1_res + m2_res inputs = (torch.ones(10), torch.ones(10)) d_ = torch.export.Dim("foo", max=2048) d = 2 * d_ ep = torch.export.export( M(), inputs, dynamic_shapes=((d,), (d,)), strict=False, preserve_module_call_signature=("m1",), ) unflat = unflatten(ep) unflat(*inputs) fn_count_sym_size = lambda graph: [node.target for node in graph.nodes].count( torch.ops.aten.sym_size.int ) self.assertEqual(fn_count_sym_size(unflat.graph), 1) self.assertEqual(fn_count_sym_size(unflat.m1.graph), 1) self.assertEqual(fn_count_sym_size(unflat.m2.graph), 0) def test_unflatten_eager(self): class NestedChild(torch.nn.Module): def forward(self, x): return x / x class Child1(torch.nn.Module): def __init__(self) -> None: super().__init__() self.nested = NestedChild() self.register_parameter( "child1param", torch.nn.Parameter(torch.ones(2, 3)) ) def forward(self, x): x = self.nested(x) return x + self.child1param class Child2(torch.nn.Module): def __init__(self) -> None: super().__init__() self.child2buffer = torch.nn.Buffer(torch.ones(2, 3)) def forward(self, x): return x - self.child2buffer class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.foo = Child1() self.bar = Child2() self.register_parameter( "rootparam", torch.nn.Parameter(torch.ones(2, 3)) ) def forward(self, x): x = x * self.rootparam x = self.foo(x) x = self.bar(x) return x orig_eager = MyModule() export_module = export(orig_eager, (torch.rand(2, 3),), {}) with _disable_interpreter(): unflattened = unflatten(export_module) self.assertEqual(unflattened._run_with_interpeter, False) self.assertEqual(unflattened.foo._run_with_interpeter, False) inputs = (torch.rand(2, 3),) # Compare the root modules and all submodules self.compare_outputs(orig_eager, unflattened, inputs) self.compare_outputs(orig_eager.foo, unflattened.foo, inputs) self.compare_outputs(orig_eager.bar, unflattened.bar, inputs) self.compare_outputs(orig_eager.foo.nested, unflattened.foo.nested, inputs) # Check state dicts are equal orig_state_dict = orig_eager.state_dict() exported_state_dict = unflattened.state_dict() for name, value in orig_state_dict.items(): self.assertTrue(torch.allclose(value, exported_state_dict[name])) # Check composability with symbolic trace, as torchrec ddp uses symbolic # tracer symbolic_traced = torch.fx.symbolic_trace(unflattened, concrete_args=inputs) self.assertTrue(torch.allclose(orig_eager(*inputs), symbolic_traced(*inputs))) # torch.compile submodule unflattened.foo = torch.compile(unflattened.foo, fullgraph=True) self.compare_outputs(orig_eager, unflattened, inputs) if __name__ == "__main__": run_tests()