1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // This pass identifies patterns for dilated convolution and replace it with
16 // a real convolution op.
17
18 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_DILATED_CONV_H_
19 #define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_DILATED_CONV_H_
20
21 #include <cstdint>
22
23 #include "llvm/Support/Casting.h"
24 #include "mlir/IR/Attributes.h" // from @llvm-project
25 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
26 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
27 #include "mlir/IR/Matchers.h" // from @llvm-project
28 #include "mlir/IR/PatternMatch.h" // from @llvm-project
29 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
30 #include "mlir/Pass/Pass.h" // from @llvm-project
31 #include "mlir/Support/LogicalResult.h" // from @llvm-project
32 #include "tensorflow/compiler/mlir/lite/utils/validators.h"
33 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
34
35 namespace mlir {
36 namespace TFL {
37
38 // A dilated convolution can be emulated with a regular convolution by chaining
39 // SpaceToBatch and BatchToSpace ops before and after it:
40 //
41 // SpaceToBatchND -> Conv2D -> BatchToSpaceND
42 //
43 // This method was common before Conv2D fully supported dilated convolution in
44 // TensorFlow. This transformation detects this "emulation", and replaces it
45 // with a true dilated convolution, eliminating the SpaceToBatch and
46 // BatchtoSpace ops.
47 //
48 // Detecting this alone would be relatively easy. However, in practice some
49 // extra ops are used, so we detect the following patterns:
50 //
51 //
52 // SpaceToBatchND -> Expand -> Conv2D -> Squeeze -> BatchToSpaceND -> BiasAdd
53 //
54 // SpaceToBatchND -> Expand -> Conv2D -> Squeeze -> Pad -> BatchToSpaceND ->
55 // BiasAdd
56 //
57 // SpaceToBatchND -> Expand -> Conv2D -> Squeeze -> BiasAdd -> BatchToSpaceND
58 //
59 // SpaceToBatchND -> Conv2D -> Pad -> BatchToSpaceND -> BiasAdd
60 //
61 // SpaceToBatchND -> Conv2D -> BatchToSpaceND -> BiasAdd
62 //
63 //
64 // The Expand/Squeeze combination is used to adapt a 3D array (such as in
65 // WaveNet) to the 4D arrays that Conv2D requires. Padding and BiasAdd are
66 // thrown in just for the extra headache. Padding adapts non-conforming input
67 // sizes, and can be discarded. The bias is necessary, so is kept.
68 template <typename Conv2dOpTy>
69 class ConvertTFDilatedConvOp : public OpRewritePattern<Conv2dOpTy> {
70 private:
71 using OpRewritePattern<Conv2dOpTy>::OpRewritePattern;
72
73 // Extract the dilation factor from `block_shape` and pack it in an ArrayAttr.
74 llvm::Optional<ArrayAttr> ExtractDilationsAttrFromBlockShape(
75 Value stb_block_shape, Value bts_block_shape, int64_t expand_axis,
76 PatternRewriter& rewriter) const;
77
78 public:
79 LogicalResult matchAndRewrite(Conv2dOpTy op,
80 PatternRewriter& rewriter) const override;
81 };
82
83 template <typename Conv2dOpTy>
matchAndRewrite(Conv2dOpTy op,PatternRewriter & rewriter)84 LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
85 Conv2dOpTy op, PatternRewriter& rewriter) const {
86 if (!op.getResult().hasOneUse()) {
87 return rewriter.notifyMatchFailure(
88 op, "result for current op has more than 1 use");
89 }
90 // Make sure Conv2D has 'VALID' padding.
91 if (op->template getAttrOfType<StringAttr>("padding").getValue() != "VALID") {
92 return rewriter.notifyMatchFailure(op,
93 "Conv2D op doesn't have valid padding");
94 }
95 // Make sure dilations are all ones if set.
96 const ArrayAttr& dilations =
97 op->template getAttrOfType<ArrayAttr>("dilations");
98 if (dilations && !TFIntListIsAllOnes(dilations)) {
99 return rewriter.notifyMatchFailure(op, "dilations should be all 1");
100 }
101
102 if (!TFTypeIsFloat32Tensor(op.input()) || !TFDataFormatIsNHWC(op)) {
103 return rewriter.notifyMatchFailure(
104 op, "op's input is not float or the data format isn't NHWC");
105 }
106
107 // Allow dynamic width and height dimensions only.
108 auto result_ty = op.getResult().getType().template cast<TensorType>();
109 if (!result_ty.hasRank() || result_ty.getRank() != 4 ||
110 result_ty.isDynamicDim(0) || result_ty.isDynamicDim(3)) {
111 return rewriter.notifyMatchFailure(
112 op, "only dynamic width and height dimensions are allowed");
113 }
114
115 // Check if the ConvOp's input is defined by `Expand` op, and the output used
116 // by `Squeeze` op.
117 Operation* producer_op = op.getOperand(0).getDefiningOp();
118 if (!producer_op || producer_op->getNumResults() != 1) {
119 return rewriter.notifyMatchFailure(
120 op, "op doesn't have a producer node that has a single result");
121 }
122 if (!producer_op->hasOneUse() ||
123 *(producer_op->getResult(0).user_begin()) != op) {
124 return rewriter.notifyMatchFailure(
125 op, "op's input isn't produced by previous operation");
126 }
127
128 auto tryGetDirectConsumerOp =
129 [&rewriter](Operation* current) -> std::pair<LogicalResult, Operation*> {
130 // Check the current operation has a single result.
131 if (current->getNumResults() != 1) {
132 return {
133 rewriter.notifyMatchFailure(current, "op doesn't have single result"),
134 nullptr};
135 }
136 // Check the current operation has a consumer node.
137 Operation* consumer_op =
138 current->getResult(0).getUses().begin()->getOwner();
139 if (!consumer_op) {
140 return {
141 rewriter.notifyMatchFailure(current, "op doesn't have consumer node"),
142 nullptr};
143 }
144 // Check the current operation's result is used by its successor node.
145 if (!current->hasOneUse() ||
146 *(current->getResult(0).user_begin()) != consumer_op) {
147 return {
148 rewriter.notifyMatchFailure(
149 current, "op's result isn't directly consumed by the next op"),
150 nullptr};
151 }
152 return {LogicalResult::success(), consumer_op};
153 };
154
155 std::pair<LogicalResult, Operation*> maybeConsumer =
156 tryGetDirectConsumerOp(op.getOperation());
157 if (failed(maybeConsumer.first)) {
158 return maybeConsumer.first;
159 }
160 Operation* consumer_op = maybeConsumer.second;
161
162 TF::ExpandDimsOp expand_op;
163 TF::SqueezeOp squeeze_op;
164 int64_t expand_axis = -1;
165 // Expand + Squeeze op.
166 if (llvm::isa<TF::ExpandDimsOp>(producer_op)) {
167 if (!llvm::isa<TF::SqueezeOp>(consumer_op)) {
168 // Expand/Squeeze op must come in pair.
169 return rewriter.notifyMatchFailure(
170 op, "ExpandDimsOp and SqueezeOp should come in pair");
171 }
172 expand_op = llvm::cast<TF::ExpandDimsOp>(producer_op);
173 squeeze_op = llvm::cast<TF::SqueezeOp>(consumer_op);
174 if (!expand_op.getResult().hasOneUse()) {
175 return rewriter.notifyMatchFailure(
176 expand_op, "result for current op has more than 1 use");
177 }
178 if (!squeeze_op.getResult().hasOneUse()) {
179 return rewriter.notifyMatchFailure(
180 squeeze_op, "result for current op has more than 1 use");
181 }
182 // Make sure that the axis in `expand_op` is constant.
183 if (auto const_op =
184 llvm::dyn_cast<TF::ConstOp>(expand_op.dim().getDefiningOp())) {
185 expand_axis = (*const_op.value()
186 .cast<DenseElementsAttr>()
187 .getValues<APInt>()
188 .begin())
189 .getSExtValue();
190 // Canonicalize axis. Some TF python functions, such as
191 // `tf.nn.convolution`, use negative axis.
192 if (expand_axis < 0) {
193 // Always expand 3D input to 4D input.
194 expand_axis += 4;
195 }
196 } else {
197 return rewriter.notifyMatchFailure(
198 expand_op, "ExpandDimsOp doesn't have a constant axis");
199 }
200 // Make sure that the `squeeze_dims` is equal to `expand_axis`.
201 auto squeeze_dims = squeeze_op.squeeze_dims();
202 if (squeeze_dims.size() != 1) {
203 return rewriter.notifyMatchFailure(
204 squeeze_op, "squeeze dims should have exactly 1 dimension specified");
205 }
206 int64_t squeeze_axis = squeeze_dims[0].cast<IntegerAttr>().getInt();
207 if (squeeze_axis < 0) {
208 // Always squeeze 4D input to 3D input.
209 squeeze_axis += 4;
210 }
211 if (squeeze_axis != expand_axis) {
212 return rewriter.notifyMatchFailure(
213 op, "squeeze axis and expand axis doesn't match");
214 }
215
216 // Update previous/next op pointer.
217 Operation* tmp = expand_op.input().getDefiningOp();
218 if (!tmp || tmp->getNumResults() != 1) {
219 return rewriter.notifyMatchFailure(
220 producer_op,
221 "op doesn't have a producer node that has a single result");
222 }
223 if (!tmp->hasOneUse() || *(tmp->getResult(0).user_begin()) != producer_op) {
224 return rewriter.notifyMatchFailure(
225 producer_op, "op's input isn't defined by its previous node");
226 }
227 producer_op = tmp;
228 std::pair<LogicalResult, Operation*> maybeConsumer =
229 tryGetDirectConsumerOp(consumer_op);
230 if (failed(maybeConsumer.first)) {
231 return maybeConsumer.first;
232 }
233 consumer_op = maybeConsumer.second;
234 }
235
236 // SpaceToBatchND op.
237 if (!llvm::isa<TF::SpaceToBatchNDOp>(producer_op)) {
238 return rewriter.notifyMatchFailure(producer_op,
239 "op should be a SpaceToBatchND op");
240 }
241 // TODO(b/149936532): Check `padding` input, currently ignored.
242 TF::SpaceToBatchNDOp stb_op = llvm::cast<TF::SpaceToBatchNDOp>(producer_op);
243 if (!stb_op.getResult().hasOneUse()) {
244 return rewriter.notifyMatchFailure(
245 stb_op, "result for current op has more than 1 use");
246 }
247
248 // Pad op.
249 TF::PadOp pad_op;
250 ElementsAttr pad_attr;
251 if (llvm::isa<TF::PadOp>(consumer_op)) {
252 pad_op = llvm::cast<TF::PadOp>(consumer_op);
253 if (!pad_op.getResult().hasOneUse()) {
254 return rewriter.notifyMatchFailure(
255 pad_op, "result for current op has more than 1 use");
256 }
257 std::pair<LogicalResult, Operation*> maybeConsumer =
258 tryGetDirectConsumerOp(consumer_op);
259 if (failed(maybeConsumer.first)) {
260 return maybeConsumer.first;
261 }
262 consumer_op = maybeConsumer.second;
263 if (!matchPattern(pad_op.paddings(), m_Constant(&pad_attr))) {
264 // If the padding value isn't constant, we can't determine the padding
265 // scheme for Conv2D below, in this case just reject the pattern.
266 return rewriter.notifyMatchFailure(
267 pad_op, "PadOp's padding value isn't constant");
268 }
269 }
270
271 // BatchToSpaceND + BiasAdd.
272 TF::BatchToSpaceNDOp bts_op;
273 TF::BiasAddOp biasadd_op;
274 bool final_op_is_bts = true;
275 if (llvm::isa<TF::BiasAddOp>(consumer_op)) {
276 // Must be BiasAdd + BatchToSpaceND.
277 biasadd_op = llvm::cast<TF::BiasAddOp>(consumer_op);
278 if (!biasadd_op.getResult().hasOneUse()) {
279 return rewriter.notifyMatchFailure(
280 biasadd_op, "result for current op has more than 1 use");
281 }
282 std::pair<LogicalResult, Operation*> maybeConsumer =
283 tryGetDirectConsumerOp(consumer_op);
284 if (failed(maybeConsumer.first)) {
285 return maybeConsumer.first;
286 }
287 if (!llvm::isa<TF::BatchToSpaceNDOp>(maybeConsumer.second)) {
288 return rewriter.notifyMatchFailure(
289 consumer_op, "op's next node isn't BatchToSpaceND op");
290 }
291 consumer_op = maybeConsumer.second;
292 bts_op = llvm::cast<TF::BatchToSpaceNDOp>(consumer_op);
293 } else if (llvm::isa<TF::BatchToSpaceNDOp>(consumer_op)) {
294 // BatchToSpaceND + (optional) BiasAdd.
295 bts_op = llvm::cast<TF::BatchToSpaceNDOp>(consumer_op);
296 std::pair<LogicalResult, Operation*> maybeConsumer =
297 tryGetDirectConsumerOp(consumer_op);
298 Operation* tmp = maybeConsumer.second;
299 if (tmp && llvm::isa<TF::BiasAddOp>(tmp)) {
300 consumer_op = tmp;
301 biasadd_op = llvm::cast<TF::BiasAddOp>(consumer_op);
302 final_op_is_bts = false;
303 }
304 } else {
305 return rewriter.notifyMatchFailure(
306 consumer_op, "next op is neither BiasAdd nor BatchToSpaceND");
307 }
308
309 llvm::Optional<ArrayAttr> dilations_attr = ExtractDilationsAttrFromBlockShape(
310 stb_op.block_shape(), bts_op.block_shape(), expand_axis, rewriter);
311 if (!dilations_attr.has_value()) {
312 return rewriter.notifyMatchFailure(op, "failed to extract dilation rate");
313 }
314
315 if (expand_op) {
316 if (stb_op.input().getType().dyn_cast<RankedTensorType>() == nullptr) {
317 return rewriter.notifyMatchFailure(
318 stb_op, "SpaceToBatchND op's input should have RankedTensorType");
319 }
320 }
321
322 // TODO(b/149936532): Check that the input width & height are multiples of
323 // dilation rate.
324 // TF python library will rewrite dilated conv to
325 // "SpaceToBatch->Conv->BatchToSpace" pattern, and the Conv in the middle
326 // always has 'VALID' padding. The padding tensor in `SpaceToBatch` has two
327 // parts of contributions, one is to reduce padding of CONV from 'SAME' to
328 // 'VALID', and another is to make input shape multiples of dilation rate. The
329 // first part of padding, which is also called `base_padding` will be used
330 // here to determine if the original padding format is 'SAME' or 'VALID'.
331 // According to the following formula we will compute the `base_padding` if
332 // it's a constant. Basically, `paddings` tensor in `SpaceToBatch` and `crops`
333 // tensor in `BatchToSpace` must satisfy the following:
334 // paddings[i, 0] = base_paddings[i, 0].
335 // 0 <= paddings[i, 1] - base_paddings[i, 1] < block_shape[i]
336 // (input_shape[i] + paddings[i, 0] + paddings[i, 1]) % block_shape[i] == 0.
337 // crops[i, 0] = 0.
338 // crops[i, 1] = paddings[i, 1] - base_paddings[i, 1].
339
340 // If `paddings` - `crops` != 0, this means that `base_paddings` != 0, which
341 // tells us the original padding is 'SAME' (with one caveat presented below).
342 // Here we need to reset the padding back to `SAME` if `base_padding`
343 // != 0.
344 // TODO(b/149936532): We might not simply rely on `paddings - crops != 0` to
345 // determine the original padding format. For example, users can build
346 // arbitrary valid examples of `STB->Conv->BTS` which doesn't represent a
347 // dilated conv, hence we shouldn't pattern match here. Instead, we need to
348 // check values of `paddings` and `crops` to make sure it really stands for
349 // a dilated conv.
350 auto stb_paddings = stb_op.paddings();
351 auto bts_crops = bts_op.crops();
352 ElementsAttr stb_paddings_attr, bts_crops_attr;
353 if (!matchPattern(stb_paddings, m_Constant(&stb_paddings_attr)) ||
354 !matchPattern(bts_crops, m_Constant(&bts_crops_attr))) {
355 return rewriter.notifyMatchFailure(
356 op,
357 "either SpaceToBatchND or BatchToSpaceND "
358 "doesn't have constant padding/crops value");
359 }
360 if (stb_paddings_attr.getType() != bts_crops_attr.getType()) {
361 return rewriter.notifyMatchFailure(
362 stb_op,
363 "SpaceToBatchND op's padding doesn't have same shape/type with "
364 "BatchToSpaceND op's crops");
365 }
366 int64_t m = stb_paddings_attr.getType().getDimSize(0);
367 // padding - crop.
368 for (uint64_t i = 0; i < m; ++i) {
369 for (uint64_t j = 0; j < 2; ++j) {
370 // `crops` tensor has shape [M, 2], crops[i] = [crop_start, crop_end]
371 // specifies the amount to crop from input dimension i + 1. If the input
372 // of `BatchToSpaceND` has been padded explicitly, then we need to
373 // take into account the additional padding when determining the padding
374 // scheme for `Conv2D`.
375 int64_t addtional_pad =
376 pad_attr ? pad_attr.getValues<APInt>()[{i + 1, j}].getSExtValue() : 0;
377 if (stb_paddings_attr.getValues<APInt>()[{i, j}].getSExtValue() +
378 addtional_pad !=
379 bts_crops_attr.getValues<APInt>()[{i, j}].getSExtValue()) {
380 op->setAttr("padding", rewriter.getStringAttr("SAME"));
381 break;
382 }
383 }
384 }
385
386 // Set dilations
387 op->setAttr("dilations", dilations_attr.getValue());
388
389 if (expand_op) {
390 // If there is `expand_op`, we need to rewire the inputs to bypass the
391 // `SpaceToBatch`, `BatchToSpace` and `Pad` op. E.g, turning
392 // 'SpaceToBatchND -> Expand -> Conv2D -> Squeeze -> BatchToSpaceND ->
393 // BiasAdd' to 'Expand -> Conv2D ->Squeeze -> BiasAdd'.
394
395 // Connect `expand_op` with the input of `stb_op`.
396 expand_op.setOperand(0, stb_op.input());
397 // Calculate the shape for expand.
398 auto input_shape = stb_op.input().getType().cast<ShapedType>().getShape();
399 SmallVector<int64_t, 4> expand_shape(input_shape.begin(),
400 input_shape.end());
401 expand_shape.insert(expand_shape.begin() + expand_axis, 1);
402
403 auto expand_result_type = RankedTensorType::get(
404 expand_shape, getElementTypeOrSelf(stb_op.input()));
405 expand_op.getResult().setType(expand_result_type);
406
407 // Update the conv op's output shape.
408 auto bts_output_shape =
409 bts_op.output().getType().cast<ShapedType>().getShape();
410 SmallVector<int64_t, 4> conv_result_shape(bts_output_shape.begin(),
411 bts_output_shape.end());
412 conv_result_shape.insert(conv_result_shape.begin() + expand_axis, 1);
413 auto conv_result_type = RankedTensorType::get(
414 conv_result_shape, getElementTypeOrSelf(stb_op.input()));
415 op.getResult().setType(conv_result_type);
416
417 squeeze_op.getResult().setType(bts_op.output().getType());
418
419 // Connect `biasadd_op` with the output of `squeeze_op`.
420 if (biasadd_op) {
421 biasadd_op.setOperand(0, squeeze_op.output());
422 biasadd_op.output().setType(squeeze_op.output().getType());
423 }
424 } else {
425 if (biasadd_op) biasadd_op.setOperand(0, op.output());
426 op.setOperand(0, stb_op.input());
427 op.getResult().setType(bts_op.getResult().getType());
428 }
429
430 if (final_op_is_bts) {
431 if (bts_op.input().getDefiningOp<TF::PadOp>()) {
432 bts_op.getResult().replaceAllUsesWith(pad_op.input());
433 } else {
434 bts_op.getResult().replaceAllUsesWith(bts_op.input());
435 }
436 }
437
438 stb_op.getResult().dropAllUses();
439 return success();
440 }
441
442 template <typename Conv2dOpTy>
443 llvm::Optional<ArrayAttr>
ExtractDilationsAttrFromBlockShape(Value stb_block_shape,Value bts_block_shape,int64_t expand_axis,PatternRewriter & rewriter)444 ConvertTFDilatedConvOp<Conv2dOpTy>::ExtractDilationsAttrFromBlockShape(
445 Value stb_block_shape, Value bts_block_shape, int64_t expand_axis,
446 PatternRewriter& rewriter) const {
447 ElementsAttr stb_bs_attr, bts_bs_attr;
448 if (!matchPattern(stb_block_shape, m_Constant(&stb_bs_attr)) ||
449 !matchPattern(bts_block_shape, m_Constant(&bts_bs_attr))) {
450 // Returns failure status if block_shape is not a constant.
451 return {};
452 }
453 // Check that the block_shape of `stb_op` and `bts_op` are equal.
454 if (stb_bs_attr.getNumElements() != bts_bs_attr.getNumElements()) return {};
455 for (uint64_t i = 0, end = stb_bs_attr.getNumElements(); i < end; ++i) {
456 if (stb_bs_attr.getValues<Attribute>()[i] !=
457 bts_bs_attr.getValues<Attribute>()[i])
458 return {};
459 }
460
461 int dilation_h_factor = -1, dilation_w_factor = -1;
462 // Set dilation factor.
463 if (stb_bs_attr.getNumElements() >= 2) {
464 dilation_h_factor = stb_bs_attr.getValues<APInt>()[0].getSExtValue();
465 dilation_w_factor = stb_bs_attr.getValues<APInt>()[1].getSExtValue();
466 } else if (stb_bs_attr.getNumElements() == 1) {
467 // For 1d conv, `tf.nn.convolution` expands NWC to NHWC format after
468 // `SpaceToBatchND`. Therefore, `block_shape` of `stb_op` only has one
469 // dilation factor of W dim, and dilation factor of H dim is set to 1.
470 if (expand_axis == 1) {
471 // NWC -> NHWC
472 dilation_h_factor = 1;
473 dilation_w_factor = stb_bs_attr.getValues<APInt>()[0].getSExtValue();
474 } else if (expand_axis == 2) {
475 // NHC -> NHWC
476 dilation_h_factor = stb_bs_attr.getValues<APInt>()[0].getSExtValue();
477 dilation_w_factor = 1;
478 }
479 }
480
481 if (dilation_h_factor == -1 || dilation_w_factor == -1) {
482 return {};
483 }
484
485 return rewriter.getI64ArrayAttr({1, dilation_h_factor, dilation_w_factor, 1});
486 }
487
488 } // namespace TFL
489 } // namespace mlir
490
491 #endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_DILATED_CONV_H_
492