xref: /aosp_15_r20/external/tensorflow/tensorflow/python/framework/device_spec.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Class to represent a device."""
16
17from tensorflow.python.util.tf_export import tf_export
18from tensorflow.python import pywrap_tfe
19
20# EPU represents for TPU embedding for now. Subject to change in future.
21_VALID_DEVICE_TYPES = frozenset({"CPU", "GPU", "TPU", "CUSTOM", "EPU"})
22
23# ==============================================================================
24# == Global Implementation Details =============================================
25# ==============================================================================
26_STRING_TO_COMPONENTS_CACHE = {}
27_COMPONENTS_TO_STRING_CACHE = {}
28
29
30def _as_str_or_none(inp):
31  return None if inp is None else str(inp)
32
33
34def _as_int_or_none(inp):
35  return None if inp is None else int(inp)
36
37
38def _as_device_str_or_none(device_type):
39  # For backwards compatibility only, we support lowercase variants of
40  # cpu and gpu but turn them into uppercase here.
41  if device_type in ("cpu", "gpu"):
42    return device_type.upper()
43  return _as_str_or_none(device_type)
44
45
46@tf_export("DeviceSpec", v1=[])
47class DeviceSpecV2(object):
48  """Represents a (possibly partial) specification for a TensorFlow device.
49
50  `DeviceSpec`s are used throughout TensorFlow to describe where state is stored
51  and computations occur. Using `DeviceSpec` allows you to parse device spec
52  strings to verify their validity, merge them or compose them programmatically.
53
54  Example:
55
56  ```python
57  # Place the operations on device "GPU:0" in the "ps" job.
58  device_spec = DeviceSpec(job="ps", device_type="GPU", device_index=0)
59  with tf.device(device_spec.to_string()):
60    # Both my_var and squared_var will be placed on /job:ps/device:GPU:0.
61    my_var = tf.Variable(..., name="my_variable")
62    squared_var = tf.square(my_var)
63  ```
64
65  With eager execution disabled (by default in TensorFlow 1.x and by calling
66  disable_eager_execution() in TensorFlow 2.x), the following syntax
67  can be used:
68
69  ```python
70  tf.compat.v1.disable_eager_execution()
71
72  # Same as previous
73  device_spec = DeviceSpec(job="ps", device_type="GPU", device_index=0)
74  # No need of .to_string() method.
75  with tf.device(device_spec):
76    my_var = tf.Variable(..., name="my_variable")
77    squared_var = tf.square(my_var)
78  ```
79
80  If a `DeviceSpec` is partially specified, it will be merged with other
81  `DeviceSpec`s according to the scope in which it is defined. `DeviceSpec`
82  components defined in inner scopes take precedence over those defined in
83  outer scopes.
84
85  ```python
86  gpu0_spec = DeviceSpec(job="ps", device_type="GPU", device_index=0)
87  with tf.device(DeviceSpec(job="train").to_string()):
88    with tf.device(gpu0_spec.to_string()):
89      # Nodes created here will be assigned to /job:ps/device:GPU:0.
90    with tf.device(DeviceSpec(device_type="GPU", device_index=1).to_string()):
91      # Nodes created here will be assigned to /job:train/device:GPU:1.
92  ```
93
94  A `DeviceSpec` consists of 5 components -- each of
95  which is optionally specified:
96
97  * Job: The job name.
98  * Replica: The replica index.
99  * Task: The task index.
100  * Device type: The device type string (e.g. "CPU" or "GPU").
101  * Device index: The device index.
102  """
103
104  __slots__ = ("_job", "_replica", "_task", "_device_type", "_device_index",
105               "_as_string", "_hash")
106
107  def __init__(self,
108               job=None,
109               replica=None,
110               task=None,
111               device_type=None,
112               device_index=None):
113    """Create a new `DeviceSpec` object.
114
115    Args:
116      job: string.  Optional job name.
117      replica: int.  Optional replica index.
118      task: int.  Optional task index.
119      device_type: Optional device type string (e.g. "CPU" or "GPU")
120      device_index: int.  Optional device index.  If left unspecified, device
121        represents 'any' device_index.
122    """
123    self._job = _as_str_or_none(job)
124    self._replica = _as_int_or_none(replica)
125    self._task = _as_int_or_none(task)
126    self._device_type = _as_device_str_or_none(device_type)
127    self._device_index = _as_int_or_none(device_index)
128    self._as_string = self._components_to_string(
129        job=self._job,
130        replica=self._replica,
131        task=self._task,
132        device_type=self._device_type,
133        device_index=self._device_index)
134    self._hash = hash(self.to_string())
135
136  def to_string(self):
137    """Return a string representation of this `DeviceSpec`.
138
139    Returns:
140      a string of the form
141      /job:<name>/replica:<id>/task:<id>/device:<device_type>:<id>.
142    """
143    return self._as_string
144
145  @classmethod
146  def from_string(cls, spec):
147    """Construct a `DeviceSpec` from a string.
148
149    Args:
150      spec: a string of the form
151       /job:<name>/replica:<id>/task:<id>/device:CPU:<id> or
152       /job:<name>/replica:<id>/task:<id>/device:GPU:<id> as cpu and gpu are
153         mutually exclusive. All entries are optional.
154
155    Returns:
156      A DeviceSpec.
157    """
158    return cls(*cls._string_to_components(spec))
159
160  def parse_from_string(self, spec):
161    """Parse a `DeviceSpec` name into its components.
162
163    **2.x behavior change**:
164
165    In TensorFlow 1.x, this function mutates its own state and returns itself.
166    In 2.x, DeviceSpecs are immutable, and this function will return a
167      DeviceSpec which contains the spec.
168
169    * Recommended:
170
171      ```
172      # my_spec and my_updated_spec are unrelated.
173      my_spec = tf.DeviceSpec.from_string("/CPU:0")
174      my_updated_spec = tf.DeviceSpec.from_string("/GPU:0")
175      with tf.device(my_updated_spec):
176        ...
177      ```
178
179    * Will work in 1.x and 2.x (though deprecated in 2.x):
180
181      ```
182      my_spec = tf.DeviceSpec.from_string("/CPU:0")
183      my_updated_spec = my_spec.parse_from_string("/GPU:0")
184      with tf.device(my_updated_spec):
185        ...
186      ```
187
188    * Will NOT work in 2.x:
189
190      ```
191      my_spec = tf.DeviceSpec.from_string("/CPU:0")
192      my_spec.parse_from_string("/GPU:0")  # <== Will not update my_spec
193      with tf.device(my_spec):
194        ...
195      ```
196
197    In general, `DeviceSpec.from_string` should completely replace
198    `DeviceSpec.parse_from_string`, and `DeviceSpec.replace` should
199    completely replace setting attributes directly.
200
201    Args:
202      spec: an optional string of the form
203       /job:<name>/replica:<id>/task:<id>/device:CPU:<id> or
204       /job:<name>/replica:<id>/task:<id>/device:GPU:<id> as cpu and gpu are
205         mutually exclusive. All entries are optional.
206
207    Returns:
208      The `DeviceSpec`.
209
210    Raises:
211      ValueError: if the spec was not valid.
212    """
213    return self.from_string(spec)
214
215  def make_merged_spec(self, dev):
216    """Returns a new DeviceSpec which incorporates `dev`.
217
218    When combining specs, `dev` will take precedence over the current spec.
219    So for instance:
220    ```
221    first_spec = tf.DeviceSpec(job=0, device_type="CPU")
222    second_spec = tf.DeviceSpec(device_type="GPU")
223    combined_spec = first_spec.make_merged_spec(second_spec)
224    ```
225
226    is equivalent to:
227    ```
228    combined_spec = tf.DeviceSpec(job=0, device_type="GPU")
229    ```
230
231    Args:
232      dev: a `DeviceSpec`
233
234    Returns:
235      A new `DeviceSpec` which combines `self` and `dev`
236    """
237    return self.__class__(*self._get_combined_properties(dev))
238
239  def replace(self, **kwargs):
240    """Convenience method for making a new DeviceSpec by overriding fields.
241
242    For instance:
243    ```
244    my_spec = DeviceSpec=(job="my_job", device="CPU")
245    my_updated_spec = my_spec.replace(device="GPU")
246    my_other_spec = my_spec.replace(device=None)
247    ```
248
249    Args:
250      **kwargs: This method takes the same args as the DeviceSpec constructor
251
252    Returns:
253      A DeviceSpec with the fields specified in kwargs overridden.
254    """
255    init_kwargs = dict(
256        job=self.job,
257        replica=self.replica,
258        task=self.task,
259        device_type=self.device_type,
260        device_index=self.device_index)
261
262    # Explicitly provided kwargs take precedence.
263    init_kwargs.update(kwargs)
264    return self.__class__(**init_kwargs)
265
266  @property
267  def job(self):
268    return self._job
269
270  @property
271  def replica(self):
272    return self._replica
273
274  @property
275  def task(self):
276    return self._task
277
278  @property
279  def device_type(self):
280    return self._device_type
281
282  @property
283  def device_index(self):
284    return self._device_index
285
286  def _get_combined_properties(self, dev):
287    """Combine the current DeviceSpec with another DeviceSpec.
288
289    The combination of DeviceSpecs is will give priority to dev.
290
291    Args:
292      dev: a `DeviceSpec`
293
294    Returns:
295      A tuple of (job, replica, task, device_type, device_index) which
296      represents the combination of self and dev.
297    """
298    return (
299        dev.job if dev.job is not None else self.job,
300        dev.replica if dev.replica is not None else self.replica,
301        dev.task if dev.task is not None else self.task,
302        dev.device_type if dev.device_type is not None else self.device_type,
303        dev.device_index if dev.device_index is not None else self.device_index,
304    )
305
306  @staticmethod
307  def _get_valid_device_types():
308    valid_device_types = set({})
309    physical_devices = pywrap_tfe.TF_ListPluggablePhysicalDevices()
310    for device in physical_devices:
311      valid_device_types.add(device.decode().split(":")[1])
312    valid_device_types = valid_device_types | _VALID_DEVICE_TYPES
313    return valid_device_types
314
315  @staticmethod
316  def _string_to_components(spec=None):
317    """Stateless portion of device spec string parsing.
318
319    Args:
320      spec: An optional string specifying a device specification.
321
322    Returns:
323      The parsed components of `spec`. Note that the result of this function
324      must go through attribute setters of DeviceSpec, and should therefore NOT
325      be used directly.
326    """
327    cached_result = _STRING_TO_COMPONENTS_CACHE.get(spec)
328    if cached_result is not None:
329      return cached_result
330
331    raw_spec = spec  # keep a copy of the original to update the cache
332    job, replica, task, device_type, device_index = None, None, None, None, None
333
334    spec = spec or ""
335    splits = [x.split(":") for x in spec.split("/")]
336    valid_device_types = DeviceSpecV2._get_valid_device_types()
337    for y in splits:
338      ly = len(y)
339      if y:
340        # NOTE(taylorrobie): these will go through setters later.
341        if ly == 2 and y[0] == "job":
342          job = y[1]
343        elif ly == 2 and y[0] == "replica":
344          replica = y[1]
345        elif ly == 2 and y[0] == "task":
346          task = y[1]
347        elif ((ly == 1 or ly == 2) and (y[0].upper() in valid_device_types)):
348          if device_type is not None:
349            raise ValueError(f"Multiple device types are not allowed "
350                             f"while parsing the device spec: {spec}.")
351          device_type = y[0].upper()
352          if ly == 2 and y[1] != "*":
353            device_index = int(y[1])
354        elif ly == 3 and y[0] == "device":
355          if device_type is not None:
356            raise ValueError(f"Multiple device types are not allowed "
357                             f"while parsing the device spec: {spec}.")
358          device_type = y[1]
359          if y[2] != "*":
360            device_index = int(y[2])
361        elif ly and y[0] != "":  # pylint: disable=g-explicit-bool-comparison
362          raise ValueError(f"Unknown attribute '{y[0]}' is encountered "
363                           f"while parsing the device spec: '{spec}'.")
364
365    output = (job, replica, task, device_type, device_index)
366    _STRING_TO_COMPONENTS_CACHE[raw_spec] = output
367    return output
368
369  @staticmethod
370  def _components_to_string(job, replica, task, device_type, device_index):
371    """Stateless portion of `to_string` (separated to allow caching)."""
372    key = (job, replica, task, device_type, device_index)
373    cached_result = _COMPONENTS_TO_STRING_CACHE.get(key)
374    if cached_result is not None:
375      return cached_result
376
377    output = []
378    if job is not None:
379      output.append("/job:" + job)
380    if replica is not None:
381      output.append("/replica:" + str(replica))
382    if task is not None:
383      output.append("/task:" + str(task))
384    if device_type is not None:
385      device_index_string = "*"
386      if device_index is not None:
387        # Unlike the others, device_index is stored as an int.
388        device_index_string = str(device_index)
389      output.append("/device:%s:%s" % (device_type, device_index_string))
390
391    output = "".join(output)
392    _COMPONENTS_TO_STRING_CACHE[key] = output
393    return output
394
395  def __eq__(self, other):
396    """Checks if the `other` DeviceSpec is same as the current instance, eg have
397
398       same value for all the internal fields.
399
400    Args:
401      other: Another DeviceSpec
402
403    Returns:
404      Return `True` if `other` is also a DeviceSpec instance and has same value
405      as the current instance.
406      Return `False` otherwise.
407    """
408    return (isinstance(other, self.__class__) and
409            self.to_string() == other.to_string())
410
411  def __hash__(self):
412    return self._hash
413
414
415@tf_export(v1=["DeviceSpec"])  # pylint: disable=missing-docstring
416class DeviceSpecV1(DeviceSpecV2):
417  __doc__ = DeviceSpecV2.__doc__
418  __slots__ = DeviceSpecV2.__slots__
419
420  @DeviceSpecV2.job.setter
421  def job(self, job):
422    self._job = _as_str_or_none(job)
423    self._as_string, self._hash = None, None
424
425  @DeviceSpecV2.replica.setter
426  def replica(self, replica):
427    self._replica = _as_int_or_none(replica)
428    self._as_string, self._hash = None, None
429
430  @DeviceSpecV2.task.setter
431  def task(self, task):
432    self._task = _as_int_or_none(task)
433    self._as_string, self._hash = None, None
434
435  @DeviceSpecV2.device_type.setter
436  def device_type(self, device_type):
437    self._device_type = _as_device_str_or_none(device_type)
438    self._as_string, self._hash = None, None
439
440  @DeviceSpecV2.device_index.setter
441  def device_index(self, device_index):
442    self._device_index = _as_int_or_none(device_index)
443    self._as_string, self._hash = None, None
444
445  def __hash__(self):
446    if self._hash is None:
447      self._hash = hash(self.to_string())
448    return self._hash
449
450  def to_string(self):
451    if self._as_string is None:
452      self._as_string = self._components_to_string(
453          job=self.job,
454          replica=self.replica,
455          task=self.task,
456          device_type=self.device_type,
457          device_index=self.device_index)
458    return self._as_string
459
460  def parse_from_string(self, spec):
461    (self.job, self.replica, self.task, self.device_type,
462     self.device_index) = self._string_to_components(spec)
463
464    return self
465
466  def merge_from(self, dev):
467    """Merge the properties of "dev" into this `DeviceSpec`.
468
469    Note: Will be removed in TensorFlow 2.x since DeviceSpecs will become
470          immutable.
471
472    Args:
473      dev: a `DeviceSpec`.
474    """
475    (self.job, self.replica, self.task, self.device_type,
476     self.device_index) = self._get_combined_properties(dev)
477
478  # Use parent class docstrings for public methods.
479  to_string.__doc__ = DeviceSpecV2.to_string.__doc__
480  parse_from_string.__doc__ = DeviceSpecV2.parse_from_string.__doc__
481