/aosp_15_r20/external/pytorch/torch/distributed/tensor/_ops/ |
H A D | _math_ops.py | 88 self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int 113 mesh: DeviceMesh, 123 self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int 255 mesh: DeviceMesh, 332 def linear_reduction_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: 358 def var_reduction_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: 377 def vector_norm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: 400 def foreach_norm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> TupleStrategy: 437 def linalg_replicate_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: 469 def softmax_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: [all …]
|
H A D | _matrix_ops.py | 33 def transpose_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: 58 mm_equation: str, mesh: DeviceMesh, op_schema: OpSchema 88 mm_equation: str, mesh: DeviceMesh, op_schema: OpSchema 142 def mm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: 147 def addmm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: 152 def bmm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: 157 def baddmm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: 165 mesh: DeviceMesh, op_schema: OpSchema 248 mesh: DeviceMesh, op_schema: OpSchema 320 def constant_pad_nd_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: [all …]
|
H A D | _tensor_ops.py | 42 def default_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: 83 def equal_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: 141 def create_like_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: 182 def new_factory_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: 226 def gen_bucketize_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: 244 def gen_slice_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: 314 def gen_slice_scatter_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: 359 def replica_only_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: 369 def scatter_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: 392 def gather_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: [all …]
|
H A D | _pointwise_ops.py | 420 mesh: DeviceMesh, op_schema: OpSchema, linearity: bool = False 454 mesh: DeviceMesh, 519 def linear_pointwise_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: 601 mesh: DeviceMesh, op_schema: OpSchema, linearity: bool = False 656 mesh: DeviceMesh, op_schema: OpSchema
|
H A D | _embedding_ops.py | 88 self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int 116 self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int 135 mesh: DeviceMesh, 190 def embedding_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: 239 mesh: DeviceMesh, op_schema: OpSchema
|
/aosp_15_r20/external/pytorch/torch/distributed/tensor/ |
H A D | _api.py | 124 device_mesh: DeviceMesh, 349 device_mesh: Optional[DeviceMesh] = None, 470 device_mesh: Optional[DeviceMesh] = None, 625 device_mesh: Optional[DeviceMesh] = None, 766 device_mesh: Optional[DeviceMesh] = None, 825 def replicate_module_params_buffers(m: nn.Module, mesh: DeviceMesh) -> None: 912 device_mesh: Optional[DeviceMesh] = None, 977 device_mesh: Optional[DeviceMesh] = None, 1020 device_mesh: Optional[DeviceMesh] = None, 1065 device_mesh: Optional[DeviceMesh] = None, [all …]
|
H A D | placement_types.py | 157 self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int 186 mesh: DeviceMesh, 221 mesh: DeviceMesh, 257 mesh: DeviceMesh, 277 mesh: DeviceMesh, 487 mesh: DeviceMesh, 556 self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int 595 self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int 606 mesh: DeviceMesh, 616 self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
|
H A D | _utils.py | 19 global_shape: ShapeType, mesh: DeviceMesh, placements: Sequence[Placement] 50 global_shape: ShapeType, mesh: DeviceMesh, placements: Sequence[Placement] 199 tensor: torch.Tensor, mesh: DeviceMesh, placements: Sequence[Placement] 279 global_stride: ShapeType, mesh: DeviceMesh, placements: Sequence[Placement]
|
H A D | _collective_utils.py | 77 mesh: DeviceMesh, 134 mesh: DeviceMesh, 245 def build_from_mesh(mesh: DeviceMesh) -> "MeshTopoInfo":
|
H A D | _random.py | 25 def is_rng_supported_mesh(device_mesh: DeviceMesh) -> bool: 51 def manual_seed(seed: int, device_mesh: DeviceMesh) -> None: 350 tp_mesh: DeviceMesh,
|
H A D | _dispatch.py | 447 mesh: "DeviceMesh", 480 mesh: "DeviceMesh",
|
H A D | _sharding_prop.py | 97 strategy_func: Callable[[DeviceMesh, OpSchema], StrategyType], 475 mesh: DeviceMesh,
|
/aosp_15_r20/external/pytorch/torch/distributed/tensor/experimental/ |
H A D | _attention.py | 121 mesh: DeviceMesh, 147 mesh: DeviceMesh, 201 mesh: DeviceMesh, 360 mesh: DeviceMesh, 458 mesh: DeviceMesh, 499 mesh: DeviceMesh, 549 device_mesh: DeviceMesh, 649 def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: 670 device_mesh: DeviceMesh, 704 device_mesh: DeviceMesh, [all …]
|
H A D | _tp_transform.py | 132 mesh: DeviceMesh, 184 mesh: DeviceMesh, 267 mesh: DeviceMesh, 309 mesh: DeviceMesh, 525 mesh: DeviceMesh,
|
/aosp_15_r20/external/pytorch/torch/distributed/tensor/parallel/ |
H A D | style.py | 31 def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: 128 def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: 230 def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: 300 def _replicate_module_fn(self, name: str, module: nn.Module, device_mesh: DeviceMesh): 323 def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: 406 mesh: DeviceMesh, 452 def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: 539 def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
|
H A D | fsdp.py | 223 device_mesh: DeviceMesh, 299 parent_mesh: Optional[DeviceMesh], 373 device_mesh: DeviceMesh, 386 parent_mesh: Optional[DeviceMesh],
|
H A D | input_reshard.py | 16 tp_device_mesh: DeviceMesh, 67 mesh: DeviceMesh, input_reshard_dim: int, x: torch.Tensor 87 mesh: DeviceMesh, input_reshard_dim: int, x: Any
|
H A D | loss.py | 93 tensor, placements: Tuple[Placement, ...], mesh: DeviceMesh 206 mesh: DeviceMesh, 356 mesh: DeviceMesh,
|
/aosp_15_r20/external/pytorch/torch/distributed/ |
H A D | device_mesh.py | 84 device_mesh: "DeviceMesh", 163 self, device_mesh: "DeviceMesh", mesh_dim_name: Optional[str] = None 228 def get_root_mesh(self, device_mesh: "DeviceMesh") -> "DeviceMesh": 235 def get_root_mesh_dim(self, device_mesh: "DeviceMesh") -> Optional[int]: 262 self, device_mesh: "DeviceMesh", mesh_dim_name: str 339 self, device_mesh: "DeviceMesh", mesh_dim_name: str 378 class DeviceMesh: class
|
/aosp_15_r20/external/pytorch/torch/distributed/fsdp/ |
H A D | _fsdp_extensions.py | 58 device_mesh: DeviceMesh, 78 parent_mesh: Optional[DeviceMesh], 142 device_mesh: DeviceMesh, 171 parent_mesh: Optional[DeviceMesh],
|
H A D | _init_utils.py | 109 device_mesh: Optional[DeviceMesh] = None, 159 device_mesh: DeviceMesh, 209 def _is_valid_hybrid_shard_device_mesh(device_mesh: DeviceMesh) -> bool: 525 def _init_extension(state: _FSDPState, device_mesh: DeviceMesh = None) -> _FSDPState:
|
H A D | _shard_utils.py | 95 device_mesh: DeviceMesh, 119 root_mesh: Optional[DeviceMesh],
|
/aosp_15_r20/external/pytorch/test/distributed/_composable/fsdp/ |
H A D | test_fully_shard_training.py | 675 mesh: DeviceMesh, 904 global_mesh: DeviceMesh, 997 def parallelize(_model: Transformer, mesh: DeviceMesh, use_seq_parallel: bool): 1051 def parallelize(_model: Transformer, mesh: DeviceMesh, use_seq_parallel: bool): 1155 global_mesh: DeviceMesh, 1225 global_mesh: DeviceMesh,
|
/aosp_15_r20/external/pytorch/test/distributed/_tensor/ |
H A D | test_pointwise_ops.py | 30 device_mesh: DeviceMesh, 73 device_mesh: DeviceMesh, 115 device_mesh: DeviceMesh,
|
/aosp_15_r20/external/pytorch/torch/testing/_internal/distributed/_tensor/ |
H A D | common_dtensor.py | 210 …module: "Transformer", device_mesh: DeviceMesh, use_seq_parallel: bool, local_output_for_attn: boo… 342 def _test_op(self, mesh: DeviceMesh, op_call, *args, **kwargs) -> None: 406 mesh: DeviceMesh, 521 self, t: torch.Tensor, mesh: DeviceMesh, placements: List[Placement]
|