xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/check_strict_fusion.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 
2 #include <torch/csrc/jit/passes/check_strict_fusion.h>
3 
4 #include <c10/util/Exception.h>
5 #include <torch/csrc/jit/frontend/error_report.h>
6 #include <torch/csrc/jit/ir/ir.h>
7 #include <torch/csrc/jit/jit_log.h>
8 #include <torch/csrc/jit/passes/quantization/helper.h>
9 #include <torch/csrc/jit/runtime/graph_iterator.h>
10 
11 namespace torch::jit {
12 
13 namespace {
14 
isStrictFusion(Value * value)15 bool isStrictFusion(Value* value) {
16   const auto class_name = getModuleName(value);
17   return class_name.has_value() &&
18       (*class_name == "__torch__.torch.jit.strict_fusion");
19 }
20 
21 } // namespace
22 
fusionGuardCheck(Symbol k)23 static bool fusionGuardCheck(Symbol k) {
24   return k == Symbol::prim("TensorExprDynamicGuard") || k == prim::TypeCheck ||
25       k == prim::CudaFusionGuard || k == prim::RequiresGradCheck;
26 }
27 
collectValuesUsedInGuard(Node * guarding_if,Node * enter_node)28 static std::unordered_set<Node*> collectValuesUsedInGuard(
29     Node* guarding_if,
30     Node* enter_node) {
31   // DFS to collect
32   std::unordered_set<Node*> visited_nodes;
33   std::vector<Node*> queue = {guarding_if};
34 
35   while (!queue.empty()) {
36     Node* curr = queue[queue.size() - 1];
37     queue.pop_back();
38     visited_nodes.insert(curr);
39     // these nodes directly test Tensor inputs, and are not part of additional
40     // guards inserted
41     if (fusionGuardCheck(curr->kind())) {
42       continue;
43     }
44     for (Value* v : curr->inputs()) {
45       Node* inp_node = v->node();
46       if (inp_node->isBefore(enter_node) ||
47           inp_node->owningBlock() != enter_node->owningBlock()) {
48         continue;
49       }
50       if (visited_nodes.count(inp_node)) {
51         continue;
52       }
53       queue.push_back(inp_node);
54     }
55   }
56   return visited_nodes;
57 }
58 
checkForUnfusedOps(Node * enter_node)59 static void checkForUnfusedOps(Node* enter_node) {
60   std::vector<Node*> unsupported_nodes;
61   std::vector<Node*> guarding_ifs; // if multiple, we will throw
62   for (Node* node = enter_node->next(); node->kind() != prim::Exit;
63        node = node->next()) {
64     if (node->kind() == prim::If &&
65         fusionGuardCheck(node->input()->node()->kind())) {
66       guarding_ifs.push_back(node);
67       continue;
68     }
69     unsupported_nodes.push_back(node);
70   }
71 
72   if (guarding_ifs.size() > 1) {
73     std::stringstream ss;
74     ss << "Found multiple fusions: \n";
75     for (Node* n : guarding_ifs) {
76       ss << *n << "\n";
77     }
78     throw(ErrorReport(enter_node->input()->node()->sourceRange()) << ss.str());
79   }
80 
81   // autodiff/nnc both insert a number of guards, see
82   // `CudaFusionViewGuard Example Graph`
83   // to check for unfused nodes, look at node's whose outputs
84   // are not depended on by the fusion guard
85   // restrict search for all values after the first
86   // node in the prim::Enter block
87 
88   std::unordered_set<Node*> guarding_check_nodes;
89   if (guarding_ifs.size() == 1) {
90     guarding_check_nodes =
91         collectValuesUsedInGuard(guarding_ifs[0], enter_node);
92   }
93   std::vector<Node*> unfused_nodes_not_used_in_guard;
94   for (Node* unfused : unsupported_nodes) {
95     if (!guarding_check_nodes.count(unfused)) {
96       unfused_nodes_not_used_in_guard.push_back(unfused);
97     }
98   }
99   if (!unfused_nodes_not_used_in_guard.empty()) {
100     std::stringstream ss;
101     ss << "Found unfused operators: \n";
102     for (Node* unfused : unfused_nodes_not_used_in_guard) {
103       ss << "\t";
104       if (unfused->maybeSchema()) {
105         ss << unfused->schema();
106       } else {
107         unfused->kind().toDisplayString();
108       }
109       ss << "\n";
110     }
111     throw(ErrorReport(enter_node->input()->node()->sourceRange()) << ss.str());
112   }
113 }
114 
CheckStrictFusion(std::shared_ptr<Graph> & graph)115 void CheckStrictFusion(std::shared_ptr<Graph>& graph) {
116   DepthFirstGraphNodeIterator it(graph);
117   Node* n = nullptr;
118   while ((n = it.next()) != nullptr) {
119     if (n->kind() == prim::Enter && isStrictFusion(n->input())) {
120       checkForUnfusedOps(n);
121     }
122   }
123 
124   // TODO: remove context manager after checks
125   // TODO: improve control flow not taken, right now always errors
126 }
127 
128 } // namespace torch::jit
129