1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/common_runtime/device_set.h"
17 #include "tensorflow/core/common_runtime/eager/context.h"
18 #include "tensorflow/core/common_runtime/function_optimization_registry.h"
19 #include "tensorflow/core/common_runtime/optimization_registry.h"
20 #include "tensorflow/core/common_runtime/placer.h"
21 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
22 #include "tensorflow/core/framework/graph_to_functiondef.h"
23 #include "tensorflow/core/grappler/grappler_item.h"
24 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
25 #include "tfrt/host_context/device.h" // from @tf_runtime
26 #include "tfrt/support/error_util.h" // from @tf_runtime
27
28 namespace tensorflow {
29
30 namespace {
31 constexpr char kDefaultCpuDeviceName[] = "CPU:0";
32 } // namespace
33
TransformGraphFunction(const std::string & func_name,const FunctionDef & fdef,const std::string & device_name,const tensorflow::DeviceSet & device_set,EagerContext * eager_ctx,bool enable_grappler,std::unique_ptr<FunctionBody> * fbody,std::unique_ptr<Graph> graph,tfrt::ArrayRef<const tfrt::Device * > input_devices,FunctionLibraryDefinition * func_lib_def)34 Status TransformGraphFunction(const std::string& func_name,
35 const FunctionDef& fdef,
36 const std::string& device_name,
37 const tensorflow::DeviceSet& device_set,
38 EagerContext* eager_ctx, bool enable_grappler,
39 std::unique_ptr<FunctionBody>* fbody,
40 std::unique_ptr<Graph> graph,
41 tfrt::ArrayRef<const tfrt::Device*> input_devices,
42 FunctionLibraryDefinition* func_lib_def) {
43 const DeviceMgr* device_mgr = eager_ctx->local_device_mgr();
44 if (device_mgr == nullptr)
45 return errors::Internal("Cannot find device manager");
46 DumpGraph("Input function graph", graph.get());
47
48 std::vector<string> ret_node_names;
49 std::vector<string> control_ret_node_names;
50 // Mapping from a function body node name to the control output name.
51 std::unordered_map<string, string> node_name_to_control_ret;
52 std::vector<Node*> arg_nodes, ret_nodes;
53 DataTypeVector ret_types;
54 auto attrs = AttrSlice(&fdef.attr());
55 TF_RETURN_IF_ERROR(GetGraphAndArgRets(
56 func_name, attrs, &fdef, func_lib_def, &graph, &arg_nodes, &ret_nodes,
57 &ret_node_names, &ret_types, &control_ret_node_names));
58 for (const auto& control_ret : fdef.control_ret()) {
59 node_name_to_control_ret.emplace(control_ret.second, control_ret.first);
60 }
61 for (Node* node : arg_nodes) {
62 const AttrValue* attr_value;
63 TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
64 int64_t index = attr_value->i();
65 node->set_assigned_device_name(input_devices[index]->name().str());
66 }
67
68 std::vector<string> input_device_names;
69 int input_size = input_devices.size();
70 input_device_names.reserve(input_size);
71 for (int i = 0; i < input_size; ++i) {
72 input_device_names.push_back(input_devices[i]->name().str());
73 }
74
75 std::vector<string> output_device_names;
76 int output_size = fdef.signature().output_arg_size();
77 output_device_names.reserve(output_size);
78 for (int i = 0; i < output_size; ++i) {
79 output_device_names.push_back(device_name);
80 }
81
82 // set default_device for placer.
83 Device* default_device = nullptr;
84 tensorflow::Status s = device_mgr->LookupDevice(device_name, &default_device);
85 if (!s.ok())
86 VLOG(1) << "TransformGraphFunction(): " << device_name << " is unknown."
87 << " default device for placer is not set.";
88
89 TF_RETURN_IF_ERROR(ProcessFunctionLibraryRuntime::PinArgsAndRets(
90 input_device_names, output_device_names, device_set, arg_nodes, ret_nodes,
91 func_lib_def,
92 eager_ctx->AllowSoftPlacement() ? default_device : nullptr));
93 DumpGraph("After running PinArgsAndRets", graph.get());
94
95 ConfigProto config;
96 bool control_rets_updated = false;
97 TF_RETURN_IF_ERROR(FunctionOptimizationPassRegistry::Global().Run(
98 device_set, config, &graph, func_lib_def, &control_ret_node_names,
99 &control_rets_updated));
100
101 if (control_rets_updated) {
102 // Function graph pass may have resulted in different nodes/node names for
103 // control rets.
104 for (const auto& control_ret : control_ret_node_names) {
105 node_name_to_control_ret.emplace(control_ret, control_ret);
106 }
107 } else {
108 for (const auto& control_ret : fdef.control_ret()) {
109 node_name_to_control_ret.emplace(control_ret.second, control_ret.first);
110 }
111 }
112 DumpGraph("After running function optimization pass (bridge)", graph.get());
113
114 // Run function inlining so that placer can place ops in nested functions.
115 GraphOptimizationPassOptions optimization_options;
116 SessionOptions session_options;
117 // In TFRT we don't lower v2 control flow to v1.
118 session_options.config.mutable_experimental()->set_use_tfrt(true);
119 session_options.config.mutable_graph_options()
120 ->mutable_optimizer_options()
121 ->set_do_function_inlining(true);
122 optimization_options.session_options = &session_options;
123 optimization_options.graph = &graph;
124 optimization_options.flib_def = func_lib_def;
125 optimization_options.device_set = &device_set;
126 optimization_options.is_function_graph = true;
127 optimization_options.default_function_device = default_device;
128 optimization_options.function_def = &fdef;
129
130 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
131 OptimizationPassRegistry::PRE_PLACEMENT, optimization_options));
132 DumpGraph("After running pre placement passes", graph.get());
133
134 // Run placer before importing GraphDef to MLIR.
135 Placer placer(graph.get(), func_name, func_lib_def, &device_set,
136 default_device, eager_ctx->AllowSoftPlacement(),
137 /*log_device_placement=*/false);
138 TF_RETURN_IF_ERROR(placer.Run());
139 DumpGraph("After running placer", graph.get());
140
141 if (enable_grappler) {
142 Device* cpu_device;
143 TF_RETURN_IF_ERROR(
144 device_mgr->LookupDevice(kDefaultCpuDeviceName, &cpu_device));
145
146 ConfigProto config_proto;
147 config_proto.mutable_experimental()->set_use_tfrt(true);
148 config_proto.mutable_graph_options()
149 ->mutable_optimizer_options()
150 ->set_do_function_inlining(true);
151 // Do not skip grappler optimization even for small graphs.
152 config_proto.mutable_graph_options()
153 ->mutable_rewrite_options()
154 ->set_min_graph_nodes(-1);
155
156 grappler::GrapplerItem::OptimizationOptions grappler_options =
157 grappler::CreateOptOptionsForEager();
158 auto status = grappler::OptimizeGraph(
159 std::move(ret_node_names), std::move(control_ret_node_names),
160 func_lib_def, device_set, cpu_device, config_proto,
161 fdef.signature().name(), grappler_options, &graph);
162 if (!status.ok()) {
163 LOG(WARNING) << "Ignoring multi-device function optimization failure: "
164 << status.ToString();
165 }
166 DumpGraph("After grappler optimization", graph.get());
167 }
168
169 // We must preserve control returns in each of the function components,
170 // otherwise after function inlining we might prune side-effectful nodes.
171 const auto control_ret =
172 [&node_name_to_control_ret](const Node* n) -> absl::optional<string> {
173 const auto it = node_name_to_control_ret.find(n->name());
174 if (it != node_name_to_control_ret.end())
175 return absl::make_optional<string>(it->second);
176 return absl::nullopt;
177 };
178 FunctionDef new_func;
179 TF_RETURN_IF_ERROR(
180 GraphToFunctionDef(*graph, func_name, control_ret, &new_func));
181 // Refresh `fbody`.
182 TF_RETURN_IF_ERROR(
183 FunctionDefToBodyHelper(new_func, AttrSlice(), func_lib_def, fbody));
184 return OkStatus();
185 }
186 } // namespace tensorflow
187