Home
last modified time | relevance | path

Searched defs:DeviceMesh (Results 1 – 25 of 51) sorted by relevance

123

/aosp_15_r20/external/pytorch/torch/distributed/tensor/_ops/
H A D_math_ops.py88 self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
113 mesh: DeviceMesh,
123 self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
255 mesh: DeviceMesh,
332 def linear_reduction_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
358 def var_reduction_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
377 def vector_norm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
400 def foreach_norm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> TupleStrategy:
437 def linalg_replicate_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
469 def softmax_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
[all …]
H A D_matrix_ops.py33 def transpose_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
58 mm_equation: str, mesh: DeviceMesh, op_schema: OpSchema
88 mm_equation: str, mesh: DeviceMesh, op_schema: OpSchema
142 def mm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
147 def addmm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
152 def bmm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
157 def baddmm_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
165 mesh: DeviceMesh, op_schema: OpSchema
248 mesh: DeviceMesh, op_schema: OpSchema
320 def constant_pad_nd_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
[all …]
H A D_tensor_ops.py42 def default_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
83 def equal_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
141 def create_like_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
182 def new_factory_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
226 def gen_bucketize_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
244 def gen_slice_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
314 def gen_slice_scatter_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
359 def replica_only_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
369 def scatter_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
392 def gather_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
[all …]
H A D_pointwise_ops.py420 mesh: DeviceMesh, op_schema: OpSchema, linearity: bool = False
454 mesh: DeviceMesh,
519 def linear_pointwise_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
601 mesh: DeviceMesh, op_schema: OpSchema, linearity: bool = False
656 mesh: DeviceMesh, op_schema: OpSchema
H A D_embedding_ops.py88 self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
116 self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
135 mesh: DeviceMesh,
190 def embedding_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
239 mesh: DeviceMesh, op_schema: OpSchema
/aosp_15_r20/external/pytorch/torch/distributed/tensor/
H A D_api.py124 device_mesh: DeviceMesh,
349 device_mesh: Optional[DeviceMesh] = None,
470 device_mesh: Optional[DeviceMesh] = None,
625 device_mesh: Optional[DeviceMesh] = None,
766 device_mesh: Optional[DeviceMesh] = None,
825 def replicate_module_params_buffers(m: nn.Module, mesh: DeviceMesh) -> None:
912 device_mesh: Optional[DeviceMesh] = None,
977 device_mesh: Optional[DeviceMesh] = None,
1020 device_mesh: Optional[DeviceMesh] = None,
1065 device_mesh: Optional[DeviceMesh] = None,
[all …]
H A Dplacement_types.py157 self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
186 mesh: DeviceMesh,
221 mesh: DeviceMesh,
257 mesh: DeviceMesh,
277 mesh: DeviceMesh,
487 mesh: DeviceMesh,
556 self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
595 self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
606 mesh: DeviceMesh,
616 self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
H A D_utils.py19 global_shape: ShapeType, mesh: DeviceMesh, placements: Sequence[Placement]
50 global_shape: ShapeType, mesh: DeviceMesh, placements: Sequence[Placement]
199 tensor: torch.Tensor, mesh: DeviceMesh, placements: Sequence[Placement]
279 global_stride: ShapeType, mesh: DeviceMesh, placements: Sequence[Placement]
H A D_collective_utils.py77 mesh: DeviceMesh,
134 mesh: DeviceMesh,
245 def build_from_mesh(mesh: DeviceMesh) -> "MeshTopoInfo":
H A D_random.py25 def is_rng_supported_mesh(device_mesh: DeviceMesh) -> bool:
51 def manual_seed(seed: int, device_mesh: DeviceMesh) -> None:
350 tp_mesh: DeviceMesh,
H A D_dispatch.py447 mesh: "DeviceMesh",
480 mesh: "DeviceMesh",
H A D_sharding_prop.py97 strategy_func: Callable[[DeviceMesh, OpSchema], StrategyType],
475 mesh: DeviceMesh,
/aosp_15_r20/external/pytorch/torch/distributed/tensor/experimental/
H A D_attention.py121 mesh: DeviceMesh,
147 mesh: DeviceMesh,
201 mesh: DeviceMesh,
360 mesh: DeviceMesh,
458 mesh: DeviceMesh,
499 mesh: DeviceMesh,
549 device_mesh: DeviceMesh,
649 def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
670 device_mesh: DeviceMesh,
704 device_mesh: DeviceMesh,
[all …]
H A D_tp_transform.py132 mesh: DeviceMesh,
184 mesh: DeviceMesh,
267 mesh: DeviceMesh,
309 mesh: DeviceMesh,
525 mesh: DeviceMesh,
/aosp_15_r20/external/pytorch/torch/distributed/tensor/parallel/
H A Dstyle.py31 def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
128 def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
230 def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
300 def _replicate_module_fn(self, name: str, module: nn.Module, device_mesh: DeviceMesh):
323 def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
406 mesh: DeviceMesh,
452 def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
539 def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
H A Dfsdp.py223 device_mesh: DeviceMesh,
299 parent_mesh: Optional[DeviceMesh],
373 device_mesh: DeviceMesh,
386 parent_mesh: Optional[DeviceMesh],
H A Dinput_reshard.py16 tp_device_mesh: DeviceMesh,
67 mesh: DeviceMesh, input_reshard_dim: int, x: torch.Tensor
87 mesh: DeviceMesh, input_reshard_dim: int, x: Any
H A Dloss.py93 tensor, placements: Tuple[Placement, ...], mesh: DeviceMesh
206 mesh: DeviceMesh,
356 mesh: DeviceMesh,
/aosp_15_r20/external/pytorch/torch/distributed/
H A Ddevice_mesh.py84 device_mesh: "DeviceMesh",
163 self, device_mesh: "DeviceMesh", mesh_dim_name: Optional[str] = None
228 def get_root_mesh(self, device_mesh: "DeviceMesh") -> "DeviceMesh":
235 def get_root_mesh_dim(self, device_mesh: "DeviceMesh") -> Optional[int]:
262 self, device_mesh: "DeviceMesh", mesh_dim_name: str
339 self, device_mesh: "DeviceMesh", mesh_dim_name: str
378 class DeviceMesh: class
/aosp_15_r20/external/pytorch/torch/distributed/fsdp/
H A D_fsdp_extensions.py58 device_mesh: DeviceMesh,
78 parent_mesh: Optional[DeviceMesh],
142 device_mesh: DeviceMesh,
171 parent_mesh: Optional[DeviceMesh],
H A D_init_utils.py109 device_mesh: Optional[DeviceMesh] = None,
159 device_mesh: DeviceMesh,
209 def _is_valid_hybrid_shard_device_mesh(device_mesh: DeviceMesh) -> bool:
525 def _init_extension(state: _FSDPState, device_mesh: DeviceMesh = None) -> _FSDPState:
H A D_shard_utils.py95 device_mesh: DeviceMesh,
119 root_mesh: Optional[DeviceMesh],
/aosp_15_r20/external/pytorch/test/distributed/_composable/fsdp/
H A Dtest_fully_shard_training.py675 mesh: DeviceMesh,
904 global_mesh: DeviceMesh,
997 def parallelize(_model: Transformer, mesh: DeviceMesh, use_seq_parallel: bool):
1051 def parallelize(_model: Transformer, mesh: DeviceMesh, use_seq_parallel: bool):
1155 global_mesh: DeviceMesh,
1225 global_mesh: DeviceMesh,
/aosp_15_r20/external/pytorch/test/distributed/_tensor/
H A Dtest_pointwise_ops.py30 device_mesh: DeviceMesh,
73 device_mesh: DeviceMesh,
115 device_mesh: DeviceMesh,
/aosp_15_r20/external/pytorch/torch/testing/_internal/distributed/_tensor/
H A Dcommon_dtensor.py210 …module: "Transformer", device_mesh: DeviceMesh, use_seq_parallel: bool, local_output_for_attn: boo…
342 def _test_op(self, mesh: DeviceMesh, op_call, *args, **kwargs) -> None:
406 mesh: DeviceMesh,
521 self, t: torch.Tensor, mesh: DeviceMesh, placements: List[Placement]

123