xref: /aosp_15_r20/external/tensorflow/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h"
17 
18 #include "tensorflow/core/framework/op.h"
19 #include "tensorflow/core/framework/tensor_shape.pb.h"
20 #include "tensorflow/core/framework/types.h"
21 #include "tensorflow/core/grappler/graph_view.h"
22 #include "tensorflow/core/grappler/grappler_item.h"
23 #include "tensorflow/core/grappler/op_types.h"
24 #include "tensorflow/core/grappler/utils/symbolic_shapes.h"
25 #include "tensorflow/core/grappler/utils/topological_sort.h"
26 #include "tensorflow/core/grappler/utils/tpu.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/strings/str_util.h"
29 
30 namespace tensorflow {
31 namespace grappler {
32 namespace internal {
33 
34 // TODO(williamchan): Change this constant to be something smarter, maybe
35 // dynamically determined.
36 constexpr int64_t kTensorMaxSize = 64;
37 
38 // All the nodes that should be denylisted and not swapped.
IsDenylisted(const NodeDef & node)39 bool IsDenylisted(const NodeDef& node) {
40   return
41       // Collective ops should not be swapped.
42       IsCollective(node) ||
43       // ControlFlow ops should not be swapped.
44       IsControlFlow(node) ||
45       // NoOp ops should not be swapped (due to group dependencies).
46       IsNoOp(node);
47 }
48 
49 // Check if Tensor is either a string or is integer and small size
IsTensorSmall(const OpInfo::TensorProperties & prop)50 bool IsTensorSmall(const OpInfo::TensorProperties& prop) {
51   if (prop.dtype() == DataType::DT_STRING) {
52     return true;
53   }
54 
55   // Check type to be int32 or int64.
56   if (prop.dtype() != DataType::DT_INT32 &&
57       prop.dtype() != DataType::DT_INT64 &&
58       prop.dtype() != DataType::DT_FLOAT) {
59     return false;
60   }
61 
62   // Check size known and small.
63   const int64_t size = NumCoefficients(prop.shape());
64   if (size < 0 || size > kTensorMaxSize) {
65     return false;
66   }
67 
68   return true;
69 }
70 
71 // Find KernelDef for `node`, greedily return first found from `devices`.
TryFindKernelDef(const std::vector<DeviceType> & devices,const NodeDef & node,const KernelDef ** kdef)72 Status TryFindKernelDef(const std::vector<DeviceType>& devices,
73                         const NodeDef& node, const KernelDef** kdef) {
74   for (const DeviceType& device : devices) {
75     const KernelDef* kernel = nullptr;
76     Status s = FindKernelDef(device, node, &kernel, nullptr);
77     if (s.ok()) {
78       if (kdef) {
79         *kdef = kernel;
80       }
81       return OkStatus();
82     }
83   }
84 
85   return errors::NotFound("Could not find KernelDef for op: ", node.op());
86 }
87 
88 // Checks if a node's output port is host friendly.
89 // Roughly this means checking if the output port is on Host memory.
IsNodeOutputPortHostFriendly(const GraphView & graph,GraphProperties * properties,const NodeDef & node,int port_id,bool * is_candidate)90 Status IsNodeOutputPortHostFriendly(const GraphView& graph,
91                                     GraphProperties* properties,
92                                     const NodeDef& node, int port_id,
93                                     bool* is_candidate) {
94   *is_candidate = false;
95 
96   // Make sure we are not a denylisted op.
97   if (IsDenylisted(node)) {
98     return OkStatus();
99   }
100 
101   // Check to make sure we have the right properties (i.e., statically shaped).
102   if (!properties->has_properties()) {
103     // This is an expensive call, call it lazily.
104     TF_RETURN_IF_ERROR(properties->InferStatically(
105         /*assume_valid_feeds=*/false, /*aggressive_shape_inference=*/false,
106         /*include_tensor_values=*/false));
107   }
108   const auto& output_properties = properties->GetOutputProperties(node.name());
109   int output_properties_size = output_properties.size();
110   if (port_id >= output_properties_size) {
111     LOG(WARNING) << "port_id=" << port_id
112                  << " but output_properties.size()=" << output_properties.size()
113                  << "\n"
114                  << node.DebugString();
115     return OkStatus();
116   }
117   if (!IsTensorSmall(output_properties[port_id])) {
118     return OkStatus();
119   }
120 
121   // These nodes may be optimized away downstream (even if pinned to Host), we
122   // should (recursively) check their source.
123   if (IsIdentity(node) || IsIdentityNSingleInput(node)) {
124     for (const auto& fanin : graph.GetFanins(node, false)) {
125       bool fanin_candidate = false;
126       TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly(
127           graph, properties, *fanin.node, fanin.port_id, &fanin_candidate));
128       if (!fanin_candidate) {
129         return OkStatus();
130       }
131     }
132     *is_candidate = true;
133     return OkStatus();
134   }
135 
136   // Check if op's device is on CPU.
137   if (absl::StrContains(node.device(), DEVICE_CPU)) {
138     *is_candidate = true;
139     return OkStatus();
140   }
141 
142   // Check if op's output port is pinned to HostMemory.
143   const OpDef* op = nullptr;
144   Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op);
145   if (!s.ok()) {
146     LOG(WARNING) << "Could not find OpDef for : " << node.op();
147     return OkStatus();
148   }
149 
150   // Map the port_id to output_arg_id.
151   const int output_arg_id = OpOutputPortIdToArgId(node, *op, port_id);
152   if (output_arg_id < 0) {
153     LOG(WARNING) << "Invalid port: " << port_id << "!\n"
154                  << node.DebugString() << "\n"
155                  << op->DebugString();
156     return OkStatus();
157   }
158 
159   // Find the kernel.
160   const KernelDef* kernel = nullptr;
161   s = TryFindKernelDef({node.device().c_str(), DEVICE_GPU, DEVICE_CPU}, node,
162                        &kernel);
163   if (!s.ok()) {
164     LOG(INFO) << "Could not find KernelDef for: " << node.op();
165     return OkStatus();
166   }
167 
168   // Check if the output_arg is pinned to Host.
169   for (const string& host_memory_arg : kernel->host_memory_arg()) {
170     if (op->output_arg(output_arg_id).name() == host_memory_arg) {
171       *is_candidate = true;
172       break;
173     }
174   }
175 
176   return OkStatus();
177 }
178 
179 // Checks if a node's input port is Host friendly.
180 // Roughly this means checking if the input port is on Host memory.
IsNodeInputPortHostFriendly(const NodeDef & node,int port_id)181 bool IsNodeInputPortHostFriendly(const NodeDef& node, int port_id) {
182   // If node is on Host, assume its inputs are Host friendly.
183   if (absl::StrContains(node.device(), DEVICE_CPU)) {
184     return true;
185   }
186 
187   // Check if op's input port is pinned to HostMemory.
188   const OpDef* op = nullptr;
189   Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op);
190   if (!s.ok()) {
191     LOG(WARNING) << "Could not find OpDef for : " << node.op();
192     return false;
193   }
194   const int input_arg_id = OpInputPortIdToArgId(node, *op, port_id);
195 
196   // Find the kernel.
197   const KernelDef* kernel = nullptr;
198   s = internal::TryFindKernelDef(
199       {node.device().c_str(), DEVICE_GPU, DEVICE_CPU}, node, &kernel);
200   if (!s.ok()) {
201     LOG(INFO) << "Could not find KernelDef for: " << node.op();
202     return false;
203   }
204 
205   // Check if the input_arg is pinned to Host.
206   for (const string& host_memory_arg : kernel->host_memory_arg()) {
207     if (op->input_arg(input_arg_id).name() == host_memory_arg) {
208       return true;
209     }
210   }
211 
212   return false;
213 }
214 
215 // Checks if a node is a candidate to pin to Host.
216 // The rough algorithm is as follows:
217 // 1] Check if node is denylisted.
218 // 2] Check if node can run on Host.
219 // 3] Check all input/outputs are Host "friendly" (atm, friendly means small,
220 //    ints, and pinned to Host).
IsNodeHostCandidate(const GraphView & graph,GraphProperties * properties,const NodeDef & node,bool * is_candidate)221 Status IsNodeHostCandidate(const GraphView& graph, GraphProperties* properties,
222                            const NodeDef& node, bool* is_candidate) {
223   *is_candidate = false;
224 
225   // Check if node already on CPU.
226   if (absl::StrContains(node.device(), DEVICE_CPU)) {
227     *is_candidate = true;
228     return OkStatus();
229   }
230 
231   // Skip these node types.
232   if (IsDenylisted(node)) {
233     return OkStatus();
234   }
235 
236   // Check the node can be run on CPU.
237   Status s = TryFindKernelDef({DEVICE_CPU}, node, nullptr);
238   if (!s.ok()) {
239     return OkStatus();
240   }
241 
242   // Check all inputs are Host friendly.
243   for (const GraphView::OutputPort& fanin :
244        graph.GetFanins(node, /*include_controlling_nodes=*/false)) {
245     bool fanin_candidate = false;
246     TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly(
247         graph, properties, *fanin.node, fanin.port_id, &fanin_candidate));
248     if (!fanin_candidate) {
249       return OkStatus();
250     }
251   }
252 
253   // Check all outputs are Host friendly.
254   if (!properties->has_properties()) {
255     // This is an expensive call, call it lazily.
256     TF_RETURN_IF_ERROR(properties->InferStatically(
257         /*assume_valid_feeds=*/false, /*aggressive_shape_inference=*/false,
258         /*include_tensor_values=*/false));
259   }
260   for (const auto& prop : properties->GetOutputProperties(node.name())) {
261     if (!IsTensorSmall(prop)) {
262       return OkStatus();
263     }
264   }
265 
266   *is_candidate = true;
267   return OkStatus();
268 }
269 
270 // Tries to find a Host device from `devices`. Returns empty string if no
271 // matching Host device is found.
TryFindHostDevice(const gtl::FlatSet<string> & devices,bool has_device_cpu,const string & device)272 string TryFindHostDevice(const gtl::FlatSet<string>& devices,
273                          bool has_device_cpu, const string& device) {
274   // Force this node onto the CPU.
275   if (device.empty() && has_device_cpu) {
276     return "/device:CPU:0";
277   } else if (absl::StrContains(device, DEVICE_GPU)) {
278     // Sometimes the cluster can have:
279     //   devices = {"/device:CPU:0", "/device:XLA_GPU:0"}
280     // and we need to handle them properly.
281     for (const auto& device_match :
282          {std::pair<string, string>("GPU", "CPU:0"),
283           std::pair<string, string>("/device", "/device:CPU:0")}) {
284       const string device_host =
285           strings::StrCat(device.substr(0, device.rfind(device_match.first)),
286                           device_match.second);
287       if (devices.find(device_host) != devices.end()) {
288         return device_host;
289       }
290     }
291   }
292 
293   // We couldn't find an appropriate Host device, return no device.
294   return "";
295 }
296 }  // end namespace internal
297 
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)298 Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
299                                     GraphDef* optimized_graph) {
300   *optimized_graph = item.graph;
301 
302   // Skip Legacy TPU bridge graphs.
303   if (IsLegacyTPUBridgeGraphDef(*optimized_graph)) {
304     return OkStatus();
305   }
306 
307   GraphProperties properties(item);
308   GraphView graph(optimized_graph);
309 
310   gtl::FlatSet<string> devices;
311   if (cluster) {
312     const std::vector<string> device_names = cluster->GetDeviceNames();
313     devices.insert(device_names.begin(), device_names.end());
314   } else {
315     devices = {"/device:CPU:0"};
316   }
317 
318   const bool has_device_cpu = devices.find("/device:CPU:0") != devices.end();
319 
320   // Topologically sort the graph, so that we traverse the nodes in order. This
321   // will help us discover producer->consumer chains of Host ops.
322   TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph));
323 
324   // All the Const nodes, and their original devices in topological order.
325   std::vector<std::pair<NodeDef*, string>> const_nodes;
326 
327   for (auto& node : *optimized_graph->mutable_node()) {
328     GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
329     bool is_candidate = false;
330     TF_RETURN_IF_ERROR(
331         internal::IsNodeHostCandidate(graph, &properties, node, &is_candidate));
332     if (!is_candidate) {
333       continue;
334     }
335 
336     string device =
337         internal::TryFindHostDevice(devices, has_device_cpu, node.device());
338     if (!device.empty()) {
339       // Keep track of all Const nodes that we swapped.
340       if (IsConstant(node)) {
341         const_nodes.emplace_back(&node, node.device());
342       }
343       VLOG(2) << "Moving node " << node.name() << " to device " << device;
344       *node.mutable_device() = std::move(device);
345     }
346   }
347 
348   // Traverse all `const_nodes`, and map them back to GPU greedily.
349   for (auto& it : const_nodes) {
350     GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
351     NodeDef* node = it.first;
352     const string& device = it.second;
353 
354     // Check all the consumers of this node, if any of them are not on CPU, swap
355     // this node back onto the original device.
356     for (const GraphView::InputPort& fanout : graph.GetFanouts(*node, false)) {
357       // The consumer is not Host friendly, swap it back to the original device.
358       if (!internal::IsNodeInputPortHostFriendly(*fanout.node,
359                                                  fanout.port_id)) {
360         VLOG(2) << "Swapping node " << node->name() << " back to device "
361                 << device;
362         node->set_device(device);
363         break;
364       }
365     }
366   }
367   return OkStatus();
368 }
369 
370 }  // end namespace grappler
371 }  // end namespace tensorflow
372