1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h"
17
18 #include <vector>
19
20 // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
21 #include "absl/strings/str_cat.h"
22 #include "absl/strings/string_view.h"
23 #include "absl/types/span.h"
24 #include "llvm/ADT/APInt.h"
25 #include "llvm/IR/BasicBlock.h"
26 #include "llvm/IR/Constants.h"
27 #include "llvm/IR/DerivedTypes.h"
28 #include "llvm/IR/Instructions.h"
29 #include "llvm/IR/Value.h"
30 #include "tensorflow/compiler/xla/primitive_util.h"
31 #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
32 #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h"
33 #include "tensorflow/compiler/xla/service/gpu/target_util.h"
34 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
35 #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
36 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
37 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
38 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
39 #include "tensorflow/compiler/xla/shape_util.h"
40 #include "tensorflow/compiler/xla/util.h"
41 #include "tensorflow/core/lib/core/status.h"
42
43 namespace xla {
44 namespace llvm_ir {
45
46 namespace {
47
48 // Adds the inner comparison loop body where we compare elements.
EmitCompareLoopBody(int64_t iteration_bound,int64_t num_values,llvm::Value * element_pair_index,int64_t xor_mask,llvm::Type * index_type,std::function<llvm::Value * (int64_t operand,llvm::Value * index)> element_address,std::function<llvm::Type * (int64_t operand,llvm::Value * index)> element_address_pointee_type,std::function<void (int64_t operand,llvm::Value * index,llvm::Value * value)> write_element,const EmitCallToNestedComputationCallback & emit_compare_callback,llvm::IRBuilder<> * b,bool needs_bounds_checks=true)49 Status EmitCompareLoopBody(
50 int64_t iteration_bound, int64_t num_values,
51 llvm::Value* element_pair_index, int64_t xor_mask, llvm::Type* index_type,
52 std::function<llvm::Value*(int64_t operand, llvm::Value* index)>
53 element_address,
54 std::function<llvm::Type*(int64_t operand, llvm::Value* index)>
55 element_address_pointee_type,
56 std::function<void(int64_t operand, llvm::Value* index, llvm::Value* value)>
57 write_element,
58 const EmitCallToNestedComputationCallback& emit_compare_callback,
59 llvm::IRBuilder<>* b, bool needs_bounds_checks = true) {
60 auto index_typed_constant = [&](int64_t value) {
61 return llvm::ConstantInt::get(index_type, value);
62 };
63 // The 'xor_mask' determines which elements are compared against each other.
64 // Index 'current_keys_index' will be compared with 'current_keys_index' xor
65 // 'xor_mask'. This means that we will always compare a block of consecutive
66 // elements against elements from the adjacent block of the same size. When
67 // 'xor_mask' is a power of 2, it immediately identifies the size of such a
68 // block. We can also have 'xor_mask' being 2^k - 1 (for some value of k). In
69 // that case, we essentially flip the last 'k' - 1 bits when computing the
70 // position of the element to compare to, so the block size is 2^(k - 1).
71 int64_t block_size = xor_mask;
72 // Check if it is a value 2^k - 1.
73 if (xor_mask > 1 && (xor_mask & (xor_mask + 1)) == 0) {
74 block_size = (xor_mask + 1) / 2;
75 }
76 auto current_keys_index = element_pair_index;
77 if (block_size == 1) {
78 // If the block size is 1, we take every second element and compare it to
79 // the next one.
80 current_keys_index =
81 b->CreateMul(current_keys_index, index_typed_constant(2));
82 } else if (block_size * 2 < iteration_bound) {
83 // current_keys_index iterates through the 'left' elements of the element
84 // pairs to be compared. We first need to compute the comparison block to
85 // which the element belongs. The block id of that block is index /
86 // block_size.
87 auto block_id =
88 b->CreateUDiv(current_keys_index, index_typed_constant(block_size));
89 // The index of the 'left' element within its block is simply the remainder
90 // when dividing by 'block_size'.
91 auto index_within_block =
92 b->CreateURem(current_keys_index, index_typed_constant(block_size));
93 // The first element of the 'left' block of elements that is compared
94 // against elements from the adjacent 'right' block of elements is
95 // 'block_id' * (2 * 'block_size').
96 auto first_element_in_block =
97 b->CreateMul(block_id, index_typed_constant(2 * block_size));
98 current_keys_index =
99 b->CreateAdd(first_element_in_block, index_within_block);
100 }
101 auto compare_keys_index =
102 b->CreateXor(current_keys_index, index_typed_constant(xor_mask));
103 // current_keys_index < compare_keys_index
104 llvm::Value* is_smaller_index =
105 b->CreateICmpSLT(current_keys_index, compare_keys_index);
106 // compare_keys_index < iteration_bound
107 llvm::Value* index_is_inbounds = b->CreateICmpSLT(
108 compare_keys_index, index_typed_constant(iteration_bound));
109 llvm::Value* do_comparison =
110 needs_bounds_checks ? b->CreateAnd(is_smaller_index, index_is_inbounds)
111 : b->getInt1(true);
112
113 // if (is_smaller_index && index_is_inbounds)
114 KernelSupportLibrary ksl(b);
115 return ksl.IfWithStatus("smaller_comparison_index", do_comparison, [&]() {
116 std::vector<llvm::Value*> values_to_compare;
117 std::vector<llvm::Type*> values_to_compare_types;
118 for (int i = 0; i < num_values; ++i) {
119 values_to_compare.push_back(element_address(i, compare_keys_index));
120 values_to_compare_types.push_back(
121 element_address_pointee_type(i, compare_keys_index));
122
123 values_to_compare.push_back(element_address(i, current_keys_index));
124 values_to_compare_types.push_back(
125 element_address_pointee_type(i, current_keys_index));
126 }
127 llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
128 llvm::Type* pred_type = llvm_ir::PrimitiveTypeToIrType(PRED, module);
129 llvm::Value* compare_return_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
130 pred_type, "compare_return_buffer", b);
131 TF_RETURN_IF_ERROR(
132 emit_compare_callback(values_to_compare, compare_return_buffer));
133 llvm::Value* result = b->CreateLoad(pred_type, compare_return_buffer);
134
135 // Check if the 'compare' function returns true.
136 llvm::Value* is_smaller_than =
137 b->CreateICmpNE(result, llvm::ConstantInt::get(result->getType(), 0),
138 "boolean_predicate");
139 ksl.If("is_smaller_than", is_smaller_than, [&]() {
140 for (int64_t i = 0; i < num_values; ++i) {
141 // Swap the values.
142 auto value1 = b->CreateLoad(values_to_compare_types[i * 2],
143 values_to_compare[i * 2]);
144 auto value2 = b->CreateLoad(values_to_compare_types[i * 2 + 1],
145 values_to_compare[i * 2 + 1]);
146 write_element(i, current_keys_index, value1);
147 write_element(i, compare_keys_index, value2);
148 }
149 });
150 return OkStatus();
151 });
152 }
153
EmitTiledCompareLoop(const IrArray::Index & tiled_keys_index,int64_t dimension_to_sort,int64_t dimension_to_sort_bound,absl::Span<const int64_t> xor_masks,const std::vector<IrArray> & params,const std::vector<llvm::GlobalVariable * > & param_shmem_buffers,int64_t tile_size,const EmitCallToNestedComputationCallback & emit_compare_callback,llvm::IRBuilder<> * b)154 Status EmitTiledCompareLoop(
155 const IrArray::Index& tiled_keys_index, int64_t dimension_to_sort,
156 int64_t dimension_to_sort_bound, absl::Span<const int64_t> xor_masks,
157 const std::vector<IrArray>& params,
158 const std::vector<llvm::GlobalVariable*>& param_shmem_buffers,
159 int64_t tile_size,
160 const EmitCallToNestedComputationCallback& emit_compare_callback,
161 llvm::IRBuilder<>* b) {
162 KernelSupportLibrary ksl(b);
163 llvm::Value* thread_id = gpu::EmitCallToTargetIntrinsic(
164 gpu::TargetIntrinsicID::kThreadIdx, {}, {}, b);
165 llvm_ir::AddRangeMetadata(0, tile_size / 2,
166 llvm::cast<llvm::Instruction>(thread_id));
167 thread_id = b->CreateIntCast(thread_id, tiled_keys_index.GetType(),
168 /*isSigned=*/true, "thread.id.x");
169
170 auto copy_loop_body =
171 [&](std::function<void(llvm::Value * cache_index, llvm::Value * index)>
172 read_or_write) {
173 auto value_one = tiled_keys_index.GetConstantWithIndexType(1);
174 auto current_keys_index =
175 b->CreateShl(tiled_keys_index[dimension_to_sort], value_one);
176 // We want to copy two adjacent elements. We first check whether the
177 // first index position is within bounds.
178 ksl.If(
179 "smaller_keys_index",
180 b->CreateICmpSLT(current_keys_index,
181 tiled_keys_index.GetConstantWithIndexType(
182 dimension_to_sort_bound)),
183 [&]() {
184 auto cache_index = b->CreateShl(thread_id, value_one);
185 read_or_write(cache_index, current_keys_index);
186 // Increment to go to the next index position.
187 current_keys_index = b->CreateAdd(current_keys_index, value_one);
188 // Here we check whether the next index position is within bounds.
189 ksl.If("inner_smaller_keys_index",
190 b->CreateICmpSLT(current_keys_index,
191 tiled_keys_index.GetConstantWithIndexType(
192 dimension_to_sort_bound)),
193 [&]() {
194 cache_index = b->CreateAdd(cache_index, value_one);
195 read_or_write(cache_index, current_keys_index);
196 });
197 });
198 };
199
200 // Copy operand tiles from the operand buffers to shared memory.
201 std::vector<llvm::Value*> keys_multi_index = tiled_keys_index.multidim();
202 for (int64_t i = 0; i < params.size(); ++i) {
203 copy_loop_body([&](llvm::Value* cache_index, llvm::Value* index) {
204 keys_multi_index[dimension_to_sort] = index;
205 IrArray::Index keys_index(keys_multi_index, params[i].GetShape(),
206 tiled_keys_index.GetType());
207 auto value = params[i].EmitReadArrayElement(keys_index, b);
208 b->CreateStore(
209 value,
210 b->CreateGEP(
211 param_shmem_buffers[i]->getValueType(), param_shmem_buffers[i],
212 {tiled_keys_index.GetConstantWithIndexType(0), cache_index}));
213 });
214 }
215 // Wait until all reads have happened.
216 gpu::EmitCallToTargetIntrinsic(gpu::TargetIntrinsicID::kBarrierId, {}, {}, b);
217
218 // Now emit the bodies of the comparison loops.
219 auto element_address = [&](int64_t operand, llvm::Value* index) {
220 auto shared_memory_address =
221 b->CreateGEP(param_shmem_buffers[operand]->getValueType(),
222 param_shmem_buffers[operand],
223 {tiled_keys_index.GetConstantWithIndexType(0), index});
224 auto ptr_type = shared_memory_address->getType();
225 // We need a generic pointer with address space 0 instead of a pointer to
226 // shared memory (address space 3) so that we can pass it to the comparison
227 // computation.
228 return b->CreateAddrSpaceCast(shared_memory_address,
229 llvm::PointerType::getWithSamePointeeType(
230 llvm::cast<llvm::PointerType>(ptr_type),
231 /*AddressSpace=*/0));
232 };
233 auto element_address_pointee_type = [&](int64_t operand, llvm::Value* index) {
234 return llvm::GetElementPtrInst::getIndexedType(
235 param_shmem_buffers[operand]->getValueType(),
236 {tiled_keys_index.GetConstantWithIndexType(0), index});
237 };
238 auto write_element = [&](int64_t operand, llvm::Value* index,
239 llvm::Value* value) {
240 b->CreateStore(
241 value,
242 b->CreateGEP(param_shmem_buffers[operand]->getValueType(),
243 param_shmem_buffers[operand],
244 {tiled_keys_index.GetConstantWithIndexType(0), index}));
245 };
246 for (int64_t xor_mask : xor_masks) {
247 // The index of the element pair to be compared within the tile stored in
248 // shared memory. We order the element pairs by the element with the smaller
249 // index.
250 auto element_pair_index = thread_id;
251 // If 'dimension_to_sort_bound' is evenly divisible by 'tile_size', we don't
252 // need any bounds checks.
253 if (dimension_to_sort_bound % tile_size) {
254 // Otherwise we need a bounds check for the last tile. The last tile has
255 // size 'dimension_to_sort_bound' % 'tile_size'.
256 TF_RETURN_IF_ERROR(ksl.IfWithStatus(
257 "is_last_tile",
258 b->CreateICmpUGE(
259 b->CreateMul(tiled_keys_index[dimension_to_sort],
260 tiled_keys_index.GetConstantWithIndexType(2)),
261 tiled_keys_index.GetConstantWithIndexType(
262 RoundDownTo(dimension_to_sort_bound, tile_size))),
263 [&]() {
264 return EmitCompareLoopBody(
265 dimension_to_sort_bound % tile_size, params.size(),
266 element_pair_index, xor_mask, tiled_keys_index.GetType(),
267 element_address, element_address_pointee_type, write_element,
268 emit_compare_callback, b);
269 },
270 [&]() {
271 return EmitCompareLoopBody(
272 tile_size, params.size(), element_pair_index, xor_mask,
273 tiled_keys_index.GetType(), element_address,
274 element_address_pointee_type, write_element,
275 emit_compare_callback, b,
276 /*needs_bounds_checks=*/false);
277 }));
278 } else {
279 TF_RETURN_IF_ERROR(EmitCompareLoopBody(
280 tile_size, params.size(), element_pair_index, xor_mask,
281 tiled_keys_index.GetType(), element_address,
282 element_address_pointee_type, write_element, emit_compare_callback, b,
283 /*needs_bounds_checks=*/false));
284 }
285 // Wait until all comparisons have happened.
286 gpu::EmitCallToTargetIntrinsic(gpu::TargetIntrinsicID::kBarrierId, {}, {},
287 b);
288 }
289
290 // Copy the operand tiles back from shared memory to the operand buffers.
291 for (int64_t i = 0; i < params.size(); ++i) {
292 copy_loop_body([&](llvm::Value* cache_index, llvm::Value* index) {
293 keys_multi_index[dimension_to_sort] = index;
294 IrArray::Index keys_index(keys_multi_index, params[i].GetShape(),
295 tiled_keys_index.GetType());
296 auto gep = b->CreateGEP(
297 param_shmem_buffers[i]->getValueType(), param_shmem_buffers[i],
298 {tiled_keys_index.GetConstantWithIndexType(0), cache_index});
299 auto gep_type = llvm::GetElementPtrInst::getIndexedType(
300 param_shmem_buffers[i]->getValueType(),
301 {tiled_keys_index.GetConstantWithIndexType(0), cache_index});
302 auto value = b->CreateLoad(gep_type, gep);
303 params[i].EmitWriteArrayElement(keys_index, value, b);
304 });
305 }
306 // We should normally synchronize here to make sure all writes have happened.
307 // However the very next thing each thread does is reading 2 elements from the
308 // operand buffer and writing it into the same location in shared memory from
309 // which it previously copied it to the operand buffer, and we synchronize
310 // after this has happened. We can be sure that a thread always writes to the
311 // same location in shared memory because we have exactly tile_size / 2 many
312 // threads, and the linear index calculated by ParallelLoopEmitter uses
313 // linear_index = blockIdx.x * blockDim.x + threadIdx.x;
314 return OkStatus();
315 }
316 } // namespace
317
EmitSortInPlace(int64_t dimension_to_sort,const std::vector<IrArray> & values_arrays,absl::string_view name,absl::Span<const int64_t> xor_masks,llvm::IRBuilder<> * b,const gpu::LaunchDimensions & launch_dimensions,int64_t num_iterations_in_sort_dim,const int64_t tile_size,const EmitCallToNestedComputationCallback & emit_compare_callback)318 Status EmitSortInPlace(
319 int64_t dimension_to_sort, const std::vector<IrArray>& values_arrays,
320 absl::string_view name, absl::Span<const int64_t> xor_masks,
321 llvm::IRBuilder<>* b, const gpu::LaunchDimensions& launch_dimensions,
322 int64_t num_iterations_in_sort_dim, const int64_t tile_size,
323 const EmitCallToNestedComputationCallback& emit_compare_callback) {
324 // Iterate through the keys shape in physical order, but skip the dimension to
325 // sort and make it the innermost loop which is the loop where the comparisons
326 // happen. In the dimension to sort, if we use tiling, we iterate through it
327 // in tiles of 64 elements each, so we use another loop that happens within
328 // one thread to process this tile worth of data (thereby combining several
329 // comparison stages of the bitonic sort algorithm because they all happen
330 // within those 64 elements and are therefore independent of the other
331 // comparisons).
332
333 const Shape& keys_shape = values_arrays[0].GetShape();
334 int64_t rank = keys_shape.rank();
335 int64_t dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort);
336 std::vector<int64_t> dimensions_in_iteration_order(rank);
337 std::vector<int64_t> iteration_order_to_logical_order(rank);
338 int64_t dim = 0;
339 for (int64_t dimension : LayoutUtil::MinorToMajor(keys_shape)) {
340 if (dimension != dimension_to_sort) {
341 dimensions_in_iteration_order[dim] = keys_shape.dimensions(dimension);
342 iteration_order_to_logical_order[dim++] = dimension;
343 }
344 }
345 dimensions_in_iteration_order[dim] = num_iterations_in_sort_dim;
346 iteration_order_to_logical_order[dim] = dimension_to_sort;
347
348 Shape iteration_shape = ShapeUtil::MakeShape(keys_shape.element_type(),
349 dimensions_in_iteration_order);
350
351 // Allocate shared memory for the tiled compare loop.
352 std::vector<llvm::GlobalVariable*> param_shmem_buffers(values_arrays.size(),
353 nullptr);
354 if (xor_masks.size() > 1) {
355 llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
356 for (int64_t i = 0; i < values_arrays.size(); ++i) {
357 llvm::Type* tile_type = llvm::ArrayType::get(
358 llvm_ir::PrimitiveTypeToIrType(
359 values_arrays[i].GetShape().element_type(), module),
360 tile_size);
361 param_shmem_buffers[i] = llvm_ir::AllocateSharedMemoryTile(
362 module, tile_type, absl::StrCat(name, "_tile_param_", i));
363 }
364 }
365
366 auto compare_loop_body_emitter =
367 [&](const IrArray::Index& tiles_index) -> Status {
368 // Naive C++ code for the inner compare loop:
369 //
370 // for (int64_t i = 0; i < dimension_to_sort_bound; ++i) {
371 // int64_t j = i ^ xor_mask;
372 // /* emitted in EmitCompareLoopBody() */
373 // if (i < j && j < dimension_to_sort_bound) {
374 // int64_t min_key = std::min(keys[i], keys[j]);
375 // keys[j] = std::max(keys[i], keys[j]);
376 // keys[i] = min_key;
377 // }
378 // }
379 //
380 // This follows the algorithm described on Wikipedia:
381 // https://en.wikipedia.org/wiki/Bitonic_sorter
382 std::vector<llvm::Value*> keys_multi_index(rank);
383 for (int64_t i = 0; i < rank; ++i) {
384 keys_multi_index[iteration_order_to_logical_order[i]] = tiles_index[i];
385 }
386 if (xor_masks.size() > 1) {
387 IrArray::Index keys_index(keys_multi_index, values_arrays[0].GetShape(),
388 tiles_index.GetType());
389 TF_RETURN_IF_ERROR(EmitTiledCompareLoop(
390 keys_index, dimension_to_sort, dimension_to_sort_bound, xor_masks,
391 values_arrays, param_shmem_buffers, tile_size, emit_compare_callback,
392 b));
393 } else {
394 auto element_address = [&](int64_t operand, llvm::Value* index) {
395 keys_multi_index[dimension_to_sort] = index;
396 IrArray::Index keys_index(keys_multi_index,
397 values_arrays[operand].GetShape(),
398 tiles_index.GetType());
399 return values_arrays[operand].EmitArrayElementAddress(keys_index, b);
400 };
401 auto element_address_pointee_type = [&](int64_t operand, llvm::Value*) {
402 return values_arrays[operand].GetElementLlvmType();
403 };
404 auto write_element = [&](int64_t operand, llvm::Value* index,
405 llvm::Value* value) {
406 keys_multi_index[dimension_to_sort] = index;
407 IrArray::Index keys_index(keys_multi_index,
408 values_arrays[operand].GetShape(),
409 tiles_index.GetType());
410 values_arrays[operand].EmitWriteArrayElement(keys_index, value, b);
411 };
412 TF_RETURN_IF_ERROR(EmitCompareLoopBody(
413 dimension_to_sort_bound, values_arrays.size(), tiles_index[rank - 1],
414 xor_masks[0], tiles_index.GetType(), element_address,
415 element_address_pointee_type, write_element, emit_compare_callback,
416 b));
417 }
418 return OkStatus();
419 };
420 return gpu::ParallelLoopEmitter(compare_loop_body_emitter, iteration_shape,
421 launch_dimensions, b)
422 .EmitLoop(name);
423 }
424
425 } // namespace llvm_ir
426 } // namespace xla
427