1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for remote execution.""" 16 17import os 18import random 19import time 20 21from absl.testing import parameterized 22import numpy as np 23import portpicker 24 25from tensorflow.python.data.ops import dataset_ops 26from tensorflow.python.distribute.cluster_resolver.cluster_resolver import SimpleClusterResolver 27from tensorflow.python.eager import cancellation 28from tensorflow.python.eager import context 29from tensorflow.python.eager import def_function 30from tensorflow.python.eager import executor 31from tensorflow.python.eager import remote 32from tensorflow.python.eager import test 33from tensorflow.python.framework import constant_op 34from tensorflow.python.framework import dtypes 35from tensorflow.python.framework import errors 36from tensorflow.python.framework import ops 37from tensorflow.python.framework import tensor_spec 38from tensorflow.python.framework import test_ops 39from tensorflow.python.framework import test_util 40from tensorflow.python.ops import array_ops 41from tensorflow.python.ops import control_flow_ops 42from tensorflow.python.ops import data_flow_ops 43from tensorflow.python.ops import functional_ops 44from tensorflow.python.ops import math_ops 45from tensorflow.python.ops import resource_variable_ops 46from tensorflow.python.ops import string_ops 47from tensorflow.python.ops import variables 48from tensorflow.python.training import server_lib 49from tensorflow.python.training.server_lib import ClusterSpec 50from tensorflow.python.util import compat 51 52 53class SingleWorkerTest(test.TestCase, parameterized.TestCase): 54 55 def setUp(self): 56 super(SingleWorkerTest, self).setUp() 57 58 workers, _ = test_util.create_local_cluster(1, 0) 59 remote.connect_to_remote_host(workers[0].target) 60 61 def tearDown(self): 62 super(SingleWorkerTest, self).tearDown() 63 64 # Clear the current device scope to avoid polluting other test cases. 65 ops.device(None).__enter__() 66 # Reset the context to avoid polluting other test cases. 67 context._reset_context() 68 69 def testMultiDeviceFunctionBasic(self): 70 71 @def_function.function 72 def basic(i): 73 with ops.device('/job:localhost/replica:0/task:0/cpu:0'): 74 a = constant_op.constant([2]) + i 75 with ops.device('/job:worker/replica:0/task:0/cpu:0'): 76 b = constant_op.constant([1]) 77 78 return a + b 79 80 self.assertAllEqual(basic(constant_op.constant([2])).numpy(), [5]) 81 self.assertAllEqual(basic(constant_op.constant([1])).numpy(), [4]) 82 83 def testMultiDeviceFunctionVariable(self): 84 with ops.device('/job:worker/replica:0/task:0/cpu:0'): 85 variable_b = variables.Variable(1) 86 87 # Add a sync point to avoid the out-of-order issue of eager async execution 88 # (b/155789951). 89 context.async_wait() 90 91 @def_function.function 92 def with_variable(i): 93 return i + variable_b 94 95 self.assertAllEqual(with_variable(constant_op.constant([2])).numpy(), [3]) 96 97 def testMultiDeviceFunctionRemoteOutput(self): 98 with ops.device('/job:worker/replica:0/task:0/cpu:0'): 99 variable_b = variables.Variable(1) 100 101 @def_function.function 102 def remote_output(i): 103 with ops.device('/job:worker/replica:0/task:0/cpu:0'): 104 c = variable_b + 1 105 return i + variable_b, c 106 107 rets = remote_output(constant_op.constant([1])) 108 self.assertAllEqual(rets[0].numpy(), [2]) 109 self.assertAllEqual(rets[1].numpy(), 2) 110 self.assertEqual(rets[0].backing_device, 111 '/job:localhost/replica:0/task:0/device:CPU:0') 112 self.assertEqual(rets[1].backing_device, 113 '/job:worker/replica:0/task:0/device:CPU:0') 114 115 def testStreaming(self): 116 """A mini stress test for streaming - issuing many RPCs back to back.""" 117 with ops.device('job:worker/replica:0/task:0/device:CPU:0'): 118 x = array_ops.ones([2, 2]) 119 y = array_ops.zeros([2, 2]) 120 num_iters = 200 121 for _ in range(num_iters): 122 y = x + y 123 # Ask for y's shape after every 10 additions on average. 124 # This exercises waiting for remote shape logic in TensorHandle. 125 if random.randint(1, 10) == 1: 126 _ = y.shape 127 np.testing.assert_array_equal( 128 [[num_iters, num_iters], [num_iters, num_iters]], y.numpy()) 129 130 def testTwoExecutors(self): 131 # Run an op on the main executor that by default uses StreamingEnqueue to 132 # schedule the op to run on the remote async executor. This op produces an 133 # error, i.e., division by zero, but will not be immediately caught due to 134 # streaming enqueue. 135 with ops.device('job:worker/replica:0/task:0/device:CPU:0'): 136 a = constant_op.constant(3) 137 b = constant_op.constant(0) 138 math_ops.div(a, b) 139 140 # Run another op using another executor that disables streaming enqueue, 141 # which would run the op using the tf_compute thread pool in the remote 142 # worker. Since the op is not run in the same remotes async executor, it 143 # will not carry back that error produced by the op above, even though this 144 # op is executed synchronously. 145 with context.executor_scope( 146 executor.new_executor( 147 enable_async=False, enable_streaming_enqueue=False)): 148 with ops.device('job:worker/replica:0/task:0/device:CPU:0'): 149 c = constant_op.constant(4) 150 d = constant_op.constant(2) 151 self.assertEqual(math_ops.div(c, d).numpy(), 2) 152 153 # Sync on the context to force to catch the error produced by the first op. 154 with self.assertRaises(errors.InvalidArgumentError) as cm: 155 context.async_wait() 156 self.assertIn('division by zero', cm.exception.message) 157 158 def testShapeError_OpByOp(self): 159 with ops.device('job:worker/replica:0/task:0/device:CPU:0'): 160 x = array_ops.ones([2, 3]) 161 y = array_ops.zeros([2, 2]) 162 with self.assertRaises(errors.InvalidArgumentError) as cm: 163 math_ops.matmul(x, y) 164 165 self.assertIn('Dimensions must be equal', cm.exception.message) 166 167 def testShapeError_Function(self): 168 169 @def_function.function 170 def matmul_func(x, y): 171 return math_ops.matmul(x, y) 172 173 x = array_ops.ones([2, 3]) 174 y = array_ops.zeros([2, 2]) 175 176 with ops.device('job:worker/replica:0/task:0/device:CPU:0'): 177 with self.assertRaises(ValueError) as cm: 178 matmul_func(x, y) 179 180 self.assertIn('Dimensions must be equal', cm.exception.args[0]) 181 182 def testClientVarible(self): 183 var = variables.Variable(initial_value=0) 184 185 @def_function.function 186 def func(): 187 with ops.device('/job:localhost/task:0'): 188 read = var.read_value() 189 return read + 1 190 191 with ops.device('/job:worker/task:0'): 192 self.assertAllEqual(func(), 1) 193 194 def testRemoteCall(self): 195 196 @def_function.function( 197 input_signature=[tensor_spec.TensorSpec([], dtypes.int32)]) 198 def _remote_fn(x): 199 return constant_op.constant(1) + x 200 201 remote_fn = _remote_fn.get_concrete_function() 202 203 @def_function.function 204 def func(x): 205 return functional_ops.remote_call( 206 args=[x], 207 Tout=[dtypes.int32], 208 f=remote_fn, 209 target='/job:worker/task:0') 210 211 with ops.device('/job:localhost/task:0'): 212 self.assertAllEqual(func(constant_op.constant(1)), [2]) 213 214 def testOperationTimeout(self): 215 context._reset_context() 216 context.context().operation_timeout_in_ms = 10 217 workers, _ = test_util.create_local_cluster(1, 0) 218 remote.connect_to_remote_host(workers[0].target) 219 220 q = data_flow_ops.FIFOQueue(1, dtypes.int32) 221 222 @def_function.function 223 def f(): 224 return q.dequeue() 225 226 with self.assertRaises(errors.DeadlineExceededError): 227 with ops.device('/job:worker/replica:0/task:0'): 228 f() 229 # If streaming RPC is enabled, fetch remote errors before end of execution 230 context.async_wait() 231 232 233class RemoteAsyncTest(test.TestCase): 234 235 def setUp(self): 236 super(RemoteAsyncTest, self).setUp() 237 238 workers, _ = test_util.create_local_cluster(1, 0) 239 remote.connect_to_remote_host(workers[0].target) 240 241 def tearDown(self): 242 super(RemoteAsyncTest, self).tearDown() 243 244 # Reset the context to avoid polluting other test cases. 245 context._reset_context() 246 247 def test_out_of_range_with_while_loop(self): 248 249 with ops.device('/job:worker/task:0'): 250 dataset = dataset_ops.Dataset.from_tensor_slices([1.0, 2.0]) 251 dataset = dataset.batch(1, drop_remainder=False) 252 iterator = iter(dataset) 253 v = variables.Variable(1.0) 254 255 @def_function.function 256 def train_step(iterator): 257 i = next(iterator) 258 v.assign_add(math_ops.reduce_mean(i)) 259 260 while True: 261 try: 262 with ops.device('/job:worker/task:0'): 263 train_step(iterator) 264 except (errors.OutOfRangeError, errors.InternalError): 265 context.async_clear_error() 266 break 267 268 self.assertAllEqual(v.numpy(), 4.0) 269 270 def test_out_of_range_with_for_loop(self): 271 272 with ops.device('/job:worker/task:0'): 273 dataset = dataset_ops.Dataset.from_tensor_slices([1.0, 2.0]) 274 dataset = dataset.batch(1, drop_remainder=False) 275 iterator = iter(dataset) 276 v = variables.Variable(1.0) 277 278 @def_function.function 279 def train_step(iterator): 280 i = next(iterator) 281 v.assign_add(math_ops.reduce_mean(i)) 282 283 num_steps = 3 284 for i in range(num_steps): 285 try: 286 with ops.device('/job:worker/task:0'): 287 train_step(iterator) 288 if i == num_steps - 1: 289 context.async_wait() 290 except errors.OutOfRangeError: 291 context.async_clear_error() 292 break 293 294 self.assertAllEqual(v.numpy(), 4.0) 295 296 def test_out_of_range_with_async_scope(self): 297 298 with ops.device('/job:worker/task:0'): 299 dataset = dataset_ops.Dataset.from_tensor_slices([1.0, 2.0]) 300 dataset = dataset.batch(1, drop_remainder=False) 301 iterator = iter(dataset) 302 v = variables.Variable(1.0) 303 304 @def_function.function 305 def train_step(iterator): 306 i = next(iterator) 307 v.assign_add(math_ops.reduce_mean(i)) 308 309 num_steps = 3 310 try: 311 with context.async_scope(): 312 for _ in range(num_steps): 313 with ops.device('/job:worker/task:0'): 314 train_step(iterator) 315 except errors.OutOfRangeError: 316 context.async_clear_error() 317 318 self.assertAllEqual(v.numpy(), 4.0) 319 320 321class MultiWorkersTest(test.TestCase, parameterized.TestCase): 322 323 def setUp(self): 324 super(MultiWorkersTest, self).setUp() 325 326 workers, _ = test_util.create_local_cluster(3, 0) 327 remote.connect_to_remote_host( 328 [workers[0].target, workers[1].target, workers[2].target]) 329 330 def tearDown(self): 331 super(MultiWorkersTest, self).tearDown() 332 333 # Clear the current device scope to avoid polluting other test cases. 334 ops.device(None).__enter__() 335 # Reset the context to avoid polluting other test cases. 336 context._reset_context() 337 338 def testReturnRemoteArgument(self): 339 340 @def_function.function 341 def local_func(i): 342 return i 343 344 with ops.device('/job:worker/replica:0/task:0'): 345 x = constant_op.constant([2, 1]) 346 347 with ops.device('/job:worker/replica:0/task:1'): 348 self.assertAllEqual(local_func(x), [2, 1]) 349 350 def testMultiDeviceFunctionAmbiguousDevice(self): 351 352 @def_function.function 353 def ambiguous_device(i): 354 with ops.device('/job:worker'): 355 # Multiple worker tasks, thus ambiguous device found error will be 356 # raised. 357 return i + constant_op.constant([2]) 358 359 with self.assertRaises(errors.InvalidArgumentError) as cm: 360 ambiguous_device(constant_op.constant([2])).numpy() 361 362 self.assertIn('the output node must match exactly one device', 363 cm.exception.message) 364 365 # Note that the following tests for remote function cancellation only works 366 # when non-streaming RPC. We need to disable streaming explicitly and restore 367 # this config to its initial value at the end of each test case. 368 def testCancelRemoteFunctionBeforeExecution(self): 369 remote_async_env_var = 'TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE' 370 default_streaming = os.environ.get(remote_async_env_var) 371 os.environ[remote_async_env_var] = str(False) 372 373 q = data_flow_ops.FIFOQueue(1, dtypes.int32) 374 375 @def_function.function 376 def f(): 377 return q.dequeue() 378 379 c_mgr = cancellation.CancellationManager() 380 cancelable_func = c_mgr.get_cancelable_function(f.get_concrete_function()) 381 382 c_mgr.start_cancel() 383 with self.assertRaises(errors.CancelledError): 384 with ops.device('/job:worker/replica:0/task:1'): 385 cancelable_func() 386 387 if default_streaming is None: 388 del os.environ[remote_async_env_var] 389 else: 390 os.environ[remote_async_env_var] = default_streaming 391 392 def testCancelRemoteFunctionDuringExecution(self): 393 remote_async_env_var = 'TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE' 394 default_streaming = os.environ.get(remote_async_env_var) 395 os.environ[remote_async_env_var] = str(False) 396 397 q = data_flow_ops.FIFOQueue(1, dtypes.int32) 398 399 @def_function.function 400 def f(): 401 return q.dequeue() 402 403 c_mgr = cancellation.CancellationManager() 404 cancelable_func = c_mgr.get_cancelable_function(f.get_concrete_function()) 405 406 def cancel_thread(): 407 time.sleep(0.5) 408 c_mgr.start_cancel() 409 410 t = self.checkedThread(cancel_thread) 411 t.start() 412 with self.assertRaises(errors.CancelledError): 413 with ops.device('/job:worker/replica:0/task:1'): 414 cancelable_func() 415 t.join() 416 417 if default_streaming is None: 418 del os.environ[remote_async_env_var] 419 else: 420 os.environ[remote_async_env_var] = default_streaming 421 422 def testMultiDeviceFunctionOnLocalDevice(self): 423 with ops.device('/job:worker/replica:0/task:1'): 424 variable_b = variables.Variable(1.0) 425 426 @def_function.function 427 def remote_function(i): 428 with ops.device('/job:worker/replica:0/task:0'): 429 a = i + variable_b 430 c = a + 1.0 431 return c 432 433 self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0]) 434 435 def testMultiDeviceFunctionExecutionOrderingWithPackedInput(self): 436 shape = [2] 437 with ops.device('/job:worker/replica:0/task:2/device:CPU:0'): 438 # Send 20 remote requests to simulate heavy load on worker:2. 439 unused_values = [] 440 for _ in range(20): 441 unused_values.append(array_ops.zeros(shape)) 442 func_input = array_ops.zeros(shape) 443 444 packed_input = ops.pack_eager_tensors([func_input]) 445 446 @def_function.function 447 def func(packed_input): 448 # When worker:2 receives the component function request, packed_input 449 # should be ready on worker:2. 450 with ops.device('/job:worker/replica:0/task:2/device:CPU:0'): 451 ret = packed_input + constant_op.constant(1.0) 452 return ret + constant_op.constant(1.0) 453 454 # Run the function on a worker:1 455 with ops.device('/job:worker/replica:0/task:1/device:CPU:0'): 456 self.assertAllEqual(func(packed_input).numpy(), 457 array_ops.ones(shape).numpy() * 2) 458 459 def testMultiDeviceFunctionWithPackedVariable(self): 460 with ops.device('/job:worker/replica:0/task:0/device:CPU:0'): 461 var0 = resource_variable_ops.ResourceVariable(1.0) 462 with ops.device('/job:worker/replica:0/task:1/device:CPU:0'): 463 var1 = resource_variable_ops.ResourceVariable(2.0) 464 465 packed_var = ops.pack_eager_tensors([var0.handle, var1.handle]) 466 self.assertEqual(packed_var.device, 467 '/job:localhost/replica:0/task:0/device:COMPOSITE:0') 468 self.assertEqual(packed_var.backing_device, 469 '/job:localhost/replica:0/task:0/device:COMPOSITE:0') 470 471 @def_function.function 472 def add_variables(): 473 with ops.device('/job:worker/replica:0/task:0/device:CPU:0'): 474 read0 = resource_variable_ops.read_variable_op( 475 packed_var, dtype=dtypes.float32) 476 with ops.device('/job:worker/replica:0/task:1/device:CPU:0'): 477 read1 = resource_variable_ops.read_variable_op( 478 packed_var, dtype=dtypes.float32) 479 480 return read0 + read1 481 482 # Run the function on a remote device 483 with ops.device('/job:worker/replica:0/task:0'): 484 self.assertAllEqual(add_variables().numpy(), 3.0) 485 486 # Run the function on a local worker 487 self.assertAllEqual(add_variables().numpy(), 3.0) 488 489 def testMultiDeviceFunctionOnRemoteDeviceWithWait(self): 490 with ops.device('/job:worker/replica:0/task:1'): 491 variable_b = variables.Variable([1.0]) 492 493 @def_function.function 494 def remote_function(i): 495 x = array_ops.ones([1000, 1000]) 496 for _ in range(1, 1000): 497 x = x * x 498 variable_b.assign_add(i) 499 a = 1.0 + variable_b 500 return a 501 502 @def_function.function 503 def remote_function2(i): 504 variable_b.assign_add(i) 505 a = 1.0 + variable_b 506 return a 507 508 # Runs first function: 509 # - on remote device 510 # - needs remote input 511 # - is side impacting 512 # - runs much slower 513 with ops.device('/job:worker/replica:0/task:0'): 514 remote_function(constant_op.constant([2.0])) 515 516 # Runs second function: 517 # - on remote device 518 # - is side impacting 519 # There should be a sync point here and the next function will be executed 520 # only after the first function has completed. 521 with ops.device('/job:worker/replica:0/task:2'): 522 self.assertAllEqual(remote_function2(constant_op.constant([3.0])), [7.0]) 523 524 def testMultiDeviceFunctionOnRemoteDevice(self): 525 with ops.device('/job:worker/replica:0/task:1'): 526 variable_b = variables.Variable(1.0) 527 528 @def_function.function 529 def remote_function(i): 530 with ops.device('/job:worker/replica:0/task:0'): 531 a = i + variable_b 532 c = a + 1.0 533 return c 534 535 with ops.device('/job:worker/replica:0/task:0'): 536 self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0]) 537 538 if test_util.is_gpu_available(): 539 with ops.device('/job:worker/replica:0/task:0/device:GPU:0'): 540 self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0]) 541 542 def testMultiDeviceFunctionRemoteOutput(self): 543 with ops.device('/job:worker/replica:0/task:1/cpu:0'): 544 variable_b = variables.Variable(1) 545 546 @def_function.function 547 def remote_output(i): 548 with ops.device('/job:worker/replica:0/task:1/cpu:0'): 549 c = variable_b + 1 550 return i + variable_b, c 551 552 with ops.device('/job:worker/replica:0/task:0/cpu:0'): 553 rets = remote_output(constant_op.constant([1])) 554 self.assertEqual(rets[0].backing_device, 555 '/job:worker/replica:0/task:0/device:CPU:0') 556 self.assertEqual(rets[1].backing_device, 557 '/job:worker/replica:0/task:1/device:CPU:0') 558 self.assertAllEqual(rets[0].numpy(), [2]) 559 self.assertAllEqual(rets[1].numpy(), 2) 560 561 def testMultiDeviceWhileLoopOnRemoteDevice(self): 562 with ops.device('/job:worker/replica:0/task:1'): 563 variable_b = variables.Variable(1.0) 564 565 @def_function.function 566 def remote_function(i): 567 568 def body(i, _): 569 with ops.device('/job:worker/replica:0/task:0'): 570 a = i + variable_b 571 return a + 1.0, 1 572 573 return control_flow_ops.while_loop_v2(lambda _, d: d < 1, body, [i, 0])[0] 574 575 with ops.device('/job:worker/replica:0/task:0'): 576 self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0]) 577 578 if test_util.is_gpu_available(): 579 with ops.device('/job:worker/replica:0/task:0/device:GPU:0'): 580 self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0]) 581 582 def testSimpleParameterServer(self): 583 584 with ops.device('/job:worker/task:2/device:CPU:0'): 585 v1 = variables.Variable(initial_value=0) 586 v2 = variables.Variable(initial_value=10) 587 588 @def_function.function 589 def worker_fn(): 590 v1.assign_add(1) 591 v2.assign_sub(2) 592 return v1.read_value() + v2.read_value() 593 594 with ops.device('/job:worker/task:0/device:CPU:0'): 595 self.assertAllEqual(worker_fn(), 9) 596 597 with ops.device('/job:worker/task:1/device:CPU:0'): 598 self.assertAllEqual(worker_fn(), 8) 599 600 601_GRPC_PREFIX = 'grpc://' 602 603 604class MultiJobsTest(test.TestCase, parameterized.TestCase): 605 606 def setUp(self): 607 super(MultiJobsTest, self).setUp() 608 609 workers, ps = test_util.create_local_cluster(num_workers=2, num_ps=2) 610 cluster = { 611 'my_worker': [_strip_prefix(t.target, _GRPC_PREFIX) for t in workers], 612 'my_ps': [_strip_prefix(t.target, _GRPC_PREFIX) for t in ps], 613 } 614 self._cluster = server_lib.ClusterSpec(cluster) 615 self._cluster_resolver = SimpleClusterResolver( 616 cluster_spec=self._cluster, master=ps[0].target) 617 618 def tearDown(self): 619 super(MultiJobsTest, self).tearDown() 620 621 # Clear the current device scope to avoid polluting other test cases. 622 ops.device(None).__enter__() 623 # Reset the context to avoid polluting other test cases. 624 context._reset_context() 625 626 def testMultipleDeviceFoundCheck(self): 627 remote.connect_to_cluster(self._cluster) 628 629 @def_function.function 630 def func(): 631 with ops.device('cpu:0'): 632 # Multiple CPU:0 devices match would be found, but the CPU:0 from the 633 # parent device scope should be picked. 634 x = test_ops.device_placement_op() 635 y = string_ops.string_upper(x) 636 packed_var_0 = array_ops.stack([x, y], 0) 637 return packed_var_0 638 639 with ops.device('/job:my_worker/task:1'): 640 output = self.evaluate(func()) 641 self.assertEqual( 642 compat.as_bytes('/job:my_worker/replica:0/task:1/device:CPU:0'), 643 output[0]) 644 self.assertIn(compat.as_bytes('/JOB:MY_WORKER'), output[1]) 645 with ops.device('/job:my_ps/task:1'): 646 output = self.evaluate(func()) 647 self.assertEqual( 648 compat.as_bytes('/job:my_ps/replica:0/task:1/device:CPU:0'), 649 output[0]) 650 self.assertIn(compat.as_bytes('/JOB:MY_PS'), output[1]) 651 652 def testSimpleParameterServer(self): 653 remote.connect_to_cluster(self._cluster) 654 655 with ops.device('/job:my_ps/task:0/device:CPU:0'): 656 v1 = variables.Variable(initial_value=0) 657 v2 = variables.Variable(initial_value=10) 658 659 @def_function.function 660 def worker_fn(): 661 v1.assign_add(1) 662 v2.assign_sub(2) 663 return v1.read_value() + v2.read_value() 664 665 with ops.device('/job:my_worker/task:0/device:CPU:0'): 666 self.assertAllEqual(worker_fn(), 9) 667 668 with ops.device('/job:my_worker/task:1/device:CPU:0'): 669 self.assertAllEqual(worker_fn(), 8) 670 671 def testResetClusterWithDifferentJobNames(self): 672 addr = 'localhost:%s' % portpicker.pick_unused_port() 673 cluster = server_lib.ClusterSpec({'localhost': [addr]}) 674 remote.connect_to_cluster(cluster, job_name='localhost') 675 with ops.device('/job:localhost/task:0/device:CPU:0'): 676 v1 = variables.Variable(initial_value=0) 677 v1.assign_add(1) 678 679 # Replace job name from 'localhost' to 'worker' in the cluster. 680 addr = 'localhost:%s' % portpicker.pick_unused_port() 681 cluster = server_lib.ClusterSpec({'worker': [addr]}) 682 remote.connect_to_cluster(cluster, job_name='worker') 683 684 with ops.device('/job:worker/task:0/device:CPU:0'): 685 v2 = variables.Variable(initial_value=0) 686 v2.assign_add(1) 687 688 # TODO(b/152224115): Re-enable this test. 689 def DISABLED_testSimpleParameterServerWithDeviceFilters(self): 690 cluster_device_filters = server_lib.ClusterDeviceFilters() 691 for i in range(2): 692 cluster_device_filters.set_device_filters('my_worker', i, ['/job:my_ps']) 693 cluster_device_filters.set_device_filters('my_ps', i, ['/job:my_worker']) 694 remote.connect_to_cluster( 695 self._cluster, cluster_device_filters=cluster_device_filters) 696 697 with ops.device('/job:my_ps/task:0/device:CPU:0'): 698 v1 = variables.Variable(initial_value=0) 699 with ops.device('/job:my_ps/task:1/device:CPU:0'): 700 v2 = variables.Variable(initial_value=10) 701 702 @def_function.function 703 def worker_fn(): 704 v1.assign_add(1) 705 v2.assign_sub(2) 706 return v1.read_value() + v2.read_value() 707 708 with ops.device('/job:my_worker/task:0/device:CPU:0'): 709 self.assertAllEqual(worker_fn(), 9) 710 with ops.device('/job:my_worker/task:1/device:CPU:0'): 711 self.assertAllEqual(worker_fn(), 8) 712 713 # The following remote call would fail because the ps nodes cannot see each 714 # other due to the device filters. 715 with self.assertRaises(errors.InvalidArgumentError) as cm: 716 with ops.device('/job:my_ps/task:0/device:CPU:0'): 717 worker_fn().numpy() 718 self.assertIn('/job:my_ps/replica:0/task:1/device:CPU:0 unknown device', 719 cm.exception.message) 720 721 with self.assertRaises(errors.InvalidArgumentError) as cm: 722 with ops.device('/job:my_ps/task:1/device:CPU:0'): 723 worker_fn().numpy() 724 self.assertIn('/job:my_ps/replica:0/task:0/device:CPU:0 unknown device', 725 cm.exception.message) 726 727 with ops.device('/job:my_worker/task:0/device:CPU:0'): 728 self.assertAllEqual(worker_fn(), 7) 729 with ops.device('/job:my_worker/task:1/device:CPU:0'): 730 self.assertAllEqual(worker_fn(), 6) 731 # Explicitly delete variables to avoid triggering errors when being GC'ed in 732 # subsequent tests. 733 del v1, v2 734 735 def testConnectWithClusterResolver(self): 736 remote.connect_to_cluster(self._cluster_resolver) 737 738 v1 = variables.Variable(initial_value=0) 739 v2 = variables.Variable(initial_value=10) 740 741 @def_function.function 742 def worker_fn(): 743 v1.assign_add(1) 744 v2.assign_sub(2) 745 return v1.read_value() + v2.read_value() 746 747 with ops.device('/job:my_worker/task:0/device:CPU:0'): 748 self.assertAllEqual(worker_fn(), 9) 749 750 with ops.device('/job:my_worker/task:1/device:CPU:0'): 751 self.assertAllEqual(worker_fn(), 8) 752 753 def testConnectToClusterTwiceOk(self): 754 remote.connect_to_cluster(self._cluster_resolver) 755 remote.connect_to_cluster(self._cluster_resolver) 756 757 def testConnectToClusterOnMismatchedDevice(self): 758 remote.connect_to_cluster(self._cluster_resolver) 759 760 # enter into another device scope. 761 ops.device('/job:my_worker/task:0/device:CPU:0').__enter__() 762 763 with self.assertRaises(ValueError): 764 remote.connect_to_cluster(self._cluster_resolver) 765 766 def testConnectToClusterWithLocalMaster(self): 767 local_resolver = SimpleClusterResolver(ClusterSpec({}), master='local') 768 remote.connect_to_cluster(local_resolver) 769 770 def testConnectToClusterInGraphModeWillFail(self): 771 ops.disable_eager_execution() 772 with self.assertRaises(ValueError): 773 remote.connect_to_cluster(self._cluster_resolver) 774 ops.enable_eager_execution() 775 776 def testConnectToClusterWithoutLocalGpu(self): 777 # Only remote workers have GPU devices 778 context.context().set_visible_devices([], 'GPU') 779 # Ensure that no default device is set in eager context 780 remote.connect_to_cluster(self._cluster_resolver, 781 make_master_device_default=False) 782 self.assertEmpty(context.get_device_name()) 783 784 v1 = variables.Variable(initial_value=0) 785 v1.assign_add(1) 786 self.assertAllEqual(v1.read_value(), 1) 787 788 789def _strip_prefix(s, prefix): 790 return s[len(prefix):] if s.startswith(prefix) else s 791 792 793if __name__ == '__main__': 794 test.main() 795