1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3# Copyright (c) Meta Platforms, Inc. and affiliates 4import inspect 5import warnings 6from typing import Any, Callable, cast, Optional, Sequence, Tuple 7 8import torch 9import torch.distributed.tensor._dispatch as op_dispatch 10import torch.distributed.tensor._random as random 11import torch.nn as nn 12from torch.distributed.device_mesh import _mesh_resources, DeviceMesh 13from torch.distributed.tensor._collective_utils import check_tensor_meta, mesh_broadcast 14from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta 15from torch.distributed.tensor._random import ( 16 is_rng_supported_mesh, 17 OffsetBasedRNGTracker, 18) 19from torch.distributed.tensor._redistribute import ( 20 Redistribute, 21 redistribute_local_tensor, 22) 23from torch.distributed.tensor._utils import ( 24 compute_global_tensor_info, 25 compute_local_shape, 26 normalize_to_torch_size, 27) 28from torch.distributed.tensor.placement_types import ( 29 Partial, 30 Placement, 31 Replicate, 32 Shard, 33) 34 35 36__all__ = [ 37 "DTensor", 38 "distribute_tensor", 39 "distribute_module", 40 "ones", 41 "empty", 42 "full", 43 "rand", 44 "randn", 45 "zeros", 46] 47 48aten = torch.ops.aten 49 50 51# NOTE [Autograd interaction between torch.Tensor] 52# 53# The autograd functions defined below are being used by the public 54# facing APIs (i.e. from_local, to_local) to ensure DTensor to work 55# together with torch.Tensor within the autograd engine. This 56# allows DTensor to only exist on part of the module hierarchy. 57# 58# As an example, we have the a module that consists of submodules 59# A, B, and C, the execution flow would be like: 60# input(torch.Tensor) -> Module A -> Module B -> Module C -> output (torch.Tensor) 61# 62# Suppose I only want to make Module B be a sharded module with 63# DTensor params, the following forward/backward should work: 64# 65# input(torch.Tensor) -> Module A 66# -> DTensor input (from_local) -> Sharded Module B -> DTensor output 67# -> torch.Tensor output (to_local) -> Module C 68# 69# So from_local/to_local must be Autograd functions. 70# 71class _ToTorchTensor(torch.autograd.Function): 72 @staticmethod 73 def forward( # type: ignore[override] 74 ctx, 75 input: "DTensor", 76 grad_placements: Optional[Sequence[Placement]], 77 ): 78 ctx.dtensor_spec = input._spec 79 ctx.grad_placements = grad_placements 80 local_tensor = input._local_tensor 81 82 # We need to return a fresh Tensor object there as autograd metadata 83 # will be inplaced into it. So we don't want to pollute the Tensor 84 # object stored in the _local_tensor of this DTensor. 85 return local_tensor.view_as(local_tensor) 86 87 @staticmethod 88 def backward(ctx, grad_output: torch.Tensor): # type: ignore[override] 89 dtensor_spec = ctx.dtensor_spec 90 mesh = dtensor_spec.mesh 91 grad_placements = ctx.grad_placements 92 dtensor_meta = dtensor_spec.tensor_meta 93 94 _, tensor_stride = compute_global_tensor_info( 95 grad_output, mesh, dtensor_spec.placements 96 ) 97 tensor_stride = tuple(tensor_stride) 98 grad_placements = grad_placements or dtensor_spec.placements 99 grad_spec = DTensorSpec( 100 mesh, 101 grad_placements, 102 tensor_meta=TensorMeta( 103 shape=dtensor_meta.shape, 104 stride=tensor_stride, 105 dtype=dtensor_meta.dtype, 106 ), 107 ) 108 109 return ( 110 DTensor( 111 grad_output, 112 grad_spec, 113 requires_grad=grad_output.requires_grad, 114 ), 115 None, 116 ) 117 118 119class _FromTorchTensor(torch.autograd.Function): 120 @staticmethod 121 def forward( # type: ignore[override] 122 ctx, # pyre-ignore[2]: Parameter must be annotated. 123 input: torch.Tensor, 124 device_mesh: DeviceMesh, 125 placements: Tuple[Placement, ...], 126 run_check: bool, 127 shape: Optional[torch.Size] = None, 128 stride: Optional[Tuple[int, ...]] = None, 129 ) -> "DTensor": 130 ctx.previous_placement = placements 131 ctx.previous_device_mesh = device_mesh 132 133 if shape and stride: 134 tensor_shape, tensor_stride = shape, stride 135 elif not shape and not stride: 136 # if it's not by default run_check, we assume user is certain that each 137 # rank has the same tensor shape, and we just use that to calculate the 138 # global shape 139 global_shape, global_stride = compute_global_tensor_info( 140 input, device_mesh, placements 141 ) 142 tensor_shape, tensor_stride = torch.Size(global_shape), tuple(global_stride) 143 else: 144 raise RuntimeError( 145 f"Found shape:{shape}, stride:{stride}.", 146 "Please pass both shape and stride at the same time.", 147 ) 148 149 if device_mesh.get_coordinate() is None: 150 # if the global rank is not participating in the device mesh, we 151 # simply set the local tensor to an empty tensor 152 input = input.new_empty(0, requires_grad=input.requires_grad) 153 elif run_check: 154 # TODO: support uneven sharding when global shape/stride not passed, by 155 # building the global TensorMeta during check_tensor_meta 156 check_shape_stride = not shape and not stride 157 check_tensor_meta(input, check_shape_stride=check_shape_stride) 158 # TODO: See if we need to make this run_check logic 159 # have a corresponding backward. 160 for idx, placement in enumerate(placements): 161 if placement.is_replicate(): 162 # broadcast rank 0 tensor to all ranks 163 # only broadcast if run_check is True 164 input = input.contiguous() 165 mesh_broadcast(input, device_mesh, mesh_dim=idx) 166 167 dist_spec = DTensorSpec( 168 device_mesh, 169 placements, 170 tensor_meta=TensorMeta( 171 tensor_shape, 172 tensor_stride, 173 input.dtype, 174 ), 175 ) 176 177 # We want a fresh Tensor object that shares memory with the input tensor 178 dist_tensor = DTensor( 179 input.view_as(input), 180 dist_spec, 181 # requires_grad of the dist tensor depends on if input 182 # requires_grad or not 183 requires_grad=input.requires_grad, 184 ) 185 return dist_tensor 186 187 @staticmethod 188 def backward(ctx, grad_output: "DTensor"): # type: ignore[override] 189 previous_placement = ctx.previous_placement 190 previous_device_mesh = ctx.previous_device_mesh 191 192 # reshard to the placement when creating DistributedTensor 193 # so that the gradient layout matches, and we could return 194 # local gradients directly 195 if grad_output.placements != previous_placement: 196 current_spec = grad_output._spec 197 target_spec = DTensorSpec( 198 previous_device_mesh, 199 previous_placement, 200 tensor_meta=grad_output._spec.tensor_meta, 201 ) 202 local_tensor = grad_output._local_tensor 203 output = redistribute_local_tensor( 204 local_tensor, current_spec, target_spec, is_backward=True 205 ) 206 # TODO: return the redistributed local tensor directly without 207 # differentiable backward. see if this make sense for all cases. 208 return output, None, None, None, None, None 209 210 # TODO: backward is also differentiable now, add a test 211 # to test higher level gradients. 212 return grad_output.to_local(), None, None, None, None, None 213 214 215class DTensor(torch.Tensor): 216 """ 217 ``DTensor`` (Distributed Tensor) is a subclass of ``torch.Tensor`` that provides single-device like 218 abstraction to program with multi-device ``torch.Tensor``. It describes the distributed tensor sharding 219 layout (DTensor Layout) through the :class:`DeviceMesh` and following types of :class:`Placement`: 220 221 * :class:`Shard`: Tensor sharded on the tensor dimension ``dim`` on the devices of the ``DeviceMesh`` dimension 222 * :class:`Replicate`: Tensor replicated on the devices of the ``DeviceMesh`` dimension 223 * :class:`Partial`: Tensor is pending reduction on the devices of the ``DeviceMesh`` dimension 224 225 When calling PyTorch operators, ``DTensor`` overrides the PyTorch operators to perform sharded computation and issue 226 communications whenever necessary. Along with the operator computation, ``DTensor`` will transform or propagate the 227 placements (DTensor Layout) properly (based on the operator semantic itself) and generate new ``DTensor`` outputs. 228 229 To ensure numerical correctness of the ``DTensor`` sharded computation when calling PyTorch operators, ``DTensor`` 230 requires every Tensor argument of the operator be DTensor. 231 232 """ 233 234 _local_tensor: torch.Tensor 235 _spec: DTensorSpec 236 __slots__ = ["_local_tensor", "_spec"] 237 238 # _op_dispatcher instance as a class attribute to handle runtime dispatching logic 239 _op_dispatcher: op_dispatch.OpDispatcher = op_dispatch.OpDispatcher() 240 241 @staticmethod 242 @torch._disable_dynamo 243 def __new__( 244 cls, 245 local_tensor: torch.Tensor, 246 spec: DTensorSpec, 247 *, 248 requires_grad: bool, 249 ) -> "DTensor": 250 """ 251 Construct a DTensor from a local tensor, device mesh, and placement and 252 other tensor properties (i.e. shape, requires_grad, strides, etc). 253 254 .. note:: This is not a public API and it's only supposed to be used by the 255 operator implementations and internals. If you want to construct a 256 DTensor from a local tensor, consider using ``DTensor.from_local``, if 257 you want to construct a DTensor from a "global" tensor (where you 258 already have tensor initialized and want to shard this tensor), 259 consider using ``distribute_tensor``. 260 """ 261 if local_tensor.requires_grad and not requires_grad: 262 warnings.warn( 263 "To construct DTensor from torch.Tensor, it's recommended to " 264 "use local_tensor.detach() and make requires_grad consistent." 265 ) 266 267 # new method instruct wrapper tensor from local_tensor and add 268 # placement spec, it does not do actual distribution 269 assert spec.tensor_meta is not None, "TensorMeta should not be None!" 270 r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] 271 cls, 272 spec.tensor_meta.shape, 273 strides=spec.tensor_meta.stride, 274 dtype=local_tensor.dtype, 275 device=local_tensor.device, 276 layout=local_tensor.layout, 277 requires_grad=requires_grad, 278 ) 279 280 r._spec = spec 281 r._local_tensor = local_tensor 282 return r 283 284 # pyre-fixme[14]: `__repr__` overrides method defined in `DTensor` inconsistently. 285 # pyre-fixme[3]: Return type must be annotated. 286 def __repr__(self): 287 # TODO: consider all_gather the local tensors for better debugging 288 return f"DTensor(local_tensor={self._local_tensor}, device_mesh={self._spec.mesh}, placements={self._spec.placements})" 289 290 def __tensor_flatten__(self): 291 """ 292 protocol to inform how to flatten a DTensor to local tensor 293 for PT2 tracing 294 """ 295 return ["_local_tensor"], (self._spec, self.requires_grad) 296 297 @staticmethod 298 def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): 299 assert ( 300 flatten_spec is not None 301 ), "Expecting spec to be not None from `__tensor_flatten__` return value!" 302 local_tensor = inner_tensors["_local_tensor"] 303 spec, requires_grad = flatten_spec 304 unflatten_tensor_meta = TensorMeta( 305 shape=outer_size, 306 stride=outer_stride, 307 dtype=spec.tensor_meta.dtype, 308 ) 309 unflatten_spec = DTensorSpec( 310 spec.mesh, 311 spec.placements, 312 tensor_meta=unflatten_tensor_meta, 313 ) 314 return DTensor( 315 local_tensor, 316 unflatten_spec, 317 requires_grad=requires_grad, 318 ) 319 320 def __coerce_tangent_metadata__(self): 321 if not any(isinstance(p, Partial) for p in self.placements): 322 return self 323 placements = [ 324 Replicate() if isinstance(p, Partial) else p for p in self.placements 325 ] 326 return self.redistribute(device_mesh=self.device_mesh, placements=placements) 327 328 def __coerce_same_metadata_as_tangent__(self, flatten_spec): 329 (spec, _) = flatten_spec # Result of tensor_flatten() 330 return self.redistribute( 331 device_mesh=self.device_mesh, 332 placements=spec.placements, 333 ) 334 335 @classmethod 336 @torch._disable_dynamo 337 # pyre-fixme[3]: Return type must be annotated. 338 # pyre-fixme[2]: Parameter must be annotated. 339 def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 340 return DTensor._op_dispatcher.dispatch( 341 func, 342 args, 343 kwargs or {}, 344 ) 345 346 @staticmethod 347 def from_local( 348 local_tensor: torch.Tensor, 349 device_mesh: Optional[DeviceMesh] = None, 350 placements: Optional[Sequence[Placement]] = None, 351 *, 352 run_check: bool = False, 353 shape: Optional[torch.Size] = None, 354 stride: Optional[Tuple[int, ...]] = None, 355 ) -> "DTensor": 356 """ 357 Create a :class:`DTensor` from a local torch.Tensor on each rank 358 according to the ``device_mesh`` and ``placements`` specified. 359 360 Args: 361 local_tensor (torch.Tensor): local torch.Tensor on each rank. 362 device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to place the 363 tensor, if not specified, must be called under a DeviceMesh 364 context manager, default: None 365 placements (List[:class:`Placement`], optional): the placements that 366 describes how to place the local torch.Tensor on DeviceMesh, must 367 have the same number of elements as ``device_mesh.ndim``. 368 369 Keyword args: 370 run_check (bool, optional): at a cost of extra communications, perform 371 sanity check across ranks to check each local tensor's meta information 372 to ensure correctness. If have :class:`Replicate` in ``placements``, the 373 data on first rank of the device mesh dimension will be broadcasted 374 to other ranks. default: False 375 shape (torch.Size, optional): A List of int which specifies the size of 376 DTensor which build on top of `local_tensor`. Note this needs to be 377 provided if the shape of ``local_tensor`` are different across the ranks. 378 If not provided, ``shape`` will be computed assuming the given distributed 379 tensor is evenly sharded across ranks. default: None 380 stride (tuple, optional): A List of int which specifies the stride of DTensor. 381 If not provided, ``stride`` will be computed assuming the given distributed 382 tensor is evenly sharded across ranks. default: None 383 384 Returns: 385 A :class:`DTensor` object 386 387 .. note:: When ``run_check=False``, it is the user's responsibility to ensure the 388 local tensor passed in is correct across ranks (i.e. the tensor is sharded for 389 the ``Shard(dim)`` placement or replicated for the ``Replicate()`` placement). 390 If not, the behavior of the created DTensor is undefined. 391 392 .. note:: ``from_local`` is differentiable, the `requires_grad` of the created 393 `DTensor` object will depend on if `local_tensor` requires_grad or not. 394 """ 395 # if same shape/dtype, no need to run_check, if not, must allgather 396 # the metadatas to check the size/dtype across ranks 397 # There should be no data communication unless there's replication 398 # strategy, where we broadcast the replication from the first rank 399 # in the mesh dimension 400 device_mesh = device_mesh or _mesh_resources.get_current_mesh() 401 device_type = device_mesh.device_type 402 403 # convert the local tensor to desired device base on device mesh's device_type 404 if device_type != local_tensor.device.type and not local_tensor.is_meta: 405 local_tensor = local_tensor.to(device_type) 406 407 # set default placements to replicated if not specified 408 if placements is None: 409 placements = [Replicate() for _ in range(device_mesh.ndim)] 410 else: 411 placements = list(placements) 412 for idx, placement in enumerate(placements): 413 # normalize shard dim to be positive 414 if placement.is_shard(): 415 placement = cast(Shard, placement) 416 if placement.dim < 0: 417 placements[idx] = Shard(placement.dim + local_tensor.ndim) 418 419 # `from_local` is differentiable, and the gradient of the dist tensor this function 420 # created should flow back the gradients to the local_tensor, so we call an autograd 421 # function to construct the dist tensor instead. 422 return _FromTorchTensor.apply( # pyre-ignore[16]: autograd func 423 local_tensor, 424 device_mesh, 425 tuple(placements), 426 run_check, 427 shape, 428 stride, 429 ) 430 431 def to_local( 432 self, *, grad_placements: Optional[Sequence[Placement]] = None 433 ) -> torch.Tensor: 434 """ 435 Get the local tensor of this DTensor on its current rank. For sharding it returns 436 a local shard of the logical tensor view, for replication it returns the replica on 437 its current rank. 438 439 Keyword args: 440 grad_placements (List[:class:`Placement`], optional): the placements describes 441 the future layout of any gradient layout of the Tensor returned from this 442 function. 443 `to_local` converts DTensor to local tensor and the returned local tensor 444 might not be used as the original DTensor layout later in the code. This 445 argument is the hint that user can give to autograd in case the gradient 446 layout of the returned tensor does not match the original DTensor layout. 447 If not specified, we will assume the gradient layout remains the same 448 as the original DTensor and use that for gradient computation. 449 450 Returns: 451 A :class:`torch.Tensor` or ``AsyncCollectiveTensor`` object. it represents the 452 local tensor on its current rank. When an ``AsyncCollectiveTensor`` object is returned, 453 it means the local tensor is not ready yet (i.e. communication is not finished). In this 454 case, user needs to call ``wait`` to wait the local tensor to be ready. 455 456 .. note:: ``to_local`` is differentiable, the ``requires_grad`` of the local tensor returned 457 will depend on if the `DTensor` requires_grad or not. 458 """ 459 if not torch.is_grad_enabled(): 460 return self._local_tensor 461 462 if grad_placements is not None and not isinstance(grad_placements, tuple): 463 grad_placements = tuple(grad_placements) 464 return _ToTorchTensor.apply( 465 self, grad_placements 466 ) # pyre-ignore[16]: autograd func 467 468 def redistribute( 469 self, 470 device_mesh: Optional[DeviceMesh] = None, 471 placements: Optional[Sequence[Placement]] = None, 472 *, 473 async_op: bool = False, 474 ) -> "DTensor": 475 """ 476 ``redistribute`` performs necessary collective operations that redistribute the current 477 DTensor from its current placements to a new placements, or from is current DeviceMesh 478 to a new DeviceMesh. i.e. we can turn a Sharded DTensor to a Replicated DTensor by 479 specifying a Replicate placement for each dimension of the DeviceMesh. 480 481 When redistributing from current to the new placements on one device mesh dimension, we 482 will perform the following operations including communication collective or local operation: 483 484 1. ``Shard(dim)`` -> ``Replicate()``: ``all_gather`` 485 2. ``Shard(src_dim)`` -> ``Shard(dst_dim)``: ``all_to_all`` 486 3. ``Replicate()`` -> ``Shard(dim)``: local chunking (i.e. ``torch.chunk``) 487 4. ``Partial()`` -> ``Replicate()``: ``all_reduce`` 488 5. ``Partial()`` -> ``Shard(dim)``: ``reduce_scatter`` 489 490 491 ``redistribute`` would correctly figure out the necessary redistribute steps for DTensors 492 that are created either on 1-D or N-D DeviceMesh. 493 494 Args: 495 device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to place the 496 DTensor. If not specified, it would use the current DTensor's DeviceMesh. 497 default: None 498 placements (List[:class:`Placement`], optional): the new placements that 499 describes how to place the DTensor into the DeviceMesh, must 500 have the same number of elements as ``device_mesh.ndim``. 501 default: replicate on all mesh dimensions 502 503 Keyword args: 504 async_op (bool, optional): whether to perform the DTensor redistribute operation 505 asynchronously or not. Default: False 506 507 Returns: 508 A :class:`DTensor` object 509 510 .. note:: ``redistribute`` is differentiable, which means user do not need to worry about 511 the backward formula of the redistribute operation. 512 513 .. note:: ``redistribute`` currently only supports redistributing DTensor on the same DeviceMesh, 514 Please file an issue if you need to redistribute DTensor to different DeviceMesh. 515 """ 516 # NOTE: This redistribute API currently only supports out 517 # of place redistribution, i.e. it always create a new 518 # DTensor object and leave the original one unchanged. 519 520 # if device_mesh is not specified, use the current device_mesh 521 device_mesh = device_mesh or self.device_mesh 522 # raise error if new placements not specified 523 if placements is None: 524 raise RuntimeError("placements is needed for redistribute!") 525 526 placements = list(placements) 527 for i, placement in enumerate(placements): 528 if placement.is_partial(): 529 raise RuntimeError( 530 "Can not redistribute to Partial, redistributing to Partial is for internal use only!" 531 ) 532 elif isinstance(placement, Shard) and placement.dim < 0: 533 # normalize shard dim to be positive 534 placements[i] = Shard(placement.dim + self.ndim) 535 placements = tuple(placements) 536 537 # pyre-fixme[16]: `Redistribute` has no attribute `apply`. 538 return Redistribute.apply(self, device_mesh, placements, async_op) 539 540 def full_tensor( 541 self, *, grad_placements: Optional[Sequence[Placement]] = None 542 ) -> torch.Tensor: 543 """ 544 Return the full tensor of this DTensor. It will perform necessary collectives 545 to gather the local tensors from other ranks in its DeviceMesh and concatenate 546 them together. It's a syntatic sugar of the following code: 547 548 ``dtensor.redistribute(placements=[Replicate()] * mesh.ndim).to_local()`` 549 550 Keyword args: 551 grad_placements (List[:class:`Placement`], optional): the placements describes 552 the future layout of any gradient layout of the full Tensor returned from this 553 function. 554 `full_tensor` converts DTensor to a full torch.Tensor and the returned torch.tensor 555 might not be used as the original replicated DTensor layout later in the code. This 556 argument is the hint that user can give to autograd in case the gradient 557 layout of the returned tensor does not match the original replicated DTensor layout. 558 If not specified, we will assume the gradient layout of the full tensor be replicated. 559 560 Returns: 561 A :class:`torch.Tensor` object that represents the full tensor of this DTensor. 562 563 .. note:: ``full_tensor`` is differentiable. 564 """ 565 566 redist_res = self.redistribute( 567 placements=[Replicate()] * self.device_mesh.ndim, async_op=False 568 ) 569 return _ToTorchTensor.apply(redist_res, grad_placements) 570 571 @property 572 def device_mesh(self) -> DeviceMesh: 573 """ 574 The :class:`DeviceMesh` attribute that associates with this DTensor object. 575 576 .. note:: ``device_mesh`` is a read-only property, it can not be set. 577 """ 578 return self._spec.mesh 579 580 @property 581 def placements(self) -> Tuple[Placement, ...]: 582 """ 583 The placements attribute of this DTensor that describes the layout of this 584 DTensor on the its DeviceMesh. 585 586 .. note:: ``placements`` is a read-only property, it can not be set. 587 """ 588 return self._spec.placements 589 590 def __create_write_items__(self, fqn: str, object: Any): 591 from torch.distributed.checkpoint.planner_helpers import ( 592 _create_write_items_for_dtensor, 593 ) 594 595 if hasattr(self._local_tensor, "__create_write_items__"): 596 return self._local_tensor.__create_write_items__(fqn, object) # type: ignore[attr-defined] 597 elif isinstance(self._local_tensor, torch.Tensor): 598 return [_create_write_items_for_dtensor(fqn, object)] 599 else: 600 raise RuntimeError("Unsupported tensor type!") 601 602 def __create_chunk_list__(self): 603 from torch.distributed.checkpoint.planner_helpers import ( 604 _create_chunk_from_dtensor, 605 ) 606 607 if hasattr(self._local_tensor, "__create_chunk_list__"): 608 return self._local_tensor.__create_chunk_list__() # type: ignore[attr-defined] 609 elif isinstance(self._local_tensor, torch.Tensor): 610 return [_create_chunk_from_dtensor(self)] 611 else: 612 raise RuntimeError("Unsupported tensor type!") 613 614 def __get_tensor_shard__(self, index): 615 if hasattr(self._local_tensor, "__get_tensor_shard__"): 616 return self._local_tensor.__get_tensor_shard__(index) # type: ignore[attr-defined] 617 elif isinstance(self._local_tensor, torch.Tensor): 618 return self.to_local() 619 else: 620 raise RuntimeError("Unsupported tensor type!") 621 622 623def distribute_tensor( 624 tensor: torch.Tensor, 625 device_mesh: Optional[DeviceMesh] = None, 626 placements: Optional[Sequence[Placement]] = None, 627) -> DTensor: 628 """ 629 Distribute a leaf ``torch.Tensor`` (i.e. nn.Parameter/buffers) to the ``device_mesh`` according 630 to the ``placements`` specified. The rank of ``device_mesh`` and ``placements`` must be the 631 same. The ``tensor`` to distribute is the logical or "global" tensor, and the API would use 632 the ``tensor`` from first rank of the DeviceMesh dimension as the source of truth to perserve 633 the single-device semantic. If you want to construct a DTensor in the middle of the Autograd 634 computation, please use :meth:`DTensor.from_local` instead. 635 636 Args: 637 tensor (torch.Tensor): torch.Tensor to be distributed. Note that if you 638 want to shard a tensor on a dimension that is not evenly divisible by 639 the number of devices in that mesh dimension, we use ``torch.chunk`` 640 semantic to shard the tensor and scatter the shards. The uneven sharding 641 behavior is experimental and subject to change. 642 device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to distribute the 643 tensor, if not specified, must be called under a DeviceMesh context 644 manager, default: None 645 placements (List[:class:`Placement`], optional): the placements that 646 describes how to place the tensor on DeviceMesh, must have the same 647 number of elements as ``device_mesh.ndim``. If not specified, we will 648 by default replicate the tensor across the ``device_mesh`` from the 649 first rank of each dimension of the `device_mesh`. 650 651 Returns: 652 A :class:`DTensor` or ``XLAShardedTensor`` object. 653 654 .. note:: 655 When initialize the DeviceMesh with the ``xla`` device_type, ``distribute_tensor`` 656 return `XLAShardedTensor` instead. see `this issue <https://github.com/pytorch/pytorch/issues/92909>`__ 657 for more details. The XLA integration is experimental and subject to change. 658 """ 659 660 torch._C._log_api_usage_once("torch.dtensor.distribute_tensor") 661 662 # get default device mesh if there's nothing specified 663 device_mesh = device_mesh or _mesh_resources.get_current_mesh() 664 device_type = device_mesh.device_type 665 if device_type == "xla": 666 try: 667 # call PyTorch/XLA SPMD for `xla` backend type device mesh. 668 # This returns XLAShardedTensor 669 from torch_xla.distributed.spmd import ( # type:ignore[import] 670 xla_distribute_tensor, 671 ) 672 673 return xla_distribute_tensor( 674 tensor, device_mesh, placements 675 ) # type:ignore[return-value] 676 except ImportError as e: 677 msg = "To use DTensor API with xla, you must install the torch_xla package!" 678 raise ImportError(msg) from e 679 680 # instantiate a RNG tracker if haven't. By default DTensor uses an 681 # OffsetBasedRNGTracker to perform random operators. 682 # TODO: the value assignment to global variable is not the ideal solution 683 # we can replace it in future. 684 if not random._rng_tracker and is_rng_supported_mesh(device_mesh): 685 random._rng_tracker = OffsetBasedRNGTracker(device_type) 686 687 if not tensor.is_leaf: 688 raise RuntimeError( 689 "`distribute_tensor` should be used to distribute leaf tensors! but found non-leaf tensor!" 690 ) 691 692 # convert tensor to the corresponding device type if it's not in that device type 693 if device_type != tensor.device.type and not tensor.is_meta: 694 tensor = tensor.to(device_type) 695 696 # set default placements to replicated if not specified 697 if placements is None: 698 placements = [Replicate() for _ in range(device_mesh.ndim)] 699 700 if len(placements) != device_mesh.ndim: 701 raise ValueError( 702 f"`placements` must have the same length as `device_mesh.ndim`! " 703 f"Found placements length: {len(placements)}, and device_mesh.ndim: {device_mesh.ndim}." 704 ) 705 if isinstance(tensor, DTensor): 706 # if the tensor is already a DTensor, we need to check: 707 # 1. if the we can further shard this DTensor if the two device mesh belong to 708 # the same parenet mesh and further sharding is possible. 709 # 2. check if device mesh and placements are the same 710 if tensor.device_mesh != device_mesh: 711 raise ValueError( 712 f"Cannot distribute a DTensor with device mesh {tensor.device_mesh} " 713 f"to a different device mesh {device_mesh}." 714 ) 715 if tensor.placements != tuple(placements): 716 raise ValueError( 717 f"Cannot distribute a DTensor with placements {tensor.placements} " 718 f"to a different placements {placements}. do you want to call " 719 f"`redistribute` instead?" 720 ) 721 return tensor 722 723 local_tensor = tensor.detach() 724 725 # TODO(xilun): address sharding order 726 # distribute the tensor according to the placements. 727 placements = list(placements) 728 for idx, placement in enumerate(placements): 729 if placement.is_shard(): 730 placement = cast(Shard, placement) 731 if placement.dim < 0: 732 # normalize shard placement dim 733 placement = Shard(placement.dim + tensor.ndim) 734 placements[idx] = placement 735 local_tensor = placement._shard_tensor(local_tensor, device_mesh, idx) 736 elif placement.is_replicate(): 737 placement = cast(Replicate, placement) 738 local_tensor = placement._replicate_tensor(local_tensor, device_mesh, idx) 739 else: 740 raise RuntimeError( 741 f"Trying to distribute tensor with unsupported placements {placement} on device mesh dimension {idx}!" 742 ) 743 placements = tuple(placements) 744 745 assert local_tensor is not None, "distributing a tensor should not be None" 746 # detach the local tensor passed to DTensor since after the construction 747 # of DTensor, autograd would work on top of DTensor instead of local tensor 748 spec = DTensorSpec( 749 mesh=device_mesh, 750 placements=placements, 751 tensor_meta=TensorMeta( 752 shape=tensor.size(), 753 stride=tensor.stride(), 754 dtype=tensor.dtype, 755 ), 756 ) 757 return DTensor( 758 local_tensor.requires_grad_(tensor.requires_grad), 759 spec, 760 requires_grad=tensor.requires_grad, 761 ) 762 763 764def distribute_module( 765 module: nn.Module, 766 device_mesh: Optional[DeviceMesh] = None, 767 partition_fn: Optional[Callable[[str, nn.Module, DeviceMesh], None]] = None, 768 input_fn: Optional[Callable[[nn.Module, Any, DeviceMesh], None]] = None, 769 output_fn: Optional[Callable[[nn.Module, Any, DeviceMesh], None]] = None, 770) -> nn.Module: 771 """ 772 This function expose three functions to control the parameters/inputs/outputs of the module: 773 774 1. To perform sharding on the module before runtime execution by specifying the 775 ``partition_fn`` (i.e. allow user to convert Module parameters to :class:`DTensor` 776 parameters according to the `partition_fn` specified). 777 2. To control the inputs or outputs of the module during runtime execution by 778 specifying the ``input_fn`` and ``output_fn``. (i.e. convert the input to 779 :class:`DTensor`, convert the output back to ``torch.Tensor``) 780 781 Args: 782 module (:class:`nn.Module`): user module to be partitioned. 783 device_mesh (:class:`DeviceMesh`): the device mesh to place the module. 784 partition_fn (Callable): the function to partition parameters (i.e. shard certain 785 parameters across the ``device_mesh``). If ``partition_fn`` is not specified, 786 by default we replicate all module parameters of ``module`` across the mesh. 787 input_fn (Callable): specify the input distribution, i.e. could control how the 788 input of the module is sharded. ``input_fn`` will be installed as a module 789 ``forward_pre_hook`` (pre forward hook). 790 output_fn (Callable): specify the output distribution, i.e. could control how the 791 output is sharded, or convert it back to torch.Tensor. ``output_fn`` will be 792 installed as a module ``forward_hook`` (post forward hook). 793 794 Returns: 795 A module that contains parameters/buffers that are all ``DTensor`` s. 796 797 .. note:: 798 When initialize the DeviceMesh with the ``xla`` device_type, ``distribute_module`` 799 return nn.Module with PyTorch/XLA SPMD annotated parameters. See 800 `this issue <https://github.com/pytorch/pytorch/issues/92909>`__ 801 for more details. The XLA integration is experimental and subject to change. 802 803 """ 804 805 torch._C._log_api_usage_once("torch.dtensor.distribute_module") 806 807 device_mesh = device_mesh or _mesh_resources.get_current_mesh() 808 device_type = device_mesh.device_type 809 if device_type == "xla": 810 try: 811 # This function annotates all module parameters for auto-partitioning with 812 # PyTorch/XLA SPMD or explicitly partition to :class:`XLAShardedTensor` parameters 813 # according to the `partition_fn` specified. 814 from torch_xla.distributed.spmd import ( # type:ignore[import] 815 xla_distribute_module, 816 ) 817 818 return xla_distribute_module( 819 module, device_mesh, partition_fn, input_fn, output_fn 820 ) # type:ignore[return-value] 821 except ImportError as e: 822 msg = "To use DTensor API with xla, you must install the torch_xla package!" 823 raise ImportError(msg) from e 824 825 def replicate_module_params_buffers(m: nn.Module, mesh: DeviceMesh) -> None: 826 # This function loop over the immediate module parameters and 827 # buffers, replicate all non DTensor params/buffers to DTensor 828 # parameters/buffers, if they have not been partitioned in the 829 # partition_fn, we can't easily use `module._apply` here 830 # because we don't know what happened inside partition_fn as 831 # user could do anything, i.e. install hooks, and we want to 832 # preserve those. 833 full_replicate = [Replicate()] * mesh.ndim 834 for key, param in m._parameters.items(): 835 if param is not None and not isinstance(param, DTensor): 836 m.register_parameter( 837 key, 838 nn.Parameter(distribute_tensor(param.data, mesh, full_replicate)), 839 ) 840 for key, buffer in m._buffers.items(): 841 if buffer is not None and not isinstance(buffer, DTensor): 842 m._buffers[key] = distribute_tensor(buffer, mesh, full_replicate) 843 844 if partition_fn is None: 845 # if partition_fn not specified, we by default replicate 846 # all module params/buffers 847 for name, submod in module.named_modules(): 848 replicate_module_params_buffers(submod, device_mesh) 849 else: 850 # apply partition_fun to submodules 851 for name, submod in module.named_modules(): 852 partition_fn(name, submod, device_mesh) 853 replicate_module_params_buffers(submod, device_mesh) 854 855 # register input_fn as module forward pre hook 856 if input_fn is not None: 857 # check the input_fn signature 858 num_args = len(inspect.signature(input_fn).parameters) 859 if num_args == 2: 860 # input_fn only takes in inputs and device mesh 861 warnings.warn( 862 "Deprecating input_fn that takes two arguments (inputs, device_mesh), " 863 "please use input_fn that takes in (module, inputs, device_mesh) instead!", 864 FutureWarning, 865 stacklevel=2, 866 ) 867 module.register_forward_pre_hook(lambda _, inputs: input_fn(inputs, device_mesh)) # type: ignore[call-arg] 868 elif num_args == 3: 869 # input_fn takes in module, inputs, device mesh 870 module.register_forward_pre_hook( 871 lambda mod, inputs: input_fn(mod, inputs, device_mesh) 872 ) 873 else: 874 raise ValueError( 875 f"input_fn should take in 3 arguments, but got {num_args} arguments!" 876 ) 877 # register output_fn as module forward hook 878 if output_fn is not None: 879 num_args = len(inspect.signature(output_fn).parameters) 880 if num_args == 2: 881 # output_fn only takes in outputs and device mesh 882 warnings.warn( 883 "Deprecating output_fn that takes two arguments (inputs, device_mesh), " 884 "please use output_fn that takes in (module, inputs, device_mesh) instead!", 885 FutureWarning, 886 stacklevel=2, 887 ) 888 module.register_forward_hook( 889 lambda mod, inputs, outputs: output_fn(outputs, device_mesh) # type: ignore[call-arg] 890 ) 891 elif num_args == 3: 892 module.register_forward_hook( 893 lambda mod, inputs, outputs: output_fn(mod, outputs, device_mesh) 894 ) 895 else: 896 raise ValueError( 897 f"output_fn should take in 3 arguments, but got {num_args} arguments!" 898 ) 899 900 return module 901 902 903# Below are tensor factory function APIs, which are used to create a DTensor directly. We need 904# to make separate factory function APIs because tensor subclass could not override the tensor 905# factory methods, and we need user to call the factory functions with user intended device_mesh 906# and placements to create a proper DTensor. 907 908 909def _dtensor_init_helper( # type: ignore[no-untyped-def] 910 init_op, 911 size: torch.Size, 912 device_mesh: Optional[DeviceMesh] = None, 913 placements: Optional[Sequence[Placement]] = None, 914 **kwargs, 915) -> DTensor: 916 # from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta 917 918 # if device_mesh is None, use the one from mesh resources 919 device_mesh = device_mesh or _mesh_resources.get_current_mesh() 920 kwargs["device"] = device_mesh.device_type 921 922 # set default placements to replicated if not specified 923 placements = placements or tuple(Replicate() for _ in range(device_mesh.ndim)) 924 925 # check device_mesh againts placements 926 assert device_mesh.ndim == len( 927 placements 928 ), "mesh dimension does not match the length of placements" 929 930 assert kwargs["layout"] == torch.strided, "layout value not supported!" 931 torch_stride = torch._prims_common.make_contiguous_strides_for(size) 932 933 # get local tensor shape 934 local_shape = compute_local_shape(size, device_mesh, placements) 935 # initialize the local tensor 936 if init_op == torch.full: 937 fill_value = kwargs.pop("fill_value", 0) 938 local_tensor = init_op(local_shape, fill_value, **kwargs) 939 elif init_op == torch.rand or init_op == torch.randn: 940 # this tensor meta is not used except `shape` 941 dtype = kwargs.get("dtype", torch.get_default_dtype()) 942 943 tensor_meta = TensorMeta(size, (0,), dtype) 944 spec = DTensorSpec(device_mesh, tuple(placements), tensor_meta=tensor_meta) 945 946 if random.is_rng_supported_mesh(device_mesh) and not random._rng_tracker: 947 random._rng_tracker = random.OffsetBasedRNGTracker() 948 949 assert random._rng_tracker is not None 950 with random._rng_tracker._distribute_region(spec): 951 local_tensor = init_op(local_shape, **kwargs) 952 else: 953 local_tensor = init_op(local_shape, **kwargs) 954 955 spec = DTensorSpec( 956 device_mesh, 957 tuple(placements), 958 tensor_meta=TensorMeta( 959 size, 960 torch_stride, 961 local_tensor.dtype, 962 ), 963 ) 964 965 return DTensor( 966 local_tensor, 967 spec, 968 requires_grad=kwargs["requires_grad"], 969 ) 970 971 972def ones( # type: ignore[no-untyped-def] 973 *size, 974 dtype: Optional[torch.dtype] = None, 975 layout: torch.layout = torch.strided, 976 requires_grad: bool = False, 977 device_mesh: Optional[DeviceMesh] = None, 978 placements: Optional[Sequence[Placement]] = None, 979) -> DTensor: 980 """ 981 Returns a :class:`DTensor` filled with the scalar value 1, with the shape defined 982 by the variable argument ``size``. 983 984 Args: 985 size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. 986 Can be a variable number of arguments or a collection like a list or tuple. 987 E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) 988 989 Keyword args: 990 dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. 991 Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). 992 layout (:class:`torch.layout`, optional): the desired layout of returned DTensor. 993 Default: ``torch.strided``. 994 requires_grad (bool, optional): If autograd should record operations on the 995 returned :class:`DTensor`. Default: ``False``. 996 device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks 997 placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` 998 999 Returns: 1000 A :class:`DTensor` object on each rank 1001 """ 1002 torch_size = normalize_to_torch_size(size) 1003 1004 return _dtensor_init_helper( 1005 torch.ones, 1006 torch_size, 1007 dtype=dtype, 1008 layout=layout, 1009 requires_grad=requires_grad, 1010 device_mesh=device_mesh, 1011 placements=placements, 1012 ) 1013 1014 1015def empty( # type: ignore[no-untyped-def] 1016 *size, 1017 dtype: Optional[torch.dtype] = None, 1018 layout: torch.layout = torch.strided, 1019 requires_grad: bool = False, 1020 device_mesh: Optional[DeviceMesh] = None, 1021 placements: Optional[Sequence[Placement]] = None, 1022) -> DTensor: 1023 """ 1024 Returns a :class:`DTensor` filled with uninitialized data. The shape of the :class:`DTensor` 1025 is defined by the variable argument ``size``. 1026 1027 Args: 1028 size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. 1029 Can be a variable number of arguments or a collection like a list or tuple. 1030 E.g.: empty(1,2,3..) or empty([1,2,3..]) or empty((1,2,3..)) 1031 1032 Keyword args: 1033 dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. 1034 Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).\ 1035 layout (:class:`torch.layout`, optional): the desired layout of returned :class:`DTensor`. 1036 Default: ``torch.strided``. 1037 requires_grad (bool, optional): If autograd should record operations on the 1038 returned :class:`DTensor`. Default: ``False``. 1039 device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks 1040 placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` 1041 1042 Returns: 1043 A :class:`DTensor` object on each rank 1044 """ 1045 torch_size = normalize_to_torch_size(size) 1046 1047 return _dtensor_init_helper( 1048 torch.empty, 1049 torch_size, 1050 dtype=dtype, 1051 layout=layout, 1052 requires_grad=requires_grad, 1053 device_mesh=device_mesh, 1054 placements=placements, 1055 ) 1056 1057 1058def full( # type: ignore[no-untyped-def] 1059 size, 1060 fill_value, 1061 *, 1062 dtype: Optional[torch.dtype] = None, 1063 layout: torch.layout = torch.strided, 1064 requires_grad: bool = False, 1065 device_mesh: Optional[DeviceMesh] = None, 1066 placements: Optional[Sequence[Placement]] = None, 1067) -> DTensor: 1068 """ 1069 Returns a :class:`DTensor` filled with ``fill_value`` according to ``device_mesh`` and 1070 ``placements``, with the shape defined by the argument ``size``. 1071 1072 Args: 1073 size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. 1074 Can be a variable number of arguments or a collection like a list or tuple. 1075 E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) 1076 fill_value(Scalar): the value to fill the output tensor with. 1077 1078 Keyword args: 1079 dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. 1080 Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). 1081 layout (:class:`torch.layout`, optional): the desired layout of returned DTensor. 1082 Default: ``torch.strided``. 1083 requires_grad (bool, optional): If autograd should record operations on the 1084 returned :class:`DTensor`. Default: ``False``. 1085 device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks. 1086 placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` 1087 1088 Returns: 1089 A :class:`DTensor` object on each rank 1090 """ 1091 torch_size = normalize_to_torch_size(size) 1092 1093 return _dtensor_init_helper( 1094 torch.full, 1095 torch_size, 1096 fill_value=fill_value, 1097 dtype=dtype, 1098 layout=layout, 1099 requires_grad=requires_grad, 1100 device_mesh=device_mesh, 1101 placements=placements, 1102 ) 1103 1104 1105def rand( # type: ignore[no-untyped-def] 1106 *size, 1107 requires_grad: bool = False, 1108 dtype: Optional[torch.dtype] = None, 1109 layout: torch.layout = torch.strided, 1110 device_mesh: Optional[DeviceMesh] = None, 1111 placements: Optional[Sequence[Placement]] = None, 1112) -> DTensor: 1113 """ 1114 Returns a :class:`DTensor` filled with random numbers from a uniform distribution 1115 on the interval ``[0, 1)``. The shape of the tensor is defined by the variable 1116 argument ``size``. 1117 1118 Args: 1119 size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. 1120 Can be a variable number of arguments or a collection like a list or tuple. 1121 E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) 1122 1123 Keyword args: 1124 dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. 1125 Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). 1126 layout (:class:`torch.layout`, optional): the desired layout of returned DTensor. 1127 Default: ``torch.strided``. 1128 requires_grad (bool, optional): If autograd should record operations on the 1129 returned :class:`DTensor`. Default: ``False``. 1130 device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks. 1131 placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` 1132 1133 Returns: 1134 A :class:`DTensor` object on each rank 1135 """ 1136 torch_size = normalize_to_torch_size(size) 1137 1138 return _dtensor_init_helper( 1139 torch.rand, 1140 torch_size, 1141 dtype=dtype, 1142 layout=layout, 1143 requires_grad=requires_grad, 1144 device_mesh=device_mesh, 1145 placements=placements, 1146 ) 1147 1148 1149def randn( # type: ignore[no-untyped-def] 1150 *size, 1151 requires_grad: bool = False, 1152 dtype: Optional[torch.dtype] = None, 1153 layout: torch.layout = torch.strided, 1154 device_mesh: Optional[DeviceMesh] = None, 1155 placements: Optional[Sequence[Placement]] = None, 1156) -> DTensor: 1157 """ 1158 Returns a :class:`DTensor` filled with random numbers from a normal distribution 1159 with mean 0 and variance 1. The shape of the tensor is defined by the variable 1160 argument ``size``. 1161 1162 Args: 1163 size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. 1164 Can be a variable number of arguments or a collection like a list or tuple. 1165 E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) 1166 1167 Keyword args: 1168 dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. 1169 Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). 1170 layout (:class:`torch.layout`, optional): the desired layout of returned DTensor. 1171 Default: ``torch.strided``. 1172 requires_grad (bool, optional): If autograd should record operations on the 1173 returned :class:`DTensor`. Default: ``False``. 1174 device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks. 1175 placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` 1176 1177 Returns: 1178 A :class:`DTensor` object on each rank 1179 """ 1180 torch_size = normalize_to_torch_size(size) 1181 1182 return _dtensor_init_helper( 1183 torch.randn, 1184 torch_size, 1185 dtype=dtype, 1186 layout=layout, 1187 requires_grad=requires_grad, 1188 device_mesh=device_mesh, 1189 placements=placements, 1190 ) 1191 1192 1193def zeros( # type: ignore[no-untyped-def] 1194 *size, 1195 requires_grad: bool = False, 1196 dtype: Optional[torch.dtype] = None, 1197 layout: torch.layout = torch.strided, 1198 device_mesh: Optional[DeviceMesh] = None, 1199 placements: Optional[Sequence[Placement]] = None, 1200) -> DTensor: 1201 """ 1202 Returns a :class:`DTensor` filled with the scalar value 0. 1203 1204 Args: 1205 size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. 1206 Can be a variable number of arguments or a collection like a list or tuple. 1207 E.g.: zeros(1,2,3..) or zeros([1,2,3..]) or zeros((1,2,3..)) 1208 Keyword args: 1209 requires_grad (bool, optional): If autograd should record operations on the 1210 returned :class:`DTensor`. Default: ``False``. 1211 dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. 1212 Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). 1213 layout (:class:`torch.layout`, optional): the desired layout of returned :class:`DTensor`. 1214 Default: ``torch.strided``. 1215 device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks 1216 placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` 1217 1218 Returns: 1219 A :class:`DTensor` object on each rank 1220 """ 1221 torch_size = normalize_to_torch_size(size) 1222 1223 return _dtensor_init_helper( 1224 torch.zeros, 1225 torch_size, 1226 dtype=dtype, 1227 layout=layout, 1228 requires_grad=requires_grad, 1229 device_mesh=device_mesh, 1230 placements=placements, 1231 ) 1232