1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.h"
17
18 #include <algorithm>
19 #include <numeric>
20 #include <vector>
21
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/strings/ascii.h"
24 #include "absl/strings/match.h"
25 #include "absl/strings/numbers.h"
26 #include "absl/strings/substitute.h"
27 #include "tensorflow/core/framework/attr_value.pb.h"
28 #include "tensorflow/core/framework/graph.pb.h"
29 #include "tensorflow/core/framework/memory_types.h"
30 #include "tensorflow/core/framework/node_def.pb.h"
31 #include "tensorflow/core/framework/tensor.pb.h"
32 #include "tensorflow/core/framework/tensor_shape.pb.h"
33 #include "tensorflow/core/framework/types.h"
34 #include "tensorflow/core/framework/types.pb.h"
35 #include "tensorflow/core/grappler/costs/graph_properties.h"
36 #include "tensorflow/core/grappler/op_types.h"
37 #include "tensorflow/core/grappler/utils.h"
38 #include "tensorflow/core/grappler/utils/frame.h"
39 #include "tensorflow/core/grappler/utils/graph_view.h"
40 #include "tensorflow/core/lib/core/errors.h"
41 #include "tensorflow/core/protobuf/device_properties.pb.h"
42 #include "tensorflow/core/util/device_name_utils.h"
43
44 namespace tensorflow {
45 namespace grappler {
46
47 namespace {
48
49 constexpr char kOptimizedSuffix[] = "LayoutOptimizer";
50 constexpr char kAttrKSize[] = "ksize";
51 constexpr char kAttrStrides[] = "strides";
52 constexpr char kAttrDilations[] = "dilations";
53 constexpr char kAttrExplicitPaddings[] = "explicit_paddings";
54 constexpr char kAttrDataFormat[] = "data_format";
55 constexpr char kAttrIsTraining[] = "is_training";
56 constexpr char kAttrValue[] = "value";
57 constexpr char kAttrN[] = "N";
58 constexpr char kAttrT[] = "T";
59 constexpr char kAttrNumSplit[] = "num_split";
60 constexpr char kAttrNumOuts[] = "num_outs";
61 constexpr char kAttrKeepDims[] = "keep_dims";
62 constexpr char kAttrSqueezeDims[] = "squeeze_dims";
63 constexpr char kOpTranspose[] = "Transpose";
64 constexpr char kOpDataFormatVecPermute[] = "DataFormatVecPermute";
65 constexpr char kOpDataFormatDimMap[] = "DataFormatDimMap";
66 constexpr char kOpConst[] = "Const";
67 constexpr char kReshape[] = "Reshape";
68 constexpr char kReshapeConst[] = "ReshapeConst";
69 constexpr int kRank = 4;
70 constexpr int kUnknownRank = -1;
71 constexpr int kInvalidRank = -2;
72
AttrDataFormatMatch(const utils::MutableNodeView & node,absl::string_view src_data_format,bool * missing)73 inline bool AttrDataFormatMatch(const utils::MutableNodeView& node,
74 absl::string_view src_data_format,
75 bool* missing) {
76 const auto* attr = node.GetAttr(kAttrDataFormat);
77 if (attr != nullptr) {
78 return attr->s() == src_data_format;
79 }
80 *missing = true;
81 return false;
82 }
83
AttrDataFormatMatch(const utils::MutableNodeView & node,absl::string_view src_data_format)84 inline bool AttrDataFormatMatch(const utils::MutableNodeView& node,
85 absl::string_view src_data_format) {
86 bool missing = false;
87 return AttrDataFormatMatch(node, src_data_format, &missing);
88 }
89
IsNonFloatingConv2D(const utils::MutableNodeView & node)90 bool IsNonFloatingConv2D(const utils::MutableNodeView& node) {
91 if (IsConv2D(*node.node()) || IsConv2DBackpropInput(*node.node())) {
92 const auto* attr = node.GetAttr(kAttrT);
93 if (attr != nullptr) {
94 return !kDataTypeIsFloating.Contains(attr->type());
95 }
96 }
97 return false;
98 }
99
100 // Utils for layout agnostic transposer.
101
IsComparisonOp(const NodeDef & node)102 bool IsComparisonOp(const NodeDef& node) {
103 bool is_compare = IsApproximateEqual(node) || IsEqual(node) ||
104 IsGreater(node) || IsGreaterEqual(node) || IsLess(node) ||
105 IsLessEqual(node) || IsNotEqual(node);
106 return is_compare;
107 }
108
GetRegularFaninPorts(const utils::MutableNodeView & node)109 std::vector<int> GetRegularFaninPorts(const utils::MutableNodeView& node) {
110 const int num_regular_fanins = node.NumRegularFanins();
111 std::vector<int> values(num_regular_fanins);
112 std::iota(values.begin(), values.end(), 0);
113 return values;
114 }
115
GetConcatDataFaninPorts(const utils::MutableNodeView & node)116 std::vector<int> GetConcatDataFaninPorts(const utils::MutableNodeView& node) {
117 const auto* n_attr = node.GetAttr(kAttrN);
118 const int n = n_attr != nullptr ? n_attr->i() : 0;
119 const int start = (node.GetOp() == "Concat") ? 1 : 0;
120 const int end = start + n;
121 std::vector<int> values(end - start);
122 std::iota(values.begin(), values.end(), start);
123 return values;
124 }
125
126 struct ComparatorByNodeNameAndIndex {
operator ()tensorflow::grappler::__anon2a86c4340111::ComparatorByNodeNameAndIndex127 bool operator()(const utils::MutableFaninView& node1,
128 const utils::MutableFaninView& node2) const {
129 auto* node1_view = node1.node_view();
130 auto* node2_view = node2.node_view();
131 auto name_compare = node1_view->GetName().compare(node2_view->GetName());
132 if (name_compare == 0) {
133 return node1.index() < node2.index();
134 }
135 return name_compare < 0;
136 }
137 };
138
IsHostMemory(const NodeDef & node,int output_port)139 bool IsHostMemory(const NodeDef& node, int output_port) {
140 DeviceNameUtils::ParsedName parsed_name;
141 if (DeviceNameUtils::ParseFullName(node.device(), &parsed_name)) {
142 DeviceType device_type(parsed_name.type);
143 Status s = FindKernelDef(device_type, node, nullptr, nullptr);
144 if (s.ok()) {
145 tensorflow::MemoryTypeVector in_mtypes;
146 tensorflow::MemoryTypeVector out_mtypes;
147 s = tensorflow::MemoryTypesForNode(OpRegistry::Global(), device_type,
148 node, &in_mtypes, &out_mtypes);
149 if (s.ok()) {
150 if (out_mtypes[output_port] == HOST_MEMORY) {
151 return true;
152 }
153 }
154 } else {
155 return true;
156 }
157 }
158 return false;
159 }
160
GetDimensionIndicesFromLabel(const absl::flat_hash_map<char,int> & dim_indices,absl::Span<const char> labels)161 std::vector<int> GetDimensionIndicesFromLabel(
162 const absl::flat_hash_map<char, int>& dim_indices,
163 absl::Span<const char> labels) {
164 std::vector<int> indices;
165 indices.reserve(labels.size());
166 for (const auto& label : labels) {
167 indices.push_back(dim_indices.at(label));
168 }
169 return indices;
170 }
171
172 // RAII-styled object for keeping track of 4D to 5D data format
173 // upgrade/conversion. Currently only NHWC -> NDHWC and NCHW -> NCDHW are
174 // supported.
175 class ScopedDataFormatUpgrader {
176 public:
ScopedDataFormatUpgrader(TransposeContext * context,int rank)177 ScopedDataFormatUpgrader(TransposeContext* context, int rank)
178 : context_(context) {
179 if (rank == 5 && IsSupportedDataFormat(context_->src_format) &&
180 IsSupportedDataFormat(context_->dst_format)) {
181 old_src_format_ = context_->src_format;
182 old_dst_format_ = context_->dst_format;
183 std::string new_src_format = GetUpgradedDataFormat(context_->src_format);
184 std::string new_dst_format = GetUpgradedDataFormat(context_->dst_format);
185 context_->AssignDeviceAndDataFormats(context_->target_device,
186 new_src_format, new_dst_format);
187 upgraded_ = true;
188 }
189 }
190
191 ScopedDataFormatUpgrader(const ScopedDataFormatUpgrader&) = delete;
192 ScopedDataFormatUpgrader& operator=(const ScopedDataFormatUpgrader&) = delete;
193
~ScopedDataFormatUpgrader()194 ~ScopedDataFormatUpgrader() {
195 if (upgraded_) {
196 context_->AssignDeviceAndDataFormats(context_->target_device,
197 old_src_format_, old_dst_format_);
198 }
199 }
200
201 private:
IsSupportedDataFormat(absl::string_view data_format)202 bool IsSupportedDataFormat(absl::string_view data_format) {
203 return data_format == "NHWC" || data_format == "NCHW";
204 }
205
GetUpgradedDataFormat(absl::string_view data_format)206 std::string GetUpgradedDataFormat(absl::string_view data_format) {
207 if (data_format == "NHWC") {
208 return "NDHWC";
209 }
210
211 DCHECK_EQ(data_format, "NCHW");
212 return "NCDHW";
213 }
214
215 TransposeContext* context_ = nullptr;
216 bool upgraded_ = false;
217 std::string old_src_format_;
218 std::string old_dst_format_;
219 };
220
221 } // namespace
222
223 // TransposeContext.
224
InitializeTransposeContext(bool assume_valid_feeds,const GrapplerItem & item,const Cluster * cluster,TransposeContext * context)225 Status TransposeContext::InitializeTransposeContext(bool assume_valid_feeds,
226 const GrapplerItem& item,
227 const Cluster* cluster,
228 TransposeContext* context) {
229 DCHECK(context != nullptr);
230 context->graph_properties = std::make_unique<GraphProperties>(item);
231 TF_RETURN_IF_ERROR(
232 context->graph_properties->InferStatically(assume_valid_feeds));
233 TF_RETURN_IF_ERROR(
234 context->graph_properties->AnnotateOutputShapes(&context->graph));
235 Status status;
236 context->graph_view =
237 std::make_unique<utils::MutableGraphView>(&context->graph, &status);
238 TF_RETURN_IF_ERROR(status);
239 context->num_nodes = context->graph.node_size();
240 const auto& nodes_to_preserve = item.NodesToPreserve();
241 context->nodes_to_preserve = absl::flat_hash_set<string>(
242 nodes_to_preserve.begin(), nodes_to_preserve.end());
243 TF_RETURN_IF_ERROR(context->frames.InferFromGraph(context->graph));
244 return OkStatus();
245 }
246
247 // Sets data formats to convert from and to for specified device type.
AssignDeviceAndDataFormats(absl::string_view target_device,absl::string_view src_format,absl::string_view dst_format)248 void TransposeContext::AssignDeviceAndDataFormats(
249 absl::string_view target_device, absl::string_view src_format,
250 absl::string_view dst_format) {
251 this->target_device = string(target_device);
252 this->src_format = string(src_format);
253 this->dst_format = string(dst_format);
254 this->src_dim_indices = GetDimensionIndices(src_format);
255 this->dst_dim_indices = GetDimensionIndices(dst_format);
256 this->src_to_dst = GetPermutation(this->src_dim_indices, dst_format);
257 this->dst_to_src = GetPermutation(this->dst_dim_indices, src_format);
258 }
259
260 // Transposer.
261
ShouldProcess(const TransposeContext & context,const utils::MutableNodeView & node) const262 bool Transposer::ShouldProcess(const TransposeContext& context,
263 const utils::MutableNodeView& node) const {
264 const auto* node_def = node.node();
265 const string& device_name = GetDeviceName(*node_def);
266 string device;
267 string task;
268 const bool is_on_target_device =
269 DeviceNameUtils::SplitDeviceName(device_name, &task, &device) &&
270 absl::StrContains(absl::AsciiStrToLower(device),
271 absl::AsciiStrToLower(context.target_device));
272
273 // Only checks data format for layout sensitive op.
274 const bool data_format_match = !IsLayoutSensitiveOp(*node_def) ||
275 AttrDataFormatMatch(node, context.src_format);
276
277 // Only transposes floating point nodes.
278 const bool is_integer_conv2d = IsNonFloatingConv2D(node);
279
280 return is_on_target_device && data_format_match && !is_integer_conv2d &&
281 !context.nodes_to_preserve.contains(node_def->name()) &&
282 !(node.NumRegularFanouts() == 0 && node.NumControlledFanouts() == 0);
283 }
284
CreateConstPermNode(TransposeContext * context,absl::string_view node_name,absl::string_view device,absl::Span<const int> permutation,absl::string_view control_node_name,utils::MutationNewNode * added_node)285 Status Transposer::CreateConstPermNode(TransposeContext* context,
286 absl::string_view node_name,
287 absl::string_view device,
288 absl::Span<const int> permutation,
289 absl::string_view control_node_name,
290 utils::MutationNewNode* added_node) {
291 auto* graph_view = context->graph_view.get();
292 DCHECK(!graph_view->HasNode(node_name));
293
294 NodeDef node;
295 node.set_name(string(node_name));
296 node.set_op(kOpConst);
297 node.set_device(string(device));
298
299 if (!control_node_name.empty()) {
300 node.add_input(string(control_node_name));
301 }
302
303 AttrValue attr_data_type;
304 attr_data_type.set_type(DT_INT32);
305 node.mutable_attr()->insert({"dtype", attr_data_type});
306
307 AttrValue attr_tensor;
308 Tensor tensor(DT_INT32, TensorShape({(long long)permutation.size()}));
309 for (int i = 0, end = permutation.size(); i < end; i++) {
310 tensor.flat<int>()(i) = permutation[i];
311 }
312 tensor.AsProtoTensorContent(attr_tensor.mutable_tensor());
313 node.mutable_attr()->insert({"value", attr_tensor});
314
315 Status status;
316 *added_node =
317 graph_view->GetMutationBuilder()->AddNode(std::move(node), &status);
318 return status;
319 }
320
CreateTransposeNode(TransposeContext * context,absl::string_view name_format,const DataType & data_type,absl::string_view device,TensorShapeProto fanin_shape,absl::Span<const int> permutation,absl::string_view control_node_name,utils::MutationNewNode * added_node,string * transpose_node_name)321 Status Transposer::CreateTransposeNode(
322 TransposeContext* context, absl::string_view name_format,
323 const DataType& data_type, absl::string_view device,
324 TensorShapeProto fanin_shape, absl::Span<const int> permutation,
325 absl::string_view control_node_name, utils::MutationNewNode* added_node,
326 string* transpose_node_name) {
327 const string node_name = absl::Substitute(name_format, kOpTranspose);
328 auto* graph_view = context->graph_view.get();
329 DCHECK(!graph_view->HasNode(node_name));
330 *transpose_node_name = node_name;
331
332 NodeDef node;
333 node.set_name(node_name);
334 node.set_op(kOpTranspose);
335 node.set_device(string(device));
336
337 AttrValue attr_data_type;
338 attr_data_type.set_type(data_type);
339 node.mutable_attr()->insert({"T", attr_data_type});
340
341 AttrValue attr_data_type_perm;
342 attr_data_type_perm.set_type(DT_INT32);
343 node.mutable_attr()->insert({"Tperm", attr_data_type_perm});
344
345 if (!fanin_shape.unknown_rank()) {
346 TF_RETURN_IF_ERROR(
347 PermuteSingle(absl::StrCat("fanin shape in", node.name()), permutation,
348 fanin_shape.mutable_dim()));
349 AttrValue attr_output_shape;
350 *attr_output_shape.mutable_list()->add_shape() = fanin_shape;
351 node.mutable_attr()->insert({kAttrOutputShape, attr_output_shape});
352 }
353
354 // Create Const Node
355 utils::MutationNewNode const_perm_added_node;
356 const string const_perm_node_name =
357 absl::Substitute(name_format, "PermConst");
358 TF_RETURN_IF_ERROR(CreateConstPermNode(context, const_perm_node_name, device,
359 permutation, control_node_name,
360 &const_perm_added_node));
361 // Add place holder for 1st input.
362 node.add_input("");
363 // Connect const_perm_node to 2nd input of transpose_node.
364 node.add_input(const_perm_node_name);
365
366 Status status;
367 *added_node =
368 graph_view->GetMutationBuilder()->AddNode(std::move(node), &status);
369 return status;
370 }
371
UpdateFaninEdgesWithOp(TransposeContext * context,absl::Span<const int> dst_ports,utils::MutableNodeView * dst_node,absl::string_view op)372 Status Transposer::UpdateFaninEdgesWithOp(TransposeContext* context,
373 absl::Span<const int> dst_ports,
374 utils::MutableNodeView* dst_node,
375 absl::string_view op) {
376 const bool is_in_frame = context->frames.IsInFrame(*dst_node->node());
377 for (int dst_port : dst_ports) {
378 auto& fanin_port = dst_node->GetRegularFanin(dst_port);
379 auto* fanin_node_view = fanin_port.node_view();
380
381 TF_RETURN_IF_ERROR(
382 UpdateEdge(context,
383 GetFaninNameFormat(dst_node->GetName(), dst_port,
384 context->src_format, context->dst_format),
385 op, /*input_shape=*/nullptr, /*is_in_frame=*/is_in_frame,
386 /*is_src_format_to_dst_format=*/true, fanin_port.index(),
387 dst_port, fanin_node_view, dst_node));
388 }
389 return OkStatus();
390 }
391
UpdateFanoutEdgesWithOp(TransposeContext * context,absl::Span<const int> src_ports,utils::MutableNodeView * src_node,absl::string_view op)392 Status Transposer::UpdateFanoutEdgesWithOp(TransposeContext* context,
393 absl::Span<const int> src_ports,
394 utils::MutableNodeView* src_node,
395 absl::string_view op) {
396 // Update attr _output_shapes for output ports.
397 const auto* output_shape_attr = src_node->GetAttr(kAttrOutputShape);
398 AttrValue shape_attr_copy;
399 if (op == kOpTranspose && output_shape_attr != nullptr) {
400 shape_attr_copy = *output_shape_attr;
401 for (int port : src_ports) {
402 auto* shape = shape_attr_copy.mutable_list()->mutable_shape(port);
403 if (shape->unknown_rank()) continue;
404 TF_RETURN_IF_ERROR(
405 PermuteSingle(absl::StrCat("output shape attribute at port ", port,
406 " in", src_node->GetName()),
407 context->src_to_dst, shape->mutable_dim()));
408 }
409 context->graph_view->GetMutationBuilder()->AddOrUpdateNodeAttr(
410 src_node, kAttrOutputShape, shape_attr_copy);
411 }
412
413 const bool is_in_frame = context->frames.IsInFrame(*src_node->node());
414 // We might modify the output set in the loop. Make a copy first.
415 // Use a set with custom comparator to order output nodes by node name,
416 // so that we can keep transposer name deterministic.
417 for (int src_port : src_ports) {
418 const auto& fanouts_src_port = src_node->GetRegularFanout(src_port);
419 std::vector<utils::MutableFaninView> sorted_fanouts(
420 fanouts_src_port.begin(), fanouts_src_port.end());
421 std::sort(sorted_fanouts.begin(), sorted_fanouts.end(),
422 ComparatorByNodeNameAndIndex());
423 int num_downstream_transposers = 0;
424 for (const auto& fanout : sorted_fanouts) {
425 TF_RETURN_IF_ERROR(UpdateEdge(
426 context,
427 GetFanoutNameFormat(src_node->GetName(), src_port,
428 num_downstream_transposers++, context->src_format,
429 context->dst_format),
430 op, &shape_attr_copy, /*is_in_frame=*/is_in_frame,
431 /*is_src_format_to_dst_format=*/false, src_port, fanout.index(),
432 src_node, fanout.node_view()));
433 }
434 }
435 return OkStatus();
436 }
437
CreateDataFormatNode(TransposeContext * context,absl::string_view node_name,absl::string_view op,absl::string_view device,const DataType & data_type,bool is_fanin_on_host,bool is_src_format_to_dst_format,utils::MutationNewNode * added_node)438 Status Transposer::CreateDataFormatNode(
439 TransposeContext* context, absl::string_view node_name,
440 absl::string_view op, absl::string_view device, const DataType& data_type,
441 bool is_fanin_on_host, bool is_src_format_to_dst_format,
442 utils::MutationNewNode* added_node) {
443 auto* graph_view = context->graph_view.get();
444 DCHECK(!graph_view->HasNode(node_name));
445
446 // Create the node
447 NodeDef node;
448 node.set_name(string(node_name));
449
450 // Set up parameters of node.
451 node.set_op(string(op));
452 node.set_device(string(device));
453 AttrValue attr_data_type;
454 attr_data_type.set_type(data_type);
455 node.mutable_attr()->insert({"T", attr_data_type});
456
457 // The inputs of a DataFormat op could be in host memory for ops such as
458 // Reshape. In such cases, run the kernel on the host too.
459 if (is_fanin_on_host) {
460 AttrValue attr_kernel;
461 attr_kernel.set_s("host");
462 node.mutable_attr()->insert({"_kernel", attr_kernel});
463 }
464
465 AttrValue src_format;
466 src_format.set_s(is_src_format_to_dst_format ? context->src_format
467 : context->dst_format);
468 node.mutable_attr()->insert({kAttrSrcFormat, src_format});
469 AttrValue dst_format;
470 dst_format.set_s(is_src_format_to_dst_format ? context->dst_format
471 : context->src_format);
472 node.mutable_attr()->insert({kAttrDstFormat, dst_format});
473
474 // Add place holder for 1st input field.
475 node.add_input("");
476
477 Status status;
478 *added_node =
479 graph_view->GetMutationBuilder()->AddNode(std::move(node), &status);
480 return status;
481 }
482
UpdateEdge(TransposeContext * context,absl::string_view name_format,absl::string_view op,const AttrValue * input_shape,bool is_in_frame,bool is_src_format_to_dst_format,const int src_port,const int dst_port,utils::MutableNodeView * src_node,utils::MutableNodeView * dst_node)483 Status Transposer::UpdateEdge(
484 TransposeContext* context, absl::string_view name_format,
485 absl::string_view op, const AttrValue* input_shape, bool is_in_frame,
486 bool is_src_format_to_dst_format, const int src_port, const int dst_port,
487 utils::MutableNodeView* src_node, utils::MutableNodeView* dst_node) {
488 DCHECK(src_node != nullptr);
489 DCHECK(dst_node != nullptr);
490 auto* src_node_def = src_node->node();
491 auto* dst_node_def = dst_node->node();
492
493 // TODO(lyandy): Minimize device parsing/fetching.
494 const string device = GetDeviceName(
495 is_src_format_to_dst_format ? *dst_node_def : *src_node_def);
496 DataType data_type =
497 is_src_format_to_dst_format
498 ? context->graph_properties
499 ->GetInputProperties(dst_node->GetName())[dst_port]
500 .dtype()
501 : context->graph_properties
502 ->GetOutputProperties(src_node->GetName())[src_port]
503 .dtype();
504
505 utils::MutationNewNode added_node;
506 string added_node_name;
507 if (op == kOpTranspose) {
508 TensorShapeProto input_shape_proto;
509 input_shape_proto.set_unknown_rank(true);
510 if (input_shape != nullptr) {
511 input_shape_proto = input_shape->list().shape(src_port);
512 } else {
513 const auto* src_node_shape_attr = src_node->GetAttr(kAttrOutputShape);
514 if (src_node_shape_attr != nullptr) {
515 input_shape_proto = src_node_shape_attr->list().shape(src_port);
516 }
517 }
518 const string control_node_name =
519 is_in_frame ? AsControlDependency(src_node_def->name()) : "";
520 const std::vector<int>& permutation =
521 is_src_format_to_dst_format ? context->src_to_dst : context->dst_to_src;
522 TF_RETURN_IF_ERROR(CreateTransposeNode(
523 context, name_format, data_type, device, input_shape_proto, permutation,
524 control_node_name, &added_node, &added_node_name));
525 } else if (op == kOpDataFormatVecPermute || op == kOpDataFormatDimMap) {
526 DeviceNameUtils::ParsedName parsed_name;
527 bool is_fanin_on_host = DeviceNameUtils::ParseFullName(
528 GetDeviceName(*src_node_def), &parsed_name) &&
529 parsed_name.type != "CPU" &&
530 IsHostMemory(*src_node_def, src_port);
531 const string node_name = absl::Substitute(name_format, op);
532 TF_RETURN_IF_ERROR(CreateDataFormatNode(
533 context, node_name, op, device, data_type, is_fanin_on_host,
534 is_src_format_to_dst_format, &added_node));
535 added_node_name = node_name;
536 } else {
537 return Status(error::INVALID_ARGUMENT,
538 absl::StrCat("Unsupported op \"", op,
539 "\". Supported ops are Transpose, "
540 "DataFormatVecPerm, DataFormatDimMap."));
541 }
542
543 // Connect src_node to 1st input of added_node.
544 utils::Mutation* mutation = context->graph_view->GetMutationBuilder();
545 mutation->AddOrUpdateRegularFanin(added_node, 0,
546 {src_node->GetName(), src_port});
547
548 // Connect output of added_node to dst_node:dst_port.
549 mutation->AddOrUpdateRegularFanin(dst_node, dst_port, {added_node_name, 0});
550
551 return OkStatus();
552 }
553
GetFanoutPortRank(const utils::MutableNodeView & node,int port) const554 int Transposer::GetFanoutPortRank(const utils::MutableNodeView& node,
555 int port) const {
556 const auto* output_shape_attr = node.GetAttr(kAttrOutputShape);
557 if (output_shape_attr == nullptr ||
558 output_shape_attr->list().shape_size() <= port) {
559 return kInvalidRank;
560 }
561 const auto& shape = output_shape_attr->list().shape(port);
562 if (shape.unknown_rank()) {
563 return kUnknownRank;
564 }
565 return shape.dim_size();
566 }
567
IsFanoutPortRankN(const utils::MutableNodeView & node,int port,int n) const568 bool Transposer::IsFanoutPortRankN(const utils::MutableNodeView& node, int port,
569 int n) const {
570 return GetFanoutPortRank(node, port) == n;
571 }
572
IsFanoutPortsRankN(const utils::MutableNodeView & node,absl::Span<const int> ports,int n) const573 bool Transposer::IsFanoutPortsRankN(const utils::MutableNodeView& node,
574 absl::Span<const int> ports, int n) const {
575 for (const auto& port : ports) {
576 if (!IsFanoutPortRankN(node, port, n)) {
577 return false;
578 }
579 }
580 return true;
581 }
582
GetFaninPortRank(const utils::MutableNodeView & node,int port) const583 int Transposer::GetFaninPortRank(const utils::MutableNodeView& node,
584 int port) const {
585 if (port < node.NumRegularFanins() && port >= 0) {
586 const auto& regular_fanin = node.GetRegularFanin(port);
587 return GetFanoutPortRank(*regular_fanin.node_view(), regular_fanin.index());
588 }
589 return kInvalidRank;
590 }
591
IsFaninPortRankN(const utils::MutableNodeView & node,int port,int n) const592 bool Transposer::IsFaninPortRankN(const utils::MutableNodeView& node, int port,
593 int n) const {
594 return GetFaninPortRank(node, port) == n;
595 }
596
IsFaninPortDimsNIfConst(const utils::MutableNodeView & node,int port,absl::Span<const int> dims) const597 bool Transposer::IsFaninPortDimsNIfConst(const utils::MutableNodeView& node,
598 int port,
599 absl::Span<const int> dims) const {
600 if (port < node.NumRegularFanins() && port >= 0) {
601 const auto& regular_fanin = node.GetRegularFanin(port);
602 const auto* fanin_node_view = regular_fanin.node_view();
603 if (!IsConstant(*fanin_node_view->node())) {
604 return true;
605 }
606 // If fanin is a Const, check tensor to see if dimensions match.
607 const auto* value_attr = fanin_node_view->GetAttr(kAttrValue);
608 if (value_attr == nullptr) {
609 return false;
610 }
611 Tensor tensor;
612 if (!tensor.FromProto(value_attr->tensor())) {
613 return false;
614 }
615 const int dims_size = dims.size();
616 if (tensor.dims() != dims_size) {
617 return false;
618 }
619 for (int i = 0; i < dims_size; ++i) {
620 if (tensor.dim_size(i) != dims[i]) {
621 return false;
622 }
623 }
624 return true;
625 }
626 return false;
627 }
628
IsFaninPortsDimsNIfConst(const utils::MutableNodeView & node,absl::Span<const int> ports,absl::Span<const int> dims) const629 bool Transposer::IsFaninPortsDimsNIfConst(const utils::MutableNodeView& node,
630 absl::Span<const int> ports,
631 absl::Span<const int> dims) const {
632 for (const auto& port : ports) {
633 if (!IsFaninPortDimsNIfConst(node, port, dims)) {
634 return false;
635 }
636 }
637 return true;
638 }
639
CanProcessNode(const TransposeContext & context,const utils::MutableNodeView & node) const640 bool Transposer::CanProcessNode(const TransposeContext& context,
641 const utils::MutableNodeView& node) const {
642 return !context.nodes_to_preserve.contains(node.GetName()) &&
643 !(node.NumRegularFanouts() == 0 && node.NumControlledFanouts() == 0);
644 }
645
GetFaninNameFormat(absl::string_view node_name,int port,absl::string_view src_format,absl::string_view dst_format)646 string Transposer::GetFaninNameFormat(absl::string_view node_name, int port,
647 absl::string_view src_format,
648 absl::string_view dst_format) {
649 return absl::StrCat(node_name, "-", port, "-$0", src_format, "To", dst_format,
650 "-", kOptimizedSuffix);
651 }
652
GetFanoutNameFormat(absl::string_view node_name,int port,int index,absl::string_view src_format,absl::string_view dst_format)653 string Transposer::GetFanoutNameFormat(absl::string_view node_name, int port,
654 int index, absl::string_view src_format,
655 absl::string_view dst_format) {
656 return absl::StrCat(node_name, "-", port, "-", index, "-$0", dst_format, "To",
657 src_format, "-", kOptimizedSuffix);
658 }
659
LayoutOptimizerNode(absl::string_view node_name)660 string Transposer::LayoutOptimizerNode(absl::string_view node_name) {
661 return absl::StrCat(node_name, "-", kOptimizedSuffix);
662 }
663
GetReshapeNodeNameFormat(absl::string_view node_name,int index,absl::string_view src_format,absl::string_view dst_format)664 string Transposer::GetReshapeNodeNameFormat(absl::string_view node_name,
665 int index,
666 absl::string_view src_format,
667 absl::string_view dst_format) {
668 return absl::StrCat(node_name, "-", index, "-", kReshape, src_format, "To",
669 dst_format);
670 }
671
GetShapeConstNodeNameFormat(absl::string_view node_name,int index)672 string Transposer::GetShapeConstNodeNameFormat(absl::string_view node_name,
673 int index) {
674 return absl::StrCat(node_name, "-", index, "-", kReshapeConst);
675 }
676
677 // Layout sensitive transposer.
678
GetLayoutSensitiveNodeDataFormat(const utils::MutableNodeView & node)679 inline string GetLayoutSensitiveNodeDataFormat(
680 const utils::MutableNodeView& node) {
681 const auto* attr = node.GetAttr(kAttrDataFormat);
682 if (attr != nullptr) {
683 return attr->s();
684 }
685 return "";
686 }
687
UpdateNode(TransposeContext * context,utils::MutableNodeView * node)688 Status LayoutSensitiveOpTransposer::UpdateNode(TransposeContext* context,
689 utils::MutableNodeView* node) {
690 utils::Mutation* mutation = context->graph_view->GetMutationBuilder();
691 AttrValue data_format_attr;
692 data_format_attr.set_s(context->dst_format);
693 mutation->AddOrUpdateNodeAttr(node, kAttrDataFormat, data_format_attr);
694
695 auto permute_attr = [&context, &node,
696 &mutation](absl::string_view attr_name) {
697 const auto* attr = node->GetAttr(attr_name);
698 if (attr != nullptr) {
699 AttrValue attr_copy(*attr);
700 TF_RETURN_IF_ERROR(PermuteSingle(
701 absl::StrCat(attr_name, " attribute in", node->GetName()),
702 context->src_to_dst, attr_copy.mutable_list()->mutable_i()));
703 mutation->AddOrUpdateNodeAttr(node, attr_name, attr_copy);
704 }
705 return OkStatus();
706 };
707
708 // Update attrs.
709 TF_RETURN_IF_ERROR(permute_attr(kAttrStrides));
710 TF_RETURN_IF_ERROR(permute_attr(kAttrKSize));
711 TF_RETURN_IF_ERROR(permute_attr(kAttrDilations));
712
713 const auto* explicit_paddings_attr = node->GetAttr(kAttrExplicitPaddings);
714 if (explicit_paddings_attr != nullptr && explicit_paddings_attr->has_list() &&
715 explicit_paddings_attr->list().i_size() > 0) {
716 AttrValue explicit_paddings_attr_copy(*explicit_paddings_attr);
717 TF_RETURN_IF_ERROR(PermuteDouble(
718 absl::StrCat("explicit_paddings attribute in", node->GetName()),
719 context->src_to_dst,
720 explicit_paddings_attr_copy.mutable_list()->mutable_i()));
721 mutation->AddOrUpdateNodeAttr(node, kAttrExplicitPaddings,
722 explicit_paddings_attr_copy);
723 }
724
725 return OkStatus();
726 }
727
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)728 Status DefaultLayoutSensitiveOpTransposer::TransposeNode(
729 TransposeContext* context, utils::MutableNodeView* node) {
730 DCHECK(IsDefaultLayoutSensitiveOp(*node->node()));
731 const int rank = GetFanoutPortRank(*node, 0);
732 if (rank != 4 && rank != 5) {
733 return OkStatus();
734 }
735 ScopedDataFormatUpgrader data_format_upgrader(context, rank);
736 if (!ShouldProcess(*context, *node)) {
737 return OkStatus();
738 }
739 VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
740 << "' with op '" << node->GetOp() << "' from data format '"
741 << context->src_format << "' to '" << context->dst_format << "'";
742 TF_RETURN_IF_ERROR(UpdateNode(context, node));
743 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
744 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
745 return context->graph_view->GetMutationBuilder()->Apply();
746 }
747
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)748 Status AvgPoolGradTransposer::TransposeNode(TransposeContext* context,
749 utils::MutableNodeView* node) {
750 DCHECK(IsAvgPoolGrad(*node->node()));
751 if (!ShouldProcess(*context, *node) || !IsFaninPortRankN(*node, 1, 4)) {
752 return OkStatus();
753 }
754 VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
755 << "' with op '" << node->GetOp() << "' from data format '"
756 << context->src_format << "' to '" << context->dst_format << "'";
757 TF_RETURN_IF_ERROR(UpdateNode(context, node));
758 TF_RETURN_IF_ERROR(
759 UpdateFaninEdgesWithOp(context, {0}, node, kOpDataFormatVecPermute));
760 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {1}, node, kOpTranspose));
761 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
762 return context->graph_view->GetMutationBuilder()->Apply();
763 }
764
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)765 Status BiasAddTransposer::TransposeNode(TransposeContext* context,
766 utils::MutableNodeView* node) {
767 // This TransposeNode allows for BiasAdd but not BiasAddV1, since BiasAdd
768 // supports different data format.
769 DCHECK(IsBiasAddV2(*node->node()));
770 const int rank = GetFanoutPortRank(*node, 0);
771 if (rank != 4 && rank != 5) {
772 return OkStatus();
773 }
774 if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, rank)) {
775 return OkStatus();
776 }
777 VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
778 << "' with op '" << node->GetOp() << "' from data format '"
779 << context->src_format << "' to '" << context->dst_format << "'";
780 // BiasAdd itself only needs NCHW/NHWC to determine whether C dim is the
781 // second or the last dim. Therefore, we use the original 4D data format in
782 // the context to update the node. For the input/output tensor, the
783 // corresponding 4D or 5D data format is needed.
784 TF_RETURN_IF_ERROR(UpdateNode(context, node));
785 ScopedDataFormatUpgrader data_format_upgrader(context, rank);
786 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
787 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
788 return context->graph_view->GetMutationBuilder()->Apply();
789 }
790
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)791 Status BiasAddGradTransposer::TransposeNode(TransposeContext* context,
792 utils::MutableNodeView* node) {
793 DCHECK(IsBiasAddGrad(*node->node()));
794 const int rank = GetFaninPortRank(*node, 0);
795 if (rank != 4 && rank != 5) {
796 return OkStatus();
797 }
798 if (!ShouldProcess(*context, *node)) {
799 return OkStatus();
800 }
801 VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
802 << "' with op '" << node->GetOp() << "' from data format '"
803 << context->src_format << "' to '" << context->dst_format << "'";
804 // BiasAddGrad itself only needs NCHW/NHWC to determine whether C dim is the
805 // second or the last dim. Therefore, we use the original 4D data format in
806 // the context to update the node. For the input tensor, the corresponding 4D
807 // or 5D data format is needed.
808 TF_RETURN_IF_ERROR(UpdateNode(context, node));
809 ScopedDataFormatUpgrader data_format_upgrader(context, rank);
810 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
811 // No need to update output shape, as it is always of shape 1-D with size the
812 // feature dimension of `out_backprop`, regardless of whether NCHW or NHWC is
813 // used.
814 return context->graph_view->GetMutationBuilder()->Apply();
815 }
816
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)817 Status Conv2DBackpropFilterTransposer::TransposeNode(
818 TransposeContext* context, utils::MutableNodeView* node) {
819 DCHECK(IsConv2DBackpropFilter(*node->node()) ||
820 IsDepthwiseConv2dNativeBackpropFilter(*node->node()));
821 if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4)) {
822 return OkStatus();
823 }
824 VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
825 << "' with op '" << node->GetOp() << "' from data format '"
826 << context->src_format << "' to '" << context->dst_format << "'";
827 TF_RETURN_IF_ERROR(UpdateNode(context, node));
828 TF_RETURN_IF_ERROR(
829 UpdateFaninEdgesWithOp(context, {0, 2}, node, kOpTranspose));
830 // No need to update output shape, as it is always of shape
831 // [filter_height, filter_width, in_channels, out_channels], regardless of
832 // whether NCHW or NHWC is used.
833 return context->graph_view->GetMutationBuilder()->Apply();
834 }
835
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)836 Status Conv2DBackpropInputTransposer::TransposeNode(
837 TransposeContext* context, utils::MutableNodeView* node) {
838 DCHECK(IsConv2DBackpropInput(*node->node()) ||
839 IsDepthwiseConv2dNativeBackpropInput(*node->node()));
840 if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4)) {
841 return OkStatus();
842 }
843
844 const auto& fanin = node->GetRegularFanin(0);
845 auto* fanin_node = fanin.node_view();
846 const auto* output_shape_attr = fanin_node->GetAttr(kAttrOutputShape);
847 if (output_shape_attr == nullptr) {
848 VLOG(3) << "Cannot compute the shape of " << fanin_node->GetName()
849 << " because it is missing attribute " << kAttrOutputShape;
850 return OkStatus();
851 }
852 TensorShapeProto fanin_shape = output_shape_attr->list().shape(fanin.index());
853 if (fanin_shape.dim_size() != 1) {
854 VLOG(3) << fanin_node->GetName() << " is not a vector.";
855 return OkStatus();
856 }
857
858 VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
859 << "' with op '" << node->GetOp() << "' from data format '"
860 << context->src_format << "' to '" << context->dst_format << "'";
861 TF_RETURN_IF_ERROR(UpdateNode(context, node));
862 TF_RETURN_IF_ERROR(
863 UpdateFaninEdgesWithOp(context, {0}, node, kOpDataFormatVecPermute));
864 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {2}, node, kOpTranspose));
865 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
866 return context->graph_view->GetMutationBuilder()->Apply();
867 }
868
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)869 Status Conv3DTransposer::TransposeNode(TransposeContext* context,
870 utils::MutableNodeView* node) {
871 DCHECK(IsConv3D(*node->node()));
872 const int rank = GetFanoutPortRank(*node, 0);
873 if (rank != 5) {
874 return OkStatus();
875 }
876 ScopedDataFormatUpgrader data_format_upgrader(context, rank);
877 if (!ShouldProcess(*context, *node)) {
878 return OkStatus();
879 }
880 VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
881 << "' with op '" << node->GetOp() << "' from data format '"
882 << context->src_format << "' to '" << context->dst_format << "'";
883 TF_RETURN_IF_ERROR(UpdateNode(context, node));
884 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
885 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
886 return context->graph_view->GetMutationBuilder()->Apply();
887 }
888
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)889 Status Conv3DBackpropFilterTransposer::TransposeNode(
890 TransposeContext* context, utils::MutableNodeView* node) {
891 DCHECK(IsConv3DBackpropFilterV2(*node->node()));
892 const int rank = GetFanoutPortRank(*node, 0);
893 if (rank != 5) {
894 return OkStatus();
895 }
896 ScopedDataFormatUpgrader data_format_upgrader(context, rank);
897 if (!ShouldProcess(*context, *node)) {
898 return OkStatus();
899 }
900 VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
901 << "' with op '" << node->GetOp() << "' from data format '"
902 << context->src_format << "' to '" << context->dst_format << "'";
903 TF_RETURN_IF_ERROR(UpdateNode(context, node));
904 TF_RETURN_IF_ERROR(
905 UpdateFaninEdgesWithOp(context, {0, 2}, node, kOpTranspose));
906 // No need to update output shape, as it is always of shape
907 // [filter_height, filter_width, in_channels, out_channels], regardless of
908 // whether NCHW or NHWC is used.
909 return context->graph_view->GetMutationBuilder()->Apply();
910 }
911
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)912 Status Conv3DBackpropInputTransposer::TransposeNode(
913 TransposeContext* context, utils::MutableNodeView* node) {
914 DCHECK(IsConv3DBackpropInputV2(*node->node()));
915 const int rank = GetFanoutPortRank(*node, 0);
916 if (rank != 5) {
917 return OkStatus();
918 }
919 ScopedDataFormatUpgrader data_format_upgrader(context, rank);
920 if (!ShouldProcess(*context, *node)) {
921 return OkStatus();
922 }
923 VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
924 << "' with op '" << node->GetOp() << "' from data format '"
925 << context->src_format << "' to '" << context->dst_format << "'";
926 TF_RETURN_IF_ERROR(UpdateNode(context, node));
927 TF_RETURN_IF_ERROR(
928 UpdateFaninEdgesWithOp(context, {0}, node, kOpDataFormatVecPermute));
929 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {2}, node, kOpTranspose));
930 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
931 return context->graph_view->GetMutationBuilder()->Apply();
932 }
933
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)934 Status FusedBatchNormExTransposer::TransposeNode(TransposeContext* context,
935 utils::MutableNodeView* node) {
936 DCHECK(IsFusedBatchNormEx(*node->node()));
937 if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4)) {
938 return OkStatus();
939 }
940 VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
941 << "' with op '" << node->GetOp() << "' from data format '"
942 << context->src_format << "' to '" << context->dst_format << "'";
943 TF_RETURN_IF_ERROR(UpdateNode(context, node));
944 if (node->NumRegularFanins() == 6) {
945 TF_RETURN_IF_ERROR(
946 UpdateFaninEdgesWithOp(context, {0, 5}, node, kOpTranspose));
947 } else {
948 TF_RETURN_IF_ERROR(
949 UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
950 }
951 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
952 return context->graph_view->GetMutationBuilder()->Apply();
953 }
954
IsTraining(const utils::MutableNodeView & node) const955 bool FusedBatchNormGradTransposer::IsTraining(
956 const utils::MutableNodeView& node) const {
957 const auto* is_training_attr = node.GetAttr(kAttrIsTraining);
958 if (is_training_attr != nullptr) {
959 return is_training_attr->b();
960 }
961 return false;
962 }
963
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)964 Status FusedBatchNormGradTransposer::TransposeNode(
965 TransposeContext* context, utils::MutableNodeView* node) {
966 DCHECK(IsFusedBatchNormGrad(*node->node()));
967 const int rank = GetFanoutPortRank(*node, 0);
968 if (rank != 4 && rank != 5) {
969 return OkStatus();
970 }
971 ScopedDataFormatUpgrader data_format_upgrader(context, rank);
972 if (!ShouldProcess(*context, *node) || !IsTraining(*node)) {
973 return OkStatus();
974 }
975 VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
976 << "' with op '" << node->GetOp() << "' from data format '"
977 << context->src_format << "' to '" << context->dst_format << "'";
978 TF_RETURN_IF_ERROR(UpdateNode(context, node));
979 TF_RETURN_IF_ERROR(
980 UpdateFaninEdgesWithOp(context, {0, 1}, node, kOpTranspose));
981 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
982 return context->graph_view->GetMutationBuilder()->Apply();
983 }
984
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)985 Status MaxPoolV2Transposer::TransposeNode(TransposeContext* context,
986 utils::MutableNodeView* node) {
987 DCHECK(IsMaxPoolV2(*node->node()));
988 // We check data_input's shape instead, because the shape inference of
989 // MaxPoolV2 is not able to infer the shape when ksize or strides is not
990 // constant.
991 const auto& data_fanin = node->GetRegularFanin(0);
992 auto* data_fanin_node = data_fanin.node_view();
993 if (!ShouldProcess(*context, *node) ||
994 !IsFanoutPortRankN(*data_fanin_node, data_fanin.index(), 4)) {
995 return OkStatus();
996 }
997 VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
998 << "' with op '" << node->GetOp() << "' from data format '"
999 << context->src_format << "' to '" << context->dst_format << "'";
1000 TF_RETURN_IF_ERROR(UpdateNode(context, node));
1001 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
1002 TF_RETURN_IF_ERROR(
1003 UpdateFaninEdgesWithOp(context, {1, 2}, node, kOpDataFormatVecPermute));
1004 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1005 return context->graph_view->GetMutationBuilder()->Apply();
1006 }
1007
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1008 Status MaxPoolGradTransposer::TransposeNode(TransposeContext* context,
1009 utils::MutableNodeView* node) {
1010 DCHECK(IsMaxPoolGrad(*node->node()) || IsMaxPoolGradGradV1(*node->node()));
1011 if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4)) {
1012 return OkStatus();
1013 }
1014 VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
1015 << "' with op '" << node->GetOp() << "' from data format '"
1016 << context->src_format << "' to '" << context->dst_format << "'";
1017 TF_RETURN_IF_ERROR(UpdateNode(context, node));
1018 TF_RETURN_IF_ERROR(
1019 UpdateFaninEdgesWithOp(context, {0, 1, 2}, node, kOpTranspose));
1020 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1021 return context->graph_view->GetMutationBuilder()->Apply();
1022 }
1023
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1024 Status MaxPoolGradV2Transposer::TransposeNode(TransposeContext* context,
1025 utils::MutableNodeView* node) {
1026 DCHECK(IsMaxPoolGradV2(*node->node()) || IsMaxPoolGradGradV2(*node->node()));
1027 if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4)) {
1028 return OkStatus();
1029 }
1030 VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
1031 << "' with op '" << node->GetOp() << "' from data format '"
1032 << context->src_format << "' to '" << context->dst_format << "'";
1033 TF_RETURN_IF_ERROR(UpdateNode(context, node));
1034 TF_RETURN_IF_ERROR(
1035 UpdateFaninEdgesWithOp(context, {0, 1, 2}, node, kOpTranspose));
1036 TF_RETURN_IF_ERROR(
1037 UpdateFaninEdgesWithOp(context, {3, 4}, node, kOpDataFormatVecPermute));
1038 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1039 return context->graph_view->GetMutationBuilder()->Apply();
1040 }
1041
1042 // Layout agnostic transposer.
1043
IsValidConstPermTransposeNode(const utils::MutableNodeView & node,absl::Span<const int> permutation)1044 inline bool IsValidConstPermTransposeNode(const utils::MutableNodeView& node,
1045 absl::Span<const int> permutation) {
1046 Tensor tensor;
1047 if (!GetValueAttrFromConstInputNode(node, IsTranspose, 1, &tensor)) {
1048 return false;
1049 }
1050 const int permutation_size = permutation.size();
1051 if (tensor.NumElements() != permutation_size) {
1052 return false;
1053 }
1054
1055 const auto& tensor_data = tensor.unaligned_flat<int32>();
1056 for (int i = 0; i < permutation_size; i++) {
1057 if (permutation[i] != tensor_data(i)) {
1058 return false;
1059 }
1060 }
1061 return true;
1062 }
1063
IsValidDataFormatNode(const utils::MutableNodeView & node,absl::string_view src_format,absl::string_view dst_format)1064 inline bool IsValidDataFormatNode(const utils::MutableNodeView& node,
1065 absl::string_view src_format,
1066 absl::string_view dst_format) {
1067 if (!IsDataFormatOp(node)) {
1068 return false;
1069 }
1070 const auto* src_format_attr = node.GetAttr(kAttrSrcFormat);
1071 if (src_format_attr == nullptr || src_format_attr->s() != src_format) {
1072 return false;
1073 }
1074 const auto* dst_format_attr = node.GetAttr(kAttrDstFormat);
1075 if (dst_format_attr == nullptr || dst_format_attr->s() != dst_format) {
1076 return false;
1077 }
1078 return true;
1079 }
1080
IsLayoutOptimizerAddedDstToSrcTranspose(const TransposeContext & context,const utils::MutableNodeView & node)1081 inline bool IsLayoutOptimizerAddedDstToSrcTranspose(
1082 const TransposeContext& context, const utils::MutableNodeView& node) {
1083 return node.node_index() >= context.num_nodes &&
1084 IsValidConstPermTransposeNode(node, context.dst_to_src);
1085 }
1086
IsLayoutOptimizerAddedDstToSrcTransform(const TransposeContext & context,const utils::MutableNodeView & node)1087 inline bool IsLayoutOptimizerAddedDstToSrcTransform(
1088 const TransposeContext& context, const utils::MutableNodeView& node) {
1089 return node.node_index() >= context.num_nodes &&
1090 (IsValidConstPermTransposeNode(node, context.dst_to_src) ||
1091 IsValidDataFormatNode(node, context.dst_format, context.src_format));
1092 }
1093
IsAfterDstToSrcTransform(const TransposeContext & context,const utils::MutableNodeView & node) const1094 bool LayoutAgnosticOpTransposer::IsAfterDstToSrcTransform(
1095 const TransposeContext& context, const utils::MutableNodeView& node) const {
1096 std::deque<utils::MutableNodeView*> queue;
1097 absl::flat_hash_set<utils::MutableNodeView*> visited_nodes;
1098 auto data_node_pos = GetDataFaninPorts(node);
1099 for (const int pos : data_node_pos) {
1100 const auto& fanin = node.GetRegularFanin(pos);
1101 auto* fanin_node = fanin.node_view();
1102 queue.push_back(fanin_node);
1103 visited_nodes.insert(fanin_node);
1104 }
1105 // The code will exit this while loop in one iteration in most cases, as the
1106 // graph is already topologically sorted.
1107 while (!queue.empty()) {
1108 utils::MutableNodeView* current_node = queue.front();
1109 queue.pop_front();
1110 if (IsLayoutOptimizerAddedDstToSrcTransform(context, *current_node)) {
1111 return true;
1112 }
1113 // We only continue searching if the path is connected through
1114 // format-agnostic nodes.
1115 if (IsLayoutAgnosticOp(*current_node->node())) {
1116 auto current_node_pos = GetDataFaninPorts(*current_node);
1117 for (const auto& pos : current_node_pos) {
1118 const auto& fanin = current_node->GetRegularFanin(pos);
1119 auto* fanin_node = fanin.node_view();
1120 if (visited_nodes.insert(fanin_node).second) {
1121 queue.push_back(fanin_node);
1122 }
1123 }
1124 }
1125 }
1126 return false;
1127 }
1128
GetVariadicNDFaninPorts(const TransposeContext & context,const utils::MutableNodeView & node,int rank) const1129 std::vector<int> LayoutAgnosticOpTransposer::GetVariadicNDFaninPorts(
1130 const TransposeContext& context, const utils::MutableNodeView& node,
1131 int rank) const {
1132 std::vector<int> ports;
1133 const int num_regular_fanins = node.NumRegularFanins();
1134 ports.reserve(num_regular_fanins);
1135 for (int i = 0; i < num_regular_fanins; ++i) {
1136 const auto& regular_fanin = node.GetRegularFanin(i);
1137 auto* regular_fanin_node = regular_fanin.node_view();
1138 int regular_fanin_port = regular_fanin.index();
1139 if ((IsFanoutPortRankN(*regular_fanin_node, regular_fanin_port, rank)) &&
1140 ((IsAfterDstToSrcTransform(context, *regular_fanin_node) &&
1141 IsLayoutAgnosticOp(*regular_fanin_node->node())) ||
1142 IsLayoutOptimizerAddedDstToSrcTranspose(context,
1143 *regular_fanin_node))) {
1144 ports.push_back(i);
1145 }
1146 }
1147 return ports;
1148 }
1149
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1150 Status DefaultLayoutAgnosticOpTransposer::TransposeNode(
1151 TransposeContext* context, utils::MutableNodeView* node) {
1152 DCHECK(IsDefaultLayoutAgnosticOp(*node->node()));
1153 const int rank = GetFanoutPortRank(*node, 0);
1154 if (rank != 4 && rank != 5) {
1155 return OkStatus();
1156 }
1157 ScopedDataFormatUpgrader data_format_upgrader(context, rank);
1158 if (!ShouldProcess(*context, *node) ||
1159 !IsAfterDstToSrcTransform(*context, *node)) {
1160 return OkStatus();
1161 }
1162 VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
1163 << "' with op '" << node->GetOp() << "' from data format '"
1164 << context->src_format << "' to '" << context->dst_format << "'";
1165 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
1166 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1167 return context->graph_view->GetMutationBuilder()->Apply();
1168 }
1169
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1170 Status AddNTransposer::TransposeNode(TransposeContext* context,
1171 utils::MutableNodeView* node) {
1172 DCHECK(IsAddN(*node->node()));
1173 const int rank = GetFanoutPortRank(*node, 0);
1174 if (rank != 4 && rank != 5) {
1175 return OkStatus();
1176 }
1177 ScopedDataFormatUpgrader data_format_upgrader(context, rank);
1178 if (!ShouldProcess(*context, *node) ||
1179 !IsAfterDstToSrcTransform(*context, *node)) {
1180 return OkStatus();
1181 }
1182 VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
1183 << "' with op '" << node->GetOp() << "' from data format '"
1184 << context->src_format << "' to '" << context->dst_format << "'";
1185 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, GetDataFaninPorts(*node),
1186 node, kOpTranspose));
1187 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1188 return context->graph_view->GetMutationBuilder()->Apply();
1189 }
1190
IsNDOperateWithMD(const utils::MutableNodeView & node,int n,int m)1191 bool BinaryOpTransposer::IsNDOperateWithMD(const utils::MutableNodeView& node,
1192 int n, int m) {
1193 return IsFaninPortRankN(node, 0, n) && IsFaninPortRankN(node, 1, m);
1194 }
1195
IsFaninShapeSupported(const utils::MutableNodeView & node,int rank)1196 bool BinaryOpTransposer::IsFaninShapeSupported(
1197 const utils::MutableNodeView& node, int rank) {
1198 return (IsNDOperateWithMD(node, rank, 0) ||
1199 IsNDOperateWithMD(node, rank, 1) ||
1200 IsNDOperateWithMD(node, rank, rank) ||
1201 IsNDOperateWithMD(node, 0, rank) || IsNDOperateWithMD(node, 1, rank));
1202 }
1203
GetNDDataFaninPorts(const utils::MutableNodeView & node,int rank)1204 std::vector<int> BinaryOpTransposer::GetNDDataFaninPorts(
1205 const utils::MutableNodeView& node, int rank) {
1206 std::vector<int> values;
1207 if (IsFaninPortRankN(node, 0, rank)) {
1208 values.push_back(0);
1209 }
1210 if (IsFaninPortRankN(node, 1, rank)) {
1211 values.push_back(1);
1212 }
1213 return values;
1214 }
1215
AddNodeReshape(utils::Mutation * mutation,absl::string_view node_name,absl::string_view node_device,absl::string_view input_name,absl::string_view shape_const_node_name,const DataType & data_type)1216 Status BinaryOpTransposer::AddNodeReshape(
1217 utils::Mutation* mutation, absl::string_view node_name,
1218 absl::string_view node_device, absl::string_view input_name,
1219 absl::string_view shape_const_node_name, const DataType& data_type) {
1220 NodeDef new_node;
1221 new_node.set_name(string(node_name));
1222 new_node.add_input(string(input_name));
1223 new_node.add_input(string(shape_const_node_name));
1224 new_node.set_op(kReshape);
1225 new_node.set_device(string(node_device));
1226
1227 AttrValue attr_type_indices;
1228 attr_type_indices.set_type(DT_INT32);
1229 new_node.mutable_attr()->insert({"Tshape", attr_type_indices});
1230
1231 AttrValue attr_type_params;
1232 attr_type_params.set_type(data_type);
1233 new_node.mutable_attr()->insert({"T", attr_type_params});
1234
1235 Status status;
1236 mutation->AddNode(std::move(new_node), &status);
1237 return status;
1238 }
1239
AddNodeShapeConst(utils::Mutation * mutation,absl::string_view node_name,absl::string_view node_device,bool node_in_frame,int num_channels,absl::string_view depended_node,int rank)1240 Status BinaryOpTransposer::AddNodeShapeConst(
1241 utils::Mutation* mutation, absl::string_view node_name,
1242 absl::string_view node_device, bool node_in_frame, int num_channels,
1243 absl::string_view depended_node, int rank) {
1244 NodeDef new_node;
1245 new_node.set_name(string(node_name));
1246 new_node.set_op(kOpConst);
1247 new_node.set_device(string(node_device));
1248 AttrValue attr_data_type;
1249 attr_data_type.set_type(DT_INT32);
1250 new_node.mutable_attr()->insert({"dtype", attr_data_type});
1251
1252 AttrValue attr_tensor;
1253 Tensor tensor(DT_INT32, TensorShape({rank}));
1254 std::vector<int> shape(rank, 1);
1255 shape[1] = num_channels;
1256 for (int i = 0; i < static_cast<int>(shape.size()); i++) {
1257 tensor.flat<int>()(i) = shape[i];
1258 }
1259 tensor.AsProtoTensorContent(attr_tensor.mutable_tensor());
1260 new_node.mutable_attr()->insert({"value", attr_tensor});
1261 if (node_in_frame) {
1262 // This is to ensure the transpose node and the const node are in the same
1263 // frame.
1264 // TODO(halehri): Add Test that exercises this condition.
1265 new_node.add_input(AsControlDependency(string(depended_node)));
1266 }
1267
1268 Status status;
1269 mutation->AddNode(std::move(new_node), &status);
1270 return status;
1271 }
1272
MaybeReshapeVectorFanin(TransposeContext * context,utils::MutableNodeView * node,int rank)1273 Status BinaryOpTransposer::MaybeReshapeVectorFanin(TransposeContext* context,
1274 utils::MutableNodeView* node,
1275 int rank) {
1276 int vector_index = -1;
1277 if (IsNDOperateWithMD(*node, rank, 1)) {
1278 vector_index = 1;
1279 } else if (IsNDOperateWithMD(*node, 1, rank)) {
1280 vector_index = 0;
1281 }
1282 if (vector_index != -1) {
1283 const string& node_name = node->GetName();
1284 const string& node_device = node->GetDevice();
1285 string reshape_node_name = LayoutOptimizerNode(GetReshapeNodeNameFormat(
1286 node_name, vector_index, context->src_format, context->dst_format));
1287 string shape_const_node_name = LayoutOptimizerNode(
1288 GetShapeConstNodeNameFormat(node_name, vector_index));
1289 const auto& fanin = node->GetRegularFanin(vector_index);
1290 auto* fanin_node = fanin.node_view();
1291 const auto* output_shape_attr = fanin_node->GetAttr(kAttrOutputShape);
1292 if (output_shape_attr == nullptr) {
1293 return errors::InvalidArgument("Missing attribute ", kAttrOutputShape);
1294 }
1295 int vector_size =
1296 output_shape_attr->list().shape(fanin.index()).dim(0).size();
1297 utils::Mutation* mutation = context->graph_view->GetMutationBuilder();
1298 TF_RETURN_IF_ERROR(
1299 AddNodeShapeConst(mutation, shape_const_node_name, node_device,
1300 context->frames.IsInFrame(*node->node()), vector_size,
1301 fanin_node->GetName(), rank));
1302 const auto* t_attr = node->GetAttr(kAttrT);
1303 if (t_attr == nullptr) {
1304 return errors::InvalidArgument("Missing attribute ", kAttrT);
1305 }
1306 TF_RETURN_IF_ERROR(
1307 AddNodeReshape(mutation, reshape_node_name, node_device,
1308 TensorIdToString({fanin_node->GetName(), fanin.index()}),
1309 shape_const_node_name, t_attr->type()));
1310 mutation->AddOrUpdateRegularFanin(node, vector_index,
1311 {reshape_node_name, 0});
1312 }
1313 return OkStatus();
1314 }
1315
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1316 Status BinaryOpTransposer::TransposeNode(TransposeContext* context,
1317 utils::MutableNodeView* node) {
1318 DCHECK(IsBinaryOp(*node->node()));
1319 const int rank = GetFanoutPortRank(*node, 0);
1320 if (rank != 4 && rank != 5) {
1321 return OkStatus();
1322 }
1323 ScopedDataFormatUpgrader data_format_upgrader(context, rank);
1324 if (!ShouldProcess(*context, *node) || !IsFaninShapeSupported(*node, rank) ||
1325 !IsAfterDstToSrcTransform(*context, *node)) {
1326 return OkStatus();
1327 }
1328 VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
1329 << "' with op '" << node->GetOp() << "' from data format '"
1330 << context->src_format << "' to '" << context->dst_format << "'";
1331 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(
1332 context, GetNDDataFaninPorts(*node, rank), node, kOpTranspose));
1333 TF_RETURN_IF_ERROR(MaybeReshapeVectorFanin(context, node, rank));
1334 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1335 return context->graph_view->GetMutationBuilder()->Apply();
1336 }
1337
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1338 Status ConcatOpTransposer::TransposeNode(TransposeContext* context,
1339 utils::MutableNodeView* node) {
1340 DCHECK(IsConcat(*node->node()));
1341 const int rank = GetFanoutPortRank(*node, 0);
1342 if (rank != 4 && rank != 5) {
1343 return OkStatus();
1344 }
1345 ScopedDataFormatUpgrader data_format_upgrader(context, rank);
1346 if (!ShouldProcess(*context, *node) ||
1347 !IsAfterDstToSrcTransform(*context, *node)) {
1348 return OkStatus();
1349 }
1350 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(
1351 context, GetConcatDataFaninPorts(*node), node, kOpTranspose));
1352 int axis_node = 0;
1353 if (node->GetOp() == "ConcatV2") {
1354 const auto* n_attr = node->GetAttr(kAttrN);
1355 if (n_attr != nullptr) {
1356 axis_node = n_attr->i();
1357 }
1358 }
1359 VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
1360 << "' with op '" << node->GetOp() << "' from data format '"
1361 << context->src_format << "' to '" << context->dst_format << "'";
1362 TF_RETURN_IF_ERROR(
1363 UpdateFaninEdgesWithOp(context, {axis_node}, node, kOpDataFormatDimMap));
1364 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1365 return context->graph_view->GetMutationBuilder()->Apply();
1366 }
1367
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1368 Status FillOpTransposer::TransposeNode(TransposeContext* context,
1369 utils::MutableNodeView* node) {
1370 DCHECK(IsFill(*node->node()));
1371 if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
1372 !IsFaninPortDimsNIfConst(*node, 0, {4}) ||
1373 !IsAfterDstToSrcTransform(*context, *node)) {
1374 return OkStatus();
1375 }
1376 TF_RETURN_IF_ERROR(
1377 UpdateFaninEdgesWithOp(context, {0}, node, kOpDataFormatVecPermute));
1378 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1379 return context->graph_view->GetMutationBuilder()->Apply();
1380 }
1381
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1382 Status IdentityNTransposer::TransposeNode(TransposeContext* context,
1383 utils::MutableNodeView* node) {
1384 DCHECK(IsIdentityN(*node->node()));
1385 const auto ports_4d = GetVariadicNDFaninPorts(*context, *node, 4);
1386
1387 // Temporarily upgrade the context to obtain the number of 5D fanin ports.
1388 std::vector<int> ports_5d;
1389 {
1390 ScopedDataFormatUpgrader data_format_upgrader(context, 5);
1391 ports_5d = GetVariadicNDFaninPorts(*context, *node, 5);
1392 }
1393
1394 if (!ShouldProcess(*context, *node)) {
1395 return OkStatus();
1396 }
1397
1398 if (!ports_4d.empty()) {
1399 TF_RETURN_IF_ERROR(
1400 UpdateFaninEdgesWithOp(context, ports_4d, node, kOpTranspose));
1401 TF_RETURN_IF_ERROR(
1402 UpdateFanoutEdgesWithOp(context, ports_4d, node, kOpTranspose));
1403 }
1404
1405 if (!ports_5d.empty()) {
1406 ScopedDataFormatUpgrader data_format_upgrader(context, 5);
1407 TF_RETURN_IF_ERROR(
1408 UpdateFaninEdgesWithOp(context, ports_5d, node, kOpTranspose));
1409 TF_RETURN_IF_ERROR(
1410 UpdateFanoutEdgesWithOp(context, ports_5d, node, kOpTranspose));
1411 }
1412 return context->graph_view->GetMutationBuilder()->Apply();
1413 }
1414
IsEveryFaninAfterDstToSrcTransform(const TransposeContext & context,const utils::MutableNodeView & node) const1415 bool MergeTransposer::IsEveryFaninAfterDstToSrcTransform(
1416 const TransposeContext& context, const utils::MutableNodeView& node) const {
1417 for (const auto& regular_fanin : node.GetRegularFanins()) {
1418 auto* regular_fanin_node = regular_fanin.node_view();
1419 if (IsFanoutPortRankN(*regular_fanin_node, regular_fanin.index(), 4) &&
1420 ((IsAfterDstToSrcTransform(context, *regular_fanin_node) &&
1421 IsLayoutAgnosticOp(*regular_fanin_node->node())) ||
1422 IsLayoutOptimizerAddedDstToSrcTranspose(context,
1423 *regular_fanin_node))) {
1424 continue;
1425 }
1426 return false;
1427 }
1428 return true;
1429 }
1430
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1431 Status MergeTransposer::TransposeNode(TransposeContext* context,
1432 utils::MutableNodeView* node) {
1433 DCHECK(IsMerge(*node->node()));
1434 if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
1435 !IsEveryFaninAfterDstToSrcTransform(*context, *node)) {
1436 return OkStatus();
1437 }
1438 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, GetDataFaninPorts(*node),
1439 node, kOpTranspose));
1440 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1441 return context->graph_view->GetMutationBuilder()->Apply();
1442 }
1443
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1444 Status PadTransposer::TransposeNode(TransposeContext* context,
1445 utils::MutableNodeView* node) {
1446 DCHECK(IsMirrorPad(*node->node()) || IsMirrorPadGrad(*node->node()) ||
1447 IsPad(*node->node()));
1448 if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
1449 !IsFaninPortDimsNIfConst(*node, 1, {4, 2}) ||
1450 !IsAfterDstToSrcTransform(*context, *node)) {
1451 return OkStatus();
1452 }
1453 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
1454 TF_RETURN_IF_ERROR(
1455 UpdateFaninEdgesWithOp(context, {1}, node, kOpDataFormatVecPermute));
1456 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1457 return context->graph_view->GetMutationBuilder()->Apply();
1458 }
1459
KeepDims(const utils::MutableNodeView & node)1460 bool ReduceTransposer::KeepDims(const utils::MutableNodeView& node) {
1461 const auto* keep_dims_attr = node.GetAttr(kAttrKeepDims);
1462 if (keep_dims_attr != nullptr) {
1463 return keep_dims_attr->b();
1464 }
1465 return false;
1466 }
1467
IsAlongAxis(const Tensor & tensor,absl::Span<const int> axis,int rank)1468 bool ReduceTransposer::IsAlongAxis(const Tensor& tensor,
1469 absl::Span<const int> axis, int rank) {
1470 const int axis_size = axis.size();
1471 if (tensor.dims() != 1 || tensor.dim_size(0) != axis_size) {
1472 return false;
1473 }
1474 for (int i = 0; i < axis_size; ++i) {
1475 int local_axis = 0;
1476 if (tensor.dtype() == DT_INT32) {
1477 local_axis = tensor.flat<int32>()(i);
1478 } else {
1479 local_axis = tensor.flat<int64_t>()(i);
1480 }
1481 if (local_axis < 0) {
1482 local_axis += rank;
1483 }
1484 bool along_axis = false;
1485 for (int dim : axis) {
1486 if (local_axis == dim) {
1487 along_axis = true;
1488 break;
1489 }
1490 }
1491 if (!along_axis) {
1492 return false;
1493 }
1494 }
1495 return true;
1496 }
1497
IsReduceAxisSupported(const TransposeContext & context,const utils::MutableNodeView & node,int rank)1498 bool ReduceTransposer::IsReduceAxisSupported(const TransposeContext& context,
1499 const utils::MutableNodeView& node,
1500 int rank) {
1501 if (KeepDims(node)) {
1502 return true;
1503 }
1504 const auto& regular_fanin_1 = node.GetRegularFanin(1);
1505 auto* axis_node = regular_fanin_1.node_view();
1506 if (!IsConstant(*axis_node->node())) {
1507 return false;
1508 }
1509 const auto* value_attr = axis_node->GetAttr(kAttrValue);
1510 if (value_attr == nullptr) {
1511 return false;
1512 }
1513 Tensor tensor;
1514 if (!tensor.FromProto(value_attr->tensor())) {
1515 LOG(ERROR) << "Failed to parse TensorProto.";
1516 return false;
1517 }
1518 auto indices = [&context](absl::Span<const char> labels) {
1519 return GetDimensionIndicesFromLabel(context.src_dim_indices, labels);
1520 };
1521 if (rank == 5) {
1522 return IsAlongAxis(tensor, indices({'N', 'D', 'H', 'W', 'C'}), 5) ||
1523 IsAlongAxis(tensor, indices({'D', 'H', 'W', 'C'}), 5) ||
1524 IsAlongAxis(tensor, indices({'N', 'D', 'H', 'W'}), 5) ||
1525 IsAlongAxis(tensor, indices({'D', 'H', 'W'}), 5) ||
1526 IsAlongAxis(tensor, indices({'C'}), 5);
1527 }
1528 DCHECK_EQ(rank, 4);
1529 return IsAlongAxis(tensor, indices({'N', 'H', 'W', 'C'}), 4) ||
1530 IsAlongAxis(tensor, indices({'H', 'W', 'C'}), 4) ||
1531 IsAlongAxis(tensor, indices({'N', 'H', 'W'}), 4) ||
1532 IsAlongAxis(tensor, indices({'H', 'W'}), 4) ||
1533 IsAlongAxis(tensor, indices({'C'}), 4);
1534 }
1535
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1536 Status ReduceTransposer::TransposeNode(TransposeContext* context,
1537 utils::MutableNodeView* node) {
1538 DCHECK(IsReduceOp(*node->node()));
1539 const int rank = GetFaninPortRank(*node, 0);
1540 if (rank != 4 && rank != 5) {
1541 return OkStatus();
1542 }
1543 ScopedDataFormatUpgrader data_format_upgrader(context, rank);
1544 if (!ShouldProcess(*context, *node) ||
1545 !IsReduceAxisSupported(*context, *node, rank) ||
1546 !IsAfterDstToSrcTransform(*context, *node)) {
1547 return OkStatus();
1548 }
1549 VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
1550 << "' with op '" << node->GetOp() << "' from data format '"
1551 << context->src_format << "' to '" << context->dst_format << "'";
1552 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
1553 TF_RETURN_IF_ERROR(
1554 UpdateFaninEdgesWithOp(context, {1}, node, kOpDataFormatDimMap));
1555 if (KeepDims(*node)) {
1556 TF_RETURN_IF_ERROR(
1557 UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1558 }
1559 return context->graph_view->GetMutationBuilder()->Apply();
1560 }
1561
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1562 Status ReverseV2Transposer::TransposeNode(TransposeContext* context,
1563 utils::MutableNodeView* node) {
1564 DCHECK(IsReverseV2(*node->node()));
1565 if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
1566 !IsAfterDstToSrcTransform(*context, *node)) {
1567 return OkStatus();
1568 }
1569 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
1570 TF_RETURN_IF_ERROR(
1571 UpdateFaninEdgesWithOp(context, {1}, node, kOpDataFormatDimMap));
1572 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1573 return context->graph_view->GetMutationBuilder()->Apply();
1574 }
1575
IsFaninScalarVector4D(const utils::MutableNodeView & fanin,int port)1576 bool SelectTransposer::IsFaninScalarVector4D(
1577 const utils::MutableNodeView& fanin, int port) {
1578 return IsFanoutPortRankN(fanin, port, 0) ||
1579 IsFanoutPortRankN(fanin, port, 1) || IsFanoutPortRankN(fanin, port, 4);
1580 }
1581
GetFaninPorts(const utils::MutableNodeView & fanin,int port)1582 std::vector<int> SelectTransposer::GetFaninPorts(
1583 const utils::MutableNodeView& fanin, int port) {
1584 // Input 0 could be a scalar, a vector with size matching the first dimension
1585 // of input 1 and 2, or must have the same shape as input 1 and 2.
1586 if (IsFanoutPortRankN(fanin, port, 4)) {
1587 return {0, 1, 2};
1588 }
1589 return {1, 2};
1590 }
1591
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1592 Status SelectTransposer::TransposeNode(TransposeContext* context,
1593 utils::MutableNodeView* node) {
1594 DCHECK(IsSelect(*node->node()));
1595 const auto& regular_fanin_0 = node->GetRegularFanin(0);
1596 auto* regular_fanin_0_node = regular_fanin_0.node_view();
1597 if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
1598 !IsFaninScalarVector4D(*regular_fanin_0_node, regular_fanin_0.index()) ||
1599 !IsAfterDstToSrcTransform(*context, *node)) {
1600 return OkStatus();
1601 }
1602 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(
1603 context, GetFaninPorts(*regular_fanin_0_node, regular_fanin_0.index()),
1604 node, kOpTranspose));
1605 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1606 return context->graph_view->GetMutationBuilder()->Apply();
1607 }
1608
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1609 Status ShapeTransposer::TransposeNode(TransposeContext* context,
1610 utils::MutableNodeView* node) {
1611 DCHECK(IsShape(*node->node()));
1612 const int rank = GetFaninPortRank(*node, 0);
1613 if (rank != 4 && rank != 5) {
1614 return OkStatus();
1615 }
1616 ScopedDataFormatUpgrader data_format_upgrader(context, rank);
1617 if (!ShouldProcess(*context, *node) ||
1618 !IsAfterDstToSrcTransform(*context, *node)) {
1619 return OkStatus();
1620 }
1621 VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
1622 << "' with op '" << node->GetOp() << "' from data format '"
1623 << context->src_format << "' to '" << context->dst_format << "'";
1624 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
1625 TF_RETURN_IF_ERROR(
1626 UpdateFanoutEdgesWithOp(context, {0}, node, kOpDataFormatVecPermute));
1627 return context->graph_view->GetMutationBuilder()->Apply();
1628 }
1629
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1630 Status ShapeNTransposer::TransposeNode(TransposeContext* context,
1631 utils::MutableNodeView* node) {
1632 DCHECK(IsShapeN(*node->node()));
1633 // ShapeN requires all input tensors to have the same dimensions. Therefore,
1634 // we simply use the 0th fanin port.
1635 const int rank = GetFaninPortRank(*node, 0);
1636 if (rank != 4 && rank != 5) {
1637 return OkStatus();
1638 }
1639 ScopedDataFormatUpgrader data_format_upgrader(context, rank);
1640 const auto ports = GetVariadicNDFaninPorts(*context, *node, rank);
1641 if (!ShouldProcess(*context, *node) || ports.empty()) {
1642 return OkStatus();
1643 }
1644 VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
1645 << "' with op '" << node->GetOp() << "' from data format '"
1646 << context->src_format << "' to '" << context->dst_format << "'";
1647 TF_RETURN_IF_ERROR(
1648 UpdateFaninEdgesWithOp(context, ports, node, kOpTranspose));
1649 TF_RETURN_IF_ERROR(
1650 UpdateFanoutEdgesWithOp(context, ports, node, kOpDataFormatVecPermute));
1651 return context->graph_view->GetMutationBuilder()->Apply();
1652 }
1653
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1654 Status SliceTransposer::TransposeNode(TransposeContext* context,
1655 utils::MutableNodeView* node) {
1656 DCHECK(IsSlice(*node->node()));
1657 const int rank = GetFanoutPortRank(*node, 0);
1658 if (rank != 4 && rank != 5) {
1659 return OkStatus();
1660 }
1661 ScopedDataFormatUpgrader data_format_upgrader(context, rank);
1662 if (!ShouldProcess(*context, *node) ||
1663 !IsFaninPortsDimsNIfConst(*node, {1, 2}, {rank}) ||
1664 !IsAfterDstToSrcTransform(*context, *node)) {
1665 return OkStatus();
1666 }
1667 VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
1668 << "' with op '" << node->GetOp() << "' from data format '"
1669 << context->src_format << "' to '" << context->dst_format << "'";
1670 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
1671 TF_RETURN_IF_ERROR(
1672 UpdateFaninEdgesWithOp(context, {1, 2}, node, kOpDataFormatVecPermute));
1673 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1674 return context->graph_view->GetMutationBuilder()->Apply();
1675 }
1676
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1677 Status SplitTransposer::TransposeNode(TransposeContext* context,
1678 utils::MutableNodeView* node) {
1679 DCHECK(IsSplit(*node->node()));
1680 const auto ports = GetDataFanoutPorts(*node);
1681 if (!ShouldProcess(*context, *node) || !IsFanoutPortsRankN(*node, ports, 4) ||
1682 !IsAfterDstToSrcTransform(*context, *node)) {
1683 return OkStatus();
1684 }
1685 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {1}, node, kOpTranspose));
1686 TF_RETURN_IF_ERROR(
1687 UpdateFaninEdgesWithOp(context, {0}, node, kOpDataFormatDimMap));
1688 TF_RETURN_IF_ERROR(
1689 UpdateFanoutEdgesWithOp(context, ports, node, kOpTranspose));
1690 return context->graph_view->GetMutationBuilder()->Apply();
1691 }
1692
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1693 Status SplitVTransposer::TransposeNode(TransposeContext* context,
1694 utils::MutableNodeView* node) {
1695 DCHECK(IsSplitV(*node->node()));
1696 const auto ports = GetDataFanoutPorts(*node);
1697 if (!ShouldProcess(*context, *node) || !IsFanoutPortsRankN(*node, ports, 4) ||
1698 !IsAfterDstToSrcTransform(*context, *node)) {
1699 return OkStatus();
1700 }
1701 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
1702 TF_RETURN_IF_ERROR(
1703 UpdateFaninEdgesWithOp(context, {2}, node, kOpDataFormatDimMap));
1704 TF_RETURN_IF_ERROR(
1705 UpdateFanoutEdgesWithOp(context, ports, node, kOpTranspose));
1706 return context->graph_view->GetMutationBuilder()->Apply();
1707 }
1708
IsInputConvertible(const TransposeContext & context,const utils::MutableNodeView & node) const1709 bool SqueezeTransposer::IsInputConvertible(
1710 const TransposeContext& context, const utils::MutableNodeView& node) const {
1711 const auto& regular_fanin_0 = node.GetRegularFanin(0);
1712 auto* regular_fanin_0_node = regular_fanin_0.node_view();
1713 const auto* output_shape_attr =
1714 regular_fanin_0_node->GetAttr(kAttrOutputShape);
1715 if (output_shape_attr != nullptr) {
1716 auto& shape = output_shape_attr->list().shape(regular_fanin_0.index());
1717 if (shape.dim_size() != kRank) {
1718 return false;
1719 }
1720 const int height_dim = context.src_dim_indices.at('H');
1721 const int width_dim = context.src_dim_indices.at('W');
1722 if (shape.dim(height_dim).size() == 1 && shape.dim(width_dim).size() == 1) {
1723 return true;
1724 }
1725 }
1726 return false;
1727 }
1728
IsAlongAxis(const AttrValue & attr,absl::Span<const int> axis,int rank) const1729 bool SqueezeTransposer::IsAlongAxis(const AttrValue& attr,
1730 absl::Span<const int> axis,
1731 int rank) const {
1732 const auto& list = attr.list();
1733 // If list is empty, Squeeze op will squeeze all dimensions of size 1.
1734 int axis_size = axis.size();
1735 if (list.i_size() == 0) {
1736 return true;
1737 } else if (list.i_size() != axis_size) {
1738 return false;
1739 }
1740 for (int i = 0; i < axis_size; ++i) {
1741 int local_axis = list.i(i);
1742 if (local_axis < 0) {
1743 local_axis += rank;
1744 }
1745 bool along_axis = false;
1746 for (int dim : axis) {
1747 if (local_axis == dim) {
1748 along_axis = true;
1749 break;
1750 }
1751 }
1752 if (!along_axis) {
1753 return false;
1754 }
1755 }
1756 return true;
1757 }
1758
IsDimsSupported(const TransposeContext & context,const utils::MutableNodeView & node) const1759 bool SqueezeTransposer::IsDimsSupported(
1760 const TransposeContext& context, const utils::MutableNodeView& node) const {
1761 auto indices = [&context](absl::Span<const char> labels) {
1762 return GetDimensionIndicesFromLabel(context.src_dim_indices, labels);
1763 };
1764 const auto* squeeze_dims_attr = node.GetAttr(kAttrSqueezeDims);
1765 if (squeeze_dims_attr == nullptr) {
1766 return false;
1767 }
1768 return (IsFanoutPortRankN(node, 0, 2) &&
1769 IsAlongAxis(*squeeze_dims_attr, indices({'H', 'W'}), kRank)) ||
1770 (IsFanoutPortRankN(node, 0, 1) &&
1771 IsAlongAxis(*squeeze_dims_attr, indices({'N', 'H', 'W'}), kRank));
1772 }
1773
UpdateSqueezeDims(TransposeContext * context,utils::MutableNodeView * node)1774 Status SqueezeTransposer::UpdateSqueezeDims(TransposeContext* context,
1775 utils::MutableNodeView* node) {
1776 const auto* squeeze_dims_attr = node->GetAttr(kAttrSqueezeDims);
1777 if (squeeze_dims_attr == nullptr) {
1778 return errors::InvalidArgument("Missing attribute ", kAttrSqueezeDims);
1779 }
1780 const int num_input_dims = context->src_format.length();
1781 const int min_squeeze_dim = -num_input_dims;
1782 std::vector<int> squeeze_dims_mapped;
1783 const int squeeze_dims_size = squeeze_dims_attr->list().i_size();
1784 squeeze_dims_mapped.reserve(squeeze_dims_size);
1785 for (int i = 0; i < squeeze_dims_size; ++i) {
1786 int dim = squeeze_dims_attr->list().i(i);
1787 if (dim < min_squeeze_dim || dim >= num_input_dims) {
1788 return errors::InvalidArgument(
1789 "Attribute '", kAttrSqueezeDims, "' contains out of range index '",
1790 dim, "', index must be between [", min_squeeze_dim, ", ",
1791 num_input_dims, ")");
1792 }
1793 if (dim < 0) {
1794 dim += num_input_dims;
1795 }
1796 squeeze_dims_mapped.push_back(context->dst_to_src[dim]);
1797 }
1798 std::sort(squeeze_dims_mapped.begin(), squeeze_dims_mapped.end());
1799 AttrValue squeeze_dims;
1800 squeeze_dims.mutable_list()->mutable_i()->Reserve(squeeze_dims_size);
1801 for (const auto& dim : squeeze_dims_mapped) {
1802 squeeze_dims.mutable_list()->mutable_i()->Add(dim);
1803 }
1804 context->graph_view->GetMutationBuilder()->AddOrUpdateNodeAttr(
1805 node, kAttrSqueezeDims, squeeze_dims);
1806 return OkStatus();
1807 }
1808
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1809 Status SqueezeTransposer::TransposeNode(TransposeContext* context,
1810 utils::MutableNodeView* node) {
1811 DCHECK(IsSqueeze(*node->node()));
1812 if (!ShouldProcess(*context, *node) || !IsDimsSupported(*context, *node) ||
1813 !IsInputConvertible(*context, *node) ||
1814 !IsAfterDstToSrcTransform(*context, *node)) {
1815 return OkStatus();
1816 }
1817 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
1818 TF_RETURN_IF_ERROR(UpdateSqueezeDims(context, node));
1819 return context->graph_view->GetMutationBuilder()->Apply();
1820 }
1821
IsMaskZero(const utils::MutableNodeView & node,absl::string_view mask)1822 bool StridedSliceTransposer::IsMaskZero(const utils::MutableNodeView& node,
1823 absl::string_view mask) {
1824 const auto* mask_attr = node.GetAttr(mask);
1825 if (mask_attr != nullptr) {
1826 return mask_attr->i() == 0;
1827 }
1828 return true;
1829 }
1830
HasOnlyBeginEndMask(const utils::MutableNodeView & node)1831 bool StridedSliceTransposer::HasOnlyBeginEndMask(
1832 const utils::MutableNodeView& node) {
1833 return IsMaskZero(node, "ellipsis_mask") &&
1834 IsMaskZero(node, "new_axis_mask") &&
1835 IsMaskZero(node, "shrink_axis_mask");
1836 }
1837
PermuteMask(TransposeContext * context,utils::MutableNodeView * node,absl::string_view mask)1838 Status StridedSliceTransposer::PermuteMask(TransposeContext* context,
1839 utils::MutableNodeView* node,
1840 absl::string_view mask) {
1841 // Computers the permutation of the masks based on the src and dst format.
1842 // For example:
1843 // src_format = NHWC
1844 // dst_format = NCHW
1845 // src_to_dst permutation = [0, 3, 1, 2].
1846 // mask : 0010 [Note the bit positions correspond to indexes i.e this is in
1847 // reverse order of the src format (CWHN)] result : 0100 (WHCN)
1848 const auto* mask_attr = node->GetAttr(mask);
1849 const int mask_i = mask_attr != nullptr ? mask_attr->i() : 0;
1850 if (mask_i < 0 || mask_i > 15) {
1851 return errors::InvalidArgument("invalid mask value: ", mask_i);
1852 }
1853 int result = 0;
1854 for (int i = 0, end = context->src_to_dst.size(); i < end; i++) {
1855 const int final_pos = context->src_to_dst[i];
1856 const int position_mask = 1 << final_pos;
1857 const int bit_i = (mask_i & position_mask) >> final_pos;
1858 result |= bit_i << i;
1859 }
1860 AttrValue new_mask_attr;
1861 new_mask_attr.set_i(result);
1862 context->graph_view->GetMutationBuilder()->AddOrUpdateNodeAttr(node, mask,
1863 new_mask_attr);
1864 return OkStatus();
1865 }
1866
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1867 Status StridedSliceTransposer::TransposeNode(TransposeContext* context,
1868 utils::MutableNodeView* node) {
1869 DCHECK(IsStridedSlice(*node->node()));
1870 if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
1871 !IsFaninPortsDimsNIfConst(*node, {1, 2, 3}, {4}) ||
1872 !HasOnlyBeginEndMask(*node) ||
1873 !IsAfterDstToSrcTransform(*context, *node)) {
1874 return OkStatus();
1875 }
1876 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
1877 TF_RETURN_IF_ERROR(PermuteMask(context, node, "begin_mask"));
1878 TF_RETURN_IF_ERROR(PermuteMask(context, node, "end_mask"));
1879 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {1, 2, 3}, node,
1880 kOpDataFormatVecPermute));
1881 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1882 return context->graph_view->GetMutationBuilder()->Apply();
1883 }
1884
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1885 Status SwitchTransposer::TransposeNode(TransposeContext* context,
1886 utils::MutableNodeView* node) {
1887 DCHECK(IsSwitch(*node->node()));
1888 if (!ShouldProcess(*context, *node) || !IsFaninPortRankN(*node, 0, 4) ||
1889 !IsAfterDstToSrcTransform(*context, *node)) {
1890 return OkStatus();
1891 }
1892 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
1893 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, GetDataFanoutPorts(*node),
1894 node, kOpTranspose));
1895 return context->graph_view->GetMutationBuilder()->Apply();
1896 }
1897
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1898 Status TernaryOpTransposer::TransposeNode(TransposeContext* context,
1899 utils::MutableNodeView* node) {
1900 DCHECK(IsTernaryOp(*node->node()));
1901 if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
1902 !IsAfterDstToSrcTransform(*context, *node)) {
1903 return OkStatus();
1904 }
1905 TF_RETURN_IF_ERROR(
1906 UpdateFaninEdgesWithOp(context, {0, 1, 2}, node, kOpTranspose));
1907 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1908 return context->graph_view->GetMutationBuilder()->Apply();
1909 }
1910
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1911 Status TileTransposer::TransposeNode(TransposeContext* context,
1912 utils::MutableNodeView* node) {
1913 DCHECK(IsTile(*node->node()));
1914 if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) ||
1915 !IsFaninPortDimsNIfConst(*node, 1, {4}) ||
1916 !IsAfterDstToSrcTransform(*context, *node)) {
1917 return OkStatus();
1918 }
1919 TF_RETURN_IF_ERROR(UpdateFaninEdgesWithOp(context, {0}, node, kOpTranspose));
1920 TF_RETURN_IF_ERROR(
1921 UpdateFaninEdgesWithOp(context, {1}, node, kOpDataFormatVecPermute));
1922 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1923 return context->graph_view->GetMutationBuilder()->Apply();
1924 }
1925
TransposeNode(TransposeContext * context,utils::MutableNodeView * node)1926 Status UnaryGradTransposer::TransposeNode(TransposeContext* context,
1927 utils::MutableNodeView* node) {
1928 DCHECK(IsUnaryGrad(*node->node()));
1929 const int rank = GetFanoutPortRank(*node, 0);
1930 if (rank != 4 && rank != 5) {
1931 return OkStatus();
1932 }
1933 ScopedDataFormatUpgrader data_format_upgrader(context, rank);
1934 if (!ShouldProcess(*context, *node) ||
1935 !IsAfterDstToSrcTransform(*context, *node)) {
1936 return OkStatus();
1937 }
1938 VLOG(3) << "GenericLayoutOptimizer: transforming node '" << node->GetName()
1939 << "' with op '" << node->GetOp() << "' from data format '"
1940 << context->src_format << "' to '" << context->dst_format << "'";
1941 TF_RETURN_IF_ERROR(
1942 UpdateFaninEdgesWithOp(context, {0, 1}, node, kOpTranspose));
1943 TF_RETURN_IF_ERROR(UpdateFanoutEdgesWithOp(context, {0}, node, kOpTranspose));
1944 return context->graph_view->GetMutationBuilder()->Apply();
1945 }
1946
1947 // Utils.
1948
GetDeviceName(const NodeDef & node)1949 string GetDeviceName(const NodeDef& node) { return node.device(); }
1950
IsDefaultLayoutSensitiveOp(const NodeDef & node)1951 bool IsDefaultLayoutSensitiveOp(const NodeDef& node) {
1952 static absl::flat_hash_set<string>* default_layout_sensitive_ops =
1953 new absl::flat_hash_set<std::string>(
1954 {"AvgPool", "Conv2D", "DepthwiseConv2dNative", "DepthToSpace",
1955 "FusedBatchNorm", "FusedBatchNormV2", "FusedBatchNormV3",
1956 "FusedConv2DBiasActivation", "MaxPool", "SpaceToDepth"});
1957 return default_layout_sensitive_ops->find(node.op()) !=
1958 default_layout_sensitive_ops->end();
1959 }
1960
IsLayoutSensitiveOp(const NodeDef & node)1961 bool IsLayoutSensitiveOp(const NodeDef& node) {
1962 return IsDefaultLayoutSensitiveOp(node) || IsAvgPoolGrad(node) ||
1963 IsBiasAddV2(node) || IsBiasAddGrad(node) ||
1964 IsConv2DBackpropFilter(node) || IsConv2DBackpropInput(node) ||
1965 IsDepthwiseConv2dNativeBackpropFilter(node) ||
1966 IsDepthwiseConv2dNativeBackpropInput(node) ||
1967 IsFusedBatchNormEx(node) || IsFusedBatchNormGrad(node) ||
1968 IsMaxPoolV2(node) || IsMaxPoolGrad(node) || IsMaxPoolGradV2(node) ||
1969 IsMaxPoolGradGradV1(node) || IsMaxPoolGradGradV2(node) ||
1970 IsConv3D(node) || IsConv3DBackpropInputV2(node) ||
1971 IsConv3DBackpropFilterV2(node);
1972 }
1973
IsDefaultLayoutAgnosticOp(const NodeDef & node)1974 bool IsDefaultLayoutAgnosticOp(const NodeDef& node) {
1975 static absl::flat_hash_set<string>* agnostic_nodes =
1976 new absl::flat_hash_set<std::string>({"Abs",
1977 "Acos",
1978 "Acosh",
1979 "Angle",
1980 "Asin",
1981 "Asinh",
1982 "Atan",
1983 "Atanh",
1984 "Bitcast",
1985 "Cast",
1986 "Ceil",
1987 "CheckNumerics",
1988 "ComplexAbs",
1989 "Conj",
1990 "Cos",
1991 "Cosh",
1992 "Digamma",
1993 "Elu",
1994 "Enter",
1995 "Erf",
1996 "Erfc",
1997 "Exit",
1998 "Exp",
1999 "Expm1",
2000 "FakeQuantWithMinMaxVars",
2001 "FakeQuantWithMinMaxArgs",
2002 "Floor",
2003 "GuaranteeConst",
2004 "Identity",
2005 "Imag",
2006 "Inv",
2007 "IsFinite",
2008 "IsInf",
2009 "IsNan",
2010 "LeakyRelu",
2011 "Lgamma",
2012 "Log",
2013 "LogicalNot",
2014 "Log1p",
2015 "Neg",
2016 "NextIteration",
2017 "OnesLike",
2018 "PreventGradient",
2019 "QuantizeAndDequantizeV2",
2020 "QuantizeAndDequantizeV3",
2021 "QuantizeAndDequantizeV4",
2022 "Real",
2023 "Reciprocal",
2024 "Relu",
2025 "Relu6",
2026 "Rint",
2027 "Selu",
2028 "Sigmoid",
2029 "Sign",
2030 "Sin",
2031 "Sinh",
2032 "Snapshot",
2033 "Softplus",
2034 "Round",
2035 "Rsqrt",
2036 "Sqrt",
2037 "Square",
2038 "StopGradient",
2039 "Tan",
2040 "Tanh",
2041 "ZerosLike"});
2042 return agnostic_nodes->find(node.op()) != agnostic_nodes->end();
2043 }
2044
IsLayoutAgnosticOp(const NodeDef & node)2045 bool IsLayoutAgnosticOp(const NodeDef& node) {
2046 return IsDefaultLayoutAgnosticOp(node) || IsAddN(node) || IsBinaryOp(node) ||
2047 IsIdentityN(node) || IsMerge(node) || IsMirrorPad(node) ||
2048 IsMirrorPadGrad(node) || IsPad(node) || IsSelect(node) ||
2049 IsSwitch(node) || IsTernaryOp(node) || IsUnaryGrad(node) ||
2050 IsConcat(node) || IsReverseV2(node) || IsTile(node) || IsShape(node) ||
2051 IsShapeN(node) || IsFill(node) || IsSlice(node) || IsSplit(node) ||
2052 IsSqueeze(node) || IsSplitV(node) || IsStridedSlice(node) ||
2053 IsReduceOp(node);
2054 }
2055
IsTernaryOp(const NodeDef & node)2056 bool IsTernaryOp(const NodeDef& node) { return IsBetainc(node); }
2057
IsUnaryGrad(const NodeDef & node)2058 bool IsUnaryGrad(const NodeDef& node) {
2059 bool is_unary_grad =
2060 IsEluGrad(node) || IsInvGrad(node) || IsLeakyReluGrad(node) ||
2061 IsReciprocalGrad(node) || IsRelu6Grad(node) || IsReluGrad(node) ||
2062 IsRsqrtGrad(node) || IsSeluGrad(node) || IsSigmoidGrad(node) ||
2063 IsSoftplusGrad(node) || IsSoftsignGrad(node) || IsSqrtGrad(node) ||
2064 IsTanhGrad(node);
2065 return is_unary_grad;
2066 }
2067
IsMaxPoolV2(const NodeDef & node)2068 bool IsMaxPoolV2(const NodeDef& node) { return node.op() == "MaxPoolV2"; }
2069
IsMaxPoolGradV2(const NodeDef & node)2070 bool IsMaxPoolGradV2(const NodeDef& node) {
2071 return node.op() == "MaxPoolGradV2";
2072 }
2073
IsMaxPoolGradGradV1(const NodeDef & node)2074 bool IsMaxPoolGradGradV1(const NodeDef& node) {
2075 return node.op() == "MaxPoolGradGrad";
2076 }
2077
IsMaxPoolGradGradV2(const NodeDef & node)2078 bool IsMaxPoolGradGradV2(const NodeDef& node) {
2079 return node.op() == "MaxPoolGradGradV2";
2080 }
2081
IsBinaryOp(const NodeDef & node)2082 bool IsBinaryOp(const NodeDef& node) {
2083 bool is_binary =
2084 IsAdd(node) || IsAtan2(node) || IsComparisonOp(node) || IsComplex(node) ||
2085 IsDiv(node) || IsFloorDiv(node) || IsIgamma(node) || IsIgammac(node) ||
2086 IsLogicalAnd(node) || IsLogicalOr(node) || IsMaximum(node) ||
2087 IsMinimum(node) || IsMod(node) || IsMul(node) || IsPolygamma(node) ||
2088 IsPow(node) || IsRealDiv(node) || IsSquaredDifference(node) ||
2089 IsSub(node) || IsTruncateDiv(node) || IsTruncateMod(node) || IsZeta(node);
2090 return is_binary;
2091 }
2092
IsReduceOp(const NodeDef & node)2093 bool IsReduceOp(const NodeDef& node) {
2094 return IsSum(node) || IsMean(node) || IsProd(node) || IsMax(node) ||
2095 IsMin(node) || IsAll(node) || IsAny(node);
2096 }
2097
GetDataFaninPorts(const utils::MutableNodeView & node)2098 std::vector<int> GetDataFaninPorts(const utils::MutableNodeView& node) {
2099 const auto* node_def = node.node();
2100 if (IsAvgPoolGrad(*node_def) || IsSplit(*node_def)) {
2101 return {1};
2102 }
2103 if (IsStridedSliceGrad(*node_def)) {
2104 return {4};
2105 }
2106 if (IsBinaryOp(*node_def) || IsUnaryGrad(*node_def)) {
2107 return {0, 1};
2108 }
2109 if (IsTernaryOp(*node_def) || IsSelect(*node_def) ||
2110 IsMaxPoolGrad(*node_def) || IsMaxPoolGradV2(*node_def) ||
2111 IsMaxPoolGradGradV1(*node_def) || IsMaxPoolGradGradV2(*node_def)) {
2112 return {0, 1, 2};
2113 }
2114 if (IsShapeN(*node_def) || IsIdentityN(*node_def) || IsAddN(*node_def) ||
2115 IsMerge(*node_def)) {
2116 return GetRegularFaninPorts(node);
2117 }
2118 if (IsConcat(*node_def)) {
2119 return GetConcatDataFaninPorts(node);
2120 }
2121 if (node.NumRegularFanins() > 0) {
2122 return {0};
2123 }
2124 return {};
2125 }
2126
GetDataFanoutPorts(const utils::MutableNodeView & node)2127 std::vector<int> GetDataFanoutPorts(const utils::MutableNodeView& node) {
2128 const auto* node_def = node.node();
2129 if (IsIdentityN(*node_def) || IsShape(*node_def) || IsShapeN(*node_def)) {
2130 return GetDataFaninPorts(node);
2131 }
2132 if (IsSplit(*node_def) || IsSplitV(*node_def)) {
2133 const auto* num_split_attr = node.GetAttr(kAttrNumSplit);
2134 if (num_split_attr == nullptr) {
2135 return {0};
2136 }
2137 std::vector<int> values(num_split_attr->i());
2138 std::iota(values.begin(), values.end(), 0);
2139 return values;
2140 }
2141 if (IsSwitch(*node_def)) {
2142 const auto* num_outs_attr = node.GetAttr(kAttrNumOuts);
2143 const int num_outs = num_outs_attr != nullptr ? num_outs_attr->i() : 2;
2144 std::vector<int> values(num_outs);
2145 std::iota(values.begin(), values.end(), 0);
2146 return values;
2147 }
2148 return {0};
2149 }
2150
GetValueAttrFromConstInputNode(const utils::MutableNodeView & node,const std::function<bool (const NodeDef &)> & predicate,int index,Tensor * tensor)2151 bool GetValueAttrFromConstInputNode(
2152 const utils::MutableNodeView& node,
2153 const std::function<bool(const NodeDef&)>& predicate, int index,
2154 Tensor* tensor) {
2155 if (!predicate(*node.node())) {
2156 return false;
2157 }
2158 const auto& regular_fanin = node.GetRegularFanin(index);
2159 auto* regular_fanin_node = regular_fanin.node_view();
2160 if (!IsConstant(*regular_fanin_node->node())) {
2161 return false;
2162 }
2163 const auto* value_attr = regular_fanin_node->GetAttr(kAttrValue);
2164 if (value_attr == nullptr || value_attr->tensor().dtype() != DT_INT32) {
2165 return false;
2166 }
2167 if (!tensor->FromProto(value_attr->tensor())) {
2168 return false;
2169 }
2170
2171 return true;
2172 }
2173
IsDataFormatOp(const utils::MutableNodeView & node)2174 bool IsDataFormatOp(const utils::MutableNodeView& node) {
2175 const string& op = node.GetOp();
2176 return op == kOpDataFormatDimMap || op == kOpDataFormatVecPermute;
2177 }
2178
GetDimensionIndices(absl::string_view data_format)2179 absl::flat_hash_map<char, int> GetDimensionIndices(
2180 absl::string_view data_format) {
2181 const int size = data_format.size();
2182 absl::flat_hash_map<char, int> index;
2183 index.reserve(size);
2184 for (int i = 0; i < size; i++) {
2185 index[data_format[i]] = i;
2186 }
2187 return index;
2188 }
2189
GetPermutation(const absl::flat_hash_map<char,int> & src_dim_indices,absl::string_view dst_format)2190 std::vector<int> GetPermutation(
2191 const absl::flat_hash_map<char, int>& src_dim_indices,
2192 absl::string_view dst_format) {
2193 // Generate permutation for transformation between src and dst format.
2194 // Example:
2195 // src = NWHC, dst = NCWH
2196 // index = { N:0 W:1 H:2 C:3 }
2197 // permutation = [0, 3, 1, 2]
2198 DCHECK(src_dim_indices.size() == dst_format.size());
2199 std::vector<int> permutation;
2200 const int size = dst_format.size();
2201 permutation.reserve(size);
2202 for (int i = 0; i < size; i++) {
2203 permutation.push_back(src_dim_indices.at(dst_format[i]));
2204 }
2205 return permutation;
2206 }
2207
2208 } // namespace grappler
2209 } // namespace tensorflow
2210