1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_GENERIC_LAYOUT_OPTIMIZER_TRANSPOSER_H_
17 #define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_GENERIC_LAYOUT_OPTIMIZER_TRANSPOSER_H_
18 
19 #include <memory>
20 #include <vector>
21 
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/string_view.h"
26 #include "absl/types/span.h"
27 #include "tensorflow/core/framework/graph.pb.h"
28 #include "tensorflow/core/framework/node_def.pb.h"
29 #include "tensorflow/core/grappler/costs/graph_properties.h"
30 #include "tensorflow/core/grappler/utils.h"
31 #include "tensorflow/core/grappler/utils/frame.h"
32 #include "tensorflow/core/grappler/utils/graph_view.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/core/status.h"
35 
36 namespace tensorflow {
37 namespace grappler {
38 
39 constexpr char kAttrSrcFormat[] = "src_format";
40 constexpr char kAttrDstFormat[] = "dst_format";
41 constexpr char kAttrOutputShape[] = "_output_shapes";
42 constexpr char kGPU[] = "GPU";
43 constexpr char kCPU[] = "CPU";
44 
45 // TransposeContext owns all data members. Must initialize GraphProperties,
46 // FrameView, GraphDef and MutableGraphView with the same graph. NodeDef
47 // pointers in FrameView, GraphDef and MutableGraphView must point to nodes in
48 // the same GraphDef instance.
49 struct TransposeContext {
50   // Initializes TransposeContext with given GrapplerItem. Because initializing
51   // FrameMap and GraphProperties may return error, we initialize
52   // TransposeContext outside constructor.
53   static Status InitializeTransposeContext(bool assume_valid_feeds,
54                                            const GrapplerItem& item,
55                                            const Cluster* cluster,
56                                            TransposeContext* context);
57 
InitializeTransposeContextTransposeContext58   static Status InitializeTransposeContext(const GrapplerItem& item,
59                                            const Cluster* cluster,
60                                            TransposeContext* context) {
61     return InitializeTransposeContext(false, item, cluster, context);
62   }
63 
64   // Sets data formats to convert from and to for specified device type.
65   void AssignDeviceAndDataFormats(absl::string_view target_device,
66                                   absl::string_view src_format,
67                                   absl::string_view dst_format);
68 
69   FrameView frames;
70   GraphDef graph;
71   // Number of nodes in the original graph. As new nodes are appended to the end
72   // of the graph, all new nodes should have a node index greater than or equal
73   // to this.
74   int num_nodes;
75   absl::flat_hash_set<string> nodes_to_preserve;
76   std::unique_ptr<GraphProperties> graph_properties;
77   std::unique_ptr<utils::MutableGraphView> graph_view;
78 
79   string target_device;
80   string src_format;
81   string dst_format;
82   absl::flat_hash_map<char, int> src_dim_indices;
83   absl::flat_hash_map<char, int> dst_dim_indices;
84   std::vector<int> src_to_dst;
85   std::vector<int> dst_to_src;
86 
87   string enforced_layout;
88 };
89 
90 class Transposer {
91  public:
Transposer()92   explicit Transposer() {}
93 
94   Transposer(const Transposer&) = delete;
95   Transposer& operator=(const Transposer&) = delete;
96 
~Transposer()97   virtual ~Transposer() {}
98 
99   // Returns true iff the node should be processed by this transposer.
100   // NodeProcessors may perform additional oprand specific checks before
101   // processing if necessary.
102   // Following common conditions are checked:
103   // * node's device matches target device
104   // * node's source format matches config's source format
105   // * node has output
106   bool ShouldProcess(const TransposeContext& context,
107                      const utils::MutableNodeView& node) const;
108 
109   // Transposes given node from src format to dst format. Also perform other
110   // necessary operations to guarantee the graph produce the same result.
111   // Eg. Add Transpose node sets before fanin ports and after fanout ports.
112   virtual Status TransposeNode(TransposeContext* context,
113                                utils::MutableNodeView* node) = 0;
114 
115   // Creates a Const node for permutation. If node with node_name already exits,
116   // return and reuse it.
117   Status CreateConstPermNode(TransposeContext* context,
118                              absl::string_view node_name,
119                              absl::string_view device,
120                              absl::Span<const int> permutation,
121                              absl::string_view control_node_name,
122                              utils::MutationNewNode* added_node);
123 
124   // Creates a TransposeNode with given properties. If node with node_name
125   // already exits, return and reuse it.
126   // A const perm node is also created and connected to the 2nd fanin.
127   // control_node_name is ignored if it is empty.
128   Status CreateTransposeNode(
129       TransposeContext* context, absl::string_view name_format,
130       const DataType& data_type, absl::string_view device,
131       TensorShapeProto fanin_shape, absl::Span<const int> permutation,
132       absl::string_view control_node_name, utils::MutationNewNode* added_node,
133       string* transpose_node_name);
134 
135   // Update all edges between dst_node->fanin[dst_ports] and dst_node by
136   // inserting an op node.
137   Status UpdateFaninEdgesWithOp(TransposeContext* context,
138                                 absl::Span<const int> dst_ports,
139                                 utils::MutableNodeView* dst_node,
140                                 absl::string_view op);
141 
142   // Update all edges between src_node:src_ports and nodes take
143   // src_node:src_ports as fanin. Also update attr _output_shape of src_node.
144   Status UpdateFanoutEdgesWithOp(TransposeContext* context,
145                                  absl::Span<const int> src_ports,
146                                  utils::MutableNodeView* src_node,
147                                  absl::string_view op);
148 
149   // Creates a DataFromat node with given properties.
150   // DataFromat op is either DataFormatVecPermute or DataFormatDimMap.
151   Status CreateDataFormatNode(TransposeContext* context,
152                               absl::string_view node_name, absl::string_view op,
153                               absl::string_view device,
154                               const DataType& data_type, bool is_fanin_on_host,
155                               bool is_src_format_to_dst_format,
156                               utils::MutationNewNode* added_node);
157 
158  protected:
159   int GetFanoutPortRank(const utils::MutableNodeView& node, int port) const;
160   bool IsFanoutPortRankN(const utils::MutableNodeView& node, int port,
161                          int n) const;
162   bool IsFanoutPortsRankN(const utils::MutableNodeView& node,
163                           absl::Span<const int> ports, int n) const;
164   int GetFaninPortRank(const utils::MutableNodeView& node, int port) const;
165   bool IsFaninPortRankN(const utils::MutableNodeView& node, int port,
166                         int n) const;
167 
168   // Checks if fanin at specified port(s) has dimensions `dims` iff fanin is a
169   // Const. If fanin is not a Const, no dimensions will be checked and this will
170   // return true.
171   bool IsFaninPortDimsNIfConst(const utils::MutableNodeView& node, int port,
172                                absl::Span<const int> dims) const;
173   bool IsFaninPortsDimsNIfConst(const utils::MutableNodeView& node,
174                                 absl::Span<const int> ports,
175                                 absl::Span<const int> dims) const;
176   bool CanProcessNode(const TransposeContext& context,
177                       const utils::MutableNodeView& node) const;
178   // Update all edges between dst_node->fanin[dst_ports] and dst_node.
179   // A node with op is created and inserted between all edges.
180   // op is one of Transpose, DataFormatVecPermute or DataFormatDimMap.
181   Status UpdateEdge(TransposeContext* context, absl::string_view name_format,
182                     absl::string_view op, const AttrValue* input_shape,
183                     bool is_in_frame, bool is_src_format_to_dst_format,
184                     const int src_port, const int dst_port,
185                     utils::MutableNodeView* src_node,
186                     utils::MutableNodeView* dst_node);
187   string GetFaninNameFormat(absl::string_view node_name, int port,
188                             absl::string_view src_format,
189                             absl::string_view dst_format);
190   string GetFanoutNameFormat(absl::string_view node_name, int port, int index,
191                              absl::string_view src_format,
192                              absl::string_view dst_format);
193   string LayoutOptimizerNode(absl::string_view node_name);
194   string GetReshapeNodeNameFormat(absl::string_view node_name, int index,
195                                   absl::string_view src_format,
196                                   absl::string_view dst_format);
197   string GetShapeConstNodeNameFormat(absl::string_view node_name, int index);
198 };
199 
200 class LayoutSensitiveOpTransposer : public Transposer {
201  public:
LayoutSensitiveOpTransposer()202   explicit LayoutSensitiveOpTransposer() : Transposer() {}
203 
204   // Updates attrs data_format, ksize, strides of the given node to dst_format.
205   // _output_shape is updated during UpdateOutputEdges.
206   Status UpdateNode(TransposeContext* context, utils::MutableNodeView* node);
207 };
208 
209 // Layout sensitive op transposers.
210 
211 class DefaultLayoutSensitiveOpTransposer : public LayoutSensitiveOpTransposer {
212  public:
DefaultLayoutSensitiveOpTransposer()213   explicit DefaultLayoutSensitiveOpTransposer()
214       : LayoutSensitiveOpTransposer() {}
215 
216   Status TransposeNode(TransposeContext* context,
217                        utils::MutableNodeView* node) override;
218 };
219 
220 class BiasAddTransposer : public LayoutSensitiveOpTransposer {
221  public:
BiasAddTransposer()222   explicit BiasAddTransposer() : LayoutSensitiveOpTransposer() {}
223 
224   Status TransposeNode(TransposeContext* context,
225                        utils::MutableNodeView* node) override;
226 };
227 
228 class AvgPoolGradTransposer : public LayoutSensitiveOpTransposer {
229  public:
AvgPoolGradTransposer()230   explicit AvgPoolGradTransposer() : LayoutSensitiveOpTransposer() {}
231 
232   Status TransposeNode(TransposeContext* context,
233                        utils::MutableNodeView* node) override;
234 };
235 
236 class BiasAddGradTransposer : public LayoutSensitiveOpTransposer {
237  public:
BiasAddGradTransposer()238   explicit BiasAddGradTransposer() : LayoutSensitiveOpTransposer() {}
239 
240   Status TransposeNode(TransposeContext* context,
241                        utils::MutableNodeView* node) override;
242 };
243 
244 class Conv2DBackpropFilterTransposer : public LayoutSensitiveOpTransposer {
245  public:
Conv2DBackpropFilterTransposer()246   explicit Conv2DBackpropFilterTransposer() : LayoutSensitiveOpTransposer() {}
247 
248   Status TransposeNode(TransposeContext* context,
249                        utils::MutableNodeView* node) override;
250 };
251 
252 class Conv2DBackpropInputTransposer : public LayoutSensitiveOpTransposer {
253  public:
Conv2DBackpropInputTransposer()254   explicit Conv2DBackpropInputTransposer() : LayoutSensitiveOpTransposer() {}
255 
256   Status TransposeNode(TransposeContext* context,
257                        utils::MutableNodeView* node) override;
258 };
259 
260 class Conv3DTransposer : public LayoutSensitiveOpTransposer {
261  public:
Conv3DTransposer()262   explicit Conv3DTransposer() : LayoutSensitiveOpTransposer() {}
263 
264   Status TransposeNode(TransposeContext* context,
265                        utils::MutableNodeView* node) override;
266 };
267 
268 class Conv3DBackpropFilterTransposer : public LayoutSensitiveOpTransposer {
269  public:
Conv3DBackpropFilterTransposer()270   explicit Conv3DBackpropFilterTransposer() : LayoutSensitiveOpTransposer() {}
271 
272   Status TransposeNode(TransposeContext* context,
273                        utils::MutableNodeView* node) override;
274 };
275 
276 class Conv3DBackpropInputTransposer : public LayoutSensitiveOpTransposer {
277  public:
Conv3DBackpropInputTransposer()278   explicit Conv3DBackpropInputTransposer() : LayoutSensitiveOpTransposer() {}
279 
280   Status TransposeNode(TransposeContext* context,
281                        utils::MutableNodeView* node) override;
282 };
283 
284 class FusedBatchNormExTransposer : public LayoutSensitiveOpTransposer {
285  public:
FusedBatchNormExTransposer()286   explicit FusedBatchNormExTransposer() : LayoutSensitiveOpTransposer() {}
287 
288   Status TransposeNode(TransposeContext* context,
289                        utils::MutableNodeView* node) override;
290 };
291 
292 class FusedBatchNormGradTransposer : public LayoutSensitiveOpTransposer {
293  public:
FusedBatchNormGradTransposer()294   explicit FusedBatchNormGradTransposer() : LayoutSensitiveOpTransposer() {}
295 
296   Status TransposeNode(TransposeContext* context,
297                        utils::MutableNodeView* node) override;
298 
299  private:
300   bool IsTraining(const utils::MutableNodeView& node) const;
301 };
302 
303 class MaxPoolV2Transposer : public LayoutSensitiveOpTransposer {
304  public:
MaxPoolV2Transposer()305   explicit MaxPoolV2Transposer() : LayoutSensitiveOpTransposer() {}
306 
307   Status TransposeNode(TransposeContext* context,
308                        utils::MutableNodeView* node) override;
309 };
310 
311 class MaxPoolGradTransposer : public LayoutSensitiveOpTransposer {
312  public:
MaxPoolGradTransposer()313   explicit MaxPoolGradTransposer() : LayoutSensitiveOpTransposer() {}
314 
315   Status TransposeNode(TransposeContext* context,
316                        utils::MutableNodeView* node) override;
317 };
318 
319 class MaxPoolGradV2Transposer : public LayoutSensitiveOpTransposer {
320  public:
MaxPoolGradV2Transposer()321   explicit MaxPoolGradV2Transposer() : LayoutSensitiveOpTransposer() {}
322 
323   Status TransposeNode(TransposeContext* context,
324                        utils::MutableNodeView* node) override;
325 };
326 
327 // Layout agnostic op transposers.
328 
329 class LayoutAgnosticOpTransposer : public Transposer {
330  public:
LayoutAgnosticOpTransposer()331   explicit LayoutAgnosticOpTransposer() : Transposer() {}
332 
333  protected:
334   bool IsAfterDstToSrcTransform(const TransposeContext& context,
335                                 const utils::MutableNodeView& node) const;
336 
337   std::vector<int> GetVariadicNDFaninPorts(const TransposeContext& context,
338                                            const utils::MutableNodeView& node,
339                                            int rank) const;
340 };
341 
342 class DefaultLayoutAgnosticOpTransposer : public LayoutAgnosticOpTransposer {
343  public:
DefaultLayoutAgnosticOpTransposer()344   explicit DefaultLayoutAgnosticOpTransposer() : LayoutAgnosticOpTransposer() {}
345 
346   Status TransposeNode(TransposeContext* context,
347                        utils::MutableNodeView* node) override;
348 };
349 
350 class AddNTransposer : public LayoutAgnosticOpTransposer {
351  public:
AddNTransposer()352   explicit AddNTransposer() : LayoutAgnosticOpTransposer() {}
353 
354   Status TransposeNode(TransposeContext* context,
355                        utils::MutableNodeView* node) override;
356 };
357 
358 class BinaryOpTransposer : public LayoutAgnosticOpTransposer {
359  public:
BinaryOpTransposer()360   explicit BinaryOpTransposer() : LayoutAgnosticOpTransposer() {}
361 
362   Status TransposeNode(TransposeContext* context,
363                        utils::MutableNodeView* node) override;
364 
365  private:
366   bool IsNDOperateWithMD(const utils::MutableNodeView& node, int n, int m);
367   bool IsFaninShapeSupported(const utils::MutableNodeView& node, int rank);
368   std::vector<int> GetNDDataFaninPorts(const utils::MutableNodeView& node,
369                                        int rank);
370   Status AddNodeShapeConst(utils::Mutation* mutation,
371                            absl::string_view node_name,
372                            absl::string_view node_device, bool node_in_frame,
373                            int num_channels, absl::string_view depended_node,
374                            int rank);
375   Status AddNodeReshape(utils::Mutation* mutation, absl::string_view node_name,
376                         absl::string_view node_device,
377                         absl::string_view input_name,
378                         absl::string_view shape_const_node_name,
379                         const DataType& data_type);
380   Status MaybeReshapeVectorFanin(TransposeContext* context,
381                                  utils::MutableNodeView* node, int rank);
382 };
383 
384 class ConcatOpTransposer : public LayoutAgnosticOpTransposer {
385  public:
ConcatOpTransposer()386   explicit ConcatOpTransposer() : LayoutAgnosticOpTransposer() {}
387 
388   Status TransposeNode(TransposeContext* context,
389                        utils::MutableNodeView* node) override;
390 };
391 
392 class FillOpTransposer : public LayoutAgnosticOpTransposer {
393  public:
FillOpTransposer()394   explicit FillOpTransposer() : LayoutAgnosticOpTransposer() {}
395 
396   Status TransposeNode(TransposeContext* context,
397                        utils::MutableNodeView* node) override;
398 };
399 
400 class IdentityNTransposer : public LayoutAgnosticOpTransposer {
401  public:
IdentityNTransposer()402   explicit IdentityNTransposer() : LayoutAgnosticOpTransposer() {}
403 
404   Status TransposeNode(TransposeContext* context,
405                        utils::MutableNodeView* node) override;
406 };
407 
408 class MergeTransposer : public LayoutAgnosticOpTransposer {
409  public:
MergeTransposer()410   explicit MergeTransposer() : LayoutAgnosticOpTransposer() {}
411 
412   Status TransposeNode(TransposeContext* context,
413                        utils::MutableNodeView* node) override;
414 
415  private:
416   bool IsEveryFaninAfterDstToSrcTransform(
417       const TransposeContext& context,
418       const utils::MutableNodeView& node) const;
419 };
420 
421 class PadTransposer : public LayoutAgnosticOpTransposer {
422  public:
PadTransposer()423   explicit PadTransposer() : LayoutAgnosticOpTransposer() {}
424 
425   Status TransposeNode(TransposeContext* context,
426                        utils::MutableNodeView* node) override;
427 };
428 
429 class ReduceTransposer : public LayoutAgnosticOpTransposer {
430  public:
ReduceTransposer()431   explicit ReduceTransposer() : LayoutAgnosticOpTransposer() {}
432 
433   Status TransposeNode(TransposeContext* context,
434                        utils::MutableNodeView* node) override;
435 
436  private:
437   bool KeepDims(const utils::MutableNodeView& node);
438   bool IsAlongAxis(const Tensor& tensor, absl::Span<const int> axis, int rank);
439   bool IsReduceAxisSupported(const TransposeContext& context,
440                              const utils::MutableNodeView& node, int rank);
441 };
442 
443 class ReverseV2Transposer : public LayoutAgnosticOpTransposer {
444  public:
ReverseV2Transposer()445   explicit ReverseV2Transposer() : LayoutAgnosticOpTransposer() {}
446 
447   Status TransposeNode(TransposeContext* context,
448                        utils::MutableNodeView* node) override;
449 };
450 
451 class SelectTransposer : public LayoutAgnosticOpTransposer {
452  public:
SelectTransposer()453   explicit SelectTransposer() : LayoutAgnosticOpTransposer() {}
454 
455   Status TransposeNode(TransposeContext* context,
456                        utils::MutableNodeView* node) override;
457 
458  protected:
459   bool IsFaninScalarVector4D(const utils::MutableNodeView& fanin, int port);
460   std::vector<int> GetFaninPorts(const utils::MutableNodeView& fanin, int port);
461 };
462 
463 class ShapeTransposer : public LayoutAgnosticOpTransposer {
464  public:
ShapeTransposer()465   explicit ShapeTransposer() : LayoutAgnosticOpTransposer() {}
466 
467   Status TransposeNode(TransposeContext* context,
468                        utils::MutableNodeView* node) override;
469 };
470 
471 class ShapeNTransposer : public LayoutAgnosticOpTransposer {
472  public:
ShapeNTransposer()473   explicit ShapeNTransposer() : LayoutAgnosticOpTransposer() {}
474 
475   Status TransposeNode(TransposeContext* context,
476                        utils::MutableNodeView* node) override;
477 };
478 
479 class SliceTransposer : public LayoutAgnosticOpTransposer {
480  public:
SliceTransposer()481   explicit SliceTransposer() : LayoutAgnosticOpTransposer() {}
482 
483   Status TransposeNode(TransposeContext* context,
484                        utils::MutableNodeView* node) override;
485 };
486 
487 class SplitTransposer : public LayoutAgnosticOpTransposer {
488  public:
SplitTransposer()489   explicit SplitTransposer() : LayoutAgnosticOpTransposer() {}
490 
491   Status TransposeNode(TransposeContext* context,
492                        utils::MutableNodeView* node) override;
493 };
494 
495 class SplitVTransposer : public LayoutAgnosticOpTransposer {
496  public:
SplitVTransposer()497   explicit SplitVTransposer() : LayoutAgnosticOpTransposer() {}
498 
499   Status TransposeNode(TransposeContext* context,
500                        utils::MutableNodeView* node) override;
501 };
502 
503 class SqueezeTransposer : public LayoutAgnosticOpTransposer {
504  public:
SqueezeTransposer()505   explicit SqueezeTransposer() : LayoutAgnosticOpTransposer() {}
506 
507   Status TransposeNode(TransposeContext* context,
508                        utils::MutableNodeView* node) override;
509 
510  private:
511   bool IsInputConvertible(const TransposeContext& context,
512                           const utils::MutableNodeView& node) const;
513   bool IsAlongAxis(const AttrValue& attr, absl::Span<const int> axis,
514                    int rank) const;
515   bool IsDimsSupported(const TransposeContext& context,
516                        const utils::MutableNodeView& node) const;
517   Status UpdateSqueezeDims(TransposeContext* context,
518                            utils::MutableNodeView* node);
519 };
520 
521 class StridedSliceTransposer : public LayoutAgnosticOpTransposer {
522  public:
StridedSliceTransposer()523   explicit StridedSliceTransposer() : LayoutAgnosticOpTransposer() {}
524 
525   Status TransposeNode(TransposeContext* context,
526                        utils::MutableNodeView* node) override;
527 
528  private:
529   bool IsMaskZero(const utils::MutableNodeView& node, absl::string_view mask);
530   bool HasOnlyBeginEndMask(const utils::MutableNodeView& node);
531   Status PermuteMask(TransposeContext* context, utils::MutableNodeView* node,
532                      absl::string_view mask);
533 };
534 
535 class SwitchTransposer : public LayoutAgnosticOpTransposer {
536  public:
SwitchTransposer()537   explicit SwitchTransposer() : LayoutAgnosticOpTransposer() {}
538 
539   Status TransposeNode(TransposeContext* context,
540                        utils::MutableNodeView* node) override;
541 };
542 
543 class TernaryOpTransposer : public LayoutAgnosticOpTransposer {
544  public:
TernaryOpTransposer()545   explicit TernaryOpTransposer() : LayoutAgnosticOpTransposer() {}
546 
547   Status TransposeNode(TransposeContext* context,
548                        utils::MutableNodeView* node) override;
549 };
550 
551 class TileTransposer : public LayoutAgnosticOpTransposer {
552  public:
TileTransposer()553   explicit TileTransposer() : LayoutAgnosticOpTransposer() {}
554 
555   Status TransposeNode(TransposeContext* context,
556                        utils::MutableNodeView* node) override;
557 };
558 
559 class UnaryGradTransposer : public LayoutAgnosticOpTransposer {
560  public:
UnaryGradTransposer()561   explicit UnaryGradTransposer() : LayoutAgnosticOpTransposer() {}
562 
563   Status TransposeNode(TransposeContext* context,
564                        utils::MutableNodeView* node) override;
565 };
566 
567 // Utils.
568 
569 // Permutes elements according to permutation and replaces the original values.
570 // Permutation and values must have same size.
571 template <typename T>
PermuteSingle(absl::string_view location,absl::Span<const int> permutation,T * values)572 Status PermuteSingle(absl::string_view location,
573                      absl::Span<const int> permutation, T* values) {
574   DCHECK(values != nullptr);
575   int permutation_size = permutation.size();
576   if (values->size() != permutation_size) {
577     return Status(tensorflow::error::Code::INVALID_ARGUMENT,
578                   absl::StrCat("Size of values ", values->size(),
579                                " does not match size of permutation ",
580                                permutation_size, " @ ", location));
581   }
582   typedef typename T::value_type V;
583   std::vector<V> elements(values->begin(), values->end());
584   int index = 0;
585   for (V& element : *values) {
586     element = elements[permutation[index++]];
587   }
588   return OkStatus();
589 }
590 
591 // Permutes two elements at a time according to permutation and replaces the
592 // original values. Values must be twice the size of permutation.
593 template <typename T>
PermuteDouble(absl::string_view location,absl::Span<const int> permutation,T * values)594 Status PermuteDouble(absl::string_view location,
595                      absl::Span<const int> permutation, T* values) {
596   DCHECK(values != nullptr);
597   int permutation_size = permutation.size();
598   if (values->size() != permutation_size * 2) {
599     return Status(tensorflow::error::Code::INVALID_ARGUMENT,
600                   absl::StrCat("Size of values ", values->size(),
601                                " does not match twice the size of permutation ",
602                                permutation_size, " @ ", location));
603   }
604   typedef typename T::value_type V;
605   std::vector<V> elements(values->begin(), values->end());
606   for (int i = 0; i < values->size(); i = i + 2) {
607     const int permutation_index = permutation[i / 2];
608     (*values)[i] = elements[permutation_index * 2];
609     (*values)[i + 1] = elements[permutation_index * 2 + 1];
610   }
611   return OkStatus();
612 }
613 
614 string GetDeviceName(const NodeDef& node);
615 
616 bool IsDefaultLayoutSensitiveOp(const NodeDef& node);
617 
618 bool IsLayoutSensitiveOp(const NodeDef& node);
619 
620 bool IsDefaultLayoutAgnosticOp(const NodeDef& node);
621 
622 bool IsLayoutAgnosticOp(const NodeDef& node);
623 
624 bool IsTernaryOp(const NodeDef& node);
625 
626 bool IsUnaryGrad(const NodeDef& node);
627 
628 bool IsMaxPoolV2(const NodeDef& node);
629 
630 bool IsMaxPoolGradV2(const NodeDef& node);
631 
632 bool IsMaxPoolGradGradV1(const NodeDef& node);
633 
634 bool IsMaxPoolGradGradV2(const NodeDef& node);
635 
636 bool IsBinaryOp(const NodeDef& node);
637 
638 bool IsReduceOp(const NodeDef& node);
639 
640 std::vector<int> GetDataFaninPorts(const utils::MutableNodeView& node);
641 
642 std::vector<int> GetDataFanoutPorts(const utils::MutableNodeView& node);
643 
644 // Returns a value of constant input to the `node` at `index`, iff `predicate`
645 // evaluated to true. Returns true if `tensor` was populated with data.
646 bool GetValueAttrFromConstInputNode(
647     const utils::MutableNodeView& node,
648     const std::function<bool(const NodeDef&)>& predicate, int index,
649     Tensor* tensor);
650 
651 bool IsDataFormatOp(const utils::MutableNodeView& node);
652 
653 absl::flat_hash_map<char, int> GetDimensionIndices(
654     absl::string_view data_format);
655 
656 std::vector<int> GetPermutation(
657     const absl::flat_hash_map<char, int>& src_dim_indices,
658     absl::string_view dst_format);
659 
660 }  // namespace grappler
661 }  // namespace tensorflow
662 
663 #endif  // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_GENERIC_LAYOUT_OPTIMIZER_TRANSPOSER_H_
664