xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/onnx/pattern_conversion/common.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/onnx/pattern_conversion/common.h>
2 
3 namespace torch {
4 namespace jit {
5 
IsSameSource(const Node * n,const Node * m)6 bool IndexingPatternFinder::IsSameSource(const Node* n, const Node* m) {
7   const auto source_n = n->sourceRange().source();
8   const auto source_m = m->sourceRange().source();
9   return (
10       (source_n->text_str() == source_m->text_str()) &&
11       (source_n->starting_line_no() == source_m->starting_line_no()));
12 }
13 
14 // Trace back all the slice & select nodes associated with the index_put node.
15 // E.g. The IR for x[1:3, 0] = update
16 //    ...
17 //    %8 : Float(2, 4) = aten::slice(%0, %4, %5, %6, %7)
18 //    ...
19 //    %11 : Float(2) = aten::select(%8, %9, %10)
20 //    ...
21 //    %13 : Tensor?[] = prim::ListConstruct()
22 //    ...
23 //    %16 : Float(2) = aten::index_put(%11, %13, %14, %15)
24 //
25 // We collect %11 and %8, to construct the index tensors.
26 // The vector slice_and_select_node contains all the associated slice and
27 // select node, in the reversed order.
FetchSliceAndSelect(const Node * node)28 std::vector<Node*> IndexingPatternFinder::FetchSliceAndSelect(
29     const Node* node) {
30   std::vector<Node*> slice_and_select_node;
31   auto src_node = node->input(0)->node();
32   while (src_node) {
33     if ((src_node->kind() == aten::slice || src_node->kind() == aten::select) &&
34         IsSameSource(src_node, node)) {
35       slice_and_select_node.emplace_back(src_node);
36       src_node = src_node->input(0)->node();
37     } else {
38       src_node = nullptr;
39     }
40   }
41   return slice_and_select_node;
42 }
43 
44 } // namespace jit
45 } // namespace torch
46