1 #include <torch/csrc/jit/passes/insert_guards.h> 2 #include <torch/csrc/jit/runtime/profiling_record.h> 3 #include <memory> 4 #include <unordered_set> 5 6 namespace torch::jit { 7 8 struct GuardInserter { GuardInsertertorch::jit::GuardInserter9 GuardInserter(std::shared_ptr<Graph> graph) : graph_(std::move(graph)) {} 10 runtorch::jit::GuardInserter11 void run() { 12 insertGuards(graph_->block()); 13 ProfilingRecord::removeProfilingNodes(graph_->block()); 14 } 15 16 private: insertGuardstorch::jit::GuardInserter17 void insertGuards(Block* b) { 18 for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) { 19 auto n = *it; 20 if (n->kind() == prim::profile) { 21 auto pttp = n->ty(attr::profiled_type)->cast<TensorType>(); 22 if (pttp) { 23 auto guard = graph_->create(prim::Guard, {n->input()}, 1); 24 auto go = guard->output(); 25 go->setType(pttp); 26 guard->insertBefore(n); 27 n->output()->replaceAllUsesWith(go); 28 } else { 29 // we didn't go down this path i.e 30 // no profiling information is available 31 n->output()->replaceAllUsesWith(n->input()); 32 } 33 it.destroyCurrent(); 34 } else { 35 for (Block* ib : n->blocks()) { 36 insertGuards(ib); 37 } 38 } 39 } 40 } 41 42 std::shared_ptr<Graph> graph_; 43 }; 44 InsertGuards(std::shared_ptr<Graph> graph)45void InsertGuards(std::shared_ptr<Graph> graph) { 46 GuardInserter gi(std::move(graph)); 47 gi.run(); 48 } 49 50 } // namespace torch::jit 51