1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
17
18 #include <cstddef>
19
20 #include "tensorflow/core/framework/dataset_metadata.pb.h"
21 #include "tensorflow/core/framework/device_base.h"
22 #include "tensorflow/core/framework/op_def.pb.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/gtl/map_util.h"
25 #include "tensorflow/core/platform/strcat.h"
26 #include "tensorflow/core/util/ptr_util.h"
27
28 namespace tensorflow {
29 namespace grappler {
30 namespace graph_utils {
31 namespace {
32
33 constexpr char kConstOpName[] = "Const";
34 constexpr char kRetValOp[] = "_Retval";
35
36 constexpr char kOutputShapes[] = "output_shapes";
37 constexpr char kOutputTypes[] = "output_types";
38 constexpr char kToutputTypes[] = "Toutput_types";
39
40 template <typename Predicate, typename Collection>
GetElementIndicesWithPredicate(const Predicate & predicate,const Collection & collection)41 std::vector<int> GetElementIndicesWithPredicate(const Predicate& predicate,
42 const Collection& collection) {
43 std::vector<int> indices = {};
44 unsigned idx = 0;
45 for (auto&& element : collection) {
46 if (predicate(element)) {
47 indices.push_back(idx);
48 }
49 idx++;
50 }
51 return indices;
52 }
53
CreateNameIndex(const GraphDef & graph)54 std::vector<int> CreateNameIndex(const GraphDef& graph) {
55 std::map<string, int> names;
56 for (int i = 0; i < graph.node_size(); ++i) {
57 names[graph.node(i).name()] = i;
58 }
59 std::vector<int> index(graph.node_size());
60 int i = 0;
61 for (const auto& pair : names) {
62 index[i++] = pair.second;
63 }
64 return index;
65 }
66
CreateInputIndex(const NodeDef & node)67 std::vector<int> CreateInputIndex(const NodeDef& node) {
68 std::map<string, int> inputs;
69 for (int i = 0; i < node.input_size(); ++i) {
70 inputs[node.input(i)] = i;
71 }
72 std::vector<int> index(node.input_size());
73 int i = 0;
74 for (const auto& pair : inputs) {
75 index[i++] = pair.second;
76 }
77 return index;
78 }
79
AddScalarConstNodeHelper(DataType dtype,const std::function<void (TensorProto *)> & add_value,MutableGraphView * graph)80 NodeDef* AddScalarConstNodeHelper(
81 DataType dtype, const std::function<void(TensorProto*)>& add_value,
82 MutableGraphView* graph) {
83 NodeDef node;
84 node.set_op(kConstOpName);
85 SetUniqueGraphNodeName(kConstOpName, graph->graph(), &node);
86
87 (*node.mutable_attr())["dtype"].set_type(dtype);
88 std::unique_ptr<tensorflow::TensorProto> tensor =
89 tensorflow::MakeUnique<tensorflow::TensorProto>();
90 std::unique_ptr<tensorflow::TensorShapeProto> tensor_shape =
91 tensorflow::MakeUnique<tensorflow::TensorShapeProto>();
92 tensor->set_allocated_tensor_shape(tensor_shape.release());
93 tensor->set_dtype(dtype);
94 add_value(tensor.get());
95 (*node.mutable_attr())["value"].set_allocated_tensor(tensor.release());
96
97 return graph->AddNode(std::move(node));
98 }
99
100 } // namespace
101
AddScalarPlaceholder(DataType dtype,MutableGraphView * graph)102 NodeDef* AddScalarPlaceholder(DataType dtype, MutableGraphView* graph) {
103 NodeDef node;
104 node.set_op("Placeholder");
105 SetUniqueGraphNodeName(node.op(), graph->graph(), &node);
106 (*node.mutable_attr())["dtype"].set_type(dtype);
107 TensorShapeProto* shape = (*node.mutable_attr())["shape"].mutable_shape();
108 shape->set_unknown_rank(false);
109 return graph->AddNode(std::move(node));
110 }
111
AddNode(StringPiece name,StringPiece op,const std::vector<string> & inputs,const std::vector<std::pair<string,AttrValue>> & attributes,MutableGraphView * graph)112 NodeDef* AddNode(StringPiece name, StringPiece op,
113 const std::vector<string>& inputs,
114 const std::vector<std::pair<string, AttrValue>>& attributes,
115 MutableGraphView* graph) {
116 NodeDef node;
117 if (!name.empty()) {
118 node.set_name(string(name));
119 } else {
120 SetUniqueGraphNodeName(op, graph->graph(), &node);
121 }
122 node.set_op(string(op));
123 for (const string& input : inputs) {
124 node.add_input(input);
125 }
126 for (const auto& attr : attributes) {
127 (*node.mutable_attr())[attr.first] = attr.second;
128 }
129 return graph->AddNode(std::move(node));
130 }
131
132 template <>
AddScalarConstNode(bool v,MutableGraphView * graph)133 NodeDef* AddScalarConstNode(bool v, MutableGraphView* graph) {
134 return AddScalarConstNodeHelper(
135 DT_BOOL, [v](TensorProto* proto) { proto->add_bool_val(v); }, graph);
136 }
137
138 template <>
AddScalarConstNode(double v,MutableGraphView * graph)139 NodeDef* AddScalarConstNode(double v, MutableGraphView* graph) {
140 return AddScalarConstNodeHelper(
141 DT_DOUBLE, [v](TensorProto* proto) { proto->add_double_val(v); }, graph);
142 }
143
144 template <>
AddScalarConstNode(float v,MutableGraphView * graph)145 NodeDef* AddScalarConstNode(float v, MutableGraphView* graph) {
146 return AddScalarConstNodeHelper(
147 DT_FLOAT, [v](TensorProto* proto) { proto->add_float_val(v); }, graph);
148 }
149
150 template <>
AddScalarConstNode(int v,MutableGraphView * graph)151 NodeDef* AddScalarConstNode(int v, MutableGraphView* graph) {
152 return AddScalarConstNodeHelper(
153 DT_INT32, [v](TensorProto* proto) { proto->add_int_val(v); }, graph);
154 }
155
156 template <>
AddScalarConstNode(int64_t v,MutableGraphView * graph)157 NodeDef* AddScalarConstNode(int64_t v, MutableGraphView* graph) {
158 return AddScalarConstNodeHelper(
159 DT_INT64, [v](TensorProto* proto) { proto->add_int64_val(v); }, graph);
160 }
161
162 template <>
AddScalarConstNode(StringPiece v,MutableGraphView * graph)163 NodeDef* AddScalarConstNode(StringPiece v, MutableGraphView* graph) {
164 return AddScalarConstNodeHelper(
165 DT_STRING,
166 [v](TensorProto* proto) { proto->add_string_val(v.data(), v.size()); },
167 graph);
168 }
169
GetScalarConstNodeValueHelper(const NodeDef & node,DataType dtype,const std::function<void (const Tensor &)> & get_value)170 Status GetScalarConstNodeValueHelper(
171 const NodeDef& node, DataType dtype,
172 const std::function<void(const Tensor&)>& get_value) {
173 if (node.op() != kConstOpName)
174 return errors::InvalidArgument("Node ", node.name(),
175 " is not a Const node. Op: ", node.op());
176
177 Tensor tensor;
178 TF_RETURN_IF_ERROR(GetNodeAttr(node, "value", &tensor));
179 if (!TensorShapeUtils::IsScalar(tensor.shape())) {
180 return errors::InvalidArgument(
181 "Node ", node.name(),
182 " should be a scalar but has shape: ", tensor.shape());
183 }
184
185 if (tensor.dtype() != dtype) {
186 return errors::InvalidArgument(
187 "Node ", node.name(), " should have type ", DataTypeString(dtype),
188 " but has type: ", DataTypeString(tensor.dtype()));
189 }
190
191 get_value(tensor);
192
193 return OkStatus();
194 }
195
196 template <>
GetScalarConstNodeValue(const NodeDef & node,int64_t * value)197 Status GetScalarConstNodeValue(const NodeDef& node, int64_t* value) {
198 return GetScalarConstNodeValueHelper(
199 node, DT_INT64,
200 [value](const Tensor& tensor) { *value = tensor.scalar<int64_t>()(); });
201 }
202
203 template <>
GetScalarConstNodeValue(const NodeDef & node,bool * value)204 Status GetScalarConstNodeValue(const NodeDef& node, bool* value) {
205 return GetScalarConstNodeValueHelper(
206 node, DT_BOOL,
207 [value](const Tensor& tensor) { *value = tensor.scalar<bool>()(); });
208 }
209
Compare(const GraphDef & g1,const GraphDef & g2)210 bool Compare(const GraphDef& g1, const GraphDef& g2) {
211 if (g1.node_size() != g2.node_size()) {
212 return false;
213 }
214 std::vector<int> name_index1 = CreateNameIndex(g1);
215 std::vector<int> name_index2 = CreateNameIndex(g2);
216 for (int i = 0; i < g1.node_size(); ++i) {
217 int idx1 = name_index1[i];
218 int idx2 = name_index2[i];
219 if (g1.node(idx1).op() != g2.node(idx2).op()) {
220 return false;
221 }
222 if (g1.node(idx1).name() != g2.node(idx2).name()) {
223 return false;
224 }
225 if (g1.node(idx1).input_size() != g2.node(idx2).input_size()) {
226 return false;
227 }
228 std::vector<int> input_index1 = CreateInputIndex(g1.node(idx1));
229 std::vector<int> input_index2 = CreateInputIndex(g2.node(idx2));
230 for (int j = 0; j < g1.node(idx1).input_size(); ++j) {
231 if (!IsSameInput(g1.node(idx1).input(input_index1[j]),
232 g2.node(idx2).input(input_index2[j]))) {
233 return false;
234 }
235 }
236 }
237 return true;
238 }
239
ContainsGraphFunctionWithName(StringPiece name,const FunctionDefLibrary & library)240 bool ContainsGraphFunctionWithName(StringPiece name,
241 const FunctionDefLibrary& library) {
242 return FindGraphFunctionWithName(name, library) != -1;
243 }
244
ContainsGraphNodeWithName(StringPiece name,const GraphDef & graph)245 bool ContainsGraphNodeWithName(StringPiece name, const GraphDef& graph) {
246 return FindGraphNodeWithName(name, graph) != -1;
247 }
248
ContainsNodeWithOp(StringPiece op,const GraphDef & graph)249 bool ContainsNodeWithOp(StringPiece op, const GraphDef& graph) {
250 return FindGraphNodeWithOp(op, graph) != -1;
251 }
252
FindGraphFunctionWithName(StringPiece name,const FunctionDefLibrary & library)253 int FindGraphFunctionWithName(StringPiece name,
254 const FunctionDefLibrary& library) {
255 return GetFirstElementIndexWithPredicate(
256 [&name](const FunctionDef& function) {
257 return function.signature().name() == name;
258 },
259 library.function());
260 }
261
FindGraphNodeWithName(StringPiece name,const GraphDef & graph)262 int FindGraphNodeWithName(StringPiece name, const GraphDef& graph) {
263 return GetFirstElementIndexWithPredicate(
264 [&name](const NodeDef& node) { return node.name() == name; },
265 graph.node());
266 }
267
FindGraphNodeWithOp(StringPiece op,const GraphDef & graph)268 int FindGraphNodeWithOp(StringPiece op, const GraphDef& graph) {
269 return GetFirstElementIndexWithPredicate(
270 [&op](const NodeDef& node) { return node.op() == op; }, graph.node());
271 }
272
FindAllGraphNodesWithOp(const string & op,const GraphDef & graph)273 std::vector<int> FindAllGraphNodesWithOp(const string& op,
274 const GraphDef& graph) {
275 return GetElementIndicesWithPredicate(
276 [&op](const NodeDef& node) { return node.op() == op; }, graph.node());
277 }
278
GetInputNode(const NodeDef & node,const MutableGraphView & graph)279 NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph) {
280 if (node.input_size() == 0) return nullptr;
281 MutableGraphView::InputPort input_port = graph.GetInputPort(node.name(), 0);
282 return graph.GetRegularFanin(input_port).node;
283 }
284
GetInputNode(const NodeDef & node,const MutableGraphView & graph,int64_t i)285 NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph,
286 int64_t i) {
287 if (node.input_size() <= i) return nullptr;
288 MutableGraphView::InputPort input_port = graph.GetInputPort(node.name(), i);
289 return graph.GetRegularFanin(input_port).node;
290 }
291
GetDatasetOutputTypesAttr(const NodeDef & node,DataTypeVector * output_types)292 Status GetDatasetOutputTypesAttr(const NodeDef& node,
293 DataTypeVector* output_types) {
294 // We don't name the output_types attr consistently, so should check for both.
295 for (const string& attr_name : {"output_types", "Toutput_types"}) {
296 if (node.attr().contains(attr_name)) {
297 return GetNodeAttr(node, attr_name, output_types);
298 }
299 }
300 return errors::InvalidArgument("Could not find output_types attr for node: ",
301 node.name(), " with op: ", node.op());
302 }
303
SetUniqueGraphNodeName(StringPiece prefix,GraphDef * graph,NodeDef * node)304 void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph,
305 NodeDef* node) {
306 string name = string(prefix);
307 int id = graph->node_size();
308 while (ContainsGraphNodeWithName(name, *graph)) {
309 if (name.rfind("_generated") != string::npos &&
310 (name.rfind("_generated") == (name.size() - strlen("_generated")))) {
311 name.insert(name.rfind("_generated"), strings::StrCat("/_", id));
312 } else {
313 name = strings::StrCat(prefix, "/_", id);
314 }
315 ++id;
316 }
317 node->set_name(std::move(name));
318 }
319
SetUniqueGraphFunctionName(StringPiece prefix,const FunctionDefLibrary * library,FunctionDef * function)320 void SetUniqueGraphFunctionName(StringPiece prefix,
321 const FunctionDefLibrary* library,
322 FunctionDef* function) {
323 string name = string(prefix);
324 int id = library->function_size();
325 while (ContainsGraphFunctionWithName(name, *library)) {
326 name = strings::StrCat(prefix, "/_", id);
327 ++id;
328 }
329 function->mutable_signature()->set_name(std::move(name));
330 }
331
CopyAttribute(const string & attribute_name,const NodeDef & from,NodeDef * to_node)332 void CopyAttribute(const string& attribute_name, const NodeDef& from,
333 NodeDef* to_node) {
334 (*to_node->mutable_attr())[attribute_name] = from.attr().at(attribute_name);
335 }
336
ConcatAttributeList(const string & attribute_name,const NodeDef & first,const NodeDef & second,NodeDef * to_node)337 void ConcatAttributeList(const string& attribute_name, const NodeDef& first,
338 const NodeDef& second, NodeDef* to_node) {
339 CopyAttribute(attribute_name, first, to_node);
340 (*to_node->mutable_attr())
341 .at(attribute_name)
342 .mutable_list()
343 ->MergeFrom(second.attr().at(attribute_name).list());
344 }
345
EnsureNodeNamesUnique(Graph * g)346 Status EnsureNodeNamesUnique(Graph* g) {
347 // Modeled after Scope::Impl::GetUniqueName
348 std::unordered_map<string, int> name_map;
349
350 for (auto node : g->op_nodes()) {
351 const string& prefix = node->name();
352 if (auto entry = gtl::FindOrNull(name_map, prefix)) {
353 string unique_name;
354 do {
355 unique_name = strings::StrCat(prefix, "_", ++(*entry));
356 } while (name_map.find(unique_name) != name_map.end());
357 name_map.insert({unique_name, 0});
358 node->set_name(std::move(unique_name));
359 } else {
360 name_map.insert({node->name(), 0});
361 }
362 }
363
364 return OkStatus();
365 }
366
GetFetchNode(const MutableGraphView & graph,const GrapplerItem & item,NodeDef ** fetch_node)367 Status GetFetchNode(const MutableGraphView& graph, const GrapplerItem& item,
368 NodeDef** fetch_node) {
369 if (item.fetch.size() != 1) {
370 return errors::InvalidArgument(
371 "Expected only one fetch node but there were ", item.fetch.size(), ": ",
372 absl::StrJoin(item.fetch, ", "));
373 }
374
375 *fetch_node = graph.GetNode(item.fetch.at(0));
376
377 return OkStatus();
378 }
379
IsItemDerivedFromFunctionDef(const GrapplerItem & item,const MutableGraphView & graph_view)380 bool IsItemDerivedFromFunctionDef(const GrapplerItem& item,
381 const MutableGraphView& graph_view) {
382 for (const auto& fetch_name : item.fetch) {
383 auto fetch = graph_view.GetNode(fetch_name);
384 if (fetch != nullptr && fetch->op() != kRetValOp) {
385 // We found a fetch node which is not a `Retval` op.
386 return false;
387 }
388 }
389 // All fetch nodes are `Retval` ops (or we don't have any fetch nodes).
390 return true;
391 }
392
MaybeSetFusedMetadata(const NodeDef & node1,const NodeDef & node2,NodeDef * fused_node)393 void MaybeSetFusedMetadata(const NodeDef& node1, const NodeDef& node2,
394 NodeDef* fused_node) {
395 data::Metadata metadata1;
396 if (node1.attr().contains("metadata")) {
397 metadata1.ParseFromString(node1.attr().at("metadata").s());
398 }
399 data::Metadata metadata2;
400 if (node2.attr().contains("metadata")) {
401 metadata2.ParseFromString(node2.attr().at("metadata").s());
402 }
403 data::Metadata fused_metadata;
404 auto normalize_name = [](const string& name) {
405 return name.empty() ? "?" : name;
406 };
407 *fused_metadata.mutable_name() =
408 strings::StrCat("fused(", normalize_name(metadata1.name()), ",",
409 normalize_name(metadata2.name()), ")");
410 fused_metadata.SerializeToString(
411 (*fused_node->mutable_attr())["metadata"].mutable_s());
412 }
413
CopyShapesAndTypesAttrs(const NodeDef & from,NodeDef * to_node)414 bool CopyShapesAndTypesAttrs(const NodeDef& from, NodeDef* to_node) {
415 auto* attr = gtl::FindOrNull(from.attr(), kOutputTypes);
416 attr = (attr == nullptr ? gtl::FindOrNull(from.attr(), kToutputTypes) : attr);
417
418 if (attr == nullptr) return false;
419 (*to_node->mutable_attr())[kOutputTypes] = *attr;
420
421 attr = gtl::FindOrNull(from.attr(), kOutputShapes);
422 if (attr == nullptr) return false;
423 (*to_node->mutable_attr())[kOutputShapes] = *attr;
424 return true;
425 }
426
427 namespace {
428 const auto* kSloppyAttrOps = new absl::flat_hash_set<string>{
429 "ParallelInterleaveDatasetV2",
430 "ParallelMapDataset",
431 "ParseExampleDataset",
432 };
433
434 const auto* kReplicateOnSplitAttrOps = new absl::flat_hash_set<string>{
435 "TensorSliceDataset",
436 "RangeDataset",
437 };
438
439 const auto* kDeterministicAttrOps = new absl::flat_hash_set<string>{
440 "LegacyParallelInterleaveDatasetV2",
441 "ParallelInterleaveDatasetV3",
442 "ParallelInterleaveDatasetV4",
443 "ParallelMapDatasetV2",
444 "ParallelBatchDataset",
445 };
446 } // anonymous namespace
447
HasSloppyAttr(const string & op)448 bool HasSloppyAttr(const string& op) { return kSloppyAttrOps->contains(op); }
449
HasReplicateOnSplitAttr(const string & op)450 bool HasReplicateOnSplitAttr(const string& op) {
451 return kReplicateOnSplitAttrOps->contains(op);
452 }
453
HasDeterministicAttr(const string & op)454 bool HasDeterministicAttr(const string& op) {
455 return kDeterministicAttrOps->contains(op);
456 }
457
SetMetadataName(const std::string & name,NodeDef * node)458 Status SetMetadataName(const std::string& name, NodeDef* node) {
459 data::Metadata metadata;
460 if (node->attr().contains("metadata")) {
461 metadata.ParseFromString(node->attr().at("metadata").s());
462 }
463 if (!metadata.name().empty()) {
464 return errors::InvalidArgument("Node ", node->name(),
465 " already has a metadata name \"",
466 metadata.name(), "\".");
467 }
468 *metadata.mutable_name() = name;
469 metadata.SerializeToString((*node->mutable_attr())["metadata"].mutable_s());
470 return OkStatus();
471 }
472
473 } // namespace graph_utils
474 } // namespace grappler
475 } // namespace tensorflow
476