1# Owner(s): ["module: fx"] 2 3import unittest 4 5import torch 6import torch.fx 7from torch.testing._internal.common_utils import TestCase 8 9 10class MyModuleBase(torch.nn.Module): 11 def forward(self, x): 12 matrx = self.get_mul_matrix() 13 if self.no_relu(): 14 return torch.mm(x, matrx) 15 else: 16 return torch.relu(torch.mm(x, matrx)) 17 18 def get_mul_matrix(self): 19 return self.param 20 21 def no_relu(self): 22 raise Exception("not implemented") # noqa: TRY002 23 24 25class MyModuleParamShape(MyModuleBase): 26 def __init__(self, in_channels): 27 super().__init__() 28 self.param = torch.nn.Parameter(torch.randn(in_channels, 3)) 29 30 def no_relu(self): 31 return self.param.shape[0] < 10 32 33 34class MyModuleParamSize(MyModuleBase): 35 def __init__(self, in_channels): 36 super().__init__() 37 self.param = torch.nn.Parameter(torch.randn(in_channels, 3)) 38 39 def no_relu(self): 40 return self.param.size()[0] < 10 41 42 43class MyModuleParamDim(MyModuleBase): 44 def __init__(self, param): 45 super().__init__() 46 self.param = param 47 48 def get_mul_matrix(self): 49 return self.param[0] if (self.param.dim() == 3) else self.param 50 51 def no_relu(self): 52 return self.param.dim() == 3 53 54 55class MyModuleParamNDim(MyModuleBase): 56 def __init__(self, param): 57 super().__init__() 58 self.param = param 59 60 def get_mul_matrix(self): 61 return self.param[0] if (self.param.ndim == 3) else self.param 62 63 def no_relu(self): 64 return self.param.ndim == 3 65 66 67class MyModuleParamNumEl(MyModuleBase): 68 def __init__(self, in_channels): 69 super().__init__() 70 self.param = torch.nn.Parameter(torch.randn(in_channels, 3)) 71 72 def no_relu(self): 73 return self.param.numel() < 10 * 3 74 75 76class MyModuleParamNElement(MyModuleBase): 77 def __init__(self, in_channels): 78 super().__init__() 79 self.param = torch.nn.Parameter(torch.randn(in_channels, 3)) 80 81 def no_relu(self): 82 return self.param.nelement() < 10 * 3 83 84 85class TestConstParamShapeInControlFlow(TestCase): 86 def verify_mm_relu_mods(self, mm_only_mod, relu_mod): 87 """ 88 Verify one module only does a mm op while the other 89 performs both mm and relu ops in cascade 90 """ 91 x = torch.randn(10, 5) 92 torch.testing.assert_close( 93 mm_only_mod(x), torch.mm(x, mm_only_mod.get_mul_matrix()) 94 ) 95 tracer = torch.fx.Tracer(param_shapes_constant=True) 96 traced_graph = tracer.trace(mm_only_mod) 97 98 # verify the graph module calculates the same result 99 graph_mod_mm = torch.fx.GraphModule(mm_only_mod, traced_graph) 100 torch.testing.assert_close( 101 graph_mod_mm(x), torch.mm(x, mm_only_mod.get_mul_matrix()) 102 ) 103 104 # Make a new module with different parameter shape to go down the different 105 # code path 106 x = torch.randn(10, 15) 107 torch.testing.assert_close( 108 relu_mod(x), torch.relu(torch.mm(x, relu_mod.get_mul_matrix())) 109 ) 110 111 tracer2 = torch.fx.Tracer(param_shapes_constant=True) 112 traced_graph2 = tracer2.trace(relu_mod) 113 114 # verify the graph module calculates the same result 115 graph_mod_relu = torch.fx.GraphModule(relu_mod, traced_graph2) 116 torch.testing.assert_close( 117 graph_mod_relu(x), torch.relu(torch.mm(x, relu_mod.get_mul_matrix())) 118 ) 119 120 graph1_node_targets = [n.target for n in traced_graph.nodes] 121 graph2_node_targets = [n.target for n in traced_graph2.nodes] 122 123 # the second graph has an exta relu function call node 124 assert torch.mm in graph1_node_targets and torch.mm in graph2_node_targets 125 assert ( 126 torch.relu not in graph1_node_targets and torch.relu in graph2_node_targets 127 ) 128 129 def test_param_shape_const(self): 130 mymod = MyModuleParamShape(in_channels=5) 131 mymod2 = MyModuleParamShape(in_channels=15) 132 self.verify_mm_relu_mods(mymod, mymod2) 133 134 def test_param_size_const(self): 135 mymod = MyModuleParamSize(in_channels=5) 136 mymod2 = MyModuleParamSize(in_channels=15) 137 self.verify_mm_relu_mods(mymod, mymod2) 138 139 def test_param_dim_const(self): 140 mymod = MyModuleParamDim(torch.nn.Parameter(torch.randn(2, 5, 3))) 141 mymod2 = MyModuleParamDim(torch.nn.Parameter(torch.randn(15, 3))) 142 self.verify_mm_relu_mods(mymod, mymod2) 143 144 def test_param_ndim_const(self): 145 mymod = MyModuleParamNDim(torch.nn.Parameter(torch.randn(2, 5, 3))) 146 mymod2 = MyModuleParamNDim(torch.nn.Parameter(torch.randn(15, 3))) 147 self.verify_mm_relu_mods(mymod, mymod2) 148 149 def test_param_numel_const(self): 150 mymod = MyModuleParamNumEl(in_channels=5) 151 mymod2 = MyModuleParamNumEl(in_channels=15) 152 self.verify_mm_relu_mods(mymod, mymod2) 153 154 def test_param_nelement_const(self): 155 mymod = MyModuleParamNElement(in_channels=5) 156 mymod2 = MyModuleParamNElement(in_channels=15) 157 self.verify_mm_relu_mods(mymod, mymod2) 158 159 160if __name__ == "__main__": 161 unittest.main() 162