1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/graph/graph_def_builder.h"
17
18 #include <utility>
19
20 #include "tensorflow/core/graph/tensor_id.h"
21 #include "tensorflow/core/lib/core/errors.h"
22
23 namespace tensorflow {
24
Options(Graph * graph,Status * status)25 GraphDefBuilder::Options::Options(Graph* graph, Status* status)
26 : graph_(graph), status_(status) {}
~Options()27 GraphDefBuilder::Options::~Options() {}
28
WithName(StringPiece name) const29 GraphDefBuilder::Options GraphDefBuilder::Options::WithName(
30 StringPiece name) const {
31 return Options(*this).WithNameImpl(name);
32 }
WithDevice(StringPiece device) const33 GraphDefBuilder::Options GraphDefBuilder::Options::WithDevice(
34 StringPiece device) const {
35 return Options(*this).WithDeviceImpl(device);
36 }
WithControlInput(Node * control_input) const37 GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInput(
38 Node* control_input) const {
39 return Options(*this).WithControlInputImpl(control_input);
40 }
WithControlInputs(gtl::ArraySlice<Node * > control_inputs) const41 GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInputs(
42 gtl::ArraySlice<Node*> control_inputs) const {
43 return Options(*this).WithControlInputsImpl(control_inputs);
44 }
WithNameImpl(StringPiece name)45 GraphDefBuilder::Options GraphDefBuilder::Options::WithNameImpl(
46 StringPiece name) {
47 name_ = string(name);
48 return *this;
49 }
WithDeviceImpl(StringPiece device)50 GraphDefBuilder::Options GraphDefBuilder::Options::WithDeviceImpl(
51 StringPiece device) {
52 device_ = string(device);
53 return *this;
54 }
WithControlInputImpl(Node * control_input)55 GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInputImpl(
56 Node* control_input) {
57 control_inputs_.push_back(control_input);
58 return *this;
59 }
WithControlInputsImpl(gtl::ArraySlice<Node * > control_inputs)60 GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInputsImpl(
61 gtl::ArraySlice<Node*> control_inputs) {
62 control_inputs_.insert(control_inputs_.end(), control_inputs.begin(),
63 control_inputs.end());
64 return *this;
65 }
66
ToGraphDef(GraphDef * graph_def) const67 Status GraphDefBuilder::ToGraphDef(GraphDef* graph_def) const {
68 if (status_.ok()) {
69 graph_.ToGraphDef(graph_def);
70 *graph_def->mutable_library() = flib_def_.ToProto();
71 }
72 return status_;
73 }
74
GetNameForOp(StringPiece op) const75 string GraphDefBuilder::Options::GetNameForOp(StringPiece op) const {
76 if (name_.empty()) return graph_->NewName(op);
77 return name_;
78 }
79
FinalizeBuilder(NodeBuilder * builder) const80 Node* GraphDefBuilder::Options::FinalizeBuilder(NodeBuilder* builder) const {
81 builder->ControlInputs(control_inputs_);
82 if (!device_.empty()) builder->Device(device_);
83 for (const auto& attr : attrs_) {
84 builder->Attr(attr.first, attr.second);
85 }
86
87 Node* returned_node;
88 UpdateStatus(builder->Finalize(graph_, &returned_node));
89 return returned_node;
90 }
91
UpdateStatus(const Status & status) const92 void GraphDefBuilder::Options::UpdateStatus(const Status& status) const {
93 if (status_ == nullptr) {
94 TF_CHECK_OK(status);
95 } else {
96 status_->Update(status);
97 }
98 }
99
100 namespace ops {
101
SourceOp(const string & op_name,const GraphDefBuilder::Options & opts)102 Node* SourceOp(const string& op_name, const GraphDefBuilder::Options& opts) {
103 if (opts.HaveError()) return nullptr;
104 NodeBuilder node_builder(opts.GetNameForOp(op_name), op_name,
105 opts.op_registry());
106 return opts.FinalizeBuilder(&node_builder);
107 }
108
UnaryOp(const string & op_name,NodeOut input,const GraphDefBuilder::Options & opts)109 Node* UnaryOp(const string& op_name, NodeOut input,
110 const GraphDefBuilder::Options& opts) {
111 if (opts.HaveError()) return nullptr;
112 NodeBuilder node_builder(opts.GetNameForOp(op_name), op_name,
113 opts.op_registry());
114 node_builder.Input(std::move(input));
115 return opts.FinalizeBuilder(&node_builder);
116 }
117
BinaryOp(const string & op_name,NodeOut a,NodeOut b,const GraphDefBuilder::Options & opts)118 Node* BinaryOp(const string& op_name, NodeOut a, NodeOut b,
119 const GraphDefBuilder::Options& opts) {
120 if (opts.HaveError()) return nullptr;
121 NodeBuilder node_builder(opts.GetNameForOp(op_name), op_name,
122 opts.op_registry());
123 node_builder.Input(std::move(a)).Input(std::move(b));
124 return opts.FinalizeBuilder(&node_builder);
125 }
126
TernaryOp(const string & op_name,NodeOut a,NodeOut b,NodeOut c,const GraphDefBuilder::Options & opts)127 Node* TernaryOp(const string& op_name, NodeOut a, NodeOut b, NodeOut c,
128 const GraphDefBuilder::Options& opts) {
129 if (opts.HaveError()) return nullptr;
130 NodeBuilder node_builder(opts.GetNameForOp(op_name), op_name,
131 opts.op_registry());
132 node_builder.Input(std::move(a)).Input(std::move(b)).Input(std::move(c));
133 return opts.FinalizeBuilder(&node_builder);
134 }
135
136 } // end namespace ops
137 } // end namespace tensorflow
138