xref: /aosp_15_r20/external/pytorch/test/test_fx_experimental.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: fx"]
2
3import functools
4import math
5import numbers
6import operator
7import pickle
8import sys
9import sympy
10import tempfile
11import unittest
12from types import BuiltinFunctionType
13from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union
14
15import torch
16import torch.fx.experimental.meta_tracer
17import torch.fx.experimental.optimization as optimization
18from torch.fx._symbolic_trace import symbolic_trace
19from torch.fx.experimental import merge_matmul
20from torch.fx.experimental.accelerator_partitioner import Partitioner
21from torch.fx.experimental.normalize import NormalizeArgs, NormalizeOperators
22from torch.fx.experimental.partitioner_utils import (
23    Device,
24    get_latency_of_partitioned_graph,
25    get_partition_to_latency_mapping,
26    NodeLatency,
27    PartitionerConfig,
28    PartitionMode,
29)
30from torch.fx.experimental.rewriter import RewritingTracer
31from torch.fx.experimental.schema_type_annotation import AnnotateTypesWithSchema
32from torch.fx.graph_module import GraphModule
33from torch.fx.node import Node
34from torch.fx.operator_schemas import (
35    _torchscript_type_to_python_type,
36    create_type_hint,
37    normalize_function,
38    normalize_module,
39    type_matches,
40)
41from torch.fx.passes import graph_manipulation
42from torch.fx.passes.param_fetch import lift_lowering_attrs_to_nodes
43from torch.fx.passes.shape_prop import ShapeProp
44from torch.fx.passes.split_module import split_module
45from torch.fx.passes.annotate_getitem_nodes import annotate_getitem_nodes
46from torch.testing._internal.common_device_type import (
47    instantiate_device_type_tests,
48    onlyCPU,
49    ops,
50)
51from torch.testing._internal.common_methods_invocations import op_db
52from torch.testing._internal.common_nn import module_tests, new_module_tests
53from torch.testing._internal.common_utils import TEST_Z3, run_tests, TestCase
54from torch.testing._internal.jit_utils import JitTestCase
55import torch.utils._pytree as pytree
56
57try:
58    import torchvision.models
59    from torchvision.models import resnet18
60
61    HAS_TORCHVISION = True
62except ImportError:
63    HAS_TORCHVISION = False
64skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
65skipIfNoMkldnn = unittest.skipIf(
66    not (torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available()),
67    "no MKLDNN",
68)
69
70
71def symbolic_trace_with_rewrite(root: Union[torch.nn.Module, Callable]) -> GraphModule:
72    return GraphModule(
73        root if isinstance(root, torch.nn.Module) else torch.nn.Module(),
74        RewritingTracer().trace(root),
75    )
76
77
78class TestFXExperimental(JitTestCase):
79    def test_find_single_partition(self):
80        class TestModule(torch.nn.Module):
81            def forward(self, a, b):
82                return a + b
83
84        m = TestModule()
85        traced = symbolic_trace(m)
86        a = torch.rand(1)
87        b = torch.rand(1)
88        graph_manipulation.get_size_of_all_nodes(traced, [a, b])
89        partitioner = Partitioner()
90        devices = [
91            Device("dev_0", 125, 0),
92            Device("dev_1", 150, 1),
93            Device("dev_2", 125, 2),
94        ]
95        partitioner_config = PartitionerConfig(devices)
96        ret = partitioner.partition_graph(traced, m, partitioner_config)
97        module_with_submodules = ret.module_with_submodules
98        dag = ret.dag
99        self.assertEqual(traced(a, b), module_with_submodules(a, b))
100        assert dag.nodes[0].logical_device_ids == [1]
101
102    def test_lack_of_devices(self):
103        class TestModule(torch.nn.Module):
104            def forward(self, a, b):
105                return a + b
106
107        m = TestModule()
108        traced = symbolic_trace(m)
109        a = torch.rand(4)
110        b = torch.rand(4)
111        graph_manipulation.get_size_of_all_nodes(traced, [a, b])
112        partitioner = Partitioner()
113        devices = [Device("dev_0", 4, 0), Device("dev_1", 4, 1)]
114        partitioner_config = PartitionerConfig(devices, PartitionMode.size_based)
115        catch_runtime_error = False
116        try:
117            ret = partitioner.partition_graph(traced, m, partitioner_config)
118        except RuntimeError:
119            catch_runtime_error = True
120        assert catch_runtime_error
121
122    def test_large_node_error(self):
123        class TestModule(torch.nn.Module):
124            def __init__(self) -> None:
125                super().__init__()
126                self.linear = torch.nn.Linear(4, 4)
127
128            def forward(self, a):
129                linear = self.linear(a)
130                add = linear + a
131                return add
132
133        m = TestModule()
134        traced = symbolic_trace(m)
135        a = torch.rand(4)
136        graph_manipulation.get_size_of_all_nodes(traced, [a])
137        partitioner = Partitioner()
138        devices = [
139            Device("dev_0", 40, 0),
140            Device("dev_1", 40, 0),
141            Device("dev_2", 40, 0),
142            Device("dev_3", 40, 0),
143            Device("dev_4", 40, 0),
144        ]
145        partitioner_config = PartitionerConfig(devices, PartitionMode.size_based)
146        catch_runtime_error = False
147        try:
148            ret = partitioner.partition_graph(traced, m, partitioner_config)
149        except RuntimeError:
150            catch_runtime_error = True
151        assert catch_runtime_error
152
153    def test_partition_node_manipulation(self):
154        class TestModule(torch.nn.Module):
155            def forward(self, a, b):
156                add_1 = a + b
157                add_2 = add_1 + torch.rand(4)
158                add_3 = add_2 + torch.rand(4)
159                return add_3
160
161        m = TestModule()
162        traced = symbolic_trace(m)
163        a, b = torch.rand(4), torch.rand(4)
164        graph_manipulation.get_size_of_all_nodes(traced, [a, b])
165        partitioner = Partitioner()
166        devices = [Device("dev_0", 1000, 0)]
167        partitioner_config = PartitionerConfig(devices)
168        ret = partitioner.partition_graph(traced, m, partitioner_config)
169        partition = partitioner.partitions[0]
170        assert partition.used_mem_bytes == 112
171        # Select add_2 node to remove
172        selected_node = None
173        for node in partition.nodes:
174            if node.name == "add_2":
175                selected_node = node
176        partition.remove_node(selected_node)
177        assert partition.used_mem_bytes == 80
178
179    def test_size_based_partition(self):
180        class TestModule(torch.nn.Module):
181            def __init__(self) -> None:
182                super().__init__()
183                self.linear = torch.nn.Linear(4, 4)
184                self.c = torch.rand(4)
185
186            def forward(self, a, b):
187                add_1 = a + b
188                linear = self.linear(add_1)
189                add_2 = linear + self.c
190                return add_2
191
192        m = TestModule()
193        traced = symbolic_trace(m)
194        a = torch.rand(4)
195        b = torch.rand(4)
196        graph_manipulation.get_size_of_all_nodes(traced, [a, b])
197        partitioner = Partitioner()
198        devices = [
199            Device("dev_0", 125, 0),
200            Device("dev_1", 125, 1),
201            Device("dev_2", 125, 2),
202        ]
203        partitioner_config = PartitionerConfig(devices, PartitionMode.size_based)
204        ret = partitioner.partition_graph(traced, m, partitioner_config)
205        module_with_submodules = ret.module_with_submodules
206        dag = ret.dag
207        self.assertEqual(traced(a, b), module_with_submodules(a, b))
208        for i, node in enumerate(dag.nodes):
209            assert node.logical_device_ids == [i]
210
211    def test_partition_device_mapping(self):
212        class TestModule(torch.nn.Module):
213            def __init__(self) -> None:
214                super().__init__()
215                self.linear = torch.nn.Linear(4, 4)
216
217            def forward(self, a):
218                b = torch.rand(4)
219                add_1 = a + b
220                linear_1 = self.linear(add_1)
221                add_2 = torch.rand(4) + a
222                add_3 = add_2 + linear_1
223                return add_3
224
225        m = TestModule()
226        traced = symbolic_trace(m)
227        a = torch.rand(4)
228        graph_manipulation.get_size_of_all_nodes(traced, [a])
229        partitioner = Partitioner()
230        devices = [Device("dev_0", 120, 0), Device("dev_1", 160, 1)]
231        partitioner_config = PartitionerConfig(devices, PartitionMode.size_based)
232        ret = partitioner.partition_graph(traced, m, partitioner_config)
233        module_with_submodules = ret.module_with_submodules
234        dag = ret.dag
235        self.assertEqual(traced(a), module_with_submodules(a))
236        for i, node in enumerate(dag.nodes):
237            if i == 1:
238                assert node.logical_device_ids == [1]
239            else:
240                assert node.logical_device_ids == [0]
241
242    def test_sparse_nn_partition(self):
243        class MyRecommendationModule(torch.nn.Module):
244            def create_mlp(self, num_of_layers: int, input_size: int, output_size: int):
245                layers = torch.nn.ModuleList()
246                for _ in range(num_of_layers):
247                    ll = torch.nn.Linear(input_size, output_size)
248                    layers.append(ll)
249                    layers.append(torch.nn.ReLU())
250                return layers
251
252            def __init__(self) -> None:
253                super().__init__()
254                layers = self.create_mlp(4, 4, 4)
255                self.bottom_layers = torch.nn.Sequential(*layers)
256                layers = self.create_mlp(3, 24, 24)
257                self.top_layers = torch.nn.Sequential(*layers)
258                self.embedding_layers = torch.nn.ModuleList()
259                el = torch.nn.EmbeddingBag(500000, 4, mode="sum", sparse=True)
260                self.embedding_layers.append(el)
261                for i in range(3):
262                    el = torch.nn.EmbeddingBag(1000000, 4, mode="sum", sparse=True)
263                    self.embedding_layers.append(el)
264                el = torch.nn.EmbeddingBag(500000, 4, mode="sum", sparse=True)
265                self.embedding_layers.append(el)
266
267            def forward(self, a, b, offset):
268                x = self.bottom_layers(a)
269                y = []
270                c = []
271                for i in range(len(self.embedding_layers)):
272                    temp = torch.randint(10, (8,))
273                    c.append(temp + b)
274                for i in range(len(self.embedding_layers)):
275                    if i % 2 == 0:
276                        y.append(self.embedding_layers[i](c[i], offset))
277                    else:
278                        y.append(
279                            self.embedding_layers[i](torch.randint(10, (8,)), offset)
280                        )
281                z = torch.cat([x] + y, dim=1)
282                p = self.top_layers(z)
283                return p
284
285        m = MyRecommendationModule()
286        a = torch.rand(2, 4)
287        b = torch.randint(10, (8,))
288        offset = torch.randint(1, (2,))
289        traced = symbolic_trace(m)
290        graph_manipulation.get_size_of_all_nodes(traced, [a, b, offset])
291        devices = [
292            Device("dev_0", 33000000, 0),
293            Device("dev_1", 33000000, 1),
294            Device("dev_2", 33000000, 2),
295        ]
296        partitioner_config = PartitionerConfig(devices, PartitionMode.sparse_nn)
297        partitioner = Partitioner()
298        ret = partitioner.partition_graph(traced, m, partitioner_config)
299        module_with_submodules = ret.module_with_submodules
300        dag = ret.dag
301        self.assertEqual(traced(a, b, offset), module_with_submodules(a, b, offset))
302        assert len(module_with_submodules.graph.nodes) == 24
303
304    def test_partition_latency(self):
305        class TestModule(torch.nn.Module):
306            def __init__(self) -> None:
307                super().__init__()
308                self.linear = torch.nn.Linear(4, 4)
309
310            def forward(self, a):
311                add_1 = a + torch.rand(4)
312                add_2 = add_1 + torch.rand(4)
313                linear_1 = self.linear(add_1)
314                add_3 = add_2 + linear_1
315                add_4 = add_2 + add_3
316                return add_4
317
318        def get_node_to_latency_mapping(fx_module: GraphModule):
319            """Given a fx module, generate node latency for each node
320            based on the size of each node
321            """
322            node_to_latency_mapping: Dict[Node, NodeLatency] = {}
323            for node in fx_module.graph.nodes:
324                if node.op not in {"output", "placeholder", "get_attr"}:
325                    if node.size_bytes.total_size == node.size_bytes.output_size:
326                        node_to_latency_mapping[node] = NodeLatency(
327                            node.size_bytes.total_size, 2.0 * node.size_bytes.total_size
328                        )
329                    else:
330                        node_to_latency_mapping[node] = NodeLatency(
331                            node.size_bytes.total_size, node.size_bytes.output_size
332                        )
333            return node_to_latency_mapping
334
335        m = TestModule()
336        traced = symbolic_trace(m)
337        a = torch.rand(4)
338        graph_manipulation.get_size_of_all_nodes(traced, [a])
339        node_to_latency_mapping = get_node_to_latency_mapping(traced)
340        devices = [Device("dev_0", 200, 0), Device("dev_1", 200, 1)]
341        partitioner = Partitioner()
342        partitioner_config = PartitionerConfig(devices)
343        ret = partitioner.partition_graph(traced, m, partitioner_config)
344        module_with_submodules = ret.module_with_submodules
345        self.assertEqual(traced(a), module_with_submodules(a))
346        partitions = partitioner.partitions
347        partition_to_latency_mapping = get_partition_to_latency_mapping(
348            partitions, node_to_latency_mapping
349        )
350        for p in partition_to_latency_mapping:
351            if p.partition_id == 0:
352                assert partition_to_latency_mapping[p] == (128.0, 80.0, 160.0)
353            else:
354                assert partition_to_latency_mapping[p] == (16.0, 32.0, 32.0)
355        transfer_rate_bytes_per_sec = 2
356        critical_path_latency_sec = get_latency_of_partitioned_graph(
357            partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec
358        )
359        assert critical_path_latency_sec == 208.0
360
361    def test_cost_aware_partition(self):
362        class MyModule(torch.nn.Module):
363            def __init__(self) -> None:
364                super().__init__()
365                self.linear = torch.nn.Linear(4, 4)
366
367            def forward(self, a):
368                add_1 = a + torch.rand(4)
369                add_2 = add_1 + torch.rand(4)
370                linear_1 = self.linear(add_1)
371                add_3 = add_2 + torch.rand(4)
372                add_4 = add_2 + linear_1
373                add_5 = add_3 + add_4
374                return add_5
375
376        def get_node_to_latency_mapping(fx_module: GraphModule):
377            node_to_latency_mapping: Dict[Node, NodeLatency] = {}
378            for node in fx_module.graph.nodes:
379                if node.op not in {"output", "placeholder", "get_attr"}:
380                    if node.size_bytes.total_size == node.size_bytes.output_size:
381                        node_to_latency_mapping[node] = NodeLatency(
382                            node.size_bytes.total_size, 1
383                        )
384                    else:
385                        node_to_latency_mapping[node] = NodeLatency(
386                            node.size_bytes.total_size, node.size_bytes.output_size
387                        )
388            return node_to_latency_mapping
389
390        m = MyModule()
391        traced = symbolic_trace(m)
392        a = torch.rand(4)
393        graph_manipulation.get_size_of_all_nodes(traced, [a])
394        devices = [
395            Device("dev_0", 125, 0),
396            Device("dev_1", 125, 1),
397            Device("dev_2", 125, 2),
398            Device("dev_3", 125, 3),
399        ]
400        node_to_latency_mapping = get_node_to_latency_mapping(traced)
401        partitioner_config = PartitionerConfig(
402            devices,
403            mode=PartitionMode.cost_aware,
404            transfer_rate_bytes_per_sec=2,
405            node_to_latency_mapping=node_to_latency_mapping,
406        )
407        partitioner = Partitioner()
408        ret = partitioner.partition_graph(traced, m, partitioner_config)
409        module_with_submodules = ret.module_with_submodules
410        dag = ret.dag
411        self.assertEqual(traced(a), module_with_submodules(a))
412        partitions = partitioner.partitions
413        partition_to_latency_mapping = get_partition_to_latency_mapping(
414            partitions, node_to_latency_mapping
415        )
416        critical_path_latency_sec = get_latency_of_partitioned_graph(
417            partitions,
418            partition_to_latency_mapping,
419            partitioner_config.transfer_rate_bytes_per_sec,
420        )
421        assert critical_path_latency_sec == 160.0
422
423    def test_aot_based_partition(self):
424        class TestModule(torch.nn.Module):
425            def __init__(self) -> None:
426                super().__init__()
427                self.b = torch.rand(4)
428                self.c = torch.rand(4)
429
430            def forward(self, a):
431                add_1 = a + self.b
432                add_2 = self.c + add_1
433                return add_2
434
435        m = TestModule()
436        traced = symbolic_trace(m)
437        a = torch.rand(4)
438        node_to_partition_id = {}
439        partition_to_logical_devices = {}
440        count = 0
441        graph_manipulation.get_size_of_all_nodes(traced, [a])
442        for node in traced.graph.nodes:
443            if node.op not in {"placeholder", "get_attr", "output"}:
444                node_to_partition_id[node] = count
445                partition_to_logical_devices[count] = [0]
446                count += 1
447        devices = [Device("dev_0", 200, 0)]
448        partitioner_config = PartitionerConfig(
449            devices=devices,
450            mode=PartitionMode.aot_based,
451            node_to_partition_mapping=node_to_partition_id,
452            partition_to_logical_device_mapping=partition_to_logical_devices,
453        )
454        partitioner = Partitioner()
455        ret = partitioner.partition_graph(traced, m, partitioner_config)
456        module_with_submodules = ret.module_with_submodules
457        dag = ret.dag
458        self.assertEqual(module_with_submodules(a), traced(a))
459        for node in dag.nodes:
460            assert node.size_bytes == 48
461            assert node.logical_device_ids == [0]
462
463    def test_replace_target_nodes_with(self):
464        class testModule(torch.nn.Module):
465            def forward(self, a, b):
466                return a + b
467
468        m = testModule()
469        traced = symbolic_trace(m)
470        input1 = torch.randn(1)
471        input2 = torch.randn(1)
472        assert (input1 + input2) == traced(input1, input2)
473        graph_manipulation.replace_target_nodes_with(
474            fx_module=traced,
475            old_op="call_function",
476            old_target=operator.add,
477            new_op="call_function",
478            new_target=operator.mul,
479        )
480        assert (input1 * input2) == traced(input1, input2)
481
482    def test_saturate_host(self):
483        class TestModule(torch.nn.Module):
484            def __init__(self) -> None:
485                super().__init__()
486                self.linear = torch.nn.Linear(4, 4)
487
488            def forward(self, a):
489                add_1 = a + torch.rand(4)
490                add_2 = add_1 + torch.rand(4)
491                linear_1 = self.linear(add_1)
492                add_3 = add_2 + linear_1
493                add_4 = add_2 + add_3
494                return add_4
495
496        m = TestModule()
497        traced = symbolic_trace(m)
498        a = torch.rand(4)
499        graph_manipulation.get_size_of_all_nodes(traced, [a])
500        devices = [
501            Device("dev_0", 200, 0),
502            Device("dev_1", 200, 1),
503            Device("dev_2", 100, 2),
504            Device("dev_3", 100, 3),
505            Device("dev_4", 200, 4),
506            Device("dev_5", 100, 5),
507        ]
508        partitioner = Partitioner()
509        # Without host saturation, the model will be split into two partitions.
510        # dev_0 holds partition 0 of 192 bytes and dev_1 holds partition 1 of 48 bytes.
511        partitioner_config = PartitionerConfig(devices, saturate_host=True)
512        ret = partitioner.partition_graph(traced, m, partitioner_config)
513        module_with_submodules = ret.module_with_submodules
514        self.assertEqual(traced(a), module_with_submodules(a))
515
516        partitions = partitioner.partitions
517        self.assertEqual(len(partitions), 2)
518        # With host saturation, partition 1 will be replicated to dev_4, and partition 2
519        # will be replicated to dev_2.
520        self.assertEqual(partitions[0].logical_device_ids, [0, 4])
521        self.assertEqual(partitions[1].logical_device_ids, [1, 2])
522
523    @skipIfNoTorchVision
524    def test_conv_bn_fusion(self):
525        rn18 = resnet18().eval()
526        traced = symbolic_trace(rn18)
527        fused = optimization.fuse(traced)
528
529        self.assertTrue(
530            all(not isinstance(m, torch.nn.BatchNorm2d) for m in fused.modules())
531        )
532
533        N, C, H, W = 20, 3, 224, 224
534        inp = torch.randn(N, C, H, W)
535
536        self.assertEqual(fused(inp), rn18(inp))
537
538    def test_conv_bn_fusion_not_running_state(self):
539        class M(torch.nn.Module):
540            def __init__(self) -> None:
541                super().__init__()
542                self.conv = torch.nn.Conv2d(32, 64, 3, stride=2)
543                self.bn = torch.nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
544
545            def forward(self, x):
546                x = self.conv(x)
547                x = self.bn(x)
548                return x
549
550        model = M().eval()
551
552        traced = symbolic_trace(model)
553        fused = optimization.fuse(traced)
554        inp = torch.randn([1, 32, 50, 50])
555
556        # bn need not be folded in conv
557        self.assertTrue(
558            any(isinstance(m, torch.nn.BatchNorm2d) for m in fused.modules())
559        )
560        self.assertEqual(fused(inp), model(inp))
561
562    def test_conv_bn_fusion_mixed_dtype(self):
563        class M(torch.nn.Module):
564            def __init__(self) -> None:
565                super().__init__()
566                self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False, dtype=torch.bfloat16)
567                self.bn = torch.nn.BatchNorm2d(16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
568
569            def forward(self, x):
570                x = self.conv(x)
571                x = self.bn(x)
572                return x
573
574        model = M().eval()
575
576        traced = symbolic_trace(model)
577        fused = optimization.fuse(traced)
578        inp = torch.randn(1, 3, 64, 64, dtype=torch.bfloat16)
579
580        self.assertTrue(
581            all(not isinstance(m, torch.nn.BatchNorm2d) for m in fused.modules())
582        )
583        self.assertEqual(fused(inp), model(inp))
584
585    def test_call_to_assert_no_msg(self):
586        class M(torch.nn.Module):
587            def forward(self, a, b):
588                assert a == b
589                return a + b
590
591        m = M()
592        traced = symbolic_trace_with_rewrite(m)
593
594        # Make sure the graph is well-formed
595        traced.graph.lint()
596
597        # Check the IR to make sure there's a call_function node with target == "Assert"
598        self.assertTrue(
599            any(
600                node.op == "call_function" and node.target == torch._assert
601                for node in traced.graph.nodes
602            )
603        )
604
605        # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to
606        traced(3, 3)
607        with self.assertRaisesRegex(AssertionError, ""):
608            traced(3, 5)
609
610        # Confirm that the output is correct
611        self.assertEqual(traced(3, 3), m(3, 3))
612
613    def test_meta_tracer(self):
614        class MetaTracerTestModule(torch.nn.Module):
615            def __init__(self) -> None:
616                super().__init__()
617                self.emb = torch.nn.Embedding(num_embeddings=42, embedding_dim=16)
618                self.layernorm = torch.nn.LayerNorm(16)
619
620            def forward(self, x):
621                emb = self.emb(x)
622                emb = emb + torch.arange(emb.shape[-1], dtype=torch.float, device=emb.device)
623                lol = self.layernorm(emb)
624                return torch.relu(lol) if lol.shape[0] < 30 else torch.sigmoid(lol)
625
626        mttm = MetaTracerTestModule()
627        for BS in [15, 35]:
628            x = torch.zeros(BS, dtype=torch.long).random_(42)
629            meta_args = {'x' : x.to(device='meta')}
630            gm = torch.fx.experimental.meta_tracer.symbolic_trace(mttm, meta_args=meta_args)
631            torch.testing.assert_close(gm(x), mttm(x))
632
633            # Test serialization/deserialization
634            with tempfile.TemporaryDirectory() as tmp_dir:
635                with open(f'{tmp_dir}/meta_module.pkl', 'wb') as f:
636                    pickle.dump(gm, f)
637
638                with open(f'{tmp_dir}/meta_module.pkl', 'rb') as f:
639                    loaded = pickle.load(f)
640
641                torch.testing.assert_close(loaded(x), mttm(x))
642
643
644    def test_call_to_assert_with_msg(self):
645        class M(torch.nn.Module):
646            def forward(self, a, b):
647                assert a == b, "test message"
648                return a + b
649
650        m = M()
651        traced = symbolic_trace_with_rewrite(m)
652
653        # Make sure the graph is well-formed
654        traced.graph.lint()
655
656        # Check the IR to make sure there's a call_function node with target == "Assert"
657        self.assertTrue(
658            any(
659                node.op == "call_function" and node.target == torch._assert
660                for node in traced.graph.nodes
661            )
662        )
663
664        # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to
665        traced(3, 3)
666        with self.assertRaisesRegex(AssertionError, "test message"):
667            traced(3, 5)
668
669        # Confirm that the output is correct
670        self.assertEqual(traced(3, 3), m(3, 3))
671
672    def test_call_to_assert_with_empty_msg(self):
673        class M(torch.nn.Module):
674            def forward(self, a, b):
675                assert a == b, ""
676                return a + b
677
678        m = M()
679        traced = symbolic_trace_with_rewrite(m)
680
681        # Make sure the graph is well-formed
682        traced.graph.lint()
683
684        # Check the IR to make sure there's a call_function node with target == "Assert"
685        self.assertTrue(
686            any(
687                node.op == "call_function" and node.target == torch._assert
688                for node in traced.graph.nodes
689            )
690        )
691
692        # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to
693        traced(3, 3)
694        with self.assertRaisesRegex(AssertionError, ""):
695            traced(3, 5)
696
697        # Confirm that the output is correct
698        self.assertEqual(traced(3, 3), m(3, 3))
699
700    def test_call_to_assert_with_multiline_message(self):
701        class M(torch.nn.Module):
702            def forward(self, a, b):
703                error_msg = """
704An error message with
705terrible spacing
706                """
707                assert a == b, error_msg
708                return a + b
709
710        m = M()
711        traced = symbolic_trace_with_rewrite(m)
712
713        # Make sure the graph is well-formed
714        traced.graph.lint()
715
716        # Check the IR to make sure there's a call_function node with target == "Assert"
717        self.assertTrue(
718            any(
719                node.op == "call_function" and node.target == torch._assert
720                for node in traced.graph.nodes
721            )
722        )
723
724        # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to
725        error_msg = """
726An error message with
727terrible spacing
728    """
729        traced(3, 3)
730        with self.assertRaisesRegex(AssertionError, error_msg):
731            traced(3, 5)
732
733        # Confirm that the output is correct
734        self.assertEqual(traced(3, 3), m(3, 3))
735
736    def test_subgraph_creation(self):
737        class MyModule(torch.nn.Module):
738            def __init__(self) -> None:
739                super().__init__()
740                self.param = torch.nn.Parameter(torch.rand(3, 4))
741                self.linear = torch.nn.Linear(4, 5)
742
743            def forward(self, x, y):
744                z = self.linear(x + self.param).clamp(min=0.0, max=1.0)
745                w = self.linear(y).clamp(min=0.0, max=1.0)
746                return z + w
747
748        # symbolically trace model
749        my_module = MyModule()
750        my_module_traced = symbolic_trace(my_module)
751
752        # random mod partitioning
753        partition_counter = 0
754        NPARTITIONS = 3
755
756        # Add some random meta info to make sure it is kept around.
757        for node in my_module_traced.graph.nodes:
758            if node.op != "output":
759                node.meta["test_meta_info"] = True
760
761        def mod_partition(node: Node):
762            nonlocal partition_counter
763            partition = partition_counter % NPARTITIONS
764            partition_counter = (partition_counter + 1) % NPARTITIONS
765            return partition
766
767        # split module in module with submodules
768        module_with_submodules = split_module(
769            my_module_traced, my_module, mod_partition
770        )
771
772        # Check that test_meta_info was still on all nodes.
773        submodules = dict(module_with_submodules.named_modules())
774        for node in module_with_submodules.graph.nodes:
775            if node.op == "call_module":
776                submod = submodules[node.target]
777                self.assertTrue(isinstance(submod, torch.fx.GraphModule))
778                for submod_node in submod.graph.nodes:
779                    if submod_node.op != "output":
780                        stored_op = submod_node.meta.get("test_meta_info")
781                        self.assertTrue(stored_op is not None and stored_op)
782
783        x = torch.rand(3, 4)
784        y = torch.rand(3, 4)
785
786        orig_out = my_module_traced(x, y)
787        submodules_out = module_with_submodules(x, y)
788
789        self.assertEqual(orig_out, submodules_out)
790
791    def test_split_module_dead_code(self):
792        class ModWithDeadCode(torch.nn.Module):
793            def forward(self, x):
794                output = x * 2  # we want this
795                dead_line = x + 2  # this is dead
796                return output
797
798        mod = ModWithDeadCode()
799        traced = torch.fx.symbolic_trace(mod)
800
801        # split into before (0), target (1), and after(2)
802        saw_mul = False
803
804        def split_callback(n):
805            nonlocal saw_mul
806            if n.target == operator.mul:
807                saw_mul = True
808                return 1
809
810            if not saw_mul:
811                return 0
812            if saw_mul:
813                return 2
814
815        split = split_module(traced, mod, split_callback)
816
817        x = torch.randn((5,))
818        torch.testing.assert_close(
819            split(x), traced(x)
820        )
821
822
823    def test_split_module_kwargs_expansion(self):
824        class ModuleWithKwargsExpansion(torch.nn.Module):
825            def forward(self, x, **kwargs):
826                return x + kwargs['foo']
827
828        mod = ModuleWithKwargsExpansion()
829        traced = torch.fx.symbolic_trace(mod)
830
831        seen_getitem = False
832
833        def split_callback(n):
834            nonlocal seen_getitem
835            split_idx = int(seen_getitem)
836            if n.target == operator.getitem:
837                seen_getitem = True
838            return split_idx
839
840        split = split_module(traced, mod, split_callback)
841
842        x = torch.randn(5, 3)
843        foo = torch.randn(5, 3)
844        torch.testing.assert_close(split(x, foo=foo), traced(x, foo=foo))
845
846    @skipIfNoTorchVision
847    def test_subgraph_trivial_resnet(self):
848        # Smoke test trivially splitting resnet into 1 partition works
849        # There was an issue before causing submodule names to be aliased
850        m = resnet18()
851        traced = symbolic_trace(m)
852        a = torch.rand(64, 3, 7, 7)
853        module_with_submodules = split_module(traced, m, lambda node: 0)
854        module_with_submodules(a)
855
856    def test_split_module_default_arg(self):
857        class ModelToTrace(torch.nn.Module):
858            def __init__(self) -> None:
859                super().__init__()
860                self.lin = torch.nn.Linear(512, 512)
861
862            def forward(self, x, targets=None):
863                x = self.lin(x)
864
865                if targets is not None:
866                    x = x + targets
867
868                return x
869
870        mtt = ModelToTrace()
871        traced = torch.fx.symbolic_trace(mtt, concrete_args={'targets': None})
872
873        split = split_module(traced, mtt, lambda node: 0)
874
875        x = torch.randn(50, 512)
876        torch.testing.assert_close(split(x), traced(x))
877
878    def test_normalize_binary_operators(self):
879        ops_to_test = {
880            torch.add,
881            torch.mul,
882            torch.sub,
883            torch.div,
884            torch.floor_divide,
885            torch.remainder,
886            torch.eq,
887            torch.ne,
888            torch.lt,
889            torch.le,
890            torch.gt,
891            torch.ge,
892        }
893
894        # Test Tensor/Tensor callsite
895        for op in ops_to_test:
896
897            class WrapperMod(torch.nn.Module):
898                def forward(self, x, y):
899                    return op(x, y)
900
901            traced = symbolic_trace(WrapperMod())
902            normalized = NormalizeOperators(traced).transform()
903            x, y = torch.randn(3, 4), torch.randn(3, 4)
904            torch.testing.assert_close(traced(x, y), normalized(x, y))
905            self.assertFalse(
906                any(n.target in ops_to_test for n in normalized.graph.nodes)
907            )
908
909        # Test Tensor/scalar callsite
910        for op in ops_to_test:
911
912            class WrapperMod(torch.nn.Module):
913                def forward(self, x):
914                    return op(x, 42)
915
916            traced = symbolic_trace(WrapperMod())
917            normalized = NormalizeOperators(traced).transform()
918            x = torch.randn(3, 4)
919            torch.testing.assert_close(traced(x), normalized(x))
920            self.assertFalse(
921                any(n.target in ops_to_test for n in normalized.graph.nodes)
922            )
923
924    @skipIfNoTorchVision
925    def test_normalize_args(self):
926        m = resnet18()
927
928        class FunctionalTracer(torch.fx.Tracer):
929            def is_leaf_module(
930                self, m: torch.nn.Module, module_qualified_name: str
931            ) -> bool:
932                # `leaves` contains the set of standard `nn.Modules` that are not
933                # currently symbolically traceable. Ideally this set would be empty
934                leaves = {torch.nn.BatchNorm2d}
935                return type(m) in leaves
936
937        traced = torch.fx.GraphModule(m, FunctionalTracer().trace(m))
938
939        input = torch.randn(5, 3, 224, 224)
940        ref_outs = traced(input)
941
942        ShapeProp(traced).propagate(input)
943        traced = NormalizeArgs(traced).transform()
944
945        modules = dict(traced.named_modules())
946
947        for node in traced.graph.nodes:
948            if node.op == "call_function" and node.target != operator.add:
949                self.assertEqual(len(node.args), 0)
950            elif node.op == "call_module":
951                submod_class = modules[node.target].__class__
952                nn_class = getattr(torch.nn, submod_class.__name__)
953                if submod_class == nn_class:
954                    self.assertEqual(len(node.args), 0)
955        traced(input)
956        self.assertEqual(traced(input), ref_outs)
957
958    def test_normalize_modules_exhaustive(self):
959        """
960        Exhaustively test `Node.normalized_arguments` on all standard
961        torch.nn Module classes
962        """
963        for test_params in module_tests + new_module_tests:
964            if "constructor" not in test_params:
965                constructor = getattr(torch.nn, test_params["module_name"])
966            else:
967                constructor = test_params["constructor"]
968
969            if "constructor_args" not in test_params:
970                args = ()
971            else:
972                args = test_params["constructor_args"]
973
974            mod = constructor(*args)
975            # Skip modules that are not standard `torch.nn`
976            # instances, including functionals. (functionals
977            # are tested in test_normalize_args)
978            if mod.__class__.__name__ not in dir(torch.nn):
979                continue
980
981            if "input_fn" not in test_params:
982                inputs = torch.randn(test_params["input_size"])
983            else:
984                inputs = test_params["input_fn"]()
985
986            if not isinstance(inputs, (tuple, list)):
987                inputs = (inputs,)
988
989            params = ", ".join(f"v{i}" for i in range(len(inputs)))
990
991            # Generate a class to wrap this standard `nn.Module` instance
992            test_classname = f"Test{mod.__class__.__name__}"
993            test_mod_code = f"""
994class {test_classname}(torch.nn.Module):
995    def __init__(self, mod):
996        super().__init__()
997        self.mod = mod
998
999    def forward(self, {params}):
1000        return self.mod({params})
1001            """
1002
1003            gbls = {"torch": torch}
1004            exec(test_mod_code, gbls)
1005
1006            test_instance = gbls[test_classname](mod)
1007            traced = symbolic_trace(test_instance)
1008
1009            # Use `Node.normalized_arguments` to get a new set of arguments
1010            # to feed to the Module. Then, rewrite the node to only take
1011            # in those arguments as kwargs
1012            modules = dict(traced.named_modules())
1013            for node in traced.graph.nodes:
1014                if node.op == "call_module":
1015                    submod_class = modules[node.target].__class__
1016                    nn_class = getattr(torch.nn, submod_class.__name__)
1017                    if submod_class == nn_class:
1018                        normalized_args = node.normalized_arguments(traced)
1019                        normalized_args2 = normalize_module(
1020                            traced, node.target, node.args, node.kwargs
1021                        )
1022                        assert normalized_args == normalized_args2
1023                        assert normalized_args
1024                        node.args = normalized_args.args
1025                        node.kwargs = normalized_args.kwargs
1026
1027            traced.recompile()
1028
1029            # These Modules have an RNG in their forward, so testing
1030            # correctness by comparing outputs is not correct. Skip that
1031            # check for these
1032            stochastic_modules = {"FractionalMaxPool2d", "FractionalMaxPool3d", "RReLU"}
1033
1034            if mod.__class__.__name__ not in stochastic_modules:
1035                self.assertEqual(traced(*inputs), mod(*inputs))
1036
1037            traced = NormalizeArgs(symbolic_trace(test_instance)).transform()
1038            modules = dict(traced.named_modules())
1039            for node in traced.graph.nodes:
1040                if node.op == "call_module":
1041                    submod_class = modules[node.target].__class__
1042                    nn_class = getattr(torch.nn, submod_class.__name__)
1043                    if submod_class == nn_class:
1044                        self.assertEqual(len(node.args), 0)
1045
1046    def test_normalize_args_preserve_meta(self):
1047        class MyModule(torch.nn.Module):
1048            def forward(self, a):
1049                return torch.add(a, 3)
1050
1051        m = MyModule()
1052        traced = symbolic_trace(m)
1053
1054        for node in traced.graph.nodes:
1055            if node.op == "call_function" and node.target == torch.add:
1056                node.meta["my_key"] = 7
1057                break
1058        else:
1059            self.fail("Didn't find call_function torch.add")
1060
1061        input = torch.randn(2, 3)
1062        ShapeProp(traced).propagate(input)
1063        traced = NormalizeArgs(traced).transform()
1064
1065        for node in traced.graph.nodes:
1066            if node.op == "call_function" and node.target == torch.add:
1067                self.assertTrue("my_key" in node.meta)
1068                self.assertEqual(node.meta["my_key"], 7)
1069                break
1070        else:
1071            self.fail("Didn't find call_function torch.add")
1072
1073    def test_normalize_args_perserve_type(self):
1074        class MyModule(torch.nn.Module):
1075            def forward(self, a: List[torch.Tensor]):
1076                return torch.add(a[0], a[1])
1077
1078        m = MyModule()
1079        traced = symbolic_trace(m)
1080        traced = NormalizeArgs(traced).transform()
1081
1082        for node in traced.graph.nodes:
1083            if node.op == "placeholder":
1084                self.assertEqual(node.type, List[torch.Tensor])
1085
1086    @skipIfNoTorchVision
1087    def test_annotate_returns_with_schema(self):
1088        m = resnet18()
1089
1090        traced_modules = symbolic_trace(m)
1091        traced_modules_annotated = AnnotateTypesWithSchema(traced_modules).transform()
1092        for node in traced_modules_annotated.graph.nodes:
1093            if node.type is None:
1094                check = (node.op, node.target)
1095                self.assertIn(
1096                    check,
1097                    {
1098                        ("placeholder", "x"),
1099                        ("call_module", "maxpool"),
1100                        ("call_function", operator.add),
1101                        ("call_function", torch.flatten),
1102                        ("output", "output"),
1103                    }
1104                )
1105
1106        # Smoke test torchscript compilation since now we're emitting type annotations
1107        torch.jit.script(traced_modules_annotated)
1108
1109        class FunctionalTracer(torch.fx.Tracer):
1110            def is_leaf_module(
1111                self, m: torch.nn.Module, module_qualified_name: str
1112            ) -> bool:
1113                # `leaves` contains the set of standard `nn.Modules` that are not
1114                # currently symbolically traceable. Ideally this set would be empty
1115                leaves = {torch.nn.BatchNorm2d}
1116                return type(m) in leaves
1117
1118        traced_functionals = torch.fx.GraphModule(m, FunctionalTracer().trace(m))
1119
1120        traced_functionals_annotated = AnnotateTypesWithSchema(
1121            traced_functionals
1122        ).transform()
1123        for node in traced_functionals_annotated.graph.nodes:
1124            if node.type is None:
1125                check = (node.op, node.target)
1126                excluded_nodes = {
1127                    ("placeholder", "x"),
1128                    # Return type differs based on boolean dispatch :(
1129                    ("call_function", torch.nn.functional.max_pool2d),
1130                    ("output", "output"),
1131                }
1132                # AnnotateTypesWithSchema doesn't work with bound C++ functions
1133                if not isinstance(node.target, BuiltinFunctionType):
1134                    self.assertIn(check, excluded_nodes)
1135
1136        # Smoke test torchscript compilation since now we're emitting type annotations
1137        torch.jit.script(traced_functionals_annotated)
1138
1139    def test_annotate_getitem_node(self):
1140        class CustomType:
1141            pass
1142
1143        class CustomNamedTuple(NamedTuple):
1144            x: int
1145            y: float
1146
1147        class MyModule(torch.nn.Module):
1148            def forward(self, inp: Tuple[CustomType, torch.Tensor], inp2: List[CustomType], inp3: CustomNamedTuple):
1149                inp_0 = inp[0]
1150                inp_1 = inp[1]
1151                inp2_0 = inp2[0]
1152                inp3_x = inp3.x
1153                inp3_y = inp3.y
1154                return inp_0 + inp_1 + inp2_0 + inp3_x + inp3_y
1155
1156        my_module = MyModule()
1157        my_module_traced = torch.fx.symbolic_trace(my_module)
1158
1159        # by default, fx transform loses type annotation of getitem nodes.
1160        for node in my_module_traced.graph.nodes:
1161            if node.target == operator.getitem:
1162                assert node.type is None
1163
1164        annotate_getitem_nodes(my_module_traced.graph)
1165
1166        for node in my_module_traced.graph.nodes:
1167            if node.target == operator.getitem:
1168                self.assertIsNotNone(node.type, f"Node {node} should be annotated but is not.")
1169
1170    def test_subgraph_uniquename(self):
1171        class MyModule(torch.nn.Module):
1172            def __init__(self) -> None:
1173                super().__init__()
1174                self.linear = torch.nn.Linear(4, 4)
1175
1176            def forward(self, a, b, c, d):
1177                add_1 = a + b
1178                add_2 = add_1 + c
1179                linear_1 = self.linear(add_1)
1180                add_3 = add_2 + d
1181                add_4 = add_2 + linear_1
1182                add_5 = add_3 + add_4
1183                return add_5
1184
1185        a, b, c, d = torch.ones(4), torch.ones(4), torch.ones(4), torch.ones(4)
1186        mm = MyModule()
1187        traced = symbolic_trace(mm)
1188
1189        def split_cb(node: torch.fx.Node):
1190            if node.name == "a" or node.name == "b" or node.name == "add":
1191                return 0
1192            else:
1193                return 1
1194
1195        module_with_submodule = split_module(traced, mm, split_cb)
1196        self.assertEqual(module_with_submodule(a, b, c, d), traced(a, b, c, d))
1197
1198    def test_split_qualname_mapping(self):
1199        d_hid = 4
1200
1201        class ExampleCode(torch.nn.Module):
1202            def __init__(self) -> None:
1203                super().__init__()
1204                self.mm_param = torch.nn.Parameter(torch.randn(d_hid, d_hid))
1205                self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
1206                self.lin = torch.nn.Linear(d_hid, d_hid)
1207
1208            def forward(self, x):
1209                x = torch.mm(x, self.mm_param)
1210                x = torch.relu(x)
1211                x = torch.mm(x, self.mm_param)
1212                x = self.lin(x)
1213                x = torch.relu(x)
1214                x = torch.mm(x, self.mm_param2)
1215                x = self.lin(x)
1216                return x
1217
1218        my_module = ExampleCode()
1219        my_module_traced = symbolic_trace(my_module)
1220
1221        part_idx = 0
1222
1223        def split_callback(n : torch.fx.Node):
1224            nonlocal part_idx
1225            if (n.op, n.target) == ('call_module', 'lin'):
1226                part_idx += 1
1227            return part_idx
1228
1229        # split module in module with submodules
1230        qualname_map : Dict[str, str] = {}
1231        module_with_submodules = split_module(
1232            my_module_traced, my_module, split_callback, qualname_map
1233        )
1234        expected_qualname_map = {
1235            'submod_1.lin': 'lin', 'submod_2.lin': 'lin'
1236        }
1237        self.assertEqual(qualname_map, expected_qualname_map)
1238
1239    def test_traceable_function_with_nonstandard_name(self):
1240        def foo(x):
1241            return torch.relu(x)
1242
1243        traced = symbolic_trace_with_rewrite(foo)
1244
1245    def test_to_folder(self):
1246        class Test(torch.nn.Module):
1247            def __init__(self) -> None:
1248                super().__init__()
1249                self.W = torch.nn.Parameter(torch.randn(2))
1250                self.seq = torch.nn.Sequential(torch.nn.BatchNorm1d(2, 2))
1251                self.linear = torch.nn.Linear(2, 2)
1252                self.attr = torch.randn(2)
1253                self.attr2 = torch.nn.Buffer(torch.randn(2))
1254                self.attr3 = torch.nn.Buffer(torch.ones(2, dtype=torch.int32))
1255
1256            def forward(self, x):
1257                return self.linear(self.seq(self.W + self.attr + self.attr2 + self.attr3 + x))
1258
1259        mod = symbolic_trace(Test())
1260        module_name = "Foo"
1261        import tempfile
1262        from pathlib import Path
1263
1264        with tempfile.TemporaryDirectory() as tmp_dir:
1265            tmp_dir = Path(tmp_dir)
1266            mod.to_folder(tmp_dir, module_name)
1267            # Recipe taken from here:
1268            # https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
1269            import importlib.util
1270
1271            spec = importlib.util.spec_from_file_location(
1272                module_name, tmp_dir / "__init__.py"
1273            )
1274            module = importlib.util.module_from_spec(spec)
1275            sys.modules[module_name] = module
1276            spec.loader.exec_module(module)
1277            t = torch.randn(2, 2)
1278            self.assertEqual(module.Foo()(t), mod(t))
1279
1280    def test_fetch(self):
1281        attrs_for_lowering: Dict[str, List[str]] = {
1282            "torch.nn.modules.conv.Conv2d": [
1283                "weight",
1284                "bias",
1285                "kernel_size",
1286                "stride",
1287                "padding",
1288                "dilation",
1289                "groups",
1290                "padding_mode",
1291            ],
1292            "torch.nn.modules.batchnorm.BatchNorm2d": [
1293                "weight",
1294                "bias",
1295                "running_mean",
1296                "running_var",
1297                "eps",
1298            ],
1299        }
1300
1301        class TestModule(torch.nn.Module):
1302            def __init__(self) -> None:
1303                super().__init__()
1304                self.conv = torch.nn.Conv2d(3, 3, 2)
1305                self.bn = torch.nn.BatchNorm2d(3)
1306
1307            def forward(self, a):
1308                a = self.conv(a)
1309                a += a
1310                return self.bn(a)
1311
1312        mod = TestModule()
1313        traced = symbolic_trace(mod)
1314        lift_lowering_attrs_to_nodes(traced)
1315
1316        for node in traced.graph.nodes:
1317            if node.op == "call_module":
1318                assert hasattr(node, "attrs_for_lowering")
1319                para_list = attrs_for_lowering[node.attrs_for_lowering["name"]]
1320
1321                # node.attrs_for_lowering has an addition field of class name
1322                assert len(para_list) + 1 == len(node.attrs_for_lowering)
1323                for p_name in para_list:
1324                    assert p_name in node.attrs_for_lowering
1325
1326    def test_merge_matmuls(self):
1327        """
1328        A collection of test cases for torch.fx.experimental.merge_matmul,
1329        a graph transformation that merges matrix multiplication operations.
1330        """
1331        # Utility function for counting matmuls for test assertions.
1332        def _count_matmuls(mod):
1333            gm = torch.fx.symbolic_trace(mod)
1334
1335            num_matmuls = 0
1336            for node in gm.graph.nodes:
1337                if node.target == torch.matmul:
1338                    num_matmuls += 1
1339
1340            return num_matmuls
1341
1342        # Simple test case in which there are two matmuls of the same size to merge.
1343        class SimpleMergeMatmulModule(torch.nn.Module):
1344            def __init__(self, rhs):
1345                super().__init__()
1346                self.rhs = rhs
1347
1348            def forward(self, x, y):
1349                a = torch.matmul(x, self.rhs)
1350                b = torch.matmul(y, self.rhs)
1351                return a + b
1352
1353        # Initialize inputs.
1354        a = torch.randn(3, 3)
1355        b = torch.randn(3, 3)
1356
1357        # Initialize RHS for matmuls.
1358        rhs = torch.randn(3, 4)
1359
1360        # Construct SimpleMergeMatmulModule and call merge_matmul on it.
1361        module = SimpleMergeMatmulModule(rhs)
1362        opt_module = merge_matmul.merge_matmul(module)
1363
1364        # Numerical correctness check.
1365        before = module(a, b)
1366        after = opt_module(a, b)
1367        before.allclose(after)
1368
1369        # Basic graph structure check; original module should have 2 matmuls
1370        # and optimized module should have 1.
1371        self.assertEqual(_count_matmuls(module), 2)
1372        self.assertEqual(_count_matmuls(opt_module), 1)
1373
1374        # Test case in which there are multiple matmuls of different sizes to merge.
1375        class FiveMergeMatmulModule(torch.nn.Module):
1376            def __init__(self, rhs):
1377                super().__init__()
1378                self.rhs = rhs
1379
1380            def forward(self, a, b, c, d, e):
1381                s = torch.tensor([])
1382                matmuls = []
1383
1384                # For some reason using a list comprehension or for-loop for this
1385                # doesn't work.
1386                matmuls.append(torch.matmul(a, self.rhs))
1387                matmuls.append(torch.matmul(b, self.rhs))
1388                matmuls.append(torch.matmul(c, self.rhs))
1389                matmuls.append(torch.matmul(d, self.rhs))
1390                matmuls.append(torch.matmul(e, self.rhs))
1391
1392                for m in matmuls:
1393                    s += torch.sum(m)
1394
1395                return s
1396
1397        # Initialize inputs.
1398        inputs = [torch.randn(2 * i + 1, 5) for i in range(5)]
1399
1400        # Initialize RHS.
1401        rhs = torch.randn(5, 4)
1402
1403        # Construct FiveMergeMatmulModule and call merge_matmul on it.
1404        module = FiveMergeMatmulModule(rhs)
1405        opt_module = merge_matmul.merge_matmul(module)
1406
1407        # Numerical correctness check.
1408        before = module(*inputs)
1409        after = opt_module(*inputs)
1410        before.allclose(after)
1411
1412        # Basic graph structure check; original module should have len(inputs) matmuls
1413        # and optimized module should have 1.
1414        self.assertEqual(_count_matmuls(module), len(inputs))
1415        self.assertEqual(_count_matmuls(opt_module), 1)
1416
1417        # Simple test case in which two matmuls cannot be merged due to a data dependency between
1418        # the LHS operands.
1419        class UnmergeableMatmulModule(torch.nn.Module):
1420            def __init__(self, rhs):
1421                super().__init__()
1422                self.rhs = rhs
1423
1424            def forward(self, x):
1425                a = torch.matmul(x, self.rhs)
1426                a_abs = torch.abs(a)
1427                b = torch.matmul(a_abs.transpose(1, 0), self.rhs)
1428                return b
1429
1430        # Initialize inputs.
1431        a = torch.randn(3, 3)
1432
1433        # Initialize RHS for matmuls.
1434        rhs = torch.randn(3, 4)
1435
1436        # Construct UnmergeableMatmulModule and call merge_matmul on it.
1437        module = UnmergeableMatmulModule(rhs)
1438        opt_module = merge_matmul.merge_matmul(module)
1439
1440        # Numerical correctness check.
1441        before = module(a)
1442        after = opt_module(a)
1443        before.allclose(after)
1444
1445        # Basic graph structure check; the number of matrix multiplcations should not have changed.
1446        self.assertEqual(_count_matmuls(module), 2)
1447        self.assertEqual(_count_matmuls(opt_module), 2)
1448
1449    def test_type_matches(self):
1450        should_be_equal = [
1451            (int, int),
1452            (numbers.Number, int),
1453            (numbers.Number, float),
1454            (int, type(torch.float)),
1455            (Union[int, float], int),
1456            (Union[int, float], float),
1457            (List[int], int),
1458            (List[int], create_type_hint([int, int])),
1459            (List[int], create_type_hint((int, int))),
1460            (List[torch.Tensor], create_type_hint([torch.Tensor, torch.Tensor])),
1461            (
1462                List[torch.Tensor],
1463                create_type_hint([torch.nn.Parameter, torch.nn.Parameter]),
1464            ),
1465            (torch.Tensor, torch.nn.Parameter),
1466            (List[torch.Tensor], create_type_hint([torch.nn.Parameter, torch.Tensor])),
1467            (List[torch.Tensor], create_type_hint([torch.Tensor, torch.nn.Parameter])),
1468            (List[torch.Tensor], create_type_hint((torch.Tensor, torch.Tensor))),
1469            (
1470                List[torch.Tensor],
1471                create_type_hint((torch.nn.Parameter, torch.nn.Parameter)),
1472            ),
1473            (torch.Tensor, torch.nn.Parameter),
1474            (List[torch.Tensor], create_type_hint((torch.nn.Parameter, torch.Tensor))),
1475            (List[torch.Tensor], create_type_hint((torch.Tensor, torch.nn.Parameter))),
1476            (Optional[List[torch.Tensor]], List[torch.Tensor]),
1477            (Optional[List[int]], List[int]),
1478        ]
1479        for sig_type, arg_type in should_be_equal:
1480            self.assertTrue(type_matches(sig_type, arg_type))
1481
1482        should_fail = [
1483            (int, float),
1484            (Union[int, float], str),
1485            (List[torch.Tensor], List[int]),
1486        ]
1487
1488        for sig_type, arg_type in should_fail:
1489            self.assertFalse(type_matches(sig_type, arg_type))
1490
1491    @skipIfNoMkldnn
1492    def test_optimize_for_inference_cpu(self):
1493        import torch.nn as nn
1494
1495        class Foo(nn.Module):
1496            def __init__(self) -> None:
1497                super().__init__()
1498                layers = []
1499                layers2 = []
1500                for _ in range(10):
1501                    layers.append(nn.Conv2d(3, 3, 1))
1502                    layers.append(nn.BatchNorm2d(3))
1503                    layers.append(nn.ReLU())
1504
1505                    layers2.append(nn.Conv2d(3, 3, 1))
1506                    layers2.append(nn.BatchNorm2d(3))
1507                    layers2.append(nn.ReLU())
1508                self.model = nn.Sequential(*layers)
1509                self.model2 = nn.Sequential(*layers2)
1510
1511            def forward(self, x):
1512                return self.model(x) + self.model2(x)
1513
1514        N, C, H, W, = (
1515            1,
1516            3,
1517            224,
1518            224,
1519        )
1520        inp = torch.randn(N, C, H, W)
1521        with torch.no_grad():
1522            model = Foo().eval()
1523            optimized_model = optimization.optimize_for_inference(model)
1524            torch.testing.assert_close(model(inp), optimized_model(inp))
1525
1526            optimized_model2 = optimization.optimize_for_inference(
1527                model, pass_config={"remove_dropout": False}
1528            )
1529            torch.testing.assert_close(model(inp), optimized_model2(inp))
1530
1531    @skipIfNoTorchVision
1532    @skipIfNoMkldnn
1533    def test_optimize_for_inference_cpu_torchvision(self):
1534        models = [
1535            torchvision.models.resnet18,
1536            torchvision.models.resnet50,
1537            torchvision.models.densenet121,
1538            torchvision.models.shufflenet_v2_x1_0,
1539            torchvision.models.vgg16,
1540            torchvision.models.mobilenet_v2,
1541            torchvision.models.mnasnet1_0,
1542            torchvision.models.resnext50_32x4d,
1543        ]
1544        with torch.no_grad():
1545            for model_type in models:
1546                model = model_type()
1547                C, H, W, = (
1548                    3,
1549                    224,
1550                    224,
1551                )
1552                inp = torch.randn(3, C, H, W)
1553                model(inp)
1554                model.eval()
1555                inp = torch.randn(1, C, H, W)
1556                heuristic = optimization.gen_mkl_autotuner(inp, iters=0, warmup=0)
1557                optimized_model = optimization.optimize_for_inference(model)
1558
1559                orig_out = model(inp)
1560                new_out = optimized_model(inp)
1561                torch.testing.assert_close(orig_out, new_out)
1562
1563
1564class TestNormalizeOperators(JitTestCase):
1565    @onlyCPU
1566    @ops(op_db, allowed_dtypes=(torch.float,))
1567    def test_normalize_operator_exhaustive(self, device, dtype, op):
1568        # These ops currently don't trace in FX for various reasons (i.e. they take a list of tensors)
1569        fx_fail = {"cat", "stack", "hstack", "vstack", "dstack", "linalg.multi_dot", "_upsample_bilinear2d_aa", "_chunk_cat"}
1570        sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
1571        if isinstance(op.op, torch._ops.OpOverload):
1572            self.skipTest("normalize operator doesn't work on torch.ops")
1573        for sample_input in sample_inputs_itr:
1574            unsupported_arg_type = False
1575            arg_values = [sample_input.input] + list(sample_input.args)
1576            kwarg_values = sample_input.kwargs
1577            arg_types = []
1578            kwarg_types = {}
1579
1580            def jit_infer_type(v):
1581                inferred_arg_type = torch._C._jit_try_infer_type(v)
1582                assert inferred_arg_type.success()
1583                t = _torchscript_type_to_python_type(inferred_arg_type.type())
1584                return t
1585
1586            for v in arg_values:
1587                if isinstance(v, torch.Tensor):
1588                    arg_types.append(type(v))
1589                else:
1590                    if isinstance(v, complex):
1591                        # Complex type not supported in FX
1592                        unsupported_arg_type = True
1593                    arg_types.append(jit_infer_type(v))
1594
1595            for k, v in kwarg_values.items():
1596                if isinstance(v, torch.Tensor):
1597                    kwarg_types[k] = type(v)
1598                else:
1599                    if isinstance(v, complex):
1600                        # Complex type not supported in FX
1601                        unsupported_arg_type = True
1602                    kwarg_types[k] = jit_infer_type(v)
1603
1604            if unsupported_arg_type:
1605                continue
1606            # Test normalize_function by itself
1607            ref_out = op.op(*arg_values, **kwarg_values)
1608            norm_args_and_kwargs = normalize_function(
1609                op.op, arg_values, kwarg_values, arg_types, kwarg_types
1610            )
1611            if norm_args_and_kwargs is None:
1612                raise RuntimeError(
1613                    """
1614                    FX failed to normalize op - add the op to the op_skip list.
1615                    A common reason is if your OpInfo was implemented with a lambda
1616                    - otherwise, file an issue
1617                    """
1618                )
1619            test_out = op.op(*norm_args_and_kwargs.args, **norm_args_and_kwargs.kwargs)
1620            self.assertEqual(test_out, ref_out)
1621
1622            # Test normalized_arguments as part of FX
1623            if op.name in fx_fail:
1624                continue
1625            param_names = []
1626            param_values = []
1627            fx_args = []
1628
1629            idx = 0
1630
1631            def process_arg(arg, name):
1632                if isinstance(arg, torch.Tensor):
1633                    param_names.append(name)
1634                    param_values.append(arg)
1635                    return name
1636                else:
1637                    return f"{repr(arg)}"
1638
1639            def process_arg_with_idx(arg):
1640                nonlocal idx
1641                res = process_arg(arg, f"arg_{idx}")
1642                idx = idx + 1
1643                return res
1644
1645            def str_arg(arg):
1646                if isinstance(arg, tuple):
1647                    args = [f"{str_arg(v)}, " for v in arg]
1648                    return f"({' '.join(args)})"
1649                elif isinstance(arg, list):
1650                    args = [f"{str_arg(v)}" for v in arg]
1651                    return f"[{', '.join(args)}]"
1652                else:
1653                    return arg
1654
1655            for v in arg_values:
1656                arg = pytree.tree_map(process_arg_with_idx, v)
1657                fx_args.append(str_arg(arg))
1658
1659            for k, v in kwarg_values.items():
1660                arg = pytree.tree_map(functools.partial(process_arg, name=k), v)
1661                fx_args.append(f"{k} = {str_arg(arg)}")
1662
1663            code = f"""
1664class TestModule(torch.nn.Module):
1665    def forward(self, {', '.join(param_names)}):
1666        return torch.{op.name}({', '.join(fx_args)})
1667            """
1668
1669            g = {"torch": torch, "inf": math.inf}
1670            exec(code, g)
1671            TestModule = g["TestModule"]
1672
1673            m = TestModule()
1674            traced = torch.fx.symbolic_trace(m)
1675            ref_out = traced(*param_values)
1676
1677            for node in traced.graph.nodes:
1678                if node.op == "call_function":
1679                    normalized_args = node.normalized_arguments(
1680                        traced, arg_types, kwarg_types
1681                    )
1682                    assert normalized_args
1683                    node.args = normalized_args.args
1684                    node.kwargs = normalized_args.kwargs
1685            traced.recompile()
1686
1687            test_out = traced(*param_values)
1688            self.assertEqual(test_out, ref_out)
1689
1690    def test_normalize_quantized_eb(self):
1691        target = torch.ops.quantized.embedding_bag_byte_rowwise_offsets
1692        args = (
1693            torch.empty((2, 3), dtype=torch.uint8),
1694            torch.empty((2,), dtype=torch.int64),
1695            torch.empty((2,), dtype=torch.int64),
1696        )
1697        norm_args_and_kwargs = normalize_function(
1698            target, args, normalize_to_only_use_kwargs=True
1699        )
1700        self.assertTrue(norm_args_and_kwargs is not None)
1701        self.assertEqual(
1702            set(norm_args_and_kwargs.kwargs.keys()),
1703            {
1704                "weight",
1705                "indices",
1706                "offsets",
1707                "scale_grad_by_freq",
1708                "mode",
1709                "pruned_weights",
1710                "per_sample_weights",
1711                "compressed_indices_mapping",
1712                "include_last_offset",
1713            },
1714        )
1715        self.assertEqual(norm_args_and_kwargs.args, ())
1716
1717    def test_normalize_args_op_overload(self):
1718        for target in [torch.ops.aten.resize_as_.default, torch.ops.aten.resize_as_]:
1719            inp1 = torch.rand([1])
1720            inp2 = torch.rand([4])
1721            args, kwargs = normalize_function(target, (inp1,), {"the_template": inp2}, normalize_to_only_use_kwargs=True)
1722            self.assertIs(kwargs["input"], inp1)
1723            self.assertIs(kwargs["the_template"], inp2)
1724
1725
1726if TEST_Z3:
1727    import z3
1728
1729    import torch._dynamo.config
1730
1731    from torch.fx.experimental.validator import SympyToZ3, TranslationValidator, ValidationException, z3str
1732    from torch.utils._sympy.functions import FloorDiv, Mod
1733
1734    class TestTranslationValidation(TestCase):
1735        def _prepare_for_translation_validation(self):
1736            validator = TranslationValidator()
1737
1738            # SymPy symbols.
1739            s0, s1, s2 = sympy.symbols("s0 s1 s2", integer=True)
1740
1741            # Z3 symbols.
1742            [validator.add_var(s, int) for s in (s0, s1, s2)]
1743            z0, z1, z2 = (validator.z3var(s) for s in (s0, s1, s2))
1744
1745            return (s0, s1, s2), (z0, z1, z2), validator
1746
1747        def test_sympy_to_z3(self):
1748
1749            (
1750                (s0, s1, s2),
1751                (z0, z1, z2),
1752                validator,
1753            ) = self._prepare_for_translation_validation()
1754
1755            test_cases = [
1756                # Integer constants.
1757                (sympy.S.Zero, z3.IntVal(0)),
1758                (sympy.S.One, z3.IntVal(1)),
1759                (sympy.S.NegativeOne, z3.IntVal(-1)),
1760                (sympy.Integer(2), z3.IntVal(2)),
1761                (
1762                    s0,
1763                    z0,
1764                ),
1765                # Arithmetic operations.
1766                *[
1767                    (op(s0, s1), op(z0, z1))
1768                    for op in (
1769                        operator.add,
1770                        operator.mul,
1771                        operator.pow,
1772                    )
1773                ],
1774                # Logical operations.
1775                *[
1776                    (sympy_op(s0, s1), z3_op(z0, z1))
1777                    for sympy_op, z3_op in (
1778                        (sympy.Eq, operator.eq),
1779                        (sympy.Ne, operator.ne),
1780                        (sympy.Lt, operator.lt),
1781                        (sympy.Le, operator.le),
1782                        (sympy.Gt, operator.gt),
1783                        (sympy.Ge, operator.ge),
1784                    )
1785                ],
1786                # Other operations.
1787                (
1788                    s0 - s1,
1789                    z0 + z3.IntVal(-1) * z1,
1790                ),
1791                (
1792                    s0 / s1,
1793                    z3.ToReal(z0) * (z1**-1),
1794                ),
1795                (FloorDiv(s0, s1), z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1))),
1796                (Mod(s0, s1), z0 - z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1)) * z1),
1797                (
1798                    Mod(s2, (s0 / s1)),
1799                    z2
1800                    - z3.ToReal(z3.ToInt(z3.ToReal(z2) / (z3.ToReal(z0) * z1**-1)))
1801                    * (z3.ToReal(z0) * z1**-1),
1802                ),
1803                (
1804                    Mod(s2, s0**3),
1805                    z2 - z3.ToReal(z3.ToInt(z3.ToReal(z2) / z0**3)) * z0**3,
1806                ),
1807            ]
1808
1809            toZ3 = SympyToZ3(validator)
1810            for sympy_expr, z3_expr in test_cases:
1811                result = toZ3.run(sympy_expr)
1812                self.assertTrue(
1813                    z3_expr.eq(result), msg=f"expected: {z3_expr}. Got: {result}"
1814                )
1815
1816        def test_sat(self):
1817            (
1818                (s0, s1, s2),
1819                (z0, z1, z2),
1820                validator,
1821            ) = self._prepare_for_translation_validation()
1822
1823            validator.add_source_expr(z0 > 5)
1824            validator.add_source_expr(z1 / 2 > z0)
1825
1826            # Solutions for target is a subset of the solutions for the source.
1827            validator.add_target_expr(s0 > 20)
1828            validator.add_target_expr(s1 > s0**2)
1829
1830            validator.validate()
1831
1832        def test_unsat(self):
1833            (
1834                (s0, s1, s2),
1835                (z0, z1, z2),
1836                validator,
1837            ) = self._prepare_for_translation_validation()
1838
1839            validator.add_source_expr(z0 > 5)
1840            validator.add_source_expr(z1 / 2 > z0)
1841
1842            # Solutions for target is NOT a subset of the solutions for the source.
1843            validator.add_target_expr(s0 > 20)
1844            # This expression is less restrictive than its counterpart.
1845            validator.add_target_expr(s1 > s0 + 2)
1846
1847            with self.assertRaisesRegex(ValidationException, "translation validation failed."):
1848                validator.validate()
1849
1850        def test_z3str(self):
1851            a = z3.Int("a")
1852            b = z3.Int("b")
1853            special = z3.Real("this.size()[2]")
1854
1855            test_cases = [
1856                (z3.IntVal(42), "42"),
1857                # Variable.
1858                (a, "a"),
1859                # Name with special characters.
1860                (special, "this.size()[2]"),
1861                # Renamed function fpplications.
1862                (a != b, "(!= a b)"),
1863                (a ** b, "(pow a b)"),
1864                # Chain of associative operations.
1865                *[
1866                    (op(op(a, 5), b), f"({opstr} 5 a b)")
1867                    for op, opstr in [
1868                        (operator.add, "+"),
1869                        (operator.mul, "*")
1870                    ]
1871                ],
1872                # Revert 'Not' conversions.
1873                (a != b, "(!= a b)"),
1874                (a < b, "(> b a)"),
1875                (a > b, "(> a b)"),
1876                # Ignore 'ToInt' and 'ToReal' functions.
1877                (z3.ToInt(special) + a, "(+ this.size()[2] a)"),
1878                (z3.ToReal(a + b), "(+ a b)"),
1879                # Convert to floor division: 'idiv'.
1880                (z3.ToInt(z3.ToReal(a) / z3.ToReal(b)), "(idiv a b)"),
1881            ]
1882
1883            for expr, expected in test_cases:
1884                self.assertEqual(z3str(expr), expected)
1885
1886
1887instantiate_device_type_tests(TestNormalizeOperators, globals())
1888
1889if __name__ == "__main__":
1890    run_tests()
1891