Home
last modified time | relevance | path

Searched defs:arg_sharding (Results 1 – 4 of 4) sorted by relevance

/aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2xla/
H A Dxla_compiler.cc852 const std::optional<xla::HloSharding>& arg_sharding, in XLAShapeForArgument()
1053 auto arg_sharding = arg_shardings.find((*input_to_args)[i]); in BuildArguments() local
/aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/
H A Dtpu_sharding_identification_pass.cc219 auto arg_sharding = GetXlaShardingFromArg(arg); in IdentifyXlaShardingForComputationInputs() local
/aosp_15_r20/external/tensorflow/tensorflow/core/tpu/graph_rewrite/
H A Ddistributed_tpu_rewrite_pass.cc2132 std::vector<xla::OpSharding>* arg_sharding, std::vector<bool>* arg_fast_mem, in AssignArgsAndRetvalsToCores()
2599 const std::vector<xla::OpSharding>& arg_sharding, in BuildCompileNode()
4760 std::vector<xla::OpSharding> arg_sharding; in RewriteTPUReplicateNode() local
/aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/xla/
H A Dmlir_hlo_to_hlo.cc2363 for (const auto& arg_sharding : llvm::enumerate(arg_shardings)) { in SetEntryTupleShardings() local