xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/tensorexpr_fuser.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/tensorexpr_fuser.h>
2 
3 #include <ATen/core/interned_strings.h>
4 #include <ATen/core/symbol.h>
5 #include <ATen/record_function.h>
6 #include <c10/util/FunctionRef.h>
7 #include <c10/util/irange.h>
8 #include <torch/csrc/jit/codegen/cuda/interface.h>
9 #include <torch/csrc/jit/codegen/fuser/interface.h>
10 #include <torch/csrc/jit/ir/alias_analysis.h>
11 #include <torch/csrc/jit/jit_log.h>
12 #include <torch/csrc/jit/jit_opt_limit.h>
13 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
14 #include <torch/csrc/jit/passes/constant_pooling.h>
15 #include <torch/csrc/jit/passes/dead_code_elimination.h>
16 #include <torch/csrc/jit/passes/pass_manager.h>
17 #include <torch/csrc/jit/passes/remove_redundant_profiles.h>
18 #include <torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h>
19 #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
20 #include <torch/csrc/jit/runtime/custom_operator.h>
21 #include <torch/csrc/jit/runtime/graph_executor.h>
22 #include <torch/csrc/jit/runtime/operator_options.h>
23 #include <torch/csrc/jit/runtime/symbolic_shape_registry.h>
24 #include <torch/csrc/jit/runtime/symbolic_shape_registry_util.h>
25 #include <torch/csrc/jit/tensorexpr/kernel.h>
26 
27 #include <utility>
28 
29 C10_DEFINE_bool(
30     torch_jit_disable_cat,
31     false,
32     "disable aten::cat in TE fusion groups");
33 
34 C10_DEFINE_bool(
35     torch_jit_enable_dynamic_shape_fusion,
36     false,
37     "enable TE fusion using dynamic shapes");
38 
39 namespace torch::jit {
40 
41 static bool texpr_reductions_enabled = false;
42 
isSupportedForBlock(Node * node)43 static bool isSupportedForBlock(Node* node) {
44   switch (node->kind()) {
45     case aten::add:
46     case aten::mul:
47       return true;
48     default:
49       return false;
50   }
51 }
52 
usedOnlyInSize(Value * v)53 bool usedOnlyInSize(Value* v) {
54   const auto& uses = v->uses();
55   return std::all_of(uses.begin(), uses.end(), [](const Use& u) {
56     return u.user->matches("aten::size(Tensor self) -> int[]");
57   });
58 }
59 
broadcastSizes(at::ArrayRef<Value * > sizes,AliasDb * db)60 Value* broadcastSizes(at::ArrayRef<Value*> sizes, AliasDb* db) {
61   AT_ASSERT(!sizes.empty());
62   Graph* graph = sizes[0]->owningGraph();
63   Node* broadcast_n =
64       graph->insertNode(graph->create(prim::BroadcastSizes, sizes));
65   broadcast_n->output()->setType(ListType::ofInts());
66   db->createValue(broadcast_n->output());
67   return broadcast_n->output();
68 }
69 
70 namespace tensorexpr {
71 
getCustomOperatorSet()72 OperatorSet& getCustomOperatorSet() {
73   static OperatorSet _g_custom_operator_set{};
74   return _g_custom_operator_set;
75 }
76 
supported_non_eltwise_set()77 static const OperatorSet& supported_non_eltwise_set() {
78   // clang-format off
79   static const OperatorSet supported_non_eltwise_set{
80       "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor",
81       "aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor",
82       "aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor",
83       "aten::matmul(Tensor self, Tensor other) -> Tensor",
84   };
85   // clang-format on
86   return supported_non_eltwise_set;
87 };
88 
isSupported(Node * node)89 bool isSupported(Node* node) {
90   // For Block codegen we allow limited ops.
91   if (tensorexpr::getTEGenerateBlockCode()) {
92     return isSupportedForBlock(node);
93   }
94 
95   static const OperatorSet supported_reduction_set{
96       "aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor",
97       "aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor",
98       "aten::softmax.int(Tensor self, int dim , ScalarType? dtype=None) -> Tensor",
99       "aten::log_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor",
100   };
101   static const OperatorSet supported_misc_set{
102       "aten::cat(Tensor[] tensors, int dim=0) -> Tensor",
103       "aten::unsqueeze(Tensor(a) self, int dim) -> Tensor(a)",
104   };
105   // clang-format on
106 
107   if (get_tensorexpr_elementwise_set().contains(node) ||
108       node->isMemberOf(supported_non_eltwise_set()) ||
109       node->isMemberOf(supported_misc_set) ||
110       node->isMemberOf(getCustomOperatorSet()) ||
111       (texpr_reductions_enabled && node->isMemberOf(supported_reduction_set))) {
112     // We only insert guards on Tensor types, so we rely on the output
113     // of a node being uniquely determined by its input types.
114     // bail if any non-Tensor input affects the output type
115     // and cannot be reasoned about statically
116 
117     // Value is either an int or a float (can occur from .item())
118     for (Value* v : node->inputs()) {
119       if (v->type()->cast<NumberType>()) {
120         return false;
121       }
122     }
123 
124     // non-const dtype / device
125     for (auto arg_name : {"dtype", "device"}) {
126       if (auto index = node->schema().argumentIndexWithName(arg_name)) {
127         if (!toIValue(node->input(*index))) {
128           return false;
129         }
130       }
131     }
132 
133     if (FLAGS_torch_jit_disable_cat && node->kind() == aten::cat) {
134       return false;
135     }
136 
137     return true;
138   }
139 
140   // unschematized ops
141   switch (node->kind()) {
142     case prim::ConstantChunk:
143     case prim::ListConstruct:
144     case prim::TensorExprGroup:
145       return true;
146   }
147 
148   return false;
149 }
150 } // namespace tensorexpr
151 
152 static bool texpr_fuser_enabled_ = true;
153 
setTensorExprFuserEnabled(bool val)154 void setTensorExprFuserEnabled(bool val) {
155   texpr_fuser_enabled_ = val;
156 }
157 
tensorExprFuserEnabled()158 bool tensorExprFuserEnabled() {
159   static const char* enable_c_str = std::getenv("PYTORCH_TENSOREXPR");
160   if (!enable_c_str) {
161     return texpr_fuser_enabled_;
162   }
163   if (std::string(enable_c_str) == "0") {
164     return false;
165   }
166   return true;
167 }
168 
tensorExprDynamicShapeFusionEnabled()169 bool tensorExprDynamicShapeFusionEnabled() {
170   return FLAGS_torch_jit_enable_dynamic_shape_fusion;
171 }
172 
setTensorExprDynamicShapeFusionEnabled(bool val)173 void setTensorExprDynamicShapeFusionEnabled(bool val) {
174   FLAGS_torch_jit_enable_dynamic_shape_fusion = val;
175 }
176 
setTexprReductionsEnabled(bool value)177 bool setTexprReductionsEnabled(bool value) {
178   bool old_value = texpr_reductions_enabled;
179   texpr_reductions_enabled = value;
180   return old_value;
181 }
182 
texprReductionsEnabled()183 bool texprReductionsEnabled() {
184   return texpr_reductions_enabled;
185 }
186 
removeProfileNodesAndSpecializeTypes(Block * b)187 static void removeProfileNodesAndSpecializeTypes(Block* b) {
188   for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) {
189     if (it->kind() == prim::profile) {
190       GRAPH_DEBUG("Removing prim::profile: %", it->output()->debugName());
191       it->output()->replaceAllUsesWith(it->input());
192       auto profiled_type = it->ty(attr::profiled_type)->expect<TensorType>();
193 
194       TensorTypePtr input_tensor_type = nullptr;
195       bool input_is_optional = false;
196       if (it->input()->type()->kind() == c10::TypeKind::TensorType) {
197         input_tensor_type = it->input()->type()->expect<TensorType>();
198       } else {
199         input_tensor_type = it->input()
200                                 ->type()
201                                 ->expectRef<OptionalType>()
202                                 .getElementType()
203                                 ->expect<TensorType>();
204         input_is_optional = true;
205       }
206 
207       if (input_is_optional) {
208         it.destroyCurrent();
209         continue;
210       }
211 
212       // A value can be profiled with differently typed uses.
213       // This can occur from:
214       // - having a use which is not executed, so the type will be
215       // TensorType::get()
216       // - control-flow that depends on tensor type:
217       //   if x.size() == 2 op(x) else op(x)
218       // - mutation of the value on a field represented in the tensor type
219       //   op(x); x.resize_([...]); op(x)
220 
221       // The most common case today with num_profiles = 1 is from the first
222       // case. Here we can just ignore non-profiled uses, and choose any of the
223       // profiled uses. Because we guard all tensor types in the runtime, even
224       // if we set a Value to have a profiled type from one use and then execute
225       // a use with a different profiled type, we will still be correct.
226       // In the future we could consider unifying the types of uses, or adding a
227       // type refinement node so uses can have the correct corresponding type.
228       if (profiled_type == TensorType::get()) {
229         continue;
230       }
231 
232       // If we encounter non-identical profiled types for the same value, merge
233       // them.  This situation can happen if, e.g., loop unrolling duplicates
234       // profiled types in a loop body in a manner that isn't logically
235       // consistent (see TestTEFuser.test_unrolled_cat).
236       if (input_tensor_type == TensorType::get()) {
237         it->input()->setType(profiled_type);
238       } else {
239         it->input()->setType(input_tensor_type->merge(*profiled_type));
240       }
241 
242       it.destroyCurrent();
243     } else {
244       for (Block* ib : it->blocks()) {
245         removeProfileNodesAndSpecializeTypes(ib);
246       }
247     }
248   }
249 }
250 
RemoveProfileNodesAndSpecializeTypes(std::shared_ptr<Graph> & graph)251 void RemoveProfileNodesAndSpecializeTypes(std::shared_ptr<Graph>& graph) {
252   GRAPH_DEBUG("Before removeProfileNodesAndSpecializeTypes:\n", *graph);
253   removeProfileNodesAndSpecializeTypes(graph->block());
254   GRAPH_DEBUG("After removeProfileNodesAndSpecializeTypes:\n", *graph);
255 }
256 
hasTensorTypeSpecialization(Value * v)257 bool hasTensorTypeSpecialization(Value* v) {
258   if (!v->type()->cast<TensorType>()) {
259     return false;
260   }
261   // Constants & TensorExprGroup will always produce specialized tensor type,
262   // TypeCheck are inserted by this pass and only used by fusion groups that
263   // insert proper guards
264   if (v->node()->kind() == prim::Constant ||
265       v->node()->kind() == prim::TypeCheck ||
266       v->node()->kind() == prim::TensorExprGroup) {
267     return false;
268   }
269   if (v->type() == TensorType::get()) {
270     return false;
271   }
272   return true;
273 }
274 
removeTensorTypeSpecialization(Value * v)275 static void removeTensorTypeSpecialization(Value* v) {
276   if (hasTensorTypeSpecialization(v)) {
277     v->setType(TensorType::get());
278   }
279 }
280 
removeTensorTypeSpecializations(Block * block)281 void removeTensorTypeSpecializations(Block* block) {
282   for (Value* v : block->inputs()) {
283     removeTensorTypeSpecialization(v);
284   }
285   for (Node* n : block->nodes()) {
286     for (Block* b : n->blocks()) {
287       removeTensorTypeSpecializations(b);
288     }
289     for (Value* v : n->outputs()) {
290       removeTensorTypeSpecialization(v);
291     }
292   }
293 }
294 
RemoveTensorTypeSpecializations(std::shared_ptr<Graph> & graph)295 void RemoveTensorTypeSpecializations(std::shared_ptr<Graph>& graph) {
296   removeTensorTypeSpecializations(graph->block());
297 }
298 
insertTypeGuard(Node * guarded_node,tensor_type_converter_t type_converter,Symbol kind)299 void insertTypeGuard(
300     Node* guarded_node,
301     tensor_type_converter_t type_converter,
302     Symbol kind) {
303   GRAPH_DEBUG("Inserting a typecheck guard for a node", *guarded_node);
304   auto subgraph = SubgraphUtils::getSubgraph(guarded_node);
305 
306   // Fixup types of the subgraph inputs
307   std::vector<Value*> inputs_to_check;
308   std::vector<TypePtr> guard_types;
309   for (Value* input : guarded_node->inputs()) {
310     // We only check inputs of the guarded nodes and expect user to infer
311     // intermediates and outputs shapes
312     if (!input->type()->cast<TensorType>()) {
313       continue;
314     }
315 
316     // fusion outputs are already guarded
317     if (input->node()->kind() == prim::Constant ||
318         input->node()->kind() == prim::FusionGroup) {
319       continue;
320     }
321     inputs_to_check.push_back(input);
322     guard_types.emplace_back(
323         type_converter(input->type()->expect<TensorType>()));
324   }
325   if (inputs_to_check.empty()) {
326     return;
327   }
328 
329   // Add prim::TypeCheck node
330   //
331   // TypeCheck nodes  look like the following:
332   //   %out1 : Float(2, 3), %out2 : Int(10, 30), %types_match : bool =
333   //   prim::TypeCheck(%inp1 : Tensor, %inp2 : Tensor)
334   //
335   // They have N inputs whose types we are going to check and N+1 outputs. The
336   // first N outputs specify expected types and N+1-th output holds the result
337   // of the check (bool).
338   Node* typecheck_node =
339       guarded_node->owningGraph()
340           ->create(kind, inputs_to_check, inputs_to_check.size() + 1)
341           ->insertBefore(guarded_node);
342   typecheck_node->tys_(attr::types, std::move(guard_types));
343   Value* typecheck_result = typecheck_node->output(inputs_to_check.size());
344 
345   std::unordered_map<Value*, Value*> typechecked_inputs;
346   for (size_t i = 0; i < typecheck_node->inputs().size(); ++i) {
347     typechecked_inputs[typecheck_node->input(i)] = typecheck_node->output(i);
348   }
349 
350   // Fixup types of the typecheck node outputs, which are used by the op in
351   // execution
352   typecheck_node->output(inputs_to_check.size())->setType(BoolType::get());
353   for (size_t i = 0; i < typecheck_node->inputs().size(); ++i) {
354     typecheck_node->output(i)->setType(typecheck_node->input(i)->type());
355   }
356 
357   // Insert if
358   auto versioning_if =
359       guarded_node->owningGraph()
360           ->create(prim::If, {typecheck_result}, guarded_node->outputs().size())
361           ->insertAfter(typecheck_node);
362   for (size_t idx = 0; idx < guarded_node->outputs().size(); ++idx) {
363     versioning_if->output(idx)->setType(guarded_node->output(idx)->type());
364     guarded_node->output(idx)->replaceAllUsesWith(versioning_if->output(idx));
365   }
366   auto true_block = versioning_if->addBlock();
367   auto false_block = versioning_if->addBlock();
368 
369   // Fill in the false block. It should contain the unoptimized
370   // copy of the fused subgraph.
371   WithInsertPoint guard(false_block->return_node());
372   const auto subgraph_outputs = insertGraph(
373       *guarded_node->owningGraph(), *subgraph, guarded_node->inputs());
374   for (Value* output : subgraph_outputs) {
375     false_block->registerOutput(output);
376   }
377 
378   // types get copied to the fallback graph, so remove specializations before
379   // replacing
380   removeTensorTypeSpecializations(false_block);
381   replaceBlockWithFallbackGraph(false_block, guarded_node->inputs());
382 
383   // Fill in the true block. It has all inputs type-checked and its
384   // body should be the fusion group node.
385   guarded_node->moveBefore(true_block->return_node());
386   for (size_t idx = 0; idx < guarded_node->inputs().size(); ++idx) {
387     if (typechecked_inputs.count(guarded_node->input(idx))) {
388       guarded_node->replaceInput(
389           idx, typechecked_inputs.at(guarded_node->input(idx)));
390     }
391   }
392   for (Value* output : guarded_node->outputs()) {
393     true_block->registerOutput(output);
394   }
395 }
396 
397 namespace {
has_unsupported_pin_memory(const Node * node)398 bool has_unsupported_pin_memory(const Node* node) {
399   // cant support non-constant pin_memory or pin_memory = True
400   if (auto maybe_index = node->schema().argumentIndexWithName("pin_memory")) {
401     int index = *maybe_index;
402     auto inp = node->input(index);
403     if (inp->type() != NoneType::get() &&
404         constant_as<bool>(inp).value_or(true)) {
405       return true;
406     }
407   }
408   return false;
409 }
410 } // namespace
411 
412 class TensorExprFuser {
413  public:
TensorExprFuser(std::shared_ptr<Graph> graph,size_t min_group_size,bool add_composed_op,bool fuse_to_dynamic_shapes)414   TensorExprFuser(
415       std::shared_ptr<Graph> graph,
416       size_t min_group_size,
417       bool add_composed_op,
418       bool fuse_to_dynamic_shapes)
419       : graph_(std::move(graph)),
420         min_group_size_(min_group_size),
421         add_composed_op_(add_composed_op),
422         fuse_to_dynamic_shapes_(fuse_to_dynamic_shapes) {
423     parseTENotFuseOption();
424   }
425 
426   // Builds up expressions that compute shapes of all intermediates (and
427   // outputs) of the fusion group, based on the sizes of inputs. You should run
428   // DCE to remove those that you end up not using.
buildShapeExpressions(Node * fusion_group)429   std::unordered_map<Value*, Value*> buildShapeExpressions(Node* fusion_group) {
430     GRAPH_DUMP("buildShapeExpressions for ", fusion_group->g(attr::Subgraph));
431     WithInsertPoint insert_guard{fusion_group->next()};
432     std::unordered_map<Value*, Value*> shape_of;
433 
434     Graph* graph = fusion_group->owningGraph();
435     auto subgraph = fusion_group->g(attr::Subgraph);
436 
437     auto inputs = fusion_group->inputs();
438     auto sinputs = subgraph->inputs();
439     AT_ASSERT(inputs.size() == sinputs.size());
440     for (const auto i : c10::irange(inputs.size())) {
441       if (inputs[i]->type()->isSubtypeOf(*TensorType::get())) {
442         Value* soutput = graph->insert(aten::size, {inputs[i]});
443         aliasDb_->createValue(soutput);
444         GRAPH_DEBUG(
445             "Adding a mapping for %",
446             sinputs[i]->debugName(),
447             " ",
448             getHeader(soutput->node()));
449         shape_of[sinputs[i]] = soutput;
450       }
451     }
452 
453     // When we have a guarantee that an output won't be removed, because it's
454     // used in expressions that don't involve size checks, we can use its size
455     // instead of computing a long chain of broadcasts, starting from the
456     // beginning of the kernel.
457     auto outputs = fusion_group->outputs();
458     auto soutputs = subgraph->outputs();
459     AT_ASSERT(outputs.size() == soutputs.size());
460     for (const auto i : c10::irange(outputs.size())) {
461       if (usedOnlyInSize(outputs[i]))
462         continue;
463       Value* soutput = graph->insert(aten::size, {outputs[i]});
464       aliasDb_->createValue(soutput);
465       shape_of[soutputs[i]] = soutput;
466     }
467 
468     for (Node* n : subgraph->nodes()) {
469       auto tensor_inputs = filter(n->inputs(), [](Value* v) {
470         return v->type()->isSubtypeOf(*TensorType::get());
471       });
472       GRAPH_DEBUG("Building sizes for ", getHeader(n));
473       bool all_inputs_have_sizes = true;
474       auto shapes = fmap(tensor_inputs, [&](Value* v) {
475         GRAPH_DEBUG("Getting aten::size for %", v->debugName());
476         all_inputs_have_sizes &= shape_of.count(v);
477         return shape_of.count(v) != 0 ? shape_of.at(v) : nullptr;
478       });
479       if (!all_inputs_have_sizes) {
480         GRAPH_DEBUG(
481             "Not all tensor arguments have sizes available to compute the broadcasted size",
482             getHeader(n));
483         continue;
484       }
485 
486       if (n->kind() == prim::ConstantChunk) {
487         Node* sizes_node = graph->insertNode(
488             graph->create(prim::ChunkSizes, shape_of.at(n->input()), 2));
489         sizes_node->i_(attr::dim, n->i(attr::dim));
490         sizes_node->i_(attr::chunks, n->i(attr::chunks));
491         for (Value* output : sizes_node->outputs()) {
492           aliasDb_->createValue(output);
493         }
494         Value* regular_size = sizes_node->outputs().at(0);
495         Value* last_size = sizes_node->outputs().at(1);
496         regular_size->setType(ListType::ofInts());
497         last_size->setType(ListType::ofInts());
498         auto outputs = n->outputs();
499         for (Value* o : outputs.slice(0, outputs.size() - 1)) {
500           shape_of.emplace(o, regular_size);
501         }
502         shape_of.emplace(outputs.at(outputs.size() - 1), last_size);
503         continue;
504       }
505 
506       // we only support shape calculations for elementwise, some
507       // non-elementwise like batch_norm, conv, matmul, and
508       // a few exceptions (e.g. prim::ConstantChunk, etc) listed above
509       if (!(get_tensorexpr_elementwise_set().contains(n)) &&
510           !n->isMemberOf(tensorexpr::supported_non_eltwise_set())) {
511         continue;
512       }
513 
514       shape_of.emplace(
515           n->output(),
516           shapes.size() == 1 ? shapes[0]
517                              : broadcastSizes(shapes, aliasDb_.get()));
518     }
519     return shape_of;
520   }
521 
removeOutputsUsedOnlyInSize(Node * fusion_group)522   void removeOutputsUsedOnlyInSize(Node* fusion_group) {
523     if (fusion_group->kind() != prim::TensorExprGroup)
524       return;
525     auto subgraph = fusion_group->g(attr::Subgraph);
526 
527     auto shape_of = buildShapeExpressions(fusion_group);
528     auto outputs = fusion_group->outputs().vec();
529     auto soutputs = subgraph->outputs().vec();
530     // XXX: Iterating in this order is not only good for performance reasons!
531     // It is also crucial for correctness (i has to reflect the current true
532     // index of outputs[i])!
533     for (int64_t i = static_cast<int64_t>(outputs.size()) - 1; i >= 0; --i) {
534       auto output = outputs[i];
535       auto soutput = soutputs[i];
536       if (usedOnlyInSize(output) && shape_of.count(soutput) > 0) {
537         auto uses = output->uses();
538         for (Use u : uses) {
539           AT_ASSERT(u.user->matches("aten::size(Tensor self) -> int[]"));
540           u.user->output()->replaceAllUsesWith(shape_of.at(soutput));
541           u.user->destroy();
542         }
543         fusion_group->eraseOutput(i);
544         subgraph->eraseOutput(i);
545       }
546     }
547   }
548 
run()549   void run() {
550     aliasDb_ = std::make_unique<AliasDb>(graph_);
551     RemoveRedundantProfiles(graph_);
552     GRAPH_DUMP("After removing redundant profile nodes: ", graph_);
553     createFusionGroups(graph_->block());
554     GRAPH_DUMP("After creating fusion groups: ", graph_);
555     // we maintain alias db correctness during initial fusion, but it is
556     // difficult to maintain correctness after inlining so inline only after
557     // fusion is done.
558     inlineSmallFusionGroups(graph_->block());
559     GRAPH_DUMP("After inlining small fusion groups: ", graph_);
560     if (fuse_to_dynamic_shapes_) {
561       VLOG(1) << "TensorExpr fusion with dynamic shapes is enabled" << '\n';
562       generalizeFusionGroups(graph_->block());
563       GRAPH_DUMP("After generalizing fusion groups: ", graph_);
564     } else {
565       prepareFusionGroupAndGuardOutputs(graph_->block());
566       GRAPH_DUMP("After guarding fusion groups: ", graph_);
567     }
568   }
569 
570  private:
getOrCreateTensorExprSubgraph(Node * n)571   Node* getOrCreateTensorExprSubgraph(Node* n) {
572     if (n->hasAttribute(attr::Subgraph) && n->kind() == prim::TensorExprGroup) {
573       return n;
574     }
575     GRAPH_UPDATE("Creating a tensorexpr::Group node from: ", *n);
576     return SubgraphUtils::createSingletonSubgraphAndUpdateAliasing(
577         n, prim::TensorExprGroup, *aliasDb_);
578   }
579 
sortReverseTopological(ArrayRef<Value * > inputs,Block * b)580   value_list sortReverseTopological(ArrayRef<Value*> inputs, Block* b) {
581     value_list result;
582     for (auto i : inputs) {
583       if (i->node()->owningBlock() == b) {
584         result.push_back(i);
585       }
586     }
587     // Sort in reverse topological order
588     std::sort(result.begin(), result.end(), [&](Value* a, Value* b) {
589       return a->node()->isAfter(b->node());
590     });
591     return result;
592   }
593 
594   // Create a fusion group starting from the node N.
595   // We then try to pull inputs into the fusion group and repeat that process
596   // until there is nothing we can pull in.
createFusionGroup(Node * fusion_node)597   std::pair<graph_node_list::iterator, bool> createFusionGroup(
598       Node* fusion_node) {
599     // Allow single-node groups containing conv2d, since we'll only select
600     // those in cases where the tensorexpr implementation is faster than the
601     // aten implementation.
602     if (min_group_size_ == 1 || fusion_node->kind() == aten::conv2d) {
603       fusion_node = getOrCreateTensorExprSubgraph(fusion_node);
604     }
605 
606     GRAPH_DEBUG("Iteratively pull input nodes into the fusion group...\n");
607     auto inputs = sortReverseTopological(
608         fusion_node->inputs(), fusion_node->owningBlock());
609     for (auto input : inputs) {
610       debugDumpFusionGroup("Current fusion group: ", fusion_node);
611       GRAPH_DEBUG("Trying to merge: ", *input->node());
612       if (auto maybe_fusion_group = tryMerge(fusion_node, input->node())) {
613         // we successfully merged, so the new group's `inputs` may have
614         // changed. So rescan the new group for more merging opportunities.
615         return std::make_pair(
616             maybe_fusion_group.value()->reverseIterator(), true);
617       }
618     }
619 
620     return std::make_pair(++fusion_node->reverseIterator(), false);
621   }
622 
debugDumpFusionGroup(const std::string & msg,Node * n)623   static void debugDumpFusionGroup(const std::string& msg, Node* n) {
624     GRAPH_DEBUG(msg, *n);
625     if (n->kind() == prim::TensorExprGroup) {
626       GRAPH_DEBUG(*n->g(attr::Subgraph));
627     }
628   }
629 
630   // No Ops in eager shouldn't be outputs of Fusion Groups because it
631   // will degrade perf and change aliasing relationships
unexecutedEagerOp(Node * n)632   static bool unexecutedEagerOp(Node* n) {
633     if (n->kind() != aten::to &&
634         n->kind() != aten::_autocast_to_reduced_precision &&
635         n->kind() != aten::_autocast_to_full_precision) {
636       return false;
637     }
638 
639     return *n->input(0)->type()->expect<TensorType>() ==
640         *n->output()->type()->expect<TensorType>();
641   }
642 
scanNode(Node * n)643   std::pair<graph_node_list::iterator, bool> scanNode(Node* n) {
644     GRAPH_DEBUG("Considering node:", *n)
645 
646     if (!canHandle(n)) {
647       return std::make_pair(++n->reverseIterator(), false);
648     }
649     // There are some nodes that we can support, but we don't want to start a
650     // fusion group from - skip them.
651     if (n->kind() == prim::ListConstruct || n->kind() == aten::slice ||
652         n->kind() == aten::unsqueeze || n->kind() == prim::ConstantChunk ||
653         n->kind() == prim::Constant || unexecutedEagerOp(n)) {
654       return std::make_pair(++n->reverseIterator(), false);
655     }
656     return createFusionGroup(n);
657   }
658 
659   // Merge fusible nodes into subgraphs in prim::TensorExprGroup nodes.
createFusionGroups(Block * block)660   void createFusionGroups(Block* block) {
661     bool any_changed = true;
662     while (any_changed) {
663       any_changed = false;
664       for (auto it = block->nodes().rbegin(); it != block->nodes().rend();) {
665         auto [tmp_it, changed] = scanNode(*it);
666         it = tmp_it;
667         any_changed |= changed;
668       }
669     }
670 
671     for (Node* n : block->nodes()) {
672       for (Block* b : n->blocks()) {
673         createFusionGroups(b);
674       }
675     }
676 
677     // Try to merge adjacent fusion groups together. Because we have only merged
678     // by looking at graph inputs, without this we would not attempt to merge
679     // adjacent fusion groups that don't have a dependency on each other
680 
681     std::vector<Node*> initial_fusion_groups;
682     for (Node* n : block->nodes()) {
683       if (n->kind() == prim::TensorExprGroup) {
684         initial_fusion_groups.push_back(n);
685       }
686     }
687 
688     Node* prev_fusion_group =
689         !initial_fusion_groups.empty() ? initial_fusion_groups[0] : nullptr;
690 
691     for (const auto i : c10::irange(1, initial_fusion_groups.size())) {
692       // Try merging the just created fusion group into the previous one.
693       // If it did not work, then put the previous fusion group into
694       // fusion_groups vector - we will not touch it anymore in this loop.
695       // If merging succeeded, save the merged group as the "previous" fusion
696       // group so that we can try to merge the next one into it.
697 
698       Node* fusion_group = initial_fusion_groups[i];
699       debugDumpFusionGroup(
700           "Trying to merge into the previous fusion group: ",
701           prev_fusion_group);
702       if (auto merged_fusion_group =
703               tryMerge(prev_fusion_group, fusion_group)) {
704         prev_fusion_group = *merged_fusion_group;
705         debugDumpFusionGroup(
706             "Successfully merged into the previous fusion group: ",
707             prev_fusion_group);
708       } else {
709         GRAPH_DEBUG("Cannot merge into the previous fusion group");
710         prev_fusion_group = fusion_group;
711       }
712     }
713   }
714 
blockSize(Block * block)715   size_t blockSize(Block* block) {
716     size_t num = 0;
717     for (Node* n : block->nodes()) {
718       // Don't count prim::Constants and prim::ListConstructs as these are nodes
719       // we only pull in along with another, "main", node. E.g. the
720       // ListConstruct nodes would also be pulled into a fusion group if they
721       // are inputs of an aten::cat node.
722       if (n->kind() == prim::Constant || n->kind() == prim::ListConstruct) {
723         continue;
724       }
725       for (Block* b : n->blocks()) {
726         num += blockSize(b);
727       }
728       num++;
729     }
730     return num;
731   }
732 
hasConv(Block * block)733   bool hasConv(Block* block) {
734     for (Node* n : block->nodes()) {
735       if (n->kind() == aten::conv2d) {
736         return true;
737       }
738     }
739     return false;
740   }
741 
inlineIfTooSmall(Node * n)742   bool inlineIfTooSmall(Node* n) {
743     if (n->kind() != prim::TensorExprGroup) {
744       return false;
745     }
746     auto subgraph = SubgraphUtils::getSubgraph(n);
747     size_t num_nodes = blockSize(subgraph->block());
748     // Allow small subgraphs containing conv2d, since we'll only select those
749     // in cases where the tensorexpr implementation is faster than the aten
750     // implementation.
751     if (num_nodes < min_group_size_ && !hasConv(subgraph->block())) {
752       GRAPH_UPDATE("Fusion group is too small, unmerging: ", *n);
753       SubgraphUtils::unmergeSubgraph(n);
754       return true;
755     }
756     // Cleanup the subgraph from duplicated constants while we're at it.
757     ConstantPooling(subgraph);
758 
759     if (GRAPH_DEBUG_ENABLED) {
760       GRAPH_EXPORT("", subgraph);
761     }
762     return false;
763   }
764 
inlineSmallFusionGroups(Block * block)765   void inlineSmallFusionGroups(Block* block) {
766     for (auto it = block->nodes().begin(); it != block->nodes().end();) {
767       Node* n = *it;
768       it++;
769 
770       for (Block* b : n->blocks()) {
771         inlineSmallFusionGroups(b);
772       }
773       inlineIfTooSmall(n);
774     }
775   }
776 
tryMerge(Node * fusion_group,Node * to_merge)777   std::optional<Node*> tryMerge(Node* fusion_group, Node* to_merge) {
778     if (!canMerge(fusion_group, to_merge)) {
779       return std::nullopt;
780     }
781 
782     std::vector<Node*> nodes_to_merge = {to_merge};
783 
784     if (to_merge->kind() == aten::cat) {
785       Node* listconstruct = to_merge->input(0)->node();
786       nodes_to_merge.push_back(listconstruct);
787     }
788 
789     // First, try to move all the nodes we want to fuse next to the fusion
790     // group.
791     Node* move_point = fusion_group;
792     for (auto n : nodes_to_merge) {
793       GRAPH_UPDATE("Trying to move node next to fusion group: ", getHeader(n));
794       if (!aliasDb_->moveBeforeTopologicallyValid(n, move_point)) {
795         GRAPH_UPDATE("Failed to move because of AliasDB checks!");
796         return std::nullopt;
797       }
798       move_point = n;
799     }
800 
801     // Now all the nodes that we're going to fuse are moved next to the fusion
802     // group, so we can safely merge them into the fusion group subgraph.
803     fusion_group = getOrCreateTensorExprSubgraph(fusion_group);
804 
805     for (auto n : nodes_to_merge) {
806       GRAPH_UPDATE("Merging ", getHeader(n));
807       SubgraphUtils::mergeNodeIntoSubgraphAndUpdateAliasing(
808           n, fusion_group, *aliasDb_);
809     }
810     return fusion_group;
811   }
812 
shapeIsKnown(Value * v)813   bool shapeIsKnown(Value* v) {
814     if (v->type()->cast<TensorType>()) {
815       if (!v->isCompleteTensor()) {
816         return false;
817       }
818     }
819     return true;
820   }
821 
allShapesAreKnown(Node * node)822   bool allShapesAreKnown(Node* node) {
823     // TODO: Relax the checks to support dynamic shapes
824     for (Value* input : node->inputs()) {
825       if (!shapeIsKnown(input)) {
826         return false;
827       }
828       if (input->node()->kind() == prim::ListConstruct) {
829         if (!allShapesAreKnown(input->node())) {
830           return false;
831         }
832       }
833     }
834     for (Value* output : node->outputs()) {
835       if (!shapeIsKnown(output)) {
836         return false;
837       }
838     }
839     return true;
840   }
841 
canFuseOnDevice(Value * v)842   bool canFuseOnDevice(Value* v) {
843     auto type = v->type()->cast<TensorType>();
844     if (!type) {
845       return true;
846     }
847     auto device = type->device();
848     if (!device) {
849       return false;
850     }
851     if (device->is_cpu()) {
852       return canFuseOnCPU();
853     } else if (device->is_cuda()) {
854       return canFuseOnGPU();
855     } else if (device->is_xpu()) {
856       return false;
857     }
858     return false;
859   }
860 
isFusableOnDevice(Node * node)861   bool isFusableOnDevice(Node* node) {
862     for (const auto& input : node->inputs()) {
863       if (input->node()->kind() == prim::ListConstruct) {
864         if (!isFusableOnDevice(input->node())) {
865           return false;
866         }
867       }
868       if (!canFuseOnDevice(input)) {
869         return false;
870       }
871     }
872     return true;
873   }
874 
typesAreSupported(Node * node)875   bool typesAreSupported(Node* node) {
876     // clang-format off
877     // breaks up the schema strings so they are no longer discoverable with ctrl-F
878     static const OperatorSet float_only_operator_set{
879       "aten::fmod.Scalar(Tensor self, Scalar other) -> Tensor",
880       "aten::fmod.Tensor(Tensor self, Tensor other) -> Tensor",
881       "aten::remainder.Scalar(Tensor self, Scalar other) -> Tensor",
882       "aten::remainder.Tensor(Tensor self, Tensor other) -> Tensor",
883     };
884     static const OperatorSet int_only_operator_set{
885       "aten::__lshift__.Scalar(Tensor self, Scalar other) -> Tensor",
886       "aten::__lshift__.Tensor(Tensor self, Tensor other) -> Tensor",
887       "aten::__rshift__.Scalar(Tensor self, Scalar other) -> Tensor",
888       "aten::__rshift__.Tensor(Tensor self, Tensor other) -> Tensor",
889     };
890     static const OperatorSet cpu_compute_heavy_set{
891       "aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor",
892       "aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor",
893       "aten::matmul(Tensor self, Tensor other) -> Tensor",
894     };
895     static const OperatorSet gpu_only_operator_set{
896       // On CPU, these are slower and less accurate than ATen kernels, because
897       // ATen is able to use MKL-VML, whereas the fuser currently can't.  The
898       // fuser uses sleef instead because sleef provides functions that operate
899       // on vectors, instead of large buffers.
900       "aten::erf(Tensor self) -> Tensor",
901       "aten::erfc(Tensor self) -> Tensor",
902     };
903     static const OperatorSet pow{
904       "aten::pow.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor",
905     };
906     // clang-format on
907 
908     // Check types of input values.
909     for (const Value* v : node->inputs()) {
910       if (auto const& tt = v->type()->cast<TensorType>()) {
911         auto const& st = tt->scalarType();
912         auto const& device = tt->device();
913 
914         // All tensors must be typed.
915         if (!st || !device) {
916           return false;
917         }
918 
919         // Byte tensors introduce too many corner cases in type promotion.
920         // Better not to try to handle them.
921         if (*st == c10::ScalarType::Byte) {
922           return false;
923         }
924 
925         // Float16 support has some issues (see e.g. #61336 and #61382), so for
926         // now it's disabled. There seem to be some problems in HalfRewriter,
927         // but on top of that Float16 has a few kinks on LLVM.  Thus, on CPU we
928         // additionally disable it until we either move to a more stable version
929         // or find workarounds.
930         if (*st == c10::ScalarType::Half && *device == c10::kCPU) {
931           return false;
932         }
933 
934         if (*st == c10::ScalarType::BFloat16 && *device == c10::kCPU) {
935 #ifndef TORCH_ENABLE_LLVM
936           return false;
937 #endif
938         }
939 
940         // These operators only support floats, because integer divisors need to
941         // raise ZeroDivisionError.
942         if (node->isMemberOf(float_only_operator_set) && !isFloatingType(*st)) {
943           return false;
944         }
945 
946         // These operators have complicated casting rules for floats.
947         if (node->isMemberOf(int_only_operator_set) && isFloatingType(*st)) {
948           return false;
949         }
950       } else if (node->isMemberOf(float_only_operator_set)) {
951         // Check scalar operands of float-only ops.
952         if (!v->type()->cast<FloatType>()) {
953           return false;
954         }
955       } else if (node->isMemberOf(int_only_operator_set)) {
956         if (!v->type()->cast<IntType>()) {
957           return false;
958         }
959       }
960     }
961 
962     // aten::pow has special rules to avoid complicated integer cases.  We
963     // expect the first arg to be a floating point tensor, and if that's the
964     // case the type of the scalar exponent doesn't matter.
965     if (node->isMemberOf(pow)) {
966       auto const& tt = node->input(0)->type()->cast<TensorType>();
967       if (!tt) {
968         return false;
969       }
970       auto const& st = tt->scalarType();
971       if (!st || !isFloatingType(*st)) {
972         return false;
973       }
974     }
975 
976     // Operator is only supported on CPU.
977     if (node->isMemberOf(cpu_compute_heavy_set)) {
978       if (fuse_to_dynamic_shapes_) {
979         return false;
980       }
981 
982       auto device = tensorexpr::pickDeviceType(node->inputs());
983       if (!device) {
984         device = tensorexpr::pickDeviceType(node->outputs());
985       }
986       if (!device || !device->is_cpu()) {
987         return false;
988       }
989     }
990 
991     // Operator is only supported on GPU.
992     if (node->isMemberOf(gpu_only_operator_set)) {
993       auto device = tensorexpr::pickDeviceType(node->inputs());
994       if (!device) {
995         device = tensorexpr::pickDeviceType(node->outputs());
996       }
997       if (!device || !device->is_cuda()) {
998         return false;
999       }
1000     }
1001 
1002     if (node->kind() == aten::to) {
1003       // only support same-device conversion
1004       auto device = tensorexpr::pickDeviceType(node->inputs());
1005       auto output_device = tensorexpr::pickDeviceType(node->outputs());
1006       if (!device || !output_device || *device != *output_device) {
1007         return false;
1008       }
1009       // non_blocking only applies in cross-device conversion, which we bail on
1010       // copy arg only applies if op is a no-op, which we dont start fusion
1011       // group from memory format is separately handled in NNC output
1012 
1013       // all non-Tensor arguments must be constant
1014       for (size_t i = 1; i < node->inputs().size(); i++) {
1015         if (node->inputs().at(i)->node()->kind() != prim::Constant) {
1016           return false;
1017         }
1018       }
1019 
1020       if (has_unsupported_pin_memory(node)) {
1021         return false;
1022       }
1023     }
1024 
1025     if (node->kind() == aten::_autocast_to_reduced_precision ||
1026         node->kind() == aten::_autocast_to_full_precision) {
1027       for (auto i : c10::irange(1, node->inputs().size())) {
1028         if (node->inputs().at(i)->node()->kind() != prim::Constant) {
1029           return false;
1030         }
1031       }
1032 
1033       bool is_reduced_precision =
1034           node->kind() == aten::_autocast_to_reduced_precision;
1035       bool is_full_precision =
1036           node->kind() == aten::_autocast_to_full_precision;
1037       auto self_tensor = node->inputs()[0]; // input tensor
1038 
1039       if (auto const& tt = self_tensor->type()->cast<TensorType>()) {
1040         auto st = tt->scalarType();
1041         if (!st.has_value()) {
1042           return false;
1043         }
1044 
1045         auto device = tt->device();
1046         if (!device.has_value()) {
1047           return false;
1048         }
1049 
1050         bool is_cpu = device->is_cpu();
1051 
1052         if (*st != at::kFloat && is_reduced_precision && is_cpu) {
1053           // Regarding CPU, aten would do nothing if the data type is
1054           // float. Then the aten performance is better than NNC. So NNC
1055           // does not pull it into its fusion group.
1056           return false;
1057         }
1058 
1059         if (*st != at::kBFloat16 && is_full_precision && is_cpu) {
1060           // Regarding CPU, aten would do nothing if the data type is
1061           // BFloat16. Then the aten performance is better than NNC. So NNC
1062           // does not pull it into its fusion group.
1063           return false;
1064         }
1065       }
1066 
1067       if (has_unsupported_pin_memory(node)) {
1068         return false;
1069       }
1070     }
1071 
1072     if (node->kind() == aten::unsqueeze) {
1073       // `dim` argument must be a constant.
1074       if (node->input(1)->node()->kind() != prim::Constant) {
1075         return false;
1076       }
1077     }
1078 
1079     if (node->kind() == aten::_convolution && !tensorexpr::isConv2d(node)) {
1080       GRAPH_DEBUG("This aten::_convolution node is not a 2D conv");
1081       return false;
1082     }
1083     if (node->kind() == aten::_convolution || node->kind() == aten::conv2d) {
1084       if (!tensorexpr::conv2dIsSupportedJit(node) &&
1085           !tensorexpr::mkldnnPrepackedConvIsSupportedJit(node)) {
1086         GRAPH_DEBUG("Params of conv2d are not supported");
1087         return false;
1088       }
1089     }
1090     if (node->kind() == aten::matmul) {
1091       if (!tensorexpr::matmulIsSupported(node)) {
1092         GRAPH_DEBUG("Shapes of matmul inputs are not supported");
1093         return false;
1094       }
1095     }
1096     return true;
1097   }
1098 
1099 #define REQ(cond)                           \
1100   if (!(cond)) {                            \
1101     GRAPH_DEBUG("Failed cond " #cond "\n"); \
1102     return false;                           \
1103   }
1104 
canHandle(Node * node)1105   bool canHandle(Node* node) {
1106     REQ(allShapesAreKnown(node));
1107     REQ(isFusableOnDevice(node));
1108     REQ(operators_not_to_fuse.find(node->kind()) ==
1109         operators_not_to_fuse.end());
1110 
1111     for (Value* input : node->inputs()) {
1112       if (auto const& tt = input->type()->cast<TensorType>()) {
1113         auto st = tt->scalarType();
1114         if (!st) {
1115           // All tensor types should be known.
1116           return false;
1117         }
1118         if (c10::isComplexType(*st) || c10::isQIntType(*st)) {
1119           return false;
1120         }
1121       }
1122     }
1123     if (node->kind() == aten::cat) {
1124       REQ(node->input(0)->node()->kind() == prim::ListConstruct);
1125       REQ(node->input(0)->uses().size() == 1);
1126       REQ(node->input(1)->node()->kind() == prim::Constant);
1127       auto const& listconstruct = node->input(0)->node();
1128       REQ(tensorexpr::pickDeviceType(listconstruct->inputs()));
1129     } else {
1130       REQ(tensorexpr::pickDeviceType(node->inputs()));
1131     }
1132 
1133     // Only fuse aten::batch_norm when the parameter 'training' is false
1134     if (node->kind() == aten::batch_norm) {
1135       REQ(node->input(5)->node()->kind() == prim::Constant);
1136       REQ(!toIValue(node->input(5)).value().toBool());
1137     }
1138 
1139     REQ(tensorexpr::isSupported(node));
1140     REQ(typesAreSupported(node));
1141 
1142     // A hook to optimizations limiter to allow bisecting the pass
1143     REQ(JIT_OPT_ALLOWED);
1144 
1145     if (fuse_to_dynamic_shapes_) {
1146       // Allow only if the node has a shape function defined.
1147       // ListConstruct node is an exception since that is needed to fuse
1148       // aten::cat, though it does not have a shape function.
1149       REQ(node->kind() == prim::ListConstruct ||
1150           node->kind() == prim::TensorExprGroup ||
1151           node->isMemberOf(tensorexpr::getCustomOperatorSet()) ||
1152           (node->maybeSchema() && shapeComputeGraphForSchema(node->schema())));
1153     }
1154 
1155     return true;
1156   }
1157 
canMerge(Node * consumer,Node * producer)1158   bool canMerge(Node* consumer, Node* producer) {
1159     // Only fuse within a block
1160     REQ(consumer->owningBlock() == producer->owningBlock());
1161 
1162     // Symbolic checks
1163     REQ(canHandle(producer) || producer->kind() == prim::TensorExprGroup);
1164     TORCH_INTERNAL_ASSERT(
1165         consumer->kind() == prim::TensorExprGroup || canHandle(consumer));
1166 
1167     // nvrtc has a limit on the number of arguments allowed in a CUDA kernel.
1168     // The specific limit is a function of constant memory size, amount
1169     // available to pass arguments, and some implementation dependence. Select a
1170     // safe limit here.
1171     constexpr size_t subgraphArgLimit = 128;
1172     auto const nInputs = consumer->inputs().size() +
1173         consumer->outputs().size() + producer->inputs().size() +
1174         producer->outputs().size();
1175     REQ(nInputs <= subgraphArgLimit);
1176 
1177     // Device checks
1178     if (consumer->kind() != aten::cat && producer->kind() != aten::cat) {
1179       // aten::cat needs a special handling because it takes a Tensor[] as its
1180       // input We deal with that in the code below.
1181       auto consumer_device = tensorexpr::pickDeviceType(consumer->inputs());
1182       REQ(consumer_device);
1183       auto producer_device = tensorexpr::pickDeviceType(producer->inputs());
1184       REQ(producer_device);
1185       REQ(*consumer_device == *producer_device);
1186     }
1187 
1188     // Alias checks
1189     REQ(aliasDb_->couldMoveBeforeTopologically(producer, consumer));
1190 
1191     // Ops that return aliases can only be folded if this is the only use.
1192     if (producer->kind() == aten::slice ||
1193         producer->kind() == aten::unsqueeze ||
1194         producer->kind() == prim::ConstantChunk) {
1195       for (auto& use : producer->output(0)->uses()) {
1196         REQ(use.user == consumer);
1197       }
1198     }
1199 
1200     if (!consumer->hasAttribute(attr::Subgraph) &&
1201         consumer->kind() != prim::TensorExprGroup) {
1202       // Don't initiate a fusion group from prim::ListConstruct
1203       REQ(consumer->kind() != prim::ListConstruct);
1204       REQ(consumer->kind() != aten::slice);
1205       REQ(consumer->kind() != aten::unsqueeze);
1206       REQ(consumer->kind() != prim::ConstantChunk);
1207 
1208       // Don't initiate a fusion group just for a constant operand
1209       REQ(producer->kind() != prim::Constant);
1210     }
1211 
1212     if (producer->kind() == aten::cat) {
1213       REQ(producer->input(0)->node()->kind() == prim::ListConstruct);
1214       REQ(producer->input(0)->uses().size() == 1);
1215       REQ(producer->input(1)->node()->kind() == prim::Constant);
1216       auto const& listConstruct = producer->input(0)->node();
1217       // We're merging listconstruct->cat->consumer. cat is the producer here
1218       // and we cannot determine its device type - we should use device of the
1219       // listconstruct instead
1220       auto listconstruct_device =
1221           tensorexpr::pickDeviceType(listConstruct->inputs());
1222       auto consumer_device = tensorexpr::pickDeviceType(consumer->inputs());
1223       REQ(listconstruct_device);
1224       REQ(consumer_device);
1225       REQ(*listconstruct_device == *consumer_device);
1226       for (auto const& input : listConstruct->inputs()) {
1227         REQ(isFusableOnDevice(input->node()));
1228       }
1229       REQ((nInputs + listConstruct->inputs().size()) <= subgraphArgLimit);
1230     } else if (consumer->kind() == aten::cat) {
1231       REQ(consumer->input(0)->node()->kind() == prim::ListConstruct);
1232       REQ(consumer->input(0)->uses().size() == 1);
1233       REQ(consumer->input(1)->node()->kind() == prim::Constant);
1234       auto const& listConstruct = consumer->input(0)->node();
1235       // We're merging listconstruct->cat. cat is the consumer and listconstruct
1236       // is the producer. cat doesn't have its device type and thus the only
1237       // thing we should check is that listconstruct's device is well defined
1238       // (e.g. all its inputs has the same device).
1239       auto listconstruct_device =
1240           tensorexpr::pickDeviceType(listConstruct->inputs());
1241       REQ(listconstruct_device);
1242       REQ((nInputs + listConstruct->inputs().size()) <= subgraphArgLimit);
1243     } else {
1244       REQ(isFusableOnDevice(producer));
1245     }
1246 
1247     return true;
1248   }
1249 #undef REQ
1250 
prepareFusionGroupAndGuardOutputs(Block * block)1251   void prepareFusionGroupAndGuardOutputs(Block* block) {
1252     std::vector<Node*> fusion_groups;
1253     for (Node* n : block->nodes()) {
1254       for (Block* b : n->blocks()) {
1255         prepareFusionGroupAndGuardOutputs(b);
1256       }
1257       if (n->kind() == prim::TensorExprGroup) {
1258         fusion_groups.push_back(n);
1259       }
1260     }
1261     for (Node* fusion_group : fusion_groups) {
1262       removeOutputsUsedOnlyInSize(fusion_group);
1263       insertTypeGuard(
1264           fusion_group,
1265           [](const TensorTypePtr& t) { return t; },
1266           prim::TypeCheck);
1267     }
1268   }
1269 
generalizeFusionGroups(Block * block)1270   void generalizeFusionGroups(Block* block) {
1271     std::vector<Node*> fusion_groups;
1272     for (Node* n : block->nodes()) {
1273       for (Block* b : n->blocks()) {
1274         generalizeFusionGroups(b);
1275       }
1276       if (n->kind() == prim::TensorExprGroup) {
1277         fusion_groups.push_back(n);
1278       }
1279     }
1280     for (Node* fusion_group : fusion_groups) {
1281       removeOutputsUsedOnlyInSize(fusion_group);
1282       VLOG(1) << "GenerateGuard for fusion group: " << *fusion_group;
1283       if (!GenerateGuard(fusion_group, add_composed_op_)) {
1284         VLOG(1) << "  Unfusing the fusion group because GenerateGuard failed"
1285                 << '\n';
1286         SubgraphUtils::unmergeSubgraph(fusion_group);
1287       }
1288     }
1289   }
1290 
1291   // This function parses the option provided by the environment variable
1292   // "PYTORCH_TENSOREXPR_DONT_FUSE".
1293   // This variable allows users to disable fusion on a list of specified
1294   // operators that are separated by ':'. e.g.,
1295   // 'PYTORCH_TENSOREXPR_DONT_FUSE="clamp:mul:add"' disables fusion on
1296   // aten::clamp, aten::mul and aten::add.
parseTENotFuseOption()1297   void parseTENotFuseOption() {
1298     const char* option = std::getenv("PYTORCH_TENSOREXPR_DONT_FUSE");
1299     std::stringstream in_ss;
1300     if (option) {
1301       in_ss << option;
1302     }
1303 
1304     std::string line;
1305     while (std::getline(in_ss, line, ':')) {
1306       if (line.empty()) {
1307         continue;
1308       }
1309       operators_not_to_fuse.insert(c10::Symbol::aten(line));
1310     }
1311   }
1312 
1313   std::shared_ptr<Graph> graph_;
1314   std::unique_ptr<AliasDb> aliasDb_ = nullptr;
1315 
1316   std::set<NodeKind> operators_not_to_fuse;
1317   // Minimal size of a fusion group
1318   size_t min_group_size_;
1319   // compose Runtime Type Guard and Kernel in one op
1320   bool add_composed_op_;
1321   // generalize static shapes to dynamic shapes
1322   bool fuse_to_dynamic_shapes_;
1323 };
1324 
FuseTensorExprs(std::shared_ptr<Graph> & graph,size_t min_group_size,bool add_composed_op,bool fuse_to_dynamic_shapes)1325 void FuseTensorExprs(
1326     std::shared_ptr<Graph>& graph,
1327     size_t min_group_size,
1328     bool add_composed_op,
1329     bool fuse_to_dynamic_shapes) {
1330   GRAPH_DUMP("Before TExprFuser: ", graph);
1331 
1332   // Temporary change for Block code generation.
1333   if (tensorexpr::getTEGenerateBlockCode()) {
1334     min_group_size = 1;
1335   }
1336 
1337   if (add_composed_op) {
1338     TORCH_INTERNAL_ASSERT(
1339         fuse_to_dynamic_shapes, "Fusing static shapes with composed op NYI");
1340   }
1341 
1342   // Get rid of dead code so that we don't waste effort fusing it.
1343   EliminateDeadCode(graph);
1344 
1345   TensorExprFuser fuser(
1346       graph, min_group_size, add_composed_op, fuse_to_dynamic_shapes);
1347   fuser.run();
1348 
1349   EliminateCommonSubexpression(graph);
1350   EliminateDeadCode(graph);
1351 
1352   GRAPH_DUMP("After TExprFuser: ", graph);
1353 }
1354 
createTensorExprOp(const Node * node)1355 static Operation createTensorExprOp(const Node* node) {
1356   bool dynamic_shape_fusion_node =
1357       node->hasAttribute(attr::striding_inputs_desc);
1358   if (!dynamic_shape_fusion_node) {
1359     auto kernel =
1360         std::make_shared<tensorexpr::TensorExprKernel>(node->g(attr::Subgraph));
1361     return [kernel](Stack& stack) {
1362       RECORD_FUNCTION(kernel->getKernelName(), std::vector<c10::IValue>());
1363       kernel->run(stack);
1364       return 0;
1365     };
1366   }
1367 
1368   // Handle the case when dynamic shape fusion is enabled.
1369   VLOG(1) << "Compiling a new kernel for " << *node;
1370   std::vector<int64_t> sym_shapes;
1371   if (node->hasAttribute(attr::symbolic_shape_inputs)) {
1372     sym_shapes = node->is(attr::symbolic_shape_inputs);
1373   }
1374   bool allow_stack_outputs = false;
1375   if (node->hasAttribute(attr::allow_stack_outputs)) {
1376     allow_stack_outputs = node->i(attr::allow_stack_outputs) == 1;
1377   }
1378 
1379   std::unordered_map<c10::Symbol, tensorexpr::NNCLoweringFunction>
1380       custom_lowerings;
1381   auto subgraph = node->g(attr::Subgraph);
1382   IValue sym_strides = node->ival(attr::striding_inputs_desc);
1383 
1384   // Striding Descriptor is serialized on the node as a vector of vector of
1385   // strings, translate back to StrideInput enum
1386   std::vector<std::vector<std::string>> sym_strides_strs =
1387       sym_strides.to<std::vector<std::vector<std::string>>>();
1388   std::vector<std::vector<StrideInput>> striding_inputs;
1389   for (const auto& vec : sym_strides_strs) {
1390     std::vector<StrideInput> input_desc;
1391     input_desc.reserve(vec.size());
1392     for (const std::string& str : vec) {
1393       input_desc.push_back(strideInputFromString(str));
1394     }
1395     striding_inputs.push_back(input_desc);
1396   }
1397   std::unordered_map<const Value*, std::vector<StrideInput>> stride_map;
1398   size_t index = 0;
1399   for (Value* v : subgraph->inputs()) {
1400     if (!v->type()->cast<TensorType>()) {
1401       continue;
1402     }
1403     stride_map[v] = striding_inputs[index];
1404     index++;
1405   }
1406   std::vector<std::string> output_desc =
1407       node->ival(attr::striding_outputs_desc).to<std::vector<std::string>>();
1408   for (size_t i = 0; i < subgraph->outputs().size(); ++i) {
1409     stride_map[subgraph->outputs().at(i)] = {
1410         strideInputFromString(output_desc.at(i))};
1411   }
1412 
1413   std::shared_ptr<tensorexpr::TensorExprKernel> kernel =
1414       std::make_shared<tensorexpr::TensorExprKernel>(
1415           subgraph,
1416           custom_lowerings,
1417           sym_shapes,
1418           /*pre_alloc*/ false,
1419           stride_map);
1420 
1421   auto num_subgraph_inputs = subgraph->inputs().size();
1422   return [kernel, num_subgraph_inputs, allow_stack_outputs](Stack& stack) {
1423     RECORD_FUNCTION(kernel->getKernelName(), std::vector<c10::IValue>());
1424 
1425     // Stack contents:
1426     //   [<outputs>] <inputs>
1427     //
1428     // If the number of graph inputs is same as the stack size, then no
1429     // outputs are being passed in. Otherwise, output tensors are passed in
1430     // at the bottom of the stack. So, we call the appropriate run function
1431     // in TensorExprKernel.
1432     if (num_subgraph_inputs == stack.size() || !allow_stack_outputs) {
1433       kernel->run(stack);
1434     } else {
1435       kernel->runWithAllocatedOutputs(stack);
1436     }
1437     return 0;
1438   };
1439 }
1440 
1441 RegisterOperators TensorExprOps({
1442     torch::jit::Operator(
1443         prim::TensorExprGroup,
1444         createTensorExprOp,
1445         AliasAnalysisKind::INTERNAL_SPECIAL_CASE),
1446 });
1447 
1448 } // namespace torch::jit
1449