xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tfrt/transforms/fallback_converter.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/tfrt/transforms/fallback_converter.h"
16 
17 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
18 #include "tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback.h"
19 #include "tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_async.h"
20 #include "tfrt/basic_kernels/opdefs/types.h"  // from @tf_runtime
21 #include "tfrt/core_runtime/opdefs/types.h"  // from @tf_runtime
22 
23 namespace tensorflow {
24 namespace tfrt_compiler {
25 
FallbackConverter(mlir::MLIRContext * context)26 FallbackConverter::FallbackConverter(mlir::MLIRContext *context)
27     : builder_(context) {
28   addConversion([](tfrt::compiler::ChainType type) { return type; });
29   addConversion([](tfrt::fallback::TFTensorType type) { return type; });
30   addConversion([=](mlir::TensorType type) -> llvm::Optional<mlir::Type> {
31     // Ref types are not supported in both compiler and runtime.
32     if (type.getElementType().isa<mlir::TF::TensorFlowRefType>()) {
33       return llvm::None;
34     }
35 
36     return builder_.getType<tfrt::fallback::TFTensorType>();
37   });
38   addConversion([=](mlir::Type type) -> llvm::Optional<mlir::Type> {
39     if (type == builder_.getI1Type()) return type;
40     return llvm::None;
41   });
42 }
43 
ConvertCoreRTTensorHandleToFallbackTensor(mlir::Location loc,llvm::StringRef device,mlir::Value value,mlir::ConversionPatternRewriter & rewriter)44 mlir::Value ConvertCoreRTTensorHandleToFallbackTensor(
45     mlir::Location loc, llvm::StringRef device, mlir::Value value,
46     mlir::ConversionPatternRewriter &rewriter) {
47   if (value.getType().isa<tfrt::fallback::TFTensorType>()) return value;
48 
49   if (!value.getType().isa<tfrt::corert::TensorHandleType>()) return {};
50 
51   mlir::OpBuilder::InsertionGuard guard(rewriter);
52 
53   if (device.endswith("CPU:0") && !device.startswith("/job:")) {
54     // Canonicalize CPU device name. This is needed as corert library only uses
55     // the default CPU device name (i.e.
56     // "/job:localhost/replica:0/task:0/device:CPU:0") and cannot recoganize
57     // other legal variants (e.g. "/device:CPU:0").
58     //
59     // Note that we don't want to make change to the device name if it is
60     // already canonicalized by users.
61     // e.g. "/job:tpu_worker/replica:0/task:x/device:CPU:0".
62     // TODO(tfrt-devs): to make the canonicalization more robust we should
63     // introduce a util to check each component of the TF device name.
64     device = GetDefaultCpuDeviceName();
65   }
66 
67   auto *def = value.getDefiningOp();
68   if (def) {
69     rewriter.setInsertionPointAfter(def);
70   } else {
71     rewriter.setInsertionPointToStart(value.getParentBlock());
72   }
73 
74   return rewriter
75       .create<tfrt::fallback_async::CoreRTTensorHandleToFallbackTensorOp>(
76           loc, rewriter.getType<tfrt::fallback::TFTensorType>(), value, device)
77       .getResult(0);
78 }
79 
ConvertFallbackTensorToCoreRTTensorHandle(mlir::Location loc,mlir::Value value,mlir::ConversionPatternRewriter & rewriter)80 mlir::Value ConvertFallbackTensorToCoreRTTensorHandle(
81     mlir::Location loc, mlir::Value value,
82     mlir::ConversionPatternRewriter &rewriter) {
83   if (value.getType().isa<tfrt::corert::TensorHandleType>()) return value;
84 
85   if (!value.getType().isa<tfrt::fallback::TFTensorType>()) return {};
86 
87   // Use CPU device by default if no device is specified.
88   llvm::StringRef device = GetDefaultCpuDeviceName();
89   if (auto *def = value.getDefiningOp()) {
90     if (auto device_attr = def->getAttrOfType<mlir::StringAttr>("device")) {
91       // NOTE: The TPU_SYSTEM check is just a short term workaround. The long
92       // term solution should be checking the HostMemory annotation of the
93       // defining op (it should be defined in TF OpKernel). If HostMemory
94       // annotation is set for an output tensor, we should use CPU device here.
95       // TODO(b/200896904): Support HostMemory annotation.
96       if (!device_attr.getValue().endswith("TPU_SYSTEM:0")) {
97         device = device_attr.getValue();
98       }
99     }
100   }
101 
102   return rewriter
103       .create<tfrt::fallback_async::FallbackTensorToCoreRTTensorHandleOp>(
104           loc, rewriter.getType<tfrt::corert::TensorHandleType>(), value,
105           device)
106       .getResult(0);
107 }
108 
ConvertCoreRTOperands(mlir::Operation * op,mlir::ValueRange operands,llvm::SmallVectorImpl<mlir::Value> * new_operands,mlir::ConversionPatternRewriter & rewriter)109 mlir::LogicalResult ConvertCoreRTOperands(
110     mlir::Operation *op, mlir::ValueRange operands,
111     llvm::SmallVectorImpl<mlir::Value> *new_operands,
112     mlir::ConversionPatternRewriter &rewriter) {
113   mlir::OpBuilder::InsertionGuard guard(rewriter);
114   // Insert before the current op.
115   rewriter.setInsertionPoint(op);
116 
117   for (auto operand : operands) {
118     auto value = ConvertFallbackTensorToCoreRTTensorHandle(op->getLoc(),
119                                                            operand, rewriter);
120     if (!value) {
121       return op->emitWarning("failed to convert to !corert.tensorhandle")
122              << operand.getType();
123     }
124 
125     new_operands->push_back(value);
126   }
127   return success();
128 }
129 
ConvertFallbackOperands(mlir::Operation * op,llvm::StringRef device,mlir::ValueRange operands,llvm::SmallVectorImpl<mlir::Value> * new_operands,mlir::ConversionPatternRewriter & rewriter)130 mlir::LogicalResult ConvertFallbackOperands(
131     mlir::Operation *op, llvm::StringRef device, mlir::ValueRange operands,
132     llvm::SmallVectorImpl<mlir::Value> *new_operands,
133     mlir::ConversionPatternRewriter &rewriter) {
134   for (auto operand : operands) {
135     if (!operand.getType().isa<tfrt::fallback::TFTensorType>()) {
136       auto new_operand = ConvertCoreRTTensorHandleToFallbackTensor(
137           op->getLoc(), device, operand, rewriter);
138       if (!new_operand)
139         return op->emitWarning(
140             "failed to convert the operand to fallback tensor.");
141       new_operands->push_back(new_operand);
142     } else {
143       new_operands->push_back(operand);
144     }
145   }
146   return success();
147 }
148 
149 }  // namespace tfrt_compiler
150 }  // namespace tensorflow
151