xref: /aosp_15_r20/external/tensorflow/tensorflow/core/transforms/remapper/remapping_helper.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_TRANSFORMS_REMAPPER_REMAPPING_HELPER_H_
16 #define TENSORFLOW_CORE_TRANSFORMS_REMAPPER_REMAPPING_HELPER_H_
17 
18 #include <string>
19 
20 #include "tensorflow/core/framework/types.h"
21 #include "tensorflow/core/transforms/utils/op_cat_helper.h"
22 #include "tensorflow/core/transforms/utils/utils.h"
23 
24 namespace mlir {
25 namespace tfg {
26 
27 // The following structures store info of the operations to be fused. These
28 // are mainly used for combining operands info and attributes for a fused
29 // operation. They are also used for some predicate functions like
30 // `IsCpuCompatible` and `IsGpuCompatible` to check if the relevant fusion is
31 // supported on CPU and GPU, respectively. Another reason to keep these
32 // structures is to follow similar logics in current grappler-remapper.
33 // TODO(intel-tf): Remove redundancies once the similar functionality is
34 // achieved by tfg-remapper.
35 struct ContractionBiasAdd {
36   Operation* contraction;
37   Operation* bias_add;
38 };
39 
40 struct ContractionBiasAddActivation {
41   Operation* contraction;
42   Operation* bias_add;
43   Operation* activation;
44 };
45 
46 struct ContractionBiasAddAdd {
47   Operation* contraction;
48   Operation* bias_add;
49   Operation* add;
50 };
51 
52 struct ContractionBiasAddAddActivation {
53   Operation* contraction;
54   Operation* bias_add;
55   Operation* add;
56   Operation* activation;
57 };
58 
59 class OpPropertyHelper : public OpCatHelper {
60  public:
61   OpPropertyHelper() = default;
62   explicit OpPropertyHelper(TFGraphDialect* dialect,
63                             bool onednn_enabled = false,
64                             bool xla_auto_clustering = false)
OpCatHelper(dialect)65       : OpCatHelper(dialect),
66         is_onednn_enabled_(onednn_enabled),
67         is_xla_auto_clustering_enabled_(xla_auto_clustering) {}
68 
HasControlOperandsOrResultUsers(Operation * op)69   bool HasControlOperandsOrResultUsers(Operation* op) const {
70     TFOp wrapper_op(op);
71     bool has_ctl_operands = !(wrapper_op.getControlOperands().empty());
72     bool has_ctl_ret_users = !(wrapper_op.controlRet().getUsers().empty());
73     if (has_ctl_operands || has_ctl_ret_users)
74       return true;
75     else
76       return false;
77   }
78 
79   // This function is to be used for an operation that has at least 1
80   // non-control result.
HasAtMostOneUserOfResult0(Operation * op)81   bool HasAtMostOneUserOfResult0(Operation* op) const {
82     // All tfg operation has 1 control result. When the operation has at least 1
83     // non-control result, the number of results should be at least 2.
84     return op->getNumResults() > 1 &&
85            (op->getResult(0).hasOneUse() || op->getResult(0).use_empty());
86   }
87 
IsContraction(Operation * op)88   bool IsContraction(Operation* op) const {
89     return dialect_->IsConv2D(op) || dialect_->IsConv3D(op) ||
90            dialect_->IsDepthwiseConv2dNative(op) || dialect_->IsMatMul(op);
91   }
92 
93   bool HaveSameDataType(Operation* lhs_op, Operation* rhs_op,
94                         StringRef attr_name = "T") const {
95     auto lhs_attr = lhs_op->getAttrOfType<TypeAttr>(attr_name);
96     auto rhs_attr = rhs_op->getAttrOfType<TypeAttr>(attr_name);
97     if (!lhs_attr || !rhs_attr) return false;
98     return lhs_attr == rhs_attr;
99   }
100 
101   // This function is currently used by contraction ops.
102   bool IsGpuCompatibleDataType(Operation* contraction_op,
103                                StringRef attr_name = "T") const {
104     auto attr = contraction_op->getAttrOfType<TypeAttr>(attr_name);
105     if (!attr) return false;
106     Type dtype = attr.getValue();
107     if (dialect_->IsConv2D(contraction_op)) {
108       return dtype.isa<Float32Type>();
109     } else if (dialect_->IsMatMul(contraction_op)) {
110       return dtype.isa<Float32Type, Float64Type>();
111     } else {
112       return false;
113     }
114   }
115 
116   // This function is currently used by contraction ops.
117   bool IsCpuCompatibleDataType(Operation* contraction_op,
118                                StringRef attr_name = "T") const {
119     auto attr = contraction_op->getAttrOfType<TypeAttr>(attr_name);
120     if (!attr) return false;
121     Type dtype = attr.getValue();
122     if (is_onednn_enabled_) {
123       // Only contraction ops (MatMul, Conv2D, Conv3D, and
124       // DepthwiseConv2dNative) and BatchMatMul are supported. BatchMatMul
125       // fusions are handled differently than contraction ops.
126       bool is_supported = IsContraction(contraction_op) ||
127                           dialect_->IsAnyBatchMatMul(contraction_op);
128       return is_supported && dtype.isa<Float32Type, BFloat16Type>();
129     }
130 
131     if (dialect_->IsConv2D(contraction_op)) {
132       return dtype.isa<Float32Type, Float64Type>();
133     } else if (dialect_->IsMatMul(contraction_op)) {
134       return dtype.isa<Float32Type>();
135     } else {
136       return false;
137     }
138   }
139 
140   // This function is currently used by convolution type op
141   bool IsGpuCompatibleDataFormat(Operation* conv_op,
142                                  StringRef attr_name = "data_format") const {
143     StringRef data_format;
144     if (auto attr = conv_op->getAttrOfType<StringAttr>(attr_name)) {
145       data_format = attr.getValue();
146     } else {
147       return false;
148     }
149     if (dialect_->IsConv2D(conv_op)) {
150       return data_format == "NHWC" || data_format == "NCHW";
151     } else {
152       return false;
153     }
154   }
155 
156   // This function is currently used by convolution type op
157   bool IsCpuCompatibleDataFormat(Operation* conv_op,
158                                  StringRef attr_name = "data_format") const {
159     StringRef data_format;
160     if (auto attr = conv_op->getAttrOfType<StringAttr>(attr_name)) {
161       data_format = attr.getValue();
162     } else {
163       return false;
164     }
165     if (dialect_->IsConv2D(conv_op)) {
166       return data_format == "NHWC" ||
167              (is_onednn_enabled_ && data_format == "NCHW");
168     } else if (dialect_->IsConv3D(conv_op)) {
169       return data_format == "NDHWC" ||
170              (is_onednn_enabled_ && data_format == "NCDHW");
171     } else {
172       return false;
173     }
174   }
175 
IsGpuCompatible(const ContractionBiasAddActivation & pattern)176   bool IsGpuCompatible(const ContractionBiasAddActivation& pattern) const {
177 #if TENSORFLOW_USE_ROCM
178     // ROCm does not support _FusedConv2D. Does it suppport _FusedMatMul?
179     return false;
180 #endif
181     // The TF->XLA bridge does not support `_FusedMatMul` so we avoid creating
182     // this op. Furthermore, XLA already does this fusion internally so there
183     // is no true benefit from doing this optimization if XLA is going to
184     // compile the unfused operations anyway.
185     if (is_xla_auto_clustering_enabled_) return false;
186     if (!util::OpHasDevice(pattern.contraction, tensorflow::DEVICE_GPU))
187       return false;
188     if (!dialect_->IsRelu(pattern.activation)) return false;
189     if (dialect_->IsMatMul(pattern.contraction)) {
190       return IsGpuCompatibleDataType(pattern.contraction);
191     } else {
192       // TODO(intel-tf): Add spatial convolution support on GPU
193       return false;
194     }
195   }
196 
197   // Currently GPU does not supprt contraction + bias_add
IsGpuCompatible(const ContractionBiasAdd &)198   bool IsGpuCompatible(const ContractionBiasAdd&) const { return false; }
199 
IsCpuCompatible(Operation * contraction_op)200   bool IsCpuCompatible(Operation* contraction_op) const {
201     if (!util::OpHasDevice(contraction_op, tensorflow::DEVICE_CPU))
202       return false;
203     if (dialect_->IsConv2D(contraction_op) ||
204         dialect_->IsConv3D(contraction_op)) {
205       return IsCpuCompatibleDataType(contraction_op) &&
206              IsCpuCompatibleDataFormat(contraction_op);
207     } else if (dialect_->IsMatMul(contraction_op) ||
208                dialect_->IsAnyBatchMatMul(contraction_op) ||
209                dialect_->IsDepthwiseConv2dNative(contraction_op)) {
210       return IsCpuCompatibleDataType(contraction_op);
211     } else {
212       return false;
213     }
214   }
215 
216   template <typename Pattern>
IsDeviceCompatible(const Pattern & pattern)217   bool IsDeviceCompatible(const Pattern& pattern) const {
218     // Currently, this function is used by contraction based fussion.
219     if constexpr (!std::is_same<Pattern, ContractionBiasAdd>::value &&
220                   !std::is_same<Pattern, ContractionBiasAddActivation>::value &&
221                   !std::is_same<Pattern, ContractionBiasAddAdd>::value &&
222                   !std::is_same<Pattern, ContractionBiasAddActivation>::value) {
223       return false;
224     }
225     return IsGpuCompatible(pattern) || IsCpuCompatible(pattern.contraction);
226   }
227 
228  private:
229   bool is_onednn_enabled_;
230   bool is_xla_auto_clustering_enabled_;
231 };
232 
233 }  // namespace tfg
234 }  // namespace mlir
235 
236 #endif  // TENSORFLOW_CORE_TRANSFORMS_REMAPPER_REMAPPING_HELPER_H_
237