xref: /aosp_15_r20/external/tensorflow/tensorflow/core/graph/node_builder.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPH_NODE_BUILDER_H_
17 #define TENSORFLOW_CORE_GRAPH_NODE_BUILDER_H_
18 
19 #include <vector>
20 #include "tensorflow/core/framework/node_def_builder.h"
21 #include "tensorflow/core/framework/op.h"
22 #include "tensorflow/core/framework/op_def.pb.h"
23 #include "tensorflow/core/graph/graph.h"
24 #include "tensorflow/core/lib/core/status.h"
25 #include "tensorflow/core/lib/core/stringpiece.h"
26 #include "tensorflow/core/lib/gtl/array_slice.h"
27 
28 namespace tensorflow {
29 
30 // This is a helper for creating a Node and adding it to a Graph.
31 // Internally, it uses a NodeDefBuilder to automatically set attrs
32 // that can be inferred from the inputs, and use default values
33 // (where they exist) for unspecified attrs.  Example usage:
34 //
35 //  Node* node;
36 //  Status status = NodeBuilder(node_name, op_name)
37 //                           .Input(...)
38 //                           .Attr(...)
39 //                           .Finalize(&graph, &node);
40 //  if (!status.ok()) return status;
41 //  // Use node here.
42 class NodeBuilder {
43  public:
44   // For specifying the output of a Node to provide to one of the Input()
45   // functions below.  It supports both regular inputs (where you are
46   // connecting to an existing Node*), and inputs from outside the graph
47   // (or haven't been added to the graph yet, like back edges, where
48   // you don't have a Node*). Both types can be mixed, e.g. in an
49   // ArraySlice.
50   struct NodeOut {
51     // For referencing an existing Node.
52     NodeOut(Node* n, int32_t i = 0);
53     NodeOut(OutputTensor t);
54 
55     // For referencing Nodes not in the graph being built. It is
56     // useful when preparing a graph for ExtendSession or creating a
57     // back edge to a node that hasn't been added to the graph yet,
58     // but will be.
59     NodeOut(StringPiece name, int32_t i, DataType t);
60 
61     // Default constructor for std::vector<NodeOut>.
62     NodeOut();
63 
64     Node* node;
65     // error is set to true if:
66     // * the NodeOut was default constructed and never overwritten,
67     // * a nullptr Node* was passed to the NodeOut constructor, or
68     // * an out-of-range index was passed to the NodeOut constructor.
69     bool error;
70     string name;
71     int32 index;
72     DataType dt;
73   };
74 
75   // Specify the name and the Op (either via an OpDef or the name of
76   // the Op plus a registry) for the Node.  Other fields are
77   // specified by calling the methods below.
78   // REQUIRES: The OpDef must satisfy ValidateOpDef().
79   NodeBuilder(StringPiece name, StringPiece op_name,
80               const OpRegistryInterface* op_registry = OpRegistry::Global(),
81               const NodeDebugInfo* debug = nullptr);
82   NodeBuilder(StringPiece name, const OpDef* op_def);
83 
84   // Create a NodeBuilder from an existing NodeDefBuilder.
85   NodeBuilder(const NodeDefBuilder& def_builder);
86 
87   // You must call one Input() function per input_arg in the Op,
88   // *and in the same order as the input_args appear in the OpDef.*
89 
90   // For inputs that take a single tensor.
91   NodeBuilder& Input(Node* src_node, int src_index = 0);
92   NodeBuilder& Input(NodeOut src);
93 
94   // For inputs that take a list of tensors.
95   NodeBuilder& Input(gtl::ArraySlice<NodeOut> src_list);
96 
97   // Require that this node run after src_node(s).
98   NodeBuilder& ControlInput(Node* src_node);
99   NodeBuilder& ControlInputs(gtl::ArraySlice<Node*> src_nodes);
100 
101   // Sets the "requested device spec" in the NodeDef (not the
102   // "assigned device" in the Node).
103   NodeBuilder& Device(StringPiece device_spec);
104 
105   // Sets the device name in the "assigned device" field in tensorflow::Node.
106   NodeBuilder& AssignedDevice(StringPiece device);
107 
108   // Sets the _XlaCluster attribute in created node to `xla_cluster`.
109   NodeBuilder& XlaCluster(StringPiece xla_cluster);
110 
111   // Set the value of an attr.  attr_name must match the name of one of
112   // attrs defined by the Op, and value must have the corresponding type
113   // (see SetAttrValue() in ../framework/attr_value_util.h for legal
114   // types for value).  Note that attrs will be set automatically if
115   // they can be determined by the inputs.
116   template <class T>
117   NodeBuilder& Attr(StringPiece attr_name, T&& value);
118   template <class T>
119   NodeBuilder& Attr(StringPiece attr_name, std::initializer_list<T> value);
120 
121   // Validates the described node and adds it to *graph, adding edges
122   // for all (non-back) inputs.  If created_node is not nullptr,
123   // *created_node will be set to the new node (or nullptr on error).
124   // If `consume` is true, the builder state will be moved into `node_def`,
125   // and the builder will be left in an undefined state.
126   Status Finalize(Graph* graph, Node** created_node, bool consume = false);
127 
128   // Same as `Finalize` above, but using StatusOr to return value. Preferred
129   // form.
130   StatusOr<Node*> Finalize(Graph* graph, bool consume = false);
131 
132   // Accessors for the values set in the constructor.
node_name()133   const string& node_name() const { return def_builder_.node_name(); }
op_def()134   const OpDef& op_def() const { return def_builder_.op_def(); }
135 
136  private:
SafeGetOutput(const Node * node,int i,bool * error)137   static DataType SafeGetOutput(const Node* node, int i, bool* error) {
138     if (node != nullptr && i >= 0 && i < node->num_outputs()) {
139       *error = false;
140       return node->output_type(i);
141     } else {
142       *error = true;
143       return DT_FLOAT;
144     }
145   }
146 
147   // If SafeGetOutput indicates a range error, add it to errors_.
148   void AddIndexError(const Node* node, int i);
149 
150   // Set *dt and returns true if i is in range. Combines
151   // SafeGetOutput() and AddIndexError().
152   bool GetOutputType(const Node* node, int i, DataType* dt);
153 
154   NodeDefBuilder def_builder_;
155   const OpRegistryInterface* op_registry_;
156   std::vector<NodeOut> inputs_;
157   std::vector<Node*> control_inputs_;
158   std::vector<string> errors_;
159   string assigned_device_;
160 };
161 
162 // IMPLEMENTATION -------------------------------------------------------------
163 
164 template <class T>
Attr(StringPiece attr_name,T && value)165 NodeBuilder& NodeBuilder::Attr(StringPiece attr_name, T&& value) {
166   def_builder_.Attr(attr_name, std::forward<T>(value));
167   return *this;
168 }
169 
170 template <class T>
Attr(StringPiece attr_name,std::initializer_list<T> value)171 NodeBuilder& NodeBuilder::Attr(StringPiece attr_name,
172                                std::initializer_list<T> value) {
173   def_builder_.Attr(attr_name, value);
174   return *this;
175 }
176 
177 }  // namespace tensorflow
178 
179 #endif  // TENSORFLOW_CORE_GRAPH_NODE_BUILDER_H_
180