1 #include <torch/csrc/jit/codegen/onednn/graph_helper.h> 2 #include <torch/csrc/jit/codegen/onednn/layout_propagation.h> 3 #include <torch/csrc/jit/jit_log.h> 4 5 namespace torch { 6 namespace jit { 7 namespace fuser { 8 namespace onednn { 9 LayoutPropagation(Node * n)10static void LayoutPropagation(Node* n) { 11 if (!LlgaGraphHelper::isLlgaSubgraph(n)) 12 return; 13 14 // initial attr::output_layouts if undefined 15 if (!n->hasAttribute(attr::output_layouts)) { 16 const auto num_output = n->outputs().size(); 17 GRAPH_DEBUG("Initial output_layouts of size ", num_output); 18 std::vector<int64_t> layouts(num_output, STRIDED_LAYOUT); 19 n->is_(attr::output_layouts, layouts); 20 } 21 22 for (auto input : n->inputs()) { 23 auto prev = input->node(); 24 auto offset = input->offset(); 25 if (LlgaGraphHelper::isLlgaSubgraph(prev)) { 26 bool useOpaqueLayout = true; 27 for (auto& use : input->uses()) { 28 if (!LlgaGraphHelper::isLlgaSubgraph(use.user)) { 29 useOpaqueLayout = false; 30 break; 31 } 32 } 33 if (useOpaqueLayout) { 34 LlgaNodeWrapper(prev).setOpaqueLayout(offset); 35 } 36 } 37 } 38 } 39 LayoutPropagation(at::ArrayRef<Block * > blocks)40static void LayoutPropagation(at::ArrayRef<Block*> blocks) { 41 for (Block* block : blocks) 42 for (Node* node : block->nodes()) 43 LayoutPropagation(node); 44 } 45 PropagateLayout(const std::shared_ptr<Graph> & graph)46void PropagateLayout(const std::shared_ptr<Graph>& graph) { 47 LayoutPropagation(graph->block()); 48 } 49 50 } // namespace onednn 51 } // namespace fuser 52 } // namespace jit 53 } // namespace torch 54