xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/cpu/ir_function.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/cpu/ir_function.h"
17 
18 #include <iterator>
19 
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
22 #include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
23 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
24 #include "tensorflow/compiler/xla/status_macros.h"
25 
26 namespace xla {
27 namespace cpu {
28 
GetComputeFunctionParams(llvm::Module * llvm_module,const int64_t num_dynamic_loop_bounds)29 static std::vector<llvm::Type*> GetComputeFunctionParams(
30     llvm::Module* llvm_module, const int64_t num_dynamic_loop_bounds) {
31   llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(llvm_module->getContext());
32   llvm::Type* i8_ptr_ptr_type = i8_ptr_type->getPointerTo();
33   llvm::Type* i64_ptr_type =
34       llvm::Type::getInt64PtrTy(llvm_module->getContext());
35   std::vector<llvm::Type*> compute_function_params(
36       {i8_ptr_type, i8_ptr_type, i8_ptr_ptr_type, i8_ptr_ptr_type,
37        i8_ptr_type});
38   if (num_dynamic_loop_bounds > 0) {
39     compute_function_params.push_back(i64_ptr_type);
40   }
41   compute_function_params.push_back(i64_ptr_type);
42   return compute_function_params;
43 }
44 
IrFunction(const std::string & function_name,llvm::Function::LinkageTypes linkage,const HloModuleConfig & module_config,llvm::Module * llvm_module,llvm::IRBuilder<> * b,int64_t num_dynamic_loop_bounds)45 IrFunction::IrFunction(const std::string& function_name,
46                        llvm::Function::LinkageTypes linkage,
47                        const HloModuleConfig& module_config,
48                        llvm::Module* llvm_module, llvm::IRBuilder<>* b,
49                        int64_t num_dynamic_loop_bounds)
50     : b_(b),
51       llvm_module_(llvm_module),
52       caller_insert_point_guard_(*b),
53       num_dynamic_loop_bounds_(num_dynamic_loop_bounds) {
54   Initialize(function_name, linkage, module_config);
55 }
56 
~IrFunction()57 IrFunction::~IrFunction() {
58   // Branch to function return.
59   b_->CreateBr(return_block_);
60 }
61 
GetDynamicLoopBounds()62 DynamicLoopBounds IrFunction::GetDynamicLoopBounds() {
63   DynamicLoopBounds dynamic_loop_bounds(num_dynamic_loop_bounds_);
64   for (int i = 0; i < num_dynamic_loop_bounds_; ++i) {
65     dynamic_loop_bounds[i].first = GetDynamicLoopBound(i * 2 + 0);
66     dynamic_loop_bounds[i].second = GetDynamicLoopBound(i * 2 + 1);
67   }
68   return dynamic_loop_bounds;
69 }
70 
Initialize(const std::string & function_name,llvm::Function::LinkageTypes linkage,const HloModuleConfig & module_config)71 void IrFunction::Initialize(const std::string& function_name,
72                             llvm::Function::LinkageTypes linkage,
73                             const HloModuleConfig& module_config) {
74   // The function signature is:
75   //   void function(i8* retval, i8* run_options, i8** params, i8**
76   //   buffer_table,
77   //                 i64* dynamic_loop_bounds, i64* prof_counters)
78   //
79   // For thread local functions:
80   //   retval: points to the returned value.
81   //   params: address of an array with pointers to parameters.
82   //   buffer_table: is null
83   //
84   // For global functions:
85   //   retval: is null
86   //   params: is null
87   //   buffer_table: address of an array with pointers to temporary buffers and
88   //     entry computation parameters (but not to constant buffers).
89   //
90   // Therefore, the generated function's signature (FunctionType) is statically
91   // determined - parameter unpacking is done in code generated into the
92   // function, rather than by a prologue dictated by the platform ABI.
93   //
94   //                      /--------------\
95   //   retval ----------> | return value |
96   //                      \--------------/
97   //
98   //                      /-------------------------------\
99   //   run_options -----> | xla::ExecutableRunOptions |
100   //                      \-------------------------------/
101   //
102   //                     /---------------------------------------------\
103   //   params -------->  |  param 0  |  param 1  | ..... |  param N-1  |
104   //                     |   addr    |   addr    |       |   addr      |
105   //                     \---------------------------------------------/
106   //                          |           |                   |
107   //                          |           |                   |
108   //                          V           V                   V
109   //                     /---------\  /---------\         /-----------\
110   //                     | param 0 |  | param 1 |         | param N-1 |
111   //                     \---------/  \---------/         \-----------/
112   //
113   //                     /---------------------------------------------\
114   //   buffer_table--->  |  buff  0  |  guff  1  | ..... |  buff  N-1  |
115   //                     |   addr    |   addr    |       |   addr      |
116   //                     \---------------------------------------------/
117   //                          |           |                   |
118   //                          |           |                   |
119   //                          V           V                   V
120   //                     /---------\  /---------\         /-----------\
121   //                     | temp  0 |  | temp  1 |         | temp  N-1 |
122   //                     \---------/  \---------/         \-----------/
123   //
124   //                        /--------------------------------------------\
125   // dynamic loop bounds -> | outer_dim0_start | outer_dim0_limit | .....|
126   //  (elided for aot)      \--------------------------------------------/
127   //
128   //                     /---------------------------------------------\
129   //   prof counters ->  | counter 0 | counter 1 | ..... | counter N-1 |
130   //                     \---------------------------------------------/
131 
132   // Even though the type of params and buffer_table is void** in the host's
133   // view, in LLVM IR this is represented by i8*, similarly to void*. It's up to
134   // the code to use GEPs to unravel the indirection layers.
135   llvm::FunctionType* function_type = llvm::FunctionType::get(
136       /*Result=*/llvm::Type::getVoidTy(llvm_module_->getContext()),
137       /*Params=*/
138       GetComputeFunctionParams(llvm_module_, num_dynamic_loop_bounds_),
139       /*isVarArg=*/false);
140 
141   // Functions with local linkage get an inlining bonus.  Because we know
142   // a-priori that embedded functions (non-entry functions) will not have its
143   // name resolved, give it local linkage.
144   function_ = llvm_ir::CreateCpuFunction(function_type, linkage, module_config,
145                                          function_name, llvm_module_);
146 
147   // Set meaningful names for the function's arguments: useful for debugging.
148   llvm::Function::arg_iterator arg_iter = function_->arg_begin();
149   arg_iter->setName("retval");
150   result_arg_ = &*arg_iter;
151   (++arg_iter)->setName("run_options");
152   exec_run_options_arg_ = &*arg_iter;
153   (++arg_iter)->setName("params");
154   parameters_arg_ = &*arg_iter;
155   (++arg_iter)->setName("buffer_table");
156   buffer_table_arg_ = &*arg_iter;
157   (++arg_iter)->setName("status");
158   status_arg_ = &*arg_iter;
159   if (num_dynamic_loop_bounds_ > 0) {
160     (++arg_iter)->setName("dynamic_loop_bounds");
161     dynamic_loop_bounds_arg_ = &*arg_iter;
162   }
163   (++arg_iter)->setName("prof_counters");
164   profile_counters_arg_ = &*arg_iter;
165 
166   // We know a-priori that the function arguments are guaranteed to point to
167   // disjoint objects.
168   llvm::Argument* retval = result_arg();
169   for (llvm::Argument& argument : function_->args()) {
170     // However, the return buffer aliases the temporaries and thus cannot be
171     // marked noalias.
172     if (&argument == retval) {
173       continue;
174     }
175     function_->addParamAttr(argument.getArgNo(), llvm::Attribute::NoAlias);
176   }
177 
178   return_block_ =
179       llvm::BasicBlock::Create(/*Context=*/llvm_module_->getContext(),
180                                /*Name=*/"return", /*Parent=*/function_);
181 
182   b_->SetInsertPoint(return_block_);
183   b_->CreateRetVoid();
184 
185   b_->SetInsertPoint(llvm::BasicBlock::Create(
186       /*Context=*/llvm_module_->getContext(),
187       /*Name=*/"entry",
188       /*Parent=*/function_,
189       /*InsertBefore=*/return_block_));
190 }
191 
GetDynamicLoopBound(const int64_t offset)192 llvm::Value* IrFunction::GetDynamicLoopBound(const int64_t offset) {
193   CHECK_GT(num_dynamic_loop_bounds_, 0);
194   CHECK_LT(offset, num_dynamic_loop_bounds_ * 2);
195   llvm::Type* int64_ty = b_->getInt64Ty();
196   auto gep = b_->CreateGEP(int64_ty, CHECK_NOTNULL(dynamic_loop_bounds_arg_),
197                            b_->getInt64(offset));
198   return b_->CreateLoad(int64_ty, gep,
199                         "dynamic_loop_bound_" + llvm::Twine(offset));
200 }
201 
EncodeArrayFunctionArguments(absl::Span<llvm::Value * const> arguments,absl::string_view name,llvm::IRBuilder<> * b)202 llvm::Value* EncodeArrayFunctionArguments(
203     absl::Span<llvm::Value* const> arguments, absl::string_view name,
204     llvm::IRBuilder<>* b) {
205   llvm::Value* arguments_buffer;
206   llvm::Type* int8ptr_ty = b->getInt8PtrTy();
207   if (arguments.empty()) {
208     arguments_buffer = llvm::Constant::getNullValue(int8ptr_ty->getPointerTo());
209   } else {
210     arguments_buffer = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
211         int8ptr_ty, b->getInt32(arguments.size()),
212         absl::StrCat(name, "_parameter_addresses"), b);
213 
214     for (size_t i = 0; i < arguments.size(); i++) {
215       llvm::Value* parameter_as_i8ptr = b->CreateBitCast(
216           arguments[i], b->getInt8PtrTy(),
217           absl::StrCat(name, "_parameter_", i, "_address_as_i8ptr"));
218       llvm::Value* slot_in_param_addresses =
219           b->CreateInBoundsGEP(int8ptr_ty, arguments_buffer, b->getInt64(i));
220       b->CreateStore(parameter_as_i8ptr, slot_in_param_addresses);
221     }
222   }
223   return arguments_buffer;
224 }
225 
226 // Emits code to allocate an array of parameter address pointers, and store
227 // each address from 'parameter_addresses'.
228 // Returns an array of compute function call arguments (including parameter
229 // address buffer).
GetArrayFunctionCallArguments(absl::Span<llvm::Value * const> parameter_addresses,llvm::IRBuilder<> * b,absl::string_view name,llvm::Value * return_value_buffer,llvm::Value * exec_run_options_arg,llvm::Value * buffer_table_arg,llvm::Value * status_arg,llvm::Value * profile_counters_arg)230 std::vector<llvm::Value*> GetArrayFunctionCallArguments(
231     absl::Span<llvm::Value* const> parameter_addresses, llvm::IRBuilder<>* b,
232     absl::string_view name, llvm::Value* return_value_buffer,
233     llvm::Value* exec_run_options_arg, llvm::Value* buffer_table_arg,
234     llvm::Value* status_arg, llvm::Value* profile_counters_arg) {
235   llvm::Value* parameter_addresses_buffer =
236       EncodeArrayFunctionArguments(parameter_addresses, name, b);
237 
238   const auto to_int8_ptr = [=](llvm::Value* ptr) {
239     return b->CreatePointerCast(ptr, b->getInt8PtrTy());
240   };
241   return std::vector<llvm::Value*>{to_int8_ptr(return_value_buffer),
242                                    to_int8_ptr(exec_run_options_arg),
243                                    parameter_addresses_buffer,
244                                    buffer_table_arg,
245                                    status_arg,
246                                    profile_counters_arg};
247 }
248 
249 // Emits a call to a runtime fork/join function which dispatches parallel
250 // calls to 'parallel_function' (and joins threads before returning).
EmitCallToParallelForkJoin(const std::vector<llvm::Value * > & arguments,const Shape & shape,const std::vector<int64_t> & dimension_partition_counts,llvm::IRBuilder<> * b,llvm::Function * parallel_function,const std::string & name)251 Status EmitCallToParallelForkJoin(
252     const std::vector<llvm::Value*>& arguments, const Shape& shape,
253     const std::vector<int64_t>& dimension_partition_counts,
254     llvm::IRBuilder<>* b, llvm::Function* parallel_function,
255     const std::string& name) {
256   llvm::Module* module = b->GetInsertBlock()->getModule();
257 
258   // Build ParallelForkJoin function type.
259   std::vector<llvm::Type*> compute_function_params =
260       GetComputeFunctionParams(module, /*num_dynamic_loop_bounds=*/0);
261   // Number of parallel compute functions.
262   compute_function_params.push_back(b->getInt32Ty());
263   // Array of partitions. There is an array element for each
264   // partition x partition_dim x 2 (for dimension start and limit).
265   compute_function_params.push_back(
266       llvm::Type::getInt64PtrTy(module->getContext()));
267   // Number of partitioned most-major dimensions in 'shape'.
268   compute_function_params.push_back(b->getInt32Ty());
269   // Function pointer for compute function to be dispatched in parallel.
270   compute_function_params.push_back(
271       llvm::Type::getInt8PtrTy(module->getContext()));
272 
273   llvm::FunctionType* fork_join_type = llvm::FunctionType::get(
274       /*Result=*/llvm::Type::getVoidTy(module->getContext()),
275       /*Params=*/compute_function_params,
276       /*isVarArg=*/false);
277 
278   llvm::Function* fork_join_func = llvm::dyn_cast<llvm::Function>(
279       module
280           ->getOrInsertFunction(runtime::kParallelForkJoinSymbolName,
281                                 fork_join_type)
282           .getCallee());
283   fork_join_func->setCallingConv(llvm::CallingConv::C);
284   fork_join_func->setDoesNotThrow();
285 
286   // Add common compute function arguments.
287   std::vector<llvm::Value*> fork_join_arguments(arguments);
288 
289   // Create ShapePartitionIterator to generate all partitions of 'shape'.
290   ShapePartitionIterator partition_iterator(shape, dimension_partition_counts);
291   const int64_t num_partitions = partition_iterator.GetTotalPartitionCount();
292   // Add argument specifying the number of parallel partitions.
293   fork_join_arguments.push_back(b->getInt32(num_partitions));
294 
295   // The number of partitioned most-major dimensions in 'shape'.
296   const int32_t num_partitioned_dims = dimension_partition_counts.size();
297   // A dimension partition consists of two elements: [start_index, limit_index).
298   const int32_t dim_partition_size = 2;
299   // Calculate array partition stride.
300   const int32_t array_partition_stride =
301       num_partitioned_dims * dim_partition_size;
302   // Calculate the total number of elements in the partition array.
303   const int32_t partition_array_size =
304       dim_partition_size * num_partitioned_dims * num_partitions;
305 
306   // Store dimension partition values as llvm constants in 'partitions'.
307   // See comments in runtime_fork_join.cc for array layout description.
308   std::vector<llvm::Constant*> partitions(partition_array_size);
309   for (int32_t i = 0; i < num_partitions; ++i) {
310     std::vector<std::pair<int64_t, int64_t>> dim_partitions =
311         partition_iterator.GetPartition(i);
312     CHECK_EQ(num_partitioned_dims, dim_partitions.size());
313     const int32_t partition_index = i * array_partition_stride;
314     for (int32_t j = 0; j < num_partitioned_dims; ++j) {
315       const std::pair<int64_t, int64_t>& dim_partition = dim_partitions[j];
316       const int32_t index = partition_index + j * dim_partition_size;
317       // Store partition [dim_start, dim_limit) intervals for each dimension.
318       partitions[index] = b->getInt64(dim_partition.first);
319       partitions[index + 1] =
320           b->getInt64(dim_partition.first + dim_partition.second);
321     }
322   }
323 
324   // Create global variable out of dimension partitions in 'partitions'.
325   llvm::ArrayType* partitions_array_type =
326       llvm::ArrayType::get(b->getInt64Ty(), partition_array_size);
327   llvm::Constant* partitions_array =
328       llvm::ConstantArray::get(partitions_array_type, partitions);
329   llvm::GlobalVariable* global_partitions_array = new llvm::GlobalVariable(
330       /*M=*/*module,
331       /*Ty=*/partitions_array_type,
332       /*isConstant=*/true,
333       /*Linkage=*/llvm::GlobalValue::PrivateLinkage,
334       /*Initializer=*/partitions_array,
335       /*Name=*/
336       absl::StrCat(name, "_parallel_dimension_partitions"));
337 
338   // Add argument specifying parallel dimension partitions.
339   fork_join_arguments.push_back(
340       b->CreateBitCast(global_partitions_array,
341                        llvm::Type::getInt64PtrTy(module->getContext())));
342   // Add argument specifying the number of partitioned most-major dimensions.
343   fork_join_arguments.push_back(b->getInt32(num_partitioned_dims));
344   // Add argument for parallel compute function pointer.
345   fork_join_arguments.push_back(
346       b->CreateBitCast(parallel_function, b->getInt8PtrTy()));
347   // Emit call to parallel fork/join.
348   b->CreateCall(fork_join_func, fork_join_arguments);
349 
350   return OkStatus();
351 }
352 
353 }  // namespace cpu
354 }  // namespace xla
355