1 // Copyright 2020 The TensorFlow Runtime Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // Pattern to lower lmhlo ops with help of the ir emitter to gpu device code
16 // and gpu dialect ops (gpu.launch_func and gpu.memcpy).
17 
18 #include <iterator>
19 #include <numeric>
20 #include <tuple>
21 
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/Support/Casting.h"
25 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"  // from @llvm-project
26 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
27 #include "mlir/Dialect/GPU/IR/GPUDialect.h"  // from @llvm-project
28 #include "mlir/Dialect/MemRef/IR/MemRef.h"  // from @llvm-project
29 #include "mlir/IR/Attributes.h"  // from @llvm-project
30 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
31 #include "mlir/IR/Builders.h"  // from @llvm-project
32 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
33 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
34 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
35 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
36 #include "mlir/IR/Visitors.h"  // from @llvm-project
37 #include "mlir/Support/LLVM.h"  // from @llvm-project
38 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
39 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
40 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
41 #include "tensorflow/compiler/mlir/tfrt/transforms/lmhlo_to_gpu/lmhlo_to_gpu_binary.h"
42 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
43 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
44 #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
45 #include "tensorflow/compiler/xla/service/gpu/copy_thunk.h"
46 #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
47 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
48 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h"
49 #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h"
50 #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
51 #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
52 #include "tensorflow/compiler/xla/service/gpu/memset_thunk.h"
53 #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h"
54 #include "tensorflow/compiler/xla/service/gpu/target_constants.h"
55 #include "tensorflow/compiler/xla/service/gpu/thunk.h"
56 #include "tfrt/gpu/passes/passes.h"  // from @tf_runtime
57 
58 #if TENSORFLOW_USE_ROCM
59 #include "tensorflow/core/platform/rocm_rocdl_path.h"
60 #else
61 #include "tensorflow/compiler/xla/service/gpu/nvptx_helper.h"
62 #endif
63 
64 namespace tensorflow {
65 
66 using mlir::ArrayRef;
67 using mlir::FloatType;
68 using mlir::Operation;
69 using mlir::SmallVector;
70 using mlir::Value;
71 using mlir::arith::ConstantFloatOp;
72 using mlir::arith::ConstantIntOp;
73 using mlir::arith::ConstantOp;
74 using mlir::memref::GetGlobalOp;
75 using xla::gpu::DeviceToDeviceCopyThunk;
76 using xla::gpu::IrEmitterContext;
77 using xla::gpu::IrEmitterUnnested;
78 using xla::gpu::KernelThunk;
79 using xla::gpu::Thunk;
80 using xla::gpu::ThunkSequence;
81 using ConstantInfo = xla::gpu::GpuExecutable::ConstantInfo;
82 
83 namespace {
84 
MakeBitPatternConstant(mlir::OpBuilder & builder,mlir::Location loc,mlir::Type type,uint32_t bit_pattern)85 mlir::Value MakeBitPatternConstant(mlir::OpBuilder& builder, mlir::Location loc,
86                                    mlir::Type type, uint32_t bit_pattern) {
87   // In XLA a 1-byte bit pattern copied to fill a 32-byte word when
88   // `Memset32BitValueThunk` is constructed, so to get back an `i1` constant we
89   // only need to check if any bit is set to `1`.
90   if (type.isInteger(1)) {
91     return builder.create<ConstantOp>(loc, builder.getBoolAttr(bit_pattern));
92   }
93 
94   if (type.isInteger(32)) {
95     llvm::APInt i32(32, bit_pattern);
96     return builder.create<ConstantIntOp>(loc, i32.getSExtValue(), type);
97   }
98 
99   if (type.isF32()) {
100     llvm::APFloat f32(llvm::APInt(32, bit_pattern).bitsToFloat());
101     return builder.create<ConstantFloatOp>(loc, f32, type.cast<FloatType>());
102   }
103 
104   llvm_unreachable("unsupported type");
105 }
106 
107 // Replaces lmhlo ops within a module with gpu.launch_func and gpu.memcpy ops.
108 struct KernelOpsPattern : mlir::OpRewritePattern<mlir::ModuleOp> {
KernelOpsPatterntensorflow::__anon24a198000111::KernelOpsPattern109   KernelOpsPattern(mlir::MLIRContext* context, GpuBinaryOptions options)
110       : mlir::OpRewritePattern<mlir::ModuleOp>(context), options(options) {}
111 
112   using OpRewritePattern<mlir::ModuleOp>::OpRewritePattern;
113 
114   mlir::LogicalResult matchAndRewrite(
115       mlir::ModuleOp module_op, mlir::PatternRewriter& rewriter) const override;
116 
117   GpuBinaryOptions options;
118 };
119 
120 struct RewriteData {
121   Operation* op;
122   mlir::SmallVector<Value, 4> arguments;
123   std::vector<xla::BufferAllocation> allocations;
124   std::unique_ptr<ThunkSequence> thunks;
125   std::vector<ConstantInfo> constants;
126   std::string gpu_module_data;
127 };
128 
129 }  // namespace
130 
MakeError(llvm::StringRef message)131 static llvm::Error MakeError(llvm::StringRef message) {
132   return llvm::createStringError(llvm::inconvertibleErrorCode(), message);
133 }
MakeError(xla::Status status)134 static llvm::Error MakeError(xla::Status status) {
135   return MakeError(status.error_message());
136 }
137 
138 // Clones `op` into a function within a module with `arguments`.
139 // The `get_global_ops` are the def ops of `arguments`, or null otherwise.
140 static std::tuple<mlir::OwningOpRef<mlir::ModuleOp>, mlir::func::FuncOp>
CloneToModule(Operation * op,mlir::ValueRange arguments,mlir::MutableArrayRef<GetGlobalOp> get_global_ops)141 CloneToModule(Operation* op, mlir::ValueRange arguments,
142               mlir::MutableArrayRef<GetGlobalOp> get_global_ops) {
143   auto loc = op->getLoc();
144   auto* context = op->getContext();
145   mlir::OpBuilder builder(context);
146 
147   mlir::OwningOpRef<mlir::ModuleOp> module_op =
148       builder.create<mlir::ModuleOp>(loc);
149   builder.setInsertionPointToEnd(module_op->getBody());
150   // Clone and annotate the memref.global ops that the memref.get_global ops
151   // refer to. The lmhlo.alloc index refers to one of the function arguments.
152   for (auto pair : llvm::enumerate(get_global_ops)) {
153     if (!pair.value()) continue;
154     Operation* global_op = mlir::SymbolTable::lookupNearestSymbolFrom(
155         pair.value(), pair.value().getNameAttr());
156     auto attr = builder.getIndexAttr(pair.index());
157     builder.clone(*global_op)->setAttr("lmhlo.alloc", attr);
158   }
159 
160   // If 'op' is a gpu.launch_func, clone referenced gpu.module.
161   if (auto launch_func_op = llvm::dyn_cast<mlir::gpu::LaunchFuncOp>(op)) {
162     builder.clone(*mlir::SymbolTable::lookupNearestSymbolFrom(
163         op, launch_func_op.getKernelModuleName()));
164   }
165 
166   auto func_type = builder.getType<mlir::FunctionType>(
167       mlir::TypeRange(arguments), mlir::TypeRange());
168   auto func_name = op->getParentOfType<mlir::func::FuncOp>().getName();
169   auto func_op = builder.create<mlir::func::FuncOp>(loc, func_name, func_type);
170   // Annotate the function arguments if they refer to a memref.global op.
171   for (auto pair : llvm::enumerate(get_global_ops)) {
172     if (!pair.value()) continue;
173     auto attr = builder.getStringAttr(pair.value().getName());
174     func_op.setArgAttr(pair.index(), "lmhlo.constant_name", attr);
175   }
176   func_op.setPublic();
177 
178   builder.setInsertionPointToEnd(func_op.addEntryBlock());
179   mlir::BlockAndValueMapping mapping;
180   for (const auto& pair : llvm::zip_first(arguments, func_op.getArguments()))
181     mapping.map(std::get<0>(pair), std::get<1>(pair));
182   // Clone the memref.get_global ops.
183   for (auto get_global_op : get_global_ops) {
184     if (!get_global_op) continue;
185     mapping.map(get_global_op, builder.clone(*get_global_op)->getResult(0));
186   }
187   auto* clone = builder.clone(*op, mapping);
188   auto name_loc = mlir::NameLoc::get(builder.getStringAttr(func_name));
189   clone->setLoc(mlir::FusedLoc::get(context, {loc, name_loc}));
190   builder.create<mlir::lmhlo::TerminatorOp>(loc);
191 
192   return std::make_tuple(std::move(module_op), func_op);
193 }
194 
195 // Converts the argument's shaped types into buffer allocations.
GetAllocations(ArrayRef<Value> arguments,ArrayRef<GetGlobalOp> get_global_ops)196 static llvm::Expected<std::vector<xla::BufferAllocation>> GetAllocations(
197     ArrayRef<Value> arguments, ArrayRef<GetGlobalOp> get_global_ops) {
198   std::vector<xla::BufferAllocation> allocations;
199   allocations.reserve(arguments.size());
200   for (Value argument : arguments) {
201     mlir::ShapedType type = argument.getType().dyn_cast<mlir::ShapedType>();
202     if (!type || !type.hasStaticShape())
203       return MakeError("Expected static shapes");
204     auto element_size_bytes = xla::GetElementTypeBytes(type.getElementType());
205     if (!element_size_bytes.ok()) return MakeError(element_size_bytes.status());
206     size_t size = *element_size_bytes * type.getNumElements();
207     allocations.emplace_back(allocations.size(), size, 0);
208   }
209   for (auto pair : llvm::zip_first(allocations, get_global_ops))
210     std::get<0>(pair).set_constant(std::get<1>(pair));
211   return allocations;
212 }
213 
214 // Emits thunks and an llvm device code module for the given func_op.
215 static llvm::Expected<
216     std::tuple<std::unique_ptr<ThunkSequence>, std::vector<ConstantInfo>>>
Emit(mlir::func::FuncOp func_op,absl::Span<const xla::BufferAllocation> allocations,const GpuBinaryOptions & gpu_options,const xla::HloModuleConfig & hlo_module_config,llvm::Module * llvm_module)217 Emit(mlir::func::FuncOp func_op,
218      absl::Span<const xla::BufferAllocation> allocations,
219      const GpuBinaryOptions& gpu_options,
220      const xla::HloModuleConfig& hlo_module_config, llvm::Module* llvm_module) {
221 #if TENSORFLOW_USE_ROCM
222   const char* target_triple = xla::gpu::amdgpu::TargetTriple();
223   const char* data_layout = xla::gpu::amdgpu::DataLayout();
224 #else
225   const char* target_triple = xla::gpu::nvptx::TargetTriple();
226   const char* data_layout = xla::gpu::nvptx::DataLayout();
227 #endif
228 
229   llvm_module->setTargetTriple(target_triple);
230   llvm_module->setDataLayout(data_layout);
231 
232   IrEmitterContext ir_emitter_context(
233       /*hlo_module=*/nullptr, /*buffer_assignment=*/nullptr,
234       gpu_options.platform_name, gpu_options.gpu_device_info,
235       gpu_options.cuda_compute_capability, gpu_options.rocm_compute_capability,
236       func_op->getContext(), llvm_module);
237 
238   ir_emitter_context.set_allocations(allocations);
239 
240   auto ir_emitter =
241       IrEmitterUnnested::Create(hlo_module_config, &ir_emitter_context);
242   if (!ir_emitter.ok()) return MakeError(ir_emitter.status());
243 
244   auto emit_status = (*ir_emitter)->EmitLmhloRegion(&func_op.getBody());
245   if (!emit_status.ok()) return MakeError(emit_status);
246 
247   return std::make_tuple((*ir_emitter)->ConsumeThunkSequence(),
248                          std::move(ir_emitter_context.constants()));
249 }
250 
251 // Returns the data to rewrite op without changing the IR.
Match(Operation * op,GpuBinaryOptions gpu_options)252 static llvm::Expected<RewriteData> Match(Operation* op,
253                                          GpuBinaryOptions gpu_options) {
254   mlir::SmallVector<Value> arguments;
255   llvm::copy_if(
256       op->getOperands(), std::back_inserter(arguments),
257       // Filter block/thread size arguments of gpu.launch_func.
258       [](Value value) { return value.getType().isa<mlir::ShapedType>(); });
259   mlir::SetVector<Value> captures;
260   getUsedValuesDefinedAbove(op->getRegions(), captures);
261   llvm::copy(captures, std::back_inserter(arguments));
262 
263   // Collect arguments that are defined by a memref.get_global op. The
264   // created module's annotations make the ir emitter recognize them as
265   // constants.
266   SmallVector<GetGlobalOp, 4> get_global_ops;
267   get_global_ops.reserve(arguments.size());
268   llvm::transform(
269       arguments, std::back_inserter(get_global_ops),
270       [](Value argument) { return argument.getDefiningOp<GetGlobalOp>(); });
271 
272   auto allocations = GetAllocations(arguments, get_global_ops);
273   if (!allocations) return allocations.takeError();
274   auto module_op = CloneToModule(op, arguments, get_global_ops);
275 
276   xla::HloModuleConfig hlo_module_config;
277   xla::DebugOptions options = xla::GetDebugOptionsFromFlags();
278   hlo_module_config.set_debug_options(options);
279 
280   llvm::LLVMContext llvm_context;
281   auto llvm_module = std::make_unique<llvm::Module>("", llvm_context);
282 
283   auto emit_result = Emit(std::get<mlir::func::FuncOp>(module_op), *allocations,
284                           gpu_options, hlo_module_config, llvm_module.get());
285   if (!emit_result) return emit_result.takeError();
286   auto thunks = std::move(std::get<0>(*emit_result));
287   auto constants = std::move(std::get<1>(*emit_result));
288   // Inline sequential thunks into the `thunks` vector.
289   for (auto it = thunks->begin(); it != thunks->end();) {
290     if (it->get()->kind() == Thunk::kSequential) {
291       auto sequence = std::move(
292           static_cast<xla::gpu::SequentialThunk*>(it->get())->thunks());
293       it = thunks->erase(it);
294       it = thunks->insert(it, std::make_move_iterator(sequence.begin()),
295                           std::make_move_iterator(sequence.end()));
296     } else {
297       ++it;
298     }
299   }
300   if (!llvm::all_of(*thunks, [](const auto& thunk) {
301         Thunk::Kind kinds[] = {Thunk::kKernel, Thunk::kCopy,
302                                Thunk::kMemset32BitValue, Thunk::kMemzero};
303         auto equal = [&](Thunk::Kind kind) { return thunk->kind() == kind; };
304         return llvm::any_of(kinds, equal);
305       })) {
306     return MakeError("Expected only kernel, copy, memset, and memzero thunks");
307   }
308 
309 #if TENSORFLOW_USE_ROCM
310   auto libdevice_dir = tensorflow::RocdlRoot();
311   xla::gpu::GpuVersion gpu_version{rocm_compute_capability};
312   auto hsaco = xla::gpu::amdgpu::CompileToHsaco(
313       llvm_module.get(), gpu_version, hlo_module_config, libdevice_dir);
314   if (!hsaco.ok()) return MakeError(hsaco.status());
315   StatusOr<std::string> ptx(std::string(hsaco->begin(), hsaco->end()));
316 #else
317   auto libdevice_dir = xla::gpu::GetLibdeviceDir(hlo_module_config);
318   auto ptx = xla::gpu::nvptx::CompileToPtx(llvm_module.get(),
319                                            gpu_options.cuda_compute_capability,
320                                            hlo_module_config, libdevice_dir);
321   if (!ptx.ok()) return MakeError(ptx.status());
322 #endif
323 
324   return RewriteData{op,
325                      std::move(arguments),
326                      std::move(*allocations),
327                      std::move(thunks),
328                      std::move(constants),
329                      std::move(*ptx)};
330 }
331 
332 // Replaces op with gpu.launch_func and gpu.memcpy ops.
Rewrite(Operation * op,mlir::PatternRewriter & rewriter,mlir::SymbolTable & symbol_table,ArrayRef<Value> arguments,ThunkSequence * thunks,ArrayRef<ConstantInfo> constants,mlir::StringRef gpu_module_data)333 static void Rewrite(Operation* op, mlir::PatternRewriter& rewriter,
334                     mlir::SymbolTable& symbol_table, ArrayRef<Value> arguments,
335                     ThunkSequence* thunks, ArrayRef<ConstantInfo> constants,
336                     mlir::StringRef gpu_module_data) {
337   mlir::OpBuilder::InsertionGuard guard(rewriter);
338   auto loc = op->getLoc();
339 
340   rewriter.setInsertionPoint(op->getParentOfType<mlir::func::FuncOp>());
341   auto gpu_module = rewriter.create<mlir::gpu::GPUModuleOp>(loc, "gpu_module");
342   symbol_table.insert(gpu_module);
343   gpu_module->setAttr(tfrt::gpu::GetGpuBinaryAttrName(),
344                       rewriter.getStringAttr(gpu_module_data));
345 
346   // Annotate memref.global ops with the gpu.module symbol, and annotate the
347   // gpu.module op with memref.global symbols which require initialization.
348   SmallVector<mlir::Attribute, 4> const_attrs;
349   for (const auto& constant : constants) {
350     auto global_op = mlir::SymbolTable::lookupNearestSymbolFrom(
351         op, rewriter.getStringAttr(constant.symbol_name));
352     if (!global_op) {
353       LOG(WARNING) << "memref.global op not found for constant. Possibly "
354                    << "unused (spurious) constant.";
355       continue;
356     }
357     global_op->setAttr(tfrt::gpu::GetGpuModuleAttrName(),
358                        mlir::SymbolRefAttr::get(gpu_module));
359     if (!constant.content.empty())
360       const_attrs.emplace_back(mlir::SymbolRefAttr::get(global_op));
361   }
362   if (!const_attrs.empty()) {
363     gpu_module->setAttr(tfrt::gpu::GetGpuConstantsAttrName(),
364                         rewriter.getArrayAttr(const_attrs));
365   }
366 
367   for (const auto& thunk : *thunks) {
368     if (thunk->kind() == Thunk::kCopy) {
369       const auto* copy_thunk =
370           static_cast<const DeviceToDeviceCopyThunk*>(thunk.get());
371       auto get_argument = [&](const xla::BufferAllocation::Slice& slice) {
372         assert(slice.offset() == 0 && slice.size() == copy_thunk->size_bytes());
373         return arguments[slice.index()];
374       };
375       rewriter.setInsertionPoint(op);
376       rewriter.create<mlir::gpu::MemcpyOp>(
377           loc, mlir::TypeRange(), mlir::ValueRange(),
378           get_argument(copy_thunk->destination()),
379           get_argument(copy_thunk->source()));
380       continue;
381     }
382 
383     auto rewrite_memset = [&](const xla::BufferAllocation::Slice& slice,
384                               uint32_t memset_value) {
385       assert(slice.offset() == 0);
386       Value buffer_arg = arguments[slice.index()];
387       auto element_type =
388           buffer_arg.getType().cast<mlir::MemRefType>().getElementType();
389       rewriter.setInsertionPoint(op);
390       Value value =
391           MakeBitPatternConstant(rewriter, loc, element_type, memset_value);
392       rewriter.create<mlir::gpu::MemsetOp>(
393           loc, mlir::TypeRange(), mlir::ValueRange(), buffer_arg, value);
394     };
395 
396     if (thunk->kind() == Thunk::kMemset32BitValue) {
397       const auto* memset_thunk =
398           static_cast<const xla::gpu::Memset32BitValueThunk*>(thunk.get());
399       rewrite_memset(memset_thunk->destination(), memset_thunk->value());
400       continue;
401     }
402     if (thunk->kind() == Thunk::kMemzero) {
403       const auto* memzero_thunk =
404           static_cast<const xla::gpu::MemzeroThunk*>(thunk.get());
405       rewrite_memset(memzero_thunk->destination(), 0);
406       continue;
407     }
408 
409     const auto* kernel_thunk = static_cast<const KernelThunk*>(thunk.get());
410     rewriter.setInsertionPointToStart(gpu_module.getBody());
411     SmallVector<Value, 4> kernel_args;
412     for (auto kernel_arg : kernel_thunk->arguments())
413       kernel_args.push_back(arguments[kernel_arg->index()]);
414     auto func_type = rewriter.getType<mlir::FunctionType>(
415         mlir::TypeRange(mlir::ValueRange(kernel_args)), mlir::TypeRange());
416     mlir::gpu::GPUFuncOp kernel_func = rewriter.create<mlir::gpu::GPUFuncOp>(
417         loc, kernel_thunk->kernel_name(), func_type);
418     kernel_func->setAttr(mlir::gpu::GPUDialect::getKernelFuncAttrName(),
419                          rewriter.getUnitAttr());
420     rewriter.setInsertionPointToEnd(&kernel_func.getBody().back());
421     rewriter.create<mlir::gpu::ReturnOp>(loc);
422 
423     rewriter.setInsertionPoint(op);
424     auto make_const_idx = [&](int64_t value) {
425       auto attr = rewriter.getIndexAttr(value);
426       return rewriter.create<mlir::arith::ConstantOp>(loc, attr).getResult();
427     };
428     auto make_kernel_dim3 = [&](const auto& dim3) {
429       return mlir::gpu::KernelDim3{make_const_idx(dim3.x),
430                                    make_const_idx(dim3.y),
431                                    make_const_idx(dim3.z)};
432     };
433     const auto& launch_dims = kernel_thunk->launch_dimensions();
434     auto grid_size = make_kernel_dim3(launch_dims.block_counts());
435     auto block_size = make_kernel_dim3(launch_dims.thread_counts_per_block());
436 
437     rewriter.create<mlir::gpu::LaunchFuncOp>(
438         loc, kernel_func, grid_size, block_size,
439         /*shared_memory_size_bytes=*/nullptr, kernel_args);
440   }
441 
442   rewriter.eraseOp(op);
443 }
444 
445 // An overload set for defining predicates for operations that should
446 // conditionally go through the XLA GPU code emitters.
447 template <typename OpTy>
HasGpuEmitter(OpTy)448 static bool HasGpuEmitter(OpTy) {
449   return true;
450 }
451 
452 // Select custom calls that have corresponding GPU emitters.
HasGpuEmitter(mlir::lmhlo::CustomCallOp custom_call)453 static bool HasGpuEmitter(mlir::lmhlo::CustomCallOp custom_call) {
454   llvm::StringRef target = custom_call.getCallTargetName();
455   return target == "SliceToDynamic" || target == "PadToStatic";
456 }
457 
matchAndRewrite(mlir::ModuleOp module_op,mlir::PatternRewriter & rewriter) const458 mlir::LogicalResult KernelOpsPattern::matchAndRewrite(
459     mlir::ModuleOp module_op, mlir::PatternRewriter& rewriter) const {
460   SmallVector<RewriteData, 4> rewrites;
461 
462   // Get data to rewrite kernel ops without changing the IR.
463   auto walk = [&](auto op_type_tag) {
464     using OpTy = decltype(op_type_tag);
465 
466     return module_op.walk([&](OpTy op) -> mlir::WalkResult {
467       if (!HasGpuEmitter(op)) return mlir::success();
468 
469       auto data = Match(op, options);
470       if (auto err = data.takeError())
471         return rewriter.notifyMatchFailure(op, toString(std::move(err)));
472 
473       rewrites.emplace_back(std::move(*data));
474       return mlir::success();
475     });
476   };
477 
478   // Compile all operations that have GPU code emitters to the GPU binary,
479   if (walk(mlir::lmhlo::FusionOp()).wasInterrupted() ||
480       walk(mlir::lmhlo::RngGetAndUpdateStateOp()).wasInterrupted() ||
481       walk(mlir::lmhlo::ScatterOp()).wasInterrupted() ||
482       walk(mlir::lmhlo::SelectAndScatterOp()).wasInterrupted() ||
483       walk(mlir::lmhlo::SortOp()).wasInterrupted() ||
484       walk(mlir::lmhlo::CustomCallOp()).wasInterrupted() ||
485       walk(mlir::gpu::LaunchFuncOp()).wasInterrupted())
486     return mlir::failure();
487 
488   if (rewrites.empty()) {
489     return rewriter.notifyMatchFailure(module_op, "No kernel ops");
490   }
491 
492   // Mark module as gpu.container_module.
493   rewriter.updateRootInPlace(module_op, [&] {
494     module_op->setAttr(mlir::gpu::GPUDialect::getContainerModuleAttrName(),
495                        rewriter.getUnitAttr());
496   });
497 
498   // Replace the kernel ops with gpu.launch_func.
499   mlir::SymbolTable symbol_table(module_op);
500   for (const auto& data : rewrites) {
501     Rewrite(data.op, rewriter, symbol_table, data.arguments, data.thunks.get(),
502             data.constants, data.gpu_module_data);
503   }
504 
505   return mlir::success();
506 }
507 
populateKernelOpsPattern(mlir::RewritePatternSet & patterns,GpuBinaryOptions options)508 void populateKernelOpsPattern(mlir::RewritePatternSet& patterns,
509                               GpuBinaryOptions options) {
510   patterns.add<KernelOpsPattern>(patterns.getContext(), options);
511 }
512 
513 }  // namespace tensorflow
514