xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/clear_profiling.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/clear_profiling.h>
2 
3 #include <torch/csrc/jit/jit_log.h>
4 
5 namespace torch::jit {
6 
unprofileGraphInputs(const std::shared_ptr<Graph> & graph)7 void unprofileGraphInputs(const std::shared_ptr<Graph>& graph) {
8   for (auto i : graph->inputs()) {
9     if (i->type()->isSubtypeOf(*TensorType::get())) {
10       i->setType(unshapedType(i->type()));
11     }
12   }
13 }
14 
unprofileBlock(Block * start_block)15 void unprofileBlock(Block* start_block) {
16   std::vector<Block*> stack;
17   stack.push_back(start_block);
18 
19   while (!stack.empty()) {
20     Block* block = stack.back();
21     stack.pop_back();
22 
23     for (auto n : block->nodes()) {
24       for (auto o : n->outputs()) {
25         if (o->type()->isSubtypeOf(*TensorType::get())) {
26           o->setType(unshapedType(o->type()));
27         }
28       }
29       stack.insert(stack.end(), n->blocks().begin(), n->blocks().end());
30     }
31   }
32 }
33 
34 // We need to make sure that passes that use profiling information
35 // use it **only after** guards validating it are inserted
36 // Ideally, we would run any pass that relies on profiling information
37 // after `InsertBailOuts`, however, practically, some passes
38 // (e.g. Peephole) useful to run both w/ and w/o profiling information
39 // so we could run them in `preoptimizeGraph` and
40 // in `runProfilingInsensitiveOptimizations`
ClearProfilingInformation(const std::shared_ptr<Graph> & graph)41 void ClearProfilingInformation(const std::shared_ptr<Graph>& graph) {
42   unprofileGraphInputs(graph);
43   unprofileBlock(graph->block());
44   GRAPH_DUMP("After ClearProfilingInformation: ", graph);
45 }
46 
47 } // namespace torch::jit
48