1 #include <torch/csrc/lazy/core/trie.h> 2 3 #include <torch/csrc/lazy/core/hash.h> 4 #include <torch/csrc/lazy/core/internal_ops/ltc_ops.h> 5 #include <torch/csrc/lazy/core/ir_metadata.h> 6 #include <torch/csrc/lazy/core/metrics.h> 7 #include <fstream> 8 #include <sstream> 9 10 namespace torch { 11 namespace lazy { 12 namespace { 13 TraverseTrie(TrieNode * node,std::stringstream & ss)14void TraverseTrie(TrieNode* node, std::stringstream& ss) { 15 if (!node) { 16 return; 17 } 18 if (node->ir_node) { 19 ss << node->unique_id << "[label=\"" << node->ir_node->op().ToString() 20 << ", " << node->hit_counter << " hits\"]\n"; 21 } 22 for (auto& successor : node->successors) { 23 ss << node->unique_id << " -> " << successor->unique_id << "\n"; 24 TraverseTrie(successor.get(), ss); 25 } 26 } 27 } // namespace 28 Get()29TrieCache* TrieCache::Get() { 30 static thread_local TrieCache* trie = new TrieCache(); 31 return trie; 32 } 33 TrieCache()34TrieCache::TrieCache() 35 : root_(std::make_shared<TrieNode>()), current_(root_.get()) {} 36 Current() const37TrieNode* TrieCache::Current() const { 38 return current_; 39 } 40 SetCurrent(std::list<std::shared_ptr<TrieNode>>::iterator & iter)41void TrieCache::SetCurrent( 42 std::list<std::shared_ptr<TrieNode>>::iterator& iter) { 43 auto& successors = current_->successors; 44 // Update current_ before iter gets destroyed 45 current_ = (*iter).get(); 46 47 // Insert this node to the front of its parent's successor list 48 if (iter != successors.begin()) { 49 successors.push_front(std::move(*iter)); 50 successors.erase(iter); 51 } 52 } 53 ResetCurrent()54void TrieCache::ResetCurrent() { 55 current_ = root_.get(); 56 } 57 Insert(NodePtr ir_node)58void TrieCache::Insert(NodePtr ir_node) { 59 TORCH_CHECK(current_); 60 if (!current_->successors.empty()) { 61 TORCH_LAZY_COUNTER("TrieForked", 1); 62 } 63 auto new_node = std::make_shared<TrieNode>(std::move(ir_node)); 64 current_->successors.push_front(std::move(new_node)); 65 // Update current_ to the newly inserted node 66 current_ = current_->successors.front().get(); 67 } 68 Clear()69void TrieCache::Clear() { 70 ResetCurrent(); 71 // Clear at the root level should be sufficient because all the nodes 72 // are created as shared_ptr. 73 root_->successors.clear(); 74 } 75 DumpToDotFile(const std::string & file_name)76void TrieCache::DumpToDotFile(const std::string& file_name) { 77 std::stringstream ss; 78 ss << "digraph G {\n"; 79 TraverseTrie(root_.get(), ss); 80 ss << "}\n"; 81 82 std::ofstream graph_file(file_name); 83 graph_file << ss.str(); 84 } 85 86 } // namespace lazy 87 } // namespace torch 88