1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/graph/node_builder.h"
17
18 #include <unordered_map>
19 #include <vector>
20
21 #include "tensorflow/core/framework/node_def_util.h"
22 #include "tensorflow/core/framework/types.pb.h"
23 #include "tensorflow/core/framework/versions.pb.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/platform/statusor.h"
26 #include "tensorflow/core/protobuf/error_codes.pb.h"
27
28 namespace tensorflow {
29
NodeOut(Node * n,int32_t i)30 NodeBuilder::NodeOut::NodeOut(Node* n, int32_t i) // NOLINT(runtime/explicit)
31 : node(n),
32 error(false),
33 name(node != nullptr ? node->name() : (error = true, "")),
34 index(i),
35 dt(SafeGetOutput(node, i, &error)) {}
36
NodeOut(OutputTensor t)37 NodeBuilder::NodeOut::NodeOut(OutputTensor t) : NodeOut(t.node, t.index) {}
38
NodeOut(StringPiece n,int32_t i,DataType t)39 NodeBuilder::NodeOut::NodeOut(StringPiece n, int32_t i, DataType t)
40 : node(nullptr), error(false), name(n), index(i), dt(t) {}
41
NodeOut()42 NodeBuilder::NodeOut::NodeOut()
43 : node(nullptr), error(true), index(0), dt(DT_FLOAT) {}
44
NodeBuilder(StringPiece name,StringPiece op_name,const OpRegistryInterface * op_registry,const NodeDebugInfo * debug)45 NodeBuilder::NodeBuilder(StringPiece name, StringPiece op_name,
46 const OpRegistryInterface* op_registry,
47 const NodeDebugInfo* debug)
48 : def_builder_(name, op_name, op_registry, debug) {}
49
NodeBuilder(StringPiece name,const OpDef * op_def)50 NodeBuilder::NodeBuilder(StringPiece name, const OpDef* op_def)
51 : def_builder_(name, op_def) {}
52
NodeBuilder(const NodeDefBuilder & def_builder)53 NodeBuilder::NodeBuilder(const NodeDefBuilder& def_builder)
54 : def_builder_(def_builder) {}
55
Input(Node * src_node,int src_index)56 NodeBuilder& NodeBuilder::Input(Node* src_node, int src_index) {
57 inputs_.emplace_back(src_node, src_index);
58 DataType dt;
59 if (GetOutputType(src_node, src_index, &dt)) {
60 def_builder_.Input(src_node->name(), src_index, dt);
61 }
62 return *this;
63 }
64
Input(NodeOut src)65 NodeBuilder& NodeBuilder::Input(NodeOut src) {
66 if (src.error) {
67 AddIndexError(src.node, src.index);
68 } else {
69 inputs_.emplace_back(src.node, src.index);
70 def_builder_.Input(src.name, src.index, src.dt);
71 }
72 return *this;
73 }
74
Input(gtl::ArraySlice<NodeOut> src_list)75 NodeBuilder& NodeBuilder::Input(gtl::ArraySlice<NodeOut> src_list) {
76 std::vector<NodeDefBuilder::NodeOut> srcs;
77 srcs.reserve(src_list.size());
78 for (const auto& node_out : src_list) {
79 if (node_out.error) {
80 AddIndexError(node_out.node, node_out.index);
81 } else {
82 srcs.emplace_back(node_out.name, node_out.index, node_out.dt);
83 inputs_.emplace_back(node_out.node, node_out.index);
84 }
85 }
86 def_builder_.Input(gtl::ArraySlice<NodeDefBuilder::NodeOut>(srcs));
87 return *this;
88 }
89
ControlInput(Node * src_node)90 NodeBuilder& NodeBuilder::ControlInput(Node* src_node) {
91 control_inputs_.emplace_back(src_node);
92 def_builder_.ControlInput(src_node->name());
93 return *this;
94 }
95
ControlInputs(gtl::ArraySlice<Node * > src_nodes)96 NodeBuilder& NodeBuilder::ControlInputs(gtl::ArraySlice<Node*> src_nodes) {
97 control_inputs_.insert(control_inputs_.end(), src_nodes.begin(),
98 src_nodes.end());
99 for (const Node* src_node : src_nodes) {
100 def_builder_.ControlInput(src_node->name());
101 }
102 return *this;
103 }
104
Device(StringPiece device_spec)105 NodeBuilder& NodeBuilder::Device(StringPiece device_spec) {
106 def_builder_.Device(device_spec);
107 return *this;
108 }
109
AssignedDevice(StringPiece device)110 NodeBuilder& NodeBuilder::AssignedDevice(StringPiece device) {
111 assigned_device_ = string(device);
112 return *this;
113 }
114
XlaCluster(StringPiece xla_cluster)115 NodeBuilder& NodeBuilder::XlaCluster(StringPiece xla_cluster) {
116 def_builder_.Attr("_XlaCluster", xla_cluster);
117 return *this;
118 }
119
Finalize(Graph * graph,bool consume)120 StatusOr<Node*> NodeBuilder::Finalize(Graph* graph, bool consume) {
121 Node* out;
122 TF_RETURN_IF_ERROR(Finalize(graph, &out, consume));
123 return out;
124 }
125
Finalize(Graph * graph,Node ** created_node,bool consume)126 Status NodeBuilder::Finalize(Graph* graph, Node** created_node, bool consume) {
127 // In case of error, set *created_node to nullptr.
128 if (created_node != nullptr) {
129 *created_node = nullptr;
130 }
131 if (!errors_.empty()) {
132 return errors::InvalidArgument(absl::StrJoin(errors_, "\n"));
133 }
134
135 NodeDef node_def;
136 TF_RETURN_IF_ERROR(def_builder_.Finalize(&node_def, consume));
137 TF_RETURN_IF_ERROR(ValidateNodeDef(node_def, def_builder_.op_def()));
138 TF_RETURN_IF_ERROR(
139 CheckOpDeprecation(def_builder_.op_def(), graph->versions().producer()));
140
141 TF_ASSIGN_OR_RETURN(Node * node, graph->AddNode(std::move(node_def)));
142
143 node->set_assigned_device_name(assigned_device_);
144
145 for (size_t i = 0; i < inputs_.size(); ++i) {
146 if (inputs_[i].node != nullptr) { // Skip back edges.
147 graph->AddEdge(inputs_[i].node, inputs_[i].index, node, i);
148 }
149 }
150 for (Node* control_input : control_inputs_) {
151 graph->AddControlEdge(control_input, node);
152 }
153
154 if (created_node != nullptr) *created_node = node;
155
156 return OkStatus();
157 }
158
AddIndexError(const Node * node,int i)159 void NodeBuilder::AddIndexError(const Node* node, int i) {
160 if (node == nullptr) {
161 errors_.emplace_back(
162 strings::StrCat("Attempt to add nullptr Node to node with type ",
163 def_builder_.op_def().name()));
164 } else {
165 errors_.emplace_back(strings::StrCat(
166 "Attempt to add output ", i, " of ", node->name(), " not in range [0, ",
167 node->num_outputs(), ") to node with type ",
168 def_builder_.op_def().name(), ". Node: ", FormatNodeForError(*node)));
169 }
170 }
171
GetOutputType(const Node * node,int i,DataType * dt)172 bool NodeBuilder::GetOutputType(const Node* node, int i, DataType* dt) {
173 bool error;
174 *dt = SafeGetOutput(node, i, &error);
175 if (error) AddIndexError(node, i);
176 return !error;
177 }
178
179 } // namespace tensorflow
180