xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/algebraic_simplifier.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
17 
18 #include <algorithm>
19 #include <cmath>
20 #include <functional>
21 #include <iterator>
22 #include <memory>
23 #include <numeric>
24 #include <optional>
25 #include <string>
26 #include <utility>
27 #include <vector>
28 
29 #include "absl/algorithm/container.h"
30 #include "absl/container/flat_hash_map.h"
31 #include "absl/container/flat_hash_set.h"
32 #include "absl/container/inlined_vector.h"
33 #include "absl/strings/str_cat.h"
34 #include "absl/types/span.h"
35 #include "tensorflow/compiler/xla/comparison_util.h"
36 #include "tensorflow/compiler/xla/layout_util.h"
37 #include "tensorflow/compiler/xla/literal.h"
38 #include "tensorflow/compiler/xla/literal_comparison.h"
39 #include "tensorflow/compiler/xla/literal_util.h"
40 #include "tensorflow/compiler/xla/overflow_util.h"
41 #include "tensorflow/compiler/xla/permutation_util.h"
42 #include "tensorflow/compiler/xla/primitive_util.h"
43 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
44 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
45 #include "tensorflow/compiler/xla/service/hlo_computation.h"
46 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
47 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
48 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
49 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
50 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
51 #include "tensorflow/compiler/xla/service/hlo_query.h"
52 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
53 #include "tensorflow/compiler/xla/shape.h"
54 #include "tensorflow/compiler/xla/shape_util.h"
55 #include "tensorflow/compiler/xla/status_macros.h"
56 #include "tensorflow/compiler/xla/types.h"
57 #include "tensorflow/compiler/xla/util.h"
58 #include "tensorflow/compiler/xla/window_util.h"
59 #include "tensorflow/compiler/xla/xla_data.pb.h"
60 #include "tensorflow/core/lib/core/errors.h"
61 #include "tensorflow/core/lib/core/status.h"
62 #include "tensorflow/core/platform/errors.h"
63 #include "tensorflow/core/platform/logging.h"
64 #include "tensorflow/core/platform/statusor.h"
65 #include "tensorflow/stream_executor/lib/statusor.h"
66 
67 namespace xla {
68 
69 namespace {
70 
71 namespace m = match;
72 
73 // Unwraps broadcasts hunting for a constant.  If we find one, checks if the
74 // constant contains only the given value.
IsAll(const HloInstruction * op,int8_t value)75 bool IsAll(const HloInstruction* op, int8_t value) {
76   switch (op->opcode()) {
77     case HloOpcode::kBroadcast:
78       return IsAll(op->operand(0), value);
79     case HloOpcode::kConstant:
80       return op->literal().IsAll(value);
81     default:
82       return false;
83   }
84 }
85 
IsAll(const HloInstruction * op,const Literal & scalar)86 bool IsAll(const HloInstruction* op, const Literal& scalar) {
87   CHECK(ShapeUtil::IsScalar(scalar.shape()));
88   switch (op->opcode()) {
89     case HloOpcode::kBroadcast:
90       return IsAll(op->operand(0), scalar);
91     case HloOpcode::kConstant:
92       return op->literal().IsAll(scalar);
93     default:
94       return false;
95   }
96 }
97 
IsAnyOperandComplex(const HloInstruction * hlo)98 bool IsAnyOperandComplex(const HloInstruction* hlo) {
99   for (auto operand : hlo->operands()) {
100     if (ShapeUtil::ElementIsComplex(operand->shape())) {
101       return true;
102     }
103   }
104   return false;
105 }
106 
IsPositive(const HloInstruction * hlo,const AlgebraicSimplifierOptions & options)107 bool IsPositive(const HloInstruction* hlo,
108                 const AlgebraicSimplifierOptions& options) {
109   // Utility only handles real types.
110   if (IsAnyOperandComplex(hlo)) {
111     return false;
112   }
113   switch (hlo->opcode()) {
114     case HloOpcode::kGetTupleElement: {
115       const HloInstruction* gte_operand = hlo->operand(0);
116       switch (gte_operand->opcode()) {
117         case HloOpcode::kCustomCall: {
118           const auto& target = gte_operand->custom_call_target();
119           return target ==
120                      options.get_cudnn_batchnorm_forward_training_metadata() &&
121                  hlo->tuple_index() == 2;
122         }
123         default:
124           return false;
125       }
126     }
127     case HloOpcode::kPower:
128     case HloOpcode::kAbs:
129     case HloOpcode::kRsqrt:
130     case HloOpcode::kSqrt:
131       return IsPositive(hlo->operand(0), options);
132 
133     case HloOpcode::kMultiply: {
134       return hlo->operand(0) == hlo->operand(1) &&
135              IsPositive(hlo->operand(0), options);
136     }
137     default:
138       return false;
139   }
140 }
141 
GetConstantValue(const HloInstruction * inst)142 std::optional<double> GetConstantValue(const HloInstruction* inst) {
143   if (!ShapeUtil::IsEffectiveScalar(inst->shape())) {
144     return std::nullopt;
145   }
146   switch (inst->shape().element_type()) {
147     case F16:
148       return static_cast<float>(inst->literal().GetFirstElement<half>());
149     case BF16:
150       return static_cast<float>(inst->literal().GetFirstElement<bfloat16>());
151     case F32:
152       return inst->literal().GetFirstElement<float>();
153     case F64:
154       return inst->literal().GetFirstElement<double>();
155     default:
156       return std::nullopt;
157   }
158 }
159 
IsScalarConstant(const HloInstruction * hlo,const LiteralSlice & literal)160 static bool IsScalarConstant(const HloInstruction* hlo,
161                              const LiteralSlice& literal) {
162   return hlo->opcode() == HloOpcode::kConstant &&
163          ShapeUtil::IsEffectiveScalar(hlo->shape()) &&
164          literal_comparison::Equal(hlo->literal(), literal).ok();
165 }
166 
IsScalarConstantZero(const HloInstruction * hlo)167 static bool IsScalarConstantZero(const HloInstruction* hlo) {
168   return IsScalarConstant(hlo, LiteralUtil::Zero(hlo->shape().element_type()));
169 }
170 
IsScalarConstantNegInf(const HloInstruction * hlo)171 static bool IsScalarConstantNegInf(const HloInstruction* hlo) {
172   return !primitive_util::IsComplexType(hlo->shape().element_type()) &&
173          IsScalarConstant(hlo,
174                           LiteralUtil::MinValue(hlo->shape().element_type()));
175 }
176 
IsScalarConstantInf(const HloInstruction * hlo)177 static bool IsScalarConstantInf(const HloInstruction* hlo) {
178   return !primitive_util::IsComplexType(hlo->shape().element_type()) &&
179          IsScalarConstant(hlo,
180                           LiteralUtil::MaxValue(hlo->shape().element_type()));
181 }
182 
IsNonNegative(const HloInstruction * hlo,const AlgebraicSimplifierOptions & options)183 bool IsNonNegative(const HloInstruction* hlo,
184                    const AlgebraicSimplifierOptions& options) {
185   // Utility only handles real types.
186   if (IsAnyOperandComplex(hlo)) {
187     return false;
188   }
189   switch (hlo->opcode()) {
190     case HloOpcode::kMultiply: {
191       return hlo->operand(0) == hlo->operand(1);
192     }
193     case HloOpcode::kAbs: {
194       return true;
195     }
196     case HloOpcode::kBroadcast: {
197       return IsNonNegative(hlo->operand(0), options);
198     }
199     case HloOpcode::kConstant: {
200       if (std::optional<double> value = GetConstantValue(hlo)) {
201         return *value >= 0.0;
202       }
203       return false;
204     }
205     case HloOpcode::kMaximum: {
206       return IsNonNegative(hlo->operand(0), options) ||
207              IsNonNegative(hlo->operand(1), options);
208     }
209     case HloOpcode::kSelect: {
210       return IsNonNegative(hlo->operand(1), options) &&
211              IsNonNegative(hlo->operand(2), options);
212     }
213     default:
214       return IsPositive(hlo, options);
215   }
216 }
217 
218 // Checks whether `op` is a floating-point constant or broadcast of a constant
219 // of the form +/- 2^k for some integer k positive, negative, or zero.  Such
220 // values are interesting because multiplying by a power of 2 just moves the
221 // exponent.
IsAllFpConstantPowerOf2(const HloInstruction * op)222 bool IsAllFpConstantPowerOf2(const HloInstruction* op) {
223   // Unwrap the broadcast if necessary.
224   const HloInstruction* c;
225   if (!Match(op, m::ConstantEffectiveScalar(&c)) &&
226       !Match(op, m::Broadcast(m::Constant(&c).WithShape(
227                      m::Shape().IsEffectiveScalar())))) {
228     return false;
229   }
230   auto val = [&]() -> std::optional<double> {
231     switch (c->shape().element_type()) {
232       case BF16:
233         return static_cast<double>(c->literal().GetFirstElement<bfloat16>());
234       case F16:
235         return static_cast<double>(c->literal().GetFirstElement<Eigen::half>());
236       case F32:
237         return c->literal().GetFirstElement<float>();
238       case F64:
239         return c->literal().GetFirstElement<double>();
240       default:
241         // Cowardly refuse to consider complex types.
242         return std::nullopt;
243     }
244   }();
245   if (!val) {
246     return false;
247   }
248 
249   int exp;
250   double mantissa = std::frexp(*val, &exp);
251   // frexp returns a value in the range (-1, -0.5] U [0.5, 1).  A return value
252   // of +/-0.5 therefore indicates that the floating point value is a power of
253   // 2.
254   return mantissa == 0.5 || mantissa == -0.5;
255 }
256 
257 // Returns whether the given transpose produces a result which is bit-wise
258 // identical to its operand and thus may be replaced with a bitcast.
TransposeIsBitcast(const HloInstruction * transpose)259 bool TransposeIsBitcast(const HloInstruction* transpose) {
260   CHECK_EQ(HloOpcode::kTranspose, transpose->opcode());
261   const HloInstruction* operand = transpose->operand(0);
262   return ShapeUtil::TransposeIsBitcast(operand->shape(), transpose->shape(),
263                                        transpose->dimensions());
264 }
265 
266 // Recursive helper for method below.
BitcastingOperandOfReshapeOrCopyChainHelper(HloInstruction * instr,HloInstruction * operand,const AlgebraicSimplifierOptions & options)267 HloInstruction* BitcastingOperandOfReshapeOrCopyChainHelper(
268     HloInstruction* instr, HloInstruction* operand,
269     const AlgebraicSimplifierOptions& options) {
270   // Can't replace chain of copies and reshapes with bitcasts if the compiler
271   // used a memory layout which isn't compatible.
272   if (options.ReshapeIsBitcast(operand->shape(), instr->shape())) {
273     return operand;
274   }
275 
276   // If the operand is a copy or reshape try to see if the operand's operand
277   // would produce a bitcast with initial instruction.
278   if (HloOpcode::kReshape == operand->opcode() ||
279       HloOpcode::kCopy == operand->opcode()) {
280     return BitcastingOperandOfReshapeOrCopyChainHelper(
281         instr, operand->mutable_operand(0), options);
282   }
283   return nullptr;
284 }
285 
286 // Returns an operand of a chain of reshapes and copies that is bit-wise
287 // identical to first reshape or copy in the chain.
BitcastingOperandOfReshapeOrCopyChain(HloInstruction * instr,const AlgebraicSimplifierOptions & options)288 HloInstruction* BitcastingOperandOfReshapeOrCopyChain(
289     HloInstruction* instr, const AlgebraicSimplifierOptions& options) {
290   if (!options.is_layout_sensitive()) {
291     return nullptr;
292   }
293   CHECK(HloOpcode::kReshape == instr->opcode() ||
294         HloOpcode::kCopy == instr->opcode());
295   return BitcastingOperandOfReshapeOrCopyChainHelper(
296       instr, instr->mutable_operand(0), options);
297 }
298 
IsUnstridedSlice(const HloInstruction * hlo)299 bool IsUnstridedSlice(const HloInstruction* hlo) {
300   return absl::c_all_of(hlo->slice_strides(),
301                         [](int64_t stride) { return stride == 1; });
302 }
303 
304 // Returns bool to determine whether a pair of converts can be eliminated.
IsConvertPairNoOp(const HloInstruction * convert)305 bool IsConvertPairNoOp(const HloInstruction* convert) {
306   //    [operand_convert]         [convert]
307   // (src)->convert-(intermediate)->convert-(dest)
308   const HloInstruction* operand_convert = convert->operand(0);
309   if (operand_convert->opcode() != HloOpcode::kConvert) {
310     return false;
311   }
312   const PrimitiveType src_type =
313       operand_convert->operand(0)->shape().element_type();
314   const PrimitiveType intermediate_type =
315       operand_convert->shape().element_type();
316 
317   return src_type == convert->shape().element_type() &&
318          primitive_util::CastPreservesValues(src_type, intermediate_type);
319 }
320 
SwapOperandsInDotPrecisionConfig(PrecisionConfig config)321 PrecisionConfig SwapOperandsInDotPrecisionConfig(PrecisionConfig config) {
322   CHECK_EQ(config.operand_precision_size(), 2);
323   std::swap(config.mutable_operand_precision()->at(0),
324             config.mutable_operand_precision()->at(1));
325   return config;
326 }
327 
328 // Validate whether tiling and padding assignments in the bitcasted shapes
329 // will make the two shapes non-equivalent.
ValidateTilingOfBitcast(const Shape & bitcast_shape,const Shape & op_shape,const std::vector<std::vector<int64_t>> & operand_map)330 bool ValidateTilingOfBitcast(
331     const Shape& bitcast_shape, const Shape& op_shape,
332     const std::vector<std::vector<int64_t>>& operand_map) {
333   if (op_shape.layout().tiles().empty() ||
334       bitcast_shape.layout().tiles().empty()) {
335     return true;
336   }
337   VLOG(2) << "op shape:" << op_shape.ToString(true) << "\n";
338   VLOG(2) << "bitcast shape:" << bitcast_shape.ToString(true) << "\n";
339   VLOG(2) << "operand_map size:" << operand_map.size() << "\n";
340   auto op_tile = op_shape.layout().tiles(0);
341   auto bitcast_tile = bitcast_shape.layout().tiles(0);
342   int64_t num_of_tiled_dims = op_tile.dimensions().size(),
343           tiled_dim_idx = num_of_tiled_dims - 1;
344   if (bitcast_tile.dimensions().size() != num_of_tiled_dims) {
345     return false;
346   }
347   for (auto op_dim : op_shape.layout().minor_to_major()) {
348     VLOG(3) << "op_dim = " << op_dim << "\n";
349     VLOG(3) << "tiled_dim_idx = " << tiled_dim_idx << "\n";
350     VLOG(3) << "tiled_dim_size = " << op_tile.dimension(tiled_dim_idx) << ":"
351             << bitcast_tile.dimension(tiled_dim_idx) << "\n";
352     if (op_tile.dimensions()[tiled_dim_idx] !=
353         bitcast_tile.dimensions()[tiled_dim_idx]) {
354       VLOG(2) << "Abort b/c tiled dimension " << op_dim
355               << " has different tiling sizes before and after bitcast.\n";
356       return false;
357     }
358     if (operand_map.size() <= op_dim || operand_map[op_dim].empty()) {
359       if (op_tile.dimensions()[tiled_dim_idx] != 1) {
360         VLOG(2) << "Abort b/c tiled dimension " << op_dim << " has size 1.\n";
361         return false;
362       }
363     } else if (bitcast_shape.dimensions_size() <= operand_map[op_dim][0]) {
364       VLOG(2) << "Abort because the bitcasted dimensions are not aligned!\n";
365       return false;
366     } else if (bitcast_shape.dimensions(operand_map[op_dim][0]) <
367                op_shape.dimensions(op_dim)) {
368       if (operand_map[op_dim].size() == 1) {
369         VLOG(2) << "Abort b/c a dimension (possibly padded) is shrank to a "
370                    "smaller size.\n";
371         return false;
372       }
373       if (tiled_dim_idx > 0) {
374         VLOG(2) << "Abort b/c a non-major tiled dimension is split.\n";
375         return false;
376       }
377       if (bitcast_shape.dimensions(operand_map[op_dim][0]) %
378                   op_tile.dimensions()[tiled_dim_idx] !=
379               0 ||
380           op_shape.dimensions(op_dim) %
381                   bitcast_shape.dimensions(operand_map[op_dim][0]) !=
382               0) {
383         VLOG(2) << "Abort b/c tiled dimension " << op_dim
384                 << " has been split in bitcasted layout\n";
385         return false;
386       }
387     } else if (bitcast_shape.dimensions(operand_map[op_dim][0]) >
388                op_shape.dimensions(op_dim)) {
389       if (tiled_dim_idx > 0) {
390         VLOG(2) << "Abort b/c a non-major tiled dimension is combined.\n";
391         return false;
392       }
393       if (bitcast_shape.dimensions(operand_map[op_dim][0]) %
394                   op_shape.dimensions(op_dim) !=
395               0 ||
396           op_shape.dimensions(op_dim) % op_tile.dimensions()[tiled_dim_idx] !=
397               0) {
398         VLOG(2) << "Abort b/c tiled dimension " << op_dim
399                 << " has been combined in bitcasted layout\n";
400         return false;
401       }
402     }
403     if (--tiled_dim_idx < 0) {
404       break;
405     }
406   }
407   return true;
408 }
409 
410 }  // namespace
411 
ResetState(HloComputation * computation)412 void AlgebraicSimplifierVisitor::ResetState(HloComputation* computation) {
413   ResetVisitStates();
414   computation_ = computation;
415 }
416 
Run(HloComputation * computation,const AlgebraicSimplifierOptions & options,AlgebraicSimplifier * simplifier)417 bool AlgebraicSimplifierVisitor::Run(HloComputation* computation,
418                                      const AlgebraicSimplifierOptions& options,
419                                      AlgebraicSimplifier* simplifier) {
420   ResetState(computation);
421   TF_CHECK_OK(computation->Accept(this));
422   return changed();
423 }
424 
SameShape(const HloInstruction * lhs,const HloInstruction * rhs) const425 bool AlgebraicSimplifierVisitor::SameShape(const HloInstruction* lhs,
426                                            const HloInstruction* rhs) const {
427   return SameShape(lhs->shape(), rhs->shape());
428 }
429 
SameShape(const Shape & lhs,const Shape & rhs) const430 bool AlgebraicSimplifierVisitor::SameShape(const Shape& lhs,
431                                            const Shape& rhs) const {
432   if (options_.is_layout_sensitive()) {
433     return ShapeUtil::Equal(lhs, rhs);
434   } else {
435     return ShapeUtil::Compatible(lhs, rhs);
436   }
437 }
438 
439 namespace {
440 
IsOpCodeMultiplyCommutative(HloOpcode opcode)441 bool IsOpCodeMultiplyCommutative(HloOpcode opcode) {
442   switch (opcode) {
443     case HloOpcode::kMultiply:
444     case HloOpcode::kTranspose:
445     case HloOpcode::kReshape:
446     case HloOpcode::kSelect:
447       return true;
448     default:
449       return false;
450   }
451 }
452 
MakeScalarInstruction(HloInstruction * target,float multiplier)453 std::unique_ptr<HloInstruction> MakeScalarInstruction(HloInstruction* target,
454                                                       float multiplier) {
455   switch (target->shape().element_type()) {
456     case BF16:
457       return HloInstruction::CreateConstant(LiteralUtil::ConvertF32ToBF16(
458           LiteralUtil::CreateR0<float>(multiplier)));
459       break;
460     case F32:
461       return HloInstruction::CreateConstant(
462           LiteralUtil::CreateR0<float>(multiplier));
463       break;
464     default:
465       LOG(FATAL) << "Unsupported data type: " << target->shape().element_type();
466   }
467 }
468 
469 }  // namespace
470 
ScalarMultiplyReduction(HloInstruction * dot)471 Status AlgebraicSimplifierVisitor::ScalarMultiplyReduction(
472     HloInstruction* dot) {
473   // We only process bfloat16 and float32 for now.
474   if (dot->shape().element_type() != BF16 &&
475       dot->shape().element_type() != F32) {
476     return OkStatus();
477   }
478 
479   auto lhs = dot->mutable_operand(0);
480   auto rhs = dot->mutable_operand(1);
481 
482   const int64_t dot_size = ShapeUtil::ElementsIn(dot->shape());
483   const int64_t lhs_size = ShapeUtil::ElementsIn(lhs->shape());
484   const int64_t rhs_size = ShapeUtil::ElementsIn(rhs->shape());
485 
486   HloInstruction* target = nullptr;
487   // (current node, user, operand_index)
488   std::vector<std::tuple<HloInstruction*, HloInstruction*, int64_t>> operands;
489   std::vector<HloInstruction*> users;
490 
491   // Find which side of dot has the smallest size:
492   // operand 0, operand 1, or output.
493   if (dot_size <= std::min(lhs_size, rhs_size)) {
494     target = dot;
495     if (dot_size < lhs_size) {
496       operands.emplace_back(lhs, dot, 0);
497     }
498     if (dot_size < rhs_size) {
499       operands.emplace_back(rhs, dot, 1);
500     }
501   } else if (lhs_size <= rhs_size) {
502     target = lhs;
503     if (lhs_size < rhs_size) {
504       operands.emplace_back(rhs, dot, 1);
505     }
506     if (lhs_size < dot_size && dot->user_count() == 1) {
507       users.push_back(dot->users().front());
508     }
509   } else {
510     target = rhs;
511     if (rhs_size < lhs_size) {
512       operands.emplace_back(lhs, dot, 0);
513     }
514     if (rhs_size < dot_size && dot->user_count() == 1) {
515       users.push_back(dot->users().front());
516     }
517   }
518 
519   std::vector<float> values;
520 
521   // DFS to find scalar multiply ops from the operands.
522   while (!operands.empty()) {
523     HloInstruction* inst;
524     HloInstruction* user;
525     int64_t index;
526     std::tie(inst, user, index) = operands.back();
527     operands.pop_back();
528 
529     // Skip the op types that are not commutative with multiply.
530     if (!IsOpCodeMultiplyCommutative(inst->opcode())) {
531       continue;
532     }
533 
534     HloInstruction* operand;
535     HloInstruction* multiplier;
536     // Pattern match a scalar multiply.
537     if (Match(inst, m::MultiplyAnyOrder(
538                         m::Op(&operand),
539                         m::Broadcast(m::ConstantScalar(&multiplier))))) {
540       CHECK_LT(index, user->operand_count());
541       CHECK_EQ(inst, user->operands()[index]);
542 
543       // When found a scalar multiply, save its scalar value.
544       values.push_back(*GetConstantValue(multiplier));
545       // And remove the scalar multiply op.
546       TF_RETURN_IF_ERROR(user->ReplaceOperandWith(index, operand));
547       inst = operand;
548     }
549 
550     // Push the operands of inst.
551     int64_t i = 0;
552     for (auto* operand : inst->operands()) {
553       operands.emplace_back(operand, inst, i++);
554     }
555   }
556 
557   // DFS to find scalar multiply ops from the users.
558   while (!users.empty()) {
559     auto inst = users.back();
560     users.pop_back();
561 
562     if (!IsOpCodeMultiplyCommutative(inst->opcode())) {
563       continue;
564     }
565 
566     HloInstruction* operand;
567     HloInstruction* multiplier;
568     if (Match(inst, m::MultiplyAnyOrder(
569                         m::Op(&operand),
570                         m::Broadcast(m::ConstantScalar(&multiplier))))) {
571       values.push_back(*GetConstantValue(multiplier));
572 
573       TF_RETURN_IF_ERROR(inst->ReplaceAllUsesWith(operand));
574       inst = operand;
575     }
576 
577     // Process the instructions with only one user.
578     // Otherwise moving scalar multiply to the operands changes the values of
579     // other users.
580     if (inst->user_count() == 1) {
581       users.push_back(inst->users().front());
582     }
583   }
584 
585   if (values.empty()) {
586     return OkStatus();
587   }
588 
589   MarkAsChanged();
590 
591   // Combine all constant multipliers.
592   float multiplier = 1.0;
593   for (const float v : values) {
594     multiplier *= v;
595   }
596 
597   // Create a new const scalar multiply instruction.
598   HloInstruction* new_const_inst;
599   new_const_inst =
600       target->AddInstruction(MakeScalarInstruction(target, multiplier));
601 
602   // Broadcast the scalar multiplier.
603   HloInstruction* new_broadcast = target->AddInstruction(
604       HloInstruction::CreateBroadcast(target->shape(), new_const_inst, {}));
605   // Create a new scalar multiply instruction.
606   HloInstruction* new_multiply =
607       target->AddInstruction(HloInstruction::CreateBinary(
608           target->shape(), HloOpcode::kMultiply, target, new_broadcast));
609   CHECK_EQ(new_multiply->shape(), target->shape());
610 
611   // Update the dependency with the rest of the instructions.
612   if (target == lhs) {
613     return dot->ReplaceOperandWith(0, new_multiply);
614   } else if (target == rhs) {
615     return dot->ReplaceOperandWith(1, new_multiply);
616   } else {
617     CHECK_EQ(target, dot);
618     return dot->ReplaceAllUsesWith(new_multiply);
619   }
620 }
621 
ReplaceWithBitcast(HloInstruction * instruction,HloInstruction * operand)622 void AlgebraicSimplifierVisitor::ReplaceWithBitcast(HloInstruction* instruction,
623                                                     HloInstruction* operand) {
624   CHECK_EQ(1, instruction->operand_count());
625   if (operand == nullptr) {
626     operand = instruction->mutable_operand(0);
627   }
628   CHECK_EQ(ShapeUtil::ElementsIn(instruction->shape()),
629            ShapeUtil::ElementsIn(operand->shape()));
630   CHECK_EQ(ShapeUtil::ByteSizeOf(instruction->shape()),
631            ShapeUtil::ByteSizeOf(operand->shape()));
632 
633   auto bitcast = instruction->AddInstruction(
634       HloInstruction::CreateBitcast(instruction->shape(), operand));
635   TF_CHECK_OK(ReplaceInstruction(instruction, bitcast));
636 }
637 
638 // Replace the old instruction with the new one if they are compatible, i.e.,
639 // 1. they have same shape
640 // 2. the replacement will not cause loss of sharding
ReplaceInstructionIfCompatible(HloInstruction * old_instruction,HloInstruction * new_instruction)641 bool AlgebraicSimplifierVisitor::ReplaceInstructionIfCompatible(
642     HloInstruction* old_instruction, HloInstruction* new_instruction) {
643   if (!SameShape(old_instruction, new_instruction)) {
644     return false;
645   }
646   return ReplaceInstruction(old_instruction, new_instruction,
647                             /*preserve_sharding=*/true)
648       .ValueOrDie();
649 }
650 
ReplaceInstructionIfCompatible(HloInstruction * old_instruction,absl::Span<HloInstruction * const> new_instructions)651 bool AlgebraicSimplifierVisitor::ReplaceInstructionIfCompatible(
652     HloInstruction* old_instruction,
653     absl::Span<HloInstruction* const> new_instructions) {
654   if (new_instructions.size() == 1) {
655     return ReplaceInstructionIfCompatible(old_instruction, new_instructions[0]);
656   }
657   CHECK(!new_instructions.empty());
658   if (!old_instruction->shape().IsTuple() ||
659       old_instruction->shape().tuple_shapes_size() != new_instructions.size()) {
660     return false;
661   }
662   for (int i = 0, n = new_instructions.size(); i < n; ++i) {
663     if (!SameShape(old_instruction->shape().tuple_shapes(i),
664                    new_instructions[i]->shape())) {
665       return false;
666     }
667   }
668   return ReplaceInstruction(old_instruction, MaybeMakeTuple(new_instructions),
669                             /*preserve_sharding=*/true)
670       .ValueOrDie();
671 }
672 
HandleAbs(HloInstruction * abs)673 Status AlgebraicSimplifierVisitor::HandleAbs(HloInstruction* abs) {
674   HloInstruction* abs_operand = abs->mutable_operand(0);
675   VLOG(10) << "trying transform [Abs(A) => A] " << abs->ToString()
676            << " Abs operand is: " << abs_operand->ToString();
677   if (IsNonNegative(abs->operand(0), options_)) {
678     return ReplaceInstruction(abs, abs_operand);
679   }
680   return OkStatus();
681 }
682 
HandleAdd(HloInstruction * add)683 Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) {
684   HloInstruction *lhs, *rhs;
685   CHECK(Match(add, m::Add(m::Op(&lhs), m::Op(&rhs))));
686 
687   // A + 0 => A
688   VLOG(10) << "trying transform [A + 0 => A]: " << add->ToString();
689   if (IsAll(rhs, 0) && ReplaceInstructionIfCompatible(add, lhs)) {
690     return OkStatus();
691   }
692   // 0 + A => A
693   VLOG(10) << "trying transform [0 + A => A]: " << add->ToString();
694   if (IsAll(lhs, 0) && ReplaceInstructionIfCompatible(add, rhs)) {
695     return OkStatus();
696   }
697 
698   // Canonicalization: Put constants on the right.  This makes the reassociation
699   // rules below simpler.
700   VLOG(10) << "trying transform [Const + A => A + Const]";
701   if (Match(add, m::Add(m::Constant(), m::NonConstant()))) {
702     return ReplaceWithNewInstruction(
703         add,
704         HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, rhs, lhs));
705   }
706 
707   // Reassociate to allow constant folding.
708   //
709   // Note: This is not general.  For example, we won't reassociate
710   //
711   //   (A + C1) + (B + C2) =>  A + B + (C1 + C2).
712   //
713   VLOG(10) << "trying transform [(A + C1) + C2 => A + (C1 + C2)]";
714   HloInstruction *a, *c1, *c2;
715   if (Match(add, m::Add(m::Add(m::NonConstant(&a), m::Constant(&c1)),
716                         m::Constant(&c2))) ||
717       Match(add, m::Add(m::Add(m::NonConstant(&a),
718                                m::Broadcast(m::ConstantScalar(&c1))),
719                         m::Broadcast(m::ConstantScalar(&c2))))) {
720     TF_ASSIGN_OR_RETURN(auto* sum_of_constants,
721                         MakeBinaryHlo(HloOpcode::kAdd, c1, c2));
722     if (ShapeUtil::IsScalar(sum_of_constants->shape()) &&
723         !ShapeUtil::IsScalar(add->shape())) {
724       sum_of_constants = add->AddInstruction(
725           HloInstruction::CreateBroadcast(add->shape(), sum_of_constants, {}));
726     }
727     return ReplaceWithNewInstruction(
728         add, HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, a,
729                                           sum_of_constants));
730   }
731 
732   // Convert add with fullshape into add with partial shape when a
733   // portion of add is effective:
734   //             zero (fullshape)   rhs (partialshape)
735   // .           |                  |
736   // . lhs .    dynamic_update_slice (fullshape)
737   // . |         |
738   // Add (fullshape)
739   //
740   // to:
741   //              lhs
742   //              |
743   //             dynamic_slice (partialshape)   rhs (partialshape)
744   // .           |                      |
745   // . lhs .    add (partial_shape)+----+
746   // . |         |
747   // dynamic_update_slice (fullshape)
748   //
749   // This is pattern is discovered in control flow V2 gradient update.
750   if (Match(add,
751             m::Add(m::Op(&lhs),
752                    m::Op(&rhs)
753                        .WithOpcode(HloOpcode::kDynamicUpdateSlice)
754                        .WithOperand(
755                            0, m::Broadcast(m::ConstantEffectiveScalar(0)))))) {
756     const Shape& partial_shape = rhs->operand(1)->shape();
757     auto sliced_lhs = lhs->AddInstruction(HloInstruction::CreateDynamicSlice(
758         partial_shape, lhs, absl::MakeSpan(rhs->operands()).subspan(2),
759         partial_shape.dimensions()));
760 
761     auto add_partial = rhs->AddInstruction(
762         HloInstruction::CreateBinary(rhs->operand(1)->shape(), HloOpcode::kAdd,
763                                      sliced_lhs, rhs->mutable_operand(1)));
764 
765     auto dynamic_update_slice_full = HloInstruction::CreateDynamicUpdateSlice(
766         lhs->shape(), lhs, add_partial,
767         absl::MakeSpan(rhs->operands()).subspan(2));
768 
769     return ReplaceWithNewInstruction(add, std::move(dynamic_update_slice_full));
770   }
771 
772   // A*C + B*C => (A+B)*C
773   //
774   //  - If A, B, and C are integers, do this unconditionally. Proof of
775   //    correctness: https://rise4fun.com/Alive/u9X.
776   //
777   //  - If A, B, and C are floating point, do this if C is a scalar constant or
778   //    broadcast of scalar constant and is equal to +/- 2^k for some (possibly
779   //    negative) integer k.
780   //
781   //    Multiplying by a power of 2 just moves the exponent, so our answer is
782   //    exact modulo rounding of intermediate results so long as
783   //
784   //     - none of the three products has an exponent which underflows (so the
785   //       result is 0 or denormal), and
786   //     - none of the three products overflows to inf.
787   //
788   //    Proof: See algebraic_simplifier_proof_distributive_property.py.
789   //
790   //    We deem these differences in rounding, underflow, and overflow
791   //    acceptable in the ML context.
792   //
793   //    Furthermore, if `enable_floats_are_real` is true, the simplification is
794   //    done nonetheless. This might cause numerical differences even if there
795   //    is no underflow or overflow.
796   HloInstruction *b, *c;
797   if (((Match(lhs, m::Multiply(m::Op(&a), m::Op(&c))) &&
798         Match(rhs, m::MultiplyAnyOrder(m::Op().Is(c), m::Op(&b)))) ||
799        (Match(lhs, m::Multiply(m::Op(&c), m::Op(&a))) &&
800         Match(rhs, m::MultiplyAnyOrder(m::Op().Is(c), m::Op(&b))))) &&
801       // Make sure we would decrease the number of multiplies.
802       (lhs->user_count() == 1 && rhs->user_count() == 1) &&
803       (ShapeUtil::ElementIsIntegral(add->shape()) ||
804        options_.enable_floats_are_real() || IsAllFpConstantPowerOf2(c))) {
805     return ReplaceWithNewInstruction(
806         add, HloInstruction::CreateBinary(
807                  add->shape(), HloOpcode::kMultiply,
808                  lhs->AddInstruction(HloInstruction::CreateBinary(
809                      add->shape(), HloOpcode::kAdd, a, b)),
810                  c));
811   }
812 
813   if (options_.is_layout_sensitive()) {
814     return OkStatus();
815   }
816 
817   HloInstruction* lhs_scatter_operand = nullptr;
818   HloInstruction* rhs_scatter_operand = nullptr;
819   HloInstruction* lhs_scatter_update = nullptr;
820   HloInstruction* rhs_scatter_update = nullptr;
821   HloInstruction* lhs_scatter_index = nullptr;
822   HloInstruction* rhs_scatter_index = nullptr;
823   bool lhs_scatter = Match(lhs, m::Scatter(m::Op(&lhs_scatter_operand),
824                                            m::Op(&lhs_scatter_index),
825                                            m::Op(&lhs_scatter_update))
826                                     .WithOneUse()) &&
827                      Match(lhs->to_apply()->root_instruction(),
828                            m::Add(m::Parameter(), m::Parameter()));
829   bool rhs_scatter = Match(rhs, m::Scatter(m::Op(&rhs_scatter_operand),
830                                            m::Op(&rhs_scatter_index),
831                                            m::Op(&rhs_scatter_update))
832                                     .WithOneUse()) &&
833                      Match(rhs->to_apply()->root_instruction(),
834                            m::Add(m::Parameter(), m::Parameter()));
835   if (rhs_scatter && lhs_scatter) {
836     const auto& lhs_dnums = lhs->scatter_dimension_numbers();
837     const auto& rhs_dnums = rhs->scatter_dimension_numbers();
838     std::optional<int64_t> index_concat_dimension;
839     std::optional<int64_t> update_concat_dimension;
840     // Don't try to combine scatters of different ranks.
841     if (lhs_scatter_index->shape().rank() !=
842         rhs_scatter_index->shape().rank()) {
843       return OkStatus();
844     }
845 
846     int64_t first_index_dim = lhs_scatter_index->shape().rank();
847     int64_t first_update_dim = lhs_scatter_update->shape().rank();
848     // Find a dimension where it is possible to concatenate the indices and
849     // updates. This is the first and only non-equal dimension or the first
850     // equally sized dimension.
851     for (int64_t d = lhs_scatter_index->shape().rank() - 1,
852                  update_dim = lhs_scatter_update->shape().rank() - 1;
853          d >= 0; --d) {
854       if (d == lhs_dnums.index_vector_dim()) {
855         continue;
856       }
857       while (
858           absl::c_linear_search(lhs_dnums.update_window_dims(), update_dim)) {
859         --update_dim;
860       }
861       if (lhs_scatter_index->shape().dimensions(d) ==
862           rhs_scatter_index->shape().dimensions(d)) {
863         first_index_dim = d;
864         first_update_dim = update_dim--;
865         continue;
866       }
867       // More than one dimension of unequal size was found, bail out.
868       if (index_concat_dimension) {
869         return OkStatus();
870       }
871       index_concat_dimension = d;
872       update_concat_dimension = update_dim--;
873     }
874     if (!index_concat_dimension) {
875       index_concat_dimension = first_index_dim;
876       update_concat_dimension = first_update_dim;
877     }
878 
879     // A scalar scatter will require additional reshapes of the index and
880     // update.
881     if (*index_concat_dimension == lhs_scatter_index->shape().rank()) {
882       return OkStatus();
883     }
884     const bool update_concat_is_cheap =
885         ShapeUtil::ElementsIn(rhs_scatter_update->shape()) +
886             ShapeUtil::ElementsIn(lhs_scatter_update->shape()) <
887         ShapeUtil::ElementsIn(lhs->shape());
888     if (!update_concat_is_cheap) {
889       return OkStatus();
890     }
891     const bool same_dimension_numbers =
892         lhs_dnums.index_vector_dim() == rhs_dnums.index_vector_dim() &&
893         absl::c_equal(lhs_dnums.scatter_dims_to_operand_dims(),
894                       rhs_dnums.scatter_dims_to_operand_dims()) &&
895         absl::c_equal(lhs_dnums.inserted_window_dims(),
896                       rhs_dnums.inserted_window_dims()) &&
897         absl::c_equal(lhs_dnums.update_window_dims(),
898                       rhs_dnums.update_window_dims());
899     const bool index_concat_is_safe =
900         !lhs->unique_indices() && !rhs->unique_indices() &&
901         !DynCast<HloScatterInstruction>(lhs)->indices_are_sorted() &&
902         !DynCast<HloScatterInstruction>(rhs)->indices_are_sorted();
903 
904     Shape lhs_update_window = ShapeUtil::FilterDimensions(
905         [&](int64_t dim) {
906           return absl::c_linear_search(lhs_dnums.update_window_dims(), dim);
907         },
908         lhs_scatter_update->shape());
909     Shape rhs_update_window = ShapeUtil::FilterDimensions(
910         [&](int64_t dim) {
911           return absl::c_linear_search(rhs_dnums.update_window_dims(), dim);
912         },
913         rhs_scatter_update->shape());
914     // Concatenate the indices and updates
915     if (index_concat_is_safe && same_dimension_numbers &&
916         index_concat_dimension &&
917         lhs_scatter_index->shape().element_type() ==
918             rhs_scatter_index->shape().element_type() &&
919         ShapeUtil::SameDimensions(lhs_update_window, rhs_update_window)) {
920       TF_ASSIGN_OR_RETURN(HloInstruction * new_operand,
921                           MakeBinaryHlo(HloOpcode::kAdd, lhs_scatter_operand,
922                                         rhs_scatter_operand));
923       TF_ASSIGN_OR_RETURN(HloInstruction * new_index,
924                           MakeConcatHlo({lhs_scatter_index, rhs_scatter_index},
925                                         *index_concat_dimension));
926       TF_ASSIGN_OR_RETURN(
927           HloInstruction * new_update,
928           MakeConcatHlo({lhs_scatter_update, rhs_scatter_update},
929                         *update_concat_dimension));
930       return ReplaceWithNewInstruction(
931           add, HloInstruction::CreateScatter(
932                    add->shape(), new_operand, new_index, new_update,
933                    lhs->to_apply(), lhs_dnums, false, false));
934     }
935     TF_ASSIGN_OR_RETURN(HloInstruction * new_operand,
936                         MakeBinaryHlo(HloOpcode::kAdd, lhs_scatter_operand,
937                                       rhs_scatter_operand));
938     TF_RETURN_IF_ERROR(rhs->ReplaceOperandWith(0, new_operand));
939     TF_RETURN_IF_ERROR(lhs->ReplaceOperandWith(0, rhs));
940     return ReplaceInstruction(add, lhs);
941   } else if (rhs_scatter) {
942     TF_ASSIGN_OR_RETURN(
943         HloInstruction * new_operand,
944         MakeBinaryHlo(HloOpcode::kAdd, lhs, rhs_scatter_operand));
945     TF_RETURN_IF_ERROR(rhs->ReplaceOperandWith(0, new_operand));
946     return ReplaceInstruction(add, rhs);
947   } else if (lhs_scatter) {
948     TF_ASSIGN_OR_RETURN(
949         HloInstruction * new_operand,
950         MakeBinaryHlo(HloOpcode::kAdd, lhs_scatter_operand, rhs));
951     TF_RETURN_IF_ERROR(lhs->ReplaceOperandWith(0, new_operand));
952     return ReplaceInstruction(add, lhs);
953   }
954   return OkStatus();
955 }
956 
TrySimplifyTautologicalCompare(HloInstruction * conjunction)957 StatusOr<bool> AlgebraicSimplifierVisitor::TrySimplifyTautologicalCompare(
958     HloInstruction* conjunction) {
959   HloInstruction *lhs, *rhs;
960   if (!Match(conjunction, m::And(m::Op(&lhs), m::Op(&rhs)))) {
961     return false;
962   }
963   struct LessThanCompareInfo {  // (LT var constant)
964     HloInstruction* var;
965     int64_t constant;
966   };
967 
968   auto get_compare_info =
969       [&](HloInstruction* cmp) -> std::optional<LessThanCompareInfo> {
970     HloInstruction *lhs, *rhs;
971     auto scalar_shape_matcher =
972         m::Shape().IsEffectiveScalar().WithElementType(PrimitiveType::S32);
973     if (Match(cmp, m::Compare(m::Op(&lhs),
974                               m::Constant(&rhs).WithShape(scalar_shape_matcher))
975                        .WithComparisonDirection(ComparisonDirection::kLt))) {
976       return {LessThanCompareInfo{lhs, *rhs->literal().GetFirstInteger()}};
977     } else if (Match(
978                    cmp,
979                    m::Compare(m::Constant(&lhs).WithShape(scalar_shape_matcher),
980                               m::Op(&rhs))
981                        .WithComparisonDirection(ComparisonDirection::kGt))) {
982       return {LessThanCompareInfo{rhs, *lhs->literal().GetFirstInteger()}};
983     }
984     return std::nullopt;
985   };
986 
987   std::optional<LessThanCompareInfo> lhs_info = get_compare_info(lhs);
988   std::optional<LessThanCompareInfo> rhs_info = get_compare_info(rhs);
989   if (lhs_info && rhs_info && lhs_info->var == rhs_info->var) {
990     int64_t new_bound = std::min(lhs_info->constant, rhs_info->constant);
991     TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
992         conjunction,
993         HloInstruction::CreateCompare(lhs->shape(), lhs_info->var,
994                                       MakeScalarLike(lhs_info->var, new_bound),
995                                       ComparisonDirection::kLt)));
996     return true;
997   }
998   return false;
999 }
1000 
HandleAnd(HloInstruction * logical_and)1001 Status AlgebraicSimplifierVisitor::HandleAnd(HloInstruction* logical_and) {
1002   HloInstruction *lhs, *rhs;
1003   CHECK(Match(logical_and, m::And(m::Op(&lhs), m::Op(&rhs))));
1004   // Simplify logical and
1005   if (ShapeUtil::HasPrimitiveType(lhs->shape(), xla::PRED) &&
1006       ShapeUtil::HasPrimitiveType(rhs->shape(), xla::PRED)) {
1007     // A && True => A
1008     VLOG(10) << "trying transform [A && True => A]: "
1009              << logical_and->ToString();
1010     if (IsAll(rhs, 1) && ReplaceInstructionIfCompatible(logical_and, lhs)) {
1011       return OkStatus();
1012     }
1013     // True && A => A
1014     VLOG(10) << "trying transform [True && A => A]: "
1015              << logical_and->ToString();
1016     if (IsAll(lhs, 1) && ReplaceInstructionIfCompatible(logical_and, rhs)) {
1017       return OkStatus();
1018     }
1019   }
1020 
1021   // A && False => False or A & 0 => 0
1022   VLOG(10) << "trying transform [A && False => False]: "
1023            << logical_and->ToString();
1024   if (IsAll(rhs, 0) && ReplaceInstructionIfCompatible(logical_and, rhs)) {
1025     return OkStatus();
1026   }
1027 
1028   // False && A => False or A & 0 => 0
1029   VLOG(10) << "trying transform [False && A => False]: "
1030            << logical_and->ToString();
1031   if (IsAll(lhs, 0) && ReplaceInstructionIfCompatible(logical_and, lhs)) {
1032     return OkStatus();
1033   }
1034 
1035   // Simplify tautological conjunctions.
1036   TF_ASSIGN_OR_RETURN(bool found_tautological_compare,
1037                       TrySimplifyTautologicalCompare(logical_and));
1038   if (found_tautological_compare) {
1039     return OkStatus();
1040   }
1041 
1042   return OkStatus();
1043 }
1044 
HandleBitcast(HloInstruction * bitcast)1045 Status AlgebraicSimplifierVisitor::HandleBitcast(HloInstruction* bitcast) {
1046   // If a bitcast feeds a bitcast, make it a single bitcast.
1047   // Make sure the whole chain of bitcasts is optimized.
1048   if (bitcast->operand(0)->opcode() == HloOpcode::kBitcast) {
1049     TF_RETURN_IF_ERROR(HandleBitcast(bitcast->mutable_operand(0)));
1050   }
1051   HloInstruction* op;
1052   if (Match(bitcast, m::Bitcast(m::Bitcast(m::Op(&op))))) {
1053     return ReplaceWithNewInstruction(
1054         bitcast, HloInstruction::CreateBitcast(bitcast->shape(), op));
1055   }
1056   // All bitcasts can be eliminated (assuming layout constraints are satisfied).
1057   ReplaceInstructionIfCompatible(bitcast, bitcast->mutable_operand(0));
1058   return OkStatus();
1059 }
1060 
1061 // Compute a pair of maps for a bitcast operation, specifically between its
1062 // result logical dimensions and the original logical dimensions of the operand.
1063 // The maps are computed by matching the physical layout dimensions
1064 // (minor-to-major) of the operands and the bitcasted result. Overall they
1065 // record how the different logical dimensions of the operand may be combined or
1066 // split in the resulting shape and in which orders they are combined/split. The
1067 // function returns  std::nullopt if unsuccessful (e.g., such a logical
1068 // dimension mapping cannot be constructed due to cases like bitcasting {4,4} to
1069 // {2,8}.
1070 std::optional<std::vector<std::vector<int64_t>>>
ComputeBitcastDimMap(const Shape & bitcast_shape,const Shape & operand_shape)1071 AlgebraicSimplifierVisitor::ComputeBitcastDimMap(const Shape& bitcast_shape,
1072                                                  const Shape& operand_shape) {
1073   std::vector<std::vector<int64_t>> operand_dim_map(
1074       operand_shape.dimensions_size());
1075   int64_t bitcast_rank = bitcast_shape.dimensions_size();
1076   int64_t operand_rank = operand_shape.dimensions_size();
1077   int64_t cur_bitcast_size = 1, cur_operand_size = 1;
1078   int64_t operand_pos = -1, operand_dim = -1;
1079   for (int64_t bitcast_pos = 0; bitcast_pos < bitcast_rank; ++bitcast_pos) {
1080     int64_t bitcast_dim = bitcast_shape.layout().minor_to_major(bitcast_pos);
1081     if (operand_pos >= operand_rank) {
1082       if (bitcast_shape.dimensions(bitcast_dim) != 1) {
1083         VLOG(3) << "Abort b/c bitcasted size is bigger than operand size.\n";
1084         return std::nullopt;
1085       }
1086       continue;
1087     }
1088     CHECK_LT(bitcast_dim, bitcast_shape.dimensions_size());
1089     int64_t bitcast_dim_size = bitcast_shape.dimensions()[bitcast_dim];
1090     auto prev_bitcast_size = cur_bitcast_size;
1091     cur_bitcast_size *= bitcast_dim_size;
1092     VLOG(2) << "bitcast pos = " << bitcast_pos << "\n";
1093     VLOG(2) << "bitcast size = " << cur_bitcast_size << "\n";
1094     if (cur_operand_size < cur_bitcast_size &&
1095         prev_bitcast_size < cur_operand_size) {
1096       // Here we are bitcasting (m1,n1) to (m2,n2), with m1 > m2 and m2 * n2
1097       // < m1, so (m1,n1) is re-partitioned instead of split or combined.
1098       VLOG(3) << "Abort b/c re-partitioning a group of dimensions is not "
1099                  "supported. \n";
1100       return std::nullopt;
1101     }
1102     while (operand_pos < operand_rank) {
1103       if (operand_pos < 0 || cur_operand_size < cur_bitcast_size) {
1104         VLOG(2) << "operand size < bitcase size\n";
1105         operand_pos++;
1106         if (operand_pos >= operand_rank) {
1107           VLOG(2)
1108               << "Abort due to size inconsistency: bitcasted size > operand "
1109                  "size.\n";
1110           return std::nullopt;
1111         }
1112         operand_dim = operand_shape.layout().minor_to_major(operand_pos);
1113         int64_t op_dim_size = operand_shape.dimensions()[operand_dim];
1114         cur_operand_size *= op_dim_size;
1115         VLOG(3) << "operand size = " << cur_operand_size << "\n";
1116         if (cur_operand_size > cur_bitcast_size &&
1117             op_dim_size < bitcast_dim_size && operand_pos > 0) {
1118           // Here we are bitcasting (m1,n1) to (m2,n2), with n1 < n2 and m1 * n1
1119           // > m2, so (m1,n1) is re-partitioned instead of split or combined.
1120           VLOG(3) << "Abort b/c re-partitioning a group of dimensions is not "
1121                      "supported. \n";
1122           return std::nullopt;
1123         }
1124       }
1125       CHECK_GE(operand_dim, 0);
1126       if (operand_shape.dimensions(operand_dim) > 1) {
1127         CHECK_LT(operand_dim, operand_dim_map.size());
1128         operand_dim_map[operand_dim].push_back(bitcast_dim);
1129         VLOG(3) << "operand dim_map[operand_dim] add " << bitcast_dim << " at "
1130                 << operand_dim << "\n";
1131       }
1132       if (cur_operand_size >= cur_bitcast_size) {
1133         VLOG(3) << cur_operand_size << ">=" << cur_bitcast_size << "\n";
1134         CHECK_GE(operand_dim, 0);
1135         // If operand_dim is a degenerate one, move on to the next dimension.
1136         if (operand_shape.dimensions()[operand_dim] == 1) {
1137           operand_pos++;
1138         }
1139         break;
1140       }
1141     }
1142   }
1143   return operand_dim_map;
1144 }
1145 
ReshapeLayoutDimensions(const Shape & original_shape,const Shape & result_shape,const std::vector<std::vector<int64_t>> & original_map,const std::vector<std::vector<int64_t>> & result_map)1146 std::optional<Shape> AlgebraicSimplifierVisitor::ReshapeLayoutDimensions(
1147     const Shape& original_shape, const Shape& result_shape,
1148     const std::vector<std::vector<int64_t>>& original_map,
1149     const std::vector<std::vector<int64_t>>& result_map) {
1150   auto original_dimensions = original_shape.layout().minor_to_major();
1151   Shape new_shape = result_shape;
1152   auto* reshaped_dimensions =
1153       new_shape.mutable_layout()->mutable_minor_to_major();
1154   int64_t bitcast_pos = -1;
1155   for (int64_t op_pos = 0; op_pos < original_dimensions.size(); ++op_pos) {
1156     int64_t op_dim = original_dimensions[op_pos];
1157     VLOG(3) << "op_pos = " << op_pos << "\n";
1158     VLOG(3) << "op_dim = " << op_dim << "\n";
1159     if (original_map.size() <= op_dim) {
1160       VLOG(3) << "Skip due to original_map has too few dimensions.\n";
1161       continue;
1162     }
1163     auto bit_dims = original_map[op_dim];
1164     for (int64_t bitcast_dim : bit_dims) {
1165       if (result_shape.dimensions(bitcast_dim) == 1) {
1166         // Postpone all degenerated dimensions (those with size 1) to the end.
1167         continue;
1168       }
1169       VLOG(3) << "Add new reshaped dimension:" << bitcast_dim << "\n";
1170       if (bitcast_pos < 0 ||
1171           (*reshaped_dimensions)[bitcast_pos] != bitcast_dim) {
1172         bitcast_pos++;
1173         // If bitcast_pos has been over incremented, the new bitcast would
1174         // have to combine non-contiguous dimensions in op. Abort.
1175         if (bitcast_pos >= reshaped_dimensions->size()) {
1176           VLOG(3) << "bitcast pos is over incremented:" << bitcast_pos << "\n";
1177           return std::nullopt;
1178         }
1179         (*reshaped_dimensions)[bitcast_pos] = bitcast_dim;
1180       }
1181       auto op_dims = result_map[bitcast_dim];
1182       if (op_dims.size() > 1 && op_pos > 0) {
1183         // Check that op dimensions that are combined into bitcast_dim are not
1184         // non-contiguous or reordered to be different from how they appear in
1185         // result_map.
1186         int64_t op_dim_prev = original_dimensions[op_pos - 1];
1187         // If the current dimension is not the first being combined into
1188         // bitcast_dim, or is not contiguous with the previous dimension, abort.
1189         if (op_dims[0] != op_dim &&
1190             (original_map[op_dim_prev].empty() ||
1191              original_map[op_dim_prev][0] != bitcast_dim)) {
1192           VLOG(2) << "Abort b/c op dimensions that are combined into "
1193                      "bitcast_dim are not contiguous in the result. \n ";
1194           return std::nullopt;
1195         }
1196         // Now perform the dimension re-ordering check in the bitcast.
1197         for (int i = 0; i < op_dims.size(); ++i) {
1198           if (op_dims[i] == op_dim_prev) {
1199             if (i == op_dims.size() - 1 || op_dims[i + 1] != op_dim) {
1200               VLOG(2) << "Abort b/c op dimensions that are combined into "
1201                          "bitcast_dim are reordered in the new bitcast. \n ";
1202               return std::nullopt;
1203             }
1204           }
1205         }
1206       }
1207     }
1208   }
1209   for (int i = 0; i < result_shape.rank(); ++i) {
1210     if (result_shape.dimensions(i) == 1) {
1211       bitcast_pos++;
1212       // Since there is a possiblity of over-incrementing bitcast_pos
1213       // we need such a check here also before accessing the vector.
1214       // Overincrementing is possible when the result's dimension is
1215       // smaller than the original dimension.
1216       if (bitcast_pos >= reshaped_dimensions->size()) {
1217         VLOG(3) << "bitcast pos is over incremented:" << bitcast_pos << "\n";
1218         return std::nullopt;
1219       }
1220       (*reshaped_dimensions)[bitcast_pos] = i;
1221     }
1222   }
1223   CHECK_EQ(bitcast_pos + 1, result_shape.rank());
1224   return new_shape;
1225 }
1226 
1227 std::vector<std::vector<int64_t>>
InvertBitcastDimMap(const Shape & original_shape,const Shape & bitcast_shape,const std::vector<std::vector<int64_t>> & original_map)1228 AlgebraicSimplifierVisitor::InvertBitcastDimMap(
1229     const Shape& original_shape, const Shape& bitcast_shape,
1230     const std::vector<std::vector<int64_t>>& original_map) {
1231   std::vector<std::vector<int64_t>> result_map(bitcast_shape.dimensions_size());
1232   // Invert the operand map into result map.
1233   for (auto i = 0; i < original_shape.rank(); ++i) {
1234     auto j = original_shape.layout().minor_to_major(i);
1235     VLOG(3) << "traversing minor to major (" << i << ")=" << j;
1236     for (auto k : original_map[j]) {
1237       VLOG(3) << "setting result_map[" << k << "] = " << j << "\n";
1238       result_map[k].push_back(j);
1239     }
1240   }
1241   return result_map;
1242 }
1243 
SwapCopyBitcastCopy(HloInstruction * root_copy)1244 bool AlgebraicSimplifierVisitor::SwapCopyBitcastCopy(
1245     HloInstruction* root_copy) {
1246   if (root_copy->opcode() != HloOpcode::kCopy) {
1247     return false;
1248   }
1249   HloInstruction* bitcast = root_copy->mutable_operand(0);
1250   if (bitcast->opcode() != HloOpcode::kBitcast) {
1251     return false;
1252   }
1253   // All bitcasts above can be collapsed.
1254   HloInstruction* copy = bitcast->mutable_operand(0);
1255   while (copy->opcode() == HloOpcode::kBitcast) {
1256     copy = copy->mutable_operand(0);
1257   }
1258   if (copy->opcode() != HloOpcode::kCopy) {
1259     return false;
1260   }
1261   VLOG(2) << "Processing " << copy->ToString() << "\n"
1262           << bitcast->ToString() << "\n"
1263           << root_copy->ToString() << "\n";
1264   HloInstruction* op = copy->mutable_operand(0);
1265   // Compute a pair of maps between op dimensions and bitcast dimensions.
1266   auto dim_map = ComputeBitcastDimMap(bitcast->shape(), copy->shape());
1267   if (!dim_map.has_value()) {
1268     VLOG(3) << "Failed to compute bitcast map.";
1269     return false;
1270   }
1271   std::vector<std::vector<int64_t>> operand_map = dim_map.value();
1272   if (!ValidateTilingOfBitcast(bitcast->shape(), copy->shape(), operand_map)) {
1273     VLOG(2) << "Abort because bitcast changes tiling assignment.\n";
1274     return false;
1275   }
1276   std::vector<std::vector<int64_t>> result_map =
1277       InvertBitcastDimMap(copy->shape(), bitcast->shape(), operand_map);
1278   if (ValidateTilingOfBitcast(bitcast->shape(), op->shape(), operand_map)) {
1279     auto new_shape = ReshapeLayoutDimensions(op->shape(), bitcast->shape(),
1280                                              operand_map, result_map);
1281     if (!new_shape.has_value() || !IsValidLayout(new_shape.value())) {
1282       return false;
1283     }
1284     auto repl = HloInstruction::CreateUnary(
1285         root_copy->shape(), HloOpcode::kCopy,
1286         bitcast->AddInstruction(
1287             bitcast->CloneWithNewOperands(new_shape.value(), {op})));
1288     VLOG(2) << "Replace with " << repl->operand(0)->ToString() << "\n"
1289             << repl->ToString() << "\n";
1290     TF_CHECK_OK(ReplaceWithNewInstruction(root_copy, std::move(repl)));
1291     return true;
1292   }
1293 
1294   if (ValidateTilingOfBitcast(copy->shape(), root_copy->shape(), result_map)) {
1295     auto new_shape = ReshapeLayoutDimensions(root_copy->shape(), copy->shape(),
1296                                              result_map, operand_map);
1297     if (!new_shape.has_value() || !IsValidLayout(new_shape.value())) {
1298       return false;
1299     }
1300     auto repl = HloInstruction::CreateUnary(
1301         root_copy->shape(), HloOpcode::kBitcast,
1302         bitcast->AddInstruction(
1303             root_copy->CloneWithNewOperands(new_shape.value(), {op})));
1304     VLOG(2) << "Replace with " << repl->operand(0)->ToString() << "\n"
1305             << repl->ToString() << "\n";
1306     TF_CHECK_OK(ReplaceWithNewInstruction(root_copy, std::move(repl)));
1307     return true;
1308   }
1309   return false;
1310 }
1311 
HandleBitcastConvert(HloInstruction * bitcast)1312 Status AlgebraicSimplifierVisitor::HandleBitcastConvert(
1313     HloInstruction* bitcast) {
1314   TF_ASSIGN_OR_RETURN(bool replaced,
1315                       TrySimplifyTautologicalBitcastConvert(bitcast));
1316   if (replaced) {
1317     return OkStatus();
1318   }
1319   // Eliminate bitcast converts between same shape.
1320   ReplaceInstructionIfCompatible(bitcast, bitcast->mutable_operand(0));
1321   return OkStatus();
1322 }
1323 
HandleCopy(HloInstruction * copy)1324 Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) {
1325   if (SwapCopyBitcastCopy(copy)) {
1326     return OkStatus();
1327   }
1328   // If a copy feeds a copy, make it a single copy.
1329   HloInstruction* op;
1330   if (Match(copy, m::Copy(m::Copy(m::Op(&op))))) {
1331     if (ShapeUtil::Equal(op->shape(), copy->shape())) {
1332       return ReplaceInstruction(copy, op);
1333     }
1334     return ReplaceWithNewInstruction(
1335         copy, HloInstruction::CreateUnary(copy->shape(), HloOpcode::kCopy, op));
1336   }
1337   // All copies can be eliminated (assuming layout constraints are satisfied).
1338   if (ReplaceInstructionIfCompatible(copy, copy->mutable_operand(0))) {
1339     return OkStatus();
1340   }
1341 
1342   if (HloInstruction* bitcast_operand =
1343           BitcastingOperandOfReshapeOrCopyChain(copy, options_)) {
1344     ReplaceWithBitcast(copy, bitcast_operand);
1345     return OkStatus();
1346   }
1347 
1348   // Replace Copy(Reshape()) with Reshape() if the Reshape is a logical bitcast.
1349   if (copy->operand(0)->opcode() == HloOpcode::kReshape &&
1350       copy->operand(0)->user_count() == 1 &&
1351       ShapeUtil::ReshapeIsBitcast(copy->operand(0)->shape(), copy->shape())) {
1352     return ReplaceWithNewInstruction(
1353         copy,
1354         copy->operand(0)->CloneWithNewOperands(
1355             copy->shape(), {copy->mutable_operand(0)->mutable_operand(0)}));
1356   }
1357   return OkStatus();
1358 }
1359 
HandleConcatenate(HloInstruction * concatenate)1360 Status AlgebraicSimplifierVisitor::HandleConcatenate(
1361     HloInstruction* concatenate) {
1362   absl::Span<HloInstruction* const> operands(concatenate->operands());
1363   if (operands.size() == 1) {
1364     // Unary concatenates are useless.
1365     ReplaceInstructionIfCompatible(concatenate, operands[0]);
1366     return OkStatus();
1367   }
1368   // Filter out and remove empty operands.
1369   std::vector<HloInstruction*> nonempty_operands;
1370   for (HloInstruction* operand : operands) {
1371     if (!ShapeUtil::IsZeroElementArray(operand->shape())) {
1372       nonempty_operands.push_back(operand);
1373     }
1374   }
1375   if (nonempty_operands.size() < operands.size()) {
1376     HloInstruction* replacement;
1377     if (nonempty_operands.empty()) {
1378       replacement = operands[0];
1379     } else if (nonempty_operands.size() == 1) {
1380       replacement = nonempty_operands[0];
1381     } else {
1382       replacement =
1383           concatenate->AddInstruction(concatenate->CloneWithNewOperands(
1384               concatenate->shape(), nonempty_operands));
1385     }
1386     VLOG(10) << "trying to replace " << concatenate->ToString() << " with "
1387              << replacement->ToString();
1388     ReplaceInstructionIfCompatible(concatenate, replacement);
1389     return OkStatus();
1390   }
1391 
1392   if (options_.is_layout_sensitive()) {
1393     return OkStatus();
1394   }
1395 
1396   // concat(x, concat(y, z)) -> concat(x, y, z).  We only do this in
1397   // layout-insensitive mode because some backends may have (late,
1398   // layout-sensitive) passes that break up ops with many operands into smaller
1399   // pieces.  This would undo that.
1400   absl::InlinedVector<HloInstruction*, 8> unnested_concat_operands;
1401   for (HloInstruction* operand : operands) {
1402     if (operand->opcode() == HloOpcode::kConcatenate &&
1403         operand->concatenate_dimension() ==
1404             concatenate->concatenate_dimension()) {
1405       for (HloInstruction* instr : operand->operands()) {
1406         unnested_concat_operands.push_back(instr);
1407       }
1408     } else {
1409       unnested_concat_operands.push_back(operand);
1410     }
1411   }
1412   if (unnested_concat_operands.size() != concatenate->operand_count()) {
1413     return ReplaceWithNewInstruction(
1414         concatenate, HloInstruction::CreateConcatenate(
1415                          concatenate->shape(), unnested_concat_operands,
1416                          concatenate->concatenate_dimension()));
1417   }
1418 
1419   // Check if we can merge "adjacent" slice operands which take slices from the
1420   // same other op. For simplicity we only merge unstrided slices.
1421   int64_t concatenate_dimension = concatenate->concatenate_dimension();
1422   std::vector<HloInstruction*> new_operands;
1423   int64_t i = 0;
1424   while (i < operands.size()) {
1425     if (operands[i]->opcode() != HloOpcode::kSlice ||
1426         !IsUnstridedSlice(operands[i])) {
1427       new_operands.push_back(operands[i]);
1428       ++i;
1429       continue;
1430     }
1431     int64_t slice_end = operands[i]->slice_limits(concatenate_dimension);
1432     HloInstruction* slice_operand = operands[i]->mutable_operand(0);
1433     int64_t j = i + 1;
1434     while (j < operands.size()) {
1435       if (operands[j]->opcode() != HloOpcode::kSlice ||
1436           !IsUnstridedSlice(operands[j]) ||
1437           operands[j]->operand(0) != slice_operand ||
1438           operands[j]->slice_starts(concatenate_dimension) != slice_end) {
1439         break;
1440       }
1441       // Check that all the slice_start values are the same in all other
1442       // dimensions. This implies that the slice_limit values are also the same,
1443       // because operands of concatenate need to have the same shape, and we
1444       // already checked that the slices are unstrided.
1445       bool same_other_starts = true;
1446       for (int64_t k = 0; k < operands[j]->slice_starts().size(); ++k) {
1447         if (k == concatenate_dimension) {
1448           continue;
1449         }
1450         if (operands[i]->slice_starts(k) != operands[j]->slice_starts(k)) {
1451           same_other_starts = false;
1452           break;
1453         }
1454       }
1455       if (!same_other_starts) {
1456         break;
1457       }
1458       slice_end = operands[j]->slice_limits(concatenate_dimension);
1459       ++j;
1460     }
1461     if (j - i > 1) {
1462       Shape new_slice_shape = operands[i]->shape();
1463       new_slice_shape.set_dimensions(
1464           concatenate_dimension,
1465           slice_end - operands[i]->slice_starts(concatenate_dimension));
1466       simplifier_->UpdateLayout(&new_slice_shape);
1467       auto new_limit_indices = operands[i]->slice_limits();
1468       new_limit_indices[concatenate_dimension] = slice_end;
1469       auto new_slice_op =
1470           operands[i]->AddInstruction(HloInstruction::CreateSlice(
1471               new_slice_shape, slice_operand,
1472               /*start_indices=*/operands[i]->slice_starts(),
1473               /*limit_indices=*/new_limit_indices,
1474               /*strides=*/operands[i]->slice_strides()));
1475       new_operands.push_back(new_slice_op);
1476     } else {
1477       new_operands.push_back(operands[i]);
1478     }
1479     i = j;
1480   }
1481   if (new_operands.size() < operands.size()) {
1482     auto replacement = concatenate->AddInstruction(
1483         concatenate->CloneWithNewOperands(concatenate->shape(), new_operands));
1484     ReplaceInstructionIfCompatible(concatenate, replacement);
1485     return OkStatus();
1486   }
1487 
1488   if (operands.size() == 2) {
1489     // A binary concat with a broadcasted scalar as an operand can be converted
1490     // into a pad which is simpler to fold into other operations.
1491     bool is_effective_low_pad = Match(
1492         operands[0], m::Broadcast(m::Op().WithShape(m::Shape().IsScalar())));
1493     bool is_effective_high_pad = Match(
1494         operands[1], m::Broadcast(m::Op().WithShape(m::Shape().IsScalar())));
1495     if (!is_effective_low_pad && !is_effective_high_pad) {
1496       return OkStatus();
1497     }
1498     PaddingConfig padding_config;
1499     for (int64_t dim = 0; dim < operands[0]->shape().rank(); ++dim) {
1500       auto padding_config_dim = padding_config.add_dimensions();
1501       padding_config_dim->set_edge_padding_high(0);
1502       padding_config_dim->set_edge_padding_low(0);
1503       padding_config_dim->set_interior_padding(0);
1504       if (dim == concatenate_dimension) {
1505         if (is_effective_low_pad) {
1506           padding_config_dim->set_edge_padding_low(
1507               operands[0]->shape().dimensions(dim));
1508         } else {
1509           padding_config_dim->set_edge_padding_high(
1510               operands[1]->shape().dimensions(dim));
1511         }
1512       }
1513     }
1514     int64_t operand_to_pad = is_effective_low_pad ? 1 : 0;
1515     int64_t pad_value_operand = is_effective_low_pad ? 0 : 1;
1516     HloInstruction* pad = concatenate->AddInstruction(HloInstruction::CreatePad(
1517         concatenate->shape(), operands[operand_to_pad],
1518         operands[pad_value_operand]->mutable_operand(0), padding_config));
1519     return ReplaceInstruction(concatenate, pad);
1520   }
1521 
1522   if (absl::c_count(operands, operands[0]) == operands.size() &&
1523       operands[0]->shape().dimensions(concatenate_dimension) == 1) {
1524     Shape new_shape = operands[0]->shape();
1525     DimensionVector broadcast_dims;
1526     for (int64_t i = 0; i < new_shape.rank(); ++i) {
1527       if (i == concatenate_dimension) {
1528         continue;
1529       }
1530       broadcast_dims.push_back(i);
1531     }
1532     new_shape.DeleteDimension(concatenate_dimension);
1533     return ReplaceInstruction(
1534         concatenate,
1535         MakeBroadcastHlo(MakeReshapeHlo(new_shape, operands[0]).ValueOrDie(),
1536                          broadcast_dims, concatenate->shape()));
1537   }
1538   return OkStatus();
1539 }
1540 
1541 StatusOr<bool>
TrySimplifyTautologicalBitcastConvert(HloInstruction * bitcast)1542 AlgebraicSimplifierVisitor::TrySimplifyTautologicalBitcastConvert(
1543     HloInstruction* bitcast) {
1544   CHECK_EQ(bitcast->opcode(), HloOpcode::kBitcastConvert);
1545   PrimitiveType outer_to = bitcast->shape().element_type();
1546   HloInstruction* concat = bitcast->mutable_operand(0);
1547   if (concat->opcode() != HloOpcode::kConcatenate) {
1548     return false;
1549   }
1550   std::vector<HloInstruction*> outer_inputs;
1551   std::vector<HloInstruction*> to_remove_bitcasts;
1552   for (int i = 0; i < concat->operand_count(); i++) {
1553     HloInstruction* in = concat->mutable_operand(i);
1554     if (in->opcode() != HloOpcode::kBitcastConvert ||
1555         in->operand(0)->shape().element_type() != outer_to) {
1556       return false;
1557     }
1558     outer_inputs.push_back(in->mutable_operand(0));
1559     to_remove_bitcasts.push_back(in);
1560   }
1561 
1562   const int64_t concat_dim = concat->concatenate_dimension();
1563   TF_ASSIGN_OR_RETURN(HloInstruction * new_concat,
1564                       MakeConcatHlo(outer_inputs, concat_dim));
1565   TF_RETURN_IF_ERROR(ReplaceInstruction(bitcast, new_concat));
1566 
1567   return true;
1568 }
1569 
BuildTupleConstant(HloComputation * computation,const LiteralSlice & literal,AlgebraicSimplifier * simplifier)1570 static HloInstruction* BuildTupleConstant(HloComputation* computation,
1571                                           const LiteralSlice& literal,
1572                                           AlgebraicSimplifier* simplifier) {
1573   if (literal.shape().IsTuple()) {
1574     std::vector<HloInstruction*> elems;
1575     elems.reserve(ShapeUtil::TupleElementCount(literal.shape()));
1576     for (int i = 0; i < ShapeUtil::TupleElementCount(literal.shape()); ++i) {
1577       elems.push_back(BuildTupleConstant(
1578           computation, LiteralSlice(literal, {i}), simplifier));
1579     }
1580     return computation->AddInstruction(HloInstruction::CreateTuple(elems));
1581   } else {
1582     return computation->AddInstruction(
1583         simplifier->CreateConstantWithLayoutUpdated(literal.Clone()));
1584   }
1585 }
1586 
HandleConstant(HloInstruction * constant)1587 Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) {
1588   // Tuple constants aren't directly supported by any backend. Expand them into
1589   // explicit Tuple instructions.
1590   if (constant->shape().IsTuple()) {
1591     return ReplaceInstruction(
1592         constant,
1593         BuildTupleConstant(computation_, constant->literal(), simplifier_));
1594   }
1595 
1596   if (constant->shape().element_type() == TOKEN) {
1597     return OkStatus();
1598   }
1599 
1600   // If a literal is all the same element replace it with a scalar broadcast.
1601   if (ShapeUtil::ElementsIn(constant->shape()) > 1 &&
1602       constant->literal().IsAllFirst()) {
1603     Literal unique_scalar(
1604         LiteralUtil::GetFirstScalarLiteral(constant->literal()));
1605     HloInstruction* scalar = constant->AddInstruction(
1606         simplifier_->CreateConstantWithLayoutUpdated(std::move(unique_scalar)));
1607     return ReplaceWithNewInstruction(
1608         constant,
1609         HloInstruction::CreateBroadcast(constant->shape(), scalar, {}));
1610   }
1611 
1612   // If a literal is an increasing sequence from zero, replace it with an iota.
1613   if (ShapeUtil::ElementsIn(constant->shape()) > 1 &&
1614       constant->literal().IsR1Iota()) {
1615     return ReplaceWithNewInstruction(
1616         constant, HloInstruction::CreateIota(constant->shape(), 0));
1617   }
1618 
1619   if (std::optional<int64_t> stride = constant->literal().IsR1StridedIota()) {
1620     // Replace the constant with iota * stride.
1621     HloInstruction* stride_hlo = MakeScalarLike(constant, *stride);
1622     HloInstruction* iota = constant->AddInstruction(
1623         HloInstruction::CreateIota(constant->shape(), 0));
1624     return ReplaceWithNewInstruction(
1625         constant,
1626         HloInstruction::CreateBinary(constant->shape(), HloOpcode::kMultiply,
1627                                      iota, stride_hlo));
1628   }
1629 
1630   return OkStatus();
1631 }
1632 
HandleSubtract(HloInstruction * sub)1633 Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) {
1634   HloInstruction *lhs, *rhs;
1635   CHECK(Match(sub, m::Subtract(m::Op(&lhs), m::Op(&rhs))));
1636   // A - 0 => A
1637   VLOG(10) << "trying transform [A - 0 => A]: " << sub->ToString();
1638   if (IsAll(rhs, 0) && ReplaceInstructionIfCompatible(sub, lhs)) {
1639     return OkStatus();
1640   }
1641 
1642   // Canonicalize subtraction of a constant to addition.
1643   VLOG(10) << "trying transform [A - Const => A + (-Const)]";
1644   if (Match(sub, m::Subtract(m::NonConstant(&lhs), m::Constant(&rhs))) ||
1645       Match(sub, m::Subtract(m::NonConstant(&lhs),
1646                              m::Broadcast(m::Constant(&rhs))))) {
1647     HloInstruction* negative_const = rhs->AddInstruction(
1648         HloInstruction::CreateUnary(rhs->shape(), HloOpcode::kNegate, rhs));
1649     if (const HloInstruction* broadcast =
1650             DynCast<HloBroadcastInstruction>(sub->operand(1))) {
1651       negative_const = rhs->AddInstruction(HloInstruction::CreateBroadcast(
1652           broadcast->shape(), negative_const, broadcast->dimensions()));
1653     }
1654     return ReplaceWithNewInstruction(
1655         sub, HloInstruction::CreateBinary(sub->shape(), HloOpcode::kAdd, lhs,
1656                                           negative_const));
1657   }
1658 
1659   // A - A => 0 for integer A.
1660   VLOG(10) << "trying transform [A - A => 0] for integer A.";
1661   if (lhs == rhs && ShapeUtil::ElementIsIntegral(sub->shape())) {
1662     return ReplaceInstruction(sub, MakeScalarLike(sub, 0));
1663   }
1664 
1665   return OkStatus();
1666 }
1667 namespace {
1668 template <typename T>
InvertConstant(const HloInstruction & constant,Literal * result)1669 Status InvertConstant(const HloInstruction& constant, Literal* result) {
1670   return result->Populate<T>([&](absl::Span<const int64_t> indices) {
1671     return T{1.0} / constant.literal().Get<T>(indices);
1672   });
1673 }
1674 
1675 template <typename T>
TryDivideToShift(HloInstruction * divide,HloComputation * computation,AlgebraicSimplifier * simplifier)1676 std::unique_ptr<HloInstruction> TryDivideToShift(
1677     HloInstruction* divide, HloComputation* computation,
1678     AlgebraicSimplifier* simplifier) {
1679   HloInstruction *a, *b, *c;
1680   CHECK(Match(divide, m::Divide(m::Op(&a), m::Op(&b))));
1681 
1682   if (ShapeUtil::ElementIsIntegral(divide->shape()) &&
1683       !Match(b, m::ConstantEffectiveScalar(&c)) &&
1684       !Match(b, m::Broadcast(m::ConstantEffectiveScalar(&c)))) {
1685     return nullptr;
1686   }
1687 
1688   if (ShapeUtil::ElementIsSigned(divide->shape())) {
1689     int64_t b_value = c->literal().GetFirstElement<T>();
1690     if (b_value > 0 && absl::has_single_bit(static_cast<uint64_t>(b_value))) {
1691       // Handle negative dividends by negating the result of the division.
1692       HloInstruction* zero_like_a = MakeScalarLike(a, 0);
1693 
1694       Shape changed_shape = ShapeUtil::ChangeElementType(a->shape(), PRED);
1695       simplifier->UpdateLayout(&changed_shape);
1696       auto* dividend_is_negative =
1697           divide->AddInstruction(HloInstruction::CreateCompare(
1698               changed_shape, a, zero_like_a, ComparisonDirection::kLt));
1699 
1700       auto* negated_dividend = divide->AddInstruction(
1701           HloInstruction::CreateUnary(a->shape(), HloOpcode::kNegate, a));
1702 
1703       auto* abs_dividend = divide->AddInstruction(HloInstruction::CreateTernary(
1704           a->shape(), HloOpcode::kSelect, dividend_is_negative,
1705           negated_dividend, a));
1706 
1707       auto* quotient = divide->AddInstruction(HloInstruction::CreateBinary(
1708           divide->shape(), HloOpcode::kShiftRightLogical, abs_dividend,
1709           MakeScalarLike(abs_dividend, Log2Floor<uint64_t>(b_value))));
1710 
1711       auto* neqated_quotient =
1712           divide->AddInstruction(HloInstruction::CreateUnary(
1713               quotient->shape(), HloOpcode::kNegate, quotient));
1714 
1715       return HloInstruction::CreateTernary(divide->shape(), HloOpcode::kSelect,
1716                                            dividend_is_negative,
1717                                            neqated_quotient, quotient);
1718     }
1719   } else {
1720     uint64_t b_value = c->literal().GetFirstElement<T>();
1721     if (absl::has_single_bit(b_value)) {
1722       return HloInstruction::CreateBinary(
1723           divide->shape(), HloOpcode::kShiftRightLogical, a,
1724           MakeScalarLike(a, Log2Floor(b_value)));
1725     }
1726   }
1727 
1728   return nullptr;
1729 }
1730 }  // namespace
1731 
HandleDivide(HloInstruction * divide)1732 Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
1733   HloInstruction *a, *b, *c, *d;
1734   CHECK(Match(divide, m::Divide(m::Op(&a), m::Op(&b))));
1735   // A/1 => A
1736   VLOG(10) << "trying transform [A/1 => A]: " << divide->ToString();
1737   if (IsAll(b, 1) && ReplaceInstructionIfCompatible(divide, a)) {
1738     return OkStatus();
1739   }
1740 
1741   // A / B => A >> log2(B) if B is a power of 2.
1742   switch (divide->shape().element_type()) {
1743     case S8:
1744       if (std::unique_ptr<HloInstruction> shift =
1745               TryDivideToShift<int8_t>(divide, computation_, simplifier_)) {
1746         return ReplaceWithNewInstruction(divide, std::move(shift));
1747       }
1748       break;
1749     case S16:
1750       if (std::unique_ptr<HloInstruction> shift =
1751               TryDivideToShift<int16_t>(divide, computation_, simplifier_)) {
1752         return ReplaceWithNewInstruction(divide, std::move(shift));
1753       }
1754       break;
1755     case S32:
1756       if (std::unique_ptr<HloInstruction> shift =
1757               TryDivideToShift<int32_t>(divide, computation_, simplifier_)) {
1758         return ReplaceWithNewInstruction(divide, std::move(shift));
1759       }
1760       break;
1761     case S64:
1762       if (std::unique_ptr<HloInstruction> shift =
1763               TryDivideToShift<int64_t>(divide, computation_, simplifier_)) {
1764         return ReplaceWithNewInstruction(divide, std::move(shift));
1765       }
1766       break;
1767     case U8:
1768       if (std::unique_ptr<HloInstruction> shift =
1769               TryDivideToShift<uint8_t>(divide, computation_, simplifier_)) {
1770         return ReplaceWithNewInstruction(divide, std::move(shift));
1771       }
1772       break;
1773     case U16:
1774       if (std::unique_ptr<HloInstruction> shift =
1775               TryDivideToShift<uint16_t>(divide, computation_, simplifier_)) {
1776         return ReplaceWithNewInstruction(divide, std::move(shift));
1777       }
1778       break;
1779     case U32:
1780       if (std::unique_ptr<HloInstruction> shift =
1781               TryDivideToShift<uint32_t>(divide, computation_, simplifier_)) {
1782         return ReplaceWithNewInstruction(divide, std::move(shift));
1783       }
1784       break;
1785     case U64:
1786       if (std::unique_ptr<HloInstruction> shift =
1787               TryDivideToShift<uint64_t>(divide, computation_, simplifier_)) {
1788         return ReplaceWithNewInstruction(divide, std::move(shift));
1789       }
1790       break;
1791     default:
1792       break;
1793   }
1794 
1795   Shape* shape;
1796   // exp(A)/exp(B) => exp(A-B)
1797   if (Match(divide, m::Divide(m::Exp(m::Op(&a)), m::Exp(m::Op(&b)))
1798                         .WithShape(m::Shape(&shape)))) {
1799     VLOG(10) << "transform [exp(A)/exp(B) => exp(A-B)]: " << divide->ToString();
1800     HloInstruction* subtract = divide->AddInstruction(
1801         HloInstruction::CreateBinary(*shape, HloOpcode::kSubtract, a, b));
1802     return ReplaceWithNewInstruction(
1803         divide, HloInstruction::CreateUnary(*shape, HloOpcode::kExp, subtract));
1804   }
1805 
1806   // A/exp(B) => A*exp(-B)
1807   if (Match(divide, m::Divide(m::Op(&a), m::Exp(m::Op(&b))))) {
1808     VLOG(10) << "transform [A/exp(B) => A*exp(-B)]: " << divide->ToString();
1809     HloInstruction* negate = divide->AddInstruction(
1810         HloInstruction::CreateUnary(divide->shape(), HloOpcode::kNegate, b));
1811     HloInstruction* new_exp = divide->mutable_operand(1)->AddInstruction(
1812         HloInstruction::CreateUnary(divide->shape(), HloOpcode::kExp, negate));
1813     return ReplaceWithNewInstruction(
1814         divide, HloInstruction::CreateBinary(divide->shape(),
1815                                              HloOpcode::kMultiply, a, new_exp));
1816   }
1817 
1818   // A/pow(B,C) => A*pow(B,-C)
1819   if (Match(divide, m::Divide(m::Op(&a), m::Power(m::Op(&b), m::Op(&c))))) {
1820     VLOG(10) << "transform [A/pow(B,C) => A*pow(B,-C)]: " << divide->ToString();
1821     // The output shape of the created negate operator should be the same as the
1822     // input.
1823     const Shape& negate_shape = c->shape();
1824     HloInstruction* negate = divide->mutable_operand(1)->AddInstruction(
1825         HloInstruction::CreateUnary(negate_shape, HloOpcode::kNegate, c));
1826     // And the power operator should retain the output shape of the old one.
1827     const Shape& new_power_shape = b->shape();
1828     HloInstruction* new_power =
1829         divide->mutable_operand(1)->AddInstruction(HloInstruction::CreateBinary(
1830             new_power_shape, HloOpcode::kPower, b, negate));
1831     return ReplaceWithNewInstruction(
1832         divide, HloInstruction::CreateBinary(
1833                     divide->shape(), HloOpcode::kMultiply, a, new_power));
1834   }
1835 
1836   // A/sqrt(B) => A*rsqrt(X).
1837   if (Match(divide, m::Divide(m::Op(&a), m::Sqrt(m::Op(&b))))) {
1838     auto* rsqrt = divide->mutable_operand(1)->AddInstruction(
1839         HloInstruction::CreateUnary(divide->shape(), HloOpcode::kRsqrt, b));
1840     return ReplaceWithNewInstruction(
1841         divide, HloInstruction::CreateBinary(rsqrt->shape(),
1842                                              HloOpcode::kMultiply, a, rsqrt));
1843   }
1844 
1845   // A/rsqrt(B) => A*sqrt(B).
1846   if (Match(divide, m::Divide(m::Op(&a), m::Rsqrt(m::Op(&b))))) {
1847     auto* sqrt = divide->mutable_operand(1)->AddInstruction(
1848         HloInstruction::CreateUnary(divide->shape(), HloOpcode::kSqrt, b));
1849     return ReplaceWithNewInstruction(
1850         divide, HloInstruction::CreateBinary(sqrt->shape(),
1851                                              HloOpcode::kMultiply, a, sqrt));
1852   }
1853 
1854   // Simplifying integral division would produce unexpected results.
1855   if (ShapeUtil::ElementIsIntegral(divide->shape())) {
1856     return OkStatus();
1857   }
1858 
1859   // A / Const => A * (1 / Const)
1860   //
1861   // (Backends can do this transformation, but generally only if the constant is
1862   // a scalar.)
1863   if (Match(divide, m::Divide(m::NonConstant(&a), m::Op(&b))) &&
1864       (Match(b, m::Constant(&c)) || Match(b, m::Broadcast(m::Constant(&c))))) {
1865     Shape result_shape = c->literal().shape();
1866     Literal new_literal(result_shape);
1867     switch (result_shape.element_type()) {
1868       case F16:
1869         TF_RETURN_IF_ERROR(InvertConstant<half>(*c, &new_literal));
1870         break;
1871       case F32:
1872         TF_RETURN_IF_ERROR(InvertConstant<float>(*c, &new_literal));
1873         break;
1874       case BF16:
1875         TF_RETURN_IF_ERROR(InvertConstant<bfloat16>(*c, &new_literal));
1876         break;
1877       case F64:
1878         TF_RETURN_IF_ERROR(InvertConstant<double>(*c, &new_literal));
1879         break;
1880       case C64:
1881         TF_RETURN_IF_ERROR(InvertConstant<complex64>(*c, &new_literal));
1882         break;
1883       case C128:
1884         TF_RETURN_IF_ERROR(InvertConstant<complex128>(*c, &new_literal));
1885         break;
1886       default:
1887         return OkStatus();
1888     }
1889     auto inverse = c->AddInstruction(
1890         simplifier_->CreateConstantWithLayoutUpdated(new_literal.Clone()));
1891     if (b != c) {
1892       inverse = b->AddInstruction(HloInstruction::CreateBroadcast(
1893           b->shape(), inverse, b->dimensions()));
1894     }
1895     TF_ASSIGN_OR_RETURN(auto new_divide,
1896                         MakeBinaryHlo(HloOpcode::kMultiply, a, inverse));
1897     return ReplaceInstruction(divide, new_divide);
1898   }
1899 
1900   // (A / B) / (C / D)  =>  (A / B)*(D / C) => (A * D) / (B * C)
1901   if (Match(divide, m::Divide(m::Divide(m::Op(&a), m::Op(&b)),
1902                               m::Divide(m::Op(&c), m::Op(&d))))) {
1903     TF_ASSIGN_OR_RETURN(auto a_times_d,
1904                         MakeBinaryHlo(HloOpcode::kMultiply, a, d));
1905     TF_ASSIGN_OR_RETURN(auto b_times_c,
1906                         MakeBinaryHlo(HloOpcode::kMultiply, b, c));
1907     TF_ASSIGN_OR_RETURN(auto new_divide, MakeBinaryHlo(HloOpcode::kDivide,
1908                                                        a_times_d, b_times_c));
1909 
1910     return ReplaceInstruction(divide, new_divide);
1911   }
1912 
1913   // (A / B) / C => A / (B * C)
1914   if (Match(divide, m::Divide(m::Divide(m::Op(&a), m::Op(&b)), m::Op(&c)))) {
1915     TF_ASSIGN_OR_RETURN(auto b_times_c,
1916                         MakeBinaryHlo(HloOpcode::kMultiply, b, c));
1917     TF_ASSIGN_OR_RETURN(auto new_divide,
1918                         MakeBinaryHlo(HloOpcode::kDivide, a, b_times_c));
1919     return ReplaceInstruction(divide, new_divide);
1920   }
1921 
1922   // A / (B / C) => (A*C) / B
1923   if (Match(divide, m::Divide(m::Op(&a), m::Divide(m::Op(&b), m::Op(&c))))) {
1924     TF_ASSIGN_OR_RETURN(auto a_times_c,
1925                         MakeBinaryHlo(HloOpcode::kMultiply, a, c));
1926     TF_ASSIGN_OR_RETURN(auto new_divide,
1927                         MakeBinaryHlo(HloOpcode::kDivide, a_times_c, b));
1928     return ReplaceInstruction(divide, new_divide);
1929   }
1930 
1931   // If X is a convert from pred, then
1932   // X / broadcast(Y) => broadcast(1/Y) * X
1933   if (Match(divide,
1934             m::Divide(
1935                 m::Convert(&a,
1936                            m::Op().WithShape(m::Shape().WithElementType(PRED))),
1937                 m::Broadcast(m::Op(&b).WithShape(m::Shape().IsScalar()))))) {
1938     TF_ASSIGN_OR_RETURN(
1939         auto recip, MakeBinaryHlo(HloOpcode::kDivide, MakeScalarLike(b, 1), b));
1940     auto recip_bcast = divide->mutable_operand(1)->AddInstruction(
1941         HloInstruction::CreateBroadcast(divide->shape(), recip, {}));
1942     TF_ASSIGN_OR_RETURN(auto mul,
1943                         MakeBinaryHlo(HloOpcode::kMultiply, recip_bcast, a));
1944     return ReplaceInstruction(divide, mul);
1945   }
1946 
1947   return OkStatus();
1948 }
1949 
RemoveDegenerateDimensionFromDot(HloInstruction * dot)1950 StatusOr<bool> AlgebraicSimplifierVisitor::RemoveDegenerateDimensionFromDot(
1951     HloInstruction* dot) {
1952   const Shape& lhs_shape = dot->operand(0)->shape();
1953   int64_t num_degenerate_lhs_dims = 0;
1954   std::vector<int64_t> lhs_dimension_map(lhs_shape.rank(), -1);
1955   for (int64_t i = 0; i < lhs_shape.rank(); ++i) {
1956     if (lhs_shape.dimensions(i) == 1) {
1957       ++num_degenerate_lhs_dims;
1958     } else {
1959       lhs_dimension_map[i] = i - num_degenerate_lhs_dims;
1960     }
1961   }
1962 
1963   const Shape& rhs_shape = dot->operand(1)->shape();
1964   int64_t num_degenerate_rhs_dims = 0;
1965   std::vector<int64_t> rhs_dimension_map(rhs_shape.rank(), -1);
1966   for (int64_t i = 0; i < rhs_shape.rank(); ++i) {
1967     if (rhs_shape.dimensions(i) == 1) {
1968       ++num_degenerate_rhs_dims;
1969     } else {
1970       rhs_dimension_map[i] = i - num_degenerate_rhs_dims;
1971     }
1972   }
1973   if (num_degenerate_lhs_dims == 0 && num_degenerate_rhs_dims == 0) {
1974     return false;
1975   }
1976   const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
1977   DotDimensionNumbers new_dnums;
1978   for (int64_t dim : dnums.lhs_batch_dimensions()) {
1979     int64_t new_dim = lhs_dimension_map[dim];
1980     if (new_dim != -1) {
1981       new_dnums.add_lhs_batch_dimensions(new_dim);
1982     }
1983   }
1984   for (int64_t dim : dnums.lhs_contracting_dimensions()) {
1985     int64_t new_dim = lhs_dimension_map[dim];
1986     if (new_dim != -1) {
1987       new_dnums.add_lhs_contracting_dimensions(new_dim);
1988     }
1989   }
1990 
1991   for (int64_t dim : dnums.rhs_batch_dimensions()) {
1992     int64_t new_dim = rhs_dimension_map[dim];
1993     if (new_dim != -1) {
1994       new_dnums.add_rhs_batch_dimensions(new_dim);
1995     }
1996   }
1997   for (int64_t dim : dnums.rhs_contracting_dimensions()) {
1998     int64_t new_dim = rhs_dimension_map[dim];
1999     if (new_dim != -1) {
2000       new_dnums.add_rhs_contracting_dimensions(new_dim);
2001     }
2002   }
2003 
2004   HloInstruction* new_lhs =
2005       num_degenerate_lhs_dims > 0
2006           ? dot->parent()->AddInstruction(HloInstruction::CreateReshape(
2007                 ShapeUtil::DropDegenerateDimensions(lhs_shape),
2008                 dot->mutable_operand(0)))
2009           : dot->mutable_operand(0);
2010   HloInstruction* new_rhs =
2011       num_degenerate_rhs_dims > 0
2012           ? dot->parent()->AddInstruction(HloInstruction::CreateReshape(
2013                 ShapeUtil::DropDegenerateDimensions(rhs_shape),
2014                 dot->mutable_operand(1)))
2015           : dot->mutable_operand(1);
2016   TF_ASSIGN_OR_RETURN(
2017       auto new_dot,
2018       MakeDotHlo(new_lhs, new_rhs, new_dnums, dot->precision_config(),
2019                  /*preferred_element_type=*/dot->shape().element_type()));
2020   if (ShapeUtil::Compatible(dot->shape(), new_dot->shape())) {
2021     TF_RETURN_IF_ERROR(ReplaceInstruction(dot, new_dot));
2022   } else {
2023     TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
2024         dot, HloInstruction::CreateReshape(dot->shape(), new_dot)));
2025   }
2026   return true;
2027 }
2028 
RemoveTransposesFromDotOperands(HloInstruction * dot)2029 StatusOr<bool> AlgebraicSimplifierVisitor::RemoveTransposesFromDotOperands(
2030     HloInstruction* dot) {
2031   const int64_t rank = dot->shape().rank();
2032   const auto& dnums = dot->dot_dimension_numbers();
2033   HloInstruction* lhs = dot->mutable_operand(0);
2034   HloInstruction* rhs = dot->mutable_operand(1);
2035 
2036   // lhs and rhs must apply the same permutation.
2037   if (lhs->opcode() != HloOpcode::kTranspose ||
2038       rhs->opcode() != HloOpcode::kTranspose ||
2039       lhs->dimensions() != rhs->dimensions()) {
2040     return false;
2041   }
2042   absl::Span<const int64_t> permutation = lhs->dimensions();
2043 
2044   // Dot must be "somewhat canonical": batch dimensions at the beginning, one
2045   // contracting dimension, and one non-contracting dim.
2046   if (absl::MakeSpan(dnums.lhs_batch_dimensions()) !=
2047           absl::MakeSpan(dnums.rhs_batch_dimensions()) ||
2048       dnums.lhs_contracting_dimensions_size() != 1 ||
2049       dnums.rhs_contracting_dimensions_size() != 1 ||
2050       dnums.lhs_contracting_dimensions(0) != rank - 1 ||
2051       dnums.rhs_contracting_dimensions(0) != rank - 2 ||
2052       rank != dnums.lhs_batch_dimensions_size() + 2) {
2053     return false;
2054   }
2055 
2056   // The last two elements of the permutation must be either [rank-2, rank-1]
2057   // (i.e. no permutation) or [rank-1, rank-2].  Otherwise, this means that
2058   // we're permuting batch dimensions with the non-batch dimensions, which isn't
2059   // allowed.
2060   //
2061   // If the permutation ends with [rank - 1, rank - 2] then we're going to flip
2062   // the order of dot operands to dot(b,a).  Otherwise it stays dot(a,b).
2063   bool reorder_operands;
2064   if (permutation.subspan(rank - 2) ==
2065       std::array<int64_t, 2>{rank - 2, rank - 1}) {
2066     reorder_operands = false;
2067   } else if (permutation.subspan(rank - 2) ==
2068              std::array<int64_t, 2>{rank - 1, rank - 2}) {
2069     reorder_operands = true;
2070   } else {
2071     return false;
2072   }
2073 
2074   HloInstruction* new_lhs =
2075       reorder_operands ? rhs->mutable_operand(0) : lhs->mutable_operand(0);
2076   HloInstruction* new_rhs =
2077       reorder_operands ? lhs->mutable_operand(0) : rhs->mutable_operand(0);
2078   auto new_dot = dot->AddInstruction(HloInstruction::CreateDot(
2079       ShapeUtil::PermuteDimensions(permutation, dot->shape()), new_lhs, new_rhs,
2080       dnums,
2081       reorder_operands
2082           ? SwapOperandsInDotPrecisionConfig(dot->precision_config())
2083           : dot->precision_config()));
2084   TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
2085       dot,
2086       HloInstruction::CreateTranspose(dot->shape(), new_dot, permutation)));
2087   return true;
2088 }
2089 
2090 StatusOr<HloInstruction*>
NormalizeDotOperandToBatchMajorAndContractingMinor(HloInstruction * dot_operand,absl::Span<const int64_t> batch_dimensions,absl::Span<const int64_t> contracting_dimensions)2091 AlgebraicSimplifierVisitor::NormalizeDotOperandToBatchMajorAndContractingMinor(
2092     HloInstruction* dot_operand, absl::Span<const int64_t> batch_dimensions,
2093     absl::Span<const int64_t> contracting_dimensions) {
2094   std::vector<int64_t> transpose_dimensions(batch_dimensions.begin(),
2095                                             batch_dimensions.end());
2096   for (int64_t i = 0; i < dot_operand->shape().rank(); ++i) {
2097     if (!(absl::c_linear_search(batch_dimensions, i) ||
2098           absl::c_linear_search(contracting_dimensions, i))) {
2099       transpose_dimensions.push_back(i);
2100     }
2101   }
2102   transpose_dimensions.insert(transpose_dimensions.end(),
2103                               contracting_dimensions.begin(),
2104                               contracting_dimensions.end());
2105   if (absl::c_is_sorted(transpose_dimensions)) {
2106     return dot_operand;
2107   }
2108   return MakeTransposeHlo(dot_operand, transpose_dimensions);
2109 }
2110 
AddReduce(HloInstruction * hlo,absl::Span<const int64_t> dims,PrimitiveType type)2111 HloInstruction* AlgebraicSimplifierVisitor::AddReduce(
2112     HloInstruction* hlo, absl::Span<const int64_t> dims, PrimitiveType type) {
2113   HloInstruction* zero =
2114       computation_->AddInstruction(simplifier_->CreateConstantWithLayoutUpdated(
2115           LiteralUtil::Zero(hlo->shape().element_type()).Clone()));
2116   HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation(type);
2117   Shape shape = ShapeUtil::DeleteDimensions(dims, hlo->shape());
2118   simplifier_->UpdateLayout(&shape);
2119   return computation_->AddInstruction(HloInstruction::CreateReduce(
2120       shape, hlo, zero, dims, AddReduce_computation));
2121 }
2122 
OptimizeDotOfConcat(HloInstruction * dot)2123 StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcat(
2124     HloInstruction* dot) {
2125   const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
2126   if (dnums.lhs_contracting_dimensions_size() != 1 ||
2127       dnums.lhs_batch_dimensions_size() != 0 ||
2128       dot->shape().dimensions_size() != 2) {  // dot output 2D
2129     return nullptr;
2130   }
2131 
2132   const int64_t lhs_contracting_dim = dnums.lhs_contracting_dimensions(0);
2133   const int64_t rhs_contracting_dim = dnums.rhs_contracting_dimensions(0);
2134   HloInstruction *lhs, *rhs;
2135   CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs))));
2136 
2137   TF_ASSIGN_OR_RETURN(
2138       HloInstruction * optimized_lhs_concat,
2139       OptimizeDotOfConcatHelper(dot, lhs, lhs_contracting_dim, rhs,
2140                                 rhs_contracting_dim, /*swapped=*/false));
2141   if (optimized_lhs_concat) {
2142     return optimized_lhs_concat;
2143   }
2144 
2145   return OptimizeDotOfConcatHelper(dot, rhs, rhs_contracting_dim, lhs,
2146                                    lhs_contracting_dim, /*swapped=*/true);
2147 }
2148 
OptimizeDotOfConcatHelper(HloInstruction * dot,HloInstruction * lhs,int64_t lhs_contracting_dim,HloInstruction * rhs,int64_t rhs_contracting_dim,bool swapped)2149 StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper(
2150     HloInstruction* dot, HloInstruction* lhs, int64_t lhs_contracting_dim,
2151     HloInstruction* rhs, int64_t rhs_contracting_dim, bool swapped) {
2152   bool can_optimize = lhs->opcode() == HloOpcode::kConcatenate &&
2153                       lhs->concatenate_dimension() == lhs_contracting_dim &&
2154                       rhs->opcode() == HloOpcode::kConstant;
2155   if (!can_optimize) {
2156     return nullptr;
2157   }
2158 
2159   // We're replacing this:
2160   //
2161   //   +-----+-----+-----+      +-------------------+
2162   //   |     |     |     |      |                   |
2163   //   |     |     |     |      |        R_0        |
2164   //   |     |     |     |      |                   |
2165   //   |     |     |     |      +-------------------+
2166   //   |     |     |     |      |                   |
2167   //   | L_0 | L_1 | L_2 |   *  |        R_1        |
2168   //   |     |     |     |      |                   |
2169   //   |     |     |     |      +-------------------+
2170   //   |     |     |     |      |                   |
2171   //   |     |     |     |      |        R_2        |
2172   //   |     |     |     |      |                   |
2173   //   +-----+-----+-----+      +-------------------+
2174   //
2175   // with this:
2176   //
2177   // [Sum over i]
2178   //
2179   //   +-----+     +-------------------+
2180   //   |     |     |                   |
2181   //   |     |  *  |        R_i        |
2182   //   |     |     |                   |
2183   //   |     |     +-------------------+
2184   //   |     |
2185   //   | L_i |
2186   //   |     |
2187   //   |     |
2188   //   |     |
2189   //   |     |
2190   //   |     |
2191   //   +-----+
2192   //
2193   // where the LHS is a concatenate operation (so we can "split" the LHS tensor
2194   // for free) and the RHS is a constant tensor (and thus can be split at
2195   // compile time).  In the future, we may also want to do this when both the
2196   // LHS and the RHS are concatenate operations that line up along the dimension
2197   // being contracted over.
2198   //
2199   // We should be able to generalize this transform to work on a non-constant
2200   // RHS when/if we have in-place slices or support input-fusing slices into
2201   // Dots.
2202 
2203   // Dimension numbers for the new dot instructions we'll create (L_i * R_i in
2204   // the diagram above).
2205   DotDimensionNumbers new_dot_dnums;
2206   new_dot_dnums.add_lhs_contracting_dimensions(swapped ? rhs_contracting_dim
2207                                                        : lhs_contracting_dim);
2208   new_dot_dnums.add_rhs_contracting_dimensions(swapped ? lhs_contracting_dim
2209                                                        : rhs_contracting_dim);
2210 
2211   // Here we use the MKN notation, where the contracted dimension has K
2212   // elements and the two non-contracted dimensions have M and N elements.
2213   HloInstruction* add_result = nullptr;
2214   int64_t rhs_contracting_dim_offset = 0;
2215   int64_t n = rhs->shape().dimensions(1 - rhs_contracting_dim);
2216   for (HloInstruction* concat_op : lhs->operands()) {
2217     int64_t sub_k = concat_op->shape().dimensions(lhs_contracting_dim);
2218     Shape rhs_slice_shape(rhs->shape());
2219     rhs_slice_shape.set_dimensions(rhs_contracting_dim, sub_k);
2220     simplifier_->UpdateLayout(&rhs_slice_shape);
2221 
2222     std::array<int64_t, 2> start_indices;
2223     start_indices[rhs_contracting_dim] = rhs_contracting_dim_offset;
2224     start_indices[1 - rhs_contracting_dim] = 0;
2225 
2226     std::array<int64_t, 2> limit_indices;
2227     limit_indices[rhs_contracting_dim] = rhs_contracting_dim_offset + sub_k;
2228     limit_indices[1 - rhs_contracting_dim] = n;
2229 
2230     HloInstruction* rhs_slice = rhs->AddInstruction(HloInstruction::CreateSlice(
2231         rhs_slice_shape, rhs, /*start_indices=*/start_indices,
2232         /*limit_indices=*/limit_indices, /*strides=*/{1, 1}));
2233 
2234     // TODO(b/69062148): We can get rid of `swapped` once all backends support
2235     // "non-canonical" contraction dimensions (that contracts dimension 1 of the
2236     // LHS with dimension 0 of the RHS).  But for now we keep the same
2237     // contraction dimensions as the incoming dot operation to ensure the new
2238     // dot operations can be lowered.
2239     HloInstruction *new_dot_lhs, *new_dot_rhs;
2240     if (swapped) {
2241       new_dot_lhs = rhs_slice;
2242       new_dot_rhs = concat_op;
2243     } else {
2244       new_dot_lhs = concat_op;
2245       new_dot_rhs = rhs_slice;
2246     }
2247 
2248     auto* new_dot = dot->AddInstruction(
2249         HloInstruction::CreateDot(dot->shape(), new_dot_lhs, new_dot_rhs,
2250                                   new_dot_dnums, dot->precision_config()));
2251 
2252     if (add_result) {
2253       add_result = dot->AddInstruction(HloInstruction::CreateBinary(
2254           dot->shape(), HloOpcode::kAdd, add_result, new_dot));
2255     } else {
2256       add_result = new_dot;
2257     }
2258 
2259     rhs_contracting_dim_offset += sub_k;
2260   }
2261 
2262   return add_result;
2263 }
2264 
OptimizeDotOfGather(HloInstruction * dot)2265 StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfGather(
2266     HloInstruction* dot) {
2267   const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
2268   if (dnums.lhs_contracting_dimensions_size() != 1 ||
2269       dnums.rhs_contracting_dimensions_size() != 1 ||
2270       dnums.lhs_batch_dimensions_size() != 0 ||
2271       dnums.rhs_batch_dimensions_size() != 0 ||
2272       dot->shape().dimensions_size() != 2) {  // dot output 2D
2273     VLOG(10) << "DotOfGather: Can only optimize 2D, non-batch dot operations.";
2274     return nullptr;
2275   }
2276 
2277   // Optimize either dot(DS(ctA), ctB)) or dot(ctB, DS(ctA)).
2278   // Currently a Gather is a DynamicSlice.
2279   auto is_dynamic_slice_constant_combination =
2280       [](HloInstruction* a, HloInstruction* b, int a_contracting_dimension) {
2281         // First operand is a DynamicSlice(Constant).
2282         if (a->opcode() != HloOpcode::kDynamicSlice) {
2283           return false;
2284         }
2285         auto* dynamic_slice_op = a->operand(0);
2286         if (dynamic_slice_op->opcode() != HloOpcode::kConstant) {
2287           return false;
2288         }
2289         // Second operand is a Constant.
2290         if (b->opcode() != HloOpcode::kConstant) {
2291           return false;
2292         }
2293         // The DynamicSlice output is a vector.
2294         const Shape& dynamic_slice_shape = a->shape();
2295         if (dynamic_slice_shape.dimensions(1 - a_contracting_dimension) != 1) {
2296           return false;
2297         }
2298         // Constant size is the same before and after slice in the contracting
2299         // dimension, otherwise we either must precompute for all possible slice
2300         // indices or dot is invalid.
2301         const Shape& dynamic_slice_op_shape = dynamic_slice_op->shape();
2302         if (dynamic_slice_op_shape.dimensions(a_contracting_dimension) !=
2303             dynamic_slice_shape.dimensions(a_contracting_dimension)) {
2304           return false;
2305         }
2306         return true;
2307       };
2308 
2309   HloInstruction* lhs = dot->mutable_operand(0);
2310   HloInstruction* rhs = dot->mutable_operand(1);
2311   int lhs_contracting_dimension = dnums.lhs_contracting_dimensions(0);
2312   int rhs_contracting_dimension = dnums.rhs_contracting_dimensions(0);
2313 
2314   if (!is_dynamic_slice_constant_combination(
2315           lhs, rhs, /*a_contracting_dimension=*/lhs_contracting_dimension) &&
2316       !is_dynamic_slice_constant_combination(
2317           rhs, lhs, /*a_contracting_dimension=*/rhs_contracting_dimension)) {
2318     VLOG(10) << "DotOfGather: Can only optimize dot(DS(ctA), ctB)) or "
2319                 "dot(ctB, DS(ctA)), where the two constants have equal "
2320                 "contracting dimensions.";
2321     return nullptr;
2322   }
2323 
2324   // LHS is DynamicSlice:
2325   // input: dot(DS(ctA), ctB))
2326   // where DS(ctA) = DS({M x K}, {start, 0}, {1, K}) and ctB = {K x N}.
2327   // => input dimensions: dot({1 x K}, {K x N}) => {1 x N}.
2328   // output: DS(dot(ctA, ctB))
2329   // => output dimensions: DS ({M x N}, {start, 0}, {1, N}) => {1 x N}.
2330 
2331   // RHS is DynamicSlice:
2332   // input: dot(ctA, DS(ctB))
2333   // where ctA = {M x K} and DS(ctB) = DS({K x N}, {0, start}, {K, 1}).
2334   // => input dimensions: dot({M x K}, {K x 1}) => {M x 1}.
2335   // output: DS(dot(ctA, ctB))
2336   // => output dimensions: DS ({M x N}, {0, start}, {M, 1}) => {M x 1}.
2337 
2338   bool lhs_is_dynamic_slice = lhs->opcode() == HloOpcode::kDynamicSlice;
2339   HloDynamicSliceInstruction* dynamic_slice =
2340       lhs_is_dynamic_slice ? Cast<HloDynamicSliceInstruction>(lhs)
2341                            : Cast<HloDynamicSliceInstruction>(rhs);
2342 
2343   // ctA:
2344   HloInstruction* left_operand =
2345       lhs_is_dynamic_slice ? lhs->mutable_operand(0) : lhs;
2346   // ctB:
2347   HloInstruction* right_operand =
2348       lhs_is_dynamic_slice ? rhs : rhs->mutable_operand(0);
2349   // Build ctA x ctB.
2350   const int m = left_operand->shape().dimensions(1 - lhs_contracting_dimension);
2351   const int n =
2352       right_operand->shape().dimensions(1 - rhs_contracting_dimension);
2353   auto memoized_shape =
2354       ShapeUtil::MakeShape(dot->shape().element_type(), {m, n});
2355   simplifier_->UpdateLayout(&memoized_shape);
2356   auto* memoized_inst = dot->AddInstruction(
2357       HloInstruction::CreateDot(memoized_shape, left_operand, right_operand,
2358                                 dnums, dot->precision_config()));
2359   // Get pair {start, 0} or {0, start}.
2360   // Position of start:
2361   int index_of_non_zero_start = lhs_is_dynamic_slice
2362                                     ? 1 - lhs_contracting_dimension
2363                                     : 1 - rhs_contracting_dimension;
2364   // Position of zero:
2365   int index_of_zero_start = 1 - index_of_non_zero_start;
2366 
2367   // Slice out start and 0 components and reorder if necessary.
2368   auto indices_type = dynamic_slice->operand(1)->shape().element_type();
2369   Shape s_shape = ShapeUtil::MakeShape(indices_type, {1});
2370   simplifier_->UpdateLayout(&s_shape);
2371   Shape d_shape = ShapeUtil::MakeShape(indices_type, {2});
2372   simplifier_->UpdateLayout(&d_shape);
2373   HloInstruction* non_zero_start =
2374       dynamic_slice->mutable_operand(1 + index_of_non_zero_start);
2375   HloInstruction* zero_start =
2376       dynamic_slice->mutable_operand(1 + index_of_zero_start);
2377   std::vector<HloInstruction*> new_start_indices;
2378   if (lhs_is_dynamic_slice) {
2379     new_start_indices = {non_zero_start, zero_start};
2380   } else {
2381     new_start_indices = {zero_start, non_zero_start};
2382   }
2383 
2384   // Build DynamicSlice(ctA x ctB).
2385   const int new_slice_m = lhs_is_dynamic_slice ? 1 : m;
2386   const int new_slice_n = lhs_is_dynamic_slice ? n : 1;
2387   auto* memoized_lookup =
2388       dot->AddInstruction(HloInstruction::CreateDynamicSlice(
2389           dot->shape(), memoized_inst, new_start_indices,
2390           {new_slice_m, new_slice_n}));
2391 
2392   return memoized_lookup;
2393 }
2394 
2395 // This function tries to transform
2396 //   dot(reshape(transpose(A)), Const) to
2397 //   dot(reshape(A), reshape(transpose(reshape(Const)))),
2398 // so that the reshape and transpose on the Const side can be constant folded.
2399 //
2400 // The basic idea is that since the accumulation in the dot operation is
2401 // associative, so as long as we permute the elements of the contracting
2402 // dimensions on both sides of the dot in the same way, the result of the
2403 // dot is not affected.
2404 StatusOr<HloInstruction*>
OptimizeDotOfReorderContractingDims(HloInstruction * dot)2405 AlgebraicSimplifierVisitor::OptimizeDotOfReorderContractingDims(
2406     HloInstruction* dot) {
2407   // This transformation assumes layout is not assigned yet.
2408   if (options_.is_layout_sensitive()) {
2409     return nullptr;
2410   }
2411 
2412   // Canonicalize dot(<constant>, rhs) to dot(rhs, <constant>) to make the
2413   // remainder of this function easier.
2414   auto dnums = dot->dot_dimension_numbers();
2415   auto lhs_contracting_dims = dnums.lhs_contracting_dimensions();
2416   auto rhs_contracting_dims = dnums.rhs_contracting_dimensions();
2417   auto* lhs = dot->mutable_operand(0);
2418   auto* rhs = dot->mutable_operand(1);
2419   if (dot->operand(0)->IsConstant()) {
2420     std::swap(lhs, rhs);
2421     std::swap(lhs_contracting_dims, rhs_contracting_dims);
2422   }
2423 
2424   // Require single contracting dim to make the implementation easier to
2425   // track contracting dims.
2426   if (dnums.lhs_contracting_dimensions_size() != 1) {
2427     return nullptr;
2428   }
2429 
2430   // Pattern match Dot(reshape(transpose(input), constant))
2431   HloInstruction* reshape;
2432   HloInstruction* transpose;
2433   HloInstruction* input;
2434   HloInstruction* constant;
2435   if (!Match(lhs,
2436              m::Reshape(&reshape, m::Transpose(&transpose, m::Op(&input)))) ||
2437       !Match(rhs, m::Constant(&constant))) {
2438     return nullptr;
2439   }
2440 
2441   // Check that reshape squishes some dims into one dim and that this one
2442   // dim is the dot's lhs contracting dim. The size of unmodified_dims should
2443   // be N - 1, where N is the rank of the reshape output. This means that the
2444   // reshape squishes some dims into one dim. lhs contracting dim should not
2445   // be in unmodified_dims. This means that the squishing target dim is the
2446   // lhs contracting dim.
2447   auto unmodified_dims = ShapeUtil::DimensionsUnmodifiedByReshape(
2448       reshape->operand(0)->shape(), reshape->shape());
2449   CHECK_EQ(lhs_contracting_dims.size(), 1);
2450   if ((unmodified_dims.size() != reshape->shape().rank() - 1) ||
2451       absl::c_any_of(unmodified_dims,
2452                      [&](const std::pair<int64_t, int64_t>& p) {
2453                        return p.second == lhs_contracting_dims[0];
2454                      })) {
2455     return nullptr;
2456   }
2457 
2458   // Virtually pull the reshape into the dot so the dot operates on the
2459   // transpose, with "unsquished" lhs contracting dims.  The new contracting
2460   // dims are all of the dims that are modified by the reshape -- that is, every
2461   // dimension that's not in `unmodified_dims[i].first`.
2462   //
2463   // (We don't need to actually create a new dot instruction. We can just keep
2464   // track of lhs and lhs_contracting_dims.)
2465   absl::flat_hash_set<int64_t> unmodified_transpose_dims;
2466   for (const auto& pair : unmodified_dims) {
2467     unmodified_transpose_dims.insert(pair.first);
2468   }
2469   lhs_contracting_dims.Clear();
2470   for (int64_t i = 0; i < transpose->shape().dimensions_size(); ++i) {
2471     if (!unmodified_transpose_dims.contains(i)) {
2472       lhs_contracting_dims.Add(i);
2473     }
2474   }
2475   // We require the "unsquished" lhs contracting dims to be consecutive.
2476   auto is_iota = [](absl::Span<const int64_t> dims) {
2477     return absl::c_adjacent_find(dims, [](const int64_t a, const int64_t b) {
2478              return (b != a + 1);
2479            }) == dims.end();
2480   };
2481   if (!is_iota(lhs_contracting_dims)) {
2482     return nullptr;
2483   }
2484   lhs = lhs->mutable_operand(0);
2485 
2486   // Check that the transpose only permutes the contracting dims.
2487   const auto& transpose_dims = transpose->dimensions();
2488   for (int64_t i = 0; i < transpose_dims.size(); ++i) {
2489     if (transpose_dims[i] != i &&
2490         !absl::c_linear_search(lhs_contracting_dims, i)) {
2491       return nullptr;
2492     }
2493   }
2494   // Virtually pull the transpose into the dot. Now the dot is equivalent to
2495   // a new dot with "permuted" lhs contracting dims.
2496   std::vector<int64_t> permutation;
2497   permutation.reserve(lhs_contracting_dims.size());
2498   for (auto dim : lhs_contracting_dims) {
2499     permutation.push_back(transpose_dims[dim] - lhs_contracting_dims[0]);
2500   }
2501   CHECK(IsPermutation(permutation));
2502   auto new_lhs_contracting_dims =
2503       ComposePermutations(lhs_contracting_dims, permutation);
2504   lhs_contracting_dims.Clear();
2505   for (auto dim : new_lhs_contracting_dims) {
2506     lhs_contracting_dims.Add(dim);
2507   }
2508   lhs = lhs->mutable_operand(0);
2509 
2510   // All checks are passed at this point.
2511   //
2512   // Transform lhs. Remove the transpose and reshape by sorting the lhs
2513   // contracting dims and squishing them into a single one. We don't actually
2514   // squish the lhs_contracting_dims here because we still need the unsquished
2515   // contracting dims to invert reshape and transpose.
2516   absl::c_sort(lhs_contracting_dims);
2517   lhs =
2518       dot->AddInstruction(HloInstruction::CreateReshape(reshape->shape(), lhs));
2519 
2520   // Transform rhs. Say the input HLO is:
2521   //
2522   //   t0 = f32[2, 2, 3] parameter(0)
2523   //   t1 = f32[2, 3, 2] transpose(t0) dimensions={0, 2, 1}
2524   //   t2 = f32[2, 6] reshape(t1)
2525   //   t3 = f32[6, 2] constant(...)
2526   //   dot = f32[2, 2] dot(t2, t3) lhs_contracting_dims={1},
2527   //                               rhs_contracting_dims={0}
2528   //
2529   // At this point in the function, we have decided that the second and third
2530   // dims of t0 can be switched to remove the transpose, and we have
2531   // "virtually decomposed" the input HLO to:
2532   //
2533   //   t0 = f32[2, 2, 3] parameter(0)
2534   //   t2' = f32[2, 6] reshape(t0)
2535   //   t3' = f32[6, 2] ops-to-be-filled ...
2536   //   dot = f32[2, 2] dot(t2', t3') lhs_contracting_dims={1},
2537   //                                 rhs_contracting_dims={0}
2538   //
2539   // The rest of this function is to fill in the ops of t3'. To do this, we
2540   // unsquish the contracting dimensions in t3 and then apply the inverse of
2541   // the transpose from t1.
2542 
2543   // Invert reshape.
2544   CHECK_EQ(rhs_contracting_dims.size(), 1);
2545   std::vector<int64_t> rhs_unsquished_shape_dims =
2546       SpanToVector(constant->shape().dimensions());
2547   auto it = rhs_unsquished_shape_dims.erase(rhs_unsquished_shape_dims.begin() +
2548                                             rhs_contracting_dims[0]);
2549   for (auto dim : lhs_contracting_dims) {
2550     it = rhs_unsquished_shape_dims.insert(it,
2551                                           transpose->shape().dimensions(dim));
2552     ++it;
2553   }
2554   HloInstruction* rhs_reshape =
2555       dot->AddInstruction(HloInstruction::CreateReshape(
2556           ShapeUtil::MakeShape(constant->shape().element_type(),
2557                                rhs_unsquished_shape_dims),
2558           constant));
2559   rhs = rhs_reshape;
2560 
2561   // Rhs reshape "unsquishes" the single rhs contracting dim into multiple dims.
2562   rhs_contracting_dims.Resize(lhs_contracting_dims.size(), 0);
2563   absl::c_iota(rhs_contracting_dims, rhs_contracting_dims[0]);
2564 
2565   // Invert transpose. First compute the shape.
2566   std::vector<int64_t> rhs_transpose_shape_dims =
2567       SpanToVector(rhs_reshape->shape().dimensions());
2568   it = rhs_transpose_shape_dims.erase(
2569       rhs_transpose_shape_dims.begin() + rhs_contracting_dims[0],
2570       rhs_transpose_shape_dims.begin() + rhs_contracting_dims[0] +
2571           rhs_contracting_dims.size());
2572   for (auto dim : lhs_contracting_dims) {
2573     it = rhs_transpose_shape_dims.insert(it, input->shape().dimensions(dim));
2574     ++it;
2575   }
2576   // Then compute the transpose dims.
2577   std::vector<int64_t> rhs_transpose_dims(rhs_reshape->shape().rank());
2578   absl::c_iota(rhs_transpose_dims, 0);
2579   it = rhs_transpose_dims.erase(
2580       rhs_transpose_dims.begin() + rhs_contracting_dims[0],
2581       rhs_transpose_dims.begin() + rhs_contracting_dims[0] +
2582           rhs_contracting_dims.size());
2583   auto inverse_lhs_transpose_dims = InversePermutation(transpose_dims);
2584   for (auto dim : lhs_contracting_dims) {
2585     it = rhs_transpose_dims.insert(it, inverse_lhs_transpose_dims[dim] -
2586                                            lhs_contracting_dims[0] +
2587                                            rhs_contracting_dims[0]);
2588     ++it;
2589   }
2590   HloInstruction* rhs_transpose =
2591       dot->AddInstruction(HloInstruction::CreateTranspose(
2592           ShapeUtil::MakeShape(constant->shape().element_type(),
2593                                rhs_transpose_shape_dims),
2594           rhs_reshape, rhs_transpose_dims));
2595   rhs = rhs_transpose;
2596 
2597   // Squish the multiple rhs contracting dims into a single one.
2598   rhs = dot->AddInstruction(
2599       HloInstruction::CreateReshape(constant->shape(), rhs));
2600 
2601   // If we virtually swapped lhs and rhs, we need to swap it back before
2602   // creating new dot.
2603   if (dot->operand(0)->IsConstant()) {
2604     std::swap(lhs, rhs);
2605   }
2606 
2607   HloInstruction* new_dot = dot->AddInstruction(HloInstruction::CreateDot(
2608       dot->shape(), lhs, rhs, dnums, dot->precision_config()));
2609   return new_dot;
2610 }
2611 
HandleDot(HloInstruction * dot)2612 Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
2613   CHECK(computation_ == dot->parent());
2614   const auto& dnums = dot->dot_dimension_numbers();
2615 
2616   HloInstruction *lhs, *rhs;
2617   CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs))));
2618   if (options_.is_layout_sensitive()) {
2619     return OkStatus();
2620   }
2621   // Replace a zero element dot with a broadcast of the constant 0.
2622   if (ShapeUtil::IsZeroElementArray(dot->shape()) ||
2623       ShapeUtil::IsZeroElementArray(lhs->shape()) ||
2624       ShapeUtil::IsZeroElementArray(rhs->shape())) {
2625     auto zero =
2626         dot->AddInstruction(simplifier_->CreateConstantWithLayoutUpdated(
2627             LiteralUtil::Zero(dot->shape().element_type())));
2628     return ReplaceWithNewInstruction(
2629         dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {}));
2630   }
2631 
2632   // If there are no contracting dimensions, a dot can be rewritten as
2633   // mul(broadcast(transpose(x)),broadcast(transpose(y)))
2634   if (options_.enable_dot_to_multiply_rewrite() &&
2635       dnums.lhs_contracting_dimensions_size() == 0) {
2636     TF_ASSIGN_OR_RETURN(HloInstruction * new_lhs,
2637                         NormalizeDotOperandToBatchMajorAndContractingMinor(
2638                             lhs, dnums.lhs_batch_dimensions(),
2639                             dnums.lhs_contracting_dimensions()));
2640     if (!ShapeUtil::SameElementType(dot->shape(), new_lhs->shape())) {
2641       new_lhs = MakeConvertToHlo(new_lhs, dot->shape().element_type());
2642     }
2643     TF_ASSIGN_OR_RETURN(HloInstruction * new_rhs,
2644                         NormalizeDotOperandToBatchMajorAndContractingMinor(
2645                             rhs, dnums.rhs_batch_dimensions(),
2646                             dnums.rhs_contracting_dimensions()));
2647     if (!ShapeUtil::SameElementType(dot->shape(), new_rhs->shape())) {
2648       new_rhs = MakeConvertToHlo(new_rhs, dot->shape().element_type());
2649     }
2650     if (dot->shape().rank() != lhs->shape().rank()) {
2651       std::vector<int64_t> lhs_broadcast_dims(lhs->shape().rank());
2652       absl::c_iota(lhs_broadcast_dims, 0);
2653       new_lhs = dot->AddInstruction(HloInstruction::CreateBroadcast(
2654           dot->shape(), new_lhs, lhs_broadcast_dims));
2655     }
2656     if (dot->shape().rank() != rhs->shape().rank()) {
2657       std::vector<int64_t> rhs_broadcast_dims(
2658           dnums.lhs_batch_dimensions_size());
2659       absl::c_iota(rhs_broadcast_dims, 0);
2660       for (int64_t i = lhs->shape().rank(); i < dot->shape().rank(); ++i) {
2661         rhs_broadcast_dims.push_back(i);
2662       }
2663       new_rhs = dot->AddInstruction(HloInstruction::CreateBroadcast(
2664           dot->shape(), new_rhs, rhs_broadcast_dims));
2665     }
2666     return ReplaceWithNewInstruction(
2667         dot, HloInstruction::CreateBinary(dot->shape(), HloOpcode::kMultiply,
2668                                           new_lhs, new_rhs));
2669   }
2670 
2671   // If the lhs or rhs have only batch and contracting dimensions, a dot can be
2672   // rewritten as reduce(mul(broadcast(transpose(x)),broadcast(transpose(y))))
2673   if (options_.enable_dot_strength_reduction() &&
2674       ((dnums.lhs_batch_dimensions_size() +
2675             dnums.lhs_contracting_dimensions_size() ==
2676         lhs->shape().rank()) ||
2677        (dnums.rhs_contracting_dimensions_size() +
2678             dnums.rhs_batch_dimensions_size() ==
2679         rhs->shape().rank()))) {
2680     TF_ASSIGN_OR_RETURN(HloInstruction * new_lhs,
2681                         NormalizeDotOperandToBatchMajorAndContractingMinor(
2682                             lhs, dnums.lhs_batch_dimensions(),
2683                             dnums.lhs_contracting_dimensions()));
2684     if (!ShapeUtil::SameElementType(dot->shape(), new_lhs->shape())) {
2685       new_lhs = MakeConvertToHlo(new_lhs, dot->shape().element_type());
2686     }
2687 
2688     TF_ASSIGN_OR_RETURN(HloInstruction * new_rhs,
2689                         NormalizeDotOperandToBatchMajorAndContractingMinor(
2690                             rhs, dnums.rhs_batch_dimensions(),
2691                             dnums.rhs_contracting_dimensions()));
2692     if (!ShapeUtil::SameElementType(dot->shape(), new_rhs->shape())) {
2693       new_rhs = MakeConvertToHlo(new_rhs, dot->shape().element_type());
2694     }
2695 
2696     int64_t lhs_outer_dims =
2697         lhs->shape().rank() - (dnums.lhs_batch_dimensions_size() +
2698                                dnums.lhs_contracting_dimensions_size());
2699     int64_t rhs_outer_dims =
2700         rhs->shape().rank() - (dnums.rhs_batch_dimensions_size() +
2701                                dnums.rhs_contracting_dimensions_size());
2702     CHECK(lhs_outer_dims == 0 || rhs_outer_dims == 0);
2703     if (rhs_outer_dims > 0) {
2704       std::vector<int64_t> lhs_broadcast_dims(
2705           dnums.lhs_batch_dimensions_size());
2706       absl::c_iota(lhs_broadcast_dims, 0);
2707       lhs_broadcast_dims.resize(lhs->shape().rank());
2708       std::iota(lhs_broadcast_dims.begin() + dnums.lhs_batch_dimensions_size(),
2709                 lhs_broadcast_dims.end(),
2710                 dnums.lhs_batch_dimensions_size() + rhs_outer_dims);
2711       new_lhs = dot->AddInstruction(HloInstruction::CreateBroadcast(
2712           new_rhs->shape(), new_lhs, lhs_broadcast_dims));
2713     } else if (lhs_outer_dims > 0) {
2714       std::vector<int64_t> rhs_broadcast_dims(
2715           dnums.rhs_batch_dimensions_size());
2716       absl::c_iota(rhs_broadcast_dims, 0);
2717       rhs_broadcast_dims.resize(rhs->shape().rank());
2718       std::iota(rhs_broadcast_dims.begin() + dnums.rhs_batch_dimensions_size(),
2719                 rhs_broadcast_dims.end(),
2720                 dnums.rhs_batch_dimensions_size() + lhs_outer_dims);
2721       new_rhs = dot->AddInstruction(HloInstruction::CreateBroadcast(
2722           new_lhs->shape(), new_rhs, rhs_broadcast_dims));
2723     }
2724 
2725     TF_ASSIGN_OR_RETURN(HloInstruction * new_dot,
2726                         MakeBinaryHlo(HloOpcode::kMultiply, new_lhs, new_rhs));
2727     std::vector<int64_t> reduce_dims(dnums.lhs_contracting_dimensions_size());
2728     PrimitiveType dot_type =
2729         ShapeUtil::ElementIsFloating(dot->shape())
2730             ? (dot->shape().element_type() == F64 ? F64 : F32)
2731             : dot->shape().element_type();
2732     new_dot = AsType(new_dot, dot_type);
2733     const int64_t outer_dims = std::max(rhs_outer_dims, lhs_outer_dims);
2734     absl::c_iota(reduce_dims, outer_dims + dnums.lhs_batch_dimensions_size());
2735     new_dot = AddReduce(new_dot, reduce_dims, dot_type);
2736     new_dot = AsType(new_dot, dot->shape().element_type());
2737     return ReplaceInstruction(dot, new_dot);
2738   }
2739 
2740   // Simplify dot(reshape(transpose(A)), Const) to:
2741   // dot(reshape(A), reshape(transpose(reshape(Const)))), so that the reshape
2742   // and transpose on the Const side can be constant folded.
2743   TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_reorder_optimized,
2744                       OptimizeDotOfReorderContractingDims(dot));
2745   if (dot_of_reorder_optimized) {
2746     VLOG(10) << " Replaced dot " << dot->ToString()
2747              << " with new dot operation: "
2748              << dot_of_reorder_optimized->ToString();
2749     return ReplaceInstruction(dot, dot_of_reorder_optimized);
2750   }
2751 
2752   TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_concat_optimized,
2753                       OptimizeDotOfConcat(dot));
2754   if (dot_of_concat_optimized) {
2755     VLOG(10) << "Replaced dot(concat(...), constant) with add(dot(..., "
2756                 "constant)...)";
2757     return ReplaceInstruction(dot, dot_of_concat_optimized);
2758   }
2759 
2760   // Simplify dot(ConstA, Gather(Index, ConstB)) to:
2761   // Gather(Index, dot*(ConstA, ConstB)), where dot* is an appropriately
2762   // batched version of dot.
2763   TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_gather_optimized,
2764                       OptimizeDotOfGather(dot));
2765   if (dot_of_gather_optimized) {
2766     VLOG(10) << "Replaced dot(constA, gather(i, constB)) with "
2767                 "gather(i, dot*(constA, constB))";
2768     return ReplaceInstruction(dot, dot_of_gather_optimized);
2769   }
2770 
2771   TF_ASSIGN_OR_RETURN(bool removed_degenerate_dimensions,
2772                       RemoveDegenerateDimensionFromDot(dot));
2773   if (removed_degenerate_dimensions) {
2774     return OkStatus();
2775   }
2776 
2777   TF_ASSIGN_OR_RETURN(bool removed_transposes,
2778                       RemoveTransposesFromDotOperands(dot));
2779   if (removed_transposes) {
2780     return OkStatus();
2781   }
2782 
2783   return OkStatus();
2784 }
2785 
HandleGather(HloInstruction * gather)2786 Status AlgebraicSimplifierVisitor::HandleGather(HloInstruction* gather) {
2787   const Shape& operand_shape = gather->operand(0)->shape();
2788   if (ShapeUtil::IsZeroElementArray(operand_shape)) {
2789     return ReplaceInstruction(gather, MakeScalarLike(gather, 0));
2790   }
2791 
2792   // Gathering from a scalar operand is simply a broadcast of that scalar
2793   if (ShapeUtil::IsEffectiveScalar(operand_shape)) {
2794     HloInstruction* new_operand = gather->mutable_operand(0);
2795     if (operand_shape.rank()) {
2796       TF_ASSIGN_OR_RETURN(new_operand,
2797                           MakeReshapeHlo(ShapeUtil::MakeScalarShape(
2798                                              operand_shape.element_type()),
2799                                          new_operand));
2800     }
2801     HloInstruction* new_gather =
2802         MakeBroadcastHlo(new_operand, {}, gather->shape());
2803     return ReplaceInstruction(gather, new_gather);
2804   }
2805   // If the operand of a gather is very small, it is easier to fuse a
2806   // sequence of selects.
2807   const Shape& index_shape = gather->operand(1)->shape();
2808   if (operand_shape.rank() == 1 &&
2809       operand_shape.dimensions(0) <= options_.very_small_gather_size() &&
2810       gather->gather_dimension_numbers().index_vector_dim() ==
2811           index_shape.rank() &&
2812       gather->gather_dimension_numbers().collapsed_slice_dims_size() == 1) {
2813     const int64_t operand_elements = operand_shape.dimensions(0);
2814     auto get_value = [&](int64_t i) {
2815       auto slice = gather->AddInstruction(HloInstruction::CreateSlice(
2816           ShapeUtil::MakeShape(operand_shape.element_type(), {1}),
2817           gather->mutable_operand(0), {i}, {i + 1}, {1}));
2818       auto scalar = gather->AddInstruction(HloInstruction::CreateReshape(
2819           ShapeUtil::MakeShape(operand_shape.element_type(), {}), slice));
2820       return gather->AddInstruction(
2821           HloInstruction::CreateBroadcast(gather->shape(), scalar, {}));
2822     };
2823     auto result = get_value(0);
2824     auto pred_shape = ShapeUtil::ChangeElementType(gather->shape(), PRED);
2825     simplifier_->UpdateLayout(&pred_shape);
2826     auto iter_shape = ShapeUtil::ChangeElementType(gather->shape(),
2827                                                    index_shape.element_type());
2828     simplifier_->UpdateLayout(&iter_shape);
2829     for (int64_t i = 0; i < operand_elements; ++i) {
2830       auto index_mask = gather->AddInstruction(HloInstruction::CreateCompare(
2831           pred_shape, gather->mutable_operand(1),
2832           MakeScalarLike(gather->mutable_operand(1), i),
2833           ComparisonDirection::kGe));
2834       result = gather->AddInstruction(
2835           HloInstruction::CreateTernary(gather->shape(), HloOpcode::kSelect,
2836                                         index_mask, get_value(i), result));
2837     }
2838     return ReplaceInstruction(gather, result);
2839   }
2840   return OkStatus();
2841 }
2842 
2843 namespace {
MinMaxToClamp(HloInstruction * clamp_lower_bound_bcast,HloInstruction * to_clamp,HloInstruction * clamp_upper_bound_bcast,AlgebraicSimplifier * simplifier)2844 StatusOr<std::unique_ptr<HloInstruction>> MinMaxToClamp(
2845     HloInstruction* clamp_lower_bound_bcast, HloInstruction* to_clamp,
2846     HloInstruction* clamp_upper_bound_bcast, AlgebraicSimplifier* simplifier) {
2847   HloInstruction* clamp_lower_bound;
2848   CHECK(Match(clamp_lower_bound_bcast,
2849               m::Broadcast(m::ConstantEffectiveScalar(&clamp_lower_bound))))
2850       << clamp_lower_bound_bcast->ToString();
2851 
2852   HloInstruction* clamp_upper_bound;
2853   CHECK(Match(clamp_upper_bound_bcast,
2854               m::Broadcast(m::ConstantEffectiveScalar(&clamp_upper_bound))))
2855       << clamp_upper_bound_bcast->ToString();
2856 
2857   const Literal& lower_bound =
2858       Cast<HloConstantInstruction>(clamp_lower_bound)->literal();
2859   const Literal& upper_bound =
2860       Cast<HloConstantInstruction>(clamp_upper_bound)->literal();
2861 
2862   TF_ASSIGN_OR_RETURN(Literal lower_bound_literal_reshaped,
2863                       lower_bound.Reshape({}));
2864   TF_ASSIGN_OR_RETURN(Literal upper_bound_literal_reshaped,
2865                       upper_bound.Reshape({}));
2866   std::unique_ptr<HloInstruction> lower_bound_instr =
2867       HloInstruction::CreateConstant(std::move(lower_bound_literal_reshaped));
2868   std::unique_ptr<HloInstruction> upper_bound_instr =
2869       HloInstruction::CreateConstant(std::move(upper_bound_literal_reshaped));
2870 
2871   Shape compare_shape =
2872       ShapeUtil::ChangeElementType(lower_bound_instr->shape(), PRED);
2873   simplifier->UpdateLayout(&compare_shape);
2874   std::unique_ptr<HloInstruction> cloned_instruction =
2875       HloInstruction::CreateCompare(compare_shape, lower_bound_instr.get(),
2876                                     upper_bound_instr.get(),
2877                                     ComparisonDirection::kLt);
2878 
2879   HloEvaluator evaluator;
2880   TF_ASSIGN_OR_RETURN(auto result,
2881                       evaluator.Evaluate(cloned_instruction.get()));
2882   if (result.IsAll(true)) {
2883     return HloInstruction::CreateTernary(to_clamp->shape(), HloOpcode::kClamp,
2884                                          clamp_lower_bound_bcast, to_clamp,
2885                                          clamp_upper_bound_bcast);
2886   }
2887   return std::unique_ptr<HloInstruction>();
2888 }
2889 }  // namespace
2890 
HandleMaximum(HloInstruction * maximum)2891 Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum) {
2892   HloInstruction *lhs, *rhs;
2893   CHECK(Match(maximum, m::Maximum(m::Op(&lhs), m::Op(&rhs))));
2894 
2895   // max(x, -inf) -> x
2896   PrimitiveType ty = maximum->shape().element_type();
2897   if (primitive_util::IsIntegralType(ty) ||
2898       (primitive_util::IsFloatingPointType(ty) &&
2899        options_.minmax_propagate_nan())) {
2900     Literal min_val = LiteralUtil::MinValue(ty);
2901     if (IsAll(lhs, min_val)) {
2902       return ReplaceInstruction(maximum, rhs);
2903     }
2904     if (IsAll(rhs, min_val)) {
2905       return ReplaceInstruction(maximum, lhs);
2906     }
2907   }
2908 
2909   HloInstruction* clamp_upper_bound_bcast;
2910   HloInstruction* clamp_lower_bound_bcast;
2911   HloInstruction* to_clamp;
2912   if (Match(maximum, m::MaximumAnyOrder(
2913                          m::Broadcast(&clamp_lower_bound_bcast,
2914                                       m::ConstantEffectiveScalar()),
2915                          m::MinimumAnyOrder(
2916                              m::Op(&to_clamp),
2917                              m::Broadcast(&clamp_upper_bound_bcast,
2918                                           m::ConstantEffectiveScalar()))))) {
2919     TF_ASSIGN_OR_RETURN(auto clamp,
2920                         MinMaxToClamp(clamp_lower_bound_bcast, to_clamp,
2921                                       clamp_upper_bound_bcast, simplifier_));
2922     if (clamp) {
2923       return ReplaceWithNewInstruction(maximum, std::move(clamp));
2924     }
2925   }
2926 
2927   HloInstruction* clamp_lower_bound;
2928   HloInstruction* clamp_upper_bound;
2929   HloInstruction* max_operand;
2930   HloInstruction* clamp;
2931   if (Match(maximum,
2932             m::MaximumAnyOrder(
2933                 m::Op(&max_operand),
2934                 m::Clamp(&clamp, m::Op(&clamp_lower_bound), m::Op(&to_clamp),
2935                          m::Op(&clamp_upper_bound))))) {
2936     if (max_operand == clamp_lower_bound &&
2937         ReplaceInstructionIfCompatible(maximum, clamp)) {
2938       return OkStatus();
2939     }
2940   }
2941 
2942   return OkStatus();
2943 }
2944 
HandleMinimum(HloInstruction * minimum)2945 Status AlgebraicSimplifierVisitor::HandleMinimum(HloInstruction* minimum) {
2946   HloInstruction *lhs, *rhs;
2947   CHECK(Match(minimum, m::Minimum(m::Op(&lhs), m::Op(&rhs))));
2948 
2949   // min(x, inf) -> x
2950   PrimitiveType ty = minimum->shape().element_type();
2951   if (primitive_util::IsIntegralType(ty) ||
2952       (primitive_util::IsFloatingPointType(ty) &&
2953        options_.minmax_propagate_nan())) {
2954     Literal max_val = LiteralUtil::MaxValue(ty);
2955     if (IsAll(lhs, max_val)) {
2956       return ReplaceInstruction(minimum, rhs);
2957     }
2958     if (IsAll(rhs, max_val)) {
2959       return ReplaceInstruction(minimum, lhs);
2960     }
2961   }
2962 
2963   HloInstruction* clamp_upper_bound_bcast;
2964   HloInstruction* clamp_lower_bound_bcast;
2965   HloInstruction* to_clamp;
2966   if (Match(minimum, m::MinimumAnyOrder(
2967                          m::Broadcast(&clamp_upper_bound_bcast,
2968                                       m::ConstantEffectiveScalar()),
2969                          m::MaximumAnyOrder(
2970                              m::Op(&to_clamp),
2971                              m::Broadcast(&clamp_lower_bound_bcast,
2972                                           m::ConstantEffectiveScalar()))))) {
2973     TF_ASSIGN_OR_RETURN(auto clamp,
2974                         MinMaxToClamp(clamp_lower_bound_bcast, to_clamp,
2975                                       clamp_upper_bound_bcast, simplifier_));
2976     if (clamp) {
2977       return ReplaceWithNewInstruction(minimum, std::move(clamp));
2978     }
2979   }
2980 
2981   return OkStatus();
2982 }
2983 
HandleClamp(HloInstruction * clamp)2984 Status AlgebraicSimplifierVisitor::HandleClamp(HloInstruction* clamp) {
2985   HloInstruction* clamp_lower_bound;
2986   HloInstruction* clamp_upper_bound;
2987   HloInstruction* to_clamp;
2988   CHECK(Match(clamp, m::Clamp(m::Op(&clamp_lower_bound), m::Op(&to_clamp),
2989                               m::Op(&clamp_upper_bound))));
2990 
2991   // clamp(a, clamp(a, x, b), b) -> clamp(a, x, b)
2992   if (Match(to_clamp, m::Clamp(m::Op().Is(clamp_lower_bound), m::Op(),
2993                                m::Op().Is(clamp_upper_bound))) &&
2994       ReplaceInstructionIfCompatible(clamp, to_clamp)) {
2995     return OkStatus();
2996   }
2997 
2998   // Eliminate redundant clamping of replica-id or partition-id.
2999   if ((Match(to_clamp, m::PartitionId()) || Match(to_clamp, m::ReplicaId())) &&
3000       Match(clamp_lower_bound, m::ConstantScalar(0U)) &&
3001       Match(clamp_upper_bound, m::ConstantScalar())) {
3002     int64_t upper_bound = Cast<HloConstantInstruction>(clamp_upper_bound)
3003                               ->literal()
3004                               .GetFirstElement<uint32_t>();
3005     const HloModuleConfig& config = clamp->GetModule()->config();
3006     int64_t runtime_bound = Match(to_clamp, m::PartitionId())
3007                                 ? config.num_partitions()
3008                                 : config.replica_count();
3009 
3010     // If num_partitions or replica_count is 1, infer it as unknown.
3011     // pid/rid < runtime_bound => The clamp(0, pid/rid, upper_bound) is
3012     // redundant if the runtime_bound <= upper_bound + 1;
3013     if (runtime_bound != 1 && runtime_bound <= upper_bound + 1) {
3014       return ReplaceInstruction(clamp, to_clamp);
3015     }
3016   }
3017 
3018   return OkStatus();
3019 }
3020 
HandleMultiply(HloInstruction * multiply)3021 Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) {
3022   HloInstruction *lhs, *rhs;
3023   CHECK(Match(multiply, m::Multiply(m::Op(&lhs), m::Op(&rhs))));
3024   // LHS*1 => LHS
3025   VLOG(10) << "trying transform [LHS*1 => LHS]: " << multiply->ToString();
3026   if (IsAll(rhs, 1) && ReplaceInstructionIfCompatible(multiply, lhs)) {
3027     return OkStatus();
3028   }
3029   // 1*RHS => RHS
3030   VLOG(10) << "trying transform [1*RHS => RHS]: " << multiply->ToString();
3031   if (IsAll(lhs, 1) && ReplaceInstructionIfCompatible(multiply, rhs)) {
3032     return OkStatus();
3033   }
3034 
3035   // 0*RHS => 0. Only applies for integral types for correct NaN-handling.
3036   if (IsAll(lhs, 0) &&
3037       primitive_util::IsIntegralType(multiply->shape().element_type()) &&
3038       ReplaceInstructionIfCompatible(multiply, lhs)) {
3039     return OkStatus();
3040   }
3041   // LHS*0 => 0
3042   if (IsAll(rhs, 0) &&
3043       primitive_util::IsIntegralType(multiply->shape().element_type()) &&
3044       ReplaceInstructionIfCompatible(multiply, rhs)) {
3045     return OkStatus();
3046   }
3047 
3048   {
3049     HloInstruction* abs_operand;
3050     if (lhs == rhs && Match(lhs, m::Abs(m::Op(&abs_operand))) &&
3051         !ShapeUtil::ElementIsComplex(abs_operand->shape())) {
3052       TF_RETURN_IF_ERROR(multiply->ReplaceOperandWith(0, abs_operand));
3053       TF_RETURN_IF_ERROR(multiply->ReplaceOperandWith(1, abs_operand));
3054       MarkAsChanged();
3055       return OkStatus();
3056     }
3057   }
3058 
3059   {
3060     HloInstruction *convert_operand, *operand;
3061     // Mul(Convert(Pred), operand) => select(pred, operand, 0)
3062     if (Match(multiply,
3063               m::MultiplyAnyOrder(
3064                   m::Op(&operand),
3065                   m::Convert(
3066                       m::Op(&convert_operand)
3067                           .WithShape(m::Shape().WithElementType(PRED)))))) {
3068       HloInstruction* zero_like_multiply =
3069           BroadcastZeros(computation_, multiply->shape().element_type(),
3070                          multiply->shape().dimensions());
3071       return ReplaceWithNewInstruction(
3072           multiply, HloInstruction::CreateTernary(
3073                         multiply->shape(), HloOpcode::kSelect, convert_operand,
3074                         operand, zero_like_multiply));
3075     }
3076   }
3077 
3078   {
3079     HloInstruction *a, *b, *c1, *c2;
3080     // Mul(Mul(x, constant1), Mul(y, constant2)) => Mul(Mul(x, y),
3081     // constant1*constant2)
3082     if (Match(multiply,
3083               m::MultiplyAnyOrder(
3084                   m::MultiplyAnyOrder(m::NonConstant(&a), m::Constant(&c1)),
3085                   m::MultiplyAnyOrder(m::NonConstant(&b), m::Constant(&c2))))) {
3086       TF_ASSIGN_OR_RETURN(auto* product_of_constants,
3087                           MakeBinaryHlo(HloOpcode::kMultiply, c1, c2));
3088       if (ShapeUtil::IsScalar(product_of_constants->shape()) &&
3089           !ShapeUtil::IsScalar(multiply->shape())) {
3090         product_of_constants =
3091             multiply->AddInstruction(HloInstruction::CreateBroadcast(
3092                 multiply->shape(), product_of_constants, {}));
3093       }
3094 
3095       return ReplaceWithNewInstruction(
3096           multiply, HloInstruction::CreateBinary(
3097                         multiply->shape(), HloOpcode::kMultiply,
3098                         multiply->AddInstruction(HloInstruction::CreateBinary(
3099                             multiply->shape(), HloOpcode::kMultiply, a, b)),
3100                         product_of_constants));
3101     }
3102   }
3103 
3104   {
3105     HloInstruction *a, *c1, *c2;
3106     // Mul(Mul(a, constant1), constant2) => Mul(a, constant1*constant2)
3107     if (Match(multiply,
3108               m::MultiplyAnyOrder(
3109                   m::MultiplyAnyOrder(m::NonConstant(&a), m::Constant(&c1)),
3110                   m::Constant(&c2)))) {
3111       TF_ASSIGN_OR_RETURN(auto* product_of_constants,
3112                           MakeBinaryHlo(HloOpcode::kMultiply, c1, c2));
3113       if (ShapeUtil::IsScalar(product_of_constants->shape()) &&
3114           !ShapeUtil::IsScalar(multiply->shape())) {
3115         product_of_constants =
3116             multiply->AddInstruction(HloInstruction::CreateBroadcast(
3117                 multiply->shape(), product_of_constants, {}));
3118       }
3119 
3120       return ReplaceWithNewInstruction(
3121           multiply,
3122           HloInstruction::CreateBinary(multiply->shape(), HloOpcode::kMultiply,
3123                                        a, product_of_constants));
3124     }
3125   }
3126 
3127   {
3128     HloInstruction *a, *b, *constant, *op;
3129     // Mul(Mul(a, constant1), Broadcast(b)) =>
3130     // Mul(Broadcast(Mul(b, constant1), a))
3131     if (Match(multiply,
3132               m::MultiplyAnyOrder(m::MultiplyAnyOrder(m::NonConstant(&a),
3133                                                       m::Constant(&constant)),
3134                                   m::Op(&op))) ||
3135         Match(multiply,
3136               m::MultiplyAnyOrder(
3137                   m::MultiplyAnyOrder(m::NonConstant(&a),
3138                                       m::Broadcast(m::Constant(&constant))),
3139                   m::Op(&op)))) {
3140       // Check that the other side was a broadcast, and not of a constant.
3141       if (ShapeUtil::IsScalar(constant->shape()) &&
3142           Match(op, m::Broadcast(m::NonConstant()))) {
3143         auto dims = op->dimensions();
3144         b = op->mutable_operand(0);
3145         if (!ShapeUtil::IsScalar(b->shape())) {
3146           constant = multiply->AddInstruction(
3147               HloInstruction::CreateBroadcast(b->shape(), constant, {}));
3148         }
3149 
3150         auto new_mul = multiply->AddInstruction(HloInstruction::CreateBinary(
3151             b->shape(), HloOpcode::kMultiply, b, constant));
3152 
3153         return ReplaceWithNewInstruction(
3154             multiply,
3155             HloInstruction::CreateBinary(
3156                 multiply->shape(), HloOpcode::kMultiply, a,
3157                 multiply->AddInstruction(HloInstruction::CreateBroadcast(
3158                     multiply->shape(), new_mul, dims))));
3159       }
3160     }
3161   }
3162 
3163   VLOG(10) << "trying transform [(A * C1) * C2 => A * (C1 * C2)]";
3164   HloInstruction *a, *c1, *c2;
3165   if (Match(multiply,
3166             m::Multiply(m::Multiply(m::NonConstant(&a), m::Constant(&c1)),
3167                         m::Constant(&c2))) ||
3168       Match(multiply,
3169             m::Multiply(
3170                 m::Multiply(m::Op(&a), m::Broadcast(m::ConstantScalar(&c1))),
3171                 m::Broadcast(m::ConstantScalar(&c2))))) {
3172     TF_ASSIGN_OR_RETURN(auto* product_of_constants,
3173                         MakeBinaryHlo(HloOpcode::kMultiply, c1, c2));
3174     if (ShapeUtil::IsScalar(product_of_constants->shape()) &&
3175         !ShapeUtil::IsScalar(multiply->shape())) {
3176       product_of_constants =
3177           multiply->AddInstruction(HloInstruction::CreateBroadcast(
3178               multiply->shape(), product_of_constants, {}));
3179     }
3180     return ReplaceWithNewInstruction(
3181         multiply,
3182         HloInstruction::CreateBinary(multiply->shape(), HloOpcode::kMultiply, a,
3183                                      product_of_constants));
3184   }
3185 
3186   VLOG(10) << "trying to transform exp(LHS) * exp(RHS) => exp(LHS+RHS) "
3187            << multiply->ToString();
3188   if (Match(multiply, m::Multiply(m::Exp(m::Op(&lhs)), m::Exp(m::Op(&rhs))))) {
3189     auto add = multiply->AddInstruction(HloInstruction::CreateBinary(
3190         multiply->shape(), HloOpcode::kAdd, lhs, rhs));
3191     return ReplaceWithNewInstruction(
3192         multiply,
3193         HloInstruction::CreateUnary(multiply->shape(), HloOpcode::kExp, add));
3194   }
3195 
3196   VLOG(10) << "trying transform [rsqrt(B) * rsqrt(B) => 1/B] "
3197            << multiply->ToString();
3198   HloInstruction* b;
3199   if (Match(multiply, m::Multiply(m::Rsqrt(m::Op(&b)), m::Rsqrt(m::Op(&b)))) &&
3200       IsPositive(b, options_)) {
3201     return ReplaceWithNewInstruction(
3202         multiply,
3203         HloInstruction::CreateBinary(multiply->shape(), HloOpcode::kDivide,
3204                                      MakeScalarLike(b, 1), b));
3205   }
3206 
3207   return OkStatus();
3208 }
3209 
HandleNegate(HloInstruction * negate)3210 Status AlgebraicSimplifierVisitor::HandleNegate(HloInstruction* negate) {
3211   // negate(negate(x)) => x
3212   HloInstruction* x;
3213   if (Match(negate, m::Negate(m::Negate(m::Op(&x)))) &&
3214       ReplaceInstructionIfCompatible(negate, x)) {
3215     return OkStatus();
3216   }
3217   return OkStatus();
3218 }
3219 
HandleNot(HloInstruction * logical_not)3220 Status AlgebraicSimplifierVisitor::HandleNot(HloInstruction* logical_not) {
3221   // not(not(x)) => x
3222   HloInstruction* x;
3223   if (Match(logical_not, m::Not(m::Not(m::Op(&x)))) &&
3224       ReplaceInstructionIfCompatible(logical_not, x)) {
3225     return OkStatus();
3226   }
3227   return OkStatus();
3228 }
3229 
HandleOr(HloInstruction * logical_or)3230 Status AlgebraicSimplifierVisitor::HandleOr(HloInstruction* logical_or) {
3231   HloInstruction *lhs, *rhs;
3232   CHECK(Match(logical_or, m::Or(m::Op(&lhs), m::Op(&rhs))));
3233 
3234   // Simplify logical or
3235   if (ShapeUtil::HasPrimitiveType(lhs->shape(), xla::PRED) &&
3236       ShapeUtil::HasPrimitiveType(rhs->shape(), xla::PRED)) {
3237     // A || True => True
3238     VLOG(10) << "trying transform [A || True => True]: "
3239              << logical_or->ToString();
3240     if (IsAll(rhs, 1) && ReplaceInstructionIfCompatible(logical_or, rhs)) {
3241       return OkStatus();
3242     }
3243     // True || A => True
3244     VLOG(10) << "trying transform [True || A => True]: "
3245              << logical_or->ToString();
3246     if (IsAll(lhs, 1) && ReplaceInstructionIfCompatible(logical_or, lhs)) {
3247       return OkStatus();
3248     }
3249   }
3250 
3251   // A || False => A and A | 0 => A
3252   VLOG(10) << "trying transform [A || False => A]: " << logical_or->ToString();
3253   if (IsAll(rhs, 0) && ReplaceInstructionIfCompatible(logical_or, lhs)) {
3254     return OkStatus();
3255   }
3256 
3257   // False || A => A and 0 | A => A
3258   VLOG(10) << "trying transform [False || A => A]: " << logical_or->ToString();
3259   if (IsAll(lhs, 0) && ReplaceInstructionIfCompatible(logical_or, rhs)) {
3260     return OkStatus();
3261   }
3262 
3263   return OkStatus();
3264 }
3265 
HandleLog(HloInstruction * log)3266 Status AlgebraicSimplifierVisitor::HandleLog(HloInstruction* log) {
3267   // ln(exp(A)) => A
3268   VLOG(10) << "trying transform [ln(exp(A)) => A]: " << log->ToString();
3269   HloInstruction *a, *b;
3270   if (Match(log, m::Log(m::Exp(m::Op(&a)))) &&
3271       ReplaceInstructionIfCompatible(log, a)) {
3272     return OkStatus();
3273   }
3274 
3275   // ln(pow(A,B)) => B*ln(abs(A))
3276   // or B*ln(A) if A is complex.
3277   if (Match(log, m::Log(m::Power(m::Op(&a), m::Op(&b))))) {
3278     auto abs_a = ShapeUtil::ElementIsComplex(a->shape())
3279                      ? a
3280                      : log->AddInstruction(HloInstruction::CreateUnary(
3281                            log->shape(), HloOpcode::kAbs, a));
3282     auto new_log = log->AddInstruction(
3283         HloInstruction::CreateUnary(log->shape(), HloOpcode::kLog, abs_a));
3284     auto non_zero_b =
3285         log->mutable_operand(0)->AddInstruction(HloInstruction::CreateBinary(
3286             log->shape(), HloOpcode::kMultiply, new_log, b));
3287     TF_ASSIGN_OR_RETURN(
3288         auto b_is_zero,
3289         MakeCompareHlo(Comparison::Direction::kEq, b, MakeScalarLike(b, 0.0)));
3290     simplifier_->UpdateLayout(b_is_zero->mutable_shape());
3291     return ReplaceWithNewInstruction(
3292         log, HloInstruction::CreateTernary(log->shape(), HloOpcode::kSelect,
3293                                            b_is_zero, MakeScalarLike(log, 0.0),
3294                                            non_zero_b));
3295   }
3296 
3297   if (Match(log, m::Log(m::Sqrt(m::Op(&a))))) {
3298     auto new_log = log->AddInstruction(
3299         HloInstruction::CreateUnary(log->shape(), HloOpcode::kLog, a));
3300     return ReplaceWithNewInstruction(
3301         log, HloInstruction::CreateBinary(log->shape(), HloOpcode::kMultiply,
3302                                           new_log, MakeScalarLike(log, 0.5)));
3303   }
3304 
3305   if (Match(log, m::Log(m::Rsqrt(m::Op(&a))))) {
3306     auto new_log = log->AddInstruction(
3307         HloInstruction::CreateUnary(log->shape(), HloOpcode::kLog, a));
3308     return ReplaceWithNewInstruction(
3309         log, HloInstruction::CreateBinary(log->shape(), HloOpcode::kMultiply,
3310                                           new_log, MakeScalarLike(log, -0.5)));
3311   }
3312 
3313   return OkStatus();
3314 }
3315 
HandleGetTupleElement(HloInstruction * get_tuple_element)3316 Status AlgebraicSimplifierVisitor::HandleGetTupleElement(
3317     HloInstruction* get_tuple_element) {
3318   auto operand = get_tuple_element->mutable_operand(0);
3319   if (operand->opcode() == HloOpcode::kTuple) {
3320     // get_tuple_element(make_tuple({A_0, A_1, ..., A_n}), i) => A_i
3321     VLOG(10) << "trying transform "
3322              << "[get_tuple_element(make_tuple({...,A_i,...}), i)] => A_i: "
3323              << get_tuple_element->ToString();
3324     if (ReplaceInstructionIfCompatible(
3325             get_tuple_element,
3326             operand->mutable_operand(get_tuple_element->tuple_index()))) {
3327       return OkStatus();
3328     }
3329   }
3330   return OkStatus();
3331 }
3332 
HandleOptimizationBarrier(HloInstruction * barrier)3333 Status AlgebraicSimplifierVisitor::HandleOptimizationBarrier(
3334     HloInstruction* barrier) {
3335   if (!barrier->shape().IsTuple() ||
3336       barrier == computation_->root_instruction()) {
3337     return OkStatus();
3338   }
3339 
3340   // The goal of this transformation is to enable DCE on the tuple elements of
3341   // an optimization barrier operand. To do this safely, the optimization
3342   // barrier users must not use the tuple element and the only use of the index
3343   // of the operand should be the tuple instruction producing the operand of the
3344   // optimization barrier. Additionally if the operand is a tuple producing
3345   // instruction it should also be safe to create a sub tuple of only the used
3346   // components to enable module level dce.
3347   std::vector<bool> used_elements(barrier->shape().tuple_shapes_size());
3348   bool has_non_gte_use = false;
3349   for (auto use : barrier->users()) {
3350     if (use->opcode() != HloOpcode::kGetTupleElement) {
3351       has_non_gte_use = true;
3352       break;
3353     }
3354     used_elements[use->tuple_index()] = true;
3355   }
3356 
3357   HloInstruction* operand = barrier->mutable_operand(0);
3358   if (operand->opcode() == HloOpcode::kTuple) {
3359     for (int64_t i = 0; i < operand->operand_count(); ++i) {
3360       if (used_elements[i]) {
3361         continue;
3362       }
3363       if (operand->operand(i)->user_count() > 1 ||
3364           operand->operand(i) == computation_->root_instruction()) {
3365         used_elements[i] = true;
3366       }
3367     }
3368   }
3369 
3370   if (has_non_gte_use || !absl::c_linear_search(used_elements, false)) {
3371     return OkStatus();
3372   }
3373 
3374   MarkAsChanged();
3375   std::vector<int64_t> index_map(used_elements.size(), -1);
3376   std::vector<HloInstruction*> operands;
3377   int64_t current_index = 0;
3378   for (int64_t element = 0; element < used_elements.size(); ++element) {
3379     if (!used_elements[element]) {
3380       continue;
3381     }
3382     index_map[element] = current_index++;
3383     if (operand->opcode() == HloOpcode::kTuple) {
3384       operands.push_back(operand->mutable_operand(element));
3385     } else {
3386       operands.push_back(barrier->AddInstruction(
3387           HloInstruction::CreateGetTupleElement(operand, element)));
3388     }
3389   }
3390 
3391   HloInstruction* new_operand =
3392       operand->AddInstruction(HloInstruction::CreateTuple(operands));
3393   TF_RETURN_IF_ERROR(barrier->ReplaceOperandWithDifferentShape(0, new_operand));
3394   *barrier->mutable_shape() = new_operand->shape();
3395   for (auto use : barrier->users()) {
3396     CHECK_EQ(use->opcode(), HloOpcode::kGetTupleElement);
3397     use->set_tuple_index(index_map[use->tuple_index()]);
3398   }
3399   return OkStatus();
3400 }
3401 
3402 namespace {
3403 
ReshapeLeavesDimensionsUnmodified(const HloInstruction * hlo,absl::Span<const int64_t> input_dim_indices)3404 std::optional<std::vector<int64_t>> ReshapeLeavesDimensionsUnmodified(
3405     const HloInstruction* hlo, absl::Span<const int64_t> input_dim_indices) {
3406   CHECK_EQ(hlo->opcode(), HloOpcode::kReshape);
3407   return ShapeUtil::ReshapeLeavesDimensionsUnmodified(
3408       hlo->operand(0)->shape(), hlo->shape(), input_dim_indices);
3409 }
3410 
3411 // Returns true if the output of "instruction" is a permutation of the
3412 // elements of "operand". Precondition: "operand" is an operand of
3413 // "instruction".
OutputIsPermutationOfOperandElements(HloInstruction * instruction,HloInstruction * operand)3414 bool OutputIsPermutationOfOperandElements(HloInstruction* instruction,
3415                                           HloInstruction* operand) {
3416   DCHECK(!instruction->OperandIndices(operand).empty());
3417   switch (instruction->opcode()) {
3418     case HloOpcode::kReshape:
3419     case HloOpcode::kReverse:
3420     case HloOpcode::kTranspose:
3421       return true;
3422     case HloOpcode::kSort:
3423       return (!instruction->shape().IsTuple());
3424     default:
3425       return false;
3426   }
3427 }
3428 
3429 // Returns true if the output of "instruction" is a subset of the elements of
3430 // "operand". Precondition: "operand" is an operand of "instruction".
OutputIsSubsetOfOperandElements(HloInstruction * instruction,HloInstruction * operand)3431 bool OutputIsSubsetOfOperandElements(HloInstruction* instruction,
3432                                      HloInstruction* operand) {
3433   const auto operand_indices = instruction->OperandIndices(operand);
3434   CHECK(!operand_indices.empty());
3435   if (operand_indices.size() != 1) {
3436     return false;
3437   }
3438   int64_t operand_index = operand_indices[0];
3439   switch (instruction->opcode()) {
3440     case HloOpcode::kSlice:
3441       CHECK_EQ(0, operand_index);
3442       return true;
3443     case HloOpcode::kDynamicSlice:
3444       return operand_index == 0;
3445     default:
3446       return false;
3447   }
3448 }
3449 
3450 }  // namespace
3451 
HandleBroadcast(HloInstruction * broadcast)3452 Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) {
3453   HloInstruction* operand;
3454   CHECK(Match(broadcast, m::Broadcast(m::Op(&operand))));
3455   auto dims = *broadcast->mutable_dimensions();
3456   // A degenerate broadcast of a reshape that does not change the number of
3457   // elements can be replaced by a reshape.
3458   if (std::is_sorted(dims.begin(), dims.end()) &&
3459       ShapeUtil::ElementsIn(broadcast->shape()) ==
3460           ShapeUtil::ElementsIn(operand->shape())) {
3461     VLOG(10) << "transform broadcast(X) -> reshape(X) where "
3462                 "n(broadcast(X)) == n(X)";
3463     return ReplaceWithNewInstruction(
3464         broadcast, HloInstruction::CreateReshape(broadcast->shape(), operand));
3465   }
3466 
3467   // A degenerate broadcast that has the same input and output rank can be
3468   // converted into a transpose.
3469   if (broadcast->shape().rank() == operand->shape().rank() &&
3470       ShapeUtil::ElementsIn(broadcast->shape()) ==
3471           ShapeUtil::ElementsIn(operand->shape())) {
3472     VLOG(10) << "transform broadcast(X) -> transpose(X) where "
3473                 "n(broadcast(X)) == n(X)";
3474     return ReplaceWithNewInstruction(
3475         broadcast,
3476         HloInstruction::CreateTranspose(broadcast->shape(), operand, dims));
3477   }
3478 
3479   // A broadcast of a reshape which merely inserts 1-sized dimensions can
3480   // elide its operand.
3481   {
3482     std::optional<ShapeUtil::ShapeEqualityDescriptor> reshape_degenerate =
3483         operand->ReshapeMerelyInsertsOrDeletes1SizedDimensions();
3484     if (reshape_degenerate.has_value() &&
3485         reshape_degenerate->deleted_dimensions.empty()) {
3486       absl::c_reverse(reshape_degenerate->inserted_dimensions);
3487       for (auto inserted_index : reshape_degenerate->inserted_dimensions) {
3488         dims.erase(dims.begin() + inserted_index);
3489       }
3490       return ReplaceWithNewInstruction(
3491           broadcast,
3492           HloInstruction::CreateBroadcast(broadcast->shape(),
3493                                           operand->mutable_operand(0), dims));
3494     }
3495   }
3496 
3497   if (options_.enable_sink_broadcast()) {
3498     // A Broadcast that feeds a unary element-wise operation can sink the
3499     // broadcast after the unary element-wise operation.
3500     TF_ASSIGN_OR_RETURN(
3501         bool sink_succeeded,
3502         TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(broadcast));
3503     if (sink_succeeded) {
3504       MarkAsChanged();
3505       return OkStatus();
3506     }
3507   }
3508 
3509   // A scalar broadcast feeding an instruction which only permutes (reshape,
3510   // transpose, sort, reverse) or selects a subset of operand elements (slice,
3511   // dynamic slice) can be replaced with a broadcast directly to the output
3512   // shape of the instruction.
3513   if (ShapeUtil::IsScalar(operand->shape())) {
3514     for (HloInstruction* user : broadcast->users()) {
3515       // Skip if the broadcast user has no uses itself.
3516       if (user->IsDead()) {
3517         continue;
3518       }
3519       if (OutputIsPermutationOfOperandElements(user, broadcast) ||
3520           OutputIsSubsetOfOperandElements(user, broadcast)) {
3521         VLOG(10) << "transform permuting/subset  of a scalar broadcast into "
3522                  << "a single broadcast";
3523         HloInstruction* new_broadcast = user->AddInstruction(
3524             HloInstruction::CreateBroadcast(user->shape(), operand, {}));
3525         // Use HloInstruction::ReplaceAllUsesWith instead of
3526         // HloComputation::ReplaceWithNewInstruction because we are replacing an
3527         // instruction other than the visited instruction.
3528         MarkAsChanged();
3529         return user->ReplaceAllUsesWith(new_broadcast);
3530       }
3531     }
3532     return OkStatus();
3533   }
3534 
3535   // broadcast(iota) -> iota.
3536   if (operand->opcode() == HloOpcode::kIota) {
3537     return ReplaceWithNewInstruction(
3538         broadcast,
3539         HloInstruction::CreateIota(
3540             broadcast->shape(),
3541             dims[Cast<HloIotaInstruction>(operand)->iota_dimension()]));
3542   }
3543 
3544   // Merge two consecutive broadcasts into a single one.
3545   if (operand->opcode() == HloOpcode::kBroadcast) {
3546     std::vector<int64_t> new_dimensions;
3547     new_dimensions.reserve(operand->dimensions().size());
3548     for (auto dim : operand->dimensions()) {
3549       new_dimensions.push_back(dims[dim]);
3550     }
3551     return ReplaceWithNewInstruction(
3552         broadcast,
3553         HloInstruction::CreateBroadcast(
3554             broadcast->shape(), operand->mutable_operand(0), new_dimensions));
3555   }
3556   if (options_.is_layout_sensitive()) {
3557     return OkStatus();
3558   }
3559   if (ShapeUtil::HasDegenerateDimensions(operand->shape())) {
3560     auto new_operand =
3561         operand->parent()->AddInstruction(HloInstruction::CreateReshape(
3562             ShapeUtil::DropDegenerateDimensions(operand->shape()), operand));
3563     std::vector<int64_t> new_dims;
3564     new_dims.reserve(new_operand->shape().rank());
3565     for (int64_t i = 0; i < operand->shape().rank(); ++i) {
3566       if (operand->shape().dimensions(i) != 1) {
3567         new_dims.push_back(dims[i]);
3568       }
3569     }
3570     return ReplaceWithNewInstruction(
3571         broadcast, HloInstruction::CreateBroadcast(broadcast->shape(),
3572                                                    new_operand, new_dims));
3573   }
3574   return OkStatus();
3575 }
3576 
HandleCompare(HloInstruction * compare)3577 Status AlgebraicSimplifierVisitor::HandleCompare(HloInstruction* compare) {
3578   HloInstruction* lhs;
3579   HloInstruction* rhs;
3580   CHECK(Match(compare, m::Compare(m::Op(&lhs), m::Op(&rhs))));
3581 
3582   if (Cast<HloCompareInstruction>(compare)->type() ==
3583       Comparison::Type::kUnsigned) {
3584     // X u<  0 -> false
3585     if (compare->comparison_direction() == ComparisonDirection::kLt &&
3586         IsAll(rhs, 0)) {
3587       return ReplaceInstruction(compare, MakeScalarLike(compare, false));
3588     }
3589     // X u>= 0 -> true
3590     if (compare->comparison_direction() == ComparisonDirection::kGe &&
3591         IsAll(rhs, 0)) {
3592       return ReplaceInstruction(compare, MakeScalarLike(compare, true));
3593     }
3594     // 0 u>  X -> false
3595     if (compare->comparison_direction() == ComparisonDirection::kGt &&
3596         IsAll(lhs, 0)) {
3597       return ReplaceInstruction(compare, MakeScalarLike(compare, false));
3598     }
3599     // 0 u<= X -> true
3600     if (compare->comparison_direction() == ComparisonDirection::kLe &&
3601         IsAll(lhs, 0)) {
3602       return ReplaceInstruction(compare, MakeScalarLike(compare, true));
3603     }
3604   }
3605 
3606   if (compare->comparison_direction() == ComparisonDirection::kLt &&
3607       lhs->opcode() == HloOpcode::kIota && IsAll(rhs, 0)) {
3608     return ReplaceInstruction(compare, MakeScalarLike(compare, false));
3609   } else if (compare->comparison_direction() == ComparisonDirection::kGt &&
3610              IsAll(lhs, 0) && rhs->opcode() == HloOpcode::kIota) {
3611     return ReplaceInstruction(compare, MakeScalarLike(compare, false));
3612   } else if (compare->comparison_direction() == ComparisonDirection::kGe &&
3613              lhs->opcode() == HloOpcode::kIota && IsAll(rhs, 0)) {
3614     return ReplaceInstruction(compare, MakeScalarLike(compare, true));
3615   } else if (compare->comparison_direction() == ComparisonDirection::kLe &&
3616              IsAll(lhs, 0) && rhs->opcode() == HloOpcode::kIota) {
3617     return ReplaceInstruction(compare, MakeScalarLike(compare, true));
3618   }
3619   if (lhs == rhs &&
3620       primitive_util::IsIntegralType(lhs->shape().element_type())) {
3621     switch (compare->comparison_direction()) {
3622       case ComparisonDirection::kGt:
3623       case ComparisonDirection::kLt:
3624       case ComparisonDirection::kNe:
3625         return ReplaceInstruction(compare, MakeScalarLike(compare, false));
3626       case ComparisonDirection::kEq:
3627       case ComparisonDirection::kGe:
3628       case ComparisonDirection::kLe:
3629         return ReplaceInstruction(compare, MakeScalarLike(compare, true));
3630     }
3631   }
3632   return OkStatus();
3633 }
3634 
HandleConvert(HloInstruction * convert)3635 Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert) {
3636   PrimitiveType src_type = convert->operand(0)->shape().element_type();
3637   PrimitiveType dest_type = convert->shape().element_type();
3638   // A conversion to the same element type as the operand is a nop and can be
3639   // removed.  A conversion of a constant can be simplified by making a new
3640   // constant.
3641   if (src_type == dest_type) {
3642     return ReplaceInstruction(convert, convert->mutable_operand(0));
3643   }
3644 
3645   // Eliminate a convert pair if it is a no-op. The following are a few
3646   // example cases that are being handled:
3647   // 1. convert(convert(A, $TYPE1), $TYPE2) is simplified to A if A is of $TYPE2
3648   //    and convert(A, $TYPE1) is an upcast
3649   // 2. convert(convert(A, $TYPE1),$TYPE2) is simplified to A if A is of $TYPE2
3650   //    and convert(A, $TYPE1) is an upcast and is an integral conversion from
3651   //    unsigned to signed (only signed to unsigned conversion is NOT allowed)
3652   // 3. Tuple(convert(A, $TYPE1) , floor(convert(convert(A, $TYPE1), $TYPE2)),
3653   //    convert(convert(A, $TYPE1), $TYPE2)) is simplified to Tuple(convert(A,
3654   //    $TYPE1) , floor(A), A) -> a case where the first convert has a
3655   //    fan-out
3656   if (IsConvertPairNoOp(convert)) {
3657     return ReplaceInstruction(convert,
3658                               convert->mutable_operand(0)->mutable_operand(0));
3659   }
3660   return OkStatus();
3661 }
3662 
3663 // Complex(Real(c), Imag(c)) -> c
HandleComplex(HloInstruction * complex)3664 Status AlgebraicSimplifierVisitor::HandleComplex(HloInstruction* complex) {
3665   HloInstruction *c0, *c1;
3666   if (Match(complex, m::Complex(m::Real(m::Op(&c0)), m::Imag(m::Op(&c1)))) &&
3667       c0 == c1) {
3668     return ReplaceInstruction(complex, c0);
3669   }
3670   return OkStatus();
3671 }
3672 
3673 // Real(Complex(r, i)) -> r
HandleReal(HloInstruction * real)3674 Status AlgebraicSimplifierVisitor::HandleReal(HloInstruction* real) {
3675   HloInstruction* op;
3676   if (Match(real, m::Real(m::Complex(m::Op(&op), m::Op())))) {
3677     return ReplaceInstruction(real, op);
3678   }
3679   return OkStatus();
3680 }
3681 
3682 // Imag(Complex(r, i)) -> i
HandleImag(HloInstruction * imag)3683 Status AlgebraicSimplifierVisitor::HandleImag(HloInstruction* imag) {
3684   HloInstruction* op;
3685   if (Match(imag, m::Imag(m::Complex(m::Op(), m::Op(&op))))) {
3686     return ReplaceInstruction(imag, op);
3687   }
3688   return OkStatus();
3689 }
3690 
HandleIota(HloInstruction * instruction)3691 Status AlgebraicSimplifierVisitor::HandleIota(HloInstruction* instruction) {
3692   // iota -> zero if the iota dimension never produces an element other than
3693   // zero.
3694   auto* iota = Cast<HloIotaInstruction>(instruction);
3695   if (iota->shape().dimensions(iota->iota_dimension()) <= 1) {
3696     return ReplaceInstruction(iota, MakeScalarLike(iota, 0));
3697   }
3698   return OkStatus();
3699 }
3700 
HandlePad(HloInstruction * pad)3701 Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) {
3702   if (ShapeUtil::IsZeroElementArray(pad->operand(0)->shape())) {
3703     return ReplaceWithNewInstruction(
3704         pad, HloInstruction::CreateBroadcast(pad->shape(),
3705                                              pad->mutable_operand(1), {}));
3706   }
3707 
3708   // Interior padding on one sized dimensions have no effect. As a result it
3709   // makes other simplifications possible if there is no interior padding.
3710   if (HasInteriorPadding(pad->padding_config())) {
3711     PaddingConfig padding_config = pad->padding_config();
3712     bool cleared_interior_padding = false;
3713     for (int64_t i = 0; i < pad->shape().rank(); ++i) {
3714       if (padding_config.dimensions(i).interior_padding() > 0 &&
3715           pad->operand(0)->shape().dimensions(i) == 1) {
3716         cleared_interior_padding = true;
3717         padding_config.mutable_dimensions(i)->set_interior_padding(0);
3718       }
3719     }
3720     if (cleared_interior_padding) {
3721       return ReplaceWithNewInstruction(
3722           pad,
3723           HloInstruction::CreatePad(pad->shape(), pad->mutable_operand(0),
3724                                     pad->mutable_operand(1), padding_config));
3725     }
3726   }
3727 
3728   // Eliminate nop pads (padding all zero), and replace a pad with negative
3729   // padding with a pad with non-negative padding followed by a slice.
3730   bool all_zero = true;
3731   bool has_negative = false;
3732   // Used to possibly split off the unchanged padding dimensions.
3733   std::vector<int64_t> padding_dimensions;
3734   int64_t dimension_index = 0;
3735   for (auto& padding_dimension : pad->padding_config().dimensions()) {
3736     if (padding_dimension.edge_padding_low() < 0 ||
3737         padding_dimension.edge_padding_high() < 0) {
3738       has_negative = true;
3739     }
3740     if (padding_dimension.edge_padding_low() != 0 ||
3741         padding_dimension.edge_padding_high() != 0) {
3742       all_zero = false;
3743       padding_dimensions.push_back(dimension_index);
3744     } else if (padding_dimension.interior_padding()) {
3745       padding_dimensions.push_back(dimension_index);
3746     }
3747     dimension_index++;
3748   }
3749 
3750   if (all_zero) {
3751     if (ReplaceInstructionIfCompatible(pad, pad->mutable_operand(0))) {
3752       return OkStatus();
3753     }
3754   }
3755 
3756   // The context of this optimization can be found at b/163617402
3757   // It tries to capture the case of pad(broadcast(x)), where
3758   // x->shape().dimensions(), or broadcast(x)->dimensions(), is
3759   // a subset of the padded dimensions in pad->config(),
3760   // and the padded dimensions in pad->config() is in turn a strict
3761   // subset of broadcast->shape().dimensions(). The combined op can be
3762   // rewritten to broadcast2(pad(broadcast1(x))), where broadcast1 extends
3763   // x  with dimensions that need to be padded, and broadcast2 extends
3764   // the result of padding to full dimensions.
3765   // TODO(qyi): for future extensions: The condition for broadcast(x)
3766   // ->dimensions() to be a subset of padded dimensions in pad->config()
3767   // does not have to be strictly required, but it makes the calculation
3768   // for optimization easier, so it is required by the current implementation.
3769   // Only the second condition between the padded dimensions and the
3770   // dimensions of the final shape have to be enforced for the optimization
3771   // to make sense. If needed to remove the first constraint, the shape
3772   // calculations across the implementation need to be re-adjusted.
3773   auto pad_dims = padding_dimensions.size();
3774   if (pad_dims < dimension_index &&
3775       pad->operand(0)->opcode() == HloOpcode::kBroadcast &&
3776       pad->operand(0)->user_count() == 1 &&
3777       pad->operand(0)->operand(0)->shape().rank() <= pad_dims) {
3778     // Check broadcast operand dimensions is a subset of pading_dimensions.
3779     // If not, skip the optimization.
3780     bool opt_is_valid = true;
3781     std::vector<int64_t> broadcast_dimensions;
3782     HloBroadcastInstruction* broadcast =
3783         static_cast<HloBroadcastInstruction*>(pad->mutable_operand(0));
3784     for (auto broadcast_index : broadcast->dimensions()) {
3785       bool found = false;
3786       for (int i = 0; i < pad_dims; ++i) {
3787         if (broadcast_index == padding_dimensions[i]) {
3788           broadcast_dimensions.push_back(i);
3789           found = true;
3790           break;
3791         }
3792       }
3793       if (!found) {
3794         opt_is_valid = false;
3795         break;
3796       }
3797     }
3798     if (opt_is_valid) {
3799       auto pad_shape = pad->shape();
3800       auto broadcast_shape = broadcast->shape();
3801       auto pad_shape1 = pad_shape;
3802       auto broadcast_shape1 = broadcast_shape;
3803       PaddingConfig pad_config;
3804       for (int i = padding_dimensions.size() - 1; i >= 0; --i) {
3805         int64_t j = padding_dimensions[i];
3806         while (--dimension_index > j) {
3807           broadcast_shape1.DeleteDimension(dimension_index);
3808           pad_shape1.DeleteDimension(dimension_index);
3809         }
3810       }
3811       while (--dimension_index >= 0) {
3812         broadcast_shape1.DeleteDimension(dimension_index);
3813         pad_shape1.DeleteDimension(dimension_index);
3814       }
3815       for (auto dimension_to_pad : padding_dimensions) {
3816         auto dimension = pad_config.add_dimensions();
3817         *dimension = pad->padding_config().dimensions(dimension_to_pad);
3818       }
3819       *broadcast->mutable_shape() = broadcast_shape1;
3820       *broadcast->mutable_dimensions() = broadcast_dimensions;
3821       simplifier_->UpdateLayout(broadcast->mutable_shape());
3822       auto pad2 = pad->AddInstruction(pad->CloneWithNewShape(pad_shape1));
3823       *pad2->mutable_padding_config() = pad_config;
3824       simplifier_->UpdateLayout(pad2->mutable_shape());
3825       auto broadcast2 = pad->AddInstruction(
3826           HloInstruction::CreateBroadcast(pad_shape, pad2, padding_dimensions));
3827       return ReplaceInstruction(pad, broadcast2);
3828     }
3829   }
3830 
3831   if (has_negative && options_.enable_negative_padding_replacement()) {
3832     // Pad has negative padding. Replace with a pad with the non-negative
3833     // padding followed by a slice which effectively performs the negative
3834     // padding.
3835     // TODO(b/34628603): Add support for negative padding in the backends, or
3836     // change kPad semantics to disallow negative padding and use slice
3837     // instead.
3838 
3839     // First construct the padding config with non-negative entries and the
3840     // compute the shape of this new pad instruction.
3841     PaddingConfig nonzero_padding = pad->padding_config();
3842     for (int i = 0; i < pad->padding_config().dimensions_size(); ++i) {
3843       PaddingConfig::PaddingConfigDimension* padding_dimension =
3844           nonzero_padding.mutable_dimensions(i);
3845       // Set negative padding to zero.
3846       if (padding_dimension->edge_padding_low() < 0) {
3847         padding_dimension->set_edge_padding_low(0);
3848       }
3849       if (padding_dimension->edge_padding_high() < 0) {
3850         padding_dimension->set_edge_padding_high(0);
3851       }
3852     }
3853 
3854     TF_ASSIGN_OR_RETURN(HloInstruction * nonzero_pad,
3855                         MakePadHlo(pad->mutable_operand(0),
3856                                    pad->mutable_operand(1), nonzero_padding));
3857     // Copy the layout from the original pad instructions. The new pad and the
3858     // slice instruction should all have the same layout.
3859     TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
3860         pad->shape(), nonzero_pad->mutable_shape()));
3861     simplifier_->UpdateLayout(nonzero_pad->mutable_shape());
3862 
3863     // Second, construct the slice instruction to perform the negative
3864     // padding.
3865     std::vector<int64_t> start_indices;
3866     std::vector<int64_t> end_indices;
3867     std::vector<int64_t> strides;
3868     for (int64_t i = 0; i < pad->padding_config().dimensions_size(); ++i) {
3869       const PaddingConfig::PaddingConfigDimension& padding_dimension =
3870           pad->padding_config().dimensions(i);
3871       int64_t start = 0;
3872       if (padding_dimension.edge_padding_low() < 0) {
3873         start = -1 * padding_dimension.edge_padding_low();
3874       }
3875       int64_t end = nonzero_pad->shape().dimensions(i);
3876       if (padding_dimension.edge_padding_high() < 0) {
3877         end += padding_dimension.edge_padding_high();
3878       }
3879       start_indices.push_back(start);
3880       end_indices.push_back(end);
3881       strides.push_back(1);
3882     }
3883 
3884     TF_ASSIGN_OR_RETURN(
3885         HloInstruction * slice,
3886         MakeSliceHlo(nonzero_pad, start_indices, end_indices, strides));
3887     TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
3888         pad->shape(), slice->mutable_shape()));
3889     simplifier_->UpdateLayout(slice->mutable_shape());
3890 
3891     // Verify that the slice shape matches the pad shape.
3892     auto equal = Shape::Equal();
3893     if (!options_.is_layout_sensitive()) {
3894       equal.IgnoreTilesInLayout();
3895     }
3896     TF_RET_CHECK(equal(slice->shape(), pad->shape()));
3897 
3898     return ReplaceInstruction(pad, slice);
3899   }
3900 
3901   return OkStatus();
3902 }
3903 
HandlePower(HloInstruction * power)3904 Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
3905   VLOG(10) << "trying transform [pow(A, 0) => 1]: " << power->ToString();
3906   HloInstruction *lhs, *rhs;
3907   CHECK(Match(power, m::Power(m::Op(&lhs), m::Op(&rhs))));
3908   if (IsAll(rhs, 0)) {
3909     return ReplaceInstruction(power, MakeScalarLike(power, 1));
3910   }
3911 
3912   VLOG(10) << "trying transform [pow(A, 1) => A]: " << power->ToString();
3913   if (IsAll(rhs, 1) && ReplaceInstructionIfCompatible(power, lhs)) {
3914     return OkStatus();
3915   }
3916 
3917   // pow(exp(A),B) => exp(A*B)
3918   HloInstruction *a, *b;
3919   if (Match(power, m::Power(m::Exp(m::Op(&a)), m::Op(&b)))) {
3920     auto a_times_b = power->AddInstruction(HloInstruction::CreateBinary(
3921         power->shape(), HloOpcode::kMultiply, a, b));
3922     return ReplaceWithNewInstruction(
3923         power, HloInstruction::CreateUnary(power->shape(), HloOpcode::kExp,
3924                                            a_times_b));
3925   }
3926 
3927   VLOG(10) << "trying transform [pow(A, 2) => A*A]: " << power->ToString();
3928   if (IsAll(rhs, 2)) {
3929     return ReplaceWithNewInstruction(
3930         power, HloInstruction::CreateBinary(power->shape(),
3931                                             HloOpcode::kMultiply, lhs, lhs));
3932   }
3933 
3934   // Pow(A, 3) is used in GELU.
3935   VLOG(10) << "trying transform [pow(A, 3) => A*A*A]: " << power->ToString();
3936   if (IsAll(rhs, 3)) {
3937     HloInstruction* tmp = power->AddInstruction(HloInstruction::CreateBinary(
3938         power->shape(), HloOpcode::kMultiply, lhs, lhs));
3939     return ReplaceWithNewInstruction(
3940         power, HloInstruction::CreateBinary(power->shape(),
3941                                             HloOpcode::kMultiply, lhs, tmp));
3942   }
3943 
3944   VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString();
3945   if (IsAll(rhs, -1)) {
3946     return ReplaceWithNewInstruction(
3947         power, HloInstruction::CreateBinary(power->shape(), HloOpcode::kDivide,
3948                                             MakeScalarLike(lhs, 1), lhs));
3949   }
3950 
3951   return OkStatus();
3952 }
3953 
3954 StatusOr<bool>
TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(HloInstruction * broadcast)3955 AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
3956     HloInstruction* broadcast) {
3957   TF_RET_CHECK(broadcast->opcode() == HloOpcode::kBroadcast);
3958   bool changed = false;
3959   if (ShapeUtil::IsScalar(broadcast->shape())) {
3960     return false;
3961   }
3962   HloInstruction* operand = broadcast->mutable_operand(0);
3963   auto is_scalar_broadcast = [](const HloInstruction* instruction) {
3964     return instruction->opcode() == HloOpcode::kBroadcast &&
3965            ShapeUtil::IsScalar(instruction->operand(0)->shape());
3966   };
3967   auto is_equal_broadcast = [operand,
3968                              broadcast](const HloInstruction* instruction) {
3969     return instruction->opcode() == HloOpcode::kBroadcast &&
3970            ShapeUtil::Equal(operand->shape(),
3971                             instruction->operand(0)->shape()) &&
3972            broadcast->dimensions() == instruction->dimensions();
3973   };
3974   auto is_compatible_broadcast = [&](const HloInstruction* instruction) {
3975     return is_scalar_broadcast(instruction) || is_equal_broadcast(instruction);
3976   };
3977   for (HloInstruction* user : broadcast->users()) {
3978     if (user->IsDead()) {
3979       continue;
3980     }
3981     // Do not move reshapes or broadcasts past copies since the shape the copy
3982     // will operate on will change.
3983     if (user->opcode() == HloOpcode::kCopy) {
3984       continue;
3985     }
3986     // Do not change the shape of fusion nodes in case there a multiple shapes
3987     // inside the fusion node already.
3988     if (user->opcode() == HloOpcode::kFusion) {
3989       continue;
3990     }
3991     if (!user->IsElementwise()) {
3992       continue;
3993     }
3994 
3995     // Check if all the operands of the user are compatible broadcasts for
3996     // sinking. (They are either scalar broadcasts or broadcasts casting
3997     // from/to the same shape/dimensions)
3998     int64_t compatible_broadcast_count = 0;
3999     int64_t broadcast_use_count = 0;
4000     for (HloInstruction* user_operand : user->operands()) {
4001       if (is_compatible_broadcast(user_operand)) {
4002         ++compatible_broadcast_count;
4003       } else if (broadcast == user_operand) {
4004         ++broadcast_use_count;
4005       }
4006     }
4007     if (compatible_broadcast_count + broadcast_use_count !=
4008         user->operand_count()) {
4009       continue;
4010     }
4011     std::vector<HloInstruction*> new_operands;
4012     new_operands.reserve(user->operand_count());
4013 
4014     Shape changed_shape;
4015     for (HloInstruction* user_operand : user->operands()) {
4016       // If this is a broadcast operand that is not our original broadcast input
4017       // to this function then we might need to change the input.
4018       if (is_compatible_broadcast(user_operand)) {
4019         // If this is a broadcast from a scalar value rewrite a broadcast from
4020         // the scalar to the new shape enforced from the other broadcast
4021         // operands.
4022         if (is_scalar_broadcast(user_operand)) {
4023           changed_shape = ShapeUtil::ChangeElementType(
4024               operand->shape(), user_operand->shape().element_type());
4025           simplifier_->UpdateLayout(&changed_shape);
4026           new_operands.push_back(
4027               user_operand->AddInstruction(HloInstruction::CreateBroadcast(
4028                   changed_shape, user_operand->mutable_operand(0), {})));
4029         } else {
4030           // For the non-scalar broadcasts we guarantee that the shape of the
4031           // operand of the broadcast needs to be already a compatible shape.
4032           new_operands.push_back(user_operand->mutable_operand(0));
4033         }
4034       } else {
4035         CHECK_EQ(broadcast, user_operand);
4036         new_operands.push_back(operand);
4037       }
4038     }
4039     VLOG(4) << "Sinking broadcast after user:";
4040     VLOG(4) << "  old broadcast: " << broadcast->ToString();
4041     VLOG(4) << "  old user: " << user->ToString();
4042     changed_shape = ShapeUtil::ChangeElementType(operand->shape(),
4043                                                  user->shape().element_type());
4044     simplifier_->UpdateLayout(&changed_shape);
4045     HloInstruction* new_user = user->AddInstruction(
4046         user->CloneWithNewOperands(changed_shape, new_operands));
4047     VLOG(4) << "  new user: " << new_user->ToString();
4048     HloInstruction* new_broadcast =
4049         broadcast->AddInstruction(HloInstruction::CreateBroadcast(
4050             user->shape(), new_user, broadcast->dimensions()));
4051     VLOG(4) << "  new broadcast: " << new_broadcast->ToString();
4052     TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(new_broadcast));
4053     changed = true;
4054   }
4055   return changed;
4056 }
4057 
4058 namespace {
4059 template <typename T>
TryRemainderToAnd(HloInstruction * remainder,HloComputation * computation,AlgebraicSimplifier * simplifier)4060 std::unique_ptr<HloInstruction> TryRemainderToAnd(
4061     HloInstruction* remainder, HloComputation* computation,
4062     AlgebraicSimplifier* simplifier) {
4063   HloInstruction *a, *b, *c;
4064   CHECK(Match(remainder, m::Remainder(m::Op(&a), m::Op(&b))));
4065 
4066   if (ShapeUtil::ElementIsIntegral(remainder->shape()) &&
4067       !Match(b, m::ConstantEffectiveScalar(&c)) &&
4068       !Match(b, m::Broadcast(m::ConstantEffectiveScalar(&c)))) {
4069     return nullptr;
4070   }
4071 
4072   if (ShapeUtil::ElementIsSigned(remainder->shape())) {
4073     int64_t b_value = c->literal().GetFirstElement<T>();
4074     if (b_value > 0 && absl::has_single_bit(static_cast<uint64_t>(b_value))) {
4075       // Handle negative dividends by negating the result of the division.
4076       HloInstruction* zero_like_a = BroadcastZeros(
4077           computation, a->shape().element_type(), a->shape().dimensions());
4078 
4079       Shape compare_shape = ShapeUtil::ChangeElementType(a->shape(), PRED);
4080       simplifier->UpdateLayout(&compare_shape);
4081       auto* dividend_is_negative =
4082           remainder->AddInstruction(HloInstruction::CreateCompare(
4083               compare_shape, a, zero_like_a, ComparisonDirection::kLt));
4084 
4085       auto* negated_dividend = remainder->AddInstruction(
4086           HloInstruction::CreateUnary(a->shape(), HloOpcode::kNegate, a));
4087 
4088       auto* abs_dividend =
4089           remainder->AddInstruction(HloInstruction::CreateTernary(
4090               a->shape(), HloOpcode::kSelect, dividend_is_negative,
4091               negated_dividend, a));
4092 
4093       auto* quotient = remainder->AddInstruction(HloInstruction::CreateBinary(
4094           remainder->shape(), HloOpcode::kAnd, abs_dividend,
4095           MakeScalarLike(abs_dividend, b_value - 1)));
4096 
4097       auto* neqated_quotient =
4098           remainder->AddInstruction(HloInstruction::CreateUnary(
4099               quotient->shape(), HloOpcode::kNegate, quotient));
4100 
4101       return HloInstruction::CreateTernary(
4102           remainder->shape(), HloOpcode::kSelect, dividend_is_negative,
4103           neqated_quotient, quotient);
4104     }
4105   } else {
4106     uint64_t b_value = c->literal().GetFirstElement<T>();
4107     if (absl::has_single_bit(b_value)) {
4108       HloInstruction* mask_amount =
4109           remainder->AddInstruction(simplifier->CreateConstantWithLayoutUpdated(
4110               LiteralUtil::CreateR0<T>(b_value - 1)));
4111       if (!ShapeUtil::IsScalar(b->shape())) {
4112         mask_amount = remainder->AddInstruction(
4113             HloInstruction::CreateBroadcast(b->shape(), mask_amount, {}));
4114       }
4115       return HloInstruction::CreateBinary(remainder->shape(), HloOpcode::kAnd,
4116                                           a, mask_amount);
4117     }
4118   }
4119   return nullptr;
4120 }
4121 }  // namespace
4122 
HandleRemainder(HloInstruction * remainder)4123 Status AlgebraicSimplifierVisitor::HandleRemainder(HloInstruction* remainder) {
4124   HloInstruction *a, *b;
4125   CHECK(Match(remainder, m::Remainder(m::Op(&a), m::Op(&b))));
4126 
4127   // (A % B) % B == A % B.
4128   if (Match(a, m::Remainder(m::Op(), m::Op().Is(b)))) {
4129     return ReplaceInstruction(remainder, a);
4130   }
4131 
4132   // A % B => A & (B - 1) if B is a power of 2.
4133   switch (remainder->shape().element_type()) {
4134     case S8:
4135       if (std::unique_ptr<HloInstruction> shift =
4136               TryRemainderToAnd<int8_t>(remainder, computation_, simplifier_)) {
4137         return ReplaceWithNewInstruction(remainder, std::move(shift));
4138       }
4139       break;
4140     case S16:
4141       if (std::unique_ptr<HloInstruction> shift = TryRemainderToAnd<int16_t>(
4142               remainder, computation_, simplifier_)) {
4143         return ReplaceWithNewInstruction(remainder, std::move(shift));
4144       }
4145       break;
4146     case S32:
4147       if (std::unique_ptr<HloInstruction> shift = TryRemainderToAnd<int32_t>(
4148               remainder, computation_, simplifier_)) {
4149         return ReplaceWithNewInstruction(remainder, std::move(shift));
4150       }
4151       break;
4152     case S64:
4153       if (std::unique_ptr<HloInstruction> shift = TryRemainderToAnd<int64_t>(
4154               remainder, computation_, simplifier_)) {
4155         return ReplaceWithNewInstruction(remainder, std::move(shift));
4156       }
4157       break;
4158     case U8:
4159       if (std::unique_ptr<HloInstruction> shift = TryRemainderToAnd<uint8_t>(
4160               remainder, computation_, simplifier_)) {
4161         return ReplaceWithNewInstruction(remainder, std::move(shift));
4162       }
4163       break;
4164     case U16:
4165       if (std::unique_ptr<HloInstruction> shift = TryRemainderToAnd<uint16_t>(
4166               remainder, computation_, simplifier_)) {
4167         return ReplaceWithNewInstruction(remainder, std::move(shift));
4168       }
4169       break;
4170     case U32:
4171       if (std::unique_ptr<HloInstruction> shift = TryRemainderToAnd<uint32_t>(
4172               remainder, computation_, simplifier_)) {
4173         return ReplaceWithNewInstruction(remainder, std::move(shift));
4174       }
4175       break;
4176     case U64:
4177       if (std::unique_ptr<HloInstruction> shift = TryRemainderToAnd<uint64_t>(
4178               remainder, computation_, simplifier_)) {
4179         return ReplaceWithNewInstruction(remainder, std::move(shift));
4180       }
4181       break;
4182     default:
4183       break;
4184   }
4185 
4186   // If M < N, then {0, ..., M} % N ==> {0, ..., M}.
4187   //
4188   // Currently this only covers the case when N is a broadcasted constant
4189   // scalar.  We could also cover the case when N is a non-broadcasted constant
4190   // with the same value repeated.
4191   HloInstruction* iota;
4192   HloInstruction* divisor;
4193   if (Match(remainder,
4194             m::Remainder(m::Iota(&iota),
4195                          m::Broadcast(m::ConstantEffectiveScalar(&divisor))))) {
4196     // The iota counts {0, ..., iota_upper_bound - 1}.  (Actually this is
4197     // conservative; the iota may overflow and count up to a smaller value than
4198     // this.  But that's OK for our purposes here.)
4199     int64_t iota_upper_bound = iota->shape().dimensions(
4200         Cast<HloIotaInstruction>(iota)->iota_dimension());
4201     std::optional<int64_t> divisor_val = divisor->literal().GetIntegralAsS64(
4202         std::vector<int64_t>(0, divisor->shape().dimensions_size()));
4203     if (divisor_val && *divisor_val >= iota_upper_bound) {
4204       return ReplaceInstruction(remainder, iota);
4205     }
4206   }
4207 
4208   // (X + N) % N = X % N, so long as X + N does not overflow.
4209   //
4210   // We don't have range tracking in XLA that would let us know whether X + N
4211   // overflows, so for now we only do this simplification when X is an iota.  We
4212   // could add other operations where it's easy to see a range, such as
4213   // remainder, convert, etc., though at some point we'd probably want a
4214   // range-tracking analysis.
4215   HloInstruction* bcast;
4216   HloInstruction* addend;
4217   if (Match(
4218           remainder,
4219           m::Remainder(
4220               m::AddAnyOrder(m::Iota(&iota),
4221                              m::Broadcast(m::ConstantEffectiveScalar(&addend))),
4222               m::Broadcast(&bcast, m::ConstantEffectiveScalar(&divisor)))) &&
4223       addend == divisor) {
4224     // The iota counts {0, ...iota_upper_bound - 1}, with the same caveat above
4225     // that iota_upper_bound is conservative, and the true upper bound may be
4226     // smaller.
4227     int64_t iota_upper_bound = iota->shape().dimensions(
4228         Cast<HloIotaInstruction>(iota)->iota_dimension());
4229     std::optional<int64_t> divisor_val = divisor->literal().GetIntegralAsS64(
4230         std::vector<int64_t>(0, divisor->shape().dimensions_size()));
4231     if (divisor_val) {
4232       // Check whether divisor_val + iota_upper_bound - 1 overflows.
4233       std::optional<int64_t> max_val =
4234           OverflowSafeAdd(*divisor_val, iota_upper_bound);
4235       if (max_val.has_value() &&
4236           FitsInIntegralType(*max_val, iota->shape().element_type())) {
4237         return ReplaceWithNewInstruction(
4238             remainder,
4239             HloInstruction::CreateBinary(remainder->shape(),
4240                                          HloOpcode::kRemainder, iota, bcast));
4241       }
4242     }
4243   }
4244 
4245   return OkStatus();
4246 }
4247 
HandleReshape(HloInstruction * reshape)4248 Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) {
4249   auto operand = reshape->mutable_operand(0);
4250 
4251   // Reshape directly to empty constant if the shape contains zero-element
4252   // dimension.
4253   if (ShapeUtil::IsZeroElementArray(reshape->shape())) {
4254     // If the instruction doesn't have a layout, use a default layout for
4255     // the literal result.
4256     Shape reshaped_shape = reshape->shape();
4257     if (!LayoutUtil::HasLayout(reshaped_shape)) {
4258       LayoutUtil::SetToDefaultLayout(&reshaped_shape);
4259     }
4260     auto empty_constant = simplifier_->CreateConstantWithLayoutUpdated(
4261         Literal::CreateFromShape(reshaped_shape));
4262 
4263     return ReplaceWithNewInstruction(reshape, std::move(empty_constant));
4264   }
4265 
4266   // Delete no-op reshapes, i.e. where shape = operand shape.
4267   if (SameShape(reshape, operand)) {
4268     VLOG(3) << "deleting no-op reshape";
4269     return ReplaceInstruction(reshape, operand);
4270   }
4271 
4272   // Merge reshapes.
4273   if (HloOpcode::kReshape == operand->opcode()) {
4274     return ReplaceWithNewInstruction(
4275         reshape, HloInstruction::CreateReshape(reshape->shape(),
4276                                                operand->mutable_operand(0)));
4277   }
4278 
4279   if (operand->opcode() == HloOpcode::kRng && operand->user_count() == 1) {
4280     *operand->mutable_shape() = reshape->shape();
4281     return ReplaceInstruction(reshape, operand);
4282   }
4283 
4284   if (HloOpcode::kBroadcast == reshape->operand(0)->opcode()) {
4285     auto opt_dims = ReshapeLeavesDimensionsUnmodified(
4286         reshape, reshape->operand(0)->dimensions());
4287     if (opt_dims.has_value()) {
4288       return ReplaceWithNewInstruction(
4289           reshape,
4290           HloInstruction::CreateBroadcast(
4291               reshape->shape(), reshape->mutable_operand(0)->mutable_operand(0),
4292               *opt_dims));
4293     }
4294   }
4295 
4296   // reshape(iota) -> iota or a mixed radix calculation like
4297   // s32[2,3,4] reshape(s32[24] iota()) to
4298   // add(
4299   //    add(s32[2,3,4] iota() iota_dimension=2,
4300   //        4 * s32[2,3,4] iota() iota_dimension=1),
4301   //    12 * s32[2,3,4] iota() iota_dimension=0).
4302   if (operand->opcode() == HloOpcode::kIota) {
4303     auto* iota = Cast<HloIotaInstruction>(operand);
4304     auto common_factors =
4305         CommonFactors(reshape->operand(0)->shape().dimensions(),
4306                       reshape->shape().dimensions());
4307     auto iota_dim = absl::c_find_if(
4308         common_factors, [&](const std::pair<int64_t, int64_t>& dim_pair) {
4309           return dim_pair.first == iota->iota_dimension() &&
4310                  reshape->shape().dimensions(dim_pair.second) > 1;
4311         });
4312     auto next_dim = absl::c_find_if(
4313         common_factors, [&](const std::pair<int64_t, int64_t>& dim_pair) {
4314           return dim_pair.first == iota->iota_dimension() + 1;
4315         });
4316     if (iota_dim != common_factors.end() && next_dim != common_factors.end()) {
4317       int64_t multiplier = 1;
4318       HloInstruction* new_reshape = nullptr;
4319 
4320       for (int64_t dim = (iota_dim + 1)->second - 1; dim >= iota_dim->second;
4321            --dim) {
4322         HloInstruction* new_iota = iota->AddInstruction(
4323             HloInstruction::CreateIota(reshape->shape(), dim));
4324         if (new_reshape) {
4325           new_reshape = reshape->AddInstruction(HloInstruction::CreateBinary(
4326               reshape->shape(), HloOpcode::kAdd, new_reshape,
4327               reshape->AddInstruction(HloInstruction::CreateBinary(
4328                   reshape->shape(), HloOpcode::kMultiply, new_iota,
4329                   MakeScalarLike(reshape, multiplier)))));
4330         } else {
4331           new_reshape = new_iota;
4332         }
4333         multiplier *= reshape->shape().dimensions(dim);
4334       }
4335       return ReplaceInstruction(reshape, new_reshape);
4336     }
4337   }
4338 
4339   // Moves the reshape in reshape(dus(...), x, ...)) before dus so that it can
4340   // enable other optimizations, e.g., merging with broadcast, and sparse update
4341   // (add(x, dus(broadcast(0), y, ...)) -> dus(x, add(ds(x), y), ...)).
4342   if (!options_.is_layout_sensitive()) {
4343     HloInstruction* dus;
4344     HloInstruction* slice;
4345     std::optional<ShapeUtil::ShapeEqualityDescriptor> trivial_reshape =
4346         reshape->ReshapeMerelyInsertsOrDeletes1SizedDimensions();
4347     // 1-sized dimensions added and removed will be one sized in both the update
4348     // slice and the dynamic-update-slice result.
4349     if (trivial_reshape.has_value() &&
4350         Match(reshape->mutable_operand(0),
4351               m::Op(&dus)
4352                   .WithOpcode(HloOpcode::kDynamicUpdateSlice)
4353                   .WithOperand(1, m::Op(&slice))) &&
4354         !dus->has_sharding() && !dus->operand(0)->has_sharding()) {
4355       auto new_operand = reshape->AddInstruction(HloInstruction::CreateReshape(
4356           reshape->shape(), dus->mutable_operand(0)));
4357       std::vector<int64_t> new_slice_shape;
4358       std::vector<HloInstruction*> new_dus_operands;
4359       new_dus_operands.push_back(new_operand);
4360       new_dus_operands.push_back(nullptr);
4361       auto zero = MakeScalarLike(dus->mutable_operand(2), 0);
4362       const Shape& old_slice_shape = dus->operand(1)->shape();
4363       for (int64_t i = 0; i <= old_slice_shape.rank(); ++i) {
4364         if (absl::c_linear_search(trivial_reshape->deleted_dimensions, i)) {
4365           continue;
4366         }
4367         while (absl::c_linear_search(trivial_reshape->inserted_dimensions,
4368                                      new_slice_shape.size())) {
4369           new_slice_shape.push_back(1);
4370           new_dus_operands.push_back(zero);
4371         }
4372         if (i < old_slice_shape.rank()) {
4373           new_slice_shape.push_back(old_slice_shape.dimensions(i));
4374           new_dus_operands.push_back(dus->mutable_operand(2 + i));
4375         }
4376       }
4377       auto new_slice = reshape->AddInstruction(HloInstruction::CreateReshape(
4378           ShapeUtil::MakeShape(old_slice_shape.element_type(), new_slice_shape),
4379           slice));
4380       new_dus_operands[1] = new_slice;
4381       auto new_dus =
4382           dus->CloneWithNewOperands(reshape->shape(), new_dus_operands);
4383       return ReplaceWithNewInstruction(reshape, std::move(new_dus));
4384     }
4385   }
4386 
4387   // Make this a bitcast if possible.
4388   if (HloInstruction* bitcast_operand =
4389           BitcastingOperandOfReshapeOrCopyChain(reshape, options_)) {
4390     ReplaceWithBitcast(reshape, bitcast_operand);
4391   }
4392   return OkStatus();
4393 }
4394 
HandleReverse(HloInstruction * reverse)4395 Status AlgebraicSimplifierVisitor::HandleReverse(HloInstruction* reverse) {
4396   // When all the dimensions to reverse are trivial (i.e. the bound is 1),
4397   // there is nothing to be done.
4398   auto dim_is_one = [&](int64_t i) -> bool {
4399     return reverse->shape().dimensions(i) == 1;
4400   };
4401   if (absl::c_all_of(reverse->dimensions(), dim_is_one)) {
4402     return ReplaceInstruction(reverse, reverse->mutable_operand(0));
4403   }
4404   return OkStatus();
4405 }
4406 
TrySimplifyScalarSlice(HloInstruction * slice)4407 StatusOr<bool> AlgebraicSimplifierVisitor::TrySimplifyScalarSlice(
4408     HloInstruction* slice) {
4409   // Only try to do this for effective scalars. We could do the same for slicing
4410   // out larger pieces of padding (replacing with a broadcast of the padding
4411   // value), but this is probably not worth it.
4412   if (!ShapeUtil::IsEffectiveScalar(slice->shape())) {
4413     return false;
4414   }
4415 
4416   if (slice->operand(0)->opcode() == HloOpcode::kConcatenate) {
4417     VLOG(10) << "Trying to simplify scalar slice of concat";
4418     // Only do this for R1, there's no chance of this being useful otherwise.
4419     if (slice->shape().rank() != 1) {
4420       VLOG(10) << "Not folding, slice is not rank 1";
4421       return false;
4422     }
4423     HloConcatenateInstruction* concat =
4424         Cast<HloConcatenateInstruction>(slice->mutable_operand(0));
4425     int64_t operand_start = 0;
4426     int64_t operand_num = 0;
4427     // Weird loop structure to avoid annoying off-by-one errors.
4428     while (true) {
4429       TF_RET_CHECK(operand_num < concat->operand_count());
4430       const HloInstruction* operand = concat->operand(operand_num);
4431       int64_t next_operand_start =
4432           operand_start + operand->shape().dimensions(0);
4433       if (next_operand_start > slice->slice_starts(0)) {
4434         break;
4435       }
4436       operand_start = next_operand_start;
4437       operand_num++;
4438     }
4439 
4440     bool replaced = ReplaceInstructionIfCompatible(
4441         slice, concat->mutable_operand(operand_num));
4442     if (replaced) {
4443       VLOG(10) << "Folding scalar slice of concat into concat operand";
4444     } else {
4445       VLOG(10) << "Folding scalar slice of concat into slice of concat operand";
4446       TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
4447           slice, HloInstruction::CreateSlice(
4448                      slice->shape(), concat->mutable_operand(operand_num),
4449                      {slice->slice_starts(0) - operand_start},
4450                      {slice->slice_starts(0) - operand_start + 1},
4451                      slice->slice_strides())));
4452     }
4453     return true;
4454   }
4455 
4456   return false;
4457 }
4458 
TryToReorderSliceAndReshape(HloInstruction * slice)4459 StatusOr<bool> AlgebraicSimplifierVisitor::TryToReorderSliceAndReshape(
4460     HloInstruction* slice) {
4461   CHECK_EQ(slice->opcode(), HloOpcode::kSlice);
4462   if (!IsUnstridedSlice(slice)) {
4463     return false;
4464   }
4465   HloInstruction* reshape = slice->mutable_operand(0);
4466   if (reshape->opcode() != HloOpcode::kReshape) {
4467     return false;
4468   }
4469   HloInstruction* new_slice_operand = reshape->mutable_operand(0);
4470   int64_t slice_rank = slice->shape().rank();
4471   std::vector<int64_t> sliced_dims;
4472   for (int64_t i = 0; i < slice_rank; ++i) {
4473     if (slice->slice_starts(i) != 0 ||
4474         slice->slice_limits(i) != reshape->shape().dimensions(i)) {
4475       sliced_dims.push_back(i);
4476     }
4477   }
4478 
4479   if (sliced_dims.size() == 1 && sliced_dims[0] == 0 &&
4480       slice->slice_starts(0) == 0) {
4481     const Shape& new_slice_shape = new_slice_operand->shape();
4482     const int64_t rank = new_slice_shape.rank();
4483     std::vector<int64_t> new_slice_starts(rank, 0);
4484     std::vector<int64_t> new_slice_stides(rank, 1);
4485     std::vector<int64_t> new_slice_limits(new_slice_shape.dimensions().begin(),
4486                                           new_slice_shape.dimensions().end());
4487     int64_t slice_elements = ShapeUtil::ElementsIn(slice->shape());
4488     for (int64_t i = rank - 1; i >= 0; --i) {
4489       if (slice_elements >= new_slice_limits[i]) {
4490         if (slice_elements % new_slice_limits[i] != 0) {
4491           return false;
4492         }
4493         slice_elements /= new_slice_limits[i];
4494       } else {
4495         new_slice_limits[i] = slice_elements;
4496         slice_elements = 1;
4497       }
4498     }
4499     HloInstruction* new_slice =
4500         slice->AddInstruction(HloInstruction::CreateSlice(
4501             ShapeUtil::MakeShape(new_slice_shape.element_type(),
4502                                  new_slice_limits),
4503             new_slice_operand, new_slice_starts, new_slice_limits,
4504             new_slice_stides));
4505     simplifier_->UpdateLayout(new_slice->mutable_shape());
4506     TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
4507         slice, HloInstruction::CreateReshape(slice->shape(), new_slice)));
4508     return true;
4509   }
4510   return false;
4511 }
4512 
4513 // Allowing a slice to move through a reverse with any necessary updates to the
4514 // slice config.
TryToReorderSliceAndReverse(HloInstruction * slice)4515 StatusOr<bool> AlgebraicSimplifierVisitor::TryToReorderSliceAndReverse(
4516     HloInstruction* slice) {
4517   VLOG(2) << "Entered TryToReorderSliceAndReverse for slice:"
4518           << slice->ToString();
4519   if (Match(slice, m::Slice(m::Reverse()))) {
4520     HloInstruction* reverse = slice->mutable_operand(0);
4521     HloInstruction* reverse_operand = reverse->mutable_operand(0);
4522     std::vector<int64_t> new_starts = slice->slice_starts();
4523     std::vector<int64_t> new_limits = slice->slice_limits();
4524     std::vector<int64_t> new_strides = slice->slice_strides();
4525     for (auto rdim : reverse->dimensions()) {
4526       int64_t start = slice->slice_starts(rdim);
4527       int64_t limit = slice->slice_limits(rdim);
4528       int64_t stride = slice->slice_strides(rdim);
4529       // find_nth allows us to compute the appropriate index to begin
4530       // with during reverse even in the presence of non-unit strides
4531       int64_t find_nth = (limit - start - 1) / stride;
4532       find_nth = start + find_nth * stride;
4533       limit = find_nth + 1;
4534       new_starts[rdim] =
4535           (reverse->shape().dimensions(rdim) - start) - (limit - start);
4536       new_limits[rdim] = reverse->shape().dimensions(rdim) - start;
4537       VLOG(2) << "Analyzing dim:" << rdim << " (start,limit):" << start << ","
4538               << limit << " and new (start, limit):" << new_starts[rdim] << ","
4539               << new_limits[rdim];
4540     }
4541     // New slice formed from the reverse_operand, but strides and shape of the
4542     // slice output remains the same. New slice's starts and limits are updated
4543     // for ONLY the reversed dimensions as indicated above.
4544     HloInstruction* new_slice = slice->AddInstruction(
4545         HloInstruction::CreateSlice(slice->shape(), reverse_operand, new_starts,
4546                                     new_limits, new_strides));
4547     simplifier_->UpdateLayout(new_slice->mutable_shape());
4548     TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
4549         slice, HloInstruction::CreateReverse(new_slice->shape(), new_slice,
4550                                              reverse->dimensions())));
4551     // We do not delete the old reverse, since there might be another
4552     // consumer of that reverse (i.e., full reverse output). DCE should take
4553     // care of any deletion that is necessary if there was no use of reverse.
4554     return true;
4555   }
4556   return false;
4557 }
4558 
HandleSlice(HloInstruction * slice)4559 Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) {
4560   // Delete no-op slices, i.e. where shape = operand shape.
4561   if (ReplaceInstructionIfCompatible(slice, slice->mutable_operand(0))) {
4562     return OkStatus();
4563   }
4564 
4565   HloInstruction* pad;
4566   HloInstruction* pad_operand;
4567   if (Match(slice, m::Slice(m::Pad(&pad, m::Op(&pad_operand), m::Op())))) {
4568     // Is the result of the slice the pad operand.
4569     bool slice_undoes_pad = true;
4570     // Can the slice be moved to the pad_operand without any padding being read.
4571     bool slice_inside_pad = true;
4572     // Does this slice slice out pading only.
4573     bool slice_in_padding = false;
4574     std::vector<int64_t> new_starts = slice->slice_starts();
4575     std::vector<int64_t> new_limits = slice->slice_limits();
4576     for (int64_t i = 0; i < slice->shape().rank(); ++i) {
4577       const int64_t start = slice->slice_starts(i);
4578       const int64_t stride = slice->slice_strides(i);
4579       const int64_t limit = slice->slice_limits(i);
4580       const int64_t size = pad->shape().dimensions(i);
4581 
4582       const auto& dim = pad->padding_config().dimensions(i);
4583       const int64_t low = dim.edge_padding_low();
4584       const int64_t high = dim.edge_padding_high();
4585       const int64_t interior = dim.interior_padding();
4586       const int64_t edge = size - high;
4587 
4588       if (limit <= low || start >= edge) {
4589         slice_in_padding = true;
4590         break;
4591       }
4592 
4593       if (start != low || stride - 1 != interior) {
4594         slice_undoes_pad = false;
4595       }
4596 
4597       if (start < low || limit > edge || interior != 0 || stride != 1) {
4598         slice_inside_pad = false;
4599       }
4600       new_starts[i] -= low;
4601       new_limits[i] -= low;
4602     }
4603     if (slice_in_padding) {
4604       HloInstruction* broadcast =
4605           MakeBroadcastHlo(pad->mutable_operand(1), {}, slice->shape());
4606       *(broadcast->mutable_shape()) = slice->shape();
4607       return ReplaceInstruction(slice, broadcast);
4608     }
4609     if (slice_undoes_pad &&
4610         ReplaceInstructionIfCompatible(slice, pad_operand)) {
4611       return OkStatus();
4612     }
4613     if (slice_inside_pad) {
4614       TF_ASSIGN_OR_RETURN(HloInstruction * new_slice,
4615                           MakeSliceHlo(pad_operand, new_starts, new_limits,
4616                                        slice->slice_strides()));
4617       *(new_slice->mutable_shape()) = slice->shape();
4618       return ReplaceInstruction(slice, new_slice);
4619     }
4620   }
4621 
4622   if (slice->operand(0)->opcode() == HloOpcode::kSlice &&
4623       IsUnstridedSlice(slice) && IsUnstridedSlice(slice->operand(0))) {
4624     HloInstruction* operand_slice = slice->mutable_operand(0);
4625     std::vector<int64_t> new_slice_starts = slice->slice_starts();
4626     std::vector<int64_t> new_slice_limits = slice->slice_limits();
4627     for (int64_t i = 0; i < new_slice_starts.size(); ++i) {
4628       new_slice_starts[i] += operand_slice->slice_starts(i);
4629       new_slice_limits[i] += operand_slice->slice_starts(i);
4630     }
4631     return ReplaceWithNewInstruction(
4632         slice, HloInstruction::CreateSlice(
4633                    slice->shape(), operand_slice->mutable_operand(0),
4634                    new_slice_starts, new_slice_limits, slice->slice_strides()));
4635   }
4636 
4637   auto only_broadcast_dims_sliced = [&] {
4638     if (slice->operand(0)->opcode() != HloOpcode::kBroadcast) {
4639       return false;
4640     }
4641     for (int64_t dim : slice->operand(0)->dimensions()) {
4642       if (slice->slice_starts(dim) != 0 || slice->slice_strides(dim) != 1 ||
4643           slice->slice_limits(dim) !=
4644               slice->operand(0)->shape().dimensions(dim)) {
4645         return false;
4646       }
4647     }
4648     return true;
4649   };
4650   if (only_broadcast_dims_sliced()) {
4651     return ReplaceWithNewInstruction(
4652         slice,
4653         HloInstruction::CreateBroadcast(
4654             slice->shape(), slice->mutable_operand(0)->mutable_operand(0),
4655             slice->mutable_operand(0)->dimensions()));
4656   }
4657 
4658   TF_ASSIGN_OR_RETURN(bool replaced, TrySimplifyScalarSlice(slice));
4659   if (replaced) {
4660     return OkStatus();
4661   }
4662 
4663   HloInstruction* broadcast;
4664   HloInstruction* broadcast_operand;
4665   if (Match(slice,
4666             m::Slice(m::Broadcast(&broadcast, m::Op(&broadcast_operand))))) {
4667     std::vector<int64_t> new_slice_starts;
4668     std::vector<int64_t> new_slice_strides;
4669     std::vector<int64_t> new_slice_limits;
4670     new_slice_starts.reserve(broadcast_operand->shape().rank());
4671     new_slice_strides.reserve(broadcast_operand->shape().rank());
4672     new_slice_limits.reserve(broadcast_operand->shape().rank());
4673     for (int64_t dim : broadcast->dimensions()) {
4674       new_slice_starts.push_back(slice->slice_starts(dim));
4675       new_slice_strides.push_back(slice->slice_strides(dim));
4676       new_slice_limits.push_back(slice->slice_limits(dim));
4677     }
4678     VLOG(3) << "Sink broadcast through slice";
4679     VLOG(3) << "Original slice: " << slice->ToString();
4680     VLOG(3) << "Original broadcast: " << broadcast->ToString();
4681     auto new_slice_shape = broadcast_operand->shape();
4682     for (int64_t i = 0; i < broadcast_operand->shape().rank(); ++i) {
4683       int64_t size_i = (new_slice_limits[i] - new_slice_starts[i] +
4684                         new_slice_strides[i] - 1) /
4685                        new_slice_strides[i];
4686       new_slice_shape.set_dimensions(i, size_i);
4687     }
4688     simplifier_->UpdateLayout(&new_slice_shape);
4689     auto new_slice = slice->AddInstruction(HloInstruction::CreateSlice(
4690         new_slice_shape, broadcast_operand, new_slice_starts, new_slice_limits,
4691         new_slice_strides));
4692     auto new_broadcast =
4693         broadcast->AddInstruction(HloInstruction::CreateBroadcast(
4694             slice->shape(), new_slice, broadcast->dimensions()));
4695     VLOG(3) << "New slice: " << slice->ToString();
4696     VLOG(3) << "New broadcast: " << new_broadcast->ToString();
4697     return ReplaceInstruction(slice, new_broadcast);
4698   }
4699 
4700   // Try to simplify concat -> slice to an operand of concat.
4701   if (slice->operand(0)->opcode() == HloOpcode::kConcatenate &&
4702       IsUnstridedSlice(slice)) {
4703     HloInstruction* concat = slice->mutable_operand(0);
4704     int64_t concat_dim = concat->concatenate_dimension();
4705     int64_t piece_start = 0;
4706     std::optional<int64_t> start_operand;
4707     std::optional<int64_t> limit_operand;
4708     int64_t concat_start;
4709     int64_t concat_limit;
4710     const int64_t slice_start = slice->slice_starts(concat_dim);
4711     const int64_t slice_limit = slice->slice_limits(concat_dim);
4712     for (int64_t i = 0; i < concat->operand_count(); ++i) {
4713       const HloInstruction* piece = concat->operand(i);
4714       const int64_t piece_size = piece->shape().dimensions(concat_dim);
4715       if (!start_operand && piece_start <= slice_start &&
4716           piece_size + piece_start > slice_start) {
4717         start_operand = i;
4718         concat_start = piece_start;
4719       }
4720       piece_start += piece_size;
4721       if (!limit_operand && piece_start >= slice_limit) {
4722         limit_operand = i + 1;
4723         concat_limit = piece_start;
4724         break;
4725       }
4726     }
4727     if (start_operand && limit_operand &&
4728         *start_operand + 1 == *limit_operand &&
4729         SameShape(concat->operand(*start_operand), slice)) {
4730       return ReplaceInstruction(slice, concat->mutable_operand(*start_operand));
4731     }
4732     if (start_operand && limit_operand &&
4733         *limit_operand - *start_operand < concat->operand_count()) {
4734       std::vector<int64_t> starts = slice->slice_starts();
4735       starts[concat_dim] = starts[concat_dim] - concat_start;
4736       std::vector<int64_t> strides = slice->slice_strides();
4737       std::vector<int64_t> limits = slice->slice_limits();
4738       limits[concat_dim] =
4739           starts[concat_dim] + slice->shape().dimensions(concat_dim);
4740       HloInstruction* operand = concat->mutable_operand(*start_operand);
4741       if (*start_operand + 1 != *limit_operand) {
4742         TF_ASSIGN_OR_RETURN(
4743             HloInstruction * new_concat,
4744             MakeConcatHlo(
4745                 absl::MakeSpan(concat->operands())
4746                     .subspan(*start_operand, *limit_operand - *start_operand),
4747                 concat_dim));
4748         *new_concat->mutable_shape()->mutable_layout() =
4749             concat->shape().layout();
4750         simplifier_->UpdateLayout(new_concat->mutable_shape());
4751         concat->SetupDerivedInstruction(new_concat);
4752         operand = new_concat;
4753       }
4754       return ReplaceWithNewInstruction(
4755           slice, HloInstruction::CreateSlice(slice->shape(), operand, starts,
4756                                              limits, strides));
4757     }
4758   }
4759 
4760   // Do not try to reorder slices and reshapes after layout assignment as it may
4761   // be invalid.
4762   if (!options_.is_layout_sensitive()) {
4763     TF_ASSIGN_OR_RETURN(replaced, TryToReorderSliceAndReshape(slice));
4764   }
4765   if (replaced) {
4766     return OkStatus();
4767   }
4768 
4769   bool reversed = false;
4770   if (Match(slice, m::Slice(m::Reverse(m::Op())))) {
4771     TF_ASSIGN_OR_RETURN(reversed, TryToReorderSliceAndReverse(slice));
4772   }
4773   if (reversed) {
4774     return OkStatus();
4775   }
4776 
4777   return OkStatus();
4778 }
4779 
HandleRsqrt(HloInstruction * rsqrt)4780 Status AlgebraicSimplifierVisitor::HandleRsqrt(HloInstruction* rsqrt) {
4781   VLOG(10) << "trying transform [rsqrt(Pow(A, -2)) => |A|] "
4782            << rsqrt->ToString();
4783   HloInstruction* rsqrt_operand = rsqrt->mutable_operand(0);
4784   if (rsqrt_operand->opcode() == HloOpcode::kPower &&
4785       IsAll(rsqrt_operand->operand(1), -2) &&
4786       IsPositive(rsqrt_operand, options_)) {
4787     return ReplaceWithNewInstruction(
4788         rsqrt, HloInstruction::CreateUnary(rsqrt->shape(), HloOpcode::kAbs,
4789                                            rsqrt_operand->mutable_operand(0)));
4790   }
4791 
4792   VLOG(10) << "trying transform [rsqrt(Divide(1, A)) => sqrt(A)] "
4793            << rsqrt->ToString();
4794   if (rsqrt_operand->opcode() == HloOpcode::kDivide &&
4795       IsAll(rsqrt_operand->operand(0), 1) &&
4796       IsPositive(rsqrt_operand->operand(1), options_)) {
4797     return ReplaceWithNewInstruction(
4798         rsqrt, HloInstruction::CreateUnary(rsqrt->shape(), HloOpcode::kSqrt,
4799                                            rsqrt_operand->mutable_operand(1)));
4800   }
4801 
4802   return OkStatus();
4803 }
4804 
HandleDynamicSlice(HloInstruction * dynamic_slice)4805 Status AlgebraicSimplifierVisitor::HandleDynamicSlice(
4806     HloInstruction* dynamic_slice) {
4807   auto operand = dynamic_slice->mutable_operand(0);
4808   if (ShapeUtil::IsScalar(dynamic_slice->shape())) {
4809     return ReplaceInstruction(dynamic_slice, operand);
4810   }
4811   // DynamicSlice where operand has the same size as the output is simply equal
4812   // to operand.
4813   if (SameShape(operand, dynamic_slice)) {
4814     return ReplaceInstruction(dynamic_slice, operand);
4815   }
4816 
4817   HloInstruction* broadcast_operand;
4818   if (Match(operand, m::Broadcast(m::Op(&broadcast_operand)))) {
4819     std::vector<HloInstruction*> new_indices;
4820     new_indices.reserve(broadcast_operand->shape().rank());
4821     std::vector<int64_t> new_slice_sizes;
4822     new_slice_sizes.reserve(broadcast_operand->shape().rank());
4823 
4824     for (int64_t dim : operand->dimensions()) {
4825       new_indices.push_back(dynamic_slice->mutable_operand(1 + dim));
4826       new_slice_sizes.push_back(dynamic_slice->slice_sizes(dim));
4827     }
4828 
4829     VLOG(3) << "Sink broadcast through dynamic slice";
4830     VLOG(3) << "Original dynamic slice: " << dynamic_slice->ToString();
4831     VLOG(3) << "Original broadcast: " << operand->ToString();
4832     HloInstruction* new_dynamic_slice = broadcast_operand;
4833     if (!new_slice_sizes.empty()) {
4834       auto new_ds_shape = broadcast_operand->shape();
4835       for (int64_t i = 0; i < broadcast_operand->shape().rank(); ++i) {
4836         new_ds_shape.set_dimensions(i, new_slice_sizes[i]);
4837       }
4838       simplifier_->UpdateLayout(&new_ds_shape);
4839       new_dynamic_slice =
4840           dynamic_slice->AddInstruction(HloInstruction::CreateDynamicSlice(
4841               new_ds_shape, broadcast_operand, new_indices, new_slice_sizes));
4842     }
4843     auto new_broadcast =
4844         operand->AddInstruction(HloInstruction::CreateBroadcast(
4845             dynamic_slice->shape(), new_dynamic_slice, operand->dimensions()));
4846     VLOG(3) << "New dynamic slice: " << dynamic_slice->ToString();
4847     VLOG(3) << "New broadcast: " << new_broadcast->ToString();
4848     return ReplaceInstruction(dynamic_slice, new_broadcast);
4849   }
4850 
4851   HloInstruction *reshape, *reshape_operand;
4852   if (Match(operand, m::Reshape(&reshape, m::Op(&reshape_operand))) &&
4853       reshape->ReshapeMerelyInsertsOrDeletes1SizedDimensions().has_value() &&
4854       !options_.is_layout_sensitive()) {
4855     int64_t slice_dim = 0;
4856     HloInstruction* zero = MakeScalarLike(dynamic_slice->mutable_operand(1), 0);
4857     std::vector<HloInstruction*> starts;
4858     starts.reserve(reshape_operand->shape().rank());
4859     std::vector<int64_t> slice_sizes;
4860     slice_sizes.reserve(reshape_operand->shape().rank());
4861     for (int64_t dim = 0; dim < reshape_operand->shape().rank(); ++dim) {
4862       if (reshape_operand->shape().dimensions(dim) == 1) {
4863         starts.push_back(zero);
4864         slice_sizes.push_back(1);
4865         continue;
4866       }
4867       while (dynamic_slice->operand(0)->shape().dimensions(slice_dim) == 1) {
4868         ++slice_dim;
4869       }
4870       starts.push_back(dynamic_slice->mutable_operand(1 + slice_dim));
4871       slice_sizes.push_back(dynamic_slice->slice_sizes(slice_dim));
4872       ++slice_dim;
4873     }
4874     HloInstruction* new_dynamic_slice =
4875         dynamic_slice->AddInstruction(HloInstruction::CreateDynamicSlice(
4876             ShapeUtil::MakeShape(dynamic_slice->shape().element_type(),
4877                                  slice_sizes),
4878             reshape_operand, starts, slice_sizes));
4879     return ReplaceWithNewInstruction(
4880         dynamic_slice, HloInstruction::CreateReshape(dynamic_slice->shape(),
4881                                                      new_dynamic_slice));
4882   }
4883 
4884   HloInstruction *transpose, *transpose_operand;
4885   if (Match(operand, m::Transpose(&transpose, m::Op(&transpose_operand))) &&
4886       !options_.is_layout_sensitive()) {
4887     auto output_to_input = InversePermutation(transpose->dimensions());
4888     HloInstruction* new_slice =
4889         dynamic_slice->AddInstruction(HloInstruction::CreateDynamicSlice(
4890             ShapeUtil::PermuteDimensions(output_to_input,
4891                                          dynamic_slice->shape()),
4892             transpose_operand,
4893             Permute(absl::MakeSpan(dynamic_slice->operands().begin() + 1,
4894                                    dynamic_slice->operands().end()),
4895                     output_to_input),
4896             Permute(dynamic_slice->dynamic_slice_sizes(), output_to_input)));
4897     return ReplaceWithNewInstruction(
4898         dynamic_slice,
4899         HloInstruction::CreateTranspose(dynamic_slice->shape(), new_slice,
4900                                         transpose->dimensions()));
4901   }
4902 
4903   // Convert a dynamic slice into a slice if all offsets are constant and the
4904   // operand is not constant.
4905   if (operand->opcode() != HloOpcode::kConstant &&
4906       absl::c_all_of(absl::MakeSpan(dynamic_slice->operands().begin() + 1,
4907                                     dynamic_slice->operands().end()),
4908                      [](HloInstruction* operand) {
4909                        return operand->opcode() == HloOpcode::kConstant &&
4910                               ShapeUtil::ElementIsIntegral(operand->shape());
4911                      })) {
4912     const int64_t rank = operand->shape().rank();
4913     std::vector<int64_t> slice_starts(rank);
4914     std::vector<int64_t> slice_limits(rank);
4915     std::vector<int64_t> slice_strides(rank, 1);
4916 
4917     for (int64_t i = 0; i < rank; ++i) {
4918       std::optional<int64_t> offset =
4919           dynamic_slice->operand(i + 1)->literal().GetFirstInteger();
4920       if (!offset || *offset < 0) {
4921         return OkStatus();
4922       }
4923       const int64_t max_offset =
4924           dynamic_slice->operand(0)->shape().dimensions(i) -
4925           dynamic_slice->shape().dimensions(i);
4926       slice_starts[i] = std::min(max_offset, *offset);
4927       slice_limits[i] =
4928           std::min(max_offset, *offset) + dynamic_slice->shape().dimensions(i);
4929     }
4930     return ReplaceWithNewInstruction(
4931         dynamic_slice,
4932         HloInstruction::CreateSlice(dynamic_slice->shape(), operand,
4933                                     slice_starts, slice_limits, slice_strides));
4934   }
4935 
4936   // Convert the dynamic slice of an iota to just a reference to the index
4937   // (possibly clamped and scaled). Index is always a scalar integer. Output
4938   // should be a rank 1 array of size 1 with element type matching that of the
4939   // scalar index (except the signedness).
4940   const PrimitiveType element_type = dynamic_slice->shape().element_type();
4941   if (operand->shape().rank() == 1 && dynamic_slice->shape().rank() == 1 &&
4942       dynamic_slice->shape().dimensions(0) == 1 &&
4943       (element_type == S32 || element_type == U32)) {
4944     // Match multiply(x, broadcast(scalar)) and return the scalar
4945     // constant.
4946     auto match_multiply_with_scalar =
4947         [&](HloInstruction* hlo) -> HloInstruction* {
4948       if (hlo->opcode() != HloOpcode::kMultiply) {
4949         return nullptr;
4950       }
4951       HloInstruction* broadcast = hlo->mutable_operand(1);
4952       if (broadcast->opcode() == HloOpcode::kBroadcast &&
4953           broadcast->dimensions().empty() &&
4954           ShapeUtil::IsScalar(broadcast->operand(0)->shape())) {
4955         return broadcast->mutable_operand(0);
4956       }
4957       return nullptr;
4958     };
4959 
4960     HloInstruction* multiplier = match_multiply_with_scalar(operand);
4961     if (multiplier) {
4962       operand = operand->mutable_operand(0);
4963     }
4964 
4965     if (operand->opcode() == HloOpcode::kIota) {
4966       // This dynamic_slice will have a single start_index operand (since its
4967       // operand is rank 1).
4968       HloInstruction* index = dynamic_slice->mutable_operand(1);
4969       const PrimitiveType index_type = index->shape().element_type();
4970 
4971       auto create_constant = [&](int64_t value) {
4972         if (index_type == S32) {
4973           return MakeScalarLike<int32_t>(index, value);
4974         } else {
4975           return MakeScalarLike<uint32_t>(index, value);
4976         }
4977       };
4978 
4979       if (index_type == S32 || index_type == U32) {
4980         // Clamp the index to the range of the iota.
4981         int64_t iota_size = operand->shape().dimensions(0);
4982         HloInstruction* low = create_constant(0);
4983         HloInstruction* high = create_constant(iota_size - 1);
4984         HloInstruction* clamped =
4985             dynamic_slice->AddInstruction(HloInstruction::CreateTernary(
4986                 index->shape(), HloOpcode::kClamp, low, index, high));
4987 
4988         // Convert the clamped index from index_type to element_type and
4989         // multiply with the multiplier.
4990         HloInstruction* result = clamped;
4991         if (index_type != element_type) {
4992           Shape result_shp = result->shape();
4993           result_shp.set_element_type(element_type);
4994           result = dynamic_slice->AddInstruction(
4995               HloInstruction::CreateConvert(result_shp, clamped));
4996         }
4997 
4998         if (multiplier) {
4999           result = dynamic_slice->AddInstruction(HloInstruction::CreateBinary(
5000               result->shape(), HloOpcode::kMultiply, result, multiplier));
5001         }
5002 
5003         return ReplaceWithNewInstruction(
5004             dynamic_slice,
5005             HloInstruction::CreateReshape(dynamic_slice->shape(), result));
5006       }
5007     }
5008   }
5009 
5010   // ds(ds(x,id),inner_id) -> ds(x, id + inner_id)
5011   if (operand->opcode() == HloOpcode::kDynamicSlice) {
5012     TF_RETURN_IF_ERROR(dynamic_slice->ReplaceOperandWithDifferentShape(
5013         0, operand->mutable_operand(0)));
5014     for (int64_t i = 1; i < dynamic_slice->operand_count(); ++i) {
5015       HloInstruction* index = dynamic_slice->mutable_operand(i);
5016       HloInstruction* inner_index = operand->mutable_operand(i);
5017       inner_index = inner_index->AddInstruction(HloInstruction::CreateTernary(
5018           inner_index->shape(), HloOpcode::kClamp,
5019           MakeScalarLike(inner_index, 0), inner_index,
5020           MakeScalarLike(inner_index,
5021                          operand->operand(0)->shape().dimensions(i - 1) -
5022                              dynamic_slice->dynamic_slice_sizes()[i - 1])));
5023       if (inner_index->shape().element_type() !=
5024           index->shape().element_type()) {
5025         inner_index = inner_index->AddInstruction(
5026             HloInstruction::CreateConvert(index->shape(), inner_index));
5027       }
5028       HloInstruction* combined_index =
5029           operand->AddInstruction(HloInstruction::CreateBinary(
5030               index->shape(), HloOpcode::kAdd, index, inner_index));
5031       TF_RETURN_IF_ERROR(dynamic_slice->ReplaceOperandWith(i, combined_index));
5032     }
5033     MarkAsChanged();
5034   }
5035   return OkStatus();
5036 }
5037 
HandleDynamicUpdateSlice(HloInstruction * dynamic_update_slice)5038 Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice(
5039     HloInstruction* dynamic_update_slice) {
5040   // Rewriting DynamicUpdateSlice when it matches
5041   // dynamic_update_slice(broadcast(constant),data,constant_index0,...)
5042   // to a Pad(x, constant)
5043   // Only Broadcast considered currently, other ops need to be considered
5044   // in the future.
5045   HloInstruction* updated = dynamic_update_slice->mutable_operand(0);
5046   HloInstruction* dus_update = dynamic_update_slice->mutable_operand(1);
5047   HloInstruction* pad_value;
5048   if (Match(updated,
5049             m::Broadcast(m::Op(&pad_value).WithShape(m::Shape().IsScalar())))) {
5050     auto updated_shape = updated->shape();
5051     auto update_shape = dus_update->shape();
5052     auto update_start_indx = dynamic_update_slice->operand(2);
5053     int64_t offset = 0;
5054     bool compatible = true;
5055     // Whether the start indices to dynamic update slice is a list,
5056     // output of a tuple/concatenate, we setup the update_start_indx
5057     // appropriately.
5058     if (ShapeUtil::IsScalar(update_start_indx->shape())) {
5059       update_start_indx = dynamic_update_slice;
5060       offset = 2;
5061     } else {
5062       if (update_start_indx->opcode() == HloOpcode::kTuple ||
5063           update_start_indx->opcode() == HloOpcode::kConcatenate) {
5064         offset = 0;
5065       } else {
5066         compatible = false;
5067       }
5068     }
5069     PaddingConfig padding_config;
5070     if (compatible) {
5071       for (int64_t dim = 0; dim < updated_shape.rank(); ++dim) {
5072         auto padding_config_dim = padding_config.add_dimensions();
5073         auto slice_dim_start = update_start_indx->operand(dim + offset);
5074         if (!Match(slice_dim_start, m::ConstantScalar())) {
5075           compatible = false;
5076           break;
5077         }
5078         VLOG(2) << "slice: " << slice_dim_start->ToString();
5079         std::optional<int64_t> beg =
5080             slice_dim_start->literal().GetFirstInteger();
5081         if (!beg) {
5082           compatible = false;
5083           break;
5084         }
5085         VLOG(2) << "beg value: " << *beg;
5086         auto update_width = ShapeUtil::GetDimension(update_shape, dim);
5087         auto bcast_width = ShapeUtil::GetDimension(updated_shape, dim);
5088         // Clamp beg so that it is non-negative.
5089         *beg = std::max<int64_t>(0, *beg);
5090         // Clamp beg so that it is in-bounds.
5091         *beg = std::min<int64_t>(bcast_width - update_width, *beg);
5092         VLOG(2) << "adjusted beg value: " << *beg;
5093         padding_config_dim->set_edge_padding_low(*beg);
5094         padding_config_dim->set_edge_padding_high(bcast_width -
5095                                                   (*beg + update_width));
5096         // dynamic_update_slice does not specify a stride
5097         padding_config_dim->set_interior_padding(0);
5098       }
5099     }
5100 
5101     if (compatible) {
5102       HloInstruction* pad =
5103           dynamic_update_slice->AddInstruction(HloInstruction::CreatePad(
5104               updated_shape, dus_update, pad_value, padding_config));
5105       VLOG(2) << dynamic_update_slice->ToString();
5106       VLOG(2) << " with pad:" << pad->ToString();
5107       VLOG(2) << " Computation before rewrite is: "
5108               << dynamic_update_slice->parent()->ToString();
5109       return ReplaceInstruction(dynamic_update_slice, pad);
5110     }
5111   }
5112 
5113   // DynamicUpdateSlice where operand and dus_update have the same size is
5114   // equal to dus_update.
5115   if (SameShape(dynamic_update_slice, dus_update)) {
5116     return ReplaceInstruction(dynamic_update_slice, dus_update);
5117   }
5118 
5119   // If any dimension of dus_update is 0, elide the DynamicUpdateSlice.  This
5120   // optimization becomes invalid should we later prefer to warn about out of
5121   // bound indices.
5122   if (ShapeUtil::IsZeroElementArray(dus_update->shape())) {
5123     return ReplaceInstruction(dynamic_update_slice, updated);
5124   }
5125 
5126   // dus(a,dus(ds(a,id),c,inner_id)),id) is equivalent to dus(a,c,inner_id + id)
5127   if (dus_update->opcode() == HloOpcode::kDynamicUpdateSlice &&
5128       (dus_update->operand(0)->opcode() == HloOpcode::kDynamicSlice &&
5129        dus_update->operand(0)->operand(0) == dynamic_update_slice->operand(0) &&
5130        absl::c_equal(
5131            absl::MakeConstSpan(dynamic_update_slice->operands()).subspan(2),
5132            absl::MakeConstSpan(dus_update->operand(0)->operands())
5133                .subspan(1)))) {
5134     TF_RETURN_IF_ERROR(dynamic_update_slice->ReplaceOperandWithDifferentShape(
5135         1, dus_update->mutable_operand(1)));
5136     for (int64_t i = 2; i < dynamic_update_slice->operand_count(); ++i) {
5137       HloInstruction* index = dynamic_update_slice->mutable_operand(i);
5138       HloInstruction* inner_index = dus_update->mutable_operand(i);
5139       inner_index = inner_index->AddInstruction(HloInstruction::CreateTernary(
5140           inner_index->shape(), HloOpcode::kClamp,
5141           MakeScalarLike(inner_index, 0), inner_index,
5142           MakeScalarLike(
5143               inner_index,
5144               dus_update->shape().dimensions(i - 2) -
5145                   dus_update->operand(1)->shape().dimensions(i - 2))));
5146       if (inner_index->shape().element_type() !=
5147           index->shape().element_type()) {
5148         inner_index = inner_index->AddInstruction(
5149             HloInstruction::CreateConvert(index->shape(), inner_index));
5150       }
5151       HloInstruction* combined_index =
5152           dus_update->AddInstruction(HloInstruction::CreateBinary(
5153               index->shape(), HloOpcode::kAdd, index, inner_index));
5154       TF_RETURN_IF_ERROR(
5155           dynamic_update_slice->ReplaceOperandWith(i, combined_index));
5156     }
5157     MarkAsChanged();
5158     return OkStatus();
5159   }
5160   return OkStatus();
5161 }
5162 
MatchArgMinMax(const HloInstruction * hlo,bool is_max)5163 static bool MatchArgMinMax(const HloInstruction* hlo, bool is_max) {
5164   // Create matcher for shared sub-expression.
5165   auto value_pred = m::OrAnyOrder(
5166       m::Compare(m::Parameter(0), m::Parameter(2))
5167           .WithComparisonDirection(is_max ? ComparisonDirection::kGt
5168                                           : ComparisonDirection::kLt),
5169       m::Compare(m::Parameter(0), m::Parameter(0))
5170           .WithComparisonDirection(ComparisonDirection::kNe));
5171 
5172   // Match on argmax reduction computation.
5173   return Match(
5174       hlo,
5175       m::Tuple(
5176           m::Select(value_pred, m::Parameter(0), m::Parameter(2)),
5177           m::Select(
5178               m::OrAnyOrder(
5179                   value_pred,
5180                   m::And(
5181                       m::Compare(m::Parameter(0), m::Parameter(2))
5182                           .WithComparisonDirection(ComparisonDirection::kEq),
5183                       m::Compare(m::Parameter(1), m::Parameter(3))
5184                           .WithComparisonDirection(ComparisonDirection::kLt))),
5185               m::Parameter(1), m::Parameter(3))));
5186 }
5187 
5188 // Match on variadic reduce which computes and returns (min, arg_min).
5189 //
5190 //                   p0   p2    p1    p3
5191 //                  /|\ \/ |\    |\   /|
5192 //                 / | \/\ | \   | \ / |
5193 //                /  | /\ \|  |  |  /\ |
5194 //               Ne  Lt |  \  |  | |  ||
5195 //                 \ /  |  |\ |  | /  ||
5196 //                  Or /  /  Eq  Lt   ||
5197 //                  | /  /    \  /    //
5198 //                  | |  |     And   //
5199 //                  | |  |      |  //
5200 //                  select     select
5201 //                      \     /
5202 //                       tuple
5203 //
MatchArgMin(const HloInstruction * hlo)5204 static bool MatchArgMin(const HloInstruction* hlo) {
5205   // Match on variadic Reduce ArgMin
5206   if (hlo->opcode() != HloOpcode::kReduce || hlo->operand_count() != 4 ||
5207       !hlo->shape().IsTuple() ||
5208       hlo->operand(1)->opcode() != HloOpcode::kIota ||
5209       !IsScalarConstantInf(hlo->operand(2)) ||
5210       !IsScalarConstantZero(hlo->operand(3))) {
5211     return false;
5212   }
5213   return MatchArgMinMax(hlo->to_apply()->root_instruction(), /*is_max=*/false);
5214 }
5215 
5216 // Match on variadic reduce which computes and returns (max, arg_max).
5217 //
5218 //                   p0   p2    p1    p3
5219 //                  /|\ \/ |\    |\   /|
5220 //                 / | \/\ | \   | \ / |
5221 //                /  | /\ \|  |  |  /\ |
5222 //               Ne  Gt |  \  |  | |  ||
5223 //                 \ /  |  |\ |  | /  ||
5224 //                  Or /  /  Eq  Lt   ||
5225 //                  | /  /    \  /    //
5226 //                  | |  |     And   //
5227 //                  | |  |      |  //
5228 //                  select     select
5229 //                      \     /
5230 //                       tuple
5231 //
MatchArgMax(const HloInstruction * hlo)5232 static bool MatchArgMax(const HloInstruction* hlo) {
5233   // Match on variadic Reduce ArgMax.
5234   if (hlo->opcode() != HloOpcode::kReduce || hlo->operand_count() != 4 ||
5235       !hlo->shape().IsTuple() ||
5236       hlo->operand(1)->opcode() != HloOpcode::kIota ||
5237       !IsScalarConstantNegInf(hlo->operand(2)) ||
5238       !IsScalarConstantZero(hlo->operand(3))) {
5239     return false;
5240   }
5241   return MatchArgMinMax(hlo->to_apply()->root_instruction(), /*is_max=*/true);
5242 }
5243 
ReductionComputationsEquivalent(const HloComputation & a,const HloComputation & b)5244 static bool ReductionComputationsEquivalent(const HloComputation& a,
5245                                             const HloComputation& b) {
5246   if (a == b) {
5247     return true;
5248   }
5249 
5250   // Check for simple commutative reduction functions.
5251   enum CommutativeFnKind { kAdd, kMul, kAnd, kOr };
5252   auto categorize_computation =
5253       [](const HloComputation& c) -> std::optional<CommutativeFnKind> {
5254     if (c.num_parameters() != 2) {
5255       return std::nullopt;
5256     }
5257 
5258     const HloInstruction* root = c.root_instruction();
5259     if (Match(root, m::AddAnyOrder(m::Parameter(0), m::Parameter(1)))) {
5260       return kAdd;
5261     }
5262     if (Match(root, m::MultiplyAnyOrder(m::Parameter(0), m::Parameter(1)))) {
5263       return kMul;
5264     }
5265     if (Match(root, m::AndAnyOrder(m::Parameter(0), m::Parameter(1)))) {
5266       return kAnd;
5267     }
5268     if (Match(root, m::OrAnyOrder(m::Parameter(0), m::Parameter(1)))) {
5269       return kOr;
5270     }
5271     return std::nullopt;
5272   };
5273   auto category_a = categorize_computation(a);
5274   auto category_b = categorize_computation(b);
5275   return category_a.has_value() && category_b.has_value() &&
5276          category_a == category_b;
5277 }
5278 
HandleReduce(HloInstruction * hlo)5279 Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) {
5280   HloReduceInstruction* reduce = Cast<HloReduceInstruction>(hlo);
5281   bool multi_output_reduce = reduce->shape().IsTuple();
5282   // For tuple reduce, we require all reduce shapes to be the same, up to the
5283   // element types, so we can just the first operand and the first result as a
5284   // representative.
5285   auto arg = reduce->inputs()[0];
5286   auto init_value = reduce->init_values()[0];
5287   const Shape& reduce_result_shape =
5288       multi_output_reduce ? reduce->shape().tuple_shapes(0) : reduce->shape();
5289 
5290   absl::Span<const int64_t> dimensions(reduce->dimensions());
5291   HloComputation* function = reduce->to_apply();
5292   if (ShapeUtil::IsZeroElementArray(arg->shape()) ||
5293       ShapeUtil::IsZeroElementArray(reduce_result_shape)) {
5294     if (multi_output_reduce) {
5295       std::vector<HloInstruction*> broadcast_inits;
5296       int64_t inputs = reduce->input_count();
5297       for (int64_t i = 0; i < inputs; ++i) {
5298         broadcast_inits.push_back(reduce->init_values()[i]->AddInstruction(
5299             HloInstruction::CreateBroadcast(reduce->shape().tuple_shapes(i),
5300                                             reduce->init_values()[i], {})));
5301       }
5302       return ReplaceWithNewInstruction(
5303           reduce, HloInstruction::CreateTuple(broadcast_inits));
5304     } else {
5305       return ReplaceWithNewInstruction(
5306           reduce,
5307           HloInstruction::CreateBroadcast(reduce_result_shape, init_value, {}));
5308     }
5309   }
5310 
5311   // Turn trivial variadic reductions into normal reductions.
5312   if (multi_output_reduce && reduce->shape().tuple_shapes_size() == 1 &&
5313       reduce->input_count() == 1 &&
5314       Match(function->root_instruction(), m::Tuple())) {
5315     absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
5316         replacements;
5317     replacements[function->root_instruction()] = nullptr;
5318     auto new_function = computation_->parent()->AddEmbeddedComputation(
5319         function->CloneWithReplacements(
5320             &replacements, /*extra_parameters=*/{},
5321             /*context=*/nullptr,
5322             /*suffix=*/"clone",
5323             /*new_root=*/function->root_instruction()->operand(0)));
5324     auto new_reduce = reduce->AddInstruction(
5325         HloInstruction::CreateReduce(reduce_result_shape, arg, init_value,
5326                                      reduce->dimensions(), new_function));
5327     return ReplaceWithNewInstruction(reduce,
5328                                      HloInstruction::CreateTuple({new_reduce}));
5329   }
5330 
5331   // If the reduction results in the same number of elements, then the only
5332   // possible side effect would be a reshape. Since the init_value is an
5333   // identity of the reduction function, we can therefore replace the reduce
5334   // with a simple reshape, ignoring the reduction function completely.
5335   if (ShapeUtil::ElementsIn(reduce_result_shape) ==
5336           ShapeUtil::ElementsIn(arg->shape()) &&
5337       (!options_.is_layout_sensitive() ||
5338        options_.ReshapeIsBitcast(arg->shape(), reduce_result_shape))) {
5339     if (multi_output_reduce) {
5340       std::vector<HloInstruction*> reshaped_args;
5341       int64_t inputs = reduce->input_count();
5342       for (int64_t i = 0; i < inputs; ++i) {
5343         reshaped_args.push_back(
5344             reduce->AddInstruction(HloInstruction::CreateReshape(
5345                 reduce->shape().tuple_shapes(i), reduce->inputs()[i])));
5346       }
5347       return ReplaceWithNewInstruction(
5348           reduce, HloInstruction::CreateTuple(reshaped_args));
5349     } else {
5350       return ReplaceWithNewInstruction(
5351           reduce, HloInstruction::CreateReshape(reduce_result_shape, arg));
5352     }
5353   }
5354 
5355   if (options_.is_layout_sensitive()) {
5356     return OkStatus();
5357   }
5358 
5359   // TODO(b/131122694): Most of those optimizations below can be done for
5360   // multi-output reduces.
5361   if (multi_output_reduce) {
5362     return OkStatus();
5363   }
5364 
5365   // A Transpose feeding a reduce can simply permute the reduction dimensions
5366   // field if the output of the reduce is a vector or scalar. Higher ranked
5367   // result may require a transpose of the output.
5368   if (arg->opcode() == HloOpcode::kTranspose &&
5369       (reduce->shape().rank() < 2 || arg->user_count() == 1 ||
5370        absl::c_all_of(arg->users(), [](HloInstruction* use) {
5371          return use->opcode() == HloOpcode::kReduce;
5372        }))) {
5373     auto transpose_dimensions = arg->dimensions();
5374     std::vector<int64_t> new_reduce_dimensions;
5375     new_reduce_dimensions.reserve(dimensions.size());
5376     for (auto dim : dimensions) {
5377       new_reduce_dimensions.push_back(transpose_dimensions[dim]);
5378     }
5379 
5380     Shape new_reduce_result_shape = ShapeUtil::DeleteDimensions(
5381         new_reduce_dimensions, arg->mutable_operand(0)->shape());
5382     HloInstruction* new_reduce =
5383         reduce->AddInstruction(HloInstruction::CreateReduce(
5384             new_reduce_result_shape, arg->mutable_operand(0), init_value,
5385             new_reduce_dimensions, function));
5386     std::vector<int64_t> new_transpose_dimensions;
5387     for (auto dim : transpose_dimensions) {
5388       if (absl::c_linear_search(new_reduce_dimensions, dim)) {
5389         continue;
5390       }
5391       new_transpose_dimensions.push_back(dim);
5392     }
5393 
5394     // If new transpose dimensions are sorted, then there is no need to
5395     // transpose reduce result.
5396     if (absl::c_is_sorted(new_transpose_dimensions)) {
5397       return ReplaceInstruction(reduce, new_reduce);
5398     }
5399     for (auto& d : new_transpose_dimensions) {
5400       auto old_dim = d;
5401       for (auto reduced_dim : new_reduce_dimensions) {
5402         if (old_dim > reduced_dim) {
5403           --d;
5404         }
5405       }
5406     }
5407     TF_ASSIGN_OR_RETURN(HloInstruction * new_transpose,
5408                         MakeTransposeHlo(new_reduce, new_transpose_dimensions));
5409     return ReplaceInstruction(reduce, new_transpose);
5410   }
5411 
5412   // If a reduce feeds a reduce with the same computation and initial value,
5413   // they can be combined into a single reduce.
5414   if (arg->opcode() == HloOpcode::kReduce &&
5415       init_value->Identical(*arg->operand(1)) &&
5416       ReductionComputationsEquivalent(*function, *arg->to_apply())) {
5417     // Create a new reduce with the combined reduction dimensions of both
5418     // reduces.
5419     std::vector<int64_t> arg_dims = *arg->mutable_dimensions();
5420     absl::c_sort(arg_dims);
5421     std::vector<int64_t> reduce_dims = *reduce->mutable_dimensions();
5422     absl::c_sort(reduce_dims);
5423     // Transform reduce_dims to the same rank as the operand of the operand.
5424     for (int64_t arg_dim : arg_dims) {
5425       for (int64_t& dim : reduce_dims) {
5426         if (dim >= arg_dim) {
5427           ++dim;
5428         }
5429       }
5430     }
5431     std::vector<int64_t> new_dimensions;
5432     new_dimensions.reserve(arg->dimensions().size() +
5433                            reduce->dimensions().size());
5434     std::merge(arg_dims.begin(), arg_dims.end(), reduce_dims.begin(),
5435                reduce_dims.end(), std::back_inserter(new_dimensions));
5436     return ReplaceWithNewInstruction(
5437         reduce, HloInstruction::CreateReduce(
5438                     reduce_result_shape, arg->mutable_operand(0), init_value,
5439                     new_dimensions, function));
5440   }
5441 
5442   // A reshape that collapses multiple dimensions into a dimension being
5443   // reduced can just reduce all of those dimensions instead of doing a
5444   // collapsing reshape before a reduction.
5445   if (options_.enable_reduce_of_reshape() &&
5446       arg->opcode() == HloOpcode::kReshape) {
5447     std::vector<std::pair<int64_t, int64_t>> unmodified_dims =
5448         ShapeUtil::DimensionsUnmodifiedByReshape(arg->operand(0)->shape(),
5449                                                  arg->shape());
5450     std::vector<bool> arg_dim_in_output(arg->shape().rank(), true);
5451     std::vector<bool> arg_dim_unmodified(arg->shape().rank(), false);
5452     for (auto dim : dimensions) {
5453       arg_dim_in_output[dim] = false;
5454     }
5455     for (auto dim_pair : unmodified_dims) {
5456       arg_dim_unmodified[dim_pair.second] = true;
5457     }
5458     // The goal is to verify that all dimensions that are not removed in the
5459     // reduce are unmodified by the reshape. For example:
5460     // reduce(reshape([A,B*C], a[A,B,C]),[1]) = reduce(a[A, B, C], [1, 2])
5461     bool can_move_reshape_into_reduce = true;
5462     for (int64_t i = 0; i < arg_dim_in_output.size(); ++i) {
5463       if (arg_dim_in_output[i] && !arg_dim_unmodified[i]) {
5464         can_move_reshape_into_reduce = false;
5465       }
5466     }
5467     if (can_move_reshape_into_reduce) {
5468       MarkAsChanged();
5469       absl::flat_hash_set<int64_t> dimensions_not_to_reduce;
5470       for (auto dim_pair : unmodified_dims) {
5471         if (arg_dim_in_output[dim_pair.second]) {
5472           dimensions_not_to_reduce.insert(dim_pair.first);
5473         }
5474       }
5475       std::vector<int64_t> new_reduce_dimensions;
5476       for (int64_t i = 0; i < arg->operand(0)->shape().rank(); ++i) {
5477         if (!dimensions_not_to_reduce.contains(i)) {
5478           new_reduce_dimensions.push_back(i);
5479         }
5480       }
5481       return ReplaceWithNewInstruction(
5482           reduce, HloInstruction::CreateReduce(
5483                       reduce_result_shape, arg->mutable_operand(0), init_value,
5484                       new_reduce_dimensions, function));
5485     }
5486   }
5487   // Convert Reduce(concat({a,b,...})) to
5488   //  map(reduce(a),map(reduce(b),...,))
5489   //
5490   // This should make fusion easier or use less memory bandwidth in the unfused
5491   // case.
5492   if (arg->opcode() == HloOpcode::kConcatenate &&
5493       absl::c_linear_search(reduce->dimensions(),
5494                             arg->concatenate_dimension())) {
5495     HloInstruction* old_reduce = nullptr;
5496     for (HloInstruction* operand : arg->operands()) {
5497       HloInstruction* new_reduce = reduce->AddInstruction(
5498           HloInstruction::CreateReduce(reduce_result_shape, operand, init_value,
5499                                        reduce->dimensions(), function));
5500       if (old_reduce != nullptr) {
5501         new_reduce = reduce->AddInstruction(HloInstruction::CreateMap(
5502             reduce_result_shape, {old_reduce, new_reduce}, function));
5503       }
5504       old_reduce = new_reduce;
5505     }
5506     return ReplaceInstruction(reduce, old_reduce);
5507   }
5508 
5509   HloInstruction *dot, *lhs, *rhs;
5510   // Convert Reduce(Dot(X,Y)) to Dot(X,Y) if any of the dimensions reduced were
5511   // batch dimensions of the dot. The transformation supports reducing other
5512   // dimensions as well.
5513   if (options_.enable_dot_strength_reduction() &&
5514       Match(arg, m::Dot(&dot, m::Op(&lhs), m::Op(&rhs)).WithOneUser()) &&
5515       Match(reduce->to_apply()->root_instruction(),
5516             m::AddAnyOrder(m::Parameter(0), m::Parameter(1))) &&
5517       absl::c_any_of(reduce->dimensions(), [&](int64_t dim) {
5518         return dim < dot->dot_dimension_numbers().lhs_batch_dimensions_size();
5519       })) {
5520     const auto& dnums = dot->dot_dimension_numbers();
5521     DotDimensionNumbers new_dnums = dnums;
5522     new_dnums.clear_lhs_batch_dimensions();
5523     new_dnums.clear_rhs_batch_dimensions();
5524     int64_t removed_dims = 0;
5525     for (int64_t batch_dim = 0; batch_dim < dnums.lhs_batch_dimensions_size();
5526          ++batch_dim) {
5527       if (absl::c_linear_search(reduce->dimensions(), batch_dim)) {
5528         new_dnums.add_rhs_contracting_dimensions(
5529             dnums.rhs_batch_dimensions(batch_dim));
5530         new_dnums.add_lhs_contracting_dimensions(
5531             dnums.lhs_batch_dimensions(batch_dim));
5532         ++removed_dims;
5533       } else {
5534         new_dnums.add_rhs_batch_dimensions(
5535             dnums.rhs_batch_dimensions(batch_dim));
5536         new_dnums.add_lhs_batch_dimensions(
5537             dnums.lhs_batch_dimensions(batch_dim));
5538       }
5539     }
5540     std::vector<int64_t> reduce_dims;
5541     for (int64_t dim : reduce->dimensions()) {
5542       if (dim >= dnums.lhs_batch_dimensions_size()) {
5543         reduce_dims.push_back(dim - removed_dims);
5544       }
5545     }
5546     TF_ASSIGN_OR_RETURN(
5547         auto new_dot,
5548         MakeDotHlo(lhs, rhs, new_dnums, dot->precision_config(),
5549                    /*preferred_element_type=*/dot->shape().element_type()));
5550     dot->SetupDerivedInstruction(new_dot);
5551     if (reduce_dims.empty()) {
5552       return ReplaceInstruction(hlo, new_dot);
5553     }
5554     TF_ASSIGN_OR_RETURN(
5555         auto new_reduce,
5556         MakeReduceHlo(new_dot, init_value, reduce_dims, HloOpcode::kAdd));
5557     reduce->SetupDerivedInstruction(new_reduce);
5558     return ReplaceInstruction(hlo, new_reduce);
5559   }
5560 
5561   // Replace Use(ReduceMax(Arg)) with Use(Gte(ReduceArgMax, 0)).
5562   // Match on Reduce Max with init value -Inf.
5563   if (reduce->operand_count() == 2 && IsScalarConstantNegInf(init_value) &&
5564       Match(reduce->to_apply()->root_instruction(),
5565             m::MaximumAnyOrder(m::Parameter(0), m::Parameter(1)))) {
5566     // Match on variadic Reduce ArgMax which is also fed by 'arg'.
5567     auto arg_max_candidate =
5568         absl::c_find_if(arg->users(), [&](const HloInstruction* user) {
5569           return user != reduce && user->operand(0) == arg &&
5570                  MatchArgMax(user) &&
5571                  reduce->dimensions() == user->dimensions();
5572         });
5573     if (arg_max_candidate != arg->users().end()) {
5574       // Replace 'reduce' uses with GTE(ArgMax, 0).
5575       return ReplaceWithNewInstruction(
5576           reduce, HloInstruction::CreateGetTupleElement(*arg_max_candidate,
5577                                                         /*index=*/0));
5578     }
5579   }
5580 
5581   // Replace Use(ReduceMin(Arg)) with Use(Gte(ReduceArgMin, 0)).
5582   // Match on Reduce Min with init value Inf.
5583   if (reduce->operand_count() == 2 && IsScalarConstantInf(init_value) &&
5584       Match(reduce->to_apply()->root_instruction(),
5585             m::MinimumAnyOrder(m::Parameter(0), m::Parameter(1)))) {
5586     // Match on variadic Reduce ArgMin which is also fed by 'arg'.
5587     auto arg_min_candidate =
5588         absl::c_find_if(arg->users(), [&](const HloInstruction* user) {
5589           return user != reduce && user->operand(0) == arg &&
5590                  MatchArgMin(user) &&
5591                  reduce->dimensions() == user->dimensions();
5592         });
5593     if (arg_min_candidate != arg->users().end()) {
5594       // Replace 'reduce' uses with GTE(ArgMin, 0).
5595       return ReplaceWithNewInstruction(
5596           reduce, HloInstruction::CreateGetTupleElement(*arg_min_candidate,
5597                                                         /*index=*/0));
5598     }
5599   }
5600   return OkStatus();
5601 }
5602 
HandleReduceWindow(HloInstruction * hlo)5603 Status AlgebraicSimplifierVisitor::HandleReduceWindow(HloInstruction* hlo) {
5604   auto* reduce_window = Cast<HloReduceWindowInstruction>(hlo);
5605   const bool multi_output_reduce_window = reduce_window->shape().IsTuple();
5606   auto inputs = reduce_window->inputs();
5607   auto init_values = reduce_window->init_values();
5608   auto input_count = reduce_window->input_count();
5609   auto input_shapes = reduce_window->input_shapes();
5610   auto output_shapes = reduce_window->output_shapes();
5611   auto replace_with_span = [&](const std::vector<HloInstruction*>& elements) {
5612     CHECK(multi_output_reduce_window || elements.size() == 1);
5613     if (multi_output_reduce_window) {
5614       return ReplaceWithNewInstruction(reduce_window,
5615                                        HloInstruction::CreateTuple(elements));
5616     }
5617     return ReplaceInstruction(reduce_window, elements[0]);
5618   };
5619   // For tuple reduce, we require all reduce shapes to be the same, up to the
5620   // element types, so we can use just the first operand and the first result as
5621   // a representative.
5622   if (ShapeUtil::IsZeroElementArray(*input_shapes[0]) ||
5623       ShapeUtil::IsZeroElementArray(*output_shapes[0])) {
5624     std::vector<HloInstruction*> broadcast_inits;
5625     for (int64_t i = 0; i < input_count; ++i) {
5626       broadcast_inits.push_back(
5627           hlo->AddInstruction(HloInstruction::CreateBroadcast(
5628               *output_shapes[i], init_values[i], {})));
5629     }
5630     return replace_with_span(broadcast_inits);
5631   }
5632   if (ShapeUtil::IsScalar(*input_shapes[0]) &&
5633       (!multi_output_reduce_window ||
5634        reduce_window->to_apply()->root_instruction()->opcode() ==
5635            HloOpcode::kTuple)) {
5636     std::vector<HloInstruction*> maps;
5637     for (int64_t i = 0; i < input_count; ++i) {
5638       TF_RET_CHECK(ShapeUtil::IsScalar(*input_shapes[i]));
5639       TF_RET_CHECK(ShapeUtil::IsScalar(*output_shapes[i]));
5640       HloInstruction* map_computation_root;
5641       absl::flat_hash_map<const HloInstruction*,
5642                           std::unique_ptr<HloInstruction>>
5643           replacements;
5644       if (multi_output_reduce_window) {
5645         map_computation_root =
5646             reduce_window->to_apply()->root_instruction()->mutable_operand(i);
5647         replacements[reduce_window->to_apply()->root_instruction()] = nullptr;
5648       } else {
5649         map_computation_root = reduce_window->to_apply()->root_instruction();
5650       }
5651       maps.push_back(inputs[i]);
5652     }
5653     return replace_with_span(maps);
5654   }
5655   // Turn trivial variadic reduce windows into normal reduce windows.
5656   auto reduce_function_root = reduce_window->to_apply()->root_instruction();
5657   if (multi_output_reduce_window && input_count == 1 &&
5658       Match(reduce_function_root, m::Tuple())) {
5659     // Make a new reducer which is identical but does not have a tuple
5660     // instruction at the bottom.
5661     absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
5662         replacements;
5663     replacements[reduce_function_root] = nullptr;
5664     auto new_function = computation_->parent()->AddEmbeddedComputation(
5665         reduce_window->to_apply()->CloneWithReplacements(
5666             &replacements, /*extra_parameters=*/{},
5667             /*context=*/nullptr,
5668             /*suffix=*/"clone",
5669             /*new_root=*/reduce_function_root->operand(0)));
5670     auto new_reduce_window =
5671         reduce_window->AddInstruction(HloInstruction::CreateReduceWindow(
5672             *output_shapes[0], inputs[0], init_values[0],
5673             reduce_window->window(), new_function));
5674     return ReplaceWithNewInstruction(
5675         reduce_window, HloInstruction::CreateTuple({new_reduce_window}));
5676   }
5677   // TODO(b/73062247) Variadic reduce window is not yet supported in simplifier.
5678   if (multi_output_reduce_window) {
5679     return OkStatus();
5680   }
5681   auto operand = reduce_window->mutable_operand(0);
5682   auto init_value = reduce_window->mutable_operand(1);
5683   auto function = reduce_window->to_apply();
5684   const Window& window = reduce_window->window();
5685 
5686   // reduce-window with a 1x1x..x1 window and no dilation etc can be replaced
5687   // with a trivial elementwise operation, plus a pad op if necessary.
5688   //
5689   // We cowardly refuse to consider this optimization when the reduce-window
5690   // subcomputation is anything other than a simple add/min/max.  Supporting
5691   // more complex subcomputations is possible, but is tantamount to implementing
5692   // jax.vmap()!
5693   if (absl::c_all_of(window.dimensions(),
5694                      [](const WindowDimension& dim) {
5695                        return dim.size() == 1 &&             //
5696                               dim.stride() == 1 &&           //
5697                               dim.window_dilation() == 1 &&  //
5698                               dim.base_dilation() == 1 &&    //
5699                               !dim.window_reversal();
5700                      }) &&
5701       Match(function->root_instruction(),
5702             m::AnyOf<HloInstruction>(
5703                 m::AddAnyOrder(m::Parameter(0), m::Parameter(1)),
5704                 m::MinimumAnyOrder(m::Parameter(0), m::Parameter(1)),
5705                 m::MaximumAnyOrder(m::Parameter(0), m::Parameter(1))))) {
5706     const HloInstruction* nested_root = function->root_instruction();
5707     DimensionVector broadcast_dims(nested_root->shape().dimensions_size());
5708     absl::c_iota(broadcast_dims, 0);
5709     TF_ASSIGN_OR_RETURN(
5710         auto new_op, MakeBinaryHlo(nested_root->opcode(), operand,
5711                                    MakeBroadcastHlo(init_value, broadcast_dims,
5712                                                     operand->shape())));
5713 
5714     if (absl::c_any_of(window.dimensions(), [](const WindowDimension& dim) {
5715           return dim.padding_low() > 0 || dim.padding_high() > 0;
5716         })) {
5717       PaddingConfig padding_config;
5718       for (const WindowDimension& window_dim : window.dimensions()) {
5719         auto& padding_dim = *padding_config.add_dimensions();
5720         padding_dim.set_edge_padding_low(window_dim.padding_low());
5721         padding_dim.set_edge_padding_high(window_dim.padding_high());
5722         padding_dim.set_interior_padding(0);
5723       }
5724       TF_ASSIGN_OR_RETURN(new_op,
5725                           MakePadHlo(new_op, init_value, padding_config));
5726     }
5727 
5728     return ReplaceInstruction(reduce_window, new_op);
5729   }
5730 
5731   if (options_.enable_window_reduce_to_reduce_replacement()) {
5732     // A reduce window can be expressed as a reduce and a reshape if all
5733     // dimensions either have a window size of one or the entire dimension. If
5734     // there is no stride, dilation, or padding, this is as easy as checking the
5735     // size of the output shape and window dimension.
5736     //
5737     // The reshape is a bitcast since it adds one-sized dimensions. Often these
5738     // ones are immediately removed as well with another reshape. The
5739     // implementation of reduce tends to be slightly more efficient at reducing
5740     // entire dimensions compared to reduce window.
5741     auto effective_reduce_dims = [&] {
5742       if (window_util::HasStride(window) || window_util::HasDilation(window) ||
5743           window_util::HasPadding(window)) {
5744         return DimensionVector{};
5745       }
5746       DimensionVector reduce_dims;
5747       for (int64_t i = 0; i < window.dimensions_size(); ++i) {
5748         if (window.dimensions(i).size() == 1) {
5749           continue;
5750         } else if (reduce_window->shape().dimensions(i) == 1) {
5751           reduce_dims.push_back(i);
5752         } else {
5753           return DimensionVector{};
5754         }
5755       }
5756       return reduce_dims;
5757     }();
5758 
5759     // If a reduce window can be expressed as a reduce, do so and reshape the
5760     // output.
5761     if (!effective_reduce_dims.empty()) {
5762       Shape reduce_shape = ShapeUtil::DeleteDimensions(effective_reduce_dims,
5763                                                        reduce_window->shape());
5764       simplifier_->UpdateLayout(&reduce_shape);
5765       HloInstruction* reduce = hlo->AddInstruction(HloInstruction::CreateReduce(
5766           /*shape=*/reduce_shape,
5767           /*operand=*/operand,
5768           /*init_value=*/reduce_window->mutable_operand(1),
5769           /*dimensions_to_reduce=*/effective_reduce_dims,
5770           /*reduce_computation=*/function));
5771       return ReplaceWithNewInstruction(
5772           reduce_window,
5773           HloInstruction::CreateReshape(reduce_window->shape(), reduce));
5774     }
5775   }
5776 
5777   // This optimization folds a pad op into reduce_window.
5778   HloInstruction* pad;
5779   const HloInstruction* convert = nullptr;
5780   if (operand->opcode() == HloOpcode::kPad) {
5781     pad = operand;
5782   } else if (operand->opcode() == HloOpcode::kConvert &&
5783              operand->operand(0)->opcode() == HloOpcode::kPad) {
5784     convert = operand;
5785     pad = operand->mutable_operand(0);
5786   } else {
5787     VLOG(10) << "Not folding pad into reduce-window as there is no pad.";
5788     return OkStatus();
5789   }
5790 
5791   VLOG(10) << "Considering folding Pad: " << pad->ToString()
5792            << "\ninto reduce-window: " << reduce_window->ToString()
5793            << (convert != nullptr
5794                    ? absl::StrCat("\nvia convert: ", convert->ToString())
5795                    : "");
5796 
5797   // Do not fold interior padding into ReduceWindow since the backends do not
5798   // support it.
5799   const PaddingConfig& pad_config = pad->padding_config();
5800   if (HasInteriorPadding(pad_config) && window_util::HasBaseDilation(window)) {
5801     VLOG(10) << "Not folding interior pad into base-dilated reduce-window.";
5802     return OkStatus();
5803   }
5804 
5805   // If reduce_window already has padding, the pad value of the pad op and the
5806   // init value of reduce_window must match to allow folding the pad.
5807   const HloInstruction* pad_value = pad->operand(1);
5808   const HloInstruction* reduce_init_value = reduce_window->operand(1);
5809   if (pad_value != reduce_init_value) {
5810     auto literals_are_equivalent = [&] {
5811       auto& pad_literal = pad_value->literal();
5812       auto& reduce_init_literal = reduce_init_value->literal();
5813       if (pad_literal == reduce_init_literal) {
5814         return true;
5815       }
5816       auto converted_pad_literal =
5817           pad_literal.ConvertToShape(reduce_init_value->shape());
5818       if (!converted_pad_literal.ok()) {
5819         return false;
5820       }
5821       return converted_pad_literal.ValueOrDie() == reduce_init_literal;
5822     };
5823     // The pad value is usually a constant, so we handle that case and do not
5824     // try to get more fancy about proving equivalence in cases beyond that.
5825     if (pad_value->opcode() != HloOpcode::kConstant ||
5826         reduce_init_value->opcode() != HloOpcode::kConstant ||
5827         !literals_are_equivalent()) {
5828       VLOG(10) << "Not folding pad into reduce-window due to different pad "
5829                   "values.";
5830       return OkStatus();
5831     }
5832   }
5833 
5834   // If the pad puts a single non-identity value in each window that we're
5835   // reducing, then this is a broadcast.
5836   HloInstruction* pad_operand = pad->mutable_operand(0);
5837   auto is_effective_broadcast = [&] {
5838     if (window_util::HasStride(window)) {
5839       VLOG(10) << "Window has stride.";
5840       return false;
5841     }
5842     if (!window_util::HasSymmetricPadding(pad_config)) {
5843       VLOG(10) << "Window has uneven padding.";
5844       return false;
5845     }
5846     if (HasInteriorPadding(pad_config)) {
5847       VLOG(10) << "Window has interior padding.";
5848       return false;
5849     }
5850     for (int64_t i = 0; i < pad_config.dimensions_size(); ++i) {
5851       const auto& pad_dimension = pad_config.dimensions(i);
5852       if ((pad_dimension.edge_padding_low() != 0 ||
5853            pad_dimension.edge_padding_high() != 0) &&
5854           pad_operand->shape().dimensions(i) != 1) {
5855         VLOG(10) << "Found non-trivial dimension being padded: " << i;
5856         return false;
5857       }
5858     }
5859     VLOG(10) << "Found to be padding trivial dimensions only.";
5860 
5861     for (int64_t i = 0; i < window.dimensions_size(); ++i) {
5862       const auto& pad_dimension = pad_config.dimensions(i);
5863       const WindowDimension& window_dimension = window.dimensions(i);
5864       bool dimension_has_padding = (pad_dimension.edge_padding_low() != 0 ||
5865                                     pad_dimension.edge_padding_high() != 0);
5866       if (dimension_has_padding &&
5867           window_dimension.size() < pad_dimension.edge_padding_low() + 1) {
5868         VLOG(10) << "Found window did not cover single unpadded element in "
5869                     "dimension: "
5870                  << i;
5871         return false;
5872       }
5873       if (pad_operand->shape().dimensions(i) != 1 &&
5874           window_dimension.size() != 1) {
5875         VLOG(10) << "Found window covers more than one element in non-trivial "
5876                     "dimension: "
5877                  << i;
5878         return false;
5879       }
5880     }
5881     VLOG(10) << "Found window covers a single unpadded element.";
5882     return true;
5883   };
5884 
5885   HloInstruction* new_reduce_window_operand;
5886   if (convert != nullptr) {
5887     Shape changed_shape = ShapeUtil::ChangeElementType(
5888         pad_operand->shape(), convert->shape().element_type());
5889     simplifier_->UpdateLayout(&changed_shape);
5890     new_reduce_window_operand = hlo->AddInstruction(
5891         HloInstruction::CreateConvert(changed_shape, pad_operand));
5892   } else {
5893     new_reduce_window_operand = pad_operand;
5894   }
5895 
5896   if (is_effective_broadcast()) {
5897     VLOG(10) << "Replacing pad/reduce-window with broadcast.";
5898     auto fadd = [hlo](std::unique_ptr<HloInstruction> x) {
5899       return hlo->AddInstruction(std::move(x));
5900     };
5901     return ReplaceWithNewInstruction(
5902         reduce_window, HloInstruction::CreateBroadcastSequence(
5903                            /*output_shape=*/reduce_window->shape(),
5904                            /*operand=*/new_reduce_window_operand, fadd));
5905   }
5906 
5907   // Carry out the folding of the pad into reduce_window.
5908   VLOG(10) << "Folding pad into reduce-window.";
5909   Window new_window = window;
5910   const int64_t rank = reduce_window->shape().rank();
5911   TF_RET_CHECK(pad_config.dimensions_size() == rank);
5912   TF_RET_CHECK(window.dimensions_size() == rank);
5913   for (int64_t i = 0; i < rank; ++i) {
5914     const auto& pad_dim = pad_config.dimensions(i);
5915     auto& window_dim = *new_window.mutable_dimensions(i);
5916     window_dim.set_padding_low(window_dim.padding_low() +
5917                                window_dim.base_dilation() *
5918                                    pad_dim.edge_padding_low());
5919     window_dim.set_padding_high(window_dim.padding_high() +
5920                                 window_dim.base_dilation() *
5921                                     pad_dim.edge_padding_high());
5922     if (pad_dim.interior_padding() != 0) {
5923       CHECK_EQ(window_dim.base_dilation(), 1);
5924       window_dim.set_base_dilation(1 + pad_dim.interior_padding());
5925     }
5926   }
5927 
5928   return ReplaceWithNewInstruction(
5929       reduce_window, HloInstruction::CreateReduceWindow(
5930                          /*shape=*/reduce_window->shape(),
5931                          /*operand=*/new_reduce_window_operand,
5932                          /*init_value=*/reduce_window->mutable_operand(1),
5933                          /*window=*/new_window,
5934                          /*reduce_computation=*/function));
5935 }
5936 
HandleSelect(HloInstruction * select)5937 Status AlgebraicSimplifierVisitor::HandleSelect(HloInstruction* select) {
5938   // select(x, y, y) -> y.
5939   if (select->operand(1) == select->operand(2)) {
5940     return ReplaceInstruction(select, select->mutable_operand(1));
5941   }
5942   // select(true, x, y) -> x.
5943   if (IsAll(select->operand(0), true)) {
5944     return ReplaceInstruction(select, select->mutable_operand(1));
5945   }
5946   // select(false, x, y) -> y.
5947   if (IsAll(select->operand(0), false)) {
5948     return ReplaceInstruction(select, select->mutable_operand(2));
5949   }
5950   // select(not(pred), a, b) -> select(pred, b, a)
5951   if (HloOpcode::kNot == select->operand(0)->opcode()) {
5952     auto pred_operand = select->mutable_operand(0)->mutable_operand(0);
5953     auto on_true = select->mutable_operand(1);
5954     auto on_false = select->mutable_operand(2);
5955     return ReplaceWithNewInstruction(
5956         select,
5957         HloInstruction::CreateTernary(select->shape(), HloOpcode::kSelect,
5958                                       pred_operand, on_false, on_true));
5959   }
5960   return OkStatus();
5961 }
5962 
HandleScatter(HloInstruction * hlo)5963 Status AlgebraicSimplifierVisitor::HandleScatter(HloInstruction* hlo) {
5964   auto* scatter = Cast<HloScatterInstruction>(hlo);
5965 
5966   if (absl::c_all_of(scatter->scatter_updates(),
5967                      [](const HloInstruction* updates) {
5968                        return ShapeUtil::IsZeroElementArray(updates->shape());
5969                      }) &&
5970       ReplaceInstructionIfCompatible(scatter, scatter->scatter_operands())) {
5971     return OkStatus();
5972   }
5973   if (scatter->scatter_operand_count() == 1 &&
5974       ShapeUtil::IsZeroElementArray(scatter->scatter_indices()->shape()) &&
5975       SameShape(scatter, scatter->scatter_operands()[0]) &&
5976       SameShape(scatter, scatter->scatter_updates()[0])) {
5977     return ReplaceWithNewInstruction(
5978         scatter, HloInstruction::CreateMap(scatter->shape(),
5979                                            {scatter->scatter_operands()[0],
5980                                             scatter->scatter_updates()[0]},
5981                                            scatter->to_apply()));
5982   }
5983   return OkStatus();
5984 }
5985 
HandleSort(HloInstruction * sort)5986 Status AlgebraicSimplifierVisitor::HandleSort(HloInstruction* sort) {
5987   auto operand = sort->mutable_operand(0);
5988   int64_t dimension_to_sort = sort->dimensions(0);
5989   if (ShapeUtil::IsZeroElementArray(operand->shape()) ||
5990       operand->shape().dimensions(dimension_to_sort) <= 1) {
5991     if (sort->operand_count() == 1) {
5992       return ReplaceInstruction(sort, operand);
5993     }
5994     // If it is key/value sort, the output of sort is a tuple.
5995     return ReplaceWithNewInstruction(
5996         sort, HloInstruction::CreateTuple(sort->operands()));
5997   }
5998   return OkStatus();
5999 }
6000 
HandleSqrt(HloInstruction * sqrt)6001 Status AlgebraicSimplifierVisitor::HandleSqrt(HloInstruction* sqrt) {
6002   VLOG(10) << "trying transform [sqrt(A*A) => |A|] " << sqrt->ToString();
6003   HloInstruction* sqrt_operand = sqrt->mutable_operand(0);
6004   if (sqrt_operand->opcode() == HloOpcode::kMultiply &&
6005       sqrt_operand->operand(0) == sqrt_operand->operand(1)) {
6006     return ReplaceWithNewInstruction(
6007         sqrt, HloInstruction::CreateUnary(
6008                   sqrt_operand->mutable_operand(0)->shape(), HloOpcode::kAbs,
6009                   sqrt_operand->mutable_operand(0)));
6010   }
6011   return OkStatus();
6012 }
6013 namespace {
OnlyPermutesDegenerateDims(const Shape & shape,absl::Span<const int64_t> perm)6014 bool OnlyPermutesDegenerateDims(const Shape& shape,
6015                                 absl::Span<const int64_t> perm) {
6016   std::vector<int64_t> new_permutation;
6017   int64_t degenerate_count = 0;
6018   for (int64_t i = 0; i < perm.size(); ++i) {
6019     if (shape.dimensions(i) != 1) {
6020       new_permutation.push_back(perm[i]);
6021     } else {
6022       ++degenerate_count;
6023     }
6024   }
6025   return degenerate_count > 0 && absl::c_is_sorted(new_permutation);
6026 }
6027 
IsPermutationOfIota(absl::Span<const int64_t> elems)6028 bool IsPermutationOfIota(absl::Span<const int64_t> elems) {
6029   DimensionVector sorted(elems.begin(), elems.end());
6030   absl::c_sort(sorted);
6031   for (int i = 0; i < sorted.size(); i++) {
6032     if (sorted[i] != i) {
6033       return false;
6034     }
6035   }
6036   return true;
6037 }
6038 
6039 }  // namespace
6040 
HandleTranspose(HloInstruction * transpose)6041 Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) {
6042   auto operand = transpose->mutable_operand(0);
6043   if (std::is_sorted(transpose->dimensions().begin(),
6044                      transpose->dimensions().end())) {
6045     VLOG(10) << "deleting no-op transpose";
6046     return ReplaceInstruction(transpose, operand);
6047   }
6048 
6049   if (HloOpcode::kTranspose == operand->opcode()) {
6050     return ReplaceWithNewInstruction(
6051         transpose, HloInstruction::CreateTranspose(
6052                        transpose->shape(), operand->mutable_operand(0),
6053                        ComposePermutations(operand->dimensions(),
6054                                            transpose->dimensions())));
6055   }
6056 
6057   // Convert transpose(dot(a,b)) to dot(b,a).
6058   auto do_transpose_of_dot = [&]() -> StatusOr<bool> {
6059     if (operand->opcode() != HloOpcode::kDot || operand->user_count() != 1) {
6060       return false;
6061     }
6062     HloInstruction* dot = operand;
6063     HloInstruction* lhs = dot->mutable_operand(0);
6064     HloInstruction* rhs = dot->mutable_operand(1);
6065 
6066     const int64_t rank = dot->shape().rank();
6067     const auto& dnums = dot->dot_dimension_numbers();
6068 
6069     // Dot must be "somewhat canonical": batch dimensions at the beginning and
6070     // one non-contracting dim.  It's the responsibility of DotDecomposer to
6071     // canonicalize dots.
6072     if (absl::MakeSpan(dnums.lhs_batch_dimensions()) !=
6073             absl::MakeSpan(dnums.rhs_batch_dimensions()) ||
6074         !IsPermutationOfIota(dnums.lhs_batch_dimensions()) ||
6075         dnums.lhs_contracting_dimensions_size() == 0 ||
6076         dnums.lhs_contracting_dimensions_size() +
6077                 dnums.lhs_batch_dimensions_size() + 1 !=
6078             lhs->shape().rank() ||
6079         dnums.rhs_contracting_dimensions_size() == 0 ||
6080         dnums.rhs_contracting_dimensions_size() +
6081                 dnums.rhs_batch_dimensions_size() + 1 !=
6082             rhs->shape().rank()) {
6083       return false;
6084     }
6085 
6086     // Transpose must just be over the two last dims (i.e. the non-batch dims).
6087     DimensionVector expected_perm(rank);
6088     absl::c_iota(expected_perm, 0);
6089     std::swap(expected_perm.rbegin()[0], expected_perm.rbegin()[1]);
6090     if (transpose->dimensions() != expected_perm) {
6091       return false;
6092     }
6093 
6094     DotDimensionNumbers new_dnums = dnums;
6095     std::swap(*new_dnums.mutable_lhs_contracting_dimensions(),
6096               *new_dnums.mutable_rhs_contracting_dimensions());
6097     TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
6098         transpose,
6099         HloInstruction::CreateDot(
6100             transpose->shape(), /*lhs=*/rhs, /*rhs=*/lhs, new_dnums,
6101             SwapOperandsInDotPrecisionConfig(dot->precision_config()))));
6102     return true;
6103   };
6104   TF_ASSIGN_OR_RETURN(bool did_transpose_of_dot, do_transpose_of_dot());
6105   if (did_transpose_of_dot) {
6106     return OkStatus();
6107   }
6108 
6109   // Replace transpose with a reshape if more than one degenerate method is
6110   // permuted.
6111   if (OnlyPermutesDegenerateDims(transpose->shape(), transpose->dimensions())) {
6112     return ReplaceWithNewInstruction(
6113         transpose, HloInstruction::CreateReshape(
6114                        transpose->shape(), transpose->mutable_operand(0)));
6115   }
6116 
6117   if (operand->opcode() == HloOpcode::kRng && operand->user_count() == 1) {
6118     *operand->mutable_shape() = transpose->shape();
6119     return ReplaceInstruction(transpose, operand);
6120   }
6121 
6122   if (options_.is_layout_sensitive() && TransposeIsBitcast(transpose)) {
6123     ReplaceWithBitcast(transpose);
6124     return OkStatus();
6125   }
6126 
6127   // Replace reshape of a transpose of a reshape with concatenated slicing if
6128   // the reshape/transpose combination can be interpreted as a space-to-depth
6129   // transformation.
6130   if (operand->opcode() == HloOpcode::kReshape &&
6131       transpose->user_count() == 1 &&
6132       HloOpcode::kReshape == transpose->users()[0]->opcode()) {
6133     VLOG(2) << "trying depth-to-space transform";
6134     HloInstruction* reshape_operand = operand->mutable_operand(0);
6135     HloInstruction* outer_reshape = transpose->users()[0];
6136     TF_ASSIGN_OR_RETURN(
6137         bool did_transform, ([&]() -> StatusOr<bool> {
6138           if (operand->shape().dimensions_size() !=
6139               reshape_operand->shape().dimensions_size() + 1) {
6140             return false;
6141           }
6142 
6143           // Check that the reshape is splitting a single dimension into two.
6144           int64_t split_dim = 0;
6145           bool found_split_dims = false;
6146           for (int64_t dim = 0; dim < reshape_operand->shape().rank(); dim++) {
6147             if (operand->shape().dimensions(dim) !=
6148                 reshape_operand->shape().dimensions(dim)) {
6149               const int64_t expected_size =
6150                   operand->shape().dimensions(dim) *
6151                   operand->shape().dimensions(dim + 1);
6152               if (reshape_operand->shape().dimensions(dim) == expected_size) {
6153                 split_dim = dim;
6154                 found_split_dims = true;
6155                 break;
6156               }
6157               return false;
6158             }
6159           }
6160           if (!found_split_dims) {
6161             return false;
6162           }
6163           for (int64_t dim = split_dim + 1;
6164                dim < reshape_operand->shape().rank(); dim++) {
6165             if (operand->shape().dimensions(dim + 1) !=
6166                 reshape_operand->shape().dimensions(dim)) {
6167               return false;
6168             }
6169           }
6170 
6171           const int64_t num_chunks = operand->shape().dimensions(split_dim);
6172           const int64_t chunk_size = operand->shape().dimensions(split_dim + 1);
6173 
6174           // This optimization is only beneficial for a small number of chunks.
6175           // TODO(b/196832483): Determine the appropriate upper bound here.
6176           const int64_t kMaxChunksForTransformation = 5;
6177           if (num_chunks > kMaxChunksForTransformation) {
6178             return false;
6179           }
6180 
6181           // Determine where the smaller split dimension is being placed in the
6182           // transpose
6183           int64_t transpose_dim = 0;
6184           bool found_transpose_dim = false;
6185           for (int64_t dim = 0; dim < operand->shape().rank(); dim++) {
6186             if (transpose->dimensions(dim) == split_dim) {
6187               transpose_dim = dim;
6188               found_transpose_dim = true;
6189               break;
6190             }
6191           }
6192 
6193           // Check that only the small split dimension is reordered in the
6194           // transpose
6195           if (!found_transpose_dim || transpose_dim == split_dim ||
6196               transpose_dim == split_dim + 1) {
6197             return false;
6198           }
6199           for (int64_t dim = 0; dim < operand->shape().rank(); dim++) {
6200             int64_t offset = 0;
6201             if (dim > transpose_dim) {
6202               offset--;
6203             }
6204             if (dim > split_dim) {
6205               offset++;
6206             }
6207 
6208             if (dim != transpose_dim &&
6209                 transpose->dimensions(dim) != dim + offset) {
6210               return false;
6211             }
6212           }
6213 
6214           // Check that the outer reshape has the same shape as the input,
6215           // with the transformed dimensions appropriately scaled by num_chunks.
6216           for (int64_t dim = 0; dim < reshape_operand->shape().rank(); dim++) {
6217             if (dim == transpose_dim - 1) {
6218               if (outer_reshape->shape().dimensions(dim) !=
6219                   reshape_operand->shape().dimensions(dim) * num_chunks) {
6220                 return false;
6221               }
6222             } else if (dim == split_dim) {
6223               if (outer_reshape->shape().dimensions(dim) !=
6224                   reshape_operand->shape().dimensions(dim) / num_chunks) {
6225                 return false;
6226               }
6227             } else if (outer_reshape->shape().dimensions(dim) !=
6228                        reshape_operand->shape().dimensions(dim)) {
6229               return false;
6230             }
6231           }
6232 
6233           // Create a concat-of-slices, slicing to create chunks of the expected
6234           // size on the smaller split dimension.
6235           std::vector<HloInstruction*> slices;
6236           for (int64_t i = 0; i < num_chunks; i++) {
6237             std::vector<int64_t> start_indices;
6238             std::vector<int64_t> end_indices;
6239             std::vector<int64_t> strides;
6240             const auto rank = reshape_operand->shape().rank();
6241             start_indices.reserve(rank);
6242             end_indices.reserve(rank);
6243             strides.reserve(rank);
6244             for (int64_t dim = 0; dim < rank; dim++) {
6245               if (dim == split_dim) {
6246                 start_indices.push_back(i * chunk_size);
6247                 end_indices.push_back(i * chunk_size + chunk_size);
6248               } else {
6249                 start_indices.push_back(0);
6250                 end_indices.push_back(reshape_operand->shape().dimensions(dim));
6251               }
6252               strides.push_back(1);
6253             }
6254             TF_ASSIGN_OR_RETURN(HloInstruction* const slice,
6255                                 MakeSliceHlo(reshape_operand, start_indices,
6256                                              end_indices, strides));
6257             slices.push_back(slice);
6258             VLOG(2) << "slice " << i << " " << slice->ToString();
6259           }
6260 
6261           TF_ASSIGN_OR_RETURN(HloInstruction* const concat,
6262                               MakeConcatHlo(slices, transpose_dim));
6263           VLOG(2) << "concat " << concat->ToString();
6264           TF_RETURN_IF_ERROR(
6265               outer_reshape->ReplaceOperandWithDifferentShape(0, concat));
6266 
6267           return true;
6268         }()));
6269     if (did_transform) {
6270       MarkAsChanged();
6271       return OkStatus();
6272     }
6273   }
6274 
6275   return OkStatus();
6276 }
6277 
FoldConvInputPad(HloInstruction * convolution)6278 StatusOr<bool> AlgebraicSimplifierVisitor::FoldConvInputPad(
6279     HloInstruction* convolution) {
6280   HloInstruction *lhs, *a, *b;
6281   if (Match(convolution,
6282             m::Convolution(m::Pad(&lhs, m::Op(&a), m::ConstantScalar(0)),
6283                            m::Op(&b)))) {
6284     const auto& window = convolution->window();
6285     const ConvolutionDimensionNumbers& dnums =
6286         convolution->convolution_dimension_numbers();
6287 
6288     const auto& padding = lhs->padding_config();
6289 
6290     // Can't pad batch or feature dims.
6291     for (int64_t dim :
6292          {dnums.input_batch_dimension(), dnums.input_feature_dimension()}) {
6293       const auto& p = padding.dimensions(dim);
6294       if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
6295           p.interior_padding() != 0) {
6296         return false;
6297       }
6298     }
6299 
6300     // Compute the window which is the result of merging the kPad and the
6301     // convolution's existing window.
6302     Window new_window = window;
6303     for (int64_t dim = 0; dim < dnums.input_spatial_dimensions_size(); ++dim) {
6304       auto& w = *new_window.mutable_dimensions(dim);
6305       const auto& p = padding.dimensions(dnums.input_spatial_dimensions(dim));
6306       // Edge padding composes with itself in the straightforward way, but
6307       // composing interior padding is nontrivial, and we cowardly refuse to
6308       // think about it. If we see interior padding in either the kPad or conv,
6309       // bail if there's any sort of padding in the other.
6310       if (p.interior_padding() != 0 &&
6311           (w.padding_low() != 0 || w.padding_high() != 0 ||
6312            w.base_dilation() != 1)) {
6313         return false;
6314       }
6315       if (w.base_dilation() != 1 &&
6316           (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
6317            p.interior_padding() != 0)) {
6318         return false;
6319       }
6320 
6321       w.set_padding_low(w.padding_low() + p.edge_padding_low());
6322       w.set_padding_high(w.padding_high() + p.edge_padding_high());
6323       if (p.interior_padding() != 0) {
6324         CHECK_EQ(w.base_dilation(), 1);
6325         w.set_base_dilation(1 + p.interior_padding());
6326       }
6327     }
6328 
6329     auto new_conv =
6330         convolution->CloneWithNewOperands(convolution->shape(), {a, b});
6331     new_conv->set_window(new_window);
6332     TF_RETURN_IF_ERROR(
6333         ReplaceWithNewInstruction(convolution, std::move(new_conv)));
6334     return true;
6335   }
6336   return false;
6337 }
6338 
FoldConvFilterPad(HloInstruction * convolution)6339 StatusOr<bool> AlgebraicSimplifierVisitor::FoldConvFilterPad(
6340     HloInstruction* convolution) {
6341   auto* lhs = convolution->mutable_operand(0);
6342   auto* rhs = convolution->mutable_operand(1);
6343   const ConvolutionDimensionNumbers& dnums =
6344       convolution->convolution_dimension_numbers();
6345 
6346   if (rhs->opcode() != HloOpcode::kPad) {
6347     return false;
6348   }
6349 
6350   // Convolution's padding is always zero, so bail if the kPad is adding
6351   // something other than zero.
6352   if (!IsAll(rhs->operand(1), 0)) {
6353     return false;
6354   }
6355 
6356   const auto& padding = rhs->padding_config();
6357 
6358   // Can't pad or dilate feature dims.
6359   for (int64_t dim : {dnums.kernel_input_feature_dimension(),
6360                       dnums.kernel_output_feature_dimension()}) {
6361     const auto& p = padding.dimensions(dim);
6362     if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
6363         p.interior_padding() != 0) {
6364       return false;
6365     }
6366   }
6367 
6368   // Compute the window which is the result of merging the kPad and the
6369   // convolution's existing window.
6370   Window new_window = convolution->window();
6371   for (int64_t dim = 0; dim < dnums.kernel_spatial_dimensions_size(); ++dim) {
6372     auto& w = *new_window.mutable_dimensions(dim);
6373     const auto& p = padding.dimensions(dnums.kernel_spatial_dimensions(dim));
6374 
6375     // We can only do this transformation if p adds dilation to the filter --
6376     // edge padding on the filter is not supported in conv.
6377     if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0) {
6378       return false;
6379     }
6380 
6381     // Nothing to do if the kPad for this dim is entirely a nop.
6382     if (p.interior_padding() == 0) {
6383       continue;
6384     }
6385 
6386     // We cowardly refuse to think about how dilation composes with itself;
6387     // bail if both the kPad and conv have dilation on this dimension.
6388     if (w.window_dilation() > 1) {
6389       return false;
6390     }
6391     CHECK_EQ(w.window_dilation(), 1);
6392     w.set_window_dilation(1 + p.interior_padding());
6393     w.set_size(rhs->operand(0)->shape().dimensions(
6394         dnums.kernel_spatial_dimensions(dim)));
6395   }
6396 
6397   auto new_conv = convolution->CloneWithNewOperands(
6398       convolution->shape(), {lhs, rhs->mutable_operand(0)});
6399   new_conv->set_window(new_window);
6400   TF_RETURN_IF_ERROR(
6401       ReplaceWithNewInstruction(convolution, std::move(new_conv)));
6402   return true;
6403 }
6404 
SwapConvOperands(HloInstruction * convolution)6405 StatusOr<bool> AlgebraicSimplifierVisitor::SwapConvOperands(
6406     HloInstruction* convolution) {
6407   if (!options_.enable_conv_operand_swap() || options_.is_layout_sensitive()) {
6408     return false;
6409   }
6410   if (convolution->feature_group_count() > 1 ||
6411       convolution->batch_group_count() > 1) {
6412     return false;
6413   }
6414 
6415   const auto& dnums = convolution->convolution_dimension_numbers();
6416   const auto& window_dims = convolution->window().dimensions();
6417   Window swapped_window;
6418 
6419   HloInstruction *input = convolution->mutable_operand(0),
6420                  *kernel = convolution->mutable_operand(1);
6421   int64_t kernel_product = 1;
6422   int64_t swapped_kernel_product = 1;
6423   DimensionVector reverse_dimensions;
6424   for (int64_t spatial_dim = 0;
6425        spatial_dim < dnums.input_spatial_dimensions_size(); ++spatial_dim) {
6426     const int64_t kernel_size = window_dims[spatial_dim].size();
6427     const bool can_be_group_or_contraction =
6428         !window_dims[spatial_dim].window_reversal() &&
6429         window_dims[spatial_dim].padding_low() == 0 &&
6430         window_dims[spatial_dim].padding_high() == 0 &&
6431         window_dims[spatial_dim].window_dilation() == 1;
6432     const bool is_group_dim =
6433         can_be_group_or_contraction &&
6434         window_dims[spatial_dim].base_dilation() == kernel_size &&
6435         window_dims[spatial_dim].stride() == kernel_size - 1;
6436     const int64_t input_size =
6437         input->shape().dimensions(dnums.input_spatial_dimensions(spatial_dim));
6438     const bool is_pure_contraction_dim =
6439         kernel_size == input_size && can_be_group_or_contraction &&
6440         window_dims[spatial_dim].base_dilation() == 1 &&
6441         window_dims[spatial_dim].stride() == 1;
6442     if (is_group_dim || is_pure_contraction_dim) {
6443       *(swapped_window.add_dimensions()) = window_dims[spatial_dim];
6444       continue;
6445     }
6446 
6447     const int64_t dilated_kernel_size =
6448         1 + (kernel_size - 1) * window_dims[spatial_dim].window_dilation();
6449     const int64_t dilated_input_size =
6450         1 + (input_size - 1) * window_dims[spatial_dim].base_dilation();
6451 
6452     // Don't decide to swap if the input size is one, since many convolution
6453     // implementations can easily hand that special case efficiently.
6454     kernel_product *= kernel_size;
6455     swapped_kernel_product *=
6456         input_size == 1 && window_dims[spatial_dim].stride() == 1 &&
6457                 window_dims[spatial_dim].window_dilation() == 1 &&
6458                 window_dims[spatial_dim].padding_high() == kernel_size - 1 &&
6459                 window_dims[spatial_dim].padding_low() == kernel_size - 1
6460             ? kernel_size
6461             : input_size;
6462 
6463     auto new_dim = swapped_window.add_dimensions();
6464     new_dim->set_size(input_size);
6465     // If the kernel is not reversed, the activations must be manually reversed.
6466     if (!window_dims[spatial_dim].window_reversal()) {
6467       reverse_dimensions.push_back(
6468           dnums.kernel_spatial_dimensions(spatial_dim));
6469     }
6470     // The input is not originally reversed so it must be reversed to move the
6471     // kernel.
6472     new_dim->set_window_reversal(true);
6473     // Base dilation and window dilation switch places.
6474     new_dim->set_base_dilation(window_dims[spatial_dim].window_dilation());
6475     new_dim->set_window_dilation(window_dims[spatial_dim].base_dilation());
6476     new_dim->set_stride(window_dims[spatial_dim].stride());
6477     new_dim->set_padding_low(dilated_input_size +
6478                              window_dims[spatial_dim].padding_low() -
6479                              dilated_kernel_size);
6480     new_dim->set_padding_high(dilated_input_size +
6481                               window_dims[spatial_dim].padding_high() -
6482                               dilated_kernel_size);
6483   }
6484 
6485   // Don't transform if a naive convolution implementation would not have fewer
6486   // flops.
6487   if (kernel_product <= swapped_kernel_product) {
6488     return false;
6489   }
6490   ConvolutionDimensionNumbers swapped_dnums;
6491   *swapped_dnums.mutable_output_spatial_dimensions() =
6492       dnums.output_spatial_dimensions();
6493   // Swap batch and output feature of the output.
6494   swapped_dnums.set_output_batch_dimension(dnums.output_feature_dimension());
6495   swapped_dnums.set_output_feature_dimension(dnums.output_batch_dimension());
6496 
6497   // Swap input dnums with kernel dnums
6498   *swapped_dnums.mutable_input_spatial_dimensions() =
6499       dnums.kernel_spatial_dimensions();
6500   swapped_dnums.set_input_batch_dimension(
6501       dnums.kernel_output_feature_dimension());
6502   swapped_dnums.set_input_feature_dimension(
6503       dnums.kernel_input_feature_dimension());
6504 
6505   // Swap kernel dnums with input dnums
6506   *swapped_dnums.mutable_kernel_spatial_dimensions() =
6507       dnums.input_spatial_dimensions();
6508   swapped_dnums.set_kernel_output_feature_dimension(
6509       dnums.input_batch_dimension());
6510   swapped_dnums.set_kernel_input_feature_dimension(
6511       dnums.input_feature_dimension());
6512 
6513   PrecisionConfig precision_config;
6514   precision_config.add_operand_precision(
6515       convolution->precision_config().operand_precision(1));
6516   precision_config.add_operand_precision(
6517       convolution->precision_config().operand_precision(0));
6518   if (!reverse_dimensions.empty()) {
6519     TF_ASSIGN_OR_RETURN(kernel, MakeReverseHlo(kernel, reverse_dimensions));
6520   }
6521   TF_ASSIGN_OR_RETURN(
6522       HloInstruction * new_convolution,
6523       MakeConvolveHlo(
6524           kernel, input, /*feature_group_count=*/1,
6525           /*batch_group_count=*/1, swapped_window, swapped_dnums,
6526           precision_config,
6527           /*preferred_element_type=*/convolution->shape().element_type()));
6528 
6529   // If we're running on GPU we need to check that we can actually lower the
6530   // conv with the given reverse_dims (either none, or rank 2 and all)
6531   if (!options_.ConvIsLowerable(new_convolution)) {
6532     TF_RETURN_IF_ERROR(kernel->parent()->RemoveInstruction(new_convolution));
6533     return false;
6534   }
6535 
6536   convolution->SetupDerivedInstruction(new_convolution);
6537   TF_RETURN_IF_ERROR(ReplaceInstruction(convolution, new_convolution));
6538 
6539   return true;
6540 }
6541 
SimplifyConvToDot(HloInstruction * convolution)6542 StatusOr<bool> AlgebraicSimplifierVisitor::SimplifyConvToDot(
6543     HloInstruction* convolution) {
6544   auto* lhs = convolution->mutable_operand(0);
6545   auto* rhs = convolution->mutable_operand(1);
6546   const auto& window = convolution->window();
6547   const ConvolutionDimensionNumbers& dnums =
6548       convolution->convolution_dimension_numbers();
6549 
6550   if (!options_.enable_conv_simplification()) {
6551     return false;
6552   }
6553 
6554   // TODO(b/31337498): For now, we cowardly refuse to do this optimization in
6555   // layout-insensitive mode, for fear of adding nontrivial reshapes.
6556   if (!options_.is_layout_sensitive()) {
6557     return false;
6558   }
6559 
6560   const Shape& input_shape = lhs->shape();
6561   const Shape& filter_shape = rhs->shape();
6562   const Shape& convolution_shape = convolution->shape();
6563   TF_RET_CHECK(LayoutUtil::HasLayout(input_shape));
6564   TF_RET_CHECK(LayoutUtil::HasLayout(filter_shape));
6565   TF_RET_CHECK(LayoutUtil::HasLayout(convolution_shape));
6566 
6567   // Require the spatial dimensions in the kernel to have a bound of one.
6568   for (int64_t i = 0; i < dnums.kernel_spatial_dimensions_size(); ++i) {
6569     if (filter_shape.dimensions(dnums.kernel_spatial_dimensions(i)) != 1) {
6570       return false;
6571     }
6572   }
6573 
6574   // Stride ignores part of the output, which matrix multiplication does not do,
6575   // so require no stride. Padding and base (lhs) dilation both implicitly
6576   // extend the data, which matrix multiplication also does not do, so require
6577   // no padding and no base (lhs) dilation. Window (rhs) dilation has no effect
6578   // for a 1x1 window, so window dilation is no problem.
6579   if (window_util::HasStride(window) || window_util::HasPadding(window) ||
6580       window_util::HasBaseDilation(window)) {
6581     return false;
6582   }
6583 
6584   // Also, the shapes must align for a rowmajor matmul:
6585   // - the input and output have the same layout.
6586   // - for input/output, the channel dimension must be the most minor. Other
6587   //   spatial dims can be in any order.
6588   // - for filters, the input channel dimension must be more major than the
6589   //   output channel dimension. The width+height don't matter because
6590   //   they are 1.
6591   //
6592   // These constraints are harsh. If the channel dimension is the most major
6593   // and/or the layout of input/output feature dimensions are reversed, we can
6594   // still convert Conv into more efficient Matmul with operand transposition
6595   // (such as the transposition flags in cuBLAS SGEMM).
6596   if (!LayoutUtil::Equal(input_shape.layout(), convolution_shape.layout()) ||
6597       LayoutUtil::Minor(input_shape.layout(), 0) !=
6598           dnums.input_feature_dimension() ||
6599       LayoutUtil::Minor(convolution_shape.layout(), 0) !=
6600           dnums.output_feature_dimension() ||
6601       // The input feature dimension should come later in the minor-to-major
6602       // order.
6603       (PositionInContainer(LayoutUtil::MinorToMajor(filter_shape),
6604                            dnums.kernel_input_feature_dimension()) <
6605        PositionInContainer(LayoutUtil::MinorToMajor(filter_shape),
6606                            dnums.kernel_output_feature_dimension()))) {
6607     return false;
6608   }
6609 
6610   if (convolution->feature_group_count() != 1 ||
6611       convolution->batch_group_count() != 1) {
6612     return false;
6613   }
6614   auto add_bitcast = [&](Shape shape, HloInstruction* operand) {
6615     std::vector<int64_t> dims(operand->shape().dimensions_size());
6616     std::iota(dims.begin(), dims.end(), 0);
6617     return operand->AddInstruction(
6618         HloInstruction::CreateBitcast(shape, operand));
6619   };
6620 
6621   // Replace it with a dot, with bitcasts around it to get the right shape.
6622   const int64_t input_channels =
6623       input_shape.dimensions(dnums.input_feature_dimension());
6624   const int64_t output_channels =
6625       filter_shape.dimensions(dnums.kernel_output_feature_dimension());
6626 
6627   // Computes the product of the non-feature dimensions.
6628   int64_t conv_width = 1;
6629   for (int i = 0; i < input_shape.dimensions_size(); ++i) {
6630     if (i != dnums.input_feature_dimension()) {
6631       conv_width *= input_shape.dimensions(i);
6632     }
6633   }
6634 
6635   // We already checked feature_dimension is most minor, so data in input_shape
6636   // and row-major {conv_width,input_channels} are bitwise identical.
6637   Shape new_input_shape = ShapeUtil::MakeShapeWithDescendingLayout(
6638       input_shape.element_type(), {conv_width, input_channels});
6639   simplifier_->UpdateLayout(&new_input_shape);
6640   // We already checked input_feature_dimension is more major than
6641   // output_feature_dimension, so data in filter_shape and row-major
6642   // {input_channels,output_channels} are bitwise identical.
6643   Shape new_filter_shape = ShapeUtil::MakeShapeWithDescendingLayout(
6644       filter_shape.element_type(), {input_channels, output_channels});
6645   simplifier_->UpdateLayout(&new_filter_shape);
6646   Shape dot_output_shape = ShapeUtil::MakeShapeWithDescendingLayout(
6647       convolution_shape.element_type(), {conv_width, output_channels});
6648   simplifier_->UpdateLayout(&dot_output_shape);
6649 
6650   auto new_lhs = add_bitcast(new_input_shape, lhs);
6651   auto new_rhs = add_bitcast(new_filter_shape, rhs);
6652   DotDimensionNumbers dot_dimension_numbers;
6653   dot_dimension_numbers.add_lhs_contracting_dimensions(1);
6654   dot_dimension_numbers.add_rhs_contracting_dimensions(0);
6655   auto dot = convolution->AddInstruction(HloInstruction::CreateDot(
6656       dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers,
6657       convolution->precision_config()));
6658 
6659   TF_RETURN_IF_ERROR(
6660       ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot)));
6661   return true;
6662 }
6663 
HandleConvolution(HloInstruction * convolution)6664 Status AlgebraicSimplifierVisitor::HandleConvolution(
6665     HloInstruction* convolution) {
6666   if (options_.enable_scalar_multiply_reduction()) {
6667     TF_RETURN_IF_ERROR(ScalarMultiplyReduction(convolution));
6668   }
6669 
6670   // Zero-sized input or filter.
6671   if (ShapeUtil::IsZeroElementArray(convolution->operand(0)->shape()) ||
6672       ShapeUtil::IsZeroElementArray(convolution->operand(1)->shape())) {
6673     return ReplaceInstruction(convolution, MakeScalarLike(convolution, 0));
6674   }
6675 
6676   // Try to merge padding/dilation of the input with the convolution's window.
6677   TF_ASSIGN_OR_RETURN(bool folded_input_pad, FoldConvInputPad(convolution));
6678   if (folded_input_pad) {
6679     return OkStatus();
6680   }
6681 
6682   // Try to merge dilation of the filter with the convolution's window.
6683   TF_ASSIGN_OR_RETURN(bool folded_filter_pad, FoldConvFilterPad(convolution));
6684   if (folded_filter_pad) {
6685     return OkStatus();
6686   }
6687 
6688   // Try to swap convolution operands.
6689   TF_ASSIGN_OR_RETURN(bool swapped, SwapConvOperands(convolution));
6690   if (swapped) {
6691     return OkStatus();
6692   }
6693   // Try to replace the convolution with a kDot instruction.
6694   TF_ASSIGN_OR_RETURN(bool replaced_with_dot, SimplifyConvToDot(convolution));
6695   if (replaced_with_dot) {
6696     return OkStatus();
6697   }
6698 
6699   return OkStatus();
6700 }
6701 
HandleMap(HloInstruction * map)6702 Status AlgebraicSimplifierVisitor::HandleMap(HloInstruction* map) {
6703   auto* map_computation = map->to_apply();
6704   auto* map_root = map_computation->root_instruction();
6705   if (map_root->opcode() == HloOpcode::kParameter) {
6706     ReplaceInstructionIfCompatible(
6707         map, map->mutable_operand(map_root->parameter_number()));
6708     return OkStatus();
6709   }
6710   if (map_root->opcode() == HloOpcode::kConstant) {
6711     if (!ShapeUtil::IsScalar(map_root->shape())) {
6712       return OkStatus();
6713     }
6714     auto clone = map_root->CloneWithNewOperands(map_root->shape(), {});
6715     if (ShapeUtil::IsScalar(map->shape())) {
6716       return ReplaceWithNewInstruction(map, std::move(clone));
6717     }
6718     return ReplaceWithNewInstruction(
6719         map, HloInstruction::CreateBroadcast(
6720                  map->shape(), map->AddInstruction(std::move(clone)), {}));
6721   }
6722   // Inline the map if the map computation only contains an elementwise
6723   // operation that can accept arbitrary shapes.
6724   if (map_root->opcode() == HloOpcode::kFusion || !map_root->IsElementwise()) {
6725     return OkStatus();
6726   }
6727   std::vector<HloInstruction*> new_operands;
6728   for (auto* root_operand : map_root->operands()) {
6729     if (root_operand->opcode() != HloOpcode::kParameter) {
6730       return OkStatus();
6731     }
6732     new_operands.push_back(
6733         map->mutable_operand(root_operand->parameter_number()));
6734   }
6735   auto clone = map_root->CloneWithNewOperands(map->shape(), new_operands);
6736   return ReplaceWithNewInstruction(map, std::move(clone));
6737 }
6738 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)6739 StatusOr<bool> AlgebraicSimplifier::Run(
6740     HloModule* module,
6741     const absl::flat_hash_set<absl::string_view>& execution_threads) {
6742   bool changed = false;
6743   AlgebraicSimplifierVisitor visitor(options_, this);
6744   for (auto* comp : module->MakeNonfusionComputations(execution_threads)) {
6745     if (visitor.Run(comp, options_, this)) {
6746       changed = true;
6747     }
6748   }
6749   return changed;
6750 }
6751 
6752 }  // namespace xla
6753