1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ALGEBRAIC_SIMPLIFIER_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_ALGEBRAIC_SIMPLIFIER_H_ 18 19 #include <cstdint> 20 #include <functional> 21 #include <utility> 22 23 #include "absl/container/inlined_vector.h" 24 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" 25 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 26 #include "tensorflow/compiler/xla/service/hlo_module.h" 27 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" 28 #include "tensorflow/compiler/xla/util.h" 29 30 namespace xla { 31 32 class AlgebraicSimplifierOptions { 33 public: 34 // Platform dependent callback to determine if a reshape `from_shape` to 35 // `to_shape` is a bitcast. 36 using ReshapeIsBitcastCallback = 37 std::function<bool(const Shape& from_shape, const Shape& to_shape)>; 38 // Platform dependent callback to determine if a set of reverse dimensions is 39 // lowerable 40 using ConvIsLowerableCallback = std::function<bool(HloInstruction* window)>; 41 42 explicit AlgebraicSimplifierOptions( 43 ReshapeIsBitcastCallback reshape_is_bitcast_callback = {}, 44 ConvIsLowerableCallback conv_is_lowerable_callback = {}) reshape_is_bitcast_callback_(std::move (reshape_is_bitcast_callback))45 : reshape_is_bitcast_callback_(std::move(reshape_is_bitcast_callback)), 46 conv_is_lowerable_callback_(std::move(conv_is_lowerable_callback)) {} 47 48 // Use the platform specific callback if set. It is not sensible to return 49 // true here if the options are not layout sensitive. ReshapeIsBitcast(const Shape & from_shape,const Shape & to_shape)50 bool ReshapeIsBitcast(const Shape& from_shape, const Shape& to_shape) const { 51 if (!is_layout_sensitive_) { 52 return false; 53 } 54 if (!reshape_is_bitcast_callback_) { 55 return ShapeUtil::ReshapeIsBitcast(from_shape, to_shape); 56 } 57 return reshape_is_bitcast_callback_(from_shape, to_shape); 58 } 59 60 // Use the platform specific callback if set. Otherwise, return true. ConvIsLowerable(HloInstruction * reverse_dims)61 bool ConvIsLowerable(HloInstruction* reverse_dims) const { 62 if (!conv_is_lowerable_callback_) { 63 return true; 64 } 65 return conv_is_lowerable_callback_(reverse_dims); 66 } 67 68 // If is_layout_sensitive is true, then the simplifier preserves layout during 69 // transformation. Otherwise, layout is ignored. set_is_layout_sensitive(bool is_layout_sensitive)70 void set_is_layout_sensitive(bool is_layout_sensitive) { 71 is_layout_sensitive_ = is_layout_sensitive; 72 } 73 is_layout_sensitive()74 bool is_layout_sensitive() const { return is_layout_sensitive_; } 75 76 // Enable dot simplification on platforms where it is profitable. set_enable_dot_strength_reduction(bool enable_dot_strength_reduction)77 void set_enable_dot_strength_reduction(bool enable_dot_strength_reduction) { 78 enable_dot_strength_reduction_ = enable_dot_strength_reduction; 79 } 80 enable_dot_strength_reduction()81 bool enable_dot_strength_reduction() const { 82 return enable_dot_strength_reduction_; 83 } 84 85 // Enable dot->multiple rewrite for dot as an outer-product set_enable_dot_to_multiply_rewrite(bool enable_dot_to_multiply_rewrite)86 void set_enable_dot_to_multiply_rewrite(bool enable_dot_to_multiply_rewrite) { 87 enable_dot_to_multiply_rewrite_ = enable_dot_to_multiply_rewrite; 88 } 89 enable_dot_to_multiply_rewrite()90 bool enable_dot_to_multiply_rewrite() const { 91 return enable_dot_to_multiply_rewrite_; 92 } 93 94 // Enable convolution simplification on platforms where it is profitable. set_enable_conv_simplification(bool enable_conv_simplification)95 void set_enable_conv_simplification(bool enable_conv_simplification) { 96 enable_conv_simplification_ = enable_conv_simplification; 97 } enable_conv_simplification()98 bool enable_conv_simplification() const { 99 return enable_conv_simplification_; 100 } 101 102 // Enable convolution operand swapping on platforms where it is supported. set_enable_conv_operand_swap(bool enable_conv_operand_swap)103 void set_enable_conv_operand_swap(bool enable_conv_operand_swap) { 104 enable_conv_operand_swap_ = enable_conv_operand_swap; 105 } enable_conv_operand_swap()106 bool enable_conv_operand_swap() const { return enable_conv_operand_swap_; } 107 108 // Move constant scalar multiply to one operand or output of convolutions with 109 // the smallest tensor size, to reduce the number of scalar multiply. set_enable_scalar_multiply_reduction(bool enable_scalar_multiply_reduction)110 void set_enable_scalar_multiply_reduction( 111 bool enable_scalar_multiply_reduction) { 112 enable_scalar_multiply_reduction_ = enable_scalar_multiply_reduction; 113 } 114 enable_scalar_multiply_reduction()115 bool enable_scalar_multiply_reduction() const { 116 return enable_scalar_multiply_reduction_; 117 } 118 119 // Also the algebraic simplifer to treat floating point values like real 120 // numbers. set_enable_floats_are_real(bool enable_floats_are_real)121 void set_enable_floats_are_real(bool enable_floats_are_real) { 122 enable_floats_are_real_ = enable_floats_are_real; 123 } 124 enable_floats_are_real()125 bool enable_floats_are_real() const { return enable_floats_are_real_; } 126 127 // If enable_window_reduce_replacement is true, the kReduceWindow instruction 128 // can be optimized by replacement with simpler operations. set_enable_window_reduce_to_reduce_replacement(bool enable_window_reduce_to_reduce_replacement)129 void set_enable_window_reduce_to_reduce_replacement( 130 bool enable_window_reduce_to_reduce_replacement) { 131 enable_window_reduce_to_reduce_replacement_ = 132 enable_window_reduce_to_reduce_replacement; 133 } 134 enable_window_reduce_to_reduce_replacement()135 bool enable_window_reduce_to_reduce_replacement() const { 136 return enable_window_reduce_to_reduce_replacement_; 137 } 138 139 // Sets the size of a gather operand that can be unrolled into many selects. set_very_small_gather_size(int64_t size)140 void set_very_small_gather_size(int64_t size) { 141 very_small_gather_size_ = size; 142 } 143 very_small_gather_size()144 int64_t very_small_gather_size() const { return very_small_gather_size_; } 145 set_cudnn_batchnorm_forward_training_metadata(const std::string & c)146 void set_cudnn_batchnorm_forward_training_metadata(const std::string& c) { 147 metadata_.cudnn_batchnorm_forward_training_metadata = c; 148 } 149 get_cudnn_batchnorm_forward_training_metadata()150 const std::string& get_cudnn_batchnorm_forward_training_metadata() const { 151 return metadata_.cudnn_batchnorm_forward_training_metadata; 152 } 153 set_enable_reduce_of_reshape(bool enable_reduce_of_reshape)154 void set_enable_reduce_of_reshape(bool enable_reduce_of_reshape) { 155 enable_reduce_of_reshape_ = enable_reduce_of_reshape; 156 } 157 enable_reduce_of_reshape()158 bool enable_reduce_of_reshape() const { return enable_reduce_of_reshape_; } 159 set_enable_negative_padding_replacement(bool enable_negative_padding_replacement)160 void set_enable_negative_padding_replacement( 161 bool enable_negative_padding_replacement) { 162 enable_negative_padding_replacement_ = enable_negative_padding_replacement; 163 } 164 enable_negative_padding_replacement()165 bool enable_negative_padding_replacement() const { 166 return enable_negative_padding_replacement_; 167 } 168 set_enable_sink_broadcast(bool enable_sink_broadcast)169 void set_enable_sink_broadcast(bool enable_sink_broadcast) { 170 enable_sink_broadcast_ = enable_sink_broadcast; 171 } 172 enable_sink_broadcast()173 bool enable_sink_broadcast() const { return enable_sink_broadcast_; } 174 175 // If true, min(x, NaN) = NaN. If false, min(x, NaN) = x. 176 // 177 // TODO(b/209827141): Remove this and make minmax_propagate_nan uncondtionally 178 // true. minmax_propagate_nan()179 bool minmax_propagate_nan() const { return minmax_propagate_nan_; } set_minmax_propagate_nan(bool val)180 void set_minmax_propagate_nan(bool val) { minmax_propagate_nan_ = val; } 181 182 private: 183 // Metadata struct can be used to store any metadata information encapsulated 184 // with the AlgebraicSimplierOptions that can be later used in an 185 // AlgebraicSimplifier pass. For example, 186 // cudnn_batchnorm_forward_training_metadata can be used to store the name of 187 // a custom call. If the custom call is 188 // __cudnn$batchNormalizationForwardTraining, the output with index 2 is 189 // guaranteed to be postive. This property has been used to recursively 190 // determine if the operand of an instruction is always positive. 191 struct Metadata { 192 std::string cudnn_batchnorm_forward_training_metadata{""}; MetadataMetadata193 Metadata() {} 194 }; 195 ReshapeIsBitcastCallback reshape_is_bitcast_callback_; 196 ConvIsLowerableCallback conv_is_lowerable_callback_; 197 bool is_layout_sensitive_{false}; 198 bool enable_dot_strength_reduction_{true}; 199 bool enable_dot_to_multiply_rewrite_{true}; 200 bool enable_conv_simplification_{true}; 201 bool enable_conv_operand_swap_{true}; 202 bool enable_scalar_multiply_reduction_{false}; 203 bool enable_floats_are_real_{false}; 204 bool enable_window_reduce_to_reduce_replacement_{true}; 205 bool enable_reduce_of_reshape_{true}; 206 bool enable_negative_padding_replacement_{true}; 207 bool enable_sink_broadcast_{true}; 208 int64_t very_small_gather_size_{4}; 209 bool minmax_propagate_nan_{true}; 210 Metadata metadata_; 211 }; 212 213 // A pass which performs algebraic simplifications. 214 class AlgebraicSimplifier : public HloModulePass { 215 public: 216 // If is_layout_sensitive is true, then the simplifier preserves layout during 217 // transformation. Otherwise, layout is ignored. AlgebraicSimplifier(const AlgebraicSimplifierOptions & options)218 explicit AlgebraicSimplifier(const AlgebraicSimplifierOptions& options) 219 : options_(options) {} 220 ~AlgebraicSimplifier() override = default; name()221 absl::string_view name() const override { return "algsimp"; } 222 223 // Run algebraic simplification on the given computation. Returns whether the 224 // computation was changed. 225 using HloPassInterface::Run; 226 StatusOr<bool> Run( 227 HloModule* module, 228 const absl::flat_hash_set<absl::string_view>& execution_threads) override; 229 230 // Create constant from literal with tiles and element size updated in the 231 // constant's layout. CreateConstantWithLayoutUpdated(Literal literal)232 std::unique_ptr<HloInstruction> CreateConstantWithLayoutUpdated( 233 Literal literal) { 234 auto constant = HloInstruction::CreateConstant(std::move(literal)); 235 UpdateLayout(constant->mutable_shape()); 236 return constant; 237 } 238 239 protected: 240 AlgebraicSimplifierOptions options_; 241 }; 242 243 // AlgebraicSimplifierVisitor traverses the HLO computation and reduces certain 244 // algebraic expressions to simplified forms. Note: This only supports 245 // simplifications that simply look at the operands of an instruction. For the 246 // more general case a worklist based approach would be needed. 247 class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { 248 public: AlgebraicSimplifierVisitor(const AlgebraicSimplifierOptions & options,AlgebraicSimplifier * simplifier)249 explicit AlgebraicSimplifierVisitor(const AlgebraicSimplifierOptions& options, 250 AlgebraicSimplifier* simplifier) 251 : options_(options), simplifier_(simplifier) {} 252 253 Status HandleAbs(HloInstruction* abs) override; 254 255 Status HandleAdd(HloInstruction* add) override; 256 257 Status HandleAnd(HloInstruction* logical_and) override; 258 259 Status HandleBitcast(HloInstruction* bitcast) override; 260 261 Status HandleBitcastConvert(HloInstruction* bitcast) override; 262 263 Status HandleBroadcast(HloInstruction* broadcast) override; 264 265 Status HandleCompare(HloInstruction* compare) override; 266 267 Status HandleConcatenate(HloInstruction* concatenate) override; 268 269 Status HandleConstant(HloInstruction* constant) override; 270 271 Status HandleCopy(HloInstruction* copy) override; 272 273 Status HandleConvert(HloInstruction* convert) override; 274 275 Status HandleComplex(HloInstruction* complex) override; 276 277 Status HandleReal(HloInstruction* real) override; 278 279 Status HandleImag(HloInstruction* imag) override; 280 281 Status HandleIota(HloInstruction* instruction) override; 282 283 Status HandleConvolution(HloInstruction* convolution) override; 284 285 Status HandleDivide(HloInstruction* divide) override; 286 287 Status HandleDot(HloInstruction* dot) override; 288 289 Status HandleGather(HloInstruction* gather) override; 290 291 Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; 292 293 Status HandleLog(HloInstruction* log) override; 294 295 Status HandleMaximum(HloInstruction* maximum) override; 296 297 Status HandleMinimum(HloInstruction* minimum) override; 298 299 Status HandleClamp(HloInstruction* clamp) override; 300 301 Status HandleMultiply(HloInstruction* multiply) override; 302 303 Status HandleNegate(HloInstruction* negate) override; 304 305 Status HandleNot(HloInstruction* logical_not) override; 306 307 Status HandleOptimizationBarrier(HloInstruction* barrier) override; 308 309 Status HandleOr(HloInstruction* logical_or) override; 310 311 Status HandlePad(HloInstruction* pad) override; 312 313 Status HandlePower(HloInstruction* power) override; 314 315 Status HandleRemainder(HloInstruction* remainder) override; 316 317 Status HandleReshape(HloInstruction* reshape) override; 318 319 Status HandleReduce(HloInstruction* hlo) override; 320 321 Status HandleReduceWindow(HloInstruction* hlo) override; 322 323 Status HandleReverse(HloInstruction* reverse) override; 324 325 Status HandleRsqrt(HloInstruction* rsqrt) override; 326 327 Status HandleSlice(HloInstruction* slice) override; 328 329 Status HandleSqrt(HloInstruction* sqrt) override; 330 331 Status HandleDynamicSlice(HloInstruction* dynamic_slice) override; 332 333 Status HandleDynamicUpdateSlice( 334 HloInstruction* dynamic_update_slice) override; 335 Status HandleScatter(HloInstruction* hlo) override; 336 337 Status HandleSelect(HloInstruction* select) override; 338 339 Status HandleSort(HloInstruction* sort) override; 340 341 Status HandleTranspose(HloInstruction* transpose) override; 342 343 Status HandleSubtract(HloInstruction* sub) override; 344 345 Status HandleMap(HloInstruction* map) override; 346 347 // Runs the visitor on a computation. 348 bool Run(HloComputation* computation, 349 const AlgebraicSimplifierOptions& options, 350 AlgebraicSimplifier* simplifier); 351 352 // Compute a function that maps from bitcasted dimensions to the resulting 353 // ones. Returns the function as a vector if successful; std::optional 354 // otherwise. 355 static std::optional<std::vector<std::vector<int64_t>>> ComputeBitcastDimMap( 356 const Shape& bitcast_shape, const Shape& operand_shape); 357 // Invert the directions of the given bitcast dimension map. 358 static std::vector<std::vector<int64_t>> InvertBitcastDimMap( 359 const Shape& original_shape, const Shape& bitcast_shape, 360 const std::vector<std::vector<int64_t>>& original_map); 361 362 // Modify the layout dimensions of result_shape, so that it becomes the 363 // re-shaped result of applying bitcast to the original_shape, by using 364 // dim_map to re-shape layout dimensions of original_shape. Returns the 365 // result_shape with modified layout if the conversion succeeds; Returns 366 // std::nullopt if fails. 367 static std::optional<Shape> ReshapeLayoutDimensions( 368 const Shape& original_shape, const Shape& result_shape, 369 const std::vector<std::vector<int64_t>>& original_map, 370 const std::vector<std::vector<int64_t>>& result_map); 371 372 // Allow backend constraints on tiling etc. to invalidate optimizations. IsValidLayout(const Shape & shape)373 virtual bool IsValidLayout(const Shape& shape) { return true; } 374 375 protected: 376 // The backend-specific options selected for the algebraic simplifier. 377 const AlgebraicSimplifierOptions& options_; 378 379 private: 380 // Removes degenerate dimension from dot. 381 StatusOr<bool> RemoveDegenerateDimensionFromDot(HloInstruction* dot); 382 383 // Converts to primitive type if the input hlo is not that type, otherwise 384 // returns the original hlo. AsType(HloInstruction * hlo,const PrimitiveType element_type)385 HloInstruction* AsType(HloInstruction* hlo, 386 const PrimitiveType element_type) { 387 if (hlo->shape().element_type() == element_type) { 388 return hlo; 389 } 390 Shape changed_shape = 391 ShapeUtil::ChangeElementType(hlo->shape(), element_type); 392 simplifier_->UpdateLayout(&changed_shape); 393 return computation_->AddInstruction( 394 HloInstruction::CreateConvert(changed_shape, hlo)); 395 } 396 397 // Transposes a dot operand such that the batch dimensions are the most major, 398 // and the contracting dimensions are most minor. 399 StatusOr<HloInstruction*> NormalizeDotOperandToBatchMajorAndContractingMinor( 400 HloInstruction* dot_operand, absl::Span<const int64_t> batch_dimensions, 401 absl::Span<const int64_t> contracting_dimensions); 402 403 // Simplify dot(transpose(a), transpose(b)) to transpose(dot(b,a)) (or 404 // transpose(dot(a,b)) if only the batch dims are transposed). 405 // 406 // Requires the dot has been canonicalized by DotDecomposer into 407 // 408 // LHS [batch dims..., non-contracting dim, contracting dim] 409 // RHS [batch dims..., contracting dim, non-contracting dim]. 410 StatusOr<bool> RemoveTransposesFromDotOperands(HloInstruction* dot); 411 412 // Helper method to perform and add reduction on a list of dimensions. 413 HloInstruction* AddReduce(HloInstruction* hlo, absl::Span<const int64_t> dims, 414 PrimitiveType type); 415 416 // Move scalar multiply to the smallest side of convolution to 417 // reduce multiply computations. 418 Status ScalarMultiplyReduction(HloInstruction* dot); 419 420 // Convenience method for replacing an instruction with a bitcast. If operand 421 // is not null, then the bitcast will use the specified operand instead of the 422 // operand of the instruction. 423 void ReplaceWithBitcast(HloInstruction* instruction, 424 HloInstruction* operand = nullptr); 425 426 // Change copy(bitcast...(copy)) into copy(bitcast) or bitcast(copy) so that 427 // the replicated copies are combined when allowed by layout/tiling assignment 428 // constraints. 429 bool SwapCopyBitcastCopy(HloInstruction* root_copy); 430 431 // Replace old instruction with new instruction if old and new instructions 432 // are compatible (have the same shape and replacement preserves sharding). 433 // Updates uses and root instruction. Returns whether a replacement was made. 434 bool ReplaceInstructionIfCompatible(HloInstruction* old_instruction, 435 HloInstruction* new_instruction); 436 // Similar to above but tuplizes `new_instructions` if there are more than 1 437 // instructions. 438 bool ReplaceInstructionIfCompatible( 439 HloInstruction* old_instruction, 440 absl::Span<HloInstruction* const> new_instructions); 441 442 // Returns whether the shape of the output of the given instructions are the 443 // same for the purposes of simplification. If options_.is_layout_sensitive() 444 // is true, then this tests shape equality including layout 445 // (ShapeUtil::Equal). If options_.is_layout_sensitive() is false, then the 446 // tests shape compatibility (ShapeUtil::Compatible). 447 bool SameShape(const HloInstruction* lhs, const HloInstruction* rhs) const; 448 449 // Same as above but takes shape arguments directly. 450 bool SameShape(const Shape& lhs, const Shape& rhs) const; 451 452 // A Broadcast that feeds an element-wise operation with a unique non-scalar 453 // operand can sink to after the operation. 454 StatusOr<bool> TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( 455 HloInstruction* broadcast); 456 457 StatusOr<HloInstruction*> OptimizeDotOfConcat(HloInstruction* dot); 458 StatusOr<HloInstruction*> OptimizeDotOfConcatHelper( 459 HloInstruction* dot, HloInstruction* lhs, int64_t lhs_contracting_dim, 460 HloInstruction* rhs, int64_t rhs_contracting_dim, bool swapped); 461 462 StatusOr<HloInstruction*> OptimizeDotOfGather(HloInstruction* dot); 463 464 StatusOr<HloInstruction*> OptimizeDotOfReorderContractingDims( 465 HloInstruction* dot); 466 GetOrCreateScalarAddComputation(PrimitiveType type)467 HloComputation* GetOrCreateScalarAddComputation(PrimitiveType type) { 468 HloComputation*& scalar_add_computation = scalar_add_computations_[type]; 469 if (scalar_add_computation) { 470 return scalar_add_computation; 471 } 472 473 HloComputation::Builder b("scalar_add_computation"); 474 Shape shape = ShapeUtil::MakeShape(type, {}); 475 simplifier_->UpdateLayout(&shape); 476 auto scalar_lhs = b.AddInstruction( 477 HloInstruction::CreateParameter(0, shape, "scalar_lhs")); 478 auto scalar_rhs = b.AddInstruction( 479 HloInstruction::CreateParameter(1, shape, "scalar_rhs")); 480 auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary( 481 shape, HloOpcode::kAdd, scalar_lhs, scalar_rhs)); 482 scalar_add_computation = 483 computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); 484 return scalar_add_computation; 485 } 486 487 // Tries to fold a kPad in the input or filter into the convolution 488 // instruction's window. 489 virtual StatusOr<bool> FoldConvInputPad(HloInstruction* convolution); 490 StatusOr<bool> FoldConvFilterPad(HloInstruction* convolution); 491 492 // Tries to swap convolution operands if they would result in a more efficient 493 // convolution. 494 StatusOr<bool> SwapConvOperands(HloInstruction* convolution); 495 496 // Tries to use a kDot in place of the given convolution. 497 StatusOr<bool> SimplifyConvToDot(HloInstruction* convolution); 498 499 // Tries to simplify a slice where the result of the slice is a scalar. 500 StatusOr<bool> TrySimplifyScalarSlice(HloInstruction* slice); 501 502 // Tries to convert slice(reshape(X)) into reshape(slice(X)) 503 StatusOr<bool> TryToReorderSliceAndReshape(HloInstruction* slice); 504 505 // Tries to convert slice(reverse(X)) into reverse(slice(X)) 506 StatusOr<bool> TryToReorderSliceAndReverse(HloInstruction* slice); 507 508 // Tries to simplify `(and (< a N) (< a K))` in cases where `N <= K` into 509 // `(< a N)`. This is crucial for being able to figure out the loop trip 510 // count. 511 // 512 // Assumes that the input is conjunction. 513 StatusOr<bool> TrySimplifyTautologicalCompare(HloInstruction* conjunction); 514 515 // Tries to simlplify (bitcast-convert (concat (bitcast-convert A) ...)) where 516 // the types of inner and outer bitcast-convert cancel out. 517 StatusOr<bool> TrySimplifyTautologicalBitcastConvert(HloInstruction* bitcast); 518 519 // Useful when we want to use the same visitor over multiple computations. 520 void ResetState(HloComputation* computation); 521 522 // Current HloComputation instance the AlgebraicSimplifierVisitor is 523 // traversing. 524 HloComputation* computation_; 525 526 // Cached computation for adding two scalars of a given type. 527 absl::flat_hash_map<PrimitiveType, HloComputation*> scalar_add_computations_; 528 529 AlgebraicSimplifier* simplifier_ = nullptr; 530 }; 531 532 } // namespace xla 533 534 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_ALGEBRAIC_SIMPLIFIER_H_ 535