xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/while_loop_concat_code_motion.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_CONCAT_CODE_MOTION_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_CONCAT_CODE_MOTION_H_
18 
19 #include "tensorflow/compiler/xla/service/hlo_module.h"
20 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
21 #include "tensorflow/compiler/xla/statusor.h"
22 
23 namespace xla {
24 
25 // A pass that tries to lift concatenation out of a while loop, and replace
26 // piece-wise subcomputations in the loop body with one on the concatenated
27 // shape.
28 //
29 // For example:
30 //
31 // loop = while (a, b, c, d) {
32 //   e = concat(a, b)
33 //   f = some-op(e) <with the same shape as e>
34 //   s0 = slice(f) first half
35 //   s1 = slice(f) second half
36 //   a_1 = add(a, s0)
37 //   b_1 = add(b, s1)
38 //   a_new = add(a_1, c)
39 //   b_new = add(b_1, d)
40 //   c_new = add(a_new, c)
41 //   d_new = add(b_new, d)
42 //   ROOT tuple(a_new, b_new, c_new, d_new)
43 // }
44 //
45 // will be transformed to
46 //
47 // ab = concat(a, b)
48 // cd = concat(c, d)
49 // while (ab, cd) {
50 //   f = some-op(ab)
51 //   ab_1 = add(ab, f)
52 //   ab_new = add(ab_1, cd)
53 //   cd_new = add(ab_new, cd)
54 //   ROOT tuple(ab_new, cd_new)
55 // }
56 // a_new = slice(ab_new) first half
57 // b_new = slice(ab_new) second half
58 // c_new = slice(cd_new) first half
59 // d_new = slice(cd_new) second half
60 class WhileLoopConcatCodeMotion : public HloModulePass {
61  public:
WhileLoopConcatCodeMotion(int64_t min_operand_count_to_optimize)62   explicit WhileLoopConcatCodeMotion(int64_t min_operand_count_to_optimize)
63       : min_operand_count_to_optimize_(min_operand_count_to_optimize) {}
64   ~WhileLoopConcatCodeMotion() override = default;
65 
name()66   absl::string_view name() const override {
67     static constexpr absl::string_view kName = "while-loop-concat-code-motion";
68     return kName;
69   }
70   using HloPassInterface::Run;
71   StatusOr<bool> Run(
72       HloModule* module,
73       const absl::flat_hash_set<absl::string_view>& execution_threads) override;
74 
75  private:
76   const int64_t min_operand_count_to_optimize_;
77 };
78 }  // namespace xla
79 
80 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_CONCAT_CODE_MOTION_H_
81