xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/remove_redundant_profiles.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/dead_code_elimination.h>
2 #include <torch/csrc/jit/passes/remove_redundant_profiles.h>
3 
4 #include <torch/csrc/jit/ir/alias_analysis.h>
5 #include <torch/csrc/jit/ir/ir_views.h>
6 #include <torch/csrc/jit/jit_log.h>
7 
8 namespace torch::jit {
9 
RemoveRedundantProfiles(Block * block,AliasDb & db)10 void RemoveRedundantProfiles(Block* block, AliasDb& db) {
11   for (auto it = block->nodes().end()->reverseIterator();
12        it != block->nodes().begin();) {
13     Node* n = *it;
14     it++;
15 
16     for (Block* b : n->blocks()) {
17       RemoveRedundantProfiles(b, db);
18     }
19 
20     // we only check prim::profile and not prim::profile_ivalue bc profile
21     // is inserted on each use, while profile_ivalue is inserted on the def
22     if (n->kind() != prim::profile ||
23         n->input()->node()->kind() != prim::profile) {
24       continue;
25     }
26 
27     Node* input_node = n->input()->node();
28     if (input_node->ty(attr::profiled_type) != n->ty(attr::profiled_type)) {
29       continue;
30     }
31 
32     if (!db.moveBeforeTopologicallyValid(input_node, n)) {
33       continue;
34     }
35 
36     n->output()->replaceAllUsesWith(n->input());
37     n->destroy();
38   }
39 }
40 
RemoveRedundantProfiles(std::shared_ptr<Graph> & graph)41 void RemoveRedundantProfiles(std::shared_ptr<Graph>& graph) {
42   AliasDb db(graph);
43   RemoveRedundantProfiles(graph->block(), db);
44 }
45 
46 } // namespace torch::jit
47