1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Tests the reduce-window XLA operation.
17
18 #include <limits>
19 #include <memory>
20
21 #include "absl/strings/str_cat.h"
22 #include "absl/strings/str_join.h"
23 #include "absl/types/span.h"
24 #include "tensorflow/compiler/xla/array2d.h"
25 #include "tensorflow/compiler/xla/array3d.h"
26 #include "tensorflow/compiler/xla/array4d.h"
27 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
28 #include "tensorflow/compiler/xla/client/local_client.h"
29 #include "tensorflow/compiler/xla/client/padding.h"
30 #include "tensorflow/compiler/xla/client/xla_builder.h"
31 #include "tensorflow/compiler/xla/client/xla_computation.h"
32 #include "tensorflow/compiler/xla/reference_util.h"
33 #include "tensorflow/compiler/xla/shape_util.h"
34 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
35 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
36 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
37 #include "tensorflow/compiler/xla/tests/test_macros.h"
38 #include "tensorflow/compiler/xla/xla_data.pb.h"
39 #include "tensorflow/core/lib/core/status.h"
40 #include "tensorflow/core/lib/core/status_test_util.h"
41 #include "tensorflow/core/platform/test.h"
42
43 namespace xla {
44 namespace {
45
46 #ifdef XLA_BACKEND_SUPPORTS_BFLOAT16
47 // Tests both F32 and BF16.
48 static std::array<bool, 2> use_bfloat16_params{false, true};
49 #else
50 // Only tests F32.
51 static std::array<bool, 1> use_bfloat16_params{false};
52 #endif
53
54 class ReduceWindowTestBase : public ClientLibraryTestBase {
55 public:
DefaultErrorSpec() const56 ErrorSpec DefaultErrorSpec() const {
57 if (use_bfloat16()) {
58 return ErrorSpec(2e-1, 6e-2);
59 } else {
60 return ErrorSpec(1e-3, 1e-3);
61 }
62 }
63 };
64
65 class ReduceWindowTest : public ::testing::WithParamInterface<bool>,
66 public ReduceWindowTestBase {
67 public:
ReduceWindowTest()68 ReduceWindowTest() : builder_(TestName()) { set_use_bfloat16(GetParam()); }
69
ReduceWindowAdd(const XlaOp input,absl::Span<const int64_t> window_dimensions,absl::Span<const int64_t> window_strides,Padding padding)70 void ReduceWindowAdd(const XlaOp input,
71 absl::Span<const int64_t> window_dimensions,
72 absl::Span<const int64_t> window_strides,
73 Padding padding) {
74 auto init = CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(0.0f),
75 &builder_);
76 ReduceWindow(input, init,
77 CreateScalarAddComputation(FloatType(), &builder_),
78 window_dimensions, window_strides, padding);
79 }
80
ReduceWindowMax(const XlaOp input,absl::Span<const int64_t> window_dimensions,absl::Span<const int64_t> window_strides,Padding padding)81 void ReduceWindowMax(const XlaOp input,
82 absl::Span<const int64_t> window_dimensions,
83 absl::Span<const int64_t> window_strides,
84 Padding padding) {
85 auto init =
86 CreateConstantFromLiteral(LiteralUtil::MinValue(F32), &builder_);
87 ReduceWindow(input, init,
88 CreateScalarMaxComputation(FloatType(), &builder_),
89 window_dimensions, window_strides, padding);
90 }
91
ReduceWindowMin(const XlaOp input,absl::Span<const int64_t> window_dimensions,absl::Span<const int64_t> window_strides,Padding padding)92 void ReduceWindowMin(const XlaOp input,
93 absl::Span<const int64_t> window_dimensions,
94 absl::Span<const int64_t> window_strides,
95 Padding padding) {
96 auto init =
97 CreateConstantFromLiteral(LiteralUtil::MaxValue(F32), &builder_);
98 ReduceWindow(input, init,
99 CreateScalarMinComputation(FloatType(), &builder_),
100 window_dimensions, window_strides, padding);
101 }
102
103 XlaBuilder builder_;
104 };
105
XLA_TEST_P(ReduceWindowTest,MismatchedRanksGivesErrorStatus)106 XLA_TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) {
107 const auto input = CreateConstantFromLiteral(
108 LiteralUtil::CreateR1<float>({1, 1, 1, 1}), &builder_);
109 const auto init_value =
110 CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(0), &builder_);
111 TF_ASSERT_OK(builder_.first_error());
112 ReduceWindow(input, init_value,
113 CreateScalarAddComputation(FloatType(), &builder_),
114 /*window_dimensions=*/{1, 2},
115 /*window_strides=*/{1}, Padding::kValid);
116 ASSERT_EQ(builder_.first_error().code(), tensorflow::error::INVALID_ARGUMENT)
117 << builder_.first_error();
118 ASSERT_THAT(builder_.first_error().error_message(),
119 ::testing::HasSubstr("Want input dimensions size"));
120 }
121
122 // Regression test for b/68964348.
XLA_TEST_P(ReduceWindowTest,R0ReduceWindow)123 XLA_TEST_P(ReduceWindowTest, R0ReduceWindow) {
124 const auto input =
125 CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(42.0), &builder_);
126 const auto init =
127 CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(0.0), &builder_);
128 ReduceWindow(input, init, CreateScalarAddComputation(FloatType(), &builder_),
129 /*window_dimensions=*/{},
130 /*window_strides=*/{}, Padding::kSame);
131 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR0<float>(42.0), {},
132 ErrorSpec(0.00001));
133 }
134
XLA_TEST_P(ReduceWindowTest,Min3In5Stride2)135 XLA_TEST_P(ReduceWindowTest, Min3In5Stride2) {
136 const auto input = CreateConstantFromLiteral(
137 LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_);
138 ReduceWindowMin(input, {3}, {2}, Padding::kValid);
139 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1<float>({100, 1}),
140 {}, ErrorSpec(0.00001));
141 }
142
XLA_TEST_P(ReduceWindowTest,Min3In5Stride2Same)143 XLA_TEST_P(ReduceWindowTest, Min3In5Stride2Same) {
144 const auto input = CreateConstantFromLiteral(
145 LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_);
146 ReduceWindowMin(input, {3}, {2}, Padding::kSame);
147 ComputeAndCompareLiteral(&builder_,
148 LiteralUtil::CreateR1<float>({1000, 10, 1}), {},
149 ErrorSpec(0.00001));
150 }
151
XLA_TEST_P(ReduceWindowTest,Min3In5Stride1WithSamePadding)152 XLA_TEST_P(ReduceWindowTest, Min3In5Stride1WithSamePadding) {
153 const auto input = CreateConstantFromLiteral(
154 LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_);
155 ReduceWindowMin(input, /*window_dimensions=*/{3}, /*window_strides=*/{1},
156 Padding::kSame);
157 ComputeAndCompareLiteral(&builder_,
158 LiteralUtil::CreateR1<float>({1000, 100, 10, 1, 1}),
159 {}, ErrorSpec(0.00001));
160 }
161
XLA_TEST_P(ReduceWindowTest,ZeroElementSmall)162 XLA_TEST_P(ReduceWindowTest, ZeroElementSmall) {
163 Array4D<float> input_array(1, 0, 2, 1);
164 const auto input = CreateConstantFromArray(input_array, &builder_);
165 Padding padding = Padding::kSame;
166 ReduceWindowAdd(input, {1, 1, 2, 1}, {1, 1, 1, 1}, padding);
167
168 auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1},
169 {1, 1, 1, 1}, padding);
170
171 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {},
172 DefaultErrorSpec());
173 }
174
XLA_TEST_P(ReduceWindowTest,NonSquareSmall)175 XLA_TEST_P(ReduceWindowTest, NonSquareSmall) {
176 Array4D<float> input_array(1, 2, 2, 1);
177 input_array.FillRandom(2.f, 2.f);
178 const auto input = CreateConstantFromArray(input_array, &builder_);
179
180 Padding padding = Padding::kSame;
181 ReduceWindowAdd(input, {1, 1, 2, 1}, {1, 1, 1, 1}, padding);
182
183 auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1},
184 {1, 1, 1, 1}, padding);
185
186 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {},
187 DefaultErrorSpec());
188 }
189
XLA_TEST_P(ReduceWindowTest,MiddleDimsSmall)190 XLA_TEST_P(ReduceWindowTest, MiddleDimsSmall) {
191 Array4D<float> input_array(1, 3, 3, 1);
192 input_array.FillRandom(2.f, 2.f);
193 const auto input = CreateConstantFromArray(input_array, &builder_);
194 Padding padding = Padding::kSame;
195 ReduceWindowAdd(input, {1, 1, 1, 1}, {1, 2, 2, 1}, padding);
196
197 auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 1, 1},
198 {1, 2, 2, 1}, padding);
199
200 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {},
201 DefaultErrorSpec());
202 }
203
XLA_TEST_P(ReduceWindowTest,Along2ndMinorDim)204 XLA_TEST_P(ReduceWindowTest, Along2ndMinorDim) {
205 Array4D<float> input_array(3, 6, 7, 32);
206 input_array.FillRandom(2.f, 2.f);
207 const auto input = CreateConstantFromArray(input_array, &builder_);
208
209 // The parameters of this reduction mimic feature norm (e.g. LRN).
210 int lrn_diameter = 7; // diameter = 2*radius + 1 --> must be odd
211 Padding padding = Padding::kSame;
212 ReduceWindowAdd(input, {1, 1, lrn_diameter, 1}, {1, 1, 1, 1}, padding);
213
214 auto res = ReferenceUtil::ReduceWindow4DAdd(
215 input_array, 0.0f, {1, 1, lrn_diameter, 1}, {1, 1, 1, 1}, padding);
216
217 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {},
218 DefaultErrorSpec());
219 }
220
XLA_TEST_P(ReduceWindowTest,AmongMajor2DimsAdd)221 XLA_TEST_P(ReduceWindowTest, AmongMajor2DimsAdd) {
222 Array4D<float> input_array(4, 4, 6, 8);
223 input_array.FillWithMinorDimNum();
224 const auto input_data_handle =
225 CreateConstantFromArray(input_array, &builder_);
226
227 int win_len = 3;
228 int win_stride = 1;
229
230 Padding padding = Padding::kSame;
231 // Reduce only along the x and y dimensions, according to the win_len.
232 ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1},
233 {win_stride, win_stride, 1, 1}, padding);
234
235 auto result = ReferenceUtil::ReduceWindow4DAdd(
236 input_array, 0.0f, {win_len, win_len, 1, 1},
237 {win_stride, win_stride, 1, 1}, padding);
238
239 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
240 DefaultErrorSpec());
241 }
242
XLA_TEST_P(ReduceWindowTest,AmongMajor2DimsMax)243 XLA_TEST_P(ReduceWindowTest, AmongMajor2DimsMax) {
244 Array4D<float> input_array(3, 3, 2, 1);
245 input_array.FillWithMinorDimNum();
246 const auto input_data_handle =
247 CreateConstantFromArray(input_array, &builder_);
248 int win_len = 2;
249 int win_stride = 1;
250 Padding padding = Padding::kValid;
251 // Reduce only along the x and y dimensions, according to the win_len.
252 ReduceWindowMax(input_data_handle, {win_len, win_len, 1, 1},
253 {win_stride, win_stride, 1, 1}, padding);
254 ComputeAndCompare(&builder_, {}, DefaultErrorSpec());
255 }
256
XLA_TEST_P(ReduceWindowTest,AmongMajor2DimsMediumSize)257 XLA_TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) {
258 Array4D<float> input_array(9, 12, 4, 89);
259 input_array.FillRandom(2.f, 2.f);
260
261 int win_len = 3;
262 int win_stride = 2;
263
264 const auto input_data_handle =
265 CreateConstantFromArray(input_array, &builder_);
266
267 Padding padding = Padding::kSame;
268 // Reduce only along the x and y dimensions, according to the win_len.
269 ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1},
270 {win_stride, win_stride, 1, 1}, padding);
271
272 auto result = ReferenceUtil::ReduceWindow4DAdd(
273 input_array, 0.0f, {win_len, win_len, 1, 1},
274 {win_stride, win_stride, 1, 1}, padding);
275
276 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
277 DefaultErrorSpec());
278 }
279
280 // Tests the super windowing logic w.r.t handling prime number of windows in a
281 // major dimension with reduction.
XLA_TEST_P(ReduceWindowTest,PrimeWindowsInReductionDimension)282 XLA_TEST_P(ReduceWindowTest, PrimeWindowsInReductionDimension) {
283 Array4D<float> input_array(15, 15, 4, 128);
284 input_array.FillRandom(2.f, 4.f);
285
286 int win_len = 3;
287 int win_stride = 2;
288
289 const auto input_data_handle =
290 CreateConstantFromArray(input_array, &builder_);
291
292 Padding padding = Padding::kSame;
293 // Reduce only along the x and y dimensions, according to the win_len.
294 ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1},
295 {win_stride, win_stride, 1, 1}, padding);
296
297 auto result = ReferenceUtil::ReduceWindow4DAdd(
298 input_array, 0.0f, {win_len, win_len, 1, 1},
299 {win_stride, win_stride, 1, 1}, padding);
300
301 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
302 DefaultErrorSpec());
303 }
304
XLA_TEST_P(ReduceWindowTest,ReduceAlongLaneDimension)305 XLA_TEST_P(ReduceWindowTest, ReduceAlongLaneDimension) {
306 Array4D<float> input_array(19, 17, 8, 256);
307 input_array.FillWithMinorDimNum();
308
309 const auto input_data_handle =
310 CreateConstantFromArray(input_array, &builder_);
311
312 Padding padding = Padding::kSame;
313 ReduceWindowAdd(input_data_handle, {1, 1, 1, 11}, {1, 1, 1, 1}, padding);
314
315 auto result = ReferenceUtil::ReduceWindow4DAdd(
316 input_array, 0.0f, {1, 1, 1, 11}, {1, 1, 1, 1}, padding);
317
318 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
319 DefaultErrorSpec());
320 }
321
322 // Tests a reduction function that is not a simple add/min/max/etc.
XLA_TEST_P(ReduceWindowTest,NonstandardReduceFunction)323 XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) {
324 Array4D<float> input_array(1, 2, 2, 1);
325 input_array(0, 0, 0, 0) = 1;
326 input_array(0, 0, 1, 0) = 2;
327 input_array(0, 1, 0, 0) = 3;
328 input_array(0, 1, 1, 0) = 4;
329 const auto input = CreateConstantFromArray(input_array, &builder_);
330
331 Padding padding = Padding::kValid;
332 const Shape scalar = ShapeUtil::MakeShape(FloatType(), {});
333 auto b = builder_.CreateSubBuilder("unusual");
334 auto lhs = Parameter(b.get(), 0, scalar, "lhs");
335 auto rhs = Parameter(b.get(), 1, scalar, "rhs");
336 Min(Add(lhs, rhs),
337 CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(8.0f), b.get()));
338 XlaComputation reduce_fn = b->BuildAndNoteError();
339
340 ReduceWindow(
341 input,
342 CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(0.0f), &builder_),
343 reduce_fn,
344 /*window_dimensions=*/{1, 1, 2, 1},
345 /*window_strides=*/{1, 1, 1, 1}, padding);
346
347 const auto reduce_func = [](float arg1, float arg2) {
348 return std::min<float>(arg1 + arg2, 8.0f);
349 };
350
351 auto expected =
352 ReferenceUtil::ReduceWindow4DGeneric(input_array, 0.0f, reduce_func,
353 /*window=*/{1, 1, 2, 1},
354 /*stride=*/{1, 1, 1, 1}, padding);
355
356 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*expected),
357 {}, DefaultErrorSpec());
358 }
359
XLA_TEST_P(ReduceWindowTest,R4UnitWindow)360 XLA_TEST_P(ReduceWindowTest, R4UnitWindow) {
361 Array4D<float> input_array(13, 12, 8, 15);
362 input_array.FillRandom(2.f, 2.f);
363 Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
364 input_array, LayoutUtil::MakeLayout({0, 3, 2, 1}));
365 XlaOp input;
366 TF_ASSERT_OK_AND_ASSIGN(
367 auto input_data, CreateParameterAndTransferLiteral(
368 0, input_literal, "parameter", &builder_, &input));
369
370 Padding padding = Padding::kSame;
371 ReduceWindowAdd(input, {1, 1, 7, 1}, {1, 4, 1, 1}, padding);
372
373 auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 7, 1},
374 {1, 4, 1, 1}, padding);
375
376 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res),
377 {input_data.get()}, DefaultErrorSpec());
378 }
379
XLA_TEST_P(ReduceWindowTest,R6AddMultipleStrides)380 XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) {
381 std::vector<int64_t> input_dims(6, 8);
382 auto shape = ShapeUtil::MakeShape(F32, input_dims);
383
384 Literal arg_literal(shape);
385 arg_literal.PopulateWithValue(1.0f);
386 const auto input = CreateConstantFromLiteral(arg_literal, &builder_);
387
388 Padding padding = Padding::kValid;
389 ReduceWindowAdd(input, {3, 1, 3, 3, 1, 1}, {1, 1, 1, 1, 1, 1}, padding);
390
391 std::vector<int64_t> output_layout = {1, 5, 3, 2, 0, 4};
392 std::vector<int64_t> output_dims = {6, 8, 6, 6, 8, 8};
393 Shape result_shape =
394 ShapeUtil::MakeShapeWithLayout(F32, output_dims, output_layout);
395 Literal expected(result_shape);
396 expected.PopulateWithValue(27.0f);
397 ComputeAndCompareLiteral(&builder_, expected, {}, DefaultErrorSpec());
398 }
399
XLA_TEST_P(ReduceWindowTest,R6Add)400 XLA_TEST_P(ReduceWindowTest, R6Add) {
401 std::vector<int64_t> input_dims(6, 8);
402 auto shape = ShapeUtil::MakeShape(F32, input_dims);
403
404 Literal arg_literal =
405 LiteralUtil::CreateFullWithDescendingLayout<float>(input_dims, 1.0f);
406
407 const auto input = CreateConstantFromLiteral(arg_literal, &builder_);
408
409 Padding padding = Padding::kValid;
410 ReduceWindowAdd(input, {1, 1, 3, 3, 1, 1}, {1, 1, 1, 1, 1, 1}, padding);
411
412 std::vector<int64_t> output_dims = {8, 8, 6, 6, 8, 8};
413 Literal expected =
414 LiteralUtil::CreateFullWithDescendingLayout<float>(output_dims, 9.0f);
415
416 ComputeAndCompareLiteral(&builder_, expected, {}, DefaultErrorSpec());
417 }
418
XLA_TEST_P(ReduceWindowTest,R4SecondMinorStride)419 XLA_TEST_P(ReduceWindowTest, R4SecondMinorStride) {
420 Array4D<float> input_array(2, 1, 27, 119);
421 input_array.FillRandom(2.0f);
422 Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
423 input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
424 XlaOp input;
425 TF_ASSERT_OK_AND_ASSIGN(
426 auto input_data, CreateParameterAndTransferLiteral(
427 0, input_literal, "parameter", &builder_, &input));
428
429 int win_len = 1;
430 int stride = 8;
431 Padding padding = Padding::kSame;
432 ReduceWindowAdd(input, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
433
434 auto res = ReferenceUtil::ReduceWindow4DAdd(
435 input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
436
437 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res),
438 {input_data.get()}, DefaultErrorSpec());
439 }
440
XLA_TEST_P(ReduceWindowTest,R4SecondMinorUnitStride)441 XLA_TEST_P(ReduceWindowTest, R4SecondMinorUnitStride) {
442 Array4D<float> input_array(3, 2, 4, 64);
443 input_array.FillRandom(2.0f);
444 Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
445 input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
446 XlaOp input;
447 TF_ASSERT_OK_AND_ASSIGN(
448 auto input_data, CreateParameterAndTransferLiteral(
449 0, input_literal, "parameter", &builder_, &input));
450
451 int win_len = 3;
452 int stride = 1;
453 Padding padding = Padding::kSame;
454 ReduceWindowAdd(input, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
455
456 auto res = ReferenceUtil::ReduceWindow4DAdd(
457 input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
458
459 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res),
460 {input_data.get()}, DefaultErrorSpec());
461 }
462
XLA_TEST_P(ReduceWindowTest,R4SecondMinorWin)463 XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) {
464 Array4D<float> input_array(1, 3, 12, 200);
465 input_array.FillRandom(2.0f);
466 Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
467 input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
468 XlaOp input;
469 TF_ASSERT_OK_AND_ASSIGN(
470 auto input_data, CreateParameterAndTransferLiteral(
471 0, input_literal, "parameter", &builder_, &input));
472
473 int win_len = 8;
474 int stride = 5;
475 Padding padding = Padding::kSame;
476 ReduceWindowAdd(input, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
477
478 auto res = ReferenceUtil::ReduceWindow4DAdd(
479 input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
480
481 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res),
482 {input_data.get()}, DefaultErrorSpec());
483 }
484
XLA_TEST_P(ReduceWindowTest,AmongMajor2DimsMultipleMinor)485 XLA_TEST_P(ReduceWindowTest, AmongMajor2DimsMultipleMinor) {
486 Array4D<float> input_array(6, 4, 10, 130);
487 input_array.FillRandom(2.0f);
488
489 int win_len = 3;
490 int win_stride = 2;
491
492 Padding padding = Padding::kSame;
493 const auto input_data_handle =
494 CreateConstantFromArray(input_array, &builder_);
495 // Reduce only along the x and y dimensions, according to the win_len.
496 ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1},
497 {win_stride, win_stride, 1, 1}, padding);
498
499 auto result = ReferenceUtil::ReduceWindow4DAdd(
500 input_array, 0.0f, {win_len, win_len, 1, 1},
501 {win_stride, win_stride, 1, 1}, padding);
502 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
503 DefaultErrorSpec());
504 }
505
XLA_TEST_P(ReduceWindowTest,Add24In1152_NoOverlap)506 XLA_TEST_P(ReduceWindowTest, Add24In1152_NoOverlap) {
507 std::vector<float> input_vector(128 * 9, 1);
508 const auto input = CreateConstantFromLiteral(
509 LiteralUtil::CreateR1<float>(input_vector), &builder_);
510 ReduceWindowAdd(input, {32}, {128}, Padding::kValid);
511 ComputeAndCompareLiteral(
512 &builder_,
513 LiteralUtil::CreateR1<float>({32, 32, 32, 32, 32, 32, 32, 32, 32}), {},
514 DefaultErrorSpec());
515 }
516
XLA_TEST_P(ReduceWindowTest,Add128In128Stride128)517 XLA_TEST_P(ReduceWindowTest, Add128In128Stride128) {
518 std::vector<float> input_vector{
519 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
520 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
521 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
522 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
523 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
524 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
525 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
526 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
527 const auto input = CreateConstantFromLiteral(
528 LiteralUtil::CreateR1<float>(input_vector), &builder_);
529 ReduceWindowAdd(input, {128}, {128}, Padding::kValid);
530 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1<float>({1088}), {},
531 DefaultErrorSpec());
532 }
533
XLA_TEST_P(ReduceWindowTest,Add128In128)534 XLA_TEST_P(ReduceWindowTest, Add128In128) {
535 std::vector<float> input_vector{
536 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
537 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
538 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
539 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
540 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
541 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
542 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
543 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
544 const auto input = CreateConstantFromLiteral(
545 LiteralUtil::CreateR1<float>(input_vector), &builder_);
546 ReduceWindowAdd(input, {128}, {1}, Padding::kValid);
547 ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1<float>({1088}), {},
548 DefaultErrorSpec());
549 }
550
551 // Regression test for a bug that appeared in Inception (b/34784899).
XLA_TEST_P(ReduceWindowTest,R2ReduceWindowInceptionFromBroadcast)552 XLA_TEST_P(ReduceWindowTest, R2ReduceWindowInceptionFromBroadcast) {
553 Array2D<float> input_array(14, 14, 1.0f);
554 const auto input = CreateConstantFromArray(input_array, &builder_);
555 int win_len = 3;
556 int stride = 1;
557 Padding padding = Padding::kSame;
558 ReduceWindowAdd(input, {win_len, win_len}, {stride, stride}, padding);
559 ComputeAndCompare(&builder_, {}, DefaultErrorSpec());
560 }
561
XLA_TEST_P(ReduceWindowTest,R2ReduceWindowNonOverlappingFromBroadcast)562 XLA_TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) {
563 Array2D<float> input_array(6, 4, 1.0f);
564 XlaOp input = Broadcast(
565 CreateConstantFromLiteral(LiteralUtil::One(F32), &builder_), {6, 4});
566 Padding padding = Padding::kSame;
567 ReduceWindowAdd(input, {4, 2}, {3, 3}, padding);
568 ComputeAndCompare(&builder_, {}, DefaultErrorSpec());
569 }
570
571 INSTANTIATE_TEST_CASE_P(ReduceWindowTestInstance, ReduceWindowTest,
572 ::testing::ValuesIn(use_bfloat16_params));
573
574 enum Reducer { kAdd, kMax };
575
576 struct R4ReduceWindowTestData {
577 int64_t base_bounds[4];
578 int64_t window_bounds[4];
579 int64_t strides[4];
580 int64_t pad_low[4];
581 int64_t pad_high[4];
582 int64_t layout[4];
583
584 Reducer reducer;
585 };
586
R4ReduceWindowTestDataToString(const::testing::TestParamInfo<::testing::tuple<R4ReduceWindowTestData,bool>> & data)587 std::string R4ReduceWindowTestDataToString(
588 const ::testing::TestParamInfo<
589 ::testing::tuple<R4ReduceWindowTestData, bool>>& data) {
590 const auto& param = ::testing::get<0>(data.param);
591 std::string str = absl::StrCat(
592 "base_bounds_", absl::StrJoin(param.base_bounds, "x"), //
593 "__window_bounds_", absl::StrJoin(param.window_bounds, "x"), //
594 "__strides_", absl::StrJoin(param.strides, "x"), //
595 "__pad_low_", absl::StrJoin(param.pad_low, "x"), //
596 "__pad_high_", absl::StrJoin(param.pad_high, "x"), //
597 "__layout_", absl::StrJoin(param.layout, "_"), //
598 (param.reducer == kAdd) ? "_add" : "_max");
599 CHECK(param.reducer == kAdd || param.reducer == kMax);
600
601 // Test names are not allowed to contain the '-' character.
602 std::replace(str.begin(), str.end(), '-', 'n');
603 if (::testing::get<1>(data.param)) {
604 absl::StrAppend(&str, "_bfloat16");
605 }
606 return str;
607 }
608
609 class R4ReduceWindowTest : public ReduceWindowTestBase,
610 public ::testing::WithParamInterface<
611 ::testing::tuple<R4ReduceWindowTestData, bool>> {
612 protected:
R4ReduceWindowTest()613 R4ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
614
DoIt()615 void DoIt() {
616 XlaBuilder b(TestName());
617 const auto& param = ::testing::get<0>(GetParam());
618
619 const float kInitValue = 0.0f;
620
621 Array4D<float> input(param.base_bounds[0], param.base_bounds[1],
622 param.base_bounds[2], param.base_bounds[3]);
623 // Choose a prime iota length so that each window sees a unique set of
624 // values. (Technically, the requirement is that the iota length is
625 // relatively prime to all of the dimensions involved in the reduce-window.)
626 input.FillRepeatedIota(0, 137);
627 // Floating point sum reduction requires higher localized precision. We need
628 // the following normalization in order to enable testing of kAdd on large
629 // windows.
630 input.Each([&](absl::Span<const int64_t> /*indices*/, float* value) {
631 *value = *value / 10000000000.f;
632 });
633 Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
634 input, LayoutUtil::MakeLayout(param.layout));
635 XlaOp parameter;
636 TF_ASSERT_OK_AND_ASSIGN(auto input_arg,
637 CreateParameterAndTransferLiteral(
638 0, input_literal, "p0", &b, ¶meter));
639
640 std::vector<std::pair<int64_t, int64_t>> padding(4);
641 for (int i = 0; i < 4; ++i) {
642 padding[i] = {param.pad_low[i], param.pad_high[i]};
643 }
644
645 auto init_value =
646 CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b);
647 CHECK(param.reducer == kAdd || param.reducer == kMax);
648 auto reducer = param.reducer;
649 auto computation = reducer == kAdd
650 ? CreateScalarAddComputation(FloatType(), &b)
651 : CreateScalarMaxComputation(FloatType(), &b);
652 ReduceWindowWithGeneralPadding(
653 /*operand=*/parameter,
654 /*init_value=*/init_value,
655 /*computation=*/computation,
656 /*window_dimensions=*/param.window_bounds,
657 /*window_strides=*/param.strides,
658 /*base_dilations=*/{},
659 /*window_dilations=*/{},
660 /*padding=*/padding);
661
662 CHECK(reducer == kAdd || reducer == kMax);
663 auto reduce_func = reducer == kAdd
664 ? +[](float a, float b) { return a + b; }
665 : +[](float a, float b) { return std::max(a, b); };
666 std::unique_ptr<Array4D<float>> expected =
667 ReferenceUtil::ReduceWindow4DGeneric(
668 /*operand=*/input,
669 /*init=*/kInitValue,
670 /*reduce_func=*/reduce_func,
671 /*window=*/param.window_bounds,
672 /*stride=*/param.strides,
673 /*padding=*/padding);
674 Literal expected_literal = LiteralUtil::CreateFromArray(*expected);
675 const Shape& expected_shape_with_layout = ShapeUtil::MakeShapeWithLayout(
676 input_literal.shape().element_type(),
677 expected_literal.shape().dimensions(), param.layout);
678 ComputeAndCompareLiteral(&b, expected_literal, {input_arg.get()},
679 DefaultErrorSpec(), &expected_shape_with_layout);
680 }
681 };
682
XLA_TEST_P(R4ReduceWindowTest,DoIt)683 XLA_TEST_P(R4ReduceWindowTest, DoIt) { DoIt(); }
684
685 // base_bounds, window_bounds, strides, pad_low, pad_high
686 const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = {
687 // Minimal edge case.
688 R4ReduceWindowTestData{/*base_bounds=*/{1, 1, 1, 1},
689 /*window_bounds=*/{1, 1, 1, 1},
690 /*strides=*/{1, 1, 1, 1},
691 /*pad_low=*/{0, 0, 0, 0},
692 /*pad_high=*/{0, 0, 0, 0},
693 /*layout=*/{3, 2, 1, 0},
694 /*reducer=*/kAdd},
695
696 // Arbitrary padding (not kSame or kValid).
697 R4ReduceWindowTestData{/*base_bounds=*/{9, 12, 4, 89},
698 /*window_bounds=*/{3, 3, 1, 1},
699 /*strides=*/{2, 2, 1, 1},
700 /*pad_low=*/{4, 4, 0, 0},
701 /*pad_high=*/{4, 4, 0, 0},
702 /*layout=*/{3, 2, 1, 0},
703 /*reducer=*/kAdd},
704
705 // Zero base bound edge case.
706 R4ReduceWindowTestData{/*base_bounds=*/{1, 0, 1, 1},
707 /*window_bounds=*/{1, 1, 1, 1},
708 /*strides=*/{1, 1, 1, 1},
709 /*pad_low=*/{0, 0, 0, 0},
710 /*pad_high=*/{0, 0, 0, 0},
711 /*layout=*/{3, 2, 1, 0},
712 /*reducer=*/kAdd},
713
714 // With max instead of add.
715 R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
716 /*window_bounds=*/{2, 3, 1, 1},
717 /*strides=*/{1, 1, 1, 1},
718 /*pad_low=*/{0, 0, 0, 0},
719 /*pad_high=*/{0, 0, 0, 0},
720 /*layout=*/{3, 2, 1, 0},
721 /*reducer=*/kMax},
722
723 // With stride.
724 R4ReduceWindowTestData{/*base_bounds=*/{4, 10, 17, 140},
725 /*window_bounds=*/{3, 2, 1, 1},
726 /*strides=*/{2, 4, 1, 1},
727 /*pad_low=*/{0, 0, 0, 0},
728 /*pad_high=*/{0, 0, 0, 0},
729 /*layout=*/{3, 2, 1, 0},
730 /*reducer=*/kAdd},
731
732 // With low padding.
733 R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
734 /*window_bounds=*/{3, 2, 1, 1},
735 /*strides=*/{2, 2, 1, 1},
736 /*pad_low=*/{3, 2, 0, 0},
737 /*pad_high=*/{0, 0, 0, 0},
738 /*layout=*/{3, 2, 1, 0},
739 /*reducer=*/kAdd},
740
741 // With high padding.
742 R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
743 /*window_bounds=*/{3, 2, 1, 1},
744 /*strides=*/{2, 2, 1, 1},
745 /*pad_low=*/{0, 0, 0, 0},
746 /*pad_high=*/{2, 3, 0, 0},
747 /*layout=*/{3, 2, 1, 0},
748 /*reducer=*/kAdd},
749
750 // With negative padding on both ends.
751 R4ReduceWindowTestData{/*base_bounds=*/{10, 10, 17, 140},
752 /*window_bounds=*/{3, 2, 1, 1},
753 /*strides=*/{2, 2, 1, 1},
754 /*pad_low=*/{-3, -2, 0, 0},
755 /*pad_high=*/{-2, -3, 0, 0},
756 /*layout=*/{3, 2, 1, 0},
757 /*reducer=*/kAdd},
758
759 // With negative low padding and positive high padding.
760 R4ReduceWindowTestData{/*base_bounds=*/{10, 10, 17, 140},
761 /*window_bounds=*/{3, 2, 1, 1},
762 /*strides=*/{2, 2, 1, 1},
763 /*pad_low=*/{-3, -2, 0, 0},
764 /*pad_high=*/{2, 3, 0, 0},
765 /*layout=*/{3, 2, 1, 0},
766 /*reducer=*/kAdd},
767
768 // With positive low padding and negative high padding.
769 R4ReduceWindowTestData{/*base_bounds=*/{10, 10, 17, 140},
770 /*window_bounds=*/{3, 2, 1, 1},
771 /*strides=*/{2, 2, 1, 1},
772 /*pad_low=*/{3, 2, 0, 0},
773 /*pad_high=*/{-2, -3, 0, 0},
774 /*layout=*/{3, 2, 1, 0},
775 /*reducer=*/kAdd},
776
777 // Window touches both sides of the padding simultaneously.
778 R4ReduceWindowTestData{/*base_bounds=*/{1, 1, 17, 140},
779 /*window_bounds=*/{3, 3, 1, 1},
780 /*strides=*/{1, 1, 1, 1},
781 /*pad_low=*/{1, 1, 0, 0},
782 /*pad_high=*/{1, 1, 0, 0},
783 /*layout=*/{3, 2, 1, 0},
784 /*reducer=*/kAdd},
785
786 // Window is entirely in the padding for some positions.
787 R4ReduceWindowTestData{/*base_bounds=*/{1, 1, 17, 140},
788 /*window_bounds=*/{3, 3, 1, 1},
789 /*strides=*/{1, 1, 1, 1},
790 /*pad_low=*/{4, 4, 0, 0},
791 /*pad_high=*/{4, 4, 0, 0},
792 /*layout=*/{3, 2, 1, 0},
793 /*reducer=*/kAdd},
794
795 // Zero base bound with padding edge case.
796 R4ReduceWindowTestData{/*base_bounds=*/{2, 0, 3, 4},
797 /*window_bounds=*/{1, 1, 1, 1},
798 /*strides=*/{1, 1, 1, 1},
799 /*pad_low=*/{0, 1, 0, 0},
800 /*pad_high=*/{0, 0, 0, 0},
801 /*layout=*/{3, 2, 1, 0},
802 /*reducer=*/kAdd},
803
804 // With stride, low padding and high padding.
805 R4ReduceWindowTestData{/*base_bounds=*/{4, 3, 17, 140},
806 /*window_bounds=*/{3, 4, 1, 1},
807 /*strides=*/{3, 1, 1, 1},
808 /*pad_low=*/{10, 1, 0, 0},
809 /*pad_high=*/{2, 3, 0, 0},
810 /*layout=*/{3, 2, 1, 0},
811 /*reducer=*/kAdd},
812
813 // With minor dimension == 129.
814 R4ReduceWindowTestData{/*base_bounds=*/{3, 2, 7, 129},
815 /*window_bounds=*/{1, 1, 1, 1},
816 /*strides=*/{1, 1, 1, 1},
817 /*pad_low=*/{0, 0, 0, 0},
818 /*pad_high=*/{0, 0, 0, 0},
819 /*layout=*/{3, 2, 1, 0},
820 /*reducer=*/kAdd},
821
822 // With minor dims reduction and non-overlapped stride.
823 R4ReduceWindowTestData{/*base_bounds=*/{2, 2, 4, 16},
824 /*window_bounds=*/{1, 1, 2, 2},
825 /*strides=*/{1, 1, 2, 2},
826 /*pad_low=*/{0, 0, 0, 0},
827 /*pad_high=*/{0, 0, 0, 0},
828 /*layout=*/{3, 2, 1, 0},
829 /*reducer=*/kAdd},
830
831 // With minor dims reduction and overlapped stride.
832 R4ReduceWindowTestData{/*base_bounds=*/{2, 2, 4, 16},
833 /*window_bounds=*/{1, 1, 4, 4},
834 /*strides=*/{1, 1, 2, 2},
835 /*pad_low=*/{0, 0, 0, 0},
836 /*pad_high=*/{1, 0, 0, 0},
837 /*layout=*/{3, 2, 1, 0},
838 /*reducer=*/kAdd},
839
840 R4ReduceWindowTestData{/*base_bounds=*/{8, 100, 100, 3},
841 /*window_bounds=*/{1, 64, 64, 1},
842 /*strides=*/{1, 64, 64, 1},
843 /*pad_low=*/{0, 0, 0, 0},
844 /*pad_high=*/{0, 0, 0, 0},
845 /*layout=*/{3, 0, 2, 1},
846 /*reducer=*/kAdd},
847
848 R4ReduceWindowTestData{/*base_bounds=*/{112, 112, 8, 64},
849 /*window_bounds=*/{112, 112, 1, 8},
850 /*strides=*/{112, 112, 1, 8},
851 /*pad_low=*/{0, 0, 0, 0},
852 /*pad_high=*/{0, 0, 0, 0},
853 /*layout=*/{3, 2, 1, 0},
854 /*reducer=*/kMax},
855
856 R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
857 /*window_bounds=*/{2, 3, 4, 5},
858 /*strides=*/{1, 1, 1, 1},
859 /*pad_low=*/{0, 0, 0, 0},
860 /*pad_high=*/{0, 0, 0, 0},
861 /*layout=*/{3, 2, 1, 0},
862 /*reducer=*/kAdd},
863
864 // With 0321 layout.
865 R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
866 /*window_bounds=*/{2, 3, 4, 5},
867 /*strides=*/{1, 2, 3, 4},
868 /*pad_low=*/{0, 0, 0, 0},
869 /*pad_high=*/{0, 0, 0, 0},
870 /*layout=*/{0, 3, 2, 1},
871 /*reducer=*/kAdd},
872
873 // With 0123 layout.
874 R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 13, 17},
875 /*window_bounds=*/{2, 3, 7, 9},
876 /*strides=*/{1, 2, 5, 8},
877 /*pad_low=*/{0, 0, 0, 0},
878 /*pad_high=*/{0, 0, 0, 0},
879 /*layout=*/{0, 1, 2, 3},
880 /*reducer=*/kAdd},
881 };
882
883 INSTANTIATE_TEST_CASE_P(
884 R4ReduceWindowTestInstantiation, R4ReduceWindowTest,
885 ::testing::Combine(::testing::ValuesIn(kR4ReduceWindowTestValues),
886 ::testing::ValuesIn(use_bfloat16_params)),
887 R4ReduceWindowTestDataToString);
888
889 class R4ReduceWindowLargeTest : public R4ReduceWindowTest {};
890
XLA_TEST_P(R4ReduceWindowLargeTest,DISABLED_ON_INTERPRETER (DoIt))891 XLA_TEST_P(R4ReduceWindowLargeTest, DISABLED_ON_INTERPRETER(DoIt)) { DoIt(); }
892
893 // Test cases that are large/slow/failed.
894 const R4ReduceWindowTestData kR4ReduceWindowLargeTestValues[] = {
895 R4ReduceWindowTestData{/*base_bounds=*/{28, 28, 256, 128},
896 /*window_bounds=*/{3, 3, 1, 5},
897 /*strides=*/{1, 1, 1, 5},
898 /*pad_low=*/{1, 1, 0, 0},
899 /*pad_high=*/{1, 1, 0, 0},
900 /*layout=*/{3, 2, 1, 0},
901 /*reducer=*/kMax},
902
903 R4ReduceWindowTestData{/*base_bounds=*/{112, 112, 64, 128},
904 /*window_bounds=*/{3, 3, 1, 1},
905 /*strides=*/{2, 2, 1, 1},
906 /*pad_low=*/{0, 0, 0, 0},
907 /*pad_high=*/{1, 1, 0, 0},
908 /*layout=*/{3, 2, 1, 0},
909 /*reducer=*/kAdd},
910
911 R4ReduceWindowTestData{/*base_bounds=*/{1, 1, 32768 - 3, 2},
912 /*window_bounds=*/{1, 1, 4, 1},
913 /*strides=*/{1, 1, 4, 1},
914 /*pad_low=*/{0, 0, 1, 0},
915 /*pad_high=*/{0, 0, 2, 0},
916 /*layout=*/{3, 2, 1, 0},
917 /*reducer=*/kMax},
918
919 // Patterns generated by cumsum/cumprod.
920 R4ReduceWindowTestData{/*base_bounds=*/{1021, 1, 16, 16},
921 /*window_bounds=*/{1021, 1, 1, 1},
922 /*strides=*/{1, 1, 1, 1},
923 /*pad_low=*/{1020, 0, 0, 0},
924 /*pad_high=*/{0, 0, 0, 0},
925 /*layout=*/{3, 2, 1, 0},
926 /*reducer=*/kAdd},
927
928 R4ReduceWindowTestData{/*base_bounds=*/{1021, 1, 16, 16},
929 /*window_bounds=*/{1, 1, 1021, 1},
930 /*strides=*/{1, 1, 1, 1},
931 /*pad_low=*/{0, 0, 1020, 0},
932 /*pad_high=*/{0, 0, 0, 0},
933 /*layout=*/{3, 2, 1, 0},
934 /*reducer=*/kAdd},
935
936 R4ReduceWindowTestData{/*base_bounds=*/{16, 1, 16, 1021},
937 /*window_bounds=*/{1, 1, 1, 1021},
938 /*strides=*/{1, 1, 1, 1},
939 /*pad_low=*/{0, 0, 0, 1020},
940 /*pad_high=*/{0, 0, 0, 0},
941 /*layout=*/{3, 2, 1, 0},
942 /*reducer=*/kAdd},
943
944 R4ReduceWindowTestData{/*base_bounds=*/{1021, 1, 16, 16},
945 /*window_bounds=*/{1021, 1, 1, 1},
946 /*strides=*/{1, 1, 1, 1},
947 /*pad_low=*/{1021, 0, 0, 0},
948 /*pad_high=*/{0, 0, 0, 0},
949 /*layout=*/{3, 2, 1, 0},
950 /*reducer=*/kAdd},
951
952 R4ReduceWindowTestData{/*base_bounds=*/{16, 1, 1021, 16},
953 /*window_bounds=*/{1, 1, 1021, 1},
954 /*strides=*/{1, 1, 1, 1},
955 /*pad_low=*/{0, 0, 1021, 0},
956 /*pad_high=*/{0, 0, 0, 0},
957 /*layout=*/{3, 2, 1, 0},
958 /*reducer=*/kAdd},
959
960 R4ReduceWindowTestData{/*base_bounds=*/{16, 1, 16, 1021},
961 /*window_bounds=*/{1, 1, 1, 1021},
962 /*strides=*/{1, 1, 1, 1},
963 /*pad_low=*/{0, 0, 0, 1021},
964 /*pad_high=*/{0, 0, 0, 0},
965 /*layout=*/{3, 2, 1, 0},
966 /*reducer=*/kAdd},
967 };
968
969 INSTANTIATE_TEST_CASE_P(
970 R4ReduceWindowLargeTestInstantiation, R4ReduceWindowLargeTest,
971 ::testing::Combine(::testing::ValuesIn(kR4ReduceWindowLargeTestValues),
972 ::testing::ValuesIn(use_bfloat16_params)),
973 R4ReduceWindowTestDataToString);
974
975 struct R3ReduceWindowTestData {
976 int64_t base_bounds[3];
977 int64_t window_bounds[3];
978 int64_t strides[3];
979 int64_t layout[3];
980 Padding padding;
981 Reducer reducer;
982 } kR3TestCases[] = {
983 {/*base_bounds=*/{2, 1, 2}, /*window_bounds=*/{1, 1, 2},
984 /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
985 /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
986 {/*base_bounds=*/{4, 3, 3}, /*window_bounds=*/{2, 2, 2},
987 /*strides=*/{2, 2, 2}, /*layout=*/{2, 1, 0},
988 /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd},
989 {/*base_bounds=*/{4, 3, 3}, /*window_bounds=*/{2, 2, 2},
990 /*strides=*/{2, 2, 2}, /*layout=*/{2, 1, 0},
991 /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
992 {/*base_bounds=*/{6, 21, 3}, /*window_bounds=*/{2, 3, 2},
993 /*strides=*/{1, 2, 2}, /*layout=*/{2, 1, 0},
994 /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
995 {/*base_bounds=*/{10, 21, 129}, /*window_bounds=*/{2, 9, 1},
996 /*strides=*/{5, 2, 1}, /*layout=*/{2, 1, 0},
997 /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd},
998 {/*base_bounds=*/{6, 21, 3}, /*window_bounds=*/{2, 3, 2},
999 /*strides=*/{1, 2, 2}, /*layout=*/{0, 1, 2},
1000 /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
1001 {/*base_bounds=*/{6, 21, 3}, /*window_bounds=*/{2, 3, 2},
1002 /*strides=*/{1, 2, 2}, /*layout=*/{1, 0, 2},
1003 /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
1004 {/*base_bounds=*/{95, 202, 251}, /*window_bounds=*/{95, 202, 251},
1005 /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
1006 /*padding=*/Padding::kValid, /*reducer=*/Reducer::kMax},
1007 {/*base_bounds=*/{999, 57, 3}, /*window_bounds=*/{999, 57, 3},
1008 /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
1009 /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
1010 {/*base_bounds=*/{178, 302, 64}, /*window_bounds=*/{178, 302, 64},
1011 /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
1012 /*padding=*/Padding::kValid, /*reducer=*/Reducer::kMax},
1013 {/*base_bounds=*/{63, 261, 257}, /*window_bounds=*/{63, 261, 257},
1014 /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
1015 /*padding=*/Padding::kValid, /*reducer=*/Reducer::kMax},
1016 {/*base_bounds=*/{10003, 10, 5}, /*window_bounds=*/{9999, 7, 3},
1017 /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
1018 /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
1019 {/*base_bounds=*/{9999, 1, 1}, /*window_bounds=*/{9999, 1, 1},
1020 /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
1021 /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
1022 {/*base_bounds=*/{10003, 10, 5}, /*window_bounds=*/{9999, 7, 3},
1023 /*strides=*/{2, 2, 2}, /*layout=*/{2, 1, 0},
1024 /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
1025 };
1026
R3ReduceWindowTestDataToString(const::testing::TestParamInfo<::testing::tuple<R3ReduceWindowTestData,bool>> & data)1027 std::string R3ReduceWindowTestDataToString(
1028 const ::testing::TestParamInfo<
1029 ::testing::tuple<R3ReduceWindowTestData, bool>>& data) {
1030 const auto& param = ::testing::get<0>(data.param);
1031 std::string str = absl::StrCat(
1032 "base_bounds_", absl::StrJoin(param.base_bounds, "x"), "__window_bounds_",
1033 absl::StrJoin(param.window_bounds, "x"), "__strides_",
1034 absl::StrJoin(param.strides, "x"), "__padding_",
1035 param.padding == Padding::kSame ? "same" : "valid", "__layout_",
1036 param.layout[0], "_", param.layout[1], "_", param.layout[2], "__reducer_",
1037 param.reducer == kAdd ? "add" : "max");
1038 if (::testing::get<1>(data.param)) {
1039 absl::StrAppend(&str, "_bfloat16");
1040 }
1041 return str;
1042 }
1043
1044 class R3ReduceWindowTest : public ReduceWindowTestBase,
1045 public ::testing::WithParamInterface<
1046 ::testing::tuple<R3ReduceWindowTestData, bool>> {
1047 protected:
R3ReduceWindowTest()1048 R3ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
1049 };
1050
XLA_TEST_P(R3ReduceWindowTest,DoIt)1051 XLA_TEST_P(R3ReduceWindowTest, DoIt) {
1052 XlaBuilder b(TestName());
1053 const auto& param = ::testing::get<0>(GetParam());
1054
1055 const float kInitValue = 0.0f;
1056 Array3D<float> input(param.base_bounds[0], param.base_bounds[1],
1057 param.base_bounds[2]);
1058 // Choose a prime iota length so that each window sees a unique set of values.
1059 // (Technically, the requirement is that the iota length is relatively prime
1060 // to all of the dimensions involved in the reduce-window.)
1061 input.FillRepeatedIota(0, 137);
1062 Literal input_literal = LiteralUtil::CreateR3FromArray3DWithLayout(
1063 input, LayoutUtil::MakeLayout(param.layout));
1064 auto reducer = param.reducer;
1065 if (use_bfloat16()) {
1066 input_literal = LiteralUtil::ConvertF32ToBF16(input_literal);
1067
1068 // To avoid numerical issues, force the reducer to be kMax for bf16
1069 // inputs.
1070 reducer = kMax;
1071 }
1072
1073 XlaOp parameter = Parameter(&b, 0, input_literal.shape(), "input");
1074 auto init_value =
1075 CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b);
1076
1077 auto computation = reducer == kAdd
1078 ? CreateScalarAddComputation(FloatType(), &b)
1079 : CreateScalarMaxComputation(FloatType(), &b);
1080
1081 ReduceWindow(/*operand=*/parameter,
1082 /*init_value=*/init_value,
1083 /*computation=*/computation,
1084 /*window_dimensions=*/param.window_bounds,
1085 /*window_strides=*/param.strides, /*padding=*/param.padding);
1086
1087 ComputeAndCompare(&b, {std::move(input_literal)}, DefaultErrorSpec());
1088 }
1089
1090 INSTANTIATE_TEST_CASE_P(
1091 R3ReduceWindowTestInstantiation, R3ReduceWindowTest,
1092 ::testing::Combine(::testing::ValuesIn(kR3TestCases),
1093 ::testing::ValuesIn(use_bfloat16_params)),
1094 R3ReduceWindowTestDataToString);
1095
1096 struct R2ReduceWindowTestData {
1097 int64_t base_bounds[2];
1098 int64_t window_bounds[2];
1099 int64_t strides[2];
1100 int64_t base_dilation[2];
1101 int64_t window_dilation[2];
1102 int64_t pad_low[2];
1103 int64_t pad_high[2];
1104 int64_t layout[2];
1105 Reducer reducer;
1106 } kR2TestCases[] = {
1107 {/*base_bounds=*/{4, 18}, /*window_bounds=*/{2, 4},
1108 /*strides=*/{1, 2},
1109 /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1},
1110 /*pad_low=*/{0, 1}, /*pad_high=*/{1, 1},
1111 /*layout=*/{0, 1},
1112 /*reducer=*/Reducer::kAdd},
1113 {/*base_bounds=*/{2, 5}, /*window_bounds=*/{2, 4},
1114 /*strides=*/{1, 1},
1115 /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1},
1116 /*pad_low=*/{0, 1}, /*pad_high=*/{1, 2},
1117 /*layout=*/{0, 1},
1118 /*reducer=*/Reducer::kAdd},
1119 {/*base_bounds=*/{1, 3}, /*window_bounds=*/{2, 3},
1120 /*strides=*/{1, 1},
1121 /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1},
1122 /*pad_low=*/{0, 1}, /*pad_high=*/{1, 1},
1123 /*layout=*/{0, 1},
1124 /*reducer=*/Reducer::kAdd},
1125 {/*base_bounds=*/{3, 129}, /*window_bounds=*/{1, 100},
1126 /*strides=*/{2, 99},
1127 /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1},
1128 /*pad_low=*/{0, 0}, /*pad_high=*/{35, 35},
1129 /*layout=*/{0, 1},
1130 /*reducer=*/Reducer::kAdd},
1131 // TODO(b/74260408): This test last failed on GPU on 2018-03-08, likely due to a
1132 // ptxas bug.
1133 #ifndef XLA_TEST_BACKEND_GPU
1134 {/*base_bounds=*/{6, 152}, /*window_bounds=*/{2, 25},
1135 /*strides=*/{5, 4},
1136 /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1},
1137 /*pad_low=*/{0, 1}, /*pad_high=*/{10, 11},
1138 /*layout=*/{0, 1},
1139 /*reducer=*/Reducer::kAdd},
1140 #endif
1141 {/*base_bounds=*/{6, 4}, /*window_bounds=*/{4, 2},
1142 /*strides=*/{3, 3},
1143 /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1},
1144 /*pad_low=*/{0, 1}, /*pad_high=*/{0, 1},
1145 /*layout=*/{0, 1},
1146 /*reducer=*/Reducer::kAdd},
1147 {/*base_bounds=*/{5, 147}, /*window_bounds=*/{1, 36},
1148 /*strides=*/{4, 5},
1149 /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1},
1150 /*pad_low=*/{0, 0}, /*pad_high=*/{17, 17},
1151 /*layout=*/{1, 0},
1152 /*reducer=*/Reducer::kAdd},
1153 {/*base_bounds=*/{4, 153}, /*window_bounds=*/{2, 93},
1154 /*strides=*/{1, 1},
1155 /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1},
1156 /*pad_low=*/{0, 1}, /*pad_high=*/{46, 46},
1157 /*layout=*/{1, 0},
1158 /*reducer=*/Reducer::kAdd},
1159 // Regression test for a bug that appeared in Inception (b/34784899).
1160 {/*base_bounds=*/{28, 28}, /*window_bounds=*/{3, 3},
1161 /*strides=*/{1, 1},
1162 /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1},
1163 /*pad_low=*/{1, 1}, /*pad_high=*/{1, 1},
1164 /*layout=*/{1, 0},
1165 /*reducer=*/Reducer::kAdd},
1166 {/*base_bounds=*/{4, 4}, /*window_bounds=*/{2, 2},
1167 /*strides=*/{1, 1},
1168 /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1},
1169 /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0},
1170 /*layout=*/{1, 0},
1171 /*reducer=*/Reducer::kMax},
1172 {/*base_bounds=*/{4, 4}, /*window_bounds=*/{2, 2},
1173 /*strides=*/{1, 1},
1174 /*base_dilation=*/{1, 1}, /*window_dilation=*/{2, 2},
1175 /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0},
1176 /*layout=*/{1, 0},
1177 /*reducer=*/Reducer::kMax},
1178 {/*base_bounds=*/{4, 4}, /*window_bounds=*/{2, 2},
1179 /*strides=*/{1, 1},
1180 /*base_dilation=*/{2, 2}, /*window_dilation=*/{1, 1},
1181 /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0},
1182 /*layout=*/{1, 0},
1183 /*reducer=*/Reducer::kMax},
1184 {/*base_bounds=*/{4, 4}, /*window_bounds=*/{2, 2},
1185 /*strides=*/{2, 2},
1186 /*base_dilation=*/{2, 2}, /*window_dilation=*/{1, 1},
1187 /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0},
1188 /*layout=*/{1, 0},
1189 /*reducer=*/Reducer::kMax},
1190 {/*base_bounds=*/{4, 4}, /*window_bounds=*/{2, 2},
1191 /*strides=*/{2, 2},
1192 /*base_dilation=*/{2, 2}, /*window_dilation=*/{1, 1},
1193 /*pad_low=*/{3, 3}, /*pad_high=*/{3, 3},
1194 /*layout=*/{1, 0},
1195 /*reducer=*/Reducer::kMax},
1196 {/*base_bounds=*/{4, 4}, /*window_bounds=*/{2, 2},
1197 /*strides=*/{2, 2},
1198 /*base_dilation=*/{2, 2}, /*window_dilation=*/{2, 2},
1199 /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0},
1200 /*layout=*/{1, 0},
1201 /*reducer=*/Reducer::kMax},
1202 // Regression test for a bug that appeared in Inception (b/34784899).
1203 {/*base_bounds=*/{4, 32}, /*window_bounds=*/{2, 2},
1204 /*strides=*/{2, 2},
1205 /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1},
1206 /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0},
1207 /*layout=*/{1, 0},
1208 /*reducer=*/Reducer::kAdd},
1209 // Regression test for b/73903312: bf16 lacks precision to store result of
1210 // very large windows. Testing with a reasonable window larger than 128.
1211 {/*base_bounds=*/{8, 130}, /*window_bounds=*/{1, 130},
1212 /*strides=*/{1, 1},
1213 /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1},
1214 /*pad_low=*/{0, 130}, /*pad_high=*/{0, 0},
1215 /*layout=*/{1, 0},
1216 /*reducer=*/Reducer::kAdd},
1217 {/*base_bounds=*/{8, 256}, /*window_bounds=*/{1, 4},
1218 /*strides=*/{1, 64},
1219 /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1},
1220 /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0},
1221 /*layout=*/{1, 0}, /*reducer=*/Reducer::kAdd},
1222 {/*base_bounds=*/{4096, 4096}, /*window_bounds=*/{1, 4},
1223 /*strides=*/{1, 1024},
1224 /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1},
1225 /*pad_low=*/{0, 0}, /*pad-high=*/{0, 0},
1226 /*layout=*/{1, 0}, /*reducer=*/Reducer::kAdd},
1227 // Regression test for b/72234705: bf16 lacks precision to store incremental
1228 // results on very large windows. Using smaller window with minor dim 128.
1229 {/*base_bounds=*/{8, 128}, /*window_bounds=*/{2, 128},
1230 /*strides=*/{1, 1},
1231 /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1},
1232 /*pad_low=*/{0, 0}, /*pad-high=*/{0, 0},
1233 /*layout=*/{1, 0}, /*reducer=*/Reducer::kAdd},
1234 // With negative padding on both ends.
1235 {/*base_bounds=*/{4, 18}, /*window_bounds=*/{2, 4},
1236 /*strides=*/{1, 2},
1237 /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1},
1238 /*pad_low=*/{0, -1}, /*pad_high=*/{0, -1},
1239 /*layout=*/{0, 1},
1240 /*reducer=*/Reducer::kAdd},
1241 // With negative low padding and positive high padding.
1242 {/*base_bounds=*/{4, 18}, /*window_bounds=*/{2, 4},
1243 /*strides=*/{1, 2},
1244 /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1},
1245 /*pad_low=*/{0, -1}, /*pad_high=*/{0, 1},
1246 /*layout=*/{0, 1},
1247 /*reducer=*/Reducer::kAdd},
1248 // With positive low padding and negative high padding.
1249 {/*base_bounds=*/{4, 18}, /*window_bounds=*/{2, 4},
1250 /*strides=*/{1, 2},
1251 /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1},
1252 /*pad_low=*/{0, 1}, /*pad_high=*/{0, -1},
1253 /*layout=*/{0, 1},
1254 /*reducer=*/Reducer::kAdd},
1255 };
1256
R2ReduceWindowTestDataToString(const::testing::TestParamInfo<::testing::tuple<R2ReduceWindowTestData,bool>> & data)1257 std::string R2ReduceWindowTestDataToString(
1258 const ::testing::TestParamInfo<
1259 ::testing::tuple<R2ReduceWindowTestData, bool>>& data) {
1260 const auto& param = ::testing::get<0>(data.param);
1261 std::string str = absl::StrCat(
1262 "base_bounds_", absl::StrJoin(param.base_bounds, "x"), //
1263 "__window_bounds_", absl::StrJoin(param.window_bounds, "x"), //
1264 "__strides_", absl::StrJoin(param.strides, "x"), //
1265 "__base_dilation_", absl::StrJoin(param.base_dilation, "x"), //
1266 "__window_dilation_", absl::StrJoin(param.window_dilation, "x"), //
1267 "__pad_low_", absl::StrJoin(param.pad_low, "x"), "__pad_high_",
1268 absl::StrJoin(param.pad_high, "x"), "__layout_", param.layout[0], "_",
1269 param.layout[1], //
1270 "__reducer_", param.reducer == kAdd ? "add" : "max");
1271
1272 // Test names are not allowed to contain the '-' character.
1273 std::replace(str.begin(), str.end(), '-', 'n');
1274 if (::testing::get<1>(data.param)) {
1275 absl::StrAppend(&str, "_bfloat16");
1276 }
1277 return str;
1278 }
1279
1280 class R2ReduceWindowTest : public ReduceWindowTestBase,
1281 public ::testing::WithParamInterface<
1282 ::testing::tuple<R2ReduceWindowTestData, bool>> {
1283 protected:
R2ReduceWindowTest()1284 R2ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
1285
DoIt()1286 void DoIt() {
1287 XlaBuilder b(TestName());
1288 const auto& param = ::testing::get<0>(GetParam());
1289
1290 Array2D<float> input(param.base_bounds[0], param.base_bounds[1], 1.0f);
1291 if (!::testing::get<1>(GetParam())) {
1292 // We only do this in F32 mode, to avoid precision issues with BF16.
1293 input = *MakeLinspaceArray2D(0, 100, param.base_bounds[0],
1294 param.base_bounds[1]);
1295 }
1296 Literal input_literal = LiteralUtil::CreateR2FromArray2DWithLayout(
1297 input, LayoutUtil::MakeLayout(param.layout));
1298
1299 XlaOp parameter;
1300 TF_ASSERT_OK(CreateParameterAndTransferLiteral(0, input_literal, "p0", &b,
1301 ¶meter)
1302 .status());
1303
1304 std::vector<std::pair<int64_t, int64_t>> padding(2);
1305 for (int i = 0; i < 2; ++i) {
1306 padding[i] = {param.pad_low[i], param.pad_high[i]};
1307 }
1308 auto computation = param.reducer == kAdd
1309 ? CreateScalarAddComputation(FloatType(), &b)
1310 : CreateScalarMaxComputation(FloatType(), &b);
1311 const float kInitValue = 0.0f;
1312 auto init_value =
1313 CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b);
1314 ReduceWindowWithGeneralPadding(
1315 /*operand=*/parameter,
1316 /*init_value=*/init_value,
1317 /*computation=*/computation,
1318 /*window_dimensions=*/param.window_bounds,
1319 /*window_strides=*/param.strides,
1320 /*base_dilations=*/param.base_dilation,
1321 /*window_dilations=*/param.window_dilation,
1322 /*padding=*/padding);
1323
1324 ComputeAndCompare(&b, {MaybeConvertLiteralToBfloat16(input_literal)},
1325 DefaultErrorSpec());
1326 }
1327 };
1328
XLA_TEST_P(R2ReduceWindowTest,DoIt)1329 XLA_TEST_P(R2ReduceWindowTest, DoIt) { DoIt(); }
1330
1331 INSTANTIATE_TEST_CASE_P(
1332 R2ReduceWindowTestInstantiation, R2ReduceWindowTest,
1333 ::testing::Combine(::testing::ValuesIn(kR2TestCases),
1334 ::testing::ValuesIn(use_bfloat16_params)),
1335 R2ReduceWindowTestDataToString);
1336
1337 struct R1ReduceWindowTestData {
1338 int64_t base_bounds[1];
1339 int64_t window_bounds[1];
1340 int64_t strides[1];
1341 int64_t pad_low[1];
1342 int64_t pad_high[1];
1343 Reducer reducer;
1344 } kR1TestCases[] = {
1345 {/*base_bounds=*/{1}, /*window_bounds=*/{1},
1346 /*strides=*/{1},
1347 /*pad_low=*/{xla::MakePadding({1}, {1}, {1}, Padding::kValid)[0].first},
1348 /*pad_high=*/{xla::MakePadding({1}, {1}, {1}, Padding::kValid)[0].second},
1349 /*reducer=*/Reducer::kAdd},
1350
1351 {/*base_bounds=*/{3}, /*window_bounds=*/{3},
1352 /*strides=*/{1},
1353 /*pad_low=*/{xla::MakePadding({3}, {3}, {1}, Padding::kValid)[0].first},
1354 /*pad_high=*/{xla::MakePadding({3}, {3}, {1}, Padding::kValid)[0].second},
1355 /*reducer=*/Reducer::kAdd},
1356
1357 {/*base_bounds=*/{3}, /*window_bounds=*/{2},
1358 /*strides=*/{1},
1359 /*pad_low=*/{xla::MakePadding({3}, {2}, {1}, Padding::kValid)[0].first},
1360 /*pad_high=*/{xla::MakePadding({3}, {2}, {1}, Padding::kValid)[0].second},
1361 /*reducer=*/Reducer::kAdd},
1362
1363 {/*base_bounds=*/{5}, /*window_bounds=*/{1},
1364 /*strides=*/{1},
1365 /*pad_low=*/{xla::MakePadding({5}, {1}, {1}, Padding::kValid)[0].first},
1366 /*pad_high=*/{xla::MakePadding({5}, {1}, {1}, Padding::kValid)[0].second},
1367 /*reducer=*/Reducer::kMax},
1368
1369 {/*base_bounds=*/{16}, /*window_bounds=*/{4},
1370 /*strides=*/{4},
1371 /*pad_low=*/{xla::MakePadding({16}, {4}, {4}, Padding::kValid)[0].first},
1372 /*pad_high=*/{xla::MakePadding({16}, {4}, {4}, Padding::kValid)[0].second},
1373 /*reducer=*/Reducer::kMax},
1374
1375 {/*base_bounds=*/{16}, /*window_bounds=*/{4},
1376 /*strides=*/{3},
1377 /*pad_low=*/{xla::MakePadding({16}, {4}, {3}, Padding::kValid)[0].first},
1378 /*pad_high=*/{xla::MakePadding({16}, {4}, {3}, Padding::kValid)[0].second},
1379 /*reducer=*/Reducer::kAdd},
1380
1381 {/*base_bounds=*/{128 * 2},
1382 /*window_bounds=*/{30},
1383 /*strides=*/{27},
1384 /*pad_low=*/
1385 {xla::MakePadding({128 * 2}, {30}, {27}, Padding::kValid)[0].first},
1386 /*pad_high=*/
1387 {xla::MakePadding({128 * 2}, {30}, {27}, Padding::kValid)[0].second},
1388 /*reducer=*/Reducer::kAdd},
1389
1390 {/*base_bounds=*/{128 * 17},
1391 /*window_bounds=*/{7},
1392 /*strides=*/{64},
1393 /*pad_low=*/
1394 {xla::MakePadding({128 * 17}, {7}, {64}, Padding::kValid)[0].first},
1395 /*pad_high=*/
1396 {xla::MakePadding({128 * 17}, {7}, {64}, Padding::kValid)[0].second},
1397 /*reducer=*/Reducer::kAdd},
1398
1399 {/*base_bounds=*/{128 * 2},
1400 /*window_bounds=*/{32},
1401 /*strides=*/{56},
1402 /*pad_low=*/
1403 {xla::MakePadding({128 * 2}, {32}, {56}, Padding::kValid)[0].first},
1404 /*pad_high=*/
1405 {xla::MakePadding({128 * 2}, {32}, {56}, Padding::kValid)[0].second},
1406 /*reducer=*/Reducer::kAdd},
1407
1408 {/*base_bounds=*/{3}, /*window_bounds=*/{2},
1409 /*strides=*/{1},
1410 /*pad_low=*/{xla::MakePadding({3}, {2}, {1}, Padding::kSame)[0].first},
1411 /*pad_high=*/{xla::MakePadding({3}, {2}, {1}, Padding::kSame)[0].second},
1412 /*reducer=*/Reducer::kAdd},
1413
1414 {/*base_bounds=*/{5}, /*window_bounds=*/{3},
1415 /*strides=*/{2},
1416 /*pad_low=*/{xla::MakePadding({5}, {3}, {2}, Padding::kSame)[0].first},
1417 /*pad_high=*/{xla::MakePadding({5}, {3}, {2}, Padding::kSame)[0].second},
1418 /*reducer=*/Reducer::kAdd},
1419
1420 {/*base_bounds=*/{16}, /*window_bounds=*/{4},
1421 /*strides=*/{3},
1422 /*pad_low=*/{xla::MakePadding({16}, {4}, {3}, Padding::kSame)[0].first},
1423 /*pad_high=*/{xla::MakePadding({16}, {4}, {3}, Padding::kSame)[0].second},
1424 /*reducer=*/Reducer::kAdd},
1425
1426 {/*base_bounds=*/{5}, /*window_bounds=*/{5},
1427 /*strides=*/{1},
1428 /*pad_low=*/{0},
1429 /*pad_high=*/{5},
1430 /*reducer=*/Reducer::kAdd},
1431
1432 {/*base_bounds=*/{5}, /*window_bounds=*/{5},
1433 /*strides=*/{1},
1434 /*pad_low=*/{5},
1435 /*pad_high=*/{0},
1436 /*reducer=*/Reducer::kAdd},
1437
1438 // With negative padding on both ends.
1439 {/*base_bounds=*/{15}, /*window_bounds=*/{5},
1440 /*strides=*/{1},
1441 /*pad_low=*/{-5},
1442 /*pad_high=*/{-5},
1443 /*reducer=*/Reducer::kAdd},
1444
1445 // With negative low padding and positive high padding.
1446 {/*base_bounds=*/{15}, /*window_bounds=*/{5},
1447 /*strides=*/{1},
1448 /*pad_low=*/{-5},
1449 /*pad_high=*/{5},
1450 /*reducer=*/Reducer::kAdd},
1451
1452 // With positive low padding and negative high padding.
1453 {/*base_bounds=*/{15}, /*window_bounds=*/{5},
1454 /*strides=*/{1},
1455 /*pad_low=*/{5},
1456 /*pad_high=*/{-5},
1457 /*reducer=*/Reducer::kAdd},
1458
1459 // The pattern generated by inclusive scan (cumsum/cumprod).
1460 {/*base_bounds=*/{4096}, /*window_bounds=*/{4096},
1461 /*strides=*/{1},
1462 /*pad_low=*/{4095},
1463 /*pad_high=*/{0},
1464 /*reducer=*/Reducer::kMax},
1465
1466 // The pattern generated by exclusive scan (cumsum/cumprod).
1467 {/*base_bounds=*/{4095}, /*window_bounds=*/{4095},
1468 /*strides=*/{1},
1469 /*pad_low=*/{4095},
1470 /*pad_high=*/{0},
1471 /*reducer=*/Reducer::kMax},
1472
1473 // The pattern generated by inclusive reverse scan (cumsum/cumprod).
1474 {/*base_bounds=*/{4096}, /*window_bounds=*/{4096},
1475 /*strides=*/{1},
1476 /*pad_low=*/{0},
1477 /*pad_high=*/{4095},
1478 /*reducer=*/Reducer::kMax},
1479
1480 // The pattern generated by exclusive reverse scan (cumsum/cumprod).
1481 {/*base_bounds=*/{4095}, /*window_bounds=*/{4095},
1482 /*strides=*/{1},
1483 /*pad_low=*/{0},
1484 /*pad_high=*/{4095},
1485 /*reducer=*/Reducer::kMax},
1486 };
1487
R1ReduceWindowTestDataToString(const::testing::TestParamInfo<::testing::tuple<R1ReduceWindowTestData,bool>> & data)1488 std::string R1ReduceWindowTestDataToString(
1489 const ::testing::TestParamInfo<
1490 ::testing::tuple<R1ReduceWindowTestData, bool>>& data) {
1491 const auto& param = ::testing::get<0>(data.param);
1492 std::string str =
1493 absl::StrCat("base_bounds_", absl::StrJoin(param.base_bounds, "x"),
1494 "__window_bounds_", absl::StrJoin(param.window_bounds, "x"),
1495 "__strides_", absl::StrJoin(param.strides, "x"),
1496 "__pad_low_", absl::StrJoin(param.pad_low, "x"),
1497 "__pad_high_", absl::StrJoin(param.pad_high, "x"),
1498 "__reducer_", param.reducer == kAdd ? "add" : "max");
1499
1500 // Test names are not allowed to contain the '-' character.
1501 std::replace(str.begin(), str.end(), '-', 'n');
1502 if (::testing::get<1>(data.param)) {
1503 absl::StrAppend(&str, "_bfloat16");
1504 }
1505 return str;
1506 }
1507
1508 class R1ReduceWindowTest : public ReduceWindowTestBase,
1509 public ::testing::WithParamInterface<
1510 ::testing::tuple<R1ReduceWindowTestData, bool>> {
1511 protected:
R1ReduceWindowTest()1512 R1ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
1513 };
1514
XLA_TEST_P(R1ReduceWindowTest,DoIt)1515 XLA_TEST_P(R1ReduceWindowTest, DoIt) {
1516 XlaBuilder b(TestName());
1517 const auto& param = ::testing::get<0>(GetParam());
1518 CHECK(param.reducer == kAdd || param.reducer == kMax);
1519
1520 const float kInitValue = 0.0f;
1521 std::vector<float> input_vector(param.base_bounds[0]);
1522 std::iota(std::begin(input_vector), std::end(input_vector), 0);
1523 Literal input_literal =
1524 LiteralUtil::CreateR1(absl::Span<const float>(input_vector));
1525 XlaOp parameter;
1526 TF_ASSERT_OK_AND_ASSIGN(
1527 auto input_arg, CreateParameterAndTransferLiteral(0, input_literal, "p0",
1528 &b, ¶meter));
1529
1530 std::vector<std::pair<int64_t, int64_t>> padding(1);
1531 padding[0] = {param.pad_low[0], param.pad_high[0]};
1532
1533 auto computation = param.reducer == kAdd
1534 ? CreateScalarAddComputation(FloatType(), &b)
1535 : CreateScalarMaxComputation(FloatType(), &b);
1536 auto init_value =
1537 CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b);
1538 ReduceWindowWithGeneralPadding(
1539 /*operand=*/parameter,
1540 /*init_value=*/init_value,
1541 /*computation=*/computation,
1542 /*window_dimensions=*/param.window_bounds,
1543 /*window_strides=*/param.strides,
1544 /*base_dilations=*/{},
1545 /*window_dilations=*/{},
1546 /*padding=*/padding);
1547
1548 auto reduce_func = param.reducer == kAdd
1549 ? +[](float a, float b) { return a + b; }
1550 : +[](float a, float b) { return std::max(a, b); };
1551 auto expected = ReferenceUtil::ReduceWindow1DGeneric(
1552 /*operand=*/absl::Span<const float>(input_vector),
1553 /*init=*/kInitValue,
1554 /*reduce_func=*/reduce_func,
1555 /*window=*/param.window_bounds,
1556 /*stride=*/param.strides,
1557 /*padding=*/padding);
1558
1559 ComputeAndCompareLiteral(&b, LiteralUtil::CreateR1<float>(*expected),
1560 {input_arg.get()}, DefaultErrorSpec());
1561 }
1562
1563 INSTANTIATE_TEST_CASE_P(
1564 R1ReduceWindowTestInstantiation, R1ReduceWindowTest,
1565 ::testing::Combine(::testing::ValuesIn(kR1TestCases),
1566 ::testing::ValuesIn(use_bfloat16_params)),
1567 R1ReduceWindowTestDataToString);
1568
1569 // Test class for text-based test cases. Note that this compares with the
1570 // results on the interpreter backend.
1571 class ReduceWindowTextTest : public HloTestBase {};
1572
XLA_TEST_F(ReduceWindowTextTest,R2General256x384)1573 XLA_TEST_F(ReduceWindowTextTest, R2General256x384) {
1574 const std::string hlo_string = R"(
1575 HloModule R2Window
1576 mul {
1577 lhs = f32[] parameter(0)
1578 rhs = f32[] parameter(1)
1579 ROOT mul = f32[] multiply(lhs, rhs)
1580 }
1581 ENTRY R2Window {
1582 operand = f32[256,384]{1,0} parameter(0)
1583 constant = f32[] constant(1)
1584 ROOT reduce-window = f32[256,384]{1,0} reduce-window(operand, constant), window={size=2x3 pad=0_1x1_1}, to_apply=mul
1585 }
1586 )";
1587 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001}));
1588 }
1589
XLA_TEST_F(ReduceWindowTextTest,R2General256x384Layout01)1590 XLA_TEST_F(ReduceWindowTextTest, R2General256x384Layout01) {
1591 const std::string hlo_string = R"(
1592 HloModule R2Window
1593 mul {
1594 lhs = f32[] parameter(0)
1595 rhs = f32[] parameter(1)
1596 ROOT mul = f32[] multiply(lhs, rhs)
1597 }
1598 ENTRY R2Window {
1599 operand = f32[256,384]{0,1} parameter(0)
1600 constant = f32[] constant(1)
1601 ROOT reduce-window = f32[256,384]{0,1} reduce-window(operand, constant), window={size=2x3 pad=0_1x1_1}, to_apply=mul
1602 }
1603 )";
1604 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001}));
1605 }
1606
XLA_TEST_F(ReduceWindowTextTest,R2General2x5)1607 XLA_TEST_F(ReduceWindowTextTest, R2General2x5) {
1608 const std::string hlo_string = R"(
1609 HloModule R2Window
1610 mul {
1611 lhs = f32[] parameter(0)
1612 rhs = f32[] parameter(1)
1613 ROOT mul = f32[] multiply(lhs, rhs)
1614 }
1615 ENTRY R2Window {
1616 operand = f32[2,5]{1,0} parameter(0)
1617 constant = f32[] constant(1)
1618 ROOT reduce-window = f32[3,5]{1,0} reduce-window(operand, constant), window={size=2x1 pad=0_2x0_0}, to_apply=mul
1619 }
1620 )";
1621 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001}));
1622 }
1623
XLA_TEST_F(ReduceWindowTextTest,R2EffectiveScalar)1624 XLA_TEST_F(ReduceWindowTextTest, R2EffectiveScalar) {
1625 const std::string hlo_string = R"(
1626 HloModule R2Window
1627 mul {
1628 lhs = f32[] parameter(0)
1629 rhs = f32[] parameter(1)
1630 ROOT mul = f32[] multiply(lhs, rhs)
1631 }
1632 ENTRY R2Window {
1633 operand = f32[1,1]{1,0} parameter(0)
1634 negate = f32[1,1]{1,0} negate(operand)
1635 constant = f32[] constant(1)
1636 ROOT reduce-window = f32[1,1]{1,0} reduce-window(negate, constant), window={size=1x1 pad=0_0x0_0}, to_apply=mul
1637 }
1638 )";
1639 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001}));
1640 }
1641
XLA_TEST_F(ReduceWindowTextTest,R3EffectiveScalar)1642 XLA_TEST_F(ReduceWindowTextTest, R3EffectiveScalar) {
1643 const std::string hlo_string = R"(
1644 HloModule R3Window
1645 mul {
1646 lhs = f32[] parameter(0)
1647 rhs = f32[] parameter(1)
1648 ROOT mul = f32[] multiply(lhs, rhs)
1649 }
1650 ENTRY R3Window {
1651 operand = f32[1,1,1]{2,1,0} parameter(0)
1652 negate = f32[1,1,1]{2,1,0} negate(operand)
1653 constant = f32[] constant(1)
1654 ROOT reduce-window = f32[1,1,1]{2,1,0}
1655 reduce-window(negate, constant),
1656 window={size=1x1x1 pad=0_0x0_0x0_0},
1657 to_apply=mul
1658 }
1659 )";
1660 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001}));
1661 }
1662
XLA_TEST_F(HloTestBase,ReduceWindowIdentity)1663 XLA_TEST_F(HloTestBase, ReduceWindowIdentity) {
1664 const std::string hlo_string = R"(
1665 HloModule ReduceWindowIdentity
1666 identity.pad_to_reduce_window {
1667 param0 = f32[] parameter(0)
1668 ROOT param1 = f32[] parameter(1)
1669 }
1670 ENTRY reduce-window-identity {
1671 operand = f32[1,32,64]{2,1,0} parameter(0)
1672 constant.4466 = f32[] constant(0)
1673 ROOT reduce-window = f32[1,33,64]{2,1,0}
1674 reduce-window(operand, constant.4466),
1675 window={size=1x1x1 pad=0_0x1_0x0_0},
1676 to_apply=identity.pad_to_reduce_window
1677 }
1678
1679 )";
1680 EXPECT_TRUE(RunAndCompare(hlo_string, std::nullopt));
1681 }
1682
XLA_TEST_F(HloTestBase,ReduceWindowIdentityNoPadding)1683 XLA_TEST_F(HloTestBase, ReduceWindowIdentityNoPadding) {
1684 const std::string hlo_string = R"(
1685 HloModule ReduceWindowIdentity
1686 identity.pad_to_reduce_window {
1687 param0 = f32[] parameter(0)
1688 ROOT param1 = f32[] parameter(1)
1689 }
1690 ENTRY reduce-window-identity {
1691 operand = f32[1,32,64]{2,1,0} parameter(0)
1692 constant.4466 = f32[] constant(0)
1693 ROOT reduce-window = f32[1,32,64]{2,1,0}
1694 reduce-window(operand, constant.4466),
1695 window={size=1x1x1 pad=0_0x0_0x0_0},
1696 to_apply=identity.pad_to_reduce_window
1697 }
1698
1699 )";
1700 EXPECT_TRUE(RunAndCompare(hlo_string, std::nullopt));
1701 }
1702
XLA_TEST_F(HloTestBase,ReduceWindowS32)1703 XLA_TEST_F(HloTestBase, ReduceWindowS32) {
1704 const std::string hlo_string = R"(
1705 HloModule reduce-window
1706
1707 %identity.pad_to_reduce_window (param0: s32[], param1: s32[]) -> s32[] {
1708 %param0 = s32[] parameter(0)
1709 ROOT %param1 = s32[] parameter(1)
1710 }
1711
1712 ENTRY %reduce-window (parameter.0: s32[81,8], parameter.1: s32[]) -> s32[82,8] {
1713 %parameter.0 = s32[81,8]{1,0} parameter(0)
1714 %parameter.1 = s32[] parameter(1)
1715 ROOT %reduce-window = s32[82,8]{1,0} reduce-window(s32[81,8]{1,0} %parameter.0, s32[] %parameter.1), window={size=1x1 pad=0_1x0_0}, to_apply=%identity.pad_to_reduce_window
1716 }
1717
1718 )";
1719 EXPECT_TRUE(RunAndCompare(hlo_string, std::nullopt));
1720 }
1721
XLA_TEST_F(HloTestBase,ReduceWindowS64)1722 XLA_TEST_F(HloTestBase, ReduceWindowS64) {
1723 const std::string hlo_string = R"(
1724 HloModule reduce-window
1725
1726 %identity.pad_to_reduce_window (param0: s64[], param1: s64[]) -> s64[] {
1727 %param0 = s64[] parameter(0)
1728 ROOT %param1 = s64[] parameter(1)
1729 }
1730
1731 ENTRY %reduce-window (parameter.0: s64[81,8], parameter.1: s64[]) -> s64[82,8] {
1732 %parameter.0 = s64[81,8]{1,0} parameter(0)
1733 %parameter.1 = s64[] parameter(1)
1734 ROOT %reduce-window = s64[82,8]{1,0} reduce-window(s64[81,8]{1,0} %parameter.0, s64[] %parameter.1), window={size=1x1 pad=0_1x0_0}, to_apply=%identity.pad_to_reduce_window
1735 }
1736
1737 )";
1738 EXPECT_TRUE(RunAndCompare(hlo_string, std::nullopt));
1739 }
1740
XLA_TEST_F(HloTestBase,ReduceWindowF16)1741 XLA_TEST_F(HloTestBase, ReduceWindowF16) {
1742 const std::string hlo_string = R"(
1743 HloModule reduce-window
1744
1745 %identity.pad_to_reduce_window (param0: f16[], param1: f16[]) -> f16[] {
1746 %param0 = f16[] parameter(0)
1747 ROOT %param1 = f16[] parameter(1)
1748 }
1749
1750 ENTRY %reduce-window (parameter.0: f16[81,8], parameter.1: f16[]) -> f16[82,8] {
1751 %parameter.0 = f16[81,8]{1,0} parameter(0)
1752 %parameter.1 = f16[] parameter(1)
1753 ROOT %reduce-window = f16[82,8]{1,0} reduce-window(f16[81,8]{1,0} %parameter.0, f16[] %parameter.1), window={size=1x1 pad=0_1x0_0}, to_apply=%identity.pad_to_reduce_window
1754 }
1755
1756 )";
1757 EXPECT_TRUE(RunAndCompare(hlo_string, std::nullopt));
1758 }
1759
XLA_TEST_F(ReduceWindowTextTest,R4OnlyDilation)1760 XLA_TEST_F(ReduceWindowTextTest, R4OnlyDilation) {
1761 const std::string hlo_string = R"(
1762 HloModule R4OnlyDilation
1763 mul {
1764 lhs = f32[] parameter(0)
1765 rhs = f32[] parameter(1)
1766 ROOT mul = f32[] multiply(lhs, rhs)
1767 }
1768 ENTRY R4OnlyDilation {
1769 operand = f32[2,2,2,2]{3,2,1,0} parameter(0)
1770 constant = f32[] constant(1)
1771 ROOT reduce-window = f32[3,3,3,3]{3,2,1,0}
1772 reduce-window(operand, constant),
1773 window={size=1x1x1x1 pad=0_0x0_0x0_0x0_0 lhs_dilate=2x2x2x2},
1774 to_apply=mul
1775 }
1776 )";
1777 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001}));
1778 }
1779
XLA_TEST_F(HloTestBase,DISABLED_ON_GPU (ReduceWindowVariadicSupport))1780 XLA_TEST_F(HloTestBase, DISABLED_ON_GPU(ReduceWindowVariadicSupport)) {
1781 const char* const hlo_string = R"(
1782 HloModule module
1783
1784 sum {
1785 a0 = f32[] parameter(0)
1786 a1 = f32[] parameter(1)
1787 b0 = f32[] parameter(2)
1788 b1 = f32[] parameter(3)
1789 add0 = f32[] add(a0, b0)
1790 add1 = f32[] add(a1, b1)
1791 ROOT sum2 = (f32[], f32[]) tuple(add0, add1)
1792 }
1793
1794 ENTRY entry {
1795 constant = f32[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}})
1796 constant.1 = f32[] constant(0)
1797 constant.2 = f32[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}})
1798 constant.3 = f32[] constant(0)
1799 reduce-window = (f32[2,2]{1,0}, f32[2,2]{1,0})
1800 reduce-window(constant, constant.2, constant.1, constant.3),
1801 window={size=5x1 stride=3x1 pad=2_2x0_0}, to_apply=sum
1802 ROOT copy = (f32[2,2]{1,0}, f32[2,2]{1,0}) copy(reduce-window)
1803 })";
1804 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-4, 1e-4}));
1805 }
1806
XLA_TEST_F(HloTestBase,DISABLED_ON_GPU (ReduceWindowVariadicSupport2))1807 XLA_TEST_F(HloTestBase, DISABLED_ON_GPU(ReduceWindowVariadicSupport2)) {
1808 const char* const hlo_string = R"(
1809 HloModule module
1810
1811 sum {
1812 a0 = f32[] parameter(0)
1813 a1 = s32[] parameter(1)
1814 b0 = f32[] parameter(2)
1815 b1 = s32[] parameter(3)
1816 add0 = f32[] add(a0, b0)
1817 add1 = s32[] add(a1, b1)
1818 ROOT sum2 = (f32[], s32[]) tuple(add0, add1)
1819 }
1820
1821 ENTRY entry {
1822 constant = f32[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}})
1823 constant.1 = f32[] constant(0)
1824 constant.2 = s32[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}})
1825 constant.3 = s32[] constant(0)
1826 ROOT reduce-window = (f32[2,2]{1,0}, s32[2,2]{1,0})
1827 reduce-window(constant, constant.2, constant.1, constant.3),
1828 window={size=5x1 stride=3x1 pad=2_2x0_0}, to_apply=sum
1829 })";
1830 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-4, 1e-4}));
1831 }
1832
XLA_TEST_F(HloTestBase,DISABLED_ON_GPU (ReduceWindowVariadicSupport3))1833 XLA_TEST_F(HloTestBase, DISABLED_ON_GPU(ReduceWindowVariadicSupport3)) {
1834 const char* const hlo_string = R"(
1835 HloModule module
1836
1837 sum {
1838 a0 = f32[] parameter(0)
1839 a1 = bf16[] parameter(1)
1840 b0 = f32[] parameter(2)
1841 b1 = bf16[] parameter(3)
1842 add0 = f32[] add(a0, b0)
1843 add1 = bf16[] add(a1, b1)
1844 ROOT sum2 = (f32[], bf16[]) tuple(add0, add1)
1845 }
1846
1847 ENTRY entry {
1848 constant = f32[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}})
1849 constant.1 = f32[] constant(0)
1850 constant.2 = bf16[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}})
1851 constant.3 = bf16[] constant(0)
1852 ROOT reduce-window = (f32[2,2]{1,0}, bf16[2,2]{1,0})
1853 reduce-window(constant, constant.2, constant.1, constant.3),
1854 window={size=5x1 stride=3x1 pad=2_2x0_0}, to_apply=sum
1855 })";
1856 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-4, 1e-4}));
1857 }
1858
XLA_TEST_F(HloTestBase,DISABLED_ON_GPU (ReduceWindowVariadicSupport4))1859 XLA_TEST_F(HloTestBase, DISABLED_ON_GPU(ReduceWindowVariadicSupport4)) {
1860 const char* const hlo_string = R"(
1861 HloModule module
1862
1863 sum {
1864 a0 = f32[] parameter(0)
1865 a1 = bf16[] parameter(1)
1866 b0 = f32[] parameter(2)
1867 b1 = bf16[] parameter(3)
1868 add0 = f32[] add(a0, b0)
1869 add1 = bf16[] multiply(a1, b1)
1870 ROOT sum2 = (f32[], bf16[]) tuple(add0, add1)
1871 }
1872
1873 ENTRY entry {
1874 constant = f32[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}})
1875 constant.1 = f32[] constant(0)
1876 constant.2 = bf16[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}})
1877 constant.3 = bf16[] constant(1)
1878 ROOT reduce-window = (f32[2,2]{1,0}, bf16[2,2]{1,0})
1879 reduce-window(constant, constant.2, constant.1, constant.3),
1880 window={size=5x1 stride=3x1 pad=2_2x0_0}, to_apply=sum
1881 })";
1882 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-4, 1e-4}));
1883 }
1884
XLA_TEST_F(HloTestBase,DISABLED_ON_GPU (ReduceWindowS64Support))1885 XLA_TEST_F(HloTestBase, DISABLED_ON_GPU(ReduceWindowS64Support)) {
1886 const char* const hlo_string = R"(
1887 HloModule jit_dilated_window_sum.10
1888
1889 %primitive_computation_add.4 (parameter.5: s64[], parameter.6: s64[]) -> s64[] {
1890 %parameter.5 = s64[] parameter(0), metadata={op_type="add" op_name="add"}
1891 %parameter.6 = s64[] parameter(1), metadata={op_type="add" op_name="add"}
1892 ROOT %add.7 = s64[] add(s64[] %parameter.5, s64[] %parameter.6), metadata={op_type="add" op_name="add"}
1893 }
1894
1895 ENTRY %jit_dilated_window_sum.10 (parameter.1: s64[8,10,12]) -> (s64[8,10,12]) {
1896 %constant.2 = pred[] constant(false)
1897 %parameter.1 = s64[8,10,12]{2,1,0} parameter(0)
1898 %constant.3 = s64[] constant(0), metadata={op_type="reduce_window_sum" op_name="jit(dilated_window_sum)/reduce_window_sum[ base_dilation=(1, 1, 1)\n padding=((1, 1), (1, 1), (0, 0))\n window_dilation=(1, 2, 2)\n window_dimensions=(3, 2, 1)\n window_strides=(1, 1, 1) ]" source_file="<ipython-input-1-315291761729>" source_line=9}
1899 %reduce-window.8 = s64[8,10,12]{2,1,0} reduce-window(s64[8,10,12]{2,1,0} %parameter.1, s64[] %constant.3), window={size=3x2x1 pad=1_1x1_1x0_0 rhs_dilate=1x2x2}, to_apply=%primitive_computation_add.4, metadata={op_type="reduce_window_sum" op_name="jit(dilated_window_sum)/reduce_window_sum[ base_dilation=(1, 1, 1)\n padding=((1, 1), (1, 1), (0, 0))\n window_dilation=(1, 2, 2)\n window_dimensions=(3, 2, 1)\n window_strides=(1, 1, 1) ]" source_file="<ipython-input-1-315291761729>" source_line=9}
1900 ROOT %tuple.9 = (s64[8,10,12]{2,1,0}) tuple(s64[8,10,12]{2,1,0} %reduce-window.8)
1901 })";
1902 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-4, 1e-4}));
1903 }
1904
XLA_TEST_F(HloTestBase,DISABLED_ON_GPU (ReduceWindowS64Support2))1905 XLA_TEST_F(HloTestBase, DISABLED_ON_GPU(ReduceWindowS64Support2)) {
1906 const char* const hlo_string = R"(
1907 HloModule SyncTensorsGraph.43
1908
1909 %MulComputation.18 (x.19: s64[], y.20: s64[]) -> s64[] {
1910 %x.19 = s64[] parameter(0)
1911 %y.20 = s64[] parameter(1)
1912 ROOT %multiply.21 = s64[] multiply(s64[] %x.19, s64[] %y.20)
1913 }
1914
1915 %MulComputation.26 (x.27: s64[], y.28: s64[]) -> s64[] {
1916 %x.27 = s64[] parameter(0)
1917 %y.28 = s64[] parameter(1)
1918 ROOT %multiply.29 = s64[] multiply(s64[] %x.27, s64[] %y.28)
1919 }
1920
1921 %MaxComputation.34 (x.35: s64[], y.36: s64[]) -> s64[] {
1922 %x.35 = s64[] parameter(0)
1923 %y.36 = s64[] parameter(1)
1924 ROOT %maximum.37 = s64[] maximum(s64[] %x.35, s64[] %y.36)
1925 }
1926
1927 ENTRY %SyncTensorsGraph.43 (p0.1: f32[], p1.7: pred[3,3]) -> (pred[]) {
1928 %constant.8 = pred[] constant(false)
1929 %reshape.9 = pred[1,1]{1,0} reshape(pred[] %constant.8)
1930 %broadcast.10 = pred[1,1]{1,0} broadcast(pred[1,1]{1,0} %reshape.9), dimensions={0,1}
1931 %reshape.11 = pred[] reshape(pred[1,1]{1,0} %broadcast.10)
1932 %broadcast.12 = pred[3,3]{1,0} broadcast(pred[] %reshape.11), dimensions={}
1933 %p1.7 = pred[3,3]{1,0} parameter(1)
1934 %reshape.13 = pred[3,3]{1,0} reshape(pred[3,3]{1,0} %p1.7)
1935 %reshape.14 = pred[3,3]{1,0} reshape(pred[3,3]{1,0} %reshape.13)
1936 %convert.24 = s64[3,3]{1,0} convert(pred[3,3]{1,0} %reshape.14)
1937 %constant.25 = s64[] constant(1)
1938 %reduce-window.30 = s64[3,3]{1,0} reduce-window(s64[3,3]{1,0} %convert.24, s64[] %constant.25), window={size=3x1 pad=2_0x0_0}, to_apply=%MulComputation.26
1939 %convert.15 = s32[3,3]{1,0} convert(pred[3,3]{1,0} %reshape.14)
1940 %convert.16 = s64[3,3]{1,0} convert(s32[3,3]{1,0} %convert.15)
1941 %constant.17 = s64[] constant(1)
1942 %reduce-window.22 = s64[3,3]{1,0} reduce-window(s64[3,3]{1,0} %convert.16, s64[] %constant.17), window={size=3x1 pad=2_0x0_0}, to_apply=%MulComputation.18
1943 %constant.2 = s64[] constant(1)
1944 %reshape.3 = s64[1,1]{1,0} reshape(s64[] %constant.2)
1945 %broadcast.4 = s64[1,1]{1,0} broadcast(s64[1,1]{1,0} %reshape.3), dimensions={0,1}
1946 %reshape.5 = s64[] reshape(s64[1,1]{1,0} %broadcast.4)
1947 %broadcast.6 = s64[3,3]{1,0} broadcast(s64[] %reshape.5), dimensions={}
1948 %multiply.23 = s64[3,3]{1,0} multiply(s64[3,3]{1,0} %reduce-window.22, s64[3,3]{1,0} %broadcast.6)
1949 %subtract.31 = s64[3,3]{1,0} subtract(s64[3,3]{1,0} %reduce-window.30, s64[3,3]{1,0} %multiply.23)
1950 %abs.32 = s64[3,3]{1,0} abs(s64[3,3]{1,0} %subtract.31)
1951 %constant.33 = s64[] constant(-9223372036854775808)
1952 %reduce.38 = s64[] reduce(s64[3,3]{1,0} %abs.32, s64[] %constant.33), dimensions={0,1}, to_apply=%MaxComputation.34
1953 %reshape.39 = s64[] reshape(s64[] %reduce.38)
1954 %convert.40 = f32[] convert(s64[] %reshape.39)
1955 %p0.1 = f32[] parameter(0)
1956 %compare.41 = pred[] compare(f32[] %convert.40, f32[] %p0.1), direction=LE
1957 ROOT %tuple.42 = (pred[]) tuple(pred[] %compare.41)
1958 })";
1959 EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-4, 1e-4}));
1960 }
1961
1962 } // namespace
1963 } // namespace xla
1964