xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tfrt/transforms/remote_run_encapsulate.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This pass converts each tfrt_dist.remote_execute_func op into a combination
17 // of tfrt_dist.register_tfrt_function op and tfrt_dist.remote_execute op. The
18 // function to be executed in the remote host will be serialized as a string
19 // attribute of the tfrt_dist.register_tfrt_function op.
20 
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/Sequence.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
25 #include "mlir/IR/Attributes.h"  // from @llvm-project
26 #include "mlir/IR/Builders.h"  // from @llvm-project
27 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
28 #include "mlir/IR/Types.h"  // from @llvm-project
29 #include "mlir/IR/Visitors.h"  // from @llvm-project
30 #include "mlir/Pass/Pass.h"  // from @llvm-project
31 #include "mlir/Pass/PassManager.h"  // from @llvm-project
32 #include "mlir/Support/LLVM.h"  // from @llvm-project
33 #include "mlir/Transforms/Passes.h"  // from @llvm-project
34 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
36 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
37 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
38 #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
39 #include "tensorflow/core/util/device_name_utils.h"
40 #include "tfrt/basic_kernels/opdefs/types.h"  // from @tf_runtime
41 #include "tfrt/core_runtime/opdefs/types.h"  // from @tf_runtime
42 #include "tfrt/distributed_runtime/opdefs/kernels.h"  // from @tf_runtime
43 #include "tfrt/distributed_runtime/opdefs/types.h"  // from @tf_runtime
44 #include "tfrt/test_kernels/opdefs/test_kernels.h"  // from @tf_runtime
45 
46 namespace tensorflow {
47 
48 namespace {
49 
50 constexpr const char* kHost = "host";
51 constexpr const char* kTFRTDevice = "tfrt.device";
52 
53 struct DistRemoteRunEncapsulatePass
54     : public PassWrapper<DistRemoteRunEncapsulatePass,
55                          OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_IDtensorflow::__anon5d3031ed0111::DistRemoteRunEncapsulatePass56   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DistRemoteRunEncapsulatePass)
57 
58   llvm::StringRef getArgument() const final {
59     return "tfrt-dist-remote-run-encapsulate";
60   }
getDescriptiontensorflow::__anon5d3031ed0111::DistRemoteRunEncapsulatePass61   llvm::StringRef getDescription() const final {
62     return "This pass looks for a remote_run_func and serialize the callee to "
63            "a string attribute attached to a remote_register operation, "
64            "followed by a remote_execute invocation.";
65   }
66   void runOnOperation() override;
67 
getDependentDialectstensorflow::__anon5d3031ed0111::DistRemoteRunEncapsulatePass68   void getDependentDialects(DialectRegistry& registry) const override {
69     registry.insert<tfrt::dist::DistributedDialect>();
70   }
71 };
72 
EncapsulateFuncAndSerialize(func::FuncOp entry_func,std::string * serialized_func_module)73 LogicalResult EncapsulateFuncAndSerialize(func::FuncOp entry_func,
74                                           std::string* serialized_func_module) {
75   ModuleOp module = entry_func->getParentOfType<ModuleOp>();
76   SymbolTable entry_module_table(module);
77   SmallVector<func::FuncOp, 4> referenced({entry_func});
78 
79   // Create a new module to hold func and all referenced functions.
80   OwningOpRef<mlir::ModuleOp> module_for_func =
81       ModuleOp::create(mlir::UnknownLoc::get(entry_func.getContext()));
82   SymbolTable symbol_table(module_for_func.get());
83 
84   while (!referenced.empty()) {
85     func::FuncOp func = referenced.pop_back_val();
86 
87     // Skip functions that have already been cloned into new module.
88     if (symbol_table.lookup<func::FuncOp>(func.getName())) continue;
89 
90     // Find any SymbolRefAttr in func that maps to a FuncOp. We need to clone
91     // all found FuncOps to new_module to make sure new_module is
92     // self-contained.
93     Optional<SymbolTable::UseRange> uses = SymbolTable::getSymbolUses(func);
94     assert(uses && "expected to be able to collect symbol uses");
95     for (SymbolTable::SymbolUse use : *uses) {
96       func::FuncOp referenced_func = entry_module_table.lookup<func::FuncOp>(
97           use.getSymbolRef().cast<FlatSymbolRefAttr>().getValue());
98 
99       // Skip Symbols that do not map to a function.
100       if (!referenced_func) continue;
101 
102       referenced.emplace_back(referenced_func);
103     }
104 
105     func::FuncOp clone = func.clone();
106     if (clone.getName() == entry_func.getName()) {
107       clone.setPublic();
108     } else {
109       clone.setPrivate();
110     }
111     symbol_table.insert(clone);
112   }
113 
114   *serialized_func_module =
115       tensorflow::SerializeMlirModule(module_for_func.get());
116   return success();
117 }
118 
runOnOperation()119 void DistRemoteRunEncapsulatePass::runOnOperation() {
120   mlir::TF::RuntimeDevices devices;
121   ModuleOp module = getOperation();
122   SymbolTable symtab(module);
123   Type chain_type = tfrt::compiler::ChainType::get(&getContext());
124   Type remote_object_id_ty = tfrt::dist::RemoteObjectIdType::get(&getContext());
125   Type tensor_handle_ty = tfrt::corert::TensorHandleType::get(&getContext());
126   module.walk([&](tfrt::dist::RemoteExecuteFuncOp remote_exec_op) {
127     FlatSymbolRefAttr callee_sym = remote_exec_op.calleeAttr();
128     func::FuncOp callee = symtab.lookup<func::FuncOp>(callee_sym.getValue());
129     if (!callee) {
130       remote_exec_op.emitOpError("callee function ")
131           << callee_sym.getValue() << " is not found";
132       signalPassFailure();
133       return WalkResult::interrupt();
134     }
135     std::string txt_module;
136     if (failed(EncapsulateFuncAndSerialize(callee, &txt_module))) {
137       remote_exec_op.emitOpError("failed to serialize the callee function ")
138           << callee.getName();
139       signalPassFailure();
140       return WalkResult::interrupt();
141     }
142     Location loc = remote_exec_op.getLoc();
143     StringAttr callee_name =
144         StringAttr::get(&getContext(), callee_sym.getValue());
145     OpBuilder builder(remote_exec_op);
146     auto register_op = builder.create<tfrt::dist::RegisterTFRTFunctionOp>(
147         loc, chain_type, remote_exec_op.in_op_chain(), remote_exec_op.context(),
148         remote_exec_op.remote_task(),
149         StringAttr::get(&getContext(), txt_module), callee_name);
150 
151     // Build the device assignment for the results
152     // TODO(tfrt-devs): Define properly MLIR types and operations
153     SmallVector<Attribute, 8> result_devices;
154     for (const auto& result : llvm::enumerate(remote_exec_op.results())) {
155       StringAttr device =
156           callee.getResultAttrOfType<StringAttr>(result.index(), kTFRTDevice);
157       if (!device) {
158         // The result might not have the device attribute if it is added by
159         // the tf-to-tfrt pass. Use the first CPU on the remote host as the
160         // device of this result.
161         DeviceNameUtils::ParsedName parsed_name;
162         if (StringAttr host_attr = callee->getAttrOfType<StringAttr>(kHost)) {
163           auto host = host_attr.getValue();
164           DeviceNameUtils::ParseFullName({host.data(), host.size()},
165                                          &parsed_name);
166         }
167         parsed_name.has_type = true;
168         parsed_name.type = "CPU";
169         parsed_name.has_id = true;
170         parsed_name.id = 0;
171         device = StringAttr::get(
172             &getContext(), DeviceNameUtils::ParsedNameToString(parsed_name));
173       }
174       result_devices.push_back(std::move(device));
175     }
176     // IDEA(donglin): Update the create_remote_execute_spec kernel to use Device
177     // object instead of Device string.
178     Type remote_spec_ty = tfrt::dist::RemoteExecuteSpecType::get(&getContext());
179     auto result_devices_attr = ArrayAttr::get(&getContext(), result_devices);
180     auto remote_spec = builder.create<tfrt::dist::CreateRemoteExecuteSpecOp>(
181         loc, remote_spec_ty, remote_exec_op.context(), result_devices_attr);
182     // If original argument is already tfrt_dist.remote_object_id, use it
183     // directly. If it is TensorHandle, insert an op to extract the
184     // tfrt_dist.remote_object_id from it. Otherwise, emit an error.
185     SmallVector<Value, 4> arguments;
186     for (Value value : remote_exec_op.callee_args()) {
187       if (value.getType().isa<tfrt::dist::RemoteObjectIdType>()) {
188         arguments.push_back(value);
189       } else if (value.getType().isa<tfrt::corert::TensorHandleType>()) {
190         auto new_op = builder.create<tfrt::dist::GetRemoteObjectIdFromTHOp>(
191             loc, remote_object_id_ty, value);
192         arguments.push_back(new_op.result());
193       } else {
194         remote_exec_op.emitOpError(
195             "callee argument type should be either "
196             "TensorHandle or RemoteObjectId");
197         signalPassFailure();
198         return WalkResult::interrupt();
199       }
200     }
201     // Result types are 1 chain, followed by `num_th_results + 1`
202     // tfrt_dist.remote_object_id results, followed by `num_th_results`
203     // corert.tensorhandle results.
204     int32_t num_th_results = remote_exec_op.results().size() - 1;
205     SmallVector<Type, 8> result_types;
206     result_types.push_back(chain_type);
207     for (int count : llvm::seq<int>(0, num_th_results + 1)) {
208       (void)count;
209       result_types.push_back(remote_object_id_ty);
210     }
211     for (int count : llvm::seq<int>(0, num_th_results)) {
212       (void)count;
213       result_types.push_back(tensor_handle_ty);
214     }
215     auto new_remote_exec_th_op = builder.create<tfrt::dist::RemoteExecuteTHOp>(
216         loc, result_types, register_op.out_op_chain(), remote_exec_op.context(),
217         remote_exec_op.remote_task(), remote_spec, num_th_results,
218         callee_name.getValue(), std::move(arguments));
219     // The part of the new results to replace the original results are 2 chains,
220     // followed `num_th_results` corert.tesnorhandle results from the callee
221     // function.
222     SmallVector<Value, 4> new_results;
223     new_results.push_back(new_remote_exec_th_op.getResult(0));
224     new_results.push_back(new_remote_exec_th_op.getResult(1));
225     for (int i : llvm::seq<int>(0, num_th_results)) {
226       new_results.push_back(
227           new_remote_exec_th_op.getResult(i + 2 + num_th_results));
228     }
229     remote_exec_op.replaceAllUsesWith(new_results);
230     remote_exec_op.erase();
231 
232     return WalkResult::advance();
233   });
234 }
235 
236 }  // namespace
237 
CreateDistRemoteRunEncapsulatePass()238 std::unique_ptr<OperationPass<ModuleOp>> CreateDistRemoteRunEncapsulatePass() {
239   return std::make_unique<DistRemoteRunEncapsulatePass>();
240 }
241 
242 static PassRegistration<DistRemoteRunEncapsulatePass> pass;
243 
244 }  // namespace tensorflow
245