# Owner(s): ["module: fx"] import functools import math import numbers import operator import pickle import sys import sympy import tempfile import unittest from types import BuiltinFunctionType from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union import torch import torch.fx.experimental.meta_tracer import torch.fx.experimental.optimization as optimization from torch.fx._symbolic_trace import symbolic_trace from torch.fx.experimental import merge_matmul from torch.fx.experimental.accelerator_partitioner import Partitioner from torch.fx.experimental.normalize import NormalizeArgs, NormalizeOperators from torch.fx.experimental.partitioner_utils import ( Device, get_latency_of_partitioned_graph, get_partition_to_latency_mapping, NodeLatency, PartitionerConfig, PartitionMode, ) from torch.fx.experimental.rewriter import RewritingTracer from torch.fx.experimental.schema_type_annotation import AnnotateTypesWithSchema from torch.fx.graph_module import GraphModule from torch.fx.node import Node from torch.fx.operator_schemas import ( _torchscript_type_to_python_type, create_type_hint, normalize_function, normalize_module, type_matches, ) from torch.fx.passes import graph_manipulation from torch.fx.passes.param_fetch import lift_lowering_attrs_to_nodes from torch.fx.passes.shape_prop import ShapeProp from torch.fx.passes.split_module import split_module from torch.fx.passes.annotate_getitem_nodes import annotate_getitem_nodes from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, onlyCPU, ops, ) from torch.testing._internal.common_methods_invocations import op_db from torch.testing._internal.common_nn import module_tests, new_module_tests from torch.testing._internal.common_utils import TEST_Z3, run_tests, TestCase from torch.testing._internal.jit_utils import JitTestCase import torch.utils._pytree as pytree try: import torchvision.models from torchvision.models import resnet18 HAS_TORCHVISION = True except ImportError: HAS_TORCHVISION = False skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") skipIfNoMkldnn = unittest.skipIf( not (torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available()), "no MKLDNN", ) def symbolic_trace_with_rewrite(root: Union[torch.nn.Module, Callable]) -> GraphModule: return GraphModule( root if isinstance(root, torch.nn.Module) else torch.nn.Module(), RewritingTracer().trace(root), ) class TestFXExperimental(JitTestCase): def test_find_single_partition(self): class TestModule(torch.nn.Module): def forward(self, a, b): return a + b m = TestModule() traced = symbolic_trace(m) a = torch.rand(1) b = torch.rand(1) graph_manipulation.get_size_of_all_nodes(traced, [a, b]) partitioner = Partitioner() devices = [ Device("dev_0", 125, 0), Device("dev_1", 150, 1), Device("dev_2", 125, 2), ] partitioner_config = PartitionerConfig(devices) ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules dag = ret.dag self.assertEqual(traced(a, b), module_with_submodules(a, b)) assert dag.nodes[0].logical_device_ids == [1] def test_lack_of_devices(self): class TestModule(torch.nn.Module): def forward(self, a, b): return a + b m = TestModule() traced = symbolic_trace(m) a = torch.rand(4) b = torch.rand(4) graph_manipulation.get_size_of_all_nodes(traced, [a, b]) partitioner = Partitioner() devices = [Device("dev_0", 4, 0), Device("dev_1", 4, 1)] partitioner_config = PartitionerConfig(devices, PartitionMode.size_based) catch_runtime_error = False try: ret = partitioner.partition_graph(traced, m, partitioner_config) except RuntimeError: catch_runtime_error = True assert catch_runtime_error def test_large_node_error(self): class TestModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(4, 4) def forward(self, a): linear = self.linear(a) add = linear + a return add m = TestModule() traced = symbolic_trace(m) a = torch.rand(4) graph_manipulation.get_size_of_all_nodes(traced, [a]) partitioner = Partitioner() devices = [ Device("dev_0", 40, 0), Device("dev_1", 40, 0), Device("dev_2", 40, 0), Device("dev_3", 40, 0), Device("dev_4", 40, 0), ] partitioner_config = PartitionerConfig(devices, PartitionMode.size_based) catch_runtime_error = False try: ret = partitioner.partition_graph(traced, m, partitioner_config) except RuntimeError: catch_runtime_error = True assert catch_runtime_error def test_partition_node_manipulation(self): class TestModule(torch.nn.Module): def forward(self, a, b): add_1 = a + b add_2 = add_1 + torch.rand(4) add_3 = add_2 + torch.rand(4) return add_3 m = TestModule() traced = symbolic_trace(m) a, b = torch.rand(4), torch.rand(4) graph_manipulation.get_size_of_all_nodes(traced, [a, b]) partitioner = Partitioner() devices = [Device("dev_0", 1000, 0)] partitioner_config = PartitionerConfig(devices) ret = partitioner.partition_graph(traced, m, partitioner_config) partition = partitioner.partitions[0] assert partition.used_mem_bytes == 112 # Select add_2 node to remove selected_node = None for node in partition.nodes: if node.name == "add_2": selected_node = node partition.remove_node(selected_node) assert partition.used_mem_bytes == 80 def test_size_based_partition(self): class TestModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(4, 4) self.c = torch.rand(4) def forward(self, a, b): add_1 = a + b linear = self.linear(add_1) add_2 = linear + self.c return add_2 m = TestModule() traced = symbolic_trace(m) a = torch.rand(4) b = torch.rand(4) graph_manipulation.get_size_of_all_nodes(traced, [a, b]) partitioner = Partitioner() devices = [ Device("dev_0", 125, 0), Device("dev_1", 125, 1), Device("dev_2", 125, 2), ] partitioner_config = PartitionerConfig(devices, PartitionMode.size_based) ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules dag = ret.dag self.assertEqual(traced(a, b), module_with_submodules(a, b)) for i, node in enumerate(dag.nodes): assert node.logical_device_ids == [i] def test_partition_device_mapping(self): class TestModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(4, 4) def forward(self, a): b = torch.rand(4) add_1 = a + b linear_1 = self.linear(add_1) add_2 = torch.rand(4) + a add_3 = add_2 + linear_1 return add_3 m = TestModule() traced = symbolic_trace(m) a = torch.rand(4) graph_manipulation.get_size_of_all_nodes(traced, [a]) partitioner = Partitioner() devices = [Device("dev_0", 120, 0), Device("dev_1", 160, 1)] partitioner_config = PartitionerConfig(devices, PartitionMode.size_based) ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules dag = ret.dag self.assertEqual(traced(a), module_with_submodules(a)) for i, node in enumerate(dag.nodes): if i == 1: assert node.logical_device_ids == [1] else: assert node.logical_device_ids == [0] def test_sparse_nn_partition(self): class MyRecommendationModule(torch.nn.Module): def create_mlp(self, num_of_layers: int, input_size: int, output_size: int): layers = torch.nn.ModuleList() for _ in range(num_of_layers): ll = torch.nn.Linear(input_size, output_size) layers.append(ll) layers.append(torch.nn.ReLU()) return layers def __init__(self) -> None: super().__init__() layers = self.create_mlp(4, 4, 4) self.bottom_layers = torch.nn.Sequential(*layers) layers = self.create_mlp(3, 24, 24) self.top_layers = torch.nn.Sequential(*layers) self.embedding_layers = torch.nn.ModuleList() el = torch.nn.EmbeddingBag(500000, 4, mode="sum", sparse=True) self.embedding_layers.append(el) for i in range(3): el = torch.nn.EmbeddingBag(1000000, 4, mode="sum", sparse=True) self.embedding_layers.append(el) el = torch.nn.EmbeddingBag(500000, 4, mode="sum", sparse=True) self.embedding_layers.append(el) def forward(self, a, b, offset): x = self.bottom_layers(a) y = [] c = [] for i in range(len(self.embedding_layers)): temp = torch.randint(10, (8,)) c.append(temp + b) for i in range(len(self.embedding_layers)): if i % 2 == 0: y.append(self.embedding_layers[i](c[i], offset)) else: y.append( self.embedding_layers[i](torch.randint(10, (8,)), offset) ) z = torch.cat([x] + y, dim=1) p = self.top_layers(z) return p m = MyRecommendationModule() a = torch.rand(2, 4) b = torch.randint(10, (8,)) offset = torch.randint(1, (2,)) traced = symbolic_trace(m) graph_manipulation.get_size_of_all_nodes(traced, [a, b, offset]) devices = [ Device("dev_0", 33000000, 0), Device("dev_1", 33000000, 1), Device("dev_2", 33000000, 2), ] partitioner_config = PartitionerConfig(devices, PartitionMode.sparse_nn) partitioner = Partitioner() ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules dag = ret.dag self.assertEqual(traced(a, b, offset), module_with_submodules(a, b, offset)) assert len(module_with_submodules.graph.nodes) == 24 def test_partition_latency(self): class TestModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(4, 4) def forward(self, a): add_1 = a + torch.rand(4) add_2 = add_1 + torch.rand(4) linear_1 = self.linear(add_1) add_3 = add_2 + linear_1 add_4 = add_2 + add_3 return add_4 def get_node_to_latency_mapping(fx_module: GraphModule): """Given a fx module, generate node latency for each node based on the size of each node """ node_to_latency_mapping: Dict[Node, NodeLatency] = {} for node in fx_module.graph.nodes: if node.op not in {"output", "placeholder", "get_attr"}: if node.size_bytes.total_size == node.size_bytes.output_size: node_to_latency_mapping[node] = NodeLatency( node.size_bytes.total_size, 2.0 * node.size_bytes.total_size ) else: node_to_latency_mapping[node] = NodeLatency( node.size_bytes.total_size, node.size_bytes.output_size ) return node_to_latency_mapping m = TestModule() traced = symbolic_trace(m) a = torch.rand(4) graph_manipulation.get_size_of_all_nodes(traced, [a]) node_to_latency_mapping = get_node_to_latency_mapping(traced) devices = [Device("dev_0", 200, 0), Device("dev_1", 200, 1)] partitioner = Partitioner() partitioner_config = PartitionerConfig(devices) ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules self.assertEqual(traced(a), module_with_submodules(a)) partitions = partitioner.partitions partition_to_latency_mapping = get_partition_to_latency_mapping( partitions, node_to_latency_mapping ) for p in partition_to_latency_mapping: if p.partition_id == 0: assert partition_to_latency_mapping[p] == (128.0, 80.0, 160.0) else: assert partition_to_latency_mapping[p] == (16.0, 32.0, 32.0) transfer_rate_bytes_per_sec = 2 critical_path_latency_sec = get_latency_of_partitioned_graph( partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec ) assert critical_path_latency_sec == 208.0 def test_cost_aware_partition(self): class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(4, 4) def forward(self, a): add_1 = a + torch.rand(4) add_2 = add_1 + torch.rand(4) linear_1 = self.linear(add_1) add_3 = add_2 + torch.rand(4) add_4 = add_2 + linear_1 add_5 = add_3 + add_4 return add_5 def get_node_to_latency_mapping(fx_module: GraphModule): node_to_latency_mapping: Dict[Node, NodeLatency] = {} for node in fx_module.graph.nodes: if node.op not in {"output", "placeholder", "get_attr"}: if node.size_bytes.total_size == node.size_bytes.output_size: node_to_latency_mapping[node] = NodeLatency( node.size_bytes.total_size, 1 ) else: node_to_latency_mapping[node] = NodeLatency( node.size_bytes.total_size, node.size_bytes.output_size ) return node_to_latency_mapping m = MyModule() traced = symbolic_trace(m) a = torch.rand(4) graph_manipulation.get_size_of_all_nodes(traced, [a]) devices = [ Device("dev_0", 125, 0), Device("dev_1", 125, 1), Device("dev_2", 125, 2), Device("dev_3", 125, 3), ] node_to_latency_mapping = get_node_to_latency_mapping(traced) partitioner_config = PartitionerConfig( devices, mode=PartitionMode.cost_aware, transfer_rate_bytes_per_sec=2, node_to_latency_mapping=node_to_latency_mapping, ) partitioner = Partitioner() ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules dag = ret.dag self.assertEqual(traced(a), module_with_submodules(a)) partitions = partitioner.partitions partition_to_latency_mapping = get_partition_to_latency_mapping( partitions, node_to_latency_mapping ) critical_path_latency_sec = get_latency_of_partitioned_graph( partitions, partition_to_latency_mapping, partitioner_config.transfer_rate_bytes_per_sec, ) assert critical_path_latency_sec == 160.0 def test_aot_based_partition(self): class TestModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.b = torch.rand(4) self.c = torch.rand(4) def forward(self, a): add_1 = a + self.b add_2 = self.c + add_1 return add_2 m = TestModule() traced = symbolic_trace(m) a = torch.rand(4) node_to_partition_id = {} partition_to_logical_devices = {} count = 0 graph_manipulation.get_size_of_all_nodes(traced, [a]) for node in traced.graph.nodes: if node.op not in {"placeholder", "get_attr", "output"}: node_to_partition_id[node] = count partition_to_logical_devices[count] = [0] count += 1 devices = [Device("dev_0", 200, 0)] partitioner_config = PartitionerConfig( devices=devices, mode=PartitionMode.aot_based, node_to_partition_mapping=node_to_partition_id, partition_to_logical_device_mapping=partition_to_logical_devices, ) partitioner = Partitioner() ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules dag = ret.dag self.assertEqual(module_with_submodules(a), traced(a)) for node in dag.nodes: assert node.size_bytes == 48 assert node.logical_device_ids == [0] def test_replace_target_nodes_with(self): class testModule(torch.nn.Module): def forward(self, a, b): return a + b m = testModule() traced = symbolic_trace(m) input1 = torch.randn(1) input2 = torch.randn(1) assert (input1 + input2) == traced(input1, input2) graph_manipulation.replace_target_nodes_with( fx_module=traced, old_op="call_function", old_target=operator.add, new_op="call_function", new_target=operator.mul, ) assert (input1 * input2) == traced(input1, input2) def test_saturate_host(self): class TestModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(4, 4) def forward(self, a): add_1 = a + torch.rand(4) add_2 = add_1 + torch.rand(4) linear_1 = self.linear(add_1) add_3 = add_2 + linear_1 add_4 = add_2 + add_3 return add_4 m = TestModule() traced = symbolic_trace(m) a = torch.rand(4) graph_manipulation.get_size_of_all_nodes(traced, [a]) devices = [ Device("dev_0", 200, 0), Device("dev_1", 200, 1), Device("dev_2", 100, 2), Device("dev_3", 100, 3), Device("dev_4", 200, 4), Device("dev_5", 100, 5), ] partitioner = Partitioner() # Without host saturation, the model will be split into two partitions. # dev_0 holds partition 0 of 192 bytes and dev_1 holds partition 1 of 48 bytes. partitioner_config = PartitionerConfig(devices, saturate_host=True) ret = partitioner.partition_graph(traced, m, partitioner_config) module_with_submodules = ret.module_with_submodules self.assertEqual(traced(a), module_with_submodules(a)) partitions = partitioner.partitions self.assertEqual(len(partitions), 2) # With host saturation, partition 1 will be replicated to dev_4, and partition 2 # will be replicated to dev_2. self.assertEqual(partitions[0].logical_device_ids, [0, 4]) self.assertEqual(partitions[1].logical_device_ids, [1, 2]) @skipIfNoTorchVision def test_conv_bn_fusion(self): rn18 = resnet18().eval() traced = symbolic_trace(rn18) fused = optimization.fuse(traced) self.assertTrue( all(not isinstance(m, torch.nn.BatchNorm2d) for m in fused.modules()) ) N, C, H, W = 20, 3, 224, 224 inp = torch.randn(N, C, H, W) self.assertEqual(fused(inp), rn18(inp)) def test_conv_bn_fusion_not_running_state(self): class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d(32, 64, 3, stride=2) self.bn = torch.nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False) def forward(self, x): x = self.conv(x) x = self.bn(x) return x model = M().eval() traced = symbolic_trace(model) fused = optimization.fuse(traced) inp = torch.randn([1, 32, 50, 50]) # bn need not be folded in conv self.assertTrue( any(isinstance(m, torch.nn.BatchNorm2d) for m in fused.modules()) ) self.assertEqual(fused(inp), model(inp)) def test_conv_bn_fusion_mixed_dtype(self): class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False, dtype=torch.bfloat16) self.bn = torch.nn.BatchNorm2d(16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True) def forward(self, x): x = self.conv(x) x = self.bn(x) return x model = M().eval() traced = symbolic_trace(model) fused = optimization.fuse(traced) inp = torch.randn(1, 3, 64, 64, dtype=torch.bfloat16) self.assertTrue( all(not isinstance(m, torch.nn.BatchNorm2d) for m in fused.modules()) ) self.assertEqual(fused(inp), model(inp)) def test_call_to_assert_no_msg(self): class M(torch.nn.Module): def forward(self, a, b): assert a == b return a + b m = M() traced = symbolic_trace_with_rewrite(m) # Make sure the graph is well-formed traced.graph.lint() # Check the IR to make sure there's a call_function node with target == "Assert" self.assertTrue( any( node.op == "call_function" and node.target == torch._assert for node in traced.graph.nodes ) ) # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to traced(3, 3) with self.assertRaisesRegex(AssertionError, ""): traced(3, 5) # Confirm that the output is correct self.assertEqual(traced(3, 3), m(3, 3)) def test_meta_tracer(self): class MetaTracerTestModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.emb = torch.nn.Embedding(num_embeddings=42, embedding_dim=16) self.layernorm = torch.nn.LayerNorm(16) def forward(self, x): emb = self.emb(x) emb = emb + torch.arange(emb.shape[-1], dtype=torch.float, device=emb.device) lol = self.layernorm(emb) return torch.relu(lol) if lol.shape[0] < 30 else torch.sigmoid(lol) mttm = MetaTracerTestModule() for BS in [15, 35]: x = torch.zeros(BS, dtype=torch.long).random_(42) meta_args = {'x' : x.to(device='meta')} gm = torch.fx.experimental.meta_tracer.symbolic_trace(mttm, meta_args=meta_args) torch.testing.assert_close(gm(x), mttm(x)) # Test serialization/deserialization with tempfile.TemporaryDirectory() as tmp_dir: with open(f'{tmp_dir}/meta_module.pkl', 'wb') as f: pickle.dump(gm, f) with open(f'{tmp_dir}/meta_module.pkl', 'rb') as f: loaded = pickle.load(f) torch.testing.assert_close(loaded(x), mttm(x)) def test_call_to_assert_with_msg(self): class M(torch.nn.Module): def forward(self, a, b): assert a == b, "test message" return a + b m = M() traced = symbolic_trace_with_rewrite(m) # Make sure the graph is well-formed traced.graph.lint() # Check the IR to make sure there's a call_function node with target == "Assert" self.assertTrue( any( node.op == "call_function" and node.target == torch._assert for node in traced.graph.nodes ) ) # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to traced(3, 3) with self.assertRaisesRegex(AssertionError, "test message"): traced(3, 5) # Confirm that the output is correct self.assertEqual(traced(3, 3), m(3, 3)) def test_call_to_assert_with_empty_msg(self): class M(torch.nn.Module): def forward(self, a, b): assert a == b, "" return a + b m = M() traced = symbolic_trace_with_rewrite(m) # Make sure the graph is well-formed traced.graph.lint() # Check the IR to make sure there's a call_function node with target == "Assert" self.assertTrue( any( node.op == "call_function" and node.target == torch._assert for node in traced.graph.nodes ) ) # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to traced(3, 3) with self.assertRaisesRegex(AssertionError, ""): traced(3, 5) # Confirm that the output is correct self.assertEqual(traced(3, 3), m(3, 3)) def test_call_to_assert_with_multiline_message(self): class M(torch.nn.Module): def forward(self, a, b): error_msg = """ An error message with terrible spacing """ assert a == b, error_msg return a + b m = M() traced = symbolic_trace_with_rewrite(m) # Make sure the graph is well-formed traced.graph.lint() # Check the IR to make sure there's a call_function node with target == "Assert" self.assertTrue( any( node.op == "call_function" and node.target == torch._assert for node in traced.graph.nodes ) ) # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to error_msg = """ An error message with terrible spacing """ traced(3, 3) with self.assertRaisesRegex(AssertionError, error_msg): traced(3, 5) # Confirm that the output is correct self.assertEqual(traced(3, 3), m(3, 3)) def test_subgraph_creation(self): class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.param = torch.nn.Parameter(torch.rand(3, 4)) self.linear = torch.nn.Linear(4, 5) def forward(self, x, y): z = self.linear(x + self.param).clamp(min=0.0, max=1.0) w = self.linear(y).clamp(min=0.0, max=1.0) return z + w # symbolically trace model my_module = MyModule() my_module_traced = symbolic_trace(my_module) # random mod partitioning partition_counter = 0 NPARTITIONS = 3 # Add some random meta info to make sure it is kept around. for node in my_module_traced.graph.nodes: if node.op != "output": node.meta["test_meta_info"] = True def mod_partition(node: Node): nonlocal partition_counter partition = partition_counter % NPARTITIONS partition_counter = (partition_counter + 1) % NPARTITIONS return partition # split module in module with submodules module_with_submodules = split_module( my_module_traced, my_module, mod_partition ) # Check that test_meta_info was still on all nodes. submodules = dict(module_with_submodules.named_modules()) for node in module_with_submodules.graph.nodes: if node.op == "call_module": submod = submodules[node.target] self.assertTrue(isinstance(submod, torch.fx.GraphModule)) for submod_node in submod.graph.nodes: if submod_node.op != "output": stored_op = submod_node.meta.get("test_meta_info") self.assertTrue(stored_op is not None and stored_op) x = torch.rand(3, 4) y = torch.rand(3, 4) orig_out = my_module_traced(x, y) submodules_out = module_with_submodules(x, y) self.assertEqual(orig_out, submodules_out) def test_split_module_dead_code(self): class ModWithDeadCode(torch.nn.Module): def forward(self, x): output = x * 2 # we want this dead_line = x + 2 # this is dead return output mod = ModWithDeadCode() traced = torch.fx.symbolic_trace(mod) # split into before (0), target (1), and after(2) saw_mul = False def split_callback(n): nonlocal saw_mul if n.target == operator.mul: saw_mul = True return 1 if not saw_mul: return 0 if saw_mul: return 2 split = split_module(traced, mod, split_callback) x = torch.randn((5,)) torch.testing.assert_close( split(x), traced(x) ) def test_split_module_kwargs_expansion(self): class ModuleWithKwargsExpansion(torch.nn.Module): def forward(self, x, **kwargs): return x + kwargs['foo'] mod = ModuleWithKwargsExpansion() traced = torch.fx.symbolic_trace(mod) seen_getitem = False def split_callback(n): nonlocal seen_getitem split_idx = int(seen_getitem) if n.target == operator.getitem: seen_getitem = True return split_idx split = split_module(traced, mod, split_callback) x = torch.randn(5, 3) foo = torch.randn(5, 3) torch.testing.assert_close(split(x, foo=foo), traced(x, foo=foo)) @skipIfNoTorchVision def test_subgraph_trivial_resnet(self): # Smoke test trivially splitting resnet into 1 partition works # There was an issue before causing submodule names to be aliased m = resnet18() traced = symbolic_trace(m) a = torch.rand(64, 3, 7, 7) module_with_submodules = split_module(traced, m, lambda node: 0) module_with_submodules(a) def test_split_module_default_arg(self): class ModelToTrace(torch.nn.Module): def __init__(self) -> None: super().__init__() self.lin = torch.nn.Linear(512, 512) def forward(self, x, targets=None): x = self.lin(x) if targets is not None: x = x + targets return x mtt = ModelToTrace() traced = torch.fx.symbolic_trace(mtt, concrete_args={'targets': None}) split = split_module(traced, mtt, lambda node: 0) x = torch.randn(50, 512) torch.testing.assert_close(split(x), traced(x)) def test_normalize_binary_operators(self): ops_to_test = { torch.add, torch.mul, torch.sub, torch.div, torch.floor_divide, torch.remainder, torch.eq, torch.ne, torch.lt, torch.le, torch.gt, torch.ge, } # Test Tensor/Tensor callsite for op in ops_to_test: class WrapperMod(torch.nn.Module): def forward(self, x, y): return op(x, y) traced = symbolic_trace(WrapperMod()) normalized = NormalizeOperators(traced).transform() x, y = torch.randn(3, 4), torch.randn(3, 4) torch.testing.assert_close(traced(x, y), normalized(x, y)) self.assertFalse( any(n.target in ops_to_test for n in normalized.graph.nodes) ) # Test Tensor/scalar callsite for op in ops_to_test: class WrapperMod(torch.nn.Module): def forward(self, x): return op(x, 42) traced = symbolic_trace(WrapperMod()) normalized = NormalizeOperators(traced).transform() x = torch.randn(3, 4) torch.testing.assert_close(traced(x), normalized(x)) self.assertFalse( any(n.target in ops_to_test for n in normalized.graph.nodes) ) @skipIfNoTorchVision def test_normalize_args(self): m = resnet18() class FunctionalTracer(torch.fx.Tracer): def is_leaf_module( self, m: torch.nn.Module, module_qualified_name: str ) -> bool: # `leaves` contains the set of standard `nn.Modules` that are not # currently symbolically traceable. Ideally this set would be empty leaves = {torch.nn.BatchNorm2d} return type(m) in leaves traced = torch.fx.GraphModule(m, FunctionalTracer().trace(m)) input = torch.randn(5, 3, 224, 224) ref_outs = traced(input) ShapeProp(traced).propagate(input) traced = NormalizeArgs(traced).transform() modules = dict(traced.named_modules()) for node in traced.graph.nodes: if node.op == "call_function" and node.target != operator.add: self.assertEqual(len(node.args), 0) elif node.op == "call_module": submod_class = modules[node.target].__class__ nn_class = getattr(torch.nn, submod_class.__name__) if submod_class == nn_class: self.assertEqual(len(node.args), 0) traced(input) self.assertEqual(traced(input), ref_outs) def test_normalize_modules_exhaustive(self): """ Exhaustively test `Node.normalized_arguments` on all standard torch.nn Module classes """ for test_params in module_tests + new_module_tests: if "constructor" not in test_params: constructor = getattr(torch.nn, test_params["module_name"]) else: constructor = test_params["constructor"] if "constructor_args" not in test_params: args = () else: args = test_params["constructor_args"] mod = constructor(*args) # Skip modules that are not standard `torch.nn` # instances, including functionals. (functionals # are tested in test_normalize_args) if mod.__class__.__name__ not in dir(torch.nn): continue if "input_fn" not in test_params: inputs = torch.randn(test_params["input_size"]) else: inputs = test_params["input_fn"]() if not isinstance(inputs, (tuple, list)): inputs = (inputs,) params = ", ".join(f"v{i}" for i in range(len(inputs))) # Generate a class to wrap this standard `nn.Module` instance test_classname = f"Test{mod.__class__.__name__}" test_mod_code = f""" class {test_classname}(torch.nn.Module): def __init__(self, mod): super().__init__() self.mod = mod def forward(self, {params}): return self.mod({params}) """ gbls = {"torch": torch} exec(test_mod_code, gbls) test_instance = gbls[test_classname](mod) traced = symbolic_trace(test_instance) # Use `Node.normalized_arguments` to get a new set of arguments # to feed to the Module. Then, rewrite the node to only take # in those arguments as kwargs modules = dict(traced.named_modules()) for node in traced.graph.nodes: if node.op == "call_module": submod_class = modules[node.target].__class__ nn_class = getattr(torch.nn, submod_class.__name__) if submod_class == nn_class: normalized_args = node.normalized_arguments(traced) normalized_args2 = normalize_module( traced, node.target, node.args, node.kwargs ) assert normalized_args == normalized_args2 assert normalized_args node.args = normalized_args.args node.kwargs = normalized_args.kwargs traced.recompile() # These Modules have an RNG in their forward, so testing # correctness by comparing outputs is not correct. Skip that # check for these stochastic_modules = {"FractionalMaxPool2d", "FractionalMaxPool3d", "RReLU"} if mod.__class__.__name__ not in stochastic_modules: self.assertEqual(traced(*inputs), mod(*inputs)) traced = NormalizeArgs(symbolic_trace(test_instance)).transform() modules = dict(traced.named_modules()) for node in traced.graph.nodes: if node.op == "call_module": submod_class = modules[node.target].__class__ nn_class = getattr(torch.nn, submod_class.__name__) if submod_class == nn_class: self.assertEqual(len(node.args), 0) def test_normalize_args_preserve_meta(self): class MyModule(torch.nn.Module): def forward(self, a): return torch.add(a, 3) m = MyModule() traced = symbolic_trace(m) for node in traced.graph.nodes: if node.op == "call_function" and node.target == torch.add: node.meta["my_key"] = 7 break else: self.fail("Didn't find call_function torch.add") input = torch.randn(2, 3) ShapeProp(traced).propagate(input) traced = NormalizeArgs(traced).transform() for node in traced.graph.nodes: if node.op == "call_function" and node.target == torch.add: self.assertTrue("my_key" in node.meta) self.assertEqual(node.meta["my_key"], 7) break else: self.fail("Didn't find call_function torch.add") def test_normalize_args_perserve_type(self): class MyModule(torch.nn.Module): def forward(self, a: List[torch.Tensor]): return torch.add(a[0], a[1]) m = MyModule() traced = symbolic_trace(m) traced = NormalizeArgs(traced).transform() for node in traced.graph.nodes: if node.op == "placeholder": self.assertEqual(node.type, List[torch.Tensor]) @skipIfNoTorchVision def test_annotate_returns_with_schema(self): m = resnet18() traced_modules = symbolic_trace(m) traced_modules_annotated = AnnotateTypesWithSchema(traced_modules).transform() for node in traced_modules_annotated.graph.nodes: if node.type is None: check = (node.op, node.target) self.assertIn( check, { ("placeholder", "x"), ("call_module", "maxpool"), ("call_function", operator.add), ("call_function", torch.flatten), ("output", "output"), } ) # Smoke test torchscript compilation since now we're emitting type annotations torch.jit.script(traced_modules_annotated) class FunctionalTracer(torch.fx.Tracer): def is_leaf_module( self, m: torch.nn.Module, module_qualified_name: str ) -> bool: # `leaves` contains the set of standard `nn.Modules` that are not # currently symbolically traceable. Ideally this set would be empty leaves = {torch.nn.BatchNorm2d} return type(m) in leaves traced_functionals = torch.fx.GraphModule(m, FunctionalTracer().trace(m)) traced_functionals_annotated = AnnotateTypesWithSchema( traced_functionals ).transform() for node in traced_functionals_annotated.graph.nodes: if node.type is None: check = (node.op, node.target) excluded_nodes = { ("placeholder", "x"), # Return type differs based on boolean dispatch :( ("call_function", torch.nn.functional.max_pool2d), ("output", "output"), } # AnnotateTypesWithSchema doesn't work with bound C++ functions if not isinstance(node.target, BuiltinFunctionType): self.assertIn(check, excluded_nodes) # Smoke test torchscript compilation since now we're emitting type annotations torch.jit.script(traced_functionals_annotated) def test_annotate_getitem_node(self): class CustomType: pass class CustomNamedTuple(NamedTuple): x: int y: float class MyModule(torch.nn.Module): def forward(self, inp: Tuple[CustomType, torch.Tensor], inp2: List[CustomType], inp3: CustomNamedTuple): inp_0 = inp[0] inp_1 = inp[1] inp2_0 = inp2[0] inp3_x = inp3.x inp3_y = inp3.y return inp_0 + inp_1 + inp2_0 + inp3_x + inp3_y my_module = MyModule() my_module_traced = torch.fx.symbolic_trace(my_module) # by default, fx transform loses type annotation of getitem nodes. for node in my_module_traced.graph.nodes: if node.target == operator.getitem: assert node.type is None annotate_getitem_nodes(my_module_traced.graph) for node in my_module_traced.graph.nodes: if node.target == operator.getitem: self.assertIsNotNone(node.type, f"Node {node} should be annotated but is not.") def test_subgraph_uniquename(self): class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(4, 4) def forward(self, a, b, c, d): add_1 = a + b add_2 = add_1 + c linear_1 = self.linear(add_1) add_3 = add_2 + d add_4 = add_2 + linear_1 add_5 = add_3 + add_4 return add_5 a, b, c, d = torch.ones(4), torch.ones(4), torch.ones(4), torch.ones(4) mm = MyModule() traced = symbolic_trace(mm) def split_cb(node: torch.fx.Node): if node.name == "a" or node.name == "b" or node.name == "add": return 0 else: return 1 module_with_submodule = split_module(traced, mm, split_cb) self.assertEqual(module_with_submodule(a, b, c, d), traced(a, b, c, d)) def test_split_qualname_mapping(self): d_hid = 4 class ExampleCode(torch.nn.Module): def __init__(self) -> None: super().__init__() self.mm_param = torch.nn.Parameter(torch.randn(d_hid, d_hid)) self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) self.lin = torch.nn.Linear(d_hid, d_hid) def forward(self, x): x = torch.mm(x, self.mm_param) x = torch.relu(x) x = torch.mm(x, self.mm_param) x = self.lin(x) x = torch.relu(x) x = torch.mm(x, self.mm_param2) x = self.lin(x) return x my_module = ExampleCode() my_module_traced = symbolic_trace(my_module) part_idx = 0 def split_callback(n : torch.fx.Node): nonlocal part_idx if (n.op, n.target) == ('call_module', 'lin'): part_idx += 1 return part_idx # split module in module with submodules qualname_map : Dict[str, str] = {} module_with_submodules = split_module( my_module_traced, my_module, split_callback, qualname_map ) expected_qualname_map = { 'submod_1.lin': 'lin', 'submod_2.lin': 'lin' } self.assertEqual(qualname_map, expected_qualname_map) def test_traceable_function_with_nonstandard_name(self): def foo(x): return torch.relu(x) traced = symbolic_trace_with_rewrite(foo) def test_to_folder(self): class Test(torch.nn.Module): def __init__(self) -> None: super().__init__() self.W = torch.nn.Parameter(torch.randn(2)) self.seq = torch.nn.Sequential(torch.nn.BatchNorm1d(2, 2)) self.linear = torch.nn.Linear(2, 2) self.attr = torch.randn(2) self.attr2 = torch.nn.Buffer(torch.randn(2)) self.attr3 = torch.nn.Buffer(torch.ones(2, dtype=torch.int32)) def forward(self, x): return self.linear(self.seq(self.W + self.attr + self.attr2 + self.attr3 + x)) mod = symbolic_trace(Test()) module_name = "Foo" import tempfile from pathlib import Path with tempfile.TemporaryDirectory() as tmp_dir: tmp_dir = Path(tmp_dir) mod.to_folder(tmp_dir, module_name) # Recipe taken from here: # https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly import importlib.util spec = importlib.util.spec_from_file_location( module_name, tmp_dir / "__init__.py" ) module = importlib.util.module_from_spec(spec) sys.modules[module_name] = module spec.loader.exec_module(module) t = torch.randn(2, 2) self.assertEqual(module.Foo()(t), mod(t)) def test_fetch(self): attrs_for_lowering: Dict[str, List[str]] = { "torch.nn.modules.conv.Conv2d": [ "weight", "bias", "kernel_size", "stride", "padding", "dilation", "groups", "padding_mode", ], "torch.nn.modules.batchnorm.BatchNorm2d": [ "weight", "bias", "running_mean", "running_var", "eps", ], } class TestModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d(3, 3, 2) self.bn = torch.nn.BatchNorm2d(3) def forward(self, a): a = self.conv(a) a += a return self.bn(a) mod = TestModule() traced = symbolic_trace(mod) lift_lowering_attrs_to_nodes(traced) for node in traced.graph.nodes: if node.op == "call_module": assert hasattr(node, "attrs_for_lowering") para_list = attrs_for_lowering[node.attrs_for_lowering["name"]] # node.attrs_for_lowering has an addition field of class name assert len(para_list) + 1 == len(node.attrs_for_lowering) for p_name in para_list: assert p_name in node.attrs_for_lowering def test_merge_matmuls(self): """ A collection of test cases for torch.fx.experimental.merge_matmul, a graph transformation that merges matrix multiplication operations. """ # Utility function for counting matmuls for test assertions. def _count_matmuls(mod): gm = torch.fx.symbolic_trace(mod) num_matmuls = 0 for node in gm.graph.nodes: if node.target == torch.matmul: num_matmuls += 1 return num_matmuls # Simple test case in which there are two matmuls of the same size to merge. class SimpleMergeMatmulModule(torch.nn.Module): def __init__(self, rhs): super().__init__() self.rhs = rhs def forward(self, x, y): a = torch.matmul(x, self.rhs) b = torch.matmul(y, self.rhs) return a + b # Initialize inputs. a = torch.randn(3, 3) b = torch.randn(3, 3) # Initialize RHS for matmuls. rhs = torch.randn(3, 4) # Construct SimpleMergeMatmulModule and call merge_matmul on it. module = SimpleMergeMatmulModule(rhs) opt_module = merge_matmul.merge_matmul(module) # Numerical correctness check. before = module(a, b) after = opt_module(a, b) before.allclose(after) # Basic graph structure check; original module should have 2 matmuls # and optimized module should have 1. self.assertEqual(_count_matmuls(module), 2) self.assertEqual(_count_matmuls(opt_module), 1) # Test case in which there are multiple matmuls of different sizes to merge. class FiveMergeMatmulModule(torch.nn.Module): def __init__(self, rhs): super().__init__() self.rhs = rhs def forward(self, a, b, c, d, e): s = torch.tensor([]) matmuls = [] # For some reason using a list comprehension or for-loop for this # doesn't work. matmuls.append(torch.matmul(a, self.rhs)) matmuls.append(torch.matmul(b, self.rhs)) matmuls.append(torch.matmul(c, self.rhs)) matmuls.append(torch.matmul(d, self.rhs)) matmuls.append(torch.matmul(e, self.rhs)) for m in matmuls: s += torch.sum(m) return s # Initialize inputs. inputs = [torch.randn(2 * i + 1, 5) for i in range(5)] # Initialize RHS. rhs = torch.randn(5, 4) # Construct FiveMergeMatmulModule and call merge_matmul on it. module = FiveMergeMatmulModule(rhs) opt_module = merge_matmul.merge_matmul(module) # Numerical correctness check. before = module(*inputs) after = opt_module(*inputs) before.allclose(after) # Basic graph structure check; original module should have len(inputs) matmuls # and optimized module should have 1. self.assertEqual(_count_matmuls(module), len(inputs)) self.assertEqual(_count_matmuls(opt_module), 1) # Simple test case in which two matmuls cannot be merged due to a data dependency between # the LHS operands. class UnmergeableMatmulModule(torch.nn.Module): def __init__(self, rhs): super().__init__() self.rhs = rhs def forward(self, x): a = torch.matmul(x, self.rhs) a_abs = torch.abs(a) b = torch.matmul(a_abs.transpose(1, 0), self.rhs) return b # Initialize inputs. a = torch.randn(3, 3) # Initialize RHS for matmuls. rhs = torch.randn(3, 4) # Construct UnmergeableMatmulModule and call merge_matmul on it. module = UnmergeableMatmulModule(rhs) opt_module = merge_matmul.merge_matmul(module) # Numerical correctness check. before = module(a) after = opt_module(a) before.allclose(after) # Basic graph structure check; the number of matrix multiplcations should not have changed. self.assertEqual(_count_matmuls(module), 2) self.assertEqual(_count_matmuls(opt_module), 2) def test_type_matches(self): should_be_equal = [ (int, int), (numbers.Number, int), (numbers.Number, float), (int, type(torch.float)), (Union[int, float], int), (Union[int, float], float), (List[int], int), (List[int], create_type_hint([int, int])), (List[int], create_type_hint((int, int))), (List[torch.Tensor], create_type_hint([torch.Tensor, torch.Tensor])), ( List[torch.Tensor], create_type_hint([torch.nn.Parameter, torch.nn.Parameter]), ), (torch.Tensor, torch.nn.Parameter), (List[torch.Tensor], create_type_hint([torch.nn.Parameter, torch.Tensor])), (List[torch.Tensor], create_type_hint([torch.Tensor, torch.nn.Parameter])), (List[torch.Tensor], create_type_hint((torch.Tensor, torch.Tensor))), ( List[torch.Tensor], create_type_hint((torch.nn.Parameter, torch.nn.Parameter)), ), (torch.Tensor, torch.nn.Parameter), (List[torch.Tensor], create_type_hint((torch.nn.Parameter, torch.Tensor))), (List[torch.Tensor], create_type_hint((torch.Tensor, torch.nn.Parameter))), (Optional[List[torch.Tensor]], List[torch.Tensor]), (Optional[List[int]], List[int]), ] for sig_type, arg_type in should_be_equal: self.assertTrue(type_matches(sig_type, arg_type)) should_fail = [ (int, float), (Union[int, float], str), (List[torch.Tensor], List[int]), ] for sig_type, arg_type in should_fail: self.assertFalse(type_matches(sig_type, arg_type)) @skipIfNoMkldnn def test_optimize_for_inference_cpu(self): import torch.nn as nn class Foo(nn.Module): def __init__(self) -> None: super().__init__() layers = [] layers2 = [] for _ in range(10): layers.append(nn.Conv2d(3, 3, 1)) layers.append(nn.BatchNorm2d(3)) layers.append(nn.ReLU()) layers2.append(nn.Conv2d(3, 3, 1)) layers2.append(nn.BatchNorm2d(3)) layers2.append(nn.ReLU()) self.model = nn.Sequential(*layers) self.model2 = nn.Sequential(*layers2) def forward(self, x): return self.model(x) + self.model2(x) N, C, H, W, = ( 1, 3, 224, 224, ) inp = torch.randn(N, C, H, W) with torch.no_grad(): model = Foo().eval() optimized_model = optimization.optimize_for_inference(model) torch.testing.assert_close(model(inp), optimized_model(inp)) optimized_model2 = optimization.optimize_for_inference( model, pass_config={"remove_dropout": False} ) torch.testing.assert_close(model(inp), optimized_model2(inp)) @skipIfNoTorchVision @skipIfNoMkldnn def test_optimize_for_inference_cpu_torchvision(self): models = [ torchvision.models.resnet18, torchvision.models.resnet50, torchvision.models.densenet121, torchvision.models.shufflenet_v2_x1_0, torchvision.models.vgg16, torchvision.models.mobilenet_v2, torchvision.models.mnasnet1_0, torchvision.models.resnext50_32x4d, ] with torch.no_grad(): for model_type in models: model = model_type() C, H, W, = ( 3, 224, 224, ) inp = torch.randn(3, C, H, W) model(inp) model.eval() inp = torch.randn(1, C, H, W) heuristic = optimization.gen_mkl_autotuner(inp, iters=0, warmup=0) optimized_model = optimization.optimize_for_inference(model) orig_out = model(inp) new_out = optimized_model(inp) torch.testing.assert_close(orig_out, new_out) class TestNormalizeOperators(JitTestCase): @onlyCPU @ops(op_db, allowed_dtypes=(torch.float,)) def test_normalize_operator_exhaustive(self, device, dtype, op): # These ops currently don't trace in FX for various reasons (i.e. they take a list of tensors) fx_fail = {"cat", "stack", "hstack", "vstack", "dstack", "linalg.multi_dot", "_upsample_bilinear2d_aa", "_chunk_cat"} sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False) if isinstance(op.op, torch._ops.OpOverload): self.skipTest("normalize operator doesn't work on torch.ops") for sample_input in sample_inputs_itr: unsupported_arg_type = False arg_values = [sample_input.input] + list(sample_input.args) kwarg_values = sample_input.kwargs arg_types = [] kwarg_types = {} def jit_infer_type(v): inferred_arg_type = torch._C._jit_try_infer_type(v) assert inferred_arg_type.success() t = _torchscript_type_to_python_type(inferred_arg_type.type()) return t for v in arg_values: if isinstance(v, torch.Tensor): arg_types.append(type(v)) else: if isinstance(v, complex): # Complex type not supported in FX unsupported_arg_type = True arg_types.append(jit_infer_type(v)) for k, v in kwarg_values.items(): if isinstance(v, torch.Tensor): kwarg_types[k] = type(v) else: if isinstance(v, complex): # Complex type not supported in FX unsupported_arg_type = True kwarg_types[k] = jit_infer_type(v) if unsupported_arg_type: continue # Test normalize_function by itself ref_out = op.op(*arg_values, **kwarg_values) norm_args_and_kwargs = normalize_function( op.op, arg_values, kwarg_values, arg_types, kwarg_types ) if norm_args_and_kwargs is None: raise RuntimeError( """ FX failed to normalize op - add the op to the op_skip list. A common reason is if your OpInfo was implemented with a lambda - otherwise, file an issue """ ) test_out = op.op(*norm_args_and_kwargs.args, **norm_args_and_kwargs.kwargs) self.assertEqual(test_out, ref_out) # Test normalized_arguments as part of FX if op.name in fx_fail: continue param_names = [] param_values = [] fx_args = [] idx = 0 def process_arg(arg, name): if isinstance(arg, torch.Tensor): param_names.append(name) param_values.append(arg) return name else: return f"{repr(arg)}" def process_arg_with_idx(arg): nonlocal idx res = process_arg(arg, f"arg_{idx}") idx = idx + 1 return res def str_arg(arg): if isinstance(arg, tuple): args = [f"{str_arg(v)}, " for v in arg] return f"({' '.join(args)})" elif isinstance(arg, list): args = [f"{str_arg(v)}" for v in arg] return f"[{', '.join(args)}]" else: return arg for v in arg_values: arg = pytree.tree_map(process_arg_with_idx, v) fx_args.append(str_arg(arg)) for k, v in kwarg_values.items(): arg = pytree.tree_map(functools.partial(process_arg, name=k), v) fx_args.append(f"{k} = {str_arg(arg)}") code = f""" class TestModule(torch.nn.Module): def forward(self, {', '.join(param_names)}): return torch.{op.name}({', '.join(fx_args)}) """ g = {"torch": torch, "inf": math.inf} exec(code, g) TestModule = g["TestModule"] m = TestModule() traced = torch.fx.symbolic_trace(m) ref_out = traced(*param_values) for node in traced.graph.nodes: if node.op == "call_function": normalized_args = node.normalized_arguments( traced, arg_types, kwarg_types ) assert normalized_args node.args = normalized_args.args node.kwargs = normalized_args.kwargs traced.recompile() test_out = traced(*param_values) self.assertEqual(test_out, ref_out) def test_normalize_quantized_eb(self): target = torch.ops.quantized.embedding_bag_byte_rowwise_offsets args = ( torch.empty((2, 3), dtype=torch.uint8), torch.empty((2,), dtype=torch.int64), torch.empty((2,), dtype=torch.int64), ) norm_args_and_kwargs = normalize_function( target, args, normalize_to_only_use_kwargs=True ) self.assertTrue(norm_args_and_kwargs is not None) self.assertEqual( set(norm_args_and_kwargs.kwargs.keys()), { "weight", "indices", "offsets", "scale_grad_by_freq", "mode", "pruned_weights", "per_sample_weights", "compressed_indices_mapping", "include_last_offset", }, ) self.assertEqual(norm_args_and_kwargs.args, ()) def test_normalize_args_op_overload(self): for target in [torch.ops.aten.resize_as_.default, torch.ops.aten.resize_as_]: inp1 = torch.rand([1]) inp2 = torch.rand([4]) args, kwargs = normalize_function(target, (inp1,), {"the_template": inp2}, normalize_to_only_use_kwargs=True) self.assertIs(kwargs["input"], inp1) self.assertIs(kwargs["the_template"], inp2) if TEST_Z3: import z3 import torch._dynamo.config from torch.fx.experimental.validator import SympyToZ3, TranslationValidator, ValidationException, z3str from torch.utils._sympy.functions import FloorDiv, Mod class TestTranslationValidation(TestCase): def _prepare_for_translation_validation(self): validator = TranslationValidator() # SymPy symbols. s0, s1, s2 = sympy.symbols("s0 s1 s2", integer=True) # Z3 symbols. [validator.add_var(s, int) for s in (s0, s1, s2)] z0, z1, z2 = (validator.z3var(s) for s in (s0, s1, s2)) return (s0, s1, s2), (z0, z1, z2), validator def test_sympy_to_z3(self): ( (s0, s1, s2), (z0, z1, z2), validator, ) = self._prepare_for_translation_validation() test_cases = [ # Integer constants. (sympy.S.Zero, z3.IntVal(0)), (sympy.S.One, z3.IntVal(1)), (sympy.S.NegativeOne, z3.IntVal(-1)), (sympy.Integer(2), z3.IntVal(2)), ( s0, z0, ), # Arithmetic operations. *[ (op(s0, s1), op(z0, z1)) for op in ( operator.add, operator.mul, operator.pow, ) ], # Logical operations. *[ (sympy_op(s0, s1), z3_op(z0, z1)) for sympy_op, z3_op in ( (sympy.Eq, operator.eq), (sympy.Ne, operator.ne), (sympy.Lt, operator.lt), (sympy.Le, operator.le), (sympy.Gt, operator.gt), (sympy.Ge, operator.ge), ) ], # Other operations. ( s0 - s1, z0 + z3.IntVal(-1) * z1, ), ( s0 / s1, z3.ToReal(z0) * (z1**-1), ), (FloorDiv(s0, s1), z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1))), (Mod(s0, s1), z0 - z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1)) * z1), ( Mod(s2, (s0 / s1)), z2 - z3.ToReal(z3.ToInt(z3.ToReal(z2) / (z3.ToReal(z0) * z1**-1))) * (z3.ToReal(z0) * z1**-1), ), ( Mod(s2, s0**3), z2 - z3.ToReal(z3.ToInt(z3.ToReal(z2) / z0**3)) * z0**3, ), ] toZ3 = SympyToZ3(validator) for sympy_expr, z3_expr in test_cases: result = toZ3.run(sympy_expr) self.assertTrue( z3_expr.eq(result), msg=f"expected: {z3_expr}. Got: {result}" ) def test_sat(self): ( (s0, s1, s2), (z0, z1, z2), validator, ) = self._prepare_for_translation_validation() validator.add_source_expr(z0 > 5) validator.add_source_expr(z1 / 2 > z0) # Solutions for target is a subset of the solutions for the source. validator.add_target_expr(s0 > 20) validator.add_target_expr(s1 > s0**2) validator.validate() def test_unsat(self): ( (s0, s1, s2), (z0, z1, z2), validator, ) = self._prepare_for_translation_validation() validator.add_source_expr(z0 > 5) validator.add_source_expr(z1 / 2 > z0) # Solutions for target is NOT a subset of the solutions for the source. validator.add_target_expr(s0 > 20) # This expression is less restrictive than its counterpart. validator.add_target_expr(s1 > s0 + 2) with self.assertRaisesRegex(ValidationException, "translation validation failed."): validator.validate() def test_z3str(self): a = z3.Int("a") b = z3.Int("b") special = z3.Real("this.size()[2]") test_cases = [ (z3.IntVal(42), "42"), # Variable. (a, "a"), # Name with special characters. (special, "this.size()[2]"), # Renamed function fpplications. (a != b, "(!= a b)"), (a ** b, "(pow a b)"), # Chain of associative operations. *[ (op(op(a, 5), b), f"({opstr} 5 a b)") for op, opstr in [ (operator.add, "+"), (operator.mul, "*") ] ], # Revert 'Not' conversions. (a != b, "(!= a b)"), (a < b, "(> b a)"), (a > b, "(> a b)"), # Ignore 'ToInt' and 'ToReal' functions. (z3.ToInt(special) + a, "(+ this.size()[2] a)"), (z3.ToReal(a + b), "(+ a b)"), # Convert to floor division: 'idiv'. (z3.ToInt(z3.ToReal(a) / z3.ToReal(b)), "(idiv a b)"), ] for expr, expected in test_cases: self.assertEqual(z3str(expr), expected) instantiate_device_type_tests(TestNormalizeOperators, globals()) if __name__ == "__main__": run_tests()