1 // Copyright 2022 The TensorFlow Runtime Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/lmhlo_to_jitrt.h"
16
17 #include <cstdint>
18 #include <numeric>
19 #include <optional>
20 #include <string>
21 #include <utility>
22
23 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
24 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" // from @llvm-project
25 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" // from @llvm-project
26 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
27 #include "mlir/Dialect/GPU/Transforms/Passes.h" // from @llvm-project
28 #include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project
29 #include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project
30 #include "mlir/IR/Attributes.h" // from @llvm-project
31 #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
32 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
33 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
34 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
35 #include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
36 #include "mlir/IR/MLIRContext.h" // from @llvm-project
37 #include "mlir/IR/SymbolTable.h" // from @llvm-project
38 #include "mlir/IR/TypeRange.h" // from @llvm-project
39 #include "mlir/Pass/Pass.h" // from @llvm-project
40 #include "mlir/Pass/PassManager.h" // from @llvm-project
41 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
42 #include "mlir/Transforms/DialectConversion.h" // from @llvm-project
43 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
44 #include "mlir/Transforms/Passes.h" // from @llvm-project
45 #include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
46 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
47 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.h"
48 #include "tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.h"
49 #include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h"
50 #include "tensorflow/compiler/xla/service/gpu/nccl_all_to_all_thunk.h"
51 #include "tensorflow/compiler/xla/service/gpu/nccl_collective_permute_thunk.h"
52 #include "tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h"
53 #include "tfrt/gpu/kernels/gpu_ops.h" // from @tf_runtime
54 #include "tfrt/gpu/passes/passes.h" // from @tf_runtime
55
56 namespace tensorflow {
57 namespace {
58
59 #define GEN_PASS_CLASSES
60 #include "tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/jitrt_passes.h.inc"
61
62 using mlir::Attribute;
63 using mlir::DialectRegistry;
64 using mlir::FunctionType;
65 using mlir::IntegerAttr;
66 using mlir::MLIRContext;
67 using mlir::ModuleOp;
68 using mlir::NamedAttribute;
69 using mlir::Operation;
70 using mlir::OperationPass;
71 using mlir::success;
72 using mlir::SymbolTable;
73 using mlir::Type;
74 using mlir::TypeRange;
75 using mlir::WalkResult;
76 using mlir::arith::ConstantOp;
77 using mlir::arith::IndexCastOp;
78 using mlir::detail::PassOptions;
79 using mlir::func::CallOp;
80 using mlir::func::FuncOp;
81 using mlir::func::ReturnOp;
82 using mlir::gpu::GPUModuleOp;
83 using mlir::gpu::LaunchFuncOp;
84 using mlir::gpu::MemcpyOp;
85 using mlir::gpu::MemsetOp;
86 using mlir::lmhlo::AllGatherOp;
87 using mlir::lmhlo::AllReduceOp;
88 using mlir::lmhlo::AllToAllOp;
89 using mlir::lmhlo::CaseOp;
90 using mlir::lmhlo::CollectivePermuteOp;
91 using mlir::lmhlo::CustomCallOp;
92 using mlir::lmhlo::FftOp;
93 using mlir::lmhlo::InfeedOp;
94 using mlir::lmhlo::OutfeedOp;
95 using mlir::lmhlo::PartitionIdOp;
96 using mlir::lmhlo::ReduceScatterOp;
97 using mlir::lmhlo::ReplicaIdOp;
98 using mlir::lmhlo::TerminatorOp;
99 using mlir::lmhlo::WhileOp;
100 using mlir::lmhlo_gpu::AllReduceDoneOp;
101 using mlir::lmhlo_gpu::AllReduceStartOp;
102 using mlir::lmhlo_gpu::CholeskyOp;
103 using mlir::lmhlo_gpu::ConvBackwardFilterOp;
104 using mlir::lmhlo_gpu::ConvBackwardInputOp;
105 using mlir::lmhlo_gpu::ConvForwardFusedOp;
106 using mlir::lmhlo_gpu::ConvForwardFusedSideInputOp;
107 using mlir::lmhlo_gpu::ConvForwardOp;
108 using mlir::lmhlo_gpu::CublasLtMatmulOp;
109 using mlir::lmhlo_gpu::GEMMOp;
110 using mlir::memref::AllocaOp;
111 using mlir::memref::GetGlobalOp;
112
113 static constexpr const char kDirectCustomCall[] = "rt.direct_custom_call";
114
115 class ConvertLmhloConstantToArgPass
116 : public ConvertLmhloConstantToArgPassBase<ConvertLmhloConstantToArgPass> {
117 public:
118 ConvertLmhloConstantToArgPass() = default;
ConvertLmhloConstantToArgPass(int64_t min_num_elements)119 explicit ConvertLmhloConstantToArgPass(int64_t min_num_elements) {
120 this->min_num_elements_ = min_num_elements;
121 }
122
123 void runOnOperation() override;
124
getDependentDialects(DialectRegistry & registry) const125 void getDependentDialects(DialectRegistry& registry) const override {
126 registry.insert<mlir::memref::MemRefDialect>();
127 }
128 };
129
130 class ConvertGpuToJitRtPass
131 : public ConvertGpuToJitRtPassBase<ConvertGpuToJitRtPass> {
132 void runOnOperation() override;
133
getDependentDialects(DialectRegistry & registry) const134 void getDependentDialects(DialectRegistry& registry) const override {
135 registry.insert<mlir::func::FuncDialect, mlir::arith::ArithmeticDialect>();
136 }
137 };
138
139 class ConvertLmhloGpuToJitRtPass
140 : public ConvertLmhloGpuToJitRtPassBase<ConvertLmhloGpuToJitRtPass> {
141 void runOnOperation() override;
142
getDependentDialects(DialectRegistry & registry) const143 void getDependentDialects(DialectRegistry& registry) const override {
144 registry.insert<mlir::func::FuncDialect, mlir::arith::ArithmeticDialect,
145 mlir::scf::SCFDialect, mlir::memref::MemRefDialect,
146 mlir::cf::ControlFlowDialect>();
147 }
148 };
149
150 } // namespace
151
152 // -------------------------------------------------------------------------- //
153
154 class GpuModuleOpLowering : public OpRewritePattern<GPUModuleOp> {
155 public:
156 using OpRewritePattern::OpRewritePattern;
157
matchAndRewrite(GPUModuleOp op,PatternRewriter & rewriter) const158 LogicalResult matchAndRewrite(GPUModuleOp op,
159 PatternRewriter& rewriter) const override {
160 rewriter.eraseOp(op);
161 return success();
162 }
163 };
164
165 // -------------------------------------------------------------------------- //
166
167 class TerminatorOpLowering : public OpRewritePattern<TerminatorOp> {
168 public:
169 using OpRewritePattern::OpRewritePattern;
170
matchAndRewrite(TerminatorOp op,PatternRewriter & rewriter) const171 LogicalResult matchAndRewrite(TerminatorOp op,
172 PatternRewriter& rewriter) const override {
173 rewriter.replaceOpWithNewOp<ReturnOp>(op);
174 return mlir::success();
175 }
176 };
177
178 // -------------------------------------------------------------------------- //
179
180 template <typename IoFeedOp>
181 class IoFeedOpLowering : public OpRewritePattern<IoFeedOp> {
182 private:
CustomCallTarget(InfeedOp)183 static StringRef CustomCallTarget(InfeedOp) { return "xla.gpu.infeed"; }
CustomCallTarget(OutfeedOp)184 static StringRef CustomCallTarget(OutfeedOp) { return "xla.gpu.outfeed"; }
185
186 public:
IoFeedOpLowering(MLIRContext * ctx)187 explicit IoFeedOpLowering(MLIRContext* ctx)
188 : OpRewritePattern<IoFeedOp>(ctx) {}
189
matchAndRewrite(IoFeedOp op,PatternRewriter & rewriter) const190 LogicalResult matchAndRewrite(IoFeedOp op,
191 PatternRewriter& rewriter) const override {
192 MLIRContext* ctx = this->getContext();
193 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
194
195 // Custom call target.
196 NamedAttribute target(b.getStringAttr(kDirectCustomCall),
197 b.getStringAttr(CustomCallTarget(op)));
198
199 // Create a custom call function declaration.
200 auto custom_call_type =
201 FunctionType::get(ctx, op.getOperandTypes(), TypeRange());
202 auto custom_call_attrs = ArrayRef<NamedAttribute>(target);
203 auto custom_call = FuncOp::create(op.getLoc(), CustomCallTarget(op),
204 custom_call_type, custom_call_attrs);
205 custom_call.setPrivate();
206
207 SymbolTable sym_table(op->template getParentOfType<ModuleOp>());
208 auto inserted = sym_table.insert(custom_call);
209 rewriter.notifyOperationInserted(custom_call);
210
211 // Call the runtime intrinsic with the original operands.
212 auto call = rewriter.create<CallOp>(op.getLoc(), inserted, TypeRange(),
213 op.getOperands());
214 call->setAttr(b.getStringAttr("config"), op.getConfigAttr());
215
216 // Erase the original infeed/outfeed operation.
217 rewriter.eraseOp(op);
218
219 return success();
220 }
221 };
222
223 class InfeedOpLowering : public IoFeedOpLowering<InfeedOp> {
224 public:
225 using IoFeedOpLowering::IoFeedOpLowering;
226 };
227
228 class OutfeedOpLowering : public IoFeedOpLowering<OutfeedOp> {
229 public:
230 using IoFeedOpLowering::IoFeedOpLowering;
231 };
232
233 // -------------------------------------------------------------------------- //
234
235 class MemcpyOpLowering : public OpRewritePattern<MemcpyOp> {
236 public:
237 using OpRewritePattern::OpRewritePattern;
238
239 // We use a heuristic to identify the direction of the memcpy operation, if
240 // the operand was allocated by alloca op or is a global memref, then it must
241 // be a memref on the host.
IsHostMemRef(Value value)242 static bool IsHostMemRef(Value value) {
243 auto* op = value.getDefiningOp();
244 return llvm::isa_and_nonnull<memref::AllocaOp, memref::GetGlobalOp>(op);
245 }
246
matchAndRewrite(MemcpyOp op,PatternRewriter & rewriter) const247 LogicalResult matchAndRewrite(MemcpyOp op,
248 PatternRewriter& rewriter) const override {
249 MLIRContext* ctx = getContext();
250 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
251
252 // Identify the direction of the memcpy operation.
253 auto memcpy = [&]() {
254 if (IsHostMemRef(op.dst())) return "xla.gpu.memcpy.d2h";
255 if (IsHostMemRef(op.src())) return "xla.gpu.memcpy.h2d";
256 return "xla.gpu.memcpy.d2d";
257 }();
258
259 // Custom call target.
260 NamedAttribute target(b.getStringAttr(kDirectCustomCall),
261 b.getStringAttr(memcpy));
262
263 // Create a custom call function declaration.
264 auto custom_call_type =
265 FunctionType::get(ctx, op.getOperandTypes(), TypeRange());
266 auto custom_call_attrs = ArrayRef<NamedAttribute>(target);
267 auto custom_call = FuncOp::create(op.getLoc(), memcpy, custom_call_type,
268 custom_call_attrs);
269 custom_call.setPrivate();
270
271 SymbolTable sym_table(op->getParentOfType<ModuleOp>());
272 auto inserted = sym_table.insert(custom_call);
273 rewriter.notifyOperationInserted(custom_call);
274
275 // Create a function launch call operation.
276 rewriter.replaceOpWithNewOp<CallOp>(op, inserted, TypeRange(),
277 op.getOperands());
278
279 return success();
280 }
281 };
282
283 // -------------------------------------------------------------------------- //
284
285 class MemsetOpLowering : public OpRewritePattern<MemsetOp> {
286 private:
287 static constexpr const char kCustomCallTarget[] = "xla.gpu.memset";
288
289 public:
290 using OpRewritePattern::OpRewritePattern;
291
matchAndRewrite(MemsetOp op,PatternRewriter & rewriter) const292 LogicalResult matchAndRewrite(MemsetOp op,
293 PatternRewriter& rewriter) const override {
294 MLIRContext* ctx = getContext();
295 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
296
297 // Custom call target.
298 NamedAttribute target(b.getStringAttr(kDirectCustomCall),
299 b.getStringAttr(kCustomCallTarget));
300
301 // Create a custom call function declaration.
302 auto custom_call_type =
303 FunctionType::get(ctx, op.getOperandTypes(), TypeRange());
304 auto custom_call_attrs = ArrayRef<NamedAttribute>(target);
305 auto custom_call = FuncOp::create(op.getLoc(), kCustomCallTarget,
306 custom_call_type, custom_call_attrs);
307 custom_call.setPrivate();
308
309 SymbolTable sym_table(op->getParentOfType<ModuleOp>());
310 auto inserted = sym_table.insert(custom_call);
311 rewriter.notifyOperationInserted(custom_call);
312
313 // Create a function launch call operation.
314 rewriter.replaceOpWithNewOp<CallOp>(op, inserted, TypeRange(),
315 op.getOperands());
316
317 return success();
318 }
319 };
320
321 // -------------------------------------------------------------------------- //
322
323 class LaunchFuncOpLowering : public OpRewritePattern<LaunchFuncOp> {
324 private:
325 static constexpr const char kCustomCallTarget[] = "xla.gpu.func.launch";
326
327 public:
328 using OpRewritePattern::OpRewritePattern;
329
matchAndRewrite(LaunchFuncOp op,PatternRewriter & rewriter) const330 LogicalResult matchAndRewrite(LaunchFuncOp op,
331 PatternRewriter& rewriter) const override {
332 MLIRContext* ctx = getContext();
333 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
334
335 ModuleOp module = op->getParentOfType<ModuleOp>();
336
337 // Cast grid and block dimensions to i32 before passing to the custom call.
338 auto cast = [&](mlir::Value value) {
339 return b.create<IndexCastOp>(b.getI32Type(), value);
340 };
341
342 // Prepare arguments for the custom call.
343 llvm::SmallVector<Value> args = {
344 cast(op.gridSizeX()), cast(op.gridSizeY()), cast(op.gridSizeZ()),
345 cast(op.blockSizeX()), cast(op.blockSizeY()), cast(op.blockSizeZ())};
346
347 // Add kernel arguments.
348 llvm::copy(op.operands(), std::back_inserter(args));
349
350 // Types of the custom call arguments.
351 llvm::SmallVector<Type> args_types = TypeRange(ValueRange(args));
352
353 // Custom call target.
354 NamedAttribute target(b.getStringAttr(kDirectCustomCall),
355 b.getStringAttr(kCustomCallTarget));
356
357 // Create a custom call function declaration.
358 auto custom_call_type = FunctionType::get(ctx, args_types, TypeRange());
359 auto custom_call_attrs = ArrayRef<NamedAttribute>(target);
360 auto custom_call = FuncOp::create(op.getLoc(), kCustomCallTarget,
361 custom_call_type, custom_call_attrs);
362 custom_call.setPrivate();
363
364 SymbolTable sym_table(module);
365 auto inserted = sym_table.insert(custom_call);
366 rewriter.notifyOperationInserted(custom_call);
367
368 // Get the compiled gpu function.
369 auto* kernel = SymbolTable::lookupNearestSymbolFrom(op, op.kernel());
370 assert(kernel && "kernel not found");
371
372 // Get the compiled GPU binary from the device kernel module.
373 auto gpu_module = kernel->getParentOfType<mlir::gpu::GPUModuleOp>();
374 auto gpu_binary = gpu_module->getAttrOfType<mlir::StringAttr>("binary");
375
376 // Create a function launch call operation.
377 auto call = b.create<CallOp>(inserted, TypeRange(), args);
378 call->setAttr(b.getStringAttr("ptx"), gpu_binary);
379 call->setAttr(b.getStringAttr("kernel"), op.getKernelName());
380
381 // Erase the original gpu launch operation.
382 rewriter.eraseOp(op);
383
384 return success();
385 }
386 };
387
388 // -------------------------------------------------------------------------- //
389
390 // Every Gemm operation in the module gets assigned a unique id, that is passed
391 // to the custom call handler. This id is used for caching resources between the
392 // different invocations of the same gemm operation.
393 class GemmUidGenerator {
394 public:
GemmUidGenerator()395 GemmUidGenerator() : cnt_(0) {}
uid()396 int64_t uid() { return cnt_.fetch_add(1); }
397
398 private:
399 std::atomic<int64_t> cnt_;
400 };
401
402 class GemmOpLowering : public OpRewritePattern<GEMMOp> {
403 private:
404 static constexpr const char kCustomCallTarget[] = "xla.gpu.gemm";
405
406 public:
GemmOpLowering(MLIRContext * ctx,GemmUidGenerator & uid)407 GemmOpLowering(MLIRContext* ctx, GemmUidGenerator& uid)
408 : OpRewritePattern<GEMMOp>(ctx), uid_(uid) {}
409
matchAndRewrite(GEMMOp op,PatternRewriter & rewriter) const410 LogicalResult matchAndRewrite(GEMMOp op,
411 PatternRewriter& rewriter) const override {
412 MLIRContext* ctx = this->getContext();
413 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
414
415 ModuleOp module = op->template getParentOfType<ModuleOp>();
416
417 // Custom call target.
418 NamedAttribute target(b.getStringAttr(kDirectCustomCall),
419 b.getStringAttr(kCustomCallTarget));
420
421 // Create a custom call function declaration.
422 auto custom_call_type =
423 FunctionType::get(ctx, op.getOperandTypes(), TypeRange());
424 auto custom_call_attrs = ArrayRef<NamedAttribute>(target);
425 auto custom_call = FuncOp::create(op.getLoc(), kCustomCallTarget,
426 custom_call_type, custom_call_attrs);
427 custom_call.setPrivate();
428
429 SymbolTable sym_table(module);
430 auto inserted = sym_table.insert(custom_call);
431 rewriter.notifyOperationInserted(custom_call);
432
433 // Convert Gemm to a function call.
434 auto call = rewriter.create<CallOp>(op.getLoc(), inserted, TypeRange(),
435 op.getOperands());
436
437 // Assign a unique id to this instance of a gemm operation.
438 call->setAttr(b.getStringAttr("uid"), b.getI64IntegerAttr(uid_.uid()));
439
440 // Copy backend specific attributes.
441 auto algorithm_attr = op.getAlgorithm()
442 ? op.getAlgorithmAttr()
443 : b.getI64IntegerAttr(se::blas::kDefaultGemmAlgo);
444 call->setAttr(b.getStringAttr("algorithm"), algorithm_attr);
445 call->setAttr(b.getStringAttr("alpha_imag"), op.getAlphaImagAttr());
446 call->setAttr(b.getStringAttr("alpha_real"), op.getAlphaRealAttr());
447 call->setAttr(b.getStringAttr("beta"), op.getBetaAttr());
448 call->setAttr(b.getStringAttr("dot_dims"), op.getDotDimensionNumbers());
449
450 // Erase the original gemm operation.
451 rewriter.eraseOp(op);
452
453 return success();
454 }
455
456 private:
457 GemmUidGenerator& uid_;
458 };
459
460 // -------------------------------------------------------------------------- //
461
462 class CublasLtMatmulOpLowering : public OpRewritePattern<CublasLtMatmulOp> {
463 private:
464 static constexpr const char kCustomCallTarget[] = "xla.gpu.cublas.lt.matmul";
465
466 public:
CublasLtMatmulOpLowering(MLIRContext * ctx,GemmUidGenerator & uid)467 CublasLtMatmulOpLowering(MLIRContext* ctx, GemmUidGenerator& uid)
468 : OpRewritePattern<CublasLtMatmulOp>(ctx), uid_(uid) {}
469
matchAndRewrite(CublasLtMatmulOp op,PatternRewriter & rewriter) const470 LogicalResult matchAndRewrite(CublasLtMatmulOp op,
471 PatternRewriter& rewriter) const override {
472 MLIRContext* ctx = this->getContext();
473 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
474
475 ModuleOp module = op->template getParentOfType<ModuleOp>();
476
477 std::string matmul;
478 switch (op.getOperands().size()) {
479 case 4:
480 matmul = kCustomCallTarget;
481 break;
482 case 5:
483 matmul = absl::StrCat(kCustomCallTarget, ".bias");
484 break;
485 default:
486 return op.emitOpError("unexpected number of operands for matmul");
487 }
488
489 // Custom call target.
490 NamedAttribute target(b.getStringAttr(kDirectCustomCall),
491 b.getStringAttr(matmul));
492
493 // Create a custom call function declaration.
494 auto custom_call_type =
495 FunctionType::get(ctx, op.getOperandTypes(), TypeRange());
496 auto custom_call_attrs = ArrayRef<NamedAttribute>(target);
497 auto custom_call = FuncOp::create(op.getLoc(), matmul, custom_call_type,
498 custom_call_attrs);
499 custom_call.setPrivate();
500
501 SymbolTable sym_table(module);
502 auto inserted = sym_table.insert(custom_call);
503 rewriter.notifyOperationInserted(custom_call);
504
505 // Convert matmul to a function call.
506 auto call = rewriter.create<CallOp>(op.getLoc(), inserted, TypeRange(),
507 op.getOperands());
508
509 // Assign a unique id to this instance of a matmul operation.
510 call->setAttr(b.getStringAttr("uid"), b.getI64IntegerAttr(uid_.uid()));
511
512 // Copy backend specific attributes.
513 call->setAttr(b.getStringAttr("algorithm"), op.getAlgorithmAttr());
514 call->setAttr(b.getStringAttr("alpha_imag"), op.getAlphaImagAttr());
515 call->setAttr(b.getStringAttr("alpha_real"), op.getAlphaRealAttr());
516 call->setAttr(b.getStringAttr("beta"), op.getBetaAttr());
517 call->setAttr(b.getStringAttr("dot_dims"), op.getDotDimensionNumbers());
518 call->setAttr(b.getStringAttr("epilogue"), op.getEpilogueAttr());
519
520 // TODO(ezhulenev): Today we can't pass an array of enum attributes to the
521 // custom call. Also we do not have a corresponding precision enum on the
522 // SE/XLA side, so we encode it as an i32 array (tensor).
523 if (auto precisions = op.getPrecisionConfig()) {
524 llvm::SmallVector<int32_t> values;
525 for (auto precision : *precisions) {
526 auto value = precision.cast<mhlo::PrecisionAttr>().getValue();
527 values.push_back(static_cast<int32_t>(value));
528 }
529 call->setAttr(b.getStringAttr("precision"), b.getI32TensorAttr(values));
530 } else {
531 call->setAttr(b.getStringAttr("precision"), b.getI32TensorAttr({0, 0}));
532 }
533
534 // Erase the original matmul operation.
535 rewriter.eraseOp(op);
536
537 return success();
538 }
539
540 private:
541 GemmUidGenerator& uid_;
542 };
543
544 // -------------------------------------------------------------------------- //
545
546 template <typename Conv>
547 class ConvOpLowering : public OpRewritePattern<Conv> {
548 private:
CustomCallTarget(ConvForwardOp)549 static StringRef CustomCallTarget(ConvForwardOp) {
550 return "xla.gpu.conv.forward";
551 }
CustomCallTarget(ConvForwardFusedOp)552 static StringRef CustomCallTarget(ConvForwardFusedOp) {
553 return "xla.gpu.conv.forward.fused";
554 }
CustomCallTarget(ConvForwardFusedSideInputOp)555 static StringRef CustomCallTarget(ConvForwardFusedSideInputOp) {
556 return "xla.gpu.conv.forward.fused.side_input";
557 }
CustomCallTarget(ConvBackwardFilterOp)558 static StringRef CustomCallTarget(ConvBackwardFilterOp) {
559 return "xla.gpu.conv.backward.filter";
560 }
CustomCallTarget(ConvBackwardInputOp)561 static StringRef CustomCallTarget(ConvBackwardInputOp) {
562 return "xla.gpu.conv.backward.input";
563 }
564
565 public:
ConvOpLowering(MLIRContext * ctx)566 explicit ConvOpLowering(MLIRContext* ctx) : OpRewritePattern<Conv>(ctx) {}
567
matchAndRewrite(Conv op,PatternRewriter & rewriter) const568 LogicalResult matchAndRewrite(Conv op,
569 PatternRewriter& rewriter) const override {
570 MLIRContext* ctx = this->getContext();
571 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
572
573 ModuleOp module = op->template getParentOfType<ModuleOp>();
574
575 // Custom call target.
576 NamedAttribute target(b.getStringAttr(kDirectCustomCall),
577 b.getStringAttr(CustomCallTarget(op)));
578
579 // Create a custom call function declaration.
580 auto custom_call_type =
581 FunctionType::get(ctx, op.getOperandTypes(), TypeRange());
582 auto custom_call_attrs = ArrayRef<NamedAttribute>(target);
583 auto custom_call = FuncOp::create(op.getLoc(), CustomCallTarget(op),
584 custom_call_type, custom_call_attrs);
585 custom_call.setPrivate();
586
587 SymbolTable sym_table(module);
588 auto inserted = sym_table.insert(custom_call);
589 rewriter.notifyOperationInserted(custom_call);
590
591 // Convert Conv to a function call.
592 auto call = rewriter.create<CallOp>(op.getLoc(), inserted, TypeRange(),
593 op.getOperands());
594
595 // Helper functins to copy attributes from the conv op to the custom call.
596 auto set_attr = [&](StringRef name, Attribute attr) {
597 call->setAttr(b.getStringAttr(name), attr);
598 };
599
600 auto set_xi64 = [&](StringRef name, Optional<DenseIntElementsAttr> attr) {
601 SmallVector<int64_t> values;
602 if (attr.has_value())
603 values = llvm::to_vector(attr->getValues<int64_t>());
604 set_attr(name, b.getI64TensorAttr(values));
605 };
606
607 // Convert `BoolElementsAttr` to i64 before passing to the runtime.
608 // TODO(ezhulenev): Allow passing boolean tensors to the JitRt custom calls.
609 auto set_xi1 = [&](StringRef name, Optional<DenseElementsAttr> attr) {
610 SmallVector<int64_t> values;
611 if (attr.has_value())
612 values.assign(attr->getValues<bool>().begin(),
613 attr->getValues<bool>().end());
614 set_attr(name, b.getI64TensorAttr(values));
615 };
616
617 // Copy dimension number attributes.
618 call->setAttr(b.getStringAttr("conv_dims"), op.getDimensionNumbers());
619
620 // Copy convolution window attributes.
621 set_xi1("window_reversal", op.getWindowReversal());
622 set_xi64("window_strides", op.getWindowStrides());
623 set_xi64("lhs_dilation", op.getLhsDilation());
624 set_xi64("rhs_dilation", op.getRhsDilation());
625 set_xi64("padding", op.getPadding());
626
627 // Copy backend config.
628 call->setAttr(b.getStringAttr("backend_config"), op.getBackendConfig());
629
630 // Copy remaining attributes.
631 set_attr("feature_group_count", op.getFeatureGroupCountAttr());
632 set_attr("result_scale", op.getResultScaleAttr());
633
634 // Copy attributes specific for fused convolutions.
635 if (auto fused = dyn_cast<ConvForwardFusedOp>(op.getOperation())) {
636 call->setAttr(b.getStringAttr("activation_mode"),
637 fused.getActivationModeAttr());
638 }
639
640 // Copy attributes specific for fused convolutions with side input.
641 if (auto fused = dyn_cast<ConvForwardFusedSideInputOp>(op.getOperation())) {
642 call->setAttr(b.getStringAttr("activation_mode"),
643 fused.getActivationModeAttr());
644 set_attr("side_input_scale", fused.getSideInputScaleAttr());
645 }
646
647 // Erase the original conv operation.
648 rewriter.eraseOp(op);
649
650 return success();
651 }
652 };
653
654 class ConvForwardOpLowering : public ConvOpLowering<ConvForwardOp> {
655 public:
656 using ConvOpLowering::ConvOpLowering;
657 };
658
659 class ConvForwardFusedOpLowering : public ConvOpLowering<ConvForwardFusedOp> {
660 public:
661 using ConvOpLowering::ConvOpLowering;
662 };
663
664 class ConvBackwardFilterOpLowering
665 : public ConvOpLowering<ConvBackwardFilterOp> {
666 public:
667 using ConvOpLowering::ConvOpLowering;
668 };
669
670 class ConvBackwardInputOpLowering : public ConvOpLowering<ConvBackwardInputOp> {
671 public:
672 using ConvOpLowering::ConvOpLowering;
673 };
674
675 class ConvForwardFusedSideInputOpLowering
676 : public ConvOpLowering<ConvForwardFusedSideInputOp> {
677 public:
678 using ConvOpLowering::ConvOpLowering;
679 };
680
681 // -------------------------------------------------------------------------- //
682
683 class WhileOpLowering : public OpRewritePattern<WhileOp> {
684 public:
685 using OpRewritePattern::OpRewritePattern;
686
matchAndRewrite(WhileOp op,PatternRewriter & rewriter) const687 LogicalResult matchAndRewrite(WhileOp op,
688 PatternRewriter& rewriter) const override {
689 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
690
691 // Create an `scf.while` loop in place of `lmhlo.while` loop.
692 auto loop = b.create<scf::WhileOp>(TypeRange(), ValueRange());
693
694 // Predicate buffer placed on the device.
695 assert(op.getNumOperands() == 1 && "expected single cond operand");
696 Value pred = op.getOperand(0);
697
698 // Clone condition and body blocks into the new loop operation.
699 BlockAndValueMapping mapping;
700 op.getCond().cloneInto(&loop.getBefore(), mapping);
701 op.getBody().cloneInto(&loop.getAfter(), mapping);
702
703 { // Replace loop condition terminator.
704 auto* terminator = loop.getBefore().back().getTerminator();
705 b.setInsertionPointAfter(terminator);
706
707 // Copy predicate buffer to the host ...
708 auto i1 = b.getI1Type();
709 Value pred_on_host = b.create<memref::AllocaOp>(MemRefType::get({}, i1));
710 b.create<gpu::MemcpyOp>(TypeRange(), ValueRange({pred_on_host, pred}));
711
712 // .. and check if we need to continue loop iteration.
713 Value cond = b.create<memref::LoadOp>(i1, pred_on_host, ValueRange());
714 b.create<scf::ConditionOp>(cond, ValueRange());
715 rewriter.eraseOp(terminator);
716 }
717
718 { // Replace loop body terminator.
719 auto* terminator = loop.getAfter().back().getTerminator();
720 b.setInsertionPointAfter(terminator);
721 b.create<scf::YieldOp>(TypeRange(), ValueRange());
722 rewriter.eraseOp(terminator);
723 }
724
725 // Erase the original while loop.
726 rewriter.eraseOp(op);
727
728 return success();
729 }
730 };
731
732 // -------------------------------------------------------------------------- //
733
734 class CaseOpLowering : public OpRewritePattern<CaseOp> {
735 public:
736 using OpRewritePattern::OpRewritePattern;
737
matchAndRewrite(CaseOp op,PatternRewriter & rewriter) const738 LogicalResult matchAndRewrite(CaseOp op,
739 PatternRewriter& rewriter) const override {
740 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
741
742 // Copy index buffer to the host ...
743 auto index_type = op.getIndex().getType().dyn_cast<MemRefType>();
744 Value index_on_host = b.create<memref::AllocaOp>(index_type);
745 b.create<gpu::MemcpyOp>(TypeRange(),
746 ValueRange({index_on_host, op.getIndex()}));
747
748 // Get the index value from the buffer.
749 Value index = b.create<memref::LoadOp>(index_type.getElementType(),
750 index_on_host, ValueRange());
751
752 bool is_predicate = index_type.getElementType().isInteger(1);
753
754 // For binary index (predicate) convert i1 to i32 index.
755 if (is_predicate) {
756 Value c0 = b.create<ConstantOp>(b.getI32IntegerAttr(0));
757 Value c1 = b.create<ConstantOp>(b.getI32IntegerAttr(1));
758 index = b.create<arith::SelectOp>(index, c0, c1);
759 }
760
761 // For integer index make sure that it is within range.
762 if (!is_predicate) {
763 unsigned n = op.getNumRegions() - 1;
764 Value c0 = b.create<ConstantOp>(b.getI32IntegerAttr(0));
765 Value cN = b.create<ConstantOp>(b.getI32IntegerAttr(n));
766
767 Value too_small = b.create<arith::CmpIOp>(
768 b.getI1Type(), arith::CmpIPredicate::slt, index, c0);
769 Value too_large = b.create<arith::CmpIOp>(
770 b.getI1Type(), arith::CmpIPredicate::sgt, index, cN);
771
772 Value out_of_range = b.create<arith::OrIOp>(too_small, too_large);
773 index = b.create<arith::SelectOp>(out_of_range, cN, index);
774 }
775
776 // Split block right at the case operation.
777 Block* cont = rewriter.splitBlock(op->getBlock(), op->getIterator());
778 Block* orig = cont->getPrevNode();
779
780 // Prepare case destinations for the `scf.switch` operation.
781 llvm::SmallVector<llvm::APInt> case_values;
782 llvm::SmallVector<Block*> case_blocks;
783 llvm::SmallVector<ValueRange> case_operands;
784
785 // Create blocks from each of the case regions.
786 for (Region& region : op->getRegions()) {
787 // Move `lmhlo.case` block before the continuation.
788 Block& block = region.front();
789 block.moveBefore(cont);
790
791 // Erase original `lmhlo.terminator`.
792 rewriter.eraseOp(block.getTerminator());
793
794 // Branch into the continuation block.
795 b.setInsertionPointToEnd(&block);
796 b.create<cf::BranchOp>(cont);
797
798 // Add a `cf.switch` case.
799 int32_t idx = case_blocks.size();
800 case_values.push_back(b.getI32IntegerAttr(idx).getValue());
801 case_blocks.push_back(&block);
802 case_operands.push_back({});
803 }
804
805 // Replace `lmhlo.case` with a `cf.switch` operation on the host.
806 b.setInsertionPointToEnd(orig);
807 b.create<cf::SwitchOp>(index, cont, ValueRange(), case_values, case_blocks,
808 case_operands);
809
810 // Erase the original case operation.
811 rewriter.eraseOp(op);
812
813 return success();
814 }
815 };
816
817 // -------------------------------------------------------------------------- //
818
819 class CustomCallOpLowering : public OpRewritePattern<CustomCallOp> {
820 private:
821 static constexpr const char kCustomCallTarget[] = "xla.gpu.custom_call";
822
823 public:
824 using OpRewritePattern::OpRewritePattern;
825
matchAndRewrite(CustomCallOp op,PatternRewriter & rewriter) const826 LogicalResult matchAndRewrite(CustomCallOp op,
827 PatternRewriter& rewriter) const override {
828 MLIRContext* ctx = this->getContext();
829 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
830
831 // Custom call target.
832 NamedAttribute target(b.getStringAttr(kDirectCustomCall),
833 b.getStringAttr(kCustomCallTarget));
834
835 // By default all operands passed to the custom call handler.
836 llvm::SmallVector<Value> operands = op.getOperands();
837
838 // If custom call has target arguments mapping, then we need to pass empty
839 // memrefs in place of holes.
840 if (op.getTargetArgMapping().has_value()) {
841 auto mapping = *op.getTargetArgMapping();
842 int64_t num_args = mapping.getNumArgs();
843 int64_t num_results = mapping.getNumResults();
844
845 // We represent holes as empty i8 memrefs.
846 Value hole = b.create<AllocaOp>(MemRefType::get({0}, b.getI8Type()));
847 operands = llvm::SmallVector<Value>(num_args + num_results, hole);
848
849 // Update operands to mapped custom call arguments.
850 auto args = mapping.getArgsToTargetArgs();
851 for (const auto& indexed : llvm::enumerate(args))
852 operands[indexed.value()] = op.getArgs()[indexed.index()];
853
854 // Update operands to mapped custom call results.
855 auto res = mapping.getResultsToTargetResults();
856 for (const auto& indexed : llvm::enumerate(res))
857 operands[num_args + indexed.value()] = op.getOutput()[indexed.index()];
858 }
859
860 // Create a custom call function declaration.
861 auto custom_call_type =
862 FunctionType::get(ctx, TypeRange(ValueRange(operands)), TypeRange());
863
864 auto custom_call_attrs = ArrayRef<NamedAttribute>(target);
865 auto custom_call = FuncOp::create(op.getLoc(), kCustomCallTarget,
866 custom_call_type, custom_call_attrs);
867 custom_call.setPrivate();
868
869 SymbolTable sym_table(op->getParentOfType<ModuleOp>());
870 auto inserted = sym_table.insert(custom_call);
871 rewriter.notifyOperationInserted(custom_call);
872
873 // Call the runtime intrinsic with the original operands.
874 auto call =
875 rewriter.create<CallOp>(op.getLoc(), inserted, TypeRange(), operands);
876
877 // Pass attributes to the custom call handler.
878 auto set_attr = [&](StringRef name, Attribute attr) {
879 call->setAttr(b.getStringAttr(name), attr);
880 };
881
882 set_attr("api_version", op.getApiVersionAttr());
883 set_attr("backend_config", op.getBackendConfigAttr());
884 set_attr("call_target_name", op.getCallTargetNameAttr());
885
886 // Erase the original infeed/outfeed operation.
887 rewriter.eraseOp(op);
888
889 return success();
890 }
891 };
892
893 // -------------------------------------------------------------------------- //
894
895 using GlobalConstantsArgs = llvm::DenseMap<FuncOp, llvm::StringMap<Value>>;
896
897 // Returns a mapping from a global constant name to the function argument.
898 //
899 // Example:
900 //
901 // memref.global "private" constant @cst : memref<2x3xf32>
902 // func @get_global(%arg0: memref<24xi8> {lmhlo.constant_name = "cst"})
903 //
904 // All memref.get_global operations will be replaced by constant arguments
905 // corresponding to the global constant.
GetConstantArgs(ModuleOp m)906 GlobalConstantsArgs GetConstantArgs(ModuleOp m) {
907 GlobalConstantsArgs mapping;
908
909 m.walk([&](FuncOp func) {
910 for (unsigned i = 0; i < func.getNumArguments(); ++i) {
911 auto cst = func.getArgAttrOfType<StringAttr>(i, "lmhlo.constant_name");
912 if (cst) mapping[func][cst] = func.getArgument(i);
913 }
914 });
915
916 return mapping;
917 }
918
919 class GetGlobalOpLowering : public OpRewritePattern<GetGlobalOp> {
920 public:
GetGlobalOpLowering(MLIRContext * ctx,const GlobalConstantsArgs & cst_args)921 GetGlobalOpLowering(MLIRContext* ctx, const GlobalConstantsArgs& cst_args)
922 : OpRewritePattern<GetGlobalOp>(ctx), cst_args_(cst_args) {}
923
matchAndRewrite(GetGlobalOp op,PatternRewriter & rewriter) const924 LogicalResult matchAndRewrite(GetGlobalOp op,
925 PatternRewriter& rewriter) const override {
926 // Find global constants mapping for the parent function.
927 auto func_mapping = cst_args_.find(op->getParentOfType<FuncOp>());
928 if (func_mapping == cst_args_.end()) return failure();
929
930 // Check if the global operation correposponds to the LMHLO constant arg.
931 auto arg = func_mapping->second.find(op.getName());
932 if (arg == func_mapping->second.end()) return failure();
933
934 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
935 MemRefType memref = op->getResult(0).getType().cast<MemRefType>();
936
937 // For identity layouts we can replace all loads from a global with the
938 // corresponding argument.
939 if (memref.getLayout().isIdentity()) {
940 Value c0 = b.create<ConstantOp>(rewriter.getIndexAttr(0));
941 rewriter.replaceOpWithNewOp<memref::ViewOp>(op, memref, arg->second, c0,
942 ValueRange());
943 return success();
944 }
945
946 // For non-identity type we first view constant argument as a flat memref
947 // with the correct element type, and then cast it to the strided memref
948 // corresponding to the original memref layout.
949
950 // Get the strides and offset from the original memref type.
951 int64_t offset;
952 llvm::SmallVector<int64_t> strides;
953 if (failed(getStridesAndOffset(memref, strides, offset)))
954 return op.emitOpError("failed to compute strides and offset");
955
956 // Create a 1d view into the corresponding argument.
957 Value c0 = b.create<ConstantOp>(rewriter.getIndexAttr(0));
958 Value flat_view = b.create<memref::ViewOp>(
959 MemRefType::get({memref.getNumElements()}, memref.getElementType()),
960 arg->second, c0, ValueRange());
961
962 // Cast flat memref view into the original memref type.
963 rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
964 op, memref, flat_view, offset, memref.getShape(), strides);
965
966 return success();
967 }
968
969 private:
970 const GlobalConstantsArgs& cst_args_;
971 };
972
973 // -------------------------------------------------------------------------- //
974
975 class FftOpLowering : public OpRewritePattern<FftOp> {
976 private:
977 static constexpr const char kCustomCallTarget[] = "xla.gpu.fft";
978
979 public:
980 using OpRewritePattern::OpRewritePattern;
981
matchAndRewrite(FftOp op,PatternRewriter & rewriter) const982 LogicalResult matchAndRewrite(FftOp op,
983 PatternRewriter& rewriter) const override {
984 MLIRContext* ctx = this->getContext();
985 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
986
987 ModuleOp module = op->getParentOfType<ModuleOp>();
988
989 // Custom call target.
990 NamedAttribute target(b.getStringAttr(kDirectCustomCall),
991 b.getStringAttr(kCustomCallTarget));
992
993 // Create a custom call function declaration.
994 auto custom_call_type =
995 FunctionType::get(ctx, op.getOperandTypes(), TypeRange());
996 auto custom_call_attrs = ArrayRef<NamedAttribute>(target);
997 auto custom_call = FuncOp::create(op.getLoc(), kCustomCallTarget,
998 custom_call_type, custom_call_attrs);
999 custom_call.setPrivate();
1000
1001 SymbolTable sym_table(module);
1002 auto inserted = sym_table.insert(custom_call);
1003 rewriter.notifyOperationInserted(custom_call);
1004
1005 // Convert Fft to a function call.
1006 auto call = rewriter.create<CallOp>(op.getLoc(), inserted, TypeRange(),
1007 op.getOperands());
1008
1009 // Copy backend specific attributes.
1010 call->setAttr(b.getStringAttr("fft_length"), op.getFftLengthAttr());
1011 call->setAttr(b.getStringAttr("fft_type"), op.getFftTypeAttr());
1012
1013 // Erase the original Fft operation.
1014 rewriter.eraseOp(op);
1015
1016 return success();
1017 }
1018 };
1019
1020 // -------------------------------------------------------------------------- //
1021
1022 class CholeskyOpLowering : public OpRewritePattern<CholeskyOp> {
1023 private:
1024 static constexpr const char kCustomCallTarget[] = "xla.gpu.cholesky";
1025
1026 public:
CholeskyOpLowering(MLIRContext * ctx)1027 explicit CholeskyOpLowering(MLIRContext* ctx)
1028 : OpRewritePattern<CholeskyOp>(ctx) {}
1029
matchAndRewrite(CholeskyOp op,PatternRewriter & rewriter) const1030 LogicalResult matchAndRewrite(CholeskyOp op,
1031 PatternRewriter& rewriter) const override {
1032 MLIRContext* ctx = this->getContext();
1033 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
1034
1035 ModuleOp module = op->getParentOfType<ModuleOp>();
1036
1037 // Custom call target.
1038 NamedAttribute target(b.getStringAttr(kDirectCustomCall),
1039 b.getStringAttr(kCustomCallTarget));
1040
1041 // Create a custom call function declaration.
1042 auto custom_call_type =
1043 FunctionType::get(ctx, op.getOperandTypes(), TypeRange());
1044 auto custom_call_attrs = ArrayRef<NamedAttribute>(target);
1045 auto custom_call = FuncOp::create(op.getLoc(), kCustomCallTarget,
1046 custom_call_type, custom_call_attrs);
1047 custom_call.setPrivate();
1048
1049 SymbolTable sym_table(module);
1050 auto inserted = sym_table.insert(custom_call);
1051 rewriter.notifyOperationInserted(custom_call);
1052
1053 // Convert Cholesky to a function call.
1054 auto call = rewriter.create<CallOp>(op.getLoc(), inserted, TypeRange(),
1055 op.getOperands());
1056
1057 const auto& dims =
1058 op.getInput().getType().cast<mlir::MemRefType>().getShape();
1059 if (dims.size() < 2)
1060 return op.emitOpError() << "Input's dimension count (" << dims.size()
1061 << ") must be 2 or greater.";
1062 int64_t n = dims[dims.size() - 1];
1063 int64_t batch_size =
1064 std::accumulate(dims.begin(), dims.end() - 2, int64_t{1},
1065 [](int64_t a, int64_t b) { return a * b; });
1066
1067 // Copy backend specific attributes.
1068 call->setAttr(b.getStringAttr("batch_size"),
1069 b.getI64IntegerAttr(batch_size));
1070 call->setAttr(b.getStringAttr("n"), b.getI64IntegerAttr(n));
1071 call->setAttr(b.getStringAttr("is_lower"), op.getIsLowerAttr());
1072
1073 // Erase the original Cholesky operation.
1074 rewriter.eraseOp(op);
1075
1076 return success();
1077 }
1078 };
1079
1080 // -------------------------------------------------------------------------- //
1081
1082 // We assign unique id to all collective operations in the module, so that we
1083 // can efficiently access per-op state at run time. Exception to this rule are
1084 // asynchronous collective operations, that share the same unique id by the pair
1085 // of corresponding `start` and `done` operations.
1086 //
1087 // Asynchronous collective operations pass HLO Token to represent the dependency
1088 // between the `Start` and `Done` operations. When we lower to JitRt custom
1089 // calls we rely on assigning each unique pair of `Start` and `Done` operations
1090 // a unique event id, and use shared "context" owned by the GpuExecutable to
1091 // pass Gpu events from `Start` to `Done` custom call handlers.
1092 //
1093 // TODO(ezhulenev): Once JitRt custom calls support returning values, we should
1094 // explicitly return event id from the `Start` custom call, and pass it to the
1095 // `Done` custom call. Longer term this should become an `!async.token` and rely
1096 // on JitRt asynchonous execution.
1097 class CollectiveUidGenerator {
1098 public:
CollectiveUidGenerator()1099 CollectiveUidGenerator() : cnt_(0) {}
1100
1101 // Assings a unique event id to the pair of start and done operations.
AssignUid(AllReduceStartOp start,AllReduceDoneOp done)1102 int32_t AssignUid(AllReduceStartOp start, AllReduceDoneOp done) {
1103 int32_t id = next();
1104 uids_[start] = id;
1105 uids_[done] = id;
1106 return id;
1107 }
1108
AssignedUid(Operation * op)1109 FailureOr<int32_t> AssignedUid(Operation* op) {
1110 // Async operations must be assigned uid ahead of time.
1111 if (isa<AllReduceStartOp, AllReduceDoneOp>(op)) {
1112 auto it = uids_.find(op);
1113 if (it == uids_.end()) return failure();
1114 return it->second;
1115 }
1116 // For every other operation we just assign a next id.
1117 return next();
1118 }
1119
1120 private:
next()1121 int32_t next() { return cnt_++; }
1122
1123 int32_t cnt_;
1124 llvm::DenseMap<Operation*, int32_t> uids_;
1125 };
1126
1127 template <typename CollectiveOp>
1128 class CollectiveOpLowering : public OpRewritePattern<CollectiveOp> {
1129 private:
CustomCallTarget(AllGatherOp)1130 static StringRef CustomCallTarget(AllGatherOp) {
1131 return "xla.gpu.all_gather";
1132 }
CustomCallTarget(AllReduceOp)1133 static StringRef CustomCallTarget(AllReduceOp) {
1134 return "xla.gpu.all_reduce";
1135 }
CustomCallTarget(AllReduceStartOp)1136 static StringRef CustomCallTarget(AllReduceStartOp) {
1137 return "xla.gpu.all_reduce_start";
1138 }
CustomCallTarget(ReduceScatterOp)1139 static StringRef CustomCallTarget(ReduceScatterOp) {
1140 return "xla.gpu.reduce_scatter";
1141 }
CustomCallTarget(AllToAllOp)1142 static StringRef CustomCallTarget(AllToAllOp) { return "xla.gpu.all_to_all"; }
CustomCallTarget(CollectivePermuteOp)1143 static StringRef CustomCallTarget(CollectivePermuteOp) {
1144 return "xla.gpu.collective_permute";
1145 }
1146
1147 template <typename ReduceOrGatherOp>
GetNcclCollectiveConfig(ReduceOrGatherOp op,int,int)1148 static xla::gpu::NcclCollectiveConfig GetNcclCollectiveConfig(
1149 ReduceOrGatherOp op, int /*replica_count*/, int /*num_partitions*/) {
1150 return xla::gpu::GetNcclCollectiveConfigForMlir(op,
1151 op.getUseGlobalDeviceIds());
1152 }
GetNcclCollectiveConfig(AllToAllOp op,int,int)1153 static xla::gpu::NcclCollectiveConfig GetNcclCollectiveConfig(
1154 AllToAllOp op, int /*replica_count*/, int /*num_partitions*/) {
1155 // TODO(b/180174349): LMHLO AllToAll incorrectly has use_global_device_ids
1156 // attribute and it should be removed.
1157 return xla::gpu::GetNcclCollectiveConfigForMlir(op, std::nullopt);
1158 }
GetNcclCollectiveConfig(CollectivePermuteOp op,int replica_count,int num_partitions)1159 static xla::gpu::NcclCollectiveConfig GetNcclCollectiveConfig(
1160 CollectivePermuteOp op, int replica_count, int num_partitions) {
1161 return xla::gpu::NcclCollectivePermuteThunk::GetNcclCollectivePermuteConfig(
1162 op, replica_count, num_partitions)
1163 .config;
1164 }
1165
1166 template <typename NonCollectivePermuteOp>
TryDegenerateToMemCopy(NonCollectivePermuteOp op,const xla::gpu::NcclCollectiveConfig & config,int replica_count,int num_partitions,PatternRewriter & rewriter)1167 static LogicalResult TryDegenerateToMemCopy(
1168 NonCollectivePermuteOp op, const xla::gpu::NcclCollectiveConfig& config,
1169 int replica_count, int num_partitions, PatternRewriter& rewriter) {
1170 if (!config.IsDegenerate(replica_count, num_partitions)) {
1171 return failure();
1172 }
1173
1174 for (int64_t i = 0; i < op.getInputs().size(); i++) {
1175 rewriter.create<gpu::MemcpyOp>(
1176 op.getLoc(), TypeRange(),
1177 ValueRange({op.getOutputs()[i], op.getOperands()[i]}));
1178 }
1179 return success();
1180 }
TryDegenerateToMemCopy(CollectivePermuteOp op,const xla::gpu::NcclCollectiveConfig & config,int replica_count,int num_partitions,PatternRewriter & rewriter)1181 static LogicalResult TryDegenerateToMemCopy(
1182 CollectivePermuteOp op, const xla::gpu::NcclCollectiveConfig& config,
1183 int replica_count, int num_partitions, PatternRewriter& rewriter) {
1184 if (!xla::gpu::NcclCollectivePermuteThunk::IsDegenerate(op, replica_count,
1185 num_partitions)) {
1186 return failure();
1187 }
1188
1189 rewriter.create<gpu::MemcpyOp>(
1190 op.getLoc(), TypeRange(),
1191 ValueRange({op.getOutput(), op.getOperand()}));
1192 return success();
1193 }
1194
CanImplement(AllGatherOp op)1195 static bool CanImplement(AllGatherOp op) {
1196 return xla::gpu::NcclAllGatherThunk::CanImplement(op);
1197 }
CanImplement(AllReduceOp op)1198 static bool CanImplement(AllReduceOp op) {
1199 return xla::gpu::NcclAllReduceThunk::CanImplement(op);
1200 }
CanImplement(AllReduceStartOp op)1201 static bool CanImplement(AllReduceStartOp op) {
1202 return xla::gpu::NcclAllReduceStartThunk::CanImplement(op);
1203 }
CanImplement(ReduceScatterOp op)1204 static bool CanImplement(ReduceScatterOp op) {
1205 return xla::gpu::NcclReduceScatterThunk::CanImplement(op);
1206 }
CanImplement(AllToAllOp op)1207 static bool CanImplement(AllToAllOp op) {
1208 return xla::gpu::NcclAllToAllThunk::CanImplement(op);
1209 }
CanImplement(CollectivePermuteOp op)1210 static bool CanImplement(CollectivePermuteOp op) {
1211 return xla::gpu::NcclCollectivePermuteThunk::CanImplement(op);
1212 }
1213
1214 template <typename ReduceOp>
SetSpecificAttrs(ImplicitLocOpBuilder & b,ReduceOp op,CallOp call)1215 static LogicalResult SetSpecificAttrs(ImplicitLocOpBuilder& b, ReduceOp op,
1216 CallOp call) {
1217 std::optional<xla::ReductionKind> reduction_kind =
1218 xla::gpu::NcclAllReduceThunkBase::MatchAllReduceComputation(
1219 op.getComputation());
1220 if (!reduction_kind.has_value())
1221 return op.emitOpError()
1222 << "Failed to determine reduction computation for AllReduce";
1223
1224 call->setAttr(
1225 b.getStringAttr("reduction_kind"),
1226 b.getI64IntegerAttr(static_cast<int64_t>(reduction_kind.value())));
1227 return success();
1228 }
SetSpecificAttrs(ImplicitLocOpBuilder & b,AllGatherOp op,CallOp call)1229 static LogicalResult SetSpecificAttrs(ImplicitLocOpBuilder& b, AllGatherOp op,
1230 CallOp call) {
1231 return success();
1232 }
SetSpecificAttrs(ImplicitLocOpBuilder & b,AllToAllOp op,CallOp call)1233 static LogicalResult SetSpecificAttrs(ImplicitLocOpBuilder& b, AllToAllOp op,
1234 CallOp call) {
1235 call->setAttr(b.getStringAttr("has_split_dimension"),
1236 b.getBoolAttr(op.getSplitDimension().has_value()));
1237 return success();
1238 }
SetSpecificAttrs(ImplicitLocOpBuilder & b,CollectivePermuteOp op,CallOp call)1239 static LogicalResult SetSpecificAttrs(ImplicitLocOpBuilder& b,
1240 CollectivePermuteOp op, CallOp call) {
1241 auto source_target_pairs_or =
1242 xla::ConvertNx2Attribute(op.getSourceTargetPairs());
1243 if (!source_target_pairs_or.ok()) {
1244 return op.emitOpError()
1245 << source_target_pairs_or.status().error_message();
1246 }
1247
1248 // Pass an array of pairs as two vectors.
1249 std::vector<std::pair<int64_t, int64_t>> source_target_pairs =
1250 std::move(source_target_pairs_or.value());
1251 std::vector<int64_t> source_peers, target_peers;
1252 source_peers.reserve(source_target_pairs.size());
1253 target_peers.reserve(source_target_pairs.size());
1254 for (const auto& source_target_pair : source_target_pairs) {
1255 source_peers.push_back(source_target_pair.first);
1256 target_peers.push_back(source_target_pair.second);
1257 }
1258
1259 auto source_peers_attr = b.getI64TensorAttr(source_peers);
1260 auto target_peers_attr = b.getI64TensorAttr(target_peers);
1261 call->setAttr(b.getStringAttr("source_peers"), source_peers_attr);
1262 call->setAttr(b.getStringAttr("target_peers"), target_peers_attr);
1263 return success();
1264 }
1265
1266 public:
CollectiveOpLowering(MLIRContext * ctx,CollectiveUidGenerator & uid)1267 explicit CollectiveOpLowering(MLIRContext* ctx, CollectiveUidGenerator& uid)
1268 : OpRewritePattern<CollectiveOp>(ctx), uid_(uid) {}
1269
matchAndRewrite(CollectiveOp op,PatternRewriter & rewriter) const1270 LogicalResult matchAndRewrite(CollectiveOp op,
1271 PatternRewriter& rewriter) const override {
1272 MLIRContext* ctx = this->getContext();
1273 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
1274
1275 ModuleOp module = op->template getParentOfType<ModuleOp>();
1276
1277 // Construct an NCCL collective config from the parent func attributes.
1278 FuncOp fn = op->template getParentOfType<FuncOp>();
1279 auto replica_count_attr = fn->getAttrOfType<IntegerAttr>("replica_count");
1280 auto num_partitions_attr = fn->getAttrOfType<IntegerAttr>("num_partitions");
1281 const int64_t replica_count = replica_count_attr.getInt();
1282 const int64_t num_partitions = num_partitions_attr.getInt();
1283
1284 xla::gpu::NcclCollectiveConfig config =
1285 GetNcclCollectiveConfig(op, replica_count, num_partitions);
1286
1287 // A given collective op can be degenerate if across all groups formed
1288 // by it are singleton. In such a case, we don't need to do any
1289 // communication and we can just copy the input to the output.
1290 if (succeeded(TryDegenerateToMemCopy(op, config, replica_count,
1291 num_partitions, rewriter))) {
1292 // For async collective erase all corresponding done operations.
1293 if (auto start = dyn_cast<AllReduceStartOp>(op.getOperation())) {
1294 auto users = llvm::to_vector(start.getToken().getUsers());
1295 llvm::for_each(users, [&](Operation* user) {
1296 if (isa<AllReduceDoneOp>(user)) rewriter.eraseOp(user);
1297 });
1298 }
1299
1300 // Erase the original collective operation.
1301 rewriter.eraseOp(op);
1302
1303 return success();
1304 }
1305
1306 // Custom call target.
1307 NamedAttribute target(b.getStringAttr(kDirectCustomCall),
1308 b.getStringAttr(CustomCallTarget(op)));
1309
1310 // Create a custom call function declaration.
1311 auto custom_call_type =
1312 FunctionType::get(ctx, op.getOperandTypes(), TypeRange());
1313 auto custom_call_attrs = ArrayRef<NamedAttribute>(target);
1314 auto custom_call = FuncOp::create(op.getLoc(), CustomCallTarget(op),
1315 custom_call_type, custom_call_attrs);
1316 custom_call.setPrivate();
1317
1318 SymbolTable sym_table(module);
1319 auto inserted = sym_table.insert(custom_call);
1320 rewriter.notifyOperationInserted(custom_call);
1321
1322 // Convert collective op to a function call.
1323 auto call = rewriter.create<CallOp>(op.getLoc(), inserted, TypeRange(),
1324 op.getOperands());
1325
1326 if (!CanImplement(op)) {
1327 return op.emitOpError()
1328 << "Requested " << CustomCallTarget(op)
1329 << " not implemented on GPU; replica_count: " << replica_count
1330 << ", num_partitions: " << num_partitions << ", group_mode: "
1331 << CollectiveOpGroupModeToString(config.group_mode)
1332 << ", NCCL support: "
1333 << xla::gpu::NcclCollectiveThunk::NcclIsEnabled();
1334 }
1335
1336 // Copy backend specific attributes.
1337 call->setAttr(b.getStringAttr("group_mode"),
1338 b.getI64IntegerAttr(static_cast<int64_t>(config.group_mode)));
1339 call->setAttr(b.getStringAttr("op_id"), b.getI64IntegerAttr(config.op_id));
1340 // TODO(b/233930690): Pass the attribute below as a nested array.
1341 // Pass an array of arrays using two vectors; one specifying all the values
1342 // and another specifying the (ending) offsets of each array in the other
1343 // vector. Example: [ [10, 20, 30, 40], [50, 60], [70, 80, 90] ] turns into
1344 // offsets=[4, 6, 9] values=[10, 20, 30, 40, 50, 60, 70, 80, 90].
1345 std::vector<int64_t> replica_group_offsets;
1346 std::vector<int64_t> replica_group_values;
1347 replica_group_offsets.reserve(config.replica_groups.size());
1348 int replica_group_offset = 0;
1349 for (const auto& replica_group : config.replica_groups) {
1350 replica_group_offset += replica_group.replica_ids_size();
1351 replica_group_offsets.push_back(replica_group_offset);
1352 replica_group_values.reserve(replica_group_offset);
1353 for (auto replica_id : replica_group.replica_ids()) {
1354 replica_group_values.push_back(replica_id);
1355 }
1356 }
1357 call->setAttr(b.getStringAttr("replica_group_offsets"),
1358 b.getI64TensorAttr(replica_group_offsets));
1359 call->setAttr(b.getStringAttr("replica_group_values"),
1360 b.getI64TensorAttr(replica_group_values));
1361
1362 // Assign a unique collective operation id.
1363 auto uid = uid_.AssignedUid(op);
1364 if (succeeded(uid)) {
1365 call->setAttr(b.getStringAttr("uid"), b.getI32IntegerAttr(*uid));
1366 } else {
1367 return op.emitOpError("failed to get a unique collective operation id");
1368 }
1369
1370 // Set attributes specific to the type of collective operation.
1371 auto result = SetSpecificAttrs(b, op, call);
1372 if (failed(result)) return result;
1373
1374 // For asynchonous start operation we need to produce a fake token, that
1375 // will be later removed, because corresponding `done` operation doesn't
1376 // have the token argument. We rely on the `unrealized_conversion_cast`
1377 // operation to create a fake token from the `i8` constant.
1378 if (auto start = dyn_cast<AllReduceStartOp>(op.getOperation())) {
1379 Value token = start.getToken();
1380 Value c0 = b.create<ConstantOp>(b.getI8IntegerAttr(0));
1381 auto fake = b.create<UnrealizedConversionCastOp>(token.getType(), c0);
1382 token.replaceAllUsesWith(fake.getResult(0));
1383 }
1384
1385 // Erase the original collective operation.
1386 rewriter.eraseOp(op);
1387
1388 return success();
1389 }
1390
1391 private:
1392 CollectiveUidGenerator& uid_;
1393 };
1394
1395 class AllGatherOpLowering : public CollectiveOpLowering<AllGatherOp> {
1396 public:
1397 using CollectiveOpLowering::CollectiveOpLowering;
1398 };
1399
1400 class AllReduceOpLowering : public CollectiveOpLowering<AllReduceOp> {
1401 public:
1402 using CollectiveOpLowering::CollectiveOpLowering;
1403 };
1404
1405 class AllReduceStartOpLowering : public CollectiveOpLowering<AllReduceStartOp> {
1406 public:
1407 using CollectiveOpLowering::CollectiveOpLowering;
1408 };
1409
1410 class ReduceScatterOpLowering : public CollectiveOpLowering<ReduceScatterOp> {
1411 public:
1412 using CollectiveOpLowering::CollectiveOpLowering;
1413 };
1414
1415 class AllToAllOpLowering : public CollectiveOpLowering<AllToAllOp> {
1416 public:
1417 using CollectiveOpLowering::CollectiveOpLowering;
1418 };
1419
1420 class CollectivePermuteOpLowering
1421 : public CollectiveOpLowering<CollectivePermuteOp> {
1422 public:
1423 using CollectiveOpLowering::CollectiveOpLowering;
1424 };
1425
1426 class AllReduceDoneOpLowering : public OpRewritePattern<AllReduceDoneOp> {
1427 private:
1428 static constexpr const char kCustomCallTarget[] = "xla.gpu.all_reduce_done";
1429
1430 public:
AllReduceDoneOpLowering(MLIRContext * ctx,CollectiveUidGenerator & uid)1431 explicit AllReduceDoneOpLowering(MLIRContext* ctx,
1432 CollectiveUidGenerator& uid)
1433 : OpRewritePattern<AllReduceDoneOp>(ctx), uid_(uid) {}
1434
matchAndRewrite(AllReduceDoneOp op,PatternRewriter & rewriter) const1435 LogicalResult matchAndRewrite(AllReduceDoneOp op,
1436 PatternRewriter& rewriter) const override {
1437 MLIRContext* ctx = this->getContext();
1438 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
1439
1440 ModuleOp module = op->getParentOfType<ModuleOp>();
1441
1442 // Custom call target.
1443 NamedAttribute target(b.getStringAttr(kDirectCustomCall),
1444 b.getStringAttr(kCustomCallTarget));
1445
1446 // For done operation we drop the token argument and communicate async event
1447 // dependency through the `uid` attribute.
1448 llvm::SmallVector<Value> operands = op.getOperands().drop_front();
1449
1450 // Create a custom call function declaration.
1451 auto custom_call_type =
1452 FunctionType::get(ctx, TypeRange(ValueRange(operands)), TypeRange());
1453 auto custom_call_attrs = ArrayRef<NamedAttribute>(target);
1454 auto custom_call = FuncOp::create(op.getLoc(), kCustomCallTarget,
1455 custom_call_type, custom_call_attrs);
1456 custom_call.setPrivate();
1457
1458 SymbolTable sym_table(module);
1459 auto inserted = sym_table.insert(custom_call);
1460 rewriter.notifyOperationInserted(custom_call);
1461
1462 // Convert AllReduceDone to a function call.
1463 auto call =
1464 rewriter.create<CallOp>(op.getLoc(), inserted, TypeRange(), operands);
1465
1466 // Assign a unique collective operation id.
1467 auto uid = uid_.AssignedUid(op);
1468 if (succeeded(uid)) {
1469 call->setAttr(b.getStringAttr("uid"), b.getI32IntegerAttr(*uid));
1470 } else {
1471 return op.emitOpError("failed to get a unique collective operation id");
1472 }
1473
1474 // Erase the original AllReduceDone operation.
1475 rewriter.eraseOp(op);
1476
1477 return success();
1478 }
1479
1480 private:
1481 CollectiveUidGenerator& uid_;
1482 };
1483
1484 // -------------------------------------------------------------------------- //
1485
1486 template <typename IdOp>
1487 class IdOpLowering : public OpRewritePattern<IdOp> {
1488 private:
CustomCallTarget(ReplicaIdOp)1489 static StringRef CustomCallTarget(ReplicaIdOp) {
1490 return "xla.gpu.replica_id";
1491 }
CustomCallTarget(PartitionIdOp)1492 static StringRef CustomCallTarget(PartitionIdOp) {
1493 return "xla.gpu.partition_id";
1494 }
1495
1496 public:
IdOpLowering(MLIRContext * ctx)1497 explicit IdOpLowering(MLIRContext* ctx) : OpRewritePattern<IdOp>(ctx) {}
1498
matchAndRewrite(IdOp op,PatternRewriter & rewriter) const1499 LogicalResult matchAndRewrite(IdOp op,
1500 PatternRewriter& rewriter) const override {
1501 MLIRContext* ctx = this->getContext();
1502 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
1503
1504 ModuleOp module = op->template getParentOfType<ModuleOp>();
1505
1506 // Custom call target.
1507 NamedAttribute target(b.getStringAttr(kDirectCustomCall),
1508 b.getStringAttr(CustomCallTarget(op)));
1509
1510 // Create a custom call function declaration.
1511 auto custom_call_type =
1512 FunctionType::get(ctx, op->getOperandTypes(), TypeRange());
1513 auto custom_call_attrs = ArrayRef<NamedAttribute>(target);
1514 auto custom_call = FuncOp::create(op.getLoc(), CustomCallTarget(op),
1515 custom_call_type, custom_call_attrs);
1516 custom_call.setPrivate();
1517
1518 SymbolTable sym_table(module);
1519 auto inserted = sym_table.insert(custom_call);
1520 rewriter.notifyOperationInserted(custom_call);
1521
1522 // Convert ReplicaId to a function call.
1523 rewriter.create<CallOp>(op.getLoc(), inserted, TypeRange(),
1524 op->getOperands());
1525
1526 // Erase the original ReplicaId operation.
1527 rewriter.eraseOp(op);
1528
1529 return success();
1530 }
1531 };
1532
1533 class ReplicaIdOpLowering : public IdOpLowering<ReplicaIdOp> {
1534 public:
1535 using IdOpLowering::IdOpLowering;
1536 };
1537
1538 class PartitionIdOpLowering : public IdOpLowering<PartitionIdOp> {
1539 public:
1540 using IdOpLowering::IdOpLowering;
1541 };
1542
1543 // -------------------------------------------------------------------------- //
1544
runOnOperation()1545 void ConvertLmhloConstantToArgPass::runOnOperation() {
1546 ModuleOp module = getOperation();
1547 MLIRContext* ctx = module.getContext();
1548
1549 // Replace memref loads from globals corresponding to the constant arguments.
1550 RewritePatternSet patterns(ctx);
1551 GlobalConstantsArgs cst_args = GetConstantArgs(module);
1552 patterns.insert<GetGlobalOpLowering>(ctx, cst_args);
1553
1554 // Set up conversion target to rewrite only GetGlobalOp larger than the
1555 // threshold and avoid any other canonicalizations that can break later
1556 // passes.
1557 ConversionTarget target(*ctx);
1558 target.addDynamicallyLegalOp<GetGlobalOp>([&](GetGlobalOp op) {
1559 auto memref = op.getType();
1560 return memref.getNumElements() < min_num_elements_;
1561 });
1562 target.addLegalOp<ConstantOp, memref::ViewOp, memref::ReinterpretCastOp>();
1563
1564 // TODO(ezhulenev): By adding MHLO and LMHLO to a set of legal dialects, we
1565 // suppress any rewrites for these dialects (there are canonicalization
1566 // patterns that interact badly with downstream Gpu binary code generation).
1567 target.addLegalDialect<mhlo::MhloDialect, lmhlo::LmhloDialect>();
1568
1569 if (failed(applyPartialConversion(module, target, std::move(patterns))))
1570 signalPassFailure();
1571 }
1572
runOnOperation()1573 void ConvertGpuToJitRtPass::runOnOperation() {
1574 ModuleOp module = getOperation();
1575 MLIRContext* ctx = module.getContext();
1576
1577 // Convert gpu operations to JitRt gpu runtime custom calls.
1578 RewritePatternSet patterns(ctx);
1579 patterns.insert<GpuModuleOpLowering, LaunchFuncOpLowering, MemcpyOpLowering,
1580 MemsetOpLowering, InfeedOpLowering, OutfeedOpLowering>(ctx);
1581
1582 if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns))))
1583 return signalPassFailure();
1584 }
1585
runOnOperation()1586 void ConvertLmhloGpuToJitRtPass::runOnOperation() {
1587 ModuleOp module = getOperation();
1588 MLIRContext* ctx = module.getContext();
1589
1590 // Convert lmhlo_gpu operations to JitRt gpu runtime custom calls.
1591 RewritePatternSet patterns(ctx);
1592
1593 // Each unique Gemm/Matmul operation in the module will get assigned a uid.
1594 GemmUidGenerator gemm_uid;
1595 patterns.insert<GemmOpLowering, CublasLtMatmulOpLowering>(ctx, gemm_uid);
1596
1597 // Assign shared unique id to each unique pair of async start-done operations,
1598 // all other collective operations will get assigned uid.
1599 CollectiveUidGenerator collective_uid;
1600 auto walked = module.walk([&](AllReduceStartOp start) -> WalkResult {
1601 Value token = start.getToken();
1602
1603 // We expect the token to be consumed just once.
1604 if (!token.hasOneUse()) return start.emitOpError("token has multiple uses");
1605
1606 // Token must be consumed by the corresponding done operation.
1607 auto done = dyn_cast<AllReduceDoneOp>(*token.getUsers().begin());
1608 if (!done) return start.emitOpError("illegal token user");
1609
1610 collective_uid.AssignUid(start, done);
1611 return WalkResult::advance();
1612 });
1613 if (walked.wasInterrupted()) return signalPassFailure();
1614
1615 // Patterns for collective operations.
1616 patterns.insert<AllGatherOpLowering, AllReduceOpLowering,
1617 AllReduceStartOpLowering, AllToAllOpLowering,
1618 CollectivePermuteOpLowering, ReduceScatterOpLowering>(
1619 ctx, collective_uid);
1620
1621 // Patterns for every other Gpu operation.
1622 patterns
1623 .insert<FftOpLowering, CholeskyOpLowering, PartitionIdOpLowering,
1624 ReplicaIdOpLowering, WhileOpLowering, CaseOpLowering,
1625 CustomCallOpLowering, TerminatorOpLowering, ConvForwardOpLowering,
1626 ConvForwardFusedOpLowering, ConvForwardFusedSideInputOpLowering,
1627 ConvBackwardFilterOpLowering, ConvBackwardInputOpLowering>(ctx);
1628
1629 if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns))))
1630 return signalPassFailure();
1631
1632 // TODO(ezhulenev): We must run `done` op lowering after the `start` op
1633 // lowering to ensure that all redundant collective operations will be
1634 // safely replaced by a `memcpy` operations. We should find a better way to
1635 // achieve this goal.
1636 {
1637 RewritePatternSet patterns(ctx);
1638 patterns.insert<AllReduceDoneOpLowering>(ctx, collective_uid);
1639 if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns))))
1640 return signalPassFailure();
1641 }
1642 }
1643
createConvertGpuToJitRtPass()1644 std::unique_ptr<OperationPass<ModuleOp>> createConvertGpuToJitRtPass() {
1645 return std::make_unique<ConvertGpuToJitRtPass>();
1646 }
1647
createConvertLmhloConstantToArgPass()1648 std::unique_ptr<OperationPass<ModuleOp>> createConvertLmhloConstantToArgPass() {
1649 return std::make_unique<ConvertLmhloConstantToArgPass>();
1650 }
1651
createConvertLmhloConstantToArgPass(int64_t min_num_elements)1652 std::unique_ptr<OperationPass<ModuleOp>> createConvertLmhloConstantToArgPass(
1653 int64_t min_num_elements) {
1654 return std::make_unique<ConvertLmhloConstantToArgPass>(min_num_elements);
1655 }
1656
createConvertLmhloGpuToJitRtPass()1657 std::unique_ptr<OperationPass<ModuleOp>> createConvertLmhloGpuToJitRtPass() {
1658 return std::make_unique<ConvertLmhloGpuToJitRtPass>();
1659 }
1660
populateLmhloToJitRtPasses(mlir::OpPassManager & pm,GpuBinaryOptions options)1661 void populateLmhloToJitRtPasses(mlir::OpPassManager& pm,
1662 GpuBinaryOptions options) {
1663 // Convert large global memrefs corresponding to XLA constants with arguments,
1664 // so that compiled device kernels do not capture them.
1665 //
1666 // TODO(ezhulenev): Threshold should be consistent with the device kernel
1667 // code generation. If constant will be embedded into the device module, we
1668 // should not inline it too early. Currently it's hardcoded to `1` element.
1669 pm.addPass(createConvertLmhloConstantToArgPass(/*min_num_elements=*/2));
1670 pm.addPass(createSymbolDCEPass()); // Clean up unused global constants.
1671
1672 // Small global constants will be embedded into the device modules.
1673 pm.addPass(createConvertLmhloToGpuBinaryPass(options));
1674
1675 // Convert remaining small global memrefs corresponding to constant arguments.
1676 pm.addPass(createConvertLmhloConstantToArgPass());
1677 pm.addPass(createSymbolDCEPass()); // Clean up unused global constants.
1678
1679 // Lower all Gpu operations to the JitRt Gpu runtime intrinsics.
1680 pm.addPass(createConvertLmhloGpuToJitRtPass());
1681 pm.addPass(createConvertGpuToJitRtPass());
1682 }
1683
registerLmhloToJitRtPasses()1684 void registerLmhloToJitRtPasses() {
1685 mlir::registerPass([] { return createConvertGpuToJitRtPass(); });
1686 mlir::registerPass([] { return createConvertLmhloConstantToArgPass(); });
1687 mlir::registerPass([] { return createConvertLmhloGpuToJitRtPass(); });
1688
1689 mlir::registerPassPipeline(
1690 "lmhlo-to-jitrt", "Lower LMHLO to JitRt IR",
1691 [](OpPassManager& pm, StringRef options,
1692 function_ref<LogicalResult(const Twine&)> errorHandler) {
1693 populateLmhloToJitRtPasses(pm);
1694 return success();
1695 },
1696 /*optHandler=*/[](function_ref<void(const PassOptions&)>) {});
1697 }
1698
1699 } // namespace tensorflow
1700