1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for device function for replicated training.""" 16 17from tensorflow.python.framework import ops 18from tensorflow.python.framework import test_util 19from tensorflow.python.ops import resource_variable_ops 20from tensorflow.python.ops import variables 21from tensorflow.python.platform import test 22from tensorflow.python.training import device_setter 23from tensorflow.python.training import server_lib 24 25 26class DeviceSetterTest(test.TestCase): 27 28 _cluster_spec = server_lib.ClusterSpec({ 29 "ps": ["ps0:2222", "ps1:2222"], 30 "worker": ["worker0:2222", "worker1:2222", "worker2:2222"] 31 }) 32 33 @test_util.run_deprecated_v1 34 def testCPUOverride(self): 35 with ops.device( 36 device_setter.replica_device_setter(cluster=self._cluster_spec)): 37 with ops.device("/cpu:0"): 38 v = variables.Variable([1, 2]) 39 w = variables.Variable([2, 1]) 40 with ops.device("/cpu:0"): 41 a = v + w 42 self.assertDeviceEqual("/job:ps/task:0/cpu:0", v.device) 43 self.assertDeviceEqual("/job:ps/task:0/cpu:0", v.initializer.device) 44 self.assertDeviceEqual("/job:ps/task:1", w.device) 45 self.assertDeviceEqual("/job:ps/task:1", w.initializer.device) 46 self.assertDeviceEqual("/job:worker/cpu:0", a.device) 47 48 @test_util.run_deprecated_v1 49 def testResource(self): 50 with ops.device( 51 device_setter.replica_device_setter(cluster=self._cluster_spec)): 52 v = resource_variable_ops.ResourceVariable([1, 2]) 53 self.assertDeviceEqual("/job:ps/task:0", v.device) 54 55 @test_util.run_deprecated_v1 56 def testPS2TasksWithClusterSpecClass(self): 57 with ops.device( 58 device_setter.replica_device_setter(cluster=self._cluster_spec)): 59 v = variables.Variable([1, 2]) 60 w = variables.Variable([2, 1]) 61 a = v + w 62 self.assertDeviceEqual("/job:ps/task:0", v.device) 63 self.assertDeviceEqual("/job:ps/task:0", v.initializer.device) 64 self.assertDeviceEqual("/job:ps/task:1", w.device) 65 self.assertDeviceEqual("/job:ps/task:1", w.initializer.device) 66 self.assertDeviceEqual("/job:worker", a.device) 67 68 @test_util.run_deprecated_v1 69 def testPS2TasksPinVariableToJob(self): 70 with ops.device( 71 device_setter.replica_device_setter(cluster=self._cluster_spec)): 72 v = variables.Variable([1, 2]) 73 with ops.device("/job:moon"): 74 w = variables.Variable([2, 1]) 75 with ops.device("/job:ps"): # Explicit PS job will get task set. 76 x = variables.Variable([0, 1]) 77 a = v + w + x 78 self.assertDeviceEqual("/job:ps/task:0", v.device) 79 self.assertDeviceEqual("/job:ps/task:0", v.initializer.device) 80 self.assertDeviceEqual("/job:moon", w.device) 81 self.assertDeviceEqual("/job:moon", w.initializer.device) 82 self.assertDeviceEqual("/job:ps/task:1", x.device) 83 self.assertDeviceEqual("/job:ps/task:1", x.initializer.device) 84 self.assertDeviceEqual("/job:worker", a.device) 85 86 @test_util.run_deprecated_v1 87 def testPS2TasksUseCpuForPS(self): 88 with ops.device( 89 device_setter.replica_device_setter(ps_tasks=1, ps_device="/cpu:0")): 90 v = variables.Variable([1, 2]) 91 with ops.device("/job:moon"): 92 w = variables.Variable([2, 1]) 93 a = v + w 94 self.assertDeviceEqual("/cpu:0", v.device) 95 self.assertDeviceEqual("/cpu:0", v.initializer.device) 96 self.assertDeviceEqual("/job:moon/cpu:0", w.device) 97 self.assertDeviceEqual("/job:moon/cpu:0", w.initializer.device) 98 self.assertDeviceEqual("/job:worker", a.device) 99 100 @test_util.run_deprecated_v1 101 def testPS2TasksNoMerging(self): 102 with ops.device( 103 device_setter.replica_device_setter( 104 cluster=self._cluster_spec, merge_devices=False)): 105 v = variables.Variable([1, 2]) 106 with ops.device("/job:ps"): # Won't assign task when merge_devices=False. 107 w = variables.Variable([2, 1]) 108 a = v + w 109 self.assertDeviceEqual("/job:ps/task:0", v.device) 110 self.assertDeviceEqual("/job:ps/task:0", v.initializer.device) 111 self.assertDeviceEqual("/job:ps", w.device) 112 self.assertDeviceEqual("/job:ps", w.initializer.device) 113 self.assertDeviceEqual("/job:worker", a.device) 114 115 @test_util.run_deprecated_v1 116 def testPS2TasksWithClusterSpecDict(self): 117 with ops.device( 118 device_setter.replica_device_setter(cluster=self._cluster_spec.as_dict( 119 ))): 120 v = variables.Variable([1, 2]) 121 w = variables.Variable([2, 1]) 122 a = v + w 123 self.assertDeviceEqual("/job:ps/task:0", v.device) 124 self.assertDeviceEqual("/job:ps/task:0", v.initializer.device) 125 self.assertDeviceEqual("/job:ps/task:1", w.device) 126 self.assertDeviceEqual("/job:ps/task:1", w.initializer.device) 127 self.assertDeviceEqual("/job:worker", a.device) 128 129 @test_util.run_deprecated_v1 130 def testPS2TasksWithClusterDef(self): 131 with ops.device( 132 device_setter.replica_device_setter( 133 cluster=self._cluster_spec.as_cluster_def())): 134 v = variables.Variable([1, 2]) 135 w = variables.Variable([2, 1]) 136 a = v + w 137 self.assertDeviceEqual("/job:ps/task:0", v.device) 138 self.assertDeviceEqual("/job:ps/task:0", v.initializer.device) 139 self.assertDeviceEqual("/job:ps/task:1", w.device) 140 self.assertDeviceEqual("/job:ps/task:1", w.initializer.device) 141 self.assertDeviceEqual("/job:worker", a.device) 142 143 @test_util.run_deprecated_v1 144 def testPS2TasksWithDevice(self): 145 cluster_spec = server_lib.ClusterSpec({ 146 "sun": ["sun0:2222", "sun1:2222", "sun2:2222"], 147 "moon": ["moon0:2222", "moon1:2222"] 148 }) 149 150 with ops.device( 151 device_setter.replica_device_setter( 152 ps_device="/job:moon", 153 worker_device="/job:sun", 154 cluster=cluster_spec.as_cluster_def())): 155 v = variables.Variable([1, 2]) 156 w = variables.Variable([2, 1]) 157 a = v + w 158 self.assertDeviceEqual("/job:moon/task:0", v.device) 159 self.assertDeviceEqual("/job:moon/task:0", v.initializer.device) 160 self.assertDeviceEqual("/job:moon/task:1", w.device) 161 self.assertDeviceEqual("/job:moon/task:1", w.initializer.device) 162 self.assertDeviceEqual("/job:sun", a.device) 163 164 @test_util.run_deprecated_v1 165 def testPS2TasksWithCPUConstraint(self): 166 cluster_spec = server_lib.ClusterSpec({ 167 "sun": ["sun0:2222", "sun1:2222", "sun2:2222"], 168 "moon": ["moon0:2222", "moon1:2222"] 169 }) 170 171 with ops.device( 172 device_setter.replica_device_setter( 173 ps_device="/job:moon/cpu:0", 174 worker_device="/job:sun", 175 cluster=cluster_spec.as_cluster_def())): 176 v = variables.Variable([1, 2]) 177 w = variables.Variable([2, 1]) 178 a = v + w 179 self.assertDeviceEqual("/job:moon/task:0/cpu:0", v.device) 180 self.assertDeviceEqual("/job:moon/task:0/cpu:0", v.initializer.device) 181 self.assertDeviceEqual("/job:moon/task:1/cpu:0", w.device) 182 self.assertDeviceEqual("/job:moon/task:1/cpu:0", w.initializer.device) 183 self.assertDeviceEqual("/job:sun", a.device) 184 185 186if __name__ == "__main__": 187 test.main() 188