1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
17
18 #include <algorithm>
19 #include <cmath>
20 #include <functional>
21 #include <iterator>
22 #include <memory>
23 #include <numeric>
24 #include <optional>
25 #include <string>
26 #include <utility>
27 #include <vector>
28
29 #include "absl/algorithm/container.h"
30 #include "absl/container/flat_hash_map.h"
31 #include "absl/container/flat_hash_set.h"
32 #include "absl/container/inlined_vector.h"
33 #include "absl/strings/str_cat.h"
34 #include "absl/types/span.h"
35 #include "tensorflow/compiler/xla/comparison_util.h"
36 #include "tensorflow/compiler/xla/layout_util.h"
37 #include "tensorflow/compiler/xla/literal.h"
38 #include "tensorflow/compiler/xla/literal_comparison.h"
39 #include "tensorflow/compiler/xla/literal_util.h"
40 #include "tensorflow/compiler/xla/overflow_util.h"
41 #include "tensorflow/compiler/xla/permutation_util.h"
42 #include "tensorflow/compiler/xla/primitive_util.h"
43 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
44 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
45 #include "tensorflow/compiler/xla/service/hlo_computation.h"
46 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
47 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
48 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
49 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
50 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
51 #include "tensorflow/compiler/xla/service/hlo_query.h"
52 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
53 #include "tensorflow/compiler/xla/shape.h"
54 #include "tensorflow/compiler/xla/shape_util.h"
55 #include "tensorflow/compiler/xla/status_macros.h"
56 #include "tensorflow/compiler/xla/types.h"
57 #include "tensorflow/compiler/xla/util.h"
58 #include "tensorflow/compiler/xla/window_util.h"
59 #include "tensorflow/compiler/xla/xla_data.pb.h"
60 #include "tensorflow/core/lib/core/errors.h"
61 #include "tensorflow/core/lib/core/status.h"
62 #include "tensorflow/core/platform/errors.h"
63 #include "tensorflow/core/platform/logging.h"
64 #include "tensorflow/core/platform/statusor.h"
65 #include "tensorflow/stream_executor/lib/statusor.h"
66
67 namespace xla {
68
69 namespace {
70
71 namespace m = match;
72
73 // Unwraps broadcasts hunting for a constant. If we find one, checks if the
74 // constant contains only the given value.
IsAll(const HloInstruction * op,int8_t value)75 bool IsAll(const HloInstruction* op, int8_t value) {
76 switch (op->opcode()) {
77 case HloOpcode::kBroadcast:
78 return IsAll(op->operand(0), value);
79 case HloOpcode::kConstant:
80 return op->literal().IsAll(value);
81 default:
82 return false;
83 }
84 }
85
IsAll(const HloInstruction * op,const Literal & scalar)86 bool IsAll(const HloInstruction* op, const Literal& scalar) {
87 CHECK(ShapeUtil::IsScalar(scalar.shape()));
88 switch (op->opcode()) {
89 case HloOpcode::kBroadcast:
90 return IsAll(op->operand(0), scalar);
91 case HloOpcode::kConstant:
92 return op->literal().IsAll(scalar);
93 default:
94 return false;
95 }
96 }
97
IsAnyOperandComplex(const HloInstruction * hlo)98 bool IsAnyOperandComplex(const HloInstruction* hlo) {
99 for (auto operand : hlo->operands()) {
100 if (ShapeUtil::ElementIsComplex(operand->shape())) {
101 return true;
102 }
103 }
104 return false;
105 }
106
IsPositive(const HloInstruction * hlo,const AlgebraicSimplifierOptions & options)107 bool IsPositive(const HloInstruction* hlo,
108 const AlgebraicSimplifierOptions& options) {
109 // Utility only handles real types.
110 if (IsAnyOperandComplex(hlo)) {
111 return false;
112 }
113 switch (hlo->opcode()) {
114 case HloOpcode::kGetTupleElement: {
115 const HloInstruction* gte_operand = hlo->operand(0);
116 switch (gte_operand->opcode()) {
117 case HloOpcode::kCustomCall: {
118 const auto& target = gte_operand->custom_call_target();
119 return target ==
120 options.get_cudnn_batchnorm_forward_training_metadata() &&
121 hlo->tuple_index() == 2;
122 }
123 default:
124 return false;
125 }
126 }
127 case HloOpcode::kPower:
128 case HloOpcode::kAbs:
129 case HloOpcode::kRsqrt:
130 case HloOpcode::kSqrt:
131 return IsPositive(hlo->operand(0), options);
132
133 case HloOpcode::kMultiply: {
134 return hlo->operand(0) == hlo->operand(1) &&
135 IsPositive(hlo->operand(0), options);
136 }
137 default:
138 return false;
139 }
140 }
141
GetConstantValue(const HloInstruction * inst)142 std::optional<double> GetConstantValue(const HloInstruction* inst) {
143 if (!ShapeUtil::IsEffectiveScalar(inst->shape())) {
144 return std::nullopt;
145 }
146 switch (inst->shape().element_type()) {
147 case F16:
148 return static_cast<float>(inst->literal().GetFirstElement<half>());
149 case BF16:
150 return static_cast<float>(inst->literal().GetFirstElement<bfloat16>());
151 case F32:
152 return inst->literal().GetFirstElement<float>();
153 case F64:
154 return inst->literal().GetFirstElement<double>();
155 default:
156 return std::nullopt;
157 }
158 }
159
IsScalarConstant(const HloInstruction * hlo,const LiteralSlice & literal)160 static bool IsScalarConstant(const HloInstruction* hlo,
161 const LiteralSlice& literal) {
162 return hlo->opcode() == HloOpcode::kConstant &&
163 ShapeUtil::IsEffectiveScalar(hlo->shape()) &&
164 literal_comparison::Equal(hlo->literal(), literal).ok();
165 }
166
IsScalarConstantZero(const HloInstruction * hlo)167 static bool IsScalarConstantZero(const HloInstruction* hlo) {
168 return IsScalarConstant(hlo, LiteralUtil::Zero(hlo->shape().element_type()));
169 }
170
IsScalarConstantNegInf(const HloInstruction * hlo)171 static bool IsScalarConstantNegInf(const HloInstruction* hlo) {
172 return !primitive_util::IsComplexType(hlo->shape().element_type()) &&
173 IsScalarConstant(hlo,
174 LiteralUtil::MinValue(hlo->shape().element_type()));
175 }
176
IsScalarConstantInf(const HloInstruction * hlo)177 static bool IsScalarConstantInf(const HloInstruction* hlo) {
178 return !primitive_util::IsComplexType(hlo->shape().element_type()) &&
179 IsScalarConstant(hlo,
180 LiteralUtil::MaxValue(hlo->shape().element_type()));
181 }
182
IsNonNegative(const HloInstruction * hlo,const AlgebraicSimplifierOptions & options)183 bool IsNonNegative(const HloInstruction* hlo,
184 const AlgebraicSimplifierOptions& options) {
185 // Utility only handles real types.
186 if (IsAnyOperandComplex(hlo)) {
187 return false;
188 }
189 switch (hlo->opcode()) {
190 case HloOpcode::kMultiply: {
191 return hlo->operand(0) == hlo->operand(1);
192 }
193 case HloOpcode::kAbs: {
194 return true;
195 }
196 case HloOpcode::kBroadcast: {
197 return IsNonNegative(hlo->operand(0), options);
198 }
199 case HloOpcode::kConstant: {
200 if (std::optional<double> value = GetConstantValue(hlo)) {
201 return *value >= 0.0;
202 }
203 return false;
204 }
205 case HloOpcode::kMaximum: {
206 return IsNonNegative(hlo->operand(0), options) ||
207 IsNonNegative(hlo->operand(1), options);
208 }
209 case HloOpcode::kSelect: {
210 return IsNonNegative(hlo->operand(1), options) &&
211 IsNonNegative(hlo->operand(2), options);
212 }
213 default:
214 return IsPositive(hlo, options);
215 }
216 }
217
218 // Checks whether `op` is a floating-point constant or broadcast of a constant
219 // of the form +/- 2^k for some integer k positive, negative, or zero. Such
220 // values are interesting because multiplying by a power of 2 just moves the
221 // exponent.
IsAllFpConstantPowerOf2(const HloInstruction * op)222 bool IsAllFpConstantPowerOf2(const HloInstruction* op) {
223 // Unwrap the broadcast if necessary.
224 const HloInstruction* c;
225 if (!Match(op, m::ConstantEffectiveScalar(&c)) &&
226 !Match(op, m::Broadcast(m::Constant(&c).WithShape(
227 m::Shape().IsEffectiveScalar())))) {
228 return false;
229 }
230 auto val = [&]() -> std::optional<double> {
231 switch (c->shape().element_type()) {
232 case BF16:
233 return static_cast<double>(c->literal().GetFirstElement<bfloat16>());
234 case F16:
235 return static_cast<double>(c->literal().GetFirstElement<Eigen::half>());
236 case F32:
237 return c->literal().GetFirstElement<float>();
238 case F64:
239 return c->literal().GetFirstElement<double>();
240 default:
241 // Cowardly refuse to consider complex types.
242 return std::nullopt;
243 }
244 }();
245 if (!val) {
246 return false;
247 }
248
249 int exp;
250 double mantissa = std::frexp(*val, &exp);
251 // frexp returns a value in the range (-1, -0.5] U [0.5, 1). A return value
252 // of +/-0.5 therefore indicates that the floating point value is a power of
253 // 2.
254 return mantissa == 0.5 || mantissa == -0.5;
255 }
256
257 // Returns whether the given transpose produces a result which is bit-wise
258 // identical to its operand and thus may be replaced with a bitcast.
TransposeIsBitcast(const HloInstruction * transpose)259 bool TransposeIsBitcast(const HloInstruction* transpose) {
260 CHECK_EQ(HloOpcode::kTranspose, transpose->opcode());
261 const HloInstruction* operand = transpose->operand(0);
262 return ShapeUtil::TransposeIsBitcast(operand->shape(), transpose->shape(),
263 transpose->dimensions());
264 }
265
266 // Recursive helper for method below.
BitcastingOperandOfReshapeOrCopyChainHelper(HloInstruction * instr,HloInstruction * operand,const AlgebraicSimplifierOptions & options)267 HloInstruction* BitcastingOperandOfReshapeOrCopyChainHelper(
268 HloInstruction* instr, HloInstruction* operand,
269 const AlgebraicSimplifierOptions& options) {
270 // Can't replace chain of copies and reshapes with bitcasts if the compiler
271 // used a memory layout which isn't compatible.
272 if (options.ReshapeIsBitcast(operand->shape(), instr->shape())) {
273 return operand;
274 }
275
276 // If the operand is a copy or reshape try to see if the operand's operand
277 // would produce a bitcast with initial instruction.
278 if (HloOpcode::kReshape == operand->opcode() ||
279 HloOpcode::kCopy == operand->opcode()) {
280 return BitcastingOperandOfReshapeOrCopyChainHelper(
281 instr, operand->mutable_operand(0), options);
282 }
283 return nullptr;
284 }
285
286 // Returns an operand of a chain of reshapes and copies that is bit-wise
287 // identical to first reshape or copy in the chain.
BitcastingOperandOfReshapeOrCopyChain(HloInstruction * instr,const AlgebraicSimplifierOptions & options)288 HloInstruction* BitcastingOperandOfReshapeOrCopyChain(
289 HloInstruction* instr, const AlgebraicSimplifierOptions& options) {
290 if (!options.is_layout_sensitive()) {
291 return nullptr;
292 }
293 CHECK(HloOpcode::kReshape == instr->opcode() ||
294 HloOpcode::kCopy == instr->opcode());
295 return BitcastingOperandOfReshapeOrCopyChainHelper(
296 instr, instr->mutable_operand(0), options);
297 }
298
IsUnstridedSlice(const HloInstruction * hlo)299 bool IsUnstridedSlice(const HloInstruction* hlo) {
300 return absl::c_all_of(hlo->slice_strides(),
301 [](int64_t stride) { return stride == 1; });
302 }
303
304 // Returns bool to determine whether a pair of converts can be eliminated.
IsConvertPairNoOp(const HloInstruction * convert)305 bool IsConvertPairNoOp(const HloInstruction* convert) {
306 // [operand_convert] [convert]
307 // (src)->convert-(intermediate)->convert-(dest)
308 const HloInstruction* operand_convert = convert->operand(0);
309 if (operand_convert->opcode() != HloOpcode::kConvert) {
310 return false;
311 }
312 const PrimitiveType src_type =
313 operand_convert->operand(0)->shape().element_type();
314 const PrimitiveType intermediate_type =
315 operand_convert->shape().element_type();
316
317 return src_type == convert->shape().element_type() &&
318 primitive_util::CastPreservesValues(src_type, intermediate_type);
319 }
320
SwapOperandsInDotPrecisionConfig(PrecisionConfig config)321 PrecisionConfig SwapOperandsInDotPrecisionConfig(PrecisionConfig config) {
322 CHECK_EQ(config.operand_precision_size(), 2);
323 std::swap(config.mutable_operand_precision()->at(0),
324 config.mutable_operand_precision()->at(1));
325 return config;
326 }
327
328 // Validate whether tiling and padding assignments in the bitcasted shapes
329 // will make the two shapes non-equivalent.
ValidateTilingOfBitcast(const Shape & bitcast_shape,const Shape & op_shape,const std::vector<std::vector<int64_t>> & operand_map)330 bool ValidateTilingOfBitcast(
331 const Shape& bitcast_shape, const Shape& op_shape,
332 const std::vector<std::vector<int64_t>>& operand_map) {
333 if (op_shape.layout().tiles().empty() ||
334 bitcast_shape.layout().tiles().empty()) {
335 return true;
336 }
337 VLOG(2) << "op shape:" << op_shape.ToString(true) << "\n";
338 VLOG(2) << "bitcast shape:" << bitcast_shape.ToString(true) << "\n";
339 VLOG(2) << "operand_map size:" << operand_map.size() << "\n";
340 auto op_tile = op_shape.layout().tiles(0);
341 auto bitcast_tile = bitcast_shape.layout().tiles(0);
342 int64_t num_of_tiled_dims = op_tile.dimensions().size(),
343 tiled_dim_idx = num_of_tiled_dims - 1;
344 if (bitcast_tile.dimensions().size() != num_of_tiled_dims) {
345 return false;
346 }
347 for (auto op_dim : op_shape.layout().minor_to_major()) {
348 VLOG(3) << "op_dim = " << op_dim << "\n";
349 VLOG(3) << "tiled_dim_idx = " << tiled_dim_idx << "\n";
350 VLOG(3) << "tiled_dim_size = " << op_tile.dimension(tiled_dim_idx) << ":"
351 << bitcast_tile.dimension(tiled_dim_idx) << "\n";
352 if (op_tile.dimensions()[tiled_dim_idx] !=
353 bitcast_tile.dimensions()[tiled_dim_idx]) {
354 VLOG(2) << "Abort b/c tiled dimension " << op_dim
355 << " has different tiling sizes before and after bitcast.\n";
356 return false;
357 }
358 if (operand_map.size() <= op_dim || operand_map[op_dim].empty()) {
359 if (op_tile.dimensions()[tiled_dim_idx] != 1) {
360 VLOG(2) << "Abort b/c tiled dimension " << op_dim << " has size 1.\n";
361 return false;
362 }
363 } else if (bitcast_shape.dimensions_size() <= operand_map[op_dim][0]) {
364 VLOG(2) << "Abort because the bitcasted dimensions are not aligned!\n";
365 return false;
366 } else if (bitcast_shape.dimensions(operand_map[op_dim][0]) <
367 op_shape.dimensions(op_dim)) {
368 if (operand_map[op_dim].size() == 1) {
369 VLOG(2) << "Abort b/c a dimension (possibly padded) is shrank to a "
370 "smaller size.\n";
371 return false;
372 }
373 if (tiled_dim_idx > 0) {
374 VLOG(2) << "Abort b/c a non-major tiled dimension is split.\n";
375 return false;
376 }
377 if (bitcast_shape.dimensions(operand_map[op_dim][0]) %
378 op_tile.dimensions()[tiled_dim_idx] !=
379 0 ||
380 op_shape.dimensions(op_dim) %
381 bitcast_shape.dimensions(operand_map[op_dim][0]) !=
382 0) {
383 VLOG(2) << "Abort b/c tiled dimension " << op_dim
384 << " has been split in bitcasted layout\n";
385 return false;
386 }
387 } else if (bitcast_shape.dimensions(operand_map[op_dim][0]) >
388 op_shape.dimensions(op_dim)) {
389 if (tiled_dim_idx > 0) {
390 VLOG(2) << "Abort b/c a non-major tiled dimension is combined.\n";
391 return false;
392 }
393 if (bitcast_shape.dimensions(operand_map[op_dim][0]) %
394 op_shape.dimensions(op_dim) !=
395 0 ||
396 op_shape.dimensions(op_dim) % op_tile.dimensions()[tiled_dim_idx] !=
397 0) {
398 VLOG(2) << "Abort b/c tiled dimension " << op_dim
399 << " has been combined in bitcasted layout\n";
400 return false;
401 }
402 }
403 if (--tiled_dim_idx < 0) {
404 break;
405 }
406 }
407 return true;
408 }
409
410 } // namespace
411
ResetState(HloComputation * computation)412 void AlgebraicSimplifierVisitor::ResetState(HloComputation* computation) {
413 ResetVisitStates();
414 computation_ = computation;
415 }
416
Run(HloComputation * computation,const AlgebraicSimplifierOptions & options,AlgebraicSimplifier * simplifier)417 bool AlgebraicSimplifierVisitor::Run(HloComputation* computation,
418 const AlgebraicSimplifierOptions& options,
419 AlgebraicSimplifier* simplifier) {
420 ResetState(computation);
421 TF_CHECK_OK(computation->Accept(this));
422 return changed();
423 }
424
SameShape(const HloInstruction * lhs,const HloInstruction * rhs) const425 bool AlgebraicSimplifierVisitor::SameShape(const HloInstruction* lhs,
426 const HloInstruction* rhs) const {
427 return SameShape(lhs->shape(), rhs->shape());
428 }
429
SameShape(const Shape & lhs,const Shape & rhs) const430 bool AlgebraicSimplifierVisitor::SameShape(const Shape& lhs,
431 const Shape& rhs) const {
432 if (options_.is_layout_sensitive()) {
433 return ShapeUtil::Equal(lhs, rhs);
434 } else {
435 return ShapeUtil::Compatible(lhs, rhs);
436 }
437 }
438
439 namespace {
440
IsOpCodeMultiplyCommutative(HloOpcode opcode)441 bool IsOpCodeMultiplyCommutative(HloOpcode opcode) {
442 switch (opcode) {
443 case HloOpcode::kMultiply:
444 case HloOpcode::kTranspose:
445 case HloOpcode::kReshape:
446 case HloOpcode::kSelect:
447 return true;
448 default:
449 return false;
450 }
451 }
452
MakeScalarInstruction(HloInstruction * target,float multiplier)453 std::unique_ptr<HloInstruction> MakeScalarInstruction(HloInstruction* target,
454 float multiplier) {
455 switch (target->shape().element_type()) {
456 case BF16:
457 return HloInstruction::CreateConstant(LiteralUtil::ConvertF32ToBF16(
458 LiteralUtil::CreateR0<float>(multiplier)));
459 break;
460 case F32:
461 return HloInstruction::CreateConstant(
462 LiteralUtil::CreateR0<float>(multiplier));
463 break;
464 default:
465 LOG(FATAL) << "Unsupported data type: " << target->shape().element_type();
466 }
467 }
468
469 } // namespace
470
ScalarMultiplyReduction(HloInstruction * dot)471 Status AlgebraicSimplifierVisitor::ScalarMultiplyReduction(
472 HloInstruction* dot) {
473 // We only process bfloat16 and float32 for now.
474 if (dot->shape().element_type() != BF16 &&
475 dot->shape().element_type() != F32) {
476 return OkStatus();
477 }
478
479 auto lhs = dot->mutable_operand(0);
480 auto rhs = dot->mutable_operand(1);
481
482 const int64_t dot_size = ShapeUtil::ElementsIn(dot->shape());
483 const int64_t lhs_size = ShapeUtil::ElementsIn(lhs->shape());
484 const int64_t rhs_size = ShapeUtil::ElementsIn(rhs->shape());
485
486 HloInstruction* target = nullptr;
487 // (current node, user, operand_index)
488 std::vector<std::tuple<HloInstruction*, HloInstruction*, int64_t>> operands;
489 std::vector<HloInstruction*> users;
490
491 // Find which side of dot has the smallest size:
492 // operand 0, operand 1, or output.
493 if (dot_size <= std::min(lhs_size, rhs_size)) {
494 target = dot;
495 if (dot_size < lhs_size) {
496 operands.emplace_back(lhs, dot, 0);
497 }
498 if (dot_size < rhs_size) {
499 operands.emplace_back(rhs, dot, 1);
500 }
501 } else if (lhs_size <= rhs_size) {
502 target = lhs;
503 if (lhs_size < rhs_size) {
504 operands.emplace_back(rhs, dot, 1);
505 }
506 if (lhs_size < dot_size && dot->user_count() == 1) {
507 users.push_back(dot->users().front());
508 }
509 } else {
510 target = rhs;
511 if (rhs_size < lhs_size) {
512 operands.emplace_back(lhs, dot, 0);
513 }
514 if (rhs_size < dot_size && dot->user_count() == 1) {
515 users.push_back(dot->users().front());
516 }
517 }
518
519 std::vector<float> values;
520
521 // DFS to find scalar multiply ops from the operands.
522 while (!operands.empty()) {
523 HloInstruction* inst;
524 HloInstruction* user;
525 int64_t index;
526 std::tie(inst, user, index) = operands.back();
527 operands.pop_back();
528
529 // Skip the op types that are not commutative with multiply.
530 if (!IsOpCodeMultiplyCommutative(inst->opcode())) {
531 continue;
532 }
533
534 HloInstruction* operand;
535 HloInstruction* multiplier;
536 // Pattern match a scalar multiply.
537 if (Match(inst, m::MultiplyAnyOrder(
538 m::Op(&operand),
539 m::Broadcast(m::ConstantScalar(&multiplier))))) {
540 CHECK_LT(index, user->operand_count());
541 CHECK_EQ(inst, user->operands()[index]);
542
543 // When found a scalar multiply, save its scalar value.
544 values.push_back(*GetConstantValue(multiplier));
545 // And remove the scalar multiply op.
546 TF_RETURN_IF_ERROR(user->ReplaceOperandWith(index, operand));
547 inst = operand;
548 }
549
550 // Push the operands of inst.
551 int64_t i = 0;
552 for (auto* operand : inst->operands()) {
553 operands.emplace_back(operand, inst, i++);
554 }
555 }
556
557 // DFS to find scalar multiply ops from the users.
558 while (!users.empty()) {
559 auto inst = users.back();
560 users.pop_back();
561
562 if (!IsOpCodeMultiplyCommutative(inst->opcode())) {
563 continue;
564 }
565
566 HloInstruction* operand;
567 HloInstruction* multiplier;
568 if (Match(inst, m::MultiplyAnyOrder(
569 m::Op(&operand),
570 m::Broadcast(m::ConstantScalar(&multiplier))))) {
571 values.push_back(*GetConstantValue(multiplier));
572
573 TF_RETURN_IF_ERROR(inst->ReplaceAllUsesWith(operand));
574 inst = operand;
575 }
576
577 // Process the instructions with only one user.
578 // Otherwise moving scalar multiply to the operands changes the values of
579 // other users.
580 if (inst->user_count() == 1) {
581 users.push_back(inst->users().front());
582 }
583 }
584
585 if (values.empty()) {
586 return OkStatus();
587 }
588
589 MarkAsChanged();
590
591 // Combine all constant multipliers.
592 float multiplier = 1.0;
593 for (const float v : values) {
594 multiplier *= v;
595 }
596
597 // Create a new const scalar multiply instruction.
598 HloInstruction* new_const_inst;
599 new_const_inst =
600 target->AddInstruction(MakeScalarInstruction(target, multiplier));
601
602 // Broadcast the scalar multiplier.
603 HloInstruction* new_broadcast = target->AddInstruction(
604 HloInstruction::CreateBroadcast(target->shape(), new_const_inst, {}));
605 // Create a new scalar multiply instruction.
606 HloInstruction* new_multiply =
607 target->AddInstruction(HloInstruction::CreateBinary(
608 target->shape(), HloOpcode::kMultiply, target, new_broadcast));
609 CHECK_EQ(new_multiply->shape(), target->shape());
610
611 // Update the dependency with the rest of the instructions.
612 if (target == lhs) {
613 return dot->ReplaceOperandWith(0, new_multiply);
614 } else if (target == rhs) {
615 return dot->ReplaceOperandWith(1, new_multiply);
616 } else {
617 CHECK_EQ(target, dot);
618 return dot->ReplaceAllUsesWith(new_multiply);
619 }
620 }
621
ReplaceWithBitcast(HloInstruction * instruction,HloInstruction * operand)622 void AlgebraicSimplifierVisitor::ReplaceWithBitcast(HloInstruction* instruction,
623 HloInstruction* operand) {
624 CHECK_EQ(1, instruction->operand_count());
625 if (operand == nullptr) {
626 operand = instruction->mutable_operand(0);
627 }
628 CHECK_EQ(ShapeUtil::ElementsIn(instruction->shape()),
629 ShapeUtil::ElementsIn(operand->shape()));
630 CHECK_EQ(ShapeUtil::ByteSizeOf(instruction->shape()),
631 ShapeUtil::ByteSizeOf(operand->shape()));
632
633 auto bitcast = instruction->AddInstruction(
634 HloInstruction::CreateBitcast(instruction->shape(), operand));
635 TF_CHECK_OK(ReplaceInstruction(instruction, bitcast));
636 }
637
638 // Replace the old instruction with the new one if they are compatible, i.e.,
639 // 1. they have same shape
640 // 2. the replacement will not cause loss of sharding
ReplaceInstructionIfCompatible(HloInstruction * old_instruction,HloInstruction * new_instruction)641 bool AlgebraicSimplifierVisitor::ReplaceInstructionIfCompatible(
642 HloInstruction* old_instruction, HloInstruction* new_instruction) {
643 if (!SameShape(old_instruction, new_instruction)) {
644 return false;
645 }
646 return ReplaceInstruction(old_instruction, new_instruction,
647 /*preserve_sharding=*/true)
648 .ValueOrDie();
649 }
650
ReplaceInstructionIfCompatible(HloInstruction * old_instruction,absl::Span<HloInstruction * const> new_instructions)651 bool AlgebraicSimplifierVisitor::ReplaceInstructionIfCompatible(
652 HloInstruction* old_instruction,
653 absl::Span<HloInstruction* const> new_instructions) {
654 if (new_instructions.size() == 1) {
655 return ReplaceInstructionIfCompatible(old_instruction, new_instructions[0]);
656 }
657 CHECK(!new_instructions.empty());
658 if (!old_instruction->shape().IsTuple() ||
659 old_instruction->shape().tuple_shapes_size() != new_instructions.size()) {
660 return false;
661 }
662 for (int i = 0, n = new_instructions.size(); i < n; ++i) {
663 if (!SameShape(old_instruction->shape().tuple_shapes(i),
664 new_instructions[i]->shape())) {
665 return false;
666 }
667 }
668 return ReplaceInstruction(old_instruction, MaybeMakeTuple(new_instructions),
669 /*preserve_sharding=*/true)
670 .ValueOrDie();
671 }
672
HandleAbs(HloInstruction * abs)673 Status AlgebraicSimplifierVisitor::HandleAbs(HloInstruction* abs) {
674 HloInstruction* abs_operand = abs->mutable_operand(0);
675 VLOG(10) << "trying transform [Abs(A) => A] " << abs->ToString()
676 << " Abs operand is: " << abs_operand->ToString();
677 if (IsNonNegative(abs->operand(0), options_)) {
678 return ReplaceInstruction(abs, abs_operand);
679 }
680 return OkStatus();
681 }
682
HandleAdd(HloInstruction * add)683 Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) {
684 HloInstruction *lhs, *rhs;
685 CHECK(Match(add, m::Add(m::Op(&lhs), m::Op(&rhs))));
686
687 // A + 0 => A
688 VLOG(10) << "trying transform [A + 0 => A]: " << add->ToString();
689 if (IsAll(rhs, 0) && ReplaceInstructionIfCompatible(add, lhs)) {
690 return OkStatus();
691 }
692 // 0 + A => A
693 VLOG(10) << "trying transform [0 + A => A]: " << add->ToString();
694 if (IsAll(lhs, 0) && ReplaceInstructionIfCompatible(add, rhs)) {
695 return OkStatus();
696 }
697
698 // Canonicalization: Put constants on the right. This makes the reassociation
699 // rules below simpler.
700 VLOG(10) << "trying transform [Const + A => A + Const]";
701 if (Match(add, m::Add(m::Constant(), m::NonConstant()))) {
702 return ReplaceWithNewInstruction(
703 add,
704 HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, rhs, lhs));
705 }
706
707 // Reassociate to allow constant folding.
708 //
709 // Note: This is not general. For example, we won't reassociate
710 //
711 // (A + C1) + (B + C2) => A + B + (C1 + C2).
712 //
713 VLOG(10) << "trying transform [(A + C1) + C2 => A + (C1 + C2)]";
714 HloInstruction *a, *c1, *c2;
715 if (Match(add, m::Add(m::Add(m::NonConstant(&a), m::Constant(&c1)),
716 m::Constant(&c2))) ||
717 Match(add, m::Add(m::Add(m::NonConstant(&a),
718 m::Broadcast(m::ConstantScalar(&c1))),
719 m::Broadcast(m::ConstantScalar(&c2))))) {
720 TF_ASSIGN_OR_RETURN(auto* sum_of_constants,
721 MakeBinaryHlo(HloOpcode::kAdd, c1, c2));
722 if (ShapeUtil::IsScalar(sum_of_constants->shape()) &&
723 !ShapeUtil::IsScalar(add->shape())) {
724 sum_of_constants = add->AddInstruction(
725 HloInstruction::CreateBroadcast(add->shape(), sum_of_constants, {}));
726 }
727 return ReplaceWithNewInstruction(
728 add, HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, a,
729 sum_of_constants));
730 }
731
732 // Convert add with fullshape into add with partial shape when a
733 // portion of add is effective:
734 // zero (fullshape) rhs (partialshape)
735 // . | |
736 // . lhs . dynamic_update_slice (fullshape)
737 // . | |
738 // Add (fullshape)
739 //
740 // to:
741 // lhs
742 // |
743 // dynamic_slice (partialshape) rhs (partialshape)
744 // . | |
745 // . lhs . add (partial_shape)+----+
746 // . | |
747 // dynamic_update_slice (fullshape)
748 //
749 // This is pattern is discovered in control flow V2 gradient update.
750 if (Match(add,
751 m::Add(m::Op(&lhs),
752 m::Op(&rhs)
753 .WithOpcode(HloOpcode::kDynamicUpdateSlice)
754 .WithOperand(
755 0, m::Broadcast(m::ConstantEffectiveScalar(0)))))) {
756 const Shape& partial_shape = rhs->operand(1)->shape();
757 auto sliced_lhs = lhs->AddInstruction(HloInstruction::CreateDynamicSlice(
758 partial_shape, lhs, absl::MakeSpan(rhs->operands()).subspan(2),
759 partial_shape.dimensions()));
760
761 auto add_partial = rhs->AddInstruction(
762 HloInstruction::CreateBinary(rhs->operand(1)->shape(), HloOpcode::kAdd,
763 sliced_lhs, rhs->mutable_operand(1)));
764
765 auto dynamic_update_slice_full = HloInstruction::CreateDynamicUpdateSlice(
766 lhs->shape(), lhs, add_partial,
767 absl::MakeSpan(rhs->operands()).subspan(2));
768
769 return ReplaceWithNewInstruction(add, std::move(dynamic_update_slice_full));
770 }
771
772 // A*C + B*C => (A+B)*C
773 //
774 // - If A, B, and C are integers, do this unconditionally. Proof of
775 // correctness: https://rise4fun.com/Alive/u9X.
776 //
777 // - If A, B, and C are floating point, do this if C is a scalar constant or
778 // broadcast of scalar constant and is equal to +/- 2^k for some (possibly
779 // negative) integer k.
780 //
781 // Multiplying by a power of 2 just moves the exponent, so our answer is
782 // exact modulo rounding of intermediate results so long as
783 //
784 // - none of the three products has an exponent which underflows (so the
785 // result is 0 or denormal), and
786 // - none of the three products overflows to inf.
787 //
788 // Proof: See algebraic_simplifier_proof_distributive_property.py.
789 //
790 // We deem these differences in rounding, underflow, and overflow
791 // acceptable in the ML context.
792 //
793 // Furthermore, if `enable_floats_are_real` is true, the simplification is
794 // done nonetheless. This might cause numerical differences even if there
795 // is no underflow or overflow.
796 HloInstruction *b, *c;
797 if (((Match(lhs, m::Multiply(m::Op(&a), m::Op(&c))) &&
798 Match(rhs, m::MultiplyAnyOrder(m::Op().Is(c), m::Op(&b)))) ||
799 (Match(lhs, m::Multiply(m::Op(&c), m::Op(&a))) &&
800 Match(rhs, m::MultiplyAnyOrder(m::Op().Is(c), m::Op(&b))))) &&
801 // Make sure we would decrease the number of multiplies.
802 (lhs->user_count() == 1 && rhs->user_count() == 1) &&
803 (ShapeUtil::ElementIsIntegral(add->shape()) ||
804 options_.enable_floats_are_real() || IsAllFpConstantPowerOf2(c))) {
805 return ReplaceWithNewInstruction(
806 add, HloInstruction::CreateBinary(
807 add->shape(), HloOpcode::kMultiply,
808 lhs->AddInstruction(HloInstruction::CreateBinary(
809 add->shape(), HloOpcode::kAdd, a, b)),
810 c));
811 }
812
813 if (options_.is_layout_sensitive()) {
814 return OkStatus();
815 }
816
817 HloInstruction* lhs_scatter_operand = nullptr;
818 HloInstruction* rhs_scatter_operand = nullptr;
819 HloInstruction* lhs_scatter_update = nullptr;
820 HloInstruction* rhs_scatter_update = nullptr;
821 HloInstruction* lhs_scatter_index = nullptr;
822 HloInstruction* rhs_scatter_index = nullptr;
823 bool lhs_scatter = Match(lhs, m::Scatter(m::Op(&lhs_scatter_operand),
824 m::Op(&lhs_scatter_index),
825 m::Op(&lhs_scatter_update))
826 .WithOneUse()) &&
827 Match(lhs->to_apply()->root_instruction(),
828 m::Add(m::Parameter(), m::Parameter()));
829 bool rhs_scatter = Match(rhs, m::Scatter(m::Op(&rhs_scatter_operand),
830 m::Op(&rhs_scatter_index),
831 m::Op(&rhs_scatter_update))
832 .WithOneUse()) &&
833 Match(rhs->to_apply()->root_instruction(),
834 m::Add(m::Parameter(), m::Parameter()));
835 if (rhs_scatter && lhs_scatter) {
836 const auto& lhs_dnums = lhs->scatter_dimension_numbers();
837 const auto& rhs_dnums = rhs->scatter_dimension_numbers();
838 std::optional<int64_t> index_concat_dimension;
839 std::optional<int64_t> update_concat_dimension;
840 // Don't try to combine scatters of different ranks.
841 if (lhs_scatter_index->shape().rank() !=
842 rhs_scatter_index->shape().rank()) {
843 return OkStatus();
844 }
845
846 int64_t first_index_dim = lhs_scatter_index->shape().rank();
847 int64_t first_update_dim = lhs_scatter_update->shape().rank();
848 // Find a dimension where it is possible to concatenate the indices and
849 // updates. This is the first and only non-equal dimension or the first
850 // equally sized dimension.
851 for (int64_t d = lhs_scatter_index->shape().rank() - 1,
852 update_dim = lhs_scatter_update->shape().rank() - 1;
853 d >= 0; --d) {
854 if (d == lhs_dnums.index_vector_dim()) {
855 continue;
856 }
857 while (
858 absl::c_linear_search(lhs_dnums.update_window_dims(), update_dim)) {
859 --update_dim;
860 }
861 if (lhs_scatter_index->shape().dimensions(d) ==
862 rhs_scatter_index->shape().dimensions(d)) {
863 first_index_dim = d;
864 first_update_dim = update_dim--;
865 continue;
866 }
867 // More than one dimension of unequal size was found, bail out.
868 if (index_concat_dimension) {
869 return OkStatus();
870 }
871 index_concat_dimension = d;
872 update_concat_dimension = update_dim--;
873 }
874 if (!index_concat_dimension) {
875 index_concat_dimension = first_index_dim;
876 update_concat_dimension = first_update_dim;
877 }
878
879 // A scalar scatter will require additional reshapes of the index and
880 // update.
881 if (*index_concat_dimension == lhs_scatter_index->shape().rank()) {
882 return OkStatus();
883 }
884 const bool update_concat_is_cheap =
885 ShapeUtil::ElementsIn(rhs_scatter_update->shape()) +
886 ShapeUtil::ElementsIn(lhs_scatter_update->shape()) <
887 ShapeUtil::ElementsIn(lhs->shape());
888 if (!update_concat_is_cheap) {
889 return OkStatus();
890 }
891 const bool same_dimension_numbers =
892 lhs_dnums.index_vector_dim() == rhs_dnums.index_vector_dim() &&
893 absl::c_equal(lhs_dnums.scatter_dims_to_operand_dims(),
894 rhs_dnums.scatter_dims_to_operand_dims()) &&
895 absl::c_equal(lhs_dnums.inserted_window_dims(),
896 rhs_dnums.inserted_window_dims()) &&
897 absl::c_equal(lhs_dnums.update_window_dims(),
898 rhs_dnums.update_window_dims());
899 const bool index_concat_is_safe =
900 !lhs->unique_indices() && !rhs->unique_indices() &&
901 !DynCast<HloScatterInstruction>(lhs)->indices_are_sorted() &&
902 !DynCast<HloScatterInstruction>(rhs)->indices_are_sorted();
903
904 Shape lhs_update_window = ShapeUtil::FilterDimensions(
905 [&](int64_t dim) {
906 return absl::c_linear_search(lhs_dnums.update_window_dims(), dim);
907 },
908 lhs_scatter_update->shape());
909 Shape rhs_update_window = ShapeUtil::FilterDimensions(
910 [&](int64_t dim) {
911 return absl::c_linear_search(rhs_dnums.update_window_dims(), dim);
912 },
913 rhs_scatter_update->shape());
914 // Concatenate the indices and updates
915 if (index_concat_is_safe && same_dimension_numbers &&
916 index_concat_dimension &&
917 lhs_scatter_index->shape().element_type() ==
918 rhs_scatter_index->shape().element_type() &&
919 ShapeUtil::SameDimensions(lhs_update_window, rhs_update_window)) {
920 TF_ASSIGN_OR_RETURN(HloInstruction * new_operand,
921 MakeBinaryHlo(HloOpcode::kAdd, lhs_scatter_operand,
922 rhs_scatter_operand));
923 TF_ASSIGN_OR_RETURN(HloInstruction * new_index,
924 MakeConcatHlo({lhs_scatter_index, rhs_scatter_index},
925 *index_concat_dimension));
926 TF_ASSIGN_OR_RETURN(
927 HloInstruction * new_update,
928 MakeConcatHlo({lhs_scatter_update, rhs_scatter_update},
929 *update_concat_dimension));
930 return ReplaceWithNewInstruction(
931 add, HloInstruction::CreateScatter(
932 add->shape(), new_operand, new_index, new_update,
933 lhs->to_apply(), lhs_dnums, false, false));
934 }
935 TF_ASSIGN_OR_RETURN(HloInstruction * new_operand,
936 MakeBinaryHlo(HloOpcode::kAdd, lhs_scatter_operand,
937 rhs_scatter_operand));
938 TF_RETURN_IF_ERROR(rhs->ReplaceOperandWith(0, new_operand));
939 TF_RETURN_IF_ERROR(lhs->ReplaceOperandWith(0, rhs));
940 return ReplaceInstruction(add, lhs);
941 } else if (rhs_scatter) {
942 TF_ASSIGN_OR_RETURN(
943 HloInstruction * new_operand,
944 MakeBinaryHlo(HloOpcode::kAdd, lhs, rhs_scatter_operand));
945 TF_RETURN_IF_ERROR(rhs->ReplaceOperandWith(0, new_operand));
946 return ReplaceInstruction(add, rhs);
947 } else if (lhs_scatter) {
948 TF_ASSIGN_OR_RETURN(
949 HloInstruction * new_operand,
950 MakeBinaryHlo(HloOpcode::kAdd, lhs_scatter_operand, rhs));
951 TF_RETURN_IF_ERROR(lhs->ReplaceOperandWith(0, new_operand));
952 return ReplaceInstruction(add, lhs);
953 }
954 return OkStatus();
955 }
956
TrySimplifyTautologicalCompare(HloInstruction * conjunction)957 StatusOr<bool> AlgebraicSimplifierVisitor::TrySimplifyTautologicalCompare(
958 HloInstruction* conjunction) {
959 HloInstruction *lhs, *rhs;
960 if (!Match(conjunction, m::And(m::Op(&lhs), m::Op(&rhs)))) {
961 return false;
962 }
963 struct LessThanCompareInfo { // (LT var constant)
964 HloInstruction* var;
965 int64_t constant;
966 };
967
968 auto get_compare_info =
969 [&](HloInstruction* cmp) -> std::optional<LessThanCompareInfo> {
970 HloInstruction *lhs, *rhs;
971 auto scalar_shape_matcher =
972 m::Shape().IsEffectiveScalar().WithElementType(PrimitiveType::S32);
973 if (Match(cmp, m::Compare(m::Op(&lhs),
974 m::Constant(&rhs).WithShape(scalar_shape_matcher))
975 .WithComparisonDirection(ComparisonDirection::kLt))) {
976 return {LessThanCompareInfo{lhs, *rhs->literal().GetFirstInteger()}};
977 } else if (Match(
978 cmp,
979 m::Compare(m::Constant(&lhs).WithShape(scalar_shape_matcher),
980 m::Op(&rhs))
981 .WithComparisonDirection(ComparisonDirection::kGt))) {
982 return {LessThanCompareInfo{rhs, *lhs->literal().GetFirstInteger()}};
983 }
984 return std::nullopt;
985 };
986
987 std::optional<LessThanCompareInfo> lhs_info = get_compare_info(lhs);
988 std::optional<LessThanCompareInfo> rhs_info = get_compare_info(rhs);
989 if (lhs_info && rhs_info && lhs_info->var == rhs_info->var) {
990 int64_t new_bound = std::min(lhs_info->constant, rhs_info->constant);
991 TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
992 conjunction,
993 HloInstruction::CreateCompare(lhs->shape(), lhs_info->var,
994 MakeScalarLike(lhs_info->var, new_bound),
995 ComparisonDirection::kLt)));
996 return true;
997 }
998 return false;
999 }
1000
HandleAnd(HloInstruction * logical_and)1001 Status AlgebraicSimplifierVisitor::HandleAnd(HloInstruction* logical_and) {
1002 HloInstruction *lhs, *rhs;
1003 CHECK(Match(logical_and, m::And(m::Op(&lhs), m::Op(&rhs))));
1004 // Simplify logical and
1005 if (ShapeUtil::HasPrimitiveType(lhs->shape(), xla::PRED) &&
1006 ShapeUtil::HasPrimitiveType(rhs->shape(), xla::PRED)) {
1007 // A && True => A
1008 VLOG(10) << "trying transform [A && True => A]: "
1009 << logical_and->ToString();
1010 if (IsAll(rhs, 1) && ReplaceInstructionIfCompatible(logical_and, lhs)) {
1011 return OkStatus();
1012 }
1013 // True && A => A
1014 VLOG(10) << "trying transform [True && A => A]: "
1015 << logical_and->ToString();
1016 if (IsAll(lhs, 1) && ReplaceInstructionIfCompatible(logical_and, rhs)) {
1017 return OkStatus();
1018 }
1019 }
1020
1021 // A && False => False or A & 0 => 0
1022 VLOG(10) << "trying transform [A && False => False]: "
1023 << logical_and->ToString();
1024 if (IsAll(rhs, 0) && ReplaceInstructionIfCompatible(logical_and, rhs)) {
1025 return OkStatus();
1026 }
1027
1028 // False && A => False or A & 0 => 0
1029 VLOG(10) << "trying transform [False && A => False]: "
1030 << logical_and->ToString();
1031 if (IsAll(lhs, 0) && ReplaceInstructionIfCompatible(logical_and, lhs)) {
1032 return OkStatus();
1033 }
1034
1035 // Simplify tautological conjunctions.
1036 TF_ASSIGN_OR_RETURN(bool found_tautological_compare,
1037 TrySimplifyTautologicalCompare(logical_and));
1038 if (found_tautological_compare) {
1039 return OkStatus();
1040 }
1041
1042 return OkStatus();
1043 }
1044
HandleBitcast(HloInstruction * bitcast)1045 Status AlgebraicSimplifierVisitor::HandleBitcast(HloInstruction* bitcast) {
1046 // If a bitcast feeds a bitcast, make it a single bitcast.
1047 // Make sure the whole chain of bitcasts is optimized.
1048 if (bitcast->operand(0)->opcode() == HloOpcode::kBitcast) {
1049 TF_RETURN_IF_ERROR(HandleBitcast(bitcast->mutable_operand(0)));
1050 }
1051 HloInstruction* op;
1052 if (Match(bitcast, m::Bitcast(m::Bitcast(m::Op(&op))))) {
1053 return ReplaceWithNewInstruction(
1054 bitcast, HloInstruction::CreateBitcast(bitcast->shape(), op));
1055 }
1056 // All bitcasts can be eliminated (assuming layout constraints are satisfied).
1057 ReplaceInstructionIfCompatible(bitcast, bitcast->mutable_operand(0));
1058 return OkStatus();
1059 }
1060
1061 // Compute a pair of maps for a bitcast operation, specifically between its
1062 // result logical dimensions and the original logical dimensions of the operand.
1063 // The maps are computed by matching the physical layout dimensions
1064 // (minor-to-major) of the operands and the bitcasted result. Overall they
1065 // record how the different logical dimensions of the operand may be combined or
1066 // split in the resulting shape and in which orders they are combined/split. The
1067 // function returns std::nullopt if unsuccessful (e.g., such a logical
1068 // dimension mapping cannot be constructed due to cases like bitcasting {4,4} to
1069 // {2,8}.
1070 std::optional<std::vector<std::vector<int64_t>>>
ComputeBitcastDimMap(const Shape & bitcast_shape,const Shape & operand_shape)1071 AlgebraicSimplifierVisitor::ComputeBitcastDimMap(const Shape& bitcast_shape,
1072 const Shape& operand_shape) {
1073 std::vector<std::vector<int64_t>> operand_dim_map(
1074 operand_shape.dimensions_size());
1075 int64_t bitcast_rank = bitcast_shape.dimensions_size();
1076 int64_t operand_rank = operand_shape.dimensions_size();
1077 int64_t cur_bitcast_size = 1, cur_operand_size = 1;
1078 int64_t operand_pos = -1, operand_dim = -1;
1079 for (int64_t bitcast_pos = 0; bitcast_pos < bitcast_rank; ++bitcast_pos) {
1080 int64_t bitcast_dim = bitcast_shape.layout().minor_to_major(bitcast_pos);
1081 if (operand_pos >= operand_rank) {
1082 if (bitcast_shape.dimensions(bitcast_dim) != 1) {
1083 VLOG(3) << "Abort b/c bitcasted size is bigger than operand size.\n";
1084 return std::nullopt;
1085 }
1086 continue;
1087 }
1088 CHECK_LT(bitcast_dim, bitcast_shape.dimensions_size());
1089 int64_t bitcast_dim_size = bitcast_shape.dimensions()[bitcast_dim];
1090 auto prev_bitcast_size = cur_bitcast_size;
1091 cur_bitcast_size *= bitcast_dim_size;
1092 VLOG(2) << "bitcast pos = " << bitcast_pos << "\n";
1093 VLOG(2) << "bitcast size = " << cur_bitcast_size << "\n";
1094 if (cur_operand_size < cur_bitcast_size &&
1095 prev_bitcast_size < cur_operand_size) {
1096 // Here we are bitcasting (m1,n1) to (m2,n2), with m1 > m2 and m2 * n2
1097 // < m1, so (m1,n1) is re-partitioned instead of split or combined.
1098 VLOG(3) << "Abort b/c re-partitioning a group of dimensions is not "
1099 "supported. \n";
1100 return std::nullopt;
1101 }
1102 while (operand_pos < operand_rank) {
1103 if (operand_pos < 0 || cur_operand_size < cur_bitcast_size) {
1104 VLOG(2) << "operand size < bitcase size\n";
1105 operand_pos++;
1106 if (operand_pos >= operand_rank) {
1107 VLOG(2)
1108 << "Abort due to size inconsistency: bitcasted size > operand "
1109 "size.\n";
1110 return std::nullopt;
1111 }
1112 operand_dim = operand_shape.layout().minor_to_major(operand_pos);
1113 int64_t op_dim_size = operand_shape.dimensions()[operand_dim];
1114 cur_operand_size *= op_dim_size;
1115 VLOG(3) << "operand size = " << cur_operand_size << "\n";
1116 if (cur_operand_size > cur_bitcast_size &&
1117 op_dim_size < bitcast_dim_size && operand_pos > 0) {
1118 // Here we are bitcasting (m1,n1) to (m2,n2), with n1 < n2 and m1 * n1
1119 // > m2, so (m1,n1) is re-partitioned instead of split or combined.
1120 VLOG(3) << "Abort b/c re-partitioning a group of dimensions is not "
1121 "supported. \n";
1122 return std::nullopt;
1123 }
1124 }
1125 CHECK_GE(operand_dim, 0);
1126 if (operand_shape.dimensions(operand_dim) > 1) {
1127 CHECK_LT(operand_dim, operand_dim_map.size());
1128 operand_dim_map[operand_dim].push_back(bitcast_dim);
1129 VLOG(3) << "operand dim_map[operand_dim] add " << bitcast_dim << " at "
1130 << operand_dim << "\n";
1131 }
1132 if (cur_operand_size >= cur_bitcast_size) {
1133 VLOG(3) << cur_operand_size << ">=" << cur_bitcast_size << "\n";
1134 CHECK_GE(operand_dim, 0);
1135 // If operand_dim is a degenerate one, move on to the next dimension.
1136 if (operand_shape.dimensions()[operand_dim] == 1) {
1137 operand_pos++;
1138 }
1139 break;
1140 }
1141 }
1142 }
1143 return operand_dim_map;
1144 }
1145
ReshapeLayoutDimensions(const Shape & original_shape,const Shape & result_shape,const std::vector<std::vector<int64_t>> & original_map,const std::vector<std::vector<int64_t>> & result_map)1146 std::optional<Shape> AlgebraicSimplifierVisitor::ReshapeLayoutDimensions(
1147 const Shape& original_shape, const Shape& result_shape,
1148 const std::vector<std::vector<int64_t>>& original_map,
1149 const std::vector<std::vector<int64_t>>& result_map) {
1150 auto original_dimensions = original_shape.layout().minor_to_major();
1151 Shape new_shape = result_shape;
1152 auto* reshaped_dimensions =
1153 new_shape.mutable_layout()->mutable_minor_to_major();
1154 int64_t bitcast_pos = -1;
1155 for (int64_t op_pos = 0; op_pos < original_dimensions.size(); ++op_pos) {
1156 int64_t op_dim = original_dimensions[op_pos];
1157 VLOG(3) << "op_pos = " << op_pos << "\n";
1158 VLOG(3) << "op_dim = " << op_dim << "\n";
1159 if (original_map.size() <= op_dim) {
1160 VLOG(3) << "Skip due to original_map has too few dimensions.\n";
1161 continue;
1162 }
1163 auto bit_dims = original_map[op_dim];
1164 for (int64_t bitcast_dim : bit_dims) {
1165 if (result_shape.dimensions(bitcast_dim) == 1) {
1166 // Postpone all degenerated dimensions (those with size 1) to the end.
1167 continue;
1168 }
1169 VLOG(3) << "Add new reshaped dimension:" << bitcast_dim << "\n";
1170 if (bitcast_pos < 0 ||
1171 (*reshaped_dimensions)[bitcast_pos] != bitcast_dim) {
1172 bitcast_pos++;
1173 // If bitcast_pos has been over incremented, the new bitcast would
1174 // have to combine non-contiguous dimensions in op. Abort.
1175 if (bitcast_pos >= reshaped_dimensions->size()) {
1176 VLOG(3) << "bitcast pos is over incremented:" << bitcast_pos << "\n";
1177 return std::nullopt;
1178 }
1179 (*reshaped_dimensions)[bitcast_pos] = bitcast_dim;
1180 }
1181 auto op_dims = result_map[bitcast_dim];
1182 if (op_dims.size() > 1 && op_pos > 0) {
1183 // Check that op dimensions that are combined into bitcast_dim are not
1184 // non-contiguous or reordered to be different from how they appear in
1185 // result_map.
1186 int64_t op_dim_prev = original_dimensions[op_pos - 1];
1187 // If the current dimension is not the first being combined into
1188 // bitcast_dim, or is not contiguous with the previous dimension, abort.
1189 if (op_dims[0] != op_dim &&
1190 (original_map[op_dim_prev].empty() ||
1191 original_map[op_dim_prev][0] != bitcast_dim)) {
1192 VLOG(2) << "Abort b/c op dimensions that are combined into "
1193 "bitcast_dim are not contiguous in the result. \n ";
1194 return std::nullopt;
1195 }
1196 // Now perform the dimension re-ordering check in the bitcast.
1197 for (int i = 0; i < op_dims.size(); ++i) {
1198 if (op_dims[i] == op_dim_prev) {
1199 if (i == op_dims.size() - 1 || op_dims[i + 1] != op_dim) {
1200 VLOG(2) << "Abort b/c op dimensions that are combined into "
1201 "bitcast_dim are reordered in the new bitcast. \n ";
1202 return std::nullopt;
1203 }
1204 }
1205 }
1206 }
1207 }
1208 }
1209 for (int i = 0; i < result_shape.rank(); ++i) {
1210 if (result_shape.dimensions(i) == 1) {
1211 bitcast_pos++;
1212 // Since there is a possiblity of over-incrementing bitcast_pos
1213 // we need such a check here also before accessing the vector.
1214 // Overincrementing is possible when the result's dimension is
1215 // smaller than the original dimension.
1216 if (bitcast_pos >= reshaped_dimensions->size()) {
1217 VLOG(3) << "bitcast pos is over incremented:" << bitcast_pos << "\n";
1218 return std::nullopt;
1219 }
1220 (*reshaped_dimensions)[bitcast_pos] = i;
1221 }
1222 }
1223 CHECK_EQ(bitcast_pos + 1, result_shape.rank());
1224 return new_shape;
1225 }
1226
1227 std::vector<std::vector<int64_t>>
InvertBitcastDimMap(const Shape & original_shape,const Shape & bitcast_shape,const std::vector<std::vector<int64_t>> & original_map)1228 AlgebraicSimplifierVisitor::InvertBitcastDimMap(
1229 const Shape& original_shape, const Shape& bitcast_shape,
1230 const std::vector<std::vector<int64_t>>& original_map) {
1231 std::vector<std::vector<int64_t>> result_map(bitcast_shape.dimensions_size());
1232 // Invert the operand map into result map.
1233 for (auto i = 0; i < original_shape.rank(); ++i) {
1234 auto j = original_shape.layout().minor_to_major(i);
1235 VLOG(3) << "traversing minor to major (" << i << ")=" << j;
1236 for (auto k : original_map[j]) {
1237 VLOG(3) << "setting result_map[" << k << "] = " << j << "\n";
1238 result_map[k].push_back(j);
1239 }
1240 }
1241 return result_map;
1242 }
1243
SwapCopyBitcastCopy(HloInstruction * root_copy)1244 bool AlgebraicSimplifierVisitor::SwapCopyBitcastCopy(
1245 HloInstruction* root_copy) {
1246 if (root_copy->opcode() != HloOpcode::kCopy) {
1247 return false;
1248 }
1249 HloInstruction* bitcast = root_copy->mutable_operand(0);
1250 if (bitcast->opcode() != HloOpcode::kBitcast) {
1251 return false;
1252 }
1253 // All bitcasts above can be collapsed.
1254 HloInstruction* copy = bitcast->mutable_operand(0);
1255 while (copy->opcode() == HloOpcode::kBitcast) {
1256 copy = copy->mutable_operand(0);
1257 }
1258 if (copy->opcode() != HloOpcode::kCopy) {
1259 return false;
1260 }
1261 VLOG(2) << "Processing " << copy->ToString() << "\n"
1262 << bitcast->ToString() << "\n"
1263 << root_copy->ToString() << "\n";
1264 HloInstruction* op = copy->mutable_operand(0);
1265 // Compute a pair of maps between op dimensions and bitcast dimensions.
1266 auto dim_map = ComputeBitcastDimMap(bitcast->shape(), copy->shape());
1267 if (!dim_map.has_value()) {
1268 VLOG(3) << "Failed to compute bitcast map.";
1269 return false;
1270 }
1271 std::vector<std::vector<int64_t>> operand_map = dim_map.value();
1272 if (!ValidateTilingOfBitcast(bitcast->shape(), copy->shape(), operand_map)) {
1273 VLOG(2) << "Abort because bitcast changes tiling assignment.\n";
1274 return false;
1275 }
1276 std::vector<std::vector<int64_t>> result_map =
1277 InvertBitcastDimMap(copy->shape(), bitcast->shape(), operand_map);
1278 if (ValidateTilingOfBitcast(bitcast->shape(), op->shape(), operand_map)) {
1279 auto new_shape = ReshapeLayoutDimensions(op->shape(), bitcast->shape(),
1280 operand_map, result_map);
1281 if (!new_shape.has_value() || !IsValidLayout(new_shape.value())) {
1282 return false;
1283 }
1284 auto repl = HloInstruction::CreateUnary(
1285 root_copy->shape(), HloOpcode::kCopy,
1286 bitcast->AddInstruction(
1287 bitcast->CloneWithNewOperands(new_shape.value(), {op})));
1288 VLOG(2) << "Replace with " << repl->operand(0)->ToString() << "\n"
1289 << repl->ToString() << "\n";
1290 TF_CHECK_OK(ReplaceWithNewInstruction(root_copy, std::move(repl)));
1291 return true;
1292 }
1293
1294 if (ValidateTilingOfBitcast(copy->shape(), root_copy->shape(), result_map)) {
1295 auto new_shape = ReshapeLayoutDimensions(root_copy->shape(), copy->shape(),
1296 result_map, operand_map);
1297 if (!new_shape.has_value() || !IsValidLayout(new_shape.value())) {
1298 return false;
1299 }
1300 auto repl = HloInstruction::CreateUnary(
1301 root_copy->shape(), HloOpcode::kBitcast,
1302 bitcast->AddInstruction(
1303 root_copy->CloneWithNewOperands(new_shape.value(), {op})));
1304 VLOG(2) << "Replace with " << repl->operand(0)->ToString() << "\n"
1305 << repl->ToString() << "\n";
1306 TF_CHECK_OK(ReplaceWithNewInstruction(root_copy, std::move(repl)));
1307 return true;
1308 }
1309 return false;
1310 }
1311
HandleBitcastConvert(HloInstruction * bitcast)1312 Status AlgebraicSimplifierVisitor::HandleBitcastConvert(
1313 HloInstruction* bitcast) {
1314 TF_ASSIGN_OR_RETURN(bool replaced,
1315 TrySimplifyTautologicalBitcastConvert(bitcast));
1316 if (replaced) {
1317 return OkStatus();
1318 }
1319 // Eliminate bitcast converts between same shape.
1320 ReplaceInstructionIfCompatible(bitcast, bitcast->mutable_operand(0));
1321 return OkStatus();
1322 }
1323
HandleCopy(HloInstruction * copy)1324 Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) {
1325 if (SwapCopyBitcastCopy(copy)) {
1326 return OkStatus();
1327 }
1328 // If a copy feeds a copy, make it a single copy.
1329 HloInstruction* op;
1330 if (Match(copy, m::Copy(m::Copy(m::Op(&op))))) {
1331 if (ShapeUtil::Equal(op->shape(), copy->shape())) {
1332 return ReplaceInstruction(copy, op);
1333 }
1334 return ReplaceWithNewInstruction(
1335 copy, HloInstruction::CreateUnary(copy->shape(), HloOpcode::kCopy, op));
1336 }
1337 // All copies can be eliminated (assuming layout constraints are satisfied).
1338 if (ReplaceInstructionIfCompatible(copy, copy->mutable_operand(0))) {
1339 return OkStatus();
1340 }
1341
1342 if (HloInstruction* bitcast_operand =
1343 BitcastingOperandOfReshapeOrCopyChain(copy, options_)) {
1344 ReplaceWithBitcast(copy, bitcast_operand);
1345 return OkStatus();
1346 }
1347
1348 // Replace Copy(Reshape()) with Reshape() if the Reshape is a logical bitcast.
1349 if (copy->operand(0)->opcode() == HloOpcode::kReshape &&
1350 copy->operand(0)->user_count() == 1 &&
1351 ShapeUtil::ReshapeIsBitcast(copy->operand(0)->shape(), copy->shape())) {
1352 return ReplaceWithNewInstruction(
1353 copy,
1354 copy->operand(0)->CloneWithNewOperands(
1355 copy->shape(), {copy->mutable_operand(0)->mutable_operand(0)}));
1356 }
1357 return OkStatus();
1358 }
1359
HandleConcatenate(HloInstruction * concatenate)1360 Status AlgebraicSimplifierVisitor::HandleConcatenate(
1361 HloInstruction* concatenate) {
1362 absl::Span<HloInstruction* const> operands(concatenate->operands());
1363 if (operands.size() == 1) {
1364 // Unary concatenates are useless.
1365 ReplaceInstructionIfCompatible(concatenate, operands[0]);
1366 return OkStatus();
1367 }
1368 // Filter out and remove empty operands.
1369 std::vector<HloInstruction*> nonempty_operands;
1370 for (HloInstruction* operand : operands) {
1371 if (!ShapeUtil::IsZeroElementArray(operand->shape())) {
1372 nonempty_operands.push_back(operand);
1373 }
1374 }
1375 if (nonempty_operands.size() < operands.size()) {
1376 HloInstruction* replacement;
1377 if (nonempty_operands.empty()) {
1378 replacement = operands[0];
1379 } else if (nonempty_operands.size() == 1) {
1380 replacement = nonempty_operands[0];
1381 } else {
1382 replacement =
1383 concatenate->AddInstruction(concatenate->CloneWithNewOperands(
1384 concatenate->shape(), nonempty_operands));
1385 }
1386 VLOG(10) << "trying to replace " << concatenate->ToString() << " with "
1387 << replacement->ToString();
1388 ReplaceInstructionIfCompatible(concatenate, replacement);
1389 return OkStatus();
1390 }
1391
1392 if (options_.is_layout_sensitive()) {
1393 return OkStatus();
1394 }
1395
1396 // concat(x, concat(y, z)) -> concat(x, y, z). We only do this in
1397 // layout-insensitive mode because some backends may have (late,
1398 // layout-sensitive) passes that break up ops with many operands into smaller
1399 // pieces. This would undo that.
1400 absl::InlinedVector<HloInstruction*, 8> unnested_concat_operands;
1401 for (HloInstruction* operand : operands) {
1402 if (operand->opcode() == HloOpcode::kConcatenate &&
1403 operand->concatenate_dimension() ==
1404 concatenate->concatenate_dimension()) {
1405 for (HloInstruction* instr : operand->operands()) {
1406 unnested_concat_operands.push_back(instr);
1407 }
1408 } else {
1409 unnested_concat_operands.push_back(operand);
1410 }
1411 }
1412 if (unnested_concat_operands.size() != concatenate->operand_count()) {
1413 return ReplaceWithNewInstruction(
1414 concatenate, HloInstruction::CreateConcatenate(
1415 concatenate->shape(), unnested_concat_operands,
1416 concatenate->concatenate_dimension()));
1417 }
1418
1419 // Check if we can merge "adjacent" slice operands which take slices from the
1420 // same other op. For simplicity we only merge unstrided slices.
1421 int64_t concatenate_dimension = concatenate->concatenate_dimension();
1422 std::vector<HloInstruction*> new_operands;
1423 int64_t i = 0;
1424 while (i < operands.size()) {
1425 if (operands[i]->opcode() != HloOpcode::kSlice ||
1426 !IsUnstridedSlice(operands[i])) {
1427 new_operands.push_back(operands[i]);
1428 ++i;
1429 continue;
1430 }
1431 int64_t slice_end = operands[i]->slice_limits(concatenate_dimension);
1432 HloInstruction* slice_operand = operands[i]->mutable_operand(0);
1433 int64_t j = i + 1;
1434 while (j < operands.size()) {
1435 if (operands[j]->opcode() != HloOpcode::kSlice ||
1436 !IsUnstridedSlice(operands[j]) ||
1437 operands[j]->operand(0) != slice_operand ||
1438 operands[j]->slice_starts(concatenate_dimension) != slice_end) {
1439 break;
1440 }
1441 // Check that all the slice_start values are the same in all other
1442 // dimensions. This implies that the slice_limit values are also the same,
1443 // because operands of concatenate need to have the same shape, and we
1444 // already checked that the slices are unstrided.
1445 bool same_other_starts = true;
1446 for (int64_t k = 0; k < operands[j]->slice_starts().size(); ++k) {
1447 if (k == concatenate_dimension) {
1448 continue;
1449 }
1450 if (operands[i]->slice_starts(k) != operands[j]->slice_starts(k)) {
1451 same_other_starts = false;
1452 break;
1453 }
1454 }
1455 if (!same_other_starts) {
1456 break;
1457 }
1458 slice_end = operands[j]->slice_limits(concatenate_dimension);
1459 ++j;
1460 }
1461 if (j - i > 1) {
1462 Shape new_slice_shape = operands[i]->shape();
1463 new_slice_shape.set_dimensions(
1464 concatenate_dimension,
1465 slice_end - operands[i]->slice_starts(concatenate_dimension));
1466 simplifier_->UpdateLayout(&new_slice_shape);
1467 auto new_limit_indices = operands[i]->slice_limits();
1468 new_limit_indices[concatenate_dimension] = slice_end;
1469 auto new_slice_op =
1470 operands[i]->AddInstruction(HloInstruction::CreateSlice(
1471 new_slice_shape, slice_operand,
1472 /*start_indices=*/operands[i]->slice_starts(),
1473 /*limit_indices=*/new_limit_indices,
1474 /*strides=*/operands[i]->slice_strides()));
1475 new_operands.push_back(new_slice_op);
1476 } else {
1477 new_operands.push_back(operands[i]);
1478 }
1479 i = j;
1480 }
1481 if (new_operands.size() < operands.size()) {
1482 auto replacement = concatenate->AddInstruction(
1483 concatenate->CloneWithNewOperands(concatenate->shape(), new_operands));
1484 ReplaceInstructionIfCompatible(concatenate, replacement);
1485 return OkStatus();
1486 }
1487
1488 if (operands.size() == 2) {
1489 // A binary concat with a broadcasted scalar as an operand can be converted
1490 // into a pad which is simpler to fold into other operations.
1491 bool is_effective_low_pad = Match(
1492 operands[0], m::Broadcast(m::Op().WithShape(m::Shape().IsScalar())));
1493 bool is_effective_high_pad = Match(
1494 operands[1], m::Broadcast(m::Op().WithShape(m::Shape().IsScalar())));
1495 if (!is_effective_low_pad && !is_effective_high_pad) {
1496 return OkStatus();
1497 }
1498 PaddingConfig padding_config;
1499 for (int64_t dim = 0; dim < operands[0]->shape().rank(); ++dim) {
1500 auto padding_config_dim = padding_config.add_dimensions();
1501 padding_config_dim->set_edge_padding_high(0);
1502 padding_config_dim->set_edge_padding_low(0);
1503 padding_config_dim->set_interior_padding(0);
1504 if (dim == concatenate_dimension) {
1505 if (is_effective_low_pad) {
1506 padding_config_dim->set_edge_padding_low(
1507 operands[0]->shape().dimensions(dim));
1508 } else {
1509 padding_config_dim->set_edge_padding_high(
1510 operands[1]->shape().dimensions(dim));
1511 }
1512 }
1513 }
1514 int64_t operand_to_pad = is_effective_low_pad ? 1 : 0;
1515 int64_t pad_value_operand = is_effective_low_pad ? 0 : 1;
1516 HloInstruction* pad = concatenate->AddInstruction(HloInstruction::CreatePad(
1517 concatenate->shape(), operands[operand_to_pad],
1518 operands[pad_value_operand]->mutable_operand(0), padding_config));
1519 return ReplaceInstruction(concatenate, pad);
1520 }
1521
1522 if (absl::c_count(operands, operands[0]) == operands.size() &&
1523 operands[0]->shape().dimensions(concatenate_dimension) == 1) {
1524 Shape new_shape = operands[0]->shape();
1525 DimensionVector broadcast_dims;
1526 for (int64_t i = 0; i < new_shape.rank(); ++i) {
1527 if (i == concatenate_dimension) {
1528 continue;
1529 }
1530 broadcast_dims.push_back(i);
1531 }
1532 new_shape.DeleteDimension(concatenate_dimension);
1533 return ReplaceInstruction(
1534 concatenate,
1535 MakeBroadcastHlo(MakeReshapeHlo(new_shape, operands[0]).ValueOrDie(),
1536 broadcast_dims, concatenate->shape()));
1537 }
1538 return OkStatus();
1539 }
1540
1541 StatusOr<bool>
TrySimplifyTautologicalBitcastConvert(HloInstruction * bitcast)1542 AlgebraicSimplifierVisitor::TrySimplifyTautologicalBitcastConvert(
1543 HloInstruction* bitcast) {
1544 CHECK_EQ(bitcast->opcode(), HloOpcode::kBitcastConvert);
1545 PrimitiveType outer_to = bitcast->shape().element_type();
1546 HloInstruction* concat = bitcast->mutable_operand(0);
1547 if (concat->opcode() != HloOpcode::kConcatenate) {
1548 return false;
1549 }
1550 std::vector<HloInstruction*> outer_inputs;
1551 std::vector<HloInstruction*> to_remove_bitcasts;
1552 for (int i = 0; i < concat->operand_count(); i++) {
1553 HloInstruction* in = concat->mutable_operand(i);
1554 if (in->opcode() != HloOpcode::kBitcastConvert ||
1555 in->operand(0)->shape().element_type() != outer_to) {
1556 return false;
1557 }
1558 outer_inputs.push_back(in->mutable_operand(0));
1559 to_remove_bitcasts.push_back(in);
1560 }
1561
1562 const int64_t concat_dim = concat->concatenate_dimension();
1563 TF_ASSIGN_OR_RETURN(HloInstruction * new_concat,
1564 MakeConcatHlo(outer_inputs, concat_dim));
1565 TF_RETURN_IF_ERROR(ReplaceInstruction(bitcast, new_concat));
1566
1567 return true;
1568 }
1569
BuildTupleConstant(HloComputation * computation,const LiteralSlice & literal,AlgebraicSimplifier * simplifier)1570 static HloInstruction* BuildTupleConstant(HloComputation* computation,
1571 const LiteralSlice& literal,
1572 AlgebraicSimplifier* simplifier) {
1573 if (literal.shape().IsTuple()) {
1574 std::vector<HloInstruction*> elems;
1575 elems.reserve(ShapeUtil::TupleElementCount(literal.shape()));
1576 for (int i = 0; i < ShapeUtil::TupleElementCount(literal.shape()); ++i) {
1577 elems.push_back(BuildTupleConstant(
1578 computation, LiteralSlice(literal, {i}), simplifier));
1579 }
1580 return computation->AddInstruction(HloInstruction::CreateTuple(elems));
1581 } else {
1582 return computation->AddInstruction(
1583 simplifier->CreateConstantWithLayoutUpdated(literal.Clone()));
1584 }
1585 }
1586
HandleConstant(HloInstruction * constant)1587 Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) {
1588 // Tuple constants aren't directly supported by any backend. Expand them into
1589 // explicit Tuple instructions.
1590 if (constant->shape().IsTuple()) {
1591 return ReplaceInstruction(
1592 constant,
1593 BuildTupleConstant(computation_, constant->literal(), simplifier_));
1594 }
1595
1596 if (constant->shape().element_type() == TOKEN) {
1597 return OkStatus();
1598 }
1599
1600 // If a literal is all the same element replace it with a scalar broadcast.
1601 if (ShapeUtil::ElementsIn(constant->shape()) > 1 &&
1602 constant->literal().IsAllFirst()) {
1603 Literal unique_scalar(
1604 LiteralUtil::GetFirstScalarLiteral(constant->literal()));
1605 HloInstruction* scalar = constant->AddInstruction(
1606 simplifier_->CreateConstantWithLayoutUpdated(std::move(unique_scalar)));
1607 return ReplaceWithNewInstruction(
1608 constant,
1609 HloInstruction::CreateBroadcast(constant->shape(), scalar, {}));
1610 }
1611
1612 // If a literal is an increasing sequence from zero, replace it with an iota.
1613 if (ShapeUtil::ElementsIn(constant->shape()) > 1 &&
1614 constant->literal().IsR1Iota()) {
1615 return ReplaceWithNewInstruction(
1616 constant, HloInstruction::CreateIota(constant->shape(), 0));
1617 }
1618
1619 if (std::optional<int64_t> stride = constant->literal().IsR1StridedIota()) {
1620 // Replace the constant with iota * stride.
1621 HloInstruction* stride_hlo = MakeScalarLike(constant, *stride);
1622 HloInstruction* iota = constant->AddInstruction(
1623 HloInstruction::CreateIota(constant->shape(), 0));
1624 return ReplaceWithNewInstruction(
1625 constant,
1626 HloInstruction::CreateBinary(constant->shape(), HloOpcode::kMultiply,
1627 iota, stride_hlo));
1628 }
1629
1630 return OkStatus();
1631 }
1632
HandleSubtract(HloInstruction * sub)1633 Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) {
1634 HloInstruction *lhs, *rhs;
1635 CHECK(Match(sub, m::Subtract(m::Op(&lhs), m::Op(&rhs))));
1636 // A - 0 => A
1637 VLOG(10) << "trying transform [A - 0 => A]: " << sub->ToString();
1638 if (IsAll(rhs, 0) && ReplaceInstructionIfCompatible(sub, lhs)) {
1639 return OkStatus();
1640 }
1641
1642 // Canonicalize subtraction of a constant to addition.
1643 VLOG(10) << "trying transform [A - Const => A + (-Const)]";
1644 if (Match(sub, m::Subtract(m::NonConstant(&lhs), m::Constant(&rhs))) ||
1645 Match(sub, m::Subtract(m::NonConstant(&lhs),
1646 m::Broadcast(m::Constant(&rhs))))) {
1647 HloInstruction* negative_const = rhs->AddInstruction(
1648 HloInstruction::CreateUnary(rhs->shape(), HloOpcode::kNegate, rhs));
1649 if (const HloInstruction* broadcast =
1650 DynCast<HloBroadcastInstruction>(sub->operand(1))) {
1651 negative_const = rhs->AddInstruction(HloInstruction::CreateBroadcast(
1652 broadcast->shape(), negative_const, broadcast->dimensions()));
1653 }
1654 return ReplaceWithNewInstruction(
1655 sub, HloInstruction::CreateBinary(sub->shape(), HloOpcode::kAdd, lhs,
1656 negative_const));
1657 }
1658
1659 // A - A => 0 for integer A.
1660 VLOG(10) << "trying transform [A - A => 0] for integer A.";
1661 if (lhs == rhs && ShapeUtil::ElementIsIntegral(sub->shape())) {
1662 return ReplaceInstruction(sub, MakeScalarLike(sub, 0));
1663 }
1664
1665 return OkStatus();
1666 }
1667 namespace {
1668 template <typename T>
InvertConstant(const HloInstruction & constant,Literal * result)1669 Status InvertConstant(const HloInstruction& constant, Literal* result) {
1670 return result->Populate<T>([&](absl::Span<const int64_t> indices) {
1671 return T{1.0} / constant.literal().Get<T>(indices);
1672 });
1673 }
1674
1675 template <typename T>
TryDivideToShift(HloInstruction * divide,HloComputation * computation,AlgebraicSimplifier * simplifier)1676 std::unique_ptr<HloInstruction> TryDivideToShift(
1677 HloInstruction* divide, HloComputation* computation,
1678 AlgebraicSimplifier* simplifier) {
1679 HloInstruction *a, *b, *c;
1680 CHECK(Match(divide, m::Divide(m::Op(&a), m::Op(&b))));
1681
1682 if (ShapeUtil::ElementIsIntegral(divide->shape()) &&
1683 !Match(b, m::ConstantEffectiveScalar(&c)) &&
1684 !Match(b, m::Broadcast(m::ConstantEffectiveScalar(&c)))) {
1685 return nullptr;
1686 }
1687
1688 if (ShapeUtil::ElementIsSigned(divide->shape())) {
1689 int64_t b_value = c->literal().GetFirstElement<T>();
1690 if (b_value > 0 && absl::has_single_bit(static_cast<uint64_t>(b_value))) {
1691 // Handle negative dividends by negating the result of the division.
1692 HloInstruction* zero_like_a = MakeScalarLike(a, 0);
1693
1694 Shape changed_shape = ShapeUtil::ChangeElementType(a->shape(), PRED);
1695 simplifier->UpdateLayout(&changed_shape);
1696 auto* dividend_is_negative =
1697 divide->AddInstruction(HloInstruction::CreateCompare(
1698 changed_shape, a, zero_like_a, ComparisonDirection::kLt));
1699
1700 auto* negated_dividend = divide->AddInstruction(
1701 HloInstruction::CreateUnary(a->shape(), HloOpcode::kNegate, a));
1702
1703 auto* abs_dividend = divide->AddInstruction(HloInstruction::CreateTernary(
1704 a->shape(), HloOpcode::kSelect, dividend_is_negative,
1705 negated_dividend, a));
1706
1707 auto* quotient = divide->AddInstruction(HloInstruction::CreateBinary(
1708 divide->shape(), HloOpcode::kShiftRightLogical, abs_dividend,
1709 MakeScalarLike(abs_dividend, Log2Floor<uint64_t>(b_value))));
1710
1711 auto* neqated_quotient =
1712 divide->AddInstruction(HloInstruction::CreateUnary(
1713 quotient->shape(), HloOpcode::kNegate, quotient));
1714
1715 return HloInstruction::CreateTernary(divide->shape(), HloOpcode::kSelect,
1716 dividend_is_negative,
1717 neqated_quotient, quotient);
1718 }
1719 } else {
1720 uint64_t b_value = c->literal().GetFirstElement<T>();
1721 if (absl::has_single_bit(b_value)) {
1722 return HloInstruction::CreateBinary(
1723 divide->shape(), HloOpcode::kShiftRightLogical, a,
1724 MakeScalarLike(a, Log2Floor(b_value)));
1725 }
1726 }
1727
1728 return nullptr;
1729 }
1730 } // namespace
1731
HandleDivide(HloInstruction * divide)1732 Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
1733 HloInstruction *a, *b, *c, *d;
1734 CHECK(Match(divide, m::Divide(m::Op(&a), m::Op(&b))));
1735 // A/1 => A
1736 VLOG(10) << "trying transform [A/1 => A]: " << divide->ToString();
1737 if (IsAll(b, 1) && ReplaceInstructionIfCompatible(divide, a)) {
1738 return OkStatus();
1739 }
1740
1741 // A / B => A >> log2(B) if B is a power of 2.
1742 switch (divide->shape().element_type()) {
1743 case S8:
1744 if (std::unique_ptr<HloInstruction> shift =
1745 TryDivideToShift<int8_t>(divide, computation_, simplifier_)) {
1746 return ReplaceWithNewInstruction(divide, std::move(shift));
1747 }
1748 break;
1749 case S16:
1750 if (std::unique_ptr<HloInstruction> shift =
1751 TryDivideToShift<int16_t>(divide, computation_, simplifier_)) {
1752 return ReplaceWithNewInstruction(divide, std::move(shift));
1753 }
1754 break;
1755 case S32:
1756 if (std::unique_ptr<HloInstruction> shift =
1757 TryDivideToShift<int32_t>(divide, computation_, simplifier_)) {
1758 return ReplaceWithNewInstruction(divide, std::move(shift));
1759 }
1760 break;
1761 case S64:
1762 if (std::unique_ptr<HloInstruction> shift =
1763 TryDivideToShift<int64_t>(divide, computation_, simplifier_)) {
1764 return ReplaceWithNewInstruction(divide, std::move(shift));
1765 }
1766 break;
1767 case U8:
1768 if (std::unique_ptr<HloInstruction> shift =
1769 TryDivideToShift<uint8_t>(divide, computation_, simplifier_)) {
1770 return ReplaceWithNewInstruction(divide, std::move(shift));
1771 }
1772 break;
1773 case U16:
1774 if (std::unique_ptr<HloInstruction> shift =
1775 TryDivideToShift<uint16_t>(divide, computation_, simplifier_)) {
1776 return ReplaceWithNewInstruction(divide, std::move(shift));
1777 }
1778 break;
1779 case U32:
1780 if (std::unique_ptr<HloInstruction> shift =
1781 TryDivideToShift<uint32_t>(divide, computation_, simplifier_)) {
1782 return ReplaceWithNewInstruction(divide, std::move(shift));
1783 }
1784 break;
1785 case U64:
1786 if (std::unique_ptr<HloInstruction> shift =
1787 TryDivideToShift<uint64_t>(divide, computation_, simplifier_)) {
1788 return ReplaceWithNewInstruction(divide, std::move(shift));
1789 }
1790 break;
1791 default:
1792 break;
1793 }
1794
1795 Shape* shape;
1796 // exp(A)/exp(B) => exp(A-B)
1797 if (Match(divide, m::Divide(m::Exp(m::Op(&a)), m::Exp(m::Op(&b)))
1798 .WithShape(m::Shape(&shape)))) {
1799 VLOG(10) << "transform [exp(A)/exp(B) => exp(A-B)]: " << divide->ToString();
1800 HloInstruction* subtract = divide->AddInstruction(
1801 HloInstruction::CreateBinary(*shape, HloOpcode::kSubtract, a, b));
1802 return ReplaceWithNewInstruction(
1803 divide, HloInstruction::CreateUnary(*shape, HloOpcode::kExp, subtract));
1804 }
1805
1806 // A/exp(B) => A*exp(-B)
1807 if (Match(divide, m::Divide(m::Op(&a), m::Exp(m::Op(&b))))) {
1808 VLOG(10) << "transform [A/exp(B) => A*exp(-B)]: " << divide->ToString();
1809 HloInstruction* negate = divide->AddInstruction(
1810 HloInstruction::CreateUnary(divide->shape(), HloOpcode::kNegate, b));
1811 HloInstruction* new_exp = divide->mutable_operand(1)->AddInstruction(
1812 HloInstruction::CreateUnary(divide->shape(), HloOpcode::kExp, negate));
1813 return ReplaceWithNewInstruction(
1814 divide, HloInstruction::CreateBinary(divide->shape(),
1815 HloOpcode::kMultiply, a, new_exp));
1816 }
1817
1818 // A/pow(B,C) => A*pow(B,-C)
1819 if (Match(divide, m::Divide(m::Op(&a), m::Power(m::Op(&b), m::Op(&c))))) {
1820 VLOG(10) << "transform [A/pow(B,C) => A*pow(B,-C)]: " << divide->ToString();
1821 // The output shape of the created negate operator should be the same as the
1822 // input.
1823 const Shape& negate_shape = c->shape();
1824 HloInstruction* negate = divide->mutable_operand(1)->AddInstruction(
1825 HloInstruction::CreateUnary(negate_shape, HloOpcode::kNegate, c));
1826 // And the power operator should retain the output shape of the old one.
1827 const Shape& new_power_shape = b->shape();
1828 HloInstruction* new_power =
1829 divide->mutable_operand(1)->AddInstruction(HloInstruction::CreateBinary(
1830 new_power_shape, HloOpcode::kPower, b, negate));
1831 return ReplaceWithNewInstruction(
1832 divide, HloInstruction::CreateBinary(
1833 divide->shape(), HloOpcode::kMultiply, a, new_power));
1834 }
1835
1836 // A/sqrt(B) => A*rsqrt(X).
1837 if (Match(divide, m::Divide(m::Op(&a), m::Sqrt(m::Op(&b))))) {
1838 auto* rsqrt = divide->mutable_operand(1)->AddInstruction(
1839 HloInstruction::CreateUnary(divide->shape(), HloOpcode::kRsqrt, b));
1840 return ReplaceWithNewInstruction(
1841 divide, HloInstruction::CreateBinary(rsqrt->shape(),
1842 HloOpcode::kMultiply, a, rsqrt));
1843 }
1844
1845 // A/rsqrt(B) => A*sqrt(B).
1846 if (Match(divide, m::Divide(m::Op(&a), m::Rsqrt(m::Op(&b))))) {
1847 auto* sqrt = divide->mutable_operand(1)->AddInstruction(
1848 HloInstruction::CreateUnary(divide->shape(), HloOpcode::kSqrt, b));
1849 return ReplaceWithNewInstruction(
1850 divide, HloInstruction::CreateBinary(sqrt->shape(),
1851 HloOpcode::kMultiply, a, sqrt));
1852 }
1853
1854 // Simplifying integral division would produce unexpected results.
1855 if (ShapeUtil::ElementIsIntegral(divide->shape())) {
1856 return OkStatus();
1857 }
1858
1859 // A / Const => A * (1 / Const)
1860 //
1861 // (Backends can do this transformation, but generally only if the constant is
1862 // a scalar.)
1863 if (Match(divide, m::Divide(m::NonConstant(&a), m::Op(&b))) &&
1864 (Match(b, m::Constant(&c)) || Match(b, m::Broadcast(m::Constant(&c))))) {
1865 Shape result_shape = c->literal().shape();
1866 Literal new_literal(result_shape);
1867 switch (result_shape.element_type()) {
1868 case F16:
1869 TF_RETURN_IF_ERROR(InvertConstant<half>(*c, &new_literal));
1870 break;
1871 case F32:
1872 TF_RETURN_IF_ERROR(InvertConstant<float>(*c, &new_literal));
1873 break;
1874 case BF16:
1875 TF_RETURN_IF_ERROR(InvertConstant<bfloat16>(*c, &new_literal));
1876 break;
1877 case F64:
1878 TF_RETURN_IF_ERROR(InvertConstant<double>(*c, &new_literal));
1879 break;
1880 case C64:
1881 TF_RETURN_IF_ERROR(InvertConstant<complex64>(*c, &new_literal));
1882 break;
1883 case C128:
1884 TF_RETURN_IF_ERROR(InvertConstant<complex128>(*c, &new_literal));
1885 break;
1886 default:
1887 return OkStatus();
1888 }
1889 auto inverse = c->AddInstruction(
1890 simplifier_->CreateConstantWithLayoutUpdated(new_literal.Clone()));
1891 if (b != c) {
1892 inverse = b->AddInstruction(HloInstruction::CreateBroadcast(
1893 b->shape(), inverse, b->dimensions()));
1894 }
1895 TF_ASSIGN_OR_RETURN(auto new_divide,
1896 MakeBinaryHlo(HloOpcode::kMultiply, a, inverse));
1897 return ReplaceInstruction(divide, new_divide);
1898 }
1899
1900 // (A / B) / (C / D) => (A / B)*(D / C) => (A * D) / (B * C)
1901 if (Match(divide, m::Divide(m::Divide(m::Op(&a), m::Op(&b)),
1902 m::Divide(m::Op(&c), m::Op(&d))))) {
1903 TF_ASSIGN_OR_RETURN(auto a_times_d,
1904 MakeBinaryHlo(HloOpcode::kMultiply, a, d));
1905 TF_ASSIGN_OR_RETURN(auto b_times_c,
1906 MakeBinaryHlo(HloOpcode::kMultiply, b, c));
1907 TF_ASSIGN_OR_RETURN(auto new_divide, MakeBinaryHlo(HloOpcode::kDivide,
1908 a_times_d, b_times_c));
1909
1910 return ReplaceInstruction(divide, new_divide);
1911 }
1912
1913 // (A / B) / C => A / (B * C)
1914 if (Match(divide, m::Divide(m::Divide(m::Op(&a), m::Op(&b)), m::Op(&c)))) {
1915 TF_ASSIGN_OR_RETURN(auto b_times_c,
1916 MakeBinaryHlo(HloOpcode::kMultiply, b, c));
1917 TF_ASSIGN_OR_RETURN(auto new_divide,
1918 MakeBinaryHlo(HloOpcode::kDivide, a, b_times_c));
1919 return ReplaceInstruction(divide, new_divide);
1920 }
1921
1922 // A / (B / C) => (A*C) / B
1923 if (Match(divide, m::Divide(m::Op(&a), m::Divide(m::Op(&b), m::Op(&c))))) {
1924 TF_ASSIGN_OR_RETURN(auto a_times_c,
1925 MakeBinaryHlo(HloOpcode::kMultiply, a, c));
1926 TF_ASSIGN_OR_RETURN(auto new_divide,
1927 MakeBinaryHlo(HloOpcode::kDivide, a_times_c, b));
1928 return ReplaceInstruction(divide, new_divide);
1929 }
1930
1931 // If X is a convert from pred, then
1932 // X / broadcast(Y) => broadcast(1/Y) * X
1933 if (Match(divide,
1934 m::Divide(
1935 m::Convert(&a,
1936 m::Op().WithShape(m::Shape().WithElementType(PRED))),
1937 m::Broadcast(m::Op(&b).WithShape(m::Shape().IsScalar()))))) {
1938 TF_ASSIGN_OR_RETURN(
1939 auto recip, MakeBinaryHlo(HloOpcode::kDivide, MakeScalarLike(b, 1), b));
1940 auto recip_bcast = divide->mutable_operand(1)->AddInstruction(
1941 HloInstruction::CreateBroadcast(divide->shape(), recip, {}));
1942 TF_ASSIGN_OR_RETURN(auto mul,
1943 MakeBinaryHlo(HloOpcode::kMultiply, recip_bcast, a));
1944 return ReplaceInstruction(divide, mul);
1945 }
1946
1947 return OkStatus();
1948 }
1949
RemoveDegenerateDimensionFromDot(HloInstruction * dot)1950 StatusOr<bool> AlgebraicSimplifierVisitor::RemoveDegenerateDimensionFromDot(
1951 HloInstruction* dot) {
1952 const Shape& lhs_shape = dot->operand(0)->shape();
1953 int64_t num_degenerate_lhs_dims = 0;
1954 std::vector<int64_t> lhs_dimension_map(lhs_shape.rank(), -1);
1955 for (int64_t i = 0; i < lhs_shape.rank(); ++i) {
1956 if (lhs_shape.dimensions(i) == 1) {
1957 ++num_degenerate_lhs_dims;
1958 } else {
1959 lhs_dimension_map[i] = i - num_degenerate_lhs_dims;
1960 }
1961 }
1962
1963 const Shape& rhs_shape = dot->operand(1)->shape();
1964 int64_t num_degenerate_rhs_dims = 0;
1965 std::vector<int64_t> rhs_dimension_map(rhs_shape.rank(), -1);
1966 for (int64_t i = 0; i < rhs_shape.rank(); ++i) {
1967 if (rhs_shape.dimensions(i) == 1) {
1968 ++num_degenerate_rhs_dims;
1969 } else {
1970 rhs_dimension_map[i] = i - num_degenerate_rhs_dims;
1971 }
1972 }
1973 if (num_degenerate_lhs_dims == 0 && num_degenerate_rhs_dims == 0) {
1974 return false;
1975 }
1976 const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
1977 DotDimensionNumbers new_dnums;
1978 for (int64_t dim : dnums.lhs_batch_dimensions()) {
1979 int64_t new_dim = lhs_dimension_map[dim];
1980 if (new_dim != -1) {
1981 new_dnums.add_lhs_batch_dimensions(new_dim);
1982 }
1983 }
1984 for (int64_t dim : dnums.lhs_contracting_dimensions()) {
1985 int64_t new_dim = lhs_dimension_map[dim];
1986 if (new_dim != -1) {
1987 new_dnums.add_lhs_contracting_dimensions(new_dim);
1988 }
1989 }
1990
1991 for (int64_t dim : dnums.rhs_batch_dimensions()) {
1992 int64_t new_dim = rhs_dimension_map[dim];
1993 if (new_dim != -1) {
1994 new_dnums.add_rhs_batch_dimensions(new_dim);
1995 }
1996 }
1997 for (int64_t dim : dnums.rhs_contracting_dimensions()) {
1998 int64_t new_dim = rhs_dimension_map[dim];
1999 if (new_dim != -1) {
2000 new_dnums.add_rhs_contracting_dimensions(new_dim);
2001 }
2002 }
2003
2004 HloInstruction* new_lhs =
2005 num_degenerate_lhs_dims > 0
2006 ? dot->parent()->AddInstruction(HloInstruction::CreateReshape(
2007 ShapeUtil::DropDegenerateDimensions(lhs_shape),
2008 dot->mutable_operand(0)))
2009 : dot->mutable_operand(0);
2010 HloInstruction* new_rhs =
2011 num_degenerate_rhs_dims > 0
2012 ? dot->parent()->AddInstruction(HloInstruction::CreateReshape(
2013 ShapeUtil::DropDegenerateDimensions(rhs_shape),
2014 dot->mutable_operand(1)))
2015 : dot->mutable_operand(1);
2016 TF_ASSIGN_OR_RETURN(
2017 auto new_dot,
2018 MakeDotHlo(new_lhs, new_rhs, new_dnums, dot->precision_config(),
2019 /*preferred_element_type=*/dot->shape().element_type()));
2020 if (ShapeUtil::Compatible(dot->shape(), new_dot->shape())) {
2021 TF_RETURN_IF_ERROR(ReplaceInstruction(dot, new_dot));
2022 } else {
2023 TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
2024 dot, HloInstruction::CreateReshape(dot->shape(), new_dot)));
2025 }
2026 return true;
2027 }
2028
RemoveTransposesFromDotOperands(HloInstruction * dot)2029 StatusOr<bool> AlgebraicSimplifierVisitor::RemoveTransposesFromDotOperands(
2030 HloInstruction* dot) {
2031 const int64_t rank = dot->shape().rank();
2032 const auto& dnums = dot->dot_dimension_numbers();
2033 HloInstruction* lhs = dot->mutable_operand(0);
2034 HloInstruction* rhs = dot->mutable_operand(1);
2035
2036 // lhs and rhs must apply the same permutation.
2037 if (lhs->opcode() != HloOpcode::kTranspose ||
2038 rhs->opcode() != HloOpcode::kTranspose ||
2039 lhs->dimensions() != rhs->dimensions()) {
2040 return false;
2041 }
2042 absl::Span<const int64_t> permutation = lhs->dimensions();
2043
2044 // Dot must be "somewhat canonical": batch dimensions at the beginning, one
2045 // contracting dimension, and one non-contracting dim.
2046 if (absl::MakeSpan(dnums.lhs_batch_dimensions()) !=
2047 absl::MakeSpan(dnums.rhs_batch_dimensions()) ||
2048 dnums.lhs_contracting_dimensions_size() != 1 ||
2049 dnums.rhs_contracting_dimensions_size() != 1 ||
2050 dnums.lhs_contracting_dimensions(0) != rank - 1 ||
2051 dnums.rhs_contracting_dimensions(0) != rank - 2 ||
2052 rank != dnums.lhs_batch_dimensions_size() + 2) {
2053 return false;
2054 }
2055
2056 // The last two elements of the permutation must be either [rank-2, rank-1]
2057 // (i.e. no permutation) or [rank-1, rank-2]. Otherwise, this means that
2058 // we're permuting batch dimensions with the non-batch dimensions, which isn't
2059 // allowed.
2060 //
2061 // If the permutation ends with [rank - 1, rank - 2] then we're going to flip
2062 // the order of dot operands to dot(b,a). Otherwise it stays dot(a,b).
2063 bool reorder_operands;
2064 if (permutation.subspan(rank - 2) ==
2065 std::array<int64_t, 2>{rank - 2, rank - 1}) {
2066 reorder_operands = false;
2067 } else if (permutation.subspan(rank - 2) ==
2068 std::array<int64_t, 2>{rank - 1, rank - 2}) {
2069 reorder_operands = true;
2070 } else {
2071 return false;
2072 }
2073
2074 HloInstruction* new_lhs =
2075 reorder_operands ? rhs->mutable_operand(0) : lhs->mutable_operand(0);
2076 HloInstruction* new_rhs =
2077 reorder_operands ? lhs->mutable_operand(0) : rhs->mutable_operand(0);
2078 auto new_dot = dot->AddInstruction(HloInstruction::CreateDot(
2079 ShapeUtil::PermuteDimensions(permutation, dot->shape()), new_lhs, new_rhs,
2080 dnums,
2081 reorder_operands
2082 ? SwapOperandsInDotPrecisionConfig(dot->precision_config())
2083 : dot->precision_config()));
2084 TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
2085 dot,
2086 HloInstruction::CreateTranspose(dot->shape(), new_dot, permutation)));
2087 return true;
2088 }
2089
2090 StatusOr<HloInstruction*>
NormalizeDotOperandToBatchMajorAndContractingMinor(HloInstruction * dot_operand,absl::Span<const int64_t> batch_dimensions,absl::Span<const int64_t> contracting_dimensions)2091 AlgebraicSimplifierVisitor::NormalizeDotOperandToBatchMajorAndContractingMinor(
2092 HloInstruction* dot_operand, absl::Span<const int64_t> batch_dimensions,
2093 absl::Span<const int64_t> contracting_dimensions) {
2094 std::vector<int64_t> transpose_dimensions(batch_dimensions.begin(),
2095 batch_dimensions.end());
2096 for (int64_t i = 0; i < dot_operand->shape().rank(); ++i) {
2097 if (!(absl::c_linear_search(batch_dimensions, i) ||
2098 absl::c_linear_search(contracting_dimensions, i))) {
2099 transpose_dimensions.push_back(i);
2100 }
2101 }
2102 transpose_dimensions.insert(transpose_dimensions.end(),
2103 contracting_dimensions.begin(),
2104 contracting_dimensions.end());
2105 if (absl::c_is_sorted(transpose_dimensions)) {
2106 return dot_operand;
2107 }
2108 return MakeTransposeHlo(dot_operand, transpose_dimensions);
2109 }
2110
AddReduce(HloInstruction * hlo,absl::Span<const int64_t> dims,PrimitiveType type)2111 HloInstruction* AlgebraicSimplifierVisitor::AddReduce(
2112 HloInstruction* hlo, absl::Span<const int64_t> dims, PrimitiveType type) {
2113 HloInstruction* zero =
2114 computation_->AddInstruction(simplifier_->CreateConstantWithLayoutUpdated(
2115 LiteralUtil::Zero(hlo->shape().element_type()).Clone()));
2116 HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation(type);
2117 Shape shape = ShapeUtil::DeleteDimensions(dims, hlo->shape());
2118 simplifier_->UpdateLayout(&shape);
2119 return computation_->AddInstruction(HloInstruction::CreateReduce(
2120 shape, hlo, zero, dims, AddReduce_computation));
2121 }
2122
OptimizeDotOfConcat(HloInstruction * dot)2123 StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcat(
2124 HloInstruction* dot) {
2125 const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
2126 if (dnums.lhs_contracting_dimensions_size() != 1 ||
2127 dnums.lhs_batch_dimensions_size() != 0 ||
2128 dot->shape().dimensions_size() != 2) { // dot output 2D
2129 return nullptr;
2130 }
2131
2132 const int64_t lhs_contracting_dim = dnums.lhs_contracting_dimensions(0);
2133 const int64_t rhs_contracting_dim = dnums.rhs_contracting_dimensions(0);
2134 HloInstruction *lhs, *rhs;
2135 CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs))));
2136
2137 TF_ASSIGN_OR_RETURN(
2138 HloInstruction * optimized_lhs_concat,
2139 OptimizeDotOfConcatHelper(dot, lhs, lhs_contracting_dim, rhs,
2140 rhs_contracting_dim, /*swapped=*/false));
2141 if (optimized_lhs_concat) {
2142 return optimized_lhs_concat;
2143 }
2144
2145 return OptimizeDotOfConcatHelper(dot, rhs, rhs_contracting_dim, lhs,
2146 lhs_contracting_dim, /*swapped=*/true);
2147 }
2148
OptimizeDotOfConcatHelper(HloInstruction * dot,HloInstruction * lhs,int64_t lhs_contracting_dim,HloInstruction * rhs,int64_t rhs_contracting_dim,bool swapped)2149 StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper(
2150 HloInstruction* dot, HloInstruction* lhs, int64_t lhs_contracting_dim,
2151 HloInstruction* rhs, int64_t rhs_contracting_dim, bool swapped) {
2152 bool can_optimize = lhs->opcode() == HloOpcode::kConcatenate &&
2153 lhs->concatenate_dimension() == lhs_contracting_dim &&
2154 rhs->opcode() == HloOpcode::kConstant;
2155 if (!can_optimize) {
2156 return nullptr;
2157 }
2158
2159 // We're replacing this:
2160 //
2161 // +-----+-----+-----+ +-------------------+
2162 // | | | | | |
2163 // | | | | | R_0 |
2164 // | | | | | |
2165 // | | | | +-------------------+
2166 // | | | | | |
2167 // | L_0 | L_1 | L_2 | * | R_1 |
2168 // | | | | | |
2169 // | | | | +-------------------+
2170 // | | | | | |
2171 // | | | | | R_2 |
2172 // | | | | | |
2173 // +-----+-----+-----+ +-------------------+
2174 //
2175 // with this:
2176 //
2177 // [Sum over i]
2178 //
2179 // +-----+ +-------------------+
2180 // | | | |
2181 // | | * | R_i |
2182 // | | | |
2183 // | | +-------------------+
2184 // | |
2185 // | L_i |
2186 // | |
2187 // | |
2188 // | |
2189 // | |
2190 // | |
2191 // +-----+
2192 //
2193 // where the LHS is a concatenate operation (so we can "split" the LHS tensor
2194 // for free) and the RHS is a constant tensor (and thus can be split at
2195 // compile time). In the future, we may also want to do this when both the
2196 // LHS and the RHS are concatenate operations that line up along the dimension
2197 // being contracted over.
2198 //
2199 // We should be able to generalize this transform to work on a non-constant
2200 // RHS when/if we have in-place slices or support input-fusing slices into
2201 // Dots.
2202
2203 // Dimension numbers for the new dot instructions we'll create (L_i * R_i in
2204 // the diagram above).
2205 DotDimensionNumbers new_dot_dnums;
2206 new_dot_dnums.add_lhs_contracting_dimensions(swapped ? rhs_contracting_dim
2207 : lhs_contracting_dim);
2208 new_dot_dnums.add_rhs_contracting_dimensions(swapped ? lhs_contracting_dim
2209 : rhs_contracting_dim);
2210
2211 // Here we use the MKN notation, where the contracted dimension has K
2212 // elements and the two non-contracted dimensions have M and N elements.
2213 HloInstruction* add_result = nullptr;
2214 int64_t rhs_contracting_dim_offset = 0;
2215 int64_t n = rhs->shape().dimensions(1 - rhs_contracting_dim);
2216 for (HloInstruction* concat_op : lhs->operands()) {
2217 int64_t sub_k = concat_op->shape().dimensions(lhs_contracting_dim);
2218 Shape rhs_slice_shape(rhs->shape());
2219 rhs_slice_shape.set_dimensions(rhs_contracting_dim, sub_k);
2220 simplifier_->UpdateLayout(&rhs_slice_shape);
2221
2222 std::array<int64_t, 2> start_indices;
2223 start_indices[rhs_contracting_dim] = rhs_contracting_dim_offset;
2224 start_indices[1 - rhs_contracting_dim] = 0;
2225
2226 std::array<int64_t, 2> limit_indices;
2227 limit_indices[rhs_contracting_dim] = rhs_contracting_dim_offset + sub_k;
2228 limit_indices[1 - rhs_contracting_dim] = n;
2229
2230 HloInstruction* rhs_slice = rhs->AddInstruction(HloInstruction::CreateSlice(
2231 rhs_slice_shape, rhs, /*start_indices=*/start_indices,
2232 /*limit_indices=*/limit_indices, /*strides=*/{1, 1}));
2233
2234 // TODO(b/69062148): We can get rid of `swapped` once all backends support
2235 // "non-canonical" contraction dimensions (that contracts dimension 1 of the
2236 // LHS with dimension 0 of the RHS). But for now we keep the same
2237 // contraction dimensions as the incoming dot operation to ensure the new
2238 // dot operations can be lowered.
2239 HloInstruction *new_dot_lhs, *new_dot_rhs;
2240 if (swapped) {
2241 new_dot_lhs = rhs_slice;
2242 new_dot_rhs = concat_op;
2243 } else {
2244 new_dot_lhs = concat_op;
2245 new_dot_rhs = rhs_slice;
2246 }
2247
2248 auto* new_dot = dot->AddInstruction(
2249 HloInstruction::CreateDot(dot->shape(), new_dot_lhs, new_dot_rhs,
2250 new_dot_dnums, dot->precision_config()));
2251
2252 if (add_result) {
2253 add_result = dot->AddInstruction(HloInstruction::CreateBinary(
2254 dot->shape(), HloOpcode::kAdd, add_result, new_dot));
2255 } else {
2256 add_result = new_dot;
2257 }
2258
2259 rhs_contracting_dim_offset += sub_k;
2260 }
2261
2262 return add_result;
2263 }
2264
OptimizeDotOfGather(HloInstruction * dot)2265 StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfGather(
2266 HloInstruction* dot) {
2267 const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
2268 if (dnums.lhs_contracting_dimensions_size() != 1 ||
2269 dnums.rhs_contracting_dimensions_size() != 1 ||
2270 dnums.lhs_batch_dimensions_size() != 0 ||
2271 dnums.rhs_batch_dimensions_size() != 0 ||
2272 dot->shape().dimensions_size() != 2) { // dot output 2D
2273 VLOG(10) << "DotOfGather: Can only optimize 2D, non-batch dot operations.";
2274 return nullptr;
2275 }
2276
2277 // Optimize either dot(DS(ctA), ctB)) or dot(ctB, DS(ctA)).
2278 // Currently a Gather is a DynamicSlice.
2279 auto is_dynamic_slice_constant_combination =
2280 [](HloInstruction* a, HloInstruction* b, int a_contracting_dimension) {
2281 // First operand is a DynamicSlice(Constant).
2282 if (a->opcode() != HloOpcode::kDynamicSlice) {
2283 return false;
2284 }
2285 auto* dynamic_slice_op = a->operand(0);
2286 if (dynamic_slice_op->opcode() != HloOpcode::kConstant) {
2287 return false;
2288 }
2289 // Second operand is a Constant.
2290 if (b->opcode() != HloOpcode::kConstant) {
2291 return false;
2292 }
2293 // The DynamicSlice output is a vector.
2294 const Shape& dynamic_slice_shape = a->shape();
2295 if (dynamic_slice_shape.dimensions(1 - a_contracting_dimension) != 1) {
2296 return false;
2297 }
2298 // Constant size is the same before and after slice in the contracting
2299 // dimension, otherwise we either must precompute for all possible slice
2300 // indices or dot is invalid.
2301 const Shape& dynamic_slice_op_shape = dynamic_slice_op->shape();
2302 if (dynamic_slice_op_shape.dimensions(a_contracting_dimension) !=
2303 dynamic_slice_shape.dimensions(a_contracting_dimension)) {
2304 return false;
2305 }
2306 return true;
2307 };
2308
2309 HloInstruction* lhs = dot->mutable_operand(0);
2310 HloInstruction* rhs = dot->mutable_operand(1);
2311 int lhs_contracting_dimension = dnums.lhs_contracting_dimensions(0);
2312 int rhs_contracting_dimension = dnums.rhs_contracting_dimensions(0);
2313
2314 if (!is_dynamic_slice_constant_combination(
2315 lhs, rhs, /*a_contracting_dimension=*/lhs_contracting_dimension) &&
2316 !is_dynamic_slice_constant_combination(
2317 rhs, lhs, /*a_contracting_dimension=*/rhs_contracting_dimension)) {
2318 VLOG(10) << "DotOfGather: Can only optimize dot(DS(ctA), ctB)) or "
2319 "dot(ctB, DS(ctA)), where the two constants have equal "
2320 "contracting dimensions.";
2321 return nullptr;
2322 }
2323
2324 // LHS is DynamicSlice:
2325 // input: dot(DS(ctA), ctB))
2326 // where DS(ctA) = DS({M x K}, {start, 0}, {1, K}) and ctB = {K x N}.
2327 // => input dimensions: dot({1 x K}, {K x N}) => {1 x N}.
2328 // output: DS(dot(ctA, ctB))
2329 // => output dimensions: DS ({M x N}, {start, 0}, {1, N}) => {1 x N}.
2330
2331 // RHS is DynamicSlice:
2332 // input: dot(ctA, DS(ctB))
2333 // where ctA = {M x K} and DS(ctB) = DS({K x N}, {0, start}, {K, 1}).
2334 // => input dimensions: dot({M x K}, {K x 1}) => {M x 1}.
2335 // output: DS(dot(ctA, ctB))
2336 // => output dimensions: DS ({M x N}, {0, start}, {M, 1}) => {M x 1}.
2337
2338 bool lhs_is_dynamic_slice = lhs->opcode() == HloOpcode::kDynamicSlice;
2339 HloDynamicSliceInstruction* dynamic_slice =
2340 lhs_is_dynamic_slice ? Cast<HloDynamicSliceInstruction>(lhs)
2341 : Cast<HloDynamicSliceInstruction>(rhs);
2342
2343 // ctA:
2344 HloInstruction* left_operand =
2345 lhs_is_dynamic_slice ? lhs->mutable_operand(0) : lhs;
2346 // ctB:
2347 HloInstruction* right_operand =
2348 lhs_is_dynamic_slice ? rhs : rhs->mutable_operand(0);
2349 // Build ctA x ctB.
2350 const int m = left_operand->shape().dimensions(1 - lhs_contracting_dimension);
2351 const int n =
2352 right_operand->shape().dimensions(1 - rhs_contracting_dimension);
2353 auto memoized_shape =
2354 ShapeUtil::MakeShape(dot->shape().element_type(), {m, n});
2355 simplifier_->UpdateLayout(&memoized_shape);
2356 auto* memoized_inst = dot->AddInstruction(
2357 HloInstruction::CreateDot(memoized_shape, left_operand, right_operand,
2358 dnums, dot->precision_config()));
2359 // Get pair {start, 0} or {0, start}.
2360 // Position of start:
2361 int index_of_non_zero_start = lhs_is_dynamic_slice
2362 ? 1 - lhs_contracting_dimension
2363 : 1 - rhs_contracting_dimension;
2364 // Position of zero:
2365 int index_of_zero_start = 1 - index_of_non_zero_start;
2366
2367 // Slice out start and 0 components and reorder if necessary.
2368 auto indices_type = dynamic_slice->operand(1)->shape().element_type();
2369 Shape s_shape = ShapeUtil::MakeShape(indices_type, {1});
2370 simplifier_->UpdateLayout(&s_shape);
2371 Shape d_shape = ShapeUtil::MakeShape(indices_type, {2});
2372 simplifier_->UpdateLayout(&d_shape);
2373 HloInstruction* non_zero_start =
2374 dynamic_slice->mutable_operand(1 + index_of_non_zero_start);
2375 HloInstruction* zero_start =
2376 dynamic_slice->mutable_operand(1 + index_of_zero_start);
2377 std::vector<HloInstruction*> new_start_indices;
2378 if (lhs_is_dynamic_slice) {
2379 new_start_indices = {non_zero_start, zero_start};
2380 } else {
2381 new_start_indices = {zero_start, non_zero_start};
2382 }
2383
2384 // Build DynamicSlice(ctA x ctB).
2385 const int new_slice_m = lhs_is_dynamic_slice ? 1 : m;
2386 const int new_slice_n = lhs_is_dynamic_slice ? n : 1;
2387 auto* memoized_lookup =
2388 dot->AddInstruction(HloInstruction::CreateDynamicSlice(
2389 dot->shape(), memoized_inst, new_start_indices,
2390 {new_slice_m, new_slice_n}));
2391
2392 return memoized_lookup;
2393 }
2394
2395 // This function tries to transform
2396 // dot(reshape(transpose(A)), Const) to
2397 // dot(reshape(A), reshape(transpose(reshape(Const)))),
2398 // so that the reshape and transpose on the Const side can be constant folded.
2399 //
2400 // The basic idea is that since the accumulation in the dot operation is
2401 // associative, so as long as we permute the elements of the contracting
2402 // dimensions on both sides of the dot in the same way, the result of the
2403 // dot is not affected.
2404 StatusOr<HloInstruction*>
OptimizeDotOfReorderContractingDims(HloInstruction * dot)2405 AlgebraicSimplifierVisitor::OptimizeDotOfReorderContractingDims(
2406 HloInstruction* dot) {
2407 // This transformation assumes layout is not assigned yet.
2408 if (options_.is_layout_sensitive()) {
2409 return nullptr;
2410 }
2411
2412 // Canonicalize dot(<constant>, rhs) to dot(rhs, <constant>) to make the
2413 // remainder of this function easier.
2414 auto dnums = dot->dot_dimension_numbers();
2415 auto lhs_contracting_dims = dnums.lhs_contracting_dimensions();
2416 auto rhs_contracting_dims = dnums.rhs_contracting_dimensions();
2417 auto* lhs = dot->mutable_operand(0);
2418 auto* rhs = dot->mutable_operand(1);
2419 if (dot->operand(0)->IsConstant()) {
2420 std::swap(lhs, rhs);
2421 std::swap(lhs_contracting_dims, rhs_contracting_dims);
2422 }
2423
2424 // Require single contracting dim to make the implementation easier to
2425 // track contracting dims.
2426 if (dnums.lhs_contracting_dimensions_size() != 1) {
2427 return nullptr;
2428 }
2429
2430 // Pattern match Dot(reshape(transpose(input), constant))
2431 HloInstruction* reshape;
2432 HloInstruction* transpose;
2433 HloInstruction* input;
2434 HloInstruction* constant;
2435 if (!Match(lhs,
2436 m::Reshape(&reshape, m::Transpose(&transpose, m::Op(&input)))) ||
2437 !Match(rhs, m::Constant(&constant))) {
2438 return nullptr;
2439 }
2440
2441 // Check that reshape squishes some dims into one dim and that this one
2442 // dim is the dot's lhs contracting dim. The size of unmodified_dims should
2443 // be N - 1, where N is the rank of the reshape output. This means that the
2444 // reshape squishes some dims into one dim. lhs contracting dim should not
2445 // be in unmodified_dims. This means that the squishing target dim is the
2446 // lhs contracting dim.
2447 auto unmodified_dims = ShapeUtil::DimensionsUnmodifiedByReshape(
2448 reshape->operand(0)->shape(), reshape->shape());
2449 CHECK_EQ(lhs_contracting_dims.size(), 1);
2450 if ((unmodified_dims.size() != reshape->shape().rank() - 1) ||
2451 absl::c_any_of(unmodified_dims,
2452 [&](const std::pair<int64_t, int64_t>& p) {
2453 return p.second == lhs_contracting_dims[0];
2454 })) {
2455 return nullptr;
2456 }
2457
2458 // Virtually pull the reshape into the dot so the dot operates on the
2459 // transpose, with "unsquished" lhs contracting dims. The new contracting
2460 // dims are all of the dims that are modified by the reshape -- that is, every
2461 // dimension that's not in `unmodified_dims[i].first`.
2462 //
2463 // (We don't need to actually create a new dot instruction. We can just keep
2464 // track of lhs and lhs_contracting_dims.)
2465 absl::flat_hash_set<int64_t> unmodified_transpose_dims;
2466 for (const auto& pair : unmodified_dims) {
2467 unmodified_transpose_dims.insert(pair.first);
2468 }
2469 lhs_contracting_dims.Clear();
2470 for (int64_t i = 0; i < transpose->shape().dimensions_size(); ++i) {
2471 if (!unmodified_transpose_dims.contains(i)) {
2472 lhs_contracting_dims.Add(i);
2473 }
2474 }
2475 // We require the "unsquished" lhs contracting dims to be consecutive.
2476 auto is_iota = [](absl::Span<const int64_t> dims) {
2477 return absl::c_adjacent_find(dims, [](const int64_t a, const int64_t b) {
2478 return (b != a + 1);
2479 }) == dims.end();
2480 };
2481 if (!is_iota(lhs_contracting_dims)) {
2482 return nullptr;
2483 }
2484 lhs = lhs->mutable_operand(0);
2485
2486 // Check that the transpose only permutes the contracting dims.
2487 const auto& transpose_dims = transpose->dimensions();
2488 for (int64_t i = 0; i < transpose_dims.size(); ++i) {
2489 if (transpose_dims[i] != i &&
2490 !absl::c_linear_search(lhs_contracting_dims, i)) {
2491 return nullptr;
2492 }
2493 }
2494 // Virtually pull the transpose into the dot. Now the dot is equivalent to
2495 // a new dot with "permuted" lhs contracting dims.
2496 std::vector<int64_t> permutation;
2497 permutation.reserve(lhs_contracting_dims.size());
2498 for (auto dim : lhs_contracting_dims) {
2499 permutation.push_back(transpose_dims[dim] - lhs_contracting_dims[0]);
2500 }
2501 CHECK(IsPermutation(permutation));
2502 auto new_lhs_contracting_dims =
2503 ComposePermutations(lhs_contracting_dims, permutation);
2504 lhs_contracting_dims.Clear();
2505 for (auto dim : new_lhs_contracting_dims) {
2506 lhs_contracting_dims.Add(dim);
2507 }
2508 lhs = lhs->mutable_operand(0);
2509
2510 // All checks are passed at this point.
2511 //
2512 // Transform lhs. Remove the transpose and reshape by sorting the lhs
2513 // contracting dims and squishing them into a single one. We don't actually
2514 // squish the lhs_contracting_dims here because we still need the unsquished
2515 // contracting dims to invert reshape and transpose.
2516 absl::c_sort(lhs_contracting_dims);
2517 lhs =
2518 dot->AddInstruction(HloInstruction::CreateReshape(reshape->shape(), lhs));
2519
2520 // Transform rhs. Say the input HLO is:
2521 //
2522 // t0 = f32[2, 2, 3] parameter(0)
2523 // t1 = f32[2, 3, 2] transpose(t0) dimensions={0, 2, 1}
2524 // t2 = f32[2, 6] reshape(t1)
2525 // t3 = f32[6, 2] constant(...)
2526 // dot = f32[2, 2] dot(t2, t3) lhs_contracting_dims={1},
2527 // rhs_contracting_dims={0}
2528 //
2529 // At this point in the function, we have decided that the second and third
2530 // dims of t0 can be switched to remove the transpose, and we have
2531 // "virtually decomposed" the input HLO to:
2532 //
2533 // t0 = f32[2, 2, 3] parameter(0)
2534 // t2' = f32[2, 6] reshape(t0)
2535 // t3' = f32[6, 2] ops-to-be-filled ...
2536 // dot = f32[2, 2] dot(t2', t3') lhs_contracting_dims={1},
2537 // rhs_contracting_dims={0}
2538 //
2539 // The rest of this function is to fill in the ops of t3'. To do this, we
2540 // unsquish the contracting dimensions in t3 and then apply the inverse of
2541 // the transpose from t1.
2542
2543 // Invert reshape.
2544 CHECK_EQ(rhs_contracting_dims.size(), 1);
2545 std::vector<int64_t> rhs_unsquished_shape_dims =
2546 SpanToVector(constant->shape().dimensions());
2547 auto it = rhs_unsquished_shape_dims.erase(rhs_unsquished_shape_dims.begin() +
2548 rhs_contracting_dims[0]);
2549 for (auto dim : lhs_contracting_dims) {
2550 it = rhs_unsquished_shape_dims.insert(it,
2551 transpose->shape().dimensions(dim));
2552 ++it;
2553 }
2554 HloInstruction* rhs_reshape =
2555 dot->AddInstruction(HloInstruction::CreateReshape(
2556 ShapeUtil::MakeShape(constant->shape().element_type(),
2557 rhs_unsquished_shape_dims),
2558 constant));
2559 rhs = rhs_reshape;
2560
2561 // Rhs reshape "unsquishes" the single rhs contracting dim into multiple dims.
2562 rhs_contracting_dims.Resize(lhs_contracting_dims.size(), 0);
2563 absl::c_iota(rhs_contracting_dims, rhs_contracting_dims[0]);
2564
2565 // Invert transpose. First compute the shape.
2566 std::vector<int64_t> rhs_transpose_shape_dims =
2567 SpanToVector(rhs_reshape->shape().dimensions());
2568 it = rhs_transpose_shape_dims.erase(
2569 rhs_transpose_shape_dims.begin() + rhs_contracting_dims[0],
2570 rhs_transpose_shape_dims.begin() + rhs_contracting_dims[0] +
2571 rhs_contracting_dims.size());
2572 for (auto dim : lhs_contracting_dims) {
2573 it = rhs_transpose_shape_dims.insert(it, input->shape().dimensions(dim));
2574 ++it;
2575 }
2576 // Then compute the transpose dims.
2577 std::vector<int64_t> rhs_transpose_dims(rhs_reshape->shape().rank());
2578 absl::c_iota(rhs_transpose_dims, 0);
2579 it = rhs_transpose_dims.erase(
2580 rhs_transpose_dims.begin() + rhs_contracting_dims[0],
2581 rhs_transpose_dims.begin() + rhs_contracting_dims[0] +
2582 rhs_contracting_dims.size());
2583 auto inverse_lhs_transpose_dims = InversePermutation(transpose_dims);
2584 for (auto dim : lhs_contracting_dims) {
2585 it = rhs_transpose_dims.insert(it, inverse_lhs_transpose_dims[dim] -
2586 lhs_contracting_dims[0] +
2587 rhs_contracting_dims[0]);
2588 ++it;
2589 }
2590 HloInstruction* rhs_transpose =
2591 dot->AddInstruction(HloInstruction::CreateTranspose(
2592 ShapeUtil::MakeShape(constant->shape().element_type(),
2593 rhs_transpose_shape_dims),
2594 rhs_reshape, rhs_transpose_dims));
2595 rhs = rhs_transpose;
2596
2597 // Squish the multiple rhs contracting dims into a single one.
2598 rhs = dot->AddInstruction(
2599 HloInstruction::CreateReshape(constant->shape(), rhs));
2600
2601 // If we virtually swapped lhs and rhs, we need to swap it back before
2602 // creating new dot.
2603 if (dot->operand(0)->IsConstant()) {
2604 std::swap(lhs, rhs);
2605 }
2606
2607 HloInstruction* new_dot = dot->AddInstruction(HloInstruction::CreateDot(
2608 dot->shape(), lhs, rhs, dnums, dot->precision_config()));
2609 return new_dot;
2610 }
2611
HandleDot(HloInstruction * dot)2612 Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
2613 CHECK(computation_ == dot->parent());
2614 const auto& dnums = dot->dot_dimension_numbers();
2615
2616 HloInstruction *lhs, *rhs;
2617 CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs))));
2618 if (options_.is_layout_sensitive()) {
2619 return OkStatus();
2620 }
2621 // Replace a zero element dot with a broadcast of the constant 0.
2622 if (ShapeUtil::IsZeroElementArray(dot->shape()) ||
2623 ShapeUtil::IsZeroElementArray(lhs->shape()) ||
2624 ShapeUtil::IsZeroElementArray(rhs->shape())) {
2625 auto zero =
2626 dot->AddInstruction(simplifier_->CreateConstantWithLayoutUpdated(
2627 LiteralUtil::Zero(dot->shape().element_type())));
2628 return ReplaceWithNewInstruction(
2629 dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {}));
2630 }
2631
2632 // If there are no contracting dimensions, a dot can be rewritten as
2633 // mul(broadcast(transpose(x)),broadcast(transpose(y)))
2634 if (options_.enable_dot_to_multiply_rewrite() &&
2635 dnums.lhs_contracting_dimensions_size() == 0) {
2636 TF_ASSIGN_OR_RETURN(HloInstruction * new_lhs,
2637 NormalizeDotOperandToBatchMajorAndContractingMinor(
2638 lhs, dnums.lhs_batch_dimensions(),
2639 dnums.lhs_contracting_dimensions()));
2640 if (!ShapeUtil::SameElementType(dot->shape(), new_lhs->shape())) {
2641 new_lhs = MakeConvertToHlo(new_lhs, dot->shape().element_type());
2642 }
2643 TF_ASSIGN_OR_RETURN(HloInstruction * new_rhs,
2644 NormalizeDotOperandToBatchMajorAndContractingMinor(
2645 rhs, dnums.rhs_batch_dimensions(),
2646 dnums.rhs_contracting_dimensions()));
2647 if (!ShapeUtil::SameElementType(dot->shape(), new_rhs->shape())) {
2648 new_rhs = MakeConvertToHlo(new_rhs, dot->shape().element_type());
2649 }
2650 if (dot->shape().rank() != lhs->shape().rank()) {
2651 std::vector<int64_t> lhs_broadcast_dims(lhs->shape().rank());
2652 absl::c_iota(lhs_broadcast_dims, 0);
2653 new_lhs = dot->AddInstruction(HloInstruction::CreateBroadcast(
2654 dot->shape(), new_lhs, lhs_broadcast_dims));
2655 }
2656 if (dot->shape().rank() != rhs->shape().rank()) {
2657 std::vector<int64_t> rhs_broadcast_dims(
2658 dnums.lhs_batch_dimensions_size());
2659 absl::c_iota(rhs_broadcast_dims, 0);
2660 for (int64_t i = lhs->shape().rank(); i < dot->shape().rank(); ++i) {
2661 rhs_broadcast_dims.push_back(i);
2662 }
2663 new_rhs = dot->AddInstruction(HloInstruction::CreateBroadcast(
2664 dot->shape(), new_rhs, rhs_broadcast_dims));
2665 }
2666 return ReplaceWithNewInstruction(
2667 dot, HloInstruction::CreateBinary(dot->shape(), HloOpcode::kMultiply,
2668 new_lhs, new_rhs));
2669 }
2670
2671 // If the lhs or rhs have only batch and contracting dimensions, a dot can be
2672 // rewritten as reduce(mul(broadcast(transpose(x)),broadcast(transpose(y))))
2673 if (options_.enable_dot_strength_reduction() &&
2674 ((dnums.lhs_batch_dimensions_size() +
2675 dnums.lhs_contracting_dimensions_size() ==
2676 lhs->shape().rank()) ||
2677 (dnums.rhs_contracting_dimensions_size() +
2678 dnums.rhs_batch_dimensions_size() ==
2679 rhs->shape().rank()))) {
2680 TF_ASSIGN_OR_RETURN(HloInstruction * new_lhs,
2681 NormalizeDotOperandToBatchMajorAndContractingMinor(
2682 lhs, dnums.lhs_batch_dimensions(),
2683 dnums.lhs_contracting_dimensions()));
2684 if (!ShapeUtil::SameElementType(dot->shape(), new_lhs->shape())) {
2685 new_lhs = MakeConvertToHlo(new_lhs, dot->shape().element_type());
2686 }
2687
2688 TF_ASSIGN_OR_RETURN(HloInstruction * new_rhs,
2689 NormalizeDotOperandToBatchMajorAndContractingMinor(
2690 rhs, dnums.rhs_batch_dimensions(),
2691 dnums.rhs_contracting_dimensions()));
2692 if (!ShapeUtil::SameElementType(dot->shape(), new_rhs->shape())) {
2693 new_rhs = MakeConvertToHlo(new_rhs, dot->shape().element_type());
2694 }
2695
2696 int64_t lhs_outer_dims =
2697 lhs->shape().rank() - (dnums.lhs_batch_dimensions_size() +
2698 dnums.lhs_contracting_dimensions_size());
2699 int64_t rhs_outer_dims =
2700 rhs->shape().rank() - (dnums.rhs_batch_dimensions_size() +
2701 dnums.rhs_contracting_dimensions_size());
2702 CHECK(lhs_outer_dims == 0 || rhs_outer_dims == 0);
2703 if (rhs_outer_dims > 0) {
2704 std::vector<int64_t> lhs_broadcast_dims(
2705 dnums.lhs_batch_dimensions_size());
2706 absl::c_iota(lhs_broadcast_dims, 0);
2707 lhs_broadcast_dims.resize(lhs->shape().rank());
2708 std::iota(lhs_broadcast_dims.begin() + dnums.lhs_batch_dimensions_size(),
2709 lhs_broadcast_dims.end(),
2710 dnums.lhs_batch_dimensions_size() + rhs_outer_dims);
2711 new_lhs = dot->AddInstruction(HloInstruction::CreateBroadcast(
2712 new_rhs->shape(), new_lhs, lhs_broadcast_dims));
2713 } else if (lhs_outer_dims > 0) {
2714 std::vector<int64_t> rhs_broadcast_dims(
2715 dnums.rhs_batch_dimensions_size());
2716 absl::c_iota(rhs_broadcast_dims, 0);
2717 rhs_broadcast_dims.resize(rhs->shape().rank());
2718 std::iota(rhs_broadcast_dims.begin() + dnums.rhs_batch_dimensions_size(),
2719 rhs_broadcast_dims.end(),
2720 dnums.rhs_batch_dimensions_size() + lhs_outer_dims);
2721 new_rhs = dot->AddInstruction(HloInstruction::CreateBroadcast(
2722 new_lhs->shape(), new_rhs, rhs_broadcast_dims));
2723 }
2724
2725 TF_ASSIGN_OR_RETURN(HloInstruction * new_dot,
2726 MakeBinaryHlo(HloOpcode::kMultiply, new_lhs, new_rhs));
2727 std::vector<int64_t> reduce_dims(dnums.lhs_contracting_dimensions_size());
2728 PrimitiveType dot_type =
2729 ShapeUtil::ElementIsFloating(dot->shape())
2730 ? (dot->shape().element_type() == F64 ? F64 : F32)
2731 : dot->shape().element_type();
2732 new_dot = AsType(new_dot, dot_type);
2733 const int64_t outer_dims = std::max(rhs_outer_dims, lhs_outer_dims);
2734 absl::c_iota(reduce_dims, outer_dims + dnums.lhs_batch_dimensions_size());
2735 new_dot = AddReduce(new_dot, reduce_dims, dot_type);
2736 new_dot = AsType(new_dot, dot->shape().element_type());
2737 return ReplaceInstruction(dot, new_dot);
2738 }
2739
2740 // Simplify dot(reshape(transpose(A)), Const) to:
2741 // dot(reshape(A), reshape(transpose(reshape(Const)))), so that the reshape
2742 // and transpose on the Const side can be constant folded.
2743 TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_reorder_optimized,
2744 OptimizeDotOfReorderContractingDims(dot));
2745 if (dot_of_reorder_optimized) {
2746 VLOG(10) << " Replaced dot " << dot->ToString()
2747 << " with new dot operation: "
2748 << dot_of_reorder_optimized->ToString();
2749 return ReplaceInstruction(dot, dot_of_reorder_optimized);
2750 }
2751
2752 TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_concat_optimized,
2753 OptimizeDotOfConcat(dot));
2754 if (dot_of_concat_optimized) {
2755 VLOG(10) << "Replaced dot(concat(...), constant) with add(dot(..., "
2756 "constant)...)";
2757 return ReplaceInstruction(dot, dot_of_concat_optimized);
2758 }
2759
2760 // Simplify dot(ConstA, Gather(Index, ConstB)) to:
2761 // Gather(Index, dot*(ConstA, ConstB)), where dot* is an appropriately
2762 // batched version of dot.
2763 TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_gather_optimized,
2764 OptimizeDotOfGather(dot));
2765 if (dot_of_gather_optimized) {
2766 VLOG(10) << "Replaced dot(constA, gather(i, constB)) with "
2767 "gather(i, dot*(constA, constB))";
2768 return ReplaceInstruction(dot, dot_of_gather_optimized);
2769 }
2770
2771 TF_ASSIGN_OR_RETURN(bool removed_degenerate_dimensions,
2772 RemoveDegenerateDimensionFromDot(dot));
2773 if (removed_degenerate_dimensions) {
2774 return OkStatus();
2775 }
2776
2777 TF_ASSIGN_OR_RETURN(bool removed_transposes,
2778 RemoveTransposesFromDotOperands(dot));
2779 if (removed_transposes) {
2780 return OkStatus();
2781 }
2782
2783 return OkStatus();
2784 }
2785
HandleGather(HloInstruction * gather)2786 Status AlgebraicSimplifierVisitor::HandleGather(HloInstruction* gather) {
2787 const Shape& operand_shape = gather->operand(0)->shape();
2788 if (ShapeUtil::IsZeroElementArray(operand_shape)) {
2789 return ReplaceInstruction(gather, MakeScalarLike(gather, 0));
2790 }
2791
2792 // Gathering from a scalar operand is simply a broadcast of that scalar
2793 if (ShapeUtil::IsEffectiveScalar(operand_shape)) {
2794 HloInstruction* new_operand = gather->mutable_operand(0);
2795 if (operand_shape.rank()) {
2796 TF_ASSIGN_OR_RETURN(new_operand,
2797 MakeReshapeHlo(ShapeUtil::MakeScalarShape(
2798 operand_shape.element_type()),
2799 new_operand));
2800 }
2801 HloInstruction* new_gather =
2802 MakeBroadcastHlo(new_operand, {}, gather->shape());
2803 return ReplaceInstruction(gather, new_gather);
2804 }
2805 // If the operand of a gather is very small, it is easier to fuse a
2806 // sequence of selects.
2807 const Shape& index_shape = gather->operand(1)->shape();
2808 if (operand_shape.rank() == 1 &&
2809 operand_shape.dimensions(0) <= options_.very_small_gather_size() &&
2810 gather->gather_dimension_numbers().index_vector_dim() ==
2811 index_shape.rank() &&
2812 gather->gather_dimension_numbers().collapsed_slice_dims_size() == 1) {
2813 const int64_t operand_elements = operand_shape.dimensions(0);
2814 auto get_value = [&](int64_t i) {
2815 auto slice = gather->AddInstruction(HloInstruction::CreateSlice(
2816 ShapeUtil::MakeShape(operand_shape.element_type(), {1}),
2817 gather->mutable_operand(0), {i}, {i + 1}, {1}));
2818 auto scalar = gather->AddInstruction(HloInstruction::CreateReshape(
2819 ShapeUtil::MakeShape(operand_shape.element_type(), {}), slice));
2820 return gather->AddInstruction(
2821 HloInstruction::CreateBroadcast(gather->shape(), scalar, {}));
2822 };
2823 auto result = get_value(0);
2824 auto pred_shape = ShapeUtil::ChangeElementType(gather->shape(), PRED);
2825 simplifier_->UpdateLayout(&pred_shape);
2826 auto iter_shape = ShapeUtil::ChangeElementType(gather->shape(),
2827 index_shape.element_type());
2828 simplifier_->UpdateLayout(&iter_shape);
2829 for (int64_t i = 0; i < operand_elements; ++i) {
2830 auto index_mask = gather->AddInstruction(HloInstruction::CreateCompare(
2831 pred_shape, gather->mutable_operand(1),
2832 MakeScalarLike(gather->mutable_operand(1), i),
2833 ComparisonDirection::kGe));
2834 result = gather->AddInstruction(
2835 HloInstruction::CreateTernary(gather->shape(), HloOpcode::kSelect,
2836 index_mask, get_value(i), result));
2837 }
2838 return ReplaceInstruction(gather, result);
2839 }
2840 return OkStatus();
2841 }
2842
2843 namespace {
MinMaxToClamp(HloInstruction * clamp_lower_bound_bcast,HloInstruction * to_clamp,HloInstruction * clamp_upper_bound_bcast,AlgebraicSimplifier * simplifier)2844 StatusOr<std::unique_ptr<HloInstruction>> MinMaxToClamp(
2845 HloInstruction* clamp_lower_bound_bcast, HloInstruction* to_clamp,
2846 HloInstruction* clamp_upper_bound_bcast, AlgebraicSimplifier* simplifier) {
2847 HloInstruction* clamp_lower_bound;
2848 CHECK(Match(clamp_lower_bound_bcast,
2849 m::Broadcast(m::ConstantEffectiveScalar(&clamp_lower_bound))))
2850 << clamp_lower_bound_bcast->ToString();
2851
2852 HloInstruction* clamp_upper_bound;
2853 CHECK(Match(clamp_upper_bound_bcast,
2854 m::Broadcast(m::ConstantEffectiveScalar(&clamp_upper_bound))))
2855 << clamp_upper_bound_bcast->ToString();
2856
2857 const Literal& lower_bound =
2858 Cast<HloConstantInstruction>(clamp_lower_bound)->literal();
2859 const Literal& upper_bound =
2860 Cast<HloConstantInstruction>(clamp_upper_bound)->literal();
2861
2862 TF_ASSIGN_OR_RETURN(Literal lower_bound_literal_reshaped,
2863 lower_bound.Reshape({}));
2864 TF_ASSIGN_OR_RETURN(Literal upper_bound_literal_reshaped,
2865 upper_bound.Reshape({}));
2866 std::unique_ptr<HloInstruction> lower_bound_instr =
2867 HloInstruction::CreateConstant(std::move(lower_bound_literal_reshaped));
2868 std::unique_ptr<HloInstruction> upper_bound_instr =
2869 HloInstruction::CreateConstant(std::move(upper_bound_literal_reshaped));
2870
2871 Shape compare_shape =
2872 ShapeUtil::ChangeElementType(lower_bound_instr->shape(), PRED);
2873 simplifier->UpdateLayout(&compare_shape);
2874 std::unique_ptr<HloInstruction> cloned_instruction =
2875 HloInstruction::CreateCompare(compare_shape, lower_bound_instr.get(),
2876 upper_bound_instr.get(),
2877 ComparisonDirection::kLt);
2878
2879 HloEvaluator evaluator;
2880 TF_ASSIGN_OR_RETURN(auto result,
2881 evaluator.Evaluate(cloned_instruction.get()));
2882 if (result.IsAll(true)) {
2883 return HloInstruction::CreateTernary(to_clamp->shape(), HloOpcode::kClamp,
2884 clamp_lower_bound_bcast, to_clamp,
2885 clamp_upper_bound_bcast);
2886 }
2887 return std::unique_ptr<HloInstruction>();
2888 }
2889 } // namespace
2890
HandleMaximum(HloInstruction * maximum)2891 Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum) {
2892 HloInstruction *lhs, *rhs;
2893 CHECK(Match(maximum, m::Maximum(m::Op(&lhs), m::Op(&rhs))));
2894
2895 // max(x, -inf) -> x
2896 PrimitiveType ty = maximum->shape().element_type();
2897 if (primitive_util::IsIntegralType(ty) ||
2898 (primitive_util::IsFloatingPointType(ty) &&
2899 options_.minmax_propagate_nan())) {
2900 Literal min_val = LiteralUtil::MinValue(ty);
2901 if (IsAll(lhs, min_val)) {
2902 return ReplaceInstruction(maximum, rhs);
2903 }
2904 if (IsAll(rhs, min_val)) {
2905 return ReplaceInstruction(maximum, lhs);
2906 }
2907 }
2908
2909 HloInstruction* clamp_upper_bound_bcast;
2910 HloInstruction* clamp_lower_bound_bcast;
2911 HloInstruction* to_clamp;
2912 if (Match(maximum, m::MaximumAnyOrder(
2913 m::Broadcast(&clamp_lower_bound_bcast,
2914 m::ConstantEffectiveScalar()),
2915 m::MinimumAnyOrder(
2916 m::Op(&to_clamp),
2917 m::Broadcast(&clamp_upper_bound_bcast,
2918 m::ConstantEffectiveScalar()))))) {
2919 TF_ASSIGN_OR_RETURN(auto clamp,
2920 MinMaxToClamp(clamp_lower_bound_bcast, to_clamp,
2921 clamp_upper_bound_bcast, simplifier_));
2922 if (clamp) {
2923 return ReplaceWithNewInstruction(maximum, std::move(clamp));
2924 }
2925 }
2926
2927 HloInstruction* clamp_lower_bound;
2928 HloInstruction* clamp_upper_bound;
2929 HloInstruction* max_operand;
2930 HloInstruction* clamp;
2931 if (Match(maximum,
2932 m::MaximumAnyOrder(
2933 m::Op(&max_operand),
2934 m::Clamp(&clamp, m::Op(&clamp_lower_bound), m::Op(&to_clamp),
2935 m::Op(&clamp_upper_bound))))) {
2936 if (max_operand == clamp_lower_bound &&
2937 ReplaceInstructionIfCompatible(maximum, clamp)) {
2938 return OkStatus();
2939 }
2940 }
2941
2942 return OkStatus();
2943 }
2944
HandleMinimum(HloInstruction * minimum)2945 Status AlgebraicSimplifierVisitor::HandleMinimum(HloInstruction* minimum) {
2946 HloInstruction *lhs, *rhs;
2947 CHECK(Match(minimum, m::Minimum(m::Op(&lhs), m::Op(&rhs))));
2948
2949 // min(x, inf) -> x
2950 PrimitiveType ty = minimum->shape().element_type();
2951 if (primitive_util::IsIntegralType(ty) ||
2952 (primitive_util::IsFloatingPointType(ty) &&
2953 options_.minmax_propagate_nan())) {
2954 Literal max_val = LiteralUtil::MaxValue(ty);
2955 if (IsAll(lhs, max_val)) {
2956 return ReplaceInstruction(minimum, rhs);
2957 }
2958 if (IsAll(rhs, max_val)) {
2959 return ReplaceInstruction(minimum, lhs);
2960 }
2961 }
2962
2963 HloInstruction* clamp_upper_bound_bcast;
2964 HloInstruction* clamp_lower_bound_bcast;
2965 HloInstruction* to_clamp;
2966 if (Match(minimum, m::MinimumAnyOrder(
2967 m::Broadcast(&clamp_upper_bound_bcast,
2968 m::ConstantEffectiveScalar()),
2969 m::MaximumAnyOrder(
2970 m::Op(&to_clamp),
2971 m::Broadcast(&clamp_lower_bound_bcast,
2972 m::ConstantEffectiveScalar()))))) {
2973 TF_ASSIGN_OR_RETURN(auto clamp,
2974 MinMaxToClamp(clamp_lower_bound_bcast, to_clamp,
2975 clamp_upper_bound_bcast, simplifier_));
2976 if (clamp) {
2977 return ReplaceWithNewInstruction(minimum, std::move(clamp));
2978 }
2979 }
2980
2981 return OkStatus();
2982 }
2983
HandleClamp(HloInstruction * clamp)2984 Status AlgebraicSimplifierVisitor::HandleClamp(HloInstruction* clamp) {
2985 HloInstruction* clamp_lower_bound;
2986 HloInstruction* clamp_upper_bound;
2987 HloInstruction* to_clamp;
2988 CHECK(Match(clamp, m::Clamp(m::Op(&clamp_lower_bound), m::Op(&to_clamp),
2989 m::Op(&clamp_upper_bound))));
2990
2991 // clamp(a, clamp(a, x, b), b) -> clamp(a, x, b)
2992 if (Match(to_clamp, m::Clamp(m::Op().Is(clamp_lower_bound), m::Op(),
2993 m::Op().Is(clamp_upper_bound))) &&
2994 ReplaceInstructionIfCompatible(clamp, to_clamp)) {
2995 return OkStatus();
2996 }
2997
2998 // Eliminate redundant clamping of replica-id or partition-id.
2999 if ((Match(to_clamp, m::PartitionId()) || Match(to_clamp, m::ReplicaId())) &&
3000 Match(clamp_lower_bound, m::ConstantScalar(0U)) &&
3001 Match(clamp_upper_bound, m::ConstantScalar())) {
3002 int64_t upper_bound = Cast<HloConstantInstruction>(clamp_upper_bound)
3003 ->literal()
3004 .GetFirstElement<uint32_t>();
3005 const HloModuleConfig& config = clamp->GetModule()->config();
3006 int64_t runtime_bound = Match(to_clamp, m::PartitionId())
3007 ? config.num_partitions()
3008 : config.replica_count();
3009
3010 // If num_partitions or replica_count is 1, infer it as unknown.
3011 // pid/rid < runtime_bound => The clamp(0, pid/rid, upper_bound) is
3012 // redundant if the runtime_bound <= upper_bound + 1;
3013 if (runtime_bound != 1 && runtime_bound <= upper_bound + 1) {
3014 return ReplaceInstruction(clamp, to_clamp);
3015 }
3016 }
3017
3018 return OkStatus();
3019 }
3020
HandleMultiply(HloInstruction * multiply)3021 Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) {
3022 HloInstruction *lhs, *rhs;
3023 CHECK(Match(multiply, m::Multiply(m::Op(&lhs), m::Op(&rhs))));
3024 // LHS*1 => LHS
3025 VLOG(10) << "trying transform [LHS*1 => LHS]: " << multiply->ToString();
3026 if (IsAll(rhs, 1) && ReplaceInstructionIfCompatible(multiply, lhs)) {
3027 return OkStatus();
3028 }
3029 // 1*RHS => RHS
3030 VLOG(10) << "trying transform [1*RHS => RHS]: " << multiply->ToString();
3031 if (IsAll(lhs, 1) && ReplaceInstructionIfCompatible(multiply, rhs)) {
3032 return OkStatus();
3033 }
3034
3035 // 0*RHS => 0. Only applies for integral types for correct NaN-handling.
3036 if (IsAll(lhs, 0) &&
3037 primitive_util::IsIntegralType(multiply->shape().element_type()) &&
3038 ReplaceInstructionIfCompatible(multiply, lhs)) {
3039 return OkStatus();
3040 }
3041 // LHS*0 => 0
3042 if (IsAll(rhs, 0) &&
3043 primitive_util::IsIntegralType(multiply->shape().element_type()) &&
3044 ReplaceInstructionIfCompatible(multiply, rhs)) {
3045 return OkStatus();
3046 }
3047
3048 {
3049 HloInstruction* abs_operand;
3050 if (lhs == rhs && Match(lhs, m::Abs(m::Op(&abs_operand))) &&
3051 !ShapeUtil::ElementIsComplex(abs_operand->shape())) {
3052 TF_RETURN_IF_ERROR(multiply->ReplaceOperandWith(0, abs_operand));
3053 TF_RETURN_IF_ERROR(multiply->ReplaceOperandWith(1, abs_operand));
3054 MarkAsChanged();
3055 return OkStatus();
3056 }
3057 }
3058
3059 {
3060 HloInstruction *convert_operand, *operand;
3061 // Mul(Convert(Pred), operand) => select(pred, operand, 0)
3062 if (Match(multiply,
3063 m::MultiplyAnyOrder(
3064 m::Op(&operand),
3065 m::Convert(
3066 m::Op(&convert_operand)
3067 .WithShape(m::Shape().WithElementType(PRED)))))) {
3068 HloInstruction* zero_like_multiply =
3069 BroadcastZeros(computation_, multiply->shape().element_type(),
3070 multiply->shape().dimensions());
3071 return ReplaceWithNewInstruction(
3072 multiply, HloInstruction::CreateTernary(
3073 multiply->shape(), HloOpcode::kSelect, convert_operand,
3074 operand, zero_like_multiply));
3075 }
3076 }
3077
3078 {
3079 HloInstruction *a, *b, *c1, *c2;
3080 // Mul(Mul(x, constant1), Mul(y, constant2)) => Mul(Mul(x, y),
3081 // constant1*constant2)
3082 if (Match(multiply,
3083 m::MultiplyAnyOrder(
3084 m::MultiplyAnyOrder(m::NonConstant(&a), m::Constant(&c1)),
3085 m::MultiplyAnyOrder(m::NonConstant(&b), m::Constant(&c2))))) {
3086 TF_ASSIGN_OR_RETURN(auto* product_of_constants,
3087 MakeBinaryHlo(HloOpcode::kMultiply, c1, c2));
3088 if (ShapeUtil::IsScalar(product_of_constants->shape()) &&
3089 !ShapeUtil::IsScalar(multiply->shape())) {
3090 product_of_constants =
3091 multiply->AddInstruction(HloInstruction::CreateBroadcast(
3092 multiply->shape(), product_of_constants, {}));
3093 }
3094
3095 return ReplaceWithNewInstruction(
3096 multiply, HloInstruction::CreateBinary(
3097 multiply->shape(), HloOpcode::kMultiply,
3098 multiply->AddInstruction(HloInstruction::CreateBinary(
3099 multiply->shape(), HloOpcode::kMultiply, a, b)),
3100 product_of_constants));
3101 }
3102 }
3103
3104 {
3105 HloInstruction *a, *c1, *c2;
3106 // Mul(Mul(a, constant1), constant2) => Mul(a, constant1*constant2)
3107 if (Match(multiply,
3108 m::MultiplyAnyOrder(
3109 m::MultiplyAnyOrder(m::NonConstant(&a), m::Constant(&c1)),
3110 m::Constant(&c2)))) {
3111 TF_ASSIGN_OR_RETURN(auto* product_of_constants,
3112 MakeBinaryHlo(HloOpcode::kMultiply, c1, c2));
3113 if (ShapeUtil::IsScalar(product_of_constants->shape()) &&
3114 !ShapeUtil::IsScalar(multiply->shape())) {
3115 product_of_constants =
3116 multiply->AddInstruction(HloInstruction::CreateBroadcast(
3117 multiply->shape(), product_of_constants, {}));
3118 }
3119
3120 return ReplaceWithNewInstruction(
3121 multiply,
3122 HloInstruction::CreateBinary(multiply->shape(), HloOpcode::kMultiply,
3123 a, product_of_constants));
3124 }
3125 }
3126
3127 {
3128 HloInstruction *a, *b, *constant, *op;
3129 // Mul(Mul(a, constant1), Broadcast(b)) =>
3130 // Mul(Broadcast(Mul(b, constant1), a))
3131 if (Match(multiply,
3132 m::MultiplyAnyOrder(m::MultiplyAnyOrder(m::NonConstant(&a),
3133 m::Constant(&constant)),
3134 m::Op(&op))) ||
3135 Match(multiply,
3136 m::MultiplyAnyOrder(
3137 m::MultiplyAnyOrder(m::NonConstant(&a),
3138 m::Broadcast(m::Constant(&constant))),
3139 m::Op(&op)))) {
3140 // Check that the other side was a broadcast, and not of a constant.
3141 if (ShapeUtil::IsScalar(constant->shape()) &&
3142 Match(op, m::Broadcast(m::NonConstant()))) {
3143 auto dims = op->dimensions();
3144 b = op->mutable_operand(0);
3145 if (!ShapeUtil::IsScalar(b->shape())) {
3146 constant = multiply->AddInstruction(
3147 HloInstruction::CreateBroadcast(b->shape(), constant, {}));
3148 }
3149
3150 auto new_mul = multiply->AddInstruction(HloInstruction::CreateBinary(
3151 b->shape(), HloOpcode::kMultiply, b, constant));
3152
3153 return ReplaceWithNewInstruction(
3154 multiply,
3155 HloInstruction::CreateBinary(
3156 multiply->shape(), HloOpcode::kMultiply, a,
3157 multiply->AddInstruction(HloInstruction::CreateBroadcast(
3158 multiply->shape(), new_mul, dims))));
3159 }
3160 }
3161 }
3162
3163 VLOG(10) << "trying transform [(A * C1) * C2 => A * (C1 * C2)]";
3164 HloInstruction *a, *c1, *c2;
3165 if (Match(multiply,
3166 m::Multiply(m::Multiply(m::NonConstant(&a), m::Constant(&c1)),
3167 m::Constant(&c2))) ||
3168 Match(multiply,
3169 m::Multiply(
3170 m::Multiply(m::Op(&a), m::Broadcast(m::ConstantScalar(&c1))),
3171 m::Broadcast(m::ConstantScalar(&c2))))) {
3172 TF_ASSIGN_OR_RETURN(auto* product_of_constants,
3173 MakeBinaryHlo(HloOpcode::kMultiply, c1, c2));
3174 if (ShapeUtil::IsScalar(product_of_constants->shape()) &&
3175 !ShapeUtil::IsScalar(multiply->shape())) {
3176 product_of_constants =
3177 multiply->AddInstruction(HloInstruction::CreateBroadcast(
3178 multiply->shape(), product_of_constants, {}));
3179 }
3180 return ReplaceWithNewInstruction(
3181 multiply,
3182 HloInstruction::CreateBinary(multiply->shape(), HloOpcode::kMultiply, a,
3183 product_of_constants));
3184 }
3185
3186 VLOG(10) << "trying to transform exp(LHS) * exp(RHS) => exp(LHS+RHS) "
3187 << multiply->ToString();
3188 if (Match(multiply, m::Multiply(m::Exp(m::Op(&lhs)), m::Exp(m::Op(&rhs))))) {
3189 auto add = multiply->AddInstruction(HloInstruction::CreateBinary(
3190 multiply->shape(), HloOpcode::kAdd, lhs, rhs));
3191 return ReplaceWithNewInstruction(
3192 multiply,
3193 HloInstruction::CreateUnary(multiply->shape(), HloOpcode::kExp, add));
3194 }
3195
3196 VLOG(10) << "trying transform [rsqrt(B) * rsqrt(B) => 1/B] "
3197 << multiply->ToString();
3198 HloInstruction* b;
3199 if (Match(multiply, m::Multiply(m::Rsqrt(m::Op(&b)), m::Rsqrt(m::Op(&b)))) &&
3200 IsPositive(b, options_)) {
3201 return ReplaceWithNewInstruction(
3202 multiply,
3203 HloInstruction::CreateBinary(multiply->shape(), HloOpcode::kDivide,
3204 MakeScalarLike(b, 1), b));
3205 }
3206
3207 return OkStatus();
3208 }
3209
HandleNegate(HloInstruction * negate)3210 Status AlgebraicSimplifierVisitor::HandleNegate(HloInstruction* negate) {
3211 // negate(negate(x)) => x
3212 HloInstruction* x;
3213 if (Match(negate, m::Negate(m::Negate(m::Op(&x)))) &&
3214 ReplaceInstructionIfCompatible(negate, x)) {
3215 return OkStatus();
3216 }
3217 return OkStatus();
3218 }
3219
HandleNot(HloInstruction * logical_not)3220 Status AlgebraicSimplifierVisitor::HandleNot(HloInstruction* logical_not) {
3221 // not(not(x)) => x
3222 HloInstruction* x;
3223 if (Match(logical_not, m::Not(m::Not(m::Op(&x)))) &&
3224 ReplaceInstructionIfCompatible(logical_not, x)) {
3225 return OkStatus();
3226 }
3227 return OkStatus();
3228 }
3229
HandleOr(HloInstruction * logical_or)3230 Status AlgebraicSimplifierVisitor::HandleOr(HloInstruction* logical_or) {
3231 HloInstruction *lhs, *rhs;
3232 CHECK(Match(logical_or, m::Or(m::Op(&lhs), m::Op(&rhs))));
3233
3234 // Simplify logical or
3235 if (ShapeUtil::HasPrimitiveType(lhs->shape(), xla::PRED) &&
3236 ShapeUtil::HasPrimitiveType(rhs->shape(), xla::PRED)) {
3237 // A || True => True
3238 VLOG(10) << "trying transform [A || True => True]: "
3239 << logical_or->ToString();
3240 if (IsAll(rhs, 1) && ReplaceInstructionIfCompatible(logical_or, rhs)) {
3241 return OkStatus();
3242 }
3243 // True || A => True
3244 VLOG(10) << "trying transform [True || A => True]: "
3245 << logical_or->ToString();
3246 if (IsAll(lhs, 1) && ReplaceInstructionIfCompatible(logical_or, lhs)) {
3247 return OkStatus();
3248 }
3249 }
3250
3251 // A || False => A and A | 0 => A
3252 VLOG(10) << "trying transform [A || False => A]: " << logical_or->ToString();
3253 if (IsAll(rhs, 0) && ReplaceInstructionIfCompatible(logical_or, lhs)) {
3254 return OkStatus();
3255 }
3256
3257 // False || A => A and 0 | A => A
3258 VLOG(10) << "trying transform [False || A => A]: " << logical_or->ToString();
3259 if (IsAll(lhs, 0) && ReplaceInstructionIfCompatible(logical_or, rhs)) {
3260 return OkStatus();
3261 }
3262
3263 return OkStatus();
3264 }
3265
HandleLog(HloInstruction * log)3266 Status AlgebraicSimplifierVisitor::HandleLog(HloInstruction* log) {
3267 // ln(exp(A)) => A
3268 VLOG(10) << "trying transform [ln(exp(A)) => A]: " << log->ToString();
3269 HloInstruction *a, *b;
3270 if (Match(log, m::Log(m::Exp(m::Op(&a)))) &&
3271 ReplaceInstructionIfCompatible(log, a)) {
3272 return OkStatus();
3273 }
3274
3275 // ln(pow(A,B)) => B*ln(abs(A))
3276 // or B*ln(A) if A is complex.
3277 if (Match(log, m::Log(m::Power(m::Op(&a), m::Op(&b))))) {
3278 auto abs_a = ShapeUtil::ElementIsComplex(a->shape())
3279 ? a
3280 : log->AddInstruction(HloInstruction::CreateUnary(
3281 log->shape(), HloOpcode::kAbs, a));
3282 auto new_log = log->AddInstruction(
3283 HloInstruction::CreateUnary(log->shape(), HloOpcode::kLog, abs_a));
3284 auto non_zero_b =
3285 log->mutable_operand(0)->AddInstruction(HloInstruction::CreateBinary(
3286 log->shape(), HloOpcode::kMultiply, new_log, b));
3287 TF_ASSIGN_OR_RETURN(
3288 auto b_is_zero,
3289 MakeCompareHlo(Comparison::Direction::kEq, b, MakeScalarLike(b, 0.0)));
3290 simplifier_->UpdateLayout(b_is_zero->mutable_shape());
3291 return ReplaceWithNewInstruction(
3292 log, HloInstruction::CreateTernary(log->shape(), HloOpcode::kSelect,
3293 b_is_zero, MakeScalarLike(log, 0.0),
3294 non_zero_b));
3295 }
3296
3297 if (Match(log, m::Log(m::Sqrt(m::Op(&a))))) {
3298 auto new_log = log->AddInstruction(
3299 HloInstruction::CreateUnary(log->shape(), HloOpcode::kLog, a));
3300 return ReplaceWithNewInstruction(
3301 log, HloInstruction::CreateBinary(log->shape(), HloOpcode::kMultiply,
3302 new_log, MakeScalarLike(log, 0.5)));
3303 }
3304
3305 if (Match(log, m::Log(m::Rsqrt(m::Op(&a))))) {
3306 auto new_log = log->AddInstruction(
3307 HloInstruction::CreateUnary(log->shape(), HloOpcode::kLog, a));
3308 return ReplaceWithNewInstruction(
3309 log, HloInstruction::CreateBinary(log->shape(), HloOpcode::kMultiply,
3310 new_log, MakeScalarLike(log, -0.5)));
3311 }
3312
3313 return OkStatus();
3314 }
3315
HandleGetTupleElement(HloInstruction * get_tuple_element)3316 Status AlgebraicSimplifierVisitor::HandleGetTupleElement(
3317 HloInstruction* get_tuple_element) {
3318 auto operand = get_tuple_element->mutable_operand(0);
3319 if (operand->opcode() == HloOpcode::kTuple) {
3320 // get_tuple_element(make_tuple({A_0, A_1, ..., A_n}), i) => A_i
3321 VLOG(10) << "trying transform "
3322 << "[get_tuple_element(make_tuple({...,A_i,...}), i)] => A_i: "
3323 << get_tuple_element->ToString();
3324 if (ReplaceInstructionIfCompatible(
3325 get_tuple_element,
3326 operand->mutable_operand(get_tuple_element->tuple_index()))) {
3327 return OkStatus();
3328 }
3329 }
3330 return OkStatus();
3331 }
3332
HandleOptimizationBarrier(HloInstruction * barrier)3333 Status AlgebraicSimplifierVisitor::HandleOptimizationBarrier(
3334 HloInstruction* barrier) {
3335 if (!barrier->shape().IsTuple() ||
3336 barrier == computation_->root_instruction()) {
3337 return OkStatus();
3338 }
3339
3340 // The goal of this transformation is to enable DCE on the tuple elements of
3341 // an optimization barrier operand. To do this safely, the optimization
3342 // barrier users must not use the tuple element and the only use of the index
3343 // of the operand should be the tuple instruction producing the operand of the
3344 // optimization barrier. Additionally if the operand is a tuple producing
3345 // instruction it should also be safe to create a sub tuple of only the used
3346 // components to enable module level dce.
3347 std::vector<bool> used_elements(barrier->shape().tuple_shapes_size());
3348 bool has_non_gte_use = false;
3349 for (auto use : barrier->users()) {
3350 if (use->opcode() != HloOpcode::kGetTupleElement) {
3351 has_non_gte_use = true;
3352 break;
3353 }
3354 used_elements[use->tuple_index()] = true;
3355 }
3356
3357 HloInstruction* operand = barrier->mutable_operand(0);
3358 if (operand->opcode() == HloOpcode::kTuple) {
3359 for (int64_t i = 0; i < operand->operand_count(); ++i) {
3360 if (used_elements[i]) {
3361 continue;
3362 }
3363 if (operand->operand(i)->user_count() > 1 ||
3364 operand->operand(i) == computation_->root_instruction()) {
3365 used_elements[i] = true;
3366 }
3367 }
3368 }
3369
3370 if (has_non_gte_use || !absl::c_linear_search(used_elements, false)) {
3371 return OkStatus();
3372 }
3373
3374 MarkAsChanged();
3375 std::vector<int64_t> index_map(used_elements.size(), -1);
3376 std::vector<HloInstruction*> operands;
3377 int64_t current_index = 0;
3378 for (int64_t element = 0; element < used_elements.size(); ++element) {
3379 if (!used_elements[element]) {
3380 continue;
3381 }
3382 index_map[element] = current_index++;
3383 if (operand->opcode() == HloOpcode::kTuple) {
3384 operands.push_back(operand->mutable_operand(element));
3385 } else {
3386 operands.push_back(barrier->AddInstruction(
3387 HloInstruction::CreateGetTupleElement(operand, element)));
3388 }
3389 }
3390
3391 HloInstruction* new_operand =
3392 operand->AddInstruction(HloInstruction::CreateTuple(operands));
3393 TF_RETURN_IF_ERROR(barrier->ReplaceOperandWithDifferentShape(0, new_operand));
3394 *barrier->mutable_shape() = new_operand->shape();
3395 for (auto use : barrier->users()) {
3396 CHECK_EQ(use->opcode(), HloOpcode::kGetTupleElement);
3397 use->set_tuple_index(index_map[use->tuple_index()]);
3398 }
3399 return OkStatus();
3400 }
3401
3402 namespace {
3403
ReshapeLeavesDimensionsUnmodified(const HloInstruction * hlo,absl::Span<const int64_t> input_dim_indices)3404 std::optional<std::vector<int64_t>> ReshapeLeavesDimensionsUnmodified(
3405 const HloInstruction* hlo, absl::Span<const int64_t> input_dim_indices) {
3406 CHECK_EQ(hlo->opcode(), HloOpcode::kReshape);
3407 return ShapeUtil::ReshapeLeavesDimensionsUnmodified(
3408 hlo->operand(0)->shape(), hlo->shape(), input_dim_indices);
3409 }
3410
3411 // Returns true if the output of "instruction" is a permutation of the
3412 // elements of "operand". Precondition: "operand" is an operand of
3413 // "instruction".
OutputIsPermutationOfOperandElements(HloInstruction * instruction,HloInstruction * operand)3414 bool OutputIsPermutationOfOperandElements(HloInstruction* instruction,
3415 HloInstruction* operand) {
3416 DCHECK(!instruction->OperandIndices(operand).empty());
3417 switch (instruction->opcode()) {
3418 case HloOpcode::kReshape:
3419 case HloOpcode::kReverse:
3420 case HloOpcode::kTranspose:
3421 return true;
3422 case HloOpcode::kSort:
3423 return (!instruction->shape().IsTuple());
3424 default:
3425 return false;
3426 }
3427 }
3428
3429 // Returns true if the output of "instruction" is a subset of the elements of
3430 // "operand". Precondition: "operand" is an operand of "instruction".
OutputIsSubsetOfOperandElements(HloInstruction * instruction,HloInstruction * operand)3431 bool OutputIsSubsetOfOperandElements(HloInstruction* instruction,
3432 HloInstruction* operand) {
3433 const auto operand_indices = instruction->OperandIndices(operand);
3434 CHECK(!operand_indices.empty());
3435 if (operand_indices.size() != 1) {
3436 return false;
3437 }
3438 int64_t operand_index = operand_indices[0];
3439 switch (instruction->opcode()) {
3440 case HloOpcode::kSlice:
3441 CHECK_EQ(0, operand_index);
3442 return true;
3443 case HloOpcode::kDynamicSlice:
3444 return operand_index == 0;
3445 default:
3446 return false;
3447 }
3448 }
3449
3450 } // namespace
3451
HandleBroadcast(HloInstruction * broadcast)3452 Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) {
3453 HloInstruction* operand;
3454 CHECK(Match(broadcast, m::Broadcast(m::Op(&operand))));
3455 auto dims = *broadcast->mutable_dimensions();
3456 // A degenerate broadcast of a reshape that does not change the number of
3457 // elements can be replaced by a reshape.
3458 if (std::is_sorted(dims.begin(), dims.end()) &&
3459 ShapeUtil::ElementsIn(broadcast->shape()) ==
3460 ShapeUtil::ElementsIn(operand->shape())) {
3461 VLOG(10) << "transform broadcast(X) -> reshape(X) where "
3462 "n(broadcast(X)) == n(X)";
3463 return ReplaceWithNewInstruction(
3464 broadcast, HloInstruction::CreateReshape(broadcast->shape(), operand));
3465 }
3466
3467 // A degenerate broadcast that has the same input and output rank can be
3468 // converted into a transpose.
3469 if (broadcast->shape().rank() == operand->shape().rank() &&
3470 ShapeUtil::ElementsIn(broadcast->shape()) ==
3471 ShapeUtil::ElementsIn(operand->shape())) {
3472 VLOG(10) << "transform broadcast(X) -> transpose(X) where "
3473 "n(broadcast(X)) == n(X)";
3474 return ReplaceWithNewInstruction(
3475 broadcast,
3476 HloInstruction::CreateTranspose(broadcast->shape(), operand, dims));
3477 }
3478
3479 // A broadcast of a reshape which merely inserts 1-sized dimensions can
3480 // elide its operand.
3481 {
3482 std::optional<ShapeUtil::ShapeEqualityDescriptor> reshape_degenerate =
3483 operand->ReshapeMerelyInsertsOrDeletes1SizedDimensions();
3484 if (reshape_degenerate.has_value() &&
3485 reshape_degenerate->deleted_dimensions.empty()) {
3486 absl::c_reverse(reshape_degenerate->inserted_dimensions);
3487 for (auto inserted_index : reshape_degenerate->inserted_dimensions) {
3488 dims.erase(dims.begin() + inserted_index);
3489 }
3490 return ReplaceWithNewInstruction(
3491 broadcast,
3492 HloInstruction::CreateBroadcast(broadcast->shape(),
3493 operand->mutable_operand(0), dims));
3494 }
3495 }
3496
3497 if (options_.enable_sink_broadcast()) {
3498 // A Broadcast that feeds a unary element-wise operation can sink the
3499 // broadcast after the unary element-wise operation.
3500 TF_ASSIGN_OR_RETURN(
3501 bool sink_succeeded,
3502 TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(broadcast));
3503 if (sink_succeeded) {
3504 MarkAsChanged();
3505 return OkStatus();
3506 }
3507 }
3508
3509 // A scalar broadcast feeding an instruction which only permutes (reshape,
3510 // transpose, sort, reverse) or selects a subset of operand elements (slice,
3511 // dynamic slice) can be replaced with a broadcast directly to the output
3512 // shape of the instruction.
3513 if (ShapeUtil::IsScalar(operand->shape())) {
3514 for (HloInstruction* user : broadcast->users()) {
3515 // Skip if the broadcast user has no uses itself.
3516 if (user->IsDead()) {
3517 continue;
3518 }
3519 if (OutputIsPermutationOfOperandElements(user, broadcast) ||
3520 OutputIsSubsetOfOperandElements(user, broadcast)) {
3521 VLOG(10) << "transform permuting/subset of a scalar broadcast into "
3522 << "a single broadcast";
3523 HloInstruction* new_broadcast = user->AddInstruction(
3524 HloInstruction::CreateBroadcast(user->shape(), operand, {}));
3525 // Use HloInstruction::ReplaceAllUsesWith instead of
3526 // HloComputation::ReplaceWithNewInstruction because we are replacing an
3527 // instruction other than the visited instruction.
3528 MarkAsChanged();
3529 return user->ReplaceAllUsesWith(new_broadcast);
3530 }
3531 }
3532 return OkStatus();
3533 }
3534
3535 // broadcast(iota) -> iota.
3536 if (operand->opcode() == HloOpcode::kIota) {
3537 return ReplaceWithNewInstruction(
3538 broadcast,
3539 HloInstruction::CreateIota(
3540 broadcast->shape(),
3541 dims[Cast<HloIotaInstruction>(operand)->iota_dimension()]));
3542 }
3543
3544 // Merge two consecutive broadcasts into a single one.
3545 if (operand->opcode() == HloOpcode::kBroadcast) {
3546 std::vector<int64_t> new_dimensions;
3547 new_dimensions.reserve(operand->dimensions().size());
3548 for (auto dim : operand->dimensions()) {
3549 new_dimensions.push_back(dims[dim]);
3550 }
3551 return ReplaceWithNewInstruction(
3552 broadcast,
3553 HloInstruction::CreateBroadcast(
3554 broadcast->shape(), operand->mutable_operand(0), new_dimensions));
3555 }
3556 if (options_.is_layout_sensitive()) {
3557 return OkStatus();
3558 }
3559 if (ShapeUtil::HasDegenerateDimensions(operand->shape())) {
3560 auto new_operand =
3561 operand->parent()->AddInstruction(HloInstruction::CreateReshape(
3562 ShapeUtil::DropDegenerateDimensions(operand->shape()), operand));
3563 std::vector<int64_t> new_dims;
3564 new_dims.reserve(new_operand->shape().rank());
3565 for (int64_t i = 0; i < operand->shape().rank(); ++i) {
3566 if (operand->shape().dimensions(i) != 1) {
3567 new_dims.push_back(dims[i]);
3568 }
3569 }
3570 return ReplaceWithNewInstruction(
3571 broadcast, HloInstruction::CreateBroadcast(broadcast->shape(),
3572 new_operand, new_dims));
3573 }
3574 return OkStatus();
3575 }
3576
HandleCompare(HloInstruction * compare)3577 Status AlgebraicSimplifierVisitor::HandleCompare(HloInstruction* compare) {
3578 HloInstruction* lhs;
3579 HloInstruction* rhs;
3580 CHECK(Match(compare, m::Compare(m::Op(&lhs), m::Op(&rhs))));
3581
3582 if (Cast<HloCompareInstruction>(compare)->type() ==
3583 Comparison::Type::kUnsigned) {
3584 // X u< 0 -> false
3585 if (compare->comparison_direction() == ComparisonDirection::kLt &&
3586 IsAll(rhs, 0)) {
3587 return ReplaceInstruction(compare, MakeScalarLike(compare, false));
3588 }
3589 // X u>= 0 -> true
3590 if (compare->comparison_direction() == ComparisonDirection::kGe &&
3591 IsAll(rhs, 0)) {
3592 return ReplaceInstruction(compare, MakeScalarLike(compare, true));
3593 }
3594 // 0 u> X -> false
3595 if (compare->comparison_direction() == ComparisonDirection::kGt &&
3596 IsAll(lhs, 0)) {
3597 return ReplaceInstruction(compare, MakeScalarLike(compare, false));
3598 }
3599 // 0 u<= X -> true
3600 if (compare->comparison_direction() == ComparisonDirection::kLe &&
3601 IsAll(lhs, 0)) {
3602 return ReplaceInstruction(compare, MakeScalarLike(compare, true));
3603 }
3604 }
3605
3606 if (compare->comparison_direction() == ComparisonDirection::kLt &&
3607 lhs->opcode() == HloOpcode::kIota && IsAll(rhs, 0)) {
3608 return ReplaceInstruction(compare, MakeScalarLike(compare, false));
3609 } else if (compare->comparison_direction() == ComparisonDirection::kGt &&
3610 IsAll(lhs, 0) && rhs->opcode() == HloOpcode::kIota) {
3611 return ReplaceInstruction(compare, MakeScalarLike(compare, false));
3612 } else if (compare->comparison_direction() == ComparisonDirection::kGe &&
3613 lhs->opcode() == HloOpcode::kIota && IsAll(rhs, 0)) {
3614 return ReplaceInstruction(compare, MakeScalarLike(compare, true));
3615 } else if (compare->comparison_direction() == ComparisonDirection::kLe &&
3616 IsAll(lhs, 0) && rhs->opcode() == HloOpcode::kIota) {
3617 return ReplaceInstruction(compare, MakeScalarLike(compare, true));
3618 }
3619 if (lhs == rhs &&
3620 primitive_util::IsIntegralType(lhs->shape().element_type())) {
3621 switch (compare->comparison_direction()) {
3622 case ComparisonDirection::kGt:
3623 case ComparisonDirection::kLt:
3624 case ComparisonDirection::kNe:
3625 return ReplaceInstruction(compare, MakeScalarLike(compare, false));
3626 case ComparisonDirection::kEq:
3627 case ComparisonDirection::kGe:
3628 case ComparisonDirection::kLe:
3629 return ReplaceInstruction(compare, MakeScalarLike(compare, true));
3630 }
3631 }
3632 return OkStatus();
3633 }
3634
HandleConvert(HloInstruction * convert)3635 Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert) {
3636 PrimitiveType src_type = convert->operand(0)->shape().element_type();
3637 PrimitiveType dest_type = convert->shape().element_type();
3638 // A conversion to the same element type as the operand is a nop and can be
3639 // removed. A conversion of a constant can be simplified by making a new
3640 // constant.
3641 if (src_type == dest_type) {
3642 return ReplaceInstruction(convert, convert->mutable_operand(0));
3643 }
3644
3645 // Eliminate a convert pair if it is a no-op. The following are a few
3646 // example cases that are being handled:
3647 // 1. convert(convert(A, $TYPE1), $TYPE2) is simplified to A if A is of $TYPE2
3648 // and convert(A, $TYPE1) is an upcast
3649 // 2. convert(convert(A, $TYPE1),$TYPE2) is simplified to A if A is of $TYPE2
3650 // and convert(A, $TYPE1) is an upcast and is an integral conversion from
3651 // unsigned to signed (only signed to unsigned conversion is NOT allowed)
3652 // 3. Tuple(convert(A, $TYPE1) , floor(convert(convert(A, $TYPE1), $TYPE2)),
3653 // convert(convert(A, $TYPE1), $TYPE2)) is simplified to Tuple(convert(A,
3654 // $TYPE1) , floor(A), A) -> a case where the first convert has a
3655 // fan-out
3656 if (IsConvertPairNoOp(convert)) {
3657 return ReplaceInstruction(convert,
3658 convert->mutable_operand(0)->mutable_operand(0));
3659 }
3660 return OkStatus();
3661 }
3662
3663 // Complex(Real(c), Imag(c)) -> c
HandleComplex(HloInstruction * complex)3664 Status AlgebraicSimplifierVisitor::HandleComplex(HloInstruction* complex) {
3665 HloInstruction *c0, *c1;
3666 if (Match(complex, m::Complex(m::Real(m::Op(&c0)), m::Imag(m::Op(&c1)))) &&
3667 c0 == c1) {
3668 return ReplaceInstruction(complex, c0);
3669 }
3670 return OkStatus();
3671 }
3672
3673 // Real(Complex(r, i)) -> r
HandleReal(HloInstruction * real)3674 Status AlgebraicSimplifierVisitor::HandleReal(HloInstruction* real) {
3675 HloInstruction* op;
3676 if (Match(real, m::Real(m::Complex(m::Op(&op), m::Op())))) {
3677 return ReplaceInstruction(real, op);
3678 }
3679 return OkStatus();
3680 }
3681
3682 // Imag(Complex(r, i)) -> i
HandleImag(HloInstruction * imag)3683 Status AlgebraicSimplifierVisitor::HandleImag(HloInstruction* imag) {
3684 HloInstruction* op;
3685 if (Match(imag, m::Imag(m::Complex(m::Op(), m::Op(&op))))) {
3686 return ReplaceInstruction(imag, op);
3687 }
3688 return OkStatus();
3689 }
3690
HandleIota(HloInstruction * instruction)3691 Status AlgebraicSimplifierVisitor::HandleIota(HloInstruction* instruction) {
3692 // iota -> zero if the iota dimension never produces an element other than
3693 // zero.
3694 auto* iota = Cast<HloIotaInstruction>(instruction);
3695 if (iota->shape().dimensions(iota->iota_dimension()) <= 1) {
3696 return ReplaceInstruction(iota, MakeScalarLike(iota, 0));
3697 }
3698 return OkStatus();
3699 }
3700
HandlePad(HloInstruction * pad)3701 Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) {
3702 if (ShapeUtil::IsZeroElementArray(pad->operand(0)->shape())) {
3703 return ReplaceWithNewInstruction(
3704 pad, HloInstruction::CreateBroadcast(pad->shape(),
3705 pad->mutable_operand(1), {}));
3706 }
3707
3708 // Interior padding on one sized dimensions have no effect. As a result it
3709 // makes other simplifications possible if there is no interior padding.
3710 if (HasInteriorPadding(pad->padding_config())) {
3711 PaddingConfig padding_config = pad->padding_config();
3712 bool cleared_interior_padding = false;
3713 for (int64_t i = 0; i < pad->shape().rank(); ++i) {
3714 if (padding_config.dimensions(i).interior_padding() > 0 &&
3715 pad->operand(0)->shape().dimensions(i) == 1) {
3716 cleared_interior_padding = true;
3717 padding_config.mutable_dimensions(i)->set_interior_padding(0);
3718 }
3719 }
3720 if (cleared_interior_padding) {
3721 return ReplaceWithNewInstruction(
3722 pad,
3723 HloInstruction::CreatePad(pad->shape(), pad->mutable_operand(0),
3724 pad->mutable_operand(1), padding_config));
3725 }
3726 }
3727
3728 // Eliminate nop pads (padding all zero), and replace a pad with negative
3729 // padding with a pad with non-negative padding followed by a slice.
3730 bool all_zero = true;
3731 bool has_negative = false;
3732 // Used to possibly split off the unchanged padding dimensions.
3733 std::vector<int64_t> padding_dimensions;
3734 int64_t dimension_index = 0;
3735 for (auto& padding_dimension : pad->padding_config().dimensions()) {
3736 if (padding_dimension.edge_padding_low() < 0 ||
3737 padding_dimension.edge_padding_high() < 0) {
3738 has_negative = true;
3739 }
3740 if (padding_dimension.edge_padding_low() != 0 ||
3741 padding_dimension.edge_padding_high() != 0) {
3742 all_zero = false;
3743 padding_dimensions.push_back(dimension_index);
3744 } else if (padding_dimension.interior_padding()) {
3745 padding_dimensions.push_back(dimension_index);
3746 }
3747 dimension_index++;
3748 }
3749
3750 if (all_zero) {
3751 if (ReplaceInstructionIfCompatible(pad, pad->mutable_operand(0))) {
3752 return OkStatus();
3753 }
3754 }
3755
3756 // The context of this optimization can be found at b/163617402
3757 // It tries to capture the case of pad(broadcast(x)), where
3758 // x->shape().dimensions(), or broadcast(x)->dimensions(), is
3759 // a subset of the padded dimensions in pad->config(),
3760 // and the padded dimensions in pad->config() is in turn a strict
3761 // subset of broadcast->shape().dimensions(). The combined op can be
3762 // rewritten to broadcast2(pad(broadcast1(x))), where broadcast1 extends
3763 // x with dimensions that need to be padded, and broadcast2 extends
3764 // the result of padding to full dimensions.
3765 // TODO(qyi): for future extensions: The condition for broadcast(x)
3766 // ->dimensions() to be a subset of padded dimensions in pad->config()
3767 // does not have to be strictly required, but it makes the calculation
3768 // for optimization easier, so it is required by the current implementation.
3769 // Only the second condition between the padded dimensions and the
3770 // dimensions of the final shape have to be enforced for the optimization
3771 // to make sense. If needed to remove the first constraint, the shape
3772 // calculations across the implementation need to be re-adjusted.
3773 auto pad_dims = padding_dimensions.size();
3774 if (pad_dims < dimension_index &&
3775 pad->operand(0)->opcode() == HloOpcode::kBroadcast &&
3776 pad->operand(0)->user_count() == 1 &&
3777 pad->operand(0)->operand(0)->shape().rank() <= pad_dims) {
3778 // Check broadcast operand dimensions is a subset of pading_dimensions.
3779 // If not, skip the optimization.
3780 bool opt_is_valid = true;
3781 std::vector<int64_t> broadcast_dimensions;
3782 HloBroadcastInstruction* broadcast =
3783 static_cast<HloBroadcastInstruction*>(pad->mutable_operand(0));
3784 for (auto broadcast_index : broadcast->dimensions()) {
3785 bool found = false;
3786 for (int i = 0; i < pad_dims; ++i) {
3787 if (broadcast_index == padding_dimensions[i]) {
3788 broadcast_dimensions.push_back(i);
3789 found = true;
3790 break;
3791 }
3792 }
3793 if (!found) {
3794 opt_is_valid = false;
3795 break;
3796 }
3797 }
3798 if (opt_is_valid) {
3799 auto pad_shape = pad->shape();
3800 auto broadcast_shape = broadcast->shape();
3801 auto pad_shape1 = pad_shape;
3802 auto broadcast_shape1 = broadcast_shape;
3803 PaddingConfig pad_config;
3804 for (int i = padding_dimensions.size() - 1; i >= 0; --i) {
3805 int64_t j = padding_dimensions[i];
3806 while (--dimension_index > j) {
3807 broadcast_shape1.DeleteDimension(dimension_index);
3808 pad_shape1.DeleteDimension(dimension_index);
3809 }
3810 }
3811 while (--dimension_index >= 0) {
3812 broadcast_shape1.DeleteDimension(dimension_index);
3813 pad_shape1.DeleteDimension(dimension_index);
3814 }
3815 for (auto dimension_to_pad : padding_dimensions) {
3816 auto dimension = pad_config.add_dimensions();
3817 *dimension = pad->padding_config().dimensions(dimension_to_pad);
3818 }
3819 *broadcast->mutable_shape() = broadcast_shape1;
3820 *broadcast->mutable_dimensions() = broadcast_dimensions;
3821 simplifier_->UpdateLayout(broadcast->mutable_shape());
3822 auto pad2 = pad->AddInstruction(pad->CloneWithNewShape(pad_shape1));
3823 *pad2->mutable_padding_config() = pad_config;
3824 simplifier_->UpdateLayout(pad2->mutable_shape());
3825 auto broadcast2 = pad->AddInstruction(
3826 HloInstruction::CreateBroadcast(pad_shape, pad2, padding_dimensions));
3827 return ReplaceInstruction(pad, broadcast2);
3828 }
3829 }
3830
3831 if (has_negative && options_.enable_negative_padding_replacement()) {
3832 // Pad has negative padding. Replace with a pad with the non-negative
3833 // padding followed by a slice which effectively performs the negative
3834 // padding.
3835 // TODO(b/34628603): Add support for negative padding in the backends, or
3836 // change kPad semantics to disallow negative padding and use slice
3837 // instead.
3838
3839 // First construct the padding config with non-negative entries and the
3840 // compute the shape of this new pad instruction.
3841 PaddingConfig nonzero_padding = pad->padding_config();
3842 for (int i = 0; i < pad->padding_config().dimensions_size(); ++i) {
3843 PaddingConfig::PaddingConfigDimension* padding_dimension =
3844 nonzero_padding.mutable_dimensions(i);
3845 // Set negative padding to zero.
3846 if (padding_dimension->edge_padding_low() < 0) {
3847 padding_dimension->set_edge_padding_low(0);
3848 }
3849 if (padding_dimension->edge_padding_high() < 0) {
3850 padding_dimension->set_edge_padding_high(0);
3851 }
3852 }
3853
3854 TF_ASSIGN_OR_RETURN(HloInstruction * nonzero_pad,
3855 MakePadHlo(pad->mutable_operand(0),
3856 pad->mutable_operand(1), nonzero_padding));
3857 // Copy the layout from the original pad instructions. The new pad and the
3858 // slice instruction should all have the same layout.
3859 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
3860 pad->shape(), nonzero_pad->mutable_shape()));
3861 simplifier_->UpdateLayout(nonzero_pad->mutable_shape());
3862
3863 // Second, construct the slice instruction to perform the negative
3864 // padding.
3865 std::vector<int64_t> start_indices;
3866 std::vector<int64_t> end_indices;
3867 std::vector<int64_t> strides;
3868 for (int64_t i = 0; i < pad->padding_config().dimensions_size(); ++i) {
3869 const PaddingConfig::PaddingConfigDimension& padding_dimension =
3870 pad->padding_config().dimensions(i);
3871 int64_t start = 0;
3872 if (padding_dimension.edge_padding_low() < 0) {
3873 start = -1 * padding_dimension.edge_padding_low();
3874 }
3875 int64_t end = nonzero_pad->shape().dimensions(i);
3876 if (padding_dimension.edge_padding_high() < 0) {
3877 end += padding_dimension.edge_padding_high();
3878 }
3879 start_indices.push_back(start);
3880 end_indices.push_back(end);
3881 strides.push_back(1);
3882 }
3883
3884 TF_ASSIGN_OR_RETURN(
3885 HloInstruction * slice,
3886 MakeSliceHlo(nonzero_pad, start_indices, end_indices, strides));
3887 TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
3888 pad->shape(), slice->mutable_shape()));
3889 simplifier_->UpdateLayout(slice->mutable_shape());
3890
3891 // Verify that the slice shape matches the pad shape.
3892 auto equal = Shape::Equal();
3893 if (!options_.is_layout_sensitive()) {
3894 equal.IgnoreTilesInLayout();
3895 }
3896 TF_RET_CHECK(equal(slice->shape(), pad->shape()));
3897
3898 return ReplaceInstruction(pad, slice);
3899 }
3900
3901 return OkStatus();
3902 }
3903
HandlePower(HloInstruction * power)3904 Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) {
3905 VLOG(10) << "trying transform [pow(A, 0) => 1]: " << power->ToString();
3906 HloInstruction *lhs, *rhs;
3907 CHECK(Match(power, m::Power(m::Op(&lhs), m::Op(&rhs))));
3908 if (IsAll(rhs, 0)) {
3909 return ReplaceInstruction(power, MakeScalarLike(power, 1));
3910 }
3911
3912 VLOG(10) << "trying transform [pow(A, 1) => A]: " << power->ToString();
3913 if (IsAll(rhs, 1) && ReplaceInstructionIfCompatible(power, lhs)) {
3914 return OkStatus();
3915 }
3916
3917 // pow(exp(A),B) => exp(A*B)
3918 HloInstruction *a, *b;
3919 if (Match(power, m::Power(m::Exp(m::Op(&a)), m::Op(&b)))) {
3920 auto a_times_b = power->AddInstruction(HloInstruction::CreateBinary(
3921 power->shape(), HloOpcode::kMultiply, a, b));
3922 return ReplaceWithNewInstruction(
3923 power, HloInstruction::CreateUnary(power->shape(), HloOpcode::kExp,
3924 a_times_b));
3925 }
3926
3927 VLOG(10) << "trying transform [pow(A, 2) => A*A]: " << power->ToString();
3928 if (IsAll(rhs, 2)) {
3929 return ReplaceWithNewInstruction(
3930 power, HloInstruction::CreateBinary(power->shape(),
3931 HloOpcode::kMultiply, lhs, lhs));
3932 }
3933
3934 // Pow(A, 3) is used in GELU.
3935 VLOG(10) << "trying transform [pow(A, 3) => A*A*A]: " << power->ToString();
3936 if (IsAll(rhs, 3)) {
3937 HloInstruction* tmp = power->AddInstruction(HloInstruction::CreateBinary(
3938 power->shape(), HloOpcode::kMultiply, lhs, lhs));
3939 return ReplaceWithNewInstruction(
3940 power, HloInstruction::CreateBinary(power->shape(),
3941 HloOpcode::kMultiply, lhs, tmp));
3942 }
3943
3944 VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString();
3945 if (IsAll(rhs, -1)) {
3946 return ReplaceWithNewInstruction(
3947 power, HloInstruction::CreateBinary(power->shape(), HloOpcode::kDivide,
3948 MakeScalarLike(lhs, 1), lhs));
3949 }
3950
3951 return OkStatus();
3952 }
3953
3954 StatusOr<bool>
TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(HloInstruction * broadcast)3955 AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
3956 HloInstruction* broadcast) {
3957 TF_RET_CHECK(broadcast->opcode() == HloOpcode::kBroadcast);
3958 bool changed = false;
3959 if (ShapeUtil::IsScalar(broadcast->shape())) {
3960 return false;
3961 }
3962 HloInstruction* operand = broadcast->mutable_operand(0);
3963 auto is_scalar_broadcast = [](const HloInstruction* instruction) {
3964 return instruction->opcode() == HloOpcode::kBroadcast &&
3965 ShapeUtil::IsScalar(instruction->operand(0)->shape());
3966 };
3967 auto is_equal_broadcast = [operand,
3968 broadcast](const HloInstruction* instruction) {
3969 return instruction->opcode() == HloOpcode::kBroadcast &&
3970 ShapeUtil::Equal(operand->shape(),
3971 instruction->operand(0)->shape()) &&
3972 broadcast->dimensions() == instruction->dimensions();
3973 };
3974 auto is_compatible_broadcast = [&](const HloInstruction* instruction) {
3975 return is_scalar_broadcast(instruction) || is_equal_broadcast(instruction);
3976 };
3977 for (HloInstruction* user : broadcast->users()) {
3978 if (user->IsDead()) {
3979 continue;
3980 }
3981 // Do not move reshapes or broadcasts past copies since the shape the copy
3982 // will operate on will change.
3983 if (user->opcode() == HloOpcode::kCopy) {
3984 continue;
3985 }
3986 // Do not change the shape of fusion nodes in case there a multiple shapes
3987 // inside the fusion node already.
3988 if (user->opcode() == HloOpcode::kFusion) {
3989 continue;
3990 }
3991 if (!user->IsElementwise()) {
3992 continue;
3993 }
3994
3995 // Check if all the operands of the user are compatible broadcasts for
3996 // sinking. (They are either scalar broadcasts or broadcasts casting
3997 // from/to the same shape/dimensions)
3998 int64_t compatible_broadcast_count = 0;
3999 int64_t broadcast_use_count = 0;
4000 for (HloInstruction* user_operand : user->operands()) {
4001 if (is_compatible_broadcast(user_operand)) {
4002 ++compatible_broadcast_count;
4003 } else if (broadcast == user_operand) {
4004 ++broadcast_use_count;
4005 }
4006 }
4007 if (compatible_broadcast_count + broadcast_use_count !=
4008 user->operand_count()) {
4009 continue;
4010 }
4011 std::vector<HloInstruction*> new_operands;
4012 new_operands.reserve(user->operand_count());
4013
4014 Shape changed_shape;
4015 for (HloInstruction* user_operand : user->operands()) {
4016 // If this is a broadcast operand that is not our original broadcast input
4017 // to this function then we might need to change the input.
4018 if (is_compatible_broadcast(user_operand)) {
4019 // If this is a broadcast from a scalar value rewrite a broadcast from
4020 // the scalar to the new shape enforced from the other broadcast
4021 // operands.
4022 if (is_scalar_broadcast(user_operand)) {
4023 changed_shape = ShapeUtil::ChangeElementType(
4024 operand->shape(), user_operand->shape().element_type());
4025 simplifier_->UpdateLayout(&changed_shape);
4026 new_operands.push_back(
4027 user_operand->AddInstruction(HloInstruction::CreateBroadcast(
4028 changed_shape, user_operand->mutable_operand(0), {})));
4029 } else {
4030 // For the non-scalar broadcasts we guarantee that the shape of the
4031 // operand of the broadcast needs to be already a compatible shape.
4032 new_operands.push_back(user_operand->mutable_operand(0));
4033 }
4034 } else {
4035 CHECK_EQ(broadcast, user_operand);
4036 new_operands.push_back(operand);
4037 }
4038 }
4039 VLOG(4) << "Sinking broadcast after user:";
4040 VLOG(4) << " old broadcast: " << broadcast->ToString();
4041 VLOG(4) << " old user: " << user->ToString();
4042 changed_shape = ShapeUtil::ChangeElementType(operand->shape(),
4043 user->shape().element_type());
4044 simplifier_->UpdateLayout(&changed_shape);
4045 HloInstruction* new_user = user->AddInstruction(
4046 user->CloneWithNewOperands(changed_shape, new_operands));
4047 VLOG(4) << " new user: " << new_user->ToString();
4048 HloInstruction* new_broadcast =
4049 broadcast->AddInstruction(HloInstruction::CreateBroadcast(
4050 user->shape(), new_user, broadcast->dimensions()));
4051 VLOG(4) << " new broadcast: " << new_broadcast->ToString();
4052 TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(new_broadcast));
4053 changed = true;
4054 }
4055 return changed;
4056 }
4057
4058 namespace {
4059 template <typename T>
TryRemainderToAnd(HloInstruction * remainder,HloComputation * computation,AlgebraicSimplifier * simplifier)4060 std::unique_ptr<HloInstruction> TryRemainderToAnd(
4061 HloInstruction* remainder, HloComputation* computation,
4062 AlgebraicSimplifier* simplifier) {
4063 HloInstruction *a, *b, *c;
4064 CHECK(Match(remainder, m::Remainder(m::Op(&a), m::Op(&b))));
4065
4066 if (ShapeUtil::ElementIsIntegral(remainder->shape()) &&
4067 !Match(b, m::ConstantEffectiveScalar(&c)) &&
4068 !Match(b, m::Broadcast(m::ConstantEffectiveScalar(&c)))) {
4069 return nullptr;
4070 }
4071
4072 if (ShapeUtil::ElementIsSigned(remainder->shape())) {
4073 int64_t b_value = c->literal().GetFirstElement<T>();
4074 if (b_value > 0 && absl::has_single_bit(static_cast<uint64_t>(b_value))) {
4075 // Handle negative dividends by negating the result of the division.
4076 HloInstruction* zero_like_a = BroadcastZeros(
4077 computation, a->shape().element_type(), a->shape().dimensions());
4078
4079 Shape compare_shape = ShapeUtil::ChangeElementType(a->shape(), PRED);
4080 simplifier->UpdateLayout(&compare_shape);
4081 auto* dividend_is_negative =
4082 remainder->AddInstruction(HloInstruction::CreateCompare(
4083 compare_shape, a, zero_like_a, ComparisonDirection::kLt));
4084
4085 auto* negated_dividend = remainder->AddInstruction(
4086 HloInstruction::CreateUnary(a->shape(), HloOpcode::kNegate, a));
4087
4088 auto* abs_dividend =
4089 remainder->AddInstruction(HloInstruction::CreateTernary(
4090 a->shape(), HloOpcode::kSelect, dividend_is_negative,
4091 negated_dividend, a));
4092
4093 auto* quotient = remainder->AddInstruction(HloInstruction::CreateBinary(
4094 remainder->shape(), HloOpcode::kAnd, abs_dividend,
4095 MakeScalarLike(abs_dividend, b_value - 1)));
4096
4097 auto* neqated_quotient =
4098 remainder->AddInstruction(HloInstruction::CreateUnary(
4099 quotient->shape(), HloOpcode::kNegate, quotient));
4100
4101 return HloInstruction::CreateTernary(
4102 remainder->shape(), HloOpcode::kSelect, dividend_is_negative,
4103 neqated_quotient, quotient);
4104 }
4105 } else {
4106 uint64_t b_value = c->literal().GetFirstElement<T>();
4107 if (absl::has_single_bit(b_value)) {
4108 HloInstruction* mask_amount =
4109 remainder->AddInstruction(simplifier->CreateConstantWithLayoutUpdated(
4110 LiteralUtil::CreateR0<T>(b_value - 1)));
4111 if (!ShapeUtil::IsScalar(b->shape())) {
4112 mask_amount = remainder->AddInstruction(
4113 HloInstruction::CreateBroadcast(b->shape(), mask_amount, {}));
4114 }
4115 return HloInstruction::CreateBinary(remainder->shape(), HloOpcode::kAnd,
4116 a, mask_amount);
4117 }
4118 }
4119 return nullptr;
4120 }
4121 } // namespace
4122
HandleRemainder(HloInstruction * remainder)4123 Status AlgebraicSimplifierVisitor::HandleRemainder(HloInstruction* remainder) {
4124 HloInstruction *a, *b;
4125 CHECK(Match(remainder, m::Remainder(m::Op(&a), m::Op(&b))));
4126
4127 // (A % B) % B == A % B.
4128 if (Match(a, m::Remainder(m::Op(), m::Op().Is(b)))) {
4129 return ReplaceInstruction(remainder, a);
4130 }
4131
4132 // A % B => A & (B - 1) if B is a power of 2.
4133 switch (remainder->shape().element_type()) {
4134 case S8:
4135 if (std::unique_ptr<HloInstruction> shift =
4136 TryRemainderToAnd<int8_t>(remainder, computation_, simplifier_)) {
4137 return ReplaceWithNewInstruction(remainder, std::move(shift));
4138 }
4139 break;
4140 case S16:
4141 if (std::unique_ptr<HloInstruction> shift = TryRemainderToAnd<int16_t>(
4142 remainder, computation_, simplifier_)) {
4143 return ReplaceWithNewInstruction(remainder, std::move(shift));
4144 }
4145 break;
4146 case S32:
4147 if (std::unique_ptr<HloInstruction> shift = TryRemainderToAnd<int32_t>(
4148 remainder, computation_, simplifier_)) {
4149 return ReplaceWithNewInstruction(remainder, std::move(shift));
4150 }
4151 break;
4152 case S64:
4153 if (std::unique_ptr<HloInstruction> shift = TryRemainderToAnd<int64_t>(
4154 remainder, computation_, simplifier_)) {
4155 return ReplaceWithNewInstruction(remainder, std::move(shift));
4156 }
4157 break;
4158 case U8:
4159 if (std::unique_ptr<HloInstruction> shift = TryRemainderToAnd<uint8_t>(
4160 remainder, computation_, simplifier_)) {
4161 return ReplaceWithNewInstruction(remainder, std::move(shift));
4162 }
4163 break;
4164 case U16:
4165 if (std::unique_ptr<HloInstruction> shift = TryRemainderToAnd<uint16_t>(
4166 remainder, computation_, simplifier_)) {
4167 return ReplaceWithNewInstruction(remainder, std::move(shift));
4168 }
4169 break;
4170 case U32:
4171 if (std::unique_ptr<HloInstruction> shift = TryRemainderToAnd<uint32_t>(
4172 remainder, computation_, simplifier_)) {
4173 return ReplaceWithNewInstruction(remainder, std::move(shift));
4174 }
4175 break;
4176 case U64:
4177 if (std::unique_ptr<HloInstruction> shift = TryRemainderToAnd<uint64_t>(
4178 remainder, computation_, simplifier_)) {
4179 return ReplaceWithNewInstruction(remainder, std::move(shift));
4180 }
4181 break;
4182 default:
4183 break;
4184 }
4185
4186 // If M < N, then {0, ..., M} % N ==> {0, ..., M}.
4187 //
4188 // Currently this only covers the case when N is a broadcasted constant
4189 // scalar. We could also cover the case when N is a non-broadcasted constant
4190 // with the same value repeated.
4191 HloInstruction* iota;
4192 HloInstruction* divisor;
4193 if (Match(remainder,
4194 m::Remainder(m::Iota(&iota),
4195 m::Broadcast(m::ConstantEffectiveScalar(&divisor))))) {
4196 // The iota counts {0, ..., iota_upper_bound - 1}. (Actually this is
4197 // conservative; the iota may overflow and count up to a smaller value than
4198 // this. But that's OK for our purposes here.)
4199 int64_t iota_upper_bound = iota->shape().dimensions(
4200 Cast<HloIotaInstruction>(iota)->iota_dimension());
4201 std::optional<int64_t> divisor_val = divisor->literal().GetIntegralAsS64(
4202 std::vector<int64_t>(0, divisor->shape().dimensions_size()));
4203 if (divisor_val && *divisor_val >= iota_upper_bound) {
4204 return ReplaceInstruction(remainder, iota);
4205 }
4206 }
4207
4208 // (X + N) % N = X % N, so long as X + N does not overflow.
4209 //
4210 // We don't have range tracking in XLA that would let us know whether X + N
4211 // overflows, so for now we only do this simplification when X is an iota. We
4212 // could add other operations where it's easy to see a range, such as
4213 // remainder, convert, etc., though at some point we'd probably want a
4214 // range-tracking analysis.
4215 HloInstruction* bcast;
4216 HloInstruction* addend;
4217 if (Match(
4218 remainder,
4219 m::Remainder(
4220 m::AddAnyOrder(m::Iota(&iota),
4221 m::Broadcast(m::ConstantEffectiveScalar(&addend))),
4222 m::Broadcast(&bcast, m::ConstantEffectiveScalar(&divisor)))) &&
4223 addend == divisor) {
4224 // The iota counts {0, ...iota_upper_bound - 1}, with the same caveat above
4225 // that iota_upper_bound is conservative, and the true upper bound may be
4226 // smaller.
4227 int64_t iota_upper_bound = iota->shape().dimensions(
4228 Cast<HloIotaInstruction>(iota)->iota_dimension());
4229 std::optional<int64_t> divisor_val = divisor->literal().GetIntegralAsS64(
4230 std::vector<int64_t>(0, divisor->shape().dimensions_size()));
4231 if (divisor_val) {
4232 // Check whether divisor_val + iota_upper_bound - 1 overflows.
4233 std::optional<int64_t> max_val =
4234 OverflowSafeAdd(*divisor_val, iota_upper_bound);
4235 if (max_val.has_value() &&
4236 FitsInIntegralType(*max_val, iota->shape().element_type())) {
4237 return ReplaceWithNewInstruction(
4238 remainder,
4239 HloInstruction::CreateBinary(remainder->shape(),
4240 HloOpcode::kRemainder, iota, bcast));
4241 }
4242 }
4243 }
4244
4245 return OkStatus();
4246 }
4247
HandleReshape(HloInstruction * reshape)4248 Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) {
4249 auto operand = reshape->mutable_operand(0);
4250
4251 // Reshape directly to empty constant if the shape contains zero-element
4252 // dimension.
4253 if (ShapeUtil::IsZeroElementArray(reshape->shape())) {
4254 // If the instruction doesn't have a layout, use a default layout for
4255 // the literal result.
4256 Shape reshaped_shape = reshape->shape();
4257 if (!LayoutUtil::HasLayout(reshaped_shape)) {
4258 LayoutUtil::SetToDefaultLayout(&reshaped_shape);
4259 }
4260 auto empty_constant = simplifier_->CreateConstantWithLayoutUpdated(
4261 Literal::CreateFromShape(reshaped_shape));
4262
4263 return ReplaceWithNewInstruction(reshape, std::move(empty_constant));
4264 }
4265
4266 // Delete no-op reshapes, i.e. where shape = operand shape.
4267 if (SameShape(reshape, operand)) {
4268 VLOG(3) << "deleting no-op reshape";
4269 return ReplaceInstruction(reshape, operand);
4270 }
4271
4272 // Merge reshapes.
4273 if (HloOpcode::kReshape == operand->opcode()) {
4274 return ReplaceWithNewInstruction(
4275 reshape, HloInstruction::CreateReshape(reshape->shape(),
4276 operand->mutable_operand(0)));
4277 }
4278
4279 if (operand->opcode() == HloOpcode::kRng && operand->user_count() == 1) {
4280 *operand->mutable_shape() = reshape->shape();
4281 return ReplaceInstruction(reshape, operand);
4282 }
4283
4284 if (HloOpcode::kBroadcast == reshape->operand(0)->opcode()) {
4285 auto opt_dims = ReshapeLeavesDimensionsUnmodified(
4286 reshape, reshape->operand(0)->dimensions());
4287 if (opt_dims.has_value()) {
4288 return ReplaceWithNewInstruction(
4289 reshape,
4290 HloInstruction::CreateBroadcast(
4291 reshape->shape(), reshape->mutable_operand(0)->mutable_operand(0),
4292 *opt_dims));
4293 }
4294 }
4295
4296 // reshape(iota) -> iota or a mixed radix calculation like
4297 // s32[2,3,4] reshape(s32[24] iota()) to
4298 // add(
4299 // add(s32[2,3,4] iota() iota_dimension=2,
4300 // 4 * s32[2,3,4] iota() iota_dimension=1),
4301 // 12 * s32[2,3,4] iota() iota_dimension=0).
4302 if (operand->opcode() == HloOpcode::kIota) {
4303 auto* iota = Cast<HloIotaInstruction>(operand);
4304 auto common_factors =
4305 CommonFactors(reshape->operand(0)->shape().dimensions(),
4306 reshape->shape().dimensions());
4307 auto iota_dim = absl::c_find_if(
4308 common_factors, [&](const std::pair<int64_t, int64_t>& dim_pair) {
4309 return dim_pair.first == iota->iota_dimension() &&
4310 reshape->shape().dimensions(dim_pair.second) > 1;
4311 });
4312 auto next_dim = absl::c_find_if(
4313 common_factors, [&](const std::pair<int64_t, int64_t>& dim_pair) {
4314 return dim_pair.first == iota->iota_dimension() + 1;
4315 });
4316 if (iota_dim != common_factors.end() && next_dim != common_factors.end()) {
4317 int64_t multiplier = 1;
4318 HloInstruction* new_reshape = nullptr;
4319
4320 for (int64_t dim = (iota_dim + 1)->second - 1; dim >= iota_dim->second;
4321 --dim) {
4322 HloInstruction* new_iota = iota->AddInstruction(
4323 HloInstruction::CreateIota(reshape->shape(), dim));
4324 if (new_reshape) {
4325 new_reshape = reshape->AddInstruction(HloInstruction::CreateBinary(
4326 reshape->shape(), HloOpcode::kAdd, new_reshape,
4327 reshape->AddInstruction(HloInstruction::CreateBinary(
4328 reshape->shape(), HloOpcode::kMultiply, new_iota,
4329 MakeScalarLike(reshape, multiplier)))));
4330 } else {
4331 new_reshape = new_iota;
4332 }
4333 multiplier *= reshape->shape().dimensions(dim);
4334 }
4335 return ReplaceInstruction(reshape, new_reshape);
4336 }
4337 }
4338
4339 // Moves the reshape in reshape(dus(...), x, ...)) before dus so that it can
4340 // enable other optimizations, e.g., merging with broadcast, and sparse update
4341 // (add(x, dus(broadcast(0), y, ...)) -> dus(x, add(ds(x), y), ...)).
4342 if (!options_.is_layout_sensitive()) {
4343 HloInstruction* dus;
4344 HloInstruction* slice;
4345 std::optional<ShapeUtil::ShapeEqualityDescriptor> trivial_reshape =
4346 reshape->ReshapeMerelyInsertsOrDeletes1SizedDimensions();
4347 // 1-sized dimensions added and removed will be one sized in both the update
4348 // slice and the dynamic-update-slice result.
4349 if (trivial_reshape.has_value() &&
4350 Match(reshape->mutable_operand(0),
4351 m::Op(&dus)
4352 .WithOpcode(HloOpcode::kDynamicUpdateSlice)
4353 .WithOperand(1, m::Op(&slice))) &&
4354 !dus->has_sharding() && !dus->operand(0)->has_sharding()) {
4355 auto new_operand = reshape->AddInstruction(HloInstruction::CreateReshape(
4356 reshape->shape(), dus->mutable_operand(0)));
4357 std::vector<int64_t> new_slice_shape;
4358 std::vector<HloInstruction*> new_dus_operands;
4359 new_dus_operands.push_back(new_operand);
4360 new_dus_operands.push_back(nullptr);
4361 auto zero = MakeScalarLike(dus->mutable_operand(2), 0);
4362 const Shape& old_slice_shape = dus->operand(1)->shape();
4363 for (int64_t i = 0; i <= old_slice_shape.rank(); ++i) {
4364 if (absl::c_linear_search(trivial_reshape->deleted_dimensions, i)) {
4365 continue;
4366 }
4367 while (absl::c_linear_search(trivial_reshape->inserted_dimensions,
4368 new_slice_shape.size())) {
4369 new_slice_shape.push_back(1);
4370 new_dus_operands.push_back(zero);
4371 }
4372 if (i < old_slice_shape.rank()) {
4373 new_slice_shape.push_back(old_slice_shape.dimensions(i));
4374 new_dus_operands.push_back(dus->mutable_operand(2 + i));
4375 }
4376 }
4377 auto new_slice = reshape->AddInstruction(HloInstruction::CreateReshape(
4378 ShapeUtil::MakeShape(old_slice_shape.element_type(), new_slice_shape),
4379 slice));
4380 new_dus_operands[1] = new_slice;
4381 auto new_dus =
4382 dus->CloneWithNewOperands(reshape->shape(), new_dus_operands);
4383 return ReplaceWithNewInstruction(reshape, std::move(new_dus));
4384 }
4385 }
4386
4387 // Make this a bitcast if possible.
4388 if (HloInstruction* bitcast_operand =
4389 BitcastingOperandOfReshapeOrCopyChain(reshape, options_)) {
4390 ReplaceWithBitcast(reshape, bitcast_operand);
4391 }
4392 return OkStatus();
4393 }
4394
HandleReverse(HloInstruction * reverse)4395 Status AlgebraicSimplifierVisitor::HandleReverse(HloInstruction* reverse) {
4396 // When all the dimensions to reverse are trivial (i.e. the bound is 1),
4397 // there is nothing to be done.
4398 auto dim_is_one = [&](int64_t i) -> bool {
4399 return reverse->shape().dimensions(i) == 1;
4400 };
4401 if (absl::c_all_of(reverse->dimensions(), dim_is_one)) {
4402 return ReplaceInstruction(reverse, reverse->mutable_operand(0));
4403 }
4404 return OkStatus();
4405 }
4406
TrySimplifyScalarSlice(HloInstruction * slice)4407 StatusOr<bool> AlgebraicSimplifierVisitor::TrySimplifyScalarSlice(
4408 HloInstruction* slice) {
4409 // Only try to do this for effective scalars. We could do the same for slicing
4410 // out larger pieces of padding (replacing with a broadcast of the padding
4411 // value), but this is probably not worth it.
4412 if (!ShapeUtil::IsEffectiveScalar(slice->shape())) {
4413 return false;
4414 }
4415
4416 if (slice->operand(0)->opcode() == HloOpcode::kConcatenate) {
4417 VLOG(10) << "Trying to simplify scalar slice of concat";
4418 // Only do this for R1, there's no chance of this being useful otherwise.
4419 if (slice->shape().rank() != 1) {
4420 VLOG(10) << "Not folding, slice is not rank 1";
4421 return false;
4422 }
4423 HloConcatenateInstruction* concat =
4424 Cast<HloConcatenateInstruction>(slice->mutable_operand(0));
4425 int64_t operand_start = 0;
4426 int64_t operand_num = 0;
4427 // Weird loop structure to avoid annoying off-by-one errors.
4428 while (true) {
4429 TF_RET_CHECK(operand_num < concat->operand_count());
4430 const HloInstruction* operand = concat->operand(operand_num);
4431 int64_t next_operand_start =
4432 operand_start + operand->shape().dimensions(0);
4433 if (next_operand_start > slice->slice_starts(0)) {
4434 break;
4435 }
4436 operand_start = next_operand_start;
4437 operand_num++;
4438 }
4439
4440 bool replaced = ReplaceInstructionIfCompatible(
4441 slice, concat->mutable_operand(operand_num));
4442 if (replaced) {
4443 VLOG(10) << "Folding scalar slice of concat into concat operand";
4444 } else {
4445 VLOG(10) << "Folding scalar slice of concat into slice of concat operand";
4446 TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
4447 slice, HloInstruction::CreateSlice(
4448 slice->shape(), concat->mutable_operand(operand_num),
4449 {slice->slice_starts(0) - operand_start},
4450 {slice->slice_starts(0) - operand_start + 1},
4451 slice->slice_strides())));
4452 }
4453 return true;
4454 }
4455
4456 return false;
4457 }
4458
TryToReorderSliceAndReshape(HloInstruction * slice)4459 StatusOr<bool> AlgebraicSimplifierVisitor::TryToReorderSliceAndReshape(
4460 HloInstruction* slice) {
4461 CHECK_EQ(slice->opcode(), HloOpcode::kSlice);
4462 if (!IsUnstridedSlice(slice)) {
4463 return false;
4464 }
4465 HloInstruction* reshape = slice->mutable_operand(0);
4466 if (reshape->opcode() != HloOpcode::kReshape) {
4467 return false;
4468 }
4469 HloInstruction* new_slice_operand = reshape->mutable_operand(0);
4470 int64_t slice_rank = slice->shape().rank();
4471 std::vector<int64_t> sliced_dims;
4472 for (int64_t i = 0; i < slice_rank; ++i) {
4473 if (slice->slice_starts(i) != 0 ||
4474 slice->slice_limits(i) != reshape->shape().dimensions(i)) {
4475 sliced_dims.push_back(i);
4476 }
4477 }
4478
4479 if (sliced_dims.size() == 1 && sliced_dims[0] == 0 &&
4480 slice->slice_starts(0) == 0) {
4481 const Shape& new_slice_shape = new_slice_operand->shape();
4482 const int64_t rank = new_slice_shape.rank();
4483 std::vector<int64_t> new_slice_starts(rank, 0);
4484 std::vector<int64_t> new_slice_stides(rank, 1);
4485 std::vector<int64_t> new_slice_limits(new_slice_shape.dimensions().begin(),
4486 new_slice_shape.dimensions().end());
4487 int64_t slice_elements = ShapeUtil::ElementsIn(slice->shape());
4488 for (int64_t i = rank - 1; i >= 0; --i) {
4489 if (slice_elements >= new_slice_limits[i]) {
4490 if (slice_elements % new_slice_limits[i] != 0) {
4491 return false;
4492 }
4493 slice_elements /= new_slice_limits[i];
4494 } else {
4495 new_slice_limits[i] = slice_elements;
4496 slice_elements = 1;
4497 }
4498 }
4499 HloInstruction* new_slice =
4500 slice->AddInstruction(HloInstruction::CreateSlice(
4501 ShapeUtil::MakeShape(new_slice_shape.element_type(),
4502 new_slice_limits),
4503 new_slice_operand, new_slice_starts, new_slice_limits,
4504 new_slice_stides));
4505 simplifier_->UpdateLayout(new_slice->mutable_shape());
4506 TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
4507 slice, HloInstruction::CreateReshape(slice->shape(), new_slice)));
4508 return true;
4509 }
4510 return false;
4511 }
4512
4513 // Allowing a slice to move through a reverse with any necessary updates to the
4514 // slice config.
TryToReorderSliceAndReverse(HloInstruction * slice)4515 StatusOr<bool> AlgebraicSimplifierVisitor::TryToReorderSliceAndReverse(
4516 HloInstruction* slice) {
4517 VLOG(2) << "Entered TryToReorderSliceAndReverse for slice:"
4518 << slice->ToString();
4519 if (Match(slice, m::Slice(m::Reverse()))) {
4520 HloInstruction* reverse = slice->mutable_operand(0);
4521 HloInstruction* reverse_operand = reverse->mutable_operand(0);
4522 std::vector<int64_t> new_starts = slice->slice_starts();
4523 std::vector<int64_t> new_limits = slice->slice_limits();
4524 std::vector<int64_t> new_strides = slice->slice_strides();
4525 for (auto rdim : reverse->dimensions()) {
4526 int64_t start = slice->slice_starts(rdim);
4527 int64_t limit = slice->slice_limits(rdim);
4528 int64_t stride = slice->slice_strides(rdim);
4529 // find_nth allows us to compute the appropriate index to begin
4530 // with during reverse even in the presence of non-unit strides
4531 int64_t find_nth = (limit - start - 1) / stride;
4532 find_nth = start + find_nth * stride;
4533 limit = find_nth + 1;
4534 new_starts[rdim] =
4535 (reverse->shape().dimensions(rdim) - start) - (limit - start);
4536 new_limits[rdim] = reverse->shape().dimensions(rdim) - start;
4537 VLOG(2) << "Analyzing dim:" << rdim << " (start,limit):" << start << ","
4538 << limit << " and new (start, limit):" << new_starts[rdim] << ","
4539 << new_limits[rdim];
4540 }
4541 // New slice formed from the reverse_operand, but strides and shape of the
4542 // slice output remains the same. New slice's starts and limits are updated
4543 // for ONLY the reversed dimensions as indicated above.
4544 HloInstruction* new_slice = slice->AddInstruction(
4545 HloInstruction::CreateSlice(slice->shape(), reverse_operand, new_starts,
4546 new_limits, new_strides));
4547 simplifier_->UpdateLayout(new_slice->mutable_shape());
4548 TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
4549 slice, HloInstruction::CreateReverse(new_slice->shape(), new_slice,
4550 reverse->dimensions())));
4551 // We do not delete the old reverse, since there might be another
4552 // consumer of that reverse (i.e., full reverse output). DCE should take
4553 // care of any deletion that is necessary if there was no use of reverse.
4554 return true;
4555 }
4556 return false;
4557 }
4558
HandleSlice(HloInstruction * slice)4559 Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) {
4560 // Delete no-op slices, i.e. where shape = operand shape.
4561 if (ReplaceInstructionIfCompatible(slice, slice->mutable_operand(0))) {
4562 return OkStatus();
4563 }
4564
4565 HloInstruction* pad;
4566 HloInstruction* pad_operand;
4567 if (Match(slice, m::Slice(m::Pad(&pad, m::Op(&pad_operand), m::Op())))) {
4568 // Is the result of the slice the pad operand.
4569 bool slice_undoes_pad = true;
4570 // Can the slice be moved to the pad_operand without any padding being read.
4571 bool slice_inside_pad = true;
4572 // Does this slice slice out pading only.
4573 bool slice_in_padding = false;
4574 std::vector<int64_t> new_starts = slice->slice_starts();
4575 std::vector<int64_t> new_limits = slice->slice_limits();
4576 for (int64_t i = 0; i < slice->shape().rank(); ++i) {
4577 const int64_t start = slice->slice_starts(i);
4578 const int64_t stride = slice->slice_strides(i);
4579 const int64_t limit = slice->slice_limits(i);
4580 const int64_t size = pad->shape().dimensions(i);
4581
4582 const auto& dim = pad->padding_config().dimensions(i);
4583 const int64_t low = dim.edge_padding_low();
4584 const int64_t high = dim.edge_padding_high();
4585 const int64_t interior = dim.interior_padding();
4586 const int64_t edge = size - high;
4587
4588 if (limit <= low || start >= edge) {
4589 slice_in_padding = true;
4590 break;
4591 }
4592
4593 if (start != low || stride - 1 != interior) {
4594 slice_undoes_pad = false;
4595 }
4596
4597 if (start < low || limit > edge || interior != 0 || stride != 1) {
4598 slice_inside_pad = false;
4599 }
4600 new_starts[i] -= low;
4601 new_limits[i] -= low;
4602 }
4603 if (slice_in_padding) {
4604 HloInstruction* broadcast =
4605 MakeBroadcastHlo(pad->mutable_operand(1), {}, slice->shape());
4606 *(broadcast->mutable_shape()) = slice->shape();
4607 return ReplaceInstruction(slice, broadcast);
4608 }
4609 if (slice_undoes_pad &&
4610 ReplaceInstructionIfCompatible(slice, pad_operand)) {
4611 return OkStatus();
4612 }
4613 if (slice_inside_pad) {
4614 TF_ASSIGN_OR_RETURN(HloInstruction * new_slice,
4615 MakeSliceHlo(pad_operand, new_starts, new_limits,
4616 slice->slice_strides()));
4617 *(new_slice->mutable_shape()) = slice->shape();
4618 return ReplaceInstruction(slice, new_slice);
4619 }
4620 }
4621
4622 if (slice->operand(0)->opcode() == HloOpcode::kSlice &&
4623 IsUnstridedSlice(slice) && IsUnstridedSlice(slice->operand(0))) {
4624 HloInstruction* operand_slice = slice->mutable_operand(0);
4625 std::vector<int64_t> new_slice_starts = slice->slice_starts();
4626 std::vector<int64_t> new_slice_limits = slice->slice_limits();
4627 for (int64_t i = 0; i < new_slice_starts.size(); ++i) {
4628 new_slice_starts[i] += operand_slice->slice_starts(i);
4629 new_slice_limits[i] += operand_slice->slice_starts(i);
4630 }
4631 return ReplaceWithNewInstruction(
4632 slice, HloInstruction::CreateSlice(
4633 slice->shape(), operand_slice->mutable_operand(0),
4634 new_slice_starts, new_slice_limits, slice->slice_strides()));
4635 }
4636
4637 auto only_broadcast_dims_sliced = [&] {
4638 if (slice->operand(0)->opcode() != HloOpcode::kBroadcast) {
4639 return false;
4640 }
4641 for (int64_t dim : slice->operand(0)->dimensions()) {
4642 if (slice->slice_starts(dim) != 0 || slice->slice_strides(dim) != 1 ||
4643 slice->slice_limits(dim) !=
4644 slice->operand(0)->shape().dimensions(dim)) {
4645 return false;
4646 }
4647 }
4648 return true;
4649 };
4650 if (only_broadcast_dims_sliced()) {
4651 return ReplaceWithNewInstruction(
4652 slice,
4653 HloInstruction::CreateBroadcast(
4654 slice->shape(), slice->mutable_operand(0)->mutable_operand(0),
4655 slice->mutable_operand(0)->dimensions()));
4656 }
4657
4658 TF_ASSIGN_OR_RETURN(bool replaced, TrySimplifyScalarSlice(slice));
4659 if (replaced) {
4660 return OkStatus();
4661 }
4662
4663 HloInstruction* broadcast;
4664 HloInstruction* broadcast_operand;
4665 if (Match(slice,
4666 m::Slice(m::Broadcast(&broadcast, m::Op(&broadcast_operand))))) {
4667 std::vector<int64_t> new_slice_starts;
4668 std::vector<int64_t> new_slice_strides;
4669 std::vector<int64_t> new_slice_limits;
4670 new_slice_starts.reserve(broadcast_operand->shape().rank());
4671 new_slice_strides.reserve(broadcast_operand->shape().rank());
4672 new_slice_limits.reserve(broadcast_operand->shape().rank());
4673 for (int64_t dim : broadcast->dimensions()) {
4674 new_slice_starts.push_back(slice->slice_starts(dim));
4675 new_slice_strides.push_back(slice->slice_strides(dim));
4676 new_slice_limits.push_back(slice->slice_limits(dim));
4677 }
4678 VLOG(3) << "Sink broadcast through slice";
4679 VLOG(3) << "Original slice: " << slice->ToString();
4680 VLOG(3) << "Original broadcast: " << broadcast->ToString();
4681 auto new_slice_shape = broadcast_operand->shape();
4682 for (int64_t i = 0; i < broadcast_operand->shape().rank(); ++i) {
4683 int64_t size_i = (new_slice_limits[i] - new_slice_starts[i] +
4684 new_slice_strides[i] - 1) /
4685 new_slice_strides[i];
4686 new_slice_shape.set_dimensions(i, size_i);
4687 }
4688 simplifier_->UpdateLayout(&new_slice_shape);
4689 auto new_slice = slice->AddInstruction(HloInstruction::CreateSlice(
4690 new_slice_shape, broadcast_operand, new_slice_starts, new_slice_limits,
4691 new_slice_strides));
4692 auto new_broadcast =
4693 broadcast->AddInstruction(HloInstruction::CreateBroadcast(
4694 slice->shape(), new_slice, broadcast->dimensions()));
4695 VLOG(3) << "New slice: " << slice->ToString();
4696 VLOG(3) << "New broadcast: " << new_broadcast->ToString();
4697 return ReplaceInstruction(slice, new_broadcast);
4698 }
4699
4700 // Try to simplify concat -> slice to an operand of concat.
4701 if (slice->operand(0)->opcode() == HloOpcode::kConcatenate &&
4702 IsUnstridedSlice(slice)) {
4703 HloInstruction* concat = slice->mutable_operand(0);
4704 int64_t concat_dim = concat->concatenate_dimension();
4705 int64_t piece_start = 0;
4706 std::optional<int64_t> start_operand;
4707 std::optional<int64_t> limit_operand;
4708 int64_t concat_start;
4709 int64_t concat_limit;
4710 const int64_t slice_start = slice->slice_starts(concat_dim);
4711 const int64_t slice_limit = slice->slice_limits(concat_dim);
4712 for (int64_t i = 0; i < concat->operand_count(); ++i) {
4713 const HloInstruction* piece = concat->operand(i);
4714 const int64_t piece_size = piece->shape().dimensions(concat_dim);
4715 if (!start_operand && piece_start <= slice_start &&
4716 piece_size + piece_start > slice_start) {
4717 start_operand = i;
4718 concat_start = piece_start;
4719 }
4720 piece_start += piece_size;
4721 if (!limit_operand && piece_start >= slice_limit) {
4722 limit_operand = i + 1;
4723 concat_limit = piece_start;
4724 break;
4725 }
4726 }
4727 if (start_operand && limit_operand &&
4728 *start_operand + 1 == *limit_operand &&
4729 SameShape(concat->operand(*start_operand), slice)) {
4730 return ReplaceInstruction(slice, concat->mutable_operand(*start_operand));
4731 }
4732 if (start_operand && limit_operand &&
4733 *limit_operand - *start_operand < concat->operand_count()) {
4734 std::vector<int64_t> starts = slice->slice_starts();
4735 starts[concat_dim] = starts[concat_dim] - concat_start;
4736 std::vector<int64_t> strides = slice->slice_strides();
4737 std::vector<int64_t> limits = slice->slice_limits();
4738 limits[concat_dim] =
4739 starts[concat_dim] + slice->shape().dimensions(concat_dim);
4740 HloInstruction* operand = concat->mutable_operand(*start_operand);
4741 if (*start_operand + 1 != *limit_operand) {
4742 TF_ASSIGN_OR_RETURN(
4743 HloInstruction * new_concat,
4744 MakeConcatHlo(
4745 absl::MakeSpan(concat->operands())
4746 .subspan(*start_operand, *limit_operand - *start_operand),
4747 concat_dim));
4748 *new_concat->mutable_shape()->mutable_layout() =
4749 concat->shape().layout();
4750 simplifier_->UpdateLayout(new_concat->mutable_shape());
4751 concat->SetupDerivedInstruction(new_concat);
4752 operand = new_concat;
4753 }
4754 return ReplaceWithNewInstruction(
4755 slice, HloInstruction::CreateSlice(slice->shape(), operand, starts,
4756 limits, strides));
4757 }
4758 }
4759
4760 // Do not try to reorder slices and reshapes after layout assignment as it may
4761 // be invalid.
4762 if (!options_.is_layout_sensitive()) {
4763 TF_ASSIGN_OR_RETURN(replaced, TryToReorderSliceAndReshape(slice));
4764 }
4765 if (replaced) {
4766 return OkStatus();
4767 }
4768
4769 bool reversed = false;
4770 if (Match(slice, m::Slice(m::Reverse(m::Op())))) {
4771 TF_ASSIGN_OR_RETURN(reversed, TryToReorderSliceAndReverse(slice));
4772 }
4773 if (reversed) {
4774 return OkStatus();
4775 }
4776
4777 return OkStatus();
4778 }
4779
HandleRsqrt(HloInstruction * rsqrt)4780 Status AlgebraicSimplifierVisitor::HandleRsqrt(HloInstruction* rsqrt) {
4781 VLOG(10) << "trying transform [rsqrt(Pow(A, -2)) => |A|] "
4782 << rsqrt->ToString();
4783 HloInstruction* rsqrt_operand = rsqrt->mutable_operand(0);
4784 if (rsqrt_operand->opcode() == HloOpcode::kPower &&
4785 IsAll(rsqrt_operand->operand(1), -2) &&
4786 IsPositive(rsqrt_operand, options_)) {
4787 return ReplaceWithNewInstruction(
4788 rsqrt, HloInstruction::CreateUnary(rsqrt->shape(), HloOpcode::kAbs,
4789 rsqrt_operand->mutable_operand(0)));
4790 }
4791
4792 VLOG(10) << "trying transform [rsqrt(Divide(1, A)) => sqrt(A)] "
4793 << rsqrt->ToString();
4794 if (rsqrt_operand->opcode() == HloOpcode::kDivide &&
4795 IsAll(rsqrt_operand->operand(0), 1) &&
4796 IsPositive(rsqrt_operand->operand(1), options_)) {
4797 return ReplaceWithNewInstruction(
4798 rsqrt, HloInstruction::CreateUnary(rsqrt->shape(), HloOpcode::kSqrt,
4799 rsqrt_operand->mutable_operand(1)));
4800 }
4801
4802 return OkStatus();
4803 }
4804
HandleDynamicSlice(HloInstruction * dynamic_slice)4805 Status AlgebraicSimplifierVisitor::HandleDynamicSlice(
4806 HloInstruction* dynamic_slice) {
4807 auto operand = dynamic_slice->mutable_operand(0);
4808 if (ShapeUtil::IsScalar(dynamic_slice->shape())) {
4809 return ReplaceInstruction(dynamic_slice, operand);
4810 }
4811 // DynamicSlice where operand has the same size as the output is simply equal
4812 // to operand.
4813 if (SameShape(operand, dynamic_slice)) {
4814 return ReplaceInstruction(dynamic_slice, operand);
4815 }
4816
4817 HloInstruction* broadcast_operand;
4818 if (Match(operand, m::Broadcast(m::Op(&broadcast_operand)))) {
4819 std::vector<HloInstruction*> new_indices;
4820 new_indices.reserve(broadcast_operand->shape().rank());
4821 std::vector<int64_t> new_slice_sizes;
4822 new_slice_sizes.reserve(broadcast_operand->shape().rank());
4823
4824 for (int64_t dim : operand->dimensions()) {
4825 new_indices.push_back(dynamic_slice->mutable_operand(1 + dim));
4826 new_slice_sizes.push_back(dynamic_slice->slice_sizes(dim));
4827 }
4828
4829 VLOG(3) << "Sink broadcast through dynamic slice";
4830 VLOG(3) << "Original dynamic slice: " << dynamic_slice->ToString();
4831 VLOG(3) << "Original broadcast: " << operand->ToString();
4832 HloInstruction* new_dynamic_slice = broadcast_operand;
4833 if (!new_slice_sizes.empty()) {
4834 auto new_ds_shape = broadcast_operand->shape();
4835 for (int64_t i = 0; i < broadcast_operand->shape().rank(); ++i) {
4836 new_ds_shape.set_dimensions(i, new_slice_sizes[i]);
4837 }
4838 simplifier_->UpdateLayout(&new_ds_shape);
4839 new_dynamic_slice =
4840 dynamic_slice->AddInstruction(HloInstruction::CreateDynamicSlice(
4841 new_ds_shape, broadcast_operand, new_indices, new_slice_sizes));
4842 }
4843 auto new_broadcast =
4844 operand->AddInstruction(HloInstruction::CreateBroadcast(
4845 dynamic_slice->shape(), new_dynamic_slice, operand->dimensions()));
4846 VLOG(3) << "New dynamic slice: " << dynamic_slice->ToString();
4847 VLOG(3) << "New broadcast: " << new_broadcast->ToString();
4848 return ReplaceInstruction(dynamic_slice, new_broadcast);
4849 }
4850
4851 HloInstruction *reshape, *reshape_operand;
4852 if (Match(operand, m::Reshape(&reshape, m::Op(&reshape_operand))) &&
4853 reshape->ReshapeMerelyInsertsOrDeletes1SizedDimensions().has_value() &&
4854 !options_.is_layout_sensitive()) {
4855 int64_t slice_dim = 0;
4856 HloInstruction* zero = MakeScalarLike(dynamic_slice->mutable_operand(1), 0);
4857 std::vector<HloInstruction*> starts;
4858 starts.reserve(reshape_operand->shape().rank());
4859 std::vector<int64_t> slice_sizes;
4860 slice_sizes.reserve(reshape_operand->shape().rank());
4861 for (int64_t dim = 0; dim < reshape_operand->shape().rank(); ++dim) {
4862 if (reshape_operand->shape().dimensions(dim) == 1) {
4863 starts.push_back(zero);
4864 slice_sizes.push_back(1);
4865 continue;
4866 }
4867 while (dynamic_slice->operand(0)->shape().dimensions(slice_dim) == 1) {
4868 ++slice_dim;
4869 }
4870 starts.push_back(dynamic_slice->mutable_operand(1 + slice_dim));
4871 slice_sizes.push_back(dynamic_slice->slice_sizes(slice_dim));
4872 ++slice_dim;
4873 }
4874 HloInstruction* new_dynamic_slice =
4875 dynamic_slice->AddInstruction(HloInstruction::CreateDynamicSlice(
4876 ShapeUtil::MakeShape(dynamic_slice->shape().element_type(),
4877 slice_sizes),
4878 reshape_operand, starts, slice_sizes));
4879 return ReplaceWithNewInstruction(
4880 dynamic_slice, HloInstruction::CreateReshape(dynamic_slice->shape(),
4881 new_dynamic_slice));
4882 }
4883
4884 HloInstruction *transpose, *transpose_operand;
4885 if (Match(operand, m::Transpose(&transpose, m::Op(&transpose_operand))) &&
4886 !options_.is_layout_sensitive()) {
4887 auto output_to_input = InversePermutation(transpose->dimensions());
4888 HloInstruction* new_slice =
4889 dynamic_slice->AddInstruction(HloInstruction::CreateDynamicSlice(
4890 ShapeUtil::PermuteDimensions(output_to_input,
4891 dynamic_slice->shape()),
4892 transpose_operand,
4893 Permute(absl::MakeSpan(dynamic_slice->operands().begin() + 1,
4894 dynamic_slice->operands().end()),
4895 output_to_input),
4896 Permute(dynamic_slice->dynamic_slice_sizes(), output_to_input)));
4897 return ReplaceWithNewInstruction(
4898 dynamic_slice,
4899 HloInstruction::CreateTranspose(dynamic_slice->shape(), new_slice,
4900 transpose->dimensions()));
4901 }
4902
4903 // Convert a dynamic slice into a slice if all offsets are constant and the
4904 // operand is not constant.
4905 if (operand->opcode() != HloOpcode::kConstant &&
4906 absl::c_all_of(absl::MakeSpan(dynamic_slice->operands().begin() + 1,
4907 dynamic_slice->operands().end()),
4908 [](HloInstruction* operand) {
4909 return operand->opcode() == HloOpcode::kConstant &&
4910 ShapeUtil::ElementIsIntegral(operand->shape());
4911 })) {
4912 const int64_t rank = operand->shape().rank();
4913 std::vector<int64_t> slice_starts(rank);
4914 std::vector<int64_t> slice_limits(rank);
4915 std::vector<int64_t> slice_strides(rank, 1);
4916
4917 for (int64_t i = 0; i < rank; ++i) {
4918 std::optional<int64_t> offset =
4919 dynamic_slice->operand(i + 1)->literal().GetFirstInteger();
4920 if (!offset || *offset < 0) {
4921 return OkStatus();
4922 }
4923 const int64_t max_offset =
4924 dynamic_slice->operand(0)->shape().dimensions(i) -
4925 dynamic_slice->shape().dimensions(i);
4926 slice_starts[i] = std::min(max_offset, *offset);
4927 slice_limits[i] =
4928 std::min(max_offset, *offset) + dynamic_slice->shape().dimensions(i);
4929 }
4930 return ReplaceWithNewInstruction(
4931 dynamic_slice,
4932 HloInstruction::CreateSlice(dynamic_slice->shape(), operand,
4933 slice_starts, slice_limits, slice_strides));
4934 }
4935
4936 // Convert the dynamic slice of an iota to just a reference to the index
4937 // (possibly clamped and scaled). Index is always a scalar integer. Output
4938 // should be a rank 1 array of size 1 with element type matching that of the
4939 // scalar index (except the signedness).
4940 const PrimitiveType element_type = dynamic_slice->shape().element_type();
4941 if (operand->shape().rank() == 1 && dynamic_slice->shape().rank() == 1 &&
4942 dynamic_slice->shape().dimensions(0) == 1 &&
4943 (element_type == S32 || element_type == U32)) {
4944 // Match multiply(x, broadcast(scalar)) and return the scalar
4945 // constant.
4946 auto match_multiply_with_scalar =
4947 [&](HloInstruction* hlo) -> HloInstruction* {
4948 if (hlo->opcode() != HloOpcode::kMultiply) {
4949 return nullptr;
4950 }
4951 HloInstruction* broadcast = hlo->mutable_operand(1);
4952 if (broadcast->opcode() == HloOpcode::kBroadcast &&
4953 broadcast->dimensions().empty() &&
4954 ShapeUtil::IsScalar(broadcast->operand(0)->shape())) {
4955 return broadcast->mutable_operand(0);
4956 }
4957 return nullptr;
4958 };
4959
4960 HloInstruction* multiplier = match_multiply_with_scalar(operand);
4961 if (multiplier) {
4962 operand = operand->mutable_operand(0);
4963 }
4964
4965 if (operand->opcode() == HloOpcode::kIota) {
4966 // This dynamic_slice will have a single start_index operand (since its
4967 // operand is rank 1).
4968 HloInstruction* index = dynamic_slice->mutable_operand(1);
4969 const PrimitiveType index_type = index->shape().element_type();
4970
4971 auto create_constant = [&](int64_t value) {
4972 if (index_type == S32) {
4973 return MakeScalarLike<int32_t>(index, value);
4974 } else {
4975 return MakeScalarLike<uint32_t>(index, value);
4976 }
4977 };
4978
4979 if (index_type == S32 || index_type == U32) {
4980 // Clamp the index to the range of the iota.
4981 int64_t iota_size = operand->shape().dimensions(0);
4982 HloInstruction* low = create_constant(0);
4983 HloInstruction* high = create_constant(iota_size - 1);
4984 HloInstruction* clamped =
4985 dynamic_slice->AddInstruction(HloInstruction::CreateTernary(
4986 index->shape(), HloOpcode::kClamp, low, index, high));
4987
4988 // Convert the clamped index from index_type to element_type and
4989 // multiply with the multiplier.
4990 HloInstruction* result = clamped;
4991 if (index_type != element_type) {
4992 Shape result_shp = result->shape();
4993 result_shp.set_element_type(element_type);
4994 result = dynamic_slice->AddInstruction(
4995 HloInstruction::CreateConvert(result_shp, clamped));
4996 }
4997
4998 if (multiplier) {
4999 result = dynamic_slice->AddInstruction(HloInstruction::CreateBinary(
5000 result->shape(), HloOpcode::kMultiply, result, multiplier));
5001 }
5002
5003 return ReplaceWithNewInstruction(
5004 dynamic_slice,
5005 HloInstruction::CreateReshape(dynamic_slice->shape(), result));
5006 }
5007 }
5008 }
5009
5010 // ds(ds(x,id),inner_id) -> ds(x, id + inner_id)
5011 if (operand->opcode() == HloOpcode::kDynamicSlice) {
5012 TF_RETURN_IF_ERROR(dynamic_slice->ReplaceOperandWithDifferentShape(
5013 0, operand->mutable_operand(0)));
5014 for (int64_t i = 1; i < dynamic_slice->operand_count(); ++i) {
5015 HloInstruction* index = dynamic_slice->mutable_operand(i);
5016 HloInstruction* inner_index = operand->mutable_operand(i);
5017 inner_index = inner_index->AddInstruction(HloInstruction::CreateTernary(
5018 inner_index->shape(), HloOpcode::kClamp,
5019 MakeScalarLike(inner_index, 0), inner_index,
5020 MakeScalarLike(inner_index,
5021 operand->operand(0)->shape().dimensions(i - 1) -
5022 dynamic_slice->dynamic_slice_sizes()[i - 1])));
5023 if (inner_index->shape().element_type() !=
5024 index->shape().element_type()) {
5025 inner_index = inner_index->AddInstruction(
5026 HloInstruction::CreateConvert(index->shape(), inner_index));
5027 }
5028 HloInstruction* combined_index =
5029 operand->AddInstruction(HloInstruction::CreateBinary(
5030 index->shape(), HloOpcode::kAdd, index, inner_index));
5031 TF_RETURN_IF_ERROR(dynamic_slice->ReplaceOperandWith(i, combined_index));
5032 }
5033 MarkAsChanged();
5034 }
5035 return OkStatus();
5036 }
5037
HandleDynamicUpdateSlice(HloInstruction * dynamic_update_slice)5038 Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice(
5039 HloInstruction* dynamic_update_slice) {
5040 // Rewriting DynamicUpdateSlice when it matches
5041 // dynamic_update_slice(broadcast(constant),data,constant_index0,...)
5042 // to a Pad(x, constant)
5043 // Only Broadcast considered currently, other ops need to be considered
5044 // in the future.
5045 HloInstruction* updated = dynamic_update_slice->mutable_operand(0);
5046 HloInstruction* dus_update = dynamic_update_slice->mutable_operand(1);
5047 HloInstruction* pad_value;
5048 if (Match(updated,
5049 m::Broadcast(m::Op(&pad_value).WithShape(m::Shape().IsScalar())))) {
5050 auto updated_shape = updated->shape();
5051 auto update_shape = dus_update->shape();
5052 auto update_start_indx = dynamic_update_slice->operand(2);
5053 int64_t offset = 0;
5054 bool compatible = true;
5055 // Whether the start indices to dynamic update slice is a list,
5056 // output of a tuple/concatenate, we setup the update_start_indx
5057 // appropriately.
5058 if (ShapeUtil::IsScalar(update_start_indx->shape())) {
5059 update_start_indx = dynamic_update_slice;
5060 offset = 2;
5061 } else {
5062 if (update_start_indx->opcode() == HloOpcode::kTuple ||
5063 update_start_indx->opcode() == HloOpcode::kConcatenate) {
5064 offset = 0;
5065 } else {
5066 compatible = false;
5067 }
5068 }
5069 PaddingConfig padding_config;
5070 if (compatible) {
5071 for (int64_t dim = 0; dim < updated_shape.rank(); ++dim) {
5072 auto padding_config_dim = padding_config.add_dimensions();
5073 auto slice_dim_start = update_start_indx->operand(dim + offset);
5074 if (!Match(slice_dim_start, m::ConstantScalar())) {
5075 compatible = false;
5076 break;
5077 }
5078 VLOG(2) << "slice: " << slice_dim_start->ToString();
5079 std::optional<int64_t> beg =
5080 slice_dim_start->literal().GetFirstInteger();
5081 if (!beg) {
5082 compatible = false;
5083 break;
5084 }
5085 VLOG(2) << "beg value: " << *beg;
5086 auto update_width = ShapeUtil::GetDimension(update_shape, dim);
5087 auto bcast_width = ShapeUtil::GetDimension(updated_shape, dim);
5088 // Clamp beg so that it is non-negative.
5089 *beg = std::max<int64_t>(0, *beg);
5090 // Clamp beg so that it is in-bounds.
5091 *beg = std::min<int64_t>(bcast_width - update_width, *beg);
5092 VLOG(2) << "adjusted beg value: " << *beg;
5093 padding_config_dim->set_edge_padding_low(*beg);
5094 padding_config_dim->set_edge_padding_high(bcast_width -
5095 (*beg + update_width));
5096 // dynamic_update_slice does not specify a stride
5097 padding_config_dim->set_interior_padding(0);
5098 }
5099 }
5100
5101 if (compatible) {
5102 HloInstruction* pad =
5103 dynamic_update_slice->AddInstruction(HloInstruction::CreatePad(
5104 updated_shape, dus_update, pad_value, padding_config));
5105 VLOG(2) << dynamic_update_slice->ToString();
5106 VLOG(2) << " with pad:" << pad->ToString();
5107 VLOG(2) << " Computation before rewrite is: "
5108 << dynamic_update_slice->parent()->ToString();
5109 return ReplaceInstruction(dynamic_update_slice, pad);
5110 }
5111 }
5112
5113 // DynamicUpdateSlice where operand and dus_update have the same size is
5114 // equal to dus_update.
5115 if (SameShape(dynamic_update_slice, dus_update)) {
5116 return ReplaceInstruction(dynamic_update_slice, dus_update);
5117 }
5118
5119 // If any dimension of dus_update is 0, elide the DynamicUpdateSlice. This
5120 // optimization becomes invalid should we later prefer to warn about out of
5121 // bound indices.
5122 if (ShapeUtil::IsZeroElementArray(dus_update->shape())) {
5123 return ReplaceInstruction(dynamic_update_slice, updated);
5124 }
5125
5126 // dus(a,dus(ds(a,id),c,inner_id)),id) is equivalent to dus(a,c,inner_id + id)
5127 if (dus_update->opcode() == HloOpcode::kDynamicUpdateSlice &&
5128 (dus_update->operand(0)->opcode() == HloOpcode::kDynamicSlice &&
5129 dus_update->operand(0)->operand(0) == dynamic_update_slice->operand(0) &&
5130 absl::c_equal(
5131 absl::MakeConstSpan(dynamic_update_slice->operands()).subspan(2),
5132 absl::MakeConstSpan(dus_update->operand(0)->operands())
5133 .subspan(1)))) {
5134 TF_RETURN_IF_ERROR(dynamic_update_slice->ReplaceOperandWithDifferentShape(
5135 1, dus_update->mutable_operand(1)));
5136 for (int64_t i = 2; i < dynamic_update_slice->operand_count(); ++i) {
5137 HloInstruction* index = dynamic_update_slice->mutable_operand(i);
5138 HloInstruction* inner_index = dus_update->mutable_operand(i);
5139 inner_index = inner_index->AddInstruction(HloInstruction::CreateTernary(
5140 inner_index->shape(), HloOpcode::kClamp,
5141 MakeScalarLike(inner_index, 0), inner_index,
5142 MakeScalarLike(
5143 inner_index,
5144 dus_update->shape().dimensions(i - 2) -
5145 dus_update->operand(1)->shape().dimensions(i - 2))));
5146 if (inner_index->shape().element_type() !=
5147 index->shape().element_type()) {
5148 inner_index = inner_index->AddInstruction(
5149 HloInstruction::CreateConvert(index->shape(), inner_index));
5150 }
5151 HloInstruction* combined_index =
5152 dus_update->AddInstruction(HloInstruction::CreateBinary(
5153 index->shape(), HloOpcode::kAdd, index, inner_index));
5154 TF_RETURN_IF_ERROR(
5155 dynamic_update_slice->ReplaceOperandWith(i, combined_index));
5156 }
5157 MarkAsChanged();
5158 return OkStatus();
5159 }
5160 return OkStatus();
5161 }
5162
MatchArgMinMax(const HloInstruction * hlo,bool is_max)5163 static bool MatchArgMinMax(const HloInstruction* hlo, bool is_max) {
5164 // Create matcher for shared sub-expression.
5165 auto value_pred = m::OrAnyOrder(
5166 m::Compare(m::Parameter(0), m::Parameter(2))
5167 .WithComparisonDirection(is_max ? ComparisonDirection::kGt
5168 : ComparisonDirection::kLt),
5169 m::Compare(m::Parameter(0), m::Parameter(0))
5170 .WithComparisonDirection(ComparisonDirection::kNe));
5171
5172 // Match on argmax reduction computation.
5173 return Match(
5174 hlo,
5175 m::Tuple(
5176 m::Select(value_pred, m::Parameter(0), m::Parameter(2)),
5177 m::Select(
5178 m::OrAnyOrder(
5179 value_pred,
5180 m::And(
5181 m::Compare(m::Parameter(0), m::Parameter(2))
5182 .WithComparisonDirection(ComparisonDirection::kEq),
5183 m::Compare(m::Parameter(1), m::Parameter(3))
5184 .WithComparisonDirection(ComparisonDirection::kLt))),
5185 m::Parameter(1), m::Parameter(3))));
5186 }
5187
5188 // Match on variadic reduce which computes and returns (min, arg_min).
5189 //
5190 // p0 p2 p1 p3
5191 // /|\ \/ |\ |\ /|
5192 // / | \/\ | \ | \ / |
5193 // / | /\ \| | | /\ |
5194 // Ne Lt | \ | | | ||
5195 // \ / | |\ | | / ||
5196 // Or / / Eq Lt ||
5197 // | / / \ / //
5198 // | | | And //
5199 // | | | | //
5200 // select select
5201 // \ /
5202 // tuple
5203 //
MatchArgMin(const HloInstruction * hlo)5204 static bool MatchArgMin(const HloInstruction* hlo) {
5205 // Match on variadic Reduce ArgMin
5206 if (hlo->opcode() != HloOpcode::kReduce || hlo->operand_count() != 4 ||
5207 !hlo->shape().IsTuple() ||
5208 hlo->operand(1)->opcode() != HloOpcode::kIota ||
5209 !IsScalarConstantInf(hlo->operand(2)) ||
5210 !IsScalarConstantZero(hlo->operand(3))) {
5211 return false;
5212 }
5213 return MatchArgMinMax(hlo->to_apply()->root_instruction(), /*is_max=*/false);
5214 }
5215
5216 // Match on variadic reduce which computes and returns (max, arg_max).
5217 //
5218 // p0 p2 p1 p3
5219 // /|\ \/ |\ |\ /|
5220 // / | \/\ | \ | \ / |
5221 // / | /\ \| | | /\ |
5222 // Ne Gt | \ | | | ||
5223 // \ / | |\ | | / ||
5224 // Or / / Eq Lt ||
5225 // | / / \ / //
5226 // | | | And //
5227 // | | | | //
5228 // select select
5229 // \ /
5230 // tuple
5231 //
MatchArgMax(const HloInstruction * hlo)5232 static bool MatchArgMax(const HloInstruction* hlo) {
5233 // Match on variadic Reduce ArgMax.
5234 if (hlo->opcode() != HloOpcode::kReduce || hlo->operand_count() != 4 ||
5235 !hlo->shape().IsTuple() ||
5236 hlo->operand(1)->opcode() != HloOpcode::kIota ||
5237 !IsScalarConstantNegInf(hlo->operand(2)) ||
5238 !IsScalarConstantZero(hlo->operand(3))) {
5239 return false;
5240 }
5241 return MatchArgMinMax(hlo->to_apply()->root_instruction(), /*is_max=*/true);
5242 }
5243
ReductionComputationsEquivalent(const HloComputation & a,const HloComputation & b)5244 static bool ReductionComputationsEquivalent(const HloComputation& a,
5245 const HloComputation& b) {
5246 if (a == b) {
5247 return true;
5248 }
5249
5250 // Check for simple commutative reduction functions.
5251 enum CommutativeFnKind { kAdd, kMul, kAnd, kOr };
5252 auto categorize_computation =
5253 [](const HloComputation& c) -> std::optional<CommutativeFnKind> {
5254 if (c.num_parameters() != 2) {
5255 return std::nullopt;
5256 }
5257
5258 const HloInstruction* root = c.root_instruction();
5259 if (Match(root, m::AddAnyOrder(m::Parameter(0), m::Parameter(1)))) {
5260 return kAdd;
5261 }
5262 if (Match(root, m::MultiplyAnyOrder(m::Parameter(0), m::Parameter(1)))) {
5263 return kMul;
5264 }
5265 if (Match(root, m::AndAnyOrder(m::Parameter(0), m::Parameter(1)))) {
5266 return kAnd;
5267 }
5268 if (Match(root, m::OrAnyOrder(m::Parameter(0), m::Parameter(1)))) {
5269 return kOr;
5270 }
5271 return std::nullopt;
5272 };
5273 auto category_a = categorize_computation(a);
5274 auto category_b = categorize_computation(b);
5275 return category_a.has_value() && category_b.has_value() &&
5276 category_a == category_b;
5277 }
5278
HandleReduce(HloInstruction * hlo)5279 Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) {
5280 HloReduceInstruction* reduce = Cast<HloReduceInstruction>(hlo);
5281 bool multi_output_reduce = reduce->shape().IsTuple();
5282 // For tuple reduce, we require all reduce shapes to be the same, up to the
5283 // element types, so we can just the first operand and the first result as a
5284 // representative.
5285 auto arg = reduce->inputs()[0];
5286 auto init_value = reduce->init_values()[0];
5287 const Shape& reduce_result_shape =
5288 multi_output_reduce ? reduce->shape().tuple_shapes(0) : reduce->shape();
5289
5290 absl::Span<const int64_t> dimensions(reduce->dimensions());
5291 HloComputation* function = reduce->to_apply();
5292 if (ShapeUtil::IsZeroElementArray(arg->shape()) ||
5293 ShapeUtil::IsZeroElementArray(reduce_result_shape)) {
5294 if (multi_output_reduce) {
5295 std::vector<HloInstruction*> broadcast_inits;
5296 int64_t inputs = reduce->input_count();
5297 for (int64_t i = 0; i < inputs; ++i) {
5298 broadcast_inits.push_back(reduce->init_values()[i]->AddInstruction(
5299 HloInstruction::CreateBroadcast(reduce->shape().tuple_shapes(i),
5300 reduce->init_values()[i], {})));
5301 }
5302 return ReplaceWithNewInstruction(
5303 reduce, HloInstruction::CreateTuple(broadcast_inits));
5304 } else {
5305 return ReplaceWithNewInstruction(
5306 reduce,
5307 HloInstruction::CreateBroadcast(reduce_result_shape, init_value, {}));
5308 }
5309 }
5310
5311 // Turn trivial variadic reductions into normal reductions.
5312 if (multi_output_reduce && reduce->shape().tuple_shapes_size() == 1 &&
5313 reduce->input_count() == 1 &&
5314 Match(function->root_instruction(), m::Tuple())) {
5315 absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
5316 replacements;
5317 replacements[function->root_instruction()] = nullptr;
5318 auto new_function = computation_->parent()->AddEmbeddedComputation(
5319 function->CloneWithReplacements(
5320 &replacements, /*extra_parameters=*/{},
5321 /*context=*/nullptr,
5322 /*suffix=*/"clone",
5323 /*new_root=*/function->root_instruction()->operand(0)));
5324 auto new_reduce = reduce->AddInstruction(
5325 HloInstruction::CreateReduce(reduce_result_shape, arg, init_value,
5326 reduce->dimensions(), new_function));
5327 return ReplaceWithNewInstruction(reduce,
5328 HloInstruction::CreateTuple({new_reduce}));
5329 }
5330
5331 // If the reduction results in the same number of elements, then the only
5332 // possible side effect would be a reshape. Since the init_value is an
5333 // identity of the reduction function, we can therefore replace the reduce
5334 // with a simple reshape, ignoring the reduction function completely.
5335 if (ShapeUtil::ElementsIn(reduce_result_shape) ==
5336 ShapeUtil::ElementsIn(arg->shape()) &&
5337 (!options_.is_layout_sensitive() ||
5338 options_.ReshapeIsBitcast(arg->shape(), reduce_result_shape))) {
5339 if (multi_output_reduce) {
5340 std::vector<HloInstruction*> reshaped_args;
5341 int64_t inputs = reduce->input_count();
5342 for (int64_t i = 0; i < inputs; ++i) {
5343 reshaped_args.push_back(
5344 reduce->AddInstruction(HloInstruction::CreateReshape(
5345 reduce->shape().tuple_shapes(i), reduce->inputs()[i])));
5346 }
5347 return ReplaceWithNewInstruction(
5348 reduce, HloInstruction::CreateTuple(reshaped_args));
5349 } else {
5350 return ReplaceWithNewInstruction(
5351 reduce, HloInstruction::CreateReshape(reduce_result_shape, arg));
5352 }
5353 }
5354
5355 if (options_.is_layout_sensitive()) {
5356 return OkStatus();
5357 }
5358
5359 // TODO(b/131122694): Most of those optimizations below can be done for
5360 // multi-output reduces.
5361 if (multi_output_reduce) {
5362 return OkStatus();
5363 }
5364
5365 // A Transpose feeding a reduce can simply permute the reduction dimensions
5366 // field if the output of the reduce is a vector or scalar. Higher ranked
5367 // result may require a transpose of the output.
5368 if (arg->opcode() == HloOpcode::kTranspose &&
5369 (reduce->shape().rank() < 2 || arg->user_count() == 1 ||
5370 absl::c_all_of(arg->users(), [](HloInstruction* use) {
5371 return use->opcode() == HloOpcode::kReduce;
5372 }))) {
5373 auto transpose_dimensions = arg->dimensions();
5374 std::vector<int64_t> new_reduce_dimensions;
5375 new_reduce_dimensions.reserve(dimensions.size());
5376 for (auto dim : dimensions) {
5377 new_reduce_dimensions.push_back(transpose_dimensions[dim]);
5378 }
5379
5380 Shape new_reduce_result_shape = ShapeUtil::DeleteDimensions(
5381 new_reduce_dimensions, arg->mutable_operand(0)->shape());
5382 HloInstruction* new_reduce =
5383 reduce->AddInstruction(HloInstruction::CreateReduce(
5384 new_reduce_result_shape, arg->mutable_operand(0), init_value,
5385 new_reduce_dimensions, function));
5386 std::vector<int64_t> new_transpose_dimensions;
5387 for (auto dim : transpose_dimensions) {
5388 if (absl::c_linear_search(new_reduce_dimensions, dim)) {
5389 continue;
5390 }
5391 new_transpose_dimensions.push_back(dim);
5392 }
5393
5394 // If new transpose dimensions are sorted, then there is no need to
5395 // transpose reduce result.
5396 if (absl::c_is_sorted(new_transpose_dimensions)) {
5397 return ReplaceInstruction(reduce, new_reduce);
5398 }
5399 for (auto& d : new_transpose_dimensions) {
5400 auto old_dim = d;
5401 for (auto reduced_dim : new_reduce_dimensions) {
5402 if (old_dim > reduced_dim) {
5403 --d;
5404 }
5405 }
5406 }
5407 TF_ASSIGN_OR_RETURN(HloInstruction * new_transpose,
5408 MakeTransposeHlo(new_reduce, new_transpose_dimensions));
5409 return ReplaceInstruction(reduce, new_transpose);
5410 }
5411
5412 // If a reduce feeds a reduce with the same computation and initial value,
5413 // they can be combined into a single reduce.
5414 if (arg->opcode() == HloOpcode::kReduce &&
5415 init_value->Identical(*arg->operand(1)) &&
5416 ReductionComputationsEquivalent(*function, *arg->to_apply())) {
5417 // Create a new reduce with the combined reduction dimensions of both
5418 // reduces.
5419 std::vector<int64_t> arg_dims = *arg->mutable_dimensions();
5420 absl::c_sort(arg_dims);
5421 std::vector<int64_t> reduce_dims = *reduce->mutable_dimensions();
5422 absl::c_sort(reduce_dims);
5423 // Transform reduce_dims to the same rank as the operand of the operand.
5424 for (int64_t arg_dim : arg_dims) {
5425 for (int64_t& dim : reduce_dims) {
5426 if (dim >= arg_dim) {
5427 ++dim;
5428 }
5429 }
5430 }
5431 std::vector<int64_t> new_dimensions;
5432 new_dimensions.reserve(arg->dimensions().size() +
5433 reduce->dimensions().size());
5434 std::merge(arg_dims.begin(), arg_dims.end(), reduce_dims.begin(),
5435 reduce_dims.end(), std::back_inserter(new_dimensions));
5436 return ReplaceWithNewInstruction(
5437 reduce, HloInstruction::CreateReduce(
5438 reduce_result_shape, arg->mutable_operand(0), init_value,
5439 new_dimensions, function));
5440 }
5441
5442 // A reshape that collapses multiple dimensions into a dimension being
5443 // reduced can just reduce all of those dimensions instead of doing a
5444 // collapsing reshape before a reduction.
5445 if (options_.enable_reduce_of_reshape() &&
5446 arg->opcode() == HloOpcode::kReshape) {
5447 std::vector<std::pair<int64_t, int64_t>> unmodified_dims =
5448 ShapeUtil::DimensionsUnmodifiedByReshape(arg->operand(0)->shape(),
5449 arg->shape());
5450 std::vector<bool> arg_dim_in_output(arg->shape().rank(), true);
5451 std::vector<bool> arg_dim_unmodified(arg->shape().rank(), false);
5452 for (auto dim : dimensions) {
5453 arg_dim_in_output[dim] = false;
5454 }
5455 for (auto dim_pair : unmodified_dims) {
5456 arg_dim_unmodified[dim_pair.second] = true;
5457 }
5458 // The goal is to verify that all dimensions that are not removed in the
5459 // reduce are unmodified by the reshape. For example:
5460 // reduce(reshape([A,B*C], a[A,B,C]),[1]) = reduce(a[A, B, C], [1, 2])
5461 bool can_move_reshape_into_reduce = true;
5462 for (int64_t i = 0; i < arg_dim_in_output.size(); ++i) {
5463 if (arg_dim_in_output[i] && !arg_dim_unmodified[i]) {
5464 can_move_reshape_into_reduce = false;
5465 }
5466 }
5467 if (can_move_reshape_into_reduce) {
5468 MarkAsChanged();
5469 absl::flat_hash_set<int64_t> dimensions_not_to_reduce;
5470 for (auto dim_pair : unmodified_dims) {
5471 if (arg_dim_in_output[dim_pair.second]) {
5472 dimensions_not_to_reduce.insert(dim_pair.first);
5473 }
5474 }
5475 std::vector<int64_t> new_reduce_dimensions;
5476 for (int64_t i = 0; i < arg->operand(0)->shape().rank(); ++i) {
5477 if (!dimensions_not_to_reduce.contains(i)) {
5478 new_reduce_dimensions.push_back(i);
5479 }
5480 }
5481 return ReplaceWithNewInstruction(
5482 reduce, HloInstruction::CreateReduce(
5483 reduce_result_shape, arg->mutable_operand(0), init_value,
5484 new_reduce_dimensions, function));
5485 }
5486 }
5487 // Convert Reduce(concat({a,b,...})) to
5488 // map(reduce(a),map(reduce(b),...,))
5489 //
5490 // This should make fusion easier or use less memory bandwidth in the unfused
5491 // case.
5492 if (arg->opcode() == HloOpcode::kConcatenate &&
5493 absl::c_linear_search(reduce->dimensions(),
5494 arg->concatenate_dimension())) {
5495 HloInstruction* old_reduce = nullptr;
5496 for (HloInstruction* operand : arg->operands()) {
5497 HloInstruction* new_reduce = reduce->AddInstruction(
5498 HloInstruction::CreateReduce(reduce_result_shape, operand, init_value,
5499 reduce->dimensions(), function));
5500 if (old_reduce != nullptr) {
5501 new_reduce = reduce->AddInstruction(HloInstruction::CreateMap(
5502 reduce_result_shape, {old_reduce, new_reduce}, function));
5503 }
5504 old_reduce = new_reduce;
5505 }
5506 return ReplaceInstruction(reduce, old_reduce);
5507 }
5508
5509 HloInstruction *dot, *lhs, *rhs;
5510 // Convert Reduce(Dot(X,Y)) to Dot(X,Y) if any of the dimensions reduced were
5511 // batch dimensions of the dot. The transformation supports reducing other
5512 // dimensions as well.
5513 if (options_.enable_dot_strength_reduction() &&
5514 Match(arg, m::Dot(&dot, m::Op(&lhs), m::Op(&rhs)).WithOneUser()) &&
5515 Match(reduce->to_apply()->root_instruction(),
5516 m::AddAnyOrder(m::Parameter(0), m::Parameter(1))) &&
5517 absl::c_any_of(reduce->dimensions(), [&](int64_t dim) {
5518 return dim < dot->dot_dimension_numbers().lhs_batch_dimensions_size();
5519 })) {
5520 const auto& dnums = dot->dot_dimension_numbers();
5521 DotDimensionNumbers new_dnums = dnums;
5522 new_dnums.clear_lhs_batch_dimensions();
5523 new_dnums.clear_rhs_batch_dimensions();
5524 int64_t removed_dims = 0;
5525 for (int64_t batch_dim = 0; batch_dim < dnums.lhs_batch_dimensions_size();
5526 ++batch_dim) {
5527 if (absl::c_linear_search(reduce->dimensions(), batch_dim)) {
5528 new_dnums.add_rhs_contracting_dimensions(
5529 dnums.rhs_batch_dimensions(batch_dim));
5530 new_dnums.add_lhs_contracting_dimensions(
5531 dnums.lhs_batch_dimensions(batch_dim));
5532 ++removed_dims;
5533 } else {
5534 new_dnums.add_rhs_batch_dimensions(
5535 dnums.rhs_batch_dimensions(batch_dim));
5536 new_dnums.add_lhs_batch_dimensions(
5537 dnums.lhs_batch_dimensions(batch_dim));
5538 }
5539 }
5540 std::vector<int64_t> reduce_dims;
5541 for (int64_t dim : reduce->dimensions()) {
5542 if (dim >= dnums.lhs_batch_dimensions_size()) {
5543 reduce_dims.push_back(dim - removed_dims);
5544 }
5545 }
5546 TF_ASSIGN_OR_RETURN(
5547 auto new_dot,
5548 MakeDotHlo(lhs, rhs, new_dnums, dot->precision_config(),
5549 /*preferred_element_type=*/dot->shape().element_type()));
5550 dot->SetupDerivedInstruction(new_dot);
5551 if (reduce_dims.empty()) {
5552 return ReplaceInstruction(hlo, new_dot);
5553 }
5554 TF_ASSIGN_OR_RETURN(
5555 auto new_reduce,
5556 MakeReduceHlo(new_dot, init_value, reduce_dims, HloOpcode::kAdd));
5557 reduce->SetupDerivedInstruction(new_reduce);
5558 return ReplaceInstruction(hlo, new_reduce);
5559 }
5560
5561 // Replace Use(ReduceMax(Arg)) with Use(Gte(ReduceArgMax, 0)).
5562 // Match on Reduce Max with init value -Inf.
5563 if (reduce->operand_count() == 2 && IsScalarConstantNegInf(init_value) &&
5564 Match(reduce->to_apply()->root_instruction(),
5565 m::MaximumAnyOrder(m::Parameter(0), m::Parameter(1)))) {
5566 // Match on variadic Reduce ArgMax which is also fed by 'arg'.
5567 auto arg_max_candidate =
5568 absl::c_find_if(arg->users(), [&](const HloInstruction* user) {
5569 return user != reduce && user->operand(0) == arg &&
5570 MatchArgMax(user) &&
5571 reduce->dimensions() == user->dimensions();
5572 });
5573 if (arg_max_candidate != arg->users().end()) {
5574 // Replace 'reduce' uses with GTE(ArgMax, 0).
5575 return ReplaceWithNewInstruction(
5576 reduce, HloInstruction::CreateGetTupleElement(*arg_max_candidate,
5577 /*index=*/0));
5578 }
5579 }
5580
5581 // Replace Use(ReduceMin(Arg)) with Use(Gte(ReduceArgMin, 0)).
5582 // Match on Reduce Min with init value Inf.
5583 if (reduce->operand_count() == 2 && IsScalarConstantInf(init_value) &&
5584 Match(reduce->to_apply()->root_instruction(),
5585 m::MinimumAnyOrder(m::Parameter(0), m::Parameter(1)))) {
5586 // Match on variadic Reduce ArgMin which is also fed by 'arg'.
5587 auto arg_min_candidate =
5588 absl::c_find_if(arg->users(), [&](const HloInstruction* user) {
5589 return user != reduce && user->operand(0) == arg &&
5590 MatchArgMin(user) &&
5591 reduce->dimensions() == user->dimensions();
5592 });
5593 if (arg_min_candidate != arg->users().end()) {
5594 // Replace 'reduce' uses with GTE(ArgMin, 0).
5595 return ReplaceWithNewInstruction(
5596 reduce, HloInstruction::CreateGetTupleElement(*arg_min_candidate,
5597 /*index=*/0));
5598 }
5599 }
5600 return OkStatus();
5601 }
5602
HandleReduceWindow(HloInstruction * hlo)5603 Status AlgebraicSimplifierVisitor::HandleReduceWindow(HloInstruction* hlo) {
5604 auto* reduce_window = Cast<HloReduceWindowInstruction>(hlo);
5605 const bool multi_output_reduce_window = reduce_window->shape().IsTuple();
5606 auto inputs = reduce_window->inputs();
5607 auto init_values = reduce_window->init_values();
5608 auto input_count = reduce_window->input_count();
5609 auto input_shapes = reduce_window->input_shapes();
5610 auto output_shapes = reduce_window->output_shapes();
5611 auto replace_with_span = [&](const std::vector<HloInstruction*>& elements) {
5612 CHECK(multi_output_reduce_window || elements.size() == 1);
5613 if (multi_output_reduce_window) {
5614 return ReplaceWithNewInstruction(reduce_window,
5615 HloInstruction::CreateTuple(elements));
5616 }
5617 return ReplaceInstruction(reduce_window, elements[0]);
5618 };
5619 // For tuple reduce, we require all reduce shapes to be the same, up to the
5620 // element types, so we can use just the first operand and the first result as
5621 // a representative.
5622 if (ShapeUtil::IsZeroElementArray(*input_shapes[0]) ||
5623 ShapeUtil::IsZeroElementArray(*output_shapes[0])) {
5624 std::vector<HloInstruction*> broadcast_inits;
5625 for (int64_t i = 0; i < input_count; ++i) {
5626 broadcast_inits.push_back(
5627 hlo->AddInstruction(HloInstruction::CreateBroadcast(
5628 *output_shapes[i], init_values[i], {})));
5629 }
5630 return replace_with_span(broadcast_inits);
5631 }
5632 if (ShapeUtil::IsScalar(*input_shapes[0]) &&
5633 (!multi_output_reduce_window ||
5634 reduce_window->to_apply()->root_instruction()->opcode() ==
5635 HloOpcode::kTuple)) {
5636 std::vector<HloInstruction*> maps;
5637 for (int64_t i = 0; i < input_count; ++i) {
5638 TF_RET_CHECK(ShapeUtil::IsScalar(*input_shapes[i]));
5639 TF_RET_CHECK(ShapeUtil::IsScalar(*output_shapes[i]));
5640 HloInstruction* map_computation_root;
5641 absl::flat_hash_map<const HloInstruction*,
5642 std::unique_ptr<HloInstruction>>
5643 replacements;
5644 if (multi_output_reduce_window) {
5645 map_computation_root =
5646 reduce_window->to_apply()->root_instruction()->mutable_operand(i);
5647 replacements[reduce_window->to_apply()->root_instruction()] = nullptr;
5648 } else {
5649 map_computation_root = reduce_window->to_apply()->root_instruction();
5650 }
5651 maps.push_back(inputs[i]);
5652 }
5653 return replace_with_span(maps);
5654 }
5655 // Turn trivial variadic reduce windows into normal reduce windows.
5656 auto reduce_function_root = reduce_window->to_apply()->root_instruction();
5657 if (multi_output_reduce_window && input_count == 1 &&
5658 Match(reduce_function_root, m::Tuple())) {
5659 // Make a new reducer which is identical but does not have a tuple
5660 // instruction at the bottom.
5661 absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
5662 replacements;
5663 replacements[reduce_function_root] = nullptr;
5664 auto new_function = computation_->parent()->AddEmbeddedComputation(
5665 reduce_window->to_apply()->CloneWithReplacements(
5666 &replacements, /*extra_parameters=*/{},
5667 /*context=*/nullptr,
5668 /*suffix=*/"clone",
5669 /*new_root=*/reduce_function_root->operand(0)));
5670 auto new_reduce_window =
5671 reduce_window->AddInstruction(HloInstruction::CreateReduceWindow(
5672 *output_shapes[0], inputs[0], init_values[0],
5673 reduce_window->window(), new_function));
5674 return ReplaceWithNewInstruction(
5675 reduce_window, HloInstruction::CreateTuple({new_reduce_window}));
5676 }
5677 // TODO(b/73062247) Variadic reduce window is not yet supported in simplifier.
5678 if (multi_output_reduce_window) {
5679 return OkStatus();
5680 }
5681 auto operand = reduce_window->mutable_operand(0);
5682 auto init_value = reduce_window->mutable_operand(1);
5683 auto function = reduce_window->to_apply();
5684 const Window& window = reduce_window->window();
5685
5686 // reduce-window with a 1x1x..x1 window and no dilation etc can be replaced
5687 // with a trivial elementwise operation, plus a pad op if necessary.
5688 //
5689 // We cowardly refuse to consider this optimization when the reduce-window
5690 // subcomputation is anything other than a simple add/min/max. Supporting
5691 // more complex subcomputations is possible, but is tantamount to implementing
5692 // jax.vmap()!
5693 if (absl::c_all_of(window.dimensions(),
5694 [](const WindowDimension& dim) {
5695 return dim.size() == 1 && //
5696 dim.stride() == 1 && //
5697 dim.window_dilation() == 1 && //
5698 dim.base_dilation() == 1 && //
5699 !dim.window_reversal();
5700 }) &&
5701 Match(function->root_instruction(),
5702 m::AnyOf<HloInstruction>(
5703 m::AddAnyOrder(m::Parameter(0), m::Parameter(1)),
5704 m::MinimumAnyOrder(m::Parameter(0), m::Parameter(1)),
5705 m::MaximumAnyOrder(m::Parameter(0), m::Parameter(1))))) {
5706 const HloInstruction* nested_root = function->root_instruction();
5707 DimensionVector broadcast_dims(nested_root->shape().dimensions_size());
5708 absl::c_iota(broadcast_dims, 0);
5709 TF_ASSIGN_OR_RETURN(
5710 auto new_op, MakeBinaryHlo(nested_root->opcode(), operand,
5711 MakeBroadcastHlo(init_value, broadcast_dims,
5712 operand->shape())));
5713
5714 if (absl::c_any_of(window.dimensions(), [](const WindowDimension& dim) {
5715 return dim.padding_low() > 0 || dim.padding_high() > 0;
5716 })) {
5717 PaddingConfig padding_config;
5718 for (const WindowDimension& window_dim : window.dimensions()) {
5719 auto& padding_dim = *padding_config.add_dimensions();
5720 padding_dim.set_edge_padding_low(window_dim.padding_low());
5721 padding_dim.set_edge_padding_high(window_dim.padding_high());
5722 padding_dim.set_interior_padding(0);
5723 }
5724 TF_ASSIGN_OR_RETURN(new_op,
5725 MakePadHlo(new_op, init_value, padding_config));
5726 }
5727
5728 return ReplaceInstruction(reduce_window, new_op);
5729 }
5730
5731 if (options_.enable_window_reduce_to_reduce_replacement()) {
5732 // A reduce window can be expressed as a reduce and a reshape if all
5733 // dimensions either have a window size of one or the entire dimension. If
5734 // there is no stride, dilation, or padding, this is as easy as checking the
5735 // size of the output shape and window dimension.
5736 //
5737 // The reshape is a bitcast since it adds one-sized dimensions. Often these
5738 // ones are immediately removed as well with another reshape. The
5739 // implementation of reduce tends to be slightly more efficient at reducing
5740 // entire dimensions compared to reduce window.
5741 auto effective_reduce_dims = [&] {
5742 if (window_util::HasStride(window) || window_util::HasDilation(window) ||
5743 window_util::HasPadding(window)) {
5744 return DimensionVector{};
5745 }
5746 DimensionVector reduce_dims;
5747 for (int64_t i = 0; i < window.dimensions_size(); ++i) {
5748 if (window.dimensions(i).size() == 1) {
5749 continue;
5750 } else if (reduce_window->shape().dimensions(i) == 1) {
5751 reduce_dims.push_back(i);
5752 } else {
5753 return DimensionVector{};
5754 }
5755 }
5756 return reduce_dims;
5757 }();
5758
5759 // If a reduce window can be expressed as a reduce, do so and reshape the
5760 // output.
5761 if (!effective_reduce_dims.empty()) {
5762 Shape reduce_shape = ShapeUtil::DeleteDimensions(effective_reduce_dims,
5763 reduce_window->shape());
5764 simplifier_->UpdateLayout(&reduce_shape);
5765 HloInstruction* reduce = hlo->AddInstruction(HloInstruction::CreateReduce(
5766 /*shape=*/reduce_shape,
5767 /*operand=*/operand,
5768 /*init_value=*/reduce_window->mutable_operand(1),
5769 /*dimensions_to_reduce=*/effective_reduce_dims,
5770 /*reduce_computation=*/function));
5771 return ReplaceWithNewInstruction(
5772 reduce_window,
5773 HloInstruction::CreateReshape(reduce_window->shape(), reduce));
5774 }
5775 }
5776
5777 // This optimization folds a pad op into reduce_window.
5778 HloInstruction* pad;
5779 const HloInstruction* convert = nullptr;
5780 if (operand->opcode() == HloOpcode::kPad) {
5781 pad = operand;
5782 } else if (operand->opcode() == HloOpcode::kConvert &&
5783 operand->operand(0)->opcode() == HloOpcode::kPad) {
5784 convert = operand;
5785 pad = operand->mutable_operand(0);
5786 } else {
5787 VLOG(10) << "Not folding pad into reduce-window as there is no pad.";
5788 return OkStatus();
5789 }
5790
5791 VLOG(10) << "Considering folding Pad: " << pad->ToString()
5792 << "\ninto reduce-window: " << reduce_window->ToString()
5793 << (convert != nullptr
5794 ? absl::StrCat("\nvia convert: ", convert->ToString())
5795 : "");
5796
5797 // Do not fold interior padding into ReduceWindow since the backends do not
5798 // support it.
5799 const PaddingConfig& pad_config = pad->padding_config();
5800 if (HasInteriorPadding(pad_config) && window_util::HasBaseDilation(window)) {
5801 VLOG(10) << "Not folding interior pad into base-dilated reduce-window.";
5802 return OkStatus();
5803 }
5804
5805 // If reduce_window already has padding, the pad value of the pad op and the
5806 // init value of reduce_window must match to allow folding the pad.
5807 const HloInstruction* pad_value = pad->operand(1);
5808 const HloInstruction* reduce_init_value = reduce_window->operand(1);
5809 if (pad_value != reduce_init_value) {
5810 auto literals_are_equivalent = [&] {
5811 auto& pad_literal = pad_value->literal();
5812 auto& reduce_init_literal = reduce_init_value->literal();
5813 if (pad_literal == reduce_init_literal) {
5814 return true;
5815 }
5816 auto converted_pad_literal =
5817 pad_literal.ConvertToShape(reduce_init_value->shape());
5818 if (!converted_pad_literal.ok()) {
5819 return false;
5820 }
5821 return converted_pad_literal.ValueOrDie() == reduce_init_literal;
5822 };
5823 // The pad value is usually a constant, so we handle that case and do not
5824 // try to get more fancy about proving equivalence in cases beyond that.
5825 if (pad_value->opcode() != HloOpcode::kConstant ||
5826 reduce_init_value->opcode() != HloOpcode::kConstant ||
5827 !literals_are_equivalent()) {
5828 VLOG(10) << "Not folding pad into reduce-window due to different pad "
5829 "values.";
5830 return OkStatus();
5831 }
5832 }
5833
5834 // If the pad puts a single non-identity value in each window that we're
5835 // reducing, then this is a broadcast.
5836 HloInstruction* pad_operand = pad->mutable_operand(0);
5837 auto is_effective_broadcast = [&] {
5838 if (window_util::HasStride(window)) {
5839 VLOG(10) << "Window has stride.";
5840 return false;
5841 }
5842 if (!window_util::HasSymmetricPadding(pad_config)) {
5843 VLOG(10) << "Window has uneven padding.";
5844 return false;
5845 }
5846 if (HasInteriorPadding(pad_config)) {
5847 VLOG(10) << "Window has interior padding.";
5848 return false;
5849 }
5850 for (int64_t i = 0; i < pad_config.dimensions_size(); ++i) {
5851 const auto& pad_dimension = pad_config.dimensions(i);
5852 if ((pad_dimension.edge_padding_low() != 0 ||
5853 pad_dimension.edge_padding_high() != 0) &&
5854 pad_operand->shape().dimensions(i) != 1) {
5855 VLOG(10) << "Found non-trivial dimension being padded: " << i;
5856 return false;
5857 }
5858 }
5859 VLOG(10) << "Found to be padding trivial dimensions only.";
5860
5861 for (int64_t i = 0; i < window.dimensions_size(); ++i) {
5862 const auto& pad_dimension = pad_config.dimensions(i);
5863 const WindowDimension& window_dimension = window.dimensions(i);
5864 bool dimension_has_padding = (pad_dimension.edge_padding_low() != 0 ||
5865 pad_dimension.edge_padding_high() != 0);
5866 if (dimension_has_padding &&
5867 window_dimension.size() < pad_dimension.edge_padding_low() + 1) {
5868 VLOG(10) << "Found window did not cover single unpadded element in "
5869 "dimension: "
5870 << i;
5871 return false;
5872 }
5873 if (pad_operand->shape().dimensions(i) != 1 &&
5874 window_dimension.size() != 1) {
5875 VLOG(10) << "Found window covers more than one element in non-trivial "
5876 "dimension: "
5877 << i;
5878 return false;
5879 }
5880 }
5881 VLOG(10) << "Found window covers a single unpadded element.";
5882 return true;
5883 };
5884
5885 HloInstruction* new_reduce_window_operand;
5886 if (convert != nullptr) {
5887 Shape changed_shape = ShapeUtil::ChangeElementType(
5888 pad_operand->shape(), convert->shape().element_type());
5889 simplifier_->UpdateLayout(&changed_shape);
5890 new_reduce_window_operand = hlo->AddInstruction(
5891 HloInstruction::CreateConvert(changed_shape, pad_operand));
5892 } else {
5893 new_reduce_window_operand = pad_operand;
5894 }
5895
5896 if (is_effective_broadcast()) {
5897 VLOG(10) << "Replacing pad/reduce-window with broadcast.";
5898 auto fadd = [hlo](std::unique_ptr<HloInstruction> x) {
5899 return hlo->AddInstruction(std::move(x));
5900 };
5901 return ReplaceWithNewInstruction(
5902 reduce_window, HloInstruction::CreateBroadcastSequence(
5903 /*output_shape=*/reduce_window->shape(),
5904 /*operand=*/new_reduce_window_operand, fadd));
5905 }
5906
5907 // Carry out the folding of the pad into reduce_window.
5908 VLOG(10) << "Folding pad into reduce-window.";
5909 Window new_window = window;
5910 const int64_t rank = reduce_window->shape().rank();
5911 TF_RET_CHECK(pad_config.dimensions_size() == rank);
5912 TF_RET_CHECK(window.dimensions_size() == rank);
5913 for (int64_t i = 0; i < rank; ++i) {
5914 const auto& pad_dim = pad_config.dimensions(i);
5915 auto& window_dim = *new_window.mutable_dimensions(i);
5916 window_dim.set_padding_low(window_dim.padding_low() +
5917 window_dim.base_dilation() *
5918 pad_dim.edge_padding_low());
5919 window_dim.set_padding_high(window_dim.padding_high() +
5920 window_dim.base_dilation() *
5921 pad_dim.edge_padding_high());
5922 if (pad_dim.interior_padding() != 0) {
5923 CHECK_EQ(window_dim.base_dilation(), 1);
5924 window_dim.set_base_dilation(1 + pad_dim.interior_padding());
5925 }
5926 }
5927
5928 return ReplaceWithNewInstruction(
5929 reduce_window, HloInstruction::CreateReduceWindow(
5930 /*shape=*/reduce_window->shape(),
5931 /*operand=*/new_reduce_window_operand,
5932 /*init_value=*/reduce_window->mutable_operand(1),
5933 /*window=*/new_window,
5934 /*reduce_computation=*/function));
5935 }
5936
HandleSelect(HloInstruction * select)5937 Status AlgebraicSimplifierVisitor::HandleSelect(HloInstruction* select) {
5938 // select(x, y, y) -> y.
5939 if (select->operand(1) == select->operand(2)) {
5940 return ReplaceInstruction(select, select->mutable_operand(1));
5941 }
5942 // select(true, x, y) -> x.
5943 if (IsAll(select->operand(0), true)) {
5944 return ReplaceInstruction(select, select->mutable_operand(1));
5945 }
5946 // select(false, x, y) -> y.
5947 if (IsAll(select->operand(0), false)) {
5948 return ReplaceInstruction(select, select->mutable_operand(2));
5949 }
5950 // select(not(pred), a, b) -> select(pred, b, a)
5951 if (HloOpcode::kNot == select->operand(0)->opcode()) {
5952 auto pred_operand = select->mutable_operand(0)->mutable_operand(0);
5953 auto on_true = select->mutable_operand(1);
5954 auto on_false = select->mutable_operand(2);
5955 return ReplaceWithNewInstruction(
5956 select,
5957 HloInstruction::CreateTernary(select->shape(), HloOpcode::kSelect,
5958 pred_operand, on_false, on_true));
5959 }
5960 return OkStatus();
5961 }
5962
HandleScatter(HloInstruction * hlo)5963 Status AlgebraicSimplifierVisitor::HandleScatter(HloInstruction* hlo) {
5964 auto* scatter = Cast<HloScatterInstruction>(hlo);
5965
5966 if (absl::c_all_of(scatter->scatter_updates(),
5967 [](const HloInstruction* updates) {
5968 return ShapeUtil::IsZeroElementArray(updates->shape());
5969 }) &&
5970 ReplaceInstructionIfCompatible(scatter, scatter->scatter_operands())) {
5971 return OkStatus();
5972 }
5973 if (scatter->scatter_operand_count() == 1 &&
5974 ShapeUtil::IsZeroElementArray(scatter->scatter_indices()->shape()) &&
5975 SameShape(scatter, scatter->scatter_operands()[0]) &&
5976 SameShape(scatter, scatter->scatter_updates()[0])) {
5977 return ReplaceWithNewInstruction(
5978 scatter, HloInstruction::CreateMap(scatter->shape(),
5979 {scatter->scatter_operands()[0],
5980 scatter->scatter_updates()[0]},
5981 scatter->to_apply()));
5982 }
5983 return OkStatus();
5984 }
5985
HandleSort(HloInstruction * sort)5986 Status AlgebraicSimplifierVisitor::HandleSort(HloInstruction* sort) {
5987 auto operand = sort->mutable_operand(0);
5988 int64_t dimension_to_sort = sort->dimensions(0);
5989 if (ShapeUtil::IsZeroElementArray(operand->shape()) ||
5990 operand->shape().dimensions(dimension_to_sort) <= 1) {
5991 if (sort->operand_count() == 1) {
5992 return ReplaceInstruction(sort, operand);
5993 }
5994 // If it is key/value sort, the output of sort is a tuple.
5995 return ReplaceWithNewInstruction(
5996 sort, HloInstruction::CreateTuple(sort->operands()));
5997 }
5998 return OkStatus();
5999 }
6000
HandleSqrt(HloInstruction * sqrt)6001 Status AlgebraicSimplifierVisitor::HandleSqrt(HloInstruction* sqrt) {
6002 VLOG(10) << "trying transform [sqrt(A*A) => |A|] " << sqrt->ToString();
6003 HloInstruction* sqrt_operand = sqrt->mutable_operand(0);
6004 if (sqrt_operand->opcode() == HloOpcode::kMultiply &&
6005 sqrt_operand->operand(0) == sqrt_operand->operand(1)) {
6006 return ReplaceWithNewInstruction(
6007 sqrt, HloInstruction::CreateUnary(
6008 sqrt_operand->mutable_operand(0)->shape(), HloOpcode::kAbs,
6009 sqrt_operand->mutable_operand(0)));
6010 }
6011 return OkStatus();
6012 }
6013 namespace {
OnlyPermutesDegenerateDims(const Shape & shape,absl::Span<const int64_t> perm)6014 bool OnlyPermutesDegenerateDims(const Shape& shape,
6015 absl::Span<const int64_t> perm) {
6016 std::vector<int64_t> new_permutation;
6017 int64_t degenerate_count = 0;
6018 for (int64_t i = 0; i < perm.size(); ++i) {
6019 if (shape.dimensions(i) != 1) {
6020 new_permutation.push_back(perm[i]);
6021 } else {
6022 ++degenerate_count;
6023 }
6024 }
6025 return degenerate_count > 0 && absl::c_is_sorted(new_permutation);
6026 }
6027
IsPermutationOfIota(absl::Span<const int64_t> elems)6028 bool IsPermutationOfIota(absl::Span<const int64_t> elems) {
6029 DimensionVector sorted(elems.begin(), elems.end());
6030 absl::c_sort(sorted);
6031 for (int i = 0; i < sorted.size(); i++) {
6032 if (sorted[i] != i) {
6033 return false;
6034 }
6035 }
6036 return true;
6037 }
6038
6039 } // namespace
6040
HandleTranspose(HloInstruction * transpose)6041 Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) {
6042 auto operand = transpose->mutable_operand(0);
6043 if (std::is_sorted(transpose->dimensions().begin(),
6044 transpose->dimensions().end())) {
6045 VLOG(10) << "deleting no-op transpose";
6046 return ReplaceInstruction(transpose, operand);
6047 }
6048
6049 if (HloOpcode::kTranspose == operand->opcode()) {
6050 return ReplaceWithNewInstruction(
6051 transpose, HloInstruction::CreateTranspose(
6052 transpose->shape(), operand->mutable_operand(0),
6053 ComposePermutations(operand->dimensions(),
6054 transpose->dimensions())));
6055 }
6056
6057 // Convert transpose(dot(a,b)) to dot(b,a).
6058 auto do_transpose_of_dot = [&]() -> StatusOr<bool> {
6059 if (operand->opcode() != HloOpcode::kDot || operand->user_count() != 1) {
6060 return false;
6061 }
6062 HloInstruction* dot = operand;
6063 HloInstruction* lhs = dot->mutable_operand(0);
6064 HloInstruction* rhs = dot->mutable_operand(1);
6065
6066 const int64_t rank = dot->shape().rank();
6067 const auto& dnums = dot->dot_dimension_numbers();
6068
6069 // Dot must be "somewhat canonical": batch dimensions at the beginning and
6070 // one non-contracting dim. It's the responsibility of DotDecomposer to
6071 // canonicalize dots.
6072 if (absl::MakeSpan(dnums.lhs_batch_dimensions()) !=
6073 absl::MakeSpan(dnums.rhs_batch_dimensions()) ||
6074 !IsPermutationOfIota(dnums.lhs_batch_dimensions()) ||
6075 dnums.lhs_contracting_dimensions_size() == 0 ||
6076 dnums.lhs_contracting_dimensions_size() +
6077 dnums.lhs_batch_dimensions_size() + 1 !=
6078 lhs->shape().rank() ||
6079 dnums.rhs_contracting_dimensions_size() == 0 ||
6080 dnums.rhs_contracting_dimensions_size() +
6081 dnums.rhs_batch_dimensions_size() + 1 !=
6082 rhs->shape().rank()) {
6083 return false;
6084 }
6085
6086 // Transpose must just be over the two last dims (i.e. the non-batch dims).
6087 DimensionVector expected_perm(rank);
6088 absl::c_iota(expected_perm, 0);
6089 std::swap(expected_perm.rbegin()[0], expected_perm.rbegin()[1]);
6090 if (transpose->dimensions() != expected_perm) {
6091 return false;
6092 }
6093
6094 DotDimensionNumbers new_dnums = dnums;
6095 std::swap(*new_dnums.mutable_lhs_contracting_dimensions(),
6096 *new_dnums.mutable_rhs_contracting_dimensions());
6097 TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
6098 transpose,
6099 HloInstruction::CreateDot(
6100 transpose->shape(), /*lhs=*/rhs, /*rhs=*/lhs, new_dnums,
6101 SwapOperandsInDotPrecisionConfig(dot->precision_config()))));
6102 return true;
6103 };
6104 TF_ASSIGN_OR_RETURN(bool did_transpose_of_dot, do_transpose_of_dot());
6105 if (did_transpose_of_dot) {
6106 return OkStatus();
6107 }
6108
6109 // Replace transpose with a reshape if more than one degenerate method is
6110 // permuted.
6111 if (OnlyPermutesDegenerateDims(transpose->shape(), transpose->dimensions())) {
6112 return ReplaceWithNewInstruction(
6113 transpose, HloInstruction::CreateReshape(
6114 transpose->shape(), transpose->mutable_operand(0)));
6115 }
6116
6117 if (operand->opcode() == HloOpcode::kRng && operand->user_count() == 1) {
6118 *operand->mutable_shape() = transpose->shape();
6119 return ReplaceInstruction(transpose, operand);
6120 }
6121
6122 if (options_.is_layout_sensitive() && TransposeIsBitcast(transpose)) {
6123 ReplaceWithBitcast(transpose);
6124 return OkStatus();
6125 }
6126
6127 // Replace reshape of a transpose of a reshape with concatenated slicing if
6128 // the reshape/transpose combination can be interpreted as a space-to-depth
6129 // transformation.
6130 if (operand->opcode() == HloOpcode::kReshape &&
6131 transpose->user_count() == 1 &&
6132 HloOpcode::kReshape == transpose->users()[0]->opcode()) {
6133 VLOG(2) << "trying depth-to-space transform";
6134 HloInstruction* reshape_operand = operand->mutable_operand(0);
6135 HloInstruction* outer_reshape = transpose->users()[0];
6136 TF_ASSIGN_OR_RETURN(
6137 bool did_transform, ([&]() -> StatusOr<bool> {
6138 if (operand->shape().dimensions_size() !=
6139 reshape_operand->shape().dimensions_size() + 1) {
6140 return false;
6141 }
6142
6143 // Check that the reshape is splitting a single dimension into two.
6144 int64_t split_dim = 0;
6145 bool found_split_dims = false;
6146 for (int64_t dim = 0; dim < reshape_operand->shape().rank(); dim++) {
6147 if (operand->shape().dimensions(dim) !=
6148 reshape_operand->shape().dimensions(dim)) {
6149 const int64_t expected_size =
6150 operand->shape().dimensions(dim) *
6151 operand->shape().dimensions(dim + 1);
6152 if (reshape_operand->shape().dimensions(dim) == expected_size) {
6153 split_dim = dim;
6154 found_split_dims = true;
6155 break;
6156 }
6157 return false;
6158 }
6159 }
6160 if (!found_split_dims) {
6161 return false;
6162 }
6163 for (int64_t dim = split_dim + 1;
6164 dim < reshape_operand->shape().rank(); dim++) {
6165 if (operand->shape().dimensions(dim + 1) !=
6166 reshape_operand->shape().dimensions(dim)) {
6167 return false;
6168 }
6169 }
6170
6171 const int64_t num_chunks = operand->shape().dimensions(split_dim);
6172 const int64_t chunk_size = operand->shape().dimensions(split_dim + 1);
6173
6174 // This optimization is only beneficial for a small number of chunks.
6175 // TODO(b/196832483): Determine the appropriate upper bound here.
6176 const int64_t kMaxChunksForTransformation = 5;
6177 if (num_chunks > kMaxChunksForTransformation) {
6178 return false;
6179 }
6180
6181 // Determine where the smaller split dimension is being placed in the
6182 // transpose
6183 int64_t transpose_dim = 0;
6184 bool found_transpose_dim = false;
6185 for (int64_t dim = 0; dim < operand->shape().rank(); dim++) {
6186 if (transpose->dimensions(dim) == split_dim) {
6187 transpose_dim = dim;
6188 found_transpose_dim = true;
6189 break;
6190 }
6191 }
6192
6193 // Check that only the small split dimension is reordered in the
6194 // transpose
6195 if (!found_transpose_dim || transpose_dim == split_dim ||
6196 transpose_dim == split_dim + 1) {
6197 return false;
6198 }
6199 for (int64_t dim = 0; dim < operand->shape().rank(); dim++) {
6200 int64_t offset = 0;
6201 if (dim > transpose_dim) {
6202 offset--;
6203 }
6204 if (dim > split_dim) {
6205 offset++;
6206 }
6207
6208 if (dim != transpose_dim &&
6209 transpose->dimensions(dim) != dim + offset) {
6210 return false;
6211 }
6212 }
6213
6214 // Check that the outer reshape has the same shape as the input,
6215 // with the transformed dimensions appropriately scaled by num_chunks.
6216 for (int64_t dim = 0; dim < reshape_operand->shape().rank(); dim++) {
6217 if (dim == transpose_dim - 1) {
6218 if (outer_reshape->shape().dimensions(dim) !=
6219 reshape_operand->shape().dimensions(dim) * num_chunks) {
6220 return false;
6221 }
6222 } else if (dim == split_dim) {
6223 if (outer_reshape->shape().dimensions(dim) !=
6224 reshape_operand->shape().dimensions(dim) / num_chunks) {
6225 return false;
6226 }
6227 } else if (outer_reshape->shape().dimensions(dim) !=
6228 reshape_operand->shape().dimensions(dim)) {
6229 return false;
6230 }
6231 }
6232
6233 // Create a concat-of-slices, slicing to create chunks of the expected
6234 // size on the smaller split dimension.
6235 std::vector<HloInstruction*> slices;
6236 for (int64_t i = 0; i < num_chunks; i++) {
6237 std::vector<int64_t> start_indices;
6238 std::vector<int64_t> end_indices;
6239 std::vector<int64_t> strides;
6240 const auto rank = reshape_operand->shape().rank();
6241 start_indices.reserve(rank);
6242 end_indices.reserve(rank);
6243 strides.reserve(rank);
6244 for (int64_t dim = 0; dim < rank; dim++) {
6245 if (dim == split_dim) {
6246 start_indices.push_back(i * chunk_size);
6247 end_indices.push_back(i * chunk_size + chunk_size);
6248 } else {
6249 start_indices.push_back(0);
6250 end_indices.push_back(reshape_operand->shape().dimensions(dim));
6251 }
6252 strides.push_back(1);
6253 }
6254 TF_ASSIGN_OR_RETURN(HloInstruction* const slice,
6255 MakeSliceHlo(reshape_operand, start_indices,
6256 end_indices, strides));
6257 slices.push_back(slice);
6258 VLOG(2) << "slice " << i << " " << slice->ToString();
6259 }
6260
6261 TF_ASSIGN_OR_RETURN(HloInstruction* const concat,
6262 MakeConcatHlo(slices, transpose_dim));
6263 VLOG(2) << "concat " << concat->ToString();
6264 TF_RETURN_IF_ERROR(
6265 outer_reshape->ReplaceOperandWithDifferentShape(0, concat));
6266
6267 return true;
6268 }()));
6269 if (did_transform) {
6270 MarkAsChanged();
6271 return OkStatus();
6272 }
6273 }
6274
6275 return OkStatus();
6276 }
6277
FoldConvInputPad(HloInstruction * convolution)6278 StatusOr<bool> AlgebraicSimplifierVisitor::FoldConvInputPad(
6279 HloInstruction* convolution) {
6280 HloInstruction *lhs, *a, *b;
6281 if (Match(convolution,
6282 m::Convolution(m::Pad(&lhs, m::Op(&a), m::ConstantScalar(0)),
6283 m::Op(&b)))) {
6284 const auto& window = convolution->window();
6285 const ConvolutionDimensionNumbers& dnums =
6286 convolution->convolution_dimension_numbers();
6287
6288 const auto& padding = lhs->padding_config();
6289
6290 // Can't pad batch or feature dims.
6291 for (int64_t dim :
6292 {dnums.input_batch_dimension(), dnums.input_feature_dimension()}) {
6293 const auto& p = padding.dimensions(dim);
6294 if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
6295 p.interior_padding() != 0) {
6296 return false;
6297 }
6298 }
6299
6300 // Compute the window which is the result of merging the kPad and the
6301 // convolution's existing window.
6302 Window new_window = window;
6303 for (int64_t dim = 0; dim < dnums.input_spatial_dimensions_size(); ++dim) {
6304 auto& w = *new_window.mutable_dimensions(dim);
6305 const auto& p = padding.dimensions(dnums.input_spatial_dimensions(dim));
6306 // Edge padding composes with itself in the straightforward way, but
6307 // composing interior padding is nontrivial, and we cowardly refuse to
6308 // think about it. If we see interior padding in either the kPad or conv,
6309 // bail if there's any sort of padding in the other.
6310 if (p.interior_padding() != 0 &&
6311 (w.padding_low() != 0 || w.padding_high() != 0 ||
6312 w.base_dilation() != 1)) {
6313 return false;
6314 }
6315 if (w.base_dilation() != 1 &&
6316 (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
6317 p.interior_padding() != 0)) {
6318 return false;
6319 }
6320
6321 w.set_padding_low(w.padding_low() + p.edge_padding_low());
6322 w.set_padding_high(w.padding_high() + p.edge_padding_high());
6323 if (p.interior_padding() != 0) {
6324 CHECK_EQ(w.base_dilation(), 1);
6325 w.set_base_dilation(1 + p.interior_padding());
6326 }
6327 }
6328
6329 auto new_conv =
6330 convolution->CloneWithNewOperands(convolution->shape(), {a, b});
6331 new_conv->set_window(new_window);
6332 TF_RETURN_IF_ERROR(
6333 ReplaceWithNewInstruction(convolution, std::move(new_conv)));
6334 return true;
6335 }
6336 return false;
6337 }
6338
FoldConvFilterPad(HloInstruction * convolution)6339 StatusOr<bool> AlgebraicSimplifierVisitor::FoldConvFilterPad(
6340 HloInstruction* convolution) {
6341 auto* lhs = convolution->mutable_operand(0);
6342 auto* rhs = convolution->mutable_operand(1);
6343 const ConvolutionDimensionNumbers& dnums =
6344 convolution->convolution_dimension_numbers();
6345
6346 if (rhs->opcode() != HloOpcode::kPad) {
6347 return false;
6348 }
6349
6350 // Convolution's padding is always zero, so bail if the kPad is adding
6351 // something other than zero.
6352 if (!IsAll(rhs->operand(1), 0)) {
6353 return false;
6354 }
6355
6356 const auto& padding = rhs->padding_config();
6357
6358 // Can't pad or dilate feature dims.
6359 for (int64_t dim : {dnums.kernel_input_feature_dimension(),
6360 dnums.kernel_output_feature_dimension()}) {
6361 const auto& p = padding.dimensions(dim);
6362 if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0 ||
6363 p.interior_padding() != 0) {
6364 return false;
6365 }
6366 }
6367
6368 // Compute the window which is the result of merging the kPad and the
6369 // convolution's existing window.
6370 Window new_window = convolution->window();
6371 for (int64_t dim = 0; dim < dnums.kernel_spatial_dimensions_size(); ++dim) {
6372 auto& w = *new_window.mutable_dimensions(dim);
6373 const auto& p = padding.dimensions(dnums.kernel_spatial_dimensions(dim));
6374
6375 // We can only do this transformation if p adds dilation to the filter --
6376 // edge padding on the filter is not supported in conv.
6377 if (p.edge_padding_low() != 0 || p.edge_padding_high() != 0) {
6378 return false;
6379 }
6380
6381 // Nothing to do if the kPad for this dim is entirely a nop.
6382 if (p.interior_padding() == 0) {
6383 continue;
6384 }
6385
6386 // We cowardly refuse to think about how dilation composes with itself;
6387 // bail if both the kPad and conv have dilation on this dimension.
6388 if (w.window_dilation() > 1) {
6389 return false;
6390 }
6391 CHECK_EQ(w.window_dilation(), 1);
6392 w.set_window_dilation(1 + p.interior_padding());
6393 w.set_size(rhs->operand(0)->shape().dimensions(
6394 dnums.kernel_spatial_dimensions(dim)));
6395 }
6396
6397 auto new_conv = convolution->CloneWithNewOperands(
6398 convolution->shape(), {lhs, rhs->mutable_operand(0)});
6399 new_conv->set_window(new_window);
6400 TF_RETURN_IF_ERROR(
6401 ReplaceWithNewInstruction(convolution, std::move(new_conv)));
6402 return true;
6403 }
6404
SwapConvOperands(HloInstruction * convolution)6405 StatusOr<bool> AlgebraicSimplifierVisitor::SwapConvOperands(
6406 HloInstruction* convolution) {
6407 if (!options_.enable_conv_operand_swap() || options_.is_layout_sensitive()) {
6408 return false;
6409 }
6410 if (convolution->feature_group_count() > 1 ||
6411 convolution->batch_group_count() > 1) {
6412 return false;
6413 }
6414
6415 const auto& dnums = convolution->convolution_dimension_numbers();
6416 const auto& window_dims = convolution->window().dimensions();
6417 Window swapped_window;
6418
6419 HloInstruction *input = convolution->mutable_operand(0),
6420 *kernel = convolution->mutable_operand(1);
6421 int64_t kernel_product = 1;
6422 int64_t swapped_kernel_product = 1;
6423 DimensionVector reverse_dimensions;
6424 for (int64_t spatial_dim = 0;
6425 spatial_dim < dnums.input_spatial_dimensions_size(); ++spatial_dim) {
6426 const int64_t kernel_size = window_dims[spatial_dim].size();
6427 const bool can_be_group_or_contraction =
6428 !window_dims[spatial_dim].window_reversal() &&
6429 window_dims[spatial_dim].padding_low() == 0 &&
6430 window_dims[spatial_dim].padding_high() == 0 &&
6431 window_dims[spatial_dim].window_dilation() == 1;
6432 const bool is_group_dim =
6433 can_be_group_or_contraction &&
6434 window_dims[spatial_dim].base_dilation() == kernel_size &&
6435 window_dims[spatial_dim].stride() == kernel_size - 1;
6436 const int64_t input_size =
6437 input->shape().dimensions(dnums.input_spatial_dimensions(spatial_dim));
6438 const bool is_pure_contraction_dim =
6439 kernel_size == input_size && can_be_group_or_contraction &&
6440 window_dims[spatial_dim].base_dilation() == 1 &&
6441 window_dims[spatial_dim].stride() == 1;
6442 if (is_group_dim || is_pure_contraction_dim) {
6443 *(swapped_window.add_dimensions()) = window_dims[spatial_dim];
6444 continue;
6445 }
6446
6447 const int64_t dilated_kernel_size =
6448 1 + (kernel_size - 1) * window_dims[spatial_dim].window_dilation();
6449 const int64_t dilated_input_size =
6450 1 + (input_size - 1) * window_dims[spatial_dim].base_dilation();
6451
6452 // Don't decide to swap if the input size is one, since many convolution
6453 // implementations can easily hand that special case efficiently.
6454 kernel_product *= kernel_size;
6455 swapped_kernel_product *=
6456 input_size == 1 && window_dims[spatial_dim].stride() == 1 &&
6457 window_dims[spatial_dim].window_dilation() == 1 &&
6458 window_dims[spatial_dim].padding_high() == kernel_size - 1 &&
6459 window_dims[spatial_dim].padding_low() == kernel_size - 1
6460 ? kernel_size
6461 : input_size;
6462
6463 auto new_dim = swapped_window.add_dimensions();
6464 new_dim->set_size(input_size);
6465 // If the kernel is not reversed, the activations must be manually reversed.
6466 if (!window_dims[spatial_dim].window_reversal()) {
6467 reverse_dimensions.push_back(
6468 dnums.kernel_spatial_dimensions(spatial_dim));
6469 }
6470 // The input is not originally reversed so it must be reversed to move the
6471 // kernel.
6472 new_dim->set_window_reversal(true);
6473 // Base dilation and window dilation switch places.
6474 new_dim->set_base_dilation(window_dims[spatial_dim].window_dilation());
6475 new_dim->set_window_dilation(window_dims[spatial_dim].base_dilation());
6476 new_dim->set_stride(window_dims[spatial_dim].stride());
6477 new_dim->set_padding_low(dilated_input_size +
6478 window_dims[spatial_dim].padding_low() -
6479 dilated_kernel_size);
6480 new_dim->set_padding_high(dilated_input_size +
6481 window_dims[spatial_dim].padding_high() -
6482 dilated_kernel_size);
6483 }
6484
6485 // Don't transform if a naive convolution implementation would not have fewer
6486 // flops.
6487 if (kernel_product <= swapped_kernel_product) {
6488 return false;
6489 }
6490 ConvolutionDimensionNumbers swapped_dnums;
6491 *swapped_dnums.mutable_output_spatial_dimensions() =
6492 dnums.output_spatial_dimensions();
6493 // Swap batch and output feature of the output.
6494 swapped_dnums.set_output_batch_dimension(dnums.output_feature_dimension());
6495 swapped_dnums.set_output_feature_dimension(dnums.output_batch_dimension());
6496
6497 // Swap input dnums with kernel dnums
6498 *swapped_dnums.mutable_input_spatial_dimensions() =
6499 dnums.kernel_spatial_dimensions();
6500 swapped_dnums.set_input_batch_dimension(
6501 dnums.kernel_output_feature_dimension());
6502 swapped_dnums.set_input_feature_dimension(
6503 dnums.kernel_input_feature_dimension());
6504
6505 // Swap kernel dnums with input dnums
6506 *swapped_dnums.mutable_kernel_spatial_dimensions() =
6507 dnums.input_spatial_dimensions();
6508 swapped_dnums.set_kernel_output_feature_dimension(
6509 dnums.input_batch_dimension());
6510 swapped_dnums.set_kernel_input_feature_dimension(
6511 dnums.input_feature_dimension());
6512
6513 PrecisionConfig precision_config;
6514 precision_config.add_operand_precision(
6515 convolution->precision_config().operand_precision(1));
6516 precision_config.add_operand_precision(
6517 convolution->precision_config().operand_precision(0));
6518 if (!reverse_dimensions.empty()) {
6519 TF_ASSIGN_OR_RETURN(kernel, MakeReverseHlo(kernel, reverse_dimensions));
6520 }
6521 TF_ASSIGN_OR_RETURN(
6522 HloInstruction * new_convolution,
6523 MakeConvolveHlo(
6524 kernel, input, /*feature_group_count=*/1,
6525 /*batch_group_count=*/1, swapped_window, swapped_dnums,
6526 precision_config,
6527 /*preferred_element_type=*/convolution->shape().element_type()));
6528
6529 // If we're running on GPU we need to check that we can actually lower the
6530 // conv with the given reverse_dims (either none, or rank 2 and all)
6531 if (!options_.ConvIsLowerable(new_convolution)) {
6532 TF_RETURN_IF_ERROR(kernel->parent()->RemoveInstruction(new_convolution));
6533 return false;
6534 }
6535
6536 convolution->SetupDerivedInstruction(new_convolution);
6537 TF_RETURN_IF_ERROR(ReplaceInstruction(convolution, new_convolution));
6538
6539 return true;
6540 }
6541
SimplifyConvToDot(HloInstruction * convolution)6542 StatusOr<bool> AlgebraicSimplifierVisitor::SimplifyConvToDot(
6543 HloInstruction* convolution) {
6544 auto* lhs = convolution->mutable_operand(0);
6545 auto* rhs = convolution->mutable_operand(1);
6546 const auto& window = convolution->window();
6547 const ConvolutionDimensionNumbers& dnums =
6548 convolution->convolution_dimension_numbers();
6549
6550 if (!options_.enable_conv_simplification()) {
6551 return false;
6552 }
6553
6554 // TODO(b/31337498): For now, we cowardly refuse to do this optimization in
6555 // layout-insensitive mode, for fear of adding nontrivial reshapes.
6556 if (!options_.is_layout_sensitive()) {
6557 return false;
6558 }
6559
6560 const Shape& input_shape = lhs->shape();
6561 const Shape& filter_shape = rhs->shape();
6562 const Shape& convolution_shape = convolution->shape();
6563 TF_RET_CHECK(LayoutUtil::HasLayout(input_shape));
6564 TF_RET_CHECK(LayoutUtil::HasLayout(filter_shape));
6565 TF_RET_CHECK(LayoutUtil::HasLayout(convolution_shape));
6566
6567 // Require the spatial dimensions in the kernel to have a bound of one.
6568 for (int64_t i = 0; i < dnums.kernel_spatial_dimensions_size(); ++i) {
6569 if (filter_shape.dimensions(dnums.kernel_spatial_dimensions(i)) != 1) {
6570 return false;
6571 }
6572 }
6573
6574 // Stride ignores part of the output, which matrix multiplication does not do,
6575 // so require no stride. Padding and base (lhs) dilation both implicitly
6576 // extend the data, which matrix multiplication also does not do, so require
6577 // no padding and no base (lhs) dilation. Window (rhs) dilation has no effect
6578 // for a 1x1 window, so window dilation is no problem.
6579 if (window_util::HasStride(window) || window_util::HasPadding(window) ||
6580 window_util::HasBaseDilation(window)) {
6581 return false;
6582 }
6583
6584 // Also, the shapes must align for a rowmajor matmul:
6585 // - the input and output have the same layout.
6586 // - for input/output, the channel dimension must be the most minor. Other
6587 // spatial dims can be in any order.
6588 // - for filters, the input channel dimension must be more major than the
6589 // output channel dimension. The width+height don't matter because
6590 // they are 1.
6591 //
6592 // These constraints are harsh. If the channel dimension is the most major
6593 // and/or the layout of input/output feature dimensions are reversed, we can
6594 // still convert Conv into more efficient Matmul with operand transposition
6595 // (such as the transposition flags in cuBLAS SGEMM).
6596 if (!LayoutUtil::Equal(input_shape.layout(), convolution_shape.layout()) ||
6597 LayoutUtil::Minor(input_shape.layout(), 0) !=
6598 dnums.input_feature_dimension() ||
6599 LayoutUtil::Minor(convolution_shape.layout(), 0) !=
6600 dnums.output_feature_dimension() ||
6601 // The input feature dimension should come later in the minor-to-major
6602 // order.
6603 (PositionInContainer(LayoutUtil::MinorToMajor(filter_shape),
6604 dnums.kernel_input_feature_dimension()) <
6605 PositionInContainer(LayoutUtil::MinorToMajor(filter_shape),
6606 dnums.kernel_output_feature_dimension()))) {
6607 return false;
6608 }
6609
6610 if (convolution->feature_group_count() != 1 ||
6611 convolution->batch_group_count() != 1) {
6612 return false;
6613 }
6614 auto add_bitcast = [&](Shape shape, HloInstruction* operand) {
6615 std::vector<int64_t> dims(operand->shape().dimensions_size());
6616 std::iota(dims.begin(), dims.end(), 0);
6617 return operand->AddInstruction(
6618 HloInstruction::CreateBitcast(shape, operand));
6619 };
6620
6621 // Replace it with a dot, with bitcasts around it to get the right shape.
6622 const int64_t input_channels =
6623 input_shape.dimensions(dnums.input_feature_dimension());
6624 const int64_t output_channels =
6625 filter_shape.dimensions(dnums.kernel_output_feature_dimension());
6626
6627 // Computes the product of the non-feature dimensions.
6628 int64_t conv_width = 1;
6629 for (int i = 0; i < input_shape.dimensions_size(); ++i) {
6630 if (i != dnums.input_feature_dimension()) {
6631 conv_width *= input_shape.dimensions(i);
6632 }
6633 }
6634
6635 // We already checked feature_dimension is most minor, so data in input_shape
6636 // and row-major {conv_width,input_channels} are bitwise identical.
6637 Shape new_input_shape = ShapeUtil::MakeShapeWithDescendingLayout(
6638 input_shape.element_type(), {conv_width, input_channels});
6639 simplifier_->UpdateLayout(&new_input_shape);
6640 // We already checked input_feature_dimension is more major than
6641 // output_feature_dimension, so data in filter_shape and row-major
6642 // {input_channels,output_channels} are bitwise identical.
6643 Shape new_filter_shape = ShapeUtil::MakeShapeWithDescendingLayout(
6644 filter_shape.element_type(), {input_channels, output_channels});
6645 simplifier_->UpdateLayout(&new_filter_shape);
6646 Shape dot_output_shape = ShapeUtil::MakeShapeWithDescendingLayout(
6647 convolution_shape.element_type(), {conv_width, output_channels});
6648 simplifier_->UpdateLayout(&dot_output_shape);
6649
6650 auto new_lhs = add_bitcast(new_input_shape, lhs);
6651 auto new_rhs = add_bitcast(new_filter_shape, rhs);
6652 DotDimensionNumbers dot_dimension_numbers;
6653 dot_dimension_numbers.add_lhs_contracting_dimensions(1);
6654 dot_dimension_numbers.add_rhs_contracting_dimensions(0);
6655 auto dot = convolution->AddInstruction(HloInstruction::CreateDot(
6656 dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers,
6657 convolution->precision_config()));
6658
6659 TF_RETURN_IF_ERROR(
6660 ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot)));
6661 return true;
6662 }
6663
HandleConvolution(HloInstruction * convolution)6664 Status AlgebraicSimplifierVisitor::HandleConvolution(
6665 HloInstruction* convolution) {
6666 if (options_.enable_scalar_multiply_reduction()) {
6667 TF_RETURN_IF_ERROR(ScalarMultiplyReduction(convolution));
6668 }
6669
6670 // Zero-sized input or filter.
6671 if (ShapeUtil::IsZeroElementArray(convolution->operand(0)->shape()) ||
6672 ShapeUtil::IsZeroElementArray(convolution->operand(1)->shape())) {
6673 return ReplaceInstruction(convolution, MakeScalarLike(convolution, 0));
6674 }
6675
6676 // Try to merge padding/dilation of the input with the convolution's window.
6677 TF_ASSIGN_OR_RETURN(bool folded_input_pad, FoldConvInputPad(convolution));
6678 if (folded_input_pad) {
6679 return OkStatus();
6680 }
6681
6682 // Try to merge dilation of the filter with the convolution's window.
6683 TF_ASSIGN_OR_RETURN(bool folded_filter_pad, FoldConvFilterPad(convolution));
6684 if (folded_filter_pad) {
6685 return OkStatus();
6686 }
6687
6688 // Try to swap convolution operands.
6689 TF_ASSIGN_OR_RETURN(bool swapped, SwapConvOperands(convolution));
6690 if (swapped) {
6691 return OkStatus();
6692 }
6693 // Try to replace the convolution with a kDot instruction.
6694 TF_ASSIGN_OR_RETURN(bool replaced_with_dot, SimplifyConvToDot(convolution));
6695 if (replaced_with_dot) {
6696 return OkStatus();
6697 }
6698
6699 return OkStatus();
6700 }
6701
HandleMap(HloInstruction * map)6702 Status AlgebraicSimplifierVisitor::HandleMap(HloInstruction* map) {
6703 auto* map_computation = map->to_apply();
6704 auto* map_root = map_computation->root_instruction();
6705 if (map_root->opcode() == HloOpcode::kParameter) {
6706 ReplaceInstructionIfCompatible(
6707 map, map->mutable_operand(map_root->parameter_number()));
6708 return OkStatus();
6709 }
6710 if (map_root->opcode() == HloOpcode::kConstant) {
6711 if (!ShapeUtil::IsScalar(map_root->shape())) {
6712 return OkStatus();
6713 }
6714 auto clone = map_root->CloneWithNewOperands(map_root->shape(), {});
6715 if (ShapeUtil::IsScalar(map->shape())) {
6716 return ReplaceWithNewInstruction(map, std::move(clone));
6717 }
6718 return ReplaceWithNewInstruction(
6719 map, HloInstruction::CreateBroadcast(
6720 map->shape(), map->AddInstruction(std::move(clone)), {}));
6721 }
6722 // Inline the map if the map computation only contains an elementwise
6723 // operation that can accept arbitrary shapes.
6724 if (map_root->opcode() == HloOpcode::kFusion || !map_root->IsElementwise()) {
6725 return OkStatus();
6726 }
6727 std::vector<HloInstruction*> new_operands;
6728 for (auto* root_operand : map_root->operands()) {
6729 if (root_operand->opcode() != HloOpcode::kParameter) {
6730 return OkStatus();
6731 }
6732 new_operands.push_back(
6733 map->mutable_operand(root_operand->parameter_number()));
6734 }
6735 auto clone = map_root->CloneWithNewOperands(map->shape(), new_operands);
6736 return ReplaceWithNewInstruction(map, std::move(clone));
6737 }
6738
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)6739 StatusOr<bool> AlgebraicSimplifier::Run(
6740 HloModule* module,
6741 const absl::flat_hash_set<absl::string_view>& execution_threads) {
6742 bool changed = false;
6743 AlgebraicSimplifierVisitor visitor(options_, this);
6744 for (auto* comp : module->MakeNonfusionComputations(execution_threads)) {
6745 if (visitor.Run(comp, options_, this)) {
6746 changed = true;
6747 }
6748 }
6749 return changed;
6750 }
6751
6752 } // namespace xla
6753