xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/lmhlo_to_jitrt.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 // Copyright 2022 The TensorFlow Runtime Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/lmhlo_to_jitrt.h"
16 
17 #include <cstdint>
18 #include <numeric>
19 #include <optional>
20 #include <string>
21 #include <utility>
22 
23 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"  // from @llvm-project
24 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"  // from @llvm-project
25 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"  // from @llvm-project
26 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
27 #include "mlir/Dialect/GPU/Transforms/Passes.h"  // from @llvm-project
28 #include "mlir/Dialect/MemRef/IR/MemRef.h"  // from @llvm-project
29 #include "mlir/Dialect/SCF/IR/SCF.h"  // from @llvm-project
30 #include "mlir/IR/Attributes.h"  // from @llvm-project
31 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
32 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
33 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
34 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
35 #include "mlir/IR/ImplicitLocOpBuilder.h"  // from @llvm-project
36 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
37 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
38 #include "mlir/IR/TypeRange.h"  // from @llvm-project
39 #include "mlir/Pass/Pass.h"  // from @llvm-project
40 #include "mlir/Pass/PassManager.h"  // from @llvm-project
41 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
42 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
43 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
44 #include "mlir/Transforms/Passes.h"  // from @llvm-project
45 #include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
46 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
47 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo_gpu/IR/lhlo_gpu_ops.h"
48 #include "tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.h"
49 #include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h"
50 #include "tensorflow/compiler/xla/service/gpu/nccl_all_to_all_thunk.h"
51 #include "tensorflow/compiler/xla/service/gpu/nccl_collective_permute_thunk.h"
52 #include "tensorflow/compiler/xla/service/gpu/nccl_collective_thunk.h"
53 #include "tfrt/gpu/kernels/gpu_ops.h"  // from @tf_runtime
54 #include "tfrt/gpu/passes/passes.h"  // from @tf_runtime
55 
56 namespace tensorflow {
57 namespace {
58 
59 #define GEN_PASS_CLASSES
60 #include "tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/jitrt_passes.h.inc"
61 
62 using mlir::Attribute;
63 using mlir::DialectRegistry;
64 using mlir::FunctionType;
65 using mlir::IntegerAttr;
66 using mlir::MLIRContext;
67 using mlir::ModuleOp;
68 using mlir::NamedAttribute;
69 using mlir::Operation;
70 using mlir::OperationPass;
71 using mlir::success;
72 using mlir::SymbolTable;
73 using mlir::Type;
74 using mlir::TypeRange;
75 using mlir::WalkResult;
76 using mlir::arith::ConstantOp;
77 using mlir::arith::IndexCastOp;
78 using mlir::detail::PassOptions;
79 using mlir::func::CallOp;
80 using mlir::func::FuncOp;
81 using mlir::func::ReturnOp;
82 using mlir::gpu::GPUModuleOp;
83 using mlir::gpu::LaunchFuncOp;
84 using mlir::gpu::MemcpyOp;
85 using mlir::gpu::MemsetOp;
86 using mlir::lmhlo::AllGatherOp;
87 using mlir::lmhlo::AllReduceOp;
88 using mlir::lmhlo::AllToAllOp;
89 using mlir::lmhlo::CaseOp;
90 using mlir::lmhlo::CollectivePermuteOp;
91 using mlir::lmhlo::CustomCallOp;
92 using mlir::lmhlo::FftOp;
93 using mlir::lmhlo::InfeedOp;
94 using mlir::lmhlo::OutfeedOp;
95 using mlir::lmhlo::PartitionIdOp;
96 using mlir::lmhlo::ReduceScatterOp;
97 using mlir::lmhlo::ReplicaIdOp;
98 using mlir::lmhlo::TerminatorOp;
99 using mlir::lmhlo::WhileOp;
100 using mlir::lmhlo_gpu::AllReduceDoneOp;
101 using mlir::lmhlo_gpu::AllReduceStartOp;
102 using mlir::lmhlo_gpu::CholeskyOp;
103 using mlir::lmhlo_gpu::ConvBackwardFilterOp;
104 using mlir::lmhlo_gpu::ConvBackwardInputOp;
105 using mlir::lmhlo_gpu::ConvForwardFusedOp;
106 using mlir::lmhlo_gpu::ConvForwardFusedSideInputOp;
107 using mlir::lmhlo_gpu::ConvForwardOp;
108 using mlir::lmhlo_gpu::CublasLtMatmulOp;
109 using mlir::lmhlo_gpu::GEMMOp;
110 using mlir::memref::AllocaOp;
111 using mlir::memref::GetGlobalOp;
112 
113 static constexpr const char kDirectCustomCall[] = "rt.direct_custom_call";
114 
115 class ConvertLmhloConstantToArgPass
116     : public ConvertLmhloConstantToArgPassBase<ConvertLmhloConstantToArgPass> {
117  public:
118   ConvertLmhloConstantToArgPass() = default;
ConvertLmhloConstantToArgPass(int64_t min_num_elements)119   explicit ConvertLmhloConstantToArgPass(int64_t min_num_elements) {
120     this->min_num_elements_ = min_num_elements;
121   }
122 
123   void runOnOperation() override;
124 
getDependentDialects(DialectRegistry & registry) const125   void getDependentDialects(DialectRegistry& registry) const override {
126     registry.insert<mlir::memref::MemRefDialect>();
127   }
128 };
129 
130 class ConvertGpuToJitRtPass
131     : public ConvertGpuToJitRtPassBase<ConvertGpuToJitRtPass> {
132   void runOnOperation() override;
133 
getDependentDialects(DialectRegistry & registry) const134   void getDependentDialects(DialectRegistry& registry) const override {
135     registry.insert<mlir::func::FuncDialect, mlir::arith::ArithmeticDialect>();
136   }
137 };
138 
139 class ConvertLmhloGpuToJitRtPass
140     : public ConvertLmhloGpuToJitRtPassBase<ConvertLmhloGpuToJitRtPass> {
141   void runOnOperation() override;
142 
getDependentDialects(DialectRegistry & registry) const143   void getDependentDialects(DialectRegistry& registry) const override {
144     registry.insert<mlir::func::FuncDialect, mlir::arith::ArithmeticDialect,
145                     mlir::scf::SCFDialect, mlir::memref::MemRefDialect,
146                     mlir::cf::ControlFlowDialect>();
147   }
148 };
149 
150 }  // namespace
151 
152 // -------------------------------------------------------------------------- //
153 
154 class GpuModuleOpLowering : public OpRewritePattern<GPUModuleOp> {
155  public:
156   using OpRewritePattern::OpRewritePattern;
157 
matchAndRewrite(GPUModuleOp op,PatternRewriter & rewriter) const158   LogicalResult matchAndRewrite(GPUModuleOp op,
159                                 PatternRewriter& rewriter) const override {
160     rewriter.eraseOp(op);
161     return success();
162   }
163 };
164 
165 // -------------------------------------------------------------------------- //
166 
167 class TerminatorOpLowering : public OpRewritePattern<TerminatorOp> {
168  public:
169   using OpRewritePattern::OpRewritePattern;
170 
matchAndRewrite(TerminatorOp op,PatternRewriter & rewriter) const171   LogicalResult matchAndRewrite(TerminatorOp op,
172                                 PatternRewriter& rewriter) const override {
173     rewriter.replaceOpWithNewOp<ReturnOp>(op);
174     return mlir::success();
175   }
176 };
177 
178 // -------------------------------------------------------------------------- //
179 
180 template <typename IoFeedOp>
181 class IoFeedOpLowering : public OpRewritePattern<IoFeedOp> {
182  private:
CustomCallTarget(InfeedOp)183   static StringRef CustomCallTarget(InfeedOp) { return "xla.gpu.infeed"; }
CustomCallTarget(OutfeedOp)184   static StringRef CustomCallTarget(OutfeedOp) { return "xla.gpu.outfeed"; }
185 
186  public:
IoFeedOpLowering(MLIRContext * ctx)187   explicit IoFeedOpLowering(MLIRContext* ctx)
188       : OpRewritePattern<IoFeedOp>(ctx) {}
189 
matchAndRewrite(IoFeedOp op,PatternRewriter & rewriter) const190   LogicalResult matchAndRewrite(IoFeedOp op,
191                                 PatternRewriter& rewriter) const override {
192     MLIRContext* ctx = this->getContext();
193     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
194 
195     // Custom call target.
196     NamedAttribute target(b.getStringAttr(kDirectCustomCall),
197                           b.getStringAttr(CustomCallTarget(op)));
198 
199     // Create a custom call function declaration.
200     auto custom_call_type =
201         FunctionType::get(ctx, op.getOperandTypes(), TypeRange());
202     auto custom_call_attrs = ArrayRef<NamedAttribute>(target);
203     auto custom_call = FuncOp::create(op.getLoc(), CustomCallTarget(op),
204                                       custom_call_type, custom_call_attrs);
205     custom_call.setPrivate();
206 
207     SymbolTable sym_table(op->template getParentOfType<ModuleOp>());
208     auto inserted = sym_table.insert(custom_call);
209     rewriter.notifyOperationInserted(custom_call);
210 
211     // Call the runtime intrinsic with the original operands.
212     auto call = rewriter.create<CallOp>(op.getLoc(), inserted, TypeRange(),
213                                         op.getOperands());
214     call->setAttr(b.getStringAttr("config"), op.getConfigAttr());
215 
216     // Erase the original infeed/outfeed operation.
217     rewriter.eraseOp(op);
218 
219     return success();
220   }
221 };
222 
223 class InfeedOpLowering : public IoFeedOpLowering<InfeedOp> {
224  public:
225   using IoFeedOpLowering::IoFeedOpLowering;
226 };
227 
228 class OutfeedOpLowering : public IoFeedOpLowering<OutfeedOp> {
229  public:
230   using IoFeedOpLowering::IoFeedOpLowering;
231 };
232 
233 // -------------------------------------------------------------------------- //
234 
235 class MemcpyOpLowering : public OpRewritePattern<MemcpyOp> {
236  public:
237   using OpRewritePattern::OpRewritePattern;
238 
239   // We use a heuristic to identify the direction of the memcpy operation, if
240   // the operand was allocated by alloca op or is a global memref, then it must
241   // be a memref on the host.
IsHostMemRef(Value value)242   static bool IsHostMemRef(Value value) {
243     auto* op = value.getDefiningOp();
244     return llvm::isa_and_nonnull<memref::AllocaOp, memref::GetGlobalOp>(op);
245   }
246 
matchAndRewrite(MemcpyOp op,PatternRewriter & rewriter) const247   LogicalResult matchAndRewrite(MemcpyOp op,
248                                 PatternRewriter& rewriter) const override {
249     MLIRContext* ctx = getContext();
250     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
251 
252     // Identify the direction of the memcpy operation.
253     auto memcpy = [&]() {
254       if (IsHostMemRef(op.dst())) return "xla.gpu.memcpy.d2h";
255       if (IsHostMemRef(op.src())) return "xla.gpu.memcpy.h2d";
256       return "xla.gpu.memcpy.d2d";
257     }();
258 
259     // Custom call target.
260     NamedAttribute target(b.getStringAttr(kDirectCustomCall),
261                           b.getStringAttr(memcpy));
262 
263     // Create a custom call function declaration.
264     auto custom_call_type =
265         FunctionType::get(ctx, op.getOperandTypes(), TypeRange());
266     auto custom_call_attrs = ArrayRef<NamedAttribute>(target);
267     auto custom_call = FuncOp::create(op.getLoc(), memcpy, custom_call_type,
268                                       custom_call_attrs);
269     custom_call.setPrivate();
270 
271     SymbolTable sym_table(op->getParentOfType<ModuleOp>());
272     auto inserted = sym_table.insert(custom_call);
273     rewriter.notifyOperationInserted(custom_call);
274 
275     // Create a function launch call operation.
276     rewriter.replaceOpWithNewOp<CallOp>(op, inserted, TypeRange(),
277                                         op.getOperands());
278 
279     return success();
280   }
281 };
282 
283 // -------------------------------------------------------------------------- //
284 
285 class MemsetOpLowering : public OpRewritePattern<MemsetOp> {
286  private:
287   static constexpr const char kCustomCallTarget[] = "xla.gpu.memset";
288 
289  public:
290   using OpRewritePattern::OpRewritePattern;
291 
matchAndRewrite(MemsetOp op,PatternRewriter & rewriter) const292   LogicalResult matchAndRewrite(MemsetOp op,
293                                 PatternRewriter& rewriter) const override {
294     MLIRContext* ctx = getContext();
295     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
296 
297     // Custom call target.
298     NamedAttribute target(b.getStringAttr(kDirectCustomCall),
299                           b.getStringAttr(kCustomCallTarget));
300 
301     // Create a custom call function declaration.
302     auto custom_call_type =
303         FunctionType::get(ctx, op.getOperandTypes(), TypeRange());
304     auto custom_call_attrs = ArrayRef<NamedAttribute>(target);
305     auto custom_call = FuncOp::create(op.getLoc(), kCustomCallTarget,
306                                       custom_call_type, custom_call_attrs);
307     custom_call.setPrivate();
308 
309     SymbolTable sym_table(op->getParentOfType<ModuleOp>());
310     auto inserted = sym_table.insert(custom_call);
311     rewriter.notifyOperationInserted(custom_call);
312 
313     // Create a function launch call operation.
314     rewriter.replaceOpWithNewOp<CallOp>(op, inserted, TypeRange(),
315                                         op.getOperands());
316 
317     return success();
318   }
319 };
320 
321 // -------------------------------------------------------------------------- //
322 
323 class LaunchFuncOpLowering : public OpRewritePattern<LaunchFuncOp> {
324  private:
325   static constexpr const char kCustomCallTarget[] = "xla.gpu.func.launch";
326 
327  public:
328   using OpRewritePattern::OpRewritePattern;
329 
matchAndRewrite(LaunchFuncOp op,PatternRewriter & rewriter) const330   LogicalResult matchAndRewrite(LaunchFuncOp op,
331                                 PatternRewriter& rewriter) const override {
332     MLIRContext* ctx = getContext();
333     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
334 
335     ModuleOp module = op->getParentOfType<ModuleOp>();
336 
337     // Cast grid and block dimensions to i32 before passing to the custom call.
338     auto cast = [&](mlir::Value value) {
339       return b.create<IndexCastOp>(b.getI32Type(), value);
340     };
341 
342     // Prepare arguments for the custom call.
343     llvm::SmallVector<Value> args = {
344         cast(op.gridSizeX()),  cast(op.gridSizeY()),  cast(op.gridSizeZ()),
345         cast(op.blockSizeX()), cast(op.blockSizeY()), cast(op.blockSizeZ())};
346 
347     // Add kernel arguments.
348     llvm::copy(op.operands(), std::back_inserter(args));
349 
350     // Types of the custom call arguments.
351     llvm::SmallVector<Type> args_types = TypeRange(ValueRange(args));
352 
353     // Custom call target.
354     NamedAttribute target(b.getStringAttr(kDirectCustomCall),
355                           b.getStringAttr(kCustomCallTarget));
356 
357     // Create a custom call function declaration.
358     auto custom_call_type = FunctionType::get(ctx, args_types, TypeRange());
359     auto custom_call_attrs = ArrayRef<NamedAttribute>(target);
360     auto custom_call = FuncOp::create(op.getLoc(), kCustomCallTarget,
361                                       custom_call_type, custom_call_attrs);
362     custom_call.setPrivate();
363 
364     SymbolTable sym_table(module);
365     auto inserted = sym_table.insert(custom_call);
366     rewriter.notifyOperationInserted(custom_call);
367 
368     // Get the compiled gpu function.
369     auto* kernel = SymbolTable::lookupNearestSymbolFrom(op, op.kernel());
370     assert(kernel && "kernel not found");
371 
372     // Get the compiled GPU binary from the device kernel module.
373     auto gpu_module = kernel->getParentOfType<mlir::gpu::GPUModuleOp>();
374     auto gpu_binary = gpu_module->getAttrOfType<mlir::StringAttr>("binary");
375 
376     // Create a function launch call operation.
377     auto call = b.create<CallOp>(inserted, TypeRange(), args);
378     call->setAttr(b.getStringAttr("ptx"), gpu_binary);
379     call->setAttr(b.getStringAttr("kernel"), op.getKernelName());
380 
381     // Erase the original gpu launch operation.
382     rewriter.eraseOp(op);
383 
384     return success();
385   }
386 };
387 
388 // -------------------------------------------------------------------------- //
389 
390 // Every Gemm operation in the module gets assigned a unique id, that is passed
391 // to the custom call handler. This id is used for caching resources between the
392 // different invocations of the same gemm operation.
393 class GemmUidGenerator {
394  public:
GemmUidGenerator()395   GemmUidGenerator() : cnt_(0) {}
uid()396   int64_t uid() { return cnt_.fetch_add(1); }
397 
398  private:
399   std::atomic<int64_t> cnt_;
400 };
401 
402 class GemmOpLowering : public OpRewritePattern<GEMMOp> {
403  private:
404   static constexpr const char kCustomCallTarget[] = "xla.gpu.gemm";
405 
406  public:
GemmOpLowering(MLIRContext * ctx,GemmUidGenerator & uid)407   GemmOpLowering(MLIRContext* ctx, GemmUidGenerator& uid)
408       : OpRewritePattern<GEMMOp>(ctx), uid_(uid) {}
409 
matchAndRewrite(GEMMOp op,PatternRewriter & rewriter) const410   LogicalResult matchAndRewrite(GEMMOp op,
411                                 PatternRewriter& rewriter) const override {
412     MLIRContext* ctx = this->getContext();
413     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
414 
415     ModuleOp module = op->template getParentOfType<ModuleOp>();
416 
417     // Custom call target.
418     NamedAttribute target(b.getStringAttr(kDirectCustomCall),
419                           b.getStringAttr(kCustomCallTarget));
420 
421     // Create a custom call function declaration.
422     auto custom_call_type =
423         FunctionType::get(ctx, op.getOperandTypes(), TypeRange());
424     auto custom_call_attrs = ArrayRef<NamedAttribute>(target);
425     auto custom_call = FuncOp::create(op.getLoc(), kCustomCallTarget,
426                                       custom_call_type, custom_call_attrs);
427     custom_call.setPrivate();
428 
429     SymbolTable sym_table(module);
430     auto inserted = sym_table.insert(custom_call);
431     rewriter.notifyOperationInserted(custom_call);
432 
433     // Convert Gemm to a function call.
434     auto call = rewriter.create<CallOp>(op.getLoc(), inserted, TypeRange(),
435                                         op.getOperands());
436 
437     // Assign a unique id to this instance of a gemm operation.
438     call->setAttr(b.getStringAttr("uid"), b.getI64IntegerAttr(uid_.uid()));
439 
440     // Copy backend specific attributes.
441     auto algorithm_attr = op.getAlgorithm()
442                               ? op.getAlgorithmAttr()
443                               : b.getI64IntegerAttr(se::blas::kDefaultGemmAlgo);
444     call->setAttr(b.getStringAttr("algorithm"), algorithm_attr);
445     call->setAttr(b.getStringAttr("alpha_imag"), op.getAlphaImagAttr());
446     call->setAttr(b.getStringAttr("alpha_real"), op.getAlphaRealAttr());
447     call->setAttr(b.getStringAttr("beta"), op.getBetaAttr());
448     call->setAttr(b.getStringAttr("dot_dims"), op.getDotDimensionNumbers());
449 
450     // Erase the original gemm operation.
451     rewriter.eraseOp(op);
452 
453     return success();
454   }
455 
456  private:
457   GemmUidGenerator& uid_;
458 };
459 
460 // -------------------------------------------------------------------------- //
461 
462 class CublasLtMatmulOpLowering : public OpRewritePattern<CublasLtMatmulOp> {
463  private:
464   static constexpr const char kCustomCallTarget[] = "xla.gpu.cublas.lt.matmul";
465 
466  public:
CublasLtMatmulOpLowering(MLIRContext * ctx,GemmUidGenerator & uid)467   CublasLtMatmulOpLowering(MLIRContext* ctx, GemmUidGenerator& uid)
468       : OpRewritePattern<CublasLtMatmulOp>(ctx), uid_(uid) {}
469 
matchAndRewrite(CublasLtMatmulOp op,PatternRewriter & rewriter) const470   LogicalResult matchAndRewrite(CublasLtMatmulOp op,
471                                 PatternRewriter& rewriter) const override {
472     MLIRContext* ctx = this->getContext();
473     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
474 
475     ModuleOp module = op->template getParentOfType<ModuleOp>();
476 
477     std::string matmul;
478     switch (op.getOperands().size()) {
479       case 4:
480         matmul = kCustomCallTarget;
481         break;
482       case 5:
483         matmul = absl::StrCat(kCustomCallTarget, ".bias");
484         break;
485       default:
486         return op.emitOpError("unexpected number of operands for matmul");
487     }
488 
489     // Custom call target.
490     NamedAttribute target(b.getStringAttr(kDirectCustomCall),
491                           b.getStringAttr(matmul));
492 
493     // Create a custom call function declaration.
494     auto custom_call_type =
495         FunctionType::get(ctx, op.getOperandTypes(), TypeRange());
496     auto custom_call_attrs = ArrayRef<NamedAttribute>(target);
497     auto custom_call = FuncOp::create(op.getLoc(), matmul, custom_call_type,
498                                       custom_call_attrs);
499     custom_call.setPrivate();
500 
501     SymbolTable sym_table(module);
502     auto inserted = sym_table.insert(custom_call);
503     rewriter.notifyOperationInserted(custom_call);
504 
505     // Convert matmul to a function call.
506     auto call = rewriter.create<CallOp>(op.getLoc(), inserted, TypeRange(),
507                                         op.getOperands());
508 
509     // Assign a unique id to this instance of a matmul operation.
510     call->setAttr(b.getStringAttr("uid"), b.getI64IntegerAttr(uid_.uid()));
511 
512     // Copy backend specific attributes.
513     call->setAttr(b.getStringAttr("algorithm"), op.getAlgorithmAttr());
514     call->setAttr(b.getStringAttr("alpha_imag"), op.getAlphaImagAttr());
515     call->setAttr(b.getStringAttr("alpha_real"), op.getAlphaRealAttr());
516     call->setAttr(b.getStringAttr("beta"), op.getBetaAttr());
517     call->setAttr(b.getStringAttr("dot_dims"), op.getDotDimensionNumbers());
518     call->setAttr(b.getStringAttr("epilogue"), op.getEpilogueAttr());
519 
520     // TODO(ezhulenev): Today we can't pass an array of enum attributes to the
521     // custom call. Also we do not have a corresponding precision enum on the
522     // SE/XLA side, so we encode it as an i32 array (tensor).
523     if (auto precisions = op.getPrecisionConfig()) {
524       llvm::SmallVector<int32_t> values;
525       for (auto precision : *precisions) {
526         auto value = precision.cast<mhlo::PrecisionAttr>().getValue();
527         values.push_back(static_cast<int32_t>(value));
528       }
529       call->setAttr(b.getStringAttr("precision"), b.getI32TensorAttr(values));
530     } else {
531       call->setAttr(b.getStringAttr("precision"), b.getI32TensorAttr({0, 0}));
532     }
533 
534     // Erase the original matmul operation.
535     rewriter.eraseOp(op);
536 
537     return success();
538   }
539 
540  private:
541   GemmUidGenerator& uid_;
542 };
543 
544 // -------------------------------------------------------------------------- //
545 
546 template <typename Conv>
547 class ConvOpLowering : public OpRewritePattern<Conv> {
548  private:
CustomCallTarget(ConvForwardOp)549   static StringRef CustomCallTarget(ConvForwardOp) {
550     return "xla.gpu.conv.forward";
551   }
CustomCallTarget(ConvForwardFusedOp)552   static StringRef CustomCallTarget(ConvForwardFusedOp) {
553     return "xla.gpu.conv.forward.fused";
554   }
CustomCallTarget(ConvForwardFusedSideInputOp)555   static StringRef CustomCallTarget(ConvForwardFusedSideInputOp) {
556     return "xla.gpu.conv.forward.fused.side_input";
557   }
CustomCallTarget(ConvBackwardFilterOp)558   static StringRef CustomCallTarget(ConvBackwardFilterOp) {
559     return "xla.gpu.conv.backward.filter";
560   }
CustomCallTarget(ConvBackwardInputOp)561   static StringRef CustomCallTarget(ConvBackwardInputOp) {
562     return "xla.gpu.conv.backward.input";
563   }
564 
565  public:
ConvOpLowering(MLIRContext * ctx)566   explicit ConvOpLowering(MLIRContext* ctx) : OpRewritePattern<Conv>(ctx) {}
567 
matchAndRewrite(Conv op,PatternRewriter & rewriter) const568   LogicalResult matchAndRewrite(Conv op,
569                                 PatternRewriter& rewriter) const override {
570     MLIRContext* ctx = this->getContext();
571     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
572 
573     ModuleOp module = op->template getParentOfType<ModuleOp>();
574 
575     // Custom call target.
576     NamedAttribute target(b.getStringAttr(kDirectCustomCall),
577                           b.getStringAttr(CustomCallTarget(op)));
578 
579     // Create a custom call function declaration.
580     auto custom_call_type =
581         FunctionType::get(ctx, op.getOperandTypes(), TypeRange());
582     auto custom_call_attrs = ArrayRef<NamedAttribute>(target);
583     auto custom_call = FuncOp::create(op.getLoc(), CustomCallTarget(op),
584                                       custom_call_type, custom_call_attrs);
585     custom_call.setPrivate();
586 
587     SymbolTable sym_table(module);
588     auto inserted = sym_table.insert(custom_call);
589     rewriter.notifyOperationInserted(custom_call);
590 
591     // Convert Conv to a function call.
592     auto call = rewriter.create<CallOp>(op.getLoc(), inserted, TypeRange(),
593                                         op.getOperands());
594 
595     // Helper functins to copy attributes from the conv op to the custom call.
596     auto set_attr = [&](StringRef name, Attribute attr) {
597       call->setAttr(b.getStringAttr(name), attr);
598     };
599 
600     auto set_xi64 = [&](StringRef name, Optional<DenseIntElementsAttr> attr) {
601       SmallVector<int64_t> values;
602       if (attr.has_value())
603         values = llvm::to_vector(attr->getValues<int64_t>());
604       set_attr(name, b.getI64TensorAttr(values));
605     };
606 
607     // Convert `BoolElementsAttr` to i64 before passing to the runtime.
608     // TODO(ezhulenev): Allow passing boolean tensors to the JitRt custom calls.
609     auto set_xi1 = [&](StringRef name, Optional<DenseElementsAttr> attr) {
610       SmallVector<int64_t> values;
611       if (attr.has_value())
612         values.assign(attr->getValues<bool>().begin(),
613                       attr->getValues<bool>().end());
614       set_attr(name, b.getI64TensorAttr(values));
615     };
616 
617     // Copy dimension number attributes.
618     call->setAttr(b.getStringAttr("conv_dims"), op.getDimensionNumbers());
619 
620     // Copy convolution window attributes.
621     set_xi1("window_reversal", op.getWindowReversal());
622     set_xi64("window_strides", op.getWindowStrides());
623     set_xi64("lhs_dilation", op.getLhsDilation());
624     set_xi64("rhs_dilation", op.getRhsDilation());
625     set_xi64("padding", op.getPadding());
626 
627     // Copy backend config.
628     call->setAttr(b.getStringAttr("backend_config"), op.getBackendConfig());
629 
630     // Copy remaining attributes.
631     set_attr("feature_group_count", op.getFeatureGroupCountAttr());
632     set_attr("result_scale", op.getResultScaleAttr());
633 
634     // Copy attributes specific for fused convolutions.
635     if (auto fused = dyn_cast<ConvForwardFusedOp>(op.getOperation())) {
636       call->setAttr(b.getStringAttr("activation_mode"),
637                     fused.getActivationModeAttr());
638     }
639 
640     // Copy attributes specific for fused convolutions with side input.
641     if (auto fused = dyn_cast<ConvForwardFusedSideInputOp>(op.getOperation())) {
642       call->setAttr(b.getStringAttr("activation_mode"),
643                     fused.getActivationModeAttr());
644       set_attr("side_input_scale", fused.getSideInputScaleAttr());
645     }
646 
647     // Erase the original conv operation.
648     rewriter.eraseOp(op);
649 
650     return success();
651   }
652 };
653 
654 class ConvForwardOpLowering : public ConvOpLowering<ConvForwardOp> {
655  public:
656   using ConvOpLowering::ConvOpLowering;
657 };
658 
659 class ConvForwardFusedOpLowering : public ConvOpLowering<ConvForwardFusedOp> {
660  public:
661   using ConvOpLowering::ConvOpLowering;
662 };
663 
664 class ConvBackwardFilterOpLowering
665     : public ConvOpLowering<ConvBackwardFilterOp> {
666  public:
667   using ConvOpLowering::ConvOpLowering;
668 };
669 
670 class ConvBackwardInputOpLowering : public ConvOpLowering<ConvBackwardInputOp> {
671  public:
672   using ConvOpLowering::ConvOpLowering;
673 };
674 
675 class ConvForwardFusedSideInputOpLowering
676     : public ConvOpLowering<ConvForwardFusedSideInputOp> {
677  public:
678   using ConvOpLowering::ConvOpLowering;
679 };
680 
681 // -------------------------------------------------------------------------- //
682 
683 class WhileOpLowering : public OpRewritePattern<WhileOp> {
684  public:
685   using OpRewritePattern::OpRewritePattern;
686 
matchAndRewrite(WhileOp op,PatternRewriter & rewriter) const687   LogicalResult matchAndRewrite(WhileOp op,
688                                 PatternRewriter& rewriter) const override {
689     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
690 
691     // Create an `scf.while` loop in place of `lmhlo.while` loop.
692     auto loop = b.create<scf::WhileOp>(TypeRange(), ValueRange());
693 
694     // Predicate buffer placed on the device.
695     assert(op.getNumOperands() == 1 && "expected single cond operand");
696     Value pred = op.getOperand(0);
697 
698     // Clone condition and body blocks into the new loop operation.
699     BlockAndValueMapping mapping;
700     op.getCond().cloneInto(&loop.getBefore(), mapping);
701     op.getBody().cloneInto(&loop.getAfter(), mapping);
702 
703     {  // Replace loop condition terminator.
704       auto* terminator = loop.getBefore().back().getTerminator();
705       b.setInsertionPointAfter(terminator);
706 
707       // Copy predicate buffer to the host ...
708       auto i1 = b.getI1Type();
709       Value pred_on_host = b.create<memref::AllocaOp>(MemRefType::get({}, i1));
710       b.create<gpu::MemcpyOp>(TypeRange(), ValueRange({pred_on_host, pred}));
711 
712       // .. and check if we need to continue loop iteration.
713       Value cond = b.create<memref::LoadOp>(i1, pred_on_host, ValueRange());
714       b.create<scf::ConditionOp>(cond, ValueRange());
715       rewriter.eraseOp(terminator);
716     }
717 
718     {  // Replace loop body terminator.
719       auto* terminator = loop.getAfter().back().getTerminator();
720       b.setInsertionPointAfter(terminator);
721       b.create<scf::YieldOp>(TypeRange(), ValueRange());
722       rewriter.eraseOp(terminator);
723     }
724 
725     // Erase the original while loop.
726     rewriter.eraseOp(op);
727 
728     return success();
729   }
730 };
731 
732 // -------------------------------------------------------------------------- //
733 
734 class CaseOpLowering : public OpRewritePattern<CaseOp> {
735  public:
736   using OpRewritePattern::OpRewritePattern;
737 
matchAndRewrite(CaseOp op,PatternRewriter & rewriter) const738   LogicalResult matchAndRewrite(CaseOp op,
739                                 PatternRewriter& rewriter) const override {
740     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
741 
742     // Copy index buffer to the host ...
743     auto index_type = op.getIndex().getType().dyn_cast<MemRefType>();
744     Value index_on_host = b.create<memref::AllocaOp>(index_type);
745     b.create<gpu::MemcpyOp>(TypeRange(),
746                             ValueRange({index_on_host, op.getIndex()}));
747 
748     // Get the index value from the buffer.
749     Value index = b.create<memref::LoadOp>(index_type.getElementType(),
750                                            index_on_host, ValueRange());
751 
752     bool is_predicate = index_type.getElementType().isInteger(1);
753 
754     // For binary index (predicate) convert i1 to i32 index.
755     if (is_predicate) {
756       Value c0 = b.create<ConstantOp>(b.getI32IntegerAttr(0));
757       Value c1 = b.create<ConstantOp>(b.getI32IntegerAttr(1));
758       index = b.create<arith::SelectOp>(index, c0, c1);
759     }
760 
761     // For integer index make sure that it is within range.
762     if (!is_predicate) {
763       unsigned n = op.getNumRegions() - 1;
764       Value c0 = b.create<ConstantOp>(b.getI32IntegerAttr(0));
765       Value cN = b.create<ConstantOp>(b.getI32IntegerAttr(n));
766 
767       Value too_small = b.create<arith::CmpIOp>(
768           b.getI1Type(), arith::CmpIPredicate::slt, index, c0);
769       Value too_large = b.create<arith::CmpIOp>(
770           b.getI1Type(), arith::CmpIPredicate::sgt, index, cN);
771 
772       Value out_of_range = b.create<arith::OrIOp>(too_small, too_large);
773       index = b.create<arith::SelectOp>(out_of_range, cN, index);
774     }
775 
776     // Split block right at the case operation.
777     Block* cont = rewriter.splitBlock(op->getBlock(), op->getIterator());
778     Block* orig = cont->getPrevNode();
779 
780     // Prepare case destinations for the `scf.switch` operation.
781     llvm::SmallVector<llvm::APInt> case_values;
782     llvm::SmallVector<Block*> case_blocks;
783     llvm::SmallVector<ValueRange> case_operands;
784 
785     // Create blocks from each of the case regions.
786     for (Region& region : op->getRegions()) {
787       // Move `lmhlo.case` block before the continuation.
788       Block& block = region.front();
789       block.moveBefore(cont);
790 
791       // Erase original `lmhlo.terminator`.
792       rewriter.eraseOp(block.getTerminator());
793 
794       // Branch into the continuation block.
795       b.setInsertionPointToEnd(&block);
796       b.create<cf::BranchOp>(cont);
797 
798       // Add a `cf.switch` case.
799       int32_t idx = case_blocks.size();
800       case_values.push_back(b.getI32IntegerAttr(idx).getValue());
801       case_blocks.push_back(&block);
802       case_operands.push_back({});
803     }
804 
805     // Replace `lmhlo.case` with a `cf.switch` operation on the host.
806     b.setInsertionPointToEnd(orig);
807     b.create<cf::SwitchOp>(index, cont, ValueRange(), case_values, case_blocks,
808                            case_operands);
809 
810     // Erase the original case operation.
811     rewriter.eraseOp(op);
812 
813     return success();
814   }
815 };
816 
817 // -------------------------------------------------------------------------- //
818 
819 class CustomCallOpLowering : public OpRewritePattern<CustomCallOp> {
820  private:
821   static constexpr const char kCustomCallTarget[] = "xla.gpu.custom_call";
822 
823  public:
824   using OpRewritePattern::OpRewritePattern;
825 
matchAndRewrite(CustomCallOp op,PatternRewriter & rewriter) const826   LogicalResult matchAndRewrite(CustomCallOp op,
827                                 PatternRewriter& rewriter) const override {
828     MLIRContext* ctx = this->getContext();
829     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
830 
831     // Custom call target.
832     NamedAttribute target(b.getStringAttr(kDirectCustomCall),
833                           b.getStringAttr(kCustomCallTarget));
834 
835     // By default all operands passed to the custom call handler.
836     llvm::SmallVector<Value> operands = op.getOperands();
837 
838     // If custom call has target arguments mapping, then we need to pass empty
839     // memrefs in place of holes.
840     if (op.getTargetArgMapping().has_value()) {
841       auto mapping = *op.getTargetArgMapping();
842       int64_t num_args = mapping.getNumArgs();
843       int64_t num_results = mapping.getNumResults();
844 
845       // We represent holes as empty i8 memrefs.
846       Value hole = b.create<AllocaOp>(MemRefType::get({0}, b.getI8Type()));
847       operands = llvm::SmallVector<Value>(num_args + num_results, hole);
848 
849       // Update operands to mapped custom call arguments.
850       auto args = mapping.getArgsToTargetArgs();
851       for (const auto& indexed : llvm::enumerate(args))
852         operands[indexed.value()] = op.getArgs()[indexed.index()];
853 
854       // Update operands to mapped custom call results.
855       auto res = mapping.getResultsToTargetResults();
856       for (const auto& indexed : llvm::enumerate(res))
857         operands[num_args + indexed.value()] = op.getOutput()[indexed.index()];
858     }
859 
860     // Create a custom call function declaration.
861     auto custom_call_type =
862         FunctionType::get(ctx, TypeRange(ValueRange(operands)), TypeRange());
863 
864     auto custom_call_attrs = ArrayRef<NamedAttribute>(target);
865     auto custom_call = FuncOp::create(op.getLoc(), kCustomCallTarget,
866                                       custom_call_type, custom_call_attrs);
867     custom_call.setPrivate();
868 
869     SymbolTable sym_table(op->getParentOfType<ModuleOp>());
870     auto inserted = sym_table.insert(custom_call);
871     rewriter.notifyOperationInserted(custom_call);
872 
873     // Call the runtime intrinsic with the original operands.
874     auto call =
875         rewriter.create<CallOp>(op.getLoc(), inserted, TypeRange(), operands);
876 
877     // Pass attributes to the custom call handler.
878     auto set_attr = [&](StringRef name, Attribute attr) {
879       call->setAttr(b.getStringAttr(name), attr);
880     };
881 
882     set_attr("api_version", op.getApiVersionAttr());
883     set_attr("backend_config", op.getBackendConfigAttr());
884     set_attr("call_target_name", op.getCallTargetNameAttr());
885 
886     // Erase the original infeed/outfeed operation.
887     rewriter.eraseOp(op);
888 
889     return success();
890   }
891 };
892 
893 // -------------------------------------------------------------------------- //
894 
895 using GlobalConstantsArgs = llvm::DenseMap<FuncOp, llvm::StringMap<Value>>;
896 
897 // Returns a mapping from a global constant name to the function argument.
898 //
899 // Example:
900 //
901 //   memref.global "private" constant @cst : memref<2x3xf32>
902 //   func @get_global(%arg0: memref<24xi8> {lmhlo.constant_name = "cst"})
903 //
904 // All memref.get_global operations will be replaced by constant arguments
905 // corresponding to the global constant.
GetConstantArgs(ModuleOp m)906 GlobalConstantsArgs GetConstantArgs(ModuleOp m) {
907   GlobalConstantsArgs mapping;
908 
909   m.walk([&](FuncOp func) {
910     for (unsigned i = 0; i < func.getNumArguments(); ++i) {
911       auto cst = func.getArgAttrOfType<StringAttr>(i, "lmhlo.constant_name");
912       if (cst) mapping[func][cst] = func.getArgument(i);
913     }
914   });
915 
916   return mapping;
917 }
918 
919 class GetGlobalOpLowering : public OpRewritePattern<GetGlobalOp> {
920  public:
GetGlobalOpLowering(MLIRContext * ctx,const GlobalConstantsArgs & cst_args)921   GetGlobalOpLowering(MLIRContext* ctx, const GlobalConstantsArgs& cst_args)
922       : OpRewritePattern<GetGlobalOp>(ctx), cst_args_(cst_args) {}
923 
matchAndRewrite(GetGlobalOp op,PatternRewriter & rewriter) const924   LogicalResult matchAndRewrite(GetGlobalOp op,
925                                 PatternRewriter& rewriter) const override {
926     // Find global constants mapping for the parent function.
927     auto func_mapping = cst_args_.find(op->getParentOfType<FuncOp>());
928     if (func_mapping == cst_args_.end()) return failure();
929 
930     // Check if the global operation correposponds to the LMHLO constant arg.
931     auto arg = func_mapping->second.find(op.getName());
932     if (arg == func_mapping->second.end()) return failure();
933 
934     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
935     MemRefType memref = op->getResult(0).getType().cast<MemRefType>();
936 
937     // For identity layouts we can replace all loads from a global with the
938     // corresponding argument.
939     if (memref.getLayout().isIdentity()) {
940       Value c0 = b.create<ConstantOp>(rewriter.getIndexAttr(0));
941       rewriter.replaceOpWithNewOp<memref::ViewOp>(op, memref, arg->second, c0,
942                                                   ValueRange());
943       return success();
944     }
945 
946     // For non-identity type we first view constant argument as a flat memref
947     // with the correct element type, and then cast it to the strided memref
948     // corresponding to the original memref layout.
949 
950     // Get the strides and offset from the original memref type.
951     int64_t offset;
952     llvm::SmallVector<int64_t> strides;
953     if (failed(getStridesAndOffset(memref, strides, offset)))
954       return op.emitOpError("failed to compute strides and offset");
955 
956     // Create a 1d view into the corresponding argument.
957     Value c0 = b.create<ConstantOp>(rewriter.getIndexAttr(0));
958     Value flat_view = b.create<memref::ViewOp>(
959         MemRefType::get({memref.getNumElements()}, memref.getElementType()),
960         arg->second, c0, ValueRange());
961 
962     // Cast flat memref view into the original memref type.
963     rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
964         op, memref, flat_view, offset, memref.getShape(), strides);
965 
966     return success();
967   }
968 
969  private:
970   const GlobalConstantsArgs& cst_args_;
971 };
972 
973 // -------------------------------------------------------------------------- //
974 
975 class FftOpLowering : public OpRewritePattern<FftOp> {
976  private:
977   static constexpr const char kCustomCallTarget[] = "xla.gpu.fft";
978 
979  public:
980   using OpRewritePattern::OpRewritePattern;
981 
matchAndRewrite(FftOp op,PatternRewriter & rewriter) const982   LogicalResult matchAndRewrite(FftOp op,
983                                 PatternRewriter& rewriter) const override {
984     MLIRContext* ctx = this->getContext();
985     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
986 
987     ModuleOp module = op->getParentOfType<ModuleOp>();
988 
989     // Custom call target.
990     NamedAttribute target(b.getStringAttr(kDirectCustomCall),
991                           b.getStringAttr(kCustomCallTarget));
992 
993     // Create a custom call function declaration.
994     auto custom_call_type =
995         FunctionType::get(ctx, op.getOperandTypes(), TypeRange());
996     auto custom_call_attrs = ArrayRef<NamedAttribute>(target);
997     auto custom_call = FuncOp::create(op.getLoc(), kCustomCallTarget,
998                                       custom_call_type, custom_call_attrs);
999     custom_call.setPrivate();
1000 
1001     SymbolTable sym_table(module);
1002     auto inserted = sym_table.insert(custom_call);
1003     rewriter.notifyOperationInserted(custom_call);
1004 
1005     // Convert Fft to a function call.
1006     auto call = rewriter.create<CallOp>(op.getLoc(), inserted, TypeRange(),
1007                                         op.getOperands());
1008 
1009     // Copy backend specific attributes.
1010     call->setAttr(b.getStringAttr("fft_length"), op.getFftLengthAttr());
1011     call->setAttr(b.getStringAttr("fft_type"), op.getFftTypeAttr());
1012 
1013     // Erase the original Fft operation.
1014     rewriter.eraseOp(op);
1015 
1016     return success();
1017   }
1018 };
1019 
1020 // -------------------------------------------------------------------------- //
1021 
1022 class CholeskyOpLowering : public OpRewritePattern<CholeskyOp> {
1023  private:
1024   static constexpr const char kCustomCallTarget[] = "xla.gpu.cholesky";
1025 
1026  public:
CholeskyOpLowering(MLIRContext * ctx)1027   explicit CholeskyOpLowering(MLIRContext* ctx)
1028       : OpRewritePattern<CholeskyOp>(ctx) {}
1029 
matchAndRewrite(CholeskyOp op,PatternRewriter & rewriter) const1030   LogicalResult matchAndRewrite(CholeskyOp op,
1031                                 PatternRewriter& rewriter) const override {
1032     MLIRContext* ctx = this->getContext();
1033     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
1034 
1035     ModuleOp module = op->getParentOfType<ModuleOp>();
1036 
1037     // Custom call target.
1038     NamedAttribute target(b.getStringAttr(kDirectCustomCall),
1039                           b.getStringAttr(kCustomCallTarget));
1040 
1041     // Create a custom call function declaration.
1042     auto custom_call_type =
1043         FunctionType::get(ctx, op.getOperandTypes(), TypeRange());
1044     auto custom_call_attrs = ArrayRef<NamedAttribute>(target);
1045     auto custom_call = FuncOp::create(op.getLoc(), kCustomCallTarget,
1046                                       custom_call_type, custom_call_attrs);
1047     custom_call.setPrivate();
1048 
1049     SymbolTable sym_table(module);
1050     auto inserted = sym_table.insert(custom_call);
1051     rewriter.notifyOperationInserted(custom_call);
1052 
1053     // Convert Cholesky to a function call.
1054     auto call = rewriter.create<CallOp>(op.getLoc(), inserted, TypeRange(),
1055                                         op.getOperands());
1056 
1057     const auto& dims =
1058         op.getInput().getType().cast<mlir::MemRefType>().getShape();
1059     if (dims.size() < 2)
1060       return op.emitOpError() << "Input's dimension count (" << dims.size()
1061                               << ") must be 2 or greater.";
1062     int64_t n = dims[dims.size() - 1];
1063     int64_t batch_size =
1064         std::accumulate(dims.begin(), dims.end() - 2, int64_t{1},
1065                         [](int64_t a, int64_t b) { return a * b; });
1066 
1067     // Copy backend specific attributes.
1068     call->setAttr(b.getStringAttr("batch_size"),
1069                   b.getI64IntegerAttr(batch_size));
1070     call->setAttr(b.getStringAttr("n"), b.getI64IntegerAttr(n));
1071     call->setAttr(b.getStringAttr("is_lower"), op.getIsLowerAttr());
1072 
1073     // Erase the original Cholesky operation.
1074     rewriter.eraseOp(op);
1075 
1076     return success();
1077   }
1078 };
1079 
1080 // -------------------------------------------------------------------------- //
1081 
1082 // We assign unique id to all collective operations in the module, so that we
1083 // can efficiently access per-op state at run time. Exception to this rule are
1084 // asynchronous collective operations, that share the same unique id by the pair
1085 // of corresponding `start` and `done` operations.
1086 //
1087 // Asynchronous collective operations pass HLO Token to represent the dependency
1088 // between the `Start` and `Done` operations. When we lower to JitRt custom
1089 // calls we rely on assigning each unique pair of `Start` and `Done` operations
1090 // a unique event id, and use shared "context" owned by the GpuExecutable to
1091 // pass Gpu events from `Start` to `Done` custom call handlers.
1092 //
1093 // TODO(ezhulenev): Once JitRt custom calls support returning values, we should
1094 // explicitly return event id from the `Start` custom call, and pass it to the
1095 // `Done` custom call. Longer term this should become an `!async.token` and rely
1096 // on JitRt asynchonous execution.
1097 class CollectiveUidGenerator {
1098  public:
CollectiveUidGenerator()1099   CollectiveUidGenerator() : cnt_(0) {}
1100 
1101   // Assings a unique event id to the pair of start and done operations.
AssignUid(AllReduceStartOp start,AllReduceDoneOp done)1102   int32_t AssignUid(AllReduceStartOp start, AllReduceDoneOp done) {
1103     int32_t id = next();
1104     uids_[start] = id;
1105     uids_[done] = id;
1106     return id;
1107   }
1108 
AssignedUid(Operation * op)1109   FailureOr<int32_t> AssignedUid(Operation* op) {
1110     // Async operations must be assigned uid ahead of time.
1111     if (isa<AllReduceStartOp, AllReduceDoneOp>(op)) {
1112       auto it = uids_.find(op);
1113       if (it == uids_.end()) return failure();
1114       return it->second;
1115     }
1116     // For every other operation we just assign a next id.
1117     return next();
1118   }
1119 
1120  private:
next()1121   int32_t next() { return cnt_++; }
1122 
1123   int32_t cnt_;
1124   llvm::DenseMap<Operation*, int32_t> uids_;
1125 };
1126 
1127 template <typename CollectiveOp>
1128 class CollectiveOpLowering : public OpRewritePattern<CollectiveOp> {
1129  private:
CustomCallTarget(AllGatherOp)1130   static StringRef CustomCallTarget(AllGatherOp) {
1131     return "xla.gpu.all_gather";
1132   }
CustomCallTarget(AllReduceOp)1133   static StringRef CustomCallTarget(AllReduceOp) {
1134     return "xla.gpu.all_reduce";
1135   }
CustomCallTarget(AllReduceStartOp)1136   static StringRef CustomCallTarget(AllReduceStartOp) {
1137     return "xla.gpu.all_reduce_start";
1138   }
CustomCallTarget(ReduceScatterOp)1139   static StringRef CustomCallTarget(ReduceScatterOp) {
1140     return "xla.gpu.reduce_scatter";
1141   }
CustomCallTarget(AllToAllOp)1142   static StringRef CustomCallTarget(AllToAllOp) { return "xla.gpu.all_to_all"; }
CustomCallTarget(CollectivePermuteOp)1143   static StringRef CustomCallTarget(CollectivePermuteOp) {
1144     return "xla.gpu.collective_permute";
1145   }
1146 
1147   template <typename ReduceOrGatherOp>
GetNcclCollectiveConfig(ReduceOrGatherOp op,int,int)1148   static xla::gpu::NcclCollectiveConfig GetNcclCollectiveConfig(
1149       ReduceOrGatherOp op, int /*replica_count*/, int /*num_partitions*/) {
1150     return xla::gpu::GetNcclCollectiveConfigForMlir(op,
1151                                                     op.getUseGlobalDeviceIds());
1152   }
GetNcclCollectiveConfig(AllToAllOp op,int,int)1153   static xla::gpu::NcclCollectiveConfig GetNcclCollectiveConfig(
1154       AllToAllOp op, int /*replica_count*/, int /*num_partitions*/) {
1155     // TODO(b/180174349): LMHLO AllToAll incorrectly has use_global_device_ids
1156     // attribute and it should be removed.
1157     return xla::gpu::GetNcclCollectiveConfigForMlir(op, std::nullopt);
1158   }
GetNcclCollectiveConfig(CollectivePermuteOp op,int replica_count,int num_partitions)1159   static xla::gpu::NcclCollectiveConfig GetNcclCollectiveConfig(
1160       CollectivePermuteOp op, int replica_count, int num_partitions) {
1161     return xla::gpu::NcclCollectivePermuteThunk::GetNcclCollectivePermuteConfig(
1162                op, replica_count, num_partitions)
1163         .config;
1164   }
1165 
1166   template <typename NonCollectivePermuteOp>
TryDegenerateToMemCopy(NonCollectivePermuteOp op,const xla::gpu::NcclCollectiveConfig & config,int replica_count,int num_partitions,PatternRewriter & rewriter)1167   static LogicalResult TryDegenerateToMemCopy(
1168       NonCollectivePermuteOp op, const xla::gpu::NcclCollectiveConfig& config,
1169       int replica_count, int num_partitions, PatternRewriter& rewriter) {
1170     if (!config.IsDegenerate(replica_count, num_partitions)) {
1171       return failure();
1172     }
1173 
1174     for (int64_t i = 0; i < op.getInputs().size(); i++) {
1175       rewriter.create<gpu::MemcpyOp>(
1176           op.getLoc(), TypeRange(),
1177           ValueRange({op.getOutputs()[i], op.getOperands()[i]}));
1178     }
1179     return success();
1180   }
TryDegenerateToMemCopy(CollectivePermuteOp op,const xla::gpu::NcclCollectiveConfig & config,int replica_count,int num_partitions,PatternRewriter & rewriter)1181   static LogicalResult TryDegenerateToMemCopy(
1182       CollectivePermuteOp op, const xla::gpu::NcclCollectiveConfig& config,
1183       int replica_count, int num_partitions, PatternRewriter& rewriter) {
1184     if (!xla::gpu::NcclCollectivePermuteThunk::IsDegenerate(op, replica_count,
1185                                                             num_partitions)) {
1186       return failure();
1187     }
1188 
1189     rewriter.create<gpu::MemcpyOp>(
1190         op.getLoc(), TypeRange(),
1191         ValueRange({op.getOutput(), op.getOperand()}));
1192     return success();
1193   }
1194 
CanImplement(AllGatherOp op)1195   static bool CanImplement(AllGatherOp op) {
1196     return xla::gpu::NcclAllGatherThunk::CanImplement(op);
1197   }
CanImplement(AllReduceOp op)1198   static bool CanImplement(AllReduceOp op) {
1199     return xla::gpu::NcclAllReduceThunk::CanImplement(op);
1200   }
CanImplement(AllReduceStartOp op)1201   static bool CanImplement(AllReduceStartOp op) {
1202     return xla::gpu::NcclAllReduceStartThunk::CanImplement(op);
1203   }
CanImplement(ReduceScatterOp op)1204   static bool CanImplement(ReduceScatterOp op) {
1205     return xla::gpu::NcclReduceScatterThunk::CanImplement(op);
1206   }
CanImplement(AllToAllOp op)1207   static bool CanImplement(AllToAllOp op) {
1208     return xla::gpu::NcclAllToAllThunk::CanImplement(op);
1209   }
CanImplement(CollectivePermuteOp op)1210   static bool CanImplement(CollectivePermuteOp op) {
1211     return xla::gpu::NcclCollectivePermuteThunk::CanImplement(op);
1212   }
1213 
1214   template <typename ReduceOp>
SetSpecificAttrs(ImplicitLocOpBuilder & b,ReduceOp op,CallOp call)1215   static LogicalResult SetSpecificAttrs(ImplicitLocOpBuilder& b, ReduceOp op,
1216                                         CallOp call) {
1217     std::optional<xla::ReductionKind> reduction_kind =
1218         xla::gpu::NcclAllReduceThunkBase::MatchAllReduceComputation(
1219             op.getComputation());
1220     if (!reduction_kind.has_value())
1221       return op.emitOpError()
1222              << "Failed to determine reduction computation for AllReduce";
1223 
1224     call->setAttr(
1225         b.getStringAttr("reduction_kind"),
1226         b.getI64IntegerAttr(static_cast<int64_t>(reduction_kind.value())));
1227     return success();
1228   }
SetSpecificAttrs(ImplicitLocOpBuilder & b,AllGatherOp op,CallOp call)1229   static LogicalResult SetSpecificAttrs(ImplicitLocOpBuilder& b, AllGatherOp op,
1230                                         CallOp call) {
1231     return success();
1232   }
SetSpecificAttrs(ImplicitLocOpBuilder & b,AllToAllOp op,CallOp call)1233   static LogicalResult SetSpecificAttrs(ImplicitLocOpBuilder& b, AllToAllOp op,
1234                                         CallOp call) {
1235     call->setAttr(b.getStringAttr("has_split_dimension"),
1236                   b.getBoolAttr(op.getSplitDimension().has_value()));
1237     return success();
1238   }
SetSpecificAttrs(ImplicitLocOpBuilder & b,CollectivePermuteOp op,CallOp call)1239   static LogicalResult SetSpecificAttrs(ImplicitLocOpBuilder& b,
1240                                         CollectivePermuteOp op, CallOp call) {
1241     auto source_target_pairs_or =
1242         xla::ConvertNx2Attribute(op.getSourceTargetPairs());
1243     if (!source_target_pairs_or.ok()) {
1244       return op.emitOpError()
1245              << source_target_pairs_or.status().error_message();
1246     }
1247 
1248     // Pass an array of pairs as two vectors.
1249     std::vector<std::pair<int64_t, int64_t>> source_target_pairs =
1250         std::move(source_target_pairs_or.value());
1251     std::vector<int64_t> source_peers, target_peers;
1252     source_peers.reserve(source_target_pairs.size());
1253     target_peers.reserve(source_target_pairs.size());
1254     for (const auto& source_target_pair : source_target_pairs) {
1255       source_peers.push_back(source_target_pair.first);
1256       target_peers.push_back(source_target_pair.second);
1257     }
1258 
1259     auto source_peers_attr = b.getI64TensorAttr(source_peers);
1260     auto target_peers_attr = b.getI64TensorAttr(target_peers);
1261     call->setAttr(b.getStringAttr("source_peers"), source_peers_attr);
1262     call->setAttr(b.getStringAttr("target_peers"), target_peers_attr);
1263     return success();
1264   }
1265 
1266  public:
CollectiveOpLowering(MLIRContext * ctx,CollectiveUidGenerator & uid)1267   explicit CollectiveOpLowering(MLIRContext* ctx, CollectiveUidGenerator& uid)
1268       : OpRewritePattern<CollectiveOp>(ctx), uid_(uid) {}
1269 
matchAndRewrite(CollectiveOp op,PatternRewriter & rewriter) const1270   LogicalResult matchAndRewrite(CollectiveOp op,
1271                                 PatternRewriter& rewriter) const override {
1272     MLIRContext* ctx = this->getContext();
1273     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
1274 
1275     ModuleOp module = op->template getParentOfType<ModuleOp>();
1276 
1277     // Construct an NCCL collective config from the parent func attributes.
1278     FuncOp fn = op->template getParentOfType<FuncOp>();
1279     auto replica_count_attr = fn->getAttrOfType<IntegerAttr>("replica_count");
1280     auto num_partitions_attr = fn->getAttrOfType<IntegerAttr>("num_partitions");
1281     const int64_t replica_count = replica_count_attr.getInt();
1282     const int64_t num_partitions = num_partitions_attr.getInt();
1283 
1284     xla::gpu::NcclCollectiveConfig config =
1285         GetNcclCollectiveConfig(op, replica_count, num_partitions);
1286 
1287     // A given collective op can be degenerate if across all groups formed
1288     // by it are singleton. In such a case, we don't need to do any
1289     // communication and we can just copy the input to the output.
1290     if (succeeded(TryDegenerateToMemCopy(op, config, replica_count,
1291                                          num_partitions, rewriter))) {
1292       // For async collective erase all corresponding done operations.
1293       if (auto start = dyn_cast<AllReduceStartOp>(op.getOperation())) {
1294         auto users = llvm::to_vector(start.getToken().getUsers());
1295         llvm::for_each(users, [&](Operation* user) {
1296           if (isa<AllReduceDoneOp>(user)) rewriter.eraseOp(user);
1297         });
1298       }
1299 
1300       // Erase the original collective operation.
1301       rewriter.eraseOp(op);
1302 
1303       return success();
1304     }
1305 
1306     // Custom call target.
1307     NamedAttribute target(b.getStringAttr(kDirectCustomCall),
1308                           b.getStringAttr(CustomCallTarget(op)));
1309 
1310     // Create a custom call function declaration.
1311     auto custom_call_type =
1312         FunctionType::get(ctx, op.getOperandTypes(), TypeRange());
1313     auto custom_call_attrs = ArrayRef<NamedAttribute>(target);
1314     auto custom_call = FuncOp::create(op.getLoc(), CustomCallTarget(op),
1315                                       custom_call_type, custom_call_attrs);
1316     custom_call.setPrivate();
1317 
1318     SymbolTable sym_table(module);
1319     auto inserted = sym_table.insert(custom_call);
1320     rewriter.notifyOperationInserted(custom_call);
1321 
1322     // Convert collective op to a function call.
1323     auto call = rewriter.create<CallOp>(op.getLoc(), inserted, TypeRange(),
1324                                         op.getOperands());
1325 
1326     if (!CanImplement(op)) {
1327       return op.emitOpError()
1328              << "Requested " << CustomCallTarget(op)
1329              << " not implemented on GPU; replica_count: " << replica_count
1330              << ", num_partitions: " << num_partitions << ", group_mode: "
1331              << CollectiveOpGroupModeToString(config.group_mode)
1332              << ", NCCL support: "
1333              << xla::gpu::NcclCollectiveThunk::NcclIsEnabled();
1334     }
1335 
1336     // Copy backend specific attributes.
1337     call->setAttr(b.getStringAttr("group_mode"),
1338                   b.getI64IntegerAttr(static_cast<int64_t>(config.group_mode)));
1339     call->setAttr(b.getStringAttr("op_id"), b.getI64IntegerAttr(config.op_id));
1340     // TODO(b/233930690): Pass the attribute below as a nested array.
1341     // Pass an array of arrays using two vectors; one specifying all the values
1342     // and another specifying the (ending) offsets of each array in the other
1343     // vector. Example: [ [10, 20, 30, 40], [50, 60], [70, 80, 90] ] turns into
1344     // offsets=[4, 6, 9] values=[10, 20, 30, 40, 50, 60, 70, 80, 90].
1345     std::vector<int64_t> replica_group_offsets;
1346     std::vector<int64_t> replica_group_values;
1347     replica_group_offsets.reserve(config.replica_groups.size());
1348     int replica_group_offset = 0;
1349     for (const auto& replica_group : config.replica_groups) {
1350       replica_group_offset += replica_group.replica_ids_size();
1351       replica_group_offsets.push_back(replica_group_offset);
1352       replica_group_values.reserve(replica_group_offset);
1353       for (auto replica_id : replica_group.replica_ids()) {
1354         replica_group_values.push_back(replica_id);
1355       }
1356     }
1357     call->setAttr(b.getStringAttr("replica_group_offsets"),
1358                   b.getI64TensorAttr(replica_group_offsets));
1359     call->setAttr(b.getStringAttr("replica_group_values"),
1360                   b.getI64TensorAttr(replica_group_values));
1361 
1362     // Assign a unique collective operation id.
1363     auto uid = uid_.AssignedUid(op);
1364     if (succeeded(uid)) {
1365       call->setAttr(b.getStringAttr("uid"), b.getI32IntegerAttr(*uid));
1366     } else {
1367       return op.emitOpError("failed to get a unique collective operation id");
1368     }
1369 
1370     // Set attributes specific to the type of collective operation.
1371     auto result = SetSpecificAttrs(b, op, call);
1372     if (failed(result)) return result;
1373 
1374     // For asynchonous start operation we need to produce a fake token, that
1375     // will be later removed, because corresponding `done` operation doesn't
1376     // have the token argument. We rely on the `unrealized_conversion_cast`
1377     // operation to create a fake token from the `i8` constant.
1378     if (auto start = dyn_cast<AllReduceStartOp>(op.getOperation())) {
1379       Value token = start.getToken();
1380       Value c0 = b.create<ConstantOp>(b.getI8IntegerAttr(0));
1381       auto fake = b.create<UnrealizedConversionCastOp>(token.getType(), c0);
1382       token.replaceAllUsesWith(fake.getResult(0));
1383     }
1384 
1385     // Erase the original collective operation.
1386     rewriter.eraseOp(op);
1387 
1388     return success();
1389   }
1390 
1391  private:
1392   CollectiveUidGenerator& uid_;
1393 };
1394 
1395 class AllGatherOpLowering : public CollectiveOpLowering<AllGatherOp> {
1396  public:
1397   using CollectiveOpLowering::CollectiveOpLowering;
1398 };
1399 
1400 class AllReduceOpLowering : public CollectiveOpLowering<AllReduceOp> {
1401  public:
1402   using CollectiveOpLowering::CollectiveOpLowering;
1403 };
1404 
1405 class AllReduceStartOpLowering : public CollectiveOpLowering<AllReduceStartOp> {
1406  public:
1407   using CollectiveOpLowering::CollectiveOpLowering;
1408 };
1409 
1410 class ReduceScatterOpLowering : public CollectiveOpLowering<ReduceScatterOp> {
1411  public:
1412   using CollectiveOpLowering::CollectiveOpLowering;
1413 };
1414 
1415 class AllToAllOpLowering : public CollectiveOpLowering<AllToAllOp> {
1416  public:
1417   using CollectiveOpLowering::CollectiveOpLowering;
1418 };
1419 
1420 class CollectivePermuteOpLowering
1421     : public CollectiveOpLowering<CollectivePermuteOp> {
1422  public:
1423   using CollectiveOpLowering::CollectiveOpLowering;
1424 };
1425 
1426 class AllReduceDoneOpLowering : public OpRewritePattern<AllReduceDoneOp> {
1427  private:
1428   static constexpr const char kCustomCallTarget[] = "xla.gpu.all_reduce_done";
1429 
1430  public:
AllReduceDoneOpLowering(MLIRContext * ctx,CollectiveUidGenerator & uid)1431   explicit AllReduceDoneOpLowering(MLIRContext* ctx,
1432                                    CollectiveUidGenerator& uid)
1433       : OpRewritePattern<AllReduceDoneOp>(ctx), uid_(uid) {}
1434 
matchAndRewrite(AllReduceDoneOp op,PatternRewriter & rewriter) const1435   LogicalResult matchAndRewrite(AllReduceDoneOp op,
1436                                 PatternRewriter& rewriter) const override {
1437     MLIRContext* ctx = this->getContext();
1438     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
1439 
1440     ModuleOp module = op->getParentOfType<ModuleOp>();
1441 
1442     // Custom call target.
1443     NamedAttribute target(b.getStringAttr(kDirectCustomCall),
1444                           b.getStringAttr(kCustomCallTarget));
1445 
1446     // For done operation we drop the token argument and communicate async event
1447     // dependency through the `uid` attribute.
1448     llvm::SmallVector<Value> operands = op.getOperands().drop_front();
1449 
1450     // Create a custom call function declaration.
1451     auto custom_call_type =
1452         FunctionType::get(ctx, TypeRange(ValueRange(operands)), TypeRange());
1453     auto custom_call_attrs = ArrayRef<NamedAttribute>(target);
1454     auto custom_call = FuncOp::create(op.getLoc(), kCustomCallTarget,
1455                                       custom_call_type, custom_call_attrs);
1456     custom_call.setPrivate();
1457 
1458     SymbolTable sym_table(module);
1459     auto inserted = sym_table.insert(custom_call);
1460     rewriter.notifyOperationInserted(custom_call);
1461 
1462     // Convert AllReduceDone to a function call.
1463     auto call =
1464         rewriter.create<CallOp>(op.getLoc(), inserted, TypeRange(), operands);
1465 
1466     // Assign a unique collective operation id.
1467     auto uid = uid_.AssignedUid(op);
1468     if (succeeded(uid)) {
1469       call->setAttr(b.getStringAttr("uid"), b.getI32IntegerAttr(*uid));
1470     } else {
1471       return op.emitOpError("failed to get a unique collective operation id");
1472     }
1473 
1474     // Erase the original AllReduceDone operation.
1475     rewriter.eraseOp(op);
1476 
1477     return success();
1478   }
1479 
1480  private:
1481   CollectiveUidGenerator& uid_;
1482 };
1483 
1484 // -------------------------------------------------------------------------- //
1485 
1486 template <typename IdOp>
1487 class IdOpLowering : public OpRewritePattern<IdOp> {
1488  private:
CustomCallTarget(ReplicaIdOp)1489   static StringRef CustomCallTarget(ReplicaIdOp) {
1490     return "xla.gpu.replica_id";
1491   }
CustomCallTarget(PartitionIdOp)1492   static StringRef CustomCallTarget(PartitionIdOp) {
1493     return "xla.gpu.partition_id";
1494   }
1495 
1496  public:
IdOpLowering(MLIRContext * ctx)1497   explicit IdOpLowering(MLIRContext* ctx) : OpRewritePattern<IdOp>(ctx) {}
1498 
matchAndRewrite(IdOp op,PatternRewriter & rewriter) const1499   LogicalResult matchAndRewrite(IdOp op,
1500                                 PatternRewriter& rewriter) const override {
1501     MLIRContext* ctx = this->getContext();
1502     ImplicitLocOpBuilder b(op.getLoc(), rewriter);
1503 
1504     ModuleOp module = op->template getParentOfType<ModuleOp>();
1505 
1506     // Custom call target.
1507     NamedAttribute target(b.getStringAttr(kDirectCustomCall),
1508                           b.getStringAttr(CustomCallTarget(op)));
1509 
1510     // Create a custom call function declaration.
1511     auto custom_call_type =
1512         FunctionType::get(ctx, op->getOperandTypes(), TypeRange());
1513     auto custom_call_attrs = ArrayRef<NamedAttribute>(target);
1514     auto custom_call = FuncOp::create(op.getLoc(), CustomCallTarget(op),
1515                                       custom_call_type, custom_call_attrs);
1516     custom_call.setPrivate();
1517 
1518     SymbolTable sym_table(module);
1519     auto inserted = sym_table.insert(custom_call);
1520     rewriter.notifyOperationInserted(custom_call);
1521 
1522     // Convert ReplicaId to a function call.
1523     rewriter.create<CallOp>(op.getLoc(), inserted, TypeRange(),
1524                             op->getOperands());
1525 
1526     // Erase the original ReplicaId operation.
1527     rewriter.eraseOp(op);
1528 
1529     return success();
1530   }
1531 };
1532 
1533 class ReplicaIdOpLowering : public IdOpLowering<ReplicaIdOp> {
1534  public:
1535   using IdOpLowering::IdOpLowering;
1536 };
1537 
1538 class PartitionIdOpLowering : public IdOpLowering<PartitionIdOp> {
1539  public:
1540   using IdOpLowering::IdOpLowering;
1541 };
1542 
1543 // -------------------------------------------------------------------------- //
1544 
runOnOperation()1545 void ConvertLmhloConstantToArgPass::runOnOperation() {
1546   ModuleOp module = getOperation();
1547   MLIRContext* ctx = module.getContext();
1548 
1549   // Replace memref loads from globals corresponding to the constant arguments.
1550   RewritePatternSet patterns(ctx);
1551   GlobalConstantsArgs cst_args = GetConstantArgs(module);
1552   patterns.insert<GetGlobalOpLowering>(ctx, cst_args);
1553 
1554   // Set up conversion target to rewrite only GetGlobalOp larger than the
1555   // threshold and avoid any other canonicalizations that can break later
1556   // passes.
1557   ConversionTarget target(*ctx);
1558   target.addDynamicallyLegalOp<GetGlobalOp>([&](GetGlobalOp op) {
1559     auto memref = op.getType();
1560     return memref.getNumElements() < min_num_elements_;
1561   });
1562   target.addLegalOp<ConstantOp, memref::ViewOp, memref::ReinterpretCastOp>();
1563 
1564   // TODO(ezhulenev): By adding MHLO and LMHLO to a set of legal dialects, we
1565   // suppress any rewrites for these dialects (there are canonicalization
1566   // patterns that interact badly with downstream Gpu binary code generation).
1567   target.addLegalDialect<mhlo::MhloDialect, lmhlo::LmhloDialect>();
1568 
1569   if (failed(applyPartialConversion(module, target, std::move(patterns))))
1570     signalPassFailure();
1571 }
1572 
runOnOperation()1573 void ConvertGpuToJitRtPass::runOnOperation() {
1574   ModuleOp module = getOperation();
1575   MLIRContext* ctx = module.getContext();
1576 
1577   // Convert gpu operations to JitRt gpu runtime custom calls.
1578   RewritePatternSet patterns(ctx);
1579   patterns.insert<GpuModuleOpLowering, LaunchFuncOpLowering, MemcpyOpLowering,
1580                   MemsetOpLowering, InfeedOpLowering, OutfeedOpLowering>(ctx);
1581 
1582   if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns))))
1583     return signalPassFailure();
1584 }
1585 
runOnOperation()1586 void ConvertLmhloGpuToJitRtPass::runOnOperation() {
1587   ModuleOp module = getOperation();
1588   MLIRContext* ctx = module.getContext();
1589 
1590   // Convert lmhlo_gpu operations to JitRt gpu runtime custom calls.
1591   RewritePatternSet patterns(ctx);
1592 
1593   // Each unique Gemm/Matmul operation in the module will get assigned a uid.
1594   GemmUidGenerator gemm_uid;
1595   patterns.insert<GemmOpLowering, CublasLtMatmulOpLowering>(ctx, gemm_uid);
1596 
1597   // Assign shared unique id to each unique pair of async start-done operations,
1598   // all other collective operations will get assigned uid.
1599   CollectiveUidGenerator collective_uid;
1600   auto walked = module.walk([&](AllReduceStartOp start) -> WalkResult {
1601     Value token = start.getToken();
1602 
1603     // We expect the token to be consumed just once.
1604     if (!token.hasOneUse()) return start.emitOpError("token has multiple uses");
1605 
1606     // Token must be consumed by the corresponding done operation.
1607     auto done = dyn_cast<AllReduceDoneOp>(*token.getUsers().begin());
1608     if (!done) return start.emitOpError("illegal token user");
1609 
1610     collective_uid.AssignUid(start, done);
1611     return WalkResult::advance();
1612   });
1613   if (walked.wasInterrupted()) return signalPassFailure();
1614 
1615   // Patterns for collective operations.
1616   patterns.insert<AllGatherOpLowering, AllReduceOpLowering,
1617                   AllReduceStartOpLowering, AllToAllOpLowering,
1618                   CollectivePermuteOpLowering, ReduceScatterOpLowering>(
1619       ctx, collective_uid);
1620 
1621   // Patterns for every other Gpu operation.
1622   patterns
1623       .insert<FftOpLowering, CholeskyOpLowering, PartitionIdOpLowering,
1624               ReplicaIdOpLowering, WhileOpLowering, CaseOpLowering,
1625               CustomCallOpLowering, TerminatorOpLowering, ConvForwardOpLowering,
1626               ConvForwardFusedOpLowering, ConvForwardFusedSideInputOpLowering,
1627               ConvBackwardFilterOpLowering, ConvBackwardInputOpLowering>(ctx);
1628 
1629   if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns))))
1630     return signalPassFailure();
1631 
1632   // TODO(ezhulenev): We must run `done` op lowering after the `start` op
1633   // lowering to ensure that all redundant collective operations will be
1634   // safely replaced by a `memcpy` operations. We should find a better way to
1635   // achieve this goal.
1636   {
1637     RewritePatternSet patterns(ctx);
1638     patterns.insert<AllReduceDoneOpLowering>(ctx, collective_uid);
1639     if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns))))
1640       return signalPassFailure();
1641   }
1642 }
1643 
createConvertGpuToJitRtPass()1644 std::unique_ptr<OperationPass<ModuleOp>> createConvertGpuToJitRtPass() {
1645   return std::make_unique<ConvertGpuToJitRtPass>();
1646 }
1647 
createConvertLmhloConstantToArgPass()1648 std::unique_ptr<OperationPass<ModuleOp>> createConvertLmhloConstantToArgPass() {
1649   return std::make_unique<ConvertLmhloConstantToArgPass>();
1650 }
1651 
createConvertLmhloConstantToArgPass(int64_t min_num_elements)1652 std::unique_ptr<OperationPass<ModuleOp>> createConvertLmhloConstantToArgPass(
1653     int64_t min_num_elements) {
1654   return std::make_unique<ConvertLmhloConstantToArgPass>(min_num_elements);
1655 }
1656 
createConvertLmhloGpuToJitRtPass()1657 std::unique_ptr<OperationPass<ModuleOp>> createConvertLmhloGpuToJitRtPass() {
1658   return std::make_unique<ConvertLmhloGpuToJitRtPass>();
1659 }
1660 
populateLmhloToJitRtPasses(mlir::OpPassManager & pm,GpuBinaryOptions options)1661 void populateLmhloToJitRtPasses(mlir::OpPassManager& pm,
1662                                 GpuBinaryOptions options) {
1663   // Convert large global memrefs corresponding to XLA constants with arguments,
1664   // so that compiled device kernels do not capture them.
1665   //
1666   // TODO(ezhulenev): Threshold should be consistent with the device kernel
1667   // code generation. If constant will be embedded into the device module, we
1668   // should not inline it too early. Currently it's hardcoded to `1` element.
1669   pm.addPass(createConvertLmhloConstantToArgPass(/*min_num_elements=*/2));
1670   pm.addPass(createSymbolDCEPass());  // Clean up unused global constants.
1671 
1672   // Small global constants will be embedded into the device modules.
1673   pm.addPass(createConvertLmhloToGpuBinaryPass(options));
1674 
1675   // Convert remaining small global memrefs corresponding to constant arguments.
1676   pm.addPass(createConvertLmhloConstantToArgPass());
1677   pm.addPass(createSymbolDCEPass());  // Clean up unused global constants.
1678 
1679   // Lower all Gpu operations to the JitRt Gpu runtime intrinsics.
1680   pm.addPass(createConvertLmhloGpuToJitRtPass());
1681   pm.addPass(createConvertGpuToJitRtPass());
1682 }
1683 
registerLmhloToJitRtPasses()1684 void registerLmhloToJitRtPasses() {
1685   mlir::registerPass([] { return createConvertGpuToJitRtPass(); });
1686   mlir::registerPass([] { return createConvertLmhloConstantToArgPass(); });
1687   mlir::registerPass([] { return createConvertLmhloGpuToJitRtPass(); });
1688 
1689   mlir::registerPassPipeline(
1690       "lmhlo-to-jitrt", "Lower LMHLO to JitRt IR",
1691       [](OpPassManager& pm, StringRef options,
1692          function_ref<LogicalResult(const Twine&)> errorHandler) {
1693         populateLmhloToJitRtPasses(pm);
1694         return success();
1695       },
1696       /*optHandler=*/[](function_ref<void(const PassOptions&)>) {});
1697 }
1698 
1699 }  // namespace tensorflow
1700