xref: /aosp_15_r20/external/tensorflow/tensorflow/core/graph/graph_def_builder.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/graph/graph_def_builder.h"
17 
18 #include <utility>
19 
20 #include "tensorflow/core/graph/tensor_id.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 
23 namespace tensorflow {
24 
Options(Graph * graph,Status * status)25 GraphDefBuilder::Options::Options(Graph* graph, Status* status)
26     : graph_(graph), status_(status) {}
~Options()27 GraphDefBuilder::Options::~Options() {}
28 
WithName(StringPiece name) const29 GraphDefBuilder::Options GraphDefBuilder::Options::WithName(
30     StringPiece name) const {
31   return Options(*this).WithNameImpl(name);
32 }
WithDevice(StringPiece device) const33 GraphDefBuilder::Options GraphDefBuilder::Options::WithDevice(
34     StringPiece device) const {
35   return Options(*this).WithDeviceImpl(device);
36 }
WithControlInput(Node * control_input) const37 GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInput(
38     Node* control_input) const {
39   return Options(*this).WithControlInputImpl(control_input);
40 }
WithControlInputs(gtl::ArraySlice<Node * > control_inputs) const41 GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInputs(
42     gtl::ArraySlice<Node*> control_inputs) const {
43   return Options(*this).WithControlInputsImpl(control_inputs);
44 }
WithNameImpl(StringPiece name)45 GraphDefBuilder::Options GraphDefBuilder::Options::WithNameImpl(
46     StringPiece name) {
47   name_ = string(name);
48   return *this;
49 }
WithDeviceImpl(StringPiece device)50 GraphDefBuilder::Options GraphDefBuilder::Options::WithDeviceImpl(
51     StringPiece device) {
52   device_ = string(device);
53   return *this;
54 }
WithControlInputImpl(Node * control_input)55 GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInputImpl(
56     Node* control_input) {
57   control_inputs_.push_back(control_input);
58   return *this;
59 }
WithControlInputsImpl(gtl::ArraySlice<Node * > control_inputs)60 GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInputsImpl(
61     gtl::ArraySlice<Node*> control_inputs) {
62   control_inputs_.insert(control_inputs_.end(), control_inputs.begin(),
63                          control_inputs.end());
64   return *this;
65 }
66 
ToGraphDef(GraphDef * graph_def) const67 Status GraphDefBuilder::ToGraphDef(GraphDef* graph_def) const {
68   if (status_.ok()) {
69     graph_.ToGraphDef(graph_def);
70     *graph_def->mutable_library() = flib_def_.ToProto();
71   }
72   return status_;
73 }
74 
GetNameForOp(StringPiece op) const75 string GraphDefBuilder::Options::GetNameForOp(StringPiece op) const {
76   if (name_.empty()) return graph_->NewName(op);
77   return name_;
78 }
79 
FinalizeBuilder(NodeBuilder * builder) const80 Node* GraphDefBuilder::Options::FinalizeBuilder(NodeBuilder* builder) const {
81   builder->ControlInputs(control_inputs_);
82   if (!device_.empty()) builder->Device(device_);
83   for (const auto& attr : attrs_) {
84     builder->Attr(attr.first, attr.second);
85   }
86 
87   Node* returned_node;
88   UpdateStatus(builder->Finalize(graph_, &returned_node));
89   return returned_node;
90 }
91 
UpdateStatus(const Status & status) const92 void GraphDefBuilder::Options::UpdateStatus(const Status& status) const {
93   if (status_ == nullptr) {
94     TF_CHECK_OK(status);
95   } else {
96     status_->Update(status);
97   }
98 }
99 
100 namespace ops {
101 
SourceOp(const string & op_name,const GraphDefBuilder::Options & opts)102 Node* SourceOp(const string& op_name, const GraphDefBuilder::Options& opts) {
103   if (opts.HaveError()) return nullptr;
104   NodeBuilder node_builder(opts.GetNameForOp(op_name), op_name,
105                            opts.op_registry());
106   return opts.FinalizeBuilder(&node_builder);
107 }
108 
UnaryOp(const string & op_name,NodeOut input,const GraphDefBuilder::Options & opts)109 Node* UnaryOp(const string& op_name, NodeOut input,
110               const GraphDefBuilder::Options& opts) {
111   if (opts.HaveError()) return nullptr;
112   NodeBuilder node_builder(opts.GetNameForOp(op_name), op_name,
113                            opts.op_registry());
114   node_builder.Input(std::move(input));
115   return opts.FinalizeBuilder(&node_builder);
116 }
117 
BinaryOp(const string & op_name,NodeOut a,NodeOut b,const GraphDefBuilder::Options & opts)118 Node* BinaryOp(const string& op_name, NodeOut a, NodeOut b,
119                const GraphDefBuilder::Options& opts) {
120   if (opts.HaveError()) return nullptr;
121   NodeBuilder node_builder(opts.GetNameForOp(op_name), op_name,
122                            opts.op_registry());
123   node_builder.Input(std::move(a)).Input(std::move(b));
124   return opts.FinalizeBuilder(&node_builder);
125 }
126 
TernaryOp(const string & op_name,NodeOut a,NodeOut b,NodeOut c,const GraphDefBuilder::Options & opts)127 Node* TernaryOp(const string& op_name, NodeOut a, NodeOut b, NodeOut c,
128                 const GraphDefBuilder::Options& opts) {
129   if (opts.HaveError()) return nullptr;
130   NodeBuilder node_builder(opts.GetNameForOp(op_name), op_name,
131                            opts.op_registry());
132   node_builder.Input(std::move(a)).Input(std::move(b)).Input(std::move(c));
133   return opts.FinalizeBuilder(&node_builder);
134 }
135 
136 }  // end namespace ops
137 }  // end namespace tensorflow
138