1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_GENERIC_LAYOUT_OPTIMIZER_TRANSPOSER_H_
17 #define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_GENERIC_LAYOUT_OPTIMIZER_TRANSPOSER_H_
18
19 #include <memory>
20 #include <vector>
21
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/string_view.h"
26 #include "absl/types/span.h"
27 #include "tensorflow/core/framework/graph.pb.h"
28 #include "tensorflow/core/framework/node_def.pb.h"
29 #include "tensorflow/core/grappler/costs/graph_properties.h"
30 #include "tensorflow/core/grappler/utils.h"
31 #include "tensorflow/core/grappler/utils/frame.h"
32 #include "tensorflow/core/grappler/utils/graph_view.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/core/status.h"
35
36 namespace tensorflow {
37 namespace grappler {
38
39 constexpr char kAttrSrcFormat[] = "src_format";
40 constexpr char kAttrDstFormat[] = "dst_format";
41 constexpr char kAttrOutputShape[] = "_output_shapes";
42 constexpr char kGPU[] = "GPU";
43 constexpr char kCPU[] = "CPU";
44
45 // TransposeContext owns all data members. Must initialize GraphProperties,
46 // FrameView, GraphDef and MutableGraphView with the same graph. NodeDef
47 // pointers in FrameView, GraphDef and MutableGraphView must point to nodes in
48 // the same GraphDef instance.
49 struct TransposeContext {
50 // Initializes TransposeContext with given GrapplerItem. Because initializing
51 // FrameMap and GraphProperties may return error, we initialize
52 // TransposeContext outside constructor.
53 static Status InitializeTransposeContext(bool assume_valid_feeds,
54 const GrapplerItem& item,
55 const Cluster* cluster,
56 TransposeContext* context);
57
InitializeTransposeContextTransposeContext58 static Status InitializeTransposeContext(const GrapplerItem& item,
59 const Cluster* cluster,
60 TransposeContext* context) {
61 return InitializeTransposeContext(false, item, cluster, context);
62 }
63
64 // Sets data formats to convert from and to for specified device type.
65 void AssignDeviceAndDataFormats(absl::string_view target_device,
66 absl::string_view src_format,
67 absl::string_view dst_format);
68
69 FrameView frames;
70 GraphDef graph;
71 // Number of nodes in the original graph. As new nodes are appended to the end
72 // of the graph, all new nodes should have a node index greater than or equal
73 // to this.
74 int num_nodes;
75 absl::flat_hash_set<string> nodes_to_preserve;
76 std::unique_ptr<GraphProperties> graph_properties;
77 std::unique_ptr<utils::MutableGraphView> graph_view;
78
79 string target_device;
80 string src_format;
81 string dst_format;
82 absl::flat_hash_map<char, int> src_dim_indices;
83 absl::flat_hash_map<char, int> dst_dim_indices;
84 std::vector<int> src_to_dst;
85 std::vector<int> dst_to_src;
86
87 string enforced_layout;
88 };
89
90 class Transposer {
91 public:
Transposer()92 explicit Transposer() {}
93
94 Transposer(const Transposer&) = delete;
95 Transposer& operator=(const Transposer&) = delete;
96
~Transposer()97 virtual ~Transposer() {}
98
99 // Returns true iff the node should be processed by this transposer.
100 // NodeProcessors may perform additional oprand specific checks before
101 // processing if necessary.
102 // Following common conditions are checked:
103 // * node's device matches target device
104 // * node's source format matches config's source format
105 // * node has output
106 bool ShouldProcess(const TransposeContext& context,
107 const utils::MutableNodeView& node) const;
108
109 // Transposes given node from src format to dst format. Also perform other
110 // necessary operations to guarantee the graph produce the same result.
111 // Eg. Add Transpose node sets before fanin ports and after fanout ports.
112 virtual Status TransposeNode(TransposeContext* context,
113 utils::MutableNodeView* node) = 0;
114
115 // Creates a Const node for permutation. If node with node_name already exits,
116 // return and reuse it.
117 Status CreateConstPermNode(TransposeContext* context,
118 absl::string_view node_name,
119 absl::string_view device,
120 absl::Span<const int> permutation,
121 absl::string_view control_node_name,
122 utils::MutationNewNode* added_node);
123
124 // Creates a TransposeNode with given properties. If node with node_name
125 // already exits, return and reuse it.
126 // A const perm node is also created and connected to the 2nd fanin.
127 // control_node_name is ignored if it is empty.
128 Status CreateTransposeNode(
129 TransposeContext* context, absl::string_view name_format,
130 const DataType& data_type, absl::string_view device,
131 TensorShapeProto fanin_shape, absl::Span<const int> permutation,
132 absl::string_view control_node_name, utils::MutationNewNode* added_node,
133 string* transpose_node_name);
134
135 // Update all edges between dst_node->fanin[dst_ports] and dst_node by
136 // inserting an op node.
137 Status UpdateFaninEdgesWithOp(TransposeContext* context,
138 absl::Span<const int> dst_ports,
139 utils::MutableNodeView* dst_node,
140 absl::string_view op);
141
142 // Update all edges between src_node:src_ports and nodes take
143 // src_node:src_ports as fanin. Also update attr _output_shape of src_node.
144 Status UpdateFanoutEdgesWithOp(TransposeContext* context,
145 absl::Span<const int> src_ports,
146 utils::MutableNodeView* src_node,
147 absl::string_view op);
148
149 // Creates a DataFromat node with given properties.
150 // DataFromat op is either DataFormatVecPermute or DataFormatDimMap.
151 Status CreateDataFormatNode(TransposeContext* context,
152 absl::string_view node_name, absl::string_view op,
153 absl::string_view device,
154 const DataType& data_type, bool is_fanin_on_host,
155 bool is_src_format_to_dst_format,
156 utils::MutationNewNode* added_node);
157
158 protected:
159 int GetFanoutPortRank(const utils::MutableNodeView& node, int port) const;
160 bool IsFanoutPortRankN(const utils::MutableNodeView& node, int port,
161 int n) const;
162 bool IsFanoutPortsRankN(const utils::MutableNodeView& node,
163 absl::Span<const int> ports, int n) const;
164 int GetFaninPortRank(const utils::MutableNodeView& node, int port) const;
165 bool IsFaninPortRankN(const utils::MutableNodeView& node, int port,
166 int n) const;
167
168 // Checks if fanin at specified port(s) has dimensions `dims` iff fanin is a
169 // Const. If fanin is not a Const, no dimensions will be checked and this will
170 // return true.
171 bool IsFaninPortDimsNIfConst(const utils::MutableNodeView& node, int port,
172 absl::Span<const int> dims) const;
173 bool IsFaninPortsDimsNIfConst(const utils::MutableNodeView& node,
174 absl::Span<const int> ports,
175 absl::Span<const int> dims) const;
176 bool CanProcessNode(const TransposeContext& context,
177 const utils::MutableNodeView& node) const;
178 // Update all edges between dst_node->fanin[dst_ports] and dst_node.
179 // A node with op is created and inserted between all edges.
180 // op is one of Transpose, DataFormatVecPermute or DataFormatDimMap.
181 Status UpdateEdge(TransposeContext* context, absl::string_view name_format,
182 absl::string_view op, const AttrValue* input_shape,
183 bool is_in_frame, bool is_src_format_to_dst_format,
184 const int src_port, const int dst_port,
185 utils::MutableNodeView* src_node,
186 utils::MutableNodeView* dst_node);
187 string GetFaninNameFormat(absl::string_view node_name, int port,
188 absl::string_view src_format,
189 absl::string_view dst_format);
190 string GetFanoutNameFormat(absl::string_view node_name, int port, int index,
191 absl::string_view src_format,
192 absl::string_view dst_format);
193 string LayoutOptimizerNode(absl::string_view node_name);
194 string GetReshapeNodeNameFormat(absl::string_view node_name, int index,
195 absl::string_view src_format,
196 absl::string_view dst_format);
197 string GetShapeConstNodeNameFormat(absl::string_view node_name, int index);
198 };
199
200 class LayoutSensitiveOpTransposer : public Transposer {
201 public:
LayoutSensitiveOpTransposer()202 explicit LayoutSensitiveOpTransposer() : Transposer() {}
203
204 // Updates attrs data_format, ksize, strides of the given node to dst_format.
205 // _output_shape is updated during UpdateOutputEdges.
206 Status UpdateNode(TransposeContext* context, utils::MutableNodeView* node);
207 };
208
209 // Layout sensitive op transposers.
210
211 class DefaultLayoutSensitiveOpTransposer : public LayoutSensitiveOpTransposer {
212 public:
DefaultLayoutSensitiveOpTransposer()213 explicit DefaultLayoutSensitiveOpTransposer()
214 : LayoutSensitiveOpTransposer() {}
215
216 Status TransposeNode(TransposeContext* context,
217 utils::MutableNodeView* node) override;
218 };
219
220 class BiasAddTransposer : public LayoutSensitiveOpTransposer {
221 public:
BiasAddTransposer()222 explicit BiasAddTransposer() : LayoutSensitiveOpTransposer() {}
223
224 Status TransposeNode(TransposeContext* context,
225 utils::MutableNodeView* node) override;
226 };
227
228 class AvgPoolGradTransposer : public LayoutSensitiveOpTransposer {
229 public:
AvgPoolGradTransposer()230 explicit AvgPoolGradTransposer() : LayoutSensitiveOpTransposer() {}
231
232 Status TransposeNode(TransposeContext* context,
233 utils::MutableNodeView* node) override;
234 };
235
236 class BiasAddGradTransposer : public LayoutSensitiveOpTransposer {
237 public:
BiasAddGradTransposer()238 explicit BiasAddGradTransposer() : LayoutSensitiveOpTransposer() {}
239
240 Status TransposeNode(TransposeContext* context,
241 utils::MutableNodeView* node) override;
242 };
243
244 class Conv2DBackpropFilterTransposer : public LayoutSensitiveOpTransposer {
245 public:
Conv2DBackpropFilterTransposer()246 explicit Conv2DBackpropFilterTransposer() : LayoutSensitiveOpTransposer() {}
247
248 Status TransposeNode(TransposeContext* context,
249 utils::MutableNodeView* node) override;
250 };
251
252 class Conv2DBackpropInputTransposer : public LayoutSensitiveOpTransposer {
253 public:
Conv2DBackpropInputTransposer()254 explicit Conv2DBackpropInputTransposer() : LayoutSensitiveOpTransposer() {}
255
256 Status TransposeNode(TransposeContext* context,
257 utils::MutableNodeView* node) override;
258 };
259
260 class Conv3DTransposer : public LayoutSensitiveOpTransposer {
261 public:
Conv3DTransposer()262 explicit Conv3DTransposer() : LayoutSensitiveOpTransposer() {}
263
264 Status TransposeNode(TransposeContext* context,
265 utils::MutableNodeView* node) override;
266 };
267
268 class Conv3DBackpropFilterTransposer : public LayoutSensitiveOpTransposer {
269 public:
Conv3DBackpropFilterTransposer()270 explicit Conv3DBackpropFilterTransposer() : LayoutSensitiveOpTransposer() {}
271
272 Status TransposeNode(TransposeContext* context,
273 utils::MutableNodeView* node) override;
274 };
275
276 class Conv3DBackpropInputTransposer : public LayoutSensitiveOpTransposer {
277 public:
Conv3DBackpropInputTransposer()278 explicit Conv3DBackpropInputTransposer() : LayoutSensitiveOpTransposer() {}
279
280 Status TransposeNode(TransposeContext* context,
281 utils::MutableNodeView* node) override;
282 };
283
284 class FusedBatchNormExTransposer : public LayoutSensitiveOpTransposer {
285 public:
FusedBatchNormExTransposer()286 explicit FusedBatchNormExTransposer() : LayoutSensitiveOpTransposer() {}
287
288 Status TransposeNode(TransposeContext* context,
289 utils::MutableNodeView* node) override;
290 };
291
292 class FusedBatchNormGradTransposer : public LayoutSensitiveOpTransposer {
293 public:
FusedBatchNormGradTransposer()294 explicit FusedBatchNormGradTransposer() : LayoutSensitiveOpTransposer() {}
295
296 Status TransposeNode(TransposeContext* context,
297 utils::MutableNodeView* node) override;
298
299 private:
300 bool IsTraining(const utils::MutableNodeView& node) const;
301 };
302
303 class MaxPoolV2Transposer : public LayoutSensitiveOpTransposer {
304 public:
MaxPoolV2Transposer()305 explicit MaxPoolV2Transposer() : LayoutSensitiveOpTransposer() {}
306
307 Status TransposeNode(TransposeContext* context,
308 utils::MutableNodeView* node) override;
309 };
310
311 class MaxPoolGradTransposer : public LayoutSensitiveOpTransposer {
312 public:
MaxPoolGradTransposer()313 explicit MaxPoolGradTransposer() : LayoutSensitiveOpTransposer() {}
314
315 Status TransposeNode(TransposeContext* context,
316 utils::MutableNodeView* node) override;
317 };
318
319 class MaxPoolGradV2Transposer : public LayoutSensitiveOpTransposer {
320 public:
MaxPoolGradV2Transposer()321 explicit MaxPoolGradV2Transposer() : LayoutSensitiveOpTransposer() {}
322
323 Status TransposeNode(TransposeContext* context,
324 utils::MutableNodeView* node) override;
325 };
326
327 // Layout agnostic op transposers.
328
329 class LayoutAgnosticOpTransposer : public Transposer {
330 public:
LayoutAgnosticOpTransposer()331 explicit LayoutAgnosticOpTransposer() : Transposer() {}
332
333 protected:
334 bool IsAfterDstToSrcTransform(const TransposeContext& context,
335 const utils::MutableNodeView& node) const;
336
337 std::vector<int> GetVariadicNDFaninPorts(const TransposeContext& context,
338 const utils::MutableNodeView& node,
339 int rank) const;
340 };
341
342 class DefaultLayoutAgnosticOpTransposer : public LayoutAgnosticOpTransposer {
343 public:
DefaultLayoutAgnosticOpTransposer()344 explicit DefaultLayoutAgnosticOpTransposer() : LayoutAgnosticOpTransposer() {}
345
346 Status TransposeNode(TransposeContext* context,
347 utils::MutableNodeView* node) override;
348 };
349
350 class AddNTransposer : public LayoutAgnosticOpTransposer {
351 public:
AddNTransposer()352 explicit AddNTransposer() : LayoutAgnosticOpTransposer() {}
353
354 Status TransposeNode(TransposeContext* context,
355 utils::MutableNodeView* node) override;
356 };
357
358 class BinaryOpTransposer : public LayoutAgnosticOpTransposer {
359 public:
BinaryOpTransposer()360 explicit BinaryOpTransposer() : LayoutAgnosticOpTransposer() {}
361
362 Status TransposeNode(TransposeContext* context,
363 utils::MutableNodeView* node) override;
364
365 private:
366 bool IsNDOperateWithMD(const utils::MutableNodeView& node, int n, int m);
367 bool IsFaninShapeSupported(const utils::MutableNodeView& node, int rank);
368 std::vector<int> GetNDDataFaninPorts(const utils::MutableNodeView& node,
369 int rank);
370 Status AddNodeShapeConst(utils::Mutation* mutation,
371 absl::string_view node_name,
372 absl::string_view node_device, bool node_in_frame,
373 int num_channels, absl::string_view depended_node,
374 int rank);
375 Status AddNodeReshape(utils::Mutation* mutation, absl::string_view node_name,
376 absl::string_view node_device,
377 absl::string_view input_name,
378 absl::string_view shape_const_node_name,
379 const DataType& data_type);
380 Status MaybeReshapeVectorFanin(TransposeContext* context,
381 utils::MutableNodeView* node, int rank);
382 };
383
384 class ConcatOpTransposer : public LayoutAgnosticOpTransposer {
385 public:
ConcatOpTransposer()386 explicit ConcatOpTransposer() : LayoutAgnosticOpTransposer() {}
387
388 Status TransposeNode(TransposeContext* context,
389 utils::MutableNodeView* node) override;
390 };
391
392 class FillOpTransposer : public LayoutAgnosticOpTransposer {
393 public:
FillOpTransposer()394 explicit FillOpTransposer() : LayoutAgnosticOpTransposer() {}
395
396 Status TransposeNode(TransposeContext* context,
397 utils::MutableNodeView* node) override;
398 };
399
400 class IdentityNTransposer : public LayoutAgnosticOpTransposer {
401 public:
IdentityNTransposer()402 explicit IdentityNTransposer() : LayoutAgnosticOpTransposer() {}
403
404 Status TransposeNode(TransposeContext* context,
405 utils::MutableNodeView* node) override;
406 };
407
408 class MergeTransposer : public LayoutAgnosticOpTransposer {
409 public:
MergeTransposer()410 explicit MergeTransposer() : LayoutAgnosticOpTransposer() {}
411
412 Status TransposeNode(TransposeContext* context,
413 utils::MutableNodeView* node) override;
414
415 private:
416 bool IsEveryFaninAfterDstToSrcTransform(
417 const TransposeContext& context,
418 const utils::MutableNodeView& node) const;
419 };
420
421 class PadTransposer : public LayoutAgnosticOpTransposer {
422 public:
PadTransposer()423 explicit PadTransposer() : LayoutAgnosticOpTransposer() {}
424
425 Status TransposeNode(TransposeContext* context,
426 utils::MutableNodeView* node) override;
427 };
428
429 class ReduceTransposer : public LayoutAgnosticOpTransposer {
430 public:
ReduceTransposer()431 explicit ReduceTransposer() : LayoutAgnosticOpTransposer() {}
432
433 Status TransposeNode(TransposeContext* context,
434 utils::MutableNodeView* node) override;
435
436 private:
437 bool KeepDims(const utils::MutableNodeView& node);
438 bool IsAlongAxis(const Tensor& tensor, absl::Span<const int> axis, int rank);
439 bool IsReduceAxisSupported(const TransposeContext& context,
440 const utils::MutableNodeView& node, int rank);
441 };
442
443 class ReverseV2Transposer : public LayoutAgnosticOpTransposer {
444 public:
ReverseV2Transposer()445 explicit ReverseV2Transposer() : LayoutAgnosticOpTransposer() {}
446
447 Status TransposeNode(TransposeContext* context,
448 utils::MutableNodeView* node) override;
449 };
450
451 class SelectTransposer : public LayoutAgnosticOpTransposer {
452 public:
SelectTransposer()453 explicit SelectTransposer() : LayoutAgnosticOpTransposer() {}
454
455 Status TransposeNode(TransposeContext* context,
456 utils::MutableNodeView* node) override;
457
458 protected:
459 bool IsFaninScalarVector4D(const utils::MutableNodeView& fanin, int port);
460 std::vector<int> GetFaninPorts(const utils::MutableNodeView& fanin, int port);
461 };
462
463 class ShapeTransposer : public LayoutAgnosticOpTransposer {
464 public:
ShapeTransposer()465 explicit ShapeTransposer() : LayoutAgnosticOpTransposer() {}
466
467 Status TransposeNode(TransposeContext* context,
468 utils::MutableNodeView* node) override;
469 };
470
471 class ShapeNTransposer : public LayoutAgnosticOpTransposer {
472 public:
ShapeNTransposer()473 explicit ShapeNTransposer() : LayoutAgnosticOpTransposer() {}
474
475 Status TransposeNode(TransposeContext* context,
476 utils::MutableNodeView* node) override;
477 };
478
479 class SliceTransposer : public LayoutAgnosticOpTransposer {
480 public:
SliceTransposer()481 explicit SliceTransposer() : LayoutAgnosticOpTransposer() {}
482
483 Status TransposeNode(TransposeContext* context,
484 utils::MutableNodeView* node) override;
485 };
486
487 class SplitTransposer : public LayoutAgnosticOpTransposer {
488 public:
SplitTransposer()489 explicit SplitTransposer() : LayoutAgnosticOpTransposer() {}
490
491 Status TransposeNode(TransposeContext* context,
492 utils::MutableNodeView* node) override;
493 };
494
495 class SplitVTransposer : public LayoutAgnosticOpTransposer {
496 public:
SplitVTransposer()497 explicit SplitVTransposer() : LayoutAgnosticOpTransposer() {}
498
499 Status TransposeNode(TransposeContext* context,
500 utils::MutableNodeView* node) override;
501 };
502
503 class SqueezeTransposer : public LayoutAgnosticOpTransposer {
504 public:
SqueezeTransposer()505 explicit SqueezeTransposer() : LayoutAgnosticOpTransposer() {}
506
507 Status TransposeNode(TransposeContext* context,
508 utils::MutableNodeView* node) override;
509
510 private:
511 bool IsInputConvertible(const TransposeContext& context,
512 const utils::MutableNodeView& node) const;
513 bool IsAlongAxis(const AttrValue& attr, absl::Span<const int> axis,
514 int rank) const;
515 bool IsDimsSupported(const TransposeContext& context,
516 const utils::MutableNodeView& node) const;
517 Status UpdateSqueezeDims(TransposeContext* context,
518 utils::MutableNodeView* node);
519 };
520
521 class StridedSliceTransposer : public LayoutAgnosticOpTransposer {
522 public:
StridedSliceTransposer()523 explicit StridedSliceTransposer() : LayoutAgnosticOpTransposer() {}
524
525 Status TransposeNode(TransposeContext* context,
526 utils::MutableNodeView* node) override;
527
528 private:
529 bool IsMaskZero(const utils::MutableNodeView& node, absl::string_view mask);
530 bool HasOnlyBeginEndMask(const utils::MutableNodeView& node);
531 Status PermuteMask(TransposeContext* context, utils::MutableNodeView* node,
532 absl::string_view mask);
533 };
534
535 class SwitchTransposer : public LayoutAgnosticOpTransposer {
536 public:
SwitchTransposer()537 explicit SwitchTransposer() : LayoutAgnosticOpTransposer() {}
538
539 Status TransposeNode(TransposeContext* context,
540 utils::MutableNodeView* node) override;
541 };
542
543 class TernaryOpTransposer : public LayoutAgnosticOpTransposer {
544 public:
TernaryOpTransposer()545 explicit TernaryOpTransposer() : LayoutAgnosticOpTransposer() {}
546
547 Status TransposeNode(TransposeContext* context,
548 utils::MutableNodeView* node) override;
549 };
550
551 class TileTransposer : public LayoutAgnosticOpTransposer {
552 public:
TileTransposer()553 explicit TileTransposer() : LayoutAgnosticOpTransposer() {}
554
555 Status TransposeNode(TransposeContext* context,
556 utils::MutableNodeView* node) override;
557 };
558
559 class UnaryGradTransposer : public LayoutAgnosticOpTransposer {
560 public:
UnaryGradTransposer()561 explicit UnaryGradTransposer() : LayoutAgnosticOpTransposer() {}
562
563 Status TransposeNode(TransposeContext* context,
564 utils::MutableNodeView* node) override;
565 };
566
567 // Utils.
568
569 // Permutes elements according to permutation and replaces the original values.
570 // Permutation and values must have same size.
571 template <typename T>
PermuteSingle(absl::string_view location,absl::Span<const int> permutation,T * values)572 Status PermuteSingle(absl::string_view location,
573 absl::Span<const int> permutation, T* values) {
574 DCHECK(values != nullptr);
575 int permutation_size = permutation.size();
576 if (values->size() != permutation_size) {
577 return Status(tensorflow::error::Code::INVALID_ARGUMENT,
578 absl::StrCat("Size of values ", values->size(),
579 " does not match size of permutation ",
580 permutation_size, " @ ", location));
581 }
582 typedef typename T::value_type V;
583 std::vector<V> elements(values->begin(), values->end());
584 int index = 0;
585 for (V& element : *values) {
586 element = elements[permutation[index++]];
587 }
588 return OkStatus();
589 }
590
591 // Permutes two elements at a time according to permutation and replaces the
592 // original values. Values must be twice the size of permutation.
593 template <typename T>
PermuteDouble(absl::string_view location,absl::Span<const int> permutation,T * values)594 Status PermuteDouble(absl::string_view location,
595 absl::Span<const int> permutation, T* values) {
596 DCHECK(values != nullptr);
597 int permutation_size = permutation.size();
598 if (values->size() != permutation_size * 2) {
599 return Status(tensorflow::error::Code::INVALID_ARGUMENT,
600 absl::StrCat("Size of values ", values->size(),
601 " does not match twice the size of permutation ",
602 permutation_size, " @ ", location));
603 }
604 typedef typename T::value_type V;
605 std::vector<V> elements(values->begin(), values->end());
606 for (int i = 0; i < values->size(); i = i + 2) {
607 const int permutation_index = permutation[i / 2];
608 (*values)[i] = elements[permutation_index * 2];
609 (*values)[i + 1] = elements[permutation_index * 2 + 1];
610 }
611 return OkStatus();
612 }
613
614 string GetDeviceName(const NodeDef& node);
615
616 bool IsDefaultLayoutSensitiveOp(const NodeDef& node);
617
618 bool IsLayoutSensitiveOp(const NodeDef& node);
619
620 bool IsDefaultLayoutAgnosticOp(const NodeDef& node);
621
622 bool IsLayoutAgnosticOp(const NodeDef& node);
623
624 bool IsTernaryOp(const NodeDef& node);
625
626 bool IsUnaryGrad(const NodeDef& node);
627
628 bool IsMaxPoolV2(const NodeDef& node);
629
630 bool IsMaxPoolGradV2(const NodeDef& node);
631
632 bool IsMaxPoolGradGradV1(const NodeDef& node);
633
634 bool IsMaxPoolGradGradV2(const NodeDef& node);
635
636 bool IsBinaryOp(const NodeDef& node);
637
638 bool IsReduceOp(const NodeDef& node);
639
640 std::vector<int> GetDataFaninPorts(const utils::MutableNodeView& node);
641
642 std::vector<int> GetDataFanoutPorts(const utils::MutableNodeView& node);
643
644 // Returns a value of constant input to the `node` at `index`, iff `predicate`
645 // evaluated to true. Returns true if `tensor` was populated with data.
646 bool GetValueAttrFromConstInputNode(
647 const utils::MutableNodeView& node,
648 const std::function<bool(const NodeDef&)>& predicate, int index,
649 Tensor* tensor);
650
651 bool IsDataFormatOp(const utils::MutableNodeView& node);
652
653 absl::flat_hash_map<char, int> GetDimensionIndices(
654 absl::string_view data_format);
655
656 std::vector<int> GetPermutation(
657 const absl::flat_hash_map<char, int>& src_dim_indices,
658 absl::string_view dst_format);
659
660 } // namespace grappler
661 } // namespace tensorflow
662
663 #endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_GENERIC_LAYOUT_OPTIMIZER_TRANSPOSER_H_
664