xref: /aosp_15_r20/external/tensorflow/tensorflow/core/framework/graph_def_util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/graph_def_util.h"
17 
18 #include <set>
19 #include <unordered_map>
20 #include <unordered_set>
21 #include <vector>
22 
23 #include "tensorflow/core/framework/attr_value.pb.h"
24 #include "tensorflow/core/framework/function.h"
25 #include "tensorflow/core/framework/function.pb.h"
26 #include "tensorflow/core/framework/graph.pb.h"
27 #include "tensorflow/core/framework/node_def.pb.h"
28 #include "tensorflow/core/framework/node_def_util.h"
29 #include "tensorflow/core/framework/op_def_util.h"
30 #include "tensorflow/core/framework/versions.pb.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/lib/core/status.h"
33 #include "tensorflow/core/lib/strings/str_util.h"
34 #include "tensorflow/core/lib/strings/strcat.h"
35 
36 namespace tensorflow {
37 
SummarizeGraphDef(const GraphDef & graph_def)38 string SummarizeGraphDef(const GraphDef& graph_def) {
39   string ret;
40   strings::StrAppend(
41       &ret, "versions = ", graph_def.versions().ShortDebugString(), ";\n");
42   for (const NodeDef& node : graph_def.node()) {
43     strings::StrAppend(&ret, SummarizeNodeDef(node), ";\n");
44   }
45   return ret;
46 }
47 
ValidateExternalGraphDefSyntax(const GraphDef & graph_def)48 Status ValidateExternalGraphDefSyntax(const GraphDef& graph_def) {
49   for (const NodeDef& node : graph_def.node()) {
50     TF_RETURN_IF_ERROR(ValidateExternalNodeDefSyntax(node));
51   }
52   return OkStatus();
53 }
54 
AddDefaultAttrsToGraphDef(GraphDef * graph_def,const OpRegistryInterface & op_registry,int node_offset)55 Status AddDefaultAttrsToGraphDef(GraphDef* graph_def,
56                                  const OpRegistryInterface& op_registry,
57                                  int node_offset) {
58   return AddDefaultAttrsToGraphDef(graph_def, op_registry, node_offset, false);
59 }
60 
AddDefaultAttrsToGraphDef(GraphDef * graph_def,const OpRegistryInterface & op_registry,int node_offset,bool skip_unknown_ops)61 Status AddDefaultAttrsToGraphDef(GraphDef* graph_def,
62                                  const OpRegistryInterface& op_registry,
63                                  int node_offset, bool skip_unknown_ops) {
64   if (node_offset > graph_def->node_size()) {
65     return errors::InvalidArgument(
66         "Tried to add default attrs to GraphDef "
67         "starting at offset ",
68         node_offset, " with total nodes in graph: ", graph_def->node_size());
69   }
70 
71   for (int i = node_offset; i < graph_def->node_size(); ++i) {
72     NodeDef* node_def = graph_def->mutable_node(i);
73     const OpDef* op_def;
74     Status s = op_registry.LookUpOpDef(node_def->op(), &op_def);
75     if (s.ok()) {
76       AddDefaultsToNodeDef(*op_def, node_def);
77     } else if (!skip_unknown_ops) {
78       return s;
79     }
80   }
81 
82   return OkStatus();
83 }
84 
RemoveNewDefaultAttrsFromNodeDef(NodeDef * node_def,const OpRegistryInterface & consumer_op_registry,const OpRegistryInterface & producer_op_registry,std::set<std::pair<string,string>> * op_attr_removed)85 static Status RemoveNewDefaultAttrsFromNodeDef(
86     NodeDef* node_def, const OpRegistryInterface& consumer_op_registry,
87     const OpRegistryInterface& producer_op_registry,
88     std::set<std::pair<string, string>>* op_attr_removed) {
89   const OpDef* producer_op_def;
90   const OpDef* consumer_op_def;
91   TF_RETURN_IF_ERROR(
92       producer_op_registry.LookUpOpDef(node_def->op(), &producer_op_def));
93   TF_RETURN_IF_ERROR(
94       consumer_op_registry.LookUpOpDef(node_def->op(), &consumer_op_def));
95 
96   std::vector<string> to_remove;
97   for (const auto& attr : node_def->attr()) {
98     // If the attr is not in consumer_op_def and doesn't start with '_'...
99     if (!absl::StartsWith(attr.first, "_") &&
100         FindAttr(attr.first, *consumer_op_def) == nullptr) {
101       const OpDef::AttrDef* producer_attr_def =
102           FindAttr(attr.first, *producer_op_def);
103       if (producer_attr_def == nullptr) {
104         return errors::InvalidArgument(
105             "Attr '", attr.first,
106             "' missing in producer's OpDef: ", SummarizeOpDef(*producer_op_def),
107             " but found in node: ", FormatNodeDefForError(*node_def));
108       }
109       // ...and it has the same value as the default in producer,
110       if (producer_attr_def->has_default_value() &&
111           AreAttrValuesEqual(producer_attr_def->default_value(), attr.second)) {
112         // then we will remove it below.
113         to_remove.emplace_back(attr.first);
114       }
115     }
116   }
117   // We separate identifying which attrs should be removed from
118   // actually removing them to avoid invalidating the loop iterators
119   // above.
120   for (const string& attr_name : to_remove) {
121     node_def->mutable_attr()->erase(attr_name);
122     if (op_attr_removed != nullptr) {
123       op_attr_removed->insert(std::make_pair(node_def->op(), attr_name));
124     }
125   }
126 
127   return OkStatus();
128 }
129 
IsFunction(const GraphDef & graph_def,const string & op_name)130 static bool IsFunction(const GraphDef& graph_def, const string& op_name) {
131   for (const auto& func_def : graph_def.library().function()) {
132     if (op_name == func_def.signature().name()) return true;
133   }
134   return false;
135 }
136 
RemoveNewDefaultAttrsFromGraphDef(GraphDef * graph_def,const OpRegistryInterface & consumer_op_registry,const OpRegistryInterface & producer_op_registry,std::set<std::pair<string,string>> * op_attr_removed)137 Status RemoveNewDefaultAttrsFromGraphDef(
138     GraphDef* graph_def, const OpRegistryInterface& consumer_op_registry,
139     const OpRegistryInterface& producer_op_registry,
140     std::set<std::pair<string, string>>* op_attr_removed) {
141   // TODO(joshL): Make IsFunction() faster by collecting the names of
142   // all functions as a preprocessing step.
143   for (int n = 0; n < graph_def->node_size(); ++n) {
144     NodeDef* node_def = graph_def->mutable_node(n);
145     if (!IsFunction(*graph_def, node_def->op())) {
146       TF_RETURN_IF_ERROR(RemoveNewDefaultAttrsFromNodeDef(
147           node_def, consumer_op_registry, producer_op_registry,
148           op_attr_removed));
149     }
150   }
151   for (int f = 0; f < graph_def->library().function_size(); ++f) {
152     FunctionDef* func_def = graph_def->mutable_library()->mutable_function(f);
153     for (int n = 0; n < func_def->node_def_size(); ++n) {
154       NodeDef* node_def = func_def->mutable_node_def(n);
155       if (!IsFunction(*graph_def, node_def->op())) {
156         // TODO(josh11b): Better handling of attrs with placeholder values.
157         TF_RETURN_IF_ERROR(RemoveNewDefaultAttrsFromNodeDef(
158             node_def, consumer_op_registry, producer_op_registry,
159             op_attr_removed));
160       }
161     }
162   }
163 
164   return OkStatus();
165 }
166 
StripDefaultAttributes(const OpRegistryInterface & op_registry,protobuf::RepeatedPtrField<NodeDef> * nodes)167 void StripDefaultAttributes(const OpRegistryInterface& op_registry,
168                             protobuf::RepeatedPtrField<NodeDef>* nodes) {
169   for (int i = 0; i < nodes->size(); ++i) {
170     NodeDef* node = nodes->Mutable(i);
171 
172     const OpDef* op_def;
173     const OpRegistrationData* op_reg_data = nullptr;
174     Status s = op_registry.LookUp(node->op(), &op_reg_data);
175     if (!s.ok()) {
176       VLOG(1) << "Ignoring encountered unknown operation "
177               << SummarizeNodeDef(*node)
178               << " when stripping default attributes. It is likely a function, "
179                  "in which case ignoring it is fine";
180       continue;
181     }
182     op_def = &op_reg_data->op_def;
183 
184     for (const OpDef::AttrDef& attr_def : op_def->attr()) {
185       if (attr_def.has_default_value()) {
186         AttrValueMap* attrs = node->mutable_attr();
187         const string& name = attr_def.name();
188         auto iter = attrs->find(name);
189         if (iter != attrs->end()) {
190           const AttrValue& default_value = attr_def.default_value();
191           // There should never be an attribute whose default value is a tensor
192           // larger than 32MB so allow false negatives  for efficient
193           // comparison.
194           if (AreAttrValuesEqual(iter->second, default_value,
195                                  /*allow_false_negatives=*/true)) {
196             attrs->erase(name);
197           }
198         }
199       }
200     }
201   }
202 }
203 
OpsUsedByGraph(const GraphDef & graph_def,std::set<string> * ops_used_in_graph)204 void OpsUsedByGraph(const GraphDef& graph_def,
205                     std::set<string>* ops_used_in_graph) {
206   // Map function names to definitions.
207   std::unordered_map<string, const FunctionDef*> name_to_function;
208   for (const auto& function : graph_def.library().function()) {
209     name_to_function.insert(
210         std::make_pair(function.signature().name(), &function));
211   }
212 
213   // Collect the sorted list of op names.  Since functions can reference
214   // functions, we need a recursive traversal.
215   std::set<string> used_ops;  // Includes both primitive ops and functions
216   std::vector<const FunctionDef*> functions_to_process;  // A subset of used_ops
217   // Collect the logic to mark an op in a lambda; it'll be used twice below.
218   const auto mark_op_as_used = [&used_ops, &functions_to_process,
219                                 &name_to_function](const string& op) {
220     if (used_ops.insert(op).second) {
221       // If it's a function, we'll need to process further
222       const auto it = name_to_function.find(op);
223       if (it != name_to_function.end()) {
224         functions_to_process.push_back(it->second);
225       }
226     }
227   };
228   for (const auto& node : graph_def.node()) {
229     mark_op_as_used(node.op());
230   }
231   while (!functions_to_process.empty()) {
232     const FunctionDef* fun = functions_to_process.back();
233     functions_to_process.pop_back();
234     for (const auto& node : fun->node_def()) {
235       mark_op_as_used(node.op());
236     }
237   }
238 
239   // Filter out function names to produce output.
240   // TODO(josh11b): Change the above code to produce this directly.
241   ops_used_in_graph->clear();
242   for (const string& op_name : used_ops) {
243     if (name_to_function.find(op_name) == name_to_function.end()) {
244       ops_used_in_graph->insert(op_name);
245     }
246   }
247 }
248 
StrippedOpListForGraph(const GraphDef & graph_def,const OpRegistryInterface & op_registry,OpList * stripped_op_list)249 Status StrippedOpListForGraph(const GraphDef& graph_def,
250                               const OpRegistryInterface& op_registry,
251                               OpList* stripped_op_list) {
252   std::set<string> used_ops;
253   OpsUsedByGraph(graph_def, &used_ops);
254 
255   // Build the stripped op list in sorted order, ignoring functions.
256   stripped_op_list->clear_op();
257   for (const string& op_name : used_ops) {
258     const OpDef* op_def;
259     TF_RETURN_IF_ERROR(op_registry.LookUpOpDef(op_name, &op_def));
260     OpDef* stripped_op = stripped_op_list->add_op();
261     stripped_op->CopyFrom(*op_def);
262     RemoveDescriptionsFromOpDef(stripped_op);
263   }
264   return OkStatus();
265 }
266 
267 }  // namespace tensorflow
268