xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/xnnpack_rewrite.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/jit_type.h>
2 #include <ATen/native/xnnpack/OpContext.h>
3 
4 #include <torch/csrc/jit/ir/ir.h>
5 #include <torch/csrc/jit/passes/constant_propagation.h>
6 #include <torch/csrc/jit/passes/fold_conv_bn.h>
7 #include <torch/csrc/jit/passes/freeze_module.h>
8 #include <torch/csrc/jit/passes/fuse_linear.h>
9 #include <torch/csrc/jit/passes/fuse_relu.h>
10 #include <torch/csrc/jit/passes/graph_rewrite_helper.h>
11 #include <torch/csrc/jit/passes/hoist_conv_packed_params.h>
12 #include <torch/csrc/jit/passes/mobile_optimizer_type.h>
13 #include <torch/csrc/jit/passes/prepack_folding.h>
14 #include <torch/csrc/jit/passes/remove_dropout.h>
15 #include <torch/csrc/jit/passes/subgraph_rewrite.h>
16 #include <torch/csrc/jit/passes/xnnpack_rewrite.h>
17 #include <torch/csrc/jit/runtime/graph_executor_impl.h>
18 
19 namespace torch::jit {
20 
21 namespace {
22 
replaceConv1dWithConv2d(std::shared_ptr<Graph> & graph)23 void replaceConv1dWithConv2d(std::shared_ptr<Graph>& graph) {
24   std::string conv_1d_pattern = R"(
25     graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
26         %res = aten::conv1d(%input, %weight, %bias, %stride, %padding, %dilation, %groups)
27         return (%res) )";
28 
29   std::string conv_2d_pattern = R"(
30     graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
31         %zero : int = prim::Constant[value=0]()
32         %one : int = prim::Constant[value=1]()
33         %stride_w : int = prim::ListUnpack(%stride)
34         %stride_2d : int[] = prim::ListConstruct(%one, %stride_w)
35         %padding_w : int = prim::ListUnpack(%padding)
36         %padding_2d : int[] = prim::ListConstruct(%zero, %padding_w)
37         %dilation_w : int = prim::ListUnpack(%dilation)
38         %dilation_2d : int[] = prim::ListConstruct(%one, %dilation_w)
39         %two : int = prim::Constant[value=2]()
40         %input_2d : Tensor = aten::unsqueeze(%input, %two)
41         %weight_2d : Tensor = aten::unsqueeze(%weight, %two)
42         %output_2d = aten::conv2d(
43             %input_2d, %weight_2d, %bias, %stride_2d, %padding_2d, %dilation_2d, %groups)
44         %output : Tensor = aten::squeeze(%output_2d, %two)
45         return (%output) )";
46 
47   std::vector<std::pair<std::string, std::string>> value_mappings(
48       {{"zero", "res"},
49        {"one", "res"},
50        {"stride_w", "res"},
51        {"stride_2d", "res"},
52        {"padding_w", "res"},
53        {"padding_2d", "res"},
54        {"dilation_w", "res"},
55        {"dilation_2d", "res"},
56        {"two", "res"},
57        {"input_2d", "res"},
58        {"weight_2d", "res"},
59        {"output_2d", "res"},
60        {"output", "res"}});
61 
62   SubgraphRewriter rewriter;
63   rewriter.RegisterRewritePattern(
64       conv_1d_pattern, conv_2d_pattern, value_mappings);
65   rewriter.runOnGraph(graph);
66 }
67 
68 } // namespace
69 
transformConv1dToConv2d(std::shared_ptr<Graph> & graph)70 void transformConv1dToConv2d(std::shared_ptr<Graph>& graph) {
71   // Replace _convolution with conv1d and conv2d
72   graph_rewrite_helper::replaceConvolutionWithAtenConv(graph);
73   replaceConv1dWithConv2d(graph);
74 }
75 
transformConv1dToConv2d(script::Module & module)76 void transformConv1dToConv2d(script::Module& module) {
77   for (auto& method : module.get_methods()) {
78     auto graph = method.graph();
79     transformConv1dToConv2d(graph);
80   }
81   for (script::Module m : module.children()) {
82     transformConv1dToConv2d(m);
83   }
84 }
85 
86 #ifdef USE_XNNPACK
87 
88 namespace {
89 
insertPrePackedLinearOp(std::shared_ptr<Graph> & graph)90 void insertPrePackedLinearOp(std::shared_ptr<Graph>& graph) {
91   // fuse decomposed linear into aten::linear
92   FuseLinear(graph);
93 
94   std::string linear_pattern = R"(
95     graph(%input, %weight, %bias):
96         %res = aten::linear(%input, %weight, %bias)
97         return (%res))";
98   std::string prepacked_ops_pattern = R"(
99     graph(%input, %weight, %bias):
100         %output_min_max : None = prim::Constant()
101         %packed_weight_bias = prepacked::linear_clamp_prepack(
102             %weight, %bias, %output_min_max, %output_min_max)
103         %res = prepacked::linear_clamp_run(%input, %packed_weight_bias)
104         return (%res))";
105 
106   std::vector<std::pair<std::string, std::string>> value_mappings(
107       {{"output_min_max", "res"},
108        {"packed_weight_bias", "res"},
109        {"res", "res"}});
110 
111   SubgraphRewriter linear_rewriter;
112   linear_rewriter.RegisterRewritePattern(
113       linear_pattern, prepacked_ops_pattern, value_mappings);
114   linear_rewriter.runOnGraph(graph);
115 }
116 
insertPrePackedConv2dOp(std::shared_ptr<Graph> & graph)117 void insertPrePackedConv2dOp(std::shared_ptr<Graph>& graph) {
118   // Replace _convolution with conv2d
119   graph_rewrite_helper::replaceConvolutionWithAtenConv(graph);
120 
121   std::string conv_2d_pattern = R"(
122     graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
123         %res = aten::conv2d(%input, %weight, %bias, %stride, %padding, %dilation, %groups)
124         return (%res) )";
125 
126   std::string prepacked_ops_conv2d_pattern = R"(
127     graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
128         %output_min_max : None = prim::Constant()
129         %packed_weight_bias = prepacked::conv2d_clamp_prepack(
130             %weight, %bias, %stride, %padding, %dilation, %groups,
131             %output_min_max, %output_min_max)
132         %res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias)
133         return (%res) )";
134 
135   std::vector<std::pair<std::string, std::string>> value_mappings(
136       {{"output_min_max", "res"},
137        {"packed_weight_bias", "res"},
138        {"res", "res"}});
139 
140   SubgraphRewriter rewriter;
141   rewriter.RegisterRewritePattern(
142       conv_2d_pattern, prepacked_ops_conv2d_pattern, value_mappings);
143   rewriter.runOnGraph(graph);
144 
145   std::string conv_2d_transpose_pattern = R"(
146       graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[],
147           %output_padding:int[], %groups:int):
148         %res = aten::conv_transpose2d(%input, %weight, %bias, %stride, %padding, %output_padding, %groups, %dilation)
149         return (%res) )";
150 
151   std::string prepacked_ops_conv2d_transpose_pattern = R"(
152     graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %output_padding:int[], %groups:int):
153         %output_min_max : None = prim::Constant()
154         %packed_weight_bias = prepacked::conv2d_transpose_clamp_prepack(
155             %weight, %bias, %stride, %padding, %output_padding, %dilation, %groups,
156             %output_min_max, %output_min_max)
157         %res = prepacked::conv2d_transpose_clamp_run(%input, %packed_weight_bias)
158         return (%res) )";
159 
160   value_mappings = {
161       {"output_min_max", "res"}, {"packed_weight_bias", "res"}, {"res", "res"}};
162 
163   SubgraphRewriter transpose_rewriter;
164   transpose_rewriter.RegisterRewritePattern(
165       conv_2d_transpose_pattern,
166       prepacked_ops_conv2d_transpose_pattern,
167       value_mappings);
168   transpose_rewriter.runOnGraph(graph);
169 }
170 
fuseHardtanhWithPackedOps(std::shared_ptr<Graph> & graph)171 void fuseHardtanhWithPackedOps(std::shared_ptr<Graph>& graph) {
172   SubgraphRewriter rewriter;
173 
174   std::string linear_prepack_run_hardtanh_fused = R"(
175     graph(%input, %weight, %bias, %output_min, %output_max, %dummy_min_max):
176         %packed_weight_bias : __torch__.torch.classes.xnnpack.LinearOpContext = prepacked::linear_clamp_prepack(
177             %weight, %bias, %output_min, %output_max)
178         %res = prepacked::linear_clamp_run(%input, %packed_weight_bias)
179         return (%res))";
180 
181   std::string conv2d_prepack_run_hardtanh_fused = R"(
182     graph(%input, %weight, %bias, %stride:int[], %padding:int[],
183           %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
184         %packed_weight_bias : __torch__.torch.classes.xnnpack.Conv2dOpContext = prepacked::conv2d_clamp_prepack(
185             %weight, %bias, %stride, %padding, %dilation, %groups,
186             %output_min, %output_max)
187         %res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias)
188         return (%res) )";
189 
190   std::string linear_prepack_run_hardtanh = R"(
191     graph(%input, %weight, %bias, %output_min, %output_max, %dummy_min_max):
192         %packed_weight_bias = prepacked::linear_clamp_prepack(
193             %weight, %bias, %dummy_min_max, %dummy_min_max)
194         %linear_res = prepacked::linear_clamp_run(%input, %packed_weight_bias)
195         %res = aten::hardtanh(%linear_res, %output_min, %output_max)
196         return (%res))";
197 
198   std::vector<std::pair<std::string, std::string>> value_mappings(
199       {{"packed_weight_bias", "packed_weight_bias"}, {"res", "res"}});
200 
201   rewriter.RegisterRewritePattern(
202       linear_prepack_run_hardtanh,
203       linear_prepack_run_hardtanh_fused,
204       value_mappings);
205 
206   std::string conv2d_prepack_run_hardtanh = R"(
207     graph(%input, %weight, %bias, %stride:int[], %padding:int[],
208           %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
209         %packed_weight_bias = prepacked::conv2d_clamp_prepack(
210             %weight, %bias, %stride, %padding, %dilation, %groups,
211             %dummy_min_max, %dummy_min_max)
212         %conv2d_res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias)
213         %res = aten::hardtanh(%conv2d_res, %output_min, %output_max)
214         return (%res) )";
215 
216   value_mappings = {
217       {"packed_weight_bias", "packed_weight_bias"}, {"res", "res"}};
218 
219   rewriter.RegisterRewritePattern(
220       conv2d_prepack_run_hardtanh,
221       conv2d_prepack_run_hardtanh_fused,
222       value_mappings);
223 
224   std::string linear_prepack_run_hardtanh_inplace = R"(
225     graph(%input, %weight, %bias, %output_min, %output_max, %dummy_min_max):
226         %packed_weight_bias = prepacked::linear_clamp_prepack(
227             %weight, %bias, %dummy_min_max, %dummy_min_max)
228         %linear_res = prepacked::linear_clamp_run(%input, %packed_weight_bias)
229         %res = aten::hardtanh_(%linear_res, %output_min, %output_max)
230         return (%res))";
231 
232   std::string conv2d_prepack_run_hardtanh_inplace = R"(
233     graph(%input, %weight, %bias, %stride:int[], %padding:int[],
234           %dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
235         %packed_weight_bias = prepacked::conv2d_clamp_prepack(
236             %weight, %bias, %stride, %padding, %dilation, %groups,
237             %dummy_min_max, %dummy_min_max)
238         %conv2d_res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias)
239         %res = aten::hardtanh_(%conv2d_res, %output_min, %output_max)
240         return (%res) )";
241 
242   value_mappings = {
243       {"packed_weight_bias", "packed_weight_bias"}, {"res", "res"}};
244 
245   rewriter.RegisterRewritePattern(
246       linear_prepack_run_hardtanh_inplace,
247       linear_prepack_run_hardtanh_fused,
248       value_mappings);
249 
250   value_mappings = {
251       {"packed_weight_bias", "packed_weight_bias"}, {"res", "res"}};
252 
253   rewriter.RegisterRewritePattern(
254       conv2d_prepack_run_hardtanh_inplace,
255       conv2d_prepack_run_hardtanh_fused,
256       value_mappings);
257 
258   rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable);
259 }
260 
fuseReluWithPackedOps(std::shared_ptr<Graph> & graph)261 void fuseReluWithPackedOps(std::shared_ptr<Graph>& graph) {
262   SubgraphRewriter rewriter;
263 
264   std::string linear_prepack_run_relu_fused = R"(
265     graph(%input, %weight, %bias, %dummy_min_max):
266         %output_min: float = prim::Constant[value=0.0]()
267         %output_max: None = prim::Constant()
268         %packed_weight_bias : __torch__.torch.classes.xnnpack.LinearOpContext = prepacked::linear_clamp_prepack(
269             %weight, %bias, %output_min, %output_max)
270         %res = prepacked::linear_clamp_run(%input, %packed_weight_bias)
271         return (%res))";
272 
273   std::string conv2d_prepack_run_relu_fused = R"(
274     graph(%input, %weight, %bias, %stride:int[], %padding:int[],
275           %dilation:int[], %groups:int, %dummy_min_max):
276         %output_min: float = prim::Constant[value=0.0]()
277         %output_max: None = prim::Constant()
278         %packed_weight_bias : __torch__.torch.classes.xnnpack.Conv2dOpContext = prepacked::conv2d_clamp_prepack(
279             %weight, %bias, %stride, %padding, %dilation, %groups,
280             %output_min, %output_max)
281         %res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias)
282         return (%res) )";
283 
284   std::string linear_prepack_run_relu = R"(
285     graph(%input, %weight, %bias, %dummy_min_max):
286         %packed_weight_bias = prepacked::linear_clamp_prepack(
287             %weight, %bias, %dummy_min_max, %dummy_min_max)
288         %linear_res = prepacked::linear_clamp_run(%input, %packed_weight_bias)
289         %res = aten::relu(%linear_res)
290         return (%res))";
291 
292   std::vector<std::pair<std::string, std::string>> value_mappings(
293       {{"output_min", "packed_weight_bias"},
294        {"output_max", "packed_weight_bias"},
295        {"packed_weight_bias", "packed_weight_bias"},
296        {"res", "res"}});
297 
298   rewriter.RegisterRewritePattern(
299       linear_prepack_run_relu, linear_prepack_run_relu_fused, value_mappings);
300 
301   std::string conv2d_prepack_run_relu = R"(
302     graph(%input, %weight, %bias, %stride:int[], %padding:int[],
303           %dilation:int[], %groups:int, %dummy_min_max):
304         %packed_weight_bias = prepacked::conv2d_clamp_prepack(
305             %weight, %bias, %stride, %padding, %dilation, %groups,
306             %dummy_min_max, %dummy_min_max)
307         %conv2d_res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias)
308         %res = aten::relu(%conv2d_res)
309         return (%res) )";
310 
311   value_mappings = {
312       {"output_min", "packed_weight_bias"},
313       {"output_max", "packed_weight_bias"},
314       {"packed_weight_bias", "packed_weight_bias"},
315       {"res", "res"}};
316 
317   rewriter.RegisterRewritePattern(
318       conv2d_prepack_run_relu, conv2d_prepack_run_relu_fused, value_mappings);
319 
320   std::string linear_prepack_run_relu_inplace = R"(
321     graph(%input, %weight, %bias, %dummy_min_max):
322         %packed_weight_bias = prepacked::linear_clamp_prepack(
323             %weight, %bias, %dummy_min_max, %dummy_min_max)
324         %linear_res = prepacked::linear_clamp_run(%input, %packed_weight_bias)
325         %res = aten::relu_(%linear_res)
326         return (%res))";
327 
328   std::string conv2d_prepack_run_relu_inplace = R"(
329     graph(%input, %weight, %bias, %stride:int[], %padding:int[],
330           %dilation:int[], %groups:int, %dummy_min_max):
331         %packed_weight_bias = prepacked::conv2d_clamp_prepack(
332             %weight, %bias, %stride, %padding, %dilation, %groups,
333             %dummy_min_max, %dummy_min_max)
334         %conv2d_res = prepacked::conv2d_clamp_run(%input, %packed_weight_bias)
335         %res = aten::relu_(%conv2d_res)
336         return (%res) )";
337 
338   value_mappings = {
339       {"output_min", "packed_weight_bias"},
340       {"output_max", "packed_weight_bias"},
341       {"packed_weight_bias", "packed_weight_bias"},
342       {"res", "res"}};
343 
344   rewriter.RegisterRewritePattern(
345       linear_prepack_run_relu_inplace,
346       linear_prepack_run_relu_fused,
347       value_mappings);
348 
349   value_mappings = {
350       {"output_min", "packed_weight_bias"},
351       {"output_max", "packed_weight_bias"},
352       {"packed_weight_bias", "packed_weight_bias"},
353       {"res", "res"}};
354 
355   rewriter.RegisterRewritePattern(
356       conv2d_prepack_run_relu_inplace,
357       conv2d_prepack_run_relu_fused,
358       value_mappings);
359   rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable);
360 }
361 
runCanonicalOptimizations(script::Module & module)362 void runCanonicalOptimizations(script::Module& module) {
363   for (const auto& method : module.get_methods()) {
364     auto graph = method.graph();
365     // Not sure if we have models running on mobile that require loop unrolling.
366     // Perhaps language/speech models? Conservatively setting that to false.
367     runOptimization(graph, false /* no loop unrolling */);
368   }
369 }
370 
371 } // namespace
372 
insertPrePackedOps(std::shared_ptr<Graph> & graph)373 void insertPrePackedOps(std::shared_ptr<Graph>& graph) {
374   insertPrePackedLinearOp(graph);
375   insertPrePackedConv2dOp(graph);
376 }
377 
insertPrePackedOps(script::Module & module)378 void insertPrePackedOps(script::Module& module) {
379   for (auto& method : module.get_methods()) {
380     auto graph = method.graph();
381     insertPrePackedOps(graph);
382   }
383   for (script::Module m : module.children()) {
384     insertPrePackedOps(m);
385   }
386 }
387 
fusePrePackedLinearConvWithClamp(script::Module & module)388 void fusePrePackedLinearConvWithClamp(script::Module& module) {
389   for (auto& method : module.get_methods()) {
390     auto graph = method.graph();
391     fuseReluWithPackedOps(graph);
392     fuseHardtanhWithPackedOps(graph);
393 
394     // Ignore user defined classes for later passes
395     ConstantPropagation(graph, true);
396   }
397 }
398 
FoldPrePackingOps(script::Module & m)399 void FoldPrePackingOps(script::Module& m) {
400   PrePackingOpsFilterFn filter_fn = [](const Node* n) -> bool {
401     return (
402         (n->kind() ==
403          Symbol::fromQualString("prepacked::linear_clamp_prepack")) ||
404         n->kind() ==
405             Symbol::fromQualString("prepacked::conv2d_clamp_prepack") ||
406         n->kind() ==
407             Symbol::fromQualString(
408                 "prepacked::conv2d_transpose_clamp_prepack"));
409   };
410   PrePackingOpsFolder(m, filter_fn, "prepack_folding");
411   for (auto& method : m.get_methods()) {
412     auto graph = method.graph();
413     // Folding requires a const propagation through user defined classes
414     ConstantPropagation(graph, false);
415   }
416 }
417 
optimizeForMobile(const script::Module & m,const std::set<MobileOptimizerType> & optimization_blocklist,const std::vector<std::string> & preserved_methods)418 script::Module optimizeForMobile(
419     const script::Module& m,
420     const std::set<MobileOptimizerType>& optimization_blocklist,
421     const std::vector<std::string>& preserved_methods) {
422   auto cloned_module = m.clone();
423   cloned_module.eval();
424 
425   if (!optimization_blocklist.count(MobileOptimizerType::CONV_1D_TO_2D)) {
426     transformConv1dToConv2d(cloned_module);
427   }
428 
429   if (!optimization_blocklist.count(MobileOptimizerType::CONV_BN_FUSION)) {
430     cloned_module = FoldConvBatchNorm(cloned_module);
431   }
432 
433   // Many optimizations require a frozen module, but ConvBatchNorm requires
434   // an unfrozen module
435   cloned_module = freeze_module(cloned_module, preserved_methods);
436 
437   if (!optimization_blocklist.count(
438           MobileOptimizerType::INSERT_FOLD_PREPACK_OPS)) {
439     // TODO fix duplication caused by referencing same op across multiple
440     // functions
441     insertPrePackedOps(cloned_module);
442     cloned_module = freeze_module(cloned_module, preserved_methods);
443     fusePrePackedLinearConvWithClamp(cloned_module);
444     FoldPrePackingOps(cloned_module);
445   }
446 
447   if (!optimization_blocklist.count(
448           MobileOptimizerType::HOIST_CONV_PACKED_PARAMS) &&
449       cloned_module.find_method("forward")) {
450     // freeze again in case it was not done in previous optional passes
451     cloned_module = freeze_module(cloned_module, preserved_methods);
452     HoistConvPackedParams(cloned_module);
453     // and freeze yet again to remove the empty QuantizedConv modules
454     cloned_module = freeze_module(cloned_module, preserved_methods);
455   }
456 
457   // Run canonical optimizations post freezing
458   // since freezing inlines the graph. Otherwise we
459   // will have to explicitly call Inlining pass.
460   runCanonicalOptimizations(cloned_module);
461 
462   if (!optimization_blocklist.count(MobileOptimizerType::REMOVE_DROPOUT)) {
463     for (const auto& method : cloned_module.get_methods()) {
464       auto graph = method.graph();
465       // Module must be not be in training mode but optimize calls eval()
466       removeDropout(graph);
467     }
468   }
469 
470   if (!optimization_blocklist.count(MobileOptimizerType::FUSE_ADD_RELU)) {
471     for (const auto& method : cloned_module.get_methods()) {
472       auto graph = method.graph();
473       FuseAddRelu(graph);
474     }
475   }
476   cloned_module.register_attribute("mobile_optimized", BoolType::get(), true);
477   return cloned_module;
478 }
479 
480 #else
481 
insertPrePackedOps(std::shared_ptr<Graph> & graph)482 void insertPrePackedOps(std::shared_ptr<Graph>& graph) {
483   TORCH_INTERNAL_ASSERT(
484       false, "XNNPACK is not enabled. Please build with USE_XNNPACK=1");
485 }
486 
insertPrePackedOps(script::Module & module)487 void insertPrePackedOps(script::Module& module) {
488   TORCH_INTERNAL_ASSERT(
489       false, "XNNPACK is not enabled. Please build with USE_XNNPACK=1");
490 }
491 
fusePrePackedLinearConvWithClamp(script::Module & module)492 void fusePrePackedLinearConvWithClamp(script::Module& module) {
493   TORCH_INTERNAL_ASSERT(
494       false, "XNNPACK is not enabled. Please build with USE_XNNPACK=1");
495 }
496 
FoldPrePackingOps(script::Module & m)497 void FoldPrePackingOps(script::Module& m) {
498   TORCH_INTERNAL_ASSERT(
499       false, "XNNPACK is not enabled. Please build with USE_XNNPACK=1");
500 }
501 
optimizeForMobile(const script::Module & module,const std::set<MobileOptimizerType> & blocklist,const std::vector<std::string> & preserved_methods)502 script::Module optimizeForMobile(
503     const script::Module& module,
504     const std::set<MobileOptimizerType>& blocklist,
505     const std::vector<std::string>& preserved_methods) {
506   TORCH_INTERNAL_ASSERT(
507       false,
508       "Mobile optimization only available with XNNPACK at the moment. "
509       "XNNPACK is not enabled. Please build with USE_XNNPACK=1");
510   return module;
511 }
512 
513 #endif
514 } // namespace torch::jit
515