xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/onnx/preprocess_for_onnx.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/onnx/preprocess_for_onnx.h>
2 
3 #include <ATen/ScalarOps.h>
4 #include <c10/util/irange.h>
5 
6 #include <torch/csrc/jit/jit_log.h>
7 #include <torch/csrc/jit/passes/onnx/helper.h>
8 
9 namespace torch::jit {
10 
11 namespace onnx {
12 using namespace ::c10::onnx;
13 }
14 
15 namespace {
16 
FindFusibleListUnpack(Node * n)17 std::optional<Node*> FindFusibleListUnpack(Node* n) {
18   // 1. number of outputs is restricted to 1.
19   // 2. output is only used by prim::ListUnpack.
20   if (n->outputs().size() != 1) {
21     return std::nullopt;
22   }
23   if (n->output()->uses().size() != 1) {
24     return std::nullopt;
25   }
26   auto listUnpackNode = n->output()->uses()[0].user;
27   if (listUnpackNode->kind() != prim::ListUnpack) {
28     return std::nullopt;
29   }
30   return listUnpackNode;
31 }
32 
33 // Fuse node + ListUnpack
34 // Node such as split/unbind produces tensor[] of static size,
35 // that is later unpacked by ListUnpack.
36 // This pass fuses the two nodes, and adds an additional input "_outputs" such
37 // that the symbolic function is aware of the number of outputs.
38 //
39 // Example IR
40 //  split.Tensor(Tensor(a -> *) self, int split_size, int dim=0) -> Tensor[]
41 //  split_with_sizes(Tensor self, int[] split_sizes, int dim=0) -> Tensor[]
42 //
43 // graph(%input : Float(5, 4, 3, strides=[12, 3, 1])):
44 //   %13 : int[] = prim::Constant[value=[2, 1, 2]]()
45 //   %7 : int = prim::Constant[value=0]()
46 //   %8 : Tensor[] = aten::split_with_sizes(%input, %13, %7)
47 //   %9 : Float(2, 4, 3, strides=[12, 3, 1]), %10 : Float(1, 4, 3, strides=[12,
48 //   3, 1]), %11 : Float(2, 4, 3, strides=[12, 3, 1]) = prim::ListUnpack(%8)
49 //   return (%9, %10, %11)
50 //
51 // After fusion
52 // graph(%input : Float(5, 4, 3, strides=[12, 3, 1])):
53 //   %13 : int[] = prim::Constant[value=[2, 1, 2]]()
54 //   %7 : int = prim::Constant[value=0]()
55 //   %8 : int = prim::Constant[value=3]()  # Adding additional input of value 3
56 //      representing the number of outputs.
57 //   %14 : Float(2, 4, 3, strides=[12, 3, 1]), %15 : Float(1, 4, 3, strides=[12,
58 //      3, 1]), %16 : Float(2, 4, 3, strides=[12, 3, 1] =
59 //      aten::split_with_sizes(%input, %13, %7, %8) return (%14, %15, %16)
FuseWithListUnpack(Node * n)60 void FuseWithListUnpack(Node* n) {
61   auto found_listUnpack = FindFusibleListUnpack(n);
62   if (!found_listUnpack) {
63     return;
64   }
65 
66   auto listUnpack_node = found_listUnpack.value();
67 
68   TORCH_INTERNAL_ASSERT(n->outputs().size() == 1);
69   // 1. Add internal input "_outputs" to node, so that later symbolic function
70   //    conversion is aware of the number of outputs.
71   // 2. Add the exact number of outputs to n, copy metadata and replace uses of
72   //    listUnpack outputs.
73   n->i_(
74       Symbol::fromQualString("attr::_outputs"),
75       static_cast<int64_t>(listUnpack_node->outputs().size()));
76 
77   for (size_t i = 0; i < listUnpack_node->outputs().size(); ++i) {
78     auto new_output = n->addOutput();
79     new_output->copyMetadata(listUnpack_node->output(i));
80   }
81   listUnpack_node->removeAllInputs();
82   // remove original output, which is input to listUnpack node.
83   n->eraseOutput(0);
84   listUnpack_node->replaceAllUsesWith(n);
85 }
86 
FuseWithListUnpack(Block * b)87 static void FuseWithListUnpack(Block* b) {
88   for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
89     for (auto* child_block : it->blocks()) {
90       FuseWithListUnpack(child_block);
91     }
92 
93     auto n_kind = it->kind();
94     switch (n_kind) {
95       case aten::split:
96       case aten::split_with_sizes:
97       case aten::unsafe_split:
98       case aten::unsafe_split_with_sizes:
99       case aten::unbind:
100       case aten::unsafe_chunk:
101       case aten::where:
102       case aten::nonzero_numpy:
103         FuseWithListUnpack(*it);
104         break;
105       default:
106         break;
107     }
108   }
109 }
110 
111 // Replace aten::add with onnx::Concat
112 // when inputs to the add node are two int lists
113 //
114 // before the pass:
115 // graph(%x.1 : Float(2, 3, 4, strides=[12, 4, 1], requires_grad=0, device=cpu),
116 //  %y.1 : Float(1, 2, 3, strides=[6, 3, 1], requires_grad=0, device=cpu)):
117 //  %2 : None = prim::Constant()
118 //  %3 : int[] = aten::size(%x.1)
119 //  %l1.1 : int[] = aten::list(%3
120 //  %5 : int[] = aten::size(%y.1)
121 //  %l2.1 : int[] = aten::list(%5)
122 //  %7 : int[] = aten::add(%l1.1, %l2.1)
123 //  %8 : Tensor = aten::new_zeros(%x.1, %7, %2, %2, %2, %2)
124 //  return (%8)
125 //
126 // after the pass:
127 // graph(%x.1 : Float(2, 3, 4, strides=[12, 4, 1], requires_grad=0, device=cpu),
128 //  %y.1 : Float(1, 2, 3, strides=[6, 3, 1], requires_grad=0, device=cpu)):
129 //  %2 : None = prim::Constant()
130 //  %3 : int[] = aten::size(%x.1)
131 //  %l1.1 : int[] = aten::list(%3)
132 //  %5 : int[] = aten::size(%y.1)
133 //  %l2.1 : int[] = aten::list(%5)
134 //  %9 : Tensor = onnx::Concat[axis=0](%l1.1, %l2.1)
135 //  %8 : Tensor = aten::new_zeros(%x.1, %9, %2, %2, %2, %2)
136 //  return (%8)
ReplaceAddWithConcat(Block * b)137 static void ReplaceAddWithConcat(Block* b) {
138   for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
139     for (auto* child_block : it->blocks()) {
140       ReplaceAddWithConcat(child_block);
141     }
142     if (it->kind() == aten::add) {
143       if (!it->input(0)->type()->cast<ListType>() ||
144           !it->input(1)->type()->cast<ListType>()) {
145         continue;
146       }
147 
148       const auto& elem =
149           it->input(0)->type()->castRaw<ListType>()->getElementType();
150       if (elem->cast<IntType>()) {
151         Node* concat_node = b->owningGraph()->create(onnx::Concat, 1);
152         concat_node->i_(attr::axis, 0);
153         concat_node->insertBefore(*it);
154         concat_node->addInput(it->input(0));
155         concat_node->addInput(it->input(1));
156         concat_node->outputs()[0]->setType(TensorType::fromNumberType(*elem));
157         concat_node->copyMetadata(*it);
158         it->replaceAllUsesWith(concat_node);
159         it->removeAllInputs();
160         it.destroyCurrent();
161       }
162     }
163   }
164 }
165 
166 // This pass also covers the case when the input to ListUnpack
167 // is int[] coming from some other op than ListConstruct (like Slice or Shape)
168 //
169 // before the pass
170 // graph(%x.1 : Float(2, 3, strides=[3, 1], requires_grad=0, device=cpu)):
171 //   %1 : None = prim::Constant()
172 //   %2 : int[] = aten::size(%x.1)
173 //   %a.1 : int, %b.1 : int = prim::ListUnpack(%2)
174 //   %5 : int[] = prim::ListConstruct(%a.1, %b.1)
175 //   %6 : Tensor = aten::new_zeros(%x.1, %5, %1, %1, %1, %1)
176 //
177 // after the pass:
178 // graph(%x.1 : Float(2, 3, strides=[3, 1], requires_grad=0, device=cpu)):
179 //   %1 : None = prim::Constant()
180 //   %2 : int[] = aten::size(%x.1)
181 //   %7 : Tensor = onnx::Constant[value={0}]()
182 //   %8 : Tensor = onnx::Gather(%2, %7)
183 //   %9 : Tensor = onnx::Constant[value={1}]()
184 //   %10 : Tensor = onnx::Gather(%2, %9)
185 //   %a.1 : int, %b.1 : int = prim::ListUnpack(%2)
186 //   %5 : int[] = prim::ListConstruct(%8, %10)
187 //   %6 : Tensor = aten::new_zeros(%x.1, %5, %1, %1, %1, %1)
fuseListAndListUnpack(Block * b)188 static void fuseListAndListUnpack(Block* b) {
189   for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
190     for (auto* child_block : it->blocks()) {
191       fuseListAndListUnpack(child_block);
192     }
193     if (it->kind() == prim::ListUnpack) {
194       for (const auto i : c10::irange(it->outputs().size())) {
195         auto output = it->outputs().at(i);
196         if (it->inputs().size() == 1 &&
197             it->input()->node()->kind() != prim::ListConstruct &&
198             it->input()->type()->cast<ListType>() &&
199             it->input()
200                 ->type()
201                 ->castRaw<ListType>()
202                 ->getElementType()
203                 ->cast<IntType>()) {
204           Node* gather_indices = b->owningGraph()->create(onnx::Constant, 1);
205           gather_indices->insertBefore(*it);
206           gather_indices->t_(
207               attr::value, at::scalar_to_tensor(at::Scalar(int(i))));
208           Node* gather_node = b->owningGraph()->create(onnx::Gather, 1);
209           gather_node->insertBefore(*it);
210           gather_node->addInput(it->input());
211           gather_node->addInput(gather_indices->output());
212           gather_node->copyMetadata(*it);
213           output->replaceAllUsesWith(gather_node->output());
214         }
215       }
216     }
217   }
218 }
219 
220 } // namespace
221 
PreprocessForONNX(std::shared_ptr<Graph> & graph)222 void PreprocessForONNX(std::shared_ptr<Graph>& graph) {
223   FuseWithListUnpack(graph->block());
224   GRAPH_DUMP("After FuseWithListUnpack: ", graph);
225   ReplaceAddWithConcat(graph->block());
226   GRAPH_DUMP("After ReplaceAddWithConcat: ", graph);
227   fuseListAndListUnpack(graph->block());
228   GRAPH_DUMP("After fuseListAndListUnpack: ", graph);
229 }
230 
231 } // namespace torch::jit
232