Home
last modified time | relevance | path

Searched defs:parallel_dims (Results 1 – 4 of 4) sorted by relevance

/aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/
H A Dhlo_sharding_util.cc1013 const GatherParallelDims& parallel_dims) { in GatherParallelDataOperandSharding()
1088 auto parallel_dims = GetGatherBatchParallelDims(hlo); in GatherDataOperandShardingFromOutput() local
1581 const HloInstruction& gather, const GatherParallelDims& parallel_dims) { in GatherOutputAlignedOperandParallelDims()
H A Dsharding_propagation.cc453 const hlo_sharding_util::GatherParallelDims& parallel_dims, in InferGatherParallelShardingFromOperands()
/aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/spmd/
H A Dgather_scatter_handler.cc495 if (std::optional<hlo_sharding_util::GatherParallelDims> parallel_dims = in PartitionGatherIndexParallelDimensions() local
H A Dspmd_partitioner_util.cc1708 const hlo_sharding_util::GatherParallelDims& parallel_dims) { in GatherOperandsShardedAcrossParallelDims()