xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_SUPPORT_LIBRARY_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_SUPPORT_LIBRARY_H_
18 
19 #include <string>
20 
21 #include "absl/strings/string_view.h"
22 #include "llvm/IR/BasicBlock.h"
23 #include "llvm/IR/IRBuilder.h"
24 #include "llvm/IR/Value.h"
25 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
26 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
27 
28 namespace xla {
29 // A thin wrapper around llvm_loop.h to make code generating structured control
30 // flow more readable.
31 class KernelSupportLibrary {
32  public:
33   // `b` is the llvm::IRBuilder instance used to generate LLVM IR.
34   // `unroll_mode` specifies the desired LLVM unrolling behavior for every loop
35   // generated by this instance of KernelSupportLibrary.
36   explicit KernelSupportLibrary(
37       llvm::IRBuilder<>* b,
38       llvm_ir::UnrollMode unroll_mode = llvm_ir::UnrollMode::kNoUnroll,
39       bool prevent_vectorization = true)
b_(b)40       : b_(b),
41         unroll_mode_(unroll_mode),
42         prevent_vectorization_(prevent_vectorization) {}
43 
44   // Generates the following control flow structure:
45   //
46   //   if (`start` < `end`) {
47   //     `for_body_generator(/*ind_var=*/start, /*is_first_iteration=*/true)`;
48   //     for (i64 i = `start` + `step`; i s< `end`; i += `step`)
49   //       `for_body_generator(/*ind_var=*/,i, /*is_first_iteration=*/false)`;
50   //   }
51   Status ForWithStatus(
52       absl::string_view name, llvm::Value* start, llvm::Value* end,
53       llvm::Value* step,
54       const std::function<Status(llvm::Value* ind_var,
55                                  bool is_first_iteration)>& for_body_generator);
56 
For(absl::string_view name,llvm::Value * start,llvm::Value * end,llvm::Value * step,const std::function<void (llvm::Value * ind_var,bool is_first_iteration)> & for_body_generator)57   void For(
58       absl::string_view name, llvm::Value* start, llvm::Value* end,
59       llvm::Value* step,
60       const std::function<void(llvm::Value* ind_var, bool is_first_iteration)>&
61           for_body_generator) {
62     CHECK_EQ(OkStatus(),
63              ForWithStatus(
64                  name, start, end, step,
65                  [&](llvm::Value* ind_var, bool is_first_iteration) -> Status {
66                    for_body_generator(ind_var, is_first_iteration);
67                    return OkStatus();
68                  }));
69   }
70 
ForWithStatus(absl::string_view name,int64_t start,int64_t end,int64_t step,const std::function<Status (llvm::Value * ind_var,bool is_first_iteration)> & for_body_generator)71   Status ForWithStatus(
72       absl::string_view name, int64_t start, int64_t end, int64_t step,
73       const std::function<Status(
74           llvm::Value* ind_var, bool is_first_iteration)>& for_body_generator) {
75     return ForWithStatus(name, /*start=*/b_->getInt64(start),
76                          /*end=*/b_->getInt64(end),
77                          /*step=*/b_->getInt64(step), for_body_generator);
78   }
79 
For(absl::string_view name,int64_t start,int64_t end,int64_t step,const std::function<void (llvm::Value * ind_var,bool is_first_iteration)> & for_body_generator)80   void For(
81       absl::string_view name, int64_t start, int64_t end, int64_t step,
82       const std::function<void(llvm::Value* ind_var, bool is_first_iteration)>&
83           for_body_generator) {
84     For(name, /*start=*/b_->getInt64(start),
85         /*end=*/b_->getInt64(end),
86         /*step=*/b_->getInt64(step), for_body_generator);
87   }
88 
89   // Generates the following control flow structure if `peel_first_iteration` is
90   // true:
91   //
92   //   if (`start` < `end`) {
93   //     `for_body_generator(/*ind_var=*/start, /*is_first_iteration=*/,true)`;
94   //     for (i64 i = `start` + `step`; i s< `end`; i += `step`)
95   //       `for_body_generator(/*ind_var=*/,i, /*is_first_iteration=*/,false)`;
96   //   }
97   //
98   // and the following if `peel_first_iteration` is false:
99   //
100   //   for (i64 i = `start`; i s< `end`; i += `step`)
101   //     `for_body_generator(/*ind_var=*/,i,
102   //                         /*is_first_iteration=*/,(i != `start`))`;
103   Status ForWithStatus(
104       absl::string_view name, llvm::Value* start, llvm::Value* end,
105       llvm::Value* step, bool peel_first_iteration,
106       const std::function<Status(llvm::Value* ind_var,
107                                  llvm::Value* is_first_iteration)>&
108           for_body_generator);
109 
For(absl::string_view name,llvm::Value * start,llvm::Value * end,llvm::Value * step,bool peel_first_iteration,const std::function<void (llvm::Value * ind_var,llvm::Value * is_first_iteration)> & for_body_generator)110   void For(absl::string_view name, llvm::Value* start, llvm::Value* end,
111            llvm::Value* step, bool peel_first_iteration,
112            const std::function<void(llvm::Value* ind_var,
113                                     llvm::Value* is_first_iteration)>&
114                for_body_generator) {
115     TF_CHECK_OK(ForWithStatus(
116         name, start, end, step, peel_first_iteration,
117         [&](llvm::Value* ind_var, llvm::Value* is_first_iteration) -> Status {
118           for_body_generator(ind_var, is_first_iteration);
119           return OkStatus();
120         }));
121   }
122 
ForWithStatus(absl::string_view name,llvm::Value * start,llvm::Value * end,int64_t step,bool peel_first_iteration,const std::function<Status (llvm::Value * ind_var,llvm::Value * is_first_iteration)> & for_body_generator)123   Status ForWithStatus(
124       absl::string_view name, llvm::Value* start, llvm::Value* end,
125       int64_t step, bool peel_first_iteration,
126       const std::function<Status(llvm::Value* ind_var,
127                                  llvm::Value* is_first_iteration)>&
128           for_body_generator) {
129     return ForWithStatus(
130         name, /*start=*/start, /*end=*/end,
131         /*step=*/llvm::ConstantInt::get(start->getType(), step),
132         peel_first_iteration, for_body_generator);
133   }
134 
For(absl::string_view name,llvm::Value * start,llvm::Value * end,int64_t step,bool peel_first_iteration,const std::function<void (llvm::Value * ind_var,llvm::Value * is_first_iteration)> & for_body_generator)135   void For(absl::string_view name, llvm::Value* start, llvm::Value* end,
136            int64_t step, bool peel_first_iteration,
137            const std::function<void(llvm::Value* ind_var,
138                                     llvm::Value* is_first_iteration)>&
139                for_body_generator) {
140     For(name, /*start=*/start, /*end=*/end,
141         /*step=*/llvm::ConstantInt::get(start->getType(), step),
142         peel_first_iteration, for_body_generator);
143   }
144 
ForWithStatus(absl::string_view name,llvm::Value * start,llvm::Value * end,llvm::Value * step,const std::function<Status (llvm::Value * ind_var)> & for_body_generator)145   Status ForWithStatus(
146       absl::string_view name, llvm::Value* start, llvm::Value* end,
147       llvm::Value* step,
148       const std::function<Status(llvm::Value* ind_var)>& for_body_generator) {
149     return ForWithStatus(name, start, end, step,
150                          /*peel_first_iteration=*/false,
151                          [&](llvm::Value* indvar, llvm::Value*) -> Status {
152                            return for_body_generator(indvar);
153                          });
154   }
155 
For(absl::string_view name,llvm::Value * start,llvm::Value * end,llvm::Value * step,const std::function<void (llvm::Value * ind_var)> & for_body_generator)156   void For(
157       absl::string_view name, llvm::Value* start, llvm::Value* end,
158       llvm::Value* step,
159       const std::function<void(llvm::Value* ind_var)>& for_body_generator) {
160     For(name, start, end, step,
161         /*peel_first_iteration=*/false, [&](llvm::Value* indvar, llvm::Value*) {
162           return for_body_generator(indvar);
163         });
164   }
165 
ForWithStatus(absl::string_view name,llvm::Value * start,llvm::Value * end,int64_t step,const std::function<Status (llvm::Value * ind_var)> & for_body_generator)166   Status ForWithStatus(
167       absl::string_view name, llvm::Value* start, llvm::Value* end,
168       int64_t step,
169       const std::function<Status(llvm::Value* ind_var)>& for_body_generator) {
170     return ForWithStatus(name, start, end,
171                          llvm::ConstantInt::get(start->getType(), step),
172                          /*peel_first_iteration=*/false,
173                          [&](llvm::Value* indvar, llvm::Value*) -> Status {
174                            return for_body_generator(indvar);
175                          });
176   }
177 
For(absl::string_view name,llvm::Value * start,llvm::Value * end,int64_t step,const std::function<void (llvm::Value * ind_var)> & for_body_generator)178   void For(
179       absl::string_view name, llvm::Value* start, llvm::Value* end,
180       int64_t step,
181       const std::function<void(llvm::Value* ind_var)>& for_body_generator) {
182     For(name, start, end, llvm::ConstantInt::get(start->getType(), step),
183         for_body_generator);
184   }
185 
ForWithStatus(absl::string_view name,int64_t start,int64_t end,int64_t step,const std::function<Status (llvm::Value * ind_var)> & for_body_generator)186   Status ForWithStatus(
187       absl::string_view name, int64_t start, int64_t end, int64_t step,
188       const std::function<Status(llvm::Value* ind_var)>& for_body_generator) {
189     return ForWithStatus(name, /*start=*/b_->getInt64(start),
190                          /*end=*/b_->getInt64(end),
191                          /*step=*/b_->getInt64(step), for_body_generator);
192   }
193 
For(absl::string_view name,int64_t start,int64_t end,int64_t step,const std::function<void (llvm::Value * ind_var)> & for_body_generator)194   void For(
195       absl::string_view name, int64_t start, int64_t end, int64_t step,
196       const std::function<void(llvm::Value* ind_var)>& for_body_generator) {
197     For(name, /*start=*/b_->getInt64(start),
198         /*end=*/b_->getInt64(end),
199         /*step=*/b_->getInt64(step), for_body_generator);
200   }
201 
202   // Generates the following control flow structure:
203   //
204   //   if (`condition`)
205   //     `true_block_generator()`;
206   //   else
207   //      `false_block_generator()`;
208   // The else is skipped if false_block_generator is null.
209   Status IfWithStatus(
210       absl::string_view name, llvm::Value* condition,
211       const std::function<Status()>& true_block_generator,
212       const std::function<Status()>& false_block_generator = nullptr);
213 
214   Status IfWithStatus(
215       llvm::Value* condition,
216       const std::function<Status()>& true_block_generator,
217       const std::function<Status()>& false_block_generator = []() -> Status {
218         return OkStatus();
219       }) {
220     return IfWithStatus("", condition, true_block_generator,
221                         false_block_generator);
222   }
223 
224   void If(llvm::Value* condition,
225           const std::function<void()>& true_block_generator,
226           const std::function<void()>& false_block_generator = nullptr) {
227     If("", condition, true_block_generator, false_block_generator);
228   }
229 
230   void If(absl::string_view name, llvm::Value* condition,
231           const std::function<void()>& true_block_generator,
232           const std::function<void()>& false_block_generator = nullptr) {
233     if (false_block_generator != nullptr) {
234       TF_CHECK_OK(IfWithStatus(
235           name, condition,
236           [&]() {
237             true_block_generator();
238             return OkStatus();
239           },
240           [&]() {
241             false_block_generator();
242             return OkStatus();
243           }));
244     } else {
245       TF_CHECK_OK(IfWithStatus(name, condition, [&]() {
246         true_block_generator();
247         return OkStatus();
248       }));
249     }
250   }
251 
252   using ArgumentVector = absl::Span<llvm::Value* const>;
253 
254   // Generates the following control flow structure:
255   //
256   //  define @`kernel_name`(arg0, arg1, ... arg`arguments.size()`) {
257   //    kernel_body_generator({arg0, arg1, ... arg`arguments.size()`});
258   //  }
259   //
260   //  ...
261   //  call @`kernel_name`(arguments[0], arguments[1] ...)
262   //  ...
263   //
264   // If a function called `kernel_name` is already present in the module then
265   // that function is re-used.  In that sense we're using the llvm::Module as a
266   // cache of outlined kernels, keyed by function name.
267   //
268   // If any of the values in `arguments` is nullptr (i.e. a nullptr
269   // llvm::Value*) then we ignore it when generating LLVM IR, and instead pass
270   // in a nullptr llvm::Value* in its position to `kernel_body_generator`.
271   // Currently we only support at most one nullptr value in `arguments`.
272   static void EmitAndCallOutlinedKernel(
273       const HloModuleConfig& module_config, llvm::IRBuilder<>* b,
274       absl::string_view kernel_name, ArgumentVector arguments,
275       const std::function<void(ArgumentVector)>& kernel_body_generator);
276 
277   // Thin wrappers around the more general EmitAndCallOutlinedKernel above.
EmitAndCallOutlinedKernel(const HloModuleConfig & module_config,llvm::IRBuilder<> * b,absl::string_view kernel_name,llvm::Value * arg0,llvm::Value * arg1,llvm::Value * arg2,const std::function<void (llvm::Value *,llvm::Value *,llvm::Value *)> & kernel_body_generator)278   static void EmitAndCallOutlinedKernel(
279       const HloModuleConfig& module_config, llvm::IRBuilder<>* b,
280       absl::string_view kernel_name, llvm::Value* arg0, llvm::Value* arg1,
281       llvm::Value* arg2,
282       const std::function<void(llvm::Value*, llvm::Value*, llvm::Value*)>&
283           kernel_body_generator) {
284     EmitAndCallOutlinedKernel(module_config, b, kernel_name, {arg0, arg1, arg2},
285                               [&](ArgumentVector args) {
286                                 kernel_body_generator(args[0], args[1],
287                                                       args[2]);
288                               });
289   }
290 
EmitAndCallOutlinedKernel(const HloModuleConfig & module_config,llvm::IRBuilder<> * b,absl::string_view kernel_name,llvm::Value * arg0,llvm::Value * arg1,llvm::Value * arg2,llvm::Value * arg3,const std::function<void (llvm::Value *,llvm::Value *,llvm::Value *,llvm::Value *)> & kernel_body_generator)291   static void EmitAndCallOutlinedKernel(
292       const HloModuleConfig& module_config, llvm::IRBuilder<>* b,
293       absl::string_view kernel_name, llvm::Value* arg0, llvm::Value* arg1,
294       llvm::Value* arg2, llvm::Value* arg3,
295       const std::function<void(llvm::Value*, llvm::Value*, llvm::Value*,
296                                llvm::Value*)>& kernel_body_generator) {
297     EmitAndCallOutlinedKernel(
298         module_config, b, kernel_name, {arg0, arg1, arg2, arg3},
299         [&](ArgumentVector args) {
300           kernel_body_generator(args[0], args[1], args[2], args[3]);
301         });
302   }
303 
304  private:
305   llvm::IRBuilder<>* b_;
306   llvm_ir::UnrollMode unroll_mode_;
307   bool prevent_vectorization_;
308 };
309 }  // namespace xla
310 
311 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_SUPPORT_LIBRARY_H_
312