1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/tfrt/transforms/fallback_converter.h"
16
17 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
18 #include "tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback.h"
19 #include "tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_async.h"
20 #include "tfrt/basic_kernels/opdefs/types.h" // from @tf_runtime
21 #include "tfrt/core_runtime/opdefs/types.h" // from @tf_runtime
22
23 namespace tensorflow {
24 namespace tfrt_compiler {
25
FallbackConverter(mlir::MLIRContext * context)26 FallbackConverter::FallbackConverter(mlir::MLIRContext *context)
27 : builder_(context) {
28 addConversion([](tfrt::compiler::ChainType type) { return type; });
29 addConversion([](tfrt::fallback::TFTensorType type) { return type; });
30 addConversion([=](mlir::TensorType type) -> llvm::Optional<mlir::Type> {
31 // Ref types are not supported in both compiler and runtime.
32 if (type.getElementType().isa<mlir::TF::TensorFlowRefType>()) {
33 return llvm::None;
34 }
35
36 return builder_.getType<tfrt::fallback::TFTensorType>();
37 });
38 addConversion([=](mlir::Type type) -> llvm::Optional<mlir::Type> {
39 if (type == builder_.getI1Type()) return type;
40 return llvm::None;
41 });
42 }
43
ConvertCoreRTTensorHandleToFallbackTensor(mlir::Location loc,llvm::StringRef device,mlir::Value value,mlir::ConversionPatternRewriter & rewriter)44 mlir::Value ConvertCoreRTTensorHandleToFallbackTensor(
45 mlir::Location loc, llvm::StringRef device, mlir::Value value,
46 mlir::ConversionPatternRewriter &rewriter) {
47 if (value.getType().isa<tfrt::fallback::TFTensorType>()) return value;
48
49 if (!value.getType().isa<tfrt::corert::TensorHandleType>()) return {};
50
51 mlir::OpBuilder::InsertionGuard guard(rewriter);
52
53 if (device.endswith("CPU:0") && !device.startswith("/job:")) {
54 // Canonicalize CPU device name. This is needed as corert library only uses
55 // the default CPU device name (i.e.
56 // "/job:localhost/replica:0/task:0/device:CPU:0") and cannot recoganize
57 // other legal variants (e.g. "/device:CPU:0").
58 //
59 // Note that we don't want to make change to the device name if it is
60 // already canonicalized by users.
61 // e.g. "/job:tpu_worker/replica:0/task:x/device:CPU:0".
62 // TODO(tfrt-devs): to make the canonicalization more robust we should
63 // introduce a util to check each component of the TF device name.
64 device = GetDefaultCpuDeviceName();
65 }
66
67 auto *def = value.getDefiningOp();
68 if (def) {
69 rewriter.setInsertionPointAfter(def);
70 } else {
71 rewriter.setInsertionPointToStart(value.getParentBlock());
72 }
73
74 return rewriter
75 .create<tfrt::fallback_async::CoreRTTensorHandleToFallbackTensorOp>(
76 loc, rewriter.getType<tfrt::fallback::TFTensorType>(), value, device)
77 .getResult(0);
78 }
79
ConvertFallbackTensorToCoreRTTensorHandle(mlir::Location loc,mlir::Value value,mlir::ConversionPatternRewriter & rewriter)80 mlir::Value ConvertFallbackTensorToCoreRTTensorHandle(
81 mlir::Location loc, mlir::Value value,
82 mlir::ConversionPatternRewriter &rewriter) {
83 if (value.getType().isa<tfrt::corert::TensorHandleType>()) return value;
84
85 if (!value.getType().isa<tfrt::fallback::TFTensorType>()) return {};
86
87 // Use CPU device by default if no device is specified.
88 llvm::StringRef device = GetDefaultCpuDeviceName();
89 if (auto *def = value.getDefiningOp()) {
90 if (auto device_attr = def->getAttrOfType<mlir::StringAttr>("device")) {
91 // NOTE: The TPU_SYSTEM check is just a short term workaround. The long
92 // term solution should be checking the HostMemory annotation of the
93 // defining op (it should be defined in TF OpKernel). If HostMemory
94 // annotation is set for an output tensor, we should use CPU device here.
95 // TODO(b/200896904): Support HostMemory annotation.
96 if (!device_attr.getValue().endswith("TPU_SYSTEM:0")) {
97 device = device_attr.getValue();
98 }
99 }
100 }
101
102 return rewriter
103 .create<tfrt::fallback_async::FallbackTensorToCoreRTTensorHandleOp>(
104 loc, rewriter.getType<tfrt::corert::TensorHandleType>(), value,
105 device)
106 .getResult(0);
107 }
108
ConvertCoreRTOperands(mlir::Operation * op,mlir::ValueRange operands,llvm::SmallVectorImpl<mlir::Value> * new_operands,mlir::ConversionPatternRewriter & rewriter)109 mlir::LogicalResult ConvertCoreRTOperands(
110 mlir::Operation *op, mlir::ValueRange operands,
111 llvm::SmallVectorImpl<mlir::Value> *new_operands,
112 mlir::ConversionPatternRewriter &rewriter) {
113 mlir::OpBuilder::InsertionGuard guard(rewriter);
114 // Insert before the current op.
115 rewriter.setInsertionPoint(op);
116
117 for (auto operand : operands) {
118 auto value = ConvertFallbackTensorToCoreRTTensorHandle(op->getLoc(),
119 operand, rewriter);
120 if (!value) {
121 return op->emitWarning("failed to convert to !corert.tensorhandle")
122 << operand.getType();
123 }
124
125 new_operands->push_back(value);
126 }
127 return success();
128 }
129
ConvertFallbackOperands(mlir::Operation * op,llvm::StringRef device,mlir::ValueRange operands,llvm::SmallVectorImpl<mlir::Value> * new_operands,mlir::ConversionPatternRewriter & rewriter)130 mlir::LogicalResult ConvertFallbackOperands(
131 mlir::Operation *op, llvm::StringRef device, mlir::ValueRange operands,
132 llvm::SmallVectorImpl<mlir::Value> *new_operands,
133 mlir::ConversionPatternRewriter &rewriter) {
134 for (auto operand : operands) {
135 if (!operand.getType().isa<tfrt::fallback::TFTensorType>()) {
136 auto new_operand = ConvertCoreRTTensorHandleToFallbackTensor(
137 op->getLoc(), device, operand, rewriter);
138 if (!new_operand)
139 return op->emitWarning(
140 "failed to convert the operand to fallback tensor.");
141 new_operands->push_back(new_operand);
142 } else {
143 new_operands->push_back(operand);
144 }
145 }
146 return success();
147 }
148
149 } // namespace tfrt_compiler
150 } // namespace tensorflow
151