1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for TPUStrategy.""" 16 17from absl import logging 18from absl.testing import parameterized 19 20from tensorflow.core.protobuf import config_pb2 21from tensorflow.python.data.ops import dataset_ops 22from tensorflow.python.distribute import distribute_lib 23from tensorflow.python.distribute import distribution_strategy_context 24from tensorflow.python.distribute import reduce_util 25from tensorflow.python.distribute import strategy_test_lib 26from tensorflow.python.distribute import tpu_strategy as tpu_lib 27from tensorflow.python.distribute import tpu_values 28from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver 29from tensorflow.python.eager import context 30from tensorflow.python.eager import def_function 31from tensorflow.python.eager import function 32from tensorflow.python.eager import remote 33from tensorflow.python.eager import test 34from tensorflow.python.framework import composite_tensor 35from tensorflow.python.framework import config 36from tensorflow.python.framework import constant_op 37from tensorflow.python.framework import device as tf_device 38from tensorflow.python.framework import dtypes 39from tensorflow.python.framework import errors 40from tensorflow.python.framework import ops 41from tensorflow.python.framework import sparse_tensor 42from tensorflow.python.framework import tensor_spec 43from tensorflow.python.framework import test_util 44from tensorflow.python.framework import type_spec 45from tensorflow.python.ops import array_ops 46from tensorflow.python.ops import control_flow_ops 47from tensorflow.python.ops import embedding_ops 48from tensorflow.python.ops import gen_dataset_ops 49from tensorflow.python.ops import logging_ops 50from tensorflow.python.ops import lookup_ops 51from tensorflow.python.ops import math_ops 52from tensorflow.python.ops import random_ops 53from tensorflow.python.ops import string_ops 54from tensorflow.python.ops import variables 55from tensorflow.python.ops.ragged import ragged_tensor 56from tensorflow.python.platform import flags 57from tensorflow.python.platform import tf_logging as logging 58from tensorflow.python.tpu import device_assignment as device_assignment_lib 59from tensorflow.python.tpu import tpu 60from tensorflow.python.tpu import tpu_hardware_feature 61from tensorflow.python.tpu import tpu_strategy_util 62from tensorflow.python.training import server_lib 63from tensorflow.python.util import nest 64 65 66FLAGS = flags.FLAGS 67flags.DEFINE_string("tpu", "", "Name of TPU to connect to.") 68flags.DEFINE_string("project", None, "Name of GCP project with TPU.") 69flags.DEFINE_string("zone", None, "Name of GCP zone with TPU.") 70 71 72def get_tpu_cluster_resolver(): 73 resolver = tpu_cluster_resolver.TPUClusterResolver( 74 tpu=FLAGS.tpu, 75 zone=FLAGS.zone, 76 project=FLAGS.project, 77 ) 78 return resolver 79 80 81def get_tpu_strategy(enable_packed_var=False): 82 resolver = get_tpu_cluster_resolver() 83 remote.connect_to_cluster(resolver) 84 tpu_strategy_util.initialize_tpu_system(resolver) 85 strategy = tpu_lib.TPUStrategyV2(resolver) 86 strategy._enable_packed_variable_in_eager_mode = enable_packed_var 87 return strategy 88 89 90# TPU tests which don't use TPUStrategy. 91@test_util.with_eager_op_as_function 92class TPUTest(test.TestCase): 93 94 # In this case, the entire computation in foo is compiled using JIT 95 # compilation. 96 def test_single_tpu_jit_compile(self): 97 with ops.device("/device:TPU:0"): 98 a = variables.Variable(1) 99 100 def get_a_plus_one(): 101 return a + 1 102 103 @def_function.function( 104 input_signature=[tensor_spec.TensorSpec([], dtypes.int32)]) 105 def foo(x): 106 b = x + get_a_plus_one() 107 b = b + get_a_plus_one() 108 return b + 1 109 110 with ops.device("/device:TPU:0"): 111 result = foo(a) 112 self.assertAllEqual(6, result) 113 114 # In this case, the entire computation in foo is compiled using JIT 115 # compilation and contains unsupported ops that should be outside compiled. 116 def test_single_tpu_jit_compile_with_outside_compilation(self): 117 context.enable_jit_compile_rewrite() 118 get_tpu_strategy(True) 119 config.set_soft_device_placement(True) 120 with ops.device("/device:TPU:0"): 121 a = variables.Variable(1) 122 123 def get_a_plus_one(): 124 return a + 1 125 126 @def_function.function( 127 input_signature=[tensor_spec.TensorSpec([], dtypes.int32)]) 128 def foo(x): 129 b = x + get_a_plus_one() 130 my_str = string_ops.as_string(b) 131 new_str = my_str + "0" 132 c = string_ops.string_to_number(new_str, out_type=dtypes.int32) 133 logging_ops.print_v2(c) 134 b = c + get_a_plus_one() 135 return b + 1 136 137 with ops.device("/device:TPU:0"): 138 result = foo(a) 139 self.assertAllEqual(33, result) 140 141 # In this case, each of the ops in the TPU device scope are compiled and run 142 # individually. 143 def test_single_tpu_on_demand(self): 144 with ops.device("/device:TPU:0"): 145 a = variables.Variable(1) 146 147 def get_a_plus_one(): 148 return a + 1 149 150 x = 1 151 with ops.device("/device:TPU:0"): 152 b = x + get_a_plus_one() 153 b = b + get_a_plus_one() 154 result = b + 1 155 156 self.assertAllEqual(6, result) 157 158 # In this case, each of the ops in the tf.function and TPU device scope are 159 # compiled and run individually. 160 def test_single_tpu_on_demand_tf_function(self): 161 with ops.device("/device:TPU:0"): 162 a = variables.Variable(1) 163 164 def get_a_plus_one(): 165 return a + 1 166 167 @def_function.function( 168 input_signature=[tensor_spec.TensorSpec([], dtypes.int32)]) 169 def foo(x): 170 with ops.device("/device:TPU:0"): 171 b = x + get_a_plus_one() 172 b = b + get_a_plus_one() 173 return b + 1 174 175 result = foo(a) 176 self.assertAllEqual(6, result) 177 178 def test_multiple_initialize_system(self): 179 resolver = get_tpu_cluster_resolver() 180 remote.connect_to_cluster(resolver) 181 tpu_strategy_util.initialize_tpu_system(resolver) 182 183 with test.mock.patch.object(logging, "warning") as mock_log: 184 tpu_strategy_util.initialize_tpu_system(resolver) 185 self.assertRegex(str(mock_log.call_args), "already been initialized") 186 187 def test_tpu_tf_function_same_device(self): 188 with ops.device("/device:TPU:0"): 189 a = variables.Variable(1) 190 191 @function.defun_with_attributes(attributes={"_noinline": True}) 192 def get_a_plus_one(): 193 return a + 1 194 195 @def_function.function( 196 input_signature=[tensor_spec.TensorSpec([], dtypes.int32)]) 197 def foo(x): 198 with ops.device("/device:TPU:0"): 199 b = x + get_a_plus_one() 200 return b + 1 201 202 result = foo(a) 203 self.assertAllEqual(4, result) 204 205 def test_tpu_return_int32(self): 206 with ops.device("/device:TPU:0"): 207 a = variables.Variable(0) 208 209 @def_function.function 210 def foo(): 211 return a + 1 212 213 @def_function.function 214 def bar(): 215 with ops.device("/device:TPU:1"): 216 return foo() 217 218 with ops.device("/device:CPU:0"): 219 result = bar() + 1 220 self.assertAllEqual(result, 2) 221 222 def test_tpu_output_device(self): 223 224 def foo(): 225 return 1 + 1 226 227 func1 = function.defun_with_attributes( 228 foo, attributes={"_XlaMustCompile": False}) 229 func2 = function.defun_with_attributes( 230 foo, attributes={ 231 "_OutputsOnOpDevice": True, 232 "_XlaMustCompile": False 233 }) 234 235 with ops.device("/device:TPU:0"): 236 ret1 = func1() 237 ret2 = func2() 238 239 self.assertAllEqual(ret1.backing_device, 240 "/job:localhost/replica:0/task:0/device:CPU:0") 241 self.assertAllEqual(ret2.backing_device, 242 "/job:localhost/replica:0/task:0/device:TPU:0") 243 244 def test_on_demand_op_with_dynamic_output(self): 245 with ops.device("/device:TPU:0"): 246 where_output = array_ops.where([True, False, True]) 247 self.assertAllEqual(where_output, [[0], [2]]) 248 249 with ops.device("/device:TPU:0"): 250 repeat_output = array_ops.repeat(math_ops.range(2), [1, 4]) 251 self.assertAllEqual(repeat_output, [0, 1, 1, 1, 1]) 252 253 254@parameterized.named_parameters([("PackedVar", True), ("", False)]) 255@test_util.with_eager_op_as_function 256class TPUStrategyTest(test.TestCase, parameterized.TestCase): 257 258 def test_handle_in_cross_replica_context(self, enable_packed_var): 259 strategy = get_tpu_strategy(enable_packed_var) 260 with strategy.scope(): 261 v = variables.Variable(1.0) 262 263 @def_function.function 264 def func(): 265 self.assertEndsWith(v.handle.device, "device:TPU:0") 266 return v + 1.0 267 268 ret = func() 269 self.assertAllEqual(ret, 2.0) 270 271 def testStaticHashTableDatasetFnHostTrainingLoop(self, enable_packed_var): 272 self._dataset_fn_tracing_count = 0 273 strategy = get_tpu_strategy(enable_packed_var) 274 275 with strategy.scope(): 276 vals = [0, 1, 2] 277 keys_tensor = constant_op.constant( 278 list(range(len(vals))), dtype=dtypes.int64) 279 vals_tensor = constant_op.constant(vals) 280 initializer = lookup_ops.KeyValueTensorInitializer( 281 keys_tensor, vals_tensor) 282 per_worker_table = lookup_ops.StaticHashTable( 283 initializer, default_value=-1) 284 285 @def_function.function 286 def dataset_fn(input_context): 287 tensor = constant_op.constant([0, 1, 3], dtype=dtypes.int64) 288 global_batch_size = 2 289 batch_size = input_context.get_per_replica_batch_size(global_batch_size) 290 dataset = dataset_ops.Dataset.from_tensors(tensor).repeat().batch( 291 batch_size, drop_remainder=True) 292 dataset = dataset.shard(input_context.num_input_pipelines, 293 input_context.input_pipeline_id) 294 dataset = dataset.prefetch(2) # This prefetches 2 batches per device. 295 dataset = dataset.map(per_worker_table.lookup) 296 self._dataset_fn_tracing_count += 1 297 return dataset 298 299 dist_iterator = iter( 300 strategy.experimental_distribute_datasets_from_function(dataset_fn)) 301 302 @def_function.function 303 def step_fn(inputs): 304 # inputs should be [0, 1, -1] 305 return math_ops.reduce_sum(inputs) 306 307 def train_steps(iterator, steps): 308 309 for _ in math_ops.range(steps): 310 strategy.run(step_fn, args=(next(iterator),)) 311 312 train_steps(dist_iterator, steps=5) 313 self.assertEqual(self._dataset_fn_tracing_count, 1) 314 315 def test_function_compile_with_xla(self, enable_packed_var): 316 if FLAGS.tpu_use_tfrt: 317 self.skipTest( 318 "This test triggers _XlaCompile and XlaLaunch which are not " 319 "supported in tfrt yet. We should avoid using these kernels on TPU. " 320 "However, it is a workaround to support b/129842431. We need more " 321 "discussion about how to support it in the long term.") 322 strategy = get_tpu_strategy(enable_packed_var) 323 with strategy.scope(): 324 v = variables.Variable(1.0) 325 326 @def_function.function 327 def func(): 328 return v.read_value() + 1.0 329 330 with ops.device("/device:TPU:0"): 331 self.assertAllEqual(func(), 2.0) 332 333 def test_sequential_runs(self, enable_packed_var): 334 resolver = get_tpu_cluster_resolver() 335 remote.connect_to_cluster(resolver) 336 topology = tpu_strategy_util.initialize_tpu_system(resolver) 337 # Computation replicated to all cores. 338 device_assignment = device_assignment_lib.DeviceAssignment.build( 339 topology, num_replicas=2) 340 strategy = tpu_lib.TPUStrategyV2( 341 resolver, experimental_device_assignment=device_assignment) 342 strategy._enable_packed_variable_in_eager_mode = enable_packed_var 343 344 # Computation on the 1st core. 345 device_assignment2 = device_assignment_lib.DeviceAssignment.build( 346 topology, num_replicas=1) 347 strategy2 = tpu_lib.TPUStrategyV2( 348 resolver, experimental_device_assignment=device_assignment2) 349 350 def computation(x): 351 return math_ops.square(x) 352 353 @def_function.function 354 def train_step(): 355 outputs = strategy.experimental_local_results( 356 strategy.run(computation, args=([2., 2.],))) 357 outputs2 = strategy2.run( 358 computation, args=([outputs[0]],)) 359 return outputs2 360 361 self.assertAllEqual([[16., 16.]], train_step()) 362 363 def test_device_switch_case(self, enable_packed_var): 364 strategy = get_tpu_strategy(enable_packed_var) 365 with strategy.scope(): 366 a = variables.Variable(1) 367 368 inference_iteration = variables.Variable(-1) 369 370 def inference_fn(x, i): 371 return a + x + i 372 373 @def_function.function 374 def run_inference(x): 375 376 def do_inference(device, inference_fn, i): 377 with ops.device(device): 378 return inference_fn(x, i) 379 380 branch_fns = { 381 0: (lambda: do_inference("/device:TPU:0", inference_fn, 0)), 382 1: (lambda: do_inference("/device:TPU:1", inference_fn, 1)), 383 } 384 branch_index = inference_iteration.assign_add(1, use_locking=True) % 2 385 return control_flow_ops.switch_case(branch_index, branch_fns) 386 387 self.assertAllEqual(2., run_inference(1)) # Use TPU core 0. 388 self.assertAllEqual(3., run_inference(1)) # Use TPU core 1. 389 390 def test_recover_from_compilation_failures(self, enable_packed_var): 391 # TODO(b/148150981): Stop skipping this test once recovery works 392 # for non-local TPU. 393 if FLAGS.tpu: 394 self.skipTest("Recovery fails for non-local TPU, see b/148150981") 395 396 # Disable automatic outside compilation. 397 config.set_soft_device_placement(False) 398 strategy = get_tpu_strategy(enable_packed_var) 399 400 @def_function.function 401 def compilation_failure_run(): 402 403 def computation(): 404 return random_ops.random_gamma([10], [0.5, 1.5]) 405 406 return strategy.run(computation) 407 408 with self.assertRaises(errors.OpError): 409 compilation_failure_run() 410 411 @def_function.function 412 def good_run(): 413 414 def computation(): 415 return random_ops.random_normal([10]) 416 417 return strategy.run(computation) 418 419 good_run() 420 421 def test_dynamic_shape_with_outside_compilation_failure( 422 self, enable_packed_var): 423 # Enable automatic outside compilation. 424 config.set_soft_device_placement(True) 425 strategy = get_tpu_strategy(enable_packed_var) 426 dataset = dataset_ops.Dataset.from_tensors(("string", 1.0)).repeat().batch( 427 2, drop_remainder=False) 428 dataset = strategy.experimental_distribute_dataset(dataset) 429 iterator = iter(dataset) 430 431 @def_function.function 432 def train_fn(iterator): 433 434 def step_fn(inputs): 435 input0, input1 = inputs 436 return array_ops.size(input0), math_ops.reduce_sum(input1) 437 438 return strategy.experimental_local_results( 439 strategy.run(step_fn, args=(next(iterator),))) 440 441 with self.assertRaises(errors.InvalidArgumentError): 442 logging.info(train_fn(iterator)) 443 444 def test_computation_on_subset_cores(self, enable_packed_var): 445 resolver = get_tpu_cluster_resolver() 446 remote.connect_to_cluster(resolver) 447 topology = tpu_strategy_util.initialize_tpu_system(resolver) 448 all_core_strategy = tpu_lib.TPUStrategyV2(resolver) 449 all_core_strategy._enable_packed_variable_in_eager_mode = enable_packed_var 450 451 with all_core_strategy.scope(): 452 v = variables.Variable(0.0, 453 aggregation=variables.VariableAggregation.MEAN) 454 455 # Computation on the 1st core. 456 device_assignment = device_assignment_lib.DeviceAssignment.build( 457 topology, num_replicas=1) 458 first_core_strategy = tpu_lib.TPUStrategyV2( 459 resolver, experimental_device_assignment=device_assignment) 460 first_core_strategy._enable_packed_variable_in_eager_mode = ( 461 enable_packed_var) 462 463 # Computation on the 2nd core. 464 device_assignment2 = device_assignment_lib.DeviceAssignment( 465 topology, [[[0, 0, 0, 1]]]) 466 second_core_strategy = tpu_lib.TPUStrategyV2( 467 resolver, experimental_device_assignment=device_assignment2) 468 second_core_strategy._enable_packed_variable_in_eager_mode = ( 469 enable_packed_var) 470 471 @def_function.function 472 def train_step(): 473 474 def step_fn(): 475 return v + 1.0 476 477 all_core_strategy.run(step_fn) 478 r1 = first_core_strategy.run(step_fn) 479 r2 = second_core_strategy.run(step_fn) 480 return r1 + r2 481 482 train_step() 483 self.assertAllEqual(2., train_step()) 484 485 def test_worker_devices_on_subset_cores(self, enable_packed_var): 486 resolver = get_tpu_cluster_resolver() 487 remote.connect_to_cluster(resolver) 488 topology = tpu_strategy_util.initialize_tpu_system(resolver) 489 490 # Strategy for the 1st core. 491 device_assignment = device_assignment_lib.DeviceAssignment.build( 492 topology, num_replicas=1) 493 first_core_strategy = tpu_lib.TPUStrategyV2( 494 resolver, experimental_device_assignment=device_assignment) 495 first_core_strategy._enable_packed_variable_in_eager_mode = ( 496 enable_packed_var) 497 498 # Strategy for the 2nd core. 499 device_assignment2 = device_assignment_lib.DeviceAssignment( 500 topology, [[[0, 0, 0, 1]]]) 501 second_core_strategy = tpu_lib.TPUStrategyV2( 502 resolver, experimental_device_assignment=device_assignment2) 503 second_core_strategy._enable_packed_variable_in_eager_mode = ( 504 enable_packed_var) 505 506 self.assertLen(first_core_strategy.extended.worker_devices, 1) 507 self.assertEndsWith(first_core_strategy.extended.worker_devices[0], 508 "device:TPU:0") 509 510 self.assertLen(second_core_strategy.extended.worker_devices, 1) 511 self.assertEndsWith(second_core_strategy.extended.worker_devices[0], 512 "device:TPU:1") 513 514 def test_control_output_in_while_body_fn(self, enable_packed_var): 515 strategy = get_tpu_strategy(enable_packed_var) 516 517 with strategy.scope(): 518 v = variables.Variable( 519 0.0, aggregation=variables.VariableAggregation.MEAN) 520 521 @def_function.function 522 def train_step(): 523 524 def step_fn(): 525 v.assign_add(1) 526 527 for _ in math_ops.range(2): 528 strategy.run(step_fn) 529 530 train_step() 531 self.assertEqual(2.0, v.numpy()) 532 533 def test_cluster_conditional_with_dynamic_shape(self, enable_packed_var): 534 strategy = get_tpu_strategy(enable_packed_var) 535 536 @def_function.function 537 def train_step(): 538 539 def shape_list(tensor): 540 shape = tensor.shape.as_list() 541 542 non_static_indexes = [] 543 for (index, dim) in enumerate(shape): 544 if dim is None: 545 non_static_indexes.append(index) 546 547 if not non_static_indexes: 548 return shape 549 550 dynamic_shape = array_ops.shape(input=tensor) 551 for index in non_static_indexes: 552 shape[index] = dynamic_shape[index] 553 554 return shape 555 556 def step_fn(condition): 557 where = array_ops.where(condition) 558 if array_ops.shape(where)[0] > 0: 559 tensor_shape = shape_list(where) 560 d1 = tensor_shape[0] 561 d2 = tensor_shape[1] 562 where = array_ops.reshape(where, [d1, d2]) 563 return where 564 565 return strategy.run(step_fn, args=([True, False, True],)) 566 567 outputs = strategy.experimental_local_results(train_step()) 568 self.assertAllEqual(outputs[0].numpy(), [[0], [2]]) 569 570 def test_cluster_in_graph_and_while_body_fn(self, enable_packed_var): 571 strategy = get_tpu_strategy(enable_packed_var) 572 573 @def_function.function 574 def train_step(): 575 576 def step_fn(prev): 577 s = prev + 1 578 return s 579 580 def init_fn(): 581 return array_ops.zeros(shape=()) 582 583 prev = strategy.run(init_fn) 584 for _ in math_ops.range(10): 585 prev = strategy.run(step_fn, args=(prev,)) 586 return strategy.reduce(reduce_util.ReduceOp.SUM, prev, axis=None) 587 588 sum_val = train_step().numpy().astype(float) 589 self.assertEqual(sum_val, strategy.num_replicas_in_sync * 10) 590 591 def test_two_clusters_with_same_fn(self, enable_packed_var): 592 strategy = get_tpu_strategy(enable_packed_var) 593 594 @def_function.function 595 def foo(x): 596 return strategy.run(lambda x: x + 1, (x,)) 597 598 @def_function.function 599 def bar(x): 600 foo(x) 601 return foo(x) 602 603 bar(1) 604 605 def test_tpu_variable_run_argument(self, enable_packed_var): 606 # TPUStrategy.run() casts inputs to Tensor, but has logic to preserve 607 # variables to avoid unintuitive errors. 608 # Here we test that a TPUDistributedVariable passed to TPUStrategy.run() 609 # remains a variable. 610 611 strategy = get_tpu_strategy(enable_packed_var) 612 613 with strategy.scope(): 614 tpu_variable = variables.Variable(1) 615 616 def replica_step(first_arg, variable): 617 del first_arg # Just here to make sure we're not relying on arg position. 618 619 if variable is not None: 620 self.assertIsInstance(variable, tpu_values.TPUDistributedVariable) 621 622 @def_function.function 623 def step(): 624 strategy.run( 625 replica_step, args=( 626 2, 627 tpu_variable, 628 )) 629 630 step() 631 632 def test_tpu_run_arg_parsing(self, enable_packed_var): 633 strategy = get_tpu_strategy(enable_packed_var) 634 635 with strategy.scope(): 636 tpu_vars = [variables.Variable(1)] 637 638 def only_star_args(*args): 639 del args 640 641 def pos_and_star_args(first_arg, *args): 642 del first_arg 643 del args 644 645 def named_args(first_arg, second_arg): 646 del first_arg 647 del second_arg 648 649 def star_args_and_kw_only(*args, kw): 650 del args 651 del kw 652 653 # pylint:disable=function-redefined 654 @def_function.function 655 def step(): 656 strategy.run(only_star_args, args=(2,)) 657 658 step() 659 660 @def_function.function 661 def step(): 662 strategy.run(named_args, kwargs={"first_arg": 2, "second_arg": 3}) 663 664 step() 665 666 with self.assertRaisesRegex(TypeError, r"got multiple values for argument"): 667 668 @def_function.function 669 def step(): 670 strategy.run( 671 named_args, args=(1,), kwargs={ 672 "first_arg": 2, 673 "second_arg": 3 674 }) 675 676 step() 677 678 with self.assertRaisesRegex(ValueError, 679 r"cannot handle Variables passed to \*args"): 680 681 @def_function.function 682 def step(): 683 strategy.run( 684 only_star_args, args=( 685 2, 686 tpu_vars, 687 )) 688 689 step() 690 691 @def_function.function 692 def step(): 693 strategy.run(pos_and_star_args, args=(2, 3, 4)) 694 695 step() 696 697 @def_function.function 698 def step(): 699 strategy.run(star_args_and_kw_only, args=(2, 3), kwargs={"kw": tpu_vars}) 700 701 step() 702 703 with self.assertRaisesRegex(ValueError, 704 r"mix of positional args and \*args"): 705 706 @def_function.function 707 def step(): 708 strategy.run(pos_and_star_args, args=(tpu_vars, 3, 4)) 709 710 step() 711 712 with self.assertRaisesRegex(ValueError, r"Too many positional arguments"): 713 714 @def_function.function 715 def step(): 716 strategy.run(named_args, args=(2, 3, 4)) 717 718 step() 719 720 class DummyClass: 721 722 @def_function.function 723 def method(self, arg_1): 724 del arg_1 725 726 def step(self): 727 strategy.run(self.method, args=(tpu_vars,)) 728 729 DummyClass().step() 730 # pylint:enable=function-redefined 731 732 def test_using_external_variable_inside_tf_function(self, enable_packed_var): 733 strategy = get_tpu_strategy(enable_packed_var) 734 dataset = dataset_ops.Dataset.range( 735 strategy.num_replicas_in_sync * 2, 736 output_type=dtypes.float32).batch(strategy.num_replicas_in_sync) 737 input_iterator = iter(strategy.experimental_distribute_dataset(dataset)) 738 739 v = variables.Variable(2.0) 740 741 @def_function.function 742 def train_step(data): 743 def computation(inputs): 744 return inputs + v 745 return strategy.run(computation, args=(data,)) 746 747 expected_result = [[x + 2.] for x in range(0, strategy.num_replicas_in_sync) 748 ] 749 self.assertAllEqual( 750 expected_result, 751 strategy.experimental_local_results(train_step(next(input_iterator)))) 752 753 # TODO(b/145574622): Remove this test once it is re-enabled in values_test.py. 754 def test_all_reduce_on_sync_on_read_variable(self, enable_packed_var): 755 strategy = get_tpu_strategy(enable_packed_var) 756 dataset = dataset_ops.Dataset.range( 757 strategy.num_replicas_in_sync, output_type=dtypes.float32).batch( 758 strategy.num_replicas_in_sync, drop_remainder=True) 759 input_iterator = iter(strategy.experimental_distribute_dataset(dataset)) 760 761 with strategy.scope(): 762 w = variables.Variable( 763 (0.,), 764 shape=(1,), 765 trainable=False, 766 synchronization=variables.VariableSynchronization.ON_READ, 767 aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA) 768 769 @def_function.function 770 def run(iterator): 771 772 def computation(x): 773 w.assign(x + w) 774 return w 775 776 def all_reduce(x): 777 ctx = distribution_strategy_context.get_replica_context() 778 return ctx.all_reduce("SUM", w) + x 779 780 outputs = strategy.run(computation, args=(next(iterator),)) 781 outputs2 = strategy.experimental_local_results( 782 strategy.run(all_reduce, args=(outputs,))) 783 return outputs2 784 785 data = range(0, strategy.num_replicas_in_sync) 786 data_sum = sum(data) 787 expected_result = [ 788 [x + data_sum] for x in range(0, strategy.num_replicas_in_sync) 789 ] 790 self.assertAllEqual(expected_result, run(input_iterator)) 791 self.assertAllEqual((0.,), w.read_value()) 792 793 def test_run_output_on_device(self, enable_packed_var): 794 strategy = get_tpu_strategy(enable_packed_var) 795 796 def computation(x): 797 return math_ops.square(x) 798 799 @def_function.function 800 def train_step(): 801 outputs = strategy.experimental_local_results( 802 strategy.run(computation, args=(2,))) 803 return outputs 804 805 results = train_step() 806 self.assertAllEqual([4., 4.], results) 807 self.assertAllEqual("/job:localhost/replica:0/task:0/device:TPU:0", 808 results[0].backing_device) 809 self.assertAllEqual("/job:localhost/replica:0/task:0/device:TPU:1", 810 results[1].backing_device) 811 812 def test_run_passing_and_returning_nones(self, enable_packed_var): 813 strategy = get_tpu_strategy(enable_packed_var) 814 815 @def_function.function 816 def train_step(): 817 818 def computation(x): 819 return x 820 821 # Note that this input None is nested. 822 outputs = strategy.experimental_local_results( 823 strategy.run(computation, args=([1, [2, None]],))) 824 return outputs 825 826 results = train_step() 827 828 self.assertAllEqual(1, results[0][0]) 829 self.assertAllEqual(2, results[0][1][0]) 830 self.assertIsNone(results[0][1][1]) 831 832 def test_run_passing_and_returning_empty_list(self, enable_packed_var): 833 strategy = get_tpu_strategy(enable_packed_var) 834 835 @def_function.function 836 def train_step(): 837 838 def computation(x): 839 return x 840 841 outputs = strategy.experimental_local_results( 842 strategy.run(computation, args=([],))) 843 return outputs 844 845 self.assertEqual([], train_step()[0]) 846 847 def test_run_passing_and_returning_empty_dict(self, enable_packed_var): 848 strategy = get_tpu_strategy(enable_packed_var) 849 850 @def_function.function 851 def train_step(): 852 853 def computation(x): 854 return x 855 856 outputs = strategy.experimental_local_results( 857 strategy.run(computation, args=({},))) 858 return outputs 859 860 self.assertEqual({}, train_step()[0]) 861 862 def test_composite_input_output(self, enable_packed_var): 863 strategy = get_tpu_strategy(enable_packed_var) 864 if strategy.num_replicas_in_sync != 2: 865 self.skipTest("Test assumes two replicas.") 866 867 with strategy.scope(): 868 table = variables.Variable( 869 initial_value=[[0.0, 1.0], [3.0, 7.0]], dtype=dtypes.float32) 870 871 @def_function.function 872 def sparse_lookup(iterator): 873 874 def tpu_function(sparse): 875 # Assumes dense_shape is (2, *) 876 looked_up = array_ops.gather(table, sparse.values) 877 segment_sum = math_ops.unsorted_segment_sum( 878 looked_up, sparse.indices[:, 0], 2) 879 return sparse, segment_sum 880 881 return nest.map_structure( 882 strategy.experimental_local_results, 883 strategy.run(tpu_function, args=(next(iterator),))) 884 885 def dataset_fn(_): 886 dataset = dataset_ops.Dataset.range(2) 887 888 def make_sparse(_): 889 return sparse_tensor.SparseTensor( 890 indices=array_ops.constant([[0, 0], [1, 0], [1, 1]], 891 dtype=dtypes.int64), 892 values=array_ops.constant([0, 0, 1], dtype=dtypes.int32), 893 dense_shape=array_ops.constant([2, 2], dtype=dtypes.int64)) 894 895 return dataset.map(make_sparse) 896 897 dataset = iter( 898 strategy.distribute_datasets_from_function( 899 dataset_fn, 900 distribute_lib.InputOptions(experimental_fetch_to_device=False))) 901 902 sparse, result = sparse_lookup(dataset) 903 904 # All replicas return identical reults. 905 for replica in range(strategy.num_replicas_in_sync): 906 self.assertIsInstance(sparse[replica], sparse_tensor.SparseTensor) 907 self.assertAllEqual(sparse[replica].indices, [[0, 0], [1, 0], [1, 1]]) 908 self.assertAllEqual(sparse[replica].values, [0, 0, 1]) 909 self.assertAllEqual(sparse[replica].dense_shape, [2, 2]) 910 self.assertAllEqual(result[replica], [[0.0, 1.0], [3.0, 8.0]]) 911 912 def test_composite_input_non_flat_output(self, enable_packed_var): 913 strategy = get_tpu_strategy(enable_packed_var) 914 if strategy.num_replicas_in_sync != 2: 915 self.skipTest("Test assumes two replicas.") 916 917 with strategy.scope(): 918 table = variables.Variable( 919 initial_value=[[0.0, 1.0], [3.0, 7.0]], dtype=dtypes.float32) 920 921 @def_function.function 922 def sparse_lookup(iterator): 923 924 def tpu_function(sparse): 925 # Assumes dense_shape is (2, *) 926 looked_up = array_ops.gather(table, sparse.values) 927 segment_sum = math_ops.unsorted_segment_sum( 928 looked_up, sparse.indices[:, 0], 2) 929 return {"sparse": sparse, "segment_sum": segment_sum} 930 931 return nest.map_structure( 932 strategy.experimental_local_results, 933 strategy.run(tpu_function, args=(next(iterator),))) 934 935 def dataset_fn(_): 936 dataset = dataset_ops.Dataset.range(2) 937 938 def make_sparse(_): 939 return sparse_tensor.SparseTensor( 940 indices=array_ops.constant([[0, 0], [1, 0], [1, 1]], 941 dtype=dtypes.int64), 942 values=array_ops.constant([0, 0, 1], dtype=dtypes.int32), 943 dense_shape=array_ops.constant([2, 2], dtype=dtypes.int64)) 944 945 return dataset.map(make_sparse) 946 947 dataset = iter( 948 strategy.distribute_datasets_from_function( 949 dataset_fn, 950 distribute_lib.InputOptions(experimental_fetch_to_device=False))) 951 952 output = sparse_lookup(dataset) 953 954 # All replicas return identical reults. 955 for replica in range(strategy.num_replicas_in_sync): 956 self.assertIsInstance(output["sparse"][replica], 957 sparse_tensor.SparseTensor) 958 self.assertAllEqual(output["sparse"][replica].indices, 959 [[0, 0], [1, 0], [1, 1]]) 960 self.assertAllEqual(output["sparse"][replica].values, [0, 0, 1]) 961 self.assertAllEqual(output["sparse"][replica].dense_shape, [2, 2]) 962 self.assertAllEqual(output["segment_sum"][replica], 963 [[0.0, 1.0], [3.0, 8.0]]) 964 965 def test_composite_input_dynamic_shapes_outside_compilation( 966 self, enable_packed_var): 967 strategy = get_tpu_strategy(enable_packed_var) 968 if strategy.num_replicas_in_sync != 2: 969 self.skipTest("Test assumes two replicas.") 970 971 table = variables.Variable( 972 initial_value=[[0.0, 1.0], [3.0, 7.0]], dtype=dtypes.float32) 973 974 @def_function.function 975 def sparse_lookup(iterator): 976 977 def tpu_function(sparse): 978 lookup = tpu.outside_compilation( 979 embedding_ops.safe_embedding_lookup_sparse, table, sparse) 980 return math_ops.reduce_sum(lookup, axis=0) 981 982 return strategy.experimental_local_results( 983 strategy.run(tpu_function, args=(next(iterator),))) 984 985 def dataset_fn(_): 986 dataset = dataset_ops.Dataset.range(2) 987 988 def make_sparse(i): 989 indices = array_ops.constant([[0, 0], [1, 0], [1, 1]], 990 dtype=dtypes.int64)[0:2 + i] 991 values = array_ops.constant([0, 0, 1], dtype=dtypes.int32)[0:2 + i] 992 shape = [ 993 array_ops.constant([2], dtype=dtypes.int64), 994 array_ops.expand_dims(1 + i, axis=0) 995 ] 996 dense_shape = array_ops.concat(shape, axis=0) 997 return sparse_tensor.SparseTensor( 998 indices=indices, values=values, dense_shape=dense_shape) 999 1000 return dataset.map(make_sparse) 1001 1002 dataset = iter( 1003 strategy.distribute_datasets_from_function( 1004 dataset_fn, 1005 options=distribute_lib.InputOptions( 1006 experimental_fetch_to_device=False))) 1007 1008 result = sparse_lookup(dataset) 1009 self.assertAllEqual(result, [[0.0, 2.0], [1.5, 5.0]]) 1010 1011 def test_composite_input_with_non_flat_components(self, enable_packed_var): 1012 strategy = get_tpu_strategy(enable_packed_var) 1013 1014 class TestCompositeTypeSpec(type_spec.TypeSpec): 1015 1016 def __init__(self, component_type_spec): 1017 self._component_type_spec = component_type_spec 1018 1019 @property 1020 def value_type(self): 1021 return TestComposite 1022 1023 def _to_components(self, value): 1024 return value.values 1025 1026 def _from_components(self, components): 1027 return TestComposite(components[0], components[1][0], components[1][1]) 1028 1029 @property 1030 def _component_specs(self): 1031 return [self._component_type_spec, 1032 [self._component_type_spec, self._component_type_spec]] 1033 1034 def _serialize(self): 1035 return (self._component_type_spec,) 1036 1037 class TestComposite(composite_tensor.CompositeTensor): 1038 1039 def __init__(self, value1, value2, value3): 1040 self.values = [value1, [value2, value3]] 1041 1042 @property 1043 def _type_spec(self): 1044 return TestCompositeTypeSpec( 1045 tensor_spec.TensorSpec.from_tensor(self.values[0])) 1046 1047 def _shape_invariant_to_type_spec(self, shape): 1048 return [shape, [shape, shape]] 1049 1050 @def_function.function 1051 def test_fn(test_composite): 1052 1053 def tpu_function(composite): 1054 return (composite, 1055 composite.values[0] + ( 1056 composite.values[1][0] + composite.values[1][1])/2) 1057 1058 return nest.map_structure( 1059 strategy.experimental_local_results, 1060 strategy.run(tpu_function, args=(test_composite,))) 1061 1062 a = array_ops.constant([0.1]) 1063 b = array_ops.constant([1.2]) 1064 c = array_ops.constant([-0.4]) 1065 test_composite = TestComposite(a, b, c) 1066 1067 composite, result = test_fn(test_composite) 1068 1069 # All replicas return identical reults. 1070 for replica in range(strategy.num_replicas_in_sync): 1071 self.assertIsInstance(composite[replica], TestComposite) 1072 self.assertAllEqual(composite[replica].values[0], a) 1073 self.assertAllEqual(composite[replica].values[1][0], b) 1074 self.assertAllEqual(composite[replica].values[1][1], c) 1075 self.assertAllEqual(result[replica], array_ops.constant([0.50000006])) 1076 1077 def test_per_device_tracing_of_mirrored_variables(self, enable_packed_var): 1078 # Define trace_count as a list to avoid python scoping error 1079 trace_count = [0] 1080 1081 strategy = get_tpu_strategy(enable_packed_var) 1082 with strategy.scope(): 1083 variable = variables.Variable(0.0) 1084 1085 @def_function.function 1086 def add_one(): 1087 trace_count[0] = trace_count[0] + 1 1088 return math_ops.add(variable, constant_op.constant(1.0)) 1089 1090 @def_function.function 1091 def update_variable(): 1092 for device in set(strategy.extended.worker_devices): 1093 with ops.device(device): 1094 add_one() 1095 1096 with strategy.scope(): 1097 update_variable.get_concrete_function() 1098 self.assertLen(strategy.extended.worker_devices, trace_count[0]) 1099 1100 def test_tpu_cancellation_does_not_close_chips(self, enable_packed_var): 1101 if not FLAGS.tpu_use_tfrt: 1102 self.skipTest( 1103 "`tpu_cancellation_closes_chip only applies to TFRT TPU Runtime.") 1104 strategy = get_tpu_strategy(enable_packed_var) 1105 num_replicas = strategy.num_replicas_in_sync 1106 with strategy.scope(): 1107 x = random_ops.random_normal((10240, 10240)) 1108 y = random_ops.random_normal((10240, 10240)) 1109 1110 v = variables.Variable(array_ops.identity(x)) 1111 dist_dataset = strategy.experimental_distribute_dataset( 1112 dataset_ops.Dataset.from_tensors(y).repeat(num_replicas).batch( 1113 num_replicas)) 1114 dist_iterator = iter(dist_dataset) 1115 1116 @def_function.function 1117 def train_steps(v, iterator, steps): 1118 1119 def step_fn(inputs): 1120 for val in inputs: 1121 v.assign(math_ops.matmul(v, val)) 1122 1123 for _ in math_ops.range(steps): 1124 strategy.run(step_fn, args=(next(iterator),)) 1125 1126 with self.assertRaises(errors.OutOfRangeError): 1127 # The iterator has num_replicas/num_replicas = 1 step only. 1128 train_steps(v, dist_iterator, 2) 1129 1130 # If TPU chips are not closed we can run the function on TPU again. 1131 w = variables.Variable(array_ops.identity(x)) 1132 dist_dataset = strategy.experimental_distribute_dataset( 1133 dataset_ops.Dataset.from_tensors(y).repeat(num_replicas).batch( 1134 num_replicas)) 1135 dist_iterator = iter(dist_dataset) 1136 train_steps(w, dist_iterator, 1) 1137 1138 def test_tpu_hardware_feature(self, enable_packed_var): 1139 strategy = get_tpu_strategy(enable_packed_var) 1140 self.assertIsInstance( 1141 strategy.extended.tpu_hardware_feature.embedding_feature, 1142 tpu_hardware_feature.HardwareFeature.EmbeddingFeature) 1143 1144 def test_get_tpu_cluster_resolver(self, enable_packed_var): 1145 strategy = get_tpu_strategy(enable_packed_var) 1146 self.assertIsNotNone(strategy.cluster_resolver) 1147 1148 1149@test_util.with_eager_op_as_function 1150class TPUStrategyDataPrefetchTest(test.TestCase): 1151 1152 def test_prefetch_to_device_default(self): 1153 strategy = get_tpu_strategy() 1154 dataset = dataset_ops.Dataset.range( 1155 strategy.num_replicas_in_sync * 2, 1156 output_type=dtypes.float32).batch(strategy.num_replicas_in_sync) 1157 1158 # Check default, should prefetch to TPU. 1159 dataset_item = next(iter(strategy.experimental_distribute_dataset(dataset))) 1160 dataset_location = tf_device.DeviceSpec.from_string( 1161 dataset_item.values[0].device) 1162 self.assertEqual(dataset_location.device_type, "TPU") 1163 1164 def test_prefetch_to_device_tpu(self): 1165 strategy = get_tpu_strategy() 1166 dataset = dataset_ops.Dataset.range( 1167 strategy.num_replicas_in_sync * 2, 1168 output_type=dtypes.float32).batch(strategy.num_replicas_in_sync) 1169 1170 input_options = distribute_lib.InputOptions( 1171 experimental_fetch_to_device=True) 1172 dataset_item = next(iter(strategy.experimental_distribute_dataset( 1173 dataset, options=input_options))) 1174 dataset_location = tf_device.DeviceSpec.from_string( 1175 dataset_item.values[0].device) 1176 self.assertEqual(dataset_location.device_type, "TPU") 1177 1178 def test_prefetch_to_device_cpu(self): 1179 strategy = get_tpu_strategy() 1180 dataset = dataset_ops.Dataset.range( 1181 strategy.num_replicas_in_sync * 2, 1182 output_type=dtypes.float32).batch(strategy.num_replicas_in_sync) 1183 1184 # Should be CPU when prefetch_to_device is False. 1185 input_options = distribute_lib.InputOptions( 1186 experimental_fetch_to_device=False) 1187 dataset_item = next(iter(strategy.experimental_distribute_dataset( 1188 dataset, options=input_options))) 1189 dataset_location = tf_device.DeviceSpec.from_string( 1190 dataset_item.values[0].device) 1191 self.assertEqual(dataset_location.device_type, "CPU") 1192 1193 def test_prefetch_to_device_sparse_dataset(self): 1194 strategy = get_tpu_strategy() 1195 # Values here aren't important. 1196 dataset = dataset_ops.Dataset.from_tensors( 1197 sparse_tensor.SparseTensor(indices=[[0, 0], [0, 1], [1, 0]], 1198 values=[1, 2, 3], 1199 dense_shape=[2, 2])) 1200 dataset = dataset.repeat() 1201 dataset = dataset.batch(strategy.num_replicas_in_sync) 1202 1203 with self.assertRaisesRegex(ValueError, "TPUStrategy does not support"): 1204 iter(strategy.experimental_distribute_dataset(dataset)) 1205 1206 def test_prefetch_to_device_ragged_dataset(self): 1207 strategy = get_tpu_strategy() 1208 # Values here aren't important. 1209 dataset = dataset_ops.Dataset.from_tensors( 1210 ragged_tensor.RaggedTensor.from_row_splits( 1211 values=[1, 2, 3], 1212 row_splits=[0, 2, 3])) 1213 dataset = dataset.repeat() 1214 dataset = dataset.batch(strategy.num_replicas_in_sync) 1215 1216 with self.assertRaisesRegex(ValueError, "TPUStrategy does not support"): 1217 iter(strategy.experimental_distribute_dataset(dataset)) 1218 1219 def test_prefetch_to_device_sparse_dataset_fn(self): 1220 strategy = get_tpu_strategy() 1221 def dataset_fn(ctx): 1222 del ctx 1223 # Values here aren't important. 1224 dataset = dataset_ops.Dataset.from_tensors( 1225 sparse_tensor.SparseTensor(indices=[[0, 0], [0, 1], [1, 0]], 1226 values=[1, 2, 3], 1227 dense_shape=[2, 2])) 1228 dataset = dataset.repeat() 1229 return dataset.batch(strategy.num_replicas_in_sync) 1230 1231 with self.assertRaisesRegex(ValueError, "TPUStrategy does not support"): 1232 iter(strategy.distribute_datasets_from_function(dataset_fn)) 1233 1234 def test_prefetch_to_device_ragged_dataset_fn(self): 1235 strategy = get_tpu_strategy() 1236 def dataset_fn(ctx): 1237 del ctx 1238 # Values here aren't important. 1239 dataset = dataset_ops.Dataset.from_tensors( 1240 ragged_tensor.RaggedTensor.from_row_splits( 1241 values=[1, 2, 3], 1242 row_splits=[0, 2, 3])) 1243 dataset = dataset.repeat() 1244 return dataset.batch(strategy.num_replicas_in_sync) 1245 1246 with self.assertRaisesRegex(ValueError, "TPUStrategy does not support"): 1247 iter(strategy.distribute_datasets_from_function(dataset_fn)) 1248 1249 def test_create_iterator_on_device(self): 1250 1251 @def_function.function 1252 def create_iter(): 1253 with ops.device("/device:TPU:0"): 1254 return gen_dataset_ops.anonymous_iterator_v3( 1255 output_types=[dtypes.float32], output_shapes=[[]]) 1256 1257 create_iter() 1258 1259 1260@test_util.with_eager_op_as_function 1261class TPUStrategyDistributionTest( 1262 strategy_test_lib.DistributionTestBase, 1263 strategy_test_lib.TwoDeviceDistributionTestBase): 1264 1265 def test_update_config_proto(self): 1266 resolver = get_tpu_cluster_resolver() 1267 remote.connect_to_cluster(resolver) 1268 tpu_strategy_util.initialize_tpu_system(resolver) 1269 strategy = tpu_lib.TPUStrategyV2(resolver) 1270 1271 config_proto = config_pb2.ConfigProto() 1272 cluster_spec = server_lib.ClusterSpec({"worker": ["fake1", "fake2"]}) 1273 with test.mock.patch.object( 1274 resolver, "cluster_spec", return_value=cluster_spec): 1275 new_config = strategy.update_config_proto(config_proto) 1276 1277 # Verify cluster_def. 1278 self.assertProtoEquals(cluster_spec.as_cluster_def(), 1279 new_config.cluster_def) 1280 1281 # Verify isolate_session_state 1282 self.assertTrue(new_config.isolate_session_state) 1283 1284 def test_make_input_fn_iterable(self): 1285 dataset_fn = lambda: dataset_ops.Dataset.range(10) 1286 expected_values = [[i, i+1] for i in range(0, 10, 2)] 1287 distribution = get_tpu_strategy() 1288 input_fn = self._input_fn_to_test_input_context( 1289 dataset_fn, 1290 expected_num_replicas_in_sync=2, 1291 expected_num_input_pipelines=1, 1292 expected_input_pipeline_id=0) 1293 self._test_input_fn_iterable(distribution, input_fn, expected_values) 1294 1295 def test_make_input_fn_iterator(self): 1296 dataset_fn = lambda: dataset_ops.Dataset.range(10) 1297 expected_values = [[i, i+1] for i in range(0, 10, 2)] 1298 distribution = get_tpu_strategy() 1299 input_fn = self._input_fn_to_test_input_context( 1300 dataset_fn, 1301 expected_num_replicas_in_sync=2, 1302 expected_num_input_pipelines=1, 1303 expected_input_pipeline_id=0) 1304 iterator = distribution.make_input_fn_iterator(input_fn) 1305 self._test_input_fn_iterator( 1306 iterator, 1307 distribution.extended.worker_devices, 1308 expected_values) 1309 1310 def test_num_replicas_in_sync(self): 1311 strategy = get_tpu_strategy() 1312 self.assertEqual(2, strategy.num_replicas_in_sync) 1313 1314 def test_call_and_merge_exceptions(self): 1315 strategy = get_tpu_strategy() 1316 self._test_call_and_merge_exceptions(strategy) 1317 1318 def test_numpy_dataset(self): 1319 strategy = get_tpu_strategy() 1320 self._test_numpy_dataset(strategy, run_in_function=True) 1321 1322 def test_global_step_update(self): 1323 strategy = get_tpu_strategy() 1324 self._test_global_step_update(strategy) 1325 1326 def test_run(self): 1327 strategy = get_tpu_strategy() 1328 self._test_run(strategy, run_in_function=True) 1329 1330 def test_summary_for_replica_zero_only(self): 1331 strategy = get_tpu_strategy() 1332 self._test_summary_for_replica_zero_only(strategy) 1333 1334 def test_all_reduce_sum(self): 1335 strategy = get_tpu_strategy() 1336 self._test_all_reduce_sum(strategy, run_in_function=True) 1337 1338 def test_all_reduce_sum_gradients(self): 1339 strategy = get_tpu_strategy() 1340 self._test_all_reduce_sum_gradients(strategy, run_in_function=True) 1341 1342 def test_all_reduce_sum_gradient_tape(self): 1343 strategy = get_tpu_strategy() 1344 self._test_all_reduce_sum_gradient_tape(strategy, run_in_function=True) 1345 1346 def test_all_reduce_mean(self): 1347 strategy = get_tpu_strategy() 1348 self._test_all_reduce_mean(strategy, run_in_function=True) 1349 1350 def test_all_reduce_mean_gradients(self): 1351 strategy = get_tpu_strategy() 1352 self._test_all_reduce_mean_gradients(strategy, run_in_function=True) 1353 1354 def test_all_reduce_mean_gradient_tape(self): 1355 strategy = get_tpu_strategy() 1356 self._test_all_reduce_mean_gradient_tape(strategy, run_in_function=True) 1357 1358 def test_reduce(self): 1359 strategy = get_tpu_strategy() 1360 1361 inputs = strategy.make_input_fn_iterator( 1362 lambda _: dataset_ops.Dataset.from_tensor_slices([2., 3.])) 1363 1364 self.evaluate(inputs.initialize()) 1365 per_replica_outputs = strategy.run( 1366 def_function.function(math_ops.square), args=(next(inputs),)) 1367 1368 with strategy.scope(): 1369 mean = strategy.reduce(reduce_util.ReduceOp.MEAN, per_replica_outputs, 1370 axis=None) 1371 self.assertEqual(6.5, self.evaluate(mean)) 1372 1373 def test_constraint(self): 1374 strategy = get_tpu_strategy() 1375 1376 with strategy.scope(): 1377 variable = variables.Variable(initial_value=2., 1378 constraint=lambda x: 0. * x + 1.) 1379 self.assertEqual(variable.value().numpy(), 2) 1380 1381 @def_function.function 1382 def update_variable(): 1383 variable.assign_add(1) 1384 variable.assign(variable.constraint(variable)) 1385 1386 update_variable() 1387 self.assertEqual(variable.value().numpy(), 1) 1388 1389 def test_trainable_variables(self): 1390 strategy = get_tpu_strategy() 1391 self._test_trainable_variable(strategy) 1392 1393 1394@test_util.with_eager_op_as_function 1395class DeviceAssignmentTest(test.TestCase): 1396 1397 def test_core_assignment(self): 1398 resolver = get_tpu_cluster_resolver() 1399 remote.connect_to_cluster(resolver) 1400 topology = tpu_strategy_util.initialize_tpu_system(resolver) 1401 device_assignment = device_assignment_lib.DeviceAssignment( 1402 topology, core_assignment=[[[0, 0, 0, 0]]]) 1403 self.assertAllEqual([[[0, 0, 0, 0]]], device_assignment.core_assignment) 1404 self.assertEqual(1, device_assignment.num_cores_per_replica) 1405 self.assertEqual(1, device_assignment.num_replicas) 1406 self.assertEqual("/task:0/device:TPU:0", device_assignment.tpu_device()) 1407 self.assertEqual("/task:0/device:CPU:0", device_assignment.host_device()) 1408 1409 def test_device_assignment_strategy_properties(self): 1410 resolver = get_tpu_cluster_resolver() 1411 remote.connect_to_cluster(resolver) 1412 topology = tpu_strategy_util.initialize_tpu_system(resolver) 1413 device_assignment = device_assignment_lib.DeviceAssignment( 1414 topology, core_assignment=[[[0, 0, 0, 0]]]) 1415 strategy = tpu_lib.TPUStrategyV2( 1416 resolver, 1417 experimental_device_assignment=device_assignment) 1418 self.assertEqual(strategy.extended.num_hosts, 1) 1419 self.assertEqual(strategy.num_replicas_in_sync, 1) 1420 self.assertEqual(strategy.extended.num_replicas_per_host, 1) # pylint: disable=protected-access 1421 1422 def test_device_assignment_constants(self): 1423 resolver = get_tpu_cluster_resolver() 1424 remote.connect_to_cluster(resolver) 1425 topology = tpu_strategy_util.initialize_tpu_system(resolver) 1426 device_assignment = device_assignment_lib.DeviceAssignment( 1427 topology, 1428 core_assignment=device_assignment_lib.SINGLE_CORE_ASSIGNMENT) 1429 self.assertAllEqual([[[0, 0, 0, 0]]], device_assignment.core_assignment) 1430 self.assertEqual(1, device_assignment.num_cores_per_replica) 1431 self.assertEqual(1, device_assignment.num_replicas) 1432 self.assertEqual("/task:0/device:TPU:0", device_assignment.tpu_device()) 1433 self.assertEqual("/task:0/device:CPU:0", device_assignment.host_device()) 1434 1435 def test_variables_mismatched_device_assignment(self): 1436 resolver = get_tpu_cluster_resolver() 1437 remote.connect_to_cluster(resolver) 1438 topology = tpu_strategy_util.initialize_tpu_system(resolver) 1439 1440 strategy0 = tpu_lib.TPUStrategyV2(resolver) 1441 self.assertEqual( 1442 ("/job:localhost/replica:0/task:0/device:TPU:0", 1443 "/job:localhost/replica:0/task:0/device:TPU:1"), 1444 strategy0.extended.worker_devices) 1445 1446 with strategy0.scope(): 1447 v = variables.Variable(1.) 1448 1449 v1_assign_op = strategy0.experimental_local_results(v)[1].assign(42.) 1450 1451 with self.cached_session(): 1452 self.evaluate(variables.global_variables_initializer()) 1453 self.evaluate(v1_assign_op) 1454 self.assertAllEqual([1., 42.], 1455 self.evaluate( 1456 strategy0.experimental_local_results(v))) 1457 1458 # Second strategy has devices reversed relative to the first. 1459 device_assignment = device_assignment_lib.DeviceAssignment( 1460 topology, core_assignment=[[[0, 0, 0, 1]], [[0, 0, 0, 0]]]) 1461 strategy1 = tpu_lib.TPUStrategyV2( 1462 resolver, 1463 experimental_device_assignment=device_assignment) 1464 self.assertEqual( 1465 ("/job:localhost/replica:0/task:0/device:TPU:1", 1466 "/job:localhost/replica:0/task:0/device:TPU:0"), 1467 strategy1.extended.worker_devices) 1468 1469 v_read = strategy1.run(def_function.function(v.read_value)) 1470 1471 with self.cached_session(): 1472 self.assertAllEqual([42., 1.], 1473 self.evaluate( 1474 strategy0.experimental_local_results(v_read))) 1475 1476 1477if __name__ == "__main__": 1478 test.main() 1479