1 #include <torch/csrc/jit/passes/dead_code_elimination.h> 2 #include <torch/csrc/jit/passes/remove_redundant_profiles.h> 3 4 #include <torch/csrc/jit/ir/alias_analysis.h> 5 #include <torch/csrc/jit/ir/ir_views.h> 6 #include <torch/csrc/jit/jit_log.h> 7 8 namespace torch::jit { 9 RemoveRedundantProfiles(Block * block,AliasDb & db)10void RemoveRedundantProfiles(Block* block, AliasDb& db) { 11 for (auto it = block->nodes().end()->reverseIterator(); 12 it != block->nodes().begin();) { 13 Node* n = *it; 14 it++; 15 16 for (Block* b : n->blocks()) { 17 RemoveRedundantProfiles(b, db); 18 } 19 20 // we only check prim::profile and not prim::profile_ivalue bc profile 21 // is inserted on each use, while profile_ivalue is inserted on the def 22 if (n->kind() != prim::profile || 23 n->input()->node()->kind() != prim::profile) { 24 continue; 25 } 26 27 Node* input_node = n->input()->node(); 28 if (input_node->ty(attr::profiled_type) != n->ty(attr::profiled_type)) { 29 continue; 30 } 31 32 if (!db.moveBeforeTopologicallyValid(input_node, n)) { 33 continue; 34 } 35 36 n->output()->replaceAllUsesWith(n->input()); 37 n->destroy(); 38 } 39 } 40 RemoveRedundantProfiles(std::shared_ptr<Graph> & graph)41void RemoveRedundantProfiles(std::shared_ptr<Graph>& graph) { 42 AliasDb db(graph); 43 RemoveRedundantProfiles(graph->block(), db); 44 } 45 46 } // namespace torch::jit 47