1 #include <torch/csrc/jit/passes/onnx/peephole.h>
2
3 #include <c10/util/Exception.h>
4 #include <c10/util/irange.h>
5 #include <torch/csrc/jit/jit_log.h>
6 #include <torch/csrc/jit/passes/dead_code_elimination.h>
7 #include <torch/csrc/jit/passes/onnx/helper.h>
8
9 #include <ATen/ScalarOps.h>
10
11 #ifndef AT_PER_OPERATOR_HEADERS
12 #include <ATen/Functions.h>
13 #include <ATen/NativeFunctions.h>
14 #else
15 #include <ATen/ops/full.h>
16 #include <ATen/ops/ones_like_native.h>
17 #endif
18
19 #include <optional>
20
21 #if defined(_MSC_VER)
22 #include <BaseTsd.h>
23 typedef SSIZE_T ssize_t;
24 #endif
25
26 namespace torch::jit {
27
28 namespace onnx {
29 using namespace ::c10::onnx;
30 }
31
isRNN(const Node * node)32 bool isRNN(const Node* node) {
33 auto k = node->kind();
34 return k == onnx::RNN || k == onnx::LSTM || k == onnx::GRU;
35 }
36
isNopTranspose(const std::vector<int64_t> & perm)37 bool isNopTranspose(const std::vector<int64_t>& perm) {
38 for (int64_t i = 0, perm_size = perm.size(); i < perm_size; i++) {
39 if (perm[i] != i) {
40 return false;
41 }
42 }
43 return true;
44 }
45
46 // returns a vector `ret` such that transposing by `ret` is equivalent
47 // to transposing by `t1` and then by `t2`
48 //
49 // This fires in the case that we have transpose ops T1 -> T2. We are
50 // fusing the transpose op T1 into T2 and discarding T1. We assume the elements
51 // of the permutation in `t1` are raw indices into its input, since a previous
52 // iteration would have folded all the transposes up to that point. Thus,
53 // `ret[i] = t1[t2[i]]` says "the output of t2 at position i takes the value of
54 // the input tensor index contained in t1 at position `t2[i]``".
composeTransposes(const std::vector<int64_t> & t1,const std::vector<int64_t> & t2)55 std::vector<int64_t> composeTransposes(
56 const std::vector<int64_t>& t1,
57 const std::vector<int64_t>& t2) {
58 TORCH_INTERNAL_ASSERT(t1.size() == t2.size());
59 std::vector<int64_t> ret;
60 ret.reserve(t1.size());
61 for (const auto& i : t2) {
62 TORCH_INTERNAL_ASSERT(i < int64_t(t1.size()));
63 ret.push_back(t1[i]);
64 }
65 return ret;
66 }
67
getBroadcastPositions(Node * node)68 std::vector<size_t> getBroadcastPositions(Node* node) {
69 // Most of the element-wise ops in ONNX supports numpy broadcasting.
70 // Only GEMM supports one-directional broadcasting, which broadcasts the bias
71 // to the product.
72 static std::unordered_map<NodeKind, std::vector<size_t>> broadcast_positions =
73 {
74 {onnx::Add, {0, 1}},
75 {onnx::Div, {0, 1}},
76 {onnx::Mul, {0, 1}},
77 {onnx::Pow, {0, 1}},
78 {onnx::Sub, {0, 1}},
79 {onnx::Gemm, {2}},
80 {onnx::Equal, {0, 1}},
81 {onnx::Greater, {0, 1}},
82 {onnx::Less, {0, 1}},
83 };
84 static std::vector<size_t> no_positions;
85 std::vector<size_t> positions;
86
87 auto iter = broadcast_positions.find(node->kind());
88 if (iter != broadcast_positions.end()) {
89 // skip optional input if not provided
90 for (size_t position : iter->second) {
91 if (position < node->inputs().size()) {
92 positions.emplace_back(position);
93 }
94 }
95 return positions;
96 }
97 return no_positions;
98 }
99
100 // Determine whether `from` can broadcast to `to`, and if so at which
101 // position. `from` must be a suffix of `to`, except that any
102 // occurrences of 1 in `from` are treated as wildcards.
fusibleExpandTo(at::IntArrayRef from,at::IntArrayRef to)103 std::optional<size_t> fusibleExpandTo(
104 at::IntArrayRef from,
105 at::IntArrayRef to) {
106 if (from.size() > to.size()) {
107 return std::nullopt;
108 }
109
110 for (const auto i : c10::irange(from.size())) {
111 auto fdim = from[from.size() - 1 - i];
112 auto tdim = to[to.size() - 1 - i];
113 if (fdim != 1 && fdim != tdim) {
114 return std::nullopt;
115 }
116 }
117
118 return to.size() - from.size();
119 }
120
121 // Fuses expand calls into ONNX operators, because it is
122 // easier for non-strided backends to more efficiently do broadcasts if this
123 // is local information. This optimization is not useful for PyTorch as
124 // 'expand' is free.
fuseBroadcast(Block * b)125 void fuseBroadcast(Block* b) {
126 for (auto n : b->nodes()) {
127 for (auto* child_block : n->blocks()) {
128 fuseBroadcast(child_block);
129 }
130
131 auto broadcast_positions = getBroadcastPositions(n);
132 if (!broadcast_positions.empty()) {
133 TORCH_INTERNAL_ASSERT(!n->hasAttribute(attr::axis));
134 }
135
136 for (size_t position : broadcast_positions) {
137 auto* expand_node = n->input(position)->node();
138
139 // Confirm it is expand node.
140 if (expand_node->kind() != aten::expand ||
141 expand_node->input(1)->node()->kind() != onnx::Constant ||
142 expand_node->input(2)->node()->kind() != onnx::Constant) {
143 continue;
144 }
145
146 auto* unexpanded_input = expand_node->input(0);
147
148 // We need to know what the type pre-expand is. We should basically
149 // always have this information (because expands are only ever traced,
150 // not generated from symbolic), but if for some reason we don't
151 // have it, we need to skip.
152 if (!unexpanded_input->isCompleteTensor() ||
153 !n->output()->isCompleteTensor()) {
154 continue;
155 }
156
157 // Not all broadcasts are supported by ONNX broadcast.
158 std::optional<size_t> axis = fusibleExpandTo(
159 unexpanded_input->type()
160 ->expectRef<TensorType>()
161 .sizes()
162 .concrete_sizes()
163 .value(), // from
164 n->output()
165 ->type()
166 ->expectRef<TensorType>()
167 .sizes()
168 .concrete_sizes()
169 .value()); // to
170 if (axis == std::nullopt) {
171 continue;
172 }
173
174 n->replaceInput(position, unexpanded_input);
175 if (!expand_node->hasUses()) {
176 expand_node->destroy();
177 }
178 }
179 }
180 }
181
fuseConsecutiveTransposes(Block * b)182 void fuseConsecutiveTransposes(Block* b) {
183 for (auto n : b->nodes()) {
184 for (auto* child_block : n->blocks()) {
185 fuseConsecutiveTransposes(child_block);
186 }
187 if (n->kind() == onnx::Transpose &&
188 n->input()->node()->kind() == onnx::Transpose &&
189 n->owningBlock() == n->input()->node()->owningBlock()) {
190 auto origInput = n->input();
191 n->is_(
192 attr::perm,
193 composeTransposes(
194 origInput->node()->is(attr::perm), n->is(attr::perm)));
195 n->replaceInput(0, origInput->node()->input());
196 if (origInput->uses().empty()) {
197 origInput->node()->destroy();
198 }
199 continue;
200 }
201 }
202 }
203
eliminateNopTranspose(Block * b)204 void eliminateNopTranspose(Block* b) {
205 for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
206 auto n = *it;
207 for (auto* child_block : n->blocks()) {
208 eliminateNopTranspose(child_block);
209 }
210 if (n->kind() == onnx::Transpose) {
211 if (isNopTranspose(n->is(attr::perm))) {
212 n->output()->replaceAllUsesWith(n->input());
213 it.destroyCurrent();
214 continue;
215 }
216 }
217 }
218 }
219
fuseTransposeIntoGemm(Block * b)220 void fuseTransposeIntoGemm(Block* b) {
221 static const std::vector<int64_t> simpleTransPerm({1, 0});
222
223 for (auto n : b->nodes()) {
224 for (auto* child_block : n->blocks()) {
225 fuseTransposeIntoGemm(child_block);
226 }
227 if (n->kind() == onnx::Gemm) {
228 for (size_t i : {0, 1}) {
229 auto inp = n->inputs()[i];
230 auto trans = i == 0 ? attr::transA : attr::transB;
231 if (inp->node()->kind() == onnx::Transpose &&
232 inp->node()->is(attr::perm) == simpleTransPerm) {
233 n->replaceInput(i, inp->node()->input());
234 n->i_(trans, n->hasAttribute(trans) ? !n->i(trans) : 1);
235 if (inp->uses().empty()) {
236 inp->node()->destroy();
237 }
238 }
239 }
240 }
241 }
242 }
243
244 // Why this is here:
245 //
246 // Pytorch has a "packed" representation of sequences, as well as a
247 // "padded" representation. ONNX has only one representation,
248 // corresponding to pytorch's "padded". Therefore, we need to remove
249 // any use of packed sequences before exporting.
250 //
251 // What this does:
252 //
253 // This code uses the observation that
254 // RNN(PackPadded(x)) == PackPadded(RNN(x))
255 // and converts the first form to the second whenever possible,
256 // "pushing" the packing operation past the RNN operation. Then,
257 // the removeNopPacking pass removes the packing operations
258 // entirely by pairing them with their inverse PadPacked. If the
259 // input graph does not pair the operations, export will fail.
pushPackingPastRnn(Block * b)260 void pushPackingPastRnn(Block* b) {
261 for (auto it = b->nodes().begin(); it != b->nodes().end(); ++it) {
262 auto* n = *it;
263 for (auto* child_block : n->blocks()) {
264 pushPackingPastRnn(child_block);
265 }
266
267 if (n->kind() != prim::PackPadded) {
268 continue;
269 }
270 if (n->outputs().at(0)->uses().size() != 1) {
271 // For now, only handle the case where there is one consumer.
272 continue;
273 }
274 Node* rnn = n->outputs()[0]->uses()[0].user;
275 if (!isRNN(rnn)) {
276 continue;
277 }
278
279 if (rnn->owningBlock() != n->owningBlock()) {
280 continue;
281 }
282
283 // Packing only has an effect on a network when its outputs are actually
284 // used, so we can remove it here.
285 if (rnn->outputs().at(0)->uses().empty() &&
286 n->outputs().at(1)->uses().size() == 1) {
287 n->outputs().at(0)->replaceAllUsesWith(n->inputs().at(0));
288 n->outputs().at(1)->replaceFirstUseWith(n->inputs().at(1));
289 it.destroyCurrent();
290 continue;
291 }
292
293 // The rnn is followed by a transpose and a reshape (if
294 // bidirectional), or by a squeeze (if unidirectional).
295 Node* next = rnn->outputs().at(0)->uses().at(0).user;
296 if (next->kind() == onnx::Transpose) {
297 next = next->outputs().at(0)->uses().at(0).user;
298 if (next->kind() != onnx::Reshape) {
299 continue;
300 }
301 } else if (next->kind() != onnx::Squeeze) {
302 continue;
303 }
304
305 // remove PackPadded from in front of the RNN
306 n->outputs().at(0)->replaceAllUsesWith(n->inputs().at(0));
307
308 Value* batch_sizes = n->outputs().at(1);
309 while (!batch_sizes->uses().empty()) {
310 Use use_0 = batch_sizes->uses().at(0);
311 Node* user = use_0.user;
312 // Make calculation of max_batch_size not depend on batch_sizes.
313 // This looks for a pattern generated by code such as
314 // https://github.com/pytorch/pytorch/blob/febff45/torch/nn/modules/rnn.py#L815-L815.
315 //
316 // Replace onnx::Gather[axis=0](batch_sizes, 0)
317 // with onnx::Gather[axis=0](onnx::Shape(rnn_input), 1)
318 if (use_0.offset == 0 && user->kind() == onnx::Gather &&
319 user->i(attr::axis) == 0 &&
320 user->inputs().at(1)->node()->kind() == onnx::Constant &&
321 user->inputs().at(1)->node()->hasAttribute(attr::value)) {
322 const at::Tensor& const_val_t =
323 user->inputs().at(1)->node()->t(attr::value);
324 if (const_val_t.item().toInt() != 0) {
325 // We'll likely produce an invalid graph if this happens.
326 break;
327 }
328 Value* rnn_input = rnn->inputs().at(0);
329 Node* shape = b->owningGraph()->create(onnx::Shape);
330 shape->insertAfter(rnn_input->node());
331 shape->addInput(rnn_input);
332 shape->copyMetadata(n);
333 batch_sizes->replaceFirstUseWith(shape->output());
334 // New Constant node is needed, as it might be shared
335 // with a Constant node 0 from others.
336 Node* gather_indices = b->owningGraph()->create(onnx::Constant, 1);
337 gather_indices->t_(attr::value, at::native::ones_like(const_val_t));
338 gather_indices->copyMetadata(n);
339 gather_indices->insertBefore(user);
340 user->replaceInput(1, gather_indices->output());
341 }
342 // Make RNN not depend on batch_sizes.
343 else if (user == rnn) {
344 batch_sizes->replaceFirstUseWith(n->inputs().at(1));
345 } else {
346 // If there are other uses that are not:
347 // * PadPacked (which will be removed in removeNopPacking),
348 // * Dead code (which will be removed in dead code elimination),
349 // then we likely have produced an invalid graph, since there will be a
350 // use of the output of PackPadded, but the PackPadded (and that output)
351 // will be removed.
352 break;
353 }
354 }
355
356 // and insert new PackPadded after the RNN
357 Node* newPackPadded = b->owningGraph()->create(prim::PackPadded, 2);
358 newPackPadded->copyMetadata(n);
359 newPackPadded->insertAfter(next);
360 newPackPadded->copyMetadata(next);
361
362 // make things consume from the new PackPadded
363 next->outputs().at(0)->replaceAllUsesWith(newPackPadded->outputs().at(0));
364 n->outputs().at(1)->replaceAllUsesWith(newPackPadded->outputs().at(1));
365
366 // set up the new PackPadded's inputs
367 newPackPadded->addInput(next->outputs().at(0));
368 newPackPadded->addInput(n->inputs().at(1));
369
370 // See https://github.com/pytorch/pytorch/issues/9043 for a full
371 // description. Since PackPadded is for now treated in an
372 // unhygenic way, Pytorch ends up propagating an incorrect type.
373 // Until a long-term cleanup comes around, we can fix this by
374 // resetting the size to the correct value.
375 TensorTypePtr oldType = rnn->inputs().at(0)->type()->cast<TensorType>();
376 if (oldType && oldType->isComplete()) {
377 std::vector<int64_t> new_sizes;
378 new_sizes.push_back(*oldType->sizes()[0]);
379 new_sizes.push_back(*oldType->sizes()[1]);
380 if (next->kind() == onnx::Reshape) {
381 // bidirection
382 new_sizes.push_back(rnn->i(attr::hidden_size) * 2);
383 } else {
384 // unidirection
385 new_sizes.push_back(rnn->i(attr::hidden_size));
386 }
387 TensorTypePtr newType = TensorType::createContiguous(
388 *oldType->scalarType(), *oldType->device(), new_sizes);
389 next->outputs().at(0)->setType(newType);
390 }
391
392 it.destroyCurrent();
393 }
394 }
395
396 // Despite the name, this actually removes the PadPacked node and leaves
397 // the PackPadded node. The PackPadded should become dead code which will
398 // be eliminated later.
removeNopPacking(Block * graph)399 void removeNopPacking(Block* graph) {
400 for (auto it = graph->nodes().begin(); it != graph->nodes().end(); ++it) {
401 auto* n = *it;
402 for (auto* child_block : n->blocks()) {
403 removeNopPacking(child_block);
404 }
405
406 if (n->kind() != prim::PadPacked) {
407 continue;
408 }
409 Node* input = n->inputs()[0]->node();
410 if (input->kind() != prim::PackPadded) {
411 continue;
412 }
413 if (input->outputs()[0] != n->inputs()[0]) {
414 continue;
415 }
416 if (input->outputs()[1] != n->inputs()[1]) {
417 continue;
418 }
419 n->outputs()[0]->replaceAllUsesWith(input->inputs()[0]);
420 n->outputs()[1]->replaceAllUsesWith(input->inputs()[1]);
421
422 n->removeAllInputs();
423 it.destroyCurrent();
424 }
425 }
426
hackFixupPadPackedShapes(Block * graph)427 void hackFixupPadPackedShapes(Block* graph) {
428 // FIXME: the shape of the input to the fictional PadPacked node has
429 // incorrect shape. For now, just copy the shape of PadPacked to the shape
430 // of its input.
431 for (auto it = graph->nodes().begin(); it != graph->nodes().end(); ++it) {
432 auto* n = *it;
433 for (auto* child_block : n->blocks()) {
434 removeNopPacking(child_block);
435 }
436
437 if (n->kind() != prim::PadPacked) {
438 continue;
439 }
440 Node* input = n->inputs()[0]->node();
441 input->outputs()[0]->setType(n->outputs()[0]->type());
442 }
443 }
444
fixDefaultRNNState(Graph * graph,Node * n,int input_index,int opset_version)445 void fixDefaultRNNState(
446 Graph* graph,
447 Node* n,
448 int input_index,
449 int opset_version) {
450 auto initial_state = n->inputs()[input_index];
451
452 // The RNN code in pytorch accepts an optional hidden state.
453 // 1- When it is provided as an input, everything works great.
454 // 2- When it is not provided, it is default-initialized by constructing a new
455 // Variable, which gets
456 // traced as a ConstantOfShape with the expected Shape.
457 // 3- When the batch size is fixed, everything works great as well.
458 // 4- When h0 and c0 are specified but are not inputs of the model (they are
459 // Constants) and the batch size is variable, the model should be saved
460 // with a batch size of 1 (or an error will occur), and we save the value
461 // of h0 and c0 with a batch size of 1. When the model is then called with
462 // a different batch size value, h0 and c0 are broadcasted to get the right
463 // shape.
464 // Recognize that last pattern here (4) and fix the shape.
465 // Note that for multi-layer RNNs there will be a Slice operation between the
466 // Constant and the RNN.
467 bool needsFixing = initial_state->node()->kind() == onnx::Constant ||
468 (initial_state->node()->kind() == onnx::Slice &&
469 initial_state->node()->inputs()[0]->node()->kind() == onnx::Constant);
470
471 if (!needsFixing) {
472 return;
473 }
474
475 Node* shape_of_input = graph->create(onnx::Shape, 1);
476 shape_of_input->copyMetadata(n);
477 shape_of_input->insertBefore(n);
478 shape_of_input->addInput(n->inputs()[0]);
479
480 Node* gather_indices = graph->create(onnx::Constant, 1);
481 gather_indices->copyMetadata(n);
482 gather_indices->insertBefore(n);
483 gather_indices->t_(attr::value, at::scalar_to_tensor(at::Scalar(1)));
484
485 Node* batch_size = graph->create(onnx::Gather, 1);
486 batch_size->copyMetadata(n);
487 batch_size->insertBefore(n);
488 batch_size->addInput(shape_of_input->outputs()[0]);
489 batch_size->addInput(gather_indices->outputs()[0]);
490
491 Node* unsqueezed_batch_size =
492 createONNXUnsqueeze(graph, n, batch_size->outputs()[0], 0, opset_version);
493
494 Node* hidden_size = graph->create(onnx::Constant, 1);
495 hidden_size->copyMetadata(n);
496 hidden_size->insertBefore(n);
497 hidden_size->t_(
498 attr::value,
499 at::full(
500 {1},
501 n->i(attr::hidden_size),
502 at::kLong)); // at::Scalar(n->i(attr::hidden_size)).toTensor());
503
504 Node* num_directions = graph->create(onnx::Constant, 1);
505 num_directions->copyMetadata(n);
506 num_directions->insertBefore(n);
507 num_directions->t_(
508 attr::value,
509 scalar_to_tensor(at::Scalar(
510 n->hasAttribute(attr::direction) &&
511 n->s(attr::direction) == "bidirectional"
512 ? 2
513 : 1)));
514
515 Node* unsqueezed_num_directions = createONNXUnsqueeze(
516 graph, n, num_directions->outputs()[0], 0, opset_version);
517
518 Node* concated_dims = graph->create(onnx::Concat, 1);
519 concated_dims->copyMetadata(n);
520 concated_dims->insertBefore(n);
521 concated_dims->i_(attr::axis, 0);
522 concated_dims->addInput(unsqueezed_num_directions->outputs()[0]);
523 concated_dims->addInput(unsqueezed_batch_size->outputs()[0]);
524 concated_dims->addInput(hidden_size->outputs()[0]);
525
526 Node* fixed_init_state = graph->create(onnx::Expand, 1);
527 fixed_init_state->copyMetadata(n);
528 fixed_init_state->insertBefore(n);
529 fixed_init_state->addInput(initial_state);
530 fixed_init_state->addInput(concated_dims->outputs()[0]);
531 n->replaceInput(input_index, fixed_init_state->outputs()[0]);
532
533 if (initial_state->uses().empty()) {
534 initial_state->node()->destroy();
535 }
536 }
537
fixDefaultRnnHiddenState(Block * b,int opset_version)538 void fixDefaultRnnHiddenState(Block* b, int opset_version) {
539 for (auto it = b->nodes().begin(); it != b->nodes().end(); ++it) {
540 auto* n = *it;
541 for (auto* child_block : n->blocks()) {
542 fixDefaultRnnHiddenState(child_block, opset_version);
543 }
544
545 if (!isRNN(n)) {
546 continue;
547 }
548 // Hidden state is the sixth input for RNN, LSTM, GRU.
549 // See https://pytorch.org/docs/main/nn.html#torch.nn.RNN
550 if (n->inputs().size() < 6) {
551 continue;
552 }
553 fixDefaultRNNState(b->owningGraph(), n, 5, opset_version);
554 }
555 }
556
fixDefaultLstmCellState(Block * b,int opset_version)557 void fixDefaultLstmCellState(Block* b, int opset_version) {
558 for (auto it = b->nodes().begin(); it != b->nodes().end(); ++it) {
559 auto* n = *it;
560 for (auto* child_block : n->blocks()) {
561 fixDefaultLstmCellState(child_block, opset_version);
562 }
563
564 if (n->kind() != onnx::LSTM) {
565 continue;
566 }
567 // Cell state is the seventh input for LSTM.
568 // See https://pytorch.org/docs/main/nn.html#torch.nn.LSTM
569 if (n->inputs().size() < 7) {
570 continue;
571 }
572 fixDefaultRNNState(b->owningGraph(), n, 6, opset_version);
573 }
574 }
575
isSafeToSpeculate(Node * n)576 static bool isSafeToSpeculate(Node* n) {
577 return n->kind() == onnx::Transpose;
578 }
579
580 // Moves ops outside of control flow blocks so that they are always executed,
581 // no matter the result of the control flow conditions.
582 // Needed only so that the split pass of the ONNX optimizer will put the ops
583 // into the init_net.
584 // TODO: Once the code in caffe2/python/onnx/backend.py no longer calls
585 // optimize_onnx, delete this function.
speculateOps(Block * block)586 static void speculateOps(Block* block) {
587 for (auto it = block->nodes().begin(), end = block->nodes().end();
588 it != end;) {
589 Node* n = *it;
590 ++it; // note: increment first so that it is safe to move the node if needed
591
592 for (auto b : n->blocks()) {
593 speculateOps(b);
594 }
595 if (!isSafeToSpeculate(n)) {
596 continue;
597 }
598 // XXX - only works for nodes with a single input
599 // move node n outside of the control flow it is nested in
600 auto node_input = n->input()->node();
601 if (node_input->owningBlock() == n->owningBlock()) {
602 continue;
603 }
604 // Skip if output of this node is part of block output.
605 bool is_block_output = false;
606 for (auto node_output : n->outputs()) {
607 for (auto node_output_use : node_output->uses()) {
608 if (node_output_use.user == n->owningBlock()->return_node()) {
609 is_block_output = true;
610 break;
611 }
612 }
613 if (is_block_output) {
614 break;
615 }
616 }
617 if (is_block_output) {
618 continue;
619 }
620 // find the control flow node in the same block as node_input that contains
621 // Node n
622 auto control_flow_node = n->owningBlock()->owningNode();
623 while (control_flow_node->owningBlock() != node_input->owningBlock()) {
624 control_flow_node = control_flow_node->owningBlock()->owningNode();
625 }
626 // put the node right before this flow node
627 n->moveBefore(control_flow_node);
628 }
629 }
630
replaceInputWithList(Node * node,size_t i,ArrayRef<Value * > to)631 static void replaceInputWithList(Node* node, size_t i, ArrayRef<Value*> to) {
632 node->removeInput(i);
633 for (auto* to_val : to) {
634 TORCH_INTERNAL_ASSERT(to_val->owningGraph() == node->owningGraph());
635 node->insertInput(i++, to_val);
636 }
637 }
638
639 static void eraseListConstruct(Block* block, int opset_version);
640
eraseListConstruct(Node * n,int opset_version)641 static void eraseListConstruct(Node* n, int opset_version) {
642 for (auto b : n->blocks()) {
643 eraseListConstruct(b, opset_version);
644 }
645 std::vector<std::tuple<size_t, std::vector<Value*>>> replacements;
646
647 auto block = n->owningBlock();
648 size_t i = 0;
649 for (auto* input : n->inputs()) {
650 if (input->node()->kind() == prim::ListConstruct) {
651 auto* lc_node = input->node();
652 TypePtr elem =
653 lc_node->output()->type()->castRaw<ListType>()->getElementType();
654 if (elem->cast<IntType>() &&
655 isValidToTransformToONNXConcatNode(lc_node)) {
656 auto concat_node = transformToONNXConcatNode(
657 block->owningGraph(), input->node(), false, opset_version);
658 concat_node->copyMetadata(n);
659 // make concat node output as new input, then ListConstruct should
660 // become dead
661 replacements.emplace_back(
662 i, std::vector<Value*>({concat_node->output()}));
663 } else {
664 if (opset_version >= OPSET_VERSION_11) {
665 c10::Symbol seq_node_kind = !lc_node->inputs().empty()
666 ? onnx::SequenceConstruct
667 : onnx::SequenceEmpty;
668 Node* seq_node = block->owningGraph()->create(
669 seq_node_kind, {lc_node->inputs()}, 1);
670 seq_node->copyMetadata(n);
671 seq_node->insertBefore(lc_node);
672 seq_node->output()->copyMetadata(lc_node->output());
673 seq_node->copyMetadata(lc_node);
674 lc_node->replaceAllUsesWith(seq_node);
675 }
676 }
677 }
678 i++;
679 }
680
681 for (auto ritr = replacements.rbegin(); ritr != replacements.rend(); ++ritr) {
682 replaceInputWithList(n, std::get<0>(*ritr), std::get<1>(*ritr));
683 }
684 }
685
eraseListConstruct(Block * block,int opset_version)686 static void eraseListConstruct(Block* block, int opset_version) {
687 // TODO: Fix this pass/maybe get rid of this part.
688 // Tensor lists might be used for meshgrid and such ops as well.
689 for (auto it = block->nodes().begin(), end = block->nodes().end();
690 it != end;) {
691 Node* n = *it;
692 ++it;
693
694 eraseListConstruct(n, opset_version);
695 }
696 eraseListConstruct(block->return_node(), opset_version);
697 }
698
699 static void eraseListUnpack(Block* block, int opset_version);
700
701 // Replace prim::ListUnpack with onnx::SequenceAt.
eraseListUnpack(Node * n,int opset_version)702 static void eraseListUnpack(Node* n, int opset_version) {
703 for (auto b : n->blocks()) {
704 eraseListUnpack(b, opset_version);
705 }
706
707 if (n->kind() == prim::ListUnpack) {
708 if (opset_version < OPSET_VERSION_11) {
709 // onnx::SequenceAt was introduced in onnx opset version 11
710 throw std::runtime_error(
711 "Unsupported: ONNX export of prim::ListUnpack in opset " +
712 std::to_string(opset_version) + ". Please try opset version 11.");
713 }
714
715 auto g = n->owningGraph();
716 for (size_t i = 0; i < n->outputs().size(); ++i) {
717 auto seq_idx_n = g->create(onnx::Constant, 1);
718 seq_idx_n->t_(attr::value, at::scalar_to_tensor(at::Scalar(int64_t(i))));
719 seq_idx_n->insertBefore(n);
720
721 auto seq_at_n = g->create(onnx::SequenceAt, 1);
722 seq_at_n->addInput(n->input());
723 seq_at_n->addInput(seq_idx_n->output());
724 seq_at_n->output()->setType(n->output(i)->type());
725 seq_at_n->insertBefore(n);
726 seq_at_n->copyMetadata(n);
727 n->output(i)->replaceAllUsesWith(seq_at_n->output());
728 }
729 }
730 }
731
eraseListUnpack(Block * block,int opset_version)732 static void eraseListUnpack(Block* block, int opset_version) {
733 for (auto it = block->nodes().begin(), end = block->nodes().end();
734 it != end;) {
735 Node* n = *it;
736 ++it;
737
738 eraseListUnpack(n, opset_version);
739 }
740 }
741
742 // From:
743 // %list = ListConstruct(%x);
744 // %unpacked = ListUnpack(%list);
745 // do_something(%unpacked);
746 //
747 // To:
748 // %list = ListConstruct(%x);
749 // %unpacked = ListUnpack(%list);
750 // do_something(%x)
751 //
752 // The ListConstruct and ListUnpack may now be dead code.
fuseListConstructListUnpack(Block * b)753 static void fuseListConstructListUnpack(Block* b) {
754 for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
755 for (auto* child_block : it->blocks()) {
756 fuseListConstructListUnpack(child_block);
757 }
758 if (it->kind() == prim::ListUnpack &&
759 it->input()->node()->kind() == prim::ListConstruct) {
760 for (const auto i : c10::irange(it->outputs().size())) {
761 auto output = it->outputs().at(i);
762 output->replaceAllUsesWith(it->input()->node()->inputs().at(i));
763 }
764 }
765 }
766 }
767
768 // https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter#quantized-model-export
eraseTupleConstruct(Block * block)769 static void eraseTupleConstruct(Block* block) {
770 std::vector<Value*> new_block_outputs;
771 bool found_tuple_construct = false;
772 // TupleConstruct is generated from the symbolics in quantized domain, and
773 // consumed by other quantized operators. The remained TupleConstruct should
774 // be at the output of the blocks.
775 for (auto* output : block->outputs()) {
776 auto output_node = output->node();
777 if (output_node->kind() == prim::TupleConstruct) {
778 found_tuple_construct = true;
779 for (auto* input : output_node->inputs()) {
780 new_block_outputs.emplace_back(input);
781 }
782 } else {
783 new_block_outputs.emplace_back(output);
784 }
785 }
786 if (found_tuple_construct) {
787 block->removeAllOutputs();
788 for (auto* output : new_block_outputs) {
789 block->registerOutput(output);
790 }
791 }
792 }
793
removeMaxPoolUnusedOutput(Block * b)794 void removeMaxPoolUnusedOutput(Block* b) {
795 for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
796 auto n = *it;
797 for (auto* child_block : n->blocks()) {
798 removeMaxPoolUnusedOutput(child_block);
799 }
800 if (strcmp(n->kind().toQualString(), "onnx::MaxPool") == 0) {
801 if (n->outputs().size() == 2 && n->outputs().at(1)->uses().empty()) {
802 it->eraseOutput(1);
803 }
804 }
805 }
806 }
807
808 // This optimization fuses LogSoftmax and NegativeLogLikelihoodLoss operators
809 // into one operator: SoftmaxCrossEntropyLoss, and depending on the dimensions
810 // of the input and different attributes there will be different subgraphs of
811 // LogSoftmax and NegativeLogLikelihoodLoss.
fuseLogSoftmaxNllLoss(Block * b)812 static void fuseLogSoftmaxNllLoss(Block* b) {
813 for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
814 for (auto* child_block : it->blocks()) {
815 fuseLogSoftmaxNllLoss(child_block);
816 }
817 if (it->kind() == onnx::NegativeLogLikelihoodLoss) {
818 auto prev = it->input(0)->node();
819 Node* origNllLossNode = *it;
820 Node* origLogSoftmaxNode = nullptr;
821
822 // Check for patterns especially in cases with autocasting enabled
823 // in which a cast node is inserted before the NegativeLogLikelihoodLoss
824 // node and this causes the patterns below not to be recognizable by the
825 // fuseLogSoftmaxNllLoss function
826 // For example if the input is 2D
827 // graph(%input : Half(3, 5),
828 // %target : Long(3)):
829 // %4 : Half(3, 5) = onnx::LogSoftmaxaxis=1
830 // %8 : Float = onnx::Cast[to=1](%4)
831 // %9 : Float(3) = onnx::NegativeLogLikelihoodLoss[reduction="none"]
832 // return (%8)
833 Node* castNode = nullptr;
834 if (prev->kind() == onnx::Cast) {
835 castNode = prev;
836 prev = prev->input(0)->node();
837 }
838
839 if (prev->kind() == onnx::LogSoftmax) {
840 // if the input is 2D
841 // graph(%input : Float(3, 5),
842 // %target : Long(3)):
843 // %4 : Float(3, 5) = onnx::LogSoftmaxaxis=1
844 // %8 : Float(3) = onnx::NegativeLogLikelihoodLoss[reduction="none"]
845 // return (%8)
846 origLogSoftmaxNode = prev;
847 } else if (
848 prev->kind() == onnx::Transpose &&
849 prev->input(0)->node()->kind() == onnx::LogSoftmax) {
850 // if the input is 4D
851 // graph(%input : Float(3, 5, 2, 7),
852 // %target : Long(3, 2, 7)):
853 // %4 : Tensor = onnx::Transpose[perm=[0, 3, 2, 1]] (%input)
854 // %5 : Tensor = onnx::LogSoftmax[axis=3] (%4)
855 // %6 : Float(3, 5, 2, 7) = onnx::Transpose[perm=[0, 3, 2, 1]] (%5)
856 // %10 : Float(3, 2, 7) =
857 // onnx::NegativeLogLikelihoodLoss[reduction="none"](%6, %target) return
858 // (%10)
859 origLogSoftmaxNode = prev->input(0)->node();
860 auto transpose = origLogSoftmaxNode->input(0)->node();
861 if (!transpose->inputs().empty()) {
862 origLogSoftmaxNode->replaceInput(0, transpose->inputs().at(0));
863 }
864 } else if (
865 prev->kind() == onnx::Reshape &&
866 prev->input(0)->node()->kind() == onnx::Transpose &&
867 prev->input(0)->node()->input(0)->node()->kind() ==
868 onnx::LogSoftmax) {
869 // if the input is 3D or > 4D
870 // graph(%input : Float(3, 5, 2),
871 // %target.1 : Long(3, 2)):
872 // %4 : Tensor = onnx::Transpose[perm=[0, 2, 1]] (%input)
873 // %5 : Tensor = onnx::LogSoftmax[axis=2] (%4)
874 // %6 : Float(3, 5, 2) = onnx::Transpose[perm=[0, 2, 1]] (%5)
875 // %8 : Tensor = onnx::Shape(%6)
876 // %10 : Tensor = onnx::Constantvalue={0}
877 // %11 : Long() = onnx::Gather[axis=0] (%8, %10)
878 // %13 : Tensor = onnx::Shape(%6)
879 // %15 Tensor = onnx::Constantvalue={1}
880 // %16 : Long() = onnx::Gather[axis=0] (%13, %15)
881 // ...
882 // %22 : Float(3, 5, 1, 2) = onnx::Reshape(%6, %21)
883 // ...
884 // %26 : Long(3, 1, 2) = onnx::Reshape(%target.1, %25)
885 // %30 : Float() = onnx::NegativeLogLikelihoodLoss[reduction="sum"](%22,
886 // %26) return (%30)
887 origLogSoftmaxNode = prev->input(0)->node()->input(0)->node();
888 auto transpose = origLogSoftmaxNode->input(0)->node();
889 TORCH_INTERNAL_ASSERT(transpose->kind() == onnx::Transpose);
890 origLogSoftmaxNode->replaceInput(0, transpose->inputs().at(0));
891 auto reshape = origNllLossNode->input(1)->node();
892 TORCH_INTERNAL_ASSERT(reshape->kind() == onnx::Reshape);
893 origNllLossNode->replaceInput(1, reshape->inputs().at(0));
894 if (origNllLossNode->s(attr::reduction) == "none") {
895 // when reduction=none a different graph is created and the graph
896 // doesn't end with node NegativeLogLikelihoodLoss like in all other
897 // cases.
898 // graph(%input : Float(3, 5, 2), %target.1 : Long(3, 2)):
899 // %4 : Tensor = onnx::Transposeperm=[0, 2, 1]
900 // %5 : Tensor = onnx::LogSoftmaxaxis=2
901 // %6 : Float(3, 5, 2) = onnx::Transposeperm=[0, 2, 1]
902 // ...
903 // %27 : Float(3, 5, 1, 2) = onnx::Reshape(%6, %26)
904 // %31 : Long(3, 1, 2) = onnx::Reshape(%target.1, %30)
905 // %35 : Float(3, 1, 2) =
906 // onnx::NegativeLogLikelihoodLoss[reduction="none"](%27, %31) %36 :
907 // int[] = prim::ListConstruct(%11, %21) %37 : Float(3, 2) =
908 // onnx::Reshape(%35, %36) return (%37)
909 auto nllloss_output = origNllLossNode->output(0)->uses()[0].user;
910 TORCH_INTERNAL_ASSERT(nllloss_output->kind() == onnx::Reshape);
911 // make output of reshape the output of nllloss
912 nllloss_output->replaceAllUsesWith(origNllLossNode);
913 origNllLossNode->output(0)->copyMetadata(nllloss_output->output(0));
914 }
915 } else {
916 continue;
917 }
918
919 // If the pattern indeed consists of a cast node before the
920 // NegativeLogLikelihoodLoss node, place a cast node in the beginning
921 // of the pattern instead
922 if (castNode != nullptr) {
923 auto onnx_type = castNode->i(attr::to);
924 Node* cast_node = b->owningGraph()->create(onnx::Cast, 1);
925 cast_node->addInput(origLogSoftmaxNode->inputs().at(0));
926 cast_node->i_(attr::to, onnx_type);
927 cast_node->insertBefore(origLogSoftmaxNode);
928 cast_node->copyMetadata(castNode);
929 origLogSoftmaxNode->replaceInputWith(
930 origLogSoftmaxNode->inputs().at(0), cast_node->output());
931 }
932
933 Node* softmaxCrossEntropyNode = b->owningGraph()->create(
934 onnx::SoftmaxCrossEntropyLoss, it->outputs().size());
935 for (size_t i = 0; i < softmaxCrossEntropyNode->outputs().size(); ++i) {
936 softmaxCrossEntropyNode->outputs()[i]->copyMetadata(it->outputs()[i]);
937 }
938 softmaxCrossEntropyNode->copyMetadata(origNllLossNode);
939 softmaxCrossEntropyNode->copyAttributes(*origNllLossNode);
940 softmaxCrossEntropyNode->insertBefore(origNllLossNode);
941 softmaxCrossEntropyNode->addInput(origLogSoftmaxNode->inputs().at(0));
942 softmaxCrossEntropyNode->addInput(origNllLossNode->inputs().at(1));
943 softmaxCrossEntropyNode->copyMetadata(origNllLossNode);
944 // optional weight input is provided
945 if (origNllLossNode->inputs().size() == 3) {
946 softmaxCrossEntropyNode->addInput(origNllLossNode->inputs().at(2));
947 }
948
949 it->replaceAllUsesWith(softmaxCrossEntropyNode);
950 it->removeAllInputs();
951 it.destroyCurrent();
952 }
953 }
954 }
955
956 // This optimization removes consecutive SplitToSequence and ConcatFromSequence
957 // operators. The optimization only happens when
958 // 1. Output of SplitToSequence is not used by any other nodes.
959 // 2. The attribute keepdims and axis of SplitToSequence match
960 // attribute new_axis and axis of ConcatFromSequence.
961 // In that case, the two ops combined are no-op, and can be safely removed.
removeSequenceSplitConcat(Block * b)962 static void removeSequenceSplitConcat(Block* b) {
963 for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
964 for (auto* child_block : it->blocks()) {
965 removeSequenceSplitConcat(child_block);
966 }
967 if (it->kind() == onnx::ConcatFromSequence &&
968 it->input()->node()->kind() == onnx::SplitToSequence) {
969 if (it->input()->uses().size() > 1) {
970 continue;
971 }
972
973 auto split_node = it->input()->node();
974 auto concat_node = *it;
975
976 const auto split_axis =
977 split_node->hasAttribute(attr::axis) ? split_node->i(attr::axis) : 0;
978 const auto split_keepdims = split_node->hasAttribute(attr::keepdims)
979 ? split_node->i(attr::keepdims)
980 : 1;
981 const auto concat_axis = concat_node->i(attr::axis);
982 const auto concat_new_axis = concat_node->hasAttribute(attr::new_axis)
983 ? concat_node->i(attr::new_axis)
984 : 0;
985 const bool has_input_split = split_node->inputs().size() == 2;
986
987 if (has_input_split) {
988 continue;
989 }
990
991 if (split_keepdims == concat_new_axis) {
992 continue;
993 }
994
995 if (split_axis != concat_axis) {
996 continue;
997 }
998
999 concat_node->output()->replaceAllUsesWith(split_node->input());
1000 }
1001 }
1002 }
1003
1004 // Work around limitation from ONNX that the block input cannot be used directly
1005 // as block output. Inserts an Identity node inside the block, and have the
1006 // block return the output of the Identity.
insertIdentityForInputUsedAsOutput(Block * b)1007 static void insertIdentityForInputUsedAsOutput(Block* b) {
1008 for (auto out : b->outputs()) {
1009 auto n = out->node();
1010 if (nullptr != n && n->kind() == prim::Param) {
1011 Node* id_node = b->owningGraph()->create(onnx::Identity);
1012 id_node->insertBefore(b->return_node());
1013 id_node->addInput(out);
1014 id_node->output()->setType(out->type());
1015 b->return_node()->replaceInputWith(out, id_node->output());
1016 }
1017 }
1018
1019 for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
1020 for (auto* child_block : it->blocks()) {
1021 insertIdentityForInputUsedAsOutput(child_block);
1022 }
1023 }
1024 }
1025
1026 // This optimization does ONNX-specific peephole optimizations.
1027 //
1028 // Before you write an optimization here, ask yourself, "Could I do this
1029 // optimization on ATen operators"? If so, you should seriously consider
1030 // writing your optimization in jit/passes/peephole.cpp rather than
1031 // here, as it will be generally applicable to the JIT as well. The
1032 // optimizations here are ONLY applied on ONNX export.
PeepholeOptimizeONNX(std::shared_ptr<Graph> & graph,int opset_version,bool fixed_batch_size)1033 void PeepholeOptimizeONNX(
1034 std::shared_ptr<Graph>& graph,
1035 int opset_version,
1036 bool fixed_batch_size) {
1037 // TODO: decide on fixpoint strategy
1038 // TODO: make it easier not to do O(k) iterations over the graph, where
1039 // k is the number of distinct peephole optimizations
1040 hackFixupPadPackedShapes(graph->block());
1041 pushPackingPastRnn(graph->block());
1042 removeNopPacking(graph->block());
1043 // we only need to fix the size of hidden state and cell state if the batch
1044 // size is variable
1045 if (!fixed_batch_size) {
1046 fixDefaultRnnHiddenState(graph->block(), opset_version);
1047 fixDefaultLstmCellState(graph->block(), opset_version);
1048 }
1049 fuseBroadcast(graph->block());
1050 fuseConsecutiveTransposes(graph->block());
1051 eliminateNopTranspose(graph->block());
1052 fuseTransposeIntoGemm(graph->block());
1053 speculateOps(graph->block());
1054 fuseListConstructListUnpack(graph->block());
1055 fuseLogSoftmaxNllLoss(graph->block());
1056 eraseListConstruct(graph->block(), opset_version);
1057 eraseTupleConstruct(graph->block());
1058 EliminateDeadCode(
1059 graph->block(),
1060 true,
1061 DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS);
1062 eraseListUnpack(graph->block(), opset_version);
1063 removeMaxPoolUnusedOutput(graph->block());
1064 removeSequenceSplitConcat(graph->block());
1065 insertIdentityForInputUsedAsOutput(graph->block());
1066
1067 GRAPH_DUMP("After PeepholeOptimizeONNX", graph);
1068 }
1069
1070 } // namespace torch::jit
1071