xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/insert_guards.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/insert_guards.h>
2 #include <torch/csrc/jit/runtime/profiling_record.h>
3 #include <memory>
4 #include <unordered_set>
5 
6 namespace torch::jit {
7 
8 struct GuardInserter {
GuardInsertertorch::jit::GuardInserter9   GuardInserter(std::shared_ptr<Graph> graph) : graph_(std::move(graph)) {}
10 
runtorch::jit::GuardInserter11   void run() {
12     insertGuards(graph_->block());
13     ProfilingRecord::removeProfilingNodes(graph_->block());
14   }
15 
16  private:
insertGuardstorch::jit::GuardInserter17   void insertGuards(Block* b) {
18     for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) {
19       auto n = *it;
20       if (n->kind() == prim::profile) {
21         auto pttp = n->ty(attr::profiled_type)->cast<TensorType>();
22         if (pttp) {
23           auto guard = graph_->create(prim::Guard, {n->input()}, 1);
24           auto go = guard->output();
25           go->setType(pttp);
26           guard->insertBefore(n);
27           n->output()->replaceAllUsesWith(go);
28         } else {
29           // we didn't go down this path i.e
30           // no profiling information is available
31           n->output()->replaceAllUsesWith(n->input());
32         }
33         it.destroyCurrent();
34       } else {
35         for (Block* ib : n->blocks()) {
36           insertGuards(ib);
37         }
38       }
39     }
40   }
41 
42   std::shared_ptr<Graph> graph_;
43 };
44 
InsertGuards(std::shared_ptr<Graph> graph)45 void InsertGuards(std::shared_ptr<Graph> graph) {
46   GuardInserter gi(std::move(graph));
47   gi.run();
48 }
49 
50 } // namespace torch::jit
51