1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_SUPPORT_LIBRARY_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_SUPPORT_LIBRARY_H_ 18 19 #include <string> 20 21 #include "absl/strings/string_view.h" 22 #include "llvm/IR/BasicBlock.h" 23 #include "llvm/IR/IRBuilder.h" 24 #include "llvm/IR/Value.h" 25 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" 26 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" 27 28 namespace xla { 29 // A thin wrapper around llvm_loop.h to make code generating structured control 30 // flow more readable. 31 class KernelSupportLibrary { 32 public: 33 // `b` is the llvm::IRBuilder instance used to generate LLVM IR. 34 // `unroll_mode` specifies the desired LLVM unrolling behavior for every loop 35 // generated by this instance of KernelSupportLibrary. 36 explicit KernelSupportLibrary( 37 llvm::IRBuilder<>* b, 38 llvm_ir::UnrollMode unroll_mode = llvm_ir::UnrollMode::kNoUnroll, 39 bool prevent_vectorization = true) b_(b)40 : b_(b), 41 unroll_mode_(unroll_mode), 42 prevent_vectorization_(prevent_vectorization) {} 43 44 // Generates the following control flow structure: 45 // 46 // if (`start` < `end`) { 47 // `for_body_generator(/*ind_var=*/start, /*is_first_iteration=*/true)`; 48 // for (i64 i = `start` + `step`; i s< `end`; i += `step`) 49 // `for_body_generator(/*ind_var=*/,i, /*is_first_iteration=*/false)`; 50 // } 51 Status ForWithStatus( 52 absl::string_view name, llvm::Value* start, llvm::Value* end, 53 llvm::Value* step, 54 const std::function<Status(llvm::Value* ind_var, 55 bool is_first_iteration)>& for_body_generator); 56 For(absl::string_view name,llvm::Value * start,llvm::Value * end,llvm::Value * step,const std::function<void (llvm::Value * ind_var,bool is_first_iteration)> & for_body_generator)57 void For( 58 absl::string_view name, llvm::Value* start, llvm::Value* end, 59 llvm::Value* step, 60 const std::function<void(llvm::Value* ind_var, bool is_first_iteration)>& 61 for_body_generator) { 62 CHECK_EQ(OkStatus(), 63 ForWithStatus( 64 name, start, end, step, 65 [&](llvm::Value* ind_var, bool is_first_iteration) -> Status { 66 for_body_generator(ind_var, is_first_iteration); 67 return OkStatus(); 68 })); 69 } 70 ForWithStatus(absl::string_view name,int64_t start,int64_t end,int64_t step,const std::function<Status (llvm::Value * ind_var,bool is_first_iteration)> & for_body_generator)71 Status ForWithStatus( 72 absl::string_view name, int64_t start, int64_t end, int64_t step, 73 const std::function<Status( 74 llvm::Value* ind_var, bool is_first_iteration)>& for_body_generator) { 75 return ForWithStatus(name, /*start=*/b_->getInt64(start), 76 /*end=*/b_->getInt64(end), 77 /*step=*/b_->getInt64(step), for_body_generator); 78 } 79 For(absl::string_view name,int64_t start,int64_t end,int64_t step,const std::function<void (llvm::Value * ind_var,bool is_first_iteration)> & for_body_generator)80 void For( 81 absl::string_view name, int64_t start, int64_t end, int64_t step, 82 const std::function<void(llvm::Value* ind_var, bool is_first_iteration)>& 83 for_body_generator) { 84 For(name, /*start=*/b_->getInt64(start), 85 /*end=*/b_->getInt64(end), 86 /*step=*/b_->getInt64(step), for_body_generator); 87 } 88 89 // Generates the following control flow structure if `peel_first_iteration` is 90 // true: 91 // 92 // if (`start` < `end`) { 93 // `for_body_generator(/*ind_var=*/start, /*is_first_iteration=*/,true)`; 94 // for (i64 i = `start` + `step`; i s< `end`; i += `step`) 95 // `for_body_generator(/*ind_var=*/,i, /*is_first_iteration=*/,false)`; 96 // } 97 // 98 // and the following if `peel_first_iteration` is false: 99 // 100 // for (i64 i = `start`; i s< `end`; i += `step`) 101 // `for_body_generator(/*ind_var=*/,i, 102 // /*is_first_iteration=*/,(i != `start`))`; 103 Status ForWithStatus( 104 absl::string_view name, llvm::Value* start, llvm::Value* end, 105 llvm::Value* step, bool peel_first_iteration, 106 const std::function<Status(llvm::Value* ind_var, 107 llvm::Value* is_first_iteration)>& 108 for_body_generator); 109 For(absl::string_view name,llvm::Value * start,llvm::Value * end,llvm::Value * step,bool peel_first_iteration,const std::function<void (llvm::Value * ind_var,llvm::Value * is_first_iteration)> & for_body_generator)110 void For(absl::string_view name, llvm::Value* start, llvm::Value* end, 111 llvm::Value* step, bool peel_first_iteration, 112 const std::function<void(llvm::Value* ind_var, 113 llvm::Value* is_first_iteration)>& 114 for_body_generator) { 115 TF_CHECK_OK(ForWithStatus( 116 name, start, end, step, peel_first_iteration, 117 [&](llvm::Value* ind_var, llvm::Value* is_first_iteration) -> Status { 118 for_body_generator(ind_var, is_first_iteration); 119 return OkStatus(); 120 })); 121 } 122 ForWithStatus(absl::string_view name,llvm::Value * start,llvm::Value * end,int64_t step,bool peel_first_iteration,const std::function<Status (llvm::Value * ind_var,llvm::Value * is_first_iteration)> & for_body_generator)123 Status ForWithStatus( 124 absl::string_view name, llvm::Value* start, llvm::Value* end, 125 int64_t step, bool peel_first_iteration, 126 const std::function<Status(llvm::Value* ind_var, 127 llvm::Value* is_first_iteration)>& 128 for_body_generator) { 129 return ForWithStatus( 130 name, /*start=*/start, /*end=*/end, 131 /*step=*/llvm::ConstantInt::get(start->getType(), step), 132 peel_first_iteration, for_body_generator); 133 } 134 For(absl::string_view name,llvm::Value * start,llvm::Value * end,int64_t step,bool peel_first_iteration,const std::function<void (llvm::Value * ind_var,llvm::Value * is_first_iteration)> & for_body_generator)135 void For(absl::string_view name, llvm::Value* start, llvm::Value* end, 136 int64_t step, bool peel_first_iteration, 137 const std::function<void(llvm::Value* ind_var, 138 llvm::Value* is_first_iteration)>& 139 for_body_generator) { 140 For(name, /*start=*/start, /*end=*/end, 141 /*step=*/llvm::ConstantInt::get(start->getType(), step), 142 peel_first_iteration, for_body_generator); 143 } 144 ForWithStatus(absl::string_view name,llvm::Value * start,llvm::Value * end,llvm::Value * step,const std::function<Status (llvm::Value * ind_var)> & for_body_generator)145 Status ForWithStatus( 146 absl::string_view name, llvm::Value* start, llvm::Value* end, 147 llvm::Value* step, 148 const std::function<Status(llvm::Value* ind_var)>& for_body_generator) { 149 return ForWithStatus(name, start, end, step, 150 /*peel_first_iteration=*/false, 151 [&](llvm::Value* indvar, llvm::Value*) -> Status { 152 return for_body_generator(indvar); 153 }); 154 } 155 For(absl::string_view name,llvm::Value * start,llvm::Value * end,llvm::Value * step,const std::function<void (llvm::Value * ind_var)> & for_body_generator)156 void For( 157 absl::string_view name, llvm::Value* start, llvm::Value* end, 158 llvm::Value* step, 159 const std::function<void(llvm::Value* ind_var)>& for_body_generator) { 160 For(name, start, end, step, 161 /*peel_first_iteration=*/false, [&](llvm::Value* indvar, llvm::Value*) { 162 return for_body_generator(indvar); 163 }); 164 } 165 ForWithStatus(absl::string_view name,llvm::Value * start,llvm::Value * end,int64_t step,const std::function<Status (llvm::Value * ind_var)> & for_body_generator)166 Status ForWithStatus( 167 absl::string_view name, llvm::Value* start, llvm::Value* end, 168 int64_t step, 169 const std::function<Status(llvm::Value* ind_var)>& for_body_generator) { 170 return ForWithStatus(name, start, end, 171 llvm::ConstantInt::get(start->getType(), step), 172 /*peel_first_iteration=*/false, 173 [&](llvm::Value* indvar, llvm::Value*) -> Status { 174 return for_body_generator(indvar); 175 }); 176 } 177 For(absl::string_view name,llvm::Value * start,llvm::Value * end,int64_t step,const std::function<void (llvm::Value * ind_var)> & for_body_generator)178 void For( 179 absl::string_view name, llvm::Value* start, llvm::Value* end, 180 int64_t step, 181 const std::function<void(llvm::Value* ind_var)>& for_body_generator) { 182 For(name, start, end, llvm::ConstantInt::get(start->getType(), step), 183 for_body_generator); 184 } 185 ForWithStatus(absl::string_view name,int64_t start,int64_t end,int64_t step,const std::function<Status (llvm::Value * ind_var)> & for_body_generator)186 Status ForWithStatus( 187 absl::string_view name, int64_t start, int64_t end, int64_t step, 188 const std::function<Status(llvm::Value* ind_var)>& for_body_generator) { 189 return ForWithStatus(name, /*start=*/b_->getInt64(start), 190 /*end=*/b_->getInt64(end), 191 /*step=*/b_->getInt64(step), for_body_generator); 192 } 193 For(absl::string_view name,int64_t start,int64_t end,int64_t step,const std::function<void (llvm::Value * ind_var)> & for_body_generator)194 void For( 195 absl::string_view name, int64_t start, int64_t end, int64_t step, 196 const std::function<void(llvm::Value* ind_var)>& for_body_generator) { 197 For(name, /*start=*/b_->getInt64(start), 198 /*end=*/b_->getInt64(end), 199 /*step=*/b_->getInt64(step), for_body_generator); 200 } 201 202 // Generates the following control flow structure: 203 // 204 // if (`condition`) 205 // `true_block_generator()`; 206 // else 207 // `false_block_generator()`; 208 // The else is skipped if false_block_generator is null. 209 Status IfWithStatus( 210 absl::string_view name, llvm::Value* condition, 211 const std::function<Status()>& true_block_generator, 212 const std::function<Status()>& false_block_generator = nullptr); 213 214 Status IfWithStatus( 215 llvm::Value* condition, 216 const std::function<Status()>& true_block_generator, 217 const std::function<Status()>& false_block_generator = []() -> Status { 218 return OkStatus(); 219 }) { 220 return IfWithStatus("", condition, true_block_generator, 221 false_block_generator); 222 } 223 224 void If(llvm::Value* condition, 225 const std::function<void()>& true_block_generator, 226 const std::function<void()>& false_block_generator = nullptr) { 227 If("", condition, true_block_generator, false_block_generator); 228 } 229 230 void If(absl::string_view name, llvm::Value* condition, 231 const std::function<void()>& true_block_generator, 232 const std::function<void()>& false_block_generator = nullptr) { 233 if (false_block_generator != nullptr) { 234 TF_CHECK_OK(IfWithStatus( 235 name, condition, 236 [&]() { 237 true_block_generator(); 238 return OkStatus(); 239 }, 240 [&]() { 241 false_block_generator(); 242 return OkStatus(); 243 })); 244 } else { 245 TF_CHECK_OK(IfWithStatus(name, condition, [&]() { 246 true_block_generator(); 247 return OkStatus(); 248 })); 249 } 250 } 251 252 using ArgumentVector = absl::Span<llvm::Value* const>; 253 254 // Generates the following control flow structure: 255 // 256 // define @`kernel_name`(arg0, arg1, ... arg`arguments.size()`) { 257 // kernel_body_generator({arg0, arg1, ... arg`arguments.size()`}); 258 // } 259 // 260 // ... 261 // call @`kernel_name`(arguments[0], arguments[1] ...) 262 // ... 263 // 264 // If a function called `kernel_name` is already present in the module then 265 // that function is re-used. In that sense we're using the llvm::Module as a 266 // cache of outlined kernels, keyed by function name. 267 // 268 // If any of the values in `arguments` is nullptr (i.e. a nullptr 269 // llvm::Value*) then we ignore it when generating LLVM IR, and instead pass 270 // in a nullptr llvm::Value* in its position to `kernel_body_generator`. 271 // Currently we only support at most one nullptr value in `arguments`. 272 static void EmitAndCallOutlinedKernel( 273 const HloModuleConfig& module_config, llvm::IRBuilder<>* b, 274 absl::string_view kernel_name, ArgumentVector arguments, 275 const std::function<void(ArgumentVector)>& kernel_body_generator); 276 277 // Thin wrappers around the more general EmitAndCallOutlinedKernel above. EmitAndCallOutlinedKernel(const HloModuleConfig & module_config,llvm::IRBuilder<> * b,absl::string_view kernel_name,llvm::Value * arg0,llvm::Value * arg1,llvm::Value * arg2,const std::function<void (llvm::Value *,llvm::Value *,llvm::Value *)> & kernel_body_generator)278 static void EmitAndCallOutlinedKernel( 279 const HloModuleConfig& module_config, llvm::IRBuilder<>* b, 280 absl::string_view kernel_name, llvm::Value* arg0, llvm::Value* arg1, 281 llvm::Value* arg2, 282 const std::function<void(llvm::Value*, llvm::Value*, llvm::Value*)>& 283 kernel_body_generator) { 284 EmitAndCallOutlinedKernel(module_config, b, kernel_name, {arg0, arg1, arg2}, 285 [&](ArgumentVector args) { 286 kernel_body_generator(args[0], args[1], 287 args[2]); 288 }); 289 } 290 EmitAndCallOutlinedKernel(const HloModuleConfig & module_config,llvm::IRBuilder<> * b,absl::string_view kernel_name,llvm::Value * arg0,llvm::Value * arg1,llvm::Value * arg2,llvm::Value * arg3,const std::function<void (llvm::Value *,llvm::Value *,llvm::Value *,llvm::Value *)> & kernel_body_generator)291 static void EmitAndCallOutlinedKernel( 292 const HloModuleConfig& module_config, llvm::IRBuilder<>* b, 293 absl::string_view kernel_name, llvm::Value* arg0, llvm::Value* arg1, 294 llvm::Value* arg2, llvm::Value* arg3, 295 const std::function<void(llvm::Value*, llvm::Value*, llvm::Value*, 296 llvm::Value*)>& kernel_body_generator) { 297 EmitAndCallOutlinedKernel( 298 module_config, b, kernel_name, {arg0, arg1, arg2, arg3}, 299 [&](ArgumentVector args) { 300 kernel_body_generator(args[0], args[1], args[2], args[3]); 301 }); 302 } 303 304 private: 305 llvm::IRBuilder<>* b_; 306 llvm_ir::UnrollMode unroll_mode_; 307 bool prevent_vectorization_; 308 }; 309 } // namespace xla 310 311 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_SUPPORT_LIBRARY_H_ 312