xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/ATen.h>
2 #include <ATen/Config.h>
3 #include <ATen/NativeFunctions.h>
4 #include <ATen/Utils.h>
5 #include <ATen/core/symbol.h>
6 #include <ATen/native/layer_norm.h>
7 #include <c10/core/ScalarType.h>
8 #include <c10/util/Exception.h>
9 #include <c10/util/irange.h>
10 
11 #include <torch/csrc/jit/ir/alias_analysis.h>
12 #include <torch/csrc/jit/ir/constants.h>
13 #include <torch/csrc/jit/ir/ir.h>
14 #include <torch/csrc/jit/jit_log.h>
15 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
16 #include <torch/csrc/jit/passes/constant_propagation.h>
17 #include <torch/csrc/jit/passes/dead_code_elimination.h>
18 #include <torch/csrc/jit/passes/fold_conv_bn.h>
19 #include <torch/csrc/jit/passes/frozen_conv_folding.h>
20 #include <torch/csrc/jit/passes/frozen_ops_to_mkldnn.h>
21 #include <torch/csrc/jit/passes/graph_rewrite_helper.h>
22 #include <torch/csrc/jit/passes/peephole.h>
23 #include <torch/csrc/jit/passes/remove_mutation.h>
24 #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
25 #include <torch/csrc/jit/runtime/custom_operator.h>
26 #include <torch/csrc/jit/runtime/operator_options.h>
27 #include <torch/csrc/jit/tensorexpr/types.h>
28 // clang-format off
29 // moving ConvUtils include induces import cycle
30 #include <ATen/native/ConvUtils.h>
31 #include <algorithm>
32 #include <memory>
33 #include <ATen/core/stack.h>
34 #include <c10/core/Layout.h>
35 #include <c10/util/StringUtil.h>
36 
37 #if AT_MKLDNN_ENABLED()
38 #include <ATen/CPUFunctions.h>
39 #include <dnnl_types.h>
40 #include <ATen/native/mkldnn/Utils.h>
41 #include <ATen/native/mkldnn/MKLDNNCommon.h>
42 #include <ideep.hpp>
43 #endif
44 
45 // clang-format on
46 
47 namespace torch::jit {
48 
49 #if AT_MKLDNN_ENABLED()
50 
51 using Tensor = at::Tensor;
52 
53 namespace {
54 
aliasAnalysisFromSchema()55 c10::AliasAnalysisKind aliasAnalysisFromSchema() {
56   return AliasAnalysisKind::FROM_SCHEMA;
57 }
58 
59 using ValueSet = std::unordered_set<Value*>;
60 using ValueSetPtr = std::shared_ptr<std::unordered_set<Value*>>;
61 
getLastUse(Value * v)62 Node* getLastUse(Value* v) {
63   auto last_use_node = v->node();
64   for (const auto& use : v->uses()) {
65     if (use.user->isAfter(last_use_node)) {
66       last_use_node = use.user;
67     }
68   }
69   return last_use_node;
70 }
71 
merge_sets(std::unordered_map<Value *,ValueSetPtr> & alias_mapping,Value * existing,Value * new_v)72 void merge_sets(
73     std::unordered_map<Value*, ValueSetPtr>& alias_mapping,
74     Value* existing,
75     Value* new_v) {
76   if (alias_mapping[existing] == alias_mapping[new_v]) {
77     return;
78   }
79   auto existing_set = alias_mapping[existing];
80   auto set_to_remove = alias_mapping[new_v];
81   for (auto it = set_to_remove->begin(); it != set_to_remove->end(); it++) {
82     existing_set->insert(*it);
83     alias_mapping[*it] = existing_set;
84   }
85 }
86 
87 // no uses of tensors in container types
assertNonTensorTypeDoesNotContainTensors(TypePtr type)88 void assertNonTensorTypeDoesNotContainTensors(TypePtr type) {
89   if (type->cast<TensorType>()) {
90     return;
91   }
92   for (const auto& t : type->containedTypes()) {
93     TORCH_INTERNAL_ASSERT(!t->cast<TensorType>());
94   }
95 }
96 
InplaceMKLDNNSubgraph(std::shared_ptr<Graph> graph)97 void InplaceMKLDNNSubgraph(std::shared_ptr<Graph> graph) {
98   // This function first calculates aliasing sets,
99   // then calculates the last node each aliasing set is alive for.
100   // Then we go through each node, if it's a node which has an equivalent
101   // inplace node and the aliasing set for its input is dead afer this node, we
102   // inplace it. Then we merge the aliasing sets for the input and output of the
103   // node and extend the liveness of the set. To inplace a node you need to
104   // prove device and dtype of the input and output are the same, which we've
105   // already done, and prove that the output size is the same as the input size,
106   // which is achieved by explicit Broadcast nodes (which we inserted for other
107   // reasons).
108   // The graphs here are simple subgraphs without uses of Tensors in
109   // containers (Lists, GetAttrs, etc)
110 
111   // CALCULATE ALIASING SETS
112 
113   auto aliasDb = std::make_unique<AliasDb>(graph);
114 
115   // map from Value to its Aliasing Set
116   std::unordered_map<Value*, ValueSetPtr> alias_mapping;
117   ValueSet set;
118   ValueSetPtr input_set = std::make_shared<ValueSet>(set);
119   for (Value* v : graph->inputs()) {
120     if (v->type()->cast<TensorType>()) {
121       input_set->insert(v);
122       alias_mapping[v] = input_set;
123     } else {
124       assertNonTensorTypeDoesNotContainTensors(v->type());
125     }
126   }
127 
128   for (Node* n : graph->nodes()) {
129     for (Value* output : n->outputs()) {
130       if (!output->type()->cast<TensorType>()) {
131         assertNonTensorTypeDoesNotContainTensors(output->type());
132         continue;
133       }
134 
135       std::unordered_set<Value*> new_set = {output};
136       alias_mapping[output] = std::make_shared<ValueSet>(new_set);
137       for (Value* input : n->inputs()) {
138         if (aliasDb->mayAlias(input, output)) {
139           merge_sets(alias_mapping, input, output);
140         }
141       }
142     }
143   }
144 
145   // CALCULATE ALIASING SET LIVENESS
146 
147   // map from aliased set -> last use of set
148   std::unordered_map<ValueSetPtr, Node*> set_liveness;
149   for (auto& set : alias_mapping) {
150     if (set_liveness.count(set.second)) {
151       continue;
152     }
153     Node* last = nullptr;
154     for (const auto& v : *set.second) {
155       auto k = v->node()->kind();
156       if (k == prim::Constant || k == prim::ConstantMKLDNNTensor ||
157           k == prim::Param) {
158         last = graph->return_node();
159         continue;
160       }
161 
162       auto last_use = getLastUse(v);
163       if (!last || last_use->isAfter(last)) {
164         last = last_use;
165       }
166     }
167     set_liveness[set.second] = last;
168   }
169 
170   // REUSING MEMORY BY REINPLACING NODES
171   std::vector<Node*> nodes_to_inplace;
172 
173   auto add_to_inplace_set = [&](Node* node) {
174     // defer making the inplacing change because that would invalidate the old
175     // Node output Value*
176     nodes_to_inplace.push_back(node);
177     TORCH_INTERNAL_ASSERT(node->outputs().size() == 1);
178     auto output_liveness_end =
179         set_liveness[alias_mapping[node->outputs().at(0)]];
180     merge_sets(alias_mapping, node->inputs().at(0), node->output());
181     set_liveness[alias_mapping[node->output()]] = output_liveness_end;
182   };
183 
184   for (Node* node : graph->nodes()) {
185     auto k = node->kind();
186     if (k == aten::relu || k == aten::sigmoid || k == aten::dropout ||
187         k == prim::MKLDNNHardSwish || k == prim::MKLDNNHardSigmoid ||
188         k == prim::MKLDNNHardTanh || k == aten::tanh ||
189         k == prim::MKLDNNClamp || k == Symbol::prim("MKLDNNScalarMul") ||
190         k == Symbol::prim("MKLDNNLayerNorm")) {
191       if (set_liveness[alias_mapping[node->inputs().at(0)]]->isAfter(node)) {
192         continue;
193       }
194       add_to_inplace_set(node);
195     } else if (k == aten::mul || k == aten::add) {
196       // the binary operators (add/mul) are commutative and only take tensor
197       // inputs, so we can inplace either the first or second input
198       int64_t reusable_value_index = -1;
199       for (const auto i : c10::irange(2)) {
200         TORCH_INTERNAL_ASSERT(node->inputs().at(i)->type()->cast<TensorType>());
201         if (!set_liveness[alias_mapping[node->inputs().at(i)]]->isAfter(node)) {
202           reusable_value_index = i;
203           break;
204         }
205       }
206 
207       if (reusable_value_index == -1) {
208         continue;
209       }
210 
211       if (reusable_value_index == 1) {
212         node->insertInput(0, node->inputs().at(1));
213         node->removeInput(2);
214       }
215       add_to_inplace_set(node);
216     }
217   }
218 
219   for (Node* node : nodes_to_inplace) {
220     node->replaceWithNewSymbol(
221         Symbol::fromQualString(node->schema().name() + "_"));
222     node->destroy();
223   }
224 }
225 
226 // This is a factory function that creates an Operation that takes
227 // MKLDNN tensors and unpacks them into 1D contiguous tensors that we can
228 // run aten operations on. The precondition for using this function is that the
229 // aten operations in `aten_op` should be an identity for zero inputs. In other
230 // words, this should: `aten_op(0) = 0` The reason for this precondition has to
231 // do with blocked formats MKLDNN uses to lay tensor elements (nChw8c, nChw16c,
232 // etc). It splits the channel dimension into chunks of 8/16 makes it the
233 // innermost dimension. Whenever the channel dim isn't divisible by 8/16 the
234 // innermost dimension is padded with 0s. The precondition, `aten_op(0) == 0`
235 // allows us to avoid any special casing of padded elements.
createUnaryOp(std::function<void (at::Tensor output,at::Tensor input)> aten_op,bool inplace=false)236 Operation createUnaryOp(
237     std::function<void(at::Tensor output, at::Tensor input)> aten_op,
238     bool inplace = false) {
239   return [aten_op, inplace](Stack& stack) {
240     auto a = pop(stack).toTensor();
241     c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset);
242     // we cast `a` to an `ideep::tensor`, so we can get at its descriptor
243     // which we then use to set up `out` tensor w/ the same props as a
244     auto a_it = at::native::itensor_from_mkldnn(a);
245     auto mkldnn_raw_data = a_it.get_data_handle();
246     auto a_options_with_strided = a.options().layout(c10::kStrided);
247 
248     // tensor's physical size could be bigger than a logical one
249     // `a_it.get_desc().get_size()` returns the real physical size (in bytes)
250     // we use it to compute `nelem` for `aten` ops
251     auto nelem = static_cast<int64_t>(
252         a_it.get_desc().get_size() / elementSize(a.scalar_type()));
253     // we also wrap `a` storage into an aten tensor
254     auto in_aten =
255         at::from_blob(mkldnn_raw_data, {nelem}, a_options_with_strided);
256 
257     auto out_raw_data = mkldnn_raw_data;
258     auto out = a;
259     if (!inplace) {
260       // `a_it.get_desc()` will allocate a tensor
261       // of the right physical size.
262       auto it_empty = ideep::tensor(a_it.get_desc());
263       TORCH_INTERNAL_ASSERT(it_empty.get_desc() == a_it.get_desc());
264       out = at::native::new_with_itensor_mkldnn(
265           std::move(it_empty),
266           c10::optTypeMetaToScalarType(a.options().dtype_opt()),
267           a.options().device_opt());
268 
269       out_raw_data = at::native::itensor_from_mkldnn(out).get_data_handle();
270     }
271 
272     TORCH_INTERNAL_ASSERT(
273         a_it.get_desc().get_size() % elementSize(a.scalar_type()) == 0);
274 
275     auto out_aten = at::from_blob(
276         out_raw_data, {static_cast<int64_t>(nelem)}, a_options_with_strided);
277     aten_op(out_aten, in_aten);
278     push(stack, out);
279   };
280 }
281 
MKLDNNLayerNormOp(Stack & stack,bool inplace)282 void MKLDNNLayerNormOp(Stack& stack, bool inplace) {
283   c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset);
284 
285   // enable_cudnn not used
286   pop(stack);
287   auto eps = pop(stack).toDouble();
288 
289   Tensor bias{};
290   Tensor weight{};
291   auto bias_ival = pop(stack);
292   TORCH_INTERNAL_ASSERT(bias_ival.isTensor());
293   bias = bias_ival.toTensor();
294 
295   auto weight_ival = pop(stack);
296   TORCH_INTERNAL_ASSERT(weight_ival.isTensor());
297   weight = weight_ival.toTensor();
298 
299   auto shape = pop(stack).toDimVector();
300   auto input = pop(stack).toTensor();
301 
302   auto [dst, mean, rstd] =
303       at::native::mkldnn_layer_norm_last_index_weight_bias_f32(
304           input, shape, weight, bias, eps, inplace);
305   push(stack, dst);
306 };
307 
BroadOp(const Node * node)308 Operation BroadOp(const Node* node) {
309   return [](Stack& stack) {
310     auto b = pop(stack).toTensor();
311     auto a = pop(stack).toTensor();
312     auto b_size = b.sizes();
313     auto a_size = a.sizes();
314     if (a_size.equals(b_size)) {
315       // TODO: follow up with MKLDNN what the best way is
316       // to handle perf incompatible formats
317       push(stack, a, b);
318       return;
319     } else {
320       auto out_size = at::infer_size(a_size, b_size);
321       int64_t out_numel = out_size[0];
322       for (size_t i = 1, end = out_size.size(); i < end; ++i) {
323         out_numel = out_numel * out_size[i];
324       }
325 
326       auto exp_a = a;
327       auto exp_b = b;
328       int stacked = 0;
329       // mkldnn tensors only support reshape, not expand or view operators
330       if (a_size.equals(out_size)) {
331         push(stack, a);
332         ++stacked;
333       } else if (out_numel == a.numel()) {
334         exp_a = a.reshape(out_size);
335       } else {
336         // TODO: consider to initializing to a blocked layout
337         // directly if needed
338         exp_a = a.to_dense().expand(out_size).to_mkldnn();
339       }
340 
341       if (b_size.equals(out_size)) {
342         push(stack, b);
343         ++stacked;
344       } else if (out_numel == b.numel()) {
345         exp_b = b.reshape(out_size);
346       } else {
347         exp_b = b.to_dense().expand(out_size).to_mkldnn();
348       }
349 
350       if (stacked < 2) {
351         if (stacked == 1) {
352           pop(stack);
353         }
354         // If one of the inputs was expanded and converted to nchw/nhwc
355         // we might end up in a very bad spot if the second argument
356         // is in a blocked format. In this case, MKLDNN uses its
357         // reference implementation for a binary operation that follows
358         // these broadcasts and it could be up to ~100x slower.
359         // We use a very simple heuristic to convert an arg in nchw
360         // to the blocked format of the other argument.
361         c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset);
362         auto a_it = at::native::itensor_from_mkldnn(exp_a);
363         auto b_it = at::native::itensor_from_mkldnn(exp_b);
364 
365         // `is_public_format` means a tensor's physical layout isn't in MKLDNN
366         // blocked layout e.g. nchw or nhwc but not nChw8c
367         if (!a_it.is_public_format()) {
368           if (b_it.is_public_format()) {
369             b_it = b_it.reorder_if_differ_in(a_it.get_desc());
370           }
371         } else if (!b_it.is_public_format()) {
372           if (a_it.is_public_format()) {
373             a_it = a_it.reorder_if_differ_in(b_it.get_desc());
374           }
375         }
376 
377         auto a_options = exp_a.options();
378         auto a_out = at::native::new_with_itensor_mkldnn(
379             std::move(a_it),
380             c10::optTypeMetaToScalarType(a_options.dtype_opt()),
381             a_options.device_opt());
382         push(stack, a_out);
383         auto b_options = exp_b.options();
384         auto b_out = at::native::new_with_itensor_mkldnn(
385             std::move(b_it),
386             c10::optTypeMetaToScalarType(b_options.dtype_opt()),
387             b_options.device_opt());
388         push(stack, b_out);
389       };
390     }
391   };
392 }
393 
hardtanh_helper(const Node * n)394 static std::function<void(at::Tensor output, at::Tensor input)> hardtanh_helper(
395     const Node* n) {
396   auto min_val = n->f(attr::min_val);
397   auto max_val = n->f(attr::max_val);
398   return [min_val, max_val](at::Tensor output, at::Tensor input) {
399     at::cpu::hardtanh_out(output, input, min_val, max_val);
400   };
401 }
402 
clamp_helper(const Node * n)403 static std::function<void(at::Tensor output, at::Tensor input)> clamp_helper(
404     const Node* n) {
405   auto min_val = n->f(attr::min_val);
406   auto max_val = n->f(attr::max_val);
407   return [min_val, max_val](at::Tensor output, at::Tensor input) {
408     at::cpu::clamp_out(output, input, min_val, max_val);
409   };
410 }
411 
412 // any op added to this registry needs to meet
413 // the precondition: `aten_op(0) == 0`
414 const RegisterOperators MKLDNNHardSwishOpReg({
415     torch::jit::Operator(
416         "prim::MKLDNNHardSwish_(Tensor(a!) self) -> Tensor(a!)",
417         createUnaryOp(
__anon6095a4570702(at::Tensor output, at::Tensor input) 418             [](at::Tensor output, at::Tensor input) {
419               at::cpu::hardswish_out(output, input);
420             },
421             true),
422         AliasAnalysisKind::FROM_SCHEMA),
423     torch::jit::Operator(
424         "prim::MKLDNNHardSigmoid_(Tensor(a!) self) -> Tensor(a!)",
425         createUnaryOp(
__anon6095a4570802(at::Tensor output, at::Tensor input) 426             [](at::Tensor output, at::Tensor input) {
427               at::cpu::hardsigmoid_out(output, input);
428             },
429             true),
430         AliasAnalysisKind::FROM_SCHEMA),
431     torch::jit::Operator(
432         "prim::MKLDNNHardTanh_(Tensor(a!) self) -> Tensor(a!)",
__anon6095a4570902(const Node* n) 433         [](const Node* n) -> Operation {
434           return createUnaryOp(hardtanh_helper(n), true);
435         },
436         AliasAnalysisKind::FROM_SCHEMA),
437     torch::jit::Operator(
438         "prim::MKLDNNClamp_(Tensor(a!) self) -> Tensor(a!)",
__anon6095a4570a02(const Node* n) 439         [](const Node* n) -> Operation {
440           return createUnaryOp(clamp_helper(n), true);
441         },
442         AliasAnalysisKind::FROM_SCHEMA),
443     torch::jit::Operator(
444         "prim::MKLDNNHardSwish(Tensor a) -> Tensor",
445         createUnaryOp(
__anon6095a4570b02(at::Tensor output, at::Tensor input) 446             [](at::Tensor output, at::Tensor input) {
447               at::cpu::hardswish_out(output, input);
448             },
449             false),
450         AliasAnalysisKind::FROM_SCHEMA),
451     torch::jit::Operator(
452         "prim::MKLDNNHardSigmoid(Tensor a) -> Tensor",
453         createUnaryOp(
__anon6095a4570c02(at::Tensor output, at::Tensor input) 454             [](at::Tensor output, at::Tensor input) {
455               at::cpu::hardsigmoid_out(output, input);
456             },
457             false),
458         AliasAnalysisKind::FROM_SCHEMA),
459     torch::jit::Operator(
460         "prim::MKLDNNHardTanh(Tensor self) -> Tensor",
__anon6095a4570d02(const Node* n) 461         [](const Node* n) -> Operation {
462           return createUnaryOp(hardtanh_helper(n), false);
463         },
464         AliasAnalysisKind::FROM_SCHEMA),
465     torch::jit::Operator(
466         "prim::MKLDNNClamp(Tensor self) -> Tensor",
__anon6095a4570e02(const Node* n) 467         [](const Node* n) -> Operation {
468           return createUnaryOp(clamp_helper(n), false);
469         },
470         AliasAnalysisKind::FROM_SCHEMA),
471 });
472 
473 const RegisterOperators BroadOpReg({
474     torch::jit::Operator(
475         prim::BroadcastMKLDNNTensors,
476         BroadOp,
477         AliasAnalysisKind::INTERNAL_SPECIAL_CASE),
478 });
479 
480 const RegisterOperators MKLDNNLayerNormOpReg({
481     torch::jit::Operator(
482         "prim::MKLDNNLayerNorm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor",
__anon6095a4570f02(Stack& stack) 483         [](Stack& stack) { MKLDNNLayerNormOp(stack, false); },
484         AliasAnalysisKind::FROM_SCHEMA),
485     torch::jit::Operator(
486         "prim::MKLDNNLayerNorm_(Tensor(a!) input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor(a!)",
__anon6095a4571002(Stack& stack) 487         [](Stack& stack) { MKLDNNLayerNormOp(stack, true); },
488         AliasAnalysisKind::FROM_SCHEMA),
489 });
490 
ConstantMKLDNNTensorOp(const Node * node)491 Operation ConstantMKLDNNTensorOp(const Node* node) {
492   const auto& t = node->t(attr::value);
493   return [t](Stack& stack) {
494     push(stack, t);
495     return 0;
496   };
497 }
498 
mkldnn_tensor_scalar_mul(Tensor & tensor,Tensor & out,float scalar)499 Tensor mkldnn_tensor_scalar_mul(Tensor& tensor, Tensor& out, float scalar) {
500   ideep::tensor& x = at::native::itensor_from_mkldnn(tensor);
501   ideep::tensor& z = at::native::itensor_from_mkldnn(out);
502   ideep::eltwise_forward::compute(
503       x,
504       z,
505       ideep::algorithm::eltwise_linear,
506       ideep::prop_kind::forward_inference,
507       /*alpha*/ scalar);
508   return out;
509 }
510 
511 // aten::convolution does a lot of precomputation and dispatching before
512 // mkldnn_convolution is called. registering here we can directly invoke the op
513 // and avoid overhead. avoiding dispatch overhead for other operators - relu,
514 // add, etc - did not benchmark as speeding up models noticeably. the additional
515 // overhead of `convolution` warrants the custom operator.
516 jit::RegisterOperators reg_fut_ops({
517     jit::Operator(
518         // XXX: this follows the schema convention of conv2d/conv3d, not
519         // aten::mkldnn_convolution, which is different for some reason!
520         "prim::mkldnn_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor",
__anon6095a4571202(jit::Stack& stack) 521         [](jit::Stack& stack) {
522           int64_t groups = pop(stack).toInt();
523           auto dilation = pop(stack).toIntVector();
524           auto padding = pop(stack).toIntVector();
525           auto stride = pop(stack).toIntVector();
526 
527           Tensor bias;
528           IValue bias_ival = pop(stack);
529           if (!bias_ival.isNone()) {
530             bias = bias_ival.toTensor();
531           }
532           Tensor weight = pop(stack).toTensor();
533           Tensor input = pop(stack).toTensor();
534 
535           at::AutoDispatchBelowAutograd mode;
536           // aten::convolution takes care of 0 dim case before calls into
537           // backends
538           if (input.size(0) == 0) {
539             std::vector<int64_t> o = at::native::conv_output_size(
540                 input.sizes(), weight.sizes(), padding, stride, dilation);
541             push(
542                 stack,
543                 at::native::empty_mkldnn(
544                     o,
545                     c10::optTypeMetaToScalarType(input.options().dtype_opt()),
546                     input.options().layout_opt(),
547                     input.options().device_opt(),
548                     input.options().pinned_memory_opt()));
549             return;
550           }
551           // aten::convolution also checks dtype mismatches
552           TORCH_CHECK(
553               input.options().type_equal(weight.options()),
554               "Input type (",
555               input.toString(),
556               ") and weight type (",
557               weight.toString(),
558               ") should be the same");
559 
560           push(
561               stack,
562               at::native::mkldnn_convolution(
563                   input, weight, bias, padding, stride, dilation, groups));
564         },
565         aliasAnalysisFromSchema()),
566     // registering as custom operators avoids Scalar->Tensor->Scalar conversion
567     // in default bindings
568     jit::Operator(
569         "prim::MKLDNNScalarMul(Tensor self, Scalar other) -> Tensor",
__anon6095a4571302(jit::Stack& stack) 570         [](jit::Stack& stack) {
571           c10::impl::ExcludeDispatchKeyGuard edkg(
572               c10::autograd_dispatch_keyset);
573           float other = pop(stack).toScalar().toFloat();
574           Tensor self = pop(stack).toTensor();
575           auto out = at::native::empty_mkldnn(
576               self.sizes(),
577               c10::optTypeMetaToScalarType(self.options().dtype_opt()),
578               self.options().layout_opt(),
579               self.options().device_opt(),
580               self.options().pinned_memory_opt());
581 
582           mkldnn_tensor_scalar_mul(self, out, other);
583           push(stack, out);
584         },
585         aliasAnalysisFromSchema()),
586     jit::Operator(
587         "prim::MKLDNNScalarMul_(Tensor(a!) self, Scalar other) -> Tensor(a!)",
__anon6095a4571402(jit::Stack& stack) 588         [](jit::Stack& stack) {
589           c10::impl::ExcludeDispatchKeyGuard edkg(
590               c10::autograd_dispatch_keyset);
591           float other = pop(stack).toScalar().toFloat();
592           Tensor self = pop(stack).toTensor();
593           mkldnn_tensor_scalar_mul(self, self, other);
594           push(stack, self);
595         },
596         aliasAnalysisFromSchema()),
597 });
598 
599 // This is registered as its own op instead of as prim::Constant bc it does not
600 // serialize which is an invariant of prim::Constant
601 // TODO: make mkldnn tensor serialize...
602 const RegisterOperators MKLDNNConstantOp({
603     torch::jit::Operator(
604         prim::ConstantMKLDNNTensor,
605         ConstantMKLDNNTensorOp,
606         AliasAnalysisKind::INTERNAL_SPECIAL_CASE),
607 });
608 
createConstantMKLDNNTensorOp(Graph * g,const Tensor & mkldnn_tensor)609 Node* createConstantMKLDNNTensorOp(Graph* g, const Tensor& mkldnn_tensor) {
610   TORCH_INTERNAL_ASSERT(mkldnn_tensor.is_mkldnn());
611   auto op = g->create(prim::ConstantMKLDNNTensor);
612   op->t_(attr::value, mkldnn_tensor);
613   return op;
614 }
615 
supportedMKLDNNWeight(const Tensor & weight)616 bool supportedMKLDNNWeight(const Tensor& weight) {
617   return weight.device().is_cpu() && weight.dtype() == c10::ScalarType::Float &&
618       weight.ndimension() != 0;
619 }
620 
replaceInputWithMKLDNNTensor(Node * n,size_t index)621 void replaceInputWithMKLDNNTensor(Node* n, size_t index) {
622   Value* input = n->inputs().at(index);
623   auto mkldnn_tensor = constant_as<Tensor>(input)->to_mkldnn();
624   auto mkldnn_tensor_value =
625       createConstantMKLDNNTensorOp(n->owningGraph(), mkldnn_tensor)
626           ->insertBefore(n)
627           ->output();
628   mkldnn_tensor_value->setDebugName(input->debugName() + "_mkldnn");
629   n->replaceInputWith(input, mkldnn_tensor_value);
630 }
631 
replaceInputWithMKLDNNTensor(Node * n,const std::string & name,const at::Tensor & mkldnn_tensor)632 void replaceInputWithMKLDNNTensor(
633     Node* n,
634     const std::string& name,
635     const at::Tensor& mkldnn_tensor) {
636   Value* input = n->namedInput(name);
637   auto mkldnn_tensor_value =
638       createConstantMKLDNNTensorOp(n->owningGraph(), mkldnn_tensor)
639           ->insertBefore(n)
640           ->output();
641   mkldnn_tensor_value->setDebugName(input->debugName() + "_mkldnn");
642   n->replaceInputWith(input, mkldnn_tensor_value);
643 }
644 
replaceInputWithMKLDNNTensor(Node * n,const std::string & name)645 void replaceInputWithMKLDNNTensor(Node* n, const std::string& name) {
646   Value* input = n->namedInput(name);
647   auto mkldnn_tensor = constant_as<Tensor>(input)->to_mkldnn();
648   replaceInputWithMKLDNNTensor(n, name, mkldnn_tensor);
649 }
650 
moveConvWeightsToMKLDNN(Node * conv)651 void moveConvWeightsToMKLDNN(Node* conv) {
652   auto conv_w_mkldnn =
653       constant_as<Tensor>(conv->namedInput("weight")).value().to_mkldnn();
654   std::vector<int64_t> padding =
655       toIValue(conv->namedInput("padding"))->toIntVector();
656   std::vector<int64_t> stride =
657       toIValue(conv->namedInput("stride"))->toIntVector();
658   std::vector<int64_t> dilation =
659       toIValue(conv->namedInput("dilation"))->toIntVector();
660   auto groups = constant_as<int64_t>(conv->namedInput("groups")).value();
661 
662   if (conv->kind() == aten::conv2d) {
663     conv_w_mkldnn = mkldnn_reorder_conv2d_weight(
664         conv_w_mkldnn, padding, stride, dilation, groups);
665   } else if (conv->kind() == aten::conv3d) {
666     conv_w_mkldnn = mkldnn_reorder_conv3d_weight(
667         conv_w_mkldnn, padding, stride, dilation, groups);
668   } else {
669     TORCH_INTERNAL_ASSERT(false);
670   }
671   replaceInputWithMKLDNNTensor(conv, "weight", conv_w_mkldnn);
672 
673   if (conv->namedInput("bias")->type() != NoneType::get()) {
674     replaceInputWithMKLDNNTensor(conv, "bias");
675   }
676 }
677 
moveWeightsToMKLDNN(Node * n)678 void moveWeightsToMKLDNN(Node* n) {
679   // conv goes through special pathway so we can call mkldnn reorder conv
680   // primitive
681   if (n->kind() == aten::conv2d || n->kind() == aten::conv3d) {
682     moveConvWeightsToMKLDNN(n);
683   } else {
684     for (size_t i = 0; i < n->inputs().size(); ++i) {
685       if (!n->input(i)->type()->cast<TensorType>() ||
686           n->input(i)->node()->kind() != prim::Constant) {
687         continue;
688       }
689       replaceInputWithMKLDNNTensor(n, i);
690     }
691   }
692 }
693 
clamp_node_creator(Node * body_node,c10::Symbol kind,double min_val,double max_val)694 static void clamp_node_creator(
695     Node* body_node,
696     c10::Symbol kind,
697     double min_val,
698     double max_val) {
699   WithInsertPoint insert_guard{body_node};
700   auto out_node =
701       body_node->owningGraph()->create({kind}, {body_node->input(0)}, 1);
702   // N.B. we can't use `insert` as it calls `getOperation` (via
703   // `emitBuiltinCall`) which uses `min_val` and `max_val` attrs which we
704   // haven't set yet.
705   body_node->owningGraph()->insertNode(out_node);
706   auto out_val = out_node->output();
707   out_node->f_(attr::min_val, min_val);
708   out_node->f_(attr::max_val, max_val);
709   out_val->copyMetadata(body_node->output());
710   body_node->output()->replaceAllUsesWith(out_val);
711   body_node->destroy();
712 }
713 
ComputeSubgraphInMKLDNN(Node * subgraph_node)714 void ComputeSubgraphInMKLDNN(Node* subgraph_node) {
715   auto graph = subgraph_node->owningGraph();
716   Value* none_value = nullptr;
717   {
718     WithInsertPoint guard(subgraph_node);
719     none_value = graph->insertConstant(IValue());
720   }
721   for (size_t i = 0; i < subgraph_node->inputs().size(); ++i) {
722     Value* v = subgraph_node->inputs().at(i);
723     if (!v->type()->cast<TensorType>()) {
724       continue;
725     }
726     auto to_mkldnn =
727         graph->create(c10::Symbol::fromQualString("aten::to_mkldnn"), 1)
728             ->insertBefore(subgraph_node);
729     to_mkldnn->addInput(v);
730     to_mkldnn->addInput(none_value);
731     subgraph_node->replaceInput(i, to_mkldnn->output());
732   }
733 
734   for (size_t i = 0; i < subgraph_node->outputs().size(); ++i) {
735     Value* v = subgraph_node->outputs().at(i);
736     if (!v->type()->cast<TensorType>()) {
737       continue;
738     }
739     auto from_mkldnn = graph
740                            ->create(
741                                c10::Symbol::fromQualString("aten::to_dense"),
742                                {v, none_value, none_value})
743                            ->insertAfter(subgraph_node);
744     v->replaceAllUsesAfterNodeWith(from_mkldnn, from_mkldnn->output());
745   }
746 
747   auto subgraph = SubgraphUtils::getSubgraph(subgraph_node);
748   for (auto it = subgraph->block()->nodes().begin();
749        it != subgraph->block()->nodes().end();) {
750     Node* body_node = *it;
751     it++;
752 
753     moveWeightsToMKLDNN(body_node);
754 
755     if (body_node->kind() == aten::add ||
756         (body_node->kind() == aten::mul &&
757          body_node->input(1)->type()->cast<TensorType>())) {
758       auto node = body_node->owningGraph()->create(
759           Symbol::prim("BroadcastMKLDNNTensors"),
760           {body_node->inputs().at(0), body_node->inputs().at(1)},
761           2);
762       node->insertBefore(body_node);
763       body_node->replaceInput(0, node->outputs().at(0));
764       body_node->replaceInput(1, node->outputs().at(1));
765     }
766     if (body_node->kind() == aten::mul &&
767         body_node->input(1)->type()->isSubtypeOf(*NumberType::get())) {
768       body_node->replaceWithNewSymbol(Symbol::prim("MKLDNNScalarMul"));
769       body_node->destroy();
770       continue;
771     }
772 
773     if (body_node->matches(
774             "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor")) {
775       body_node->replaceWithNewSymbol(Symbol::prim("MKLDNNLayerNorm"));
776       body_node->destroy();
777       continue;
778     }
779 
780     if (body_node->kind() == aten::hardswish) {
781       body_node->replaceWithNewSymbol(prim::MKLDNNHardSwish);
782       body_node->destroy();
783       continue;
784     }
785 
786     if (body_node->kind() == aten::hardsigmoid) {
787       body_node->replaceWithNewSymbol(prim::MKLDNNHardSigmoid);
788       body_node->destroy();
789       continue;
790     }
791 
792     if (body_node->kind() == aten::relu6) {
793       clamp_node_creator(body_node, prim::MKLDNNHardTanh, 0., 6.);
794       continue;
795     }
796 
797     if (body_node->kind() == aten::hardtanh) {
798       auto min_val =
799           constant_as<double>(body_node->namedInput("min_val")).value();
800       auto max_val =
801           constant_as<double>(body_node->namedInput("max_val")).value();
802       clamp_node_creator(body_node, prim::MKLDNNHardTanh, min_val, max_val);
803       continue;
804     }
805 
806     if (body_node->kind() == aten::clamp) {
807       auto min_val = constant_as<double>(body_node->namedInput("min")).value();
808       auto max_val = constant_as<double>(body_node->namedInput("max")).value();
809       clamp_node_creator(body_node, prim::MKLDNNClamp, min_val, max_val);
810       continue;
811     }
812 
813     if (body_node->kind() == aten::conv2d ||
814         body_node->kind() == aten::conv3d) {
815       // this node doesnt handle string padding yet...
816       if (!body_node->namedInput("padding")->type()->cast<StringType>()) {
817         body_node->replaceWithNewSymbol(Symbol::prim("mkldnn_convolution"));
818         body_node->destroy();
819         continue;
820       }
821     }
822   }
823 }
824 
nonConstantParameters(Node * n)825 bool nonConstantParameters(Node* n) {
826   for (size_t i = 1; i < n->inputs().size(); i++) {
827     if (n->inputs().at(i)->node()->kind() != prim::Constant) {
828       return true;
829     }
830   }
831   return false;
832 }
833 
frozenMkldnnCompatibleLinearNode(Node * n)834 bool frozenMkldnnCompatibleLinearNode(Node* n) {
835   if (nonConstantParameters(n)) {
836     return false;
837   }
838 
839   if (n->kind() != aten::linear) {
840     return false;
841   }
842 
843   auto weight = constant_as<Tensor>(n->namedInput("weight")).value();
844   return supportedMKLDNNWeight(weight);
845 }
846 
frozenMkldnnCompatibleConvNode(Node * n)847 bool frozenMkldnnCompatibleConvNode(Node* n) {
848   if (nonConstantParameters(n)) {
849     return false;
850   }
851   // mkldnn does not support conv1d
852   // _convolution is rewritten before this pass is invoked
853   if (n->kind() != aten::conv2d && n->kind() != aten::conv3d) {
854     return false;
855   }
856 
857   auto weight = constant_as<Tensor>(n->namedInput("weight")).value();
858   return supportedMKLDNNWeight(weight);
859 }
860 
861 // [mkldnn perf strategy]
862 // Certain ops - aten::linear, aten::conv2d, aten::conv3d - provide a huge speed
863 // up just by converting the constant weights to MKLDNN AOT, and then at runtime
864 // converting the non-constant input to_mkldnn before the op, and then back to
865 // its original layout after the op. The speed up holds even if you end up
866 // converting the input to_mkldnn and output back to_dense. We start groups of
867 // ops to compute in MKLDNN only from these ops that are a strict speedup. Then,
868 // we expand the groups to include operators which are computable in MKLDNN &
869 // are roughly perf equal to eager. We do this in the hopes of joining multiple
870 // fast nodes together, saving to_mkldnn and to_dense conversions.
871 //
872 // MKLDNN only supports float32 inputs for aten::linear, aten::conv2d &
873 // aten::conv3d. We only fuse these nodes if the weights are float32, and then
874 // we only include operators which we can prove will execute in float32. By
875 // fusing topologically we can maintain the invariant that all tensor types in
876 // the graph are floating point. In fusing Conv-> Add -> Relu -> Conv we start
877 // with the first Conv, know that the output is float, and can then safely merge
878 // Add and Relu. If we started with the last Conv, it would be difficult to
879 // prove in our first pass that the Add's inputs were both float32 without first
880 // fusing the first conv.
881 
882 class MKLDNNSubgraphSlicer {
883  public:
MKLDNNSubgraphSlicer(Block * block,std::shared_ptr<Graph> graph,AliasDb & aliasDb)884   MKLDNNSubgraphSlicer(
885       Block* block,
886       std::shared_ptr<Graph> graph,
887       AliasDb& aliasDb)
888       : block_(block), graph_(std::move(graph)), aliasDb_(aliasDb) {}
889 
run()890   void run() {
891     // We maintain alias db correctness in-place while building up the autodiff
892     // subgraphs, however it is difficult to preserve correctness when
893     // un-inlining autodiff subgraphs. We first recursively construct all
894     // subgraphs and then unmerge them into the graph
895     buildupSubgraphs();
896     computeSubgraphsInMKLDNN();
897     // Run CSE globally onceto eliminate duplicates that may have occurred
898     // while inlining subgraphs.
899     EliminateCommonSubexpression(graph_);
900   }
901 
buildupSubgraphs()902   void buildupSubgraphs() {
903     // We need to run the slicer multiple times in order to get all merge
904     // opportunities. This is because moveBeforeTopologicalValid may reorder
905     // nodes to be AFTER the current iteration point. In order to properly
906     // consider those nodes for merging, we need run the pass until no changes
907     // have been made.
908     //
909     // Example:
910     //   c = f(a, b)
911     //   d = f(c)
912     //   e = f(d)  <- iter is here, moving upward
913     // After c.moveBeforeTopologicallyValid(e), we have:
914     //   c = f(a, b)
915     //   e = f(d)  <- iter still here
916     //   d = f(c)  <- this was node moved on the other side.
917 
918     bool any_changed = true;
919     while (any_changed) {
920       any_changed = false;
921       for (auto it = block_->nodes().begin(); it != block_->nodes().end();) {
922         bool changed = false;
923         std::tie(it, changed) = scanNode(*it);
924         any_changed |= changed;
925       }
926     }
927 
928     // Construct Subgraphs Recursively
929     for (Node* n : block_->nodes()) {
930       for (auto subBlock : n->blocks()) {
931         MKLDNNSubgraphSlicer(subBlock, graph_, aliasDb_).buildupSubgraphs();
932       }
933     }
934   }
935 
MKLDNNGroupStart(Node * node)936   static bool MKLDNNGroupStart(Node* node) {
937     // if we're already in the process of merging
938     if (node->kind() == prim::MKLDNNGroup) {
939       return true;
940     }
941     // see [mkldnn perf strategy]
942     return frozenMkldnnCompatibleConvNode(node);
943   }
944 
945  private:
946   // MKLDNN only supports floats of dimension > 0, so we only support
947   // Tensors who have a known type or were previously verified
948   // to be usable in an MKLDNN Group
tensorInputIsMKLDNNSupported(Value * v,Node * v_use)949   bool tensorInputIsMKLDNNSupported(Value* v, Node* v_use) {
950     auto const_tensor = constant_as<Tensor>(v);
951     if (const_tensor) {
952       return supportedMKLDNNWeight(*const_tensor);
953     }
954     auto k = v->node()->kind();
955     if (k == prim::MKLDNNGroup || k == prim::ConstantMKLDNNTensor ||
956         k == aten::to_mkldnn) {
957       return true;
958     }
959     for (const auto& use : v->uses()) {
960       if (use.user->kind() == aten::to_mkldnn &&
961           v_use->owningBlock() == use.user->owningBlock()) {
962         return true;
963       }
964     }
965     return false;
966   }
967 
968   // We include ops here which are roughly perf-equivalent in mkldnn as with
969   // aten (single & multithreaded) and whose inputs & outputs are float32.
computableInMKLDNN(Node * n)970   bool computableInMKLDNN(Node* n) {
971     for (Value* v : n->inputs()) {
972       if (v->type()->cast<TensorType>() &&
973           !(tensorInputIsMKLDNNSupported(v, n))) {
974         return false;
975       }
976     }
977 
978     if (n->matches(
979             "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor") &&
980         n->namedInput("weight")->type() != NoneType::get() &&
981         n->namedInput("bias")->type() != NoneType::get()) {
982       auto norm_shape =
983           constant_as<std::vector<int64_t>>(n->namedInput("normalized_shape"));
984       return norm_shape.has_value() && norm_shape->size() == 1;
985     }
986 
987     // unary ops we dont need to prove anything else than
988     // the input is mkldnn supported
989     switch (n->kind()) {
990       case aten::relu:
991       case aten::relu6:
992       case aten::gelu:
993       case aten::prelu:
994       case aten::sigmoid:
995       case aten::hardsigmoid:
996       case aten::hardswish:
997       case aten::tanh:
998       case aten::batch_norm:
999       case aten::max_pool2d:
1000       case aten::max_pool3d:
1001       case aten::avg_pool2d:
1002       case aten::adaptive_avg_pool2d:
1003       case aten::avg_pool3d:
1004         // case aten::adaptive_max_pool2d: // return tuples which break fusion
1005         // case aten::adaptive_max_pool3d: // return tuples which break fusion
1006         // case aten::adaptive_avg_pool3d: // no ideep binding
1007         return true;
1008     }
1009 
1010     if ((n->kind() == aten::hardtanh || n->kind() == aten::clamp) &&
1011         !nonConstantParameters(n)) {
1012       const size_t MIN_INDEX = 1, MAX_INDEX = 2;
1013       auto min_val = constant_as<double>(n->input(MIN_INDEX)).value();
1014       auto max_val = constant_as<double>(n->input(MAX_INDEX)).value();
1015       // we need to maintain the following invariant `pointwise_func(0) == 0`,
1016       // see `createUnaryOp`
1017       if (min_val <= 0. && max_val >= 0.) {
1018         return true;
1019       }
1020     }
1021 
1022     if (n->kind() == aten::add) {
1023       // mkldnn doesn't currently support Tensor-Scalar add
1024       for (const auto i : c10::irange(2)) {
1025         if (!n->inputs().at(i)->type()->cast<TensorType>()) {
1026           return false;
1027         }
1028       }
1029       return true;
1030     }
1031     if (n->kind() == aten::mul) {
1032       return n->input(0)->type()->cast<TensorType>() &&
1033           (n->input(1)->type()->cast<TensorType>() ||
1034            n->input(1)->type()->isSubtypeOf(*NumberType::get()));
1035     }
1036 
1037     if (n->kind() == aten::dropout) {
1038       auto train = constant_as<bool>(n->namedInput("train")).value();
1039       return train == false;
1040     }
1041     return false;
1042   }
1043 
computeSubgraphsInMKLDNN()1044   void computeSubgraphsInMKLDNN() {
1045     auto curNode = *block_->nodes().begin();
1046     while (curNode != *block_->nodes().end()) {
1047       auto nextNode = curNode->next();
1048       if (curNode->kind() == prim::MKLDNNGroup) {
1049         ComputeSubgraphInMKLDNN(curNode);
1050         InplaceMKLDNNSubgraph(SubgraphUtils::getSubgraph(curNode));
1051         SubgraphUtils::unmergeSubgraph(curNode);
1052       }
1053       curNode = nextNode;
1054     }
1055     for (Node* n : block_->nodes()) {
1056       for (Block* b : n->blocks()) {
1057         MKLDNNSubgraphSlicer(b, graph_, aliasDb_).computeSubgraphsInMKLDNN();
1058       }
1059     }
1060   }
1061 
shouldConsiderForMerge(Node * node)1062   bool shouldConsiderForMerge(Node* node) {
1063     // if we're already in the process of merging
1064     if (node->kind() == prim::MKLDNNGroup) {
1065       return true;
1066     }
1067     return frozenMkldnnCompatibleLinearNode(node) ||
1068         frozenMkldnnCompatibleConvNode(node) || computableInMKLDNN(node);
1069   }
1070 
scanNode(Node * producer)1071   std::pair<graph_node_list::iterator, bool> scanNode(Node* producer) {
1072     if (MKLDNNGroupStart(producer)) {
1073       if (producer->kind() != prim::MKLDNNGroup) {
1074         producer = SubgraphUtils::createSingletonSubgraphAndUpdateAliasing(
1075             producer, prim::MKLDNNGroup, aliasDb_);
1076       }
1077       std::vector<Node*> output_nodes;
1078       for (Value* v : producer->outputs()) {
1079         for (const auto& use : v->uses()) {
1080           output_nodes.push_back(use.user);
1081         }
1082       }
1083       std::sort(
1084           output_nodes.begin(), output_nodes.end(), [&](Node* a, Node* b) {
1085             return a->isBefore(b);
1086           });
1087       for (auto output_node : output_nodes) {
1088         if (auto group = tryMerge(producer, output_node)) {
1089           // we successfully merged, so the new group's `outputs` may have
1090           // changed. So rescan the new group for more merging opportunities.
1091           return std::make_pair(group.value()->iterator()++, true);
1092         }
1093       }
1094     }
1095 
1096     return std::make_pair(++producer->iterator(), false);
1097   }
1098 
1099   // Try to merge `consumer` into `producer`. If successful, this destroys
1100   // `consumer` and returns the `producer` group.
tryMerge(Node * producer,Node * consumer)1101   std::optional<Node*> tryMerge(Node* producer, Node* consumer) {
1102     AT_ASSERT(producer->kind() == prim::MKLDNNGroup);
1103     bool canMerge = shouldConsiderForMerge(consumer) &&
1104         aliasDb_.moveAfterTopologicallyValid(consumer, producer);
1105 
1106     if (!canMerge) {
1107       return std::nullopt;
1108     }
1109 
1110     SubgraphUtils::mergeNodeIntoSubgraphAndUpdateAliasing(
1111         consumer, producer, aliasDb_);
1112 
1113     return producer;
1114   }
1115 
1116   Block* block_;
1117   std::shared_ptr<Graph> graph_;
1118   AliasDb& aliasDb_;
1119 };
1120 
containsMKLDNNGroup(Block * b)1121 bool containsMKLDNNGroup(Block* b) {
1122   for (Node* n : b->nodes()) {
1123     for (Block* block : n->blocks()) {
1124       if (containsMKLDNNGroup(block)) {
1125         return true;
1126       }
1127     }
1128     if (MKLDNNSubgraphSlicer::MKLDNNGroupStart(n)) {
1129       return true;
1130     }
1131   }
1132   return false;
1133 }
1134 
1135 } // namespace
1136 
ConvertFrozenOpsToMKLDNN(std::shared_ptr<Graph> & graph)1137 void ConvertFrozenOpsToMKLDNN(std::shared_ptr<Graph>& graph) {
1138   GRAPH_DUMP("Before convert frozen ops to mkldnn", graph);
1139   // TODO: replace conv1d with conv2d ?
1140   graph_rewrite_helper::replaceConvolutionWithAtenConv(graph);
1141   if (containsMKLDNNGroup(graph->block())) {
1142     // Only remove tensor mutation if we know we're going to create speedups
1143     // with mkldnn. Only supporting functional ops simplifies this pass bc
1144     // running an op in mkldnn removes the aliasing relationships that
1145     // previously existed between input and output.
1146     RemoveTensorMutation(graph, [](Node* node_to_functionalize) {
1147       static std::unordered_set<Symbol> mkldnn_ops = {
1148           aten::add_,
1149           aten::mul_,
1150           aten::relu_,
1151           aten::relu6_,
1152           aten::gelu_,
1153           aten::hardswish_,
1154           aten::dropout_,
1155           aten::sigmoid_,
1156           aten::hardsigmoid_,
1157           aten::hardtanh_,
1158           aten::tanh_,
1159           aten::clamp_,
1160       };
1161       return mkldnn_ops.count(node_to_functionalize->kind()) != 0;
1162     });
1163 
1164     AliasDb db(graph);
1165     MKLDNNSubgraphSlicer(graph->block(), graph, db).run();
1166     EliminateDeadCode(graph);
1167     GRAPH_DUMP("After convert frozen ops to mkldnn", graph);
1168   } else {
1169     GRAPH_DUMP("No mkldnn compatible frozen nodes", graph);
1170   }
1171 }
1172 
1173 #else
1174 
1175 void ConvertFrozenOpsToMKLDNN(std::shared_ptr<Graph>& graph) {
1176   GRAPH_DUMP("MKLDNN Not enabled", graph);
1177 }
1178 
1179 #endif
1180 
1181 } // namespace torch::jit
1182