1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_TFRT_EAGER_OP_CACHE_H_ 16 #define TENSORFLOW_CORE_TFRT_EAGER_OP_CACHE_H_ 17 18 #include "tensorflow/core/common_runtime/eager/attr_builder.h" 19 #include "tensorflow/core/common_runtime/eager/context.h" 20 #include "tensorflow/core/framework/function.h" 21 #include "tensorflow/core/framework/types.h" 22 #include "tensorflow/core/lib/core/errors.h" 23 #include "tensorflow/core/platform/fingerprint.h" 24 #include "tensorflow/core/platform/status.h" 25 #include "tensorflow/core/tfrt/utils/utils.h" 26 #include "tfrt/core_runtime/core_runtime.h" // from @tf_runtime 27 #include "tfrt/core_runtime/core_runtime_op.h" // from @tf_runtime 28 #include "tfrt/core_runtime/op_attrs.h" // from @tf_runtime 29 #include "tfrt/core_runtime/op_handler.h" // from @tf_runtime 30 #include "tfrt/host_context/host_context.h" // from @tf_runtime 31 #include "tfrt/support/error_util.h" // from @tf_runtime 32 #include "tfrt/support/mutex.h" // from @tf_runtime 33 #include "tfrt/support/ref_count.h" // from @tf_runtime 34 #include "tfrt/support/string_util.h" // from @tf_runtime 35 36 namespace tfrt { 37 namespace tf { 38 39 class ContextInterface; 40 class OperationInterface; 41 42 // Cache for a single core runtime op. Thread safe. 43 class OpCache { 44 public: 45 // Helper function to look up the cache. If miss, insert the CoreRuntimeOp 46 // to the cache. 47 Expected<CoreRuntimeOp*> GetOrAddOp(string_view op_name, 48 OpHandler* op_handler, 49 string_view device_name, 50 llvm::SmallVector<string_view, 4> dtypes, 51 OperationInterface* const op_interface) 52 TFRT_EXCLUDES(cache_mu_); 53 54 // Compile with XLA is currently supported via fallback, and the compilation 55 // result is a CoreRuntimeOp. 56 // TODO(tfrt-devs): Native support of compile_with_xla. 57 Expected<CoreRuntimeOp*> GetOrAddXlaOp(string_view op_name, 58 ContextInterface* context) 59 TFRT_EXCLUDES(cache_mu_); 60 61 // The following helper functions are for debugging and testing only. Size()62 size_t Size() const { 63 mutex_lock l(cache_mu_); 64 return cache_.size(); 65 } 66 Contains(string_view op_name,OpHandler * op_handler,string_view device_name,llvm::SmallVector<string_view,4> dtypes)67 bool Contains(string_view op_name, OpHandler* op_handler, 68 string_view device_name, 69 llvm::SmallVector<string_view, 4> dtypes) const { 70 const CacheKey& cache_key{op_name, op_handler, 71 (op_handler == nullptr ? device_name : ""), 72 dtypes}; 73 mutex_lock l(cache_mu_); 74 return cache_.find(cache_key) != cache_.end(); 75 } 76 77 private: 78 class CacheKey { 79 public: CacheKey(string_view op_name,OpHandler * op_handler,string_view device_name,llvm::SmallVector<string_view,4> dtypes)80 CacheKey(string_view op_name, OpHandler* op_handler, 81 string_view device_name, llvm::SmallVector<string_view, 4> dtypes) 82 : op_handler_(op_handler), 83 op_name_(op_name), 84 device_name_(device_name), 85 dtypes_(dtypes) {} 86 CacheKey(const CacheKey & other)87 CacheKey(const CacheKey& other) 88 : op_handler_(other.op_handler_), 89 op_name_(other.op_name_), 90 device_name_(other.device_name_), 91 dtypes_(other.dtypes_) { 92 // Copy the concrete strings if the key is concrete, and set the 93 // string_views to refer to the concrete strings. 94 if (other.is_concrete_) { 95 op_name_concrete_ = other.op_name_concrete_; 96 op_name_ = op_name_concrete_.data(); 97 device_name_concrete_ = other.device_name_concrete_; 98 device_name_ = device_name_concrete_.data(); 99 size_t n = other.dtypes_concrete_.size(); 100 dtypes_concrete_.reserve(n); 101 dtypes_.clear(); 102 for (size_t i = 0; i < n; ++i) { 103 dtypes_concrete_.push_back(other.dtypes_concrete_[i]); 104 dtypes_.push_back(dtypes_concrete_[i].data()); 105 } 106 is_concrete_ = true; 107 } 108 } 109 110 // Make the cache key concrete by copying the key components (strings) to 111 // internal storage. MakeConcrete()112 void MakeConcrete() { 113 op_name_concrete_ = op_name_.str(); 114 device_name_concrete_ = device_name_.str(); 115 dtypes_concrete_.reserve(dtypes_.size()); 116 for (const auto& dtype : dtypes_) dtypes_concrete_.push_back(dtype.str()); 117 is_concrete_ = true; 118 } 119 120 bool operator==(const CacheKey& other) const { 121 // During comparing keys, self or other can be either concrete or not. 122 // If a CacheKey is concrete, it's likely that the string_view fields 123 // are not valid (for example the key is obtained from the cache). We 124 // need to make the string_view fields refer to the concrete fields 125 // by constructing copies of them. 126 CacheKey lhs{*this}; 127 CacheKey rhs{other}; 128 129 if (lhs.op_handler_ != rhs.op_handler_) return false; 130 if (lhs.dtypes_.size() != rhs.dtypes_.size()) return false; 131 132 for (size_t i = 0, n = lhs.dtypes_.size(); i < n; ++i) { 133 if (lhs.dtypes_[i] != rhs.dtypes_[i]) return false; 134 } 135 return (lhs.op_name_ == rhs.op_name_ && 136 lhs.device_name_ == rhs.device_name_); 137 } 138 OpName()139 string_view OpName() { return op_name_; } 140 DeviceName()141 string_view DeviceName() { return device_name_; } 142 Dtypes()143 const llvm::SmallVector<string_view, 4>& Dtypes() { return dtypes_; } 144 145 private: 146 class OpHandler* op_handler_; 147 // friend size_t CacheKeyHash::operator()(const CacheKey& input_key); 148 // string_view is used for efficient cache look up to avoid string copy. 149 string_view op_name_, device_name_; 150 llvm::SmallVector<string_view, 4> dtypes_; 151 152 // Concrete string is used for storing cache key, since the lifetime 153 // of the strings should be the same as the container. 154 bool is_concrete_ = false; 155 std::string op_name_concrete_, device_name_concrete_; 156 llvm::SmallVector<std::string, 4> dtypes_concrete_; 157 }; 158 159 class CacheKeyHash { 160 public: FingerprintCat128(const tensorflow::Fprint128 & a,const tensorflow::Fprint128 & b)161 tensorflow::Fprint128 FingerprintCat128( 162 const tensorflow::Fprint128& a, const tensorflow::Fprint128& b) const { 163 return {tensorflow::FingerprintCat64(a.low64, b.low64), 164 tensorflow::FingerprintCat64(a.high64, b.high64)}; 165 } 166 operator()167 size_t operator()(const CacheKey& input_key) const { 168 CacheKey key{input_key}; 169 tensorflow::Fprint128 hash = tensorflow::Fingerprint128( 170 {key.OpName().data(), key.OpName().size()}); 171 hash = FingerprintCat128( 172 hash, tensorflow::Fingerprint128( 173 {key.DeviceName().data(), key.DeviceName().size()})); 174 for (const auto& dtype : key.Dtypes()) 175 hash = FingerprintCat128( 176 hash, tensorflow::Fingerprint128({dtype.data(), dtype.size()})); 177 return hash.high64 ^ hash.low64; 178 } 179 }; 180 181 mutable mutex cache_mu_; 182 std::unordered_map<CacheKey, CoreRuntimeOp, CacheKeyHash> cache_ 183 TFRT_GUARDED_BY(cache_mu_); 184 }; 185 186 } // namespace tf 187 } // namespace tfrt 188 189 #endif // TENSORFLOW_CORE_TFRT_EAGER_OP_CACHE_H_ 190