xref: /aosp_15_r20/external/tensorflow/tensorflow/core/grappler/optimizers/generic_layout_optimizer.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/generic_layout_optimizer.h"
17 
18 #include <utility>
19 
20 #include "absl/memory/memory.h"
21 #include "absl/strings/string_view.h"
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/framework/node_def_util.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/types.pb.h"
26 #include "tensorflow/core/grappler/clusters/cluster.h"
27 #include "tensorflow/core/grappler/grappler_item.h"
28 #include "tensorflow/core/grappler/op_types.h"
29 #include "tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.h"
30 #include "tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer_factory.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 
33 namespace tensorflow {
34 namespace grappler {
35 
36 namespace {
37 
38 constexpr char kNHWC[] = "NHWC";
39 constexpr char kNCHW[] = "NCHW";
40 constexpr float kVoltaGPURatioThreshold = 0.5;
41 constexpr float kConvGPUFP16Threshold = 0.5;
42 
43 struct MutableNodeViewFormatter {
operator ()tensorflow::grappler::__anon59a466040111::MutableNodeViewFormatter44   void operator()(std::string* out, utils::MutableNodeView* node_view) const {
45     absl::StrAppend(out, node_view->node()->name());
46   }
47 };
48 
GetNumGPUs(const Cluster & cluster)49 inline std::pair<int, int> GetNumGPUs(const Cluster& cluster) {
50   auto devices = cluster.GetDevices();
51   int num_gpus = 0;
52   int num_volta = 0;
53   for (const auto& device : devices) {
54     if (device.second.type() != kGPU) {
55       continue;
56     }
57     num_gpus++;
58     auto compute_capability_it =
59         device.second.environment().find("architecture");
60     if (compute_capability_it == device.second.environment().end()) {
61       continue;
62     }
63     double compute_capability = 0.0;
64     if (absl::SimpleAtod(compute_capability_it->second, &compute_capability) &&
65         compute_capability >= 7.0) {
66       num_volta++;
67     }
68   }
69   return {num_gpus, num_volta};
70 }
71 
NumConvOnDeviceWithDataTypeOverThreshold(const TransposeContext & context,absl::string_view device,const DataType & data_type)72 inline bool NumConvOnDeviceWithDataTypeOverThreshold(
73     const TransposeContext& context, absl::string_view device,
74     const DataType& data_type) {
75   int num_conv_gpu = 0;
76   int num_conv_gpu_fp16 = 0;
77 
78   for (const auto& node : context.graph_view->GetNodes()) {
79     const auto* node_def = node.node();
80     if (!IsConv2D(*node_def) && !IsConv3D(*node_def)) {
81       continue;
82     }
83     const string& device_name = GetDeviceName(*node_def);
84     string device_type;
85     string task;
86     if (!DeviceNameUtils::SplitDeviceName(device_name, &task, &device_type) ||
87         !absl::StrContains(absl::AsciiStrToLower(device_type),
88                            absl::AsciiStrToLower(device))) {
89       continue;
90     }
91     num_conv_gpu++;
92     const auto* t_attr = node.GetAttr("T");
93     if (t_attr == nullptr) {
94       continue;
95     }
96     if (t_attr->type() == data_type) {
97       num_conv_gpu_fp16++;
98     }
99   }
100 
101   if (num_conv_gpu == 0) return false;
102 
103   return (static_cast<float>(num_conv_gpu_fp16) /
104           static_cast<float>(num_conv_gpu)) >= kConvGPUFP16Threshold;
105 }
106 
GetSrcAndDstDataFormats(const TransposeContext & context,int num_gpus,int num_voltas)107 inline std::pair<string, string> GetSrcAndDstDataFormats(
108     const TransposeContext& context, int num_gpus, int num_voltas) {
109   string src_format = kNHWC;
110   string dst_format = kNCHW;
111 
112   const bool is_NHWC_enforced =
113       (!context.enforced_layout.empty() && context.enforced_layout == "NHWC");
114   const bool should_swap =
115       ((static_cast<float>(num_voltas) / static_cast<float>(num_gpus)) >=
116        kVoltaGPURatioThreshold) &&
117       NumConvOnDeviceWithDataTypeOverThreshold(context, kGPU, DT_HALF);
118   // We swap only if NHWC is enforced or no layout is enforced and the devices
119   // config meet the thresholds
120   if (is_NHWC_enforced || (context.enforced_layout.empty() && should_swap)) {
121     std::swap(src_format, dst_format);
122   }
123 
124   return {src_format, dst_format};
125 }
126 
ExpandLayoutSensitiveOp(TransposeContext * context,TransposerFactory * transposer_factory)127 Status ExpandLayoutSensitiveOp(TransposeContext* context,
128                                TransposerFactory* transposer_factory) {
129   const int num_nodes = context->num_nodes;
130   for (int i = 0; i < num_nodes; ++i) {
131     auto* node_view = context->graph_view->GetNode(i);
132     auto* node_def = node_view->node();
133     if (IsLayoutSensitiveOp(*node_def)) {
134       std::shared_ptr<Transposer> transposer =
135           transposer_factory->GetTransposer(*node_def);
136       if (transposer == nullptr) {
137         return Status(
138             error::NOT_FOUND,
139             absl::StrCat(
140                 "Layout sensitive operation should have a transposer. Node: ",
141                 node_def->DebugString()));
142       }
143       TF_RETURN_IF_ERROR(transposer->TransposeNode(context, node_view));
144     }
145   }
146   return OkStatus();
147 }
148 
ExpandLayoutAgnosticOp(TransposeContext * context,TransposerFactory * transposer_factory)149 Status ExpandLayoutAgnosticOp(TransposeContext* context,
150                               TransposerFactory* transposer_factory) {
151   const int num_nodes = context->num_nodes;
152   for (int i = 0; i < num_nodes; ++i) {
153     auto* node_view = context->graph_view->GetNode(i);
154     auto* node_def = node_view->node();
155     if (IsLayoutAgnosticOp(*node_def)) {
156       const auto& transposer = transposer_factory->GetTransposer(*node_def);
157       if (transposer == nullptr) {
158         return Status(
159             error::NOT_FOUND,
160             absl::StrCat(
161                 "Layout agnostic operation should have a transposer. Node: ",
162                 node_def->DebugString()));
163       }
164       TF_RETURN_IF_ERROR(transposer->TransposeNode(context, node_view));
165     }
166   }
167   return OkStatus();
168 }
169 
IsCancellableConstPermTransposeNodePair(const utils::MutableNodeView & fanout_transpose,const utils::MutableNodeView & fanin_transpose)170 inline bool IsCancellableConstPermTransposeNodePair(
171     const utils::MutableNodeView& fanout_transpose,
172     const utils::MutableNodeView& fanin_transpose) {
173   Tensor fanout_tensor;
174   if (!GetValueAttrFromConstInputNode(fanout_transpose, IsTranspose, 1,
175                                       &fanout_tensor)) {
176     return false;
177   }
178   Tensor fanin_tensor;
179   if (!GetValueAttrFromConstInputNode(fanin_transpose, IsTranspose, 1,
180                                       &fanin_tensor)) {
181     return false;
182   }
183   if (fanout_tensor.NumElements() != fanin_tensor.NumElements()) {
184     return false;
185   }
186 
187   // Using dst->src to permute on src->dst will result in
188   // seq(0, ..., num_elements - 1) if they are cancellable.
189   const auto& fanout_tensor_data = fanout_tensor.unaligned_flat<int32>();
190   const auto& fanin_tensor_data = fanin_tensor.unaligned_flat<int32>();
191   const int num_elements = fanout_tensor.NumElements();
192   for (int i = 0; i < num_elements; ++i) {
193     if (fanout_tensor_data(fanin_tensor_data(i)) != i) {
194       return false;
195     }
196   }
197   return true;
198 }
199 
IsCancellableDataFormatNodePair(const utils::MutableNodeView & fanout_transpose,const utils::MutableNodeView & fanin_transpose)200 inline bool IsCancellableDataFormatNodePair(
201     const utils::MutableNodeView& fanout_transpose,
202     const utils::MutableNodeView& fanin_transpose) {
203   if (!IsDataFormatOp(fanout_transpose) || !IsDataFormatOp(fanin_transpose)) {
204     return false;
205   }
206 
207   auto src_dst_match = [](const utils::MutableNodeView& src,
208                           const utils::MutableNodeView& dst) {
209     const auto* src_format = src.GetAttr(kAttrSrcFormat);
210     if (src_format == nullptr) {
211       return false;
212     }
213     const auto* dst_format = dst.GetAttr(kAttrDstFormat);
214     if (dst_format == nullptr) {
215       return false;
216     }
217     return src_format->s() == dst_format->s();
218   };
219 
220   // If src_format node A is equal to dst_format of node B and dst_format of
221   // node A is equal to src_format of node B, then they are cancellable.
222   return src_dst_match(fanin_transpose, fanout_transpose) &&
223          src_dst_match(fanout_transpose, fanin_transpose);
224 }
225 
IsCancellableNodePair(const utils::MutableNodeView & fanout_transpose,const utils::MutableNodeView & fanin_transpose)226 inline bool IsCancellableNodePair(
227     const utils::MutableNodeView& fanout_transpose,
228     const utils::MutableNodeView& fanin_transpose) {
229   return IsCancellableConstPermTransposeNodePair(fanout_transpose,
230                                                  fanin_transpose) ||
231          IsCancellableDataFormatNodePair(fanout_transpose, fanin_transpose);
232 }
233 
EraseCancellableNodes(TransposeContext * context)234 Status EraseCancellableNodes(TransposeContext* context) {
235   const int original_num_nodes = context->num_nodes;
236   utils::MutableGraphView* graph_view = context->graph_view.get();
237   utils::Mutation* mutation = graph_view->GetMutationBuilder();
238   const int num_nodes = graph_view->NumNodes();
239 
240   for (int i = original_num_nodes; i < num_nodes; ++i) {
241     auto* node = graph_view->GetNode(i);
242     if (node->NumRegularFanins() < 1) {
243       continue;
244     }
245     const auto& regular_fanin_0 = node->GetRegularFanin(0);
246     auto* fanin_node = regular_fanin_0.node_view();
247     // TODO(lyandy): Lift restriction once original nodes in the graph can be
248     // pruned away.
249     if (fanin_node->node_index() < original_num_nodes) {
250       continue;
251     }
252     if (!IsCancellableNodePair(*node, *fanin_node)) {
253       continue;
254     }
255     const auto& fanin_to_forward = fanin_node->GetRegularFanin(0);
256     TensorId fanin_id_to_forward(fanin_to_forward.node_view()->GetName(),
257                                  fanin_to_forward.index());
258     for (const auto& regular_fanout : node->GetRegularFanout(0)) {
259       mutation->AddOrUpdateRegularFanin(regular_fanout.node_view(),
260                                         regular_fanout.index(),
261                                         fanin_id_to_forward);
262     }
263     mutation->RemoveNode(node);
264     if (node->NumRegularFanins() > 1) {
265       mutation->RemoveNode(node->GetRegularFanin(1).node_view());
266     }
267     mutation->RemoveNode(fanin_node);
268     if (fanin_node->NumRegularFanins() > 1) {
269       mutation->RemoveNode(fanin_node->GetRegularFanin(1).node_view());
270     }
271   }
272   return mutation->Apply();
273 }
274 
275 // TODO(ezhulenev): This is a temporary workaround for a graph pattern
276 // in Resnet models. We should be able to push down transpose nodes across Pad
277 // and many other ops, and then rely on cancellation to remove them.
278 //
279 // From: Transpose[NHWC->NCHW] -> Pad[paddings] -> Transpose[NCHW->NHWC]
280 // To:   Pad[Permute(paddings)]
EraseCancellableNodesAroundPad(TransposeContext * context)281 Status EraseCancellableNodesAroundPad(TransposeContext* context) {
282   utils::MutableGraphView* graph_view = context->graph_view.get();
283   utils::Mutation* mutation = graph_view->GetMutationBuilder();
284 
285   absl::flat_hash_set<utils::MutableNodeView*> cancelled_transposes;
286 
287   const int num_nodes = graph_view->NumNodes();
288   for (int i = 0; i < num_nodes; ++i) {
289     // Transpose node after Pad.
290     auto* transpose_after = graph_view->GetNode(i);
291     if (!IsTranspose(*transpose_after->node())) continue;
292 
293     // This transpose was already cancelled in previous loop iteration.
294     if (cancelled_transposes.contains(transpose_after)) continue;
295 
296     // Pad node.
297     const auto& transpose_after_fanin = transpose_after->GetRegularFanin(0);
298     auto* pad = transpose_after_fanin.node_view();
299     if (!IsPad(*pad->node())) continue;
300 
301     // Transpose node before Pad.
302     const auto& pad_fanin_0 = pad->GetRegularFanin(0);
303     auto* transpose_before = pad_fanin_0.node_view();
304     if (!IsTranspose(*transpose_before->node())) continue;
305 
306     // Transpose before output used once by the Pad node.
307     if (transpose_before->NumRegularFanouts() != 1) continue;
308 
309     // Transposes are cancellable.
310     if (!IsCancellableConstPermTransposeNodePair(*transpose_after,
311                                                  *transpose_before))
312       continue;
313 
314     // Paddings are known constant values.
315     Tensor paddings_t;
316     if (!GetValueAttrFromConstInputNode(*pad, IsPad, 1, &paddings_t)) continue;
317 
318     // Paddings value used once by the pad node only.
319     const auto& pad_fanin_1 = pad->GetRegularFanin(1);
320     auto* paddings = pad_fanin_1.node_view();
321     if (paddings->NumRegularFanouts() != 1) continue;
322 
323     // Get permutation after the padding.
324     Tensor permute_t;
325     if (!GetValueAttrFromConstInputNode(*transpose_after, IsTranspose, 1,
326                                         &permute_t))
327       continue;
328 
329     // Pad output might be used multiple times by different Transpose nodes. If
330     // they all have identical permutation, we can cancel all of them.
331     std::vector<utils::MutableNodeView*> pad_fanout_transposes;
332     pad_fanout_transposes.emplace_back(transpose_after);
333 
334     bool pad_has_unsupported_fanout = false;
335     for (auto& fanout : pad->GetRegularFanout(0)) {
336       auto* extra_transpose = fanout.node_view();
337       if (extra_transpose == transpose_after) continue;
338 
339       // Check that fanout is a Transpose identical to the transpose_after.
340       Tensor extra_permute_t;
341       if (!GetValueAttrFromConstInputNode(*extra_transpose, IsTranspose, 1,
342                                           &extra_permute_t) ||
343           extra_permute_t.tensor_data() != permute_t.tensor_data()) {
344         pad_has_unsupported_fanout = true;
345         break;
346       }
347 
348       pad_fanout_transposes.emplace_back(extra_transpose);
349     }
350     if (pad_has_unsupported_fanout) continue;
351 
352     VLOG(0) << "Cancel Transpose nodes around Pad:"
353             << " transpose_before=" << transpose_before->node()->name()
354             << " pad=" << pad->node()->name() << " transpose_after="
355             << absl::StrJoin(pad_fanout_transposes, ",",
356                              MutableNodeViewFormatter());
357 
358     // Permute paddings in place according to permutation in second transpose.
359     auto permutation_s = absl::Span<int32>(permute_t.flat<int32>().data(),
360                                            permute_t.NumElements());
361     auto paddings_s = absl::Span<int32>(paddings_t.flat<int32>().data(),
362                                         paddings_t.NumElements());
363     TF_RETURN_IF_ERROR(
364         PermuteDouble(absl::StrCat("paddings in ", pad->GetName()),
365                       permutation_s, &paddings_s));
366 
367     // Update paddings constant value with a permuted tensor.
368     AttrValue permuted_paddings_tensor;
369     paddings_t.AsProtoTensorContent(permuted_paddings_tensor.mutable_tensor());
370     mutation->AddOrUpdateNodeAttr(paddings, "value", permuted_paddings_tensor);
371 
372     // Transform Transpose nodes into Identity nodes.
373     const auto transpose_to_identity =
374         [&cancelled_transposes,
375          &mutation](utils::MutableNodeView* transpose) -> void {
376       mutation->UpdateNodeOp(transpose, "Identity");
377       mutation->RemoveNodeAttr(transpose, "Tperm");
378       mutation->RemoveRegularFanin(transpose, 1);
379       cancelled_transposes.insert(transpose);
380     };
381 
382     transpose_to_identity(transpose_before);
383     absl::c_for_each(pad_fanout_transposes, transpose_to_identity);
384   }
385 
386   return mutation->Apply();
387 }
388 
EraseOutputShapeAttrs(TransposeContext * context)389 Status EraseOutputShapeAttrs(TransposeContext* context) {
390   utils::MutableGraphView* graph_view = context->graph_view.get();
391   utils::Mutation* mutation = graph_view->GetMutationBuilder();
392   const int num_nodes = graph_view->NumNodes();
393   for (int i = 0; i < num_nodes; ++i) {
394     auto* node = graph_view->GetNode(i);
395     if (IsArg(*node->node())) {
396       continue;
397     }
398     mutation->RemoveNodeAttr(node, kAttrOutputShape);
399     TF_RETURN_IF_ERROR(mutation->Apply());
400   }
401   return OkStatus();
402 }
403 
404 }  // namespace
405 
406 // When there is a GPU, the computation graph is converted to NCHW format.
407 // When there is only CPU, there will be no conversion by default, unless user
408 // chose to convert the graph to a desired format. Currently, NCHW -> NHWC
409 // format conversion is available on CPU.
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * output)410 Status GenericLayoutOptimizer::Optimize(Cluster* cluster,
411                                         const GrapplerItem& item,
412                                         GraphDef* output) {
413   if (cluster == nullptr) {
414     LOG(WARNING)
415         << "generic layout optimizer was called with cluster == nullptr";
416     return errors::Aborted("cluster == nullptr.");
417   }
418   if (!enforced_layout_.empty() && enforced_layout_ != "NHWC" &&
419       enforced_layout_ != "NCHW") {
420     return Status(
421         tensorflow::error::Code::INVALID_ARGUMENT,
422         absl::StrCat("Invalid value for enforced_layout: ", enforced_layout_,
423                      ". Supported layouts: 'NHWC', 'NCHW'."));
424   }
425   const auto num_gpus_and_num_volta = GetNumGPUs(*cluster);
426   const int num_gpus = num_gpus_and_num_volta.first;
427 
428   const bool is_aggressive = opt_level_ == RewriterConfig::AGGRESSIVE;
429 
430   TransposeContext context;
431   context.enforced_layout = enforced_layout_;
432 
433   if (num_gpus > 0) {
434     TF_RETURN_IF_ERROR(TransposeContext::InitializeTransposeContext(
435         /*assume_valid_feeds=*/is_aggressive, item, cluster, &context));
436 
437     const auto src_dst_formats = GetSrcAndDstDataFormats(
438         context, num_gpus, num_gpus_and_num_volta.second);
439     context.AssignDeviceAndDataFormats(kGPU, src_dst_formats.first,
440                                        src_dst_formats.second);
441   } else {
442     TF_RETURN_IF_ERROR(TransposeContext::InitializeTransposeContext(
443         /*assume_valid_feeds=*/is_aggressive, item, cluster, &context));
444     switch (cpu_layout_conversion_) {
445       case RewriterConfig::NCHW_TO_NHWC:
446         context.AssignDeviceAndDataFormats(kCPU, kNCHW, kNHWC);
447         break;
448       // TODO(intel-tf): Add functionality for NHWC_TO_NCHW layout conversion on
449       // CPU.
450       case RewriterConfig::NHWC_TO_NCHW:
451         return errors::Aborted(
452             "Conversion from NHWC to NCHW is currently not  available for "
453             "CPU.");
454       default:
455         *output = item.graph;
456         VLOG(2) << "No layout conversion will take place for CPU.";
457         return OkStatus();
458     }
459   }
460 
461   TransposerFactory transposer_factory;
462   TF_RETURN_IF_ERROR(ExpandLayoutSensitiveOp(&context, &transposer_factory));
463   if (context.graph.node_size() > context.num_nodes || is_aggressive) {
464     TF_RETURN_IF_ERROR(ExpandLayoutAgnosticOp(&context, &transposer_factory));
465     TF_RETURN_IF_ERROR(EraseCancellableNodes(&context));
466     TF_RETURN_IF_ERROR(EraseCancellableNodesAroundPad(&context));
467     // TODO(lyandy): Remove sorting once other optimizers are migrated to using
468     // `utils::GraphView`.
469     TF_RETURN_IF_ERROR(
470         context.graph_view->SortTopologically(/*ignore_cycles=*/false, {}));
471   }
472   TF_RETURN_IF_ERROR(EraseOutputShapeAttrs(&context));
473 
474   *output = context.graph;
475   return OkStatus();
476 }
477 
478 }  // end namespace grappler
479 }  // end namespace tensorflow
480