1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This pass inserts corert.transfer op to make sure any argument of any op is
17 // on the same device of the op itself.
18
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/ADT/StringMap.h"
21 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
22 #include "mlir/IR/Attributes.h" // from @llvm-project
23 #include "mlir/IR/Builders.h" // from @llvm-project
24 #include "mlir/IR/Types.h" // from @llvm-project
25 #include "mlir/Pass/PassManager.h" // from @llvm-project
26 #include "mlir/Transforms/Passes.h" // from @llvm-project
27 #include "tensorflow/core/util/device_name_utils.h"
28 #include "tfrt/basic_kernels/opdefs/basic_kernels.h" // from @tf_runtime
29 #include "tfrt/basic_kernels/opdefs/types.h" // from @tf_runtime
30 #include "tfrt/core_runtime/opdefs/core_runtime.h" // from @tf_runtime
31 #include "tfrt/core_runtime/opdefs/types.h" // from @tf_runtime
32
33 namespace tensorflow {
34
35 namespace {
36
37 using DeviceNameUtils = ::tensorflow::DeviceNameUtils;
38
39 constexpr const char *kDeviceAttr = "device";
40 constexpr const char *kTFRTDeviceAttr = "tfrt.device";
41 // TODO(b/175480458): Do not assign default device once every op in the TF
42 // dialect has the device attribute.
43 constexpr const char *kDefaultDevice =
44 "/job:localhost/replica:0/task:0/device:CPU:0";
45
46 // This method canonicalizes the device name so that we can use string
47 // comparison to see if two devices are the same. It does the following
48 // transformations:
49 // 1) Set device ID to 0 if device ID is not already specified.
50 // 2) Change the device type to uppercase string.
CanonicalizeDeviceName(const std::string & device)51 static std::string CanonicalizeDeviceName(const std::string &device) {
52 if (device.empty()) return kDefaultDevice;
53
54 DeviceNameUtils::ParsedName parsed_name;
55 if (!device.empty() && device.at(0) == '/') {
56 DeviceNameUtils::ParseFullName(device, &parsed_name);
57 } else {
58 DeviceNameUtils::ParseFullName("/device:" + device, &parsed_name);
59 }
60
61 if (!parsed_name.has_id) {
62 parsed_name.has_id = true;
63 parsed_name.id = 0;
64 }
65
66 if (parsed_name.type == "cpu")
67 parsed_name.type = "CPU";
68 else if (parsed_name.type == "gpu")
69 parsed_name.type = "GPU";
70 else if (parsed_name.type == "tpu")
71 parsed_name.type = "TPU";
72 return DeviceNameUtils::ParsedNameToString(parsed_name);
73 }
74
75 // Return the device of the given operation.
GetDevice(Operation * op)76 static std::string GetDevice(Operation *op) {
77 std::string device = "";
78 if (StringAttr device_attr = op->getAttrOfType<StringAttr>(kDeviceAttr)) {
79 device = device_attr.getValue().str();
80 } else if (auto execute_op = llvm::dyn_cast<tfrt::corert::ExecuteOp>(op)) {
81 SmallVector<std::pair<StringRef, Attribute>, 4> attrs;
82 execute_op.getOpAttrs(&attrs);
83 for (std::pair<StringRef, Attribute> entry : attrs) {
84 if (entry.first == kDeviceAttr && entry.second.isa<StringAttr>()) {
85 device = entry.second.cast<StringAttr>().getValue().str();
86 break;
87 }
88 }
89 }
90
91 return CanonicalizeDeviceName(device);
92 }
93
94 // Return the device of the given value.
GetDevice(mlir::Value value,func::FuncOp parent_func_op)95 static std::string GetDevice(mlir::Value value, func::FuncOp parent_func_op) {
96 std::string device = "";
97 if (BlockArgument block_arg = value.dyn_cast<BlockArgument>()) {
98 if (StringAttr device_attr = parent_func_op.getArgAttrOfType<StringAttr>(
99 block_arg.getArgNumber(), kTFRTDeviceAttr)) {
100 device = device_attr.getValue().str();
101 }
102 } else {
103 device = GetDevice(value.getDefiningOp());
104 }
105
106 return CanonicalizeDeviceName(device);
107 }
108
109 struct CrossDeviceTransferPass
110 : public PassWrapper<CrossDeviceTransferPass, OperationPass<func::FuncOp>> {
111 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CrossDeviceTransferPass)
112
113 void runOnOperation() override;
114
getArgumenttensorflow::__anon4aa4a9b60111::CrossDeviceTransferPass115 llvm::StringRef getArgument() const final {
116 return "tfrt-cross-device-transfer";
117 }
118
getDescriptiontensorflow::__anon4aa4a9b60111::CrossDeviceTransferPass119 llvm::StringRef getDescription() const final {
120 return "This pass inserts corert.transfer op to make sure any argument of "
121 "any op is on the same device of the op itself.";
122 }
123 };
124
runOnOperation()125 void CrossDeviceTransferPass::runOnOperation() {
126 func::FuncOp func_op = getOperation();
127 llvm::DenseMap<mlir::Value, llvm::StringMap<mlir::Value>>
128 transferred_value_by_value_and_device;
129
130 func_op.getBody().walk([&](Operation *op) {
131 if (op->hasTrait<OpTrait::IsTerminator>()) return WalkResult::advance();
132 // Do not transfer the argument of corert.transfer op.
133 if (llvm::isa<tfrt::corert::TransferOp>(op)) return WalkResult::advance();
134
135 OpBuilder builder(op);
136 std::string dst_device = GetDevice(op);
137 mlir::Type tensor_type_type =
138 builder.getType<::tfrt::compiler::TensorTypeType>();
139 mlir::Type device_type = builder.getType<::tfrt::compiler::DeviceType>();
140
141 for (mlir::Value arg : op->getOperands()) {
142 // Do not transfer non-TensorHandle values.
143 if (!arg.getType().isa<tfrt::corert::TensorHandleType>()) continue;
144
145 // Do not transfer the result of corert.transfer op.
146 if (OpResult op_result = arg.dyn_cast<OpResult>()) {
147 Operation *defining_op = arg.getDefiningOp();
148 if (llvm::isa<tfrt::corert::TransferOp>(defining_op)) continue;
149 }
150
151 std::string src_device = GetDevice(arg, func_op);
152
153 if (DeviceNameUtils::LocalName(src_device) ==
154 DeviceNameUtils::LocalName(dst_device))
155 continue;
156
157 // Re-use the value already transferred to the given device.
158 llvm::StringMap<mlir::Value> &transferred_value_by_device =
159 transferred_value_by_value_and_device[arg];
160 auto iter = transferred_value_by_device.find(dst_device);
161 if (iter != transferred_value_by_device.end()) {
162 op->replaceUsesOfWith(arg, iter->second);
163 continue;
164 }
165
166 mlir::Value chain_in = func_op.getArgument(0);
167 auto get_device_op = builder.create<tfrt::compiler::GetDeviceOp>(
168 op->getLoc(), device_type, chain_in, dst_device);
169 auto get_tensor_type_op =
170 builder.create<tfrt::corert::GetDstTensorTypeOp>(
171 op->getLoc(), tensor_type_type, arg, get_device_op.getResult());
172 auto transfer_op = builder.create<tfrt::corert::TransferOp>(
173 op->getLoc(), arg.getType(), arg, get_device_op.getResult(),
174 get_tensor_type_op.getResult());
175 mlir::Value new_arg = transfer_op.getResult();
176 transferred_value_by_device[dst_device] = new_arg;
177 op->replaceUsesOfWith(arg, new_arg);
178 }
179 return WalkResult::advance();
180 });
181 }
182
183 } // namespace
184
CreateCrossDeviceTransferPass()185 std::unique_ptr<OperationPass<func::FuncOp>> CreateCrossDeviceTransferPass() {
186 return std::make_unique<CrossDeviceTransferPass>();
187 }
188
189 static PassRegistration<CrossDeviceTransferPass> pass;
190
191 } // namespace tensorflow
192