xref: /aosp_15_r20/external/tensorflow/tensorflow/core/graph/graph_def_builder.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPH_GRAPH_DEF_BUILDER_H_
17 #define TENSORFLOW_CORE_GRAPH_GRAPH_DEF_BUILDER_H_
18 
19 #include <vector>
20 
21 #include "tensorflow/core/framework/function.pb.h"
22 #include "tensorflow/core/framework/graph.pb.h"
23 #include "tensorflow/core/framework/op.h"
24 #include "tensorflow/core/graph/graph.h"
25 #include "tensorflow/core/graph/node_builder.h"
26 #include "tensorflow/core/lib/core/status.h"
27 #include "tensorflow/core/lib/core/stringpiece.h"
28 #include "tensorflow/core/lib/gtl/array_slice.h"
29 
30 namespace tensorflow {
31 
32 // Given a function like:
33 //   namespace ops {
34 //   Node* Identity(NodeOut input, const GraphDefBuilder::Options& opts) {
35 //     if (opts.HaveError()) return nullptr;
36 //     static const string kOpName = "Identity";
37 //     NodeBuilder node_builder(opts.GetNameForOp(kOpName), kOpName,
38 //                              opts.op_registry());
39 //     node_builder.Input(input);
40 //     return opts.FinalizeBuilder(&node_builder);
41 //   }
42 //   }  // namespace ops
43 //
44 //   // Or, alternatively:
45 //   namespace ops {
46 //   Node* Identity(NodeOut input, const GraphDefBuilder::Options& opts) {
47 //     static const string kOpName = "Identity";
48 //     return UnaryOp(kOpName, input, opts);
49 //   }
50 //   }  // namespace ops
51 //
52 // You call it like:
53 //   GraphDefBuilder b;
54 //   using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
55 //   Node* na = Const(7, b.opts());
56 //   // Note: WithName() returns a copy, opts is unchanged.
57 //   Node* nb = Const(5, b.opts().WithName("control-input"));
58 //   Node* nc = Identity(na, b.opts().WithControlInput(nb));
59 //   GraphDef graph_def;
60 //   Status status = b.ToGraphDef(&graph_def);
61 //   if (!status.ok()) { /* Handle error */ }
62 //
63 // In tests you can skip the status handling via:
64 //   GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
65 //   ...
66 //   b.ToGraphDef(&graph_def);
67 
68 class GraphDefBuilder {
69  public:
70   // Options for adding a Node to a Graph.
71   class Options {
72    public:
73     // Sets the Graph (that Nodes will be added to) and the status.  The
74     // status may be set to nullptr, in which case errors cause CHECK
75     // failures.  The graph and status must outlive *this.
76     Options(Graph* graph, Status* status);
77     ~Options();
78 
79     // Methods for setting options.  These are const methods: they
80     // return a copy of *this with the option set.
81     Options WithName(StringPiece name) const;
82     Options WithDevice(StringPiece device) const;
83     Options WithControlInput(Node* control_input) const;
84     Options WithControlInputs(gtl::ArraySlice<Node*> control_inputs) const;
85 
86     // Override the default value for an optional attr.
87     template <class T>
WithAttr(StringPiece attr_name,T && value)88     Options WithAttr(StringPiece attr_name, T&& value) const {
89       return Options(*this).WithAttrImpl(attr_name, std::forward<T>(value));
90     }
91     // Note: overload needed to allow {...} expressions for value.
92     template <class T>
WithAttr(StringPiece attr_name,std::initializer_list<T> value)93     Options WithAttr(StringPiece attr_name,
94                      std::initializer_list<T> value) const {
95       return WithAttr<std::initializer_list<T>>(attr_name, std::move(value));
96     }
97 
98     // Methods for using options from a function that creates a Node.
99 
100     // Returns true if the status associated with *this has an error.
101     // Use this to skip processing that may depend on prior results.
HaveError()102     bool HaveError() const { return status_ != nullptr && !status_->ok(); }
103 
104     // Returns a string representation of the status associated with *this.
105     // Returns the string `"OK"` if the status doesn't have any error.
StatusToString()106     string StatusToString() const {
107       return status_->ok() ? "OK" : status_->error_message();
108     }
109 
110     // Given the Op type name, return a name for a node of that type.
111     // Uses the value set in WithName() if that has been called.  Otherwise,
112     // returns a name built out of the Op type name.
113     string GetNameForOp(StringPiece op) const;
114 
115     // Sets the device, adds control inputs, adds attrs, and calls Finalize().
116     // If Finalize returns an error, it is saved and this function returns
117     // nullptr.
118     Node* FinalizeBuilder(NodeBuilder* builder) const;
119 
120     // Updates the associated status, if any, or calls TF_CHECK_OK if none.
121     void UpdateStatus(const Status& status) const;
122 
123     // Accessor
op_registry()124     const OpRegistryInterface* op_registry() const {
125       return graph_->op_registry();
126     }
127 
128    private:
129     Options WithNameImpl(StringPiece name);
130     Options WithDeviceImpl(StringPiece device);
131     Options WithControlInputImpl(Node* control_input);
132     Options WithControlInputsImpl(gtl::ArraySlice<Node*> control_inputs);
133     template <class T>
WithAttrImpl(StringPiece name,T && value)134     Options WithAttrImpl(StringPiece name, T&& value) {
135       attrs_.emplace_back(string(name), AttrValue());
136       SetAttrValue(std::forward<T>(value), &attrs_.back().second);
137       return *this;
138     }
139 
140     Graph* const graph_;
141     Status* const status_;
142     string name_;
143     string device_;
144     std::vector<Node*> control_inputs_;
145     std::vector<std::pair<string, AttrValue>> attrs_;
146   };
147 
148   // Start building a new graph.
149   explicit GraphDefBuilder(
150       const OpRegistryInterface* op_registry = OpRegistry::Global())
graph_(op_registry)151       : graph_(op_registry), flib_def_(op_registry), opts_(&graph_, &status_) {}
152 
153   // For use in tests, where you want to fail immediately on error instead
154   // of checking the status at the end.
155   enum TestFailImmediatelyType { kFailImmediately };
156   explicit GraphDefBuilder(
157       TestFailImmediatelyType,
158       const OpRegistryInterface* op_registry = OpRegistry::Global())
graph_(op_registry)159       : graph_(op_registry), flib_def_(op_registry), opts_(&graph_, nullptr) {}
160 
161   // Gets the Options with the associated Graph and Status.
opts()162   const Options& opts() const { return opts_; }
163 
164   // Once all the nodes have been added, call this to get whether it was
165   // successful, and if so fill *graph_def.
166   Status ToGraphDef(GraphDef* graph_def) const;
167 
168   // Adds the function and gradient definitions in `fdef_lib` to this graph's op
169   // registry. Ignores duplicate functions, and returns a bad status if an
170   // imported function differs from an existing function or op with the same
171   // name.
AddFunctionLibrary(const FunctionDefLibrary & fdef_lib)172   Status AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) {
173     return flib_def_.AddLibrary(fdef_lib);
174   }
175 
176   // Returns whether a user-defined function with `name` already exists in the
177   // graph.
HasFunction(const string & name)178   bool HasFunction(const string& name) {
179     return flib_def_.Find(name) != nullptr;
180   }
181 
182  private:
183   Graph graph_;
184   FunctionLibraryDefinition flib_def_;
185   Status status_;
186   Options opts_;
187 };
188 
189 namespace ops {
190 
191 // A NodeOut may either be a regular input or back input.  Regular
192 // inputs are specified via either a Node* or a Node* and an output
193 // index.  Back inputs are specified by a node name, output index, and
194 // output type.
195 typedef NodeBuilder::NodeOut NodeOut;
196 
197 // For adding an Op with no inputs to a GraphDefBuilder.
198 Node* SourceOp(const string& op_name, const GraphDefBuilder::Options& opts);
199 
200 // For adding an Op with one input to a GraphDefBuilder.
201 Node* UnaryOp(const string& op_name, NodeOut input,
202               const GraphDefBuilder::Options& opts);
203 
204 // For adding an Op with two inputs to a GraphDefBuilder.
205 Node* BinaryOp(const string& op_name, NodeOut a, NodeOut b,
206                const GraphDefBuilder::Options& opts);
207 
208 // For adding an Op with three inputs to a GraphDefBuilder.
209 Node* TernaryOp(const string& op_name, NodeOut a, NodeOut b, NodeOut c,
210                 const GraphDefBuilder::Options& opts);
211 
212 }  // namespace ops
213 }  // namespace tensorflow
214 
215 #endif  // TENSORFLOW_CORE_GRAPH_GRAPH_DEF_BUILDER_H_
216