xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/metal_rewrite.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/jit_type.h>
2 #include <c10/util/irange.h>
3 
4 #include <torch/csrc/jit/ir/ir.h>
5 #include <torch/csrc/jit/ir/subgraph_matcher.h>
6 #include <torch/csrc/jit/passes/constant_pooling.h>
7 #include <torch/csrc/jit/passes/fold_conv_bn.h>
8 #include <torch/csrc/jit/passes/freeze_module.h>
9 #include <torch/csrc/jit/passes/fuse_linear.h>
10 #include <torch/csrc/jit/passes/graph_rewrite_helper.h>
11 #include <torch/csrc/jit/passes/metal_rewrite.h>
12 #include <torch/csrc/jit/passes/prepack_folding.h>
13 #include <torch/csrc/jit/passes/remove_dropout.h>
14 #include <torch/csrc/jit/passes/remove_mutation.h>
15 #include <torch/csrc/jit/passes/subgraph_rewrite.h>
16 #include <torch/csrc/jit/runtime/graph_executor_impl.h>
17 
18 namespace torch::jit {
19 
20 namespace {
21 
insertPrePackedLinearOp(std::shared_ptr<Graph> & graph)22 void insertPrePackedLinearOp(std::shared_ptr<Graph>& graph) {
23   // fuse decomposed linear into aten::linear
24   FuseLinear(graph);
25 
26   std::string linear_pattern = R"(
27     graph(%input, %weight, %bias):
28         %r = aten::linear(%input, %weight, %bias)
29         return (%r))";
30   std::string prepacked_ops_pattern = R"(
31     graph(%input, %weight, %bias):
32         %output_min_max : None = prim::Constant()
33         %packed_weight_bias = metal_prepack::linear_prepack(
34             %weight, %bias, %output_min_max, %output_min_max)
35         %res = metal_prepack::linear_run(%input, %packed_weight_bias)
36         return (%res))";
37 
38   SubgraphRewriter linear_rewriter;
39   linear_rewriter.RegisterRewritePattern(linear_pattern, prepacked_ops_pattern);
40   linear_rewriter.runOnGraph(graph);
41 }
42 
insertPrePackedConv2dOp(std::shared_ptr<Graph> & graph)43 void insertPrePackedConv2dOp(std::shared_ptr<Graph>& graph) {
44   graph_rewrite_helper::replaceConvolutionWithAtenConv(graph);
45 
46   std::string conv_2d_pattern = R"(
47     graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
48         %r = aten::conv2d(%input, %weight, %bias, %stride, %padding, %dilation, %groups)
49         return (%r) )";
50 
51   std::string prepacked_ops_conv2d_pattern = R"(
52     graph(%input, %weight, %bias, %stride:int[], %padding:int[],
53           %dilation:int[], %groups:int):
54         %output_min_max : None = prim::Constant()
55         %packed_weight_bias = metal_prepack::conv2d_prepack(
56             %weight, %bias, %stride, %padding, %dilation, %groups,
57             %output_min_max, %output_min_max)
58         %r = metal_prepack::conv2d_run(%input, %packed_weight_bias)
59         return (%r) )";
60 
61   SubgraphRewriter rewriter;
62   rewriter.RegisterRewritePattern(
63       conv_2d_pattern, prepacked_ops_conv2d_pattern);
64   rewriter.runOnGraph(graph);
65 }
66 
fuseReluWithPackedOps(std::shared_ptr<Graph> & graph)67 void fuseReluWithPackedOps(std::shared_ptr<Graph>& graph) {
68   SubgraphRewriter rewriter;
69 
70   std::string linear_prepack_run_relu_fused = R"(
71     graph(%input, %weight, %bias, %dummy_min_max):
72         %output_min: float = prim::Constant[value=0.0]()
73         %output_max: None = prim::Constant()
74         %packed_weight_bias : __torch__.torch.classes.metal.LinearOpContext = metal_prepack::linear_prepack(
75             %weight, %bias, %output_min, %output_max)
76         %res = metal_prepack::linear_run(%input, %packed_weight_bias)
77         return (%res))";
78 
79   std::string linear_prepack_run_relu = R"(
80     graph(%input, %weight, %bias, %dummy_min_max):
81         %packed_weight_bias = metal_prepack::linear_prepack(
82             %weight, %bias, %dummy_min_max, %dummy_min_max)
83         %linear_res = metal_prepack::linear_run(%input, %packed_weight_bias)
84         %res = aten::relu(%linear_res)
85         return (%res))";
86 
87   rewriter.RegisterRewritePattern(
88       linear_prepack_run_relu, linear_prepack_run_relu_fused);
89 
90   std::string conv2d_prepack_run_relu = R"(
91     graph(%input, %weight, %bias, %stride:int[], %padding:int[],
92           %dilation:int[], %groups:int, %dummy_min_max):
93         %packed_weight_bias = metal_prepack::conv2d_prepack(
94             %weight, %bias, %stride, %padding, %dilation, %groups,
95             %dummy_min_max, %dummy_min_max)
96         %r = metal_prepack::conv2d_run(%input, %packed_weight_bias)
97         %r = aten::relu(%r)
98         return (%r) )";
99 
100   std::string conv2d_prepack_run_relu_fused = R"(
101   graph(%input, %weight, %bias, %stride:int[], %padding:int[],
102         %dilation:int[], %groups:int, %dummy_min_max):
103       %output_min: float = prim::Constant[value=0.0]()
104       %output_max: None = prim::Constant()
105       %packed_weight_bias: __torch__.torch.classes.metal.Conv2dOpContext = metal_prepack::conv2d_prepack(
106           %weight, %bias, %stride, %padding, %dilation, %groups,
107           %output_min, %output_max)
108       %r = metal_prepack::conv2d_run(%input, %packed_weight_bias)
109       return (%r) )";
110 
111   rewriter.RegisterRewritePattern(
112       conv2d_prepack_run_relu, conv2d_prepack_run_relu_fused);
113 
114   std::string linear_prepack_run_relu_inplace = R"(
115     graph(%input, %weight, %bias, %dummy_min_max):
116         %packed_weight_bias = metal_prepack::linear_prepack(
117             %weight, %bias, %dummy_min_max, %dummy_min_max)
118         %linear_res = metal_prepack::linear_run(%input, %packed_weight_bias)
119         %res = aten::relu_(%linear_res)
120         return (%res))";
121 
122   std::string conv2d_prepack_run_relu_inplace = R"(
123   graph(%input, %weight, %bias, %stride:int[], %padding:int[],
124         %dilation:int[], %groups:int, %dummy_min_max):
125       %packed_weight_bias = metal_prepack::conv2d_prepack(
126           %weight, %bias, %stride, %padding, %dilation, %groups,
127           %dummy_min_max, %dummy_min_max)
128       %r = metal_prepack::conv2d_run(%input, %packed_weight_bias)
129       %r = aten::relu_(%r)
130       return (%r) )";
131 
132   rewriter.RegisterRewritePattern(
133       linear_prepack_run_relu_inplace, linear_prepack_run_relu_fused);
134   rewriter.RegisterRewritePattern(
135       conv2d_prepack_run_relu_inplace, conv2d_prepack_run_relu_fused);
136 
137   rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable);
138 }
139 
fuseHardtanhWithPackedOps(std::shared_ptr<Graph> & graph)140 void fuseHardtanhWithPackedOps(std::shared_ptr<Graph>& graph) {
141   SubgraphRewriter rewriter;
142 
143   std::string linear_prepack_run_hardtanh_fused = R"(
144     graph(%input, %weight, %bias, %output_min, %output_max, %dummy_min_max):
145         %packed_weight_bias : __torch__.torch.classes.metal.LinearOpContext = metal_prepack::linear_prepack(%weight, %bias, %output_min, %output_max)
146         %res = metal_prepack::linear_run(%input, %packed_weight_bias)
147         return (%res))";
148 
149   std::string linear_prepack_run_hardtanh = R"(
150     graph(%input, %weight, %bias, %output_min, %output_max, %dummy_min_max):
151         %packed_weight_bias = metal_prepack::linear_prepack(
152             %weight, %bias, %dummy_min_max, %dummy_min_max)
153         %linear_res = metal_prepack::linear_run(%input, %packed_weight_bias)
154         %res = aten::hardtanh(%linear_res, %output_min, %output_max)
155         return (%res))";
156 
157   rewriter.RegisterRewritePattern(
158       linear_prepack_run_hardtanh, linear_prepack_run_hardtanh_fused);
159 
160   std::string conv2d_prepack_run_hardtanh_fused = R"(
161     graph(%input, %weight, %bias, %stride:int[], %padding:int[],
162           %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
163         %packed_weight_bias: __torch__.torch.classes.metal.Conv2dOpContext = metal_prepack::conv2d_prepack(
164             %weight, %bias, %stride, %padding, %dilation, %groups,
165             %output_min, %output_max)
166         %r = metal_prepack::conv2d_run(%input, %packed_weight_bias)
167         return (%r) )";
168 
169   std::string conv2d_prepack_run_hardtanh = R"(
170     graph(%input, %weight, %bias, %stride:int[], %padding:int[],
171           %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
172         %packed_weight_bias = metal_prepack::conv2d_prepack(
173             %weight, %bias, %stride, %padding, %dilation, %groups,
174             %dummy_min_max, %dummy_min_max)
175         %r = metal_prepack::conv2d_run(%input, %packed_weight_bias)
176         %r = aten::hardtanh(%r, %output_min, %output_max)
177         return (%r) )";
178 
179   rewriter.RegisterRewritePattern(
180       conv2d_prepack_run_hardtanh, conv2d_prepack_run_hardtanh_fused);
181 
182   std::string conv2d_prepack_run_hardtanh_inplace = R"(
183     graph(%input, %weight, %bias, %stride:int[], %padding:int[],
184           %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
185         %packed_weight_bias = metal_prepack::conv2d_prepack(
186             %weight, %bias, %stride, %padding, %dilation, %groups,
187             %dummy_min_max, %dummy_min_max)
188         %r = metal_prepack::conv2d_run(%input, %packed_weight_bias)
189         %r = aten::hardtanh_(%r, %output_min, %output_max)
190         return (%r) )";
191 
192   std::string linear_prepack_run_hardtanh_inplace = R"(
193     graph(%input, %weight, %bias, %output_min, %output_max, %dummy_min_max):
194         %packed_weight_bias = metal_prepack::linear_prepack(
195             %weight, %bias, %dummy_min_max, %dummy_min_max)
196         %linear_res = metal_prepack::linear_run(%input, %packed_weight_bias)
197         %res = aten::hardtanh_(%linear_res, %output_min, %output_max)
198         return (%res))";
199 
200   rewriter.RegisterRewritePattern(
201       linear_prepack_run_hardtanh_inplace, linear_prepack_run_hardtanh_fused);
202 
203   rewriter.RegisterRewritePattern(
204       conv2d_prepack_run_hardtanh_inplace, conv2d_prepack_run_hardtanh_fused);
205 
206   rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable);
207 }
208 
209 } // namespace
210 
metalInsertPrePackedOps(std::shared_ptr<Graph> & graph)211 void metalInsertPrePackedOps(std::shared_ptr<Graph>& graph) {
212   insertPrePackedLinearOp(graph);
213   insertPrePackedConv2dOp(graph);
214 }
215 
metalInsertPrePackedOps(script::Module & module)216 void metalInsertPrePackedOps(script::Module& module) {
217   for (auto& method : module.get_methods()) {
218     auto graph = method.graph();
219     metalInsertPrePackedOps(graph);
220   }
221   for (script::Module m : module.children()) {
222     metalInsertPrePackedOps(m);
223   }
224 }
225 
metalFoldPrePackingOps(script::Module & m)226 void metalFoldPrePackingOps(script::Module& m) {
227   PrePackingOpsFilterFn filter_fn = [](const Node* n) -> bool {
228     return (
229         (n->kind() ==
230          Symbol::fromQualString("metal_prepack::conv2d_prepack")) ||
231         (n->kind() == Symbol::fromQualString("metal_prepack::linear_prepack")));
232   };
233   PrePackingOpsFolder(m, filter_fn, "prepack_folding");
234 }
235 
metalFusePrePackedConvWithClamp(script::Module & module)236 void metalFusePrePackedConvWithClamp(script::Module& module) {
237   auto graph = module.get_method("forward").graph();
238   fuseReluWithPackedOps(graph);
239   fuseHardtanhWithPackedOps(graph);
240 }
241 
metalRemoveMutation(script::Module & module)242 static void metalRemoveMutation(script::Module& module) {
243   auto graph = module.get_method("forward").graph();
244   RemoveTensorMutation(graph);
245 }
246 
metalRunCanonicalOptimizations(script::Module & module)247 static void metalRunCanonicalOptimizations(script::Module& module) {
248   auto graph = module.get_method("forward").graph();
249   runOptimization(graph, false /* no loop unrolling */);
250 }
251 
metalOptimizeForMobile(const script::Module & m,const std::vector<std::string> & preserved_methods)252 script::Module metalOptimizeForMobile(
253     const script::Module& m,
254     const std::vector<std::string>& preserved_methods) {
255   auto cloned_module = m.clone();
256   cloned_module.eval();
257   cloned_module = FoldConvBatchNorm(cloned_module);
258   metalInsertPrePackedOps(cloned_module);
259   cloned_module = freeze_module(cloned_module, preserved_methods);
260   metalFusePrePackedConvWithClamp(cloned_module);
261   metalFoldPrePackingOps(cloned_module);
262   removeDropout(cloned_module);
263   metalRemoveMutation(cloned_module);
264   // remove duplicated constants
265   metalRunCanonicalOptimizations(cloned_module);
266   cloned_module.register_attribute(
267       "optimized_for_metal", BoolType::get(), true);
268   return cloned_module;
269 }
270 
271 } // namespace torch::jit
272