xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/algebraic_simplifier.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ALGEBRAIC_SIMPLIFIER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_ALGEBRAIC_SIMPLIFIER_H_
18 
19 #include <cstdint>
20 #include <functional>
21 #include <utility>
22 
23 #include "absl/container/inlined_vector.h"
24 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
25 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
26 #include "tensorflow/compiler/xla/service/hlo_module.h"
27 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
28 #include "tensorflow/compiler/xla/util.h"
29 
30 namespace xla {
31 
32 class AlgebraicSimplifierOptions {
33  public:
34   // Platform dependent callback to determine if a reshape `from_shape` to
35   // `to_shape` is a bitcast.
36   using ReshapeIsBitcastCallback =
37       std::function<bool(const Shape& from_shape, const Shape& to_shape)>;
38   // Platform dependent callback to determine if a set of reverse dimensions is
39   // lowerable
40   using ConvIsLowerableCallback = std::function<bool(HloInstruction* window)>;
41 
42   explicit AlgebraicSimplifierOptions(
43       ReshapeIsBitcastCallback reshape_is_bitcast_callback = {},
44       ConvIsLowerableCallback conv_is_lowerable_callback = {})
reshape_is_bitcast_callback_(std::move (reshape_is_bitcast_callback))45       : reshape_is_bitcast_callback_(std::move(reshape_is_bitcast_callback)),
46         conv_is_lowerable_callback_(std::move(conv_is_lowerable_callback)) {}
47 
48   // Use the platform specific callback if set. It is not sensible to return
49   // true here if the options are not layout sensitive.
ReshapeIsBitcast(const Shape & from_shape,const Shape & to_shape)50   bool ReshapeIsBitcast(const Shape& from_shape, const Shape& to_shape) const {
51     if (!is_layout_sensitive_) {
52       return false;
53     }
54     if (!reshape_is_bitcast_callback_) {
55       return ShapeUtil::ReshapeIsBitcast(from_shape, to_shape);
56     }
57     return reshape_is_bitcast_callback_(from_shape, to_shape);
58   }
59 
60   // Use the platform specific callback if set. Otherwise, return true.
ConvIsLowerable(HloInstruction * reverse_dims)61   bool ConvIsLowerable(HloInstruction* reverse_dims) const {
62     if (!conv_is_lowerable_callback_) {
63       return true;
64     }
65     return conv_is_lowerable_callback_(reverse_dims);
66   }
67 
68   // If is_layout_sensitive is true, then the simplifier preserves layout during
69   // transformation. Otherwise, layout is ignored.
set_is_layout_sensitive(bool is_layout_sensitive)70   void set_is_layout_sensitive(bool is_layout_sensitive) {
71     is_layout_sensitive_ = is_layout_sensitive;
72   }
73 
is_layout_sensitive()74   bool is_layout_sensitive() const { return is_layout_sensitive_; }
75 
76   // Enable dot simplification on platforms where it is profitable.
set_enable_dot_strength_reduction(bool enable_dot_strength_reduction)77   void set_enable_dot_strength_reduction(bool enable_dot_strength_reduction) {
78     enable_dot_strength_reduction_ = enable_dot_strength_reduction;
79   }
80 
enable_dot_strength_reduction()81   bool enable_dot_strength_reduction() const {
82     return enable_dot_strength_reduction_;
83   }
84 
85   // Enable dot->multiple rewrite for dot as an outer-product
set_enable_dot_to_multiply_rewrite(bool enable_dot_to_multiply_rewrite)86   void set_enable_dot_to_multiply_rewrite(bool enable_dot_to_multiply_rewrite) {
87     enable_dot_to_multiply_rewrite_ = enable_dot_to_multiply_rewrite;
88   }
89 
enable_dot_to_multiply_rewrite()90   bool enable_dot_to_multiply_rewrite() const {
91     return enable_dot_to_multiply_rewrite_;
92   }
93 
94   // Enable convolution simplification on platforms where it is profitable.
set_enable_conv_simplification(bool enable_conv_simplification)95   void set_enable_conv_simplification(bool enable_conv_simplification) {
96     enable_conv_simplification_ = enable_conv_simplification;
97   }
enable_conv_simplification()98   bool enable_conv_simplification() const {
99     return enable_conv_simplification_;
100   }
101 
102   // Enable convolution operand swapping on platforms where it is supported.
set_enable_conv_operand_swap(bool enable_conv_operand_swap)103   void set_enable_conv_operand_swap(bool enable_conv_operand_swap) {
104     enable_conv_operand_swap_ = enable_conv_operand_swap;
105   }
enable_conv_operand_swap()106   bool enable_conv_operand_swap() const { return enable_conv_operand_swap_; }
107 
108   // Move constant scalar multiply to one operand or output of convolutions with
109   // the smallest tensor size, to reduce the number of scalar multiply.
set_enable_scalar_multiply_reduction(bool enable_scalar_multiply_reduction)110   void set_enable_scalar_multiply_reduction(
111       bool enable_scalar_multiply_reduction) {
112     enable_scalar_multiply_reduction_ = enable_scalar_multiply_reduction;
113   }
114 
enable_scalar_multiply_reduction()115   bool enable_scalar_multiply_reduction() const {
116     return enable_scalar_multiply_reduction_;
117   }
118 
119   // Also the algebraic simplifer to treat floating point values like real
120   // numbers.
set_enable_floats_are_real(bool enable_floats_are_real)121   void set_enable_floats_are_real(bool enable_floats_are_real) {
122     enable_floats_are_real_ = enable_floats_are_real;
123   }
124 
enable_floats_are_real()125   bool enable_floats_are_real() const { return enable_floats_are_real_; }
126 
127   // If enable_window_reduce_replacement is true, the kReduceWindow instruction
128   // can be optimized by replacement with simpler operations.
set_enable_window_reduce_to_reduce_replacement(bool enable_window_reduce_to_reduce_replacement)129   void set_enable_window_reduce_to_reduce_replacement(
130       bool enable_window_reduce_to_reduce_replacement) {
131     enable_window_reduce_to_reduce_replacement_ =
132         enable_window_reduce_to_reduce_replacement;
133   }
134 
enable_window_reduce_to_reduce_replacement()135   bool enable_window_reduce_to_reduce_replacement() const {
136     return enable_window_reduce_to_reduce_replacement_;
137   }
138 
139   // Sets the size of a gather operand that can be unrolled into many selects.
set_very_small_gather_size(int64_t size)140   void set_very_small_gather_size(int64_t size) {
141     very_small_gather_size_ = size;
142   }
143 
very_small_gather_size()144   int64_t very_small_gather_size() const { return very_small_gather_size_; }
145 
set_cudnn_batchnorm_forward_training_metadata(const std::string & c)146   void set_cudnn_batchnorm_forward_training_metadata(const std::string& c) {
147     metadata_.cudnn_batchnorm_forward_training_metadata = c;
148   }
149 
get_cudnn_batchnorm_forward_training_metadata()150   const std::string& get_cudnn_batchnorm_forward_training_metadata() const {
151     return metadata_.cudnn_batchnorm_forward_training_metadata;
152   }
153 
set_enable_reduce_of_reshape(bool enable_reduce_of_reshape)154   void set_enable_reduce_of_reshape(bool enable_reduce_of_reshape) {
155     enable_reduce_of_reshape_ = enable_reduce_of_reshape;
156   }
157 
enable_reduce_of_reshape()158   bool enable_reduce_of_reshape() const { return enable_reduce_of_reshape_; }
159 
set_enable_negative_padding_replacement(bool enable_negative_padding_replacement)160   void set_enable_negative_padding_replacement(
161       bool enable_negative_padding_replacement) {
162     enable_negative_padding_replacement_ = enable_negative_padding_replacement;
163   }
164 
enable_negative_padding_replacement()165   bool enable_negative_padding_replacement() const {
166     return enable_negative_padding_replacement_;
167   }
168 
set_enable_sink_broadcast(bool enable_sink_broadcast)169   void set_enable_sink_broadcast(bool enable_sink_broadcast) {
170     enable_sink_broadcast_ = enable_sink_broadcast;
171   }
172 
enable_sink_broadcast()173   bool enable_sink_broadcast() const { return enable_sink_broadcast_; }
174 
175   // If true, min(x, NaN) = NaN.  If false, min(x, NaN) = x.
176   //
177   // TODO(b/209827141): Remove this and make minmax_propagate_nan uncondtionally
178   // true.
minmax_propagate_nan()179   bool minmax_propagate_nan() const { return minmax_propagate_nan_; }
set_minmax_propagate_nan(bool val)180   void set_minmax_propagate_nan(bool val) { minmax_propagate_nan_ = val; }
181 
182  private:
183   // Metadata struct can be used to store any metadata information encapsulated
184   // with the AlgebraicSimplierOptions that can be later used in an
185   // AlgebraicSimplifier pass. For example,
186   // cudnn_batchnorm_forward_training_metadata can be used to store the name of
187   // a custom call. If the custom call is
188   // __cudnn$batchNormalizationForwardTraining, the output with index 2 is
189   // guaranteed to be postive. This property has been used to recursively
190   // determine if the operand of an instruction is always positive.
191   struct Metadata {
192     std::string cudnn_batchnorm_forward_training_metadata{""};
MetadataMetadata193     Metadata() {}
194   };
195   ReshapeIsBitcastCallback reshape_is_bitcast_callback_;
196   ConvIsLowerableCallback conv_is_lowerable_callback_;
197   bool is_layout_sensitive_{false};
198   bool enable_dot_strength_reduction_{true};
199   bool enable_dot_to_multiply_rewrite_{true};
200   bool enable_conv_simplification_{true};
201   bool enable_conv_operand_swap_{true};
202   bool enable_scalar_multiply_reduction_{false};
203   bool enable_floats_are_real_{false};
204   bool enable_window_reduce_to_reduce_replacement_{true};
205   bool enable_reduce_of_reshape_{true};
206   bool enable_negative_padding_replacement_{true};
207   bool enable_sink_broadcast_{true};
208   int64_t very_small_gather_size_{4};
209   bool minmax_propagate_nan_{true};
210   Metadata metadata_;
211 };
212 
213 // A pass which performs algebraic simplifications.
214 class AlgebraicSimplifier : public HloModulePass {
215  public:
216   // If is_layout_sensitive is true, then the simplifier preserves layout during
217   // transformation. Otherwise, layout is ignored.
AlgebraicSimplifier(const AlgebraicSimplifierOptions & options)218   explicit AlgebraicSimplifier(const AlgebraicSimplifierOptions& options)
219       : options_(options) {}
220   ~AlgebraicSimplifier() override = default;
name()221   absl::string_view name() const override { return "algsimp"; }
222 
223   // Run algebraic simplification on the given computation. Returns whether the
224   // computation was changed.
225   using HloPassInterface::Run;
226   StatusOr<bool> Run(
227       HloModule* module,
228       const absl::flat_hash_set<absl::string_view>& execution_threads) override;
229 
230   // Create constant from literal with tiles and element size updated in the
231   // constant's layout.
CreateConstantWithLayoutUpdated(Literal literal)232   std::unique_ptr<HloInstruction> CreateConstantWithLayoutUpdated(
233       Literal literal) {
234     auto constant = HloInstruction::CreateConstant(std::move(literal));
235     UpdateLayout(constant->mutable_shape());
236     return constant;
237   }
238 
239  protected:
240   AlgebraicSimplifierOptions options_;
241 };
242 
243 // AlgebraicSimplifierVisitor traverses the HLO computation and reduces certain
244 // algebraic expressions to simplified forms. Note: This only supports
245 // simplifications that simply look at the operands of an instruction. For the
246 // more general case a worklist based approach would be needed.
247 class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor {
248  public:
AlgebraicSimplifierVisitor(const AlgebraicSimplifierOptions & options,AlgebraicSimplifier * simplifier)249   explicit AlgebraicSimplifierVisitor(const AlgebraicSimplifierOptions& options,
250                                       AlgebraicSimplifier* simplifier)
251       : options_(options), simplifier_(simplifier) {}
252 
253   Status HandleAbs(HloInstruction* abs) override;
254 
255   Status HandleAdd(HloInstruction* add) override;
256 
257   Status HandleAnd(HloInstruction* logical_and) override;
258 
259   Status HandleBitcast(HloInstruction* bitcast) override;
260 
261   Status HandleBitcastConvert(HloInstruction* bitcast) override;
262 
263   Status HandleBroadcast(HloInstruction* broadcast) override;
264 
265   Status HandleCompare(HloInstruction* compare) override;
266 
267   Status HandleConcatenate(HloInstruction* concatenate) override;
268 
269   Status HandleConstant(HloInstruction* constant) override;
270 
271   Status HandleCopy(HloInstruction* copy) override;
272 
273   Status HandleConvert(HloInstruction* convert) override;
274 
275   Status HandleComplex(HloInstruction* complex) override;
276 
277   Status HandleReal(HloInstruction* real) override;
278 
279   Status HandleImag(HloInstruction* imag) override;
280 
281   Status HandleIota(HloInstruction* instruction) override;
282 
283   Status HandleConvolution(HloInstruction* convolution) override;
284 
285   Status HandleDivide(HloInstruction* divide) override;
286 
287   Status HandleDot(HloInstruction* dot) override;
288 
289   Status HandleGather(HloInstruction* gather) override;
290 
291   Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
292 
293   Status HandleLog(HloInstruction* log) override;
294 
295   Status HandleMaximum(HloInstruction* maximum) override;
296 
297   Status HandleMinimum(HloInstruction* minimum) override;
298 
299   Status HandleClamp(HloInstruction* clamp) override;
300 
301   Status HandleMultiply(HloInstruction* multiply) override;
302 
303   Status HandleNegate(HloInstruction* negate) override;
304 
305   Status HandleNot(HloInstruction* logical_not) override;
306 
307   Status HandleOptimizationBarrier(HloInstruction* barrier) override;
308 
309   Status HandleOr(HloInstruction* logical_or) override;
310 
311   Status HandlePad(HloInstruction* pad) override;
312 
313   Status HandlePower(HloInstruction* power) override;
314 
315   Status HandleRemainder(HloInstruction* remainder) override;
316 
317   Status HandleReshape(HloInstruction* reshape) override;
318 
319   Status HandleReduce(HloInstruction* hlo) override;
320 
321   Status HandleReduceWindow(HloInstruction* hlo) override;
322 
323   Status HandleReverse(HloInstruction* reverse) override;
324 
325   Status HandleRsqrt(HloInstruction* rsqrt) override;
326 
327   Status HandleSlice(HloInstruction* slice) override;
328 
329   Status HandleSqrt(HloInstruction* sqrt) override;
330 
331   Status HandleDynamicSlice(HloInstruction* dynamic_slice) override;
332 
333   Status HandleDynamicUpdateSlice(
334       HloInstruction* dynamic_update_slice) override;
335   Status HandleScatter(HloInstruction* hlo) override;
336 
337   Status HandleSelect(HloInstruction* select) override;
338 
339   Status HandleSort(HloInstruction* sort) override;
340 
341   Status HandleTranspose(HloInstruction* transpose) override;
342 
343   Status HandleSubtract(HloInstruction* sub) override;
344 
345   Status HandleMap(HloInstruction* map) override;
346 
347   // Runs the visitor on a computation.
348   bool Run(HloComputation* computation,
349            const AlgebraicSimplifierOptions& options,
350            AlgebraicSimplifier* simplifier);
351 
352   // Compute a function that maps from bitcasted dimensions to the resulting
353   // ones. Returns the function as a vector if successful; std::optional
354   // otherwise.
355   static std::optional<std::vector<std::vector<int64_t>>> ComputeBitcastDimMap(
356       const Shape& bitcast_shape, const Shape& operand_shape);
357   // Invert the directions of the given bitcast dimension map.
358   static std::vector<std::vector<int64_t>> InvertBitcastDimMap(
359       const Shape& original_shape, const Shape& bitcast_shape,
360       const std::vector<std::vector<int64_t>>& original_map);
361 
362   // Modify the layout dimensions of result_shape, so that it becomes the
363   // re-shaped result of applying bitcast to the original_shape, by using
364   // dim_map to re-shape layout dimensions of original_shape. Returns the
365   // result_shape with modified layout if the conversion succeeds; Returns
366   // std::nullopt if fails.
367   static std::optional<Shape> ReshapeLayoutDimensions(
368       const Shape& original_shape, const Shape& result_shape,
369       const std::vector<std::vector<int64_t>>& original_map,
370       const std::vector<std::vector<int64_t>>& result_map);
371 
372   // Allow backend constraints on tiling etc. to invalidate optimizations.
IsValidLayout(const Shape & shape)373   virtual bool IsValidLayout(const Shape& shape) { return true; }
374 
375  protected:
376   // The backend-specific options selected for the algebraic simplifier.
377   const AlgebraicSimplifierOptions& options_;
378 
379  private:
380   // Removes degenerate dimension from dot.
381   StatusOr<bool> RemoveDegenerateDimensionFromDot(HloInstruction* dot);
382 
383   // Converts to primitive type if the input hlo is not that type, otherwise
384   // returns the original hlo.
AsType(HloInstruction * hlo,const PrimitiveType element_type)385   HloInstruction* AsType(HloInstruction* hlo,
386                          const PrimitiveType element_type) {
387     if (hlo->shape().element_type() == element_type) {
388       return hlo;
389     }
390     Shape changed_shape =
391         ShapeUtil::ChangeElementType(hlo->shape(), element_type);
392     simplifier_->UpdateLayout(&changed_shape);
393     return computation_->AddInstruction(
394         HloInstruction::CreateConvert(changed_shape, hlo));
395   }
396 
397   // Transposes a dot operand such that the batch dimensions are the most major,
398   // and the contracting dimensions are most minor.
399   StatusOr<HloInstruction*> NormalizeDotOperandToBatchMajorAndContractingMinor(
400       HloInstruction* dot_operand, absl::Span<const int64_t> batch_dimensions,
401       absl::Span<const int64_t> contracting_dimensions);
402 
403   // Simplify dot(transpose(a), transpose(b)) to transpose(dot(b,a)) (or
404   // transpose(dot(a,b)) if only the batch dims are transposed).
405   //
406   // Requires the dot has been canonicalized by DotDecomposer into
407   //
408   //   LHS [batch dims..., non-contracting dim, contracting dim]
409   //   RHS [batch dims..., contracting dim, non-contracting dim].
410   StatusOr<bool> RemoveTransposesFromDotOperands(HloInstruction* dot);
411 
412   // Helper method to perform and add reduction on a list of dimensions.
413   HloInstruction* AddReduce(HloInstruction* hlo, absl::Span<const int64_t> dims,
414                             PrimitiveType type);
415 
416   // Move scalar multiply to the smallest side of convolution to
417   // reduce multiply computations.
418   Status ScalarMultiplyReduction(HloInstruction* dot);
419 
420   // Convenience method for replacing an instruction with a bitcast. If operand
421   // is not null, then the bitcast will use the specified operand instead of the
422   // operand of the instruction.
423   void ReplaceWithBitcast(HloInstruction* instruction,
424                           HloInstruction* operand = nullptr);
425 
426   // Change copy(bitcast...(copy)) into copy(bitcast) or bitcast(copy) so that
427   // the replicated copies are combined when allowed by layout/tiling assignment
428   // constraints.
429   bool SwapCopyBitcastCopy(HloInstruction* root_copy);
430 
431   // Replace old instruction with new instruction if old and new instructions
432   // are compatible (have the same shape and replacement preserves sharding).
433   // Updates uses and root instruction. Returns whether a replacement was made.
434   bool ReplaceInstructionIfCompatible(HloInstruction* old_instruction,
435                                       HloInstruction* new_instruction);
436   // Similar to above but tuplizes `new_instructions` if there are more than 1
437   // instructions.
438   bool ReplaceInstructionIfCompatible(
439       HloInstruction* old_instruction,
440       absl::Span<HloInstruction* const> new_instructions);
441 
442   // Returns whether the shape of the output of the given instructions are the
443   // same for the purposes of simplification. If options_.is_layout_sensitive()
444   // is true, then this tests shape equality including layout
445   // (ShapeUtil::Equal). If options_.is_layout_sensitive() is false, then the
446   // tests shape compatibility (ShapeUtil::Compatible).
447   bool SameShape(const HloInstruction* lhs, const HloInstruction* rhs) const;
448 
449   // Same as above but takes shape arguments directly.
450   bool SameShape(const Shape& lhs, const Shape& rhs) const;
451 
452   // A Broadcast that feeds an element-wise operation with a unique non-scalar
453   // operand can sink to after the operation.
454   StatusOr<bool> TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(
455       HloInstruction* broadcast);
456 
457   StatusOr<HloInstruction*> OptimizeDotOfConcat(HloInstruction* dot);
458   StatusOr<HloInstruction*> OptimizeDotOfConcatHelper(
459       HloInstruction* dot, HloInstruction* lhs, int64_t lhs_contracting_dim,
460       HloInstruction* rhs, int64_t rhs_contracting_dim, bool swapped);
461 
462   StatusOr<HloInstruction*> OptimizeDotOfGather(HloInstruction* dot);
463 
464   StatusOr<HloInstruction*> OptimizeDotOfReorderContractingDims(
465       HloInstruction* dot);
466 
GetOrCreateScalarAddComputation(PrimitiveType type)467   HloComputation* GetOrCreateScalarAddComputation(PrimitiveType type) {
468     HloComputation*& scalar_add_computation = scalar_add_computations_[type];
469     if (scalar_add_computation) {
470       return scalar_add_computation;
471     }
472 
473     HloComputation::Builder b("scalar_add_computation");
474     Shape shape = ShapeUtil::MakeShape(type, {});
475     simplifier_->UpdateLayout(&shape);
476     auto scalar_lhs = b.AddInstruction(
477         HloInstruction::CreateParameter(0, shape, "scalar_lhs"));
478     auto scalar_rhs = b.AddInstruction(
479         HloInstruction::CreateParameter(1, shape, "scalar_rhs"));
480     auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary(
481         shape, HloOpcode::kAdd, scalar_lhs, scalar_rhs));
482     scalar_add_computation =
483         computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op));
484     return scalar_add_computation;
485   }
486 
487   // Tries to fold a kPad in the input or filter into the convolution
488   // instruction's window.
489   virtual StatusOr<bool> FoldConvInputPad(HloInstruction* convolution);
490   StatusOr<bool> FoldConvFilterPad(HloInstruction* convolution);
491 
492   // Tries to swap convolution operands if they would result in a more efficient
493   // convolution.
494   StatusOr<bool> SwapConvOperands(HloInstruction* convolution);
495 
496   // Tries to use a kDot in place of the given convolution.
497   StatusOr<bool> SimplifyConvToDot(HloInstruction* convolution);
498 
499   // Tries to simplify a slice where the result of the slice is a scalar.
500   StatusOr<bool> TrySimplifyScalarSlice(HloInstruction* slice);
501 
502   // Tries to convert slice(reshape(X)) into reshape(slice(X))
503   StatusOr<bool> TryToReorderSliceAndReshape(HloInstruction* slice);
504 
505   // Tries to convert slice(reverse(X)) into reverse(slice(X))
506   StatusOr<bool> TryToReorderSliceAndReverse(HloInstruction* slice);
507 
508   // Tries to simplify `(and (< a N) (< a K))` in cases where `N <= K` into
509   // `(< a N)`. This is crucial for being able to figure out the loop trip
510   // count.
511   //
512   // Assumes that the input is conjunction.
513   StatusOr<bool> TrySimplifyTautologicalCompare(HloInstruction* conjunction);
514 
515   // Tries to simlplify (bitcast-convert (concat (bitcast-convert A) ...)) where
516   // the types of inner and outer bitcast-convert cancel out.
517   StatusOr<bool> TrySimplifyTautologicalBitcastConvert(HloInstruction* bitcast);
518 
519   // Useful when we want to use the same visitor over multiple computations.
520   void ResetState(HloComputation* computation);
521 
522   // Current HloComputation instance the AlgebraicSimplifierVisitor is
523   // traversing.
524   HloComputation* computation_;
525 
526   // Cached computation for adding two scalars of a given type.
527   absl::flat_hash_map<PrimitiveType, HloComputation*> scalar_add_computations_;
528 
529   AlgebraicSimplifier* simplifier_ = nullptr;
530 };
531 
532 }  // namespace xla
533 
534 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_ALGEBRAIC_SIMPLIFIER_H_
535