xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2tensorrt/convert/op_converter_registry.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/tf2tensorrt/convert/op_converter_registry.h"
16 
17 #include <set>
18 #include <utility>
19 
20 #include "tensorflow/core/platform/mutex.h"
21 
22 #if GOOGLE_CUDA && GOOGLE_TENSORRT
23 
24 namespace tensorflow {
25 namespace tensorrt {
26 namespace convert {
27 
28 struct OpConverterRegistration {
29   OpConverter converter;
30   int priority;
31 };
32 class OpConverterRegistry::Impl {
33  public:
34   ~Impl() = default;
35 
Register(const string & name,const int priority,OpConverter converter)36   InitOnStartupMarker Register(const string& name, const int priority,
37                                OpConverter converter) {
38     mutex_lock lock(mu_);
39     auto item = registry_.find(name);
40     if (item != registry_.end()) {
41       const int existing_priority = item->second.priority;
42       if (priority <= existing_priority) {
43         LOG(WARNING) << absl::StrCat(
44             "Ignoring TF->TRT ", name, " op converter with priority ",
45             existing_priority, " due to another converter with priority ",
46             priority);
47         return {};
48       } else {
49         LOG(WARNING) << absl::StrCat(
50             "Overwriting TF->TRT ", name, " op converter with priority ",
51             existing_priority, " using another converter with priority ",
52             priority);
53         registry_.erase(item);
54       }
55     }
56     registry_.insert({name, OpConverterRegistration{converter, priority}});
57     return {};
58   }
59 
LookUp(const string & name)60   StatusOr<OpConverter> LookUp(const string& name) {
61     mutex_lock lock(mu_);
62     auto found = registry_.find(name);
63     if (found != registry_.end()) {
64       return found->second.converter;
65     }
66     return errors::NotFound("No converter for op ", name);
67   }
68 
Clear(const std::string & name)69   void Clear(const std::string& name) {
70     mutex_lock lock(mu_);
71     auto itr = registry_.find(name);
72     if (itr == registry_.end()) {
73       return;
74     }
75     registry_.erase(itr);
76   }
77 
ListRegisteredOps() const78   std::vector<std::string> ListRegisteredOps() const {
79     mutex_lock lock(mu_);
80     std::vector<std::string> result;
81     result.reserve(registry_.size());
82     for (const auto& item : registry_) {
83       result.push_back(item.first);
84     }
85     return result;
86   }
87 
88  private:
89   mutable mutex mu_;
90   mutable std::unordered_map<std::string, OpConverterRegistration> registry_
91       TF_GUARDED_BY(mu_);
92 };
93 
OpConverterRegistry()94 OpConverterRegistry::OpConverterRegistry() : impl_(std::make_unique<Impl>()) {}
95 
LookUp(const string & name)96 StatusOr<OpConverter> OpConverterRegistry::LookUp(const string& name) {
97   return impl_->LookUp(name);
98 }
99 
Register(const string & name,const int priority,OpConverter converter)100 InitOnStartupMarker OpConverterRegistry::Register(const string& name,
101                                                   const int priority,
102                                                   OpConverter converter) {
103   return impl_->Register(name, priority, converter);
104 }
105 
ListRegisteredOps() const106 std::vector<std::string> OpConverterRegistry::ListRegisteredOps() const {
107   return impl_->ListRegisteredOps();
108 }
109 
Clear(const std::string & name)110 void OpConverterRegistry::Clear(const std::string& name) { impl_->Clear(name); }
111 
GetOpConverterRegistry()112 OpConverterRegistry* GetOpConverterRegistry() {
113   static OpConverterRegistry* registry = new OpConverterRegistry();
114   return registry;
115 }
116 
117 }  // namespace convert
118 }  // namespace tensorrt
119 }  // namespace tensorflow
120 
121 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
122