xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/codegen/fuser/kernel_cache.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/codegen/fuser/kernel_cache.h>
2 #include <torch/csrc/jit/passes/canonicalize.h>
3 #include <torch/csrc/jit/passes/shape_analysis.h>
4 
5 #include <cstdint>
6 #include <mutex>
7 #include <unordered_map>
8 
9 namespace torch::jit::fuser {
10 
11 struct KernelCacheImpl {
12   // Note: std::unordered_map does not invalidate references even if rehashing
13   // occurs. This is a critical property for thread-safety.
14   std::mutex mutex_;
15   int64_t kernel_counter{0};
16 
17   // Map of fusion key to KernelSpec
18   std::unordered_map<int64_t, KernelSpec> specMap_;
19 
20   // Map of pretty-printed graph string to fusion key
21   // Used to check if a graph has already been cached in specMap_
22   std::unordered_map<std::string, int64_t> graphToKey_;
23 };
24 
getKernelCache()25 static KernelCacheImpl& getKernelCache() {
26   static KernelCacheImpl cache;
27   return cache;
28 }
29 
debugNumCachedKernelSpecs()30 int64_t debugNumCachedKernelSpecs() {
31   auto& cache = getKernelCache();
32   std::lock_guard<std::mutex> guard{cache.mutex_};
33   return cache.specMap_.size();
34 }
35 
normalizeGraphForCache(const std::shared_ptr<Graph> & graph)36 std::shared_ptr<Graph> normalizeGraphForCache(
37     const std::shared_ptr<Graph>& graph) {
38   auto result = Canonicalize(graph, /*keep_unique_names=*/false);
39   EraseShapeInformation(result);
40   return result;
41 }
42 
43 // TODO: lookup by historic string key to start, then issue key
44 // as appropriate for faster lookup in the future
45 // precondition: graph has been normalized via normalizeGraphForCache
store(std::shared_ptr<Graph> graph)46 int64_t store(std::shared_ptr<Graph> graph) {
47   auto& cache = getKernelCache();
48   std::string repr = graph->toString(false);
49 
50   std::lock_guard<std::mutex> guard{cache.mutex_};
51   const auto key = cache.kernel_counter++;
52   cache.specMap_.emplace(
53       std::piecewise_construct,
54       std::forward_as_tuple(key),
55       std::forward_as_tuple(key, graph));
56   cache.graphToKey_.emplace(std::move(repr), key);
57   return key;
58 }
59 
60 // XXX: Does not grab mutex
nolock_retrieve(KernelCacheImpl & cache,const int64_t key)61 static std::optional<KernelSpec*> nolock_retrieve(
62     KernelCacheImpl& cache,
63     const int64_t key) {
64   auto it = cache.specMap_.find(key);
65   if (it == cache.specMap_.end())
66     return std::nullopt;
67   return &(it->second);
68 }
69 
retrieve(const int64_t key)70 std::optional<KernelSpec*> retrieve(const int64_t key) {
71   auto& cache = getKernelCache();
72   std::lock_guard<std::mutex> guard{cache.mutex_};
73   return nolock_retrieve(cache, key);
74 }
75 
76 // precondition: graph has been normalized via normalizeGraphForCache
lookupGraph(const std::shared_ptr<Graph> & graph)77 std::optional<KernelSpec*> lookupGraph(const std::shared_ptr<Graph>& graph) {
78   auto& cache = getKernelCache();
79   std::string repr = graph->toString(false);
80 
81   std::lock_guard<std::mutex> guard{cache.mutex_};
82   auto it = cache.graphToKey_.find(repr);
83   if (it == cache.graphToKey_.end())
84     return std::nullopt;
85   return nolock_retrieve(cache, it->second);
86 }
87 
88 } // namespace torch::jit::fuser
89