xref: /aosp_15_r20/external/tensorflow/tensorflow/python/eager/remote_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for remote execution."""
16
17import os
18import random
19import time
20
21from absl.testing import parameterized
22import numpy as np
23import portpicker
24
25from tensorflow.python.data.ops import dataset_ops
26from tensorflow.python.distribute.cluster_resolver.cluster_resolver import SimpleClusterResolver
27from tensorflow.python.eager import cancellation
28from tensorflow.python.eager import context
29from tensorflow.python.eager import def_function
30from tensorflow.python.eager import executor
31from tensorflow.python.eager import remote
32from tensorflow.python.eager import test
33from tensorflow.python.framework import constant_op
34from tensorflow.python.framework import dtypes
35from tensorflow.python.framework import errors
36from tensorflow.python.framework import ops
37from tensorflow.python.framework import tensor_spec
38from tensorflow.python.framework import test_ops
39from tensorflow.python.framework import test_util
40from tensorflow.python.ops import array_ops
41from tensorflow.python.ops import control_flow_ops
42from tensorflow.python.ops import data_flow_ops
43from tensorflow.python.ops import functional_ops
44from tensorflow.python.ops import math_ops
45from tensorflow.python.ops import resource_variable_ops
46from tensorflow.python.ops import string_ops
47from tensorflow.python.ops import variables
48from tensorflow.python.training import server_lib
49from tensorflow.python.training.server_lib import ClusterSpec
50from tensorflow.python.util import compat
51
52
53class SingleWorkerTest(test.TestCase, parameterized.TestCase):
54
55  def setUp(self):
56    super(SingleWorkerTest, self).setUp()
57
58    workers, _ = test_util.create_local_cluster(1, 0)
59    remote.connect_to_remote_host(workers[0].target)
60
61  def tearDown(self):
62    super(SingleWorkerTest, self).tearDown()
63
64    # Clear the current device scope to avoid polluting other test cases.
65    ops.device(None).__enter__()
66    # Reset the context to avoid polluting other test cases.
67    context._reset_context()
68
69  def testMultiDeviceFunctionBasic(self):
70
71    @def_function.function
72    def basic(i):
73      with ops.device('/job:localhost/replica:0/task:0/cpu:0'):
74        a = constant_op.constant([2]) + i
75      with ops.device('/job:worker/replica:0/task:0/cpu:0'):
76        b = constant_op.constant([1])
77
78      return a + b
79
80    self.assertAllEqual(basic(constant_op.constant([2])).numpy(), [5])
81    self.assertAllEqual(basic(constant_op.constant([1])).numpy(), [4])
82
83  def testMultiDeviceFunctionVariable(self):
84    with ops.device('/job:worker/replica:0/task:0/cpu:0'):
85      variable_b = variables.Variable(1)
86
87    # Add a sync point to avoid the out-of-order issue of eager async execution
88    # (b/155789951).
89    context.async_wait()
90
91    @def_function.function
92    def with_variable(i):
93      return i + variable_b
94
95    self.assertAllEqual(with_variable(constant_op.constant([2])).numpy(), [3])
96
97  def testMultiDeviceFunctionRemoteOutput(self):
98    with ops.device('/job:worker/replica:0/task:0/cpu:0'):
99      variable_b = variables.Variable(1)
100
101    @def_function.function
102    def remote_output(i):
103      with ops.device('/job:worker/replica:0/task:0/cpu:0'):
104        c = variable_b + 1
105      return i + variable_b, c
106
107    rets = remote_output(constant_op.constant([1]))
108    self.assertAllEqual(rets[0].numpy(), [2])
109    self.assertAllEqual(rets[1].numpy(), 2)
110    self.assertEqual(rets[0].backing_device,
111                     '/job:localhost/replica:0/task:0/device:CPU:0')
112    self.assertEqual(rets[1].backing_device,
113                     '/job:worker/replica:0/task:0/device:CPU:0')
114
115  def testStreaming(self):
116    """A mini stress test for streaming - issuing many RPCs back to back."""
117    with ops.device('job:worker/replica:0/task:0/device:CPU:0'):
118      x = array_ops.ones([2, 2])
119      y = array_ops.zeros([2, 2])
120      num_iters = 200
121      for _ in range(num_iters):
122        y = x + y
123        # Ask for y's shape after every 10 additions on average.
124        # This exercises waiting for remote shape logic in TensorHandle.
125        if random.randint(1, 10) == 1:
126          _ = y.shape
127    np.testing.assert_array_equal(
128        [[num_iters, num_iters], [num_iters, num_iters]], y.numpy())
129
130  def testTwoExecutors(self):
131    # Run an op on the main executor that by default uses StreamingEnqueue to
132    # schedule the op to run on the remote async executor. This op produces an
133    # error, i.e., division by zero, but will not be immediately caught due to
134    # streaming enqueue.
135    with ops.device('job:worker/replica:0/task:0/device:CPU:0'):
136      a = constant_op.constant(3)
137      b = constant_op.constant(0)
138      math_ops.div(a, b)
139
140    # Run another op using another executor that disables streaming enqueue,
141    # which would run the op using the tf_compute thread pool in the remote
142    # worker. Since the op is not run in the same remotes async executor, it
143    # will not carry back that error produced by the op above, even though this
144    # op is executed synchronously.
145    with context.executor_scope(
146        executor.new_executor(
147            enable_async=False, enable_streaming_enqueue=False)):
148      with ops.device('job:worker/replica:0/task:0/device:CPU:0'):
149        c = constant_op.constant(4)
150        d = constant_op.constant(2)
151        self.assertEqual(math_ops.div(c, d).numpy(), 2)
152
153    # Sync on the context to force to catch the error produced by the first op.
154    with self.assertRaises(errors.InvalidArgumentError) as cm:
155      context.async_wait()
156    self.assertIn('division by zero', cm.exception.message)
157
158  def testShapeError_OpByOp(self):
159    with ops.device('job:worker/replica:0/task:0/device:CPU:0'):
160      x = array_ops.ones([2, 3])
161      y = array_ops.zeros([2, 2])
162      with self.assertRaises(errors.InvalidArgumentError) as cm:
163        math_ops.matmul(x, y)
164
165    self.assertIn('Dimensions must be equal', cm.exception.message)
166
167  def testShapeError_Function(self):
168
169    @def_function.function
170    def matmul_func(x, y):
171      return math_ops.matmul(x, y)
172
173    x = array_ops.ones([2, 3])
174    y = array_ops.zeros([2, 2])
175
176    with ops.device('job:worker/replica:0/task:0/device:CPU:0'):
177      with self.assertRaises(ValueError) as cm:
178        matmul_func(x, y)
179
180    self.assertIn('Dimensions must be equal', cm.exception.args[0])
181
182  def testClientVarible(self):
183    var = variables.Variable(initial_value=0)
184
185    @def_function.function
186    def func():
187      with ops.device('/job:localhost/task:0'):
188        read = var.read_value()
189      return read + 1
190
191    with ops.device('/job:worker/task:0'):
192      self.assertAllEqual(func(), 1)
193
194  def testRemoteCall(self):
195
196    @def_function.function(
197        input_signature=[tensor_spec.TensorSpec([], dtypes.int32)])
198    def _remote_fn(x):
199      return constant_op.constant(1) + x
200
201    remote_fn = _remote_fn.get_concrete_function()
202
203    @def_function.function
204    def func(x):
205      return functional_ops.remote_call(
206          args=[x],
207          Tout=[dtypes.int32],
208          f=remote_fn,
209          target='/job:worker/task:0')
210
211    with ops.device('/job:localhost/task:0'):
212      self.assertAllEqual(func(constant_op.constant(1)), [2])
213
214  def testOperationTimeout(self):
215    context._reset_context()
216    context.context().operation_timeout_in_ms = 10
217    workers, _ = test_util.create_local_cluster(1, 0)
218    remote.connect_to_remote_host(workers[0].target)
219
220    q = data_flow_ops.FIFOQueue(1, dtypes.int32)
221
222    @def_function.function
223    def f():
224      return q.dequeue()
225
226    with self.assertRaises(errors.DeadlineExceededError):
227      with ops.device('/job:worker/replica:0/task:0'):
228        f()
229      # If streaming RPC is enabled, fetch remote errors before end of execution
230      context.async_wait()
231
232
233class RemoteAsyncTest(test.TestCase):
234
235  def setUp(self):
236    super(RemoteAsyncTest, self).setUp()
237
238    workers, _ = test_util.create_local_cluster(1, 0)
239    remote.connect_to_remote_host(workers[0].target)
240
241  def tearDown(self):
242    super(RemoteAsyncTest, self).tearDown()
243
244    # Reset the context to avoid polluting other test cases.
245    context._reset_context()
246
247  def test_out_of_range_with_while_loop(self):
248
249    with ops.device('/job:worker/task:0'):
250      dataset = dataset_ops.Dataset.from_tensor_slices([1.0, 2.0])
251      dataset = dataset.batch(1, drop_remainder=False)
252      iterator = iter(dataset)
253      v = variables.Variable(1.0)
254
255    @def_function.function
256    def train_step(iterator):
257      i = next(iterator)
258      v.assign_add(math_ops.reduce_mean(i))
259
260    while True:
261      try:
262        with ops.device('/job:worker/task:0'):
263          train_step(iterator)
264      except (errors.OutOfRangeError, errors.InternalError):
265        context.async_clear_error()
266        break
267
268    self.assertAllEqual(v.numpy(), 4.0)
269
270  def test_out_of_range_with_for_loop(self):
271
272    with ops.device('/job:worker/task:0'):
273      dataset = dataset_ops.Dataset.from_tensor_slices([1.0, 2.0])
274      dataset = dataset.batch(1, drop_remainder=False)
275      iterator = iter(dataset)
276      v = variables.Variable(1.0)
277
278    @def_function.function
279    def train_step(iterator):
280      i = next(iterator)
281      v.assign_add(math_ops.reduce_mean(i))
282
283    num_steps = 3
284    for i in range(num_steps):
285      try:
286        with ops.device('/job:worker/task:0'):
287          train_step(iterator)
288        if i == num_steps - 1:
289          context.async_wait()
290      except errors.OutOfRangeError:
291        context.async_clear_error()
292        break
293
294    self.assertAllEqual(v.numpy(), 4.0)
295
296  def test_out_of_range_with_async_scope(self):
297
298    with ops.device('/job:worker/task:0'):
299      dataset = dataset_ops.Dataset.from_tensor_slices([1.0, 2.0])
300      dataset = dataset.batch(1, drop_remainder=False)
301      iterator = iter(dataset)
302      v = variables.Variable(1.0)
303
304    @def_function.function
305    def train_step(iterator):
306      i = next(iterator)
307      v.assign_add(math_ops.reduce_mean(i))
308
309    num_steps = 3
310    try:
311      with context.async_scope():
312        for _ in range(num_steps):
313          with ops.device('/job:worker/task:0'):
314            train_step(iterator)
315    except errors.OutOfRangeError:
316      context.async_clear_error()
317
318    self.assertAllEqual(v.numpy(), 4.0)
319
320
321class MultiWorkersTest(test.TestCase, parameterized.TestCase):
322
323  def setUp(self):
324    super(MultiWorkersTest, self).setUp()
325
326    workers, _ = test_util.create_local_cluster(3, 0)
327    remote.connect_to_remote_host(
328        [workers[0].target, workers[1].target, workers[2].target])
329
330  def tearDown(self):
331    super(MultiWorkersTest, self).tearDown()
332
333    # Clear the current device scope to avoid polluting other test cases.
334    ops.device(None).__enter__()
335    # Reset the context to avoid polluting other test cases.
336    context._reset_context()
337
338  def testReturnRemoteArgument(self):
339
340    @def_function.function
341    def local_func(i):
342      return i
343
344    with ops.device('/job:worker/replica:0/task:0'):
345      x = constant_op.constant([2, 1])
346
347    with ops.device('/job:worker/replica:0/task:1'):
348      self.assertAllEqual(local_func(x), [2, 1])
349
350  def testMultiDeviceFunctionAmbiguousDevice(self):
351
352    @def_function.function
353    def ambiguous_device(i):
354      with ops.device('/job:worker'):
355        # Multiple worker tasks, thus ambiguous device found error will be
356        # raised.
357        return i + constant_op.constant([2])
358
359    with self.assertRaises(errors.InvalidArgumentError) as cm:
360      ambiguous_device(constant_op.constant([2])).numpy()
361
362    self.assertIn('the output node must match exactly one device',
363                  cm.exception.message)
364
365  # Note that the following tests for remote function cancellation only works
366  # when non-streaming RPC. We need to disable streaming explicitly and restore
367  # this config to its initial value at the end of each test case.
368  def testCancelRemoteFunctionBeforeExecution(self):
369    remote_async_env_var = 'TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE'
370    default_streaming = os.environ.get(remote_async_env_var)
371    os.environ[remote_async_env_var] = str(False)
372
373    q = data_flow_ops.FIFOQueue(1, dtypes.int32)
374
375    @def_function.function
376    def f():
377      return q.dequeue()
378
379    c_mgr = cancellation.CancellationManager()
380    cancelable_func = c_mgr.get_cancelable_function(f.get_concrete_function())
381
382    c_mgr.start_cancel()
383    with self.assertRaises(errors.CancelledError):
384      with ops.device('/job:worker/replica:0/task:1'):
385        cancelable_func()
386
387    if default_streaming is None:
388      del os.environ[remote_async_env_var]
389    else:
390      os.environ[remote_async_env_var] = default_streaming
391
392  def testCancelRemoteFunctionDuringExecution(self):
393    remote_async_env_var = 'TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE'
394    default_streaming = os.environ.get(remote_async_env_var)
395    os.environ[remote_async_env_var] = str(False)
396
397    q = data_flow_ops.FIFOQueue(1, dtypes.int32)
398
399    @def_function.function
400    def f():
401      return q.dequeue()
402
403    c_mgr = cancellation.CancellationManager()
404    cancelable_func = c_mgr.get_cancelable_function(f.get_concrete_function())
405
406    def cancel_thread():
407      time.sleep(0.5)
408      c_mgr.start_cancel()
409
410    t = self.checkedThread(cancel_thread)
411    t.start()
412    with self.assertRaises(errors.CancelledError):
413      with ops.device('/job:worker/replica:0/task:1'):
414        cancelable_func()
415    t.join()
416
417    if default_streaming is None:
418      del os.environ[remote_async_env_var]
419    else:
420      os.environ[remote_async_env_var] = default_streaming
421
422  def testMultiDeviceFunctionOnLocalDevice(self):
423    with ops.device('/job:worker/replica:0/task:1'):
424      variable_b = variables.Variable(1.0)
425
426    @def_function.function
427    def remote_function(i):
428      with ops.device('/job:worker/replica:0/task:0'):
429        a = i + variable_b
430      c = a + 1.0
431      return c
432
433    self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
434
435  def testMultiDeviceFunctionExecutionOrderingWithPackedInput(self):
436    shape = [2]
437    with ops.device('/job:worker/replica:0/task:2/device:CPU:0'):
438      # Send 20 remote requests to simulate heavy load on worker:2.
439      unused_values = []
440      for _ in range(20):
441        unused_values.append(array_ops.zeros(shape))
442      func_input = array_ops.zeros(shape)
443
444    packed_input = ops.pack_eager_tensors([func_input])
445
446    @def_function.function
447    def func(packed_input):
448      # When worker:2 receives the component function request, packed_input
449      # should be ready on worker:2.
450      with ops.device('/job:worker/replica:0/task:2/device:CPU:0'):
451        ret = packed_input + constant_op.constant(1.0)
452      return ret + constant_op.constant(1.0)
453
454    # Run the function on a worker:1
455    with ops.device('/job:worker/replica:0/task:1/device:CPU:0'):
456      self.assertAllEqual(func(packed_input).numpy(),
457                          array_ops.ones(shape).numpy() * 2)
458
459  def testMultiDeviceFunctionWithPackedVariable(self):
460    with ops.device('/job:worker/replica:0/task:0/device:CPU:0'):
461      var0 = resource_variable_ops.ResourceVariable(1.0)
462    with ops.device('/job:worker/replica:0/task:1/device:CPU:0'):
463      var1 = resource_variable_ops.ResourceVariable(2.0)
464
465    packed_var = ops.pack_eager_tensors([var0.handle, var1.handle])
466    self.assertEqual(packed_var.device,
467                     '/job:localhost/replica:0/task:0/device:COMPOSITE:0')
468    self.assertEqual(packed_var.backing_device,
469                     '/job:localhost/replica:0/task:0/device:COMPOSITE:0')
470
471    @def_function.function
472    def add_variables():
473      with ops.device('/job:worker/replica:0/task:0/device:CPU:0'):
474        read0 = resource_variable_ops.read_variable_op(
475            packed_var, dtype=dtypes.float32)
476      with ops.device('/job:worker/replica:0/task:1/device:CPU:0'):
477        read1 = resource_variable_ops.read_variable_op(
478            packed_var, dtype=dtypes.float32)
479
480      return read0 + read1
481
482    # Run the function on a remote device
483    with ops.device('/job:worker/replica:0/task:0'):
484      self.assertAllEqual(add_variables().numpy(), 3.0)
485
486    # Run the function on a local worker
487    self.assertAllEqual(add_variables().numpy(), 3.0)
488
489  def testMultiDeviceFunctionOnRemoteDeviceWithWait(self):
490    with ops.device('/job:worker/replica:0/task:1'):
491      variable_b = variables.Variable([1.0])
492
493    @def_function.function
494    def remote_function(i):
495      x = array_ops.ones([1000, 1000])
496      for _ in range(1, 1000):
497        x = x * x
498      variable_b.assign_add(i)
499      a = 1.0 + variable_b
500      return a
501
502    @def_function.function
503    def remote_function2(i):
504      variable_b.assign_add(i)
505      a = 1.0 + variable_b
506      return a
507
508    # Runs first function:
509    # - on remote device
510    # - needs remote input
511    # - is side impacting
512    # - runs much slower
513    with ops.device('/job:worker/replica:0/task:0'):
514      remote_function(constant_op.constant([2.0]))
515
516    # Runs second function:
517    # - on remote device
518    # - is side impacting
519    # There should be a sync point here and the next function will be executed
520    # only after the first function has completed.
521    with ops.device('/job:worker/replica:0/task:2'):
522      self.assertAllEqual(remote_function2(constant_op.constant([3.0])), [7.0])
523
524  def testMultiDeviceFunctionOnRemoteDevice(self):
525    with ops.device('/job:worker/replica:0/task:1'):
526      variable_b = variables.Variable(1.0)
527
528    @def_function.function
529    def remote_function(i):
530      with ops.device('/job:worker/replica:0/task:0'):
531        a = i + variable_b
532      c = a + 1.0
533      return c
534
535    with ops.device('/job:worker/replica:0/task:0'):
536      self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
537
538    if test_util.is_gpu_available():
539      with ops.device('/job:worker/replica:0/task:0/device:GPU:0'):
540        self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
541
542  def testMultiDeviceFunctionRemoteOutput(self):
543    with ops.device('/job:worker/replica:0/task:1/cpu:0'):
544      variable_b = variables.Variable(1)
545
546    @def_function.function
547    def remote_output(i):
548      with ops.device('/job:worker/replica:0/task:1/cpu:0'):
549        c = variable_b + 1
550      return i + variable_b, c
551
552    with ops.device('/job:worker/replica:0/task:0/cpu:0'):
553      rets = remote_output(constant_op.constant([1]))
554    self.assertEqual(rets[0].backing_device,
555                     '/job:worker/replica:0/task:0/device:CPU:0')
556    self.assertEqual(rets[1].backing_device,
557                     '/job:worker/replica:0/task:1/device:CPU:0')
558    self.assertAllEqual(rets[0].numpy(), [2])
559    self.assertAllEqual(rets[1].numpy(), 2)
560
561  def testMultiDeviceWhileLoopOnRemoteDevice(self):
562    with ops.device('/job:worker/replica:0/task:1'):
563      variable_b = variables.Variable(1.0)
564
565    @def_function.function
566    def remote_function(i):
567
568      def body(i, _):
569        with ops.device('/job:worker/replica:0/task:0'):
570          a = i + variable_b
571        return a + 1.0, 1
572
573      return control_flow_ops.while_loop_v2(lambda _, d: d < 1, body, [i, 0])[0]
574
575    with ops.device('/job:worker/replica:0/task:0'):
576      self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
577
578    if test_util.is_gpu_available():
579      with ops.device('/job:worker/replica:0/task:0/device:GPU:0'):
580        self.assertAllEqual(remote_function(constant_op.constant([1.0])), [3.0])
581
582  def testSimpleParameterServer(self):
583
584    with ops.device('/job:worker/task:2/device:CPU:0'):
585      v1 = variables.Variable(initial_value=0)
586      v2 = variables.Variable(initial_value=10)
587
588    @def_function.function
589    def worker_fn():
590      v1.assign_add(1)
591      v2.assign_sub(2)
592      return v1.read_value() + v2.read_value()
593
594    with ops.device('/job:worker/task:0/device:CPU:0'):
595      self.assertAllEqual(worker_fn(), 9)
596
597    with ops.device('/job:worker/task:1/device:CPU:0'):
598      self.assertAllEqual(worker_fn(), 8)
599
600
601_GRPC_PREFIX = 'grpc://'
602
603
604class MultiJobsTest(test.TestCase, parameterized.TestCase):
605
606  def setUp(self):
607    super(MultiJobsTest, self).setUp()
608
609    workers, ps = test_util.create_local_cluster(num_workers=2, num_ps=2)
610    cluster = {
611        'my_worker': [_strip_prefix(t.target, _GRPC_PREFIX) for t in workers],
612        'my_ps': [_strip_prefix(t.target, _GRPC_PREFIX) for t in ps],
613    }
614    self._cluster = server_lib.ClusterSpec(cluster)
615    self._cluster_resolver = SimpleClusterResolver(
616        cluster_spec=self._cluster, master=ps[0].target)
617
618  def tearDown(self):
619    super(MultiJobsTest, self).tearDown()
620
621    # Clear the current device scope to avoid polluting other test cases.
622    ops.device(None).__enter__()
623    # Reset the context to avoid polluting other test cases.
624    context._reset_context()
625
626  def testMultipleDeviceFoundCheck(self):
627    remote.connect_to_cluster(self._cluster)
628
629    @def_function.function
630    def func():
631      with ops.device('cpu:0'):
632        # Multiple CPU:0 devices match would be found, but the CPU:0 from the
633        # parent device scope should be picked.
634        x = test_ops.device_placement_op()
635        y = string_ops.string_upper(x)
636        packed_var_0 = array_ops.stack([x, y], 0)
637        return packed_var_0
638
639    with ops.device('/job:my_worker/task:1'):
640      output = self.evaluate(func())
641      self.assertEqual(
642          compat.as_bytes('/job:my_worker/replica:0/task:1/device:CPU:0'),
643          output[0])
644      self.assertIn(compat.as_bytes('/JOB:MY_WORKER'), output[1])
645    with ops.device('/job:my_ps/task:1'):
646      output = self.evaluate(func())
647      self.assertEqual(
648          compat.as_bytes('/job:my_ps/replica:0/task:1/device:CPU:0'),
649          output[0])
650      self.assertIn(compat.as_bytes('/JOB:MY_PS'), output[1])
651
652  def testSimpleParameterServer(self):
653    remote.connect_to_cluster(self._cluster)
654
655    with ops.device('/job:my_ps/task:0/device:CPU:0'):
656      v1 = variables.Variable(initial_value=0)
657      v2 = variables.Variable(initial_value=10)
658
659    @def_function.function
660    def worker_fn():
661      v1.assign_add(1)
662      v2.assign_sub(2)
663      return v1.read_value() + v2.read_value()
664
665    with ops.device('/job:my_worker/task:0/device:CPU:0'):
666      self.assertAllEqual(worker_fn(), 9)
667
668    with ops.device('/job:my_worker/task:1/device:CPU:0'):
669      self.assertAllEqual(worker_fn(), 8)
670
671  def testResetClusterWithDifferentJobNames(self):
672    addr = 'localhost:%s' % portpicker.pick_unused_port()
673    cluster = server_lib.ClusterSpec({'localhost': [addr]})
674    remote.connect_to_cluster(cluster, job_name='localhost')
675    with ops.device('/job:localhost/task:0/device:CPU:0'):
676      v1 = variables.Variable(initial_value=0)
677      v1.assign_add(1)
678
679    # Replace job name from 'localhost' to 'worker' in the cluster.
680    addr = 'localhost:%s' % portpicker.pick_unused_port()
681    cluster = server_lib.ClusterSpec({'worker': [addr]})
682    remote.connect_to_cluster(cluster, job_name='worker')
683
684    with ops.device('/job:worker/task:0/device:CPU:0'):
685      v2 = variables.Variable(initial_value=0)
686      v2.assign_add(1)
687
688  # TODO(b/152224115): Re-enable this test.
689  def DISABLED_testSimpleParameterServerWithDeviceFilters(self):
690    cluster_device_filters = server_lib.ClusterDeviceFilters()
691    for i in range(2):
692      cluster_device_filters.set_device_filters('my_worker', i, ['/job:my_ps'])
693      cluster_device_filters.set_device_filters('my_ps', i, ['/job:my_worker'])
694    remote.connect_to_cluster(
695        self._cluster, cluster_device_filters=cluster_device_filters)
696
697    with ops.device('/job:my_ps/task:0/device:CPU:0'):
698      v1 = variables.Variable(initial_value=0)
699    with ops.device('/job:my_ps/task:1/device:CPU:0'):
700      v2 = variables.Variable(initial_value=10)
701
702    @def_function.function
703    def worker_fn():
704      v1.assign_add(1)
705      v2.assign_sub(2)
706      return v1.read_value() + v2.read_value()
707
708    with ops.device('/job:my_worker/task:0/device:CPU:0'):
709      self.assertAllEqual(worker_fn(), 9)
710    with ops.device('/job:my_worker/task:1/device:CPU:0'):
711      self.assertAllEqual(worker_fn(), 8)
712
713    # The following remote call would fail because the ps nodes cannot see each
714    # other due to the device filters.
715    with self.assertRaises(errors.InvalidArgumentError) as cm:
716      with ops.device('/job:my_ps/task:0/device:CPU:0'):
717        worker_fn().numpy()
718    self.assertIn('/job:my_ps/replica:0/task:1/device:CPU:0 unknown device',
719                  cm.exception.message)
720
721    with self.assertRaises(errors.InvalidArgumentError) as cm:
722      with ops.device('/job:my_ps/task:1/device:CPU:0'):
723        worker_fn().numpy()
724    self.assertIn('/job:my_ps/replica:0/task:0/device:CPU:0 unknown device',
725                  cm.exception.message)
726
727    with ops.device('/job:my_worker/task:0/device:CPU:0'):
728      self.assertAllEqual(worker_fn(), 7)
729    with ops.device('/job:my_worker/task:1/device:CPU:0'):
730      self.assertAllEqual(worker_fn(), 6)
731    # Explicitly delete variables to avoid triggering errors when being GC'ed in
732    # subsequent tests.
733    del v1, v2
734
735  def testConnectWithClusterResolver(self):
736    remote.connect_to_cluster(self._cluster_resolver)
737
738    v1 = variables.Variable(initial_value=0)
739    v2 = variables.Variable(initial_value=10)
740
741    @def_function.function
742    def worker_fn():
743      v1.assign_add(1)
744      v2.assign_sub(2)
745      return v1.read_value() + v2.read_value()
746
747    with ops.device('/job:my_worker/task:0/device:CPU:0'):
748      self.assertAllEqual(worker_fn(), 9)
749
750    with ops.device('/job:my_worker/task:1/device:CPU:0'):
751      self.assertAllEqual(worker_fn(), 8)
752
753  def testConnectToClusterTwiceOk(self):
754    remote.connect_to_cluster(self._cluster_resolver)
755    remote.connect_to_cluster(self._cluster_resolver)
756
757  def testConnectToClusterOnMismatchedDevice(self):
758    remote.connect_to_cluster(self._cluster_resolver)
759
760    # enter into another device scope.
761    ops.device('/job:my_worker/task:0/device:CPU:0').__enter__()
762
763    with self.assertRaises(ValueError):
764      remote.connect_to_cluster(self._cluster_resolver)
765
766  def testConnectToClusterWithLocalMaster(self):
767    local_resolver = SimpleClusterResolver(ClusterSpec({}), master='local')
768    remote.connect_to_cluster(local_resolver)
769
770  def testConnectToClusterInGraphModeWillFail(self):
771    ops.disable_eager_execution()
772    with self.assertRaises(ValueError):
773      remote.connect_to_cluster(self._cluster_resolver)
774    ops.enable_eager_execution()
775
776  def testConnectToClusterWithoutLocalGpu(self):
777    # Only remote workers have GPU devices
778    context.context().set_visible_devices([], 'GPU')
779    # Ensure that no default device is set in eager context
780    remote.connect_to_cluster(self._cluster_resolver,
781                              make_master_device_default=False)
782    self.assertEmpty(context.get_device_name())
783
784    v1 = variables.Variable(initial_value=0)
785    v1.assign_add(1)
786    self.assertAllEqual(v1.read_value(), 1)
787
788
789def _strip_prefix(s, prefix):
790  return s[len(prefix):] if s.startswith(prefix) else s
791
792
793if __name__ == '__main__':
794  test.main()
795