xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/onnx/pattern_conversion/pattern_encapsulation.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/dead_code_elimination.h>
2 #include <torch/csrc/jit/passes/onnx.h>
3 #include <torch/csrc/jit/passes/onnx/pattern_conversion/common.h>
4 #include <torch/csrc/jit/passes/onnx/pattern_conversion/pattern_encapsulation.h>
5 #include <torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.h>
6 
7 // EDITING THIS FILE? READ THIS FIRST!
8 // see Note [Edit Pattern Encapsulation] in pattern_encapsulation.h
9 
10 namespace torch {
11 namespace jit {
12 
13 namespace {
14 
15 // Trace back all the slice & select nodes associated with the index_put node,
16 // and copy them under the placeholder subblock.
17 // E.g. The IR for x[1:3, 0] = update
18 //    ...
19 //    %8 : Float(2, 4) = aten::slice(%0, %4, %5, %6, %7)
20 //    ...
21 //    %11 : Float(2) = aten::select(%8, %9, %10)
22 //    ...
23 //    %13 : Tensor?[] = prim::ListConstruct()
24 //    ...
25 //    %16 : Float(2) = aten::index_put(%11, %13, %14, %15)
26 // The aten::index_put node alone does not contain any indices (%13 : Tensor?[]
27 // = prim::ListConstruct()).
EncapsulateInplaceIndexPutForONNX(Node * index_put_node)28 Node* EncapsulateInplaceIndexPutForONNX(Node* index_put_node) {
29   auto graph = index_put_node->owningGraph();
30 
31   // Find slice and select operators that are associated with this index
32   // operator. E.g. x[1:3, 0] = y will generate one slice operator(1:3) and one
33   // select operator(0).
34   std::vector<Node*> slice_and_select_nodes =
35       IndexingPatternFinder::FetchSliceAndSelect(index_put_node);
36   Node* last_node = !slice_and_select_nodes.empty()
37       ? slice_and_select_nodes.back()
38       : index_put_node;
39   Value* orig_data = last_node->input(0);
40 
41   // Copy related nodes into subblock of a new special placeholder node.
42   Node* placeholder_node =
43       graph->create(Symbol::fromQualString("onnx::Placeholder"));
44   placeholder_node->s_(attr::name, index_put_node->kind().toUnqualString());
45   placeholder_node->addInput(orig_data);
46 
47   // Construct subblock
48   auto subblock = placeholder_node->addBlock();
49   std::unordered_map<Value*, Value*> env;
50 
51   // slice_and_select_nodes are in reversed order.
52   for (auto it = slice_and_select_nodes.rbegin();
53        it != slice_and_select_nodes.rend();
54        ++it) {
55     auto n = *it;
56     auto cloned_n = subblock->appendNode(graph->createClone(
57         n, [&](Value* v) { return env.find(v) != env.end() ? env[v] : v; }));
58     for (size_t i = 0; i < cloned_n->outputs().size(); ++i) {
59       env[n->outputs().at(i)] = cloned_n->outputs().at(i);
60     }
61   }
62 
63   Node* new_index_put_node =
64       subblock->appendNode(graph->createClone(index_put_node, [&](Value* v) {
65         return env.find(v) != env.end() ? env[v] : v;
66       }));
67   for (auto o : new_index_put_node->outputs()) {
68     subblock->registerOutput(o);
69   }
70 
71   placeholder_node->insertBefore(index_put_node);
72   placeholder_node->copyMetadata(index_put_node);
73   index_put_node->replaceAllUsesWith(placeholder_node);
74 
75   return placeholder_node;
76 }
77 
78 } // namespace
79 
EncapsulatePatternIntoSubblock(Node * n)80 std::optional<Node*> EncapsulatePatternIntoSubblock(Node* n) {
81   switch (n->kind()) {
82     case aten::index_put_:
83     case aten::index_put: {
84       return EncapsulateInplaceIndexPutForONNX(n);
85     }
86   }
87   return std::nullopt;
88 }
89 
90 } // namespace jit
91 } // namespace torch
92