1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/generic_layout_optimizer.h"
17
18 #include <utility>
19
20 #include "absl/memory/memory.h"
21 #include "absl/strings/string_view.h"
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/framework/node_def_util.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/types.pb.h"
26 #include "tensorflow/core/grappler/clusters/cluster.h"
27 #include "tensorflow/core/grappler/grappler_item.h"
28 #include "tensorflow/core/grappler/op_types.h"
29 #include "tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.h"
30 #include "tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer_factory.h"
31 #include "tensorflow/core/lib/core/errors.h"
32
33 namespace tensorflow {
34 namespace grappler {
35
36 namespace {
37
38 constexpr char kNHWC[] = "NHWC";
39 constexpr char kNCHW[] = "NCHW";
40 constexpr float kVoltaGPURatioThreshold = 0.5;
41 constexpr float kConvGPUFP16Threshold = 0.5;
42
43 struct MutableNodeViewFormatter {
operator ()tensorflow::grappler::__anon59a466040111::MutableNodeViewFormatter44 void operator()(std::string* out, utils::MutableNodeView* node_view) const {
45 absl::StrAppend(out, node_view->node()->name());
46 }
47 };
48
GetNumGPUs(const Cluster & cluster)49 inline std::pair<int, int> GetNumGPUs(const Cluster& cluster) {
50 auto devices = cluster.GetDevices();
51 int num_gpus = 0;
52 int num_volta = 0;
53 for (const auto& device : devices) {
54 if (device.second.type() != kGPU) {
55 continue;
56 }
57 num_gpus++;
58 auto compute_capability_it =
59 device.second.environment().find("architecture");
60 if (compute_capability_it == device.second.environment().end()) {
61 continue;
62 }
63 double compute_capability = 0.0;
64 if (absl::SimpleAtod(compute_capability_it->second, &compute_capability) &&
65 compute_capability >= 7.0) {
66 num_volta++;
67 }
68 }
69 return {num_gpus, num_volta};
70 }
71
NumConvOnDeviceWithDataTypeOverThreshold(const TransposeContext & context,absl::string_view device,const DataType & data_type)72 inline bool NumConvOnDeviceWithDataTypeOverThreshold(
73 const TransposeContext& context, absl::string_view device,
74 const DataType& data_type) {
75 int num_conv_gpu = 0;
76 int num_conv_gpu_fp16 = 0;
77
78 for (const auto& node : context.graph_view->GetNodes()) {
79 const auto* node_def = node.node();
80 if (!IsConv2D(*node_def) && !IsConv3D(*node_def)) {
81 continue;
82 }
83 const string& device_name = GetDeviceName(*node_def);
84 string device_type;
85 string task;
86 if (!DeviceNameUtils::SplitDeviceName(device_name, &task, &device_type) ||
87 !absl::StrContains(absl::AsciiStrToLower(device_type),
88 absl::AsciiStrToLower(device))) {
89 continue;
90 }
91 num_conv_gpu++;
92 const auto* t_attr = node.GetAttr("T");
93 if (t_attr == nullptr) {
94 continue;
95 }
96 if (t_attr->type() == data_type) {
97 num_conv_gpu_fp16++;
98 }
99 }
100
101 if (num_conv_gpu == 0) return false;
102
103 return (static_cast<float>(num_conv_gpu_fp16) /
104 static_cast<float>(num_conv_gpu)) >= kConvGPUFP16Threshold;
105 }
106
GetSrcAndDstDataFormats(const TransposeContext & context,int num_gpus,int num_voltas)107 inline std::pair<string, string> GetSrcAndDstDataFormats(
108 const TransposeContext& context, int num_gpus, int num_voltas) {
109 string src_format = kNHWC;
110 string dst_format = kNCHW;
111
112 const bool is_NHWC_enforced =
113 (!context.enforced_layout.empty() && context.enforced_layout == "NHWC");
114 const bool should_swap =
115 ((static_cast<float>(num_voltas) / static_cast<float>(num_gpus)) >=
116 kVoltaGPURatioThreshold) &&
117 NumConvOnDeviceWithDataTypeOverThreshold(context, kGPU, DT_HALF);
118 // We swap only if NHWC is enforced or no layout is enforced and the devices
119 // config meet the thresholds
120 if (is_NHWC_enforced || (context.enforced_layout.empty() && should_swap)) {
121 std::swap(src_format, dst_format);
122 }
123
124 return {src_format, dst_format};
125 }
126
ExpandLayoutSensitiveOp(TransposeContext * context,TransposerFactory * transposer_factory)127 Status ExpandLayoutSensitiveOp(TransposeContext* context,
128 TransposerFactory* transposer_factory) {
129 const int num_nodes = context->num_nodes;
130 for (int i = 0; i < num_nodes; ++i) {
131 auto* node_view = context->graph_view->GetNode(i);
132 auto* node_def = node_view->node();
133 if (IsLayoutSensitiveOp(*node_def)) {
134 std::shared_ptr<Transposer> transposer =
135 transposer_factory->GetTransposer(*node_def);
136 if (transposer == nullptr) {
137 return Status(
138 error::NOT_FOUND,
139 absl::StrCat(
140 "Layout sensitive operation should have a transposer. Node: ",
141 node_def->DebugString()));
142 }
143 TF_RETURN_IF_ERROR(transposer->TransposeNode(context, node_view));
144 }
145 }
146 return OkStatus();
147 }
148
ExpandLayoutAgnosticOp(TransposeContext * context,TransposerFactory * transposer_factory)149 Status ExpandLayoutAgnosticOp(TransposeContext* context,
150 TransposerFactory* transposer_factory) {
151 const int num_nodes = context->num_nodes;
152 for (int i = 0; i < num_nodes; ++i) {
153 auto* node_view = context->graph_view->GetNode(i);
154 auto* node_def = node_view->node();
155 if (IsLayoutAgnosticOp(*node_def)) {
156 const auto& transposer = transposer_factory->GetTransposer(*node_def);
157 if (transposer == nullptr) {
158 return Status(
159 error::NOT_FOUND,
160 absl::StrCat(
161 "Layout agnostic operation should have a transposer. Node: ",
162 node_def->DebugString()));
163 }
164 TF_RETURN_IF_ERROR(transposer->TransposeNode(context, node_view));
165 }
166 }
167 return OkStatus();
168 }
169
IsCancellableConstPermTransposeNodePair(const utils::MutableNodeView & fanout_transpose,const utils::MutableNodeView & fanin_transpose)170 inline bool IsCancellableConstPermTransposeNodePair(
171 const utils::MutableNodeView& fanout_transpose,
172 const utils::MutableNodeView& fanin_transpose) {
173 Tensor fanout_tensor;
174 if (!GetValueAttrFromConstInputNode(fanout_transpose, IsTranspose, 1,
175 &fanout_tensor)) {
176 return false;
177 }
178 Tensor fanin_tensor;
179 if (!GetValueAttrFromConstInputNode(fanin_transpose, IsTranspose, 1,
180 &fanin_tensor)) {
181 return false;
182 }
183 if (fanout_tensor.NumElements() != fanin_tensor.NumElements()) {
184 return false;
185 }
186
187 // Using dst->src to permute on src->dst will result in
188 // seq(0, ..., num_elements - 1) if they are cancellable.
189 const auto& fanout_tensor_data = fanout_tensor.unaligned_flat<int32>();
190 const auto& fanin_tensor_data = fanin_tensor.unaligned_flat<int32>();
191 const int num_elements = fanout_tensor.NumElements();
192 for (int i = 0; i < num_elements; ++i) {
193 if (fanout_tensor_data(fanin_tensor_data(i)) != i) {
194 return false;
195 }
196 }
197 return true;
198 }
199
IsCancellableDataFormatNodePair(const utils::MutableNodeView & fanout_transpose,const utils::MutableNodeView & fanin_transpose)200 inline bool IsCancellableDataFormatNodePair(
201 const utils::MutableNodeView& fanout_transpose,
202 const utils::MutableNodeView& fanin_transpose) {
203 if (!IsDataFormatOp(fanout_transpose) || !IsDataFormatOp(fanin_transpose)) {
204 return false;
205 }
206
207 auto src_dst_match = [](const utils::MutableNodeView& src,
208 const utils::MutableNodeView& dst) {
209 const auto* src_format = src.GetAttr(kAttrSrcFormat);
210 if (src_format == nullptr) {
211 return false;
212 }
213 const auto* dst_format = dst.GetAttr(kAttrDstFormat);
214 if (dst_format == nullptr) {
215 return false;
216 }
217 return src_format->s() == dst_format->s();
218 };
219
220 // If src_format node A is equal to dst_format of node B and dst_format of
221 // node A is equal to src_format of node B, then they are cancellable.
222 return src_dst_match(fanin_transpose, fanout_transpose) &&
223 src_dst_match(fanout_transpose, fanin_transpose);
224 }
225
IsCancellableNodePair(const utils::MutableNodeView & fanout_transpose,const utils::MutableNodeView & fanin_transpose)226 inline bool IsCancellableNodePair(
227 const utils::MutableNodeView& fanout_transpose,
228 const utils::MutableNodeView& fanin_transpose) {
229 return IsCancellableConstPermTransposeNodePair(fanout_transpose,
230 fanin_transpose) ||
231 IsCancellableDataFormatNodePair(fanout_transpose, fanin_transpose);
232 }
233
EraseCancellableNodes(TransposeContext * context)234 Status EraseCancellableNodes(TransposeContext* context) {
235 const int original_num_nodes = context->num_nodes;
236 utils::MutableGraphView* graph_view = context->graph_view.get();
237 utils::Mutation* mutation = graph_view->GetMutationBuilder();
238 const int num_nodes = graph_view->NumNodes();
239
240 for (int i = original_num_nodes; i < num_nodes; ++i) {
241 auto* node = graph_view->GetNode(i);
242 if (node->NumRegularFanins() < 1) {
243 continue;
244 }
245 const auto& regular_fanin_0 = node->GetRegularFanin(0);
246 auto* fanin_node = regular_fanin_0.node_view();
247 // TODO(lyandy): Lift restriction once original nodes in the graph can be
248 // pruned away.
249 if (fanin_node->node_index() < original_num_nodes) {
250 continue;
251 }
252 if (!IsCancellableNodePair(*node, *fanin_node)) {
253 continue;
254 }
255 const auto& fanin_to_forward = fanin_node->GetRegularFanin(0);
256 TensorId fanin_id_to_forward(fanin_to_forward.node_view()->GetName(),
257 fanin_to_forward.index());
258 for (const auto& regular_fanout : node->GetRegularFanout(0)) {
259 mutation->AddOrUpdateRegularFanin(regular_fanout.node_view(),
260 regular_fanout.index(),
261 fanin_id_to_forward);
262 }
263 mutation->RemoveNode(node);
264 if (node->NumRegularFanins() > 1) {
265 mutation->RemoveNode(node->GetRegularFanin(1).node_view());
266 }
267 mutation->RemoveNode(fanin_node);
268 if (fanin_node->NumRegularFanins() > 1) {
269 mutation->RemoveNode(fanin_node->GetRegularFanin(1).node_view());
270 }
271 }
272 return mutation->Apply();
273 }
274
275 // TODO(ezhulenev): This is a temporary workaround for a graph pattern
276 // in Resnet models. We should be able to push down transpose nodes across Pad
277 // and many other ops, and then rely on cancellation to remove them.
278 //
279 // From: Transpose[NHWC->NCHW] -> Pad[paddings] -> Transpose[NCHW->NHWC]
280 // To: Pad[Permute(paddings)]
EraseCancellableNodesAroundPad(TransposeContext * context)281 Status EraseCancellableNodesAroundPad(TransposeContext* context) {
282 utils::MutableGraphView* graph_view = context->graph_view.get();
283 utils::Mutation* mutation = graph_view->GetMutationBuilder();
284
285 absl::flat_hash_set<utils::MutableNodeView*> cancelled_transposes;
286
287 const int num_nodes = graph_view->NumNodes();
288 for (int i = 0; i < num_nodes; ++i) {
289 // Transpose node after Pad.
290 auto* transpose_after = graph_view->GetNode(i);
291 if (!IsTranspose(*transpose_after->node())) continue;
292
293 // This transpose was already cancelled in previous loop iteration.
294 if (cancelled_transposes.contains(transpose_after)) continue;
295
296 // Pad node.
297 const auto& transpose_after_fanin = transpose_after->GetRegularFanin(0);
298 auto* pad = transpose_after_fanin.node_view();
299 if (!IsPad(*pad->node())) continue;
300
301 // Transpose node before Pad.
302 const auto& pad_fanin_0 = pad->GetRegularFanin(0);
303 auto* transpose_before = pad_fanin_0.node_view();
304 if (!IsTranspose(*transpose_before->node())) continue;
305
306 // Transpose before output used once by the Pad node.
307 if (transpose_before->NumRegularFanouts() != 1) continue;
308
309 // Transposes are cancellable.
310 if (!IsCancellableConstPermTransposeNodePair(*transpose_after,
311 *transpose_before))
312 continue;
313
314 // Paddings are known constant values.
315 Tensor paddings_t;
316 if (!GetValueAttrFromConstInputNode(*pad, IsPad, 1, &paddings_t)) continue;
317
318 // Paddings value used once by the pad node only.
319 const auto& pad_fanin_1 = pad->GetRegularFanin(1);
320 auto* paddings = pad_fanin_1.node_view();
321 if (paddings->NumRegularFanouts() != 1) continue;
322
323 // Get permutation after the padding.
324 Tensor permute_t;
325 if (!GetValueAttrFromConstInputNode(*transpose_after, IsTranspose, 1,
326 &permute_t))
327 continue;
328
329 // Pad output might be used multiple times by different Transpose nodes. If
330 // they all have identical permutation, we can cancel all of them.
331 std::vector<utils::MutableNodeView*> pad_fanout_transposes;
332 pad_fanout_transposes.emplace_back(transpose_after);
333
334 bool pad_has_unsupported_fanout = false;
335 for (auto& fanout : pad->GetRegularFanout(0)) {
336 auto* extra_transpose = fanout.node_view();
337 if (extra_transpose == transpose_after) continue;
338
339 // Check that fanout is a Transpose identical to the transpose_after.
340 Tensor extra_permute_t;
341 if (!GetValueAttrFromConstInputNode(*extra_transpose, IsTranspose, 1,
342 &extra_permute_t) ||
343 extra_permute_t.tensor_data() != permute_t.tensor_data()) {
344 pad_has_unsupported_fanout = true;
345 break;
346 }
347
348 pad_fanout_transposes.emplace_back(extra_transpose);
349 }
350 if (pad_has_unsupported_fanout) continue;
351
352 VLOG(0) << "Cancel Transpose nodes around Pad:"
353 << " transpose_before=" << transpose_before->node()->name()
354 << " pad=" << pad->node()->name() << " transpose_after="
355 << absl::StrJoin(pad_fanout_transposes, ",",
356 MutableNodeViewFormatter());
357
358 // Permute paddings in place according to permutation in second transpose.
359 auto permutation_s = absl::Span<int32>(permute_t.flat<int32>().data(),
360 permute_t.NumElements());
361 auto paddings_s = absl::Span<int32>(paddings_t.flat<int32>().data(),
362 paddings_t.NumElements());
363 TF_RETURN_IF_ERROR(
364 PermuteDouble(absl::StrCat("paddings in ", pad->GetName()),
365 permutation_s, &paddings_s));
366
367 // Update paddings constant value with a permuted tensor.
368 AttrValue permuted_paddings_tensor;
369 paddings_t.AsProtoTensorContent(permuted_paddings_tensor.mutable_tensor());
370 mutation->AddOrUpdateNodeAttr(paddings, "value", permuted_paddings_tensor);
371
372 // Transform Transpose nodes into Identity nodes.
373 const auto transpose_to_identity =
374 [&cancelled_transposes,
375 &mutation](utils::MutableNodeView* transpose) -> void {
376 mutation->UpdateNodeOp(transpose, "Identity");
377 mutation->RemoveNodeAttr(transpose, "Tperm");
378 mutation->RemoveRegularFanin(transpose, 1);
379 cancelled_transposes.insert(transpose);
380 };
381
382 transpose_to_identity(transpose_before);
383 absl::c_for_each(pad_fanout_transposes, transpose_to_identity);
384 }
385
386 return mutation->Apply();
387 }
388
EraseOutputShapeAttrs(TransposeContext * context)389 Status EraseOutputShapeAttrs(TransposeContext* context) {
390 utils::MutableGraphView* graph_view = context->graph_view.get();
391 utils::Mutation* mutation = graph_view->GetMutationBuilder();
392 const int num_nodes = graph_view->NumNodes();
393 for (int i = 0; i < num_nodes; ++i) {
394 auto* node = graph_view->GetNode(i);
395 if (IsArg(*node->node())) {
396 continue;
397 }
398 mutation->RemoveNodeAttr(node, kAttrOutputShape);
399 TF_RETURN_IF_ERROR(mutation->Apply());
400 }
401 return OkStatus();
402 }
403
404 } // namespace
405
406 // When there is a GPU, the computation graph is converted to NCHW format.
407 // When there is only CPU, there will be no conversion by default, unless user
408 // chose to convert the graph to a desired format. Currently, NCHW -> NHWC
409 // format conversion is available on CPU.
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * output)410 Status GenericLayoutOptimizer::Optimize(Cluster* cluster,
411 const GrapplerItem& item,
412 GraphDef* output) {
413 if (cluster == nullptr) {
414 LOG(WARNING)
415 << "generic layout optimizer was called with cluster == nullptr";
416 return errors::Aborted("cluster == nullptr.");
417 }
418 if (!enforced_layout_.empty() && enforced_layout_ != "NHWC" &&
419 enforced_layout_ != "NCHW") {
420 return Status(
421 tensorflow::error::Code::INVALID_ARGUMENT,
422 absl::StrCat("Invalid value for enforced_layout: ", enforced_layout_,
423 ". Supported layouts: 'NHWC', 'NCHW'."));
424 }
425 const auto num_gpus_and_num_volta = GetNumGPUs(*cluster);
426 const int num_gpus = num_gpus_and_num_volta.first;
427
428 const bool is_aggressive = opt_level_ == RewriterConfig::AGGRESSIVE;
429
430 TransposeContext context;
431 context.enforced_layout = enforced_layout_;
432
433 if (num_gpus > 0) {
434 TF_RETURN_IF_ERROR(TransposeContext::InitializeTransposeContext(
435 /*assume_valid_feeds=*/is_aggressive, item, cluster, &context));
436
437 const auto src_dst_formats = GetSrcAndDstDataFormats(
438 context, num_gpus, num_gpus_and_num_volta.second);
439 context.AssignDeviceAndDataFormats(kGPU, src_dst_formats.first,
440 src_dst_formats.second);
441 } else {
442 TF_RETURN_IF_ERROR(TransposeContext::InitializeTransposeContext(
443 /*assume_valid_feeds=*/is_aggressive, item, cluster, &context));
444 switch (cpu_layout_conversion_) {
445 case RewriterConfig::NCHW_TO_NHWC:
446 context.AssignDeviceAndDataFormats(kCPU, kNCHW, kNHWC);
447 break;
448 // TODO(intel-tf): Add functionality for NHWC_TO_NCHW layout conversion on
449 // CPU.
450 case RewriterConfig::NHWC_TO_NCHW:
451 return errors::Aborted(
452 "Conversion from NHWC to NCHW is currently not available for "
453 "CPU.");
454 default:
455 *output = item.graph;
456 VLOG(2) << "No layout conversion will take place for CPU.";
457 return OkStatus();
458 }
459 }
460
461 TransposerFactory transposer_factory;
462 TF_RETURN_IF_ERROR(ExpandLayoutSensitiveOp(&context, &transposer_factory));
463 if (context.graph.node_size() > context.num_nodes || is_aggressive) {
464 TF_RETURN_IF_ERROR(ExpandLayoutAgnosticOp(&context, &transposer_factory));
465 TF_RETURN_IF_ERROR(EraseCancellableNodes(&context));
466 TF_RETURN_IF_ERROR(EraseCancellableNodesAroundPad(&context));
467 // TODO(lyandy): Remove sorting once other optimizers are migrated to using
468 // `utils::GraphView`.
469 TF_RETURN_IF_ERROR(
470 context.graph_view->SortTopologically(/*ignore_cycles=*/false, {}));
471 }
472 TF_RETURN_IF_ERROR(EraseOutputShapeAttrs(&context));
473
474 *output = context.graph;
475 return OkStatus();
476 }
477
478 } // end namespace grappler
479 } // end namespace tensorflow
480