1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Python definitions for `Mesh` and `Layout`.""" 16 17import collections 18import itertools 19from typing import List, Dict, Optional 20 21import numpy as np 22 23from tensorflow.dtensor.proto import layout_pb2 24from tensorflow.python.framework import config as tf_config 25from tensorflow.python.framework import device as tf_device 26from tensorflow.python.framework import ops 27from tensorflow.python.util.tf_export import tf_export 28 29# UNSHARDED indicates a tensor dimension is not sharded over any mesh dimension. 30UNSHARDED = 'unsharded' 31MATCH = 'match' 32 33tf_export('experimental.dtensor.UNSHARDED', v1=[]).export_constant( 34 __name__, 'UNSHARDED') 35tf_export('experimental.dtensor.MATCH', v1=[]).export_constant( 36 __name__, 'MATCH') 37 38MeshDimension = collections.namedtuple('MeshDimension', ['name', 'size']) 39 40 41def _compute_mesh_strides(mesh_dims: List[MeshDimension]) -> List[int]: 42 strides = [1] 43 for idx, dim in enumerate(reversed(mesh_dims[1:])): 44 strides.append(strides[idx] * dim.size) 45 strides.reverse() 46 return strides 47 48 49@tf_export('experimental.dtensor.Mesh', v1=[]) 50class Mesh(object): 51 """Represents a Mesh configuration over a certain list of Mesh Dimensions. 52 53 A mesh consists of named dimensions with sizes, which describe how a set of 54 devices are arranged. Defining tensor layouts in terms of mesh dimensions 55 allows us to efficiently determine the communication required when computing 56 an operation with tensors of different layouts. 57 58 A mesh provides information not only about the placement of the tensors but 59 also the topology of the underlying devices. For example, we can group 8 TPUs 60 as a 1-D array for data parallelism or a `2x4` grid for (2-way) data 61 parallelism and (4-way) model parallelism. 62 63 Note: the utilities `dtensor.create_mesh` and 64 `dtensor.create_distributed_mesh` provide a simpler API to create meshes for 65 single- or multi-client use cases. 66 """ 67 68 _dim_dict: Dict[str, MeshDimension] 69 _dim_names: List[str] 70 _local_device_ids: List[int] 71 _global_device_ids: np.ndarray 72 _name: str 73 _local_devices = List[tf_device.DeviceSpec] 74 _global_devices = Optional[List[tf_device.DeviceSpec]] 75 _device_type: str 76 77 def __init__(self, 78 dim_names: List[str], 79 global_device_ids: np.ndarray, 80 local_device_ids: List[int], 81 local_devices: List[tf_device.DeviceSpec], 82 mesh_name: str = '', 83 global_devices: Optional[List[tf_device.DeviceSpec]] = None): 84 """Builds a Mesh. 85 86 The `dim_names` and `global_device_ids` arguments describe the dimension 87 names and shape for the mesh. 88 89 For example, 90 91 ```python 92 dim_names = ('x', 'y'), 93 global_device_ids = [[0, 1], 94 [2, 3], 95 [4, 5]] 96 ``` 97 98 defines a 2D mesh of shape 3x2. A reduction over the 'x' dimension will 99 reduce across columns (0, 2, 4) and (1, 3, 5), and a reduction over the 'y' 100 dimension reduces across rows. 101 102 Note: the utilities `dtensor.create_mesh` and 103 `dtensor.create_distributed_mesh` provide a simpler API to create meshes for 104 single- or multi-client use cases. 105 106 Args: 107 dim_names: A list of strings indicating dimension names. 108 global_device_ids: An ndarray of global device IDs is used to compose 109 DeviceSpecs describing the mesh. The shape of this array determines the 110 size of each mesh dimension. Values in this array should increment 111 sequentially from 0. This argument is the same for every DTensor client. 112 local_device_ids: A list of local device IDs equal to a subset of values 113 in global_device_ids. They indicate the position of local devices in the 114 global mesh. Different DTensor clients must contain distinct 115 local_device_ids contents. All local_device_ids from all DTensor clients 116 must cover every element in global_device_ids. 117 local_devices: The list of devices hosted locally. The elements correspond 118 1:1 to those of local_device_ids. 119 mesh_name: The name of the mesh. Currently, this is rarely used, and is 120 mostly used to indicate whether it is a CPU, GPU, or TPU-based mesh. 121 global_devices (optional): The list of global devices. Set when multiple 122 device meshes are in use. 123 """ 124 # Check if input args are valid. 125 if not isinstance(global_device_ids, np.ndarray): 126 raise ValueError('Variable global_device_ids must be an ndarray.') 127 if global_device_ids.size == 0: 128 raise ValueError('Variable global_device_ids must be non-empty.') 129 flat_global_device_ids = global_device_ids.flatten() 130 # global_device_ids are expected to be consecutive numbers. 131 # LINT.IfChange 132 distance = flat_global_device_ids[0] 133 if any( 134 (gid - i != distance) for i, gid in enumerate(flat_global_device_ids)): 135 raise ValueError('global_device_ids must sequentially increase: %s' % 136 global_device_ids) 137 # LINT.ThenChange(//tensorflow/dtensor/cc/dtensor_device.cc) 138 139 if len(dim_names) != global_device_ids.ndim: 140 raise ValueError( 141 'Number of mesh dimensions does not match number of dimension names.') 142 143 if not isinstance(local_device_ids, list): 144 raise ValueError('Variable local_device_ids must be a list of integers.') 145 146 if not isinstance(local_devices, list): 147 raise ValueError('Variable local_devices must be a list of DeviceSpecs.') 148 149 if global_devices and not isinstance(global_devices, list): 150 raise ValueError('Variable global_devices must be a list of DeviceSpecs.') 151 152 if not local_devices and not global_devices: 153 raise ValueError('Empty list of devices not allowed.') 154 155 local_devices_set = set(local_devices) 156 local_device_only_contains_host_cpu = ( 157 len(local_devices_set) == 1 and 158 list(local_devices_set)[0].device_type == 'CPU') 159 if not local_device_only_contains_host_cpu and len(local_devices) != len( 160 local_devices_set): 161 raise ValueError('Duplicate devices found in mesh specification %s.' % 162 [d for d in local_devices if local_devices.count(d) > 1]) 163 164 if len(local_device_ids) != len(local_devices): 165 raise ValueError( 166 'Variable local_device_ids does not have same size as local_devices.') 167 168 if len(local_device_ids) > len(np.ravel(global_device_ids)): 169 raise ValueError('Cannot have more local than gobal device IDs.') 170 171 device_types = set([device.device_type for device in local_devices]) 172 if not device_types: 173 device_types = set([device.device_type for device in global_devices]) 174 if None in device_types: 175 raise ValueError('device_type is required') 176 if len(device_types) > 1: 177 raise ValueError('Devices containing multiple device_types : %s' % 178 device_types) 179 180 # Set object's state. 181 self._device_type = device_types.pop() 182 self._dim_names = dim_names 183 self._dim_dict = { 184 dim_name: MeshDimension(dim_name, global_device_ids.shape[i]) 185 for i, dim_name in enumerate(dim_names) 186 } 187 self._global_device_ids = global_device_ids 188 self._local_device_ids = local_device_ids 189 self._local_devices = local_devices 190 self._global_devices = global_devices 191 self._name = mesh_name 192 self._strides = _compute_mesh_strides( 193 [self._dim_dict[dim] for dim in self._dim_names]) 194 195 def __contains__(self, dim_name: str) -> bool: 196 return self.contains_dim(dim_name) 197 198 def __eq__(self, other): 199 if not isinstance(other, type(self)) and not isinstance(self, type(other)): 200 raise ValueError('comparing with type : {0} but expecting : {1}'.format( 201 type(other), type(self))) 202 return self.as_proto().SerializeToString() == other.as_proto( 203 ).SerializeToString() 204 205 def __getitem__(self, dim_name: str) -> MeshDimension: 206 if dim_name not in self._dim_dict: 207 raise KeyError( 208 f'Dimension {dim_name} not defined in mesh: {self._dim_dict.keys()}') 209 return self._dim_dict[dim_name] 210 211 # TODO(b/168730933): Define a nicer mesh ID. 212 def __hash__(self): 213 return hash(self.as_proto().SerializeToString(deterministic=True)) 214 215 def __repr__(self) -> str: 216 dims = [tuple(self[dim_name]) for dim_name in self.dim_names] 217 return ( 218 f'<Mesh object with dims={dims}, device_type="{self.device_type()}", ' 219 f'num_local_devices={self.num_local_devices()}), ' 220 f'size={self.size}>') 221 222 def as_proto(self) -> layout_pb2.MeshProto: 223 """Returns mesh protobuffer.""" 224 225 mesh_proto = layout_pb2.MeshProto() 226 227 mesh_proto.name = self._name 228 229 for i, mesh_dimension in enumerate(self._dim_names): 230 dim = mesh_proto.mesh_dimensions.add() 231 dim.name = mesh_dimension 232 dim.size = self._global_device_ids.shape[i] 233 234 for d in np.ravel(self._global_device_ids): 235 mesh_proto.global_device_ids.append(d) 236 237 for d in self._local_device_ids: 238 mesh_proto.local_device_ids.append(d) 239 240 for d in self._local_devices: 241 mesh_proto.local_devices.append(d.to_string()) 242 243 if self._global_devices: 244 for d in self._global_devices: 245 mesh_proto.global_devices.append(d.to_string()) 246 247 return mesh_proto 248 249 def contains_dim(self, dim_name: str) -> bool: 250 """Returns True if a Mesh contains the given dimension name.""" 251 return dim_name in self._dim_dict 252 253 def coords(self, device_idx: int) -> ops.Tensor: 254 """Converts the device index into a tensor of mesh coordinates.""" 255 strides = ops.convert_to_tensor(self.strides) 256 shape = ops.convert_to_tensor(self.shape()) 257 return (device_idx // strides) % shape 258 259 def device_type(self) -> str: 260 """Returns the device_type of a Mesh.""" 261 return self._device_type 262 263 @property 264 def dim_names(self) -> List[str]: 265 return self._dim_names 266 267 def dim_size(self, dim_name: str) -> int: 268 """Returns the size of a dimension.""" 269 if dim_name not in self._dim_dict.keys(): 270 raise ValueError(('"{dim_name}" not a dimension name in current mesh. ' + 271 'Dimension names: {dim_names}.').format( 272 dim_name=dim_name, 273 dim_names=list(self._dim_dict.keys()))) 274 return self._dim_dict[dim_name].size 275 276 @staticmethod 277 def from_proto(proto: layout_pb2.MeshProto) -> 'Mesh': 278 """Construct a mesh instance from input `proto`.""" 279 shape = [dim.size for dim in proto.mesh_dimensions] 280 281 # Convert global_device ids list back into array form 282 global_device_ids = [int(d) for d in proto.global_device_ids] 283 global_device_ids = np.asarray(global_device_ids).reshape(shape) 284 285 # Construct local_device_ids list 286 local_device_ids = [int(d) for d in proto.local_device_ids] 287 288 # Convert local devices list back to array form 289 local_devices = [ 290 tf_device.DeviceSpec.from_string(d) for d in proto.local_devices 291 ] 292 293 # Convert global devices list back to array form 294 global_devices = [ 295 tf_device.DeviceSpec.from_string(d) for d in proto.global_devices 296 ] 297 298 name = proto.name 299 dims = [dim.name for dim in proto.mesh_dimensions] 300 return Mesh(dims, global_device_ids, local_device_ids, local_devices, name, 301 global_devices) 302 303 @staticmethod 304 def from_string(mesh_str: str) -> 'Mesh': 305 """Construct a mesh instance from input `proto`.""" 306 # Separate elements of mesh. 307 mesh_parts = mesh_str.split('|') 308 global_dev_str = None 309 if len(mesh_parts) == 5: 310 name, mesh_dim_strs, global_id_str, local_id_str, dev_str = mesh_parts 311 elif len(mesh_parts) == 6: 312 (name, mesh_dim_strs, global_id_str, local_id_str, dev_str, 313 global_dev_str) = mesh_parts 314 else: 315 raise ValueError('Invalid mesh string : %s' % mesh_str) 316 317 # Load mesh proto. 318 mesh_proto = layout_pb2.MeshProto() 319 mesh_proto.name = name 320 321 for mesh_dim_str in mesh_dim_strs.split(','): 322 name, size_str = mesh_dim_str.split('=') 323 dim = mesh_proto.mesh_dimensions.add() 324 dim.name = name 325 dim.size = int(size_str) 326 327 for global_id in global_id_str.split(','): 328 mesh_proto.global_device_ids.append(int(global_id)) 329 330 if local_id_str: 331 for local_id in local_id_str.split(','): 332 mesh_proto.local_device_ids.append(int(local_id)) 333 334 if dev_str: 335 for dev in dev_str.split(','): 336 mesh_proto.local_devices.append(dev) 337 338 if global_dev_str: 339 for dev in global_dev_str.split(','): 340 mesh_proto.global_devices.append(dev) 341 342 return Mesh.from_proto(mesh_proto) 343 344 def host_mesh(self): 345 """Returns the 1-1 mapped host mesh.""" 346 if self.device_type().upper() == 'CPU': 347 return self 348 349 v_cpus_counts = len(tf_config.list_logical_devices('CPU')) 350 if v_cpus_counts < len(self._local_devices): 351 raise ValueError('Must have at least {0} virtual CPUs for mesh : {1}, ' 352 'but got : {2} virtual CPUs.'.format( 353 len(self._local_devices), self.to_string(), 354 v_cpus_counts)) 355 device_array = np.asarray([ 356 spec.replace(device_type='CPU') for spec in self._local_devices 357 ]).reshape((len(self._local_devices), 1)) 358 global_devices = None 359 if self._global_devices: 360 global_devices = [ 361 spec.replace(device_type='CPU') for spec in self._global_devices 362 ] 363 h_mesh = Mesh( 364 self._dim_names, 365 self._global_device_ids, 366 self.local_device_ids(), 367 np.ravel(device_array).tolist(), 368 global_devices=global_devices) 369 return h_mesh 370 371 def is_remote(self) -> bool: 372 """Returns True if a Mesh contains only remote devices.""" 373 return not self._local_device_ids and self._global_device_ids.size > 0 374 375 def local_device_ids(self) -> List[int]: 376 """Returns a list of local device IDs.""" 377 return self._local_device_ids 378 379 def local_device_locations(self) -> List[Dict[str, int]]: 380 """Returns a list of local device locations. 381 382 A device location is a dictionary from dimension names to indices on those 383 dimensions. 384 """ 385 mapping = self.unravel_index() 386 return [mapping[device_id] for device_id in self.local_device_ids()] 387 388 def local_devices(self) -> List[str]: 389 """Returns a list of local device specs represented as strings.""" 390 return [d.to_string() for d in self._local_devices] 391 392 def min_global_device_id(self) -> int: 393 """Returns the minimum global device ID.""" 394 # global_device_ids sequentially increases. 395 return self._global_device_ids.flatten()[0] 396 397 @property 398 def name(self) -> str: 399 return self._name 400 401 def num_local_devices(self) -> int: 402 """Returns the number of local devices.""" 403 return len(self._local_devices) 404 405 def shape(self) -> List[int]: 406 """Returns the shape of the mesh.""" 407 return [self.dim_size(dim) for dim in self._dim_names] 408 409 @property 410 def size(self) -> int: 411 return len(np.ravel(self._global_device_ids)) 412 413 @property 414 def strides(self) -> List[int]: 415 """Returns the strides tensor array for this mesh. 416 417 If the mesh shape is `[a, b, c, d]`, then the strides array can be computed 418 as `[b*c*d, c*d, d, 1]`. This array can be useful in computing local device 419 offsets given a device ID. Using the same example, the device coordinates of 420 the mesh can be computed as: 421 422 ``` 423 [(device_id / (b*c*d)) % a, 424 (device_id / (c*d)) % b, 425 (device_id / (d)) % c, 426 (device_id) % d] 427 ``` 428 429 This is the same as `(device_id // mesh.strides) % mesh.shape`. 430 431 Returns: 432 The mesh strides as an integer tensor. 433 """ 434 return self._strides 435 436 def to_string(self) -> str: 437 """Returns string representation of Mesh.""" 438 439 # Get proto representation 440 mesh_proto = self.as_proto() 441 # Separate individual elements with ','. 442 name = mesh_proto.name 443 dim_str = ','.join( 444 dim.name + '=' + str(dim.size) for dim in mesh_proto.mesh_dimensions) 445 global_ids = ','.join(str(id) for id in mesh_proto.global_device_ids) 446 local_ids = ','.join(str(id) for id in mesh_proto.local_device_ids) 447 devices = ','.join(dev for dev in mesh_proto.local_devices) 448 components = [name, dim_str, global_ids, local_ids, devices] 449 if mesh_proto.global_devices: 450 global_devices = ','.join(dev for dev in mesh_proto.global_devices) 451 components.append(global_devices) 452 # Separate mesh components with '|'. 453 return '|'.join(components) 454 455 def unravel_index(self): 456 """Returns a dictionary from device ID to {dim_name: dim_index}. 457 458 For example, for a 3x2 mesh, return this: 459 460 ``` 461 { 0: {'x': 0, 'y', 0}, 462 1: {'x': 0, 'y', 1}, 463 2: {'x': 1, 'y', 0}, 464 3: {'x': 1, 'y', 1}, 465 4: {'x': 2, 'y', 0}, 466 5: {'x': 2, 'y', 1} } 467 ``` 468 """ 469 idx_ranges = [ 470 range(self.dim_size(dim_name)) for dim_name in self._dim_names 471 ] 472 mesh_pos = itertools.product(*idx_ranges) 473 mapping = {} 474 for device_id, device_pos in enumerate(mesh_pos): 475 device_loc = {} 476 for dim_name, dim_index in zip(self._dim_names, device_pos): 477 device_loc[dim_name] = dim_index 478 mapping[device_id] = device_loc 479 return mapping 480 481 482# TODO(hthu): Consider making this class immutable. 483@tf_export('experimental.dtensor.Layout', v1=[]) 484class Layout(object): 485 """Represents the layout information of a DTensor. 486 487 A layout describes how a distributed tensor is partitioned across a mesh (and 488 thus across devices). For each axis of the tensor, the corresponding 489 sharding spec indicates which dimension of the mesh it is sharded over. A 490 special sharding spec `UNSHARDED` indicates that axis is replicated on 491 all the devices of that mesh. 492 493 For example, let's consider a 1-D mesh: 494 495 ``` 496 Mesh(["TPU:0", "TPU:1", "TPU:2", "TPU:3", "TPU:4", "TPU:5"], [("x", 6)]) 497 ``` 498 499 This mesh arranges 6 TPU devices into a 1-D array. `Layout([UNSHARDED], mesh)` 500 is a layout for rank-1 tensor which is replicated on the 6 devices. 501 502 For another example, let's consider a 2-D mesh: 503 504 ``` 505 Mesh(["TPU:0", "TPU:1", "TPU:2", "TPU:3", "TPU:4", "TPU:5"], 506 [("x", 3), ("y", 2)]) 507 ``` 508 509 This mesh arranges 6 TPU devices into a `3x2` 2-D array. 510 `Layout(["x", UNSHARDED], mesh)` is a layout for rank-2 tensor whose first 511 axis is sharded on mesh dimension "x" and the second axis is replicated. If we 512 place `np.arange(6).reshape((3, 2))` using this layout, the individual 513 components tensors would look like: 514 515 ``` 516 Device | Component 517 TPU:0 [[0, 1]] 518 TPU:1 [[0, 1]] 519 TPU:2 [[2, 3]] 520 TPU:3 [[2, 3]] 521 TPU:4 [[4, 5]] 522 TPU:5 [[4, 5]] 523 ``` 524 525 """ 526 527 def __init__(self, sharding_specs: List[str], mesh: Mesh): 528 """Builds a Layout from a list of dimension names and a Mesh. 529 530 Args: 531 sharding_specs: List of sharding specifications, each corresponding to a 532 tensor axis. Each specification (dim_sharding) can either be a mesh 533 dimension or the special value UNSHARDED. 534 mesh: A mesh configuration for the Tensor. 535 536 Returns: 537 A valid Layout built with given layout & mesh. 538 """ 539 # Validate mesh 540 if not isinstance(mesh, Mesh): 541 raise ValueError('mesh is not a valid Mesh object.') 542 543 # Validate sharding spec 544 for _, dim_sharding in enumerate(sharding_specs): 545 # If special value no need to check for uniqueness, just skip. 546 if dim_sharding == UNSHARDED or dim_sharding == MATCH: 547 continue 548 # Check dim_sharding is unique. 549 if sharding_specs.count(dim_sharding) > 1: 550 raise ValueError( 551 ('Mesh dimension {mesh_dim} was repeated in sharding ' + 552 'specification {sharding_specs}. Mesh dimensions must be unique ' + 553 'in a layout.').format( 554 mesh_dim=dim_sharding, sharding_specs=sharding_specs)) 555 # Check dim_sharding is mesh dimension. 556 if dim_sharding not in mesh: 557 raise ValueError( 558 ('{dim_sharding}: A dimension sharding must either be a ' + 559 'valid mesh dimension or UNSHARDED.').format( 560 dim_sharding=dim_sharding)) 561 562 # Set object's state 563 self.sharding_specs = sharding_specs 564 self.rank = len(sharding_specs) 565 self.mesh = mesh 566 self.shape = [self.num_shards(i) for i in range(self.rank)] 567 568 def __eq__(self, other) -> bool: 569 return self.serialized_string() == other.serialized_string() 570 571 def __repr__(self) -> str: 572 return f'Layout(sharding_specs={self.sharding_specs}, mesh={self.mesh})' 573 574 def as_proto(self) -> layout_pb2.LayoutProto: 575 """Create a proto representation of a layout.""" 576 layout_proto = layout_pb2.LayoutProto() 577 578 for dim_sharding in self.sharding_specs: 579 tensor_dim = layout_proto.sharding_specs.add() 580 tensor_dim.sharding_spec = dim_sharding 581 582 layout_proto.mesh_config.CopyFrom(self.mesh_proto()) 583 584 return layout_proto 585 586 @staticmethod 587 def batch_sharded(mesh: Mesh, batch_dim: str, rank: int) -> 'Layout': 588 """Returns a layout sharded on batch dimension.""" 589 return Layout([batch_dim] + [UNSHARDED] * (rank - 1), mesh) 590 591 def delete(self, dims: List[int]) -> 'Layout': 592 """Returns the layout with the give dimensions deleted.""" 593 if not isinstance(dims, list): 594 dims = [dims] 595 new_specs = [ 596 spec for i, spec in enumerate(self.sharding_specs) if i not in dims 597 ] 598 return Layout(new_specs, self.mesh) 599 600 @staticmethod 601 def from_str(layout_str: bytes) -> 'Layout': 602 """Creates an instance from a serialized Protobuf binary string.""" 603 layout_proto = layout_pb2.LayoutProto() 604 layout_proto.ParseFromString(layout_str) 605 sharding_specs = [ 606 sharding_spec.sharding_spec 607 for sharding_spec in layout_proto.sharding_specs 608 ] 609 mesh = Mesh.from_proto(layout_proto.mesh_config) 610 return Layout(sharding_specs, mesh) 611 612 @staticmethod 613 def from_string(layout_str: str) -> 'Layout': 614 """Creates an instance from a human-readable string.""" 615 layout_parts = layout_str.split(' ') 616 if len(layout_parts) != 2: 617 raise ValueError( 618 'layout string must contain two parts: specs and mesh. But got {}.' 619 .format(layout_str)) 620 621 sharding_specs_str = layout_parts[0].replace('sharding_specs:', '') 622 mesh_str = layout_parts[1].replace('mesh:', '') 623 624 sharding_specs = sharding_specs_str.split(',')[:-1] 625 626 mesh = Mesh.from_string(mesh_str) 627 layout = Layout(sharding_specs, mesh) 628 return layout 629 630 @staticmethod 631 def inner_sharded(mesh: Mesh, inner_dim: str, rank: int) -> 'Layout': 632 """Returns a layout sharded on inner dimension.""" 633 return Layout([UNSHARDED] * (rank - 1) + [inner_dim], mesh) 634 635 def is_fully_replicated(self) -> bool: 636 """Returns True if all tensor axes are replicated.""" 637 return all([self.num_shards(i) == 1 for i in range(self.rank)]) 638 639 def mesh_proto(self) -> layout_pb2.MeshProto: 640 """Returns the underlying mesh in Protobuf format.""" 641 return self.mesh.as_proto() 642 643 def num_shards(self, idx: int) -> int: 644 """Returns the number of shards for tensor dimension `idx`.""" 645 dim_sharding = self.sharding_specs[idx] 646 if dim_sharding == UNSHARDED: 647 return 1 648 if dim_sharding == MATCH: 649 return -1 650 return self.mesh.dim_size(dim_sharding) 651 652 def offset_to_shard(self): 653 """Mapping from offset in a flattened list to shard index.""" 654 unravel_index = self.mesh.unravel_index() 655 locations = [None] * self.mesh.size 656 for offset, mesh_loc in unravel_index.items(): 657 loc = [] 658 for dim_sharding in self.sharding_specs: 659 if dim_sharding == UNSHARDED: 660 loc.append(0) 661 else: 662 loc.append(mesh_loc[dim_sharding]) 663 locations[offset] = tuple(loc) 664 return locations 665 666 def offset_tuple_to_global_index(self, offset_tuple): 667 """Mapping from offset to index in global tensor.""" 668 index = 0 669 for i, o in enumerate(offset_tuple): 670 m = 1 671 for x in range(i + 1, self.rank): 672 m = m * self.num_shards(x) 673 index = index + m * o 674 return index 675 676 @staticmethod 677 def replicated(mesh: Mesh, rank: int) -> 'Layout': 678 """Returns a replicated layout of rank `rank`.""" 679 return Layout([UNSHARDED] * rank, mesh) 680 681 def serialized_string(self) -> bytes: 682 """Returns a serialized Protobuf binary string representation.""" 683 return self.as_proto().SerializeToString() 684 685 # A layout with no sharding specs is acceptable, therefore we only check the 686 # mesh. 687 def to_string(self) -> str: 688 """Returns a human-readable string representation.""" 689 sharding_spec_str = 'sharding_specs:' 690 # Add comma after each instruction. 691 for spec in self.sharding_specs: 692 sharding_spec_str += spec + ',' 693 694 mesh_str = 'mesh:' + self.mesh.to_string() 695 return sharding_spec_str + ' ' + mesh_str 696 697 def unravel(self, unpacked_tensors: List[np.ndarray]) -> np.ndarray: 698 """Convert a flattened list of shards into a sharded array.""" 699 unravelled = np.ndarray([self.num_shards(i) for i in range(self.rank)], 700 dtype=np.object) 701 for offset, loc in enumerate(self.offset_to_shard()): 702 unravelled[loc] = unpacked_tensors[offset] 703 return unravelled 704