1 #include <torch/csrc/jit/passes/dbr_quantization/remove_redundant_aliases.h> 2 3 #include <torch/csrc/jit/ir/alias_analysis.h> 4 #include <torch/csrc/jit/jit_log.h> 5 #include <torch/csrc/jit/passes/quantization/helper.h> 6 #include <torch/csrc/jit/runtime/graph_iterator.h> 7 8 namespace torch { 9 namespace jit { 10 11 namespace { 12 DBRQuantRemoveRedundantAliasesImpl(const Method & method)13void DBRQuantRemoveRedundantAliasesImpl(const Method& method) { 14 auto g = method.graph(); 15 const bool is_frozen = false; 16 const bool descend_function_calls = true; 17 AliasDb alias_db(g, is_frozen, descend_function_calls); 18 // find the alias nodes 19 std::vector<Node*> alias_nodes; 20 DepthFirstGraphNodeIterator it(g); 21 Node* node = nullptr; 22 while ((node = it.next()) != nullptr) { 23 if (node->kind() == Symbol::aten("alias")) { 24 alias_nodes.push_back(node); 25 } 26 } 27 28 // remove the alias nodes, if it is safe to do so 29 for (auto* node : alias_nodes) { 30 GRAPH_DEBUG(*node); 31 32 Value* input_value = node->input(); 33 Value* output_value = node->output(); 34 35 bool always_safe_to_mutate = alias_db.safeToChangeAliasingRelationship( 36 node->inputs(), node->outputs()); 37 38 const auto g_in = g->inputs(); 39 const auto g_out = g->outputs(); 40 bool is_input = 41 std::find(g_in.begin(), g_in.end(), input_value) != g_in.end(); 42 bool is_output = 43 std::find(g_out.begin(), g_out.end(), output_value) != g_out.end(); 44 // We assume that aliasing is safe to update on inputs and outputs if they 45 // do not have writers. 46 bool input_safe_to_mutate = 47 (is_input && !alias_db.hasWriters(input_value) && 48 !alias_db.hasWriters(output_value)); 49 bool output_safe_to_mutate = 50 (is_output && !alias_db.hasWriters(input_value) && 51 !alias_db.hasWriters(output_value)); 52 53 if (always_safe_to_mutate || input_safe_to_mutate || 54 output_safe_to_mutate) { 55 output_value->replaceAllUsesWith(input_value); 56 node->destroy(); 57 } 58 } 59 } 60 61 } // namespace 62 DBRQuantRemoveRedundantAliases(Module & module)63Module DBRQuantRemoveRedundantAliases(Module& module) { 64 for (const auto& child : module.modules()) { 65 for (const auto& method : child.get_methods()) { 66 DBRQuantRemoveRedundantAliasesImpl(method); 67 } 68 } 69 70 return module; 71 } 72 73 } // namespace jit 74 } // namespace torch 75