xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/expansions/softmax_spmd_expander.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/dtensor/mlir/expansions/softmax_spmd_expander.h"
17 
18 #include <optional>
19 #include <string>
20 
21 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
22 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
23 #include "tensorflow/dtensor/cc/tensor_layout.h"
24 #include "tensorflow/dtensor/mlir/collectives.h"
25 #include "tensorflow/dtensor/mlir/layout_parsing.h"
26 #include "tensorflow/dtensor/mlir/op_utils.h"
27 #include "tensorflow/dtensor/mlir/shape_utils.h"
28 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
29 #include "tensorflow/dtensor/mlir/value_utils.h"
30 
31 namespace tensorflow {
32 namespace dtensor {
33 namespace {
34 
35 // Computes a local reduce followed by an EmitAllReduce. This performs a global
36 // reduction, output will have global shape 1 on the reduced axes if keep dims
37 // is true otherwise the axes will be removed.
38 // Assumes builder's insertion point is after input.
ComputeGlobalReduce(mlir::OpBuilder & builder,const mlir::Value & input,const Layout & input_layout,const absl::flat_hash_set<int> & reduced_dims,absl::string_view reduce_op,bool keep_dims)39 StatusOr<mlir::Value> ComputeGlobalReduce(
40     mlir::OpBuilder& builder, const mlir::Value& input,
41     const Layout& input_layout, const absl::flat_hash_set<int>& reduced_dims,
42     absl::string_view reduce_op, bool keep_dims) {
43   const Layout reduction_layout =
44       input_layout.GetLayoutWithReducedDims(reduced_dims,
45                                             /*keep_dims=*/true);
46   std::vector<int32> reduce_dim_array(reduced_dims.begin(), reduced_dims.end());
47   const mlir::Value reduction_indices =
48       IntConst(builder, input.getLoc(), reduce_dim_array);
49   mlir::Operation* local_reduce;
50 
51   // First compute a local reduce
52   if (reduce_op == kReduceOpAdd) {
53     local_reduce = builder.create<mlir::TF::SumOp>(
54         input.getLoc(), input, reduction_indices,
55         /*keep_dims=*/builder.getBoolAttr(true));
56   } else if (reduce_op == kReduceOpMax) {
57     local_reduce = builder.create<mlir::TF::MaxOp>(
58         input.getLoc(), input, reduction_indices,
59         /*keep_dims=*/builder.getBoolAttr(true));
60   } else {
61     return errors::Unimplemented("reduction ", reduce_op, " not implemented");
62   }
63 
64   // Then an all reduce.
65   absl::flat_hash_set<std::string> reduced_sharding_specs;
66   for (const int dim : reduced_dims)
67     if (Layout::IsShardedSpec(input_layout.dim(dim)))
68       reduced_sharding_specs.emplace(input_layout.sharding_spec(dim));
69   TF_ASSIGN_OR_RETURN(
70       mlir::Operation * global_reduce,
71       EmitAllReduce(builder, reduction_layout, reduced_sharding_specs,
72                     local_reduce, reduce_op));
73 
74   if (!keep_dims) {
75     mlir::RankedTensorType output_type =
76         global_reduce->getResult(0)
77             .getType()
78             .dyn_cast<mlir::RankedTensorType>();
79     if (!output_type)
80       return errors::Internal(
81           "output of EmitAllReduce is not a RankedTensorType");
82     std::vector<int64_t> new_shape;
83     for (int i = 0; i < output_type.getRank(); ++i)
84       if (!reduced_dims.contains(i))
85         new_shape.emplace_back(output_type.getDimSize(i));
86     mlir::RankedTensorType new_type =
87         mlir::RankedTensorType::get(new_shape, output_type.getElementType());
88     // Upcast the dimensions to int64_t as SqueezeOp requires this for its
89     // dimension attribute type. Everything else is OK with int32_t dimensions.
90     std::vector<int64_t> reduce_dim_array_64(reduced_dims.begin(),
91                                              reduced_dims.end());
92     global_reduce = builder.create<mlir::TF::SqueezeOp>(
93         input.getLoc(), new_type, global_reduce->getResult(0),
94         builder.getI64ArrayAttr(reduce_dim_array_64));
95   }
96   return global_reduce->getResult(0);
97 }
98 
99 // Takes a sharded logits and compute both the shifted exponentiation of the
100 // logits and its sum. Assumes that builder's insertion point is after logits.
ComputeExpAndSum(mlir::OpBuilder & builder,const mlir::Value & logits,const Layout & logits_layout,mlir::Value & shifted_logits,mlir::Value & exp_of_shifted_logits,mlir::Value & sum_of_exp)101 Status ComputeExpAndSum(mlir::OpBuilder& builder, const mlir::Value& logits,
102                         const Layout& logits_layout,
103                         mlir::Value& shifted_logits,
104                         mlir::Value& exp_of_shifted_logits,
105                         mlir::Value& sum_of_exp) {
106   auto loc = logits.getLoc();
107 
108   if (logits_layout.rank() == 0)
109     return errors::Unimplemented("softmax not supported for rank 0 tensors.");
110 
111   const int64 class_dimension = logits_layout.rank() - 1;
112 
113   // Softmax is exp(input)/sum(exp(input)) and LogSoftmax is
114   // logits - log(sum(exp(input)) where the sum takes place on the
115   // last axis.
116   // For numerical stability, we shift the logits by the max (along
117   // the last axis) before doing the above calculation.
118 
119   // Construct the max.
120   TF_ASSIGN_OR_RETURN(
121       const mlir::Value max_logits,
122       ComputeGlobalReduce(builder, logits, logits_layout, {class_dimension},
123                           kReduceOpMax, /*keep_dims=*/true));
124 
125   // Subtract max from local copy of logits.
126   shifted_logits =
127       builder.create<mlir::TF::SubOp>(loc, logits, max_logits).getResult();
128   exp_of_shifted_logits =
129       builder.create<mlir::TF::ExpOp>(loc, shifted_logits).getResult();
130 
131   // Sum the exponential.
132   TF_ASSIGN_OR_RETURN(
133       sum_of_exp,
134       ComputeGlobalReduce(builder, exp_of_shifted_logits, logits_layout,
135                           {class_dimension}, kReduceOpAdd,
136                           /*keep_dims=*/true));
137   return OkStatus();
138 }
139 
140 // Computes softmax from its components. Assumes that builder's insertion point
141 // is after sum_of_exp and exp_of_shifted_logits.
ComputeSoftmax(mlir::OpBuilder & builder,const mlir::Value & exp_of_shifted_logits,const mlir::Value & sum_of_exp)142 mlir::Value ComputeSoftmax(mlir::OpBuilder& builder,
143                            const mlir::Value& exp_of_shifted_logits,
144                            const mlir::Value& sum_of_exp) {
145   // For Softmax, we compute exp(shifted_logits)/sum(exp(shifted_logits))
146   auto softmax = builder.create<mlir::TF::DivOp>(
147       exp_of_shifted_logits.getLoc(), exp_of_shifted_logits, sum_of_exp);
148   return softmax.getResult();
149 }
150 
151 // Computes softmax from its components. Assumes that builder's insertion point
152 // is after shifted_logits and sum_of_exp.
ComputeLogSoftmax(mlir::OpBuilder & builder,const mlir::Value & shifted_logits,const mlir::Value & sum_of_exp)153 mlir::Value ComputeLogSoftmax(mlir::OpBuilder& builder,
154                               const mlir::Value& shifted_logits,
155                               const mlir::Value& sum_of_exp) {
156   // For LogSoftmax, we compute shifted_logs - log(sum(exp(shifted_logits)))
157   auto log_of_sum =
158       builder.create<mlir::TF::LogOp>(shifted_logits.getLoc(), sum_of_exp);
159   auto log_softmax = builder.create<mlir::TF::SubOp>(
160       shifted_logits.getLoc(), shifted_logits, log_of_sum.getResult());
161   return log_softmax.getResult();
162 }
163 
164 // Computes the softmax of the input along the last axis, assuming that the
165 // input is sharded along that axis.
ComputeShardedSoftmax(mlir::OpBuilder & builder,const mlir::Value & logits,const Layout & logits_layout,bool log_softmax)166 StatusOr<mlir::Value> ComputeShardedSoftmax(mlir::OpBuilder& builder,
167                                             const mlir::Value& logits,
168                                             const Layout& logits_layout,
169                                             bool log_softmax) {
170   mlir::Value shifted_logits;
171   mlir::Value exp_of_shifted_logits;
172   mlir::Value sum_of_exp;
173   TF_RETURN_IF_ERROR(ComputeExpAndSum(builder, logits, logits_layout,
174                                       shifted_logits, exp_of_shifted_logits,
175                                       sum_of_exp));
176 
177   if (log_softmax) {
178     return ComputeLogSoftmax(builder, shifted_logits, sum_of_exp);
179   } else {
180     return ComputeSoftmax(builder, exp_of_shifted_logits, sum_of_exp);
181   }
182 }
183 
184 // Creates a layout from specs which is
185 // 1) Left truncated to match the size of global_shape.
186 // 2) Has unsharded dimensions where ever global_shape is 1.
GetBroadcastedLayout(llvm::ArrayRef<int64_t> global_shape,const std::vector<ShardingSpec> & specs,const Mesh & mesh)187 StatusOr<Layout> GetBroadcastedLayout(llvm::ArrayRef<int64_t> global_shape,
188                                       const std::vector<ShardingSpec>& specs,
189                                       const Mesh& mesh) {
190   std::vector<ShardingSpec> new_specs(global_shape.size());
191   for (int i = 0; i < global_shape.size(); ++i) {
192     if (global_shape[i] == 1)
193       new_specs[i].set_sharding_spec(Layout::kUnshardedDim);
194     else
195       new_specs[i] = specs[i + specs.size() - global_shape.size()];
196   }
197 
198   return Layout::GetLayout(new_specs, mesh);
199 }
200 
201 // Gets a scalar floating point constant with the same element type as the input
202 // value. Assumes builder's insertion point is after input.
GetFPConstOfType(mlir::OpBuilder & builder,const mlir::Value & input,float value)203 StatusOr<mlir::Value> GetFPConstOfType(mlir::OpBuilder& builder,
204                                        const mlir::Value& input, float value) {
205   if (mlir::TensorType type = input.getType().dyn_cast<mlir::TensorType>()) {
206     return builder
207         .create<mlir::TF::ConstOp>(
208             input.getLoc(),
209             mlir::DenseFPElementsAttr::get<float>(
210                 mlir::RankedTensorType::get({}, type.getElementType()),
211                 {value}))
212         .output();
213   } else {
214     return errors::Unimplemented("non tensor type for labels is not supported");
215   }
216 }
217 
218 // Takes input, which has layout agreeing with the truncation of desired_layout
219 // and runs OneHot on it to make it 2 dimensions.
220 // Assumes builder's insertion point is after input and desired_layout is rank
221 // 2.
222 //
223 // OneHot's element type matches that of features and the number of class is
224 // derived from features last dimension and the number of shards in the last
225 // dimension of desired layout.
226 //
227 // TODO(bfontain): Extract and share with OneHotSPMDExpander
ComputeOneHot(mlir::OpBuilder & builder,const mlir::Value & input,const mlir::Value & features,const Layout & desired_layout)228 StatusOr<mlir::Value> ComputeOneHot(mlir::OpBuilder& builder,
229                                     const mlir::Value& input,
230                                     const mlir::Value& features,
231                                     const Layout& desired_layout) {
232   // Get the number of classes for this onehot. The number of classes is the
233   // global size of the last dimension of features.
234   mlir::RankedTensorType features_type =
235       features.getType().dyn_cast<mlir::RankedTensorType>();
236   if (!features_type)
237     return errors::InvalidArgument(
238         "feature input shape must be statically known");
239   if (features_type.getRank() == 0)
240     return errors::InvalidArgument(
241         "expected feature input to have at least rank 1, but found rank 0");
242 
243   const int64_t local_classes = features_type.getShape().back();
244   const int64_t classes =
245       local_classes *
246       desired_layout.num_shards_for_dim(desired_layout.sharding_specs().back());
247 
248   int64_t num_shards = desired_layout.num_shards_for_dim(desired_layout.dim(1));
249   if (classes % num_shards)
250     return errors::InvalidArgument("unable to shard onehot with size ", classes,
251                                    " over dimension with ", num_shards,
252                                    " shards");
253   const mlir::Location& loc = input.getLoc();
254 
255   mlir::Value depth = CreateIntScalarConst(classes / num_shards, builder, loc,
256                                            /*use_int64=*/false);
257 
258   // TODO(bfontain): Extract this block (upto and including the SqueezeOp) to
259   // a common function.
260   mlir::tf_device::ClusterOp cluster =
261       depth.getDefiningOp()->getParentOfType<mlir::tf_device::ClusterOp>();
262 
263   // `mesh_coordinates` is tensor of size [1, mesh_size] where each
264   // element in the tensor refers to shard id for the specified mesh
265   // dimension.
266   TF_ASSIGN_OR_RETURN(mlir::Value mesh_coordinates,
267                       GetMeshCoordinatesFromCluster(cluster));
268 
269   const int mesh_dim_index = desired_layout.mesh().GetMeshDimIndexWithName(
270       desired_layout.sharding_spec(/*idx=*/1));
271 
272   // Slice out the [1,1] for mesh_dim_index.
273   mlir::Value shard_id =
274       builder
275           .create<mlir::TF::SliceOp>(
276               loc, mlir::RankedTensorType::get({1, 1}, builder.getI32Type()),
277               mesh_coordinates,
278               IntConst(builder, input.getLoc(), {0, mesh_dim_index}),
279               IntConst(builder, input.getLoc(), {1, 1}))
280           .output();
281 
282   shard_id = builder
283                  .create<mlir::TF::SqueezeOp>(
284                      loc, mlir::RankedTensorType::get({}, builder.getI32Type()),
285                      shard_id, builder.getI64ArrayAttr({0, 1}))
286                  .output();
287 
288   // `new_indices` = `input` - `shard_id` * (classes/num_shards)
289   mlir::Value id_offset =
290       builder.create<mlir::TF::MulOp>(loc, shard_id, depth).z();
291 
292   // Note that the type of id_offset (int32) may not match the type of input.
293   // So we insert a cast in this case.
294   mlir::TensorType input_type = input.getType().dyn_cast<mlir::TensorType>();
295   if (!input_type) return errors::InvalidArgument("input is not a TensorType");
296   if (!input_type.getElementType().isInteger(32))
297     id_offset =
298         builder
299             .create<mlir::TF::CastOp>(
300                 loc,
301                 mlir::RankedTensorType::get({}, input_type.getElementType()),
302                 id_offset)
303             .y();
304 
305   mlir::Value indices =
306       builder.create<mlir::TF::SubOp>(loc, input, id_offset).z();
307 
308   TF_ASSIGN_OR_RETURN(mlir::Value on_value,
309                       GetFPConstOfType(builder, features, 1.0));
310   TF_ASSIGN_OR_RETURN(mlir::Value off_value,
311                       GetFPConstOfType(builder, features, 0.0));
312 
313   return builder
314       .create<mlir::TF::OneHotOp>(input.getLoc(), indices, depth, on_value,
315                                   off_value, builder.getI64IntegerAttr(1))
316       .output();
317 }
318 
319 }  // namespace
320 
321 // Expander for Softmax and LogSoftmax ops.
ExpandOp(mlir::Operation * op)322 StatusOr<mlir::Operation*> SoftmaxOpSPMDExpander::ExpandOp(
323     mlir::Operation* op) {
324   TF_ASSIGN_OR_RETURN(auto logits_layout,
325                       ExtractLayoutFromOperand(op->getOperand(0)));
326 
327   if (!logits_layout) {
328     return errors::InvalidArgument("Failed during SPMD expansion of ",
329                                    OpName(op),
330                                    ". Layout of logits input must be known.");
331   }
332 
333   // (Log)Softmax's logits are a rank >= 1 tensor. We reduce over the last
334   // dimension. If this is replicated, we don't need any cross-replica
335   // operations and can just emit the op as is.
336   if (logits_layout->IsLastDimReplicated())
337     return InferSPMDExpandedLocalShape(op);
338 
339   mlir::OpBuilder builder(op);
340   builder.setInsertionPointAfter(op);
341 
342   TF_ASSIGN_OR_RETURN(
343       const mlir::Value new_softmax,
344       ComputeShardedSoftmax(builder, op->getOperand(0), *logits_layout,
345                             mlir::isa<mlir::TF::LogSoftmaxOp>(op)));
346 
347   op->getOpResult(0).replaceAllUsesWith(new_softmax);
348   op->erase();
349   return new_softmax.getDefiningOp();
350 }
351 
352 StatusOr<llvm::DenseMap<int, Layout>>
ComputeLayoutForward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts)353 SoftmaxOpSPMDExpander::ComputeLayoutForward(
354     mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts) {
355   // We want to use the same layout for the output.
356   return input_layouts;
357 }
358 
359 StatusOr<llvm::DenseMap<int, Layout>>
ComputeLayoutBackward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & output_layouts)360 SoftmaxOpSPMDExpander::ComputeLayoutBackward(
361     mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts) {
362   // We want to use the same layout for the input.
363   return output_layouts;
364 }
365 
366 // Takes the input and output layouts and
367 // 1) Selects a batch and class sharding from the layouts
368 // 2) Applies relayout to the input
369 // 3) Sets the new features and loss layout. Takes into account broadcasting.
370 // 4) Returns the full layout for backprop/loss.
MaybeRelayoutInputs(mlir::Operation * op,bool is_sparse,const Layout & features_layout,const Layout & labels_layout,const Layout & loss_layout,const Layout & backprop_layout,Layout & new_features_layout,Layout & new_labels_layout)371 StatusOr<Layout> SoftmaxLossOpSPMDExpander::MaybeRelayoutInputs(
372     mlir::Operation* op, bool is_sparse, const Layout& features_layout,
373     const Layout& labels_layout, const Layout& loss_layout,
374     const Layout& backprop_layout, Layout& new_features_layout,
375     Layout& new_labels_layout) {
376   // This layout represents the 'internal layout' that the softmax will be
377   // operating on. Inputs will be relayout'ed to this layout and outputs will be
378   // relayout'ed from this layout to their desired layout.
379   std::vector<ShardingSpec> internal_layout(2);
380   internal_layout[0].set_sharding_spec(Layout::kUnshardedDim);
381   internal_layout[1].set_sharding_spec(Layout::kUnshardedDim);
382 
383   // Choose an internal layout, ideally this layout would be chosen so that
384   // the relayout costs for the inputs (from features_layout/labels_layout to
385   // internal_layout) and the outputs (from internal_layout to
386   // loss_layout/backprop_layout) are minimized, but we will do something more
387   // naive for now.
388 
389   // Pick a batch sharding, first from features, then labels, loss and backprop.
390   // Due to possible broadcasting on features and labels, they will only
391   // have a batch dim if they are rank 2.
392   if (features_layout.rank() == 2) internal_layout[0] = features_layout.dim(0);
393   if (((labels_layout.rank() == 2) ||
394        (is_sparse && labels_layout.rank() == 1)) &&
395       Layout::IsUnshardedSpec(internal_layout[0]))
396     internal_layout[0] = labels_layout.dim(0);
397   if (Layout::IsUnshardedSpec(internal_layout[0]))
398     internal_layout[0] = loss_layout.dim(0);
399   if (Layout::IsUnshardedSpec(internal_layout[0]))
400     internal_layout[0] = backprop_layout.dim(0);
401 
402   // Pick a class sharding, first from features, then labels and backprop.
403   // The class dim for features and labels is always the last dim if it exists.
404   // Note that loss and backprop have fixed ranks 1 and 2 respectively where as
405   // ranks of features and labels may involved broadcasting.
406   if (features_layout.rank() > 0 &&
407       (internal_layout[0].sharding_spec() !=
408        features_layout.sharding_spec(features_layout.rank() - 1)))
409     internal_layout[1] = features_layout.dim(features_layout.rank() - 1);
410   if (!is_sparse && labels_layout.rank() > 0 &&
411       Layout::IsUnshardedSpec(internal_layout[1]) &&
412       (internal_layout[0].sharding_spec() !=
413        labels_layout.sharding_spec(labels_layout.rank() - 1)))
414     internal_layout[1] = labels_layout.dim(labels_layout.rank() - 1);
415   if (Layout::IsUnshardedSpec(internal_layout[1]) &&
416       (internal_layout[0].sharding_spec() != backprop_layout.sharding_spec(1)))
417     internal_layout[1] = backprop_layout.dim(1);
418 
419   TF_ASSIGN_OR_RETURN(
420       llvm::ArrayRef<int64_t> features_global_shape,
421       GetGlobalShapeOfValueFromDTensorLayout(op->getOperand(0)));
422 
423   // At this point we need to compute the new layout of features and labels.
424   // Broadcasting makes this more complicated: First we truncate the correct
425   // rank and then set any dimensions where the global shape is size 1 to
426   // unsharded.
427   TF_ASSIGN_OR_RETURN(
428       new_features_layout,
429       GetBroadcastedLayout(features_global_shape, internal_layout,
430                            features_layout.mesh()));
431 
432   TF_ASSIGN_OR_RETURN(
433       const mlir::Value new_features,
434       EmitRelayout(op->getOperand(0), features_layout, new_features_layout));
435 
436   op->setOperand(0, new_features);
437 
438   TF_ASSIGN_OR_RETURN(
439       llvm::ArrayRef<int64_t> labels_global_shape,
440       GetGlobalShapeOfValueFromDTensorLayout(op->getOperand(1)));
441 
442   if (is_sparse) {
443     // If we are sparse, then the only possible dimension is the batch_dim.
444     std::vector<ShardingSpec> sparse_specs = {internal_layout[0]};
445     TF_ASSIGN_OR_RETURN(new_labels_layout,
446                         GetBroadcastedLayout(labels_global_shape, sparse_specs,
447                                              labels_layout.mesh()));
448   } else {
449     TF_ASSIGN_OR_RETURN(
450         new_labels_layout,
451         GetBroadcastedLayout(labels_global_shape, internal_layout,
452                              labels_layout.mesh()));
453   }
454 
455   TF_ASSIGN_OR_RETURN(
456       const mlir::Value new_labels,
457       EmitRelayout(op->getOperand(1), labels_layout, new_labels_layout));
458 
459   op->setOperand(1, new_labels);
460 
461   return Layout::GetLayout(internal_layout, features_layout.mesh());
462 }
463 
464 // Takes the given loss, backprop values and relayouts them out to the required
465 // layouts and pass them through an IdentityN op.
466 // This assumes that the input have local shape in their type.
MaybeRelayoutOutputs(mlir::Operation * op,const mlir::Value & loss,const mlir::Value & backprop,const Layout & output_layout,const Layout & loss_layout,const Layout & backprop_layout)467 StatusOr<mlir::Operation*> SoftmaxLossOpSPMDExpander::MaybeRelayoutOutputs(
468     mlir::Operation* op, const mlir::Value& loss, const mlir::Value& backprop,
469     const Layout& output_layout, const Layout& loss_layout,
470     const Layout& backprop_layout) {
471   const Layout current_loss_layout = output_layout.Truncate(/*split_point=*/1);
472   const Layout& current_backprop_layout = output_layout;
473 
474   llvm::SmallPtrSet<mlir::Operation*, 4> newly_created_ops;
475   TF_ASSIGN_OR_RETURN(
476       const mlir::Value new_loss,
477       EmitRelayout(loss, current_loss_layout, loss_layout, &newly_created_ops));
478 
479   TF_ASSIGN_OR_RETURN(const mlir::Value new_backprop,
480                       EmitRelayout(backprop, current_backprop_layout,
481                                    backprop_layout, &newly_created_ops));
482 
483   mlir::OpBuilder builder(loss.getContext());
484 
485   if (new_loss.getDefiningOp()->isBeforeInBlock(new_backprop.getDefiningOp()))
486     builder.setInsertionPointAfterValue(new_backprop);
487   else
488     builder.setInsertionPointAfterValue(new_loss);
489 
490   llvm::SmallVector<mlir::Type, 4> types = {new_loss.getType(),
491                                             new_backprop.getType()};
492   llvm::SmallVector<mlir::Value, 4> values = {new_loss, new_backprop};
493 
494   mlir::TF::IdentityNOp identity_op =
495       builder.create<mlir::TF::IdentityNOp>(loss.getLoc(), types, values);
496 
497   newly_created_ops.insert(identity_op);
498 
499   op->getResult(0).replaceAllUsesExcept(identity_op.getResult(0),
500                                         newly_created_ops);
501   op->getResult(1).replaceAllUsesExcept(identity_op.getResult(1),
502                                         newly_created_ops);
503 
504   // If the op we are expanding isn't being used any more, erase it from the
505   // program.
506   if (op->getResult(0).use_empty() && op->getResult(1).use_empty()) op->erase();
507 
508   return identity_op.getOperation();
509 }
510 
ExpandOp(mlir::Operation * op)511 StatusOr<mlir::Operation*> SoftmaxLossOpSPMDExpander::ExpandOp(
512     mlir::Operation* op) {
513   if (!mlir::isa<mlir::TF::SoftmaxCrossEntropyWithLogitsOp>(op) &&
514       !mlir::isa<mlir::TF::SparseSoftmaxCrossEntropyWithLogitsOp>(op))
515     return errors::InvalidArgument(
516         "unsupported op for in SoftmaxLossOpSPMDExpander");
517 
518   TF_ASSIGN_OR_RETURN(const Layout& features_layout,
519                       ExtractRequiredLayoutFromOperand(op->getOperand(0)));
520   TF_ASSIGN_OR_RETURN(const Layout& labels_layout,
521                       ExtractRequiredLayoutFromOperand(op->getOperand(1)));
522   TF_ASSIGN_OR_RETURN(const std::vector<Layout>& output_layouts,
523                       ExtractRequiredLayoutFromOp(op));
524 
525   const bool is_sparse =
526       mlir::isa<mlir::TF::SparseSoftmaxCrossEntropyWithLogitsOp>(op);
527 
528   Layout new_features_layout;
529   Layout new_labels_layout;
530 
531   TF_ASSIGN_OR_RETURN(
532       const Layout internal_layout,
533       MaybeRelayoutInputs(op, is_sparse, features_layout, labels_layout,
534                           output_layouts[0], output_layouts[1],
535                           new_features_layout, new_labels_layout));
536 
537   assert(internal_layout.rank() == 2);
538 
539   // If the class dim is unshared, we can emit a local op.
540   if (Layout::IsUnshardedSpec(internal_layout.dim(1))) {
541     op = InferSPMDExpandedLocalShape(op);
542     return MaybeRelayoutOutputs(op, op->getResult(0), op->getResult(1),
543                                 internal_layout, output_layouts[0],
544                                 output_layouts[1]);
545   }
546 
547   mlir::OpBuilder builder(op);
548   builder.setInsertionPointAfter(op);
549 
550   mlir::Value features = op->getOperand(0);
551   mlir::Value labels = op->getOperand(1);
552   if (is_sparse) {
553     // SparseSoftmaxCrossEntropyWithLogits(features, labels) can be rewritten
554     // as SoftmaxCrossEntropyWithLogits(features, OneHot(labels)).
555     // Note that this is what is done in the XLA kernel for this op.
556     TF_ASSIGN_OR_RETURN(
557         labels, ComputeOneHot(builder, labels, features, internal_layout));
558   }
559 
560   if (features_layout.rank() == 0)
561     return errors::Unimplemented(
562         "scalar values features is not currently supported");
563 
564   // SoftmaxCrossEntropyWithLogitsOp is the same as:
565   // loss = -tf.reduce_sum(labels*tf.LogSoftmax(features), class_dim)
566   // backprop = tf.Softmax(features) - labels
567 
568   mlir::Value shifted_logits;
569   mlir::Value exp_of_shifted_logits;
570   mlir::Value sum_of_exp;
571 
572   // Note that its possible that features is shape [x, 1] and is broadcasted
573   // to match labels. In this case we are doing a bunch of extra work, since
574   // softmax is 1 and log_softmax is 0.
575   TF_RETURN_IF_ERROR(ComputeExpAndSum(builder, features, new_features_layout,
576                                       shifted_logits, exp_of_shifted_logits,
577                                       sum_of_exp));
578 
579   const mlir::Value log_softmax =
580       ComputeLogSoftmax(builder, shifted_logits, sum_of_exp);
581   const mlir::Value softmax =
582       ComputeSoftmax(builder, exp_of_shifted_logits, sum_of_exp);
583 
584   // Mimic the XLA, which uses where/select to ensure that sub is zero when
585   // labels are zero.
586   TF_ASSIGN_OR_RETURN(const mlir::Value features_zero,
587                       GetFPConstOfType(builder, features, 0.0));
588   TF_ASSIGN_OR_RETURN(const mlir::Value labels_zero,
589                       GetFPConstOfType(builder, labels, 0.0));
590 
591   const mlir::Value is_labels_zero =
592       builder
593           .create<mlir::TF::EqualOp>(op->getLoc(), labels, labels_zero,
594                                      builder.getBoolAttr(true))
595           .z();
596   const mlir::Value safe_softmax =
597       builder
598           .create<mlir::TF::SelectV2Op>(op->getLoc(), is_labels_zero,
599                                         features_zero, log_softmax)
600           .output();
601   const mlir::Value prod =
602       builder.create<mlir::TF::MulOp>(op->getLoc(), labels, safe_softmax).z();
603 
604   // Compute the reduce sum
605   TF_ASSIGN_OR_RETURN(
606       mlir::Value positive_loss,
607       ComputeGlobalReduce(builder, prod, internal_layout, /*reduced_dims=*/{1},
608                           kReduceOpAdd, /*keep_dims=*/false));
609 
610   builder.setInsertionPointAfterValue(positive_loss);
611   mlir::Value loss =
612       builder.create<mlir::TF::NegOp>(op->getLoc(), positive_loss).y();
613 
614   mlir::Value backprop =
615       builder.create<mlir::TF::SubOp>(op->getLoc(), softmax, labels);
616 
617   return MaybeRelayoutOutputs(op, loss, backprop, internal_layout,
618                               output_layouts[0], output_layouts[1]);
619 }
620 
621 StatusOr<llvm::DenseMap<int, Layout>>
ComputeLayoutForward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts)622 SoftmaxLossOpSPMDExpander::ComputeLayoutForward(
623     mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts) {
624   TF_ASSIGN_OR_RETURN(const Mesh mesh, ExtractDeviceMeshEnclosingCluster(op));
625   const bool is_sparse =
626       mlir::isa<mlir::TF::SparseSoftmaxCrossEntropyWithLogitsOp>(op);
627 
628   // loss is sum(-labels * logsoftmax(features)), so the layout is batch
629   // sharded if labels and features are batch sharded on the same mesh dim or
630   // if one is replicated.
631   // backprop is softmax(features) - labels
632 
633   absl::optional<Layout> features_layout;
634   if (input_layouts.find(0) != input_layouts.end())
635     features_layout.emplace(input_layouts.lookup(0));
636   absl::optional<Layout> labels_layout;
637   if (input_layouts.find(1) != input_layouts.end())
638     labels_layout.emplace(input_layouts.lookup(1));
639 
640   // We need to compute shardings for two dimensions: batch and class.
641   std::vector<ShardingSpec> layout_specs(2);
642   layout_specs[0].set_sharding_spec(Layout::kUnshardedDim);
643   layout_specs[1].set_sharding_spec(Layout::kUnshardedDim);
644 
645   // First pick the batch dimension, set it to the batch dimension of features
646   // if it exists otherwise to the batch dimesion of labels.
647   if (features_layout && (features_layout->rank() == 2))
648     layout_specs[0] = features_layout->dim(0);
649   if (labels_layout &&
650       (labels_layout->rank() == 2 ||
651        (is_sparse && labels_layout->rank() == 1)) &&
652       Layout::IsUnshardedSpec(layout_specs[0]))
653     layout_specs[0] = labels_layout->dim(0);
654 
655   // The class dim for features and labels is always the last dim if it
656   // exists.
657   if (features_layout && (features_layout->rank() > 0) &&
658       (layout_specs[0].sharding_spec() !=
659        features_layout->sharding_spec(features_layout->rank() - 1)))
660     layout_specs[1] = features_layout->dim(features_layout->rank() - 1);
661   if (!is_sparse && labels_layout && (labels_layout->rank() > 0) &&
662       Layout::IsUnshardedSpec(layout_specs[1]) &&
663       (layout_specs[0].sharding_spec() !=
664        labels_layout->sharding_spec(labels_layout->rank() - 1)))
665     layout_specs[1] = labels_layout->dim(labels_layout->rank() - 1);
666 
667   TF_ASSIGN_OR_RETURN(const Layout backprop_layout,
668                       Layout::GetLayout(layout_specs, mesh));
669   const Layout loss_layout = backprop_layout.Truncate(/*split_point=*/1);
670 
671   return llvm::DenseMap<int, Layout>({{0, loss_layout}, {1, backprop_layout}});
672 }
673 
674 StatusOr<llvm::DenseMap<int, Layout>>
ComputeLayoutBackward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & output_layouts)675 SoftmaxLossOpSPMDExpander::ComputeLayoutBackward(
676     mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts) {
677   TF_ASSIGN_OR_RETURN(const Mesh mesh, ExtractDeviceMeshEnclosingCluster(op));
678   const bool is_sparse =
679       mlir::isa<mlir::TF::SparseSoftmaxCrossEntropyWithLogitsOp>(op);
680 
681   absl::optional<Layout> loss_layout;
682   if (output_layouts.find(0) != output_layouts.end())
683     loss_layout.emplace(output_layouts.lookup(0));
684   absl::optional<Layout> backprop_layout;
685   if (output_layouts.find(1) != output_layouts.end())
686     backprop_layout.emplace(output_layouts.lookup(1));
687 
688   // We need to compute two possible shardings:
689   // One for the batch dimension and one for the class dimension.
690   std::vector<ShardingSpec> layout_specs(2);
691   layout_specs[0].set_sharding_spec(Layout::kUnshardedDim);
692   layout_specs[1].set_sharding_spec(Layout::kUnshardedDim);
693 
694   // Respect the loss layout if it is set, otherwise use the backprop
695   // layout for the batch_dim.
696   if (loss_layout) layout_specs[0] = loss_layout->dim(0);
697   if (backprop_layout && Layout::IsUnshardedSpec(layout_specs[0]))
698     layout_specs[0] = backprop_layout->dim(0);
699 
700   // Only backprop has class dim so use that if it is available.
701   if (backprop_layout &&
702       backprop_layout->sharding_spec(1) != layout_specs[0].sharding_spec())
703     layout_specs[1] = backprop_layout->dim(1);
704 
705   TF_ASSIGN_OR_RETURN(const auto features_shape,
706                       GetShapeOfValue(op->getOperand(0)));
707   TF_ASSIGN_OR_RETURN(const auto labels_shape,
708                       GetShapeOfValue(op->getOperand(1)));
709   TF_ASSIGN_OR_RETURN(const Layout features_layout,
710                       GetBroadcastedLayout(features_shape, layout_specs, mesh));
711   if (is_sparse) {
712     // Drop the class sharding as the labels don't have class dimension in
713     // the sparse version.
714     layout_specs.resize(1);
715   }
716   TF_ASSIGN_OR_RETURN(const Layout labels_layout,
717                       GetBroadcastedLayout(labels_shape, layout_specs, mesh));
718 
719   return llvm::DenseMap<int, Layout>(
720       {{0, features_layout}, {1, labels_layout}});
721 }
722 
723 }  // namespace dtensor
724 }  // namespace tensorflow
725