1 2 #include <torch/csrc/jit/passes/check_strict_fusion.h> 3 4 #include <c10/util/Exception.h> 5 #include <torch/csrc/jit/frontend/error_report.h> 6 #include <torch/csrc/jit/ir/ir.h> 7 #include <torch/csrc/jit/jit_log.h> 8 #include <torch/csrc/jit/passes/quantization/helper.h> 9 #include <torch/csrc/jit/runtime/graph_iterator.h> 10 11 namespace torch::jit { 12 13 namespace { 14 isStrictFusion(Value * value)15bool isStrictFusion(Value* value) { 16 const auto class_name = getModuleName(value); 17 return class_name.has_value() && 18 (*class_name == "__torch__.torch.jit.strict_fusion"); 19 } 20 21 } // namespace 22 fusionGuardCheck(Symbol k)23static bool fusionGuardCheck(Symbol k) { 24 return k == Symbol::prim("TensorExprDynamicGuard") || k == prim::TypeCheck || 25 k == prim::CudaFusionGuard || k == prim::RequiresGradCheck; 26 } 27 collectValuesUsedInGuard(Node * guarding_if,Node * enter_node)28static std::unordered_set<Node*> collectValuesUsedInGuard( 29 Node* guarding_if, 30 Node* enter_node) { 31 // DFS to collect 32 std::unordered_set<Node*> visited_nodes; 33 std::vector<Node*> queue = {guarding_if}; 34 35 while (!queue.empty()) { 36 Node* curr = queue[queue.size() - 1]; 37 queue.pop_back(); 38 visited_nodes.insert(curr); 39 // these nodes directly test Tensor inputs, and are not part of additional 40 // guards inserted 41 if (fusionGuardCheck(curr->kind())) { 42 continue; 43 } 44 for (Value* v : curr->inputs()) { 45 Node* inp_node = v->node(); 46 if (inp_node->isBefore(enter_node) || 47 inp_node->owningBlock() != enter_node->owningBlock()) { 48 continue; 49 } 50 if (visited_nodes.count(inp_node)) { 51 continue; 52 } 53 queue.push_back(inp_node); 54 } 55 } 56 return visited_nodes; 57 } 58 checkForUnfusedOps(Node * enter_node)59static void checkForUnfusedOps(Node* enter_node) { 60 std::vector<Node*> unsupported_nodes; 61 std::vector<Node*> guarding_ifs; // if multiple, we will throw 62 for (Node* node = enter_node->next(); node->kind() != prim::Exit; 63 node = node->next()) { 64 if (node->kind() == prim::If && 65 fusionGuardCheck(node->input()->node()->kind())) { 66 guarding_ifs.push_back(node); 67 continue; 68 } 69 unsupported_nodes.push_back(node); 70 } 71 72 if (guarding_ifs.size() > 1) { 73 std::stringstream ss; 74 ss << "Found multiple fusions: \n"; 75 for (Node* n : guarding_ifs) { 76 ss << *n << "\n"; 77 } 78 throw(ErrorReport(enter_node->input()->node()->sourceRange()) << ss.str()); 79 } 80 81 // autodiff/nnc both insert a number of guards, see 82 // `CudaFusionViewGuard Example Graph` 83 // to check for unfused nodes, look at node's whose outputs 84 // are not depended on by the fusion guard 85 // restrict search for all values after the first 86 // node in the prim::Enter block 87 88 std::unordered_set<Node*> guarding_check_nodes; 89 if (guarding_ifs.size() == 1) { 90 guarding_check_nodes = 91 collectValuesUsedInGuard(guarding_ifs[0], enter_node); 92 } 93 std::vector<Node*> unfused_nodes_not_used_in_guard; 94 for (Node* unfused : unsupported_nodes) { 95 if (!guarding_check_nodes.count(unfused)) { 96 unfused_nodes_not_used_in_guard.push_back(unfused); 97 } 98 } 99 if (!unfused_nodes_not_used_in_guard.empty()) { 100 std::stringstream ss; 101 ss << "Found unfused operators: \n"; 102 for (Node* unfused : unfused_nodes_not_used_in_guard) { 103 ss << "\t"; 104 if (unfused->maybeSchema()) { 105 ss << unfused->schema(); 106 } else { 107 unfused->kind().toDisplayString(); 108 } 109 ss << "\n"; 110 } 111 throw(ErrorReport(enter_node->input()->node()->sourceRange()) << ss.str()); 112 } 113 } 114 CheckStrictFusion(std::shared_ptr<Graph> & graph)115void CheckStrictFusion(std::shared_ptr<Graph>& graph) { 116 DepthFirstGraphNodeIterator it(graph); 117 Node* n = nullptr; 118 while ((n = it.next()) != nullptr) { 119 if (n->kind() == prim::Enter && isStrictFusion(n->input())) { 120 checkForUnfusedOps(n); 121 } 122 } 123 124 // TODO: remove context manager after checks 125 // TODO: improve control flow not taken, right now always errors 126 } 127 128 } // namespace torch::jit 129