1 #include <torch/csrc/jit/passes/onnx/preprocess_for_onnx.h>
2
3 #include <ATen/ScalarOps.h>
4 #include <c10/util/irange.h>
5
6 #include <torch/csrc/jit/jit_log.h>
7 #include <torch/csrc/jit/passes/onnx/helper.h>
8
9 namespace torch::jit {
10
11 namespace onnx {
12 using namespace ::c10::onnx;
13 }
14
15 namespace {
16
FindFusibleListUnpack(Node * n)17 std::optional<Node*> FindFusibleListUnpack(Node* n) {
18 // 1. number of outputs is restricted to 1.
19 // 2. output is only used by prim::ListUnpack.
20 if (n->outputs().size() != 1) {
21 return std::nullopt;
22 }
23 if (n->output()->uses().size() != 1) {
24 return std::nullopt;
25 }
26 auto listUnpackNode = n->output()->uses()[0].user;
27 if (listUnpackNode->kind() != prim::ListUnpack) {
28 return std::nullopt;
29 }
30 return listUnpackNode;
31 }
32
33 // Fuse node + ListUnpack
34 // Node such as split/unbind produces tensor[] of static size,
35 // that is later unpacked by ListUnpack.
36 // This pass fuses the two nodes, and adds an additional input "_outputs" such
37 // that the symbolic function is aware of the number of outputs.
38 //
39 // Example IR
40 // split.Tensor(Tensor(a -> *) self, int split_size, int dim=0) -> Tensor[]
41 // split_with_sizes(Tensor self, int[] split_sizes, int dim=0) -> Tensor[]
42 //
43 // graph(%input : Float(5, 4, 3, strides=[12, 3, 1])):
44 // %13 : int[] = prim::Constant[value=[2, 1, 2]]()
45 // %7 : int = prim::Constant[value=0]()
46 // %8 : Tensor[] = aten::split_with_sizes(%input, %13, %7)
47 // %9 : Float(2, 4, 3, strides=[12, 3, 1]), %10 : Float(1, 4, 3, strides=[12,
48 // 3, 1]), %11 : Float(2, 4, 3, strides=[12, 3, 1]) = prim::ListUnpack(%8)
49 // return (%9, %10, %11)
50 //
51 // After fusion
52 // graph(%input : Float(5, 4, 3, strides=[12, 3, 1])):
53 // %13 : int[] = prim::Constant[value=[2, 1, 2]]()
54 // %7 : int = prim::Constant[value=0]()
55 // %8 : int = prim::Constant[value=3]() # Adding additional input of value 3
56 // representing the number of outputs.
57 // %14 : Float(2, 4, 3, strides=[12, 3, 1]), %15 : Float(1, 4, 3, strides=[12,
58 // 3, 1]), %16 : Float(2, 4, 3, strides=[12, 3, 1] =
59 // aten::split_with_sizes(%input, %13, %7, %8) return (%14, %15, %16)
FuseWithListUnpack(Node * n)60 void FuseWithListUnpack(Node* n) {
61 auto found_listUnpack = FindFusibleListUnpack(n);
62 if (!found_listUnpack) {
63 return;
64 }
65
66 auto listUnpack_node = found_listUnpack.value();
67
68 TORCH_INTERNAL_ASSERT(n->outputs().size() == 1);
69 // 1. Add internal input "_outputs" to node, so that later symbolic function
70 // conversion is aware of the number of outputs.
71 // 2. Add the exact number of outputs to n, copy metadata and replace uses of
72 // listUnpack outputs.
73 n->i_(
74 Symbol::fromQualString("attr::_outputs"),
75 static_cast<int64_t>(listUnpack_node->outputs().size()));
76
77 for (size_t i = 0; i < listUnpack_node->outputs().size(); ++i) {
78 auto new_output = n->addOutput();
79 new_output->copyMetadata(listUnpack_node->output(i));
80 }
81 listUnpack_node->removeAllInputs();
82 // remove original output, which is input to listUnpack node.
83 n->eraseOutput(0);
84 listUnpack_node->replaceAllUsesWith(n);
85 }
86
FuseWithListUnpack(Block * b)87 static void FuseWithListUnpack(Block* b) {
88 for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
89 for (auto* child_block : it->blocks()) {
90 FuseWithListUnpack(child_block);
91 }
92
93 auto n_kind = it->kind();
94 switch (n_kind) {
95 case aten::split:
96 case aten::split_with_sizes:
97 case aten::unsafe_split:
98 case aten::unsafe_split_with_sizes:
99 case aten::unbind:
100 case aten::unsafe_chunk:
101 case aten::where:
102 case aten::nonzero_numpy:
103 FuseWithListUnpack(*it);
104 break;
105 default:
106 break;
107 }
108 }
109 }
110
111 // Replace aten::add with onnx::Concat
112 // when inputs to the add node are two int lists
113 //
114 // before the pass:
115 // graph(%x.1 : Float(2, 3, 4, strides=[12, 4, 1], requires_grad=0, device=cpu),
116 // %y.1 : Float(1, 2, 3, strides=[6, 3, 1], requires_grad=0, device=cpu)):
117 // %2 : None = prim::Constant()
118 // %3 : int[] = aten::size(%x.1)
119 // %l1.1 : int[] = aten::list(%3
120 // %5 : int[] = aten::size(%y.1)
121 // %l2.1 : int[] = aten::list(%5)
122 // %7 : int[] = aten::add(%l1.1, %l2.1)
123 // %8 : Tensor = aten::new_zeros(%x.1, %7, %2, %2, %2, %2)
124 // return (%8)
125 //
126 // after the pass:
127 // graph(%x.1 : Float(2, 3, 4, strides=[12, 4, 1], requires_grad=0, device=cpu),
128 // %y.1 : Float(1, 2, 3, strides=[6, 3, 1], requires_grad=0, device=cpu)):
129 // %2 : None = prim::Constant()
130 // %3 : int[] = aten::size(%x.1)
131 // %l1.1 : int[] = aten::list(%3)
132 // %5 : int[] = aten::size(%y.1)
133 // %l2.1 : int[] = aten::list(%5)
134 // %9 : Tensor = onnx::Concat[axis=0](%l1.1, %l2.1)
135 // %8 : Tensor = aten::new_zeros(%x.1, %9, %2, %2, %2, %2)
136 // return (%8)
ReplaceAddWithConcat(Block * b)137 static void ReplaceAddWithConcat(Block* b) {
138 for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
139 for (auto* child_block : it->blocks()) {
140 ReplaceAddWithConcat(child_block);
141 }
142 if (it->kind() == aten::add) {
143 if (!it->input(0)->type()->cast<ListType>() ||
144 !it->input(1)->type()->cast<ListType>()) {
145 continue;
146 }
147
148 const auto& elem =
149 it->input(0)->type()->castRaw<ListType>()->getElementType();
150 if (elem->cast<IntType>()) {
151 Node* concat_node = b->owningGraph()->create(onnx::Concat, 1);
152 concat_node->i_(attr::axis, 0);
153 concat_node->insertBefore(*it);
154 concat_node->addInput(it->input(0));
155 concat_node->addInput(it->input(1));
156 concat_node->outputs()[0]->setType(TensorType::fromNumberType(*elem));
157 concat_node->copyMetadata(*it);
158 it->replaceAllUsesWith(concat_node);
159 it->removeAllInputs();
160 it.destroyCurrent();
161 }
162 }
163 }
164 }
165
166 // This pass also covers the case when the input to ListUnpack
167 // is int[] coming from some other op than ListConstruct (like Slice or Shape)
168 //
169 // before the pass
170 // graph(%x.1 : Float(2, 3, strides=[3, 1], requires_grad=0, device=cpu)):
171 // %1 : None = prim::Constant()
172 // %2 : int[] = aten::size(%x.1)
173 // %a.1 : int, %b.1 : int = prim::ListUnpack(%2)
174 // %5 : int[] = prim::ListConstruct(%a.1, %b.1)
175 // %6 : Tensor = aten::new_zeros(%x.1, %5, %1, %1, %1, %1)
176 //
177 // after the pass:
178 // graph(%x.1 : Float(2, 3, strides=[3, 1], requires_grad=0, device=cpu)):
179 // %1 : None = prim::Constant()
180 // %2 : int[] = aten::size(%x.1)
181 // %7 : Tensor = onnx::Constant[value={0}]()
182 // %8 : Tensor = onnx::Gather(%2, %7)
183 // %9 : Tensor = onnx::Constant[value={1}]()
184 // %10 : Tensor = onnx::Gather(%2, %9)
185 // %a.1 : int, %b.1 : int = prim::ListUnpack(%2)
186 // %5 : int[] = prim::ListConstruct(%8, %10)
187 // %6 : Tensor = aten::new_zeros(%x.1, %5, %1, %1, %1, %1)
fuseListAndListUnpack(Block * b)188 static void fuseListAndListUnpack(Block* b) {
189 for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
190 for (auto* child_block : it->blocks()) {
191 fuseListAndListUnpack(child_block);
192 }
193 if (it->kind() == prim::ListUnpack) {
194 for (const auto i : c10::irange(it->outputs().size())) {
195 auto output = it->outputs().at(i);
196 if (it->inputs().size() == 1 &&
197 it->input()->node()->kind() != prim::ListConstruct &&
198 it->input()->type()->cast<ListType>() &&
199 it->input()
200 ->type()
201 ->castRaw<ListType>()
202 ->getElementType()
203 ->cast<IntType>()) {
204 Node* gather_indices = b->owningGraph()->create(onnx::Constant, 1);
205 gather_indices->insertBefore(*it);
206 gather_indices->t_(
207 attr::value, at::scalar_to_tensor(at::Scalar(int(i))));
208 Node* gather_node = b->owningGraph()->create(onnx::Gather, 1);
209 gather_node->insertBefore(*it);
210 gather_node->addInput(it->input());
211 gather_node->addInput(gather_indices->output());
212 gather_node->copyMetadata(*it);
213 output->replaceAllUsesWith(gather_node->output());
214 }
215 }
216 }
217 }
218 }
219
220 } // namespace
221
PreprocessForONNX(std::shared_ptr<Graph> & graph)222 void PreprocessForONNX(std::shared_ptr<Graph>& graph) {
223 FuseWithListUnpack(graph->block());
224 GRAPH_DUMP("After FuseWithListUnpack: ", graph);
225 ReplaceAddWithConcat(graph->block());
226 GRAPH_DUMP("After ReplaceAddWithConcat: ", graph);
227 fuseListAndListUnpack(graph->block());
228 GRAPH_DUMP("After fuseListAndListUnpack: ", graph);
229 }
230
231 } // namespace torch::jit
232