1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/cc/tools/freeze_saved_model.h"
17
18 #include <iostream>
19 #include <queue>
20
21 #include "tensorflow/core/framework/attr_value.pb.h"
22 #include "tensorflow/core/framework/function.pb.h"
23 #include "tensorflow/core/framework/graph.pb.h"
24 #include "tensorflow/core/framework/node_def.pb.h"
25 #include "tensorflow/core/framework/versions.pb.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/strings/str_util.h"
28 #include "tensorflow/core/platform/statusor.h"
29 #include "tensorflow/core/protobuf/meta_graph.pb.h"
30
31 namespace tensorflow {
32
33 namespace {
34
35 // Gets tensor names from tensor_info and inserts them into the set of tensor
36 // names.
GetTensorNamesFromTensorInfo(const TensorInfo & tensor_info,std::unordered_set<string> * tensor_names)37 void GetTensorNamesFromTensorInfo(const TensorInfo& tensor_info,
38 std::unordered_set<string>* tensor_names) {
39 if (tensor_info.has_coo_sparse()) {
40 // If the tensor is sparse we have to add all three tensors of the sparse
41 // representations.
42 const TensorInfo_CooSparse& coo_sparse = tensor_info.coo_sparse();
43 tensor_names->insert(coo_sparse.values_tensor_name());
44 tensor_names->insert(coo_sparse.indices_tensor_name());
45 tensor_names->insert(coo_sparse.dense_shape_tensor_name());
46 } else if (tensor_info.has_composite_tensor()) {
47 for (const auto& component : tensor_info.composite_tensor().components()) {
48 tensor_names->insert(component.name());
49 }
50 } else {
51 tensor_names->insert(tensor_info.name());
52 }
53 }
54
55 // Gets the union of all inputs and outputs of all SignatureDefs in the bundle
GetSignatureDefsInputsAndOutputs(const SavedModelBundle & saved_model_bundle,std::unordered_set<string> * inputs,std::unordered_set<string> * outputs)56 void GetSignatureDefsInputsAndOutputs(
57 const SavedModelBundle& saved_model_bundle,
58 std::unordered_set<string>* inputs, std::unordered_set<string>* outputs) {
59 for (auto& sigdef_elem : saved_model_bundle.meta_graph_def.signature_def()) {
60 const SignatureDef& signature_def = sigdef_elem.second;
61 for (auto& input_elem : signature_def.inputs()) {
62 GetTensorNamesFromTensorInfo(input_elem.second, inputs);
63 }
64 for (auto& output_elem : signature_def.outputs()) {
65 GetTensorNamesFromTensorInfo(output_elem.second, outputs);
66 }
67 }
68 }
69
70 // Gets a map from string node name to NodeDef.
GetNodeNameToNodeDefMap(GraphDef * graph_def,std::unordered_map<string,NodeDef * > * name_to_node_map)71 void GetNodeNameToNodeDefMap(
72 GraphDef* graph_def,
73 std::unordered_map<string, NodeDef*>* name_to_node_map) {
74 for (size_t i = 0; i < graph_def->node_size(); i++) {
75 NodeDef* node = graph_def->mutable_node(i);
76 (*name_to_node_map)[node->name()] = node;
77 }
78 }
79
80 // Strips off the tensor part of the tensor_name to get the node_name.
GetNodeNameFromTensorName(string tensor_name)81 const string GetNodeNameFromTensorName(string tensor_name) {
82 if (tensor_name[0] == '^') {
83 tensor_name.erase(0, 1);
84 }
85 std::vector<string> tensor_name_parts = str_util::Split(tensor_name, ':');
86 return tensor_name_parts[0];
87 }
88
89 // Gets the set of node names needed by `outputs` and the corresponding set of
90 // variable nodes to convert.
GetReachableNodesAndVariables(GraphDef * graph_def,const std::unordered_set<string> & outputs,const std::unordered_map<string,NodeDef * > & name_to_node_map,std::unordered_set<string> * reachable_node_names,std::unordered_set<string> * variable_node_names)91 void GetReachableNodesAndVariables(
92 GraphDef* graph_def, const std::unordered_set<string>& outputs,
93 const std::unordered_map<string, NodeDef*>& name_to_node_map,
94 std::unordered_set<string>* reachable_node_names,
95 std::unordered_set<string>* variable_node_names) {
96 // TODO(suharshs): Add support for ResourceVariables.
97 static const std::unordered_set<string>* kVariableTypes =
98 new std::unordered_set<string>({"Variable", "VariableV2", "VarHandleOp"});
99
100 std::queue<string> nodes_to_visit;
101 for (const string& output_tensor_name : outputs) {
102 nodes_to_visit.push(GetNodeNameFromTensorName(output_tensor_name));
103 }
104 // We do a traversal backwards from the outputs specified in the MetaGraphDef.
105 while (!nodes_to_visit.empty()) {
106 const string node_name = nodes_to_visit.front();
107 nodes_to_visit.pop();
108 if (reachable_node_names->find(node_name) != reachable_node_names->end()) {
109 continue;
110 }
111 reachable_node_names->insert(node_name);
112 NodeDef* node = name_to_node_map.at(node_name);
113 if (kVariableTypes->find(node->op()) != kVariableTypes->end()) {
114 variable_node_names->insert(node->name());
115 }
116 for (const string& input_tensor_name : node->input()) {
117 nodes_to_visit.push(GetNodeNameFromTensorName(input_tensor_name));
118 }
119 }
120 }
121
122 // Gets a map from variable name to variable value.
GetVariableNameToTensorMap(Session * session,const std::unordered_map<string,NodeDef * > & name_to_node_map,std::unordered_set<string> variable_names_set,std::unordered_map<string,Tensor> * variable_name_to_value_map)123 Status GetVariableNameToTensorMap(
124 Session* session,
125 const std::unordered_map<string, NodeDef*>& name_to_node_map,
126 std::unordered_set<string> variable_names_set,
127 std::unordered_map<string, Tensor>* variable_name_to_value_map) {
128 if (variable_names_set.empty()) {
129 return OkStatus();
130 }
131 std::vector<string> variable_names;
132 variable_names.reserve(variable_names_set.size());
133 std::vector<string> tensor_names;
134 tensor_names.reserve(variable_names_set.size());
135 for (const string& node_name : variable_names_set) {
136 variable_names.push_back(node_name);
137 NodeDef* node_def = name_to_node_map.at(node_name);
138 if (node_def->op() == "VarHandleOp") {
139 // If this is a resource variable, we have to run the corresponding
140 // ReadVariableOp.
141 tensor_names.push_back(node_name + "/Read/ReadVariableOp:0");
142 } else {
143 tensor_names.push_back(node_name + ":0");
144 }
145 }
146 std::vector<Tensor> outputs;
147 TF_RETURN_IF_ERROR(
148 session->Run(/* inputs */ {}, tensor_names, /* targets */ {}, &outputs));
149 for (size_t i = 0; i < variable_names.size(); i++) {
150 (*variable_name_to_value_map)[variable_names[i]] = outputs[i];
151 }
152 return OkStatus();
153 }
154
155 // Converts a Variable NodeDef into a Constant NodeDef.
ConvertVariableToConstant(const NodeDef & variable_node,const Tensor & variable_value,NodeDef * const_node)156 void ConvertVariableToConstant(const NodeDef& variable_node,
157 const Tensor& variable_value,
158 NodeDef* const_node) {
159 const_node->set_name(variable_node.name());
160 const_node->set_op("Const");
161 (*const_node->mutable_attr())["dtype"] = variable_node.attr().at("dtype");
162 variable_value.AsProtoTensorContent(
163 (*const_node->mutable_attr())["value"].mutable_tensor());
164 }
165
166 // Converts a ReadVariableOp NodeDef to an Identity NodeDef.
ConvertReadVariableOpToIdentity(const NodeDef & node,NodeDef * identity_node)167 void ConvertReadVariableOpToIdentity(const NodeDef& node,
168 NodeDef* identity_node) {
169 identity_node->set_name(node.name());
170 identity_node->set_op("Identity");
171 (*identity_node->mutable_attr())["T"] = node.attr().at("dtype");
172 identity_node->add_input(node.input(0));
173 }
174
175 // Returns the name of the VarHandleOp that provides input (possibly indirectly)
176 // to node with node_name. A typical indirect chain of nodes (that can occur due
177 // to graph inlining) is the following: VarHandleOp -> Identity -> Identity ->
178 // ReadVariableOp. Calling the function on any of these nodes would return the
179 // name of the VarHandleOp.
GetVarHandleName(const std::unordered_map<string,NodeDef * > & name_to_node_map,string node_name)180 StatusOr<string> GetVarHandleName(
181 const std::unordered_map<string, NodeDef*>& name_to_node_map,
182 string node_name) {
183 const NodeDef* node = name_to_node_map.at(node_name);
184 while (node->input_size() > 0) {
185 auto parent = name_to_node_map.find(node->input(0));
186 if (parent == name_to_node_map.end()) break;
187 node = parent->second;
188 if (node->op() != "Identity") {
189 VLOG(2) << "Stopping at non-identity node " << node->op();
190 break;
191 }
192 }
193 if (node->op() == "VarHandleOp") {
194 return node->name();
195 }
196 return errors::NotFound("No VarHandleOp ancestor found");
197 }
198
199 // Looks up the variable handle that provides input to node with node_name,
200 // and returns the handle name if the handle corresponds to a variable that we
201 // want to freeze (i.e. its name is contained in variable_node_names). If there
202 // is no such handle in the graph (or we do not want to save that variable)
203 // then NotFound error is returned.
GetHandleNameIfNeedsToFreeze(const std::unordered_map<string,NodeDef * > & name_to_node_map,string node_name,const std::unordered_set<string> & variable_node_names)204 StatusOr<string> GetHandleNameIfNeedsToFreeze(
205 const std::unordered_map<string, NodeDef*>& name_to_node_map,
206 string node_name, const std::unordered_set<string>& variable_node_names) {
207 StatusOr<string> var_handle_name =
208 GetVarHandleName(name_to_node_map, node_name);
209 if (var_handle_name.ok() && variable_node_names.count(*var_handle_name)) {
210 return var_handle_name;
211 }
212 return errors::NotFound("No VarHandleOp ancestor found");
213 }
214
215 // Freezes the subgraph of all nodes needed by `outputs`.
FreezeGraphDef(const SavedModelBundle & saved_model_bundle,const std::unordered_set<string> & outputs,GraphDef * frozen_graph_def)216 Status FreezeGraphDef(const SavedModelBundle& saved_model_bundle,
217 const std::unordered_set<string>& outputs,
218 GraphDef* frozen_graph_def) {
219 GraphDef graph_def = saved_model_bundle.meta_graph_def.graph_def();
220 // Copy versions and library as-is from original graph.
221 *frozen_graph_def->mutable_versions() = graph_def.versions();
222 *frozen_graph_def->mutable_library() = graph_def.library();
223 // If the graph is empty there is nothing left to do.
224 if (graph_def.node_size() == 0) {
225 return OkStatus();
226 }
227 // name_to_node_map is needed to get the inputs from the NodeDef corresponding
228 // the a string node name. These inputs are used when doing our backwards
229 // traversal.
230 std::unordered_map<string, NodeDef*> name_to_node_map;
231 GetNodeNameToNodeDefMap(&graph_def, &name_to_node_map);
232 std::unordered_set<string> reachable_node_names;
233 std::unordered_set<string> variable_node_names;
234 GetReachableNodesAndVariables(&graph_def, outputs, name_to_node_map,
235 &reachable_node_names, &variable_node_names);
236 std::unordered_map<string, Tensor> variable_to_value_map;
237 TF_RETURN_IF_ERROR(GetVariableNameToTensorMap(
238 saved_model_bundle.session.get(), name_to_node_map, variable_node_names,
239 &variable_to_value_map));
240 // We copy the nodes in the same order they were in the original graph_def.
241 for (const NodeDef& node : graph_def.node()) {
242 if (reachable_node_names.find(node.name()) == reachable_node_names.end()) {
243 continue;
244 }
245 if (variable_node_names.find(node.name()) != variable_node_names.end()) {
246 ConvertVariableToConstant(node, variable_to_value_map[node.name()],
247 frozen_graph_def->add_node());
248 continue;
249 } else if (node.op() == "ReadVariableOp" &&
250 GetHandleNameIfNeedsToFreeze(name_to_node_map, node.name(),
251 variable_node_names)
252 .ok()) {
253 // If the node is a ReadVariableOp, its input VarHandleOp will be
254 // converted to a Constant, so we will need to convert it to an Identity.
255 ConvertReadVariableOpToIdentity(node, frozen_graph_def->add_node());
256 continue;
257 } else if (node.op() == "Identity") {
258 StatusOr<string> handle_name = GetHandleNameIfNeedsToFreeze(
259 name_to_node_map, node.name(), variable_node_names);
260 if (handle_name.ok()) {
261 // Identity node that is forwarding the value of a frozen
262 // VarhandleOp. We ensure that the dtype matches of the variable dtype.
263 NodeDef* new_node = frozen_graph_def->add_node();
264 *new_node = node;
265 (*new_node->mutable_attr())["T"] =
266 name_to_node_map.at(*handle_name)->attr().at("dtype");
267 continue;
268 }
269 }
270 // If the node isn't a variable, just copy the node as-is.
271 *frozen_graph_def->add_node() = node;
272 }
273 return OkStatus();
274 }
275
276 } // namespace
277
FreezeSavedModel(const SavedModelBundle & saved_model_bundle,GraphDef * frozen_graph_def,std::unordered_set<string> * inputs,std::unordered_set<string> * outputs)278 Status FreezeSavedModel(const SavedModelBundle& saved_model_bundle,
279 GraphDef* frozen_graph_def,
280 std::unordered_set<string>* inputs,
281 std::unordered_set<string>* outputs) {
282 GetSignatureDefsInputsAndOutputs(saved_model_bundle, inputs, outputs);
283 TF_RETURN_IF_ERROR(
284 FreezeGraphDef(saved_model_bundle, *outputs, frozen_graph_def));
285 return OkStatus();
286 }
287
288 } // namespace tensorflow
289