1 #include <torch/csrc/jit/passes/graph_fuser.h>
2
3 #include <c10/util/Exception.h>
4 #include <c10/util/irange.h>
5 #include <torch/csrc/jit/codegen/fuser/interface.h>
6 #include <torch/csrc/jit/frontend/ir_emitter.h>
7 #include <torch/csrc/jit/ir/alias_analysis.h>
8 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
9 #include <torch/csrc/jit/passes/constant_pooling.h>
10 #include <torch/csrc/jit/passes/dead_code_elimination.h>
11 #include <torch/csrc/jit/passes/tensorexpr_fuser.h>
12 #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
13 #include <torch/csrc/jit/runtime/autodiff.h>
14 #include <torch/csrc/jit/runtime/custom_operator.h>
15 #include <torch/csrc/jit/runtime/operator.h>
16
17 #include <queue>
18 #include <unordered_map>
19 #include <utility>
20
21 namespace torch::jit {
22
23 namespace {
24
25 // What is a simple mappable operator? It:
26 // - Has a single tensor output
27 // - Output and all tensor inputs have the same shape
28 // - Output and all tensor inputs have the same scalar type
29 // or all tensor inputs have the same scalar type and
30 // output is identified in PropagateInputShapes
31 // - Output and all tensor inputs should be on the same device
32 // - Produces dense non-overlapping outputs
33 // Some of these restrictions may be relaxable, but you should
34 // carefully read the code first, as we rely on these assumptions.
isSimpleMap(Node * node)35 bool isSimpleMap(Node* node) {
36 static OperatorSet simple_mappable{{
37 "aten::_cast_Float(Tensor self, bool non_blocking) -> Tensor",
38
39 "aten::abs(Tensor self) -> Tensor",
40 "aten::acos(Tensor self) -> Tensor",
41 "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
42 "aten::asin(Tensor self) -> Tensor",
43 "aten::atan(Tensor self) -> Tensor",
44 "aten::atan2(Tensor self, Tensor other) -> Tensor",
45 "aten::ceil(Tensor self) -> Tensor",
46 "aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor",
47 "aten::cos(Tensor self) -> Tensor",
48 "aten::cosh(Tensor self) -> Tensor",
49 "aten::div(Tensor self, Tensor other) -> Tensor",
50 "aten::exp(Tensor self) -> Tensor",
51 "aten::expm1(Tensor self) -> Tensor",
52 "aten::erf(Tensor self) -> Tensor",
53 "aten::erfc(Tensor self) -> Tensor",
54 "aten::floor(Tensor self) -> Tensor",
55 "aten::fmod(Tensor self, Tensor other) -> Tensor",
56 "aten::frac(Tensor self) -> Tensor",
57 "aten::lgamma(Tensor self) -> Tensor",
58 "aten::log(Tensor self) -> Tensor",
59 "aten::log10(Tensor self) -> Tensor",
60 "aten::log1p(Tensor self) -> Tensor",
61 "aten::log2(Tensor self) -> Tensor",
62 "aten::logit(Tensor self, float? eps=None) -> Tensor",
63 "aten::lerp(Tensor self, Tensor end, Scalar weight) -> Tensor",
64 "aten::lerp(Tensor self, Tensor end, Tensor weight) -> Tensor",
65 "aten::max(Tensor self, Tensor other) -> Tensor",
66 "aten::min(Tensor self, Tensor other) -> Tensor",
67 "aten::mul(Tensor self, Tensor other) -> Tensor",
68 "aten::neg(Tensor self) -> Tensor",
69 "aten::pow(Tensor self, Tensor exponent) -> Tensor",
70 "aten::pow(Tensor self, Scalar exponent) -> Tensor",
71 "aten::pow(Scalar self, Tensor exponent) -> Tensor",
72 "aten::reciprocal(Tensor self) -> Tensor",
73 "aten::relu(Tensor self) -> Tensor",
74 "aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor",
75 "aten::remainder(Tensor self, Tensor other) -> Tensor",
76 "aten::round(Tensor self) -> Tensor",
77 "aten::rsqrt(Tensor self) -> Tensor",
78 "aten::sigmoid(Tensor self) -> Tensor",
79 "aten::sin(Tensor self) -> Tensor",
80 "aten::sinh(Tensor self) -> Tensor",
81 "aten::sqrt(Tensor self) -> Tensor",
82 "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
83 "aten::tan(Tensor self) -> Tensor",
84 "aten::rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
85 "aten::tanh(Tensor self) -> Tensor",
86 "aten::trunc(Tensor self) -> Tensor",
87 "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor",
88 "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor",
89 "aten::mul(Tensor self, Scalar other) -> Tensor",
90 "aten::div(Tensor self, Scalar other) -> Tensor",
91
92 "aten::eq(Tensor self, Tensor other) -> Tensor",
93 "aten::eq(Tensor self, Scalar other) -> Tensor",
94 "aten::ne(Tensor self, Tensor other) -> Tensor",
95 "aten::ne(Tensor self, Scalar other) -> Tensor",
96 "aten::ge(Tensor self, Tensor other) -> Tensor",
97 "aten::ge(Tensor self, Scalar other) -> Tensor",
98 "aten::gt(Tensor self, Tensor other) -> Tensor",
99 "aten::gt(Tensor self, Scalar other) -> Tensor",
100 "aten::le(Tensor self, Tensor other) -> Tensor",
101 "aten::le(Tensor self, Scalar other) -> Tensor",
102 "aten::lt(Tensor self, Tensor other) -> Tensor",
103 "aten::lt(Tensor self, Scalar other) -> Tensor",
104
105 "aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor",
106 "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor",
107
108 "aten::type_as(Tensor self, Tensor other) -> Tensor",
109 }};
110 if (!node->isMemberOf(simple_mappable)) {
111 return false;
112 }
113 for (Value* input : node->inputs()) {
114 if (input->type()->isSubtypeOf(*TensorType::get()) ||
115 input->type()->isSubtypeOf(*FloatType::get())) {
116 continue;
117 }
118 if (input->node()->kind() != prim::Constant) {
119 return false;
120 }
121 }
122 return true;
123 }
124
125 struct GraphFuser {
126 using FusionCallback = std::function<bool(GraphFuser*, Node*)>;
127
128 Block* block_;
129 AliasDb* aliasDb_;
130 std::shared_ptr<Graph> graph_;
__anon688491630202torch::jit::__anon688491630111::GraphFuser131 FusionCallback callback_ = [](GraphFuser* gf, Node* n) {
132 return gf->isFusableDefault(n, gf->strict_fuser_check_);
133 };
134 Symbol kind_ = prim::FusionGroup;
135 bool strict_fuser_check_ = false;
136
137 // nvrtc has a limit on the number of arguments allowed in a CUDA kernel.
138 // The specific limit is a function of constant memory size, amount available
139 // to pass arguments, and some implementation dependence. Select a safe
140 // limit here.
141 // This limit is also applied to other devices in the fuser by default.
142 // Change with setInputArgLimit
143 size_t subgraph_arg_limit_ = 128;
144
GraphFusertorch::jit::__anon688491630111::GraphFuser145 GraphFuser(AliasDb* aliasDb, Block* block, bool strict_fuser_check)
146 : block_(block),
147 aliasDb_(aliasDb),
148 strict_fuser_check_(strict_fuser_check) {}
149
150 // Custom passes require kind to specified
GraphFusertorch::jit::__anon688491630111::GraphFuser151 GraphFuser(
152 AliasDb* aliasDb,
153 Block* block,
154 FusionCallback callback,
155 Symbol kind,
156 bool strict_fuser_check = false)
157 : block_(block),
158 aliasDb_(aliasDb),
159 callback_(std::move(callback)),
160 kind_(kind),
161 strict_fuser_check_(strict_fuser_check) {}
162
setInputArgLimittorch::jit::__anon688491630111::GraphFuser163 void setInputArgLimit(size_t limit) {
164 subgraph_arg_limit_ = limit;
165 }
166
tensorInputstorch::jit::__anon688491630111::GraphFuser167 value_list tensorInputs(Node* node) {
168 return filter(node->inputs(), [](Value* v) {
169 return v->type()->isSubtypeOf(*TensorType::get());
170 });
171 }
172
isFusabletorch::jit::__anon688491630111::GraphFuser173 bool isFusable(Node* node) {
174 return callback_(this, node);
175 }
176
isFusableDevicetorch::jit::__anon688491630111::GraphFuser177 bool isFusableDevice(Value* v, bool strict_fuser_check) {
178 if (!v->type()->isSubtypeOf(*TensorType::get())) {
179 return true;
180 }
181 auto device = v->type()->expectRef<TensorType>().device();
182 if (!device) {
183 return !strict_fuser_check;
184 }
185 if ((*device).is_cpu()) {
186 return canFuseOnCPULegacy();
187 } else if ((*device).is_cuda()) {
188 return canFuseOnGPU();
189 } else if ((*device).is_xpu()) {
190 return false;
191 } else {
192 TORCH_CHECK_NOT_IMPLEMENTED(false, "Unknown device for graph fuser");
193 }
194 }
195
196 // Default fusability check - used when the user doesn't pass in
197 // a callback.
isFusableDefaulttorch::jit::__anon688491630111::GraphFuser198 bool isFusableDefault(Node* node, bool strict_fuser_check) {
199 bool fusableDevice = true;
200 for (const auto& output : node->outputs()) {
201 if (!output->uses().empty()) {
202 fusableDevice &= isFusableDevice(output, strict_fuser_check);
203 }
204 }
205 return fusableDevice && isFusableMap(node);
206 }
207
isFusableMaptorch::jit::__anon688491630111::GraphFuser208 bool isFusableMap(Node* node) {
209 // We don't want to bother with cross-block node movements, as they
210 // are not necessarily correct.
211 if (node->owningBlock() != block_)
212 return false;
213 return node->kind() == prim::FusionGroup || isSimpleMap(node);
214 }
215
isFusableCatNodetorch::jit::__anon688491630111::GraphFuser216 bool isFusableCatNode(Node* node) {
217 if (node->kind() != aten::cat)
218 return false;
219 if (!node->is_constant(attr::dim))
220 return false;
221
222 auto tensors_node = node->namedInput(attr::tensors)->node();
223 if ((tensors_node->inputs().size() + node->outputs().size()) >
224 subgraph_arg_limit_) {
225 return false;
226 }
227 if (tensors_node->kind() != prim::ListConstruct)
228 return false;
229 // NB: Note that technically other uses of the list aren't a big problem for
230 // us. It would be enough to place the prim::FusedConcat before the
231 // prim::ListConstruct, and allUsersAreThisConsumerOrOccurAfterIt would
232 // still be satisfied. However, I don't expect this to be necessary any time
233 // soon, and so we're simply assuming that we don't have to deal with it.
234 if (tensors_node->output()->uses().size() > 1)
235 return false;
236 return true;
237 }
238
calculatesSizetorch::jit::__anon688491630111::GraphFuser239 bool calculatesSize(Node* node) {
240 return node->matches("aten::size(Tensor self) -> int[]");
241 }
242
allUsersAreThisConsumerOrCalcSizestorch::jit::__anon688491630111::GraphFuser243 bool allUsersAreThisConsumerOrCalcSizes(Node* consumer, Value* producer) {
244 auto defining_node = producer->node();
245 for (auto o : defining_node->outputs()) {
246 for (auto u : o->uses()) {
247 if (u.user != consumer && !calculatesSize(u.user))
248 return false;
249 }
250 }
251 return true;
252 }
253
getSubgraphtorch::jit::__anon688491630111::GraphFuser254 Graph& getSubgraph(Node* n) {
255 AT_ASSERT(n->kind() == kind_);
256 return *n->g(attr::Subgraph);
257 }
258
mergeFusionGroupstorch::jit::__anon688491630111::GraphFuser259 void mergeFusionGroups(Node* consumer_group, Node* producer_group) {
260 // Now we have two fusion groups!
261 // Revert the fusion - place all inner nodes of producer back in the outer
262 // graph.
263 std::vector<Node*> temporary_nodes;
264 auto producer_subgraph = &getSubgraph(producer_group);
265
266 // Initialize a map of inner graph values to outer graph values
267 std::unordered_map<Value*, Value*> inner_to_outer;
268 auto inner_inputs = producer_subgraph->inputs();
269 auto outer_inputs = producer_group->inputs();
270 for (const auto i : c10::irange(inner_inputs.size())) {
271 inner_to_outer[inner_inputs[i]] = outer_inputs[i];
272 }
273
274 // Clone all nodes
275 for (auto inner : producer_subgraph->nodes()) {
276 Node* outer = block_->owningGraph()->createClone(
277 inner, [&](Value* k) -> Value* { return inner_to_outer.at(k); });
278 outer->insertBefore(producer_group);
279 temporary_nodes.emplace_back(outer);
280 auto inner_outputs = inner->outputs();
281 auto outer_outputs = outer->outputs();
282 for (const auto i : c10::irange(inner_outputs.size())) {
283 inner_to_outer[inner_outputs[i]] = outer_outputs[i];
284 }
285 }
286
287 // Replace uses of producer_group outputs and destroy the producer
288 auto subgraph_outputs = producer_subgraph->outputs();
289 for (const auto i : c10::irange(subgraph_outputs.size())) {
290 auto outer_output = inner_to_outer.at(subgraph_outputs[i]);
291 producer_group->outputs()[i]->replaceAllUsesWith(outer_output);
292 // new producer outputs have same aliasing properties as outer_output
293 aliasDb_->replaceWithNewValue(producer_group->outputs()[i], outer_output);
294 }
295 producer_group->destroy();
296 producer_group =
297 nullptr; // Just to get a clear error in case someone uses it
298
299 // Inline the temporary nodes into the first group
300 auto consumer_subgraph = &getSubgraph(consumer_group);
301 for (auto it = temporary_nodes.rbegin(); it != temporary_nodes.rend();
302 ++it) {
303 Node* node = *it;
304 Node* merged = mergeNodeIntoGroup(consumer_group, node);
305 // If any of the outputs are still used then we need to add them
306 auto outputs = node->outputs();
307 for (const auto i : c10::irange(outputs.size())) {
308 auto output = outputs[i];
309 if (output->uses().empty())
310 continue;
311 consumer_subgraph->registerOutput(merged->outputs()[i]);
312 auto new_output = consumer_group->addOutput();
313 output->replaceAllUsesWith(new_output);
314 aliasDb_->replaceWithNewValue(output, new_output);
315 new_output->setType(output->type());
316 }
317 node->destroy();
318 }
319 }
320
321 // insert a producer node into a consuming fusion group.
322 // DOES NOT WORK if n is a consumer of an output of the fusion group
323 // returns the node _inside_ the group that represents the node
mergeNodeIntoGrouptorch::jit::__anon688491630111::GraphFuser324 Node* mergeNodeIntoGroup(Node* group, Node* n) {
325 AT_ASSERT(n->kind() != kind_);
326 auto& subgraph = getSubgraph(group);
327 // map from nodes in the surrounding graph to parameters in the fusion
328 // group's subgraph that correspond to them
329 std::unordered_map<Value*, Value*> inputs_map;
330 size_t i = 0;
331 size_t tensor_insert_idx = 0;
332 AT_ASSERT(group->inputs().size() == subgraph.inputs().size());
333 for (auto input : group->inputs()) {
334 inputs_map[input] = subgraph.inputs()[i++];
335 if (input->type()->isSubtypeOf(*TensorType::get()))
336 tensor_insert_idx = i;
337 }
338 // add n's inputs to the fusion group's input list if we don't already have
339 // them
340 // we insert tensors first because the fuser assumes that to be the case
341 // (as a legacy from tensors only)
342 WithInsertPoint guard(*subgraph.nodes().begin());
343 for (auto input : n->inputs()) {
344 if (inputs_map.count(input) == 0) {
345 if (input->type()->isSubtypeOf(*TensorType::get())) {
346 auto in_group = subgraph.insertInput(tensor_insert_idx);
347 in_group->setType(input->type());
348 inputs_map[input] = in_group;
349 group->insertInput(tensor_insert_idx, input);
350 tensor_insert_idx++;
351 } else if (
352 (input->type()->isSubtypeOf(*FloatType::get()) &&
353 input->node()->kind() != prim::Constant) ||
354 (n->kind() == aten::_grad_sum_to_size &&
355 input->type()->isSubtypeOf(*ListType::ofInts()))) {
356 auto in_group = subgraph.addInput();
357 in_group->setType(input->type());
358 inputs_map[input] = in_group;
359 group->addInput(input);
360 } else {
361 // We don't support passing in scalars as arguments to fused kernels,
362 // so we generally don't allow fusing tensor-scalar operations unless
363 // the scalar is constant. In those cases we inline the constants
364 // directly in the body of the fused group.
365 AT_ASSERT(input->node()->kind() == prim::Constant);
366 Node* in_const =
367 subgraph.createClone(input->node(), [](Value*) -> Value* {
368 throw std::runtime_error("unexpected input");
369 });
370 subgraph.insertNode(in_const);
371 inputs_map[input] = in_const->output();
372 }
373 }
374 }
375 // copy n into the graph, remapping its inputs to internal nodes
376 Node* in_graph = subgraph.createClone(
377 n, [&](Value* k) -> Value* { return inputs_map[k]; });
378 // if n's outputs are already inputs to the fusion group,
379 // we need to remove them because n is now inside the fusion group.
380 //
381 // i.e.,
382 // x = f(w); group(x, y, z) becomes group(w, y, z).
383 // x, y, z = f(w); group(x, y, z) becomes group(w).
384 //
385 // remapping nodes that used the input to the newly-merged node
386 // n is not an input when the fusion group is empty
387 auto inputs = group->inputs();
388 for (size_t i = 0; i < n->outputs().size(); ++i) {
389 auto it = std::find(inputs.begin(), inputs.end(), n->outputs()[i]);
390 if (it != inputs.end()) {
391 size_t p = it - inputs.begin();
392 group->removeInput(p);
393 subgraph.inputs()[p]->replaceAllUsesWith(in_graph->outputs()[i]);
394 subgraph.eraseInput(p);
395 }
396 }
397 return subgraph.insertNode(in_graph);
398 }
399
400 // turn consumer node n into a fusion group with just n inside
401 // to prepare for fusion and replace uses of n with the new group
createSingletonFusionGrouptorch::jit::__anon688491630111::GraphFuser402 Node* createSingletonFusionGroup(Node* n) {
403 auto group = block_->owningGraph()->createWithSubgraph(kind_);
404 // propagate position information for the new node so we can always
405 // have a valid mapping
406 group->insertBefore(n);
407 Node* mergedNode = mergeNodeIntoGroup(group, n);
408 getSubgraph(group).registerOutput(mergedNode->output());
409 auto sel = group->addOutput();
410 sel->copyMetadata(n->output());
411 aliasDb_->replaceWithNewValue(n->output(), sel);
412 n->replaceAllUsesWith(group);
413 n->destroy();
414 return group;
415 }
416
tryFusetorch::jit::__anon688491630111::GraphFuser417 std::optional<Node*> tryFuse(Node* consumer, Value* producer) {
418 // this handles cases where producer can be moved _into_ the fusion group of
419 // consumer.
420 // TODO: extend to fusion of consumer into _producer's_ fusion blob
421 // if the consumer allInputsAreThisProducer(consumer,producer)
422 // we can move the consumer up into the producer.
423 // but this requires better handling of merging fusion groups so it is not
424 // done now
425 bool shouldFuse = isFusable(producer->node()) &&
426 // Rearrange nodes such that all uses of producer are after the
427 // consumer. Fusion will rewrite those later uses to use the version of
428 // producer generated by the fused blob. In this case, producer becomes
429 // an output of the fusion group.
430 aliasDb_->moveBeforeTopologicallyValid(producer->node(), consumer);
431
432 if (!shouldFuse) {
433 return std::nullopt;
434 }
435
436 if ((consumer->inputs().size() + consumer->outputs().size() +
437 producer->node()->inputs().size() +
438 producer->node()->outputs().size()) > subgraph_arg_limit_) {
439 return std::nullopt;
440 }
441
442 auto group = consumer;
443 if (consumer->kind() != kind_) {
444 group = createSingletonFusionGroup(consumer);
445 }
446
447 if (producer->node()->kind() == kind_) {
448 mergeFusionGroups(group, producer->node());
449 return group;
450 }
451 AT_ASSERT(producer->node()->outputs().size() == 1);
452 Node* merged = mergeNodeIntoGroup(group, producer->node());
453 // remaining uses of this producer can occur because we allow
454 // fusion in cases where uses remain after the consumer
455 // if these exist, re-route them to the version of producer
456 // created in FusionGroup
457 if (!producer->uses().empty()) {
458 getSubgraph(group).registerOutput(merged->output());
459 Value* new_producer = group->addOutput();
460 new_producer->copyMetadata(producer);
461 aliasDb_->replaceWithNewValue(producer, new_producer);
462 producer->replaceAllUsesWith(new_producer);
463 }
464 producer->node()->destroy();
465 return group;
466 }
467
canFuseChunktorch::jit::__anon688491630111::GraphFuser468 bool canFuseChunk(Node* consumer, Value* producer) {
469 if (consumer->kind() != prim::FusionGroup) {
470 return false;
471 }
472 // Does the chunk have constant chunks/dim?
473 auto* chunk = producer->node();
474 if (chunk->kind() != prim::ConstantChunk)
475 return false;
476 // And all uses of the chunk are in this consumer
477 for (auto s : chunk->outputs()) {
478 for (auto u : s->uses()) {
479 if (u.user != consumer) {
480 return false;
481 }
482 }
483 }
484 // And isn't a no-op chunk (chunks == 1). Have CSE clean this up.
485 // We could fuse this but it's better to just delete the node.
486 if (chunk->i(attr::chunks) == 1) {
487 return false;
488 }
489 return true;
490 }
491
findFusedChunktorch::jit::__anon688491630111::GraphFuser492 std::optional<Node*> findFusedChunk(Node* group, Value* input) {
493 AT_ASSERT(group->kind() == prim::FusionGroup);
494 auto it = std::find(group->inputs().begin(), group->inputs().end(), input);
495 if (it == group->inputs().end()) {
496 return std::nullopt;
497 }
498 size_t input_index = it - group->inputs().begin();
499 auto& subgraph = getSubgraph(group);
500 auto* subgraph_input = subgraph.inputs().at(input_index);
501 // If subgraph_input is an input to prim::ConstantChunk, it will have 1 use
502 auto* node = subgraph_input->uses().at(0).user;
503 if (node->kind() == prim::ConstantChunk) {
504 AT_ASSERT(subgraph_input->uses().size() == 1);
505 return node;
506 }
507 return std::nullopt;
508 }
509
fuseChunkByReusingExistingFusedChunktorch::jit::__anon688491630111::GraphFuser510 void fuseChunkByReusingExistingFusedChunk(
511 Node* group,
512 Node* chunk,
513 Node* existingFusedChunk) {
514 if (chunk->outputs().size() != existingFusedChunk->outputs().size()) {
515 return;
516 }
517 auto& subgraph = getSubgraph(group);
518 for (size_t i = 0; i < chunk->outputs().size(); ++i) {
519 // Find the input to the FusionGroup (group)
520 auto* replacement_val = existingFusedChunk->outputs().at(i);
521 auto* val = chunk->outputs().at(i);
522 auto it = std::find(group->inputs().begin(), group->inputs().end(), val);
523 auto input_index = it - group->inputs().begin();
524
525 // Rewrite the graph to use replacement_val
526 auto group_input = subgraph.inputs().at(input_index);
527 group_input->replaceAllUsesWith(replacement_val);
528
529 // Remove the input, it's no longer needed
530 group->removeInput(input_index);
531 subgraph.eraseInput(input_index);
532 }
533 chunk->destroy();
534 }
535
536 // There are two invariants for prim::ConstantChunk:
537 // (1) the tensor input to prim::ConstantChunk must be an input to the fusion
538 // group (2) no two ConstantChunks in the same FusionGroup can share a tensor
539 // input.
fuseChunktorch::jit::__anon688491630111::GraphFuser540 graph_node_list::iterator fuseChunk(Node* consumer, Value* producer) {
541 auto* chunk = producer->node();
542 AT_ASSERT(consumer->kind() == prim::FusionGroup);
543 AT_ASSERT(chunk->kind() == prim::ConstantChunk);
544
545 // if producer's input is already an input to a prim::ConstantChunk node,
546 // we cannot add a new prim::ConstantChunk node because of invariant (2).
547 auto* chunked_tensor = producer->node()->input();
548 if (auto existingFusedChunk = findFusedChunk(consumer, chunked_tensor)) {
549 fuseChunkByReusingExistingFusedChunk(
550 consumer, chunk, *existingFusedChunk);
551 return consumer->reverseIterator();
552 }
553
554 // Move prim::ConstantChunk into the FusionGroup
555 mergeNodeIntoGroup(consumer, chunk);
556 chunk->destroy();
557 return consumer->reverseIterator();
558 }
559
sortReverseTopologicaltorch::jit::__anon688491630111::GraphFuser560 value_list sortReverseTopological(ArrayRef<Value*> inputs) {
561 value_list result;
562 for (auto i : inputs) {
563 if (i->node()->owningBlock() == block_) {
564 result.push_back(i);
565 }
566 }
567 // Sort in reverse topological order
568 std::sort(result.begin(), result.end(), [&](Value* a, Value* b) {
569 return a->node()->isAfter(b->node());
570 });
571 return result;
572 }
573
scanNodeForChunkstorch::jit::__anon688491630111::GraphFuser574 graph_node_list::iterator scanNodeForChunks(Node* consumer) {
575 if (consumer->kind() == prim::FusionGroup) {
576 auto inputs = sortReverseTopological(consumer->inputs());
577 for (auto producer : inputs) {
578 if (!canFuseChunk(consumer, producer)) {
579 continue;
580 }
581 return fuseChunk(consumer, producer);
582 }
583 }
584 return ++consumer->reverseIterator();
585 }
586
broadcast_tensorstorch::jit::__anon688491630111::GraphFuser587 at::ArrayRef<Value*> broadcast_tensors(value_list inputs) {
588 AT_ASSERT(!inputs.empty());
589 auto* g = inputs[0]->owningGraph();
590 auto* input_list =
591 g->insertNode(g->createList(TensorType::get(), inputs))->output();
592 aliasDb_->createValue(input_list);
593 auto* output_list = g->insert(aten::broadcast_tensors, {input_list});
594 aliasDb_->createValue(output_list);
595 auto* unpack_node = g->insertNode(
596 g->create(prim::ListUnpack, {output_list}, inputs.size()));
597
598 // We are doing:
599 // input_list = listConstruct(a, b, ...)
600 // output_list = broadcast_tensors(input_list)
601 // a_broadcasted, b_broadcasted = listUnpack(output_list)
602 // `a_broadcasted` should receive the same aliasing info as `a`
603 TORCH_INTERNAL_ASSERT(unpack_node->outputs().size() == inputs.size());
604 for (const auto i : c10::irange(inputs.size())) {
605 Value* original_input = inputs[i];
606 Value* broadcasted_output = unpack_node->outputs()[i];
607 aliasDb_->copyValue(original_input, broadcasted_output);
608 }
609
610 return unpack_node->outputs();
611 }
612
insertExplicitBroadcasttorch::jit::__anon688491630111::GraphFuser613 void insertExplicitBroadcast(Node* node) {
614 WithInsertPoint insert_guard{node};
615 auto tensors = tensorInputs(node);
616 auto new_tensors = broadcast_tensors(std::move(tensors));
617
618 // Replace tensors inputs with broadcasted values
619 auto new_tensors_it = new_tensors.begin();
620 for (size_t i = 0; i < node->inputs().size(); ++i) {
621 if (node->inputs()[i]->type()->isSubtypeOf(*TensorType::get())) {
622 AT_ASSERT(new_tensors_it != new_tensors.end());
623 node->replaceInput(i, *(new_tensors_it++));
624 }
625 }
626 }
627
promoteChunkToBroadcastingChunktorch::jit::__anon688491630111::GraphFuser628 Node* promoteChunkToBroadcastingChunk(Node* chunk) {
629 AT_ASSERT(chunk->kind() == prim::ConstantChunk);
630
631 size_t nchunks = chunk->i(attr::chunks);
632 Node* bchunk =
633 chunk->owningGraph()->create(prim::BroadcastingChunk, nchunks);
634 bchunk->addInput(chunk->input());
635 for (const auto i : c10::irange(nchunks)) {
636 auto* old_output = chunk->outputs().at(i);
637 auto* new_output = bchunk->outputs().at(i);
638 new_output->copyMetadata(old_output);
639 aliasDb_->replaceWithNewValue(old_output, new_output);
640 old_output->replaceAllUsesWith(new_output);
641 }
642 bchunk->copyAttributes(*chunk);
643 bchunk->insertAfter(chunk);
644 chunk->destroy();
645 return bchunk;
646 }
647
648 // in places where op can be fused into a consumer but chunk is in the way
649 // distribute chunk to op's operands:
650 // replace a,b = chunk(op(x,y,z)) with:
651 // x', y', z' = broadcast_tensors([x, y, z])
652 // x0,x1 = chunk(x') (x0 has a's type, x1 has b's type)
653 // y0,y1 = chunk(y') (y0 has a's type, y1 has b's type)
654 // z0,z1 = chunk(z') (z0 has a's type, z1 has b's type)
655 // a = op(x0,y0,z0) (a,b have their same size but are now contiguous)
656 // b = op(x1,y1,x1)
657 //
658 // The graph fuser uses an intermediate prim::BroadcastingChunk node to
659 // represent this behavior concisely. BroadcastingChunk(x, y, z) broadcasts
660 // all of its inputs and then chunks each input, in order, the same way.
661 // The above graph is equivalent to:
662 // x0, x1, y0, y1, z0, z1 = BroadcastingChunk(x, y, z)
663 // a = op(x0,y0,z0)
664 // b = op(x1,y1,x1)
665 //
666 // NB: The explicit broadcast is important for correctness.
667 // Let's say we have:
668 // %z = aten::mul(%x, %y)
669 // %z.1, %z.2 = aten::chunk(%z, ...)
670 // ... = prim::FusionGroup(%z.1, %z.2, ...)
671 // It's possible that %x and %y do not have the same size as %z and
672 // need to be expanded first so that they can be chunked like %z
673 //
674 // NB: Chunk motion only occurs with fusable consumers, which implies
675 // that there is always some other operation, e.g., a+b, that happens
676 // after the chunk, and will be put into the fusion group. This is
677 // important, because distributing the chunk changes the contiguity
678 // of a and b, and so the results would be invalid, except that we know
679 // that simple_mappable operations will restore contiguity before
680 // we exit the fusion group.
681 //
682 // NB: The intermediate BroadcastingChunk is important for moving chunks past
683 // more than one operation: the graph fuser is not able to easily move
684 // operations around broadcast_tensors + chunk nodes. Let f, g, h be fusible
685 // ops
686 // x = f(v, w)
687 // z = g(x, y)
688 // a, b = chunk(z)
689 // c = h(a, b)
690 // becomes (with the broadcast_tensors + chunk approach):
691 // x = f(v, w)
692 // x', y' = broadcast_tensors([x, y])
693 // ax, bx = chunk(x')
694 // ay, by = chunk(y')
695 // a = g(ax, ay)
696 // b = g(bx, by)
697 // c = h(a, b)
698 // The broadcast_tensors node makes it harder to move f into the resulting
699 // FusionGroup of g, g, and h. Keeping the broadcasting and chunk behavior
700 // together results in:
701 // x = f(v, w)
702 // ax, bx, ay, by = BroadcastingChunk(x, y)
703 // a = g(ax, ay)
704 // b = g(bx, by)
705 // c = h(a, b)
706 // making it easier to move f after the BroadcastingChunk:
707 // ay, by, av, bv, aw, bw = BroadcastingChunk(y, v, w)
708 // ax = f(av, aw)
709 // by = f(bv, bw)
710 // a = g(ax, ay)
711 // b = g(bx, by)
712 // c = h(a, b)
713
tryToMoveChunktorch::jit::__anon688491630111::GraphFuser714 bool tryToMoveChunk(Node* consumer, Value* producer) {
715 // is the output from a chunk/bchunk node?
716 auto* chunk = producer->node();
717 if (chunk->kind() != prim::ConstantChunk &&
718 chunk->kind() != prim::BroadcastingChunk)
719 return false;
720
721 // try to find a producer to move after the chunk/bchunk. The producer must
722 // be fusible into the consumer.
723 auto it = std::find_if(
724 chunk->inputs().begin(),
725 chunk->inputs().end(),
726 [&](Value* producer_for_chunk) {
727 return isFusableMap(producer_for_chunk->node()) &&
728 allUsersAreThisConsumerOrCalcSizes(chunk, producer_for_chunk);
729 });
730 if (it == chunk->inputs().end()) {
731 return false;
732 }
733 Value* producer_for_chunk = *it;
734 size_t producer_index = it - chunk->inputs().begin();
735
736 // all uses of the chunk must be in this consumer
737 for (auto s : chunk->outputs()) {
738 for (auto u : s->uses()) {
739 if (u.user != consumer)
740 return false;
741 }
742 }
743 // multiple return operators
744 Node* producer_for_chunk_node = producer_for_chunk->node();
745 AT_ASSERT(producer_for_chunk_node->outputs().size() == 1);
746
747 // Convert chunk to bchunk, if it isn't one already. The bchunk represents a
748 // broadcast and one or more chunk operations.
749 auto* bchunk = chunk;
750 if (chunk->kind() == prim::ConstantChunk) {
751 bchunk = promoteChunkToBroadcastingChunk(chunk);
752 }
753 size_t nchunks = bchunk->i(attr::chunks);
754 WithInsertPoint guard(bchunk->next());
755
756 std::vector<Value*> producer_chunk_outputs;
757 for (const auto i : c10::irange(nchunks)) {
758 producer_chunk_outputs.push_back(
759 bchunk->output(nchunks * producer_index + i));
760 }
761
762 // Add each of op's operands to the bchunk node.
763 // chunked_inputs[input_nr][chunk_output_idx]
764 // = Node* for chunk_output_idx'th output of the chunk(inputs[input_nr])
765 std::vector<std::vector<Value*>> chunked_inputs;
766
767 for (auto input : producer_for_chunk_node->inputs()) {
768 // XXX: we only work with pointwise ops in here, so we know it is valid to
769 // push the concat only through tensor arguments (and all other args can
770 // be safely ignored).
771 if (!input->type()->isSubtypeOf(*TensorType::get()))
772 continue;
773
774 // if 'input' is already an input to the bchunk, reuse it.
775 auto bchunk_inputs = bchunk->inputs();
776 auto it = std::find(bchunk_inputs.begin(), bchunk_inputs.end(), input);
777 if (it != bchunk_inputs.end()) {
778 chunked_inputs.emplace_back();
779 auto input_index = std::distance(bchunk_inputs.begin(), it);
780 for (const auto chunki : c10::irange(nchunks)) {
781 chunked_inputs.back().push_back(
782 bchunk->outputs().at(nchunks * input_index + chunki));
783 }
784 continue;
785 }
786
787 // NB: I decided not to use cloneFrom here, because if we make cloneFrom
788 // copy selects one day, it is definitely not what you want here (selects
789 // have different types).
790 // TODO: Perhaps we should use cloneFrom now, as it seems unlikely
791 // to copy select nodes now that we have refactored to have a Value
792 // distinct from Node.
793 bchunk->addInput(input);
794 chunked_inputs.emplace_back(); // alas, to not be C++17
795 for (auto chunk_sel : producer_chunk_outputs) {
796 Value* input_chunk_sel = bchunk->addOutput();
797 input_chunk_sel->setType(chunk_sel->type());
798 // Add a fresh value for each output element of the broadcasting chunk
799 // node. This is safe because it will be consumed only by the chunked
800 // ops.
801 aliasDb_->createValue(input_chunk_sel);
802 chunked_inputs.back().push_back(input_chunk_sel);
803 }
804 }
805
806 // apply the op to each chunk of the chunked operands,
807 // and then rewrite the graph to use them!
808 for (auto chunk_sel : producer_chunk_outputs) {
809 auto original_inputs = producer_for_chunk_node->inputs();
810 Node* chunked_op =
811 block_->owningGraph()->create(producer_for_chunk_node->kind());
812 chunked_op->copyAttributes(*producer_for_chunk_node);
813 chunked_op->output()->setType(chunk_sel->type());
814 auto chunked_inputs_it = chunked_inputs.begin();
815 for (Value* original_input : original_inputs) {
816 if (original_input->type()->isSubtypeOf(*TensorType::get())) {
817 AT_ASSERT(chunked_inputs_it != chunked_inputs.end());
818 chunked_op->addInput(
819 // NOLINTNEXTLINE(clang-analyzer-core.DivideZero)
820 chunked_inputs_it->at(chunk_sel->offset() % nchunks));
821 ++chunked_inputs_it;
822 } else {
823 chunked_op->addInput(original_input);
824 }
825 }
826 bchunk->owningGraph()->insertNode(chunked_op);
827 chunk_sel->replaceAllUsesWith(chunked_op->output());
828 aliasDb_->replaceWithNewValue(chunk_sel, chunked_op->output());
829 }
830
831 bchunk->removeInput(producer_index);
832 for (const auto i : c10::irange(nchunks)) {
833 (void)i; // Suppress unused variable warning
834 bchunk->eraseOutput(nchunks * producer_index);
835 }
836
837 // The output of producer_for_chunk_node could have been used in some
838 // aten::size operators, so we need to clean those up as well (we simply
839 // broadcast all its tensor inputs).
840 // We need to insert these early in the graph, i.e. immediately after
841 // the producer_for_chunk_node as we will have the _size_if_not_same
842 // that may be before the bchunk.
843 WithInsertPoint guard2(producer_for_chunk_node);
844 auto size_calc_uses = producer_for_chunk_node->output()->uses();
845 if (!size_calc_uses.empty()) {
846 auto tensor_inputs = filter(
847 producer_for_chunk_node->inputs(),
848 [](Value* v) { return v->type()->isSubtypeOf(*TensorType::get()); });
849 auto tensor_sizes = fmap(tensor_inputs, [&](Value* v) {
850 Value* output = v->owningGraph()->insert(aten::size, {v});
851 aliasDb_->createValue(output);
852 return output;
853 });
854 AT_ASSERT(!tensor_sizes.empty());
855 Value* output_size = tensor_sizes.size() == 1
856 ? tensor_sizes[0]
857 : broadcastSizes(tensor_sizes, aliasDb_);
858 for (Use u : size_calc_uses) {
859 u.user->output()->replaceAllUsesWith(output_size);
860 u.user->destroy();
861 }
862 }
863 producer_for_chunk_node->destroy();
864 return true;
865 }
866
867 // returns where to continue scanning, and whether any fusion was made
scanNodetorch::jit::__anon688491630111::GraphFuser868 std::pair<graph_node_list::iterator, bool> scanNode(Node* consumer) {
869 if (isFusable(consumer)) {
870 // handle inputs in reverse topological order as well...
871 // otherwise in f(a,a+b) it will appear a is used twice if we consider
872 // the f-a fusion before the f-(a+b) fusion first.
873 auto inputs = sortReverseTopological(consumer->inputs());
874 for (auto producer : inputs) {
875 if (tryToMoveChunk(consumer, producer)) {
876 // the chunk before this consumer was re-arranged to allow fusion,
877 // we scan this consumer again to perform the fusion
878 return std::make_pair(consumer->reverseIterator(), true);
879 }
880 auto fusion_group = tryFuse(consumer, producer);
881 if (fusion_group) {
882 // after fusion, consumer moves into a FusionGroup, so inputs is no
883 // longer valid so we rescan the new FusionGroup for more fusions...
884 return std::make_pair(fusion_group.value()->reverseIterator(), true);
885 }
886 }
887 }
888 return std::make_pair(++consumer->reverseIterator(), false);
889 }
890
replaceIntermediateBroadcastingChunkstorch::jit::__anon688491630111::GraphFuser891 void replaceIntermediateBroadcastingChunks() {
892 for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend();) {
893 auto* node = *it;
894 ++it; // We might delete node, so increment the iterator now.
895 if (node->kind() != prim::BroadcastingChunk) {
896 continue;
897 }
898 auto* bchunk = node;
899 insertExplicitBroadcast(bchunk);
900
901 auto* graph = block_->owningGraph();
902 size_t nchunks = bchunk->i(attr::chunks);
903 WithInsertPoint guard(bchunk->next());
904
905 // Split the bchunk into bchunks.inputs().size() number of chunk nodes.
906 for (size_t input_offset = 0; input_offset < bchunk->inputs().size();
907 input_offset++) {
908 auto* input = bchunk->inputs().at(input_offset);
909
910 Node* new_chunk =
911 graph->insertNode(graph->create(prim::ConstantChunk, input, 0));
912 new_chunk->copyAttributes(*bchunk);
913 for (const auto output_offset : c10::irange(nchunks)) {
914 auto new_output = new_chunk->addOutput();
915 auto old_output =
916 bchunk->outputs().at(input_offset * nchunks + output_offset);
917 new_output->copyMetadata(old_output);
918 aliasDb_->replaceWithNewValue(old_output, new_output);
919 old_output->replaceAllUsesWith(new_output);
920 }
921 }
922 bchunk->destroy();
923 }
924 }
925
926 // Builds up expressions that compute shapes of all intermediates (and
927 // outputs) of the fusion group, based on the sizes of inputs. You should run
928 // DCE to remove those that you end up not using.
buildShapeExpressionstorch::jit::__anon688491630111::GraphFuser929 std::unordered_map<Value*, Value*> buildShapeExpressions(Node* fusion_group) {
930 WithInsertPoint insert_guard{fusion_group->next()};
931 std::unordered_map<Value*, Value*> shape_of;
932
933 Graph* graph = fusion_group->owningGraph();
934 auto subgraph = fusion_group->g(attr::Subgraph);
935
936 auto inputs = fusion_group->inputs();
937 auto sinputs = subgraph->inputs();
938 AT_ASSERT(inputs.size() == sinputs.size());
939 for (const auto i : c10::irange(inputs.size())) {
940 if (inputs[i]->type()->isSubtypeOf(*TensorType::get())) {
941 Value* soutput = graph->insert(aten::size, {inputs[i]});
942 aliasDb_->createValue(soutput);
943 shape_of[sinputs[i]] = soutput;
944 }
945 }
946
947 // When we have a guarantee that an output won't be removed, because it's
948 // used in expressions that don't involve size checks, we can use its size
949 // instead of computing a long chain of broadcasts, starting from the
950 // beginning of the kernel.
951 auto outputs = fusion_group->outputs();
952 auto soutputs = subgraph->outputs();
953 AT_ASSERT(outputs.size() == soutputs.size());
954 for (const auto i : c10::irange(outputs.size())) {
955 if (usedOnlyInSize(outputs[i]))
956 continue;
957 Value* soutput = graph->insert(aten::size, {outputs[i]});
958 aliasDb_->createValue(soutput);
959 shape_of[soutputs[i]] = soutput;
960 }
961
962 for (Node* n : subgraph->nodes()) {
963 // XXX: Use of shape_of.emplace is crucial to the output shape
964 // optimization!
965 if (n->kind() == prim::FusedConcat) {
966 // This is a bit more involved, because we have to account for the case
967 // when inputs have different shapes, but fortunately those tensors are
968 // always outputs, and so we can simply avoid replacing their queries,
969 // because it won't help us.
970 continue;
971 }
972 if (n->kind() == prim::Constant) {
973 continue;
974 }
975 if (n->kind() == prim::ConstantChunk) {
976 Node* sizes_node = graph->insertNode(
977 graph->create(prim::ChunkSizes, shape_of.at(n->input()), 2));
978 sizes_node->i_(attr::dim, n->i(attr::dim));
979 sizes_node->i_(attr::chunks, n->i(attr::chunks));
980 for (Value* output : sizes_node->outputs()) {
981 aliasDb_->createValue(output);
982 }
983 Value* regular_size = sizes_node->outputs().at(0);
984 Value* last_size = sizes_node->outputs().at(1);
985 regular_size->setType(ListType::ofInts());
986 last_size->setType(ListType::ofInts());
987 auto outputs = n->outputs();
988 for (Value* o : outputs.slice(0, outputs.size() - 1)) {
989 shape_of.emplace(o, regular_size);
990 }
991 shape_of.emplace(outputs.at(outputs.size() - 1), last_size);
992 continue;
993 }
994 auto tensor_inputs = filter(n->inputs(), [](Value* v) {
995 return v->type()->isSubtypeOf(*TensorType::get());
996 });
997 auto shapes =
998 fmap(tensor_inputs, [&](Value* v) { return shape_of.at(v); });
999 AT_ASSERT(!shapes.empty());
1000 shape_of.emplace(
1001 n->output(),
1002 shapes.size() == 1 ? shapes[0] : broadcastSizes(shapes, aliasDb_));
1003 }
1004 return shape_of;
1005 }
1006
removeOutputsUsedOnlyInSizetorch::jit::__anon688491630111::GraphFuser1007 void removeOutputsUsedOnlyInSize(Node* fusion_group) {
1008 if (fusion_group->kind() != prim::FusionGroup)
1009 return;
1010 auto subgraph = fusion_group->g(attr::Subgraph);
1011
1012 auto shape_of = buildShapeExpressions(fusion_group);
1013 auto outputs = fusion_group->outputs().vec();
1014 auto soutputs = subgraph->outputs().vec();
1015 // XXX: Iterating in this order is not only good for performance reasons!
1016 // It is also crucial for correctness (i has to reflect the current true
1017 // index of outputs[i])!
1018 for (int64_t i = static_cast<int64_t>(outputs.size()) - 1; i >= 0; --i) {
1019 auto output = outputs[i];
1020 auto soutput = soutputs[i];
1021 if (usedOnlyInSize(output) && shape_of.count(soutput) > 0) {
1022 auto uses = output->uses();
1023 for (Use u : uses) {
1024 AT_ASSERT(u.user->matches("aten::size(Tensor self) -> int[]"));
1025 u.user->output()->replaceAllUsesWith(shape_of.at(soutput));
1026 u.user->destroy();
1027 }
1028 fusion_group->eraseOutput(i);
1029 subgraph->eraseOutput(i);
1030 }
1031 }
1032 }
1033
canFuseWithConcattorch::jit::__anon688491630111::GraphFuser1034 bool canFuseWithConcat(Value* producer, Node* before_check) {
1035 if (!isFusable(producer->node())) {
1036 return false;
1037 }
1038 // NB: it is important that this check happens after isFusable, which checks
1039 // that the blocks match, and it's not a special node like prim::Param
1040 if (!aliasDb_->couldMoveBeforeTopologically(
1041 producer->node(), before_check)) {
1042 return false;
1043 }
1044
1045 // If the number of kernel args could exceed the limit, skip.
1046 if ((before_check->inputs().size() + before_check->outputs().size() +
1047 producer->node()->inputs().size() +
1048 producer->node()->outputs().size()) > subgraph_arg_limit_) {
1049 return false;
1050 }
1051
1052 // Fusion groups can be merged with concat's group if and only if
1053 // the value they produce isn't already coming from a concat
1054 if (producer->node()->kind() == prim::FusionGroup) {
1055 auto subgraph = producer->node()->g(attr::Subgraph);
1056 auto* node = subgraph->outputs().at(producer->offset())->node();
1057 return node->kind() != prim::FusedConcat;
1058 }
1059 return true;
1060 }
1061
createFusedConcattorch::jit::__anon688491630111::GraphFuser1062 Node* createFusedConcat(Node* node) {
1063 AT_ASSERT(node->kind() == aten::cat);
1064
1065 Graph* graph = node->owningGraph();
1066 Node* list_construct = node->namedInput(attr::tensors)->node();
1067 int64_t dim = node->get<int64_t>(attr::dim).value();
1068
1069 Node* fused_cat = graph->create(prim::FusedConcat, list_construct->inputs())
1070 ->i_(attr::dim, dim);
1071 fused_cat->insertBefore(list_construct);
1072 fused_cat->output()->copyMetadata(node->output());
1073 aliasDb_->copyValue(node->output(), fused_cat->output());
1074
1075 // NB: this deletes the fused_cat node from the original graph
1076 return createSingletonFusionGroup(fused_cat);
1077 }
1078
fuseConcatstorch::jit::__anon688491630111::GraphFuser1079 void fuseConcats() {
1080 for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend();
1081 ++it) {
1082 Node* cat = *it;
1083 if (!isFusableCatNode(cat)) {
1084 continue;
1085 }
1086 Node* list_construct = cat->namedInput(attr::tensors)->node();
1087 Node* fused_cat = createFusedConcat(cat);
1088 Value* fused_cat_out = fused_cat->output();
1089
1090 auto sorted_inputs = sortReverseTopological(fused_cat->inputs());
1091 size_t input_idx = 0;
1092 bool any_fused = false;
1093 while (input_idx < sorted_inputs.size()) {
1094 Value* input = sorted_inputs[input_idx++];
1095 if (!canFuseWithConcat(input, fused_cat)) {
1096 continue;
1097 }
1098 any_fused = true;
1099 auto maybe_group = tryFuse(fused_cat, input);
1100 AT_ASSERT(maybe_group && maybe_group == fused_cat);
1101 // We could have destroyed multiple inputs when performing this fusion,
1102 // so we have to recompute the list and iterate over it again.
1103 sorted_inputs = sortReverseTopological(fused_cat->inputs());
1104 input_idx = 0;
1105 }
1106
1107 if (any_fused) {
1108 cat->output()->replaceAllUsesWith(fused_cat_out);
1109 it.destroyCurrent();
1110 if (list_construct->output()->uses().empty()) {
1111 list_construct->destroy();
1112 }
1113 } else {
1114 fused_cat->destroy();
1115 }
1116 }
1117 }
1118
optimizeFusedGraphstorch::jit::__anon688491630111::GraphFuser1119 void optimizeFusedGraphs() {
1120 for (Node* node : block_->nodes()) {
1121 if (node->kind() != prim::FusionGroup) {
1122 continue;
1123 }
1124 auto subgraph = node->g(attr::Subgraph);
1125 EliminateDeadCode(subgraph);
1126 EliminateCommonSubexpression(subgraph);
1127 ConstantPooling(subgraph);
1128 }
1129 }
1130
runtorch::jit::__anon688491630111::GraphFuser1131 void run() {
1132 // TODO: old fuser is not maintained internally, somewhere it is being turned on
1133 // inadvertently for certain workflows. make this a no-op until we identify
1134 // location
1135 #if defined(FBCODE_CAFFE2)
1136 return;
1137 #endif
1138
1139 // Run the pass until no changes are made.
1140 // This is necessary, because the algorithm can miss out on certain fusion
1141 // opportunities if ran only once. Consider this graph:
1142 //
1143 // %1 = f(...)
1144 // %2 = g(%1)
1145 // %3 = h(%1)
1146 // %4 = l(%3)
1147 // return (%4, %2)
1148 //
1149 // where f, g, h, l are simple map ops.
1150 // The first iteration will fuse %4 and %3, and see that %1 is an input, but
1151 // can't be fused, because it has a different use before the fusion group
1152 // in our topological ordering. Then, %2 will be considered, and fused with
1153 // %1. If we do another iteration, the algorithm will consider the fusion of
1154 // these two groups and fix the situation.
1155 bool any_changed = true;
1156 while (any_changed) {
1157 any_changed = false;
1158 for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend();) {
1159 auto [tmp_it, changed] = scanNode(*it);
1160 it = tmp_it;
1161 any_changed |= changed;
1162 }
1163 }
1164
1165 fuseConcats();
1166
1167 optimizeFusedGraphs();
1168
1169 // The graph fuser can add intermediate prim::BroadcastingChunk nodes.
1170 // Replace them with broadcasts + chunks.
1171 replaceIntermediateBroadcastingChunks();
1172
1173 // Fuse starting chunks into the group.
1174 for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend();) {
1175 it = scanNodeForChunks(*it);
1176 }
1177
1178 // Remove outputs that have been added only because we need their size
1179 for (Node* n : block_->nodes()) {
1180 removeOutputsUsedOnlyInSize(n);
1181 }
1182
1183 for (Node* node : block_->nodes()) {
1184 for (Block* sub_block : node->blocks()) {
1185 GraphFuser(aliasDb_, sub_block, callback_, kind_, strict_fuser_check_)
1186 .run();
1187 }
1188 }
1189 }
1190 };
1191
PeepholeOptimizeShapeExpressions(Block * block,AliasDb * db)1192 void PeepholeOptimizeShapeExpressions(Block* block, AliasDb* db) {
1193 auto nodes = block->nodes();
1194 for (auto it = nodes.begin(); it != nodes.end(); ++it) {
1195 Node* node = *it;
1196 for (Block* subblock : node->blocks()) {
1197 PeepholeOptimizeShapeExpressions(subblock, db);
1198 }
1199 if (node->kind() == prim::BroadcastSizes) {
1200 // Remove no-op broadcasts.
1201 if (node->inputs().size() == 1) {
1202 node->output()->replaceAllUsesWith(node->input());
1203 it.destroyCurrent();
1204 continue;
1205 }
1206 // Deduplicate inputs, but use their unique() values to ensure
1207 // this process only depends on the graph.
1208 std::map<size_t, Value*> unique_to_value;
1209 for (Value* input : node->inputs()) {
1210 unique_to_value.emplace(input->unique(), input);
1211 }
1212 if (unique_to_value.size() != node->inputs().size()) {
1213 std::vector<Value*> inputs;
1214 inputs.reserve(unique_to_value.size());
1215 for (auto& entry : unique_to_value) {
1216 inputs.push_back(entry.second);
1217 }
1218 if (inputs.size() == 1) {
1219 node->output()->replaceAllUsesWith(inputs[0]);
1220 } else {
1221 WithInsertPoint insert_guard{node};
1222 node->output()->replaceAllUsesWith(broadcastSizes(inputs, db));
1223 }
1224 it.destroyCurrent();
1225 --it; // Revisit the node with deduplicated inputs
1226 continue;
1227 }
1228 // Remove compose simple chains of broadcasts into a single node.
1229 const auto& uses = node->output()->uses();
1230 if (uses.size() == 1 && uses[0].user->kind() == prim::BroadcastSizes) {
1231 Node* user = uses[0].user;
1232 user->removeInput(uses[0].offset);
1233 // NB: we don't care about deduplication in here, as we will visit user
1234 // later.
1235 for (Value* i : node->inputs()) {
1236 user->addInput(i);
1237 }
1238 it.destroyCurrent();
1239 }
1240 }
1241 }
1242 }
1243
1244 } // anonymous namespace
1245
1246 static bool cpu_fuser_enabled_legacy = false;
1247
canFuseOnCPULegacy()1248 bool canFuseOnCPULegacy() {
1249 return cpu_fuser_enabled_legacy;
1250 }
1251
overrideCanFuseOnCPULegacy(bool value)1252 void overrideCanFuseOnCPULegacy(bool value) {
1253 cpu_fuser_enabled_legacy = value;
1254 }
1255
FuseGraph(std::shared_ptr<Graph> & graph,bool strict_fuser_check)1256 void FuseGraph(std::shared_ptr<Graph>& graph, bool strict_fuser_check) {
1257 AliasDb db(graph);
1258 GraphFuser(&db, graph->block(), strict_fuser_check).run();
1259 Lint(&db);
1260 // After FuseGraph some common subexpressions may come back
1261 EliminateCommonSubexpression(graph);
1262 // We might have emitted a fair amount of useless shape propagating code, so
1263 // remove it
1264 EliminateDeadCode(graph);
1265 // Improve the quality of shape propagation code that was left
1266 PeepholeOptimizeShapeExpressions(graph->block(), &db);
1267 }
1268
CustomFuseGraph(std::shared_ptr<Graph> & graph,const std::function<bool (Node *)> & fn,Symbol kind,size_t arg_limit)1269 void CustomFuseGraph(
1270 std::shared_ptr<Graph>& graph,
1271 const std::function<bool(Node*)>& fn,
1272 Symbol kind,
1273 size_t arg_limit) {
1274 AliasDb db(graph);
1275 auto g = GraphFuser(
1276 &db,
1277 graph->block(),
1278 [=](GraphFuser* gf, Node* n) { return fn(n) || n->kind() == kind; },
1279 kind);
1280 g.setInputArgLimit(arg_limit);
1281 g.run();
1282 Lint(&db);
1283 }
1284
1285 } // namespace torch::jit
1286