1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h"
17
18 #include "tensorflow/core/framework/op.h"
19 #include "tensorflow/core/framework/tensor_shape.pb.h"
20 #include "tensorflow/core/framework/types.h"
21 #include "tensorflow/core/grappler/graph_view.h"
22 #include "tensorflow/core/grappler/grappler_item.h"
23 #include "tensorflow/core/grappler/op_types.h"
24 #include "tensorflow/core/grappler/utils/symbolic_shapes.h"
25 #include "tensorflow/core/grappler/utils/topological_sort.h"
26 #include "tensorflow/core/grappler/utils/tpu.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/strings/str_util.h"
29
30 namespace tensorflow {
31 namespace grappler {
32 namespace internal {
33
34 // TODO(williamchan): Change this constant to be something smarter, maybe
35 // dynamically determined.
36 constexpr int64_t kTensorMaxSize = 64;
37
38 // All the nodes that should be denylisted and not swapped.
IsDenylisted(const NodeDef & node)39 bool IsDenylisted(const NodeDef& node) {
40 return
41 // Collective ops should not be swapped.
42 IsCollective(node) ||
43 // ControlFlow ops should not be swapped.
44 IsControlFlow(node) ||
45 // NoOp ops should not be swapped (due to group dependencies).
46 IsNoOp(node);
47 }
48
49 // Check if Tensor is either a string or is integer and small size
IsTensorSmall(const OpInfo::TensorProperties & prop)50 bool IsTensorSmall(const OpInfo::TensorProperties& prop) {
51 if (prop.dtype() == DataType::DT_STRING) {
52 return true;
53 }
54
55 // Check type to be int32 or int64.
56 if (prop.dtype() != DataType::DT_INT32 &&
57 prop.dtype() != DataType::DT_INT64 &&
58 prop.dtype() != DataType::DT_FLOAT) {
59 return false;
60 }
61
62 // Check size known and small.
63 const int64_t size = NumCoefficients(prop.shape());
64 if (size < 0 || size > kTensorMaxSize) {
65 return false;
66 }
67
68 return true;
69 }
70
71 // Find KernelDef for `node`, greedily return first found from `devices`.
TryFindKernelDef(const std::vector<DeviceType> & devices,const NodeDef & node,const KernelDef ** kdef)72 Status TryFindKernelDef(const std::vector<DeviceType>& devices,
73 const NodeDef& node, const KernelDef** kdef) {
74 for (const DeviceType& device : devices) {
75 const KernelDef* kernel = nullptr;
76 Status s = FindKernelDef(device, node, &kernel, nullptr);
77 if (s.ok()) {
78 if (kdef) {
79 *kdef = kernel;
80 }
81 return OkStatus();
82 }
83 }
84
85 return errors::NotFound("Could not find KernelDef for op: ", node.op());
86 }
87
88 // Checks if a node's output port is host friendly.
89 // Roughly this means checking if the output port is on Host memory.
IsNodeOutputPortHostFriendly(const GraphView & graph,GraphProperties * properties,const NodeDef & node,int port_id,bool * is_candidate)90 Status IsNodeOutputPortHostFriendly(const GraphView& graph,
91 GraphProperties* properties,
92 const NodeDef& node, int port_id,
93 bool* is_candidate) {
94 *is_candidate = false;
95
96 // Make sure we are not a denylisted op.
97 if (IsDenylisted(node)) {
98 return OkStatus();
99 }
100
101 // Check to make sure we have the right properties (i.e., statically shaped).
102 if (!properties->has_properties()) {
103 // This is an expensive call, call it lazily.
104 TF_RETURN_IF_ERROR(properties->InferStatically(
105 /*assume_valid_feeds=*/false, /*aggressive_shape_inference=*/false,
106 /*include_tensor_values=*/false));
107 }
108 const auto& output_properties = properties->GetOutputProperties(node.name());
109 int output_properties_size = output_properties.size();
110 if (port_id >= output_properties_size) {
111 LOG(WARNING) << "port_id=" << port_id
112 << " but output_properties.size()=" << output_properties.size()
113 << "\n"
114 << node.DebugString();
115 return OkStatus();
116 }
117 if (!IsTensorSmall(output_properties[port_id])) {
118 return OkStatus();
119 }
120
121 // These nodes may be optimized away downstream (even if pinned to Host), we
122 // should (recursively) check their source.
123 if (IsIdentity(node) || IsIdentityNSingleInput(node)) {
124 for (const auto& fanin : graph.GetFanins(node, false)) {
125 bool fanin_candidate = false;
126 TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly(
127 graph, properties, *fanin.node, fanin.port_id, &fanin_candidate));
128 if (!fanin_candidate) {
129 return OkStatus();
130 }
131 }
132 *is_candidate = true;
133 return OkStatus();
134 }
135
136 // Check if op's device is on CPU.
137 if (absl::StrContains(node.device(), DEVICE_CPU)) {
138 *is_candidate = true;
139 return OkStatus();
140 }
141
142 // Check if op's output port is pinned to HostMemory.
143 const OpDef* op = nullptr;
144 Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op);
145 if (!s.ok()) {
146 LOG(WARNING) << "Could not find OpDef for : " << node.op();
147 return OkStatus();
148 }
149
150 // Map the port_id to output_arg_id.
151 const int output_arg_id = OpOutputPortIdToArgId(node, *op, port_id);
152 if (output_arg_id < 0) {
153 LOG(WARNING) << "Invalid port: " << port_id << "!\n"
154 << node.DebugString() << "\n"
155 << op->DebugString();
156 return OkStatus();
157 }
158
159 // Find the kernel.
160 const KernelDef* kernel = nullptr;
161 s = TryFindKernelDef({node.device().c_str(), DEVICE_GPU, DEVICE_CPU}, node,
162 &kernel);
163 if (!s.ok()) {
164 LOG(INFO) << "Could not find KernelDef for: " << node.op();
165 return OkStatus();
166 }
167
168 // Check if the output_arg is pinned to Host.
169 for (const string& host_memory_arg : kernel->host_memory_arg()) {
170 if (op->output_arg(output_arg_id).name() == host_memory_arg) {
171 *is_candidate = true;
172 break;
173 }
174 }
175
176 return OkStatus();
177 }
178
179 // Checks if a node's input port is Host friendly.
180 // Roughly this means checking if the input port is on Host memory.
IsNodeInputPortHostFriendly(const NodeDef & node,int port_id)181 bool IsNodeInputPortHostFriendly(const NodeDef& node, int port_id) {
182 // If node is on Host, assume its inputs are Host friendly.
183 if (absl::StrContains(node.device(), DEVICE_CPU)) {
184 return true;
185 }
186
187 // Check if op's input port is pinned to HostMemory.
188 const OpDef* op = nullptr;
189 Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op);
190 if (!s.ok()) {
191 LOG(WARNING) << "Could not find OpDef for : " << node.op();
192 return false;
193 }
194 const int input_arg_id = OpInputPortIdToArgId(node, *op, port_id);
195
196 // Find the kernel.
197 const KernelDef* kernel = nullptr;
198 s = internal::TryFindKernelDef(
199 {node.device().c_str(), DEVICE_GPU, DEVICE_CPU}, node, &kernel);
200 if (!s.ok()) {
201 LOG(INFO) << "Could not find KernelDef for: " << node.op();
202 return false;
203 }
204
205 // Check if the input_arg is pinned to Host.
206 for (const string& host_memory_arg : kernel->host_memory_arg()) {
207 if (op->input_arg(input_arg_id).name() == host_memory_arg) {
208 return true;
209 }
210 }
211
212 return false;
213 }
214
215 // Checks if a node is a candidate to pin to Host.
216 // The rough algorithm is as follows:
217 // 1] Check if node is denylisted.
218 // 2] Check if node can run on Host.
219 // 3] Check all input/outputs are Host "friendly" (atm, friendly means small,
220 // ints, and pinned to Host).
IsNodeHostCandidate(const GraphView & graph,GraphProperties * properties,const NodeDef & node,bool * is_candidate)221 Status IsNodeHostCandidate(const GraphView& graph, GraphProperties* properties,
222 const NodeDef& node, bool* is_candidate) {
223 *is_candidate = false;
224
225 // Check if node already on CPU.
226 if (absl::StrContains(node.device(), DEVICE_CPU)) {
227 *is_candidate = true;
228 return OkStatus();
229 }
230
231 // Skip these node types.
232 if (IsDenylisted(node)) {
233 return OkStatus();
234 }
235
236 // Check the node can be run on CPU.
237 Status s = TryFindKernelDef({DEVICE_CPU}, node, nullptr);
238 if (!s.ok()) {
239 return OkStatus();
240 }
241
242 // Check all inputs are Host friendly.
243 for (const GraphView::OutputPort& fanin :
244 graph.GetFanins(node, /*include_controlling_nodes=*/false)) {
245 bool fanin_candidate = false;
246 TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly(
247 graph, properties, *fanin.node, fanin.port_id, &fanin_candidate));
248 if (!fanin_candidate) {
249 return OkStatus();
250 }
251 }
252
253 // Check all outputs are Host friendly.
254 if (!properties->has_properties()) {
255 // This is an expensive call, call it lazily.
256 TF_RETURN_IF_ERROR(properties->InferStatically(
257 /*assume_valid_feeds=*/false, /*aggressive_shape_inference=*/false,
258 /*include_tensor_values=*/false));
259 }
260 for (const auto& prop : properties->GetOutputProperties(node.name())) {
261 if (!IsTensorSmall(prop)) {
262 return OkStatus();
263 }
264 }
265
266 *is_candidate = true;
267 return OkStatus();
268 }
269
270 // Tries to find a Host device from `devices`. Returns empty string if no
271 // matching Host device is found.
TryFindHostDevice(const gtl::FlatSet<string> & devices,bool has_device_cpu,const string & device)272 string TryFindHostDevice(const gtl::FlatSet<string>& devices,
273 bool has_device_cpu, const string& device) {
274 // Force this node onto the CPU.
275 if (device.empty() && has_device_cpu) {
276 return "/device:CPU:0";
277 } else if (absl::StrContains(device, DEVICE_GPU)) {
278 // Sometimes the cluster can have:
279 // devices = {"/device:CPU:0", "/device:XLA_GPU:0"}
280 // and we need to handle them properly.
281 for (const auto& device_match :
282 {std::pair<string, string>("GPU", "CPU:0"),
283 std::pair<string, string>("/device", "/device:CPU:0")}) {
284 const string device_host =
285 strings::StrCat(device.substr(0, device.rfind(device_match.first)),
286 device_match.second);
287 if (devices.find(device_host) != devices.end()) {
288 return device_host;
289 }
290 }
291 }
292
293 // We couldn't find an appropriate Host device, return no device.
294 return "";
295 }
296 } // end namespace internal
297
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)298 Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
299 GraphDef* optimized_graph) {
300 *optimized_graph = item.graph;
301
302 // Skip Legacy TPU bridge graphs.
303 if (IsLegacyTPUBridgeGraphDef(*optimized_graph)) {
304 return OkStatus();
305 }
306
307 GraphProperties properties(item);
308 GraphView graph(optimized_graph);
309
310 gtl::FlatSet<string> devices;
311 if (cluster) {
312 const std::vector<string> device_names = cluster->GetDeviceNames();
313 devices.insert(device_names.begin(), device_names.end());
314 } else {
315 devices = {"/device:CPU:0"};
316 }
317
318 const bool has_device_cpu = devices.find("/device:CPU:0") != devices.end();
319
320 // Topologically sort the graph, so that we traverse the nodes in order. This
321 // will help us discover producer->consumer chains of Host ops.
322 TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph));
323
324 // All the Const nodes, and their original devices in topological order.
325 std::vector<std::pair<NodeDef*, string>> const_nodes;
326
327 for (auto& node : *optimized_graph->mutable_node()) {
328 GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
329 bool is_candidate = false;
330 TF_RETURN_IF_ERROR(
331 internal::IsNodeHostCandidate(graph, &properties, node, &is_candidate));
332 if (!is_candidate) {
333 continue;
334 }
335
336 string device =
337 internal::TryFindHostDevice(devices, has_device_cpu, node.device());
338 if (!device.empty()) {
339 // Keep track of all Const nodes that we swapped.
340 if (IsConstant(node)) {
341 const_nodes.emplace_back(&node, node.device());
342 }
343 VLOG(2) << "Moving node " << node.name() << " to device " << device;
344 *node.mutable_device() = std::move(device);
345 }
346 }
347
348 // Traverse all `const_nodes`, and map them back to GPU greedily.
349 for (auto& it : const_nodes) {
350 GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
351 NodeDef* node = it.first;
352 const string& device = it.second;
353
354 // Check all the consumers of this node, if any of them are not on CPU, swap
355 // this node back onto the original device.
356 for (const GraphView::InputPort& fanout : graph.GetFanouts(*node, false)) {
357 // The consumer is not Host friendly, swap it back to the original device.
358 if (!internal::IsNodeInputPortHostFriendly(*fanout.node,
359 fanout.port_id)) {
360 VLOG(2) << "Swapping node " << node->name() << " back to device "
361 << device;
362 node->set_device(device);
363 break;
364 }
365 }
366 }
367 return OkStatus();
368 }
369
370 } // end namespace grappler
371 } // end namespace tensorflow
372