1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/utils.h"
17 #include "tensorflow/compiler/xla/xla_data.pb.h"
18 #include "tensorflow/lite/kernels/padding.h"
19
20 namespace mlir::quant {
21
GetDimValue(OpBuilder & builder,Location loc,Value shape_value,int32_t dim)22 Value GetDimValue(OpBuilder &builder, Location loc, Value shape_value,
23 int32_t dim) {
24 Type attribute_type = builder.getI64Type();
25 return builder.create<TF::StridedSliceOp>(
26 loc,
27 RankedTensorType::get(
28 {},
29 shape_value.getType().template cast<ShapedType>().getElementType()),
30 /*input=*/shape_value,
31 /*begin=*/Create1DConstValue<int32_t>(builder, loc, {dim}),
32 /*end=*/Create1DConstValue<int32_t>(builder, loc, {dim + 1}),
33 /*strides=*/Create1DConstValue<int32_t>(builder, loc, {1}),
34 /*begin_mask=*/builder.getIntegerAttr(attribute_type, 0),
35 /*end_mask=*/builder.getIntegerAttr(attribute_type, 0),
36 /*ellipsis_mask=*/builder.getIntegerAttr(attribute_type, 0),
37 /*new_axis_mask=*/builder.getIntegerAttr(attribute_type, 0),
38 /*shrink_axis_mask=*/builder.getIntegerAttr(attribute_type, 1));
39 }
40
41 // Given Value input_size, and known numbers filter_sz, dilation_rate, stride,
42 // calculate padding_low and padding_high for SAME padding.
GetSamePaddingValues(OpBuilder & builder,Location loc,Value input_size,int64_t filter_sz,int64_t dilation_rate,int64_t stride,Value & padding_low,Value & padding_high)43 void GetSamePaddingValues(OpBuilder &builder, Location loc, Value input_size,
44 int64_t filter_sz, int64_t dilation_rate,
45 int64_t stride, Value &padding_low,
46 Value &padding_high) {
47 Value zero = CreateScalarConstValue<int32_t>(builder, loc, 0);
48 Value one = CreateScalarConstValue<int32_t>(builder, loc, 1);
49 Value two = CreateScalarConstValue<int32_t>(builder, loc, 2);
50 Value filter_size = CreateScalarConstValue<int32_t>(builder, loc, filter_sz);
51 Type int32_scalar_type = zero.getType();
52
53 auto scalar_add = [&](Value lhs, Value rhs) {
54 return builder.create<TF::AddOp>(loc, int32_scalar_type, lhs, rhs);
55 };
56 auto scalar_mul = [&](Value lhs, Value rhs) {
57 return builder.create<TF::MulOp>(loc, int32_scalar_type, lhs, rhs);
58 };
59 auto scalar_sub = [&](Value lhs, Value rhs) {
60 return builder.create<TF::SubOp>(loc, int32_scalar_type, lhs, rhs);
61 };
62 auto scalar_div = [&](Value lhs, Value rhs) {
63 return builder.create<TF::DivOp>(loc, int32_scalar_type, lhs, rhs);
64 };
65
66 // effective_filter_size = (filter_size - 1) * dilation_rate + 1
67 Value stride_value = CreateScalarConstValue<int32_t>(builder, loc, stride);
68 Value dilation_rate_value =
69 CreateScalarConstValue<int32_t>(builder, loc, dilation_rate);
70
71 Value effective_filter_size_op = scalar_add(
72 scalar_mul(dilation_rate_value, scalar_sub(filter_size, one)), one);
73
74 // output_size = (input_size + stride - 1) / stride
75 Value output_size = scalar_div(
76 scalar_add(input_size, scalar_sub(stride_value, one)), stride_value);
77 // padding_needed = std::max(
78 // 0,
79 // (output_size - 1) * stride + effective_filter_size - input_size)
80 Value padding_needed = scalar_sub(
81 scalar_add(effective_filter_size_op,
82 scalar_mul(stride_value, scalar_sub(output_size, one))),
83 input_size);
84 padding_needed = builder.create<TF::MaximumOp>(loc, padding_needed, zero);
85 padding_low = scalar_div(padding_needed, two);
86 padding_high = scalar_sub(padding_needed, padding_low);
87 }
88
PadForDynamicShapedInputSamePadding(OpBuilder & builder,Location loc,Value input,Value filter,int8_t input_zp_value,ArrayAttr strides,ArrayAttr dilations,StringAttr conv_padding,Value & padding)89 Value PadForDynamicShapedInputSamePadding(
90 OpBuilder &builder, Location loc, Value input, Value filter,
91 int8_t input_zp_value, ArrayAttr strides, ArrayAttr dilations,
92 StringAttr conv_padding, Value &padding) {
93 ShapedType filter_shape = filter.getType().template cast<ShapedType>();
94 const int stride_h = strides[1].cast<IntegerAttr>().getInt();
95 const int stride_w = strides[2].cast<IntegerAttr>().getInt();
96 const int dilation_h = dilations[1].cast<IntegerAttr>().getInt();
97 const int dilation_w = dilations[2].cast<IntegerAttr>().getInt();
98 const int filter_h = filter_shape.getDimSize(0);
99 const int filter_w = filter_shape.getDimSize(1);
100
101 Value input_shape_value = builder.create<TF::ShapeOp>(
102 loc, RankedTensorType::get({4}, builder.getI32Type()), input);
103 Value input_size_h = GetDimValue(builder, loc, input_shape_value, 1);
104 Value pad_h_low, pad_h_high;
105 GetSamePaddingValues(builder, loc, input_size_h, filter_h, dilation_h,
106 stride_h, pad_h_low, pad_h_high);
107 Value input_size_w = GetDimValue(builder, loc, input_shape_value, 2);
108 Value pad_w_low, pad_w_high;
109 GetSamePaddingValues(builder, loc, input_size_w, filter_w, dilation_w,
110 stride_w, pad_w_low, pad_w_high);
111 padding = CreateConstValue<int32_t>(builder, loc, {2, 2}, {0, 0, 0, 0});
112 Value zero = CreateScalarConstValue(builder, loc, 0);
113 Value zero_rank1 = CreateConstValue<int32_t>(builder, loc, {1}, {0});
114 auto reshape_op = [&](Value value, const SmallVector<int64_t> &shape) {
115 const int64_t rank = shape.size();
116 return builder.create<TF::ReshapeOp>(
117 loc, RankedTensorType::get(shape, builder.getI32Type()), value,
118 CreateConstValue<int64_t>(builder, loc, {rank}, shape));
119 };
120 auto scalar_to_rank1 = [&](Value value) { return reshape_op(value, {1}); };
121 Value temp_padding_rank1 = builder.create<TF::ConcatOp>(
122 loc, RankedTensorType::get({8}, builder.getI32Type()), zero,
123 ArrayRef<Value>({zero_rank1, zero_rank1, scalar_to_rank1(pad_h_low),
124 scalar_to_rank1(pad_h_high), scalar_to_rank1(pad_w_low),
125 scalar_to_rank1(pad_w_high), zero_rank1, zero_rank1}));
126 Value temp_padding = reshape_op(temp_padding_rank1, {4, 2});
127 return builder.create<TF::PadV2Op>(
128 loc, input.getType(), input, temp_padding,
129 CreateScalarConstValue<int8_t>(builder, loc, input_zp_value));
130 }
131
132 // If input spatial sizes are dynamic (unknown) and padding is same, add ops to
133 // dynamically calculate padding size and add input_zp value Pad op with the
134 // padding.
135 // Otherwise, calculates padding with known numbers, and only for non-zero
136 // padding (input_zp != 0), adds Pad op before convolution.
CalculatePaddingAndPadIfNeeded(OpBuilder & builder,Location loc,Value input,Value filter,int8_t input_zp_value,ArrayAttr strides,ArrayAttr dilations,StringAttr conv_padding,ArrayAttr explicit_paddings,Value & padding)137 Value CalculatePaddingAndPadIfNeeded(
138 OpBuilder &builder, Location loc, Value input, Value filter,
139 int8_t input_zp_value, ArrayAttr strides, ArrayAttr dilations,
140 StringAttr conv_padding, ArrayAttr explicit_paddings, Value &padding) {
141 ShapedType input_shape = input.getType().template cast<ShapedType>();
142
143 if (conv_padding.strref().equals("SAME") &&
144 (input_shape.isDynamicDim(1) || input_shape.isDynamicDim(2))) {
145 return PadForDynamicShapedInputSamePadding(
146 builder, loc, input, filter, input_zp_value, strides, dilations,
147 conv_padding, padding);
148 }
149
150 ShapedType filter_shape = filter.getType().template cast<ShapedType>();
151
152 int padding_h_before, padding_h_after, padding_w_before, padding_w_after;
153 if (conv_padding.strref().equals("EXPLICIT")) {
154 if (explicit_paddings.size() != 8) {
155 emitError(loc, "explicit_paddings are expected to be 8-element arrays");
156 return {};
157 }
158 padding_h_before = explicit_paddings[2].cast<IntegerAttr>().getInt();
159 padding_h_after = explicit_paddings[3].cast<IntegerAttr>().getInt();
160 padding_w_before = explicit_paddings[4].cast<IntegerAttr>().getInt();
161 padding_w_after = explicit_paddings[5].cast<IntegerAttr>().getInt();
162 } else if (conv_padding.strref().equals("VALID")) {
163 padding_h_before = 0;
164 padding_h_after = 0;
165 padding_w_before = 0;
166 padding_w_after = 0;
167 } else {
168 // conv_padding is "SAME".
169 int output_height, output_width;
170 const int stride_h = strides[1].cast<IntegerAttr>().getInt();
171 const int stride_w = strides[2].cast<IntegerAttr>().getInt();
172 const int dilation_h = dilations[1].cast<IntegerAttr>().getInt();
173 const int dilation_w = dilations[2].cast<IntegerAttr>().getInt();
174 TfLitePaddingValues padding_values = tflite::ComputePaddingHeightWidth(
175 stride_h, stride_w, dilation_h, dilation_w,
176 /*in_height=*/input_shape.getDimSize(1),
177 /*in_width=*/input_shape.getDimSize(2),
178 /*filter_height=*/filter_shape.getDimSize(0),
179 /*filter_width=*/filter_shape.getDimSize(1), kTfLitePaddingSame,
180 &output_height, &output_width);
181 padding_h_before = padding_values.height;
182 padding_h_after = padding_values.height + padding_values.height_offset;
183 padding_w_before = padding_values.width;
184 padding_w_after = padding_values.width + padding_values.width_offset;
185 }
186
187 if (input_zp_value == 0 || (padding_h_before == 0 && padding_h_after == 0 &&
188 padding_w_before == 0 && padding_w_after == 0)) {
189 padding = CreateConstValue<int32_t>(
190 builder, loc, {2, 2},
191 {padding_h_before, padding_h_after, padding_w_before, padding_w_after});
192 return input;
193 }
194 padding = CreateConstValue<int32_t>(builder, loc, {2, 2}, {0, 0, 0, 0});
195
196 Value temp_padding =
197 CreateConstValue<int32_t>(builder, loc, {4, 2},
198 {0, 0, padding_h_before, padding_h_after,
199 padding_w_before, padding_w_after, 0, 0});
200 SmallVector<int64_t> output_shape(input_shape.getShape().begin(),
201 input_shape.getShape().end());
202 output_shape[1] += padding_h_before + padding_h_after;
203 output_shape[2] += padding_w_before + padding_w_after;
204
205 return builder.create<TF::PadV2Op>(
206 loc, RankedTensorType::get(output_shape, builder.getI8Type()), input,
207 temp_padding,
208 CreateScalarConstValue<int8_t>(builder, loc, input_zp_value));
209 }
210
211 } // namespace mlir::quant
212