xref: /aosp_15_r20/external/tensorflow/tensorflow/core/tfrt/eager/transform_graph_function.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/device_set.h"
17 #include "tensorflow/core/common_runtime/eager/context.h"
18 #include "tensorflow/core/common_runtime/function_optimization_registry.h"
19 #include "tensorflow/core/common_runtime/optimization_registry.h"
20 #include "tensorflow/core/common_runtime/placer.h"
21 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
22 #include "tensorflow/core/framework/graph_to_functiondef.h"
23 #include "tensorflow/core/grappler/grappler_item.h"
24 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
25 #include "tfrt/host_context/device.h"  // from @tf_runtime
26 #include "tfrt/support/error_util.h"  // from @tf_runtime
27 
28 namespace tensorflow {
29 
30 namespace {
31 constexpr char kDefaultCpuDeviceName[] = "CPU:0";
32 }  // namespace
33 
TransformGraphFunction(const std::string & func_name,const FunctionDef & fdef,const std::string & device_name,const tensorflow::DeviceSet & device_set,EagerContext * eager_ctx,bool enable_grappler,std::unique_ptr<FunctionBody> * fbody,std::unique_ptr<Graph> graph,tfrt::ArrayRef<const tfrt::Device * > input_devices,FunctionLibraryDefinition * func_lib_def)34 Status TransformGraphFunction(const std::string& func_name,
35                               const FunctionDef& fdef,
36                               const std::string& device_name,
37                               const tensorflow::DeviceSet& device_set,
38                               EagerContext* eager_ctx, bool enable_grappler,
39                               std::unique_ptr<FunctionBody>* fbody,
40                               std::unique_ptr<Graph> graph,
41                               tfrt::ArrayRef<const tfrt::Device*> input_devices,
42                               FunctionLibraryDefinition* func_lib_def) {
43   const DeviceMgr* device_mgr = eager_ctx->local_device_mgr();
44   if (device_mgr == nullptr)
45     return errors::Internal("Cannot find device manager");
46   DumpGraph("Input function graph", graph.get());
47 
48   std::vector<string> ret_node_names;
49   std::vector<string> control_ret_node_names;
50   // Mapping from a function body node name to the control output name.
51   std::unordered_map<string, string> node_name_to_control_ret;
52   std::vector<Node*> arg_nodes, ret_nodes;
53   DataTypeVector ret_types;
54   auto attrs = AttrSlice(&fdef.attr());
55   TF_RETURN_IF_ERROR(GetGraphAndArgRets(
56       func_name, attrs, &fdef, func_lib_def, &graph, &arg_nodes, &ret_nodes,
57       &ret_node_names, &ret_types, &control_ret_node_names));
58   for (const auto& control_ret : fdef.control_ret()) {
59     node_name_to_control_ret.emplace(control_ret.second, control_ret.first);
60   }
61   for (Node* node : arg_nodes) {
62     const AttrValue* attr_value;
63     TF_RETURN_IF_ERROR(node->attrs().Find("index", &attr_value));
64     int64_t index = attr_value->i();
65     node->set_assigned_device_name(input_devices[index]->name().str());
66   }
67 
68   std::vector<string> input_device_names;
69   int input_size = input_devices.size();
70   input_device_names.reserve(input_size);
71   for (int i = 0; i < input_size; ++i) {
72     input_device_names.push_back(input_devices[i]->name().str());
73   }
74 
75   std::vector<string> output_device_names;
76   int output_size = fdef.signature().output_arg_size();
77   output_device_names.reserve(output_size);
78   for (int i = 0; i < output_size; ++i) {
79     output_device_names.push_back(device_name);
80   }
81 
82   // set default_device for placer.
83   Device* default_device = nullptr;
84   tensorflow::Status s = device_mgr->LookupDevice(device_name, &default_device);
85   if (!s.ok())
86     VLOG(1) << "TransformGraphFunction(): " << device_name << " is unknown."
87             << " default device for placer is not set.";
88 
89   TF_RETURN_IF_ERROR(ProcessFunctionLibraryRuntime::PinArgsAndRets(
90       input_device_names, output_device_names, device_set, arg_nodes, ret_nodes,
91       func_lib_def,
92       eager_ctx->AllowSoftPlacement() ? default_device : nullptr));
93   DumpGraph("After running PinArgsAndRets", graph.get());
94 
95   ConfigProto config;
96   bool control_rets_updated = false;
97   TF_RETURN_IF_ERROR(FunctionOptimizationPassRegistry::Global().Run(
98       device_set, config, &graph, func_lib_def, &control_ret_node_names,
99       &control_rets_updated));
100 
101   if (control_rets_updated) {
102     // Function graph pass may have resulted in different nodes/node names for
103     // control rets.
104     for (const auto& control_ret : control_ret_node_names) {
105       node_name_to_control_ret.emplace(control_ret, control_ret);
106     }
107   } else {
108     for (const auto& control_ret : fdef.control_ret()) {
109       node_name_to_control_ret.emplace(control_ret.second, control_ret.first);
110     }
111   }
112   DumpGraph("After running function optimization pass (bridge)", graph.get());
113 
114   // Run function inlining so that placer can place ops in nested functions.
115   GraphOptimizationPassOptions optimization_options;
116   SessionOptions session_options;
117   // In TFRT we don't lower v2 control flow to v1.
118   session_options.config.mutable_experimental()->set_use_tfrt(true);
119   session_options.config.mutable_graph_options()
120       ->mutable_optimizer_options()
121       ->set_do_function_inlining(true);
122   optimization_options.session_options = &session_options;
123   optimization_options.graph = &graph;
124   optimization_options.flib_def = func_lib_def;
125   optimization_options.device_set = &device_set;
126   optimization_options.is_function_graph = true;
127   optimization_options.default_function_device = default_device;
128   optimization_options.function_def = &fdef;
129 
130   TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
131       OptimizationPassRegistry::PRE_PLACEMENT, optimization_options));
132   DumpGraph("After running pre placement passes", graph.get());
133 
134   // Run placer before importing GraphDef to MLIR.
135   Placer placer(graph.get(), func_name, func_lib_def, &device_set,
136                 default_device, eager_ctx->AllowSoftPlacement(),
137                 /*log_device_placement=*/false);
138   TF_RETURN_IF_ERROR(placer.Run());
139   DumpGraph("After running placer", graph.get());
140 
141   if (enable_grappler) {
142     Device* cpu_device;
143     TF_RETURN_IF_ERROR(
144         device_mgr->LookupDevice(kDefaultCpuDeviceName, &cpu_device));
145 
146     ConfigProto config_proto;
147     config_proto.mutable_experimental()->set_use_tfrt(true);
148     config_proto.mutable_graph_options()
149         ->mutable_optimizer_options()
150         ->set_do_function_inlining(true);
151     // Do not skip grappler optimization even for small graphs.
152     config_proto.mutable_graph_options()
153         ->mutable_rewrite_options()
154         ->set_min_graph_nodes(-1);
155 
156     grappler::GrapplerItem::OptimizationOptions grappler_options =
157         grappler::CreateOptOptionsForEager();
158     auto status = grappler::OptimizeGraph(
159         std::move(ret_node_names), std::move(control_ret_node_names),
160         func_lib_def, device_set, cpu_device, config_proto,
161         fdef.signature().name(), grappler_options, &graph);
162     if (!status.ok()) {
163       LOG(WARNING) << "Ignoring multi-device function optimization failure: "
164                    << status.ToString();
165     }
166     DumpGraph("After grappler optimization", graph.get());
167   }
168 
169   // We must preserve control returns in each of the function components,
170   // otherwise after function inlining we might prune side-effectful nodes.
171   const auto control_ret =
172       [&node_name_to_control_ret](const Node* n) -> absl::optional<string> {
173     const auto it = node_name_to_control_ret.find(n->name());
174     if (it != node_name_to_control_ret.end())
175       return absl::make_optional<string>(it->second);
176     return absl::nullopt;
177   };
178   FunctionDef new_func;
179   TF_RETURN_IF_ERROR(
180       GraphToFunctionDef(*graph, func_name, control_ret, &new_func));
181   // Refresh `fbody`.
182   TF_RETURN_IF_ERROR(
183       FunctionDefToBodyHelper(new_func, AttrSlice(), func_lib_def, fbody));
184   return OkStatus();
185 }
186 }  // namespace tensorflow
187