xref: /aosp_15_r20/external/pytorch/torch/distributed/_shard/sharded_tensor/api.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from __future__ import annotations  # type: ignore[attr-defined]
3
4import copy
5import operator
6import threading
7import warnings
8import weakref
9from dataclasses import dataclass
10from functools import reduce
11from typing import Callable, cast, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING
12from typing_extensions import deprecated
13
14import torch
15import torch.distributed as dist
16import torch.distributed._shard.sharding_spec as shard_spec
17from torch.distributed import distributed_c10d, rpc
18from torch.distributed._shard._utils import DEPRECATE_MSG
19from torch.distributed._shard.sharding_spec._internals import (
20    check_tensor,
21    validate_non_overlapping_shards_metadata,
22)
23from torch.distributed._shard.sharding_spec.api import (
24    _dispatch_custom_op,
25    _has_custom_op,
26)
27from torch.distributed.remote_device import _remote_device
28from torch.utils import _pytree as pytree
29
30from .metadata import ShardedTensorMetadata, TensorProperties
31from .reshard import reshard_local_shard, reshuffle_local_shard
32from .shard import Shard
33from .utils import (
34    _flatten_tensor_size,
35    _parse_and_validate_remote_device,
36    _validate_output_tensor_for_gather,
37    build_global_metadata,
38    build_metadata_from_local_shards,
39)
40
41
42if TYPE_CHECKING:
43    from torch.distributed._shard.metadata import ShardMetadata
44
45
46# Tracking for sharded tensor objects.
47_sharded_tensor_lock = threading.Lock()
48_sharded_tensor_current_id = 0
49_sharded_tensor_map: Dict[int, weakref.ReferenceType[ShardedTensor]] = {}
50
51# Default sharded ops
52_SHARDED_OPS: Dict[Callable, Callable] = {}
53
54# Customized user ops
55_CUSTOM_SHARDED_OPS: Dict[Callable, Callable] = {}
56
57
58def _register_remote_shards(
59    sharded_tensor_id: int, rrefs: List[rpc.RRef[Shard]], rpc_rank: int
60):
61    with _sharded_tensor_lock:
62        if sharded_tensor_id not in _sharded_tensor_map:
63            raise RuntimeError(
64                f"Could not find sharded_tensor_id: {sharded_tensor_id} in map: {_sharded_tensor_map.keys()}"
65            )
66
67        sharded_tensor = _sharded_tensor_map[sharded_tensor_id]()
68        if sharded_tensor is None:
69            raise RuntimeError("ShardedTensor weakref has been deallocated")
70        else:
71            sharded_tensor._register_remote_shards(rrefs, rpc_rank)
72
73
74class ShardedTensorBase(torch.Tensor):
75    _sharding_spec: shard_spec.ShardingSpec
76    _metadata: ShardedTensorMetadata
77    _local_shards: List[Shard]
78
79    def __new__(cls, sharding_spec: shard_spec.ShardingSpec, *size, **kwargs):
80        # Use __new__ to construct a wrapper tensor, for recording tensor
81        # properties and logging purposes.
82        torch._C._log_api_usage_once("torch.distributed._shard.sharded_tensor")
83
84        # check sharding spec and build sharded tensor metadata
85        if not isinstance(sharding_spec, shard_spec.ShardingSpec):
86            raise ValueError(f"Expecting ShardingSpec but got: {type(sharding_spec)}")
87
88        sizes = _flatten_tensor_size(size)
89        dtype = kwargs["dtype"]
90        layout = kwargs["layout"]
91        pin_memory = kwargs["pin_memory"]
92        requires_grad = kwargs["requires_grad"]
93
94        if dtype is None:
95            dtype = torch.get_default_dtype()
96
97        tensor_properties = TensorProperties(
98            dtype, layout, requires_grad, pin_memory=pin_memory
99        )
100        sharded_tensor_metadata = sharding_spec.build_metadata(
101            sizes, tensor_properties=tensor_properties
102        )
103
104        r = torch.Tensor._make_wrapper_subclass(  # type: ignore[attr-defined]
105            cls,
106            sizes,
107            dtype=dtype,
108            layout=layout,
109            pin_memory=pin_memory,
110            requires_grad=requires_grad,
111        )
112        # set sharding spec
113        r._sharding_spec = sharding_spec
114        # set metadata
115        r._metadata = sharded_tensor_metadata
116        # set local shards
117        r._local_shards = []
118        return r
119
120    def metadata(self) -> ShardedTensorMetadata:
121        """
122        Returns a :class:`ShardedTensorMetadata` object corresponding to the
123        metadata for the entire tensor.
124        """
125        return self._metadata
126
127    def local_shards(self) -> List[Shard]:
128        """
129        Returns a list of :class:`Shard' corresponding to the
130        local shards for this rank. Returns an empty list if the current rank
131        does not host any shards for this Tensor.
132        """
133        return self._local_shards
134
135    @classmethod
136    def _init_from_local_shards_and_global_metadata(
137        cls,
138        local_shards: List[Shard],
139        sharded_tensor_metadata: ShardedTensorMetadata,
140        sharding_spec=None,
141    ) -> ShardedTensorBase:
142        """
143        Initialize a ShardedTensorBase with local shards and a global
144        ShardedTensorMetadata built on each rank.
145        Warning: This API is experimental and subject to change. It does
146                 not do cross rank validations, and fully rely on the user
147                 for the correctness of sharded_tensor_metadata on each rank
148        """
149        shards_metadata = sharded_tensor_metadata.shards_metadata
150        tensor_properties = sharded_tensor_metadata.tensor_properties
151
152        if len(shards_metadata) == 0:
153            raise ValueError("shards_metadata must not be empty!")
154
155        if tensor_properties.layout != torch.strided:
156            raise ValueError("Only torch.strided layout is currently supported")
157
158        if sharding_spec is None:
159            spec = shard_spec._infer_sharding_spec_from_shards_metadata(shards_metadata)
160        else:
161            spec = sharding_spec
162
163        sharded_tensor_base = ShardedTensorBase.__new__(
164            ShardedTensor,
165            spec,
166            sharded_tensor_metadata.size,
167            dtype=tensor_properties.dtype,
168            layout=tensor_properties.layout,
169            pin_memory=tensor_properties.pin_memory,
170            requires_grad=tensor_properties.requires_grad,
171        )
172
173        # check if shards_metadata have overlap shards
174        validate_non_overlapping_shards_metadata(shards_metadata)
175
176        # check if the shards_metadata is compatible with overall size of the sharded tensor.
177        check_tensor(shards_metadata, list(sharded_tensor_metadata.size))
178
179        # done validation, add local_shards
180        sharded_tensor_base._local_shards = local_shards
181        return sharded_tensor_base
182
183    @classmethod
184    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
185        raise RuntimeError(
186            f"A {cls.__name__} object is being used from c++ while calling {func.__module__}.{func.__name__} "
187            "but the there is no custom __torch_dispatch__ implementation for it."
188        )
189
190
191class ShardedTensor(ShardedTensorBase):
192    """
193    ShardedTensor is an torch.Tensor subclass to represent Tensors that are sharded
194    across multiple devices and multiple processes.
195
196    ShardedTensor is initialized in an SPMD like fashion where each rank
197    initializes the ShardedTensor. The ShardedTensor object on each rank
198    then only stores the local shard for the Tensor and provides global
199    metadata for all the shards.
200
201    ShardedTensor doesn't provide any Tensor like operations but is a wrapper
202    providing the Tensor representing the local shard and the global metadata.
203    Using these, users can build their custom distributed._sharded computations
204    on top of this primitive. The local shards are all initialized using the
205    create_op specified by tensor_init_params.create_op, e.g., torch.ones, or
206    torch.empty
207
208    Args:
209        sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
210            describing how to shard the Tensor.
211        size (int...): a sequence of integers defining the shape of the output
212            tensor. Can be a variable number of arguments or a collection like a list or tuple.
213
214    Keyword args:
215        dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
216                Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
217        layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
218            Default: ``torch.strided``.
219        requires_grad (bool, optional): If autograd should record operations on the
220            returned tensor. Default: ``False``.
221        pin_memory (bool, optional): If set, returned tensor would be allocated in
222            the pinned memory. Works only for CPU tensors. Default: ``False``.
223        memory_format (:class:`torch.memory_format`, optional): the desired memory format of
224            returned Tensor. Default: ``torch.contiguous_format``.
225        init_rrefs (bool, optional): Whether or not to initialize
226            :class:`torch.distributed.rpc.RRef`s pointing to remote shards.
227            Need to initialize the RPC Framework if specified as ``True``.
228            Default: ``False``.
229
230    .. note:: ShardedTensor uses collectives to do various operations, i.e. it
231        uses all_gather to do cross rank validations. For NCCL-based process
232        groups, internal tensor representations of objects must be moved to the
233        GPU device before communication takes place. In this case, the device
234        used is given by ``torch.cuda.current_device()`` and it is the user's
235        responsibility to ensure that this is set so that each rank has an
236        individual GPU, via ``torch.cuda.set_device()``
237
238    """
239
240    def __new__(cls, sharding_spec: shard_spec.ShardingSpec, *size, **kwargs):
241        self = super().__new__(cls, sharding_spec, *size, **kwargs)
242        return self
243
244    def __init__(
245        self,
246        sharding_spec: shard_spec.ShardingSpec,
247        *size,
248        dtype=None,
249        layout=torch.strided,
250        requires_grad=False,
251        pin_memory=False,
252        memory_format=torch.contiguous_format,
253        process_group=None,
254        init_rrefs=False,
255    ):
256        # prepare initialization, initialize fields like
257        # _process_group, _local_shards, etc.
258        self._prepare_init(process_group=process_group, init_rrefs=init_rrefs)
259
260        if layout != torch.strided:
261            raise ValueError("Only torch.strided layout is currently supported")
262
263        if memory_format != torch.contiguous_format:
264            raise ValueError(
265                "Only torch.contiguous_format memory_format is currently supported"
266            )
267
268        self._metadata.tensor_properties.memory_format = memory_format
269
270        current_rank = dist.get_rank()  # global rank
271
272        for shard_metadata in self._metadata.shards_metadata:
273            rank, device = _parse_and_validate_remote_device(
274                self._process_group, shard_metadata.placement
275            )
276            if rank == current_rank:
277                local_tensor = _create_tensor_from_params(
278                    shard_metadata.shard_sizes,
279                    local_device=device,
280                    tensor_properties=self._metadata.tensor_properties,
281                )
282                self._local_shards.append(Shard(local_tensor, shard_metadata))
283
284        # do post initialization (i.e. register sharded_tensor_id, initialize_rpc)
285        self._post_init()
286
287    def _prepare_init(self, process_group=None, init_rrefs=False):
288        self._init_rrefs = init_rrefs
289        self._sharded_tensor_id = None
290
291        self._process_group = self._normalize_pg(process_group)
292        self._remote_shards: Dict[int, List[rpc.RRef[Shard]]] = {}
293
294    def _post_init(self):
295        # Initialize RPC if available.
296        if self._init_rrefs:
297            with _sharded_tensor_lock:
298                global _sharded_tensor_current_id, _sharded_tensor_map
299                self._sharded_tensor_id = _sharded_tensor_current_id
300                _sharded_tensor_map[self._sharded_tensor_id] = weakref.ref(self)
301                _sharded_tensor_current_id += 1
302
303            if not rpc._is_current_rpc_agent_set():
304                raise RuntimeError(
305                    "RPC Framework needs to be initialized using"
306                    " torch.distributed.rpc.init_rpc if init_rrefs is set to True"
307                )
308            self._init_rpc()
309
310    def __del__(self):
311        # Clean up the global map.
312        with _sharded_tensor_lock:
313            global _sharded_tensor_current_id, _sharded_tensor_map
314            if (
315                hasattr(self, "_sharded_tensor_id")
316                and self._sharded_tensor_id in _sharded_tensor_map
317            ):
318                _sharded_tensor_map.pop(self._sharded_tensor_id)  # type: ignore[call-overload]
319
320    def _init_rpc(self):
321        # Validate PG and RPC ranks match.
322        pg_rank = dist.get_rank()
323        rpc_rank = rpc.get_worker_info().id
324        if pg_rank != rpc_rank:
325            raise ValueError(
326                f"Default ProcessGroup and RPC ranks must be "
327                f"the same for ShardedTensor, found process group rank: "
328                f"{pg_rank} and RPC rank: {rpc_rank}"
329            )
330
331        self._remote_shards = {}
332
333        # Gather all the sharded tensor ids.
334        worker_infos = rpc._get_current_rpc_agent().get_worker_infos()
335        rank_to_name = {}
336        name_to_rank = {}
337
338        for worker_info in worker_infos:
339            rank_to_name[worker_info.id] = worker_info.name
340            name_to_rank[worker_info.name] = worker_info.id
341
342        all_tensor_ids = rpc.api._all_gather(self._sharded_tensor_id)
343
344        # Share the local shards to the entire world.
345        futs = []
346        rpc_rank = rpc.get_worker_info().id
347        for rank in range(dist.get_world_size()):
348            # Skip self.
349            if rank == dist.get_rank():
350                continue
351
352            if len(self.local_shards()) != 0:
353                rrefs: List[rpc.RRef[Shard]] = [
354                    rpc.RRef(shard) for shard in self.local_shards()
355                ]
356                fut = rpc.rpc_async(
357                    rank,
358                    _register_remote_shards,
359                    args=(all_tensor_ids[rank_to_name[rank]], rrefs, rpc_rank),
360                )
361                futs.append(fut)
362
363        torch.futures.wait_all(futs)
364
365        # Barrier for all RPCs to finish on all ranks.
366        rpc.api._all_gather(None)
367
368    def _get_preferred_device(self) -> torch.device:
369        """
370        Return the preferred device to be used when creating tensors for collectives.
371        This method takes into account the associated process group
372        """
373        if dist.get_backend(self._process_group) == dist.Backend.NCCL:
374            return torch.device(torch.cuda.current_device())
375        return torch.device("cpu")
376
377    def gather(  # type: ignore[override]
378        self,
379        dst: int = 0,
380        out: Optional[torch.Tensor] = None,
381        enforce_dtype: bool = False,
382        dtype: Optional[torch.dtype] = None,
383    ) -> None:
384        """
385        Creates a full :class:`Tensor` on rank ``dst`` by gathering all shards of the
386        sharded tensor.
387
388        The API needs to be called on all ranks in SPMD fashion. All ranks should have
389        the same ``dst``. ``out`` should be a tensor of the same size as the overall
390        size of the sharded tensor on ``dst`` and ``None`` on all other ranks.
391
392        Args:
393            dst(int): The rank where full tensor is constructed.
394                Default: 0
395            out (:class `torch.Tensor`, optional): The output full tensor.
396                Must to be provided ONLY on ``dst`` rank.
397                Default: ``None``
398            enforce_dtype (bool): Deprecated, please use dtype instead.  Force the
399                gathered tensors to be the same type as input and output.
400            dtype (torch.dtype): Force the gathered tensors to be this dtype.
401                Default: ``None``
402        """
403
404        def shard_size(shard_md):
405            return reduce(operator.mul, shard_md.shard_sizes)  # type: ignore[attr-defined]
406
407        if enforce_dtype:
408            warnings.warn(
409                "`enforce_dtype` is deprecated. Please use `dtype` instead.",
410                FutureWarning,
411                stacklevel=2,
412            )
413
414        rank = dist.get_rank(self._process_group)
415        full_size = self.metadata().size
416        _validate_output_tensor_for_gather(rank, dst, full_size, out)
417
418        local_shards = self.local_shards()
419        world_size = dist.get_world_size(self._process_group)
420        rank_sizes = [0 for _ in range(world_size)]
421        max_rank_size = 0
422        shard_placement: Dict[ShardMetadata, Tuple[int, int]] = {}
423        # collect sizes
424        for shard_md in self.metadata().shards_metadata:
425            shard_rank = cast(_remote_device, shard_md.placement).rank()
426            assert shard_rank is not None
427
428            shard_placement[shard_md] = (shard_rank, rank_sizes[shard_rank])
429            rank_sizes[shard_rank] += shard_size(shard_md)
430            max_rank_size = max(max_rank_size, rank_sizes[shard_rank])
431
432        gather_list: Optional[List[torch.Tensor]]
433        if rank == dst:
434            assert out is not None
435            if enforce_dtype:
436                # enforce_dtype is deprecated.  Do it for backward compatibility.
437                dtype = out.dtype
438            # TODO make it as a view of out tensor
439            gather_list = [
440                torch.empty((max_rank_size,), device=out.device, dtype=dtype)
441                for _ in range(world_size)
442            ]
443        else:
444            gather_list = None
445
446        with torch.no_grad():
447            if enforce_dtype and len(local_shards) > 0:
448                # enforce_dtype is deprecated.  Do it for backward compatibility.
449                dtype = local_shards[0].tensor.dtype
450            data = torch.empty(
451                max_rank_size, device=self._get_preferred_device(), dtype=dtype
452            )
453
454            for shard in local_shards:
455                src = shard.tensor.flatten()
456                if src.nelement() == 0:
457                    warnings.warn(
458                        "Gathering a tensor with zero elements on rank " + str(rank)
459                    )
460                    return
461                shard_offset = shard_placement[shard.metadata][1]
462                data[shard_offset : shard_offset + src.numel()].copy_(src)
463
464        dist.gather(
465            tensor=data,
466            gather_list=gather_list,
467            dst=dst,
468            group=self._process_group,
469        )
470        if rank != dst:
471            return
472        # In _validate_output_tensor_for_gather, we raise if out == None and rank == dst
473        out = cast(torch.Tensor, out)
474        assert gather_list is not None
475
476        full_size = self.metadata().size
477        dims = len(full_size)
478        for shard_md in self.metadata().shards_metadata:
479            rank, rank_offset = shard_placement[shard_md]
480            tensor = gather_list[rank]
481            tensor = tensor[rank_offset : rank_offset + shard_size(shard_md)]
482            tensor = tensor.view(shard_md.shard_sizes)
483
484            out_narrow_view = out
485            for dim in range(dims):
486                out_narrow_view = out_narrow_view.narrow(
487                    dim,
488                    shard_md.shard_offsets[dim],
489                    shard_md.shard_sizes[dim],
490                )
491
492            out_narrow_view.copy_(tensor)
493
494    def cpu(
495        self, memory_format=torch.preserve_format, process_group=None
496    ) -> ShardedTensor:
497        """
498        Returns a copy of this object in CPU memory.
499
500        If this ShardedTensor is already on CPU memory, then no copy is
501        performed and original object is returned.
502
503        .. note:: When moving a ShardedTensor from GPU to CPU, the ShardedTensor might
504            need to be managed by a different type of ProcessGroup(i.e. ProcessGroupGloo),
505            it is the user's responsiblity to explicitly pass in a new process_group that
506            is compatible with CPU.
507        """
508        # TODO: make this a __torch_function__ op once ShardedTensor becomes a
509        # torch.Tensor subclass, see https://github.com/pytorch/pytorch/issues/75402
510        if (
511            memory_format != torch.preserve_format
512            and memory_format != torch.contiguous_format
513        ):
514            raise RuntimeError(
515                "Only `torch.contiguous_format` or "
516                "`torch.preserve_format` is supported!"
517            )
518        all_on_cpu = True
519        for meta in self.metadata().shards_metadata:
520            all_on_cpu &= meta.placement.device().type == "cpu"  # type: ignore[union-attr]
521
522        # if every shard is already on CPU, return the original object
523        if all_on_cpu:
524            return self
525
526        # if not, returns a copy of this object on CPU
527        list_shards: List[Shard] = []
528        # move all local shards to cpu, and change metadata
529        for shard in self._local_shards:
530            cpu_tensor = shard.tensor.cpu(memory_format=memory_format)  # type: ignore[call-arg]
531            metadata = copy.deepcopy(shard.metadata)
532            metadata.placement._device = torch.device("cpu")  # type: ignore[union-attr]
533            list_shards.append(Shard(cpu_tensor, metadata))
534
535        st_meta = copy.deepcopy(self.metadata())
536        for meta in st_meta.shards_metadata:
537            if meta.placement.device().type != "cpu":  # type: ignore[union-attr]
538                meta.placement._device = torch.device("cpu")  # type: ignore[union-attr]
539
540        pg = self._process_group if process_group is None else process_group
541        st_cpu = ShardedTensor._init_from_local_shards_and_global_metadata(
542            list_shards,
543            sharded_tensor_metadata=st_meta,
544            process_group=pg,
545            init_rrefs=self._init_rrefs,
546        )
547        return st_cpu
548
549    def cuda(
550        self,
551        device=None,
552        non_blocking=False,
553        memory_format=torch.preserve_format,
554        process_group=None,
555    ) -> ShardedTensor:
556        """
557        Returns a copy of this object in CUDA memory, if the original ShardedTensor
558        is on CPU, we will move the local shard to the current GPU device of each
559        process in a SPMD fashion.
560        If this ShardedTensor is already on CUDA memory and local shards on each rank are
561        already on current device, we still returns a new ShardedTensor object with new
562        metadata, but no underlying data movements are performed.
563        .. note:: When moving a ShardedTensor from CPU to GPU, the ShardedTensor might
564            need to be managed by a different type of ProcessGroup(i.e. ProcessGroupNCCL),
565            it is the user's responsiblity to explicitly pass in a new process_group that
566            is compatible with GPU.
567        """
568        if (
569            memory_format != torch.preserve_format
570            and memory_format != torch.contiguous_format
571        ):
572            raise RuntimeError(
573                "Only `torch.contiguous_format` or "
574                "`torch.preserve_format` is supported!"
575            )
576
577        if device is not None:
578            device = torch.device(device) if isinstance(device, str) else device
579            assert (
580                isinstance(device, torch.device)
581                and device.index == torch.cuda.current_device()
582            ), """Only device without device id (e.g. "cpu" or "cuda") is expected for ShardedTensor!"""
583
584        current_device = torch.device(torch.cuda.current_device())
585        # returns a copy of ShardedTensor on CUDA current device
586        list_shards: List[Shard] = []
587        # move all local shards to current device, and change metadata
588        # if local shards already on the current device, there's no
589        # real data movement, only the metadata are copied.
590        for shard in self._local_shards:
591            cuda_tensor = shard.tensor.cuda(
592                device=current_device,
593                non_blocking=non_blocking,
594                memory_format=memory_format,
595            )  # type: ignore[call-arg]
596            metadata = copy.deepcopy(shard.metadata)
597            metadata.placement._device = current_device  # type: ignore[union-attr]
598
599            list_shards.append(Shard(cuda_tensor, metadata))
600
601        st_meta = copy.deepcopy(self.metadata())
602        for meta in st_meta.shards_metadata:
603            if meta.placement.device().type != "cuda":  # type: ignore[union-attr]
604                meta.placement._device = current_device  # type: ignore[union-attr]
605
606        pg = self._process_group if process_group is None else process_group
607        # we need to use `init_from_local_shards` to communicate between ranks
608        # and update the sharding spec/shards metadata.
609        st_cuda = ShardedTensor._init_from_local_shards_and_global_metadata(
610            list_shards,
611            sharded_tensor_metadata=st_meta,
612            process_group=pg,
613            init_rrefs=self._init_rrefs,
614        )
615        return st_cuda
616
617    def to(self, *args, **kwargs) -> ShardedTensor:
618        current_device: torch.device
619        if self._local_shards:
620            current_device = self._local_shards[0].tensor.device
621        elif self._process_group._get_backend_name() == "gloo":
622            current_device = torch.device("cpu")
623        else:
624            current_device = torch.device(torch.cuda.current_device())
625        current_dtype = self.dtype
626        device_to = current_device
627        dtype_to = current_dtype
628        if len(args) == 1:
629            if isinstance(args[0], torch.dtype):
630                dtype_to = args[0]
631            elif isinstance(args[0], torch.device):
632                device_to = args[0]
633            elif isinstance(args[0], (str, int)):
634                device_to = torch.device(args[0])
635            elif isinstance(args[0], torch.Tensor):
636                dtype_to = args[0].dtype
637                device_to = args[0].device
638            else:
639                raise RuntimeError(f"ShardedTensor.to() have wrong arguments: {args}")
640        elif len(args) == 2:
641            device_to, dtype_to = args
642        else:
643            dtype_to = kwargs.get("dtype", current_dtype)
644            device_to = kwargs.get("device", current_device)
645
646        device_to = (
647            torch.device(device_to) if isinstance(device_to, (str, int)) else device_to
648        )
649
650        if device_to.type == "cuda":
651            # if device_to set to cuda, set to current device even
652            # if user specify the device index.
653            current_idx = torch.cuda.current_device()
654            if device_to.index != current_idx:
655                warnings.warn(
656                    "ShardedTensor.to only move tensor to its current device"
657                    "If you want to put to different device, use `reshard` instead."
658                )
659            device_to = torch.device(current_idx)
660
661        copy_tensor = kwargs.get("copy", False)
662        non_blocking = kwargs.get("non_blocking", False)
663        memory_format = kwargs.get("memory_format", torch.preserve_format)
664        process_group = kwargs.get("process_group", None)
665
666        if (
667            not copy_tensor
668            and dtype_to == current_dtype
669            and device_to == current_device
670        ):
671            # already have correct dtype and device, return itself
672            return self
673
674        # returns a copy of ShardedTensor on CUDA current device
675        list_shards: List[Shard] = []
676
677        for shard in self._local_shards:
678            new_tensor = shard.tensor.to(  # type: ignore[call-overload]
679                device=device_to,
680                dtype=dtype_to,
681                non_blocking=non_blocking,
682                copy=copy_tensor,
683                memory_format=memory_format,
684            )
685            metadata = copy.deepcopy(shard.metadata)
686            if metadata.placement is not None:
687                metadata.placement._device = device_to
688            list_shards.append(Shard(new_tensor, metadata))
689
690        # update metadata
691        st_meta = copy.deepcopy(self.metadata())
692        st_meta.tensor_properties.dtype = dtype_to
693        for meta in st_meta.shards_metadata:
694            meta.placement._device = device_to  # type: ignore[union-attr]
695
696        pg = self._process_group if process_group is None else process_group
697        # we need to use `init_from_local_shards` to communicate between ranks
698        # and update the sharding spec/shards metadata.
699        st_to = ShardedTensor._init_from_local_shards_and_global_metadata(
700            list_shards,
701            sharded_tensor_metadata=st_meta,
702            process_group=pg,
703            init_rrefs=self._init_rrefs,
704        )
705        return st_to
706
707    @classmethod
708    def _normalize_pg(
709        cls, process_group: Optional[dist.ProcessGroup]
710    ) -> dist.ProcessGroup:
711        if process_group is not None:
712            return process_group
713        return distributed_c10d._get_default_group()
714
715    @classmethod
716    def _init_from_local_shards(
717        cls,
718        local_shards: List[Shard],
719        *global_size,
720        process_group=None,
721        init_rrefs=False,
722    ):
723        # STEP 1: Validate the Shardmetadatas locally
724        process_group = cls._normalize_pg(process_group)
725        current_rank = dist.get_rank()  # intentional to get global rank
726        world_size = dist.get_world_size(process_group)
727
728        local_sharded_tensor_metadata: Optional[ShardedTensorMetadata] = None
729        global_tensor_size = _flatten_tensor_size(global_size)
730
731        if len(local_shards) > 0:
732            local_sharded_tensor_metadata = build_metadata_from_local_shards(
733                local_shards, global_tensor_size, current_rank, process_group
734            )
735
736        # STEP 2. Validate metadata across ranks, and build a global sharded tensor
737        # metadata by gathering local ShardedTensorMetadata
738        gathered_metadatas: List[Optional[ShardedTensorMetadata]] = []
739        if world_size > 1:
740            gathered_metadatas = [None for _ in range(world_size)]
741
742            dist.all_gather_object(
743                gathered_metadatas, local_sharded_tensor_metadata, group=process_group
744            )
745        else:
746            gathered_metadatas = [local_sharded_tensor_metadata]
747
748        global_sharded_tensor_metadata = build_global_metadata(gathered_metadatas)
749        tensor_properties = global_sharded_tensor_metadata.tensor_properties
750
751        # STEP 3: Validation done, create the actual ShardedTensor and populate fields
752        # prepare initialization
753        spec = shard_spec._infer_sharding_spec_from_shards_metadata(
754            global_sharded_tensor_metadata.shards_metadata
755        )
756        sharded_tensor = cls.__new__(
757            cls,
758            spec,
759            global_sharded_tensor_metadata.size,
760            dtype=tensor_properties.dtype,
761            layout=tensor_properties.layout,
762            pin_memory=tensor_properties.pin_memory,
763            requires_grad=tensor_properties.requires_grad,
764        )
765        sharded_tensor._prepare_init(process_group=process_group, init_rrefs=init_rrefs)
766
767        # attach local_shards to the ShardedTensor created
768        sharded_tensor._local_shards = local_shards
769
770        # run post initialization, i.e. map registration, rpc initialization
771        sharded_tensor._post_init()
772        return sharded_tensor
773
774    @classmethod
775    @deprecated(DEPRECATE_MSG, category=FutureWarning)
776    def _init_from_local_tensor(
777        cls,
778        local_tensor: torch.Tensor,
779        sharding_spec: shard_spec.ShardingSpec,
780        *global_size: Sequence[int],
781        process_group: Optional[dist.ProcessGroup] = None,
782        init_rrefs=False,
783    ) -> ShardedTensor:
784        """
785        Initialize a ShardedTensor given only one local tensor, global sharded tensor
786        size and sharding spec on each rank.
787
788        Args:
789            local_tensor (Tensor): Single tensor of local shard stored in each rank.
790            sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`):
791                The specification describing how to shard the Tensor.
792            global_size (Sequence[int]): Size of the sharded tensor.
793            process_group (ProcessGroup, optional): The process group to aggregate on.
794                Default: None
795            init_rrefs (bool, optional): Whether or not to initialize
796                :class:`torch.distributed.rpc.RRef`s pointing to remote shards.
797                Need to initialize the RPC Framework if specified as ``True``.
798                Default: ``False``.
799
800        Returns:
801            A :class:`ShardedTensor` sharded based on the given sharding_spec with local
802                tensor stored in the current rank.
803
804        Examples:
805            >>> # xdoctest: +SKIP
806            >>> # All tensors below are of torch.int64 type.
807            >>> # We have 2 process groups, 2 ranks.
808            >>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank
809            >>> local_tensor = torch.unsqueeze(torch.cat([tensor, tensor + 2]))
810            >>> local_tensor
811            tensor([[1, 2, 3, 4]]) # Rank 0
812            tensor([[3, 4, 5, 6]]) # Rank 1
813            >>> sharding_dim = 0
814            >>> sharding_spec = ChunkShardingSpec(
815                    dim=sharding_dim,
816                    placements=[
817                        "rank:0/cuda:0",
818                        "rank:1/cuda:1",
819                    ],
820                )
821            >>> st = ShardedTensor._init_from_local_tensor(local_tensor, sharding_spec, [2, 4])
822            >>> st
823            ShardedTensor(
824                ShardedTensorMetadata(
825                    shards_metadata=[
826                        ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1, 4], placement=rank:0/cuda:0),
827                        ShardMetadata(shard_offsets=[1, 0], shard_sizes=[1, 4], placement=rank:1/cuda:1),
828                    ],
829                    size=torch.Size([2, 4])
830            )
831            >>> st.local_tensor()
832            tensor([1, 2, 3, 4]) # Rank 0
833            tensor([3, 4, 5, 6]) # Rank 1
834
835        Warning: This API is experimental and subject to change. It lacks of a fully across
836                 rank validations, and we only validate the local shard on the current rank.
837                 We fully rely on the user to ensure local tensor is sharded based on the
838                 sharding spec.
839        """
840        if not local_tensor.is_contiguous():
841            raise ValueError("local_tensor is not a contiguous Tensor.")
842
843        global_tensor_size = _flatten_tensor_size(global_size)
844        tensor_properties = TensorProperties(
845            dtype=local_tensor.dtype,
846            layout=local_tensor.layout,
847            requires_grad=local_tensor.requires_grad,
848            memory_format=torch.contiguous_format,
849            pin_memory=local_tensor.is_pinned(),
850        )
851        sharded_tensor_metadata = sharding_spec.build_metadata(
852            global_tensor_size, tensor_properties
853        )
854
855        process_group = cls._normalize_pg(process_group)
856        current_rank = dist.get_rank()  # intentional to get global rank
857
858        local_shards: List[Shard] = []
859        for shard_metadata in sharded_tensor_metadata.shards_metadata:
860            rank, device = _parse_and_validate_remote_device(
861                process_group, shard_metadata.placement
862            )
863            if rank == current_rank:
864                local_shards.append(Shard(local_tensor, shard_metadata))
865
866        # TODO: figure out what the API should behave when some rank have no shard
867        # see https://github.com/pytorch/pytorch/issues/7313
868        return ShardedTensor._init_from_local_shards_and_global_metadata(
869            local_shards,
870            sharded_tensor_metadata,
871            process_group=process_group,
872            init_rrefs=init_rrefs,
873            sharding_spec=sharding_spec,
874        )
875
876    @classmethod
877    def _init_from_local_shards_and_global_metadata(  # type: ignore[override]
878        cls,
879        local_shards: List[Shard],
880        sharded_tensor_metadata: ShardedTensorMetadata,
881        process_group=None,
882        init_rrefs=False,
883        sharding_spec=None,
884    ) -> ShardedTensor:
885        """
886        Initialize a ShardedTensor with local shards and a global
887        ShardedTensorMetadata built on each rank.
888
889        Warning: This API is experimental and subject to change. It does
890                 not do cross rank validations, and fully rely on the user
891                 for the correctness of sharded_tensor_metadata on each rank
892        """
893        process_group = cls._normalize_pg(process_group)
894        current_rank = dist.get_rank()  # intentional to get global rank
895
896        shards_metadata = sharded_tensor_metadata.shards_metadata
897
898        local_shard_metadatas = []
899
900        # collect local shard metadatas from the global sharded_tensor_metadata
901        for shard_metadata in shards_metadata:  # type: ignore[attr-defined]
902            rank, local_device = _parse_and_validate_remote_device(
903                process_group, shard_metadata.placement
904            )
905
906            if current_rank == rank:
907                local_shard_metadatas.append(shard_metadata)
908
909        if len(local_shards) != len(local_shard_metadatas):
910            raise RuntimeError(
911                f"Number of local shards ({len(local_shards)}) does not match number of local "
912                f"shards metadata in sharded_tensor_metadata ({len(local_shard_metadatas)}) "
913                f"on rank ({current_rank}) "
914            )
915
916        shards_metadata = sharded_tensor_metadata.shards_metadata
917        tensor_properties = sharded_tensor_metadata.tensor_properties
918
919        if len(shards_metadata) == 0:
920            raise ValueError("shards_metadata must not be empty!")
921
922        if tensor_properties.layout != torch.strided:
923            raise ValueError("Only torch.strided layout is currently supported")
924
925        if sharding_spec is None:
926            spec = shard_spec._infer_sharding_spec_from_shards_metadata(shards_metadata)
927        else:
928            spec = sharding_spec
929
930        sharded_tensor = ShardedTensor.__new__(
931            ShardedTensor,
932            spec,
933            sharded_tensor_metadata.size,
934            dtype=tensor_properties.dtype,
935            layout=tensor_properties.layout,
936            pin_memory=tensor_properties.pin_memory,
937            requires_grad=tensor_properties.requires_grad,
938        )
939
940        def _raise_if_mismatch(expected, actual, prop_name, rank, is_property=False):
941            tensor_property_or_metadata = (
942                "tensor property" if is_property else "local ShardMetadata"
943            )
944            if expected != actual:
945                raise ValueError(
946                    f"Local shards' tensor {prop_name} property is incompatible with "
947                    f"{tensor_property_or_metadata} on rank {rank}: "
948                    f"{tensor_property_or_metadata} {prop_name}={expected}, "
949                    f"local shard tensor {prop_name}={actual}."
950                )
951
952        for shard in local_shards:
953            shard_meta = shard.metadata
954            local_shard_tensor = shard.tensor
955            placement = shard_meta.placement
956            assert placement is not None, "Must specify placement for `Shard`!"
957            rank = placement.rank()
958            local_device = placement.device()
959
960            _raise_if_mismatch(
961                tensor_properties.layout,
962                local_shard_tensor.layout,
963                "layout",
964                rank,
965                True,
966            )
967            if not local_shard_tensor.is_contiguous():
968                raise ValueError(
969                    "Only torch.contiguous_format memory_format is currently supported"
970                )
971
972            _raise_if_mismatch(
973                shard_meta.shard_sizes,
974                list(local_shard_tensor.size()),
975                "size",
976                rank,
977            )
978            _raise_if_mismatch(
979                tensor_properties.pin_memory,
980                local_shard_tensor.is_pinned(),
981                "pin_memory",
982                rank,
983                True,
984            )
985            _raise_if_mismatch(local_device, local_shard_tensor.device, "device", rank)
986            _raise_if_mismatch(
987                tensor_properties.dtype,
988                local_shard_tensor.dtype,
989                "dtype",
990                rank,
991                True,
992            )
993            _raise_if_mismatch(
994                tensor_properties.requires_grad,
995                local_shard_tensor.requires_grad,
996                "requires_grad",
997                rank,
998                True,
999            )
1000
1001        # check if shards_metadata have overlap shards
1002        validate_non_overlapping_shards_metadata(shards_metadata)
1003
1004        # check if the shards_metadata is compatible with overall size of the sharded tensor.
1005        check_tensor(shards_metadata, list(sharded_tensor_metadata.size))
1006
1007        # done validation, add local_shards
1008        sharded_tensor._local_shards = local_shards
1009        sharded_tensor._prepare_init(process_group=process_group, init_rrefs=init_rrefs)
1010
1011        # run post initialization, i.e. map registration, rpc initialization
1012        sharded_tensor._post_init()
1013        return sharded_tensor
1014
1015    def sharding_spec(self) -> shard_spec.ShardingSpec:
1016        """
1017        Returns the ShardingSpec for the tensor.
1018        """
1019        return self._sharding_spec
1020
1021    @deprecated(DEPRECATE_MSG, category=FutureWarning)
1022    def reshard(self, resharding_spec: shard_spec.ShardingSpec) -> ShardedTensor:
1023        """
1024        Reshard a sharded tensor given the ``resharding_spec``. For now, we only support
1025        single local shard.
1026
1027        If ``resharding_spec`` is same as the original one, this becomes a no-op.
1028        If only ``resharding_spec`` shares the same sharding dim with the original one,
1029        we swap local shards directly.
1030        For more generic cases, we merge different shards across different ranks and split
1031        the local shards based on the ``resharding_spec`` via `all_to_all` collective API.
1032
1033        Args:
1034            resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The
1035                specification describing how the tensor is sharded.
1036
1037        Returns:
1038            A :class:`ShardedTensor` object whose local shards are resharded.
1039
1040        Examples:
1041            >>> # xdoctest: +SKIP
1042            >>> # We have 2 process groups, 2 ranks.
1043            >>> tensor = torch.arange(4, dtype=torch.int64) + 1 + 2 * rank
1044            >>> tensor = torch.stack([tensor, tensor])
1045            >>> tensor
1046            tensor([[1, 2, 3, 4], [1, 2, 3, 4]]) # Rank 0
1047            tensor([[3, 4, 5, 6], [3, 4, 5, 6]]) # Rank 1
1048            tensor([[5, 6, 7, 8], [5, 6, 7, 8]]) # Rank 2
1049            tensor([[7, 8, 9, 10], [7, 8, 9, 10]]) # Rank 3
1050            >>> sharding_dim = 0
1051            >>> spec = ChunkShardingSpec(
1052                    dim=sharding_dim,
1053                    placements=[
1054                        "rank:0/cuda:0",
1055                        "rank:1/cuda:1",
1056                        "rank:2/cuda:2",
1057                        "rank:3/cuda:3",
1058                    ],
1059                )
1060            >>> current_offsets = [0] * 2
1061            >>> current_offsets[0] = rank * 2
1062            >>> shard_metadata = ShardMetadata(
1063                    shard_offsets=copy.deepcopy(current_offsets),
1064                    shard_sizes=tensor.size(),
1065                    placement=spec.placements[rank],
1066                )
1067            >>> local_shards = [
1068                    Shard(
1069                        tensor=tensor,
1070                        metadata=shard_metadata,
1071                    )
1072                ]
1073            >>> st = ShardedTensor._init_from_local_shards(local_shards, tensor.size())
1074            >>> sharding_dim = 1
1075            >>> resharding_spec = ChunkShardingSpec(
1076                    dim=sharding_dim,
1077                    placements=[
1078                        "rank:0/cuda:0",
1079                        "rank:1/cuda:1",
1080                        "rank:2/cuda:2",
1081                        "rank:3/cuda:3",
1082                    ],
1083                )
1084            >>> st.reshard(resharding_spec)
1085            >>> tensor = st.local_shards()[0].tensor
1086            >>> tensor
1087            tensor([[1], [1], [3], [3], [5], [5], [7], [7]]) # Rank 0
1088            tensor([[2], [2], [4], [4], [6], [6], [8], [8]]) # Rank 1
1089            tensor([[3], [3], [5], [5], [7], [7], [9], [9]]) # Rank 2
1090            tensor([[4], [4], [6], [6], [8], [8], [10], [10]]) # Rank 3
1091        """
1092        if not isinstance(
1093            resharding_spec, shard_spec.ChunkShardingSpec
1094        ) or not isinstance(self._sharding_spec, shard_spec.ChunkShardingSpec):
1095            raise NotImplementedError("Only ChunkShardingSpec supported for reshard.")
1096        if len(self.local_shards()) != 1:
1097            raise NotImplementedError("Only single local shard supported for reshard.")
1098
1099        if self._sharding_spec.dim == resharding_spec.dim:  # type: ignore[attr-defined]
1100            if self._sharding_spec.placements == resharding_spec.placements:  # type: ignore[attr-defined]
1101                return self
1102            else:
1103                local_shards, shards_metadata = reshuffle_local_shard(
1104                    self.local_tensor(),
1105                    self.size(),  # type: ignore[arg-type]
1106                    self._sharding_spec,
1107                    resharding_spec,
1108                    self._process_group,
1109                )
1110        else:
1111            local_shards, shards_metadata = reshard_local_shard(
1112                self.local_tensor(),
1113                self.size(),  # type: ignore[arg-type]
1114                self._sharding_spec,
1115                resharding_spec,
1116                self._process_group,
1117            )
1118        self._local_shards = local_shards
1119        self._metadata.shards_metadata = shards_metadata
1120        self._sharding_spec = resharding_spec
1121        return self
1122
1123    def local_tensor(self) -> torch.Tensor:
1124        """
1125        Return local tensor for a sharded_tensor. For now we only support single local shard.
1126
1127        Returns:
1128            A :class:`torch.Tensor` of the local shard.
1129        """
1130        if len(self.local_shards()) != 1:
1131            raise NotImplementedError("Only single local shard is supported.")
1132        return self.local_shards()[0].tensor
1133
1134    @classmethod
1135    @deprecated(DEPRECATE_MSG, category=FutureWarning)
1136    def __torch_function__(cls, func, types, args=(), kwargs=None):
1137        def dispatch(st: ShardedTensor, func: Callable):
1138            # Dispatch to custom user provided op first if it exists.
1139            if func in _CUSTOM_SHARDED_OPS:
1140                return _CUSTOM_SHARDED_OPS[func](types, args, kwargs, st._process_group)
1141
1142            # Dispatch to custom sharding spec op if it has one.
1143            if _has_custom_op(st._sharding_spec, func):
1144                return _dispatch_custom_op(
1145                    st._sharding_spec, func, types, args, kwargs, st._process_group
1146                )
1147
1148            if func in _SHARDED_OPS:
1149                return _SHARDED_OPS[func](types, args, kwargs, st._process_group)
1150
1151            raise RuntimeError(
1152                f"torch function '{func.__name__}', with args: {args} and "
1153                f"kwargs: {kwargs} not supported for ShardedTensor!"
1154            )
1155
1156        # Find ShardedTensor instance to get process_group and sharding_spec.
1157        st_instance = None
1158
1159        def find_sharded_tensor(e):
1160            nonlocal st_instance
1161            if st_instance is None and isinstance(e, ShardedTensor):
1162                st_instance = e
1163
1164        pytree.tree_map_(find_sharded_tensor, args)
1165        pytree.tree_map_(find_sharded_tensor, kwargs)
1166
1167        if st_instance is not None:
1168            return dispatch(st_instance, func)
1169
1170        raise RuntimeError(
1171            f"torch function '{func.__name__}', with args: {args} and "
1172            f"kwargs: {kwargs} not supported for ShardedTensor!"
1173        )
1174
1175    def is_pinned(self) -> bool:  # type: ignore[override]
1176        """
1177        Returns True if the sharded tensor (each local shard) resides in pinned memory.
1178        """
1179        return self._metadata.tensor_properties.pin_memory
1180
1181    def _register_remote_shards(
1182        self, remote_shards: List[rpc.RRef[Shard]], rpc_rank: int
1183    ):
1184        self._remote_shards[rpc_rank] = remote_shards
1185
1186    def remote_shards(self) -> Dict[int, List[rpc.RRef[Shard]]]:
1187        """
1188        Returns a Dict[int, RRef] with keys being the RPC rank and values
1189        being RRefs to shards on that rank. Need to initialize the
1190        RPC framework for this functionality.
1191
1192        Raises an exception if ShardedTensor was created with ``init_rrefs=False``
1193        """
1194        if not self._init_rrefs:
1195            raise RuntimeError(
1196                "ShardedTensor created with init_rrefs=False, no RRefs to remote shards available"
1197            )
1198        return self._remote_shards
1199
1200    def __hash__(self):
1201        return id(self)
1202
1203    def __repr__(self):
1204        return f"ShardedTensor({self._metadata})"
1205
1206    @dataclass
1207    class ProcessGroupState:
1208        """
1209        State for ser-de of process group
1210        """
1211
1212        local_rank: int
1213        global_rank: int
1214        local_world_size: int
1215        global_world_size: int
1216
1217    def __getstate__(self):
1218        pg_state = ShardedTensor.ProcessGroupState(
1219            distributed_c10d.get_rank(self._process_group),
1220            distributed_c10d.get_rank(),
1221            distributed_c10d.get_world_size(self._process_group),
1222            distributed_c10d.get_world_size(),
1223        )
1224
1225        return (
1226            self._local_shards,
1227            self._metadata,
1228            pg_state,
1229            self._sharding_spec,
1230            self._init_rrefs,
1231        )
1232
1233    def __setstate__(self, state):
1234        self._sharded_tensor_id = None
1235        if not distributed_c10d.is_initialized():
1236            raise RuntimeError(
1237                "Need to initialize default process group using "
1238                '"init_process_group" before loading ShardedTensor'
1239            )
1240
1241        (
1242            self._local_shards,
1243            self._metadata,
1244            pg_state,
1245            self._sharding_spec,
1246            self._init_rrefs,
1247        ) = state
1248
1249        # Setup process group
1250        from torch.distributed._shard.api import _get_current_process_group
1251
1252        self._process_group = _get_current_process_group()
1253
1254        # Validate process group.
1255        local_rank = distributed_c10d.get_rank(self._process_group)
1256        if pg_state.local_rank != local_rank:
1257            raise RuntimeError(
1258                f"Local rank at save time was {pg_state.local_rank}, but at "
1259                f"load time was {local_rank}"
1260            )
1261
1262        global_rank = distributed_c10d.get_rank()
1263        if pg_state.global_rank != global_rank:
1264            raise RuntimeError(
1265                f"Global rank at save time was {pg_state.global_rank}, but at "
1266                f"load time was {global_rank}"
1267            )
1268
1269        local_world_size = distributed_c10d.get_world_size(self._process_group)
1270        if pg_state.local_world_size != local_world_size:
1271            raise RuntimeError(
1272                f"Local world size at save time was {pg_state.local_world_size}, "
1273                f"but at load time was {local_world_size}"
1274            )
1275
1276        global_world_size = distributed_c10d.get_world_size()
1277        if pg_state.global_world_size != global_world_size:
1278            raise RuntimeError(
1279                f"Global world size at save time was {pg_state.global_world_size}, "
1280                f"but at load time was {global_world_size}"
1281            )
1282
1283        self._post_init()
1284
1285
1286def _create_tensor_from_params(
1287    *size, local_device, tensor_properties: TensorProperties
1288):
1289    """Helper to construct tensor from size, device and common params."""
1290    dtype = tensor_properties.dtype
1291    layout = tensor_properties.layout
1292    requires_grad = tensor_properties.requires_grad
1293    memory_format = tensor_properties.memory_format
1294    pin_memory = tensor_properties.pin_memory
1295
1296    return torch.empty(
1297        *size,
1298        dtype=dtype,
1299        layout=layout,
1300        device=local_device,
1301        requires_grad=requires_grad,
1302        memory_format=memory_format,
1303        pin_memory=pin_memory,
1304    )
1305