1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for MirroredStrategy."""
16
17from tensorflow.python.distribute import combinations
18from tensorflow.python.distribute import mirrored_strategy
19from tensorflow.python.distribute import multi_worker_test_base
20from tensorflow.python.distribute import strategy_test_lib
21from tensorflow.python.eager import context
22from tensorflow.python.eager import test
23
24
25def get_gpus():
26  gpus = context.context().list_logical_devices("GPU")
27  actual_gpus = []
28  for gpu in gpus:
29    if "job" in gpu.name:
30      actual_gpus.append(gpu.name)
31  return actual_gpus
32
33
34@combinations.generate(
35    combinations.combine(
36        distribution=[
37            combinations.NamedDistribution(
38                "Mirrored",
39                # pylint: disable=g-long-lambda
40                lambda: mirrored_strategy.MirroredStrategy(get_gpus()),
41                required_gpus=1)
42        ],
43        mode=["eager"]))
44class RemoteSingleWorkerMirroredStrategyEager(
45    multi_worker_test_base.SingleWorkerTestBaseEager,
46    strategy_test_lib.RemoteSingleWorkerMirroredStrategyBase):
47
48  def _get_num_gpus(self):
49    return len(get_gpus())
50
51  def testNumReplicasInSync(self, distribution):
52    self._testNumReplicasInSync(distribution)
53
54  def testMinimizeLoss(self, distribution):
55    self._testMinimizeLoss(distribution)
56
57  def testDeviceScope(self, distribution):
58    self._testDeviceScope(distribution)
59
60  def testMakeInputFnIteratorWithDataset(self, distribution):
61    self._testMakeInputFnIteratorWithDataset(distribution)
62
63  def testMakeInputFnIteratorWithCallable(self, distribution):
64    self._testMakeInputFnIteratorWithCallable(distribution)
65
66
67if __name__ == "__main__":
68  test.main()
69