xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/python/layout.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python definitions for `Mesh` and `Layout`."""
16
17import collections
18import itertools
19from typing import List, Dict, Optional
20
21import numpy as np
22
23from tensorflow.dtensor.proto import layout_pb2
24from tensorflow.python.framework import config as tf_config
25from tensorflow.python.framework import device as tf_device
26from tensorflow.python.framework import ops
27from tensorflow.python.util.tf_export import tf_export
28
29# UNSHARDED indicates a tensor dimension is not sharded over any mesh dimension.
30UNSHARDED = 'unsharded'
31MATCH = 'match'
32
33tf_export('experimental.dtensor.UNSHARDED', v1=[]).export_constant(
34    __name__, 'UNSHARDED')
35tf_export('experimental.dtensor.MATCH', v1=[]).export_constant(
36    __name__, 'MATCH')
37
38MeshDimension = collections.namedtuple('MeshDimension', ['name', 'size'])
39
40
41def _compute_mesh_strides(mesh_dims: List[MeshDimension]) -> List[int]:
42  strides = [1]
43  for idx, dim in enumerate(reversed(mesh_dims[1:])):
44    strides.append(strides[idx] * dim.size)
45  strides.reverse()
46  return strides
47
48
49@tf_export('experimental.dtensor.Mesh', v1=[])
50class Mesh(object):
51  """Represents a Mesh configuration over a certain list of Mesh Dimensions.
52
53  A mesh consists of named dimensions with sizes, which describe how a set of
54  devices are arranged. Defining tensor layouts in terms of mesh dimensions
55  allows us to efficiently determine the communication required when computing
56  an operation with tensors of different layouts.
57
58  A mesh provides information not only about the placement of the tensors but
59  also the topology of the underlying devices. For example, we can group 8 TPUs
60  as a 1-D array for data parallelism or a `2x4` grid for (2-way) data
61  parallelism and (4-way) model parallelism.
62
63  Note: the utilities `dtensor.create_mesh` and
64  `dtensor.create_distributed_mesh` provide a simpler API to create meshes for
65  single- or multi-client use cases.
66  """
67
68  _dim_dict: Dict[str, MeshDimension]
69  _dim_names: List[str]
70  _local_device_ids: List[int]
71  _global_device_ids: np.ndarray
72  _name: str
73  _local_devices = List[tf_device.DeviceSpec]
74  _global_devices = Optional[List[tf_device.DeviceSpec]]
75  _device_type: str
76
77  def __init__(self,
78               dim_names: List[str],
79               global_device_ids: np.ndarray,
80               local_device_ids: List[int],
81               local_devices: List[tf_device.DeviceSpec],
82               mesh_name: str = '',
83               global_devices: Optional[List[tf_device.DeviceSpec]] = None):
84    """Builds a Mesh.
85
86    The `dim_names` and `global_device_ids` arguments describe the dimension
87    names and shape for the mesh.
88
89    For example,
90
91    ```python
92      dim_names = ('x', 'y'),
93      global_device_ids = [[0, 1],
94                           [2, 3],
95                           [4, 5]]
96    ```
97
98    defines a 2D mesh of shape 3x2. A reduction over the 'x' dimension will
99    reduce across columns (0, 2, 4) and (1, 3, 5), and a reduction over the 'y'
100    dimension reduces across rows.
101
102    Note: the utilities `dtensor.create_mesh` and
103    `dtensor.create_distributed_mesh` provide a simpler API to create meshes for
104    single- or multi-client use cases.
105
106    Args:
107      dim_names: A list of strings indicating dimension names.
108      global_device_ids: An ndarray of global device IDs is used to compose
109        DeviceSpecs describing the mesh. The shape of this array determines the
110        size of each mesh dimension. Values in this array should increment
111        sequentially from 0. This argument is the same for every DTensor client.
112      local_device_ids: A list of local device IDs equal to a subset of values
113        in global_device_ids. They indicate the position of local devices in the
114        global mesh. Different DTensor clients must contain distinct
115        local_device_ids contents. All local_device_ids from all DTensor clients
116        must cover every element in global_device_ids.
117      local_devices: The list of devices hosted locally. The elements correspond
118        1:1 to those of local_device_ids.
119      mesh_name: The name of the mesh. Currently, this is rarely used, and is
120        mostly used to indicate whether it is a CPU, GPU, or TPU-based mesh.
121      global_devices (optional): The list of global devices. Set when multiple
122        device meshes are in use.
123    """
124    # Check if input args are valid.
125    if not isinstance(global_device_ids, np.ndarray):
126      raise ValueError('Variable global_device_ids must be an ndarray.')
127    if global_device_ids.size == 0:
128      raise ValueError('Variable global_device_ids must be non-empty.')
129    flat_global_device_ids = global_device_ids.flatten()
130    # global_device_ids are expected to be consecutive numbers.
131    # LINT.IfChange
132    distance = flat_global_device_ids[0]
133    if any(
134        (gid - i != distance) for i, gid in enumerate(flat_global_device_ids)):
135      raise ValueError('global_device_ids must sequentially increase: %s' %
136                       global_device_ids)
137    # LINT.ThenChange(//tensorflow/dtensor/cc/dtensor_device.cc)
138
139    if len(dim_names) != global_device_ids.ndim:
140      raise ValueError(
141          'Number of mesh dimensions does not match number of dimension names.')
142
143    if not isinstance(local_device_ids, list):
144      raise ValueError('Variable local_device_ids must be a list of integers.')
145
146    if not isinstance(local_devices, list):
147      raise ValueError('Variable local_devices must be a list of DeviceSpecs.')
148
149    if global_devices and not isinstance(global_devices, list):
150      raise ValueError('Variable global_devices must be a list of DeviceSpecs.')
151
152    if not local_devices and not global_devices:
153      raise ValueError('Empty list of devices not allowed.')
154
155    local_devices_set = set(local_devices)
156    local_device_only_contains_host_cpu = (
157        len(local_devices_set) == 1 and
158        list(local_devices_set)[0].device_type == 'CPU')
159    if not local_device_only_contains_host_cpu and len(local_devices) != len(
160        local_devices_set):
161      raise ValueError('Duplicate devices found in mesh specification %s.' %
162                       [d for d in local_devices if local_devices.count(d) > 1])
163
164    if len(local_device_ids) != len(local_devices):
165      raise ValueError(
166          'Variable local_device_ids does not have same size as local_devices.')
167
168    if len(local_device_ids) > len(np.ravel(global_device_ids)):
169      raise ValueError('Cannot have more local than gobal device IDs.')
170
171    device_types = set([device.device_type for device in local_devices])
172    if not device_types:
173      device_types = set([device.device_type for device in global_devices])
174    if None in device_types:
175      raise ValueError('device_type is required')
176    if len(device_types) > 1:
177      raise ValueError('Devices containing multiple device_types : %s' %
178                       device_types)
179
180    # Set object's state.
181    self._device_type = device_types.pop()
182    self._dim_names = dim_names
183    self._dim_dict = {
184        dim_name: MeshDimension(dim_name, global_device_ids.shape[i])
185        for i, dim_name in enumerate(dim_names)
186    }
187    self._global_device_ids = global_device_ids
188    self._local_device_ids = local_device_ids
189    self._local_devices = local_devices
190    self._global_devices = global_devices
191    self._name = mesh_name
192    self._strides = _compute_mesh_strides(
193        [self._dim_dict[dim] for dim in self._dim_names])
194
195  def __contains__(self, dim_name: str) -> bool:
196    return self.contains_dim(dim_name)
197
198  def __eq__(self, other):
199    if not isinstance(other, type(self)) and not isinstance(self, type(other)):
200      raise ValueError('comparing with type : {0} but expecting : {1}'.format(
201          type(other), type(self)))
202    return self.as_proto().SerializeToString() == other.as_proto(
203    ).SerializeToString()
204
205  def __getitem__(self, dim_name: str) -> MeshDimension:
206    if dim_name not in self._dim_dict:
207      raise KeyError(
208          f'Dimension {dim_name} not defined in mesh: {self._dim_dict.keys()}')
209    return self._dim_dict[dim_name]
210
211  # TODO(b/168730933): Define a nicer mesh ID.
212  def __hash__(self):
213    return hash(self.as_proto().SerializeToString(deterministic=True))
214
215  def __repr__(self) -> str:
216    dims = [tuple(self[dim_name]) for dim_name in self.dim_names]
217    return (
218        f'<Mesh object with dims={dims}, device_type="{self.device_type()}", '
219        f'num_local_devices={self.num_local_devices()}), '
220        f'size={self.size}>')
221
222  def as_proto(self) -> layout_pb2.MeshProto:
223    """Returns mesh protobuffer."""
224
225    mesh_proto = layout_pb2.MeshProto()
226
227    mesh_proto.name = self._name
228
229    for i, mesh_dimension in enumerate(self._dim_names):
230      dim = mesh_proto.mesh_dimensions.add()
231      dim.name = mesh_dimension
232      dim.size = self._global_device_ids.shape[i]
233
234    for d in np.ravel(self._global_device_ids):
235      mesh_proto.global_device_ids.append(d)
236
237    for d in self._local_device_ids:
238      mesh_proto.local_device_ids.append(d)
239
240    for d in self._local_devices:
241      mesh_proto.local_devices.append(d.to_string())
242
243    if self._global_devices:
244      for d in self._global_devices:
245        mesh_proto.global_devices.append(d.to_string())
246
247    return mesh_proto
248
249  def contains_dim(self, dim_name: str) -> bool:
250    """Returns True if a Mesh contains the given dimension name."""
251    return dim_name in self._dim_dict
252
253  def coords(self, device_idx: int) -> ops.Tensor:
254    """Converts the device index into a tensor of mesh coordinates."""
255    strides = ops.convert_to_tensor(self.strides)
256    shape = ops.convert_to_tensor(self.shape())
257    return (device_idx // strides) % shape
258
259  def device_type(self) -> str:
260    """Returns the device_type of a Mesh."""
261    return self._device_type
262
263  @property
264  def dim_names(self) -> List[str]:
265    return self._dim_names
266
267  def dim_size(self, dim_name: str) -> int:
268    """Returns the size of a dimension."""
269    if dim_name not in self._dim_dict.keys():
270      raise ValueError(('"{dim_name}" not a dimension name in current mesh. ' +
271                        'Dimension names: {dim_names}.').format(
272                            dim_name=dim_name,
273                            dim_names=list(self._dim_dict.keys())))
274    return self._dim_dict[dim_name].size
275
276  @staticmethod
277  def from_proto(proto: layout_pb2.MeshProto) -> 'Mesh':
278    """Construct a mesh instance from input `proto`."""
279    shape = [dim.size for dim in proto.mesh_dimensions]
280
281    # Convert global_device ids list back into array form
282    global_device_ids = [int(d) for d in proto.global_device_ids]
283    global_device_ids = np.asarray(global_device_ids).reshape(shape)
284
285    # Construct local_device_ids list
286    local_device_ids = [int(d) for d in proto.local_device_ids]
287
288    # Convert local devices list back to array form
289    local_devices = [
290        tf_device.DeviceSpec.from_string(d) for d in proto.local_devices
291    ]
292
293    # Convert global devices list back to array form
294    global_devices = [
295        tf_device.DeviceSpec.from_string(d) for d in proto.global_devices
296    ]
297
298    name = proto.name
299    dims = [dim.name for dim in proto.mesh_dimensions]
300    return Mesh(dims, global_device_ids, local_device_ids, local_devices, name,
301                global_devices)
302
303  @staticmethod
304  def from_string(mesh_str: str) -> 'Mesh':
305    """Construct a mesh instance from input `proto`."""
306    # Separate elements of mesh.
307    mesh_parts = mesh_str.split('|')
308    global_dev_str = None
309    if len(mesh_parts) == 5:
310      name, mesh_dim_strs, global_id_str, local_id_str, dev_str = mesh_parts
311    elif len(mesh_parts) == 6:
312      (name, mesh_dim_strs, global_id_str, local_id_str, dev_str,
313       global_dev_str) = mesh_parts
314    else:
315      raise ValueError('Invalid mesh string : %s' % mesh_str)
316
317    # Load mesh proto.
318    mesh_proto = layout_pb2.MeshProto()
319    mesh_proto.name = name
320
321    for mesh_dim_str in mesh_dim_strs.split(','):
322      name, size_str = mesh_dim_str.split('=')
323      dim = mesh_proto.mesh_dimensions.add()
324      dim.name = name
325      dim.size = int(size_str)
326
327    for global_id in global_id_str.split(','):
328      mesh_proto.global_device_ids.append(int(global_id))
329
330    if local_id_str:
331      for local_id in local_id_str.split(','):
332        mesh_proto.local_device_ids.append(int(local_id))
333
334    if dev_str:
335      for dev in dev_str.split(','):
336        mesh_proto.local_devices.append(dev)
337
338    if global_dev_str:
339      for dev in global_dev_str.split(','):
340        mesh_proto.global_devices.append(dev)
341
342    return Mesh.from_proto(mesh_proto)
343
344  def host_mesh(self):
345    """Returns the 1-1 mapped host mesh."""
346    if self.device_type().upper() == 'CPU':
347      return self
348
349    v_cpus_counts = len(tf_config.list_logical_devices('CPU'))
350    if v_cpus_counts < len(self._local_devices):
351      raise ValueError('Must have at least {0} virtual CPUs for mesh : {1}, '
352                       'but got : {2} virtual CPUs.'.format(
353                           len(self._local_devices), self.to_string(),
354                           v_cpus_counts))
355    device_array = np.asarray([
356        spec.replace(device_type='CPU') for spec in self._local_devices
357    ]).reshape((len(self._local_devices), 1))
358    global_devices = None
359    if self._global_devices:
360      global_devices = [
361          spec.replace(device_type='CPU') for spec in self._global_devices
362      ]
363    h_mesh = Mesh(
364        self._dim_names,
365        self._global_device_ids,
366        self.local_device_ids(),
367        np.ravel(device_array).tolist(),
368        global_devices=global_devices)
369    return h_mesh
370
371  def is_remote(self) -> bool:
372    """Returns True if a Mesh contains only remote devices."""
373    return not self._local_device_ids and self._global_device_ids.size > 0
374
375  def local_device_ids(self) -> List[int]:
376    """Returns a list of local device IDs."""
377    return self._local_device_ids
378
379  def local_device_locations(self) -> List[Dict[str, int]]:
380    """Returns a list of local device locations.
381
382    A device location is a dictionary from dimension names to indices on those
383    dimensions.
384    """
385    mapping = self.unravel_index()
386    return [mapping[device_id] for device_id in self.local_device_ids()]
387
388  def local_devices(self) -> List[str]:
389    """Returns a list of local device specs represented as strings."""
390    return [d.to_string() for d in self._local_devices]
391
392  def min_global_device_id(self) -> int:
393    """Returns the minimum global device ID."""
394    # global_device_ids sequentially increases.
395    return self._global_device_ids.flatten()[0]
396
397  @property
398  def name(self) -> str:
399    return self._name
400
401  def num_local_devices(self) -> int:
402    """Returns the number of local devices."""
403    return len(self._local_devices)
404
405  def shape(self) -> List[int]:
406    """Returns the shape of the mesh."""
407    return [self.dim_size(dim) for dim in self._dim_names]
408
409  @property
410  def size(self) -> int:
411    return len(np.ravel(self._global_device_ids))
412
413  @property
414  def strides(self) -> List[int]:
415    """Returns the strides tensor array for this mesh.
416
417    If the mesh shape is `[a, b, c, d]`, then the strides array can be computed
418    as `[b*c*d, c*d, d, 1]`. This array can be useful in computing local device
419    offsets given a device ID. Using the same example, the device coordinates of
420    the mesh can be computed as:
421
422    ```
423    [(device_id / (b*c*d)) % a,
424     (device_id / (c*d))   % b,
425     (device_id / (d))     % c,
426     (device_id)           % d]
427    ```
428
429    This is the same as `(device_id // mesh.strides) % mesh.shape`.
430
431    Returns:
432      The mesh strides as an integer tensor.
433    """
434    return self._strides
435
436  def to_string(self) -> str:
437    """Returns string representation of Mesh."""
438
439    # Get proto representation
440    mesh_proto = self.as_proto()
441    # Separate individual elements with ','.
442    name = mesh_proto.name
443    dim_str = ','.join(
444        dim.name + '=' + str(dim.size) for dim in mesh_proto.mesh_dimensions)
445    global_ids = ','.join(str(id) for id in mesh_proto.global_device_ids)
446    local_ids = ','.join(str(id) for id in mesh_proto.local_device_ids)
447    devices = ','.join(dev for dev in mesh_proto.local_devices)
448    components = [name, dim_str, global_ids, local_ids, devices]
449    if mesh_proto.global_devices:
450      global_devices = ','.join(dev for dev in mesh_proto.global_devices)
451      components.append(global_devices)
452    # Separate mesh components with '|'.
453    return '|'.join(components)
454
455  def unravel_index(self):
456    """Returns a dictionary from device ID to {dim_name: dim_index}.
457
458    For example, for a 3x2 mesh, return this:
459
460    ```
461      { 0: {'x': 0, 'y', 0},
462        1: {'x': 0, 'y', 1},
463        2: {'x': 1, 'y', 0},
464        3: {'x': 1, 'y', 1},
465        4: {'x': 2, 'y', 0},
466        5: {'x': 2, 'y', 1} }
467    ```
468    """
469    idx_ranges = [
470        range(self.dim_size(dim_name)) for dim_name in self._dim_names
471    ]
472    mesh_pos = itertools.product(*idx_ranges)
473    mapping = {}
474    for device_id, device_pos in enumerate(mesh_pos):
475      device_loc = {}
476      for dim_name, dim_index in zip(self._dim_names, device_pos):
477        device_loc[dim_name] = dim_index
478      mapping[device_id] = device_loc
479    return mapping
480
481
482# TODO(hthu): Consider making this class immutable.
483@tf_export('experimental.dtensor.Layout', v1=[])
484class Layout(object):
485  """Represents the layout information of a DTensor.
486
487  A layout describes how a distributed tensor is partitioned across a mesh (and
488  thus across devices). For each axis of the tensor, the corresponding
489  sharding spec indicates which dimension of the mesh it is sharded over. A
490  special sharding spec `UNSHARDED` indicates that axis is replicated on
491  all the devices of that mesh.
492
493  For example, let's consider a 1-D mesh:
494
495  ```
496  Mesh(["TPU:0", "TPU:1", "TPU:2", "TPU:3", "TPU:4", "TPU:5"], [("x", 6)])
497  ```
498
499  This mesh arranges 6 TPU devices into a 1-D array. `Layout([UNSHARDED], mesh)`
500  is a layout for rank-1 tensor which is replicated on the 6 devices.
501
502  For another example, let's consider a 2-D mesh:
503
504  ```
505  Mesh(["TPU:0", "TPU:1", "TPU:2", "TPU:3", "TPU:4", "TPU:5"],
506       [("x", 3), ("y", 2)])
507  ```
508
509  This mesh arranges 6 TPU devices into a `3x2` 2-D array.
510  `Layout(["x", UNSHARDED], mesh)` is a layout for rank-2 tensor whose first
511  axis is sharded on mesh dimension "x" and the second axis is replicated. If we
512  place `np.arange(6).reshape((3, 2))` using this layout, the individual
513  components tensors would look like:
514
515  ```
516  Device  |  Component
517   TPU:0     [[0, 1]]
518   TPU:1     [[0, 1]]
519   TPU:2     [[2, 3]]
520   TPU:3     [[2, 3]]
521   TPU:4     [[4, 5]]
522   TPU:5     [[4, 5]]
523  ```
524
525  """
526
527  def __init__(self, sharding_specs: List[str], mesh: Mesh):
528    """Builds a Layout from a list of dimension names and a Mesh.
529
530    Args:
531      sharding_specs: List of sharding specifications, each corresponding to a
532        tensor axis. Each specification (dim_sharding) can either be a mesh
533        dimension or the special value UNSHARDED.
534      mesh: A mesh configuration for the Tensor.
535
536    Returns:
537      A valid Layout built with given layout & mesh.
538    """
539    # Validate mesh
540    if not isinstance(mesh, Mesh):
541      raise ValueError('mesh is not a valid Mesh object.')
542
543    # Validate sharding spec
544    for _, dim_sharding in enumerate(sharding_specs):
545      # If special value no need to check for uniqueness, just skip.
546      if dim_sharding == UNSHARDED or dim_sharding == MATCH:
547        continue
548      # Check dim_sharding is unique.
549      if sharding_specs.count(dim_sharding) > 1:
550        raise ValueError(
551            ('Mesh dimension {mesh_dim} was repeated in sharding ' +
552             'specification {sharding_specs}. Mesh dimensions must be unique ' +
553             'in a layout.').format(
554                 mesh_dim=dim_sharding, sharding_specs=sharding_specs))
555      # Check dim_sharding is mesh dimension.
556      if dim_sharding not in mesh:
557        raise ValueError(
558            ('{dim_sharding}: A dimension sharding must either be a ' +
559             'valid mesh dimension or UNSHARDED.').format(
560                 dim_sharding=dim_sharding))
561
562    # Set object's state
563    self.sharding_specs = sharding_specs
564    self.rank = len(sharding_specs)
565    self.mesh = mesh
566    self.shape = [self.num_shards(i) for i in range(self.rank)]
567
568  def __eq__(self, other) -> bool:
569    return self.serialized_string() == other.serialized_string()
570
571  def __repr__(self) -> str:
572    return f'Layout(sharding_specs={self.sharding_specs}, mesh={self.mesh})'
573
574  def as_proto(self) -> layout_pb2.LayoutProto:
575    """Create a proto representation of a layout."""
576    layout_proto = layout_pb2.LayoutProto()
577
578    for dim_sharding in self.sharding_specs:
579      tensor_dim = layout_proto.sharding_specs.add()
580      tensor_dim.sharding_spec = dim_sharding
581
582    layout_proto.mesh_config.CopyFrom(self.mesh_proto())
583
584    return layout_proto
585
586  @staticmethod
587  def batch_sharded(mesh: Mesh, batch_dim: str, rank: int) -> 'Layout':
588    """Returns a layout sharded on batch dimension."""
589    return Layout([batch_dim] + [UNSHARDED] * (rank - 1), mesh)
590
591  def delete(self, dims: List[int]) -> 'Layout':
592    """Returns the layout with the give dimensions deleted."""
593    if not isinstance(dims, list):
594      dims = [dims]
595    new_specs = [
596        spec for i, spec in enumerate(self.sharding_specs) if i not in dims
597    ]
598    return Layout(new_specs, self.mesh)
599
600  @staticmethod
601  def from_str(layout_str: bytes) -> 'Layout':
602    """Creates an instance from a serialized Protobuf binary string."""
603    layout_proto = layout_pb2.LayoutProto()
604    layout_proto.ParseFromString(layout_str)
605    sharding_specs = [
606        sharding_spec.sharding_spec
607        for sharding_spec in layout_proto.sharding_specs
608    ]
609    mesh = Mesh.from_proto(layout_proto.mesh_config)
610    return Layout(sharding_specs, mesh)
611
612  @staticmethod
613  def from_string(layout_str: str) -> 'Layout':
614    """Creates an instance from a human-readable string."""
615    layout_parts = layout_str.split(' ')
616    if len(layout_parts) != 2:
617      raise ValueError(
618          'layout string must contain two parts: specs and mesh. But got {}.'
619          .format(layout_str))
620
621    sharding_specs_str = layout_parts[0].replace('sharding_specs:', '')
622    mesh_str = layout_parts[1].replace('mesh:', '')
623
624    sharding_specs = sharding_specs_str.split(',')[:-1]
625
626    mesh = Mesh.from_string(mesh_str)
627    layout = Layout(sharding_specs, mesh)
628    return layout
629
630  @staticmethod
631  def inner_sharded(mesh: Mesh, inner_dim: str, rank: int) -> 'Layout':
632    """Returns a layout sharded on inner dimension."""
633    return Layout([UNSHARDED] * (rank - 1) + [inner_dim], mesh)
634
635  def is_fully_replicated(self) -> bool:
636    """Returns True if all tensor axes are replicated."""
637    return all([self.num_shards(i) == 1 for i in range(self.rank)])
638
639  def mesh_proto(self) -> layout_pb2.MeshProto:
640    """Returns the underlying mesh in Protobuf format."""
641    return self.mesh.as_proto()
642
643  def num_shards(self, idx: int) -> int:
644    """Returns the number of shards for tensor dimension `idx`."""
645    dim_sharding = self.sharding_specs[idx]
646    if dim_sharding == UNSHARDED:
647      return 1
648    if dim_sharding == MATCH:
649      return -1
650    return self.mesh.dim_size(dim_sharding)
651
652  def offset_to_shard(self):
653    """Mapping from offset in a flattened list to shard index."""
654    unravel_index = self.mesh.unravel_index()
655    locations = [None] * self.mesh.size
656    for offset, mesh_loc in unravel_index.items():
657      loc = []
658      for dim_sharding in self.sharding_specs:
659        if dim_sharding == UNSHARDED:
660          loc.append(0)
661        else:
662          loc.append(mesh_loc[dim_sharding])
663      locations[offset] = tuple(loc)
664    return locations
665
666  def offset_tuple_to_global_index(self, offset_tuple):
667    """Mapping from offset to index in global tensor."""
668    index = 0
669    for i, o in enumerate(offset_tuple):
670      m = 1
671      for x in range(i + 1, self.rank):
672        m = m * self.num_shards(x)
673      index = index + m * o
674    return index
675
676  @staticmethod
677  def replicated(mesh: Mesh, rank: int) -> 'Layout':
678    """Returns a replicated layout of rank `rank`."""
679    return Layout([UNSHARDED] * rank, mesh)
680
681  def serialized_string(self) -> bytes:
682    """Returns a serialized Protobuf binary string representation."""
683    return self.as_proto().SerializeToString()
684
685  # A layout with no sharding specs is acceptable, therefore we only check the
686  # mesh.
687  def to_string(self) -> str:
688    """Returns a human-readable string representation."""
689    sharding_spec_str = 'sharding_specs:'
690    # Add comma after each instruction.
691    for spec in self.sharding_specs:
692      sharding_spec_str += spec + ','
693
694    mesh_str = 'mesh:' + self.mesh.to_string()
695    return sharding_spec_str + ' ' + mesh_str
696
697  def unravel(self, unpacked_tensors: List[np.ndarray]) -> np.ndarray:
698    """Convert a flattened list of shards into a sharded array."""
699    unravelled = np.ndarray([self.num_shards(i) for i in range(self.rank)],
700                            dtype=np.object)
701    for offset, loc in enumerate(self.offset_to_shard()):
702      unravelled[loc] = unpacked_tensors[offset]
703    return unravelled
704