xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/mkldnn_rewrite.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/Config.h>
2 #include <ATen/code_template.h>
3 #include <torch/csrc/jit/ir/ir.h>
4 #include <torch/csrc/jit/jit_log.h>
5 #include <torch/csrc/jit/passes/constant_propagation.h>
6 #include <torch/csrc/jit/passes/dead_code_elimination.h>
7 #include <torch/csrc/jit/passes/graph_rewrite_helper.h>
8 #include <torch/csrc/jit/passes/mkldnn_rewrite.h>
9 #include <torch/csrc/jit/tensorexpr/kernel.h>
10 
11 namespace torch::jit {
12 
13 #if AT_MKLDNN_ENABLED()
14 
getSizesOf(Node * n,size_t idx)15 static c10::VaryingShape<int64_t> getSizesOf(Node* n, size_t idx) {
16   auto tt = n->input(idx)->type()->cast<TensorType>();
17   return tt->sizes();
18 }
19 
insertPrePackedConvOpForNode(Node * n)20 static void insertPrePackedConvOpForNode(Node* n) {
21   constexpr int POS_INPUT = 0;
22   constexpr int POS_WEIGHT = 1;
23   if (!tensorexpr::isContiguous(
24           n->input(POS_INPUT), at::MemoryFormat::ChannelsLast)) {
25     GRAPH_DEBUG(
26         "insertPrePackedConvOpForNode: input is not ChannelsLast contiguous");
27     return;
28   }
29 
30   if (!tensorexpr::isContiguous(
31           n->input(POS_WEIGHT), at::MemoryFormat::ChannelsLast)) {
32     GRAPH_DEBUG(
33         "insertPrePackedConvOpForNode: weight is not ChannelsLast contiguous");
34     return;
35   }
36 
37   // Leave depthwise conv2d to NNC
38   if (tensorexpr::conv2dIsSupportedJit(n)) {
39     GRAPH_DEBUG("insertPrePackedConvOpForNode: leave depthwise conv2d to NNC");
40     return;
41   }
42 
43   WithInsertPoint guard(n);
44   auto graph = n->owningGraph();
45 
46   auto input_sizes = getSizesOf(n, POS_INPUT);
47   IValue input_size_value(*input_sizes.concrete_sizes());
48   auto input_size = graph->insertConstant(input_size_value);
49 
50   auto prepack_node = graph->create(
51       Symbol::fromQualString("mkldnn_prepacked::conv2d_prepack"), 1);
52 
53   // skip input value
54   for (const auto i : c10::irange(1, n->inputs().size())) {
55     Value* v = n->input(i);
56     prepack_node->addInput(v);
57   }
58   prepack_node->addInput(input_size);
59   auto attr = graph->insertConstant(IValue("none"));
60   prepack_node->addInput(attr);
61   prepack_node->output()->setType(
62       getCustomClass("__torch__.torch.classes.mkldnn.ConvOpContext"));
63   graph->insertNode(prepack_node);
64 
65   auto prepack_conv = graph->insertNode(
66       graph->create(Symbol::fromQualString("mkldnn_prepacked::conv2d_run"), 1));
67   prepack_conv->addInput(n->input(0));
68   prepack_conv->addInput(prepack_node->output());
69   prepack_conv->output()->setType(n->output()->type()->cast<TensorType>());
70 
71   n->output()->replaceAllUsesWith(prepack_conv->output());
72 }
73 
isTensorTypeCPU(Node * node)74 static bool isTensorTypeCPU(Node* node) {
75   for (const auto& input : node->inputs()) {
76     auto type = input->type()->cast<TensorType>();
77     if (!type) {
78       continue;
79     }
80     auto device = type->device();
81     if (!device) {
82       return false;
83     }
84     if (!device->is_cpu()) {
85       return false;
86     }
87   }
88   return true;
89 }
90 
insertPrePackedConvOp(Block * b)91 static void insertPrePackedConvOp(Block* b) {
92   for (Node* n : b->nodes()) {
93     for (Block* b : n->blocks()) {
94       insertPrePackedConvOp(b);
95     }
96 
97     if (n->kind() == aten::conv2d) {
98       if (isTensorTypeCPU(n)) {
99         insertPrePackedConvOpForNode(n);
100       }
101     }
102   }
103   EliminateDeadCode(b);
104 }
105 
insertMkldnnPrePackedConv2dOp(std::shared_ptr<Graph> & graph)106 static void insertMkldnnPrePackedConv2dOp(std::shared_ptr<Graph>& graph) {
107   insertPrePackedConvOp(graph->block());
108 }
109 
insertMkldnnPrePackedOps(std::shared_ptr<Graph> & graph)110 static void insertMkldnnPrePackedOps(std::shared_ptr<Graph>& graph) {
111   insertMkldnnPrePackedConv2dOp(graph);
112 }
113 
FuseReluWithPackedOps(std::shared_ptr<Graph> & graph)114 static void FuseReluWithPackedOps(std::shared_ptr<Graph>& graph) {
115   auto conv_op_rstring = at::jit::CodeTemplate(R"(
116     graph(%input, %weight, %bias, %stride:int[], %padding:int[],
117           %dilation:int[], %groups:int, %input_size:int[], %dummy_attr:str):
118         %packed_weight_bias = mkldnn_prepacked::conv2d_prepack(
119             %weight, %bias, %stride, %padding, %dilation, %groups,
120             %input_size, %dummy_attr)
121         %conv2d_res = mkldnn_prepacked::conv2d_run(%input, %packed_weight_bias)
122         %res = aten::${op}(%conv2d_res)
123         return (%res))");
124 
125   auto conv_op_fused_rstring = at::jit::CodeTemplate(R"(
126     graph(%input, %weight, %bias, %stride:int[], %padding:int[],
127           %dilation:int[], %groups:int, %input_size:int[], %dummy_attr:str):
128         %attr: str = prim::Constant[value="${op_attr}"]()
129         %packed_weight_bias : __torch__.torch.classes.mkldnn.ConvOpContext = mkldnn_prepacked::conv2d_prepack(
130             %weight, %bias, %stride, %padding, %dilation, %groups,
131             %input_size, %attr)
132         %res = mkldnn_prepacked::conv2d_run(%input, %packed_weight_bias)
133         return (%res))");
134 
135   for (auto const& it : mkldnn::fusion_rewrite_map) {
136     std::string op = it.first;
137     if (op == std::string("none")) {
138       continue;
139     }
140 
141     at::jit::TemplateEnv env;
142     env.s("op", op);
143 
144     at::jit::TemplateEnv env_fused;
145     env_fused.s("op_attr", op);
146 
147     SubgraphRewriter rewriter;
148     rewriter.RegisterRewritePattern(
149         conv_op_rstring.format(env), conv_op_fused_rstring.format(env_fused));
150 
151     auto filters = it.second;
152     rewriter.runOnGraph(graph, filters);
153   }
154 }
155 
PrePackingOpsFolder(Block * b)156 static void PrePackingOpsFolder(Block* b) {
157   auto is_foldable_op = [](const Node* n) -> bool {
158     return (
159         n->kind() ==
160         Symbol::fromQualString("mkldnn_prepacked::conv2d_prepack"));
161   };
162 
163   std::unordered_set<Node*> nodes_to_delete;
164   for (Node* n : b->nodes()) {
165     for (Block* block : n->blocks()) {
166       PrePackingOpsFolder(block);
167     }
168     if (is_foldable_op(n)) {
169       auto optional_outputs = torch::jit::runNodeIfInputsAreConstant(n);
170       if (optional_outputs) {
171         auto outputs = optional_outputs.value();
172         TORCH_CHECK(outputs.size() == 1, "Prepack ops have single output");
173         Value* prepack_op_value = n->output(0);
174         auto graph = n->owningGraph();
175         WithInsertPoint ins(prepack_op_value->node());
176         auto weak_class_obj =
177             outputs[0].toObject()->copy_to_weak_compilation_ref();
178         Value* packed_weight = graph->insertConstant(weak_class_obj)
179                                    ->setType(n->output(0)->type());
180         prepack_op_value->replaceAllUsesWith(packed_weight);
181         nodes_to_delete.insert(n);
182       }
183     }
184   }
185   for (auto n : nodes_to_delete) {
186     n->removeAllInputs();
187   }
188   for (auto n : nodes_to_delete) {
189     n->destroy();
190   }
191 }
192 
FoldPrePackingOps(std::shared_ptr<Graph> & graph)193 static void FoldPrePackingOps(std::shared_ptr<Graph>& graph) {
194   PrePackingOpsFolder(graph->block());
195 }
196 
FuseConvWithEltwise(std::shared_ptr<Graph> & graph)197 void FuseConvWithEltwise(std::shared_ptr<Graph>& graph) {
198   GRAPH_DEBUG(
199       "Before insertMkldnnPrePackedOps. Beginning of FuseConvWithEltwise\n",
200       *graph);
201   insertMkldnnPrePackedOps(graph);
202   GRAPH_DEBUG(
203       "After insertMkldnnPrePackedOps, before FuseReluWithPackedOps\n", *graph);
204   FuseReluWithPackedOps(graph);
205   GRAPH_DEBUG(
206       "After FuseReluWithPackedOps, before FoldPrePackingOps\n", *graph);
207   FoldPrePackingOps(graph);
208   GRAPH_DEBUG("After FoldPrePackingOps. End of FuseConvWithEltwise\n", *graph);
209 }
210 
211 #else
212 
213 void FuseConvWithEltwise(std::shared_ptr<Graph>& graph) {
214   GRAPH_DEBUG("MKLDNN Not enabled");
215 }
216 
217 #endif // AT_MKLDNN_ENABLED()
218 
219 } // namespace torch::jit
220