xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/tf2xla/kernels/image_resize_ops.h"
16 
17 #include <string>
18 
19 #include "absl/strings/str_format.h"
20 #include "absl/types/span.h"
21 #include "tensorflow/compiler/jit/xla_activity.pb.h"
22 #include "tensorflow/compiler/jit/xla_activity_listener.h"
23 #include "tensorflow/compiler/tf2xla/shape_util.h"
24 #include "tensorflow/compiler/tf2xla/type_util.h"
25 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
26 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
27 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
28 #include "tensorflow/compiler/xla/array4d.h"
29 #include "tensorflow/compiler/xla/client/lib/constants.h"
30 #include "tensorflow/compiler/xla/client/xla_builder.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/framework/kernel_def_builder.h"
34 #include "tensorflow/core/framework/register_types.h"
35 #include "tensorflow/core/lib/math/math_util.h"
36 #include "tensorflow/core/platform/errors.h"
37 
38 namespace tensorflow {
39 namespace {
40 
41 // We implement bilinear interpolation by upsampling followed by convolution.
42 // The basic idea is as follows. To scale from NxN to RxR:
43 //
44 //    1. S := (N - 1) /  gcd(N-1, R-1)
45 //    2. k := (R - 1) /  gcd(N-1, R-1)
46 //    3. Convolution((2k-1)x(2k-1), stride=S, lhs_dilation=k, padding=k-1)
47 //
48 // For example, to Scale from 7x7 -> 15x15:
49 //
50 //    1. S := (7-1) / gcd(7-1, 15-1) = 6 / gcd(6, 14) = 6 / 2 = 3
51 //    2. k := (15 - 1) / gcd(7-1, 15-1) = 14 / gcd(6, 14) = 14 / 2 = 7
52 //    3. Convolution(15x15, stride=3, lhs_dilation=7, padding=2)
53 //
54 //
55 // The 7x7 -> 15x15 case is much too large to write out in full as an
56 // example. The smallest interesting example is 3x3 -> 4x4.
57 //
58 // S := 2
59 // k := 3
60 //
61 // 00 03 06    00 00 00 00 00 00 00 00 00 00 00      00 02 04 06
62 // 09 12 15 -> 00 00 00 00 00 00 00 00 00 00 00   -> 06 08 10 12
63 // 18 21 24    00 00 00 00 00 03 00 00 06 00 00      12 14 16 18
64 //             00 00 00 00 00 00 00 00 00 00 00      18 20 22 24
65 //             00 00 00 00 00 00 00 00 00 00 00
66 //             00 00 09 00 00 12 00 00 15 00 00
67 //             00 00 00 00 00 00 00 00 00 00 00
68 //             00 00 00 00 00 00 00 00 00 00 00
69 //             00 00 18 00 00 21 00 00 24 00 00
70 //             00 00 00 00 00 00 00 00 00 00 00
71 //             00 00 00 00 00 00 00 00 00 00 00
72 //
73 // with the following convolutional kernel, with stride [2, 2]:
74 //       1 2 3 2 1
75 //       2 4 6 4 2
76 // 1/9 * 3 6 9 6 3
77 //       2 4 6 4 2
78 //       1 2 3 2 1
79 // Note that the convolution kernel matrix is separable and thus we can instead
80 // use 2 consecutive 1D kernel of the dimension 2k-1, along each axis.
81 
82 // Computes the size of the convolutional kernel and stride to use when resizing
83 // from in_size to out_size.
84 struct ResizeConvolutionDims {
85   // Size of the kernel to use.
86   std::vector<int64_t> kernel_size;  // k
87 
88   // Stride of the convolution to use.
89   std::vector<int64_t> stride;  // S
90 };
ComputeResizeConvolutionParameters(absl::Span<const int64_t> in_size,absl::Span<const int64_t> out_size,bool align_corners)91 ResizeConvolutionDims ComputeResizeConvolutionParameters(
92     absl::Span<const int64_t> in_size, absl::Span<const int64_t> out_size,
93     bool align_corners) {
94   CHECK_EQ(in_size.size(), out_size.size());
95   int num_spatial_dims = in_size.size();
96   ResizeConvolutionDims dims;
97   dims.kernel_size.resize(num_spatial_dims);
98   dims.stride.resize(num_spatial_dims);
99   for (int i = 0; i < num_spatial_dims; ++i) {
100     if (in_size[i] == 1) {
101       // We must handle input size 1 specially because XLA convolution does
102       // not allow stride 0.
103       dims.stride[i] = dims.kernel_size[i] = 1;
104     } else if (out_size[i] == 1) {
105       // If in_size[i] > 1 but out_size[i] == 1, then we slice out the first
106       // entry before resizing.
107       dims.stride[i] = dims.kernel_size[i] = 1;
108     } else {
109       // The scaling factor changes depending on the alignment of corners.
110       const int64_t in_size_factor =
111           align_corners ? in_size[i] - 1 : in_size[i];
112       const int64_t out_size_factor =
113           align_corners ? out_size[i] - 1 : out_size[i];
114 
115       int64_t gcd = MathUtil::GCD(static_cast<uint64>(in_size_factor),
116                                   static_cast<uint64>(out_size_factor));
117       dims.stride[i] = in_size_factor / gcd;
118       dims.kernel_size[i] = out_size_factor / gcd;
119     }
120   }
121   return dims;
122 }
123 
124 // The upper padding of the input needed by ConvGeneralDilated calls is
125 // determined by solving two related relationships (assuming rhs_dilation == 0):
126 // 1. dilated_input_dim = lower_padding + upper_padding
127 //                        + lhs_dilation * (in_size - 1) + 1
128 // 2. dilated_input_dim = (2 * dims.kernel-size - 1)
129 //                        + dims.stride * (out_size - 1)
CalculateUpperPadding(int64_t in_size,int64_t out_size,int64_t kernel_size,int64_t stride)130 int64_t CalculateUpperPadding(int64_t in_size, int64_t out_size,
131                               int64_t kernel_size, int64_t stride) {
132   int64_t padding = (2 * kernel_size - 1) + (out_size - 1) * stride -
133                     (kernel_size - 1) - 1 - (kernel_size * (in_size - 1));
134 
135   return padding;
136 }
137 
138 // Form a 2D convolution kernel like:
139 //       1 2 3 2 1
140 //       2 4 6 4 2
141 // 1/9 * 3 6 9 6 3
142 //       2 4 6 4 2
143 //       1 2 3 2 1
144 // by multiplying two 1D kernels of the form:
145 // 1/3 * [1 2 3 2 1]
146 // If the 2D kernel would be very large, the 1D kernel can be applied once in
147 // each dimension due to the symmetry of the kernel along all axis to reduce the
148 // computational intensity.
MakeBilinear1DKernel(xla::XlaBuilder * builder,xla::PrimitiveType type,int64_t n)149 xla::XlaOp MakeBilinear1DKernel(xla::XlaBuilder* builder,
150                                 xla::PrimitiveType type, int64_t n) {
151   std::vector<float> kernel(n * 2 - 1);
152   for (int64_t i = 0; i < n; ++i) {
153     float v = (i + 1.0f) / n;
154     kernel[i] = v;
155     kernel[n * 2 - 2 - i] = v;
156   }
157   return xla::ConvertElementType(xla::ConstantR1<float>(builder, kernel), type);
158 }
159 
160 // Unlike the bilinear kernel, which is triangular, the nearest neighbor
161 // kernel is a square. For example, a 1D kernel with n=3 would look like
162 // [0 1 1 1 0]
163 // and n=4 would look like
164 // [0 0 1 1 1 1 0].
165 // Note that in the second case, the kernel is not symmetric and we default
166 // to the right (because an existing non TPU kernel
167 // for nearest neighbor resize already chose to default to the right,
168 // so we want to be consistent).
MakeNearestNeighbor1DKernel(xla::XlaBuilder * builder,xla::PrimitiveType type,int64_t n)169 xla::XlaOp MakeNearestNeighbor1DKernel(xla::XlaBuilder* builder,
170                                        xla::PrimitiveType type, int64_t n) {
171   std::vector<float> kernel(n * 2 - 1, 0.0f);
172   std::fill(&kernel[n / 2], &kernel[(3 * n) / 2], 1.0f);
173 
174   return xla::ConvertElementType(xla::ConstantR1<float>(builder, kernel), type);
175 }
176 
177 // Kernels with more than 16 spatial elements are considered intense and the
178 // kernel should be applied to each dimension independently.
179 const int64_t kMax2DKernelSize = 16;
180 
MakeGeneralResizeKernel(xla::XlaBuilder * builder,xla::PrimitiveType type,absl::Span<const int64_t> kernel_size,int64_t channels,bool is_kernel_bilinear)181 xla::XlaOp MakeGeneralResizeKernel(xla::XlaBuilder* builder,
182                                    xla::PrimitiveType type,
183                                    absl::Span<const int64_t> kernel_size,
184                                    int64_t channels, bool is_kernel_bilinear) {
185   auto make_kernel_func =
186       is_kernel_bilinear ? MakeBilinear1DKernel : MakeNearestNeighbor1DKernel;
187 
188   std::vector<int64_t> depthwise_kernel_sizes = {
189       (2 * kernel_size[0] - 1), (2 * kernel_size[1] - 1), channels, 1};
190   auto depthwise_kernel =
191       xla::BroadcastInDim(make_kernel_func(builder, type, kernel_size[1]),
192                           depthwise_kernel_sizes, /*broadcast_dimensions=*/{1});
193 
194   return xla::Mul(depthwise_kernel,
195                   make_kernel_func(builder, type, kernel_size[0]),
196                   /*broadcast_dimensions=*/{0});
197 }
198 
MakeGeneralResizeKernelInDim(xla::XlaBuilder * builder,xla::PrimitiveType type,absl::Span<const int64_t> kernel_size,int64_t channels,int64_t dim,bool is_kernel_bilinear)199 xla::XlaOp MakeGeneralResizeKernelInDim(xla::XlaBuilder* builder,
200                                         xla::PrimitiveType type,
201                                         absl::Span<const int64_t> kernel_size,
202                                         int64_t channels, int64_t dim,
203                                         bool is_kernel_bilinear) {
204   auto make_kernel_func =
205       is_kernel_bilinear ? MakeBilinear1DKernel : MakeNearestNeighbor1DKernel;
206 
207   std::vector<int64_t> depthwise_kernel_sizes = {
208       dim == 0 ? (2 * kernel_size[0] - 1) : 1,
209       dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels, 1};
210   return xla::BroadcastInDim(make_kernel_func(builder, type, kernel_size[dim]),
211                              depthwise_kernel_sizes,
212                              /*broadcast_dimensions=*/{dim});
213 }
214 
BroadcastSpatialDimensions(xla::XlaBuilder * builder,const xla::XlaOp & input,int32_t spatial_dimensions_offset,absl::Span<const int64_t> in_size,absl::Span<const int64_t> out_size)215 xla::XlaOp BroadcastSpatialDimensions(xla::XlaBuilder* builder,
216                                       const xla::XlaOp& input,
217                                       int32_t spatial_dimensions_offset,
218                                       absl::Span<const int64_t> in_size,
219                                       absl::Span<const int64_t> out_size) {
220   // Add broadcasts to handle expanding from a size == 1 dimension to a
221   // size > 1 dimension.
222   auto broadcast_shape_or_status = builder->GetShape(input);
223   if (!broadcast_shape_or_status.ok()) {
224     return builder->ReportError(broadcast_shape_or_status.status());
225   }
226   xla::Shape broadcast_shape = broadcast_shape_or_status.ValueOrDie();
227   for (int32_t i = 0; i < in_size.size(); ++i) {
228     if (in_size[i] == 1 && out_size[i] > 1) {
229       broadcast_shape.set_dimensions(spatial_dimensions_offset + i,
230                                      out_size[i]);
231     }
232   }
233   return xla::BroadcastInDim(input, broadcast_shape.dimensions(),
234                              /*broadcast_dimensions=*/{0, 1, 2, 3});
235 }
236 
ResizeUsingDilationAndConvolution(xla::XlaBuilder * builder,const xla::XlaOp & input,xla::PrimitiveType type,const int num_spatial_dims,absl::Span<const int64_t> in_size,absl::Span<const int64_t> out_size,const int64_t channels,const bool align_corners,bool is_kernel_bilinear)237 xla::XlaOp ResizeUsingDilationAndConvolution(
238     xla::XlaBuilder* builder, const xla::XlaOp& input, xla::PrimitiveType type,
239     const int num_spatial_dims, absl::Span<const int64_t> in_size,
240     absl::Span<const int64_t> out_size, const int64_t channels,
241     const bool align_corners, bool is_kernel_bilinear) {
242   // Picture for a 1x3 to 1x4 bilinear resize:
243   // stride = 2, kernel size = 3
244   // Input:
245   // 3 6 9
246   // Input with dilation and padding:
247   // 0 0 3 0 0 6 0 0 9 0 0
248   // Convolution kernel:
249   // 1/3 * [1 2 3 2 1]
250   // Output:
251   // 3 5 7 9
252   xla::ConvolutionDimensionNumbers dimension_numbers;
253   dimension_numbers.set_input_batch_dimension(0);
254   dimension_numbers.set_output_batch_dimension(0);
255   dimension_numbers.set_input_feature_dimension(num_spatial_dims + 1);
256   dimension_numbers.set_output_feature_dimension(num_spatial_dims + 1);
257   for (int i = 0; i < num_spatial_dims; ++i) {
258     dimension_numbers.add_input_spatial_dimensions(1 + i);
259     dimension_numbers.add_output_spatial_dimensions(1 + i);
260     dimension_numbers.add_kernel_spatial_dimensions(i);
261   }
262   dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims + 1);
263   dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims);
264 
265   ResizeConvolutionDims dims =
266       ComputeResizeConvolutionParameters(in_size, out_size, align_corners);
267 
268   if (dims.kernel_size[0] * dims.kernel_size[1] >
269       kMax2DKernelSize * kMax2DKernelSize) {
270     BroadcastOptimizationRemark(
271         XlaOptimizationRemark::SLOW_IMAGE_RESIZE_DIMENSIONS,
272         absl::StrFormat("%dx%d", dims.kernel_size[0], dims.kernel_size[1]))
273         .IgnoreError();
274   }
275 
276   xla::XlaOp output;
277 
278   // Concatenation and padding below currently assumes num_spatial_dims is 2 to
279   // prevent needless code complexity.
280   CHECK_EQ(num_spatial_dims, 2)
281       << "ResizeUsingDilationAndConvolution pads only 2 dimensions currently.";
282   std::vector<int64_t> upper_padding(num_spatial_dims);
283   for (int i = 0; i < num_spatial_dims; ++i) {
284     upper_padding[i] = dims.kernel_size[i] - 1;
285   }
286   xla::XlaOp input_data = input;
287 
288   if (!align_corners) {
289     // When Tensorflow does not align_corners, the resize indexing can access
290     // beyond the upper bound and is instead clamped to prevent out of bounds
291     // reads. This is conceptually the same as extending the edges of the input.
292     // We emulate this by copying the last row/column of the input.
293     // Calculate what padding would be needed then determine how far to extend
294     // the border before lhs dilation.
295     std::vector<int64_t> num_extended(num_spatial_dims);
296     upper_padding[0] = CalculateUpperPadding(
297         in_size[0], out_size[0], dims.kernel_size[0], dims.stride[0]);
298     upper_padding[1] = CalculateUpperPadding(
299         in_size[1], out_size[1], dims.kernel_size[1], dims.stride[1]);
300     num_extended[0] = upper_padding[0] / (dims.kernel_size[0]);
301     num_extended[1] = upper_padding[1] / (dims.kernel_size[1]);
302 
303     const int64_t batch_dim_size =
304         builder->GetShape(input).ValueOrDie().dimensions(0);
305     if (num_extended[0] > 0) {
306       auto slice = xla::Slice(
307           input_data, {0, in_size[0] - 1, 0, 0},
308           {batch_dim_size, in_size[0], in_size[1], channels}, {1, 1, 1, 1});
309       for (int i = 0; i < num_extended[0]; i++) {
310         input_data = xla::ConcatInDim(builder, {input_data, slice}, 1);
311       }
312     }
313 
314     if (num_extended[1] > 0) {
315       auto slice = xla::Slice(
316           input_data, {0, 0, in_size[1] - 1, 0},
317           {batch_dim_size, in_size[0] + num_extended[0], in_size[1], channels},
318           {1, 1, 1, 1});
319       for (int i = 0; i < num_extended[1]; i++) {
320         input_data = xla::ConcatInDim(builder, {input_data, slice}, 2);
321       }
322     }
323 
324     // Setting in_size to (in_size + num_extended) due to the above Slice and
325     // ConcatInDim. Recalculate needed padding after the above Slice/Concat.
326     upper_padding[0] =
327         CalculateUpperPadding(in_size[0] + num_extended[0], out_size[0],
328                               dims.kernel_size[0], dims.stride[0]);
329     upper_padding[1] =
330         CalculateUpperPadding(in_size[1] + num_extended[1], out_size[1],
331                               dims.kernel_size[1], dims.stride[1]);
332   }
333 
334   // Split convolutions into independent dimensions if they would be a very
335   // large kernel or if one or more of the dimensions are already equal.
336   bool decompose_resize =
337       in_size[0] == out_size[0] || in_size[1] == out_size[1] ||
338       dims.kernel_size[0] * dims.kernel_size[1] >= kMax2DKernelSize;
339   if (!decompose_resize) {
340     xla::XlaOp kernel = MakeGeneralResizeKernel(builder, type, dims.kernel_size,
341                                                 channels, is_kernel_bilinear);
342     output =
343         xla::ConvGeneralDilated(input_data, kernel, dims.stride,
344                                 /*padding=*/
345                                 {{dims.kernel_size[0] - 1, upper_padding[0]},
346                                  {dims.kernel_size[1] - 1, upper_padding[1]}},
347                                 /*lhs_dilation=*/dims.kernel_size,
348                                 /*rhs_dilation=*/{1, 1}, dimension_numbers,
349                                 /*feature_group_count=*/channels);
350   } else {
351     output = input_data;
352     if (in_size[0] != out_size[0]) {
353       xla::XlaOp kernel0 = MakeGeneralResizeKernelInDim(
354           builder, type, dims.kernel_size, channels, 0, is_kernel_bilinear);
355       output = xla::ConvGeneralDilated(
356           output, kernel0, {dims.stride[0], 1},
357           /*padding=*/
358           {{dims.kernel_size[0] - 1, upper_padding[0]}, {0, 0}},
359           /*lhs_dilation=*/{dims.kernel_size[0], 1},
360           /*rhs_dilation=*/{1, 1}, dimension_numbers,
361           /*feature_group_count=*/channels);
362     }
363 
364     if (in_size[1] != out_size[1]) {
365       xla::XlaOp kernel1 = MakeGeneralResizeKernelInDim(
366           builder, type, dims.kernel_size, channels, 1, is_kernel_bilinear);
367       output = xla::ConvGeneralDilated(
368           output, kernel1, {1, dims.stride[1]},
369           /*padding=*/
370           {{0, 0}, {dims.kernel_size[1] - 1, upper_padding[1]}},
371           /*lhs_dilation=*/{1, dims.kernel_size[1]},
372           /*rhs_dilation=*/{1, 1}, dimension_numbers,
373           /*feature_group_count=*/channels);
374     }
375   }
376 
377   // Add broadcasts to handle expanding from a size == 1 dimension to a
378   // size > 1 dimension.
379   return BroadcastSpatialDimensions(
380       builder, output, /*spatial_dimensions_offset=*/1, in_size, out_size);
381 }
382 
ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder * builder,const xla::XlaOp & grad,xla::PrimitiveType type,const int num_spatial_dims,absl::Span<const int64_t> in_size,absl::Span<const int64_t> grad_size,const int64_t channels,const bool align_corners,bool is_kernel_bilinear)383 xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(
384     xla::XlaBuilder* builder, const xla::XlaOp& grad, xla::PrimitiveType type,
385     const int num_spatial_dims, absl::Span<const int64_t> in_size,
386     absl::Span<const int64_t> grad_size, const int64_t channels,
387     const bool align_corners, bool is_kernel_bilinear) {
388   ResizeConvolutionDims dims =
389       ComputeResizeConvolutionParameters(in_size, grad_size, align_corners);
390 
391   // To form the backward convolution, we keep the kernel unchanged (it is
392   // already symmetric) and swap the roles of strides and LHS dilation.
393   xla::ConvolutionDimensionNumbers dimension_numbers;
394   dimension_numbers.set_input_batch_dimension(0);
395   dimension_numbers.set_output_batch_dimension(0);
396   dimension_numbers.set_input_feature_dimension(num_spatial_dims + 1);
397   dimension_numbers.set_output_feature_dimension(num_spatial_dims + 1);
398   for (int i = 0; i < num_spatial_dims; ++i) {
399     dimension_numbers.add_input_spatial_dimensions(i + 1);
400     dimension_numbers.add_output_spatial_dimensions(i + 1);
401     dimension_numbers.add_kernel_spatial_dimensions(i);
402   }
403   dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims + 1);
404   dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims);
405   xla::XlaOp output;
406   if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) {
407     xla::XlaOp kernel = MakeGeneralResizeKernel(builder, type, dims.kernel_size,
408                                                 channels, is_kernel_bilinear);
409 
410     // Broadcast the input kernel where the forward op expanded from a size == 1
411     // dimension to a size > 1 dimension. This has the effect of summing the
412     // gradient contributions in that dimension.
413     kernel = BroadcastSpatialDimensions(
414         builder, kernel, /*spatial_dimensions_offset=*/0, in_size, grad_size);
415 
416     output = xla::ConvGeneralDilated(
417         grad, kernel, /*window_strides=*/dims.kernel_size,
418         /*padding=*/
419         {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1},
420          {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}},
421         /*lhs_dilation=*/dims.stride,
422         /*rhs_dilation=*/{1, 1}, dimension_numbers,
423         /*feature_group_count=*/channels);
424   } else {
425     xla::XlaOp kernel0 = MakeGeneralResizeKernelInDim(
426         builder, type, dims.kernel_size, channels, 0, is_kernel_bilinear);
427     xla::XlaOp kernel1 = MakeGeneralResizeKernelInDim(
428         builder, type, dims.kernel_size, channels, 1, is_kernel_bilinear);
429 
430     // Broadcast the input kernel where the forward op expanded from a
431     // size == 1 dimension to a size > 1 dimension. This has the effect of
432     // summing the gradient contributions in that dimension.
433     if (in_size[0] == 1 && grad_size[0] > 1) {
434       kernel0 = BroadcastSpatialDimensions(builder, kernel0,
435                                            /*spatial_dimensions_offset=*/0, {1},
436                                            {grad_size[0]});
437     }
438     if (in_size[1] == 1 && grad_size[1] > 1) {
439       kernel1 = BroadcastSpatialDimensions(builder, kernel0,
440                                            /*spatial_dimensions_offset=*/0,
441                                            in_size, grad_size);
442     }
443 
444     output = xla::ConvGeneralDilated(
445         grad, kernel0, /*window_strides=*/{dims.kernel_size[0], 1},
446         /*padding=*/
447         {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {0, 0}},
448         /*lhs_dilation=*/{dims.stride[0], 1},
449         /*rhs_dilation=*/{1, 1}, dimension_numbers,
450         /*feature_group_count=*/channels);
451 
452     output = xla::ConvGeneralDilated(
453         output, kernel1, /*window_strides=*/{1, dims.kernel_size[1]},
454         /*padding=*/
455         {{0, 0}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}},
456         /*lhs_dilation=*/{1, dims.stride[1]},
457         /*rhs_dilation=*/{1, 1}, dimension_numbers,
458         /*feature_group_count=*/channels);
459   }
460 
461   // If in_size[i] > 1 and grad_size[i] == 1, pad the output in dimension i.
462   // Opposite of the slice performed by the forward op.
463   xla::PaddingConfig padding = xla::MakeNoPaddingConfig(4);
464   bool pad_output = false;
465   for (int i = 0; i < num_spatial_dims; ++i) {
466     if (in_size[i] > 1 && grad_size[i] == 1) {
467       pad_output = true;
468       padding.mutable_dimensions(1 + i)->set_edge_padding_high(in_size[i] - 1);
469     }
470   }
471   if (pad_output) {
472     output = xla::Pad(output, xla::Zero(builder, type), padding);
473   }
474   return output;
475 }
476 
GeneralCompile(XlaOpKernelContext * ctx,bool align_corners_,bool half_pixel_centers_,bool is_kernel_bilinear_)477 void GeneralCompile(XlaOpKernelContext* ctx, bool align_corners_,
478                     bool half_pixel_centers_, bool is_kernel_bilinear_) {
479   // We implement bilinear interpolation and nearest neighbor with a Gather op.
480   // For each output pixel, we gather the necessary slices of the input.
481   // We then construct the weights that are necessary to calculate the weighted
482   // sum for each output pixel. We do this with a DotGeneral op.
483   xla::XlaBuilder* b = ctx->builder();
484 
485   TensorShape input_shape = ctx->InputShape(0);
486   OP_REQUIRES(ctx, input_shape.dims() == 4,
487               errors::InvalidArgument("input must be 4-dimensional",
488                                       input_shape.DebugString()));
489   // First dimension always assumed to be batch
490   const int64_t batch = input_shape.dim_size(0);
491   std::vector<int64_t> in_size = {input_shape.dim_size(1),
492                                   input_shape.dim_size(2)};
493   // Last/4th dimension always assumed to be num channels
494   const int64_t channels = input_shape.dim_size(3);
495   OP_REQUIRES(ctx, in_size[0] > 0 && in_size[1] > 0,
496               errors::InvalidArgument("input size must be positive, got [",
497                                       in_size[0], ",", in_size[1], "]"));
498 
499   std::vector<int64_t> out_size;
500   OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &out_size));
501   OP_REQUIRES(ctx, out_size.size() == 2,
502               errors::InvalidArgument("output size must be length 2, got ",
503                                       out_size.size()));
504   OP_REQUIRES(ctx, out_size[0] > 0 && out_size[1] > 0,
505               errors::InvalidArgument("output size must be positive, got [",
506                                       out_size[0], ",", out_size[1], "]"));
507 
508   xla::XlaOp input = ctx->Input(0);
509   xla::PrimitiveType input_type = ctx->input_xla_type(0);
510   xla::PrimitiveType original_input_type = input_type;
511   if (is_kernel_bilinear_ || xla::primitive_util::IsIntegralType(input_type)) {
512     input = xla::ConvertElementType(input, xla::F32);
513     input_type = xla::F32;
514   }
515   DataType output_dtype =
516       EncodePrimitiveTypeAsDataType(input_type).ValueOrDie();
517 
518   xla::XlaOp scalar_one_op =
519       xla::ConvertElementType(xla::ConstantR0(b, 1), input_type);
520   xla::XlaOp scalar_half_op =
521       xla::ConvertElementType(xla::ConstantR0(b, 0.5), input_type);
522   xla::XlaOp scalar_zero_op =
523       xla::ConvertElementType(xla::ConstantR0(b, 0), input_type);
524   float h_scale;
525   if (align_corners_ && out_size[0] > 1) {
526     h_scale = (in_size[0] - 1) / static_cast<float>(out_size[0] - 1);
527   } else {
528     h_scale = in_size[0] / static_cast<float>(out_size[0]);
529   }
530   xla::XlaOp h_span_start =
531       xla::Iota(b, xla::ShapeUtil::MakeShape(input_type, {out_size[0]}), 0);
532   if (half_pixel_centers_) {
533     h_span_start = xla::Add(h_span_start, scalar_half_op);
534   }
535   xla::XlaOp h_scale_op =
536       xla::ConvertElementType(xla::ConstantR0(b, h_scale), input_type);
537   xla::XlaOp h_sample_f = xla::Mul(h_span_start, h_scale_op);
538 
539   if (is_kernel_bilinear_) {
540     h_span_start = xla::Sub(h_sample_f, scalar_one_op);
541     if (half_pixel_centers_) {
542       h_span_start = xla::Sub(h_span_start, scalar_half_op);
543     }
544     h_span_start = xla::Ceil(h_span_start);
545   } else {
546     h_span_start =
547         align_corners_ ? xla::Round(h_sample_f) : xla::Floor(h_sample_f);
548   }
549   const int64_t h_span_size =
550       is_kernel_bilinear_ ? std::min(static_cast<int64_t>(3), in_size[0]) : 1;
551   xla::XlaOp h_upper_bound = xla::ConvertElementType(
552       xla::ConstantR0(b, in_size[0] - h_span_size), input_type);
553   if (!is_kernel_bilinear_ && !half_pixel_centers_) {
554     h_span_start = xla::Min(h_span_start, h_upper_bound);
555   } else {
556     h_span_start = xla::Clamp(scalar_zero_op, h_span_start, h_upper_bound);
557   }
558   xla::XlaOp broadcasted_h_span_start =
559       xla::BroadcastInDim(h_span_start, {out_size[0], out_size[1], 1}, {0});
560 
561   float w_scale;
562   if (align_corners_ && out_size[1] > 1) {
563     w_scale = (in_size[1] - 1) / static_cast<float>(out_size[1] - 1);
564   } else {
565     w_scale = in_size[1] / static_cast<float>(out_size[1]);
566   }
567   xla::XlaOp w_span_start =
568       xla::Iota(b, xla::ShapeUtil::MakeShape(input_type, {out_size[1]}), 0);
569   if (half_pixel_centers_) {
570     w_span_start = xla::Add(w_span_start, scalar_half_op);
571   }
572   xla::XlaOp w_scale_op =
573       xla::ConvertElementType(xla::ConstantR0(b, w_scale), input_type);
574   xla::XlaOp w_sample_f = xla::Mul(w_span_start, w_scale_op);
575   if (is_kernel_bilinear_) {
576     w_span_start = xla::Sub(w_sample_f, scalar_one_op);
577     if (half_pixel_centers_) {
578       w_span_start = xla::Sub(w_span_start, scalar_half_op);
579     }
580     w_span_start = xla::Ceil(w_span_start);
581   } else {
582     w_span_start =
583         align_corners_ ? xla::Round(w_sample_f) : xla::Floor(w_sample_f);
584   }
585   const int64_t w_span_size =
586       is_kernel_bilinear_ ? std::min(static_cast<int64_t>(3), in_size[1]) : 1;
587   xla::XlaOp w_upper_bound = xla::ConvertElementType(
588       xla::ConstantR0(b, in_size[1] - w_span_size), input_type);
589   if (!is_kernel_bilinear_ && !half_pixel_centers_) {
590     w_span_start = xla::Min(w_span_start, w_upper_bound);
591   } else {
592     w_span_start = xla::Clamp(scalar_zero_op, w_span_start, w_upper_bound);
593   }
594   xla::XlaOp broadcasted_w_span_start =
595       xla::BroadcastInDim(w_span_start, {out_size[0], out_size[1], 1}, {1});
596 
597   xla::XlaOp concatted = xla::ConvertElementType(
598       xla::ConcatInDim(b, {broadcasted_h_span_start, broadcasted_w_span_start},
599                        2),
600       xla::S32);
601 
602   absl::InlinedVector<int64_t, 4> slize_sizes = {batch, h_span_size,
603                                                  w_span_size, channels};
604   xla::GatherDimensionNumbers dimension_numbers;
605   dimension_numbers.add_offset_dims(0);
606   dimension_numbers.add_offset_dims(1);
607   dimension_numbers.add_offset_dims(2);
608   dimension_numbers.add_offset_dims(3);
609   dimension_numbers.add_start_index_map(1);
610   dimension_numbers.add_start_index_map(2);
611   dimension_numbers.set_index_vector_dim(2);
612   input = xla::Gather(input, concatted, dimension_numbers, slize_sizes, false);
613 
614   xla::XlaOp w_weight;
615   if (is_kernel_bilinear_) {
616     xla::XlaOp w_sub = xla::Sub(w_span_start, w_sample_f);
617     w_sub = xla::BroadcastInDim(w_sub, {out_size[1], w_span_size}, {0});
618     xla::XlaOp w_offset =
619         xla::Iota(b, xla::ShapeUtil::MakeShape(input_type, {w_span_size}), 0);
620     xla::XlaOp w_kernel_pos = xla::Add(w_sub, w_offset, {1});
621     if (half_pixel_centers_) {
622       w_kernel_pos = xla::Add(w_kernel_pos, scalar_half_op);
623     }
624     w_weight = xla::Max(scalar_zero_op,
625                         xla::Sub(scalar_one_op, xla::Abs(w_kernel_pos)));
626   } else {
627     w_weight = xla::Broadcast(scalar_one_op, {out_size[1], w_span_size});
628   }
629   xla::XlaOp w_weight_sum = xla::Reduce(
630       w_weight, scalar_zero_op, *ctx->GetOrCreateAdd(output_dtype), {1});
631   w_weight = xla::Div(w_weight, w_weight_sum, {0});
632 
633   xla::XlaOp h_weight;
634   if (is_kernel_bilinear_) {
635     xla::XlaOp h_sub = xla::Sub(h_span_start, h_sample_f);
636     h_sub = xla::BroadcastInDim(h_sub, {out_size[0], h_span_size}, {0});
637     xla::XlaOp h_offset =
638         xla::Iota(b, xla::ShapeUtil::MakeShape(input_type, {h_span_size}), 0);
639     xla::XlaOp h_kernel_pos = xla::Add(h_sub, h_offset, {1});
640     if (half_pixel_centers_) {
641       h_kernel_pos = xla::Add(h_kernel_pos, scalar_half_op);
642     }
643     h_weight = xla::Max(scalar_zero_op,
644                         xla::Sub(scalar_one_op, xla::Abs(h_kernel_pos)));
645   } else {
646     h_weight = xla::Broadcast(scalar_one_op, {out_size[0], h_span_size});
647   }
648   xla::XlaOp h_weight_sum = xla::Reduce(
649       h_weight, scalar_zero_op, *ctx->GetOrCreateAdd(output_dtype), {1});
650   h_weight = xla::Div(h_weight, h_weight_sum, {0});
651 
652   xla::DotDimensionNumbers dot_dnum;
653   dot_dnum.add_lhs_contracting_dimensions(3);
654   dot_dnum.add_lhs_contracting_dimensions(1);
655   dot_dnum.add_rhs_contracting_dimensions(1);
656   dot_dnum.add_rhs_contracting_dimensions(2);
657   dot_dnum.add_lhs_batch_dimensions(2);
658   dot_dnum.add_lhs_batch_dimensions(0);
659   dot_dnum.add_rhs_batch_dimensions(4);
660   dot_dnum.add_rhs_batch_dimensions(5);
661   input = xla::DotGeneral(
662       xla::DotGeneral(w_weight, h_weight, xla::DotDimensionNumbers()), input,
663       dot_dnum);
664 
665   absl::InlinedVector<int64_t, 4> perm = {2, 0, 1, 3};
666   input = xla::Transpose(input, perm);
667 
668   if (!is_kernel_bilinear_ && original_input_type != input_type) {
669     input = xla::ConvertElementType(input, original_input_type);
670   }
671   ctx->SetOutput(0, input);
672 }
673 }  // namespace
674 
ResizeNearestNeighborOp(OpKernelConstruction * ctx)675 ResizeNearestNeighborOp::ResizeNearestNeighborOp(OpKernelConstruction* ctx)
676     : XlaOpKernel(ctx) {
677   OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_));
678   OP_REQUIRES_OK(ctx, ctx->GetAttr("half_pixel_centers", &half_pixel_centers_));
679   OP_REQUIRES(ctx, !half_pixel_centers_ || !align_corners_,
680               errors::Unimplemented("If half_pixel_centers is True, "
681                                     "align_corners must be False."));
682 }
683 
Compile(XlaOpKernelContext * ctx)684 void ResizeNearestNeighborOp::Compile(XlaOpKernelContext* ctx) {
685   GeneralCompile(ctx, align_corners_, half_pixel_centers_, is_kernel_bilinear_);
686 }
687 
688 REGISTER_XLA_OP(Name("ResizeNearestNeighbor").CompileTimeConstantInput("size"),
689                 ResizeNearestNeighborOp);
690 
ResizeBilinearOp(OpKernelConstruction * ctx)691 ResizeBilinearOp::ResizeBilinearOp(OpKernelConstruction* ctx)
692     : XlaOpKernel(ctx) {
693   OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_));
694   OP_REQUIRES_OK(ctx, ctx->GetAttr("half_pixel_centers", &half_pixel_centers_));
695   OP_REQUIRES(ctx, !half_pixel_centers_ || !align_corners_,
696               errors::Unimplemented("If half_pixel_centers is True, "
697                                     "align_corners must be False."));
698 }
699 
Compile(XlaOpKernelContext * ctx)700 void ResizeBilinearOp::Compile(XlaOpKernelContext* ctx) {
701   GeneralCompile(ctx, align_corners_, half_pixel_centers_, is_kernel_bilinear_);
702 }
703 
704 REGISTER_XLA_OP(Name("ResizeBilinear").CompileTimeConstantInput("size"),
705                 ResizeBilinearOp);
706 
ResizeBilinearGradOp(OpKernelConstruction * ctx)707 ResizeBilinearGradOp::ResizeBilinearGradOp(OpKernelConstruction* ctx)
708     : XlaOpKernel(ctx) {
709   OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_));
710   OP_REQUIRES_OK(ctx, ctx->GetAttr("half_pixel_centers", &half_pixel_centers_));
711 
712   if ((!align_corners_ || half_pixel_centers_)) {
713     if (ctx->device_type().type_string() == DEVICE_GPU_XLA_JIT) {
714 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
715       // Use light outside compilation on GPU only.
716       fallback_tf_kernel_.emplace(ctx);
717       return;
718 #endif
719     }
720 
721     OP_REQUIRES(ctx, false,
722                 errors::Unimplemented(
723                     "ResizeBilinearGrad with align_corners=False or "
724                     "half_pixel_centers=True is not yet implemented"));
725   }
726 
727   DataType output_dtype;
728   OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &output_dtype));
729   OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(output_dtype, &output_type_));
730 }
731 
Compile(XlaOpKernelContext * ctx)732 void ResizeBilinearGradOp::Compile(XlaOpKernelContext* ctx) {
733 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
734   if (fallback_tf_kernel_.has_value()) {
735     fallback_tf_kernel_->Compile(ctx);
736     return;
737   }
738 #endif
739 
740   xla::XlaBuilder* b = ctx->builder();
741   TensorShape input_shape = ctx->InputShape(1);
742   OP_REQUIRES(ctx, input_shape.dims() == 4,
743               errors::InvalidArgument("input must be 4-dimensional",
744                                       input_shape.DebugString()));
745   const int64_t batch = input_shape.dim_size(0);
746   std::vector<int64_t> in_size = {input_shape.dim_size(1),
747                                   input_shape.dim_size(2)};
748   const int64_t channels = input_shape.dim_size(3);
749   OP_REQUIRES(ctx, in_size[0] > 0 && in_size[1] > 0,
750               errors::InvalidArgument("input size must be positive, got [",
751                                       in_size[0], ",", in_size[1], "]"));
752 
753   TensorShape grad_shape = ctx->InputShape(0);
754   OP_REQUIRES(ctx, grad_shape.dims() == 4,
755               errors::InvalidArgument("gradient must be 4-dimensional",
756                                       grad_shape.DebugString()));
757   const int64_t grad_batch = grad_shape.dim_size(0);
758   const std::vector<int64_t> grad_size = {grad_shape.dim_size(1),
759                                           grad_shape.dim_size(2)};
760   const int64_t grad_channels = grad_shape.dim_size(3);
761   OP_REQUIRES(ctx, batch == grad_batch,
762               errors::InvalidArgument(
763                   "activations and gradients must have the same batch size (",
764                   batch, " vs. ", grad_batch, ")"));
765   OP_REQUIRES(ctx, grad_size[0] > 0 && grad_size[1] > 0,
766               errors::InvalidArgument("gradient size must be positive, got [",
767                                       grad_size[0], ",", grad_size[1], "]"));
768   OP_REQUIRES(
769       ctx, channels == grad_channels,
770       errors::InvalidArgument(
771           "activations and gradients must have the same number of channels (",
772           channels, " vs. ", grad_channels, ")"));
773 
774   const int num_spatial_dims = 2;
775 
776   xla::XlaOp grad = ctx->Input(0);
777 
778   xla::XlaOp output = grad;
779   while (in_size != grad_size) {
780     if (in_size[0] != 1 && in_size[1] != 1) {
781       std::vector<float> k = {
782           (static_cast<float>(grad_size[0]) - 1) / ((in_size[0] - 1) * 2),
783           (static_cast<float>(grad_size[1]) - 1) / ((in_size[1] - 1) * 2)};
784       if ((k[0] == std::floor(k[0])) && (k[1] == std::floor(k[1])) &&
785           k[0] > 1 && k[1] > 1) {
786         std::vector<int64_t> next_grad_size = {(in_size[0] - 1) * 2 + 1,
787                                                (in_size[1] - 1) * 2 + 1};
788         output = ResizeUsingDilationAndConvolutionGradOp(
789             b, grad, xla::F32, num_spatial_dims, in_size, next_grad_size,
790             channels, align_corners_, true);
791         grad = output;
792         in_size = next_grad_size;
793       } else {
794         output = ResizeUsingDilationAndConvolutionGradOp(
795             b, grad, xla::F32, num_spatial_dims, in_size, grad_size, channels,
796             align_corners_, true);
797         in_size = grad_size;
798       }
799     } else {
800       output = ResizeUsingDilationAndConvolutionGradOp(
801           b, grad, xla::F32, num_spatial_dims, in_size, grad_size, channels,
802           align_corners_, true);
803       in_size = grad_size;
804     }
805   }
806 
807   output = xla::ConvertElementType(output, output_type_);
808   ctx->SetOutput(0, output);
809 }
810 
811 REGISTER_XLA_OP(Name("ResizeBilinearGrad"), ResizeBilinearGradOp);
812 
813 }  // namespace tensorflow
814