1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/tf2xla/kernels/image_resize_ops.h"
16
17 #include <string>
18
19 #include "absl/strings/str_format.h"
20 #include "absl/types/span.h"
21 #include "tensorflow/compiler/jit/xla_activity.pb.h"
22 #include "tensorflow/compiler/jit/xla_activity_listener.h"
23 #include "tensorflow/compiler/tf2xla/shape_util.h"
24 #include "tensorflow/compiler/tf2xla/type_util.h"
25 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
26 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
27 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
28 #include "tensorflow/compiler/xla/array4d.h"
29 #include "tensorflow/compiler/xla/client/lib/constants.h"
30 #include "tensorflow/compiler/xla/client/xla_builder.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/framework/kernel_def_builder.h"
34 #include "tensorflow/core/framework/register_types.h"
35 #include "tensorflow/core/lib/math/math_util.h"
36 #include "tensorflow/core/platform/errors.h"
37
38 namespace tensorflow {
39 namespace {
40
41 // We implement bilinear interpolation by upsampling followed by convolution.
42 // The basic idea is as follows. To scale from NxN to RxR:
43 //
44 // 1. S := (N - 1) / gcd(N-1, R-1)
45 // 2. k := (R - 1) / gcd(N-1, R-1)
46 // 3. Convolution((2k-1)x(2k-1), stride=S, lhs_dilation=k, padding=k-1)
47 //
48 // For example, to Scale from 7x7 -> 15x15:
49 //
50 // 1. S := (7-1) / gcd(7-1, 15-1) = 6 / gcd(6, 14) = 6 / 2 = 3
51 // 2. k := (15 - 1) / gcd(7-1, 15-1) = 14 / gcd(6, 14) = 14 / 2 = 7
52 // 3. Convolution(15x15, stride=3, lhs_dilation=7, padding=2)
53 //
54 //
55 // The 7x7 -> 15x15 case is much too large to write out in full as an
56 // example. The smallest interesting example is 3x3 -> 4x4.
57 //
58 // S := 2
59 // k := 3
60 //
61 // 00 03 06 00 00 00 00 00 00 00 00 00 00 00 00 02 04 06
62 // 09 12 15 -> 00 00 00 00 00 00 00 00 00 00 00 -> 06 08 10 12
63 // 18 21 24 00 00 00 00 00 03 00 00 06 00 00 12 14 16 18
64 // 00 00 00 00 00 00 00 00 00 00 00 18 20 22 24
65 // 00 00 00 00 00 00 00 00 00 00 00
66 // 00 00 09 00 00 12 00 00 15 00 00
67 // 00 00 00 00 00 00 00 00 00 00 00
68 // 00 00 00 00 00 00 00 00 00 00 00
69 // 00 00 18 00 00 21 00 00 24 00 00
70 // 00 00 00 00 00 00 00 00 00 00 00
71 // 00 00 00 00 00 00 00 00 00 00 00
72 //
73 // with the following convolutional kernel, with stride [2, 2]:
74 // 1 2 3 2 1
75 // 2 4 6 4 2
76 // 1/9 * 3 6 9 6 3
77 // 2 4 6 4 2
78 // 1 2 3 2 1
79 // Note that the convolution kernel matrix is separable and thus we can instead
80 // use 2 consecutive 1D kernel of the dimension 2k-1, along each axis.
81
82 // Computes the size of the convolutional kernel and stride to use when resizing
83 // from in_size to out_size.
84 struct ResizeConvolutionDims {
85 // Size of the kernel to use.
86 std::vector<int64_t> kernel_size; // k
87
88 // Stride of the convolution to use.
89 std::vector<int64_t> stride; // S
90 };
ComputeResizeConvolutionParameters(absl::Span<const int64_t> in_size,absl::Span<const int64_t> out_size,bool align_corners)91 ResizeConvolutionDims ComputeResizeConvolutionParameters(
92 absl::Span<const int64_t> in_size, absl::Span<const int64_t> out_size,
93 bool align_corners) {
94 CHECK_EQ(in_size.size(), out_size.size());
95 int num_spatial_dims = in_size.size();
96 ResizeConvolutionDims dims;
97 dims.kernel_size.resize(num_spatial_dims);
98 dims.stride.resize(num_spatial_dims);
99 for (int i = 0; i < num_spatial_dims; ++i) {
100 if (in_size[i] == 1) {
101 // We must handle input size 1 specially because XLA convolution does
102 // not allow stride 0.
103 dims.stride[i] = dims.kernel_size[i] = 1;
104 } else if (out_size[i] == 1) {
105 // If in_size[i] > 1 but out_size[i] == 1, then we slice out the first
106 // entry before resizing.
107 dims.stride[i] = dims.kernel_size[i] = 1;
108 } else {
109 // The scaling factor changes depending on the alignment of corners.
110 const int64_t in_size_factor =
111 align_corners ? in_size[i] - 1 : in_size[i];
112 const int64_t out_size_factor =
113 align_corners ? out_size[i] - 1 : out_size[i];
114
115 int64_t gcd = MathUtil::GCD(static_cast<uint64>(in_size_factor),
116 static_cast<uint64>(out_size_factor));
117 dims.stride[i] = in_size_factor / gcd;
118 dims.kernel_size[i] = out_size_factor / gcd;
119 }
120 }
121 return dims;
122 }
123
124 // The upper padding of the input needed by ConvGeneralDilated calls is
125 // determined by solving two related relationships (assuming rhs_dilation == 0):
126 // 1. dilated_input_dim = lower_padding + upper_padding
127 // + lhs_dilation * (in_size - 1) + 1
128 // 2. dilated_input_dim = (2 * dims.kernel-size - 1)
129 // + dims.stride * (out_size - 1)
CalculateUpperPadding(int64_t in_size,int64_t out_size,int64_t kernel_size,int64_t stride)130 int64_t CalculateUpperPadding(int64_t in_size, int64_t out_size,
131 int64_t kernel_size, int64_t stride) {
132 int64_t padding = (2 * kernel_size - 1) + (out_size - 1) * stride -
133 (kernel_size - 1) - 1 - (kernel_size * (in_size - 1));
134
135 return padding;
136 }
137
138 // Form a 2D convolution kernel like:
139 // 1 2 3 2 1
140 // 2 4 6 4 2
141 // 1/9 * 3 6 9 6 3
142 // 2 4 6 4 2
143 // 1 2 3 2 1
144 // by multiplying two 1D kernels of the form:
145 // 1/3 * [1 2 3 2 1]
146 // If the 2D kernel would be very large, the 1D kernel can be applied once in
147 // each dimension due to the symmetry of the kernel along all axis to reduce the
148 // computational intensity.
MakeBilinear1DKernel(xla::XlaBuilder * builder,xla::PrimitiveType type,int64_t n)149 xla::XlaOp MakeBilinear1DKernel(xla::XlaBuilder* builder,
150 xla::PrimitiveType type, int64_t n) {
151 std::vector<float> kernel(n * 2 - 1);
152 for (int64_t i = 0; i < n; ++i) {
153 float v = (i + 1.0f) / n;
154 kernel[i] = v;
155 kernel[n * 2 - 2 - i] = v;
156 }
157 return xla::ConvertElementType(xla::ConstantR1<float>(builder, kernel), type);
158 }
159
160 // Unlike the bilinear kernel, which is triangular, the nearest neighbor
161 // kernel is a square. For example, a 1D kernel with n=3 would look like
162 // [0 1 1 1 0]
163 // and n=4 would look like
164 // [0 0 1 1 1 1 0].
165 // Note that in the second case, the kernel is not symmetric and we default
166 // to the right (because an existing non TPU kernel
167 // for nearest neighbor resize already chose to default to the right,
168 // so we want to be consistent).
MakeNearestNeighbor1DKernel(xla::XlaBuilder * builder,xla::PrimitiveType type,int64_t n)169 xla::XlaOp MakeNearestNeighbor1DKernel(xla::XlaBuilder* builder,
170 xla::PrimitiveType type, int64_t n) {
171 std::vector<float> kernel(n * 2 - 1, 0.0f);
172 std::fill(&kernel[n / 2], &kernel[(3 * n) / 2], 1.0f);
173
174 return xla::ConvertElementType(xla::ConstantR1<float>(builder, kernel), type);
175 }
176
177 // Kernels with more than 16 spatial elements are considered intense and the
178 // kernel should be applied to each dimension independently.
179 const int64_t kMax2DKernelSize = 16;
180
MakeGeneralResizeKernel(xla::XlaBuilder * builder,xla::PrimitiveType type,absl::Span<const int64_t> kernel_size,int64_t channels,bool is_kernel_bilinear)181 xla::XlaOp MakeGeneralResizeKernel(xla::XlaBuilder* builder,
182 xla::PrimitiveType type,
183 absl::Span<const int64_t> kernel_size,
184 int64_t channels, bool is_kernel_bilinear) {
185 auto make_kernel_func =
186 is_kernel_bilinear ? MakeBilinear1DKernel : MakeNearestNeighbor1DKernel;
187
188 std::vector<int64_t> depthwise_kernel_sizes = {
189 (2 * kernel_size[0] - 1), (2 * kernel_size[1] - 1), channels, 1};
190 auto depthwise_kernel =
191 xla::BroadcastInDim(make_kernel_func(builder, type, kernel_size[1]),
192 depthwise_kernel_sizes, /*broadcast_dimensions=*/{1});
193
194 return xla::Mul(depthwise_kernel,
195 make_kernel_func(builder, type, kernel_size[0]),
196 /*broadcast_dimensions=*/{0});
197 }
198
MakeGeneralResizeKernelInDim(xla::XlaBuilder * builder,xla::PrimitiveType type,absl::Span<const int64_t> kernel_size,int64_t channels,int64_t dim,bool is_kernel_bilinear)199 xla::XlaOp MakeGeneralResizeKernelInDim(xla::XlaBuilder* builder,
200 xla::PrimitiveType type,
201 absl::Span<const int64_t> kernel_size,
202 int64_t channels, int64_t dim,
203 bool is_kernel_bilinear) {
204 auto make_kernel_func =
205 is_kernel_bilinear ? MakeBilinear1DKernel : MakeNearestNeighbor1DKernel;
206
207 std::vector<int64_t> depthwise_kernel_sizes = {
208 dim == 0 ? (2 * kernel_size[0] - 1) : 1,
209 dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels, 1};
210 return xla::BroadcastInDim(make_kernel_func(builder, type, kernel_size[dim]),
211 depthwise_kernel_sizes,
212 /*broadcast_dimensions=*/{dim});
213 }
214
BroadcastSpatialDimensions(xla::XlaBuilder * builder,const xla::XlaOp & input,int32_t spatial_dimensions_offset,absl::Span<const int64_t> in_size,absl::Span<const int64_t> out_size)215 xla::XlaOp BroadcastSpatialDimensions(xla::XlaBuilder* builder,
216 const xla::XlaOp& input,
217 int32_t spatial_dimensions_offset,
218 absl::Span<const int64_t> in_size,
219 absl::Span<const int64_t> out_size) {
220 // Add broadcasts to handle expanding from a size == 1 dimension to a
221 // size > 1 dimension.
222 auto broadcast_shape_or_status = builder->GetShape(input);
223 if (!broadcast_shape_or_status.ok()) {
224 return builder->ReportError(broadcast_shape_or_status.status());
225 }
226 xla::Shape broadcast_shape = broadcast_shape_or_status.ValueOrDie();
227 for (int32_t i = 0; i < in_size.size(); ++i) {
228 if (in_size[i] == 1 && out_size[i] > 1) {
229 broadcast_shape.set_dimensions(spatial_dimensions_offset + i,
230 out_size[i]);
231 }
232 }
233 return xla::BroadcastInDim(input, broadcast_shape.dimensions(),
234 /*broadcast_dimensions=*/{0, 1, 2, 3});
235 }
236
ResizeUsingDilationAndConvolution(xla::XlaBuilder * builder,const xla::XlaOp & input,xla::PrimitiveType type,const int num_spatial_dims,absl::Span<const int64_t> in_size,absl::Span<const int64_t> out_size,const int64_t channels,const bool align_corners,bool is_kernel_bilinear)237 xla::XlaOp ResizeUsingDilationAndConvolution(
238 xla::XlaBuilder* builder, const xla::XlaOp& input, xla::PrimitiveType type,
239 const int num_spatial_dims, absl::Span<const int64_t> in_size,
240 absl::Span<const int64_t> out_size, const int64_t channels,
241 const bool align_corners, bool is_kernel_bilinear) {
242 // Picture for a 1x3 to 1x4 bilinear resize:
243 // stride = 2, kernel size = 3
244 // Input:
245 // 3 6 9
246 // Input with dilation and padding:
247 // 0 0 3 0 0 6 0 0 9 0 0
248 // Convolution kernel:
249 // 1/3 * [1 2 3 2 1]
250 // Output:
251 // 3 5 7 9
252 xla::ConvolutionDimensionNumbers dimension_numbers;
253 dimension_numbers.set_input_batch_dimension(0);
254 dimension_numbers.set_output_batch_dimension(0);
255 dimension_numbers.set_input_feature_dimension(num_spatial_dims + 1);
256 dimension_numbers.set_output_feature_dimension(num_spatial_dims + 1);
257 for (int i = 0; i < num_spatial_dims; ++i) {
258 dimension_numbers.add_input_spatial_dimensions(1 + i);
259 dimension_numbers.add_output_spatial_dimensions(1 + i);
260 dimension_numbers.add_kernel_spatial_dimensions(i);
261 }
262 dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims + 1);
263 dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims);
264
265 ResizeConvolutionDims dims =
266 ComputeResizeConvolutionParameters(in_size, out_size, align_corners);
267
268 if (dims.kernel_size[0] * dims.kernel_size[1] >
269 kMax2DKernelSize * kMax2DKernelSize) {
270 BroadcastOptimizationRemark(
271 XlaOptimizationRemark::SLOW_IMAGE_RESIZE_DIMENSIONS,
272 absl::StrFormat("%dx%d", dims.kernel_size[0], dims.kernel_size[1]))
273 .IgnoreError();
274 }
275
276 xla::XlaOp output;
277
278 // Concatenation and padding below currently assumes num_spatial_dims is 2 to
279 // prevent needless code complexity.
280 CHECK_EQ(num_spatial_dims, 2)
281 << "ResizeUsingDilationAndConvolution pads only 2 dimensions currently.";
282 std::vector<int64_t> upper_padding(num_spatial_dims);
283 for (int i = 0; i < num_spatial_dims; ++i) {
284 upper_padding[i] = dims.kernel_size[i] - 1;
285 }
286 xla::XlaOp input_data = input;
287
288 if (!align_corners) {
289 // When Tensorflow does not align_corners, the resize indexing can access
290 // beyond the upper bound and is instead clamped to prevent out of bounds
291 // reads. This is conceptually the same as extending the edges of the input.
292 // We emulate this by copying the last row/column of the input.
293 // Calculate what padding would be needed then determine how far to extend
294 // the border before lhs dilation.
295 std::vector<int64_t> num_extended(num_spatial_dims);
296 upper_padding[0] = CalculateUpperPadding(
297 in_size[0], out_size[0], dims.kernel_size[0], dims.stride[0]);
298 upper_padding[1] = CalculateUpperPadding(
299 in_size[1], out_size[1], dims.kernel_size[1], dims.stride[1]);
300 num_extended[0] = upper_padding[0] / (dims.kernel_size[0]);
301 num_extended[1] = upper_padding[1] / (dims.kernel_size[1]);
302
303 const int64_t batch_dim_size =
304 builder->GetShape(input).ValueOrDie().dimensions(0);
305 if (num_extended[0] > 0) {
306 auto slice = xla::Slice(
307 input_data, {0, in_size[0] - 1, 0, 0},
308 {batch_dim_size, in_size[0], in_size[1], channels}, {1, 1, 1, 1});
309 for (int i = 0; i < num_extended[0]; i++) {
310 input_data = xla::ConcatInDim(builder, {input_data, slice}, 1);
311 }
312 }
313
314 if (num_extended[1] > 0) {
315 auto slice = xla::Slice(
316 input_data, {0, 0, in_size[1] - 1, 0},
317 {batch_dim_size, in_size[0] + num_extended[0], in_size[1], channels},
318 {1, 1, 1, 1});
319 for (int i = 0; i < num_extended[1]; i++) {
320 input_data = xla::ConcatInDim(builder, {input_data, slice}, 2);
321 }
322 }
323
324 // Setting in_size to (in_size + num_extended) due to the above Slice and
325 // ConcatInDim. Recalculate needed padding after the above Slice/Concat.
326 upper_padding[0] =
327 CalculateUpperPadding(in_size[0] + num_extended[0], out_size[0],
328 dims.kernel_size[0], dims.stride[0]);
329 upper_padding[1] =
330 CalculateUpperPadding(in_size[1] + num_extended[1], out_size[1],
331 dims.kernel_size[1], dims.stride[1]);
332 }
333
334 // Split convolutions into independent dimensions if they would be a very
335 // large kernel or if one or more of the dimensions are already equal.
336 bool decompose_resize =
337 in_size[0] == out_size[0] || in_size[1] == out_size[1] ||
338 dims.kernel_size[0] * dims.kernel_size[1] >= kMax2DKernelSize;
339 if (!decompose_resize) {
340 xla::XlaOp kernel = MakeGeneralResizeKernel(builder, type, dims.kernel_size,
341 channels, is_kernel_bilinear);
342 output =
343 xla::ConvGeneralDilated(input_data, kernel, dims.stride,
344 /*padding=*/
345 {{dims.kernel_size[0] - 1, upper_padding[0]},
346 {dims.kernel_size[1] - 1, upper_padding[1]}},
347 /*lhs_dilation=*/dims.kernel_size,
348 /*rhs_dilation=*/{1, 1}, dimension_numbers,
349 /*feature_group_count=*/channels);
350 } else {
351 output = input_data;
352 if (in_size[0] != out_size[0]) {
353 xla::XlaOp kernel0 = MakeGeneralResizeKernelInDim(
354 builder, type, dims.kernel_size, channels, 0, is_kernel_bilinear);
355 output = xla::ConvGeneralDilated(
356 output, kernel0, {dims.stride[0], 1},
357 /*padding=*/
358 {{dims.kernel_size[0] - 1, upper_padding[0]}, {0, 0}},
359 /*lhs_dilation=*/{dims.kernel_size[0], 1},
360 /*rhs_dilation=*/{1, 1}, dimension_numbers,
361 /*feature_group_count=*/channels);
362 }
363
364 if (in_size[1] != out_size[1]) {
365 xla::XlaOp kernel1 = MakeGeneralResizeKernelInDim(
366 builder, type, dims.kernel_size, channels, 1, is_kernel_bilinear);
367 output = xla::ConvGeneralDilated(
368 output, kernel1, {1, dims.stride[1]},
369 /*padding=*/
370 {{0, 0}, {dims.kernel_size[1] - 1, upper_padding[1]}},
371 /*lhs_dilation=*/{1, dims.kernel_size[1]},
372 /*rhs_dilation=*/{1, 1}, dimension_numbers,
373 /*feature_group_count=*/channels);
374 }
375 }
376
377 // Add broadcasts to handle expanding from a size == 1 dimension to a
378 // size > 1 dimension.
379 return BroadcastSpatialDimensions(
380 builder, output, /*spatial_dimensions_offset=*/1, in_size, out_size);
381 }
382
ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder * builder,const xla::XlaOp & grad,xla::PrimitiveType type,const int num_spatial_dims,absl::Span<const int64_t> in_size,absl::Span<const int64_t> grad_size,const int64_t channels,const bool align_corners,bool is_kernel_bilinear)383 xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(
384 xla::XlaBuilder* builder, const xla::XlaOp& grad, xla::PrimitiveType type,
385 const int num_spatial_dims, absl::Span<const int64_t> in_size,
386 absl::Span<const int64_t> grad_size, const int64_t channels,
387 const bool align_corners, bool is_kernel_bilinear) {
388 ResizeConvolutionDims dims =
389 ComputeResizeConvolutionParameters(in_size, grad_size, align_corners);
390
391 // To form the backward convolution, we keep the kernel unchanged (it is
392 // already symmetric) and swap the roles of strides and LHS dilation.
393 xla::ConvolutionDimensionNumbers dimension_numbers;
394 dimension_numbers.set_input_batch_dimension(0);
395 dimension_numbers.set_output_batch_dimension(0);
396 dimension_numbers.set_input_feature_dimension(num_spatial_dims + 1);
397 dimension_numbers.set_output_feature_dimension(num_spatial_dims + 1);
398 for (int i = 0; i < num_spatial_dims; ++i) {
399 dimension_numbers.add_input_spatial_dimensions(i + 1);
400 dimension_numbers.add_output_spatial_dimensions(i + 1);
401 dimension_numbers.add_kernel_spatial_dimensions(i);
402 }
403 dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims + 1);
404 dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims);
405 xla::XlaOp output;
406 if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) {
407 xla::XlaOp kernel = MakeGeneralResizeKernel(builder, type, dims.kernel_size,
408 channels, is_kernel_bilinear);
409
410 // Broadcast the input kernel where the forward op expanded from a size == 1
411 // dimension to a size > 1 dimension. This has the effect of summing the
412 // gradient contributions in that dimension.
413 kernel = BroadcastSpatialDimensions(
414 builder, kernel, /*spatial_dimensions_offset=*/0, in_size, grad_size);
415
416 output = xla::ConvGeneralDilated(
417 grad, kernel, /*window_strides=*/dims.kernel_size,
418 /*padding=*/
419 {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1},
420 {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}},
421 /*lhs_dilation=*/dims.stride,
422 /*rhs_dilation=*/{1, 1}, dimension_numbers,
423 /*feature_group_count=*/channels);
424 } else {
425 xla::XlaOp kernel0 = MakeGeneralResizeKernelInDim(
426 builder, type, dims.kernel_size, channels, 0, is_kernel_bilinear);
427 xla::XlaOp kernel1 = MakeGeneralResizeKernelInDim(
428 builder, type, dims.kernel_size, channels, 1, is_kernel_bilinear);
429
430 // Broadcast the input kernel where the forward op expanded from a
431 // size == 1 dimension to a size > 1 dimension. This has the effect of
432 // summing the gradient contributions in that dimension.
433 if (in_size[0] == 1 && grad_size[0] > 1) {
434 kernel0 = BroadcastSpatialDimensions(builder, kernel0,
435 /*spatial_dimensions_offset=*/0, {1},
436 {grad_size[0]});
437 }
438 if (in_size[1] == 1 && grad_size[1] > 1) {
439 kernel1 = BroadcastSpatialDimensions(builder, kernel0,
440 /*spatial_dimensions_offset=*/0,
441 in_size, grad_size);
442 }
443
444 output = xla::ConvGeneralDilated(
445 grad, kernel0, /*window_strides=*/{dims.kernel_size[0], 1},
446 /*padding=*/
447 {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {0, 0}},
448 /*lhs_dilation=*/{dims.stride[0], 1},
449 /*rhs_dilation=*/{1, 1}, dimension_numbers,
450 /*feature_group_count=*/channels);
451
452 output = xla::ConvGeneralDilated(
453 output, kernel1, /*window_strides=*/{1, dims.kernel_size[1]},
454 /*padding=*/
455 {{0, 0}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}},
456 /*lhs_dilation=*/{1, dims.stride[1]},
457 /*rhs_dilation=*/{1, 1}, dimension_numbers,
458 /*feature_group_count=*/channels);
459 }
460
461 // If in_size[i] > 1 and grad_size[i] == 1, pad the output in dimension i.
462 // Opposite of the slice performed by the forward op.
463 xla::PaddingConfig padding = xla::MakeNoPaddingConfig(4);
464 bool pad_output = false;
465 for (int i = 0; i < num_spatial_dims; ++i) {
466 if (in_size[i] > 1 && grad_size[i] == 1) {
467 pad_output = true;
468 padding.mutable_dimensions(1 + i)->set_edge_padding_high(in_size[i] - 1);
469 }
470 }
471 if (pad_output) {
472 output = xla::Pad(output, xla::Zero(builder, type), padding);
473 }
474 return output;
475 }
476
GeneralCompile(XlaOpKernelContext * ctx,bool align_corners_,bool half_pixel_centers_,bool is_kernel_bilinear_)477 void GeneralCompile(XlaOpKernelContext* ctx, bool align_corners_,
478 bool half_pixel_centers_, bool is_kernel_bilinear_) {
479 // We implement bilinear interpolation and nearest neighbor with a Gather op.
480 // For each output pixel, we gather the necessary slices of the input.
481 // We then construct the weights that are necessary to calculate the weighted
482 // sum for each output pixel. We do this with a DotGeneral op.
483 xla::XlaBuilder* b = ctx->builder();
484
485 TensorShape input_shape = ctx->InputShape(0);
486 OP_REQUIRES(ctx, input_shape.dims() == 4,
487 errors::InvalidArgument("input must be 4-dimensional",
488 input_shape.DebugString()));
489 // First dimension always assumed to be batch
490 const int64_t batch = input_shape.dim_size(0);
491 std::vector<int64_t> in_size = {input_shape.dim_size(1),
492 input_shape.dim_size(2)};
493 // Last/4th dimension always assumed to be num channels
494 const int64_t channels = input_shape.dim_size(3);
495 OP_REQUIRES(ctx, in_size[0] > 0 && in_size[1] > 0,
496 errors::InvalidArgument("input size must be positive, got [",
497 in_size[0], ",", in_size[1], "]"));
498
499 std::vector<int64_t> out_size;
500 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &out_size));
501 OP_REQUIRES(ctx, out_size.size() == 2,
502 errors::InvalidArgument("output size must be length 2, got ",
503 out_size.size()));
504 OP_REQUIRES(ctx, out_size[0] > 0 && out_size[1] > 0,
505 errors::InvalidArgument("output size must be positive, got [",
506 out_size[0], ",", out_size[1], "]"));
507
508 xla::XlaOp input = ctx->Input(0);
509 xla::PrimitiveType input_type = ctx->input_xla_type(0);
510 xla::PrimitiveType original_input_type = input_type;
511 if (is_kernel_bilinear_ || xla::primitive_util::IsIntegralType(input_type)) {
512 input = xla::ConvertElementType(input, xla::F32);
513 input_type = xla::F32;
514 }
515 DataType output_dtype =
516 EncodePrimitiveTypeAsDataType(input_type).ValueOrDie();
517
518 xla::XlaOp scalar_one_op =
519 xla::ConvertElementType(xla::ConstantR0(b, 1), input_type);
520 xla::XlaOp scalar_half_op =
521 xla::ConvertElementType(xla::ConstantR0(b, 0.5), input_type);
522 xla::XlaOp scalar_zero_op =
523 xla::ConvertElementType(xla::ConstantR0(b, 0), input_type);
524 float h_scale;
525 if (align_corners_ && out_size[0] > 1) {
526 h_scale = (in_size[0] - 1) / static_cast<float>(out_size[0] - 1);
527 } else {
528 h_scale = in_size[0] / static_cast<float>(out_size[0]);
529 }
530 xla::XlaOp h_span_start =
531 xla::Iota(b, xla::ShapeUtil::MakeShape(input_type, {out_size[0]}), 0);
532 if (half_pixel_centers_) {
533 h_span_start = xla::Add(h_span_start, scalar_half_op);
534 }
535 xla::XlaOp h_scale_op =
536 xla::ConvertElementType(xla::ConstantR0(b, h_scale), input_type);
537 xla::XlaOp h_sample_f = xla::Mul(h_span_start, h_scale_op);
538
539 if (is_kernel_bilinear_) {
540 h_span_start = xla::Sub(h_sample_f, scalar_one_op);
541 if (half_pixel_centers_) {
542 h_span_start = xla::Sub(h_span_start, scalar_half_op);
543 }
544 h_span_start = xla::Ceil(h_span_start);
545 } else {
546 h_span_start =
547 align_corners_ ? xla::Round(h_sample_f) : xla::Floor(h_sample_f);
548 }
549 const int64_t h_span_size =
550 is_kernel_bilinear_ ? std::min(static_cast<int64_t>(3), in_size[0]) : 1;
551 xla::XlaOp h_upper_bound = xla::ConvertElementType(
552 xla::ConstantR0(b, in_size[0] - h_span_size), input_type);
553 if (!is_kernel_bilinear_ && !half_pixel_centers_) {
554 h_span_start = xla::Min(h_span_start, h_upper_bound);
555 } else {
556 h_span_start = xla::Clamp(scalar_zero_op, h_span_start, h_upper_bound);
557 }
558 xla::XlaOp broadcasted_h_span_start =
559 xla::BroadcastInDim(h_span_start, {out_size[0], out_size[1], 1}, {0});
560
561 float w_scale;
562 if (align_corners_ && out_size[1] > 1) {
563 w_scale = (in_size[1] - 1) / static_cast<float>(out_size[1] - 1);
564 } else {
565 w_scale = in_size[1] / static_cast<float>(out_size[1]);
566 }
567 xla::XlaOp w_span_start =
568 xla::Iota(b, xla::ShapeUtil::MakeShape(input_type, {out_size[1]}), 0);
569 if (half_pixel_centers_) {
570 w_span_start = xla::Add(w_span_start, scalar_half_op);
571 }
572 xla::XlaOp w_scale_op =
573 xla::ConvertElementType(xla::ConstantR0(b, w_scale), input_type);
574 xla::XlaOp w_sample_f = xla::Mul(w_span_start, w_scale_op);
575 if (is_kernel_bilinear_) {
576 w_span_start = xla::Sub(w_sample_f, scalar_one_op);
577 if (half_pixel_centers_) {
578 w_span_start = xla::Sub(w_span_start, scalar_half_op);
579 }
580 w_span_start = xla::Ceil(w_span_start);
581 } else {
582 w_span_start =
583 align_corners_ ? xla::Round(w_sample_f) : xla::Floor(w_sample_f);
584 }
585 const int64_t w_span_size =
586 is_kernel_bilinear_ ? std::min(static_cast<int64_t>(3), in_size[1]) : 1;
587 xla::XlaOp w_upper_bound = xla::ConvertElementType(
588 xla::ConstantR0(b, in_size[1] - w_span_size), input_type);
589 if (!is_kernel_bilinear_ && !half_pixel_centers_) {
590 w_span_start = xla::Min(w_span_start, w_upper_bound);
591 } else {
592 w_span_start = xla::Clamp(scalar_zero_op, w_span_start, w_upper_bound);
593 }
594 xla::XlaOp broadcasted_w_span_start =
595 xla::BroadcastInDim(w_span_start, {out_size[0], out_size[1], 1}, {1});
596
597 xla::XlaOp concatted = xla::ConvertElementType(
598 xla::ConcatInDim(b, {broadcasted_h_span_start, broadcasted_w_span_start},
599 2),
600 xla::S32);
601
602 absl::InlinedVector<int64_t, 4> slize_sizes = {batch, h_span_size,
603 w_span_size, channels};
604 xla::GatherDimensionNumbers dimension_numbers;
605 dimension_numbers.add_offset_dims(0);
606 dimension_numbers.add_offset_dims(1);
607 dimension_numbers.add_offset_dims(2);
608 dimension_numbers.add_offset_dims(3);
609 dimension_numbers.add_start_index_map(1);
610 dimension_numbers.add_start_index_map(2);
611 dimension_numbers.set_index_vector_dim(2);
612 input = xla::Gather(input, concatted, dimension_numbers, slize_sizes, false);
613
614 xla::XlaOp w_weight;
615 if (is_kernel_bilinear_) {
616 xla::XlaOp w_sub = xla::Sub(w_span_start, w_sample_f);
617 w_sub = xla::BroadcastInDim(w_sub, {out_size[1], w_span_size}, {0});
618 xla::XlaOp w_offset =
619 xla::Iota(b, xla::ShapeUtil::MakeShape(input_type, {w_span_size}), 0);
620 xla::XlaOp w_kernel_pos = xla::Add(w_sub, w_offset, {1});
621 if (half_pixel_centers_) {
622 w_kernel_pos = xla::Add(w_kernel_pos, scalar_half_op);
623 }
624 w_weight = xla::Max(scalar_zero_op,
625 xla::Sub(scalar_one_op, xla::Abs(w_kernel_pos)));
626 } else {
627 w_weight = xla::Broadcast(scalar_one_op, {out_size[1], w_span_size});
628 }
629 xla::XlaOp w_weight_sum = xla::Reduce(
630 w_weight, scalar_zero_op, *ctx->GetOrCreateAdd(output_dtype), {1});
631 w_weight = xla::Div(w_weight, w_weight_sum, {0});
632
633 xla::XlaOp h_weight;
634 if (is_kernel_bilinear_) {
635 xla::XlaOp h_sub = xla::Sub(h_span_start, h_sample_f);
636 h_sub = xla::BroadcastInDim(h_sub, {out_size[0], h_span_size}, {0});
637 xla::XlaOp h_offset =
638 xla::Iota(b, xla::ShapeUtil::MakeShape(input_type, {h_span_size}), 0);
639 xla::XlaOp h_kernel_pos = xla::Add(h_sub, h_offset, {1});
640 if (half_pixel_centers_) {
641 h_kernel_pos = xla::Add(h_kernel_pos, scalar_half_op);
642 }
643 h_weight = xla::Max(scalar_zero_op,
644 xla::Sub(scalar_one_op, xla::Abs(h_kernel_pos)));
645 } else {
646 h_weight = xla::Broadcast(scalar_one_op, {out_size[0], h_span_size});
647 }
648 xla::XlaOp h_weight_sum = xla::Reduce(
649 h_weight, scalar_zero_op, *ctx->GetOrCreateAdd(output_dtype), {1});
650 h_weight = xla::Div(h_weight, h_weight_sum, {0});
651
652 xla::DotDimensionNumbers dot_dnum;
653 dot_dnum.add_lhs_contracting_dimensions(3);
654 dot_dnum.add_lhs_contracting_dimensions(1);
655 dot_dnum.add_rhs_contracting_dimensions(1);
656 dot_dnum.add_rhs_contracting_dimensions(2);
657 dot_dnum.add_lhs_batch_dimensions(2);
658 dot_dnum.add_lhs_batch_dimensions(0);
659 dot_dnum.add_rhs_batch_dimensions(4);
660 dot_dnum.add_rhs_batch_dimensions(5);
661 input = xla::DotGeneral(
662 xla::DotGeneral(w_weight, h_weight, xla::DotDimensionNumbers()), input,
663 dot_dnum);
664
665 absl::InlinedVector<int64_t, 4> perm = {2, 0, 1, 3};
666 input = xla::Transpose(input, perm);
667
668 if (!is_kernel_bilinear_ && original_input_type != input_type) {
669 input = xla::ConvertElementType(input, original_input_type);
670 }
671 ctx->SetOutput(0, input);
672 }
673 } // namespace
674
ResizeNearestNeighborOp(OpKernelConstruction * ctx)675 ResizeNearestNeighborOp::ResizeNearestNeighborOp(OpKernelConstruction* ctx)
676 : XlaOpKernel(ctx) {
677 OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_));
678 OP_REQUIRES_OK(ctx, ctx->GetAttr("half_pixel_centers", &half_pixel_centers_));
679 OP_REQUIRES(ctx, !half_pixel_centers_ || !align_corners_,
680 errors::Unimplemented("If half_pixel_centers is True, "
681 "align_corners must be False."));
682 }
683
Compile(XlaOpKernelContext * ctx)684 void ResizeNearestNeighborOp::Compile(XlaOpKernelContext* ctx) {
685 GeneralCompile(ctx, align_corners_, half_pixel_centers_, is_kernel_bilinear_);
686 }
687
688 REGISTER_XLA_OP(Name("ResizeNearestNeighbor").CompileTimeConstantInput("size"),
689 ResizeNearestNeighborOp);
690
ResizeBilinearOp(OpKernelConstruction * ctx)691 ResizeBilinearOp::ResizeBilinearOp(OpKernelConstruction* ctx)
692 : XlaOpKernel(ctx) {
693 OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_));
694 OP_REQUIRES_OK(ctx, ctx->GetAttr("half_pixel_centers", &half_pixel_centers_));
695 OP_REQUIRES(ctx, !half_pixel_centers_ || !align_corners_,
696 errors::Unimplemented("If half_pixel_centers is True, "
697 "align_corners must be False."));
698 }
699
Compile(XlaOpKernelContext * ctx)700 void ResizeBilinearOp::Compile(XlaOpKernelContext* ctx) {
701 GeneralCompile(ctx, align_corners_, half_pixel_centers_, is_kernel_bilinear_);
702 }
703
704 REGISTER_XLA_OP(Name("ResizeBilinear").CompileTimeConstantInput("size"),
705 ResizeBilinearOp);
706
ResizeBilinearGradOp(OpKernelConstruction * ctx)707 ResizeBilinearGradOp::ResizeBilinearGradOp(OpKernelConstruction* ctx)
708 : XlaOpKernel(ctx) {
709 OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_));
710 OP_REQUIRES_OK(ctx, ctx->GetAttr("half_pixel_centers", &half_pixel_centers_));
711
712 if ((!align_corners_ || half_pixel_centers_)) {
713 if (ctx->device_type().type_string() == DEVICE_GPU_XLA_JIT) {
714 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
715 // Use light outside compilation on GPU only.
716 fallback_tf_kernel_.emplace(ctx);
717 return;
718 #endif
719 }
720
721 OP_REQUIRES(ctx, false,
722 errors::Unimplemented(
723 "ResizeBilinearGrad with align_corners=False or "
724 "half_pixel_centers=True is not yet implemented"));
725 }
726
727 DataType output_dtype;
728 OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &output_dtype));
729 OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(output_dtype, &output_type_));
730 }
731
Compile(XlaOpKernelContext * ctx)732 void ResizeBilinearGradOp::Compile(XlaOpKernelContext* ctx) {
733 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
734 if (fallback_tf_kernel_.has_value()) {
735 fallback_tf_kernel_->Compile(ctx);
736 return;
737 }
738 #endif
739
740 xla::XlaBuilder* b = ctx->builder();
741 TensorShape input_shape = ctx->InputShape(1);
742 OP_REQUIRES(ctx, input_shape.dims() == 4,
743 errors::InvalidArgument("input must be 4-dimensional",
744 input_shape.DebugString()));
745 const int64_t batch = input_shape.dim_size(0);
746 std::vector<int64_t> in_size = {input_shape.dim_size(1),
747 input_shape.dim_size(2)};
748 const int64_t channels = input_shape.dim_size(3);
749 OP_REQUIRES(ctx, in_size[0] > 0 && in_size[1] > 0,
750 errors::InvalidArgument("input size must be positive, got [",
751 in_size[0], ",", in_size[1], "]"));
752
753 TensorShape grad_shape = ctx->InputShape(0);
754 OP_REQUIRES(ctx, grad_shape.dims() == 4,
755 errors::InvalidArgument("gradient must be 4-dimensional",
756 grad_shape.DebugString()));
757 const int64_t grad_batch = grad_shape.dim_size(0);
758 const std::vector<int64_t> grad_size = {grad_shape.dim_size(1),
759 grad_shape.dim_size(2)};
760 const int64_t grad_channels = grad_shape.dim_size(3);
761 OP_REQUIRES(ctx, batch == grad_batch,
762 errors::InvalidArgument(
763 "activations and gradients must have the same batch size (",
764 batch, " vs. ", grad_batch, ")"));
765 OP_REQUIRES(ctx, grad_size[0] > 0 && grad_size[1] > 0,
766 errors::InvalidArgument("gradient size must be positive, got [",
767 grad_size[0], ",", grad_size[1], "]"));
768 OP_REQUIRES(
769 ctx, channels == grad_channels,
770 errors::InvalidArgument(
771 "activations and gradients must have the same number of channels (",
772 channels, " vs. ", grad_channels, ")"));
773
774 const int num_spatial_dims = 2;
775
776 xla::XlaOp grad = ctx->Input(0);
777
778 xla::XlaOp output = grad;
779 while (in_size != grad_size) {
780 if (in_size[0] != 1 && in_size[1] != 1) {
781 std::vector<float> k = {
782 (static_cast<float>(grad_size[0]) - 1) / ((in_size[0] - 1) * 2),
783 (static_cast<float>(grad_size[1]) - 1) / ((in_size[1] - 1) * 2)};
784 if ((k[0] == std::floor(k[0])) && (k[1] == std::floor(k[1])) &&
785 k[0] > 1 && k[1] > 1) {
786 std::vector<int64_t> next_grad_size = {(in_size[0] - 1) * 2 + 1,
787 (in_size[1] - 1) * 2 + 1};
788 output = ResizeUsingDilationAndConvolutionGradOp(
789 b, grad, xla::F32, num_spatial_dims, in_size, next_grad_size,
790 channels, align_corners_, true);
791 grad = output;
792 in_size = next_grad_size;
793 } else {
794 output = ResizeUsingDilationAndConvolutionGradOp(
795 b, grad, xla::F32, num_spatial_dims, in_size, grad_size, channels,
796 align_corners_, true);
797 in_size = grad_size;
798 }
799 } else {
800 output = ResizeUsingDilationAndConvolutionGradOp(
801 b, grad, xla::F32, num_spatial_dims, in_size, grad_size, channels,
802 align_corners_, true);
803 in_size = grad_size;
804 }
805 }
806
807 output = xla::ConvertElementType(output, output_type_);
808 ctx->SetOutput(0, output);
809 }
810
811 REGISTER_XLA_OP(Name("ResizeBilinearGrad"), ResizeBilinearGradOp);
812
813 } // namespace tensorflow
814