1# mypy: allow-untyped-defs 2from __future__ import annotations # type: ignore[attr-defined] 3 4import copy 5import operator 6import threading 7import warnings 8import weakref 9from dataclasses import dataclass 10from functools import reduce 11from typing import Callable, cast, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING 12from typing_extensions import deprecated 13 14import torch 15import torch.distributed as dist 16import torch.distributed._shard.sharding_spec as shard_spec 17from torch.distributed import distributed_c10d, rpc 18from torch.distributed._shard._utils import DEPRECATE_MSG 19from torch.distributed._shard.sharding_spec._internals import ( 20 check_tensor, 21 validate_non_overlapping_shards_metadata, 22) 23from torch.distributed._shard.sharding_spec.api import ( 24 _dispatch_custom_op, 25 _has_custom_op, 26) 27from torch.distributed.remote_device import _remote_device 28from torch.utils import _pytree as pytree 29 30from .metadata import ShardedTensorMetadata, TensorProperties 31from .reshard import reshard_local_shard, reshuffle_local_shard 32from .shard import Shard 33from .utils import ( 34 _flatten_tensor_size, 35 _parse_and_validate_remote_device, 36 _validate_output_tensor_for_gather, 37 build_global_metadata, 38 build_metadata_from_local_shards, 39) 40 41 42if TYPE_CHECKING: 43 from torch.distributed._shard.metadata import ShardMetadata 44 45 46# Tracking for sharded tensor objects. 47_sharded_tensor_lock = threading.Lock() 48_sharded_tensor_current_id = 0 49_sharded_tensor_map: Dict[int, weakref.ReferenceType[ShardedTensor]] = {} 50 51# Default sharded ops 52_SHARDED_OPS: Dict[Callable, Callable] = {} 53 54# Customized user ops 55_CUSTOM_SHARDED_OPS: Dict[Callable, Callable] = {} 56 57 58def _register_remote_shards( 59 sharded_tensor_id: int, rrefs: List[rpc.RRef[Shard]], rpc_rank: int 60): 61 with _sharded_tensor_lock: 62 if sharded_tensor_id not in _sharded_tensor_map: 63 raise RuntimeError( 64 f"Could not find sharded_tensor_id: {sharded_tensor_id} in map: {_sharded_tensor_map.keys()}" 65 ) 66 67 sharded_tensor = _sharded_tensor_map[sharded_tensor_id]() 68 if sharded_tensor is None: 69 raise RuntimeError("ShardedTensor weakref has been deallocated") 70 else: 71 sharded_tensor._register_remote_shards(rrefs, rpc_rank) 72 73 74class ShardedTensorBase(torch.Tensor): 75 _sharding_spec: shard_spec.ShardingSpec 76 _metadata: ShardedTensorMetadata 77 _local_shards: List[Shard] 78 79 def __new__(cls, sharding_spec: shard_spec.ShardingSpec, *size, **kwargs): 80 # Use __new__ to construct a wrapper tensor, for recording tensor 81 # properties and logging purposes. 82 torch._C._log_api_usage_once("torch.distributed._shard.sharded_tensor") 83 84 # check sharding spec and build sharded tensor metadata 85 if not isinstance(sharding_spec, shard_spec.ShardingSpec): 86 raise ValueError(f"Expecting ShardingSpec but got: {type(sharding_spec)}") 87 88 sizes = _flatten_tensor_size(size) 89 dtype = kwargs["dtype"] 90 layout = kwargs["layout"] 91 pin_memory = kwargs["pin_memory"] 92 requires_grad = kwargs["requires_grad"] 93 94 if dtype is None: 95 dtype = torch.get_default_dtype() 96 97 tensor_properties = TensorProperties( 98 dtype, layout, requires_grad, pin_memory=pin_memory 99 ) 100 sharded_tensor_metadata = sharding_spec.build_metadata( 101 sizes, tensor_properties=tensor_properties 102 ) 103 104 r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] 105 cls, 106 sizes, 107 dtype=dtype, 108 layout=layout, 109 pin_memory=pin_memory, 110 requires_grad=requires_grad, 111 ) 112 # set sharding spec 113 r._sharding_spec = sharding_spec 114 # set metadata 115 r._metadata = sharded_tensor_metadata 116 # set local shards 117 r._local_shards = [] 118 return r 119 120 def metadata(self) -> ShardedTensorMetadata: 121 """ 122 Returns a :class:`ShardedTensorMetadata` object corresponding to the 123 metadata for the entire tensor. 124 """ 125 return self._metadata 126 127 def local_shards(self) -> List[Shard]: 128 """ 129 Returns a list of :class:`Shard' corresponding to the 130 local shards for this rank. Returns an empty list if the current rank 131 does not host any shards for this Tensor. 132 """ 133 return self._local_shards 134 135 @classmethod 136 def _init_from_local_shards_and_global_metadata( 137 cls, 138 local_shards: List[Shard], 139 sharded_tensor_metadata: ShardedTensorMetadata, 140 sharding_spec=None, 141 ) -> ShardedTensorBase: 142 """ 143 Initialize a ShardedTensorBase with local shards and a global 144 ShardedTensorMetadata built on each rank. 145 Warning: This API is experimental and subject to change. It does 146 not do cross rank validations, and fully rely on the user 147 for the correctness of sharded_tensor_metadata on each rank 148 """ 149 shards_metadata = sharded_tensor_metadata.shards_metadata 150 tensor_properties = sharded_tensor_metadata.tensor_properties 151 152 if len(shards_metadata) == 0: 153 raise ValueError("shards_metadata must not be empty!") 154 155 if tensor_properties.layout != torch.strided: 156 raise ValueError("Only torch.strided layout is currently supported") 157 158 if sharding_spec is None: 159 spec = shard_spec._infer_sharding_spec_from_shards_metadata(shards_metadata) 160 else: 161 spec = sharding_spec 162 163 sharded_tensor_base = ShardedTensorBase.__new__( 164 ShardedTensor, 165 spec, 166 sharded_tensor_metadata.size, 167 dtype=tensor_properties.dtype, 168 layout=tensor_properties.layout, 169 pin_memory=tensor_properties.pin_memory, 170 requires_grad=tensor_properties.requires_grad, 171 ) 172 173 # check if shards_metadata have overlap shards 174 validate_non_overlapping_shards_metadata(shards_metadata) 175 176 # check if the shards_metadata is compatible with overall size of the sharded tensor. 177 check_tensor(shards_metadata, list(sharded_tensor_metadata.size)) 178 179 # done validation, add local_shards 180 sharded_tensor_base._local_shards = local_shards 181 return sharded_tensor_base 182 183 @classmethod 184 def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 185 raise RuntimeError( 186 f"A {cls.__name__} object is being used from c++ while calling {func.__module__}.{func.__name__} " 187 "but the there is no custom __torch_dispatch__ implementation for it." 188 ) 189 190 191class ShardedTensor(ShardedTensorBase): 192 """ 193 ShardedTensor is an torch.Tensor subclass to represent Tensors that are sharded 194 across multiple devices and multiple processes. 195 196 ShardedTensor is initialized in an SPMD like fashion where each rank 197 initializes the ShardedTensor. The ShardedTensor object on each rank 198 then only stores the local shard for the Tensor and provides global 199 metadata for all the shards. 200 201 ShardedTensor doesn't provide any Tensor like operations but is a wrapper 202 providing the Tensor representing the local shard and the global metadata. 203 Using these, users can build their custom distributed._sharded computations 204 on top of this primitive. The local shards are all initialized using the 205 create_op specified by tensor_init_params.create_op, e.g., torch.ones, or 206 torch.empty 207 208 Args: 209 sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification 210 describing how to shard the Tensor. 211 size (int...): a sequence of integers defining the shape of the output 212 tensor. Can be a variable number of arguments or a collection like a list or tuple. 213 214 Keyword args: 215 dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. 216 Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). 217 layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. 218 Default: ``torch.strided``. 219 requires_grad (bool, optional): If autograd should record operations on the 220 returned tensor. Default: ``False``. 221 pin_memory (bool, optional): If set, returned tensor would be allocated in 222 the pinned memory. Works only for CPU tensors. Default: ``False``. 223 memory_format (:class:`torch.memory_format`, optional): the desired memory format of 224 returned Tensor. Default: ``torch.contiguous_format``. 225 init_rrefs (bool, optional): Whether or not to initialize 226 :class:`torch.distributed.rpc.RRef`s pointing to remote shards. 227 Need to initialize the RPC Framework if specified as ``True``. 228 Default: ``False``. 229 230 .. note:: ShardedTensor uses collectives to do various operations, i.e. it 231 uses all_gather to do cross rank validations. For NCCL-based process 232 groups, internal tensor representations of objects must be moved to the 233 GPU device before communication takes place. In this case, the device 234 used is given by ``torch.cuda.current_device()`` and it is the user's 235 responsibility to ensure that this is set so that each rank has an 236 individual GPU, via ``torch.cuda.set_device()`` 237 238 """ 239 240 def __new__(cls, sharding_spec: shard_spec.ShardingSpec, *size, **kwargs): 241 self = super().__new__(cls, sharding_spec, *size, **kwargs) 242 return self 243 244 def __init__( 245 self, 246 sharding_spec: shard_spec.ShardingSpec, 247 *size, 248 dtype=None, 249 layout=torch.strided, 250 requires_grad=False, 251 pin_memory=False, 252 memory_format=torch.contiguous_format, 253 process_group=None, 254 init_rrefs=False, 255 ): 256 # prepare initialization, initialize fields like 257 # _process_group, _local_shards, etc. 258 self._prepare_init(process_group=process_group, init_rrefs=init_rrefs) 259 260 if layout != torch.strided: 261 raise ValueError("Only torch.strided layout is currently supported") 262 263 if memory_format != torch.contiguous_format: 264 raise ValueError( 265 "Only torch.contiguous_format memory_format is currently supported" 266 ) 267 268 self._metadata.tensor_properties.memory_format = memory_format 269 270 current_rank = dist.get_rank() # global rank 271 272 for shard_metadata in self._metadata.shards_metadata: 273 rank, device = _parse_and_validate_remote_device( 274 self._process_group, shard_metadata.placement 275 ) 276 if rank == current_rank: 277 local_tensor = _create_tensor_from_params( 278 shard_metadata.shard_sizes, 279 local_device=device, 280 tensor_properties=self._metadata.tensor_properties, 281 ) 282 self._local_shards.append(Shard(local_tensor, shard_metadata)) 283 284 # do post initialization (i.e. register sharded_tensor_id, initialize_rpc) 285 self._post_init() 286 287 def _prepare_init(self, process_group=None, init_rrefs=False): 288 self._init_rrefs = init_rrefs 289 self._sharded_tensor_id = None 290 291 self._process_group = self._normalize_pg(process_group) 292 self._remote_shards: Dict[int, List[rpc.RRef[Shard]]] = {} 293 294 def _post_init(self): 295 # Initialize RPC if available. 296 if self._init_rrefs: 297 with _sharded_tensor_lock: 298 global _sharded_tensor_current_id, _sharded_tensor_map 299 self._sharded_tensor_id = _sharded_tensor_current_id 300 _sharded_tensor_map[self._sharded_tensor_id] = weakref.ref(self) 301 _sharded_tensor_current_id += 1 302 303 if not rpc._is_current_rpc_agent_set(): 304 raise RuntimeError( 305 "RPC Framework needs to be initialized using" 306 " torch.distributed.rpc.init_rpc if init_rrefs is set to True" 307 ) 308 self._init_rpc() 309 310 def __del__(self): 311 # Clean up the global map. 312 with _sharded_tensor_lock: 313 global _sharded_tensor_current_id, _sharded_tensor_map 314 if ( 315 hasattr(self, "_sharded_tensor_id") 316 and self._sharded_tensor_id in _sharded_tensor_map 317 ): 318 _sharded_tensor_map.pop(self._sharded_tensor_id) # type: ignore[call-overload] 319 320 def _init_rpc(self): 321 # Validate PG and RPC ranks match. 322 pg_rank = dist.get_rank() 323 rpc_rank = rpc.get_worker_info().id 324 if pg_rank != rpc_rank: 325 raise ValueError( 326 f"Default ProcessGroup and RPC ranks must be " 327 f"the same for ShardedTensor, found process group rank: " 328 f"{pg_rank} and RPC rank: {rpc_rank}" 329 ) 330 331 self._remote_shards = {} 332 333 # Gather all the sharded tensor ids. 334 worker_infos = rpc._get_current_rpc_agent().get_worker_infos() 335 rank_to_name = {} 336 name_to_rank = {} 337 338 for worker_info in worker_infos: 339 rank_to_name[worker_info.id] = worker_info.name 340 name_to_rank[worker_info.name] = worker_info.id 341 342 all_tensor_ids = rpc.api._all_gather(self._sharded_tensor_id) 343 344 # Share the local shards to the entire world. 345 futs = [] 346 rpc_rank = rpc.get_worker_info().id 347 for rank in range(dist.get_world_size()): 348 # Skip self. 349 if rank == dist.get_rank(): 350 continue 351 352 if len(self.local_shards()) != 0: 353 rrefs: List[rpc.RRef[Shard]] = [ 354 rpc.RRef(shard) for shard in self.local_shards() 355 ] 356 fut = rpc.rpc_async( 357 rank, 358 _register_remote_shards, 359 args=(all_tensor_ids[rank_to_name[rank]], rrefs, rpc_rank), 360 ) 361 futs.append(fut) 362 363 torch.futures.wait_all(futs) 364 365 # Barrier for all RPCs to finish on all ranks. 366 rpc.api._all_gather(None) 367 368 def _get_preferred_device(self) -> torch.device: 369 """ 370 Return the preferred device to be used when creating tensors for collectives. 371 This method takes into account the associated process group 372 """ 373 if dist.get_backend(self._process_group) == dist.Backend.NCCL: 374 return torch.device(torch.cuda.current_device()) 375 return torch.device("cpu") 376 377 def gather( # type: ignore[override] 378 self, 379 dst: int = 0, 380 out: Optional[torch.Tensor] = None, 381 enforce_dtype: bool = False, 382 dtype: Optional[torch.dtype] = None, 383 ) -> None: 384 """ 385 Creates a full :class:`Tensor` on rank ``dst`` by gathering all shards of the 386 sharded tensor. 387 388 The API needs to be called on all ranks in SPMD fashion. All ranks should have 389 the same ``dst``. ``out`` should be a tensor of the same size as the overall 390 size of the sharded tensor on ``dst`` and ``None`` on all other ranks. 391 392 Args: 393 dst(int): The rank where full tensor is constructed. 394 Default: 0 395 out (:class `torch.Tensor`, optional): The output full tensor. 396 Must to be provided ONLY on ``dst`` rank. 397 Default: ``None`` 398 enforce_dtype (bool): Deprecated, please use dtype instead. Force the 399 gathered tensors to be the same type as input and output. 400 dtype (torch.dtype): Force the gathered tensors to be this dtype. 401 Default: ``None`` 402 """ 403 404 def shard_size(shard_md): 405 return reduce(operator.mul, shard_md.shard_sizes) # type: ignore[attr-defined] 406 407 if enforce_dtype: 408 warnings.warn( 409 "`enforce_dtype` is deprecated. Please use `dtype` instead.", 410 FutureWarning, 411 stacklevel=2, 412 ) 413 414 rank = dist.get_rank(self._process_group) 415 full_size = self.metadata().size 416 _validate_output_tensor_for_gather(rank, dst, full_size, out) 417 418 local_shards = self.local_shards() 419 world_size = dist.get_world_size(self._process_group) 420 rank_sizes = [0 for _ in range(world_size)] 421 max_rank_size = 0 422 shard_placement: Dict[ShardMetadata, Tuple[int, int]] = {} 423 # collect sizes 424 for shard_md in self.metadata().shards_metadata: 425 shard_rank = cast(_remote_device, shard_md.placement).rank() 426 assert shard_rank is not None 427 428 shard_placement[shard_md] = (shard_rank, rank_sizes[shard_rank]) 429 rank_sizes[shard_rank] += shard_size(shard_md) 430 max_rank_size = max(max_rank_size, rank_sizes[shard_rank]) 431 432 gather_list: Optional[List[torch.Tensor]] 433 if rank == dst: 434 assert out is not None 435 if enforce_dtype: 436 # enforce_dtype is deprecated. Do it for backward compatibility. 437 dtype = out.dtype 438 # TODO make it as a view of out tensor 439 gather_list = [ 440 torch.empty((max_rank_size,), device=out.device, dtype=dtype) 441 for _ in range(world_size) 442 ] 443 else: 444 gather_list = None 445 446 with torch.no_grad(): 447 if enforce_dtype and len(local_shards) > 0: 448 # enforce_dtype is deprecated. Do it for backward compatibility. 449 dtype = local_shards[0].tensor.dtype 450 data = torch.empty( 451 max_rank_size, device=self._get_preferred_device(), dtype=dtype 452 ) 453 454 for shard in local_shards: 455 src = shard.tensor.flatten() 456 if src.nelement() == 0: 457 warnings.warn( 458 "Gathering a tensor with zero elements on rank " + str(rank) 459 ) 460 return 461 shard_offset = shard_placement[shard.metadata][1] 462 data[shard_offset : shard_offset + src.numel()].copy_(src) 463 464 dist.gather( 465 tensor=data, 466 gather_list=gather_list, 467 dst=dst, 468 group=self._process_group, 469 ) 470 if rank != dst: 471 return 472 # In _validate_output_tensor_for_gather, we raise if out == None and rank == dst 473 out = cast(torch.Tensor, out) 474 assert gather_list is not None 475 476 full_size = self.metadata().size 477 dims = len(full_size) 478 for shard_md in self.metadata().shards_metadata: 479 rank, rank_offset = shard_placement[shard_md] 480 tensor = gather_list[rank] 481 tensor = tensor[rank_offset : rank_offset + shard_size(shard_md)] 482 tensor = tensor.view(shard_md.shard_sizes) 483 484 out_narrow_view = out 485 for dim in range(dims): 486 out_narrow_view = out_narrow_view.narrow( 487 dim, 488 shard_md.shard_offsets[dim], 489 shard_md.shard_sizes[dim], 490 ) 491 492 out_narrow_view.copy_(tensor) 493 494 def cpu( 495 self, memory_format=torch.preserve_format, process_group=None 496 ) -> ShardedTensor: 497 """ 498 Returns a copy of this object in CPU memory. 499 500 If this ShardedTensor is already on CPU memory, then no copy is 501 performed and original object is returned. 502 503 .. note:: When moving a ShardedTensor from GPU to CPU, the ShardedTensor might 504 need to be managed by a different type of ProcessGroup(i.e. ProcessGroupGloo), 505 it is the user's responsiblity to explicitly pass in a new process_group that 506 is compatible with CPU. 507 """ 508 # TODO: make this a __torch_function__ op once ShardedTensor becomes a 509 # torch.Tensor subclass, see https://github.com/pytorch/pytorch/issues/75402 510 if ( 511 memory_format != torch.preserve_format 512 and memory_format != torch.contiguous_format 513 ): 514 raise RuntimeError( 515 "Only `torch.contiguous_format` or " 516 "`torch.preserve_format` is supported!" 517 ) 518 all_on_cpu = True 519 for meta in self.metadata().shards_metadata: 520 all_on_cpu &= meta.placement.device().type == "cpu" # type: ignore[union-attr] 521 522 # if every shard is already on CPU, return the original object 523 if all_on_cpu: 524 return self 525 526 # if not, returns a copy of this object on CPU 527 list_shards: List[Shard] = [] 528 # move all local shards to cpu, and change metadata 529 for shard in self._local_shards: 530 cpu_tensor = shard.tensor.cpu(memory_format=memory_format) # type: ignore[call-arg] 531 metadata = copy.deepcopy(shard.metadata) 532 metadata.placement._device = torch.device("cpu") # type: ignore[union-attr] 533 list_shards.append(Shard(cpu_tensor, metadata)) 534 535 st_meta = copy.deepcopy(self.metadata()) 536 for meta in st_meta.shards_metadata: 537 if meta.placement.device().type != "cpu": # type: ignore[union-attr] 538 meta.placement._device = torch.device("cpu") # type: ignore[union-attr] 539 540 pg = self._process_group if process_group is None else process_group 541 st_cpu = ShardedTensor._init_from_local_shards_and_global_metadata( 542 list_shards, 543 sharded_tensor_metadata=st_meta, 544 process_group=pg, 545 init_rrefs=self._init_rrefs, 546 ) 547 return st_cpu 548 549 def cuda( 550 self, 551 device=None, 552 non_blocking=False, 553 memory_format=torch.preserve_format, 554 process_group=None, 555 ) -> ShardedTensor: 556 """ 557 Returns a copy of this object in CUDA memory, if the original ShardedTensor 558 is on CPU, we will move the local shard to the current GPU device of each 559 process in a SPMD fashion. 560 If this ShardedTensor is already on CUDA memory and local shards on each rank are 561 already on current device, we still returns a new ShardedTensor object with new 562 metadata, but no underlying data movements are performed. 563 .. note:: When moving a ShardedTensor from CPU to GPU, the ShardedTensor might 564 need to be managed by a different type of ProcessGroup(i.e. ProcessGroupNCCL), 565 it is the user's responsiblity to explicitly pass in a new process_group that 566 is compatible with GPU. 567 """ 568 if ( 569 memory_format != torch.preserve_format 570 and memory_format != torch.contiguous_format 571 ): 572 raise RuntimeError( 573 "Only `torch.contiguous_format` or " 574 "`torch.preserve_format` is supported!" 575 ) 576 577 if device is not None: 578 device = torch.device(device) if isinstance(device, str) else device 579 assert ( 580 isinstance(device, torch.device) 581 and device.index == torch.cuda.current_device() 582 ), """Only device without device id (e.g. "cpu" or "cuda") is expected for ShardedTensor!""" 583 584 current_device = torch.device(torch.cuda.current_device()) 585 # returns a copy of ShardedTensor on CUDA current device 586 list_shards: List[Shard] = [] 587 # move all local shards to current device, and change metadata 588 # if local shards already on the current device, there's no 589 # real data movement, only the metadata are copied. 590 for shard in self._local_shards: 591 cuda_tensor = shard.tensor.cuda( 592 device=current_device, 593 non_blocking=non_blocking, 594 memory_format=memory_format, 595 ) # type: ignore[call-arg] 596 metadata = copy.deepcopy(shard.metadata) 597 metadata.placement._device = current_device # type: ignore[union-attr] 598 599 list_shards.append(Shard(cuda_tensor, metadata)) 600 601 st_meta = copy.deepcopy(self.metadata()) 602 for meta in st_meta.shards_metadata: 603 if meta.placement.device().type != "cuda": # type: ignore[union-attr] 604 meta.placement._device = current_device # type: ignore[union-attr] 605 606 pg = self._process_group if process_group is None else process_group 607 # we need to use `init_from_local_shards` to communicate between ranks 608 # and update the sharding spec/shards metadata. 609 st_cuda = ShardedTensor._init_from_local_shards_and_global_metadata( 610 list_shards, 611 sharded_tensor_metadata=st_meta, 612 process_group=pg, 613 init_rrefs=self._init_rrefs, 614 ) 615 return st_cuda 616 617 def to(self, *args, **kwargs) -> ShardedTensor: 618 current_device: torch.device 619 if self._local_shards: 620 current_device = self._local_shards[0].tensor.device 621 elif self._process_group._get_backend_name() == "gloo": 622 current_device = torch.device("cpu") 623 else: 624 current_device = torch.device(torch.cuda.current_device()) 625 current_dtype = self.dtype 626 device_to = current_device 627 dtype_to = current_dtype 628 if len(args) == 1: 629 if isinstance(args[0], torch.dtype): 630 dtype_to = args[0] 631 elif isinstance(args[0], torch.device): 632 device_to = args[0] 633 elif isinstance(args[0], (str, int)): 634 device_to = torch.device(args[0]) 635 elif isinstance(args[0], torch.Tensor): 636 dtype_to = args[0].dtype 637 device_to = args[0].device 638 else: 639 raise RuntimeError(f"ShardedTensor.to() have wrong arguments: {args}") 640 elif len(args) == 2: 641 device_to, dtype_to = args 642 else: 643 dtype_to = kwargs.get("dtype", current_dtype) 644 device_to = kwargs.get("device", current_device) 645 646 device_to = ( 647 torch.device(device_to) if isinstance(device_to, (str, int)) else device_to 648 ) 649 650 if device_to.type == "cuda": 651 # if device_to set to cuda, set to current device even 652 # if user specify the device index. 653 current_idx = torch.cuda.current_device() 654 if device_to.index != current_idx: 655 warnings.warn( 656 "ShardedTensor.to only move tensor to its current device" 657 "If you want to put to different device, use `reshard` instead." 658 ) 659 device_to = torch.device(current_idx) 660 661 copy_tensor = kwargs.get("copy", False) 662 non_blocking = kwargs.get("non_blocking", False) 663 memory_format = kwargs.get("memory_format", torch.preserve_format) 664 process_group = kwargs.get("process_group", None) 665 666 if ( 667 not copy_tensor 668 and dtype_to == current_dtype 669 and device_to == current_device 670 ): 671 # already have correct dtype and device, return itself 672 return self 673 674 # returns a copy of ShardedTensor on CUDA current device 675 list_shards: List[Shard] = [] 676 677 for shard in self._local_shards: 678 new_tensor = shard.tensor.to( # type: ignore[call-overload] 679 device=device_to, 680 dtype=dtype_to, 681 non_blocking=non_blocking, 682 copy=copy_tensor, 683 memory_format=memory_format, 684 ) 685 metadata = copy.deepcopy(shard.metadata) 686 if metadata.placement is not None: 687 metadata.placement._device = device_to 688 list_shards.append(Shard(new_tensor, metadata)) 689 690 # update metadata 691 st_meta = copy.deepcopy(self.metadata()) 692 st_meta.tensor_properties.dtype = dtype_to 693 for meta in st_meta.shards_metadata: 694 meta.placement._device = device_to # type: ignore[union-attr] 695 696 pg = self._process_group if process_group is None else process_group 697 # we need to use `init_from_local_shards` to communicate between ranks 698 # and update the sharding spec/shards metadata. 699 st_to = ShardedTensor._init_from_local_shards_and_global_metadata( 700 list_shards, 701 sharded_tensor_metadata=st_meta, 702 process_group=pg, 703 init_rrefs=self._init_rrefs, 704 ) 705 return st_to 706 707 @classmethod 708 def _normalize_pg( 709 cls, process_group: Optional[dist.ProcessGroup] 710 ) -> dist.ProcessGroup: 711 if process_group is not None: 712 return process_group 713 return distributed_c10d._get_default_group() 714 715 @classmethod 716 def _init_from_local_shards( 717 cls, 718 local_shards: List[Shard], 719 *global_size, 720 process_group=None, 721 init_rrefs=False, 722 ): 723 # STEP 1: Validate the Shardmetadatas locally 724 process_group = cls._normalize_pg(process_group) 725 current_rank = dist.get_rank() # intentional to get global rank 726 world_size = dist.get_world_size(process_group) 727 728 local_sharded_tensor_metadata: Optional[ShardedTensorMetadata] = None 729 global_tensor_size = _flatten_tensor_size(global_size) 730 731 if len(local_shards) > 0: 732 local_sharded_tensor_metadata = build_metadata_from_local_shards( 733 local_shards, global_tensor_size, current_rank, process_group 734 ) 735 736 # STEP 2. Validate metadata across ranks, and build a global sharded tensor 737 # metadata by gathering local ShardedTensorMetadata 738 gathered_metadatas: List[Optional[ShardedTensorMetadata]] = [] 739 if world_size > 1: 740 gathered_metadatas = [None for _ in range(world_size)] 741 742 dist.all_gather_object( 743 gathered_metadatas, local_sharded_tensor_metadata, group=process_group 744 ) 745 else: 746 gathered_metadatas = [local_sharded_tensor_metadata] 747 748 global_sharded_tensor_metadata = build_global_metadata(gathered_metadatas) 749 tensor_properties = global_sharded_tensor_metadata.tensor_properties 750 751 # STEP 3: Validation done, create the actual ShardedTensor and populate fields 752 # prepare initialization 753 spec = shard_spec._infer_sharding_spec_from_shards_metadata( 754 global_sharded_tensor_metadata.shards_metadata 755 ) 756 sharded_tensor = cls.__new__( 757 cls, 758 spec, 759 global_sharded_tensor_metadata.size, 760 dtype=tensor_properties.dtype, 761 layout=tensor_properties.layout, 762 pin_memory=tensor_properties.pin_memory, 763 requires_grad=tensor_properties.requires_grad, 764 ) 765 sharded_tensor._prepare_init(process_group=process_group, init_rrefs=init_rrefs) 766 767 # attach local_shards to the ShardedTensor created 768 sharded_tensor._local_shards = local_shards 769 770 # run post initialization, i.e. map registration, rpc initialization 771 sharded_tensor._post_init() 772 return sharded_tensor 773 774 @classmethod 775 @deprecated(DEPRECATE_MSG, category=FutureWarning) 776 def _init_from_local_tensor( 777 cls, 778 local_tensor: torch.Tensor, 779 sharding_spec: shard_spec.ShardingSpec, 780 *global_size: Sequence[int], 781 process_group: Optional[dist.ProcessGroup] = None, 782 init_rrefs=False, 783 ) -> ShardedTensor: 784 """ 785 Initialize a ShardedTensor given only one local tensor, global sharded tensor 786 size and sharding spec on each rank. 787 788 Args: 789 local_tensor (Tensor): Single tensor of local shard stored in each rank. 790 sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): 791 The specification describing how to shard the Tensor. 792 global_size (Sequence[int]): Size of the sharded tensor. 793 process_group (ProcessGroup, optional): The process group to aggregate on. 794 Default: None 795 init_rrefs (bool, optional): Whether or not to initialize 796 :class:`torch.distributed.rpc.RRef`s pointing to remote shards. 797 Need to initialize the RPC Framework if specified as ``True``. 798 Default: ``False``. 799 800 Returns: 801 A :class:`ShardedTensor` sharded based on the given sharding_spec with local 802 tensor stored in the current rank. 803 804 Examples: 805 >>> # xdoctest: +SKIP 806 >>> # All tensors below are of torch.int64 type. 807 >>> # We have 2 process groups, 2 ranks. 808 >>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank 809 >>> local_tensor = torch.unsqueeze(torch.cat([tensor, tensor + 2])) 810 >>> local_tensor 811 tensor([[1, 2, 3, 4]]) # Rank 0 812 tensor([[3, 4, 5, 6]]) # Rank 1 813 >>> sharding_dim = 0 814 >>> sharding_spec = ChunkShardingSpec( 815 dim=sharding_dim, 816 placements=[ 817 "rank:0/cuda:0", 818 "rank:1/cuda:1", 819 ], 820 ) 821 >>> st = ShardedTensor._init_from_local_tensor(local_tensor, sharding_spec, [2, 4]) 822 >>> st 823 ShardedTensor( 824 ShardedTensorMetadata( 825 shards_metadata=[ 826 ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1, 4], placement=rank:0/cuda:0), 827 ShardMetadata(shard_offsets=[1, 0], shard_sizes=[1, 4], placement=rank:1/cuda:1), 828 ], 829 size=torch.Size([2, 4]) 830 ) 831 >>> st.local_tensor() 832 tensor([1, 2, 3, 4]) # Rank 0 833 tensor([3, 4, 5, 6]) # Rank 1 834 835 Warning: This API is experimental and subject to change. It lacks of a fully across 836 rank validations, and we only validate the local shard on the current rank. 837 We fully rely on the user to ensure local tensor is sharded based on the 838 sharding spec. 839 """ 840 if not local_tensor.is_contiguous(): 841 raise ValueError("local_tensor is not a contiguous Tensor.") 842 843 global_tensor_size = _flatten_tensor_size(global_size) 844 tensor_properties = TensorProperties( 845 dtype=local_tensor.dtype, 846 layout=local_tensor.layout, 847 requires_grad=local_tensor.requires_grad, 848 memory_format=torch.contiguous_format, 849 pin_memory=local_tensor.is_pinned(), 850 ) 851 sharded_tensor_metadata = sharding_spec.build_metadata( 852 global_tensor_size, tensor_properties 853 ) 854 855 process_group = cls._normalize_pg(process_group) 856 current_rank = dist.get_rank() # intentional to get global rank 857 858 local_shards: List[Shard] = [] 859 for shard_metadata in sharded_tensor_metadata.shards_metadata: 860 rank, device = _parse_and_validate_remote_device( 861 process_group, shard_metadata.placement 862 ) 863 if rank == current_rank: 864 local_shards.append(Shard(local_tensor, shard_metadata)) 865 866 # TODO: figure out what the API should behave when some rank have no shard 867 # see https://github.com/pytorch/pytorch/issues/7313 868 return ShardedTensor._init_from_local_shards_and_global_metadata( 869 local_shards, 870 sharded_tensor_metadata, 871 process_group=process_group, 872 init_rrefs=init_rrefs, 873 sharding_spec=sharding_spec, 874 ) 875 876 @classmethod 877 def _init_from_local_shards_and_global_metadata( # type: ignore[override] 878 cls, 879 local_shards: List[Shard], 880 sharded_tensor_metadata: ShardedTensorMetadata, 881 process_group=None, 882 init_rrefs=False, 883 sharding_spec=None, 884 ) -> ShardedTensor: 885 """ 886 Initialize a ShardedTensor with local shards and a global 887 ShardedTensorMetadata built on each rank. 888 889 Warning: This API is experimental and subject to change. It does 890 not do cross rank validations, and fully rely on the user 891 for the correctness of sharded_tensor_metadata on each rank 892 """ 893 process_group = cls._normalize_pg(process_group) 894 current_rank = dist.get_rank() # intentional to get global rank 895 896 shards_metadata = sharded_tensor_metadata.shards_metadata 897 898 local_shard_metadatas = [] 899 900 # collect local shard metadatas from the global sharded_tensor_metadata 901 for shard_metadata in shards_metadata: # type: ignore[attr-defined] 902 rank, local_device = _parse_and_validate_remote_device( 903 process_group, shard_metadata.placement 904 ) 905 906 if current_rank == rank: 907 local_shard_metadatas.append(shard_metadata) 908 909 if len(local_shards) != len(local_shard_metadatas): 910 raise RuntimeError( 911 f"Number of local shards ({len(local_shards)}) does not match number of local " 912 f"shards metadata in sharded_tensor_metadata ({len(local_shard_metadatas)}) " 913 f"on rank ({current_rank}) " 914 ) 915 916 shards_metadata = sharded_tensor_metadata.shards_metadata 917 tensor_properties = sharded_tensor_metadata.tensor_properties 918 919 if len(shards_metadata) == 0: 920 raise ValueError("shards_metadata must not be empty!") 921 922 if tensor_properties.layout != torch.strided: 923 raise ValueError("Only torch.strided layout is currently supported") 924 925 if sharding_spec is None: 926 spec = shard_spec._infer_sharding_spec_from_shards_metadata(shards_metadata) 927 else: 928 spec = sharding_spec 929 930 sharded_tensor = ShardedTensor.__new__( 931 ShardedTensor, 932 spec, 933 sharded_tensor_metadata.size, 934 dtype=tensor_properties.dtype, 935 layout=tensor_properties.layout, 936 pin_memory=tensor_properties.pin_memory, 937 requires_grad=tensor_properties.requires_grad, 938 ) 939 940 def _raise_if_mismatch(expected, actual, prop_name, rank, is_property=False): 941 tensor_property_or_metadata = ( 942 "tensor property" if is_property else "local ShardMetadata" 943 ) 944 if expected != actual: 945 raise ValueError( 946 f"Local shards' tensor {prop_name} property is incompatible with " 947 f"{tensor_property_or_metadata} on rank {rank}: " 948 f"{tensor_property_or_metadata} {prop_name}={expected}, " 949 f"local shard tensor {prop_name}={actual}." 950 ) 951 952 for shard in local_shards: 953 shard_meta = shard.metadata 954 local_shard_tensor = shard.tensor 955 placement = shard_meta.placement 956 assert placement is not None, "Must specify placement for `Shard`!" 957 rank = placement.rank() 958 local_device = placement.device() 959 960 _raise_if_mismatch( 961 tensor_properties.layout, 962 local_shard_tensor.layout, 963 "layout", 964 rank, 965 True, 966 ) 967 if not local_shard_tensor.is_contiguous(): 968 raise ValueError( 969 "Only torch.contiguous_format memory_format is currently supported" 970 ) 971 972 _raise_if_mismatch( 973 shard_meta.shard_sizes, 974 list(local_shard_tensor.size()), 975 "size", 976 rank, 977 ) 978 _raise_if_mismatch( 979 tensor_properties.pin_memory, 980 local_shard_tensor.is_pinned(), 981 "pin_memory", 982 rank, 983 True, 984 ) 985 _raise_if_mismatch(local_device, local_shard_tensor.device, "device", rank) 986 _raise_if_mismatch( 987 tensor_properties.dtype, 988 local_shard_tensor.dtype, 989 "dtype", 990 rank, 991 True, 992 ) 993 _raise_if_mismatch( 994 tensor_properties.requires_grad, 995 local_shard_tensor.requires_grad, 996 "requires_grad", 997 rank, 998 True, 999 ) 1000 1001 # check if shards_metadata have overlap shards 1002 validate_non_overlapping_shards_metadata(shards_metadata) 1003 1004 # check if the shards_metadata is compatible with overall size of the sharded tensor. 1005 check_tensor(shards_metadata, list(sharded_tensor_metadata.size)) 1006 1007 # done validation, add local_shards 1008 sharded_tensor._local_shards = local_shards 1009 sharded_tensor._prepare_init(process_group=process_group, init_rrefs=init_rrefs) 1010 1011 # run post initialization, i.e. map registration, rpc initialization 1012 sharded_tensor._post_init() 1013 return sharded_tensor 1014 1015 def sharding_spec(self) -> shard_spec.ShardingSpec: 1016 """ 1017 Returns the ShardingSpec for the tensor. 1018 """ 1019 return self._sharding_spec 1020 1021 @deprecated(DEPRECATE_MSG, category=FutureWarning) 1022 def reshard(self, resharding_spec: shard_spec.ShardingSpec) -> ShardedTensor: 1023 """ 1024 Reshard a sharded tensor given the ``resharding_spec``. For now, we only support 1025 single local shard. 1026 1027 If ``resharding_spec`` is same as the original one, this becomes a no-op. 1028 If only ``resharding_spec`` shares the same sharding dim with the original one, 1029 we swap local shards directly. 1030 For more generic cases, we merge different shards across different ranks and split 1031 the local shards based on the ``resharding_spec`` via `all_to_all` collective API. 1032 1033 Args: 1034 resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The 1035 specification describing how the tensor is sharded. 1036 1037 Returns: 1038 A :class:`ShardedTensor` object whose local shards are resharded. 1039 1040 Examples: 1041 >>> # xdoctest: +SKIP 1042 >>> # We have 2 process groups, 2 ranks. 1043 >>> tensor = torch.arange(4, dtype=torch.int64) + 1 + 2 * rank 1044 >>> tensor = torch.stack([tensor, tensor]) 1045 >>> tensor 1046 tensor([[1, 2, 3, 4], [1, 2, 3, 4]]) # Rank 0 1047 tensor([[3, 4, 5, 6], [3, 4, 5, 6]]) # Rank 1 1048 tensor([[5, 6, 7, 8], [5, 6, 7, 8]]) # Rank 2 1049 tensor([[7, 8, 9, 10], [7, 8, 9, 10]]) # Rank 3 1050 >>> sharding_dim = 0 1051 >>> spec = ChunkShardingSpec( 1052 dim=sharding_dim, 1053 placements=[ 1054 "rank:0/cuda:0", 1055 "rank:1/cuda:1", 1056 "rank:2/cuda:2", 1057 "rank:3/cuda:3", 1058 ], 1059 ) 1060 >>> current_offsets = [0] * 2 1061 >>> current_offsets[0] = rank * 2 1062 >>> shard_metadata = ShardMetadata( 1063 shard_offsets=copy.deepcopy(current_offsets), 1064 shard_sizes=tensor.size(), 1065 placement=spec.placements[rank], 1066 ) 1067 >>> local_shards = [ 1068 Shard( 1069 tensor=tensor, 1070 metadata=shard_metadata, 1071 ) 1072 ] 1073 >>> st = ShardedTensor._init_from_local_shards(local_shards, tensor.size()) 1074 >>> sharding_dim = 1 1075 >>> resharding_spec = ChunkShardingSpec( 1076 dim=sharding_dim, 1077 placements=[ 1078 "rank:0/cuda:0", 1079 "rank:1/cuda:1", 1080 "rank:2/cuda:2", 1081 "rank:3/cuda:3", 1082 ], 1083 ) 1084 >>> st.reshard(resharding_spec) 1085 >>> tensor = st.local_shards()[0].tensor 1086 >>> tensor 1087 tensor([[1], [1], [3], [3], [5], [5], [7], [7]]) # Rank 0 1088 tensor([[2], [2], [4], [4], [6], [6], [8], [8]]) # Rank 1 1089 tensor([[3], [3], [5], [5], [7], [7], [9], [9]]) # Rank 2 1090 tensor([[4], [4], [6], [6], [8], [8], [10], [10]]) # Rank 3 1091 """ 1092 if not isinstance( 1093 resharding_spec, shard_spec.ChunkShardingSpec 1094 ) or not isinstance(self._sharding_spec, shard_spec.ChunkShardingSpec): 1095 raise NotImplementedError("Only ChunkShardingSpec supported for reshard.") 1096 if len(self.local_shards()) != 1: 1097 raise NotImplementedError("Only single local shard supported for reshard.") 1098 1099 if self._sharding_spec.dim == resharding_spec.dim: # type: ignore[attr-defined] 1100 if self._sharding_spec.placements == resharding_spec.placements: # type: ignore[attr-defined] 1101 return self 1102 else: 1103 local_shards, shards_metadata = reshuffle_local_shard( 1104 self.local_tensor(), 1105 self.size(), # type: ignore[arg-type] 1106 self._sharding_spec, 1107 resharding_spec, 1108 self._process_group, 1109 ) 1110 else: 1111 local_shards, shards_metadata = reshard_local_shard( 1112 self.local_tensor(), 1113 self.size(), # type: ignore[arg-type] 1114 self._sharding_spec, 1115 resharding_spec, 1116 self._process_group, 1117 ) 1118 self._local_shards = local_shards 1119 self._metadata.shards_metadata = shards_metadata 1120 self._sharding_spec = resharding_spec 1121 return self 1122 1123 def local_tensor(self) -> torch.Tensor: 1124 """ 1125 Return local tensor for a sharded_tensor. For now we only support single local shard. 1126 1127 Returns: 1128 A :class:`torch.Tensor` of the local shard. 1129 """ 1130 if len(self.local_shards()) != 1: 1131 raise NotImplementedError("Only single local shard is supported.") 1132 return self.local_shards()[0].tensor 1133 1134 @classmethod 1135 @deprecated(DEPRECATE_MSG, category=FutureWarning) 1136 def __torch_function__(cls, func, types, args=(), kwargs=None): 1137 def dispatch(st: ShardedTensor, func: Callable): 1138 # Dispatch to custom user provided op first if it exists. 1139 if func in _CUSTOM_SHARDED_OPS: 1140 return _CUSTOM_SHARDED_OPS[func](types, args, kwargs, st._process_group) 1141 1142 # Dispatch to custom sharding spec op if it has one. 1143 if _has_custom_op(st._sharding_spec, func): 1144 return _dispatch_custom_op( 1145 st._sharding_spec, func, types, args, kwargs, st._process_group 1146 ) 1147 1148 if func in _SHARDED_OPS: 1149 return _SHARDED_OPS[func](types, args, kwargs, st._process_group) 1150 1151 raise RuntimeError( 1152 f"torch function '{func.__name__}', with args: {args} and " 1153 f"kwargs: {kwargs} not supported for ShardedTensor!" 1154 ) 1155 1156 # Find ShardedTensor instance to get process_group and sharding_spec. 1157 st_instance = None 1158 1159 def find_sharded_tensor(e): 1160 nonlocal st_instance 1161 if st_instance is None and isinstance(e, ShardedTensor): 1162 st_instance = e 1163 1164 pytree.tree_map_(find_sharded_tensor, args) 1165 pytree.tree_map_(find_sharded_tensor, kwargs) 1166 1167 if st_instance is not None: 1168 return dispatch(st_instance, func) 1169 1170 raise RuntimeError( 1171 f"torch function '{func.__name__}', with args: {args} and " 1172 f"kwargs: {kwargs} not supported for ShardedTensor!" 1173 ) 1174 1175 def is_pinned(self) -> bool: # type: ignore[override] 1176 """ 1177 Returns True if the sharded tensor (each local shard) resides in pinned memory. 1178 """ 1179 return self._metadata.tensor_properties.pin_memory 1180 1181 def _register_remote_shards( 1182 self, remote_shards: List[rpc.RRef[Shard]], rpc_rank: int 1183 ): 1184 self._remote_shards[rpc_rank] = remote_shards 1185 1186 def remote_shards(self) -> Dict[int, List[rpc.RRef[Shard]]]: 1187 """ 1188 Returns a Dict[int, RRef] with keys being the RPC rank and values 1189 being RRefs to shards on that rank. Need to initialize the 1190 RPC framework for this functionality. 1191 1192 Raises an exception if ShardedTensor was created with ``init_rrefs=False`` 1193 """ 1194 if not self._init_rrefs: 1195 raise RuntimeError( 1196 "ShardedTensor created with init_rrefs=False, no RRefs to remote shards available" 1197 ) 1198 return self._remote_shards 1199 1200 def __hash__(self): 1201 return id(self) 1202 1203 def __repr__(self): 1204 return f"ShardedTensor({self._metadata})" 1205 1206 @dataclass 1207 class ProcessGroupState: 1208 """ 1209 State for ser-de of process group 1210 """ 1211 1212 local_rank: int 1213 global_rank: int 1214 local_world_size: int 1215 global_world_size: int 1216 1217 def __getstate__(self): 1218 pg_state = ShardedTensor.ProcessGroupState( 1219 distributed_c10d.get_rank(self._process_group), 1220 distributed_c10d.get_rank(), 1221 distributed_c10d.get_world_size(self._process_group), 1222 distributed_c10d.get_world_size(), 1223 ) 1224 1225 return ( 1226 self._local_shards, 1227 self._metadata, 1228 pg_state, 1229 self._sharding_spec, 1230 self._init_rrefs, 1231 ) 1232 1233 def __setstate__(self, state): 1234 self._sharded_tensor_id = None 1235 if not distributed_c10d.is_initialized(): 1236 raise RuntimeError( 1237 "Need to initialize default process group using " 1238 '"init_process_group" before loading ShardedTensor' 1239 ) 1240 1241 ( 1242 self._local_shards, 1243 self._metadata, 1244 pg_state, 1245 self._sharding_spec, 1246 self._init_rrefs, 1247 ) = state 1248 1249 # Setup process group 1250 from torch.distributed._shard.api import _get_current_process_group 1251 1252 self._process_group = _get_current_process_group() 1253 1254 # Validate process group. 1255 local_rank = distributed_c10d.get_rank(self._process_group) 1256 if pg_state.local_rank != local_rank: 1257 raise RuntimeError( 1258 f"Local rank at save time was {pg_state.local_rank}, but at " 1259 f"load time was {local_rank}" 1260 ) 1261 1262 global_rank = distributed_c10d.get_rank() 1263 if pg_state.global_rank != global_rank: 1264 raise RuntimeError( 1265 f"Global rank at save time was {pg_state.global_rank}, but at " 1266 f"load time was {global_rank}" 1267 ) 1268 1269 local_world_size = distributed_c10d.get_world_size(self._process_group) 1270 if pg_state.local_world_size != local_world_size: 1271 raise RuntimeError( 1272 f"Local world size at save time was {pg_state.local_world_size}, " 1273 f"but at load time was {local_world_size}" 1274 ) 1275 1276 global_world_size = distributed_c10d.get_world_size() 1277 if pg_state.global_world_size != global_world_size: 1278 raise RuntimeError( 1279 f"Global world size at save time was {pg_state.global_world_size}, " 1280 f"but at load time was {global_world_size}" 1281 ) 1282 1283 self._post_init() 1284 1285 1286def _create_tensor_from_params( 1287 *size, local_device, tensor_properties: TensorProperties 1288): 1289 """Helper to construct tensor from size, device and common params.""" 1290 dtype = tensor_properties.dtype 1291 layout = tensor_properties.layout 1292 requires_grad = tensor_properties.requires_grad 1293 memory_format = tensor_properties.memory_format 1294 pin_memory = tensor_properties.pin_memory 1295 1296 return torch.empty( 1297 *size, 1298 dtype=dtype, 1299 layout=layout, 1300 device=local_device, 1301 requires_grad=requires_grad, 1302 memory_format=memory_format, 1303 pin_memory=pin_memory, 1304 ) 1305