1# Owner(s): ["oncall: mobile"] 2 3import torch 4import torch._C 5import torch.nn.functional as F 6from torch.testing._internal.common_utils import skipIfNoXNNPACK 7from torch.testing._internal.jit_utils import JitTestCase 8 9 10class TestOptimizeForMobilePreserveDebugInfo(JitTestCase): 11 def check_replacement( 12 self, 13 model, 14 replacements, 15 jit_pass, 16 ): 17 """ 18 model: Model which optimization is performed on 19 replacements: Dict mapping from nodes' kinds in the optimized model 20 to the kinds of nodes they replaced in the original model 21 jit_pass: Function to perform optimization 22 """ 23 24 original_kinds = set(replacements.values()) 25 original_source_ranges = { 26 node.kind(): node.sourceRange() 27 for node in model.graph.nodes() 28 if node.kind() in original_kinds 29 } 30 31 jit_pass(model._c) 32 33 for node in model.graph.nodes(): 34 if node.kind() in replacements: 35 self.assertEqual( 36 node.sourceRange(), 37 original_source_ranges[replacements[node.kind()]], 38 ) 39 40 @skipIfNoXNNPACK 41 def test_replace_conv1d_with_conv2d(self): 42 class TestConv1d(torch.nn.Module): 43 def __init__(self, weight, bias): 44 super().__init__() 45 self.weight = weight 46 self.bias = bias 47 48 def forward(self, x): 49 return F.conv1d(x, self.weight, self.bias) 50 51 self.check_replacement( 52 model=torch.jit.script( 53 TestConv1d( 54 weight=torch.rand(3, 3, 3), 55 bias=torch.rand(3), 56 ), 57 ), 58 replacements={ 59 "prim::ListUnpack": "aten::conv1d", 60 "prim::ListConstruct": "aten::conv1d", 61 "aten::unsqueeze": "aten::conv1d", 62 "aten::conv2d": "aten::conv1d", 63 "aten::squeeze": "aten::conv1d", 64 }, 65 jit_pass=torch._C._jit_pass_transform_conv1d_to_conv2d, 66 ) 67 68 @skipIfNoXNNPACK 69 def test_insert_pre_packed_linear_before_inline_and_conv_2d_op(self): 70 class TestPrepackedLinearBeforeInlineAndConv2dOp(torch.nn.Module): 71 def __init__( 72 self, 73 linear_weight, 74 linear_bias, 75 conv2d_weight, 76 conv2d_bias, 77 conv_transpose2d_weight, 78 conv_transpose2d_bias, 79 ): 80 super( 81 TestPrepackedLinearBeforeInlineAndConv2dOp, 82 self, 83 ).__init__() 84 self.linear_weight = linear_weight.float() 85 self.linear_bias = linear_bias.float() 86 self.conv2d_weight = conv2d_weight.float() 87 self.conv2d_bias = conv2d_bias.float() 88 self.conv_transpose2d_weight = conv_transpose2d_weight.float() 89 self.conv_transpose2d_bias = conv_transpose2d_bias.float() 90 91 def forward(self, x): 92 linear_res = F.linear( 93 x.float(), 94 self.linear_weight, 95 self.linear_bias, 96 ) 97 conv2d_res = F.conv2d( 98 input=linear_res.unsqueeze(dim=0).float(), 99 weight=self.conv2d_weight, 100 bias=self.conv2d_bias, 101 ) 102 return F.conv_transpose2d( 103 input=conv2d_res, 104 weight=self.conv_transpose2d_weight, 105 bias=self.conv_transpose2d_bias, 106 ) 107 108 minibatch = 1 109 in_channels = 6 110 iH = 4 111 iW = 5 112 out_channels = 6 113 kH = 2 114 kW = 3 115 116 self.check_replacement( 117 model=torch.jit.script( 118 TestPrepackedLinearBeforeInlineAndConv2dOp( 119 linear_weight=torch.rand(iW, 3), 120 linear_bias=torch.rand(iW), 121 conv2d_weight=torch.rand(out_channels, in_channels, kH, kW), 122 conv2d_bias=torch.rand(out_channels), 123 conv_transpose2d_weight=torch.rand( 124 out_channels, 125 in_channels, 126 kH, 127 kW, 128 ), 129 conv_transpose2d_bias=torch.rand(out_channels), 130 ), 131 ), 132 replacements={ 133 "prepacked::linear_clamp_prepack": "aten::linear", 134 "prepacked::linear_clamp_run": "aten::linear", 135 "prepacked::conv2d_clamp_prepack": "aten::conv2d", 136 "prepacked::conv2d_clamp_run": "aten::conv2d", 137 "prepacked::conv2d_transpose_clamp_prepack": "aten::conv_transpose2d", 138 "prepacked::conv2d_transpose_clamp_run": "aten::conv_transpose2d", 139 }, 140 jit_pass=torch._C._jit_pass_insert_prepacked_ops, 141 ) 142 143 @skipIfNoXNNPACK 144 def test_insert_pre_packed_linear_op(self): 145 self.check_replacement( 146 model=torch.jit.trace(torch.nn.Linear(5, 4), torch.rand(3, 2, 5)), 147 replacements={ 148 "prepacked::linear_clamp_prepack": "aten::linear", 149 "prepacked::linear_clamp_run": "aten::linear", 150 }, 151 jit_pass=torch._C._jit_pass_insert_prepacked_ops, 152 ) 153 154 def run_test_fuse_activation_with_pack_ops_linear_conv2d( 155 self, 156 linear_activation, 157 linear_activation_kind, 158 conv2d_activation, 159 conv2d_activation_kind, 160 ): 161 class TestFuseActivationLinearConv2d(torch.nn.Module): 162 def __init__( 163 self, 164 linear_weight, 165 linear_bias, 166 conv2d_weight, 167 conv2d_bias, 168 ): 169 super().__init__() 170 self.linear_weight = linear_weight 171 self.linear_bias = linear_bias 172 self.conv2d_weight = conv2d_weight 173 self.conv2d_bias = conv2d_bias 174 175 def forward(self, x): 176 x = F.linear( 177 input=x, 178 weight=self.linear_weight, 179 bias=self.linear_bias, 180 ) 181 x = linear_activation(x) 182 x = F.conv2d( 183 input=x.unsqueeze(dim=0), 184 weight=self.conv2d_weight, 185 bias=self.conv2d_bias, 186 ) 187 return conv2d_activation(x) 188 189 linear_in_features = 5 190 linear_out_features = 4 191 conv2d_in_channels = 3 192 conv2d_out_channels = 4 193 conv2d_kernel = 2 194 x_shape = (3, 2, 5) 195 196 model = torch.jit.trace( 197 TestFuseActivationLinearConv2d( 198 linear_weight=torch.nn.Parameter( 199 data=torch.rand( 200 linear_out_features, 201 linear_in_features, 202 ), 203 requires_grad=False, 204 ), 205 linear_bias=torch.nn.Parameter( 206 data=torch.rand(linear_out_features), 207 requires_grad=False, 208 ), 209 conv2d_weight=torch.rand( 210 conv2d_out_channels, 211 conv2d_in_channels, 212 conv2d_kernel, 213 conv2d_kernel, 214 ), 215 conv2d_bias=torch.rand(conv2d_out_channels), 216 ), 217 torch.rand(x_shape), 218 ) 219 220 torch._C._jit_pass_insert_prepacked_ops(model._c) 221 222 self.check_replacement( 223 model=model, 224 replacements={ 225 "prepacked::linear_clamp_prepack": "prepacked::linear_clamp_prepack", 226 "prepacked::linear_clamp_run": linear_activation_kind, 227 "prepacked::conv2d_clamp_prepack": "prepacked::conv2d_clamp_prepack", 228 "prepacked::conv2d_clamp_run": conv2d_activation_kind, 229 }, 230 jit_pass=torch._C._jit_pass_fuse_clamp_w_prepacked_linear_conv, 231 ) 232 233 @skipIfNoXNNPACK 234 def test_fuse_activation_with_pack_ops_linear_conv2d_1(self): 235 self.run_test_fuse_activation_with_pack_ops_linear_conv2d( 236 linear_activation=F.hardtanh, 237 linear_activation_kind="aten::hardtanh", 238 conv2d_activation=F.hardtanh_, 239 conv2d_activation_kind="aten::hardtanh_", 240 ) 241 242 @skipIfNoXNNPACK 243 def test_fuse_activation_with_pack_ops_linear_conv2d_2(self): 244 self.run_test_fuse_activation_with_pack_ops_linear_conv2d( 245 linear_activation=F.hardtanh_, 246 linear_activation_kind="aten::hardtanh_", 247 conv2d_activation=F.hardtanh, 248 conv2d_activation_kind="aten::hardtanh", 249 ) 250 251 @skipIfNoXNNPACK 252 def test_fuse_activation_with_pack_ops_linear_conv2d_3(self): 253 self.run_test_fuse_activation_with_pack_ops_linear_conv2d( 254 linear_activation=F.relu, 255 linear_activation_kind="aten::relu", 256 conv2d_activation=F.relu_, 257 conv2d_activation_kind="aten::relu_", 258 ) 259 260 @skipIfNoXNNPACK 261 def test_fuse_activation_with_pack_ops_linear_conv2d_4(self): 262 self.run_test_fuse_activation_with_pack_ops_linear_conv2d( 263 linear_activation=F.relu_, 264 linear_activation_kind="aten::relu_", 265 conv2d_activation=F.relu, 266 conv2d_activation_kind="aten::relu", 267 ) 268