1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.h"
17 
18 #include <algorithm>
19 #include <numeric>
20 #include <vector>
21 
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/strings/ascii.h"
24 #include "absl/strings/match.h"
25 #include "absl/strings/numbers.h"
26 #include "absl/strings/substitute.h"
27 #include "tensorflow/core/framework/attr_value.pb.h"
28 #include "tensorflow/core/framework/graph.pb.h"
29 #include "tensorflow/core/framework/memory_types.h"
30 #include "tensorflow/core/framework/node_def.pb.h"
31 #include "tensorflow/core/framework/tensor.pb.h"
32 #include "tensorflow/core/framework/tensor_shape.pb.h"
33 #include "tensorflow/core/framework/types.h"
34 #include "tensorflow/core/framework/types.pb.h"
35 #include "tensorflow/core/grappler/costs/graph_properties.h"
36 #include "tensorflow/core/grappler/op_types.h"
37 #include "tensorflow/core/grappler/utils.h"
38 #include "tensorflow/core/grappler/utils/frame.h"
39 #include "tensorflow/core/grappler/utils/graph_view.h"
40 #include "tensorflow/core/lib/core/errors.h"
41 #include "tensorflow/core/protobuf/device_properties.pb.h"
42 #include "tensorflow/core/util/device_name_utils.h"
43 
44 namespace tensorflow {
45 namespace grappler {
46 
47 namespace {
48 
49 constexpr char kOptimizedSuffix[] = "LayoutOptimizer";
50 constexpr char kAttrKSize[] = "ksize";
51 constexpr char kAttrStrides[] = "strides";
52 constexpr char kAttrDilations[] = "dilations";
53 constexpr char kAttrExplicitPaddings[] = "explicit_paddings";
54 constexpr char kAttrDataFormat[] = "data_format";
55 constexpr char kAttrIsTraining[] = "is_training";
56 constexpr char kAttrValue[] = "value";
57 constexpr char kAttrN[] = "N";
58 constexpr char kAttrT[] = "T";
59 constexpr char kAttrNumSplit[] = "num_split";
60 constexpr char kAttrNumOuts[] = "num_outs";
61 constexpr char kAttrKeepDims[] = "keep_dims";
62 constexpr char kAttrSqueezeDims[] = "squeeze_dims";
63 constexpr char kOpTranspose[] = "Transpose";
64 constexpr char kOpDataFormatVecPermute[] = "DataFormatVecPermute";
65 constexpr char kOpDataFormatDimMap[] = "DataFormatDimMap";
66 constexpr char kOpConst[] = "Const";
67 constexpr char kReshape[] = "Reshape";
68 constexpr char kReshapeConst[] = "ReshapeConst";
69 constexpr int kRank = 4;
70 constexpr int kUnknownRank = -1;
71 constexpr int kInvalidRank = -2;
72 
AttrDataFormatMatch(const utils::MutableNodeView & node,absl::string_view src_data_format,bool * missing)73 inline bool AttrDataFormatMatch(const utils::MutableNodeView& node,
74                                 absl::string_view src_data_format,
75                                 bool* missing) {
76   const auto* attr = node.GetAttr(kAttrDataFormat);
77   if (attr != nullptr) {
78     return attr->s() == src_data_format;
79   }
80   *missing = true;
81   return false;
82 }
83 
AttrDataFormatMatch(const utils::MutableNodeView & node,absl::string_view src_data_format)84 inline bool AttrDataFormatMatch(const utils::MutableNodeView& node,
85                                 absl::string_view src_data_format) {
86   bool missing = false;
87   return AttrDataFormatMatch(node, src_data_format, &missing);
88 }
89 
IsNonFloatingConv2D(const utils::MutableNodeView & node)90 bool IsNonFloatingConv2D(const utils::MutableNodeView& node) {
91   if (IsConv2D(*node.node()) || IsConv2DBackpropInput(*node.node())) {
92     const auto* attr = node.GetAttr(kAttrT);
93     if (attr != nullptr) {
94       return !kDataTypeIsFloating.Contains(attr->type());
95     }
96   }
97   return false;
98 }
99 
100 // Utils for layout agnostic transposer.
101 
IsComparisonOp(const NodeDef & node)102 bool IsComparisonOp(const NodeDef& node) {
103   bool is_compare = IsApproximateEqual(node) || IsEqual(node) ||
104                     IsGreater(node) || IsGreaterEqual(node) || IsLess(node) ||
105                     IsLessEqual(node) || IsNotEqual(node);
106   return is_compare;
107 }
108 
GetRegularFaninPorts(const utils::MutableNodeView & node)109 std::vector<int> GetRegularFaninPorts(const utils::MutableNodeView& node) {
110   const int num_regular_fanins = node.NumRegularFanins();
111   std::vector<int> values(num_regular_fanins);
112   std::iota(values.begin(), values.end(), 0);
113   return values;
114 }
115 
GetConcatDataFaninPorts(const utils::MutableNodeView & node)116 std::vector<int> GetConcatDataFaninPorts(const utils::MutableNodeView& node) {
117   const auto* n_attr = node.GetAttr(kAttrN);
118   const int n = n_attr != nullptr ? n_attr->i() : 0;
119   const int start = (node.GetOp() == "Concat") ? 1 : 0;
120   const int end = start + n;
121   std::vector<int> values(end - start);
122   std::iota(values.begin(), values.end(), start);
123   return values;
124 }
125 
126 struct ComparatorByNodeNameAndIndex {
operator ()tensorflow::grappler::__anon2a86c4340111::ComparatorByNodeNameAndIndex127   bool operator()(const utils::MutableFaninView& node1,
128                   const utils::MutableFaninView& node2) const {
129     auto* node1_view = node1.node_view();
130     auto* node2_view = node2.node_view();
131     auto name_compare = node1_view->GetName().compare(node2_view->GetName());
132     if (name_compare == 0) {
133       return node1.index() < node2.index();
134     }
135     return name_compare < 0;
136   }
137 };
138 
IsHostMemory(const NodeDef & node,int output_port)139 bool IsHostMemory(const NodeDef& node, int output_port) {
140   DeviceNameUtils::ParsedName parsed_name;
141   if (DeviceNameUtils::ParseFullName(node.device(), &parsed_name)) {
142     DeviceType device_type(parsed_name.type);
143     Status s = FindKernelDef(device_type, node, nullptr, nullptr);
144     if (s.ok()) {
145       tensorflow::MemoryTypeVector in_mtypes;
146       tensorflow::MemoryTypeVector out_mtypes;
147       s = tensorflow::MemoryTypesForNode(OpRegistry::Global(), device_type,
148                                          node, &in_mtypes, &out_mtypes);
149       if (s.ok()) {
150         if (out_mtypes[output_port] == HOST_MEMORY) {
151           return true;
152         }
153       }
154     } else {
155       return true;
156     }
157   }
158   return false;
159 }
160 
GetDimensionIndicesFromLabel(const absl::flat_hash_map<char,int> & dim_indices,absl::Span<const char> labels)161 std::vector<int> GetDimensionIndicesFromLabel(
162     const absl::flat_hash_map<char, int>& dim_indices,
163     absl::Span<const char> labels) {
164   std::vector<int> indices;
165   indices.reserve(labels.size());
166   for (const auto& label : labels) {
167     indices.push_back(dim_indices.at(label));
168   }
169   return indices;
170 }
171 
172 // RAII-styled object for keeping track of 4D to 5D data format
173 // upgrade/conversion. Currently only NHWC -> NDHWC and NCHW -> NCDHW are
174 // supported.
175 class ScopedDataFormatUpgrader {
176  public:
ScopedDataFormatUpgrader(TransposeContext * context,int rank)177   ScopedDataFormatUpgrader(TransposeContext* context, int rank)
178       : context_(context) {
179     if (rank == 5 && IsSupportedDataFormat(context_->src_format) &&
180         IsSupportedDataFormat(context_->dst_format)) {
181       old_src_format_ = context_->src_format;
182       old_dst_format_ = context_->dst_format;
183       std::string new_src_format = GetUpgradedDataFormat(context_->src_format);
184       std::string new_dst_format = GetUpgradedDataFormat(context_->dst_format);
185       context_->AssignDeviceAndDataFormats(context_->target_device,
186                                            new_src_format, new_dst_format);
187       upgraded_ = true;
188     }
189   }
190 
191   ScopedDataFormatUpgrader(const ScopedDataFormatUpgrader&) = delete;
192   ScopedDataFormatUpgrader& operator=(const ScopedDataFormatUpgrader&) = delete;
193 
~ScopedDataFormatUpgrader()194   ~ScopedDataFormatUpgrader() {
195     if (upgraded_) {
196       context_->AssignDeviceAndDataFormats(context_->target_device,
197                                            old_src_format_, old_dst_format_);
198     }
199   }
200 
201  private:
IsSupportedDataFormat(absl::string_view data_format)202   bool IsSupportedDataFormat(absl::string_view data_format) {
203     return data_format == "NHWC" || data_format == "NCHW";
204   }
205 
GetUpgradedDataFormat(absl::string_view data_format)206   std::string GetUpgradedDataFormat(absl::string_view data_format) {
207     if (data_format == "NHWC") {
208       return "NDHWC";
209     }
210 
211     DCHECK_EQ(data_format, "NCHW");
212     return "NCDHW";
213   }
214 
215   TransposeContext* context_ = nullptr;
216   bool upgraded_ = false;
217   std::string old_src_format_;
218   std::string old_dst_format_;
219 };
220 
221 }  // namespace
222 
223 // TransposeContext.
224 
InitializeTransposeContext(bool assume_valid_feeds,const GrapplerItem & item,const Cluster * cluster,TransposeContext * context)225 Status TransposeContext::InitializeTransposeContext(bool assume_valid_feeds,
226                                                     const GrapplerItem& item,
227                                                     const Cluster* cluster,
228                                                     TransposeContext* context) {
229   DCHECK(context != nullptr);
230   context->graph_properties = std::make_unique<GraphProperties>(item);
231   TF_RETURN_IF_ERROR(
232       context->graph_properties->InferStatically(assume_valid_feeds));
233   TF_RETURN_IF_ERROR(
234       context->graph_properties->AnnotateOutputShapes(&context->graph));
235   Status status;
236   context->graph_view =
237       std::make_unique<utils::MutableGraphView>(&context->graph, &status);
238   TF_RETURN_IF_ERROR(status);
239   context->num_nodes = context->graph.node_size();
240   const auto& nodes_to_preserve = item.NodesToPreserve();
241   context->nodes_to_preserve = absl::flat_hash_set<string>(
242       nodes_to_preserve.begin(), nodes_to_preserve.end());
243   TF_RETURN_IF_ERROR(context->frames.InferFromGraph(context->graph));
244   return OkStatus();
245 }
246 
247 // Sets data formats to convert from and to for specified device type.
AssignDeviceAndDataFormats(absl::string_view target_device,absl::string_view src_format,absl::string_view dst_format)248 void TransposeContext::AssignDeviceAndDataFormats(
249     absl::string_view target_device, absl::string_view src_format,
250     absl::string_view dst_format) {
251   this->target_device = string(target_device);
252   this->src_format = string(src_format);
253   this->dst_format = string(dst_format);
254   this->src_dim_indices = GetDimensionIndices(src_format);
255   this->dst_dim_indices = GetDimensionIndices(dst_format);
256   this->src_to_dst = GetPermutation(this->src_dim_indices, dst_format);
257   this->dst_to_src = GetPermutation(this->dst_dim_indices, src_format);
258 }
259 
260 // Transposer.
261 
ShouldProcess(const TransposeContext & context,const utils::MutableNodeView & node) const262 bool Transposer::ShouldProcess(const TransposeContext& context,
263                                const utils::MutableNodeView& node) const {
264   const auto* node_def = node.node();
265   const string& device_name = GetDeviceName(*node_def);
266   string device;
267   string task;
268   const bool is_on_target_device =
269       DeviceNameUtils::SplitDeviceName(device_name, &task, &device) &&
270       absl::StrContains(absl::AsciiStrToLower(device),
271                         absl::AsciiStrToLower(context.target_device));
272 
273   // Only checks data format for layout sensitive op.
274   const bool data_format_match = !IsLayoutSensitiveOp(*node_def) ||
275                                  AttrDataFormatMatch(node, context.src_format);
276 
277   // Only transposes floating point nodes.
278   const bool is_integer_conv2d = IsNonFloatingConv2D(node);
279 
280   return is_on_target_device && data_format_match && !is_integer_conv2d &&
281          !context.nodes_to_preserve.contains(node_def->name()) &&
282          !(node.NumRegularFanouts() == 0 && node.NumControlledFanouts() == 0);
283 }
284 
CreateConstPermNode(TransposeContext * context,absl::string_view node_name,absl::string_view device,absl::Span<const int> permutation,absl::string_view control_node_name,utils::MutationNewNode * added_node)285 Status Transposer::CreateConstPermNode(TransposeContext* context,
286                                        absl::string_view node_name,
287                                        absl::string_view device,
288                                        absl::Span<const int> permutation,
289                                        absl::string_view control_node_name,
290                                        utils::MutationNewNode* added_node) {
291   auto* graph_view = context->graph_view.get();
292   DCHECK(!graph_view->HasNode(node_name));
293 
294   NodeDef node;
295   node.set_name(string(node_name));
296   node.set_op(kOpConst);
297   node.set_device(string(device));
298 
299   if (!control_node_name.empty()) {
300     node.add_input(string(control_node_name));
301   }
302 
303   AttrValue attr_data_type;
304   attr_data_type.set_type(DT_INT32);
305   node.mutable_attr()->insert({"dtype", attr_data_type});
306 
307   AttrValue attr_tensor;
308   Tensor tensor(DT_INT32, TensorShape({(long long)permutation.size()}));
309   for (int i = 0, end = permutation.size(); i < end; i++) {
310     tensor.flat<int>()(i) = permutation[i];
311   }
312   tensor.AsProtoTensorContent(attr_tensor.mutable_tensor());
313   node.mutable_attr()->insert({"value", attr_tensor});
314 
315   Status status;
316   *added_node =
317       graph_view->GetMutationBuilder()->AddNode(std::move(node), &status);
318   return status;
319 }
320 
CreateTransposeNode(TransposeContext * context,absl::string_view name_format,const DataType & data_type,absl::string_view device,TensorShapeProto fanin_shape,absl::Span<const int> permutation,absl::string_view control_node_name,utils::MutationNewNode * added_node,string * transpose_node_name)321 Status Transposer::CreateTransposeNode(
322     TransposeContext* context, absl::string_view name_format,
323     const DataType& data_type, absl::string_view device,
324     TensorShapeProto fanin_shape, absl::Span<const int> permutation,
325     absl::string_view control_node_name, utils::MutationNewNode* added_node,
326     string* transpose_node_name) {
327   const string node_name = absl::Substitute(name_format, kOpTranspose);
328   auto* graph_view = context->graph_view.get();
329   DCHECK(!graph_view->HasNode(node_name));
330   *transpose_node_name = node_name;
331 
332   NodeDef node;
333   node.set_name(node_name);
334   node.set_op(kOpTranspose);
335   node.set_device(string(device));
336 
337   AttrValue attr_data_type;
338   attr_data_type.set_type(data_type);
339   node.mutable_attr()->insert({"T", attr_data_type});
340 
341   AttrValue attr_data_type_perm;
342   attr_data_type_perm.set_type(DT_INT32);
343   node.mutable_attr()->insert({"Tperm", attr_data_type_perm});
344 
345   if (!fanin_shape.unknown_rank()) {
346     TF_RETURN_IF_ERROR(
347         PermuteSingle(absl::StrCat("fanin shape in", node.name()), permutation,
348                       fanin_shape.mutable_dim()));
349     AttrValue attr_output_shape;
350     *attr_output_shape.mutable_list()->add_shape() = fanin_shape;
351     node.mutable_attr()->insert({kAttrOutputShape, attr_output_shape});
352   }
353 
354   // Create Const Node
355   utils::MutationNewNode const_perm_added_node;
356   const string const_perm_node_name =
357       absl::Substitute(name_format, "PermConst");
358   TF_RETURN_IF_ERROR(CreateConstPermNode(context, const_perm_node_name, device,
359                                          permutation, control_node_name,
360                                          &const_perm_added_node));
361   // Add place holder for 1st input.
362   node.add_input("");
363   // Connect const_perm_node to 2nd input of transpose_node.
364   node.add_input(const_perm_node_name);
365 
366   Status status;
367   *added_node =
368       graph_view->GetMutationBuilder()->AddNode(std::move(node), &status);
369   return status;
370 }
371 
UpdateFaninEdgesWithOp(TransposeContext * context,absl::Span<const int> dst_ports,utils::MutableNodeView * dst_node,absl::string_view op)372 Status Transposer::UpdateFaninEdgesWithOp(TransposeContext* context,
373                                           absl::Span<const int> dst_ports,
374                                           utils::MutableNodeView* dst_node,
375                                           absl::string_view op) {
376   const bool is_in_frame = context->frames.IsInFrame(*dst_node->node());
377   for (int dst_port : dst_ports) {
378     auto& fanin_port = dst_node->GetRegularFanin(dst_port);
379     auto* fanin_node_view = fanin_port.node_view();
380 
381     TF_RETURN_IF_ERROR(
382         UpdateEdge(context,
383                    GetFaninNameFormat(dst_node->GetName(), dst_port,
384                                       context->src_format, context->dst_format),
385                    op, /*input_shape=*/nullptr, /*is_in_frame=*/is_in_frame,
386                    /*is_src_format_to_dst_format=*/true, fanin_port.index(),
387                    dst_port, fanin_node_view, dst_node));
388   }
389   return OkStatus();
390 }
391 
UpdateFanoutEdgesWithOp(TransposeContext * context,absl::Span<const int> src_ports,utils::MutableNodeView * src_node,absl::string_view op)392 Status Transposer::UpdateFanoutEdgesWithOp(TransposeContext* context,
393                                            absl::Span<const int> src_ports,
394                                            utils::MutableNodeView* src_node,
395                                            absl::string_view op) {
396   // Update attr _output_shapes for output ports.
397   const auto* output_shape_attr = src_node->GetAttr(kAttrOutputShape);
398   AttrValue shape_attr_copy;
399   if (op == kOpTranspose && output_shape_attr != nullptr) {
400     shape_attr_copy = *output_shape_attr;
401     for (int port : src_ports) {
402       auto* shape = shape_attr_copy.mutable_list()->mutable_shape(port);
403       if (shape->unknown_rank()) continue;
404       TF_RETURN_IF_ERROR(
405           PermuteSingle(absl::StrCat("output shape attribute at port ", port,
406                                      " in", src_node->GetName()),
407                         context->src_to_dst, shape->mutable_dim()));
408     }
409     context->graph_view->GetMutationBuilder()->AddOrUpdateNodeAttr(
410         src_node, kAttrOutputShape, shape_attr_copy);
411   }
412 
413   const bool is_in_frame = context->frames.IsInFrame(*src_node->node());
414   // We might modify the output set in the loop. Make a copy first.
415   // Use a set with custom comparator to order output nodes by node name,
416   // so that we can keep transposer name deterministic.
417   for (int src_port : src_ports) {
418     const auto& fanouts_src_port = src_node->GetRegularFanout(src_port);
419     std::vector<utils::MutableFaninView> sorted_fanouts(
420         fanouts_src_port.begin(), fanouts_src_port.end());
421     std::sort(sorted_fanouts.begin(), sorted_fanouts.end(),
422               ComparatorByNodeNameAndIndex());
423     int num_downstream_transposers = 0;
424     for (const auto& fanout : sorted_fanouts) {
425       TF_RETURN_IF_ERROR(UpdateEdge(
426           context,
427           GetFanoutNameFormat(src_node->GetName(), src_port,
428                               num_downstream_transposers++, context->src_format,
429                               context->dst_format),
430           op, &shape_attr_copy, /*is_in_frame=*/is_in_frame,
431           /*is_src_format_to_dst_format=*/false, src_port, fanout.index(),
432           src_node, fanout.node_view()));
433     }
434   }
435   return OkStatus();
436 }
437 
CreateDataFormatNode(TransposeContext * context,absl::string_view node_name,absl::string_view op,absl::string_view device,const DataType & data_type,bool is_fanin_on_host,bool is_src_format_to_dst_format,utils::MutationNewNode * added_node)438 Status Transposer::CreateDataFormatNode(
439     TransposeContext* context, absl::string_view node_name,
440     absl::string_view op, absl::string_view device, const DataType& data_type,
441     bool is_fanin_on_host, bool is_src_format_to_dst_format,
442     utils::MutationNewNode* added_node) {
443   auto* graph_view = context->graph_view.get();
444   DCHECK(!graph_view->HasNode(node_name));
445 
446   // Create the node
447   NodeDef node;
448   node.set_name(string(node_name));
449 
450   // Set up parameters of node.
451   node.set_op(string(op));
452   node.set_device(string(device));
453   AttrValue attr_data_type;
454   attr_data_type.set_type(data_type);
455   node.mutable_attr()->insert({"T", attr_data_type});
456 
457   // The inputs of a DataFormat op could be in host memory for ops such as
458   // Reshape. In such cases, run the kernel on the host too.
459   if (is_fanin_on_host) {
460     AttrValue attr_kernel;
461     attr_kernel.set_s("host");
462     node.mutable_attr()->insert({"_kernel", attr_kernel});
463   }
464 
465   AttrValue src_format;
466   src_format.set_s(is_src_format_to_dst_format ? context->src_format
467                                                : context->dst_format);
468   node.mutable_attr()->insert({kAttrSrcFormat, src_format});
469   AttrValue dst_format;
470   dst_format.set_s(is_src_format_to_dst_format ? context->dst_format
471                                                : context->src_format);
472   node.mutable_attr()->insert({kAttrDstFormat, dst_format});
473 
474   // Add place holder for 1st input field.
475   node.add_input("");
476 
477   Status status;
478   *added_node =
479       graph_view->GetMutationBuilder()->AddNode(std::move(node), &status);
480   return status;
481 }
482 
UpdateEdge(TransposeContext * context,absl::string_view name_format,absl::string_view op,const AttrValue * input_shape,bool is_in_frame,bool is_src_format_to_dst_format,const int src_port,const int dst_port,utils::MutableNodeView * src_node,utils::MutableNodeView * dst_node)483 Status Transposer::UpdateEdge(
484     TransposeContext* context, absl::string_view name_format,
485     absl::string_view op, const AttrValue* input_shape, bool is_in_frame,
486     bool is_src_format_to_dst_format, const int src_port, const int dst_port,
487     utils::MutableNodeView* src_node, utils::MutableNodeView* dst_node) {
488   DCHECK(src_node != nullptr);
489   DCHECK(dst_node != nullptr);
490   auto* src_node_def = src_node->node();
491   auto* dst_node_def = dst_node->node();
492 
493   // TODO(lyandy): Minimize device parsing/fetching.
494   const string device = GetDeviceName(
495       is_src_format_to_dst_format ? *dst_node_def : *src_node_def);
496   DataType data_type =
497       is_src_format_to_dst_format
498           ? context->graph_properties
499                 ->GetInputProperties(dst_node->GetName())[dst_port]
500                 .dtype()
501           : context->graph_properties
502                 ->GetOutputProperties(src_node->GetName())[src_port]
503                 .dtype();
504 
505   utils::MutationNewNode added_node;
506   string added_node_name;
507   if (op == kOpTranspose) {
508     TensorShapeProto input_shape_proto;
509     input_shape_proto.set_unknown_rank(true);
510     if (input_shape != nullptr) {
511       input_shape_proto = input_shape->list().shape(src_port);
512     } else {
513       const auto* src_node_shape_attr = src_node->GetAttr(kAttrOutputShape);
514       if (src_node_shape_attr != nullptr) {
515         input_shape_proto = src_node_shape_attr->list().shape(src_port);
516       }
517     }
518     const string control_node_name =
519         is_in_frame ? AsControlDependency(src_node_def->name()) : "";
520     const std::vector<int>& permutation =
521         is_src_format_to_dst_format ? context->src_to_dst : context->dst_to_src;
522     TF_RETURN_IF_ERROR(CreateTransposeNode(
523         context, name_format, data_type, device, input_shape_proto, permutation,
524         control_node_name, &added_node, &added_node_name));
525   } else if (op == kOpDataFormatVecPermute || op == kOpDataFormatDimMap) {
526     DeviceNameUtils::ParsedName parsed_name;
527     bool is_fanin_on_host = DeviceNameUtils::ParseFullName(
528                                 GetDeviceName(*src_node_def), &parsed_name) &&
529                             parsed_name.type != "CPU" &&
530                             IsHostMemory(*src_node_def, src_port);
531     const string node_name = absl::Substitute(name_format, op);
532     TF_RETURN_IF_ERROR(CreateDataFormatNode(
533         context, node_name, op, device, data_type, is_fanin_on_host,
534         is_src_format_to_dst_format, &added_node));
535     added_node_name = node_name;
536   } else {
537     return Status(error::INVALID_ARGUMENT,
538                   absl::StrCat("Unsupported op \"", op,
539                                "\". Supported ops are Transpose, "
540                                "DataFormatVecPerm, DataFormatDimMap."));
541   }
542 
543   // Connect src_node to 1st input of added_node.
544   utils::Mutation* mutation = context->graph_view->GetMutationBuilder();
545   mutation->AddOrUpdateRegularFanin(added_node, 0,
546                                     {src_node->GetName(), src_port});
547 
548   // Connect output of added_node to dst_node:dst_port.
549   mutation->AddOrUpdateRegularFanin(dst_node, dst_port, {added_node_name, 0});
550 
551   return OkStatus();
552 }
553 
GetFanoutPortRank(const utils::MutableNodeView & node,int port) const554 int Transposer::GetFanoutPortRank(const utils::MutableNodeView& node,
555                                   int port) const {
556   const auto* output_shape_attr = node.GetAttr(kAttrOutputShape);
557   if (output_shape_attr == nullptr ||
558       output_shape_attr->list().shape_size() <= port) {
559     return kInvalidRank;
560   }
561   const auto& shape = output_shape_attr->list().shape(port);
562   if (shape.unknown_rank()) {
563     return kUnknownRank;
564   }
565   return shape.dim_size();
566 }
567 
IsFanoutPortRankN(const utils::MutableNodeView & node,int port,int n) const568 bool Transposer::IsFanoutPortRankN(const utils::MutableNodeView& node, int port,
569                                    int n) const {
570   return GetFanoutPortRank(node, port) == n;
571 }
572 
IsFanoutPortsRankN(const utils::MutableNodeView & node,absl::Span<const int> ports,int n) const573 bool Transposer::IsFanoutPortsRankN(const utils::MutableNodeView& node,
574                                     absl::Span<const int> ports, int n) const {
575   for (const auto& port : ports) {
576     if (!IsFanoutPortRankN(node, port, n)) {
577       return false;
578     }
579   }
580   return true;
581 }
582 
GetFaninPortRank(const utils::MutableNodeView & node,int port) const583 int Transposer::GetFaninPortRank(const utils::MutableNodeView& node,
584                                  int port) const {
585   if (port < node.NumRegularFanins() && port >= 0) {
586     const auto& regular_fanin = node.GetRegularFanin(port);
587     return GetFanoutPortRank(*regular_fanin.node_view(), regular_fanin.index());
588   }
589   return kInvalidRank;
590 }
591 
IsFaninPortRankN(const utils::MutableNodeView & node,int port,int n) const592 bool Transposer::IsFaninPortRankN(const utils::MutableNodeView& node, int port,
593                                   int n) const {
594   return GetFaninPortRank(node, port) == n;
595 }
596 
IsFaninPortDimsNIfConst(const utils::MutableNodeView & node,int port,absl::Span<const int> dims) const597 bool Transposer::IsFaninPortDimsNIfConst(const utils::MutableNodeView& node,
598                                          int port,
599                                          absl::Span<const int> dims) const {
600   if (port < node.NumRegularFanins() && port >= 0) {
601     const auto& regular_fanin = node.GetRegularFanin(port);
602     const auto* fanin_node_view = regular_fanin.node_view();
603     if (!IsConstant(*fanin_node_view->node())) {
604       return true;
605     }
606     // If fanin is a Const, check tensor to see if dimensions match.
607     const auto* value_attr = fanin_node_view->GetAttr(kAttrValue);
608     if (value_attr == nullptr) {
609       return false;
610     }
611     Tensor tensor;
612     if (!tensor.FromProto(value_attr->tensor())) {
613       return false;
614     }
615     const int dims_size = dims.size();
616     if (tensor.dims() != dims_size) {
617       return false;
618     }
619     for (int i = 0; i < dims_size; ++i) {
620       if (tensor.dim_size(i) != dims[i]) {
621         return false;
622       }
623     }
624     return true;
625   }
626   return false;
627 }
628 
IsFaninPortsDimsNIfConst(const utils::MutableNodeView & node,absl::Span<const int> ports,absl::Span<const int> dims) const629 bool Transposer::IsFaninPortsDimsNIfConst(const utils::MutableNodeView& node,
630                                           absl::Span<const int> ports,
631                                           absl::Span<const int> dims) const {
632   for (const auto& port : ports) {
633     if (!IsFaninPortDimsNIfConst(node, port, dims)) {
634       return false;
635     }
636   }
637   return true;
638 }
639 
CanProcessNode(const TransposeContext & context,const utils::MutableNodeView & node) const640 bool Transposer::CanProcessNode(const TransposeContext& context,
641                                 const utils::MutableNodeView& node) const {
642   return !context.nodes_to_preserve.contains(node.GetName()) &&
643          !(node.NumRegularFanouts() == 0 && node.NumControlledFanouts() == 0);
644 }
645 
GetFaninNameFormat(absl::string_view node_name,int port,absl::string_view src_format,absl::string_view dst_format)646 string Transposer::GetFaninNameFormat(absl::string_view node_name, int port,
647                                       absl::string_view src_format,
648                                       absl::string_view dst_format) {
649   return absl::StrCat(node_name, "-", port, "-$0", src_format, "To", dst_format,
650                       "-", kOptimizedSuffix);
651 }
652 
GetFanoutNameFormat(absl::string_view node_name,int port,int index,absl::string_view src_format,absl::string_view dst_format)653 string Transposer::GetFanoutNameFormat(absl::string_view node_name, int port,
654                                        int index, absl::string_view src_format,
655                                        absl::string_view dst_format) {
656   return absl::StrCat(node_name, "-", port, "-", index, "-$0", dst_format, "To",
657                       src_format, "-", kOptimizedSuffix);
658 }
659 
LayoutOptimizerNode(absl::string_view node_name)660 string Transposer::LayoutOptimizerNode(absl::string_view node_name) {
661   return absl::StrCat(node_name, "-", kOptimizedSuffix);
662 }
663 
GetReshapeNodeNameFormat(absl::string_view node_name,int index,absl::string_view src_format,absl::string_view dst_format)664 string Transposer::GetReshapeNodeNameFormat(absl::string_view node_name,
665                                             int index,
666                                             absl::string_view src_format,
667                                             absl::string_view dst_format) {
668   return absl::StrCat(node_name, "-", index, "-", kReshape, src_format, "To",
669                       dst_format);
670 }
671 
GetShapeConstNodeNameFormat(absl::string_view node_name,int index)672 string Transposer::GetShapeConstNodeNameFormat(absl::string_view node_name,
673                                                int index) {
674   return absl::StrCat(node_name, "-", index, "-", kReshapeConst);
675 }
676 
677 // Layout sensitive transposer.
678 
GetLayoutSensitiveNodeDataFormat(const utils::MutableNodeView & node)679 inline string GetLayoutSensitiveNodeDataFormat(
680     const utils::MutableNodeView& node) {
681   const auto* attr = node.GetAttr(kAttrDataFormat);
682   if (attr != nullptr) {
683     return attr->s();
684   }
685   return "";
686 }
687 
UpdateNode(TransposeContext * context,utils::MutableNodeView * node)688 Status LayoutSensitiveOpTransposer::UpdateNode(TransposeContext* context,
689                                                utils::MutableNodeView* node) {
690   utils::Mutation* mutation = context->graph_view->GetMutationBuilder();
691   AttrValue data_format_attr;
692   data_format_attr.set_s(context->dst_format);
693   mutation->AddOrUpdateNodeAttr(node, kAttrDataFormat, data_format_attr);
694 
695   auto permute_attr = [&context, &node,
696                        &mutation](absl::string_view attr_name) {
697     const auto* attr = node->GetAttr(attr_name);
698     if (attr != nullptr) {
699       AttrValue attr_copy(*attr);
700       TF_RETURN_IF_ERROR(PermuteSingle(
701           absl::StrCat(attr_name, " attribute in", node->GetName()),
702           context->src_to_dst, attr_copy.mutable_list()->mutable_i()));
703       mutation->AddOrUpdateNodeAttr(node, attr_name, attr_copy);
704     }
705     return OkStatus();
706   };
707 
708   // Update attrs.
709   TF_RETURN_IF_ERROR(permute_attr(kAttrStrides));
710   TF_RETURN_IF_ERROR(permute_attr(kAttrKSize));
711   TF_RETURN_IF_ERROR(permute_attr(kAttrDilations));
712 
713   const auto* explicit_paddings_attr = node->GetAttr(kAttrExplicitPaddings);
714   if (explicit_paddings_attr != nullptr && explicit_paddings_attr->has_list() &&
715       explicit_paddings_attr->list().i_size() > 0) {
716     AttrValue explicit_paddings_attr_copy(*explicit_paddings_attr);
717     TF_RETURN_IF_ERROR(PermuteDouble(
718         absl::StrCat("explicit_paddings attribute in", node->GetName()),
719         context->src_to_dst,
720         explicit_paddings_attr_copy.mutable_list()->mutable_i()));
721     mutation->AddOrUpdateNodeAttr(node, kAttrExplicitPaddings,
722                                   explicit_paddings_attr_copy);
723   }
724 
725   return OkStatus();
726 }
727 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)728 Status DefaultLayoutSensitiveOpTransposer::TransposeNode(
729     TransposeContext* context, utils::MutableNodeView* node) {
730   DCHECK(IsDefaultLayoutSensitiveOp(*node->node()));
731   const int rank = GetFanoutPortRank(*node, 0);
732   if (rank != 4 && rank != 5) {
733     return OkStatus();
734   }
735   ScopedDataFormatUpgrader data_format_upgrader(context, rank);
736   if (!ShouldProcess(*context, *node)) {
737     return OkStatus();
738   }
739   VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
740           << "' with op '" << node->GetOp() << "' from data format '"
741           << context->src_format << "' to '" << context->dst_format << "'";
742   TF_RETURN_IF_ERROR(UpdateNode(context, node));
743   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
744   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
745   return context->graph_view->GetMutationBuilder()->Apply();
746 }
747 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)748 Status AvgPoolGradTransposer::TransposeNode(TransposeContext* context,
749                                             utils::MutableNodeView* node) {
750   DCHECK(IsAvgPoolGrad(*node->node()));
751   if (!ShouldProcess(*context, *node) || !IsFaninPortRankN(*node, 1, 4)) {
752     return OkStatus();
753   }
754   VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
755           << "' with op '" << node->GetOp() << "' from data format '"
756           << context->src_format << "' to '" << context->dst_format << "'";
757   TF_RETURN_IF_ERROR(UpdateNode(context, node));
758   TF_RETURN_IF_ERROR(
759       UpdateFaninEdgesWithOp(context, {0}, node, kOpDataFormatVecPermute));
760   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {1}, node, kOpTranspose));
761   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
762   return context->graph_view->GetMutationBuilder()->Apply();
763 }
764 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)765 Status BiasAddTransposer::TransposeNode(TransposeContext* context,
766                                         utils::MutableNodeView* node) {
767   // This TransposeNode allows for BiasAdd but not BiasAddV1, since BiasAdd
768   // supports different data format.
769   DCHECK(IsBiasAddV2(*node->node()));
770   const int rank = GetFanoutPortRank(*node, 0);
771   if (rank != 4 && rank != 5) {
772     return OkStatus();
773   }
774   if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, rank)) {
775     return OkStatus();
776   }
777   VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
778           << "' with op '" << node->GetOp() << "' from data format '"
779           << context->src_format << "' to '" << context->dst_format << "'";
780   // BiasAdd itself only needs NCHW/NHWC to determine whether C dim is the
781   // second or the last dim. Therefore, we use the original 4D data format in
782   // the context to update the node. For the input/output tensor, the
783   // corresponding 4D or 5D data format is needed.
784   TF_RETURN_IF_ERROR(UpdateNode(context, node));
785   ScopedDataFormatUpgrader data_format_upgrader(context, rank);
786   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
787   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
788   return context->graph_view->GetMutationBuilder()->Apply();
789 }
790 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)791 Status BiasAddGradTransposer::TransposeNode(TransposeContext* context,
792                                             utils::MutableNodeView* node) {
793   DCHECK(IsBiasAddGrad(*node->node()));
794   const int rank = GetFaninPortRank(*node, 0);
795   if (rank != 4 && rank != 5) {
796     return OkStatus();
797   }
798   if (!ShouldProcess(*context, *node)) {
799     return OkStatus();
800   }
801   VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
802           << "' with op '" << node->GetOp() << "' from data format '"
803           << context->src_format << "' to '" << context->dst_format << "'";
804   // BiasAddGrad itself only needs NCHW/NHWC to determine whether C dim is the
805   // second or the last dim. Therefore, we use the original 4D data format in
806   // the context to update the node. For the input tensor, the corresponding 4D
807   // or 5D data format is needed.
808   TF_RETURN_IF_ERROR(UpdateNode(context, node));
809   ScopedDataFormatUpgrader data_format_upgrader(context, rank);
810   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
811   // No need to update output shape, as it is always of shape 1-D with size the
812   // feature dimension of `out_backprop`, regardless of whether NCHW or NHWC is
813   // used.
814   return context->graph_view->GetMutationBuilder()->Apply();
815 }
816 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)817 Status Conv2DBackpropFilterTransposer::TransposeNode(
818     TransposeContext* context, utils::MutableNodeView* node) {
819   DCHECK(IsConv2DBackpropFilter(*node->node()) ||
820          IsDepthwiseConv2dNativeBackpropFilter(*node->node()));
821   if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4)) {
822     return OkStatus();
823   }
824   VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
825           << "' with op '" << node->GetOp() << "' from data format '"
826           << context->src_format << "' to '" << context->dst_format << "'";
827   TF_RETURN_IF_ERROR(UpdateNode(context, node));
828   TF_RETURN_IF_ERROR(
829       UpdateFaninEdgesWithOp(context, {0, 2}, node, kOpTranspose));
830   // No need to update output shape, as it is always of shape
831   // [filter_height, filter_width, in_channels, out_channels], regardless of
832   // whether NCHW or NHWC is used.
833   return context->graph_view->GetMutationBuilder()->Apply();
834 }
835 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)836 Status Conv2DBackpropInputTransposer::TransposeNode(
837     TransposeContext* context, utils::MutableNodeView* node) {
838   DCHECK(IsConv2DBackpropInput(*node->node()) ||
839          IsDepthwiseConv2dNativeBackpropInput(*node->node()));
840   if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4)) {
841     return OkStatus();
842   }
843 
844   const auto& fanin = node->GetRegularFanin(0);
845   auto* fanin_node = fanin.node_view();
846   const auto* output_shape_attr = fanin_node->GetAttr(kAttrOutputShape);
847   if (output_shape_attr == nullptr) {
848     VLOG(3) << "Cannot compute the shape of " << fanin_node->GetName()
849             << " because it is missing attribute " << kAttrOutputShape;
850     return OkStatus();
851   }
852   TensorShapeProto fanin_shape = output_shape_attr->list().shape(fanin.index());
853   if (fanin_shape.dim_size() != 1) {
854     VLOG(3) << fanin_node->GetName() << " is not a vector.";
855     return OkStatus();
856   }
857 
858   VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
859           << "' with op '" << node->GetOp() << "' from data format '"
860           << context->src_format << "' to '" << context->dst_format << "'";
861   TF_RETURN_IF_ERROR(UpdateNode(context, node));
862   TF_RETURN_IF_ERROR(
863       UpdateFaninEdgesWithOp(context, {0}, node, kOpDataFormatVecPermute));
864   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {2}, node, kOpTranspose));
865   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
866   return context->graph_view->GetMutationBuilder()->Apply();
867 }
868 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)869 Status Conv3DTransposer::TransposeNode(TransposeContext* context,
870                                        utils::MutableNodeView* node) {
871   DCHECK(IsConv3D(*node->node()));
872   const int rank = GetFanoutPortRank(*node, 0);
873   if (rank != 5) {
874     return OkStatus();
875   }
876   ScopedDataFormatUpgrader data_format_upgrader(context, rank);
877   if (!ShouldProcess(*context, *node)) {
878     return OkStatus();
879   }
880   VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
881           << "' with op '" << node->GetOp() << "' from data format '"
882           << context->src_format << "' to '" << context->dst_format << "'";
883   TF_RETURN_IF_ERROR(UpdateNode(context, node));
884   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
885   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
886   return context->graph_view->GetMutationBuilder()->Apply();
887 }
888 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)889 Status Conv3DBackpropFilterTransposer::TransposeNode(
890     TransposeContext* context, utils::MutableNodeView* node) {
891   DCHECK(IsConv3DBackpropFilterV2(*node->node()));
892   const int rank = GetFanoutPortRank(*node, 0);
893   if (rank != 5) {
894     return OkStatus();
895   }
896   ScopedDataFormatUpgrader data_format_upgrader(context, rank);
897   if (!ShouldProcess(*context, *node)) {
898     return OkStatus();
899   }
900   VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
901           << "' with op '" << node->GetOp() << "' from data format '"
902           << context->src_format << "' to '" << context->dst_format << "'";
903   TF_RETURN_IF_ERROR(UpdateNode(context, node));
904   TF_RETURN_IF_ERROR(
905       UpdateFaninEdgesWithOp(context, {0, 2}, node, kOpTranspose));
906   // No need to update output shape, as it is always of shape
907   // [filter_height, filter_width, in_channels, out_channels], regardless of
908   // whether NCHW or NHWC is used.
909   return context->graph_view->GetMutationBuilder()->Apply();
910 }
911 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)912 Status Conv3DBackpropInputTransposer::TransposeNode(
913     TransposeContext* context, utils::MutableNodeView* node) {
914   DCHECK(IsConv3DBackpropInputV2(*node->node()));
915   const int rank = GetFanoutPortRank(*node, 0);
916   if (rank != 5) {
917     return OkStatus();
918   }
919   ScopedDataFormatUpgrader data_format_upgrader(context, rank);
920   if (!ShouldProcess(*context, *node)) {
921     return OkStatus();
922   }
923   VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
924           << "' with op '" << node->GetOp() << "' from data format '"
925           << context->src_format << "' to '" << context->dst_format << "'";
926   TF_RETURN_IF_ERROR(UpdateNode(context, node));
927   TF_RETURN_IF_ERROR(
928       UpdateFaninEdgesWithOp(context, {0}, node, kOpDataFormatVecPermute));
929   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {2}, node, kOpTranspose));
930   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
931   return context->graph_view->GetMutationBuilder()->Apply();
932 }
933 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)934 Status FusedBatchNormExTransposer::TransposeNode(TransposeContext* context,
935                                                  utils::MutableNodeView* node) {
936   DCHECK(IsFusedBatchNormEx(*node->node()));
937   if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4)) {
938     return OkStatus();
939   }
940   VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
941           << "' with op '" << node->GetOp() << "' from data format '"
942           << context->src_format << "' to '" << context->dst_format << "'";
943   TF_RETURN_IF_ERROR(UpdateNode(context, node));
944   if (node->NumRegularFanins() == 6) {
945     TF_RETURN_IF_ERROR(
946         UpdateFaninEdgesWithOp(context, {0, 5}, node, kOpTranspose));
947   } else {
948     TF_RETURN_IF_ERROR(
949         UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
950   }
951   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
952   return context->graph_view->GetMutationBuilder()->Apply();
953 }
954 
IsTraining(const utils::MutableNodeView & node) const955 bool FusedBatchNormGradTransposer::IsTraining(
956     const utils::MutableNodeView& node) const {
957   const auto* is_training_attr = node.GetAttr(kAttrIsTraining);
958   if (is_training_attr != nullptr) {
959     return is_training_attr->b();
960   }
961   return false;
962 }
963 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)964 Status FusedBatchNormGradTransposer::TransposeNode(
965     TransposeContext* context, utils::MutableNodeView* node) {
966   DCHECK(IsFusedBatchNormGrad(*node->node()));
967   const int rank = GetFanoutPortRank(*node, 0);
968   if (rank != 4 && rank != 5) {
969     return OkStatus();
970   }
971   ScopedDataFormatUpgrader data_format_upgrader(context, rank);
972   if (!ShouldProcess(*context, *node) || !IsTraining(*node)) {
973     return OkStatus();
974   }
975   VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
976           << "' with op '" << node->GetOp() << "' from data format '"
977           << context->src_format << "' to '" << context->dst_format << "'";
978   TF_RETURN_IF_ERROR(UpdateNode(context, node));
979   TF_RETURN_IF_ERROR(
980       UpdateFaninEdgesWithOp(context, {0, 1}, node, kOpTranspose));
981   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
982   return context->graph_view->GetMutationBuilder()->Apply();
983 }
984 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)985 Status MaxPoolV2Transposer::TransposeNode(TransposeContext* context,
986                                           utils::MutableNodeView* node) {
987   DCHECK(IsMaxPoolV2(*node->node()));
988   // We check data_input's shape instead, because the shape inference of
989   // MaxPoolV2 is not able to infer the shape when ksize or strides is not
990   // constant.
991   const auto& data_fanin = node->GetRegularFanin(0);
992   auto* data_fanin_node = data_fanin.node_view();
993   if (!ShouldProcess(*context, *node) ||
994       !IsFanoutPortRankN(*data_fanin_node, data_fanin.index(), 4)) {
995     return OkStatus();
996   }
997   VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
998           << "' with op '" << node->GetOp() << "' from data format '"
999           << context->src_format << "' to '" << context->dst_format << "'";
1000   TF_RETURN_IF_ERROR(UpdateNode(context, node));
1001   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
1002   TF_RETURN_IF_ERROR(
1003       UpdateFaninEdgesWithOp(context, {1, 2}, node, kOpDataFormatVecPermute));
1004   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1005   return context->graph_view->GetMutationBuilder()->Apply();
1006 }
1007 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1008 Status MaxPoolGradTransposer::TransposeNode(TransposeContext* context,
1009                                             utils::MutableNodeView* node) {
1010   DCHECK(IsMaxPoolGrad(*node->node()) || IsMaxPoolGradGradV1(*node->node()));
1011   if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4)) {
1012     return OkStatus();
1013   }
1014   VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
1015           << "' with op '" << node->GetOp() << "' from data format '"
1016           << context->src_format << "' to '" << context->dst_format << "'";
1017   TF_RETURN_IF_ERROR(UpdateNode(context, node));
1018   TF_RETURN_IF_ERROR(
1019       UpdateFaninEdgesWithOp(context, {0, 1, 2}, node, kOpTranspose));
1020   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1021   return context->graph_view->GetMutationBuilder()->Apply();
1022 }
1023 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1024 Status MaxPoolGradV2Transposer::TransposeNode(TransposeContext* context,
1025                                               utils::MutableNodeView* node) {
1026   DCHECK(IsMaxPoolGradV2(*node->node()) || IsMaxPoolGradGradV2(*node->node()));
1027   if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4)) {
1028     return OkStatus();
1029   }
1030   VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
1031           << "' with op '" << node->GetOp() << "' from data format '"
1032           << context->src_format << "' to '" << context->dst_format << "'";
1033   TF_RETURN_IF_ERROR(UpdateNode(context, node));
1034   TF_RETURN_IF_ERROR(
1035       UpdateFaninEdgesWithOp(context, {0, 1, 2}, node, kOpTranspose));
1036   TF_RETURN_IF_ERROR(
1037       UpdateFaninEdgesWithOp(context, {3, 4}, node, kOpDataFormatVecPermute));
1038   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1039   return context->graph_view->GetMutationBuilder()->Apply();
1040 }
1041 
1042 // Layout agnostic transposer.
1043 
IsValidConstPermTransposeNode(const utils::MutableNodeView & node,absl::Span<const int> permutation)1044 inline bool IsValidConstPermTransposeNode(const utils::MutableNodeView& node,
1045                                           absl::Span<const int> permutation) {
1046   Tensor tensor;
1047   if (!GetValueAttrFromConstInputNode(node, IsTranspose, 1, &tensor)) {
1048     return false;
1049   }
1050   const int permutation_size = permutation.size();
1051   if (tensor.NumElements() != permutation_size) {
1052     return false;
1053   }
1054 
1055   const auto& tensor_data = tensor.unaligned_flat<int32>();
1056   for (int i = 0; i < permutation_size; i++) {
1057     if (permutation[i] != tensor_data(i)) {
1058       return false;
1059     }
1060   }
1061   return true;
1062 }
1063 
IsValidDataFormatNode(const utils::MutableNodeView & node,absl::string_view src_format,absl::string_view dst_format)1064 inline bool IsValidDataFormatNode(const utils::MutableNodeView& node,
1065                                   absl::string_view src_format,
1066                                   absl::string_view dst_format) {
1067   if (!IsDataFormatOp(node)) {
1068     return false;
1069   }
1070   const auto* src_format_attr = node.GetAttr(kAttrSrcFormat);
1071   if (src_format_attr == nullptr || src_format_attr->s() != src_format) {
1072     return false;
1073   }
1074   const auto* dst_format_attr = node.GetAttr(kAttrDstFormat);
1075   if (dst_format_attr == nullptr || dst_format_attr->s() != dst_format) {
1076     return false;
1077   }
1078   return true;
1079 }
1080 
IsLayoutOptimizerAddedDstToSrcTranspose(const TransposeContext & context,const utils::MutableNodeView & node)1081 inline bool IsLayoutOptimizerAddedDstToSrcTranspose(
1082     const TransposeContext& context, const utils::MutableNodeView& node) {
1083   return node.node_index() >= context.num_nodes &&
1084          IsValidConstPermTransposeNode(node, context.dst_to_src);
1085 }
1086 
IsLayoutOptimizerAddedDstToSrcTransform(const TransposeContext & context,const utils::MutableNodeView & node)1087 inline bool IsLayoutOptimizerAddedDstToSrcTransform(
1088     const TransposeContext& context, const utils::MutableNodeView& node) {
1089   return node.node_index() >= context.num_nodes &&
1090          (IsValidConstPermTransposeNode(node, context.dst_to_src) ||
1091           IsValidDataFormatNode(node, context.dst_format, context.src_format));
1092 }
1093 
IsAfterDstToSrcTransform(const TransposeContext & context,const utils::MutableNodeView & node) const1094 bool LayoutAgnosticOpTransposer::IsAfterDstToSrcTransform(
1095     const TransposeContext& context, const utils::MutableNodeView& node) const {
1096   std::deque<utils::MutableNodeView*> queue;
1097   absl::flat_hash_set<utils::MutableNodeView*> visited_nodes;
1098   auto data_node_pos = GetDataFaninPorts(node);
1099   for (const int pos : data_node_pos) {
1100     const auto& fanin = node.GetRegularFanin(pos);
1101     auto* fanin_node = fanin.node_view();
1102     queue.push_back(fanin_node);
1103     visited_nodes.insert(fanin_node);
1104   }
1105   // The code will exit this while loop in one iteration in most cases, as the
1106   // graph is already topologically sorted.
1107   while (!queue.empty()) {
1108     utils::MutableNodeView* current_node = queue.front();
1109     queue.pop_front();
1110     if (IsLayoutOptimizerAddedDstToSrcTransform(context, *current_node)) {
1111       return true;
1112     }
1113     // We only continue searching if the path is connected through
1114     // format-agnostic nodes.
1115     if (IsLayoutAgnosticOp(*current_node->node())) {
1116       auto current_node_pos = GetDataFaninPorts(*current_node);
1117       for (const auto& pos : current_node_pos) {
1118         const auto& fanin = current_node->GetRegularFanin(pos);
1119         auto* fanin_node = fanin.node_view();
1120         if (visited_nodes.insert(fanin_node).second) {
1121           queue.push_back(fanin_node);
1122         }
1123       }
1124     }
1125   }
1126   return false;
1127 }
1128 
GetVariadicNDFaninPorts(const TransposeContext & context,const utils::MutableNodeView & node,int rank) const1129 std::vector<int> LayoutAgnosticOpTransposer::GetVariadicNDFaninPorts(
1130     const TransposeContext& context, const utils::MutableNodeView& node,
1131     int rank) const {
1132   std::vector<int> ports;
1133   const int num_regular_fanins = node.NumRegularFanins();
1134   ports.reserve(num_regular_fanins);
1135   for (int i = 0; i < num_regular_fanins; ++i) {
1136     const auto& regular_fanin = node.GetRegularFanin(i);
1137     auto* regular_fanin_node = regular_fanin.node_view();
1138     int regular_fanin_port = regular_fanin.index();
1139     if ((IsFanoutPortRankN(*regular_fanin_node, regular_fanin_port, rank)) &&
1140         ((IsAfterDstToSrcTransform(context, *regular_fanin_node) &&
1141           IsLayoutAgnosticOp(*regular_fanin_node->node())) ||
1142          IsLayoutOptimizerAddedDstToSrcTranspose(context,
1143                                                  *regular_fanin_node))) {
1144       ports.push_back(i);
1145     }
1146   }
1147   return ports;
1148 }
1149 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1150 Status DefaultLayoutAgnosticOpTransposer::TransposeNode(
1151     TransposeContext* context, utils::MutableNodeView* node) {
1152   DCHECK(IsDefaultLayoutAgnosticOp(*node->node()));
1153   const int rank = GetFanoutPortRank(*node, 0);
1154   if (rank != 4 && rank != 5) {
1155     return OkStatus();
1156   }
1157   ScopedDataFormatUpgrader data_format_upgrader(context, rank);
1158   if (!ShouldProcess(*context, *node) ||
1159       !IsAfterDstToSrcTransform(*context, *node)) {
1160     return OkStatus();
1161   }
1162   VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
1163           << "' with op '" << node->GetOp() << "' from data format '"
1164           << context->src_format << "' to '" << context->dst_format << "'";
1165   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
1166   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1167   return context->graph_view->GetMutationBuilder()->Apply();
1168 }
1169 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1170 Status AddNTransposer::TransposeNode(TransposeContext* context,
1171                                      utils::MutableNodeView* node) {
1172   DCHECK(IsAddN(*node->node()));
1173   const int rank = GetFanoutPortRank(*node, 0);
1174   if (rank != 4 && rank != 5) {
1175     return OkStatus();
1176   }
1177   ScopedDataFormatUpgrader data_format_upgrader(context, rank);
1178   if (!ShouldProcess(*context, *node) ||
1179       !IsAfterDstToSrcTransform(*context, *node)) {
1180     return OkStatus();
1181   }
1182   VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
1183           << "' with op '" << node->GetOp() << "' from data format '"
1184           << context->src_format << "' to '" << context->dst_format << "'";
1185   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, GetDataFaninPorts(*node),
1186                                             node, kOpTranspose));
1187   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1188   return context->graph_view->GetMutationBuilder()->Apply();
1189 }
1190 
IsNDOperateWithMD(const utils::MutableNodeView & node,int n,int m)1191 bool BinaryOpTransposer::IsNDOperateWithMD(const utils::MutableNodeView& node,
1192                                            int n, int m) {
1193   return IsFaninPortRankN(node, 0, n) && IsFaninPortRankN(node, 1, m);
1194 }
1195 
IsFaninShapeSupported(const utils::MutableNodeView & node,int rank)1196 bool BinaryOpTransposer::IsFaninShapeSupported(
1197     const utils::MutableNodeView& node, int rank) {
1198   return (IsNDOperateWithMD(node, rank, 0) ||
1199           IsNDOperateWithMD(node, rank, 1) ||
1200           IsNDOperateWithMD(node, rank, rank) ||
1201           IsNDOperateWithMD(node, 0, rank) || IsNDOperateWithMD(node, 1, rank));
1202 }
1203 
GetNDDataFaninPorts(const utils::MutableNodeView & node,int rank)1204 std::vector<int> BinaryOpTransposer::GetNDDataFaninPorts(
1205     const utils::MutableNodeView& node, int rank) {
1206   std::vector<int> values;
1207   if (IsFaninPortRankN(node, 0, rank)) {
1208     values.push_back(0);
1209   }
1210   if (IsFaninPortRankN(node, 1, rank)) {
1211     values.push_back(1);
1212   }
1213   return values;
1214 }
1215 
AddNodeReshape(utils::Mutation * mutation,absl::string_view node_name,absl::string_view node_device,absl::string_view input_name,absl::string_view shape_const_node_name,const DataType & data_type)1216 Status BinaryOpTransposer::AddNodeReshape(
1217     utils::Mutation* mutation, absl::string_view node_name,
1218     absl::string_view node_device, absl::string_view input_name,
1219     absl::string_view shape_const_node_name, const DataType& data_type) {
1220   NodeDef new_node;
1221   new_node.set_name(string(node_name));
1222   new_node.add_input(string(input_name));
1223   new_node.add_input(string(shape_const_node_name));
1224   new_node.set_op(kReshape);
1225   new_node.set_device(string(node_device));
1226 
1227   AttrValue attr_type_indices;
1228   attr_type_indices.set_type(DT_INT32);
1229   new_node.mutable_attr()->insert({"Tshape", attr_type_indices});
1230 
1231   AttrValue attr_type_params;
1232   attr_type_params.set_type(data_type);
1233   new_node.mutable_attr()->insert({"T", attr_type_params});
1234 
1235   Status status;
1236   mutation->AddNode(std::move(new_node), &status);
1237   return status;
1238 }
1239 
AddNodeShapeConst(utils::Mutation * mutation,absl::string_view node_name,absl::string_view node_device,bool node_in_frame,int num_channels,absl::string_view depended_node,int rank)1240 Status BinaryOpTransposer::AddNodeShapeConst(
1241     utils::Mutation* mutation, absl::string_view node_name,
1242     absl::string_view node_device, bool node_in_frame, int num_channels,
1243     absl::string_view depended_node, int rank) {
1244   NodeDef new_node;
1245   new_node.set_name(string(node_name));
1246   new_node.set_op(kOpConst);
1247   new_node.set_device(string(node_device));
1248   AttrValue attr_data_type;
1249   attr_data_type.set_type(DT_INT32);
1250   new_node.mutable_attr()->insert({"dtype", attr_data_type});
1251 
1252   AttrValue attr_tensor;
1253   Tensor tensor(DT_INT32, TensorShape({rank}));
1254   std::vector<int> shape(rank, 1);
1255   shape[1] = num_channels;
1256   for (int i = 0; i < static_cast<int>(shape.size()); i++) {
1257     tensor.flat<int>()(i) = shape[i];
1258   }
1259   tensor.AsProtoTensorContent(attr_tensor.mutable_tensor());
1260   new_node.mutable_attr()->insert({"value", attr_tensor});
1261   if (node_in_frame) {
1262     // This is to ensure the transpose node and the const node are in the same
1263     // frame.
1264     // TODO(halehri): Add Test that exercises this condition.
1265     new_node.add_input(AsControlDependency(string(depended_node)));
1266   }
1267 
1268   Status status;
1269   mutation->AddNode(std::move(new_node), &status);
1270   return status;
1271 }
1272 
MaybeReshapeVectorFanin(TransposeContext * context,utils::MutableNodeView * node,int rank)1273 Status BinaryOpTransposer::MaybeReshapeVectorFanin(TransposeContext* context,
1274                                                    utils::MutableNodeView* node,
1275                                                    int rank) {
1276   int vector_index = -1;
1277   if (IsNDOperateWithMD(*node, rank, 1)) {
1278     vector_index = 1;
1279   } else if (IsNDOperateWithMD(*node, 1, rank)) {
1280     vector_index = 0;
1281   }
1282   if (vector_index != -1) {
1283     const string& node_name = node->GetName();
1284     const string& node_device = node->GetDevice();
1285     string reshape_node_name = LayoutOptimizerNode(GetReshapeNodeNameFormat(
1286         node_name, vector_index, context->src_format, context->dst_format));
1287     string shape_const_node_name = LayoutOptimizerNode(
1288         GetShapeConstNodeNameFormat(node_name, vector_index));
1289     const auto& fanin = node->GetRegularFanin(vector_index);
1290     auto* fanin_node = fanin.node_view();
1291     const auto* output_shape_attr = fanin_node->GetAttr(kAttrOutputShape);
1292     if (output_shape_attr == nullptr) {
1293       return errors::InvalidArgument("Missing attribute ", kAttrOutputShape);
1294     }
1295     int vector_size =
1296         output_shape_attr->list().shape(fanin.index()).dim(0).size();
1297     utils::Mutation* mutation = context->graph_view->GetMutationBuilder();
1298     TF_RETURN_IF_ERROR(
1299         AddNodeShapeConst(mutation, shape_const_node_name, node_device,
1300                           context->frames.IsInFrame(*node->node()), vector_size,
1301                           fanin_node->GetName(), rank));
1302     const auto* t_attr = node->GetAttr(kAttrT);
1303     if (t_attr == nullptr) {
1304       return errors::InvalidArgument("Missing attribute ", kAttrT);
1305     }
1306     TF_RETURN_IF_ERROR(
1307         AddNodeReshape(mutation, reshape_node_name, node_device,
1308                        TensorIdToString({fanin_node->GetName(), fanin.index()}),
1309                        shape_const_node_name, t_attr->type()));
1310     mutation->AddOrUpdateRegularFanin(node, vector_index,
1311                                       {reshape_node_name, 0});
1312   }
1313   return OkStatus();
1314 }
1315 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1316 Status BinaryOpTransposer::TransposeNode(TransposeContext* context,
1317                                          utils::MutableNodeView* node) {
1318   DCHECK(IsBinaryOp(*node->node()));
1319   const int rank = GetFanoutPortRank(*node, 0);
1320   if (rank != 4 && rank != 5) {
1321     return OkStatus();
1322   }
1323   ScopedDataFormatUpgrader data_format_upgrader(context, rank);
1324   if (!ShouldProcess(*context, *node) || !IsFaninShapeSupported(*node, rank) ||
1325       !IsAfterDstToSrcTransform(*context, *node)) {
1326     return OkStatus();
1327   }
1328   VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
1329           << "' with op '" << node->GetOp() << "' from data format '"
1330           << context->src_format << "' to '" << context->dst_format << "'";
1331   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(
1332       context, GetNDDataFaninPorts(*node, rank), node, kOpTranspose));
1333   TF_RETURN_IF_ERROR(MaybeReshapeVectorFanin(context, node, rank));
1334   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1335   return context->graph_view->GetMutationBuilder()->Apply();
1336 }
1337 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1338 Status ConcatOpTransposer::TransposeNode(TransposeContext* context,
1339                                          utils::MutableNodeView* node) {
1340   DCHECK(IsConcat(*node->node()));
1341   const int rank = GetFanoutPortRank(*node, 0);
1342   if (rank != 4 && rank != 5) {
1343     return OkStatus();
1344   }
1345   ScopedDataFormatUpgrader data_format_upgrader(context, rank);
1346   if (!ShouldProcess(*context, *node) ||
1347       !IsAfterDstToSrcTransform(*context, *node)) {
1348     return OkStatus();
1349   }
1350   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(
1351       context, GetConcatDataFaninPorts(*node), node, kOpTranspose));
1352   int axis_node = 0;
1353   if (node->GetOp() == "ConcatV2") {
1354     const auto* n_attr = node->GetAttr(kAttrN);
1355     if (n_attr != nullptr) {
1356       axis_node = n_attr->i();
1357     }
1358   }
1359   VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
1360           << "' with op '" << node->GetOp() << "' from data format '"
1361           << context->src_format << "' to '" << context->dst_format << "'";
1362   TF_RETURN_IF_ERROR(
1363       UpdateFaninEdgesWithOp(context, {axis_node}, node, kOpDataFormatDimMap));
1364   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1365   return context->graph_view->GetMutationBuilder()->Apply();
1366 }
1367 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1368 Status FillOpTransposer::TransposeNode(TransposeContext* context,
1369                                        utils::MutableNodeView* node) {
1370   DCHECK(IsFill(*node->node()));
1371   if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
1372       !IsFaninPortDimsNIfConst(*node, 0, {4}) ||
1373       !IsAfterDstToSrcTransform(*context, *node)) {
1374     return OkStatus();
1375   }
1376   TF_RETURN_IF_ERROR(
1377       UpdateFaninEdgesWithOp(context, {0}, node, kOpDataFormatVecPermute));
1378   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1379   return context->graph_view->GetMutationBuilder()->Apply();
1380 }
1381 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1382 Status IdentityNTransposer::TransposeNode(TransposeContext* context,
1383                                           utils::MutableNodeView* node) {
1384   DCHECK(IsIdentityN(*node->node()));
1385   const auto ports_4d = GetVariadicNDFaninPorts(*context, *node, 4);
1386 
1387   // Temporarily upgrade the context to obtain the number of 5D fanin ports.
1388   std::vector<int> ports_5d;
1389   {
1390     ScopedDataFormatUpgrader data_format_upgrader(context, 5);
1391     ports_5d = GetVariadicNDFaninPorts(*context, *node, 5);
1392   }
1393 
1394   if (!ShouldProcess(*context, *node)) {
1395     return OkStatus();
1396   }
1397 
1398   if (!ports_4d.empty()) {
1399     TF_RETURN_IF_ERROR(
1400         UpdateFaninEdgesWithOp(context, ports_4d, node, kOpTranspose));
1401     TF_RETURN_IF_ERROR(
1402         UpdateFanoutEdgesWithOp(context, ports_4d, node, kOpTranspose));
1403   }
1404 
1405   if (!ports_5d.empty()) {
1406     ScopedDataFormatUpgrader data_format_upgrader(context, 5);
1407     TF_RETURN_IF_ERROR(
1408         UpdateFaninEdgesWithOp(context, ports_5d, node, kOpTranspose));
1409     TF_RETURN_IF_ERROR(
1410         UpdateFanoutEdgesWithOp(context, ports_5d, node, kOpTranspose));
1411   }
1412   return context->graph_view->GetMutationBuilder()->Apply();
1413 }
1414 
IsEveryFaninAfterDstToSrcTransform(const TransposeContext & context,const utils::MutableNodeView & node) const1415 bool MergeTransposer::IsEveryFaninAfterDstToSrcTransform(
1416     const TransposeContext& context, const utils::MutableNodeView& node) const {
1417   for (const auto& regular_fanin : node.GetRegularFanins()) {
1418     auto* regular_fanin_node = regular_fanin.node_view();
1419     if (IsFanoutPortRankN(*regular_fanin_node, regular_fanin.index(), 4) &&
1420         ((IsAfterDstToSrcTransform(context, *regular_fanin_node) &&
1421           IsLayoutAgnosticOp(*regular_fanin_node->node())) ||
1422          IsLayoutOptimizerAddedDstToSrcTranspose(context,
1423                                                  *regular_fanin_node))) {
1424       continue;
1425     }
1426     return false;
1427   }
1428   return true;
1429 }
1430 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1431 Status MergeTransposer::TransposeNode(TransposeContext* context,
1432                                       utils::MutableNodeView* node) {
1433   DCHECK(IsMerge(*node->node()));
1434   if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
1435       !IsEveryFaninAfterDstToSrcTransform(*context, *node)) {
1436     return OkStatus();
1437   }
1438   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, GetDataFaninPorts(*node),
1439                                             node, kOpTranspose));
1440   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1441   return context->graph_view->GetMutationBuilder()->Apply();
1442 }
1443 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1444 Status PadTransposer::TransposeNode(TransposeContext* context,
1445                                     utils::MutableNodeView* node) {
1446   DCHECK(IsMirrorPad(*node->node()) || IsMirrorPadGrad(*node->node()) ||
1447          IsPad(*node->node()));
1448   if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
1449       !IsFaninPortDimsNIfConst(*node, 1, {4, 2}) ||
1450       !IsAfterDstToSrcTransform(*context, *node)) {
1451     return OkStatus();
1452   }
1453   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
1454   TF_RETURN_IF_ERROR(
1455       UpdateFaninEdgesWithOp(context, {1}, node, kOpDataFormatVecPermute));
1456   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1457   return context->graph_view->GetMutationBuilder()->Apply();
1458 }
1459 
KeepDims(const utils::MutableNodeView & node)1460 bool ReduceTransposer::KeepDims(const utils::MutableNodeView& node) {
1461   const auto* keep_dims_attr = node.GetAttr(kAttrKeepDims);
1462   if (keep_dims_attr != nullptr) {
1463     return keep_dims_attr->b();
1464   }
1465   return false;
1466 }
1467 
IsAlongAxis(const Tensor & tensor,absl::Span<const int> axis,int rank)1468 bool ReduceTransposer::IsAlongAxis(const Tensor& tensor,
1469                                    absl::Span<const int> axis, int rank) {
1470   const int axis_size = axis.size();
1471   if (tensor.dims() != 1 || tensor.dim_size(0) != axis_size) {
1472     return false;
1473   }
1474   for (int i = 0; i < axis_size; ++i) {
1475     int local_axis = 0;
1476     if (tensor.dtype() == DT_INT32) {
1477       local_axis = tensor.flat<int32>()(i);
1478     } else {
1479       local_axis = tensor.flat<int64_t>()(i);
1480     }
1481     if (local_axis < 0) {
1482       local_axis += rank;
1483     }
1484     bool along_axis = false;
1485     for (int dim : axis) {
1486       if (local_axis == dim) {
1487         along_axis = true;
1488         break;
1489       }
1490     }
1491     if (!along_axis) {
1492       return false;
1493     }
1494   }
1495   return true;
1496 }
1497 
IsReduceAxisSupported(const TransposeContext & context,const utils::MutableNodeView & node,int rank)1498 bool ReduceTransposer::IsReduceAxisSupported(const TransposeContext& context,
1499                                              const utils::MutableNodeView& node,
1500                                              int rank) {
1501   if (KeepDims(node)) {
1502     return true;
1503   }
1504   const auto& regular_fanin_1 = node.GetRegularFanin(1);
1505   auto* axis_node = regular_fanin_1.node_view();
1506   if (!IsConstant(*axis_node->node())) {
1507     return false;
1508   }
1509   const auto* value_attr = axis_node->GetAttr(kAttrValue);
1510   if (value_attr == nullptr) {
1511     return false;
1512   }
1513   Tensor tensor;
1514   if (!tensor.FromProto(value_attr->tensor())) {
1515     LOG(ERROR) << "Failed to parse TensorProto.";
1516     return false;
1517   }
1518   auto indices = [&context](absl::Span<const char> labels) {
1519     return GetDimensionIndicesFromLabel(context.src_dim_indices, labels);
1520   };
1521   if (rank == 5) {
1522     return IsAlongAxis(tensor, indices({'N', 'D', 'H', 'W', 'C'}), 5) ||
1523            IsAlongAxis(tensor, indices({'D', 'H', 'W', 'C'}), 5) ||
1524            IsAlongAxis(tensor, indices({'N', 'D', 'H', 'W'}), 5) ||
1525            IsAlongAxis(tensor, indices({'D', 'H', 'W'}), 5) ||
1526            IsAlongAxis(tensor, indices({'C'}), 5);
1527   }
1528   DCHECK_EQ(rank, 4);
1529   return IsAlongAxis(tensor, indices({'N', 'H', 'W', 'C'}), 4) ||
1530          IsAlongAxis(tensor, indices({'H', 'W', 'C'}), 4) ||
1531          IsAlongAxis(tensor, indices({'N', 'H', 'W'}), 4) ||
1532          IsAlongAxis(tensor, indices({'H', 'W'}), 4) ||
1533          IsAlongAxis(tensor, indices({'C'}), 4);
1534 }
1535 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1536 Status ReduceTransposer::TransposeNode(TransposeContext* context,
1537                                        utils::MutableNodeView* node) {
1538   DCHECK(IsReduceOp(*node->node()));
1539   const int rank = GetFaninPortRank(*node, 0);
1540   if (rank != 4 && rank != 5) {
1541     return OkStatus();
1542   }
1543   ScopedDataFormatUpgrader data_format_upgrader(context, rank);
1544   if (!ShouldProcess(*context, *node) ||
1545       !IsReduceAxisSupported(*context, *node, rank) ||
1546       !IsAfterDstToSrcTransform(*context, *node)) {
1547     return OkStatus();
1548   }
1549   VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
1550           << "' with op '" << node->GetOp() << "' from data format '"
1551           << context->src_format << "' to '" << context->dst_format << "'";
1552   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
1553   TF_RETURN_IF_ERROR(
1554       UpdateFaninEdgesWithOp(context, {1}, node, kOpDataFormatDimMap));
1555   if (KeepDims(*node)) {
1556     TF_RETURN_IF_ERROR(
1557         UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1558   }
1559   return context->graph_view->GetMutationBuilder()->Apply();
1560 }
1561 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1562 Status ReverseV2Transposer::TransposeNode(TransposeContext* context,
1563                                           utils::MutableNodeView* node) {
1564   DCHECK(IsReverseV2(*node->node()));
1565   if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
1566       !IsAfterDstToSrcTransform(*context, *node)) {
1567     return OkStatus();
1568   }
1569   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
1570   TF_RETURN_IF_ERROR(
1571       UpdateFaninEdgesWithOp(context, {1}, node, kOpDataFormatDimMap));
1572   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1573   return context->graph_view->GetMutationBuilder()->Apply();
1574 }
1575 
IsFaninScalarVector4D(const utils::MutableNodeView & fanin,int port)1576 bool SelectTransposer::IsFaninScalarVector4D(
1577     const utils::MutableNodeView& fanin, int port) {
1578   return IsFanoutPortRankN(fanin, port, 0) ||
1579          IsFanoutPortRankN(fanin, port, 1) || IsFanoutPortRankN(fanin, port, 4);
1580 }
1581 
GetFaninPorts(const utils::MutableNodeView & fanin,int port)1582 std::vector<int> SelectTransposer::GetFaninPorts(
1583     const utils::MutableNodeView& fanin, int port) {
1584   // Input 0 could be a scalar, a vector with size matching the first dimension
1585   // of input 1 and 2, or must have the same shape as input 1 and 2.
1586   if (IsFanoutPortRankN(fanin, port, 4)) {
1587     return {0, 1, 2};
1588   }
1589   return {1, 2};
1590 }
1591 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1592 Status SelectTransposer::TransposeNode(TransposeContext* context,
1593                                        utils::MutableNodeView* node) {
1594   DCHECK(IsSelect(*node->node()));
1595   const auto& regular_fanin_0 = node->GetRegularFanin(0);
1596   auto* regular_fanin_0_node = regular_fanin_0.node_view();
1597   if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
1598       !IsFaninScalarVector4D(*regular_fanin_0_node, regular_fanin_0.index()) ||
1599       !IsAfterDstToSrcTransform(*context, *node)) {
1600     return OkStatus();
1601   }
1602   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(
1603       context, GetFaninPorts(*regular_fanin_0_node, regular_fanin_0.index()),
1604       node, kOpTranspose));
1605   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1606   return context->graph_view->GetMutationBuilder()->Apply();
1607 }
1608 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1609 Status ShapeTransposer::TransposeNode(TransposeContext* context,
1610                                       utils::MutableNodeView* node) {
1611   DCHECK(IsShape(*node->node()));
1612   const int rank = GetFaninPortRank(*node, 0);
1613   if (rank != 4 && rank != 5) {
1614     return OkStatus();
1615   }
1616   ScopedDataFormatUpgrader data_format_upgrader(context, rank);
1617   if (!ShouldProcess(*context, *node) ||
1618       !IsAfterDstToSrcTransform(*context, *node)) {
1619     return OkStatus();
1620   }
1621   VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
1622           << "' with op '" << node->GetOp() << "' from data format '"
1623           << context->src_format << "' to '" << context->dst_format << "'";
1624   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
1625   TF_RETURN_IF_ERROR(
1626       UpdateFanoutEdgesWithOp(context, {0}, node, kOpDataFormatVecPermute));
1627   return context->graph_view->GetMutationBuilder()->Apply();
1628 }
1629 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1630 Status ShapeNTransposer::TransposeNode(TransposeContext* context,
1631                                        utils::MutableNodeView* node) {
1632   DCHECK(IsShapeN(*node->node()));
1633   // ShapeN requires all input tensors to have the same dimensions. Therefore,
1634   // we simply use the 0th fanin port.
1635   const int rank = GetFaninPortRank(*node, 0);
1636   if (rank != 4 && rank != 5) {
1637     return OkStatus();
1638   }
1639   ScopedDataFormatUpgrader data_format_upgrader(context, rank);
1640   const auto ports = GetVariadicNDFaninPorts(*context, *node, rank);
1641   if (!ShouldProcess(*context, *node) || ports.empty()) {
1642     return OkStatus();
1643   }
1644   VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
1645           << "' with op '" << node->GetOp() << "' from data format '"
1646           << context->src_format << "' to '" << context->dst_format << "'";
1647   TF_RETURN_IF_ERROR(
1648       UpdateFaninEdgesWithOp(context, ports, node, kOpTranspose));
1649   TF_RETURN_IF_ERROR(
1650       UpdateFanoutEdgesWithOp(context, ports, node, kOpDataFormatVecPermute));
1651   return context->graph_view->GetMutationBuilder()->Apply();
1652 }
1653 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1654 Status SliceTransposer::TransposeNode(TransposeContext* context,
1655                                       utils::MutableNodeView* node) {
1656   DCHECK(IsSlice(*node->node()));
1657   const int rank = GetFanoutPortRank(*node, 0);
1658   if (rank != 4 && rank != 5) {
1659     return OkStatus();
1660   }
1661   ScopedDataFormatUpgrader data_format_upgrader(context, rank);
1662   if (!ShouldProcess(*context, *node) ||
1663       !IsFaninPortsDimsNIfConst(*node, {1, 2}, {rank}) ||
1664       !IsAfterDstToSrcTransform(*context, *node)) {
1665     return OkStatus();
1666   }
1667   VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
1668           << "' with op '" << node->GetOp() << "' from data format '"
1669           << context->src_format << "' to '" << context->dst_format << "'";
1670   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
1671   TF_RETURN_IF_ERROR(
1672       UpdateFaninEdgesWithOp(context, {1, 2}, node, kOpDataFormatVecPermute));
1673   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1674   return context->graph_view->GetMutationBuilder()->Apply();
1675 }
1676 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1677 Status SplitTransposer::TransposeNode(TransposeContext* context,
1678                                       utils::MutableNodeView* node) {
1679   DCHECK(IsSplit(*node->node()));
1680   const auto ports = GetDataFanoutPorts(*node);
1681   if (!ShouldProcess(*context, *node) || !IsFanoutPortsRankN(*node, ports, 4) ||
1682       !IsAfterDstToSrcTransform(*context, *node)) {
1683     return OkStatus();
1684   }
1685   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {1}, node, kOpTranspose));
1686   TF_RETURN_IF_ERROR(
1687       UpdateFaninEdgesWithOp(context, {0}, node, kOpDataFormatDimMap));
1688   TF_RETURN_IF_ERROR(
1689       UpdateFanoutEdgesWithOp(context, ports, node, kOpTranspose));
1690   return context->graph_view->GetMutationBuilder()->Apply();
1691 }
1692 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1693 Status SplitVTransposer::TransposeNode(TransposeContext* context,
1694                                        utils::MutableNodeView* node) {
1695   DCHECK(IsSplitV(*node->node()));
1696   const auto ports = GetDataFanoutPorts(*node);
1697   if (!ShouldProcess(*context, *node) || !IsFanoutPortsRankN(*node, ports, 4) ||
1698       !IsAfterDstToSrcTransform(*context, *node)) {
1699     return OkStatus();
1700   }
1701   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
1702   TF_RETURN_IF_ERROR(
1703       UpdateFaninEdgesWithOp(context, {2}, node, kOpDataFormatDimMap));
1704   TF_RETURN_IF_ERROR(
1705       UpdateFanoutEdgesWithOp(context, ports, node, kOpTranspose));
1706   return context->graph_view->GetMutationBuilder()->Apply();
1707 }
1708 
IsInputConvertible(const TransposeContext & context,const utils::MutableNodeView & node) const1709 bool SqueezeTransposer::IsInputConvertible(
1710     const TransposeContext& context, const utils::MutableNodeView& node) const {
1711   const auto& regular_fanin_0 = node.GetRegularFanin(0);
1712   auto* regular_fanin_0_node = regular_fanin_0.node_view();
1713   const auto* output_shape_attr =
1714       regular_fanin_0_node->GetAttr(kAttrOutputShape);
1715   if (output_shape_attr != nullptr) {
1716     auto& shape = output_shape_attr->list().shape(regular_fanin_0.index());
1717     if (shape.dim_size() != kRank) {
1718       return false;
1719     }
1720     const int height_dim = context.src_dim_indices.at('H');
1721     const int width_dim = context.src_dim_indices.at('W');
1722     if (shape.dim(height_dim).size() == 1 && shape.dim(width_dim).size() == 1) {
1723       return true;
1724     }
1725   }
1726   return false;
1727 }
1728 
IsAlongAxis(const AttrValue & attr,absl::Span<const int> axis,int rank) const1729 bool SqueezeTransposer::IsAlongAxis(const AttrValue& attr,
1730                                     absl::Span<const int> axis,
1731                                     int rank) const {
1732   const auto& list = attr.list();
1733   // If list is empty, Squeeze op will squeeze all dimensions of size 1.
1734   int axis_size = axis.size();
1735   if (list.i_size() == 0) {
1736     return true;
1737   } else if (list.i_size() != axis_size) {
1738     return false;
1739   }
1740   for (int i = 0; i < axis_size; ++i) {
1741     int local_axis = list.i(i);
1742     if (local_axis < 0) {
1743       local_axis += rank;
1744     }
1745     bool along_axis = false;
1746     for (int dim : axis) {
1747       if (local_axis == dim) {
1748         along_axis = true;
1749         break;
1750       }
1751     }
1752     if (!along_axis) {
1753       return false;
1754     }
1755   }
1756   return true;
1757 }
1758 
IsDimsSupported(const TransposeContext & context,const utils::MutableNodeView & node) const1759 bool SqueezeTransposer::IsDimsSupported(
1760     const TransposeContext& context, const utils::MutableNodeView& node) const {
1761   auto indices = [&context](absl::Span<const char> labels) {
1762     return GetDimensionIndicesFromLabel(context.src_dim_indices, labels);
1763   };
1764   const auto* squeeze_dims_attr = node.GetAttr(kAttrSqueezeDims);
1765   if (squeeze_dims_attr == nullptr) {
1766     return false;
1767   }
1768   return (IsFanoutPortRankN(node, 0, 2) &&
1769           IsAlongAxis(*squeeze_dims_attr, indices({'H', 'W'}), kRank)) ||
1770          (IsFanoutPortRankN(node, 0, 1) &&
1771           IsAlongAxis(*squeeze_dims_attr, indices({'N', 'H', 'W'}), kRank));
1772 }
1773 
UpdateSqueezeDims(TransposeContext * context,utils::MutableNodeView * node)1774 Status SqueezeTransposer::UpdateSqueezeDims(TransposeContext* context,
1775                                             utils::MutableNodeView* node) {
1776   const auto* squeeze_dims_attr = node->GetAttr(kAttrSqueezeDims);
1777   if (squeeze_dims_attr == nullptr) {
1778     return errors::InvalidArgument("Missing attribute ", kAttrSqueezeDims);
1779   }
1780   const int num_input_dims = context->src_format.length();
1781   const int min_squeeze_dim = -num_input_dims;
1782   std::vector<int> squeeze_dims_mapped;
1783   const int squeeze_dims_size = squeeze_dims_attr->list().i_size();
1784   squeeze_dims_mapped.reserve(squeeze_dims_size);
1785   for (int i = 0; i < squeeze_dims_size; ++i) {
1786     int dim = squeeze_dims_attr->list().i(i);
1787     if (dim < min_squeeze_dim || dim >= num_input_dims) {
1788       return errors::InvalidArgument(
1789           "Attribute '", kAttrSqueezeDims, "' contains out of range index '",
1790           dim, "', index must be between [", min_squeeze_dim, ", ",
1791           num_input_dims, ")");
1792     }
1793     if (dim < 0) {
1794       dim += num_input_dims;
1795     }
1796     squeeze_dims_mapped.push_back(context->dst_to_src[dim]);
1797   }
1798   std::sort(squeeze_dims_mapped.begin(), squeeze_dims_mapped.end());
1799   AttrValue squeeze_dims;
1800   squeeze_dims.mutable_list()->mutable_i()->Reserve(squeeze_dims_size);
1801   for (const auto& dim : squeeze_dims_mapped) {
1802     squeeze_dims.mutable_list()->mutable_i()->Add(dim);
1803   }
1804   context->graph_view->GetMutationBuilder()->AddOrUpdateNodeAttr(
1805       node, kAttrSqueezeDims, squeeze_dims);
1806   return OkStatus();
1807 }
1808 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1809 Status SqueezeTransposer::TransposeNode(TransposeContext* context,
1810                                         utils::MutableNodeView* node) {
1811   DCHECK(IsSqueeze(*node->node()));
1812   if (!ShouldProcess(*context, *node) || !IsDimsSupported(*context, *node) ||
1813       !IsInputConvertible(*context, *node) ||
1814       !IsAfterDstToSrcTransform(*context, *node)) {
1815     return OkStatus();
1816   }
1817   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
1818   TF_RETURN_IF_ERROR(UpdateSqueezeDims(context, node));
1819   return context->graph_view->GetMutationBuilder()->Apply();
1820 }
1821 
IsMaskZero(const utils::MutableNodeView & node,absl::string_view mask)1822 bool StridedSliceTransposer::IsMaskZero(const utils::MutableNodeView& node,
1823                                         absl::string_view mask) {
1824   const auto* mask_attr = node.GetAttr(mask);
1825   if (mask_attr != nullptr) {
1826     return mask_attr->i() == 0;
1827   }
1828   return true;
1829 }
1830 
HasOnlyBeginEndMask(const utils::MutableNodeView & node)1831 bool StridedSliceTransposer::HasOnlyBeginEndMask(
1832     const utils::MutableNodeView& node) {
1833   return IsMaskZero(node, "ellipsis_mask") &&
1834          IsMaskZero(node, "new_axis_mask") &&
1835          IsMaskZero(node, "shrink_axis_mask");
1836 }
1837 
PermuteMask(TransposeContext * context,utils::MutableNodeView * node,absl::string_view mask)1838 Status StridedSliceTransposer::PermuteMask(TransposeContext* context,
1839                                            utils::MutableNodeView* node,
1840                                            absl::string_view mask) {
1841   // Computers the permutation of the masks based on the src and dst format.
1842   // For example:
1843   // src_format = NHWC
1844   // dst_format = NCHW
1845   // src_to_dst permutation = [0, 3, 1, 2].
1846   // mask : 0010 [Note the bit positions correspond to indexes i.e this is in
1847   // reverse order of the src format (CWHN)] result : 0100 (WHCN)
1848   const auto* mask_attr = node->GetAttr(mask);
1849   const int mask_i = mask_attr != nullptr ? mask_attr->i() : 0;
1850   if (mask_i < 0 || mask_i > 15) {
1851     return errors::InvalidArgument("invalid mask value: ", mask_i);
1852   }
1853   int result = 0;
1854   for (int i = 0, end = context->src_to_dst.size(); i < end; i++) {
1855     const int final_pos = context->src_to_dst[i];
1856     const int position_mask = 1 << final_pos;
1857     const int bit_i = (mask_i & position_mask) >> final_pos;
1858     result |= bit_i << i;
1859   }
1860   AttrValue new_mask_attr;
1861   new_mask_attr.set_i(result);
1862   context->graph_view->GetMutationBuilder()->AddOrUpdateNodeAttr(node, mask,
1863                                                                  new_mask_attr);
1864   return OkStatus();
1865 }
1866 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1867 Status StridedSliceTransposer::TransposeNode(TransposeContext* context,
1868                                              utils::MutableNodeView* node) {
1869   DCHECK(IsStridedSlice(*node->node()));
1870   if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
1871       !IsFaninPortsDimsNIfConst(*node, {1, 2, 3}, {4}) ||
1872       !HasOnlyBeginEndMask(*node) ||
1873       !IsAfterDstToSrcTransform(*context, *node)) {
1874     return OkStatus();
1875   }
1876   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
1877   TF_RETURN_IF_ERROR(PermuteMask(context, node, "begin_mask"));
1878   TF_RETURN_IF_ERROR(PermuteMask(context, node, "end_mask"));
1879   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {1, 2, 3}, node,
1880                                             kOpDataFormatVecPermute));
1881   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1882   return context->graph_view->GetMutationBuilder()->Apply();
1883 }
1884 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1885 Status SwitchTransposer::TransposeNode(TransposeContext* context,
1886                                        utils::MutableNodeView* node) {
1887   DCHECK(IsSwitch(*node->node()));
1888   if (!ShouldProcess(*context, *node) || !IsFaninPortRankN(*node, 0, 4) ||
1889       !IsAfterDstToSrcTransform(*context, *node)) {
1890     return OkStatus();
1891   }
1892   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
1893   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, GetDataFanoutPorts(*node),
1894                                              node, kOpTranspose));
1895   return context->graph_view->GetMutationBuilder()->Apply();
1896 }
1897 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1898 Status TernaryOpTransposer::TransposeNode(TransposeContext* context,
1899                                           utils::MutableNodeView* node) {
1900   DCHECK(IsTernaryOp(*node->node()));
1901   if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
1902       !IsAfterDstToSrcTransform(*context, *node)) {
1903     return OkStatus();
1904   }
1905   TF_RETURN_IF_ERROR(
1906       UpdateFaninEdgesWithOp(context, {0, 1, 2}, node, kOpTranspose));
1907   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1908   return context->graph_view->GetMutationBuilder()->Apply();
1909 }
1910 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1911 Status TileTransposer::TransposeNode(TransposeContext* context,
1912                                      utils::MutableNodeView* node) {
1913   DCHECK(IsTile(*node->node()));
1914   if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
1915       !IsFaninPortDimsNIfConst(*node, 1, {4}) ||
1916       !IsAfterDstToSrcTransform(*context, *node)) {
1917     return OkStatus();
1918   }
1919   TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
1920   TF_RETURN_IF_ERROR(
1921       UpdateFaninEdgesWithOp(context, {1}, node, kOpDataFormatVecPermute));
1922   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1923   return context->graph_view->GetMutationBuilder()->Apply();
1924 }
1925 
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1926 Status UnaryGradTransposer::TransposeNode(TransposeContext* context,
1927                                           utils::MutableNodeView* node) {
1928   DCHECK(IsUnaryGrad(*node->node()));
1929   const int rank = GetFanoutPortRank(*node, 0);
1930   if (rank != 4 && rank != 5) {
1931     return OkStatus();
1932   }
1933   ScopedDataFormatUpgrader data_format_upgrader(context, rank);
1934   if (!ShouldProcess(*context, *node) ||
1935       !IsAfterDstToSrcTransform(*context, *node)) {
1936     return OkStatus();
1937   }
1938   VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
1939           << "' with op '" << node->GetOp() << "' from data format '"
1940           << context->src_format << "' to '" << context->dst_format << "'";
1941   TF_RETURN_IF_ERROR(
1942       UpdateFaninEdgesWithOp(context, {0, 1}, node, kOpTranspose));
1943   TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1944   return context->graph_view->GetMutationBuilder()->Apply();
1945 }
1946 
1947 // Utils.
1948 
GetDeviceName(const NodeDef & node)1949 string GetDeviceName(const NodeDef& node) { return node.device(); }
1950 
IsDefaultLayoutSensitiveOp(const NodeDef & node)1951 bool IsDefaultLayoutSensitiveOp(const NodeDef& node) {
1952   static absl::flat_hash_set<string>* default_layout_sensitive_ops =
1953       new absl::flat_hash_set<std::string>(
1954           {"AvgPool", "Conv2D", "DepthwiseConv2dNative", "DepthToSpace",
1955            "FusedBatchNorm", "FusedBatchNormV2", "FusedBatchNormV3",
1956            "FusedConv2DBiasActivation", "MaxPool", "SpaceToDepth"});
1957   return default_layout_sensitive_ops->find(node.op()) !=
1958          default_layout_sensitive_ops->end();
1959 }
1960 
IsLayoutSensitiveOp(const NodeDef & node)1961 bool IsLayoutSensitiveOp(const NodeDef& node) {
1962   return IsDefaultLayoutSensitiveOp(node) || IsAvgPoolGrad(node) ||
1963          IsBiasAddV2(node) || IsBiasAddGrad(node) ||
1964          IsConv2DBackpropFilter(node) || IsConv2DBackpropInput(node) ||
1965          IsDepthwiseConv2dNativeBackpropFilter(node) ||
1966          IsDepthwiseConv2dNativeBackpropInput(node) ||
1967          IsFusedBatchNormEx(node) || IsFusedBatchNormGrad(node) ||
1968          IsMaxPoolV2(node) || IsMaxPoolGrad(node) || IsMaxPoolGradV2(node) ||
1969          IsMaxPoolGradGradV1(node) || IsMaxPoolGradGradV2(node) ||
1970          IsConv3D(node) || IsConv3DBackpropInputV2(node) ||
1971          IsConv3DBackpropFilterV2(node);
1972 }
1973 
IsDefaultLayoutAgnosticOp(const NodeDef & node)1974 bool IsDefaultLayoutAgnosticOp(const NodeDef& node) {
1975   static absl::flat_hash_set<string>* agnostic_nodes =
1976       new absl::flat_hash_set<std::string>({"Abs",
1977                                             "Acos",
1978                                             "Acosh",
1979                                             "Angle",
1980                                             "Asin",
1981                                             "Asinh",
1982                                             "Atan",
1983                                             "Atanh",
1984                                             "Bitcast",
1985                                             "Cast",
1986                                             "Ceil",
1987                                             "CheckNumerics",
1988                                             "ComplexAbs",
1989                                             "Conj",
1990                                             "Cos",
1991                                             "Cosh",
1992                                             "Digamma",
1993                                             "Elu",
1994                                             "Enter",
1995                                             "Erf",
1996                                             "Erfc",
1997                                             "Exit",
1998                                             "Exp",
1999                                             "Expm1",
2000                                             "FakeQuantWithMinMaxVars",
2001                                             "FakeQuantWithMinMaxArgs",
2002                                             "Floor",
2003                                             "GuaranteeConst",
2004                                             "Identity",
2005                                             "Imag",
2006                                             "Inv",
2007                                             "IsFinite",
2008                                             "IsInf",
2009                                             "IsNan",
2010                                             "LeakyRelu",
2011                                             "Lgamma",
2012                                             "Log",
2013                                             "LogicalNot",
2014                                             "Log1p",
2015                                             "Neg",
2016                                             "NextIteration",
2017                                             "OnesLike",
2018                                             "PreventGradient",
2019                                             "QuantizeAndDequantizeV2",
2020                                             "QuantizeAndDequantizeV3",
2021                                             "QuantizeAndDequantizeV4",
2022                                             "Real",
2023                                             "Reciprocal",
2024                                             "Relu",
2025                                             "Relu6",
2026                                             "Rint",
2027                                             "Selu",
2028                                             "Sigmoid",
2029                                             "Sign",
2030                                             "Sin",
2031                                             "Sinh",
2032                                             "Snapshot",
2033                                             "Softplus",
2034                                             "Round",
2035                                             "Rsqrt",
2036                                             "Sqrt",
2037                                             "Square",
2038                                             "StopGradient",
2039                                             "Tan",
2040                                             "Tanh",
2041                                             "ZerosLike"});
2042   return agnostic_nodes->find(node.op()) != agnostic_nodes->end();
2043 }
2044 
IsLayoutAgnosticOp(const NodeDef & node)2045 bool IsLayoutAgnosticOp(const NodeDef& node) {
2046   return IsDefaultLayoutAgnosticOp(node) || IsAddN(node) || IsBinaryOp(node) ||
2047          IsIdentityN(node) || IsMerge(node) || IsMirrorPad(node) ||
2048          IsMirrorPadGrad(node) || IsPad(node) || IsSelect(node) ||
2049          IsSwitch(node) || IsTernaryOp(node) || IsUnaryGrad(node) ||
2050          IsConcat(node) || IsReverseV2(node) || IsTile(node) || IsShape(node) ||
2051          IsShapeN(node) || IsFill(node) || IsSlice(node) || IsSplit(node) ||
2052          IsSqueeze(node) || IsSplitV(node) || IsStridedSlice(node) ||
2053          IsReduceOp(node);
2054 }
2055 
IsTernaryOp(const NodeDef & node)2056 bool IsTernaryOp(const NodeDef& node) { return IsBetainc(node); }
2057 
IsUnaryGrad(const NodeDef & node)2058 bool IsUnaryGrad(const NodeDef& node) {
2059   bool is_unary_grad =
2060       IsEluGrad(node) || IsInvGrad(node) || IsLeakyReluGrad(node) ||
2061       IsReciprocalGrad(node) || IsRelu6Grad(node) || IsReluGrad(node) ||
2062       IsRsqrtGrad(node) || IsSeluGrad(node) || IsSigmoidGrad(node) ||
2063       IsSoftplusGrad(node) || IsSoftsignGrad(node) || IsSqrtGrad(node) ||
2064       IsTanhGrad(node);
2065   return is_unary_grad;
2066 }
2067 
IsMaxPoolV2(const NodeDef & node)2068 bool IsMaxPoolV2(const NodeDef& node) { return node.op() == "MaxPoolV2"; }
2069 
IsMaxPoolGradV2(const NodeDef & node)2070 bool IsMaxPoolGradV2(const NodeDef& node) {
2071   return node.op() == "MaxPoolGradV2";
2072 }
2073 
IsMaxPoolGradGradV1(const NodeDef & node)2074 bool IsMaxPoolGradGradV1(const NodeDef& node) {
2075   return node.op() == "MaxPoolGradGrad";
2076 }
2077 
IsMaxPoolGradGradV2(const NodeDef & node)2078 bool IsMaxPoolGradGradV2(const NodeDef& node) {
2079   return node.op() == "MaxPoolGradGradV2";
2080 }
2081 
IsBinaryOp(const NodeDef & node)2082 bool IsBinaryOp(const NodeDef& node) {
2083   bool is_binary =
2084       IsAdd(node) || IsAtan2(node) || IsComparisonOp(node) || IsComplex(node) ||
2085       IsDiv(node) || IsFloorDiv(node) || IsIgamma(node) || IsIgammac(node) ||
2086       IsLogicalAnd(node) || IsLogicalOr(node) || IsMaximum(node) ||
2087       IsMinimum(node) || IsMod(node) || IsMul(node) || IsPolygamma(node) ||
2088       IsPow(node) || IsRealDiv(node) || IsSquaredDifference(node) ||
2089       IsSub(node) || IsTruncateDiv(node) || IsTruncateMod(node) || IsZeta(node);
2090   return is_binary;
2091 }
2092 
IsReduceOp(const NodeDef & node)2093 bool IsReduceOp(const NodeDef& node) {
2094   return IsSum(node) || IsMean(node) || IsProd(node) || IsMax(node) ||
2095          IsMin(node) || IsAll(node) || IsAny(node);
2096 }
2097 
GetDataFaninPorts(const utils::MutableNodeView & node)2098 std::vector<int> GetDataFaninPorts(const utils::MutableNodeView& node) {
2099   const auto* node_def = node.node();
2100   if (IsAvgPoolGrad(*node_def) || IsSplit(*node_def)) {
2101     return {1};
2102   }
2103   if (IsStridedSliceGrad(*node_def)) {
2104     return {4};
2105   }
2106   if (IsBinaryOp(*node_def) || IsUnaryGrad(*node_def)) {
2107     return {0, 1};
2108   }
2109   if (IsTernaryOp(*node_def) || IsSelect(*node_def) ||
2110       IsMaxPoolGrad(*node_def) || IsMaxPoolGradV2(*node_def) ||
2111       IsMaxPoolGradGradV1(*node_def) || IsMaxPoolGradGradV2(*node_def)) {
2112     return {0, 1, 2};
2113   }
2114   if (IsShapeN(*node_def) || IsIdentityN(*node_def) || IsAddN(*node_def) ||
2115       IsMerge(*node_def)) {
2116     return GetRegularFaninPorts(node);
2117   }
2118   if (IsConcat(*node_def)) {
2119     return GetConcatDataFaninPorts(node);
2120   }
2121   if (node.NumRegularFanins() > 0) {
2122     return {0};
2123   }
2124   return {};
2125 }
2126 
GetDataFanoutPorts(const utils::MutableNodeView & node)2127 std::vector<int> GetDataFanoutPorts(const utils::MutableNodeView& node) {
2128   const auto* node_def = node.node();
2129   if (IsIdentityN(*node_def) || IsShape(*node_def) || IsShapeN(*node_def)) {
2130     return GetDataFaninPorts(node);
2131   }
2132   if (IsSplit(*node_def) || IsSplitV(*node_def)) {
2133     const auto* num_split_attr = node.GetAttr(kAttrNumSplit);
2134     if (num_split_attr == nullptr) {
2135       return {0};
2136     }
2137     std::vector<int> values(num_split_attr->i());
2138     std::iota(values.begin(), values.end(), 0);
2139     return values;
2140   }
2141   if (IsSwitch(*node_def)) {
2142     const auto* num_outs_attr = node.GetAttr(kAttrNumOuts);
2143     const int num_outs = num_outs_attr != nullptr ? num_outs_attr->i() : 2;
2144     std::vector<int> values(num_outs);
2145     std::iota(values.begin(), values.end(), 0);
2146     return values;
2147   }
2148   return {0};
2149 }
2150 
GetValueAttrFromConstInputNode(const utils::MutableNodeView & node,const std::function<bool (const NodeDef &)> & predicate,int index,Tensor * tensor)2151 bool GetValueAttrFromConstInputNode(
2152     const utils::MutableNodeView& node,
2153     const std::function<bool(const NodeDef&)>& predicate, int index,
2154     Tensor* tensor) {
2155   if (!predicate(*node.node())) {
2156     return false;
2157   }
2158   const auto& regular_fanin = node.GetRegularFanin(index);
2159   auto* regular_fanin_node = regular_fanin.node_view();
2160   if (!IsConstant(*regular_fanin_node->node())) {
2161     return false;
2162   }
2163   const auto* value_attr = regular_fanin_node->GetAttr(kAttrValue);
2164   if (value_attr == nullptr || value_attr->tensor().dtype() != DT_INT32) {
2165     return false;
2166   }
2167   if (!tensor->FromProto(value_attr->tensor())) {
2168     return false;
2169   }
2170 
2171   return true;
2172 }
2173 
IsDataFormatOp(const utils::MutableNodeView & node)2174 bool IsDataFormatOp(const utils::MutableNodeView& node) {
2175   const string& op = node.GetOp();
2176   return op == kOpDataFormatDimMap || op == kOpDataFormatVecPermute;
2177 }
2178 
GetDimensionIndices(absl::string_view data_format)2179 absl::flat_hash_map<char, int> GetDimensionIndices(
2180     absl::string_view data_format) {
2181   const int size = data_format.size();
2182   absl::flat_hash_map<char, int> index;
2183   index.reserve(size);
2184   for (int i = 0; i < size; i++) {
2185     index[data_format[i]] = i;
2186   }
2187   return index;
2188 }
2189 
GetPermutation(const absl::flat_hash_map<char,int> & src_dim_indices,absl::string_view dst_format)2190 std::vector<int> GetPermutation(
2191     const absl::flat_hash_map<char, int>& src_dim_indices,
2192     absl::string_view dst_format) {
2193   // Generate permutation for transformation between src and dst format.
2194   // Example:
2195   // src = NWHC, dst = NCWH
2196   // index = { N:0 W:1 H:2 C:3 }
2197   // permutation = [0, 3, 1, 2]
2198   DCHECK(src_dim_indices.size() == dst_format.size());
2199   std::vector<int> permutation;
2200   const int size = dst_format.size();
2201   permutation.reserve(size);
2202   for (int i = 0; i < size; i++) {
2203     permutation.push_back(src_dim_indices.at(dst_format[i]));
2204   }
2205   return permutation;
2206 }
2207 
2208 }  // namespace grappler
2209 }  // namespace tensorflow
2210