xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/onnx/peephole.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/onnx/peephole.h>
2 
3 #include <c10/util/Exception.h>
4 #include <c10/util/irange.h>
5 #include <torch/csrc/jit/jit_log.h>
6 #include <torch/csrc/jit/passes/dead_code_elimination.h>
7 #include <torch/csrc/jit/passes/onnx/helper.h>
8 
9 #include <ATen/ScalarOps.h>
10 
11 #ifndef AT_PER_OPERATOR_HEADERS
12 #include <ATen/Functions.h>
13 #include <ATen/NativeFunctions.h>
14 #else
15 #include <ATen/ops/full.h>
16 #include <ATen/ops/ones_like_native.h>
17 #endif
18 
19 #include <optional>
20 
21 #if defined(_MSC_VER)
22 #include <BaseTsd.h>
23 typedef SSIZE_T ssize_t;
24 #endif
25 
26 namespace torch::jit {
27 
28 namespace onnx {
29 using namespace ::c10::onnx;
30 }
31 
isRNN(const Node * node)32 bool isRNN(const Node* node) {
33   auto k = node->kind();
34   return k == onnx::RNN || k == onnx::LSTM || k == onnx::GRU;
35 }
36 
isNopTranspose(const std::vector<int64_t> & perm)37 bool isNopTranspose(const std::vector<int64_t>& perm) {
38   for (int64_t i = 0, perm_size = perm.size(); i < perm_size; i++) {
39     if (perm[i] != i) {
40       return false;
41     }
42   }
43   return true;
44 }
45 
46 // returns a vector `ret` such that transposing by `ret` is equivalent
47 // to transposing by `t1` and then by `t2`
48 //
49 // This fires in the case that we have transpose ops T1 -> T2. We are
50 // fusing the transpose op T1 into T2 and discarding T1. We assume the elements
51 // of the permutation in `t1` are raw indices into its input, since a previous
52 // iteration would have folded all the transposes up to that point. Thus,
53 // `ret[i] = t1[t2[i]]` says "the output of t2 at position i takes the value of
54 // the input tensor index contained in t1 at position `t2[i]``".
composeTransposes(const std::vector<int64_t> & t1,const std::vector<int64_t> & t2)55 std::vector<int64_t> composeTransposes(
56     const std::vector<int64_t>& t1,
57     const std::vector<int64_t>& t2) {
58   TORCH_INTERNAL_ASSERT(t1.size() == t2.size());
59   std::vector<int64_t> ret;
60   ret.reserve(t1.size());
61   for (const auto& i : t2) {
62     TORCH_INTERNAL_ASSERT(i < int64_t(t1.size()));
63     ret.push_back(t1[i]);
64   }
65   return ret;
66 }
67 
getBroadcastPositions(Node * node)68 std::vector<size_t> getBroadcastPositions(Node* node) {
69   // Most of the element-wise ops in ONNX supports numpy broadcasting.
70   // Only GEMM supports one-directional broadcasting, which broadcasts the bias
71   // to the product.
72   static std::unordered_map<NodeKind, std::vector<size_t>> broadcast_positions =
73       {
74           {onnx::Add, {0, 1}},
75           {onnx::Div, {0, 1}},
76           {onnx::Mul, {0, 1}},
77           {onnx::Pow, {0, 1}},
78           {onnx::Sub, {0, 1}},
79           {onnx::Gemm, {2}},
80           {onnx::Equal, {0, 1}},
81           {onnx::Greater, {0, 1}},
82           {onnx::Less, {0, 1}},
83       };
84   static std::vector<size_t> no_positions;
85   std::vector<size_t> positions;
86 
87   auto iter = broadcast_positions.find(node->kind());
88   if (iter != broadcast_positions.end()) {
89     // skip optional input if not provided
90     for (size_t position : iter->second) {
91       if (position < node->inputs().size()) {
92         positions.emplace_back(position);
93       }
94     }
95     return positions;
96   }
97   return no_positions;
98 }
99 
100 // Determine whether `from` can broadcast to `to`, and if so at which
101 // position. `from` must be a suffix of `to`, except that any
102 // occurrences of 1 in `from` are treated as wildcards.
fusibleExpandTo(at::IntArrayRef from,at::IntArrayRef to)103 std::optional<size_t> fusibleExpandTo(
104     at::IntArrayRef from,
105     at::IntArrayRef to) {
106   if (from.size() > to.size()) {
107     return std::nullopt;
108   }
109 
110   for (const auto i : c10::irange(from.size())) {
111     auto fdim = from[from.size() - 1 - i];
112     auto tdim = to[to.size() - 1 - i];
113     if (fdim != 1 && fdim != tdim) {
114       return std::nullopt;
115     }
116   }
117 
118   return to.size() - from.size();
119 }
120 
121 // Fuses expand calls into ONNX operators, because it is
122 // easier for non-strided backends to more efficiently do broadcasts if this
123 // is local information. This optimization is not useful for PyTorch as
124 // 'expand' is free.
fuseBroadcast(Block * b)125 void fuseBroadcast(Block* b) {
126   for (auto n : b->nodes()) {
127     for (auto* child_block : n->blocks()) {
128       fuseBroadcast(child_block);
129     }
130 
131     auto broadcast_positions = getBroadcastPositions(n);
132     if (!broadcast_positions.empty()) {
133       TORCH_INTERNAL_ASSERT(!n->hasAttribute(attr::axis));
134     }
135 
136     for (size_t position : broadcast_positions) {
137       auto* expand_node = n->input(position)->node();
138 
139       // Confirm it is expand node.
140       if (expand_node->kind() != aten::expand ||
141           expand_node->input(1)->node()->kind() != onnx::Constant ||
142           expand_node->input(2)->node()->kind() != onnx::Constant) {
143         continue;
144       }
145 
146       auto* unexpanded_input = expand_node->input(0);
147 
148       // We need to know what the type pre-expand is.  We should basically
149       // always have this information (because expands are only ever traced,
150       // not generated from symbolic), but if for some reason we don't
151       // have it, we need to skip.
152       if (!unexpanded_input->isCompleteTensor() ||
153           !n->output()->isCompleteTensor()) {
154         continue;
155       }
156 
157       // Not all broadcasts are supported by ONNX broadcast.
158       std::optional<size_t> axis = fusibleExpandTo(
159           unexpanded_input->type()
160               ->expectRef<TensorType>()
161               .sizes()
162               .concrete_sizes()
163               .value(), // from
164           n->output()
165               ->type()
166               ->expectRef<TensorType>()
167               .sizes()
168               .concrete_sizes()
169               .value()); // to
170       if (axis == std::nullopt) {
171         continue;
172       }
173 
174       n->replaceInput(position, unexpanded_input);
175       if (!expand_node->hasUses()) {
176         expand_node->destroy();
177       }
178     }
179   }
180 }
181 
fuseConsecutiveTransposes(Block * b)182 void fuseConsecutiveTransposes(Block* b) {
183   for (auto n : b->nodes()) {
184     for (auto* child_block : n->blocks()) {
185       fuseConsecutiveTransposes(child_block);
186     }
187     if (n->kind() == onnx::Transpose &&
188         n->input()->node()->kind() == onnx::Transpose &&
189         n->owningBlock() == n->input()->node()->owningBlock()) {
190       auto origInput = n->input();
191       n->is_(
192           attr::perm,
193           composeTransposes(
194               origInput->node()->is(attr::perm), n->is(attr::perm)));
195       n->replaceInput(0, origInput->node()->input());
196       if (origInput->uses().empty()) {
197         origInput->node()->destroy();
198       }
199       continue;
200     }
201   }
202 }
203 
eliminateNopTranspose(Block * b)204 void eliminateNopTranspose(Block* b) {
205   for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
206     auto n = *it;
207     for (auto* child_block : n->blocks()) {
208       eliminateNopTranspose(child_block);
209     }
210     if (n->kind() == onnx::Transpose) {
211       if (isNopTranspose(n->is(attr::perm))) {
212         n->output()->replaceAllUsesWith(n->input());
213         it.destroyCurrent();
214         continue;
215       }
216     }
217   }
218 }
219 
fuseTransposeIntoGemm(Block * b)220 void fuseTransposeIntoGemm(Block* b) {
221   static const std::vector<int64_t> simpleTransPerm({1, 0});
222 
223   for (auto n : b->nodes()) {
224     for (auto* child_block : n->blocks()) {
225       fuseTransposeIntoGemm(child_block);
226     }
227     if (n->kind() == onnx::Gemm) {
228       for (size_t i : {0, 1}) {
229         auto inp = n->inputs()[i];
230         auto trans = i == 0 ? attr::transA : attr::transB;
231         if (inp->node()->kind() == onnx::Transpose &&
232             inp->node()->is(attr::perm) == simpleTransPerm) {
233           n->replaceInput(i, inp->node()->input());
234           n->i_(trans, n->hasAttribute(trans) ? !n->i(trans) : 1);
235           if (inp->uses().empty()) {
236             inp->node()->destroy();
237           }
238         }
239       }
240     }
241   }
242 }
243 
244 // Why this is here:
245 //
246 //   Pytorch has a "packed" representation of sequences, as well as a
247 //   "padded" representation. ONNX has only one representation,
248 //   corresponding to pytorch's "padded". Therefore, we need to remove
249 //   any use of packed sequences before exporting.
250 //
251 // What this does:
252 //
253 //   This code uses the observation that
254 //     RNN(PackPadded(x)) == PackPadded(RNN(x))
255 //   and converts the first form to the second whenever possible,
256 //   "pushing" the packing operation past the RNN operation. Then,
257 //   the removeNopPacking pass removes the packing operations
258 //   entirely by pairing them with their inverse PadPacked. If the
259 //   input graph does not pair the operations, export will fail.
pushPackingPastRnn(Block * b)260 void pushPackingPastRnn(Block* b) {
261   for (auto it = b->nodes().begin(); it != b->nodes().end(); ++it) {
262     auto* n = *it;
263     for (auto* child_block : n->blocks()) {
264       pushPackingPastRnn(child_block);
265     }
266 
267     if (n->kind() != prim::PackPadded) {
268       continue;
269     }
270     if (n->outputs().at(0)->uses().size() != 1) {
271       // For now, only handle the case where there is one consumer.
272       continue;
273     }
274     Node* rnn = n->outputs()[0]->uses()[0].user;
275     if (!isRNN(rnn)) {
276       continue;
277     }
278 
279     if (rnn->owningBlock() != n->owningBlock()) {
280       continue;
281     }
282 
283     // Packing only has an effect on a network when its outputs are actually
284     // used, so we can remove it here.
285     if (rnn->outputs().at(0)->uses().empty() &&
286         n->outputs().at(1)->uses().size() == 1) {
287       n->outputs().at(0)->replaceAllUsesWith(n->inputs().at(0));
288       n->outputs().at(1)->replaceFirstUseWith(n->inputs().at(1));
289       it.destroyCurrent();
290       continue;
291     }
292 
293     // The rnn is followed by a transpose and a reshape (if
294     // bidirectional), or by a squeeze (if unidirectional).
295     Node* next = rnn->outputs().at(0)->uses().at(0).user;
296     if (next->kind() == onnx::Transpose) {
297       next = next->outputs().at(0)->uses().at(0).user;
298       if (next->kind() != onnx::Reshape) {
299         continue;
300       }
301     } else if (next->kind() != onnx::Squeeze) {
302       continue;
303     }
304 
305     // remove PackPadded from in front of the RNN
306     n->outputs().at(0)->replaceAllUsesWith(n->inputs().at(0));
307 
308     Value* batch_sizes = n->outputs().at(1);
309     while (!batch_sizes->uses().empty()) {
310       Use use_0 = batch_sizes->uses().at(0);
311       Node* user = use_0.user;
312       // Make calculation of max_batch_size not depend on batch_sizes.
313       // This looks for a pattern generated by code such as
314       // https://github.com/pytorch/pytorch/blob/febff45/torch/nn/modules/rnn.py#L815-L815.
315       //
316       // Replace onnx::Gather[axis=0](batch_sizes, 0)
317       // with    onnx::Gather[axis=0](onnx::Shape(rnn_input), 1)
318       if (use_0.offset == 0 && user->kind() == onnx::Gather &&
319           user->i(attr::axis) == 0 &&
320           user->inputs().at(1)->node()->kind() == onnx::Constant &&
321           user->inputs().at(1)->node()->hasAttribute(attr::value)) {
322         const at::Tensor& const_val_t =
323             user->inputs().at(1)->node()->t(attr::value);
324         if (const_val_t.item().toInt() != 0) {
325           // We'll likely produce an invalid graph if this happens.
326           break;
327         }
328         Value* rnn_input = rnn->inputs().at(0);
329         Node* shape = b->owningGraph()->create(onnx::Shape);
330         shape->insertAfter(rnn_input->node());
331         shape->addInput(rnn_input);
332         shape->copyMetadata(n);
333         batch_sizes->replaceFirstUseWith(shape->output());
334         // New Constant node is needed, as it might be shared
335         // with a Constant node 0 from others.
336         Node* gather_indices = b->owningGraph()->create(onnx::Constant, 1);
337         gather_indices->t_(attr::value, at::native::ones_like(const_val_t));
338         gather_indices->copyMetadata(n);
339         gather_indices->insertBefore(user);
340         user->replaceInput(1, gather_indices->output());
341       }
342       // Make RNN not depend on batch_sizes.
343       else if (user == rnn) {
344         batch_sizes->replaceFirstUseWith(n->inputs().at(1));
345       } else {
346         // If there are other uses that are not:
347         // * PadPacked (which will be removed in removeNopPacking),
348         // * Dead code (which will be removed in dead code elimination),
349         // then we likely have produced an invalid graph, since there will be a
350         // use of the output of PackPadded, but the PackPadded (and that output)
351         // will be removed.
352         break;
353       }
354     }
355 
356     // and insert new PackPadded after the RNN
357     Node* newPackPadded = b->owningGraph()->create(prim::PackPadded, 2);
358     newPackPadded->copyMetadata(n);
359     newPackPadded->insertAfter(next);
360     newPackPadded->copyMetadata(next);
361 
362     // make things consume from the new PackPadded
363     next->outputs().at(0)->replaceAllUsesWith(newPackPadded->outputs().at(0));
364     n->outputs().at(1)->replaceAllUsesWith(newPackPadded->outputs().at(1));
365 
366     // set up the new PackPadded's inputs
367     newPackPadded->addInput(next->outputs().at(0));
368     newPackPadded->addInput(n->inputs().at(1));
369 
370     // See https://github.com/pytorch/pytorch/issues/9043 for a full
371     // description.  Since PackPadded is for now treated in an
372     // unhygenic way, Pytorch ends up propagating an incorrect type.
373     // Until a long-term cleanup comes around, we can fix this by
374     // resetting the size to the correct value.
375     TensorTypePtr oldType = rnn->inputs().at(0)->type()->cast<TensorType>();
376     if (oldType && oldType->isComplete()) {
377       std::vector<int64_t> new_sizes;
378       new_sizes.push_back(*oldType->sizes()[0]);
379       new_sizes.push_back(*oldType->sizes()[1]);
380       if (next->kind() == onnx::Reshape) {
381         // bidirection
382         new_sizes.push_back(rnn->i(attr::hidden_size) * 2);
383       } else {
384         // unidirection
385         new_sizes.push_back(rnn->i(attr::hidden_size));
386       }
387       TensorTypePtr newType = TensorType::createContiguous(
388           *oldType->scalarType(), *oldType->device(), new_sizes);
389       next->outputs().at(0)->setType(newType);
390     }
391 
392     it.destroyCurrent();
393   }
394 }
395 
396 // Despite the name, this actually removes the PadPacked node and leaves
397 // the PackPadded node. The PackPadded should become dead code which will
398 // be eliminated later.
removeNopPacking(Block * graph)399 void removeNopPacking(Block* graph) {
400   for (auto it = graph->nodes().begin(); it != graph->nodes().end(); ++it) {
401     auto* n = *it;
402     for (auto* child_block : n->blocks()) {
403       removeNopPacking(child_block);
404     }
405 
406     if (n->kind() != prim::PadPacked) {
407       continue;
408     }
409     Node* input = n->inputs()[0]->node();
410     if (input->kind() != prim::PackPadded) {
411       continue;
412     }
413     if (input->outputs()[0] != n->inputs()[0]) {
414       continue;
415     }
416     if (input->outputs()[1] != n->inputs()[1]) {
417       continue;
418     }
419     n->outputs()[0]->replaceAllUsesWith(input->inputs()[0]);
420     n->outputs()[1]->replaceAllUsesWith(input->inputs()[1]);
421 
422     n->removeAllInputs();
423     it.destroyCurrent();
424   }
425 }
426 
hackFixupPadPackedShapes(Block * graph)427 void hackFixupPadPackedShapes(Block* graph) {
428   // FIXME: the shape of the input to the fictional PadPacked node has
429   // incorrect shape. For now, just copy the shape of PadPacked to the shape
430   // of its input.
431   for (auto it = graph->nodes().begin(); it != graph->nodes().end(); ++it) {
432     auto* n = *it;
433     for (auto* child_block : n->blocks()) {
434       removeNopPacking(child_block);
435     }
436 
437     if (n->kind() != prim::PadPacked) {
438       continue;
439     }
440     Node* input = n->inputs()[0]->node();
441     input->outputs()[0]->setType(n->outputs()[0]->type());
442   }
443 }
444 
fixDefaultRNNState(Graph * graph,Node * n,int input_index,int opset_version)445 void fixDefaultRNNState(
446     Graph* graph,
447     Node* n,
448     int input_index,
449     int opset_version) {
450   auto initial_state = n->inputs()[input_index];
451 
452   // The RNN code in pytorch accepts an optional hidden state.
453   // 1- When it is provided as an input, everything works great.
454   // 2- When it is not provided, it is default-initialized by constructing a new
455   // Variable, which gets
456   //    traced as a ConstantOfShape with the expected Shape.
457   // 3- When the batch size is fixed, everything works great as well.
458   // 4- When h0 and c0 are specified but are not inputs of the model (they are
459   //    Constants) and the batch size is variable, the model should be saved
460   //    with a batch size of 1 (or an error will occur), and we save the value
461   //    of h0 and c0 with a batch size of 1. When the model is then called with
462   //    a different batch size value, h0 and c0 are broadcasted to get the right
463   //    shape.
464   // Recognize that last pattern here (4) and fix the shape.
465   // Note that for multi-layer RNNs there will be a Slice operation between the
466   // Constant and the RNN.
467   bool needsFixing = initial_state->node()->kind() == onnx::Constant ||
468       (initial_state->node()->kind() == onnx::Slice &&
469        initial_state->node()->inputs()[0]->node()->kind() == onnx::Constant);
470 
471   if (!needsFixing) {
472     return;
473   }
474 
475   Node* shape_of_input = graph->create(onnx::Shape, 1);
476   shape_of_input->copyMetadata(n);
477   shape_of_input->insertBefore(n);
478   shape_of_input->addInput(n->inputs()[0]);
479 
480   Node* gather_indices = graph->create(onnx::Constant, 1);
481   gather_indices->copyMetadata(n);
482   gather_indices->insertBefore(n);
483   gather_indices->t_(attr::value, at::scalar_to_tensor(at::Scalar(1)));
484 
485   Node* batch_size = graph->create(onnx::Gather, 1);
486   batch_size->copyMetadata(n);
487   batch_size->insertBefore(n);
488   batch_size->addInput(shape_of_input->outputs()[0]);
489   batch_size->addInput(gather_indices->outputs()[0]);
490 
491   Node* unsqueezed_batch_size =
492       createONNXUnsqueeze(graph, n, batch_size->outputs()[0], 0, opset_version);
493 
494   Node* hidden_size = graph->create(onnx::Constant, 1);
495   hidden_size->copyMetadata(n);
496   hidden_size->insertBefore(n);
497   hidden_size->t_(
498       attr::value,
499       at::full(
500           {1},
501           n->i(attr::hidden_size),
502           at::kLong)); // at::Scalar(n->i(attr::hidden_size)).toTensor());
503 
504   Node* num_directions = graph->create(onnx::Constant, 1);
505   num_directions->copyMetadata(n);
506   num_directions->insertBefore(n);
507   num_directions->t_(
508       attr::value,
509       scalar_to_tensor(at::Scalar(
510           n->hasAttribute(attr::direction) &&
511                   n->s(attr::direction) == "bidirectional"
512               ? 2
513               : 1)));
514 
515   Node* unsqueezed_num_directions = createONNXUnsqueeze(
516       graph, n, num_directions->outputs()[0], 0, opset_version);
517 
518   Node* concated_dims = graph->create(onnx::Concat, 1);
519   concated_dims->copyMetadata(n);
520   concated_dims->insertBefore(n);
521   concated_dims->i_(attr::axis, 0);
522   concated_dims->addInput(unsqueezed_num_directions->outputs()[0]);
523   concated_dims->addInput(unsqueezed_batch_size->outputs()[0]);
524   concated_dims->addInput(hidden_size->outputs()[0]);
525 
526   Node* fixed_init_state = graph->create(onnx::Expand, 1);
527   fixed_init_state->copyMetadata(n);
528   fixed_init_state->insertBefore(n);
529   fixed_init_state->addInput(initial_state);
530   fixed_init_state->addInput(concated_dims->outputs()[0]);
531   n->replaceInput(input_index, fixed_init_state->outputs()[0]);
532 
533   if (initial_state->uses().empty()) {
534     initial_state->node()->destroy();
535   }
536 }
537 
fixDefaultRnnHiddenState(Block * b,int opset_version)538 void fixDefaultRnnHiddenState(Block* b, int opset_version) {
539   for (auto it = b->nodes().begin(); it != b->nodes().end(); ++it) {
540     auto* n = *it;
541     for (auto* child_block : n->blocks()) {
542       fixDefaultRnnHiddenState(child_block, opset_version);
543     }
544 
545     if (!isRNN(n)) {
546       continue;
547     }
548     // Hidden state is the sixth input for RNN, LSTM, GRU.
549     // See https://pytorch.org/docs/main/nn.html#torch.nn.RNN
550     if (n->inputs().size() < 6) {
551       continue;
552     }
553     fixDefaultRNNState(b->owningGraph(), n, 5, opset_version);
554   }
555 }
556 
fixDefaultLstmCellState(Block * b,int opset_version)557 void fixDefaultLstmCellState(Block* b, int opset_version) {
558   for (auto it = b->nodes().begin(); it != b->nodes().end(); ++it) {
559     auto* n = *it;
560     for (auto* child_block : n->blocks()) {
561       fixDefaultLstmCellState(child_block, opset_version);
562     }
563 
564     if (n->kind() != onnx::LSTM) {
565       continue;
566     }
567     // Cell state is the seventh input for LSTM.
568     // See https://pytorch.org/docs/main/nn.html#torch.nn.LSTM
569     if (n->inputs().size() < 7) {
570       continue;
571     }
572     fixDefaultRNNState(b->owningGraph(), n, 6, opset_version);
573   }
574 }
575 
isSafeToSpeculate(Node * n)576 static bool isSafeToSpeculate(Node* n) {
577   return n->kind() == onnx::Transpose;
578 }
579 
580 // Moves ops outside of control flow blocks so that they are always executed,
581 // no matter the result of the control flow conditions.
582 // Needed only so that the split pass of the ONNX optimizer will put the ops
583 // into the init_net.
584 // TODO: Once the code in caffe2/python/onnx/backend.py no longer calls
585 // optimize_onnx, delete this function.
speculateOps(Block * block)586 static void speculateOps(Block* block) {
587   for (auto it = block->nodes().begin(), end = block->nodes().end();
588        it != end;) {
589     Node* n = *it;
590     ++it; // note: increment first so that it is safe to move the node if needed
591 
592     for (auto b : n->blocks()) {
593       speculateOps(b);
594     }
595     if (!isSafeToSpeculate(n)) {
596       continue;
597     }
598     // XXX - only works for nodes with a single input
599     // move node n outside of the control flow it is nested in
600     auto node_input = n->input()->node();
601     if (node_input->owningBlock() == n->owningBlock()) {
602       continue;
603     }
604     // Skip if output of this node is part of block output.
605     bool is_block_output = false;
606     for (auto node_output : n->outputs()) {
607       for (auto node_output_use : node_output->uses()) {
608         if (node_output_use.user == n->owningBlock()->return_node()) {
609           is_block_output = true;
610           break;
611         }
612       }
613       if (is_block_output) {
614         break;
615       }
616     }
617     if (is_block_output) {
618       continue;
619     }
620     // find the control flow node in the same block as node_input that contains
621     // Node n
622     auto control_flow_node = n->owningBlock()->owningNode();
623     while (control_flow_node->owningBlock() != node_input->owningBlock()) {
624       control_flow_node = control_flow_node->owningBlock()->owningNode();
625     }
626     // put the node right before this flow node
627     n->moveBefore(control_flow_node);
628   }
629 }
630 
replaceInputWithList(Node * node,size_t i,ArrayRef<Value * > to)631 static void replaceInputWithList(Node* node, size_t i, ArrayRef<Value*> to) {
632   node->removeInput(i);
633   for (auto* to_val : to) {
634     TORCH_INTERNAL_ASSERT(to_val->owningGraph() == node->owningGraph());
635     node->insertInput(i++, to_val);
636   }
637 }
638 
639 static void eraseListConstruct(Block* block, int opset_version);
640 
eraseListConstruct(Node * n,int opset_version)641 static void eraseListConstruct(Node* n, int opset_version) {
642   for (auto b : n->blocks()) {
643     eraseListConstruct(b, opset_version);
644   }
645   std::vector<std::tuple<size_t, std::vector<Value*>>> replacements;
646 
647   auto block = n->owningBlock();
648   size_t i = 0;
649   for (auto* input : n->inputs()) {
650     if (input->node()->kind() == prim::ListConstruct) {
651       auto* lc_node = input->node();
652       TypePtr elem =
653           lc_node->output()->type()->castRaw<ListType>()->getElementType();
654       if (elem->cast<IntType>() &&
655           isValidToTransformToONNXConcatNode(lc_node)) {
656         auto concat_node = transformToONNXConcatNode(
657             block->owningGraph(), input->node(), false, opset_version);
658         concat_node->copyMetadata(n);
659         // make concat node output as new input, then ListConstruct should
660         // become dead
661         replacements.emplace_back(
662             i, std::vector<Value*>({concat_node->output()}));
663       } else {
664         if (opset_version >= OPSET_VERSION_11) {
665           c10::Symbol seq_node_kind = !lc_node->inputs().empty()
666               ? onnx::SequenceConstruct
667               : onnx::SequenceEmpty;
668           Node* seq_node = block->owningGraph()->create(
669               seq_node_kind, {lc_node->inputs()}, 1);
670           seq_node->copyMetadata(n);
671           seq_node->insertBefore(lc_node);
672           seq_node->output()->copyMetadata(lc_node->output());
673           seq_node->copyMetadata(lc_node);
674           lc_node->replaceAllUsesWith(seq_node);
675         }
676       }
677     }
678     i++;
679   }
680 
681   for (auto ritr = replacements.rbegin(); ritr != replacements.rend(); ++ritr) {
682     replaceInputWithList(n, std::get<0>(*ritr), std::get<1>(*ritr));
683   }
684 }
685 
eraseListConstruct(Block * block,int opset_version)686 static void eraseListConstruct(Block* block, int opset_version) {
687   // TODO: Fix this pass/maybe get rid of this part.
688   // Tensor lists might be used for meshgrid and such ops as well.
689   for (auto it = block->nodes().begin(), end = block->nodes().end();
690        it != end;) {
691     Node* n = *it;
692     ++it;
693 
694     eraseListConstruct(n, opset_version);
695   }
696   eraseListConstruct(block->return_node(), opset_version);
697 }
698 
699 static void eraseListUnpack(Block* block, int opset_version);
700 
701 // Replace prim::ListUnpack with onnx::SequenceAt.
eraseListUnpack(Node * n,int opset_version)702 static void eraseListUnpack(Node* n, int opset_version) {
703   for (auto b : n->blocks()) {
704     eraseListUnpack(b, opset_version);
705   }
706 
707   if (n->kind() == prim::ListUnpack) {
708     if (opset_version < OPSET_VERSION_11) {
709       // onnx::SequenceAt was introduced in onnx opset version 11
710       throw std::runtime_error(
711           "Unsupported: ONNX export of prim::ListUnpack in opset " +
712           std::to_string(opset_version) + ". Please try opset version 11.");
713     }
714 
715     auto g = n->owningGraph();
716     for (size_t i = 0; i < n->outputs().size(); ++i) {
717       auto seq_idx_n = g->create(onnx::Constant, 1);
718       seq_idx_n->t_(attr::value, at::scalar_to_tensor(at::Scalar(int64_t(i))));
719       seq_idx_n->insertBefore(n);
720 
721       auto seq_at_n = g->create(onnx::SequenceAt, 1);
722       seq_at_n->addInput(n->input());
723       seq_at_n->addInput(seq_idx_n->output());
724       seq_at_n->output()->setType(n->output(i)->type());
725       seq_at_n->insertBefore(n);
726       seq_at_n->copyMetadata(n);
727       n->output(i)->replaceAllUsesWith(seq_at_n->output());
728     }
729   }
730 }
731 
eraseListUnpack(Block * block,int opset_version)732 static void eraseListUnpack(Block* block, int opset_version) {
733   for (auto it = block->nodes().begin(), end = block->nodes().end();
734        it != end;) {
735     Node* n = *it;
736     ++it;
737 
738     eraseListUnpack(n, opset_version);
739   }
740 }
741 
742 // From:
743 //   %list = ListConstruct(%x);
744 //   %unpacked = ListUnpack(%list);
745 //   do_something(%unpacked);
746 //
747 // To:
748 //   %list = ListConstruct(%x);
749 //   %unpacked = ListUnpack(%list);
750 //   do_something(%x)
751 //
752 // The ListConstruct and ListUnpack may now be dead code.
fuseListConstructListUnpack(Block * b)753 static void fuseListConstructListUnpack(Block* b) {
754   for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
755     for (auto* child_block : it->blocks()) {
756       fuseListConstructListUnpack(child_block);
757     }
758     if (it->kind() == prim::ListUnpack &&
759         it->input()->node()->kind() == prim::ListConstruct) {
760       for (const auto i : c10::irange(it->outputs().size())) {
761         auto output = it->outputs().at(i);
762         output->replaceAllUsesWith(it->input()->node()->inputs().at(i));
763       }
764     }
765   }
766 }
767 
768 // https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter#quantized-model-export
eraseTupleConstruct(Block * block)769 static void eraseTupleConstruct(Block* block) {
770   std::vector<Value*> new_block_outputs;
771   bool found_tuple_construct = false;
772   // TupleConstruct is generated from the symbolics in quantized domain, and
773   // consumed by other quantized operators. The remained TupleConstruct should
774   // be at the output of the blocks.
775   for (auto* output : block->outputs()) {
776     auto output_node = output->node();
777     if (output_node->kind() == prim::TupleConstruct) {
778       found_tuple_construct = true;
779       for (auto* input : output_node->inputs()) {
780         new_block_outputs.emplace_back(input);
781       }
782     } else {
783       new_block_outputs.emplace_back(output);
784     }
785   }
786   if (found_tuple_construct) {
787     block->removeAllOutputs();
788     for (auto* output : new_block_outputs) {
789       block->registerOutput(output);
790     }
791   }
792 }
793 
removeMaxPoolUnusedOutput(Block * b)794 void removeMaxPoolUnusedOutput(Block* b) {
795   for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
796     auto n = *it;
797     for (auto* child_block : n->blocks()) {
798       removeMaxPoolUnusedOutput(child_block);
799     }
800     if (strcmp(n->kind().toQualString(), "onnx::MaxPool") == 0) {
801       if (n->outputs().size() == 2 && n->outputs().at(1)->uses().empty()) {
802         it->eraseOutput(1);
803       }
804     }
805   }
806 }
807 
808 // This optimization fuses LogSoftmax and NegativeLogLikelihoodLoss operators
809 // into one operator: SoftmaxCrossEntropyLoss, and depending on the dimensions
810 // of the input and different attributes there will be different subgraphs of
811 // LogSoftmax and NegativeLogLikelihoodLoss.
fuseLogSoftmaxNllLoss(Block * b)812 static void fuseLogSoftmaxNllLoss(Block* b) {
813   for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
814     for (auto* child_block : it->blocks()) {
815       fuseLogSoftmaxNllLoss(child_block);
816     }
817     if (it->kind() == onnx::NegativeLogLikelihoodLoss) {
818       auto prev = it->input(0)->node();
819       Node* origNllLossNode = *it;
820       Node* origLogSoftmaxNode = nullptr;
821 
822       // Check for patterns especially in cases with autocasting enabled
823       // in which a cast node is inserted before the NegativeLogLikelihoodLoss
824       // node and this causes the patterns below not to be recognizable by the
825       // fuseLogSoftmaxNllLoss function
826       // For example if the input is 2D
827       // graph(%input : Half(3, 5),
828       // %target : Long(3)):
829       // %4 : Half(3, 5) = onnx::LogSoftmaxaxis=1
830       // %8 : Float = onnx::Cast[to=1](%4)
831       // %9 : Float(3) = onnx::NegativeLogLikelihoodLoss[reduction="none"]
832       // return (%8)
833       Node* castNode = nullptr;
834       if (prev->kind() == onnx::Cast) {
835         castNode = prev;
836         prev = prev->input(0)->node();
837       }
838 
839       if (prev->kind() == onnx::LogSoftmax) {
840         // if the input is 2D
841         // graph(%input : Float(3, 5),
842         // %target : Long(3)):
843         // %4 : Float(3, 5) = onnx::LogSoftmaxaxis=1
844         // %8 : Float(3) = onnx::NegativeLogLikelihoodLoss[reduction="none"]
845         // return (%8)
846         origLogSoftmaxNode = prev;
847       } else if (
848           prev->kind() == onnx::Transpose &&
849           prev->input(0)->node()->kind() == onnx::LogSoftmax) {
850         // if the input is 4D
851         // graph(%input : Float(3, 5, 2, 7),
852         // %target : Long(3, 2, 7)):
853         // %4 : Tensor = onnx::Transpose[perm=[0, 3, 2, 1]] (%input)
854         // %5 : Tensor = onnx::LogSoftmax[axis=3] (%4)
855         // %6 : Float(3, 5, 2, 7) = onnx::Transpose[perm=[0, 3, 2, 1]] (%5)
856         // %10 : Float(3, 2, 7) =
857         // onnx::NegativeLogLikelihoodLoss[reduction="none"](%6, %target) return
858         // (%10)
859         origLogSoftmaxNode = prev->input(0)->node();
860         auto transpose = origLogSoftmaxNode->input(0)->node();
861         if (!transpose->inputs().empty()) {
862           origLogSoftmaxNode->replaceInput(0, transpose->inputs().at(0));
863         }
864       } else if (
865           prev->kind() == onnx::Reshape &&
866           prev->input(0)->node()->kind() == onnx::Transpose &&
867           prev->input(0)->node()->input(0)->node()->kind() ==
868               onnx::LogSoftmax) {
869         // if the input is 3D or > 4D
870         // graph(%input : Float(3, 5, 2),
871         // %target.1 : Long(3, 2)):
872         // %4 : Tensor = onnx::Transpose[perm=[0, 2, 1]] (%input)
873         // %5 : Tensor = onnx::LogSoftmax[axis=2] (%4)
874         // %6 : Float(3, 5, 2) = onnx::Transpose[perm=[0, 2, 1]] (%5)
875         // %8 : Tensor = onnx::Shape(%6)
876         // %10 : Tensor = onnx::Constantvalue={0}
877         // %11 : Long() = onnx::Gather[axis=0] (%8, %10)
878         // %13 : Tensor = onnx::Shape(%6)
879         // %15 Tensor = onnx::Constantvalue={1}
880         // %16 : Long() = onnx::Gather[axis=0] (%13, %15)
881         // ...
882         // %22 : Float(3, 5, 1, 2) = onnx::Reshape(%6, %21)
883         // ...
884         // %26 : Long(3, 1, 2) = onnx::Reshape(%target.1, %25)
885         // %30 : Float() = onnx::NegativeLogLikelihoodLoss[reduction="sum"](%22,
886         // %26) return (%30)
887         origLogSoftmaxNode = prev->input(0)->node()->input(0)->node();
888         auto transpose = origLogSoftmaxNode->input(0)->node();
889         TORCH_INTERNAL_ASSERT(transpose->kind() == onnx::Transpose);
890         origLogSoftmaxNode->replaceInput(0, transpose->inputs().at(0));
891         auto reshape = origNllLossNode->input(1)->node();
892         TORCH_INTERNAL_ASSERT(reshape->kind() == onnx::Reshape);
893         origNllLossNode->replaceInput(1, reshape->inputs().at(0));
894         if (origNllLossNode->s(attr::reduction) == "none") {
895           // when reduction=none a different graph is created and the graph
896           // doesn't end with node NegativeLogLikelihoodLoss like in all other
897           // cases.
898           // graph(%input : Float(3, 5, 2), %target.1 : Long(3, 2)):
899           // %4 : Tensor = onnx::Transposeperm=[0, 2, 1]
900           // %5 : Tensor = onnx::LogSoftmaxaxis=2
901           // %6 : Float(3, 5, 2) = onnx::Transposeperm=[0, 2, 1]
902           // ...
903           // %27 : Float(3, 5, 1, 2) = onnx::Reshape(%6, %26)
904           // %31 : Long(3, 1, 2) = onnx::Reshape(%target.1, %30)
905           // %35 : Float(3, 1, 2) =
906           // onnx::NegativeLogLikelihoodLoss[reduction="none"](%27, %31) %36 :
907           // int[] = prim::ListConstruct(%11, %21) %37 : Float(3, 2) =
908           // onnx::Reshape(%35, %36) return (%37)
909           auto nllloss_output = origNllLossNode->output(0)->uses()[0].user;
910           TORCH_INTERNAL_ASSERT(nllloss_output->kind() == onnx::Reshape);
911           // make output of reshape the output of nllloss
912           nllloss_output->replaceAllUsesWith(origNllLossNode);
913           origNllLossNode->output(0)->copyMetadata(nllloss_output->output(0));
914         }
915       } else {
916         continue;
917       }
918 
919       // If the pattern indeed consists of a cast node before the
920       // NegativeLogLikelihoodLoss node, place a cast node in the beginning
921       // of the pattern instead
922       if (castNode != nullptr) {
923         auto onnx_type = castNode->i(attr::to);
924         Node* cast_node = b->owningGraph()->create(onnx::Cast, 1);
925         cast_node->addInput(origLogSoftmaxNode->inputs().at(0));
926         cast_node->i_(attr::to, onnx_type);
927         cast_node->insertBefore(origLogSoftmaxNode);
928         cast_node->copyMetadata(castNode);
929         origLogSoftmaxNode->replaceInputWith(
930             origLogSoftmaxNode->inputs().at(0), cast_node->output());
931       }
932 
933       Node* softmaxCrossEntropyNode = b->owningGraph()->create(
934           onnx::SoftmaxCrossEntropyLoss, it->outputs().size());
935       for (size_t i = 0; i < softmaxCrossEntropyNode->outputs().size(); ++i) {
936         softmaxCrossEntropyNode->outputs()[i]->copyMetadata(it->outputs()[i]);
937       }
938       softmaxCrossEntropyNode->copyMetadata(origNllLossNode);
939       softmaxCrossEntropyNode->copyAttributes(*origNllLossNode);
940       softmaxCrossEntropyNode->insertBefore(origNllLossNode);
941       softmaxCrossEntropyNode->addInput(origLogSoftmaxNode->inputs().at(0));
942       softmaxCrossEntropyNode->addInput(origNllLossNode->inputs().at(1));
943       softmaxCrossEntropyNode->copyMetadata(origNllLossNode);
944       // optional weight input is provided
945       if (origNllLossNode->inputs().size() == 3) {
946         softmaxCrossEntropyNode->addInput(origNllLossNode->inputs().at(2));
947       }
948 
949       it->replaceAllUsesWith(softmaxCrossEntropyNode);
950       it->removeAllInputs();
951       it.destroyCurrent();
952     }
953   }
954 }
955 
956 // This optimization removes consecutive SplitToSequence and ConcatFromSequence
957 // operators. The optimization only happens when
958 //  1. Output of SplitToSequence is not used by any other nodes.
959 //  2. The attribute keepdims and axis of SplitToSequence match
960 //     attribute new_axis and axis of ConcatFromSequence.
961 // In that case, the two ops combined are no-op, and can be safely removed.
removeSequenceSplitConcat(Block * b)962 static void removeSequenceSplitConcat(Block* b) {
963   for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
964     for (auto* child_block : it->blocks()) {
965       removeSequenceSplitConcat(child_block);
966     }
967     if (it->kind() == onnx::ConcatFromSequence &&
968         it->input()->node()->kind() == onnx::SplitToSequence) {
969       if (it->input()->uses().size() > 1) {
970         continue;
971       }
972 
973       auto split_node = it->input()->node();
974       auto concat_node = *it;
975 
976       const auto split_axis =
977           split_node->hasAttribute(attr::axis) ? split_node->i(attr::axis) : 0;
978       const auto split_keepdims = split_node->hasAttribute(attr::keepdims)
979           ? split_node->i(attr::keepdims)
980           : 1;
981       const auto concat_axis = concat_node->i(attr::axis);
982       const auto concat_new_axis = concat_node->hasAttribute(attr::new_axis)
983           ? concat_node->i(attr::new_axis)
984           : 0;
985       const bool has_input_split = split_node->inputs().size() == 2;
986 
987       if (has_input_split) {
988         continue;
989       }
990 
991       if (split_keepdims == concat_new_axis) {
992         continue;
993       }
994 
995       if (split_axis != concat_axis) {
996         continue;
997       }
998 
999       concat_node->output()->replaceAllUsesWith(split_node->input());
1000     }
1001   }
1002 }
1003 
1004 // Work around limitation from ONNX that the block input cannot be used directly
1005 // as block output. Inserts an Identity node inside the block, and have the
1006 // block return the output of the Identity.
insertIdentityForInputUsedAsOutput(Block * b)1007 static void insertIdentityForInputUsedAsOutput(Block* b) {
1008   for (auto out : b->outputs()) {
1009     auto n = out->node();
1010     if (nullptr != n && n->kind() == prim::Param) {
1011       Node* id_node = b->owningGraph()->create(onnx::Identity);
1012       id_node->insertBefore(b->return_node());
1013       id_node->addInput(out);
1014       id_node->output()->setType(out->type());
1015       b->return_node()->replaceInputWith(out, id_node->output());
1016     }
1017   }
1018 
1019   for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
1020     for (auto* child_block : it->blocks()) {
1021       insertIdentityForInputUsedAsOutput(child_block);
1022     }
1023   }
1024 }
1025 
1026 // This optimization does ONNX-specific peephole optimizations.
1027 //
1028 // Before you write an optimization here, ask yourself, "Could I do this
1029 // optimization on ATen operators"?  If so, you should seriously consider
1030 // writing your optimization in jit/passes/peephole.cpp rather than
1031 // here, as it will be generally applicable to the JIT as well.  The
1032 // optimizations here are ONLY applied on ONNX export.
PeepholeOptimizeONNX(std::shared_ptr<Graph> & graph,int opset_version,bool fixed_batch_size)1033 void PeepholeOptimizeONNX(
1034     std::shared_ptr<Graph>& graph,
1035     int opset_version,
1036     bool fixed_batch_size) {
1037   // TODO: decide on fixpoint strategy
1038   // TODO: make it easier not to do O(k) iterations over the graph, where
1039   // k is the number of distinct peephole optimizations
1040   hackFixupPadPackedShapes(graph->block());
1041   pushPackingPastRnn(graph->block());
1042   removeNopPacking(graph->block());
1043   // we only need to fix the size of hidden state and cell state if the batch
1044   // size is variable
1045   if (!fixed_batch_size) {
1046     fixDefaultRnnHiddenState(graph->block(), opset_version);
1047     fixDefaultLstmCellState(graph->block(), opset_version);
1048   }
1049   fuseBroadcast(graph->block());
1050   fuseConsecutiveTransposes(graph->block());
1051   eliminateNopTranspose(graph->block());
1052   fuseTransposeIntoGemm(graph->block());
1053   speculateOps(graph->block());
1054   fuseListConstructListUnpack(graph->block());
1055   fuseLogSoftmaxNllLoss(graph->block());
1056   eraseListConstruct(graph->block(), opset_version);
1057   eraseTupleConstruct(graph->block());
1058   EliminateDeadCode(
1059       graph->block(),
1060       true,
1061       DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS);
1062   eraseListUnpack(graph->block(), opset_version);
1063   removeMaxPoolUnusedOutput(graph->block());
1064   removeSequenceSplitConcat(graph->block());
1065   insertIdentityForInputUsedAsOutput(graph->block());
1066 
1067   GRAPH_DUMP("After PeepholeOptimizeONNX", graph);
1068 }
1069 
1070 } // namespace torch::jit
1071