xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // This pass identifies patterns for dilated convolution and replace it with
16 // a real convolution op.
17 
18 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_DILATED_CONV_H_
19 #define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_DILATED_CONV_H_
20 
21 #include <cstdint>
22 
23 #include "llvm/Support/Casting.h"
24 #include "mlir/IR/Attributes.h"  // from @llvm-project
25 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
26 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
27 #include "mlir/IR/Matchers.h"  // from @llvm-project
28 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
29 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
30 #include "mlir/Pass/Pass.h"  // from @llvm-project
31 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
32 #include "tensorflow/compiler/mlir/lite/utils/validators.h"
33 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
34 
35 namespace mlir {
36 namespace TFL {
37 
38 // A dilated convolution can be emulated with a regular convolution by chaining
39 // SpaceToBatch and BatchToSpace ops before and after it:
40 //
41 //     SpaceToBatchND -> Conv2D -> BatchToSpaceND
42 //
43 // This method was common before Conv2D fully supported dilated convolution in
44 // TensorFlow. This transformation detects this "emulation", and replaces it
45 // with a true dilated convolution, eliminating the SpaceToBatch and
46 // BatchtoSpace ops.
47 //
48 // Detecting this alone would be relatively easy. However, in practice some
49 // extra ops are used, so we detect the following patterns:
50 //
51 //
52 //   SpaceToBatchND -> Expand -> Conv2D -> Squeeze -> BatchToSpaceND -> BiasAdd
53 //
54 //   SpaceToBatchND -> Expand -> Conv2D -> Squeeze -> Pad -> BatchToSpaceND ->
55 //   BiasAdd
56 //
57 //   SpaceToBatchND -> Expand -> Conv2D -> Squeeze -> BiasAdd -> BatchToSpaceND
58 //
59 //   SpaceToBatchND -> Conv2D -> Pad -> BatchToSpaceND -> BiasAdd
60 //
61 //   SpaceToBatchND -> Conv2D -> BatchToSpaceND -> BiasAdd
62 //
63 //
64 // The Expand/Squeeze combination is used to adapt a 3D array (such as in
65 // WaveNet) to the 4D arrays that Conv2D requires. Padding and BiasAdd are
66 // thrown in just for the extra headache. Padding adapts non-conforming input
67 // sizes, and can be discarded. The bias is necessary, so is kept.
68 template <typename Conv2dOpTy>
69 class ConvertTFDilatedConvOp : public OpRewritePattern<Conv2dOpTy> {
70  private:
71   using OpRewritePattern<Conv2dOpTy>::OpRewritePattern;
72 
73   // Extract the dilation factor from `block_shape` and pack it in an ArrayAttr.
74   llvm::Optional<ArrayAttr> ExtractDilationsAttrFromBlockShape(
75       Value stb_block_shape, Value bts_block_shape, int64_t expand_axis,
76       PatternRewriter& rewriter) const;
77 
78  public:
79   LogicalResult matchAndRewrite(Conv2dOpTy op,
80                                 PatternRewriter& rewriter) const override;
81 };
82 
83 template <typename Conv2dOpTy>
matchAndRewrite(Conv2dOpTy op,PatternRewriter & rewriter)84 LogicalResult ConvertTFDilatedConvOp<Conv2dOpTy>::matchAndRewrite(
85     Conv2dOpTy op, PatternRewriter& rewriter) const {
86   if (!op.getResult().hasOneUse()) {
87     return rewriter.notifyMatchFailure(
88         op, "result for current op has more than 1 use");
89   }
90   // Make sure Conv2D has 'VALID' padding.
91   if (op->template getAttrOfType<StringAttr>("padding").getValue() != "VALID") {
92     return rewriter.notifyMatchFailure(op,
93                                        "Conv2D op doesn't have valid padding");
94   }
95   // Make sure dilations are all ones if set.
96   const ArrayAttr& dilations =
97       op->template getAttrOfType<ArrayAttr>("dilations");
98   if (dilations && !TFIntListIsAllOnes(dilations)) {
99     return rewriter.notifyMatchFailure(op, "dilations should be all 1");
100   }
101 
102   if (!TFTypeIsFloat32Tensor(op.input()) || !TFDataFormatIsNHWC(op)) {
103     return rewriter.notifyMatchFailure(
104         op, "op's input is not float or the data format isn't NHWC");
105   }
106 
107   // Allow dynamic width and height dimensions only.
108   auto result_ty = op.getResult().getType().template cast<TensorType>();
109   if (!result_ty.hasRank() || result_ty.getRank() != 4 ||
110       result_ty.isDynamicDim(0) || result_ty.isDynamicDim(3)) {
111     return rewriter.notifyMatchFailure(
112         op, "only dynamic width and height dimensions are allowed");
113   }
114 
115   // Check if the ConvOp's input is defined by `Expand` op, and the output used
116   // by `Squeeze` op.
117   Operation* producer_op = op.getOperand(0).getDefiningOp();
118   if (!producer_op || producer_op->getNumResults() != 1) {
119     return rewriter.notifyMatchFailure(
120         op, "op doesn't have a producer node that has a single result");
121   }
122   if (!producer_op->hasOneUse() ||
123       *(producer_op->getResult(0).user_begin()) != op) {
124     return rewriter.notifyMatchFailure(
125         op, "op's input isn't produced by previous operation");
126   }
127 
128   auto tryGetDirectConsumerOp =
129       [&rewriter](Operation* current) -> std::pair<LogicalResult, Operation*> {
130     // Check the current operation has a single result.
131     if (current->getNumResults() != 1) {
132       return {
133           rewriter.notifyMatchFailure(current, "op doesn't have single result"),
134           nullptr};
135     }
136     // Check the current operation has a consumer node.
137     Operation* consumer_op =
138         current->getResult(0).getUses().begin()->getOwner();
139     if (!consumer_op) {
140       return {
141           rewriter.notifyMatchFailure(current, "op doesn't have consumer node"),
142           nullptr};
143     }
144     // Check the current operation's result is used by its successor node.
145     if (!current->hasOneUse() ||
146         *(current->getResult(0).user_begin()) != consumer_op) {
147       return {
148           rewriter.notifyMatchFailure(
149               current, "op's result isn't directly consumed by the next op"),
150           nullptr};
151     }
152     return {LogicalResult::success(), consumer_op};
153   };
154 
155   std::pair<LogicalResult, Operation*> maybeConsumer =
156       tryGetDirectConsumerOp(op.getOperation());
157   if (failed(maybeConsumer.first)) {
158     return maybeConsumer.first;
159   }
160   Operation* consumer_op = maybeConsumer.second;
161 
162   TF::ExpandDimsOp expand_op;
163   TF::SqueezeOp squeeze_op;
164   int64_t expand_axis = -1;
165   // Expand + Squeeze op.
166   if (llvm::isa<TF::ExpandDimsOp>(producer_op)) {
167     if (!llvm::isa<TF::SqueezeOp>(consumer_op)) {
168       // Expand/Squeeze op must come in pair.
169       return rewriter.notifyMatchFailure(
170           op, "ExpandDimsOp and SqueezeOp should come in pair");
171     }
172     expand_op = llvm::cast<TF::ExpandDimsOp>(producer_op);
173     squeeze_op = llvm::cast<TF::SqueezeOp>(consumer_op);
174     if (!expand_op.getResult().hasOneUse()) {
175       return rewriter.notifyMatchFailure(
176           expand_op, "result for current op has more than 1 use");
177     }
178     if (!squeeze_op.getResult().hasOneUse()) {
179       return rewriter.notifyMatchFailure(
180           squeeze_op, "result for current op has more than 1 use");
181     }
182     // Make sure that the axis in `expand_op` is constant.
183     if (auto const_op =
184             llvm::dyn_cast<TF::ConstOp>(expand_op.dim().getDefiningOp())) {
185       expand_axis = (*const_op.value()
186                           .cast<DenseElementsAttr>()
187                           .getValues<APInt>()
188                           .begin())
189                         .getSExtValue();
190       // Canonicalize axis. Some TF python functions, such as
191       // `tf.nn.convolution`, use negative axis.
192       if (expand_axis < 0) {
193         // Always expand 3D input to 4D input.
194         expand_axis += 4;
195       }
196     } else {
197       return rewriter.notifyMatchFailure(
198           expand_op, "ExpandDimsOp doesn't have a constant axis");
199     }
200     // Make sure that the `squeeze_dims` is equal to `expand_axis`.
201     auto squeeze_dims = squeeze_op.squeeze_dims();
202     if (squeeze_dims.size() != 1) {
203       return rewriter.notifyMatchFailure(
204           squeeze_op, "squeeze dims should have exactly 1 dimension specified");
205     }
206     int64_t squeeze_axis = squeeze_dims[0].cast<IntegerAttr>().getInt();
207     if (squeeze_axis < 0) {
208       // Always squeeze 4D input to 3D input.
209       squeeze_axis += 4;
210     }
211     if (squeeze_axis != expand_axis) {
212       return rewriter.notifyMatchFailure(
213           op, "squeeze axis and expand axis doesn't match");
214     }
215 
216     // Update previous/next op pointer.
217     Operation* tmp = expand_op.input().getDefiningOp();
218     if (!tmp || tmp->getNumResults() != 1) {
219       return rewriter.notifyMatchFailure(
220           producer_op,
221           "op doesn't have a producer node that has a single result");
222     }
223     if (!tmp->hasOneUse() || *(tmp->getResult(0).user_begin()) != producer_op) {
224       return rewriter.notifyMatchFailure(
225           producer_op, "op's input isn't defined by its previous node");
226     }
227     producer_op = tmp;
228     std::pair<LogicalResult, Operation*> maybeConsumer =
229         tryGetDirectConsumerOp(consumer_op);
230     if (failed(maybeConsumer.first)) {
231       return maybeConsumer.first;
232     }
233     consumer_op = maybeConsumer.second;
234   }
235 
236   // SpaceToBatchND op.
237   if (!llvm::isa<TF::SpaceToBatchNDOp>(producer_op)) {
238     return rewriter.notifyMatchFailure(producer_op,
239                                        "op should be a SpaceToBatchND op");
240   }
241   // TODO(b/149936532): Check `padding` input, currently ignored.
242   TF::SpaceToBatchNDOp stb_op = llvm::cast<TF::SpaceToBatchNDOp>(producer_op);
243   if (!stb_op.getResult().hasOneUse()) {
244     return rewriter.notifyMatchFailure(
245         stb_op, "result for current op has more than 1 use");
246   }
247 
248   // Pad op.
249   TF::PadOp pad_op;
250   ElementsAttr pad_attr;
251   if (llvm::isa<TF::PadOp>(consumer_op)) {
252     pad_op = llvm::cast<TF::PadOp>(consumer_op);
253     if (!pad_op.getResult().hasOneUse()) {
254       return rewriter.notifyMatchFailure(
255           pad_op, "result for current op has more than 1 use");
256     }
257     std::pair<LogicalResult, Operation*> maybeConsumer =
258         tryGetDirectConsumerOp(consumer_op);
259     if (failed(maybeConsumer.first)) {
260       return maybeConsumer.first;
261     }
262     consumer_op = maybeConsumer.second;
263     if (!matchPattern(pad_op.paddings(), m_Constant(&pad_attr))) {
264       // If the padding value isn't constant, we can't determine the padding
265       // scheme for Conv2D below, in this case just reject the pattern.
266       return rewriter.notifyMatchFailure(
267           pad_op, "PadOp's padding value isn't constant");
268     }
269   }
270 
271   // BatchToSpaceND + BiasAdd.
272   TF::BatchToSpaceNDOp bts_op;
273   TF::BiasAddOp biasadd_op;
274   bool final_op_is_bts = true;
275   if (llvm::isa<TF::BiasAddOp>(consumer_op)) {
276     // Must be BiasAdd + BatchToSpaceND.
277     biasadd_op = llvm::cast<TF::BiasAddOp>(consumer_op);
278     if (!biasadd_op.getResult().hasOneUse()) {
279       return rewriter.notifyMatchFailure(
280           biasadd_op, "result for current op has more than 1 use");
281     }
282     std::pair<LogicalResult, Operation*> maybeConsumer =
283         tryGetDirectConsumerOp(consumer_op);
284     if (failed(maybeConsumer.first)) {
285       return maybeConsumer.first;
286     }
287     if (!llvm::isa<TF::BatchToSpaceNDOp>(maybeConsumer.second)) {
288       return rewriter.notifyMatchFailure(
289           consumer_op, "op's next node isn't BatchToSpaceND op");
290     }
291     consumer_op = maybeConsumer.second;
292     bts_op = llvm::cast<TF::BatchToSpaceNDOp>(consumer_op);
293   } else if (llvm::isa<TF::BatchToSpaceNDOp>(consumer_op)) {
294     // BatchToSpaceND + (optional) BiasAdd.
295     bts_op = llvm::cast<TF::BatchToSpaceNDOp>(consumer_op);
296     std::pair<LogicalResult, Operation*> maybeConsumer =
297         tryGetDirectConsumerOp(consumer_op);
298     Operation* tmp = maybeConsumer.second;
299     if (tmp && llvm::isa<TF::BiasAddOp>(tmp)) {
300       consumer_op = tmp;
301       biasadd_op = llvm::cast<TF::BiasAddOp>(consumer_op);
302       final_op_is_bts = false;
303     }
304   } else {
305     return rewriter.notifyMatchFailure(
306         consumer_op, "next op is neither BiasAdd nor BatchToSpaceND");
307   }
308 
309   llvm::Optional<ArrayAttr> dilations_attr = ExtractDilationsAttrFromBlockShape(
310       stb_op.block_shape(), bts_op.block_shape(), expand_axis, rewriter);
311   if (!dilations_attr.has_value()) {
312     return rewriter.notifyMatchFailure(op, "failed to extract dilation rate");
313   }
314 
315   if (expand_op) {
316     if (stb_op.input().getType().dyn_cast<RankedTensorType>() == nullptr) {
317       return rewriter.notifyMatchFailure(
318           stb_op, "SpaceToBatchND op's input should have RankedTensorType");
319     }
320   }
321 
322   // TODO(b/149936532): Check that the input width & height are multiples of
323   // dilation rate.
324   // TF python library will rewrite dilated conv to
325   // "SpaceToBatch->Conv->BatchToSpace" pattern, and the Conv in the middle
326   // always has 'VALID' padding. The padding tensor in `SpaceToBatch` has two
327   // parts of contributions, one is to reduce padding of CONV from 'SAME' to
328   // 'VALID', and another is to make input shape multiples of dilation rate. The
329   // first part of padding, which is also called `base_padding` will be used
330   // here to determine if the original padding format is 'SAME' or 'VALID'.
331   // According to the following formula we will compute the `base_padding` if
332   // it's a constant. Basically, `paddings` tensor in `SpaceToBatch` and `crops`
333   // tensor  in `BatchToSpace` must satisfy the following:
334   //  paddings[i, 0] = base_paddings[i, 0].
335   //  0 <= paddings[i, 1] - base_paddings[i, 1] < block_shape[i]
336   // (input_shape[i] + paddings[i, 0] + paddings[i, 1]) % block_shape[i] == 0.
337   //  crops[i, 0] = 0.
338   //  crops[i, 1] = paddings[i, 1] - base_paddings[i, 1].
339 
340   //  If `paddings` - `crops` != 0, this means that `base_paddings` != 0, which
341   // tells us the original padding is 'SAME' (with one caveat presented below).
342   // Here we need to reset the padding back to `SAME` if `base_padding`
343   // != 0.
344   // TODO(b/149936532): We might not simply rely on `paddings - crops != 0` to
345   // determine the original padding format. For example, users can build
346   // arbitrary valid examples of `STB->Conv->BTS` which doesn't represent a
347   // dilated conv, hence we shouldn't pattern match here. Instead, we need to
348   // check values of `paddings` and `crops` to make sure it really stands for
349   // a dilated conv.
350   auto stb_paddings = stb_op.paddings();
351   auto bts_crops = bts_op.crops();
352   ElementsAttr stb_paddings_attr, bts_crops_attr;
353   if (!matchPattern(stb_paddings, m_Constant(&stb_paddings_attr)) ||
354       !matchPattern(bts_crops, m_Constant(&bts_crops_attr))) {
355     return rewriter.notifyMatchFailure(
356         op,
357         "either SpaceToBatchND or BatchToSpaceND "
358         "doesn't have constant padding/crops value");
359   }
360   if (stb_paddings_attr.getType() != bts_crops_attr.getType()) {
361     return rewriter.notifyMatchFailure(
362         stb_op,
363         "SpaceToBatchND op's padding doesn't have same shape/type with "
364         "BatchToSpaceND op's crops");
365   }
366   int64_t m = stb_paddings_attr.getType().getDimSize(0);
367   // padding - crop.
368   for (uint64_t i = 0; i < m; ++i) {
369     for (uint64_t j = 0; j < 2; ++j) {
370       // `crops` tensor has shape [M, 2], crops[i] = [crop_start, crop_end]
371       // specifies the amount to crop from input dimension i + 1. If the input
372       // of `BatchToSpaceND` has been padded explicitly, then we need to
373       // take into account the additional padding when determining the padding
374       // scheme for `Conv2D`.
375       int64_t addtional_pad =
376           pad_attr ? pad_attr.getValues<APInt>()[{i + 1, j}].getSExtValue() : 0;
377       if (stb_paddings_attr.getValues<APInt>()[{i, j}].getSExtValue() +
378               addtional_pad !=
379           bts_crops_attr.getValues<APInt>()[{i, j}].getSExtValue()) {
380         op->setAttr("padding", rewriter.getStringAttr("SAME"));
381         break;
382       }
383     }
384   }
385 
386   // Set dilations
387   op->setAttr("dilations", dilations_attr.getValue());
388 
389   if (expand_op) {
390     // If there is `expand_op`, we need to rewire the inputs to bypass the
391     // `SpaceToBatch`, `BatchToSpace` and `Pad` op. E.g, turning
392     // 'SpaceToBatchND -> Expand -> Conv2D -> Squeeze -> BatchToSpaceND ->
393     // BiasAdd' to 'Expand -> Conv2D ->Squeeze -> BiasAdd'.
394 
395     // Connect `expand_op` with the input of `stb_op`.
396     expand_op.setOperand(0, stb_op.input());
397     // Calculate the shape for expand.
398     auto input_shape = stb_op.input().getType().cast<ShapedType>().getShape();
399     SmallVector<int64_t, 4> expand_shape(input_shape.begin(),
400                                          input_shape.end());
401     expand_shape.insert(expand_shape.begin() + expand_axis, 1);
402 
403     auto expand_result_type = RankedTensorType::get(
404         expand_shape, getElementTypeOrSelf(stb_op.input()));
405     expand_op.getResult().setType(expand_result_type);
406 
407     // Update the conv op's output shape.
408     auto bts_output_shape =
409         bts_op.output().getType().cast<ShapedType>().getShape();
410     SmallVector<int64_t, 4> conv_result_shape(bts_output_shape.begin(),
411                                               bts_output_shape.end());
412     conv_result_shape.insert(conv_result_shape.begin() + expand_axis, 1);
413     auto conv_result_type = RankedTensorType::get(
414         conv_result_shape, getElementTypeOrSelf(stb_op.input()));
415     op.getResult().setType(conv_result_type);
416 
417     squeeze_op.getResult().setType(bts_op.output().getType());
418 
419     // Connect `biasadd_op` with the output of `squeeze_op`.
420     if (biasadd_op) {
421       biasadd_op.setOperand(0, squeeze_op.output());
422       biasadd_op.output().setType(squeeze_op.output().getType());
423     }
424   } else {
425     if (biasadd_op) biasadd_op.setOperand(0, op.output());
426     op.setOperand(0, stb_op.input());
427     op.getResult().setType(bts_op.getResult().getType());
428   }
429 
430   if (final_op_is_bts) {
431     if (bts_op.input().getDefiningOp<TF::PadOp>()) {
432       bts_op.getResult().replaceAllUsesWith(pad_op.input());
433     } else {
434       bts_op.getResult().replaceAllUsesWith(bts_op.input());
435     }
436   }
437 
438   stb_op.getResult().dropAllUses();
439   return success();
440 }
441 
442 template <typename Conv2dOpTy>
443 llvm::Optional<ArrayAttr>
ExtractDilationsAttrFromBlockShape(Value stb_block_shape,Value bts_block_shape,int64_t expand_axis,PatternRewriter & rewriter)444 ConvertTFDilatedConvOp<Conv2dOpTy>::ExtractDilationsAttrFromBlockShape(
445     Value stb_block_shape, Value bts_block_shape, int64_t expand_axis,
446     PatternRewriter& rewriter) const {
447   ElementsAttr stb_bs_attr, bts_bs_attr;
448   if (!matchPattern(stb_block_shape, m_Constant(&stb_bs_attr)) ||
449       !matchPattern(bts_block_shape, m_Constant(&bts_bs_attr))) {
450     // Returns failure status if block_shape is not a constant.
451     return {};
452   }
453   // Check that the block_shape of `stb_op` and `bts_op` are equal.
454   if (stb_bs_attr.getNumElements() != bts_bs_attr.getNumElements()) return {};
455   for (uint64_t i = 0, end = stb_bs_attr.getNumElements(); i < end; ++i) {
456     if (stb_bs_attr.getValues<Attribute>()[i] !=
457         bts_bs_attr.getValues<Attribute>()[i])
458       return {};
459   }
460 
461   int dilation_h_factor = -1, dilation_w_factor = -1;
462   // Set dilation factor.
463   if (stb_bs_attr.getNumElements() >= 2) {
464     dilation_h_factor = stb_bs_attr.getValues<APInt>()[0].getSExtValue();
465     dilation_w_factor = stb_bs_attr.getValues<APInt>()[1].getSExtValue();
466   } else if (stb_bs_attr.getNumElements() == 1) {
467     // For 1d conv, `tf.nn.convolution` expands NWC to NHWC format after
468     // `SpaceToBatchND`. Therefore, `block_shape` of `stb_op` only has one
469     // dilation factor of W dim, and dilation factor of H dim is set to 1.
470     if (expand_axis == 1) {
471       // NWC -> NHWC
472       dilation_h_factor = 1;
473       dilation_w_factor = stb_bs_attr.getValues<APInt>()[0].getSExtValue();
474     } else if (expand_axis == 2) {
475       // NHC -> NHWC
476       dilation_h_factor = stb_bs_attr.getValues<APInt>()[0].getSExtValue();
477       dilation_w_factor = 1;
478     }
479   }
480 
481   if (dilation_h_factor == -1 || dilation_w_factor == -1) {
482     return {};
483   }
484 
485   return rewriter.getI64ArrayAttr({1, dilation_h_factor, dilation_w_factor, 1});
486 }
487 
488 }  // namespace TFL
489 }  // namespace mlir
490 
491 #endif  // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_DILATED_CONV_H_
492