xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/core/trie.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <atomic>
4 #include <list>
5 
6 #include <c10/core/ScalarType.h>
7 #include <torch/csrc/lazy/core/ir.h>
8 #include <torch/csrc/lazy/core/metrics.h>
9 
10 namespace torch {
11 namespace lazy {
12 
13 struct TORCH_API TrieNode {
GetNextUniqueIdTrieNode14   static size_t GetNextUniqueId() {
15     static thread_local size_t id_generator = 0;
16     return id_generator++;
17   }
18 
19   size_t unique_id;
20   size_t hit_counter;
21   NodePtr ir_node;
22   std::list<std::shared_ptr<TrieNode>> successors;
23 
TrieNodeTrieNode24   TrieNode() : unique_id(GetNextUniqueId()), hit_counter(0), ir_node(nullptr) {}
TrieNodeTrieNode25   explicit TrieNode(NodePtr node)
26       : unique_id(GetNextUniqueId()),
27         hit_counter(0),
28         ir_node(std::move(node)) {}
29 };
30 
31 class TORCH_API TrieCache {
32  public:
33   static TrieCache* Get();
34 
35   TrieNode* Current() const;
36   // Take an iterator as the input because we want to move the corresponding
37   // node in the successor list to achieve a LRU caching effect
38   void SetCurrent(std::list<std::shared_ptr<TrieNode>>::iterator& iter);
39   // Used in MarkStep to indicate the end of one tracing
40   void ResetCurrent();
41 
42   // Create a new TrieNode for ir_node and insert into the TrieCache
43   void Insert(NodePtr ir_node);
44 
45   // Clear all TrieCache nodes
46   // TODO: Because we don't expect user to explicitly call this function via
47   // a Python API, we may need to introduce a threshold on the size of the cache
48   // to avoid holding tensors for too long.
49   void Clear();
50 
51   void DumpToDotFile(const std::string& file_name);
52 
53  private:
54   TrieCache();
55 
56   std::shared_ptr<TrieNode> root_;
57   TrieNode* current_;
58 };
59 
60 template <typename T, typename... Args>
LookupNodeFromTrieCache(Args &&...args)61 NodePtr LookupNodeFromTrieCache(Args&&... args) {
62   auto& successors = TrieCache::Get()->Current()->successors;
63   for (auto it = successors.begin(); it != successors.end(); it++) {
64     NodePtr ir_node = (*it)->ir_node;
65     const T* concrete_node = NodeCast<T>(ir_node.get());
66     if (concrete_node &&
67         concrete_node->CanBeReused(std::forward<Args>(args)...)) {
68       TORCH_LAZY_COUNTER(
69           "IrNodeReused_" + c10::demangle((typeid(T).name())), 1);
70       (*it)->hit_counter++;
71       TrieCache::Get()->SetCurrent(it);
72       return ir_node;
73     }
74   }
75   return nullptr;
76 }
77 
78 } // namespace lazy
79 } // namespace torch
80