1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/op.h"
17
18 #include <algorithm>
19 #include <memory>
20 #include <vector>
21
22 #include "tensorflow/core/framework/full_type.pb.h"
23 #include "tensorflow/core/framework/op_def_builder.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/gtl/map_util.h"
26 #include "tensorflow/core/lib/strings/str_util.h"
27 #include "tensorflow/core/platform/host_info.h"
28 #include "tensorflow/core/platform/logging.h"
29 #include "tensorflow/core/platform/mutex.h"
30 #include "tensorflow/core/platform/protobuf.h"
31 #include "tensorflow/core/platform/types.h"
32
33 namespace tensorflow {
34
DefaultValidator(const OpRegistryInterface & op_registry)35 Status DefaultValidator(const OpRegistryInterface& op_registry) {
36 LOG(WARNING) << "No kernel validator registered with OpRegistry.";
37 return OkStatus();
38 }
39
40 // OpRegistry -----------------------------------------------------------------
41
~OpRegistryInterface()42 OpRegistryInterface::~OpRegistryInterface() {}
43
LookUpOpDef(const string & op_type_name,const OpDef ** op_def) const44 Status OpRegistryInterface::LookUpOpDef(const string& op_type_name,
45 const OpDef** op_def) const {
46 *op_def = nullptr;
47 const OpRegistrationData* op_reg_data = nullptr;
48 TF_RETURN_IF_ERROR(LookUp(op_type_name, &op_reg_data));
49 *op_def = &op_reg_data->op_def;
50 return OkStatus();
51 }
52
OpRegistry()53 OpRegistry::OpRegistry()
54 : initialized_(false), op_registry_validator_(DefaultValidator) {}
55
~OpRegistry()56 OpRegistry::~OpRegistry() {
57 for (const auto& e : registry_) delete e.second;
58 }
59
Register(const OpRegistrationDataFactory & op_data_factory)60 void OpRegistry::Register(const OpRegistrationDataFactory& op_data_factory) {
61 mutex_lock lock(mu_);
62 if (initialized_) {
63 TF_QCHECK_OK(RegisterAlreadyLocked(op_data_factory));
64 } else {
65 deferred_.push_back(op_data_factory);
66 }
67 }
68
69 namespace {
70 // Helper function that returns Status message for failed LookUp.
OpNotFound(const string & op_type_name)71 Status OpNotFound(const string& op_type_name) {
72 Status status = errors::NotFound(
73 "Op type not registered '", op_type_name, "' in binary running on ",
74 port::Hostname(), ". ",
75 "Make sure the Op and Kernel are registered in the binary running in "
76 "this process. Note that if you are loading a saved graph which used ops "
77 "from tf.contrib, accessing (e.g.) `tf.contrib.resampler` should be done "
78 "before importing the graph, as contrib ops are lazily registered when "
79 "the module is first accessed.");
80 VLOG(1) << status.ToString();
81 return status;
82 }
83 } // namespace
84
LookUp(const string & op_type_name,const OpRegistrationData ** op_reg_data) const85 Status OpRegistry::LookUp(const string& op_type_name,
86 const OpRegistrationData** op_reg_data) const {
87 if ((*op_reg_data = LookUp(op_type_name))) return OkStatus();
88 return OpNotFound(op_type_name);
89 }
90
LookUp(const string & op_type_name) const91 const OpRegistrationData* OpRegistry::LookUp(const string& op_type_name) const {
92 {
93 tf_shared_lock l(mu_);
94 if (initialized_) {
95 if (const OpRegistrationData* res =
96 gtl::FindWithDefault(registry_, op_type_name, nullptr)) {
97 return res;
98 }
99 }
100 }
101 return LookUpSlow(op_type_name);
102 }
103
LookUpSlow(const string & op_type_name) const104 const OpRegistrationData* OpRegistry::LookUpSlow(
105 const string& op_type_name) const {
106 const OpRegistrationData* res = nullptr;
107
108 bool first_call = false;
109 bool first_unregistered = false;
110 { // Scope for lock.
111 mutex_lock lock(mu_);
112 first_call = MustCallDeferred();
113 res = gtl::FindWithDefault(registry_, op_type_name, nullptr);
114
115 static bool unregistered_before = false;
116 first_unregistered = !unregistered_before && (res == nullptr);
117 if (first_unregistered) {
118 unregistered_before = true;
119 }
120 // Note: Can't hold mu_ while calling Export() below.
121 }
122 if (first_call) {
123 TF_QCHECK_OK(op_registry_validator_(*this));
124 }
125 if (res == nullptr) {
126 if (first_unregistered) {
127 OpList op_list;
128 Export(true, &op_list);
129 if (VLOG_IS_ON(3)) {
130 LOG(INFO) << "All registered Ops:";
131 for (const auto& op : op_list.op()) {
132 LOG(INFO) << SummarizeOpDef(op);
133 }
134 }
135 }
136 }
137 return res;
138 }
139
GetRegisteredOps(std::vector<OpDef> * op_defs)140 void OpRegistry::GetRegisteredOps(std::vector<OpDef>* op_defs) {
141 mutex_lock lock(mu_);
142 MustCallDeferred();
143 for (const auto& p : registry_) {
144 op_defs->push_back(p.second->op_def);
145 }
146 }
147
GetOpRegistrationData(std::vector<OpRegistrationData> * op_data)148 void OpRegistry::GetOpRegistrationData(
149 std::vector<OpRegistrationData>* op_data) {
150 mutex_lock lock(mu_);
151 MustCallDeferred();
152 for (const auto& p : registry_) {
153 op_data->push_back(*p.second);
154 }
155 }
156
SetWatcher(const Watcher & watcher)157 Status OpRegistry::SetWatcher(const Watcher& watcher) {
158 mutex_lock lock(mu_);
159 if (watcher_ && watcher) {
160 return errors::AlreadyExists(
161 "Cannot over-write a valid watcher with another.");
162 }
163 watcher_ = watcher;
164 return OkStatus();
165 }
166
Export(bool include_internal,OpList * ops) const167 void OpRegistry::Export(bool include_internal, OpList* ops) const {
168 mutex_lock lock(mu_);
169 MustCallDeferred();
170
171 std::vector<std::pair<string, const OpRegistrationData*>> sorted(
172 registry_.begin(), registry_.end());
173 std::sort(sorted.begin(), sorted.end());
174
175 auto out = ops->mutable_op();
176 out->Clear();
177 out->Reserve(sorted.size());
178
179 for (const auto& item : sorted) {
180 if (include_internal || !absl::StartsWith(item.first, "_")) {
181 *out->Add() = item.second->op_def;
182 }
183 }
184 }
185
DeferRegistrations()186 void OpRegistry::DeferRegistrations() {
187 mutex_lock lock(mu_);
188 initialized_ = false;
189 }
190
ClearDeferredRegistrations()191 void OpRegistry::ClearDeferredRegistrations() {
192 mutex_lock lock(mu_);
193 deferred_.clear();
194 }
195
ProcessRegistrations() const196 Status OpRegistry::ProcessRegistrations() const {
197 mutex_lock lock(mu_);
198 return CallDeferred();
199 }
200
DebugString(bool include_internal) const201 string OpRegistry::DebugString(bool include_internal) const {
202 OpList op_list;
203 Export(include_internal, &op_list);
204 string ret;
205 for (const auto& op : op_list.op()) {
206 strings::StrAppend(&ret, SummarizeOpDef(op), "\n");
207 }
208 return ret;
209 }
210
MustCallDeferred() const211 bool OpRegistry::MustCallDeferred() const {
212 if (initialized_) return false;
213 initialized_ = true;
214 for (size_t i = 0; i < deferred_.size(); ++i) {
215 TF_QCHECK_OK(RegisterAlreadyLocked(deferred_[i]));
216 }
217 deferred_.clear();
218 return true;
219 }
220
CallDeferred() const221 Status OpRegistry::CallDeferred() const {
222 if (initialized_) return OkStatus();
223 initialized_ = true;
224 for (size_t i = 0; i < deferred_.size(); ++i) {
225 Status s = RegisterAlreadyLocked(deferred_[i]);
226 if (!s.ok()) {
227 return s;
228 }
229 }
230 deferred_.clear();
231 return OkStatus();
232 }
233
RegisterAlreadyLocked(const OpRegistrationDataFactory & op_data_factory) const234 Status OpRegistry::RegisterAlreadyLocked(
235 const OpRegistrationDataFactory& op_data_factory) const {
236 std::unique_ptr<OpRegistrationData> op_reg_data(new OpRegistrationData);
237 Status s = op_data_factory(op_reg_data.get());
238 if (s.ok()) {
239 s = ValidateOpDef(op_reg_data->op_def);
240 if (s.ok() &&
241 !gtl::InsertIfNotPresent(®istry_, op_reg_data->op_def.name(),
242 op_reg_data.get())) {
243 s = errors::AlreadyExists("Op with name ", op_reg_data->op_def.name());
244 }
245 }
246 Status watcher_status = s;
247 if (watcher_) {
248 watcher_status = watcher_(s, op_reg_data->op_def);
249 }
250 if (s.ok()) {
251 op_reg_data.release();
252 } else {
253 op_reg_data.reset();
254 }
255 return watcher_status;
256 }
257
258 // static
Global()259 OpRegistry* OpRegistry::Global() {
260 static OpRegistry* global_op_registry = new OpRegistry;
261 return global_op_registry;
262 }
263
264 // OpListOpRegistry -----------------------------------------------------------
265
OpListOpRegistry(const OpList * op_list)266 OpListOpRegistry::OpListOpRegistry(const OpList* op_list) {
267 for (const OpDef& op_def : op_list->op()) {
268 auto* op_reg_data = new OpRegistrationData();
269 op_reg_data->op_def = op_def;
270 index_[op_def.name()] = op_reg_data;
271 }
272 }
273
~OpListOpRegistry()274 OpListOpRegistry::~OpListOpRegistry() {
275 for (const auto& e : index_) delete e.second;
276 }
277
LookUp(const string & op_type_name) const278 const OpRegistrationData* OpListOpRegistry::LookUp(
279 const string& op_type_name) const {
280 auto iter = index_.find(op_type_name);
281 if (iter == index_.end()) {
282 return nullptr;
283 }
284 return iter->second;
285 }
286
LookUp(const string & op_type_name,const OpRegistrationData ** op_reg_data) const287 Status OpListOpRegistry::LookUp(const string& op_type_name,
288 const OpRegistrationData** op_reg_data) const {
289 if ((*op_reg_data = LookUp(op_type_name))) return OkStatus();
290 return OpNotFound(op_type_name);
291 }
292
293 namespace register_op {
294
operator ()()295 InitOnStartupMarker OpDefBuilderWrapper::operator()() {
296 OpRegistry::Global()->Register(
297 [builder =
298 std::move(builder_)](OpRegistrationData* op_reg_data) -> Status {
299 return builder.Finalize(op_reg_data);
300 });
301 return {};
302 }
303
304 } // namespace register_op
305
306 } // namespace tensorflow
307