1 #pragma once 2 3 #include <atomic> 4 #include <list> 5 6 #include <c10/core/ScalarType.h> 7 #include <torch/csrc/lazy/core/ir.h> 8 #include <torch/csrc/lazy/core/metrics.h> 9 10 namespace torch { 11 namespace lazy { 12 13 struct TORCH_API TrieNode { GetNextUniqueIdTrieNode14 static size_t GetNextUniqueId() { 15 static thread_local size_t id_generator = 0; 16 return id_generator++; 17 } 18 19 size_t unique_id; 20 size_t hit_counter; 21 NodePtr ir_node; 22 std::list<std::shared_ptr<TrieNode>> successors; 23 TrieNodeTrieNode24 TrieNode() : unique_id(GetNextUniqueId()), hit_counter(0), ir_node(nullptr) {} TrieNodeTrieNode25 explicit TrieNode(NodePtr node) 26 : unique_id(GetNextUniqueId()), 27 hit_counter(0), 28 ir_node(std::move(node)) {} 29 }; 30 31 class TORCH_API TrieCache { 32 public: 33 static TrieCache* Get(); 34 35 TrieNode* Current() const; 36 // Take an iterator as the input because we want to move the corresponding 37 // node in the successor list to achieve a LRU caching effect 38 void SetCurrent(std::list<std::shared_ptr<TrieNode>>::iterator& iter); 39 // Used in MarkStep to indicate the end of one tracing 40 void ResetCurrent(); 41 42 // Create a new TrieNode for ir_node and insert into the TrieCache 43 void Insert(NodePtr ir_node); 44 45 // Clear all TrieCache nodes 46 // TODO: Because we don't expect user to explicitly call this function via 47 // a Python API, we may need to introduce a threshold on the size of the cache 48 // to avoid holding tensors for too long. 49 void Clear(); 50 51 void DumpToDotFile(const std::string& file_name); 52 53 private: 54 TrieCache(); 55 56 std::shared_ptr<TrieNode> root_; 57 TrieNode* current_; 58 }; 59 60 template <typename T, typename... Args> LookupNodeFromTrieCache(Args &&...args)61NodePtr LookupNodeFromTrieCache(Args&&... args) { 62 auto& successors = TrieCache::Get()->Current()->successors; 63 for (auto it = successors.begin(); it != successors.end(); it++) { 64 NodePtr ir_node = (*it)->ir_node; 65 const T* concrete_node = NodeCast<T>(ir_node.get()); 66 if (concrete_node && 67 concrete_node->CanBeReused(std::forward<Args>(args)...)) { 68 TORCH_LAZY_COUNTER( 69 "IrNodeReused_" + c10::demangle((typeid(T).name())), 1); 70 (*it)->hit_counter++; 71 TrieCache::Get()->SetCurrent(it); 72 return ir_node; 73 } 74 } 75 return nullptr; 76 } 77 78 } // namespace lazy 79 } // namespace torch 80