1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/cpu/ir_function.h"
17
18 #include <iterator>
19
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
22 #include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
23 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
24 #include "tensorflow/compiler/xla/status_macros.h"
25
26 namespace xla {
27 namespace cpu {
28
GetComputeFunctionParams(llvm::Module * llvm_module,const int64_t num_dynamic_loop_bounds)29 static std::vector<llvm::Type*> GetComputeFunctionParams(
30 llvm::Module* llvm_module, const int64_t num_dynamic_loop_bounds) {
31 llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(llvm_module->getContext());
32 llvm::Type* i8_ptr_ptr_type = i8_ptr_type->getPointerTo();
33 llvm::Type* i64_ptr_type =
34 llvm::Type::getInt64PtrTy(llvm_module->getContext());
35 std::vector<llvm::Type*> compute_function_params(
36 {i8_ptr_type, i8_ptr_type, i8_ptr_ptr_type, i8_ptr_ptr_type,
37 i8_ptr_type});
38 if (num_dynamic_loop_bounds > 0) {
39 compute_function_params.push_back(i64_ptr_type);
40 }
41 compute_function_params.push_back(i64_ptr_type);
42 return compute_function_params;
43 }
44
IrFunction(const std::string & function_name,llvm::Function::LinkageTypes linkage,const HloModuleConfig & module_config,llvm::Module * llvm_module,llvm::IRBuilder<> * b,int64_t num_dynamic_loop_bounds)45 IrFunction::IrFunction(const std::string& function_name,
46 llvm::Function::LinkageTypes linkage,
47 const HloModuleConfig& module_config,
48 llvm::Module* llvm_module, llvm::IRBuilder<>* b,
49 int64_t num_dynamic_loop_bounds)
50 : b_(b),
51 llvm_module_(llvm_module),
52 caller_insert_point_guard_(*b),
53 num_dynamic_loop_bounds_(num_dynamic_loop_bounds) {
54 Initialize(function_name, linkage, module_config);
55 }
56
~IrFunction()57 IrFunction::~IrFunction() {
58 // Branch to function return.
59 b_->CreateBr(return_block_);
60 }
61
GetDynamicLoopBounds()62 DynamicLoopBounds IrFunction::GetDynamicLoopBounds() {
63 DynamicLoopBounds dynamic_loop_bounds(num_dynamic_loop_bounds_);
64 for (int i = 0; i < num_dynamic_loop_bounds_; ++i) {
65 dynamic_loop_bounds[i].first = GetDynamicLoopBound(i * 2 + 0);
66 dynamic_loop_bounds[i].second = GetDynamicLoopBound(i * 2 + 1);
67 }
68 return dynamic_loop_bounds;
69 }
70
Initialize(const std::string & function_name,llvm::Function::LinkageTypes linkage,const HloModuleConfig & module_config)71 void IrFunction::Initialize(const std::string& function_name,
72 llvm::Function::LinkageTypes linkage,
73 const HloModuleConfig& module_config) {
74 // The function signature is:
75 // void function(i8* retval, i8* run_options, i8** params, i8**
76 // buffer_table,
77 // i64* dynamic_loop_bounds, i64* prof_counters)
78 //
79 // For thread local functions:
80 // retval: points to the returned value.
81 // params: address of an array with pointers to parameters.
82 // buffer_table: is null
83 //
84 // For global functions:
85 // retval: is null
86 // params: is null
87 // buffer_table: address of an array with pointers to temporary buffers and
88 // entry computation parameters (but not to constant buffers).
89 //
90 // Therefore, the generated function's signature (FunctionType) is statically
91 // determined - parameter unpacking is done in code generated into the
92 // function, rather than by a prologue dictated by the platform ABI.
93 //
94 // /--------------\
95 // retval ----------> | return value |
96 // \--------------/
97 //
98 // /-------------------------------\
99 // run_options -----> | xla::ExecutableRunOptions |
100 // \-------------------------------/
101 //
102 // /---------------------------------------------\
103 // params --------> | param 0 | param 1 | ..... | param N-1 |
104 // | addr | addr | | addr |
105 // \---------------------------------------------/
106 // | | |
107 // | | |
108 // V V V
109 // /---------\ /---------\ /-----------\
110 // | param 0 | | param 1 | | param N-1 |
111 // \---------/ \---------/ \-----------/
112 //
113 // /---------------------------------------------\
114 // buffer_table---> | buff 0 | guff 1 | ..... | buff N-1 |
115 // | addr | addr | | addr |
116 // \---------------------------------------------/
117 // | | |
118 // | | |
119 // V V V
120 // /---------\ /---------\ /-----------\
121 // | temp 0 | | temp 1 | | temp N-1 |
122 // \---------/ \---------/ \-----------/
123 //
124 // /--------------------------------------------\
125 // dynamic loop bounds -> | outer_dim0_start | outer_dim0_limit | .....|
126 // (elided for aot) \--------------------------------------------/
127 //
128 // /---------------------------------------------\
129 // prof counters -> | counter 0 | counter 1 | ..... | counter N-1 |
130 // \---------------------------------------------/
131
132 // Even though the type of params and buffer_table is void** in the host's
133 // view, in LLVM IR this is represented by i8*, similarly to void*. It's up to
134 // the code to use GEPs to unravel the indirection layers.
135 llvm::FunctionType* function_type = llvm::FunctionType::get(
136 /*Result=*/llvm::Type::getVoidTy(llvm_module_->getContext()),
137 /*Params=*/
138 GetComputeFunctionParams(llvm_module_, num_dynamic_loop_bounds_),
139 /*isVarArg=*/false);
140
141 // Functions with local linkage get an inlining bonus. Because we know
142 // a-priori that embedded functions (non-entry functions) will not have its
143 // name resolved, give it local linkage.
144 function_ = llvm_ir::CreateCpuFunction(function_type, linkage, module_config,
145 function_name, llvm_module_);
146
147 // Set meaningful names for the function's arguments: useful for debugging.
148 llvm::Function::arg_iterator arg_iter = function_->arg_begin();
149 arg_iter->setName("retval");
150 result_arg_ = &*arg_iter;
151 (++arg_iter)->setName("run_options");
152 exec_run_options_arg_ = &*arg_iter;
153 (++arg_iter)->setName("params");
154 parameters_arg_ = &*arg_iter;
155 (++arg_iter)->setName("buffer_table");
156 buffer_table_arg_ = &*arg_iter;
157 (++arg_iter)->setName("status");
158 status_arg_ = &*arg_iter;
159 if (num_dynamic_loop_bounds_ > 0) {
160 (++arg_iter)->setName("dynamic_loop_bounds");
161 dynamic_loop_bounds_arg_ = &*arg_iter;
162 }
163 (++arg_iter)->setName("prof_counters");
164 profile_counters_arg_ = &*arg_iter;
165
166 // We know a-priori that the function arguments are guaranteed to point to
167 // disjoint objects.
168 llvm::Argument* retval = result_arg();
169 for (llvm::Argument& argument : function_->args()) {
170 // However, the return buffer aliases the temporaries and thus cannot be
171 // marked noalias.
172 if (&argument == retval) {
173 continue;
174 }
175 function_->addParamAttr(argument.getArgNo(), llvm::Attribute::NoAlias);
176 }
177
178 return_block_ =
179 llvm::BasicBlock::Create(/*Context=*/llvm_module_->getContext(),
180 /*Name=*/"return", /*Parent=*/function_);
181
182 b_->SetInsertPoint(return_block_);
183 b_->CreateRetVoid();
184
185 b_->SetInsertPoint(llvm::BasicBlock::Create(
186 /*Context=*/llvm_module_->getContext(),
187 /*Name=*/"entry",
188 /*Parent=*/function_,
189 /*InsertBefore=*/return_block_));
190 }
191
GetDynamicLoopBound(const int64_t offset)192 llvm::Value* IrFunction::GetDynamicLoopBound(const int64_t offset) {
193 CHECK_GT(num_dynamic_loop_bounds_, 0);
194 CHECK_LT(offset, num_dynamic_loop_bounds_ * 2);
195 llvm::Type* int64_ty = b_->getInt64Ty();
196 auto gep = b_->CreateGEP(int64_ty, CHECK_NOTNULL(dynamic_loop_bounds_arg_),
197 b_->getInt64(offset));
198 return b_->CreateLoad(int64_ty, gep,
199 "dynamic_loop_bound_" + llvm::Twine(offset));
200 }
201
EncodeArrayFunctionArguments(absl::Span<llvm::Value * const> arguments,absl::string_view name,llvm::IRBuilder<> * b)202 llvm::Value* EncodeArrayFunctionArguments(
203 absl::Span<llvm::Value* const> arguments, absl::string_view name,
204 llvm::IRBuilder<>* b) {
205 llvm::Value* arguments_buffer;
206 llvm::Type* int8ptr_ty = b->getInt8PtrTy();
207 if (arguments.empty()) {
208 arguments_buffer = llvm::Constant::getNullValue(int8ptr_ty->getPointerTo());
209 } else {
210 arguments_buffer = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
211 int8ptr_ty, b->getInt32(arguments.size()),
212 absl::StrCat(name, "_parameter_addresses"), b);
213
214 for (size_t i = 0; i < arguments.size(); i++) {
215 llvm::Value* parameter_as_i8ptr = b->CreateBitCast(
216 arguments[i], b->getInt8PtrTy(),
217 absl::StrCat(name, "_parameter_", i, "_address_as_i8ptr"));
218 llvm::Value* slot_in_param_addresses =
219 b->CreateInBoundsGEP(int8ptr_ty, arguments_buffer, b->getInt64(i));
220 b->CreateStore(parameter_as_i8ptr, slot_in_param_addresses);
221 }
222 }
223 return arguments_buffer;
224 }
225
226 // Emits code to allocate an array of parameter address pointers, and store
227 // each address from 'parameter_addresses'.
228 // Returns an array of compute function call arguments (including parameter
229 // address buffer).
GetArrayFunctionCallArguments(absl::Span<llvm::Value * const> parameter_addresses,llvm::IRBuilder<> * b,absl::string_view name,llvm::Value * return_value_buffer,llvm::Value * exec_run_options_arg,llvm::Value * buffer_table_arg,llvm::Value * status_arg,llvm::Value * profile_counters_arg)230 std::vector<llvm::Value*> GetArrayFunctionCallArguments(
231 absl::Span<llvm::Value* const> parameter_addresses, llvm::IRBuilder<>* b,
232 absl::string_view name, llvm::Value* return_value_buffer,
233 llvm::Value* exec_run_options_arg, llvm::Value* buffer_table_arg,
234 llvm::Value* status_arg, llvm::Value* profile_counters_arg) {
235 llvm::Value* parameter_addresses_buffer =
236 EncodeArrayFunctionArguments(parameter_addresses, name, b);
237
238 const auto to_int8_ptr = [=](llvm::Value* ptr) {
239 return b->CreatePointerCast(ptr, b->getInt8PtrTy());
240 };
241 return std::vector<llvm::Value*>{to_int8_ptr(return_value_buffer),
242 to_int8_ptr(exec_run_options_arg),
243 parameter_addresses_buffer,
244 buffer_table_arg,
245 status_arg,
246 profile_counters_arg};
247 }
248
249 // Emits a call to a runtime fork/join function which dispatches parallel
250 // calls to 'parallel_function' (and joins threads before returning).
EmitCallToParallelForkJoin(const std::vector<llvm::Value * > & arguments,const Shape & shape,const std::vector<int64_t> & dimension_partition_counts,llvm::IRBuilder<> * b,llvm::Function * parallel_function,const std::string & name)251 Status EmitCallToParallelForkJoin(
252 const std::vector<llvm::Value*>& arguments, const Shape& shape,
253 const std::vector<int64_t>& dimension_partition_counts,
254 llvm::IRBuilder<>* b, llvm::Function* parallel_function,
255 const std::string& name) {
256 llvm::Module* module = b->GetInsertBlock()->getModule();
257
258 // Build ParallelForkJoin function type.
259 std::vector<llvm::Type*> compute_function_params =
260 GetComputeFunctionParams(module, /*num_dynamic_loop_bounds=*/0);
261 // Number of parallel compute functions.
262 compute_function_params.push_back(b->getInt32Ty());
263 // Array of partitions. There is an array element for each
264 // partition x partition_dim x 2 (for dimension start and limit).
265 compute_function_params.push_back(
266 llvm::Type::getInt64PtrTy(module->getContext()));
267 // Number of partitioned most-major dimensions in 'shape'.
268 compute_function_params.push_back(b->getInt32Ty());
269 // Function pointer for compute function to be dispatched in parallel.
270 compute_function_params.push_back(
271 llvm::Type::getInt8PtrTy(module->getContext()));
272
273 llvm::FunctionType* fork_join_type = llvm::FunctionType::get(
274 /*Result=*/llvm::Type::getVoidTy(module->getContext()),
275 /*Params=*/compute_function_params,
276 /*isVarArg=*/false);
277
278 llvm::Function* fork_join_func = llvm::dyn_cast<llvm::Function>(
279 module
280 ->getOrInsertFunction(runtime::kParallelForkJoinSymbolName,
281 fork_join_type)
282 .getCallee());
283 fork_join_func->setCallingConv(llvm::CallingConv::C);
284 fork_join_func->setDoesNotThrow();
285
286 // Add common compute function arguments.
287 std::vector<llvm::Value*> fork_join_arguments(arguments);
288
289 // Create ShapePartitionIterator to generate all partitions of 'shape'.
290 ShapePartitionIterator partition_iterator(shape, dimension_partition_counts);
291 const int64_t num_partitions = partition_iterator.GetTotalPartitionCount();
292 // Add argument specifying the number of parallel partitions.
293 fork_join_arguments.push_back(b->getInt32(num_partitions));
294
295 // The number of partitioned most-major dimensions in 'shape'.
296 const int32_t num_partitioned_dims = dimension_partition_counts.size();
297 // A dimension partition consists of two elements: [start_index, limit_index).
298 const int32_t dim_partition_size = 2;
299 // Calculate array partition stride.
300 const int32_t array_partition_stride =
301 num_partitioned_dims * dim_partition_size;
302 // Calculate the total number of elements in the partition array.
303 const int32_t partition_array_size =
304 dim_partition_size * num_partitioned_dims * num_partitions;
305
306 // Store dimension partition values as llvm constants in 'partitions'.
307 // See comments in runtime_fork_join.cc for array layout description.
308 std::vector<llvm::Constant*> partitions(partition_array_size);
309 for (int32_t i = 0; i < num_partitions; ++i) {
310 std::vector<std::pair<int64_t, int64_t>> dim_partitions =
311 partition_iterator.GetPartition(i);
312 CHECK_EQ(num_partitioned_dims, dim_partitions.size());
313 const int32_t partition_index = i * array_partition_stride;
314 for (int32_t j = 0; j < num_partitioned_dims; ++j) {
315 const std::pair<int64_t, int64_t>& dim_partition = dim_partitions[j];
316 const int32_t index = partition_index + j * dim_partition_size;
317 // Store partition [dim_start, dim_limit) intervals for each dimension.
318 partitions[index] = b->getInt64(dim_partition.first);
319 partitions[index + 1] =
320 b->getInt64(dim_partition.first + dim_partition.second);
321 }
322 }
323
324 // Create global variable out of dimension partitions in 'partitions'.
325 llvm::ArrayType* partitions_array_type =
326 llvm::ArrayType::get(b->getInt64Ty(), partition_array_size);
327 llvm::Constant* partitions_array =
328 llvm::ConstantArray::get(partitions_array_type, partitions);
329 llvm::GlobalVariable* global_partitions_array = new llvm::GlobalVariable(
330 /*M=*/*module,
331 /*Ty=*/partitions_array_type,
332 /*isConstant=*/true,
333 /*Linkage=*/llvm::GlobalValue::PrivateLinkage,
334 /*Initializer=*/partitions_array,
335 /*Name=*/
336 absl::StrCat(name, "_parallel_dimension_partitions"));
337
338 // Add argument specifying parallel dimension partitions.
339 fork_join_arguments.push_back(
340 b->CreateBitCast(global_partitions_array,
341 llvm::Type::getInt64PtrTy(module->getContext())));
342 // Add argument specifying the number of partitioned most-major dimensions.
343 fork_join_arguments.push_back(b->getInt32(num_partitioned_dims));
344 // Add argument for parallel compute function pointer.
345 fork_join_arguments.push_back(
346 b->CreateBitCast(parallel_function, b->getInt8PtrTy()));
347 // Emit call to parallel fork/join.
348 b->CreateCall(fork_join_func, fork_join_arguments);
349
350 return OkStatus();
351 }
352
353 } // namespace cpu
354 } // namespace xla
355