1 #include <ATen/core/jit_type.h>
2 #include <torch/csrc/jit/ir/ir.h>
3 #include <torch/csrc/jit/ir/subgraph_matcher.h>
4 #include <torch/csrc/jit/passes/dead_code_elimination.h>
5 #include <torch/csrc/jit/passes/fold_conv_bn.h>
6 #include <torch/csrc/jit/passes/freeze_module.h>
7 #include <torch/csrc/jit/passes/fuse_linear.h>
8 #include <torch/csrc/jit/passes/graph_rewrite_helper.h>
9 #include <torch/csrc/jit/passes/prepack_folding.h>
10 #include <torch/csrc/jit/passes/remove_dropout.h>
11 #include <torch/csrc/jit/passes/remove_mutation.h>
12 #include <torch/csrc/jit/passes/subgraph_rewrite.h>
13 #include <torch/csrc/jit/passes/vulkan_rewrite.h>
14 #include <torch/csrc/jit/runtime/graph_executor_impl.h>
15
16 namespace torch::jit {
17
18 namespace {
19
insertPrePackedBatchNormOp(std::shared_ptr<Graph> & graph)20 void insertPrePackedBatchNormOp(std::shared_ptr<Graph>& graph) {
21 std::string batchnorm_pattern = R"(
22 graph(%input, %weight, %bias, %mean, %var, %training, %momentum, %eps, %cudnn_enable):
23 %r = aten::batch_norm(%input, %weight, %bias, %mean, %var, %training, %momentum, %eps, %cudnn_enable)
24 return (%r))";
25 std::string prepacked_ops_pattern = R"(
26 graph(%input, %weight, %bias, %mean, %var, %training, %momentum, %eps, %cudnn_enable):
27 %op_context : __torch__.torch.classes.vulkan.BatchNormPackedContext = vulkan_prepack::create_batchnorm_context(
28 %weight, %bias, %mean, %var, %training, %momentum, %eps, %cudnn_enable)
29 %res = vulkan_prepack::run_batchnorm_context(%input, %op_context)
30 return (%res))";
31
32 SubgraphRewriter batchnorm_rewriter;
33 batchnorm_rewriter.RegisterRewritePattern(
34 batchnorm_pattern, prepacked_ops_pattern);
35 batchnorm_rewriter.runOnGraph(graph);
36 }
37
insertPrePackedLinearOp(std::shared_ptr<Graph> & graph)38 void insertPrePackedLinearOp(std::shared_ptr<Graph>& graph) {
39 // fuse decomposed linear into aten::linear
40 FuseLinear(graph);
41
42 std::string linear_pattern = R"(
43 graph(%input, %weight, %bias):
44 %r = aten::linear(%input, %weight, %bias)
45 return (%r))";
46 std::string prepacked_ops_pattern = R"(
47 graph(%input, %weight, %bias):
48 %weight_t = aten::t(%weight)
49 %packed_weight_bias = vulkan_prepack::create_linear_context(
50 %weight_t, %bias)
51 %res = vulkan_prepack::run_linear_context(%input, %packed_weight_bias)
52 return (%res))";
53
54 SubgraphRewriter linear_rewriter;
55 linear_rewriter.RegisterRewritePattern(linear_pattern, prepacked_ops_pattern);
56 linear_rewriter.runOnGraph(graph);
57 }
58
insertPrePackedLayernormOp(std::shared_ptr<Graph> & graph)59 void insertPrePackedLayernormOp(std::shared_ptr<Graph>& graph) {
60 std::string layernorm_pattern = R"(
61 graph(%input, %normalized_shape, %weight, %bias, %eps, %cudnn_enable):
62 %r = aten::layer_norm(%input, %normalized_shape, %weight, %bias, %eps, %cudnn_enable)
63 return (%r))";
64 std::string prepacked_ops_pattern = R"(
65 graph(%input, %normalized_shape, %weight, %bias, %eps, %cudnn_enable):
66 %op_context : __torch__.torch.classes.vulkan.LayernormPackedContext = vulkan_prepack::create_layernorm_context(
67 %weight, %bias, %eps)
68 %res = vulkan_prepack::run_layernorm_context(%input, %normalized_shape, %op_context)
69 return (%res))";
70
71 SubgraphRewriter layernorm_rewriter;
72 layernorm_rewriter.RegisterRewritePattern(
73 layernorm_pattern, prepacked_ops_pattern);
74 layernorm_rewriter.runOnGraph(graph);
75 }
76
insertPrePackedConv2dOp(std::shared_ptr<Graph> & graph)77 void insertPrePackedConv2dOp(std::shared_ptr<Graph>& graph) {
78 graph_rewrite_helper::replaceConvolutionWithAtenConv(graph);
79
80 std::string conv_2d_pattern = R"(
81 graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
82 %r = aten::conv2d(%input, %weight, %bias, %stride, %padding, %dilation, %groups)
83 return (%r) )";
84
85 std::string prepacked_ops_conv2d_pattern = R"(
86 graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
87 %output_min_max : None = prim::Constant()
88 %packed_weight_bias = vulkan_prepack::create_conv2d_context(
89 %weight, %bias, %stride, %padding, %dilation, %groups,
90 %output_min_max, %output_min_max)
91 %r = vulkan_prepack::run_conv2d_context(%input, %packed_weight_bias)
92 return (%r) )";
93
94 SubgraphRewriter rewriter;
95 rewriter.RegisterRewritePattern(
96 conv_2d_pattern, prepacked_ops_conv2d_pattern);
97 rewriter.runOnGraph(graph);
98
99 std::string conv_2d_transpose_pattern = R"(
100 graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[],
101 %output_padding:int[], %groups:int):
102 %res = aten::conv_transpose2d(%input, %weight, %bias, %stride, %padding, %output_padding, %groups, %dilation)
103 return (%res) )";
104
105 std::string prepacked_ops_conv2d_transpose_pattern = R"(
106 graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %output_padding:int[], %groups:int):
107 %output_min_max : None = prim::Constant()
108 %packed_weight_bias = vulkan_prepack::create_tconv2d_context(
109 %weight, %bias, %stride, %padding, %output_padding, %dilation, %groups,
110 %output_min_max, %output_min_max)
111 %res = vulkan_prepack::run_tconv2d_context(%input, %packed_weight_bias)
112 return (%res) )";
113
114 SubgraphRewriter transpose_rewriter;
115 transpose_rewriter.RegisterRewritePattern(
116 conv_2d_transpose_pattern, prepacked_ops_conv2d_transpose_pattern);
117 transpose_rewriter.runOnGraph(graph);
118 }
119
insertPrePackedConv1dOp(std::shared_ptr<Graph> & graph)120 void insertPrePackedConv1dOp(std::shared_ptr<Graph>& graph) {
121 graph_rewrite_helper::replaceConvolutionWithAtenConv(graph);
122
123 std::string conv_1d_pattern = R"(
124 graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
125 %r = aten::conv1d(%input, %weight, %bias, %stride, %padding, %dilation, %groups)
126 return (%r) )";
127
128 std::string prepacked_ops_conv1d_pattern = R"(
129 graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
130 %packed_weight_bias = vulkan_prepack::create_conv1d_context(
131 %weight, %bias, %stride, %padding, %dilation, %groups)
132 %r = vulkan_prepack::run_conv1d_context(%input, %packed_weight_bias)
133 return (%r) )";
134
135 SubgraphRewriter rewriter;
136 rewriter.RegisterRewritePattern(
137 conv_1d_pattern, prepacked_ops_conv1d_pattern);
138 rewriter.runOnGraph(graph);
139 }
140
transferInputOutputBackends(std::shared_ptr<Graph> & graph)141 void transferInputOutputBackends(std::shared_ptr<Graph>& graph) {
142 // Move inputs to Vulkan backend
143 for (Value* input : graph->inputs()) {
144 NamedValue named_input = NamedValue("", input);
145 if (named_input.type()->kind() == TypeKind::TensorType &&
146 !input->uses().empty()) {
147 // find the insertion point
148 WithInsertPoint ip(input->uses()[0].user->prev());
149 Value* replaced_input = graph->insert(
150 Symbol::fromQualString("aten::to"), {named_input, "vulkan"});
151 // replace the input
152 input->replaceAllUsesAfterNodeWith(
153 replaced_input->node(), replaced_input);
154 }
155 }
156
157 // Move outputs to CPU backend
158 at::ArrayRef<Value*>&& outputs = graph->outputs();
159 for (size_t i = 0; i < outputs.size(); i++) {
160 Value* output = outputs[i];
161 NamedValue named_output = NamedValue("", output);
162 if (named_output.type()->kind() == TypeKind::TensorType) {
163 // find the insertion point
164 WithInsertPoint ip(output->node()->next());
165 Value* replaced_output = graph->insert(
166 Symbol::fromQualString("aten::to"), {named_output, "cpu"});
167 // replace the output
168 graph->block()->replaceOutput(i, replaced_output);
169 }
170 }
171
172 SubgraphRewriter rewriter;
173 rewriter.runOnGraph(graph);
174 }
175
transferInputOutputBackends(script::Module & module)176 void transferInputOutputBackends(script::Module& module) {
177 std::shared_ptr<Graph> graph = module.get_methods()[0].graph();
178 transferInputOutputBackends(graph);
179 }
180
eliminateDeadCode(script::Module & module)181 void eliminateDeadCode(script::Module& module) {
182 for (auto& method : module.get_methods()) {
183 EliminateDeadCode(method.graph());
184 }
185 }
186
rewriteQuantizedOps(std::shared_ptr<Graph> & graph)187 void rewriteQuantizedOps(std::shared_ptr<Graph>& graph) {
188 // quantized::add
189 std::string quantized_add_pattern = R"(
190 graph(%a_quant, %b_quant, %r_scale, %r_zero_point) :
191 %res = quantized::add(%a_quant, %b_quant, %r_scale, %r_zero_point)
192 return (%res) )";
193 std::string vk_quantized_add_pattern = R"(
194 graph(%a_quant, %b_quant, %r_scale, %r_zero_point) :
195 %res = vulkan_quantized::add(%a_quant, %b_quant, %r_scale, %r_zero_point)
196 return (%res) )";
197
198 torch::jit::SubgraphRewriter quantized_add_rewriter;
199 quantized_add_rewriter.RegisterRewritePattern(
200 quantized_add_pattern, vk_quantized_add_pattern);
201 quantized_add_rewriter.runOnGraph(graph);
202
203 // quantized::mul
204 std::string quantized_mul_pattern = R"(
205 graph(%a_quant, %b_quant, %r_scale, %r_zero_point) :
206 %res = quantized::mul(%a_quant, %b_quant, %r_scale, %r_zero_point)
207 return (%res) )";
208 std::string vk_quantized_mul_pattern = R"(
209 graph(%a_quant, %b_quant, %r_scale, %r_zero_point) :
210 %res = vulkan_quantized::mul(%a_quant, %b_quant, %r_scale, %r_zero_point)
211 return (%res) )";
212
213 torch::jit::SubgraphRewriter quantized_mul_rewriter;
214 quantized_mul_rewriter.RegisterRewritePattern(
215 quantized_mul_pattern, vk_quantized_mul_pattern);
216 quantized_mul_rewriter.runOnGraph(graph);
217
218 // quantized::conv2d
219 std::string quantized_conv2d_pattern = R"(
220 graph(%a_quant, %packed_params, %r_scale, %r_zero_point) :
221 %res = quantized::conv2d(%a_quant, %packed_params, %r_scale, %r_zero_point)
222 return (%res) )";
223 std::string vk_quantized_conv2d_pattern = R"(
224 graph(%a_quant, %packed_params, %r_scale, %r_zero_point):
225 %output_min_max : None = prim::Constant()
226 %vk_packed_params : __torch__.torch.classes.vulkan.Conv2dPackedContext = vulkan_quantized_prepack::convert_qconv2d_context(
227 %packed_params, %output_min_max, %output_min_max)
228 %res = vulkan_prepack::run_qconv2d_context(
229 %a_quant, %r_scale, %r_zero_point, %vk_packed_params)
230 return (%res) )";
231
232 torch::jit::SubgraphRewriter quantized_conv2d_rewriter;
233 quantized_conv2d_rewriter.RegisterRewritePattern(
234 quantized_conv2d_pattern, vk_quantized_conv2d_pattern);
235 quantized_conv2d_rewriter.runOnGraph(graph);
236
237 // quantized::conv_transpose2d
238 std::string quantized_conv_transpose2d_pattern = R"(
239 graph(%a_quant, %packed_params, %r_scale, %r_zero_point) :
240 %res = quantized::conv_transpose2d(%a_quant, %packed_params, %r_scale, %r_zero_point)
241 return (%res) )";
242 std::string vk_quantized_conv_transpose2d_pattern = R"(
243 graph(%a_quant, %packed_params, %r_scale, %r_zero_point):
244 %output_min_max : None = prim::Constant()
245 %vk_packed_params : __torch__.torch.classes.vulkan.Conv2dPackedContext = vulkan_quantized_prepack::convert_qtconv2d_context(
246 %packed_params, %output_min_max, %output_min_max)
247 %res = vulkan_prepack::run_qconv2d_context(
248 %a_quant, %r_scale, %r_zero_point, %vk_packed_params)
249 return (%res) )";
250
251 torch::jit::SubgraphRewriter quantized_conv_transpose2d_rewriter;
252 quantized_conv_transpose2d_rewriter.RegisterRewritePattern(
253 quantized_conv_transpose2d_pattern,
254 vk_quantized_conv_transpose2d_pattern);
255 quantized_conv_transpose2d_rewriter.runOnGraph(graph);
256
257 // quantized::conv2d_relu
258 std::string quantized_conv2d_relu_pattern = R"(
259 graph(%a_quant, %packed_params, %r_scale, %r_zero_point) :
260 %res = quantized::conv2d_relu(%a_quant, %packed_params, %r_scale, %r_zero_point)
261 return (%res) )";
262 std::string vk_quantized_conv2d_relu_pattern = R"(
263 graph(%a_quant, %packed_params, %r_scale, %r_zero_point):
264 %output_min: float = prim::Constant[value=0.0]()
265 %output_max: None = prim::Constant()
266 %vk_packed_params : __torch__.torch.classes.vulkan.Conv2dPackedContext = vulkan_quantized_prepack::convert_qconv2d_context(
267 %packed_params, %output_min, %output_max)
268 %res = vulkan_prepack::run_qconv2d_context(
269 %a_quant, %r_scale, %r_zero_point, %vk_packed_params)
270 return (%res) )";
271
272 torch::jit::SubgraphRewriter quantized_conv2d_relu_rewriter;
273 quantized_conv2d_relu_rewriter.RegisterRewritePattern(
274 quantized_conv2d_relu_pattern, vk_quantized_conv2d_relu_pattern);
275 quantized_conv2d_relu_rewriter.runOnGraph(graph);
276
277 // quantized::linear
278 std::string quantized_linear_pattern = R"(
279 graph(%a_quant, %packed_params, %r_scale, %r_zero_point) :
280 %res = quantized::linear(%a_quant, %packed_params, %r_scale, %r_zero_point)
281 return (%res) )";
282 std::string vk_quantized_linear_pattern = R"(
283 graph(%a_quant, %packed_params, %r_scale, %r_zero_point):
284 %vk_packed_params : __torch__.torch.classes.vulkan.LinearPackedContext = vulkan_quantized_prepack::convert_linear_context(
285 %packed_params)
286 %res = vulkan_prepack::run_qlinear_context(
287 %a_quant, %r_scale, %r_zero_point, %vk_packed_params)
288 return (%res) )";
289
290 torch::jit::SubgraphRewriter quantized_linear_rewriter;
291 quantized_linear_rewriter.RegisterRewritePattern(
292 quantized_linear_pattern, vk_quantized_linear_pattern);
293 quantized_linear_rewriter.runOnGraph(graph);
294 }
295
insertPrePackedGruOp(std::shared_ptr<Graph> & graph)296 void insertPrePackedGruOp(std::shared_ptr<Graph>& graph) {
297 std::string gru_pattern = R"(
298 graph(%input.1, %hx.1, %params_cpu:Tensor[], %has_biases:bool, %num_layers:int, %dropout:float, %train:bool, %bidirectional:bool, %batch_first:bool):
299 %y.1 : Tensor, %hn.1 : Tensor = aten::gru(%input.1, %hx.1, %params_cpu, %has_biases, %num_layers, %dropout, %train, %bidirectional, %batch_first)
300 return (%y.1, %hn.1) )";
301 std::string prepacked_ops_pattern = R"(
302 graph(%input.1, %hx.1, %params_cpu:Tensor[], %has_biases:bool, %num_layers:int, %dropout:float, %train:bool, %bidirectional:bool, %batch_first:bool):
303 %packed_weights_biases = vulkan_prepack::create_gru_context(
304 %params_cpu, %has_biases, %num_layers, %dropout, %train, %bidirectional, %batch_first)
305 %y.1 : Tensor, %hn.1 : Tensor = vulkan_prepack::run_gru_context(%input.1, %hx.1, %packed_weights_biases)
306 return (%y.1, %hn.1) )";
307
308 auto filter = [&](const Match& match,
309 const std::unordered_map<std::string, Value*>& vmap) {
310 auto node = match.values_map.at(vmap.at("params_cpu"))->node();
311 return node->output()->type()->str() == "Tensor[]";
312 };
313
314 SubgraphRewriter gru_rewriter;
315 gru_rewriter.RegisterRewritePattern(gru_pattern, prepacked_ops_pattern);
316 gru_rewriter.runOnGraph(graph, filter);
317 }
318
insertPrePackedLstmOp(std::shared_ptr<Graph> & graph)319 void insertPrePackedLstmOp(std::shared_ptr<Graph>& graph) {
320 std::string lstm_pattern = R"(
321 graph(%input.1, %hx:Tensor[], %params_cpu:Tensor[], %has_biases:bool, %num_layers:int, %dropout:float, %train:bool, %bidirectional:bool, %batch_first:bool):
322 %y.1 : Tensor, %hn.1 : Tensor, %cn.1 : Tensor = aten::lstm(%input.1, %hx, %params_cpu, %has_biases, %num_layers, %dropout, %train, %bidirectional, %batch_first)
323 return (%y.1, %hn.1, %cn.1) )";
324 std::string prepacked_ops_pattern = R"(
325 graph(%input.1, %hx:Tensor[], %params_cpu:Tensor[], %has_biases:bool, %num_layers:int, %dropout:float, %train:bool, %bidirectional:bool, %batch_first:bool):
326 %packed_weights_biases = vulkan_prepack::create_lstm_context(
327 %params_cpu, %has_biases, %num_layers, %dropout, %train, %bidirectional, %batch_first)
328 %hx.1 : Tensor, %cx.1 : Tensor = prim::ListUnpack(%hx)
329 %y.1 : Tensor, %hn.1 : Tensor, %cn.1 : Tensor = vulkan_prepack::run_lstm_context(%input.1, %hx.1, %cx.1, %packed_weights_biases)
330 return (%y.1, %hn.1, %cn.1) )";
331
332 auto filter = [&](const Match& match,
333 const std::unordered_map<std::string, Value*>& vmap) {
334 auto node = match.values_map.at(vmap.at("hx"))->node();
335 return node->output()->type()->str() == "Tensor[]";
336 };
337
338 SubgraphRewriter lstm_rewriter;
339 lstm_rewriter.RegisterRewritePattern(lstm_pattern, prepacked_ops_pattern);
340 lstm_rewriter.runOnGraph(graph, filter);
341 }
342
fuseHardtanhWithPackedOps(std::shared_ptr<Graph> & graph)343 void fuseHardtanhWithPackedOps(std::shared_ptr<Graph>& graph) {
344 SubgraphRewriter rewriter;
345
346 std::string conv2d_prepack_run_hardtanh_fused = R"(
347 graph(%input, %weight, %bias, %stride:int[], %padding:int[],
348 %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
349 %packed_weight_bias : __torch__.torch.classes.vulkan.Conv2dPackedContext = vulkan_prepack::create_conv2d_context(
350 %weight, %bias, %stride, %padding, %dilation, %groups,
351 %output_min, %output_max)
352 %r = vulkan_prepack::run_conv2d_context(%input, %packed_weight_bias)
353 return (%r) )";
354
355 std::string conv2d_prepack_run_hardtanh = R"(
356 graph(%input, %weight, %bias, %stride:int[], %padding:int[],
357 %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
358 %packed_weight_bias = vulkan_prepack::create_conv2d_context(
359 %weight, %bias, %stride, %padding, %dilation, %groups,
360 %dummy_min_max, %dummy_min_max)
361 %conv2d_res = vulkan_prepack::run_conv2d_context(%input, %packed_weight_bias)
362 %r = aten::hardtanh(%conv2d_res, %output_min, %output_max)
363 return (%r) )";
364
365 rewriter.RegisterRewritePattern(
366 conv2d_prepack_run_hardtanh, conv2d_prepack_run_hardtanh_fused);
367
368 std::string conv2d_prepack_run_hardtanh_inplace = R"(
369 graph(%input, %weight, %bias, %stride:int[], %padding:int[],
370 %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
371 %packed_weight_bias = vulkan_prepack::create_conv2d_context(
372 %weight, %bias, %stride, %padding, %dilation, %groups,
373 %dummy_min_max, %dummy_min_max)
374 %conv2d_res = vulkan_prepack::run_conv2d_context(%input, %packed_weight_bias)
375 %r = aten::hardtanh_(%conv2d_res, %output_min, %output_max)
376 return (%r) )";
377
378 rewriter.RegisterRewritePattern(
379 conv2d_prepack_run_hardtanh_inplace, conv2d_prepack_run_hardtanh_fused);
380
381 rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable);
382 }
383
fuseReluWithPackedOps(std::shared_ptr<Graph> & graph)384 void fuseReluWithPackedOps(std::shared_ptr<Graph>& graph) {
385 SubgraphRewriter rewriter;
386
387 std::string conv2d_prepack_run_relu_fused = R"(
388 graph(%input, %weight, %bias, %stride:int[], %padding:int[],
389 %dilation:int[], %groups:int, %dummy_min_max):
390 %output_min: float = prim::Constant[value=0.0]()
391 %output_max: None = prim::Constant()
392 %packed_weight_bias : __torch__.torch.classes.vulkan.Conv2dPackedContext = vulkan_prepack::create_conv2d_context(
393 %weight, %bias, %stride, %padding, %dilation, %groups,
394 %output_min, %output_max)
395 %r = vulkan_prepack::run_conv2d_context(%input, %packed_weight_bias)
396 return (%r) )";
397
398 std::string conv2d_prepack_run_relu = R"(
399 graph(%input, %weight, %bias, %stride:int[], %padding:int[],
400 %dilation:int[], %groups:int, %dummy_min_max):
401 %packed_weight_bias = vulkan_prepack::create_conv2d_context(
402 %weight, %bias, %stride, %padding, %dilation, %groups,
403 %dummy_min_max, %dummy_min_max)
404 %conv2d_res = vulkan_prepack::run_conv2d_context(%input, %packed_weight_bias)
405 %r = aten::relu(%conv2d_res)
406 return (%r) )";
407
408 rewriter.RegisterRewritePattern(
409 conv2d_prepack_run_relu, conv2d_prepack_run_relu_fused);
410
411 std::string conv2d_prepack_run_relu_inplace = R"(
412 graph(%input, %weight, %bias, %stride:int[], %padding:int[],
413 %dilation:int[], %groups:int, %dummy_min_max):
414 %packed_weight_bias = vulkan_prepack::create_conv2d_context(
415 %weight, %bias, %stride, %padding, %dilation, %groups,
416 %dummy_min_max, %dummy_min_max)
417 %conv2d_res = vulkan_prepack::run_conv2d_context(%input, %packed_weight_bias)
418 %r = aten::relu_(%conv2d_res)
419 return (%r) )";
420
421 rewriter.RegisterRewritePattern(
422 conv2d_prepack_run_relu_inplace, conv2d_prepack_run_relu_fused);
423 rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable);
424 }
425
426 } // namespace
427
vulkanInsertPrePackedOps(std::shared_ptr<Graph> & graph)428 void vulkanInsertPrePackedOps(std::shared_ptr<Graph>& graph) {
429 insertPrePackedLinearOp(graph);
430 insertPrePackedLayernormOp(graph);
431 insertPrePackedConv2dOp(graph);
432 insertPrePackedConv1dOp(graph);
433 rewriteQuantizedOps(graph);
434 insertPrePackedGruOp(graph);
435 insertPrePackedLstmOp(graph);
436 insertPrePackedBatchNormOp(graph);
437 }
438
vulkanInsertPrePackedOps(script::Module & module)439 void vulkanInsertPrePackedOps(script::Module& module) {
440 for (auto& method : module.get_methods()) {
441 auto graph = method.graph();
442 vulkanInsertPrePackedOps(graph);
443 }
444 for (script::Module m : module.children()) {
445 vulkanInsertPrePackedOps(m);
446 }
447 }
448
vulkanFusePrePackedConvWithClamp(script::Module & module)449 void vulkanFusePrePackedConvWithClamp(script::Module& module) {
450 auto graph = module.get_method("forward").graph();
451 fuseReluWithPackedOps(graph);
452 fuseHardtanhWithPackedOps(graph);
453 }
454
vulkanFoldPrePackingOps(script::Module & m)455 void vulkanFoldPrePackingOps(script::Module& m) {
456 PrePackingOpsFilterFn filter_fn = [](const Node* n) -> bool {
457 return (
458 (n->kind() ==
459 Symbol::fromQualString("vulkan_prepack::create_conv2d_context")) ||
460 (n->kind() ==
461 Symbol::fromQualString("vulkan_prepack::create_tconv2d_context")) ||
462 (n->kind() ==
463 Symbol::fromQualString("vulkan_prepack::create_qconv2d_context")) ||
464 (n->kind() ==
465 Symbol::fromQualString("vulkan_prepack::create_qtconv2d_context")) ||
466 (n->kind() ==
467 Symbol::fromQualString(
468 "vulkan_quantized_prepack::convert_qconv2d_context")) ||
469 (n->kind() ==
470 Symbol::fromQualString("vulkan_prepack::create_conv1d_context")) ||
471 (n->kind() ==
472 Symbol::fromQualString(
473 "vulkan_quantized_prepack::convert_qtconv2d_context")) ||
474 (n->kind() ==
475 Symbol::fromQualString(
476 "vulkan_quantized_prepack::convert_linear_context")) ||
477 (n->kind() ==
478 Symbol::fromQualString("vulkan_prepack::create_linear_context")) ||
479 (n->kind() ==
480 Symbol::fromQualString("vulkan_prepack::create_layernorm_context")) ||
481 (n->kind() ==
482 Symbol::fromQualString("vulkan_prepack::create_gru_context")) ||
483 (n->kind() ==
484 Symbol::fromQualString("vulkan_prepack::create_lstm_context")) ||
485 (n->kind() ==
486 Symbol::fromQualString("vulkan_prepack::create_batchnorm_context")));
487 };
488 PrePackingOpsFolder(m, filter_fn, "prepack_folding");
489 }
490
vulkanRemoveMutation(script::Module & module)491 static void vulkanRemoveMutation(script::Module& module) {
492 auto graph = module.get_method("forward").graph();
493 RemoveTensorMutation(graph);
494 }
495
vulkanRunCanonicalOptimizations(script::Module & module)496 static void vulkanRunCanonicalOptimizations(script::Module& module) {
497 auto graph = module.get_method("forward").graph();
498 for (const auto& method : module.get_methods()) {
499 auto method_graph = method.graph();
500 runOptimization(method_graph, false /* no loop unrolling */);
501 }
502 }
503
vulkanOptimizeForMobile(const script::Module & m,const std::set<MobileOptimizerType> & optimization_blocklist,const std::vector<std::string> & preserved_methods)504 script::Module vulkanOptimizeForMobile(
505 const script::Module& m,
506 const std::set<MobileOptimizerType>& optimization_blocklist,
507 const std::vector<std::string>& preserved_methods) {
508 auto cloned_module = m.clone();
509 cloned_module.eval();
510 cloned_module = FoldConvBatchNorm(cloned_module);
511 cloned_module = freeze_module(cloned_module, preserved_methods);
512 vulkanInsertPrePackedOps(cloned_module);
513 vulkanFusePrePackedConvWithClamp(cloned_module);
514 vulkanFoldPrePackingOps(cloned_module);
515 removeDropout(cloned_module);
516 vulkanRemoveMutation(cloned_module);
517
518 if (!optimization_blocklist.count(
519 MobileOptimizerType::VULKAN_AUTOMATIC_GPU_TRANSFER)) {
520 transferInputOutputBackends(cloned_module);
521 cloned_module.register_attribute(
522 "requires_backend_transfers", BoolType::get(), false);
523 }
524
525 // remove duplicated constants
526 vulkanRunCanonicalOptimizations(cloned_module);
527 eliminateDeadCode(cloned_module);
528
529 cloned_module.register_attribute(
530 "optimized_for_vulkan", BoolType::get(), true);
531 return cloned_module;
532 }
533
534 } // namespace torch::jit
535