1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_FRAMEWORK_OP_H_ 17 #define TENSORFLOW_CORE_FRAMEWORK_OP_H_ 18 19 #include <functional> 20 #include <unordered_map> 21 #include <vector> 22 23 #include "tensorflow/core/framework/full_type.pb.h" 24 #include "tensorflow/core/framework/full_type_inference_util.h" 25 #include "tensorflow/core/framework/full_type_util.h" 26 #include "tensorflow/core/framework/op_def_builder.h" 27 #include "tensorflow/core/framework/op_def_util.h" 28 #include "tensorflow/core/framework/registration/registration.h" 29 #include "tensorflow/core/lib/core/errors.h" 30 #include "tensorflow/core/lib/core/status.h" 31 #include "tensorflow/core/lib/strings/str_util.h" 32 #include "tensorflow/core/lib/strings/strcat.h" 33 #include "tensorflow/core/platform/logging.h" 34 #include "tensorflow/core/platform/macros.h" 35 #include "tensorflow/core/platform/mutex.h" 36 #include "tensorflow/core/platform/thread_annotations.h" 37 #include "tensorflow/core/platform/types.h" 38 39 namespace tensorflow { 40 41 // Users that want to look up an OpDef by type name should take an 42 // OpRegistryInterface. Functions accepting a 43 // (const) OpRegistryInterface* may call LookUp() from multiple threads. 44 class OpRegistryInterface { 45 public: 46 virtual ~OpRegistryInterface(); 47 48 // Returns an error status and sets *op_reg_data to nullptr if no OpDef is 49 // registered under that name, otherwise returns the registered OpDef. 50 // Caller must not delete the returned pointer. 51 virtual Status LookUp(const std::string& op_type_name, 52 const OpRegistrationData** op_reg_data) const = 0; 53 54 // Shorthand for calling LookUp to get the OpDef. 55 Status LookUpOpDef(const std::string& op_type_name, 56 const OpDef** op_def) const; 57 }; 58 59 // The standard implementation of OpRegistryInterface, along with a 60 // global singleton used for registering ops via the REGISTER 61 // macros below. Thread-safe. 62 // 63 // Example registration: 64 // OpRegistry::Global()->Register( 65 // [](OpRegistrationData* op_reg_data)->Status { 66 // // Populate *op_reg_data here. 67 // return Status::OK(); 68 // }); 69 class OpRegistry : public OpRegistryInterface { 70 public: 71 typedef std::function<Status(OpRegistrationData*)> OpRegistrationDataFactory; 72 73 OpRegistry(); 74 ~OpRegistry() override; 75 76 void Register(const OpRegistrationDataFactory& op_data_factory); 77 78 Status LookUp(const std::string& op_type_name, 79 const OpRegistrationData** op_reg_data) const override; 80 81 // Returns OpRegistrationData* of registered op type, else returns nullptr. 82 const OpRegistrationData* LookUp(const std::string& op_type_name) const; 83 84 // Fills *ops with all registered OpDefs (except those with names 85 // starting with '_' if include_internal == false) sorted in 86 // ascending alphabetical order. 87 void Export(bool include_internal, OpList* ops) const; 88 89 // Returns ASCII-format OpList for all registered OpDefs (except 90 // those with names starting with '_' if include_internal == false). 91 std::string DebugString(bool include_internal) const; 92 93 // A singleton available at startup. 94 static OpRegistry* Global(); 95 96 // Get all registered ops. 97 void GetRegisteredOps(std::vector<OpDef>* op_defs); 98 99 // Get all `OpRegistrationData`s. 100 void GetOpRegistrationData(std::vector<OpRegistrationData>* op_data); 101 102 // Registers a function that validates op registry. RegisterValidator(std::function<Status (const OpRegistryInterface &)> validator)103 void RegisterValidator( 104 std::function<Status(const OpRegistryInterface&)> validator) { 105 op_registry_validator_ = std::move(validator); 106 } 107 108 // Watcher, a function object. 109 // The watcher, if set by SetWatcher(), is called every time an op is 110 // registered via the Register function. The watcher is passed the Status 111 // obtained from building and adding the OpDef to the registry, and the OpDef 112 // itself if it was successfully built. A watcher returns a Status which is in 113 // turn returned as the final registration status. 114 typedef std::function<Status(const Status&, const OpDef&)> Watcher; 115 116 // An OpRegistry object has only one watcher. This interface is not thread 117 // safe, as different clients are free to set the watcher any time. 118 // Clients are expected to atomically perform the following sequence of 119 // operations : 120 // SetWatcher(a_watcher); 121 // Register some ops; 122 // op_registry->ProcessRegistrations(); 123 // SetWatcher(nullptr); 124 // Returns a non-OK status if a non-null watcher is over-written by another 125 // non-null watcher. 126 Status SetWatcher(const Watcher& watcher); 127 128 // Process the current list of deferred registrations. Note that calls to 129 // Export, LookUp and DebugString would also implicitly process the deferred 130 // registrations. Returns the status of the first failed op registration or 131 // Status::OK() otherwise. 132 Status ProcessRegistrations() const; 133 134 // Defer the registrations until a later call to a function that processes 135 // deferred registrations are made. Normally, registrations that happen after 136 // calls to Export, LookUp, ProcessRegistrations and DebugString are processed 137 // immediately. Call this to defer future registrations. 138 void DeferRegistrations(); 139 140 // Clear the registrations that have been deferred. 141 void ClearDeferredRegistrations(); 142 143 private: 144 // Ensures that all the functions in deferred_ get called, their OpDef's 145 // registered, and returns with deferred_ empty. Returns true the first 146 // time it is called. Prints a fatal log if any op registration fails. 147 bool MustCallDeferred() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 148 149 // Calls the functions in deferred_ and registers their OpDef's 150 // It returns the Status of the first failed op registration or Status::OK() 151 // otherwise. 152 Status CallDeferred() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 153 154 // Add 'def' to the registry with additional data 'data'. On failure, or if 155 // there is already an OpDef with that name registered, returns a non-okay 156 // status. 157 Status RegisterAlreadyLocked(const OpRegistrationDataFactory& op_data_factory) 158 const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 159 160 const OpRegistrationData* LookUpSlow(const std::string& op_type_name) const; 161 162 mutable mutex mu_; 163 // Functions in deferred_ may only be called with mu_ held. 164 mutable std::vector<OpRegistrationDataFactory> deferred_ TF_GUARDED_BY(mu_); 165 // Values are owned. 166 mutable std::unordered_map<string, const OpRegistrationData*> registry_ 167 TF_GUARDED_BY(mu_); 168 mutable bool initialized_ TF_GUARDED_BY(mu_); 169 170 // Registry watcher. 171 mutable Watcher watcher_ TF_GUARDED_BY(mu_); 172 173 std::function<Status(const OpRegistryInterface&)> op_registry_validator_; 174 }; 175 176 // An adapter to allow an OpList to be used as an OpRegistryInterface. 177 // 178 // Note that shape inference functions are not passed in to OpListOpRegistry, so 179 // it will return an unusable shape inference function for every op it supports; 180 // therefore, it should only be used in contexts where this is okay. 181 class OpListOpRegistry : public OpRegistryInterface { 182 public: 183 // Does not take ownership of op_list, *op_list must outlive *this. 184 explicit OpListOpRegistry(const OpList* op_list); 185 ~OpListOpRegistry() override; 186 Status LookUp(const std::string& op_type_name, 187 const OpRegistrationData** op_reg_data) const override; 188 189 // Returns OpRegistrationData* of op type in list, else returns nullptr. 190 const OpRegistrationData* LookUp(const std::string& op_type_name) const; 191 192 private: 193 // Values are owned. 194 std::unordered_map<string, const OpRegistrationData*> index_; 195 }; 196 197 // Support for defining the OpDef (specifying the semantics of the Op and how 198 // it should be created) and registering it in the OpRegistry::Global() 199 // registry. Usage: 200 // 201 // REGISTER_OP("my_op_name") 202 // .Attr("<name>:<type>") 203 // .Attr("<name>:<type>=<default>") 204 // .Input("<name>:<type-expr>") 205 // .Input("<name>:Ref(<type-expr>)") 206 // .Output("<name>:<type-expr>") 207 // .Doc(R"( 208 // <1-line summary> 209 // <rest of the description (potentially many lines)> 210 // <name-of-attr-input-or-output>: <description of name> 211 // <name-of-attr-input-or-output>: <description of name; 212 // if long, indent the description on subsequent lines> 213 // )"); 214 // 215 // Note: .Doc() should be last. 216 // For details, see the OpDefBuilder class in op_def_builder.h. 217 218 namespace register_op { 219 220 class OpDefBuilderWrapper { 221 public: OpDefBuilderWrapper(const char name[])222 explicit OpDefBuilderWrapper(const char name[]) : builder_(name) {} Attr(std::string spec)223 OpDefBuilderWrapper& Attr(std::string spec) { 224 builder_.Attr(std::move(spec)); 225 return *this; 226 } Attr(const char * spec)227 OpDefBuilderWrapper& Attr(const char* spec) TF_ATTRIBUTE_NOINLINE { 228 return Attr(std::string(spec)); 229 } Input(std::string spec)230 OpDefBuilderWrapper& Input(std::string spec) { 231 builder_.Input(std::move(spec)); 232 return *this; 233 } Input(const char * spec)234 OpDefBuilderWrapper& Input(const char* spec) TF_ATTRIBUTE_NOINLINE { 235 return Input(std::string(spec)); 236 } Output(std::string spec)237 OpDefBuilderWrapper& Output(std::string spec) { 238 builder_.Output(std::move(spec)); 239 return *this; 240 } Output(const char * spec)241 OpDefBuilderWrapper& Output(const char* spec) TF_ATTRIBUTE_NOINLINE { 242 return Output(std::string(spec)); 243 } SetIsCommutative()244 OpDefBuilderWrapper& SetIsCommutative() { 245 builder_.SetIsCommutative(); 246 return *this; 247 } SetIsAggregate()248 OpDefBuilderWrapper& SetIsAggregate() { 249 builder_.SetIsAggregate(); 250 return *this; 251 } SetIsStateful()252 OpDefBuilderWrapper& SetIsStateful() { 253 builder_.SetIsStateful(); 254 return *this; 255 } SetDoNotOptimize()256 OpDefBuilderWrapper& SetDoNotOptimize() { 257 // We don't have a separate flag to disable optimizations such as constant 258 // folding and CSE so we reuse the stateful flag. 259 builder_.SetIsStateful(); 260 return *this; 261 } SetAllowsUninitializedInput()262 OpDefBuilderWrapper& SetAllowsUninitializedInput() { 263 builder_.SetAllowsUninitializedInput(); 264 return *this; 265 } Deprecated(int version,std::string explanation)266 OpDefBuilderWrapper& Deprecated(int version, std::string explanation) { 267 builder_.Deprecated(version, std::move(explanation)); 268 return *this; 269 } Doc(std::string text)270 OpDefBuilderWrapper& Doc(std::string text) { 271 builder_.Doc(std::move(text)); 272 return *this; 273 } SetShapeFn(OpShapeInferenceFn fn)274 OpDefBuilderWrapper& SetShapeFn(OpShapeInferenceFn fn) { 275 builder_.SetShapeFn(std::move(fn)); 276 return *this; 277 } SetIsDistributedCommunication()278 OpDefBuilderWrapper& SetIsDistributedCommunication() { 279 builder_.SetIsDistributedCommunication(); 280 return *this; 281 } 282 SetTypeConstructor(OpTypeConstructor fn)283 OpDefBuilderWrapper& SetTypeConstructor(OpTypeConstructor fn) { 284 builder_.SetTypeConstructor(std::move(fn)); 285 return *this; 286 } 287 SetForwardTypeFn(ForwardTypeInferenceFn fn)288 OpDefBuilderWrapper& SetForwardTypeFn(ForwardTypeInferenceFn fn) { 289 builder_.SetForwardTypeFn(std::move(fn)); 290 return *this; 291 } 292 SetReverseTypeFn(int input_number,ForwardTypeInferenceFn fn)293 OpDefBuilderWrapper& SetReverseTypeFn(int input_number, 294 ForwardTypeInferenceFn fn) { 295 builder_.SetReverseTypeFn(input_number, std::move(fn)); 296 return *this; 297 } 298 builder()299 const ::tensorflow::OpDefBuilder& builder() const { return builder_; } 300 301 InitOnStartupMarker operator()(); 302 303 private: 304 mutable ::tensorflow::OpDefBuilder builder_; 305 }; 306 307 } // namespace register_op 308 309 #define REGISTER_OP_IMPL(ctr, name, is_system_op) \ 310 static ::tensorflow::InitOnStartupMarker const register_op##ctr \ 311 TF_ATTRIBUTE_UNUSED = \ 312 TF_INIT_ON_STARTUP_IF(is_system_op || SHOULD_REGISTER_OP(name)) \ 313 << ::tensorflow::register_op::OpDefBuilderWrapper(name) 314 315 #define REGISTER_OP(name) \ 316 TF_ATTRIBUTE_ANNOTATE("tf:op") \ 317 TF_NEW_ID_FOR_INIT(REGISTER_OP_IMPL, name, false) 318 319 // The `REGISTER_SYSTEM_OP()` macro acts as `REGISTER_OP()` except 320 // that the op is registered unconditionally even when selective 321 // registration is used. 322 #define REGISTER_SYSTEM_OP(name) \ 323 TF_ATTRIBUTE_ANNOTATE("tf:op") \ 324 TF_ATTRIBUTE_ANNOTATE("tf:op:system") \ 325 TF_NEW_ID_FOR_INIT(REGISTER_OP_IMPL, name, true) 326 327 } // namespace tensorflow 328 329 #endif // TENSORFLOW_CORE_FRAMEWORK_OP_H_ 330