xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/vulkan_rewrite.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/jit_type.h>
2 #include <torch/csrc/jit/ir/ir.h>
3 #include <torch/csrc/jit/ir/subgraph_matcher.h>
4 #include <torch/csrc/jit/passes/dead_code_elimination.h>
5 #include <torch/csrc/jit/passes/fold_conv_bn.h>
6 #include <torch/csrc/jit/passes/freeze_module.h>
7 #include <torch/csrc/jit/passes/fuse_linear.h>
8 #include <torch/csrc/jit/passes/graph_rewrite_helper.h>
9 #include <torch/csrc/jit/passes/prepack_folding.h>
10 #include <torch/csrc/jit/passes/remove_dropout.h>
11 #include <torch/csrc/jit/passes/remove_mutation.h>
12 #include <torch/csrc/jit/passes/subgraph_rewrite.h>
13 #include <torch/csrc/jit/passes/vulkan_rewrite.h>
14 #include <torch/csrc/jit/runtime/graph_executor_impl.h>
15 
16 namespace torch::jit {
17 
18 namespace {
19 
insertPrePackedBatchNormOp(std::shared_ptr<Graph> & graph)20 void insertPrePackedBatchNormOp(std::shared_ptr<Graph>& graph) {
21   std::string batchnorm_pattern = R"(
22     graph(%input, %weight, %bias, %mean, %var, %training, %momentum, %eps, %cudnn_enable):
23         %r = aten::batch_norm(%input, %weight, %bias, %mean, %var, %training, %momentum, %eps, %cudnn_enable)
24         return (%r))";
25   std::string prepacked_ops_pattern = R"(
26     graph(%input, %weight, %bias, %mean, %var, %training, %momentum, %eps, %cudnn_enable):
27         %op_context : __torch__.torch.classes.vulkan.BatchNormPackedContext = vulkan_prepack::create_batchnorm_context(
28             %weight, %bias, %mean, %var, %training, %momentum, %eps, %cudnn_enable)
29         %res = vulkan_prepack::run_batchnorm_context(%input, %op_context)
30         return (%res))";
31 
32   SubgraphRewriter batchnorm_rewriter;
33   batchnorm_rewriter.RegisterRewritePattern(
34       batchnorm_pattern, prepacked_ops_pattern);
35   batchnorm_rewriter.runOnGraph(graph);
36 }
37 
insertPrePackedLinearOp(std::shared_ptr<Graph> & graph)38 void insertPrePackedLinearOp(std::shared_ptr<Graph>& graph) {
39   // fuse decomposed linear into aten::linear
40   FuseLinear(graph);
41 
42   std::string linear_pattern = R"(
43     graph(%input, %weight, %bias):
44         %r = aten::linear(%input, %weight, %bias)
45         return (%r))";
46   std::string prepacked_ops_pattern = R"(
47     graph(%input, %weight, %bias):
48         %weight_t = aten::t(%weight)
49         %packed_weight_bias = vulkan_prepack::create_linear_context(
50             %weight_t, %bias)
51         %res = vulkan_prepack::run_linear_context(%input, %packed_weight_bias)
52         return (%res))";
53 
54   SubgraphRewriter linear_rewriter;
55   linear_rewriter.RegisterRewritePattern(linear_pattern, prepacked_ops_pattern);
56   linear_rewriter.runOnGraph(graph);
57 }
58 
insertPrePackedLayernormOp(std::shared_ptr<Graph> & graph)59 void insertPrePackedLayernormOp(std::shared_ptr<Graph>& graph) {
60   std::string layernorm_pattern = R"(
61     graph(%input, %normalized_shape, %weight, %bias, %eps, %cudnn_enable):
62         %r = aten::layer_norm(%input, %normalized_shape, %weight, %bias, %eps, %cudnn_enable)
63         return (%r))";
64   std::string prepacked_ops_pattern = R"(
65     graph(%input, %normalized_shape, %weight, %bias, %eps, %cudnn_enable):
66         %op_context : __torch__.torch.classes.vulkan.LayernormPackedContext = vulkan_prepack::create_layernorm_context(
67             %weight, %bias, %eps)
68         %res = vulkan_prepack::run_layernorm_context(%input, %normalized_shape, %op_context)
69         return (%res))";
70 
71   SubgraphRewriter layernorm_rewriter;
72   layernorm_rewriter.RegisterRewritePattern(
73       layernorm_pattern, prepacked_ops_pattern);
74   layernorm_rewriter.runOnGraph(graph);
75 }
76 
insertPrePackedConv2dOp(std::shared_ptr<Graph> & graph)77 void insertPrePackedConv2dOp(std::shared_ptr<Graph>& graph) {
78   graph_rewrite_helper::replaceConvolutionWithAtenConv(graph);
79 
80   std::string conv_2d_pattern = R"(
81     graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
82         %r = aten::conv2d(%input, %weight, %bias, %stride, %padding, %dilation, %groups)
83         return (%r) )";
84 
85   std::string prepacked_ops_conv2d_pattern = R"(
86     graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
87         %output_min_max : None = prim::Constant()
88         %packed_weight_bias = vulkan_prepack::create_conv2d_context(
89             %weight, %bias, %stride, %padding, %dilation, %groups,
90             %output_min_max, %output_min_max)
91         %r = vulkan_prepack::run_conv2d_context(%input, %packed_weight_bias)
92         return (%r) )";
93 
94   SubgraphRewriter rewriter;
95   rewriter.RegisterRewritePattern(
96       conv_2d_pattern, prepacked_ops_conv2d_pattern);
97   rewriter.runOnGraph(graph);
98 
99   std::string conv_2d_transpose_pattern = R"(
100       graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[],
101           %output_padding:int[], %groups:int):
102         %res = aten::conv_transpose2d(%input, %weight, %bias, %stride, %padding, %output_padding, %groups, %dilation)
103         return (%res) )";
104 
105   std::string prepacked_ops_conv2d_transpose_pattern = R"(
106     graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %output_padding:int[], %groups:int):
107         %output_min_max : None = prim::Constant()
108         %packed_weight_bias = vulkan_prepack::create_tconv2d_context(
109             %weight, %bias, %stride, %padding, %output_padding, %dilation, %groups,
110             %output_min_max, %output_min_max)
111         %res = vulkan_prepack::run_tconv2d_context(%input, %packed_weight_bias)
112         return (%res) )";
113 
114   SubgraphRewriter transpose_rewriter;
115   transpose_rewriter.RegisterRewritePattern(
116       conv_2d_transpose_pattern, prepacked_ops_conv2d_transpose_pattern);
117   transpose_rewriter.runOnGraph(graph);
118 }
119 
insertPrePackedConv1dOp(std::shared_ptr<Graph> & graph)120 void insertPrePackedConv1dOp(std::shared_ptr<Graph>& graph) {
121   graph_rewrite_helper::replaceConvolutionWithAtenConv(graph);
122 
123   std::string conv_1d_pattern = R"(
124     graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
125         %r = aten::conv1d(%input, %weight, %bias, %stride, %padding, %dilation, %groups)
126         return (%r) )";
127 
128   std::string prepacked_ops_conv1d_pattern = R"(
129     graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
130         %packed_weight_bias = vulkan_prepack::create_conv1d_context(
131             %weight, %bias, %stride, %padding, %dilation, %groups)
132         %r = vulkan_prepack::run_conv1d_context(%input, %packed_weight_bias)
133         return (%r) )";
134 
135   SubgraphRewriter rewriter;
136   rewriter.RegisterRewritePattern(
137       conv_1d_pattern, prepacked_ops_conv1d_pattern);
138   rewriter.runOnGraph(graph);
139 }
140 
transferInputOutputBackends(std::shared_ptr<Graph> & graph)141 void transferInputOutputBackends(std::shared_ptr<Graph>& graph) {
142   // Move inputs to Vulkan backend
143   for (Value* input : graph->inputs()) {
144     NamedValue named_input = NamedValue("", input);
145     if (named_input.type()->kind() == TypeKind::TensorType &&
146         !input->uses().empty()) {
147       // find the insertion point
148       WithInsertPoint ip(input->uses()[0].user->prev());
149       Value* replaced_input = graph->insert(
150           Symbol::fromQualString("aten::to"), {named_input, "vulkan"});
151       // replace the input
152       input->replaceAllUsesAfterNodeWith(
153           replaced_input->node(), replaced_input);
154     }
155   }
156 
157   // Move outputs to CPU backend
158   at::ArrayRef<Value*>&& outputs = graph->outputs();
159   for (size_t i = 0; i < outputs.size(); i++) {
160     Value* output = outputs[i];
161     NamedValue named_output = NamedValue("", output);
162     if (named_output.type()->kind() == TypeKind::TensorType) {
163       // find the insertion point
164       WithInsertPoint ip(output->node()->next());
165       Value* replaced_output = graph->insert(
166           Symbol::fromQualString("aten::to"), {named_output, "cpu"});
167       // replace the output
168       graph->block()->replaceOutput(i, replaced_output);
169     }
170   }
171 
172   SubgraphRewriter rewriter;
173   rewriter.runOnGraph(graph);
174 }
175 
transferInputOutputBackends(script::Module & module)176 void transferInputOutputBackends(script::Module& module) {
177   std::shared_ptr<Graph> graph = module.get_methods()[0].graph();
178   transferInputOutputBackends(graph);
179 }
180 
eliminateDeadCode(script::Module & module)181 void eliminateDeadCode(script::Module& module) {
182   for (auto& method : module.get_methods()) {
183     EliminateDeadCode(method.graph());
184   }
185 }
186 
rewriteQuantizedOps(std::shared_ptr<Graph> & graph)187 void rewriteQuantizedOps(std::shared_ptr<Graph>& graph) {
188   // quantized::add
189   std::string quantized_add_pattern = R"(
190     graph(%a_quant, %b_quant, %r_scale, %r_zero_point) :
191       %res = quantized::add(%a_quant, %b_quant, %r_scale, %r_zero_point)
192       return (%res) )";
193   std::string vk_quantized_add_pattern = R"(
194     graph(%a_quant, %b_quant, %r_scale, %r_zero_point) :
195       %res = vulkan_quantized::add(%a_quant, %b_quant, %r_scale, %r_zero_point)
196       return (%res) )";
197 
198   torch::jit::SubgraphRewriter quantized_add_rewriter;
199   quantized_add_rewriter.RegisterRewritePattern(
200       quantized_add_pattern, vk_quantized_add_pattern);
201   quantized_add_rewriter.runOnGraph(graph);
202 
203   // quantized::mul
204   std::string quantized_mul_pattern = R"(
205     graph(%a_quant, %b_quant, %r_scale, %r_zero_point) :
206       %res = quantized::mul(%a_quant, %b_quant, %r_scale, %r_zero_point)
207       return (%res) )";
208   std::string vk_quantized_mul_pattern = R"(
209     graph(%a_quant, %b_quant, %r_scale, %r_zero_point) :
210       %res = vulkan_quantized::mul(%a_quant, %b_quant, %r_scale, %r_zero_point)
211       return (%res) )";
212 
213   torch::jit::SubgraphRewriter quantized_mul_rewriter;
214   quantized_mul_rewriter.RegisterRewritePattern(
215       quantized_mul_pattern, vk_quantized_mul_pattern);
216   quantized_mul_rewriter.runOnGraph(graph);
217 
218   // quantized::conv2d
219   std::string quantized_conv2d_pattern = R"(
220     graph(%a_quant, %packed_params, %r_scale, %r_zero_point) :
221       %res = quantized::conv2d(%a_quant, %packed_params, %r_scale, %r_zero_point)
222       return (%res) )";
223   std::string vk_quantized_conv2d_pattern = R"(
224     graph(%a_quant, %packed_params, %r_scale, %r_zero_point):
225       %output_min_max : None = prim::Constant()
226       %vk_packed_params : __torch__.torch.classes.vulkan.Conv2dPackedContext = vulkan_quantized_prepack::convert_qconv2d_context(
227         %packed_params, %output_min_max, %output_min_max)
228       %res = vulkan_prepack::run_qconv2d_context(
229         %a_quant, %r_scale, %r_zero_point, %vk_packed_params)
230       return (%res) )";
231 
232   torch::jit::SubgraphRewriter quantized_conv2d_rewriter;
233   quantized_conv2d_rewriter.RegisterRewritePattern(
234       quantized_conv2d_pattern, vk_quantized_conv2d_pattern);
235   quantized_conv2d_rewriter.runOnGraph(graph);
236 
237   // quantized::conv_transpose2d
238   std::string quantized_conv_transpose2d_pattern = R"(
239     graph(%a_quant, %packed_params, %r_scale, %r_zero_point) :
240       %res = quantized::conv_transpose2d(%a_quant, %packed_params, %r_scale, %r_zero_point)
241       return (%res) )";
242   std::string vk_quantized_conv_transpose2d_pattern = R"(
243     graph(%a_quant, %packed_params, %r_scale, %r_zero_point):
244       %output_min_max : None = prim::Constant()
245       %vk_packed_params : __torch__.torch.classes.vulkan.Conv2dPackedContext = vulkan_quantized_prepack::convert_qtconv2d_context(
246         %packed_params, %output_min_max, %output_min_max)
247       %res = vulkan_prepack::run_qconv2d_context(
248         %a_quant, %r_scale, %r_zero_point, %vk_packed_params)
249       return (%res) )";
250 
251   torch::jit::SubgraphRewriter quantized_conv_transpose2d_rewriter;
252   quantized_conv_transpose2d_rewriter.RegisterRewritePattern(
253       quantized_conv_transpose2d_pattern,
254       vk_quantized_conv_transpose2d_pattern);
255   quantized_conv_transpose2d_rewriter.runOnGraph(graph);
256 
257   // quantized::conv2d_relu
258   std::string quantized_conv2d_relu_pattern = R"(
259     graph(%a_quant, %packed_params, %r_scale, %r_zero_point) :
260       %res = quantized::conv2d_relu(%a_quant, %packed_params, %r_scale, %r_zero_point)
261       return (%res) )";
262   std::string vk_quantized_conv2d_relu_pattern = R"(
263     graph(%a_quant, %packed_params, %r_scale, %r_zero_point):
264       %output_min: float = prim::Constant[value=0.0]()
265       %output_max: None = prim::Constant()
266       %vk_packed_params : __torch__.torch.classes.vulkan.Conv2dPackedContext = vulkan_quantized_prepack::convert_qconv2d_context(
267         %packed_params, %output_min, %output_max)
268       %res = vulkan_prepack::run_qconv2d_context(
269         %a_quant, %r_scale, %r_zero_point, %vk_packed_params)
270       return (%res) )";
271 
272   torch::jit::SubgraphRewriter quantized_conv2d_relu_rewriter;
273   quantized_conv2d_relu_rewriter.RegisterRewritePattern(
274       quantized_conv2d_relu_pattern, vk_quantized_conv2d_relu_pattern);
275   quantized_conv2d_relu_rewriter.runOnGraph(graph);
276 
277   // quantized::linear
278   std::string quantized_linear_pattern = R"(
279     graph(%a_quant, %packed_params, %r_scale, %r_zero_point) :
280       %res = quantized::linear(%a_quant, %packed_params, %r_scale, %r_zero_point)
281       return (%res) )";
282   std::string vk_quantized_linear_pattern = R"(
283     graph(%a_quant, %packed_params, %r_scale, %r_zero_point):
284       %vk_packed_params : __torch__.torch.classes.vulkan.LinearPackedContext = vulkan_quantized_prepack::convert_linear_context(
285         %packed_params)
286       %res = vulkan_prepack::run_qlinear_context(
287         %a_quant, %r_scale, %r_zero_point, %vk_packed_params)
288       return (%res) )";
289 
290   torch::jit::SubgraphRewriter quantized_linear_rewriter;
291   quantized_linear_rewriter.RegisterRewritePattern(
292       quantized_linear_pattern, vk_quantized_linear_pattern);
293   quantized_linear_rewriter.runOnGraph(graph);
294 }
295 
insertPrePackedGruOp(std::shared_ptr<Graph> & graph)296 void insertPrePackedGruOp(std::shared_ptr<Graph>& graph) {
297   std::string gru_pattern = R"(
298       graph(%input.1, %hx.1, %params_cpu:Tensor[], %has_biases:bool, %num_layers:int, %dropout:float, %train:bool, %bidirectional:bool, %batch_first:bool):
299         %y.1 : Tensor, %hn.1 : Tensor = aten::gru(%input.1, %hx.1, %params_cpu, %has_biases, %num_layers, %dropout, %train, %bidirectional, %batch_first)
300         return (%y.1, %hn.1) )";
301   std::string prepacked_ops_pattern = R"(
302       graph(%input.1, %hx.1, %params_cpu:Tensor[], %has_biases:bool, %num_layers:int, %dropout:float, %train:bool, %bidirectional:bool, %batch_first:bool):
303         %packed_weights_biases = vulkan_prepack::create_gru_context(
304             %params_cpu, %has_biases, %num_layers, %dropout, %train, %bidirectional, %batch_first)
305         %y.1 : Tensor, %hn.1 : Tensor = vulkan_prepack::run_gru_context(%input.1, %hx.1, %packed_weights_biases)
306         return (%y.1, %hn.1) )";
307 
308   auto filter = [&](const Match& match,
309                     const std::unordered_map<std::string, Value*>& vmap) {
310     auto node = match.values_map.at(vmap.at("params_cpu"))->node();
311     return node->output()->type()->str() == "Tensor[]";
312   };
313 
314   SubgraphRewriter gru_rewriter;
315   gru_rewriter.RegisterRewritePattern(gru_pattern, prepacked_ops_pattern);
316   gru_rewriter.runOnGraph(graph, filter);
317 }
318 
insertPrePackedLstmOp(std::shared_ptr<Graph> & graph)319 void insertPrePackedLstmOp(std::shared_ptr<Graph>& graph) {
320   std::string lstm_pattern = R"(
321       graph(%input.1, %hx:Tensor[], %params_cpu:Tensor[], %has_biases:bool, %num_layers:int, %dropout:float, %train:bool, %bidirectional:bool, %batch_first:bool):
322         %y.1 : Tensor, %hn.1 : Tensor, %cn.1 : Tensor = aten::lstm(%input.1, %hx, %params_cpu, %has_biases, %num_layers, %dropout, %train, %bidirectional, %batch_first)
323         return (%y.1, %hn.1, %cn.1) )";
324   std::string prepacked_ops_pattern = R"(
325       graph(%input.1, %hx:Tensor[], %params_cpu:Tensor[], %has_biases:bool, %num_layers:int, %dropout:float, %train:bool, %bidirectional:bool, %batch_first:bool):
326         %packed_weights_biases = vulkan_prepack::create_lstm_context(
327             %params_cpu, %has_biases, %num_layers, %dropout, %train, %bidirectional, %batch_first)
328         %hx.1 : Tensor, %cx.1 : Tensor = prim::ListUnpack(%hx)
329         %y.1 : Tensor, %hn.1 : Tensor, %cn.1 : Tensor = vulkan_prepack::run_lstm_context(%input.1, %hx.1, %cx.1, %packed_weights_biases)
330         return (%y.1, %hn.1, %cn.1) )";
331 
332   auto filter = [&](const Match& match,
333                     const std::unordered_map<std::string, Value*>& vmap) {
334     auto node = match.values_map.at(vmap.at("hx"))->node();
335     return node->output()->type()->str() == "Tensor[]";
336   };
337 
338   SubgraphRewriter lstm_rewriter;
339   lstm_rewriter.RegisterRewritePattern(lstm_pattern, prepacked_ops_pattern);
340   lstm_rewriter.runOnGraph(graph, filter);
341 }
342 
fuseHardtanhWithPackedOps(std::shared_ptr<Graph> & graph)343 void fuseHardtanhWithPackedOps(std::shared_ptr<Graph>& graph) {
344   SubgraphRewriter rewriter;
345 
346   std::string conv2d_prepack_run_hardtanh_fused = R"(
347     graph(%input, %weight, %bias, %stride:int[], %padding:int[],
348           %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
349         %packed_weight_bias : __torch__.torch.classes.vulkan.Conv2dPackedContext = vulkan_prepack::create_conv2d_context(
350             %weight, %bias, %stride, %padding, %dilation, %groups,
351             %output_min, %output_max)
352         %r = vulkan_prepack::run_conv2d_context(%input, %packed_weight_bias)
353         return (%r) )";
354 
355   std::string conv2d_prepack_run_hardtanh = R"(
356     graph(%input, %weight, %bias, %stride:int[], %padding:int[],
357           %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
358         %packed_weight_bias = vulkan_prepack::create_conv2d_context(
359             %weight, %bias, %stride, %padding, %dilation, %groups,
360             %dummy_min_max, %dummy_min_max)
361         %conv2d_res = vulkan_prepack::run_conv2d_context(%input, %packed_weight_bias)
362         %r = aten::hardtanh(%conv2d_res, %output_min, %output_max)
363         return (%r) )";
364 
365   rewriter.RegisterRewritePattern(
366       conv2d_prepack_run_hardtanh, conv2d_prepack_run_hardtanh_fused);
367 
368   std::string conv2d_prepack_run_hardtanh_inplace = R"(
369     graph(%input, %weight, %bias, %stride:int[], %padding:int[],
370           %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
371         %packed_weight_bias = vulkan_prepack::create_conv2d_context(
372             %weight, %bias, %stride, %padding, %dilation, %groups,
373             %dummy_min_max, %dummy_min_max)
374         %conv2d_res = vulkan_prepack::run_conv2d_context(%input, %packed_weight_bias)
375         %r = aten::hardtanh_(%conv2d_res, %output_min, %output_max)
376         return (%r) )";
377 
378   rewriter.RegisterRewritePattern(
379       conv2d_prepack_run_hardtanh_inplace, conv2d_prepack_run_hardtanh_fused);
380 
381   rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable);
382 }
383 
fuseReluWithPackedOps(std::shared_ptr<Graph> & graph)384 void fuseReluWithPackedOps(std::shared_ptr<Graph>& graph) {
385   SubgraphRewriter rewriter;
386 
387   std::string conv2d_prepack_run_relu_fused = R"(
388     graph(%input, %weight, %bias, %stride:int[], %padding:int[],
389           %dilation:int[], %groups:int, %dummy_min_max):
390         %output_min: float = prim::Constant[value=0.0]()
391         %output_max: None = prim::Constant()
392         %packed_weight_bias : __torch__.torch.classes.vulkan.Conv2dPackedContext = vulkan_prepack::create_conv2d_context(
393             %weight, %bias, %stride, %padding, %dilation, %groups,
394             %output_min, %output_max)
395         %r = vulkan_prepack::run_conv2d_context(%input, %packed_weight_bias)
396         return (%r) )";
397 
398   std::string conv2d_prepack_run_relu = R"(
399     graph(%input, %weight, %bias, %stride:int[], %padding:int[],
400           %dilation:int[], %groups:int, %dummy_min_max):
401         %packed_weight_bias = vulkan_prepack::create_conv2d_context(
402             %weight, %bias, %stride, %padding, %dilation, %groups,
403             %dummy_min_max, %dummy_min_max)
404         %conv2d_res = vulkan_prepack::run_conv2d_context(%input, %packed_weight_bias)
405         %r = aten::relu(%conv2d_res)
406         return (%r) )";
407 
408   rewriter.RegisterRewritePattern(
409       conv2d_prepack_run_relu, conv2d_prepack_run_relu_fused);
410 
411   std::string conv2d_prepack_run_relu_inplace = R"(
412     graph(%input, %weight, %bias, %stride:int[], %padding:int[],
413           %dilation:int[], %groups:int, %dummy_min_max):
414         %packed_weight_bias = vulkan_prepack::create_conv2d_context(
415             %weight, %bias, %stride, %padding, %dilation, %groups,
416             %dummy_min_max, %dummy_min_max)
417         %conv2d_res = vulkan_prepack::run_conv2d_context(%input, %packed_weight_bias)
418         %r = aten::relu_(%conv2d_res)
419         return (%r) )";
420 
421   rewriter.RegisterRewritePattern(
422       conv2d_prepack_run_relu_inplace, conv2d_prepack_run_relu_fused);
423   rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable);
424 }
425 
426 } // namespace
427 
vulkanInsertPrePackedOps(std::shared_ptr<Graph> & graph)428 void vulkanInsertPrePackedOps(std::shared_ptr<Graph>& graph) {
429   insertPrePackedLinearOp(graph);
430   insertPrePackedLayernormOp(graph);
431   insertPrePackedConv2dOp(graph);
432   insertPrePackedConv1dOp(graph);
433   rewriteQuantizedOps(graph);
434   insertPrePackedGruOp(graph);
435   insertPrePackedLstmOp(graph);
436   insertPrePackedBatchNormOp(graph);
437 }
438 
vulkanInsertPrePackedOps(script::Module & module)439 void vulkanInsertPrePackedOps(script::Module& module) {
440   for (auto& method : module.get_methods()) {
441     auto graph = method.graph();
442     vulkanInsertPrePackedOps(graph);
443   }
444   for (script::Module m : module.children()) {
445     vulkanInsertPrePackedOps(m);
446   }
447 }
448 
vulkanFusePrePackedConvWithClamp(script::Module & module)449 void vulkanFusePrePackedConvWithClamp(script::Module& module) {
450   auto graph = module.get_method("forward").graph();
451   fuseReluWithPackedOps(graph);
452   fuseHardtanhWithPackedOps(graph);
453 }
454 
vulkanFoldPrePackingOps(script::Module & m)455 void vulkanFoldPrePackingOps(script::Module& m) {
456   PrePackingOpsFilterFn filter_fn = [](const Node* n) -> bool {
457     return (
458         (n->kind() ==
459          Symbol::fromQualString("vulkan_prepack::create_conv2d_context")) ||
460         (n->kind() ==
461          Symbol::fromQualString("vulkan_prepack::create_tconv2d_context")) ||
462         (n->kind() ==
463          Symbol::fromQualString("vulkan_prepack::create_qconv2d_context")) ||
464         (n->kind() ==
465          Symbol::fromQualString("vulkan_prepack::create_qtconv2d_context")) ||
466         (n->kind() ==
467          Symbol::fromQualString(
468              "vulkan_quantized_prepack::convert_qconv2d_context")) ||
469         (n->kind() ==
470          Symbol::fromQualString("vulkan_prepack::create_conv1d_context")) ||
471         (n->kind() ==
472          Symbol::fromQualString(
473              "vulkan_quantized_prepack::convert_qtconv2d_context")) ||
474         (n->kind() ==
475          Symbol::fromQualString(
476              "vulkan_quantized_prepack::convert_linear_context")) ||
477         (n->kind() ==
478          Symbol::fromQualString("vulkan_prepack::create_linear_context")) ||
479         (n->kind() ==
480          Symbol::fromQualString("vulkan_prepack::create_layernorm_context")) ||
481         (n->kind() ==
482          Symbol::fromQualString("vulkan_prepack::create_gru_context")) ||
483         (n->kind() ==
484          Symbol::fromQualString("vulkan_prepack::create_lstm_context")) ||
485         (n->kind() ==
486          Symbol::fromQualString("vulkan_prepack::create_batchnorm_context")));
487   };
488   PrePackingOpsFolder(m, filter_fn, "prepack_folding");
489 }
490 
vulkanRemoveMutation(script::Module & module)491 static void vulkanRemoveMutation(script::Module& module) {
492   auto graph = module.get_method("forward").graph();
493   RemoveTensorMutation(graph);
494 }
495 
vulkanRunCanonicalOptimizations(script::Module & module)496 static void vulkanRunCanonicalOptimizations(script::Module& module) {
497   auto graph = module.get_method("forward").graph();
498   for (const auto& method : module.get_methods()) {
499     auto method_graph = method.graph();
500     runOptimization(method_graph, false /* no loop unrolling */);
501   }
502 }
503 
vulkanOptimizeForMobile(const script::Module & m,const std::set<MobileOptimizerType> & optimization_blocklist,const std::vector<std::string> & preserved_methods)504 script::Module vulkanOptimizeForMobile(
505     const script::Module& m,
506     const std::set<MobileOptimizerType>& optimization_blocklist,
507     const std::vector<std::string>& preserved_methods) {
508   auto cloned_module = m.clone();
509   cloned_module.eval();
510   cloned_module = FoldConvBatchNorm(cloned_module);
511   cloned_module = freeze_module(cloned_module, preserved_methods);
512   vulkanInsertPrePackedOps(cloned_module);
513   vulkanFusePrePackedConvWithClamp(cloned_module);
514   vulkanFoldPrePackingOps(cloned_module);
515   removeDropout(cloned_module);
516   vulkanRemoveMutation(cloned_module);
517 
518   if (!optimization_blocklist.count(
519           MobileOptimizerType::VULKAN_AUTOMATIC_GPU_TRANSFER)) {
520     transferInputOutputBackends(cloned_module);
521     cloned_module.register_attribute(
522         "requires_backend_transfers", BoolType::get(), false);
523   }
524 
525   // remove duplicated constants
526   vulkanRunCanonicalOptimizations(cloned_module);
527   eliminateDeadCode(cloned_module);
528 
529   cloned_module.register_attribute(
530       "optimized_for_vulkan", BoolType::get(), true);
531   return cloned_module;
532 }
533 
534 } // namespace torch::jit
535