1 #include <torch/csrc/jit/passes/dead_code_elimination.h>
2 #include <torch/csrc/jit/passes/onnx.h>
3 #include <torch/csrc/jit/passes/onnx/pattern_conversion/common.h>
4 #include <torch/csrc/jit/passes/onnx/pattern_conversion/pattern_encapsulation.h>
5 #include <torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.h>
6
7 // EDITING THIS FILE? READ THIS FIRST!
8 // see Note [Edit Pattern Encapsulation] in pattern_encapsulation.h
9
10 namespace torch {
11 namespace jit {
12
13 namespace {
14
15 // Trace back all the slice & select nodes associated with the index_put node,
16 // and copy them under the placeholder subblock.
17 // E.g. The IR for x[1:3, 0] = update
18 // ...
19 // %8 : Float(2, 4) = aten::slice(%0, %4, %5, %6, %7)
20 // ...
21 // %11 : Float(2) = aten::select(%8, %9, %10)
22 // ...
23 // %13 : Tensor?[] = prim::ListConstruct()
24 // ...
25 // %16 : Float(2) = aten::index_put(%11, %13, %14, %15)
26 // The aten::index_put node alone does not contain any indices (%13 : Tensor?[]
27 // = prim::ListConstruct()).
EncapsulateInplaceIndexPutForONNX(Node * index_put_node)28 Node* EncapsulateInplaceIndexPutForONNX(Node* index_put_node) {
29 auto graph = index_put_node->owningGraph();
30
31 // Find slice and select operators that are associated with this index
32 // operator. E.g. x[1:3, 0] = y will generate one slice operator(1:3) and one
33 // select operator(0).
34 std::vector<Node*> slice_and_select_nodes =
35 IndexingPatternFinder::FetchSliceAndSelect(index_put_node);
36 Node* last_node = !slice_and_select_nodes.empty()
37 ? slice_and_select_nodes.back()
38 : index_put_node;
39 Value* orig_data = last_node->input(0);
40
41 // Copy related nodes into subblock of a new special placeholder node.
42 Node* placeholder_node =
43 graph->create(Symbol::fromQualString("onnx::Placeholder"));
44 placeholder_node->s_(attr::name, index_put_node->kind().toUnqualString());
45 placeholder_node->addInput(orig_data);
46
47 // Construct subblock
48 auto subblock = placeholder_node->addBlock();
49 std::unordered_map<Value*, Value*> env;
50
51 // slice_and_select_nodes are in reversed order.
52 for (auto it = slice_and_select_nodes.rbegin();
53 it != slice_and_select_nodes.rend();
54 ++it) {
55 auto n = *it;
56 auto cloned_n = subblock->appendNode(graph->createClone(
57 n, [&](Value* v) { return env.find(v) != env.end() ? env[v] : v; }));
58 for (size_t i = 0; i < cloned_n->outputs().size(); ++i) {
59 env[n->outputs().at(i)] = cloned_n->outputs().at(i);
60 }
61 }
62
63 Node* new_index_put_node =
64 subblock->appendNode(graph->createClone(index_put_node, [&](Value* v) {
65 return env.find(v) != env.end() ? env[v] : v;
66 }));
67 for (auto o : new_index_put_node->outputs()) {
68 subblock->registerOutput(o);
69 }
70
71 placeholder_node->insertBefore(index_put_node);
72 placeholder_node->copyMetadata(index_put_node);
73 index_put_node->replaceAllUsesWith(placeholder_node);
74
75 return placeholder_node;
76 }
77
78 } // namespace
79
EncapsulatePatternIntoSubblock(Node * n)80 std::optional<Node*> EncapsulatePatternIntoSubblock(Node* n) {
81 switch (n->kind()) {
82 case aten::index_put_:
83 case aten::index_put: {
84 return EncapsulateInplaceIndexPutForONNX(n);
85 }
86 }
87 return std::nullopt;
88 }
89
90 } // namespace jit
91 } // namespace torch
92