xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
17 
18 #include <algorithm>
19 #include <array>
20 #include <cstdint>
21 #include <numeric>
22 #include <optional>
23 #include <vector>
24 
25 #include "llvm/IR/IntrinsicsNVPTX.h"
26 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"  // from @llvm-project
27 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
28 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
29 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
30 #include "tensorflow/compiler/xla/service/gpu/target_util.h"
31 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
32 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
33 #include "tensorflow/compiler/xla/service/hlo_parser.h"
34 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_type_conversion_util.h"
35 
36 namespace xla {
37 namespace gpu {
38 
39 namespace {
40 
41 // Return whether the given shape is rank 2 excluding the batch dimensions.
IsRank2(const Shape & shape,int64_t batch_dimensions_size)42 bool IsRank2(const Shape& shape, int64_t batch_dimensions_size) {
43   return shape.rank() == batch_dimensions_size + 2;
44 }
45 
46 // Given a shape and a group of contiguous dimensions in the shape, returns
47 // a tuple of three values (major, middle, minor), where major is the size of
48 // the dimensions more major then the given dimensions, minor is the size of
49 // dimensions more minor then the given dimensions, and middle is the size of
50 // the given dimensions.
PartitionShapeByMiddleDimensions(const Shape & shape,absl::Span<const int64_t> dims_middle)51 std::array<int64_t, 3> PartitionShapeByMiddleDimensions(
52     const Shape& shape, absl::Span<const int64_t> dims_middle) {
53   CHECK(LayoutUtil::AreDimensionsConsecutive(shape.layout(), dims_middle));
54   std::array<int64_t, 3> values = {1, 1, 1};
55   enum Segment { kMajor = 0, kMiddle = 1, kMinor = 2 };
56   Segment cur_segment = kMinor;
57 
58   for (int64_t cur_dim : LayoutUtil::MinorToMajor(shape)) {
59     if (cur_segment != kMajor) {
60       // Handle change of segments.
61       bool cur_dim_in_middle = absl::c_linear_search(dims_middle, cur_dim);
62       if (cur_segment == kMinor) {
63         if (cur_dim_in_middle) {
64           cur_segment = kMiddle;
65         }
66       } else if (cur_segment == kMiddle) {
67         if (!cur_dim_in_middle) {
68           cur_segment = kMajor;
69         }
70       }
71     }
72     values[cur_segment] *= shape.dimensions(cur_dim);
73   }
74   return values;
75 }
76 
GetShapeFromTensorType(mlir::Value value)77 Shape GetShapeFromTensorType(mlir::Value value) {
78   constexpr char kDefaultLayoutAttrName[] = "xla_shape";
79 
80   mlir::Operation* op = value.getDefiningOp();
81   CHECK(op);
82   CHECK(value.getType().isa<mlir::TensorType>());
83   Shape shape;
84   if (auto attr = op->getAttrOfType<mlir::StringAttr>(kDefaultLayoutAttrName)) {
85     shape = *xla::ParseShape(
86         absl::string_view(attr.getValue().data(), attr.getValue().size()));
87   } else {
88     shape = TypeToShape(value.getType());
89   }
90   return shape;
91 }
92 
93 }  // namespace
94 
IsMatrixMultiplication(const HloInstruction & dot)95 bool IsMatrixMultiplication(const HloInstruction& dot) {
96   if (dot.opcode() != HloOpcode::kDot) {
97     return false;
98   }
99   const Shape& lhs_shape = dot.operand(0)->shape();
100   const Shape& rhs_shape = dot.operand(1)->shape();
101   const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers();
102 
103   PrimitiveType output_primitive_type = dot.shape().element_type();
104   bool type_is_allowed =
105       (output_primitive_type == F16 || output_primitive_type == BF16 ||
106        output_primitive_type == F32 || output_primitive_type == F64 ||
107        output_primitive_type == C64 || output_primitive_type == C128) ||
108       (output_primitive_type == S32 && lhs_shape.element_type() == S8 &&
109        lhs_shape.element_type() == S8);
110   bool shapes_are_valid =
111       type_is_allowed &&
112       IsRank2(lhs_shape, dim_numbers.lhs_batch_dimensions_size()) &&
113       IsRank2(rhs_shape, dim_numbers.lhs_batch_dimensions_size()) &&
114       IsRank2(dot.shape(), dim_numbers.lhs_batch_dimensions_size()) &&
115       !ShapeUtil::IsZeroElementArray(lhs_shape) &&
116       !ShapeUtil::IsZeroElementArray(rhs_shape);
117 
118   if (!shapes_are_valid) {
119     return false;
120   }
121 
122   // The size of the reduction dimension should match. The shape inference
123   // guarantees this invariant, so the check here is for programming
124   // errors.
125   CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)),
126            rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0)));
127 
128   return true;
129 }
130 
GetReductionTiling(const ReductionDimensions & reduction_dimensions,se::CudaComputeCapability cuda_compute_capability)131 std::array<int64_t, 3> GetReductionTiling(
132     const ReductionDimensions& reduction_dimensions,
133     se::CudaComputeCapability cuda_compute_capability) {
134   if (reduction_dimensions.is_row_reduction) {
135     int64_t tile_z = std::min(reduction_dimensions.dimensions[0],
136                               BatchedReductionRaceFreeBound());
137     return {tile_z, 1, 16};
138   }
139 
140   // Column reduction.
141   return {1, 128, 1};
142 }
143 
144 const char* const kCusolverCholeskyCallTarget = "__cusolver$cholesky";
145 
IsCustomCallToCusolver(const HloInstruction & hlo)146 bool IsCustomCallToCusolver(const HloInstruction& hlo) {
147   if (hlo.opcode() != HloOpcode::kCustomCall) {
148     return false;
149   }
150   const auto& target = hlo.custom_call_target();
151   return target == kCusolverCholeskyCallTarget;
152 }
153 
GetReductionKindAndContiguousComponentsImpl(const Shape & input_shape,absl::Span<const int64_t> dims_to_reduce)154 static ReductionDimensions GetReductionKindAndContiguousComponentsImpl(
155     const Shape& input_shape, absl::Span<const int64_t> dims_to_reduce) {
156   DimensionVector dims_to_keep;
157   for (int64_t dim = 0; dim < input_shape.rank(); ++dim) {
158     if (!absl::c_linear_search(dims_to_reduce, dim)) {
159       dims_to_keep.push_back(dim);
160     }
161   }
162 
163   if (dims_to_keep.empty()) {
164     return {/*is_row_reduction=*/true,
165             {1, 1, ShapeUtil::ElementsIn(input_shape)}};
166   }
167 
168   if (LayoutUtil::AreDimensionsConsecutive(input_shape.layout(),
169                                            dims_to_keep)) {
170     std::array<int64_t, 3> shape_partition =
171         PartitionShapeByMiddleDimensions(input_shape, dims_to_keep);
172     if (shape_partition[1] == 1) {
173       return {/*is_row_reduction=*/true,
174               {1, 1, shape_partition[0] * shape_partition[2]}};
175     }
176     if (shape_partition[2] == 1) {
177       return {/*is_row_reduction=*/false,
178               {1, shape_partition[0], shape_partition[1]}};
179     }
180     return {/*is_row_reduction=*/true, shape_partition};
181   }
182 
183   std::array<int64_t, 3> shape_partition =
184       PartitionShapeByMiddleDimensions(input_shape, dims_to_reduce);
185 
186   if (shape_partition[2] == 1) {
187     return {/*is_row_reduction=*/true,
188             {1, shape_partition[0], shape_partition[1]}};
189   }
190   return {/*is_row_reduction=*/false, shape_partition};
191 }
192 
IsUnnestedReductionFasterThanElemental(const ReductionDimensions & reduction_dimensions)193 static bool IsUnnestedReductionFasterThanElemental(
194     const ReductionDimensions& reduction_dimensions) {
195   if (reduction_dimensions.is_row_reduction) {
196     // For row reduction, the tile block is 1 x tile_size_x, and we are reducing
197     // along tile_size_x which needs to be large enough to make the tiling
198     // implementation efficient.
199     // For very small reductions with a power-of-two size, we can fit multiple
200     // reductions inside a single warp, which is more efficient than a loop.
201     return (reduction_dimensions.dimensions[2] >= WarpSize()) ||
202            ((WarpSize() % reduction_dimensions.dimensions[2]) == 0);
203   }
204 
205   // For column reduction, the tile block is tile_size_y x tile_size_x, and we
206   // are reducing along tile_size_y. Only tile_size_y needs to be
207   // large enough to make the tiling implementation efficient.
208   int64_t major_size = reduction_dimensions.dimensions[1];
209   int64_t minor_size = reduction_dimensions.dimensions[2];
210 
211   // Rule generated by sweeping the search space of small column reductions.
212   bool prefer_elemental_emitter =
213       (major_size < WarpSize()) ||
214       (major_size < 2 * WarpSize() && minor_size < WarpSize()) ||
215       (major_size < 4 * WarpSize() && minor_size < 8) ||
216       (major_size < 8 * WarpSize() && minor_size < 3);
217 
218   return !prefer_elemental_emitter;
219 }
220 
221 // Whether we can/should use the unnested emitter for reduction.
IsReductionFromOrToContiguousDimensionsImpl(const Shape & operand_shape,absl::Span<int64_t const> dims_to_reduce)222 static bool IsReductionFromOrToContiguousDimensionsImpl(
223     const Shape& operand_shape, absl::Span<int64_t const> dims_to_reduce) {
224   DimensionVector dims_to_keep;
225   for (int64_t dim = 0; dim < operand_shape.dimensions().size(); ++dim) {
226     if (!absl::c_linear_search(dims_to_reduce, dim)) {
227       dims_to_keep.push_back(dim);
228     }
229   }
230 
231   // We support fast codegen for three cases:
232   // 1) Row reduction: (K, R)
233   // 2) Column reduction: (K, R, K)
234   // 3) "Batched" row reduction: (R, K, R)
235   return (LayoutUtil::AreDimensionsConsecutive(operand_shape.layout(),
236                                                dims_to_keep) ||
237           LayoutUtil::AreDimensionsConsecutive(operand_shape.layout(),
238                                                dims_to_reduce)) &&
239          IsUnnestedReductionFasterThanElemental(
240              GetReductionKindAndContiguousComponentsImpl(operand_shape,
241                                                          dims_to_reduce));
242 }
243 
IsReductionFromOrToContiguousDimensions(const HloInstruction & reduce)244 bool IsReductionFromOrToContiguousDimensions(const HloInstruction& reduce) {
245   return reduce.opcode() == HloOpcode::kReduce &&
246          IsReductionFromOrToContiguousDimensionsImpl(reduce.operand(0)->shape(),
247                                                      reduce.dimensions());
248 }
249 
IsReductionFromOrToContiguousDimensions(mlir::Operation * op)250 bool IsReductionFromOrToContiguousDimensions(mlir::Operation* op) {
251   auto reduce = mlir::dyn_cast<mlir::mhlo::ReduceOp>(op);
252   if (!reduce) {
253     return false;
254   }
255 
256   mlir::Value first_operand = reduce.operands()[0];
257   Shape operand_shape = GetShape(first_operand);
258 
259   llvm::SmallVector<int64_t> dimensions_to_reduce;
260   for (const llvm::APInt& d : reduce.dimensions()) {
261     dimensions_to_reduce.push_back(d.getZExtValue());
262   }
263 
264   return IsReductionFromOrToContiguousDimensionsImpl(operand_shape,
265                                                      dimensions_to_reduce);
266 }
267 
IsInputFusibleSlices(mlir::Operation * unnested_hlo,bool verify_no_strides)268 bool IsInputFusibleSlices(mlir::Operation* unnested_hlo,
269                           bool verify_no_strides) {
270   auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(unnested_hlo);
271   if (!fusion) {
272     return false;
273   }
274 
275   auto is_non_strided = [](mlir::DenseIntElementsAttr strides) -> bool {
276     return absl::c_all_of(
277         strides, [](const llvm::APInt& stride) { return stride == 1; });
278   };
279 
280   for (mlir::Value value : fusion.getFusionResults()) {
281     auto slice =
282         mlir::dyn_cast_or_null<mlir::mhlo::SliceOp>(value.getDefiningOp());
283     if (!slice) {
284       return false;
285     }
286     if (verify_no_strides && !is_non_strided(slice.strides())) {
287       return false;
288     }
289   }
290   return true;
291 }
292 
GetReductionKindAndContiguousComponents(const HloInstruction & reduce)293 ReductionDimensions GetReductionKindAndContiguousComponents(
294     const HloInstruction& reduce) {
295   return GetReductionKindAndContiguousComponentsImpl(reduce.operand(0)->shape(),
296                                                      reduce.dimensions());
297 }
298 
GetReductionKindAndContiguousComponents(mlir::Operation * reduce)299 ReductionDimensions GetReductionKindAndContiguousComponents(
300     mlir::Operation* reduce) {
301   mlir::Value input = reduce->getOperand(0);
302   Shape operand_shape = GetShape(input);
303   llvm::SmallVector<int64_t> dimensions_to_reduce;
304   for (const llvm::APInt& d :
305        mlir::cast<mlir::mhlo::ReduceOp>(reduce).dimensions()) {
306     dimensions_to_reduce.push_back(d.getZExtValue());
307   }
308   return GetReductionKindAndContiguousComponentsImpl(operand_shape,
309                                                      dimensions_to_reduce);
310 }
311 
312 // This emits a device-side call to
313 // "i32 vprintf(i8* fmt, arguments_type* arguments)" in the driver; see
314 // http://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability/index.html#system-calls
EmitPrintf(absl::string_view fmt,absl::Span<llvm::Value * const> arguments,llvm::IRBuilder<> * builder)315 llvm::Value* EmitPrintf(absl::string_view fmt,
316                         absl::Span<llvm::Value* const> arguments,
317                         llvm::IRBuilder<>* builder) {
318   std::vector<llvm::Type*> argument_types;
319 
320   // Variadic arguments implicit promotion [1] converts float to double,
321   // and bool/char/short are converted to int.
322   // [1] https://en.cppreference.com/w/cpp/language/variadic_arguments
323   auto requires_int32_promotion = [](llvm::Type* type) {
324     return type->isIntegerTy(/*BitWidth=*/1) ||
325            type->isIntegerTy(/*BitWidth=*/8) ||
326            type->isIntegerTy(/*BitWidth=*/16);
327   };
328   auto requires_double_promotion = [](llvm::Type* type) {
329     return type->isFloatingPointTy();
330   };
331 
332   for (auto argument : arguments) {
333     llvm::Type* type = argument->getType();
334     if (requires_double_promotion(type)) {
335       argument_types.push_back(builder->getDoubleTy());
336     } else if (requires_int32_promotion(type)) {
337       argument_types.push_back(builder->getInt32Ty());
338     } else {
339       argument_types.push_back(type);
340     }
341   }
342   auto* arguments_type = llvm::StructType::create(argument_types);
343   llvm::Value* arguments_ptr = builder->CreateAlloca(arguments_type);
344   for (size_t i = 0; i < arguments.size(); ++i) {
345     llvm::Value* value = arguments[i];
346     llvm::Type* type = value->getType();
347     if (requires_double_promotion(type)) {
348       value = builder->CreateFPCast(value, builder->getDoubleTy());
349     } else if (requires_int32_promotion(type)) {
350       value = builder->CreateIntCast(value, builder->getInt32Ty(),
351                                      /*isSigned=*/true);
352     }
353     builder->CreateStore(
354         value,
355         builder->CreateGEP(arguments_type, arguments_ptr,
356                            {builder->getInt64(0), builder->getInt32(i)}));
357   }
358   llvm::Type* ptr_ty = builder->getInt8Ty()->getPointerTo();
359   return builder->CreateCall(
360       builder->GetInsertBlock()->getParent()->getParent()->getOrInsertFunction(
361           "vprintf",
362           llvm::FunctionType::get(builder->getInt32Ty(), {ptr_ty, ptr_ty},
363                                   /*isVarArg=*/false)),
364       {builder->CreateGlobalStringPtr(llvm_ir::AsStringRef(fmt)),
365        builder->CreatePointerCast(arguments_ptr, ptr_ty)});
366 }
367 
368 // Helper function to emit call to AMDGPU shfl_down function.
EmitAMDGPUShflDown(llvm::Value * value,llvm::Value * offset,llvm::IRBuilder<> * b)369 llvm::Value* EmitAMDGPUShflDown(llvm::Value* value, llvm::Value* offset,
370                                 llvm::IRBuilder<>* b) {
371   llvm::Module* module = b->GetInsertBlock()->getModule();
372   CHECK_EQ(value->getType()->getPrimitiveSizeInBits(), 32);
373   auto* i32_ty = b->getInt32Ty();
374   llvm::FunctionCallee shfl_fn = module->getOrInsertFunction(
375       llvm_ir::AsStringRef("__ockl_readuplane_i32"),
376       llvm::FunctionType::get(/*Result=*/i32_ty, {i32_ty, i32_ty},
377                               /*isVarArg=*/false));
378   // AMDGPU device function requires first argument as i32.
379   llvm::Value* result =
380       b->CreateCall(shfl_fn, {b->CreateBitCast(value, i32_ty), offset});
381   // AMDGPU device function always returns an i32 type.
382   return b->CreateBitCast(result, value->getType());
383 }
384 
385 // Helper function to emit call to NVPTX shfl_down intrinsic.
EmitNVPTXShflDown(llvm::Value * value,llvm::Value * offset,llvm::IRBuilder<> * b)386 llvm::Value* EmitNVPTXShflDown(llvm::Value* value, llvm::Value* offset,
387                                llvm::IRBuilder<>* b) {
388   llvm::Module* module = b->GetInsertBlock()->getModule();
389   llvm::Intrinsic::ID llvm_intrinsic_id;
390   CHECK_EQ(value->getType()->getPrimitiveSizeInBits(), 32);
391   if (value->getType()->isFloatTy()) {
392     llvm_intrinsic_id = llvm::Intrinsic::nvvm_shfl_sync_down_f32;
393   } else {
394     llvm_intrinsic_id = llvm::Intrinsic::nvvm_shfl_sync_down_i32;
395   }
396   llvm::Function* intrinsic =
397       llvm::Intrinsic::getDeclaration(module, llvm_intrinsic_id, {});
398   return b->CreateCall(
399       intrinsic, {b->getInt32(-1), value, offset, b->getInt32(WarpSize() - 1)});
400 }
401 
EmitFullWarpShuffleDown(llvm::Value * value,llvm::Value * offset,llvm::IRBuilder<> * builder)402 llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset,
403                                      llvm::IRBuilder<>* builder) {
404   int bit_width = value->getType()->getPrimitiveSizeInBits();
405   llvm::Module* module = builder->GetInsertBlock()->getModule();
406   llvm::Triple target_triple = llvm::Triple(module->getTargetTriple());
407 
408   // Special case for efficiency
409   if (value->getType()->isFloatTy() && bit_width == 32) {
410     if (target_triple.isNVPTX()) {
411       return EmitNVPTXShflDown(value, offset, builder);
412     } else if (target_triple.getArch() == llvm::Triple::amdgcn) {
413       return EmitAMDGPUShflDown(value, offset, builder);
414     } else {
415       LOG(FATAL) << "Invalid triple " << target_triple.str();
416     }
417   }
418 
419   // We must split values wider than 32 bits as the "shfl" instruction operates
420   // on 32-bit values.
421   int num_segments = CeilOfRatio(bit_width, 32);
422   llvm::Value* x = builder->CreateBitCast(
423       builder->CreateZExt(
424           builder->CreateBitCast(value, builder->getIntNTy(bit_width)),
425           builder->getIntNTy(32 * num_segments)),
426       llvm::VectorType::get(builder->getInt32Ty(), num_segments, false));
427   for (int i = 0; i < num_segments; ++i) {
428     llvm::Value* insert_val;
429     if (target_triple.isNVPTX()) {
430       insert_val = EmitNVPTXShflDown(builder->CreateExtractElement(x, i),
431                                      offset, builder);
432     } else if (target_triple.getArch() == llvm::Triple::amdgcn) {
433       insert_val = EmitAMDGPUShflDown(builder->CreateExtractElement(x, i),
434                                       offset, builder);
435     } else {
436       LOG(FATAL) << "Invalid triple " << target_triple.str();
437     }
438     x = builder->CreateInsertElement(x, insert_val, i);
439   }
440   return builder->CreateBitCast(
441       builder->CreateTrunc(
442           builder->CreateBitCast(x, builder->getIntNTy(32 * num_segments)),
443           builder->getIntNTy(bit_width)),
444       value->getType());
445 }
446 
IsBlock0Thread0(llvm::IRBuilder<> * b)447 llvm::Value* IsBlock0Thread0(llvm::IRBuilder<>* b) {
448   llvm::Value* is_thread0 = b->CreateICmpEQ(
449       b->getInt32(0),
450       EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b));
451 
452   llvm::Value* is_block0 = b->CreateICmpEQ(
453       b->getInt32(0),
454       EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, b));
455   return b->CreateAnd(is_thread0, is_block0);
456 }
457 
IsFusedReductionOutputConsistent(const HloInstruction * inst,const HloInstruction * first_reduce)458 bool IsFusedReductionOutputConsistent(const HloInstruction* inst,
459                                       const HloInstruction* first_reduce) {
460   if (IsReductionFromOrToContiguousDimensions(*inst)) {
461     // Shapes, layouts and dimensions must be the same for all reduces
462     // inside of this fusion.
463     // TODO(tjoerg): Relax the shape constraint. The datatype does not matter.
464     return ShapeUtil::Equal(first_reduce->shape(), inst->shape()) &&
465            ShapeUtil::Equal(first_reduce->operand(0)->shape(),
466                             inst->operand(0)->shape()) &&
467            ShapeUtil::Equal(first_reduce->operand(1)->shape(),
468                             inst->operand(1)->shape()) &&
469            first_reduce->dimensions() == inst->dimensions();
470   }
471   return ShapeUtil::CompatibleIgnoringElementType(
472              first_reduce->operand(0)->shape(), inst->shape()) &&
473          LayoutUtil::Equal(first_reduce->operand(0)->shape().layout(),
474                            inst->shape().layout());
475 }
476 
477 // Given an LMHLO op, returns the operand index of the first output operand.
478 //
479 // Notice that an operand alised to an output isn't an output, even though in
480 // that case WritesMlirBuffer() returns true on that operand.
481 //
482 // An operand is !WritesMlirBuffer() || equals (aliases) to a later operand. An
483 // output is the opposite, being both WritesMlirBuffer() and does not equal to
484 // any later operand.
PartitionLmhloOperandsAndOutputs(mlir::Operation * op)485 int PartitionLmhloOperandsAndOutputs(mlir::Operation* op) {
486   CHECK(op->getDialect() == op->getContext()->getLoadedDialect("lmhlo"));
487 
488   int i;
489   for (i = op->getOperands().size() - 1; i >= 0; i--) {
490     const bool aliased =
491         std::find(op->getOperands().begin() + i + 1, op->getOperands().end(),
492                   op->getOperand(i)) != op->getOperands().end();
493     if (!WritesMlirBuffer(op, op->getOperand(i)) || aliased) {
494       break;
495     }
496   }
497   return i + 1;
498 }
499 
GetHloOperands(mlir::Operation * op)500 std::vector<mlir::Value> GetHloOperands(mlir::Operation* op) {
501   if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) {
502     return ToStdVector(fusion.getInputBuffers());
503   }
504   if (op->getDialect() == op->getContext()->getLoadedDialect("lmhlo")) {
505     int output_start = PartitionLmhloOperandsAndOutputs(op);
506     std::vector<mlir::Value> operands;
507     operands.reserve(output_start);
508     for (int i = 0; i < output_start; i++) {
509       operands.push_back(op->getOperand(i));
510     }
511     return operands;
512   }
513   if (op->getDialect() == op->getContext()->getLoadedDialect("mhlo")) {
514     return std::vector<mlir::Value>(op->getOperands().begin(),
515                                     op->getOperands().end());
516   }
517   LOG(FATAL) << "Unexpected op: " << MlirToString(op);
518 }
519 
GetHloOutputs(mlir::Operation * op)520 std::vector<mlir::Value> GetHloOutputs(mlir::Operation* op) {
521   if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) {
522     return ToStdVector(fusion.getOutputBuffers());
523   }
524   if (op->getDialect() == op->getContext()->getLoadedDialect("lmhlo")) {
525     int output_start = PartitionLmhloOperandsAndOutputs(op);
526     std::vector<mlir::Value> outputs;
527     for (int i = output_start; i < op->getNumOperands(); i++) {
528       outputs.push_back(op->getOperand(i));
529     }
530     return outputs;
531   }
532   if (op->getDialect() == op->getContext()->getLoadedDialect("mhlo")) {
533     return std::vector<mlir::Value>(op->getResults().begin(),
534                                     op->getResults().end());
535   }
536   LOG(FATAL) << "Unexpected op: " << MlirToString(op);
537 }
538 
WritesMlirBuffer(mlir::Operation * op,mlir::Value operand)539 bool WritesMlirBuffer(mlir::Operation* op, mlir::Value operand) {
540   llvm::SmallVector<mlir::MemoryEffects::EffectInstance, 2> effects;
541   mlir::cast<mlir::MemoryEffectOpInterface>(op).getEffectsOnValue(operand,
542                                                                   effects);
543   return absl::c_any_of(
544       effects, [](const mlir::MemoryEffects::EffectInstance& instance) {
545         return mlir::isa<mlir::MemoryEffects::Write>(instance.getEffect());
546       });
547 }
548 
GetMemRefSizeInBytes(mlir::MemRefType type)549 static int64_t GetMemRefSizeInBytes(mlir::MemRefType type) {
550   // For i1 memrefs, the underlying allocation is 8 bits.
551   if (type.getElementType().isInteger(/*width=*/1)) {
552     return type.getNumElements();
553   } else {
554     return type.cast<mlir::ShapedType>().getSizeInBits() / CHAR_BIT;
555   }
556 }
557 
GetAllocationIndex(mlir::BlockArgument func_arg,std::string * constant_name)558 static int64_t GetAllocationIndex(mlir::BlockArgument func_arg,
559                                   std::string* constant_name) {
560   auto func_op =
561       mlir::cast<mlir::func::FuncOp>(func_arg.getParentRegion()->getParentOp());
562   if (constant_name) {
563     if (auto constant_name_attr = func_op.getArgAttrOfType<mlir::StringAttr>(
564             func_arg.getArgNumber(), "lmhlo.constant_name")) {
565       *constant_name = constant_name_attr.getValue().str();
566     }
567   }
568   return func_arg.getArgNumber();
569 }
570 
GetAllocationSlice(mlir::Value v,absl::Span<const BufferAllocation> allocations,std::string * constant_name)571 StatusOr<BufferAllocation::Slice> GetAllocationSlice(
572     mlir::Value v, absl::Span<const BufferAllocation> allocations,
573     std::string* constant_name) {
574   if (constant_name) {
575     constant_name->clear();
576   }
577 
578   int64_t size = GetMemRefSizeInBytes(v.getType().cast<mlir::MemRefType>());
579 
580   // We match the following patterns here:
581   //  base := ViewOp(arg) | get_global_memref (global_memref) | arg
582   //  root := base | MemRefReinterpretCastOp(base)
583 
584   if (auto cast = mlir::dyn_cast_or_null<mlir::memref::ReinterpretCastOp>(
585           v.getDefiningOp())) {
586     v = cast.getViewSource();
587   }
588   if (auto view =
589           mlir::dyn_cast_or_null<mlir::memref::ViewOp>(v.getDefiningOp())) {
590     TF_RET_CHECK(view.getSource().isa<mlir::BlockArgument>());
591 
592     return BufferAllocation::Slice(
593         &allocations[GetAllocationIndex(
594             view.getSource().cast<mlir::BlockArgument>(), constant_name)],
595         mlir::cast<mlir::arith::ConstantOp>(view.getByteShift().getDefiningOp())
596             .getValue()
597             .cast<mlir::IntegerAttr>()
598             .getValue()
599             .getSExtValue(),
600         size);
601   }
602   if (auto get_global = mlir::dyn_cast_or_null<mlir::memref::GetGlobalOp>(
603           v.getDefiningOp())) {
604     auto module = get_global->getParentOfType<mlir::ModuleOp>();
605     if (constant_name) {
606       *constant_name = get_global.getName().str();
607     }
608     auto global = mlir::cast<mlir::memref::GlobalOp>(
609         module.lookupSymbol(get_global.getName()));
610     int64_t index =
611         global->getAttrOfType<mlir::IntegerAttr>("lmhlo.alloc").getInt();
612     return BufferAllocation::Slice(&allocations[index], 0,
613                                    allocations[index].size());
614   }
615   if (auto arg = v.dyn_cast<mlir::BlockArgument>()) {
616     return BufferAllocation::Slice(
617         &allocations[GetAllocationIndex(arg, constant_name)], 0, size);
618   }
619 
620   return Unimplemented(
621       "Operand has to be in the form of ViewOp(arg) or "
622       "StaticMemRefCastOp(ViewOp(arg)) or arg");
623 }
624 
CanEmitFusedDynamicUpdateSliceInPlaceForGpu(mlir::lmhlo::FusionOp fusion,absl::Span<const BufferAllocation> allocations)625 bool CanEmitFusedDynamicUpdateSliceInPlaceForGpu(
626     mlir::lmhlo::FusionOp fusion,
627     absl::Span<const BufferAllocation> allocations) {
628   auto results = fusion.getFusionResults();
629   if (results.size() != 1) {
630     return false;
631   }
632   auto dus = mlir::dyn_cast<mlir::mhlo::DynamicUpdateSliceOp>(
633       results[0].getDefiningOp());
634   if (!dus) {
635     return false;
636   }
637 
638   auto output_buffers = fusion.getOutputBuffers();
639   CHECK_EQ(1, output_buffers.size());
640   auto parameter = mlir::dyn_cast<mlir::bufferization::ToTensorOp>(
641       dus.operand().getDefiningOp());
642 
643   if (!parameter) {
644     return false;
645   }
646 
647   auto maybe_lhs = GetAllocationSlice(parameter.getMemref(), allocations);
648   auto maybe_rhs = GetAllocationSlice(output_buffers[0], allocations);
649   return maybe_lhs.ok() && maybe_rhs.ok() && *maybe_lhs == *maybe_rhs;
650 }
651 
GetShape(mlir::Value value)652 Shape GetShape(mlir::Value value) {
653   if (value.getType().isa<mlir::MemRefType>()) {
654     return TypeToShape(value.getType());
655   } else if (value.getType().isa<mlir::TensorType>()) {
656     return GetShapeFromTensorType(value);
657   } else if (value.getType().isa<mlir::TupleType>()) {
658     return TypeToShape(value.getType());
659   }
660   LOG(FATAL) << "Unexpected value type to get shape for";
661   return {};
662 }
663 
ReductionIsRaceFree(const ReductionDimensions & reduction_dimensions,const std::array<int64_t,3> & reduction_tiling)664 bool ReductionIsRaceFree(const ReductionDimensions& reduction_dimensions,
665                          const std::array<int64_t, 3>& reduction_tiling) {
666   return (reduction_dimensions.is_row_reduction &&
667           reduction_dimensions.dimensions[2] <=
668               MinThreadsXRowReduction() * reduction_tiling[2] &&
669           reduction_dimensions.dimensions[0] <=
670               BatchedReductionRaceFreeBound()) ||
671          (!reduction_dimensions.is_row_reduction &&
672           reduction_dimensions.dimensions[1] <=
673               WarpSize() * reduction_tiling[1]);
674 }
675 
676 // A recursive function to inspect the users of a parameter to determine
677 // whether it's safe for a parameter to participate in a shared-memory
678 // transpose.
679 //
680 // Consider a fusion parameter P for which we might want to use a shmem
681 // transpose.  If we do, we use a GPU thread block to preload a tile of P with
682 // indices [z, y..y+31, x..x+31] to compute an output tile with the same indices
683 // cooperatively, where z, y, x are the indices for the normalized input/output
684 // tensor (see the document for FindTranspose021 for the definition of
685 // normalized tensor for 0-2-1 transpose). This shmem transpose implementation
686 // requires that the computation of the output tile only read elements within
687 // the preload tile. If this is not true, we can't use a shmem transpose for P.
688 //
689 // If the computation of output element [z, y, x] only requires the element of
690 // P with the same indices, the shmem transpose implementation can be applied
691 // to P safely. This is a sufficient but not necessary condition. We check all
692 // the transitive users of P to see if we can find a user that may cause an
693 // exception to the situation. If such a user is not found, we conclude that P
694 // is safe for shmem transpose.
695 //
696 // This is trivially true for elementwise operations and some "data-movement"
697 // ops like kTuple. However, it's not true for operations that can change the
698 // dimensions of the inputs (e.g. pad, slice) and bitcast operation.
699 // For example:
700 //
701 // fused_computation {
702 //   param_0 = f32[64,64]{1,0} parameter(0)
703 //   ROOT bitcast = f32[64,64]{0,1} bitcast(param_0)
704 // }
705 // The output element at logical address [0, 63] depends on the input element
706 // at logical address [63, 0], which would not be within the shared-memory
707 // block.
708 //
709 // TODO(bixia): In order to extend this for kInput fusion, that is reduction
710 // with transpose, we only need to end the use-chain checking with the input of
711 // a reduce operations. In this case, the above description on "output" apply
712 // to the result of such a use-chain, which provides the input to the reduce
713 // operation.
IsInstructionSafeForShmemTranspose(const HloInstruction * hlo)714 static bool IsInstructionSafeForShmemTranspose(const HloInstruction* hlo) {
715   if (hlo->IsElementwise()) {
716     return absl::c_all_of(hlo->users(), [&](const HloInstruction* user) {
717       return IsInstructionSafeForShmemTranspose(user);
718     });
719   }
720 
721   // Needs to be kept in sync with `ShmemTransposeSupportedForInputs` below.
722   switch (hlo->opcode()) {
723     // Non-elementwise instructions that don't cause the shmem transpose
724     // to be unsafe, including the instructions that don't currently fuse.
725     case HloOpcode::kGetDimensionSize:
726       // The result of the operation doesn't rely on the content of the
727       // tensor. As such, there is no need to further inspect its users.
728       return true;
729     case HloOpcode::kGetTupleElement:
730     case HloOpcode::kMap:
731     case HloOpcode::kParameter:
732     case HloOpcode::kTuple:
733       return absl::c_all_of(hlo->users(), [&](const HloInstruction* user) {
734         return IsInstructionSafeForShmemTranspose(user);
735       });
736 
737     default:
738       return false;
739   }
740 }
741 
742 // Given a group of input parameters that are 0-2-1 transpose of the outputs
743 // of a fusion kernel, returns the input parameters that are safe for the
744 // shared memory transpose implementation.
745 //
746 // When a tile based shared memory transpose is used to implement an input
747 // with 0-2-1 transpose, we preload a tile of the input elements [z, y..y+31,
748 // x..x+31] to compute the output tile elements of the same indices.
749 // Preloading the input tile this way is only safe when the computation of the
750 // output tile elements do not need any input element outside the preloaded
751 // tile. We inspect all the transitive users of the input parameter up to the
752 // fusion root instruction to see if we can find any instruction that can make
753 // preloading the input tile unsafe.
FilterInputsForShmemTranspose(const HloComputation * fused_computation,std::vector<int64_t> input_ids)754 static std::vector<int64_t> FilterInputsForShmemTranspose(
755     const HloComputation* fused_computation, std::vector<int64_t> input_ids) {
756   std::vector<int64_t> filtered_input_ids;
757   for (int64_t input_id : input_ids) {
758     const HloInstruction* instr =
759         fused_computation->parameter_instruction(input_id);
760     if (IsInstructionSafeForShmemTranspose(instr)) {
761       filtered_input_ids.push_back(input_id);
762     } else {
763       VLOG(10) << "Input not safe for shmem transpose " << instr->ToString();
764     }
765   }
766   return filtered_input_ids;
767 }
768 
769 // Whether code emission supports shared memory transposes for one or more
770 // `input_ids`.
ShmemTransposeSupportedForInputs(const HloInstruction & instr,std::vector<int64_t> input_ids)771 bool ShmemTransposeSupportedForInputs(const HloInstruction& instr,
772                                       std::vector<int64_t> input_ids) {
773   if (instr.opcode() == HloOpcode::kFusion) {
774     return !FilterInputsForShmemTranspose(
775                 instr.fused_instructions_computation(), input_ids)
776                 .empty();
777   }
778   // Needs to be kept in sync with `IsInstructionSafeForShmemTranspose` above.
779   return instr.IsElementwise() ||
780          instr.opcode() == HloOpcode::kGetDimensionSize;
781 }
782 
FindTranspose021DimsAndParameters(const std::vector<Shape> & operand_shapes,const Shape & output_shape)783 static std::optional<TransposeDimsAndParams> FindTranspose021DimsAndParameters(
784     const std::vector<Shape>& operand_shapes, const Shape& output_shape) {
785   std::vector<int64_t> params_012;
786   std::optional<Vector3> reduced_dims_021;
787   for (int64_t operand_idx = 0; operand_idx < operand_shapes.size();
788        ++operand_idx) {
789     std::optional<Vector3> find_transpose_result =
790         ShapeUtil::FindTranspose021(operand_shapes[operand_idx], output_shape);
791     if (!find_transpose_result.has_value()) {
792       continue;
793     } else if (!reduced_dims_021.has_value()) {
794       reduced_dims_021 = *find_transpose_result;
795     } else if (!absl::c_equal(*reduced_dims_021, *find_transpose_result)) {
796       // There is more than one possible transpose. Instead of picking one
797       // transpose, we simply give up here.
798       VLOG(3) << "021 transpose not matched; More than one possible "
799                  "transposition of parameters: "
800               << VectorString(*reduced_dims_021) << " and "
801               << VectorString(*find_transpose_result);
802       return std::nullopt;
803     }
804     params_012.push_back(operand_idx);
805   }
806   if (!reduced_dims_021.has_value()) {
807     return std::nullopt;
808   }
809   return TransposeDimsAndParams{*reduced_dims_021, params_012};
810 }
811 
Match021Transpose(const HloComputation * fused_computation)812 std::optional<TransposeDimsAndParams> Match021Transpose(
813     const HloComputation* fused_computation) {
814   // If a dimensions is smaller than this, untiled transposition may be more
815   // efficient.
816   static const int64_t kMinDimensionToTransposeTiled = 16;
817 
818   const HloInstruction* root = fused_computation->root_instruction();
819   if (root->shape().IsTuple()) {
820     // TODO(cheshire): Why we are not pattern-matching other outputs?
821     root = root->operand(0);
822   }
823   const Shape& output_shape = root->shape();
824 
825   std::vector<Shape> param_shapes;
826   absl::c_for_each(fused_computation->parameter_instructions(),
827                    [&](const HloInstruction* param) {
828                      param_shapes.push_back(param->shape());
829                    });
830   std::optional<TransposeDimsAndParams> reduced_dims_and_params_021 =
831       FindTranspose021DimsAndParameters(param_shapes, output_shape);
832 
833   if (!reduced_dims_and_params_021.has_value()) {
834     VLOG(3) << "021 transposition not found on instruction "
835             << root->ToString();
836     return std::nullopt;
837   }
838   std::vector<int64_t> params_012 = reduced_dims_and_params_021->params;
839   const Vector3& reduced_dims_021 = reduced_dims_and_params_021->dims;
840 
841   if (reduced_dims_021.at(1) < kMinDimensionToTransposeTiled ||
842       reduced_dims_021.at(2) < kMinDimensionToTransposeTiled) {
843     VLOG(3) << "021 transpose not matched: dimensions of transposition "
844             << root->ToString() << " are too small";
845     return std::nullopt;
846   }
847 
848   params_012 = FilterInputsForShmemTranspose(fused_computation, params_012);
849   if (params_012.empty()) {
850     VLOG(3) << "021 transpose on " << root->ToString()
851             << "not matched: no inputs matched after filtering "
852                "for shmem access";
853     return std::nullopt;
854   }
855 
856   // Each of our shared memory tiles has 32*33 elements (so ~4kb, if the
857   // elements are of size 4 bytes), and CUDA has an architectural limit of
858   // 48kb shared memory per SM.  (This is increased to 96kb in Volta, but we
859   // don't use this, in part because it eats into our L1 cache space.)
860   //
861   // For correctness we need to ensure that we don't make more than 48kb worth
862   // of shmem tiles per block.  And for performance, we'd probably like to use
863   // significantly less, so that we can fit more than one block at a time on a
864   // gpu core.
865   //
866   // We say without benchmarks that we want at least 3 threads/block,
867   // corresponding to 3 shmem tiles if the elements are 32 bits wide.  We
868   // choose which params get the shmem transpose treatment arbitrarily; it's
869   // not clear if there's a Right Choice.
870   //
871   // This is only sound if tiled transposes are the only place where we use
872   // shared memory in fusions.  If in the future other fusible ops use shared
873   // memory, we'll have to adjust this heuristic.
874   constexpr int kMinBlocksPerCore = 3;
875   constexpr int64_t kShmemPerCore = 48 * 1024;
876   int64_t shmem_used = 0;
877   for (int64_t i = 0; i < params_012.size(); ++i) {
878     const Shape& operand_shape =
879         fused_computation->parameter_instruction(params_012[i])->shape();
880     shmem_used +=
881         32 * 33 *
882         ShapeUtil::ByteSizeOfPrimitiveType(operand_shape.element_type());
883 
884     if (kMinBlocksPerCore * shmem_used > kShmemPerCore) {
885       // Erase this element and everything after it from params_012.
886       params_012.resize(i);
887       break;
888     }
889   }
890 
891   if (params_012.empty()) {
892     VLOG(3)
893         << "021 transpose on :" << root->ToString()
894         << " not matched: no inputs matched after filtering for shmem budget";
895     return std::nullopt;
896   }
897 
898   return TransposeDimsAndParams{reduced_dims_021, params_012};
899 }
900 
901 }  // namespace gpu
902 }  // namespace xla
903