1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
17
18 #include <algorithm>
19 #include <array>
20 #include <cstdint>
21 #include <numeric>
22 #include <optional>
23 #include <vector>
24
25 #include "llvm/IR/IntrinsicsNVPTX.h"
26 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
27 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
28 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
29 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
30 #include "tensorflow/compiler/xla/service/gpu/target_util.h"
31 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
32 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
33 #include "tensorflow/compiler/xla/service/hlo_parser.h"
34 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_type_conversion_util.h"
35
36 namespace xla {
37 namespace gpu {
38
39 namespace {
40
41 // Return whether the given shape is rank 2 excluding the batch dimensions.
IsRank2(const Shape & shape,int64_t batch_dimensions_size)42 bool IsRank2(const Shape& shape, int64_t batch_dimensions_size) {
43 return shape.rank() == batch_dimensions_size + 2;
44 }
45
46 // Given a shape and a group of contiguous dimensions in the shape, returns
47 // a tuple of three values (major, middle, minor), where major is the size of
48 // the dimensions more major then the given dimensions, minor is the size of
49 // dimensions more minor then the given dimensions, and middle is the size of
50 // the given dimensions.
PartitionShapeByMiddleDimensions(const Shape & shape,absl::Span<const int64_t> dims_middle)51 std::array<int64_t, 3> PartitionShapeByMiddleDimensions(
52 const Shape& shape, absl::Span<const int64_t> dims_middle) {
53 CHECK(LayoutUtil::AreDimensionsConsecutive(shape.layout(), dims_middle));
54 std::array<int64_t, 3> values = {1, 1, 1};
55 enum Segment { kMajor = 0, kMiddle = 1, kMinor = 2 };
56 Segment cur_segment = kMinor;
57
58 for (int64_t cur_dim : LayoutUtil::MinorToMajor(shape)) {
59 if (cur_segment != kMajor) {
60 // Handle change of segments.
61 bool cur_dim_in_middle = absl::c_linear_search(dims_middle, cur_dim);
62 if (cur_segment == kMinor) {
63 if (cur_dim_in_middle) {
64 cur_segment = kMiddle;
65 }
66 } else if (cur_segment == kMiddle) {
67 if (!cur_dim_in_middle) {
68 cur_segment = kMajor;
69 }
70 }
71 }
72 values[cur_segment] *= shape.dimensions(cur_dim);
73 }
74 return values;
75 }
76
GetShapeFromTensorType(mlir::Value value)77 Shape GetShapeFromTensorType(mlir::Value value) {
78 constexpr char kDefaultLayoutAttrName[] = "xla_shape";
79
80 mlir::Operation* op = value.getDefiningOp();
81 CHECK(op);
82 CHECK(value.getType().isa<mlir::TensorType>());
83 Shape shape;
84 if (auto attr = op->getAttrOfType<mlir::StringAttr>(kDefaultLayoutAttrName)) {
85 shape = *xla::ParseShape(
86 absl::string_view(attr.getValue().data(), attr.getValue().size()));
87 } else {
88 shape = TypeToShape(value.getType());
89 }
90 return shape;
91 }
92
93 } // namespace
94
IsMatrixMultiplication(const HloInstruction & dot)95 bool IsMatrixMultiplication(const HloInstruction& dot) {
96 if (dot.opcode() != HloOpcode::kDot) {
97 return false;
98 }
99 const Shape& lhs_shape = dot.operand(0)->shape();
100 const Shape& rhs_shape = dot.operand(1)->shape();
101 const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers();
102
103 PrimitiveType output_primitive_type = dot.shape().element_type();
104 bool type_is_allowed =
105 (output_primitive_type == F16 || output_primitive_type == BF16 ||
106 output_primitive_type == F32 || output_primitive_type == F64 ||
107 output_primitive_type == C64 || output_primitive_type == C128) ||
108 (output_primitive_type == S32 && lhs_shape.element_type() == S8 &&
109 lhs_shape.element_type() == S8);
110 bool shapes_are_valid =
111 type_is_allowed &&
112 IsRank2(lhs_shape, dim_numbers.lhs_batch_dimensions_size()) &&
113 IsRank2(rhs_shape, dim_numbers.lhs_batch_dimensions_size()) &&
114 IsRank2(dot.shape(), dim_numbers.lhs_batch_dimensions_size()) &&
115 !ShapeUtil::IsZeroElementArray(lhs_shape) &&
116 !ShapeUtil::IsZeroElementArray(rhs_shape);
117
118 if (!shapes_are_valid) {
119 return false;
120 }
121
122 // The size of the reduction dimension should match. The shape inference
123 // guarantees this invariant, so the check here is for programming
124 // errors.
125 CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)),
126 rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0)));
127
128 return true;
129 }
130
GetReductionTiling(const ReductionDimensions & reduction_dimensions,se::CudaComputeCapability cuda_compute_capability)131 std::array<int64_t, 3> GetReductionTiling(
132 const ReductionDimensions& reduction_dimensions,
133 se::CudaComputeCapability cuda_compute_capability) {
134 if (reduction_dimensions.is_row_reduction) {
135 int64_t tile_z = std::min(reduction_dimensions.dimensions[0],
136 BatchedReductionRaceFreeBound());
137 return {tile_z, 1, 16};
138 }
139
140 // Column reduction.
141 return {1, 128, 1};
142 }
143
144 const char* const kCusolverCholeskyCallTarget = "__cusolver$cholesky";
145
IsCustomCallToCusolver(const HloInstruction & hlo)146 bool IsCustomCallToCusolver(const HloInstruction& hlo) {
147 if (hlo.opcode() != HloOpcode::kCustomCall) {
148 return false;
149 }
150 const auto& target = hlo.custom_call_target();
151 return target == kCusolverCholeskyCallTarget;
152 }
153
GetReductionKindAndContiguousComponentsImpl(const Shape & input_shape,absl::Span<const int64_t> dims_to_reduce)154 static ReductionDimensions GetReductionKindAndContiguousComponentsImpl(
155 const Shape& input_shape, absl::Span<const int64_t> dims_to_reduce) {
156 DimensionVector dims_to_keep;
157 for (int64_t dim = 0; dim < input_shape.rank(); ++dim) {
158 if (!absl::c_linear_search(dims_to_reduce, dim)) {
159 dims_to_keep.push_back(dim);
160 }
161 }
162
163 if (dims_to_keep.empty()) {
164 return {/*is_row_reduction=*/true,
165 {1, 1, ShapeUtil::ElementsIn(input_shape)}};
166 }
167
168 if (LayoutUtil::AreDimensionsConsecutive(input_shape.layout(),
169 dims_to_keep)) {
170 std::array<int64_t, 3> shape_partition =
171 PartitionShapeByMiddleDimensions(input_shape, dims_to_keep);
172 if (shape_partition[1] == 1) {
173 return {/*is_row_reduction=*/true,
174 {1, 1, shape_partition[0] * shape_partition[2]}};
175 }
176 if (shape_partition[2] == 1) {
177 return {/*is_row_reduction=*/false,
178 {1, shape_partition[0], shape_partition[1]}};
179 }
180 return {/*is_row_reduction=*/true, shape_partition};
181 }
182
183 std::array<int64_t, 3> shape_partition =
184 PartitionShapeByMiddleDimensions(input_shape, dims_to_reduce);
185
186 if (shape_partition[2] == 1) {
187 return {/*is_row_reduction=*/true,
188 {1, shape_partition[0], shape_partition[1]}};
189 }
190 return {/*is_row_reduction=*/false, shape_partition};
191 }
192
IsUnnestedReductionFasterThanElemental(const ReductionDimensions & reduction_dimensions)193 static bool IsUnnestedReductionFasterThanElemental(
194 const ReductionDimensions& reduction_dimensions) {
195 if (reduction_dimensions.is_row_reduction) {
196 // For row reduction, the tile block is 1 x tile_size_x, and we are reducing
197 // along tile_size_x which needs to be large enough to make the tiling
198 // implementation efficient.
199 // For very small reductions with a power-of-two size, we can fit multiple
200 // reductions inside a single warp, which is more efficient than a loop.
201 return (reduction_dimensions.dimensions[2] >= WarpSize()) ||
202 ((WarpSize() % reduction_dimensions.dimensions[2]) == 0);
203 }
204
205 // For column reduction, the tile block is tile_size_y x tile_size_x, and we
206 // are reducing along tile_size_y. Only tile_size_y needs to be
207 // large enough to make the tiling implementation efficient.
208 int64_t major_size = reduction_dimensions.dimensions[1];
209 int64_t minor_size = reduction_dimensions.dimensions[2];
210
211 // Rule generated by sweeping the search space of small column reductions.
212 bool prefer_elemental_emitter =
213 (major_size < WarpSize()) ||
214 (major_size < 2 * WarpSize() && minor_size < WarpSize()) ||
215 (major_size < 4 * WarpSize() && minor_size < 8) ||
216 (major_size < 8 * WarpSize() && minor_size < 3);
217
218 return !prefer_elemental_emitter;
219 }
220
221 // Whether we can/should use the unnested emitter for reduction.
IsReductionFromOrToContiguousDimensionsImpl(const Shape & operand_shape,absl::Span<int64_t const> dims_to_reduce)222 static bool IsReductionFromOrToContiguousDimensionsImpl(
223 const Shape& operand_shape, absl::Span<int64_t const> dims_to_reduce) {
224 DimensionVector dims_to_keep;
225 for (int64_t dim = 0; dim < operand_shape.dimensions().size(); ++dim) {
226 if (!absl::c_linear_search(dims_to_reduce, dim)) {
227 dims_to_keep.push_back(dim);
228 }
229 }
230
231 // We support fast codegen for three cases:
232 // 1) Row reduction: (K, R)
233 // 2) Column reduction: (K, R, K)
234 // 3) "Batched" row reduction: (R, K, R)
235 return (LayoutUtil::AreDimensionsConsecutive(operand_shape.layout(),
236 dims_to_keep) ||
237 LayoutUtil::AreDimensionsConsecutive(operand_shape.layout(),
238 dims_to_reduce)) &&
239 IsUnnestedReductionFasterThanElemental(
240 GetReductionKindAndContiguousComponentsImpl(operand_shape,
241 dims_to_reduce));
242 }
243
IsReductionFromOrToContiguousDimensions(const HloInstruction & reduce)244 bool IsReductionFromOrToContiguousDimensions(const HloInstruction& reduce) {
245 return reduce.opcode() == HloOpcode::kReduce &&
246 IsReductionFromOrToContiguousDimensionsImpl(reduce.operand(0)->shape(),
247 reduce.dimensions());
248 }
249
IsReductionFromOrToContiguousDimensions(mlir::Operation * op)250 bool IsReductionFromOrToContiguousDimensions(mlir::Operation* op) {
251 auto reduce = mlir::dyn_cast<mlir::mhlo::ReduceOp>(op);
252 if (!reduce) {
253 return false;
254 }
255
256 mlir::Value first_operand = reduce.operands()[0];
257 Shape operand_shape = GetShape(first_operand);
258
259 llvm::SmallVector<int64_t> dimensions_to_reduce;
260 for (const llvm::APInt& d : reduce.dimensions()) {
261 dimensions_to_reduce.push_back(d.getZExtValue());
262 }
263
264 return IsReductionFromOrToContiguousDimensionsImpl(operand_shape,
265 dimensions_to_reduce);
266 }
267
IsInputFusibleSlices(mlir::Operation * unnested_hlo,bool verify_no_strides)268 bool IsInputFusibleSlices(mlir::Operation* unnested_hlo,
269 bool verify_no_strides) {
270 auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(unnested_hlo);
271 if (!fusion) {
272 return false;
273 }
274
275 auto is_non_strided = [](mlir::DenseIntElementsAttr strides) -> bool {
276 return absl::c_all_of(
277 strides, [](const llvm::APInt& stride) { return stride == 1; });
278 };
279
280 for (mlir::Value value : fusion.getFusionResults()) {
281 auto slice =
282 mlir::dyn_cast_or_null<mlir::mhlo::SliceOp>(value.getDefiningOp());
283 if (!slice) {
284 return false;
285 }
286 if (verify_no_strides && !is_non_strided(slice.strides())) {
287 return false;
288 }
289 }
290 return true;
291 }
292
GetReductionKindAndContiguousComponents(const HloInstruction & reduce)293 ReductionDimensions GetReductionKindAndContiguousComponents(
294 const HloInstruction& reduce) {
295 return GetReductionKindAndContiguousComponentsImpl(reduce.operand(0)->shape(),
296 reduce.dimensions());
297 }
298
GetReductionKindAndContiguousComponents(mlir::Operation * reduce)299 ReductionDimensions GetReductionKindAndContiguousComponents(
300 mlir::Operation* reduce) {
301 mlir::Value input = reduce->getOperand(0);
302 Shape operand_shape = GetShape(input);
303 llvm::SmallVector<int64_t> dimensions_to_reduce;
304 for (const llvm::APInt& d :
305 mlir::cast<mlir::mhlo::ReduceOp>(reduce).dimensions()) {
306 dimensions_to_reduce.push_back(d.getZExtValue());
307 }
308 return GetReductionKindAndContiguousComponentsImpl(operand_shape,
309 dimensions_to_reduce);
310 }
311
312 // This emits a device-side call to
313 // "i32 vprintf(i8* fmt, arguments_type* arguments)" in the driver; see
314 // http://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability/index.html#system-calls
EmitPrintf(absl::string_view fmt,absl::Span<llvm::Value * const> arguments,llvm::IRBuilder<> * builder)315 llvm::Value* EmitPrintf(absl::string_view fmt,
316 absl::Span<llvm::Value* const> arguments,
317 llvm::IRBuilder<>* builder) {
318 std::vector<llvm::Type*> argument_types;
319
320 // Variadic arguments implicit promotion [1] converts float to double,
321 // and bool/char/short are converted to int.
322 // [1] https://en.cppreference.com/w/cpp/language/variadic_arguments
323 auto requires_int32_promotion = [](llvm::Type* type) {
324 return type->isIntegerTy(/*BitWidth=*/1) ||
325 type->isIntegerTy(/*BitWidth=*/8) ||
326 type->isIntegerTy(/*BitWidth=*/16);
327 };
328 auto requires_double_promotion = [](llvm::Type* type) {
329 return type->isFloatingPointTy();
330 };
331
332 for (auto argument : arguments) {
333 llvm::Type* type = argument->getType();
334 if (requires_double_promotion(type)) {
335 argument_types.push_back(builder->getDoubleTy());
336 } else if (requires_int32_promotion(type)) {
337 argument_types.push_back(builder->getInt32Ty());
338 } else {
339 argument_types.push_back(type);
340 }
341 }
342 auto* arguments_type = llvm::StructType::create(argument_types);
343 llvm::Value* arguments_ptr = builder->CreateAlloca(arguments_type);
344 for (size_t i = 0; i < arguments.size(); ++i) {
345 llvm::Value* value = arguments[i];
346 llvm::Type* type = value->getType();
347 if (requires_double_promotion(type)) {
348 value = builder->CreateFPCast(value, builder->getDoubleTy());
349 } else if (requires_int32_promotion(type)) {
350 value = builder->CreateIntCast(value, builder->getInt32Ty(),
351 /*isSigned=*/true);
352 }
353 builder->CreateStore(
354 value,
355 builder->CreateGEP(arguments_type, arguments_ptr,
356 {builder->getInt64(0), builder->getInt32(i)}));
357 }
358 llvm::Type* ptr_ty = builder->getInt8Ty()->getPointerTo();
359 return builder->CreateCall(
360 builder->GetInsertBlock()->getParent()->getParent()->getOrInsertFunction(
361 "vprintf",
362 llvm::FunctionType::get(builder->getInt32Ty(), {ptr_ty, ptr_ty},
363 /*isVarArg=*/false)),
364 {builder->CreateGlobalStringPtr(llvm_ir::AsStringRef(fmt)),
365 builder->CreatePointerCast(arguments_ptr, ptr_ty)});
366 }
367
368 // Helper function to emit call to AMDGPU shfl_down function.
EmitAMDGPUShflDown(llvm::Value * value,llvm::Value * offset,llvm::IRBuilder<> * b)369 llvm::Value* EmitAMDGPUShflDown(llvm::Value* value, llvm::Value* offset,
370 llvm::IRBuilder<>* b) {
371 llvm::Module* module = b->GetInsertBlock()->getModule();
372 CHECK_EQ(value->getType()->getPrimitiveSizeInBits(), 32);
373 auto* i32_ty = b->getInt32Ty();
374 llvm::FunctionCallee shfl_fn = module->getOrInsertFunction(
375 llvm_ir::AsStringRef("__ockl_readuplane_i32"),
376 llvm::FunctionType::get(/*Result=*/i32_ty, {i32_ty, i32_ty},
377 /*isVarArg=*/false));
378 // AMDGPU device function requires first argument as i32.
379 llvm::Value* result =
380 b->CreateCall(shfl_fn, {b->CreateBitCast(value, i32_ty), offset});
381 // AMDGPU device function always returns an i32 type.
382 return b->CreateBitCast(result, value->getType());
383 }
384
385 // Helper function to emit call to NVPTX shfl_down intrinsic.
EmitNVPTXShflDown(llvm::Value * value,llvm::Value * offset,llvm::IRBuilder<> * b)386 llvm::Value* EmitNVPTXShflDown(llvm::Value* value, llvm::Value* offset,
387 llvm::IRBuilder<>* b) {
388 llvm::Module* module = b->GetInsertBlock()->getModule();
389 llvm::Intrinsic::ID llvm_intrinsic_id;
390 CHECK_EQ(value->getType()->getPrimitiveSizeInBits(), 32);
391 if (value->getType()->isFloatTy()) {
392 llvm_intrinsic_id = llvm::Intrinsic::nvvm_shfl_sync_down_f32;
393 } else {
394 llvm_intrinsic_id = llvm::Intrinsic::nvvm_shfl_sync_down_i32;
395 }
396 llvm::Function* intrinsic =
397 llvm::Intrinsic::getDeclaration(module, llvm_intrinsic_id, {});
398 return b->CreateCall(
399 intrinsic, {b->getInt32(-1), value, offset, b->getInt32(WarpSize() - 1)});
400 }
401
EmitFullWarpShuffleDown(llvm::Value * value,llvm::Value * offset,llvm::IRBuilder<> * builder)402 llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset,
403 llvm::IRBuilder<>* builder) {
404 int bit_width = value->getType()->getPrimitiveSizeInBits();
405 llvm::Module* module = builder->GetInsertBlock()->getModule();
406 llvm::Triple target_triple = llvm::Triple(module->getTargetTriple());
407
408 // Special case for efficiency
409 if (value->getType()->isFloatTy() && bit_width == 32) {
410 if (target_triple.isNVPTX()) {
411 return EmitNVPTXShflDown(value, offset, builder);
412 } else if (target_triple.getArch() == llvm::Triple::amdgcn) {
413 return EmitAMDGPUShflDown(value, offset, builder);
414 } else {
415 LOG(FATAL) << "Invalid triple " << target_triple.str();
416 }
417 }
418
419 // We must split values wider than 32 bits as the "shfl" instruction operates
420 // on 32-bit values.
421 int num_segments = CeilOfRatio(bit_width, 32);
422 llvm::Value* x = builder->CreateBitCast(
423 builder->CreateZExt(
424 builder->CreateBitCast(value, builder->getIntNTy(bit_width)),
425 builder->getIntNTy(32 * num_segments)),
426 llvm::VectorType::get(builder->getInt32Ty(), num_segments, false));
427 for (int i = 0; i < num_segments; ++i) {
428 llvm::Value* insert_val;
429 if (target_triple.isNVPTX()) {
430 insert_val = EmitNVPTXShflDown(builder->CreateExtractElement(x, i),
431 offset, builder);
432 } else if (target_triple.getArch() == llvm::Triple::amdgcn) {
433 insert_val = EmitAMDGPUShflDown(builder->CreateExtractElement(x, i),
434 offset, builder);
435 } else {
436 LOG(FATAL) << "Invalid triple " << target_triple.str();
437 }
438 x = builder->CreateInsertElement(x, insert_val, i);
439 }
440 return builder->CreateBitCast(
441 builder->CreateTrunc(
442 builder->CreateBitCast(x, builder->getIntNTy(32 * num_segments)),
443 builder->getIntNTy(bit_width)),
444 value->getType());
445 }
446
IsBlock0Thread0(llvm::IRBuilder<> * b)447 llvm::Value* IsBlock0Thread0(llvm::IRBuilder<>* b) {
448 llvm::Value* is_thread0 = b->CreateICmpEQ(
449 b->getInt32(0),
450 EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b));
451
452 llvm::Value* is_block0 = b->CreateICmpEQ(
453 b->getInt32(0),
454 EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, b));
455 return b->CreateAnd(is_thread0, is_block0);
456 }
457
IsFusedReductionOutputConsistent(const HloInstruction * inst,const HloInstruction * first_reduce)458 bool IsFusedReductionOutputConsistent(const HloInstruction* inst,
459 const HloInstruction* first_reduce) {
460 if (IsReductionFromOrToContiguousDimensions(*inst)) {
461 // Shapes, layouts and dimensions must be the same for all reduces
462 // inside of this fusion.
463 // TODO(tjoerg): Relax the shape constraint. The datatype does not matter.
464 return ShapeUtil::Equal(first_reduce->shape(), inst->shape()) &&
465 ShapeUtil::Equal(first_reduce->operand(0)->shape(),
466 inst->operand(0)->shape()) &&
467 ShapeUtil::Equal(first_reduce->operand(1)->shape(),
468 inst->operand(1)->shape()) &&
469 first_reduce->dimensions() == inst->dimensions();
470 }
471 return ShapeUtil::CompatibleIgnoringElementType(
472 first_reduce->operand(0)->shape(), inst->shape()) &&
473 LayoutUtil::Equal(first_reduce->operand(0)->shape().layout(),
474 inst->shape().layout());
475 }
476
477 // Given an LMHLO op, returns the operand index of the first output operand.
478 //
479 // Notice that an operand alised to an output isn't an output, even though in
480 // that case WritesMlirBuffer() returns true on that operand.
481 //
482 // An operand is !WritesMlirBuffer() || equals (aliases) to a later operand. An
483 // output is the opposite, being both WritesMlirBuffer() and does not equal to
484 // any later operand.
PartitionLmhloOperandsAndOutputs(mlir::Operation * op)485 int PartitionLmhloOperandsAndOutputs(mlir::Operation* op) {
486 CHECK(op->getDialect() == op->getContext()->getLoadedDialect("lmhlo"));
487
488 int i;
489 for (i = op->getOperands().size() - 1; i >= 0; i--) {
490 const bool aliased =
491 std::find(op->getOperands().begin() + i + 1, op->getOperands().end(),
492 op->getOperand(i)) != op->getOperands().end();
493 if (!WritesMlirBuffer(op, op->getOperand(i)) || aliased) {
494 break;
495 }
496 }
497 return i + 1;
498 }
499
GetHloOperands(mlir::Operation * op)500 std::vector<mlir::Value> GetHloOperands(mlir::Operation* op) {
501 if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) {
502 return ToStdVector(fusion.getInputBuffers());
503 }
504 if (op->getDialect() == op->getContext()->getLoadedDialect("lmhlo")) {
505 int output_start = PartitionLmhloOperandsAndOutputs(op);
506 std::vector<mlir::Value> operands;
507 operands.reserve(output_start);
508 for (int i = 0; i < output_start; i++) {
509 operands.push_back(op->getOperand(i));
510 }
511 return operands;
512 }
513 if (op->getDialect() == op->getContext()->getLoadedDialect("mhlo")) {
514 return std::vector<mlir::Value>(op->getOperands().begin(),
515 op->getOperands().end());
516 }
517 LOG(FATAL) << "Unexpected op: " << MlirToString(op);
518 }
519
GetHloOutputs(mlir::Operation * op)520 std::vector<mlir::Value> GetHloOutputs(mlir::Operation* op) {
521 if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) {
522 return ToStdVector(fusion.getOutputBuffers());
523 }
524 if (op->getDialect() == op->getContext()->getLoadedDialect("lmhlo")) {
525 int output_start = PartitionLmhloOperandsAndOutputs(op);
526 std::vector<mlir::Value> outputs;
527 for (int i = output_start; i < op->getNumOperands(); i++) {
528 outputs.push_back(op->getOperand(i));
529 }
530 return outputs;
531 }
532 if (op->getDialect() == op->getContext()->getLoadedDialect("mhlo")) {
533 return std::vector<mlir::Value>(op->getResults().begin(),
534 op->getResults().end());
535 }
536 LOG(FATAL) << "Unexpected op: " << MlirToString(op);
537 }
538
WritesMlirBuffer(mlir::Operation * op,mlir::Value operand)539 bool WritesMlirBuffer(mlir::Operation* op, mlir::Value operand) {
540 llvm::SmallVector<mlir::MemoryEffects::EffectInstance, 2> effects;
541 mlir::cast<mlir::MemoryEffectOpInterface>(op).getEffectsOnValue(operand,
542 effects);
543 return absl::c_any_of(
544 effects, [](const mlir::MemoryEffects::EffectInstance& instance) {
545 return mlir::isa<mlir::MemoryEffects::Write>(instance.getEffect());
546 });
547 }
548
GetMemRefSizeInBytes(mlir::MemRefType type)549 static int64_t GetMemRefSizeInBytes(mlir::MemRefType type) {
550 // For i1 memrefs, the underlying allocation is 8 bits.
551 if (type.getElementType().isInteger(/*width=*/1)) {
552 return type.getNumElements();
553 } else {
554 return type.cast<mlir::ShapedType>().getSizeInBits() / CHAR_BIT;
555 }
556 }
557
GetAllocationIndex(mlir::BlockArgument func_arg,std::string * constant_name)558 static int64_t GetAllocationIndex(mlir::BlockArgument func_arg,
559 std::string* constant_name) {
560 auto func_op =
561 mlir::cast<mlir::func::FuncOp>(func_arg.getParentRegion()->getParentOp());
562 if (constant_name) {
563 if (auto constant_name_attr = func_op.getArgAttrOfType<mlir::StringAttr>(
564 func_arg.getArgNumber(), "lmhlo.constant_name")) {
565 *constant_name = constant_name_attr.getValue().str();
566 }
567 }
568 return func_arg.getArgNumber();
569 }
570
GetAllocationSlice(mlir::Value v,absl::Span<const BufferAllocation> allocations,std::string * constant_name)571 StatusOr<BufferAllocation::Slice> GetAllocationSlice(
572 mlir::Value v, absl::Span<const BufferAllocation> allocations,
573 std::string* constant_name) {
574 if (constant_name) {
575 constant_name->clear();
576 }
577
578 int64_t size = GetMemRefSizeInBytes(v.getType().cast<mlir::MemRefType>());
579
580 // We match the following patterns here:
581 // base := ViewOp(arg) | get_global_memref (global_memref) | arg
582 // root := base | MemRefReinterpretCastOp(base)
583
584 if (auto cast = mlir::dyn_cast_or_null<mlir::memref::ReinterpretCastOp>(
585 v.getDefiningOp())) {
586 v = cast.getViewSource();
587 }
588 if (auto view =
589 mlir::dyn_cast_or_null<mlir::memref::ViewOp>(v.getDefiningOp())) {
590 TF_RET_CHECK(view.getSource().isa<mlir::BlockArgument>());
591
592 return BufferAllocation::Slice(
593 &allocations[GetAllocationIndex(
594 view.getSource().cast<mlir::BlockArgument>(), constant_name)],
595 mlir::cast<mlir::arith::ConstantOp>(view.getByteShift().getDefiningOp())
596 .getValue()
597 .cast<mlir::IntegerAttr>()
598 .getValue()
599 .getSExtValue(),
600 size);
601 }
602 if (auto get_global = mlir::dyn_cast_or_null<mlir::memref::GetGlobalOp>(
603 v.getDefiningOp())) {
604 auto module = get_global->getParentOfType<mlir::ModuleOp>();
605 if (constant_name) {
606 *constant_name = get_global.getName().str();
607 }
608 auto global = mlir::cast<mlir::memref::GlobalOp>(
609 module.lookupSymbol(get_global.getName()));
610 int64_t index =
611 global->getAttrOfType<mlir::IntegerAttr>("lmhlo.alloc").getInt();
612 return BufferAllocation::Slice(&allocations[index], 0,
613 allocations[index].size());
614 }
615 if (auto arg = v.dyn_cast<mlir::BlockArgument>()) {
616 return BufferAllocation::Slice(
617 &allocations[GetAllocationIndex(arg, constant_name)], 0, size);
618 }
619
620 return Unimplemented(
621 "Operand has to be in the form of ViewOp(arg) or "
622 "StaticMemRefCastOp(ViewOp(arg)) or arg");
623 }
624
CanEmitFusedDynamicUpdateSliceInPlaceForGpu(mlir::lmhlo::FusionOp fusion,absl::Span<const BufferAllocation> allocations)625 bool CanEmitFusedDynamicUpdateSliceInPlaceForGpu(
626 mlir::lmhlo::FusionOp fusion,
627 absl::Span<const BufferAllocation> allocations) {
628 auto results = fusion.getFusionResults();
629 if (results.size() != 1) {
630 return false;
631 }
632 auto dus = mlir::dyn_cast<mlir::mhlo::DynamicUpdateSliceOp>(
633 results[0].getDefiningOp());
634 if (!dus) {
635 return false;
636 }
637
638 auto output_buffers = fusion.getOutputBuffers();
639 CHECK_EQ(1, output_buffers.size());
640 auto parameter = mlir::dyn_cast<mlir::bufferization::ToTensorOp>(
641 dus.operand().getDefiningOp());
642
643 if (!parameter) {
644 return false;
645 }
646
647 auto maybe_lhs = GetAllocationSlice(parameter.getMemref(), allocations);
648 auto maybe_rhs = GetAllocationSlice(output_buffers[0], allocations);
649 return maybe_lhs.ok() && maybe_rhs.ok() && *maybe_lhs == *maybe_rhs;
650 }
651
GetShape(mlir::Value value)652 Shape GetShape(mlir::Value value) {
653 if (value.getType().isa<mlir::MemRefType>()) {
654 return TypeToShape(value.getType());
655 } else if (value.getType().isa<mlir::TensorType>()) {
656 return GetShapeFromTensorType(value);
657 } else if (value.getType().isa<mlir::TupleType>()) {
658 return TypeToShape(value.getType());
659 }
660 LOG(FATAL) << "Unexpected value type to get shape for";
661 return {};
662 }
663
ReductionIsRaceFree(const ReductionDimensions & reduction_dimensions,const std::array<int64_t,3> & reduction_tiling)664 bool ReductionIsRaceFree(const ReductionDimensions& reduction_dimensions,
665 const std::array<int64_t, 3>& reduction_tiling) {
666 return (reduction_dimensions.is_row_reduction &&
667 reduction_dimensions.dimensions[2] <=
668 MinThreadsXRowReduction() * reduction_tiling[2] &&
669 reduction_dimensions.dimensions[0] <=
670 BatchedReductionRaceFreeBound()) ||
671 (!reduction_dimensions.is_row_reduction &&
672 reduction_dimensions.dimensions[1] <=
673 WarpSize() * reduction_tiling[1]);
674 }
675
676 // A recursive function to inspect the users of a parameter to determine
677 // whether it's safe for a parameter to participate in a shared-memory
678 // transpose.
679 //
680 // Consider a fusion parameter P for which we might want to use a shmem
681 // transpose. If we do, we use a GPU thread block to preload a tile of P with
682 // indices [z, y..y+31, x..x+31] to compute an output tile with the same indices
683 // cooperatively, where z, y, x are the indices for the normalized input/output
684 // tensor (see the document for FindTranspose021 for the definition of
685 // normalized tensor for 0-2-1 transpose). This shmem transpose implementation
686 // requires that the computation of the output tile only read elements within
687 // the preload tile. If this is not true, we can't use a shmem transpose for P.
688 //
689 // If the computation of output element [z, y, x] only requires the element of
690 // P with the same indices, the shmem transpose implementation can be applied
691 // to P safely. This is a sufficient but not necessary condition. We check all
692 // the transitive users of P to see if we can find a user that may cause an
693 // exception to the situation. If such a user is not found, we conclude that P
694 // is safe for shmem transpose.
695 //
696 // This is trivially true for elementwise operations and some "data-movement"
697 // ops like kTuple. However, it's not true for operations that can change the
698 // dimensions of the inputs (e.g. pad, slice) and bitcast operation.
699 // For example:
700 //
701 // fused_computation {
702 // param_0 = f32[64,64]{1,0} parameter(0)
703 // ROOT bitcast = f32[64,64]{0,1} bitcast(param_0)
704 // }
705 // The output element at logical address [0, 63] depends on the input element
706 // at logical address [63, 0], which would not be within the shared-memory
707 // block.
708 //
709 // TODO(bixia): In order to extend this for kInput fusion, that is reduction
710 // with transpose, we only need to end the use-chain checking with the input of
711 // a reduce operations. In this case, the above description on "output" apply
712 // to the result of such a use-chain, which provides the input to the reduce
713 // operation.
IsInstructionSafeForShmemTranspose(const HloInstruction * hlo)714 static bool IsInstructionSafeForShmemTranspose(const HloInstruction* hlo) {
715 if (hlo->IsElementwise()) {
716 return absl::c_all_of(hlo->users(), [&](const HloInstruction* user) {
717 return IsInstructionSafeForShmemTranspose(user);
718 });
719 }
720
721 // Needs to be kept in sync with `ShmemTransposeSupportedForInputs` below.
722 switch (hlo->opcode()) {
723 // Non-elementwise instructions that don't cause the shmem transpose
724 // to be unsafe, including the instructions that don't currently fuse.
725 case HloOpcode::kGetDimensionSize:
726 // The result of the operation doesn't rely on the content of the
727 // tensor. As such, there is no need to further inspect its users.
728 return true;
729 case HloOpcode::kGetTupleElement:
730 case HloOpcode::kMap:
731 case HloOpcode::kParameter:
732 case HloOpcode::kTuple:
733 return absl::c_all_of(hlo->users(), [&](const HloInstruction* user) {
734 return IsInstructionSafeForShmemTranspose(user);
735 });
736
737 default:
738 return false;
739 }
740 }
741
742 // Given a group of input parameters that are 0-2-1 transpose of the outputs
743 // of a fusion kernel, returns the input parameters that are safe for the
744 // shared memory transpose implementation.
745 //
746 // When a tile based shared memory transpose is used to implement an input
747 // with 0-2-1 transpose, we preload a tile of the input elements [z, y..y+31,
748 // x..x+31] to compute the output tile elements of the same indices.
749 // Preloading the input tile this way is only safe when the computation of the
750 // output tile elements do not need any input element outside the preloaded
751 // tile. We inspect all the transitive users of the input parameter up to the
752 // fusion root instruction to see if we can find any instruction that can make
753 // preloading the input tile unsafe.
FilterInputsForShmemTranspose(const HloComputation * fused_computation,std::vector<int64_t> input_ids)754 static std::vector<int64_t> FilterInputsForShmemTranspose(
755 const HloComputation* fused_computation, std::vector<int64_t> input_ids) {
756 std::vector<int64_t> filtered_input_ids;
757 for (int64_t input_id : input_ids) {
758 const HloInstruction* instr =
759 fused_computation->parameter_instruction(input_id);
760 if (IsInstructionSafeForShmemTranspose(instr)) {
761 filtered_input_ids.push_back(input_id);
762 } else {
763 VLOG(10) << "Input not safe for shmem transpose " << instr->ToString();
764 }
765 }
766 return filtered_input_ids;
767 }
768
769 // Whether code emission supports shared memory transposes for one or more
770 // `input_ids`.
ShmemTransposeSupportedForInputs(const HloInstruction & instr,std::vector<int64_t> input_ids)771 bool ShmemTransposeSupportedForInputs(const HloInstruction& instr,
772 std::vector<int64_t> input_ids) {
773 if (instr.opcode() == HloOpcode::kFusion) {
774 return !FilterInputsForShmemTranspose(
775 instr.fused_instructions_computation(), input_ids)
776 .empty();
777 }
778 // Needs to be kept in sync with `IsInstructionSafeForShmemTranspose` above.
779 return instr.IsElementwise() ||
780 instr.opcode() == HloOpcode::kGetDimensionSize;
781 }
782
FindTranspose021DimsAndParameters(const std::vector<Shape> & operand_shapes,const Shape & output_shape)783 static std::optional<TransposeDimsAndParams> FindTranspose021DimsAndParameters(
784 const std::vector<Shape>& operand_shapes, const Shape& output_shape) {
785 std::vector<int64_t> params_012;
786 std::optional<Vector3> reduced_dims_021;
787 for (int64_t operand_idx = 0; operand_idx < operand_shapes.size();
788 ++operand_idx) {
789 std::optional<Vector3> find_transpose_result =
790 ShapeUtil::FindTranspose021(operand_shapes[operand_idx], output_shape);
791 if (!find_transpose_result.has_value()) {
792 continue;
793 } else if (!reduced_dims_021.has_value()) {
794 reduced_dims_021 = *find_transpose_result;
795 } else if (!absl::c_equal(*reduced_dims_021, *find_transpose_result)) {
796 // There is more than one possible transpose. Instead of picking one
797 // transpose, we simply give up here.
798 VLOG(3) << "021 transpose not matched; More than one possible "
799 "transposition of parameters: "
800 << VectorString(*reduced_dims_021) << " and "
801 << VectorString(*find_transpose_result);
802 return std::nullopt;
803 }
804 params_012.push_back(operand_idx);
805 }
806 if (!reduced_dims_021.has_value()) {
807 return std::nullopt;
808 }
809 return TransposeDimsAndParams{*reduced_dims_021, params_012};
810 }
811
Match021Transpose(const HloComputation * fused_computation)812 std::optional<TransposeDimsAndParams> Match021Transpose(
813 const HloComputation* fused_computation) {
814 // If a dimensions is smaller than this, untiled transposition may be more
815 // efficient.
816 static const int64_t kMinDimensionToTransposeTiled = 16;
817
818 const HloInstruction* root = fused_computation->root_instruction();
819 if (root->shape().IsTuple()) {
820 // TODO(cheshire): Why we are not pattern-matching other outputs?
821 root = root->operand(0);
822 }
823 const Shape& output_shape = root->shape();
824
825 std::vector<Shape> param_shapes;
826 absl::c_for_each(fused_computation->parameter_instructions(),
827 [&](const HloInstruction* param) {
828 param_shapes.push_back(param->shape());
829 });
830 std::optional<TransposeDimsAndParams> reduced_dims_and_params_021 =
831 FindTranspose021DimsAndParameters(param_shapes, output_shape);
832
833 if (!reduced_dims_and_params_021.has_value()) {
834 VLOG(3) << "021 transposition not found on instruction "
835 << root->ToString();
836 return std::nullopt;
837 }
838 std::vector<int64_t> params_012 = reduced_dims_and_params_021->params;
839 const Vector3& reduced_dims_021 = reduced_dims_and_params_021->dims;
840
841 if (reduced_dims_021.at(1) < kMinDimensionToTransposeTiled ||
842 reduced_dims_021.at(2) < kMinDimensionToTransposeTiled) {
843 VLOG(3) << "021 transpose not matched: dimensions of transposition "
844 << root->ToString() << " are too small";
845 return std::nullopt;
846 }
847
848 params_012 = FilterInputsForShmemTranspose(fused_computation, params_012);
849 if (params_012.empty()) {
850 VLOG(3) << "021 transpose on " << root->ToString()
851 << "not matched: no inputs matched after filtering "
852 "for shmem access";
853 return std::nullopt;
854 }
855
856 // Each of our shared memory tiles has 32*33 elements (so ~4kb, if the
857 // elements are of size 4 bytes), and CUDA has an architectural limit of
858 // 48kb shared memory per SM. (This is increased to 96kb in Volta, but we
859 // don't use this, in part because it eats into our L1 cache space.)
860 //
861 // For correctness we need to ensure that we don't make more than 48kb worth
862 // of shmem tiles per block. And for performance, we'd probably like to use
863 // significantly less, so that we can fit more than one block at a time on a
864 // gpu core.
865 //
866 // We say without benchmarks that we want at least 3 threads/block,
867 // corresponding to 3 shmem tiles if the elements are 32 bits wide. We
868 // choose which params get the shmem transpose treatment arbitrarily; it's
869 // not clear if there's a Right Choice.
870 //
871 // This is only sound if tiled transposes are the only place where we use
872 // shared memory in fusions. If in the future other fusible ops use shared
873 // memory, we'll have to adjust this heuristic.
874 constexpr int kMinBlocksPerCore = 3;
875 constexpr int64_t kShmemPerCore = 48 * 1024;
876 int64_t shmem_used = 0;
877 for (int64_t i = 0; i < params_012.size(); ++i) {
878 const Shape& operand_shape =
879 fused_computation->parameter_instruction(params_012[i])->shape();
880 shmem_used +=
881 32 * 33 *
882 ShapeUtil::ByteSizeOfPrimitiveType(operand_shape.element_type());
883
884 if (kMinBlocksPerCore * shmem_used > kShmemPerCore) {
885 // Erase this element and everything after it from params_012.
886 params_012.resize(i);
887 break;
888 }
889 }
890
891 if (params_012.empty()) {
892 VLOG(3)
893 << "021 transpose on :" << root->ToString()
894 << " not matched: no inputs matched after filtering for shmem budget";
895 return std::nullopt;
896 }
897
898 return TransposeDimsAndParams{reduced_dims_021, params_012};
899 }
900
901 } // namespace gpu
902 } // namespace xla
903