1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_CONCAT_CODE_MOTION_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_CONCAT_CODE_MOTION_H_ 18 19 #include "tensorflow/compiler/xla/service/hlo_module.h" 20 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" 21 #include "tensorflow/compiler/xla/statusor.h" 22 23 namespace xla { 24 25 // A pass that tries to lift concatenation out of a while loop, and replace 26 // piece-wise subcomputations in the loop body with one on the concatenated 27 // shape. 28 // 29 // For example: 30 // 31 // loop = while (a, b, c, d) { 32 // e = concat(a, b) 33 // f = some-op(e) <with the same shape as e> 34 // s0 = slice(f) first half 35 // s1 = slice(f) second half 36 // a_1 = add(a, s0) 37 // b_1 = add(b, s1) 38 // a_new = add(a_1, c) 39 // b_new = add(b_1, d) 40 // c_new = add(a_new, c) 41 // d_new = add(b_new, d) 42 // ROOT tuple(a_new, b_new, c_new, d_new) 43 // } 44 // 45 // will be transformed to 46 // 47 // ab = concat(a, b) 48 // cd = concat(c, d) 49 // while (ab, cd) { 50 // f = some-op(ab) 51 // ab_1 = add(ab, f) 52 // ab_new = add(ab_1, cd) 53 // cd_new = add(ab_new, cd) 54 // ROOT tuple(ab_new, cd_new) 55 // } 56 // a_new = slice(ab_new) first half 57 // b_new = slice(ab_new) second half 58 // c_new = slice(cd_new) first half 59 // d_new = slice(cd_new) second half 60 class WhileLoopConcatCodeMotion : public HloModulePass { 61 public: WhileLoopConcatCodeMotion(int64_t min_operand_count_to_optimize)62 explicit WhileLoopConcatCodeMotion(int64_t min_operand_count_to_optimize) 63 : min_operand_count_to_optimize_(min_operand_count_to_optimize) {} 64 ~WhileLoopConcatCodeMotion() override = default; 65 name()66 absl::string_view name() const override { 67 static constexpr absl::string_view kName = "while-loop-concat-code-motion"; 68 return kName; 69 } 70 using HloPassInterface::Run; 71 StatusOr<bool> Run( 72 HloModule* module, 73 const absl::flat_hash_set<absl::string_view>& execution_threads) override; 74 75 private: 76 const int64_t min_operand_count_to_optimize_; 77 }; 78 } // namespace xla 79 80 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_CONCAT_CODE_MOTION_H_ 81