xref: /aosp_15_r20/external/tensorflow/tensorflow/core/tfrt/eager/op_cache.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_TFRT_EAGER_OP_CACHE_H_
16 #define TENSORFLOW_CORE_TFRT_EAGER_OP_CACHE_H_
17 
18 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
19 #include "tensorflow/core/common_runtime/eager/context.h"
20 #include "tensorflow/core/framework/function.h"
21 #include "tensorflow/core/framework/types.h"
22 #include "tensorflow/core/lib/core/errors.h"
23 #include "tensorflow/core/platform/fingerprint.h"
24 #include "tensorflow/core/platform/status.h"
25 #include "tensorflow/core/tfrt/utils/utils.h"
26 #include "tfrt/core_runtime/core_runtime.h"  // from @tf_runtime
27 #include "tfrt/core_runtime/core_runtime_op.h"  // from @tf_runtime
28 #include "tfrt/core_runtime/op_attrs.h"  // from @tf_runtime
29 #include "tfrt/core_runtime/op_handler.h"  // from @tf_runtime
30 #include "tfrt/host_context/host_context.h"  // from @tf_runtime
31 #include "tfrt/support/error_util.h"  // from @tf_runtime
32 #include "tfrt/support/mutex.h"  // from @tf_runtime
33 #include "tfrt/support/ref_count.h"  // from @tf_runtime
34 #include "tfrt/support/string_util.h"  // from @tf_runtime
35 
36 namespace tfrt {
37 namespace tf {
38 
39 class ContextInterface;
40 class OperationInterface;
41 
42 // Cache for a single core runtime op. Thread safe.
43 class OpCache {
44  public:
45   // Helper function to look up the cache. If miss, insert the CoreRuntimeOp
46   // to the cache.
47   Expected<CoreRuntimeOp*> GetOrAddOp(string_view op_name,
48                                       OpHandler* op_handler,
49                                       string_view device_name,
50                                       llvm::SmallVector<string_view, 4> dtypes,
51                                       OperationInterface* const op_interface)
52       TFRT_EXCLUDES(cache_mu_);
53 
54   // Compile with XLA is currently supported via fallback, and the compilation
55   // result is a CoreRuntimeOp.
56   // TODO(tfrt-devs): Native support of compile_with_xla.
57   Expected<CoreRuntimeOp*> GetOrAddXlaOp(string_view op_name,
58                                          ContextInterface* context)
59       TFRT_EXCLUDES(cache_mu_);
60 
61   // The following helper functions are for debugging and testing only.
Size()62   size_t Size() const {
63     mutex_lock l(cache_mu_);
64     return cache_.size();
65   }
66 
Contains(string_view op_name,OpHandler * op_handler,string_view device_name,llvm::SmallVector<string_view,4> dtypes)67   bool Contains(string_view op_name, OpHandler* op_handler,
68                 string_view device_name,
69                 llvm::SmallVector<string_view, 4> dtypes) const {
70     const CacheKey& cache_key{op_name, op_handler,
71                               (op_handler == nullptr ? device_name : ""),
72                               dtypes};
73     mutex_lock l(cache_mu_);
74     return cache_.find(cache_key) != cache_.end();
75   }
76 
77  private:
78   class CacheKey {
79    public:
CacheKey(string_view op_name,OpHandler * op_handler,string_view device_name,llvm::SmallVector<string_view,4> dtypes)80     CacheKey(string_view op_name, OpHandler* op_handler,
81              string_view device_name, llvm::SmallVector<string_view, 4> dtypes)
82         : op_handler_(op_handler),
83           op_name_(op_name),
84           device_name_(device_name),
85           dtypes_(dtypes) {}
86 
CacheKey(const CacheKey & other)87     CacheKey(const CacheKey& other)
88         : op_handler_(other.op_handler_),
89           op_name_(other.op_name_),
90           device_name_(other.device_name_),
91           dtypes_(other.dtypes_) {
92       // Copy the concrete strings if the key is concrete, and set the
93       // string_views to refer to the concrete strings.
94       if (other.is_concrete_) {
95         op_name_concrete_ = other.op_name_concrete_;
96         op_name_ = op_name_concrete_.data();
97         device_name_concrete_ = other.device_name_concrete_;
98         device_name_ = device_name_concrete_.data();
99         size_t n = other.dtypes_concrete_.size();
100         dtypes_concrete_.reserve(n);
101         dtypes_.clear();
102         for (size_t i = 0; i < n; ++i) {
103           dtypes_concrete_.push_back(other.dtypes_concrete_[i]);
104           dtypes_.push_back(dtypes_concrete_[i].data());
105         }
106         is_concrete_ = true;
107       }
108     }
109 
110     // Make the cache key concrete by copying the key components (strings) to
111     // internal storage.
MakeConcrete()112     void MakeConcrete() {
113       op_name_concrete_ = op_name_.str();
114       device_name_concrete_ = device_name_.str();
115       dtypes_concrete_.reserve(dtypes_.size());
116       for (const auto& dtype : dtypes_) dtypes_concrete_.push_back(dtype.str());
117       is_concrete_ = true;
118     }
119 
120     bool operator==(const CacheKey& other) const {
121       // During comparing keys, self or other can be either concrete or not.
122       // If a CacheKey is concrete, it's likely that the string_view fields
123       // are not valid (for example the key is obtained from the cache). We
124       // need to make the string_view fields refer to the concrete fields
125       // by constructing copies of them.
126       CacheKey lhs{*this};
127       CacheKey rhs{other};
128 
129       if (lhs.op_handler_ != rhs.op_handler_) return false;
130       if (lhs.dtypes_.size() != rhs.dtypes_.size()) return false;
131 
132       for (size_t i = 0, n = lhs.dtypes_.size(); i < n; ++i) {
133         if (lhs.dtypes_[i] != rhs.dtypes_[i]) return false;
134       }
135       return (lhs.op_name_ == rhs.op_name_ &&
136               lhs.device_name_ == rhs.device_name_);
137     }
138 
OpName()139     string_view OpName() { return op_name_; }
140 
DeviceName()141     string_view DeviceName() { return device_name_; }
142 
Dtypes()143     const llvm::SmallVector<string_view, 4>& Dtypes() { return dtypes_; }
144 
145    private:
146     class OpHandler* op_handler_;
147     // friend size_t CacheKeyHash::operator()(const CacheKey& input_key);
148     // string_view is used for efficient cache look up to avoid string copy.
149     string_view op_name_, device_name_;
150     llvm::SmallVector<string_view, 4> dtypes_;
151 
152     // Concrete string is used for storing cache key, since the lifetime
153     // of the strings should be the same as the container.
154     bool is_concrete_ = false;
155     std::string op_name_concrete_, device_name_concrete_;
156     llvm::SmallVector<std::string, 4> dtypes_concrete_;
157   };
158 
159   class CacheKeyHash {
160    public:
FingerprintCat128(const tensorflow::Fprint128 & a,const tensorflow::Fprint128 & b)161     tensorflow::Fprint128 FingerprintCat128(
162         const tensorflow::Fprint128& a, const tensorflow::Fprint128& b) const {
163       return {tensorflow::FingerprintCat64(a.low64, b.low64),
164               tensorflow::FingerprintCat64(a.high64, b.high64)};
165     }
166 
operator()167     size_t operator()(const CacheKey& input_key) const {
168       CacheKey key{input_key};
169       tensorflow::Fprint128 hash = tensorflow::Fingerprint128(
170           {key.OpName().data(), key.OpName().size()});
171       hash = FingerprintCat128(
172           hash, tensorflow::Fingerprint128(
173                     {key.DeviceName().data(), key.DeviceName().size()}));
174       for (const auto& dtype : key.Dtypes())
175         hash = FingerprintCat128(
176             hash, tensorflow::Fingerprint128({dtype.data(), dtype.size()}));
177       return hash.high64 ^ hash.low64;
178     }
179   };
180 
181   mutable mutex cache_mu_;
182   std::unordered_map<CacheKey, CoreRuntimeOp, CacheKeyHash> cache_
183       TFRT_GUARDED_BY(cache_mu_);
184 };
185 
186 }  // namespace tf
187 }  // namespace tfrt
188 
189 #endif  // TENSORFLOW_CORE_TFRT_EAGER_OP_CACHE_H_
190