1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/dtensor/mlir/expansions/softmax_spmd_expander.h"
17
18 #include <optional>
19 #include <string>
20
21 #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
22 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
23 #include "tensorflow/dtensor/cc/tensor_layout.h"
24 #include "tensorflow/dtensor/mlir/collectives.h"
25 #include "tensorflow/dtensor/mlir/layout_parsing.h"
26 #include "tensorflow/dtensor/mlir/op_utils.h"
27 #include "tensorflow/dtensor/mlir/shape_utils.h"
28 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
29 #include "tensorflow/dtensor/mlir/value_utils.h"
30
31 namespace tensorflow {
32 namespace dtensor {
33 namespace {
34
35 // Computes a local reduce followed by an EmitAllReduce. This performs a global
36 // reduction, output will have global shape 1 on the reduced axes if keep dims
37 // is true otherwise the axes will be removed.
38 // Assumes builder's insertion point is after input.
ComputeGlobalReduce(mlir::OpBuilder & builder,const mlir::Value & input,const Layout & input_layout,const absl::flat_hash_set<int> & reduced_dims,absl::string_view reduce_op,bool keep_dims)39 StatusOr<mlir::Value> ComputeGlobalReduce(
40 mlir::OpBuilder& builder, const mlir::Value& input,
41 const Layout& input_layout, const absl::flat_hash_set<int>& reduced_dims,
42 absl::string_view reduce_op, bool keep_dims) {
43 const Layout reduction_layout =
44 input_layout.GetLayoutWithReducedDims(reduced_dims,
45 /*keep_dims=*/true);
46 std::vector<int32> reduce_dim_array(reduced_dims.begin(), reduced_dims.end());
47 const mlir::Value reduction_indices =
48 IntConst(builder, input.getLoc(), reduce_dim_array);
49 mlir::Operation* local_reduce;
50
51 // First compute a local reduce
52 if (reduce_op == kReduceOpAdd) {
53 local_reduce = builder.create<mlir::TF::SumOp>(
54 input.getLoc(), input, reduction_indices,
55 /*keep_dims=*/builder.getBoolAttr(true));
56 } else if (reduce_op == kReduceOpMax) {
57 local_reduce = builder.create<mlir::TF::MaxOp>(
58 input.getLoc(), input, reduction_indices,
59 /*keep_dims=*/builder.getBoolAttr(true));
60 } else {
61 return errors::Unimplemented("reduction ", reduce_op, " not implemented");
62 }
63
64 // Then an all reduce.
65 absl::flat_hash_set<std::string> reduced_sharding_specs;
66 for (const int dim : reduced_dims)
67 if (Layout::IsShardedSpec(input_layout.dim(dim)))
68 reduced_sharding_specs.emplace(input_layout.sharding_spec(dim));
69 TF_ASSIGN_OR_RETURN(
70 mlir::Operation * global_reduce,
71 EmitAllReduce(builder, reduction_layout, reduced_sharding_specs,
72 local_reduce, reduce_op));
73
74 if (!keep_dims) {
75 mlir::RankedTensorType output_type =
76 global_reduce->getResult(0)
77 .getType()
78 .dyn_cast<mlir::RankedTensorType>();
79 if (!output_type)
80 return errors::Internal(
81 "output of EmitAllReduce is not a RankedTensorType");
82 std::vector<int64_t> new_shape;
83 for (int i = 0; i < output_type.getRank(); ++i)
84 if (!reduced_dims.contains(i))
85 new_shape.emplace_back(output_type.getDimSize(i));
86 mlir::RankedTensorType new_type =
87 mlir::RankedTensorType::get(new_shape, output_type.getElementType());
88 // Upcast the dimensions to int64_t as SqueezeOp requires this for its
89 // dimension attribute type. Everything else is OK with int32_t dimensions.
90 std::vector<int64_t> reduce_dim_array_64(reduced_dims.begin(),
91 reduced_dims.end());
92 global_reduce = builder.create<mlir::TF::SqueezeOp>(
93 input.getLoc(), new_type, global_reduce->getResult(0),
94 builder.getI64ArrayAttr(reduce_dim_array_64));
95 }
96 return global_reduce->getResult(0);
97 }
98
99 // Takes a sharded logits and compute both the shifted exponentiation of the
100 // logits and its sum. Assumes that builder's insertion point is after logits.
ComputeExpAndSum(mlir::OpBuilder & builder,const mlir::Value & logits,const Layout & logits_layout,mlir::Value & shifted_logits,mlir::Value & exp_of_shifted_logits,mlir::Value & sum_of_exp)101 Status ComputeExpAndSum(mlir::OpBuilder& builder, const mlir::Value& logits,
102 const Layout& logits_layout,
103 mlir::Value& shifted_logits,
104 mlir::Value& exp_of_shifted_logits,
105 mlir::Value& sum_of_exp) {
106 auto loc = logits.getLoc();
107
108 if (logits_layout.rank() == 0)
109 return errors::Unimplemented("softmax not supported for rank 0 tensors.");
110
111 const int64 class_dimension = logits_layout.rank() - 1;
112
113 // Softmax is exp(input)/sum(exp(input)) and LogSoftmax is
114 // logits - log(sum(exp(input)) where the sum takes place on the
115 // last axis.
116 // For numerical stability, we shift the logits by the max (along
117 // the last axis) before doing the above calculation.
118
119 // Construct the max.
120 TF_ASSIGN_OR_RETURN(
121 const mlir::Value max_logits,
122 ComputeGlobalReduce(builder, logits, logits_layout, {class_dimension},
123 kReduceOpMax, /*keep_dims=*/true));
124
125 // Subtract max from local copy of logits.
126 shifted_logits =
127 builder.create<mlir::TF::SubOp>(loc, logits, max_logits).getResult();
128 exp_of_shifted_logits =
129 builder.create<mlir::TF::ExpOp>(loc, shifted_logits).getResult();
130
131 // Sum the exponential.
132 TF_ASSIGN_OR_RETURN(
133 sum_of_exp,
134 ComputeGlobalReduce(builder, exp_of_shifted_logits, logits_layout,
135 {class_dimension}, kReduceOpAdd,
136 /*keep_dims=*/true));
137 return OkStatus();
138 }
139
140 // Computes softmax from its components. Assumes that builder's insertion point
141 // is after sum_of_exp and exp_of_shifted_logits.
ComputeSoftmax(mlir::OpBuilder & builder,const mlir::Value & exp_of_shifted_logits,const mlir::Value & sum_of_exp)142 mlir::Value ComputeSoftmax(mlir::OpBuilder& builder,
143 const mlir::Value& exp_of_shifted_logits,
144 const mlir::Value& sum_of_exp) {
145 // For Softmax, we compute exp(shifted_logits)/sum(exp(shifted_logits))
146 auto softmax = builder.create<mlir::TF::DivOp>(
147 exp_of_shifted_logits.getLoc(), exp_of_shifted_logits, sum_of_exp);
148 return softmax.getResult();
149 }
150
151 // Computes softmax from its components. Assumes that builder's insertion point
152 // is after shifted_logits and sum_of_exp.
ComputeLogSoftmax(mlir::OpBuilder & builder,const mlir::Value & shifted_logits,const mlir::Value & sum_of_exp)153 mlir::Value ComputeLogSoftmax(mlir::OpBuilder& builder,
154 const mlir::Value& shifted_logits,
155 const mlir::Value& sum_of_exp) {
156 // For LogSoftmax, we compute shifted_logs - log(sum(exp(shifted_logits)))
157 auto log_of_sum =
158 builder.create<mlir::TF::LogOp>(shifted_logits.getLoc(), sum_of_exp);
159 auto log_softmax = builder.create<mlir::TF::SubOp>(
160 shifted_logits.getLoc(), shifted_logits, log_of_sum.getResult());
161 return log_softmax.getResult();
162 }
163
164 // Computes the softmax of the input along the last axis, assuming that the
165 // input is sharded along that axis.
ComputeShardedSoftmax(mlir::OpBuilder & builder,const mlir::Value & logits,const Layout & logits_layout,bool log_softmax)166 StatusOr<mlir::Value> ComputeShardedSoftmax(mlir::OpBuilder& builder,
167 const mlir::Value& logits,
168 const Layout& logits_layout,
169 bool log_softmax) {
170 mlir::Value shifted_logits;
171 mlir::Value exp_of_shifted_logits;
172 mlir::Value sum_of_exp;
173 TF_RETURN_IF_ERROR(ComputeExpAndSum(builder, logits, logits_layout,
174 shifted_logits, exp_of_shifted_logits,
175 sum_of_exp));
176
177 if (log_softmax) {
178 return ComputeLogSoftmax(builder, shifted_logits, sum_of_exp);
179 } else {
180 return ComputeSoftmax(builder, exp_of_shifted_logits, sum_of_exp);
181 }
182 }
183
184 // Creates a layout from specs which is
185 // 1) Left truncated to match the size of global_shape.
186 // 2) Has unsharded dimensions where ever global_shape is 1.
GetBroadcastedLayout(llvm::ArrayRef<int64_t> global_shape,const std::vector<ShardingSpec> & specs,const Mesh & mesh)187 StatusOr<Layout> GetBroadcastedLayout(llvm::ArrayRef<int64_t> global_shape,
188 const std::vector<ShardingSpec>& specs,
189 const Mesh& mesh) {
190 std::vector<ShardingSpec> new_specs(global_shape.size());
191 for (int i = 0; i < global_shape.size(); ++i) {
192 if (global_shape[i] == 1)
193 new_specs[i].set_sharding_spec(Layout::kUnshardedDim);
194 else
195 new_specs[i] = specs[i + specs.size() - global_shape.size()];
196 }
197
198 return Layout::GetLayout(new_specs, mesh);
199 }
200
201 // Gets a scalar floating point constant with the same element type as the input
202 // value. Assumes builder's insertion point is after input.
GetFPConstOfType(mlir::OpBuilder & builder,const mlir::Value & input,float value)203 StatusOr<mlir::Value> GetFPConstOfType(mlir::OpBuilder& builder,
204 const mlir::Value& input, float value) {
205 if (mlir::TensorType type = input.getType().dyn_cast<mlir::TensorType>()) {
206 return builder
207 .create<mlir::TF::ConstOp>(
208 input.getLoc(),
209 mlir::DenseFPElementsAttr::get<float>(
210 mlir::RankedTensorType::get({}, type.getElementType()),
211 {value}))
212 .output();
213 } else {
214 return errors::Unimplemented("non tensor type for labels is not supported");
215 }
216 }
217
218 // Takes input, which has layout agreeing with the truncation of desired_layout
219 // and runs OneHot on it to make it 2 dimensions.
220 // Assumes builder's insertion point is after input and desired_layout is rank
221 // 2.
222 //
223 // OneHot's element type matches that of features and the number of class is
224 // derived from features last dimension and the number of shards in the last
225 // dimension of desired layout.
226 //
227 // TODO(bfontain): Extract and share with OneHotSPMDExpander
ComputeOneHot(mlir::OpBuilder & builder,const mlir::Value & input,const mlir::Value & features,const Layout & desired_layout)228 StatusOr<mlir::Value> ComputeOneHot(mlir::OpBuilder& builder,
229 const mlir::Value& input,
230 const mlir::Value& features,
231 const Layout& desired_layout) {
232 // Get the number of classes for this onehot. The number of classes is the
233 // global size of the last dimension of features.
234 mlir::RankedTensorType features_type =
235 features.getType().dyn_cast<mlir::RankedTensorType>();
236 if (!features_type)
237 return errors::InvalidArgument(
238 "feature input shape must be statically known");
239 if (features_type.getRank() == 0)
240 return errors::InvalidArgument(
241 "expected feature input to have at least rank 1, but found rank 0");
242
243 const int64_t local_classes = features_type.getShape().back();
244 const int64_t classes =
245 local_classes *
246 desired_layout.num_shards_for_dim(desired_layout.sharding_specs().back());
247
248 int64_t num_shards = desired_layout.num_shards_for_dim(desired_layout.dim(1));
249 if (classes % num_shards)
250 return errors::InvalidArgument("unable to shard onehot with size ", classes,
251 " over dimension with ", num_shards,
252 " shards");
253 const mlir::Location& loc = input.getLoc();
254
255 mlir::Value depth = CreateIntScalarConst(classes / num_shards, builder, loc,
256 /*use_int64=*/false);
257
258 // TODO(bfontain): Extract this block (upto and including the SqueezeOp) to
259 // a common function.
260 mlir::tf_device::ClusterOp cluster =
261 depth.getDefiningOp()->getParentOfType<mlir::tf_device::ClusterOp>();
262
263 // `mesh_coordinates` is tensor of size [1, mesh_size] where each
264 // element in the tensor refers to shard id for the specified mesh
265 // dimension.
266 TF_ASSIGN_OR_RETURN(mlir::Value mesh_coordinates,
267 GetMeshCoordinatesFromCluster(cluster));
268
269 const int mesh_dim_index = desired_layout.mesh().GetMeshDimIndexWithName(
270 desired_layout.sharding_spec(/*idx=*/1));
271
272 // Slice out the [1,1] for mesh_dim_index.
273 mlir::Value shard_id =
274 builder
275 .create<mlir::TF::SliceOp>(
276 loc, mlir::RankedTensorType::get({1, 1}, builder.getI32Type()),
277 mesh_coordinates,
278 IntConst(builder, input.getLoc(), {0, mesh_dim_index}),
279 IntConst(builder, input.getLoc(), {1, 1}))
280 .output();
281
282 shard_id = builder
283 .create<mlir::TF::SqueezeOp>(
284 loc, mlir::RankedTensorType::get({}, builder.getI32Type()),
285 shard_id, builder.getI64ArrayAttr({0, 1}))
286 .output();
287
288 // `new_indices` = `input` - `shard_id` * (classes/num_shards)
289 mlir::Value id_offset =
290 builder.create<mlir::TF::MulOp>(loc, shard_id, depth).z();
291
292 // Note that the type of id_offset (int32) may not match the type of input.
293 // So we insert a cast in this case.
294 mlir::TensorType input_type = input.getType().dyn_cast<mlir::TensorType>();
295 if (!input_type) return errors::InvalidArgument("input is not a TensorType");
296 if (!input_type.getElementType().isInteger(32))
297 id_offset =
298 builder
299 .create<mlir::TF::CastOp>(
300 loc,
301 mlir::RankedTensorType::get({}, input_type.getElementType()),
302 id_offset)
303 .y();
304
305 mlir::Value indices =
306 builder.create<mlir::TF::SubOp>(loc, input, id_offset).z();
307
308 TF_ASSIGN_OR_RETURN(mlir::Value on_value,
309 GetFPConstOfType(builder, features, 1.0));
310 TF_ASSIGN_OR_RETURN(mlir::Value off_value,
311 GetFPConstOfType(builder, features, 0.0));
312
313 return builder
314 .create<mlir::TF::OneHotOp>(input.getLoc(), indices, depth, on_value,
315 off_value, builder.getI64IntegerAttr(1))
316 .output();
317 }
318
319 } // namespace
320
321 // Expander for Softmax and LogSoftmax ops.
ExpandOp(mlir::Operation * op)322 StatusOr<mlir::Operation*> SoftmaxOpSPMDExpander::ExpandOp(
323 mlir::Operation* op) {
324 TF_ASSIGN_OR_RETURN(auto logits_layout,
325 ExtractLayoutFromOperand(op->getOperand(0)));
326
327 if (!logits_layout) {
328 return errors::InvalidArgument("Failed during SPMD expansion of ",
329 OpName(op),
330 ". Layout of logits input must be known.");
331 }
332
333 // (Log)Softmax's logits are a rank >= 1 tensor. We reduce over the last
334 // dimension. If this is replicated, we don't need any cross-replica
335 // operations and can just emit the op as is.
336 if (logits_layout->IsLastDimReplicated())
337 return InferSPMDExpandedLocalShape(op);
338
339 mlir::OpBuilder builder(op);
340 builder.setInsertionPointAfter(op);
341
342 TF_ASSIGN_OR_RETURN(
343 const mlir::Value new_softmax,
344 ComputeShardedSoftmax(builder, op->getOperand(0), *logits_layout,
345 mlir::isa<mlir::TF::LogSoftmaxOp>(op)));
346
347 op->getOpResult(0).replaceAllUsesWith(new_softmax);
348 op->erase();
349 return new_softmax.getDefiningOp();
350 }
351
352 StatusOr<llvm::DenseMap<int, Layout>>
ComputeLayoutForward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts)353 SoftmaxOpSPMDExpander::ComputeLayoutForward(
354 mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts) {
355 // We want to use the same layout for the output.
356 return input_layouts;
357 }
358
359 StatusOr<llvm::DenseMap<int, Layout>>
ComputeLayoutBackward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & output_layouts)360 SoftmaxOpSPMDExpander::ComputeLayoutBackward(
361 mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts) {
362 // We want to use the same layout for the input.
363 return output_layouts;
364 }
365
366 // Takes the input and output layouts and
367 // 1) Selects a batch and class sharding from the layouts
368 // 2) Applies relayout to the input
369 // 3) Sets the new features and loss layout. Takes into account broadcasting.
370 // 4) Returns the full layout for backprop/loss.
MaybeRelayoutInputs(mlir::Operation * op,bool is_sparse,const Layout & features_layout,const Layout & labels_layout,const Layout & loss_layout,const Layout & backprop_layout,Layout & new_features_layout,Layout & new_labels_layout)371 StatusOr<Layout> SoftmaxLossOpSPMDExpander::MaybeRelayoutInputs(
372 mlir::Operation* op, bool is_sparse, const Layout& features_layout,
373 const Layout& labels_layout, const Layout& loss_layout,
374 const Layout& backprop_layout, Layout& new_features_layout,
375 Layout& new_labels_layout) {
376 // This layout represents the 'internal layout' that the softmax will be
377 // operating on. Inputs will be relayout'ed to this layout and outputs will be
378 // relayout'ed from this layout to their desired layout.
379 std::vector<ShardingSpec> internal_layout(2);
380 internal_layout[0].set_sharding_spec(Layout::kUnshardedDim);
381 internal_layout[1].set_sharding_spec(Layout::kUnshardedDim);
382
383 // Choose an internal layout, ideally this layout would be chosen so that
384 // the relayout costs for the inputs (from features_layout/labels_layout to
385 // internal_layout) and the outputs (from internal_layout to
386 // loss_layout/backprop_layout) are minimized, but we will do something more
387 // naive for now.
388
389 // Pick a batch sharding, first from features, then labels, loss and backprop.
390 // Due to possible broadcasting on features and labels, they will only
391 // have a batch dim if they are rank 2.
392 if (features_layout.rank() == 2) internal_layout[0] = features_layout.dim(0);
393 if (((labels_layout.rank() == 2) ||
394 (is_sparse && labels_layout.rank() == 1)) &&
395 Layout::IsUnshardedSpec(internal_layout[0]))
396 internal_layout[0] = labels_layout.dim(0);
397 if (Layout::IsUnshardedSpec(internal_layout[0]))
398 internal_layout[0] = loss_layout.dim(0);
399 if (Layout::IsUnshardedSpec(internal_layout[0]))
400 internal_layout[0] = backprop_layout.dim(0);
401
402 // Pick a class sharding, first from features, then labels and backprop.
403 // The class dim for features and labels is always the last dim if it exists.
404 // Note that loss and backprop have fixed ranks 1 and 2 respectively where as
405 // ranks of features and labels may involved broadcasting.
406 if (features_layout.rank() > 0 &&
407 (internal_layout[0].sharding_spec() !=
408 features_layout.sharding_spec(features_layout.rank() - 1)))
409 internal_layout[1] = features_layout.dim(features_layout.rank() - 1);
410 if (!is_sparse && labels_layout.rank() > 0 &&
411 Layout::IsUnshardedSpec(internal_layout[1]) &&
412 (internal_layout[0].sharding_spec() !=
413 labels_layout.sharding_spec(labels_layout.rank() - 1)))
414 internal_layout[1] = labels_layout.dim(labels_layout.rank() - 1);
415 if (Layout::IsUnshardedSpec(internal_layout[1]) &&
416 (internal_layout[0].sharding_spec() != backprop_layout.sharding_spec(1)))
417 internal_layout[1] = backprop_layout.dim(1);
418
419 TF_ASSIGN_OR_RETURN(
420 llvm::ArrayRef<int64_t> features_global_shape,
421 GetGlobalShapeOfValueFromDTensorLayout(op->getOperand(0)));
422
423 // At this point we need to compute the new layout of features and labels.
424 // Broadcasting makes this more complicated: First we truncate the correct
425 // rank and then set any dimensions where the global shape is size 1 to
426 // unsharded.
427 TF_ASSIGN_OR_RETURN(
428 new_features_layout,
429 GetBroadcastedLayout(features_global_shape, internal_layout,
430 features_layout.mesh()));
431
432 TF_ASSIGN_OR_RETURN(
433 const mlir::Value new_features,
434 EmitRelayout(op->getOperand(0), features_layout, new_features_layout));
435
436 op->setOperand(0, new_features);
437
438 TF_ASSIGN_OR_RETURN(
439 llvm::ArrayRef<int64_t> labels_global_shape,
440 GetGlobalShapeOfValueFromDTensorLayout(op->getOperand(1)));
441
442 if (is_sparse) {
443 // If we are sparse, then the only possible dimension is the batch_dim.
444 std::vector<ShardingSpec> sparse_specs = {internal_layout[0]};
445 TF_ASSIGN_OR_RETURN(new_labels_layout,
446 GetBroadcastedLayout(labels_global_shape, sparse_specs,
447 labels_layout.mesh()));
448 } else {
449 TF_ASSIGN_OR_RETURN(
450 new_labels_layout,
451 GetBroadcastedLayout(labels_global_shape, internal_layout,
452 labels_layout.mesh()));
453 }
454
455 TF_ASSIGN_OR_RETURN(
456 const mlir::Value new_labels,
457 EmitRelayout(op->getOperand(1), labels_layout, new_labels_layout));
458
459 op->setOperand(1, new_labels);
460
461 return Layout::GetLayout(internal_layout, features_layout.mesh());
462 }
463
464 // Takes the given loss, backprop values and relayouts them out to the required
465 // layouts and pass them through an IdentityN op.
466 // This assumes that the input have local shape in their type.
MaybeRelayoutOutputs(mlir::Operation * op,const mlir::Value & loss,const mlir::Value & backprop,const Layout & output_layout,const Layout & loss_layout,const Layout & backprop_layout)467 StatusOr<mlir::Operation*> SoftmaxLossOpSPMDExpander::MaybeRelayoutOutputs(
468 mlir::Operation* op, const mlir::Value& loss, const mlir::Value& backprop,
469 const Layout& output_layout, const Layout& loss_layout,
470 const Layout& backprop_layout) {
471 const Layout current_loss_layout = output_layout.Truncate(/*split_point=*/1);
472 const Layout& current_backprop_layout = output_layout;
473
474 llvm::SmallPtrSet<mlir::Operation*, 4> newly_created_ops;
475 TF_ASSIGN_OR_RETURN(
476 const mlir::Value new_loss,
477 EmitRelayout(loss, current_loss_layout, loss_layout, &newly_created_ops));
478
479 TF_ASSIGN_OR_RETURN(const mlir::Value new_backprop,
480 EmitRelayout(backprop, current_backprop_layout,
481 backprop_layout, &newly_created_ops));
482
483 mlir::OpBuilder builder(loss.getContext());
484
485 if (new_loss.getDefiningOp()->isBeforeInBlock(new_backprop.getDefiningOp()))
486 builder.setInsertionPointAfterValue(new_backprop);
487 else
488 builder.setInsertionPointAfterValue(new_loss);
489
490 llvm::SmallVector<mlir::Type, 4> types = {new_loss.getType(),
491 new_backprop.getType()};
492 llvm::SmallVector<mlir::Value, 4> values = {new_loss, new_backprop};
493
494 mlir::TF::IdentityNOp identity_op =
495 builder.create<mlir::TF::IdentityNOp>(loss.getLoc(), types, values);
496
497 newly_created_ops.insert(identity_op);
498
499 op->getResult(0).replaceAllUsesExcept(identity_op.getResult(0),
500 newly_created_ops);
501 op->getResult(1).replaceAllUsesExcept(identity_op.getResult(1),
502 newly_created_ops);
503
504 // If the op we are expanding isn't being used any more, erase it from the
505 // program.
506 if (op->getResult(0).use_empty() && op->getResult(1).use_empty()) op->erase();
507
508 return identity_op.getOperation();
509 }
510
ExpandOp(mlir::Operation * op)511 StatusOr<mlir::Operation*> SoftmaxLossOpSPMDExpander::ExpandOp(
512 mlir::Operation* op) {
513 if (!mlir::isa<mlir::TF::SoftmaxCrossEntropyWithLogitsOp>(op) &&
514 !mlir::isa<mlir::TF::SparseSoftmaxCrossEntropyWithLogitsOp>(op))
515 return errors::InvalidArgument(
516 "unsupported op for in SoftmaxLossOpSPMDExpander");
517
518 TF_ASSIGN_OR_RETURN(const Layout& features_layout,
519 ExtractRequiredLayoutFromOperand(op->getOperand(0)));
520 TF_ASSIGN_OR_RETURN(const Layout& labels_layout,
521 ExtractRequiredLayoutFromOperand(op->getOperand(1)));
522 TF_ASSIGN_OR_RETURN(const std::vector<Layout>& output_layouts,
523 ExtractRequiredLayoutFromOp(op));
524
525 const bool is_sparse =
526 mlir::isa<mlir::TF::SparseSoftmaxCrossEntropyWithLogitsOp>(op);
527
528 Layout new_features_layout;
529 Layout new_labels_layout;
530
531 TF_ASSIGN_OR_RETURN(
532 const Layout internal_layout,
533 MaybeRelayoutInputs(op, is_sparse, features_layout, labels_layout,
534 output_layouts[0], output_layouts[1],
535 new_features_layout, new_labels_layout));
536
537 assert(internal_layout.rank() == 2);
538
539 // If the class dim is unshared, we can emit a local op.
540 if (Layout::IsUnshardedSpec(internal_layout.dim(1))) {
541 op = InferSPMDExpandedLocalShape(op);
542 return MaybeRelayoutOutputs(op, op->getResult(0), op->getResult(1),
543 internal_layout, output_layouts[0],
544 output_layouts[1]);
545 }
546
547 mlir::OpBuilder builder(op);
548 builder.setInsertionPointAfter(op);
549
550 mlir::Value features = op->getOperand(0);
551 mlir::Value labels = op->getOperand(1);
552 if (is_sparse) {
553 // SparseSoftmaxCrossEntropyWithLogits(features, labels) can be rewritten
554 // as SoftmaxCrossEntropyWithLogits(features, OneHot(labels)).
555 // Note that this is what is done in the XLA kernel for this op.
556 TF_ASSIGN_OR_RETURN(
557 labels, ComputeOneHot(builder, labels, features, internal_layout));
558 }
559
560 if (features_layout.rank() == 0)
561 return errors::Unimplemented(
562 "scalar values features is not currently supported");
563
564 // SoftmaxCrossEntropyWithLogitsOp is the same as:
565 // loss = -tf.reduce_sum(labels*tf.LogSoftmax(features), class_dim)
566 // backprop = tf.Softmax(features) - labels
567
568 mlir::Value shifted_logits;
569 mlir::Value exp_of_shifted_logits;
570 mlir::Value sum_of_exp;
571
572 // Note that its possible that features is shape [x, 1] and is broadcasted
573 // to match labels. In this case we are doing a bunch of extra work, since
574 // softmax is 1 and log_softmax is 0.
575 TF_RETURN_IF_ERROR(ComputeExpAndSum(builder, features, new_features_layout,
576 shifted_logits, exp_of_shifted_logits,
577 sum_of_exp));
578
579 const mlir::Value log_softmax =
580 ComputeLogSoftmax(builder, shifted_logits, sum_of_exp);
581 const mlir::Value softmax =
582 ComputeSoftmax(builder, exp_of_shifted_logits, sum_of_exp);
583
584 // Mimic the XLA, which uses where/select to ensure that sub is zero when
585 // labels are zero.
586 TF_ASSIGN_OR_RETURN(const mlir::Value features_zero,
587 GetFPConstOfType(builder, features, 0.0));
588 TF_ASSIGN_OR_RETURN(const mlir::Value labels_zero,
589 GetFPConstOfType(builder, labels, 0.0));
590
591 const mlir::Value is_labels_zero =
592 builder
593 .create<mlir::TF::EqualOp>(op->getLoc(), labels, labels_zero,
594 builder.getBoolAttr(true))
595 .z();
596 const mlir::Value safe_softmax =
597 builder
598 .create<mlir::TF::SelectV2Op>(op->getLoc(), is_labels_zero,
599 features_zero, log_softmax)
600 .output();
601 const mlir::Value prod =
602 builder.create<mlir::TF::MulOp>(op->getLoc(), labels, safe_softmax).z();
603
604 // Compute the reduce sum
605 TF_ASSIGN_OR_RETURN(
606 mlir::Value positive_loss,
607 ComputeGlobalReduce(builder, prod, internal_layout, /*reduced_dims=*/{1},
608 kReduceOpAdd, /*keep_dims=*/false));
609
610 builder.setInsertionPointAfterValue(positive_loss);
611 mlir::Value loss =
612 builder.create<mlir::TF::NegOp>(op->getLoc(), positive_loss).y();
613
614 mlir::Value backprop =
615 builder.create<mlir::TF::SubOp>(op->getLoc(), softmax, labels);
616
617 return MaybeRelayoutOutputs(op, loss, backprop, internal_layout,
618 output_layouts[0], output_layouts[1]);
619 }
620
621 StatusOr<llvm::DenseMap<int, Layout>>
ComputeLayoutForward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts)622 SoftmaxLossOpSPMDExpander::ComputeLayoutForward(
623 mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts) {
624 TF_ASSIGN_OR_RETURN(const Mesh mesh, ExtractDeviceMeshEnclosingCluster(op));
625 const bool is_sparse =
626 mlir::isa<mlir::TF::SparseSoftmaxCrossEntropyWithLogitsOp>(op);
627
628 // loss is sum(-labels * logsoftmax(features)), so the layout is batch
629 // sharded if labels and features are batch sharded on the same mesh dim or
630 // if one is replicated.
631 // backprop is softmax(features) - labels
632
633 absl::optional<Layout> features_layout;
634 if (input_layouts.find(0) != input_layouts.end())
635 features_layout.emplace(input_layouts.lookup(0));
636 absl::optional<Layout> labels_layout;
637 if (input_layouts.find(1) != input_layouts.end())
638 labels_layout.emplace(input_layouts.lookup(1));
639
640 // We need to compute shardings for two dimensions: batch and class.
641 std::vector<ShardingSpec> layout_specs(2);
642 layout_specs[0].set_sharding_spec(Layout::kUnshardedDim);
643 layout_specs[1].set_sharding_spec(Layout::kUnshardedDim);
644
645 // First pick the batch dimension, set it to the batch dimension of features
646 // if it exists otherwise to the batch dimesion of labels.
647 if (features_layout && (features_layout->rank() == 2))
648 layout_specs[0] = features_layout->dim(0);
649 if (labels_layout &&
650 (labels_layout->rank() == 2 ||
651 (is_sparse && labels_layout->rank() == 1)) &&
652 Layout::IsUnshardedSpec(layout_specs[0]))
653 layout_specs[0] = labels_layout->dim(0);
654
655 // The class dim for features and labels is always the last dim if it
656 // exists.
657 if (features_layout && (features_layout->rank() > 0) &&
658 (layout_specs[0].sharding_spec() !=
659 features_layout->sharding_spec(features_layout->rank() - 1)))
660 layout_specs[1] = features_layout->dim(features_layout->rank() - 1);
661 if (!is_sparse && labels_layout && (labels_layout->rank() > 0) &&
662 Layout::IsUnshardedSpec(layout_specs[1]) &&
663 (layout_specs[0].sharding_spec() !=
664 labels_layout->sharding_spec(labels_layout->rank() - 1)))
665 layout_specs[1] = labels_layout->dim(labels_layout->rank() - 1);
666
667 TF_ASSIGN_OR_RETURN(const Layout backprop_layout,
668 Layout::GetLayout(layout_specs, mesh));
669 const Layout loss_layout = backprop_layout.Truncate(/*split_point=*/1);
670
671 return llvm::DenseMap<int, Layout>({{0, loss_layout}, {1, backprop_layout}});
672 }
673
674 StatusOr<llvm::DenseMap<int, Layout>>
ComputeLayoutBackward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & output_layouts)675 SoftmaxLossOpSPMDExpander::ComputeLayoutBackward(
676 mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts) {
677 TF_ASSIGN_OR_RETURN(const Mesh mesh, ExtractDeviceMeshEnclosingCluster(op));
678 const bool is_sparse =
679 mlir::isa<mlir::TF::SparseSoftmaxCrossEntropyWithLogitsOp>(op);
680
681 absl::optional<Layout> loss_layout;
682 if (output_layouts.find(0) != output_layouts.end())
683 loss_layout.emplace(output_layouts.lookup(0));
684 absl::optional<Layout> backprop_layout;
685 if (output_layouts.find(1) != output_layouts.end())
686 backprop_layout.emplace(output_layouts.lookup(1));
687
688 // We need to compute two possible shardings:
689 // One for the batch dimension and one for the class dimension.
690 std::vector<ShardingSpec> layout_specs(2);
691 layout_specs[0].set_sharding_spec(Layout::kUnshardedDim);
692 layout_specs[1].set_sharding_spec(Layout::kUnshardedDim);
693
694 // Respect the loss layout if it is set, otherwise use the backprop
695 // layout for the batch_dim.
696 if (loss_layout) layout_specs[0] = loss_layout->dim(0);
697 if (backprop_layout && Layout::IsUnshardedSpec(layout_specs[0]))
698 layout_specs[0] = backprop_layout->dim(0);
699
700 // Only backprop has class dim so use that if it is available.
701 if (backprop_layout &&
702 backprop_layout->sharding_spec(1) != layout_specs[0].sharding_spec())
703 layout_specs[1] = backprop_layout->dim(1);
704
705 TF_ASSIGN_OR_RETURN(const auto features_shape,
706 GetShapeOfValue(op->getOperand(0)));
707 TF_ASSIGN_OR_RETURN(const auto labels_shape,
708 GetShapeOfValue(op->getOperand(1)));
709 TF_ASSIGN_OR_RETURN(const Layout features_layout,
710 GetBroadcastedLayout(features_shape, layout_specs, mesh));
711 if (is_sparse) {
712 // Drop the class sharding as the labels don't have class dimension in
713 // the sparse version.
714 layout_specs.resize(1);
715 }
716 TF_ASSIGN_OR_RETURN(const Layout labels_layout,
717 GetBroadcastedLayout(labels_shape, layout_specs, mesh));
718
719 return llvm::DenseMap<int, Layout>(
720 {{0, features_layout}, {1, labels_layout}});
721 }
722
723 } // namespace dtensor
724 } // namespace tensorflow
725