1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_TIMING_CACHE_H_ 16 #define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_TIMING_CACHE_H_ 17 #if GOOGLE_CUDA && GOOGLE_TENSORRT 18 #include <unordered_map> 19 20 #include "tensorflow/compiler/tf2tensorrt/common/utils.h" 21 #include "tensorflow/core/framework/registration/registration.h" 22 #include "tensorflow/core/platform/mutex.h" 23 #include "tensorflow/core/platform/statusor.h" 24 #include "third_party/tensorrt/NvInfer.h" 25 26 namespace tensorflow { 27 namespace tensorrt { 28 namespace convert { 29 30 // A registry for holding serialized TensorRT autotuner timing caches. 31 // For TensorRT versions < 8.0, the timing cache is not serializable, so these 32 // operations become no-ops. 33 class TimingCacheRegistry { 34 public: 35 TimingCacheRegistry() = default; 36 ~TimingCacheRegistry() = default; 37 38 #if IS_TRT_VERSION_GE(8, 0, 0, 0) 39 using TimingCache = nvinfer1::ITimingCache; 40 using TimingCachePtr = std::unique_ptr<TimingCache>; 41 #else 42 struct TimingCache {}; 43 using TimingCachePtr = std::unique_ptr<TimingCache>; 44 #endif 45 46 // Insert or update a registry into the map using the given name. The cache 47 // will be serialized before being placed into the map. 48 void Upsert(const string& name, TimingCache* cache); 49 50 // Find a timing cache using the given name. The provided BuilderConfig is 51 // used to deserialize the cache. If no timing cache is found, a new timing 52 // cache is returned. 53 StatusOr<TimingCachePtr> LookUp(const string& name, 54 nvinfer1::IBuilderConfig* builder_config); 55 56 private: 57 using SerializedTimingCache = std::vector<uint8_t>; 58 59 mutex mu_; 60 std::unordered_map<std::string, SerializedTimingCache> map_; 61 }; 62 63 TimingCacheRegistry* GetTimingCacheRegistry(); 64 65 } // namespace convert 66 } // namespace tensorrt 67 } // namespace tensorflow 68 69 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT 70 #endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_TIMING_CACHE_H_ 71