1 #include <torch/csrc/jit/passes/tensorexpr_fuser.h>
2
3 #include <ATen/core/interned_strings.h>
4 #include <ATen/core/symbol.h>
5 #include <ATen/record_function.h>
6 #include <c10/util/FunctionRef.h>
7 #include <c10/util/irange.h>
8 #include <torch/csrc/jit/codegen/cuda/interface.h>
9 #include <torch/csrc/jit/codegen/fuser/interface.h>
10 #include <torch/csrc/jit/ir/alias_analysis.h>
11 #include <torch/csrc/jit/jit_log.h>
12 #include <torch/csrc/jit/jit_opt_limit.h>
13 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
14 #include <torch/csrc/jit/passes/constant_pooling.h>
15 #include <torch/csrc/jit/passes/dead_code_elimination.h>
16 #include <torch/csrc/jit/passes/pass_manager.h>
17 #include <torch/csrc/jit/passes/remove_redundant_profiles.h>
18 #include <torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h>
19 #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
20 #include <torch/csrc/jit/runtime/custom_operator.h>
21 #include <torch/csrc/jit/runtime/graph_executor.h>
22 #include <torch/csrc/jit/runtime/operator_options.h>
23 #include <torch/csrc/jit/runtime/symbolic_shape_registry.h>
24 #include <torch/csrc/jit/runtime/symbolic_shape_registry_util.h>
25 #include <torch/csrc/jit/tensorexpr/kernel.h>
26
27 #include <utility>
28
29 C10_DEFINE_bool(
30 torch_jit_disable_cat,
31 false,
32 "disable aten::cat in TE fusion groups");
33
34 C10_DEFINE_bool(
35 torch_jit_enable_dynamic_shape_fusion,
36 false,
37 "enable TE fusion using dynamic shapes");
38
39 namespace torch::jit {
40
41 static bool texpr_reductions_enabled = false;
42
isSupportedForBlock(Node * node)43 static bool isSupportedForBlock(Node* node) {
44 switch (node->kind()) {
45 case aten::add:
46 case aten::mul:
47 return true;
48 default:
49 return false;
50 }
51 }
52
usedOnlyInSize(Value * v)53 bool usedOnlyInSize(Value* v) {
54 const auto& uses = v->uses();
55 return std::all_of(uses.begin(), uses.end(), [](const Use& u) {
56 return u.user->matches("aten::size(Tensor self) -> int[]");
57 });
58 }
59
broadcastSizes(at::ArrayRef<Value * > sizes,AliasDb * db)60 Value* broadcastSizes(at::ArrayRef<Value*> sizes, AliasDb* db) {
61 AT_ASSERT(!sizes.empty());
62 Graph* graph = sizes[0]->owningGraph();
63 Node* broadcast_n =
64 graph->insertNode(graph->create(prim::BroadcastSizes, sizes));
65 broadcast_n->output()->setType(ListType::ofInts());
66 db->createValue(broadcast_n->output());
67 return broadcast_n->output();
68 }
69
70 namespace tensorexpr {
71
getCustomOperatorSet()72 OperatorSet& getCustomOperatorSet() {
73 static OperatorSet _g_custom_operator_set{};
74 return _g_custom_operator_set;
75 }
76
supported_non_eltwise_set()77 static const OperatorSet& supported_non_eltwise_set() {
78 // clang-format off
79 static const OperatorSet supported_non_eltwise_set{
80 "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor",
81 "aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor",
82 "aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor",
83 "aten::matmul(Tensor self, Tensor other) -> Tensor",
84 };
85 // clang-format on
86 return supported_non_eltwise_set;
87 };
88
isSupported(Node * node)89 bool isSupported(Node* node) {
90 // For Block codegen we allow limited ops.
91 if (tensorexpr::getTEGenerateBlockCode()) {
92 return isSupportedForBlock(node);
93 }
94
95 static const OperatorSet supported_reduction_set{
96 "aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor",
97 "aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor",
98 "aten::softmax.int(Tensor self, int dim , ScalarType? dtype=None) -> Tensor",
99 "aten::log_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor",
100 };
101 static const OperatorSet supported_misc_set{
102 "aten::cat(Tensor[] tensors, int dim=0) -> Tensor",
103 "aten::unsqueeze(Tensor(a) self, int dim) -> Tensor(a)",
104 };
105 // clang-format on
106
107 if (get_tensorexpr_elementwise_set().contains(node) ||
108 node->isMemberOf(supported_non_eltwise_set()) ||
109 node->isMemberOf(supported_misc_set) ||
110 node->isMemberOf(getCustomOperatorSet()) ||
111 (texpr_reductions_enabled && node->isMemberOf(supported_reduction_set))) {
112 // We only insert guards on Tensor types, so we rely on the output
113 // of a node being uniquely determined by its input types.
114 // bail if any non-Tensor input affects the output type
115 // and cannot be reasoned about statically
116
117 // Value is either an int or a float (can occur from .item())
118 for (Value* v : node->inputs()) {
119 if (v->type()->cast<NumberType>()) {
120 return false;
121 }
122 }
123
124 // non-const dtype / device
125 for (auto arg_name : {"dtype", "device"}) {
126 if (auto index = node->schema().argumentIndexWithName(arg_name)) {
127 if (!toIValue(node->input(*index))) {
128 return false;
129 }
130 }
131 }
132
133 if (FLAGS_torch_jit_disable_cat && node->kind() == aten::cat) {
134 return false;
135 }
136
137 return true;
138 }
139
140 // unschematized ops
141 switch (node->kind()) {
142 case prim::ConstantChunk:
143 case prim::ListConstruct:
144 case prim::TensorExprGroup:
145 return true;
146 }
147
148 return false;
149 }
150 } // namespace tensorexpr
151
152 static bool texpr_fuser_enabled_ = true;
153
setTensorExprFuserEnabled(bool val)154 void setTensorExprFuserEnabled(bool val) {
155 texpr_fuser_enabled_ = val;
156 }
157
tensorExprFuserEnabled()158 bool tensorExprFuserEnabled() {
159 static const char* enable_c_str = std::getenv("PYTORCH_TENSOREXPR");
160 if (!enable_c_str) {
161 return texpr_fuser_enabled_;
162 }
163 if (std::string(enable_c_str) == "0") {
164 return false;
165 }
166 return true;
167 }
168
tensorExprDynamicShapeFusionEnabled()169 bool tensorExprDynamicShapeFusionEnabled() {
170 return FLAGS_torch_jit_enable_dynamic_shape_fusion;
171 }
172
setTensorExprDynamicShapeFusionEnabled(bool val)173 void setTensorExprDynamicShapeFusionEnabled(bool val) {
174 FLAGS_torch_jit_enable_dynamic_shape_fusion = val;
175 }
176
setTexprReductionsEnabled(bool value)177 bool setTexprReductionsEnabled(bool value) {
178 bool old_value = texpr_reductions_enabled;
179 texpr_reductions_enabled = value;
180 return old_value;
181 }
182
texprReductionsEnabled()183 bool texprReductionsEnabled() {
184 return texpr_reductions_enabled;
185 }
186
removeProfileNodesAndSpecializeTypes(Block * b)187 static void removeProfileNodesAndSpecializeTypes(Block* b) {
188 for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) {
189 if (it->kind() == prim::profile) {
190 GRAPH_DEBUG("Removing prim::profile: %", it->output()->debugName());
191 it->output()->replaceAllUsesWith(it->input());
192 auto profiled_type = it->ty(attr::profiled_type)->expect<TensorType>();
193
194 TensorTypePtr input_tensor_type = nullptr;
195 bool input_is_optional = false;
196 if (it->input()->type()->kind() == c10::TypeKind::TensorType) {
197 input_tensor_type = it->input()->type()->expect<TensorType>();
198 } else {
199 input_tensor_type = it->input()
200 ->type()
201 ->expectRef<OptionalType>()
202 .getElementType()
203 ->expect<TensorType>();
204 input_is_optional = true;
205 }
206
207 if (input_is_optional) {
208 it.destroyCurrent();
209 continue;
210 }
211
212 // A value can be profiled with differently typed uses.
213 // This can occur from:
214 // - having a use which is not executed, so the type will be
215 // TensorType::get()
216 // - control-flow that depends on tensor type:
217 // if x.size() == 2 op(x) else op(x)
218 // - mutation of the value on a field represented in the tensor type
219 // op(x); x.resize_([...]); op(x)
220
221 // The most common case today with num_profiles = 1 is from the first
222 // case. Here we can just ignore non-profiled uses, and choose any of the
223 // profiled uses. Because we guard all tensor types in the runtime, even
224 // if we set a Value to have a profiled type from one use and then execute
225 // a use with a different profiled type, we will still be correct.
226 // In the future we could consider unifying the types of uses, or adding a
227 // type refinement node so uses can have the correct corresponding type.
228 if (profiled_type == TensorType::get()) {
229 continue;
230 }
231
232 // If we encounter non-identical profiled types for the same value, merge
233 // them. This situation can happen if, e.g., loop unrolling duplicates
234 // profiled types in a loop body in a manner that isn't logically
235 // consistent (see TestTEFuser.test_unrolled_cat).
236 if (input_tensor_type == TensorType::get()) {
237 it->input()->setType(profiled_type);
238 } else {
239 it->input()->setType(input_tensor_type->merge(*profiled_type));
240 }
241
242 it.destroyCurrent();
243 } else {
244 for (Block* ib : it->blocks()) {
245 removeProfileNodesAndSpecializeTypes(ib);
246 }
247 }
248 }
249 }
250
RemoveProfileNodesAndSpecializeTypes(std::shared_ptr<Graph> & graph)251 void RemoveProfileNodesAndSpecializeTypes(std::shared_ptr<Graph>& graph) {
252 GRAPH_DEBUG("Before removeProfileNodesAndSpecializeTypes:\n", *graph);
253 removeProfileNodesAndSpecializeTypes(graph->block());
254 GRAPH_DEBUG("After removeProfileNodesAndSpecializeTypes:\n", *graph);
255 }
256
hasTensorTypeSpecialization(Value * v)257 bool hasTensorTypeSpecialization(Value* v) {
258 if (!v->type()->cast<TensorType>()) {
259 return false;
260 }
261 // Constants & TensorExprGroup will always produce specialized tensor type,
262 // TypeCheck are inserted by this pass and only used by fusion groups that
263 // insert proper guards
264 if (v->node()->kind() == prim::Constant ||
265 v->node()->kind() == prim::TypeCheck ||
266 v->node()->kind() == prim::TensorExprGroup) {
267 return false;
268 }
269 if (v->type() == TensorType::get()) {
270 return false;
271 }
272 return true;
273 }
274
removeTensorTypeSpecialization(Value * v)275 static void removeTensorTypeSpecialization(Value* v) {
276 if (hasTensorTypeSpecialization(v)) {
277 v->setType(TensorType::get());
278 }
279 }
280
removeTensorTypeSpecializations(Block * block)281 void removeTensorTypeSpecializations(Block* block) {
282 for (Value* v : block->inputs()) {
283 removeTensorTypeSpecialization(v);
284 }
285 for (Node* n : block->nodes()) {
286 for (Block* b : n->blocks()) {
287 removeTensorTypeSpecializations(b);
288 }
289 for (Value* v : n->outputs()) {
290 removeTensorTypeSpecialization(v);
291 }
292 }
293 }
294
RemoveTensorTypeSpecializations(std::shared_ptr<Graph> & graph)295 void RemoveTensorTypeSpecializations(std::shared_ptr<Graph>& graph) {
296 removeTensorTypeSpecializations(graph->block());
297 }
298
insertTypeGuard(Node * guarded_node,tensor_type_converter_t type_converter,Symbol kind)299 void insertTypeGuard(
300 Node* guarded_node,
301 tensor_type_converter_t type_converter,
302 Symbol kind) {
303 GRAPH_DEBUG("Inserting a typecheck guard for a node", *guarded_node);
304 auto subgraph = SubgraphUtils::getSubgraph(guarded_node);
305
306 // Fixup types of the subgraph inputs
307 std::vector<Value*> inputs_to_check;
308 std::vector<TypePtr> guard_types;
309 for (Value* input : guarded_node->inputs()) {
310 // We only check inputs of the guarded nodes and expect user to infer
311 // intermediates and outputs shapes
312 if (!input->type()->cast<TensorType>()) {
313 continue;
314 }
315
316 // fusion outputs are already guarded
317 if (input->node()->kind() == prim::Constant ||
318 input->node()->kind() == prim::FusionGroup) {
319 continue;
320 }
321 inputs_to_check.push_back(input);
322 guard_types.emplace_back(
323 type_converter(input->type()->expect<TensorType>()));
324 }
325 if (inputs_to_check.empty()) {
326 return;
327 }
328
329 // Add prim::TypeCheck node
330 //
331 // TypeCheck nodes look like the following:
332 // %out1 : Float(2, 3), %out2 : Int(10, 30), %types_match : bool =
333 // prim::TypeCheck(%inp1 : Tensor, %inp2 : Tensor)
334 //
335 // They have N inputs whose types we are going to check and N+1 outputs. The
336 // first N outputs specify expected types and N+1-th output holds the result
337 // of the check (bool).
338 Node* typecheck_node =
339 guarded_node->owningGraph()
340 ->create(kind, inputs_to_check, inputs_to_check.size() + 1)
341 ->insertBefore(guarded_node);
342 typecheck_node->tys_(attr::types, std::move(guard_types));
343 Value* typecheck_result = typecheck_node->output(inputs_to_check.size());
344
345 std::unordered_map<Value*, Value*> typechecked_inputs;
346 for (size_t i = 0; i < typecheck_node->inputs().size(); ++i) {
347 typechecked_inputs[typecheck_node->input(i)] = typecheck_node->output(i);
348 }
349
350 // Fixup types of the typecheck node outputs, which are used by the op in
351 // execution
352 typecheck_node->output(inputs_to_check.size())->setType(BoolType::get());
353 for (size_t i = 0; i < typecheck_node->inputs().size(); ++i) {
354 typecheck_node->output(i)->setType(typecheck_node->input(i)->type());
355 }
356
357 // Insert if
358 auto versioning_if =
359 guarded_node->owningGraph()
360 ->create(prim::If, {typecheck_result}, guarded_node->outputs().size())
361 ->insertAfter(typecheck_node);
362 for (size_t idx = 0; idx < guarded_node->outputs().size(); ++idx) {
363 versioning_if->output(idx)->setType(guarded_node->output(idx)->type());
364 guarded_node->output(idx)->replaceAllUsesWith(versioning_if->output(idx));
365 }
366 auto true_block = versioning_if->addBlock();
367 auto false_block = versioning_if->addBlock();
368
369 // Fill in the false block. It should contain the unoptimized
370 // copy of the fused subgraph.
371 WithInsertPoint guard(false_block->return_node());
372 const auto subgraph_outputs = insertGraph(
373 *guarded_node->owningGraph(), *subgraph, guarded_node->inputs());
374 for (Value* output : subgraph_outputs) {
375 false_block->registerOutput(output);
376 }
377
378 // types get copied to the fallback graph, so remove specializations before
379 // replacing
380 removeTensorTypeSpecializations(false_block);
381 replaceBlockWithFallbackGraph(false_block, guarded_node->inputs());
382
383 // Fill in the true block. It has all inputs type-checked and its
384 // body should be the fusion group node.
385 guarded_node->moveBefore(true_block->return_node());
386 for (size_t idx = 0; idx < guarded_node->inputs().size(); ++idx) {
387 if (typechecked_inputs.count(guarded_node->input(idx))) {
388 guarded_node->replaceInput(
389 idx, typechecked_inputs.at(guarded_node->input(idx)));
390 }
391 }
392 for (Value* output : guarded_node->outputs()) {
393 true_block->registerOutput(output);
394 }
395 }
396
397 namespace {
has_unsupported_pin_memory(const Node * node)398 bool has_unsupported_pin_memory(const Node* node) {
399 // cant support non-constant pin_memory or pin_memory = True
400 if (auto maybe_index = node->schema().argumentIndexWithName("pin_memory")) {
401 int index = *maybe_index;
402 auto inp = node->input(index);
403 if (inp->type() != NoneType::get() &&
404 constant_as<bool>(inp).value_or(true)) {
405 return true;
406 }
407 }
408 return false;
409 }
410 } // namespace
411
412 class TensorExprFuser {
413 public:
TensorExprFuser(std::shared_ptr<Graph> graph,size_t min_group_size,bool add_composed_op,bool fuse_to_dynamic_shapes)414 TensorExprFuser(
415 std::shared_ptr<Graph> graph,
416 size_t min_group_size,
417 bool add_composed_op,
418 bool fuse_to_dynamic_shapes)
419 : graph_(std::move(graph)),
420 min_group_size_(min_group_size),
421 add_composed_op_(add_composed_op),
422 fuse_to_dynamic_shapes_(fuse_to_dynamic_shapes) {
423 parseTENotFuseOption();
424 }
425
426 // Builds up expressions that compute shapes of all intermediates (and
427 // outputs) of the fusion group, based on the sizes of inputs. You should run
428 // DCE to remove those that you end up not using.
buildShapeExpressions(Node * fusion_group)429 std::unordered_map<Value*, Value*> buildShapeExpressions(Node* fusion_group) {
430 GRAPH_DUMP("buildShapeExpressions for ", fusion_group->g(attr::Subgraph));
431 WithInsertPoint insert_guard{fusion_group->next()};
432 std::unordered_map<Value*, Value*> shape_of;
433
434 Graph* graph = fusion_group->owningGraph();
435 auto subgraph = fusion_group->g(attr::Subgraph);
436
437 auto inputs = fusion_group->inputs();
438 auto sinputs = subgraph->inputs();
439 AT_ASSERT(inputs.size() == sinputs.size());
440 for (const auto i : c10::irange(inputs.size())) {
441 if (inputs[i]->type()->isSubtypeOf(*TensorType::get())) {
442 Value* soutput = graph->insert(aten::size, {inputs[i]});
443 aliasDb_->createValue(soutput);
444 GRAPH_DEBUG(
445 "Adding a mapping for %",
446 sinputs[i]->debugName(),
447 " ",
448 getHeader(soutput->node()));
449 shape_of[sinputs[i]] = soutput;
450 }
451 }
452
453 // When we have a guarantee that an output won't be removed, because it's
454 // used in expressions that don't involve size checks, we can use its size
455 // instead of computing a long chain of broadcasts, starting from the
456 // beginning of the kernel.
457 auto outputs = fusion_group->outputs();
458 auto soutputs = subgraph->outputs();
459 AT_ASSERT(outputs.size() == soutputs.size());
460 for (const auto i : c10::irange(outputs.size())) {
461 if (usedOnlyInSize(outputs[i]))
462 continue;
463 Value* soutput = graph->insert(aten::size, {outputs[i]});
464 aliasDb_->createValue(soutput);
465 shape_of[soutputs[i]] = soutput;
466 }
467
468 for (Node* n : subgraph->nodes()) {
469 auto tensor_inputs = filter(n->inputs(), [](Value* v) {
470 return v->type()->isSubtypeOf(*TensorType::get());
471 });
472 GRAPH_DEBUG("Building sizes for ", getHeader(n));
473 bool all_inputs_have_sizes = true;
474 auto shapes = fmap(tensor_inputs, [&](Value* v) {
475 GRAPH_DEBUG("Getting aten::size for %", v->debugName());
476 all_inputs_have_sizes &= shape_of.count(v);
477 return shape_of.count(v) != 0 ? shape_of.at(v) : nullptr;
478 });
479 if (!all_inputs_have_sizes) {
480 GRAPH_DEBUG(
481 "Not all tensor arguments have sizes available to compute the broadcasted size",
482 getHeader(n));
483 continue;
484 }
485
486 if (n->kind() == prim::ConstantChunk) {
487 Node* sizes_node = graph->insertNode(
488 graph->create(prim::ChunkSizes, shape_of.at(n->input()), 2));
489 sizes_node->i_(attr::dim, n->i(attr::dim));
490 sizes_node->i_(attr::chunks, n->i(attr::chunks));
491 for (Value* output : sizes_node->outputs()) {
492 aliasDb_->createValue(output);
493 }
494 Value* regular_size = sizes_node->outputs().at(0);
495 Value* last_size = sizes_node->outputs().at(1);
496 regular_size->setType(ListType::ofInts());
497 last_size->setType(ListType::ofInts());
498 auto outputs = n->outputs();
499 for (Value* o : outputs.slice(0, outputs.size() - 1)) {
500 shape_of.emplace(o, regular_size);
501 }
502 shape_of.emplace(outputs.at(outputs.size() - 1), last_size);
503 continue;
504 }
505
506 // we only support shape calculations for elementwise, some
507 // non-elementwise like batch_norm, conv, matmul, and
508 // a few exceptions (e.g. prim::ConstantChunk, etc) listed above
509 if (!(get_tensorexpr_elementwise_set().contains(n)) &&
510 !n->isMemberOf(tensorexpr::supported_non_eltwise_set())) {
511 continue;
512 }
513
514 shape_of.emplace(
515 n->output(),
516 shapes.size() == 1 ? shapes[0]
517 : broadcastSizes(shapes, aliasDb_.get()));
518 }
519 return shape_of;
520 }
521
removeOutputsUsedOnlyInSize(Node * fusion_group)522 void removeOutputsUsedOnlyInSize(Node* fusion_group) {
523 if (fusion_group->kind() != prim::TensorExprGroup)
524 return;
525 auto subgraph = fusion_group->g(attr::Subgraph);
526
527 auto shape_of = buildShapeExpressions(fusion_group);
528 auto outputs = fusion_group->outputs().vec();
529 auto soutputs = subgraph->outputs().vec();
530 // XXX: Iterating in this order is not only good for performance reasons!
531 // It is also crucial for correctness (i has to reflect the current true
532 // index of outputs[i])!
533 for (int64_t i = static_cast<int64_t>(outputs.size()) - 1; i >= 0; --i) {
534 auto output = outputs[i];
535 auto soutput = soutputs[i];
536 if (usedOnlyInSize(output) && shape_of.count(soutput) > 0) {
537 auto uses = output->uses();
538 for (Use u : uses) {
539 AT_ASSERT(u.user->matches("aten::size(Tensor self) -> int[]"));
540 u.user->output()->replaceAllUsesWith(shape_of.at(soutput));
541 u.user->destroy();
542 }
543 fusion_group->eraseOutput(i);
544 subgraph->eraseOutput(i);
545 }
546 }
547 }
548
run()549 void run() {
550 aliasDb_ = std::make_unique<AliasDb>(graph_);
551 RemoveRedundantProfiles(graph_);
552 GRAPH_DUMP("After removing redundant profile nodes: ", graph_);
553 createFusionGroups(graph_->block());
554 GRAPH_DUMP("After creating fusion groups: ", graph_);
555 // we maintain alias db correctness during initial fusion, but it is
556 // difficult to maintain correctness after inlining so inline only after
557 // fusion is done.
558 inlineSmallFusionGroups(graph_->block());
559 GRAPH_DUMP("After inlining small fusion groups: ", graph_);
560 if (fuse_to_dynamic_shapes_) {
561 VLOG(1) << "TensorExpr fusion with dynamic shapes is enabled" << '\n';
562 generalizeFusionGroups(graph_->block());
563 GRAPH_DUMP("After generalizing fusion groups: ", graph_);
564 } else {
565 prepareFusionGroupAndGuardOutputs(graph_->block());
566 GRAPH_DUMP("After guarding fusion groups: ", graph_);
567 }
568 }
569
570 private:
getOrCreateTensorExprSubgraph(Node * n)571 Node* getOrCreateTensorExprSubgraph(Node* n) {
572 if (n->hasAttribute(attr::Subgraph) && n->kind() == prim::TensorExprGroup) {
573 return n;
574 }
575 GRAPH_UPDATE("Creating a tensorexpr::Group node from: ", *n);
576 return SubgraphUtils::createSingletonSubgraphAndUpdateAliasing(
577 n, prim::TensorExprGroup, *aliasDb_);
578 }
579
sortReverseTopological(ArrayRef<Value * > inputs,Block * b)580 value_list sortReverseTopological(ArrayRef<Value*> inputs, Block* b) {
581 value_list result;
582 for (auto i : inputs) {
583 if (i->node()->owningBlock() == b) {
584 result.push_back(i);
585 }
586 }
587 // Sort in reverse topological order
588 std::sort(result.begin(), result.end(), [&](Value* a, Value* b) {
589 return a->node()->isAfter(b->node());
590 });
591 return result;
592 }
593
594 // Create a fusion group starting from the node N.
595 // We then try to pull inputs into the fusion group and repeat that process
596 // until there is nothing we can pull in.
createFusionGroup(Node * fusion_node)597 std::pair<graph_node_list::iterator, bool> createFusionGroup(
598 Node* fusion_node) {
599 // Allow single-node groups containing conv2d, since we'll only select
600 // those in cases where the tensorexpr implementation is faster than the
601 // aten implementation.
602 if (min_group_size_ == 1 || fusion_node->kind() == aten::conv2d) {
603 fusion_node = getOrCreateTensorExprSubgraph(fusion_node);
604 }
605
606 GRAPH_DEBUG("Iteratively pull input nodes into the fusion group...\n");
607 auto inputs = sortReverseTopological(
608 fusion_node->inputs(), fusion_node->owningBlock());
609 for (auto input : inputs) {
610 debugDumpFusionGroup("Current fusion group: ", fusion_node);
611 GRAPH_DEBUG("Trying to merge: ", *input->node());
612 if (auto maybe_fusion_group = tryMerge(fusion_node, input->node())) {
613 // we successfully merged, so the new group's `inputs` may have
614 // changed. So rescan the new group for more merging opportunities.
615 return std::make_pair(
616 maybe_fusion_group.value()->reverseIterator(), true);
617 }
618 }
619
620 return std::make_pair(++fusion_node->reverseIterator(), false);
621 }
622
debugDumpFusionGroup(const std::string & msg,Node * n)623 static void debugDumpFusionGroup(const std::string& msg, Node* n) {
624 GRAPH_DEBUG(msg, *n);
625 if (n->kind() == prim::TensorExprGroup) {
626 GRAPH_DEBUG(*n->g(attr::Subgraph));
627 }
628 }
629
630 // No Ops in eager shouldn't be outputs of Fusion Groups because it
631 // will degrade perf and change aliasing relationships
unexecutedEagerOp(Node * n)632 static bool unexecutedEagerOp(Node* n) {
633 if (n->kind() != aten::to &&
634 n->kind() != aten::_autocast_to_reduced_precision &&
635 n->kind() != aten::_autocast_to_full_precision) {
636 return false;
637 }
638
639 return *n->input(0)->type()->expect<TensorType>() ==
640 *n->output()->type()->expect<TensorType>();
641 }
642
scanNode(Node * n)643 std::pair<graph_node_list::iterator, bool> scanNode(Node* n) {
644 GRAPH_DEBUG("Considering node:", *n)
645
646 if (!canHandle(n)) {
647 return std::make_pair(++n->reverseIterator(), false);
648 }
649 // There are some nodes that we can support, but we don't want to start a
650 // fusion group from - skip them.
651 if (n->kind() == prim::ListConstruct || n->kind() == aten::slice ||
652 n->kind() == aten::unsqueeze || n->kind() == prim::ConstantChunk ||
653 n->kind() == prim::Constant || unexecutedEagerOp(n)) {
654 return std::make_pair(++n->reverseIterator(), false);
655 }
656 return createFusionGroup(n);
657 }
658
659 // Merge fusible nodes into subgraphs in prim::TensorExprGroup nodes.
createFusionGroups(Block * block)660 void createFusionGroups(Block* block) {
661 bool any_changed = true;
662 while (any_changed) {
663 any_changed = false;
664 for (auto it = block->nodes().rbegin(); it != block->nodes().rend();) {
665 auto [tmp_it, changed] = scanNode(*it);
666 it = tmp_it;
667 any_changed |= changed;
668 }
669 }
670
671 for (Node* n : block->nodes()) {
672 for (Block* b : n->blocks()) {
673 createFusionGroups(b);
674 }
675 }
676
677 // Try to merge adjacent fusion groups together. Because we have only merged
678 // by looking at graph inputs, without this we would not attempt to merge
679 // adjacent fusion groups that don't have a dependency on each other
680
681 std::vector<Node*> initial_fusion_groups;
682 for (Node* n : block->nodes()) {
683 if (n->kind() == prim::TensorExprGroup) {
684 initial_fusion_groups.push_back(n);
685 }
686 }
687
688 Node* prev_fusion_group =
689 !initial_fusion_groups.empty() ? initial_fusion_groups[0] : nullptr;
690
691 for (const auto i : c10::irange(1, initial_fusion_groups.size())) {
692 // Try merging the just created fusion group into the previous one.
693 // If it did not work, then put the previous fusion group into
694 // fusion_groups vector - we will not touch it anymore in this loop.
695 // If merging succeeded, save the merged group as the "previous" fusion
696 // group so that we can try to merge the next one into it.
697
698 Node* fusion_group = initial_fusion_groups[i];
699 debugDumpFusionGroup(
700 "Trying to merge into the previous fusion group: ",
701 prev_fusion_group);
702 if (auto merged_fusion_group =
703 tryMerge(prev_fusion_group, fusion_group)) {
704 prev_fusion_group = *merged_fusion_group;
705 debugDumpFusionGroup(
706 "Successfully merged into the previous fusion group: ",
707 prev_fusion_group);
708 } else {
709 GRAPH_DEBUG("Cannot merge into the previous fusion group");
710 prev_fusion_group = fusion_group;
711 }
712 }
713 }
714
blockSize(Block * block)715 size_t blockSize(Block* block) {
716 size_t num = 0;
717 for (Node* n : block->nodes()) {
718 // Don't count prim::Constants and prim::ListConstructs as these are nodes
719 // we only pull in along with another, "main", node. E.g. the
720 // ListConstruct nodes would also be pulled into a fusion group if they
721 // are inputs of an aten::cat node.
722 if (n->kind() == prim::Constant || n->kind() == prim::ListConstruct) {
723 continue;
724 }
725 for (Block* b : n->blocks()) {
726 num += blockSize(b);
727 }
728 num++;
729 }
730 return num;
731 }
732
hasConv(Block * block)733 bool hasConv(Block* block) {
734 for (Node* n : block->nodes()) {
735 if (n->kind() == aten::conv2d) {
736 return true;
737 }
738 }
739 return false;
740 }
741
inlineIfTooSmall(Node * n)742 bool inlineIfTooSmall(Node* n) {
743 if (n->kind() != prim::TensorExprGroup) {
744 return false;
745 }
746 auto subgraph = SubgraphUtils::getSubgraph(n);
747 size_t num_nodes = blockSize(subgraph->block());
748 // Allow small subgraphs containing conv2d, since we'll only select those
749 // in cases where the tensorexpr implementation is faster than the aten
750 // implementation.
751 if (num_nodes < min_group_size_ && !hasConv(subgraph->block())) {
752 GRAPH_UPDATE("Fusion group is too small, unmerging: ", *n);
753 SubgraphUtils::unmergeSubgraph(n);
754 return true;
755 }
756 // Cleanup the subgraph from duplicated constants while we're at it.
757 ConstantPooling(subgraph);
758
759 if (GRAPH_DEBUG_ENABLED) {
760 GRAPH_EXPORT("", subgraph);
761 }
762 return false;
763 }
764
inlineSmallFusionGroups(Block * block)765 void inlineSmallFusionGroups(Block* block) {
766 for (auto it = block->nodes().begin(); it != block->nodes().end();) {
767 Node* n = *it;
768 it++;
769
770 for (Block* b : n->blocks()) {
771 inlineSmallFusionGroups(b);
772 }
773 inlineIfTooSmall(n);
774 }
775 }
776
tryMerge(Node * fusion_group,Node * to_merge)777 std::optional<Node*> tryMerge(Node* fusion_group, Node* to_merge) {
778 if (!canMerge(fusion_group, to_merge)) {
779 return std::nullopt;
780 }
781
782 std::vector<Node*> nodes_to_merge = {to_merge};
783
784 if (to_merge->kind() == aten::cat) {
785 Node* listconstruct = to_merge->input(0)->node();
786 nodes_to_merge.push_back(listconstruct);
787 }
788
789 // First, try to move all the nodes we want to fuse next to the fusion
790 // group.
791 Node* move_point = fusion_group;
792 for (auto n : nodes_to_merge) {
793 GRAPH_UPDATE("Trying to move node next to fusion group: ", getHeader(n));
794 if (!aliasDb_->moveBeforeTopologicallyValid(n, move_point)) {
795 GRAPH_UPDATE("Failed to move because of AliasDB checks!");
796 return std::nullopt;
797 }
798 move_point = n;
799 }
800
801 // Now all the nodes that we're going to fuse are moved next to the fusion
802 // group, so we can safely merge them into the fusion group subgraph.
803 fusion_group = getOrCreateTensorExprSubgraph(fusion_group);
804
805 for (auto n : nodes_to_merge) {
806 GRAPH_UPDATE("Merging ", getHeader(n));
807 SubgraphUtils::mergeNodeIntoSubgraphAndUpdateAliasing(
808 n, fusion_group, *aliasDb_);
809 }
810 return fusion_group;
811 }
812
shapeIsKnown(Value * v)813 bool shapeIsKnown(Value* v) {
814 if (v->type()->cast<TensorType>()) {
815 if (!v->isCompleteTensor()) {
816 return false;
817 }
818 }
819 return true;
820 }
821
allShapesAreKnown(Node * node)822 bool allShapesAreKnown(Node* node) {
823 // TODO: Relax the checks to support dynamic shapes
824 for (Value* input : node->inputs()) {
825 if (!shapeIsKnown(input)) {
826 return false;
827 }
828 if (input->node()->kind() == prim::ListConstruct) {
829 if (!allShapesAreKnown(input->node())) {
830 return false;
831 }
832 }
833 }
834 for (Value* output : node->outputs()) {
835 if (!shapeIsKnown(output)) {
836 return false;
837 }
838 }
839 return true;
840 }
841
canFuseOnDevice(Value * v)842 bool canFuseOnDevice(Value* v) {
843 auto type = v->type()->cast<TensorType>();
844 if (!type) {
845 return true;
846 }
847 auto device = type->device();
848 if (!device) {
849 return false;
850 }
851 if (device->is_cpu()) {
852 return canFuseOnCPU();
853 } else if (device->is_cuda()) {
854 return canFuseOnGPU();
855 } else if (device->is_xpu()) {
856 return false;
857 }
858 return false;
859 }
860
isFusableOnDevice(Node * node)861 bool isFusableOnDevice(Node* node) {
862 for (const auto& input : node->inputs()) {
863 if (input->node()->kind() == prim::ListConstruct) {
864 if (!isFusableOnDevice(input->node())) {
865 return false;
866 }
867 }
868 if (!canFuseOnDevice(input)) {
869 return false;
870 }
871 }
872 return true;
873 }
874
typesAreSupported(Node * node)875 bool typesAreSupported(Node* node) {
876 // clang-format off
877 // breaks up the schema strings so they are no longer discoverable with ctrl-F
878 static const OperatorSet float_only_operator_set{
879 "aten::fmod.Scalar(Tensor self, Scalar other) -> Tensor",
880 "aten::fmod.Tensor(Tensor self, Tensor other) -> Tensor",
881 "aten::remainder.Scalar(Tensor self, Scalar other) -> Tensor",
882 "aten::remainder.Tensor(Tensor self, Tensor other) -> Tensor",
883 };
884 static const OperatorSet int_only_operator_set{
885 "aten::__lshift__.Scalar(Tensor self, Scalar other) -> Tensor",
886 "aten::__lshift__.Tensor(Tensor self, Tensor other) -> Tensor",
887 "aten::__rshift__.Scalar(Tensor self, Scalar other) -> Tensor",
888 "aten::__rshift__.Tensor(Tensor self, Tensor other) -> Tensor",
889 };
890 static const OperatorSet cpu_compute_heavy_set{
891 "aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor",
892 "aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor",
893 "aten::matmul(Tensor self, Tensor other) -> Tensor",
894 };
895 static const OperatorSet gpu_only_operator_set{
896 // On CPU, these are slower and less accurate than ATen kernels, because
897 // ATen is able to use MKL-VML, whereas the fuser currently can't. The
898 // fuser uses sleef instead because sleef provides functions that operate
899 // on vectors, instead of large buffers.
900 "aten::erf(Tensor self) -> Tensor",
901 "aten::erfc(Tensor self) -> Tensor",
902 };
903 static const OperatorSet pow{
904 "aten::pow.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor",
905 };
906 // clang-format on
907
908 // Check types of input values.
909 for (const Value* v : node->inputs()) {
910 if (auto const& tt = v->type()->cast<TensorType>()) {
911 auto const& st = tt->scalarType();
912 auto const& device = tt->device();
913
914 // All tensors must be typed.
915 if (!st || !device) {
916 return false;
917 }
918
919 // Byte tensors introduce too many corner cases in type promotion.
920 // Better not to try to handle them.
921 if (*st == c10::ScalarType::Byte) {
922 return false;
923 }
924
925 // Float16 support has some issues (see e.g. #61336 and #61382), so for
926 // now it's disabled. There seem to be some problems in HalfRewriter,
927 // but on top of that Float16 has a few kinks on LLVM. Thus, on CPU we
928 // additionally disable it until we either move to a more stable version
929 // or find workarounds.
930 if (*st == c10::ScalarType::Half && *device == c10::kCPU) {
931 return false;
932 }
933
934 if (*st == c10::ScalarType::BFloat16 && *device == c10::kCPU) {
935 #ifndef TORCH_ENABLE_LLVM
936 return false;
937 #endif
938 }
939
940 // These operators only support floats, because integer divisors need to
941 // raise ZeroDivisionError.
942 if (node->isMemberOf(float_only_operator_set) && !isFloatingType(*st)) {
943 return false;
944 }
945
946 // These operators have complicated casting rules for floats.
947 if (node->isMemberOf(int_only_operator_set) && isFloatingType(*st)) {
948 return false;
949 }
950 } else if (node->isMemberOf(float_only_operator_set)) {
951 // Check scalar operands of float-only ops.
952 if (!v->type()->cast<FloatType>()) {
953 return false;
954 }
955 } else if (node->isMemberOf(int_only_operator_set)) {
956 if (!v->type()->cast<IntType>()) {
957 return false;
958 }
959 }
960 }
961
962 // aten::pow has special rules to avoid complicated integer cases. We
963 // expect the first arg to be a floating point tensor, and if that's the
964 // case the type of the scalar exponent doesn't matter.
965 if (node->isMemberOf(pow)) {
966 auto const& tt = node->input(0)->type()->cast<TensorType>();
967 if (!tt) {
968 return false;
969 }
970 auto const& st = tt->scalarType();
971 if (!st || !isFloatingType(*st)) {
972 return false;
973 }
974 }
975
976 // Operator is only supported on CPU.
977 if (node->isMemberOf(cpu_compute_heavy_set)) {
978 if (fuse_to_dynamic_shapes_) {
979 return false;
980 }
981
982 auto device = tensorexpr::pickDeviceType(node->inputs());
983 if (!device) {
984 device = tensorexpr::pickDeviceType(node->outputs());
985 }
986 if (!device || !device->is_cpu()) {
987 return false;
988 }
989 }
990
991 // Operator is only supported on GPU.
992 if (node->isMemberOf(gpu_only_operator_set)) {
993 auto device = tensorexpr::pickDeviceType(node->inputs());
994 if (!device) {
995 device = tensorexpr::pickDeviceType(node->outputs());
996 }
997 if (!device || !device->is_cuda()) {
998 return false;
999 }
1000 }
1001
1002 if (node->kind() == aten::to) {
1003 // only support same-device conversion
1004 auto device = tensorexpr::pickDeviceType(node->inputs());
1005 auto output_device = tensorexpr::pickDeviceType(node->outputs());
1006 if (!device || !output_device || *device != *output_device) {
1007 return false;
1008 }
1009 // non_blocking only applies in cross-device conversion, which we bail on
1010 // copy arg only applies if op is a no-op, which we dont start fusion
1011 // group from memory format is separately handled in NNC output
1012
1013 // all non-Tensor arguments must be constant
1014 for (size_t i = 1; i < node->inputs().size(); i++) {
1015 if (node->inputs().at(i)->node()->kind() != prim::Constant) {
1016 return false;
1017 }
1018 }
1019
1020 if (has_unsupported_pin_memory(node)) {
1021 return false;
1022 }
1023 }
1024
1025 if (node->kind() == aten::_autocast_to_reduced_precision ||
1026 node->kind() == aten::_autocast_to_full_precision) {
1027 for (auto i : c10::irange(1, node->inputs().size())) {
1028 if (node->inputs().at(i)->node()->kind() != prim::Constant) {
1029 return false;
1030 }
1031 }
1032
1033 bool is_reduced_precision =
1034 node->kind() == aten::_autocast_to_reduced_precision;
1035 bool is_full_precision =
1036 node->kind() == aten::_autocast_to_full_precision;
1037 auto self_tensor = node->inputs()[0]; // input tensor
1038
1039 if (auto const& tt = self_tensor->type()->cast<TensorType>()) {
1040 auto st = tt->scalarType();
1041 if (!st.has_value()) {
1042 return false;
1043 }
1044
1045 auto device = tt->device();
1046 if (!device.has_value()) {
1047 return false;
1048 }
1049
1050 bool is_cpu = device->is_cpu();
1051
1052 if (*st != at::kFloat && is_reduced_precision && is_cpu) {
1053 // Regarding CPU, aten would do nothing if the data type is
1054 // float. Then the aten performance is better than NNC. So NNC
1055 // does not pull it into its fusion group.
1056 return false;
1057 }
1058
1059 if (*st != at::kBFloat16 && is_full_precision && is_cpu) {
1060 // Regarding CPU, aten would do nothing if the data type is
1061 // BFloat16. Then the aten performance is better than NNC. So NNC
1062 // does not pull it into its fusion group.
1063 return false;
1064 }
1065 }
1066
1067 if (has_unsupported_pin_memory(node)) {
1068 return false;
1069 }
1070 }
1071
1072 if (node->kind() == aten::unsqueeze) {
1073 // `dim` argument must be a constant.
1074 if (node->input(1)->node()->kind() != prim::Constant) {
1075 return false;
1076 }
1077 }
1078
1079 if (node->kind() == aten::_convolution && !tensorexpr::isConv2d(node)) {
1080 GRAPH_DEBUG("This aten::_convolution node is not a 2D conv");
1081 return false;
1082 }
1083 if (node->kind() == aten::_convolution || node->kind() == aten::conv2d) {
1084 if (!tensorexpr::conv2dIsSupportedJit(node) &&
1085 !tensorexpr::mkldnnPrepackedConvIsSupportedJit(node)) {
1086 GRAPH_DEBUG("Params of conv2d are not supported");
1087 return false;
1088 }
1089 }
1090 if (node->kind() == aten::matmul) {
1091 if (!tensorexpr::matmulIsSupported(node)) {
1092 GRAPH_DEBUG("Shapes of matmul inputs are not supported");
1093 return false;
1094 }
1095 }
1096 return true;
1097 }
1098
1099 #define REQ(cond) \
1100 if (!(cond)) { \
1101 GRAPH_DEBUG("Failed cond " #cond "\n"); \
1102 return false; \
1103 }
1104
canHandle(Node * node)1105 bool canHandle(Node* node) {
1106 REQ(allShapesAreKnown(node));
1107 REQ(isFusableOnDevice(node));
1108 REQ(operators_not_to_fuse.find(node->kind()) ==
1109 operators_not_to_fuse.end());
1110
1111 for (Value* input : node->inputs()) {
1112 if (auto const& tt = input->type()->cast<TensorType>()) {
1113 auto st = tt->scalarType();
1114 if (!st) {
1115 // All tensor types should be known.
1116 return false;
1117 }
1118 if (c10::isComplexType(*st) || c10::isQIntType(*st)) {
1119 return false;
1120 }
1121 }
1122 }
1123 if (node->kind() == aten::cat) {
1124 REQ(node->input(0)->node()->kind() == prim::ListConstruct);
1125 REQ(node->input(0)->uses().size() == 1);
1126 REQ(node->input(1)->node()->kind() == prim::Constant);
1127 auto const& listconstruct = node->input(0)->node();
1128 REQ(tensorexpr::pickDeviceType(listconstruct->inputs()));
1129 } else {
1130 REQ(tensorexpr::pickDeviceType(node->inputs()));
1131 }
1132
1133 // Only fuse aten::batch_norm when the parameter 'training' is false
1134 if (node->kind() == aten::batch_norm) {
1135 REQ(node->input(5)->node()->kind() == prim::Constant);
1136 REQ(!toIValue(node->input(5)).value().toBool());
1137 }
1138
1139 REQ(tensorexpr::isSupported(node));
1140 REQ(typesAreSupported(node));
1141
1142 // A hook to optimizations limiter to allow bisecting the pass
1143 REQ(JIT_OPT_ALLOWED);
1144
1145 if (fuse_to_dynamic_shapes_) {
1146 // Allow only if the node has a shape function defined.
1147 // ListConstruct node is an exception since that is needed to fuse
1148 // aten::cat, though it does not have a shape function.
1149 REQ(node->kind() == prim::ListConstruct ||
1150 node->kind() == prim::TensorExprGroup ||
1151 node->isMemberOf(tensorexpr::getCustomOperatorSet()) ||
1152 (node->maybeSchema() && shapeComputeGraphForSchema(node->schema())));
1153 }
1154
1155 return true;
1156 }
1157
canMerge(Node * consumer,Node * producer)1158 bool canMerge(Node* consumer, Node* producer) {
1159 // Only fuse within a block
1160 REQ(consumer->owningBlock() == producer->owningBlock());
1161
1162 // Symbolic checks
1163 REQ(canHandle(producer) || producer->kind() == prim::TensorExprGroup);
1164 TORCH_INTERNAL_ASSERT(
1165 consumer->kind() == prim::TensorExprGroup || canHandle(consumer));
1166
1167 // nvrtc has a limit on the number of arguments allowed in a CUDA kernel.
1168 // The specific limit is a function of constant memory size, amount
1169 // available to pass arguments, and some implementation dependence. Select a
1170 // safe limit here.
1171 constexpr size_t subgraphArgLimit = 128;
1172 auto const nInputs = consumer->inputs().size() +
1173 consumer->outputs().size() + producer->inputs().size() +
1174 producer->outputs().size();
1175 REQ(nInputs <= subgraphArgLimit);
1176
1177 // Device checks
1178 if (consumer->kind() != aten::cat && producer->kind() != aten::cat) {
1179 // aten::cat needs a special handling because it takes a Tensor[] as its
1180 // input We deal with that in the code below.
1181 auto consumer_device = tensorexpr::pickDeviceType(consumer->inputs());
1182 REQ(consumer_device);
1183 auto producer_device = tensorexpr::pickDeviceType(producer->inputs());
1184 REQ(producer_device);
1185 REQ(*consumer_device == *producer_device);
1186 }
1187
1188 // Alias checks
1189 REQ(aliasDb_->couldMoveBeforeTopologically(producer, consumer));
1190
1191 // Ops that return aliases can only be folded if this is the only use.
1192 if (producer->kind() == aten::slice ||
1193 producer->kind() == aten::unsqueeze ||
1194 producer->kind() == prim::ConstantChunk) {
1195 for (auto& use : producer->output(0)->uses()) {
1196 REQ(use.user == consumer);
1197 }
1198 }
1199
1200 if (!consumer->hasAttribute(attr::Subgraph) &&
1201 consumer->kind() != prim::TensorExprGroup) {
1202 // Don't initiate a fusion group from prim::ListConstruct
1203 REQ(consumer->kind() != prim::ListConstruct);
1204 REQ(consumer->kind() != aten::slice);
1205 REQ(consumer->kind() != aten::unsqueeze);
1206 REQ(consumer->kind() != prim::ConstantChunk);
1207
1208 // Don't initiate a fusion group just for a constant operand
1209 REQ(producer->kind() != prim::Constant);
1210 }
1211
1212 if (producer->kind() == aten::cat) {
1213 REQ(producer->input(0)->node()->kind() == prim::ListConstruct);
1214 REQ(producer->input(0)->uses().size() == 1);
1215 REQ(producer->input(1)->node()->kind() == prim::Constant);
1216 auto const& listConstruct = producer->input(0)->node();
1217 // We're merging listconstruct->cat->consumer. cat is the producer here
1218 // and we cannot determine its device type - we should use device of the
1219 // listconstruct instead
1220 auto listconstruct_device =
1221 tensorexpr::pickDeviceType(listConstruct->inputs());
1222 auto consumer_device = tensorexpr::pickDeviceType(consumer->inputs());
1223 REQ(listconstruct_device);
1224 REQ(consumer_device);
1225 REQ(*listconstruct_device == *consumer_device);
1226 for (auto const& input : listConstruct->inputs()) {
1227 REQ(isFusableOnDevice(input->node()));
1228 }
1229 REQ((nInputs + listConstruct->inputs().size()) <= subgraphArgLimit);
1230 } else if (consumer->kind() == aten::cat) {
1231 REQ(consumer->input(0)->node()->kind() == prim::ListConstruct);
1232 REQ(consumer->input(0)->uses().size() == 1);
1233 REQ(consumer->input(1)->node()->kind() == prim::Constant);
1234 auto const& listConstruct = consumer->input(0)->node();
1235 // We're merging listconstruct->cat. cat is the consumer and listconstruct
1236 // is the producer. cat doesn't have its device type and thus the only
1237 // thing we should check is that listconstruct's device is well defined
1238 // (e.g. all its inputs has the same device).
1239 auto listconstruct_device =
1240 tensorexpr::pickDeviceType(listConstruct->inputs());
1241 REQ(listconstruct_device);
1242 REQ((nInputs + listConstruct->inputs().size()) <= subgraphArgLimit);
1243 } else {
1244 REQ(isFusableOnDevice(producer));
1245 }
1246
1247 return true;
1248 }
1249 #undef REQ
1250
prepareFusionGroupAndGuardOutputs(Block * block)1251 void prepareFusionGroupAndGuardOutputs(Block* block) {
1252 std::vector<Node*> fusion_groups;
1253 for (Node* n : block->nodes()) {
1254 for (Block* b : n->blocks()) {
1255 prepareFusionGroupAndGuardOutputs(b);
1256 }
1257 if (n->kind() == prim::TensorExprGroup) {
1258 fusion_groups.push_back(n);
1259 }
1260 }
1261 for (Node* fusion_group : fusion_groups) {
1262 removeOutputsUsedOnlyInSize(fusion_group);
1263 insertTypeGuard(
1264 fusion_group,
1265 [](const TensorTypePtr& t) { return t; },
1266 prim::TypeCheck);
1267 }
1268 }
1269
generalizeFusionGroups(Block * block)1270 void generalizeFusionGroups(Block* block) {
1271 std::vector<Node*> fusion_groups;
1272 for (Node* n : block->nodes()) {
1273 for (Block* b : n->blocks()) {
1274 generalizeFusionGroups(b);
1275 }
1276 if (n->kind() == prim::TensorExprGroup) {
1277 fusion_groups.push_back(n);
1278 }
1279 }
1280 for (Node* fusion_group : fusion_groups) {
1281 removeOutputsUsedOnlyInSize(fusion_group);
1282 VLOG(1) << "GenerateGuard for fusion group: " << *fusion_group;
1283 if (!GenerateGuard(fusion_group, add_composed_op_)) {
1284 VLOG(1) << " Unfusing the fusion group because GenerateGuard failed"
1285 << '\n';
1286 SubgraphUtils::unmergeSubgraph(fusion_group);
1287 }
1288 }
1289 }
1290
1291 // This function parses the option provided by the environment variable
1292 // "PYTORCH_TENSOREXPR_DONT_FUSE".
1293 // This variable allows users to disable fusion on a list of specified
1294 // operators that are separated by ':'. e.g.,
1295 // 'PYTORCH_TENSOREXPR_DONT_FUSE="clamp:mul:add"' disables fusion on
1296 // aten::clamp, aten::mul and aten::add.
parseTENotFuseOption()1297 void parseTENotFuseOption() {
1298 const char* option = std::getenv("PYTORCH_TENSOREXPR_DONT_FUSE");
1299 std::stringstream in_ss;
1300 if (option) {
1301 in_ss << option;
1302 }
1303
1304 std::string line;
1305 while (std::getline(in_ss, line, ':')) {
1306 if (line.empty()) {
1307 continue;
1308 }
1309 operators_not_to_fuse.insert(c10::Symbol::aten(line));
1310 }
1311 }
1312
1313 std::shared_ptr<Graph> graph_;
1314 std::unique_ptr<AliasDb> aliasDb_ = nullptr;
1315
1316 std::set<NodeKind> operators_not_to_fuse;
1317 // Minimal size of a fusion group
1318 size_t min_group_size_;
1319 // compose Runtime Type Guard and Kernel in one op
1320 bool add_composed_op_;
1321 // generalize static shapes to dynamic shapes
1322 bool fuse_to_dynamic_shapes_;
1323 };
1324
FuseTensorExprs(std::shared_ptr<Graph> & graph,size_t min_group_size,bool add_composed_op,bool fuse_to_dynamic_shapes)1325 void FuseTensorExprs(
1326 std::shared_ptr<Graph>& graph,
1327 size_t min_group_size,
1328 bool add_composed_op,
1329 bool fuse_to_dynamic_shapes) {
1330 GRAPH_DUMP("Before TExprFuser: ", graph);
1331
1332 // Temporary change for Block code generation.
1333 if (tensorexpr::getTEGenerateBlockCode()) {
1334 min_group_size = 1;
1335 }
1336
1337 if (add_composed_op) {
1338 TORCH_INTERNAL_ASSERT(
1339 fuse_to_dynamic_shapes, "Fusing static shapes with composed op NYI");
1340 }
1341
1342 // Get rid of dead code so that we don't waste effort fusing it.
1343 EliminateDeadCode(graph);
1344
1345 TensorExprFuser fuser(
1346 graph, min_group_size, add_composed_op, fuse_to_dynamic_shapes);
1347 fuser.run();
1348
1349 EliminateCommonSubexpression(graph);
1350 EliminateDeadCode(graph);
1351
1352 GRAPH_DUMP("After TExprFuser: ", graph);
1353 }
1354
createTensorExprOp(const Node * node)1355 static Operation createTensorExprOp(const Node* node) {
1356 bool dynamic_shape_fusion_node =
1357 node->hasAttribute(attr::striding_inputs_desc);
1358 if (!dynamic_shape_fusion_node) {
1359 auto kernel =
1360 std::make_shared<tensorexpr::TensorExprKernel>(node->g(attr::Subgraph));
1361 return [kernel](Stack& stack) {
1362 RECORD_FUNCTION(kernel->getKernelName(), std::vector<c10::IValue>());
1363 kernel->run(stack);
1364 return 0;
1365 };
1366 }
1367
1368 // Handle the case when dynamic shape fusion is enabled.
1369 VLOG(1) << "Compiling a new kernel for " << *node;
1370 std::vector<int64_t> sym_shapes;
1371 if (node->hasAttribute(attr::symbolic_shape_inputs)) {
1372 sym_shapes = node->is(attr::symbolic_shape_inputs);
1373 }
1374 bool allow_stack_outputs = false;
1375 if (node->hasAttribute(attr::allow_stack_outputs)) {
1376 allow_stack_outputs = node->i(attr::allow_stack_outputs) == 1;
1377 }
1378
1379 std::unordered_map<c10::Symbol, tensorexpr::NNCLoweringFunction>
1380 custom_lowerings;
1381 auto subgraph = node->g(attr::Subgraph);
1382 IValue sym_strides = node->ival(attr::striding_inputs_desc);
1383
1384 // Striding Descriptor is serialized on the node as a vector of vector of
1385 // strings, translate back to StrideInput enum
1386 std::vector<std::vector<std::string>> sym_strides_strs =
1387 sym_strides.to<std::vector<std::vector<std::string>>>();
1388 std::vector<std::vector<StrideInput>> striding_inputs;
1389 for (const auto& vec : sym_strides_strs) {
1390 std::vector<StrideInput> input_desc;
1391 input_desc.reserve(vec.size());
1392 for (const std::string& str : vec) {
1393 input_desc.push_back(strideInputFromString(str));
1394 }
1395 striding_inputs.push_back(input_desc);
1396 }
1397 std::unordered_map<const Value*, std::vector<StrideInput>> stride_map;
1398 size_t index = 0;
1399 for (Value* v : subgraph->inputs()) {
1400 if (!v->type()->cast<TensorType>()) {
1401 continue;
1402 }
1403 stride_map[v] = striding_inputs[index];
1404 index++;
1405 }
1406 std::vector<std::string> output_desc =
1407 node->ival(attr::striding_outputs_desc).to<std::vector<std::string>>();
1408 for (size_t i = 0; i < subgraph->outputs().size(); ++i) {
1409 stride_map[subgraph->outputs().at(i)] = {
1410 strideInputFromString(output_desc.at(i))};
1411 }
1412
1413 std::shared_ptr<tensorexpr::TensorExprKernel> kernel =
1414 std::make_shared<tensorexpr::TensorExprKernel>(
1415 subgraph,
1416 custom_lowerings,
1417 sym_shapes,
1418 /*pre_alloc*/ false,
1419 stride_map);
1420
1421 auto num_subgraph_inputs = subgraph->inputs().size();
1422 return [kernel, num_subgraph_inputs, allow_stack_outputs](Stack& stack) {
1423 RECORD_FUNCTION(kernel->getKernelName(), std::vector<c10::IValue>());
1424
1425 // Stack contents:
1426 // [<outputs>] <inputs>
1427 //
1428 // If the number of graph inputs is same as the stack size, then no
1429 // outputs are being passed in. Otherwise, output tensors are passed in
1430 // at the bottom of the stack. So, we call the appropriate run function
1431 // in TensorExprKernel.
1432 if (num_subgraph_inputs == stack.size() || !allow_stack_outputs) {
1433 kernel->run(stack);
1434 } else {
1435 kernel->runWithAllocatedOutputs(stack);
1436 }
1437 return 0;
1438 };
1439 }
1440
1441 RegisterOperators TensorExprOps({
1442 torch::jit::Operator(
1443 prim::TensorExprGroup,
1444 createTensorExprOp,
1445 AliasAnalysisKind::INTERNAL_SPECIAL_CASE),
1446 });
1447
1448 } // namespace torch::jit
1449