1 #include <torch/csrc/jit/codegen/fuser/kernel_cache.h>
2 #include <torch/csrc/jit/passes/canonicalize.h>
3 #include <torch/csrc/jit/passes/shape_analysis.h>
4
5 #include <cstdint>
6 #include <mutex>
7 #include <unordered_map>
8
9 namespace torch::jit::fuser {
10
11 struct KernelCacheImpl {
12 // Note: std::unordered_map does not invalidate references even if rehashing
13 // occurs. This is a critical property for thread-safety.
14 std::mutex mutex_;
15 int64_t kernel_counter{0};
16
17 // Map of fusion key to KernelSpec
18 std::unordered_map<int64_t, KernelSpec> specMap_;
19
20 // Map of pretty-printed graph string to fusion key
21 // Used to check if a graph has already been cached in specMap_
22 std::unordered_map<std::string, int64_t> graphToKey_;
23 };
24
getKernelCache()25 static KernelCacheImpl& getKernelCache() {
26 static KernelCacheImpl cache;
27 return cache;
28 }
29
debugNumCachedKernelSpecs()30 int64_t debugNumCachedKernelSpecs() {
31 auto& cache = getKernelCache();
32 std::lock_guard<std::mutex> guard{cache.mutex_};
33 return cache.specMap_.size();
34 }
35
normalizeGraphForCache(const std::shared_ptr<Graph> & graph)36 std::shared_ptr<Graph> normalizeGraphForCache(
37 const std::shared_ptr<Graph>& graph) {
38 auto result = Canonicalize(graph, /*keep_unique_names=*/false);
39 EraseShapeInformation(result);
40 return result;
41 }
42
43 // TODO: lookup by historic string key to start, then issue key
44 // as appropriate for faster lookup in the future
45 // precondition: graph has been normalized via normalizeGraphForCache
store(std::shared_ptr<Graph> graph)46 int64_t store(std::shared_ptr<Graph> graph) {
47 auto& cache = getKernelCache();
48 std::string repr = graph->toString(false);
49
50 std::lock_guard<std::mutex> guard{cache.mutex_};
51 const auto key = cache.kernel_counter++;
52 cache.specMap_.emplace(
53 std::piecewise_construct,
54 std::forward_as_tuple(key),
55 std::forward_as_tuple(key, graph));
56 cache.graphToKey_.emplace(std::move(repr), key);
57 return key;
58 }
59
60 // XXX: Does not grab mutex
nolock_retrieve(KernelCacheImpl & cache,const int64_t key)61 static std::optional<KernelSpec*> nolock_retrieve(
62 KernelCacheImpl& cache,
63 const int64_t key) {
64 auto it = cache.specMap_.find(key);
65 if (it == cache.specMap_.end())
66 return std::nullopt;
67 return &(it->second);
68 }
69
retrieve(const int64_t key)70 std::optional<KernelSpec*> retrieve(const int64_t key) {
71 auto& cache = getKernelCache();
72 std::lock_guard<std::mutex> guard{cache.mutex_};
73 return nolock_retrieve(cache, key);
74 }
75
76 // precondition: graph has been normalized via normalizeGraphForCache
lookupGraph(const std::shared_ptr<Graph> & graph)77 std::optional<KernelSpec*> lookupGraph(const std::shared_ptr<Graph>& graph) {
78 auto& cache = getKernelCache();
79 std::string repr = graph->toString(false);
80
81 std::lock_guard<std::mutex> guard{cache.mutex_};
82 auto it = cache.graphToKey_.find(repr);
83 if (it == cache.graphToKey_.end())
84 return std::nullopt;
85 return nolock_retrieve(cache, it->second);
86 }
87
88 } // namespace torch::jit::fuser
89