xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/conditional_code_motion.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_CODE_MOTION_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_CODE_MOTION_H_
18 
19 #include <string>
20 
21 #include "absl/strings/string_view.h"
22 #include "tensorflow/compiler/xla/service/hlo_module.h"
23 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
24 #include "tensorflow/compiler/xla/statusor.h"
25 
26 namespace xla {
27 
28 namespace conditional_opt {
29 // At the conceptual level, a boundary can be thought of as representing a
30 // single virtual operation, except this virtual operation is conditionally
31 // instantiated into different concrete operations at each conditional branch.
32 // So a boundary is mapped to a single concrete operation if it is outside of
33 // conditional branches, and is mapped to a list of instructions if inside the
34 // branches. This data structure therefore allows a common data structure
35 // representation of the instructions to be moved, whether  they are inside or
36 // outside of the branches. Subsequently, it allows a common implementation
37 // basis to be used for both moving instructions out of and for moving them
38 // inside branches.
39 class Boundary {
40  public:
41   enum class Position { kInsideBranch, kOutsideBranch, kUndefined };
Boundary()42   Boundary() : position_(Position::kUndefined) {}
Boundary(Position p)43   explicit Boundary(Position p) : position_(p) {}
mutable_operands()44   std::vector<HloInstruction*>& mutable_operands() { return operands_; }
operands()45   const std::vector<HloInstruction*>& operands() const { return operands_; }
IsInsideBranch()46   bool IsInsideBranch() const { return position_ == Position::kInsideBranch; }
IsOutsideBranch()47   bool IsOutsideBranch() const { return position_ == Position::kOutsideBranch; }
GetPosition()48   Position GetPosition() const { return position_; }
IsEmpty()49   bool IsEmpty() const { return operands_.empty(); }
ToString()50   std::string ToString() const {
51     std::string res;
52     for (HloInstruction* op : operands_) {
53       res += op->ToString() + ";";
54     }
55     return res;
56   }
57   bool operator==(const Boundary& that) const {
58     return absl::c_equal(operands_, that.operands_);
59   }
60   template <typename H>
AbslHashValue(H h,const Boundary & boundary)61   friend H AbslHashValue(H h, const Boundary& boundary) {
62     return H::combine(std::move(h), boundary.operands_);
63   }
64 
65  private:
66   // Boundary instructions in the conditional branches, one from each branch
67   // of the conditional; or a single operand from outside the conditional.
68   std::vector<HloInstruction*> operands_;
69   Position position_;
70 };
71 
72 // HLO pass that moves identical ops in/out of conditional.
73 // - The definition of identical are the shape of the operands are identical
74 // and their properties are identical.
75 // - Only the identical ops that won't share operands with other ops will
76 // be moved out of conditional.
77 // The cost model of the code motion optimization includes two components:
78 // represented by the move_config_ and reuse_config_ arrays of the optimization.
79 // The move_config_ array uses 1 vs 0 to dictate whether each Hlo Opcode, when
80 // used with its first operand being another given Hlo Opcode, is allowed to
81 // move across any conditional boundary; the reuse_config_ array uses an integer
82 // to represent the force between each pair of HloOpcode regarding how
83 // attractive it is to place these instructions together (both inside or outside
84 // of a conditional). Both arrays use Hlo Opcode only to drive the
85 // configuration, regardless of where the operations are located in the
86 // module.
87 class ConditionalCodeMotion : public HloModulePass {
88  public:
89   // If is_layout_sensitive is true, then the hoist process preserves layout
90   // during identical comparison. Otherwise, layout is ignored.
91   // The search configuration is a single integer but is split into four parts:
92   // (sign, n, m, p), where n,m,p each occupy 8 bits and together make the 24
93   // bits at the end of the int32_t. For the sign part, if search_config is <0,
94   // the reuse_config_ cost model is modified (tuned); if search_config is >0,
95   // the move_config_ cost model is modified (tuned); if search_config == 0,
96   // the default cost model is used with no tuning. When tuning, the entries in
97   // the designated configuration array (move_config_ or reuse_config_) are
98   // flipped between 0 and another default integer, starting from the pth entry
99   // being queried by the optimization and repeated every nth time a new entry
100   // is visited, until a maximal of m entries have been changed. The tuning
101   // start over when optimizing a new model.
102   explicit ConditionalCodeMotion(bool is_layout_sensitive,
103                                  bool pursue_full_conditional_code_motion,
104                                  int64_t search_config = 0)
is_layout_sensitive_(is_layout_sensitive)105       : is_layout_sensitive_(is_layout_sensitive),
106         pursue_full_conditional_code_motion_(
107             /*turn off special case if tuning*/
108             pursue_full_conditional_code_motion && search_config == 0),
109         search_config_index_(0) {
110     search_config_.push_back(search_config);
111     if (search_config != 0) {
112       search_config_map_[0] = search_config_;
113     }
114   }
ConditionalCodeMotion(bool is_layout_sensitive,bool pursue_full_conditional_code_motion,std::string search_config)115   explicit ConditionalCodeMotion(bool is_layout_sensitive,
116                                  bool pursue_full_conditional_code_motion,
117                                  std::string search_config)
118       : is_layout_sensitive_(is_layout_sensitive),
119         pursue_full_conditional_code_motion_(
120             /*turn off special case if tuning*/
121             pursue_full_conditional_code_motion && search_config.empty()),
122         search_config_index_(-1) {
123     ParseSearchConfiguration(search_config);
124   }
125   // Parse a given string in the format of a sequence of i,s,m,t into a
126   // list of transformation search configurations, each configuration generated
127   // by invoking MakeSearchConfig(s,m,t) and will be used for the ith
128   // conditional encountered when optimizing a given module.
129   void ParseSearchConfiguration(const std::string& search_config);
130   // Make a single search configuration for changing transformation decisions:
131   // flip the decisions at position n = flip_start + flip_stride * m, and
132   // m = 0..max_flip.
133   // The following defines how the int64_t search configuration is composed, as
134   // flip_start + (flip_max << kMaxPos) + (flip_stride << kStridePos).
135   // Position (digit) for maximum number of flips.
136   static constexpr int kMaxPos = 16;
137   // Position (digit) for the count-down to the first flip.
138   static constexpr int kStartPos = 0;
139   // Position (digit) for the count-down to the next flip.
140   static constexpr int kStridePos = 32;
141   // Bit mask for extracting the last digits of value.
142   static constexpr int kValueMask = 0xffff;
MakeSearchConfig(int64_t start,int64_t max,int64_t stride)143   static int64_t MakeSearchConfig(int64_t start, int64_t max, int64_t stride) {
144     const int64_t config =
145         (max << kMaxPos) + (start << kStartPos) + (stride << kStridePos);
146     VLOG(2) << "flip stride = " << flip_stride(config) << "\n";
147     VLOG(2) << "flig config = " << config << "\n";
148     return config;
149   }
150 
flip_start(int64_t search_config)151   static int16_t flip_start(int64_t search_config) {
152     return (search_config >> kStartPos) & kValueMask;
153   }
154 
flip_stride(int64_t search_config)155   static int16_t flip_stride(int64_t search_config) {
156     return (search_config >> kStridePos) & kValueMask;
157   }
158 
DecrementMaxFlip(int64_t * search_config)159   static int16_t DecrementMaxFlip(int64_t* search_config) {
160     const int16_t max_flip = ((*search_config) >> kMaxPos) & kValueMask;
161     // Decrement flip count so we can stop if it reaches 0.
162     if (max_flip > 0) {
163       *search_config -= (1 << kMaxPos);
164     }
165     return max_flip;
166   }
167 
name()168   absl::string_view name() const override { return "conditional-code-motion"; }
169   using HloPassInterface::Run;
170   StatusOr<bool> Run(
171       HloModule* module,
172       const absl::flat_hash_set<absl::string_view>& execution_threads) override;
173 
174   // Optimization decision for each boundary of the conditional instruction.
175   class Decision {
176    public:
177     enum class Direction : uint8_t {
178       kMoveOutOfBranch,
179       kMoveIntoBranch,
180       kNoChange
181     };
182 
183    public:
Decision(Direction direction,int benefit)184     Decision(Direction direction, int benefit)
185         : direction_(direction), benefit_(benefit) {}
GetDirection()186     Direction GetDirection() const { return direction_; }
GetBenefit()187     int GetBenefit() const { return benefit_; }
188 
189    private:
190     Direction direction_;
191     int benefit_;
192   };
193   // If the optimization decision is NO_CHANGE, new_boundary is set to nullptr;
194   // otherwise, it is set to the new boundary after proposed optimization.
195   virtual Decision ConsiderCodeMotion(
196       HloInstruction* conditional, const Boundary& cur_boundary,
197       std::vector<Boundary>& to_move, std::vector<Boundary>& new_boundaries,
198       absl::flat_hash_map<HloInstruction*, int>& visited_count);
199 
200  private:
201   const bool is_layout_sensitive_;
202   const bool pursue_full_conditional_code_motion_;
203   // The following parameterizes the transformation decisions and cost model.
204   std::vector<int64_t> search_config_;
205   int64_t search_config_index_;
206   // Map each conditional to a vector of its search configurations. The key of
207   // the map is the index number of the conditional in a module when traversed
208   // in post order, and the value of the map is the sequence of search
209   // configurations specified with the same index number for the conditional.
210   absl::flat_hash_map<int64_t, std::vector<int64_t>> search_config_map_;
211   std::vector<std::vector<int64_t>> move_config_, reuse_config_;
212 
213   StatusOr<bool> MoveInstructionOut(HloInstruction* conditional,
214                                     std::vector<Boundary>& to_move_out,
215                                     std::vector<Boundary>& new_boundaries);
216   StatusOr<bool> MoveInstructionIn(HloInstruction* conditional,
217                                    std::vector<Boundary>& to_move_in,
218                                    std::vector<Boundary>& new_boundaries);
219   void SetDefaultMoveConfig();
220 };
221 }  // namespace conditional_opt
222 
223 }  // namespace xla
224 
225 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_CODE_MOTION_H_
226