1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer_factory.h"
17 
18 #include "tensorflow/core/grappler/op_types.h"
19 
20 namespace tensorflow {
21 namespace grappler {
22 
GetTransposer(const NodeDef & node)23 std::shared_ptr<Transposer> TransposerFactory::GetTransposer(
24     const NodeDef& node) {
25   // Check layout sensitive ops.
26   if (IsDefaultLayoutSensitiveOp(node)) {
27     return GetOrCreateIfNotFound<DefaultLayoutSensitiveOpTransposer>(
28         "DefaultLayoutSensitiveOp");
29   }
30   if (IsAvgPoolGrad(node)) {
31     return GetOrCreateIfNotFound<AvgPoolGradTransposer>("AvgPoolGrad");
32   }
33   if (IsBiasAddV2(node)) {
34     return GetOrCreateIfNotFound<BiasAddTransposer>("BiasAdd");
35   }
36   if (IsBiasAddGrad(node)) {
37     return GetOrCreateIfNotFound<BiasAddGradTransposer>("BiasAddGrad");
38   }
39   if (IsConv2DBackpropFilter(node) ||
40       IsDepthwiseConv2dNativeBackpropFilter(node)) {
41     return GetOrCreateIfNotFound<Conv2DBackpropFilterTransposer>(
42         "Conv2DBackpropFilter");
43   }
44   if (IsConv2DBackpropInput(node) ||
45       IsDepthwiseConv2dNativeBackpropInput(node)) {
46     return GetOrCreateIfNotFound<Conv2DBackpropInputTransposer>(
47         "Conv2DBackpropInput");
48   }
49   if (IsConv3D(node)) {
50     return GetOrCreateIfNotFound<Conv3DTransposer>("Conv3D");
51   }
52   if (IsConv3DBackpropInputV2(node)) {
53     return GetOrCreateIfNotFound<Conv3DBackpropInputTransposer>(
54         "Conv3DBackpropInput");
55   }
56   if (IsConv3DBackpropFilterV2(node)) {
57     return GetOrCreateIfNotFound<Conv3DBackpropFilterTransposer>(
58         "Conv3DBackpropFilter");
59   }
60   if (IsFusedBatchNormEx(node)) {
61     return GetOrCreateIfNotFound<FusedBatchNormExTransposer>(
62         "FusedBatchNormEx");
63   }
64   if (IsFusedBatchNormGrad(node)) {
65     return GetOrCreateIfNotFound<FusedBatchNormGradTransposer>(
66         "FusedBatchNormGrad");
67   }
68   if (IsMaxPoolV2(node)) {
69     return GetOrCreateIfNotFound<MaxPoolV2Transposer>("MaxPoolV2");
70   }
71   if (IsMaxPoolGrad(node) || IsMaxPoolGradGradV1(node)) {
72     return GetOrCreateIfNotFound<MaxPoolGradTransposer>("MaxPoolGrad");
73   }
74   if (IsMaxPoolGradV2(node) || IsMaxPoolGradGradV2(node)) {
75     return GetOrCreateIfNotFound<MaxPoolGradV2Transposer>("MaxPoolGradV2");
76   }
77   // Check layout agnostic ops.
78   if (IsDefaultLayoutAgnosticOp(node)) {
79     return GetOrCreateIfNotFound<DefaultLayoutAgnosticOpTransposer>(
80         "DefaultLayoutAgnosticOp");
81   }
82   if (IsAddN(node)) {
83     return GetOrCreateIfNotFound<AddNTransposer>("AddN");
84   }
85   if (IsBinaryOp(node)) {
86     return GetOrCreateIfNotFound<BinaryOpTransposer>("BinaryOp");
87   }
88   if (IsConcat(node)) {
89     return GetOrCreateIfNotFound<ConcatOpTransposer>("Concat");
90   }
91   if (IsFill(node)) {
92     return GetOrCreateIfNotFound<FillOpTransposer>("Fill");
93   }
94   if (IsIdentityN(node)) {
95     return GetOrCreateIfNotFound<IdentityNTransposer>("IdentityN");
96   }
97   if (IsMerge(node)) {
98     return GetOrCreateIfNotFound<MergeTransposer>("Merge");
99   }
100   if (IsMirrorPad(node) || IsMirrorPadGrad(node) || IsPad(node)) {
101     return GetOrCreateIfNotFound<PadTransposer>("Pad");
102   }
103   if (IsReduceOp(node)) {
104     return GetOrCreateIfNotFound<ReduceTransposer>("ReduceOp");
105   }
106   if (IsReverseV2(node)) {
107     return GetOrCreateIfNotFound<ReverseV2Transposer>("ReverseV2");
108   }
109   if (IsSelect(node)) {
110     return GetOrCreateIfNotFound<SelectTransposer>("Select");
111   }
112   if (IsShape(node)) {
113     return GetOrCreateIfNotFound<ShapeTransposer>("Shape");
114   }
115   if (IsShapeN(node)) {
116     return GetOrCreateIfNotFound<ShapeNTransposer>("ShapeN");
117   }
118   if (IsSlice(node)) {
119     return GetOrCreateIfNotFound<SliceTransposer>("Slice");
120   }
121   if (IsSplit(node)) {
122     return GetOrCreateIfNotFound<SplitTransposer>("Split");
123   }
124   if (IsSplitV(node)) {
125     return GetOrCreateIfNotFound<SplitVTransposer>("SplitV");
126   }
127   if (IsSqueeze(node)) {
128     return GetOrCreateIfNotFound<SqueezeTransposer>("Squeeze");
129   }
130   if (IsStridedSlice(node)) {
131     return GetOrCreateIfNotFound<StridedSliceTransposer>("StridedSlice");
132   }
133   if (IsSwitch(node)) {
134     return GetOrCreateIfNotFound<SwitchTransposer>("Switch");
135   }
136   if (IsTernaryOp(node)) {
137     return GetOrCreateIfNotFound<TernaryOpTransposer>("TernaryOp");
138   }
139   if (IsTile(node)) {
140     return GetOrCreateIfNotFound<TileTransposer>("Tile");
141   }
142   if (IsUnaryGrad(node)) {
143     return GetOrCreateIfNotFound<UnaryGradTransposer>("UnaryGrad");
144   }
145   return nullptr;
146 }
147 
148 }  // namespace grappler
149 }  // namespace tensorflow
150