xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/tests/reduce_window_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Tests the reduce-window XLA operation.
17 
18 #include <limits>
19 #include <memory>
20 
21 #include "absl/strings/str_cat.h"
22 #include "absl/strings/str_join.h"
23 #include "absl/types/span.h"
24 #include "tensorflow/compiler/xla/array2d.h"
25 #include "tensorflow/compiler/xla/array3d.h"
26 #include "tensorflow/compiler/xla/array4d.h"
27 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
28 #include "tensorflow/compiler/xla/client/local_client.h"
29 #include "tensorflow/compiler/xla/client/padding.h"
30 #include "tensorflow/compiler/xla/client/xla_builder.h"
31 #include "tensorflow/compiler/xla/client/xla_computation.h"
32 #include "tensorflow/compiler/xla/reference_util.h"
33 #include "tensorflow/compiler/xla/shape_util.h"
34 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
35 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
36 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
37 #include "tensorflow/compiler/xla/tests/test_macros.h"
38 #include "tensorflow/compiler/xla/xla_data.pb.h"
39 #include "tensorflow/core/lib/core/status.h"
40 #include "tensorflow/core/lib/core/status_test_util.h"
41 #include "tensorflow/core/platform/test.h"
42 
43 namespace xla {
44 namespace {
45 
46 #ifdef XLA_BACKEND_SUPPORTS_BFLOAT16
47 // Tests both F32 and BF16.
48 static std::array<bool, 2> use_bfloat16_params{false, true};
49 #else
50 // Only tests F32.
51 static std::array<bool, 1> use_bfloat16_params{false};
52 #endif
53 
54 class ReduceWindowTestBase : public ClientLibraryTestBase {
55  public:
DefaultErrorSpec() const56   ErrorSpec DefaultErrorSpec() const {
57     if (use_bfloat16()) {
58       return ErrorSpec(2e-1, 6e-2);
59     } else {
60       return ErrorSpec(1e-3, 1e-3);
61     }
62   }
63 };
64 
65 class ReduceWindowTest : public ::testing::WithParamInterface<bool>,
66                          public ReduceWindowTestBase {
67  public:
ReduceWindowTest()68   ReduceWindowTest() : builder_(TestName()) { set_use_bfloat16(GetParam()); }
69 
ReduceWindowAdd(const XlaOp input,absl::Span<const int64_t> window_dimensions,absl::Span<const int64_t> window_strides,Padding padding)70   void ReduceWindowAdd(const XlaOp input,
71                        absl::Span<const int64_t> window_dimensions,
72                        absl::Span<const int64_t> window_strides,
73                        Padding padding) {
74     auto init = CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(0.0f),
75                                           &builder_);
76     ReduceWindow(input, init,
77                  CreateScalarAddComputation(FloatType(), &builder_),
78                  window_dimensions, window_strides, padding);
79   }
80 
ReduceWindowMax(const XlaOp input,absl::Span<const int64_t> window_dimensions,absl::Span<const int64_t> window_strides,Padding padding)81   void ReduceWindowMax(const XlaOp input,
82                        absl::Span<const int64_t> window_dimensions,
83                        absl::Span<const int64_t> window_strides,
84                        Padding padding) {
85     auto init =
86         CreateConstantFromLiteral(LiteralUtil::MinValue(F32), &builder_);
87     ReduceWindow(input, init,
88                  CreateScalarMaxComputation(FloatType(), &builder_),
89                  window_dimensions, window_strides, padding);
90   }
91 
ReduceWindowMin(const XlaOp input,absl::Span<const int64_t> window_dimensions,absl::Span<const int64_t> window_strides,Padding padding)92   void ReduceWindowMin(const XlaOp input,
93                        absl::Span<const int64_t> window_dimensions,
94                        absl::Span<const int64_t> window_strides,
95                        Padding padding) {
96     auto init =
97         CreateConstantFromLiteral(LiteralUtil::MaxValue(F32), &builder_);
98     ReduceWindow(input, init,
99                  CreateScalarMinComputation(FloatType(), &builder_),
100                  window_dimensions, window_strides, padding);
101   }
102 
103   XlaBuilder builder_;
104 };
105 
XLA_TEST_P(ReduceWindowTest,MismatchedRanksGivesErrorStatus)106 XLA_TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) {
107   const auto input = CreateConstantFromLiteral(
108       LiteralUtil::CreateR1<float>({1, 1, 1, 1}), &builder_);
109   const auto init_value =
110       CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(0), &builder_);
111   TF_ASSERT_OK(builder_.first_error());
112   ReduceWindow(input, init_value,
113                CreateScalarAddComputation(FloatType(), &builder_),
114                /*window_dimensions=*/{1, 2},
115                /*window_strides=*/{1}, Padding::kValid);
116   ASSERT_EQ(builder_.first_error().code(), tensorflow::error::INVALID_ARGUMENT)
117       << builder_.first_error();
118   ASSERT_THAT(builder_.first_error().error_message(),
119               ::testing::HasSubstr("Want input dimensions size"));
120 }
121 
122 // Regression test for b/68964348.
XLA_TEST_P(ReduceWindowTest,R0ReduceWindow)123 XLA_TEST_P(ReduceWindowTest, R0ReduceWindow) {
124   const auto input =
125       CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(42.0), &builder_);
126   const auto init =
127       CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(0.0), &builder_);
128   ReduceWindow(input, init, CreateScalarAddComputation(FloatType(), &builder_),
129                /*window_dimensions=*/{},
130                /*window_strides=*/{}, Padding::kSame);
131   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR0<float>(42.0), {},
132                            ErrorSpec(0.00001));
133 }
134 
XLA_TEST_P(ReduceWindowTest,Min3In5Stride2)135 XLA_TEST_P(ReduceWindowTest, Min3In5Stride2) {
136   const auto input = CreateConstantFromLiteral(
137       LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_);
138   ReduceWindowMin(input, {3}, {2}, Padding::kValid);
139   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1<float>({100, 1}),
140                            {}, ErrorSpec(0.00001));
141 }
142 
XLA_TEST_P(ReduceWindowTest,Min3In5Stride2Same)143 XLA_TEST_P(ReduceWindowTest, Min3In5Stride2Same) {
144   const auto input = CreateConstantFromLiteral(
145       LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_);
146   ReduceWindowMin(input, {3}, {2}, Padding::kSame);
147   ComputeAndCompareLiteral(&builder_,
148                            LiteralUtil::CreateR1<float>({1000, 10, 1}), {},
149                            ErrorSpec(0.00001));
150 }
151 
XLA_TEST_P(ReduceWindowTest,Min3In5Stride1WithSamePadding)152 XLA_TEST_P(ReduceWindowTest, Min3In5Stride1WithSamePadding) {
153   const auto input = CreateConstantFromLiteral(
154       LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1}), &builder_);
155   ReduceWindowMin(input, /*window_dimensions=*/{3}, /*window_strides=*/{1},
156                   Padding::kSame);
157   ComputeAndCompareLiteral(&builder_,
158                            LiteralUtil::CreateR1<float>({1000, 100, 10, 1, 1}),
159                            {}, ErrorSpec(0.00001));
160 }
161 
XLA_TEST_P(ReduceWindowTest,ZeroElementSmall)162 XLA_TEST_P(ReduceWindowTest, ZeroElementSmall) {
163   Array4D<float> input_array(1, 0, 2, 1);
164   const auto input = CreateConstantFromArray(input_array, &builder_);
165   Padding padding = Padding::kSame;
166   ReduceWindowAdd(input, {1, 1, 2, 1}, {1, 1, 1, 1}, padding);
167 
168   auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1},
169                                               {1, 1, 1, 1}, padding);
170 
171   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {},
172                            DefaultErrorSpec());
173 }
174 
XLA_TEST_P(ReduceWindowTest,NonSquareSmall)175 XLA_TEST_P(ReduceWindowTest, NonSquareSmall) {
176   Array4D<float> input_array(1, 2, 2, 1);
177   input_array.FillRandom(2.f, 2.f);
178   const auto input = CreateConstantFromArray(input_array, &builder_);
179 
180   Padding padding = Padding::kSame;
181   ReduceWindowAdd(input, {1, 1, 2, 1}, {1, 1, 1, 1}, padding);
182 
183   auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1},
184                                               {1, 1, 1, 1}, padding);
185 
186   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {},
187                            DefaultErrorSpec());
188 }
189 
XLA_TEST_P(ReduceWindowTest,MiddleDimsSmall)190 XLA_TEST_P(ReduceWindowTest, MiddleDimsSmall) {
191   Array4D<float> input_array(1, 3, 3, 1);
192   input_array.FillRandom(2.f, 2.f);
193   const auto input = CreateConstantFromArray(input_array, &builder_);
194   Padding padding = Padding::kSame;
195   ReduceWindowAdd(input, {1, 1, 1, 1}, {1, 2, 2, 1}, padding);
196 
197   auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 1, 1},
198                                               {1, 2, 2, 1}, padding);
199 
200   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {},
201                            DefaultErrorSpec());
202 }
203 
XLA_TEST_P(ReduceWindowTest,Along2ndMinorDim)204 XLA_TEST_P(ReduceWindowTest, Along2ndMinorDim) {
205   Array4D<float> input_array(3, 6, 7, 32);
206   input_array.FillRandom(2.f, 2.f);
207   const auto input = CreateConstantFromArray(input_array, &builder_);
208 
209   // The parameters of this reduction mimic feature norm (e.g. LRN).
210   int lrn_diameter = 7;  // diameter = 2*radius + 1 --> must be odd
211   Padding padding = Padding::kSame;
212   ReduceWindowAdd(input, {1, 1, lrn_diameter, 1}, {1, 1, 1, 1}, padding);
213 
214   auto res = ReferenceUtil::ReduceWindow4DAdd(
215       input_array, 0.0f, {1, 1, lrn_diameter, 1}, {1, 1, 1, 1}, padding);
216 
217   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res), {},
218                            DefaultErrorSpec());
219 }
220 
XLA_TEST_P(ReduceWindowTest,AmongMajor2DimsAdd)221 XLA_TEST_P(ReduceWindowTest, AmongMajor2DimsAdd) {
222   Array4D<float> input_array(4, 4, 6, 8);
223   input_array.FillWithMinorDimNum();
224   const auto input_data_handle =
225       CreateConstantFromArray(input_array, &builder_);
226 
227   int win_len = 3;
228   int win_stride = 1;
229 
230   Padding padding = Padding::kSame;
231   // Reduce only along the x and y dimensions, according to the win_len.
232   ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1},
233                   {win_stride, win_stride, 1, 1}, padding);
234 
235   auto result = ReferenceUtil::ReduceWindow4DAdd(
236       input_array, 0.0f, {win_len, win_len, 1, 1},
237       {win_stride, win_stride, 1, 1}, padding);
238 
239   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
240                            DefaultErrorSpec());
241 }
242 
XLA_TEST_P(ReduceWindowTest,AmongMajor2DimsMax)243 XLA_TEST_P(ReduceWindowTest, AmongMajor2DimsMax) {
244   Array4D<float> input_array(3, 3, 2, 1);
245   input_array.FillWithMinorDimNum();
246   const auto input_data_handle =
247       CreateConstantFromArray(input_array, &builder_);
248   int win_len = 2;
249   int win_stride = 1;
250   Padding padding = Padding::kValid;
251   // Reduce only along the x and y dimensions, according to the win_len.
252   ReduceWindowMax(input_data_handle, {win_len, win_len, 1, 1},
253                   {win_stride, win_stride, 1, 1}, padding);
254   ComputeAndCompare(&builder_, {}, DefaultErrorSpec());
255 }
256 
XLA_TEST_P(ReduceWindowTest,AmongMajor2DimsMediumSize)257 XLA_TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) {
258   Array4D<float> input_array(9, 12, 4, 89);
259   input_array.FillRandom(2.f, 2.f);
260 
261   int win_len = 3;
262   int win_stride = 2;
263 
264   const auto input_data_handle =
265       CreateConstantFromArray(input_array, &builder_);
266 
267   Padding padding = Padding::kSame;
268   // Reduce only along the x and y dimensions, according to the win_len.
269   ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1},
270                   {win_stride, win_stride, 1, 1}, padding);
271 
272   auto result = ReferenceUtil::ReduceWindow4DAdd(
273       input_array, 0.0f, {win_len, win_len, 1, 1},
274       {win_stride, win_stride, 1, 1}, padding);
275 
276   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
277                            DefaultErrorSpec());
278 }
279 
280 // Tests the super windowing logic w.r.t handling prime number of windows in a
281 // major dimension with reduction.
XLA_TEST_P(ReduceWindowTest,PrimeWindowsInReductionDimension)282 XLA_TEST_P(ReduceWindowTest, PrimeWindowsInReductionDimension) {
283   Array4D<float> input_array(15, 15, 4, 128);
284   input_array.FillRandom(2.f, 4.f);
285 
286   int win_len = 3;
287   int win_stride = 2;
288 
289   const auto input_data_handle =
290       CreateConstantFromArray(input_array, &builder_);
291 
292   Padding padding = Padding::kSame;
293   // Reduce only along the x and y dimensions, according to the win_len.
294   ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1},
295                   {win_stride, win_stride, 1, 1}, padding);
296 
297   auto result = ReferenceUtil::ReduceWindow4DAdd(
298       input_array, 0.0f, {win_len, win_len, 1, 1},
299       {win_stride, win_stride, 1, 1}, padding);
300 
301   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
302                            DefaultErrorSpec());
303 }
304 
XLA_TEST_P(ReduceWindowTest,ReduceAlongLaneDimension)305 XLA_TEST_P(ReduceWindowTest, ReduceAlongLaneDimension) {
306   Array4D<float> input_array(19, 17, 8, 256);
307   input_array.FillWithMinorDimNum();
308 
309   const auto input_data_handle =
310       CreateConstantFromArray(input_array, &builder_);
311 
312   Padding padding = Padding::kSame;
313   ReduceWindowAdd(input_data_handle, {1, 1, 1, 11}, {1, 1, 1, 1}, padding);
314 
315   auto result = ReferenceUtil::ReduceWindow4DAdd(
316       input_array, 0.0f, {1, 1, 1, 11}, {1, 1, 1, 1}, padding);
317 
318   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
319                            DefaultErrorSpec());
320 }
321 
322 // Tests a reduction function that is not a simple add/min/max/etc.
XLA_TEST_P(ReduceWindowTest,NonstandardReduceFunction)323 XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) {
324   Array4D<float> input_array(1, 2, 2, 1);
325   input_array(0, 0, 0, 0) = 1;
326   input_array(0, 0, 1, 0) = 2;
327   input_array(0, 1, 0, 0) = 3;
328   input_array(0, 1, 1, 0) = 4;
329   const auto input = CreateConstantFromArray(input_array, &builder_);
330 
331   Padding padding = Padding::kValid;
332   const Shape scalar = ShapeUtil::MakeShape(FloatType(), {});
333   auto b = builder_.CreateSubBuilder("unusual");
334   auto lhs = Parameter(b.get(), 0, scalar, "lhs");
335   auto rhs = Parameter(b.get(), 1, scalar, "rhs");
336   Min(Add(lhs, rhs),
337       CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(8.0f), b.get()));
338   XlaComputation reduce_fn = b->BuildAndNoteError();
339 
340   ReduceWindow(
341       input,
342       CreateConstantFromLiteral(LiteralUtil::CreateR0<float>(0.0f), &builder_),
343       reduce_fn,
344       /*window_dimensions=*/{1, 1, 2, 1},
345       /*window_strides=*/{1, 1, 1, 1}, padding);
346 
347   const auto reduce_func = [](float arg1, float arg2) {
348     return std::min<float>(arg1 + arg2, 8.0f);
349   };
350 
351   auto expected =
352       ReferenceUtil::ReduceWindow4DGeneric(input_array, 0.0f, reduce_func,
353                                            /*window=*/{1, 1, 2, 1},
354                                            /*stride=*/{1, 1, 1, 1}, padding);
355 
356   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*expected),
357                            {}, DefaultErrorSpec());
358 }
359 
XLA_TEST_P(ReduceWindowTest,R4UnitWindow)360 XLA_TEST_P(ReduceWindowTest, R4UnitWindow) {
361   Array4D<float> input_array(13, 12, 8, 15);
362   input_array.FillRandom(2.f, 2.f);
363   Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
364       input_array, LayoutUtil::MakeLayout({0, 3, 2, 1}));
365   XlaOp input;
366   TF_ASSERT_OK_AND_ASSIGN(
367       auto input_data, CreateParameterAndTransferLiteral(
368                            0, input_literal, "parameter", &builder_, &input));
369 
370   Padding padding = Padding::kSame;
371   ReduceWindowAdd(input, {1, 1, 7, 1}, {1, 4, 1, 1}, padding);
372 
373   auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 7, 1},
374                                               {1, 4, 1, 1}, padding);
375 
376   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res),
377                            {input_data.get()}, DefaultErrorSpec());
378 }
379 
XLA_TEST_P(ReduceWindowTest,R6AddMultipleStrides)380 XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) {
381   std::vector<int64_t> input_dims(6, 8);
382   auto shape = ShapeUtil::MakeShape(F32, input_dims);
383 
384   Literal arg_literal(shape);
385   arg_literal.PopulateWithValue(1.0f);
386   const auto input = CreateConstantFromLiteral(arg_literal, &builder_);
387 
388   Padding padding = Padding::kValid;
389   ReduceWindowAdd(input, {3, 1, 3, 3, 1, 1}, {1, 1, 1, 1, 1, 1}, padding);
390 
391   std::vector<int64_t> output_layout = {1, 5, 3, 2, 0, 4};
392   std::vector<int64_t> output_dims = {6, 8, 6, 6, 8, 8};
393   Shape result_shape =
394       ShapeUtil::MakeShapeWithLayout(F32, output_dims, output_layout);
395   Literal expected(result_shape);
396   expected.PopulateWithValue(27.0f);
397   ComputeAndCompareLiteral(&builder_, expected, {}, DefaultErrorSpec());
398 }
399 
XLA_TEST_P(ReduceWindowTest,R6Add)400 XLA_TEST_P(ReduceWindowTest, R6Add) {
401   std::vector<int64_t> input_dims(6, 8);
402   auto shape = ShapeUtil::MakeShape(F32, input_dims);
403 
404   Literal arg_literal =
405       LiteralUtil::CreateFullWithDescendingLayout<float>(input_dims, 1.0f);
406 
407   const auto input = CreateConstantFromLiteral(arg_literal, &builder_);
408 
409   Padding padding = Padding::kValid;
410   ReduceWindowAdd(input, {1, 1, 3, 3, 1, 1}, {1, 1, 1, 1, 1, 1}, padding);
411 
412   std::vector<int64_t> output_dims = {8, 8, 6, 6, 8, 8};
413   Literal expected =
414       LiteralUtil::CreateFullWithDescendingLayout<float>(output_dims, 9.0f);
415 
416   ComputeAndCompareLiteral(&builder_, expected, {}, DefaultErrorSpec());
417 }
418 
XLA_TEST_P(ReduceWindowTest,R4SecondMinorStride)419 XLA_TEST_P(ReduceWindowTest, R4SecondMinorStride) {
420   Array4D<float> input_array(2, 1, 27, 119);
421   input_array.FillRandom(2.0f);
422   Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
423       input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
424   XlaOp input;
425   TF_ASSERT_OK_AND_ASSIGN(
426       auto input_data, CreateParameterAndTransferLiteral(
427                            0, input_literal, "parameter", &builder_, &input));
428 
429   int win_len = 1;
430   int stride = 8;
431   Padding padding = Padding::kSame;
432   ReduceWindowAdd(input, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
433 
434   auto res = ReferenceUtil::ReduceWindow4DAdd(
435       input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
436 
437   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res),
438                            {input_data.get()}, DefaultErrorSpec());
439 }
440 
XLA_TEST_P(ReduceWindowTest,R4SecondMinorUnitStride)441 XLA_TEST_P(ReduceWindowTest, R4SecondMinorUnitStride) {
442   Array4D<float> input_array(3, 2, 4, 64);
443   input_array.FillRandom(2.0f);
444   Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
445       input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
446   XlaOp input;
447   TF_ASSERT_OK_AND_ASSIGN(
448       auto input_data, CreateParameterAndTransferLiteral(
449                            0, input_literal, "parameter", &builder_, &input));
450 
451   int win_len = 3;
452   int stride = 1;
453   Padding padding = Padding::kSame;
454   ReduceWindowAdd(input, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
455 
456   auto res = ReferenceUtil::ReduceWindow4DAdd(
457       input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
458 
459   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res),
460                            {input_data.get()}, DefaultErrorSpec());
461 }
462 
XLA_TEST_P(ReduceWindowTest,R4SecondMinorWin)463 XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) {
464   Array4D<float> input_array(1, 3, 12, 200);
465   input_array.FillRandom(2.0f);
466   Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
467       input_array, LayoutUtil::MakeLayout({3, 2, 1, 0}));
468   XlaOp input;
469   TF_ASSERT_OK_AND_ASSIGN(
470       auto input_data, CreateParameterAndTransferLiteral(
471                            0, input_literal, "parameter", &builder_, &input));
472 
473   int win_len = 8;
474   int stride = 5;
475   Padding padding = Padding::kSame;
476   ReduceWindowAdd(input, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
477 
478   auto res = ReferenceUtil::ReduceWindow4DAdd(
479       input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding);
480 
481   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*res),
482                            {input_data.get()}, DefaultErrorSpec());
483 }
484 
XLA_TEST_P(ReduceWindowTest,AmongMajor2DimsMultipleMinor)485 XLA_TEST_P(ReduceWindowTest, AmongMajor2DimsMultipleMinor) {
486   Array4D<float> input_array(6, 4, 10, 130);
487   input_array.FillRandom(2.0f);
488 
489   int win_len = 3;
490   int win_stride = 2;
491 
492   Padding padding = Padding::kSame;
493   const auto input_data_handle =
494       CreateConstantFromArray(input_array, &builder_);
495   // Reduce only along the x and y dimensions, according to the win_len.
496   ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1},
497                   {win_stride, win_stride, 1, 1}, padding);
498 
499   auto result = ReferenceUtil::ReduceWindow4DAdd(
500       input_array, 0.0f, {win_len, win_len, 1, 1},
501       {win_stride, win_stride, 1, 1}, padding);
502   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateFromArray(*result), {},
503                            DefaultErrorSpec());
504 }
505 
XLA_TEST_P(ReduceWindowTest,Add24In1152_NoOverlap)506 XLA_TEST_P(ReduceWindowTest, Add24In1152_NoOverlap) {
507   std::vector<float> input_vector(128 * 9, 1);
508   const auto input = CreateConstantFromLiteral(
509       LiteralUtil::CreateR1<float>(input_vector), &builder_);
510   ReduceWindowAdd(input, {32}, {128}, Padding::kValid);
511   ComputeAndCompareLiteral(
512       &builder_,
513       LiteralUtil::CreateR1<float>({32, 32, 32, 32, 32, 32, 32, 32, 32}), {},
514       DefaultErrorSpec());
515 }
516 
XLA_TEST_P(ReduceWindowTest,Add128In128Stride128)517 XLA_TEST_P(ReduceWindowTest, Add128In128Stride128) {
518   std::vector<float> input_vector{
519       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
520       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
521       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
522       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
523       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
524       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
525       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
526       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
527   const auto input = CreateConstantFromLiteral(
528       LiteralUtil::CreateR1<float>(input_vector), &builder_);
529   ReduceWindowAdd(input, {128}, {128}, Padding::kValid);
530   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1<float>({1088}), {},
531                            DefaultErrorSpec());
532 }
533 
XLA_TEST_P(ReduceWindowTest,Add128In128)534 XLA_TEST_P(ReduceWindowTest, Add128In128) {
535   std::vector<float> input_vector{
536       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
537       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
538       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
539       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
540       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
541       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
542       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
543       1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
544   const auto input = CreateConstantFromLiteral(
545       LiteralUtil::CreateR1<float>(input_vector), &builder_);
546   ReduceWindowAdd(input, {128}, {1}, Padding::kValid);
547   ComputeAndCompareLiteral(&builder_, LiteralUtil::CreateR1<float>({1088}), {},
548                            DefaultErrorSpec());
549 }
550 
551 // Regression test for a bug that appeared in Inception (b/34784899).
XLA_TEST_P(ReduceWindowTest,R2ReduceWindowInceptionFromBroadcast)552 XLA_TEST_P(ReduceWindowTest, R2ReduceWindowInceptionFromBroadcast) {
553   Array2D<float> input_array(14, 14, 1.0f);
554   const auto input = CreateConstantFromArray(input_array, &builder_);
555   int win_len = 3;
556   int stride = 1;
557   Padding padding = Padding::kSame;
558   ReduceWindowAdd(input, {win_len, win_len}, {stride, stride}, padding);
559   ComputeAndCompare(&builder_, {}, DefaultErrorSpec());
560 }
561 
XLA_TEST_P(ReduceWindowTest,R2ReduceWindowNonOverlappingFromBroadcast)562 XLA_TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) {
563   Array2D<float> input_array(6, 4, 1.0f);
564   XlaOp input = Broadcast(
565       CreateConstantFromLiteral(LiteralUtil::One(F32), &builder_), {6, 4});
566   Padding padding = Padding::kSame;
567   ReduceWindowAdd(input, {4, 2}, {3, 3}, padding);
568   ComputeAndCompare(&builder_, {}, DefaultErrorSpec());
569 }
570 
571 INSTANTIATE_TEST_CASE_P(ReduceWindowTestInstance, ReduceWindowTest,
572                         ::testing::ValuesIn(use_bfloat16_params));
573 
574 enum Reducer { kAdd, kMax };
575 
576 struct R4ReduceWindowTestData {
577   int64_t base_bounds[4];
578   int64_t window_bounds[4];
579   int64_t strides[4];
580   int64_t pad_low[4];
581   int64_t pad_high[4];
582   int64_t layout[4];
583 
584   Reducer reducer;
585 };
586 
R4ReduceWindowTestDataToString(const::testing::TestParamInfo<::testing::tuple<R4ReduceWindowTestData,bool>> & data)587 std::string R4ReduceWindowTestDataToString(
588     const ::testing::TestParamInfo<
589         ::testing::tuple<R4ReduceWindowTestData, bool>>& data) {
590   const auto& param = ::testing::get<0>(data.param);
591   std::string str = absl::StrCat(
592       "base_bounds_", absl::StrJoin(param.base_bounds, "x"),        //
593       "__window_bounds_", absl::StrJoin(param.window_bounds, "x"),  //
594       "__strides_", absl::StrJoin(param.strides, "x"),              //
595       "__pad_low_", absl::StrJoin(param.pad_low, "x"),              //
596       "__pad_high_", absl::StrJoin(param.pad_high, "x"),            //
597       "__layout_", absl::StrJoin(param.layout, "_"),                //
598       (param.reducer == kAdd) ? "_add" : "_max");
599   CHECK(param.reducer == kAdd || param.reducer == kMax);
600 
601   // Test names are not allowed to contain the '-' character.
602   std::replace(str.begin(), str.end(), '-', 'n');
603   if (::testing::get<1>(data.param)) {
604     absl::StrAppend(&str, "_bfloat16");
605   }
606   return str;
607 }
608 
609 class R4ReduceWindowTest : public ReduceWindowTestBase,
610                            public ::testing::WithParamInterface<
611                                ::testing::tuple<R4ReduceWindowTestData, bool>> {
612  protected:
R4ReduceWindowTest()613   R4ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
614 
DoIt()615   void DoIt() {
616     XlaBuilder b(TestName());
617     const auto& param = ::testing::get<0>(GetParam());
618 
619     const float kInitValue = 0.0f;
620 
621     Array4D<float> input(param.base_bounds[0], param.base_bounds[1],
622                          param.base_bounds[2], param.base_bounds[3]);
623     // Choose a prime iota length so that each window sees a unique set of
624     // values. (Technically, the requirement is that the iota length is
625     // relatively prime to all of the dimensions involved in the reduce-window.)
626     input.FillRepeatedIota(0, 137);
627     // Floating point sum reduction requires higher localized precision. We need
628     // the following normalization in order to enable testing of kAdd on large
629     // windows.
630     input.Each([&](absl::Span<const int64_t> /*indices*/, float* value) {
631       *value = *value / 10000000000.f;
632     });
633     Literal input_literal = LiteralUtil::CreateR4FromArray4DWithLayout(
634         input, LayoutUtil::MakeLayout(param.layout));
635     XlaOp parameter;
636     TF_ASSERT_OK_AND_ASSIGN(auto input_arg,
637                             CreateParameterAndTransferLiteral(
638                                 0, input_literal, "p0", &b, &parameter));
639 
640     std::vector<std::pair<int64_t, int64_t>> padding(4);
641     for (int i = 0; i < 4; ++i) {
642       padding[i] = {param.pad_low[i], param.pad_high[i]};
643     }
644 
645     auto init_value =
646         CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b);
647     CHECK(param.reducer == kAdd || param.reducer == kMax);
648     auto reducer = param.reducer;
649     auto computation = reducer == kAdd
650                            ? CreateScalarAddComputation(FloatType(), &b)
651                            : CreateScalarMaxComputation(FloatType(), &b);
652     ReduceWindowWithGeneralPadding(
653         /*operand=*/parameter,
654         /*init_value=*/init_value,
655         /*computation=*/computation,
656         /*window_dimensions=*/param.window_bounds,
657         /*window_strides=*/param.strides,
658         /*base_dilations=*/{},
659         /*window_dilations=*/{},
660         /*padding=*/padding);
661 
662     CHECK(reducer == kAdd || reducer == kMax);
663     auto reduce_func = reducer == kAdd
664                            ? +[](float a, float b) { return a + b; }
665                            : +[](float a, float b) { return std::max(a, b); };
666     std::unique_ptr<Array4D<float>> expected =
667         ReferenceUtil::ReduceWindow4DGeneric(
668             /*operand=*/input,
669             /*init=*/kInitValue,
670             /*reduce_func=*/reduce_func,
671             /*window=*/param.window_bounds,
672             /*stride=*/param.strides,
673             /*padding=*/padding);
674     Literal expected_literal = LiteralUtil::CreateFromArray(*expected);
675     const Shape& expected_shape_with_layout = ShapeUtil::MakeShapeWithLayout(
676         input_literal.shape().element_type(),
677         expected_literal.shape().dimensions(), param.layout);
678     ComputeAndCompareLiteral(&b, expected_literal, {input_arg.get()},
679                              DefaultErrorSpec(), &expected_shape_with_layout);
680   }
681 };
682 
XLA_TEST_P(R4ReduceWindowTest,DoIt)683 XLA_TEST_P(R4ReduceWindowTest, DoIt) { DoIt(); }
684 
685 // base_bounds, window_bounds, strides, pad_low, pad_high
686 const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = {
687     // Minimal edge case.
688     R4ReduceWindowTestData{/*base_bounds=*/{1, 1, 1, 1},
689                            /*window_bounds=*/{1, 1, 1, 1},
690                            /*strides=*/{1, 1, 1, 1},
691                            /*pad_low=*/{0, 0, 0, 0},
692                            /*pad_high=*/{0, 0, 0, 0},
693                            /*layout=*/{3, 2, 1, 0},
694                            /*reducer=*/kAdd},
695 
696     // Arbitrary padding (not kSame or kValid).
697     R4ReduceWindowTestData{/*base_bounds=*/{9, 12, 4, 89},
698                            /*window_bounds=*/{3, 3, 1, 1},
699                            /*strides=*/{2, 2, 1, 1},
700                            /*pad_low=*/{4, 4, 0, 0},
701                            /*pad_high=*/{4, 4, 0, 0},
702                            /*layout=*/{3, 2, 1, 0},
703                            /*reducer=*/kAdd},
704 
705     // Zero base bound edge case.
706     R4ReduceWindowTestData{/*base_bounds=*/{1, 0, 1, 1},
707                            /*window_bounds=*/{1, 1, 1, 1},
708                            /*strides=*/{1, 1, 1, 1},
709                            /*pad_low=*/{0, 0, 0, 0},
710                            /*pad_high=*/{0, 0, 0, 0},
711                            /*layout=*/{3, 2, 1, 0},
712                            /*reducer=*/kAdd},
713 
714     // With max instead of add.
715     R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
716                            /*window_bounds=*/{2, 3, 1, 1},
717                            /*strides=*/{1, 1, 1, 1},
718                            /*pad_low=*/{0, 0, 0, 0},
719                            /*pad_high=*/{0, 0, 0, 0},
720                            /*layout=*/{3, 2, 1, 0},
721                            /*reducer=*/kMax},
722 
723     // With stride.
724     R4ReduceWindowTestData{/*base_bounds=*/{4, 10, 17, 140},
725                            /*window_bounds=*/{3, 2, 1, 1},
726                            /*strides=*/{2, 4, 1, 1},
727                            /*pad_low=*/{0, 0, 0, 0},
728                            /*pad_high=*/{0, 0, 0, 0},
729                            /*layout=*/{3, 2, 1, 0},
730                            /*reducer=*/kAdd},
731 
732     // With low padding.
733     R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
734                            /*window_bounds=*/{3, 2, 1, 1},
735                            /*strides=*/{2, 2, 1, 1},
736                            /*pad_low=*/{3, 2, 0, 0},
737                            /*pad_high=*/{0, 0, 0, 0},
738                            /*layout=*/{3, 2, 1, 0},
739                            /*reducer=*/kAdd},
740 
741     // With high padding.
742     R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
743                            /*window_bounds=*/{3, 2, 1, 1},
744                            /*strides=*/{2, 2, 1, 1},
745                            /*pad_low=*/{0, 0, 0, 0},
746                            /*pad_high=*/{2, 3, 0, 0},
747                            /*layout=*/{3, 2, 1, 0},
748                            /*reducer=*/kAdd},
749 
750     // With negative padding on both ends.
751     R4ReduceWindowTestData{/*base_bounds=*/{10, 10, 17, 140},
752                            /*window_bounds=*/{3, 2, 1, 1},
753                            /*strides=*/{2, 2, 1, 1},
754                            /*pad_low=*/{-3, -2, 0, 0},
755                            /*pad_high=*/{-2, -3, 0, 0},
756                            /*layout=*/{3, 2, 1, 0},
757                            /*reducer=*/kAdd},
758 
759     // With negative low padding and positive high padding.
760     R4ReduceWindowTestData{/*base_bounds=*/{10, 10, 17, 140},
761                            /*window_bounds=*/{3, 2, 1, 1},
762                            /*strides=*/{2, 2, 1, 1},
763                            /*pad_low=*/{-3, -2, 0, 0},
764                            /*pad_high=*/{2, 3, 0, 0},
765                            /*layout=*/{3, 2, 1, 0},
766                            /*reducer=*/kAdd},
767 
768     // With positive low padding and negative high padding.
769     R4ReduceWindowTestData{/*base_bounds=*/{10, 10, 17, 140},
770                            /*window_bounds=*/{3, 2, 1, 1},
771                            /*strides=*/{2, 2, 1, 1},
772                            /*pad_low=*/{3, 2, 0, 0},
773                            /*pad_high=*/{-2, -3, 0, 0},
774                            /*layout=*/{3, 2, 1, 0},
775                            /*reducer=*/kAdd},
776 
777     // Window touches both sides of the padding simultaneously.
778     R4ReduceWindowTestData{/*base_bounds=*/{1, 1, 17, 140},
779                            /*window_bounds=*/{3, 3, 1, 1},
780                            /*strides=*/{1, 1, 1, 1},
781                            /*pad_low=*/{1, 1, 0, 0},
782                            /*pad_high=*/{1, 1, 0, 0},
783                            /*layout=*/{3, 2, 1, 0},
784                            /*reducer=*/kAdd},
785 
786     // Window is entirely in the padding for some positions.
787     R4ReduceWindowTestData{/*base_bounds=*/{1, 1, 17, 140},
788                            /*window_bounds=*/{3, 3, 1, 1},
789                            /*strides=*/{1, 1, 1, 1},
790                            /*pad_low=*/{4, 4, 0, 0},
791                            /*pad_high=*/{4, 4, 0, 0},
792                            /*layout=*/{3, 2, 1, 0},
793                            /*reducer=*/kAdd},
794 
795     // Zero base bound with padding edge case.
796     R4ReduceWindowTestData{/*base_bounds=*/{2, 0, 3, 4},
797                            /*window_bounds=*/{1, 1, 1, 1},
798                            /*strides=*/{1, 1, 1, 1},
799                            /*pad_low=*/{0, 1, 0, 0},
800                            /*pad_high=*/{0, 0, 0, 0},
801                            /*layout=*/{3, 2, 1, 0},
802                            /*reducer=*/kAdd},
803 
804     // With stride, low padding and high padding.
805     R4ReduceWindowTestData{/*base_bounds=*/{4, 3, 17, 140},
806                            /*window_bounds=*/{3, 4, 1, 1},
807                            /*strides=*/{3, 1, 1, 1},
808                            /*pad_low=*/{10, 1, 0, 0},
809                            /*pad_high=*/{2, 3, 0, 0},
810                            /*layout=*/{3, 2, 1, 0},
811                            /*reducer=*/kAdd},
812 
813     // With minor dimension == 129.
814     R4ReduceWindowTestData{/*base_bounds=*/{3, 2, 7, 129},
815                            /*window_bounds=*/{1, 1, 1, 1},
816                            /*strides=*/{1, 1, 1, 1},
817                            /*pad_low=*/{0, 0, 0, 0},
818                            /*pad_high=*/{0, 0, 0, 0},
819                            /*layout=*/{3, 2, 1, 0},
820                            /*reducer=*/kAdd},
821 
822     // With minor dims reduction and non-overlapped stride.
823     R4ReduceWindowTestData{/*base_bounds=*/{2, 2, 4, 16},
824                            /*window_bounds=*/{1, 1, 2, 2},
825                            /*strides=*/{1, 1, 2, 2},
826                            /*pad_low=*/{0, 0, 0, 0},
827                            /*pad_high=*/{0, 0, 0, 0},
828                            /*layout=*/{3, 2, 1, 0},
829                            /*reducer=*/kAdd},
830 
831     // With minor dims reduction and overlapped stride.
832     R4ReduceWindowTestData{/*base_bounds=*/{2, 2, 4, 16},
833                            /*window_bounds=*/{1, 1, 4, 4},
834                            /*strides=*/{1, 1, 2, 2},
835                            /*pad_low=*/{0, 0, 0, 0},
836                            /*pad_high=*/{1, 0, 0, 0},
837                            /*layout=*/{3, 2, 1, 0},
838                            /*reducer=*/kAdd},
839 
840     R4ReduceWindowTestData{/*base_bounds=*/{8, 100, 100, 3},
841                            /*window_bounds=*/{1, 64, 64, 1},
842                            /*strides=*/{1, 64, 64, 1},
843                            /*pad_low=*/{0, 0, 0, 0},
844                            /*pad_high=*/{0, 0, 0, 0},
845                            /*layout=*/{3, 0, 2, 1},
846                            /*reducer=*/kAdd},
847 
848     R4ReduceWindowTestData{/*base_bounds=*/{112, 112, 8, 64},
849                            /*window_bounds=*/{112, 112, 1, 8},
850                            /*strides=*/{112, 112, 1, 8},
851                            /*pad_low=*/{0, 0, 0, 0},
852                            /*pad_high=*/{0, 0, 0, 0},
853                            /*layout=*/{3, 2, 1, 0},
854                            /*reducer=*/kMax},
855 
856     R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
857                            /*window_bounds=*/{2, 3, 4, 5},
858                            /*strides=*/{1, 1, 1, 1},
859                            /*pad_low=*/{0, 0, 0, 0},
860                            /*pad_high=*/{0, 0, 0, 0},
861                            /*layout=*/{3, 2, 1, 0},
862                            /*reducer=*/kAdd},
863 
864     // With 0321 layout.
865     R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140},
866                            /*window_bounds=*/{2, 3, 4, 5},
867                            /*strides=*/{1, 2, 3, 4},
868                            /*pad_low=*/{0, 0, 0, 0},
869                            /*pad_high=*/{0, 0, 0, 0},
870                            /*layout=*/{0, 3, 2, 1},
871                            /*reducer=*/kAdd},
872 
873     // With 0123 layout.
874     R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 13, 17},
875                            /*window_bounds=*/{2, 3, 7, 9},
876                            /*strides=*/{1, 2, 5, 8},
877                            /*pad_low=*/{0, 0, 0, 0},
878                            /*pad_high=*/{0, 0, 0, 0},
879                            /*layout=*/{0, 1, 2, 3},
880                            /*reducer=*/kAdd},
881 };
882 
883 INSTANTIATE_TEST_CASE_P(
884     R4ReduceWindowTestInstantiation, R4ReduceWindowTest,
885     ::testing::Combine(::testing::ValuesIn(kR4ReduceWindowTestValues),
886                        ::testing::ValuesIn(use_bfloat16_params)),
887     R4ReduceWindowTestDataToString);
888 
889 class R4ReduceWindowLargeTest : public R4ReduceWindowTest {};
890 
XLA_TEST_P(R4ReduceWindowLargeTest,DISABLED_ON_INTERPRETER (DoIt))891 XLA_TEST_P(R4ReduceWindowLargeTest, DISABLED_ON_INTERPRETER(DoIt)) { DoIt(); }
892 
893 // Test cases that are large/slow/failed.
894 const R4ReduceWindowTestData kR4ReduceWindowLargeTestValues[] = {
895     R4ReduceWindowTestData{/*base_bounds=*/{28, 28, 256, 128},
896                            /*window_bounds=*/{3, 3, 1, 5},
897                            /*strides=*/{1, 1, 1, 5},
898                            /*pad_low=*/{1, 1, 0, 0},
899                            /*pad_high=*/{1, 1, 0, 0},
900                            /*layout=*/{3, 2, 1, 0},
901                            /*reducer=*/kMax},
902 
903     R4ReduceWindowTestData{/*base_bounds=*/{112, 112, 64, 128},
904                            /*window_bounds=*/{3, 3, 1, 1},
905                            /*strides=*/{2, 2, 1, 1},
906                            /*pad_low=*/{0, 0, 0, 0},
907                            /*pad_high=*/{1, 1, 0, 0},
908                            /*layout=*/{3, 2, 1, 0},
909                            /*reducer=*/kAdd},
910 
911     R4ReduceWindowTestData{/*base_bounds=*/{1, 1, 32768 - 3, 2},
912                            /*window_bounds=*/{1, 1, 4, 1},
913                            /*strides=*/{1, 1, 4, 1},
914                            /*pad_low=*/{0, 0, 1, 0},
915                            /*pad_high=*/{0, 0, 2, 0},
916                            /*layout=*/{3, 2, 1, 0},
917                            /*reducer=*/kMax},
918 
919     // Patterns generated by cumsum/cumprod.
920     R4ReduceWindowTestData{/*base_bounds=*/{1021, 1, 16, 16},
921                            /*window_bounds=*/{1021, 1, 1, 1},
922                            /*strides=*/{1, 1, 1, 1},
923                            /*pad_low=*/{1020, 0, 0, 0},
924                            /*pad_high=*/{0, 0, 0, 0},
925                            /*layout=*/{3, 2, 1, 0},
926                            /*reducer=*/kAdd},
927 
928     R4ReduceWindowTestData{/*base_bounds=*/{1021, 1, 16, 16},
929                            /*window_bounds=*/{1, 1, 1021, 1},
930                            /*strides=*/{1, 1, 1, 1},
931                            /*pad_low=*/{0, 0, 1020, 0},
932                            /*pad_high=*/{0, 0, 0, 0},
933                            /*layout=*/{3, 2, 1, 0},
934                            /*reducer=*/kAdd},
935 
936     R4ReduceWindowTestData{/*base_bounds=*/{16, 1, 16, 1021},
937                            /*window_bounds=*/{1, 1, 1, 1021},
938                            /*strides=*/{1, 1, 1, 1},
939                            /*pad_low=*/{0, 0, 0, 1020},
940                            /*pad_high=*/{0, 0, 0, 0},
941                            /*layout=*/{3, 2, 1, 0},
942                            /*reducer=*/kAdd},
943 
944     R4ReduceWindowTestData{/*base_bounds=*/{1021, 1, 16, 16},
945                            /*window_bounds=*/{1021, 1, 1, 1},
946                            /*strides=*/{1, 1, 1, 1},
947                            /*pad_low=*/{1021, 0, 0, 0},
948                            /*pad_high=*/{0, 0, 0, 0},
949                            /*layout=*/{3, 2, 1, 0},
950                            /*reducer=*/kAdd},
951 
952     R4ReduceWindowTestData{/*base_bounds=*/{16, 1, 1021, 16},
953                            /*window_bounds=*/{1, 1, 1021, 1},
954                            /*strides=*/{1, 1, 1, 1},
955                            /*pad_low=*/{0, 0, 1021, 0},
956                            /*pad_high=*/{0, 0, 0, 0},
957                            /*layout=*/{3, 2, 1, 0},
958                            /*reducer=*/kAdd},
959 
960     R4ReduceWindowTestData{/*base_bounds=*/{16, 1, 16, 1021},
961                            /*window_bounds=*/{1, 1, 1, 1021},
962                            /*strides=*/{1, 1, 1, 1},
963                            /*pad_low=*/{0, 0, 0, 1021},
964                            /*pad_high=*/{0, 0, 0, 0},
965                            /*layout=*/{3, 2, 1, 0},
966                            /*reducer=*/kAdd},
967 };
968 
969 INSTANTIATE_TEST_CASE_P(
970     R4ReduceWindowLargeTestInstantiation, R4ReduceWindowLargeTest,
971     ::testing::Combine(::testing::ValuesIn(kR4ReduceWindowLargeTestValues),
972                        ::testing::ValuesIn(use_bfloat16_params)),
973     R4ReduceWindowTestDataToString);
974 
975 struct R3ReduceWindowTestData {
976   int64_t base_bounds[3];
977   int64_t window_bounds[3];
978   int64_t strides[3];
979   int64_t layout[3];
980   Padding padding;
981   Reducer reducer;
982 } kR3TestCases[] = {
983     {/*base_bounds=*/{2, 1, 2}, /*window_bounds=*/{1, 1, 2},
984      /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
985      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
986     {/*base_bounds=*/{4, 3, 3}, /*window_bounds=*/{2, 2, 2},
987      /*strides=*/{2, 2, 2}, /*layout=*/{2, 1, 0},
988      /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd},
989     {/*base_bounds=*/{4, 3, 3}, /*window_bounds=*/{2, 2, 2},
990      /*strides=*/{2, 2, 2}, /*layout=*/{2, 1, 0},
991      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
992     {/*base_bounds=*/{6, 21, 3}, /*window_bounds=*/{2, 3, 2},
993      /*strides=*/{1, 2, 2}, /*layout=*/{2, 1, 0},
994      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
995     {/*base_bounds=*/{10, 21, 129}, /*window_bounds=*/{2, 9, 1},
996      /*strides=*/{5, 2, 1}, /*layout=*/{2, 1, 0},
997      /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd},
998     {/*base_bounds=*/{6, 21, 3}, /*window_bounds=*/{2, 3, 2},
999      /*strides=*/{1, 2, 2}, /*layout=*/{0, 1, 2},
1000      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
1001     {/*base_bounds=*/{6, 21, 3}, /*window_bounds=*/{2, 3, 2},
1002      /*strides=*/{1, 2, 2}, /*layout=*/{1, 0, 2},
1003      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
1004     {/*base_bounds=*/{95, 202, 251}, /*window_bounds=*/{95, 202, 251},
1005      /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
1006      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kMax},
1007     {/*base_bounds=*/{999, 57, 3}, /*window_bounds=*/{999, 57, 3},
1008      /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
1009      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
1010     {/*base_bounds=*/{178, 302, 64}, /*window_bounds=*/{178, 302, 64},
1011      /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
1012      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kMax},
1013     {/*base_bounds=*/{63, 261, 257}, /*window_bounds=*/{63, 261, 257},
1014      /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
1015      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kMax},
1016     {/*base_bounds=*/{10003, 10, 5}, /*window_bounds=*/{9999, 7, 3},
1017      /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
1018      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
1019     {/*base_bounds=*/{9999, 1, 1}, /*window_bounds=*/{9999, 1, 1},
1020      /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0},
1021      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
1022     {/*base_bounds=*/{10003, 10, 5}, /*window_bounds=*/{9999, 7, 3},
1023      /*strides=*/{2, 2, 2}, /*layout=*/{2, 1, 0},
1024      /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
1025 };
1026 
R3ReduceWindowTestDataToString(const::testing::TestParamInfo<::testing::tuple<R3ReduceWindowTestData,bool>> & data)1027 std::string R3ReduceWindowTestDataToString(
1028     const ::testing::TestParamInfo<
1029         ::testing::tuple<R3ReduceWindowTestData, bool>>& data) {
1030   const auto& param = ::testing::get<0>(data.param);
1031   std::string str = absl::StrCat(
1032       "base_bounds_", absl::StrJoin(param.base_bounds, "x"), "__window_bounds_",
1033       absl::StrJoin(param.window_bounds, "x"), "__strides_",
1034       absl::StrJoin(param.strides, "x"), "__padding_",
1035       param.padding == Padding::kSame ? "same" : "valid", "__layout_",
1036       param.layout[0], "_", param.layout[1], "_", param.layout[2], "__reducer_",
1037       param.reducer == kAdd ? "add" : "max");
1038   if (::testing::get<1>(data.param)) {
1039     absl::StrAppend(&str, "_bfloat16");
1040   }
1041   return str;
1042 }
1043 
1044 class R3ReduceWindowTest : public ReduceWindowTestBase,
1045                            public ::testing::WithParamInterface<
1046                                ::testing::tuple<R3ReduceWindowTestData, bool>> {
1047  protected:
R3ReduceWindowTest()1048   R3ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
1049 };
1050 
XLA_TEST_P(R3ReduceWindowTest,DoIt)1051 XLA_TEST_P(R3ReduceWindowTest, DoIt) {
1052   XlaBuilder b(TestName());
1053   const auto& param = ::testing::get<0>(GetParam());
1054 
1055   const float kInitValue = 0.0f;
1056   Array3D<float> input(param.base_bounds[0], param.base_bounds[1],
1057                        param.base_bounds[2]);
1058   // Choose a prime iota length so that each window sees a unique set of values.
1059   // (Technically, the requirement is that the iota length is relatively prime
1060   // to all of the dimensions involved in the reduce-window.)
1061   input.FillRepeatedIota(0, 137);
1062   Literal input_literal = LiteralUtil::CreateR3FromArray3DWithLayout(
1063       input, LayoutUtil::MakeLayout(param.layout));
1064   auto reducer = param.reducer;
1065   if (use_bfloat16()) {
1066     input_literal = LiteralUtil::ConvertF32ToBF16(input_literal);
1067 
1068     // To avoid numerical issues, force the reducer to be kMax for bf16
1069     // inputs.
1070     reducer = kMax;
1071   }
1072 
1073   XlaOp parameter = Parameter(&b, 0, input_literal.shape(), "input");
1074   auto init_value =
1075       CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b);
1076 
1077   auto computation = reducer == kAdd
1078                          ? CreateScalarAddComputation(FloatType(), &b)
1079                          : CreateScalarMaxComputation(FloatType(), &b);
1080 
1081   ReduceWindow(/*operand=*/parameter,
1082                /*init_value=*/init_value,
1083                /*computation=*/computation,
1084                /*window_dimensions=*/param.window_bounds,
1085                /*window_strides=*/param.strides, /*padding=*/param.padding);
1086 
1087   ComputeAndCompare(&b, {std::move(input_literal)}, DefaultErrorSpec());
1088 }
1089 
1090 INSTANTIATE_TEST_CASE_P(
1091     R3ReduceWindowTestInstantiation, R3ReduceWindowTest,
1092     ::testing::Combine(::testing::ValuesIn(kR3TestCases),
1093                        ::testing::ValuesIn(use_bfloat16_params)),
1094     R3ReduceWindowTestDataToString);
1095 
1096 struct R2ReduceWindowTestData {
1097   int64_t base_bounds[2];
1098   int64_t window_bounds[2];
1099   int64_t strides[2];
1100   int64_t base_dilation[2];
1101   int64_t window_dilation[2];
1102   int64_t pad_low[2];
1103   int64_t pad_high[2];
1104   int64_t layout[2];
1105   Reducer reducer;
1106 } kR2TestCases[] = {
1107     {/*base_bounds=*/{4, 18}, /*window_bounds=*/{2, 4},
1108      /*strides=*/{1, 2},
1109      /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1},
1110      /*pad_low=*/{0, 1}, /*pad_high=*/{1, 1},
1111      /*layout=*/{0, 1},
1112      /*reducer=*/Reducer::kAdd},
1113     {/*base_bounds=*/{2, 5}, /*window_bounds=*/{2, 4},
1114      /*strides=*/{1, 1},
1115      /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1},
1116      /*pad_low=*/{0, 1}, /*pad_high=*/{1, 2},
1117      /*layout=*/{0, 1},
1118      /*reducer=*/Reducer::kAdd},
1119     {/*base_bounds=*/{1, 3}, /*window_bounds=*/{2, 3},
1120      /*strides=*/{1, 1},
1121      /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1},
1122      /*pad_low=*/{0, 1}, /*pad_high=*/{1, 1},
1123      /*layout=*/{0, 1},
1124      /*reducer=*/Reducer::kAdd},
1125     {/*base_bounds=*/{3, 129}, /*window_bounds=*/{1, 100},
1126      /*strides=*/{2, 99},
1127      /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1},
1128      /*pad_low=*/{0, 0}, /*pad_high=*/{35, 35},
1129      /*layout=*/{0, 1},
1130      /*reducer=*/Reducer::kAdd},
1131 // TODO(b/74260408): This test last failed on GPU on 2018-03-08, likely due to a
1132 // ptxas bug.
1133 #ifndef XLA_TEST_BACKEND_GPU
1134     {/*base_bounds=*/{6, 152}, /*window_bounds=*/{2, 25},
1135      /*strides=*/{5, 4},
1136      /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1},
1137      /*pad_low=*/{0, 1}, /*pad_high=*/{10, 11},
1138      /*layout=*/{0, 1},
1139      /*reducer=*/Reducer::kAdd},
1140 #endif
1141     {/*base_bounds=*/{6, 4}, /*window_bounds=*/{4, 2},
1142      /*strides=*/{3, 3},
1143      /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1},
1144      /*pad_low=*/{0, 1}, /*pad_high=*/{0, 1},
1145      /*layout=*/{0, 1},
1146      /*reducer=*/Reducer::kAdd},
1147     {/*base_bounds=*/{5, 147}, /*window_bounds=*/{1, 36},
1148      /*strides=*/{4, 5},
1149      /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1},
1150      /*pad_low=*/{0, 0}, /*pad_high=*/{17, 17},
1151      /*layout=*/{1, 0},
1152      /*reducer=*/Reducer::kAdd},
1153     {/*base_bounds=*/{4, 153}, /*window_bounds=*/{2, 93},
1154      /*strides=*/{1, 1},
1155      /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1},
1156      /*pad_low=*/{0, 1}, /*pad_high=*/{46, 46},
1157      /*layout=*/{1, 0},
1158      /*reducer=*/Reducer::kAdd},
1159     // Regression test for a bug that appeared in Inception (b/34784899).
1160     {/*base_bounds=*/{28, 28}, /*window_bounds=*/{3, 3},
1161      /*strides=*/{1, 1},
1162      /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1},
1163      /*pad_low=*/{1, 1}, /*pad_high=*/{1, 1},
1164      /*layout=*/{1, 0},
1165      /*reducer=*/Reducer::kAdd},
1166     {/*base_bounds=*/{4, 4}, /*window_bounds=*/{2, 2},
1167      /*strides=*/{1, 1},
1168      /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1},
1169      /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0},
1170      /*layout=*/{1, 0},
1171      /*reducer=*/Reducer::kMax},
1172     {/*base_bounds=*/{4, 4}, /*window_bounds=*/{2, 2},
1173      /*strides=*/{1, 1},
1174      /*base_dilation=*/{1, 1}, /*window_dilation=*/{2, 2},
1175      /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0},
1176      /*layout=*/{1, 0},
1177      /*reducer=*/Reducer::kMax},
1178     {/*base_bounds=*/{4, 4}, /*window_bounds=*/{2, 2},
1179      /*strides=*/{1, 1},
1180      /*base_dilation=*/{2, 2}, /*window_dilation=*/{1, 1},
1181      /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0},
1182      /*layout=*/{1, 0},
1183      /*reducer=*/Reducer::kMax},
1184     {/*base_bounds=*/{4, 4}, /*window_bounds=*/{2, 2},
1185      /*strides=*/{2, 2},
1186      /*base_dilation=*/{2, 2}, /*window_dilation=*/{1, 1},
1187      /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0},
1188      /*layout=*/{1, 0},
1189      /*reducer=*/Reducer::kMax},
1190     {/*base_bounds=*/{4, 4}, /*window_bounds=*/{2, 2},
1191      /*strides=*/{2, 2},
1192      /*base_dilation=*/{2, 2}, /*window_dilation=*/{1, 1},
1193      /*pad_low=*/{3, 3}, /*pad_high=*/{3, 3},
1194      /*layout=*/{1, 0},
1195      /*reducer=*/Reducer::kMax},
1196     {/*base_bounds=*/{4, 4}, /*window_bounds=*/{2, 2},
1197      /*strides=*/{2, 2},
1198      /*base_dilation=*/{2, 2}, /*window_dilation=*/{2, 2},
1199      /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0},
1200      /*layout=*/{1, 0},
1201      /*reducer=*/Reducer::kMax},
1202     // Regression test for a bug that appeared in Inception (b/34784899).
1203     {/*base_bounds=*/{4, 32}, /*window_bounds=*/{2, 2},
1204      /*strides=*/{2, 2},
1205      /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1},
1206      /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0},
1207      /*layout=*/{1, 0},
1208      /*reducer=*/Reducer::kAdd},
1209     // Regression test for b/73903312: bf16 lacks precision to store result of
1210     // very large windows. Testing with a reasonable window larger than 128.
1211     {/*base_bounds=*/{8, 130}, /*window_bounds=*/{1, 130},
1212      /*strides=*/{1, 1},
1213      /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1},
1214      /*pad_low=*/{0, 130}, /*pad_high=*/{0, 0},
1215      /*layout=*/{1, 0},
1216      /*reducer=*/Reducer::kAdd},
1217     {/*base_bounds=*/{8, 256}, /*window_bounds=*/{1, 4},
1218      /*strides=*/{1, 64},
1219      /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1},
1220      /*pad_low=*/{0, 0}, /*pad_high=*/{0, 0},
1221      /*layout=*/{1, 0}, /*reducer=*/Reducer::kAdd},
1222     {/*base_bounds=*/{4096, 4096}, /*window_bounds=*/{1, 4},
1223      /*strides=*/{1, 1024},
1224      /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1},
1225      /*pad_low=*/{0, 0}, /*pad-high=*/{0, 0},
1226      /*layout=*/{1, 0}, /*reducer=*/Reducer::kAdd},
1227     // Regression test for b/72234705: bf16 lacks precision to store incremental
1228     // results on very large windows. Using smaller window with minor dim 128.
1229     {/*base_bounds=*/{8, 128}, /*window_bounds=*/{2, 128},
1230      /*strides=*/{1, 1},
1231      /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1},
1232      /*pad_low=*/{0, 0}, /*pad-high=*/{0, 0},
1233      /*layout=*/{1, 0}, /*reducer=*/Reducer::kAdd},
1234     // With negative padding on both ends.
1235     {/*base_bounds=*/{4, 18}, /*window_bounds=*/{2, 4},
1236      /*strides=*/{1, 2},
1237      /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1},
1238      /*pad_low=*/{0, -1}, /*pad_high=*/{0, -1},
1239      /*layout=*/{0, 1},
1240      /*reducer=*/Reducer::kAdd},
1241     // With negative low padding and positive high padding.
1242     {/*base_bounds=*/{4, 18}, /*window_bounds=*/{2, 4},
1243      /*strides=*/{1, 2},
1244      /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1},
1245      /*pad_low=*/{0, -1}, /*pad_high=*/{0, 1},
1246      /*layout=*/{0, 1},
1247      /*reducer=*/Reducer::kAdd},
1248     // With positive low padding and negative high padding.
1249     {/*base_bounds=*/{4, 18}, /*window_bounds=*/{2, 4},
1250      /*strides=*/{1, 2},
1251      /*base_dilation=*/{1, 1}, /*window_dilation=*/{1, 1},
1252      /*pad_low=*/{0, 1}, /*pad_high=*/{0, -1},
1253      /*layout=*/{0, 1},
1254      /*reducer=*/Reducer::kAdd},
1255 };
1256 
R2ReduceWindowTestDataToString(const::testing::TestParamInfo<::testing::tuple<R2ReduceWindowTestData,bool>> & data)1257 std::string R2ReduceWindowTestDataToString(
1258     const ::testing::TestParamInfo<
1259         ::testing::tuple<R2ReduceWindowTestData, bool>>& data) {
1260   const auto& param = ::testing::get<0>(data.param);
1261   std::string str = absl::StrCat(
1262       "base_bounds_", absl::StrJoin(param.base_bounds, "x"),            //
1263       "__window_bounds_", absl::StrJoin(param.window_bounds, "x"),      //
1264       "__strides_", absl::StrJoin(param.strides, "x"),                  //
1265       "__base_dilation_", absl::StrJoin(param.base_dilation, "x"),      //
1266       "__window_dilation_", absl::StrJoin(param.window_dilation, "x"),  //
1267       "__pad_low_", absl::StrJoin(param.pad_low, "x"), "__pad_high_",
1268       absl::StrJoin(param.pad_high, "x"), "__layout_", param.layout[0], "_",
1269       param.layout[1],  //
1270       "__reducer_", param.reducer == kAdd ? "add" : "max");
1271 
1272   // Test names are not allowed to contain the '-' character.
1273   std::replace(str.begin(), str.end(), '-', 'n');
1274   if (::testing::get<1>(data.param)) {
1275     absl::StrAppend(&str, "_bfloat16");
1276   }
1277   return str;
1278 }
1279 
1280 class R2ReduceWindowTest : public ReduceWindowTestBase,
1281                            public ::testing::WithParamInterface<
1282                                ::testing::tuple<R2ReduceWindowTestData, bool>> {
1283  protected:
R2ReduceWindowTest()1284   R2ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
1285 
DoIt()1286   void DoIt() {
1287     XlaBuilder b(TestName());
1288     const auto& param = ::testing::get<0>(GetParam());
1289 
1290     Array2D<float> input(param.base_bounds[0], param.base_bounds[1], 1.0f);
1291     if (!::testing::get<1>(GetParam())) {
1292       // We only do this in F32 mode, to avoid precision issues with BF16.
1293       input = *MakeLinspaceArray2D(0, 100, param.base_bounds[0],
1294                                    param.base_bounds[1]);
1295     }
1296     Literal input_literal = LiteralUtil::CreateR2FromArray2DWithLayout(
1297         input, LayoutUtil::MakeLayout(param.layout));
1298 
1299     XlaOp parameter;
1300     TF_ASSERT_OK(CreateParameterAndTransferLiteral(0, input_literal, "p0", &b,
1301                                                    &parameter)
1302                      .status());
1303 
1304     std::vector<std::pair<int64_t, int64_t>> padding(2);
1305     for (int i = 0; i < 2; ++i) {
1306       padding[i] = {param.pad_low[i], param.pad_high[i]};
1307     }
1308     auto computation = param.reducer == kAdd
1309                            ? CreateScalarAddComputation(FloatType(), &b)
1310                            : CreateScalarMaxComputation(FloatType(), &b);
1311     const float kInitValue = 0.0f;
1312     auto init_value =
1313         CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b);
1314     ReduceWindowWithGeneralPadding(
1315         /*operand=*/parameter,
1316         /*init_value=*/init_value,
1317         /*computation=*/computation,
1318         /*window_dimensions=*/param.window_bounds,
1319         /*window_strides=*/param.strides,
1320         /*base_dilations=*/param.base_dilation,
1321         /*window_dilations=*/param.window_dilation,
1322         /*padding=*/padding);
1323 
1324     ComputeAndCompare(&b, {MaybeConvertLiteralToBfloat16(input_literal)},
1325                       DefaultErrorSpec());
1326   }
1327 };
1328 
XLA_TEST_P(R2ReduceWindowTest,DoIt)1329 XLA_TEST_P(R2ReduceWindowTest, DoIt) { DoIt(); }
1330 
1331 INSTANTIATE_TEST_CASE_P(
1332     R2ReduceWindowTestInstantiation, R2ReduceWindowTest,
1333     ::testing::Combine(::testing::ValuesIn(kR2TestCases),
1334                        ::testing::ValuesIn(use_bfloat16_params)),
1335     R2ReduceWindowTestDataToString);
1336 
1337 struct R1ReduceWindowTestData {
1338   int64_t base_bounds[1];
1339   int64_t window_bounds[1];
1340   int64_t strides[1];
1341   int64_t pad_low[1];
1342   int64_t pad_high[1];
1343   Reducer reducer;
1344 } kR1TestCases[] = {
1345     {/*base_bounds=*/{1}, /*window_bounds=*/{1},
1346      /*strides=*/{1},
1347      /*pad_low=*/{xla::MakePadding({1}, {1}, {1}, Padding::kValid)[0].first},
1348      /*pad_high=*/{xla::MakePadding({1}, {1}, {1}, Padding::kValid)[0].second},
1349      /*reducer=*/Reducer::kAdd},
1350 
1351     {/*base_bounds=*/{3}, /*window_bounds=*/{3},
1352      /*strides=*/{1},
1353      /*pad_low=*/{xla::MakePadding({3}, {3}, {1}, Padding::kValid)[0].first},
1354      /*pad_high=*/{xla::MakePadding({3}, {3}, {1}, Padding::kValid)[0].second},
1355      /*reducer=*/Reducer::kAdd},
1356 
1357     {/*base_bounds=*/{3}, /*window_bounds=*/{2},
1358      /*strides=*/{1},
1359      /*pad_low=*/{xla::MakePadding({3}, {2}, {1}, Padding::kValid)[0].first},
1360      /*pad_high=*/{xla::MakePadding({3}, {2}, {1}, Padding::kValid)[0].second},
1361      /*reducer=*/Reducer::kAdd},
1362 
1363     {/*base_bounds=*/{5}, /*window_bounds=*/{1},
1364      /*strides=*/{1},
1365      /*pad_low=*/{xla::MakePadding({5}, {1}, {1}, Padding::kValid)[0].first},
1366      /*pad_high=*/{xla::MakePadding({5}, {1}, {1}, Padding::kValid)[0].second},
1367      /*reducer=*/Reducer::kMax},
1368 
1369     {/*base_bounds=*/{16}, /*window_bounds=*/{4},
1370      /*strides=*/{4},
1371      /*pad_low=*/{xla::MakePadding({16}, {4}, {4}, Padding::kValid)[0].first},
1372      /*pad_high=*/{xla::MakePadding({16}, {4}, {4}, Padding::kValid)[0].second},
1373      /*reducer=*/Reducer::kMax},
1374 
1375     {/*base_bounds=*/{16}, /*window_bounds=*/{4},
1376      /*strides=*/{3},
1377      /*pad_low=*/{xla::MakePadding({16}, {4}, {3}, Padding::kValid)[0].first},
1378      /*pad_high=*/{xla::MakePadding({16}, {4}, {3}, Padding::kValid)[0].second},
1379      /*reducer=*/Reducer::kAdd},
1380 
1381     {/*base_bounds=*/{128 * 2},
1382      /*window_bounds=*/{30},
1383      /*strides=*/{27},
1384      /*pad_low=*/
1385      {xla::MakePadding({128 * 2}, {30}, {27}, Padding::kValid)[0].first},
1386      /*pad_high=*/
1387      {xla::MakePadding({128 * 2}, {30}, {27}, Padding::kValid)[0].second},
1388      /*reducer=*/Reducer::kAdd},
1389 
1390     {/*base_bounds=*/{128 * 17},
1391      /*window_bounds=*/{7},
1392      /*strides=*/{64},
1393      /*pad_low=*/
1394      {xla::MakePadding({128 * 17}, {7}, {64}, Padding::kValid)[0].first},
1395      /*pad_high=*/
1396      {xla::MakePadding({128 * 17}, {7}, {64}, Padding::kValid)[0].second},
1397      /*reducer=*/Reducer::kAdd},
1398 
1399     {/*base_bounds=*/{128 * 2},
1400      /*window_bounds=*/{32},
1401      /*strides=*/{56},
1402      /*pad_low=*/
1403      {xla::MakePadding({128 * 2}, {32}, {56}, Padding::kValid)[0].first},
1404      /*pad_high=*/
1405      {xla::MakePadding({128 * 2}, {32}, {56}, Padding::kValid)[0].second},
1406      /*reducer=*/Reducer::kAdd},
1407 
1408     {/*base_bounds=*/{3}, /*window_bounds=*/{2},
1409      /*strides=*/{1},
1410      /*pad_low=*/{xla::MakePadding({3}, {2}, {1}, Padding::kSame)[0].first},
1411      /*pad_high=*/{xla::MakePadding({3}, {2}, {1}, Padding::kSame)[0].second},
1412      /*reducer=*/Reducer::kAdd},
1413 
1414     {/*base_bounds=*/{5}, /*window_bounds=*/{3},
1415      /*strides=*/{2},
1416      /*pad_low=*/{xla::MakePadding({5}, {3}, {2}, Padding::kSame)[0].first},
1417      /*pad_high=*/{xla::MakePadding({5}, {3}, {2}, Padding::kSame)[0].second},
1418      /*reducer=*/Reducer::kAdd},
1419 
1420     {/*base_bounds=*/{16}, /*window_bounds=*/{4},
1421      /*strides=*/{3},
1422      /*pad_low=*/{xla::MakePadding({16}, {4}, {3}, Padding::kSame)[0].first},
1423      /*pad_high=*/{xla::MakePadding({16}, {4}, {3}, Padding::kSame)[0].second},
1424      /*reducer=*/Reducer::kAdd},
1425 
1426     {/*base_bounds=*/{5}, /*window_bounds=*/{5},
1427      /*strides=*/{1},
1428      /*pad_low=*/{0},
1429      /*pad_high=*/{5},
1430      /*reducer=*/Reducer::kAdd},
1431 
1432     {/*base_bounds=*/{5}, /*window_bounds=*/{5},
1433      /*strides=*/{1},
1434      /*pad_low=*/{5},
1435      /*pad_high=*/{0},
1436      /*reducer=*/Reducer::kAdd},
1437 
1438     // With negative padding on both ends.
1439     {/*base_bounds=*/{15}, /*window_bounds=*/{5},
1440      /*strides=*/{1},
1441      /*pad_low=*/{-5},
1442      /*pad_high=*/{-5},
1443      /*reducer=*/Reducer::kAdd},
1444 
1445     // With negative low padding and positive high padding.
1446     {/*base_bounds=*/{15}, /*window_bounds=*/{5},
1447      /*strides=*/{1},
1448      /*pad_low=*/{-5},
1449      /*pad_high=*/{5},
1450      /*reducer=*/Reducer::kAdd},
1451 
1452     // With positive low padding and negative high padding.
1453     {/*base_bounds=*/{15}, /*window_bounds=*/{5},
1454      /*strides=*/{1},
1455      /*pad_low=*/{5},
1456      /*pad_high=*/{-5},
1457      /*reducer=*/Reducer::kAdd},
1458 
1459     // The pattern generated by inclusive scan (cumsum/cumprod).
1460     {/*base_bounds=*/{4096}, /*window_bounds=*/{4096},
1461      /*strides=*/{1},
1462      /*pad_low=*/{4095},
1463      /*pad_high=*/{0},
1464      /*reducer=*/Reducer::kMax},
1465 
1466     // The pattern generated by exclusive scan (cumsum/cumprod).
1467     {/*base_bounds=*/{4095}, /*window_bounds=*/{4095},
1468      /*strides=*/{1},
1469      /*pad_low=*/{4095},
1470      /*pad_high=*/{0},
1471      /*reducer=*/Reducer::kMax},
1472 
1473     // The pattern generated by inclusive reverse scan (cumsum/cumprod).
1474     {/*base_bounds=*/{4096}, /*window_bounds=*/{4096},
1475      /*strides=*/{1},
1476      /*pad_low=*/{0},
1477      /*pad_high=*/{4095},
1478      /*reducer=*/Reducer::kMax},
1479 
1480     // The pattern generated by exclusive reverse scan (cumsum/cumprod).
1481     {/*base_bounds=*/{4095}, /*window_bounds=*/{4095},
1482      /*strides=*/{1},
1483      /*pad_low=*/{0},
1484      /*pad_high=*/{4095},
1485      /*reducer=*/Reducer::kMax},
1486 };
1487 
R1ReduceWindowTestDataToString(const::testing::TestParamInfo<::testing::tuple<R1ReduceWindowTestData,bool>> & data)1488 std::string R1ReduceWindowTestDataToString(
1489     const ::testing::TestParamInfo<
1490         ::testing::tuple<R1ReduceWindowTestData, bool>>& data) {
1491   const auto& param = ::testing::get<0>(data.param);
1492   std::string str =
1493       absl::StrCat("base_bounds_", absl::StrJoin(param.base_bounds, "x"),
1494                    "__window_bounds_", absl::StrJoin(param.window_bounds, "x"),
1495                    "__strides_", absl::StrJoin(param.strides, "x"),
1496                    "__pad_low_", absl::StrJoin(param.pad_low, "x"),
1497                    "__pad_high_", absl::StrJoin(param.pad_high, "x"),
1498                    "__reducer_", param.reducer == kAdd ? "add" : "max");
1499 
1500   // Test names are not allowed to contain the '-' character.
1501   std::replace(str.begin(), str.end(), '-', 'n');
1502   if (::testing::get<1>(data.param)) {
1503     absl::StrAppend(&str, "_bfloat16");
1504   }
1505   return str;
1506 }
1507 
1508 class R1ReduceWindowTest : public ReduceWindowTestBase,
1509                            public ::testing::WithParamInterface<
1510                                ::testing::tuple<R1ReduceWindowTestData, bool>> {
1511  protected:
R1ReduceWindowTest()1512   R1ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); }
1513 };
1514 
XLA_TEST_P(R1ReduceWindowTest,DoIt)1515 XLA_TEST_P(R1ReduceWindowTest, DoIt) {
1516   XlaBuilder b(TestName());
1517   const auto& param = ::testing::get<0>(GetParam());
1518   CHECK(param.reducer == kAdd || param.reducer == kMax);
1519 
1520   const float kInitValue = 0.0f;
1521   std::vector<float> input_vector(param.base_bounds[0]);
1522   std::iota(std::begin(input_vector), std::end(input_vector), 0);
1523   Literal input_literal =
1524       LiteralUtil::CreateR1(absl::Span<const float>(input_vector));
1525   XlaOp parameter;
1526   TF_ASSERT_OK_AND_ASSIGN(
1527       auto input_arg, CreateParameterAndTransferLiteral(0, input_literal, "p0",
1528                                                         &b, &parameter));
1529 
1530   std::vector<std::pair<int64_t, int64_t>> padding(1);
1531   padding[0] = {param.pad_low[0], param.pad_high[0]};
1532 
1533   auto computation = param.reducer == kAdd
1534                          ? CreateScalarAddComputation(FloatType(), &b)
1535                          : CreateScalarMaxComputation(FloatType(), &b);
1536   auto init_value =
1537       CreateConstantFromLiteral(LiteralUtil::CreateR0(kInitValue), &b);
1538   ReduceWindowWithGeneralPadding(
1539       /*operand=*/parameter,
1540       /*init_value=*/init_value,
1541       /*computation=*/computation,
1542       /*window_dimensions=*/param.window_bounds,
1543       /*window_strides=*/param.strides,
1544       /*base_dilations=*/{},
1545       /*window_dilations=*/{},
1546       /*padding=*/padding);
1547 
1548   auto reduce_func = param.reducer == kAdd
1549                          ? +[](float a, float b) { return a + b; }
1550                          : +[](float a, float b) { return std::max(a, b); };
1551   auto expected = ReferenceUtil::ReduceWindow1DGeneric(
1552       /*operand=*/absl::Span<const float>(input_vector),
1553       /*init=*/kInitValue,
1554       /*reduce_func=*/reduce_func,
1555       /*window=*/param.window_bounds,
1556       /*stride=*/param.strides,
1557       /*padding=*/padding);
1558 
1559   ComputeAndCompareLiteral(&b, LiteralUtil::CreateR1<float>(*expected),
1560                            {input_arg.get()}, DefaultErrorSpec());
1561 }
1562 
1563 INSTANTIATE_TEST_CASE_P(
1564     R1ReduceWindowTestInstantiation, R1ReduceWindowTest,
1565     ::testing::Combine(::testing::ValuesIn(kR1TestCases),
1566                        ::testing::ValuesIn(use_bfloat16_params)),
1567     R1ReduceWindowTestDataToString);
1568 
1569 // Test class for text-based test cases. Note that this compares with the
1570 // results on the interpreter backend.
1571 class ReduceWindowTextTest : public HloTestBase {};
1572 
XLA_TEST_F(ReduceWindowTextTest,R2General256x384)1573 XLA_TEST_F(ReduceWindowTextTest, R2General256x384) {
1574   const std::string hlo_string = R"(
1575 HloModule R2Window
1576 mul {
1577   lhs = f32[] parameter(0)
1578   rhs = f32[] parameter(1)
1579   ROOT mul = f32[] multiply(lhs, rhs)
1580 }
1581 ENTRY R2Window {
1582   operand = f32[256,384]{1,0} parameter(0)
1583   constant = f32[] constant(1)
1584   ROOT reduce-window = f32[256,384]{1,0} reduce-window(operand, constant), window={size=2x3 pad=0_1x1_1}, to_apply=mul
1585 }
1586 )";
1587   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001}));
1588 }
1589 
XLA_TEST_F(ReduceWindowTextTest,R2General256x384Layout01)1590 XLA_TEST_F(ReduceWindowTextTest, R2General256x384Layout01) {
1591   const std::string hlo_string = R"(
1592 HloModule R2Window
1593 mul {
1594 lhs = f32[] parameter(0)
1595 rhs = f32[] parameter(1)
1596 ROOT mul = f32[] multiply(lhs, rhs)
1597 }
1598 ENTRY R2Window {
1599 operand = f32[256,384]{0,1} parameter(0)
1600 constant = f32[] constant(1)
1601 ROOT reduce-window = f32[256,384]{0,1} reduce-window(operand, constant), window={size=2x3 pad=0_1x1_1}, to_apply=mul
1602 }
1603 )";
1604   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001}));
1605 }
1606 
XLA_TEST_F(ReduceWindowTextTest,R2General2x5)1607 XLA_TEST_F(ReduceWindowTextTest, R2General2x5) {
1608   const std::string hlo_string = R"(
1609 HloModule R2Window
1610 mul {
1611   lhs = f32[] parameter(0)
1612   rhs = f32[] parameter(1)
1613   ROOT mul = f32[] multiply(lhs, rhs)
1614 }
1615 ENTRY R2Window {
1616   operand = f32[2,5]{1,0} parameter(0)
1617   constant = f32[] constant(1)
1618   ROOT reduce-window = f32[3,5]{1,0} reduce-window(operand, constant), window={size=2x1 pad=0_2x0_0}, to_apply=mul
1619 }
1620 )";
1621   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001}));
1622 }
1623 
XLA_TEST_F(ReduceWindowTextTest,R2EffectiveScalar)1624 XLA_TEST_F(ReduceWindowTextTest, R2EffectiveScalar) {
1625   const std::string hlo_string = R"(
1626 HloModule R2Window
1627 mul {
1628   lhs = f32[] parameter(0)
1629   rhs = f32[] parameter(1)
1630   ROOT mul = f32[] multiply(lhs, rhs)
1631 }
1632 ENTRY R2Window {
1633   operand = f32[1,1]{1,0} parameter(0)
1634   negate = f32[1,1]{1,0} negate(operand)
1635   constant = f32[] constant(1)
1636   ROOT reduce-window = f32[1,1]{1,0} reduce-window(negate, constant), window={size=1x1 pad=0_0x0_0}, to_apply=mul
1637 }
1638 )";
1639   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001}));
1640 }
1641 
XLA_TEST_F(ReduceWindowTextTest,R3EffectiveScalar)1642 XLA_TEST_F(ReduceWindowTextTest, R3EffectiveScalar) {
1643   const std::string hlo_string = R"(
1644 HloModule R3Window
1645 mul {
1646   lhs = f32[] parameter(0)
1647   rhs = f32[] parameter(1)
1648   ROOT mul = f32[] multiply(lhs, rhs)
1649 }
1650 ENTRY R3Window {
1651   operand = f32[1,1,1]{2,1,0} parameter(0)
1652   negate = f32[1,1,1]{2,1,0} negate(operand)
1653   constant = f32[] constant(1)
1654   ROOT reduce-window = f32[1,1,1]{2,1,0}
1655     reduce-window(negate, constant),
1656     window={size=1x1x1 pad=0_0x0_0x0_0},
1657     to_apply=mul
1658 }
1659 )";
1660   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001}));
1661 }
1662 
XLA_TEST_F(HloTestBase,ReduceWindowIdentity)1663 XLA_TEST_F(HloTestBase, ReduceWindowIdentity) {
1664   const std::string hlo_string = R"(
1665 HloModule ReduceWindowIdentity
1666 identity.pad_to_reduce_window {
1667   param0 = f32[] parameter(0)
1668   ROOT param1 = f32[] parameter(1)
1669 }
1670 ENTRY reduce-window-identity {
1671   operand = f32[1,32,64]{2,1,0} parameter(0)
1672   constant.4466 = f32[] constant(0)
1673   ROOT reduce-window = f32[1,33,64]{2,1,0}
1674     reduce-window(operand, constant.4466),
1675       window={size=1x1x1 pad=0_0x1_0x0_0},
1676       to_apply=identity.pad_to_reduce_window
1677 }
1678 
1679 )";
1680   EXPECT_TRUE(RunAndCompare(hlo_string, std::nullopt));
1681 }
1682 
XLA_TEST_F(HloTestBase,ReduceWindowIdentityNoPadding)1683 XLA_TEST_F(HloTestBase, ReduceWindowIdentityNoPadding) {
1684   const std::string hlo_string = R"(
1685 HloModule ReduceWindowIdentity
1686 identity.pad_to_reduce_window {
1687   param0 = f32[] parameter(0)
1688   ROOT param1 = f32[] parameter(1)
1689 }
1690 ENTRY reduce-window-identity {
1691   operand = f32[1,32,64]{2,1,0} parameter(0)
1692   constant.4466 = f32[] constant(0)
1693   ROOT reduce-window = f32[1,32,64]{2,1,0}
1694     reduce-window(operand, constant.4466),
1695       window={size=1x1x1 pad=0_0x0_0x0_0},
1696       to_apply=identity.pad_to_reduce_window
1697 }
1698 
1699 )";
1700   EXPECT_TRUE(RunAndCompare(hlo_string, std::nullopt));
1701 }
1702 
XLA_TEST_F(HloTestBase,ReduceWindowS32)1703 XLA_TEST_F(HloTestBase, ReduceWindowS32) {
1704   const std::string hlo_string = R"(
1705 HloModule reduce-window
1706 
1707 %identity.pad_to_reduce_window (param0: s32[], param1: s32[]) -> s32[] {
1708   %param0 = s32[] parameter(0)
1709   ROOT %param1 = s32[] parameter(1)
1710 }
1711 
1712 ENTRY %reduce-window (parameter.0: s32[81,8], parameter.1: s32[]) -> s32[82,8] {
1713   %parameter.0 = s32[81,8]{1,0} parameter(0)
1714   %parameter.1 = s32[] parameter(1)
1715   ROOT %reduce-window = s32[82,8]{1,0} reduce-window(s32[81,8]{1,0} %parameter.0, s32[] %parameter.1), window={size=1x1 pad=0_1x0_0}, to_apply=%identity.pad_to_reduce_window
1716 }
1717 
1718 )";
1719   EXPECT_TRUE(RunAndCompare(hlo_string, std::nullopt));
1720 }
1721 
XLA_TEST_F(HloTestBase,ReduceWindowS64)1722 XLA_TEST_F(HloTestBase, ReduceWindowS64) {
1723   const std::string hlo_string = R"(
1724 HloModule reduce-window
1725 
1726 %identity.pad_to_reduce_window (param0: s64[], param1: s64[]) -> s64[] {
1727   %param0 = s64[] parameter(0)
1728   ROOT %param1 = s64[] parameter(1)
1729 }
1730 
1731 ENTRY %reduce-window (parameter.0: s64[81,8], parameter.1: s64[]) -> s64[82,8] {
1732   %parameter.0 = s64[81,8]{1,0} parameter(0)
1733   %parameter.1 = s64[] parameter(1)
1734   ROOT %reduce-window = s64[82,8]{1,0} reduce-window(s64[81,8]{1,0} %parameter.0, s64[] %parameter.1), window={size=1x1 pad=0_1x0_0}, to_apply=%identity.pad_to_reduce_window
1735 }
1736 
1737 )";
1738   EXPECT_TRUE(RunAndCompare(hlo_string, std::nullopt));
1739 }
1740 
XLA_TEST_F(HloTestBase,ReduceWindowF16)1741 XLA_TEST_F(HloTestBase, ReduceWindowF16) {
1742   const std::string hlo_string = R"(
1743 HloModule reduce-window
1744 
1745 %identity.pad_to_reduce_window (param0: f16[], param1: f16[]) -> f16[] {
1746   %param0 = f16[] parameter(0)
1747   ROOT %param1 = f16[] parameter(1)
1748 }
1749 
1750 ENTRY %reduce-window (parameter.0: f16[81,8], parameter.1: f16[]) -> f16[82,8] {
1751   %parameter.0 = f16[81,8]{1,0} parameter(0)
1752   %parameter.1 = f16[] parameter(1)
1753   ROOT %reduce-window = f16[82,8]{1,0} reduce-window(f16[81,8]{1,0} %parameter.0, f16[] %parameter.1), window={size=1x1 pad=0_1x0_0}, to_apply=%identity.pad_to_reduce_window
1754 }
1755 
1756 )";
1757   EXPECT_TRUE(RunAndCompare(hlo_string, std::nullopt));
1758 }
1759 
XLA_TEST_F(ReduceWindowTextTest,R4OnlyDilation)1760 XLA_TEST_F(ReduceWindowTextTest, R4OnlyDilation) {
1761   const std::string hlo_string = R"(
1762 HloModule R4OnlyDilation
1763 mul {
1764   lhs = f32[] parameter(0)
1765   rhs = f32[] parameter(1)
1766   ROOT mul = f32[] multiply(lhs, rhs)
1767 }
1768 ENTRY R4OnlyDilation {
1769   operand = f32[2,2,2,2]{3,2,1,0} parameter(0)
1770   constant = f32[] constant(1)
1771   ROOT reduce-window = f32[3,3,3,3]{3,2,1,0}
1772     reduce-window(operand, constant),
1773     window={size=1x1x1x1 pad=0_0x0_0x0_0x0_0 lhs_dilate=2x2x2x2},
1774     to_apply=mul
1775 }
1776 )";
1777   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001}));
1778 }
1779 
XLA_TEST_F(HloTestBase,DISABLED_ON_GPU (ReduceWindowVariadicSupport))1780 XLA_TEST_F(HloTestBase, DISABLED_ON_GPU(ReduceWindowVariadicSupport)) {
1781   const char* const hlo_string = R"(
1782 HloModule module
1783 
1784 sum {
1785   a0 = f32[] parameter(0)
1786   a1 = f32[] parameter(1)
1787   b0 = f32[] parameter(2)
1788   b1 = f32[] parameter(3)
1789   add0 = f32[] add(a0, b0)
1790   add1 = f32[] add(a1, b1)
1791   ROOT sum2 = (f32[], f32[]) tuple(add0, add1)
1792 }
1793 
1794 ENTRY entry {
1795   constant = f32[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}})
1796   constant.1 = f32[] constant(0)
1797   constant.2 = f32[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}})
1798   constant.3 = f32[] constant(0)
1799   reduce-window = (f32[2,2]{1,0}, f32[2,2]{1,0})
1800     reduce-window(constant, constant.2, constant.1, constant.3),
1801     window={size=5x1 stride=3x1 pad=2_2x0_0}, to_apply=sum
1802   ROOT copy = (f32[2,2]{1,0}, f32[2,2]{1,0}) copy(reduce-window)
1803 })";
1804   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-4, 1e-4}));
1805 }
1806 
XLA_TEST_F(HloTestBase,DISABLED_ON_GPU (ReduceWindowVariadicSupport2))1807 XLA_TEST_F(HloTestBase, DISABLED_ON_GPU(ReduceWindowVariadicSupport2)) {
1808   const char* const hlo_string = R"(
1809 HloModule module
1810 
1811 sum {
1812   a0 = f32[] parameter(0)
1813   a1 = s32[] parameter(1)
1814   b0 = f32[] parameter(2)
1815   b1 = s32[] parameter(3)
1816   add0 = f32[] add(a0, b0)
1817   add1 = s32[] add(a1, b1)
1818   ROOT sum2 = (f32[], s32[]) tuple(add0, add1)
1819 }
1820 
1821 ENTRY entry {
1822   constant = f32[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}})
1823   constant.1 = f32[] constant(0)
1824   constant.2 = s32[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}})
1825   constant.3 = s32[] constant(0)
1826   ROOT reduce-window = (f32[2,2]{1,0}, s32[2,2]{1,0})
1827     reduce-window(constant, constant.2, constant.1, constant.3),
1828     window={size=5x1 stride=3x1 pad=2_2x0_0}, to_apply=sum
1829 })";
1830   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-4, 1e-4}));
1831 }
1832 
XLA_TEST_F(HloTestBase,DISABLED_ON_GPU (ReduceWindowVariadicSupport3))1833 XLA_TEST_F(HloTestBase, DISABLED_ON_GPU(ReduceWindowVariadicSupport3)) {
1834   const char* const hlo_string = R"(
1835 HloModule module
1836 
1837 sum {
1838   a0 = f32[] parameter(0)
1839   a1 = bf16[] parameter(1)
1840   b0 = f32[] parameter(2)
1841   b1 = bf16[] parameter(3)
1842   add0 = f32[] add(a0, b0)
1843   add1 = bf16[] add(a1, b1)
1844   ROOT sum2 = (f32[], bf16[]) tuple(add0, add1)
1845 }
1846 
1847 ENTRY entry {
1848   constant = f32[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}})
1849   constant.1 = f32[] constant(0)
1850   constant.2 = bf16[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}})
1851   constant.3 = bf16[] constant(0)
1852   ROOT reduce-window = (f32[2,2]{1,0}, bf16[2,2]{1,0})
1853     reduce-window(constant, constant.2, constant.1, constant.3),
1854     window={size=5x1 stride=3x1 pad=2_2x0_0}, to_apply=sum
1855 })";
1856   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-4, 1e-4}));
1857 }
1858 
XLA_TEST_F(HloTestBase,DISABLED_ON_GPU (ReduceWindowVariadicSupport4))1859 XLA_TEST_F(HloTestBase, DISABLED_ON_GPU(ReduceWindowVariadicSupport4)) {
1860   const char* const hlo_string = R"(
1861 HloModule module
1862 
1863 sum {
1864   a0 = f32[] parameter(0)
1865   a1 = bf16[] parameter(1)
1866   b0 = f32[] parameter(2)
1867   b1 = bf16[] parameter(3)
1868   add0 = f32[] add(a0, b0)
1869   add1 = bf16[] multiply(a1, b1)
1870   ROOT sum2 = (f32[], bf16[]) tuple(add0, add1)
1871 }
1872 
1873 ENTRY entry {
1874   constant = f32[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}})
1875   constant.1 = f32[] constant(0)
1876   constant.2 = bf16[4,2]{1,0} constant({{1,1},{1,4},{2,1},{3,1}})
1877   constant.3 = bf16[] constant(1)
1878   ROOT reduce-window = (f32[2,2]{1,0}, bf16[2,2]{1,0})
1879     reduce-window(constant, constant.2, constant.1, constant.3),
1880     window={size=5x1 stride=3x1 pad=2_2x0_0}, to_apply=sum
1881 })";
1882   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-4, 1e-4}));
1883 }
1884 
XLA_TEST_F(HloTestBase,DISABLED_ON_GPU (ReduceWindowS64Support))1885 XLA_TEST_F(HloTestBase, DISABLED_ON_GPU(ReduceWindowS64Support)) {
1886   const char* const hlo_string = R"(
1887 HloModule jit_dilated_window_sum.10
1888 
1889 %primitive_computation_add.4 (parameter.5: s64[], parameter.6: s64[]) -> s64[] {
1890   %parameter.5 = s64[] parameter(0), metadata={op_type="add" op_name="add"}
1891   %parameter.6 = s64[] parameter(1), metadata={op_type="add" op_name="add"}
1892   ROOT %add.7 = s64[] add(s64[] %parameter.5, s64[] %parameter.6), metadata={op_type="add" op_name="add"}
1893 }
1894 
1895 ENTRY %jit_dilated_window_sum.10 (parameter.1: s64[8,10,12]) -> (s64[8,10,12]) {
1896   %constant.2 = pred[] constant(false)
1897   %parameter.1 = s64[8,10,12]{2,1,0} parameter(0)
1898   %constant.3 = s64[] constant(0), metadata={op_type="reduce_window_sum" op_name="jit(dilated_window_sum)/reduce_window_sum[ base_dilation=(1, 1, 1)\n                                           padding=((1, 1), (1, 1), (0, 0))\n                                           window_dilation=(1, 2, 2)\n                                           window_dimensions=(3, 2, 1)\n                                           window_strides=(1, 1, 1) ]" source_file="<ipython-input-1-315291761729>" source_line=9}
1899   %reduce-window.8 = s64[8,10,12]{2,1,0} reduce-window(s64[8,10,12]{2,1,0} %parameter.1, s64[] %constant.3), window={size=3x2x1 pad=1_1x1_1x0_0 rhs_dilate=1x2x2}, to_apply=%primitive_computation_add.4, metadata={op_type="reduce_window_sum" op_name="jit(dilated_window_sum)/reduce_window_sum[ base_dilation=(1, 1, 1)\n                                           padding=((1, 1), (1, 1), (0, 0))\n                                           window_dilation=(1, 2, 2)\n                                           window_dimensions=(3, 2, 1)\n                                           window_strides=(1, 1, 1) ]" source_file="<ipython-input-1-315291761729>" source_line=9}
1900   ROOT %tuple.9 = (s64[8,10,12]{2,1,0}) tuple(s64[8,10,12]{2,1,0} %reduce-window.8)
1901 })";
1902   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-4, 1e-4}));
1903 }
1904 
XLA_TEST_F(HloTestBase,DISABLED_ON_GPU (ReduceWindowS64Support2))1905 XLA_TEST_F(HloTestBase, DISABLED_ON_GPU(ReduceWindowS64Support2)) {
1906   const char* const hlo_string = R"(
1907 HloModule SyncTensorsGraph.43
1908 
1909 %MulComputation.18 (x.19: s64[], y.20: s64[]) -> s64[] {
1910   %x.19 = s64[] parameter(0)
1911   %y.20 = s64[] parameter(1)
1912   ROOT %multiply.21 = s64[] multiply(s64[] %x.19, s64[] %y.20)
1913 }
1914 
1915 %MulComputation.26 (x.27: s64[], y.28: s64[]) -> s64[] {
1916   %x.27 = s64[] parameter(0)
1917   %y.28 = s64[] parameter(1)
1918   ROOT %multiply.29 = s64[] multiply(s64[] %x.27, s64[] %y.28)
1919 }
1920 
1921 %MaxComputation.34 (x.35: s64[], y.36: s64[]) -> s64[] {
1922   %x.35 = s64[] parameter(0)
1923   %y.36 = s64[] parameter(1)
1924   ROOT %maximum.37 = s64[] maximum(s64[] %x.35, s64[] %y.36)
1925 }
1926 
1927 ENTRY %SyncTensorsGraph.43 (p0.1: f32[], p1.7: pred[3,3]) -> (pred[]) {
1928   %constant.8 = pred[] constant(false)
1929   %reshape.9 = pred[1,1]{1,0} reshape(pred[] %constant.8)
1930   %broadcast.10 = pred[1,1]{1,0} broadcast(pred[1,1]{1,0} %reshape.9), dimensions={0,1}
1931   %reshape.11 = pred[] reshape(pred[1,1]{1,0} %broadcast.10)
1932   %broadcast.12 = pred[3,3]{1,0} broadcast(pred[] %reshape.11), dimensions={}
1933   %p1.7 = pred[3,3]{1,0} parameter(1)
1934   %reshape.13 = pred[3,3]{1,0} reshape(pred[3,3]{1,0} %p1.7)
1935   %reshape.14 = pred[3,3]{1,0} reshape(pred[3,3]{1,0} %reshape.13)
1936   %convert.24 = s64[3,3]{1,0} convert(pred[3,3]{1,0} %reshape.14)
1937   %constant.25 = s64[] constant(1)
1938   %reduce-window.30 = s64[3,3]{1,0} reduce-window(s64[3,3]{1,0} %convert.24, s64[] %constant.25), window={size=3x1 pad=2_0x0_0}, to_apply=%MulComputation.26
1939   %convert.15 = s32[3,3]{1,0} convert(pred[3,3]{1,0} %reshape.14)
1940   %convert.16 = s64[3,3]{1,0} convert(s32[3,3]{1,0} %convert.15)
1941   %constant.17 = s64[] constant(1)
1942   %reduce-window.22 = s64[3,3]{1,0} reduce-window(s64[3,3]{1,0} %convert.16, s64[] %constant.17), window={size=3x1 pad=2_0x0_0}, to_apply=%MulComputation.18
1943   %constant.2 = s64[] constant(1)
1944   %reshape.3 = s64[1,1]{1,0} reshape(s64[] %constant.2)
1945   %broadcast.4 = s64[1,1]{1,0} broadcast(s64[1,1]{1,0} %reshape.3), dimensions={0,1}
1946   %reshape.5 = s64[] reshape(s64[1,1]{1,0} %broadcast.4)
1947   %broadcast.6 = s64[3,3]{1,0} broadcast(s64[] %reshape.5), dimensions={}
1948   %multiply.23 = s64[3,3]{1,0} multiply(s64[3,3]{1,0} %reduce-window.22, s64[3,3]{1,0} %broadcast.6)
1949   %subtract.31 = s64[3,3]{1,0} subtract(s64[3,3]{1,0} %reduce-window.30, s64[3,3]{1,0} %multiply.23)
1950   %abs.32 = s64[3,3]{1,0} abs(s64[3,3]{1,0} %subtract.31)
1951   %constant.33 = s64[] constant(-9223372036854775808)
1952   %reduce.38 = s64[] reduce(s64[3,3]{1,0} %abs.32, s64[] %constant.33), dimensions={0,1}, to_apply=%MaxComputation.34
1953   %reshape.39 = s64[] reshape(s64[] %reduce.38)
1954   %convert.40 = f32[] convert(s64[] %reshape.39)
1955   %p0.1 = f32[] parameter(0)
1956   %compare.41 = pred[] compare(f32[] %convert.40, f32[] %p0.1), direction=LE
1957   ROOT %tuple.42 = (pred[]) tuple(pred[] %compare.41)
1958 })";
1959   EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{1e-4, 1e-4}));
1960 }
1961 
1962 }  // namespace
1963 }  // namespace xla
1964