xref: /aosp_15_r20/external/tensorflow/tensorflow/core/graph/node_builder.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/graph/node_builder.h"
17 
18 #include <unordered_map>
19 #include <vector>
20 
21 #include "tensorflow/core/framework/node_def_util.h"
22 #include "tensorflow/core/framework/types.pb.h"
23 #include "tensorflow/core/framework/versions.pb.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/platform/statusor.h"
26 #include "tensorflow/core/protobuf/error_codes.pb.h"
27 
28 namespace tensorflow {
29 
NodeOut(Node * n,int32_t i)30 NodeBuilder::NodeOut::NodeOut(Node* n, int32_t i)  // NOLINT(runtime/explicit)
31     : node(n),
32       error(false),
33       name(node != nullptr ? node->name() : (error = true, "")),
34       index(i),
35       dt(SafeGetOutput(node, i, &error)) {}
36 
NodeOut(OutputTensor t)37 NodeBuilder::NodeOut::NodeOut(OutputTensor t) : NodeOut(t.node, t.index) {}
38 
NodeOut(StringPiece n,int32_t i,DataType t)39 NodeBuilder::NodeOut::NodeOut(StringPiece n, int32_t i, DataType t)
40     : node(nullptr), error(false), name(n), index(i), dt(t) {}
41 
NodeOut()42 NodeBuilder::NodeOut::NodeOut()
43     : node(nullptr), error(true), index(0), dt(DT_FLOAT) {}
44 
NodeBuilder(StringPiece name,StringPiece op_name,const OpRegistryInterface * op_registry,const NodeDebugInfo * debug)45 NodeBuilder::NodeBuilder(StringPiece name, StringPiece op_name,
46                          const OpRegistryInterface* op_registry,
47                          const NodeDebugInfo* debug)
48     : def_builder_(name, op_name, op_registry, debug) {}
49 
NodeBuilder(StringPiece name,const OpDef * op_def)50 NodeBuilder::NodeBuilder(StringPiece name, const OpDef* op_def)
51     : def_builder_(name, op_def) {}
52 
NodeBuilder(const NodeDefBuilder & def_builder)53 NodeBuilder::NodeBuilder(const NodeDefBuilder& def_builder)
54     : def_builder_(def_builder) {}
55 
Input(Node * src_node,int src_index)56 NodeBuilder& NodeBuilder::Input(Node* src_node, int src_index) {
57   inputs_.emplace_back(src_node, src_index);
58   DataType dt;
59   if (GetOutputType(src_node, src_index, &dt)) {
60     def_builder_.Input(src_node->name(), src_index, dt);
61   }
62   return *this;
63 }
64 
Input(NodeOut src)65 NodeBuilder& NodeBuilder::Input(NodeOut src) {
66   if (src.error) {
67     AddIndexError(src.node, src.index);
68   } else {
69     inputs_.emplace_back(src.node, src.index);
70     def_builder_.Input(src.name, src.index, src.dt);
71   }
72   return *this;
73 }
74 
Input(gtl::ArraySlice<NodeOut> src_list)75 NodeBuilder& NodeBuilder::Input(gtl::ArraySlice<NodeOut> src_list) {
76   std::vector<NodeDefBuilder::NodeOut> srcs;
77   srcs.reserve(src_list.size());
78   for (const auto& node_out : src_list) {
79     if (node_out.error) {
80       AddIndexError(node_out.node, node_out.index);
81     } else {
82       srcs.emplace_back(node_out.name, node_out.index, node_out.dt);
83       inputs_.emplace_back(node_out.node, node_out.index);
84     }
85   }
86   def_builder_.Input(gtl::ArraySlice<NodeDefBuilder::NodeOut>(srcs));
87   return *this;
88 }
89 
ControlInput(Node * src_node)90 NodeBuilder& NodeBuilder::ControlInput(Node* src_node) {
91   control_inputs_.emplace_back(src_node);
92   def_builder_.ControlInput(src_node->name());
93   return *this;
94 }
95 
ControlInputs(gtl::ArraySlice<Node * > src_nodes)96 NodeBuilder& NodeBuilder::ControlInputs(gtl::ArraySlice<Node*> src_nodes) {
97   control_inputs_.insert(control_inputs_.end(), src_nodes.begin(),
98                          src_nodes.end());
99   for (const Node* src_node : src_nodes) {
100     def_builder_.ControlInput(src_node->name());
101   }
102   return *this;
103 }
104 
Device(StringPiece device_spec)105 NodeBuilder& NodeBuilder::Device(StringPiece device_spec) {
106   def_builder_.Device(device_spec);
107   return *this;
108 }
109 
AssignedDevice(StringPiece device)110 NodeBuilder& NodeBuilder::AssignedDevice(StringPiece device) {
111   assigned_device_ = string(device);
112   return *this;
113 }
114 
XlaCluster(StringPiece xla_cluster)115 NodeBuilder& NodeBuilder::XlaCluster(StringPiece xla_cluster) {
116   def_builder_.Attr("_XlaCluster", xla_cluster);
117   return *this;
118 }
119 
Finalize(Graph * graph,bool consume)120 StatusOr<Node*> NodeBuilder::Finalize(Graph* graph, bool consume) {
121   Node* out;
122   TF_RETURN_IF_ERROR(Finalize(graph, &out, consume));
123   return out;
124 }
125 
Finalize(Graph * graph,Node ** created_node,bool consume)126 Status NodeBuilder::Finalize(Graph* graph, Node** created_node, bool consume) {
127   // In case of error, set *created_node to nullptr.
128   if (created_node != nullptr) {
129     *created_node = nullptr;
130   }
131   if (!errors_.empty()) {
132     return errors::InvalidArgument(absl::StrJoin(errors_, "\n"));
133   }
134 
135   NodeDef node_def;
136   TF_RETURN_IF_ERROR(def_builder_.Finalize(&node_def, consume));
137   TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, def_builder_.op_def()));
138   TF_RETURN_IF_ERROR(
139       CheckOpDeprecation(def_builder_.op_def(), graph->versions().producer()));
140 
141   TF_ASSIGN_OR_RETURN(Node * node, graph->AddNode(std::move(node_def)));
142 
143   node->set_assigned_device_name(assigned_device_);
144 
145   for (size_t i = 0; i < inputs_.size(); ++i) {
146     if (inputs_[i].node != nullptr) {  // Skip back edges.
147       graph->AddEdge(inputs_[i].node, inputs_[i].index, node, i);
148     }
149   }
150   for (Node* control_input : control_inputs_) {
151     graph->AddControlEdge(control_input, node);
152   }
153 
154   if (created_node != nullptr) *created_node = node;
155 
156   return OkStatus();
157 }
158 
AddIndexError(const Node * node,int i)159 void NodeBuilder::AddIndexError(const Node* node, int i) {
160   if (node == nullptr) {
161     errors_.emplace_back(
162         strings::StrCat("Attempt to add nullptr Node to node with type ",
163                         def_builder_.op_def().name()));
164   } else {
165     errors_.emplace_back(strings::StrCat(
166         "Attempt to add output ", i, " of ", node->name(), " not in range [0, ",
167         node->num_outputs(), ") to node with type ",
168         def_builder_.op_def().name(), ". Node: ", FormatNodeForError(*node)));
169   }
170 }
171 
GetOutputType(const Node * node,int i,DataType * dt)172 bool NodeBuilder::GetOutputType(const Node* node, int i, DataType* dt) {
173   bool error;
174   *dt = SafeGetOutput(node, i, &error);
175   if (error) AddIndexError(node, i);
176   return !error;
177 }
178 
179 }  // namespace tensorflow
180