xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h"
17 
18 #include <vector>
19 
20 // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
21 #include "absl/strings/str_cat.h"
22 #include "absl/strings/string_view.h"
23 #include "absl/types/span.h"
24 #include "llvm/ADT/APInt.h"
25 #include "llvm/IR/BasicBlock.h"
26 #include "llvm/IR/Constants.h"
27 #include "llvm/IR/DerivedTypes.h"
28 #include "llvm/IR/Instructions.h"
29 #include "llvm/IR/Value.h"
30 #include "tensorflow/compiler/xla/primitive_util.h"
31 #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
32 #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h"
33 #include "tensorflow/compiler/xla/service/gpu/target_util.h"
34 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
35 #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
36 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
37 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
38 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
39 #include "tensorflow/compiler/xla/shape_util.h"
40 #include "tensorflow/compiler/xla/util.h"
41 #include "tensorflow/core/lib/core/status.h"
42 
43 namespace xla {
44 namespace llvm_ir {
45 
46 namespace {
47 
48 // Adds the inner comparison loop body where we compare elements.
EmitCompareLoopBody(int64_t iteration_bound,int64_t num_values,llvm::Value * element_pair_index,int64_t xor_mask,llvm::Type * index_type,std::function<llvm::Value * (int64_t operand,llvm::Value * index)> element_address,std::function<llvm::Type * (int64_t operand,llvm::Value * index)> element_address_pointee_type,std::function<void (int64_t operand,llvm::Value * index,llvm::Value * value)> write_element,const EmitCallToNestedComputationCallback & emit_compare_callback,llvm::IRBuilder<> * b,bool needs_bounds_checks=true)49 Status EmitCompareLoopBody(
50     int64_t iteration_bound, int64_t num_values,
51     llvm::Value* element_pair_index, int64_t xor_mask, llvm::Type* index_type,
52     std::function<llvm::Value*(int64_t operand, llvm::Value* index)>
53         element_address,
54     std::function<llvm::Type*(int64_t operand, llvm::Value* index)>
55         element_address_pointee_type,
56     std::function<void(int64_t operand, llvm::Value* index, llvm::Value* value)>
57         write_element,
58     const EmitCallToNestedComputationCallback& emit_compare_callback,
59     llvm::IRBuilder<>* b, bool needs_bounds_checks = true) {
60   auto index_typed_constant = [&](int64_t value) {
61     return llvm::ConstantInt::get(index_type, value);
62   };
63   // The 'xor_mask' determines which elements are compared against each other.
64   // Index 'current_keys_index' will be compared with 'current_keys_index' xor
65   // 'xor_mask'. This means that we will always compare a block of consecutive
66   // elements against elements from the adjacent block of the same size. When
67   // 'xor_mask' is a power of 2, it immediately identifies the size of such a
68   // block. We can also have 'xor_mask' being 2^k - 1 (for some value of k). In
69   // that case, we essentially flip the last 'k' - 1 bits when computing the
70   // position of the element to compare to, so the block size is 2^(k - 1).
71   int64_t block_size = xor_mask;
72   // Check if it is a value 2^k - 1.
73   if (xor_mask > 1 && (xor_mask & (xor_mask + 1)) == 0) {
74     block_size = (xor_mask + 1) / 2;
75   }
76   auto current_keys_index = element_pair_index;
77   if (block_size == 1) {
78     // If the block size is 1, we take every second element and compare it to
79     // the next one.
80     current_keys_index =
81         b->CreateMul(current_keys_index, index_typed_constant(2));
82   } else if (block_size * 2 < iteration_bound) {
83     // current_keys_index iterates through the 'left' elements of the element
84     // pairs to be compared. We first need to compute the comparison block to
85     // which the element belongs. The block id of that block is index /
86     // block_size.
87     auto block_id =
88         b->CreateUDiv(current_keys_index, index_typed_constant(block_size));
89     // The index of the 'left' element within its block is simply the remainder
90     // when dividing by 'block_size'.
91     auto index_within_block =
92         b->CreateURem(current_keys_index, index_typed_constant(block_size));
93     // The first element of the 'left' block of elements that is compared
94     // against elements from the adjacent 'right' block of elements is
95     // 'block_id' * (2 * 'block_size').
96     auto first_element_in_block =
97         b->CreateMul(block_id, index_typed_constant(2 * block_size));
98     current_keys_index =
99         b->CreateAdd(first_element_in_block, index_within_block);
100   }
101   auto compare_keys_index =
102       b->CreateXor(current_keys_index, index_typed_constant(xor_mask));
103   // current_keys_index < compare_keys_index
104   llvm::Value* is_smaller_index =
105       b->CreateICmpSLT(current_keys_index, compare_keys_index);
106   // compare_keys_index < iteration_bound
107   llvm::Value* index_is_inbounds = b->CreateICmpSLT(
108       compare_keys_index, index_typed_constant(iteration_bound));
109   llvm::Value* do_comparison =
110       needs_bounds_checks ? b->CreateAnd(is_smaller_index, index_is_inbounds)
111                           : b->getInt1(true);
112 
113   // if (is_smaller_index && index_is_inbounds)
114   KernelSupportLibrary ksl(b);
115   return ksl.IfWithStatus("smaller_comparison_index", do_comparison, [&]() {
116     std::vector<llvm::Value*> values_to_compare;
117     std::vector<llvm::Type*> values_to_compare_types;
118     for (int i = 0; i < num_values; ++i) {
119       values_to_compare.push_back(element_address(i, compare_keys_index));
120       values_to_compare_types.push_back(
121           element_address_pointee_type(i, compare_keys_index));
122 
123       values_to_compare.push_back(element_address(i, current_keys_index));
124       values_to_compare_types.push_back(
125           element_address_pointee_type(i, current_keys_index));
126     }
127     llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
128     llvm::Type* pred_type = llvm_ir::PrimitiveTypeToIrType(PRED, module);
129     llvm::Value* compare_return_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
130         pred_type, "compare_return_buffer", b);
131     TF_RETURN_IF_ERROR(
132         emit_compare_callback(values_to_compare, compare_return_buffer));
133     llvm::Value* result = b->CreateLoad(pred_type, compare_return_buffer);
134 
135     // Check if the 'compare' function returns true.
136     llvm::Value* is_smaller_than =
137         b->CreateICmpNE(result, llvm::ConstantInt::get(result->getType(), 0),
138                         "boolean_predicate");
139     ksl.If("is_smaller_than", is_smaller_than, [&]() {
140       for (int64_t i = 0; i < num_values; ++i) {
141         // Swap the values.
142         auto value1 = b->CreateLoad(values_to_compare_types[i * 2],
143                                     values_to_compare[i * 2]);
144         auto value2 = b->CreateLoad(values_to_compare_types[i * 2 + 1],
145                                     values_to_compare[i * 2 + 1]);
146         write_element(i, current_keys_index, value1);
147         write_element(i, compare_keys_index, value2);
148       }
149     });
150     return OkStatus();
151   });
152 }
153 
EmitTiledCompareLoop(const IrArray::Index & tiled_keys_index,int64_t dimension_to_sort,int64_t dimension_to_sort_bound,absl::Span<const int64_t> xor_masks,const std::vector<IrArray> & params,const std::vector<llvm::GlobalVariable * > & param_shmem_buffers,int64_t tile_size,const EmitCallToNestedComputationCallback & emit_compare_callback,llvm::IRBuilder<> * b)154 Status EmitTiledCompareLoop(
155     const IrArray::Index& tiled_keys_index, int64_t dimension_to_sort,
156     int64_t dimension_to_sort_bound, absl::Span<const int64_t> xor_masks,
157     const std::vector<IrArray>& params,
158     const std::vector<llvm::GlobalVariable*>& param_shmem_buffers,
159     int64_t tile_size,
160     const EmitCallToNestedComputationCallback& emit_compare_callback,
161     llvm::IRBuilder<>* b) {
162   KernelSupportLibrary ksl(b);
163   llvm::Value* thread_id = gpu::EmitCallToTargetIntrinsic(
164       gpu::TargetIntrinsicID::kThreadIdx, {}, {}, b);
165   llvm_ir::AddRangeMetadata(0, tile_size / 2,
166                             llvm::cast<llvm::Instruction>(thread_id));
167   thread_id = b->CreateIntCast(thread_id, tiled_keys_index.GetType(),
168                                /*isSigned=*/true, "thread.id.x");
169 
170   auto copy_loop_body =
171       [&](std::function<void(llvm::Value * cache_index, llvm::Value * index)>
172               read_or_write) {
173         auto value_one = tiled_keys_index.GetConstantWithIndexType(1);
174         auto current_keys_index =
175             b->CreateShl(tiled_keys_index[dimension_to_sort], value_one);
176         // We want to copy two adjacent elements. We first check whether the
177         // first index position is within bounds.
178         ksl.If(
179             "smaller_keys_index",
180             b->CreateICmpSLT(current_keys_index,
181                              tiled_keys_index.GetConstantWithIndexType(
182                                  dimension_to_sort_bound)),
183             [&]() {
184               auto cache_index = b->CreateShl(thread_id, value_one);
185               read_or_write(cache_index, current_keys_index);
186               // Increment to go to the next index position.
187               current_keys_index = b->CreateAdd(current_keys_index, value_one);
188               // Here we check whether the next index position is within bounds.
189               ksl.If("inner_smaller_keys_index",
190                      b->CreateICmpSLT(current_keys_index,
191                                       tiled_keys_index.GetConstantWithIndexType(
192                                           dimension_to_sort_bound)),
193                      [&]() {
194                        cache_index = b->CreateAdd(cache_index, value_one);
195                        read_or_write(cache_index, current_keys_index);
196                      });
197             });
198       };
199 
200   // Copy operand tiles from the operand buffers to shared memory.
201   std::vector<llvm::Value*> keys_multi_index = tiled_keys_index.multidim();
202   for (int64_t i = 0; i < params.size(); ++i) {
203     copy_loop_body([&](llvm::Value* cache_index, llvm::Value* index) {
204       keys_multi_index[dimension_to_sort] = index;
205       IrArray::Index keys_index(keys_multi_index, params[i].GetShape(),
206                                 tiled_keys_index.GetType());
207       auto value = params[i].EmitReadArrayElement(keys_index, b);
208       b->CreateStore(
209           value,
210           b->CreateGEP(
211               param_shmem_buffers[i]->getValueType(), param_shmem_buffers[i],
212               {tiled_keys_index.GetConstantWithIndexType(0), cache_index}));
213     });
214   }
215   // Wait until all reads have happened.
216   gpu::EmitCallToTargetIntrinsic(gpu::TargetIntrinsicID::kBarrierId, {}, {}, b);
217 
218   // Now emit the bodies of the comparison loops.
219   auto element_address = [&](int64_t operand, llvm::Value* index) {
220     auto shared_memory_address =
221         b->CreateGEP(param_shmem_buffers[operand]->getValueType(),
222                      param_shmem_buffers[operand],
223                      {tiled_keys_index.GetConstantWithIndexType(0), index});
224     auto ptr_type = shared_memory_address->getType();
225     // We need a generic pointer with address space 0 instead of a pointer to
226     // shared memory (address space 3) so that we can pass it to the comparison
227     // computation.
228     return b->CreateAddrSpaceCast(shared_memory_address,
229                                   llvm::PointerType::getWithSamePointeeType(
230                                       llvm::cast<llvm::PointerType>(ptr_type),
231                                       /*AddressSpace=*/0));
232   };
233   auto element_address_pointee_type = [&](int64_t operand, llvm::Value* index) {
234     return llvm::GetElementPtrInst::getIndexedType(
235         param_shmem_buffers[operand]->getValueType(),
236         {tiled_keys_index.GetConstantWithIndexType(0), index});
237   };
238   auto write_element = [&](int64_t operand, llvm::Value* index,
239                            llvm::Value* value) {
240     b->CreateStore(
241         value,
242         b->CreateGEP(param_shmem_buffers[operand]->getValueType(),
243                      param_shmem_buffers[operand],
244                      {tiled_keys_index.GetConstantWithIndexType(0), index}));
245   };
246   for (int64_t xor_mask : xor_masks) {
247     // The index of the element pair to be compared within the tile stored in
248     // shared memory. We order the element pairs by the element with the smaller
249     // index.
250     auto element_pair_index = thread_id;
251     // If 'dimension_to_sort_bound' is evenly divisible by 'tile_size', we don't
252     // need any bounds checks.
253     if (dimension_to_sort_bound % tile_size) {
254       // Otherwise we need a bounds check for the last tile. The last tile has
255       // size 'dimension_to_sort_bound' % 'tile_size'.
256       TF_RETURN_IF_ERROR(ksl.IfWithStatus(
257           "is_last_tile",
258           b->CreateICmpUGE(
259               b->CreateMul(tiled_keys_index[dimension_to_sort],
260                            tiled_keys_index.GetConstantWithIndexType(2)),
261               tiled_keys_index.GetConstantWithIndexType(
262                   RoundDownTo(dimension_to_sort_bound, tile_size))),
263           [&]() {
264             return EmitCompareLoopBody(
265                 dimension_to_sort_bound % tile_size, params.size(),
266                 element_pair_index, xor_mask, tiled_keys_index.GetType(),
267                 element_address, element_address_pointee_type, write_element,
268                 emit_compare_callback, b);
269           },
270           [&]() {
271             return EmitCompareLoopBody(
272                 tile_size, params.size(), element_pair_index, xor_mask,
273                 tiled_keys_index.GetType(), element_address,
274                 element_address_pointee_type, write_element,
275                 emit_compare_callback, b,
276                 /*needs_bounds_checks=*/false);
277           }));
278     } else {
279       TF_RETURN_IF_ERROR(EmitCompareLoopBody(
280           tile_size, params.size(), element_pair_index, xor_mask,
281           tiled_keys_index.GetType(), element_address,
282           element_address_pointee_type, write_element, emit_compare_callback, b,
283           /*needs_bounds_checks=*/false));
284     }
285     // Wait until all comparisons have happened.
286     gpu::EmitCallToTargetIntrinsic(gpu::TargetIntrinsicID::kBarrierId, {}, {},
287                                    b);
288   }
289 
290   // Copy the operand tiles back from shared memory to the operand buffers.
291   for (int64_t i = 0; i < params.size(); ++i) {
292     copy_loop_body([&](llvm::Value* cache_index, llvm::Value* index) {
293       keys_multi_index[dimension_to_sort] = index;
294       IrArray::Index keys_index(keys_multi_index, params[i].GetShape(),
295                                 tiled_keys_index.GetType());
296       auto gep = b->CreateGEP(
297           param_shmem_buffers[i]->getValueType(), param_shmem_buffers[i],
298           {tiled_keys_index.GetConstantWithIndexType(0), cache_index});
299       auto gep_type = llvm::GetElementPtrInst::getIndexedType(
300           param_shmem_buffers[i]->getValueType(),
301           {tiled_keys_index.GetConstantWithIndexType(0), cache_index});
302       auto value = b->CreateLoad(gep_type, gep);
303       params[i].EmitWriteArrayElement(keys_index, value, b);
304     });
305   }
306   // We should normally synchronize here to make sure all writes have happened.
307   // However the very next thing each thread does is reading 2 elements from the
308   // operand buffer and writing it into the same location in shared memory from
309   // which it previously copied it to the operand buffer, and we synchronize
310   // after this has happened. We can be sure that a thread always writes to the
311   // same location in shared memory because we have exactly tile_size / 2 many
312   // threads, and the linear index calculated by ParallelLoopEmitter uses
313   // linear_index = blockIdx.x * blockDim.x + threadIdx.x;
314   return OkStatus();
315 }
316 }  // namespace
317 
EmitSortInPlace(int64_t dimension_to_sort,const std::vector<IrArray> & values_arrays,absl::string_view name,absl::Span<const int64_t> xor_masks,llvm::IRBuilder<> * b,const gpu::LaunchDimensions & launch_dimensions,int64_t num_iterations_in_sort_dim,const int64_t tile_size,const EmitCallToNestedComputationCallback & emit_compare_callback)318 Status EmitSortInPlace(
319     int64_t dimension_to_sort, const std::vector<IrArray>& values_arrays,
320     absl::string_view name, absl::Span<const int64_t> xor_masks,
321     llvm::IRBuilder<>* b, const gpu::LaunchDimensions& launch_dimensions,
322     int64_t num_iterations_in_sort_dim, const int64_t tile_size,
323     const EmitCallToNestedComputationCallback& emit_compare_callback) {
324   // Iterate through the keys shape in physical order, but skip the dimension to
325   // sort and make it the innermost loop which is the loop where the comparisons
326   // happen. In the dimension to sort, if we use tiling, we iterate through it
327   // in tiles of 64 elements each, so we use another loop that happens within
328   // one thread to process this tile worth of data (thereby combining several
329   // comparison stages of the bitonic sort algorithm because they all happen
330   // within those 64 elements and are therefore independent of the other
331   // comparisons).
332 
333   const Shape& keys_shape = values_arrays[0].GetShape();
334   int64_t rank = keys_shape.rank();
335   int64_t dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort);
336   std::vector<int64_t> dimensions_in_iteration_order(rank);
337   std::vector<int64_t> iteration_order_to_logical_order(rank);
338   int64_t dim = 0;
339   for (int64_t dimension : LayoutUtil::MinorToMajor(keys_shape)) {
340     if (dimension != dimension_to_sort) {
341       dimensions_in_iteration_order[dim] = keys_shape.dimensions(dimension);
342       iteration_order_to_logical_order[dim++] = dimension;
343     }
344   }
345   dimensions_in_iteration_order[dim] = num_iterations_in_sort_dim;
346   iteration_order_to_logical_order[dim] = dimension_to_sort;
347 
348   Shape iteration_shape = ShapeUtil::MakeShape(keys_shape.element_type(),
349                                                dimensions_in_iteration_order);
350 
351   // Allocate shared memory for the tiled compare loop.
352   std::vector<llvm::GlobalVariable*> param_shmem_buffers(values_arrays.size(),
353                                                          nullptr);
354   if (xor_masks.size() > 1) {
355     llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
356     for (int64_t i = 0; i < values_arrays.size(); ++i) {
357       llvm::Type* tile_type = llvm::ArrayType::get(
358           llvm_ir::PrimitiveTypeToIrType(
359               values_arrays[i].GetShape().element_type(), module),
360           tile_size);
361       param_shmem_buffers[i] = llvm_ir::AllocateSharedMemoryTile(
362           module, tile_type, absl::StrCat(name, "_tile_param_", i));
363     }
364   }
365 
366   auto compare_loop_body_emitter =
367       [&](const IrArray::Index& tiles_index) -> Status {
368     // Naive C++ code for the inner compare loop:
369     //
370     // for (int64_t i = 0; i < dimension_to_sort_bound; ++i) {
371     //   int64_t j = i ^ xor_mask;
372     //   /* emitted in EmitCompareLoopBody() */
373     //   if (i < j && j < dimension_to_sort_bound) {
374     //     int64_t min_key = std::min(keys[i], keys[j]);
375     //     keys[j] = std::max(keys[i], keys[j]);
376     //     keys[i] = min_key;
377     //   }
378     // }
379     //
380     // This follows the algorithm described on Wikipedia:
381     // https://en.wikipedia.org/wiki/Bitonic_sorter
382     std::vector<llvm::Value*> keys_multi_index(rank);
383     for (int64_t i = 0; i < rank; ++i) {
384       keys_multi_index[iteration_order_to_logical_order[i]] = tiles_index[i];
385     }
386     if (xor_masks.size() > 1) {
387       IrArray::Index keys_index(keys_multi_index, values_arrays[0].GetShape(),
388                                 tiles_index.GetType());
389       TF_RETURN_IF_ERROR(EmitTiledCompareLoop(
390           keys_index, dimension_to_sort, dimension_to_sort_bound, xor_masks,
391           values_arrays, param_shmem_buffers, tile_size, emit_compare_callback,
392           b));
393     } else {
394       auto element_address = [&](int64_t operand, llvm::Value* index) {
395         keys_multi_index[dimension_to_sort] = index;
396         IrArray::Index keys_index(keys_multi_index,
397                                   values_arrays[operand].GetShape(),
398                                   tiles_index.GetType());
399         return values_arrays[operand].EmitArrayElementAddress(keys_index, b);
400       };
401       auto element_address_pointee_type = [&](int64_t operand, llvm::Value*) {
402         return values_arrays[operand].GetElementLlvmType();
403       };
404       auto write_element = [&](int64_t operand, llvm::Value* index,
405                                llvm::Value* value) {
406         keys_multi_index[dimension_to_sort] = index;
407         IrArray::Index keys_index(keys_multi_index,
408                                   values_arrays[operand].GetShape(),
409                                   tiles_index.GetType());
410         values_arrays[operand].EmitWriteArrayElement(keys_index, value, b);
411       };
412       TF_RETURN_IF_ERROR(EmitCompareLoopBody(
413           dimension_to_sort_bound, values_arrays.size(), tiles_index[rank - 1],
414           xor_masks[0], tiles_index.GetType(), element_address,
415           element_address_pointee_type, write_element, emit_compare_callback,
416           b));
417     }
418     return OkStatus();
419   };
420   return gpu::ParallelLoopEmitter(compare_loop_body_emitter, iteration_shape,
421                                   launch_dimensions, b)
422       .EmitLoop(name);
423 }
424 
425 }  // namespace llvm_ir
426 }  // namespace xla
427