xref: /aosp_15_r20/external/tensorflow/tensorflow/core/framework/op.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_OP_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_OP_H_
18 
19 #include <functional>
20 #include <unordered_map>
21 #include <vector>
22 
23 #include "tensorflow/core/framework/full_type.pb.h"
24 #include "tensorflow/core/framework/full_type_inference_util.h"
25 #include "tensorflow/core/framework/full_type_util.h"
26 #include "tensorflow/core/framework/op_def_builder.h"
27 #include "tensorflow/core/framework/op_def_util.h"
28 #include "tensorflow/core/framework/registration/registration.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/lib/strings/str_util.h"
32 #include "tensorflow/core/lib/strings/strcat.h"
33 #include "tensorflow/core/platform/logging.h"
34 #include "tensorflow/core/platform/macros.h"
35 #include "tensorflow/core/platform/mutex.h"
36 #include "tensorflow/core/platform/thread_annotations.h"
37 #include "tensorflow/core/platform/types.h"
38 
39 namespace tensorflow {
40 
41 // Users that want to look up an OpDef by type name should take an
42 // OpRegistryInterface.  Functions accepting a
43 // (const) OpRegistryInterface* may call LookUp() from multiple threads.
44 class OpRegistryInterface {
45  public:
46   virtual ~OpRegistryInterface();
47 
48   // Returns an error status and sets *op_reg_data to nullptr if no OpDef is
49   // registered under that name, otherwise returns the registered OpDef.
50   // Caller must not delete the returned pointer.
51   virtual Status LookUp(const std::string& op_type_name,
52                         const OpRegistrationData** op_reg_data) const = 0;
53 
54   // Shorthand for calling LookUp to get the OpDef.
55   Status LookUpOpDef(const std::string& op_type_name,
56                      const OpDef** op_def) const;
57 };
58 
59 // The standard implementation of OpRegistryInterface, along with a
60 // global singleton used for registering ops via the REGISTER
61 // macros below.  Thread-safe.
62 //
63 // Example registration:
64 //   OpRegistry::Global()->Register(
65 //     [](OpRegistrationData* op_reg_data)->Status {
66 //       // Populate *op_reg_data here.
67 //       return Status::OK();
68 //   });
69 class OpRegistry : public OpRegistryInterface {
70  public:
71   typedef std::function<Status(OpRegistrationData*)> OpRegistrationDataFactory;
72 
73   OpRegistry();
74   ~OpRegistry() override;
75 
76   void Register(const OpRegistrationDataFactory& op_data_factory);
77 
78   Status LookUp(const std::string& op_type_name,
79                 const OpRegistrationData** op_reg_data) const override;
80 
81   // Returns OpRegistrationData* of registered op type, else returns nullptr.
82   const OpRegistrationData* LookUp(const std::string& op_type_name) const;
83 
84   // Fills *ops with all registered OpDefs (except those with names
85   // starting with '_' if include_internal == false) sorted in
86   // ascending alphabetical order.
87   void Export(bool include_internal, OpList* ops) const;
88 
89   // Returns ASCII-format OpList for all registered OpDefs (except
90   // those with names starting with '_' if include_internal == false).
91   std::string DebugString(bool include_internal) const;
92 
93   // A singleton available at startup.
94   static OpRegistry* Global();
95 
96   // Get all registered ops.
97   void GetRegisteredOps(std::vector<OpDef>* op_defs);
98 
99   // Get all `OpRegistrationData`s.
100   void GetOpRegistrationData(std::vector<OpRegistrationData>* op_data);
101 
102   // Registers a function that validates op registry.
RegisterValidator(std::function<Status (const OpRegistryInterface &)> validator)103   void RegisterValidator(
104       std::function<Status(const OpRegistryInterface&)> validator) {
105     op_registry_validator_ = std::move(validator);
106   }
107 
108   // Watcher, a function object.
109   // The watcher, if set by SetWatcher(), is called every time an op is
110   // registered via the Register function. The watcher is passed the Status
111   // obtained from building and adding the OpDef to the registry, and the OpDef
112   // itself if it was successfully built. A watcher returns a Status which is in
113   // turn returned as the final registration status.
114   typedef std::function<Status(const Status&, const OpDef&)> Watcher;
115 
116   // An OpRegistry object has only one watcher. This interface is not thread
117   // safe, as different clients are free to set the watcher any time.
118   // Clients are expected to atomically perform the following sequence of
119   // operations :
120   // SetWatcher(a_watcher);
121   // Register some ops;
122   // op_registry->ProcessRegistrations();
123   // SetWatcher(nullptr);
124   // Returns a non-OK status if a non-null watcher is over-written by another
125   // non-null watcher.
126   Status SetWatcher(const Watcher& watcher);
127 
128   // Process the current list of deferred registrations. Note that calls to
129   // Export, LookUp and DebugString would also implicitly process the deferred
130   // registrations. Returns the status of the first failed op registration or
131   // Status::OK() otherwise.
132   Status ProcessRegistrations() const;
133 
134   // Defer the registrations until a later call to a function that processes
135   // deferred registrations are made. Normally, registrations that happen after
136   // calls to Export, LookUp, ProcessRegistrations and DebugString are processed
137   // immediately. Call this to defer future registrations.
138   void DeferRegistrations();
139 
140   // Clear the registrations that have been deferred.
141   void ClearDeferredRegistrations();
142 
143  private:
144   // Ensures that all the functions in deferred_ get called, their OpDef's
145   // registered, and returns with deferred_ empty.  Returns true the first
146   // time it is called. Prints a fatal log if any op registration fails.
147   bool MustCallDeferred() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
148 
149   // Calls the functions in deferred_ and registers their OpDef's
150   // It returns the Status of the first failed op registration or Status::OK()
151   // otherwise.
152   Status CallDeferred() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
153 
154   // Add 'def' to the registry with additional data 'data'. On failure, or if
155   // there is already an OpDef with that name registered, returns a non-okay
156   // status.
157   Status RegisterAlreadyLocked(const OpRegistrationDataFactory& op_data_factory)
158       const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
159 
160   const OpRegistrationData* LookUpSlow(const std::string& op_type_name) const;
161 
162   mutable mutex mu_;
163   // Functions in deferred_ may only be called with mu_ held.
164   mutable std::vector<OpRegistrationDataFactory> deferred_ TF_GUARDED_BY(mu_);
165   // Values are owned.
166   mutable std::unordered_map<string, const OpRegistrationData*> registry_
167       TF_GUARDED_BY(mu_);
168   mutable bool initialized_ TF_GUARDED_BY(mu_);
169 
170   // Registry watcher.
171   mutable Watcher watcher_ TF_GUARDED_BY(mu_);
172 
173   std::function<Status(const OpRegistryInterface&)> op_registry_validator_;
174 };
175 
176 // An adapter to allow an OpList to be used as an OpRegistryInterface.
177 //
178 // Note that shape inference functions are not passed in to OpListOpRegistry, so
179 // it will return an unusable shape inference function for every op it supports;
180 // therefore, it should only be used in contexts where this is okay.
181 class OpListOpRegistry : public OpRegistryInterface {
182  public:
183   // Does not take ownership of op_list, *op_list must outlive *this.
184   explicit OpListOpRegistry(const OpList* op_list);
185   ~OpListOpRegistry() override;
186   Status LookUp(const std::string& op_type_name,
187                 const OpRegistrationData** op_reg_data) const override;
188 
189   // Returns OpRegistrationData* of op type in list, else returns nullptr.
190   const OpRegistrationData* LookUp(const std::string& op_type_name) const;
191 
192  private:
193   // Values are owned.
194   std::unordered_map<string, const OpRegistrationData*> index_;
195 };
196 
197 // Support for defining the OpDef (specifying the semantics of the Op and how
198 // it should be created) and registering it in the OpRegistry::Global()
199 // registry.  Usage:
200 //
201 // REGISTER_OP("my_op_name")
202 //     .Attr("<name>:<type>")
203 //     .Attr("<name>:<type>=<default>")
204 //     .Input("<name>:<type-expr>")
205 //     .Input("<name>:Ref(<type-expr>)")
206 //     .Output("<name>:<type-expr>")
207 //     .Doc(R"(
208 // <1-line summary>
209 // <rest of the description (potentially many lines)>
210 // <name-of-attr-input-or-output>: <description of name>
211 // <name-of-attr-input-or-output>: <description of name;
212 //   if long, indent the description on subsequent lines>
213 // )");
214 //
215 // Note: .Doc() should be last.
216 // For details, see the OpDefBuilder class in op_def_builder.h.
217 
218 namespace register_op {
219 
220 class OpDefBuilderWrapper {
221  public:
OpDefBuilderWrapper(const char name[])222   explicit OpDefBuilderWrapper(const char name[]) : builder_(name) {}
Attr(std::string spec)223   OpDefBuilderWrapper& Attr(std::string spec) {
224     builder_.Attr(std::move(spec));
225     return *this;
226   }
Attr(const char * spec)227   OpDefBuilderWrapper& Attr(const char* spec) TF_ATTRIBUTE_NOINLINE {
228     return Attr(std::string(spec));
229   }
Input(std::string spec)230   OpDefBuilderWrapper& Input(std::string spec) {
231     builder_.Input(std::move(spec));
232     return *this;
233   }
Input(const char * spec)234   OpDefBuilderWrapper& Input(const char* spec) TF_ATTRIBUTE_NOINLINE {
235     return Input(std::string(spec));
236   }
Output(std::string spec)237   OpDefBuilderWrapper& Output(std::string spec) {
238     builder_.Output(std::move(spec));
239     return *this;
240   }
Output(const char * spec)241   OpDefBuilderWrapper& Output(const char* spec) TF_ATTRIBUTE_NOINLINE {
242     return Output(std::string(spec));
243   }
SetIsCommutative()244   OpDefBuilderWrapper& SetIsCommutative() {
245     builder_.SetIsCommutative();
246     return *this;
247   }
SetIsAggregate()248   OpDefBuilderWrapper& SetIsAggregate() {
249     builder_.SetIsAggregate();
250     return *this;
251   }
SetIsStateful()252   OpDefBuilderWrapper& SetIsStateful() {
253     builder_.SetIsStateful();
254     return *this;
255   }
SetDoNotOptimize()256   OpDefBuilderWrapper& SetDoNotOptimize() {
257     // We don't have a separate flag to disable optimizations such as constant
258     // folding and CSE so we reuse the stateful flag.
259     builder_.SetIsStateful();
260     return *this;
261   }
SetAllowsUninitializedInput()262   OpDefBuilderWrapper& SetAllowsUninitializedInput() {
263     builder_.SetAllowsUninitializedInput();
264     return *this;
265   }
Deprecated(int version,std::string explanation)266   OpDefBuilderWrapper& Deprecated(int version, std::string explanation) {
267     builder_.Deprecated(version, std::move(explanation));
268     return *this;
269   }
Doc(std::string text)270   OpDefBuilderWrapper& Doc(std::string text) {
271     builder_.Doc(std::move(text));
272     return *this;
273   }
SetShapeFn(OpShapeInferenceFn fn)274   OpDefBuilderWrapper& SetShapeFn(OpShapeInferenceFn fn) {
275     builder_.SetShapeFn(std::move(fn));
276     return *this;
277   }
SetIsDistributedCommunication()278   OpDefBuilderWrapper& SetIsDistributedCommunication() {
279     builder_.SetIsDistributedCommunication();
280     return *this;
281   }
282 
SetTypeConstructor(OpTypeConstructor fn)283   OpDefBuilderWrapper& SetTypeConstructor(OpTypeConstructor fn) {
284     builder_.SetTypeConstructor(std::move(fn));
285     return *this;
286   }
287 
SetForwardTypeFn(ForwardTypeInferenceFn fn)288   OpDefBuilderWrapper& SetForwardTypeFn(ForwardTypeInferenceFn fn) {
289     builder_.SetForwardTypeFn(std::move(fn));
290     return *this;
291   }
292 
SetReverseTypeFn(int input_number,ForwardTypeInferenceFn fn)293   OpDefBuilderWrapper& SetReverseTypeFn(int input_number,
294                                         ForwardTypeInferenceFn fn) {
295     builder_.SetReverseTypeFn(input_number, std::move(fn));
296     return *this;
297   }
298 
builder()299   const ::tensorflow::OpDefBuilder& builder() const { return builder_; }
300 
301   InitOnStartupMarker operator()();
302 
303  private:
304   mutable ::tensorflow::OpDefBuilder builder_;
305 };
306 
307 }  // namespace register_op
308 
309 #define REGISTER_OP_IMPL(ctr, name, is_system_op)                         \
310   static ::tensorflow::InitOnStartupMarker const register_op##ctr         \
311       TF_ATTRIBUTE_UNUSED =                                               \
312           TF_INIT_ON_STARTUP_IF(is_system_op || SHOULD_REGISTER_OP(name)) \
313           << ::tensorflow::register_op::OpDefBuilderWrapper(name)
314 
315 #define REGISTER_OP(name)        \
316   TF_ATTRIBUTE_ANNOTATE("tf:op") \
317   TF_NEW_ID_FOR_INIT(REGISTER_OP_IMPL, name, false)
318 
319 // The `REGISTER_SYSTEM_OP()` macro acts as `REGISTER_OP()` except
320 // that the op is registered unconditionally even when selective
321 // registration is used.
322 #define REGISTER_SYSTEM_OP(name)        \
323   TF_ATTRIBUTE_ANNOTATE("tf:op")        \
324   TF_ATTRIBUTE_ANNOTATE("tf:op:system") \
325   TF_NEW_ID_FOR_INIT(REGISTER_OP_IMPL, name, true)
326 
327 }  // namespace tensorflow
328 
329 #endif  // TENSORFLOW_CORE_FRAMEWORK_OP_H_
330