1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: fx"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport functools 4*da0073e9SAndroid Build Coastguard Workerimport math 5*da0073e9SAndroid Build Coastguard Workerimport numbers 6*da0073e9SAndroid Build Coastguard Workerimport operator 7*da0073e9SAndroid Build Coastguard Workerimport pickle 8*da0073e9SAndroid Build Coastguard Workerimport sys 9*da0073e9SAndroid Build Coastguard Workerimport sympy 10*da0073e9SAndroid Build Coastguard Workerimport tempfile 11*da0073e9SAndroid Build Coastguard Workerimport unittest 12*da0073e9SAndroid Build Coastguard Workerfrom types import BuiltinFunctionType 13*da0073e9SAndroid Build Coastguard Workerfrom typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Workerimport torch 16*da0073e9SAndroid Build Coastguard Workerimport torch.fx.experimental.meta_tracer 17*da0073e9SAndroid Build Coastguard Workerimport torch.fx.experimental.optimization as optimization 18*da0073e9SAndroid Build Coastguard Workerfrom torch.fx._symbolic_trace import symbolic_trace 19*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.experimental import merge_matmul 20*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.experimental.accelerator_partitioner import Partitioner 21*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.experimental.normalize import NormalizeArgs, NormalizeOperators 22*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.experimental.partitioner_utils import ( 23*da0073e9SAndroid Build Coastguard Worker Device, 24*da0073e9SAndroid Build Coastguard Worker get_latency_of_partitioned_graph, 25*da0073e9SAndroid Build Coastguard Worker get_partition_to_latency_mapping, 26*da0073e9SAndroid Build Coastguard Worker NodeLatency, 27*da0073e9SAndroid Build Coastguard Worker PartitionerConfig, 28*da0073e9SAndroid Build Coastguard Worker PartitionMode, 29*da0073e9SAndroid Build Coastguard Worker) 30*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.experimental.rewriter import RewritingTracer 31*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.experimental.schema_type_annotation import AnnotateTypesWithSchema 32*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.graph_module import GraphModule 33*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.node import Node 34*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.operator_schemas import ( 35*da0073e9SAndroid Build Coastguard Worker _torchscript_type_to_python_type, 36*da0073e9SAndroid Build Coastguard Worker create_type_hint, 37*da0073e9SAndroid Build Coastguard Worker normalize_function, 38*da0073e9SAndroid Build Coastguard Worker normalize_module, 39*da0073e9SAndroid Build Coastguard Worker type_matches, 40*da0073e9SAndroid Build Coastguard Worker) 41*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.passes import graph_manipulation 42*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.passes.param_fetch import lift_lowering_attrs_to_nodes 43*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.passes.shape_prop import ShapeProp 44*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.passes.split_module import split_module 45*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.passes.annotate_getitem_nodes import annotate_getitem_nodes 46*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import ( 47*da0073e9SAndroid Build Coastguard Worker instantiate_device_type_tests, 48*da0073e9SAndroid Build Coastguard Worker onlyCPU, 49*da0073e9SAndroid Build Coastguard Worker ops, 50*da0073e9SAndroid Build Coastguard Worker) 51*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_methods_invocations import op_db 52*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_nn import module_tests, new_module_tests 53*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import TEST_Z3, run_tests, TestCase 54*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase 55*da0073e9SAndroid Build Coastguard Workerimport torch.utils._pytree as pytree 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Workertry: 58*da0073e9SAndroid Build Coastguard Worker import torchvision.models 59*da0073e9SAndroid Build Coastguard Worker from torchvision.models import resnet18 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Worker HAS_TORCHVISION = True 62*da0073e9SAndroid Build Coastguard Workerexcept ImportError: 63*da0073e9SAndroid Build Coastguard Worker HAS_TORCHVISION = False 64*da0073e9SAndroid Build Coastguard WorkerskipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") 65*da0073e9SAndroid Build Coastguard WorkerskipIfNoMkldnn = unittest.skipIf( 66*da0073e9SAndroid Build Coastguard Worker not (torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available()), 67*da0073e9SAndroid Build Coastguard Worker "no MKLDNN", 68*da0073e9SAndroid Build Coastguard Worker) 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Workerdef symbolic_trace_with_rewrite(root: Union[torch.nn.Module, Callable]) -> GraphModule: 72*da0073e9SAndroid Build Coastguard Worker return GraphModule( 73*da0073e9SAndroid Build Coastguard Worker root if isinstance(root, torch.nn.Module) else torch.nn.Module(), 74*da0073e9SAndroid Build Coastguard Worker RewritingTracer().trace(root), 75*da0073e9SAndroid Build Coastguard Worker ) 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Workerclass TestFXExperimental(JitTestCase): 79*da0073e9SAndroid Build Coastguard Worker def test_find_single_partition(self): 80*da0073e9SAndroid Build Coastguard Worker class TestModule(torch.nn.Module): 81*da0073e9SAndroid Build Coastguard Worker def forward(self, a, b): 82*da0073e9SAndroid Build Coastguard Worker return a + b 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard Worker m = TestModule() 85*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(m) 86*da0073e9SAndroid Build Coastguard Worker a = torch.rand(1) 87*da0073e9SAndroid Build Coastguard Worker b = torch.rand(1) 88*da0073e9SAndroid Build Coastguard Worker graph_manipulation.get_size_of_all_nodes(traced, [a, b]) 89*da0073e9SAndroid Build Coastguard Worker partitioner = Partitioner() 90*da0073e9SAndroid Build Coastguard Worker devices = [ 91*da0073e9SAndroid Build Coastguard Worker Device("dev_0", 125, 0), 92*da0073e9SAndroid Build Coastguard Worker Device("dev_1", 150, 1), 93*da0073e9SAndroid Build Coastguard Worker Device("dev_2", 125, 2), 94*da0073e9SAndroid Build Coastguard Worker ] 95*da0073e9SAndroid Build Coastguard Worker partitioner_config = PartitionerConfig(devices) 96*da0073e9SAndroid Build Coastguard Worker ret = partitioner.partition_graph(traced, m, partitioner_config) 97*da0073e9SAndroid Build Coastguard Worker module_with_submodules = ret.module_with_submodules 98*da0073e9SAndroid Build Coastguard Worker dag = ret.dag 99*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced(a, b), module_with_submodules(a, b)) 100*da0073e9SAndroid Build Coastguard Worker assert dag.nodes[0].logical_device_ids == [1] 101*da0073e9SAndroid Build Coastguard Worker 102*da0073e9SAndroid Build Coastguard Worker def test_lack_of_devices(self): 103*da0073e9SAndroid Build Coastguard Worker class TestModule(torch.nn.Module): 104*da0073e9SAndroid Build Coastguard Worker def forward(self, a, b): 105*da0073e9SAndroid Build Coastguard Worker return a + b 106*da0073e9SAndroid Build Coastguard Worker 107*da0073e9SAndroid Build Coastguard Worker m = TestModule() 108*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(m) 109*da0073e9SAndroid Build Coastguard Worker a = torch.rand(4) 110*da0073e9SAndroid Build Coastguard Worker b = torch.rand(4) 111*da0073e9SAndroid Build Coastguard Worker graph_manipulation.get_size_of_all_nodes(traced, [a, b]) 112*da0073e9SAndroid Build Coastguard Worker partitioner = Partitioner() 113*da0073e9SAndroid Build Coastguard Worker devices = [Device("dev_0", 4, 0), Device("dev_1", 4, 1)] 114*da0073e9SAndroid Build Coastguard Worker partitioner_config = PartitionerConfig(devices, PartitionMode.size_based) 115*da0073e9SAndroid Build Coastguard Worker catch_runtime_error = False 116*da0073e9SAndroid Build Coastguard Worker try: 117*da0073e9SAndroid Build Coastguard Worker ret = partitioner.partition_graph(traced, m, partitioner_config) 118*da0073e9SAndroid Build Coastguard Worker except RuntimeError: 119*da0073e9SAndroid Build Coastguard Worker catch_runtime_error = True 120*da0073e9SAndroid Build Coastguard Worker assert catch_runtime_error 121*da0073e9SAndroid Build Coastguard Worker 122*da0073e9SAndroid Build Coastguard Worker def test_large_node_error(self): 123*da0073e9SAndroid Build Coastguard Worker class TestModule(torch.nn.Module): 124*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 125*da0073e9SAndroid Build Coastguard Worker super().__init__() 126*da0073e9SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(4, 4) 127*da0073e9SAndroid Build Coastguard Worker 128*da0073e9SAndroid Build Coastguard Worker def forward(self, a): 129*da0073e9SAndroid Build Coastguard Worker linear = self.linear(a) 130*da0073e9SAndroid Build Coastguard Worker add = linear + a 131*da0073e9SAndroid Build Coastguard Worker return add 132*da0073e9SAndroid Build Coastguard Worker 133*da0073e9SAndroid Build Coastguard Worker m = TestModule() 134*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(m) 135*da0073e9SAndroid Build Coastguard Worker a = torch.rand(4) 136*da0073e9SAndroid Build Coastguard Worker graph_manipulation.get_size_of_all_nodes(traced, [a]) 137*da0073e9SAndroid Build Coastguard Worker partitioner = Partitioner() 138*da0073e9SAndroid Build Coastguard Worker devices = [ 139*da0073e9SAndroid Build Coastguard Worker Device("dev_0", 40, 0), 140*da0073e9SAndroid Build Coastguard Worker Device("dev_1", 40, 0), 141*da0073e9SAndroid Build Coastguard Worker Device("dev_2", 40, 0), 142*da0073e9SAndroid Build Coastguard Worker Device("dev_3", 40, 0), 143*da0073e9SAndroid Build Coastguard Worker Device("dev_4", 40, 0), 144*da0073e9SAndroid Build Coastguard Worker ] 145*da0073e9SAndroid Build Coastguard Worker partitioner_config = PartitionerConfig(devices, PartitionMode.size_based) 146*da0073e9SAndroid Build Coastguard Worker catch_runtime_error = False 147*da0073e9SAndroid Build Coastguard Worker try: 148*da0073e9SAndroid Build Coastguard Worker ret = partitioner.partition_graph(traced, m, partitioner_config) 149*da0073e9SAndroid Build Coastguard Worker except RuntimeError: 150*da0073e9SAndroid Build Coastguard Worker catch_runtime_error = True 151*da0073e9SAndroid Build Coastguard Worker assert catch_runtime_error 152*da0073e9SAndroid Build Coastguard Worker 153*da0073e9SAndroid Build Coastguard Worker def test_partition_node_manipulation(self): 154*da0073e9SAndroid Build Coastguard Worker class TestModule(torch.nn.Module): 155*da0073e9SAndroid Build Coastguard Worker def forward(self, a, b): 156*da0073e9SAndroid Build Coastguard Worker add_1 = a + b 157*da0073e9SAndroid Build Coastguard Worker add_2 = add_1 + torch.rand(4) 158*da0073e9SAndroid Build Coastguard Worker add_3 = add_2 + torch.rand(4) 159*da0073e9SAndroid Build Coastguard Worker return add_3 160*da0073e9SAndroid Build Coastguard Worker 161*da0073e9SAndroid Build Coastguard Worker m = TestModule() 162*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(m) 163*da0073e9SAndroid Build Coastguard Worker a, b = torch.rand(4), torch.rand(4) 164*da0073e9SAndroid Build Coastguard Worker graph_manipulation.get_size_of_all_nodes(traced, [a, b]) 165*da0073e9SAndroid Build Coastguard Worker partitioner = Partitioner() 166*da0073e9SAndroid Build Coastguard Worker devices = [Device("dev_0", 1000, 0)] 167*da0073e9SAndroid Build Coastguard Worker partitioner_config = PartitionerConfig(devices) 168*da0073e9SAndroid Build Coastguard Worker ret = partitioner.partition_graph(traced, m, partitioner_config) 169*da0073e9SAndroid Build Coastguard Worker partition = partitioner.partitions[0] 170*da0073e9SAndroid Build Coastguard Worker assert partition.used_mem_bytes == 112 171*da0073e9SAndroid Build Coastguard Worker # Select add_2 node to remove 172*da0073e9SAndroid Build Coastguard Worker selected_node = None 173*da0073e9SAndroid Build Coastguard Worker for node in partition.nodes: 174*da0073e9SAndroid Build Coastguard Worker if node.name == "add_2": 175*da0073e9SAndroid Build Coastguard Worker selected_node = node 176*da0073e9SAndroid Build Coastguard Worker partition.remove_node(selected_node) 177*da0073e9SAndroid Build Coastguard Worker assert partition.used_mem_bytes == 80 178*da0073e9SAndroid Build Coastguard Worker 179*da0073e9SAndroid Build Coastguard Worker def test_size_based_partition(self): 180*da0073e9SAndroid Build Coastguard Worker class TestModule(torch.nn.Module): 181*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 182*da0073e9SAndroid Build Coastguard Worker super().__init__() 183*da0073e9SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(4, 4) 184*da0073e9SAndroid Build Coastguard Worker self.c = torch.rand(4) 185*da0073e9SAndroid Build Coastguard Worker 186*da0073e9SAndroid Build Coastguard Worker def forward(self, a, b): 187*da0073e9SAndroid Build Coastguard Worker add_1 = a + b 188*da0073e9SAndroid Build Coastguard Worker linear = self.linear(add_1) 189*da0073e9SAndroid Build Coastguard Worker add_2 = linear + self.c 190*da0073e9SAndroid Build Coastguard Worker return add_2 191*da0073e9SAndroid Build Coastguard Worker 192*da0073e9SAndroid Build Coastguard Worker m = TestModule() 193*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(m) 194*da0073e9SAndroid Build Coastguard Worker a = torch.rand(4) 195*da0073e9SAndroid Build Coastguard Worker b = torch.rand(4) 196*da0073e9SAndroid Build Coastguard Worker graph_manipulation.get_size_of_all_nodes(traced, [a, b]) 197*da0073e9SAndroid Build Coastguard Worker partitioner = Partitioner() 198*da0073e9SAndroid Build Coastguard Worker devices = [ 199*da0073e9SAndroid Build Coastguard Worker Device("dev_0", 125, 0), 200*da0073e9SAndroid Build Coastguard Worker Device("dev_1", 125, 1), 201*da0073e9SAndroid Build Coastguard Worker Device("dev_2", 125, 2), 202*da0073e9SAndroid Build Coastguard Worker ] 203*da0073e9SAndroid Build Coastguard Worker partitioner_config = PartitionerConfig(devices, PartitionMode.size_based) 204*da0073e9SAndroid Build Coastguard Worker ret = partitioner.partition_graph(traced, m, partitioner_config) 205*da0073e9SAndroid Build Coastguard Worker module_with_submodules = ret.module_with_submodules 206*da0073e9SAndroid Build Coastguard Worker dag = ret.dag 207*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced(a, b), module_with_submodules(a, b)) 208*da0073e9SAndroid Build Coastguard Worker for i, node in enumerate(dag.nodes): 209*da0073e9SAndroid Build Coastguard Worker assert node.logical_device_ids == [i] 210*da0073e9SAndroid Build Coastguard Worker 211*da0073e9SAndroid Build Coastguard Worker def test_partition_device_mapping(self): 212*da0073e9SAndroid Build Coastguard Worker class TestModule(torch.nn.Module): 213*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 214*da0073e9SAndroid Build Coastguard Worker super().__init__() 215*da0073e9SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(4, 4) 216*da0073e9SAndroid Build Coastguard Worker 217*da0073e9SAndroid Build Coastguard Worker def forward(self, a): 218*da0073e9SAndroid Build Coastguard Worker b = torch.rand(4) 219*da0073e9SAndroid Build Coastguard Worker add_1 = a + b 220*da0073e9SAndroid Build Coastguard Worker linear_1 = self.linear(add_1) 221*da0073e9SAndroid Build Coastguard Worker add_2 = torch.rand(4) + a 222*da0073e9SAndroid Build Coastguard Worker add_3 = add_2 + linear_1 223*da0073e9SAndroid Build Coastguard Worker return add_3 224*da0073e9SAndroid Build Coastguard Worker 225*da0073e9SAndroid Build Coastguard Worker m = TestModule() 226*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(m) 227*da0073e9SAndroid Build Coastguard Worker a = torch.rand(4) 228*da0073e9SAndroid Build Coastguard Worker graph_manipulation.get_size_of_all_nodes(traced, [a]) 229*da0073e9SAndroid Build Coastguard Worker partitioner = Partitioner() 230*da0073e9SAndroid Build Coastguard Worker devices = [Device("dev_0", 120, 0), Device("dev_1", 160, 1)] 231*da0073e9SAndroid Build Coastguard Worker partitioner_config = PartitionerConfig(devices, PartitionMode.size_based) 232*da0073e9SAndroid Build Coastguard Worker ret = partitioner.partition_graph(traced, m, partitioner_config) 233*da0073e9SAndroid Build Coastguard Worker module_with_submodules = ret.module_with_submodules 234*da0073e9SAndroid Build Coastguard Worker dag = ret.dag 235*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced(a), module_with_submodules(a)) 236*da0073e9SAndroid Build Coastguard Worker for i, node in enumerate(dag.nodes): 237*da0073e9SAndroid Build Coastguard Worker if i == 1: 238*da0073e9SAndroid Build Coastguard Worker assert node.logical_device_ids == [1] 239*da0073e9SAndroid Build Coastguard Worker else: 240*da0073e9SAndroid Build Coastguard Worker assert node.logical_device_ids == [0] 241*da0073e9SAndroid Build Coastguard Worker 242*da0073e9SAndroid Build Coastguard Worker def test_sparse_nn_partition(self): 243*da0073e9SAndroid Build Coastguard Worker class MyRecommendationModule(torch.nn.Module): 244*da0073e9SAndroid Build Coastguard Worker def create_mlp(self, num_of_layers: int, input_size: int, output_size: int): 245*da0073e9SAndroid Build Coastguard Worker layers = torch.nn.ModuleList() 246*da0073e9SAndroid Build Coastguard Worker for _ in range(num_of_layers): 247*da0073e9SAndroid Build Coastguard Worker ll = torch.nn.Linear(input_size, output_size) 248*da0073e9SAndroid Build Coastguard Worker layers.append(ll) 249*da0073e9SAndroid Build Coastguard Worker layers.append(torch.nn.ReLU()) 250*da0073e9SAndroid Build Coastguard Worker return layers 251*da0073e9SAndroid Build Coastguard Worker 252*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 253*da0073e9SAndroid Build Coastguard Worker super().__init__() 254*da0073e9SAndroid Build Coastguard Worker layers = self.create_mlp(4, 4, 4) 255*da0073e9SAndroid Build Coastguard Worker self.bottom_layers = torch.nn.Sequential(*layers) 256*da0073e9SAndroid Build Coastguard Worker layers = self.create_mlp(3, 24, 24) 257*da0073e9SAndroid Build Coastguard Worker self.top_layers = torch.nn.Sequential(*layers) 258*da0073e9SAndroid Build Coastguard Worker self.embedding_layers = torch.nn.ModuleList() 259*da0073e9SAndroid Build Coastguard Worker el = torch.nn.EmbeddingBag(500000, 4, mode="sum", sparse=True) 260*da0073e9SAndroid Build Coastguard Worker self.embedding_layers.append(el) 261*da0073e9SAndroid Build Coastguard Worker for i in range(3): 262*da0073e9SAndroid Build Coastguard Worker el = torch.nn.EmbeddingBag(1000000, 4, mode="sum", sparse=True) 263*da0073e9SAndroid Build Coastguard Worker self.embedding_layers.append(el) 264*da0073e9SAndroid Build Coastguard Worker el = torch.nn.EmbeddingBag(500000, 4, mode="sum", sparse=True) 265*da0073e9SAndroid Build Coastguard Worker self.embedding_layers.append(el) 266*da0073e9SAndroid Build Coastguard Worker 267*da0073e9SAndroid Build Coastguard Worker def forward(self, a, b, offset): 268*da0073e9SAndroid Build Coastguard Worker x = self.bottom_layers(a) 269*da0073e9SAndroid Build Coastguard Worker y = [] 270*da0073e9SAndroid Build Coastguard Worker c = [] 271*da0073e9SAndroid Build Coastguard Worker for i in range(len(self.embedding_layers)): 272*da0073e9SAndroid Build Coastguard Worker temp = torch.randint(10, (8,)) 273*da0073e9SAndroid Build Coastguard Worker c.append(temp + b) 274*da0073e9SAndroid Build Coastguard Worker for i in range(len(self.embedding_layers)): 275*da0073e9SAndroid Build Coastguard Worker if i % 2 == 0: 276*da0073e9SAndroid Build Coastguard Worker y.append(self.embedding_layers[i](c[i], offset)) 277*da0073e9SAndroid Build Coastguard Worker else: 278*da0073e9SAndroid Build Coastguard Worker y.append( 279*da0073e9SAndroid Build Coastguard Worker self.embedding_layers[i](torch.randint(10, (8,)), offset) 280*da0073e9SAndroid Build Coastguard Worker ) 281*da0073e9SAndroid Build Coastguard Worker z = torch.cat([x] + y, dim=1) 282*da0073e9SAndroid Build Coastguard Worker p = self.top_layers(z) 283*da0073e9SAndroid Build Coastguard Worker return p 284*da0073e9SAndroid Build Coastguard Worker 285*da0073e9SAndroid Build Coastguard Worker m = MyRecommendationModule() 286*da0073e9SAndroid Build Coastguard Worker a = torch.rand(2, 4) 287*da0073e9SAndroid Build Coastguard Worker b = torch.randint(10, (8,)) 288*da0073e9SAndroid Build Coastguard Worker offset = torch.randint(1, (2,)) 289*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(m) 290*da0073e9SAndroid Build Coastguard Worker graph_manipulation.get_size_of_all_nodes(traced, [a, b, offset]) 291*da0073e9SAndroid Build Coastguard Worker devices = [ 292*da0073e9SAndroid Build Coastguard Worker Device("dev_0", 33000000, 0), 293*da0073e9SAndroid Build Coastguard Worker Device("dev_1", 33000000, 1), 294*da0073e9SAndroid Build Coastguard Worker Device("dev_2", 33000000, 2), 295*da0073e9SAndroid Build Coastguard Worker ] 296*da0073e9SAndroid Build Coastguard Worker partitioner_config = PartitionerConfig(devices, PartitionMode.sparse_nn) 297*da0073e9SAndroid Build Coastguard Worker partitioner = Partitioner() 298*da0073e9SAndroid Build Coastguard Worker ret = partitioner.partition_graph(traced, m, partitioner_config) 299*da0073e9SAndroid Build Coastguard Worker module_with_submodules = ret.module_with_submodules 300*da0073e9SAndroid Build Coastguard Worker dag = ret.dag 301*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced(a, b, offset), module_with_submodules(a, b, offset)) 302*da0073e9SAndroid Build Coastguard Worker assert len(module_with_submodules.graph.nodes) == 24 303*da0073e9SAndroid Build Coastguard Worker 304*da0073e9SAndroid Build Coastguard Worker def test_partition_latency(self): 305*da0073e9SAndroid Build Coastguard Worker class TestModule(torch.nn.Module): 306*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 307*da0073e9SAndroid Build Coastguard Worker super().__init__() 308*da0073e9SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(4, 4) 309*da0073e9SAndroid Build Coastguard Worker 310*da0073e9SAndroid Build Coastguard Worker def forward(self, a): 311*da0073e9SAndroid Build Coastguard Worker add_1 = a + torch.rand(4) 312*da0073e9SAndroid Build Coastguard Worker add_2 = add_1 + torch.rand(4) 313*da0073e9SAndroid Build Coastguard Worker linear_1 = self.linear(add_1) 314*da0073e9SAndroid Build Coastguard Worker add_3 = add_2 + linear_1 315*da0073e9SAndroid Build Coastguard Worker add_4 = add_2 + add_3 316*da0073e9SAndroid Build Coastguard Worker return add_4 317*da0073e9SAndroid Build Coastguard Worker 318*da0073e9SAndroid Build Coastguard Worker def get_node_to_latency_mapping(fx_module: GraphModule): 319*da0073e9SAndroid Build Coastguard Worker """Given a fx module, generate node latency for each node 320*da0073e9SAndroid Build Coastguard Worker based on the size of each node 321*da0073e9SAndroid Build Coastguard Worker """ 322*da0073e9SAndroid Build Coastguard Worker node_to_latency_mapping: Dict[Node, NodeLatency] = {} 323*da0073e9SAndroid Build Coastguard Worker for node in fx_module.graph.nodes: 324*da0073e9SAndroid Build Coastguard Worker if node.op not in {"output", "placeholder", "get_attr"}: 325*da0073e9SAndroid Build Coastguard Worker if node.size_bytes.total_size == node.size_bytes.output_size: 326*da0073e9SAndroid Build Coastguard Worker node_to_latency_mapping[node] = NodeLatency( 327*da0073e9SAndroid Build Coastguard Worker node.size_bytes.total_size, 2.0 * node.size_bytes.total_size 328*da0073e9SAndroid Build Coastguard Worker ) 329*da0073e9SAndroid Build Coastguard Worker else: 330*da0073e9SAndroid Build Coastguard Worker node_to_latency_mapping[node] = NodeLatency( 331*da0073e9SAndroid Build Coastguard Worker node.size_bytes.total_size, node.size_bytes.output_size 332*da0073e9SAndroid Build Coastguard Worker ) 333*da0073e9SAndroid Build Coastguard Worker return node_to_latency_mapping 334*da0073e9SAndroid Build Coastguard Worker 335*da0073e9SAndroid Build Coastguard Worker m = TestModule() 336*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(m) 337*da0073e9SAndroid Build Coastguard Worker a = torch.rand(4) 338*da0073e9SAndroid Build Coastguard Worker graph_manipulation.get_size_of_all_nodes(traced, [a]) 339*da0073e9SAndroid Build Coastguard Worker node_to_latency_mapping = get_node_to_latency_mapping(traced) 340*da0073e9SAndroid Build Coastguard Worker devices = [Device("dev_0", 200, 0), Device("dev_1", 200, 1)] 341*da0073e9SAndroid Build Coastguard Worker partitioner = Partitioner() 342*da0073e9SAndroid Build Coastguard Worker partitioner_config = PartitionerConfig(devices) 343*da0073e9SAndroid Build Coastguard Worker ret = partitioner.partition_graph(traced, m, partitioner_config) 344*da0073e9SAndroid Build Coastguard Worker module_with_submodules = ret.module_with_submodules 345*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced(a), module_with_submodules(a)) 346*da0073e9SAndroid Build Coastguard Worker partitions = partitioner.partitions 347*da0073e9SAndroid Build Coastguard Worker partition_to_latency_mapping = get_partition_to_latency_mapping( 348*da0073e9SAndroid Build Coastguard Worker partitions, node_to_latency_mapping 349*da0073e9SAndroid Build Coastguard Worker ) 350*da0073e9SAndroid Build Coastguard Worker for p in partition_to_latency_mapping: 351*da0073e9SAndroid Build Coastguard Worker if p.partition_id == 0: 352*da0073e9SAndroid Build Coastguard Worker assert partition_to_latency_mapping[p] == (128.0, 80.0, 160.0) 353*da0073e9SAndroid Build Coastguard Worker else: 354*da0073e9SAndroid Build Coastguard Worker assert partition_to_latency_mapping[p] == (16.0, 32.0, 32.0) 355*da0073e9SAndroid Build Coastguard Worker transfer_rate_bytes_per_sec = 2 356*da0073e9SAndroid Build Coastguard Worker critical_path_latency_sec = get_latency_of_partitioned_graph( 357*da0073e9SAndroid Build Coastguard Worker partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec 358*da0073e9SAndroid Build Coastguard Worker ) 359*da0073e9SAndroid Build Coastguard Worker assert critical_path_latency_sec == 208.0 360*da0073e9SAndroid Build Coastguard Worker 361*da0073e9SAndroid Build Coastguard Worker def test_cost_aware_partition(self): 362*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 363*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 364*da0073e9SAndroid Build Coastguard Worker super().__init__() 365*da0073e9SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(4, 4) 366*da0073e9SAndroid Build Coastguard Worker 367*da0073e9SAndroid Build Coastguard Worker def forward(self, a): 368*da0073e9SAndroid Build Coastguard Worker add_1 = a + torch.rand(4) 369*da0073e9SAndroid Build Coastguard Worker add_2 = add_1 + torch.rand(4) 370*da0073e9SAndroid Build Coastguard Worker linear_1 = self.linear(add_1) 371*da0073e9SAndroid Build Coastguard Worker add_3 = add_2 + torch.rand(4) 372*da0073e9SAndroid Build Coastguard Worker add_4 = add_2 + linear_1 373*da0073e9SAndroid Build Coastguard Worker add_5 = add_3 + add_4 374*da0073e9SAndroid Build Coastguard Worker return add_5 375*da0073e9SAndroid Build Coastguard Worker 376*da0073e9SAndroid Build Coastguard Worker def get_node_to_latency_mapping(fx_module: GraphModule): 377*da0073e9SAndroid Build Coastguard Worker node_to_latency_mapping: Dict[Node, NodeLatency] = {} 378*da0073e9SAndroid Build Coastguard Worker for node in fx_module.graph.nodes: 379*da0073e9SAndroid Build Coastguard Worker if node.op not in {"output", "placeholder", "get_attr"}: 380*da0073e9SAndroid Build Coastguard Worker if node.size_bytes.total_size == node.size_bytes.output_size: 381*da0073e9SAndroid Build Coastguard Worker node_to_latency_mapping[node] = NodeLatency( 382*da0073e9SAndroid Build Coastguard Worker node.size_bytes.total_size, 1 383*da0073e9SAndroid Build Coastguard Worker ) 384*da0073e9SAndroid Build Coastguard Worker else: 385*da0073e9SAndroid Build Coastguard Worker node_to_latency_mapping[node] = NodeLatency( 386*da0073e9SAndroid Build Coastguard Worker node.size_bytes.total_size, node.size_bytes.output_size 387*da0073e9SAndroid Build Coastguard Worker ) 388*da0073e9SAndroid Build Coastguard Worker return node_to_latency_mapping 389*da0073e9SAndroid Build Coastguard Worker 390*da0073e9SAndroid Build Coastguard Worker m = MyModule() 391*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(m) 392*da0073e9SAndroid Build Coastguard Worker a = torch.rand(4) 393*da0073e9SAndroid Build Coastguard Worker graph_manipulation.get_size_of_all_nodes(traced, [a]) 394*da0073e9SAndroid Build Coastguard Worker devices = [ 395*da0073e9SAndroid Build Coastguard Worker Device("dev_0", 125, 0), 396*da0073e9SAndroid Build Coastguard Worker Device("dev_1", 125, 1), 397*da0073e9SAndroid Build Coastguard Worker Device("dev_2", 125, 2), 398*da0073e9SAndroid Build Coastguard Worker Device("dev_3", 125, 3), 399*da0073e9SAndroid Build Coastguard Worker ] 400*da0073e9SAndroid Build Coastguard Worker node_to_latency_mapping = get_node_to_latency_mapping(traced) 401*da0073e9SAndroid Build Coastguard Worker partitioner_config = PartitionerConfig( 402*da0073e9SAndroid Build Coastguard Worker devices, 403*da0073e9SAndroid Build Coastguard Worker mode=PartitionMode.cost_aware, 404*da0073e9SAndroid Build Coastguard Worker transfer_rate_bytes_per_sec=2, 405*da0073e9SAndroid Build Coastguard Worker node_to_latency_mapping=node_to_latency_mapping, 406*da0073e9SAndroid Build Coastguard Worker ) 407*da0073e9SAndroid Build Coastguard Worker partitioner = Partitioner() 408*da0073e9SAndroid Build Coastguard Worker ret = partitioner.partition_graph(traced, m, partitioner_config) 409*da0073e9SAndroid Build Coastguard Worker module_with_submodules = ret.module_with_submodules 410*da0073e9SAndroid Build Coastguard Worker dag = ret.dag 411*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced(a), module_with_submodules(a)) 412*da0073e9SAndroid Build Coastguard Worker partitions = partitioner.partitions 413*da0073e9SAndroid Build Coastguard Worker partition_to_latency_mapping = get_partition_to_latency_mapping( 414*da0073e9SAndroid Build Coastguard Worker partitions, node_to_latency_mapping 415*da0073e9SAndroid Build Coastguard Worker ) 416*da0073e9SAndroid Build Coastguard Worker critical_path_latency_sec = get_latency_of_partitioned_graph( 417*da0073e9SAndroid Build Coastguard Worker partitions, 418*da0073e9SAndroid Build Coastguard Worker partition_to_latency_mapping, 419*da0073e9SAndroid Build Coastguard Worker partitioner_config.transfer_rate_bytes_per_sec, 420*da0073e9SAndroid Build Coastguard Worker ) 421*da0073e9SAndroid Build Coastguard Worker assert critical_path_latency_sec == 160.0 422*da0073e9SAndroid Build Coastguard Worker 423*da0073e9SAndroid Build Coastguard Worker def test_aot_based_partition(self): 424*da0073e9SAndroid Build Coastguard Worker class TestModule(torch.nn.Module): 425*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 426*da0073e9SAndroid Build Coastguard Worker super().__init__() 427*da0073e9SAndroid Build Coastguard Worker self.b = torch.rand(4) 428*da0073e9SAndroid Build Coastguard Worker self.c = torch.rand(4) 429*da0073e9SAndroid Build Coastguard Worker 430*da0073e9SAndroid Build Coastguard Worker def forward(self, a): 431*da0073e9SAndroid Build Coastguard Worker add_1 = a + self.b 432*da0073e9SAndroid Build Coastguard Worker add_2 = self.c + add_1 433*da0073e9SAndroid Build Coastguard Worker return add_2 434*da0073e9SAndroid Build Coastguard Worker 435*da0073e9SAndroid Build Coastguard Worker m = TestModule() 436*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(m) 437*da0073e9SAndroid Build Coastguard Worker a = torch.rand(4) 438*da0073e9SAndroid Build Coastguard Worker node_to_partition_id = {} 439*da0073e9SAndroid Build Coastguard Worker partition_to_logical_devices = {} 440*da0073e9SAndroid Build Coastguard Worker count = 0 441*da0073e9SAndroid Build Coastguard Worker graph_manipulation.get_size_of_all_nodes(traced, [a]) 442*da0073e9SAndroid Build Coastguard Worker for node in traced.graph.nodes: 443*da0073e9SAndroid Build Coastguard Worker if node.op not in {"placeholder", "get_attr", "output"}: 444*da0073e9SAndroid Build Coastguard Worker node_to_partition_id[node] = count 445*da0073e9SAndroid Build Coastguard Worker partition_to_logical_devices[count] = [0] 446*da0073e9SAndroid Build Coastguard Worker count += 1 447*da0073e9SAndroid Build Coastguard Worker devices = [Device("dev_0", 200, 0)] 448*da0073e9SAndroid Build Coastguard Worker partitioner_config = PartitionerConfig( 449*da0073e9SAndroid Build Coastguard Worker devices=devices, 450*da0073e9SAndroid Build Coastguard Worker mode=PartitionMode.aot_based, 451*da0073e9SAndroid Build Coastguard Worker node_to_partition_mapping=node_to_partition_id, 452*da0073e9SAndroid Build Coastguard Worker partition_to_logical_device_mapping=partition_to_logical_devices, 453*da0073e9SAndroid Build Coastguard Worker ) 454*da0073e9SAndroid Build Coastguard Worker partitioner = Partitioner() 455*da0073e9SAndroid Build Coastguard Worker ret = partitioner.partition_graph(traced, m, partitioner_config) 456*da0073e9SAndroid Build Coastguard Worker module_with_submodules = ret.module_with_submodules 457*da0073e9SAndroid Build Coastguard Worker dag = ret.dag 458*da0073e9SAndroid Build Coastguard Worker self.assertEqual(module_with_submodules(a), traced(a)) 459*da0073e9SAndroid Build Coastguard Worker for node in dag.nodes: 460*da0073e9SAndroid Build Coastguard Worker assert node.size_bytes == 48 461*da0073e9SAndroid Build Coastguard Worker assert node.logical_device_ids == [0] 462*da0073e9SAndroid Build Coastguard Worker 463*da0073e9SAndroid Build Coastguard Worker def test_replace_target_nodes_with(self): 464*da0073e9SAndroid Build Coastguard Worker class testModule(torch.nn.Module): 465*da0073e9SAndroid Build Coastguard Worker def forward(self, a, b): 466*da0073e9SAndroid Build Coastguard Worker return a + b 467*da0073e9SAndroid Build Coastguard Worker 468*da0073e9SAndroid Build Coastguard Worker m = testModule() 469*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(m) 470*da0073e9SAndroid Build Coastguard Worker input1 = torch.randn(1) 471*da0073e9SAndroid Build Coastguard Worker input2 = torch.randn(1) 472*da0073e9SAndroid Build Coastguard Worker assert (input1 + input2) == traced(input1, input2) 473*da0073e9SAndroid Build Coastguard Worker graph_manipulation.replace_target_nodes_with( 474*da0073e9SAndroid Build Coastguard Worker fx_module=traced, 475*da0073e9SAndroid Build Coastguard Worker old_op="call_function", 476*da0073e9SAndroid Build Coastguard Worker old_target=operator.add, 477*da0073e9SAndroid Build Coastguard Worker new_op="call_function", 478*da0073e9SAndroid Build Coastguard Worker new_target=operator.mul, 479*da0073e9SAndroid Build Coastguard Worker ) 480*da0073e9SAndroid Build Coastguard Worker assert (input1 * input2) == traced(input1, input2) 481*da0073e9SAndroid Build Coastguard Worker 482*da0073e9SAndroid Build Coastguard Worker def test_saturate_host(self): 483*da0073e9SAndroid Build Coastguard Worker class TestModule(torch.nn.Module): 484*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 485*da0073e9SAndroid Build Coastguard Worker super().__init__() 486*da0073e9SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(4, 4) 487*da0073e9SAndroid Build Coastguard Worker 488*da0073e9SAndroid Build Coastguard Worker def forward(self, a): 489*da0073e9SAndroid Build Coastguard Worker add_1 = a + torch.rand(4) 490*da0073e9SAndroid Build Coastguard Worker add_2 = add_1 + torch.rand(4) 491*da0073e9SAndroid Build Coastguard Worker linear_1 = self.linear(add_1) 492*da0073e9SAndroid Build Coastguard Worker add_3 = add_2 + linear_1 493*da0073e9SAndroid Build Coastguard Worker add_4 = add_2 + add_3 494*da0073e9SAndroid Build Coastguard Worker return add_4 495*da0073e9SAndroid Build Coastguard Worker 496*da0073e9SAndroid Build Coastguard Worker m = TestModule() 497*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(m) 498*da0073e9SAndroid Build Coastguard Worker a = torch.rand(4) 499*da0073e9SAndroid Build Coastguard Worker graph_manipulation.get_size_of_all_nodes(traced, [a]) 500*da0073e9SAndroid Build Coastguard Worker devices = [ 501*da0073e9SAndroid Build Coastguard Worker Device("dev_0", 200, 0), 502*da0073e9SAndroid Build Coastguard Worker Device("dev_1", 200, 1), 503*da0073e9SAndroid Build Coastguard Worker Device("dev_2", 100, 2), 504*da0073e9SAndroid Build Coastguard Worker Device("dev_3", 100, 3), 505*da0073e9SAndroid Build Coastguard Worker Device("dev_4", 200, 4), 506*da0073e9SAndroid Build Coastguard Worker Device("dev_5", 100, 5), 507*da0073e9SAndroid Build Coastguard Worker ] 508*da0073e9SAndroid Build Coastguard Worker partitioner = Partitioner() 509*da0073e9SAndroid Build Coastguard Worker # Without host saturation, the model will be split into two partitions. 510*da0073e9SAndroid Build Coastguard Worker # dev_0 holds partition 0 of 192 bytes and dev_1 holds partition 1 of 48 bytes. 511*da0073e9SAndroid Build Coastguard Worker partitioner_config = PartitionerConfig(devices, saturate_host=True) 512*da0073e9SAndroid Build Coastguard Worker ret = partitioner.partition_graph(traced, m, partitioner_config) 513*da0073e9SAndroid Build Coastguard Worker module_with_submodules = ret.module_with_submodules 514*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced(a), module_with_submodules(a)) 515*da0073e9SAndroid Build Coastguard Worker 516*da0073e9SAndroid Build Coastguard Worker partitions = partitioner.partitions 517*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(partitions), 2) 518*da0073e9SAndroid Build Coastguard Worker # With host saturation, partition 1 will be replicated to dev_4, and partition 2 519*da0073e9SAndroid Build Coastguard Worker # will be replicated to dev_2. 520*da0073e9SAndroid Build Coastguard Worker self.assertEqual(partitions[0].logical_device_ids, [0, 4]) 521*da0073e9SAndroid Build Coastguard Worker self.assertEqual(partitions[1].logical_device_ids, [1, 2]) 522*da0073e9SAndroid Build Coastguard Worker 523*da0073e9SAndroid Build Coastguard Worker @skipIfNoTorchVision 524*da0073e9SAndroid Build Coastguard Worker def test_conv_bn_fusion(self): 525*da0073e9SAndroid Build Coastguard Worker rn18 = resnet18().eval() 526*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(rn18) 527*da0073e9SAndroid Build Coastguard Worker fused = optimization.fuse(traced) 528*da0073e9SAndroid Build Coastguard Worker 529*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 530*da0073e9SAndroid Build Coastguard Worker all(not isinstance(m, torch.nn.BatchNorm2d) for m in fused.modules()) 531*da0073e9SAndroid Build Coastguard Worker ) 532*da0073e9SAndroid Build Coastguard Worker 533*da0073e9SAndroid Build Coastguard Worker N, C, H, W = 20, 3, 224, 224 534*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(N, C, H, W) 535*da0073e9SAndroid Build Coastguard Worker 536*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fused(inp), rn18(inp)) 537*da0073e9SAndroid Build Coastguard Worker 538*da0073e9SAndroid Build Coastguard Worker def test_conv_bn_fusion_not_running_state(self): 539*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 540*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 541*da0073e9SAndroid Build Coastguard Worker super().__init__() 542*da0073e9SAndroid Build Coastguard Worker self.conv = torch.nn.Conv2d(32, 64, 3, stride=2) 543*da0073e9SAndroid Build Coastguard Worker self.bn = torch.nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False) 544*da0073e9SAndroid Build Coastguard Worker 545*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 546*da0073e9SAndroid Build Coastguard Worker x = self.conv(x) 547*da0073e9SAndroid Build Coastguard Worker x = self.bn(x) 548*da0073e9SAndroid Build Coastguard Worker return x 549*da0073e9SAndroid Build Coastguard Worker 550*da0073e9SAndroid Build Coastguard Worker model = M().eval() 551*da0073e9SAndroid Build Coastguard Worker 552*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(model) 553*da0073e9SAndroid Build Coastguard Worker fused = optimization.fuse(traced) 554*da0073e9SAndroid Build Coastguard Worker inp = torch.randn([1, 32, 50, 50]) 555*da0073e9SAndroid Build Coastguard Worker 556*da0073e9SAndroid Build Coastguard Worker # bn need not be folded in conv 557*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 558*da0073e9SAndroid Build Coastguard Worker any(isinstance(m, torch.nn.BatchNorm2d) for m in fused.modules()) 559*da0073e9SAndroid Build Coastguard Worker ) 560*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fused(inp), model(inp)) 561*da0073e9SAndroid Build Coastguard Worker 562*da0073e9SAndroid Build Coastguard Worker def test_conv_bn_fusion_mixed_dtype(self): 563*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 564*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 565*da0073e9SAndroid Build Coastguard Worker super().__init__() 566*da0073e9SAndroid Build Coastguard Worker self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False, dtype=torch.bfloat16) 567*da0073e9SAndroid Build Coastguard Worker self.bn = torch.nn.BatchNorm2d(16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True) 568*da0073e9SAndroid Build Coastguard Worker 569*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 570*da0073e9SAndroid Build Coastguard Worker x = self.conv(x) 571*da0073e9SAndroid Build Coastguard Worker x = self.bn(x) 572*da0073e9SAndroid Build Coastguard Worker return x 573*da0073e9SAndroid Build Coastguard Worker 574*da0073e9SAndroid Build Coastguard Worker model = M().eval() 575*da0073e9SAndroid Build Coastguard Worker 576*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(model) 577*da0073e9SAndroid Build Coastguard Worker fused = optimization.fuse(traced) 578*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(1, 3, 64, 64, dtype=torch.bfloat16) 579*da0073e9SAndroid Build Coastguard Worker 580*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 581*da0073e9SAndroid Build Coastguard Worker all(not isinstance(m, torch.nn.BatchNorm2d) for m in fused.modules()) 582*da0073e9SAndroid Build Coastguard Worker ) 583*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fused(inp), model(inp)) 584*da0073e9SAndroid Build Coastguard Worker 585*da0073e9SAndroid Build Coastguard Worker def test_call_to_assert_no_msg(self): 586*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 587*da0073e9SAndroid Build Coastguard Worker def forward(self, a, b): 588*da0073e9SAndroid Build Coastguard Worker assert a == b 589*da0073e9SAndroid Build Coastguard Worker return a + b 590*da0073e9SAndroid Build Coastguard Worker 591*da0073e9SAndroid Build Coastguard Worker m = M() 592*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace_with_rewrite(m) 593*da0073e9SAndroid Build Coastguard Worker 594*da0073e9SAndroid Build Coastguard Worker # Make sure the graph is well-formed 595*da0073e9SAndroid Build Coastguard Worker traced.graph.lint() 596*da0073e9SAndroid Build Coastguard Worker 597*da0073e9SAndroid Build Coastguard Worker # Check the IR to make sure there's a call_function node with target == "Assert" 598*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 599*da0073e9SAndroid Build Coastguard Worker any( 600*da0073e9SAndroid Build Coastguard Worker node.op == "call_function" and node.target == torch._assert 601*da0073e9SAndroid Build Coastguard Worker for node in traced.graph.nodes 602*da0073e9SAndroid Build Coastguard Worker ) 603*da0073e9SAndroid Build Coastguard Worker ) 604*da0073e9SAndroid Build Coastguard Worker 605*da0073e9SAndroid Build Coastguard Worker # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to 606*da0073e9SAndroid Build Coastguard Worker traced(3, 3) 607*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, ""): 608*da0073e9SAndroid Build Coastguard Worker traced(3, 5) 609*da0073e9SAndroid Build Coastguard Worker 610*da0073e9SAndroid Build Coastguard Worker # Confirm that the output is correct 611*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced(3, 3), m(3, 3)) 612*da0073e9SAndroid Build Coastguard Worker 613*da0073e9SAndroid Build Coastguard Worker def test_meta_tracer(self): 614*da0073e9SAndroid Build Coastguard Worker class MetaTracerTestModule(torch.nn.Module): 615*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 616*da0073e9SAndroid Build Coastguard Worker super().__init__() 617*da0073e9SAndroid Build Coastguard Worker self.emb = torch.nn.Embedding(num_embeddings=42, embedding_dim=16) 618*da0073e9SAndroid Build Coastguard Worker self.layernorm = torch.nn.LayerNorm(16) 619*da0073e9SAndroid Build Coastguard Worker 620*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 621*da0073e9SAndroid Build Coastguard Worker emb = self.emb(x) 622*da0073e9SAndroid Build Coastguard Worker emb = emb + torch.arange(emb.shape[-1], dtype=torch.float, device=emb.device) 623*da0073e9SAndroid Build Coastguard Worker lol = self.layernorm(emb) 624*da0073e9SAndroid Build Coastguard Worker return torch.relu(lol) if lol.shape[0] < 30 else torch.sigmoid(lol) 625*da0073e9SAndroid Build Coastguard Worker 626*da0073e9SAndroid Build Coastguard Worker mttm = MetaTracerTestModule() 627*da0073e9SAndroid Build Coastguard Worker for BS in [15, 35]: 628*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(BS, dtype=torch.long).random_(42) 629*da0073e9SAndroid Build Coastguard Worker meta_args = {'x' : x.to(device='meta')} 630*da0073e9SAndroid Build Coastguard Worker gm = torch.fx.experimental.meta_tracer.symbolic_trace(mttm, meta_args=meta_args) 631*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(gm(x), mttm(x)) 632*da0073e9SAndroid Build Coastguard Worker 633*da0073e9SAndroid Build Coastguard Worker # Test serialization/deserialization 634*da0073e9SAndroid Build Coastguard Worker with tempfile.TemporaryDirectory() as tmp_dir: 635*da0073e9SAndroid Build Coastguard Worker with open(f'{tmp_dir}/meta_module.pkl', 'wb') as f: 636*da0073e9SAndroid Build Coastguard Worker pickle.dump(gm, f) 637*da0073e9SAndroid Build Coastguard Worker 638*da0073e9SAndroid Build Coastguard Worker with open(f'{tmp_dir}/meta_module.pkl', 'rb') as f: 639*da0073e9SAndroid Build Coastguard Worker loaded = pickle.load(f) 640*da0073e9SAndroid Build Coastguard Worker 641*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(loaded(x), mttm(x)) 642*da0073e9SAndroid Build Coastguard Worker 643*da0073e9SAndroid Build Coastguard Worker 644*da0073e9SAndroid Build Coastguard Worker def test_call_to_assert_with_msg(self): 645*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 646*da0073e9SAndroid Build Coastguard Worker def forward(self, a, b): 647*da0073e9SAndroid Build Coastguard Worker assert a == b, "test message" 648*da0073e9SAndroid Build Coastguard Worker return a + b 649*da0073e9SAndroid Build Coastguard Worker 650*da0073e9SAndroid Build Coastguard Worker m = M() 651*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace_with_rewrite(m) 652*da0073e9SAndroid Build Coastguard Worker 653*da0073e9SAndroid Build Coastguard Worker # Make sure the graph is well-formed 654*da0073e9SAndroid Build Coastguard Worker traced.graph.lint() 655*da0073e9SAndroid Build Coastguard Worker 656*da0073e9SAndroid Build Coastguard Worker # Check the IR to make sure there's a call_function node with target == "Assert" 657*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 658*da0073e9SAndroid Build Coastguard Worker any( 659*da0073e9SAndroid Build Coastguard Worker node.op == "call_function" and node.target == torch._assert 660*da0073e9SAndroid Build Coastguard Worker for node in traced.graph.nodes 661*da0073e9SAndroid Build Coastguard Worker ) 662*da0073e9SAndroid Build Coastguard Worker ) 663*da0073e9SAndroid Build Coastguard Worker 664*da0073e9SAndroid Build Coastguard Worker # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to 665*da0073e9SAndroid Build Coastguard Worker traced(3, 3) 666*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, "test message"): 667*da0073e9SAndroid Build Coastguard Worker traced(3, 5) 668*da0073e9SAndroid Build Coastguard Worker 669*da0073e9SAndroid Build Coastguard Worker # Confirm that the output is correct 670*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced(3, 3), m(3, 3)) 671*da0073e9SAndroid Build Coastguard Worker 672*da0073e9SAndroid Build Coastguard Worker def test_call_to_assert_with_empty_msg(self): 673*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 674*da0073e9SAndroid Build Coastguard Worker def forward(self, a, b): 675*da0073e9SAndroid Build Coastguard Worker assert a == b, "" 676*da0073e9SAndroid Build Coastguard Worker return a + b 677*da0073e9SAndroid Build Coastguard Worker 678*da0073e9SAndroid Build Coastguard Worker m = M() 679*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace_with_rewrite(m) 680*da0073e9SAndroid Build Coastguard Worker 681*da0073e9SAndroid Build Coastguard Worker # Make sure the graph is well-formed 682*da0073e9SAndroid Build Coastguard Worker traced.graph.lint() 683*da0073e9SAndroid Build Coastguard Worker 684*da0073e9SAndroid Build Coastguard Worker # Check the IR to make sure there's a call_function node with target == "Assert" 685*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 686*da0073e9SAndroid Build Coastguard Worker any( 687*da0073e9SAndroid Build Coastguard Worker node.op == "call_function" and node.target == torch._assert 688*da0073e9SAndroid Build Coastguard Worker for node in traced.graph.nodes 689*da0073e9SAndroid Build Coastguard Worker ) 690*da0073e9SAndroid Build Coastguard Worker ) 691*da0073e9SAndroid Build Coastguard Worker 692*da0073e9SAndroid Build Coastguard Worker # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to 693*da0073e9SAndroid Build Coastguard Worker traced(3, 3) 694*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, ""): 695*da0073e9SAndroid Build Coastguard Worker traced(3, 5) 696*da0073e9SAndroid Build Coastguard Worker 697*da0073e9SAndroid Build Coastguard Worker # Confirm that the output is correct 698*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced(3, 3), m(3, 3)) 699*da0073e9SAndroid Build Coastguard Worker 700*da0073e9SAndroid Build Coastguard Worker def test_call_to_assert_with_multiline_message(self): 701*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 702*da0073e9SAndroid Build Coastguard Worker def forward(self, a, b): 703*da0073e9SAndroid Build Coastguard Worker error_msg = """ 704*da0073e9SAndroid Build Coastguard WorkerAn error message with 705*da0073e9SAndroid Build Coastguard Workerterrible spacing 706*da0073e9SAndroid Build Coastguard Worker """ 707*da0073e9SAndroid Build Coastguard Worker assert a == b, error_msg 708*da0073e9SAndroid Build Coastguard Worker return a + b 709*da0073e9SAndroid Build Coastguard Worker 710*da0073e9SAndroid Build Coastguard Worker m = M() 711*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace_with_rewrite(m) 712*da0073e9SAndroid Build Coastguard Worker 713*da0073e9SAndroid Build Coastguard Worker # Make sure the graph is well-formed 714*da0073e9SAndroid Build Coastguard Worker traced.graph.lint() 715*da0073e9SAndroid Build Coastguard Worker 716*da0073e9SAndroid Build Coastguard Worker # Check the IR to make sure there's a call_function node with target == "Assert" 717*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 718*da0073e9SAndroid Build Coastguard Worker any( 719*da0073e9SAndroid Build Coastguard Worker node.op == "call_function" and node.target == torch._assert 720*da0073e9SAndroid Build Coastguard Worker for node in traced.graph.nodes 721*da0073e9SAndroid Build Coastguard Worker ) 722*da0073e9SAndroid Build Coastguard Worker ) 723*da0073e9SAndroid Build Coastguard Worker 724*da0073e9SAndroid Build Coastguard Worker # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to 725*da0073e9SAndroid Build Coastguard Worker error_msg = """ 726*da0073e9SAndroid Build Coastguard WorkerAn error message with 727*da0073e9SAndroid Build Coastguard Workerterrible spacing 728*da0073e9SAndroid Build Coastguard Worker """ 729*da0073e9SAndroid Build Coastguard Worker traced(3, 3) 730*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, error_msg): 731*da0073e9SAndroid Build Coastguard Worker traced(3, 5) 732*da0073e9SAndroid Build Coastguard Worker 733*da0073e9SAndroid Build Coastguard Worker # Confirm that the output is correct 734*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced(3, 3), m(3, 3)) 735*da0073e9SAndroid Build Coastguard Worker 736*da0073e9SAndroid Build Coastguard Worker def test_subgraph_creation(self): 737*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 738*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 739*da0073e9SAndroid Build Coastguard Worker super().__init__() 740*da0073e9SAndroid Build Coastguard Worker self.param = torch.nn.Parameter(torch.rand(3, 4)) 741*da0073e9SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(4, 5) 742*da0073e9SAndroid Build Coastguard Worker 743*da0073e9SAndroid Build Coastguard Worker def forward(self, x, y): 744*da0073e9SAndroid Build Coastguard Worker z = self.linear(x + self.param).clamp(min=0.0, max=1.0) 745*da0073e9SAndroid Build Coastguard Worker w = self.linear(y).clamp(min=0.0, max=1.0) 746*da0073e9SAndroid Build Coastguard Worker return z + w 747*da0073e9SAndroid Build Coastguard Worker 748*da0073e9SAndroid Build Coastguard Worker # symbolically trace model 749*da0073e9SAndroid Build Coastguard Worker my_module = MyModule() 750*da0073e9SAndroid Build Coastguard Worker my_module_traced = symbolic_trace(my_module) 751*da0073e9SAndroid Build Coastguard Worker 752*da0073e9SAndroid Build Coastguard Worker # random mod partitioning 753*da0073e9SAndroid Build Coastguard Worker partition_counter = 0 754*da0073e9SAndroid Build Coastguard Worker NPARTITIONS = 3 755*da0073e9SAndroid Build Coastguard Worker 756*da0073e9SAndroid Build Coastguard Worker # Add some random meta info to make sure it is kept around. 757*da0073e9SAndroid Build Coastguard Worker for node in my_module_traced.graph.nodes: 758*da0073e9SAndroid Build Coastguard Worker if node.op != "output": 759*da0073e9SAndroid Build Coastguard Worker node.meta["test_meta_info"] = True 760*da0073e9SAndroid Build Coastguard Worker 761*da0073e9SAndroid Build Coastguard Worker def mod_partition(node: Node): 762*da0073e9SAndroid Build Coastguard Worker nonlocal partition_counter 763*da0073e9SAndroid Build Coastguard Worker partition = partition_counter % NPARTITIONS 764*da0073e9SAndroid Build Coastguard Worker partition_counter = (partition_counter + 1) % NPARTITIONS 765*da0073e9SAndroid Build Coastguard Worker return partition 766*da0073e9SAndroid Build Coastguard Worker 767*da0073e9SAndroid Build Coastguard Worker # split module in module with submodules 768*da0073e9SAndroid Build Coastguard Worker module_with_submodules = split_module( 769*da0073e9SAndroid Build Coastguard Worker my_module_traced, my_module, mod_partition 770*da0073e9SAndroid Build Coastguard Worker ) 771*da0073e9SAndroid Build Coastguard Worker 772*da0073e9SAndroid Build Coastguard Worker # Check that test_meta_info was still on all nodes. 773*da0073e9SAndroid Build Coastguard Worker submodules = dict(module_with_submodules.named_modules()) 774*da0073e9SAndroid Build Coastguard Worker for node in module_with_submodules.graph.nodes: 775*da0073e9SAndroid Build Coastguard Worker if node.op == "call_module": 776*da0073e9SAndroid Build Coastguard Worker submod = submodules[node.target] 777*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(submod, torch.fx.GraphModule)) 778*da0073e9SAndroid Build Coastguard Worker for submod_node in submod.graph.nodes: 779*da0073e9SAndroid Build Coastguard Worker if submod_node.op != "output": 780*da0073e9SAndroid Build Coastguard Worker stored_op = submod_node.meta.get("test_meta_info") 781*da0073e9SAndroid Build Coastguard Worker self.assertTrue(stored_op is not None and stored_op) 782*da0073e9SAndroid Build Coastguard Worker 783*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, 4) 784*da0073e9SAndroid Build Coastguard Worker y = torch.rand(3, 4) 785*da0073e9SAndroid Build Coastguard Worker 786*da0073e9SAndroid Build Coastguard Worker orig_out = my_module_traced(x, y) 787*da0073e9SAndroid Build Coastguard Worker submodules_out = module_with_submodules(x, y) 788*da0073e9SAndroid Build Coastguard Worker 789*da0073e9SAndroid Build Coastguard Worker self.assertEqual(orig_out, submodules_out) 790*da0073e9SAndroid Build Coastguard Worker 791*da0073e9SAndroid Build Coastguard Worker def test_split_module_dead_code(self): 792*da0073e9SAndroid Build Coastguard Worker class ModWithDeadCode(torch.nn.Module): 793*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 794*da0073e9SAndroid Build Coastguard Worker output = x * 2 # we want this 795*da0073e9SAndroid Build Coastguard Worker dead_line = x + 2 # this is dead 796*da0073e9SAndroid Build Coastguard Worker return output 797*da0073e9SAndroid Build Coastguard Worker 798*da0073e9SAndroid Build Coastguard Worker mod = ModWithDeadCode() 799*da0073e9SAndroid Build Coastguard Worker traced = torch.fx.symbolic_trace(mod) 800*da0073e9SAndroid Build Coastguard Worker 801*da0073e9SAndroid Build Coastguard Worker # split into before (0), target (1), and after(2) 802*da0073e9SAndroid Build Coastguard Worker saw_mul = False 803*da0073e9SAndroid Build Coastguard Worker 804*da0073e9SAndroid Build Coastguard Worker def split_callback(n): 805*da0073e9SAndroid Build Coastguard Worker nonlocal saw_mul 806*da0073e9SAndroid Build Coastguard Worker if n.target == operator.mul: 807*da0073e9SAndroid Build Coastguard Worker saw_mul = True 808*da0073e9SAndroid Build Coastguard Worker return 1 809*da0073e9SAndroid Build Coastguard Worker 810*da0073e9SAndroid Build Coastguard Worker if not saw_mul: 811*da0073e9SAndroid Build Coastguard Worker return 0 812*da0073e9SAndroid Build Coastguard Worker if saw_mul: 813*da0073e9SAndroid Build Coastguard Worker return 2 814*da0073e9SAndroid Build Coastguard Worker 815*da0073e9SAndroid Build Coastguard Worker split = split_module(traced, mod, split_callback) 816*da0073e9SAndroid Build Coastguard Worker 817*da0073e9SAndroid Build Coastguard Worker x = torch.randn((5,)) 818*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close( 819*da0073e9SAndroid Build Coastguard Worker split(x), traced(x) 820*da0073e9SAndroid Build Coastguard Worker ) 821*da0073e9SAndroid Build Coastguard Worker 822*da0073e9SAndroid Build Coastguard Worker 823*da0073e9SAndroid Build Coastguard Worker def test_split_module_kwargs_expansion(self): 824*da0073e9SAndroid Build Coastguard Worker class ModuleWithKwargsExpansion(torch.nn.Module): 825*da0073e9SAndroid Build Coastguard Worker def forward(self, x, **kwargs): 826*da0073e9SAndroid Build Coastguard Worker return x + kwargs['foo'] 827*da0073e9SAndroid Build Coastguard Worker 828*da0073e9SAndroid Build Coastguard Worker mod = ModuleWithKwargsExpansion() 829*da0073e9SAndroid Build Coastguard Worker traced = torch.fx.symbolic_trace(mod) 830*da0073e9SAndroid Build Coastguard Worker 831*da0073e9SAndroid Build Coastguard Worker seen_getitem = False 832*da0073e9SAndroid Build Coastguard Worker 833*da0073e9SAndroid Build Coastguard Worker def split_callback(n): 834*da0073e9SAndroid Build Coastguard Worker nonlocal seen_getitem 835*da0073e9SAndroid Build Coastguard Worker split_idx = int(seen_getitem) 836*da0073e9SAndroid Build Coastguard Worker if n.target == operator.getitem: 837*da0073e9SAndroid Build Coastguard Worker seen_getitem = True 838*da0073e9SAndroid Build Coastguard Worker return split_idx 839*da0073e9SAndroid Build Coastguard Worker 840*da0073e9SAndroid Build Coastguard Worker split = split_module(traced, mod, split_callback) 841*da0073e9SAndroid Build Coastguard Worker 842*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 3) 843*da0073e9SAndroid Build Coastguard Worker foo = torch.randn(5, 3) 844*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(split(x, foo=foo), traced(x, foo=foo)) 845*da0073e9SAndroid Build Coastguard Worker 846*da0073e9SAndroid Build Coastguard Worker @skipIfNoTorchVision 847*da0073e9SAndroid Build Coastguard Worker def test_subgraph_trivial_resnet(self): 848*da0073e9SAndroid Build Coastguard Worker # Smoke test trivially splitting resnet into 1 partition works 849*da0073e9SAndroid Build Coastguard Worker # There was an issue before causing submodule names to be aliased 850*da0073e9SAndroid Build Coastguard Worker m = resnet18() 851*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(m) 852*da0073e9SAndroid Build Coastguard Worker a = torch.rand(64, 3, 7, 7) 853*da0073e9SAndroid Build Coastguard Worker module_with_submodules = split_module(traced, m, lambda node: 0) 854*da0073e9SAndroid Build Coastguard Worker module_with_submodules(a) 855*da0073e9SAndroid Build Coastguard Worker 856*da0073e9SAndroid Build Coastguard Worker def test_split_module_default_arg(self): 857*da0073e9SAndroid Build Coastguard Worker class ModelToTrace(torch.nn.Module): 858*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 859*da0073e9SAndroid Build Coastguard Worker super().__init__() 860*da0073e9SAndroid Build Coastguard Worker self.lin = torch.nn.Linear(512, 512) 861*da0073e9SAndroid Build Coastguard Worker 862*da0073e9SAndroid Build Coastguard Worker def forward(self, x, targets=None): 863*da0073e9SAndroid Build Coastguard Worker x = self.lin(x) 864*da0073e9SAndroid Build Coastguard Worker 865*da0073e9SAndroid Build Coastguard Worker if targets is not None: 866*da0073e9SAndroid Build Coastguard Worker x = x + targets 867*da0073e9SAndroid Build Coastguard Worker 868*da0073e9SAndroid Build Coastguard Worker return x 869*da0073e9SAndroid Build Coastguard Worker 870*da0073e9SAndroid Build Coastguard Worker mtt = ModelToTrace() 871*da0073e9SAndroid Build Coastguard Worker traced = torch.fx.symbolic_trace(mtt, concrete_args={'targets': None}) 872*da0073e9SAndroid Build Coastguard Worker 873*da0073e9SAndroid Build Coastguard Worker split = split_module(traced, mtt, lambda node: 0) 874*da0073e9SAndroid Build Coastguard Worker 875*da0073e9SAndroid Build Coastguard Worker x = torch.randn(50, 512) 876*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(split(x), traced(x)) 877*da0073e9SAndroid Build Coastguard Worker 878*da0073e9SAndroid Build Coastguard Worker def test_normalize_binary_operators(self): 879*da0073e9SAndroid Build Coastguard Worker ops_to_test = { 880*da0073e9SAndroid Build Coastguard Worker torch.add, 881*da0073e9SAndroid Build Coastguard Worker torch.mul, 882*da0073e9SAndroid Build Coastguard Worker torch.sub, 883*da0073e9SAndroid Build Coastguard Worker torch.div, 884*da0073e9SAndroid Build Coastguard Worker torch.floor_divide, 885*da0073e9SAndroid Build Coastguard Worker torch.remainder, 886*da0073e9SAndroid Build Coastguard Worker torch.eq, 887*da0073e9SAndroid Build Coastguard Worker torch.ne, 888*da0073e9SAndroid Build Coastguard Worker torch.lt, 889*da0073e9SAndroid Build Coastguard Worker torch.le, 890*da0073e9SAndroid Build Coastguard Worker torch.gt, 891*da0073e9SAndroid Build Coastguard Worker torch.ge, 892*da0073e9SAndroid Build Coastguard Worker } 893*da0073e9SAndroid Build Coastguard Worker 894*da0073e9SAndroid Build Coastguard Worker # Test Tensor/Tensor callsite 895*da0073e9SAndroid Build Coastguard Worker for op in ops_to_test: 896*da0073e9SAndroid Build Coastguard Worker 897*da0073e9SAndroid Build Coastguard Worker class WrapperMod(torch.nn.Module): 898*da0073e9SAndroid Build Coastguard Worker def forward(self, x, y): 899*da0073e9SAndroid Build Coastguard Worker return op(x, y) 900*da0073e9SAndroid Build Coastguard Worker 901*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(WrapperMod()) 902*da0073e9SAndroid Build Coastguard Worker normalized = NormalizeOperators(traced).transform() 903*da0073e9SAndroid Build Coastguard Worker x, y = torch.randn(3, 4), torch.randn(3, 4) 904*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(traced(x, y), normalized(x, y)) 905*da0073e9SAndroid Build Coastguard Worker self.assertFalse( 906*da0073e9SAndroid Build Coastguard Worker any(n.target in ops_to_test for n in normalized.graph.nodes) 907*da0073e9SAndroid Build Coastguard Worker ) 908*da0073e9SAndroid Build Coastguard Worker 909*da0073e9SAndroid Build Coastguard Worker # Test Tensor/scalar callsite 910*da0073e9SAndroid Build Coastguard Worker for op in ops_to_test: 911*da0073e9SAndroid Build Coastguard Worker 912*da0073e9SAndroid Build Coastguard Worker class WrapperMod(torch.nn.Module): 913*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 914*da0073e9SAndroid Build Coastguard Worker return op(x, 42) 915*da0073e9SAndroid Build Coastguard Worker 916*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(WrapperMod()) 917*da0073e9SAndroid Build Coastguard Worker normalized = NormalizeOperators(traced).transform() 918*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 4) 919*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(traced(x), normalized(x)) 920*da0073e9SAndroid Build Coastguard Worker self.assertFalse( 921*da0073e9SAndroid Build Coastguard Worker any(n.target in ops_to_test for n in normalized.graph.nodes) 922*da0073e9SAndroid Build Coastguard Worker ) 923*da0073e9SAndroid Build Coastguard Worker 924*da0073e9SAndroid Build Coastguard Worker @skipIfNoTorchVision 925*da0073e9SAndroid Build Coastguard Worker def test_normalize_args(self): 926*da0073e9SAndroid Build Coastguard Worker m = resnet18() 927*da0073e9SAndroid Build Coastguard Worker 928*da0073e9SAndroid Build Coastguard Worker class FunctionalTracer(torch.fx.Tracer): 929*da0073e9SAndroid Build Coastguard Worker def is_leaf_module( 930*da0073e9SAndroid Build Coastguard Worker self, m: torch.nn.Module, module_qualified_name: str 931*da0073e9SAndroid Build Coastguard Worker ) -> bool: 932*da0073e9SAndroid Build Coastguard Worker # `leaves` contains the set of standard `nn.Modules` that are not 933*da0073e9SAndroid Build Coastguard Worker # currently symbolically traceable. Ideally this set would be empty 934*da0073e9SAndroid Build Coastguard Worker leaves = {torch.nn.BatchNorm2d} 935*da0073e9SAndroid Build Coastguard Worker return type(m) in leaves 936*da0073e9SAndroid Build Coastguard Worker 937*da0073e9SAndroid Build Coastguard Worker traced = torch.fx.GraphModule(m, FunctionalTracer().trace(m)) 938*da0073e9SAndroid Build Coastguard Worker 939*da0073e9SAndroid Build Coastguard Worker input = torch.randn(5, 3, 224, 224) 940*da0073e9SAndroid Build Coastguard Worker ref_outs = traced(input) 941*da0073e9SAndroid Build Coastguard Worker 942*da0073e9SAndroid Build Coastguard Worker ShapeProp(traced).propagate(input) 943*da0073e9SAndroid Build Coastguard Worker traced = NormalizeArgs(traced).transform() 944*da0073e9SAndroid Build Coastguard Worker 945*da0073e9SAndroid Build Coastguard Worker modules = dict(traced.named_modules()) 946*da0073e9SAndroid Build Coastguard Worker 947*da0073e9SAndroid Build Coastguard Worker for node in traced.graph.nodes: 948*da0073e9SAndroid Build Coastguard Worker if node.op == "call_function" and node.target != operator.add: 949*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(node.args), 0) 950*da0073e9SAndroid Build Coastguard Worker elif node.op == "call_module": 951*da0073e9SAndroid Build Coastguard Worker submod_class = modules[node.target].__class__ 952*da0073e9SAndroid Build Coastguard Worker nn_class = getattr(torch.nn, submod_class.__name__) 953*da0073e9SAndroid Build Coastguard Worker if submod_class == nn_class: 954*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(node.args), 0) 955*da0073e9SAndroid Build Coastguard Worker traced(input) 956*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced(input), ref_outs) 957*da0073e9SAndroid Build Coastguard Worker 958*da0073e9SAndroid Build Coastguard Worker def test_normalize_modules_exhaustive(self): 959*da0073e9SAndroid Build Coastguard Worker """ 960*da0073e9SAndroid Build Coastguard Worker Exhaustively test `Node.normalized_arguments` on all standard 961*da0073e9SAndroid Build Coastguard Worker torch.nn Module classes 962*da0073e9SAndroid Build Coastguard Worker """ 963*da0073e9SAndroid Build Coastguard Worker for test_params in module_tests + new_module_tests: 964*da0073e9SAndroid Build Coastguard Worker if "constructor" not in test_params: 965*da0073e9SAndroid Build Coastguard Worker constructor = getattr(torch.nn, test_params["module_name"]) 966*da0073e9SAndroid Build Coastguard Worker else: 967*da0073e9SAndroid Build Coastguard Worker constructor = test_params["constructor"] 968*da0073e9SAndroid Build Coastguard Worker 969*da0073e9SAndroid Build Coastguard Worker if "constructor_args" not in test_params: 970*da0073e9SAndroid Build Coastguard Worker args = () 971*da0073e9SAndroid Build Coastguard Worker else: 972*da0073e9SAndroid Build Coastguard Worker args = test_params["constructor_args"] 973*da0073e9SAndroid Build Coastguard Worker 974*da0073e9SAndroid Build Coastguard Worker mod = constructor(*args) 975*da0073e9SAndroid Build Coastguard Worker # Skip modules that are not standard `torch.nn` 976*da0073e9SAndroid Build Coastguard Worker # instances, including functionals. (functionals 977*da0073e9SAndroid Build Coastguard Worker # are tested in test_normalize_args) 978*da0073e9SAndroid Build Coastguard Worker if mod.__class__.__name__ not in dir(torch.nn): 979*da0073e9SAndroid Build Coastguard Worker continue 980*da0073e9SAndroid Build Coastguard Worker 981*da0073e9SAndroid Build Coastguard Worker if "input_fn" not in test_params: 982*da0073e9SAndroid Build Coastguard Worker inputs = torch.randn(test_params["input_size"]) 983*da0073e9SAndroid Build Coastguard Worker else: 984*da0073e9SAndroid Build Coastguard Worker inputs = test_params["input_fn"]() 985*da0073e9SAndroid Build Coastguard Worker 986*da0073e9SAndroid Build Coastguard Worker if not isinstance(inputs, (tuple, list)): 987*da0073e9SAndroid Build Coastguard Worker inputs = (inputs,) 988*da0073e9SAndroid Build Coastguard Worker 989*da0073e9SAndroid Build Coastguard Worker params = ", ".join(f"v{i}" for i in range(len(inputs))) 990*da0073e9SAndroid Build Coastguard Worker 991*da0073e9SAndroid Build Coastguard Worker # Generate a class to wrap this standard `nn.Module` instance 992*da0073e9SAndroid Build Coastguard Worker test_classname = f"Test{mod.__class__.__name__}" 993*da0073e9SAndroid Build Coastguard Worker test_mod_code = f""" 994*da0073e9SAndroid Build Coastguard Workerclass {test_classname}(torch.nn.Module): 995*da0073e9SAndroid Build Coastguard Worker def __init__(self, mod): 996*da0073e9SAndroid Build Coastguard Worker super().__init__() 997*da0073e9SAndroid Build Coastguard Worker self.mod = mod 998*da0073e9SAndroid Build Coastguard Worker 999*da0073e9SAndroid Build Coastguard Worker def forward(self, {params}): 1000*da0073e9SAndroid Build Coastguard Worker return self.mod({params}) 1001*da0073e9SAndroid Build Coastguard Worker """ 1002*da0073e9SAndroid Build Coastguard Worker 1003*da0073e9SAndroid Build Coastguard Worker gbls = {"torch": torch} 1004*da0073e9SAndroid Build Coastguard Worker exec(test_mod_code, gbls) 1005*da0073e9SAndroid Build Coastguard Worker 1006*da0073e9SAndroid Build Coastguard Worker test_instance = gbls[test_classname](mod) 1007*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(test_instance) 1008*da0073e9SAndroid Build Coastguard Worker 1009*da0073e9SAndroid Build Coastguard Worker # Use `Node.normalized_arguments` to get a new set of arguments 1010*da0073e9SAndroid Build Coastguard Worker # to feed to the Module. Then, rewrite the node to only take 1011*da0073e9SAndroid Build Coastguard Worker # in those arguments as kwargs 1012*da0073e9SAndroid Build Coastguard Worker modules = dict(traced.named_modules()) 1013*da0073e9SAndroid Build Coastguard Worker for node in traced.graph.nodes: 1014*da0073e9SAndroid Build Coastguard Worker if node.op == "call_module": 1015*da0073e9SAndroid Build Coastguard Worker submod_class = modules[node.target].__class__ 1016*da0073e9SAndroid Build Coastguard Worker nn_class = getattr(torch.nn, submod_class.__name__) 1017*da0073e9SAndroid Build Coastguard Worker if submod_class == nn_class: 1018*da0073e9SAndroid Build Coastguard Worker normalized_args = node.normalized_arguments(traced) 1019*da0073e9SAndroid Build Coastguard Worker normalized_args2 = normalize_module( 1020*da0073e9SAndroid Build Coastguard Worker traced, node.target, node.args, node.kwargs 1021*da0073e9SAndroid Build Coastguard Worker ) 1022*da0073e9SAndroid Build Coastguard Worker assert normalized_args == normalized_args2 1023*da0073e9SAndroid Build Coastguard Worker assert normalized_args 1024*da0073e9SAndroid Build Coastguard Worker node.args = normalized_args.args 1025*da0073e9SAndroid Build Coastguard Worker node.kwargs = normalized_args.kwargs 1026*da0073e9SAndroid Build Coastguard Worker 1027*da0073e9SAndroid Build Coastguard Worker traced.recompile() 1028*da0073e9SAndroid Build Coastguard Worker 1029*da0073e9SAndroid Build Coastguard Worker # These Modules have an RNG in their forward, so testing 1030*da0073e9SAndroid Build Coastguard Worker # correctness by comparing outputs is not correct. Skip that 1031*da0073e9SAndroid Build Coastguard Worker # check for these 1032*da0073e9SAndroid Build Coastguard Worker stochastic_modules = {"FractionalMaxPool2d", "FractionalMaxPool3d", "RReLU"} 1033*da0073e9SAndroid Build Coastguard Worker 1034*da0073e9SAndroid Build Coastguard Worker if mod.__class__.__name__ not in stochastic_modules: 1035*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced(*inputs), mod(*inputs)) 1036*da0073e9SAndroid Build Coastguard Worker 1037*da0073e9SAndroid Build Coastguard Worker traced = NormalizeArgs(symbolic_trace(test_instance)).transform() 1038*da0073e9SAndroid Build Coastguard Worker modules = dict(traced.named_modules()) 1039*da0073e9SAndroid Build Coastguard Worker for node in traced.graph.nodes: 1040*da0073e9SAndroid Build Coastguard Worker if node.op == "call_module": 1041*da0073e9SAndroid Build Coastguard Worker submod_class = modules[node.target].__class__ 1042*da0073e9SAndroid Build Coastguard Worker nn_class = getattr(torch.nn, submod_class.__name__) 1043*da0073e9SAndroid Build Coastguard Worker if submod_class == nn_class: 1044*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(node.args), 0) 1045*da0073e9SAndroid Build Coastguard Worker 1046*da0073e9SAndroid Build Coastguard Worker def test_normalize_args_preserve_meta(self): 1047*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 1048*da0073e9SAndroid Build Coastguard Worker def forward(self, a): 1049*da0073e9SAndroid Build Coastguard Worker return torch.add(a, 3) 1050*da0073e9SAndroid Build Coastguard Worker 1051*da0073e9SAndroid Build Coastguard Worker m = MyModule() 1052*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(m) 1053*da0073e9SAndroid Build Coastguard Worker 1054*da0073e9SAndroid Build Coastguard Worker for node in traced.graph.nodes: 1055*da0073e9SAndroid Build Coastguard Worker if node.op == "call_function" and node.target == torch.add: 1056*da0073e9SAndroid Build Coastguard Worker node.meta["my_key"] = 7 1057*da0073e9SAndroid Build Coastguard Worker break 1058*da0073e9SAndroid Build Coastguard Worker else: 1059*da0073e9SAndroid Build Coastguard Worker self.fail("Didn't find call_function torch.add") 1060*da0073e9SAndroid Build Coastguard Worker 1061*da0073e9SAndroid Build Coastguard Worker input = torch.randn(2, 3) 1062*da0073e9SAndroid Build Coastguard Worker ShapeProp(traced).propagate(input) 1063*da0073e9SAndroid Build Coastguard Worker traced = NormalizeArgs(traced).transform() 1064*da0073e9SAndroid Build Coastguard Worker 1065*da0073e9SAndroid Build Coastguard Worker for node in traced.graph.nodes: 1066*da0073e9SAndroid Build Coastguard Worker if node.op == "call_function" and node.target == torch.add: 1067*da0073e9SAndroid Build Coastguard Worker self.assertTrue("my_key" in node.meta) 1068*da0073e9SAndroid Build Coastguard Worker self.assertEqual(node.meta["my_key"], 7) 1069*da0073e9SAndroid Build Coastguard Worker break 1070*da0073e9SAndroid Build Coastguard Worker else: 1071*da0073e9SAndroid Build Coastguard Worker self.fail("Didn't find call_function torch.add") 1072*da0073e9SAndroid Build Coastguard Worker 1073*da0073e9SAndroid Build Coastguard Worker def test_normalize_args_perserve_type(self): 1074*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 1075*da0073e9SAndroid Build Coastguard Worker def forward(self, a: List[torch.Tensor]): 1076*da0073e9SAndroid Build Coastguard Worker return torch.add(a[0], a[1]) 1077*da0073e9SAndroid Build Coastguard Worker 1078*da0073e9SAndroid Build Coastguard Worker m = MyModule() 1079*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(m) 1080*da0073e9SAndroid Build Coastguard Worker traced = NormalizeArgs(traced).transform() 1081*da0073e9SAndroid Build Coastguard Worker 1082*da0073e9SAndroid Build Coastguard Worker for node in traced.graph.nodes: 1083*da0073e9SAndroid Build Coastguard Worker if node.op == "placeholder": 1084*da0073e9SAndroid Build Coastguard Worker self.assertEqual(node.type, List[torch.Tensor]) 1085*da0073e9SAndroid Build Coastguard Worker 1086*da0073e9SAndroid Build Coastguard Worker @skipIfNoTorchVision 1087*da0073e9SAndroid Build Coastguard Worker def test_annotate_returns_with_schema(self): 1088*da0073e9SAndroid Build Coastguard Worker m = resnet18() 1089*da0073e9SAndroid Build Coastguard Worker 1090*da0073e9SAndroid Build Coastguard Worker traced_modules = symbolic_trace(m) 1091*da0073e9SAndroid Build Coastguard Worker traced_modules_annotated = AnnotateTypesWithSchema(traced_modules).transform() 1092*da0073e9SAndroid Build Coastguard Worker for node in traced_modules_annotated.graph.nodes: 1093*da0073e9SAndroid Build Coastguard Worker if node.type is None: 1094*da0073e9SAndroid Build Coastguard Worker check = (node.op, node.target) 1095*da0073e9SAndroid Build Coastguard Worker self.assertIn( 1096*da0073e9SAndroid Build Coastguard Worker check, 1097*da0073e9SAndroid Build Coastguard Worker { 1098*da0073e9SAndroid Build Coastguard Worker ("placeholder", "x"), 1099*da0073e9SAndroid Build Coastguard Worker ("call_module", "maxpool"), 1100*da0073e9SAndroid Build Coastguard Worker ("call_function", operator.add), 1101*da0073e9SAndroid Build Coastguard Worker ("call_function", torch.flatten), 1102*da0073e9SAndroid Build Coastguard Worker ("output", "output"), 1103*da0073e9SAndroid Build Coastguard Worker } 1104*da0073e9SAndroid Build Coastguard Worker ) 1105*da0073e9SAndroid Build Coastguard Worker 1106*da0073e9SAndroid Build Coastguard Worker # Smoke test torchscript compilation since now we're emitting type annotations 1107*da0073e9SAndroid Build Coastguard Worker torch.jit.script(traced_modules_annotated) 1108*da0073e9SAndroid Build Coastguard Worker 1109*da0073e9SAndroid Build Coastguard Worker class FunctionalTracer(torch.fx.Tracer): 1110*da0073e9SAndroid Build Coastguard Worker def is_leaf_module( 1111*da0073e9SAndroid Build Coastguard Worker self, m: torch.nn.Module, module_qualified_name: str 1112*da0073e9SAndroid Build Coastguard Worker ) -> bool: 1113*da0073e9SAndroid Build Coastguard Worker # `leaves` contains the set of standard `nn.Modules` that are not 1114*da0073e9SAndroid Build Coastguard Worker # currently symbolically traceable. Ideally this set would be empty 1115*da0073e9SAndroid Build Coastguard Worker leaves = {torch.nn.BatchNorm2d} 1116*da0073e9SAndroid Build Coastguard Worker return type(m) in leaves 1117*da0073e9SAndroid Build Coastguard Worker 1118*da0073e9SAndroid Build Coastguard Worker traced_functionals = torch.fx.GraphModule(m, FunctionalTracer().trace(m)) 1119*da0073e9SAndroid Build Coastguard Worker 1120*da0073e9SAndroid Build Coastguard Worker traced_functionals_annotated = AnnotateTypesWithSchema( 1121*da0073e9SAndroid Build Coastguard Worker traced_functionals 1122*da0073e9SAndroid Build Coastguard Worker ).transform() 1123*da0073e9SAndroid Build Coastguard Worker for node in traced_functionals_annotated.graph.nodes: 1124*da0073e9SAndroid Build Coastguard Worker if node.type is None: 1125*da0073e9SAndroid Build Coastguard Worker check = (node.op, node.target) 1126*da0073e9SAndroid Build Coastguard Worker excluded_nodes = { 1127*da0073e9SAndroid Build Coastguard Worker ("placeholder", "x"), 1128*da0073e9SAndroid Build Coastguard Worker # Return type differs based on boolean dispatch :( 1129*da0073e9SAndroid Build Coastguard Worker ("call_function", torch.nn.functional.max_pool2d), 1130*da0073e9SAndroid Build Coastguard Worker ("output", "output"), 1131*da0073e9SAndroid Build Coastguard Worker } 1132*da0073e9SAndroid Build Coastguard Worker # AnnotateTypesWithSchema doesn't work with bound C++ functions 1133*da0073e9SAndroid Build Coastguard Worker if not isinstance(node.target, BuiltinFunctionType): 1134*da0073e9SAndroid Build Coastguard Worker self.assertIn(check, excluded_nodes) 1135*da0073e9SAndroid Build Coastguard Worker 1136*da0073e9SAndroid Build Coastguard Worker # Smoke test torchscript compilation since now we're emitting type annotations 1137*da0073e9SAndroid Build Coastguard Worker torch.jit.script(traced_functionals_annotated) 1138*da0073e9SAndroid Build Coastguard Worker 1139*da0073e9SAndroid Build Coastguard Worker def test_annotate_getitem_node(self): 1140*da0073e9SAndroid Build Coastguard Worker class CustomType: 1141*da0073e9SAndroid Build Coastguard Worker pass 1142*da0073e9SAndroid Build Coastguard Worker 1143*da0073e9SAndroid Build Coastguard Worker class CustomNamedTuple(NamedTuple): 1144*da0073e9SAndroid Build Coastguard Worker x: int 1145*da0073e9SAndroid Build Coastguard Worker y: float 1146*da0073e9SAndroid Build Coastguard Worker 1147*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 1148*da0073e9SAndroid Build Coastguard Worker def forward(self, inp: Tuple[CustomType, torch.Tensor], inp2: List[CustomType], inp3: CustomNamedTuple): 1149*da0073e9SAndroid Build Coastguard Worker inp_0 = inp[0] 1150*da0073e9SAndroid Build Coastguard Worker inp_1 = inp[1] 1151*da0073e9SAndroid Build Coastguard Worker inp2_0 = inp2[0] 1152*da0073e9SAndroid Build Coastguard Worker inp3_x = inp3.x 1153*da0073e9SAndroid Build Coastguard Worker inp3_y = inp3.y 1154*da0073e9SAndroid Build Coastguard Worker return inp_0 + inp_1 + inp2_0 + inp3_x + inp3_y 1155*da0073e9SAndroid Build Coastguard Worker 1156*da0073e9SAndroid Build Coastguard Worker my_module = MyModule() 1157*da0073e9SAndroid Build Coastguard Worker my_module_traced = torch.fx.symbolic_trace(my_module) 1158*da0073e9SAndroid Build Coastguard Worker 1159*da0073e9SAndroid Build Coastguard Worker # by default, fx transform loses type annotation of getitem nodes. 1160*da0073e9SAndroid Build Coastguard Worker for node in my_module_traced.graph.nodes: 1161*da0073e9SAndroid Build Coastguard Worker if node.target == operator.getitem: 1162*da0073e9SAndroid Build Coastguard Worker assert node.type is None 1163*da0073e9SAndroid Build Coastguard Worker 1164*da0073e9SAndroid Build Coastguard Worker annotate_getitem_nodes(my_module_traced.graph) 1165*da0073e9SAndroid Build Coastguard Worker 1166*da0073e9SAndroid Build Coastguard Worker for node in my_module_traced.graph.nodes: 1167*da0073e9SAndroid Build Coastguard Worker if node.target == operator.getitem: 1168*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(node.type, f"Node {node} should be annotated but is not.") 1169*da0073e9SAndroid Build Coastguard Worker 1170*da0073e9SAndroid Build Coastguard Worker def test_subgraph_uniquename(self): 1171*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 1172*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1173*da0073e9SAndroid Build Coastguard Worker super().__init__() 1174*da0073e9SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(4, 4) 1175*da0073e9SAndroid Build Coastguard Worker 1176*da0073e9SAndroid Build Coastguard Worker def forward(self, a, b, c, d): 1177*da0073e9SAndroid Build Coastguard Worker add_1 = a + b 1178*da0073e9SAndroid Build Coastguard Worker add_2 = add_1 + c 1179*da0073e9SAndroid Build Coastguard Worker linear_1 = self.linear(add_1) 1180*da0073e9SAndroid Build Coastguard Worker add_3 = add_2 + d 1181*da0073e9SAndroid Build Coastguard Worker add_4 = add_2 + linear_1 1182*da0073e9SAndroid Build Coastguard Worker add_5 = add_3 + add_4 1183*da0073e9SAndroid Build Coastguard Worker return add_5 1184*da0073e9SAndroid Build Coastguard Worker 1185*da0073e9SAndroid Build Coastguard Worker a, b, c, d = torch.ones(4), torch.ones(4), torch.ones(4), torch.ones(4) 1186*da0073e9SAndroid Build Coastguard Worker mm = MyModule() 1187*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(mm) 1188*da0073e9SAndroid Build Coastguard Worker 1189*da0073e9SAndroid Build Coastguard Worker def split_cb(node: torch.fx.Node): 1190*da0073e9SAndroid Build Coastguard Worker if node.name == "a" or node.name == "b" or node.name == "add": 1191*da0073e9SAndroid Build Coastguard Worker return 0 1192*da0073e9SAndroid Build Coastguard Worker else: 1193*da0073e9SAndroid Build Coastguard Worker return 1 1194*da0073e9SAndroid Build Coastguard Worker 1195*da0073e9SAndroid Build Coastguard Worker module_with_submodule = split_module(traced, mm, split_cb) 1196*da0073e9SAndroid Build Coastguard Worker self.assertEqual(module_with_submodule(a, b, c, d), traced(a, b, c, d)) 1197*da0073e9SAndroid Build Coastguard Worker 1198*da0073e9SAndroid Build Coastguard Worker def test_split_qualname_mapping(self): 1199*da0073e9SAndroid Build Coastguard Worker d_hid = 4 1200*da0073e9SAndroid Build Coastguard Worker 1201*da0073e9SAndroid Build Coastguard Worker class ExampleCode(torch.nn.Module): 1202*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1203*da0073e9SAndroid Build Coastguard Worker super().__init__() 1204*da0073e9SAndroid Build Coastguard Worker self.mm_param = torch.nn.Parameter(torch.randn(d_hid, d_hid)) 1205*da0073e9SAndroid Build Coastguard Worker self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) 1206*da0073e9SAndroid Build Coastguard Worker self.lin = torch.nn.Linear(d_hid, d_hid) 1207*da0073e9SAndroid Build Coastguard Worker 1208*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1209*da0073e9SAndroid Build Coastguard Worker x = torch.mm(x, self.mm_param) 1210*da0073e9SAndroid Build Coastguard Worker x = torch.relu(x) 1211*da0073e9SAndroid Build Coastguard Worker x = torch.mm(x, self.mm_param) 1212*da0073e9SAndroid Build Coastguard Worker x = self.lin(x) 1213*da0073e9SAndroid Build Coastguard Worker x = torch.relu(x) 1214*da0073e9SAndroid Build Coastguard Worker x = torch.mm(x, self.mm_param2) 1215*da0073e9SAndroid Build Coastguard Worker x = self.lin(x) 1216*da0073e9SAndroid Build Coastguard Worker return x 1217*da0073e9SAndroid Build Coastguard Worker 1218*da0073e9SAndroid Build Coastguard Worker my_module = ExampleCode() 1219*da0073e9SAndroid Build Coastguard Worker my_module_traced = symbolic_trace(my_module) 1220*da0073e9SAndroid Build Coastguard Worker 1221*da0073e9SAndroid Build Coastguard Worker part_idx = 0 1222*da0073e9SAndroid Build Coastguard Worker 1223*da0073e9SAndroid Build Coastguard Worker def split_callback(n : torch.fx.Node): 1224*da0073e9SAndroid Build Coastguard Worker nonlocal part_idx 1225*da0073e9SAndroid Build Coastguard Worker if (n.op, n.target) == ('call_module', 'lin'): 1226*da0073e9SAndroid Build Coastguard Worker part_idx += 1 1227*da0073e9SAndroid Build Coastguard Worker return part_idx 1228*da0073e9SAndroid Build Coastguard Worker 1229*da0073e9SAndroid Build Coastguard Worker # split module in module with submodules 1230*da0073e9SAndroid Build Coastguard Worker qualname_map : Dict[str, str] = {} 1231*da0073e9SAndroid Build Coastguard Worker module_with_submodules = split_module( 1232*da0073e9SAndroid Build Coastguard Worker my_module_traced, my_module, split_callback, qualname_map 1233*da0073e9SAndroid Build Coastguard Worker ) 1234*da0073e9SAndroid Build Coastguard Worker expected_qualname_map = { 1235*da0073e9SAndroid Build Coastguard Worker 'submod_1.lin': 'lin', 'submod_2.lin': 'lin' 1236*da0073e9SAndroid Build Coastguard Worker } 1237*da0073e9SAndroid Build Coastguard Worker self.assertEqual(qualname_map, expected_qualname_map) 1238*da0073e9SAndroid Build Coastguard Worker 1239*da0073e9SAndroid Build Coastguard Worker def test_traceable_function_with_nonstandard_name(self): 1240*da0073e9SAndroid Build Coastguard Worker def foo(x): 1241*da0073e9SAndroid Build Coastguard Worker return torch.relu(x) 1242*da0073e9SAndroid Build Coastguard Worker 1243*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace_with_rewrite(foo) 1244*da0073e9SAndroid Build Coastguard Worker 1245*da0073e9SAndroid Build Coastguard Worker def test_to_folder(self): 1246*da0073e9SAndroid Build Coastguard Worker class Test(torch.nn.Module): 1247*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1248*da0073e9SAndroid Build Coastguard Worker super().__init__() 1249*da0073e9SAndroid Build Coastguard Worker self.W = torch.nn.Parameter(torch.randn(2)) 1250*da0073e9SAndroid Build Coastguard Worker self.seq = torch.nn.Sequential(torch.nn.BatchNorm1d(2, 2)) 1251*da0073e9SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(2, 2) 1252*da0073e9SAndroid Build Coastguard Worker self.attr = torch.randn(2) 1253*da0073e9SAndroid Build Coastguard Worker self.attr2 = torch.nn.Buffer(torch.randn(2)) 1254*da0073e9SAndroid Build Coastguard Worker self.attr3 = torch.nn.Buffer(torch.ones(2, dtype=torch.int32)) 1255*da0073e9SAndroid Build Coastguard Worker 1256*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1257*da0073e9SAndroid Build Coastguard Worker return self.linear(self.seq(self.W + self.attr + self.attr2 + self.attr3 + x)) 1258*da0073e9SAndroid Build Coastguard Worker 1259*da0073e9SAndroid Build Coastguard Worker mod = symbolic_trace(Test()) 1260*da0073e9SAndroid Build Coastguard Worker module_name = "Foo" 1261*da0073e9SAndroid Build Coastguard Worker import tempfile 1262*da0073e9SAndroid Build Coastguard Worker from pathlib import Path 1263*da0073e9SAndroid Build Coastguard Worker 1264*da0073e9SAndroid Build Coastguard Worker with tempfile.TemporaryDirectory() as tmp_dir: 1265*da0073e9SAndroid Build Coastguard Worker tmp_dir = Path(tmp_dir) 1266*da0073e9SAndroid Build Coastguard Worker mod.to_folder(tmp_dir, module_name) 1267*da0073e9SAndroid Build Coastguard Worker # Recipe taken from here: 1268*da0073e9SAndroid Build Coastguard Worker # https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly 1269*da0073e9SAndroid Build Coastguard Worker import importlib.util 1270*da0073e9SAndroid Build Coastguard Worker 1271*da0073e9SAndroid Build Coastguard Worker spec = importlib.util.spec_from_file_location( 1272*da0073e9SAndroid Build Coastguard Worker module_name, tmp_dir / "__init__.py" 1273*da0073e9SAndroid Build Coastguard Worker ) 1274*da0073e9SAndroid Build Coastguard Worker module = importlib.util.module_from_spec(spec) 1275*da0073e9SAndroid Build Coastguard Worker sys.modules[module_name] = module 1276*da0073e9SAndroid Build Coastguard Worker spec.loader.exec_module(module) 1277*da0073e9SAndroid Build Coastguard Worker t = torch.randn(2, 2) 1278*da0073e9SAndroid Build Coastguard Worker self.assertEqual(module.Foo()(t), mod(t)) 1279*da0073e9SAndroid Build Coastguard Worker 1280*da0073e9SAndroid Build Coastguard Worker def test_fetch(self): 1281*da0073e9SAndroid Build Coastguard Worker attrs_for_lowering: Dict[str, List[str]] = { 1282*da0073e9SAndroid Build Coastguard Worker "torch.nn.modules.conv.Conv2d": [ 1283*da0073e9SAndroid Build Coastguard Worker "weight", 1284*da0073e9SAndroid Build Coastguard Worker "bias", 1285*da0073e9SAndroid Build Coastguard Worker "kernel_size", 1286*da0073e9SAndroid Build Coastguard Worker "stride", 1287*da0073e9SAndroid Build Coastguard Worker "padding", 1288*da0073e9SAndroid Build Coastguard Worker "dilation", 1289*da0073e9SAndroid Build Coastguard Worker "groups", 1290*da0073e9SAndroid Build Coastguard Worker "padding_mode", 1291*da0073e9SAndroid Build Coastguard Worker ], 1292*da0073e9SAndroid Build Coastguard Worker "torch.nn.modules.batchnorm.BatchNorm2d": [ 1293*da0073e9SAndroid Build Coastguard Worker "weight", 1294*da0073e9SAndroid Build Coastguard Worker "bias", 1295*da0073e9SAndroid Build Coastguard Worker "running_mean", 1296*da0073e9SAndroid Build Coastguard Worker "running_var", 1297*da0073e9SAndroid Build Coastguard Worker "eps", 1298*da0073e9SAndroid Build Coastguard Worker ], 1299*da0073e9SAndroid Build Coastguard Worker } 1300*da0073e9SAndroid Build Coastguard Worker 1301*da0073e9SAndroid Build Coastguard Worker class TestModule(torch.nn.Module): 1302*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1303*da0073e9SAndroid Build Coastguard Worker super().__init__() 1304*da0073e9SAndroid Build Coastguard Worker self.conv = torch.nn.Conv2d(3, 3, 2) 1305*da0073e9SAndroid Build Coastguard Worker self.bn = torch.nn.BatchNorm2d(3) 1306*da0073e9SAndroid Build Coastguard Worker 1307*da0073e9SAndroid Build Coastguard Worker def forward(self, a): 1308*da0073e9SAndroid Build Coastguard Worker a = self.conv(a) 1309*da0073e9SAndroid Build Coastguard Worker a += a 1310*da0073e9SAndroid Build Coastguard Worker return self.bn(a) 1311*da0073e9SAndroid Build Coastguard Worker 1312*da0073e9SAndroid Build Coastguard Worker mod = TestModule() 1313*da0073e9SAndroid Build Coastguard Worker traced = symbolic_trace(mod) 1314*da0073e9SAndroid Build Coastguard Worker lift_lowering_attrs_to_nodes(traced) 1315*da0073e9SAndroid Build Coastguard Worker 1316*da0073e9SAndroid Build Coastguard Worker for node in traced.graph.nodes: 1317*da0073e9SAndroid Build Coastguard Worker if node.op == "call_module": 1318*da0073e9SAndroid Build Coastguard Worker assert hasattr(node, "attrs_for_lowering") 1319*da0073e9SAndroid Build Coastguard Worker para_list = attrs_for_lowering[node.attrs_for_lowering["name"]] 1320*da0073e9SAndroid Build Coastguard Worker 1321*da0073e9SAndroid Build Coastguard Worker # node.attrs_for_lowering has an addition field of class name 1322*da0073e9SAndroid Build Coastguard Worker assert len(para_list) + 1 == len(node.attrs_for_lowering) 1323*da0073e9SAndroid Build Coastguard Worker for p_name in para_list: 1324*da0073e9SAndroid Build Coastguard Worker assert p_name in node.attrs_for_lowering 1325*da0073e9SAndroid Build Coastguard Worker 1326*da0073e9SAndroid Build Coastguard Worker def test_merge_matmuls(self): 1327*da0073e9SAndroid Build Coastguard Worker """ 1328*da0073e9SAndroid Build Coastguard Worker A collection of test cases for torch.fx.experimental.merge_matmul, 1329*da0073e9SAndroid Build Coastguard Worker a graph transformation that merges matrix multiplication operations. 1330*da0073e9SAndroid Build Coastguard Worker """ 1331*da0073e9SAndroid Build Coastguard Worker # Utility function for counting matmuls for test assertions. 1332*da0073e9SAndroid Build Coastguard Worker def _count_matmuls(mod): 1333*da0073e9SAndroid Build Coastguard Worker gm = torch.fx.symbolic_trace(mod) 1334*da0073e9SAndroid Build Coastguard Worker 1335*da0073e9SAndroid Build Coastguard Worker num_matmuls = 0 1336*da0073e9SAndroid Build Coastguard Worker for node in gm.graph.nodes: 1337*da0073e9SAndroid Build Coastguard Worker if node.target == torch.matmul: 1338*da0073e9SAndroid Build Coastguard Worker num_matmuls += 1 1339*da0073e9SAndroid Build Coastguard Worker 1340*da0073e9SAndroid Build Coastguard Worker return num_matmuls 1341*da0073e9SAndroid Build Coastguard Worker 1342*da0073e9SAndroid Build Coastguard Worker # Simple test case in which there are two matmuls of the same size to merge. 1343*da0073e9SAndroid Build Coastguard Worker class SimpleMergeMatmulModule(torch.nn.Module): 1344*da0073e9SAndroid Build Coastguard Worker def __init__(self, rhs): 1345*da0073e9SAndroid Build Coastguard Worker super().__init__() 1346*da0073e9SAndroid Build Coastguard Worker self.rhs = rhs 1347*da0073e9SAndroid Build Coastguard Worker 1348*da0073e9SAndroid Build Coastguard Worker def forward(self, x, y): 1349*da0073e9SAndroid Build Coastguard Worker a = torch.matmul(x, self.rhs) 1350*da0073e9SAndroid Build Coastguard Worker b = torch.matmul(y, self.rhs) 1351*da0073e9SAndroid Build Coastguard Worker return a + b 1352*da0073e9SAndroid Build Coastguard Worker 1353*da0073e9SAndroid Build Coastguard Worker # Initialize inputs. 1354*da0073e9SAndroid Build Coastguard Worker a = torch.randn(3, 3) 1355*da0073e9SAndroid Build Coastguard Worker b = torch.randn(3, 3) 1356*da0073e9SAndroid Build Coastguard Worker 1357*da0073e9SAndroid Build Coastguard Worker # Initialize RHS for matmuls. 1358*da0073e9SAndroid Build Coastguard Worker rhs = torch.randn(3, 4) 1359*da0073e9SAndroid Build Coastguard Worker 1360*da0073e9SAndroid Build Coastguard Worker # Construct SimpleMergeMatmulModule and call merge_matmul on it. 1361*da0073e9SAndroid Build Coastguard Worker module = SimpleMergeMatmulModule(rhs) 1362*da0073e9SAndroid Build Coastguard Worker opt_module = merge_matmul.merge_matmul(module) 1363*da0073e9SAndroid Build Coastguard Worker 1364*da0073e9SAndroid Build Coastguard Worker # Numerical correctness check. 1365*da0073e9SAndroid Build Coastguard Worker before = module(a, b) 1366*da0073e9SAndroid Build Coastguard Worker after = opt_module(a, b) 1367*da0073e9SAndroid Build Coastguard Worker before.allclose(after) 1368*da0073e9SAndroid Build Coastguard Worker 1369*da0073e9SAndroid Build Coastguard Worker # Basic graph structure check; original module should have 2 matmuls 1370*da0073e9SAndroid Build Coastguard Worker # and optimized module should have 1. 1371*da0073e9SAndroid Build Coastguard Worker self.assertEqual(_count_matmuls(module), 2) 1372*da0073e9SAndroid Build Coastguard Worker self.assertEqual(_count_matmuls(opt_module), 1) 1373*da0073e9SAndroid Build Coastguard Worker 1374*da0073e9SAndroid Build Coastguard Worker # Test case in which there are multiple matmuls of different sizes to merge. 1375*da0073e9SAndroid Build Coastguard Worker class FiveMergeMatmulModule(torch.nn.Module): 1376*da0073e9SAndroid Build Coastguard Worker def __init__(self, rhs): 1377*da0073e9SAndroid Build Coastguard Worker super().__init__() 1378*da0073e9SAndroid Build Coastguard Worker self.rhs = rhs 1379*da0073e9SAndroid Build Coastguard Worker 1380*da0073e9SAndroid Build Coastguard Worker def forward(self, a, b, c, d, e): 1381*da0073e9SAndroid Build Coastguard Worker s = torch.tensor([]) 1382*da0073e9SAndroid Build Coastguard Worker matmuls = [] 1383*da0073e9SAndroid Build Coastguard Worker 1384*da0073e9SAndroid Build Coastguard Worker # For some reason using a list comprehension or for-loop for this 1385*da0073e9SAndroid Build Coastguard Worker # doesn't work. 1386*da0073e9SAndroid Build Coastguard Worker matmuls.append(torch.matmul(a, self.rhs)) 1387*da0073e9SAndroid Build Coastguard Worker matmuls.append(torch.matmul(b, self.rhs)) 1388*da0073e9SAndroid Build Coastguard Worker matmuls.append(torch.matmul(c, self.rhs)) 1389*da0073e9SAndroid Build Coastguard Worker matmuls.append(torch.matmul(d, self.rhs)) 1390*da0073e9SAndroid Build Coastguard Worker matmuls.append(torch.matmul(e, self.rhs)) 1391*da0073e9SAndroid Build Coastguard Worker 1392*da0073e9SAndroid Build Coastguard Worker for m in matmuls: 1393*da0073e9SAndroid Build Coastguard Worker s += torch.sum(m) 1394*da0073e9SAndroid Build Coastguard Worker 1395*da0073e9SAndroid Build Coastguard Worker return s 1396*da0073e9SAndroid Build Coastguard Worker 1397*da0073e9SAndroid Build Coastguard Worker # Initialize inputs. 1398*da0073e9SAndroid Build Coastguard Worker inputs = [torch.randn(2 * i + 1, 5) for i in range(5)] 1399*da0073e9SAndroid Build Coastguard Worker 1400*da0073e9SAndroid Build Coastguard Worker # Initialize RHS. 1401*da0073e9SAndroid Build Coastguard Worker rhs = torch.randn(5, 4) 1402*da0073e9SAndroid Build Coastguard Worker 1403*da0073e9SAndroid Build Coastguard Worker # Construct FiveMergeMatmulModule and call merge_matmul on it. 1404*da0073e9SAndroid Build Coastguard Worker module = FiveMergeMatmulModule(rhs) 1405*da0073e9SAndroid Build Coastguard Worker opt_module = merge_matmul.merge_matmul(module) 1406*da0073e9SAndroid Build Coastguard Worker 1407*da0073e9SAndroid Build Coastguard Worker # Numerical correctness check. 1408*da0073e9SAndroid Build Coastguard Worker before = module(*inputs) 1409*da0073e9SAndroid Build Coastguard Worker after = opt_module(*inputs) 1410*da0073e9SAndroid Build Coastguard Worker before.allclose(after) 1411*da0073e9SAndroid Build Coastguard Worker 1412*da0073e9SAndroid Build Coastguard Worker # Basic graph structure check; original module should have len(inputs) matmuls 1413*da0073e9SAndroid Build Coastguard Worker # and optimized module should have 1. 1414*da0073e9SAndroid Build Coastguard Worker self.assertEqual(_count_matmuls(module), len(inputs)) 1415*da0073e9SAndroid Build Coastguard Worker self.assertEqual(_count_matmuls(opt_module), 1) 1416*da0073e9SAndroid Build Coastguard Worker 1417*da0073e9SAndroid Build Coastguard Worker # Simple test case in which two matmuls cannot be merged due to a data dependency between 1418*da0073e9SAndroid Build Coastguard Worker # the LHS operands. 1419*da0073e9SAndroid Build Coastguard Worker class UnmergeableMatmulModule(torch.nn.Module): 1420*da0073e9SAndroid Build Coastguard Worker def __init__(self, rhs): 1421*da0073e9SAndroid Build Coastguard Worker super().__init__() 1422*da0073e9SAndroid Build Coastguard Worker self.rhs = rhs 1423*da0073e9SAndroid Build Coastguard Worker 1424*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1425*da0073e9SAndroid Build Coastguard Worker a = torch.matmul(x, self.rhs) 1426*da0073e9SAndroid Build Coastguard Worker a_abs = torch.abs(a) 1427*da0073e9SAndroid Build Coastguard Worker b = torch.matmul(a_abs.transpose(1, 0), self.rhs) 1428*da0073e9SAndroid Build Coastguard Worker return b 1429*da0073e9SAndroid Build Coastguard Worker 1430*da0073e9SAndroid Build Coastguard Worker # Initialize inputs. 1431*da0073e9SAndroid Build Coastguard Worker a = torch.randn(3, 3) 1432*da0073e9SAndroid Build Coastguard Worker 1433*da0073e9SAndroid Build Coastguard Worker # Initialize RHS for matmuls. 1434*da0073e9SAndroid Build Coastguard Worker rhs = torch.randn(3, 4) 1435*da0073e9SAndroid Build Coastguard Worker 1436*da0073e9SAndroid Build Coastguard Worker # Construct UnmergeableMatmulModule and call merge_matmul on it. 1437*da0073e9SAndroid Build Coastguard Worker module = UnmergeableMatmulModule(rhs) 1438*da0073e9SAndroid Build Coastguard Worker opt_module = merge_matmul.merge_matmul(module) 1439*da0073e9SAndroid Build Coastguard Worker 1440*da0073e9SAndroid Build Coastguard Worker # Numerical correctness check. 1441*da0073e9SAndroid Build Coastguard Worker before = module(a) 1442*da0073e9SAndroid Build Coastguard Worker after = opt_module(a) 1443*da0073e9SAndroid Build Coastguard Worker before.allclose(after) 1444*da0073e9SAndroid Build Coastguard Worker 1445*da0073e9SAndroid Build Coastguard Worker # Basic graph structure check; the number of matrix multiplcations should not have changed. 1446*da0073e9SAndroid Build Coastguard Worker self.assertEqual(_count_matmuls(module), 2) 1447*da0073e9SAndroid Build Coastguard Worker self.assertEqual(_count_matmuls(opt_module), 2) 1448*da0073e9SAndroid Build Coastguard Worker 1449*da0073e9SAndroid Build Coastguard Worker def test_type_matches(self): 1450*da0073e9SAndroid Build Coastguard Worker should_be_equal = [ 1451*da0073e9SAndroid Build Coastguard Worker (int, int), 1452*da0073e9SAndroid Build Coastguard Worker (numbers.Number, int), 1453*da0073e9SAndroid Build Coastguard Worker (numbers.Number, float), 1454*da0073e9SAndroid Build Coastguard Worker (int, type(torch.float)), 1455*da0073e9SAndroid Build Coastguard Worker (Union[int, float], int), 1456*da0073e9SAndroid Build Coastguard Worker (Union[int, float], float), 1457*da0073e9SAndroid Build Coastguard Worker (List[int], int), 1458*da0073e9SAndroid Build Coastguard Worker (List[int], create_type_hint([int, int])), 1459*da0073e9SAndroid Build Coastguard Worker (List[int], create_type_hint((int, int))), 1460*da0073e9SAndroid Build Coastguard Worker (List[torch.Tensor], create_type_hint([torch.Tensor, torch.Tensor])), 1461*da0073e9SAndroid Build Coastguard Worker ( 1462*da0073e9SAndroid Build Coastguard Worker List[torch.Tensor], 1463*da0073e9SAndroid Build Coastguard Worker create_type_hint([torch.nn.Parameter, torch.nn.Parameter]), 1464*da0073e9SAndroid Build Coastguard Worker ), 1465*da0073e9SAndroid Build Coastguard Worker (torch.Tensor, torch.nn.Parameter), 1466*da0073e9SAndroid Build Coastguard Worker (List[torch.Tensor], create_type_hint([torch.nn.Parameter, torch.Tensor])), 1467*da0073e9SAndroid Build Coastguard Worker (List[torch.Tensor], create_type_hint([torch.Tensor, torch.nn.Parameter])), 1468*da0073e9SAndroid Build Coastguard Worker (List[torch.Tensor], create_type_hint((torch.Tensor, torch.Tensor))), 1469*da0073e9SAndroid Build Coastguard Worker ( 1470*da0073e9SAndroid Build Coastguard Worker List[torch.Tensor], 1471*da0073e9SAndroid Build Coastguard Worker create_type_hint((torch.nn.Parameter, torch.nn.Parameter)), 1472*da0073e9SAndroid Build Coastguard Worker ), 1473*da0073e9SAndroid Build Coastguard Worker (torch.Tensor, torch.nn.Parameter), 1474*da0073e9SAndroid Build Coastguard Worker (List[torch.Tensor], create_type_hint((torch.nn.Parameter, torch.Tensor))), 1475*da0073e9SAndroid Build Coastguard Worker (List[torch.Tensor], create_type_hint((torch.Tensor, torch.nn.Parameter))), 1476*da0073e9SAndroid Build Coastguard Worker (Optional[List[torch.Tensor]], List[torch.Tensor]), 1477*da0073e9SAndroid Build Coastguard Worker (Optional[List[int]], List[int]), 1478*da0073e9SAndroid Build Coastguard Worker ] 1479*da0073e9SAndroid Build Coastguard Worker for sig_type, arg_type in should_be_equal: 1480*da0073e9SAndroid Build Coastguard Worker self.assertTrue(type_matches(sig_type, arg_type)) 1481*da0073e9SAndroid Build Coastguard Worker 1482*da0073e9SAndroid Build Coastguard Worker should_fail = [ 1483*da0073e9SAndroid Build Coastguard Worker (int, float), 1484*da0073e9SAndroid Build Coastguard Worker (Union[int, float], str), 1485*da0073e9SAndroid Build Coastguard Worker (List[torch.Tensor], List[int]), 1486*da0073e9SAndroid Build Coastguard Worker ] 1487*da0073e9SAndroid Build Coastguard Worker 1488*da0073e9SAndroid Build Coastguard Worker for sig_type, arg_type in should_fail: 1489*da0073e9SAndroid Build Coastguard Worker self.assertFalse(type_matches(sig_type, arg_type)) 1490*da0073e9SAndroid Build Coastguard Worker 1491*da0073e9SAndroid Build Coastguard Worker @skipIfNoMkldnn 1492*da0073e9SAndroid Build Coastguard Worker def test_optimize_for_inference_cpu(self): 1493*da0073e9SAndroid Build Coastguard Worker import torch.nn as nn 1494*da0073e9SAndroid Build Coastguard Worker 1495*da0073e9SAndroid Build Coastguard Worker class Foo(nn.Module): 1496*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1497*da0073e9SAndroid Build Coastguard Worker super().__init__() 1498*da0073e9SAndroid Build Coastguard Worker layers = [] 1499*da0073e9SAndroid Build Coastguard Worker layers2 = [] 1500*da0073e9SAndroid Build Coastguard Worker for _ in range(10): 1501*da0073e9SAndroid Build Coastguard Worker layers.append(nn.Conv2d(3, 3, 1)) 1502*da0073e9SAndroid Build Coastguard Worker layers.append(nn.BatchNorm2d(3)) 1503*da0073e9SAndroid Build Coastguard Worker layers.append(nn.ReLU()) 1504*da0073e9SAndroid Build Coastguard Worker 1505*da0073e9SAndroid Build Coastguard Worker layers2.append(nn.Conv2d(3, 3, 1)) 1506*da0073e9SAndroid Build Coastguard Worker layers2.append(nn.BatchNorm2d(3)) 1507*da0073e9SAndroid Build Coastguard Worker layers2.append(nn.ReLU()) 1508*da0073e9SAndroid Build Coastguard Worker self.model = nn.Sequential(*layers) 1509*da0073e9SAndroid Build Coastguard Worker self.model2 = nn.Sequential(*layers2) 1510*da0073e9SAndroid Build Coastguard Worker 1511*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1512*da0073e9SAndroid Build Coastguard Worker return self.model(x) + self.model2(x) 1513*da0073e9SAndroid Build Coastguard Worker 1514*da0073e9SAndroid Build Coastguard Worker N, C, H, W, = ( 1515*da0073e9SAndroid Build Coastguard Worker 1, 1516*da0073e9SAndroid Build Coastguard Worker 3, 1517*da0073e9SAndroid Build Coastguard Worker 224, 1518*da0073e9SAndroid Build Coastguard Worker 224, 1519*da0073e9SAndroid Build Coastguard Worker ) 1520*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(N, C, H, W) 1521*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 1522*da0073e9SAndroid Build Coastguard Worker model = Foo().eval() 1523*da0073e9SAndroid Build Coastguard Worker optimized_model = optimization.optimize_for_inference(model) 1524*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(model(inp), optimized_model(inp)) 1525*da0073e9SAndroid Build Coastguard Worker 1526*da0073e9SAndroid Build Coastguard Worker optimized_model2 = optimization.optimize_for_inference( 1527*da0073e9SAndroid Build Coastguard Worker model, pass_config={"remove_dropout": False} 1528*da0073e9SAndroid Build Coastguard Worker ) 1529*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(model(inp), optimized_model2(inp)) 1530*da0073e9SAndroid Build Coastguard Worker 1531*da0073e9SAndroid Build Coastguard Worker @skipIfNoTorchVision 1532*da0073e9SAndroid Build Coastguard Worker @skipIfNoMkldnn 1533*da0073e9SAndroid Build Coastguard Worker def test_optimize_for_inference_cpu_torchvision(self): 1534*da0073e9SAndroid Build Coastguard Worker models = [ 1535*da0073e9SAndroid Build Coastguard Worker torchvision.models.resnet18, 1536*da0073e9SAndroid Build Coastguard Worker torchvision.models.resnet50, 1537*da0073e9SAndroid Build Coastguard Worker torchvision.models.densenet121, 1538*da0073e9SAndroid Build Coastguard Worker torchvision.models.shufflenet_v2_x1_0, 1539*da0073e9SAndroid Build Coastguard Worker torchvision.models.vgg16, 1540*da0073e9SAndroid Build Coastguard Worker torchvision.models.mobilenet_v2, 1541*da0073e9SAndroid Build Coastguard Worker torchvision.models.mnasnet1_0, 1542*da0073e9SAndroid Build Coastguard Worker torchvision.models.resnext50_32x4d, 1543*da0073e9SAndroid Build Coastguard Worker ] 1544*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 1545*da0073e9SAndroid Build Coastguard Worker for model_type in models: 1546*da0073e9SAndroid Build Coastguard Worker model = model_type() 1547*da0073e9SAndroid Build Coastguard Worker C, H, W, = ( 1548*da0073e9SAndroid Build Coastguard Worker 3, 1549*da0073e9SAndroid Build Coastguard Worker 224, 1550*da0073e9SAndroid Build Coastguard Worker 224, 1551*da0073e9SAndroid Build Coastguard Worker ) 1552*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(3, C, H, W) 1553*da0073e9SAndroid Build Coastguard Worker model(inp) 1554*da0073e9SAndroid Build Coastguard Worker model.eval() 1555*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(1, C, H, W) 1556*da0073e9SAndroid Build Coastguard Worker heuristic = optimization.gen_mkl_autotuner(inp, iters=0, warmup=0) 1557*da0073e9SAndroid Build Coastguard Worker optimized_model = optimization.optimize_for_inference(model) 1558*da0073e9SAndroid Build Coastguard Worker 1559*da0073e9SAndroid Build Coastguard Worker orig_out = model(inp) 1560*da0073e9SAndroid Build Coastguard Worker new_out = optimized_model(inp) 1561*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(orig_out, new_out) 1562*da0073e9SAndroid Build Coastguard Worker 1563*da0073e9SAndroid Build Coastguard Worker 1564*da0073e9SAndroid Build Coastguard Workerclass TestNormalizeOperators(JitTestCase): 1565*da0073e9SAndroid Build Coastguard Worker @onlyCPU 1566*da0073e9SAndroid Build Coastguard Worker @ops(op_db, allowed_dtypes=(torch.float,)) 1567*da0073e9SAndroid Build Coastguard Worker def test_normalize_operator_exhaustive(self, device, dtype, op): 1568*da0073e9SAndroid Build Coastguard Worker # These ops currently don't trace in FX for various reasons (i.e. they take a list of tensors) 1569*da0073e9SAndroid Build Coastguard Worker fx_fail = {"cat", "stack", "hstack", "vstack", "dstack", "linalg.multi_dot", "_upsample_bilinear2d_aa", "_chunk_cat"} 1570*da0073e9SAndroid Build Coastguard Worker sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False) 1571*da0073e9SAndroid Build Coastguard Worker if isinstance(op.op, torch._ops.OpOverload): 1572*da0073e9SAndroid Build Coastguard Worker self.skipTest("normalize operator doesn't work on torch.ops") 1573*da0073e9SAndroid Build Coastguard Worker for sample_input in sample_inputs_itr: 1574*da0073e9SAndroid Build Coastguard Worker unsupported_arg_type = False 1575*da0073e9SAndroid Build Coastguard Worker arg_values = [sample_input.input] + list(sample_input.args) 1576*da0073e9SAndroid Build Coastguard Worker kwarg_values = sample_input.kwargs 1577*da0073e9SAndroid Build Coastguard Worker arg_types = [] 1578*da0073e9SAndroid Build Coastguard Worker kwarg_types = {} 1579*da0073e9SAndroid Build Coastguard Worker 1580*da0073e9SAndroid Build Coastguard Worker def jit_infer_type(v): 1581*da0073e9SAndroid Build Coastguard Worker inferred_arg_type = torch._C._jit_try_infer_type(v) 1582*da0073e9SAndroid Build Coastguard Worker assert inferred_arg_type.success() 1583*da0073e9SAndroid Build Coastguard Worker t = _torchscript_type_to_python_type(inferred_arg_type.type()) 1584*da0073e9SAndroid Build Coastguard Worker return t 1585*da0073e9SAndroid Build Coastguard Worker 1586*da0073e9SAndroid Build Coastguard Worker for v in arg_values: 1587*da0073e9SAndroid Build Coastguard Worker if isinstance(v, torch.Tensor): 1588*da0073e9SAndroid Build Coastguard Worker arg_types.append(type(v)) 1589*da0073e9SAndroid Build Coastguard Worker else: 1590*da0073e9SAndroid Build Coastguard Worker if isinstance(v, complex): 1591*da0073e9SAndroid Build Coastguard Worker # Complex type not supported in FX 1592*da0073e9SAndroid Build Coastguard Worker unsupported_arg_type = True 1593*da0073e9SAndroid Build Coastguard Worker arg_types.append(jit_infer_type(v)) 1594*da0073e9SAndroid Build Coastguard Worker 1595*da0073e9SAndroid Build Coastguard Worker for k, v in kwarg_values.items(): 1596*da0073e9SAndroid Build Coastguard Worker if isinstance(v, torch.Tensor): 1597*da0073e9SAndroid Build Coastguard Worker kwarg_types[k] = type(v) 1598*da0073e9SAndroid Build Coastguard Worker else: 1599*da0073e9SAndroid Build Coastguard Worker if isinstance(v, complex): 1600*da0073e9SAndroid Build Coastguard Worker # Complex type not supported in FX 1601*da0073e9SAndroid Build Coastguard Worker unsupported_arg_type = True 1602*da0073e9SAndroid Build Coastguard Worker kwarg_types[k] = jit_infer_type(v) 1603*da0073e9SAndroid Build Coastguard Worker 1604*da0073e9SAndroid Build Coastguard Worker if unsupported_arg_type: 1605*da0073e9SAndroid Build Coastguard Worker continue 1606*da0073e9SAndroid Build Coastguard Worker # Test normalize_function by itself 1607*da0073e9SAndroid Build Coastguard Worker ref_out = op.op(*arg_values, **kwarg_values) 1608*da0073e9SAndroid Build Coastguard Worker norm_args_and_kwargs = normalize_function( 1609*da0073e9SAndroid Build Coastguard Worker op.op, arg_values, kwarg_values, arg_types, kwarg_types 1610*da0073e9SAndroid Build Coastguard Worker ) 1611*da0073e9SAndroid Build Coastguard Worker if norm_args_and_kwargs is None: 1612*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 1613*da0073e9SAndroid Build Coastguard Worker """ 1614*da0073e9SAndroid Build Coastguard Worker FX failed to normalize op - add the op to the op_skip list. 1615*da0073e9SAndroid Build Coastguard Worker A common reason is if your OpInfo was implemented with a lambda 1616*da0073e9SAndroid Build Coastguard Worker - otherwise, file an issue 1617*da0073e9SAndroid Build Coastguard Worker """ 1618*da0073e9SAndroid Build Coastguard Worker ) 1619*da0073e9SAndroid Build Coastguard Worker test_out = op.op(*norm_args_and_kwargs.args, **norm_args_and_kwargs.kwargs) 1620*da0073e9SAndroid Build Coastguard Worker self.assertEqual(test_out, ref_out) 1621*da0073e9SAndroid Build Coastguard Worker 1622*da0073e9SAndroid Build Coastguard Worker # Test normalized_arguments as part of FX 1623*da0073e9SAndroid Build Coastguard Worker if op.name in fx_fail: 1624*da0073e9SAndroid Build Coastguard Worker continue 1625*da0073e9SAndroid Build Coastguard Worker param_names = [] 1626*da0073e9SAndroid Build Coastguard Worker param_values = [] 1627*da0073e9SAndroid Build Coastguard Worker fx_args = [] 1628*da0073e9SAndroid Build Coastguard Worker 1629*da0073e9SAndroid Build Coastguard Worker idx = 0 1630*da0073e9SAndroid Build Coastguard Worker 1631*da0073e9SAndroid Build Coastguard Worker def process_arg(arg, name): 1632*da0073e9SAndroid Build Coastguard Worker if isinstance(arg, torch.Tensor): 1633*da0073e9SAndroid Build Coastguard Worker param_names.append(name) 1634*da0073e9SAndroid Build Coastguard Worker param_values.append(arg) 1635*da0073e9SAndroid Build Coastguard Worker return name 1636*da0073e9SAndroid Build Coastguard Worker else: 1637*da0073e9SAndroid Build Coastguard Worker return f"{repr(arg)}" 1638*da0073e9SAndroid Build Coastguard Worker 1639*da0073e9SAndroid Build Coastguard Worker def process_arg_with_idx(arg): 1640*da0073e9SAndroid Build Coastguard Worker nonlocal idx 1641*da0073e9SAndroid Build Coastguard Worker res = process_arg(arg, f"arg_{idx}") 1642*da0073e9SAndroid Build Coastguard Worker idx = idx + 1 1643*da0073e9SAndroid Build Coastguard Worker return res 1644*da0073e9SAndroid Build Coastguard Worker 1645*da0073e9SAndroid Build Coastguard Worker def str_arg(arg): 1646*da0073e9SAndroid Build Coastguard Worker if isinstance(arg, tuple): 1647*da0073e9SAndroid Build Coastguard Worker args = [f"{str_arg(v)}, " for v in arg] 1648*da0073e9SAndroid Build Coastguard Worker return f"({' '.join(args)})" 1649*da0073e9SAndroid Build Coastguard Worker elif isinstance(arg, list): 1650*da0073e9SAndroid Build Coastguard Worker args = [f"{str_arg(v)}" for v in arg] 1651*da0073e9SAndroid Build Coastguard Worker return f"[{', '.join(args)}]" 1652*da0073e9SAndroid Build Coastguard Worker else: 1653*da0073e9SAndroid Build Coastguard Worker return arg 1654*da0073e9SAndroid Build Coastguard Worker 1655*da0073e9SAndroid Build Coastguard Worker for v in arg_values: 1656*da0073e9SAndroid Build Coastguard Worker arg = pytree.tree_map(process_arg_with_idx, v) 1657*da0073e9SAndroid Build Coastguard Worker fx_args.append(str_arg(arg)) 1658*da0073e9SAndroid Build Coastguard Worker 1659*da0073e9SAndroid Build Coastguard Worker for k, v in kwarg_values.items(): 1660*da0073e9SAndroid Build Coastguard Worker arg = pytree.tree_map(functools.partial(process_arg, name=k), v) 1661*da0073e9SAndroid Build Coastguard Worker fx_args.append(f"{k} = {str_arg(arg)}") 1662*da0073e9SAndroid Build Coastguard Worker 1663*da0073e9SAndroid Build Coastguard Worker code = f""" 1664*da0073e9SAndroid Build Coastguard Workerclass TestModule(torch.nn.Module): 1665*da0073e9SAndroid Build Coastguard Worker def forward(self, {', '.join(param_names)}): 1666*da0073e9SAndroid Build Coastguard Worker return torch.{op.name}({', '.join(fx_args)}) 1667*da0073e9SAndroid Build Coastguard Worker """ 1668*da0073e9SAndroid Build Coastguard Worker 1669*da0073e9SAndroid Build Coastguard Worker g = {"torch": torch, "inf": math.inf} 1670*da0073e9SAndroid Build Coastguard Worker exec(code, g) 1671*da0073e9SAndroid Build Coastguard Worker TestModule = g["TestModule"] 1672*da0073e9SAndroid Build Coastguard Worker 1673*da0073e9SAndroid Build Coastguard Worker m = TestModule() 1674*da0073e9SAndroid Build Coastguard Worker traced = torch.fx.symbolic_trace(m) 1675*da0073e9SAndroid Build Coastguard Worker ref_out = traced(*param_values) 1676*da0073e9SAndroid Build Coastguard Worker 1677*da0073e9SAndroid Build Coastguard Worker for node in traced.graph.nodes: 1678*da0073e9SAndroid Build Coastguard Worker if node.op == "call_function": 1679*da0073e9SAndroid Build Coastguard Worker normalized_args = node.normalized_arguments( 1680*da0073e9SAndroid Build Coastguard Worker traced, arg_types, kwarg_types 1681*da0073e9SAndroid Build Coastguard Worker ) 1682*da0073e9SAndroid Build Coastguard Worker assert normalized_args 1683*da0073e9SAndroid Build Coastguard Worker node.args = normalized_args.args 1684*da0073e9SAndroid Build Coastguard Worker node.kwargs = normalized_args.kwargs 1685*da0073e9SAndroid Build Coastguard Worker traced.recompile() 1686*da0073e9SAndroid Build Coastguard Worker 1687*da0073e9SAndroid Build Coastguard Worker test_out = traced(*param_values) 1688*da0073e9SAndroid Build Coastguard Worker self.assertEqual(test_out, ref_out) 1689*da0073e9SAndroid Build Coastguard Worker 1690*da0073e9SAndroid Build Coastguard Worker def test_normalize_quantized_eb(self): 1691*da0073e9SAndroid Build Coastguard Worker target = torch.ops.quantized.embedding_bag_byte_rowwise_offsets 1692*da0073e9SAndroid Build Coastguard Worker args = ( 1693*da0073e9SAndroid Build Coastguard Worker torch.empty((2, 3), dtype=torch.uint8), 1694*da0073e9SAndroid Build Coastguard Worker torch.empty((2,), dtype=torch.int64), 1695*da0073e9SAndroid Build Coastguard Worker torch.empty((2,), dtype=torch.int64), 1696*da0073e9SAndroid Build Coastguard Worker ) 1697*da0073e9SAndroid Build Coastguard Worker norm_args_and_kwargs = normalize_function( 1698*da0073e9SAndroid Build Coastguard Worker target, args, normalize_to_only_use_kwargs=True 1699*da0073e9SAndroid Build Coastguard Worker ) 1700*da0073e9SAndroid Build Coastguard Worker self.assertTrue(norm_args_and_kwargs is not None) 1701*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1702*da0073e9SAndroid Build Coastguard Worker set(norm_args_and_kwargs.kwargs.keys()), 1703*da0073e9SAndroid Build Coastguard Worker { 1704*da0073e9SAndroid Build Coastguard Worker "weight", 1705*da0073e9SAndroid Build Coastguard Worker "indices", 1706*da0073e9SAndroid Build Coastguard Worker "offsets", 1707*da0073e9SAndroid Build Coastguard Worker "scale_grad_by_freq", 1708*da0073e9SAndroid Build Coastguard Worker "mode", 1709*da0073e9SAndroid Build Coastguard Worker "pruned_weights", 1710*da0073e9SAndroid Build Coastguard Worker "per_sample_weights", 1711*da0073e9SAndroid Build Coastguard Worker "compressed_indices_mapping", 1712*da0073e9SAndroid Build Coastguard Worker "include_last_offset", 1713*da0073e9SAndroid Build Coastguard Worker }, 1714*da0073e9SAndroid Build Coastguard Worker ) 1715*da0073e9SAndroid Build Coastguard Worker self.assertEqual(norm_args_and_kwargs.args, ()) 1716*da0073e9SAndroid Build Coastguard Worker 1717*da0073e9SAndroid Build Coastguard Worker def test_normalize_args_op_overload(self): 1718*da0073e9SAndroid Build Coastguard Worker for target in [torch.ops.aten.resize_as_.default, torch.ops.aten.resize_as_]: 1719*da0073e9SAndroid Build Coastguard Worker inp1 = torch.rand([1]) 1720*da0073e9SAndroid Build Coastguard Worker inp2 = torch.rand([4]) 1721*da0073e9SAndroid Build Coastguard Worker args, kwargs = normalize_function(target, (inp1,), {"the_template": inp2}, normalize_to_only_use_kwargs=True) 1722*da0073e9SAndroid Build Coastguard Worker self.assertIs(kwargs["input"], inp1) 1723*da0073e9SAndroid Build Coastguard Worker self.assertIs(kwargs["the_template"], inp2) 1724*da0073e9SAndroid Build Coastguard Worker 1725*da0073e9SAndroid Build Coastguard Worker 1726*da0073e9SAndroid Build Coastguard Workerif TEST_Z3: 1727*da0073e9SAndroid Build Coastguard Worker import z3 1728*da0073e9SAndroid Build Coastguard Worker 1729*da0073e9SAndroid Build Coastguard Worker import torch._dynamo.config 1730*da0073e9SAndroid Build Coastguard Worker 1731*da0073e9SAndroid Build Coastguard Worker from torch.fx.experimental.validator import SympyToZ3, TranslationValidator, ValidationException, z3str 1732*da0073e9SAndroid Build Coastguard Worker from torch.utils._sympy.functions import FloorDiv, Mod 1733*da0073e9SAndroid Build Coastguard Worker 1734*da0073e9SAndroid Build Coastguard Worker class TestTranslationValidation(TestCase): 1735*da0073e9SAndroid Build Coastguard Worker def _prepare_for_translation_validation(self): 1736*da0073e9SAndroid Build Coastguard Worker validator = TranslationValidator() 1737*da0073e9SAndroid Build Coastguard Worker 1738*da0073e9SAndroid Build Coastguard Worker # SymPy symbols. 1739*da0073e9SAndroid Build Coastguard Worker s0, s1, s2 = sympy.symbols("s0 s1 s2", integer=True) 1740*da0073e9SAndroid Build Coastguard Worker 1741*da0073e9SAndroid Build Coastguard Worker # Z3 symbols. 1742*da0073e9SAndroid Build Coastguard Worker [validator.add_var(s, int) for s in (s0, s1, s2)] 1743*da0073e9SAndroid Build Coastguard Worker z0, z1, z2 = (validator.z3var(s) for s in (s0, s1, s2)) 1744*da0073e9SAndroid Build Coastguard Worker 1745*da0073e9SAndroid Build Coastguard Worker return (s0, s1, s2), (z0, z1, z2), validator 1746*da0073e9SAndroid Build Coastguard Worker 1747*da0073e9SAndroid Build Coastguard Worker def test_sympy_to_z3(self): 1748*da0073e9SAndroid Build Coastguard Worker 1749*da0073e9SAndroid Build Coastguard Worker ( 1750*da0073e9SAndroid Build Coastguard Worker (s0, s1, s2), 1751*da0073e9SAndroid Build Coastguard Worker (z0, z1, z2), 1752*da0073e9SAndroid Build Coastguard Worker validator, 1753*da0073e9SAndroid Build Coastguard Worker ) = self._prepare_for_translation_validation() 1754*da0073e9SAndroid Build Coastguard Worker 1755*da0073e9SAndroid Build Coastguard Worker test_cases = [ 1756*da0073e9SAndroid Build Coastguard Worker # Integer constants. 1757*da0073e9SAndroid Build Coastguard Worker (sympy.S.Zero, z3.IntVal(0)), 1758*da0073e9SAndroid Build Coastguard Worker (sympy.S.One, z3.IntVal(1)), 1759*da0073e9SAndroid Build Coastguard Worker (sympy.S.NegativeOne, z3.IntVal(-1)), 1760*da0073e9SAndroid Build Coastguard Worker (sympy.Integer(2), z3.IntVal(2)), 1761*da0073e9SAndroid Build Coastguard Worker ( 1762*da0073e9SAndroid Build Coastguard Worker s0, 1763*da0073e9SAndroid Build Coastguard Worker z0, 1764*da0073e9SAndroid Build Coastguard Worker ), 1765*da0073e9SAndroid Build Coastguard Worker # Arithmetic operations. 1766*da0073e9SAndroid Build Coastguard Worker *[ 1767*da0073e9SAndroid Build Coastguard Worker (op(s0, s1), op(z0, z1)) 1768*da0073e9SAndroid Build Coastguard Worker for op in ( 1769*da0073e9SAndroid Build Coastguard Worker operator.add, 1770*da0073e9SAndroid Build Coastguard Worker operator.mul, 1771*da0073e9SAndroid Build Coastguard Worker operator.pow, 1772*da0073e9SAndroid Build Coastguard Worker ) 1773*da0073e9SAndroid Build Coastguard Worker ], 1774*da0073e9SAndroid Build Coastguard Worker # Logical operations. 1775*da0073e9SAndroid Build Coastguard Worker *[ 1776*da0073e9SAndroid Build Coastguard Worker (sympy_op(s0, s1), z3_op(z0, z1)) 1777*da0073e9SAndroid Build Coastguard Worker for sympy_op, z3_op in ( 1778*da0073e9SAndroid Build Coastguard Worker (sympy.Eq, operator.eq), 1779*da0073e9SAndroid Build Coastguard Worker (sympy.Ne, operator.ne), 1780*da0073e9SAndroid Build Coastguard Worker (sympy.Lt, operator.lt), 1781*da0073e9SAndroid Build Coastguard Worker (sympy.Le, operator.le), 1782*da0073e9SAndroid Build Coastguard Worker (sympy.Gt, operator.gt), 1783*da0073e9SAndroid Build Coastguard Worker (sympy.Ge, operator.ge), 1784*da0073e9SAndroid Build Coastguard Worker ) 1785*da0073e9SAndroid Build Coastguard Worker ], 1786*da0073e9SAndroid Build Coastguard Worker # Other operations. 1787*da0073e9SAndroid Build Coastguard Worker ( 1788*da0073e9SAndroid Build Coastguard Worker s0 - s1, 1789*da0073e9SAndroid Build Coastguard Worker z0 + z3.IntVal(-1) * z1, 1790*da0073e9SAndroid Build Coastguard Worker ), 1791*da0073e9SAndroid Build Coastguard Worker ( 1792*da0073e9SAndroid Build Coastguard Worker s0 / s1, 1793*da0073e9SAndroid Build Coastguard Worker z3.ToReal(z0) * (z1**-1), 1794*da0073e9SAndroid Build Coastguard Worker ), 1795*da0073e9SAndroid Build Coastguard Worker (FloorDiv(s0, s1), z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1))), 1796*da0073e9SAndroid Build Coastguard Worker (Mod(s0, s1), z0 - z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1)) * z1), 1797*da0073e9SAndroid Build Coastguard Worker ( 1798*da0073e9SAndroid Build Coastguard Worker Mod(s2, (s0 / s1)), 1799*da0073e9SAndroid Build Coastguard Worker z2 1800*da0073e9SAndroid Build Coastguard Worker - z3.ToReal(z3.ToInt(z3.ToReal(z2) / (z3.ToReal(z0) * z1**-1))) 1801*da0073e9SAndroid Build Coastguard Worker * (z3.ToReal(z0) * z1**-1), 1802*da0073e9SAndroid Build Coastguard Worker ), 1803*da0073e9SAndroid Build Coastguard Worker ( 1804*da0073e9SAndroid Build Coastguard Worker Mod(s2, s0**3), 1805*da0073e9SAndroid Build Coastguard Worker z2 - z3.ToReal(z3.ToInt(z3.ToReal(z2) / z0**3)) * z0**3, 1806*da0073e9SAndroid Build Coastguard Worker ), 1807*da0073e9SAndroid Build Coastguard Worker ] 1808*da0073e9SAndroid Build Coastguard Worker 1809*da0073e9SAndroid Build Coastguard Worker toZ3 = SympyToZ3(validator) 1810*da0073e9SAndroid Build Coastguard Worker for sympy_expr, z3_expr in test_cases: 1811*da0073e9SAndroid Build Coastguard Worker result = toZ3.run(sympy_expr) 1812*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 1813*da0073e9SAndroid Build Coastguard Worker z3_expr.eq(result), msg=f"expected: {z3_expr}. Got: {result}" 1814*da0073e9SAndroid Build Coastguard Worker ) 1815*da0073e9SAndroid Build Coastguard Worker 1816*da0073e9SAndroid Build Coastguard Worker def test_sat(self): 1817*da0073e9SAndroid Build Coastguard Worker ( 1818*da0073e9SAndroid Build Coastguard Worker (s0, s1, s2), 1819*da0073e9SAndroid Build Coastguard Worker (z0, z1, z2), 1820*da0073e9SAndroid Build Coastguard Worker validator, 1821*da0073e9SAndroid Build Coastguard Worker ) = self._prepare_for_translation_validation() 1822*da0073e9SAndroid Build Coastguard Worker 1823*da0073e9SAndroid Build Coastguard Worker validator.add_source_expr(z0 > 5) 1824*da0073e9SAndroid Build Coastguard Worker validator.add_source_expr(z1 / 2 > z0) 1825*da0073e9SAndroid Build Coastguard Worker 1826*da0073e9SAndroid Build Coastguard Worker # Solutions for target is a subset of the solutions for the source. 1827*da0073e9SAndroid Build Coastguard Worker validator.add_target_expr(s0 > 20) 1828*da0073e9SAndroid Build Coastguard Worker validator.add_target_expr(s1 > s0**2) 1829*da0073e9SAndroid Build Coastguard Worker 1830*da0073e9SAndroid Build Coastguard Worker validator.validate() 1831*da0073e9SAndroid Build Coastguard Worker 1832*da0073e9SAndroid Build Coastguard Worker def test_unsat(self): 1833*da0073e9SAndroid Build Coastguard Worker ( 1834*da0073e9SAndroid Build Coastguard Worker (s0, s1, s2), 1835*da0073e9SAndroid Build Coastguard Worker (z0, z1, z2), 1836*da0073e9SAndroid Build Coastguard Worker validator, 1837*da0073e9SAndroid Build Coastguard Worker ) = self._prepare_for_translation_validation() 1838*da0073e9SAndroid Build Coastguard Worker 1839*da0073e9SAndroid Build Coastguard Worker validator.add_source_expr(z0 > 5) 1840*da0073e9SAndroid Build Coastguard Worker validator.add_source_expr(z1 / 2 > z0) 1841*da0073e9SAndroid Build Coastguard Worker 1842*da0073e9SAndroid Build Coastguard Worker # Solutions for target is NOT a subset of the solutions for the source. 1843*da0073e9SAndroid Build Coastguard Worker validator.add_target_expr(s0 > 20) 1844*da0073e9SAndroid Build Coastguard Worker # This expression is less restrictive than its counterpart. 1845*da0073e9SAndroid Build Coastguard Worker validator.add_target_expr(s1 > s0 + 2) 1846*da0073e9SAndroid Build Coastguard Worker 1847*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValidationException, "translation validation failed."): 1848*da0073e9SAndroid Build Coastguard Worker validator.validate() 1849*da0073e9SAndroid Build Coastguard Worker 1850*da0073e9SAndroid Build Coastguard Worker def test_z3str(self): 1851*da0073e9SAndroid Build Coastguard Worker a = z3.Int("a") 1852*da0073e9SAndroid Build Coastguard Worker b = z3.Int("b") 1853*da0073e9SAndroid Build Coastguard Worker special = z3.Real("this.size()[2]") 1854*da0073e9SAndroid Build Coastguard Worker 1855*da0073e9SAndroid Build Coastguard Worker test_cases = [ 1856*da0073e9SAndroid Build Coastguard Worker (z3.IntVal(42), "42"), 1857*da0073e9SAndroid Build Coastguard Worker # Variable. 1858*da0073e9SAndroid Build Coastguard Worker (a, "a"), 1859*da0073e9SAndroid Build Coastguard Worker # Name with special characters. 1860*da0073e9SAndroid Build Coastguard Worker (special, "this.size()[2]"), 1861*da0073e9SAndroid Build Coastguard Worker # Renamed function fpplications. 1862*da0073e9SAndroid Build Coastguard Worker (a != b, "(!= a b)"), 1863*da0073e9SAndroid Build Coastguard Worker (a ** b, "(pow a b)"), 1864*da0073e9SAndroid Build Coastguard Worker # Chain of associative operations. 1865*da0073e9SAndroid Build Coastguard Worker *[ 1866*da0073e9SAndroid Build Coastguard Worker (op(op(a, 5), b), f"({opstr} 5 a b)") 1867*da0073e9SAndroid Build Coastguard Worker for op, opstr in [ 1868*da0073e9SAndroid Build Coastguard Worker (operator.add, "+"), 1869*da0073e9SAndroid Build Coastguard Worker (operator.mul, "*") 1870*da0073e9SAndroid Build Coastguard Worker ] 1871*da0073e9SAndroid Build Coastguard Worker ], 1872*da0073e9SAndroid Build Coastguard Worker # Revert 'Not' conversions. 1873*da0073e9SAndroid Build Coastguard Worker (a != b, "(!= a b)"), 1874*da0073e9SAndroid Build Coastguard Worker (a < b, "(> b a)"), 1875*da0073e9SAndroid Build Coastguard Worker (a > b, "(> a b)"), 1876*da0073e9SAndroid Build Coastguard Worker # Ignore 'ToInt' and 'ToReal' functions. 1877*da0073e9SAndroid Build Coastguard Worker (z3.ToInt(special) + a, "(+ this.size()[2] a)"), 1878*da0073e9SAndroid Build Coastguard Worker (z3.ToReal(a + b), "(+ a b)"), 1879*da0073e9SAndroid Build Coastguard Worker # Convert to floor division: 'idiv'. 1880*da0073e9SAndroid Build Coastguard Worker (z3.ToInt(z3.ToReal(a) / z3.ToReal(b)), "(idiv a b)"), 1881*da0073e9SAndroid Build Coastguard Worker ] 1882*da0073e9SAndroid Build Coastguard Worker 1883*da0073e9SAndroid Build Coastguard Worker for expr, expected in test_cases: 1884*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z3str(expr), expected) 1885*da0073e9SAndroid Build Coastguard Worker 1886*da0073e9SAndroid Build Coastguard Worker 1887*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestNormalizeOperators, globals()) 1888*da0073e9SAndroid Build Coastguard Worker 1889*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 1890*da0073e9SAndroid Build Coastguard Worker run_tests() 1891