1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/tf2tensorrt/convert/op_converter_registry.h"
16
17 #include <set>
18 #include <utility>
19
20 #include "tensorflow/core/platform/mutex.h"
21
22 #if GOOGLE_CUDA && GOOGLE_TENSORRT
23
24 namespace tensorflow {
25 namespace tensorrt {
26 namespace convert {
27
28 struct OpConverterRegistration {
29 OpConverter converter;
30 int priority;
31 };
32 class OpConverterRegistry::Impl {
33 public:
34 ~Impl() = default;
35
Register(const string & name,const int priority,OpConverter converter)36 InitOnStartupMarker Register(const string& name, const int priority,
37 OpConverter converter) {
38 mutex_lock lock(mu_);
39 auto item = registry_.find(name);
40 if (item != registry_.end()) {
41 const int existing_priority = item->second.priority;
42 if (priority <= existing_priority) {
43 LOG(WARNING) << absl::StrCat(
44 "Ignoring TF->TRT ", name, " op converter with priority ",
45 existing_priority, " due to another converter with priority ",
46 priority);
47 return {};
48 } else {
49 LOG(WARNING) << absl::StrCat(
50 "Overwriting TF->TRT ", name, " op converter with priority ",
51 existing_priority, " using another converter with priority ",
52 priority);
53 registry_.erase(item);
54 }
55 }
56 registry_.insert({name, OpConverterRegistration{converter, priority}});
57 return {};
58 }
59
LookUp(const string & name)60 StatusOr<OpConverter> LookUp(const string& name) {
61 mutex_lock lock(mu_);
62 auto found = registry_.find(name);
63 if (found != registry_.end()) {
64 return found->second.converter;
65 }
66 return errors::NotFound("No converter for op ", name);
67 }
68
Clear(const std::string & name)69 void Clear(const std::string& name) {
70 mutex_lock lock(mu_);
71 auto itr = registry_.find(name);
72 if (itr == registry_.end()) {
73 return;
74 }
75 registry_.erase(itr);
76 }
77
ListRegisteredOps() const78 std::vector<std::string> ListRegisteredOps() const {
79 mutex_lock lock(mu_);
80 std::vector<std::string> result;
81 result.reserve(registry_.size());
82 for (const auto& item : registry_) {
83 result.push_back(item.first);
84 }
85 return result;
86 }
87
88 private:
89 mutable mutex mu_;
90 mutable std::unordered_map<std::string, OpConverterRegistration> registry_
91 TF_GUARDED_BY(mu_);
92 };
93
OpConverterRegistry()94 OpConverterRegistry::OpConverterRegistry() : impl_(std::make_unique<Impl>()) {}
95
LookUp(const string & name)96 StatusOr<OpConverter> OpConverterRegistry::LookUp(const string& name) {
97 return impl_->LookUp(name);
98 }
99
Register(const string & name,const int priority,OpConverter converter)100 InitOnStartupMarker OpConverterRegistry::Register(const string& name,
101 const int priority,
102 OpConverter converter) {
103 return impl_->Register(name, priority, converter);
104 }
105
ListRegisteredOps() const106 std::vector<std::string> OpConverterRegistry::ListRegisteredOps() const {
107 return impl_->ListRegisteredOps();
108 }
109
Clear(const std::string & name)110 void OpConverterRegistry::Clear(const std::string& name) { impl_->Clear(name); }
111
GetOpConverterRegistry()112 OpConverterRegistry* GetOpConverterRegistry() {
113 static OpConverterRegistry* registry = new OpConverterRegistry();
114 return registry;
115 }
116
117 } // namespace convert
118 } // namespace tensorrt
119 } // namespace tensorflow
120
121 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT
122