1 #include <ATen/Config.h>
2 #include <ATen/code_template.h>
3 #include <torch/csrc/jit/ir/ir.h>
4 #include <torch/csrc/jit/jit_log.h>
5 #include <torch/csrc/jit/passes/constant_propagation.h>
6 #include <torch/csrc/jit/passes/dead_code_elimination.h>
7 #include <torch/csrc/jit/passes/graph_rewrite_helper.h>
8 #include <torch/csrc/jit/passes/mkldnn_rewrite.h>
9 #include <torch/csrc/jit/tensorexpr/kernel.h>
10
11 namespace torch::jit {
12
13 #if AT_MKLDNN_ENABLED()
14
getSizesOf(Node * n,size_t idx)15 static c10::VaryingShape<int64_t> getSizesOf(Node* n, size_t idx) {
16 auto tt = n->input(idx)->type()->cast<TensorType>();
17 return tt->sizes();
18 }
19
insertPrePackedConvOpForNode(Node * n)20 static void insertPrePackedConvOpForNode(Node* n) {
21 constexpr int POS_INPUT = 0;
22 constexpr int POS_WEIGHT = 1;
23 if (!tensorexpr::isContiguous(
24 n->input(POS_INPUT), at::MemoryFormat::ChannelsLast)) {
25 GRAPH_DEBUG(
26 "insertPrePackedConvOpForNode: input is not ChannelsLast contiguous");
27 return;
28 }
29
30 if (!tensorexpr::isContiguous(
31 n->input(POS_WEIGHT), at::MemoryFormat::ChannelsLast)) {
32 GRAPH_DEBUG(
33 "insertPrePackedConvOpForNode: weight is not ChannelsLast contiguous");
34 return;
35 }
36
37 // Leave depthwise conv2d to NNC
38 if (tensorexpr::conv2dIsSupportedJit(n)) {
39 GRAPH_DEBUG("insertPrePackedConvOpForNode: leave depthwise conv2d to NNC");
40 return;
41 }
42
43 WithInsertPoint guard(n);
44 auto graph = n->owningGraph();
45
46 auto input_sizes = getSizesOf(n, POS_INPUT);
47 IValue input_size_value(*input_sizes.concrete_sizes());
48 auto input_size = graph->insertConstant(input_size_value);
49
50 auto prepack_node = graph->create(
51 Symbol::fromQualString("mkldnn_prepacked::conv2d_prepack"), 1);
52
53 // skip input value
54 for (const auto i : c10::irange(1, n->inputs().size())) {
55 Value* v = n->input(i);
56 prepack_node->addInput(v);
57 }
58 prepack_node->addInput(input_size);
59 auto attr = graph->insertConstant(IValue("none"));
60 prepack_node->addInput(attr);
61 prepack_node->output()->setType(
62 getCustomClass("__torch__.torch.classes.mkldnn.ConvOpContext"));
63 graph->insertNode(prepack_node);
64
65 auto prepack_conv = graph->insertNode(
66 graph->create(Symbol::fromQualString("mkldnn_prepacked::conv2d_run"), 1));
67 prepack_conv->addInput(n->input(0));
68 prepack_conv->addInput(prepack_node->output());
69 prepack_conv->output()->setType(n->output()->type()->cast<TensorType>());
70
71 n->output()->replaceAllUsesWith(prepack_conv->output());
72 }
73
isTensorTypeCPU(Node * node)74 static bool isTensorTypeCPU(Node* node) {
75 for (const auto& input : node->inputs()) {
76 auto type = input->type()->cast<TensorType>();
77 if (!type) {
78 continue;
79 }
80 auto device = type->device();
81 if (!device) {
82 return false;
83 }
84 if (!device->is_cpu()) {
85 return false;
86 }
87 }
88 return true;
89 }
90
insertPrePackedConvOp(Block * b)91 static void insertPrePackedConvOp(Block* b) {
92 for (Node* n : b->nodes()) {
93 for (Block* b : n->blocks()) {
94 insertPrePackedConvOp(b);
95 }
96
97 if (n->kind() == aten::conv2d) {
98 if (isTensorTypeCPU(n)) {
99 insertPrePackedConvOpForNode(n);
100 }
101 }
102 }
103 EliminateDeadCode(b);
104 }
105
insertMkldnnPrePackedConv2dOp(std::shared_ptr<Graph> & graph)106 static void insertMkldnnPrePackedConv2dOp(std::shared_ptr<Graph>& graph) {
107 insertPrePackedConvOp(graph->block());
108 }
109
insertMkldnnPrePackedOps(std::shared_ptr<Graph> & graph)110 static void insertMkldnnPrePackedOps(std::shared_ptr<Graph>& graph) {
111 insertMkldnnPrePackedConv2dOp(graph);
112 }
113
FuseReluWithPackedOps(std::shared_ptr<Graph> & graph)114 static void FuseReluWithPackedOps(std::shared_ptr<Graph>& graph) {
115 auto conv_op_rstring = at::jit::CodeTemplate(R"(
116 graph(%input, %weight, %bias, %stride:int[], %padding:int[],
117 %dilation:int[], %groups:int, %input_size:int[], %dummy_attr:str):
118 %packed_weight_bias = mkldnn_prepacked::conv2d_prepack(
119 %weight, %bias, %stride, %padding, %dilation, %groups,
120 %input_size, %dummy_attr)
121 %conv2d_res = mkldnn_prepacked::conv2d_run(%input, %packed_weight_bias)
122 %res = aten::${op}(%conv2d_res)
123 return (%res))");
124
125 auto conv_op_fused_rstring = at::jit::CodeTemplate(R"(
126 graph(%input, %weight, %bias, %stride:int[], %padding:int[],
127 %dilation:int[], %groups:int, %input_size:int[], %dummy_attr:str):
128 %attr: str = prim::Constant[value="${op_attr}"]()
129 %packed_weight_bias : __torch__.torch.classes.mkldnn.ConvOpContext = mkldnn_prepacked::conv2d_prepack(
130 %weight, %bias, %stride, %padding, %dilation, %groups,
131 %input_size, %attr)
132 %res = mkldnn_prepacked::conv2d_run(%input, %packed_weight_bias)
133 return (%res))");
134
135 for (auto const& it : mkldnn::fusion_rewrite_map) {
136 std::string op = it.first;
137 if (op == std::string("none")) {
138 continue;
139 }
140
141 at::jit::TemplateEnv env;
142 env.s("op", op);
143
144 at::jit::TemplateEnv env_fused;
145 env_fused.s("op_attr", op);
146
147 SubgraphRewriter rewriter;
148 rewriter.RegisterRewritePattern(
149 conv_op_rstring.format(env), conv_op_fused_rstring.format(env_fused));
150
151 auto filters = it.second;
152 rewriter.runOnGraph(graph, filters);
153 }
154 }
155
PrePackingOpsFolder(Block * b)156 static void PrePackingOpsFolder(Block* b) {
157 auto is_foldable_op = [](const Node* n) -> bool {
158 return (
159 n->kind() ==
160 Symbol::fromQualString("mkldnn_prepacked::conv2d_prepack"));
161 };
162
163 std::unordered_set<Node*> nodes_to_delete;
164 for (Node* n : b->nodes()) {
165 for (Block* block : n->blocks()) {
166 PrePackingOpsFolder(block);
167 }
168 if (is_foldable_op(n)) {
169 auto optional_outputs = torch::jit::runNodeIfInputsAreConstant(n);
170 if (optional_outputs) {
171 auto outputs = optional_outputs.value();
172 TORCH_CHECK(outputs.size() == 1, "Prepack ops have single output");
173 Value* prepack_op_value = n->output(0);
174 auto graph = n->owningGraph();
175 WithInsertPoint ins(prepack_op_value->node());
176 auto weak_class_obj =
177 outputs[0].toObject()->copy_to_weak_compilation_ref();
178 Value* packed_weight = graph->insertConstant(weak_class_obj)
179 ->setType(n->output(0)->type());
180 prepack_op_value->replaceAllUsesWith(packed_weight);
181 nodes_to_delete.insert(n);
182 }
183 }
184 }
185 for (auto n : nodes_to_delete) {
186 n->removeAllInputs();
187 }
188 for (auto n : nodes_to_delete) {
189 n->destroy();
190 }
191 }
192
FoldPrePackingOps(std::shared_ptr<Graph> & graph)193 static void FoldPrePackingOps(std::shared_ptr<Graph>& graph) {
194 PrePackingOpsFolder(graph->block());
195 }
196
FuseConvWithEltwise(std::shared_ptr<Graph> & graph)197 void FuseConvWithEltwise(std::shared_ptr<Graph>& graph) {
198 GRAPH_DEBUG(
199 "Before insertMkldnnPrePackedOps. Beginning of FuseConvWithEltwise\n",
200 *graph);
201 insertMkldnnPrePackedOps(graph);
202 GRAPH_DEBUG(
203 "After insertMkldnnPrePackedOps, before FuseReluWithPackedOps\n", *graph);
204 FuseReluWithPackedOps(graph);
205 GRAPH_DEBUG(
206 "After FuseReluWithPackedOps, before FoldPrePackingOps\n", *graph);
207 FoldPrePackingOps(graph);
208 GRAPH_DEBUG("After FoldPrePackingOps. End of FuseConvWithEltwise\n", *graph);
209 }
210
211 #else
212
213 void FuseConvWithEltwise(std::shared_ptr<Graph>& graph) {
214 GRAPH_DEBUG("MKLDNN Not enabled");
215 }
216
217 #endif // AT_MKLDNN_ENABLED()
218
219 } // namespace torch::jit
220