xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/codegen/onednn/layout_propagation.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/codegen/onednn/graph_helper.h>
2 #include <torch/csrc/jit/codegen/onednn/layout_propagation.h>
3 #include <torch/csrc/jit/jit_log.h>
4 
5 namespace torch {
6 namespace jit {
7 namespace fuser {
8 namespace onednn {
9 
LayoutPropagation(Node * n)10 static void LayoutPropagation(Node* n) {
11   if (!LlgaGraphHelper::isLlgaSubgraph(n))
12     return;
13 
14   // initial attr::output_layouts if undefined
15   if (!n->hasAttribute(attr::output_layouts)) {
16     const auto num_output = n->outputs().size();
17     GRAPH_DEBUG("Initial output_layouts of size ", num_output);
18     std::vector<int64_t> layouts(num_output, STRIDED_LAYOUT);
19     n->is_(attr::output_layouts, layouts);
20   }
21 
22   for (auto input : n->inputs()) {
23     auto prev = input->node();
24     auto offset = input->offset();
25     if (LlgaGraphHelper::isLlgaSubgraph(prev)) {
26       bool useOpaqueLayout = true;
27       for (auto& use : input->uses()) {
28         if (!LlgaGraphHelper::isLlgaSubgraph(use.user)) {
29           useOpaqueLayout = false;
30           break;
31         }
32       }
33       if (useOpaqueLayout) {
34         LlgaNodeWrapper(prev).setOpaqueLayout(offset);
35       }
36     }
37   }
38 }
39 
LayoutPropagation(at::ArrayRef<Block * > blocks)40 static void LayoutPropagation(at::ArrayRef<Block*> blocks) {
41   for (Block* block : blocks)
42     for (Node* node : block->nodes())
43       LayoutPropagation(node);
44 }
45 
PropagateLayout(const std::shared_ptr<Graph> & graph)46 void PropagateLayout(const std::shared_ptr<Graph>& graph) {
47   LayoutPropagation(graph->block());
48 }
49 
50 } // namespace onednn
51 } // namespace fuser
52 } // namespace jit
53 } // namespace torch
54