1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/data/make_deterministic.h"
17
18 #include <algorithm>
19 #include <utility>
20
21 #include "absl/container/flat_hash_set.h"
22 #include "tensorflow/core/data/dataset_utils.h"
23 #include "tensorflow/core/framework/dataset.h"
24 #include "tensorflow/core/framework/node_def.pb.h"
25 #include "tensorflow/core/framework/node_def_util.h"
26 #include "tensorflow/core/grappler/clusters/cluster.h"
27 #include "tensorflow/core/grappler/grappler_item.h"
28 #include "tensorflow/core/grappler/mutable_graph_view.h"
29 #include "tensorflow/core/grappler/op_types.h"
30 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
31 #include "tensorflow/core/grappler/optimizers/data/function_utils.h"
32 #include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
33 #include "tensorflow/core/grappler/optimizers/data/split_utils.h"
34 #include "tensorflow/core/grappler/utils.h"
35
36 namespace tensorflow {
37 namespace grappler {
38 namespace {
39
40 constexpr char kInterleaveOp[] = "InterleaveDataset";
41 constexpr char kParallelInterleaveOp[] = "ParallelInterleaveDataset";
42 constexpr char kLegacyParallelInterleaveOp[] =
43 "LegacyParallelInterleaveDatasetV2";
44 constexpr char kMapOp[] = "MapDataset";
45 constexpr char kParallelMapOp[] = "ParallelMapDataset";
46 constexpr char kParallelMapOpV2[] = "ParallelMapDatasetV2";
47 constexpr char kMapAndBatchOp[] = "MapAndBatchDataset";
48 constexpr char kBatchOp[] = "BatchDataset";
49 constexpr char kBatchV2Op[] = "BatchDatasetV2";
50 constexpr char kParallelBatchOp[] = "ParallelBatchDataset";
51 constexpr char kPrefetchOp[] = "PrefetchDataset";
52
53 // List of stateful ops which do not introduce nondeterminism when run as part
54 // of a Dataset function, e.g. within an InterleaveDataset's function. These are
55 // stateful dataset ops which do not read or modify TensorFlow state. Stateful
56 // ops not in this list can introduce nondeterminism, either due to the fact
57 // they are run in parallel (e.g. in a MapDataset with num_parallel_calls > 1)
58 // or because they can run asynchronously (e.g. a PrefetchDataset can cause ops
59 // in a MapDataset to run at the same time as ops outside a dataset).
60 //
61 // Ops in this list are allowed to read from files, as we do not make any
62 // guarantees on determinism if files are modified while a dataset is running.
63 // TODO(reedwm): Expand this list.
64 constexpr std::array<const char*, 9> kDeterministicStatefulOps = {
65 "TextLineDataset", "FixedLengthRecordDataset", "TFRecordDataset",
66 "TensorSliceDataset", "RangeDataset", "SSTableDataset", "RecordIODataset",
67 // Because Print and Assert are on this list, the order of Print and Assert
68 // ops may not be deterministic. This is acceptable, as it doesn't affect
69 // model outputs or weights or other numeric values.
70 "Print", "Assert"};
71
72 // List of stateful ops which do not introduce nondeterminism when run
73 // asynchronously as part of a Dataset function, but may introduce
74 // nondeterminism when run in parallel. All legacy random ops can be put on this
75 // list, since the state in internal to the op itself, and so there is no risk
76 // of ops outside the dataset reading or modifying the state.
77 constexpr std::array<const char*, 13> kDeterministicStatefulOpsWhenAsync = {
78 "RandomUniform",
79 "RandomUniformInt",
80 "RandomStandardNormal",
81 "ParameterizedTruncatedNormal",
82 "TruncatedNormal",
83 "RandomShuffle",
84 "Multinomial",
85 "RandomGamma",
86 "RandomGammaGrad",
87 "RandomPoisson",
88 "RandomCrop",
89 "SampleDistortedBoundingBox",
90 "SampleDistortedBoundingBoxV2"};
91
IsDeterministicWhenRunInParallel(const std::string & stateful_op)92 bool IsDeterministicWhenRunInParallel(const std::string& stateful_op) {
93 for (auto op_in_array : kDeterministicStatefulOps) {
94 if (data::MatchesAnyVersion(op_in_array, stateful_op)) {
95 return true;
96 }
97 }
98 return false;
99 }
100
IsDeterministicWhenRunAsynchronously(const std::string & stateful_op)101 bool IsDeterministicWhenRunAsynchronously(const std::string& stateful_op) {
102 for (auto op_in_array : kDeterministicStatefulOps) {
103 if (data::MatchesAnyVersion(op_in_array, stateful_op)) {
104 return true;
105 }
106 }
107 for (auto op_in_array : kDeterministicStatefulOpsWhenAsync) {
108 if (data::MatchesAnyVersion(op_in_array, stateful_op)) {
109 return true;
110 }
111 }
112 return false;
113 }
114
IsParallelInterleave(const std::string & op)115 bool IsParallelInterleave(const std::string& op) {
116 return data::MatchesAnyVersion(kParallelInterleaveOp, op) ||
117 op == kLegacyParallelInterleaveOp;
118 }
119
IsParallelMap(const std::string & op)120 bool IsParallelMap(const std::string& op) {
121 return data::MatchesAnyVersion(kParallelMapOp, op);
122 }
123
IsParallelBatch(const std::string & op)124 bool IsParallelBatch(const std::string& op) {
125 return data::MatchesAnyVersion(kParallelBatchOp, op);
126 }
127
IsMapAndBatch(const std::string & op)128 bool IsMapAndBatch(const std::string& op) {
129 return data::MatchesAnyVersion(kMapAndBatchOp, op);
130 }
131
IsPrefetch(const std::string & op)132 bool IsPrefetch(const std::string& op) {
133 return data::MatchesAnyVersion(kPrefetchOp, op);
134 }
135
136 // Returns whether the op is a dataset op which runs a function multiple times
137 // in parallel.
IntroducesFunctionParallelism(const std::string & op)138 bool IntroducesFunctionParallelism(const std::string& op) {
139 return IsParallelInterleave(op) || IsParallelMap(op) || IsMapAndBatch(op);
140 }
141
142 // Returns whether the op is a dataset op which can cause functions in the input
143 // pipeline to run asynchronously.
IntroducesAsynchrony(const std::string & op)144 bool IntroducesAsynchrony(const std::string& op) {
145 // Currently, every op that introduces parallelism also introduces
146 // asynchrony.
147 return IntroducesFunctionParallelism(op) || IsPrefetch(op) ||
148 IsParallelBatch(op);
149 }
150
151 // Returns map from node name to NodeDef in a function.
NameToNode(const FunctionDef & function)152 absl::flat_hash_map<absl::string_view, const NodeDef*> NameToNode(
153 const FunctionDef& function) {
154 absl::flat_hash_map<absl::string_view, const NodeDef*> name_to_node;
155 for (const NodeDef& node : function.node_def()) {
156 name_to_node.insert({node.name(), &node});
157 }
158 return name_to_node;
159 }
160
GetMutableNode(const string & node_name,MutableGraphView * graph)161 NodeDef* GetMutableNode(const string& node_name, MutableGraphView* graph) {
162 int index = graph_utils::FindGraphNodeWithName(node_name, *graph->graph());
163 DCHECK_NE(index, -1) << "Failed to find node " << node_name
164 << " in the optimized graph.";
165 return graph->graph()->mutable_node(index);
166 }
167
168 // Converts a ParallelInterleaveDataset or ParallelMapDataset to the equivalent
169 // non-parallel version, to make it deterministic.
ConvertMapOrInterleave(const string & node_name,MutableGraphView * graph)170 Status ConvertMapOrInterleave(const string& node_name,
171 MutableGraphView* graph) {
172 NodeDef* node = GetMutableNode(node_name, graph);
173
174 auto Targuments = node->attr().find("Targuments");
175 if (Targuments == node->attr().end()) {
176 return errors::Internal("Failed to find Targuments attribute for node ",
177 node_name);
178 }
179
180 int num_inputs_after_rewrite;
181 if (IsParallelInterleave(node->op())) {
182 node->set_op(kInterleaveOp);
183 num_inputs_after_rewrite = 3 + Targuments->second.list().type_size();
184 } else {
185 DCHECK(IsParallelMap(node->op()));
186 node->set_op(kMapOp);
187 num_inputs_after_rewrite = 1 + Targuments->second.list().type_size();
188 }
189
190 // ParallelInterleave and ParallelMap ops take in more inputs than the
191 // corresponding non-parallel versions, so turn extra inputs into control
192 // inputs. These extra inputs are for performance and are safe to ignore.
193 int inputs_processed = 0;
194 for (int i = 0; i < node->input_size(); i++) {
195 std::string input = node->input(i);
196 if (IsControlInput(input)) {
197 continue;
198 }
199 if (inputs_processed >= num_inputs_after_rewrite) {
200 node->set_input(i, absl::StrCat("^", input));
201 }
202 inputs_processed++;
203 }
204 if (inputs_processed < num_inputs_after_rewrite) {
205 return errors::Internal("Found only ", inputs_processed, " inputs to node ",
206 node_name, ", but expected to find at least ",
207 num_inputs_after_rewrite);
208 }
209
210 // Remove extra attributes not in Interleave or Map.
211 node->mutable_attr()->erase("deterministic");
212 node->mutable_attr()->erase("sloppy");
213 return OkStatus();
214 }
215
216 // Returns all transitive dependencies of a set of nodes, including the nodes
217 // themselves.
GetAllTransitiveDependencies(const FunctionDef & function_def,const absl::flat_hash_set<absl::string_view> & nodes)218 absl::flat_hash_set<absl::string_view> GetAllTransitiveDependencies(
219 const FunctionDef& function_def,
220 const absl::flat_hash_set<absl::string_view>& nodes) {
221 std::vector<absl::string_view> nodes_to_process;
222 std::copy(nodes.begin(), nodes.end(), std::back_inserter(nodes_to_process));
223
224 absl::flat_hash_map<absl::string_view, const NodeDef*> name_to_node =
225 NameToNode(function_def);
226 absl::flat_hash_set<absl::string_view> dependencies;
227 while (!nodes_to_process.empty()) {
228 absl::string_view node_name = nodes_to_process.back();
229 nodes_to_process.pop_back();
230 if (dependencies.contains(node_name)) {
231 continue;
232 }
233 dependencies.insert(node_name);
234 auto iter = name_to_node.find(node_name);
235 if (iter == name_to_node.end()) {
236 // If the node doesn't exist, the function is malformed, so just ignore
237 // the node for now.
238 continue;
239 }
240 for (absl::string_view inp : iter->second->input()) {
241 absl::string_view inp_node = inp.substr(0, inp.find(':'));
242 if (inp_node.at(0) == '^') {
243 inp_node = inp_node.substr(1);
244 }
245 // Input may be an argument instead of a node, so explicitly check if name
246 // is in name_to_node.
247 if (name_to_node.contains(inp_node)) {
248 nodes_to_process.push_back(inp_node);
249 }
250 }
251 }
252 return dependencies;
253 }
254
255 // Makes a ParallelMapV2 op deterministic by splitting it into separate Map and
256 // ParallelMapV2 ops, or a MapAndBatch op deterministic by splitting it into
257 // separate Map and MapAndBatch ops. All the nondeterministic nodes and their
258 // dependencies are moved to the Map node.
SplitMap(const FunctionLibraryDefinition & library,const string & map_node_name,MutableGraphView * graph,const absl::flat_hash_set<absl::string_view> & nondeterministic_nodes)259 Status SplitMap(
260 const FunctionLibraryDefinition& library, const string& map_node_name,
261 MutableGraphView* graph,
262 const absl::flat_hash_set<absl::string_view>& nondeterministic_nodes) {
263 NodeDef* map_node = GetMutableNode(map_node_name, graph);
264 NameAttrList func = map_node->attr().at("f").func();
265 const FunctionDef* function_def = library.Find(func.name());
266 if (!function_def) {
267 return errors::Internal("Could not look up function ", func.name(),
268 " in FunctionLibraryDefinition");
269 }
270
271 absl::flat_hash_set<absl::string_view> nodes_to_move =
272 GetAllTransitiveDependencies(*function_def, nondeterministic_nodes);
273
274 VLOG(2) << "Will move nodes to nonparallel function: "
275 << absl::StrJoin(nodes_to_move, ", ");
276
277 int64_t num_captured_arguments =
278 map_node->attr().find("Targuments")->second.list().type_size();
279
280 TF_ASSIGN_OR_RETURN(
281 split_utils::SplitResults split_results,
282 split_utils::SplitFunction(*function_def, nodes_to_move,
283 num_captured_arguments, library));
284
285 if (split_results.first_function_output_types.empty()) {
286 // Map datasets require there to be at least one output.
287 return errors::Unimplemented(
288 "The case where the first function has no outputs is unimplemented.");
289 }
290
291 bool is_map_and_batch = map_node->op() == kMapAndBatchOp;
292
293 NodeDef* first_map_node_ptr;
294 {
295 NodeDef first_map_node;
296 graph_utils::SetUniqueGraphNodeName(
297 strings::StrCat("make_deterministic_sequential_map/", map_node->name()),
298 graph->graph(), &first_map_node);
299 first_map_node.set_op(kMapOp);
300 int num_control_deps = NumControlInputs(*map_node);
301 // ParallelMap and MapAndBatch nodes have "num_extra_inputs" more inputs
302 // than Map. All inputs are copied to the Map node, but the
303 // "num_extra_inputs" inputs are converted to control dependencies.
304 int num_extra_inputs = is_map_and_batch ? 3 : 1;
305 int control_deps_index = map_node->input_size() - num_control_deps;
306 int extra_inputs_index = control_deps_index - num_extra_inputs;
307 for (int i = 0; i < extra_inputs_index; i++) {
308 // Copy inputs that are also inputs to Map
309 DCHECK(!IsControlInput(map_node->input(i)));
310 first_map_node.add_input(map_node->input(i));
311 }
312 for (int i = extra_inputs_index; i < control_deps_index; i++) {
313 // Copy the extra inputs, converting them to control dependencies
314 DCHECK(!IsControlInput(map_node->input(i)));
315 first_map_node.add_input(absl::StrCat("^", map_node->input(i)));
316 }
317 for (int i = control_deps_index; i < map_node->input_size(); i++) {
318 // Copy the control dependencies
319 DCHECK(IsControlInput(map_node->input(i)));
320 first_map_node.add_input(map_node->input(i));
321 }
322
323 NameAttrList* name_attr_list =
324 (*first_map_node.mutable_attr())["f"].mutable_func();
325 // TODO(reedwm): Set attrs?
326 name_attr_list->set_name(split_results.first_function.signature().name());
327
328 graph_utils::CopyAttribute("Targuments", *map_node, &first_map_node);
329 for (auto key : {"use_inter_op_parallelism", "preserve_cardinality"}) {
330 if (gtl::FindOrNull(map_node->attr(), key)) {
331 graph_utils::CopyAttribute(key, *map_node, &first_map_node);
332 }
333 }
334 AddNodeAttr("output_types", split_results.first_function_output_types,
335 &first_map_node);
336 TensorShapeProto unknown_shape;
337 unknown_shape.set_unknown_rank(true);
338 std::vector<TensorShapeProto> output_shapes(
339 split_results.first_function_output_types.size(), unknown_shape);
340 AddNodeAttr("output_shapes", output_shapes, &first_map_node);
341 first_map_node_ptr = graph->AddNode(std::move(first_map_node));
342 }
343
344 NodeDef* second_map_node_ptr;
345 {
346 NodeDef second_map_node;
347 string node_name =
348 map_node->op() == kMapAndBatchOp ? "map_and_batch" : "parallel_map";
349 graph_utils::SetUniqueGraphNodeName(
350 strings::StrCat("make_deterministic_parallel_", node_name, "/",
351 map_node->name()),
352 graph->graph(), &second_map_node);
353 second_map_node.set_op(map_node->op());
354 second_map_node.add_input(first_map_node_ptr->name());
355 for (int i = 1; i < map_node->input_size(); i++) {
356 second_map_node.add_input(map_node->input(i));
357 }
358 NameAttrList* name_attr_list =
359 (*second_map_node.mutable_attr())["f"].mutable_func();
360 // TODO(reedwm): Set attrs?
361 name_attr_list->set_name(split_results.second_function.signature().name());
362 graph_utils::CopyAttribute("Targuments", *map_node, &second_map_node);
363 graph_utils::CopyAttribute("output_types", *map_node, &second_map_node);
364 graph_utils::CopyAttribute("output_shapes", *map_node, &second_map_node);
365 if (!is_map_and_batch) {
366 AddNodeAttr("deterministic", "true", &second_map_node);
367 }
368 for (auto key : {"use_inter_op_parallelism", "preserve_cardinality"}) {
369 if (gtl::FindOrNull(map_node->attr(), key)) {
370 graph_utils::CopyAttribute(key, *map_node, &second_map_node);
371 }
372 }
373 second_map_node_ptr = graph->AddNode(std::move(second_map_node));
374 }
375
376 TF_RETURN_IF_ERROR(
377 graph->UpdateFanouts(map_node->name(), second_map_node_ptr->name()));
378 *graph->graph()->mutable_library()->mutable_function()->Add() =
379 split_results.first_function;
380 *graph->graph()->mutable_library()->mutable_function()->Add() =
381 split_results.second_function;
382 return OkStatus();
383 }
384
385 // Converts a ParallalBatch dataset to a Batch dataset, to make it
386 // deterministic.
ConvertBatch(const string & node_name,MutableGraphView * graph)387 Status ConvertBatch(const string& node_name, MutableGraphView* graph) {
388 NodeDef* node = GetMutableNode(node_name, graph);
389 node->set_op(kBatchV2Op);
390 std::string num_parallel_calls_input = node->input(2);
391 node->set_input(2, node->input(3));
392 node->set_input(3, absl::StrCat("^", num_parallel_calls_input));
393 node->mutable_attr()->erase("deterministic");
394 return OkStatus();
395 }
396
397 // Convert a MapAndBatch node to a separate Map node and Batch node, to make it
398 // deterministic. Caller should delete the MapAndBatch node afterwards.
399 // TODO(reedwm): Handle 'metadata' attribute. Currently the Map node and Batch
400 // node will have an empty 'metadata' attribute.
ConvertMapAndBatch(const string & node_name,MutableGraphView * graph)401 Status ConvertMapAndBatch(const string& node_name, MutableGraphView* graph) {
402 int index = graph_utils::FindGraphNodeWithName(node_name, *graph->graph());
403 DCHECK_NE(index, -1) << "Failed to find node " << node_name
404 << " in the optimized graph.";
405 const NodeDef& orig_node = graph->graph()->node(index);
406
407 auto Targuments = orig_node.attr().find("Targuments");
408 if (Targuments == orig_node.attr().end()) {
409 return errors::Internal("Failed to find Targuments attribute for node ",
410 node_name);
411 }
412
413 // Create map node
414 NodeDef new_map_node;
415 new_map_node.set_op(kMapOp);
416 graph_utils::SetUniqueGraphNodeName(kMapOp, graph->graph(), &new_map_node);
417 int num_map_inputs = 1 + Targuments->second.list().type_size();
418 for (int i = 0; i < num_map_inputs; i++) {
419 new_map_node.add_input(orig_node.input(i));
420 }
421 for (int i = num_map_inputs; i < orig_node.input_size(); i++) {
422 if (IsControlInput(orig_node.input(i))) {
423 new_map_node.add_input(orig_node.input(i));
424 } else {
425 new_map_node.add_input(absl::StrCat("^", orig_node.input(i)));
426 }
427 }
428 for (auto key : {"f", "Targuments", "output_types"}) {
429 graph_utils::CopyAttribute(key, orig_node, &new_map_node);
430 }
431 for (auto key : {"preserve_cardinality"}) {
432 if (gtl::FindOrNull(new_map_node.attr(), key)) {
433 graph_utils::CopyAttribute(key, orig_node, &new_map_node);
434 }
435 }
436 auto orig_output_shapes = orig_node.attr().find("output_shapes");
437 if (orig_output_shapes == orig_node.attr().end()) {
438 return errors::Internal("Failed to find output_shapes attribute for node ",
439 node_name);
440 }
441
442 // Set "output_shapes" attr of Map to be "output_shapes" of MapAndBatch with
443 // the leading dimension removed for each shape.
444 AttrValue& map_output_shapes =
445 (*new_map_node.mutable_attr())["output_shapes"];
446 for (const TensorShapeProto& orig_shape :
447 orig_output_shapes->second.list().shape()) {
448 TensorShapeProto* new_shape = map_output_shapes.mutable_list()->add_shape();
449 if (orig_shape.unknown_rank()) {
450 new_shape->set_unknown_rank(true);
451 } else if (orig_shape.dim_size() == 0) {
452 return errors::Internal(
453 "Output shape of MapAndBatch node cannot be scalar");
454 } else {
455 for (int i = 1; i < orig_shape.dim_size(); i++) {
456 *new_shape->add_dim() = orig_shape.dim(i);
457 }
458 }
459 }
460
461 // Create batch node
462 NodeDef new_batch_node;
463 new_batch_node.set_op(kBatchV2Op);
464 graph_utils::SetUniqueGraphNodeName(kBatchOp, graph->graph(),
465 &new_batch_node);
466 new_batch_node.add_input(new_map_node.name());
467 new_batch_node.add_input(orig_node.input(num_map_inputs)); // batch_size
468 new_batch_node.add_input(
469 orig_node.input(num_map_inputs + 2)); // drop_remainder
470 graph_utils::CopyShapesAndTypesAttrs(orig_node, &new_batch_node);
471
472 graph->AddNode(std::move(new_map_node));
473 NodeDef* graph_batch_node = graph->AddNode(std::move(new_batch_node));
474 TF_RETURN_IF_ERROR(
475 graph->UpdateFanouts(orig_node.name(), graph_batch_node->name()));
476 return OkStatus();
477 }
478
479 // Change the buffer_size of a Prefetch node to zero, effectively disabling it,
480 // to make it deterministic.
ConvertPrefetch(const string & node_name,MutableGraphView * graph)481 Status ConvertPrefetch(const string& node_name, MutableGraphView* graph) {
482 NodeDef* node = GetMutableNode(node_name, graph);
483 constexpr int buffer_size_index = 1;
484 node->add_input(absl::StrCat("^", node->input(buffer_size_index)));
485 NodeDef* tmp = graph_utils::AddScalarConstNode<int64_t>(0, graph);
486 node->set_input(buffer_size_index, tmp->name());
487 return OkStatus();
488 }
489
490 // The two ways nondeterminism can occur in an input pipeline when there are
491 // stateful ops.
492 enum class NondeterminismType { PARALLELISM, ASYNCHRONY };
493
494 // Returns whether the stateful op is deterministic if run in parallel or
495 // asynchronously.
IsDeterministicStatefulOp(NondeterminismType type,const std::string & stateful_op)496 bool IsDeterministicStatefulOp(NondeterminismType type,
497 const std::string& stateful_op) {
498 return type == NondeterminismType::PARALLELISM
499 ? IsDeterministicWhenRunInParallel(stateful_op)
500 : IsDeterministicWhenRunAsynchronously(stateful_op);
501 }
502
503 // Defined below. Mutually recursive with FunctionMayIntroduceNondeterminism.
504 bool FunctionNodeMayIntroduceNondeterminism(
505 const FunctionLibraryDefinition& library, const NodeDef& node_def,
506 NondeterminismType nondeterminism_type,
507 absl::flat_hash_set<std::string>* functions_processed);
508
509 // Returns true if the function may introduce nondeterminism. Depending on
510 // 'nondeterminism_type', either checks if nondeterminism can occur when the
511 // function is run several times in parallel or when run asynchronously.
512 // Recursively checks any function attributes of ops within the function.
513 // "functions_processed" is the list of functions already processed, so that the
514 // same function is not recursively checked twice. If not null, nodes causing
515 // nondeterminism will be added to "nondeterministic_nodes".
FunctionMayIntroduceNondeterminism(const FunctionLibraryDefinition & library,const std::string & function_name,NondeterminismType nondeterminism_type,absl::flat_hash_set<std::string> * functions_processed,absl::flat_hash_set<absl::string_view> * nondeterministic_nodes)516 bool FunctionMayIntroduceNondeterminism(
517 const FunctionLibraryDefinition& library, const std::string& function_name,
518 NondeterminismType nondeterminism_type,
519 absl::flat_hash_set<std::string>* functions_processed,
520 absl::flat_hash_set<absl::string_view>* nondeterministic_nodes) {
521 if (functions_processed->contains(function_name)) {
522 return false;
523 }
524 functions_processed->insert(function_name);
525 const FunctionDef* function_def = library.Find(function_name);
526 if (!function_def) {
527 VLOG(2) << "Could not look up function " << function_name
528 << " in FunctionLibraryDefinition, so rewriting op to be safe";
529 return true;
530 }
531 bool found = false;
532 for (const NodeDef& node_def : function_def->node_def()) {
533 bool nondeterministic = FunctionNodeMayIntroduceNondeterminism(
534 library, node_def, nondeterminism_type, functions_processed);
535 if (nondeterministic) {
536 if (nondeterministic_nodes) {
537 nondeterministic_nodes->insert(node_def.name());
538 found = true;
539 } else {
540 return true;
541 }
542 }
543 }
544 return found;
545 }
546
FunctionMayIntroduceNondeterminism(const FunctionLibraryDefinition & library,const std::string & function_name,NondeterminismType nondeterminism_type)547 bool FunctionMayIntroduceNondeterminism(
548 const FunctionLibraryDefinition& library, const std::string& function_name,
549 NondeterminismType nondeterminism_type) {
550 absl::flat_hash_set<string> functions_processed;
551 return FunctionMayIntroduceNondeterminism(library, function_name,
552 nondeterminism_type,
553 &functions_processed, nullptr);
554 }
555
556 // Returns true if the given NodeDef inside a function may cause nondeterminism.
FunctionNodeMayIntroduceNondeterminism(const FunctionLibraryDefinition & library,const NodeDef & node_def,NondeterminismType nondeterminism_type,absl::flat_hash_set<std::string> * functions_processed)557 bool FunctionNodeMayIntroduceNondeterminism(
558 const FunctionLibraryDefinition& library, const NodeDef& node_def,
559 NondeterminismType nondeterminism_type,
560 absl::flat_hash_set<std::string>* functions_processed) {
561 const OpRegistrationData* op_reg_data = nullptr;
562 Status s = library.LookUp(node_def.op(), &op_reg_data);
563 if (!s.ok()) {
564 VLOG(2) << "Could not look up op " << node_def.op()
565 << " in FunctionLibraryDefinition, so rewriting op to be safe";
566 return true;
567 }
568 bool is_function_op = op_reg_data->is_function_op;
569
570 bool is_stateful = false;
571 if (!is_function_op) {
572 const OpDef* op_def;
573 s = OpRegistry::Global()->LookUpOpDef(node_def.op(), &op_def);
574 if (!s.ok()) {
575 VLOG(2) << "Could not look up op " << node_def.op()
576 << " in OpRegistry, so rewriting op to be safe";
577 return true;
578 }
579 is_stateful = op_def->is_stateful();
580 }
581
582 // Rewrite nondeterministic stateful ops. Function ops and If/While ops are
583 // skipped, since we instead look at the ops within the function(s).
584 if (is_stateful && !IsStatefulPartitionedCall((node_def)) &&
585 !IsIf(node_def) && !IsWhile(node_def) &&
586 !IsDeterministicStatefulOp(nondeterminism_type, node_def.op())) {
587 VLOG(2) << "Will rewrite due to op: " << node_def.op();
588 return true;
589 }
590
591 // Recursively check for nondeterminism in all function attributes.
592 std::vector<std::string> attr_func_names;
593 for (const auto& attr : node_def.attr()) {
594 if (attr.second.has_func()) {
595 attr_func_names.push_back(attr.second.func().name());
596 }
597 for (const auto& name_attr_list : attr.second.list().func()) {
598 attr_func_names.push_back(name_attr_list.name());
599 }
600 }
601 if (is_function_op) {
602 attr_func_names.push_back(node_def.op());
603 }
604 for (const std::string& inner_function_name : attr_func_names) {
605 if (FunctionMayIntroduceNondeterminism(library, inner_function_name,
606 nondeterminism_type,
607 functions_processed, nullptr)) {
608 return true;
609 }
610 }
611 return false;
612 }
613
614 // Returns true if "node" is a dataset node whose function can introduce
615 // nondeterminism when run asynchronously.
NodeMayIntroduceNondeterminismWhenAsync(const FunctionLibraryDefinition & library,const NodeDef & node)616 bool NodeMayIntroduceNondeterminismWhenAsync(
617 const FunctionLibraryDefinition& library, const NodeDef& node) {
618 const OpDef* op_def;
619 Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
620 if (s.code() == error::NOT_FOUND) {
621 return false;
622 } else if (!s.ok()) {
623 return true;
624 }
625 if (data::DatasetOpKernel::IsDatasetOp(*op_def)) {
626 std::vector<std::string> attr_func_names;
627 for (const auto& attr : node.attr()) {
628 if (attr.second.has_func()) {
629 attr_func_names.push_back(attr.second.func().name());
630 }
631 for (const auto& name_attr_list : attr.second.list().func()) {
632 attr_func_names.push_back(name_attr_list.name());
633 }
634 }
635 for (const std::string& inner_function_name : attr_func_names) {
636 if (FunctionMayIntroduceNondeterminism(library, inner_function_name,
637 NondeterminismType::ASYNCHRONY)) {
638 return true;
639 }
640 }
641 }
642 return false;
643 }
644
645 // Returns true if the graph has any dataset node whose function can introduce
646 // nondeterminism when run asynchronously.
GraphMayHaveAsyncNondeterminism(const FunctionLibraryDefinition & library,const GraphDef & graph)647 bool GraphMayHaveAsyncNondeterminism(const FunctionLibraryDefinition& library,
648 const GraphDef& graph) {
649 for (const NodeDef& node : graph.node()) {
650 if (NodeMayIntroduceNondeterminismWhenAsync(library, node)) {
651 return true;
652 }
653 }
654 for (const string& function_name : library.ListFunctionNames()) {
655 const FunctionDef* function_def = library.Find(function_name);
656 CHECK(function_def); // Crash Ok
657 for (const NodeDef& node : function_def->node_def()) {
658 if (NodeMayIntroduceNondeterminismWhenAsync(library, node)) {
659 return true;
660 }
661 }
662 }
663 return false;
664 }
665
666 } // namespace
667
OptimizeAndCollectStats(Cluster * cluster,const GrapplerItem & item,GraphDef * output,OptimizationStats * stats)668 Status MakeDeterministic::OptimizeAndCollectStats(Cluster* cluster,
669 const GrapplerItem& item,
670 GraphDef* output,
671 OptimizationStats* stats) {
672 *output = item.graph;
673 MutableGraphView graph(output);
674 FunctionLibraryDefinition function_library(OpRegistry::Global(),
675 item.graph.library());
676 absl::flat_hash_set<string> nodes_to_delete;
677 bool remove_async_nodes =
678 GraphMayHaveAsyncNondeterminism(function_library, item.graph);
679
680 for (const NodeDef& node : item.graph.node()) {
681 if (graph_utils::HasSloppyAttr(node.op())) {
682 NodeDef* mutable_node = GetMutableNode(node.name(), &graph);
683 (*mutable_node->mutable_attr())["sloppy"].set_b(false);
684 stats->num_changes++;
685 }
686 if (graph_utils::HasDeterministicAttr(node.op())) {
687 NodeDef* mutable_node = GetMutableNode(node.name(), &graph);
688 (*mutable_node->mutable_attr())["deterministic"].set_s("true");
689 stats->num_changes++;
690 }
691
692 bool rewrite_due_to_async =
693 IntroducesAsynchrony(node.op()) && remove_async_nodes;
694 absl::flat_hash_set<std::string> functions_processed;
695 absl::flat_hash_set<absl::string_view> nondeterministic_nodes;
696 bool rewrite_due_to_parallelism =
697 IntroducesFunctionParallelism(node.op()) &&
698 FunctionMayIntroduceNondeterminism(
699 function_library, node.attr().at("f").func().name(),
700 NondeterminismType::PARALLELISM, &functions_processed,
701 &nondeterministic_nodes);
702 if (!rewrite_due_to_async && !rewrite_due_to_parallelism) {
703 continue;
704 }
705
706 VLOG(1) << "Rewriting node " << node.name() << " (" << node.op()
707 << ") because it introduces nondeterminism through "
708 << (rewrite_due_to_async ? "asynchrony" : "parallelism");
709
710 bool maybe_can_split =
711 !rewrite_due_to_async &&
712 (node.op() == kParallelMapOpV2 || IsMapAndBatch(node.op()));
713 if (maybe_can_split) {
714 Status s = SplitMap(function_library, node.name(), &graph,
715 nondeterministic_nodes);
716 if (s.ok()) {
717 VLOG(1) << "Split node " << node.name() << " (" << node.op()
718 << ") into two map nodes: a nonparallel version and a "
719 "parallel version.";
720 nodes_to_delete.insert(node.name());
721 continue;
722 } else if (s.code() == error::UNIMPLEMENTED) {
723 // If splitting the function is unimplemented, instead convert the node
724 // to a nonparallel version below.
725 VLOG(1) << "Could not move stateful ops to their own function, so will "
726 "convert node "
727 << node.name()
728 << " to a nonparallel version instead. Reason: " << s;
729 } else {
730 return s;
731 }
732 }
733
734 if (IsPrefetch(node.op())) {
735 TF_RETURN_IF_ERROR(ConvertPrefetch(node.name(), &graph));
736 } else if (IsMapAndBatch(node.op())) {
737 TF_RETURN_IF_ERROR(ConvertMapAndBatch(node.name(), &graph));
738 nodes_to_delete.insert(node.name());
739 } else if (IsParallelBatch(node.op())) {
740 TF_RETURN_IF_ERROR(ConvertBatch(node.name(), &graph));
741 } else {
742 DCHECK(IsParallelInterleave(node.op()) || IsParallelMap(node.op()));
743 TF_RETURN_IF_ERROR(ConvertMapOrInterleave(node.name(), &graph));
744 }
745 stats->num_changes++;
746 }
747
748 TF_RETURN_IF_ERROR(graph.DeleteNodes(nodes_to_delete));
749 return OkStatus();
750 }
751
752 REGISTER_GRAPH_OPTIMIZER_AS(MakeDeterministic, "make_deterministic");
753
754 } // namespace grappler
755 } // namespace tensorflow
756