1 #include <torch/csrc/jit/passes/onnx/pattern_conversion/common.h>
2
3 namespace torch {
4 namespace jit {
5
IsSameSource(const Node * n,const Node * m)6 bool IndexingPatternFinder::IsSameSource(const Node* n, const Node* m) {
7 const auto source_n = n->sourceRange().source();
8 const auto source_m = m->sourceRange().source();
9 return (
10 (source_n->text_str() == source_m->text_str()) &&
11 (source_n->starting_line_no() == source_m->starting_line_no()));
12 }
13
14 // Trace back all the slice & select nodes associated with the index_put node.
15 // E.g. The IR for x[1:3, 0] = update
16 // ...
17 // %8 : Float(2, 4) = aten::slice(%0, %4, %5, %6, %7)
18 // ...
19 // %11 : Float(2) = aten::select(%8, %9, %10)
20 // ...
21 // %13 : Tensor?[] = prim::ListConstruct()
22 // ...
23 // %16 : Float(2) = aten::index_put(%11, %13, %14, %15)
24 //
25 // We collect %11 and %8, to construct the index tensors.
26 // The vector slice_and_select_node contains all the associated slice and
27 // select node, in the reversed order.
FetchSliceAndSelect(const Node * node)28 std::vector<Node*> IndexingPatternFinder::FetchSliceAndSelect(
29 const Node* node) {
30 std::vector<Node*> slice_and_select_node;
31 auto src_node = node->input(0)->node();
32 while (src_node) {
33 if ((src_node->kind() == aten::slice || src_node->kind() == aten::select) &&
34 IsSameSource(src_node, node)) {
35 slice_and_select_node.emplace_back(src_node);
36 src_node = src_node->input(0)->node();
37 } else {
38 src_node = nullptr;
39 }
40 }
41 return slice_and_select_node;
42 }
43
44 } // namespace jit
45 } // namespace torch
46