xref: /aosp_15_r20/external/pytorch/test/test_fx_experimental.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: fx"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport functools
4*da0073e9SAndroid Build Coastguard Workerimport math
5*da0073e9SAndroid Build Coastguard Workerimport numbers
6*da0073e9SAndroid Build Coastguard Workerimport operator
7*da0073e9SAndroid Build Coastguard Workerimport pickle
8*da0073e9SAndroid Build Coastguard Workerimport sys
9*da0073e9SAndroid Build Coastguard Workerimport sympy
10*da0073e9SAndroid Build Coastguard Workerimport tempfile
11*da0073e9SAndroid Build Coastguard Workerimport unittest
12*da0073e9SAndroid Build Coastguard Workerfrom types import BuiltinFunctionType
13*da0073e9SAndroid Build Coastguard Workerfrom typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Workerimport torch
16*da0073e9SAndroid Build Coastguard Workerimport torch.fx.experimental.meta_tracer
17*da0073e9SAndroid Build Coastguard Workerimport torch.fx.experimental.optimization as optimization
18*da0073e9SAndroid Build Coastguard Workerfrom torch.fx._symbolic_trace import symbolic_trace
19*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.experimental import merge_matmul
20*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.experimental.accelerator_partitioner import Partitioner
21*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.experimental.normalize import NormalizeArgs, NormalizeOperators
22*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.experimental.partitioner_utils import (
23*da0073e9SAndroid Build Coastguard Worker    Device,
24*da0073e9SAndroid Build Coastguard Worker    get_latency_of_partitioned_graph,
25*da0073e9SAndroid Build Coastguard Worker    get_partition_to_latency_mapping,
26*da0073e9SAndroid Build Coastguard Worker    NodeLatency,
27*da0073e9SAndroid Build Coastguard Worker    PartitionerConfig,
28*da0073e9SAndroid Build Coastguard Worker    PartitionMode,
29*da0073e9SAndroid Build Coastguard Worker)
30*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.experimental.rewriter import RewritingTracer
31*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.experimental.schema_type_annotation import AnnotateTypesWithSchema
32*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.graph_module import GraphModule
33*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.node import Node
34*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.operator_schemas import (
35*da0073e9SAndroid Build Coastguard Worker    _torchscript_type_to_python_type,
36*da0073e9SAndroid Build Coastguard Worker    create_type_hint,
37*da0073e9SAndroid Build Coastguard Worker    normalize_function,
38*da0073e9SAndroid Build Coastguard Worker    normalize_module,
39*da0073e9SAndroid Build Coastguard Worker    type_matches,
40*da0073e9SAndroid Build Coastguard Worker)
41*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.passes import graph_manipulation
42*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.passes.param_fetch import lift_lowering_attrs_to_nodes
43*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.passes.shape_prop import ShapeProp
44*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.passes.split_module import split_module
45*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.passes.annotate_getitem_nodes import annotate_getitem_nodes
46*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import (
47*da0073e9SAndroid Build Coastguard Worker    instantiate_device_type_tests,
48*da0073e9SAndroid Build Coastguard Worker    onlyCPU,
49*da0073e9SAndroid Build Coastguard Worker    ops,
50*da0073e9SAndroid Build Coastguard Worker)
51*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_methods_invocations import op_db
52*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_nn import module_tests, new_module_tests
53*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import TEST_Z3, run_tests, TestCase
54*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase
55*da0073e9SAndroid Build Coastguard Workerimport torch.utils._pytree as pytree
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Workertry:
58*da0073e9SAndroid Build Coastguard Worker    import torchvision.models
59*da0073e9SAndroid Build Coastguard Worker    from torchvision.models import resnet18
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Worker    HAS_TORCHVISION = True
62*da0073e9SAndroid Build Coastguard Workerexcept ImportError:
63*da0073e9SAndroid Build Coastguard Worker    HAS_TORCHVISION = False
64*da0073e9SAndroid Build Coastguard WorkerskipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
65*da0073e9SAndroid Build Coastguard WorkerskipIfNoMkldnn = unittest.skipIf(
66*da0073e9SAndroid Build Coastguard Worker    not (torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available()),
67*da0073e9SAndroid Build Coastguard Worker    "no MKLDNN",
68*da0073e9SAndroid Build Coastguard Worker)
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Workerdef symbolic_trace_with_rewrite(root: Union[torch.nn.Module, Callable]) -> GraphModule:
72*da0073e9SAndroid Build Coastguard Worker    return GraphModule(
73*da0073e9SAndroid Build Coastguard Worker        root if isinstance(root, torch.nn.Module) else torch.nn.Module(),
74*da0073e9SAndroid Build Coastguard Worker        RewritingTracer().trace(root),
75*da0073e9SAndroid Build Coastguard Worker    )
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Workerclass TestFXExperimental(JitTestCase):
79*da0073e9SAndroid Build Coastguard Worker    def test_find_single_partition(self):
80*da0073e9SAndroid Build Coastguard Worker        class TestModule(torch.nn.Module):
81*da0073e9SAndroid Build Coastguard Worker            def forward(self, a, b):
82*da0073e9SAndroid Build Coastguard Worker                return a + b
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard Worker        m = TestModule()
85*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(m)
86*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(1)
87*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(1)
88*da0073e9SAndroid Build Coastguard Worker        graph_manipulation.get_size_of_all_nodes(traced, [a, b])
89*da0073e9SAndroid Build Coastguard Worker        partitioner = Partitioner()
90*da0073e9SAndroid Build Coastguard Worker        devices = [
91*da0073e9SAndroid Build Coastguard Worker            Device("dev_0", 125, 0),
92*da0073e9SAndroid Build Coastguard Worker            Device("dev_1", 150, 1),
93*da0073e9SAndroid Build Coastguard Worker            Device("dev_2", 125, 2),
94*da0073e9SAndroid Build Coastguard Worker        ]
95*da0073e9SAndroid Build Coastguard Worker        partitioner_config = PartitionerConfig(devices)
96*da0073e9SAndroid Build Coastguard Worker        ret = partitioner.partition_graph(traced, m, partitioner_config)
97*da0073e9SAndroid Build Coastguard Worker        module_with_submodules = ret.module_with_submodules
98*da0073e9SAndroid Build Coastguard Worker        dag = ret.dag
99*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced(a, b), module_with_submodules(a, b))
100*da0073e9SAndroid Build Coastguard Worker        assert dag.nodes[0].logical_device_ids == [1]
101*da0073e9SAndroid Build Coastguard Worker
102*da0073e9SAndroid Build Coastguard Worker    def test_lack_of_devices(self):
103*da0073e9SAndroid Build Coastguard Worker        class TestModule(torch.nn.Module):
104*da0073e9SAndroid Build Coastguard Worker            def forward(self, a, b):
105*da0073e9SAndroid Build Coastguard Worker                return a + b
106*da0073e9SAndroid Build Coastguard Worker
107*da0073e9SAndroid Build Coastguard Worker        m = TestModule()
108*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(m)
109*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(4)
110*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(4)
111*da0073e9SAndroid Build Coastguard Worker        graph_manipulation.get_size_of_all_nodes(traced, [a, b])
112*da0073e9SAndroid Build Coastguard Worker        partitioner = Partitioner()
113*da0073e9SAndroid Build Coastguard Worker        devices = [Device("dev_0", 4, 0), Device("dev_1", 4, 1)]
114*da0073e9SAndroid Build Coastguard Worker        partitioner_config = PartitionerConfig(devices, PartitionMode.size_based)
115*da0073e9SAndroid Build Coastguard Worker        catch_runtime_error = False
116*da0073e9SAndroid Build Coastguard Worker        try:
117*da0073e9SAndroid Build Coastguard Worker            ret = partitioner.partition_graph(traced, m, partitioner_config)
118*da0073e9SAndroid Build Coastguard Worker        except RuntimeError:
119*da0073e9SAndroid Build Coastguard Worker            catch_runtime_error = True
120*da0073e9SAndroid Build Coastguard Worker        assert catch_runtime_error
121*da0073e9SAndroid Build Coastguard Worker
122*da0073e9SAndroid Build Coastguard Worker    def test_large_node_error(self):
123*da0073e9SAndroid Build Coastguard Worker        class TestModule(torch.nn.Module):
124*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
125*da0073e9SAndroid Build Coastguard Worker                super().__init__()
126*da0073e9SAndroid Build Coastguard Worker                self.linear = torch.nn.Linear(4, 4)
127*da0073e9SAndroid Build Coastguard Worker
128*da0073e9SAndroid Build Coastguard Worker            def forward(self, a):
129*da0073e9SAndroid Build Coastguard Worker                linear = self.linear(a)
130*da0073e9SAndroid Build Coastguard Worker                add = linear + a
131*da0073e9SAndroid Build Coastguard Worker                return add
132*da0073e9SAndroid Build Coastguard Worker
133*da0073e9SAndroid Build Coastguard Worker        m = TestModule()
134*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(m)
135*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(4)
136*da0073e9SAndroid Build Coastguard Worker        graph_manipulation.get_size_of_all_nodes(traced, [a])
137*da0073e9SAndroid Build Coastguard Worker        partitioner = Partitioner()
138*da0073e9SAndroid Build Coastguard Worker        devices = [
139*da0073e9SAndroid Build Coastguard Worker            Device("dev_0", 40, 0),
140*da0073e9SAndroid Build Coastguard Worker            Device("dev_1", 40, 0),
141*da0073e9SAndroid Build Coastguard Worker            Device("dev_2", 40, 0),
142*da0073e9SAndroid Build Coastguard Worker            Device("dev_3", 40, 0),
143*da0073e9SAndroid Build Coastguard Worker            Device("dev_4", 40, 0),
144*da0073e9SAndroid Build Coastguard Worker        ]
145*da0073e9SAndroid Build Coastguard Worker        partitioner_config = PartitionerConfig(devices, PartitionMode.size_based)
146*da0073e9SAndroid Build Coastguard Worker        catch_runtime_error = False
147*da0073e9SAndroid Build Coastguard Worker        try:
148*da0073e9SAndroid Build Coastguard Worker            ret = partitioner.partition_graph(traced, m, partitioner_config)
149*da0073e9SAndroid Build Coastguard Worker        except RuntimeError:
150*da0073e9SAndroid Build Coastguard Worker            catch_runtime_error = True
151*da0073e9SAndroid Build Coastguard Worker        assert catch_runtime_error
152*da0073e9SAndroid Build Coastguard Worker
153*da0073e9SAndroid Build Coastguard Worker    def test_partition_node_manipulation(self):
154*da0073e9SAndroid Build Coastguard Worker        class TestModule(torch.nn.Module):
155*da0073e9SAndroid Build Coastguard Worker            def forward(self, a, b):
156*da0073e9SAndroid Build Coastguard Worker                add_1 = a + b
157*da0073e9SAndroid Build Coastguard Worker                add_2 = add_1 + torch.rand(4)
158*da0073e9SAndroid Build Coastguard Worker                add_3 = add_2 + torch.rand(4)
159*da0073e9SAndroid Build Coastguard Worker                return add_3
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Worker        m = TestModule()
162*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(m)
163*da0073e9SAndroid Build Coastguard Worker        a, b = torch.rand(4), torch.rand(4)
164*da0073e9SAndroid Build Coastguard Worker        graph_manipulation.get_size_of_all_nodes(traced, [a, b])
165*da0073e9SAndroid Build Coastguard Worker        partitioner = Partitioner()
166*da0073e9SAndroid Build Coastguard Worker        devices = [Device("dev_0", 1000, 0)]
167*da0073e9SAndroid Build Coastguard Worker        partitioner_config = PartitionerConfig(devices)
168*da0073e9SAndroid Build Coastguard Worker        ret = partitioner.partition_graph(traced, m, partitioner_config)
169*da0073e9SAndroid Build Coastguard Worker        partition = partitioner.partitions[0]
170*da0073e9SAndroid Build Coastguard Worker        assert partition.used_mem_bytes == 112
171*da0073e9SAndroid Build Coastguard Worker        # Select add_2 node to remove
172*da0073e9SAndroid Build Coastguard Worker        selected_node = None
173*da0073e9SAndroid Build Coastguard Worker        for node in partition.nodes:
174*da0073e9SAndroid Build Coastguard Worker            if node.name == "add_2":
175*da0073e9SAndroid Build Coastguard Worker                selected_node = node
176*da0073e9SAndroid Build Coastguard Worker        partition.remove_node(selected_node)
177*da0073e9SAndroid Build Coastguard Worker        assert partition.used_mem_bytes == 80
178*da0073e9SAndroid Build Coastguard Worker
179*da0073e9SAndroid Build Coastguard Worker    def test_size_based_partition(self):
180*da0073e9SAndroid Build Coastguard Worker        class TestModule(torch.nn.Module):
181*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
182*da0073e9SAndroid Build Coastguard Worker                super().__init__()
183*da0073e9SAndroid Build Coastguard Worker                self.linear = torch.nn.Linear(4, 4)
184*da0073e9SAndroid Build Coastguard Worker                self.c = torch.rand(4)
185*da0073e9SAndroid Build Coastguard Worker
186*da0073e9SAndroid Build Coastguard Worker            def forward(self, a, b):
187*da0073e9SAndroid Build Coastguard Worker                add_1 = a + b
188*da0073e9SAndroid Build Coastguard Worker                linear = self.linear(add_1)
189*da0073e9SAndroid Build Coastguard Worker                add_2 = linear + self.c
190*da0073e9SAndroid Build Coastguard Worker                return add_2
191*da0073e9SAndroid Build Coastguard Worker
192*da0073e9SAndroid Build Coastguard Worker        m = TestModule()
193*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(m)
194*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(4)
195*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(4)
196*da0073e9SAndroid Build Coastguard Worker        graph_manipulation.get_size_of_all_nodes(traced, [a, b])
197*da0073e9SAndroid Build Coastguard Worker        partitioner = Partitioner()
198*da0073e9SAndroid Build Coastguard Worker        devices = [
199*da0073e9SAndroid Build Coastguard Worker            Device("dev_0", 125, 0),
200*da0073e9SAndroid Build Coastguard Worker            Device("dev_1", 125, 1),
201*da0073e9SAndroid Build Coastguard Worker            Device("dev_2", 125, 2),
202*da0073e9SAndroid Build Coastguard Worker        ]
203*da0073e9SAndroid Build Coastguard Worker        partitioner_config = PartitionerConfig(devices, PartitionMode.size_based)
204*da0073e9SAndroid Build Coastguard Worker        ret = partitioner.partition_graph(traced, m, partitioner_config)
205*da0073e9SAndroid Build Coastguard Worker        module_with_submodules = ret.module_with_submodules
206*da0073e9SAndroid Build Coastguard Worker        dag = ret.dag
207*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced(a, b), module_with_submodules(a, b))
208*da0073e9SAndroid Build Coastguard Worker        for i, node in enumerate(dag.nodes):
209*da0073e9SAndroid Build Coastguard Worker            assert node.logical_device_ids == [i]
210*da0073e9SAndroid Build Coastguard Worker
211*da0073e9SAndroid Build Coastguard Worker    def test_partition_device_mapping(self):
212*da0073e9SAndroid Build Coastguard Worker        class TestModule(torch.nn.Module):
213*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
214*da0073e9SAndroid Build Coastguard Worker                super().__init__()
215*da0073e9SAndroid Build Coastguard Worker                self.linear = torch.nn.Linear(4, 4)
216*da0073e9SAndroid Build Coastguard Worker
217*da0073e9SAndroid Build Coastguard Worker            def forward(self, a):
218*da0073e9SAndroid Build Coastguard Worker                b = torch.rand(4)
219*da0073e9SAndroid Build Coastguard Worker                add_1 = a + b
220*da0073e9SAndroid Build Coastguard Worker                linear_1 = self.linear(add_1)
221*da0073e9SAndroid Build Coastguard Worker                add_2 = torch.rand(4) + a
222*da0073e9SAndroid Build Coastguard Worker                add_3 = add_2 + linear_1
223*da0073e9SAndroid Build Coastguard Worker                return add_3
224*da0073e9SAndroid Build Coastguard Worker
225*da0073e9SAndroid Build Coastguard Worker        m = TestModule()
226*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(m)
227*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(4)
228*da0073e9SAndroid Build Coastguard Worker        graph_manipulation.get_size_of_all_nodes(traced, [a])
229*da0073e9SAndroid Build Coastguard Worker        partitioner = Partitioner()
230*da0073e9SAndroid Build Coastguard Worker        devices = [Device("dev_0", 120, 0), Device("dev_1", 160, 1)]
231*da0073e9SAndroid Build Coastguard Worker        partitioner_config = PartitionerConfig(devices, PartitionMode.size_based)
232*da0073e9SAndroid Build Coastguard Worker        ret = partitioner.partition_graph(traced, m, partitioner_config)
233*da0073e9SAndroid Build Coastguard Worker        module_with_submodules = ret.module_with_submodules
234*da0073e9SAndroid Build Coastguard Worker        dag = ret.dag
235*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced(a), module_with_submodules(a))
236*da0073e9SAndroid Build Coastguard Worker        for i, node in enumerate(dag.nodes):
237*da0073e9SAndroid Build Coastguard Worker            if i == 1:
238*da0073e9SAndroid Build Coastguard Worker                assert node.logical_device_ids == [1]
239*da0073e9SAndroid Build Coastguard Worker            else:
240*da0073e9SAndroid Build Coastguard Worker                assert node.logical_device_ids == [0]
241*da0073e9SAndroid Build Coastguard Worker
242*da0073e9SAndroid Build Coastguard Worker    def test_sparse_nn_partition(self):
243*da0073e9SAndroid Build Coastguard Worker        class MyRecommendationModule(torch.nn.Module):
244*da0073e9SAndroid Build Coastguard Worker            def create_mlp(self, num_of_layers: int, input_size: int, output_size: int):
245*da0073e9SAndroid Build Coastguard Worker                layers = torch.nn.ModuleList()
246*da0073e9SAndroid Build Coastguard Worker                for _ in range(num_of_layers):
247*da0073e9SAndroid Build Coastguard Worker                    ll = torch.nn.Linear(input_size, output_size)
248*da0073e9SAndroid Build Coastguard Worker                    layers.append(ll)
249*da0073e9SAndroid Build Coastguard Worker                    layers.append(torch.nn.ReLU())
250*da0073e9SAndroid Build Coastguard Worker                return layers
251*da0073e9SAndroid Build Coastguard Worker
252*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
253*da0073e9SAndroid Build Coastguard Worker                super().__init__()
254*da0073e9SAndroid Build Coastguard Worker                layers = self.create_mlp(4, 4, 4)
255*da0073e9SAndroid Build Coastguard Worker                self.bottom_layers = torch.nn.Sequential(*layers)
256*da0073e9SAndroid Build Coastguard Worker                layers = self.create_mlp(3, 24, 24)
257*da0073e9SAndroid Build Coastguard Worker                self.top_layers = torch.nn.Sequential(*layers)
258*da0073e9SAndroid Build Coastguard Worker                self.embedding_layers = torch.nn.ModuleList()
259*da0073e9SAndroid Build Coastguard Worker                el = torch.nn.EmbeddingBag(500000, 4, mode="sum", sparse=True)
260*da0073e9SAndroid Build Coastguard Worker                self.embedding_layers.append(el)
261*da0073e9SAndroid Build Coastguard Worker                for i in range(3):
262*da0073e9SAndroid Build Coastguard Worker                    el = torch.nn.EmbeddingBag(1000000, 4, mode="sum", sparse=True)
263*da0073e9SAndroid Build Coastguard Worker                    self.embedding_layers.append(el)
264*da0073e9SAndroid Build Coastguard Worker                el = torch.nn.EmbeddingBag(500000, 4, mode="sum", sparse=True)
265*da0073e9SAndroid Build Coastguard Worker                self.embedding_layers.append(el)
266*da0073e9SAndroid Build Coastguard Worker
267*da0073e9SAndroid Build Coastguard Worker            def forward(self, a, b, offset):
268*da0073e9SAndroid Build Coastguard Worker                x = self.bottom_layers(a)
269*da0073e9SAndroid Build Coastguard Worker                y = []
270*da0073e9SAndroid Build Coastguard Worker                c = []
271*da0073e9SAndroid Build Coastguard Worker                for i in range(len(self.embedding_layers)):
272*da0073e9SAndroid Build Coastguard Worker                    temp = torch.randint(10, (8,))
273*da0073e9SAndroid Build Coastguard Worker                    c.append(temp + b)
274*da0073e9SAndroid Build Coastguard Worker                for i in range(len(self.embedding_layers)):
275*da0073e9SAndroid Build Coastguard Worker                    if i % 2 == 0:
276*da0073e9SAndroid Build Coastguard Worker                        y.append(self.embedding_layers[i](c[i], offset))
277*da0073e9SAndroid Build Coastguard Worker                    else:
278*da0073e9SAndroid Build Coastguard Worker                        y.append(
279*da0073e9SAndroid Build Coastguard Worker                            self.embedding_layers[i](torch.randint(10, (8,)), offset)
280*da0073e9SAndroid Build Coastguard Worker                        )
281*da0073e9SAndroid Build Coastguard Worker                z = torch.cat([x] + y, dim=1)
282*da0073e9SAndroid Build Coastguard Worker                p = self.top_layers(z)
283*da0073e9SAndroid Build Coastguard Worker                return p
284*da0073e9SAndroid Build Coastguard Worker
285*da0073e9SAndroid Build Coastguard Worker        m = MyRecommendationModule()
286*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(2, 4)
287*da0073e9SAndroid Build Coastguard Worker        b = torch.randint(10, (8,))
288*da0073e9SAndroid Build Coastguard Worker        offset = torch.randint(1, (2,))
289*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(m)
290*da0073e9SAndroid Build Coastguard Worker        graph_manipulation.get_size_of_all_nodes(traced, [a, b, offset])
291*da0073e9SAndroid Build Coastguard Worker        devices = [
292*da0073e9SAndroid Build Coastguard Worker            Device("dev_0", 33000000, 0),
293*da0073e9SAndroid Build Coastguard Worker            Device("dev_1", 33000000, 1),
294*da0073e9SAndroid Build Coastguard Worker            Device("dev_2", 33000000, 2),
295*da0073e9SAndroid Build Coastguard Worker        ]
296*da0073e9SAndroid Build Coastguard Worker        partitioner_config = PartitionerConfig(devices, PartitionMode.sparse_nn)
297*da0073e9SAndroid Build Coastguard Worker        partitioner = Partitioner()
298*da0073e9SAndroid Build Coastguard Worker        ret = partitioner.partition_graph(traced, m, partitioner_config)
299*da0073e9SAndroid Build Coastguard Worker        module_with_submodules = ret.module_with_submodules
300*da0073e9SAndroid Build Coastguard Worker        dag = ret.dag
301*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced(a, b, offset), module_with_submodules(a, b, offset))
302*da0073e9SAndroid Build Coastguard Worker        assert len(module_with_submodules.graph.nodes) == 24
303*da0073e9SAndroid Build Coastguard Worker
304*da0073e9SAndroid Build Coastguard Worker    def test_partition_latency(self):
305*da0073e9SAndroid Build Coastguard Worker        class TestModule(torch.nn.Module):
306*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
307*da0073e9SAndroid Build Coastguard Worker                super().__init__()
308*da0073e9SAndroid Build Coastguard Worker                self.linear = torch.nn.Linear(4, 4)
309*da0073e9SAndroid Build Coastguard Worker
310*da0073e9SAndroid Build Coastguard Worker            def forward(self, a):
311*da0073e9SAndroid Build Coastguard Worker                add_1 = a + torch.rand(4)
312*da0073e9SAndroid Build Coastguard Worker                add_2 = add_1 + torch.rand(4)
313*da0073e9SAndroid Build Coastguard Worker                linear_1 = self.linear(add_1)
314*da0073e9SAndroid Build Coastguard Worker                add_3 = add_2 + linear_1
315*da0073e9SAndroid Build Coastguard Worker                add_4 = add_2 + add_3
316*da0073e9SAndroid Build Coastguard Worker                return add_4
317*da0073e9SAndroid Build Coastguard Worker
318*da0073e9SAndroid Build Coastguard Worker        def get_node_to_latency_mapping(fx_module: GraphModule):
319*da0073e9SAndroid Build Coastguard Worker            """Given a fx module, generate node latency for each node
320*da0073e9SAndroid Build Coastguard Worker            based on the size of each node
321*da0073e9SAndroid Build Coastguard Worker            """
322*da0073e9SAndroid Build Coastguard Worker            node_to_latency_mapping: Dict[Node, NodeLatency] = {}
323*da0073e9SAndroid Build Coastguard Worker            for node in fx_module.graph.nodes:
324*da0073e9SAndroid Build Coastguard Worker                if node.op not in {"output", "placeholder", "get_attr"}:
325*da0073e9SAndroid Build Coastguard Worker                    if node.size_bytes.total_size == node.size_bytes.output_size:
326*da0073e9SAndroid Build Coastguard Worker                        node_to_latency_mapping[node] = NodeLatency(
327*da0073e9SAndroid Build Coastguard Worker                            node.size_bytes.total_size, 2.0 * node.size_bytes.total_size
328*da0073e9SAndroid Build Coastguard Worker                        )
329*da0073e9SAndroid Build Coastguard Worker                    else:
330*da0073e9SAndroid Build Coastguard Worker                        node_to_latency_mapping[node] = NodeLatency(
331*da0073e9SAndroid Build Coastguard Worker                            node.size_bytes.total_size, node.size_bytes.output_size
332*da0073e9SAndroid Build Coastguard Worker                        )
333*da0073e9SAndroid Build Coastguard Worker            return node_to_latency_mapping
334*da0073e9SAndroid Build Coastguard Worker
335*da0073e9SAndroid Build Coastguard Worker        m = TestModule()
336*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(m)
337*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(4)
338*da0073e9SAndroid Build Coastguard Worker        graph_manipulation.get_size_of_all_nodes(traced, [a])
339*da0073e9SAndroid Build Coastguard Worker        node_to_latency_mapping = get_node_to_latency_mapping(traced)
340*da0073e9SAndroid Build Coastguard Worker        devices = [Device("dev_0", 200, 0), Device("dev_1", 200, 1)]
341*da0073e9SAndroid Build Coastguard Worker        partitioner = Partitioner()
342*da0073e9SAndroid Build Coastguard Worker        partitioner_config = PartitionerConfig(devices)
343*da0073e9SAndroid Build Coastguard Worker        ret = partitioner.partition_graph(traced, m, partitioner_config)
344*da0073e9SAndroid Build Coastguard Worker        module_with_submodules = ret.module_with_submodules
345*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced(a), module_with_submodules(a))
346*da0073e9SAndroid Build Coastguard Worker        partitions = partitioner.partitions
347*da0073e9SAndroid Build Coastguard Worker        partition_to_latency_mapping = get_partition_to_latency_mapping(
348*da0073e9SAndroid Build Coastguard Worker            partitions, node_to_latency_mapping
349*da0073e9SAndroid Build Coastguard Worker        )
350*da0073e9SAndroid Build Coastguard Worker        for p in partition_to_latency_mapping:
351*da0073e9SAndroid Build Coastguard Worker            if p.partition_id == 0:
352*da0073e9SAndroid Build Coastguard Worker                assert partition_to_latency_mapping[p] == (128.0, 80.0, 160.0)
353*da0073e9SAndroid Build Coastguard Worker            else:
354*da0073e9SAndroid Build Coastguard Worker                assert partition_to_latency_mapping[p] == (16.0, 32.0, 32.0)
355*da0073e9SAndroid Build Coastguard Worker        transfer_rate_bytes_per_sec = 2
356*da0073e9SAndroid Build Coastguard Worker        critical_path_latency_sec = get_latency_of_partitioned_graph(
357*da0073e9SAndroid Build Coastguard Worker            partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec
358*da0073e9SAndroid Build Coastguard Worker        )
359*da0073e9SAndroid Build Coastguard Worker        assert critical_path_latency_sec == 208.0
360*da0073e9SAndroid Build Coastguard Worker
361*da0073e9SAndroid Build Coastguard Worker    def test_cost_aware_partition(self):
362*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
363*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
364*da0073e9SAndroid Build Coastguard Worker                super().__init__()
365*da0073e9SAndroid Build Coastguard Worker                self.linear = torch.nn.Linear(4, 4)
366*da0073e9SAndroid Build Coastguard Worker
367*da0073e9SAndroid Build Coastguard Worker            def forward(self, a):
368*da0073e9SAndroid Build Coastguard Worker                add_1 = a + torch.rand(4)
369*da0073e9SAndroid Build Coastguard Worker                add_2 = add_1 + torch.rand(4)
370*da0073e9SAndroid Build Coastguard Worker                linear_1 = self.linear(add_1)
371*da0073e9SAndroid Build Coastguard Worker                add_3 = add_2 + torch.rand(4)
372*da0073e9SAndroid Build Coastguard Worker                add_4 = add_2 + linear_1
373*da0073e9SAndroid Build Coastguard Worker                add_5 = add_3 + add_4
374*da0073e9SAndroid Build Coastguard Worker                return add_5
375*da0073e9SAndroid Build Coastguard Worker
376*da0073e9SAndroid Build Coastguard Worker        def get_node_to_latency_mapping(fx_module: GraphModule):
377*da0073e9SAndroid Build Coastguard Worker            node_to_latency_mapping: Dict[Node, NodeLatency] = {}
378*da0073e9SAndroid Build Coastguard Worker            for node in fx_module.graph.nodes:
379*da0073e9SAndroid Build Coastguard Worker                if node.op not in {"output", "placeholder", "get_attr"}:
380*da0073e9SAndroid Build Coastguard Worker                    if node.size_bytes.total_size == node.size_bytes.output_size:
381*da0073e9SAndroid Build Coastguard Worker                        node_to_latency_mapping[node] = NodeLatency(
382*da0073e9SAndroid Build Coastguard Worker                            node.size_bytes.total_size, 1
383*da0073e9SAndroid Build Coastguard Worker                        )
384*da0073e9SAndroid Build Coastguard Worker                    else:
385*da0073e9SAndroid Build Coastguard Worker                        node_to_latency_mapping[node] = NodeLatency(
386*da0073e9SAndroid Build Coastguard Worker                            node.size_bytes.total_size, node.size_bytes.output_size
387*da0073e9SAndroid Build Coastguard Worker                        )
388*da0073e9SAndroid Build Coastguard Worker            return node_to_latency_mapping
389*da0073e9SAndroid Build Coastguard Worker
390*da0073e9SAndroid Build Coastguard Worker        m = MyModule()
391*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(m)
392*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(4)
393*da0073e9SAndroid Build Coastguard Worker        graph_manipulation.get_size_of_all_nodes(traced, [a])
394*da0073e9SAndroid Build Coastguard Worker        devices = [
395*da0073e9SAndroid Build Coastguard Worker            Device("dev_0", 125, 0),
396*da0073e9SAndroid Build Coastguard Worker            Device("dev_1", 125, 1),
397*da0073e9SAndroid Build Coastguard Worker            Device("dev_2", 125, 2),
398*da0073e9SAndroid Build Coastguard Worker            Device("dev_3", 125, 3),
399*da0073e9SAndroid Build Coastguard Worker        ]
400*da0073e9SAndroid Build Coastguard Worker        node_to_latency_mapping = get_node_to_latency_mapping(traced)
401*da0073e9SAndroid Build Coastguard Worker        partitioner_config = PartitionerConfig(
402*da0073e9SAndroid Build Coastguard Worker            devices,
403*da0073e9SAndroid Build Coastguard Worker            mode=PartitionMode.cost_aware,
404*da0073e9SAndroid Build Coastguard Worker            transfer_rate_bytes_per_sec=2,
405*da0073e9SAndroid Build Coastguard Worker            node_to_latency_mapping=node_to_latency_mapping,
406*da0073e9SAndroid Build Coastguard Worker        )
407*da0073e9SAndroid Build Coastguard Worker        partitioner = Partitioner()
408*da0073e9SAndroid Build Coastguard Worker        ret = partitioner.partition_graph(traced, m, partitioner_config)
409*da0073e9SAndroid Build Coastguard Worker        module_with_submodules = ret.module_with_submodules
410*da0073e9SAndroid Build Coastguard Worker        dag = ret.dag
411*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced(a), module_with_submodules(a))
412*da0073e9SAndroid Build Coastguard Worker        partitions = partitioner.partitions
413*da0073e9SAndroid Build Coastguard Worker        partition_to_latency_mapping = get_partition_to_latency_mapping(
414*da0073e9SAndroid Build Coastguard Worker            partitions, node_to_latency_mapping
415*da0073e9SAndroid Build Coastguard Worker        )
416*da0073e9SAndroid Build Coastguard Worker        critical_path_latency_sec = get_latency_of_partitioned_graph(
417*da0073e9SAndroid Build Coastguard Worker            partitions,
418*da0073e9SAndroid Build Coastguard Worker            partition_to_latency_mapping,
419*da0073e9SAndroid Build Coastguard Worker            partitioner_config.transfer_rate_bytes_per_sec,
420*da0073e9SAndroid Build Coastguard Worker        )
421*da0073e9SAndroid Build Coastguard Worker        assert critical_path_latency_sec == 160.0
422*da0073e9SAndroid Build Coastguard Worker
423*da0073e9SAndroid Build Coastguard Worker    def test_aot_based_partition(self):
424*da0073e9SAndroid Build Coastguard Worker        class TestModule(torch.nn.Module):
425*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
426*da0073e9SAndroid Build Coastguard Worker                super().__init__()
427*da0073e9SAndroid Build Coastguard Worker                self.b = torch.rand(4)
428*da0073e9SAndroid Build Coastguard Worker                self.c = torch.rand(4)
429*da0073e9SAndroid Build Coastguard Worker
430*da0073e9SAndroid Build Coastguard Worker            def forward(self, a):
431*da0073e9SAndroid Build Coastguard Worker                add_1 = a + self.b
432*da0073e9SAndroid Build Coastguard Worker                add_2 = self.c + add_1
433*da0073e9SAndroid Build Coastguard Worker                return add_2
434*da0073e9SAndroid Build Coastguard Worker
435*da0073e9SAndroid Build Coastguard Worker        m = TestModule()
436*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(m)
437*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(4)
438*da0073e9SAndroid Build Coastguard Worker        node_to_partition_id = {}
439*da0073e9SAndroid Build Coastguard Worker        partition_to_logical_devices = {}
440*da0073e9SAndroid Build Coastguard Worker        count = 0
441*da0073e9SAndroid Build Coastguard Worker        graph_manipulation.get_size_of_all_nodes(traced, [a])
442*da0073e9SAndroid Build Coastguard Worker        for node in traced.graph.nodes:
443*da0073e9SAndroid Build Coastguard Worker            if node.op not in {"placeholder", "get_attr", "output"}:
444*da0073e9SAndroid Build Coastguard Worker                node_to_partition_id[node] = count
445*da0073e9SAndroid Build Coastguard Worker                partition_to_logical_devices[count] = [0]
446*da0073e9SAndroid Build Coastguard Worker                count += 1
447*da0073e9SAndroid Build Coastguard Worker        devices = [Device("dev_0", 200, 0)]
448*da0073e9SAndroid Build Coastguard Worker        partitioner_config = PartitionerConfig(
449*da0073e9SAndroid Build Coastguard Worker            devices=devices,
450*da0073e9SAndroid Build Coastguard Worker            mode=PartitionMode.aot_based,
451*da0073e9SAndroid Build Coastguard Worker            node_to_partition_mapping=node_to_partition_id,
452*da0073e9SAndroid Build Coastguard Worker            partition_to_logical_device_mapping=partition_to_logical_devices,
453*da0073e9SAndroid Build Coastguard Worker        )
454*da0073e9SAndroid Build Coastguard Worker        partitioner = Partitioner()
455*da0073e9SAndroid Build Coastguard Worker        ret = partitioner.partition_graph(traced, m, partitioner_config)
456*da0073e9SAndroid Build Coastguard Worker        module_with_submodules = ret.module_with_submodules
457*da0073e9SAndroid Build Coastguard Worker        dag = ret.dag
458*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(module_with_submodules(a), traced(a))
459*da0073e9SAndroid Build Coastguard Worker        for node in dag.nodes:
460*da0073e9SAndroid Build Coastguard Worker            assert node.size_bytes == 48
461*da0073e9SAndroid Build Coastguard Worker            assert node.logical_device_ids == [0]
462*da0073e9SAndroid Build Coastguard Worker
463*da0073e9SAndroid Build Coastguard Worker    def test_replace_target_nodes_with(self):
464*da0073e9SAndroid Build Coastguard Worker        class testModule(torch.nn.Module):
465*da0073e9SAndroid Build Coastguard Worker            def forward(self, a, b):
466*da0073e9SAndroid Build Coastguard Worker                return a + b
467*da0073e9SAndroid Build Coastguard Worker
468*da0073e9SAndroid Build Coastguard Worker        m = testModule()
469*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(m)
470*da0073e9SAndroid Build Coastguard Worker        input1 = torch.randn(1)
471*da0073e9SAndroid Build Coastguard Worker        input2 = torch.randn(1)
472*da0073e9SAndroid Build Coastguard Worker        assert (input1 + input2) == traced(input1, input2)
473*da0073e9SAndroid Build Coastguard Worker        graph_manipulation.replace_target_nodes_with(
474*da0073e9SAndroid Build Coastguard Worker            fx_module=traced,
475*da0073e9SAndroid Build Coastguard Worker            old_op="call_function",
476*da0073e9SAndroid Build Coastguard Worker            old_target=operator.add,
477*da0073e9SAndroid Build Coastguard Worker            new_op="call_function",
478*da0073e9SAndroid Build Coastguard Worker            new_target=operator.mul,
479*da0073e9SAndroid Build Coastguard Worker        )
480*da0073e9SAndroid Build Coastguard Worker        assert (input1 * input2) == traced(input1, input2)
481*da0073e9SAndroid Build Coastguard Worker
482*da0073e9SAndroid Build Coastguard Worker    def test_saturate_host(self):
483*da0073e9SAndroid Build Coastguard Worker        class TestModule(torch.nn.Module):
484*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
485*da0073e9SAndroid Build Coastguard Worker                super().__init__()
486*da0073e9SAndroid Build Coastguard Worker                self.linear = torch.nn.Linear(4, 4)
487*da0073e9SAndroid Build Coastguard Worker
488*da0073e9SAndroid Build Coastguard Worker            def forward(self, a):
489*da0073e9SAndroid Build Coastguard Worker                add_1 = a + torch.rand(4)
490*da0073e9SAndroid Build Coastguard Worker                add_2 = add_1 + torch.rand(4)
491*da0073e9SAndroid Build Coastguard Worker                linear_1 = self.linear(add_1)
492*da0073e9SAndroid Build Coastguard Worker                add_3 = add_2 + linear_1
493*da0073e9SAndroid Build Coastguard Worker                add_4 = add_2 + add_3
494*da0073e9SAndroid Build Coastguard Worker                return add_4
495*da0073e9SAndroid Build Coastguard Worker
496*da0073e9SAndroid Build Coastguard Worker        m = TestModule()
497*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(m)
498*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(4)
499*da0073e9SAndroid Build Coastguard Worker        graph_manipulation.get_size_of_all_nodes(traced, [a])
500*da0073e9SAndroid Build Coastguard Worker        devices = [
501*da0073e9SAndroid Build Coastguard Worker            Device("dev_0", 200, 0),
502*da0073e9SAndroid Build Coastguard Worker            Device("dev_1", 200, 1),
503*da0073e9SAndroid Build Coastguard Worker            Device("dev_2", 100, 2),
504*da0073e9SAndroid Build Coastguard Worker            Device("dev_3", 100, 3),
505*da0073e9SAndroid Build Coastguard Worker            Device("dev_4", 200, 4),
506*da0073e9SAndroid Build Coastguard Worker            Device("dev_5", 100, 5),
507*da0073e9SAndroid Build Coastguard Worker        ]
508*da0073e9SAndroid Build Coastguard Worker        partitioner = Partitioner()
509*da0073e9SAndroid Build Coastguard Worker        # Without host saturation, the model will be split into two partitions.
510*da0073e9SAndroid Build Coastguard Worker        # dev_0 holds partition 0 of 192 bytes and dev_1 holds partition 1 of 48 bytes.
511*da0073e9SAndroid Build Coastguard Worker        partitioner_config = PartitionerConfig(devices, saturate_host=True)
512*da0073e9SAndroid Build Coastguard Worker        ret = partitioner.partition_graph(traced, m, partitioner_config)
513*da0073e9SAndroid Build Coastguard Worker        module_with_submodules = ret.module_with_submodules
514*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced(a), module_with_submodules(a))
515*da0073e9SAndroid Build Coastguard Worker
516*da0073e9SAndroid Build Coastguard Worker        partitions = partitioner.partitions
517*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(partitions), 2)
518*da0073e9SAndroid Build Coastguard Worker        # With host saturation, partition 1 will be replicated to dev_4, and partition 2
519*da0073e9SAndroid Build Coastguard Worker        # will be replicated to dev_2.
520*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(partitions[0].logical_device_ids, [0, 4])
521*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(partitions[1].logical_device_ids, [1, 2])
522*da0073e9SAndroid Build Coastguard Worker
523*da0073e9SAndroid Build Coastguard Worker    @skipIfNoTorchVision
524*da0073e9SAndroid Build Coastguard Worker    def test_conv_bn_fusion(self):
525*da0073e9SAndroid Build Coastguard Worker        rn18 = resnet18().eval()
526*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(rn18)
527*da0073e9SAndroid Build Coastguard Worker        fused = optimization.fuse(traced)
528*da0073e9SAndroid Build Coastguard Worker
529*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
530*da0073e9SAndroid Build Coastguard Worker            all(not isinstance(m, torch.nn.BatchNorm2d) for m in fused.modules())
531*da0073e9SAndroid Build Coastguard Worker        )
532*da0073e9SAndroid Build Coastguard Worker
533*da0073e9SAndroid Build Coastguard Worker        N, C, H, W = 20, 3, 224, 224
534*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(N, C, H, W)
535*da0073e9SAndroid Build Coastguard Worker
536*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fused(inp), rn18(inp))
537*da0073e9SAndroid Build Coastguard Worker
538*da0073e9SAndroid Build Coastguard Worker    def test_conv_bn_fusion_not_running_state(self):
539*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
540*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
541*da0073e9SAndroid Build Coastguard Worker                super().__init__()
542*da0073e9SAndroid Build Coastguard Worker                self.conv = torch.nn.Conv2d(32, 64, 3, stride=2)
543*da0073e9SAndroid Build Coastguard Worker                self.bn = torch.nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
544*da0073e9SAndroid Build Coastguard Worker
545*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
546*da0073e9SAndroid Build Coastguard Worker                x = self.conv(x)
547*da0073e9SAndroid Build Coastguard Worker                x = self.bn(x)
548*da0073e9SAndroid Build Coastguard Worker                return x
549*da0073e9SAndroid Build Coastguard Worker
550*da0073e9SAndroid Build Coastguard Worker        model = M().eval()
551*da0073e9SAndroid Build Coastguard Worker
552*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(model)
553*da0073e9SAndroid Build Coastguard Worker        fused = optimization.fuse(traced)
554*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn([1, 32, 50, 50])
555*da0073e9SAndroid Build Coastguard Worker
556*da0073e9SAndroid Build Coastguard Worker        # bn need not be folded in conv
557*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
558*da0073e9SAndroid Build Coastguard Worker            any(isinstance(m, torch.nn.BatchNorm2d) for m in fused.modules())
559*da0073e9SAndroid Build Coastguard Worker        )
560*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fused(inp), model(inp))
561*da0073e9SAndroid Build Coastguard Worker
562*da0073e9SAndroid Build Coastguard Worker    def test_conv_bn_fusion_mixed_dtype(self):
563*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
564*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
565*da0073e9SAndroid Build Coastguard Worker                super().__init__()
566*da0073e9SAndroid Build Coastguard Worker                self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False, dtype=torch.bfloat16)
567*da0073e9SAndroid Build Coastguard Worker                self.bn = torch.nn.BatchNorm2d(16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
568*da0073e9SAndroid Build Coastguard Worker
569*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
570*da0073e9SAndroid Build Coastguard Worker                x = self.conv(x)
571*da0073e9SAndroid Build Coastguard Worker                x = self.bn(x)
572*da0073e9SAndroid Build Coastguard Worker                return x
573*da0073e9SAndroid Build Coastguard Worker
574*da0073e9SAndroid Build Coastguard Worker        model = M().eval()
575*da0073e9SAndroid Build Coastguard Worker
576*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(model)
577*da0073e9SAndroid Build Coastguard Worker        fused = optimization.fuse(traced)
578*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(1, 3, 64, 64, dtype=torch.bfloat16)
579*da0073e9SAndroid Build Coastguard Worker
580*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
581*da0073e9SAndroid Build Coastguard Worker            all(not isinstance(m, torch.nn.BatchNorm2d) for m in fused.modules())
582*da0073e9SAndroid Build Coastguard Worker        )
583*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fused(inp), model(inp))
584*da0073e9SAndroid Build Coastguard Worker
585*da0073e9SAndroid Build Coastguard Worker    def test_call_to_assert_no_msg(self):
586*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
587*da0073e9SAndroid Build Coastguard Worker            def forward(self, a, b):
588*da0073e9SAndroid Build Coastguard Worker                assert a == b
589*da0073e9SAndroid Build Coastguard Worker                return a + b
590*da0073e9SAndroid Build Coastguard Worker
591*da0073e9SAndroid Build Coastguard Worker        m = M()
592*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace_with_rewrite(m)
593*da0073e9SAndroid Build Coastguard Worker
594*da0073e9SAndroid Build Coastguard Worker        # Make sure the graph is well-formed
595*da0073e9SAndroid Build Coastguard Worker        traced.graph.lint()
596*da0073e9SAndroid Build Coastguard Worker
597*da0073e9SAndroid Build Coastguard Worker        # Check the IR to make sure there's a call_function node with target == "Assert"
598*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
599*da0073e9SAndroid Build Coastguard Worker            any(
600*da0073e9SAndroid Build Coastguard Worker                node.op == "call_function" and node.target == torch._assert
601*da0073e9SAndroid Build Coastguard Worker                for node in traced.graph.nodes
602*da0073e9SAndroid Build Coastguard Worker            )
603*da0073e9SAndroid Build Coastguard Worker        )
604*da0073e9SAndroid Build Coastguard Worker
605*da0073e9SAndroid Build Coastguard Worker        # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to
606*da0073e9SAndroid Build Coastguard Worker        traced(3, 3)
607*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(AssertionError, ""):
608*da0073e9SAndroid Build Coastguard Worker            traced(3, 5)
609*da0073e9SAndroid Build Coastguard Worker
610*da0073e9SAndroid Build Coastguard Worker        # Confirm that the output is correct
611*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced(3, 3), m(3, 3))
612*da0073e9SAndroid Build Coastguard Worker
613*da0073e9SAndroid Build Coastguard Worker    def test_meta_tracer(self):
614*da0073e9SAndroid Build Coastguard Worker        class MetaTracerTestModule(torch.nn.Module):
615*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
616*da0073e9SAndroid Build Coastguard Worker                super().__init__()
617*da0073e9SAndroid Build Coastguard Worker                self.emb = torch.nn.Embedding(num_embeddings=42, embedding_dim=16)
618*da0073e9SAndroid Build Coastguard Worker                self.layernorm = torch.nn.LayerNorm(16)
619*da0073e9SAndroid Build Coastguard Worker
620*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
621*da0073e9SAndroid Build Coastguard Worker                emb = self.emb(x)
622*da0073e9SAndroid Build Coastguard Worker                emb = emb + torch.arange(emb.shape[-1], dtype=torch.float, device=emb.device)
623*da0073e9SAndroid Build Coastguard Worker                lol = self.layernorm(emb)
624*da0073e9SAndroid Build Coastguard Worker                return torch.relu(lol) if lol.shape[0] < 30 else torch.sigmoid(lol)
625*da0073e9SAndroid Build Coastguard Worker
626*da0073e9SAndroid Build Coastguard Worker        mttm = MetaTracerTestModule()
627*da0073e9SAndroid Build Coastguard Worker        for BS in [15, 35]:
628*da0073e9SAndroid Build Coastguard Worker            x = torch.zeros(BS, dtype=torch.long).random_(42)
629*da0073e9SAndroid Build Coastguard Worker            meta_args = {'x' : x.to(device='meta')}
630*da0073e9SAndroid Build Coastguard Worker            gm = torch.fx.experimental.meta_tracer.symbolic_trace(mttm, meta_args=meta_args)
631*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(gm(x), mttm(x))
632*da0073e9SAndroid Build Coastguard Worker
633*da0073e9SAndroid Build Coastguard Worker            # Test serialization/deserialization
634*da0073e9SAndroid Build Coastguard Worker            with tempfile.TemporaryDirectory() as tmp_dir:
635*da0073e9SAndroid Build Coastguard Worker                with open(f'{tmp_dir}/meta_module.pkl', 'wb') as f:
636*da0073e9SAndroid Build Coastguard Worker                    pickle.dump(gm, f)
637*da0073e9SAndroid Build Coastguard Worker
638*da0073e9SAndroid Build Coastguard Worker                with open(f'{tmp_dir}/meta_module.pkl', 'rb') as f:
639*da0073e9SAndroid Build Coastguard Worker                    loaded = pickle.load(f)
640*da0073e9SAndroid Build Coastguard Worker
641*da0073e9SAndroid Build Coastguard Worker                torch.testing.assert_close(loaded(x), mttm(x))
642*da0073e9SAndroid Build Coastguard Worker
643*da0073e9SAndroid Build Coastguard Worker
644*da0073e9SAndroid Build Coastguard Worker    def test_call_to_assert_with_msg(self):
645*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
646*da0073e9SAndroid Build Coastguard Worker            def forward(self, a, b):
647*da0073e9SAndroid Build Coastguard Worker                assert a == b, "test message"
648*da0073e9SAndroid Build Coastguard Worker                return a + b
649*da0073e9SAndroid Build Coastguard Worker
650*da0073e9SAndroid Build Coastguard Worker        m = M()
651*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace_with_rewrite(m)
652*da0073e9SAndroid Build Coastguard Worker
653*da0073e9SAndroid Build Coastguard Worker        # Make sure the graph is well-formed
654*da0073e9SAndroid Build Coastguard Worker        traced.graph.lint()
655*da0073e9SAndroid Build Coastguard Worker
656*da0073e9SAndroid Build Coastguard Worker        # Check the IR to make sure there's a call_function node with target == "Assert"
657*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
658*da0073e9SAndroid Build Coastguard Worker            any(
659*da0073e9SAndroid Build Coastguard Worker                node.op == "call_function" and node.target == torch._assert
660*da0073e9SAndroid Build Coastguard Worker                for node in traced.graph.nodes
661*da0073e9SAndroid Build Coastguard Worker            )
662*da0073e9SAndroid Build Coastguard Worker        )
663*da0073e9SAndroid Build Coastguard Worker
664*da0073e9SAndroid Build Coastguard Worker        # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to
665*da0073e9SAndroid Build Coastguard Worker        traced(3, 3)
666*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(AssertionError, "test message"):
667*da0073e9SAndroid Build Coastguard Worker            traced(3, 5)
668*da0073e9SAndroid Build Coastguard Worker
669*da0073e9SAndroid Build Coastguard Worker        # Confirm that the output is correct
670*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced(3, 3), m(3, 3))
671*da0073e9SAndroid Build Coastguard Worker
672*da0073e9SAndroid Build Coastguard Worker    def test_call_to_assert_with_empty_msg(self):
673*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
674*da0073e9SAndroid Build Coastguard Worker            def forward(self, a, b):
675*da0073e9SAndroid Build Coastguard Worker                assert a == b, ""
676*da0073e9SAndroid Build Coastguard Worker                return a + b
677*da0073e9SAndroid Build Coastguard Worker
678*da0073e9SAndroid Build Coastguard Worker        m = M()
679*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace_with_rewrite(m)
680*da0073e9SAndroid Build Coastguard Worker
681*da0073e9SAndroid Build Coastguard Worker        # Make sure the graph is well-formed
682*da0073e9SAndroid Build Coastguard Worker        traced.graph.lint()
683*da0073e9SAndroid Build Coastguard Worker
684*da0073e9SAndroid Build Coastguard Worker        # Check the IR to make sure there's a call_function node with target == "Assert"
685*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
686*da0073e9SAndroid Build Coastguard Worker            any(
687*da0073e9SAndroid Build Coastguard Worker                node.op == "call_function" and node.target == torch._assert
688*da0073e9SAndroid Build Coastguard Worker                for node in traced.graph.nodes
689*da0073e9SAndroid Build Coastguard Worker            )
690*da0073e9SAndroid Build Coastguard Worker        )
691*da0073e9SAndroid Build Coastguard Worker
692*da0073e9SAndroid Build Coastguard Worker        # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to
693*da0073e9SAndroid Build Coastguard Worker        traced(3, 3)
694*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(AssertionError, ""):
695*da0073e9SAndroid Build Coastguard Worker            traced(3, 5)
696*da0073e9SAndroid Build Coastguard Worker
697*da0073e9SAndroid Build Coastguard Worker        # Confirm that the output is correct
698*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced(3, 3), m(3, 3))
699*da0073e9SAndroid Build Coastguard Worker
700*da0073e9SAndroid Build Coastguard Worker    def test_call_to_assert_with_multiline_message(self):
701*da0073e9SAndroid Build Coastguard Worker        class M(torch.nn.Module):
702*da0073e9SAndroid Build Coastguard Worker            def forward(self, a, b):
703*da0073e9SAndroid Build Coastguard Worker                error_msg = """
704*da0073e9SAndroid Build Coastguard WorkerAn error message with
705*da0073e9SAndroid Build Coastguard Workerterrible spacing
706*da0073e9SAndroid Build Coastguard Worker                """
707*da0073e9SAndroid Build Coastguard Worker                assert a == b, error_msg
708*da0073e9SAndroid Build Coastguard Worker                return a + b
709*da0073e9SAndroid Build Coastguard Worker
710*da0073e9SAndroid Build Coastguard Worker        m = M()
711*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace_with_rewrite(m)
712*da0073e9SAndroid Build Coastguard Worker
713*da0073e9SAndroid Build Coastguard Worker        # Make sure the graph is well-formed
714*da0073e9SAndroid Build Coastguard Worker        traced.graph.lint()
715*da0073e9SAndroid Build Coastguard Worker
716*da0073e9SAndroid Build Coastguard Worker        # Check the IR to make sure there's a call_function node with target == "Assert"
717*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
718*da0073e9SAndroid Build Coastguard Worker            any(
719*da0073e9SAndroid Build Coastguard Worker                node.op == "call_function" and node.target == torch._assert
720*da0073e9SAndroid Build Coastguard Worker                for node in traced.graph.nodes
721*da0073e9SAndroid Build Coastguard Worker            )
722*da0073e9SAndroid Build Coastguard Worker        )
723*da0073e9SAndroid Build Coastguard Worker
724*da0073e9SAndroid Build Coastguard Worker        # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to
725*da0073e9SAndroid Build Coastguard Worker        error_msg = """
726*da0073e9SAndroid Build Coastguard WorkerAn error message with
727*da0073e9SAndroid Build Coastguard Workerterrible spacing
728*da0073e9SAndroid Build Coastguard Worker    """
729*da0073e9SAndroid Build Coastguard Worker        traced(3, 3)
730*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(AssertionError, error_msg):
731*da0073e9SAndroid Build Coastguard Worker            traced(3, 5)
732*da0073e9SAndroid Build Coastguard Worker
733*da0073e9SAndroid Build Coastguard Worker        # Confirm that the output is correct
734*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced(3, 3), m(3, 3))
735*da0073e9SAndroid Build Coastguard Worker
736*da0073e9SAndroid Build Coastguard Worker    def test_subgraph_creation(self):
737*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
738*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
739*da0073e9SAndroid Build Coastguard Worker                super().__init__()
740*da0073e9SAndroid Build Coastguard Worker                self.param = torch.nn.Parameter(torch.rand(3, 4))
741*da0073e9SAndroid Build Coastguard Worker                self.linear = torch.nn.Linear(4, 5)
742*da0073e9SAndroid Build Coastguard Worker
743*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, y):
744*da0073e9SAndroid Build Coastguard Worker                z = self.linear(x + self.param).clamp(min=0.0, max=1.0)
745*da0073e9SAndroid Build Coastguard Worker                w = self.linear(y).clamp(min=0.0, max=1.0)
746*da0073e9SAndroid Build Coastguard Worker                return z + w
747*da0073e9SAndroid Build Coastguard Worker
748*da0073e9SAndroid Build Coastguard Worker        # symbolically trace model
749*da0073e9SAndroid Build Coastguard Worker        my_module = MyModule()
750*da0073e9SAndroid Build Coastguard Worker        my_module_traced = symbolic_trace(my_module)
751*da0073e9SAndroid Build Coastguard Worker
752*da0073e9SAndroid Build Coastguard Worker        # random mod partitioning
753*da0073e9SAndroid Build Coastguard Worker        partition_counter = 0
754*da0073e9SAndroid Build Coastguard Worker        NPARTITIONS = 3
755*da0073e9SAndroid Build Coastguard Worker
756*da0073e9SAndroid Build Coastguard Worker        # Add some random meta info to make sure it is kept around.
757*da0073e9SAndroid Build Coastguard Worker        for node in my_module_traced.graph.nodes:
758*da0073e9SAndroid Build Coastguard Worker            if node.op != "output":
759*da0073e9SAndroid Build Coastguard Worker                node.meta["test_meta_info"] = True
760*da0073e9SAndroid Build Coastguard Worker
761*da0073e9SAndroid Build Coastguard Worker        def mod_partition(node: Node):
762*da0073e9SAndroid Build Coastguard Worker            nonlocal partition_counter
763*da0073e9SAndroid Build Coastguard Worker            partition = partition_counter % NPARTITIONS
764*da0073e9SAndroid Build Coastguard Worker            partition_counter = (partition_counter + 1) % NPARTITIONS
765*da0073e9SAndroid Build Coastguard Worker            return partition
766*da0073e9SAndroid Build Coastguard Worker
767*da0073e9SAndroid Build Coastguard Worker        # split module in module with submodules
768*da0073e9SAndroid Build Coastguard Worker        module_with_submodules = split_module(
769*da0073e9SAndroid Build Coastguard Worker            my_module_traced, my_module, mod_partition
770*da0073e9SAndroid Build Coastguard Worker        )
771*da0073e9SAndroid Build Coastguard Worker
772*da0073e9SAndroid Build Coastguard Worker        # Check that test_meta_info was still on all nodes.
773*da0073e9SAndroid Build Coastguard Worker        submodules = dict(module_with_submodules.named_modules())
774*da0073e9SAndroid Build Coastguard Worker        for node in module_with_submodules.graph.nodes:
775*da0073e9SAndroid Build Coastguard Worker            if node.op == "call_module":
776*da0073e9SAndroid Build Coastguard Worker                submod = submodules[node.target]
777*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(isinstance(submod, torch.fx.GraphModule))
778*da0073e9SAndroid Build Coastguard Worker                for submod_node in submod.graph.nodes:
779*da0073e9SAndroid Build Coastguard Worker                    if submod_node.op != "output":
780*da0073e9SAndroid Build Coastguard Worker                        stored_op = submod_node.meta.get("test_meta_info")
781*da0073e9SAndroid Build Coastguard Worker                        self.assertTrue(stored_op is not None and stored_op)
782*da0073e9SAndroid Build Coastguard Worker
783*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(3, 4)
784*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(3, 4)
785*da0073e9SAndroid Build Coastguard Worker
786*da0073e9SAndroid Build Coastguard Worker        orig_out = my_module_traced(x, y)
787*da0073e9SAndroid Build Coastguard Worker        submodules_out = module_with_submodules(x, y)
788*da0073e9SAndroid Build Coastguard Worker
789*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(orig_out, submodules_out)
790*da0073e9SAndroid Build Coastguard Worker
791*da0073e9SAndroid Build Coastguard Worker    def test_split_module_dead_code(self):
792*da0073e9SAndroid Build Coastguard Worker        class ModWithDeadCode(torch.nn.Module):
793*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
794*da0073e9SAndroid Build Coastguard Worker                output = x * 2  # we want this
795*da0073e9SAndroid Build Coastguard Worker                dead_line = x + 2  # this is dead
796*da0073e9SAndroid Build Coastguard Worker                return output
797*da0073e9SAndroid Build Coastguard Worker
798*da0073e9SAndroid Build Coastguard Worker        mod = ModWithDeadCode()
799*da0073e9SAndroid Build Coastguard Worker        traced = torch.fx.symbolic_trace(mod)
800*da0073e9SAndroid Build Coastguard Worker
801*da0073e9SAndroid Build Coastguard Worker        # split into before (0), target (1), and after(2)
802*da0073e9SAndroid Build Coastguard Worker        saw_mul = False
803*da0073e9SAndroid Build Coastguard Worker
804*da0073e9SAndroid Build Coastguard Worker        def split_callback(n):
805*da0073e9SAndroid Build Coastguard Worker            nonlocal saw_mul
806*da0073e9SAndroid Build Coastguard Worker            if n.target == operator.mul:
807*da0073e9SAndroid Build Coastguard Worker                saw_mul = True
808*da0073e9SAndroid Build Coastguard Worker                return 1
809*da0073e9SAndroid Build Coastguard Worker
810*da0073e9SAndroid Build Coastguard Worker            if not saw_mul:
811*da0073e9SAndroid Build Coastguard Worker                return 0
812*da0073e9SAndroid Build Coastguard Worker            if saw_mul:
813*da0073e9SAndroid Build Coastguard Worker                return 2
814*da0073e9SAndroid Build Coastguard Worker
815*da0073e9SAndroid Build Coastguard Worker        split = split_module(traced, mod, split_callback)
816*da0073e9SAndroid Build Coastguard Worker
817*da0073e9SAndroid Build Coastguard Worker        x = torch.randn((5,))
818*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(
819*da0073e9SAndroid Build Coastguard Worker            split(x), traced(x)
820*da0073e9SAndroid Build Coastguard Worker        )
821*da0073e9SAndroid Build Coastguard Worker
822*da0073e9SAndroid Build Coastguard Worker
823*da0073e9SAndroid Build Coastguard Worker    def test_split_module_kwargs_expansion(self):
824*da0073e9SAndroid Build Coastguard Worker        class ModuleWithKwargsExpansion(torch.nn.Module):
825*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, **kwargs):
826*da0073e9SAndroid Build Coastguard Worker                return x + kwargs['foo']
827*da0073e9SAndroid Build Coastguard Worker
828*da0073e9SAndroid Build Coastguard Worker        mod = ModuleWithKwargsExpansion()
829*da0073e9SAndroid Build Coastguard Worker        traced = torch.fx.symbolic_trace(mod)
830*da0073e9SAndroid Build Coastguard Worker
831*da0073e9SAndroid Build Coastguard Worker        seen_getitem = False
832*da0073e9SAndroid Build Coastguard Worker
833*da0073e9SAndroid Build Coastguard Worker        def split_callback(n):
834*da0073e9SAndroid Build Coastguard Worker            nonlocal seen_getitem
835*da0073e9SAndroid Build Coastguard Worker            split_idx = int(seen_getitem)
836*da0073e9SAndroid Build Coastguard Worker            if n.target == operator.getitem:
837*da0073e9SAndroid Build Coastguard Worker                seen_getitem = True
838*da0073e9SAndroid Build Coastguard Worker            return split_idx
839*da0073e9SAndroid Build Coastguard Worker
840*da0073e9SAndroid Build Coastguard Worker        split = split_module(traced, mod, split_callback)
841*da0073e9SAndroid Build Coastguard Worker
842*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 3)
843*da0073e9SAndroid Build Coastguard Worker        foo = torch.randn(5, 3)
844*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(split(x, foo=foo), traced(x, foo=foo))
845*da0073e9SAndroid Build Coastguard Worker
846*da0073e9SAndroid Build Coastguard Worker    @skipIfNoTorchVision
847*da0073e9SAndroid Build Coastguard Worker    def test_subgraph_trivial_resnet(self):
848*da0073e9SAndroid Build Coastguard Worker        # Smoke test trivially splitting resnet into 1 partition works
849*da0073e9SAndroid Build Coastguard Worker        # There was an issue before causing submodule names to be aliased
850*da0073e9SAndroid Build Coastguard Worker        m = resnet18()
851*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(m)
852*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(64, 3, 7, 7)
853*da0073e9SAndroid Build Coastguard Worker        module_with_submodules = split_module(traced, m, lambda node: 0)
854*da0073e9SAndroid Build Coastguard Worker        module_with_submodules(a)
855*da0073e9SAndroid Build Coastguard Worker
856*da0073e9SAndroid Build Coastguard Worker    def test_split_module_default_arg(self):
857*da0073e9SAndroid Build Coastguard Worker        class ModelToTrace(torch.nn.Module):
858*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
859*da0073e9SAndroid Build Coastguard Worker                super().__init__()
860*da0073e9SAndroid Build Coastguard Worker                self.lin = torch.nn.Linear(512, 512)
861*da0073e9SAndroid Build Coastguard Worker
862*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, targets=None):
863*da0073e9SAndroid Build Coastguard Worker                x = self.lin(x)
864*da0073e9SAndroid Build Coastguard Worker
865*da0073e9SAndroid Build Coastguard Worker                if targets is not None:
866*da0073e9SAndroid Build Coastguard Worker                    x = x + targets
867*da0073e9SAndroid Build Coastguard Worker
868*da0073e9SAndroid Build Coastguard Worker                return x
869*da0073e9SAndroid Build Coastguard Worker
870*da0073e9SAndroid Build Coastguard Worker        mtt = ModelToTrace()
871*da0073e9SAndroid Build Coastguard Worker        traced = torch.fx.symbolic_trace(mtt, concrete_args={'targets': None})
872*da0073e9SAndroid Build Coastguard Worker
873*da0073e9SAndroid Build Coastguard Worker        split = split_module(traced, mtt, lambda node: 0)
874*da0073e9SAndroid Build Coastguard Worker
875*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(50, 512)
876*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(split(x), traced(x))
877*da0073e9SAndroid Build Coastguard Worker
878*da0073e9SAndroid Build Coastguard Worker    def test_normalize_binary_operators(self):
879*da0073e9SAndroid Build Coastguard Worker        ops_to_test = {
880*da0073e9SAndroid Build Coastguard Worker            torch.add,
881*da0073e9SAndroid Build Coastguard Worker            torch.mul,
882*da0073e9SAndroid Build Coastguard Worker            torch.sub,
883*da0073e9SAndroid Build Coastguard Worker            torch.div,
884*da0073e9SAndroid Build Coastguard Worker            torch.floor_divide,
885*da0073e9SAndroid Build Coastguard Worker            torch.remainder,
886*da0073e9SAndroid Build Coastguard Worker            torch.eq,
887*da0073e9SAndroid Build Coastguard Worker            torch.ne,
888*da0073e9SAndroid Build Coastguard Worker            torch.lt,
889*da0073e9SAndroid Build Coastguard Worker            torch.le,
890*da0073e9SAndroid Build Coastguard Worker            torch.gt,
891*da0073e9SAndroid Build Coastguard Worker            torch.ge,
892*da0073e9SAndroid Build Coastguard Worker        }
893*da0073e9SAndroid Build Coastguard Worker
894*da0073e9SAndroid Build Coastguard Worker        # Test Tensor/Tensor callsite
895*da0073e9SAndroid Build Coastguard Worker        for op in ops_to_test:
896*da0073e9SAndroid Build Coastguard Worker
897*da0073e9SAndroid Build Coastguard Worker            class WrapperMod(torch.nn.Module):
898*da0073e9SAndroid Build Coastguard Worker                def forward(self, x, y):
899*da0073e9SAndroid Build Coastguard Worker                    return op(x, y)
900*da0073e9SAndroid Build Coastguard Worker
901*da0073e9SAndroid Build Coastguard Worker            traced = symbolic_trace(WrapperMod())
902*da0073e9SAndroid Build Coastguard Worker            normalized = NormalizeOperators(traced).transform()
903*da0073e9SAndroid Build Coastguard Worker            x, y = torch.randn(3, 4), torch.randn(3, 4)
904*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(traced(x, y), normalized(x, y))
905*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(
906*da0073e9SAndroid Build Coastguard Worker                any(n.target in ops_to_test for n in normalized.graph.nodes)
907*da0073e9SAndroid Build Coastguard Worker            )
908*da0073e9SAndroid Build Coastguard Worker
909*da0073e9SAndroid Build Coastguard Worker        # Test Tensor/scalar callsite
910*da0073e9SAndroid Build Coastguard Worker        for op in ops_to_test:
911*da0073e9SAndroid Build Coastguard Worker
912*da0073e9SAndroid Build Coastguard Worker            class WrapperMod(torch.nn.Module):
913*da0073e9SAndroid Build Coastguard Worker                def forward(self, x):
914*da0073e9SAndroid Build Coastguard Worker                    return op(x, 42)
915*da0073e9SAndroid Build Coastguard Worker
916*da0073e9SAndroid Build Coastguard Worker            traced = symbolic_trace(WrapperMod())
917*da0073e9SAndroid Build Coastguard Worker            normalized = NormalizeOperators(traced).transform()
918*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(3, 4)
919*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(traced(x), normalized(x))
920*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(
921*da0073e9SAndroid Build Coastguard Worker                any(n.target in ops_to_test for n in normalized.graph.nodes)
922*da0073e9SAndroid Build Coastguard Worker            )
923*da0073e9SAndroid Build Coastguard Worker
924*da0073e9SAndroid Build Coastguard Worker    @skipIfNoTorchVision
925*da0073e9SAndroid Build Coastguard Worker    def test_normalize_args(self):
926*da0073e9SAndroid Build Coastguard Worker        m = resnet18()
927*da0073e9SAndroid Build Coastguard Worker
928*da0073e9SAndroid Build Coastguard Worker        class FunctionalTracer(torch.fx.Tracer):
929*da0073e9SAndroid Build Coastguard Worker            def is_leaf_module(
930*da0073e9SAndroid Build Coastguard Worker                self, m: torch.nn.Module, module_qualified_name: str
931*da0073e9SAndroid Build Coastguard Worker            ) -> bool:
932*da0073e9SAndroid Build Coastguard Worker                # `leaves` contains the set of standard `nn.Modules` that are not
933*da0073e9SAndroid Build Coastguard Worker                # currently symbolically traceable. Ideally this set would be empty
934*da0073e9SAndroid Build Coastguard Worker                leaves = {torch.nn.BatchNorm2d}
935*da0073e9SAndroid Build Coastguard Worker                return type(m) in leaves
936*da0073e9SAndroid Build Coastguard Worker
937*da0073e9SAndroid Build Coastguard Worker        traced = torch.fx.GraphModule(m, FunctionalTracer().trace(m))
938*da0073e9SAndroid Build Coastguard Worker
939*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(5, 3, 224, 224)
940*da0073e9SAndroid Build Coastguard Worker        ref_outs = traced(input)
941*da0073e9SAndroid Build Coastguard Worker
942*da0073e9SAndroid Build Coastguard Worker        ShapeProp(traced).propagate(input)
943*da0073e9SAndroid Build Coastguard Worker        traced = NormalizeArgs(traced).transform()
944*da0073e9SAndroid Build Coastguard Worker
945*da0073e9SAndroid Build Coastguard Worker        modules = dict(traced.named_modules())
946*da0073e9SAndroid Build Coastguard Worker
947*da0073e9SAndroid Build Coastguard Worker        for node in traced.graph.nodes:
948*da0073e9SAndroid Build Coastguard Worker            if node.op == "call_function" and node.target != operator.add:
949*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(len(node.args), 0)
950*da0073e9SAndroid Build Coastguard Worker            elif node.op == "call_module":
951*da0073e9SAndroid Build Coastguard Worker                submod_class = modules[node.target].__class__
952*da0073e9SAndroid Build Coastguard Worker                nn_class = getattr(torch.nn, submod_class.__name__)
953*da0073e9SAndroid Build Coastguard Worker                if submod_class == nn_class:
954*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(len(node.args), 0)
955*da0073e9SAndroid Build Coastguard Worker        traced(input)
956*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(traced(input), ref_outs)
957*da0073e9SAndroid Build Coastguard Worker
958*da0073e9SAndroid Build Coastguard Worker    def test_normalize_modules_exhaustive(self):
959*da0073e9SAndroid Build Coastguard Worker        """
960*da0073e9SAndroid Build Coastguard Worker        Exhaustively test `Node.normalized_arguments` on all standard
961*da0073e9SAndroid Build Coastguard Worker        torch.nn Module classes
962*da0073e9SAndroid Build Coastguard Worker        """
963*da0073e9SAndroid Build Coastguard Worker        for test_params in module_tests + new_module_tests:
964*da0073e9SAndroid Build Coastguard Worker            if "constructor" not in test_params:
965*da0073e9SAndroid Build Coastguard Worker                constructor = getattr(torch.nn, test_params["module_name"])
966*da0073e9SAndroid Build Coastguard Worker            else:
967*da0073e9SAndroid Build Coastguard Worker                constructor = test_params["constructor"]
968*da0073e9SAndroid Build Coastguard Worker
969*da0073e9SAndroid Build Coastguard Worker            if "constructor_args" not in test_params:
970*da0073e9SAndroid Build Coastguard Worker                args = ()
971*da0073e9SAndroid Build Coastguard Worker            else:
972*da0073e9SAndroid Build Coastguard Worker                args = test_params["constructor_args"]
973*da0073e9SAndroid Build Coastguard Worker
974*da0073e9SAndroid Build Coastguard Worker            mod = constructor(*args)
975*da0073e9SAndroid Build Coastguard Worker            # Skip modules that are not standard `torch.nn`
976*da0073e9SAndroid Build Coastguard Worker            # instances, including functionals. (functionals
977*da0073e9SAndroid Build Coastguard Worker            # are tested in test_normalize_args)
978*da0073e9SAndroid Build Coastguard Worker            if mod.__class__.__name__ not in dir(torch.nn):
979*da0073e9SAndroid Build Coastguard Worker                continue
980*da0073e9SAndroid Build Coastguard Worker
981*da0073e9SAndroid Build Coastguard Worker            if "input_fn" not in test_params:
982*da0073e9SAndroid Build Coastguard Worker                inputs = torch.randn(test_params["input_size"])
983*da0073e9SAndroid Build Coastguard Worker            else:
984*da0073e9SAndroid Build Coastguard Worker                inputs = test_params["input_fn"]()
985*da0073e9SAndroid Build Coastguard Worker
986*da0073e9SAndroid Build Coastguard Worker            if not isinstance(inputs, (tuple, list)):
987*da0073e9SAndroid Build Coastguard Worker                inputs = (inputs,)
988*da0073e9SAndroid Build Coastguard Worker
989*da0073e9SAndroid Build Coastguard Worker            params = ", ".join(f"v{i}" for i in range(len(inputs)))
990*da0073e9SAndroid Build Coastguard Worker
991*da0073e9SAndroid Build Coastguard Worker            # Generate a class to wrap this standard `nn.Module` instance
992*da0073e9SAndroid Build Coastguard Worker            test_classname = f"Test{mod.__class__.__name__}"
993*da0073e9SAndroid Build Coastguard Worker            test_mod_code = f"""
994*da0073e9SAndroid Build Coastguard Workerclass {test_classname}(torch.nn.Module):
995*da0073e9SAndroid Build Coastguard Worker    def __init__(self, mod):
996*da0073e9SAndroid Build Coastguard Worker        super().__init__()
997*da0073e9SAndroid Build Coastguard Worker        self.mod = mod
998*da0073e9SAndroid Build Coastguard Worker
999*da0073e9SAndroid Build Coastguard Worker    def forward(self, {params}):
1000*da0073e9SAndroid Build Coastguard Worker        return self.mod({params})
1001*da0073e9SAndroid Build Coastguard Worker            """
1002*da0073e9SAndroid Build Coastguard Worker
1003*da0073e9SAndroid Build Coastguard Worker            gbls = {"torch": torch}
1004*da0073e9SAndroid Build Coastguard Worker            exec(test_mod_code, gbls)
1005*da0073e9SAndroid Build Coastguard Worker
1006*da0073e9SAndroid Build Coastguard Worker            test_instance = gbls[test_classname](mod)
1007*da0073e9SAndroid Build Coastguard Worker            traced = symbolic_trace(test_instance)
1008*da0073e9SAndroid Build Coastguard Worker
1009*da0073e9SAndroid Build Coastguard Worker            # Use `Node.normalized_arguments` to get a new set of arguments
1010*da0073e9SAndroid Build Coastguard Worker            # to feed to the Module. Then, rewrite the node to only take
1011*da0073e9SAndroid Build Coastguard Worker            # in those arguments as kwargs
1012*da0073e9SAndroid Build Coastguard Worker            modules = dict(traced.named_modules())
1013*da0073e9SAndroid Build Coastguard Worker            for node in traced.graph.nodes:
1014*da0073e9SAndroid Build Coastguard Worker                if node.op == "call_module":
1015*da0073e9SAndroid Build Coastguard Worker                    submod_class = modules[node.target].__class__
1016*da0073e9SAndroid Build Coastguard Worker                    nn_class = getattr(torch.nn, submod_class.__name__)
1017*da0073e9SAndroid Build Coastguard Worker                    if submod_class == nn_class:
1018*da0073e9SAndroid Build Coastguard Worker                        normalized_args = node.normalized_arguments(traced)
1019*da0073e9SAndroid Build Coastguard Worker                        normalized_args2 = normalize_module(
1020*da0073e9SAndroid Build Coastguard Worker                            traced, node.target, node.args, node.kwargs
1021*da0073e9SAndroid Build Coastguard Worker                        )
1022*da0073e9SAndroid Build Coastguard Worker                        assert normalized_args == normalized_args2
1023*da0073e9SAndroid Build Coastguard Worker                        assert normalized_args
1024*da0073e9SAndroid Build Coastguard Worker                        node.args = normalized_args.args
1025*da0073e9SAndroid Build Coastguard Worker                        node.kwargs = normalized_args.kwargs
1026*da0073e9SAndroid Build Coastguard Worker
1027*da0073e9SAndroid Build Coastguard Worker            traced.recompile()
1028*da0073e9SAndroid Build Coastguard Worker
1029*da0073e9SAndroid Build Coastguard Worker            # These Modules have an RNG in their forward, so testing
1030*da0073e9SAndroid Build Coastguard Worker            # correctness by comparing outputs is not correct. Skip that
1031*da0073e9SAndroid Build Coastguard Worker            # check for these
1032*da0073e9SAndroid Build Coastguard Worker            stochastic_modules = {"FractionalMaxPool2d", "FractionalMaxPool3d", "RReLU"}
1033*da0073e9SAndroid Build Coastguard Worker
1034*da0073e9SAndroid Build Coastguard Worker            if mod.__class__.__name__ not in stochastic_modules:
1035*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(traced(*inputs), mod(*inputs))
1036*da0073e9SAndroid Build Coastguard Worker
1037*da0073e9SAndroid Build Coastguard Worker            traced = NormalizeArgs(symbolic_trace(test_instance)).transform()
1038*da0073e9SAndroid Build Coastguard Worker            modules = dict(traced.named_modules())
1039*da0073e9SAndroid Build Coastguard Worker            for node in traced.graph.nodes:
1040*da0073e9SAndroid Build Coastguard Worker                if node.op == "call_module":
1041*da0073e9SAndroid Build Coastguard Worker                    submod_class = modules[node.target].__class__
1042*da0073e9SAndroid Build Coastguard Worker                    nn_class = getattr(torch.nn, submod_class.__name__)
1043*da0073e9SAndroid Build Coastguard Worker                    if submod_class == nn_class:
1044*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(len(node.args), 0)
1045*da0073e9SAndroid Build Coastguard Worker
1046*da0073e9SAndroid Build Coastguard Worker    def test_normalize_args_preserve_meta(self):
1047*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
1048*da0073e9SAndroid Build Coastguard Worker            def forward(self, a):
1049*da0073e9SAndroid Build Coastguard Worker                return torch.add(a, 3)
1050*da0073e9SAndroid Build Coastguard Worker
1051*da0073e9SAndroid Build Coastguard Worker        m = MyModule()
1052*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(m)
1053*da0073e9SAndroid Build Coastguard Worker
1054*da0073e9SAndroid Build Coastguard Worker        for node in traced.graph.nodes:
1055*da0073e9SAndroid Build Coastguard Worker            if node.op == "call_function" and node.target == torch.add:
1056*da0073e9SAndroid Build Coastguard Worker                node.meta["my_key"] = 7
1057*da0073e9SAndroid Build Coastguard Worker                break
1058*da0073e9SAndroid Build Coastguard Worker        else:
1059*da0073e9SAndroid Build Coastguard Worker            self.fail("Didn't find call_function torch.add")
1060*da0073e9SAndroid Build Coastguard Worker
1061*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(2, 3)
1062*da0073e9SAndroid Build Coastguard Worker        ShapeProp(traced).propagate(input)
1063*da0073e9SAndroid Build Coastguard Worker        traced = NormalizeArgs(traced).transform()
1064*da0073e9SAndroid Build Coastguard Worker
1065*da0073e9SAndroid Build Coastguard Worker        for node in traced.graph.nodes:
1066*da0073e9SAndroid Build Coastguard Worker            if node.op == "call_function" and node.target == torch.add:
1067*da0073e9SAndroid Build Coastguard Worker                self.assertTrue("my_key" in node.meta)
1068*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(node.meta["my_key"], 7)
1069*da0073e9SAndroid Build Coastguard Worker                break
1070*da0073e9SAndroid Build Coastguard Worker        else:
1071*da0073e9SAndroid Build Coastguard Worker            self.fail("Didn't find call_function torch.add")
1072*da0073e9SAndroid Build Coastguard Worker
1073*da0073e9SAndroid Build Coastguard Worker    def test_normalize_args_perserve_type(self):
1074*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
1075*da0073e9SAndroid Build Coastguard Worker            def forward(self, a: List[torch.Tensor]):
1076*da0073e9SAndroid Build Coastguard Worker                return torch.add(a[0], a[1])
1077*da0073e9SAndroid Build Coastguard Worker
1078*da0073e9SAndroid Build Coastguard Worker        m = MyModule()
1079*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(m)
1080*da0073e9SAndroid Build Coastguard Worker        traced = NormalizeArgs(traced).transform()
1081*da0073e9SAndroid Build Coastguard Worker
1082*da0073e9SAndroid Build Coastguard Worker        for node in traced.graph.nodes:
1083*da0073e9SAndroid Build Coastguard Worker            if node.op == "placeholder":
1084*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(node.type, List[torch.Tensor])
1085*da0073e9SAndroid Build Coastguard Worker
1086*da0073e9SAndroid Build Coastguard Worker    @skipIfNoTorchVision
1087*da0073e9SAndroid Build Coastguard Worker    def test_annotate_returns_with_schema(self):
1088*da0073e9SAndroid Build Coastguard Worker        m = resnet18()
1089*da0073e9SAndroid Build Coastguard Worker
1090*da0073e9SAndroid Build Coastguard Worker        traced_modules = symbolic_trace(m)
1091*da0073e9SAndroid Build Coastguard Worker        traced_modules_annotated = AnnotateTypesWithSchema(traced_modules).transform()
1092*da0073e9SAndroid Build Coastguard Worker        for node in traced_modules_annotated.graph.nodes:
1093*da0073e9SAndroid Build Coastguard Worker            if node.type is None:
1094*da0073e9SAndroid Build Coastguard Worker                check = (node.op, node.target)
1095*da0073e9SAndroid Build Coastguard Worker                self.assertIn(
1096*da0073e9SAndroid Build Coastguard Worker                    check,
1097*da0073e9SAndroid Build Coastguard Worker                    {
1098*da0073e9SAndroid Build Coastguard Worker                        ("placeholder", "x"),
1099*da0073e9SAndroid Build Coastguard Worker                        ("call_module", "maxpool"),
1100*da0073e9SAndroid Build Coastguard Worker                        ("call_function", operator.add),
1101*da0073e9SAndroid Build Coastguard Worker                        ("call_function", torch.flatten),
1102*da0073e9SAndroid Build Coastguard Worker                        ("output", "output"),
1103*da0073e9SAndroid Build Coastguard Worker                    }
1104*da0073e9SAndroid Build Coastguard Worker                )
1105*da0073e9SAndroid Build Coastguard Worker
1106*da0073e9SAndroid Build Coastguard Worker        # Smoke test torchscript compilation since now we're emitting type annotations
1107*da0073e9SAndroid Build Coastguard Worker        torch.jit.script(traced_modules_annotated)
1108*da0073e9SAndroid Build Coastguard Worker
1109*da0073e9SAndroid Build Coastguard Worker        class FunctionalTracer(torch.fx.Tracer):
1110*da0073e9SAndroid Build Coastguard Worker            def is_leaf_module(
1111*da0073e9SAndroid Build Coastguard Worker                self, m: torch.nn.Module, module_qualified_name: str
1112*da0073e9SAndroid Build Coastguard Worker            ) -> bool:
1113*da0073e9SAndroid Build Coastguard Worker                # `leaves` contains the set of standard `nn.Modules` that are not
1114*da0073e9SAndroid Build Coastguard Worker                # currently symbolically traceable. Ideally this set would be empty
1115*da0073e9SAndroid Build Coastguard Worker                leaves = {torch.nn.BatchNorm2d}
1116*da0073e9SAndroid Build Coastguard Worker                return type(m) in leaves
1117*da0073e9SAndroid Build Coastguard Worker
1118*da0073e9SAndroid Build Coastguard Worker        traced_functionals = torch.fx.GraphModule(m, FunctionalTracer().trace(m))
1119*da0073e9SAndroid Build Coastguard Worker
1120*da0073e9SAndroid Build Coastguard Worker        traced_functionals_annotated = AnnotateTypesWithSchema(
1121*da0073e9SAndroid Build Coastguard Worker            traced_functionals
1122*da0073e9SAndroid Build Coastguard Worker        ).transform()
1123*da0073e9SAndroid Build Coastguard Worker        for node in traced_functionals_annotated.graph.nodes:
1124*da0073e9SAndroid Build Coastguard Worker            if node.type is None:
1125*da0073e9SAndroid Build Coastguard Worker                check = (node.op, node.target)
1126*da0073e9SAndroid Build Coastguard Worker                excluded_nodes = {
1127*da0073e9SAndroid Build Coastguard Worker                    ("placeholder", "x"),
1128*da0073e9SAndroid Build Coastguard Worker                    # Return type differs based on boolean dispatch :(
1129*da0073e9SAndroid Build Coastguard Worker                    ("call_function", torch.nn.functional.max_pool2d),
1130*da0073e9SAndroid Build Coastguard Worker                    ("output", "output"),
1131*da0073e9SAndroid Build Coastguard Worker                }
1132*da0073e9SAndroid Build Coastguard Worker                # AnnotateTypesWithSchema doesn't work with bound C++ functions
1133*da0073e9SAndroid Build Coastguard Worker                if not isinstance(node.target, BuiltinFunctionType):
1134*da0073e9SAndroid Build Coastguard Worker                    self.assertIn(check, excluded_nodes)
1135*da0073e9SAndroid Build Coastguard Worker
1136*da0073e9SAndroid Build Coastguard Worker        # Smoke test torchscript compilation since now we're emitting type annotations
1137*da0073e9SAndroid Build Coastguard Worker        torch.jit.script(traced_functionals_annotated)
1138*da0073e9SAndroid Build Coastguard Worker
1139*da0073e9SAndroid Build Coastguard Worker    def test_annotate_getitem_node(self):
1140*da0073e9SAndroid Build Coastguard Worker        class CustomType:
1141*da0073e9SAndroid Build Coastguard Worker            pass
1142*da0073e9SAndroid Build Coastguard Worker
1143*da0073e9SAndroid Build Coastguard Worker        class CustomNamedTuple(NamedTuple):
1144*da0073e9SAndroid Build Coastguard Worker            x: int
1145*da0073e9SAndroid Build Coastguard Worker            y: float
1146*da0073e9SAndroid Build Coastguard Worker
1147*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
1148*da0073e9SAndroid Build Coastguard Worker            def forward(self, inp: Tuple[CustomType, torch.Tensor], inp2: List[CustomType], inp3: CustomNamedTuple):
1149*da0073e9SAndroid Build Coastguard Worker                inp_0 = inp[0]
1150*da0073e9SAndroid Build Coastguard Worker                inp_1 = inp[1]
1151*da0073e9SAndroid Build Coastguard Worker                inp2_0 = inp2[0]
1152*da0073e9SAndroid Build Coastguard Worker                inp3_x = inp3.x
1153*da0073e9SAndroid Build Coastguard Worker                inp3_y = inp3.y
1154*da0073e9SAndroid Build Coastguard Worker                return inp_0 + inp_1 + inp2_0 + inp3_x + inp3_y
1155*da0073e9SAndroid Build Coastguard Worker
1156*da0073e9SAndroid Build Coastguard Worker        my_module = MyModule()
1157*da0073e9SAndroid Build Coastguard Worker        my_module_traced = torch.fx.symbolic_trace(my_module)
1158*da0073e9SAndroid Build Coastguard Worker
1159*da0073e9SAndroid Build Coastguard Worker        # by default, fx transform loses type annotation of getitem nodes.
1160*da0073e9SAndroid Build Coastguard Worker        for node in my_module_traced.graph.nodes:
1161*da0073e9SAndroid Build Coastguard Worker            if node.target == operator.getitem:
1162*da0073e9SAndroid Build Coastguard Worker                assert node.type is None
1163*da0073e9SAndroid Build Coastguard Worker
1164*da0073e9SAndroid Build Coastguard Worker        annotate_getitem_nodes(my_module_traced.graph)
1165*da0073e9SAndroid Build Coastguard Worker
1166*da0073e9SAndroid Build Coastguard Worker        for node in my_module_traced.graph.nodes:
1167*da0073e9SAndroid Build Coastguard Worker            if node.target == operator.getitem:
1168*da0073e9SAndroid Build Coastguard Worker                self.assertIsNotNone(node.type, f"Node {node} should be annotated but is not.")
1169*da0073e9SAndroid Build Coastguard Worker
1170*da0073e9SAndroid Build Coastguard Worker    def test_subgraph_uniquename(self):
1171*da0073e9SAndroid Build Coastguard Worker        class MyModule(torch.nn.Module):
1172*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1173*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1174*da0073e9SAndroid Build Coastguard Worker                self.linear = torch.nn.Linear(4, 4)
1175*da0073e9SAndroid Build Coastguard Worker
1176*da0073e9SAndroid Build Coastguard Worker            def forward(self, a, b, c, d):
1177*da0073e9SAndroid Build Coastguard Worker                add_1 = a + b
1178*da0073e9SAndroid Build Coastguard Worker                add_2 = add_1 + c
1179*da0073e9SAndroid Build Coastguard Worker                linear_1 = self.linear(add_1)
1180*da0073e9SAndroid Build Coastguard Worker                add_3 = add_2 + d
1181*da0073e9SAndroid Build Coastguard Worker                add_4 = add_2 + linear_1
1182*da0073e9SAndroid Build Coastguard Worker                add_5 = add_3 + add_4
1183*da0073e9SAndroid Build Coastguard Worker                return add_5
1184*da0073e9SAndroid Build Coastguard Worker
1185*da0073e9SAndroid Build Coastguard Worker        a, b, c, d = torch.ones(4), torch.ones(4), torch.ones(4), torch.ones(4)
1186*da0073e9SAndroid Build Coastguard Worker        mm = MyModule()
1187*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(mm)
1188*da0073e9SAndroid Build Coastguard Worker
1189*da0073e9SAndroid Build Coastguard Worker        def split_cb(node: torch.fx.Node):
1190*da0073e9SAndroid Build Coastguard Worker            if node.name == "a" or node.name == "b" or node.name == "add":
1191*da0073e9SAndroid Build Coastguard Worker                return 0
1192*da0073e9SAndroid Build Coastguard Worker            else:
1193*da0073e9SAndroid Build Coastguard Worker                return 1
1194*da0073e9SAndroid Build Coastguard Worker
1195*da0073e9SAndroid Build Coastguard Worker        module_with_submodule = split_module(traced, mm, split_cb)
1196*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(module_with_submodule(a, b, c, d), traced(a, b, c, d))
1197*da0073e9SAndroid Build Coastguard Worker
1198*da0073e9SAndroid Build Coastguard Worker    def test_split_qualname_mapping(self):
1199*da0073e9SAndroid Build Coastguard Worker        d_hid = 4
1200*da0073e9SAndroid Build Coastguard Worker
1201*da0073e9SAndroid Build Coastguard Worker        class ExampleCode(torch.nn.Module):
1202*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1203*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1204*da0073e9SAndroid Build Coastguard Worker                self.mm_param = torch.nn.Parameter(torch.randn(d_hid, d_hid))
1205*da0073e9SAndroid Build Coastguard Worker                self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
1206*da0073e9SAndroid Build Coastguard Worker                self.lin = torch.nn.Linear(d_hid, d_hid)
1207*da0073e9SAndroid Build Coastguard Worker
1208*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1209*da0073e9SAndroid Build Coastguard Worker                x = torch.mm(x, self.mm_param)
1210*da0073e9SAndroid Build Coastguard Worker                x = torch.relu(x)
1211*da0073e9SAndroid Build Coastguard Worker                x = torch.mm(x, self.mm_param)
1212*da0073e9SAndroid Build Coastguard Worker                x = self.lin(x)
1213*da0073e9SAndroid Build Coastguard Worker                x = torch.relu(x)
1214*da0073e9SAndroid Build Coastguard Worker                x = torch.mm(x, self.mm_param2)
1215*da0073e9SAndroid Build Coastguard Worker                x = self.lin(x)
1216*da0073e9SAndroid Build Coastguard Worker                return x
1217*da0073e9SAndroid Build Coastguard Worker
1218*da0073e9SAndroid Build Coastguard Worker        my_module = ExampleCode()
1219*da0073e9SAndroid Build Coastguard Worker        my_module_traced = symbolic_trace(my_module)
1220*da0073e9SAndroid Build Coastguard Worker
1221*da0073e9SAndroid Build Coastguard Worker        part_idx = 0
1222*da0073e9SAndroid Build Coastguard Worker
1223*da0073e9SAndroid Build Coastguard Worker        def split_callback(n : torch.fx.Node):
1224*da0073e9SAndroid Build Coastguard Worker            nonlocal part_idx
1225*da0073e9SAndroid Build Coastguard Worker            if (n.op, n.target) == ('call_module', 'lin'):
1226*da0073e9SAndroid Build Coastguard Worker                part_idx += 1
1227*da0073e9SAndroid Build Coastguard Worker            return part_idx
1228*da0073e9SAndroid Build Coastguard Worker
1229*da0073e9SAndroid Build Coastguard Worker        # split module in module with submodules
1230*da0073e9SAndroid Build Coastguard Worker        qualname_map : Dict[str, str] = {}
1231*da0073e9SAndroid Build Coastguard Worker        module_with_submodules = split_module(
1232*da0073e9SAndroid Build Coastguard Worker            my_module_traced, my_module, split_callback, qualname_map
1233*da0073e9SAndroid Build Coastguard Worker        )
1234*da0073e9SAndroid Build Coastguard Worker        expected_qualname_map = {
1235*da0073e9SAndroid Build Coastguard Worker            'submod_1.lin': 'lin', 'submod_2.lin': 'lin'
1236*da0073e9SAndroid Build Coastguard Worker        }
1237*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(qualname_map, expected_qualname_map)
1238*da0073e9SAndroid Build Coastguard Worker
1239*da0073e9SAndroid Build Coastguard Worker    def test_traceable_function_with_nonstandard_name(self):
1240*da0073e9SAndroid Build Coastguard Worker        def foo(x):
1241*da0073e9SAndroid Build Coastguard Worker            return torch.relu(x)
1242*da0073e9SAndroid Build Coastguard Worker
1243*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace_with_rewrite(foo)
1244*da0073e9SAndroid Build Coastguard Worker
1245*da0073e9SAndroid Build Coastguard Worker    def test_to_folder(self):
1246*da0073e9SAndroid Build Coastguard Worker        class Test(torch.nn.Module):
1247*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1248*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1249*da0073e9SAndroid Build Coastguard Worker                self.W = torch.nn.Parameter(torch.randn(2))
1250*da0073e9SAndroid Build Coastguard Worker                self.seq = torch.nn.Sequential(torch.nn.BatchNorm1d(2, 2))
1251*da0073e9SAndroid Build Coastguard Worker                self.linear = torch.nn.Linear(2, 2)
1252*da0073e9SAndroid Build Coastguard Worker                self.attr = torch.randn(2)
1253*da0073e9SAndroid Build Coastguard Worker                self.attr2 = torch.nn.Buffer(torch.randn(2))
1254*da0073e9SAndroid Build Coastguard Worker                self.attr3 = torch.nn.Buffer(torch.ones(2, dtype=torch.int32))
1255*da0073e9SAndroid Build Coastguard Worker
1256*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1257*da0073e9SAndroid Build Coastguard Worker                return self.linear(self.seq(self.W + self.attr + self.attr2 + self.attr3 + x))
1258*da0073e9SAndroid Build Coastguard Worker
1259*da0073e9SAndroid Build Coastguard Worker        mod = symbolic_trace(Test())
1260*da0073e9SAndroid Build Coastguard Worker        module_name = "Foo"
1261*da0073e9SAndroid Build Coastguard Worker        import tempfile
1262*da0073e9SAndroid Build Coastguard Worker        from pathlib import Path
1263*da0073e9SAndroid Build Coastguard Worker
1264*da0073e9SAndroid Build Coastguard Worker        with tempfile.TemporaryDirectory() as tmp_dir:
1265*da0073e9SAndroid Build Coastguard Worker            tmp_dir = Path(tmp_dir)
1266*da0073e9SAndroid Build Coastguard Worker            mod.to_folder(tmp_dir, module_name)
1267*da0073e9SAndroid Build Coastguard Worker            # Recipe taken from here:
1268*da0073e9SAndroid Build Coastguard Worker            # https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
1269*da0073e9SAndroid Build Coastguard Worker            import importlib.util
1270*da0073e9SAndroid Build Coastguard Worker
1271*da0073e9SAndroid Build Coastguard Worker            spec = importlib.util.spec_from_file_location(
1272*da0073e9SAndroid Build Coastguard Worker                module_name, tmp_dir / "__init__.py"
1273*da0073e9SAndroid Build Coastguard Worker            )
1274*da0073e9SAndroid Build Coastguard Worker            module = importlib.util.module_from_spec(spec)
1275*da0073e9SAndroid Build Coastguard Worker            sys.modules[module_name] = module
1276*da0073e9SAndroid Build Coastguard Worker            spec.loader.exec_module(module)
1277*da0073e9SAndroid Build Coastguard Worker            t = torch.randn(2, 2)
1278*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(module.Foo()(t), mod(t))
1279*da0073e9SAndroid Build Coastguard Worker
1280*da0073e9SAndroid Build Coastguard Worker    def test_fetch(self):
1281*da0073e9SAndroid Build Coastguard Worker        attrs_for_lowering: Dict[str, List[str]] = {
1282*da0073e9SAndroid Build Coastguard Worker            "torch.nn.modules.conv.Conv2d": [
1283*da0073e9SAndroid Build Coastguard Worker                "weight",
1284*da0073e9SAndroid Build Coastguard Worker                "bias",
1285*da0073e9SAndroid Build Coastguard Worker                "kernel_size",
1286*da0073e9SAndroid Build Coastguard Worker                "stride",
1287*da0073e9SAndroid Build Coastguard Worker                "padding",
1288*da0073e9SAndroid Build Coastguard Worker                "dilation",
1289*da0073e9SAndroid Build Coastguard Worker                "groups",
1290*da0073e9SAndroid Build Coastguard Worker                "padding_mode",
1291*da0073e9SAndroid Build Coastguard Worker            ],
1292*da0073e9SAndroid Build Coastguard Worker            "torch.nn.modules.batchnorm.BatchNorm2d": [
1293*da0073e9SAndroid Build Coastguard Worker                "weight",
1294*da0073e9SAndroid Build Coastguard Worker                "bias",
1295*da0073e9SAndroid Build Coastguard Worker                "running_mean",
1296*da0073e9SAndroid Build Coastguard Worker                "running_var",
1297*da0073e9SAndroid Build Coastguard Worker                "eps",
1298*da0073e9SAndroid Build Coastguard Worker            ],
1299*da0073e9SAndroid Build Coastguard Worker        }
1300*da0073e9SAndroid Build Coastguard Worker
1301*da0073e9SAndroid Build Coastguard Worker        class TestModule(torch.nn.Module):
1302*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1303*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1304*da0073e9SAndroid Build Coastguard Worker                self.conv = torch.nn.Conv2d(3, 3, 2)
1305*da0073e9SAndroid Build Coastguard Worker                self.bn = torch.nn.BatchNorm2d(3)
1306*da0073e9SAndroid Build Coastguard Worker
1307*da0073e9SAndroid Build Coastguard Worker            def forward(self, a):
1308*da0073e9SAndroid Build Coastguard Worker                a = self.conv(a)
1309*da0073e9SAndroid Build Coastguard Worker                a += a
1310*da0073e9SAndroid Build Coastguard Worker                return self.bn(a)
1311*da0073e9SAndroid Build Coastguard Worker
1312*da0073e9SAndroid Build Coastguard Worker        mod = TestModule()
1313*da0073e9SAndroid Build Coastguard Worker        traced = symbolic_trace(mod)
1314*da0073e9SAndroid Build Coastguard Worker        lift_lowering_attrs_to_nodes(traced)
1315*da0073e9SAndroid Build Coastguard Worker
1316*da0073e9SAndroid Build Coastguard Worker        for node in traced.graph.nodes:
1317*da0073e9SAndroid Build Coastguard Worker            if node.op == "call_module":
1318*da0073e9SAndroid Build Coastguard Worker                assert hasattr(node, "attrs_for_lowering")
1319*da0073e9SAndroid Build Coastguard Worker                para_list = attrs_for_lowering[node.attrs_for_lowering["name"]]
1320*da0073e9SAndroid Build Coastguard Worker
1321*da0073e9SAndroid Build Coastguard Worker                # node.attrs_for_lowering has an addition field of class name
1322*da0073e9SAndroid Build Coastguard Worker                assert len(para_list) + 1 == len(node.attrs_for_lowering)
1323*da0073e9SAndroid Build Coastguard Worker                for p_name in para_list:
1324*da0073e9SAndroid Build Coastguard Worker                    assert p_name in node.attrs_for_lowering
1325*da0073e9SAndroid Build Coastguard Worker
1326*da0073e9SAndroid Build Coastguard Worker    def test_merge_matmuls(self):
1327*da0073e9SAndroid Build Coastguard Worker        """
1328*da0073e9SAndroid Build Coastguard Worker        A collection of test cases for torch.fx.experimental.merge_matmul,
1329*da0073e9SAndroid Build Coastguard Worker        a graph transformation that merges matrix multiplication operations.
1330*da0073e9SAndroid Build Coastguard Worker        """
1331*da0073e9SAndroid Build Coastguard Worker        # Utility function for counting matmuls for test assertions.
1332*da0073e9SAndroid Build Coastguard Worker        def _count_matmuls(mod):
1333*da0073e9SAndroid Build Coastguard Worker            gm = torch.fx.symbolic_trace(mod)
1334*da0073e9SAndroid Build Coastguard Worker
1335*da0073e9SAndroid Build Coastguard Worker            num_matmuls = 0
1336*da0073e9SAndroid Build Coastguard Worker            for node in gm.graph.nodes:
1337*da0073e9SAndroid Build Coastguard Worker                if node.target == torch.matmul:
1338*da0073e9SAndroid Build Coastguard Worker                    num_matmuls += 1
1339*da0073e9SAndroid Build Coastguard Worker
1340*da0073e9SAndroid Build Coastguard Worker            return num_matmuls
1341*da0073e9SAndroid Build Coastguard Worker
1342*da0073e9SAndroid Build Coastguard Worker        # Simple test case in which there are two matmuls of the same size to merge.
1343*da0073e9SAndroid Build Coastguard Worker        class SimpleMergeMatmulModule(torch.nn.Module):
1344*da0073e9SAndroid Build Coastguard Worker            def __init__(self, rhs):
1345*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1346*da0073e9SAndroid Build Coastguard Worker                self.rhs = rhs
1347*da0073e9SAndroid Build Coastguard Worker
1348*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, y):
1349*da0073e9SAndroid Build Coastguard Worker                a = torch.matmul(x, self.rhs)
1350*da0073e9SAndroid Build Coastguard Worker                b = torch.matmul(y, self.rhs)
1351*da0073e9SAndroid Build Coastguard Worker                return a + b
1352*da0073e9SAndroid Build Coastguard Worker
1353*da0073e9SAndroid Build Coastguard Worker        # Initialize inputs.
1354*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(3, 3)
1355*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(3, 3)
1356*da0073e9SAndroid Build Coastguard Worker
1357*da0073e9SAndroid Build Coastguard Worker        # Initialize RHS for matmuls.
1358*da0073e9SAndroid Build Coastguard Worker        rhs = torch.randn(3, 4)
1359*da0073e9SAndroid Build Coastguard Worker
1360*da0073e9SAndroid Build Coastguard Worker        # Construct SimpleMergeMatmulModule and call merge_matmul on it.
1361*da0073e9SAndroid Build Coastguard Worker        module = SimpleMergeMatmulModule(rhs)
1362*da0073e9SAndroid Build Coastguard Worker        opt_module = merge_matmul.merge_matmul(module)
1363*da0073e9SAndroid Build Coastguard Worker
1364*da0073e9SAndroid Build Coastguard Worker        # Numerical correctness check.
1365*da0073e9SAndroid Build Coastguard Worker        before = module(a, b)
1366*da0073e9SAndroid Build Coastguard Worker        after = opt_module(a, b)
1367*da0073e9SAndroid Build Coastguard Worker        before.allclose(after)
1368*da0073e9SAndroid Build Coastguard Worker
1369*da0073e9SAndroid Build Coastguard Worker        # Basic graph structure check; original module should have 2 matmuls
1370*da0073e9SAndroid Build Coastguard Worker        # and optimized module should have 1.
1371*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(_count_matmuls(module), 2)
1372*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(_count_matmuls(opt_module), 1)
1373*da0073e9SAndroid Build Coastguard Worker
1374*da0073e9SAndroid Build Coastguard Worker        # Test case in which there are multiple matmuls of different sizes to merge.
1375*da0073e9SAndroid Build Coastguard Worker        class FiveMergeMatmulModule(torch.nn.Module):
1376*da0073e9SAndroid Build Coastguard Worker            def __init__(self, rhs):
1377*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1378*da0073e9SAndroid Build Coastguard Worker                self.rhs = rhs
1379*da0073e9SAndroid Build Coastguard Worker
1380*da0073e9SAndroid Build Coastguard Worker            def forward(self, a, b, c, d, e):
1381*da0073e9SAndroid Build Coastguard Worker                s = torch.tensor([])
1382*da0073e9SAndroid Build Coastguard Worker                matmuls = []
1383*da0073e9SAndroid Build Coastguard Worker
1384*da0073e9SAndroid Build Coastguard Worker                # For some reason using a list comprehension or for-loop for this
1385*da0073e9SAndroid Build Coastguard Worker                # doesn't work.
1386*da0073e9SAndroid Build Coastguard Worker                matmuls.append(torch.matmul(a, self.rhs))
1387*da0073e9SAndroid Build Coastguard Worker                matmuls.append(torch.matmul(b, self.rhs))
1388*da0073e9SAndroid Build Coastguard Worker                matmuls.append(torch.matmul(c, self.rhs))
1389*da0073e9SAndroid Build Coastguard Worker                matmuls.append(torch.matmul(d, self.rhs))
1390*da0073e9SAndroid Build Coastguard Worker                matmuls.append(torch.matmul(e, self.rhs))
1391*da0073e9SAndroid Build Coastguard Worker
1392*da0073e9SAndroid Build Coastguard Worker                for m in matmuls:
1393*da0073e9SAndroid Build Coastguard Worker                    s += torch.sum(m)
1394*da0073e9SAndroid Build Coastguard Worker
1395*da0073e9SAndroid Build Coastguard Worker                return s
1396*da0073e9SAndroid Build Coastguard Worker
1397*da0073e9SAndroid Build Coastguard Worker        # Initialize inputs.
1398*da0073e9SAndroid Build Coastguard Worker        inputs = [torch.randn(2 * i + 1, 5) for i in range(5)]
1399*da0073e9SAndroid Build Coastguard Worker
1400*da0073e9SAndroid Build Coastguard Worker        # Initialize RHS.
1401*da0073e9SAndroid Build Coastguard Worker        rhs = torch.randn(5, 4)
1402*da0073e9SAndroid Build Coastguard Worker
1403*da0073e9SAndroid Build Coastguard Worker        # Construct FiveMergeMatmulModule and call merge_matmul on it.
1404*da0073e9SAndroid Build Coastguard Worker        module = FiveMergeMatmulModule(rhs)
1405*da0073e9SAndroid Build Coastguard Worker        opt_module = merge_matmul.merge_matmul(module)
1406*da0073e9SAndroid Build Coastguard Worker
1407*da0073e9SAndroid Build Coastguard Worker        # Numerical correctness check.
1408*da0073e9SAndroid Build Coastguard Worker        before = module(*inputs)
1409*da0073e9SAndroid Build Coastguard Worker        after = opt_module(*inputs)
1410*da0073e9SAndroid Build Coastguard Worker        before.allclose(after)
1411*da0073e9SAndroid Build Coastguard Worker
1412*da0073e9SAndroid Build Coastguard Worker        # Basic graph structure check; original module should have len(inputs) matmuls
1413*da0073e9SAndroid Build Coastguard Worker        # and optimized module should have 1.
1414*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(_count_matmuls(module), len(inputs))
1415*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(_count_matmuls(opt_module), 1)
1416*da0073e9SAndroid Build Coastguard Worker
1417*da0073e9SAndroid Build Coastguard Worker        # Simple test case in which two matmuls cannot be merged due to a data dependency between
1418*da0073e9SAndroid Build Coastguard Worker        # the LHS operands.
1419*da0073e9SAndroid Build Coastguard Worker        class UnmergeableMatmulModule(torch.nn.Module):
1420*da0073e9SAndroid Build Coastguard Worker            def __init__(self, rhs):
1421*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1422*da0073e9SAndroid Build Coastguard Worker                self.rhs = rhs
1423*da0073e9SAndroid Build Coastguard Worker
1424*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1425*da0073e9SAndroid Build Coastguard Worker                a = torch.matmul(x, self.rhs)
1426*da0073e9SAndroid Build Coastguard Worker                a_abs = torch.abs(a)
1427*da0073e9SAndroid Build Coastguard Worker                b = torch.matmul(a_abs.transpose(1, 0), self.rhs)
1428*da0073e9SAndroid Build Coastguard Worker                return b
1429*da0073e9SAndroid Build Coastguard Worker
1430*da0073e9SAndroid Build Coastguard Worker        # Initialize inputs.
1431*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(3, 3)
1432*da0073e9SAndroid Build Coastguard Worker
1433*da0073e9SAndroid Build Coastguard Worker        # Initialize RHS for matmuls.
1434*da0073e9SAndroid Build Coastguard Worker        rhs = torch.randn(3, 4)
1435*da0073e9SAndroid Build Coastguard Worker
1436*da0073e9SAndroid Build Coastguard Worker        # Construct UnmergeableMatmulModule and call merge_matmul on it.
1437*da0073e9SAndroid Build Coastguard Worker        module = UnmergeableMatmulModule(rhs)
1438*da0073e9SAndroid Build Coastguard Worker        opt_module = merge_matmul.merge_matmul(module)
1439*da0073e9SAndroid Build Coastguard Worker
1440*da0073e9SAndroid Build Coastguard Worker        # Numerical correctness check.
1441*da0073e9SAndroid Build Coastguard Worker        before = module(a)
1442*da0073e9SAndroid Build Coastguard Worker        after = opt_module(a)
1443*da0073e9SAndroid Build Coastguard Worker        before.allclose(after)
1444*da0073e9SAndroid Build Coastguard Worker
1445*da0073e9SAndroid Build Coastguard Worker        # Basic graph structure check; the number of matrix multiplcations should not have changed.
1446*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(_count_matmuls(module), 2)
1447*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(_count_matmuls(opt_module), 2)
1448*da0073e9SAndroid Build Coastguard Worker
1449*da0073e9SAndroid Build Coastguard Worker    def test_type_matches(self):
1450*da0073e9SAndroid Build Coastguard Worker        should_be_equal = [
1451*da0073e9SAndroid Build Coastguard Worker            (int, int),
1452*da0073e9SAndroid Build Coastguard Worker            (numbers.Number, int),
1453*da0073e9SAndroid Build Coastguard Worker            (numbers.Number, float),
1454*da0073e9SAndroid Build Coastguard Worker            (int, type(torch.float)),
1455*da0073e9SAndroid Build Coastguard Worker            (Union[int, float], int),
1456*da0073e9SAndroid Build Coastguard Worker            (Union[int, float], float),
1457*da0073e9SAndroid Build Coastguard Worker            (List[int], int),
1458*da0073e9SAndroid Build Coastguard Worker            (List[int], create_type_hint([int, int])),
1459*da0073e9SAndroid Build Coastguard Worker            (List[int], create_type_hint((int, int))),
1460*da0073e9SAndroid Build Coastguard Worker            (List[torch.Tensor], create_type_hint([torch.Tensor, torch.Tensor])),
1461*da0073e9SAndroid Build Coastguard Worker            (
1462*da0073e9SAndroid Build Coastguard Worker                List[torch.Tensor],
1463*da0073e9SAndroid Build Coastguard Worker                create_type_hint([torch.nn.Parameter, torch.nn.Parameter]),
1464*da0073e9SAndroid Build Coastguard Worker            ),
1465*da0073e9SAndroid Build Coastguard Worker            (torch.Tensor, torch.nn.Parameter),
1466*da0073e9SAndroid Build Coastguard Worker            (List[torch.Tensor], create_type_hint([torch.nn.Parameter, torch.Tensor])),
1467*da0073e9SAndroid Build Coastguard Worker            (List[torch.Tensor], create_type_hint([torch.Tensor, torch.nn.Parameter])),
1468*da0073e9SAndroid Build Coastguard Worker            (List[torch.Tensor], create_type_hint((torch.Tensor, torch.Tensor))),
1469*da0073e9SAndroid Build Coastguard Worker            (
1470*da0073e9SAndroid Build Coastguard Worker                List[torch.Tensor],
1471*da0073e9SAndroid Build Coastguard Worker                create_type_hint((torch.nn.Parameter, torch.nn.Parameter)),
1472*da0073e9SAndroid Build Coastguard Worker            ),
1473*da0073e9SAndroid Build Coastguard Worker            (torch.Tensor, torch.nn.Parameter),
1474*da0073e9SAndroid Build Coastguard Worker            (List[torch.Tensor], create_type_hint((torch.nn.Parameter, torch.Tensor))),
1475*da0073e9SAndroid Build Coastguard Worker            (List[torch.Tensor], create_type_hint((torch.Tensor, torch.nn.Parameter))),
1476*da0073e9SAndroid Build Coastguard Worker            (Optional[List[torch.Tensor]], List[torch.Tensor]),
1477*da0073e9SAndroid Build Coastguard Worker            (Optional[List[int]], List[int]),
1478*da0073e9SAndroid Build Coastguard Worker        ]
1479*da0073e9SAndroid Build Coastguard Worker        for sig_type, arg_type in should_be_equal:
1480*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(type_matches(sig_type, arg_type))
1481*da0073e9SAndroid Build Coastguard Worker
1482*da0073e9SAndroid Build Coastguard Worker        should_fail = [
1483*da0073e9SAndroid Build Coastguard Worker            (int, float),
1484*da0073e9SAndroid Build Coastguard Worker            (Union[int, float], str),
1485*da0073e9SAndroid Build Coastguard Worker            (List[torch.Tensor], List[int]),
1486*da0073e9SAndroid Build Coastguard Worker        ]
1487*da0073e9SAndroid Build Coastguard Worker
1488*da0073e9SAndroid Build Coastguard Worker        for sig_type, arg_type in should_fail:
1489*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(type_matches(sig_type, arg_type))
1490*da0073e9SAndroid Build Coastguard Worker
1491*da0073e9SAndroid Build Coastguard Worker    @skipIfNoMkldnn
1492*da0073e9SAndroid Build Coastguard Worker    def test_optimize_for_inference_cpu(self):
1493*da0073e9SAndroid Build Coastguard Worker        import torch.nn as nn
1494*da0073e9SAndroid Build Coastguard Worker
1495*da0073e9SAndroid Build Coastguard Worker        class Foo(nn.Module):
1496*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1497*da0073e9SAndroid Build Coastguard Worker                super().__init__()
1498*da0073e9SAndroid Build Coastguard Worker                layers = []
1499*da0073e9SAndroid Build Coastguard Worker                layers2 = []
1500*da0073e9SAndroid Build Coastguard Worker                for _ in range(10):
1501*da0073e9SAndroid Build Coastguard Worker                    layers.append(nn.Conv2d(3, 3, 1))
1502*da0073e9SAndroid Build Coastguard Worker                    layers.append(nn.BatchNorm2d(3))
1503*da0073e9SAndroid Build Coastguard Worker                    layers.append(nn.ReLU())
1504*da0073e9SAndroid Build Coastguard Worker
1505*da0073e9SAndroid Build Coastguard Worker                    layers2.append(nn.Conv2d(3, 3, 1))
1506*da0073e9SAndroid Build Coastguard Worker                    layers2.append(nn.BatchNorm2d(3))
1507*da0073e9SAndroid Build Coastguard Worker                    layers2.append(nn.ReLU())
1508*da0073e9SAndroid Build Coastguard Worker                self.model = nn.Sequential(*layers)
1509*da0073e9SAndroid Build Coastguard Worker                self.model2 = nn.Sequential(*layers2)
1510*da0073e9SAndroid Build Coastguard Worker
1511*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
1512*da0073e9SAndroid Build Coastguard Worker                return self.model(x) + self.model2(x)
1513*da0073e9SAndroid Build Coastguard Worker
1514*da0073e9SAndroid Build Coastguard Worker        N, C, H, W, = (
1515*da0073e9SAndroid Build Coastguard Worker            1,
1516*da0073e9SAndroid Build Coastguard Worker            3,
1517*da0073e9SAndroid Build Coastguard Worker            224,
1518*da0073e9SAndroid Build Coastguard Worker            224,
1519*da0073e9SAndroid Build Coastguard Worker        )
1520*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn(N, C, H, W)
1521*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
1522*da0073e9SAndroid Build Coastguard Worker            model = Foo().eval()
1523*da0073e9SAndroid Build Coastguard Worker            optimized_model = optimization.optimize_for_inference(model)
1524*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(model(inp), optimized_model(inp))
1525*da0073e9SAndroid Build Coastguard Worker
1526*da0073e9SAndroid Build Coastguard Worker            optimized_model2 = optimization.optimize_for_inference(
1527*da0073e9SAndroid Build Coastguard Worker                model, pass_config={"remove_dropout": False}
1528*da0073e9SAndroid Build Coastguard Worker            )
1529*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(model(inp), optimized_model2(inp))
1530*da0073e9SAndroid Build Coastguard Worker
1531*da0073e9SAndroid Build Coastguard Worker    @skipIfNoTorchVision
1532*da0073e9SAndroid Build Coastguard Worker    @skipIfNoMkldnn
1533*da0073e9SAndroid Build Coastguard Worker    def test_optimize_for_inference_cpu_torchvision(self):
1534*da0073e9SAndroid Build Coastguard Worker        models = [
1535*da0073e9SAndroid Build Coastguard Worker            torchvision.models.resnet18,
1536*da0073e9SAndroid Build Coastguard Worker            torchvision.models.resnet50,
1537*da0073e9SAndroid Build Coastguard Worker            torchvision.models.densenet121,
1538*da0073e9SAndroid Build Coastguard Worker            torchvision.models.shufflenet_v2_x1_0,
1539*da0073e9SAndroid Build Coastguard Worker            torchvision.models.vgg16,
1540*da0073e9SAndroid Build Coastguard Worker            torchvision.models.mobilenet_v2,
1541*da0073e9SAndroid Build Coastguard Worker            torchvision.models.mnasnet1_0,
1542*da0073e9SAndroid Build Coastguard Worker            torchvision.models.resnext50_32x4d,
1543*da0073e9SAndroid Build Coastguard Worker        ]
1544*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
1545*da0073e9SAndroid Build Coastguard Worker            for model_type in models:
1546*da0073e9SAndroid Build Coastguard Worker                model = model_type()
1547*da0073e9SAndroid Build Coastguard Worker                C, H, W, = (
1548*da0073e9SAndroid Build Coastguard Worker                    3,
1549*da0073e9SAndroid Build Coastguard Worker                    224,
1550*da0073e9SAndroid Build Coastguard Worker                    224,
1551*da0073e9SAndroid Build Coastguard Worker                )
1552*da0073e9SAndroid Build Coastguard Worker                inp = torch.randn(3, C, H, W)
1553*da0073e9SAndroid Build Coastguard Worker                model(inp)
1554*da0073e9SAndroid Build Coastguard Worker                model.eval()
1555*da0073e9SAndroid Build Coastguard Worker                inp = torch.randn(1, C, H, W)
1556*da0073e9SAndroid Build Coastguard Worker                heuristic = optimization.gen_mkl_autotuner(inp, iters=0, warmup=0)
1557*da0073e9SAndroid Build Coastguard Worker                optimized_model = optimization.optimize_for_inference(model)
1558*da0073e9SAndroid Build Coastguard Worker
1559*da0073e9SAndroid Build Coastguard Worker                orig_out = model(inp)
1560*da0073e9SAndroid Build Coastguard Worker                new_out = optimized_model(inp)
1561*da0073e9SAndroid Build Coastguard Worker                torch.testing.assert_close(orig_out, new_out)
1562*da0073e9SAndroid Build Coastguard Worker
1563*da0073e9SAndroid Build Coastguard Worker
1564*da0073e9SAndroid Build Coastguard Workerclass TestNormalizeOperators(JitTestCase):
1565*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
1566*da0073e9SAndroid Build Coastguard Worker    @ops(op_db, allowed_dtypes=(torch.float,))
1567*da0073e9SAndroid Build Coastguard Worker    def test_normalize_operator_exhaustive(self, device, dtype, op):
1568*da0073e9SAndroid Build Coastguard Worker        # These ops currently don't trace in FX for various reasons (i.e. they take a list of tensors)
1569*da0073e9SAndroid Build Coastguard Worker        fx_fail = {"cat", "stack", "hstack", "vstack", "dstack", "linalg.multi_dot", "_upsample_bilinear2d_aa", "_chunk_cat"}
1570*da0073e9SAndroid Build Coastguard Worker        sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
1571*da0073e9SAndroid Build Coastguard Worker        if isinstance(op.op, torch._ops.OpOverload):
1572*da0073e9SAndroid Build Coastguard Worker            self.skipTest("normalize operator doesn't work on torch.ops")
1573*da0073e9SAndroid Build Coastguard Worker        for sample_input in sample_inputs_itr:
1574*da0073e9SAndroid Build Coastguard Worker            unsupported_arg_type = False
1575*da0073e9SAndroid Build Coastguard Worker            arg_values = [sample_input.input] + list(sample_input.args)
1576*da0073e9SAndroid Build Coastguard Worker            kwarg_values = sample_input.kwargs
1577*da0073e9SAndroid Build Coastguard Worker            arg_types = []
1578*da0073e9SAndroid Build Coastguard Worker            kwarg_types = {}
1579*da0073e9SAndroid Build Coastguard Worker
1580*da0073e9SAndroid Build Coastguard Worker            def jit_infer_type(v):
1581*da0073e9SAndroid Build Coastguard Worker                inferred_arg_type = torch._C._jit_try_infer_type(v)
1582*da0073e9SAndroid Build Coastguard Worker                assert inferred_arg_type.success()
1583*da0073e9SAndroid Build Coastguard Worker                t = _torchscript_type_to_python_type(inferred_arg_type.type())
1584*da0073e9SAndroid Build Coastguard Worker                return t
1585*da0073e9SAndroid Build Coastguard Worker
1586*da0073e9SAndroid Build Coastguard Worker            for v in arg_values:
1587*da0073e9SAndroid Build Coastguard Worker                if isinstance(v, torch.Tensor):
1588*da0073e9SAndroid Build Coastguard Worker                    arg_types.append(type(v))
1589*da0073e9SAndroid Build Coastguard Worker                else:
1590*da0073e9SAndroid Build Coastguard Worker                    if isinstance(v, complex):
1591*da0073e9SAndroid Build Coastguard Worker                        # Complex type not supported in FX
1592*da0073e9SAndroid Build Coastguard Worker                        unsupported_arg_type = True
1593*da0073e9SAndroid Build Coastguard Worker                    arg_types.append(jit_infer_type(v))
1594*da0073e9SAndroid Build Coastguard Worker
1595*da0073e9SAndroid Build Coastguard Worker            for k, v in kwarg_values.items():
1596*da0073e9SAndroid Build Coastguard Worker                if isinstance(v, torch.Tensor):
1597*da0073e9SAndroid Build Coastguard Worker                    kwarg_types[k] = type(v)
1598*da0073e9SAndroid Build Coastguard Worker                else:
1599*da0073e9SAndroid Build Coastguard Worker                    if isinstance(v, complex):
1600*da0073e9SAndroid Build Coastguard Worker                        # Complex type not supported in FX
1601*da0073e9SAndroid Build Coastguard Worker                        unsupported_arg_type = True
1602*da0073e9SAndroid Build Coastguard Worker                    kwarg_types[k] = jit_infer_type(v)
1603*da0073e9SAndroid Build Coastguard Worker
1604*da0073e9SAndroid Build Coastguard Worker            if unsupported_arg_type:
1605*da0073e9SAndroid Build Coastguard Worker                continue
1606*da0073e9SAndroid Build Coastguard Worker            # Test normalize_function by itself
1607*da0073e9SAndroid Build Coastguard Worker            ref_out = op.op(*arg_values, **kwarg_values)
1608*da0073e9SAndroid Build Coastguard Worker            norm_args_and_kwargs = normalize_function(
1609*da0073e9SAndroid Build Coastguard Worker                op.op, arg_values, kwarg_values, arg_types, kwarg_types
1610*da0073e9SAndroid Build Coastguard Worker            )
1611*da0073e9SAndroid Build Coastguard Worker            if norm_args_and_kwargs is None:
1612*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(
1613*da0073e9SAndroid Build Coastguard Worker                    """
1614*da0073e9SAndroid Build Coastguard Worker                    FX failed to normalize op - add the op to the op_skip list.
1615*da0073e9SAndroid Build Coastguard Worker                    A common reason is if your OpInfo was implemented with a lambda
1616*da0073e9SAndroid Build Coastguard Worker                    - otherwise, file an issue
1617*da0073e9SAndroid Build Coastguard Worker                    """
1618*da0073e9SAndroid Build Coastguard Worker                )
1619*da0073e9SAndroid Build Coastguard Worker            test_out = op.op(*norm_args_and_kwargs.args, **norm_args_and_kwargs.kwargs)
1620*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(test_out, ref_out)
1621*da0073e9SAndroid Build Coastguard Worker
1622*da0073e9SAndroid Build Coastguard Worker            # Test normalized_arguments as part of FX
1623*da0073e9SAndroid Build Coastguard Worker            if op.name in fx_fail:
1624*da0073e9SAndroid Build Coastguard Worker                continue
1625*da0073e9SAndroid Build Coastguard Worker            param_names = []
1626*da0073e9SAndroid Build Coastguard Worker            param_values = []
1627*da0073e9SAndroid Build Coastguard Worker            fx_args = []
1628*da0073e9SAndroid Build Coastguard Worker
1629*da0073e9SAndroid Build Coastguard Worker            idx = 0
1630*da0073e9SAndroid Build Coastguard Worker
1631*da0073e9SAndroid Build Coastguard Worker            def process_arg(arg, name):
1632*da0073e9SAndroid Build Coastguard Worker                if isinstance(arg, torch.Tensor):
1633*da0073e9SAndroid Build Coastguard Worker                    param_names.append(name)
1634*da0073e9SAndroid Build Coastguard Worker                    param_values.append(arg)
1635*da0073e9SAndroid Build Coastguard Worker                    return name
1636*da0073e9SAndroid Build Coastguard Worker                else:
1637*da0073e9SAndroid Build Coastguard Worker                    return f"{repr(arg)}"
1638*da0073e9SAndroid Build Coastguard Worker
1639*da0073e9SAndroid Build Coastguard Worker            def process_arg_with_idx(arg):
1640*da0073e9SAndroid Build Coastguard Worker                nonlocal idx
1641*da0073e9SAndroid Build Coastguard Worker                res = process_arg(arg, f"arg_{idx}")
1642*da0073e9SAndroid Build Coastguard Worker                idx = idx + 1
1643*da0073e9SAndroid Build Coastguard Worker                return res
1644*da0073e9SAndroid Build Coastguard Worker
1645*da0073e9SAndroid Build Coastguard Worker            def str_arg(arg):
1646*da0073e9SAndroid Build Coastguard Worker                if isinstance(arg, tuple):
1647*da0073e9SAndroid Build Coastguard Worker                    args = [f"{str_arg(v)}, " for v in arg]
1648*da0073e9SAndroid Build Coastguard Worker                    return f"({' '.join(args)})"
1649*da0073e9SAndroid Build Coastguard Worker                elif isinstance(arg, list):
1650*da0073e9SAndroid Build Coastguard Worker                    args = [f"{str_arg(v)}" for v in arg]
1651*da0073e9SAndroid Build Coastguard Worker                    return f"[{', '.join(args)}]"
1652*da0073e9SAndroid Build Coastguard Worker                else:
1653*da0073e9SAndroid Build Coastguard Worker                    return arg
1654*da0073e9SAndroid Build Coastguard Worker
1655*da0073e9SAndroid Build Coastguard Worker            for v in arg_values:
1656*da0073e9SAndroid Build Coastguard Worker                arg = pytree.tree_map(process_arg_with_idx, v)
1657*da0073e9SAndroid Build Coastguard Worker                fx_args.append(str_arg(arg))
1658*da0073e9SAndroid Build Coastguard Worker
1659*da0073e9SAndroid Build Coastguard Worker            for k, v in kwarg_values.items():
1660*da0073e9SAndroid Build Coastguard Worker                arg = pytree.tree_map(functools.partial(process_arg, name=k), v)
1661*da0073e9SAndroid Build Coastguard Worker                fx_args.append(f"{k} = {str_arg(arg)}")
1662*da0073e9SAndroid Build Coastguard Worker
1663*da0073e9SAndroid Build Coastguard Worker            code = f"""
1664*da0073e9SAndroid Build Coastguard Workerclass TestModule(torch.nn.Module):
1665*da0073e9SAndroid Build Coastguard Worker    def forward(self, {', '.join(param_names)}):
1666*da0073e9SAndroid Build Coastguard Worker        return torch.{op.name}({', '.join(fx_args)})
1667*da0073e9SAndroid Build Coastguard Worker            """
1668*da0073e9SAndroid Build Coastguard Worker
1669*da0073e9SAndroid Build Coastguard Worker            g = {"torch": torch, "inf": math.inf}
1670*da0073e9SAndroid Build Coastguard Worker            exec(code, g)
1671*da0073e9SAndroid Build Coastguard Worker            TestModule = g["TestModule"]
1672*da0073e9SAndroid Build Coastguard Worker
1673*da0073e9SAndroid Build Coastguard Worker            m = TestModule()
1674*da0073e9SAndroid Build Coastguard Worker            traced = torch.fx.symbolic_trace(m)
1675*da0073e9SAndroid Build Coastguard Worker            ref_out = traced(*param_values)
1676*da0073e9SAndroid Build Coastguard Worker
1677*da0073e9SAndroid Build Coastguard Worker            for node in traced.graph.nodes:
1678*da0073e9SAndroid Build Coastguard Worker                if node.op == "call_function":
1679*da0073e9SAndroid Build Coastguard Worker                    normalized_args = node.normalized_arguments(
1680*da0073e9SAndroid Build Coastguard Worker                        traced, arg_types, kwarg_types
1681*da0073e9SAndroid Build Coastguard Worker                    )
1682*da0073e9SAndroid Build Coastguard Worker                    assert normalized_args
1683*da0073e9SAndroid Build Coastguard Worker                    node.args = normalized_args.args
1684*da0073e9SAndroid Build Coastguard Worker                    node.kwargs = normalized_args.kwargs
1685*da0073e9SAndroid Build Coastguard Worker            traced.recompile()
1686*da0073e9SAndroid Build Coastguard Worker
1687*da0073e9SAndroid Build Coastguard Worker            test_out = traced(*param_values)
1688*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(test_out, ref_out)
1689*da0073e9SAndroid Build Coastguard Worker
1690*da0073e9SAndroid Build Coastguard Worker    def test_normalize_quantized_eb(self):
1691*da0073e9SAndroid Build Coastguard Worker        target = torch.ops.quantized.embedding_bag_byte_rowwise_offsets
1692*da0073e9SAndroid Build Coastguard Worker        args = (
1693*da0073e9SAndroid Build Coastguard Worker            torch.empty((2, 3), dtype=torch.uint8),
1694*da0073e9SAndroid Build Coastguard Worker            torch.empty((2,), dtype=torch.int64),
1695*da0073e9SAndroid Build Coastguard Worker            torch.empty((2,), dtype=torch.int64),
1696*da0073e9SAndroid Build Coastguard Worker        )
1697*da0073e9SAndroid Build Coastguard Worker        norm_args_and_kwargs = normalize_function(
1698*da0073e9SAndroid Build Coastguard Worker            target, args, normalize_to_only_use_kwargs=True
1699*da0073e9SAndroid Build Coastguard Worker        )
1700*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(norm_args_and_kwargs is not None)
1701*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1702*da0073e9SAndroid Build Coastguard Worker            set(norm_args_and_kwargs.kwargs.keys()),
1703*da0073e9SAndroid Build Coastguard Worker            {
1704*da0073e9SAndroid Build Coastguard Worker                "weight",
1705*da0073e9SAndroid Build Coastguard Worker                "indices",
1706*da0073e9SAndroid Build Coastguard Worker                "offsets",
1707*da0073e9SAndroid Build Coastguard Worker                "scale_grad_by_freq",
1708*da0073e9SAndroid Build Coastguard Worker                "mode",
1709*da0073e9SAndroid Build Coastguard Worker                "pruned_weights",
1710*da0073e9SAndroid Build Coastguard Worker                "per_sample_weights",
1711*da0073e9SAndroid Build Coastguard Worker                "compressed_indices_mapping",
1712*da0073e9SAndroid Build Coastguard Worker                "include_last_offset",
1713*da0073e9SAndroid Build Coastguard Worker            },
1714*da0073e9SAndroid Build Coastguard Worker        )
1715*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(norm_args_and_kwargs.args, ())
1716*da0073e9SAndroid Build Coastguard Worker
1717*da0073e9SAndroid Build Coastguard Worker    def test_normalize_args_op_overload(self):
1718*da0073e9SAndroid Build Coastguard Worker        for target in [torch.ops.aten.resize_as_.default, torch.ops.aten.resize_as_]:
1719*da0073e9SAndroid Build Coastguard Worker            inp1 = torch.rand([1])
1720*da0073e9SAndroid Build Coastguard Worker            inp2 = torch.rand([4])
1721*da0073e9SAndroid Build Coastguard Worker            args, kwargs = normalize_function(target, (inp1,), {"the_template": inp2}, normalize_to_only_use_kwargs=True)
1722*da0073e9SAndroid Build Coastguard Worker            self.assertIs(kwargs["input"], inp1)
1723*da0073e9SAndroid Build Coastguard Worker            self.assertIs(kwargs["the_template"], inp2)
1724*da0073e9SAndroid Build Coastguard Worker
1725*da0073e9SAndroid Build Coastguard Worker
1726*da0073e9SAndroid Build Coastguard Workerif TEST_Z3:
1727*da0073e9SAndroid Build Coastguard Worker    import z3
1728*da0073e9SAndroid Build Coastguard Worker
1729*da0073e9SAndroid Build Coastguard Worker    import torch._dynamo.config
1730*da0073e9SAndroid Build Coastguard Worker
1731*da0073e9SAndroid Build Coastguard Worker    from torch.fx.experimental.validator import SympyToZ3, TranslationValidator, ValidationException, z3str
1732*da0073e9SAndroid Build Coastguard Worker    from torch.utils._sympy.functions import FloorDiv, Mod
1733*da0073e9SAndroid Build Coastguard Worker
1734*da0073e9SAndroid Build Coastguard Worker    class TestTranslationValidation(TestCase):
1735*da0073e9SAndroid Build Coastguard Worker        def _prepare_for_translation_validation(self):
1736*da0073e9SAndroid Build Coastguard Worker            validator = TranslationValidator()
1737*da0073e9SAndroid Build Coastguard Worker
1738*da0073e9SAndroid Build Coastguard Worker            # SymPy symbols.
1739*da0073e9SAndroid Build Coastguard Worker            s0, s1, s2 = sympy.symbols("s0 s1 s2", integer=True)
1740*da0073e9SAndroid Build Coastguard Worker
1741*da0073e9SAndroid Build Coastguard Worker            # Z3 symbols.
1742*da0073e9SAndroid Build Coastguard Worker            [validator.add_var(s, int) for s in (s0, s1, s2)]
1743*da0073e9SAndroid Build Coastguard Worker            z0, z1, z2 = (validator.z3var(s) for s in (s0, s1, s2))
1744*da0073e9SAndroid Build Coastguard Worker
1745*da0073e9SAndroid Build Coastguard Worker            return (s0, s1, s2), (z0, z1, z2), validator
1746*da0073e9SAndroid Build Coastguard Worker
1747*da0073e9SAndroid Build Coastguard Worker        def test_sympy_to_z3(self):
1748*da0073e9SAndroid Build Coastguard Worker
1749*da0073e9SAndroid Build Coastguard Worker            (
1750*da0073e9SAndroid Build Coastguard Worker                (s0, s1, s2),
1751*da0073e9SAndroid Build Coastguard Worker                (z0, z1, z2),
1752*da0073e9SAndroid Build Coastguard Worker                validator,
1753*da0073e9SAndroid Build Coastguard Worker            ) = self._prepare_for_translation_validation()
1754*da0073e9SAndroid Build Coastguard Worker
1755*da0073e9SAndroid Build Coastguard Worker            test_cases = [
1756*da0073e9SAndroid Build Coastguard Worker                # Integer constants.
1757*da0073e9SAndroid Build Coastguard Worker                (sympy.S.Zero, z3.IntVal(0)),
1758*da0073e9SAndroid Build Coastguard Worker                (sympy.S.One, z3.IntVal(1)),
1759*da0073e9SAndroid Build Coastguard Worker                (sympy.S.NegativeOne, z3.IntVal(-1)),
1760*da0073e9SAndroid Build Coastguard Worker                (sympy.Integer(2), z3.IntVal(2)),
1761*da0073e9SAndroid Build Coastguard Worker                (
1762*da0073e9SAndroid Build Coastguard Worker                    s0,
1763*da0073e9SAndroid Build Coastguard Worker                    z0,
1764*da0073e9SAndroid Build Coastguard Worker                ),
1765*da0073e9SAndroid Build Coastguard Worker                # Arithmetic operations.
1766*da0073e9SAndroid Build Coastguard Worker                *[
1767*da0073e9SAndroid Build Coastguard Worker                    (op(s0, s1), op(z0, z1))
1768*da0073e9SAndroid Build Coastguard Worker                    for op in (
1769*da0073e9SAndroid Build Coastguard Worker                        operator.add,
1770*da0073e9SAndroid Build Coastguard Worker                        operator.mul,
1771*da0073e9SAndroid Build Coastguard Worker                        operator.pow,
1772*da0073e9SAndroid Build Coastguard Worker                    )
1773*da0073e9SAndroid Build Coastguard Worker                ],
1774*da0073e9SAndroid Build Coastguard Worker                # Logical operations.
1775*da0073e9SAndroid Build Coastguard Worker                *[
1776*da0073e9SAndroid Build Coastguard Worker                    (sympy_op(s0, s1), z3_op(z0, z1))
1777*da0073e9SAndroid Build Coastguard Worker                    for sympy_op, z3_op in (
1778*da0073e9SAndroid Build Coastguard Worker                        (sympy.Eq, operator.eq),
1779*da0073e9SAndroid Build Coastguard Worker                        (sympy.Ne, operator.ne),
1780*da0073e9SAndroid Build Coastguard Worker                        (sympy.Lt, operator.lt),
1781*da0073e9SAndroid Build Coastguard Worker                        (sympy.Le, operator.le),
1782*da0073e9SAndroid Build Coastguard Worker                        (sympy.Gt, operator.gt),
1783*da0073e9SAndroid Build Coastguard Worker                        (sympy.Ge, operator.ge),
1784*da0073e9SAndroid Build Coastguard Worker                    )
1785*da0073e9SAndroid Build Coastguard Worker                ],
1786*da0073e9SAndroid Build Coastguard Worker                # Other operations.
1787*da0073e9SAndroid Build Coastguard Worker                (
1788*da0073e9SAndroid Build Coastguard Worker                    s0 - s1,
1789*da0073e9SAndroid Build Coastguard Worker                    z0 + z3.IntVal(-1) * z1,
1790*da0073e9SAndroid Build Coastguard Worker                ),
1791*da0073e9SAndroid Build Coastguard Worker                (
1792*da0073e9SAndroid Build Coastguard Worker                    s0 / s1,
1793*da0073e9SAndroid Build Coastguard Worker                    z3.ToReal(z0) * (z1**-1),
1794*da0073e9SAndroid Build Coastguard Worker                ),
1795*da0073e9SAndroid Build Coastguard Worker                (FloorDiv(s0, s1), z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1))),
1796*da0073e9SAndroid Build Coastguard Worker                (Mod(s0, s1), z0 - z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1)) * z1),
1797*da0073e9SAndroid Build Coastguard Worker                (
1798*da0073e9SAndroid Build Coastguard Worker                    Mod(s2, (s0 / s1)),
1799*da0073e9SAndroid Build Coastguard Worker                    z2
1800*da0073e9SAndroid Build Coastguard Worker                    - z3.ToReal(z3.ToInt(z3.ToReal(z2) / (z3.ToReal(z0) * z1**-1)))
1801*da0073e9SAndroid Build Coastguard Worker                    * (z3.ToReal(z0) * z1**-1),
1802*da0073e9SAndroid Build Coastguard Worker                ),
1803*da0073e9SAndroid Build Coastguard Worker                (
1804*da0073e9SAndroid Build Coastguard Worker                    Mod(s2, s0**3),
1805*da0073e9SAndroid Build Coastguard Worker                    z2 - z3.ToReal(z3.ToInt(z3.ToReal(z2) / z0**3)) * z0**3,
1806*da0073e9SAndroid Build Coastguard Worker                ),
1807*da0073e9SAndroid Build Coastguard Worker            ]
1808*da0073e9SAndroid Build Coastguard Worker
1809*da0073e9SAndroid Build Coastguard Worker            toZ3 = SympyToZ3(validator)
1810*da0073e9SAndroid Build Coastguard Worker            for sympy_expr, z3_expr in test_cases:
1811*da0073e9SAndroid Build Coastguard Worker                result = toZ3.run(sympy_expr)
1812*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(
1813*da0073e9SAndroid Build Coastguard Worker                    z3_expr.eq(result), msg=f"expected: {z3_expr}. Got: {result}"
1814*da0073e9SAndroid Build Coastguard Worker                )
1815*da0073e9SAndroid Build Coastguard Worker
1816*da0073e9SAndroid Build Coastguard Worker        def test_sat(self):
1817*da0073e9SAndroid Build Coastguard Worker            (
1818*da0073e9SAndroid Build Coastguard Worker                (s0, s1, s2),
1819*da0073e9SAndroid Build Coastguard Worker                (z0, z1, z2),
1820*da0073e9SAndroid Build Coastguard Worker                validator,
1821*da0073e9SAndroid Build Coastguard Worker            ) = self._prepare_for_translation_validation()
1822*da0073e9SAndroid Build Coastguard Worker
1823*da0073e9SAndroid Build Coastguard Worker            validator.add_source_expr(z0 > 5)
1824*da0073e9SAndroid Build Coastguard Worker            validator.add_source_expr(z1 / 2 > z0)
1825*da0073e9SAndroid Build Coastguard Worker
1826*da0073e9SAndroid Build Coastguard Worker            # Solutions for target is a subset of the solutions for the source.
1827*da0073e9SAndroid Build Coastguard Worker            validator.add_target_expr(s0 > 20)
1828*da0073e9SAndroid Build Coastguard Worker            validator.add_target_expr(s1 > s0**2)
1829*da0073e9SAndroid Build Coastguard Worker
1830*da0073e9SAndroid Build Coastguard Worker            validator.validate()
1831*da0073e9SAndroid Build Coastguard Worker
1832*da0073e9SAndroid Build Coastguard Worker        def test_unsat(self):
1833*da0073e9SAndroid Build Coastguard Worker            (
1834*da0073e9SAndroid Build Coastguard Worker                (s0, s1, s2),
1835*da0073e9SAndroid Build Coastguard Worker                (z0, z1, z2),
1836*da0073e9SAndroid Build Coastguard Worker                validator,
1837*da0073e9SAndroid Build Coastguard Worker            ) = self._prepare_for_translation_validation()
1838*da0073e9SAndroid Build Coastguard Worker
1839*da0073e9SAndroid Build Coastguard Worker            validator.add_source_expr(z0 > 5)
1840*da0073e9SAndroid Build Coastguard Worker            validator.add_source_expr(z1 / 2 > z0)
1841*da0073e9SAndroid Build Coastguard Worker
1842*da0073e9SAndroid Build Coastguard Worker            # Solutions for target is NOT a subset of the solutions for the source.
1843*da0073e9SAndroid Build Coastguard Worker            validator.add_target_expr(s0 > 20)
1844*da0073e9SAndroid Build Coastguard Worker            # This expression is less restrictive than its counterpart.
1845*da0073e9SAndroid Build Coastguard Worker            validator.add_target_expr(s1 > s0 + 2)
1846*da0073e9SAndroid Build Coastguard Worker
1847*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(ValidationException, "translation validation failed."):
1848*da0073e9SAndroid Build Coastguard Worker                validator.validate()
1849*da0073e9SAndroid Build Coastguard Worker
1850*da0073e9SAndroid Build Coastguard Worker        def test_z3str(self):
1851*da0073e9SAndroid Build Coastguard Worker            a = z3.Int("a")
1852*da0073e9SAndroid Build Coastguard Worker            b = z3.Int("b")
1853*da0073e9SAndroid Build Coastguard Worker            special = z3.Real("this.size()[2]")
1854*da0073e9SAndroid Build Coastguard Worker
1855*da0073e9SAndroid Build Coastguard Worker            test_cases = [
1856*da0073e9SAndroid Build Coastguard Worker                (z3.IntVal(42), "42"),
1857*da0073e9SAndroid Build Coastguard Worker                # Variable.
1858*da0073e9SAndroid Build Coastguard Worker                (a, "a"),
1859*da0073e9SAndroid Build Coastguard Worker                # Name with special characters.
1860*da0073e9SAndroid Build Coastguard Worker                (special, "this.size()[2]"),
1861*da0073e9SAndroid Build Coastguard Worker                # Renamed function fpplications.
1862*da0073e9SAndroid Build Coastguard Worker                (a != b, "(!= a b)"),
1863*da0073e9SAndroid Build Coastguard Worker                (a ** b, "(pow a b)"),
1864*da0073e9SAndroid Build Coastguard Worker                # Chain of associative operations.
1865*da0073e9SAndroid Build Coastguard Worker                *[
1866*da0073e9SAndroid Build Coastguard Worker                    (op(op(a, 5), b), f"({opstr} 5 a b)")
1867*da0073e9SAndroid Build Coastguard Worker                    for op, opstr in [
1868*da0073e9SAndroid Build Coastguard Worker                        (operator.add, "+"),
1869*da0073e9SAndroid Build Coastguard Worker                        (operator.mul, "*")
1870*da0073e9SAndroid Build Coastguard Worker                    ]
1871*da0073e9SAndroid Build Coastguard Worker                ],
1872*da0073e9SAndroid Build Coastguard Worker                # Revert 'Not' conversions.
1873*da0073e9SAndroid Build Coastguard Worker                (a != b, "(!= a b)"),
1874*da0073e9SAndroid Build Coastguard Worker                (a < b, "(> b a)"),
1875*da0073e9SAndroid Build Coastguard Worker                (a > b, "(> a b)"),
1876*da0073e9SAndroid Build Coastguard Worker                # Ignore 'ToInt' and 'ToReal' functions.
1877*da0073e9SAndroid Build Coastguard Worker                (z3.ToInt(special) + a, "(+ this.size()[2] a)"),
1878*da0073e9SAndroid Build Coastguard Worker                (z3.ToReal(a + b), "(+ a b)"),
1879*da0073e9SAndroid Build Coastguard Worker                # Convert to floor division: 'idiv'.
1880*da0073e9SAndroid Build Coastguard Worker                (z3.ToInt(z3.ToReal(a) / z3.ToReal(b)), "(idiv a b)"),
1881*da0073e9SAndroid Build Coastguard Worker            ]
1882*da0073e9SAndroid Build Coastguard Worker
1883*da0073e9SAndroid Build Coastguard Worker            for expr, expected in test_cases:
1884*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(z3str(expr), expected)
1885*da0073e9SAndroid Build Coastguard Worker
1886*da0073e9SAndroid Build Coastguard Worker
1887*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestNormalizeOperators, globals())
1888*da0073e9SAndroid Build Coastguard Worker
1889*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
1890*da0073e9SAndroid Build Coastguard Worker    run_tests()
1891