xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/graph_fuser.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/graph_fuser.h>
2 
3 #include <c10/util/Exception.h>
4 #include <c10/util/irange.h>
5 #include <torch/csrc/jit/codegen/fuser/interface.h>
6 #include <torch/csrc/jit/frontend/ir_emitter.h>
7 #include <torch/csrc/jit/ir/alias_analysis.h>
8 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
9 #include <torch/csrc/jit/passes/constant_pooling.h>
10 #include <torch/csrc/jit/passes/dead_code_elimination.h>
11 #include <torch/csrc/jit/passes/tensorexpr_fuser.h>
12 #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
13 #include <torch/csrc/jit/runtime/autodiff.h>
14 #include <torch/csrc/jit/runtime/custom_operator.h>
15 #include <torch/csrc/jit/runtime/operator.h>
16 
17 #include <queue>
18 #include <unordered_map>
19 #include <utility>
20 
21 namespace torch::jit {
22 
23 namespace {
24 
25 // What is a simple mappable operator?  It:
26 //    - Has a single tensor output
27 //    - Output and all tensor inputs have the same shape
28 //    - Output and all tensor inputs have the same scalar type
29 //      or all tensor inputs have the same scalar type and
30 //         output is identified in PropagateInputShapes
31 //    - Output and all tensor inputs should be on the same device
32 //    - Produces dense non-overlapping outputs
33 // Some of these restrictions may be relaxable, but you should
34 // carefully read the code first, as we rely on these assumptions.
isSimpleMap(Node * node)35 bool isSimpleMap(Node* node) {
36   static OperatorSet simple_mappable{{
37       "aten::_cast_Float(Tensor self, bool non_blocking) -> Tensor",
38 
39       "aten::abs(Tensor self) -> Tensor",
40       "aten::acos(Tensor self) -> Tensor",
41       "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
42       "aten::asin(Tensor self) -> Tensor",
43       "aten::atan(Tensor self) -> Tensor",
44       "aten::atan2(Tensor self, Tensor other) -> Tensor",
45       "aten::ceil(Tensor self) -> Tensor",
46       "aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor",
47       "aten::cos(Tensor self) -> Tensor",
48       "aten::cosh(Tensor self) -> Tensor",
49       "aten::div(Tensor self, Tensor other) -> Tensor",
50       "aten::exp(Tensor self) -> Tensor",
51       "aten::expm1(Tensor self) -> Tensor",
52       "aten::erf(Tensor self) -> Tensor",
53       "aten::erfc(Tensor self) -> Tensor",
54       "aten::floor(Tensor self) -> Tensor",
55       "aten::fmod(Tensor self, Tensor other) -> Tensor",
56       "aten::frac(Tensor self) -> Tensor",
57       "aten::lgamma(Tensor self) -> Tensor",
58       "aten::log(Tensor self) -> Tensor",
59       "aten::log10(Tensor self) -> Tensor",
60       "aten::log1p(Tensor self) -> Tensor",
61       "aten::log2(Tensor self) -> Tensor",
62       "aten::logit(Tensor self, float? eps=None) -> Tensor",
63       "aten::lerp(Tensor self, Tensor end, Scalar weight) -> Tensor",
64       "aten::lerp(Tensor self, Tensor end, Tensor weight) -> Tensor",
65       "aten::max(Tensor self, Tensor other) -> Tensor",
66       "aten::min(Tensor self, Tensor other) -> Tensor",
67       "aten::mul(Tensor self, Tensor other) -> Tensor",
68       "aten::neg(Tensor self) -> Tensor",
69       "aten::pow(Tensor self, Tensor exponent) -> Tensor",
70       "aten::pow(Tensor self, Scalar exponent) -> Tensor",
71       "aten::pow(Scalar self, Tensor exponent) -> Tensor",
72       "aten::reciprocal(Tensor self) -> Tensor",
73       "aten::relu(Tensor self) -> Tensor",
74       "aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor",
75       "aten::remainder(Tensor self, Tensor other) -> Tensor",
76       "aten::round(Tensor self) -> Tensor",
77       "aten::rsqrt(Tensor self) -> Tensor",
78       "aten::sigmoid(Tensor self) -> Tensor",
79       "aten::sin(Tensor self) -> Tensor",
80       "aten::sinh(Tensor self) -> Tensor",
81       "aten::sqrt(Tensor self) -> Tensor",
82       "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
83       "aten::tan(Tensor self) -> Tensor",
84       "aten::rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
85       "aten::tanh(Tensor self) -> Tensor",
86       "aten::trunc(Tensor self) -> Tensor",
87       "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor",
88       "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor",
89       "aten::mul(Tensor self, Scalar other) -> Tensor",
90       "aten::div(Tensor self, Scalar other) -> Tensor",
91 
92       "aten::eq(Tensor self, Tensor other) -> Tensor",
93       "aten::eq(Tensor self, Scalar other) -> Tensor",
94       "aten::ne(Tensor self, Tensor other) -> Tensor",
95       "aten::ne(Tensor self, Scalar other) -> Tensor",
96       "aten::ge(Tensor self, Tensor other) -> Tensor",
97       "aten::ge(Tensor self, Scalar other) -> Tensor",
98       "aten::gt(Tensor self, Tensor other) -> Tensor",
99       "aten::gt(Tensor self, Scalar other) -> Tensor",
100       "aten::le(Tensor self, Tensor other) -> Tensor",
101       "aten::le(Tensor self, Scalar other) -> Tensor",
102       "aten::lt(Tensor self, Tensor other) -> Tensor",
103       "aten::lt(Tensor self, Scalar other) -> Tensor",
104 
105       "aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor",
106       "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor",
107 
108       "aten::type_as(Tensor self, Tensor other) -> Tensor",
109   }};
110   if (!node->isMemberOf(simple_mappable)) {
111     return false;
112   }
113   for (Value* input : node->inputs()) {
114     if (input->type()->isSubtypeOf(*TensorType::get()) ||
115         input->type()->isSubtypeOf(*FloatType::get())) {
116       continue;
117     }
118     if (input->node()->kind() != prim::Constant) {
119       return false;
120     }
121   }
122   return true;
123 }
124 
125 struct GraphFuser {
126   using FusionCallback = std::function<bool(GraphFuser*, Node*)>;
127 
128   Block* block_;
129   AliasDb* aliasDb_;
130   std::shared_ptr<Graph> graph_;
__anon688491630202torch::jit::__anon688491630111::GraphFuser131   FusionCallback callback_ = [](GraphFuser* gf, Node* n) {
132     return gf->isFusableDefault(n, gf->strict_fuser_check_);
133   };
134   Symbol kind_ = prim::FusionGroup;
135   bool strict_fuser_check_ = false;
136 
137   // nvrtc has a limit on the number of arguments allowed in a CUDA kernel.
138   // The specific limit is a function of constant memory size, amount available
139   // to pass arguments, and some implementation dependence. Select a safe
140   // limit here.
141   // This limit is also applied to other devices in the fuser by default.
142   // Change with setInputArgLimit
143   size_t subgraph_arg_limit_ = 128;
144 
GraphFusertorch::jit::__anon688491630111::GraphFuser145   GraphFuser(AliasDb* aliasDb, Block* block, bool strict_fuser_check)
146       : block_(block),
147         aliasDb_(aliasDb),
148         strict_fuser_check_(strict_fuser_check) {}
149 
150   // Custom passes require kind to specified
GraphFusertorch::jit::__anon688491630111::GraphFuser151   GraphFuser(
152       AliasDb* aliasDb,
153       Block* block,
154       FusionCallback callback,
155       Symbol kind,
156       bool strict_fuser_check = false)
157       : block_(block),
158         aliasDb_(aliasDb),
159         callback_(std::move(callback)),
160         kind_(kind),
161         strict_fuser_check_(strict_fuser_check) {}
162 
setInputArgLimittorch::jit::__anon688491630111::GraphFuser163   void setInputArgLimit(size_t limit) {
164     subgraph_arg_limit_ = limit;
165   }
166 
tensorInputstorch::jit::__anon688491630111::GraphFuser167   value_list tensorInputs(Node* node) {
168     return filter(node->inputs(), [](Value* v) {
169       return v->type()->isSubtypeOf(*TensorType::get());
170     });
171   }
172 
isFusabletorch::jit::__anon688491630111::GraphFuser173   bool isFusable(Node* node) {
174     return callback_(this, node);
175   }
176 
isFusableDevicetorch::jit::__anon688491630111::GraphFuser177   bool isFusableDevice(Value* v, bool strict_fuser_check) {
178     if (!v->type()->isSubtypeOf(*TensorType::get())) {
179       return true;
180     }
181     auto device = v->type()->expectRef<TensorType>().device();
182     if (!device) {
183       return !strict_fuser_check;
184     }
185     if ((*device).is_cpu()) {
186       return canFuseOnCPULegacy();
187     } else if ((*device).is_cuda()) {
188       return canFuseOnGPU();
189     } else if ((*device).is_xpu()) {
190       return false;
191     } else {
192       TORCH_CHECK_NOT_IMPLEMENTED(false, "Unknown device for graph fuser");
193     }
194   }
195 
196   // Default fusability check - used when the user doesn't pass in
197   // a callback.
isFusableDefaulttorch::jit::__anon688491630111::GraphFuser198   bool isFusableDefault(Node* node, bool strict_fuser_check) {
199     bool fusableDevice = true;
200     for (const auto& output : node->outputs()) {
201       if (!output->uses().empty()) {
202         fusableDevice &= isFusableDevice(output, strict_fuser_check);
203       }
204     }
205     return fusableDevice && isFusableMap(node);
206   }
207 
isFusableMaptorch::jit::__anon688491630111::GraphFuser208   bool isFusableMap(Node* node) {
209     // We don't want to bother with cross-block node movements, as they
210     // are not necessarily correct.
211     if (node->owningBlock() != block_)
212       return false;
213     return node->kind() == prim::FusionGroup || isSimpleMap(node);
214   }
215 
isFusableCatNodetorch::jit::__anon688491630111::GraphFuser216   bool isFusableCatNode(Node* node) {
217     if (node->kind() != aten::cat)
218       return false;
219     if (!node->is_constant(attr::dim))
220       return false;
221 
222     auto tensors_node = node->namedInput(attr::tensors)->node();
223     if ((tensors_node->inputs().size() + node->outputs().size()) >
224         subgraph_arg_limit_) {
225       return false;
226     }
227     if (tensors_node->kind() != prim::ListConstruct)
228       return false;
229     // NB: Note that technically other uses of the list aren't a big problem for
230     // us. It would be enough to place the prim::FusedConcat before the
231     // prim::ListConstruct, and allUsersAreThisConsumerOrOccurAfterIt would
232     // still be satisfied. However, I don't expect this to be necessary any time
233     // soon, and so we're simply assuming that we don't have to deal with it.
234     if (tensors_node->output()->uses().size() > 1)
235       return false;
236     return true;
237   }
238 
calculatesSizetorch::jit::__anon688491630111::GraphFuser239   bool calculatesSize(Node* node) {
240     return node->matches("aten::size(Tensor self) -> int[]");
241   }
242 
allUsersAreThisConsumerOrCalcSizestorch::jit::__anon688491630111::GraphFuser243   bool allUsersAreThisConsumerOrCalcSizes(Node* consumer, Value* producer) {
244     auto defining_node = producer->node();
245     for (auto o : defining_node->outputs()) {
246       for (auto u : o->uses()) {
247         if (u.user != consumer && !calculatesSize(u.user))
248           return false;
249       }
250     }
251     return true;
252   }
253 
getSubgraphtorch::jit::__anon688491630111::GraphFuser254   Graph& getSubgraph(Node* n) {
255     AT_ASSERT(n->kind() == kind_);
256     return *n->g(attr::Subgraph);
257   }
258 
mergeFusionGroupstorch::jit::__anon688491630111::GraphFuser259   void mergeFusionGroups(Node* consumer_group, Node* producer_group) {
260     // Now we have two fusion groups!
261     // Revert the fusion - place all inner nodes of producer back in the outer
262     // graph.
263     std::vector<Node*> temporary_nodes;
264     auto producer_subgraph = &getSubgraph(producer_group);
265 
266     // Initialize a map of inner graph values to outer graph values
267     std::unordered_map<Value*, Value*> inner_to_outer;
268     auto inner_inputs = producer_subgraph->inputs();
269     auto outer_inputs = producer_group->inputs();
270     for (const auto i : c10::irange(inner_inputs.size())) {
271       inner_to_outer[inner_inputs[i]] = outer_inputs[i];
272     }
273 
274     // Clone all nodes
275     for (auto inner : producer_subgraph->nodes()) {
276       Node* outer = block_->owningGraph()->createClone(
277           inner, [&](Value* k) -> Value* { return inner_to_outer.at(k); });
278       outer->insertBefore(producer_group);
279       temporary_nodes.emplace_back(outer);
280       auto inner_outputs = inner->outputs();
281       auto outer_outputs = outer->outputs();
282       for (const auto i : c10::irange(inner_outputs.size())) {
283         inner_to_outer[inner_outputs[i]] = outer_outputs[i];
284       }
285     }
286 
287     // Replace uses of producer_group outputs and destroy the producer
288     auto subgraph_outputs = producer_subgraph->outputs();
289     for (const auto i : c10::irange(subgraph_outputs.size())) {
290       auto outer_output = inner_to_outer.at(subgraph_outputs[i]);
291       producer_group->outputs()[i]->replaceAllUsesWith(outer_output);
292       // new producer outputs have same aliasing properties as outer_output
293       aliasDb_->replaceWithNewValue(producer_group->outputs()[i], outer_output);
294     }
295     producer_group->destroy();
296     producer_group =
297         nullptr; // Just to get a clear error in case someone uses it
298 
299     // Inline the temporary nodes into the first group
300     auto consumer_subgraph = &getSubgraph(consumer_group);
301     for (auto it = temporary_nodes.rbegin(); it != temporary_nodes.rend();
302          ++it) {
303       Node* node = *it;
304       Node* merged = mergeNodeIntoGroup(consumer_group, node);
305       // If any of the outputs are still used then we need to add them
306       auto outputs = node->outputs();
307       for (const auto i : c10::irange(outputs.size())) {
308         auto output = outputs[i];
309         if (output->uses().empty())
310           continue;
311         consumer_subgraph->registerOutput(merged->outputs()[i]);
312         auto new_output = consumer_group->addOutput();
313         output->replaceAllUsesWith(new_output);
314         aliasDb_->replaceWithNewValue(output, new_output);
315         new_output->setType(output->type());
316       }
317       node->destroy();
318     }
319   }
320 
321   // insert a producer node into a consuming fusion group.
322   // DOES NOT WORK if n is a consumer of an output of the fusion group
323   // returns the node _inside_ the group that represents the node
mergeNodeIntoGrouptorch::jit::__anon688491630111::GraphFuser324   Node* mergeNodeIntoGroup(Node* group, Node* n) {
325     AT_ASSERT(n->kind() != kind_);
326     auto& subgraph = getSubgraph(group);
327     // map from nodes in the surrounding graph to parameters in the fusion
328     // group's subgraph that correspond to them
329     std::unordered_map<Value*, Value*> inputs_map;
330     size_t i = 0;
331     size_t tensor_insert_idx = 0;
332     AT_ASSERT(group->inputs().size() == subgraph.inputs().size());
333     for (auto input : group->inputs()) {
334       inputs_map[input] = subgraph.inputs()[i++];
335       if (input->type()->isSubtypeOf(*TensorType::get()))
336         tensor_insert_idx = i;
337     }
338     // add n's inputs to the fusion group's input list if we don't already have
339     // them
340     // we insert tensors first because the fuser assumes that to be the case
341     // (as a legacy from tensors only)
342     WithInsertPoint guard(*subgraph.nodes().begin());
343     for (auto input : n->inputs()) {
344       if (inputs_map.count(input) == 0) {
345         if (input->type()->isSubtypeOf(*TensorType::get())) {
346           auto in_group = subgraph.insertInput(tensor_insert_idx);
347           in_group->setType(input->type());
348           inputs_map[input] = in_group;
349           group->insertInput(tensor_insert_idx, input);
350           tensor_insert_idx++;
351         } else if (
352             (input->type()->isSubtypeOf(*FloatType::get()) &&
353              input->node()->kind() != prim::Constant) ||
354             (n->kind() == aten::_grad_sum_to_size &&
355              input->type()->isSubtypeOf(*ListType::ofInts()))) {
356           auto in_group = subgraph.addInput();
357           in_group->setType(input->type());
358           inputs_map[input] = in_group;
359           group->addInput(input);
360         } else {
361           // We don't support passing in scalars as arguments to fused kernels,
362           // so we generally don't allow fusing tensor-scalar operations unless
363           // the scalar is constant. In those cases we inline the constants
364           // directly in the body of the fused group.
365           AT_ASSERT(input->node()->kind() == prim::Constant);
366           Node* in_const =
367               subgraph.createClone(input->node(), [](Value*) -> Value* {
368                 throw std::runtime_error("unexpected input");
369               });
370           subgraph.insertNode(in_const);
371           inputs_map[input] = in_const->output();
372         }
373       }
374     }
375     // copy n into the graph, remapping its inputs to internal nodes
376     Node* in_graph = subgraph.createClone(
377         n, [&](Value* k) -> Value* { return inputs_map[k]; });
378     // if n's outputs are already inputs to the fusion group,
379     // we need to remove them because n is now inside the fusion group.
380     //
381     // i.e.,
382     // x = f(w); group(x, y, z) becomes group(w, y, z).
383     // x, y, z = f(w); group(x, y, z) becomes group(w).
384     //
385     // remapping nodes that used the input to the newly-merged node
386     // n is not an input when the fusion group is empty
387     auto inputs = group->inputs();
388     for (size_t i = 0; i < n->outputs().size(); ++i) {
389       auto it = std::find(inputs.begin(), inputs.end(), n->outputs()[i]);
390       if (it != inputs.end()) {
391         size_t p = it - inputs.begin();
392         group->removeInput(p);
393         subgraph.inputs()[p]->replaceAllUsesWith(in_graph->outputs()[i]);
394         subgraph.eraseInput(p);
395       }
396     }
397     return subgraph.insertNode(in_graph);
398   }
399 
400   // turn consumer node n into a fusion group with just n inside
401   // to prepare for fusion and replace uses of n with the new group
createSingletonFusionGrouptorch::jit::__anon688491630111::GraphFuser402   Node* createSingletonFusionGroup(Node* n) {
403     auto group = block_->owningGraph()->createWithSubgraph(kind_);
404     // propagate position information for the new node so we can always
405     // have a valid mapping
406     group->insertBefore(n);
407     Node* mergedNode = mergeNodeIntoGroup(group, n);
408     getSubgraph(group).registerOutput(mergedNode->output());
409     auto sel = group->addOutput();
410     sel->copyMetadata(n->output());
411     aliasDb_->replaceWithNewValue(n->output(), sel);
412     n->replaceAllUsesWith(group);
413     n->destroy();
414     return group;
415   }
416 
tryFusetorch::jit::__anon688491630111::GraphFuser417   std::optional<Node*> tryFuse(Node* consumer, Value* producer) {
418     // this handles cases where producer can be moved _into_ the fusion group of
419     // consumer.
420     // TODO: extend to fusion of consumer into _producer's_ fusion blob
421     // if the consumer allInputsAreThisProducer(consumer,producer)
422     // we can move the consumer up into the producer.
423     // but this requires better handling of merging fusion groups so it is not
424     // done now
425     bool shouldFuse = isFusable(producer->node()) &&
426         // Rearrange nodes such that all uses of producer are after the
427         // consumer. Fusion will rewrite those later uses to use the version of
428         // producer generated by the fused blob. In this case, producer becomes
429         // an output of the fusion group.
430         aliasDb_->moveBeforeTopologicallyValid(producer->node(), consumer);
431 
432     if (!shouldFuse) {
433       return std::nullopt;
434     }
435 
436     if ((consumer->inputs().size() + consumer->outputs().size() +
437          producer->node()->inputs().size() +
438          producer->node()->outputs().size()) > subgraph_arg_limit_) {
439       return std::nullopt;
440     }
441 
442     auto group = consumer;
443     if (consumer->kind() != kind_) {
444       group = createSingletonFusionGroup(consumer);
445     }
446 
447     if (producer->node()->kind() == kind_) {
448       mergeFusionGroups(group, producer->node());
449       return group;
450     }
451     AT_ASSERT(producer->node()->outputs().size() == 1);
452     Node* merged = mergeNodeIntoGroup(group, producer->node());
453     // remaining uses of this producer can occur because we allow
454     // fusion in cases where uses remain after the consumer
455     // if these exist, re-route them to the version of producer
456     // created in FusionGroup
457     if (!producer->uses().empty()) {
458       getSubgraph(group).registerOutput(merged->output());
459       Value* new_producer = group->addOutput();
460       new_producer->copyMetadata(producer);
461       aliasDb_->replaceWithNewValue(producer, new_producer);
462       producer->replaceAllUsesWith(new_producer);
463     }
464     producer->node()->destroy();
465     return group;
466   }
467 
canFuseChunktorch::jit::__anon688491630111::GraphFuser468   bool canFuseChunk(Node* consumer, Value* producer) {
469     if (consumer->kind() != prim::FusionGroup) {
470       return false;
471     }
472     // Does the chunk have constant chunks/dim?
473     auto* chunk = producer->node();
474     if (chunk->kind() != prim::ConstantChunk)
475       return false;
476     // And all uses of the chunk are in this consumer
477     for (auto s : chunk->outputs()) {
478       for (auto u : s->uses()) {
479         if (u.user != consumer) {
480           return false;
481         }
482       }
483     }
484     // And isn't a no-op chunk (chunks == 1). Have CSE clean this up.
485     // We could fuse this but it's better to just delete the node.
486     if (chunk->i(attr::chunks) == 1) {
487       return false;
488     }
489     return true;
490   }
491 
findFusedChunktorch::jit::__anon688491630111::GraphFuser492   std::optional<Node*> findFusedChunk(Node* group, Value* input) {
493     AT_ASSERT(group->kind() == prim::FusionGroup);
494     auto it = std::find(group->inputs().begin(), group->inputs().end(), input);
495     if (it == group->inputs().end()) {
496       return std::nullopt;
497     }
498     size_t input_index = it - group->inputs().begin();
499     auto& subgraph = getSubgraph(group);
500     auto* subgraph_input = subgraph.inputs().at(input_index);
501     // If subgraph_input is an input to prim::ConstantChunk, it will have 1 use
502     auto* node = subgraph_input->uses().at(0).user;
503     if (node->kind() == prim::ConstantChunk) {
504       AT_ASSERT(subgraph_input->uses().size() == 1);
505       return node;
506     }
507     return std::nullopt;
508   }
509 
fuseChunkByReusingExistingFusedChunktorch::jit::__anon688491630111::GraphFuser510   void fuseChunkByReusingExistingFusedChunk(
511       Node* group,
512       Node* chunk,
513       Node* existingFusedChunk) {
514     if (chunk->outputs().size() != existingFusedChunk->outputs().size()) {
515       return;
516     }
517     auto& subgraph = getSubgraph(group);
518     for (size_t i = 0; i < chunk->outputs().size(); ++i) {
519       // Find the input to the FusionGroup (group)
520       auto* replacement_val = existingFusedChunk->outputs().at(i);
521       auto* val = chunk->outputs().at(i);
522       auto it = std::find(group->inputs().begin(), group->inputs().end(), val);
523       auto input_index = it - group->inputs().begin();
524 
525       // Rewrite the graph to use replacement_val
526       auto group_input = subgraph.inputs().at(input_index);
527       group_input->replaceAllUsesWith(replacement_val);
528 
529       // Remove the input, it's no longer needed
530       group->removeInput(input_index);
531       subgraph.eraseInput(input_index);
532     }
533     chunk->destroy();
534   }
535 
536   // There are two invariants for prim::ConstantChunk:
537   // (1) the tensor input to prim::ConstantChunk must be an input to the fusion
538   // group (2) no two ConstantChunks in the same FusionGroup can share a tensor
539   // input.
fuseChunktorch::jit::__anon688491630111::GraphFuser540   graph_node_list::iterator fuseChunk(Node* consumer, Value* producer) {
541     auto* chunk = producer->node();
542     AT_ASSERT(consumer->kind() == prim::FusionGroup);
543     AT_ASSERT(chunk->kind() == prim::ConstantChunk);
544 
545     // if producer's input is already an input to a prim::ConstantChunk node,
546     // we cannot add a new prim::ConstantChunk node because of invariant (2).
547     auto* chunked_tensor = producer->node()->input();
548     if (auto existingFusedChunk = findFusedChunk(consumer, chunked_tensor)) {
549       fuseChunkByReusingExistingFusedChunk(
550           consumer, chunk, *existingFusedChunk);
551       return consumer->reverseIterator();
552     }
553 
554     // Move prim::ConstantChunk into the FusionGroup
555     mergeNodeIntoGroup(consumer, chunk);
556     chunk->destroy();
557     return consumer->reverseIterator();
558   }
559 
sortReverseTopologicaltorch::jit::__anon688491630111::GraphFuser560   value_list sortReverseTopological(ArrayRef<Value*> inputs) {
561     value_list result;
562     for (auto i : inputs) {
563       if (i->node()->owningBlock() == block_) {
564         result.push_back(i);
565       }
566     }
567     // Sort in reverse topological order
568     std::sort(result.begin(), result.end(), [&](Value* a, Value* b) {
569       return a->node()->isAfter(b->node());
570     });
571     return result;
572   }
573 
scanNodeForChunkstorch::jit::__anon688491630111::GraphFuser574   graph_node_list::iterator scanNodeForChunks(Node* consumer) {
575     if (consumer->kind() == prim::FusionGroup) {
576       auto inputs = sortReverseTopological(consumer->inputs());
577       for (auto producer : inputs) {
578         if (!canFuseChunk(consumer, producer)) {
579           continue;
580         }
581         return fuseChunk(consumer, producer);
582       }
583     }
584     return ++consumer->reverseIterator();
585   }
586 
broadcast_tensorstorch::jit::__anon688491630111::GraphFuser587   at::ArrayRef<Value*> broadcast_tensors(value_list inputs) {
588     AT_ASSERT(!inputs.empty());
589     auto* g = inputs[0]->owningGraph();
590     auto* input_list =
591         g->insertNode(g->createList(TensorType::get(), inputs))->output();
592     aliasDb_->createValue(input_list);
593     auto* output_list = g->insert(aten::broadcast_tensors, {input_list});
594     aliasDb_->createValue(output_list);
595     auto* unpack_node = g->insertNode(
596         g->create(prim::ListUnpack, {output_list}, inputs.size()));
597 
598     // We are doing:
599     //   input_list = listConstruct(a, b, ...)
600     //   output_list = broadcast_tensors(input_list)
601     //   a_broadcasted, b_broadcasted = listUnpack(output_list)
602     // `a_broadcasted` should receive the same aliasing info as `a`
603     TORCH_INTERNAL_ASSERT(unpack_node->outputs().size() == inputs.size());
604     for (const auto i : c10::irange(inputs.size())) {
605       Value* original_input = inputs[i];
606       Value* broadcasted_output = unpack_node->outputs()[i];
607       aliasDb_->copyValue(original_input, broadcasted_output);
608     }
609 
610     return unpack_node->outputs();
611   }
612 
insertExplicitBroadcasttorch::jit::__anon688491630111::GraphFuser613   void insertExplicitBroadcast(Node* node) {
614     WithInsertPoint insert_guard{node};
615     auto tensors = tensorInputs(node);
616     auto new_tensors = broadcast_tensors(std::move(tensors));
617 
618     // Replace tensors inputs with broadcasted values
619     auto new_tensors_it = new_tensors.begin();
620     for (size_t i = 0; i < node->inputs().size(); ++i) {
621       if (node->inputs()[i]->type()->isSubtypeOf(*TensorType::get())) {
622         AT_ASSERT(new_tensors_it != new_tensors.end());
623         node->replaceInput(i, *(new_tensors_it++));
624       }
625     }
626   }
627 
promoteChunkToBroadcastingChunktorch::jit::__anon688491630111::GraphFuser628   Node* promoteChunkToBroadcastingChunk(Node* chunk) {
629     AT_ASSERT(chunk->kind() == prim::ConstantChunk);
630 
631     size_t nchunks = chunk->i(attr::chunks);
632     Node* bchunk =
633         chunk->owningGraph()->create(prim::BroadcastingChunk, nchunks);
634     bchunk->addInput(chunk->input());
635     for (const auto i : c10::irange(nchunks)) {
636       auto* old_output = chunk->outputs().at(i);
637       auto* new_output = bchunk->outputs().at(i);
638       new_output->copyMetadata(old_output);
639       aliasDb_->replaceWithNewValue(old_output, new_output);
640       old_output->replaceAllUsesWith(new_output);
641     }
642     bchunk->copyAttributes(*chunk);
643     bchunk->insertAfter(chunk);
644     chunk->destroy();
645     return bchunk;
646   }
647 
648   // in places where op can be fused into a consumer but chunk is in the way
649   // distribute chunk to op's operands:
650   // replace a,b = chunk(op(x,y,z)) with:
651   // x', y', z' = broadcast_tensors([x, y, z])
652   // x0,x1 = chunk(x') (x0 has a's type, x1 has b's type)
653   // y0,y1 = chunk(y') (y0 has a's type, y1 has b's type)
654   // z0,z1 = chunk(z') (z0 has a's type, z1 has b's type)
655   // a = op(x0,y0,z0) (a,b have their same size but are now contiguous)
656   // b = op(x1,y1,x1)
657   //
658   // The graph fuser uses an intermediate prim::BroadcastingChunk node to
659   // represent this behavior concisely. BroadcastingChunk(x, y, z) broadcasts
660   // all of its inputs and then chunks each input, in order, the same way.
661   // The above graph is equivalent to:
662   // x0, x1, y0, y1, z0, z1 = BroadcastingChunk(x, y, z)
663   // a = op(x0,y0,z0)
664   // b = op(x1,y1,x1)
665   //
666   // NB: The explicit broadcast is important for correctness.
667   // Let's say we have:
668   // %z = aten::mul(%x, %y)
669   // %z.1, %z.2 = aten::chunk(%z, ...)
670   // ... = prim::FusionGroup(%z.1, %z.2, ...)
671   // It's possible that %x and %y do not have the same size as %z and
672   // need to be expanded first so that they can be chunked like %z
673   //
674   // NB: Chunk motion only occurs with fusable consumers, which implies
675   // that there is always some other operation, e.g., a+b, that happens
676   // after the chunk, and will be put into the fusion group. This is
677   // important, because distributing the chunk changes the contiguity
678   // of a and b, and so the results would be invalid, except that we know
679   // that simple_mappable operations will restore contiguity before
680   // we exit the fusion group.
681   //
682   // NB: The intermediate BroadcastingChunk is important for moving chunks past
683   // more than one operation: the graph fuser is not able to easily move
684   // operations around broadcast_tensors + chunk nodes. Let f, g, h be fusible
685   // ops
686   //   x = f(v, w)
687   //   z = g(x, y)
688   //   a, b = chunk(z)
689   //   c = h(a, b)
690   // becomes (with the broadcast_tensors + chunk approach):
691   //   x = f(v, w)
692   //   x', y' = broadcast_tensors([x, y])
693   //   ax, bx = chunk(x')
694   //   ay, by = chunk(y')
695   //   a = g(ax, ay)
696   //   b = g(bx, by)
697   //   c = h(a, b)
698   // The broadcast_tensors node makes it harder to move f into the resulting
699   // FusionGroup of g, g, and h. Keeping the broadcasting and chunk behavior
700   // together results in:
701   //   x = f(v, w)
702   //   ax, bx, ay, by = BroadcastingChunk(x, y)
703   //   a = g(ax, ay)
704   //   b = g(bx, by)
705   //   c = h(a, b)
706   // making it easier to move f after the BroadcastingChunk:
707   //   ay, by, av, bv, aw, bw = BroadcastingChunk(y, v, w)
708   //   ax = f(av, aw)
709   //   by = f(bv, bw)
710   //   a = g(ax, ay)
711   //   b = g(bx, by)
712   //   c = h(a, b)
713 
tryToMoveChunktorch::jit::__anon688491630111::GraphFuser714   bool tryToMoveChunk(Node* consumer, Value* producer) {
715     // is the output from a chunk/bchunk node?
716     auto* chunk = producer->node();
717     if (chunk->kind() != prim::ConstantChunk &&
718         chunk->kind() != prim::BroadcastingChunk)
719       return false;
720 
721     // try to find a producer to move after the chunk/bchunk. The producer must
722     // be fusible into the consumer.
723     auto it = std::find_if(
724         chunk->inputs().begin(),
725         chunk->inputs().end(),
726         [&](Value* producer_for_chunk) {
727           return isFusableMap(producer_for_chunk->node()) &&
728               allUsersAreThisConsumerOrCalcSizes(chunk, producer_for_chunk);
729         });
730     if (it == chunk->inputs().end()) {
731       return false;
732     }
733     Value* producer_for_chunk = *it;
734     size_t producer_index = it - chunk->inputs().begin();
735 
736     // all uses of the chunk must be in this consumer
737     for (auto s : chunk->outputs()) {
738       for (auto u : s->uses()) {
739         if (u.user != consumer)
740           return false;
741       }
742     }
743     // multiple return operators
744     Node* producer_for_chunk_node = producer_for_chunk->node();
745     AT_ASSERT(producer_for_chunk_node->outputs().size() == 1);
746 
747     // Convert chunk to bchunk, if it isn't one already. The bchunk represents a
748     // broadcast and one or more chunk operations.
749     auto* bchunk = chunk;
750     if (chunk->kind() == prim::ConstantChunk) {
751       bchunk = promoteChunkToBroadcastingChunk(chunk);
752     }
753     size_t nchunks = bchunk->i(attr::chunks);
754     WithInsertPoint guard(bchunk->next());
755 
756     std::vector<Value*> producer_chunk_outputs;
757     for (const auto i : c10::irange(nchunks)) {
758       producer_chunk_outputs.push_back(
759           bchunk->output(nchunks * producer_index + i));
760     }
761 
762     // Add each of op's operands to the bchunk node.
763     // chunked_inputs[input_nr][chunk_output_idx]
764     //  = Node* for chunk_output_idx'th output of the chunk(inputs[input_nr])
765     std::vector<std::vector<Value*>> chunked_inputs;
766 
767     for (auto input : producer_for_chunk_node->inputs()) {
768       // XXX: we only work with pointwise ops in here, so we know it is valid to
769       // push the concat only through tensor arguments (and all other args can
770       // be safely ignored).
771       if (!input->type()->isSubtypeOf(*TensorType::get()))
772         continue;
773 
774       // if 'input' is already an input to the bchunk, reuse it.
775       auto bchunk_inputs = bchunk->inputs();
776       auto it = std::find(bchunk_inputs.begin(), bchunk_inputs.end(), input);
777       if (it != bchunk_inputs.end()) {
778         chunked_inputs.emplace_back();
779         auto input_index = std::distance(bchunk_inputs.begin(), it);
780         for (const auto chunki : c10::irange(nchunks)) {
781           chunked_inputs.back().push_back(
782               bchunk->outputs().at(nchunks * input_index + chunki));
783         }
784         continue;
785       }
786 
787       // NB: I decided not to use cloneFrom here, because if we make cloneFrom
788       // copy selects one day, it is definitely not what you want here (selects
789       // have different types).
790       // TODO: Perhaps we should use cloneFrom now, as it seems unlikely
791       // to copy select nodes now that we have refactored to have a Value
792       // distinct from Node.
793       bchunk->addInput(input);
794       chunked_inputs.emplace_back(); // alas, to not be C++17
795       for (auto chunk_sel : producer_chunk_outputs) {
796         Value* input_chunk_sel = bchunk->addOutput();
797         input_chunk_sel->setType(chunk_sel->type());
798         // Add a fresh value for each output element of the broadcasting chunk
799         // node. This is safe because it will be consumed only by the chunked
800         // ops.
801         aliasDb_->createValue(input_chunk_sel);
802         chunked_inputs.back().push_back(input_chunk_sel);
803       }
804     }
805 
806     // apply the op to each chunk of the chunked operands,
807     // and then rewrite the graph to use them!
808     for (auto chunk_sel : producer_chunk_outputs) {
809       auto original_inputs = producer_for_chunk_node->inputs();
810       Node* chunked_op =
811           block_->owningGraph()->create(producer_for_chunk_node->kind());
812       chunked_op->copyAttributes(*producer_for_chunk_node);
813       chunked_op->output()->setType(chunk_sel->type());
814       auto chunked_inputs_it = chunked_inputs.begin();
815       for (Value* original_input : original_inputs) {
816         if (original_input->type()->isSubtypeOf(*TensorType::get())) {
817           AT_ASSERT(chunked_inputs_it != chunked_inputs.end());
818           chunked_op->addInput(
819               // NOLINTNEXTLINE(clang-analyzer-core.DivideZero)
820               chunked_inputs_it->at(chunk_sel->offset() % nchunks));
821           ++chunked_inputs_it;
822         } else {
823           chunked_op->addInput(original_input);
824         }
825       }
826       bchunk->owningGraph()->insertNode(chunked_op);
827       chunk_sel->replaceAllUsesWith(chunked_op->output());
828       aliasDb_->replaceWithNewValue(chunk_sel, chunked_op->output());
829     }
830 
831     bchunk->removeInput(producer_index);
832     for (const auto i : c10::irange(nchunks)) {
833       (void)i; // Suppress unused variable warning
834       bchunk->eraseOutput(nchunks * producer_index);
835     }
836 
837     // The output of producer_for_chunk_node could have been used in some
838     // aten::size operators, so we need to clean those up as well (we simply
839     // broadcast all its tensor inputs).
840     // We need to insert these early in the graph, i.e. immediately after
841     // the producer_for_chunk_node as we will have the _size_if_not_same
842     // that may be before the bchunk.
843     WithInsertPoint guard2(producer_for_chunk_node);
844     auto size_calc_uses = producer_for_chunk_node->output()->uses();
845     if (!size_calc_uses.empty()) {
846       auto tensor_inputs = filter(
847           producer_for_chunk_node->inputs(),
848           [](Value* v) { return v->type()->isSubtypeOf(*TensorType::get()); });
849       auto tensor_sizes = fmap(tensor_inputs, [&](Value* v) {
850         Value* output = v->owningGraph()->insert(aten::size, {v});
851         aliasDb_->createValue(output);
852         return output;
853       });
854       AT_ASSERT(!tensor_sizes.empty());
855       Value* output_size = tensor_sizes.size() == 1
856           ? tensor_sizes[0]
857           : broadcastSizes(tensor_sizes, aliasDb_);
858       for (Use u : size_calc_uses) {
859         u.user->output()->replaceAllUsesWith(output_size);
860         u.user->destroy();
861       }
862     }
863     producer_for_chunk_node->destroy();
864     return true;
865   }
866 
867   // returns where to continue scanning, and whether any fusion was made
scanNodetorch::jit::__anon688491630111::GraphFuser868   std::pair<graph_node_list::iterator, bool> scanNode(Node* consumer) {
869     if (isFusable(consumer)) {
870       // handle inputs in reverse topological order as well...
871       // otherwise in f(a,a+b) it will appear a is used twice if we consider
872       // the f-a fusion before the f-(a+b) fusion first.
873       auto inputs = sortReverseTopological(consumer->inputs());
874       for (auto producer : inputs) {
875         if (tryToMoveChunk(consumer, producer)) {
876           // the chunk before this consumer was re-arranged to allow fusion,
877           // we scan this consumer again to perform the fusion
878           return std::make_pair(consumer->reverseIterator(), true);
879         }
880         auto fusion_group = tryFuse(consumer, producer);
881         if (fusion_group) {
882           // after fusion, consumer moves into a FusionGroup, so inputs is no
883           // longer valid so we rescan the new FusionGroup for more fusions...
884           return std::make_pair(fusion_group.value()->reverseIterator(), true);
885         }
886       }
887     }
888     return std::make_pair(++consumer->reverseIterator(), false);
889   }
890 
replaceIntermediateBroadcastingChunkstorch::jit::__anon688491630111::GraphFuser891   void replaceIntermediateBroadcastingChunks() {
892     for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend();) {
893       auto* node = *it;
894       ++it; // We might delete node, so increment the iterator now.
895       if (node->kind() != prim::BroadcastingChunk) {
896         continue;
897       }
898       auto* bchunk = node;
899       insertExplicitBroadcast(bchunk);
900 
901       auto* graph = block_->owningGraph();
902       size_t nchunks = bchunk->i(attr::chunks);
903       WithInsertPoint guard(bchunk->next());
904 
905       // Split the bchunk into bchunks.inputs().size() number of chunk nodes.
906       for (size_t input_offset = 0; input_offset < bchunk->inputs().size();
907            input_offset++) {
908         auto* input = bchunk->inputs().at(input_offset);
909 
910         Node* new_chunk =
911             graph->insertNode(graph->create(prim::ConstantChunk, input, 0));
912         new_chunk->copyAttributes(*bchunk);
913         for (const auto output_offset : c10::irange(nchunks)) {
914           auto new_output = new_chunk->addOutput();
915           auto old_output =
916               bchunk->outputs().at(input_offset * nchunks + output_offset);
917           new_output->copyMetadata(old_output);
918           aliasDb_->replaceWithNewValue(old_output, new_output);
919           old_output->replaceAllUsesWith(new_output);
920         }
921       }
922       bchunk->destroy();
923     }
924   }
925 
926   // Builds up expressions that compute shapes of all intermediates (and
927   // outputs) of the fusion group, based on the sizes of inputs. You should run
928   // DCE to remove those that you end up not using.
buildShapeExpressionstorch::jit::__anon688491630111::GraphFuser929   std::unordered_map<Value*, Value*> buildShapeExpressions(Node* fusion_group) {
930     WithInsertPoint insert_guard{fusion_group->next()};
931     std::unordered_map<Value*, Value*> shape_of;
932 
933     Graph* graph = fusion_group->owningGraph();
934     auto subgraph = fusion_group->g(attr::Subgraph);
935 
936     auto inputs = fusion_group->inputs();
937     auto sinputs = subgraph->inputs();
938     AT_ASSERT(inputs.size() == sinputs.size());
939     for (const auto i : c10::irange(inputs.size())) {
940       if (inputs[i]->type()->isSubtypeOf(*TensorType::get())) {
941         Value* soutput = graph->insert(aten::size, {inputs[i]});
942         aliasDb_->createValue(soutput);
943         shape_of[sinputs[i]] = soutput;
944       }
945     }
946 
947     // When we have a guarantee that an output won't be removed, because it's
948     // used in expressions that don't involve size checks, we can use its size
949     // instead of computing a long chain of broadcasts, starting from the
950     // beginning of the kernel.
951     auto outputs = fusion_group->outputs();
952     auto soutputs = subgraph->outputs();
953     AT_ASSERT(outputs.size() == soutputs.size());
954     for (const auto i : c10::irange(outputs.size())) {
955       if (usedOnlyInSize(outputs[i]))
956         continue;
957       Value* soutput = graph->insert(aten::size, {outputs[i]});
958       aliasDb_->createValue(soutput);
959       shape_of[soutputs[i]] = soutput;
960     }
961 
962     for (Node* n : subgraph->nodes()) {
963       // XXX: Use of shape_of.emplace is crucial to the output shape
964       // optimization!
965       if (n->kind() == prim::FusedConcat) {
966         // This is a bit more involved, because we have to account for the case
967         // when inputs have different shapes, but fortunately those tensors are
968         // always outputs, and so we can simply avoid replacing their queries,
969         // because it won't help us.
970         continue;
971       }
972       if (n->kind() == prim::Constant) {
973         continue;
974       }
975       if (n->kind() == prim::ConstantChunk) {
976         Node* sizes_node = graph->insertNode(
977             graph->create(prim::ChunkSizes, shape_of.at(n->input()), 2));
978         sizes_node->i_(attr::dim, n->i(attr::dim));
979         sizes_node->i_(attr::chunks, n->i(attr::chunks));
980         for (Value* output : sizes_node->outputs()) {
981           aliasDb_->createValue(output);
982         }
983         Value* regular_size = sizes_node->outputs().at(0);
984         Value* last_size = sizes_node->outputs().at(1);
985         regular_size->setType(ListType::ofInts());
986         last_size->setType(ListType::ofInts());
987         auto outputs = n->outputs();
988         for (Value* o : outputs.slice(0, outputs.size() - 1)) {
989           shape_of.emplace(o, regular_size);
990         }
991         shape_of.emplace(outputs.at(outputs.size() - 1), last_size);
992         continue;
993       }
994       auto tensor_inputs = filter(n->inputs(), [](Value* v) {
995         return v->type()->isSubtypeOf(*TensorType::get());
996       });
997       auto shapes =
998           fmap(tensor_inputs, [&](Value* v) { return shape_of.at(v); });
999       AT_ASSERT(!shapes.empty());
1000       shape_of.emplace(
1001           n->output(),
1002           shapes.size() == 1 ? shapes[0] : broadcastSizes(shapes, aliasDb_));
1003     }
1004     return shape_of;
1005   }
1006 
removeOutputsUsedOnlyInSizetorch::jit::__anon688491630111::GraphFuser1007   void removeOutputsUsedOnlyInSize(Node* fusion_group) {
1008     if (fusion_group->kind() != prim::FusionGroup)
1009       return;
1010     auto subgraph = fusion_group->g(attr::Subgraph);
1011 
1012     auto shape_of = buildShapeExpressions(fusion_group);
1013     auto outputs = fusion_group->outputs().vec();
1014     auto soutputs = subgraph->outputs().vec();
1015     // XXX: Iterating in this order is not only good for performance reasons!
1016     // It is also crucial for correctness (i has to reflect the current true
1017     // index of outputs[i])!
1018     for (int64_t i = static_cast<int64_t>(outputs.size()) - 1; i >= 0; --i) {
1019       auto output = outputs[i];
1020       auto soutput = soutputs[i];
1021       if (usedOnlyInSize(output) && shape_of.count(soutput) > 0) {
1022         auto uses = output->uses();
1023         for (Use u : uses) {
1024           AT_ASSERT(u.user->matches("aten::size(Tensor self) -> int[]"));
1025           u.user->output()->replaceAllUsesWith(shape_of.at(soutput));
1026           u.user->destroy();
1027         }
1028         fusion_group->eraseOutput(i);
1029         subgraph->eraseOutput(i);
1030       }
1031     }
1032   }
1033 
canFuseWithConcattorch::jit::__anon688491630111::GraphFuser1034   bool canFuseWithConcat(Value* producer, Node* before_check) {
1035     if (!isFusable(producer->node())) {
1036       return false;
1037     }
1038     // NB: it is important that this check happens after isFusable, which checks
1039     // that the blocks match, and it's not a special node like prim::Param
1040     if (!aliasDb_->couldMoveBeforeTopologically(
1041             producer->node(), before_check)) {
1042       return false;
1043     }
1044 
1045     // If the number of kernel args could exceed the limit, skip.
1046     if ((before_check->inputs().size() + before_check->outputs().size() +
1047          producer->node()->inputs().size() +
1048          producer->node()->outputs().size()) > subgraph_arg_limit_) {
1049       return false;
1050     }
1051 
1052     // Fusion groups can be merged with concat's group if and only if
1053     // the value they produce isn't already coming from a concat
1054     if (producer->node()->kind() == prim::FusionGroup) {
1055       auto subgraph = producer->node()->g(attr::Subgraph);
1056       auto* node = subgraph->outputs().at(producer->offset())->node();
1057       return node->kind() != prim::FusedConcat;
1058     }
1059     return true;
1060   }
1061 
createFusedConcattorch::jit::__anon688491630111::GraphFuser1062   Node* createFusedConcat(Node* node) {
1063     AT_ASSERT(node->kind() == aten::cat);
1064 
1065     Graph* graph = node->owningGraph();
1066     Node* list_construct = node->namedInput(attr::tensors)->node();
1067     int64_t dim = node->get<int64_t>(attr::dim).value();
1068 
1069     Node* fused_cat = graph->create(prim::FusedConcat, list_construct->inputs())
1070                           ->i_(attr::dim, dim);
1071     fused_cat->insertBefore(list_construct);
1072     fused_cat->output()->copyMetadata(node->output());
1073     aliasDb_->copyValue(node->output(), fused_cat->output());
1074 
1075     // NB: this deletes the fused_cat node from the original graph
1076     return createSingletonFusionGroup(fused_cat);
1077   }
1078 
fuseConcatstorch::jit::__anon688491630111::GraphFuser1079   void fuseConcats() {
1080     for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend();
1081          ++it) {
1082       Node* cat = *it;
1083       if (!isFusableCatNode(cat)) {
1084         continue;
1085       }
1086       Node* list_construct = cat->namedInput(attr::tensors)->node();
1087       Node* fused_cat = createFusedConcat(cat);
1088       Value* fused_cat_out = fused_cat->output();
1089 
1090       auto sorted_inputs = sortReverseTopological(fused_cat->inputs());
1091       size_t input_idx = 0;
1092       bool any_fused = false;
1093       while (input_idx < sorted_inputs.size()) {
1094         Value* input = sorted_inputs[input_idx++];
1095         if (!canFuseWithConcat(input, fused_cat)) {
1096           continue;
1097         }
1098         any_fused = true;
1099         auto maybe_group = tryFuse(fused_cat, input);
1100         AT_ASSERT(maybe_group && maybe_group == fused_cat);
1101         // We could have destroyed multiple inputs when performing this fusion,
1102         // so we have to recompute the list and iterate over it again.
1103         sorted_inputs = sortReverseTopological(fused_cat->inputs());
1104         input_idx = 0;
1105       }
1106 
1107       if (any_fused) {
1108         cat->output()->replaceAllUsesWith(fused_cat_out);
1109         it.destroyCurrent();
1110         if (list_construct->output()->uses().empty()) {
1111           list_construct->destroy();
1112         }
1113       } else {
1114         fused_cat->destroy();
1115       }
1116     }
1117   }
1118 
optimizeFusedGraphstorch::jit::__anon688491630111::GraphFuser1119   void optimizeFusedGraphs() {
1120     for (Node* node : block_->nodes()) {
1121       if (node->kind() != prim::FusionGroup) {
1122         continue;
1123       }
1124       auto subgraph = node->g(attr::Subgraph);
1125       EliminateDeadCode(subgraph);
1126       EliminateCommonSubexpression(subgraph);
1127       ConstantPooling(subgraph);
1128     }
1129   }
1130 
runtorch::jit::__anon688491630111::GraphFuser1131   void run() {
1132 // TODO: old fuser is not maintained internally, somewhere it is being turned on
1133 // inadvertently for certain workflows. make this a no-op until we identify
1134 // location
1135 #if defined(FBCODE_CAFFE2)
1136     return;
1137 #endif
1138 
1139     // Run the pass until no changes are made.
1140     // This is necessary, because the algorithm can miss out on certain fusion
1141     // opportunities if ran only once. Consider this graph:
1142     //
1143     // %1 = f(...)
1144     // %2 = g(%1)
1145     // %3 = h(%1)
1146     // %4 = l(%3)
1147     // return (%4, %2)
1148     //
1149     // where f, g, h, l are simple map ops.
1150     // The first iteration will fuse %4 and %3, and see that %1 is an input, but
1151     // can't be fused, because it has a different use before the fusion group
1152     // in our topological ordering. Then, %2 will be considered, and fused with
1153     // %1. If we do another iteration, the algorithm will consider the fusion of
1154     // these two groups and fix the situation.
1155     bool any_changed = true;
1156     while (any_changed) {
1157       any_changed = false;
1158       for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend();) {
1159         auto [tmp_it, changed] = scanNode(*it);
1160         it = tmp_it;
1161         any_changed |= changed;
1162       }
1163     }
1164 
1165     fuseConcats();
1166 
1167     optimizeFusedGraphs();
1168 
1169     // The graph fuser can add intermediate prim::BroadcastingChunk nodes.
1170     // Replace them with broadcasts + chunks.
1171     replaceIntermediateBroadcastingChunks();
1172 
1173     // Fuse starting chunks into the group.
1174     for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend();) {
1175       it = scanNodeForChunks(*it);
1176     }
1177 
1178     // Remove outputs that have been added only because we need their size
1179     for (Node* n : block_->nodes()) {
1180       removeOutputsUsedOnlyInSize(n);
1181     }
1182 
1183     for (Node* node : block_->nodes()) {
1184       for (Block* sub_block : node->blocks()) {
1185         GraphFuser(aliasDb_, sub_block, callback_, kind_, strict_fuser_check_)
1186             .run();
1187       }
1188     }
1189   }
1190 };
1191 
PeepholeOptimizeShapeExpressions(Block * block,AliasDb * db)1192 void PeepholeOptimizeShapeExpressions(Block* block, AliasDb* db) {
1193   auto nodes = block->nodes();
1194   for (auto it = nodes.begin(); it != nodes.end(); ++it) {
1195     Node* node = *it;
1196     for (Block* subblock : node->blocks()) {
1197       PeepholeOptimizeShapeExpressions(subblock, db);
1198     }
1199     if (node->kind() == prim::BroadcastSizes) {
1200       // Remove no-op broadcasts.
1201       if (node->inputs().size() == 1) {
1202         node->output()->replaceAllUsesWith(node->input());
1203         it.destroyCurrent();
1204         continue;
1205       }
1206       // Deduplicate inputs, but use their unique() values to ensure
1207       // this process only depends on the graph.
1208       std::map<size_t, Value*> unique_to_value;
1209       for (Value* input : node->inputs()) {
1210         unique_to_value.emplace(input->unique(), input);
1211       }
1212       if (unique_to_value.size() != node->inputs().size()) {
1213         std::vector<Value*> inputs;
1214         inputs.reserve(unique_to_value.size());
1215         for (auto& entry : unique_to_value) {
1216           inputs.push_back(entry.second);
1217         }
1218         if (inputs.size() == 1) {
1219           node->output()->replaceAllUsesWith(inputs[0]);
1220         } else {
1221           WithInsertPoint insert_guard{node};
1222           node->output()->replaceAllUsesWith(broadcastSizes(inputs, db));
1223         }
1224         it.destroyCurrent();
1225         --it; // Revisit the node with deduplicated inputs
1226         continue;
1227       }
1228       // Remove compose simple chains of broadcasts into a single node.
1229       const auto& uses = node->output()->uses();
1230       if (uses.size() == 1 && uses[0].user->kind() == prim::BroadcastSizes) {
1231         Node* user = uses[0].user;
1232         user->removeInput(uses[0].offset);
1233         // NB: we don't care about deduplication in here, as we will visit user
1234         // later.
1235         for (Value* i : node->inputs()) {
1236           user->addInput(i);
1237         }
1238         it.destroyCurrent();
1239       }
1240     }
1241   }
1242 }
1243 
1244 } // anonymous namespace
1245 
1246 static bool cpu_fuser_enabled_legacy = false;
1247 
canFuseOnCPULegacy()1248 bool canFuseOnCPULegacy() {
1249   return cpu_fuser_enabled_legacy;
1250 }
1251 
overrideCanFuseOnCPULegacy(bool value)1252 void overrideCanFuseOnCPULegacy(bool value) {
1253   cpu_fuser_enabled_legacy = value;
1254 }
1255 
FuseGraph(std::shared_ptr<Graph> & graph,bool strict_fuser_check)1256 void FuseGraph(std::shared_ptr<Graph>& graph, bool strict_fuser_check) {
1257   AliasDb db(graph);
1258   GraphFuser(&db, graph->block(), strict_fuser_check).run();
1259   Lint(&db);
1260   // After FuseGraph some common subexpressions may come back
1261   EliminateCommonSubexpression(graph);
1262   // We might have emitted a fair amount of useless shape propagating code, so
1263   // remove it
1264   EliminateDeadCode(graph);
1265   // Improve the quality of shape propagation code that was left
1266   PeepholeOptimizeShapeExpressions(graph->block(), &db);
1267 }
1268 
CustomFuseGraph(std::shared_ptr<Graph> & graph,const std::function<bool (Node *)> & fn,Symbol kind,size_t arg_limit)1269 void CustomFuseGraph(
1270     std::shared_ptr<Graph>& graph,
1271     const std::function<bool(Node*)>& fn,
1272     Symbol kind,
1273     size_t arg_limit) {
1274   AliasDb db(graph);
1275   auto g = GraphFuser(
1276       &db,
1277       graph->block(),
1278       [=](GraphFuser* gf, Node* n) { return fn(n) || n->kind() == kind; },
1279       kind);
1280   g.setInputArgLimit(arg_limit);
1281   g.run();
1282   Lint(&db);
1283 }
1284 
1285 } // namespace torch::jit
1286