xref: /aosp_15_r20/external/tensorflow/tensorflow/core/framework/op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/op.h"
17 
18 #include <algorithm>
19 #include <memory>
20 #include <vector>
21 
22 #include "tensorflow/core/framework/full_type.pb.h"
23 #include "tensorflow/core/framework/op_def_builder.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/gtl/map_util.h"
26 #include "tensorflow/core/lib/strings/str_util.h"
27 #include "tensorflow/core/platform/host_info.h"
28 #include "tensorflow/core/platform/logging.h"
29 #include "tensorflow/core/platform/mutex.h"
30 #include "tensorflow/core/platform/protobuf.h"
31 #include "tensorflow/core/platform/types.h"
32 
33 namespace tensorflow {
34 
DefaultValidator(const OpRegistryInterface & op_registry)35 Status DefaultValidator(const OpRegistryInterface& op_registry) {
36   LOG(WARNING) << "No kernel validator registered with OpRegistry.";
37   return OkStatus();
38 }
39 
40 // OpRegistry -----------------------------------------------------------------
41 
~OpRegistryInterface()42 OpRegistryInterface::~OpRegistryInterface() {}
43 
LookUpOpDef(const string & op_type_name,const OpDef ** op_def) const44 Status OpRegistryInterface::LookUpOpDef(const string& op_type_name,
45                                         const OpDef** op_def) const {
46   *op_def = nullptr;
47   const OpRegistrationData* op_reg_data = nullptr;
48   TF_RETURN_IF_ERROR(LookUp(op_type_name, &op_reg_data));
49   *op_def = &op_reg_data->op_def;
50   return OkStatus();
51 }
52 
OpRegistry()53 OpRegistry::OpRegistry()
54     : initialized_(false), op_registry_validator_(DefaultValidator) {}
55 
~OpRegistry()56 OpRegistry::~OpRegistry() {
57   for (const auto& e : registry_) delete e.second;
58 }
59 
Register(const OpRegistrationDataFactory & op_data_factory)60 void OpRegistry::Register(const OpRegistrationDataFactory& op_data_factory) {
61   mutex_lock lock(mu_);
62   if (initialized_) {
63     TF_QCHECK_OK(RegisterAlreadyLocked(op_data_factory));
64   } else {
65     deferred_.push_back(op_data_factory);
66   }
67 }
68 
69 namespace {
70 // Helper function that returns Status message for failed LookUp.
OpNotFound(const string & op_type_name)71 Status OpNotFound(const string& op_type_name) {
72   Status status = errors::NotFound(
73       "Op type not registered '", op_type_name, "' in binary running on ",
74       port::Hostname(), ". ",
75       "Make sure the Op and Kernel are registered in the binary running in "
76       "this process. Note that if you are loading a saved graph which used ops "
77       "from tf.contrib, accessing (e.g.) `tf.contrib.resampler` should be done "
78       "before importing the graph, as contrib ops are lazily registered when "
79       "the module is first accessed.");
80   VLOG(1) << status.ToString();
81   return status;
82 }
83 }  // namespace
84 
LookUp(const string & op_type_name,const OpRegistrationData ** op_reg_data) const85 Status OpRegistry::LookUp(const string& op_type_name,
86                           const OpRegistrationData** op_reg_data) const {
87   if ((*op_reg_data = LookUp(op_type_name))) return OkStatus();
88   return OpNotFound(op_type_name);
89 }
90 
LookUp(const string & op_type_name) const91 const OpRegistrationData* OpRegistry::LookUp(const string& op_type_name) const {
92   {
93     tf_shared_lock l(mu_);
94     if (initialized_) {
95       if (const OpRegistrationData* res =
96               gtl::FindWithDefault(registry_, op_type_name, nullptr)) {
97         return res;
98       }
99     }
100   }
101   return LookUpSlow(op_type_name);
102 }
103 
LookUpSlow(const string & op_type_name) const104 const OpRegistrationData* OpRegistry::LookUpSlow(
105     const string& op_type_name) const {
106   const OpRegistrationData* res = nullptr;
107 
108   bool first_call = false;
109   bool first_unregistered = false;
110   {  // Scope for lock.
111     mutex_lock lock(mu_);
112     first_call = MustCallDeferred();
113     res = gtl::FindWithDefault(registry_, op_type_name, nullptr);
114 
115     static bool unregistered_before = false;
116     first_unregistered = !unregistered_before && (res == nullptr);
117     if (first_unregistered) {
118       unregistered_before = true;
119     }
120     // Note: Can't hold mu_ while calling Export() below.
121   }
122   if (first_call) {
123     TF_QCHECK_OK(op_registry_validator_(*this));
124   }
125   if (res == nullptr) {
126     if (first_unregistered) {
127       OpList op_list;
128       Export(true, &op_list);
129       if (VLOG_IS_ON(3)) {
130         LOG(INFO) << "All registered Ops:";
131         for (const auto& op : op_list.op()) {
132           LOG(INFO) << SummarizeOpDef(op);
133         }
134       }
135     }
136   }
137   return res;
138 }
139 
GetRegisteredOps(std::vector<OpDef> * op_defs)140 void OpRegistry::GetRegisteredOps(std::vector<OpDef>* op_defs) {
141   mutex_lock lock(mu_);
142   MustCallDeferred();
143   for (const auto& p : registry_) {
144     op_defs->push_back(p.second->op_def);
145   }
146 }
147 
GetOpRegistrationData(std::vector<OpRegistrationData> * op_data)148 void OpRegistry::GetOpRegistrationData(
149     std::vector<OpRegistrationData>* op_data) {
150   mutex_lock lock(mu_);
151   MustCallDeferred();
152   for (const auto& p : registry_) {
153     op_data->push_back(*p.second);
154   }
155 }
156 
SetWatcher(const Watcher & watcher)157 Status OpRegistry::SetWatcher(const Watcher& watcher) {
158   mutex_lock lock(mu_);
159   if (watcher_ && watcher) {
160     return errors::AlreadyExists(
161         "Cannot over-write a valid watcher with another.");
162   }
163   watcher_ = watcher;
164   return OkStatus();
165 }
166 
Export(bool include_internal,OpList * ops) const167 void OpRegistry::Export(bool include_internal, OpList* ops) const {
168   mutex_lock lock(mu_);
169   MustCallDeferred();
170 
171   std::vector<std::pair<string, const OpRegistrationData*>> sorted(
172       registry_.begin(), registry_.end());
173   std::sort(sorted.begin(), sorted.end());
174 
175   auto out = ops->mutable_op();
176   out->Clear();
177   out->Reserve(sorted.size());
178 
179   for (const auto& item : sorted) {
180     if (include_internal || !absl::StartsWith(item.first, "_")) {
181       *out->Add() = item.second->op_def;
182     }
183   }
184 }
185 
DeferRegistrations()186 void OpRegistry::DeferRegistrations() {
187   mutex_lock lock(mu_);
188   initialized_ = false;
189 }
190 
ClearDeferredRegistrations()191 void OpRegistry::ClearDeferredRegistrations() {
192   mutex_lock lock(mu_);
193   deferred_.clear();
194 }
195 
ProcessRegistrations() const196 Status OpRegistry::ProcessRegistrations() const {
197   mutex_lock lock(mu_);
198   return CallDeferred();
199 }
200 
DebugString(bool include_internal) const201 string OpRegistry::DebugString(bool include_internal) const {
202   OpList op_list;
203   Export(include_internal, &op_list);
204   string ret;
205   for (const auto& op : op_list.op()) {
206     strings::StrAppend(&ret, SummarizeOpDef(op), "\n");
207   }
208   return ret;
209 }
210 
MustCallDeferred() const211 bool OpRegistry::MustCallDeferred() const {
212   if (initialized_) return false;
213   initialized_ = true;
214   for (size_t i = 0; i < deferred_.size(); ++i) {
215     TF_QCHECK_OK(RegisterAlreadyLocked(deferred_[i]));
216   }
217   deferred_.clear();
218   return true;
219 }
220 
CallDeferred() const221 Status OpRegistry::CallDeferred() const {
222   if (initialized_) return OkStatus();
223   initialized_ = true;
224   for (size_t i = 0; i < deferred_.size(); ++i) {
225     Status s = RegisterAlreadyLocked(deferred_[i]);
226     if (!s.ok()) {
227       return s;
228     }
229   }
230   deferred_.clear();
231   return OkStatus();
232 }
233 
RegisterAlreadyLocked(const OpRegistrationDataFactory & op_data_factory) const234 Status OpRegistry::RegisterAlreadyLocked(
235     const OpRegistrationDataFactory& op_data_factory) const {
236   std::unique_ptr<OpRegistrationData> op_reg_data(new OpRegistrationData);
237   Status s = op_data_factory(op_reg_data.get());
238   if (s.ok()) {
239     s = ValidateOpDef(op_reg_data->op_def);
240     if (s.ok() &&
241         !gtl::InsertIfNotPresent(&registry_, op_reg_data->op_def.name(),
242                                  op_reg_data.get())) {
243       s = errors::AlreadyExists("Op with name ", op_reg_data->op_def.name());
244     }
245   }
246   Status watcher_status = s;
247   if (watcher_) {
248     watcher_status = watcher_(s, op_reg_data->op_def);
249   }
250   if (s.ok()) {
251     op_reg_data.release();
252   } else {
253     op_reg_data.reset();
254   }
255   return watcher_status;
256 }
257 
258 // static
Global()259 OpRegistry* OpRegistry::Global() {
260   static OpRegistry* global_op_registry = new OpRegistry;
261   return global_op_registry;
262 }
263 
264 // OpListOpRegistry -----------------------------------------------------------
265 
OpListOpRegistry(const OpList * op_list)266 OpListOpRegistry::OpListOpRegistry(const OpList* op_list) {
267   for (const OpDef& op_def : op_list->op()) {
268     auto* op_reg_data = new OpRegistrationData();
269     op_reg_data->op_def = op_def;
270     index_[op_def.name()] = op_reg_data;
271   }
272 }
273 
~OpListOpRegistry()274 OpListOpRegistry::~OpListOpRegistry() {
275   for (const auto& e : index_) delete e.second;
276 }
277 
LookUp(const string & op_type_name) const278 const OpRegistrationData* OpListOpRegistry::LookUp(
279     const string& op_type_name) const {
280   auto iter = index_.find(op_type_name);
281   if (iter == index_.end()) {
282     return nullptr;
283   }
284   return iter->second;
285 }
286 
LookUp(const string & op_type_name,const OpRegistrationData ** op_reg_data) const287 Status OpListOpRegistry::LookUp(const string& op_type_name,
288                                 const OpRegistrationData** op_reg_data) const {
289   if ((*op_reg_data = LookUp(op_type_name))) return OkStatus();
290   return OpNotFound(op_type_name);
291 }
292 
293 namespace register_op {
294 
operator ()()295 InitOnStartupMarker OpDefBuilderWrapper::operator()() {
296   OpRegistry::Global()->Register(
297       [builder =
298            std::move(builder_)](OpRegistrationData* op_reg_data) -> Status {
299         return builder.Finalize(op_reg_data);
300       });
301   return {};
302 }
303 
304 }  //  namespace register_op
305 
306 }  // namespace tensorflow
307