xref: /aosp_15_r20/external/tensorflow/tensorflow/python/training/device_setter_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for device function for replicated training."""
16
17from tensorflow.python.framework import ops
18from tensorflow.python.framework import test_util
19from tensorflow.python.ops import resource_variable_ops
20from tensorflow.python.ops import variables
21from tensorflow.python.platform import test
22from tensorflow.python.training import device_setter
23from tensorflow.python.training import server_lib
24
25
26class DeviceSetterTest(test.TestCase):
27
28  _cluster_spec = server_lib.ClusterSpec({
29      "ps": ["ps0:2222", "ps1:2222"],
30      "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]
31  })
32
33  @test_util.run_deprecated_v1
34  def testCPUOverride(self):
35    with ops.device(
36        device_setter.replica_device_setter(cluster=self._cluster_spec)):
37      with ops.device("/cpu:0"):
38        v = variables.Variable([1, 2])
39      w = variables.Variable([2, 1])
40      with ops.device("/cpu:0"):
41        a = v + w
42      self.assertDeviceEqual("/job:ps/task:0/cpu:0", v.device)
43      self.assertDeviceEqual("/job:ps/task:0/cpu:0", v.initializer.device)
44      self.assertDeviceEqual("/job:ps/task:1", w.device)
45      self.assertDeviceEqual("/job:ps/task:1", w.initializer.device)
46      self.assertDeviceEqual("/job:worker/cpu:0", a.device)
47
48  @test_util.run_deprecated_v1
49  def testResource(self):
50    with ops.device(
51        device_setter.replica_device_setter(cluster=self._cluster_spec)):
52      v = resource_variable_ops.ResourceVariable([1, 2])
53      self.assertDeviceEqual("/job:ps/task:0", v.device)
54
55  @test_util.run_deprecated_v1
56  def testPS2TasksWithClusterSpecClass(self):
57    with ops.device(
58        device_setter.replica_device_setter(cluster=self._cluster_spec)):
59      v = variables.Variable([1, 2])
60      w = variables.Variable([2, 1])
61      a = v + w
62      self.assertDeviceEqual("/job:ps/task:0", v.device)
63      self.assertDeviceEqual("/job:ps/task:0", v.initializer.device)
64      self.assertDeviceEqual("/job:ps/task:1", w.device)
65      self.assertDeviceEqual("/job:ps/task:1", w.initializer.device)
66      self.assertDeviceEqual("/job:worker", a.device)
67
68  @test_util.run_deprecated_v1
69  def testPS2TasksPinVariableToJob(self):
70    with ops.device(
71        device_setter.replica_device_setter(cluster=self._cluster_spec)):
72      v = variables.Variable([1, 2])
73      with ops.device("/job:moon"):
74        w = variables.Variable([2, 1])
75        with ops.device("/job:ps"):  # Explicit PS job will get task set.
76          x = variables.Variable([0, 1])
77      a = v + w + x
78      self.assertDeviceEqual("/job:ps/task:0", v.device)
79      self.assertDeviceEqual("/job:ps/task:0", v.initializer.device)
80      self.assertDeviceEqual("/job:moon", w.device)
81      self.assertDeviceEqual("/job:moon", w.initializer.device)
82      self.assertDeviceEqual("/job:ps/task:1", x.device)
83      self.assertDeviceEqual("/job:ps/task:1", x.initializer.device)
84      self.assertDeviceEqual("/job:worker", a.device)
85
86  @test_util.run_deprecated_v1
87  def testPS2TasksUseCpuForPS(self):
88    with ops.device(
89        device_setter.replica_device_setter(ps_tasks=1, ps_device="/cpu:0")):
90      v = variables.Variable([1, 2])
91      with ops.device("/job:moon"):
92        w = variables.Variable([2, 1])
93      a = v + w
94      self.assertDeviceEqual("/cpu:0", v.device)
95      self.assertDeviceEqual("/cpu:0", v.initializer.device)
96      self.assertDeviceEqual("/job:moon/cpu:0", w.device)
97      self.assertDeviceEqual("/job:moon/cpu:0", w.initializer.device)
98      self.assertDeviceEqual("/job:worker", a.device)
99
100  @test_util.run_deprecated_v1
101  def testPS2TasksNoMerging(self):
102    with ops.device(
103        device_setter.replica_device_setter(
104            cluster=self._cluster_spec, merge_devices=False)):
105      v = variables.Variable([1, 2])
106      with ops.device("/job:ps"):  # Won't assign task when merge_devices=False.
107        w = variables.Variable([2, 1])
108      a = v + w
109      self.assertDeviceEqual("/job:ps/task:0", v.device)
110      self.assertDeviceEqual("/job:ps/task:0", v.initializer.device)
111      self.assertDeviceEqual("/job:ps", w.device)
112      self.assertDeviceEqual("/job:ps", w.initializer.device)
113      self.assertDeviceEqual("/job:worker", a.device)
114
115  @test_util.run_deprecated_v1
116  def testPS2TasksWithClusterSpecDict(self):
117    with ops.device(
118        device_setter.replica_device_setter(cluster=self._cluster_spec.as_dict(
119        ))):
120      v = variables.Variable([1, 2])
121      w = variables.Variable([2, 1])
122      a = v + w
123      self.assertDeviceEqual("/job:ps/task:0", v.device)
124      self.assertDeviceEqual("/job:ps/task:0", v.initializer.device)
125      self.assertDeviceEqual("/job:ps/task:1", w.device)
126      self.assertDeviceEqual("/job:ps/task:1", w.initializer.device)
127      self.assertDeviceEqual("/job:worker", a.device)
128
129  @test_util.run_deprecated_v1
130  def testPS2TasksWithClusterDef(self):
131    with ops.device(
132        device_setter.replica_device_setter(
133            cluster=self._cluster_spec.as_cluster_def())):
134      v = variables.Variable([1, 2])
135      w = variables.Variable([2, 1])
136      a = v + w
137      self.assertDeviceEqual("/job:ps/task:0", v.device)
138      self.assertDeviceEqual("/job:ps/task:0", v.initializer.device)
139      self.assertDeviceEqual("/job:ps/task:1", w.device)
140      self.assertDeviceEqual("/job:ps/task:1", w.initializer.device)
141      self.assertDeviceEqual("/job:worker", a.device)
142
143  @test_util.run_deprecated_v1
144  def testPS2TasksWithDevice(self):
145    cluster_spec = server_lib.ClusterSpec({
146        "sun": ["sun0:2222", "sun1:2222", "sun2:2222"],
147        "moon": ["moon0:2222", "moon1:2222"]
148    })
149
150    with ops.device(
151        device_setter.replica_device_setter(
152            ps_device="/job:moon",
153            worker_device="/job:sun",
154            cluster=cluster_spec.as_cluster_def())):
155      v = variables.Variable([1, 2])
156      w = variables.Variable([2, 1])
157      a = v + w
158      self.assertDeviceEqual("/job:moon/task:0", v.device)
159      self.assertDeviceEqual("/job:moon/task:0", v.initializer.device)
160      self.assertDeviceEqual("/job:moon/task:1", w.device)
161      self.assertDeviceEqual("/job:moon/task:1", w.initializer.device)
162      self.assertDeviceEqual("/job:sun", a.device)
163
164  @test_util.run_deprecated_v1
165  def testPS2TasksWithCPUConstraint(self):
166    cluster_spec = server_lib.ClusterSpec({
167        "sun": ["sun0:2222", "sun1:2222", "sun2:2222"],
168        "moon": ["moon0:2222", "moon1:2222"]
169    })
170
171    with ops.device(
172        device_setter.replica_device_setter(
173            ps_device="/job:moon/cpu:0",
174            worker_device="/job:sun",
175            cluster=cluster_spec.as_cluster_def())):
176      v = variables.Variable([1, 2])
177      w = variables.Variable([2, 1])
178      a = v + w
179      self.assertDeviceEqual("/job:moon/task:0/cpu:0", v.device)
180      self.assertDeviceEqual("/job:moon/task:0/cpu:0", v.initializer.device)
181      self.assertDeviceEqual("/job:moon/task:1/cpu:0", w.device)
182      self.assertDeviceEqual("/job:moon/task:1/cpu:0", w.initializer.device)
183      self.assertDeviceEqual("/job:sun", a.device)
184
185
186if __name__ == "__main__":
187  test.main()
188