xref: /aosp_15_r20/external/tensorflow/tensorflow/cc/tools/freeze_saved_model.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/cc/tools/freeze_saved_model.h"
17 
18 #include <iostream>
19 #include <queue>
20 
21 #include "tensorflow/core/framework/attr_value.pb.h"
22 #include "tensorflow/core/framework/function.pb.h"
23 #include "tensorflow/core/framework/graph.pb.h"
24 #include "tensorflow/core/framework/node_def.pb.h"
25 #include "tensorflow/core/framework/versions.pb.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/strings/str_util.h"
28 #include "tensorflow/core/platform/statusor.h"
29 #include "tensorflow/core/protobuf/meta_graph.pb.h"
30 
31 namespace tensorflow {
32 
33 namespace {
34 
35 // Gets tensor names from tensor_info and inserts them into the set of tensor
36 // names.
GetTensorNamesFromTensorInfo(const TensorInfo & tensor_info,std::unordered_set<string> * tensor_names)37 void GetTensorNamesFromTensorInfo(const TensorInfo& tensor_info,
38                                   std::unordered_set<string>* tensor_names) {
39   if (tensor_info.has_coo_sparse()) {
40     // If the tensor is sparse we have to add all three tensors of the sparse
41     // representations.
42     const TensorInfo_CooSparse& coo_sparse = tensor_info.coo_sparse();
43     tensor_names->insert(coo_sparse.values_tensor_name());
44     tensor_names->insert(coo_sparse.indices_tensor_name());
45     tensor_names->insert(coo_sparse.dense_shape_tensor_name());
46   } else if (tensor_info.has_composite_tensor()) {
47     for (const auto& component : tensor_info.composite_tensor().components()) {
48       tensor_names->insert(component.name());
49     }
50   } else {
51     tensor_names->insert(tensor_info.name());
52   }
53 }
54 
55 // Gets the union of all inputs and outputs of all SignatureDefs in the bundle
GetSignatureDefsInputsAndOutputs(const SavedModelBundle & saved_model_bundle,std::unordered_set<string> * inputs,std::unordered_set<string> * outputs)56 void GetSignatureDefsInputsAndOutputs(
57     const SavedModelBundle& saved_model_bundle,
58     std::unordered_set<string>* inputs, std::unordered_set<string>* outputs) {
59   for (auto& sigdef_elem : saved_model_bundle.meta_graph_def.signature_def()) {
60     const SignatureDef& signature_def = sigdef_elem.second;
61     for (auto& input_elem : signature_def.inputs()) {
62       GetTensorNamesFromTensorInfo(input_elem.second, inputs);
63     }
64     for (auto& output_elem : signature_def.outputs()) {
65       GetTensorNamesFromTensorInfo(output_elem.second, outputs);
66     }
67   }
68 }
69 
70 // Gets a map from string node name to NodeDef.
GetNodeNameToNodeDefMap(GraphDef * graph_def,std::unordered_map<string,NodeDef * > * name_to_node_map)71 void GetNodeNameToNodeDefMap(
72     GraphDef* graph_def,
73     std::unordered_map<string, NodeDef*>* name_to_node_map) {
74   for (size_t i = 0; i < graph_def->node_size(); i++) {
75     NodeDef* node = graph_def->mutable_node(i);
76     (*name_to_node_map)[node->name()] = node;
77   }
78 }
79 
80 // Strips off the tensor part of the tensor_name to get the node_name.
GetNodeNameFromTensorName(string tensor_name)81 const string GetNodeNameFromTensorName(string tensor_name) {
82   if (tensor_name[0] == '^') {
83     tensor_name.erase(0, 1);
84   }
85   std::vector<string> tensor_name_parts = str_util::Split(tensor_name, ':');
86   return tensor_name_parts[0];
87 }
88 
89 // Gets the set of node names needed by `outputs` and the corresponding set of
90 // variable nodes to convert.
GetReachableNodesAndVariables(GraphDef * graph_def,const std::unordered_set<string> & outputs,const std::unordered_map<string,NodeDef * > & name_to_node_map,std::unordered_set<string> * reachable_node_names,std::unordered_set<string> * variable_node_names)91 void GetReachableNodesAndVariables(
92     GraphDef* graph_def, const std::unordered_set<string>& outputs,
93     const std::unordered_map<string, NodeDef*>& name_to_node_map,
94     std::unordered_set<string>* reachable_node_names,
95     std::unordered_set<string>* variable_node_names) {
96   // TODO(suharshs): Add support for ResourceVariables.
97   static const std::unordered_set<string>* kVariableTypes =
98       new std::unordered_set<string>({"Variable", "VariableV2", "VarHandleOp"});
99 
100   std::queue<string> nodes_to_visit;
101   for (const string& output_tensor_name : outputs) {
102     nodes_to_visit.push(GetNodeNameFromTensorName(output_tensor_name));
103   }
104   // We do a traversal backwards from the outputs specified in the MetaGraphDef.
105   while (!nodes_to_visit.empty()) {
106     const string node_name = nodes_to_visit.front();
107     nodes_to_visit.pop();
108     if (reachable_node_names->find(node_name) != reachable_node_names->end()) {
109       continue;
110     }
111     reachable_node_names->insert(node_name);
112     NodeDef* node = name_to_node_map.at(node_name);
113     if (kVariableTypes->find(node->op()) != kVariableTypes->end()) {
114       variable_node_names->insert(node->name());
115     }
116     for (const string& input_tensor_name : node->input()) {
117       nodes_to_visit.push(GetNodeNameFromTensorName(input_tensor_name));
118     }
119   }
120 }
121 
122 // Gets a map from variable name to variable value.
GetVariableNameToTensorMap(Session * session,const std::unordered_map<string,NodeDef * > & name_to_node_map,std::unordered_set<string> variable_names_set,std::unordered_map<string,Tensor> * variable_name_to_value_map)123 Status GetVariableNameToTensorMap(
124     Session* session,
125     const std::unordered_map<string, NodeDef*>& name_to_node_map,
126     std::unordered_set<string> variable_names_set,
127     std::unordered_map<string, Tensor>* variable_name_to_value_map) {
128   if (variable_names_set.empty()) {
129     return OkStatus();
130   }
131   std::vector<string> variable_names;
132   variable_names.reserve(variable_names_set.size());
133   std::vector<string> tensor_names;
134   tensor_names.reserve(variable_names_set.size());
135   for (const string& node_name : variable_names_set) {
136     variable_names.push_back(node_name);
137     NodeDef* node_def = name_to_node_map.at(node_name);
138     if (node_def->op() == "VarHandleOp") {
139       // If this is a resource variable, we have to run the corresponding
140       // ReadVariableOp.
141       tensor_names.push_back(node_name + "/Read/ReadVariableOp:0");
142     } else {
143       tensor_names.push_back(node_name + ":0");
144     }
145   }
146   std::vector<Tensor> outputs;
147   TF_RETURN_IF_ERROR(
148       session->Run(/* inputs */ {}, tensor_names, /* targets */ {}, &outputs));
149   for (size_t i = 0; i < variable_names.size(); i++) {
150     (*variable_name_to_value_map)[variable_names[i]] = outputs[i];
151   }
152   return OkStatus();
153 }
154 
155 // Converts a Variable NodeDef into a Constant NodeDef.
ConvertVariableToConstant(const NodeDef & variable_node,const Tensor & variable_value,NodeDef * const_node)156 void ConvertVariableToConstant(const NodeDef& variable_node,
157                                const Tensor& variable_value,
158                                NodeDef* const_node) {
159   const_node->set_name(variable_node.name());
160   const_node->set_op("Const");
161   (*const_node->mutable_attr())["dtype"] = variable_node.attr().at("dtype");
162   variable_value.AsProtoTensorContent(
163       (*const_node->mutable_attr())["value"].mutable_tensor());
164 }
165 
166 // Converts a ReadVariableOp NodeDef to an Identity NodeDef.
ConvertReadVariableOpToIdentity(const NodeDef & node,NodeDef * identity_node)167 void ConvertReadVariableOpToIdentity(const NodeDef& node,
168                                      NodeDef* identity_node) {
169   identity_node->set_name(node.name());
170   identity_node->set_op("Identity");
171   (*identity_node->mutable_attr())["T"] = node.attr().at("dtype");
172   identity_node->add_input(node.input(0));
173 }
174 
175 // Returns the name of the VarHandleOp that provides input (possibly indirectly)
176 // to node with node_name. A typical indirect chain of nodes (that can occur due
177 // to graph inlining) is the following: VarHandleOp -> Identity -> Identity ->
178 // ReadVariableOp. Calling the function on any of these nodes would return the
179 // name of the VarHandleOp.
GetVarHandleName(const std::unordered_map<string,NodeDef * > & name_to_node_map,string node_name)180 StatusOr<string> GetVarHandleName(
181     const std::unordered_map<string, NodeDef*>& name_to_node_map,
182     string node_name) {
183   const NodeDef* node = name_to_node_map.at(node_name);
184   while (node->input_size() > 0) {
185     auto parent = name_to_node_map.find(node->input(0));
186     if (parent == name_to_node_map.end()) break;
187     node = parent->second;
188     if (node->op() != "Identity") {
189       VLOG(2) << "Stopping at non-identity node " << node->op();
190       break;
191     }
192   }
193   if (node->op() == "VarHandleOp") {
194     return node->name();
195   }
196   return errors::NotFound("No VarHandleOp ancestor found");
197 }
198 
199 // Looks up the variable handle that provides input to node with node_name,
200 // and returns the handle name if the handle corresponds to a variable that we
201 // want to freeze (i.e. its name is contained in variable_node_names). If there
202 // is no such handle in the graph (or we do not want to save that variable)
203 // then NotFound error is returned.
GetHandleNameIfNeedsToFreeze(const std::unordered_map<string,NodeDef * > & name_to_node_map,string node_name,const std::unordered_set<string> & variable_node_names)204 StatusOr<string> GetHandleNameIfNeedsToFreeze(
205     const std::unordered_map<string, NodeDef*>& name_to_node_map,
206     string node_name, const std::unordered_set<string>& variable_node_names) {
207   StatusOr<string> var_handle_name =
208       GetVarHandleName(name_to_node_map, node_name);
209   if (var_handle_name.ok() && variable_node_names.count(*var_handle_name)) {
210     return var_handle_name;
211   }
212   return errors::NotFound("No VarHandleOp ancestor found");
213 }
214 
215 // Freezes the subgraph of all nodes needed by `outputs`.
FreezeGraphDef(const SavedModelBundle & saved_model_bundle,const std::unordered_set<string> & outputs,GraphDef * frozen_graph_def)216 Status FreezeGraphDef(const SavedModelBundle& saved_model_bundle,
217                       const std::unordered_set<string>& outputs,
218                       GraphDef* frozen_graph_def) {
219   GraphDef graph_def = saved_model_bundle.meta_graph_def.graph_def();
220   // Copy versions and library as-is from original graph.
221   *frozen_graph_def->mutable_versions() = graph_def.versions();
222   *frozen_graph_def->mutable_library() = graph_def.library();
223   // If the graph is empty there is nothing left to do.
224   if (graph_def.node_size() == 0) {
225     return OkStatus();
226   }
227   // name_to_node_map is needed to get the inputs from the NodeDef corresponding
228   // the a string node name. These inputs are used when doing our backwards
229   // traversal.
230   std::unordered_map<string, NodeDef*> name_to_node_map;
231   GetNodeNameToNodeDefMap(&graph_def, &name_to_node_map);
232   std::unordered_set<string> reachable_node_names;
233   std::unordered_set<string> variable_node_names;
234   GetReachableNodesAndVariables(&graph_def, outputs, name_to_node_map,
235                                 &reachable_node_names, &variable_node_names);
236   std::unordered_map<string, Tensor> variable_to_value_map;
237   TF_RETURN_IF_ERROR(GetVariableNameToTensorMap(
238       saved_model_bundle.session.get(), name_to_node_map, variable_node_names,
239       &variable_to_value_map));
240   // We copy the nodes in the same order they were in the original graph_def.
241   for (const NodeDef& node : graph_def.node()) {
242     if (reachable_node_names.find(node.name()) == reachable_node_names.end()) {
243       continue;
244     }
245     if (variable_node_names.find(node.name()) != variable_node_names.end()) {
246       ConvertVariableToConstant(node, variable_to_value_map[node.name()],
247                                 frozen_graph_def->add_node());
248       continue;
249     } else if (node.op() == "ReadVariableOp" &&
250                GetHandleNameIfNeedsToFreeze(name_to_node_map, node.name(),
251                                             variable_node_names)
252                    .ok()) {
253       // If the node is a ReadVariableOp, its input VarHandleOp will be
254       // converted to a Constant, so we will need to convert it to an Identity.
255       ConvertReadVariableOpToIdentity(node, frozen_graph_def->add_node());
256       continue;
257     } else if (node.op() == "Identity") {
258       StatusOr<string> handle_name = GetHandleNameIfNeedsToFreeze(
259           name_to_node_map, node.name(), variable_node_names);
260       if (handle_name.ok()) {
261         // Identity node that is forwarding the value of a frozen
262         // VarhandleOp. We ensure that the dtype matches of the variable dtype.
263         NodeDef* new_node = frozen_graph_def->add_node();
264         *new_node = node;
265         (*new_node->mutable_attr())["T"] =
266             name_to_node_map.at(*handle_name)->attr().at("dtype");
267         continue;
268       }
269     }
270     // If the node isn't a variable, just copy the node as-is.
271     *frozen_graph_def->add_node() = node;
272   }
273   return OkStatus();
274 }
275 
276 }  // namespace
277 
FreezeSavedModel(const SavedModelBundle & saved_model_bundle,GraphDef * frozen_graph_def,std::unordered_set<string> * inputs,std::unordered_set<string> * outputs)278 Status FreezeSavedModel(const SavedModelBundle& saved_model_bundle,
279                         GraphDef* frozen_graph_def,
280                         std::unordered_set<string>* inputs,
281                         std::unordered_set<string>* outputs) {
282   GetSignatureDefsInputsAndOutputs(saved_model_bundle, inputs, outputs);
283   TF_RETURN_IF_ERROR(
284       FreezeGraphDef(saved_model_bundle, *outputs, frozen_graph_def));
285   return OkStatus();
286 }
287 
288 }  // namespace tensorflow
289