1 #include <torch/csrc/jit/codegen/onednn/LlgaTensorImpl.h>
2 #include <torch/csrc/jit/codegen/onednn/graph_helper.h>
3
4 #include <ATen/core/functional.h>
5 #include <torch/csrc/jit/jit_log.h>
6 #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
7
8 namespace torch {
9 namespace jit {
10 namespace fuser {
11 namespace onednn {
12
13 using opkind = dnnl::graph::op::kind;
14
fixConvOptionalBias(Node * node)15 static void fixConvOptionalBias(Node* node) {
16 if (node->namedInput("bias")->mustNotBeNone() == false) {
17 // Replace non-existent optional bias with const None
18 auto g = node->owningGraph();
19 auto n = g->createNone();
20 auto v = n->insertBefore(node)->output();
21 node->replaceInput(2, v);
22 }
23 }
24
getDimensions(Value * v)25 static std::optional<size_t> getDimensions(Value* v) {
26 if (v->type()->isSubtypeOf(TensorType::get())) {
27 return v->type()->cast<TensorType>()->sizes().size();
28 } else {
29 return std::nullopt;
30 }
31 }
32
33 // PyTorch ops that can't otherwise be mapped to oneDNN Graph ops are mapped as
34 // Wildcards instead. They make the integration code with PyTorch simpler by
35 // passing every op to the oneDNN Graph library in the add_op call -
36 // no need to check beforehand whether the op is supported by oneDNN Graph or
37 // not oneDNN Graph ops separated by wildcards don't end up in the same
38 // partition.
makeWildcardOp(Node * node)39 static Operator makeWildcardOp(Node* node) {
40 auto o = Operator(node, opkind::Wildcard);
41 // wildcard op contains only topology info
42 for (size_t i = 0; i < node->inputs().size(); i++) {
43 o.setInput(0, i);
44 }
45 for (size_t i = 0; i < node->outputs().size(); i++) {
46 o.setOutput(i);
47 }
48 return o;
49 }
50
51 // If we don't meet a certain condition to map a PyTorch op to a oneDNN Graph
52 // op, then we create a wildcard op corresponding to that PyTorch op instead.
53 #define REQUIRE(cond) \
54 if (!(cond)) { \
55 GRAPH_DEBUG("Unsupported condition " #cond "\n"); \
56 return makeWildcardOp(node); \
57 }
58
makeEltwiseOp(Node * node,opkind kind)59 Operator LlgaGraphHelper::makeEltwiseOp(Node* node, opkind kind) {
60 return Operator(node, kind).setInput(0).setOutput(dnnl_graph_, 0);
61 }
62
makeBinaryOp(Node * node,opkind kind)63 Operator LlgaGraphHelper::makeBinaryOp(Node* node, opkind kind) {
64 REQUIRE(
65 node->input(0)->type()->isSubtypeOf(TensorType::get()) &&
66 node->input(1)->type()->isSubtypeOf(TensorType::get()))
67 return Operator(node, kind).setInput(0, 1).setOutput(dnnl_graph_, 0);
68 }
69
70 // Map a PyTorch op to its corresponding oneDNN Graph op.
71 // If mapping isn't possible, then create a wildcard op instead.
72 // The mapping is done as per oneDNN Graph op schema defined in
73 // third_party/ideep/mkl-dnn/src/interface/op_def.hpp.
createOperator(Node * node)74 Operator LlgaGraphHelper::createOperator(Node* node) {
75 auto nodeKind = node->kind();
76 // we're using an if-else clause instead of a switch staement
77 // because we would soon be adding custom ops with function schemas.
78 // We would have to use Symbol::fromQualString at that time anyway,
79 // but we are okay with this choice, since this code is not in the hot-path.
80 if (nodeKind == Symbol::fromQualString("aten::conv2d")) {
81 fixConvOptionalBias(node);
82 return Operator(node, opkind::Convolution)
83 .setInput(0, 1, 2)
84 .setOutput(dnnl_graph_, 0)
85 .setAttr(dnnl::graph::op::attr::strides, Operator::Ints, 3)
86 .setAttr(dnnl::graph::op::attr::pads_begin, Operator::Ints, 4)
87 .setAttr(dnnl::graph::op::attr::pads_end, Operator::Ints, 4)
88 .setAttr(dnnl::graph::op::attr::dilations, Operator::Ints, 5)
89 .setAttr(dnnl::graph::op::attr::groups, Operator::Int, 6)
90 .setAttr(dnnl::graph::op::attr::weights_format, std::string("OIX"))
91 .setAttr(dnnl::graph::op::attr::data_format, std::string("NCX"));
92 } else if (
93 (nodeKind == Symbol::fromQualString("aten::_convolution")) ||
94 (nodeKind == Symbol::fromQualString("aten::convolution"))) {
95 bool transposed = toIValue(node->namedInput("transposed"))->toBool();
96 REQUIRE(!transposed);
97 return Operator(node, opkind::Convolution)
98 .setInput(0, 1, 2)
99 .setOutput(dnnl_graph_, 0)
100 .setAttr(dnnl::graph::op::attr::strides, Operator::Ints, 3)
101 .setAttr(dnnl::graph::op::attr::pads_begin, Operator::Ints, 4)
102 .setAttr(dnnl::graph::op::attr::pads_end, Operator::Ints, 4)
103 .setAttr(dnnl::graph::op::attr::dilations, Operator::Ints, 5)
104 .setAttr(dnnl::graph::op::attr::groups, Operator::Int, 8)
105 .setAttr(dnnl::graph::op::attr::weights_format, std::string("OIX"))
106 .setAttr(dnnl::graph::op::attr::data_format, std::string("NCX"));
107 } else if (nodeKind == Symbol::fromQualString("aten::batch_norm")) {
108 auto training = toIValue(node->namedInput("training"));
109 REQUIRE(training.has_value()); // cannot get training status in script mode
110 if (!training->toBool()) {
111 return Operator(node, opkind::BatchNormInference)
112 .setInput(0, 1, 2, 3, 4)
113 .setOutput(dnnl_graph_, 0)
114 .setAttr(dnnl::graph::op::attr::epsilon, Operator::Float, 7)
115 .setAttr(dnnl::graph::op::attr::data_format, std::string("NCX"));
116 }
117 } else if (nodeKind == Symbol::fromQualString("aten::layer_norm")) {
118 auto normalized_shape = toIValue(node->namedInput("normalized_shape"));
119 REQUIRE(normalized_shape->toIntList().size() == 1);
120 return Operator(node, opkind::LayerNorm)
121 .setInput(0, 2, 3)
122 .setOutput(dnnl_graph_, 0)
123 .setAttr(dnnl::graph::op::attr::epsilon, Operator::Float, 4)
124 .setAttr(dnnl::graph::op::attr::keep_stats, false);
125 } else if (nodeKind == Symbol::fromQualString("aten::addmm")) {
126 auto alpha = toIValue(node->namedInput("alpha"));
127 auto beta = toIValue(node->namedInput("beta"));
128 if (alpha.has_value() && beta.has_value()) {
129 if ((alpha->toDouble() == 1.0) && (beta->toDouble() == 1.0)) {
130 return Operator(node, opkind::MatMul)
131 .setInput(1, 2, 0)
132 .setOutput(dnnl_graph_, 0);
133 } else if ((alpha->toDouble() == 1.0) && (beta->toDouble() == 0.0)) {
134 return Operator(node, opkind::MatMul)
135 .setInput(1, 2)
136 .setOutput(dnnl_graph_, 0);
137 }
138 }
139 } else if (nodeKind == Symbol::fromQualString("aten::add"))
140 return makeBinaryOp(node, opkind::Add);
141 else if (nodeKind == Symbol::fromQualString("aten::mul"))
142 return makeBinaryOp(node, opkind::Multiply);
143 else if (nodeKind == Symbol::fromQualString("aten::div"))
144 return makeBinaryOp(node, opkind::Divide);
145 else if (nodeKind == Symbol::fromQualString("aten::tanh"))
146 return makeEltwiseOp(node, opkind::Tanh);
147 else if (nodeKind == Symbol::fromQualString("aten::relu"))
148 return makeEltwiseOp(node, opkind::ReLU);
149 else if (nodeKind == Symbol::fromQualString("aten::elu"))
150 return makeEltwiseOp(node, opkind::Elu)
151 .setAttr(dnnl::graph::op::attr::alpha, Operator::Float, 1);
152 else if (nodeKind == Symbol::fromQualString("aten::sigmoid"))
153 return makeEltwiseOp(node, opkind::Sigmoid);
154 else if (nodeKind == Symbol::fromQualString("aten::gelu"))
155 return makeEltwiseOp(node, opkind::GELU);
156 else if (nodeKind == Symbol::fromQualString("aten::round"))
157 return makeEltwiseOp(node, opkind::Round);
158 else if (nodeKind == Symbol::fromQualString("aten::exp"))
159 return makeEltwiseOp(node, opkind::Exp);
160 else if (nodeKind == Symbol::fromQualString("aten::sqrt"))
161 return makeEltwiseOp(node, opkind::Sqrt);
162 else if (nodeKind == Symbol::fromQualString("aten::abs"))
163 return makeEltwiseOp(node, opkind::Abs);
164 else if (nodeKind == Symbol::fromQualString("aten::square"))
165 return makeEltwiseOp(node, opkind::Square);
166 else if (nodeKind == Symbol::fromQualString("aten::clamp")) {
167 // PyTorch API already checks that both min & max are not None.
168 // But we can check it nevertheless.
169 auto clamp_min = toIValue(node->input(1));
170 auto clamp_max = toIValue(node->input(2));
171 REQUIRE(!(clamp_max->isNone() && clamp_min->isNone()));
172 auto clamp_min_value = (clamp_min->isNone())
173 ? -std::numeric_limits<float>::infinity()
174 : Operator::ScalarToFloat(node, 1);
175 auto clamp_max_value = (clamp_max->isNone())
176 ? std::numeric_limits<float>::infinity()
177 : Operator::ScalarToFloat(node, 2);
178 return makeEltwiseOp(node, opkind::Clamp)
179 .setAttr(dnnl::graph::op::attr::min, clamp_min_value)
180 .setAttr(dnnl::graph::op::attr::max, clamp_max_value);
181 } else if (nodeKind == Symbol::fromQualString("aten::hardtanh")) {
182 return makeEltwiseOp(node, opkind::Clamp)
183 .setAttr(dnnl::graph::op::attr::min, Operator::ScalarToFloat, 1)
184 .setAttr(dnnl::graph::op::attr::max, Operator::ScalarToFloat, 2);
185 } else if (nodeKind == Symbol::fromQualString("aten::hardswish"))
186 return makeEltwiseOp(node, opkind::HardSwish);
187 else if (nodeKind == Symbol::fromQualString("aten::log"))
188 return makeEltwiseOp(node, opkind::Log);
189 else if (nodeKind == Symbol::fromQualString("aten::leaky_relu")) {
190 return makeEltwiseOp(node, opkind::LeakyReLU)
191 .setAttr(dnnl::graph::op::attr::alpha, Operator::Float, 1);
192 } else if (nodeKind == Symbol::fromQualString("aten::relu6")) {
193 return makeEltwiseOp(node, opkind::Clamp)
194 .setAttr(dnnl::graph::op::attr::min, 0.f)
195 .setAttr(dnnl::graph::op::attr::max, 6.f);
196 } else if (
197 (nodeKind == Symbol::fromQualString("aten::softmax")) ||
198 (nodeKind == Symbol::fromQualString("aten::_softmax"))) {
199 auto axis = toIValue(node->namedInput("dim"))->toInt();
200 return Operator(node, opkind::SoftMax)
201 .setInput(0)
202 .setOutput(dnnl_graph_, 0)
203 .setAttr(dnnl::graph::op::attr::axis, axis);
204 } else if (nodeKind == Symbol::fromQualString("aten::_log_softmax")) {
205 auto axis = toIValue(node->namedInput("dim"))->toInt();
206 return Operator(node, opkind::LogSoftmax)
207 .setInput(0)
208 .setOutput(dnnl_graph_, 0)
209 .setAttr(dnnl::graph::op::attr::axis, axis);
210 } else if (nodeKind == Symbol::fromQualString("aten::cat")) {
211 auto o = Operator(node, opkind::Concat);
212 REQUIRE(node->namedInput("tensors")->node()->kind() == prim::ListConstruct);
213 REQUIRE(node->namedInput("tensors")->uses().size() == 1);
214 REQUIRE(node->namedInput("dim")->node()->kind() == prim::Constant);
215 // aten::cat needs a special handling since it takes a Tensor[] as input.
216 // We set the inputs of ListConstruct as the inputs of cat.
217 //
218 // Pytorch IR: LLGA sees:
219 // %a %b %c %dim %a %b %c
220 // \ | / | \ | /
221 // prim::ListConstruct prim::Constant llga::Concat[axis=%dim]
222 // \ /
223 // aten::cat
224 auto listConstruct = node->input(0)->node();
225 for (auto input : listConstruct->inputs())
226 o.setInputValue(input);
227 return o.setOutput(dnnl_graph_, 0)
228 .setAttr(dnnl::graph::op::attr::axis, Operator::Int, 1);
229 } else if (
230 (nodeKind == Symbol::fromQualString("aten::max_pool2d")) ||
231 (nodeKind == Symbol::fromQualString("aten::max_pool2d_with_indices"))) {
232 // Currently, LLGA lacks support to create indices mask.
233 // Once it's supported, max_pool2d_with_indices should be mapped differently
234 REQUIRE(node->namedInput("kernel_size")->node()->kind() == prim::Constant);
235 auto rounding_type =
236 toIValue(node->namedInput("ceil_mode"))->toBool() ? "ceil" : "floor";
237 return Operator(node, opkind::MaxPool)
238 .setInput(0)
239 .setOutput(dnnl_graph_, 0)
240 .setAttr(dnnl::graph::op::attr::kernel, Operator::Ints, 1)
241 .setAttr(dnnl::graph::op::attr::strides, Operator::Ints, 2)
242 .setAttr(dnnl::graph::op::attr::pads_begin, Operator::Ints, 3)
243 .setAttr(dnnl::graph::op::attr::pads_end, Operator::Ints, 3)
244 .setAttr(dnnl::graph::op::attr::dilations, Operator::Ints, 4)
245 .setAttr(
246 dnnl::graph::op::attr::rounding_type, std::string(rounding_type))
247 .setAttr(dnnl::graph::op::attr::data_format, std::string("NCX"));
248 } else if (nodeKind == Symbol::fromQualString("aten::avg_pool2d")) {
249 // TODO: do we need add checks for all Constants?
250 REQUIRE(node->namedInput("kernel_size")->node()->kind() == prim::Constant);
251 auto rounding_type =
252 toIValue(node->namedInput("ceil_mode"))->toBool() ? "ceil" : "floor";
253 auto divisor_override = toIValue(node->namedInput("divisor_override"));
254 REQUIRE(divisor_override->isNone());
255 return Operator(node, opkind::AvgPool)
256 .setInput(0)
257 .setOutput(dnnl_graph_, 0)
258 .setAttr(dnnl::graph::op::attr::kernel, Operator::Ints, 1)
259 .setAttr(dnnl::graph::op::attr::strides, Operator::Ints, 2)
260 .setAttr(dnnl::graph::op::attr::pads_begin, Operator::Ints, 3)
261 .setAttr(dnnl::graph::op::attr::pads_end, Operator::Ints, 3)
262 .setAttr(dnnl::graph::op::attr::exclude_pad, !Operator::Bool(node, 5))
263 .setAttr(
264 dnnl::graph::op::attr::rounding_type, std::string(rounding_type))
265 .setAttr(dnnl::graph::op::attr::data_format, std::string("NCX"));
266 } else if (nodeKind == Symbol::fromQualString("aten::matmul")) {
267 auto dim0 = getDimensions(node->namedInput("self")).value_or(-1);
268 auto dim1 = getDimensions(node->namedInput("other")).value_or(-1);
269 // TODO: support all shape combinations
270 REQUIRE(
271 (dim0 == 2 && dim1 == 2) || (dim0 == 4 && dim1 == 4) ||
272 (dim0 == 3 && dim1 == 2));
273 return Operator(node, opkind::MatMul)
274 .setInput(0, 1)
275 .setOutput(dnnl_graph_, 0);
276 } // fall through
277 else if (nodeKind == Symbol::fromQualString("aten::mm")) {
278 return Operator(node, opkind::MatMul)
279 .setInput(0, 1)
280 .setOutput(dnnl_graph_, 0);
281 } else if (nodeKind == Symbol::fromQualString("aten::bmm")) {
282 return Operator(node, opkind::MatMul)
283 .setInput(0, 1)
284 .setOutput(dnnl_graph_, 0);
285 } else if (nodeKind == Symbol::fromQualString("aten::linear")) {
286 return Operator(node, opkind::MatMul)
287 .setInput(0, 1, 2)
288 .setOutput(dnnl_graph_, 0)
289 .setAttr(dnnl::graph::op::attr::transpose_b, true);
290 } else if (nodeKind == Symbol::fromQualString("aten::permute")) {
291 REQUIRE(aliasDb_->hasInputWriters(node) == false);
292 return Operator(node, opkind::StaticTranspose)
293 .setInput(0)
294 .setOutput(dnnl_graph_, 0)
295 .setAttr(
296 dnnl::graph::op::attr::order,
297 toIValue(node->namedInput("dims"))->toIntVector());
298 } else if (nodeKind == Symbol::fromQualString("aten::contiguous")) {
299 // Contiguous should only be mapped to oneDNN Graph if the destination
300 // memory-layout is different than the source memory-format
301 // Strides would be different, but shape would be same
302 auto typeOfInput = node->input(0)->type()->expect<TensorType>();
303 auto typeOfOutput = node->output(0)->type()->expect<TensorType>();
304 auto inputStrides = typeOfInput->strides().concrete_sizes();
305 auto outputStrides = typeOfOutput->strides().concrete_sizes();
306 REQUIRE(inputStrides != outputStrides);
307 return Operator(node, opkind::Reorder)
308 .setInput(0)
309 .setOutput(dnnl_graph_, 0);
310 }
311 GRAPH_DEBUG("Making ", nodeKind.toQualString(), " a wildcard");
312 return makeWildcardOp(node);
313 }
314
inferDeviceFromValue(Value * v)315 static DeviceType inferDeviceFromValue(Value* v) {
316 auto tt = v->type()->cast<TensorType>();
317 if (!tt) {
318 return at::kCPU;
319 }
320 auto device = tt->device();
321 if (!device) {
322 return at::kCPU;
323 }
324 return device->type();
325 }
326
inferDevice(const std::shared_ptr<Graph> & graph)327 static DeviceType inferDevice(const std::shared_ptr<Graph>& graph) {
328 auto dt = inferDeviceFromValue(graph->inputs()[0]);
329 TORCH_CHECK(
330 std::all_of(
331 graph->inputs().begin(),
332 graph->inputs().end(),
333 [dt](Value* v) { return inferDeviceFromValue(v) == dt; }),
334 "All inputs must have the same deive type");
335 return dt;
336 }
337
getLlgaEngineKind(DeviceType type)338 static dnnl::engine::kind getLlgaEngineKind(DeviceType type) {
339 switch (type) {
340 case DeviceType::CPU:
341 return dnnl::engine::kind::cpu;
342 default:
343 TORCH_CHECK(false, "Not support device type ", type);
344 }
345 }
346
mayAddListConstructIntoConcatPartition(Node * n,OpPartitionMap & opToOwningPartition)347 static void mayAddListConstructIntoConcatPartition(
348 Node* n,
349 OpPartitionMap& opToOwningPartition) {
350 // Since prim::ListConstruct is not visible to the LLGA,
351 // it will not be in any partition returned from partfuseritioning results.
352 // We need rewrite opToOwningPartition to make the prim::ListConstruct to be
353 // 'virtually' in the same partition with the aten::cat, so that
354 // prim::ListConstruct can be fused into the fusion group by graph fuser.
355 // We emphasize on 'virtually' because get_num_ops() for cat's partition
356 // would still return 1.
357 if (n->kind() == aten::cat && opToOwningPartition.has(n)) {
358 auto listConstrcut = n->namedInput("tensors")->node();
359 auto partitionId = opToOwningPartition.get(n);
360 opToOwningPartition.add(listConstrcut, partitionId);
361 }
362 }
363
364 // Verify that input tensors are compatible with oneDNN Graph.
365 // Scalars would be converted to 1-D tensors later anyway,
366 // but they shouldn't be complex-double
367 // If this check fails, convert op to wildcard
checkInputCompatibility(Node * node)368 static bool checkInputCompatibility(Node* node) {
369 auto allInputs = node->inputs();
370 for (auto input : allInputs) {
371 c10::IValue inputIValue = toIValue(input);
372 if (inputIValue.isTensor()) {
373 const at::Tensor& tensor = inputIValue.toTensor();
374 if (tensor.device() != at::kCPU) {
375 return false;
376 }
377 auto dtype = tensor.scalar_type();
378 if ((dtype != at::ScalarType::BFloat16) &&
379 (dtype != at::ScalarType::Float) && (dtype != at::ScalarType::Long)) {
380 // We've allowed Long dtype here although oneDNN Graph does not support
381 // Long dtype because oneDNN Graph will end up not handling the op that
382 // has an input with Long dtype, so it'd be handled by PyTorch.
383 return false;
384 }
385 } else if (inputIValue.isScalar()) {
386 if (inputIValue.isComplexDouble()) {
387 return false;
388 }
389 } else if (input->type()->isSubtypeOf(TensorType::get())) {
390 auto input_typeptr = input->type()->cast<TensorType>();
391 if (input_typeptr->scalarType().has_value()) {
392 at::ScalarType dtype = input_typeptr->scalarType().value();
393 if ((dtype != at::ScalarType::Float) &&
394 (dtype != at::ScalarType::BFloat16)) {
395 return false;
396 }
397 }
398 }
399 }
400 return true;
401 }
402
LlgaGraphHelper(const std::shared_ptr<Graph> & graph,dnnl::graph::partition::policy policy)403 LlgaGraphHelper::LlgaGraphHelper(
404 const std::shared_ptr<Graph>& graph,
405 dnnl::graph::partition::policy policy) {
406 auto deviceType = inferDevice(graph);
407 auto engineKind = getLlgaEngineKind(deviceType);
408 dnnl_graph_ = std::make_unique<dnnl::graph::graph>(engineKind);
409 aliasDb_ = std::make_unique<torch::jit::AliasDb>(graph);
410 GRAPH_DEBUG("Constructing LLGA graph");
411 // TODO: select nodes in top-level block for now
412 for (auto* node : graph->block()->nodes()) {
413 auto kindOfNode = node->kind();
414 GRAPH_DEBUG("Trying to add ", kindOfNode.toQualString());
415 if (checkInputCompatibility(node)) {
416 auto op = createOperator(node);
417 dnnl_graph_->add_op(op.llgaOp());
418 GRAPH_DEBUG(" Added node ", kindOfNode.toQualString());
419 } else {
420 GRAPH_DEBUG("Incompatible inputs for ", kindOfNode.toQualString());
421 dnnl_graph_->add_op(makeWildcardOp(node).llgaOp());
422 }
423
424 for (Value* input : node->inputs()) {
425 tensorIdToValue_.emplace(input->unique(), input);
426 }
427 }
428
429 dnnl_graph_->finalize();
430
431 GRAPH_DEBUG("Get Partitions");
432 std::vector<dnnl::graph::partition> partitions =
433 dnnl_graph_->get_partitions(policy);
434 // excluded unsupported Wildcard partitions
435 for (auto& partition : partitions) {
436 if (partition.is_supported()) {
437 partitions_.push_back(partition);
438 }
439 }
440
441 GRAPH_DEBUG(" Got #partitions: ", partitions_.size());
442 for (size_t partId = 0; partId < partitions_.size(); partId++) {
443 for (auto opId : partitions_[partId].get_ops()) {
444 opToOwningPartition_.add(opId, partId);
445 }
446 }
447
448 // Scanning the graph again for post processing
449 for (auto* node : graph->block()->nodes()) {
450 mayAddListConstructIntoConcatPartition(node, opToOwningPartition_);
451 }
452 }
453
isLlgaSubgraph(const Node * node)454 bool LlgaGraphHelper::isLlgaSubgraph(const Node* node) {
455 return node->hasAttribute(attr::Subgraph) &&
456 node->kind() == prim::oneDNNFusionGroup;
457 }
458
shouldMerge(Node * toMerge,Node * subgraph)459 bool LlgaGraphHelper::shouldMerge(Node* toMerge, Node* subgraph) {
460 TORCH_CHECK(
461 isLlgaSubgraph(subgraph),
462 "The consumer node does not contain a subgraph");
463 if (!shouldConsiderForMerge(toMerge)) {
464 return false;
465 }
466 return opToOwningPartition_.get(toMerge) ==
467 opToOwningPartition_.get(subgraph);
468 }
469
470 // Except for conv & GEMMs, which should always be handled by oneDNN Graph,
471 // only use single-op partitions for ops unsupported by NNC, or ops
472 // that oneDNN executes faster. prim::ListConstruct is an exception, since
473 // we simply want to fuse it with cat.
isBetterSuitedForLLGA(NodeKind kindOfOp)474 static bool isBetterSuitedForLLGA(NodeKind kindOfOp) {
475 return (
476 (kindOfOp == aten::layer_norm) || (kindOfOp == aten::avg_pool2d) ||
477 (kindOfOp == aten::matmul) || (kindOfOp == aten::max_pool2d) ||
478 (kindOfOp == aten::conv2d) || (kindOfOp == aten::_convolution) ||
479 (kindOfOp == aten::mm) || (kindOfOp == aten::linear) ||
480 (kindOfOp == aten::cat) || (kindOfOp == prim::ListConstruct));
481 }
482
checkForSingleOpPartition(Node * node)483 bool LlgaGraphHelper::checkForSingleOpPartition(Node* node) {
484 if (opToOwningPartition_.has(node)) {
485 auto partitionId = opToOwningPartition_.get(node);
486 if (partitions_[partitionId].get_ops_num() == 1) {
487 auto kindOfNode = node->kind();
488 return isBetterSuitedForLLGA(kindOfNode);
489 } else {
490 // multi-op partition
491 return true;
492 }
493 } else {
494 // this op isn't present in any partition
495 return false;
496 }
497 }
498
shouldConsiderForMerge(Node * node)499 bool LlgaGraphHelper::shouldConsiderForMerge(Node* node) {
500 // if we're already in the process of merging
501 if (isLlgaSubgraph(node)) {
502 return true;
503 }
504 return checkForSingleOpPartition(node);
505 }
506
createSingletonSubgraph(Node * n,AliasDb & aliasDb)507 Node* LlgaGraphHelper::createSingletonSubgraph(Node* n, AliasDb& aliasDb) {
508 auto partitionId = opToOwningPartition_.get(n);
509 GRAPH_DEBUG(
510 "Creating FusionGroup_", partitionId, " for ", n->kind().toQualString());
511 auto group = SubgraphUtils::createSingletonSubgraphAndUpdateAliasing(
512 n, prim::oneDNNFusionGroup, aliasDb);
513 opToOwningPartition_.add(group, partitionId);
514 return group;
515 }
516
mergeNodeIntoSubgraph(Node * toMerge,Node * subgraphNode,AliasDb & aliasDb)517 void LlgaGraphHelper::mergeNodeIntoSubgraph(
518 Node* toMerge,
519 Node* subgraphNode,
520 AliasDb& aliasDb) {
521 if (isLlgaSubgraph(toMerge)) {
522 GRAPH_DEBUG(
523 "Merging ",
524 toMerge->kind().toQualString(),
525 "_",
526 opToOwningPartition_.get(toMerge),
527 " into ",
528 subgraphNode->kind().toQualString(),
529 "_",
530 opToOwningPartition_.get(subgraphNode));
531 } else {
532 GRAPH_DEBUG(
533 "Merging ",
534 toMerge->kind().toQualString(),
535 " into ",
536 subgraphNode->kind().toQualString(),
537 "_",
538 opToOwningPartition_.get(subgraphNode));
539 }
540
541 SubgraphUtils::mergeNodeIntoSubgraphAndUpdateAliasing(
542 toMerge, subgraphNode, aliasDb);
543 }
544
unmergeIfAnyNodeIsMissing(Node * subgraphNode)545 void LlgaGraphHelper::unmergeIfAnyNodeIsMissing(Node* subgraphNode) {
546 TORCH_CHECK(isLlgaSubgraph(subgraphNode), "Cannot unmerge a non-LLGA node");
547
548 auto partitionId = opToOwningPartition_.get(subgraphNode);
549 auto expectOpNum = partitions_[partitionId].get_ops_num();
550 auto actualOpNum = countSupportedOps(subgraphNode->g(attr::Subgraph));
551
552 if (expectOpNum != actualOpNum) {
553 GRAPH_DEBUG(
554 "Unmerging FusionGroup_",
555 partitionId,
556 ". Expected ",
557 expectOpNum,
558 " ops, but got ",
559 actualOpNum,
560 " ops.");
561 SubgraphUtils::unmergeSubgraph(subgraphNode);
562 }
563 }
564
countSupportedOps(const std::shared_ptr<Graph> & graph) const565 size_t LlgaGraphHelper::countSupportedOps(
566 const std::shared_ptr<Graph>& graph) const {
567 // TODO: count nodes in top-level block for now
568 size_t cnt = 0;
569 for (auto* node : graph->block()->nodes()) {
570 auto nodeKind = node->kind();
571 if ((nodeKind != prim::Constant) && (nodeKind != prim::ListConstruct)) {
572 cnt++;
573 }
574 }
575 return cnt;
576 }
577
getPartitions() const578 std::vector<dnnl::graph::partition> LlgaGraphHelper::getPartitions() const {
579 return partitions_;
580 }
581
getTensorIdToValue() const582 std::map<size_t, Value*> LlgaGraphHelper::getTensorIdToValue() const {
583 return tensorIdToValue_;
584 }
585
LlgaNodeWrapper(const Node * node)586 LlgaNodeWrapper::LlgaNodeWrapper(const Node* node)
587 : n(const_cast<Node*>(node)) { // NOLINT
588 TORCH_CHECK(
589 LlgaGraphHelper::isLlgaSubgraph(n), "Cannot wrap a non-LLGA fusion node");
590 }
591
setOpaqueLayout(size_t offset)592 void LlgaNodeWrapper::setOpaqueLayout(size_t offset) {
593 const auto num_output = n->is(attr::output_layouts).size();
594 TORCH_CHECK(
595 offset < num_output,
596 "Out of range. (Invalid index ",
597 offset,
598 " for attr::output_layouts with size ",
599 num_output,
600 ")");
601 auto& layouts =
602 const_cast<std::vector<int64_t>&>(n->is(attr::output_layouts)); // NOLINT
603 layouts.at(offset) = OPAQUE_LAYOUT;
604 }
605
useOpaqueLayout(size_t offset) const606 bool LlgaNodeWrapper::useOpaqueLayout(size_t offset) const {
607 const auto num_output = n->is(attr::output_layouts).size();
608 TORCH_CHECK(
609 offset < num_output,
610 "Out of range. (Invalid index ",
611 offset,
612 " for attr::output_layouts with size ",
613 num_output,
614 ")");
615 return n->is(attr::output_layouts)[offset] == OPAQUE_LAYOUT;
616 }
617
618 } // namespace onednn
619 } // namespace fuser
620 } // namespace jit
621 } // namespace torch
622