Home
last modified time | relevance | path

Searched defs:mesh_dim (Results 1 – 9 of 9) sorted by relevance

/aosp_15_r20/external/tensorflow/tensorflow/dtensor/cc/
H A Dtensor_layout.cc305 for (const auto& mesh_dim : dims()) { in dim_size() local
312 for (const auto& mesh_dim : dims()) dim_names.push_back(mesh_dim.name); in dim_size() local
322 for (const auto& mesh_dim : mesh_dims_) dim_sizes.push_back(mesh_dim.size); in dim_sizes() local
395 for (const auto& mesh_dim : dims()) in IsMeshDim() local
403 const auto mesh_dim = dim(i); in GetMeshDimIndexWithName() local
505 MeshDimension mesh_dim; in StrToMeshDimension() local
531 for (const MeshDimension& mesh_dim : mesh_dims) mesh_size *= mesh_dim.size; in GenerateMeshDevicesForTests() local
732 for (const MeshDimension& mesh_dim : layout->mesh().dims()) { in ReducedAbstractMesh() local
H A Dsave_restore_util.cc174 const std::string& mesh_dim = sharding_spec_strs[tensor_dim_index]; in SliceSpecOnDevice() local
/aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/
H A Dcollectives.cc338 const std::string& mesh_dim) { in GetMeshDimensionOffsetWithNeighbor()
375 for (const MeshDimension& mesh_dim : mesh_dimensions) { in CreateConstSrcTargetPair() local
393 const MeshDimension& mesh_dim = data.value(); in CreateConstSrcTargetPair() local
433 const std::string& mesh_dim, in EmitHaloExchange()
H A Dlayout_propagation_v2.cc472 const std::string& mesh_dim = layout.dim(i).sharding_spec(); in GetMostShardedLayout() local
/aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/utils/
H A Dtpu_rewrite_device_util_test.cc117 for (int mesh_dim : mesh_shape) topology_proto.add_mesh_shape(mesh_dim); in TopologyWithMeshShape() local
125 for (int mesh_dim : mesh_shape) topology_proto.add_mesh_shape(mesh_dim); in TopologyWithMeshShapeAndTasks() local
/aosp_15_r20/external/pytorch/torch/distributed/tensor/
H A D_collective_utils.py50 def shard_dim_alltoall(input, gather_dim, shard_dim, mesh, mesh_dim): argument
/aosp_15_r20/external/pytorch/torch/distributed/tensor/parallel/
H A Dloss.py127 def _log_softmax(x, dim, half_to_float, mesh, mesh_dim): argument
/aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/expansions/
H A Ddataparallel_spmd_expander.cc166 std::string mesh_dim = Layout::kUnshardedDim; in IntermediateBatchLayout() local
H A Ddtensor_op_spmd_expander.cc490 for (const auto& mesh_dim : recv_mesh.dims()) in ExpandOp() local