1 #include <ATen/ATen.h>
2 #include <ATen/Config.h>
3 #include <ATen/NativeFunctions.h>
4 #include <ATen/Utils.h>
5 #include <ATen/core/symbol.h>
6 #include <ATen/native/layer_norm.h>
7 #include <c10/core/ScalarType.h>
8 #include <c10/util/Exception.h>
9 #include <c10/util/irange.h>
10
11 #include <torch/csrc/jit/ir/alias_analysis.h>
12 #include <torch/csrc/jit/ir/constants.h>
13 #include <torch/csrc/jit/ir/ir.h>
14 #include <torch/csrc/jit/jit_log.h>
15 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
16 #include <torch/csrc/jit/passes/constant_propagation.h>
17 #include <torch/csrc/jit/passes/dead_code_elimination.h>
18 #include <torch/csrc/jit/passes/fold_conv_bn.h>
19 #include <torch/csrc/jit/passes/frozen_conv_folding.h>
20 #include <torch/csrc/jit/passes/frozen_ops_to_mkldnn.h>
21 #include <torch/csrc/jit/passes/graph_rewrite_helper.h>
22 #include <torch/csrc/jit/passes/peephole.h>
23 #include <torch/csrc/jit/passes/remove_mutation.h>
24 #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
25 #include <torch/csrc/jit/runtime/custom_operator.h>
26 #include <torch/csrc/jit/runtime/operator_options.h>
27 #include <torch/csrc/jit/tensorexpr/types.h>
28 // clang-format off
29 // moving ConvUtils include induces import cycle
30 #include <ATen/native/ConvUtils.h>
31 #include <algorithm>
32 #include <memory>
33 #include <ATen/core/stack.h>
34 #include <c10/core/Layout.h>
35 #include <c10/util/StringUtil.h>
36
37 #if AT_MKLDNN_ENABLED()
38 #include <ATen/CPUFunctions.h>
39 #include <dnnl_types.h>
40 #include <ATen/native/mkldnn/Utils.h>
41 #include <ATen/native/mkldnn/MKLDNNCommon.h>
42 #include <ideep.hpp>
43 #endif
44
45 // clang-format on
46
47 namespace torch::jit {
48
49 #if AT_MKLDNN_ENABLED()
50
51 using Tensor = at::Tensor;
52
53 namespace {
54
aliasAnalysisFromSchema()55 c10::AliasAnalysisKind aliasAnalysisFromSchema() {
56 return AliasAnalysisKind::FROM_SCHEMA;
57 }
58
59 using ValueSet = std::unordered_set<Value*>;
60 using ValueSetPtr = std::shared_ptr<std::unordered_set<Value*>>;
61
getLastUse(Value * v)62 Node* getLastUse(Value* v) {
63 auto last_use_node = v->node();
64 for (const auto& use : v->uses()) {
65 if (use.user->isAfter(last_use_node)) {
66 last_use_node = use.user;
67 }
68 }
69 return last_use_node;
70 }
71
merge_sets(std::unordered_map<Value *,ValueSetPtr> & alias_mapping,Value * existing,Value * new_v)72 void merge_sets(
73 std::unordered_map<Value*, ValueSetPtr>& alias_mapping,
74 Value* existing,
75 Value* new_v) {
76 if (alias_mapping[existing] == alias_mapping[new_v]) {
77 return;
78 }
79 auto existing_set = alias_mapping[existing];
80 auto set_to_remove = alias_mapping[new_v];
81 for (auto it = set_to_remove->begin(); it != set_to_remove->end(); it++) {
82 existing_set->insert(*it);
83 alias_mapping[*it] = existing_set;
84 }
85 }
86
87 // no uses of tensors in container types
assertNonTensorTypeDoesNotContainTensors(TypePtr type)88 void assertNonTensorTypeDoesNotContainTensors(TypePtr type) {
89 if (type->cast<TensorType>()) {
90 return;
91 }
92 for (const auto& t : type->containedTypes()) {
93 TORCH_INTERNAL_ASSERT(!t->cast<TensorType>());
94 }
95 }
96
InplaceMKLDNNSubgraph(std::shared_ptr<Graph> graph)97 void InplaceMKLDNNSubgraph(std::shared_ptr<Graph> graph) {
98 // This function first calculates aliasing sets,
99 // then calculates the last node each aliasing set is alive for.
100 // Then we go through each node, if it's a node which has an equivalent
101 // inplace node and the aliasing set for its input is dead afer this node, we
102 // inplace it. Then we merge the aliasing sets for the input and output of the
103 // node and extend the liveness of the set. To inplace a node you need to
104 // prove device and dtype of the input and output are the same, which we've
105 // already done, and prove that the output size is the same as the input size,
106 // which is achieved by explicit Broadcast nodes (which we inserted for other
107 // reasons).
108 // The graphs here are simple subgraphs without uses of Tensors in
109 // containers (Lists, GetAttrs, etc)
110
111 // CALCULATE ALIASING SETS
112
113 auto aliasDb = std::make_unique<AliasDb>(graph);
114
115 // map from Value to its Aliasing Set
116 std::unordered_map<Value*, ValueSetPtr> alias_mapping;
117 ValueSet set;
118 ValueSetPtr input_set = std::make_shared<ValueSet>(set);
119 for (Value* v : graph->inputs()) {
120 if (v->type()->cast<TensorType>()) {
121 input_set->insert(v);
122 alias_mapping[v] = input_set;
123 } else {
124 assertNonTensorTypeDoesNotContainTensors(v->type());
125 }
126 }
127
128 for (Node* n : graph->nodes()) {
129 for (Value* output : n->outputs()) {
130 if (!output->type()->cast<TensorType>()) {
131 assertNonTensorTypeDoesNotContainTensors(output->type());
132 continue;
133 }
134
135 std::unordered_set<Value*> new_set = {output};
136 alias_mapping[output] = std::make_shared<ValueSet>(new_set);
137 for (Value* input : n->inputs()) {
138 if (aliasDb->mayAlias(input, output)) {
139 merge_sets(alias_mapping, input, output);
140 }
141 }
142 }
143 }
144
145 // CALCULATE ALIASING SET LIVENESS
146
147 // map from aliased set -> last use of set
148 std::unordered_map<ValueSetPtr, Node*> set_liveness;
149 for (auto& set : alias_mapping) {
150 if (set_liveness.count(set.second)) {
151 continue;
152 }
153 Node* last = nullptr;
154 for (const auto& v : *set.second) {
155 auto k = v->node()->kind();
156 if (k == prim::Constant || k == prim::ConstantMKLDNNTensor ||
157 k == prim::Param) {
158 last = graph->return_node();
159 continue;
160 }
161
162 auto last_use = getLastUse(v);
163 if (!last || last_use->isAfter(last)) {
164 last = last_use;
165 }
166 }
167 set_liveness[set.second] = last;
168 }
169
170 // REUSING MEMORY BY REINPLACING NODES
171 std::vector<Node*> nodes_to_inplace;
172
173 auto add_to_inplace_set = [&](Node* node) {
174 // defer making the inplacing change because that would invalidate the old
175 // Node output Value*
176 nodes_to_inplace.push_back(node);
177 TORCH_INTERNAL_ASSERT(node->outputs().size() == 1);
178 auto output_liveness_end =
179 set_liveness[alias_mapping[node->outputs().at(0)]];
180 merge_sets(alias_mapping, node->inputs().at(0), node->output());
181 set_liveness[alias_mapping[node->output()]] = output_liveness_end;
182 };
183
184 for (Node* node : graph->nodes()) {
185 auto k = node->kind();
186 if (k == aten::relu || k == aten::sigmoid || k == aten::dropout ||
187 k == prim::MKLDNNHardSwish || k == prim::MKLDNNHardSigmoid ||
188 k == prim::MKLDNNHardTanh || k == aten::tanh ||
189 k == prim::MKLDNNClamp || k == Symbol::prim("MKLDNNScalarMul") ||
190 k == Symbol::prim("MKLDNNLayerNorm")) {
191 if (set_liveness[alias_mapping[node->inputs().at(0)]]->isAfter(node)) {
192 continue;
193 }
194 add_to_inplace_set(node);
195 } else if (k == aten::mul || k == aten::add) {
196 // the binary operators (add/mul) are commutative and only take tensor
197 // inputs, so we can inplace either the first or second input
198 int64_t reusable_value_index = -1;
199 for (const auto i : c10::irange(2)) {
200 TORCH_INTERNAL_ASSERT(node->inputs().at(i)->type()->cast<TensorType>());
201 if (!set_liveness[alias_mapping[node->inputs().at(i)]]->isAfter(node)) {
202 reusable_value_index = i;
203 break;
204 }
205 }
206
207 if (reusable_value_index == -1) {
208 continue;
209 }
210
211 if (reusable_value_index == 1) {
212 node->insertInput(0, node->inputs().at(1));
213 node->removeInput(2);
214 }
215 add_to_inplace_set(node);
216 }
217 }
218
219 for (Node* node : nodes_to_inplace) {
220 node->replaceWithNewSymbol(
221 Symbol::fromQualString(node->schema().name() + "_"));
222 node->destroy();
223 }
224 }
225
226 // This is a factory function that creates an Operation that takes
227 // MKLDNN tensors and unpacks them into 1D contiguous tensors that we can
228 // run aten operations on. The precondition for using this function is that the
229 // aten operations in `aten_op` should be an identity for zero inputs. In other
230 // words, this should: `aten_op(0) = 0` The reason for this precondition has to
231 // do with blocked formats MKLDNN uses to lay tensor elements (nChw8c, nChw16c,
232 // etc). It splits the channel dimension into chunks of 8/16 makes it the
233 // innermost dimension. Whenever the channel dim isn't divisible by 8/16 the
234 // innermost dimension is padded with 0s. The precondition, `aten_op(0) == 0`
235 // allows us to avoid any special casing of padded elements.
createUnaryOp(std::function<void (at::Tensor output,at::Tensor input)> aten_op,bool inplace=false)236 Operation createUnaryOp(
237 std::function<void(at::Tensor output, at::Tensor input)> aten_op,
238 bool inplace = false) {
239 return [aten_op, inplace](Stack& stack) {
240 auto a = pop(stack).toTensor();
241 c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset);
242 // we cast `a` to an `ideep::tensor`, so we can get at its descriptor
243 // which we then use to set up `out` tensor w/ the same props as a
244 auto a_it = at::native::itensor_from_mkldnn(a);
245 auto mkldnn_raw_data = a_it.get_data_handle();
246 auto a_options_with_strided = a.options().layout(c10::kStrided);
247
248 // tensor's physical size could be bigger than a logical one
249 // `a_it.get_desc().get_size()` returns the real physical size (in bytes)
250 // we use it to compute `nelem` for `aten` ops
251 auto nelem = static_cast<int64_t>(
252 a_it.get_desc().get_size() / elementSize(a.scalar_type()));
253 // we also wrap `a` storage into an aten tensor
254 auto in_aten =
255 at::from_blob(mkldnn_raw_data, {nelem}, a_options_with_strided);
256
257 auto out_raw_data = mkldnn_raw_data;
258 auto out = a;
259 if (!inplace) {
260 // `a_it.get_desc()` will allocate a tensor
261 // of the right physical size.
262 auto it_empty = ideep::tensor(a_it.get_desc());
263 TORCH_INTERNAL_ASSERT(it_empty.get_desc() == a_it.get_desc());
264 out = at::native::new_with_itensor_mkldnn(
265 std::move(it_empty),
266 c10::optTypeMetaToScalarType(a.options().dtype_opt()),
267 a.options().device_opt());
268
269 out_raw_data = at::native::itensor_from_mkldnn(out).get_data_handle();
270 }
271
272 TORCH_INTERNAL_ASSERT(
273 a_it.get_desc().get_size() % elementSize(a.scalar_type()) == 0);
274
275 auto out_aten = at::from_blob(
276 out_raw_data, {static_cast<int64_t>(nelem)}, a_options_with_strided);
277 aten_op(out_aten, in_aten);
278 push(stack, out);
279 };
280 }
281
MKLDNNLayerNormOp(Stack & stack,bool inplace)282 void MKLDNNLayerNormOp(Stack& stack, bool inplace) {
283 c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset);
284
285 // enable_cudnn not used
286 pop(stack);
287 auto eps = pop(stack).toDouble();
288
289 Tensor bias{};
290 Tensor weight{};
291 auto bias_ival = pop(stack);
292 TORCH_INTERNAL_ASSERT(bias_ival.isTensor());
293 bias = bias_ival.toTensor();
294
295 auto weight_ival = pop(stack);
296 TORCH_INTERNAL_ASSERT(weight_ival.isTensor());
297 weight = weight_ival.toTensor();
298
299 auto shape = pop(stack).toDimVector();
300 auto input = pop(stack).toTensor();
301
302 auto [dst, mean, rstd] =
303 at::native::mkldnn_layer_norm_last_index_weight_bias_f32(
304 input, shape, weight, bias, eps, inplace);
305 push(stack, dst);
306 };
307
BroadOp(const Node * node)308 Operation BroadOp(const Node* node) {
309 return [](Stack& stack) {
310 auto b = pop(stack).toTensor();
311 auto a = pop(stack).toTensor();
312 auto b_size = b.sizes();
313 auto a_size = a.sizes();
314 if (a_size.equals(b_size)) {
315 // TODO: follow up with MKLDNN what the best way is
316 // to handle perf incompatible formats
317 push(stack, a, b);
318 return;
319 } else {
320 auto out_size = at::infer_size(a_size, b_size);
321 int64_t out_numel = out_size[0];
322 for (size_t i = 1, end = out_size.size(); i < end; ++i) {
323 out_numel = out_numel * out_size[i];
324 }
325
326 auto exp_a = a;
327 auto exp_b = b;
328 int stacked = 0;
329 // mkldnn tensors only support reshape, not expand or view operators
330 if (a_size.equals(out_size)) {
331 push(stack, a);
332 ++stacked;
333 } else if (out_numel == a.numel()) {
334 exp_a = a.reshape(out_size);
335 } else {
336 // TODO: consider to initializing to a blocked layout
337 // directly if needed
338 exp_a = a.to_dense().expand(out_size).to_mkldnn();
339 }
340
341 if (b_size.equals(out_size)) {
342 push(stack, b);
343 ++stacked;
344 } else if (out_numel == b.numel()) {
345 exp_b = b.reshape(out_size);
346 } else {
347 exp_b = b.to_dense().expand(out_size).to_mkldnn();
348 }
349
350 if (stacked < 2) {
351 if (stacked == 1) {
352 pop(stack);
353 }
354 // If one of the inputs was expanded and converted to nchw/nhwc
355 // we might end up in a very bad spot if the second argument
356 // is in a blocked format. In this case, MKLDNN uses its
357 // reference implementation for a binary operation that follows
358 // these broadcasts and it could be up to ~100x slower.
359 // We use a very simple heuristic to convert an arg in nchw
360 // to the blocked format of the other argument.
361 c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset);
362 auto a_it = at::native::itensor_from_mkldnn(exp_a);
363 auto b_it = at::native::itensor_from_mkldnn(exp_b);
364
365 // `is_public_format` means a tensor's physical layout isn't in MKLDNN
366 // blocked layout e.g. nchw or nhwc but not nChw8c
367 if (!a_it.is_public_format()) {
368 if (b_it.is_public_format()) {
369 b_it = b_it.reorder_if_differ_in(a_it.get_desc());
370 }
371 } else if (!b_it.is_public_format()) {
372 if (a_it.is_public_format()) {
373 a_it = a_it.reorder_if_differ_in(b_it.get_desc());
374 }
375 }
376
377 auto a_options = exp_a.options();
378 auto a_out = at::native::new_with_itensor_mkldnn(
379 std::move(a_it),
380 c10::optTypeMetaToScalarType(a_options.dtype_opt()),
381 a_options.device_opt());
382 push(stack, a_out);
383 auto b_options = exp_b.options();
384 auto b_out = at::native::new_with_itensor_mkldnn(
385 std::move(b_it),
386 c10::optTypeMetaToScalarType(b_options.dtype_opt()),
387 b_options.device_opt());
388 push(stack, b_out);
389 };
390 }
391 };
392 }
393
hardtanh_helper(const Node * n)394 static std::function<void(at::Tensor output, at::Tensor input)> hardtanh_helper(
395 const Node* n) {
396 auto min_val = n->f(attr::min_val);
397 auto max_val = n->f(attr::max_val);
398 return [min_val, max_val](at::Tensor output, at::Tensor input) {
399 at::cpu::hardtanh_out(output, input, min_val, max_val);
400 };
401 }
402
clamp_helper(const Node * n)403 static std::function<void(at::Tensor output, at::Tensor input)> clamp_helper(
404 const Node* n) {
405 auto min_val = n->f(attr::min_val);
406 auto max_val = n->f(attr::max_val);
407 return [min_val, max_val](at::Tensor output, at::Tensor input) {
408 at::cpu::clamp_out(output, input, min_val, max_val);
409 };
410 }
411
412 // any op added to this registry needs to meet
413 // the precondition: `aten_op(0) == 0`
414 const RegisterOperators MKLDNNHardSwishOpReg({
415 torch::jit::Operator(
416 "prim::MKLDNNHardSwish_(Tensor(a!) self) -> Tensor(a!)",
417 createUnaryOp(
__anon6095a4570702(at::Tensor output, at::Tensor input) 418 [](at::Tensor output, at::Tensor input) {
419 at::cpu::hardswish_out(output, input);
420 },
421 true),
422 AliasAnalysisKind::FROM_SCHEMA),
423 torch::jit::Operator(
424 "prim::MKLDNNHardSigmoid_(Tensor(a!) self) -> Tensor(a!)",
425 createUnaryOp(
__anon6095a4570802(at::Tensor output, at::Tensor input) 426 [](at::Tensor output, at::Tensor input) {
427 at::cpu::hardsigmoid_out(output, input);
428 },
429 true),
430 AliasAnalysisKind::FROM_SCHEMA),
431 torch::jit::Operator(
432 "prim::MKLDNNHardTanh_(Tensor(a!) self) -> Tensor(a!)",
__anon6095a4570902(const Node* n) 433 [](const Node* n) -> Operation {
434 return createUnaryOp(hardtanh_helper(n), true);
435 },
436 AliasAnalysisKind::FROM_SCHEMA),
437 torch::jit::Operator(
438 "prim::MKLDNNClamp_(Tensor(a!) self) -> Tensor(a!)",
__anon6095a4570a02(const Node* n) 439 [](const Node* n) -> Operation {
440 return createUnaryOp(clamp_helper(n), true);
441 },
442 AliasAnalysisKind::FROM_SCHEMA),
443 torch::jit::Operator(
444 "prim::MKLDNNHardSwish(Tensor a) -> Tensor",
445 createUnaryOp(
__anon6095a4570b02(at::Tensor output, at::Tensor input) 446 [](at::Tensor output, at::Tensor input) {
447 at::cpu::hardswish_out(output, input);
448 },
449 false),
450 AliasAnalysisKind::FROM_SCHEMA),
451 torch::jit::Operator(
452 "prim::MKLDNNHardSigmoid(Tensor a) -> Tensor",
453 createUnaryOp(
__anon6095a4570c02(at::Tensor output, at::Tensor input) 454 [](at::Tensor output, at::Tensor input) {
455 at::cpu::hardsigmoid_out(output, input);
456 },
457 false),
458 AliasAnalysisKind::FROM_SCHEMA),
459 torch::jit::Operator(
460 "prim::MKLDNNHardTanh(Tensor self) -> Tensor",
__anon6095a4570d02(const Node* n) 461 [](const Node* n) -> Operation {
462 return createUnaryOp(hardtanh_helper(n), false);
463 },
464 AliasAnalysisKind::FROM_SCHEMA),
465 torch::jit::Operator(
466 "prim::MKLDNNClamp(Tensor self) -> Tensor",
__anon6095a4570e02(const Node* n) 467 [](const Node* n) -> Operation {
468 return createUnaryOp(clamp_helper(n), false);
469 },
470 AliasAnalysisKind::FROM_SCHEMA),
471 });
472
473 const RegisterOperators BroadOpReg({
474 torch::jit::Operator(
475 prim::BroadcastMKLDNNTensors,
476 BroadOp,
477 AliasAnalysisKind::INTERNAL_SPECIAL_CASE),
478 });
479
480 const RegisterOperators MKLDNNLayerNormOpReg({
481 torch::jit::Operator(
482 "prim::MKLDNNLayerNorm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor",
__anon6095a4570f02(Stack& stack) 483 [](Stack& stack) { MKLDNNLayerNormOp(stack, false); },
484 AliasAnalysisKind::FROM_SCHEMA),
485 torch::jit::Operator(
486 "prim::MKLDNNLayerNorm_(Tensor(a!) input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor(a!)",
__anon6095a4571002(Stack& stack) 487 [](Stack& stack) { MKLDNNLayerNormOp(stack, true); },
488 AliasAnalysisKind::FROM_SCHEMA),
489 });
490
ConstantMKLDNNTensorOp(const Node * node)491 Operation ConstantMKLDNNTensorOp(const Node* node) {
492 const auto& t = node->t(attr::value);
493 return [t](Stack& stack) {
494 push(stack, t);
495 return 0;
496 };
497 }
498
mkldnn_tensor_scalar_mul(Tensor & tensor,Tensor & out,float scalar)499 Tensor mkldnn_tensor_scalar_mul(Tensor& tensor, Tensor& out, float scalar) {
500 ideep::tensor& x = at::native::itensor_from_mkldnn(tensor);
501 ideep::tensor& z = at::native::itensor_from_mkldnn(out);
502 ideep::eltwise_forward::compute(
503 x,
504 z,
505 ideep::algorithm::eltwise_linear,
506 ideep::prop_kind::forward_inference,
507 /*alpha*/ scalar);
508 return out;
509 }
510
511 // aten::convolution does a lot of precomputation and dispatching before
512 // mkldnn_convolution is called. registering here we can directly invoke the op
513 // and avoid overhead. avoiding dispatch overhead for other operators - relu,
514 // add, etc - did not benchmark as speeding up models noticeably. the additional
515 // overhead of `convolution` warrants the custom operator.
516 jit::RegisterOperators reg_fut_ops({
517 jit::Operator(
518 // XXX: this follows the schema convention of conv2d/conv3d, not
519 // aten::mkldnn_convolution, which is different for some reason!
520 "prim::mkldnn_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor",
__anon6095a4571202(jit::Stack& stack) 521 [](jit::Stack& stack) {
522 int64_t groups = pop(stack).toInt();
523 auto dilation = pop(stack).toIntVector();
524 auto padding = pop(stack).toIntVector();
525 auto stride = pop(stack).toIntVector();
526
527 Tensor bias;
528 IValue bias_ival = pop(stack);
529 if (!bias_ival.isNone()) {
530 bias = bias_ival.toTensor();
531 }
532 Tensor weight = pop(stack).toTensor();
533 Tensor input = pop(stack).toTensor();
534
535 at::AutoDispatchBelowAutograd mode;
536 // aten::convolution takes care of 0 dim case before calls into
537 // backends
538 if (input.size(0) == 0) {
539 std::vector<int64_t> o = at::native::conv_output_size(
540 input.sizes(), weight.sizes(), padding, stride, dilation);
541 push(
542 stack,
543 at::native::empty_mkldnn(
544 o,
545 c10::optTypeMetaToScalarType(input.options().dtype_opt()),
546 input.options().layout_opt(),
547 input.options().device_opt(),
548 input.options().pinned_memory_opt()));
549 return;
550 }
551 // aten::convolution also checks dtype mismatches
552 TORCH_CHECK(
553 input.options().type_equal(weight.options()),
554 "Input type (",
555 input.toString(),
556 ") and weight type (",
557 weight.toString(),
558 ") should be the same");
559
560 push(
561 stack,
562 at::native::mkldnn_convolution(
563 input, weight, bias, padding, stride, dilation, groups));
564 },
565 aliasAnalysisFromSchema()),
566 // registering as custom operators avoids Scalar->Tensor->Scalar conversion
567 // in default bindings
568 jit::Operator(
569 "prim::MKLDNNScalarMul(Tensor self, Scalar other) -> Tensor",
__anon6095a4571302(jit::Stack& stack) 570 [](jit::Stack& stack) {
571 c10::impl::ExcludeDispatchKeyGuard edkg(
572 c10::autograd_dispatch_keyset);
573 float other = pop(stack).toScalar().toFloat();
574 Tensor self = pop(stack).toTensor();
575 auto out = at::native::empty_mkldnn(
576 self.sizes(),
577 c10::optTypeMetaToScalarType(self.options().dtype_opt()),
578 self.options().layout_opt(),
579 self.options().device_opt(),
580 self.options().pinned_memory_opt());
581
582 mkldnn_tensor_scalar_mul(self, out, other);
583 push(stack, out);
584 },
585 aliasAnalysisFromSchema()),
586 jit::Operator(
587 "prim::MKLDNNScalarMul_(Tensor(a!) self, Scalar other) -> Tensor(a!)",
__anon6095a4571402(jit::Stack& stack) 588 [](jit::Stack& stack) {
589 c10::impl::ExcludeDispatchKeyGuard edkg(
590 c10::autograd_dispatch_keyset);
591 float other = pop(stack).toScalar().toFloat();
592 Tensor self = pop(stack).toTensor();
593 mkldnn_tensor_scalar_mul(self, self, other);
594 push(stack, self);
595 },
596 aliasAnalysisFromSchema()),
597 });
598
599 // This is registered as its own op instead of as prim::Constant bc it does not
600 // serialize which is an invariant of prim::Constant
601 // TODO: make mkldnn tensor serialize...
602 const RegisterOperators MKLDNNConstantOp({
603 torch::jit::Operator(
604 prim::ConstantMKLDNNTensor,
605 ConstantMKLDNNTensorOp,
606 AliasAnalysisKind::INTERNAL_SPECIAL_CASE),
607 });
608
createConstantMKLDNNTensorOp(Graph * g,const Tensor & mkldnn_tensor)609 Node* createConstantMKLDNNTensorOp(Graph* g, const Tensor& mkldnn_tensor) {
610 TORCH_INTERNAL_ASSERT(mkldnn_tensor.is_mkldnn());
611 auto op = g->create(prim::ConstantMKLDNNTensor);
612 op->t_(attr::value, mkldnn_tensor);
613 return op;
614 }
615
supportedMKLDNNWeight(const Tensor & weight)616 bool supportedMKLDNNWeight(const Tensor& weight) {
617 return weight.device().is_cpu() && weight.dtype() == c10::ScalarType::Float &&
618 weight.ndimension() != 0;
619 }
620
replaceInputWithMKLDNNTensor(Node * n,size_t index)621 void replaceInputWithMKLDNNTensor(Node* n, size_t index) {
622 Value* input = n->inputs().at(index);
623 auto mkldnn_tensor = constant_as<Tensor>(input)->to_mkldnn();
624 auto mkldnn_tensor_value =
625 createConstantMKLDNNTensorOp(n->owningGraph(), mkldnn_tensor)
626 ->insertBefore(n)
627 ->output();
628 mkldnn_tensor_value->setDebugName(input->debugName() + "_mkldnn");
629 n->replaceInputWith(input, mkldnn_tensor_value);
630 }
631
replaceInputWithMKLDNNTensor(Node * n,const std::string & name,const at::Tensor & mkldnn_tensor)632 void replaceInputWithMKLDNNTensor(
633 Node* n,
634 const std::string& name,
635 const at::Tensor& mkldnn_tensor) {
636 Value* input = n->namedInput(name);
637 auto mkldnn_tensor_value =
638 createConstantMKLDNNTensorOp(n->owningGraph(), mkldnn_tensor)
639 ->insertBefore(n)
640 ->output();
641 mkldnn_tensor_value->setDebugName(input->debugName() + "_mkldnn");
642 n->replaceInputWith(input, mkldnn_tensor_value);
643 }
644
replaceInputWithMKLDNNTensor(Node * n,const std::string & name)645 void replaceInputWithMKLDNNTensor(Node* n, const std::string& name) {
646 Value* input = n->namedInput(name);
647 auto mkldnn_tensor = constant_as<Tensor>(input)->to_mkldnn();
648 replaceInputWithMKLDNNTensor(n, name, mkldnn_tensor);
649 }
650
moveConvWeightsToMKLDNN(Node * conv)651 void moveConvWeightsToMKLDNN(Node* conv) {
652 auto conv_w_mkldnn =
653 constant_as<Tensor>(conv->namedInput("weight")).value().to_mkldnn();
654 std::vector<int64_t> padding =
655 toIValue(conv->namedInput("padding"))->toIntVector();
656 std::vector<int64_t> stride =
657 toIValue(conv->namedInput("stride"))->toIntVector();
658 std::vector<int64_t> dilation =
659 toIValue(conv->namedInput("dilation"))->toIntVector();
660 auto groups = constant_as<int64_t>(conv->namedInput("groups")).value();
661
662 if (conv->kind() == aten::conv2d) {
663 conv_w_mkldnn = mkldnn_reorder_conv2d_weight(
664 conv_w_mkldnn, padding, stride, dilation, groups);
665 } else if (conv->kind() == aten::conv3d) {
666 conv_w_mkldnn = mkldnn_reorder_conv3d_weight(
667 conv_w_mkldnn, padding, stride, dilation, groups);
668 } else {
669 TORCH_INTERNAL_ASSERT(false);
670 }
671 replaceInputWithMKLDNNTensor(conv, "weight", conv_w_mkldnn);
672
673 if (conv->namedInput("bias")->type() != NoneType::get()) {
674 replaceInputWithMKLDNNTensor(conv, "bias");
675 }
676 }
677
moveWeightsToMKLDNN(Node * n)678 void moveWeightsToMKLDNN(Node* n) {
679 // conv goes through special pathway so we can call mkldnn reorder conv
680 // primitive
681 if (n->kind() == aten::conv2d || n->kind() == aten::conv3d) {
682 moveConvWeightsToMKLDNN(n);
683 } else {
684 for (size_t i = 0; i < n->inputs().size(); ++i) {
685 if (!n->input(i)->type()->cast<TensorType>() ||
686 n->input(i)->node()->kind() != prim::Constant) {
687 continue;
688 }
689 replaceInputWithMKLDNNTensor(n, i);
690 }
691 }
692 }
693
clamp_node_creator(Node * body_node,c10::Symbol kind,double min_val,double max_val)694 static void clamp_node_creator(
695 Node* body_node,
696 c10::Symbol kind,
697 double min_val,
698 double max_val) {
699 WithInsertPoint insert_guard{body_node};
700 auto out_node =
701 body_node->owningGraph()->create({kind}, {body_node->input(0)}, 1);
702 // N.B. we can't use `insert` as it calls `getOperation` (via
703 // `emitBuiltinCall`) which uses `min_val` and `max_val` attrs which we
704 // haven't set yet.
705 body_node->owningGraph()->insertNode(out_node);
706 auto out_val = out_node->output();
707 out_node->f_(attr::min_val, min_val);
708 out_node->f_(attr::max_val, max_val);
709 out_val->copyMetadata(body_node->output());
710 body_node->output()->replaceAllUsesWith(out_val);
711 body_node->destroy();
712 }
713
ComputeSubgraphInMKLDNN(Node * subgraph_node)714 void ComputeSubgraphInMKLDNN(Node* subgraph_node) {
715 auto graph = subgraph_node->owningGraph();
716 Value* none_value = nullptr;
717 {
718 WithInsertPoint guard(subgraph_node);
719 none_value = graph->insertConstant(IValue());
720 }
721 for (size_t i = 0; i < subgraph_node->inputs().size(); ++i) {
722 Value* v = subgraph_node->inputs().at(i);
723 if (!v->type()->cast<TensorType>()) {
724 continue;
725 }
726 auto to_mkldnn =
727 graph->create(c10::Symbol::fromQualString("aten::to_mkldnn"), 1)
728 ->insertBefore(subgraph_node);
729 to_mkldnn->addInput(v);
730 to_mkldnn->addInput(none_value);
731 subgraph_node->replaceInput(i, to_mkldnn->output());
732 }
733
734 for (size_t i = 0; i < subgraph_node->outputs().size(); ++i) {
735 Value* v = subgraph_node->outputs().at(i);
736 if (!v->type()->cast<TensorType>()) {
737 continue;
738 }
739 auto from_mkldnn = graph
740 ->create(
741 c10::Symbol::fromQualString("aten::to_dense"),
742 {v, none_value, none_value})
743 ->insertAfter(subgraph_node);
744 v->replaceAllUsesAfterNodeWith(from_mkldnn, from_mkldnn->output());
745 }
746
747 auto subgraph = SubgraphUtils::getSubgraph(subgraph_node);
748 for (auto it = subgraph->block()->nodes().begin();
749 it != subgraph->block()->nodes().end();) {
750 Node* body_node = *it;
751 it++;
752
753 moveWeightsToMKLDNN(body_node);
754
755 if (body_node->kind() == aten::add ||
756 (body_node->kind() == aten::mul &&
757 body_node->input(1)->type()->cast<TensorType>())) {
758 auto node = body_node->owningGraph()->create(
759 Symbol::prim("BroadcastMKLDNNTensors"),
760 {body_node->inputs().at(0), body_node->inputs().at(1)},
761 2);
762 node->insertBefore(body_node);
763 body_node->replaceInput(0, node->outputs().at(0));
764 body_node->replaceInput(1, node->outputs().at(1));
765 }
766 if (body_node->kind() == aten::mul &&
767 body_node->input(1)->type()->isSubtypeOf(*NumberType::get())) {
768 body_node->replaceWithNewSymbol(Symbol::prim("MKLDNNScalarMul"));
769 body_node->destroy();
770 continue;
771 }
772
773 if (body_node->matches(
774 "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor")) {
775 body_node->replaceWithNewSymbol(Symbol::prim("MKLDNNLayerNorm"));
776 body_node->destroy();
777 continue;
778 }
779
780 if (body_node->kind() == aten::hardswish) {
781 body_node->replaceWithNewSymbol(prim::MKLDNNHardSwish);
782 body_node->destroy();
783 continue;
784 }
785
786 if (body_node->kind() == aten::hardsigmoid) {
787 body_node->replaceWithNewSymbol(prim::MKLDNNHardSigmoid);
788 body_node->destroy();
789 continue;
790 }
791
792 if (body_node->kind() == aten::relu6) {
793 clamp_node_creator(body_node, prim::MKLDNNHardTanh, 0., 6.);
794 continue;
795 }
796
797 if (body_node->kind() == aten::hardtanh) {
798 auto min_val =
799 constant_as<double>(body_node->namedInput("min_val")).value();
800 auto max_val =
801 constant_as<double>(body_node->namedInput("max_val")).value();
802 clamp_node_creator(body_node, prim::MKLDNNHardTanh, min_val, max_val);
803 continue;
804 }
805
806 if (body_node->kind() == aten::clamp) {
807 auto min_val = constant_as<double>(body_node->namedInput("min")).value();
808 auto max_val = constant_as<double>(body_node->namedInput("max")).value();
809 clamp_node_creator(body_node, prim::MKLDNNClamp, min_val, max_val);
810 continue;
811 }
812
813 if (body_node->kind() == aten::conv2d ||
814 body_node->kind() == aten::conv3d) {
815 // this node doesnt handle string padding yet...
816 if (!body_node->namedInput("padding")->type()->cast<StringType>()) {
817 body_node->replaceWithNewSymbol(Symbol::prim("mkldnn_convolution"));
818 body_node->destroy();
819 continue;
820 }
821 }
822 }
823 }
824
nonConstantParameters(Node * n)825 bool nonConstantParameters(Node* n) {
826 for (size_t i = 1; i < n->inputs().size(); i++) {
827 if (n->inputs().at(i)->node()->kind() != prim::Constant) {
828 return true;
829 }
830 }
831 return false;
832 }
833
frozenMkldnnCompatibleLinearNode(Node * n)834 bool frozenMkldnnCompatibleLinearNode(Node* n) {
835 if (nonConstantParameters(n)) {
836 return false;
837 }
838
839 if (n->kind() != aten::linear) {
840 return false;
841 }
842
843 auto weight = constant_as<Tensor>(n->namedInput("weight")).value();
844 return supportedMKLDNNWeight(weight);
845 }
846
frozenMkldnnCompatibleConvNode(Node * n)847 bool frozenMkldnnCompatibleConvNode(Node* n) {
848 if (nonConstantParameters(n)) {
849 return false;
850 }
851 // mkldnn does not support conv1d
852 // _convolution is rewritten before this pass is invoked
853 if (n->kind() != aten::conv2d && n->kind() != aten::conv3d) {
854 return false;
855 }
856
857 auto weight = constant_as<Tensor>(n->namedInput("weight")).value();
858 return supportedMKLDNNWeight(weight);
859 }
860
861 // [mkldnn perf strategy]
862 // Certain ops - aten::linear, aten::conv2d, aten::conv3d - provide a huge speed
863 // up just by converting the constant weights to MKLDNN AOT, and then at runtime
864 // converting the non-constant input to_mkldnn before the op, and then back to
865 // its original layout after the op. The speed up holds even if you end up
866 // converting the input to_mkldnn and output back to_dense. We start groups of
867 // ops to compute in MKLDNN only from these ops that are a strict speedup. Then,
868 // we expand the groups to include operators which are computable in MKLDNN &
869 // are roughly perf equal to eager. We do this in the hopes of joining multiple
870 // fast nodes together, saving to_mkldnn and to_dense conversions.
871 //
872 // MKLDNN only supports float32 inputs for aten::linear, aten::conv2d &
873 // aten::conv3d. We only fuse these nodes if the weights are float32, and then
874 // we only include operators which we can prove will execute in float32. By
875 // fusing topologically we can maintain the invariant that all tensor types in
876 // the graph are floating point. In fusing Conv-> Add -> Relu -> Conv we start
877 // with the first Conv, know that the output is float, and can then safely merge
878 // Add and Relu. If we started with the last Conv, it would be difficult to
879 // prove in our first pass that the Add's inputs were both float32 without first
880 // fusing the first conv.
881
882 class MKLDNNSubgraphSlicer {
883 public:
MKLDNNSubgraphSlicer(Block * block,std::shared_ptr<Graph> graph,AliasDb & aliasDb)884 MKLDNNSubgraphSlicer(
885 Block* block,
886 std::shared_ptr<Graph> graph,
887 AliasDb& aliasDb)
888 : block_(block), graph_(std::move(graph)), aliasDb_(aliasDb) {}
889
run()890 void run() {
891 // We maintain alias db correctness in-place while building up the autodiff
892 // subgraphs, however it is difficult to preserve correctness when
893 // un-inlining autodiff subgraphs. We first recursively construct all
894 // subgraphs and then unmerge them into the graph
895 buildupSubgraphs();
896 computeSubgraphsInMKLDNN();
897 // Run CSE globally onceto eliminate duplicates that may have occurred
898 // while inlining subgraphs.
899 EliminateCommonSubexpression(graph_);
900 }
901
buildupSubgraphs()902 void buildupSubgraphs() {
903 // We need to run the slicer multiple times in order to get all merge
904 // opportunities. This is because moveBeforeTopologicalValid may reorder
905 // nodes to be AFTER the current iteration point. In order to properly
906 // consider those nodes for merging, we need run the pass until no changes
907 // have been made.
908 //
909 // Example:
910 // c = f(a, b)
911 // d = f(c)
912 // e = f(d) <- iter is here, moving upward
913 // After c.moveBeforeTopologicallyValid(e), we have:
914 // c = f(a, b)
915 // e = f(d) <- iter still here
916 // d = f(c) <- this was node moved on the other side.
917
918 bool any_changed = true;
919 while (any_changed) {
920 any_changed = false;
921 for (auto it = block_->nodes().begin(); it != block_->nodes().end();) {
922 bool changed = false;
923 std::tie(it, changed) = scanNode(*it);
924 any_changed |= changed;
925 }
926 }
927
928 // Construct Subgraphs Recursively
929 for (Node* n : block_->nodes()) {
930 for (auto subBlock : n->blocks()) {
931 MKLDNNSubgraphSlicer(subBlock, graph_, aliasDb_).buildupSubgraphs();
932 }
933 }
934 }
935
MKLDNNGroupStart(Node * node)936 static bool MKLDNNGroupStart(Node* node) {
937 // if we're already in the process of merging
938 if (node->kind() == prim::MKLDNNGroup) {
939 return true;
940 }
941 // see [mkldnn perf strategy]
942 return frozenMkldnnCompatibleConvNode(node);
943 }
944
945 private:
946 // MKLDNN only supports floats of dimension > 0, so we only support
947 // Tensors who have a known type or were previously verified
948 // to be usable in an MKLDNN Group
tensorInputIsMKLDNNSupported(Value * v,Node * v_use)949 bool tensorInputIsMKLDNNSupported(Value* v, Node* v_use) {
950 auto const_tensor = constant_as<Tensor>(v);
951 if (const_tensor) {
952 return supportedMKLDNNWeight(*const_tensor);
953 }
954 auto k = v->node()->kind();
955 if (k == prim::MKLDNNGroup || k == prim::ConstantMKLDNNTensor ||
956 k == aten::to_mkldnn) {
957 return true;
958 }
959 for (const auto& use : v->uses()) {
960 if (use.user->kind() == aten::to_mkldnn &&
961 v_use->owningBlock() == use.user->owningBlock()) {
962 return true;
963 }
964 }
965 return false;
966 }
967
968 // We include ops here which are roughly perf-equivalent in mkldnn as with
969 // aten (single & multithreaded) and whose inputs & outputs are float32.
computableInMKLDNN(Node * n)970 bool computableInMKLDNN(Node* n) {
971 for (Value* v : n->inputs()) {
972 if (v->type()->cast<TensorType>() &&
973 !(tensorInputIsMKLDNNSupported(v, n))) {
974 return false;
975 }
976 }
977
978 if (n->matches(
979 "aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor") &&
980 n->namedInput("weight")->type() != NoneType::get() &&
981 n->namedInput("bias")->type() != NoneType::get()) {
982 auto norm_shape =
983 constant_as<std::vector<int64_t>>(n->namedInput("normalized_shape"));
984 return norm_shape.has_value() && norm_shape->size() == 1;
985 }
986
987 // unary ops we dont need to prove anything else than
988 // the input is mkldnn supported
989 switch (n->kind()) {
990 case aten::relu:
991 case aten::relu6:
992 case aten::gelu:
993 case aten::prelu:
994 case aten::sigmoid:
995 case aten::hardsigmoid:
996 case aten::hardswish:
997 case aten::tanh:
998 case aten::batch_norm:
999 case aten::max_pool2d:
1000 case aten::max_pool3d:
1001 case aten::avg_pool2d:
1002 case aten::adaptive_avg_pool2d:
1003 case aten::avg_pool3d:
1004 // case aten::adaptive_max_pool2d: // return tuples which break fusion
1005 // case aten::adaptive_max_pool3d: // return tuples which break fusion
1006 // case aten::adaptive_avg_pool3d: // no ideep binding
1007 return true;
1008 }
1009
1010 if ((n->kind() == aten::hardtanh || n->kind() == aten::clamp) &&
1011 !nonConstantParameters(n)) {
1012 const size_t MIN_INDEX = 1, MAX_INDEX = 2;
1013 auto min_val = constant_as<double>(n->input(MIN_INDEX)).value();
1014 auto max_val = constant_as<double>(n->input(MAX_INDEX)).value();
1015 // we need to maintain the following invariant `pointwise_func(0) == 0`,
1016 // see `createUnaryOp`
1017 if (min_val <= 0. && max_val >= 0.) {
1018 return true;
1019 }
1020 }
1021
1022 if (n->kind() == aten::add) {
1023 // mkldnn doesn't currently support Tensor-Scalar add
1024 for (const auto i : c10::irange(2)) {
1025 if (!n->inputs().at(i)->type()->cast<TensorType>()) {
1026 return false;
1027 }
1028 }
1029 return true;
1030 }
1031 if (n->kind() == aten::mul) {
1032 return n->input(0)->type()->cast<TensorType>() &&
1033 (n->input(1)->type()->cast<TensorType>() ||
1034 n->input(1)->type()->isSubtypeOf(*NumberType::get()));
1035 }
1036
1037 if (n->kind() == aten::dropout) {
1038 auto train = constant_as<bool>(n->namedInput("train")).value();
1039 return train == false;
1040 }
1041 return false;
1042 }
1043
computeSubgraphsInMKLDNN()1044 void computeSubgraphsInMKLDNN() {
1045 auto curNode = *block_->nodes().begin();
1046 while (curNode != *block_->nodes().end()) {
1047 auto nextNode = curNode->next();
1048 if (curNode->kind() == prim::MKLDNNGroup) {
1049 ComputeSubgraphInMKLDNN(curNode);
1050 InplaceMKLDNNSubgraph(SubgraphUtils::getSubgraph(curNode));
1051 SubgraphUtils::unmergeSubgraph(curNode);
1052 }
1053 curNode = nextNode;
1054 }
1055 for (Node* n : block_->nodes()) {
1056 for (Block* b : n->blocks()) {
1057 MKLDNNSubgraphSlicer(b, graph_, aliasDb_).computeSubgraphsInMKLDNN();
1058 }
1059 }
1060 }
1061
shouldConsiderForMerge(Node * node)1062 bool shouldConsiderForMerge(Node* node) {
1063 // if we're already in the process of merging
1064 if (node->kind() == prim::MKLDNNGroup) {
1065 return true;
1066 }
1067 return frozenMkldnnCompatibleLinearNode(node) ||
1068 frozenMkldnnCompatibleConvNode(node) || computableInMKLDNN(node);
1069 }
1070
scanNode(Node * producer)1071 std::pair<graph_node_list::iterator, bool> scanNode(Node* producer) {
1072 if (MKLDNNGroupStart(producer)) {
1073 if (producer->kind() != prim::MKLDNNGroup) {
1074 producer = SubgraphUtils::createSingletonSubgraphAndUpdateAliasing(
1075 producer, prim::MKLDNNGroup, aliasDb_);
1076 }
1077 std::vector<Node*> output_nodes;
1078 for (Value* v : producer->outputs()) {
1079 for (const auto& use : v->uses()) {
1080 output_nodes.push_back(use.user);
1081 }
1082 }
1083 std::sort(
1084 output_nodes.begin(), output_nodes.end(), [&](Node* a, Node* b) {
1085 return a->isBefore(b);
1086 });
1087 for (auto output_node : output_nodes) {
1088 if (auto group = tryMerge(producer, output_node)) {
1089 // we successfully merged, so the new group's `outputs` may have
1090 // changed. So rescan the new group for more merging opportunities.
1091 return std::make_pair(group.value()->iterator()++, true);
1092 }
1093 }
1094 }
1095
1096 return std::make_pair(++producer->iterator(), false);
1097 }
1098
1099 // Try to merge `consumer` into `producer`. If successful, this destroys
1100 // `consumer` and returns the `producer` group.
tryMerge(Node * producer,Node * consumer)1101 std::optional<Node*> tryMerge(Node* producer, Node* consumer) {
1102 AT_ASSERT(producer->kind() == prim::MKLDNNGroup);
1103 bool canMerge = shouldConsiderForMerge(consumer) &&
1104 aliasDb_.moveAfterTopologicallyValid(consumer, producer);
1105
1106 if (!canMerge) {
1107 return std::nullopt;
1108 }
1109
1110 SubgraphUtils::mergeNodeIntoSubgraphAndUpdateAliasing(
1111 consumer, producer, aliasDb_);
1112
1113 return producer;
1114 }
1115
1116 Block* block_;
1117 std::shared_ptr<Graph> graph_;
1118 AliasDb& aliasDb_;
1119 };
1120
containsMKLDNNGroup(Block * b)1121 bool containsMKLDNNGroup(Block* b) {
1122 for (Node* n : b->nodes()) {
1123 for (Block* block : n->blocks()) {
1124 if (containsMKLDNNGroup(block)) {
1125 return true;
1126 }
1127 }
1128 if (MKLDNNSubgraphSlicer::MKLDNNGroupStart(n)) {
1129 return true;
1130 }
1131 }
1132 return false;
1133 }
1134
1135 } // namespace
1136
ConvertFrozenOpsToMKLDNN(std::shared_ptr<Graph> & graph)1137 void ConvertFrozenOpsToMKLDNN(std::shared_ptr<Graph>& graph) {
1138 GRAPH_DUMP("Before convert frozen ops to mkldnn", graph);
1139 // TODO: replace conv1d with conv2d ?
1140 graph_rewrite_helper::replaceConvolutionWithAtenConv(graph);
1141 if (containsMKLDNNGroup(graph->block())) {
1142 // Only remove tensor mutation if we know we're going to create speedups
1143 // with mkldnn. Only supporting functional ops simplifies this pass bc
1144 // running an op in mkldnn removes the aliasing relationships that
1145 // previously existed between input and output.
1146 RemoveTensorMutation(graph, [](Node* node_to_functionalize) {
1147 static std::unordered_set<Symbol> mkldnn_ops = {
1148 aten::add_,
1149 aten::mul_,
1150 aten::relu_,
1151 aten::relu6_,
1152 aten::gelu_,
1153 aten::hardswish_,
1154 aten::dropout_,
1155 aten::sigmoid_,
1156 aten::hardsigmoid_,
1157 aten::hardtanh_,
1158 aten::tanh_,
1159 aten::clamp_,
1160 };
1161 return mkldnn_ops.count(node_to_functionalize->kind()) != 0;
1162 });
1163
1164 AliasDb db(graph);
1165 MKLDNNSubgraphSlicer(graph->block(), graph, db).run();
1166 EliminateDeadCode(graph);
1167 GRAPH_DUMP("After convert frozen ops to mkldnn", graph);
1168 } else {
1169 GRAPH_DUMP("No mkldnn compatible frozen nodes", graph);
1170 }
1171 }
1172
1173 #else
1174
1175 void ConvertFrozenOpsToMKLDNN(std::shared_ptr<Graph>& graph) {
1176 GRAPH_DUMP("MKLDNN Not enabled", graph);
1177 }
1178
1179 #endif
1180
1181 } // namespace torch::jit
1182