1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Device-related support functions.""" 16 17 18 19from tensorflow.python.eager import context 20from tensorflow.python.framework import config 21from tensorflow.python.framework import device as tf_device 22from tensorflow.python.framework import ops 23 24 25def canonicalize(d, default=None): 26 """Canonicalize device string. 27 28 If d has missing components, the rest would be deduced from the `default` 29 argument or from '/replica:0/task:0/device:CPU:0'. For example: 30 If d = '/cpu:0', default='/job:worker/task:1', it returns 31 '/job:worker/replica:0/task:1/device:CPU:0'. 32 If d = '/cpu:0', default='/job:worker', it returns 33 '/job:worker/replica:0/task:0/device:CPU:0'. 34 If d = '/gpu:0', default=None, it returns 35 '/replica:0/task:0/device:GPU:0'. 36 37 Note: This uses "job:localhost" as the default if executing eagerly. 38 39 Args: 40 d: a device string or tf.config.LogicalDevice 41 default: a string for default device if d doesn't have all components. 42 43 Returns: 44 a canonicalized device string. 45 """ 46 if isinstance(d, context.LogicalDevice): 47 d = tf_device.DeviceSpec.from_string(d.name) 48 else: 49 d = tf_device.DeviceSpec.from_string(d) 50 51 assert d.device_type is None or d.device_type == d.device_type.upper(), ( 52 "Device type '%s' must be all-caps." % (d.device_type,)) 53 # Fill in missing device fields using defaults. 54 result = tf_device.DeviceSpec( 55 replica=0, task=0, device_type="CPU", device_index=0) 56 if ops.executing_eagerly_outside_functions(): 57 # Try to deduce job, replica and task in case it's in a multi worker setup. 58 # TODO(b/151452748): Using list_logical_devices is not always safe since it 59 # may return remote devices as well, but we're already doing this elsewhere. 60 host_cpu = tf_device.DeviceSpec.from_string( 61 config.list_logical_devices("CPU")[0].name) 62 if host_cpu.job: 63 result = result.make_merged_spec(host_cpu) 64 else: 65 # The default job is localhost if eager execution is enabled 66 result = result.replace(job="localhost") 67 if default: 68 # Overrides any defaults with values from the default device if given. 69 result = result.make_merged_spec( 70 tf_device.DeviceSpec.from_string(default)) 71 72 # Apply `d` last, so that it's values take precedence over the defaults. 73 result = result.make_merged_spec(d) 74 return result.to_string() 75 76 77def canonicalize_without_job_and_task(d): 78 """Partially canonicalize device string. 79 80 This returns device string from `d` without including job and task. 81 This is most useful for parameter server strategy where the device strings are 82 generated on the chief, but executed on workers. 83 84 For example: 85 If d = '/cpu:0', default='/job:worker/task:1', it returns 86 '/replica:0/device:CPU:0'. 87 If d = '/cpu:0', default='/job:worker', it returns 88 '/replica:0/device:CPU:0'. 89 If d = '/gpu:0', default=None, it returns 90 '/replica:0/device:GPU:0'. 91 92 Note: This uses "job:localhost" as the default if executing eagerly. 93 94 Args: 95 d: a device string or tf.config.LogicalDevice 96 97 Returns: 98 a partially canonicalized device string. 99 """ 100 canonicalized_device = canonicalize(d) 101 spec = tf_device.DeviceSpec.from_string(canonicalized_device) 102 spec = spec.replace(job=None, task=None, replica=0) 103 return spec.to_string() 104 105 106def resolve(d): 107 """Canonicalize `d` with current device as default.""" 108 return canonicalize(d, default=current()) 109 110 111class _FakeNodeDef(object): 112 """A fake NodeDef for _FakeOperation.""" 113 114 __slots__ = ["op", "name"] 115 116 def __init__(self): 117 self.op = "" 118 self.name = "" 119 120 121class _FakeOperation(object): 122 """A fake Operation object to pass to device functions.""" 123 124 def __init__(self): 125 self.device = "" 126 self.type = "" 127 self.name = "" 128 self.node_def = _FakeNodeDef() 129 130 def _set_device(self, device): 131 self.device = ops._device_string(device) # pylint: disable=protected-access 132 133 def _set_device_from_string(self, device_str): 134 self.device = device_str 135 136 137def current(): 138 """Return a string (not canonicalized) for the current device.""" 139 # TODO(josh11b): Work out how this function interacts with ops.colocate_with. 140 if ops.executing_eagerly_outside_functions(): 141 d = context.context().device_name 142 else: 143 op = _FakeOperation() 144 ops.get_default_graph()._apply_device_functions(op) # pylint: disable=protected-access 145 d = op.device 146 return d 147 148 149def get_host_for_device(device): 150 """Returns the corresponding host device for the given device.""" 151 spec = tf_device.DeviceSpec.from_string(device) 152 return tf_device.DeviceSpec( 153 job=spec.job, replica=spec.replica, task=spec.task, 154 device_type="CPU", device_index=0).to_string() 155 156 157def local_devices_from_num_gpus(num_gpus): 158 """Returns device strings for local GPUs or CPU.""" 159 return (tuple("/device:GPU:%d" % i for i in range(num_gpus)) or 160 ("/device:CPU:0",)) 161