/aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/ |
H A D | hlo_sharding_test.cc | 58 HloSharding sharding = HloSharding::Replicate(); in TEST_F() local 73 HloSharding sharding = HloSharding::AssignDevice(5); in TEST_F() local 110 HloSharding sharding = HloSharding::FromProto(proto).value(); in TEST_F() local 117 HloSharding sharding = HloSharding::Tile(MakeArray({2, 2}, {0, 0, 2, 3})); in TEST_F() local 124 HloSharding sharding = HloSharding::Tile(MakeArray({2, 2}, {0, 1, 2, 3})); in TEST_F() local 132 HloSharding sharding = HloSharding::Tile(MakeArray({2, 2}, {0, 3, 2, 1})); in TEST_F() local 156 HloSharding sharding = HloSharding::SingleTuple(ShapeUtil::MakeTupleShape({}), in TEST_F() local 193 HloSharding sharding = in TEST_F() local 281 HloSharding sharding = HloSharding::Replicate(); in TEST_F() local 289 HloSharding sharding = HloSharding::Replicate(std::get<0>(GetParam())); in TEST_P() local [all …]
|
H A D | hlo_sharding_util_test.cc | 54 HloSharding sharding = HloSharding::AssignDevice(7); in TEST() local 64 HloSharding sharding = HloSharding::Tile(Array3D<int64_t>({{{0}, {1}}})); in TEST() local 127 HloSharding sharding = HloSharding::Tile(sharding_array); in TEST() local 161 HloSharding sharding = HloSharding::Tile(Array3D<int64_t>({{{0}, {1}}})); in TEST() local 170 HloSharding sharding = HloSharding::Tile(Array3D<int64_t>({{{0}, {1}}})); in TEST() local 177 HloSharding sharding = HloSharding::Tile(Array2D<int64_t>({{0, 1}, {2, 3}})); in TEST() local 184 HloSharding sharding = HloSharding::Tile(Array2D<int64_t>({{0, 1}, {2, 3}})); in TEST() local 191 HloSharding sharding = in TEST() local 201 HloSharding sharding = in TEST() local 210 HloSharding sharding = in TEST() local [all …]
|
H A D | hlo_sharding_metadata.cc | 57 const HloSharding& sharding) { in SetSingleSharding() 124 const HloSharding& sharding) { in FixupPassThroughDomainLinks() 147 std::shared_ptr<const HloSharding> sharding) { in CloneShardingForDomain() 156 const HloSharding& sharding) { in ApplyDomainSingleSharding() 345 const HloSharding& sharding) { in ApplyDomainSharding() 389 std::shared_ptr<const HloSharding> sharding; in ExtractOriginalCommonSharding() local 411 std::unique_ptr<HloSharding> sharding; in Clone() local 452 const HloSharding* sharding = sharding_metadata->sharding(); in NormalizeShardingDomain() local
|
H A D | hlo_sharding_util.cc | 158 const HloSharding& sharding, int64_t manual_dim) { in MergeShardingIfCompatible() 324 HloSharding TransposeSharding(const HloSharding& sharding, in TransposeSharding() 357 const HloSharding& sharding) { in ReshapeSharding() 493 HloSharding ReverseSharding(const HloSharding& sharding, in ReverseSharding() 517 HloSharding ReshapeToTileDimension(const HloSharding& sharding, int64_t dim, in ReshapeToTileDimension() 1212 const HloSharding& sharding, in DevicesForShardingInternal() 1243 const HloSharding& sharding, absl::Span<const int64_t> available_devices) { in DevicesForSharding() 1260 const HloSharding& sharding, absl::Span<const int64_t> dims_to_replicate) { in PartiallyReplicateTiledShardingOnDims() 1311 const HloSharding& sharding, absl::Span<const int64_t> dims_to_keep) { in PartiallyReplicateTiledShardingOnAllDimsExcept() 1326 HloSharding ReplicateAllDataDims(const HloSharding& sharding, in ReplicateAllDataDims() [all …]
|
H A D | sharding_propagation.cc | 59 bool IsSpatiallyPartitioned(const HloSharding& sharding) { in IsSpatiallyPartitioned() 75 bool MaybeImproveInstructionSharding(HloSharding sharding, in MaybeImproveInstructionSharding() 545 const auto& sharding = inst->sharding(); in InferConvolutionShardingFromOperands() local 681 auto sharding = *hlo_sharding_util::TransposeShardingWithCollapsedDims( in InferDotOperandSharding() local 1067 HloSharding sharding = annotate_op->sharding(); in InferUnspecifiedDimsFromOperand() local 1120 HloSharding sharding = annotate_op->sharding(); in InferUnspecifiedDimsFromOneUser() local 1196 HloSharding SetCSEPreventionSharding(const HloSharding& sharding) { in SetCSEPreventionSharding() 1203 bool IsCSEPreventionSharding(const HloSharding& sharding) { in IsCSEPreventionSharding() 1261 const HloSharding& sharding = instruction->sharding(); in ProcessShardingInstruction() local 1292 const auto& sharding = sharding_metadata->sharding(); in NormalizeDomain() local [all …]
|
H A D | hlo_sharding.cc | 199 HloSharding sharding = PartialTile(tiles, metadata); in Subgroup() local 260 for (auto& sharding : shardings) { in Tuple() local 272 const HloSharding& sharding) { in SingleTuple() 282 const HloSharding& sharding) { in Single() 844 auto assign_metadata = [&](HloSharding& sharding) { in WithMetadata() 850 HloSharding sharding = *this; in WithMetadata() local 862 HloSharding sharding = *this; in WithoutMetadata() local 870 std::ostream& operator<<(std::ostream& out, const HloSharding& sharding) { in operator <<()
|
H A D | hlo_sharding_metadata.h | 34 explicit ShardingMetadata(std::shared_ptr<const HloSharding> sharding) in ShardingMetadata() 56 const HloSharding* sharding() const { return sharding_.get(); } in sharding() function 90 std::shared_ptr<const HloSharding> sharding; member
|
H A D | hlo_matchers.h | 144 explicit HloShardingMatcher(const std::optional<HloSharding>& sharding) in HloShardingMatcher() 456 const HloSharding& sharding) { in Sharding() 462 absl::string_view sharding) { in Sharding()
|
H A D | batchnorm_expander.cc | 273 const HloSharding& sharding = batch_norm->sharding(); in HandleBatchNormTraining() local 364 const HloSharding& sharding = batch_norm->sharding(); in HandleBatchNormInference() local 538 const HloSharding& sharding = batch_norm->sharding(); in HandleBatchNormGrad() local
|
/aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/ |
H A D | tpu_sharding_identification_pass.cc | 91 xla::OpSharding sharding; in VerifySharding() local 117 StringRef sharding = std::get<0>(sharding_and_arg); in VerifyShardings() local 124 StringRef sharding = std::get<0>(sharding_and_retval); in VerifyShardings() local 152 if (auto sharding = llvm::dyn_cast<TF::XlaShardingOp>(owner)) in GetXlaShardingFromArg() local 312 if (auto sharding = llvm::dyn_cast_or_null<TF::XlaShardingOp>(def)) in GetXlaShardingFromRetval() local 315 if (auto sharding = def->getAttrOfType<StringAttr>("_XlaSharding")) { in GetXlaShardingFromRetval() local 440 xla::OpSharding sharding; in IdentifyXlaShardingForTPUComputation() local 474 StringRef sharding = std::get<0>(sharding_and_arg); in IdentifyXlaShardingForTPUComputation() local 483 StringRef sharding = std::get<0>(sharding_and_retval); in IdentifyXlaShardingForTPUComputation() local
|
/aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/utils/ |
H A D | xla_sharding_util.cc | 197 bool UnsupportedPartitionedShardingType(xla::OpSharding::Type sharding) { in UnsupportedPartitionedShardingType() 234 xla::OpSharding sharding; in ExtractInputsForLogicalDevices() local 329 xla::OpSharding sharding; in ParseAndValidateOutputSharding() local 359 const xla::OpSharding& sharding) { in IsAssignedToLogicalDevice() 397 const xla::OpSharding& sharding = in GetTileShardedOutputsToMerge() local 438 const xla::OpSharding& sharding = in HandleTileShardedOutputs() local 657 const auto& sharding = arg_and_idx.value().sharding(); in GetMetadataArgumentMapping() local
|
/aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2xla/ |
H A D | sharding_util_test.cc | 28 [](std::optional<xla::OpSharding> sharding) -> int64 { in TEST() 77 auto check_metadata = [](const xla::OpSharding& sharding) { in TEST_P() 90 auto& sharding = status_or_sharding.ValueOrDie(); in TEST_P() local 129 xla::OpSharding sharding; in CreateTupleSharding() local
|
H A D | sharding_util.cc | 38 void AssignOpMetadataToSharding(xla::OpSharding& sharding, in AssignOpMetadataToSharding() 81 auto sharding = xla::sharding_builder::AssignDevice(core); in ParseShardingFromDevice() local 159 xla::OpSharding sharding; in GetShardingFromNodeDefInternal() local
|
H A D | layout_util.cc | 38 const std::optional<xla::HloSharding>& sharding, bool use_fast_memory, in RewriteLayoutWithShardedShape() 85 std::optional<xla::OpSharding> sharding, bool fast_mem) { in ReshapeWithCorrectRepresentationAndSharding()
|
H A D | xla_compiler.cc | 194 const std::optional<xla::OpSharding>& sharding) { in BuildComputation() 229 std::optional<xla::OpSharding> sharding = in BuildComputation() local 330 auto sharding = it == arg_shardings.end() in BuildComputation() local 1054 std::optional<xla::HloSharding> sharding; in BuildArguments() local
|
/aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/spmd/ |
H A D | spmd_partitioner.cc | 204 const HloSharding& sharding, absl::Span<const int64_t> replication_dims) { in GetPartitionGroupsForReplication() 228 const HloSharding& sharding, const Window& window) { in GetShardingReplicatedOnWindowedDimension() 581 const HloSharding& sharding = hlo_->sharding(); in PadWithValueHlo() local 1141 const HloSharding sharding = hlo_->sharding(); in Replicate() local 1453 const HloSharding& sharding = hlo_->sharding(); in Broadcast() local 2062 HloSharding sharding = hlo->sharding().HasUniqueDevice() in DefaultAction() local 2091 const HloSharding& sharding) { in Preprocess() 2122 [](const HloSharding& sharding) { return sharding.IsManual(); })); in Preprocess() 2140 [](const HloSharding& sharding) { in Preprocess() 2147 nullptr) -> StatusOr<GroupedSharding> { in Preprocess() [all …]
|
H A D | spmd_partitioner_util.cc | 50 bool HasReplicatedSharding(const HloSharding& sharding) { in HasReplicatedSharding() 74 bool EvenlyPartitions(const Shape& shape, const HloSharding& sharding) { in EvenlyPartitions() 95 Shape MakePartitionedShape(const Shape& shape, const HloSharding& sharding) { in MakePartitionedShape() 116 const HloSharding& sharding, in MakeNonPaddedShapeForGivenPartition() 156 const Shape& shape, const HloSharding& sharding, in MakePartitionOffsets() 191 const HloSharding& sharding, HloInstruction* partition_id, SpmdBuilder* b) { in MakeTiledPartitionOrdinals() 202 const HloSharding& sharding) { in GetPaddedShapeForUnevenPartitioning() 545 std::optional<int64_t> UniqueTiledDim(const HloSharding& sharding) { in UniqueTiledDim() 1189 const HloSharding& sharding = sort->operand(0)->sharding(); in GetKValueInTopKWhenPartitionSortDim() local 1238 int64_t ShardCountAtDim(const HloSharding& sharding, int64_t dim) { in ShardCountAtDim() [all …]
|
H A D | fft_handler.cc | 51 HloInstruction* hlo, int64_t num_partitions, const HloSharding& sharding, in PadEachPartitionWithHaloExchange() 226 HloInstruction* hlo, const HloSharding& sharding, in GetFinalFftUsingCollectivePermute()
|
/aosp_15_r20/external/tensorflow/tensorflow/core/tpu/graph_rewrite/ |
H A D | distributed_tpu_rewrite_pass.cc | 608 const xla::OpSharding& sharding, std::map<int, int>* split_dimension_map) { in GetDimensionIndicesAndNumSplitsFromSharding() 803 const xla::OpSharding& sharding, int orig_arg_num, DataType dtype, in CreateOrGetSplitNodesForInputSharding() 924 const xla::OpSharding& sharding, in CreateXlaSplitOp() 979 const xla::OpSharding& sharding, Graph* graph) { in ShardInputWithXlaSplitOp() 1007 const xla::OpSharding& sharding, const int replica_id, in CreateOrGetXlaSplitNodeForShardedPerReplicaArg() 1034 const xla::OpSharding& sharding, const int num_replicas, in CreateOrGetXlaSplitNodeForDistributedArg() 1064 const xla::OpSharding& sharding, const int num_replicas, in CreateOrGetXlaSplitNodeForVariableArg() 1233 const xla::OpSharding& sharding, DataType dtype, in CreateConcatNodesForRetval() 1284 const xla::OpSharding& sharding, const int replica_id, DataType dtype, in CreateXlaConcatNode() 1398 xla::OpSharding sharding; member [all …]
|
/aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/client/ |
H A D | xla_builder.h | 214 void SetSharding(const OpSharding& sharding) { sharding_ = sharding; } in SetSharding() 250 const std::optional<OpSharding>& sharding() const { return sharding_; } in sharding() function 1576 std::optional<OpSharding> sharding) in XlaScopedShardingAssignment() 1588 void SetSharding(const std::optional<OpSharding>& sharding) { in SetSharding()
|
/aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/xla/ |
H A D | layout_util.cc | 22 const std::optional<xla::HloSharding>& sharding, bool use_fast_memory, in RewriteLayoutWithShardedShape() 70 std::optional<xla::OpSharding> sharding, bool fast_mem) { in ReshapeWithCorrectRepresentationAndSharding()
|
H A D | mlir_hlo_to_hlo.cc | 458 llvm::StringRef sharding) { in CreateOpShardingFromStringRef() 469 auto sharding = op->getAttrOfType<mlir::StringAttr>(kShardingAttr); in CreateOpShardingFromAttribute() local 530 [](const std::optional<xla::OpSharding>& sharding) { in AllOptionalShardingsAreSet() 543 if (auto sharding = in ExtractShardingsFromFunction() local 550 if (auto sharding = in ExtractShardingsFromFunction() local 2190 xla::OpSharding sharding; in Lower() local 2361 xla::OpSharding sharding; in SetEntryTupleShardings() local
|
/aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2xla/kernels/ |
H A D | spmd_manual_sharding_ops.cc | 50 xla::OpSharding sharding; in Compile() local 97 xla::OpSharding sharding; in Compile() local
|
/aosp_15_r20/external/tensorflow/tensorflow/core/tpu/kernels/xla/ |
H A D | infeed_op.cc | 52 absl::optional<xla::OpSharding> sharding) { in UpdateInfeedLayout() 131 absl::optional<xla::OpSharding> sharding; in Compile() local
|
/aosp_15_r20/external/tensorflow/tensorflow/core/protobuf/tpu/ |
H A D | compile_metadata.proto | 33 xla.OpSharding sharding = 4; field 76 xla.OpSharding sharding = 1; field
|