xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/codegen/onednn/graph_helper.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/codegen/onednn/LlgaTensorImpl.h>
2 #include <torch/csrc/jit/codegen/onednn/graph_helper.h>
3 
4 #include <ATen/core/functional.h>
5 #include <torch/csrc/jit/jit_log.h>
6 #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
7 
8 namespace torch {
9 namespace jit {
10 namespace fuser {
11 namespace onednn {
12 
13 using opkind = dnnl::graph::op::kind;
14 
fixConvOptionalBias(Node * node)15 static void fixConvOptionalBias(Node* node) {
16   if (node->namedInput("bias")->mustNotBeNone() == false) {
17     // Replace non-existent optional bias with const None
18     auto g = node->owningGraph();
19     auto n = g->createNone();
20     auto v = n->insertBefore(node)->output();
21     node->replaceInput(2, v);
22   }
23 }
24 
getDimensions(Value * v)25 static std::optional<size_t> getDimensions(Value* v) {
26   if (v->type()->isSubtypeOf(TensorType::get())) {
27     return v->type()->cast<TensorType>()->sizes().size();
28   } else {
29     return std::nullopt;
30   }
31 }
32 
33 // PyTorch ops that can't otherwise be mapped to oneDNN Graph ops are mapped as
34 // Wildcards instead. They make the integration code with PyTorch simpler by
35 // passing every op to the oneDNN Graph library in the add_op call -
36 // no need to check beforehand whether the op is supported by oneDNN Graph or
37 // not oneDNN Graph ops separated by wildcards don't end up in the same
38 // partition.
makeWildcardOp(Node * node)39 static Operator makeWildcardOp(Node* node) {
40   auto o = Operator(node, opkind::Wildcard);
41   // wildcard op contains only topology info
42   for (size_t i = 0; i < node->inputs().size(); i++) {
43     o.setInput(0, i);
44   }
45   for (size_t i = 0; i < node->outputs().size(); i++) {
46     o.setOutput(i);
47   }
48   return o;
49 }
50 
51 // If we don't meet a certain condition to map a PyTorch op to a oneDNN Graph
52 // op, then we create a wildcard op corresponding to that PyTorch op instead.
53 #define REQUIRE(cond)                                 \
54   if (!(cond)) {                                      \
55     GRAPH_DEBUG("Unsupported condition " #cond "\n"); \
56     return makeWildcardOp(node);                      \
57   }
58 
makeEltwiseOp(Node * node,opkind kind)59 Operator LlgaGraphHelper::makeEltwiseOp(Node* node, opkind kind) {
60   return Operator(node, kind).setInput(0).setOutput(dnnl_graph_, 0);
61 }
62 
makeBinaryOp(Node * node,opkind kind)63 Operator LlgaGraphHelper::makeBinaryOp(Node* node, opkind kind) {
64   REQUIRE(
65       node->input(0)->type()->isSubtypeOf(TensorType::get()) &&
66       node->input(1)->type()->isSubtypeOf(TensorType::get()))
67   return Operator(node, kind).setInput(0, 1).setOutput(dnnl_graph_, 0);
68 }
69 
70 // Map a PyTorch op to its corresponding oneDNN Graph op.
71 // If mapping isn't possible, then create a wildcard op instead.
72 // The mapping is done as per oneDNN Graph op schema defined in
73 // third_party/ideep/mkl-dnn/src/interface/op_def.hpp.
createOperator(Node * node)74 Operator LlgaGraphHelper::createOperator(Node* node) {
75   auto nodeKind = node->kind();
76   // we're using an if-else clause instead of a switch staement
77   // because we would soon be adding custom ops with function schemas.
78   // We would have to use Symbol::fromQualString at that time anyway,
79   // but we are okay with this choice, since this code is not in the hot-path.
80   if (nodeKind == Symbol::fromQualString("aten::conv2d")) {
81     fixConvOptionalBias(node);
82     return Operator(node, opkind::Convolution)
83         .setInput(0, 1, 2)
84         .setOutput(dnnl_graph_, 0)
85         .setAttr(dnnl::graph::op::attr::strides, Operator::Ints, 3)
86         .setAttr(dnnl::graph::op::attr::pads_begin, Operator::Ints, 4)
87         .setAttr(dnnl::graph::op::attr::pads_end, Operator::Ints, 4)
88         .setAttr(dnnl::graph::op::attr::dilations, Operator::Ints, 5)
89         .setAttr(dnnl::graph::op::attr::groups, Operator::Int, 6)
90         .setAttr(dnnl::graph::op::attr::weights_format, std::string("OIX"))
91         .setAttr(dnnl::graph::op::attr::data_format, std::string("NCX"));
92   } else if (
93       (nodeKind == Symbol::fromQualString("aten::_convolution")) ||
94       (nodeKind == Symbol::fromQualString("aten::convolution"))) {
95     bool transposed = toIValue(node->namedInput("transposed"))->toBool();
96     REQUIRE(!transposed);
97     return Operator(node, opkind::Convolution)
98         .setInput(0, 1, 2)
99         .setOutput(dnnl_graph_, 0)
100         .setAttr(dnnl::graph::op::attr::strides, Operator::Ints, 3)
101         .setAttr(dnnl::graph::op::attr::pads_begin, Operator::Ints, 4)
102         .setAttr(dnnl::graph::op::attr::pads_end, Operator::Ints, 4)
103         .setAttr(dnnl::graph::op::attr::dilations, Operator::Ints, 5)
104         .setAttr(dnnl::graph::op::attr::groups, Operator::Int, 8)
105         .setAttr(dnnl::graph::op::attr::weights_format, std::string("OIX"))
106         .setAttr(dnnl::graph::op::attr::data_format, std::string("NCX"));
107   } else if (nodeKind == Symbol::fromQualString("aten::batch_norm")) {
108     auto training = toIValue(node->namedInput("training"));
109     REQUIRE(training.has_value()); // cannot get training status in script mode
110     if (!training->toBool()) {
111       return Operator(node, opkind::BatchNormInference)
112           .setInput(0, 1, 2, 3, 4)
113           .setOutput(dnnl_graph_, 0)
114           .setAttr(dnnl::graph::op::attr::epsilon, Operator::Float, 7)
115           .setAttr(dnnl::graph::op::attr::data_format, std::string("NCX"));
116     }
117   } else if (nodeKind == Symbol::fromQualString("aten::layer_norm")) {
118     auto normalized_shape = toIValue(node->namedInput("normalized_shape"));
119     REQUIRE(normalized_shape->toIntList().size() == 1);
120     return Operator(node, opkind::LayerNorm)
121         .setInput(0, 2, 3)
122         .setOutput(dnnl_graph_, 0)
123         .setAttr(dnnl::graph::op::attr::epsilon, Operator::Float, 4)
124         .setAttr(dnnl::graph::op::attr::keep_stats, false);
125   } else if (nodeKind == Symbol::fromQualString("aten::addmm")) {
126     auto alpha = toIValue(node->namedInput("alpha"));
127     auto beta = toIValue(node->namedInput("beta"));
128     if (alpha.has_value() && beta.has_value()) {
129       if ((alpha->toDouble() == 1.0) && (beta->toDouble() == 1.0)) {
130         return Operator(node, opkind::MatMul)
131             .setInput(1, 2, 0)
132             .setOutput(dnnl_graph_, 0);
133       } else if ((alpha->toDouble() == 1.0) && (beta->toDouble() == 0.0)) {
134         return Operator(node, opkind::MatMul)
135             .setInput(1, 2)
136             .setOutput(dnnl_graph_, 0);
137       }
138     }
139   } else if (nodeKind == Symbol::fromQualString("aten::add"))
140     return makeBinaryOp(node, opkind::Add);
141   else if (nodeKind == Symbol::fromQualString("aten::mul"))
142     return makeBinaryOp(node, opkind::Multiply);
143   else if (nodeKind == Symbol::fromQualString("aten::div"))
144     return makeBinaryOp(node, opkind::Divide);
145   else if (nodeKind == Symbol::fromQualString("aten::tanh"))
146     return makeEltwiseOp(node, opkind::Tanh);
147   else if (nodeKind == Symbol::fromQualString("aten::relu"))
148     return makeEltwiseOp(node, opkind::ReLU);
149   else if (nodeKind == Symbol::fromQualString("aten::elu"))
150     return makeEltwiseOp(node, opkind::Elu)
151         .setAttr(dnnl::graph::op::attr::alpha, Operator::Float, 1);
152   else if (nodeKind == Symbol::fromQualString("aten::sigmoid"))
153     return makeEltwiseOp(node, opkind::Sigmoid);
154   else if (nodeKind == Symbol::fromQualString("aten::gelu"))
155     return makeEltwiseOp(node, opkind::GELU);
156   else if (nodeKind == Symbol::fromQualString("aten::round"))
157     return makeEltwiseOp(node, opkind::Round);
158   else if (nodeKind == Symbol::fromQualString("aten::exp"))
159     return makeEltwiseOp(node, opkind::Exp);
160   else if (nodeKind == Symbol::fromQualString("aten::sqrt"))
161     return makeEltwiseOp(node, opkind::Sqrt);
162   else if (nodeKind == Symbol::fromQualString("aten::abs"))
163     return makeEltwiseOp(node, opkind::Abs);
164   else if (nodeKind == Symbol::fromQualString("aten::square"))
165     return makeEltwiseOp(node, opkind::Square);
166   else if (nodeKind == Symbol::fromQualString("aten::clamp")) {
167     // PyTorch API already checks that both min & max are not None.
168     // But we can check it nevertheless.
169     auto clamp_min = toIValue(node->input(1));
170     auto clamp_max = toIValue(node->input(2));
171     REQUIRE(!(clamp_max->isNone() && clamp_min->isNone()));
172     auto clamp_min_value = (clamp_min->isNone())
173         ? -std::numeric_limits<float>::infinity()
174         : Operator::ScalarToFloat(node, 1);
175     auto clamp_max_value = (clamp_max->isNone())
176         ? std::numeric_limits<float>::infinity()
177         : Operator::ScalarToFloat(node, 2);
178     return makeEltwiseOp(node, opkind::Clamp)
179         .setAttr(dnnl::graph::op::attr::min, clamp_min_value)
180         .setAttr(dnnl::graph::op::attr::max, clamp_max_value);
181   } else if (nodeKind == Symbol::fromQualString("aten::hardtanh")) {
182     return makeEltwiseOp(node, opkind::Clamp)
183         .setAttr(dnnl::graph::op::attr::min, Operator::ScalarToFloat, 1)
184         .setAttr(dnnl::graph::op::attr::max, Operator::ScalarToFloat, 2);
185   } else if (nodeKind == Symbol::fromQualString("aten::hardswish"))
186     return makeEltwiseOp(node, opkind::HardSwish);
187   else if (nodeKind == Symbol::fromQualString("aten::log"))
188     return makeEltwiseOp(node, opkind::Log);
189   else if (nodeKind == Symbol::fromQualString("aten::leaky_relu")) {
190     return makeEltwiseOp(node, opkind::LeakyReLU)
191         .setAttr(dnnl::graph::op::attr::alpha, Operator::Float, 1);
192   } else if (nodeKind == Symbol::fromQualString("aten::relu6")) {
193     return makeEltwiseOp(node, opkind::Clamp)
194         .setAttr(dnnl::graph::op::attr::min, 0.f)
195         .setAttr(dnnl::graph::op::attr::max, 6.f);
196   } else if (
197       (nodeKind == Symbol::fromQualString("aten::softmax")) ||
198       (nodeKind == Symbol::fromQualString("aten::_softmax"))) {
199     auto axis = toIValue(node->namedInput("dim"))->toInt();
200     return Operator(node, opkind::SoftMax)
201         .setInput(0)
202         .setOutput(dnnl_graph_, 0)
203         .setAttr(dnnl::graph::op::attr::axis, axis);
204   } else if (nodeKind == Symbol::fromQualString("aten::_log_softmax")) {
205     auto axis = toIValue(node->namedInput("dim"))->toInt();
206     return Operator(node, opkind::LogSoftmax)
207         .setInput(0)
208         .setOutput(dnnl_graph_, 0)
209         .setAttr(dnnl::graph::op::attr::axis, axis);
210   } else if (nodeKind == Symbol::fromQualString("aten::cat")) {
211     auto o = Operator(node, opkind::Concat);
212     REQUIRE(node->namedInput("tensors")->node()->kind() == prim::ListConstruct);
213     REQUIRE(node->namedInput("tensors")->uses().size() == 1);
214     REQUIRE(node->namedInput("dim")->node()->kind() == prim::Constant);
215     // aten::cat needs a special handling since it takes a Tensor[] as input.
216     // We set the inputs of ListConstruct as the inputs of cat.
217     //
218     // Pytorch IR:                              LLGA sees:
219     //     %a    %b     %c          %dim              %a    %b    %c
220     //      \     |     /             |                \     |    /
221     //   prim::ListConstruct   prim::Constant     llga::Concat[axis=%dim]
222     //                    \      /
223     //                    aten::cat
224     auto listConstruct = node->input(0)->node();
225     for (auto input : listConstruct->inputs())
226       o.setInputValue(input);
227     return o.setOutput(dnnl_graph_, 0)
228         .setAttr(dnnl::graph::op::attr::axis, Operator::Int, 1);
229   } else if (
230       (nodeKind == Symbol::fromQualString("aten::max_pool2d")) ||
231       (nodeKind == Symbol::fromQualString("aten::max_pool2d_with_indices"))) {
232     // Currently, LLGA lacks support to create indices mask.
233     // Once it's supported, max_pool2d_with_indices should be mapped differently
234     REQUIRE(node->namedInput("kernel_size")->node()->kind() == prim::Constant);
235     auto rounding_type =
236         toIValue(node->namedInput("ceil_mode"))->toBool() ? "ceil" : "floor";
237     return Operator(node, opkind::MaxPool)
238         .setInput(0)
239         .setOutput(dnnl_graph_, 0)
240         .setAttr(dnnl::graph::op::attr::kernel, Operator::Ints, 1)
241         .setAttr(dnnl::graph::op::attr::strides, Operator::Ints, 2)
242         .setAttr(dnnl::graph::op::attr::pads_begin, Operator::Ints, 3)
243         .setAttr(dnnl::graph::op::attr::pads_end, Operator::Ints, 3)
244         .setAttr(dnnl::graph::op::attr::dilations, Operator::Ints, 4)
245         .setAttr(
246             dnnl::graph::op::attr::rounding_type, std::string(rounding_type))
247         .setAttr(dnnl::graph::op::attr::data_format, std::string("NCX"));
248   } else if (nodeKind == Symbol::fromQualString("aten::avg_pool2d")) {
249     // TODO: do we need add checks for all Constants?
250     REQUIRE(node->namedInput("kernel_size")->node()->kind() == prim::Constant);
251     auto rounding_type =
252         toIValue(node->namedInput("ceil_mode"))->toBool() ? "ceil" : "floor";
253     auto divisor_override = toIValue(node->namedInput("divisor_override"));
254     REQUIRE(divisor_override->isNone());
255     return Operator(node, opkind::AvgPool)
256         .setInput(0)
257         .setOutput(dnnl_graph_, 0)
258         .setAttr(dnnl::graph::op::attr::kernel, Operator::Ints, 1)
259         .setAttr(dnnl::graph::op::attr::strides, Operator::Ints, 2)
260         .setAttr(dnnl::graph::op::attr::pads_begin, Operator::Ints, 3)
261         .setAttr(dnnl::graph::op::attr::pads_end, Operator::Ints, 3)
262         .setAttr(dnnl::graph::op::attr::exclude_pad, !Operator::Bool(node, 5))
263         .setAttr(
264             dnnl::graph::op::attr::rounding_type, std::string(rounding_type))
265         .setAttr(dnnl::graph::op::attr::data_format, std::string("NCX"));
266   } else if (nodeKind == Symbol::fromQualString("aten::matmul")) {
267     auto dim0 = getDimensions(node->namedInput("self")).value_or(-1);
268     auto dim1 = getDimensions(node->namedInput("other")).value_or(-1);
269     // TODO: support all shape combinations
270     REQUIRE(
271         (dim0 == 2 && dim1 == 2) || (dim0 == 4 && dim1 == 4) ||
272         (dim0 == 3 && dim1 == 2));
273     return Operator(node, opkind::MatMul)
274         .setInput(0, 1)
275         .setOutput(dnnl_graph_, 0);
276   } // fall through
277   else if (nodeKind == Symbol::fromQualString("aten::mm")) {
278     return Operator(node, opkind::MatMul)
279         .setInput(0, 1)
280         .setOutput(dnnl_graph_, 0);
281   } else if (nodeKind == Symbol::fromQualString("aten::bmm")) {
282     return Operator(node, opkind::MatMul)
283         .setInput(0, 1)
284         .setOutput(dnnl_graph_, 0);
285   } else if (nodeKind == Symbol::fromQualString("aten::linear")) {
286     return Operator(node, opkind::MatMul)
287         .setInput(0, 1, 2)
288         .setOutput(dnnl_graph_, 0)
289         .setAttr(dnnl::graph::op::attr::transpose_b, true);
290   } else if (nodeKind == Symbol::fromQualString("aten::permute")) {
291     REQUIRE(aliasDb_->hasInputWriters(node) == false);
292     return Operator(node, opkind::StaticTranspose)
293         .setInput(0)
294         .setOutput(dnnl_graph_, 0)
295         .setAttr(
296             dnnl::graph::op::attr::order,
297             toIValue(node->namedInput("dims"))->toIntVector());
298   } else if (nodeKind == Symbol::fromQualString("aten::contiguous")) {
299     // Contiguous should only be mapped to oneDNN Graph if the destination
300     // memory-layout is different than the source memory-format
301     // Strides would be different, but shape would be same
302     auto typeOfInput = node->input(0)->type()->expect<TensorType>();
303     auto typeOfOutput = node->output(0)->type()->expect<TensorType>();
304     auto inputStrides = typeOfInput->strides().concrete_sizes();
305     auto outputStrides = typeOfOutput->strides().concrete_sizes();
306     REQUIRE(inputStrides != outputStrides);
307     return Operator(node, opkind::Reorder)
308         .setInput(0)
309         .setOutput(dnnl_graph_, 0);
310   }
311   GRAPH_DEBUG("Making ", nodeKind.toQualString(), " a wildcard");
312   return makeWildcardOp(node);
313 }
314 
inferDeviceFromValue(Value * v)315 static DeviceType inferDeviceFromValue(Value* v) {
316   auto tt = v->type()->cast<TensorType>();
317   if (!tt) {
318     return at::kCPU;
319   }
320   auto device = tt->device();
321   if (!device) {
322     return at::kCPU;
323   }
324   return device->type();
325 }
326 
inferDevice(const std::shared_ptr<Graph> & graph)327 static DeviceType inferDevice(const std::shared_ptr<Graph>& graph) {
328   auto dt = inferDeviceFromValue(graph->inputs()[0]);
329   TORCH_CHECK(
330       std::all_of(
331           graph->inputs().begin(),
332           graph->inputs().end(),
333           [dt](Value* v) { return inferDeviceFromValue(v) == dt; }),
334       "All inputs must have the same deive type");
335   return dt;
336 }
337 
getLlgaEngineKind(DeviceType type)338 static dnnl::engine::kind getLlgaEngineKind(DeviceType type) {
339   switch (type) {
340     case DeviceType::CPU:
341       return dnnl::engine::kind::cpu;
342     default:
343       TORCH_CHECK(false, "Not support device type ", type);
344   }
345 }
346 
mayAddListConstructIntoConcatPartition(Node * n,OpPartitionMap & opToOwningPartition)347 static void mayAddListConstructIntoConcatPartition(
348     Node* n,
349     OpPartitionMap& opToOwningPartition) {
350   // Since prim::ListConstruct is not visible to the LLGA,
351   // it will not be in any partition returned from partfuseritioning results.
352   // We need rewrite opToOwningPartition to make the prim::ListConstruct to be
353   // 'virtually' in the same partition with the aten::cat, so that
354   // prim::ListConstruct can be fused into the fusion group by graph fuser.
355   // We emphasize on 'virtually' because get_num_ops() for cat's partition
356   // would still return 1.
357   if (n->kind() == aten::cat && opToOwningPartition.has(n)) {
358     auto listConstrcut = n->namedInput("tensors")->node();
359     auto partitionId = opToOwningPartition.get(n);
360     opToOwningPartition.add(listConstrcut, partitionId);
361   }
362 }
363 
364 // Verify that input tensors are compatible with oneDNN Graph.
365 // Scalars would be converted to 1-D tensors later anyway,
366 // but they shouldn't be complex-double
367 // If this check fails, convert op to wildcard
checkInputCompatibility(Node * node)368 static bool checkInputCompatibility(Node* node) {
369   auto allInputs = node->inputs();
370   for (auto input : allInputs) {
371     c10::IValue inputIValue = toIValue(input);
372     if (inputIValue.isTensor()) {
373       const at::Tensor& tensor = inputIValue.toTensor();
374       if (tensor.device() != at::kCPU) {
375         return false;
376       }
377       auto dtype = tensor.scalar_type();
378       if ((dtype != at::ScalarType::BFloat16) &&
379           (dtype != at::ScalarType::Float) && (dtype != at::ScalarType::Long)) {
380         // We've allowed Long dtype here although oneDNN Graph does not support
381         // Long dtype because oneDNN Graph will end up not handling the op that
382         // has an input with Long dtype, so it'd be handled by PyTorch.
383         return false;
384       }
385     } else if (inputIValue.isScalar()) {
386       if (inputIValue.isComplexDouble()) {
387         return false;
388       }
389     } else if (input->type()->isSubtypeOf(TensorType::get())) {
390       auto input_typeptr = input->type()->cast<TensorType>();
391       if (input_typeptr->scalarType().has_value()) {
392         at::ScalarType dtype = input_typeptr->scalarType().value();
393         if ((dtype != at::ScalarType::Float) &&
394             (dtype != at::ScalarType::BFloat16)) {
395           return false;
396         }
397       }
398     }
399   }
400   return true;
401 }
402 
LlgaGraphHelper(const std::shared_ptr<Graph> & graph,dnnl::graph::partition::policy policy)403 LlgaGraphHelper::LlgaGraphHelper(
404     const std::shared_ptr<Graph>& graph,
405     dnnl::graph::partition::policy policy) {
406   auto deviceType = inferDevice(graph);
407   auto engineKind = getLlgaEngineKind(deviceType);
408   dnnl_graph_ = std::make_unique<dnnl::graph::graph>(engineKind);
409   aliasDb_ = std::make_unique<torch::jit::AliasDb>(graph);
410   GRAPH_DEBUG("Constructing LLGA graph");
411   // TODO: select nodes in top-level block for now
412   for (auto* node : graph->block()->nodes()) {
413     auto kindOfNode = node->kind();
414     GRAPH_DEBUG("Trying to add ", kindOfNode.toQualString());
415     if (checkInputCompatibility(node)) {
416       auto op = createOperator(node);
417       dnnl_graph_->add_op(op.llgaOp());
418       GRAPH_DEBUG("  Added node ", kindOfNode.toQualString());
419     } else {
420       GRAPH_DEBUG("Incompatible inputs for ", kindOfNode.toQualString());
421       dnnl_graph_->add_op(makeWildcardOp(node).llgaOp());
422     }
423 
424     for (Value* input : node->inputs()) {
425       tensorIdToValue_.emplace(input->unique(), input);
426     }
427   }
428 
429   dnnl_graph_->finalize();
430 
431   GRAPH_DEBUG("Get Partitions");
432   std::vector<dnnl::graph::partition> partitions =
433       dnnl_graph_->get_partitions(policy);
434   // excluded unsupported Wildcard partitions
435   for (auto& partition : partitions) {
436     if (partition.is_supported()) {
437       partitions_.push_back(partition);
438     }
439   }
440 
441   GRAPH_DEBUG("  Got #partitions: ", partitions_.size());
442   for (size_t partId = 0; partId < partitions_.size(); partId++) {
443     for (auto opId : partitions_[partId].get_ops()) {
444       opToOwningPartition_.add(opId, partId);
445     }
446   }
447 
448   // Scanning the graph again for post processing
449   for (auto* node : graph->block()->nodes()) {
450     mayAddListConstructIntoConcatPartition(node, opToOwningPartition_);
451   }
452 }
453 
isLlgaSubgraph(const Node * node)454 bool LlgaGraphHelper::isLlgaSubgraph(const Node* node) {
455   return node->hasAttribute(attr::Subgraph) &&
456       node->kind() == prim::oneDNNFusionGroup;
457 }
458 
shouldMerge(Node * toMerge,Node * subgraph)459 bool LlgaGraphHelper::shouldMerge(Node* toMerge, Node* subgraph) {
460   TORCH_CHECK(
461       isLlgaSubgraph(subgraph),
462       "The consumer node does not contain a subgraph");
463   if (!shouldConsiderForMerge(toMerge)) {
464     return false;
465   }
466   return opToOwningPartition_.get(toMerge) ==
467       opToOwningPartition_.get(subgraph);
468 }
469 
470 // Except for conv & GEMMs, which should always be handled by oneDNN Graph,
471 // only use single-op partitions for ops unsupported by NNC, or ops
472 // that oneDNN executes faster. prim::ListConstruct is an exception, since
473 // we simply want to fuse it with cat.
isBetterSuitedForLLGA(NodeKind kindOfOp)474 static bool isBetterSuitedForLLGA(NodeKind kindOfOp) {
475   return (
476       (kindOfOp == aten::layer_norm) || (kindOfOp == aten::avg_pool2d) ||
477       (kindOfOp == aten::matmul) || (kindOfOp == aten::max_pool2d) ||
478       (kindOfOp == aten::conv2d) || (kindOfOp == aten::_convolution) ||
479       (kindOfOp == aten::mm) || (kindOfOp == aten::linear) ||
480       (kindOfOp == aten::cat) || (kindOfOp == prim::ListConstruct));
481 }
482 
checkForSingleOpPartition(Node * node)483 bool LlgaGraphHelper::checkForSingleOpPartition(Node* node) {
484   if (opToOwningPartition_.has(node)) {
485     auto partitionId = opToOwningPartition_.get(node);
486     if (partitions_[partitionId].get_ops_num() == 1) {
487       auto kindOfNode = node->kind();
488       return isBetterSuitedForLLGA(kindOfNode);
489     } else {
490       // multi-op partition
491       return true;
492     }
493   } else {
494     // this op isn't present in any partition
495     return false;
496   }
497 }
498 
shouldConsiderForMerge(Node * node)499 bool LlgaGraphHelper::shouldConsiderForMerge(Node* node) {
500   // if we're already in the process of merging
501   if (isLlgaSubgraph(node)) {
502     return true;
503   }
504   return checkForSingleOpPartition(node);
505 }
506 
createSingletonSubgraph(Node * n,AliasDb & aliasDb)507 Node* LlgaGraphHelper::createSingletonSubgraph(Node* n, AliasDb& aliasDb) {
508   auto partitionId = opToOwningPartition_.get(n);
509   GRAPH_DEBUG(
510       "Creating FusionGroup_", partitionId, " for ", n->kind().toQualString());
511   auto group = SubgraphUtils::createSingletonSubgraphAndUpdateAliasing(
512       n, prim::oneDNNFusionGroup, aliasDb);
513   opToOwningPartition_.add(group, partitionId);
514   return group;
515 }
516 
mergeNodeIntoSubgraph(Node * toMerge,Node * subgraphNode,AliasDb & aliasDb)517 void LlgaGraphHelper::mergeNodeIntoSubgraph(
518     Node* toMerge,
519     Node* subgraphNode,
520     AliasDb& aliasDb) {
521   if (isLlgaSubgraph(toMerge)) {
522     GRAPH_DEBUG(
523         "Merging ",
524         toMerge->kind().toQualString(),
525         "_",
526         opToOwningPartition_.get(toMerge),
527         " into ",
528         subgraphNode->kind().toQualString(),
529         "_",
530         opToOwningPartition_.get(subgraphNode));
531   } else {
532     GRAPH_DEBUG(
533         "Merging ",
534         toMerge->kind().toQualString(),
535         " into ",
536         subgraphNode->kind().toQualString(),
537         "_",
538         opToOwningPartition_.get(subgraphNode));
539   }
540 
541   SubgraphUtils::mergeNodeIntoSubgraphAndUpdateAliasing(
542       toMerge, subgraphNode, aliasDb);
543 }
544 
unmergeIfAnyNodeIsMissing(Node * subgraphNode)545 void LlgaGraphHelper::unmergeIfAnyNodeIsMissing(Node* subgraphNode) {
546   TORCH_CHECK(isLlgaSubgraph(subgraphNode), "Cannot unmerge a non-LLGA node");
547 
548   auto partitionId = opToOwningPartition_.get(subgraphNode);
549   auto expectOpNum = partitions_[partitionId].get_ops_num();
550   auto actualOpNum = countSupportedOps(subgraphNode->g(attr::Subgraph));
551 
552   if (expectOpNum != actualOpNum) {
553     GRAPH_DEBUG(
554         "Unmerging FusionGroup_",
555         partitionId,
556         ". Expected ",
557         expectOpNum,
558         " ops, but got ",
559         actualOpNum,
560         " ops.");
561     SubgraphUtils::unmergeSubgraph(subgraphNode);
562   }
563 }
564 
countSupportedOps(const std::shared_ptr<Graph> & graph) const565 size_t LlgaGraphHelper::countSupportedOps(
566     const std::shared_ptr<Graph>& graph) const {
567   // TODO: count nodes in top-level block for now
568   size_t cnt = 0;
569   for (auto* node : graph->block()->nodes()) {
570     auto nodeKind = node->kind();
571     if ((nodeKind != prim::Constant) && (nodeKind != prim::ListConstruct)) {
572       cnt++;
573     }
574   }
575   return cnt;
576 }
577 
getPartitions() const578 std::vector<dnnl::graph::partition> LlgaGraphHelper::getPartitions() const {
579   return partitions_;
580 }
581 
getTensorIdToValue() const582 std::map<size_t, Value*> LlgaGraphHelper::getTensorIdToValue() const {
583   return tensorIdToValue_;
584 }
585 
LlgaNodeWrapper(const Node * node)586 LlgaNodeWrapper::LlgaNodeWrapper(const Node* node)
587     : n(const_cast<Node*>(node)) { // NOLINT
588   TORCH_CHECK(
589       LlgaGraphHelper::isLlgaSubgraph(n), "Cannot wrap a non-LLGA fusion node");
590 }
591 
setOpaqueLayout(size_t offset)592 void LlgaNodeWrapper::setOpaqueLayout(size_t offset) {
593   const auto num_output = n->is(attr::output_layouts).size();
594   TORCH_CHECK(
595       offset < num_output,
596       "Out of range. (Invalid index ",
597       offset,
598       " for attr::output_layouts with size ",
599       num_output,
600       ")");
601   auto& layouts =
602       const_cast<std::vector<int64_t>&>(n->is(attr::output_layouts)); // NOLINT
603   layouts.at(offset) = OPAQUE_LAYOUT;
604 }
605 
useOpaqueLayout(size_t offset) const606 bool LlgaNodeWrapper::useOpaqueLayout(size_t offset) const {
607   const auto num_output = n->is(attr::output_layouts).size();
608   TORCH_CHECK(
609       offset < num_output,
610       "Out of range. (Invalid index ",
611       offset,
612       " for attr::output_layouts with size ",
613       num_output,
614       ")");
615   return n->is(attr::output_layouts)[offset] == OPAQUE_LAYOUT;
616 }
617 
618 } // namespace onednn
619 } // namespace fuser
620 } // namespace jit
621 } // namespace torch
622