1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This pass converts each tfrt_dist.remote_execute_func op into a combination
17 // of tfrt_dist.register_tfrt_function op and tfrt_dist.remote_execute op. The
18 // function to be executed in the remote host will be serialized as a string
19 // attribute of the tfrt_dist.register_tfrt_function op.
20
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/Sequence.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
25 #include "mlir/IR/Attributes.h" // from @llvm-project
26 #include "mlir/IR/Builders.h" // from @llvm-project
27 #include "mlir/IR/SymbolTable.h" // from @llvm-project
28 #include "mlir/IR/Types.h" // from @llvm-project
29 #include "mlir/IR/Visitors.h" // from @llvm-project
30 #include "mlir/Pass/Pass.h" // from @llvm-project
31 #include "mlir/Pass/PassManager.h" // from @llvm-project
32 #include "mlir/Support/LLVM.h" // from @llvm-project
33 #include "mlir/Transforms/Passes.h" // from @llvm-project
34 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
36 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
37 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
38 #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
39 #include "tensorflow/core/util/device_name_utils.h"
40 #include "tfrt/basic_kernels/opdefs/types.h" // from @tf_runtime
41 #include "tfrt/core_runtime/opdefs/types.h" // from @tf_runtime
42 #include "tfrt/distributed_runtime/opdefs/kernels.h" // from @tf_runtime
43 #include "tfrt/distributed_runtime/opdefs/types.h" // from @tf_runtime
44 #include "tfrt/test_kernels/opdefs/test_kernels.h" // from @tf_runtime
45
46 namespace tensorflow {
47
48 namespace {
49
50 constexpr const char* kHost = "host";
51 constexpr const char* kTFRTDevice = "tfrt.device";
52
53 struct DistRemoteRunEncapsulatePass
54 : public PassWrapper<DistRemoteRunEncapsulatePass,
55 OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_IDtensorflow::__anon5d3031ed0111::DistRemoteRunEncapsulatePass56 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DistRemoteRunEncapsulatePass)
57
58 llvm::StringRef getArgument() const final {
59 return "tfrt-dist-remote-run-encapsulate";
60 }
getDescriptiontensorflow::__anon5d3031ed0111::DistRemoteRunEncapsulatePass61 llvm::StringRef getDescription() const final {
62 return "This pass looks for a remote_run_func and serialize the callee to "
63 "a string attribute attached to a remote_register operation, "
64 "followed by a remote_execute invocation.";
65 }
66 void runOnOperation() override;
67
getDependentDialectstensorflow::__anon5d3031ed0111::DistRemoteRunEncapsulatePass68 void getDependentDialects(DialectRegistry& registry) const override {
69 registry.insert<tfrt::dist::DistributedDialect>();
70 }
71 };
72
EncapsulateFuncAndSerialize(func::FuncOp entry_func,std::string * serialized_func_module)73 LogicalResult EncapsulateFuncAndSerialize(func::FuncOp entry_func,
74 std::string* serialized_func_module) {
75 ModuleOp module = entry_func->getParentOfType<ModuleOp>();
76 SymbolTable entry_module_table(module);
77 SmallVector<func::FuncOp, 4> referenced({entry_func});
78
79 // Create a new module to hold func and all referenced functions.
80 OwningOpRef<mlir::ModuleOp> module_for_func =
81 ModuleOp::create(mlir::UnknownLoc::get(entry_func.getContext()));
82 SymbolTable symbol_table(module_for_func.get());
83
84 while (!referenced.empty()) {
85 func::FuncOp func = referenced.pop_back_val();
86
87 // Skip functions that have already been cloned into new module.
88 if (symbol_table.lookup<func::FuncOp>(func.getName())) continue;
89
90 // Find any SymbolRefAttr in func that maps to a FuncOp. We need to clone
91 // all found FuncOps to new_module to make sure new_module is
92 // self-contained.
93 Optional<SymbolTable::UseRange> uses = SymbolTable::getSymbolUses(func);
94 assert(uses && "expected to be able to collect symbol uses");
95 for (SymbolTable::SymbolUse use : *uses) {
96 func::FuncOp referenced_func = entry_module_table.lookup<func::FuncOp>(
97 use.getSymbolRef().cast<FlatSymbolRefAttr>().getValue());
98
99 // Skip Symbols that do not map to a function.
100 if (!referenced_func) continue;
101
102 referenced.emplace_back(referenced_func);
103 }
104
105 func::FuncOp clone = func.clone();
106 if (clone.getName() == entry_func.getName()) {
107 clone.setPublic();
108 } else {
109 clone.setPrivate();
110 }
111 symbol_table.insert(clone);
112 }
113
114 *serialized_func_module =
115 tensorflow::SerializeMlirModule(module_for_func.get());
116 return success();
117 }
118
runOnOperation()119 void DistRemoteRunEncapsulatePass::runOnOperation() {
120 mlir::TF::RuntimeDevices devices;
121 ModuleOp module = getOperation();
122 SymbolTable symtab(module);
123 Type chain_type = tfrt::compiler::ChainType::get(&getContext());
124 Type remote_object_id_ty = tfrt::dist::RemoteObjectIdType::get(&getContext());
125 Type tensor_handle_ty = tfrt::corert::TensorHandleType::get(&getContext());
126 module.walk([&](tfrt::dist::RemoteExecuteFuncOp remote_exec_op) {
127 FlatSymbolRefAttr callee_sym = remote_exec_op.calleeAttr();
128 func::FuncOp callee = symtab.lookup<func::FuncOp>(callee_sym.getValue());
129 if (!callee) {
130 remote_exec_op.emitOpError("callee function ")
131 << callee_sym.getValue() << " is not found";
132 signalPassFailure();
133 return WalkResult::interrupt();
134 }
135 std::string txt_module;
136 if (failed(EncapsulateFuncAndSerialize(callee, &txt_module))) {
137 remote_exec_op.emitOpError("failed to serialize the callee function ")
138 << callee.getName();
139 signalPassFailure();
140 return WalkResult::interrupt();
141 }
142 Location loc = remote_exec_op.getLoc();
143 StringAttr callee_name =
144 StringAttr::get(&getContext(), callee_sym.getValue());
145 OpBuilder builder(remote_exec_op);
146 auto register_op = builder.create<tfrt::dist::RegisterTFRTFunctionOp>(
147 loc, chain_type, remote_exec_op.in_op_chain(), remote_exec_op.context(),
148 remote_exec_op.remote_task(),
149 StringAttr::get(&getContext(), txt_module), callee_name);
150
151 // Build the device assignment for the results
152 // TODO(tfrt-devs): Define properly MLIR types and operations
153 SmallVector<Attribute, 8> result_devices;
154 for (const auto& result : llvm::enumerate(remote_exec_op.results())) {
155 StringAttr device =
156 callee.getResultAttrOfType<StringAttr>(result.index(), kTFRTDevice);
157 if (!device) {
158 // The result might not have the device attribute if it is added by
159 // the tf-to-tfrt pass. Use the first CPU on the remote host as the
160 // device of this result.
161 DeviceNameUtils::ParsedName parsed_name;
162 if (StringAttr host_attr = callee->getAttrOfType<StringAttr>(kHost)) {
163 auto host = host_attr.getValue();
164 DeviceNameUtils::ParseFullName({host.data(), host.size()},
165 &parsed_name);
166 }
167 parsed_name.has_type = true;
168 parsed_name.type = "CPU";
169 parsed_name.has_id = true;
170 parsed_name.id = 0;
171 device = StringAttr::get(
172 &getContext(), DeviceNameUtils::ParsedNameToString(parsed_name));
173 }
174 result_devices.push_back(std::move(device));
175 }
176 // IDEA(donglin): Update the create_remote_execute_spec kernel to use Device
177 // object instead of Device string.
178 Type remote_spec_ty = tfrt::dist::RemoteExecuteSpecType::get(&getContext());
179 auto result_devices_attr = ArrayAttr::get(&getContext(), result_devices);
180 auto remote_spec = builder.create<tfrt::dist::CreateRemoteExecuteSpecOp>(
181 loc, remote_spec_ty, remote_exec_op.context(), result_devices_attr);
182 // If original argument is already tfrt_dist.remote_object_id, use it
183 // directly. If it is TensorHandle, insert an op to extract the
184 // tfrt_dist.remote_object_id from it. Otherwise, emit an error.
185 SmallVector<Value, 4> arguments;
186 for (Value value : remote_exec_op.callee_args()) {
187 if (value.getType().isa<tfrt::dist::RemoteObjectIdType>()) {
188 arguments.push_back(value);
189 } else if (value.getType().isa<tfrt::corert::TensorHandleType>()) {
190 auto new_op = builder.create<tfrt::dist::GetRemoteObjectIdFromTHOp>(
191 loc, remote_object_id_ty, value);
192 arguments.push_back(new_op.result());
193 } else {
194 remote_exec_op.emitOpError(
195 "callee argument type should be either "
196 "TensorHandle or RemoteObjectId");
197 signalPassFailure();
198 return WalkResult::interrupt();
199 }
200 }
201 // Result types are 1 chain, followed by `num_th_results + 1`
202 // tfrt_dist.remote_object_id results, followed by `num_th_results`
203 // corert.tensorhandle results.
204 int32_t num_th_results = remote_exec_op.results().size() - 1;
205 SmallVector<Type, 8> result_types;
206 result_types.push_back(chain_type);
207 for (int count : llvm::seq<int>(0, num_th_results + 1)) {
208 (void)count;
209 result_types.push_back(remote_object_id_ty);
210 }
211 for (int count : llvm::seq<int>(0, num_th_results)) {
212 (void)count;
213 result_types.push_back(tensor_handle_ty);
214 }
215 auto new_remote_exec_th_op = builder.create<tfrt::dist::RemoteExecuteTHOp>(
216 loc, result_types, register_op.out_op_chain(), remote_exec_op.context(),
217 remote_exec_op.remote_task(), remote_spec, num_th_results,
218 callee_name.getValue(), std::move(arguments));
219 // The part of the new results to replace the original results are 2 chains,
220 // followed `num_th_results` corert.tesnorhandle results from the callee
221 // function.
222 SmallVector<Value, 4> new_results;
223 new_results.push_back(new_remote_exec_th_op.getResult(0));
224 new_results.push_back(new_remote_exec_th_op.getResult(1));
225 for (int i : llvm::seq<int>(0, num_th_results)) {
226 new_results.push_back(
227 new_remote_exec_th_op.getResult(i + 2 + num_th_results));
228 }
229 remote_exec_op.replaceAllUsesWith(new_results);
230 remote_exec_op.erase();
231
232 return WalkResult::advance();
233 });
234 }
235
236 } // namespace
237
CreateDistRemoteRunEncapsulatePass()238 std::unique_ptr<OperationPass<ModuleOp>> CreateDistRemoteRunEncapsulatePass() {
239 return std::make_unique<DistRemoteRunEncapsulatePass>();
240 }
241
242 static PassRegistration<DistRemoteRunEncapsulatePass> pass;
243
244 } // namespace tensorflow
245