1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_LRU_CACHE_H_ 17 #define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_LRU_CACHE_H_ 18 19 #include <list> 20 #include <thread> 21 #include <unordered_map> 22 23 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h" 24 #include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" 25 #include "tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.h" 26 #include "tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h" 27 #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" 28 #include "tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h" 29 #include "tensorflow/core/framework/resource_mgr.h" 30 #include "tensorflow/core/lib/core/errors.h" 31 32 #if GOOGLE_CUDA && GOOGLE_TENSORRT 33 #include "third_party/tensorrt/NvInfer.h" 34 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT 35 36 namespace tensorflow { 37 namespace tensorrt { 38 39 template <class Key, class Value, class HashFunction> 40 class LRUCache { 41 public: 42 typedef Value value_type; 43 typedef Key key_type; 44 typedef HashFunction hasher; 45 typedef typename std::unordered_map<key_type, value_type, hasher> map_type; 46 typedef typename map_type::iterator iterator; 47 typedef typename map_type::const_iterator const_iterator; 48 LRUCache()49 LRUCache() : capacity_(0) {} LRUCache(size_t capacity)50 explicit LRUCache(size_t capacity) : capacity_(capacity) {} 51 capacity()52 size_t capacity() const { return capacity_; } 53 reserve(size_t capacity)54 void reserve(size_t capacity) { 55 capacity_ = capacity; 56 DiscardOld(); 57 } 58 size()59 size_t size() const { return objects_.size(); } 60 count(const key_type & key)61 size_t count(const key_type& key) const { return objects_.count(key); } 62 at(const key_type & key)63 value_type& at(const key_type& key) { return Touch(key); } 64 begin()65 const_iterator begin() const { return objects_.begin(); } end()66 const_iterator end() const { return objects_.end(); } 67 begin()68 iterator begin() { return objects_.begin(); } end()69 iterator end() { return objects_.end(); } 70 71 template <typename... Args> emplace(Args &&...args)72 std::pair<iterator, bool> emplace(Args&&... args) { 73 DiscardOld(1); 74 std::pair<iterator, bool> result = 75 objects_.emplace(std::forward<Args>(args)...); 76 key_type key = result.first->first; 77 if (result.second) { 78 keys_.push_front(key); 79 } else { 80 TouchNoCheck(key); // The key must exist in this case. 81 } 82 return result; 83 } 84 85 private: 86 std::unordered_map<key_type, value_type, hasher> objects_; 87 std::list<key_type> keys_; 88 size_t capacity_; 89 value_type not_found_value_; 90 Touch(const key_type & key)91 value_type& Touch(const key_type& key) { 92 // Check that the key exists, and let it return std::out_of_range error if 93 // not. 94 value_type& value = objects_.at(key); 95 TouchNoCheck(key); 96 return value; 97 } 98 TouchNoCheck(const key_type & key)99 void TouchNoCheck(const key_type& key) { 100 auto rank = std::find(keys_.begin(), keys_.end(), key); 101 if (rank != keys_.begin()) { 102 keys_.erase(rank); 103 keys_.push_front(key); 104 } 105 } 106 107 // Creates n free positions in cache 108 void DiscardOld(size_t n = 0) { 109 DCHECK(capacity_ >= n) << "Insufficient capacity in cache (capacity = " 110 << capacity_ << ", requested " << n << ")"; 111 while (objects_.size() > (capacity_ - n)) { 112 key_type discard_key = keys_.back(); 113 keys_.pop_back(); 114 objects_.erase(discard_key); 115 } 116 } 117 }; 118 119 #if GOOGLE_CUDA && GOOGLE_TENSORRT 120 121 struct EngineContext { EngineContextEngineContext122 EngineContext() {} // Creates an empty context. EngineContextEngineContext123 EngineContext(TrtUniquePtrType<nvinfer1::ICudaEngine>&& cuda_engine, 124 ExecutionContext&& execution_context) 125 : cuda_engine_(std::move(cuda_engine)) { 126 execution_contexts.push_back(std::move(execution_context)); 127 device_memory_size_ = 128 cuda_engine_ ? cuda_engine_->getDeviceMemorySize() : 0; 129 } EngineContextEngineContext130 EngineContext(TrtUniquePtrType<nvinfer1::ICudaEngine>&& cuda_engine, 131 std::vector<ExecutionContext>&& execution_contexts) 132 : cuda_engine_(std::move(cuda_engine)), 133 execution_contexts(std::move(execution_contexts)) { 134 device_memory_size_ = 135 cuda_engine_ ? cuda_engine_->getDeviceMemorySize() : 0; 136 } 137 138 mutex mu; 139 GetCudaEngineEngineContext140 nvinfer1::ICudaEngine* GetCudaEngine() { return cuda_engine_.get(); } 141 GetExecutionContextEngineContext142 Status GetExecutionContext(int idx, nvinfer1::IExecutionContext** exec_ctx, 143 bool* has_device_memory) 144 TF_EXCLUSIVE_LOCKS_REQUIRED(mu) { 145 if (idx >= execution_contexts.size()) { 146 return errors::Internal("Requested engine context with index ", idx, 147 ", but only ", execution_contexts.size(), 148 "contexts are present."); 149 } 150 *exec_ctx = execution_contexts[idx].get(); 151 *has_device_memory = execution_contexts[idx].HasDeviceMemory(); 152 return Status::OK(); 153 } 154 GetNumContextsEngineContext155 int GetNumContexts() { 156 mutex_lock lock(mu); 157 return execution_contexts.size(); 158 } 159 GetDeviceMemorySizeEngineContext160 size_t GetDeviceMemorySize() { return device_memory_size_; } 161 162 private: 163 // Note: declaration has to come before execution_contexts, to ensure proper 164 // order of destruction. 165 TrtUniquePtrType<nvinfer1::ICudaEngine> cuda_engine_; 166 167 public: 168 // In explicit batch mode, we maintain a vector of contexts for each engine, 169 // where each context is created for a specific profile. This is because it is 170 // either not possible or non-trivial to change the profile of a context for 171 // the following reasons: 172 // - To switch profiles (from TRT 7), one must first ensure that all inference 173 // calls in that context are finished. This would require an additional 174 // synchronization before we call setOptimizationProfile. To avoid this 175 // extra sync call, we mantain separate execution context for each profile. 176 // IExecutionContext object is not thread safe: only one thread should use it 177 // for inference at a time therefore we need a mutex. More details at 178 // https://docs.nvidia.com/deeplearning/sdk/tensorrt-best-practices/index.html#thread-safety 179 // Additional discussion about execution context management and thread safety 180 // at https://github.com/tensorflow/tensorflow/issues/36959 181 std::vector<ExecutionContext> execution_contexts TF_GUARDED_BY(mu); 182 183 private: 184 // Until TRT 8.4 ICudaEngine::getDeviceMemorySize() has a non-negligible 185 // latency. Since its value remains constant, we can cache it. 186 size_t device_memory_size_; 187 }; 188 // Contains the context required to build the calibration data. 189 class CalibrationContext { 190 public: 191 string TerminateCalibration(); 192 193 // Lookup table for temporary staging areas of input tensors for calibration. 194 std::unordered_map<string, std::pair<void*, size_t>> device_buffers_; 195 196 // Temporary staging areas for calibration inputs. 197 std::vector<Tensor> device_tensors_; 198 199 std::unique_ptr<TRTInt8Calibrator> calibrator_; 200 TrtUniquePtrType<nvinfer1::IBuilder> builder_; 201 TrtUniquePtrType<nvinfer1::ICudaEngine> engine_; 202 // TODO(sami): Use threadpool threads! 203 std::unique_ptr<std::thread> thr_; 204 205 private: 206 mutex mu_; 207 bool terminated_ TF_GUARDED_BY(mu_) = false; 208 std::string calibration_table_ TF_GUARDED_BY(mu_); 209 }; 210 211 ABSL_CONST_INIT extern const absl::string_view kTfTrtContainerName; 212 213 class TRTEngineCacheResource : public ResourceBase { 214 public: 215 // According to the TensorRT API, the logger is considered a singleton by the 216 // TensorRT library, and multiple instances of IRuntime and/or IBuilder must 217 // all use the same logger. So here we make it a singleton. 218 // 219 // TODO(laigd): use this logger in all places where conversion happens. 220 static Logger& GetLogger(); 221 222 TRTEngineCacheResource(OpKernelContext* ctx, size_t capacity); 223 224 ~TRTEngineCacheResource() override; 225 226 string DebugString() const override; 227 228 // Returns the EngineContext that is compatible with input_shapes. 229 // Returns nullptr if no compatible EngineContexts is found in cache. 230 EngineContext* GetEngineContext(const std::vector<TensorShape>& input_shapes); 231 232 // Returns the EngineContext that is compatible with profile_id. 233 // This function should be only called in explicit batch mode where 234 // cache size is expected to be at most one. 235 // Returns nullptr if no compatible EngineContexts is found in cache. 236 EngineContext* GetEngineContext(const int profile_id); 237 238 // Keep device allocator for TRT. 239 std::unique_ptr<TRTBaseAllocator> allocator_; 240 241 // Declare cache after allocator so that it is destroyed before allocator is. 242 LRUCache<std::vector<TensorShape>, std::unique_ptr<EngineContext>, 243 VectorTensorShapeHasher> 244 cache_; 245 246 // TODO(hinsu): Use different calibration context for the available shapes and 247 // attach it to each item of the cache. 248 std::unique_ptr<CalibrationContext> calib_ctx_; 249 250 // This object maintains all the optimization profiles during profile 251 // generation and engine build. During runtime the list of profiles is used to 252 // look up a matching profile for the input data. 253 TrtShapeOptimizationProfile profiles_; 254 }; 255 256 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT 257 258 } // namespace tensorrt 259 } // namespace tensorflow 260 261 #endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_LRU_CACHE_H_ 262