/aosp_15_r20/external/pytorch/torch/distributed/tensor/parallel/ |
H A D | style.py | 89 def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): argument 102 def _partition_linear_fn(self, name, module, device_mesh): argument 112 def _partition_embedding_fn(self, name, module, device_mesh): argument 121 def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): argument 191 def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): argument 200 def _partition_linear_fn(self, name, module, device_mesh): argument 212 def _partition_embedding_fn(self, name, module, device_mesh): argument 221 def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): argument 310 def _prepare_input_fn(sequence_dim, mod, inputs, device_mesh): argument 320 def _prepare_output_fn(use_local_output, mod, outputs, device_mesh): argument [all …]
|
/aosp_15_r20/external/pytorch/test/distributed/_tensor/ |
H A D | test_api.py | 124 def shard_fn(name, module, device_mesh): argument 146 def replicate_fn(name, module, device_mesh): argument 163 def shard_fn(name, module, device_mesh): argument 188 def input_fn(mod, inputs, device_mesh): argument 191 def output_fn(mod, outputs, device_mesh): argument 210 def replicate_input_fn(mod, inputs, device_mesh): argument 233 def input_fn(inputs, device_mesh): argument 236 def output_fn(outputs, device_mesh): argument 302 def shard_fn(name, module, device_mesh): argument
|
H A D | test_embedding_ops.py | 33 def _apply_sharding(self, embedding_mod, shard_dim, device_mesh): argument 34 def shard_embedding_fn(name, module, device_mesh): argument 48 device_mesh, argument
|
H A D | test_optimizers.py | 24 def shard_fn(name, module, device_mesh): argument 38 def input_fn(mod, inputs, device_mesh): argument 45 def output_fn(mod, outputs, device_mesh): argument
|
H A D | test_math_ops.py | 286 def _replicate_fn(name, module, device_mesh): argument 351 def _replicate_fn(name, module, device_mesh): argument
|
/aosp_15_r20/external/pytorch/test/distributed/_tensor/experimental/ |
H A D | test_local_map.py | 31 def equal_allgather_forward(device_mesh, X, Y): argument 37 def mm_all_gather_forward(device_mesh, A, B): argument 46 def mm_allreduce_forward(device_mesh, A, B): argument 56 def mm_allreduce_forward_decorated(device_mesh, A, B): argument
|
/aosp_15_r20/external/pytorch/test/distributed/fsdp/ |
H A D | test_fsdp_tp_integration.py | 73 def distribute_rmsnorm(module, device_mesh): argument 74 def prepare_input_fn(mod, inputs, device_mesh): argument 78 def prepare_output_fn(mod, outputs, device_mesh): argument
|
H A D | test_hsdp_dtensor_state_dict.py | 51 def _create_model(self, device_mesh=None): argument 272 def __init__(self, device_mesh): argument
|
H A D | test_fsdp_dtensor_state_dict.py | 64 def _create_model(self, is_even_sharded_model, device_mesh=None): argument
|
/aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/experimental/xla_sharding/ |
H A D | xla_sharding.py | 507 def mesh_split_sharding(device_mesh, argument 559 device_mesh, argument
|
/aosp_15_r20/external/pytorch/torch/distributed/tensor/ |
H A D | _dtensor_spec.py | 130 def device_mesh(self) -> DeviceMesh: member in DTensorSpec
|
H A D | _api.py | 572 def device_mesh(self) -> DeviceMesh: member in DTensor
|
/aosp_15_r20/external/pytorch/test/distributed/tensor/parallel/ |
H A D | test_tp_examples.py | 160 def _test_mlp_inference(self, device_mesh): argument
|
/aosp_15_r20/external/pytorch/torch/distributed/ |
H A D | device_mesh.py | 287 self, device_mesh, mesh_dim_names argument
|
/aosp_15_r20/external/pytorch/torch/nn/parallel/ |
H A D | distributed.py | 644 device_mesh=None, argument
|