xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/device_util.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Device-related support functions."""
16
17
18
19from tensorflow.python.eager import context
20from tensorflow.python.framework import config
21from tensorflow.python.framework import device as tf_device
22from tensorflow.python.framework import ops
23
24
25def canonicalize(d, default=None):
26  """Canonicalize device string.
27
28  If d has missing components, the rest would be deduced from the `default`
29  argument or from '/replica:0/task:0/device:CPU:0'. For example:
30    If d = '/cpu:0', default='/job:worker/task:1', it returns
31      '/job:worker/replica:0/task:1/device:CPU:0'.
32    If d = '/cpu:0', default='/job:worker', it returns
33      '/job:worker/replica:0/task:0/device:CPU:0'.
34    If d = '/gpu:0', default=None, it returns
35      '/replica:0/task:0/device:GPU:0'.
36
37  Note: This uses "job:localhost" as the default if executing eagerly.
38
39  Args:
40    d: a device string or tf.config.LogicalDevice
41    default: a string for default device if d doesn't have all components.
42
43  Returns:
44    a canonicalized device string.
45  """
46  if isinstance(d, context.LogicalDevice):
47    d = tf_device.DeviceSpec.from_string(d.name)
48  else:
49    d = tf_device.DeviceSpec.from_string(d)
50
51  assert d.device_type is None or d.device_type == d.device_type.upper(), (
52      "Device type '%s' must be all-caps." % (d.device_type,))
53  # Fill in missing device fields using defaults.
54  result = tf_device.DeviceSpec(
55      replica=0, task=0, device_type="CPU", device_index=0)
56  if ops.executing_eagerly_outside_functions():
57    # Try to deduce job, replica and task in case it's in a multi worker setup.
58    # TODO(b/151452748): Using list_logical_devices is not always safe since it
59    # may return remote devices as well, but we're already doing this elsewhere.
60    host_cpu = tf_device.DeviceSpec.from_string(
61        config.list_logical_devices("CPU")[0].name)
62    if host_cpu.job:
63      result = result.make_merged_spec(host_cpu)
64    else:
65      # The default job is localhost if eager execution is enabled
66      result = result.replace(job="localhost")
67  if default:
68    # Overrides any defaults with values from the default device if given.
69    result = result.make_merged_spec(
70        tf_device.DeviceSpec.from_string(default))
71
72  # Apply `d` last, so that it's values take precedence over the defaults.
73  result = result.make_merged_spec(d)
74  return result.to_string()
75
76
77def canonicalize_without_job_and_task(d):
78  """Partially canonicalize device string.
79
80  This returns device string from `d` without including job and task.
81  This is most useful for parameter server strategy where the device strings are
82  generated on the chief, but executed on workers.
83
84   For example:
85    If d = '/cpu:0', default='/job:worker/task:1', it returns
86      '/replica:0/device:CPU:0'.
87    If d = '/cpu:0', default='/job:worker', it returns
88      '/replica:0/device:CPU:0'.
89    If d = '/gpu:0', default=None, it returns
90      '/replica:0/device:GPU:0'.
91
92  Note: This uses "job:localhost" as the default if executing eagerly.
93
94  Args:
95    d: a device string or tf.config.LogicalDevice
96
97  Returns:
98    a partially canonicalized device string.
99  """
100  canonicalized_device = canonicalize(d)
101  spec = tf_device.DeviceSpec.from_string(canonicalized_device)
102  spec = spec.replace(job=None, task=None, replica=0)
103  return spec.to_string()
104
105
106def resolve(d):
107  """Canonicalize `d` with current device as default."""
108  return canonicalize(d, default=current())
109
110
111class _FakeNodeDef(object):
112  """A fake NodeDef for _FakeOperation."""
113
114  __slots__ = ["op", "name"]
115
116  def __init__(self):
117    self.op = ""
118    self.name = ""
119
120
121class _FakeOperation(object):
122  """A fake Operation object to pass to device functions."""
123
124  def __init__(self):
125    self.device = ""
126    self.type = ""
127    self.name = ""
128    self.node_def = _FakeNodeDef()
129
130  def _set_device(self, device):
131    self.device = ops._device_string(device)  # pylint: disable=protected-access
132
133  def _set_device_from_string(self, device_str):
134    self.device = device_str
135
136
137def current():
138  """Return a string (not canonicalized) for the current device."""
139  # TODO(josh11b): Work out how this function interacts with ops.colocate_with.
140  if ops.executing_eagerly_outside_functions():
141    d = context.context().device_name
142  else:
143    op = _FakeOperation()
144    ops.get_default_graph()._apply_device_functions(op)  # pylint: disable=protected-access
145    d = op.device
146  return d
147
148
149def get_host_for_device(device):
150  """Returns the corresponding host device for the given device."""
151  spec = tf_device.DeviceSpec.from_string(device)
152  return tf_device.DeviceSpec(
153      job=spec.job, replica=spec.replica, task=spec.task,
154      device_type="CPU", device_index=0).to_string()
155
156
157def local_devices_from_num_gpus(num_gpus):
158  """Returns device strings for local GPUs or CPU."""
159  return (tuple("/device:GPU:%d" % i for i in range(num_gpus)) or
160          ("/device:CPU:0",))
161