xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/core/trie.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/lazy/core/trie.h>
2 
3 #include <torch/csrc/lazy/core/hash.h>
4 #include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
5 #include <torch/csrc/lazy/core/ir_metadata.h>
6 #include <torch/csrc/lazy/core/metrics.h>
7 #include <fstream>
8 #include <sstream>
9 
10 namespace torch {
11 namespace lazy {
12 namespace {
13 
TraverseTrie(TrieNode * node,std::stringstream & ss)14 void TraverseTrie(TrieNode* node, std::stringstream& ss) {
15   if (!node) {
16     return;
17   }
18   if (node->ir_node) {
19     ss << node->unique_id << "[label=\"" << node->ir_node->op().ToString()
20        << ", " << node->hit_counter << " hits\"]\n";
21   }
22   for (auto& successor : node->successors) {
23     ss << node->unique_id << " -> " << successor->unique_id << "\n";
24     TraverseTrie(successor.get(), ss);
25   }
26 }
27 } // namespace
28 
Get()29 TrieCache* TrieCache::Get() {
30   static thread_local TrieCache* trie = new TrieCache();
31   return trie;
32 }
33 
TrieCache()34 TrieCache::TrieCache()
35     : root_(std::make_shared<TrieNode>()), current_(root_.get()) {}
36 
Current() const37 TrieNode* TrieCache::Current() const {
38   return current_;
39 }
40 
SetCurrent(std::list<std::shared_ptr<TrieNode>>::iterator & iter)41 void TrieCache::SetCurrent(
42     std::list<std::shared_ptr<TrieNode>>::iterator& iter) {
43   auto& successors = current_->successors;
44   // Update current_ before iter gets destroyed
45   current_ = (*iter).get();
46 
47   // Insert this node to the front of its parent's successor list
48   if (iter != successors.begin()) {
49     successors.push_front(std::move(*iter));
50     successors.erase(iter);
51   }
52 }
53 
ResetCurrent()54 void TrieCache::ResetCurrent() {
55   current_ = root_.get();
56 }
57 
Insert(NodePtr ir_node)58 void TrieCache::Insert(NodePtr ir_node) {
59   TORCH_CHECK(current_);
60   if (!current_->successors.empty()) {
61     TORCH_LAZY_COUNTER("TrieForked", 1);
62   }
63   auto new_node = std::make_shared<TrieNode>(std::move(ir_node));
64   current_->successors.push_front(std::move(new_node));
65   // Update current_ to the newly inserted node
66   current_ = current_->successors.front().get();
67 }
68 
Clear()69 void TrieCache::Clear() {
70   ResetCurrent();
71   // Clear at the root level should be sufficient because all the nodes
72   // are created as shared_ptr.
73   root_->successors.clear();
74 }
75 
DumpToDotFile(const std::string & file_name)76 void TrieCache::DumpToDotFile(const std::string& file_name) {
77   std::stringstream ss;
78   ss << "digraph G {\n";
79   TraverseTrie(root_.get(), ss);
80   ss << "}\n";
81 
82   std::ofstream graph_file(file_name);
83   graph_file << ss.str();
84 }
85 
86 } // namespace lazy
87 } // namespace torch
88