1 #include <ATen/core/jit_type.h>
2 #include <c10/util/irange.h>
3
4 #include <torch/csrc/jit/ir/ir.h>
5 #include <torch/csrc/jit/ir/subgraph_matcher.h>
6 #include <torch/csrc/jit/passes/constant_pooling.h>
7 #include <torch/csrc/jit/passes/fold_conv_bn.h>
8 #include <torch/csrc/jit/passes/freeze_module.h>
9 #include <torch/csrc/jit/passes/fuse_linear.h>
10 #include <torch/csrc/jit/passes/graph_rewrite_helper.h>
11 #include <torch/csrc/jit/passes/metal_rewrite.h>
12 #include <torch/csrc/jit/passes/prepack_folding.h>
13 #include <torch/csrc/jit/passes/remove_dropout.h>
14 #include <torch/csrc/jit/passes/remove_mutation.h>
15 #include <torch/csrc/jit/passes/subgraph_rewrite.h>
16 #include <torch/csrc/jit/runtime/graph_executor_impl.h>
17
18 namespace torch::jit {
19
20 namespace {
21
insertPrePackedLinearOp(std::shared_ptr<Graph> & graph)22 void insertPrePackedLinearOp(std::shared_ptr<Graph>& graph) {
23 // fuse decomposed linear into aten::linear
24 FuseLinear(graph);
25
26 std::string linear_pattern = R"(
27 graph(%input, %weight, %bias):
28 %r = aten::linear(%input, %weight, %bias)
29 return (%r))";
30 std::string prepacked_ops_pattern = R"(
31 graph(%input, %weight, %bias):
32 %output_min_max : None = prim::Constant()
33 %packed_weight_bias = metal_prepack::linear_prepack(
34 %weight, %bias, %output_min_max, %output_min_max)
35 %res = metal_prepack::linear_run(%input, %packed_weight_bias)
36 return (%res))";
37
38 SubgraphRewriter linear_rewriter;
39 linear_rewriter.RegisterRewritePattern(linear_pattern, prepacked_ops_pattern);
40 linear_rewriter.runOnGraph(graph);
41 }
42
insertPrePackedConv2dOp(std::shared_ptr<Graph> & graph)43 void insertPrePackedConv2dOp(std::shared_ptr<Graph>& graph) {
44 graph_rewrite_helper::replaceConvolutionWithAtenConv(graph);
45
46 std::string conv_2d_pattern = R"(
47 graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
48 %r = aten::conv2d(%input, %weight, %bias, %stride, %padding, %dilation, %groups)
49 return (%r) )";
50
51 std::string prepacked_ops_conv2d_pattern = R"(
52 graph(%input, %weight, %bias, %stride:int[], %padding:int[],
53 %dilation:int[], %groups:int):
54 %output_min_max : None = prim::Constant()
55 %packed_weight_bias = metal_prepack::conv2d_prepack(
56 %weight, %bias, %stride, %padding, %dilation, %groups,
57 %output_min_max, %output_min_max)
58 %r = metal_prepack::conv2d_run(%input, %packed_weight_bias)
59 return (%r) )";
60
61 SubgraphRewriter rewriter;
62 rewriter.RegisterRewritePattern(
63 conv_2d_pattern, prepacked_ops_conv2d_pattern);
64 rewriter.runOnGraph(graph);
65 }
66
fuseReluWithPackedOps(std::shared_ptr<Graph> & graph)67 void fuseReluWithPackedOps(std::shared_ptr<Graph>& graph) {
68 SubgraphRewriter rewriter;
69
70 std::string linear_prepack_run_relu_fused = R"(
71 graph(%input, %weight, %bias, %dummy_min_max):
72 %output_min: float = prim::Constant[value=0.0]()
73 %output_max: None = prim::Constant()
74 %packed_weight_bias : __torch__.torch.classes.metal.LinearOpContext = metal_prepack::linear_prepack(
75 %weight, %bias, %output_min, %output_max)
76 %res = metal_prepack::linear_run(%input, %packed_weight_bias)
77 return (%res))";
78
79 std::string linear_prepack_run_relu = R"(
80 graph(%input, %weight, %bias, %dummy_min_max):
81 %packed_weight_bias = metal_prepack::linear_prepack(
82 %weight, %bias, %dummy_min_max, %dummy_min_max)
83 %linear_res = metal_prepack::linear_run(%input, %packed_weight_bias)
84 %res = aten::relu(%linear_res)
85 return (%res))";
86
87 rewriter.RegisterRewritePattern(
88 linear_prepack_run_relu, linear_prepack_run_relu_fused);
89
90 std::string conv2d_prepack_run_relu = R"(
91 graph(%input, %weight, %bias, %stride:int[], %padding:int[],
92 %dilation:int[], %groups:int, %dummy_min_max):
93 %packed_weight_bias = metal_prepack::conv2d_prepack(
94 %weight, %bias, %stride, %padding, %dilation, %groups,
95 %dummy_min_max, %dummy_min_max)
96 %r = metal_prepack::conv2d_run(%input, %packed_weight_bias)
97 %r = aten::relu(%r)
98 return (%r) )";
99
100 std::string conv2d_prepack_run_relu_fused = R"(
101 graph(%input, %weight, %bias, %stride:int[], %padding:int[],
102 %dilation:int[], %groups:int, %dummy_min_max):
103 %output_min: float = prim::Constant[value=0.0]()
104 %output_max: None = prim::Constant()
105 %packed_weight_bias: __torch__.torch.classes.metal.Conv2dOpContext = metal_prepack::conv2d_prepack(
106 %weight, %bias, %stride, %padding, %dilation, %groups,
107 %output_min, %output_max)
108 %r = metal_prepack::conv2d_run(%input, %packed_weight_bias)
109 return (%r) )";
110
111 rewriter.RegisterRewritePattern(
112 conv2d_prepack_run_relu, conv2d_prepack_run_relu_fused);
113
114 std::string linear_prepack_run_relu_inplace = R"(
115 graph(%input, %weight, %bias, %dummy_min_max):
116 %packed_weight_bias = metal_prepack::linear_prepack(
117 %weight, %bias, %dummy_min_max, %dummy_min_max)
118 %linear_res = metal_prepack::linear_run(%input, %packed_weight_bias)
119 %res = aten::relu_(%linear_res)
120 return (%res))";
121
122 std::string conv2d_prepack_run_relu_inplace = R"(
123 graph(%input, %weight, %bias, %stride:int[], %padding:int[],
124 %dilation:int[], %groups:int, %dummy_min_max):
125 %packed_weight_bias = metal_prepack::conv2d_prepack(
126 %weight, %bias, %stride, %padding, %dilation, %groups,
127 %dummy_min_max, %dummy_min_max)
128 %r = metal_prepack::conv2d_run(%input, %packed_weight_bias)
129 %r = aten::relu_(%r)
130 return (%r) )";
131
132 rewriter.RegisterRewritePattern(
133 linear_prepack_run_relu_inplace, linear_prepack_run_relu_fused);
134 rewriter.RegisterRewritePattern(
135 conv2d_prepack_run_relu_inplace, conv2d_prepack_run_relu_fused);
136
137 rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable);
138 }
139
fuseHardtanhWithPackedOps(std::shared_ptr<Graph> & graph)140 void fuseHardtanhWithPackedOps(std::shared_ptr<Graph>& graph) {
141 SubgraphRewriter rewriter;
142
143 std::string linear_prepack_run_hardtanh_fused = R"(
144 graph(%input, %weight, %bias, %output_min, %output_max, %dummy_min_max):
145 %packed_weight_bias : __torch__.torch.classes.metal.LinearOpContext = metal_prepack::linear_prepack(%weight, %bias, %output_min, %output_max)
146 %res = metal_prepack::linear_run(%input, %packed_weight_bias)
147 return (%res))";
148
149 std::string linear_prepack_run_hardtanh = R"(
150 graph(%input, %weight, %bias, %output_min, %output_max, %dummy_min_max):
151 %packed_weight_bias = metal_prepack::linear_prepack(
152 %weight, %bias, %dummy_min_max, %dummy_min_max)
153 %linear_res = metal_prepack::linear_run(%input, %packed_weight_bias)
154 %res = aten::hardtanh(%linear_res, %output_min, %output_max)
155 return (%res))";
156
157 rewriter.RegisterRewritePattern(
158 linear_prepack_run_hardtanh, linear_prepack_run_hardtanh_fused);
159
160 std::string conv2d_prepack_run_hardtanh_fused = R"(
161 graph(%input, %weight, %bias, %stride:int[], %padding:int[],
162 %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
163 %packed_weight_bias: __torch__.torch.classes.metal.Conv2dOpContext = metal_prepack::conv2d_prepack(
164 %weight, %bias, %stride, %padding, %dilation, %groups,
165 %output_min, %output_max)
166 %r = metal_prepack::conv2d_run(%input, %packed_weight_bias)
167 return (%r) )";
168
169 std::string conv2d_prepack_run_hardtanh = R"(
170 graph(%input, %weight, %bias, %stride:int[], %padding:int[],
171 %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
172 %packed_weight_bias = metal_prepack::conv2d_prepack(
173 %weight, %bias, %stride, %padding, %dilation, %groups,
174 %dummy_min_max, %dummy_min_max)
175 %r = metal_prepack::conv2d_run(%input, %packed_weight_bias)
176 %r = aten::hardtanh(%r, %output_min, %output_max)
177 return (%r) )";
178
179 rewriter.RegisterRewritePattern(
180 conv2d_prepack_run_hardtanh, conv2d_prepack_run_hardtanh_fused);
181
182 std::string conv2d_prepack_run_hardtanh_inplace = R"(
183 graph(%input, %weight, %bias, %stride:int[], %padding:int[],
184 %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
185 %packed_weight_bias = metal_prepack::conv2d_prepack(
186 %weight, %bias, %stride, %padding, %dilation, %groups,
187 %dummy_min_max, %dummy_min_max)
188 %r = metal_prepack::conv2d_run(%input, %packed_weight_bias)
189 %r = aten::hardtanh_(%r, %output_min, %output_max)
190 return (%r) )";
191
192 std::string linear_prepack_run_hardtanh_inplace = R"(
193 graph(%input, %weight, %bias, %output_min, %output_max, %dummy_min_max):
194 %packed_weight_bias = metal_prepack::linear_prepack(
195 %weight, %bias, %dummy_min_max, %dummy_min_max)
196 %linear_res = metal_prepack::linear_run(%input, %packed_weight_bias)
197 %res = aten::hardtanh_(%linear_res, %output_min, %output_max)
198 return (%res))";
199
200 rewriter.RegisterRewritePattern(
201 linear_prepack_run_hardtanh_inplace, linear_prepack_run_hardtanh_fused);
202
203 rewriter.RegisterRewritePattern(
204 conv2d_prepack_run_hardtanh_inplace, conv2d_prepack_run_hardtanh_fused);
205
206 rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable);
207 }
208
209 } // namespace
210
metalInsertPrePackedOps(std::shared_ptr<Graph> & graph)211 void metalInsertPrePackedOps(std::shared_ptr<Graph>& graph) {
212 insertPrePackedLinearOp(graph);
213 insertPrePackedConv2dOp(graph);
214 }
215
metalInsertPrePackedOps(script::Module & module)216 void metalInsertPrePackedOps(script::Module& module) {
217 for (auto& method : module.get_methods()) {
218 auto graph = method.graph();
219 metalInsertPrePackedOps(graph);
220 }
221 for (script::Module m : module.children()) {
222 metalInsertPrePackedOps(m);
223 }
224 }
225
metalFoldPrePackingOps(script::Module & m)226 void metalFoldPrePackingOps(script::Module& m) {
227 PrePackingOpsFilterFn filter_fn = [](const Node* n) -> bool {
228 return (
229 (n->kind() ==
230 Symbol::fromQualString("metal_prepack::conv2d_prepack")) ||
231 (n->kind() == Symbol::fromQualString("metal_prepack::linear_prepack")));
232 };
233 PrePackingOpsFolder(m, filter_fn, "prepack_folding");
234 }
235
metalFusePrePackedConvWithClamp(script::Module & module)236 void metalFusePrePackedConvWithClamp(script::Module& module) {
237 auto graph = module.get_method("forward").graph();
238 fuseReluWithPackedOps(graph);
239 fuseHardtanhWithPackedOps(graph);
240 }
241
metalRemoveMutation(script::Module & module)242 static void metalRemoveMutation(script::Module& module) {
243 auto graph = module.get_method("forward").graph();
244 RemoveTensorMutation(graph);
245 }
246
metalRunCanonicalOptimizations(script::Module & module)247 static void metalRunCanonicalOptimizations(script::Module& module) {
248 auto graph = module.get_method("forward").graph();
249 runOptimization(graph, false /* no loop unrolling */);
250 }
251
metalOptimizeForMobile(const script::Module & m,const std::vector<std::string> & preserved_methods)252 script::Module metalOptimizeForMobile(
253 const script::Module& m,
254 const std::vector<std::string>& preserved_methods) {
255 auto cloned_module = m.clone();
256 cloned_module.eval();
257 cloned_module = FoldConvBatchNorm(cloned_module);
258 metalInsertPrePackedOps(cloned_module);
259 cloned_module = freeze_module(cloned_module, preserved_methods);
260 metalFusePrePackedConvWithClamp(cloned_module);
261 metalFoldPrePackingOps(cloned_module);
262 removeDropout(cloned_module);
263 metalRemoveMutation(cloned_module);
264 // remove duplicated constants
265 metalRunCanonicalOptimizations(cloned_module);
266 cloned_module.register_attribute(
267 "optimized_for_metal", BoolType::get(), true);
268 return cloned_module;
269 }
270
271 } // namespace torch::jit
272