Home
last modified time | relevance | path

Searched defs:sharding (Results 1 – 25 of 62) sorted by relevance

123

/aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/
H A Dhlo_sharding_test.cc58 HloSharding sharding = HloSharding::Replicate(); in TEST_F() local
73 HloSharding sharding = HloSharding::AssignDevice(5); in TEST_F() local
110 HloSharding sharding = HloSharding::FromProto(proto).value(); in TEST_F() local
117 HloSharding sharding = HloSharding::Tile(MakeArray({2, 2}, {0, 0, 2, 3})); in TEST_F() local
124 HloSharding sharding = HloSharding::Tile(MakeArray({2, 2}, {0, 1, 2, 3})); in TEST_F() local
132 HloSharding sharding = HloSharding::Tile(MakeArray({2, 2}, {0, 3, 2, 1})); in TEST_F() local
156 HloSharding sharding = HloSharding::SingleTuple(ShapeUtil::MakeTupleShape({}), in TEST_F() local
193 HloSharding sharding = in TEST_F() local
281 HloSharding sharding = HloSharding::Replicate(); in TEST_F() local
289 HloSharding sharding = HloSharding::Replicate(std::get<0>(GetParam())); in TEST_P() local
[all …]
H A Dhlo_sharding_util_test.cc54 HloSharding sharding = HloSharding::AssignDevice(7); in TEST() local
64 HloSharding sharding = HloSharding::Tile(Array3D<int64_t>({{{0}, {1}}})); in TEST() local
127 HloSharding sharding = HloSharding::Tile(sharding_array); in TEST() local
161 HloSharding sharding = HloSharding::Tile(Array3D<int64_t>({{{0}, {1}}})); in TEST() local
170 HloSharding sharding = HloSharding::Tile(Array3D<int64_t>({{{0}, {1}}})); in TEST() local
177 HloSharding sharding = HloSharding::Tile(Array2D<int64_t>({{0, 1}, {2, 3}})); in TEST() local
184 HloSharding sharding = HloSharding::Tile(Array2D<int64_t>({{0, 1}, {2, 3}})); in TEST() local
191 HloSharding sharding = in TEST() local
201 HloSharding sharding = in TEST() local
210 HloSharding sharding = in TEST() local
[all …]
H A Dhlo_sharding_metadata.cc57 const HloSharding& sharding) { in SetSingleSharding()
124 const HloSharding& sharding) { in FixupPassThroughDomainLinks()
147 std::shared_ptr<const HloSharding> sharding) { in CloneShardingForDomain()
156 const HloSharding& sharding) { in ApplyDomainSingleSharding()
345 const HloSharding& sharding) { in ApplyDomainSharding()
389 std::shared_ptr<const HloSharding> sharding; in ExtractOriginalCommonSharding() local
411 std::unique_ptr<HloSharding> sharding; in Clone() local
452 const HloSharding* sharding = sharding_metadata->sharding(); in NormalizeShardingDomain() local
H A Dhlo_sharding_util.cc158 const HloSharding& sharding, int64_t manual_dim) { in MergeShardingIfCompatible()
324 HloSharding TransposeSharding(const HloSharding& sharding, in TransposeSharding()
357 const HloSharding& sharding) { in ReshapeSharding()
493 HloSharding ReverseSharding(const HloSharding& sharding, in ReverseSharding()
517 HloSharding ReshapeToTileDimension(const HloSharding& sharding, int64_t dim, in ReshapeToTileDimension()
1212 const HloSharding& sharding, in DevicesForShardingInternal()
1243 const HloSharding& sharding, absl::Span<const int64_t> available_devices) { in DevicesForSharding()
1260 const HloSharding& sharding, absl::Span<const int64_t> dims_to_replicate) { in PartiallyReplicateTiledShardingOnDims()
1311 const HloSharding& sharding, absl::Span<const int64_t> dims_to_keep) { in PartiallyReplicateTiledShardingOnAllDimsExcept()
1326 HloSharding ReplicateAllDataDims(const HloSharding& sharding, in ReplicateAllDataDims()
[all …]
H A Dsharding_propagation.cc59 bool IsSpatiallyPartitioned(const HloSharding& sharding) { in IsSpatiallyPartitioned()
75 bool MaybeImproveInstructionSharding(HloSharding sharding, in MaybeImproveInstructionSharding()
545 const auto& sharding = inst->sharding(); in InferConvolutionShardingFromOperands() local
681 auto sharding = *hlo_sharding_util::TransposeShardingWithCollapsedDims( in InferDotOperandSharding() local
1067 HloSharding sharding = annotate_op->sharding(); in InferUnspecifiedDimsFromOperand() local
1120 HloSharding sharding = annotate_op->sharding(); in InferUnspecifiedDimsFromOneUser() local
1196 HloSharding SetCSEPreventionSharding(const HloSharding& sharding) { in SetCSEPreventionSharding()
1203 bool IsCSEPreventionSharding(const HloSharding& sharding) { in IsCSEPreventionSharding()
1261 const HloSharding& sharding = instruction->sharding(); in ProcessShardingInstruction() local
1292 const auto& sharding = sharding_metadata->sharding(); in NormalizeDomain() local
[all …]
H A Dhlo_sharding.cc199 HloSharding sharding = PartialTile(tiles, metadata); in Subgroup() local
260 for (auto& sharding : shardings) { in Tuple() local
272 const HloSharding& sharding) { in SingleTuple()
282 const HloSharding& sharding) { in Single()
844 auto assign_metadata = [&](HloSharding& sharding) { in WithMetadata()
850 HloSharding sharding = *this; in WithMetadata() local
862 HloSharding sharding = *this; in WithoutMetadata() local
870 std::ostream& operator<<(std::ostream& out, const HloSharding& sharding) { in operator <<()
H A Dhlo_sharding_metadata.h34 explicit ShardingMetadata(std::shared_ptr<const HloSharding> sharding) in ShardingMetadata()
56 const HloSharding* sharding() const { return sharding_.get(); } in sharding() function
90 std::shared_ptr<const HloSharding> sharding; member
H A Dhlo_matchers.h144 explicit HloShardingMatcher(const std::optional<HloSharding>& sharding) in HloShardingMatcher()
456 const HloSharding& sharding) { in Sharding()
462 absl::string_view sharding) { in Sharding()
H A Dbatchnorm_expander.cc273 const HloSharding& sharding = batch_norm->sharding(); in HandleBatchNormTraining() local
364 const HloSharding& sharding = batch_norm->sharding(); in HandleBatchNormInference() local
538 const HloSharding& sharding = batch_norm->sharding(); in HandleBatchNormGrad() local
/aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/
H A Dtpu_sharding_identification_pass.cc91 xla::OpSharding sharding; in VerifySharding() local
117 StringRef sharding = std::get<0>(sharding_and_arg); in VerifyShardings() local
124 StringRef sharding = std::get<0>(sharding_and_retval); in VerifyShardings() local
152 if (auto sharding = llvm::dyn_cast<TF::XlaShardingOp>(owner)) in GetXlaShardingFromArg() local
312 if (auto sharding = llvm::dyn_cast_or_null<TF::XlaShardingOp>(def)) in GetXlaShardingFromRetval() local
315 if (auto sharding = def->getAttrOfType<StringAttr>("_XlaSharding")) { in GetXlaShardingFromRetval() local
440 xla::OpSharding sharding; in IdentifyXlaShardingForTPUComputation() local
474 StringRef sharding = std::get<0>(sharding_and_arg); in IdentifyXlaShardingForTPUComputation() local
483 StringRef sharding = std::get<0>(sharding_and_retval); in IdentifyXlaShardingForTPUComputation() local
/aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/utils/
H A Dxla_sharding_util.cc197 bool UnsupportedPartitionedShardingType(xla::OpSharding::Type sharding) { in UnsupportedPartitionedShardingType()
234 xla::OpSharding sharding; in ExtractInputsForLogicalDevices() local
329 xla::OpSharding sharding; in ParseAndValidateOutputSharding() local
359 const xla::OpSharding& sharding) { in IsAssignedToLogicalDevice()
397 const xla::OpSharding& sharding = in GetTileShardedOutputsToMerge() local
438 const xla::OpSharding& sharding = in HandleTileShardedOutputs() local
657 const auto& sharding = arg_and_idx.value().sharding(); in GetMetadataArgumentMapping() local
/aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2xla/
H A Dsharding_util_test.cc28 [](std::optional<xla::OpSharding> sharding) -> int64 { in TEST()
77 auto check_metadata = [](const xla::OpSharding& sharding) { in TEST_P()
90 auto& sharding = status_or_sharding.ValueOrDie(); in TEST_P() local
129 xla::OpSharding sharding; in CreateTupleSharding() local
H A Dsharding_util.cc38 void AssignOpMetadataToSharding(xla::OpSharding& sharding, in AssignOpMetadataToSharding()
81 auto sharding = xla::sharding_builder::AssignDevice(core); in ParseShardingFromDevice() local
159 xla::OpSharding sharding; in GetShardingFromNodeDefInternal() local
H A Dlayout_util.cc38 const std::optional<xla::HloSharding>& sharding, bool use_fast_memory, in RewriteLayoutWithShardedShape()
85 std::optional<xla::OpSharding> sharding, bool fast_mem) { in ReshapeWithCorrectRepresentationAndSharding()
H A Dxla_compiler.cc194 const std::optional<xla::OpSharding>& sharding) { in BuildComputation()
229 std::optional<xla::OpSharding> sharding = in BuildComputation() local
330 auto sharding = it == arg_shardings.end() in BuildComputation() local
1054 std::optional<xla::HloSharding> sharding; in BuildArguments() local
/aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/spmd/
H A Dspmd_partitioner.cc204 const HloSharding& sharding, absl::Span<const int64_t> replication_dims) { in GetPartitionGroupsForReplication()
228 const HloSharding& sharding, const Window& window) { in GetShardingReplicatedOnWindowedDimension()
581 const HloSharding& sharding = hlo_->sharding(); in PadWithValueHlo() local
1141 const HloSharding sharding = hlo_->sharding(); in Replicate() local
1453 const HloSharding& sharding = hlo_->sharding(); in Broadcast() local
2062 HloSharding sharding = hlo->sharding().HasUniqueDevice() in DefaultAction() local
2091 const HloSharding& sharding) { in Preprocess()
2122 [](const HloSharding& sharding) { return sharding.IsManual(); })); in Preprocess()
2140 [](const HloSharding& sharding) { in Preprocess()
2147 nullptr) -> StatusOr<GroupedSharding> { in Preprocess()
[all …]
H A Dspmd_partitioner_util.cc50 bool HasReplicatedSharding(const HloSharding& sharding) { in HasReplicatedSharding()
74 bool EvenlyPartitions(const Shape& shape, const HloSharding& sharding) { in EvenlyPartitions()
95 Shape MakePartitionedShape(const Shape& shape, const HloSharding& sharding) { in MakePartitionedShape()
116 const HloSharding& sharding, in MakeNonPaddedShapeForGivenPartition()
156 const Shape& shape, const HloSharding& sharding, in MakePartitionOffsets()
191 const HloSharding& sharding, HloInstruction* partition_id, SpmdBuilder* b) { in MakeTiledPartitionOrdinals()
202 const HloSharding& sharding) { in GetPaddedShapeForUnevenPartitioning()
545 std::optional<int64_t> UniqueTiledDim(const HloSharding& sharding) { in UniqueTiledDim()
1189 const HloSharding& sharding = sort->operand(0)->sharding(); in GetKValueInTopKWhenPartitionSortDim() local
1238 int64_t ShardCountAtDim(const HloSharding& sharding, int64_t dim) { in ShardCountAtDim()
[all …]
H A Dfft_handler.cc51 HloInstruction* hlo, int64_t num_partitions, const HloSharding& sharding, in PadEachPartitionWithHaloExchange()
226 HloInstruction* hlo, const HloSharding& sharding, in GetFinalFftUsingCollectivePermute()
/aosp_15_r20/external/tensorflow/tensorflow/core/tpu/graph_rewrite/
H A Ddistributed_tpu_rewrite_pass.cc608 const xla::OpSharding& sharding, std::map<int, int>* split_dimension_map) { in GetDimensionIndicesAndNumSplitsFromSharding()
803 const xla::OpSharding& sharding, int orig_arg_num, DataType dtype, in CreateOrGetSplitNodesForInputSharding()
924 const xla::OpSharding& sharding, in CreateXlaSplitOp()
979 const xla::OpSharding& sharding, Graph* graph) { in ShardInputWithXlaSplitOp()
1007 const xla::OpSharding& sharding, const int replica_id, in CreateOrGetXlaSplitNodeForShardedPerReplicaArg()
1034 const xla::OpSharding& sharding, const int num_replicas, in CreateOrGetXlaSplitNodeForDistributedArg()
1064 const xla::OpSharding& sharding, const int num_replicas, in CreateOrGetXlaSplitNodeForVariableArg()
1233 const xla::OpSharding& sharding, DataType dtype, in CreateConcatNodesForRetval()
1284 const xla::OpSharding& sharding, const int replica_id, DataType dtype, in CreateXlaConcatNode()
1398 xla::OpSharding sharding; member
[all …]
/aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/client/
H A Dxla_builder.h214 void SetSharding(const OpSharding& sharding) { sharding_ = sharding; } in SetSharding()
250 const std::optional<OpSharding>& sharding() const { return sharding_; } in sharding() function
1576 std::optional<OpSharding> sharding) in XlaScopedShardingAssignment()
1588 void SetSharding(const std::optional<OpSharding>& sharding) { in SetSharding()
/aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/xla/
H A Dlayout_util.cc22 const std::optional<xla::HloSharding>& sharding, bool use_fast_memory, in RewriteLayoutWithShardedShape()
70 std::optional<xla::OpSharding> sharding, bool fast_mem) { in ReshapeWithCorrectRepresentationAndSharding()
H A Dmlir_hlo_to_hlo.cc458 llvm::StringRef sharding) { in CreateOpShardingFromStringRef()
469 auto sharding = op->getAttrOfType<mlir::StringAttr>(kShardingAttr); in CreateOpShardingFromAttribute() local
530 [](const std::optional<xla::OpSharding>& sharding) { in AllOptionalShardingsAreSet()
543 if (auto sharding = in ExtractShardingsFromFunction() local
550 if (auto sharding = in ExtractShardingsFromFunction() local
2190 xla::OpSharding sharding; in Lower() local
2361 xla::OpSharding sharding; in SetEntryTupleShardings() local
/aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2xla/kernels/
H A Dspmd_manual_sharding_ops.cc50 xla::OpSharding sharding; in Compile() local
97 xla::OpSharding sharding; in Compile() local
/aosp_15_r20/external/tensorflow/tensorflow/core/tpu/kernels/xla/
H A Dinfeed_op.cc52 absl::optional<xla::OpSharding> sharding) { in UpdateInfeedLayout()
131 absl::optional<xla::OpSharding> sharding; in Compile() local
/aosp_15_r20/external/tensorflow/tensorflow/core/protobuf/tpu/
H A Dcompile_metadata.proto33 xla.OpSharding sharding = 4; field
76 xla.OpSharding sharding = 1; field

123