1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_GRAPH_GRAPH_DEF_BUILDER_H_ 17 #define TENSORFLOW_CORE_GRAPH_GRAPH_DEF_BUILDER_H_ 18 19 #include <vector> 20 21 #include "tensorflow/core/framework/function.pb.h" 22 #include "tensorflow/core/framework/graph.pb.h" 23 #include "tensorflow/core/framework/op.h" 24 #include "tensorflow/core/graph/graph.h" 25 #include "tensorflow/core/graph/node_builder.h" 26 #include "tensorflow/core/lib/core/status.h" 27 #include "tensorflow/core/lib/core/stringpiece.h" 28 #include "tensorflow/core/lib/gtl/array_slice.h" 29 30 namespace tensorflow { 31 32 // Given a function like: 33 // namespace ops { 34 // Node* Identity(NodeOut input, const GraphDefBuilder::Options& opts) { 35 // if (opts.HaveError()) return nullptr; 36 // static const string kOpName = "Identity"; 37 // NodeBuilder node_builder(opts.GetNameForOp(kOpName), kOpName, 38 // opts.op_registry()); 39 // node_builder.Input(input); 40 // return opts.FinalizeBuilder(&node_builder); 41 // } 42 // } // namespace ops 43 // 44 // // Or, alternatively: 45 // namespace ops { 46 // Node* Identity(NodeOut input, const GraphDefBuilder::Options& opts) { 47 // static const string kOpName = "Identity"; 48 // return UnaryOp(kOpName, input, opts); 49 // } 50 // } // namespace ops 51 // 52 // You call it like: 53 // GraphDefBuilder b; 54 // using namespace ::tensorflow::ops; // NOLINT(build/namespaces) 55 // Node* na = Const(7, b.opts()); 56 // // Note: WithName() returns a copy, opts is unchanged. 57 // Node* nb = Const(5, b.opts().WithName("control-input")); 58 // Node* nc = Identity(na, b.opts().WithControlInput(nb)); 59 // GraphDef graph_def; 60 // Status status = b.ToGraphDef(&graph_def); 61 // if (!status.ok()) { /* Handle error */ } 62 // 63 // In tests you can skip the status handling via: 64 // GraphDefBuilder b(GraphDefBuilder::kFailImmediately); 65 // ... 66 // b.ToGraphDef(&graph_def); 67 68 class GraphDefBuilder { 69 public: 70 // Options for adding a Node to a Graph. 71 class Options { 72 public: 73 // Sets the Graph (that Nodes will be added to) and the status. The 74 // status may be set to nullptr, in which case errors cause CHECK 75 // failures. The graph and status must outlive *this. 76 Options(Graph* graph, Status* status); 77 ~Options(); 78 79 // Methods for setting options. These are const methods: they 80 // return a copy of *this with the option set. 81 Options WithName(StringPiece name) const; 82 Options WithDevice(StringPiece device) const; 83 Options WithControlInput(Node* control_input) const; 84 Options WithControlInputs(gtl::ArraySlice<Node*> control_inputs) const; 85 86 // Override the default value for an optional attr. 87 template <class T> WithAttr(StringPiece attr_name,T && value)88 Options WithAttr(StringPiece attr_name, T&& value) const { 89 return Options(*this).WithAttrImpl(attr_name, std::forward<T>(value)); 90 } 91 // Note: overload needed to allow {...} expressions for value. 92 template <class T> WithAttr(StringPiece attr_name,std::initializer_list<T> value)93 Options WithAttr(StringPiece attr_name, 94 std::initializer_list<T> value) const { 95 return WithAttr<std::initializer_list<T>>(attr_name, std::move(value)); 96 } 97 98 // Methods for using options from a function that creates a Node. 99 100 // Returns true if the status associated with *this has an error. 101 // Use this to skip processing that may depend on prior results. HaveError()102 bool HaveError() const { return status_ != nullptr && !status_->ok(); } 103 104 // Returns a string representation of the status associated with *this. 105 // Returns the string `"OK"` if the status doesn't have any error. StatusToString()106 string StatusToString() const { 107 return status_->ok() ? "OK" : status_->error_message(); 108 } 109 110 // Given the Op type name, return a name for a node of that type. 111 // Uses the value set in WithName() if that has been called. Otherwise, 112 // returns a name built out of the Op type name. 113 string GetNameForOp(StringPiece op) const; 114 115 // Sets the device, adds control inputs, adds attrs, and calls Finalize(). 116 // If Finalize returns an error, it is saved and this function returns 117 // nullptr. 118 Node* FinalizeBuilder(NodeBuilder* builder) const; 119 120 // Updates the associated status, if any, or calls TF_CHECK_OK if none. 121 void UpdateStatus(const Status& status) const; 122 123 // Accessor op_registry()124 const OpRegistryInterface* op_registry() const { 125 return graph_->op_registry(); 126 } 127 128 private: 129 Options WithNameImpl(StringPiece name); 130 Options WithDeviceImpl(StringPiece device); 131 Options WithControlInputImpl(Node* control_input); 132 Options WithControlInputsImpl(gtl::ArraySlice<Node*> control_inputs); 133 template <class T> WithAttrImpl(StringPiece name,T && value)134 Options WithAttrImpl(StringPiece name, T&& value) { 135 attrs_.emplace_back(string(name), AttrValue()); 136 SetAttrValue(std::forward<T>(value), &attrs_.back().second); 137 return *this; 138 } 139 140 Graph* const graph_; 141 Status* const status_; 142 string name_; 143 string device_; 144 std::vector<Node*> control_inputs_; 145 std::vector<std::pair<string, AttrValue>> attrs_; 146 }; 147 148 // Start building a new graph. 149 explicit GraphDefBuilder( 150 const OpRegistryInterface* op_registry = OpRegistry::Global()) graph_(op_registry)151 : graph_(op_registry), flib_def_(op_registry), opts_(&graph_, &status_) {} 152 153 // For use in tests, where you want to fail immediately on error instead 154 // of checking the status at the end. 155 enum TestFailImmediatelyType { kFailImmediately }; 156 explicit GraphDefBuilder( 157 TestFailImmediatelyType, 158 const OpRegistryInterface* op_registry = OpRegistry::Global()) graph_(op_registry)159 : graph_(op_registry), flib_def_(op_registry), opts_(&graph_, nullptr) {} 160 161 // Gets the Options with the associated Graph and Status. opts()162 const Options& opts() const { return opts_; } 163 164 // Once all the nodes have been added, call this to get whether it was 165 // successful, and if so fill *graph_def. 166 Status ToGraphDef(GraphDef* graph_def) const; 167 168 // Adds the function and gradient definitions in `fdef_lib` to this graph's op 169 // registry. Ignores duplicate functions, and returns a bad status if an 170 // imported function differs from an existing function or op with the same 171 // name. AddFunctionLibrary(const FunctionDefLibrary & fdef_lib)172 Status AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) { 173 return flib_def_.AddLibrary(fdef_lib); 174 } 175 176 // Returns whether a user-defined function with `name` already exists in the 177 // graph. HasFunction(const string & name)178 bool HasFunction(const string& name) { 179 return flib_def_.Find(name) != nullptr; 180 } 181 182 private: 183 Graph graph_; 184 FunctionLibraryDefinition flib_def_; 185 Status status_; 186 Options opts_; 187 }; 188 189 namespace ops { 190 191 // A NodeOut may either be a regular input or back input. Regular 192 // inputs are specified via either a Node* or a Node* and an output 193 // index. Back inputs are specified by a node name, output index, and 194 // output type. 195 typedef NodeBuilder::NodeOut NodeOut; 196 197 // For adding an Op with no inputs to a GraphDefBuilder. 198 Node* SourceOp(const string& op_name, const GraphDefBuilder::Options& opts); 199 200 // For adding an Op with one input to a GraphDefBuilder. 201 Node* UnaryOp(const string& op_name, NodeOut input, 202 const GraphDefBuilder::Options& opts); 203 204 // For adding an Op with two inputs to a GraphDefBuilder. 205 Node* BinaryOp(const string& op_name, NodeOut a, NodeOut b, 206 const GraphDefBuilder::Options& opts); 207 208 // For adding an Op with three inputs to a GraphDefBuilder. 209 Node* TernaryOp(const string& op_name, NodeOut a, NodeOut b, NodeOut c, 210 const GraphDefBuilder::Options& opts); 211 212 } // namespace ops 213 } // namespace tensorflow 214 215 #endif // TENSORFLOW_CORE_GRAPH_GRAPH_DEF_BUILDER_H_ 216