1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform_gpu.h"
17 
18 #include <memory>
19 #include <utility>
20 
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/Support/Casting.h"
24 #include "llvm/Support/raw_ostream.h"
25 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
26 #include "mlir/IR/Attributes.h"  // from @llvm-project
27 #include "mlir/IR/Builders.h"  // from @llvm-project
28 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
29 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
30 #include "mlir/IR/Value.h"  // from @llvm-project
31 #include "mlir/Pass/Pass.h"  // from @llvm-project
32 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
33 #include "mlir/Support/LLVM.h"  // from @llvm-project
34 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
35 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
36 #include "tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h"
37 #include "tensorflow/compiler/mlir/lite/experimental/tac/hardwares/gpu_hardware.h"
38 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
40 
41 namespace mlir {
42 namespace TFL {
43 namespace tac {
44 namespace {
45 
46 struct DeviceTransformGPUPass
47     : public mlir::PassWrapper<DeviceTransformGPUPass,
48                                OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_IDmlir::TFL::tac::__anon107daea10111::DeviceTransformGPUPass49   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DeviceTransformGPUPass)
50 
51   llvm::StringRef getArgument() const final {
52     return "tfl-device-transform-gpu";
53   }
getDescriptionmlir::TFL::tac::__anon107daea10111::DeviceTransformGPUPass54   llvm::StringRef getDescription() const final {
55     return "Suitable transformation for gpu only.";
56   }
getDependentDialectsmlir::TFL::tac::__anon107daea10111::DeviceTransformGPUPass57   void getDependentDialects(DialectRegistry& registry) const override {
58     registry.insert<TF::TensorFlowDialect>();
59   }
60   void runOnOperation() override;
61 };
62 
runOnOperation()63 void DeviceTransformGPUPass::runOnOperation() {
64   auto func = getOperation();
65   auto* ctx = &getContext();
66   RewritePatternSet patterns = GetHardwareRewritePatternsGPU(ctx);
67   (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
68 }
69 
70 }  // namespace
71 
GetHardwareRewritePatternsGPU(MLIRContext * context)72 RewritePatternSet GetHardwareRewritePatternsGPU(MLIRContext* context) {
73   GpuHardware gpu_hardware;
74   return gpu_hardware.GetTransformations(context);
75 }
76 
CreateDeviceTransformGPUPass()77 std::unique_ptr<OperationPass<func::FuncOp>> CreateDeviceTransformGPUPass() {
78   return std::make_unique<DeviceTransformGPUPass>();
79 }
80 
81 static PassRegistration<DeviceTransformGPUPass> pass;
82 
83 }  // namespace tac
84 }  // namespace TFL
85 }  // namespace mlir
86