xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/python/dtensor_device.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Propagates information about tensor layouts across operations."""
16
17import contextlib
18import logging
19import os
20import threading
21from typing import Any, List, Sequence, Set
22
23import numpy as np
24
25from tensorflow.core.framework import attr_value_pb2
26from tensorflow.dtensor.python import gen_dtensor_ops
27from tensorflow.dtensor.python import layout as layout_lib
28from tensorflow.python import _pywrap_dtensor_device
29from tensorflow.python.eager import context
30from tensorflow.python.eager import core
31from tensorflow.python.framework import device as tf_device
32from tensorflow.python.framework import dtypes
33from tensorflow.python.framework import ops
34from tensorflow.python.framework import sparse_tensor
35from tensorflow.python.ops import resource_variable_ops
36
37_DT_CLIENT_ID = "DTENSOR_CLIENT_ID"
38_DT_NUM_CLIENTS = "DTENSOR_NUM_CLIENTS"
39_DT_JOB_NAME = "DTENSOR_JOB_NAME"
40
41# TODO(allenl): Allow something other than "CUSTOM" so we don't need device
42# numbering hacks to avoid collisions between parallel devices and dtensor
43# devices.
44_next_device_number = 0
45_next_device_number_lock = threading.Lock()
46
47
48class DTensorDevice(object):
49  """Wraps a custom device which attempts to propagate tensor layouts."""
50
51  def __init__(self, meshes: List[layout_lib.Mesh], is_async=True):
52    """Create a new DTensorDevice which executes ops on `underlying_device`.
53
54    Args:
55      meshes: A list of `Mesh` objects indicating groups of devices to execute
56        on. These may also be registered lazily.
57      is_async: Indicates whether DTensor operations on this client will return
58        immediately (with "non-ready" handles) or block until executed. This is
59        on by default and is exposed as an option for ease of debugging.
60    """
61    if any(not isinstance(mesh, layout_lib.Mesh) for mesh in meshes):
62      raise TypeError(
63          "Expected a flat list of Mesh objects, got {}".format(meshes))
64    global _next_device_number
65    ctx = context.context()
66    with _next_device_number_lock:
67      self.name = "{}/device:CUSTOM:{}".format(ctx.host_address_space(),
68                                               _next_device_number)
69      _next_device_number += 1
70    device, device_info = _pywrap_dtensor_device.Allocate(self.name)
71    context.register_custom_device(device, self.name, device_info)
72
73    self._device_info = device_info
74    self._current_output_layout = None
75    self._current_default_mesh = None
76    self._is_async = is_async
77    self._meshes = set()
78    self._mesh_lock = threading.Lock()
79    for mesh in meshes:
80      self._register_mesh(mesh)
81
82# FIXME(b/241819185): Reuse the logic in api.py.
83# LINT.IfChange
84  def _num_clients(self):
85    """Returns number of clients in current DTensor cluster."""
86    # If missing, 1 is a good default.
87    return int(os.environ.get(_DT_NUM_CLIENTS, "1"))
88# LINT.ThenChange(//tensorflow/dtensor/cc/dtensor_utils.cc)
89
90# LINT.IfChange
91  def _client_id(self):
92    """Returns current client ID (int) in current DTensor cluster."""
93    return int(os.environ.get(_DT_CLIENT_ID, "0"))
94# LINT.ThenChange(//tensorflow/dtensor/cc/dtensor_utils.cc)
95
96  def _job_name(self):
97    """Returns the DTensor Borg job name."""
98    # If missing, the program is likely running locally or on Forge.
99    return os.environ.get(_DT_JOB_NAME,
100                          "localhost" if self._num_clients() == 1 else "worker")
101
102  def _full_job_name(self):
103    """Returns the fully qualified TF job name for this task."""
104    return self._job_name() + "/replica:0/task:" + str(self._client_id())
105
106  def _create_host_array(self, shape, host_id):
107    """Returns ID and device lists that can be used to create a host mesh."""
108    num_global_devices = np.prod(shape)
109    global_device_ids = np.arange(num_global_devices).reshape(shape)
110    local_device_list = [
111        tf_device.DeviceSpec(
112            job=self._full_job_name(), device_type="CPU", device_index=0)
113    ]
114    num_local_devices = len(local_device_list)
115    local_device_ids = [
116        x + host_id * num_local_devices for x in range(num_local_devices)
117    ]
118    return global_device_ids, local_device_ids, local_device_list
119
120  def _create_embedding_host_mesh(self, tpu_mesh: layout_lib.Mesh):
121    """Returns Embedding host mesh for each client."""
122    if tpu_mesh.device_type().upper() != "TPU":
123      raise ValueError("Must pass input of a tpu mesh.")
124
125    # Global device ids are global host ids, while local device ids contains
126    # local host id.
127
128    ts_local_device_ids = []
129    ts_local_devices = []
130    for local_device_str in tpu_mesh.local_devices():
131      # We only need to keep TPU:0 for each client.
132      if not local_device_str.endswith("TPU:0"):
133        continue
134
135      device_spec = tf_device.DeviceSpec.from_string(local_device_str)
136      ts_local_device_ids.append(device_spec.task)
137      ts_local_devices.append(device_spec.replace(device_type="CPU"))
138
139    if not ts_local_device_ids or not ts_local_device_ids:
140      logging.info(
141          "Cannot create tpu system mesh as %s has no `TPU:0` local device "
142          "found", tpu_mesh.to_string())
143      return None
144
145    ts_global_device_ids = np.arange(self._num_clients())
146    # TODO(zhonglinhan): parse global device specs as input when not None.
147    return layout_lib.Mesh(
148        dim_names=[tpu_mesh.dim_names[0]],  # 1D mesh.
149        global_device_ids=ts_global_device_ids,
150        local_device_ids=ts_local_device_ids,
151        local_devices=ts_local_devices)
152
153  def _register_mesh(self, mesh: layout_lib.Mesh):
154    """Idempotently register `mesh` with the dtensor device."""
155    with self._mesh_lock:
156      if mesh not in self._meshes:
157        _pywrap_dtensor_device.AddMesh(self._device_info, mesh.to_string(),
158                                       self._is_async, False)
159        self._meshes.add(mesh)
160        if mesh.device_type().upper() == "TPU":
161          logging.info(
162              "Registering virtual 1:1 mapped host mesh %s for mesh %s",
163              mesh.host_mesh().to_string(), mesh.to_string())
164          _pywrap_dtensor_device.AddMesh(self._device_info,
165                                         mesh.host_mesh().to_string(),
166                                         self._is_async, True)
167          self._meshes.add(mesh.host_mesh())
168          embedding_host_mesh = self._create_embedding_host_mesh(mesh)
169          if embedding_host_mesh:
170            logging.info(
171                "Registering embedding host mesh %s on each client for mesh %s",
172                embedding_host_mesh.to_string(), mesh.to_string())
173            _pywrap_dtensor_device.AddMesh(self._device_info,
174                                           embedding_host_mesh.to_string(),
175                                           self._is_async, False)
176            self._meshes.add(embedding_host_mesh)
177
178  @property
179  def meshes(self) -> Set[layout_lib.Mesh]:
180    return self._meshes
181
182  def copy_to_mesh(self, tensor, new_layout, source_layout=None) -> ops.Tensor:
183    """Copy `tensor` to `device` with the given layout."""
184    self._register_mesh(new_layout.mesh)
185    with ops.device(self.name):
186      return gen_dtensor_ops.copy_to_mesh(
187          tensor,
188          layout=new_layout.to_string(),
189          source_layout=source_layout.to_string() if source_layout else "")
190
191  def pack(self, tensors: Sequence[Any], layout: layout_lib.Layout) -> Any:
192    """Packs tensors into a DTensor handle on this DTensor device.
193
194    Packing and unpacking are inverse operations:
195
196    ```
197    * unpack(pack(tensors)) == tensors
198    * pack(unpack(dtensor)) == dtensor
199    ```
200
201    Refer to `dtensor.pack` for more information.
202
203    Args:
204      tensors: The list of tensors to pack into a DTensor.
205      layout: The layout of the DTensor to be created.
206
207    Returns:
208      A DTensor created from the individual component tensors.
209
210    Raises:
211      RuntimeError: When not called eagerly.
212    """
213    if not context.executing_eagerly():
214      raise RuntimeError("Pack must be called eagerly.")
215    if any(
216        issubclass(type(t), resource_variable_ops.BaseResourceVariable)
217        for t in tensors):
218      raise TypeError(
219          "Received Variable input to Pack, Variable is not supported.")
220    self._register_mesh(layout.mesh)
221    with ops.device(self.name):
222      if all(isinstance(t, sparse_tensor.SparseTensor) for t in tensors):
223        if not all(t.shape == tensors[0].shape for t in tensors):
224          raise TypeError("All input SparseTensors to Pack must be same shape.")
225        is_sparse = True
226        tensors = [t.indices for t in tensors] + [t.values for t in tensors] + [
227            ops.convert_to_tensor(t.shape, dtype=dtypes.int64) for t in tensors
228        ]
229      elif any(isinstance(t, sparse_tensor.SparseTensor) for t in tensors):
230        raise TypeError("Cannot Pack SparseTensors with Tensors.")
231      else:
232        is_sparse = False
233      try:
234        return _pywrap_dtensor_device.Pack(
235            context.context()._handle,  # pylint: disable=protected-access
236            tensors,
237            layout.to_string(),
238            self._device_info,
239            is_sparse)
240      except core._NotOkStatusException as e:  # pylint: disable=protected-access
241        raise core._status_to_exception(e) from None  # pylint: disable=protected-access
242
243  def unpack(self, dtensor: Any) -> Sequence[Any]:
244    """Unpacks a DTensor handle on this DTensor device.
245
246    Packing and unpacking are inverse operations:
247
248    ```
249    * unpack(pack(tensors)) == tensors
250    * pack(unpack(dtensor)) == dtensor
251    ```
252
253    Refer to `dtensor.unpack` for more information.
254
255    Args:
256      dtensor: The DTensor to unpack.
257
258    Returns:
259      The raw underlying tensor components of the DTensor.
260
261    Raises:
262      RuntimeError: When not called eagerly.
263    """
264    if not context.executing_eagerly():
265      raise RuntimeError("Unpack must be called eagerly.")
266    if issubclass(type(dtensor), resource_variable_ops.BaseResourceVariable):
267      raise TypeError(
268          "Received Variable input to unpack, Variable is not supported.")
269    try:
270      tensors = _pywrap_dtensor_device.Unpack(
271          context.context()._handle,  # pylint: disable=protected-access
272          dtensor,
273          self._device_info)
274    except core._NotOkStatusException as e:  # pylint: disable=protected-access
275      raise core._status_to_exception(e) from None  # pylint: disable=protected-access
276
277    is_sparse = _pywrap_dtensor_device.IsSparseDTensor(
278        context.context()._handle,  # pylint: disable=protected-access.
279        dtensor,
280        self._device_info)
281    if is_sparse:
282      result = []
283      for i in range(len(tensors) // 3):
284        result.append(
285            sparse_tensor.SparseTensor(tensors[i],
286                                       tensors[i + len(tensors) // 3],
287                                       tensors[i + 2 * len(tensors) // 3]))
288      return result
289    else:
290      return tensors
291
292  def fetch_layout(self, dtensor: Any) -> layout_lib.Layout:
293    """Fetches the layout of the DTensor.
294
295    Args:
296      dtensor: The DTensor whose layout is to be fetched.
297
298    Returns:
299      The `Layout` of this DTensor.
300
301    Raises:
302      RuntimeError: When not called eagerly.
303    """
304    if not context.executing_eagerly():
305      raise RuntimeError("FetchLayout must be called eagerly.")
306    if issubclass(type(dtensor), resource_variable_ops.BaseResourceVariable):
307      dtensor = dtensor.read_value()
308    try:
309      layout_string = _pywrap_dtensor_device.FetchLayout(
310          context.context()._handle,  # pylint: disable=protected-access
311          dtensor,
312          self._device_info)
313    except core._NotOkStatusException as e:  # pylint: disable=protected-access
314      raise core._status_to_exception(e) from None  # pylint: disable=protected-access
315    return layout_lib.Layout.from_string(layout_string)
316
317  def set_same_shape_policy(self, enabled):
318    """Guess layouts using the layouts of other tensors with the same shape.
319
320    This is the default behavior, and is quite safe. The `default_layout` scope
321    overrides shape-based guesses.
322
323    Args:
324      enabled: A boolean indicating whether to use the policy.
325    """
326    _pywrap_dtensor_device.SetSameShapePolicy(self._device_info, enabled)
327
328  def set_tpu_core_ids(self, mesh_name, tpu_core_ids):
329    """Sets the singleton global device ID-to-physical core ID map.
330
331    Args:
332      mesh_name: The name of a mesh. If empty, set the default mapping.
333      tpu_core_ids: TPU core IDs sorted by TF task/device ordinal.
334    """
335    _pywrap_dtensor_device.SetTPUCoreIDs(self._device_info, mesh_name,
336                                         tpu_core_ids)
337
338  def clear_tpu_core_ids(self):
339    _pywrap_dtensor_device.ClearTPUCoreIDs(self._device_info)
340
341  def tpu_core_ids_to_locations(self, tpu_core_ids):
342    """Translates TPU core IDs to TPU core locations.
343
344    Args:
345      tpu_core_ids: A list of TPU core IDs. Each one is an unsigned integer.
346
347    Returns:
348      A list of corresponding TPU core locations.
349    """
350    return _pywrap_dtensor_device.TPUCoreIDsToLocations(
351        context.context()._handle,  # pylint: disable=protected-access
352        self._device_info,
353        tpu_core_ids)
354
355  def tpu_core_locations_to_ids(self, tpu_core_locations):
356    """Translates TPU core locations to TPU core IDs.
357
358    Args:
359      tpu_core_locations: A list of TPU core locations. Each one is a list of
360        four unsigned integers, [x, y, z, core].
361
362    Returns:
363      A list of corresponding TPU core IDs.
364    """
365    return _pywrap_dtensor_device.TPUCoreLocationsToIDs(
366        context.context()._handle,  # pylint: disable=protected-access
367        self._device_info,
368        tpu_core_locations)
369
370  def _get_function_cache_hit_and_miss_count(self):
371    """Returns the number of cache hit and miss for function compilation.
372
373    Returns:
374      A dictionary keyed with miss and hit, corresponding to the cache hit and
375      miss count.
376    """
377    return _pywrap_dtensor_device.GetFunctionCacheHitAndMissCount(
378        context.context()._handle,  # pylint: disable=protected-access,
379        self._device_info)
380
381  @contextlib.contextmanager
382  def _experimental_default_mesh(self, mesh: layout_lib.Mesh):
383    """Sets a default mesh for all ops in the scope.
384
385    Note: This is an internal helper method, which is not user facing api.
386
387    Useful for requesting a specific mesh for ops which would have no inferred
388    layout, e.g. tf.zeros.
389
390    Args:
391      mesh: A Mesh to be used for ops without Mesh.
392
393    Yields:
394      Nothing.
395    """
396    previous_default = self._current_default_mesh
397    self._register_mesh(mesh)
398    _pywrap_dtensor_device.ExperimentalSetDefaultMesh(
399        self._device_info,
400        mesh.to_string().encode("utf-8"))
401    self._current_default_mesh = mesh
402    yield
403    _pywrap_dtensor_device.ExperimentalClearDefaultMesh(self._device_info)
404    if previous_default:
405      _pywrap_dtensor_device.ExperimentalSetDefaultMesh(
406          self._device_info,
407          previous_default.to_string().encode("utf-8"))
408    self._current_default_mesh = previous_default
409
410  @contextlib.contextmanager
411  def _default_layout(self, layout: layout_lib.Layout):
412    """Sets a default output layout for all ops in the scope.
413
414    Note: This is an internal helper method, which is not user facing api.
415
416    Useful for requesting a specific layout for ops which would have no inferred
417    layout, e.g. tf.zeros.
418
419    Caveats:
420
421    - Currently only affects the first output of an op. For Op with multiple
422      outputs, this does not support yet.
423
424    - All Ops in the scope will be attached with the same layout. This might not
425      be valid as the rank is different. The current suggestion is: Try to wrap
426      the raw op wheneven possible.
427
428    Args:
429      layout: A Layout for the outputs of all operations in this scope.
430
431    Yields:
432      Nothing.
433    """
434    previous_default = None
435    previous_graph_size = None
436    graph = None
437
438    self._register_mesh(layout.mesh)
439    try:
440      previous_default = self._current_output_layout
441      self._current_output_layout = layout.to_string().encode("utf-8")
442      _pywrap_dtensor_device.ExperimentalSetDefaultLayout(
443          self._device_info, self._current_output_layout)
444      if context.executing_eagerly():
445        with ops.device(self.name):
446          yield
447      else:
448        # Custom devices currently don't affect graph building, so we need a
449        # separate way to indicate layouts.
450        #
451        # TODO(allenl): Remove this case once the DTensor device is active
452        # during tracing.
453        graph = ops.get_default_graph()
454        previous_graph_size = len(graph.get_operations())
455        yield
456    finally:
457      if graph is not None:
458        # Tag operations added under this scope
459        for operation in graph.get_operations()[previous_graph_size:]:
460          # Set layout directly on the Op itself.
461          operation._set_attr(  # pylint: disable=protected-access
462              "_layout",
463              attr_value_pb2.AttrValue(
464                  list=attr_value_pb2.AttrValue.ListValue(
465                      s=[self._current_output_layout])))
466          operation._set_attr(  # pylint: disable=protected-access
467              "_mesh",
468              attr_value_pb2.AttrValue(
469                  s=layout.mesh.to_string().encode("utf-8")))
470
471      self._current_output_layout = previous_default
472      if self._current_output_layout is None:
473        _pywrap_dtensor_device.ExperimentalClearDefaultLayout(self._device_info)
474      else:
475        _pywrap_dtensor_device.ExperimentalSetDefaultLayout(
476            self._device_info, self._current_output_layout.decode("utf-8"))
477