Home
last modified time | relevance | path

Searched defs:device_mesh (Results 1 – 15 of 15) sorted by relevance

/aosp_15_r20/external/pytorch/torch/distributed/tensor/parallel/
H A Dstyle.py89 def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): argument
102 def _partition_linear_fn(self, name, module, device_mesh): argument
112 def _partition_embedding_fn(self, name, module, device_mesh): argument
121 def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): argument
191 def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): argument
200 def _partition_linear_fn(self, name, module, device_mesh): argument
212 def _partition_embedding_fn(self, name, module, device_mesh): argument
221 def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): argument
310 def _prepare_input_fn(sequence_dim, mod, inputs, device_mesh): argument
320 def _prepare_output_fn(use_local_output, mod, outputs, device_mesh): argument
[all …]
/aosp_15_r20/external/pytorch/test/distributed/_tensor/
H A Dtest_api.py124 def shard_fn(name, module, device_mesh): argument
146 def replicate_fn(name, module, device_mesh): argument
163 def shard_fn(name, module, device_mesh): argument
188 def input_fn(mod, inputs, device_mesh): argument
191 def output_fn(mod, outputs, device_mesh): argument
210 def replicate_input_fn(mod, inputs, device_mesh): argument
233 def input_fn(inputs, device_mesh): argument
236 def output_fn(outputs, device_mesh): argument
302 def shard_fn(name, module, device_mesh): argument
H A Dtest_embedding_ops.py33 def _apply_sharding(self, embedding_mod, shard_dim, device_mesh): argument
34 def shard_embedding_fn(name, module, device_mesh): argument
48 device_mesh, argument
H A Dtest_optimizers.py24 def shard_fn(name, module, device_mesh): argument
38 def input_fn(mod, inputs, device_mesh): argument
45 def output_fn(mod, outputs, device_mesh): argument
H A Dtest_math_ops.py286 def _replicate_fn(name, module, device_mesh): argument
351 def _replicate_fn(name, module, device_mesh): argument
/aosp_15_r20/external/pytorch/test/distributed/_tensor/experimental/
H A Dtest_local_map.py31 def equal_allgather_forward(device_mesh, X, Y): argument
37 def mm_all_gather_forward(device_mesh, A, B): argument
46 def mm_allreduce_forward(device_mesh, A, B): argument
56 def mm_allreduce_forward_decorated(device_mesh, A, B): argument
/aosp_15_r20/external/pytorch/test/distributed/fsdp/
H A Dtest_fsdp_tp_integration.py73 def distribute_rmsnorm(module, device_mesh): argument
74 def prepare_input_fn(mod, inputs, device_mesh): argument
78 def prepare_output_fn(mod, outputs, device_mesh): argument
H A Dtest_hsdp_dtensor_state_dict.py51 def _create_model(self, device_mesh=None): argument
272 def __init__(self, device_mesh): argument
H A Dtest_fsdp_dtensor_state_dict.py64 def _create_model(self, is_even_sharded_model, device_mesh=None): argument
/aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/experimental/xla_sharding/
H A Dxla_sharding.py507 def mesh_split_sharding(device_mesh, argument
559 device_mesh, argument
/aosp_15_r20/external/pytorch/torch/distributed/tensor/
H A D_dtensor_spec.py130 def device_mesh(self) -> DeviceMesh: member in DTensorSpec
H A D_api.py572 def device_mesh(self) -> DeviceMesh: member in DTensor
/aosp_15_r20/external/pytorch/test/distributed/tensor/parallel/
H A Dtest_tp_examples.py160 def _test_mlp_inference(self, device_mesh): argument
/aosp_15_r20/external/pytorch/torch/distributed/
H A Ddevice_mesh.py287 self, device_mesh, mesh_dim_names argument
/aosp_15_r20/external/pytorch/torch/nn/parallel/
H A Ddistributed.py644 device_mesh=None, argument