1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for CrossDeviceOps.""" 16 17import collections 18import os 19import threading 20import time 21 22from absl.testing import parameterized 23 24from tensorflow.core.protobuf import config_pb2 25from tensorflow.core.protobuf import tensorflow_server_pb2 26from tensorflow.python.distribute import cluster_resolver as cluster_resolver_lib 27from tensorflow.python.distribute import collective_util 28from tensorflow.python.distribute import combinations 29from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib 30from tensorflow.python.distribute import cross_device_utils 31from tensorflow.python.distribute import device_util 32from tensorflow.python.distribute import multi_process_runner 33from tensorflow.python.distribute import multi_worker_test_base 34from tensorflow.python.distribute import reduce_util 35from tensorflow.python.distribute import test_util 36from tensorflow.python.distribute import values as value_lib 37from tensorflow.python.eager import context 38from tensorflow.python.eager import def_function 39from tensorflow.python.eager import test 40from tensorflow.python.framework import constant_op 41from tensorflow.python.framework import dtypes 42from tensorflow.python.framework import errors 43from tensorflow.python.framework import indexed_slices 44from tensorflow.python.framework import ops 45from tensorflow.python.ops import array_ops 46from tensorflow.python.ops import collective_ops 47from tensorflow.python.ops import control_flow_ops 48from tensorflow.python.ops import math_ops 49from tensorflow.python.util import nest 50 51CollectiveReplicaLauncher = cross_device_utils.CollectiveReplicaLauncher 52CommunicationImplementation = collective_util.CommunicationImplementation 53ReduceOp = reduce_util.ReduceOp 54IndexedSlicesValue = indexed_slices.IndexedSlicesValue 55IndexedSlices = indexed_slices.IndexedSlices 56 57 58def make_per_replica_value(value, devices): 59 """Creates a `PerReplica` object whose values reside in `devices`. 60 61 Args: 62 value: a tensor-convertible value or a `IndexedSlicesValue`, or a callable 63 that takes one argument (`device_idx`) and should return the value that is 64 going to be created on devices[device_idx]. 65 devices: a list of device strings to create `PerReplica` values on. 66 67 Returns: 68 A `PerReplica` object. 69 """ 70 values = [] 71 for device_idx, device in enumerate(devices): 72 if callable(value): 73 v = value(device_idx) 74 elif isinstance(value, list): 75 v = value[device_idx] 76 else: 77 v = value 78 if isinstance(v, IndexedSlicesValue): 79 with ops.device(device): 80 values.append( 81 IndexedSlices( 82 values=array_ops.identity(v.values), 83 indices=array_ops.identity(v.indices), 84 dense_shape=array_ops.identity(v.dense_shape))) 85 else: 86 with ops.device(device): 87 values.append(array_ops.identity(v)) 88 return value_lib.PerReplica(values) 89 90 91def enable_collective_ops(): 92 """Enable collectives in the current process.""" 93 cluster_resolver = cluster_resolver_lib.TFConfigClusterResolver() 94 context.context().configure_collective_ops( 95 collective_leader="'/job:worker/replica:0/task:0'") 96 config_proto = config_pb2.ConfigProto() 97 config_proto.experimental.collective_group_leader = ( 98 "/job:worker/replica:0/task:0") 99 server_def = tensorflow_server_pb2.ServerDef( 100 cluster=cluster_resolver.cluster_spec().as_cluster_def(), 101 default_session_config=config_proto, 102 job_name=cluster_resolver.task_type, 103 task_index=cluster_resolver.task_id, 104 protocol=cluster_resolver.rpc_layer) 105 context.context().enable_collective_ops(server_def) 106 # Recover default flag values. 107 CollectiveReplicaLauncher._prefer_unique_instance_key = True 108 CollectiveReplicaLauncher._prefer_ordering_token = False 109 110 111class MultiProcessPoolRunner(): 112 113 def __init__(self, num_processes): 114 cluster_spec_dict = multi_worker_test_base.create_cluster_spec( 115 num_workers=num_processes) 116 self.runner = multi_process_runner.MultiProcessPoolRunner(cluster_spec_dict) 117 118 119# Global MultiProcessPoolRunners that can be shared by test cases to avoid 120# expensive initialization cost of TensorFlow in new processes. 121# 122# Note that they have to be globals and can't be owned by test classes because 123# usually fn usually captures the test class instance, and test class 124# instance can't be pickled if it has mpr as a member (it is not allowed to 125# pickle Process objects). 126# TODO(crccw): Use `num_workers` combination once it is ready. 127global_mpr_2p = MultiProcessPoolRunner(num_processes=2) 128global_mpr_1p = MultiProcessPoolRunner(num_processes=1) 129 130 131def get_global_mpr(num_processes): 132 if num_processes == 1: 133 return global_mpr_1p.runner 134 elif num_processes == 2: 135 return global_mpr_2p.runner 136 else: 137 raise ValueError("get_global_mpr: num_processes must be 1 or 2, got %d" % 138 num_processes) 139 140 141class CollectiveOpsTest(test.TestCase, parameterized.TestCase): 142 143 def setUp(self): 144 super().setUp() 145 # Enabling collectives can be done in "setUpClass", but requires using 146 # different collective_keys in different tests as collectives are reused 147 # across tests. Always resetting collective ops before each test offers 148 # better test isolation. 149 global_mpr_1p.runner.run(enable_collective_ops) 150 global_mpr_2p.runner.run(enable_collective_ops) 151 152 def make_collective(self, num_processes, gpu_per_process): 153 """Returns collectives and other info to be used in tests. 154 155 Args: 156 num_processes: an integer indicating the number of processes that 157 participate in the collective. 158 gpu_per_process: number of GPUs (0 if no GPUs) used by each process. 159 160 Returns: 161 A tuple of (collective, devices, pid) where collective is a instance 162 of `CollectiveAllReduce`, devices are a list of local devices (str) 163 attached to the current process, and pid is the id of this process among 164 all participant processes. 165 """ 166 167 cluster_resolver = cluster_resolver_lib.TFConfigClusterResolver() 168 devices = [ 169 "/job:worker/replica:0/task:%d/device:CPU:0" % cluster_resolver.task_id 170 ] 171 if gpu_per_process > 0: 172 devices = [ 173 "/job:worker/replica:0/task:%d/device:GPU:%d" % 174 (cluster_resolver.task_id, i) for i in range(gpu_per_process) 175 ] 176 group_size = num_processes * len(devices) 177 collective = cross_device_ops_lib.CollectiveAllReduce( 178 devices=devices, 179 group_size=group_size, 180 options=collective_util.Options()) 181 return collective, devices, cluster_resolver.task_id 182 183 def as_list(self, value): 184 """An utility to convert a `Mirrored`, `Tensor` or `IndexedSlices` to a list. 185 186 The reason it exists is to provide a uniformed view of returned value of 187 "reduce" calls, especially across tf.function boundaries. Returning 188 `Mirrored` from a tf.function will only evaluate the primary value, which 189 makes collective ops of non-primary device being pruned, and will eventually 190 cause hanging. 191 192 Args: 193 value: the value to convert, can be one of `Mirrored`, `Tensor` and 194 `IndexedSlices`. 195 196 Returns: 197 A list of `Tensor` or `IndexedSlices`. 198 """ 199 if isinstance(value, ops.Tensor): 200 return [value] 201 elif isinstance(value, IndexedSlices): 202 return [value] 203 elif isinstance(value, value_lib.Mirrored): 204 return value.values 205 else: 206 raise ValueError("unwrap: unsupported input type: %s" % type(value)) 207 208 RunOptions = collections.namedtuple( # pylint: disable=invalid-name 209 "RunOptions", 210 [ 211 "mode", # A list of str from ["eager", "func_graph"] 212 "num_processes", 213 "gpus_per_process", 214 "reduce_op", 215 "communication_options", 216 "prefer_unique_instance_key", 217 ]) 218 RunOptions.__new__.__defaults__ = (["eager", 219 "func_graph"], 2, 0, ReduceOp.SUM, 220 collective_util.Options(), True) 221 222 def reduce_and_verify(self, inputs, expect, options): 223 """Reduce the given `inputs` and verify the output matches `expect`. 224 225 Args: 226 inputs: a list of `Tensor` or `IndexedSlices`, where i-th value will be 227 fed to i-th replica. 228 expect: a `Tensor` or `IndexedSlices`. This should be the expected value 229 for one replica. 230 options: a `RunOpotions` instance. 231 """ 232 233 def replica_fn(): 234 CollectiveReplicaLauncher._prefer_unique_instance_key = ( 235 options.prefer_unique_instance_key) 236 collective, devices, pid = self.make_collective(options.num_processes, 237 options.gpus_per_process) 238 239 def reduce_fn(): 240 value_fn = lambda device_idx: inputs[pid * len(devices) + device_idx] 241 per_replica_value = make_per_replica_value(value_fn, devices) 242 reduced_values = collective.reduce(options.reduce_op, per_replica_value, 243 per_replica_value, 244 options.communication_options) 245 if options.gpus_per_process > 1: 246 self.assertIsInstance(reduced_values, value_lib.Mirrored) 247 reduced_values = self.as_list(reduced_values) 248 self.assertAllEqual(devices, [v.device for v in reduced_values]) 249 return [ops.convert_to_tensor(v) for v in reduced_values] 250 251 per_replica_expect = [ops.convert_to_tensor(expect)] * len(devices) 252 253 if "eager" in options.mode: 254 got = reduce_fn() 255 self.assertAllClose(got, per_replica_expect) 256 257 if "func_graph" in options.mode: 258 got = def_function.function(reduce_fn)() 259 self.assertAllClose(got, per_replica_expect) 260 261 get_global_mpr(options.num_processes).run(replica_fn) 262 263 def batch_reduce_and_verify(self, inputs, expect, options): 264 """Batch reduce the given `inputs` and verify the output matches `expect`. 265 266 Args: 267 inputs: a 2-level nested list of `Tensor` or `IndexedSlices`, where i-th 268 value will be fed to i-th replica. 269 expect: a list of `Tensor` or `IndexedSlices`. This should be the expected 270 value for one replica. 271 options: a `RunOpotions` instance. 272 """ 273 274 def replica_fn(): 275 CollectiveReplicaLauncher._prefer_unique_instance_key = ( 276 options.prefer_unique_instance_key) 277 collective, devices, pid = self.make_collective(options.num_processes, 278 options.gpus_per_process) 279 280 def batch_reduce_fn(): 281 batch_size = len(inputs[0]) 282 value_dst_pairs = [] 283 for i in range(batch_size): 284 285 def value_fn(device_idx, idx=i): 286 return inputs[pid * len(devices) + device_idx][idx] 287 288 per_replica_value = make_per_replica_value(value_fn, devices) 289 value_dst_pairs.append((per_replica_value, per_replica_value)) 290 reduced_values = collective.batch_reduce(options.reduce_op, 291 value_dst_pairs, 292 options.communication_options) 293 if options.gpus_per_process > 1: 294 for v in reduced_values: 295 self.assertIsInstance(v, value_lib.Mirrored) 296 reduced_values = [self.as_list(v) for v in reduced_values] 297 for v in reduced_values: 298 self.assertAllEqual(devices, [t.device for t in v]) 299 return nest.map_structure(ops.convert_to_tensor, reduced_values) 300 301 per_replica_expect = nest.map_structure( 302 lambda x: [ops.convert_to_tensor(x)] * len(devices), expect) 303 304 if "eager" in options.mode: 305 got = batch_reduce_fn() 306 self.assertAllClose(got, per_replica_expect) 307 308 if "func_graph" in options.mode: 309 got = def_function.function(batch_reduce_fn)() 310 self.assertAllClose(got, per_replica_expect) 311 312 get_global_mpr(options.num_processes).run(replica_fn) 313 314 @combinations.generate( 315 combinations.combine( 316 num_processes=[1, 2], 317 required_gpus=[0, 1, 2], 318 implementation=[ 319 CommunicationImplementation.AUTO, 320 CommunicationImplementation.RING, 321 CommunicationImplementation.NCCL, 322 ], 323 reduce_op=[ReduceOp.SUM, ReduceOp.MEAN], 324 prefer_unique_instance_key=[True, False])) 325 def testReduceDense(self, num_processes, required_gpus, implementation, 326 reduce_op, prefer_unique_instance_key): 327 if (required_gpus == 0 and 328 implementation == CommunicationImplementation.NCCL): 329 self.skipTest("Skip CPU + NCCL combination") 330 if (num_processes != required_gpus and 331 implementation == CommunicationImplementation.NCCL): 332 self.skipTest("Skip NCCL combination with mismatched process and GPU " 333 "count. NCCL requires physical GPUs for every process.") 334 if (num_processes != required_gpus and 335 implementation == CommunicationImplementation.AUTO): 336 self.skipTest("Skip potential NCCL combination (AUTO) with mismatched " 337 "process and GPU count. NCCL requires physical GPUs for " 338 "every process.") 339 options = self.RunOptions( 340 num_processes=num_processes, 341 gpus_per_process=required_gpus, 342 reduce_op=reduce_op, 343 communication_options=collective_util.Options( 344 implementation=implementation), 345 prefer_unique_instance_key=prefer_unique_instance_key) 346 group_size = options.num_processes * (options.gpus_per_process or 1) 347 348 inputs_data = [1.0, 2.0, 3.0, 4.0] 349 inputs = inputs_data[0:group_size] 350 351 if group_size == 1: 352 expect = 1.0 353 if group_size == 2: 354 expect = 3.0 if reduce_op == ReduceOp.SUM else 1.5 355 elif group_size == 4: 356 expect = 10.0 if reduce_op == ReduceOp.SUM else 2.5 357 358 self.reduce_and_verify(inputs, expect, options) 359 360 @combinations.generate( 361 combinations.combine( 362 num_processes=[1, 2], 363 required_gpus=[0, 1, 2], 364 implementation=[ 365 CommunicationImplementation.AUTO, 366 CommunicationImplementation.RING, 367 CommunicationImplementation.NCCL, 368 ], 369 # TODO(b/166682130): add MEAN reduce once the bug is fixed. 370 reduce_op=ReduceOp.SUM, 371 prefer_unique_instance_key=[True, False])) 372 def testReduceSparse(self, num_processes, required_gpus, implementation, 373 reduce_op, prefer_unique_instance_key): 374 if (required_gpus == 0 and 375 implementation == CommunicationImplementation.NCCL): 376 self.skipTest("Skip CPU + NCCL combination") 377 if (num_processes != required_gpus and 378 implementation == CommunicationImplementation.NCCL): 379 self.skipTest("Skip NCCL combination with mismatched process and GPU " 380 "count. NCCL requires physical GPUs for every process.") 381 if (num_processes != required_gpus and 382 implementation == CommunicationImplementation.AUTO): 383 self.skipTest("Skip potential NCCL combination (AUTO) with mismatched " 384 "process and GPU count. NCCL requires physical GPUs for " 385 "every process.") 386 options = self.RunOptions( 387 mode=["func_graph"], # Sparse reduce is not supported in eager. 388 num_processes=num_processes, 389 gpus_per_process=required_gpus, 390 reduce_op=reduce_op, 391 communication_options=collective_util.Options( 392 implementation=implementation), 393 prefer_unique_instance_key=prefer_unique_instance_key) 394 group_size = options.num_processes * (options.gpus_per_process or 1) 395 396 inputs_data = [ 397 IndexedSlicesValue( 398 values=[[1.], [2.]], indices=[0, 1], dense_shape=[10, 1]), 399 IndexedSlicesValue( 400 values=[[3.], [4.]], indices=[1, 2], dense_shape=[10, 1]), 401 IndexedSlicesValue( 402 values=[[5.], [6.]], indices=[7, 8], dense_shape=[10, 1]), 403 IndexedSlicesValue( 404 values=[[7.], [8.]], indices=[3, 2], dense_shape=[10, 1]), 405 ] 406 inputs = inputs_data[0:group_size] 407 408 if group_size == 1: 409 expect = IndexedSlices( 410 values=[[1.], [2.]], indices=[0, 1], dense_shape=[10, 1]) 411 elif group_size == 2: 412 expect = IndexedSlices( 413 values=[[1.], [2.], [3.], [4.]], 414 indices=[0, 1, 1, 2], 415 dense_shape=[10, 1]) 416 elif group_size == 4: 417 expect = IndexedSlices( 418 values=[[1.], [2.], [3.], [4.], [5.], [6.], [7.], [8.]], 419 indices=[0, 1, 1, 2, 7, 8, 3, 2], 420 dense_shape=[10, 1]) 421 422 self.reduce_and_verify(inputs, expect, options) 423 424 @combinations.generate( 425 combinations.combine(prefer_unique_instance_key=[True, False])) 426 def testReduceSparseVariableLength(self, prefer_unique_instance_key): 427 # One device per process, 2 processes, 2 replicas in total. 428 inputs = [ 429 IndexedSlicesValue(values=[[1.]], indices=[0], dense_shape=[10, 1]), 430 IndexedSlicesValue( 431 values=[[2.], [3.], [4.]], indices=[0, 1, 2], dense_shape=[10, 1]), 432 ] 433 expect = IndexedSlices( 434 values=[[1.], [2.], [3.], [4.]], 435 indices=[0, 0, 1, 2], 436 dense_shape=[10, 1]) 437 self.reduce_and_verify( 438 inputs, 439 expect, 440 self.RunOptions( 441 mode=["func_graph"], # Sparse reduce is not supported in eager. 442 num_processes=2, 443 reduce_op=ReduceOp.SUM, 444 prefer_unique_instance_key=prefer_unique_instance_key)) 445 446 @combinations.generate( 447 combinations.combine( 448 num_processes=[1, 2], 449 required_gpus=[0, 1, 2], 450 implementation=[ 451 CommunicationImplementation.AUTO, 452 CommunicationImplementation.RING, 453 CommunicationImplementation.NCCL, 454 ], 455 reduce_op=[ReduceOp.SUM, ReduceOp.MEAN], 456 prefer_unique_instance_key=[True, False])) 457 def testBatchReduceDense(self, num_processes, required_gpus, implementation, 458 reduce_op, prefer_unique_instance_key): 459 if (required_gpus == 0 and 460 implementation == CommunicationImplementation.NCCL): 461 self.skipTest("Skip CPU + NCCL combination") 462 if (num_processes != required_gpus and 463 implementation == CommunicationImplementation.NCCL): 464 self.skipTest("Skip NCCL combination with mismatched process and GPU " 465 "count. NCCL requires physical GPUs for every process.") 466 if (num_processes != required_gpus and 467 implementation == CommunicationImplementation.AUTO): 468 self.skipTest("Skip potential NCCL combination (AUTO) with mismatched " 469 "process and GPU count. NCCL requires physical GPUs for " 470 "every process.") 471 472 options = self.RunOptions( 473 num_processes=num_processes, 474 gpus_per_process=required_gpus, 475 reduce_op=reduce_op, 476 communication_options=collective_util.Options( 477 implementation=implementation), 478 prefer_unique_instance_key=prefer_unique_instance_key) 479 group_size = options.num_processes * (options.gpus_per_process or 1) 480 481 inputs_data = [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]] 482 inputs = inputs_data[0:group_size] 483 484 if group_size == 1: 485 expect = [1.0, 2.0] 486 if group_size == 2: 487 expect = [4.0, 6.0] if reduce_op == ReduceOp.SUM else [2.0, 3.0] 488 elif group_size == 4: 489 expect = [16.0, 20.0] if reduce_op == ReduceOp.SUM else [4.0, 5.0] 490 491 self.batch_reduce_and_verify(inputs, expect, options) 492 493 @combinations.generate( 494 combinations.combine( 495 num_processes=[1, 2], 496 required_gpus=[0, 1, 2], 497 implementation=[ 498 CommunicationImplementation.AUTO, 499 CommunicationImplementation.RING, 500 CommunicationImplementation.NCCL, 501 ], 502 # TODO(b/166682130): add MEAN reduce once the bug is fixed. 503 reduce_op=ReduceOp.SUM, 504 prefer_unique_instance_key=[True, False])) 505 def testBatchReduceSparse(self, num_processes, required_gpus, implementation, 506 reduce_op, prefer_unique_instance_key): 507 if (required_gpus == 0 and 508 implementation == CommunicationImplementation.NCCL): 509 self.skipTest("Skip CPU + NCCL combination") 510 if (num_processes != required_gpus and 511 implementation == CommunicationImplementation.NCCL): 512 self.skipTest("Skip NCCL combination with mismatched process and GPU " 513 "count. NCCL requires physical GPUs for every process.") 514 if (num_processes != required_gpus and 515 implementation == CommunicationImplementation.AUTO): 516 self.skipTest("Skip potential NCCL combination (AUTO) with mismatched " 517 "process and GPU count. NCCL requires physical GPUs for " 518 "every process.") 519 520 options = self.RunOptions( 521 mode=["func_graph"], # Sparse reduce is not supported in eager. 522 num_processes=num_processes, 523 gpus_per_process=required_gpus, 524 reduce_op=reduce_op, 525 communication_options=collective_util.Options( 526 implementation=implementation), 527 prefer_unique_instance_key=prefer_unique_instance_key) 528 group_size = options.num_processes * (options.gpus_per_process or 1) 529 530 inputs_data = ([ 531 IndexedSlicesValue( 532 values=[[1.], [2.]], indices=[0, 1], dense_shape=[10, 1]), 533 IndexedSlicesValue( 534 values=[[3.], [4.]], indices=[1, 2], dense_shape=[5, 1]) 535 ], [ 536 IndexedSlicesValue( 537 values=[[5.], [6.]], indices=[1, 2], dense_shape=[10, 1]), 538 IndexedSlicesValue( 539 values=[[7.], [8.]], indices=[0, 1], dense_shape=[5, 1]) 540 ], [ 541 IndexedSlicesValue( 542 values=[[9.], [10.]], indices=[3, 4], dense_shape=[10, 1]), 543 IndexedSlicesValue( 544 values=[[11.], [12.]], indices=[3, 4], dense_shape=[5, 1]) 545 ], [ 546 IndexedSlicesValue( 547 values=[[13.], [14.]], indices=[8, 9], dense_shape=[10, 1]), 548 IndexedSlicesValue( 549 values=[[15.], [16.]], indices=[3, 4], dense_shape=[5, 1]) 550 ]) 551 inputs = inputs_data[0:group_size] 552 553 if group_size == 1: 554 expect = [ 555 IndexedSlices( 556 values=[[1.], [2.]], indices=[0, 1], dense_shape=[10, 1]), 557 IndexedSlices( 558 values=[[3.], [4.]], indices=[1, 2], dense_shape=[5, 1]) 559 ] 560 if group_size == 2: 561 expect = [ 562 IndexedSlices( 563 values=[[1.], [2.], [5.], [6.]], 564 indices=[0, 1, 1, 2], 565 dense_shape=[10, 1]), 566 IndexedSlices( 567 values=[[3.], [4.], [7.], [8.]], 568 indices=[1, 2, 0, 1], 569 dense_shape=[5, 1]) 570 ] 571 elif group_size == 4: 572 expect = [ 573 IndexedSlices( 574 values=[[1.], [2.], [5.], [6.], [9.], [10.], [13.], [14.]], 575 indices=[0, 1, 1, 2, 3, 4, 8, 9], 576 dense_shape=[10, 1]), 577 IndexedSlices( 578 values=[[3.], [4.], [7.], [8.], [11.], [12.], [15.], [16.]], 579 indices=[1, 2, 0, 1, 3, 4, 3, 4], 580 dense_shape=[5, 2]) 581 ] 582 self.batch_reduce_and_verify(inputs, expect, options) 583 584 def testBatchReduceMixedDenseAndSparse(self): 585 586 options = self.RunOptions( 587 num_processes=2, 588 gpus_per_process=0, 589 reduce_op=ReduceOp.SUM, 590 mode=["func_graph"]) 591 592 inputs_data = [ 593 [ 594 1.0, 2.0, 595 IndexedSlicesValue( 596 values=[[1.], [2.]], indices=[0, 1], dense_shape=[10, 1]), 597 IndexedSlicesValue( 598 values=[[3.], [4.]], indices=[1, 2], dense_shape=[5, 1]) 599 ], 600 [ 601 3.0, 4.0, 602 IndexedSlicesValue( 603 values=[[5.], [6.]], indices=[1, 2], dense_shape=[10, 1]), 604 IndexedSlicesValue( 605 values=[[7.], [8.]], indices=[0, 1], dense_shape=[5, 1]) 606 ], 607 ] 608 609 expect = [ 610 4.0, 6.0, 611 IndexedSlices( 612 values=[[1.], [2.], [5.], [6.]], 613 indices=[0, 1, 1, 2], 614 dense_shape=[10, 1]), 615 IndexedSlices( 616 values=[[3.], [4.], [7.], [8.]], 617 indices=[1, 2, 0, 1], 618 dense_shape=[5, 1]) 619 ] 620 621 self.batch_reduce_and_verify(inputs_data, expect, options) 622 623 @combinations.generate( 624 combinations.combine( 625 num_processes=[1, 2], 626 required_gpus=[0, 1, 2], 627 implementation=[ 628 CommunicationImplementation.AUTO, 629 CommunicationImplementation.RING, 630 CommunicationImplementation.NCCL, 631 ], 632 reduce_op=[ReduceOp.SUM, ReduceOp.MEAN], 633 )) 634 def testAllReduceDense(self, num_processes, required_gpus, implementation, 635 reduce_op): 636 if (required_gpus == 0 and 637 implementation == CommunicationImplementation.NCCL): 638 self.skipTest("Skip CPU + NCCL combination") 639 if (num_processes != required_gpus and 640 implementation == CommunicationImplementation.NCCL): 641 self.skipTest("Skip NCCL combination with mismatched process and GPU " 642 "count. NCCL requires physical GPUs for every process.") 643 if (num_processes != required_gpus and 644 implementation == CommunicationImplementation.AUTO): 645 self.skipTest("Skip potential NCCL combination (AUTO) with mismatched " 646 "process and GPU count. NCCL requires physical GPUs for " 647 "every process.") 648 649 def replica_fn(): 650 collective, devices, _ = self.make_collective(num_processes, 651 required_gpus) 652 options = collective_util.Options(implementation=implementation) 653 group_size = num_processes * (required_gpus or 1) 654 655 @def_function.function 656 def collective_all_reduce(): 657 results = [] 658 for replica_id, device in enumerate(devices): 659 with ops.device(device): 660 value = constant_op.constant(1.0) 661 results.append( 662 collective._all_reduce(reduce_op, value, replica_id, options)) 663 return results 664 665 got = collective_all_reduce() 666 if reduce_op == ReduceOp.SUM: 667 expect = [1.0 * group_size] * len(devices) 668 elif reduce_op == ReduceOp.MEAN: 669 expect = [1.0] * len(devices) 670 self.assertAllClose(got, expect) 671 672 @def_function.function 673 def collective_batch_all_reduce(): 674 results = [] 675 for replica_id, device in enumerate(devices): 676 with ops.device(device): 677 value = (constant_op.constant(1.0), constant_op.constant(2.0)) 678 results.append( 679 collective._all_reduce(reduce_op, value, replica_id, options)) 680 return results 681 682 got = collective_batch_all_reduce() 683 if reduce_op == ReduceOp.SUM: 684 expect = [(1.0 * group_size, 2.0 * group_size)] * len(devices) 685 elif reduce_op == ReduceOp.MEAN: 686 expect = [(1.0, 2.0)] * len(devices) 687 self.assertAllClose(got, expect) 688 689 get_global_mpr(num_processes).run(replica_fn) 690 691 @combinations.generate( 692 combinations.combine( 693 num_processes=[1, 2], 694 required_gpus=[0, 1, 2], 695 implementation=[ 696 CommunicationImplementation.AUTO, 697 CommunicationImplementation.RING, 698 CommunicationImplementation.NCCL, 699 ], 700 reduce_op=[ReduceOp.SUM, ReduceOp.MEAN], 701 )) 702 def testAllReduceSparse(self, num_processes, required_gpus, implementation, 703 reduce_op): 704 if (required_gpus == 0 and 705 implementation == CommunicationImplementation.NCCL): 706 self.skipTest("Skip CPU + NCCL combination") 707 if (num_processes != required_gpus and 708 implementation == CommunicationImplementation.NCCL): 709 self.skipTest("Skip NCCL combination with mismatched process and GPU " 710 "count. NCCL requires physical GPUs for every process.") 711 if (num_processes != required_gpus and 712 implementation == CommunicationImplementation.AUTO): 713 self.skipTest("Skip potential NCCL combination (AUTO) with mismatched " 714 "process and GPU count. NCCL requires physical GPUs for " 715 "every process.") 716 717 def replica_fn(): 718 collective, devices, _ = self.make_collective(num_processes, 719 required_gpus) 720 options = collective_util.Options(implementation=implementation) 721 group_size = num_processes * (required_gpus or 1) 722 723 @def_function.function 724 def collective_all_reduce(): 725 results = [] 726 for replica_id, device in enumerate(devices): 727 with ops.device(device): 728 value = IndexedSlices( 729 values=array_ops.identity([[1.]]), 730 indices=array_ops.identity([0]), 731 dense_shape=array_ops.identity([5, 1])) 732 results.append( 733 collective._all_reduce(reduce_op, value, replica_id, options)) 734 return results 735 736 got = collective_all_reduce() 737 if reduce_op == ReduceOp.SUM: 738 expect = [IndexedSlices([[1. * group_size]], [0], [5, 1]) 739 ] * len(devices) 740 elif reduce_op == ReduceOp.MEAN: 741 expect = [IndexedSlices([[1.]], [0], [5, 1])] * len(devices) 742 self.assertAllClose( 743 nest.map_structure(ops.convert_to_tensor, got), 744 nest.map_structure(ops.convert_to_tensor, expect)) 745 746 @def_function.function 747 def collective_batch_all_reduce(): 748 results = [] 749 for replica_id, device in enumerate(devices): 750 with ops.device(device): 751 value = (IndexedSlices( 752 array_ops.identity([[1.]]), array_ops.identity([0]), 753 array_ops.identity([5, 1])), 754 IndexedSlices( 755 array_ops.identity([[3.]]), array_ops.identity([2]), 756 array_ops.identity([5, 1]))) 757 results.append( 758 collective._all_reduce(reduce_op, value, replica_id, options)) 759 return results 760 761 got = collective_batch_all_reduce() 762 if reduce_op == ReduceOp.SUM: 763 expect = [(IndexedSlices([[1. * group_size]], [0], [5, 1]), 764 IndexedSlices([[3. * group_size]], [2], [5, 1])) 765 ] * len(devices) 766 elif reduce_op == ReduceOp.MEAN: 767 expect = [(IndexedSlices([[1.]], [0], [5, 1]), 768 IndexedSlices([[3.]], [2], [5, 1]))] * len(devices) 769 self.assertAllClose( 770 nest.map_structure(ops.convert_to_tensor, got), 771 nest.map_structure(ops.convert_to_tensor, expect)) 772 773 get_global_mpr(num_processes).run(replica_fn) 774 775 @combinations.generate( 776 combinations.combine( 777 num_processes=2, 778 required_gpus=0, 779 implementation=CommunicationImplementation.AUTO, 780 reduce_op=ReduceOp.SUM)) 781 def testAllReduceMixedDenseAndSparse(self, num_processes, required_gpus, 782 implementation, reduce_op): 783 784 if (num_processes != required_gpus and 785 implementation == CommunicationImplementation.AUTO): 786 self.skipTest("Skip potential NCCL combination (AUTO) with mismatched " 787 "process and GPU count. NCCL requires physical GPUs for " 788 "every process.") 789 790 def replica_fn(): 791 collective, devices, _ = self.make_collective(num_processes, 792 required_gpus) 793 options = collective_util.Options(implementation=implementation) 794 group_size = num_processes * (required_gpus or 1) 795 796 @def_function.function 797 def collective_batch_all_reduce(): 798 results = [] 799 for replica_id, device in enumerate(devices): 800 with ops.device(device): 801 value = (IndexedSlices( 802 array_ops.identity([[1.]]), array_ops.identity([0]), 803 array_ops.identity([5, 1])), array_ops.identity(1.0), 804 IndexedSlices( 805 array_ops.identity([[3.]]), array_ops.identity([2]), 806 array_ops.identity([5, 1])), array_ops.identity(2.0)) 807 results.append( 808 collective._all_reduce(reduce_op, value, replica_id, options)) 809 return results 810 811 got = collective_batch_all_reduce() 812 expect = [ 813 (IndexedSlices([[1. * group_size]], [0], [5, 1]), 1.0 * group_size, 814 IndexedSlices([[3. * group_size]], [2], [5, 1]), 2.0 * group_size) 815 ] * len(devices) 816 self.assertAllClose( 817 nest.map_structure(ops.convert_to_tensor, got), 818 nest.map_structure(ops.convert_to_tensor, expect)) 819 820 get_global_mpr(num_processes).run(replica_fn) 821 822 @combinations.generate( 823 combinations.combine( 824 num_processes=[1, 2], 825 required_gpus=[0, 1, 2], 826 axis=[0, 1, 2], 827 func_mode=["eager", "func_graph"], 828 implementation=[ 829 CommunicationImplementation.AUTO, 830 CommunicationImplementation.RING, 831 CommunicationImplementation.NCCL, 832 ], 833 prefer_unique_instance_key=[True, False])) 834 def testAllGatherSameShape(self, num_processes, required_gpus, implementation, 835 func_mode, axis, prefer_unique_instance_key): 836 837 if (required_gpus == 0 and 838 implementation == CommunicationImplementation.NCCL): 839 self.skipTest("Skip CPU + NCCL combination") 840 if (num_processes != required_gpus and 841 implementation == CommunicationImplementation.NCCL): 842 self.skipTest("Skip NCCL combination with mismatched process and GPU " 843 "count. NCCL requires physical GPUs for every process.") 844 if (num_processes != required_gpus and 845 implementation == CommunicationImplementation.AUTO): 846 self.skipTest("Skip potential NCCL combination (AUTO) with mismatched " 847 "process and GPU count. NCCL requires physical GPUs for " 848 "every process.") 849 850 def replica_fn(): 851 CollectiveReplicaLauncher._prefer_unique_instance_key = ( 852 prefer_unique_instance_key) 853 collective, devices, _ = self.make_collective(num_processes, 854 required_gpus) 855 options = collective_util.Options(implementation=implementation) 856 value = constant_op.constant([[[1, 2], [1, 2]]], dtype=dtypes.float32) 857 858 def gather_fn(): 859 per_replica_value = make_per_replica_value(value, devices) 860 gathered_values = collective._gather( 861 per_replica_value, per_replica_value, axis=axis, options=options) 862 gathered_values = self.as_list(gathered_values) 863 # Skip checking devices in eager. In eager the device attribute doesn't 864 # reflect the actual device of the tensor. 865 if not context.executing_eagerly(): 866 self.assertAllEqual(devices, [v.device for v in gathered_values]) 867 return [ops.convert_to_tensor(v) for v in gathered_values] 868 869 group_size = num_processes * (required_gpus or 1) 870 expect = array_ops.concat([value] * group_size, axis=axis) 871 per_replica_expect = [ops.convert_to_tensor(expect)] * len(devices) 872 873 if func_mode == "eager": 874 result = gather_fn() 875 self.assertAllClose(result, per_replica_expect) 876 877 if func_mode == "func_graph": 878 result = def_function.function(gather_fn)() 879 self.assertAllClose(result, per_replica_expect) 880 881 get_global_mpr(num_processes).run(replica_fn) 882 883 @combinations.generate( 884 combinations.combine( 885 num_processes=[1, 2], 886 required_gpus=[0, 1, 2], 887 implementation=[CommunicationImplementation.RING])) 888 def testCollectiveV2ControlFlow(self, num_processes, required_gpus, 889 implementation): 890 891 def replica_fn(): 892 CollectiveReplicaLauncher._prefer_unique_instance_key = True 893 collective, devices, _ = self.make_collective(num_processes, 894 required_gpus) 895 options = collective_util.Options(implementation=implementation) 896 value = make_per_replica_value(constant_op.constant([1.]), devices) 897 898 @def_function.function 899 def reduce_fn(): 900 901 def cond_body(): 902 reduced = collective.reduce(reduce_util.ReduceOp.SUM, value, value, 903 options) 904 return math_ops.add_n(self.as_list(reduced)) / len(devices) 905 906 return control_flow_ops.cond( 907 array_ops.identity(False), cond_body, cond_body) 908 909 num_replicas = num_processes * len(devices) 910 self.assertAllEqual(reduce_fn(), [1. * num_replicas]) 911 912 get_global_mpr(num_processes).run(replica_fn) 913 914 @combinations.generate( 915 combinations.combine( 916 num_processes=1, 917 required_gpus=2, 918 implementation=[ 919 CommunicationImplementation.RING, 920 CommunicationImplementation.NCCL, 921 ], 922 prefer_unique_instance_key=[True, False])) 923 def testMultiThreadedCollectiveLaunchNoInterleave(self, num_processes, 924 required_gpus, 925 implementation, 926 prefer_unique_instance_key): 927 928 if (num_processes != required_gpus and 929 implementation == CommunicationImplementation.NCCL): 930 self.skipTest("Skip NCCL combination with mismatched process and GPU " 931 "count. NCCL requires physical GPUs for every process.") 932 if (num_processes != required_gpus and 933 implementation == CommunicationImplementation.AUTO): 934 self.skipTest("Skip potential NCCL combination (AUTO) with mismatched " 935 "process and GPU count. NCCL requires physical GPUs for " 936 "every process.") 937 938 def replica_fn(): 939 CollectiveReplicaLauncher._prefer_unique_instance_key = ( 940 prefer_unique_instance_key) 941 collective, devices, _ = self.make_collective(num_processes, 942 required_gpus) 943 options = collective_util.Options(implementation=implementation) 944 945 # We would like to simulate the following sequence: 946 # thread-0 device0 device1 947 # thread-1 device0 device1 948 # If the kernel launch sequence is as-is the program will deadlock since 949 # NCCL requires the launch order to be same on each device. 950 v0 = make_per_replica_value(1.0, devices) 951 v1 = make_per_replica_value(2.0, devices) 952 953 # Add a delay to collective_ops.all_reduce according to the input tensors 954 # index in `sequence.` 955 sequence = [v0.values[0], v1.values[0], v1.values[1], v0.values[1]] 956 all_reduce = collective_ops.all_reduce 957 958 def delayed_all_reduce(input_tensor, *args, **kwargs): 959 for idx, v in enumerate(sequence): 960 if input_tensor is v: 961 time.sleep(idx) 962 break 963 return all_reduce(input_tensor, *args, **kwargs) 964 965 with test.mock.patch.object(collective_ops, "all_reduce", 966 delayed_all_reduce): 967 # We only use NCCL for batch reduce with two or more values, so we use 968 # two values here. 969 970 def thread_fn(): 971 reduced = collective.batch_reduce(reduce_util.ReduceOp.SUM, 972 [(v0, v0), (v0, v0)], options) 973 self.assertAllEqual(reduced[0].values, [2.0, 2.0]) 974 self.assertAllEqual(reduced[1].values, [2.0, 2.0]) 975 976 t = threading.Thread(target=thread_fn) 977 t.start() 978 reduced = collective.batch_reduce(reduce_util.ReduceOp.SUM, [(v1, v1), 979 (v1, v1)], 980 options) 981 self.assertAllEqual(reduced[0].values, [4.0, 4.0]) 982 self.assertAllEqual(reduced[1].values, [4.0, 4.0]) 983 t.join() 984 985 get_global_mpr(num_processes).run(replica_fn) 986 987 @combinations.generate( 988 combinations.combine( 989 num_processes=1, 990 required_gpus=2, 991 implementation=[ 992 CommunicationImplementation.RING, 993 CommunicationImplementation.NCCL, 994 ], 995 prefer_unique_instance_key=[True, False])) 996 def testInputsAreFunctionArgs(self, num_processes, required_gpus, 997 implementation, prefer_unique_instance_key): 998 999 if (num_processes != required_gpus and 1000 implementation == CommunicationImplementation.NCCL): 1001 self.skipTest("Skip NCCL combination with mismatched process and GPU " 1002 "count. NCCL requires physical GPUs for every process.") 1003 if (num_processes != required_gpus and 1004 implementation == CommunicationImplementation.AUTO): 1005 self.skipTest("Skip potential NCCL combination (AUTO) with mismatched " 1006 "process and GPU count. NCCL requires physical GPUs for " 1007 "every process.") 1008 1009 def replica_fn(): 1010 CollectiveReplicaLauncher._prefer_unique_instance_key = ( 1011 prefer_unique_instance_key) 1012 collective, devices, _ = self.make_collective(num_processes, 1013 required_gpus) 1014 options = collective_util.Options(implementation=implementation) 1015 1016 @def_function.function 1017 def reduce_fn(v): 1018 # Function inputs don't have device placement. 1019 self.assertEqual(v.values[0].device, "") 1020 self.assertEqual(v.values[1].device, "") 1021 # We only use NCCL for batch reduce with two or more values, so we use 1022 # two values here. 1023 reduced = collective.batch_reduce(reduce_util.ReduceOp.SUM, [(v, v), 1024 (v, v)], 1025 options) 1026 self.assertEqual(reduced[0].values[0].device, devices[0]) 1027 self.assertEqual(reduced[0].values[1].device, devices[1]) 1028 self.assertEqual(reduced[1].values[0].device, devices[0]) 1029 self.assertEqual(reduced[1].values[1].device, devices[1]) 1030 # Returning Mirrored only evaluates the primary value, which causes 1031 # hanging, 1032 return [reduced[0].values, reduced[1].values] 1033 1034 v = make_per_replica_value(1.0, devices) 1035 reduced = reduce_fn(v) 1036 self.assertAllClose(reduced, [[2.0, 2.0], [2.0, 2.0]]) 1037 1038 get_global_mpr(num_processes).run(replica_fn) 1039 1040 @combinations.generate( 1041 combinations.combine( 1042 num_processes=2, 1043 required_gpus=[0, 1], 1044 implementation=[ 1045 CommunicationImplementation.RING, 1046 CommunicationImplementation.NCCL, 1047 ], 1048 prefer_unique_instance_key=[True, False])) 1049 def testTimeoutReduceDense(self, num_processes, implementation, required_gpus, 1050 prefer_unique_instance_key): 1051 1052 if (required_gpus == 0 and 1053 implementation == CommunicationImplementation.NCCL): 1054 self.skipTest("Skip CPU + NCCL combination") 1055 if (num_processes != required_gpus and 1056 implementation == CommunicationImplementation.NCCL): 1057 self.skipTest("Skip NCCL combination with mismatched process and GPU " 1058 "count. NCCL requires physical GPUs for every process.") 1059 if (num_processes != required_gpus and 1060 implementation == CommunicationImplementation.AUTO): 1061 self.skipTest("Skip potential NCCL combination (AUTO) with mismatched " 1062 "process and GPU count. NCCL requires physical GPUs for " 1063 "every process.") 1064 1065 def replica_fn(): 1066 CollectiveReplicaLauncher._prefer_unique_instance_key = ( 1067 prefer_unique_instance_key) 1068 collective, devices, task_id = self.make_collective( 1069 num_processes, required_gpus) 1070 if task_id != 0: 1071 return 1072 1073 v = make_per_replica_value(1.0, devices) 1074 options = collective_util.Options( 1075 timeout_seconds=1., implementation=implementation) 1076 1077 @def_function.function 1078 def reduce_dense(): 1079 return collective.reduce(reduce_util.ReduceOp.SUM, v, v, options) 1080 1081 # The collective should time out because we only launch it on worker-0, 1082 # while there're three workers in total. 1083 with self.assertRaises(errors.DeadlineExceededError): 1084 reduce_dense() 1085 1086 get_global_mpr(num_processes).run(replica_fn) 1087 1088 @combinations.generate( 1089 combinations.combine( 1090 num_processes=2, 1091 required_gpus=[0, 1], 1092 implementation=[ 1093 CommunicationImplementation.RING, 1094 CommunicationImplementation.NCCL, 1095 ], 1096 prefer_unique_instance_key=[True, False])) 1097 def testTimeoutBatchReduceDense(self, num_processes, implementation, 1098 required_gpus, prefer_unique_instance_key): 1099 if (required_gpus == 0 and 1100 implementation == CommunicationImplementation.NCCL): 1101 self.skipTest("Skip CPU + NCCL combination") 1102 if (num_processes != required_gpus and 1103 implementation == CommunicationImplementation.NCCL): 1104 self.skipTest("Skip NCCL combination with mismatched process and GPU " 1105 "count. NCCL requires physical GPUs for every process.") 1106 if (num_processes != required_gpus and 1107 implementation == CommunicationImplementation.AUTO): 1108 self.skipTest("Skip potential NCCL combination (AUTO) with mismatched " 1109 "process and GPU count. NCCL requires physical GPUs for " 1110 "every process.") 1111 1112 def replica_fn(): 1113 CollectiveReplicaLauncher._prefer_unique_instance_key = ( 1114 prefer_unique_instance_key) 1115 collective, devices, task_id = self.make_collective( 1116 num_processes, required_gpus) 1117 if task_id != 0: 1118 return 1119 1120 v = make_per_replica_value(1.0, devices) 1121 options = collective_util.Options( 1122 timeout_seconds=1., implementation=implementation) 1123 1124 @def_function.function 1125 def batch_reduce_dense(): 1126 return collective.batch_reduce(reduce_util.ReduceOp.SUM, 1127 [(v, v), (v, v)], options) 1128 1129 # The collective should time out because we only launch it on worker-0, 1130 # while there're two workers in total. 1131 with self.assertRaises(errors.DeadlineExceededError): 1132 batch_reduce_dense() 1133 1134 get_global_mpr(num_processes).run(replica_fn) 1135 1136 @combinations.generate( 1137 combinations.combine( 1138 num_processes=2, 1139 required_gpus=[0, 1], 1140 implementation=[ 1141 CommunicationImplementation.RING, 1142 CommunicationImplementation.NCCL, 1143 ], 1144 prefer_unique_instance_key=[True, False])) 1145 def testTimeoutReduceSparse(self, num_processes, implementation, 1146 required_gpus, prefer_unique_instance_key): 1147 if (required_gpus == 0 and 1148 implementation == CommunicationImplementation.NCCL): 1149 self.skipTest("Skip CPU + NCCL combination") 1150 if (num_processes != required_gpus and 1151 implementation == CommunicationImplementation.NCCL): 1152 self.skipTest("Skip NCCL combination with mismatched process and GPU " 1153 "count. NCCL requires physical GPUs for every process.") 1154 if (num_processes != required_gpus and 1155 implementation == CommunicationImplementation.AUTO): 1156 self.skipTest("Skip potential NCCL combination (AUTO) with mismatched " 1157 "process and GPU count. NCCL requires physical GPUs for " 1158 "every process.") 1159 1160 def replica_fn(): 1161 CollectiveReplicaLauncher._prefer_unique_instance_key = ( 1162 prefer_unique_instance_key) 1163 collective, devices, task_id = self.make_collective( 1164 num_processes, required_gpus) 1165 if task_id != 0: 1166 return 1167 1168 v = make_per_replica_value( 1169 IndexedSlicesValue( 1170 values=[[4., 6.]], indices=[1], dense_shape=[5, 2]), devices) 1171 options = collective_util.Options( 1172 timeout_seconds=1., implementation=implementation) 1173 1174 @def_function.function 1175 def reduce_sparse(): 1176 return collective.reduce(reduce_util.ReduceOp.SUM, v, v, options) 1177 1178 # The collective should time out because we only launch it on worker-0, 1179 # while there're two workers in total. 1180 with self.assertRaises(errors.DeadlineExceededError): 1181 reduce_sparse() 1182 1183 get_global_mpr(num_processes).run(replica_fn) 1184 1185 @combinations.generate( 1186 combinations.combine( 1187 num_processes=2, 1188 required_gpus=[0, 1], 1189 implementation=[ 1190 CommunicationImplementation.RING, 1191 CommunicationImplementation.NCCL, 1192 ], 1193 prefer_unique_instance_key=[True, False])) 1194 def testTimeoutBatchReduceSparse(self, num_processes, required_gpus, 1195 implementation, prefer_unique_instance_key): 1196 if (required_gpus == 0 and 1197 implementation == CommunicationImplementation.NCCL): 1198 self.skipTest("Skip CPU + NCCL combination") 1199 if (num_processes != required_gpus and 1200 implementation == CommunicationImplementation.NCCL): 1201 self.skipTest("Skip NCCL combination with mismatched process and GPU " 1202 "count. NCCL requires physical GPUs for every process.") 1203 if (num_processes != required_gpus and 1204 implementation == CommunicationImplementation.AUTO): 1205 self.skipTest("Skip potential NCCL combination (AUTO) with mismatched " 1206 "process and GPU count. NCCL requires physical GPUs for " 1207 "every process.") 1208 1209 def replica_fn(): 1210 CollectiveReplicaLauncher._prefer_unique_instance_key = ( 1211 prefer_unique_instance_key) 1212 collective, devices, task_id = self.make_collective( 1213 num_processes, required_gpus) 1214 if task_id != 0: 1215 return 1216 1217 v = make_per_replica_value( 1218 IndexedSlicesValue( 1219 values=[[4., 6.]], indices=[1], dense_shape=[5, 2]), devices) 1220 options = collective_util.Options( 1221 timeout_seconds=1., implementation=implementation) 1222 1223 @def_function.function 1224 def batch_reduce_sparse(): 1225 return collective.batch_reduce(reduce_util.ReduceOp.SUM, 1226 [(v, v), (v, v)], options) 1227 1228 # The collective should time out because we only launch it on worker-0, 1229 # while there're two workers in total. 1230 with self.assertRaises(errors.DeadlineExceededError): 1231 batch_reduce_sparse() 1232 1233 get_global_mpr(num_processes).run(replica_fn) 1234 1235 @combinations.generate(combinations.combine(num_processes=1, required_gpus=2)) 1236 def testNcclOrdering(self, num_processes, required_gpus): 1237 1238 if num_processes != required_gpus: 1239 self.skipTest("Skip NCCL combination with mismatched process and GPU " 1240 "count. NCCL requires physical GPUs for every process.") 1241 1242 def replica_fn(): 1243 CollectiveReplicaLauncher._prefer_unique_instance_key = True 1244 CollectiveReplicaLauncher._prefer_ordering_token = True 1245 collective, devices, _ = self.make_collective(num_processes, 1246 required_gpus) 1247 options = collective_util.Options( 1248 implementation=CommunicationImplementation.NCCL) 1249 1250 v_dense = make_per_replica_value([1.0, 1.0], devices) 1251 v_sparse = make_per_replica_value([ 1252 IndexedSlicesValue([[4., 6.], [5., 6.]], [1, 3], [5, 2]), 1253 IndexedSlicesValue([[4., 6.], [5., 6.]], [1, 3], [5, 2]), 1254 ], devices) 1255 1256 @def_function.function 1257 def nested_dense(): 1258 collective.reduce(reduce_util.ReduceOp.SUM, v_dense, v_dense, options) 1259 1260 @def_function.function 1261 def nested_sparse(): 1262 collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse, options) 1263 1264 # All collectives, function calls, if clause and while loops should be 1265 # chained by control dependencies, so that the execution order is 1266 # deterministic. 1267 @def_function.function 1268 def f(): 1269 # pylint: disable=pointless-statement 1270 collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse, options) 1271 # reducing dense value. 1272 collective.reduce(reduce_util.ReduceOp.SUM, v_dense, v_dense, options) 1273 # reducing sparse value. 1274 collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse, options) 1275 # reduce dense value in nested tf.function. 1276 nested_dense() 1277 # reduce sparse value in nested tf.function. 1278 nested_sparse() 1279 # reduce dense value in tf.cond. 1280 if array_ops.identity(1.0) > array_ops.identity(2.0): 1281 collective.reduce(reduce_util.ReduceOp.SUM, v_dense, v_dense, options) 1282 else: 1283 v_dense 1284 # reduce sparse value in tf.cond. 1285 if array_ops.identity(1.0) > array_ops.identity(2.0): 1286 v_sparse 1287 else: 1288 collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse, 1289 options) 1290 # reduce dense value in tf.while_loop. 1291 i = array_ops.identity(1) 1292 while i < 3: 1293 collective.reduce(reduce_util.ReduceOp.SUM, v_dense, v_dense, options) 1294 i += 1 1295 # reduce sparse value in tf.while_loop. 1296 i = array_ops.identity(1) 1297 while i < 3: 1298 collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse, 1299 options) 1300 i += 1 1301 # reducing dense and sparse value again. 1302 collective.reduce(reduce_util.ReduceOp.SUM, v_dense, v_dense, options) 1303 collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse, options) 1304 # pylint: enable=pointless-statement 1305 1306 graph = f.get_concrete_function().graph 1307 should_be_ordered = set([ 1308 "CollectiveReduceV2", "CollectiveGatherV2", "If", "While", 1309 "StatefulPartitionedCall" 1310 ]) 1311 nodes_by_device = {} 1312 for op in graph.get_operations(): 1313 if op.type in should_be_ordered: 1314 if op.device not in nodes_by_device: 1315 nodes_by_device[op.device] = [] 1316 nodes_by_device[op.device].append(op) 1317 order = test_util.topological_sort_operations(graph.get_operations()) 1318 for device in devices: 1319 device = device_util.canonicalize(device) 1320 # Those function ops don't have device annotations, but they contain 1321 # collectives for both devices so we always include them. 1322 operations = nodes_by_device[device] + nodes_by_device[""] 1323 # Verify that we get all types of nodes we want. 1324 self.assertEqual(set(op.type for op in operations), should_be_ordered) 1325 test_util.assert_sequential_execution(order, operations) 1326 1327 get_global_mpr(num_processes).run(replica_fn) 1328 1329 1330if __name__ == "__main__": 1331 # Set default inter op thread pool size to one to ensure we don't exhaust the 1332 # thread pool with the additional executors to run collectives in eager. 1333 os.environ["TF_NUM_INTEROP_THREADS"] = "1" 1334 # TODO(b/172304955): figure why logical devices doesn't work. 1335 test_util.main(config_logical_devices=False) 1336