1 #include <ATen/core/jit_type.h>
2 #include <ATen/native/xnnpack/OpContext.h>
3
4 #include <torch/csrc/jit/ir/ir.h>
5 #include <torch/csrc/jit/passes/constant_propagation.h>
6 #include <torch/csrc/jit/passes/fold_conv_bn.h>
7 #include <torch/csrc/jit/passes/freeze_module.h>
8 #include <torch/csrc/jit/passes/fuse_linear.h>
9 #include <torch/csrc/jit/passes/fuse_relu.h>
10 #include <torch/csrc/jit/passes/graph_rewrite_helper.h>
11 #include <torch/csrc/jit/passes/hoist_conv_packed_params.h>
12 #include <torch/csrc/jit/passes/mobile_optimizer_type.h>
13 #include <torch/csrc/jit/passes/prepack_folding.h>
14 #include <torch/csrc/jit/passes/remove_dropout.h>
15 #include <torch/csrc/jit/passes/subgraph_rewrite.h>
16 #include <torch/csrc/jit/passes/xnnpack_rewrite.h>
17 #include <torch/csrc/jit/runtime/graph_executor_impl.h>
18
19 namespace torch::jit {
20
21 namespace {
22
replaceConv1dWithConv2d(std::shared_ptr<Graph> & graph)23 void replaceConv1dWithConv2d(std::shared_ptr<Graph>& graph) {
24 std::string conv_1d_pattern = R"(
25 graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
26 %res = aten::conv1d(%input, %weight, %bias, %stride, %padding, %dilation, %groups)
27 return (%res) )";
28
29 std::string conv_2d_pattern = R"(
30 graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
31 %zero : int = prim::Constant[value=0]()
32 %one : int = prim::Constant[value=1]()
33 %stride_w : int = prim::ListUnpack(%stride)
34 %stride_2d : int[] = prim::ListConstruct(%one, %stride_w)
35 %padding_w : int = prim::ListUnpack(%padding)
36 %padding_2d : int[] = prim::ListConstruct(%zero, %padding_w)
37 %dilation_w : int = prim::ListUnpack(%dilation)
38 %dilation_2d : int[] = prim::ListConstruct(%one, %dilation_w)
39 %two : int = prim::Constant[value=2]()
40 %input_2d : Tensor = aten::unsqueeze(%input, %two)
41 %weight_2d : Tensor = aten::unsqueeze(%weight, %two)
42 %output_2d = aten::conv2d(
43 %input_2d, %weight_2d, %bias, %stride_2d, %padding_2d, %dilation_2d, %groups)
44 %output : Tensor = aten::squeeze(%output_2d, %two)
45 return (%output) )";
46
47 std::vector<std::pair<std::string, std::string>> value_mappings(
48 {{"zero", "res"},
49 {"one", "res"},
50 {"stride_w", "res"},
51 {"stride_2d", "res"},
52 {"padding_w", "res"},
53 {"padding_2d", "res"},
54 {"dilation_w", "res"},
55 {"dilation_2d", "res"},
56 {"two", "res"},
57 {"input_2d", "res"},
58 {"weight_2d", "res"},
59 {"output_2d", "res"},
60 {"output", "res"}});
61
62 SubgraphRewriter rewriter;
63 rewriter.RegisterRewritePattern(
64 conv_1d_pattern, conv_2d_pattern, value_mappings);
65 rewriter.runOnGraph(graph);
66 }
67
68 } // namespace
69
transformConv1dToConv2d(std::shared_ptr<Graph> & graph)70 void transformConv1dToConv2d(std::shared_ptr<Graph>& graph) {
71 // Replace _convolution with conv1d and conv2d
72 graph_rewrite_helper::replaceConvolutionWithAtenConv(graph);
73 replaceConv1dWithConv2d(graph);
74 }
75
transformConv1dToConv2d(script::Module & module)76 void transformConv1dToConv2d(script::Module& module) {
77 for (auto& method : module.get_methods()) {
78 auto graph = method.graph();
79 transformConv1dToConv2d(graph);
80 }
81 for (script::Module m : module.children()) {
82 transformConv1dToConv2d(m);
83 }
84 }
85
86 #ifdef USE_XNNPACK
87
88 namespace {
89
insertPrePackedLinearOp(std::shared_ptr<Graph> & graph)90 void insertPrePackedLinearOp(std::shared_ptr<Graph>& graph) {
91 // fuse decomposed linear into aten::linear
92 FuseLinear(graph);
93
94 std::string linear_pattern = R"(
95 graph(%input, %weight, %bias):
96 %res = aten::linear(%input, %weight, %bias)
97 return (%res))";
98 std::string prepacked_ops_pattern = R"(
99 graph(%input, %weight, %bias):
100 %output_min_max : None = prim::Constant()
101 %packed_weight_bias = prepacked::linear_clamp_prepack(
102 %weight, %bias, %output_min_max, %output_min_max)
103 %res = prepacked::linear_clamp_run(%input, %packed_weight_bias)
104 return (%res))";
105
106 std::vector<std::pair<std::string, std::string>> value_mappings(
107 {{"output_min_max", "res"},
108 {"packed_weight_bias", "res"},
109 {"res", "res"}});
110
111 SubgraphRewriter linear_rewriter;
112 linear_rewriter.RegisterRewritePattern(
113 linear_pattern, prepacked_ops_pattern, value_mappings);
114 linear_rewriter.runOnGraph(graph);
115 }
116
insertPrePackedConv2dOp(std::shared_ptr<Graph> & graph)117 void insertPrePackedConv2dOp(std::shared_ptr<Graph>& graph) {
118 // Replace _convolution with conv2d
119 graph_rewrite_helper::replaceConvolutionWithAtenConv(graph);
120
121 std::string conv_2d_pattern = R"(
122 graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
123 %res = aten::conv2d(%input, %weight, %bias, %stride, %padding, %dilation, %groups)
124 return (%res) )";
125
126 std::string prepacked_ops_conv2d_pattern = R"(
127 graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
128 %output_min_max : None = prim::Constant()
129 %packed_weight_bias = prepacked::conv2d_clamp_prepack(
130 %weight, %bias, %stride, %padding, %dilation, %groups,
131 %output_min_max, %output_min_max)
132 %res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias)
133 return (%res) )";
134
135 std::vector<std::pair<std::string, std::string>> value_mappings(
136 {{"output_min_max", "res"},
137 {"packed_weight_bias", "res"},
138 {"res", "res"}});
139
140 SubgraphRewriter rewriter;
141 rewriter.RegisterRewritePattern(
142 conv_2d_pattern, prepacked_ops_conv2d_pattern, value_mappings);
143 rewriter.runOnGraph(graph);
144
145 std::string conv_2d_transpose_pattern = R"(
146 graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[],
147 %output_padding:int[], %groups:int):
148 %res = aten::conv_transpose2d(%input, %weight, %bias, %stride, %padding, %output_padding, %groups, %dilation)
149 return (%res) )";
150
151 std::string prepacked_ops_conv2d_transpose_pattern = R"(
152 graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %output_padding:int[], %groups:int):
153 %output_min_max : None = prim::Constant()
154 %packed_weight_bias = prepacked::conv2d_transpose_clamp_prepack(
155 %weight, %bias, %stride, %padding, %output_padding, %dilation, %groups,
156 %output_min_max, %output_min_max)
157 %res = prepacked::conv2d_transpose_clamp_run(%input, %packed_weight_bias)
158 return (%res) )";
159
160 value_mappings = {
161 {"output_min_max", "res"}, {"packed_weight_bias", "res"}, {"res", "res"}};
162
163 SubgraphRewriter transpose_rewriter;
164 transpose_rewriter.RegisterRewritePattern(
165 conv_2d_transpose_pattern,
166 prepacked_ops_conv2d_transpose_pattern,
167 value_mappings);
168 transpose_rewriter.runOnGraph(graph);
169 }
170
fuseHardtanhWithPackedOps(std::shared_ptr<Graph> & graph)171 void fuseHardtanhWithPackedOps(std::shared_ptr<Graph>& graph) {
172 SubgraphRewriter rewriter;
173
174 std::string linear_prepack_run_hardtanh_fused = R"(
175 graph(%input, %weight, %bias, %output_min, %output_max, %dummy_min_max):
176 %packed_weight_bias : __torch__.torch.classes.xnnpack.LinearOpContext = prepacked::linear_clamp_prepack(
177 %weight, %bias, %output_min, %output_max)
178 %res = prepacked::linear_clamp_run(%input, %packed_weight_bias)
179 return (%res))";
180
181 std::string conv2d_prepack_run_hardtanh_fused = R"(
182 graph(%input, %weight, %bias, %stride:int[], %padding:int[],
183 %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
184 %packed_weight_bias : __torch__.torch.classes.xnnpack.Conv2dOpContext = prepacked::conv2d_clamp_prepack(
185 %weight, %bias, %stride, %padding, %dilation, %groups,
186 %output_min, %output_max)
187 %res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias)
188 return (%res) )";
189
190 std::string linear_prepack_run_hardtanh = R"(
191 graph(%input, %weight, %bias, %output_min, %output_max, %dummy_min_max):
192 %packed_weight_bias = prepacked::linear_clamp_prepack(
193 %weight, %bias, %dummy_min_max, %dummy_min_max)
194 %linear_res = prepacked::linear_clamp_run(%input, %packed_weight_bias)
195 %res = aten::hardtanh(%linear_res, %output_min, %output_max)
196 return (%res))";
197
198 std::vector<std::pair<std::string, std::string>> value_mappings(
199 {{"packed_weight_bias", "packed_weight_bias"}, {"res", "res"}});
200
201 rewriter.RegisterRewritePattern(
202 linear_prepack_run_hardtanh,
203 linear_prepack_run_hardtanh_fused,
204 value_mappings);
205
206 std::string conv2d_prepack_run_hardtanh = R"(
207 graph(%input, %weight, %bias, %stride:int[], %padding:int[],
208 %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
209 %packed_weight_bias = prepacked::conv2d_clamp_prepack(
210 %weight, %bias, %stride, %padding, %dilation, %groups,
211 %dummy_min_max, %dummy_min_max)
212 %conv2d_res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias)
213 %res = aten::hardtanh(%conv2d_res, %output_min, %output_max)
214 return (%res) )";
215
216 value_mappings = {
217 {"packed_weight_bias", "packed_weight_bias"}, {"res", "res"}};
218
219 rewriter.RegisterRewritePattern(
220 conv2d_prepack_run_hardtanh,
221 conv2d_prepack_run_hardtanh_fused,
222 value_mappings);
223
224 std::string linear_prepack_run_hardtanh_inplace = R"(
225 graph(%input, %weight, %bias, %output_min, %output_max, %dummy_min_max):
226 %packed_weight_bias = prepacked::linear_clamp_prepack(
227 %weight, %bias, %dummy_min_max, %dummy_min_max)
228 %linear_res = prepacked::linear_clamp_run(%input, %packed_weight_bias)
229 %res = aten::hardtanh_(%linear_res, %output_min, %output_max)
230 return (%res))";
231
232 std::string conv2d_prepack_run_hardtanh_inplace = R"(
233 graph(%input, %weight, %bias, %stride:int[], %padding:int[],
234 %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
235 %packed_weight_bias = prepacked::conv2d_clamp_prepack(
236 %weight, %bias, %stride, %padding, %dilation, %groups,
237 %dummy_min_max, %dummy_min_max)
238 %conv2d_res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias)
239 %res = aten::hardtanh_(%conv2d_res, %output_min, %output_max)
240 return (%res) )";
241
242 value_mappings = {
243 {"packed_weight_bias", "packed_weight_bias"}, {"res", "res"}};
244
245 rewriter.RegisterRewritePattern(
246 linear_prepack_run_hardtanh_inplace,
247 linear_prepack_run_hardtanh_fused,
248 value_mappings);
249
250 value_mappings = {
251 {"packed_weight_bias", "packed_weight_bias"}, {"res", "res"}};
252
253 rewriter.RegisterRewritePattern(
254 conv2d_prepack_run_hardtanh_inplace,
255 conv2d_prepack_run_hardtanh_fused,
256 value_mappings);
257
258 rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable);
259 }
260
fuseReluWithPackedOps(std::shared_ptr<Graph> & graph)261 void fuseReluWithPackedOps(std::shared_ptr<Graph>& graph) {
262 SubgraphRewriter rewriter;
263
264 std::string linear_prepack_run_relu_fused = R"(
265 graph(%input, %weight, %bias, %dummy_min_max):
266 %output_min: float = prim::Constant[value=0.0]()
267 %output_max: None = prim::Constant()
268 %packed_weight_bias : __torch__.torch.classes.xnnpack.LinearOpContext = prepacked::linear_clamp_prepack(
269 %weight, %bias, %output_min, %output_max)
270 %res = prepacked::linear_clamp_run(%input, %packed_weight_bias)
271 return (%res))";
272
273 std::string conv2d_prepack_run_relu_fused = R"(
274 graph(%input, %weight, %bias, %stride:int[], %padding:int[],
275 %dilation:int[], %groups:int, %dummy_min_max):
276 %output_min: float = prim::Constant[value=0.0]()
277 %output_max: None = prim::Constant()
278 %packed_weight_bias : __torch__.torch.classes.xnnpack.Conv2dOpContext = prepacked::conv2d_clamp_prepack(
279 %weight, %bias, %stride, %padding, %dilation, %groups,
280 %output_min, %output_max)
281 %res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias)
282 return (%res) )";
283
284 std::string linear_prepack_run_relu = R"(
285 graph(%input, %weight, %bias, %dummy_min_max):
286 %packed_weight_bias = prepacked::linear_clamp_prepack(
287 %weight, %bias, %dummy_min_max, %dummy_min_max)
288 %linear_res = prepacked::linear_clamp_run(%input, %packed_weight_bias)
289 %res = aten::relu(%linear_res)
290 return (%res))";
291
292 std::vector<std::pair<std::string, std::string>> value_mappings(
293 {{"output_min", "packed_weight_bias"},
294 {"output_max", "packed_weight_bias"},
295 {"packed_weight_bias", "packed_weight_bias"},
296 {"res", "res"}});
297
298 rewriter.RegisterRewritePattern(
299 linear_prepack_run_relu, linear_prepack_run_relu_fused, value_mappings);
300
301 std::string conv2d_prepack_run_relu = R"(
302 graph(%input, %weight, %bias, %stride:int[], %padding:int[],
303 %dilation:int[], %groups:int, %dummy_min_max):
304 %packed_weight_bias = prepacked::conv2d_clamp_prepack(
305 %weight, %bias, %stride, %padding, %dilation, %groups,
306 %dummy_min_max, %dummy_min_max)
307 %conv2d_res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias)
308 %res = aten::relu(%conv2d_res)
309 return (%res) )";
310
311 value_mappings = {
312 {"output_min", "packed_weight_bias"},
313 {"output_max", "packed_weight_bias"},
314 {"packed_weight_bias", "packed_weight_bias"},
315 {"res", "res"}};
316
317 rewriter.RegisterRewritePattern(
318 conv2d_prepack_run_relu, conv2d_prepack_run_relu_fused, value_mappings);
319
320 std::string linear_prepack_run_relu_inplace = R"(
321 graph(%input, %weight, %bias, %dummy_min_max):
322 %packed_weight_bias = prepacked::linear_clamp_prepack(
323 %weight, %bias, %dummy_min_max, %dummy_min_max)
324 %linear_res = prepacked::linear_clamp_run(%input, %packed_weight_bias)
325 %res = aten::relu_(%linear_res)
326 return (%res))";
327
328 std::string conv2d_prepack_run_relu_inplace = R"(
329 graph(%input, %weight, %bias, %stride:int[], %padding:int[],
330 %dilation:int[], %groups:int, %dummy_min_max):
331 %packed_weight_bias = prepacked::conv2d_clamp_prepack(
332 %weight, %bias, %stride, %padding, %dilation, %groups,
333 %dummy_min_max, %dummy_min_max)
334 %conv2d_res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias)
335 %res = aten::relu_(%conv2d_res)
336 return (%res) )";
337
338 value_mappings = {
339 {"output_min", "packed_weight_bias"},
340 {"output_max", "packed_weight_bias"},
341 {"packed_weight_bias", "packed_weight_bias"},
342 {"res", "res"}};
343
344 rewriter.RegisterRewritePattern(
345 linear_prepack_run_relu_inplace,
346 linear_prepack_run_relu_fused,
347 value_mappings);
348
349 value_mappings = {
350 {"output_min", "packed_weight_bias"},
351 {"output_max", "packed_weight_bias"},
352 {"packed_weight_bias", "packed_weight_bias"},
353 {"res", "res"}};
354
355 rewriter.RegisterRewritePattern(
356 conv2d_prepack_run_relu_inplace,
357 conv2d_prepack_run_relu_fused,
358 value_mappings);
359 rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable);
360 }
361
runCanonicalOptimizations(script::Module & module)362 void runCanonicalOptimizations(script::Module& module) {
363 for (const auto& method : module.get_methods()) {
364 auto graph = method.graph();
365 // Not sure if we have models running on mobile that require loop unrolling.
366 // Perhaps language/speech models? Conservatively setting that to false.
367 runOptimization(graph, false /* no loop unrolling */);
368 }
369 }
370
371 } // namespace
372
insertPrePackedOps(std::shared_ptr<Graph> & graph)373 void insertPrePackedOps(std::shared_ptr<Graph>& graph) {
374 insertPrePackedLinearOp(graph);
375 insertPrePackedConv2dOp(graph);
376 }
377
insertPrePackedOps(script::Module & module)378 void insertPrePackedOps(script::Module& module) {
379 for (auto& method : module.get_methods()) {
380 auto graph = method.graph();
381 insertPrePackedOps(graph);
382 }
383 for (script::Module m : module.children()) {
384 insertPrePackedOps(m);
385 }
386 }
387
fusePrePackedLinearConvWithClamp(script::Module & module)388 void fusePrePackedLinearConvWithClamp(script::Module& module) {
389 for (auto& method : module.get_methods()) {
390 auto graph = method.graph();
391 fuseReluWithPackedOps(graph);
392 fuseHardtanhWithPackedOps(graph);
393
394 // Ignore user defined classes for later passes
395 ConstantPropagation(graph, true);
396 }
397 }
398
FoldPrePackingOps(script::Module & m)399 void FoldPrePackingOps(script::Module& m) {
400 PrePackingOpsFilterFn filter_fn = [](const Node* n) -> bool {
401 return (
402 (n->kind() ==
403 Symbol::fromQualString("prepacked::linear_clamp_prepack")) ||
404 n->kind() ==
405 Symbol::fromQualString("prepacked::conv2d_clamp_prepack") ||
406 n->kind() ==
407 Symbol::fromQualString(
408 "prepacked::conv2d_transpose_clamp_prepack"));
409 };
410 PrePackingOpsFolder(m, filter_fn, "prepack_folding");
411 for (auto& method : m.get_methods()) {
412 auto graph = method.graph();
413 // Folding requires a const propagation through user defined classes
414 ConstantPropagation(graph, false);
415 }
416 }
417
optimizeForMobile(const script::Module & m,const std::set<MobileOptimizerType> & optimization_blocklist,const std::vector<std::string> & preserved_methods)418 script::Module optimizeForMobile(
419 const script::Module& m,
420 const std::set<MobileOptimizerType>& optimization_blocklist,
421 const std::vector<std::string>& preserved_methods) {
422 auto cloned_module = m.clone();
423 cloned_module.eval();
424
425 if (!optimization_blocklist.count(MobileOptimizerType::CONV_1D_TO_2D)) {
426 transformConv1dToConv2d(cloned_module);
427 }
428
429 if (!optimization_blocklist.count(MobileOptimizerType::CONV_BN_FUSION)) {
430 cloned_module = FoldConvBatchNorm(cloned_module);
431 }
432
433 // Many optimizations require a frozen module, but ConvBatchNorm requires
434 // an unfrozen module
435 cloned_module = freeze_module(cloned_module, preserved_methods);
436
437 if (!optimization_blocklist.count(
438 MobileOptimizerType::INSERT_FOLD_PREPACK_OPS)) {
439 // TODO fix duplication caused by referencing same op across multiple
440 // functions
441 insertPrePackedOps(cloned_module);
442 cloned_module = freeze_module(cloned_module, preserved_methods);
443 fusePrePackedLinearConvWithClamp(cloned_module);
444 FoldPrePackingOps(cloned_module);
445 }
446
447 if (!optimization_blocklist.count(
448 MobileOptimizerType::HOIST_CONV_PACKED_PARAMS) &&
449 cloned_module.find_method("forward")) {
450 // freeze again in case it was not done in previous optional passes
451 cloned_module = freeze_module(cloned_module, preserved_methods);
452 HoistConvPackedParams(cloned_module);
453 // and freeze yet again to remove the empty QuantizedConv modules
454 cloned_module = freeze_module(cloned_module, preserved_methods);
455 }
456
457 // Run canonical optimizations post freezing
458 // since freezing inlines the graph. Otherwise we
459 // will have to explicitly call Inlining pass.
460 runCanonicalOptimizations(cloned_module);
461
462 if (!optimization_blocklist.count(MobileOptimizerType::REMOVE_DROPOUT)) {
463 for (const auto& method : cloned_module.get_methods()) {
464 auto graph = method.graph();
465 // Module must be not be in training mode but optimize calls eval()
466 removeDropout(graph);
467 }
468 }
469
470 if (!optimization_blocklist.count(MobileOptimizerType::FUSE_ADD_RELU)) {
471 for (const auto& method : cloned_module.get_methods()) {
472 auto graph = method.graph();
473 FuseAddRelu(graph);
474 }
475 }
476 cloned_module.register_attribute("mobile_optimized", BoolType::get(), true);
477 return cloned_module;
478 }
479
480 #else
481
insertPrePackedOps(std::shared_ptr<Graph> & graph)482 void insertPrePackedOps(std::shared_ptr<Graph>& graph) {
483 TORCH_INTERNAL_ASSERT(
484 false, "XNNPACK is not enabled. Please build with USE_XNNPACK=1");
485 }
486
insertPrePackedOps(script::Module & module)487 void insertPrePackedOps(script::Module& module) {
488 TORCH_INTERNAL_ASSERT(
489 false, "XNNPACK is not enabled. Please build with USE_XNNPACK=1");
490 }
491
fusePrePackedLinearConvWithClamp(script::Module & module)492 void fusePrePackedLinearConvWithClamp(script::Module& module) {
493 TORCH_INTERNAL_ASSERT(
494 false, "XNNPACK is not enabled. Please build with USE_XNNPACK=1");
495 }
496
FoldPrePackingOps(script::Module & m)497 void FoldPrePackingOps(script::Module& m) {
498 TORCH_INTERNAL_ASSERT(
499 false, "XNNPACK is not enabled. Please build with USE_XNNPACK=1");
500 }
501
optimizeForMobile(const script::Module & module,const std::set<MobileOptimizerType> & blocklist,const std::vector<std::string> & preserved_methods)502 script::Module optimizeForMobile(
503 const script::Module& module,
504 const std::set<MobileOptimizerType>& blocklist,
505 const std::vector<std::string>& preserved_methods) {
506 TORCH_INTERNAL_ASSERT(
507 false,
508 "Mobile optimization only available with XNNPACK at the moment. "
509 "XNNPACK is not enabled. Please build with USE_XNNPACK=1");
510 return module;
511 }
512
513 #endif
514 } // namespace torch::jit
515