xref: /aosp_15_r20/external/pytorch/torch/distributed/tensor/_api.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3# Copyright (c) Meta Platforms, Inc. and affiliates
4import inspect
5import warnings
6from typing import Any, Callable, cast, Optional, Sequence, Tuple
7
8import torch
9import torch.distributed.tensor._dispatch as op_dispatch
10import torch.distributed.tensor._random as random
11import torch.nn as nn
12from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
13from torch.distributed.tensor._collective_utils import check_tensor_meta, mesh_broadcast
14from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
15from torch.distributed.tensor._random import (
16    is_rng_supported_mesh,
17    OffsetBasedRNGTracker,
18)
19from torch.distributed.tensor._redistribute import (
20    Redistribute,
21    redistribute_local_tensor,
22)
23from torch.distributed.tensor._utils import (
24    compute_global_tensor_info,
25    compute_local_shape,
26    normalize_to_torch_size,
27)
28from torch.distributed.tensor.placement_types import (
29    Partial,
30    Placement,
31    Replicate,
32    Shard,
33)
34
35
36__all__ = [
37    "DTensor",
38    "distribute_tensor",
39    "distribute_module",
40    "ones",
41    "empty",
42    "full",
43    "rand",
44    "randn",
45    "zeros",
46]
47
48aten = torch.ops.aten
49
50
51# NOTE [Autograd interaction between torch.Tensor]
52#
53# The autograd functions defined below are being used by the public
54# facing APIs (i.e. from_local, to_local) to ensure DTensor to work
55# together with torch.Tensor within the autograd engine. This
56# allows DTensor to only exist on part of the module hierarchy.
57#
58# As an example, we have the a module that consists of submodules
59# A, B, and C, the execution flow would be like:
60#  input(torch.Tensor) -> Module A -> Module B -> Module C -> output (torch.Tensor)
61#
62# Suppose I only want to make Module B be a sharded module with
63# DTensor params, the following forward/backward should work:
64#
65#  input(torch.Tensor) -> Module A
66#       -> DTensor input (from_local) -> Sharded Module B -> DTensor output
67#           -> torch.Tensor output (to_local) -> Module C
68#
69# So from_local/to_local must be Autograd functions.
70#
71class _ToTorchTensor(torch.autograd.Function):
72    @staticmethod
73    def forward(  # type: ignore[override]
74        ctx,
75        input: "DTensor",
76        grad_placements: Optional[Sequence[Placement]],
77    ):
78        ctx.dtensor_spec = input._spec
79        ctx.grad_placements = grad_placements
80        local_tensor = input._local_tensor
81
82        # We need to return a fresh Tensor object there as autograd metadata
83        # will be inplaced into it. So we don't want to pollute the Tensor
84        # object stored in the _local_tensor of this DTensor.
85        return local_tensor.view_as(local_tensor)
86
87    @staticmethod
88    def backward(ctx, grad_output: torch.Tensor):  # type: ignore[override]
89        dtensor_spec = ctx.dtensor_spec
90        mesh = dtensor_spec.mesh
91        grad_placements = ctx.grad_placements
92        dtensor_meta = dtensor_spec.tensor_meta
93
94        _, tensor_stride = compute_global_tensor_info(
95            grad_output, mesh, dtensor_spec.placements
96        )
97        tensor_stride = tuple(tensor_stride)
98        grad_placements = grad_placements or dtensor_spec.placements
99        grad_spec = DTensorSpec(
100            mesh,
101            grad_placements,
102            tensor_meta=TensorMeta(
103                shape=dtensor_meta.shape,
104                stride=tensor_stride,
105                dtype=dtensor_meta.dtype,
106            ),
107        )
108
109        return (
110            DTensor(
111                grad_output,
112                grad_spec,
113                requires_grad=grad_output.requires_grad,
114            ),
115            None,
116        )
117
118
119class _FromTorchTensor(torch.autograd.Function):
120    @staticmethod
121    def forward(  # type: ignore[override]
122        ctx,  # pyre-ignore[2]: Parameter must be annotated.
123        input: torch.Tensor,
124        device_mesh: DeviceMesh,
125        placements: Tuple[Placement, ...],
126        run_check: bool,
127        shape: Optional[torch.Size] = None,
128        stride: Optional[Tuple[int, ...]] = None,
129    ) -> "DTensor":
130        ctx.previous_placement = placements
131        ctx.previous_device_mesh = device_mesh
132
133        if shape and stride:
134            tensor_shape, tensor_stride = shape, stride
135        elif not shape and not stride:
136            # if it's not by default run_check, we assume user is certain that each
137            # rank has the same tensor shape, and we just use that to calculate the
138            # global shape
139            global_shape, global_stride = compute_global_tensor_info(
140                input, device_mesh, placements
141            )
142            tensor_shape, tensor_stride = torch.Size(global_shape), tuple(global_stride)
143        else:
144            raise RuntimeError(
145                f"Found shape:{shape}, stride:{stride}.",
146                "Please pass both shape and stride at the same time.",
147            )
148
149        if device_mesh.get_coordinate() is None:
150            # if the global rank is not participating in the device mesh, we
151            # simply set the local tensor to an empty tensor
152            input = input.new_empty(0, requires_grad=input.requires_grad)
153        elif run_check:
154            # TODO: support uneven sharding when global shape/stride not passed, by
155            # building the global TensorMeta during check_tensor_meta
156            check_shape_stride = not shape and not stride
157            check_tensor_meta(input, check_shape_stride=check_shape_stride)
158            # TODO: See if we need to make this run_check logic
159            # have a corresponding backward.
160            for idx, placement in enumerate(placements):
161                if placement.is_replicate():
162                    # broadcast rank 0 tensor to all ranks
163                    # only broadcast if run_check is True
164                    input = input.contiguous()
165                    mesh_broadcast(input, device_mesh, mesh_dim=idx)
166
167        dist_spec = DTensorSpec(
168            device_mesh,
169            placements,
170            tensor_meta=TensorMeta(
171                tensor_shape,
172                tensor_stride,
173                input.dtype,
174            ),
175        )
176
177        # We want a fresh Tensor object that shares memory with the input tensor
178        dist_tensor = DTensor(
179            input.view_as(input),
180            dist_spec,
181            # requires_grad of the dist tensor depends on if input
182            # requires_grad or not
183            requires_grad=input.requires_grad,
184        )
185        return dist_tensor
186
187    @staticmethod
188    def backward(ctx, grad_output: "DTensor"):  # type: ignore[override]
189        previous_placement = ctx.previous_placement
190        previous_device_mesh = ctx.previous_device_mesh
191
192        # reshard to the placement when creating DistributedTensor
193        # so that the gradient layout matches, and we could return
194        # local gradients directly
195        if grad_output.placements != previous_placement:
196            current_spec = grad_output._spec
197            target_spec = DTensorSpec(
198                previous_device_mesh,
199                previous_placement,
200                tensor_meta=grad_output._spec.tensor_meta,
201            )
202            local_tensor = grad_output._local_tensor
203            output = redistribute_local_tensor(
204                local_tensor, current_spec, target_spec, is_backward=True
205            )
206            # TODO: return the redistributed local tensor directly without
207            # differentiable backward. see if this make sense for all cases.
208            return output, None, None, None, None, None
209
210        # TODO: backward is also differentiable now, add a test
211        # to test higher level gradients.
212        return grad_output.to_local(), None, None, None, None, None
213
214
215class DTensor(torch.Tensor):
216    """
217    ``DTensor`` (Distributed Tensor) is a subclass of ``torch.Tensor`` that provides single-device like
218    abstraction to program with multi-device ``torch.Tensor``. It describes the distributed tensor sharding
219    layout (DTensor Layout) through the :class:`DeviceMesh` and following types of :class:`Placement`:
220
221    * :class:`Shard`: Tensor sharded on the tensor dimension ``dim`` on the devices of the ``DeviceMesh`` dimension
222    * :class:`Replicate`: Tensor replicated on the devices of the ``DeviceMesh`` dimension
223    * :class:`Partial`: Tensor is pending reduction on the devices of the ``DeviceMesh`` dimension
224
225    When calling PyTorch operators, ``DTensor`` overrides the PyTorch operators to perform sharded computation and issue
226    communications whenever necessary. Along with the operator computation, ``DTensor`` will transform or propagate the
227    placements (DTensor Layout) properly (based on the operator semantic itself) and generate new ``DTensor`` outputs.
228
229    To ensure numerical correctness of the ``DTensor`` sharded computation when calling PyTorch operators, ``DTensor``
230    requires every Tensor argument of the operator be DTensor.
231
232    """
233
234    _local_tensor: torch.Tensor
235    _spec: DTensorSpec
236    __slots__ = ["_local_tensor", "_spec"]
237
238    # _op_dispatcher instance as a class attribute to handle runtime dispatching logic
239    _op_dispatcher: op_dispatch.OpDispatcher = op_dispatch.OpDispatcher()
240
241    @staticmethod
242    @torch._disable_dynamo
243    def __new__(
244        cls,
245        local_tensor: torch.Tensor,
246        spec: DTensorSpec,
247        *,
248        requires_grad: bool,
249    ) -> "DTensor":
250        """
251        Construct a DTensor from a local tensor, device mesh, and placement and
252        other tensor properties (i.e. shape, requires_grad, strides, etc).
253
254        .. note:: This is not a public API and it's only supposed to be used by the
255            operator implementations and internals. If you want to construct a
256            DTensor from a local tensor, consider using ``DTensor.from_local``, if
257            you want to construct a DTensor from a "global" tensor (where you
258            already have tensor initialized and want to shard this tensor),
259            consider using ``distribute_tensor``.
260        """
261        if local_tensor.requires_grad and not requires_grad:
262            warnings.warn(
263                "To construct DTensor from torch.Tensor, it's recommended to "
264                "use local_tensor.detach() and make requires_grad consistent."
265            )
266
267        # new method instruct wrapper tensor from local_tensor and add
268        # placement spec, it does not do actual distribution
269        assert spec.tensor_meta is not None, "TensorMeta should not be None!"
270        r = torch.Tensor._make_wrapper_subclass(  # type: ignore[attr-defined]
271            cls,
272            spec.tensor_meta.shape,
273            strides=spec.tensor_meta.stride,
274            dtype=local_tensor.dtype,
275            device=local_tensor.device,
276            layout=local_tensor.layout,
277            requires_grad=requires_grad,
278        )
279
280        r._spec = spec
281        r._local_tensor = local_tensor
282        return r
283
284    # pyre-fixme[14]: `__repr__` overrides method defined in `DTensor` inconsistently.
285    # pyre-fixme[3]: Return type must be annotated.
286    def __repr__(self):
287        # TODO: consider all_gather the local tensors for better debugging
288        return f"DTensor(local_tensor={self._local_tensor}, device_mesh={self._spec.mesh}, placements={self._spec.placements})"
289
290    def __tensor_flatten__(self):
291        """
292        protocol to inform how to flatten a DTensor to local tensor
293        for PT2 tracing
294        """
295        return ["_local_tensor"], (self._spec, self.requires_grad)
296
297    @staticmethod
298    def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
299        assert (
300            flatten_spec is not None
301        ), "Expecting spec to be not None from `__tensor_flatten__` return value!"
302        local_tensor = inner_tensors["_local_tensor"]
303        spec, requires_grad = flatten_spec
304        unflatten_tensor_meta = TensorMeta(
305            shape=outer_size,
306            stride=outer_stride,
307            dtype=spec.tensor_meta.dtype,
308        )
309        unflatten_spec = DTensorSpec(
310            spec.mesh,
311            spec.placements,
312            tensor_meta=unflatten_tensor_meta,
313        )
314        return DTensor(
315            local_tensor,
316            unflatten_spec,
317            requires_grad=requires_grad,
318        )
319
320    def __coerce_tangent_metadata__(self):
321        if not any(isinstance(p, Partial) for p in self.placements):
322            return self
323        placements = [
324            Replicate() if isinstance(p, Partial) else p for p in self.placements
325        ]
326        return self.redistribute(device_mesh=self.device_mesh, placements=placements)
327
328    def __coerce_same_metadata_as_tangent__(self, flatten_spec):
329        (spec, _) = flatten_spec  # Result of tensor_flatten()
330        return self.redistribute(
331            device_mesh=self.device_mesh,
332            placements=spec.placements,
333        )
334
335    @classmethod
336    @torch._disable_dynamo
337    # pyre-fixme[3]: Return type must be annotated.
338    # pyre-fixme[2]: Parameter must be annotated.
339    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
340        return DTensor._op_dispatcher.dispatch(
341            func,
342            args,
343            kwargs or {},
344        )
345
346    @staticmethod
347    def from_local(
348        local_tensor: torch.Tensor,
349        device_mesh: Optional[DeviceMesh] = None,
350        placements: Optional[Sequence[Placement]] = None,
351        *,
352        run_check: bool = False,
353        shape: Optional[torch.Size] = None,
354        stride: Optional[Tuple[int, ...]] = None,
355    ) -> "DTensor":
356        """
357        Create a :class:`DTensor` from a local torch.Tensor on each rank
358        according to the ``device_mesh`` and ``placements`` specified.
359
360        Args:
361            local_tensor (torch.Tensor): local torch.Tensor on each rank.
362            device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to place the
363                tensor, if not specified, must be called under a DeviceMesh
364                context manager, default: None
365            placements (List[:class:`Placement`], optional): the placements that
366                describes how to place the local torch.Tensor on DeviceMesh, must
367                have the same number of elements as ``device_mesh.ndim``.
368
369        Keyword args:
370            run_check (bool, optional): at a cost of extra communications, perform
371                sanity check across ranks to check each local tensor's meta information
372                to ensure correctness. If have :class:`Replicate` in ``placements``, the
373                data on first rank of the device mesh dimension will be broadcasted
374                to other ranks. default: False
375            shape (torch.Size, optional): A List of int which specifies the size of
376                DTensor which build on top of `local_tensor`. Note this needs to be
377                provided if the shape of ``local_tensor`` are different across the ranks.
378                If not provided, ``shape`` will be computed assuming the given distributed
379                tensor is evenly sharded across ranks. default: None
380            stride (tuple, optional): A List of int which specifies the stride of DTensor.
381                If not provided, ``stride`` will be computed assuming the given distributed
382                tensor is evenly sharded across ranks. default: None
383
384        Returns:
385            A :class:`DTensor` object
386
387        .. note:: When ``run_check=False``, it is the user's responsibility to ensure the
388            local tensor passed in is correct across ranks (i.e. the tensor is sharded for
389            the ``Shard(dim)`` placement or replicated for the ``Replicate()`` placement).
390            If not, the behavior of the created DTensor is undefined.
391
392        .. note:: ``from_local`` is differentiable, the `requires_grad` of the created
393            `DTensor` object will depend on if `local_tensor` requires_grad or not.
394        """
395        # if same shape/dtype, no need to run_check, if not, must allgather
396        # the metadatas to check the size/dtype across ranks
397        # There should be no data communication unless there's replication
398        # strategy, where we broadcast the replication from the first rank
399        # in the mesh dimension
400        device_mesh = device_mesh or _mesh_resources.get_current_mesh()
401        device_type = device_mesh.device_type
402
403        # convert the local tensor to desired device base on device mesh's device_type
404        if device_type != local_tensor.device.type and not local_tensor.is_meta:
405            local_tensor = local_tensor.to(device_type)
406
407        # set default placements to replicated if not specified
408        if placements is None:
409            placements = [Replicate() for _ in range(device_mesh.ndim)]
410        else:
411            placements = list(placements)
412            for idx, placement in enumerate(placements):
413                # normalize shard dim to be positive
414                if placement.is_shard():
415                    placement = cast(Shard, placement)
416                    if placement.dim < 0:
417                        placements[idx] = Shard(placement.dim + local_tensor.ndim)
418
419        # `from_local` is differentiable, and the gradient of the dist tensor this function
420        # created should flow back the gradients to the local_tensor, so we call an autograd
421        # function to construct the dist tensor instead.
422        return _FromTorchTensor.apply(  # pyre-ignore[16]: autograd func
423            local_tensor,
424            device_mesh,
425            tuple(placements),
426            run_check,
427            shape,
428            stride,
429        )
430
431    def to_local(
432        self, *, grad_placements: Optional[Sequence[Placement]] = None
433    ) -> torch.Tensor:
434        """
435        Get the local tensor of this DTensor on its current rank. For sharding it returns
436        a local shard of the logical tensor view, for replication it returns the replica on
437        its current rank.
438
439        Keyword args:
440            grad_placements (List[:class:`Placement`], optional): the placements describes
441                the future layout of any gradient layout of the Tensor returned from this
442                function.
443                `to_local` converts DTensor to local tensor and the returned local tensor
444                might not be used as the original DTensor layout later in the code. This
445                argument is the hint that user can give to autograd in case the gradient
446                layout of the returned tensor does not match the original DTensor layout.
447                If not specified, we will assume the gradient layout remains the same
448                as the original DTensor and use that for gradient computation.
449
450        Returns:
451            A :class:`torch.Tensor` or ``AsyncCollectiveTensor`` object. it represents the
452            local tensor on its current rank. When an ``AsyncCollectiveTensor`` object is returned,
453            it means the local tensor is not ready yet (i.e. communication is not finished). In this
454            case, user needs to call ``wait`` to wait the local tensor to be ready.
455
456        .. note:: ``to_local`` is differentiable, the ``requires_grad`` of the local tensor returned
457            will depend on if the `DTensor` requires_grad or not.
458        """
459        if not torch.is_grad_enabled():
460            return self._local_tensor
461
462        if grad_placements is not None and not isinstance(grad_placements, tuple):
463            grad_placements = tuple(grad_placements)
464        return _ToTorchTensor.apply(
465            self, grad_placements
466        )  # pyre-ignore[16]: autograd func
467
468    def redistribute(
469        self,
470        device_mesh: Optional[DeviceMesh] = None,
471        placements: Optional[Sequence[Placement]] = None,
472        *,
473        async_op: bool = False,
474    ) -> "DTensor":
475        """
476        ``redistribute`` performs necessary collective operations that redistribute the current
477        DTensor from its current placements to a new placements, or from is current DeviceMesh
478        to a new DeviceMesh. i.e. we can turn a Sharded DTensor to a Replicated DTensor by
479        specifying a Replicate placement for each dimension of the DeviceMesh.
480
481        When redistributing from current to the new placements on one device mesh dimension, we
482        will perform the following operations including communication collective or local operation:
483
484        1. ``Shard(dim)`` -> ``Replicate()``: ``all_gather``
485        2. ``Shard(src_dim)`` -> ``Shard(dst_dim)``: ``all_to_all``
486        3. ``Replicate()`` -> ``Shard(dim)``: local chunking (i.e. ``torch.chunk``)
487        4. ``Partial()`` -> ``Replicate()``: ``all_reduce``
488        5. ``Partial()`` -> ``Shard(dim)``: ``reduce_scatter``
489
490
491        ``redistribute`` would correctly figure out the necessary redistribute steps for DTensors
492        that are created either on 1-D or N-D DeviceMesh.
493
494        Args:
495            device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to place the
496                DTensor. If not specified, it would use the current DTensor's DeviceMesh.
497                default: None
498            placements (List[:class:`Placement`], optional): the new placements that
499                describes how to place the DTensor into the DeviceMesh, must
500                have the same number of elements as ``device_mesh.ndim``.
501                default: replicate on all mesh dimensions
502
503        Keyword args:
504            async_op (bool, optional): whether to perform the DTensor redistribute operation
505                asynchronously or not. Default: False
506
507        Returns:
508            A :class:`DTensor` object
509
510        .. note:: ``redistribute`` is differentiable, which means user do not need to worry about
511            the backward formula of the redistribute operation.
512
513        .. note:: ``redistribute`` currently only supports redistributing DTensor on the same DeviceMesh,
514            Please file an issue if you need to redistribute DTensor to different DeviceMesh.
515        """
516        # NOTE: This redistribute API currently only supports out
517        # of place redistribution, i.e. it always create a new
518        # DTensor object and leave the original one unchanged.
519
520        # if device_mesh is not specified, use the current device_mesh
521        device_mesh = device_mesh or self.device_mesh
522        # raise error if new placements not specified
523        if placements is None:
524            raise RuntimeError("placements is needed for redistribute!")
525
526        placements = list(placements)
527        for i, placement in enumerate(placements):
528            if placement.is_partial():
529                raise RuntimeError(
530                    "Can not redistribute to Partial, redistributing to Partial is for internal use only!"
531                )
532            elif isinstance(placement, Shard) and placement.dim < 0:
533                # normalize shard dim to be positive
534                placements[i] = Shard(placement.dim + self.ndim)
535        placements = tuple(placements)
536
537        # pyre-fixme[16]: `Redistribute` has no attribute `apply`.
538        return Redistribute.apply(self, device_mesh, placements, async_op)
539
540    def full_tensor(
541        self, *, grad_placements: Optional[Sequence[Placement]] = None
542    ) -> torch.Tensor:
543        """
544        Return the full tensor of this DTensor. It will perform necessary collectives
545        to gather the local tensors from other ranks in its DeviceMesh and concatenate
546        them together. It's a syntatic sugar of the following code:
547
548        ``dtensor.redistribute(placements=[Replicate()] * mesh.ndim).to_local()``
549
550        Keyword args:
551            grad_placements (List[:class:`Placement`], optional): the placements describes
552                the future layout of any gradient layout of the full Tensor returned from this
553                function.
554                `full_tensor` converts DTensor to a full torch.Tensor and the returned torch.tensor
555                might not be used as the original replicated DTensor layout later in the code. This
556                argument is the hint that user can give to autograd in case the gradient
557                layout of the returned tensor does not match the original replicated DTensor layout.
558                If not specified, we will assume the gradient layout of the full tensor be replicated.
559
560        Returns:
561            A :class:`torch.Tensor` object that represents the full tensor of this DTensor.
562
563        .. note:: ``full_tensor`` is differentiable.
564        """
565
566        redist_res = self.redistribute(
567            placements=[Replicate()] * self.device_mesh.ndim, async_op=False
568        )
569        return _ToTorchTensor.apply(redist_res, grad_placements)
570
571    @property
572    def device_mesh(self) -> DeviceMesh:
573        """
574        The :class:`DeviceMesh` attribute that associates with this DTensor object.
575
576        .. note:: ``device_mesh`` is a read-only property, it can not be set.
577        """
578        return self._spec.mesh
579
580    @property
581    def placements(self) -> Tuple[Placement, ...]:
582        """
583        The placements attribute of this DTensor that describes the layout of this
584        DTensor on the its DeviceMesh.
585
586        .. note:: ``placements`` is a read-only property, it can not be set.
587        """
588        return self._spec.placements
589
590    def __create_write_items__(self, fqn: str, object: Any):
591        from torch.distributed.checkpoint.planner_helpers import (
592            _create_write_items_for_dtensor,
593        )
594
595        if hasattr(self._local_tensor, "__create_write_items__"):
596            return self._local_tensor.__create_write_items__(fqn, object)  # type: ignore[attr-defined]
597        elif isinstance(self._local_tensor, torch.Tensor):
598            return [_create_write_items_for_dtensor(fqn, object)]
599        else:
600            raise RuntimeError("Unsupported tensor type!")
601
602    def __create_chunk_list__(self):
603        from torch.distributed.checkpoint.planner_helpers import (
604            _create_chunk_from_dtensor,
605        )
606
607        if hasattr(self._local_tensor, "__create_chunk_list__"):
608            return self._local_tensor.__create_chunk_list__()  # type: ignore[attr-defined]
609        elif isinstance(self._local_tensor, torch.Tensor):
610            return [_create_chunk_from_dtensor(self)]
611        else:
612            raise RuntimeError("Unsupported tensor type!")
613
614    def __get_tensor_shard__(self, index):
615        if hasattr(self._local_tensor, "__get_tensor_shard__"):
616            return self._local_tensor.__get_tensor_shard__(index)  # type: ignore[attr-defined]
617        elif isinstance(self._local_tensor, torch.Tensor):
618            return self.to_local()
619        else:
620            raise RuntimeError("Unsupported tensor type!")
621
622
623def distribute_tensor(
624    tensor: torch.Tensor,
625    device_mesh: Optional[DeviceMesh] = None,
626    placements: Optional[Sequence[Placement]] = None,
627) -> DTensor:
628    """
629    Distribute a leaf ``torch.Tensor`` (i.e. nn.Parameter/buffers) to the ``device_mesh`` according
630    to the ``placements`` specified. The rank of ``device_mesh`` and ``placements`` must be the
631    same. The ``tensor`` to distribute is the logical or "global" tensor, and the API would use
632    the ``tensor`` from first rank of the DeviceMesh dimension as the source of truth to perserve
633    the single-device semantic. If you want to construct a DTensor in the middle of the Autograd
634    computation, please use :meth:`DTensor.from_local` instead.
635
636    Args:
637        tensor (torch.Tensor): torch.Tensor to be distributed. Note that if you
638            want to shard a tensor on a dimension that is not evenly divisible by
639            the number of devices in that mesh dimension, we use ``torch.chunk``
640            semantic to shard the tensor and scatter the shards. The uneven sharding
641            behavior is experimental and subject to change.
642        device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to distribute the
643            tensor, if not specified, must be called under a DeviceMesh context
644            manager, default: None
645        placements (List[:class:`Placement`], optional): the placements that
646            describes how to place the tensor on DeviceMesh, must have the same
647            number of elements as ``device_mesh.ndim``. If not specified, we will
648            by default replicate the tensor across the ``device_mesh`` from the
649            first rank of each dimension of the `device_mesh`.
650
651    Returns:
652        A :class:`DTensor` or ``XLAShardedTensor`` object.
653
654    .. note::
655        When initialize the DeviceMesh with the ``xla`` device_type, ``distribute_tensor``
656        return `XLAShardedTensor` instead. see `this issue <https://github.com/pytorch/pytorch/issues/92909>`__
657        for more details. The XLA integration is experimental and subject to change.
658    """
659
660    torch._C._log_api_usage_once("torch.dtensor.distribute_tensor")
661
662    # get default device mesh if there's nothing specified
663    device_mesh = device_mesh or _mesh_resources.get_current_mesh()
664    device_type = device_mesh.device_type
665    if device_type == "xla":
666        try:
667            # call PyTorch/XLA SPMD for `xla` backend type device mesh.
668            # This returns XLAShardedTensor
669            from torch_xla.distributed.spmd import (  # type:ignore[import]
670                xla_distribute_tensor,
671            )
672
673            return xla_distribute_tensor(
674                tensor, device_mesh, placements
675            )  # type:ignore[return-value]
676        except ImportError as e:
677            msg = "To use DTensor API with xla, you must install the torch_xla package!"
678            raise ImportError(msg) from e
679
680    # instantiate a RNG tracker if haven't. By default DTensor uses an
681    # OffsetBasedRNGTracker to perform random operators.
682    # TODO: the value assignment to global variable is not the ideal solution
683    # we can replace it in future.
684    if not random._rng_tracker and is_rng_supported_mesh(device_mesh):
685        random._rng_tracker = OffsetBasedRNGTracker(device_type)
686
687    if not tensor.is_leaf:
688        raise RuntimeError(
689            "`distribute_tensor` should be used to distribute leaf tensors! but found non-leaf tensor!"
690        )
691
692    # convert tensor to the corresponding device type if it's not in that device type
693    if device_type != tensor.device.type and not tensor.is_meta:
694        tensor = tensor.to(device_type)
695
696    # set default placements to replicated if not specified
697    if placements is None:
698        placements = [Replicate() for _ in range(device_mesh.ndim)]
699
700    if len(placements) != device_mesh.ndim:
701        raise ValueError(
702            f"`placements` must have the same length as `device_mesh.ndim`! "
703            f"Found placements length: {len(placements)}, and device_mesh.ndim: {device_mesh.ndim}."
704        )
705    if isinstance(tensor, DTensor):
706        # if the tensor is already a DTensor, we need to check:
707        # 1. if the we can further shard this DTensor if the two device mesh belong to
708        #   the same parenet mesh and further sharding is possible.
709        # 2. check if device mesh and placements are the same
710        if tensor.device_mesh != device_mesh:
711            raise ValueError(
712                f"Cannot distribute a DTensor with device mesh {tensor.device_mesh} "
713                f"to a different device mesh {device_mesh}."
714            )
715        if tensor.placements != tuple(placements):
716            raise ValueError(
717                f"Cannot distribute a DTensor with placements {tensor.placements} "
718                f"to a different placements {placements}. do you want to call "
719                f"`redistribute` instead?"
720            )
721        return tensor
722
723    local_tensor = tensor.detach()
724
725    # TODO(xilun): address sharding order
726    # distribute the tensor according to the placements.
727    placements = list(placements)
728    for idx, placement in enumerate(placements):
729        if placement.is_shard():
730            placement = cast(Shard, placement)
731            if placement.dim < 0:
732                # normalize shard placement dim
733                placement = Shard(placement.dim + tensor.ndim)
734                placements[idx] = placement
735            local_tensor = placement._shard_tensor(local_tensor, device_mesh, idx)
736        elif placement.is_replicate():
737            placement = cast(Replicate, placement)
738            local_tensor = placement._replicate_tensor(local_tensor, device_mesh, idx)
739        else:
740            raise RuntimeError(
741                f"Trying to distribute tensor with unsupported placements {placement} on device mesh dimension {idx}!"
742            )
743    placements = tuple(placements)
744
745    assert local_tensor is not None, "distributing a tensor should not be None"
746    # detach the local tensor passed to DTensor since after the construction
747    # of DTensor, autograd would work on top of DTensor instead of local tensor
748    spec = DTensorSpec(
749        mesh=device_mesh,
750        placements=placements,
751        tensor_meta=TensorMeta(
752            shape=tensor.size(),
753            stride=tensor.stride(),
754            dtype=tensor.dtype,
755        ),
756    )
757    return DTensor(
758        local_tensor.requires_grad_(tensor.requires_grad),
759        spec,
760        requires_grad=tensor.requires_grad,
761    )
762
763
764def distribute_module(
765    module: nn.Module,
766    device_mesh: Optional[DeviceMesh] = None,
767    partition_fn: Optional[Callable[[str, nn.Module, DeviceMesh], None]] = None,
768    input_fn: Optional[Callable[[nn.Module, Any, DeviceMesh], None]] = None,
769    output_fn: Optional[Callable[[nn.Module, Any, DeviceMesh], None]] = None,
770) -> nn.Module:
771    """
772    This function expose three functions to control the parameters/inputs/outputs of the module:
773
774    1. To perform sharding on the module before runtime execution by specifying the
775    ``partition_fn`` (i.e. allow user to convert Module parameters to :class:`DTensor`
776    parameters according to the `partition_fn` specified).
777    2. To control the inputs or outputs of the module during runtime execution by
778    specifying the ``input_fn`` and ``output_fn``. (i.e. convert the input to
779    :class:`DTensor`, convert the output back to ``torch.Tensor``)
780
781    Args:
782        module (:class:`nn.Module`): user module to be partitioned.
783        device_mesh (:class:`DeviceMesh`): the device mesh to place the module.
784        partition_fn (Callable): the function to partition parameters (i.e. shard certain
785            parameters across the ``device_mesh``). If ``partition_fn`` is not specified,
786            by default we replicate all module parameters of ``module`` across the mesh.
787        input_fn (Callable): specify the input distribution, i.e. could control how the
788            input of the module is sharded. ``input_fn`` will be installed as a module
789            ``forward_pre_hook`` (pre forward hook).
790        output_fn (Callable): specify the output distribution, i.e. could control how the
791            output is sharded, or convert it back to torch.Tensor. ``output_fn`` will be
792            installed as a module ``forward_hook`` (post forward hook).
793
794    Returns:
795        A module that contains parameters/buffers that are all ``DTensor`` s.
796
797    .. note::
798        When initialize the DeviceMesh with the ``xla`` device_type, ``distribute_module``
799        return nn.Module with PyTorch/XLA SPMD annotated parameters. See
800        `this issue <https://github.com/pytorch/pytorch/issues/92909>`__
801        for more details. The XLA integration is experimental and subject to change.
802
803    """
804
805    torch._C._log_api_usage_once("torch.dtensor.distribute_module")
806
807    device_mesh = device_mesh or _mesh_resources.get_current_mesh()
808    device_type = device_mesh.device_type
809    if device_type == "xla":
810        try:
811            # This function annotates all module parameters for auto-partitioning with
812            # PyTorch/XLA SPMD or explicitly partition to :class:`XLAShardedTensor` parameters
813            # according to the `partition_fn` specified.
814            from torch_xla.distributed.spmd import (  # type:ignore[import]
815                xla_distribute_module,
816            )
817
818            return xla_distribute_module(
819                module, device_mesh, partition_fn, input_fn, output_fn
820            )  # type:ignore[return-value]
821        except ImportError as e:
822            msg = "To use DTensor API with xla, you must install the torch_xla package!"
823            raise ImportError(msg) from e
824
825    def replicate_module_params_buffers(m: nn.Module, mesh: DeviceMesh) -> None:
826        # This function loop over the immediate module parameters and
827        # buffers, replicate all non DTensor params/buffers to DTensor
828        # parameters/buffers, if they have not been partitioned in the
829        # partition_fn, we can't easily use `module._apply` here
830        # because we don't know what happened inside partition_fn as
831        # user could do anything, i.e. install hooks, and we want to
832        # preserve those.
833        full_replicate = [Replicate()] * mesh.ndim
834        for key, param in m._parameters.items():
835            if param is not None and not isinstance(param, DTensor):
836                m.register_parameter(
837                    key,
838                    nn.Parameter(distribute_tensor(param.data, mesh, full_replicate)),
839                )
840        for key, buffer in m._buffers.items():
841            if buffer is not None and not isinstance(buffer, DTensor):
842                m._buffers[key] = distribute_tensor(buffer, mesh, full_replicate)
843
844    if partition_fn is None:
845        # if partition_fn not specified, we by default replicate
846        # all module params/buffers
847        for name, submod in module.named_modules():
848            replicate_module_params_buffers(submod, device_mesh)
849    else:
850        # apply partition_fun to submodules
851        for name, submod in module.named_modules():
852            partition_fn(name, submod, device_mesh)
853            replicate_module_params_buffers(submod, device_mesh)
854
855    # register input_fn as module forward pre hook
856    if input_fn is not None:
857        # check the input_fn signature
858        num_args = len(inspect.signature(input_fn).parameters)
859        if num_args == 2:
860            # input_fn only takes in inputs and device mesh
861            warnings.warn(
862                "Deprecating input_fn that takes two arguments (inputs, device_mesh), "
863                "please use input_fn that takes in (module, inputs, device_mesh) instead!",
864                FutureWarning,
865                stacklevel=2,
866            )
867            module.register_forward_pre_hook(lambda _, inputs: input_fn(inputs, device_mesh))  # type: ignore[call-arg]
868        elif num_args == 3:
869            # input_fn takes in module, inputs, device mesh
870            module.register_forward_pre_hook(
871                lambda mod, inputs: input_fn(mod, inputs, device_mesh)
872            )
873        else:
874            raise ValueError(
875                f"input_fn should take in 3 arguments, but got {num_args} arguments!"
876            )
877    # register output_fn as module forward hook
878    if output_fn is not None:
879        num_args = len(inspect.signature(output_fn).parameters)
880        if num_args == 2:
881            # output_fn only takes in outputs and device mesh
882            warnings.warn(
883                "Deprecating output_fn that takes two arguments (inputs, device_mesh), "
884                "please use output_fn that takes in (module, inputs, device_mesh) instead!",
885                FutureWarning,
886                stacklevel=2,
887            )
888            module.register_forward_hook(
889                lambda mod, inputs, outputs: output_fn(outputs, device_mesh)  # type: ignore[call-arg]
890            )
891        elif num_args == 3:
892            module.register_forward_hook(
893                lambda mod, inputs, outputs: output_fn(mod, outputs, device_mesh)
894            )
895        else:
896            raise ValueError(
897                f"output_fn should take in 3 arguments, but got {num_args} arguments!"
898            )
899
900    return module
901
902
903# Below are tensor factory function APIs, which are used to create a DTensor directly. We need
904# to make separate factory function APIs because tensor subclass could not override the tensor
905# factory methods, and we need user to call the factory functions with user intended device_mesh
906# and placements to create a proper DTensor.
907
908
909def _dtensor_init_helper(  # type: ignore[no-untyped-def]
910    init_op,
911    size: torch.Size,
912    device_mesh: Optional[DeviceMesh] = None,
913    placements: Optional[Sequence[Placement]] = None,
914    **kwargs,
915) -> DTensor:
916    # from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta
917
918    # if device_mesh is None, use the one from mesh resources
919    device_mesh = device_mesh or _mesh_resources.get_current_mesh()
920    kwargs["device"] = device_mesh.device_type
921
922    # set default placements to replicated if not specified
923    placements = placements or tuple(Replicate() for _ in range(device_mesh.ndim))
924
925    # check device_mesh againts placements
926    assert device_mesh.ndim == len(
927        placements
928    ), "mesh dimension does not match the length of placements"
929
930    assert kwargs["layout"] == torch.strided, "layout value not supported!"
931    torch_stride = torch._prims_common.make_contiguous_strides_for(size)
932
933    # get local tensor shape
934    local_shape = compute_local_shape(size, device_mesh, placements)
935    # initialize the local tensor
936    if init_op == torch.full:
937        fill_value = kwargs.pop("fill_value", 0)
938        local_tensor = init_op(local_shape, fill_value, **kwargs)
939    elif init_op == torch.rand or init_op == torch.randn:
940        # this tensor meta is not used except `shape`
941        dtype = kwargs.get("dtype", torch.get_default_dtype())
942
943        tensor_meta = TensorMeta(size, (0,), dtype)
944        spec = DTensorSpec(device_mesh, tuple(placements), tensor_meta=tensor_meta)
945
946        if random.is_rng_supported_mesh(device_mesh) and not random._rng_tracker:
947            random._rng_tracker = random.OffsetBasedRNGTracker()
948
949        assert random._rng_tracker is not None
950        with random._rng_tracker._distribute_region(spec):
951            local_tensor = init_op(local_shape, **kwargs)
952    else:
953        local_tensor = init_op(local_shape, **kwargs)
954
955    spec = DTensorSpec(
956        device_mesh,
957        tuple(placements),
958        tensor_meta=TensorMeta(
959            size,
960            torch_stride,
961            local_tensor.dtype,
962        ),
963    )
964
965    return DTensor(
966        local_tensor,
967        spec,
968        requires_grad=kwargs["requires_grad"],
969    )
970
971
972def ones(  # type: ignore[no-untyped-def]
973    *size,
974    dtype: Optional[torch.dtype] = None,
975    layout: torch.layout = torch.strided,
976    requires_grad: bool = False,
977    device_mesh: Optional[DeviceMesh] = None,
978    placements: Optional[Sequence[Placement]] = None,
979) -> DTensor:
980    """
981    Returns a :class:`DTensor` filled with the scalar value 1, with the shape defined
982    by the variable argument ``size``.
983
984    Args:
985        size (int...): a sequence of integers defining the shape of the output :class:`DTensor`.
986            Can be a variable number of arguments or a collection like a list or tuple.
987            E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))
988
989    Keyword args:
990        dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`.
991            Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
992        layout (:class:`torch.layout`, optional): the desired layout of returned DTensor.
993            Default: ``torch.strided``.
994        requires_grad (bool, optional): If autograd should record operations on the
995            returned :class:`DTensor`. Default: ``False``.
996        device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks
997        placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``
998
999    Returns:
1000        A :class:`DTensor` object on each rank
1001    """
1002    torch_size = normalize_to_torch_size(size)
1003
1004    return _dtensor_init_helper(
1005        torch.ones,
1006        torch_size,
1007        dtype=dtype,
1008        layout=layout,
1009        requires_grad=requires_grad,
1010        device_mesh=device_mesh,
1011        placements=placements,
1012    )
1013
1014
1015def empty(  # type: ignore[no-untyped-def]
1016    *size,
1017    dtype: Optional[torch.dtype] = None,
1018    layout: torch.layout = torch.strided,
1019    requires_grad: bool = False,
1020    device_mesh: Optional[DeviceMesh] = None,
1021    placements: Optional[Sequence[Placement]] = None,
1022) -> DTensor:
1023    """
1024    Returns a :class:`DTensor` filled with uninitialized data. The shape of the :class:`DTensor`
1025    is defined by the variable argument ``size``.
1026
1027    Args:
1028        size (int...): a sequence of integers defining the shape of the output :class:`DTensor`.
1029            Can be a variable number of arguments or a collection like a list or tuple.
1030            E.g.: empty(1,2,3..) or empty([1,2,3..]) or empty((1,2,3..))
1031
1032    Keyword args:
1033        dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`.
1034            Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).\
1035        layout (:class:`torch.layout`, optional): the desired layout of returned :class:`DTensor`.
1036            Default: ``torch.strided``.
1037        requires_grad (bool, optional): If autograd should record operations on the
1038            returned :class:`DTensor`. Default: ``False``.
1039        device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks
1040        placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``
1041
1042    Returns:
1043        A :class:`DTensor` object on each rank
1044    """
1045    torch_size = normalize_to_torch_size(size)
1046
1047    return _dtensor_init_helper(
1048        torch.empty,
1049        torch_size,
1050        dtype=dtype,
1051        layout=layout,
1052        requires_grad=requires_grad,
1053        device_mesh=device_mesh,
1054        placements=placements,
1055    )
1056
1057
1058def full(  # type: ignore[no-untyped-def]
1059    size,
1060    fill_value,
1061    *,
1062    dtype: Optional[torch.dtype] = None,
1063    layout: torch.layout = torch.strided,
1064    requires_grad: bool = False,
1065    device_mesh: Optional[DeviceMesh] = None,
1066    placements: Optional[Sequence[Placement]] = None,
1067) -> DTensor:
1068    """
1069    Returns a :class:`DTensor` filled with ``fill_value`` according to ``device_mesh`` and
1070    ``placements``, with the shape defined by the argument ``size``.
1071
1072    Args:
1073        size (int...): a sequence of integers defining the shape of the output :class:`DTensor`.
1074            Can be a variable number of arguments or a collection like a list or tuple.
1075            E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))
1076        fill_value(Scalar): the value to fill the output tensor with.
1077
1078    Keyword args:
1079        dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`.
1080            Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
1081        layout (:class:`torch.layout`, optional): the desired layout of returned DTensor.
1082            Default: ``torch.strided``.
1083        requires_grad (bool, optional): If autograd should record operations on the
1084            returned :class:`DTensor`. Default: ``False``.
1085        device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks.
1086        placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``
1087
1088    Returns:
1089        A :class:`DTensor` object on each rank
1090    """
1091    torch_size = normalize_to_torch_size(size)
1092
1093    return _dtensor_init_helper(
1094        torch.full,
1095        torch_size,
1096        fill_value=fill_value,
1097        dtype=dtype,
1098        layout=layout,
1099        requires_grad=requires_grad,
1100        device_mesh=device_mesh,
1101        placements=placements,
1102    )
1103
1104
1105def rand(  # type: ignore[no-untyped-def]
1106    *size,
1107    requires_grad: bool = False,
1108    dtype: Optional[torch.dtype] = None,
1109    layout: torch.layout = torch.strided,
1110    device_mesh: Optional[DeviceMesh] = None,
1111    placements: Optional[Sequence[Placement]] = None,
1112) -> DTensor:
1113    """
1114    Returns a :class:`DTensor` filled with random numbers from a uniform distribution
1115    on the interval ``[0, 1)``. The shape of the tensor is defined by the variable
1116    argument ``size``.
1117
1118    Args:
1119        size (int...): a sequence of integers defining the shape of the output :class:`DTensor`.
1120            Can be a variable number of arguments or a collection like a list or tuple.
1121            E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))
1122
1123    Keyword args:
1124        dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`.
1125            Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
1126        layout (:class:`torch.layout`, optional): the desired layout of returned DTensor.
1127            Default: ``torch.strided``.
1128        requires_grad (bool, optional): If autograd should record operations on the
1129            returned :class:`DTensor`. Default: ``False``.
1130        device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks.
1131        placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``
1132
1133    Returns:
1134        A :class:`DTensor` object on each rank
1135    """
1136    torch_size = normalize_to_torch_size(size)
1137
1138    return _dtensor_init_helper(
1139        torch.rand,
1140        torch_size,
1141        dtype=dtype,
1142        layout=layout,
1143        requires_grad=requires_grad,
1144        device_mesh=device_mesh,
1145        placements=placements,
1146    )
1147
1148
1149def randn(  # type: ignore[no-untyped-def]
1150    *size,
1151    requires_grad: bool = False,
1152    dtype: Optional[torch.dtype] = None,
1153    layout: torch.layout = torch.strided,
1154    device_mesh: Optional[DeviceMesh] = None,
1155    placements: Optional[Sequence[Placement]] = None,
1156) -> DTensor:
1157    """
1158    Returns a :class:`DTensor` filled with random numbers from a normal distribution
1159    with mean 0 and variance 1. The shape of the tensor is defined by the variable
1160    argument ``size``.
1161
1162    Args:
1163        size (int...): a sequence of integers defining the shape of the output :class:`DTensor`.
1164            Can be a variable number of arguments or a collection like a list or tuple.
1165            E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..))
1166
1167    Keyword args:
1168        dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`.
1169            Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
1170        layout (:class:`torch.layout`, optional): the desired layout of returned DTensor.
1171            Default: ``torch.strided``.
1172        requires_grad (bool, optional): If autograd should record operations on the
1173            returned :class:`DTensor`. Default: ``False``.
1174        device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks.
1175        placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``
1176
1177    Returns:
1178        A :class:`DTensor` object on each rank
1179    """
1180    torch_size = normalize_to_torch_size(size)
1181
1182    return _dtensor_init_helper(
1183        torch.randn,
1184        torch_size,
1185        dtype=dtype,
1186        layout=layout,
1187        requires_grad=requires_grad,
1188        device_mesh=device_mesh,
1189        placements=placements,
1190    )
1191
1192
1193def zeros(  # type: ignore[no-untyped-def]
1194    *size,
1195    requires_grad: bool = False,
1196    dtype: Optional[torch.dtype] = None,
1197    layout: torch.layout = torch.strided,
1198    device_mesh: Optional[DeviceMesh] = None,
1199    placements: Optional[Sequence[Placement]] = None,
1200) -> DTensor:
1201    """
1202    Returns a :class:`DTensor` filled with the scalar value 0.
1203
1204    Args:
1205        size (int...): a sequence of integers defining the shape of the output :class:`DTensor`.
1206            Can be a variable number of arguments or a collection like a list or tuple.
1207            E.g.: zeros(1,2,3..) or zeros([1,2,3..]) or zeros((1,2,3..))
1208    Keyword args:
1209        requires_grad (bool, optional): If autograd should record operations on the
1210            returned :class:`DTensor`. Default: ``False``.
1211        dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`.
1212            Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
1213        layout (:class:`torch.layout`, optional): the desired layout of returned :class:`DTensor`.
1214            Default: ``torch.strided``.
1215        device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks
1216        placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate``
1217
1218    Returns:
1219        A :class:`DTensor` object on each rank
1220    """
1221    torch_size = normalize_to_torch_size(size)
1222
1223    return _dtensor_init_helper(
1224        torch.zeros,
1225        torch_size,
1226        dtype=dtype,
1227        layout=layout,
1228        requires_grad=requires_grad,
1229        device_mesh=device_mesh,
1230        placements=placements,
1231    )
1232