xref: /aosp_15_r20/external/tensorflow/tensorflow/core/grappler/optimizers/data/graph_utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
17 
18 #include <cstddef>
19 
20 #include "tensorflow/core/framework/dataset_metadata.pb.h"
21 #include "tensorflow/core/framework/device_base.h"
22 #include "tensorflow/core/framework/op_def.pb.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/gtl/map_util.h"
25 #include "tensorflow/core/platform/strcat.h"
26 #include "tensorflow/core/util/ptr_util.h"
27 
28 namespace tensorflow {
29 namespace grappler {
30 namespace graph_utils {
31 namespace {
32 
33 constexpr char kConstOpName[] = "Const";
34 constexpr char kRetValOp[] = "_Retval";
35 
36 constexpr char kOutputShapes[] = "output_shapes";
37 constexpr char kOutputTypes[] = "output_types";
38 constexpr char kToutputTypes[] = "Toutput_types";
39 
40 template <typename Predicate, typename Collection>
GetElementIndicesWithPredicate(const Predicate & predicate,const Collection & collection)41 std::vector<int> GetElementIndicesWithPredicate(const Predicate& predicate,
42                                                 const Collection& collection) {
43   std::vector<int> indices = {};
44   unsigned idx = 0;
45   for (auto&& element : collection) {
46     if (predicate(element)) {
47       indices.push_back(idx);
48     }
49     idx++;
50   }
51   return indices;
52 }
53 
CreateNameIndex(const GraphDef & graph)54 std::vector<int> CreateNameIndex(const GraphDef& graph) {
55   std::map<string, int> names;
56   for (int i = 0; i < graph.node_size(); ++i) {
57     names[graph.node(i).name()] = i;
58   }
59   std::vector<int> index(graph.node_size());
60   int i = 0;
61   for (const auto& pair : names) {
62     index[i++] = pair.second;
63   }
64   return index;
65 }
66 
CreateInputIndex(const NodeDef & node)67 std::vector<int> CreateInputIndex(const NodeDef& node) {
68   std::map<string, int> inputs;
69   for (int i = 0; i < node.input_size(); ++i) {
70     inputs[node.input(i)] = i;
71   }
72   std::vector<int> index(node.input_size());
73   int i = 0;
74   for (const auto& pair : inputs) {
75     index[i++] = pair.second;
76   }
77   return index;
78 }
79 
AddScalarConstNodeHelper(DataType dtype,const std::function<void (TensorProto *)> & add_value,MutableGraphView * graph)80 NodeDef* AddScalarConstNodeHelper(
81     DataType dtype, const std::function<void(TensorProto*)>& add_value,
82     MutableGraphView* graph) {
83   NodeDef node;
84   node.set_op(kConstOpName);
85   SetUniqueGraphNodeName(kConstOpName, graph->graph(), &node);
86 
87   (*node.mutable_attr())["dtype"].set_type(dtype);
88   std::unique_ptr<tensorflow::TensorProto> tensor =
89       tensorflow::MakeUnique<tensorflow::TensorProto>();
90   std::unique_ptr<tensorflow::TensorShapeProto> tensor_shape =
91       tensorflow::MakeUnique<tensorflow::TensorShapeProto>();
92   tensor->set_allocated_tensor_shape(tensor_shape.release());
93   tensor->set_dtype(dtype);
94   add_value(tensor.get());
95   (*node.mutable_attr())["value"].set_allocated_tensor(tensor.release());
96 
97   return graph->AddNode(std::move(node));
98 }
99 
100 }  // namespace
101 
AddScalarPlaceholder(DataType dtype,MutableGraphView * graph)102 NodeDef* AddScalarPlaceholder(DataType dtype, MutableGraphView* graph) {
103   NodeDef node;
104   node.set_op("Placeholder");
105   SetUniqueGraphNodeName(node.op(), graph->graph(), &node);
106   (*node.mutable_attr())["dtype"].set_type(dtype);
107   TensorShapeProto* shape = (*node.mutable_attr())["shape"].mutable_shape();
108   shape->set_unknown_rank(false);
109   return graph->AddNode(std::move(node));
110 }
111 
AddNode(StringPiece name,StringPiece op,const std::vector<string> & inputs,const std::vector<std::pair<string,AttrValue>> & attributes,MutableGraphView * graph)112 NodeDef* AddNode(StringPiece name, StringPiece op,
113                  const std::vector<string>& inputs,
114                  const std::vector<std::pair<string, AttrValue>>& attributes,
115                  MutableGraphView* graph) {
116   NodeDef node;
117   if (!name.empty()) {
118     node.set_name(string(name));
119   } else {
120     SetUniqueGraphNodeName(op, graph->graph(), &node);
121   }
122   node.set_op(string(op));
123   for (const string& input : inputs) {
124     node.add_input(input);
125   }
126   for (const auto& attr : attributes) {
127     (*node.mutable_attr())[attr.first] = attr.second;
128   }
129   return graph->AddNode(std::move(node));
130 }
131 
132 template <>
AddScalarConstNode(bool v,MutableGraphView * graph)133 NodeDef* AddScalarConstNode(bool v, MutableGraphView* graph) {
134   return AddScalarConstNodeHelper(
135       DT_BOOL, [v](TensorProto* proto) { proto->add_bool_val(v); }, graph);
136 }
137 
138 template <>
AddScalarConstNode(double v,MutableGraphView * graph)139 NodeDef* AddScalarConstNode(double v, MutableGraphView* graph) {
140   return AddScalarConstNodeHelper(
141       DT_DOUBLE, [v](TensorProto* proto) { proto->add_double_val(v); }, graph);
142 }
143 
144 template <>
AddScalarConstNode(float v,MutableGraphView * graph)145 NodeDef* AddScalarConstNode(float v, MutableGraphView* graph) {
146   return AddScalarConstNodeHelper(
147       DT_FLOAT, [v](TensorProto* proto) { proto->add_float_val(v); }, graph);
148 }
149 
150 template <>
AddScalarConstNode(int v,MutableGraphView * graph)151 NodeDef* AddScalarConstNode(int v, MutableGraphView* graph) {
152   return AddScalarConstNodeHelper(
153       DT_INT32, [v](TensorProto* proto) { proto->add_int_val(v); }, graph);
154 }
155 
156 template <>
AddScalarConstNode(int64_t v,MutableGraphView * graph)157 NodeDef* AddScalarConstNode(int64_t v, MutableGraphView* graph) {
158   return AddScalarConstNodeHelper(
159       DT_INT64, [v](TensorProto* proto) { proto->add_int64_val(v); }, graph);
160 }
161 
162 template <>
AddScalarConstNode(StringPiece v,MutableGraphView * graph)163 NodeDef* AddScalarConstNode(StringPiece v, MutableGraphView* graph) {
164   return AddScalarConstNodeHelper(
165       DT_STRING,
166       [v](TensorProto* proto) { proto->add_string_val(v.data(), v.size()); },
167       graph);
168 }
169 
GetScalarConstNodeValueHelper(const NodeDef & node,DataType dtype,const std::function<void (const Tensor &)> & get_value)170 Status GetScalarConstNodeValueHelper(
171     const NodeDef& node, DataType dtype,
172     const std::function<void(const Tensor&)>& get_value) {
173   if (node.op() != kConstOpName)
174     return errors::InvalidArgument("Node ", node.name(),
175                                    " is not a Const node. Op: ", node.op());
176 
177   Tensor tensor;
178   TF_RETURN_IF_ERROR(GetNodeAttr(node, "value", &tensor));
179   if (!TensorShapeUtils::IsScalar(tensor.shape())) {
180     return errors::InvalidArgument(
181         "Node ", node.name(),
182         " should be a scalar but has shape: ", tensor.shape());
183   }
184 
185   if (tensor.dtype() != dtype) {
186     return errors::InvalidArgument(
187         "Node ", node.name(), " should have type ", DataTypeString(dtype),
188         " but has type: ", DataTypeString(tensor.dtype()));
189   }
190 
191   get_value(tensor);
192 
193   return OkStatus();
194 }
195 
196 template <>
GetScalarConstNodeValue(const NodeDef & node,int64_t * value)197 Status GetScalarConstNodeValue(const NodeDef& node, int64_t* value) {
198   return GetScalarConstNodeValueHelper(
199       node, DT_INT64,
200       [value](const Tensor& tensor) { *value = tensor.scalar<int64_t>()(); });
201 }
202 
203 template <>
GetScalarConstNodeValue(const NodeDef & node,bool * value)204 Status GetScalarConstNodeValue(const NodeDef& node, bool* value) {
205   return GetScalarConstNodeValueHelper(
206       node, DT_BOOL,
207       [value](const Tensor& tensor) { *value = tensor.scalar<bool>()(); });
208 }
209 
Compare(const GraphDef & g1,const GraphDef & g2)210 bool Compare(const GraphDef& g1, const GraphDef& g2) {
211   if (g1.node_size() != g2.node_size()) {
212     return false;
213   }
214   std::vector<int> name_index1 = CreateNameIndex(g1);
215   std::vector<int> name_index2 = CreateNameIndex(g2);
216   for (int i = 0; i < g1.node_size(); ++i) {
217     int idx1 = name_index1[i];
218     int idx2 = name_index2[i];
219     if (g1.node(idx1).op() != g2.node(idx2).op()) {
220       return false;
221     }
222     if (g1.node(idx1).name() != g2.node(idx2).name()) {
223       return false;
224     }
225     if (g1.node(idx1).input_size() != g2.node(idx2).input_size()) {
226       return false;
227     }
228     std::vector<int> input_index1 = CreateInputIndex(g1.node(idx1));
229     std::vector<int> input_index2 = CreateInputIndex(g2.node(idx2));
230     for (int j = 0; j < g1.node(idx1).input_size(); ++j) {
231       if (!IsSameInput(g1.node(idx1).input(input_index1[j]),
232                        g2.node(idx2).input(input_index2[j]))) {
233         return false;
234       }
235     }
236   }
237   return true;
238 }
239 
ContainsGraphFunctionWithName(StringPiece name,const FunctionDefLibrary & library)240 bool ContainsGraphFunctionWithName(StringPiece name,
241                                    const FunctionDefLibrary& library) {
242   return FindGraphFunctionWithName(name, library) != -1;
243 }
244 
ContainsGraphNodeWithName(StringPiece name,const GraphDef & graph)245 bool ContainsGraphNodeWithName(StringPiece name, const GraphDef& graph) {
246   return FindGraphNodeWithName(name, graph) != -1;
247 }
248 
ContainsNodeWithOp(StringPiece op,const GraphDef & graph)249 bool ContainsNodeWithOp(StringPiece op, const GraphDef& graph) {
250   return FindGraphNodeWithOp(op, graph) != -1;
251 }
252 
FindGraphFunctionWithName(StringPiece name,const FunctionDefLibrary & library)253 int FindGraphFunctionWithName(StringPiece name,
254                               const FunctionDefLibrary& library) {
255   return GetFirstElementIndexWithPredicate(
256       [&name](const FunctionDef& function) {
257         return function.signature().name() == name;
258       },
259       library.function());
260 }
261 
FindGraphNodeWithName(StringPiece name,const GraphDef & graph)262 int FindGraphNodeWithName(StringPiece name, const GraphDef& graph) {
263   return GetFirstElementIndexWithPredicate(
264       [&name](const NodeDef& node) { return node.name() == name; },
265       graph.node());
266 }
267 
FindGraphNodeWithOp(StringPiece op,const GraphDef & graph)268 int FindGraphNodeWithOp(StringPiece op, const GraphDef& graph) {
269   return GetFirstElementIndexWithPredicate(
270       [&op](const NodeDef& node) { return node.op() == op; }, graph.node());
271 }
272 
FindAllGraphNodesWithOp(const string & op,const GraphDef & graph)273 std::vector<int> FindAllGraphNodesWithOp(const string& op,
274                                          const GraphDef& graph) {
275   return GetElementIndicesWithPredicate(
276       [&op](const NodeDef& node) { return node.op() == op; }, graph.node());
277 }
278 
GetInputNode(const NodeDef & node,const MutableGraphView & graph)279 NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph) {
280   if (node.input_size() == 0) return nullptr;
281   MutableGraphView::InputPort input_port = graph.GetInputPort(node.name(), 0);
282   return graph.GetRegularFanin(input_port).node;
283 }
284 
GetInputNode(const NodeDef & node,const MutableGraphView & graph,int64_t i)285 NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph,
286                       int64_t i) {
287   if (node.input_size() <= i) return nullptr;
288   MutableGraphView::InputPort input_port = graph.GetInputPort(node.name(), i);
289   return graph.GetRegularFanin(input_port).node;
290 }
291 
GetDatasetOutputTypesAttr(const NodeDef & node,DataTypeVector * output_types)292 Status GetDatasetOutputTypesAttr(const NodeDef& node,
293                                  DataTypeVector* output_types) {
294   // We don't name the output_types attr consistently, so should check for both.
295   for (const string& attr_name : {"output_types", "Toutput_types"}) {
296     if (node.attr().contains(attr_name)) {
297       return GetNodeAttr(node, attr_name, output_types);
298     }
299   }
300   return errors::InvalidArgument("Could not find output_types attr for node: ",
301                                  node.name(), " with op: ", node.op());
302 }
303 
SetUniqueGraphNodeName(StringPiece prefix,GraphDef * graph,NodeDef * node)304 void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph,
305                             NodeDef* node) {
306   string name = string(prefix);
307   int id = graph->node_size();
308   while (ContainsGraphNodeWithName(name, *graph)) {
309     if (name.rfind("_generated") != string::npos &&
310         (name.rfind("_generated") == (name.size() - strlen("_generated")))) {
311       name.insert(name.rfind("_generated"), strings::StrCat("/_", id));
312     } else {
313       name = strings::StrCat(prefix, "/_", id);
314     }
315     ++id;
316   }
317   node->set_name(std::move(name));
318 }
319 
SetUniqueGraphFunctionName(StringPiece prefix,const FunctionDefLibrary * library,FunctionDef * function)320 void SetUniqueGraphFunctionName(StringPiece prefix,
321                                 const FunctionDefLibrary* library,
322                                 FunctionDef* function) {
323   string name = string(prefix);
324   int id = library->function_size();
325   while (ContainsGraphFunctionWithName(name, *library)) {
326     name = strings::StrCat(prefix, "/_", id);
327     ++id;
328   }
329   function->mutable_signature()->set_name(std::move(name));
330 }
331 
CopyAttribute(const string & attribute_name,const NodeDef & from,NodeDef * to_node)332 void CopyAttribute(const string& attribute_name, const NodeDef& from,
333                    NodeDef* to_node) {
334   (*to_node->mutable_attr())[attribute_name] = from.attr().at(attribute_name);
335 }
336 
ConcatAttributeList(const string & attribute_name,const NodeDef & first,const NodeDef & second,NodeDef * to_node)337 void ConcatAttributeList(const string& attribute_name, const NodeDef& first,
338                          const NodeDef& second, NodeDef* to_node) {
339   CopyAttribute(attribute_name, first, to_node);
340   (*to_node->mutable_attr())
341       .at(attribute_name)
342       .mutable_list()
343       ->MergeFrom(second.attr().at(attribute_name).list());
344 }
345 
EnsureNodeNamesUnique(Graph * g)346 Status EnsureNodeNamesUnique(Graph* g) {
347   // Modeled after Scope::Impl::GetUniqueName
348   std::unordered_map<string, int> name_map;
349 
350   for (auto node : g->op_nodes()) {
351     const string& prefix = node->name();
352     if (auto entry = gtl::FindOrNull(name_map, prefix)) {
353       string unique_name;
354       do {
355         unique_name = strings::StrCat(prefix, "_", ++(*entry));
356       } while (name_map.find(unique_name) != name_map.end());
357       name_map.insert({unique_name, 0});
358       node->set_name(std::move(unique_name));
359     } else {
360       name_map.insert({node->name(), 0});
361     }
362   }
363 
364   return OkStatus();
365 }
366 
GetFetchNode(const MutableGraphView & graph,const GrapplerItem & item,NodeDef ** fetch_node)367 Status GetFetchNode(const MutableGraphView& graph, const GrapplerItem& item,
368                     NodeDef** fetch_node) {
369   if (item.fetch.size() != 1) {
370     return errors::InvalidArgument(
371         "Expected only one fetch node but there were ", item.fetch.size(), ": ",
372         absl::StrJoin(item.fetch, ", "));
373   }
374 
375   *fetch_node = graph.GetNode(item.fetch.at(0));
376 
377   return OkStatus();
378 }
379 
IsItemDerivedFromFunctionDef(const GrapplerItem & item,const MutableGraphView & graph_view)380 bool IsItemDerivedFromFunctionDef(const GrapplerItem& item,
381                                   const MutableGraphView& graph_view) {
382   for (const auto& fetch_name : item.fetch) {
383     auto fetch = graph_view.GetNode(fetch_name);
384     if (fetch != nullptr && fetch->op() != kRetValOp) {
385       // We found a fetch node which is not a `Retval` op.
386       return false;
387     }
388   }
389   // All fetch nodes are `Retval` ops (or we don't have any fetch nodes).
390   return true;
391 }
392 
MaybeSetFusedMetadata(const NodeDef & node1,const NodeDef & node2,NodeDef * fused_node)393 void MaybeSetFusedMetadata(const NodeDef& node1, const NodeDef& node2,
394                            NodeDef* fused_node) {
395   data::Metadata metadata1;
396   if (node1.attr().contains("metadata")) {
397     metadata1.ParseFromString(node1.attr().at("metadata").s());
398   }
399   data::Metadata metadata2;
400   if (node2.attr().contains("metadata")) {
401     metadata2.ParseFromString(node2.attr().at("metadata").s());
402   }
403   data::Metadata fused_metadata;
404   auto normalize_name = [](const string& name) {
405     return name.empty() ? "?" : name;
406   };
407   *fused_metadata.mutable_name() =
408       strings::StrCat("fused(", normalize_name(metadata1.name()), ",",
409                       normalize_name(metadata2.name()), ")");
410   fused_metadata.SerializeToString(
411       (*fused_node->mutable_attr())["metadata"].mutable_s());
412 }
413 
CopyShapesAndTypesAttrs(const NodeDef & from,NodeDef * to_node)414 bool CopyShapesAndTypesAttrs(const NodeDef& from, NodeDef* to_node) {
415   auto* attr = gtl::FindOrNull(from.attr(), kOutputTypes);
416   attr = (attr == nullptr ? gtl::FindOrNull(from.attr(), kToutputTypes) : attr);
417 
418   if (attr == nullptr) return false;
419   (*to_node->mutable_attr())[kOutputTypes] = *attr;
420 
421   attr = gtl::FindOrNull(from.attr(), kOutputShapes);
422   if (attr == nullptr) return false;
423   (*to_node->mutable_attr())[kOutputShapes] = *attr;
424   return true;
425 }
426 
427 namespace {
428 const auto* kSloppyAttrOps = new absl::flat_hash_set<string>{
429     "ParallelInterleaveDatasetV2",
430     "ParallelMapDataset",
431     "ParseExampleDataset",
432 };
433 
434 const auto* kReplicateOnSplitAttrOps = new absl::flat_hash_set<string>{
435     "TensorSliceDataset",
436     "RangeDataset",
437 };
438 
439 const auto* kDeterministicAttrOps = new absl::flat_hash_set<string>{
440     "LegacyParallelInterleaveDatasetV2",
441     "ParallelInterleaveDatasetV3",
442     "ParallelInterleaveDatasetV4",
443     "ParallelMapDatasetV2",
444     "ParallelBatchDataset",
445 };
446 }  // anonymous namespace
447 
HasSloppyAttr(const string & op)448 bool HasSloppyAttr(const string& op) { return kSloppyAttrOps->contains(op); }
449 
HasReplicateOnSplitAttr(const string & op)450 bool HasReplicateOnSplitAttr(const string& op) {
451   return kReplicateOnSplitAttrOps->contains(op);
452 }
453 
HasDeterministicAttr(const string & op)454 bool HasDeterministicAttr(const string& op) {
455   return kDeterministicAttrOps->contains(op);
456 }
457 
SetMetadataName(const std::string & name,NodeDef * node)458 Status SetMetadataName(const std::string& name, NodeDef* node) {
459   data::Metadata metadata;
460   if (node->attr().contains("metadata")) {
461     metadata.ParseFromString(node->attr().at("metadata").s());
462   }
463   if (!metadata.name().empty()) {
464     return errors::InvalidArgument("Node ", node->name(),
465                                    " already has a metadata name \"",
466                                    metadata.name(), "\".");
467   }
468   *metadata.mutable_name() = name;
469   metadata.SerializeToString((*node->mutable_attr())["metadata"].mutable_s());
470   return OkStatus();
471 }
472 
473 }  // namespace graph_utils
474 }  // namespace grappler
475 }  // namespace tensorflow
476