xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/hlo_evaluator_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
16 
17 #include <initializer_list>
18 #include <memory>
19 #include <string>
20 #include <tuple>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/strings/str_format.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26 #include "tensorflow/compiler/xla/literal.h"
27 #include "tensorflow/compiler/xla/permutation_util.h"
28 #include "tensorflow/compiler/xla/reference_util.h"
29 #include "tensorflow/compiler/xla/service/hlo_computation.h"
30 #include "tensorflow/compiler/xla/service/hlo_element_type_converter.h"
31 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/status.h"
34 #include "tensorflow/compiler/xla/status_macros.h"
35 #include "tensorflow/compiler/xla/statusor.h"
36 #include "tensorflow/compiler/xla/test.h"
37 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
38 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
39 #include "tensorflow/compiler/xla/tests/test_utils.h"
40 #include "tensorflow/compiler/xla/types.h"
41 #include "tensorflow/compiler/xla/util.h"
42 #include "tensorflow/compiler/xla/xla_data.pb.h"
43 #include "tensorflow/core/lib/core/status.h"
44 #include "tensorflow/core/lib/core/status_test_util.h"
45 #include "tensorflow/core/platform/test.h"
46 #include "tensorflow/core/platform/test_benchmark.h"
47 
48 namespace xla {
49 namespace {
50 
51 static std::array<bool, 2> use_bf16_params{true, false};
52 
53 // Test fixture for the HloEvaluator.
54 //
55 // In bf16 mode, all f32 shapes are converted to bf16 before running.
56 class HloEvaluatorTest : public HloTestBase {
57  public:
HloEvaluatorTest()58   HloEvaluatorTest() : use_bfloat16_(false) { InitializeFftData(); }
59 
Evaluate(absl::Span<const Literal * const> arg_literals={})60   StatusOr<Literal> Evaluate(
61       absl::Span<const Literal* const> arg_literals = {}) {
62     if (use_bfloat16_) {
63       HloElementTypeConverter(F32, BF16).Run(m_.get()).ValueOrDie();
64     }
65     return evaluator_.Evaluate(*m_->entry_computation(), arg_literals);
66   }
67 
68   // Evaluate function that takes in a local module instead of using m_
69   // that is in HloTestBase. Once m_ in HloTestBase is
70   // removed, this should be the default Evaluate function.
EvaluateWithModule(HloModule * module,absl::Span<const Literal * const> arg_literals={})71   Literal EvaluateWithModule(
72       HloModule* module, absl::Span<const Literal* const> arg_literals = {}) {
73     if (use_bfloat16_) {
74       HloElementTypeConverter(F32, BF16).Run(m_.get()).ValueOrDie();
75     }
76     return evaluator_.Evaluate(*module->entry_computation(), arg_literals)
77         .value();
78   }
79 
TestUnaryOp(HloOpcode opcode,Literal expected,Literal input,float aabs=0)80   void TestUnaryOp(HloOpcode opcode, Literal expected, Literal input,
81                    float aabs = 0) {
82     HloComputation::Builder b(TestName());
83     auto c1 =
84         b.AddInstruction(HloInstruction::CreateConstant(std::move(input)));
85     b.AddInstruction(HloInstruction::CreateUnary(expected.shape(), opcode, c1));
86     m_->AddEntryComputation(b.Build());
87 
88     TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
89 
90     auto element_type = expected.shape().element_type();
91     if (element_type == F32 || element_type == F64) {
92       ErrorSpec error(aabs);
93       EXPECT_TRUE(LiteralTestUtil::Near(expected, result, error));
94     } else {
95       EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
96     }
97   }
98 
TestBinaryOp(HloOpcode opcode,Literal expected,Literal lhs,Literal rhs)99   void TestBinaryOp(HloOpcode opcode, Literal expected, Literal lhs,
100                     Literal rhs) {
101     HloComputation::Builder b(TestName());
102     auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs)));
103     auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs)));
104     b.AddInstruction(
105         HloInstruction::CreateBinary(expected.shape(), opcode, c1, c2));
106     m_->AddEntryComputation(b.Build());
107 
108     TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
109 
110     EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
111   }
112 
TestTernaryOp(HloOpcode opcode,Literal expected,Literal src0,Literal src1,Literal src2)113   void TestTernaryOp(HloOpcode opcode, Literal expected, Literal src0,
114                      Literal src1, Literal src2) {
115     HloComputation::Builder b(TestName());
116     auto operand0 =
117         b.AddInstruction(HloInstruction::CreateConstant(std::move(src0)));
118     auto operand1 =
119         b.AddInstruction(HloInstruction::CreateConstant(std::move(src1)));
120     auto operand2 =
121         b.AddInstruction(HloInstruction::CreateConstant(std::move(src2)));
122     b.AddInstruction(HloInstruction::CreateTernary(
123         expected.shape(), opcode, operand0, operand1, operand2));
124     m_->AddEntryComputation(b.Build());
125 
126     TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
127 
128     EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
129   }
130 
TestEvaluateInstruction(HloInstruction * instruction,const Literal & expected)131   void TestEvaluateInstruction(HloInstruction* instruction,
132                                const Literal& expected) {
133     TF_ASSERT_OK_AND_ASSIGN(Literal result, evaluator_.Evaluate(instruction));
134     EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
135   }
136 
TestEvaluationFailure(HloInstruction * instruction)137   void TestEvaluationFailure(HloInstruction* instruction) {
138     StatusOr<Literal> result = evaluator_.Evaluate(instruction);
139     EXPECT_TRUE(!result.ok());
140   }
141 
TestRecursivelyEvaluateInstruction(HloInstruction * instruction,const Literal & expected)142   void TestRecursivelyEvaluateInstruction(HloInstruction* instruction,
143                                           const Literal& expected) {
144     TF_ASSERT_OK_AND_ASSIGN(
145         Literal result,
146         evaluator_.Evaluate(
147             instruction,
148             /*recursively_evaluate_nonconstant_operands=*/true));
149     EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
150   }
151 
TestRecursiveEvaluationFailure(HloInstruction * instruction)152   void TestRecursiveEvaluationFailure(HloInstruction* instruction) {
153     StatusOr<Literal> result = evaluator_.Evaluate(
154         instruction, /*recursively_evaluate_nonconstant_operands=*/true);
155     EXPECT_TRUE(!result.ok());
156   }
157 
MaxComputationScalarF32()158   std::unique_ptr<HloComputation> MaxComputationScalarF32() {
159     HloComputation::Builder max_computation("max");
160     Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
161     auto param_lhs = max_computation.AddInstruction(
162         HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
163     auto param_rhs = max_computation.AddInstruction(
164         HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
165     max_computation.AddInstruction(HloInstruction::CreateBinary(
166         scalar_shape, HloOpcode::kMaximum, param_lhs, param_rhs));
167     return max_computation.Build();
168   }
169 
ReduceWindowMaxIotaTest(int window_size,int padding,int stride,int window_dilation,int base_dilation,const Literal & expected)170   void ReduceWindowMaxIotaTest(int window_size, int padding, int stride,
171                                int window_dilation, int base_dilation,
172                                const Literal& expected) {
173     HloComputation::Builder b(TestName());
174 
175     // arg:
176     // f32[4,4] {
177     //  {  0,  1,  2,  3 },
178     //  {  4,  5,  6,  7 },
179     //  {  8,  9, 10, 11 },
180     //  { 12, 13, 14, 15 }
181     // }
182     auto arg_array = std::make_unique<Array2D<float>>(4, 4);
183     arg_array->FillIota(0);
184     auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array);
185 
186     HloInstruction* arg_instruction = b.AddInstruction(
187         HloInstruction::CreateConstant(std::move(arg_literal)));
188     auto init_value = b.AddInstruction(
189         HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
190     auto max_func = m_->AddEmbeddedComputation(MaxComputationScalarF32());
191 
192     Window window;
193     WindowDimension dim;
194     dim.set_size(window_size);
195     dim.set_stride(stride);
196     dim.set_padding_low(padding);
197     dim.set_padding_high(padding);
198     dim.set_window_dilation(window_dilation);
199     dim.set_base_dilation(base_dilation);
200     *window.add_dimensions() = dim;
201     *window.add_dimensions() = dim;
202 
203     int dim0 = expected.shape().dimensions(0);
204     int dim1 = expected.shape().dimensions(1);
205     Shape shape = ShapeUtil::MakeShape(F32, {dim0, dim1});
206     b.AddInstruction(HloInstruction::CreateReduceWindow(
207         shape, arg_instruction, init_value, window, max_func));
208 
209     m_->AddEntryComputation(b.Build());
210     TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
211     EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
212   }
213 
214  protected:
HloEvaluatorTest(bool use_bfloat16)215   explicit HloEvaluatorTest(bool use_bfloat16) : use_bfloat16_(use_bfloat16) {
216     InitializeFftData();
217   }
218 
219   // Initializes data sets used in FFT tests below.
220   void InitializeFftData();
221 
222   HloEvaluator evaluator_;
223 
224   const bool use_bfloat16_;
225   std::unique_ptr<HloModule> m_ = CreateNewVerifiedModule();
226 
227   // Data sets used in FFT tests below.
228   ErrorSpec fft_error_ = ErrorSpec(1e-4, 1e-5);
229   Literal fft_c64x2x4x8_;
230   Literal fft_c64x2x4x8_1d_;
231   Literal fft_c64x2x4x8_2d_;
232   Literal fft_c64x2x4x8_3d_;
233 };
234 
235 // Lets you write TEST_Ps that run twice, once with and once without bf16.
236 class HloEvaluatorBf16Test : public ::testing::WithParamInterface<bool>,
237                              public HloEvaluatorTest {
238  protected:
HloEvaluatorBf16Test()239   HloEvaluatorBf16Test() : HloEvaluatorTest(/*use_bfloat16=*/GetParam()) {}
240 };
241 
242 INSTANTIATE_TEST_SUITE_P(HloEvaluatorTest_Instantiation, HloEvaluatorBf16Test,
243                          ::testing::ValuesIn(use_bf16_params));
244 
245 // Verifies that HloEvaluator evaluates a HLO instruction that performs clamp
246 // with 3 operands.
TEST_P(HloEvaluatorBf16Test,DoesClamp)247 TEST_P(HloEvaluatorBf16Test, DoesClamp) {
248   auto low = LiteralUtil::CreateR2<float>({{0.f, 2.f}, {2.f, 4.f}});
249   auto value = LiteralUtil::CreateR2<float>({{0.f, 5.f}, {0.f, 4.f}});
250   auto high = LiteralUtil::CreateR2<float>({{2.f, 4.f}, {4.f, 4.f}});
251 
252   Shape shape = low.shape();
253   HloComputation::Builder b(TestName());
254   auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low)));
255   auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
256   auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high)));
257   b.AddInstruction(
258       HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3));
259   m_->AddEntryComputation(b.Build());
260 
261   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
262 
263   auto expected = LiteralUtil::CreateR2<float>({{0, 4}, {2, 4}});
264 
265   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
266 }
267 
268 // Verifies that clamping of int64_t does not cause loss of precision
TEST_P(HloEvaluatorBf16Test,DoesClampInt64)269 TEST_P(HloEvaluatorBf16Test, DoesClampInt64) {
270   auto ones = [](int bits) { return (int64_t{1} << bits) - 1; };
271 
272   auto low =
273       LiteralUtil::CreateR2<int64_t>({{0, ones(54)}, {ones(54), ones(58)}});
274   auto value = LiteralUtil::CreateR2<int64_t>({{0, ones(56)}, {0, ones(58)}});
275   auto high = LiteralUtil::CreateR2<int64_t>(
276       {{ones(54), ones(55)}, {ones(56), ones(58)}});
277 
278   Shape shape = low.shape();
279   HloComputation::Builder b(TestName());
280   auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low)));
281   auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
282   auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high)));
283   b.AddInstruction(
284       HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3));
285   m_->AddEntryComputation(b.Build());
286 
287   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
288 
289   auto expected =
290       LiteralUtil::CreateR2<int64_t>({{0, ones(55)}, {ones(54), ones(58)}});
291 
292   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
293 }
294 
TEST_P(HloEvaluatorBf16Test,DISABLED_DoesClampSpecialBroadcast)295 TEST_P(HloEvaluatorBf16Test, DISABLED_DoesClampSpecialBroadcast) {
296   auto low = LiteralUtil::CreateR0<float>(0.f);
297   auto value = LiteralUtil::CreateR2<float>({{-1.f, 0.f}, {1.f, 2.f}});
298   auto high = LiteralUtil::CreateR0<float>(1.f);
299 
300   Shape shape = value.shape();
301   HloComputation::Builder b(TestName());
302   auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low)));
303   auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
304   auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high)));
305   b.AddInstruction(
306       HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3));
307   m_->AddEntryComputation(b.Build());
308 
309   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
310 
311   auto expected = LiteralUtil::CreateR2<float>({{0, 0}, {1, 1}});
312 
313   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
314 }
315 
316 // Verifies that HloEvaluator evaluates a HLO instruction that performs select
317 // with 3 operands.
TEST_P(HloEvaluatorBf16Test,DoesSelect)318 TEST_P(HloEvaluatorBf16Test, DoesSelect) {
319   auto pred = LiteralUtil::CreateR2<bool>({{true, false}, {false, true}});
320   auto on_true = LiteralUtil::CreateR2<float>({{2.f, 4.f}, {4.f, 4.f}});
321   auto on_false = LiteralUtil::CreateR2<float>({{0.f, 5.f}, {0.f, 4.f}});
322 
323   Shape shape = on_true.shape();
324   HloComputation::Builder b(TestName());
325   auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(pred)));
326   auto c2 =
327       b.AddInstruction(HloInstruction::CreateConstant(std::move(on_true)));
328   auto c3 =
329       b.AddInstruction(HloInstruction::CreateConstant(std::move(on_false)));
330   b.AddInstruction(
331       HloInstruction::CreateTernary(shape, HloOpcode::kSelect, c1, c2, c3));
332   m_->AddEntryComputation(b.Build());
333 
334   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({}));
335 
336   auto expected = LiteralUtil::CreateR2<float>({{2, 5}, {0, 4}});
337 
338   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
339 }
340 
341 // Verifies that HloEvaluator evaluates a HLO instruction that performs
342 // element-wise addition with 2 operands.
TEST_F(HloEvaluatorTest,DoesAdd)343 TEST_F(HloEvaluatorTest, DoesAdd) {
344   auto lhs = LiteralUtil::CreateR2<int64_t>({{1, 0}, {-100, 4}});
345   auto rhs = LiteralUtil::CreateR2<int64_t>({{2, 4}, {4, 4}});
346   auto expected = LiteralUtil::CreateR2<int64_t>({{3, 4}, {-96, 8}});
347   TestBinaryOp(HloOpcode::kAdd, std::move(expected), std::move(lhs),
348                std::move(rhs));
349 }
350 // Verifies that HloEvaluator evaluates a HLO instruction that performs
351 // element-wise and with 2 operands.
TEST_P(HloEvaluatorBf16Test,DoesAnd)352 TEST_P(HloEvaluatorBf16Test, DoesAnd) {
353   auto lhs = LiteralUtil::CreateR2<int64_t>({{1, 0}, {-100, 4}});
354   auto rhs = LiteralUtil::CreateR2<int64_t>({{2, 4}, {4, 4}});
355   auto expected = LiteralUtil::CreateR2<int64_t>({{0, 0}, {4, 4}});
356   TestBinaryOp(HloOpcode::kAnd, std::move(expected), std::move(lhs),
357                std::move(rhs));
358 }
359 // Verifies that HloEvaluator evaluates a HLO instruction that performs
360 // element-wise or with 2 operands.
TEST_F(HloEvaluatorTest,DoesOr)361 TEST_F(HloEvaluatorTest, DoesOr) {
362   auto lhs = LiteralUtil::CreateR2<int64_t>({{1, 0}, {-100, 4}});
363   auto rhs = LiteralUtil::CreateR2<int64_t>({{2, 4}, {4, 4}});
364   auto expected = LiteralUtil::CreateR2<int64_t>({{3, 4}, {-100, 4}});
365   TestBinaryOp(HloOpcode::kOr, std::move(expected), std::move(lhs),
366                std::move(rhs));
367 }
368 // Verifies that HloEvaluator evaluates a HLO instruction that performs
369 // element-wise or with 2 operands.
TEST_F(HloEvaluatorTest,DoesXor)370 TEST_F(HloEvaluatorTest, DoesXor) {
371   auto lhs = LiteralUtil::CreateR2<int64_t>({{1, 0}, {-100, 4}});
372   auto rhs = LiteralUtil::CreateR2<int64_t>({{2, 4}, {4, 4}});
373   auto expected = LiteralUtil::CreateR2<int64_t>({{3, 4}, {-104, 0}});
374   TestBinaryOp(HloOpcode::kXor, std::move(expected), std::move(lhs),
375                std::move(rhs));
376 }
377 // Verifies that HloEvaluator evaluates a HLO instruction that performs
378 // element-wise multiply with 2 operands.
TEST_F(HloEvaluatorTest,DoesMultiply)379 TEST_F(HloEvaluatorTest, DoesMultiply) {
380   auto lhs = LiteralUtil::CreateR2<int32_t>({{-1, 0}, {-100, 4}});
381   auto rhs = LiteralUtil::CreateR2<int32_t>(
382       {{std::numeric_limits<int32_t>::min(), 4}, {4, 4}});
383   auto expected = LiteralUtil::CreateR2<int32_t>(
384       {{std::numeric_limits<int32_t>::min(), 0}, {-400, 16}});
385   TestBinaryOp(HloOpcode::kMultiply, std::move(expected), std::move(lhs),
386                std::move(rhs));
387 }
388 // Verifies that HloEvaluator evaluates a HLO instruction that performs
389 // element-wise divide with 2 operands.
TEST_F(HloEvaluatorTest,DoesDivideInt64)390 TEST_F(HloEvaluatorTest, DoesDivideInt64) {
391   auto lhs = LiteralUtil::CreateR2<int64_t>({{1, 0}, {-100, 4}});
392   auto rhs = LiteralUtil::CreateR2<int64_t>({{2, 4}, {4, 4}});
393   auto expected = LiteralUtil::CreateR2<int64_t>({{0, 0}, {-25, 1}});
394   TestBinaryOp(HloOpcode::kDivide, std::move(expected), std::move(lhs),
395                std::move(rhs));
396 }
397 
TEST_F(HloEvaluatorTest,DoesClampS64)398 TEST_F(HloEvaluatorTest, DoesClampS64) {
399   auto low = LiteralUtil::CreateR1<int64_t>(
400       {-8616761059752331528LL, 6780561065411491190LL, -8616761059752331528LL});
401   auto value = LiteralUtil::CreateR1<int64_t>(
402       {-6780561065411491190LL, 6780561065411491180LL, 4241131823772864090LL});
403   auto high = LiteralUtil::CreateR1<int64_t>(
404       {-6780561065411491180LL, 8616761059752331528LL, 3832151243857508051LL});
405   auto expected = LiteralUtil::CreateR1<int64_t>(
406       {-6780561065411491190LL, 6780561065411491190LL, 3832151243857508051LL});
407   TestTernaryOp(HloOpcode::kClamp, std::move(expected), std::move(low),
408                 std::move(value), std::move(high));
409 }
410 
TEST_P(HloEvaluatorBf16Test,DoesDivideDouble)411 TEST_P(HloEvaluatorBf16Test, DoesDivideDouble) {
412   auto lhs = LiteralUtil::CreateR2<double>({{1.0, 0.0}, {-100.0, 4.0}});
413   auto rhs = LiteralUtil::CreateR2<double>({{2.2, 4.0}, {4.0, 4.0}});
414   auto expected =
415       LiteralUtil::CreateR2<double>({{0.45454545454545453, 0}, {-25, 1}});
416   TestBinaryOp(HloOpcode::kDivide, std::move(expected), std::move(lhs),
417                std::move(rhs));
418 }
419 
420 // Verifies that HloEvaluator evaluates a HLO instruction that performs
421 // element-wise abs op with 1 operand.
TEST_F(HloEvaluatorTest,DoesAbsR2)422 TEST_F(HloEvaluatorTest, DoesAbsR2) {
423   auto operand = LiteralUtil::CreateR2<int64_t>({{1, -20}, {-100, 4}});
424   auto expected = LiteralUtil::CreateR2<int64_t>({{1, 20}, {100, 4}});
425   TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand));
426 }
TEST_P(HloEvaluatorBf16Test,DoesAbsR0)427 TEST_P(HloEvaluatorBf16Test, DoesAbsR0) {
428   auto operand = LiteralUtil::CreateR0<float>(-1.0f);
429   auto expected = LiteralUtil::CreateR0<float>(1.0f);
430   TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand));
431 }
TEST_P(HloEvaluatorBf16Test,DoesAbsR1WithZeroSize)432 TEST_P(HloEvaluatorBf16Test, DoesAbsR1WithZeroSize) {
433   auto operand = LiteralUtil::CreateR1<float>({});
434   auto expected = LiteralUtil::CreateR1<float>({});
435   TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand));
436 }
437 
TEST_F(HloEvaluatorTest,DoesAbsC128)438 TEST_F(HloEvaluatorTest, DoesAbsC128) {
439   auto x = LiteralUtil::CreateR0<complex128>({1, 2});
440   auto expected_real = LiteralUtil::CreateR0<double>(2.23607);
441   TestUnaryOp(HloOpcode::kAbs, std::move(expected_real), std::move(x), 3e-06);
442 }
443 
TEST_F(HloEvaluatorTest,DoesNegateR2)444 TEST_F(HloEvaluatorTest, DoesNegateR2) {
445   auto operand = LiteralUtil::CreateR2<int32_t>(
446       {{0, std::numeric_limits<int32_t>::min()}, {-1, 4}});
447   auto expected = LiteralUtil::CreateR2<int32_t>(
448       {{0, std::numeric_limits<int>::min()}, {1, -4}});
449   TestUnaryOp(HloOpcode::kNegate, std::move(expected), std::move(operand));
450 }
TEST_P(HloEvaluatorBf16Test,DoesCosR2)451 TEST_P(HloEvaluatorBf16Test, DoesCosR2) {
452   auto operand = LiteralUtil::CreateR2<float>({{0, M_PI}, {-M_PI, 2 * M_PI}});
453   auto expected = LiteralUtil::CreateR2<float>({{1, -1}, {-1, 1}});
454   TestUnaryOp(HloOpcode::kCos, std::move(expected), std::move(operand),
455               use_bfloat16_ ? 0.031250 : 9.5367431640625E-7);
456 }
TEST_P(HloEvaluatorBf16Test,DoesSinR2)457 TEST_P(HloEvaluatorBf16Test, DoesSinR2) {
458   auto operand = LiteralUtil::CreateR2<float>({{0, M_PI}, {-M_PI, 2 * M_PI}});
459   auto expected = LiteralUtil::CreateR2<float>({{0, 0}, {0, 0}});
460   TestUnaryOp(HloOpcode::kSin, std::move(expected), std::move(operand),
461               use_bfloat16_ ? 0.031250 : 9.5367431640625E-7);
462 }
TEST_F(HloEvaluatorTest,DoesNotR2)463 TEST_F(HloEvaluatorTest, DoesNotR2) {
464   auto operand =
465       LiteralUtil::CreateR2<int32_t>({{0, std::numeric_limits<int>::min()},
466                                       {-1, std::numeric_limits<int>::max()}});
467   auto expected =
468       LiteralUtil::CreateR2<int32_t>({{-1, std::numeric_limits<int>::max()},
469                                       {0, std::numeric_limits<int>::min()}});
470   TestUnaryOp(HloOpcode::kNot, std::move(expected), std::move(operand));
471 }
472 
TEST_F(HloEvaluatorTest,DoesRealC128)473 TEST_F(HloEvaluatorTest, DoesRealC128) {
474   auto x = LiteralUtil::CreateR1<complex128>({{1, 0}, {-100, 4}});
475   auto expected_real = LiteralUtil::CreateR1<double>({1, -100});
476   TestUnaryOp(HloOpcode::kReal, std::move(expected_real), std::move(x));
477 }
478 
TEST_F(HloEvaluatorTest,DoesImagC128)479 TEST_F(HloEvaluatorTest, DoesImagC128) {
480   auto x = LiteralUtil::CreateR1<complex128>({{1, 0}, {-100, 4}});
481   auto expected_imag = LiteralUtil::CreateR1<double>({0, 4});
482   TestUnaryOp(HloOpcode::kImag, std::move(expected_imag), std::move(x));
483 }
484 
TEST_P(HloEvaluatorBf16Test,DoesImagF32AndBf16)485 TEST_P(HloEvaluatorBf16Test, DoesImagF32AndBf16) {
486   auto x = LiteralUtil::CreateR1<float>({1, -100});
487   auto expected_imag = LiteralUtil::CreateR1<float>({0, 0});
488   TestUnaryOp(HloOpcode::kImag, std::move(expected_imag), std::move(x));
489 }
490 
TEST_F(HloEvaluatorTest,DoesImagF64)491 TEST_F(HloEvaluatorTest, DoesImagF64) {
492   auto x = LiteralUtil::CreateR1<double>({1, -100});
493   auto expected_imag = LiteralUtil::CreateR1<double>({0, 0});
494   TestUnaryOp(HloOpcode::kImag, std::move(expected_imag), std::move(x));
495 }
496 
497 // Verifies that HloEvaluator evaluates a HLO Computation with non-parameter nor
498 // constant operands.
TEST_F(HloEvaluatorTest,DoesTraverseInstructions)499 TEST_F(HloEvaluatorTest, DoesTraverseInstructions) {
500   auto lhs = LiteralUtil::CreateR2<int64_t>({{1, 0}, {-100, 4}});
501   auto rhs = LiteralUtil::CreateR2<int64_t>({{2, 4}, {4, 4}});
502   auto rhs2 = LiteralUtil::CreateR2<int64_t>({{1, -20}, {-100, 4}});
503   std::vector<const Literal*> args = {&lhs, &rhs, &rhs2};
504 
505   Shape shape = ShapeUtil::MakeShape(S64, {2, 2});
506 
507   HloComputation::Builder b(TestName());
508   auto param_lhs =
509       b.AddInstruction(HloInstruction::CreateParameter(0, shape, "lhs"));
510   auto param_rhs =
511       b.AddInstruction(HloInstruction::CreateParameter(1, shape, "rhs"));
512   auto lhs_instruction = b.AddInstruction(HloInstruction::CreateBinary(
513       shape, HloOpcode::kAdd, param_lhs, param_rhs));
514 
515   auto param_rhs2 =
516       b.AddInstruction(HloInstruction::CreateParameter(2, shape, "rhs2"));
517   b.AddInstruction(HloInstruction::CreateBinary(shape, HloOpcode::kAdd,
518                                                 lhs_instruction, param_rhs2));
519   m_->AddEntryComputation(b.Build());
520 
521   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate(args));
522 
523   auto expected = LiteralUtil::CreateR2<int64_t>({{4, -16}, {-196, 12}});
524 
525   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
526 }
527 
528 // Verifies Reshape operation is correctly evaluated.
TEST_F(HloEvaluatorTest,DoesReshape)529 TEST_F(HloEvaluatorTest, DoesReshape) {
530   HloComputation::Builder b(TestName());
531   const int64_t dimensions[] = {11, 8, 7, 5, 9};
532   TF_ASSERT_OK_AND_ASSIGN(auto literal,
533                           LiteralUtil::CreateRandomLiteral<F32>(
534                               ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
535   auto literal_clone = literal.Clone();
536   HloInstruction* literal_instruction =
537       b.AddInstruction(HloInstruction::CreateConstant(std::move(literal)));
538 
539   Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5});
540   const int64_t permutation[] = {1, 2, 0, 4, 3};
541   b.AddInstruction(
542       HloInstruction::CreateTranspose(shape, literal_instruction, permutation));
543   m_->AddEntryComputation(b.Build());
544 
545   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({}));
546 
547   using NativeT = typename primitive_util::PrimitiveTypeToNative<F32>::type;
548   result.EachCell<NativeT>(
549       [&](absl::Span<const int64_t> indices, NativeT value) {
550         std::vector<int64_t> rindexes = PermuteInverse(indices, permutation);
551         EXPECT_NEAR(value, literal_clone.Get<NativeT>(rindexes), 0.031250);
552       });
553 }
554 
555 // Verifies Broadcast operation is correctly evaluated.
TEST_F(HloEvaluatorTest,DoesBroadcast)556 TEST_F(HloEvaluatorTest, DoesBroadcast) {
557   HloComputation::Builder b(TestName());
558   auto input_literal = LiteralUtil::CreateR2<int32_t>({{1, 2}, {3, 4}, {5, 6}});
559   auto output_literal = LiteralUtil::CreateR3<int32_t>(
560       {{{1, 2}, {3, 4}, {5, 6}}, {{1, 2}, {3, 4}, {5, 6}}});
561   HloInstruction* literal_instruction = b.AddInstruction(
562       HloInstruction::CreateConstant(std::move(input_literal)));
563   b.AddInstruction(HloInstruction::CreateBroadcast(
564       output_literal.shape(), literal_instruction, {1, 2}));
565   m_->AddEntryComputation(b.Build());
566 
567   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({}));
568 
569   EXPECT_TRUE(LiteralTestUtil::Equal(result, output_literal));
570 }
571 
TEST_F(HloEvaluatorTest,DoesBroadcastScalar)572 TEST_F(HloEvaluatorTest, DoesBroadcastScalar) {
573   HloComputation::Builder b(TestName());
574   auto input_literal = LiteralUtil::CreateR0<int32_t>(111);
575   auto output_literal = LiteralUtil::CreateR2<int32_t>(
576       {{111, 111}, {111, 111}, {111, 111}, {111, 111}, {111, 111}, {111, 111}});
577 
578   HloInstruction* literal_instruction = b.AddInstruction(
579       HloInstruction::CreateConstant(std::move(input_literal)));
580   // Broadcast dimension should be empty in the case of scalars.
581   b.AddInstruction(HloInstruction::CreateBroadcast(
582       output_literal.shape(), literal_instruction,
583       /*broadcast_dimensions=*/{}));
584   m_->AddEntryComputation(b.Build());
585 
586   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({}));
587 
588   EXPECT_TRUE(LiteralTestUtil::Equal(result, output_literal));
589 }
590 
TEST_F(HloEvaluatorTest,DoesConcatenateSimple)591 TEST_F(HloEvaluatorTest, DoesConcatenateSimple) {
592   HloComputation::Builder b(TestName());
593 
594   HloInstruction* operand1 = b.AddInstruction(HloInstruction::CreateConstant(
595       LiteralUtil::CreateR2<int64_t>({{-1, -2}, {100, 200}})));
596   HloInstruction* operand2 = b.AddInstruction(HloInstruction::CreateConstant(
597       LiteralUtil::CreateR2<int64_t>({{-2, -3}, {-100, -200}})));
598 
599   std::vector<HloInstruction*> operands = {operand1, operand2};
600 
601   Shape shape = ShapeUtil::MakeShape(S64, {4, 2});
602   b.AddInstruction(HloInstruction::CreateConcatenate(shape, operands, 0));
603 
604   m_->AddEntryComputation(b.Build());
605 
606   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
607 
608   auto expected = LiteralUtil::CreateR2<int64_t>(
609       {{-1, -2}, {100, 200}, {-2, -3}, {-100, -200}});
610   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
611 }
612 
TEST_F(HloEvaluatorTest,ConcatenateHandlesShapeWithZeroElement)613 TEST_F(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) {
614   HloComputation::Builder b(TestName());
615 
616   HloInstruction* operand1 = b.AddInstruction(HloInstruction::CreateConstant(
617       LiteralUtil::CreateR1<int64_t>({100, 200})));
618   HloInstruction* operand2 = b.AddInstruction(
619       HloInstruction::CreateConstant(LiteralUtil::CreateR1<int64_t>({})));
620 
621   std::vector<HloInstruction*> operands = {operand1, operand2};
622 
623   Shape shape = ShapeUtil::MakeShape(S64, {2});
624   b.AddInstruction(HloInstruction::CreateConcatenate(shape, operands, 0));
625 
626   m_->AddEntryComputation(b.Build());
627 
628   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
629 
630   auto expected = LiteralUtil::CreateR1<int64_t>({100, 200});
631   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
632 }
633 
TEST_P(HloEvaluatorBf16Test,ConvertWithSameLayout)634 TEST_P(HloEvaluatorBf16Test, ConvertWithSameLayout) {
635   HloComputation::Builder b(TestName());
636 
637   auto input_literal = LiteralUtil::CreateR2<int32_t>({{1, 2}, {3, 4}, {5, 6}});
638   auto expected =
639       LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}});
640   ASSERT_TRUE(LayoutUtil::LayoutsInShapesEqual(input_literal.shape(),
641                                                expected.shape()));
642 
643   HloInstruction* constant = b.AddInstruction(
644       HloInstruction::CreateConstant(std::move(input_literal)));
645   b.AddInstruction(HloInstruction::CreateConvert(expected.shape(), constant));
646   m_->AddEntryComputation(b.Build());
647 
648   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
649 
650   EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
651 }
652 
TEST_P(HloEvaluatorBf16Test,ConvertWithDifferentLayout)653 TEST_P(HloEvaluatorBf16Test, ConvertWithDifferentLayout) {
654   HloComputation::Builder b(TestName());
655 
656   auto input_literal = LiteralUtil::CreateR2WithLayout<int32_t>(
657       {{1, 2}, {3, 4}, {5, 6}}, LayoutUtil::MakeLayout({0, 1}));
658   auto expected = LiteralUtil::CreateR2WithLayout<float>(
659       {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}, LayoutUtil::MakeLayout({1, 0}));
660   ASSERT_FALSE(LayoutUtil::LayoutsInShapesEqual(input_literal.shape(),
661                                                 expected.shape()));
662 
663   HloInstruction* constant = b.AddInstruction(
664       HloInstruction::CreateConstant(std::move(input_literal)));
665   b.AddInstruction(HloInstruction::CreateConvert(expected.shape(), constant));
666   m_->AddEntryComputation(b.Build());
667 
668   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
669 
670   EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
671 }
672 
CreatePaddingConfig(std::initializer_list<std::array<int64_t,3>> padding_dimensions)673 PaddingConfig CreatePaddingConfig(
674     std::initializer_list<std::array<int64_t, 3>> padding_dimensions) {
675   PaddingConfig padding_config;
676 
677   for (auto& paddings_per_dim : padding_dimensions) {
678     auto dimension = padding_config.add_dimensions();
679     dimension->set_edge_padding_low(paddings_per_dim[0]);
680     dimension->set_edge_padding_high(paddings_per_dim[1]);
681     dimension->set_interior_padding(paddings_per_dim[2]);
682   }
683   return padding_config;
684 }
685 
TEST_F(HloEvaluatorTest,Pad2DIntegerArrayWithZeroDimension)686 TEST_F(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) {
687   auto operand = LiteralUtil::CreateR2<int32_t>({{}, {}});
688   HloComputation::Builder b(TestName());
689   auto operand_instruction =
690       b.AddInstruction(HloInstruction::CreateConstant(std::move(operand)));
691 
692   constexpr int32_t kPadValue = 10;
693   auto pad_value = LiteralUtil::CreateR0<int32_t>(kPadValue);
694   auto padding_value_instruction =
695       b.AddInstruction(HloInstruction::CreateConstant(std::move(pad_value)));
696 
697   auto padding_config = CreatePaddingConfig({{{1, 0, 2}}, {{0, 2, 1}}});
698   Shape shape = ShapeUtil::MakeShape(S32, {5, 2});
699   b.AddInstruction(HloInstruction::CreatePad(
700       shape, operand_instruction, padding_value_instruction, padding_config));
701   m_->AddEntryComputation(b.Build());
702 
703   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
704 
705   auto expected = LiteralUtil::CreateR2<int32_t>(
706       {{10, 10}, {10, 10}, {10, 10}, {10, 10}, {10, 10}});
707 
708   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
709 }
710 
TEST_P(HloEvaluatorBf16Test,Pad4DFloatArrayWithInteriorPadding)711 TEST_P(HloEvaluatorBf16Test, Pad4DFloatArrayWithInteriorPadding) {
712   HloComputation::Builder b(TestName());
713 
714   Array4D<float> input_array(3, 2, 1, 1, {1, 2, 3, 4, 5, 6});
715   auto input = LiteralUtil::CreateR4FromArray4D<float>(input_array);
716   HloInstruction* input_instruction =
717       b.AddInstruction(HloInstruction::CreateConstant(std::move(input)));
718   constexpr float kPadValue = 1.5;
719   auto pad_value = LiteralUtil::CreateR0<float>(kPadValue);
720   HloInstruction* pad_instruction =
721       b.AddInstruction(HloInstruction::CreateConstant(std::move(pad_value)));
722 
723   Shape shape = ShapeUtil::MakeShape(F32, {8, 5, 1, 1});
724   auto r4_padding_on_dim0_dim1 =
725       CreatePaddingConfig({{{1, 0, 2}}, {{0, 2, 1}}, {{0, 0, 0}}, {{0, 0, 0}}});
726   b.AddInstruction(HloInstruction::CreatePad(
727       shape, input_instruction, pad_instruction, r4_padding_on_dim0_dim1));
728   m_->AddEntryComputation(b.Build());
729 
730   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
731 
732   auto expected_array = std::make_unique<Array4D<float>>(8, 5, 1, 1);
733   expected_array->Fill(kPadValue);
734   (*expected_array)(1, 0, 0, 0) = 1.0f;
735   (*expected_array)(1, 2, 0, 0) = 2.0f;
736   (*expected_array)(4, 0, 0, 0) = 3.0f;
737   (*expected_array)(4, 2, 0, 0) = 4.0f;
738   (*expected_array)(7, 0, 0, 0) = 5.0f;
739   (*expected_array)(7, 2, 0, 0) = 6.0f;
740 
741   auto expected = LiteralUtil::CreateR4FromArray4D<float>(*expected_array);
742 
743   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
744 }
745 
TEST_P(HloEvaluatorBf16Test,NegativePadding2D)746 TEST_P(HloEvaluatorBf16Test, NegativePadding2D) {
747   HloComputation::Builder b(TestName());
748 
749   // input_array:
750   // f32[4,3] {
751   //  { 1, 2, 3 },
752   //  { 5, 6, 7 },
753   //  { 9, 10, 11 },
754   //  { 13, 14, 15 },
755   // }
756   auto input_array = std::make_unique<Array2D<float>>(4, 3);
757   input_array->FillUnique(1.0f);
758   auto input = LiteralUtil::CreateR2FromArray2D<float>(*input_array);
759   HloInstruction* input_instruction =
760       b.AddInstruction(HloInstruction::CreateConstant(std::move(input)));
761 
762   auto pad_value_instruction = b.AddInstruction(
763       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.718f)));
764 
765   auto r2_padding_on_dim0_dim1 =
766       CreatePaddingConfig({{{-1, -2, 0}}, {{-2, 4, 0}}});
767   Shape shape = ShapeUtil::MakeShape(F32, {1, 5});
768   b.AddInstruction(HloInstruction::CreatePad(shape, input_instruction,
769                                              pad_value_instruction,
770                                              r2_padding_on_dim0_dim1));
771 
772   m_->AddEntryComputation(b.Build());
773 
774   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
775 
776   // f32[1,5] { 7.0, 2.718, 2.718, 2.718, 2.718 }
777   auto expected_array = std::make_unique<Array2D<float>>(1, 5);
778   (*expected_array)(0, 0) = 7.0f;
779   (*expected_array)(0, 1) = 2.718f;
780   (*expected_array)(0, 2) = 2.718f;
781   (*expected_array)(0, 3) = 2.718f;
782   (*expected_array)(0, 4) = 2.718f;
783   auto expected = LiteralUtil::CreateR2FromArray2D<float>(*expected_array);
784 
785   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(0.031250)));
786 }
787 
TEST_P(HloEvaluatorBf16Test,NegativeAndInteriorPadding2D)788 TEST_P(HloEvaluatorBf16Test, NegativeAndInteriorPadding2D) {
789   HloComputation::Builder b(TestName());
790 
791   // f32[4,3] {
792   //  { 1, 2, 3 },
793   //  { 5, 6, 7 },
794   //  { 9, 10, 11 },
795   //  { 13, 14, 15 },
796   // }
797   auto input_array = std::make_unique<Array2D<float>>(4, 3);
798   input_array->FillUnique(1.0f);
799   auto input = LiteralUtil::CreateR2FromArray2D<float>(*input_array);
800   HloInstruction* input_instruction =
801       b.AddInstruction(HloInstruction::CreateConstant(std::move(input)));
802 
803   auto pad_value_instruction = b.AddInstruction(
804       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.718f)));
805 
806   PaddingConfig padding_config = MakeNoPaddingConfig(2);
807 
808   // Negative padding that results in zero dimensions.
809   auto r2_padding_on_dim0_dim1 =
810       CreatePaddingConfig({{{-2, -5, 1}}, {{-2, 4, 2}}});
811 
812   Shape shape = ShapeUtil::MakeShape(F32, {0, 9});
813   b.AddInstruction(HloInstruction::CreatePad(shape, input_instruction,
814                                              pad_value_instruction,
815                                              r2_padding_on_dim0_dim1));
816 
817   m_->AddEntryComputation(b.Build());
818 
819   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
820 
821   auto expected_array = std::make_unique<Array2D<float>>(0, 9);
822   auto expected = LiteralUtil::CreateR2FromArray2D<float>(*expected_array);
823 
824   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
825 }
826 
TEST_P(HloEvaluatorBf16Test,DotRank2AndRank1)827 TEST_P(HloEvaluatorBf16Test, DotRank2AndRank1) {
828   HloComputation::Builder b(TestName());
829 
830   // lhs:
831   // f32[4,1] {
832   //  { 1 },
833   //  { 2 },
834   //  { 3 },
835   //  { 4 },
836   // }
837   auto lhs_array = std::make_unique<Array2D<float>>(4, 1);
838   lhs_array->FillUnique(1.0f);
839   auto lhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*lhs_array);
840   HloInstruction* lhs_instruction =
841       b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
842 
843   // rhs:
844   // f32[2] { 1, 2 },
845   auto rhs_literal = LiteralUtil::CreateR2<float>({{1, 2}});
846   HloInstruction* rhs_instruction =
847       b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
848 
849   Shape shape = ShapeUtil::MakeShape(F32, {4, 2});
850   DotDimensionNumbers dot_dnums;
851   dot_dnums.add_lhs_contracting_dimensions(1);
852   dot_dnums.add_rhs_contracting_dimensions(0);
853   b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction,
854                                              rhs_instruction, dot_dnums,
855                                              DefaultPrecisionConfig(2)));
856   m_->AddEntryComputation(b.Build());
857 
858   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
859 
860   // clang-format off
861   auto expected_array = Array2D<float>({
862       {1.f, 2.f},
863       {2.f, 4.f},
864       {3.f, 6.f},
865       {4.f, 8.f},
866   });
867   // clang-format on
868   auto expected = LiteralUtil::CreateR2FromArray2D<float>(expected_array);
869 
870   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
871 }
872 
TEST_P(HloEvaluatorBf16Test,DotRank1AndRank2)873 TEST_P(HloEvaluatorBf16Test, DotRank1AndRank2) {
874   HloComputation::Builder b(TestName());
875 
876   // lhs:
877   // f32[3]
878   //  { 1, 2, 3 },
879   auto lhs_literal = LiteralUtil::CreateR1<float>({1, 2, 3});
880   HloInstruction* lhs_instruction =
881       b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
882 
883   // rhs:
884   // f32[3,2] {
885   //  { 1, 2 },
886   //  { 3, 4 },
887   //  { 5, 6 },
888   // }
889   auto rhs_array = std::make_unique<Array2D<float>>(3, 2);
890   rhs_array->FillUnique(1.0f);
891   auto rhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*rhs_array);
892   HloInstruction* rhs_instruction =
893       b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
894 
895   Shape shape = ShapeUtil::MakeShape(F32, {2});
896   DotDimensionNumbers dot_dnums;
897   dot_dnums.add_lhs_contracting_dimensions(0);
898   dot_dnums.add_rhs_contracting_dimensions(0);
899   b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction,
900                                              rhs_instruction, dot_dnums,
901                                              DefaultPrecisionConfig(2)));
902   m_->AddEntryComputation(b.Build());
903 
904   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
905 
906   auto expected = LiteralUtil::CreateR1<float>({22.f, 28.f});
907 
908   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
909 }
910 
TEST_P(HloEvaluatorBf16Test,DotRank2AndRank2)911 TEST_P(HloEvaluatorBf16Test, DotRank2AndRank2) {
912   HloComputation::Builder b(TestName());
913 
914   // lhs:
915   // f32[4,3] {
916   //  { 1, 2, 3 },
917   //  { 5, 6, 7 },
918   //  { 9, 10, 11 },
919   //  { 13, 14, 15 },
920   // }
921   auto lhs_array = std::make_unique<Array2D<float>>(4, 3);
922   lhs_array->FillUnique(1.0f);
923   auto lhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*lhs_array);
924   HloInstruction* lhs_instruction =
925       b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
926 
927   // rhs:
928   // f32[3,2] {
929   //  { 1, 2 },
930   //  { 3, 4 },
931   //  { 5, 6 },
932   // }
933   auto rhs_array = std::make_unique<Array2D<float>>(3, 2);
934   rhs_array->FillUnique(1.0f);
935   auto rhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*rhs_array);
936   HloInstruction* rhs_instruction =
937       b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
938 
939   Shape shape = ShapeUtil::MakeShape(F32, {4, 2});
940   DotDimensionNumbers dot_dnums;
941   dot_dnums.add_lhs_contracting_dimensions(1);
942   dot_dnums.add_rhs_contracting_dimensions(0);
943   b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction,
944                                              rhs_instruction, dot_dnums,
945                                              DefaultPrecisionConfig(2)));
946   m_->AddEntryComputation(b.Build());
947 
948   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
949 
950   auto expected_array = Array2D<float>({
951       {22.f, 28.f},
952       {58.f, 76.f},
953       {94.f, 124.f},
954       {130.f, 172.f},
955   });
956   auto expected = LiteralUtil::CreateR2FromArray2D<float>(expected_array);
957 
958   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
959 }
960 
TEST_P(HloEvaluatorBf16Test,DotRank4AndRank4)961 TEST_P(HloEvaluatorBf16Test, DotRank4AndRank4) {
962   HloComputation::Builder b(TestName());
963 
964   auto lhs_array = std::make_unique<Array4D<float>>(2, 2, 3, 1);
965   lhs_array->FillIota(1.0f);
966   auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(*lhs_array);
967   HloInstruction* lhs_instruction =
968       b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
969 
970   auto rhs_array = std::make_unique<Array4D<float>>(2, 2, 3, 1);
971   rhs_array->FillIota(2.0f);
972   auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(*rhs_array);
973   HloInstruction* rhs_instruction =
974       b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
975 
976   Shape shape = ShapeUtil::MakeShape(F32, {2, 1, 1});
977   DotDimensionNumbers dot_dnums;
978 
979   dot_dnums.add_lhs_batch_dimensions(0);
980   dot_dnums.add_rhs_batch_dimensions(0);
981   dot_dnums.add_lhs_contracting_dimensions(1);
982   dot_dnums.add_lhs_contracting_dimensions(2);
983   dot_dnums.add_rhs_contracting_dimensions(1);
984   dot_dnums.add_rhs_contracting_dimensions(2);
985   b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction,
986                                              rhs_instruction, dot_dnums,
987                                              DefaultPrecisionConfig(2)));
988   m_->AddEntryComputation(b.Build());
989 
990   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
991 
992   float expected_1 = 0;
993   for (float i = 1.0f; i < 7.0f; ++i) {
994     expected_1 += i * i + i;
995   }
996   float expected_2 = 0;
997   for (float i = 7.0f; i < 13.0f; ++i) {
998     expected_2 += i * i + i;
999   }
1000   auto expected_array = Array3D<float>({{{expected_1}}, {{expected_2}}});
1001   auto expected = LiteralUtil::CreateR3FromArray3D<float>(expected_array);
1002 
1003   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1004 }
1005 
TEST_P(HloEvaluatorBf16Test,SimpleConv1D)1006 TEST_P(HloEvaluatorBf16Test, SimpleConv1D) {
1007   HloComputation::Builder b(TestName());
1008 
1009   Array3D<float> lhs_array = {{{1, 2, 3}}};
1010   auto lhs_literal = LiteralUtil::CreateR3FromArray3D<float>(lhs_array);
1011   HloInstruction* lhs_instruction =
1012       b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
1013 
1014   Array3D<float> rhs_array = {{{3.f, 4.f}}};
1015   auto rhs_literal = LiteralUtil::CreateR3FromArray3D<float>(rhs_array);
1016   HloInstruction* rhs_instruction =
1017       b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
1018 
1019   Window window;
1020   WindowDimension dim;
1021   dim.set_size(2);
1022   dim.set_stride(1);
1023   dim.set_padding_low(0);
1024   dim.set_padding_high(1);
1025   dim.set_window_dilation(1);
1026   dim.set_base_dilation(1);
1027   *window.add_dimensions() = dim;
1028 
1029   ConvolutionDimensionNumbers dnums;
1030   dnums.set_input_batch_dimension(0);
1031   dnums.set_output_batch_dimension(0);
1032   dnums.set_input_feature_dimension(1);
1033   dnums.set_output_feature_dimension(1);
1034   dnums.add_input_spatial_dimensions(2);
1035   dnums.add_output_spatial_dimensions(2);
1036 
1037   dnums.set_kernel_output_feature_dimension(0);
1038   dnums.set_kernel_input_feature_dimension(1);
1039   dnums.add_kernel_spatial_dimensions(2);
1040 
1041   Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 3});
1042   b.AddInstruction(HloInstruction::CreateConvolve(
1043       shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
1044       /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
1045   m_->AddEntryComputation(b.Build());
1046 
1047   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1048 
1049   Array3D<float> expected_array = {{{11.f, 18.f, 9.f}}};
1050   auto expected = LiteralUtil::CreateR3FromArray3D<float>(expected_array);
1051 
1052   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1053 }
1054 
TEST_P(HloEvaluatorBf16Test,Simple4x4Conv2DWith2x2Kernel)1055 TEST_P(HloEvaluatorBf16Test, Simple4x4Conv2DWith2x2Kernel) {
1056   HloComputation::Builder b(TestName());
1057 
1058   Array4D<float> lhs_array(1, 1, 4, 4);
1059   // clang-format off
1060   lhs_array.FillWithYX(Array2D<float>({
1061     {1,  2,  3,  4 },
1062     {5,  6,  7,  8 },
1063     {9,  10, 11, 12},
1064     {13, 14, 15, 16},
1065   }));
1066   // clang-format on
1067   auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(lhs_array);
1068   HloInstruction* lhs_instruction =
1069       b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
1070 
1071   Array4D<float> rhs_array(1, 1, 2, 2);
1072   // clang-format off
1073   rhs_array.FillWithYX(Array2D<float>({
1074     {5, 6},
1075     {7, 8},
1076   }));
1077   // clang-format on
1078   auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(rhs_array);
1079   HloInstruction* rhs_instruction =
1080       b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
1081 
1082   Window window;
1083   WindowDimension dim;
1084   dim.set_size(2);
1085   dim.set_stride(1);
1086   dim.set_padding_low(0);
1087   dim.set_padding_high(1);
1088   dim.set_window_dilation(1);
1089   dim.set_base_dilation(1);
1090   *window.add_dimensions() = dim;
1091   *window.add_dimensions() = dim;
1092 
1093   ConvolutionDimensionNumbers dnums =
1094       XlaBuilder::CreateDefaultConvDimensionNumbers(2);
1095 
1096   Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
1097   b.AddInstruction(HloInstruction::CreateConvolve(
1098       shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
1099       /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
1100   m_->AddEntryComputation(b.Build());
1101 
1102   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1103 
1104   Array4D<float> expected_array(1, 1, 4, 4);
1105   // clang-format off
1106   expected_array.FillWithYX(Array2D<float>({
1107     {100, 126, 152,  76},
1108     {204, 230, 256, 124},
1109     {308, 334, 360, 172},
1110     {149, 160, 171,  80},
1111   }));
1112   // clang-format on
1113   auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
1114 
1115   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1116 }
1117 
TEST_P(HloEvaluatorBf16Test,Conv2DGeneralDimensionsReversed)1118 TEST_P(HloEvaluatorBf16Test, Conv2DGeneralDimensionsReversed) {
1119   HloComputation::Builder b(TestName());
1120 
1121   // clang-format off
1122   // Input dimensions: [feature=2, height=3, batch=1, width=4]
1123   Array4D<float> input({
1124     {{{1, 2, 3, 4}},
1125      {{5, 6, 7, 8}},
1126      {{9, 10, 11, 12}}},
1127     {{{13, 14, 15, 16}},
1128      {{17, 18, 19, 20}},
1129      {{21, 22, 23, 24}}}
1130   });
1131   // Weight dimensions:
1132   // [kernel_output_feature=1, width=3, kernel_input_feature=2, height=3]
1133   Array4D<float> weight({{
1134     {{1, 7, 13},
1135      {4, 10, 16}},
1136     {{2, 8, 14},
1137      {5, 11, 17}},
1138     {{3, 9, 15},
1139      {6, 12, 18}}
1140   }});
1141   // clang-format on
1142 
1143   auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(input);
1144   HloInstruction* lhs_instruction =
1145       b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
1146 
1147   auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(weight);
1148   HloInstruction* rhs_instruction =
1149       b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
1150   rhs_instruction = b.AddInstruction(HloInstruction::CreateReverse(
1151       rhs_instruction->shape(), rhs_instruction, {3, 1}));
1152 
1153   Window window;
1154   WindowDimension dim;
1155   dim.set_size(3);
1156   dim.set_stride(1);
1157   dim.set_padding_low(0);
1158   dim.set_padding_high(0);
1159   dim.set_window_dilation(1);
1160   dim.set_base_dilation(1);
1161   dim.set_window_reversal(true);
1162   *window.add_dimensions() = dim;
1163   *window.add_dimensions() = dim;
1164 
1165   ConvolutionDimensionNumbers dnums;
1166   dnums.set_input_batch_dimension(2);
1167   dnums.set_output_batch_dimension(2);
1168   dnums.set_input_feature_dimension(0);
1169   dnums.set_output_feature_dimension(0);
1170   dnums.add_input_spatial_dimensions(1);
1171   dnums.add_output_spatial_dimensions(1);
1172   dnums.add_input_spatial_dimensions(3);
1173   dnums.add_output_spatial_dimensions(3);
1174 
1175   dnums.set_kernel_output_feature_dimension(0);
1176   dnums.set_kernel_input_feature_dimension(2);
1177   dnums.add_kernel_spatial_dimensions(3);
1178   dnums.add_kernel_spatial_dimensions(1);
1179 
1180   Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
1181   b.AddInstruction(HloInstruction::CreateConvolve(
1182       shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
1183       /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
1184   m_->AddEntryComputation(b.Build());
1185 
1186   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1187 
1188   // clang-format off
1189   // Result dimensions: [feature=1, height=1, batch=1, width=2]
1190   Array4D<float> expected_array({{{{2514, 2685}}}});
1191   Array4D<float> expected_array_bf16({{{{2512, 2688}}}});
1192   // clang-format on
1193   auto expected = LiteralUtil::CreateR4FromArray4D<float>(
1194       use_bfloat16_ ? expected_array_bf16 : expected_array);
1195 
1196   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1197 }
1198 
TEST_P(HloEvaluatorBf16Test,Conv2DGeneralDimensions)1199 TEST_P(HloEvaluatorBf16Test, Conv2DGeneralDimensions) {
1200   HloComputation::Builder b(TestName());
1201 
1202   // clang-format off
1203   // Input dimensions: [feature=2, height=3, batch=1, width=4]
1204   Array4D<float> input({
1205     {{{1, 2, 3, 4}},
1206      {{5, 6, 7, 8}},
1207      {{9, 10, 11, 12}}},
1208     {{{13, 14, 15, 16}},
1209      {{17, 18, 19, 20}},
1210      {{21, 22, 23, 24}}}
1211   });
1212   // Weight dimensions:
1213   // [kernel_output_feature=1, width=3, kernel_input_feature=2, height=3]
1214   Array4D<float> weight({{
1215     {{1, 7, 13},
1216      {4, 10, 16}},
1217     {{2, 8, 14},
1218      {5, 11, 17}},
1219     {{3, 9, 15},
1220      {6, 12, 18}}
1221   }});
1222   // clang-format on
1223 
1224   auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(input);
1225   HloInstruction* lhs_instruction =
1226       b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
1227 
1228   auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(weight);
1229   HloInstruction* rhs_instruction =
1230       b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
1231 
1232   Window window;
1233   WindowDimension dim;
1234   dim.set_size(3);
1235   dim.set_stride(1);
1236   dim.set_padding_low(0);
1237   dim.set_padding_high(0);
1238   dim.set_window_dilation(1);
1239   dim.set_base_dilation(1);
1240   *window.add_dimensions() = dim;
1241   *window.add_dimensions() = dim;
1242 
1243   ConvolutionDimensionNumbers dnums;
1244   dnums.set_input_batch_dimension(2);
1245   dnums.set_output_batch_dimension(2);
1246   dnums.set_input_feature_dimension(0);
1247   dnums.set_output_feature_dimension(0);
1248   dnums.add_input_spatial_dimensions(1);
1249   dnums.add_output_spatial_dimensions(1);
1250   dnums.add_input_spatial_dimensions(3);
1251   dnums.add_output_spatial_dimensions(3);
1252 
1253   dnums.set_kernel_output_feature_dimension(0);
1254   dnums.set_kernel_input_feature_dimension(2);
1255   dnums.add_kernel_spatial_dimensions(3);
1256   dnums.add_kernel_spatial_dimensions(1);
1257 
1258   Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
1259   b.AddInstruction(HloInstruction::CreateConvolve(
1260       shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
1261       /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
1262   m_->AddEntryComputation(b.Build());
1263 
1264   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1265 
1266   // clang-format off
1267   // Result dimensions: [feature=1, height=1, batch=1, width=2]
1268   Array4D<float> expected_array({{{{2514, 2685}}}});
1269   Array4D<float> expected_array_bf16({{{{2512, 2688}}}});
1270   // clang-format on
1271   auto expected = LiteralUtil::CreateR4FromArray4D<float>(
1272       use_bfloat16_ ? expected_array_bf16 : expected_array);
1273 
1274   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1275 }
1276 
TEST_P(HloEvaluatorBf16Test,DilatedBaseConv2DWithHighPadding)1277 TEST_P(HloEvaluatorBf16Test, DilatedBaseConv2DWithHighPadding) {
1278   HloComputation::Builder b(TestName());
1279 
1280   Array4D<float> lhs_array(1, 1, 4, 4);
1281   // clang-format off
1282   lhs_array.FillWithYX(Array2D<float>({
1283     {1,  2,  3,  4 },
1284     {5,  6,  7,  8 },
1285     {9,  10, 11, 12},
1286     {13, 14, 15, 16},
1287   }));
1288   // clang-format on
1289   auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(lhs_array);
1290   HloInstruction* lhs_instruction =
1291       b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
1292 
1293   Array4D<float> rhs_array(1, 1, 2, 2);
1294   // clang-format off
1295   rhs_array.FillWithYX(Array2D<float>({
1296     {5, 6},
1297     {7, 8},
1298   }));
1299   // clang-format on
1300   auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(rhs_array);
1301   HloInstruction* rhs_instruction =
1302       b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
1303 
1304   Window window;
1305   WindowDimension dim;
1306   dim.set_size(2);
1307   dim.set_stride(1);
1308   dim.set_padding_low(0);
1309   dim.set_padding_high(1);
1310   dim.set_window_dilation(1);
1311   dim.set_base_dilation(2);
1312   *window.add_dimensions() = dim;
1313   *window.add_dimensions() = dim;
1314 
1315   ConvolutionDimensionNumbers dnums =
1316       XlaBuilder::CreateDefaultConvDimensionNumbers(2);
1317 
1318   Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 7, 7});
1319   b.AddInstruction(HloInstruction::CreateConvolve(
1320       shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
1321       /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
1322   m_->AddEntryComputation(b.Build());
1323 
1324   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1325 
1326   Array4D<float> expected_array(1, 1, 7, 7);
1327   expected_array.FillWithYX(Array2D<float>({
1328       {5, 12, 10, 18, 15, 24, 20},
1329       {35, 48, 42, 56, 49, 64, 56},
1330       {25, 36, 30, 42, 35, 48, 40},
1331       {63, 80, 70, 88, 77, 96, 84},
1332       {45, 60, 50, 66, 55, 72, 60},
1333       {91, 112, 98, 120, 105, 128, 112},
1334       {65, 84, 70, 90, 75, 96, 80},
1335   }));
1336   auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
1337 
1338   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1339 }
1340 
TEST_P(HloEvaluatorBf16Test,DilatedBaseConv2DWithLowAndHighPadding)1341 TEST_P(HloEvaluatorBf16Test, DilatedBaseConv2DWithLowAndHighPadding) {
1342   HloComputation::Builder b(TestName());
1343 
1344   Array4D<float> lhs_array(1, 1, 4, 4);
1345   // clang-format off
1346   lhs_array.FillWithYX(Array2D<float>({
1347     {1,  2,  3,  4 },
1348     {5,  6,  7,  8 },
1349     {9,  10, 11, 12},
1350     {13, 14, 15, 16},
1351   }));
1352   // clang-format on
1353   auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(lhs_array);
1354   HloInstruction* lhs_instruction =
1355       b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
1356 
1357   Array4D<float> rhs_array(1, 1, 2, 2);
1358   // clang-format off
1359   rhs_array.FillWithYX(Array2D<float>({
1360     {5, 6},
1361     {7, 8},
1362   }));
1363   // clang-format on
1364   auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(rhs_array);
1365   HloInstruction* rhs_instruction =
1366       b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
1367 
1368   Window window;
1369   WindowDimension dim;
1370   dim.set_size(2);
1371   dim.set_stride(1);
1372   dim.set_padding_low(1);
1373   dim.set_padding_high(1);
1374   dim.set_window_dilation(1);
1375   dim.set_base_dilation(2);
1376   *window.add_dimensions() = dim;
1377   *window.add_dimensions() = dim;
1378 
1379   ConvolutionDimensionNumbers dnums =
1380       XlaBuilder::CreateDefaultConvDimensionNumbers(2);
1381 
1382   Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 8, 8});
1383   b.AddInstruction(HloInstruction::CreateConvolve(
1384       shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
1385       /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
1386   m_->AddEntryComputation(b.Build());
1387 
1388   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1389 
1390   Array4D<float> expected_array(1, 1, 8, 8);
1391   expected_array.FillWithYX(Array2D<float>({
1392       {8, 7, 16, 14, 24, 21, 32, 28},
1393       {6, 5, 12, 10, 18, 15, 24, 20},
1394       {40, 35, 48, 42, 56, 49, 64, 56},
1395       {30, 25, 36, 30, 42, 35, 48, 40},
1396       {72, 63, 80, 70, 88, 77, 96, 84},
1397       {54, 45, 60, 50, 66, 55, 72, 60},
1398       {104, 91, 112, 98, 120, 105, 128, 112},
1399       {78, 65, 84, 70, 90, 75, 96, 80},
1400   }));
1401   auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
1402 
1403   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1404 }
1405 
TEST_P(HloEvaluatorBf16Test,DilatedWindowAndBaseConv2DWithDifferentLowAndHighPaddingAndStrides)1406 TEST_P(HloEvaluatorBf16Test,
1407        DilatedWindowAndBaseConv2DWithDifferentLowAndHighPaddingAndStrides) {
1408   HloComputation::Builder b(TestName());
1409 
1410   Array4D<float> lhs_array(1, 1, 4, 4);
1411   // clang-format off
1412   lhs_array.FillWithYX(Array2D<float>({
1413     {1,  2,  3,  4 },
1414     {5,  6,  7,  8 },
1415     {9,  10, 11, 12},
1416     {13, 14, 15, 16},
1417   }));
1418   // clang-format on
1419   auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(lhs_array);
1420   HloInstruction* lhs_instruction =
1421       b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
1422 
1423   Array4D<float> rhs_array(1, 1, 2, 3);
1424   // clang-format off
1425   rhs_array.FillWithYX(Array2D<float>({
1426     {5, 6, 7},
1427     {8, 9, 10},
1428   }));
1429   // clang-format on
1430   auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(rhs_array);
1431   HloInstruction* rhs_instruction =
1432       b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
1433 
1434   Window window;
1435   WindowDimension dim;
1436   dim.set_size(2);
1437   dim.set_stride(1);
1438   dim.set_padding_low(2);
1439   dim.set_padding_high(2);
1440   dim.set_window_dilation(2);
1441   dim.set_base_dilation(2);
1442   *window.add_dimensions() = dim;
1443   dim.set_size(3);
1444   dim.set_stride(3);
1445   dim.set_padding_low(2);
1446   dim.set_padding_high(-1);
1447   dim.set_window_dilation(1);
1448   dim.set_base_dilation(3);
1449   *window.add_dimensions() = dim;
1450 
1451   ConvolutionDimensionNumbers dnums =
1452       XlaBuilder::CreateDefaultConvDimensionNumbers(2);
1453 
1454   Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 9, 3});
1455   b.AddInstruction(HloInstruction::CreateConvolve(
1456       shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
1457       /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
1458   m_->AddEntryComputation(b.Build());
1459 
1460   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1461 
1462   Array4D<float> expected_array(1, 1, 9, 3);
1463   expected_array.FillWithYX(Array2D<float>({
1464       {10, 20, 30},
1465       {0, 0, 0},
1466       {57, 74, 91},
1467       {0, 0, 0},
1468       {125, 142, 159},
1469       {0, 0, 0},
1470       {193, 210, 227},
1471       {0, 0, 0},
1472       {91, 98, 105},
1473   }));
1474   auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
1475 
1476   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1477 }
1478 
TEST_P(HloEvaluatorBf16Test,Conv2DGroupedConvolution)1479 TEST_P(HloEvaluatorBf16Test, Conv2DGroupedConvolution) {
1480   HloComputation::Builder b(TestName());
1481   std::vector<int64_t> input_dims = {1, 2, 2, 4};
1482   std::vector<int64_t> filter_dims = {2, 2, 2, 8};
1483   Shape input_shape = ShapeUtil::MakeShapeWithType<float>(input_dims);
1484   Shape filter_shape = ShapeUtil::MakeShapeWithType<float>(filter_dims);
1485   // Tensorflow dimension numbers for 2D convolution.
1486   ConvolutionDimensionNumbers dnums;
1487   dnums.set_input_batch_dimension(0);
1488   dnums.set_output_batch_dimension(0);
1489   dnums.add_input_spatial_dimensions(1);
1490   dnums.add_output_spatial_dimensions(1);
1491   dnums.add_input_spatial_dimensions(2);
1492   dnums.add_output_spatial_dimensions(2);
1493   dnums.set_input_feature_dimension(3);
1494   dnums.set_output_feature_dimension(3);
1495   dnums.add_kernel_spatial_dimensions(0);
1496   dnums.add_kernel_spatial_dimensions(1);
1497   dnums.set_kernel_input_feature_dimension(2);
1498   dnums.set_kernel_output_feature_dimension(3);
1499 
1500   Window window;
1501   WindowDimension dim;
1502   dim.set_size(2);
1503   dim.set_stride(1);
1504   dim.set_padding_low(0);
1505   dim.set_padding_high(0);
1506   dim.set_window_dilation(1);
1507   dim.set_base_dilation(1);
1508   *window.add_dimensions() = dim;
1509   *window.add_dimensions() = dim;
1510 
1511   std::vector<float> input_elems(ShapeUtil::ElementsIn(input_shape));
1512   std::iota(input_elems.begin(), input_elems.end(), -7);
1513   auto input_r1 = LiteralUtil::CreateR1<float>(input_elems);
1514   auto input_r4 = input_r1.Reshape(input_dims).value();
1515   HloInstruction* lhs_instruction =
1516       b.AddInstruction(HloInstruction::CreateConstant(std::move(input_r4)));
1517 
1518   std::vector<float> filter_elems(ShapeUtil::ElementsIn(filter_shape));
1519   std::iota(filter_elems.begin(), filter_elems.end(), -31);
1520   auto filter_r1 = LiteralUtil::CreateR1<float>(filter_elems);
1521   auto filter_r4 = filter_r1.Reshape(filter_dims).value();
1522   HloInstruction* rhs_instruction =
1523       b.AddInstruction(HloInstruction::CreateConstant(std::move(filter_r4)));
1524 
1525   Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 8});
1526   b.AddInstruction(HloInstruction::CreateConvolve(
1527       shape, lhs_instruction, rhs_instruction,
1528       /*feature_group_count=*/2, /*batch_group_count=*/1, window, dnums,
1529       DefaultPrecisionConfig(2)));
1530   m_->AddEntryComputation(b.Build());
1531 
1532   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1533 
1534   Array4D<float> expected_array(1, 1, 1, 8);
1535   expected_array.FillWithYX(
1536       Array2D<float>({{668, 664, 660, 656, 668, 680, 692, 704}}));
1537   auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
1538   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1539 }
1540 
1541 // Initialization of data sets for FFT tests:
1542 
InitializeFftData()1543 void HloEvaluatorTest::InitializeFftData() {
1544   // clang-format off
1545   fft_c64x2x4x8_ = LiteralUtil::CreateR3<complex64>({
1546     {{{0.0, 0.0}, {1.0, 0.0}, {2.0, 0.0}, {3.0, 0.0},
1547       {4.0, 0.0}, {5.0, 0.0}, {6.0, 0.0}, {7.0, 0.0}},
1548      {{0.0, 0.0}, {0.0, 1.0}, {0.0, 2.0}, {0.0, 3.0},
1549       {0.0, 4.0}, {0.0, 5.0}, {0.0, 6.0}, {0.0, 7.0}},
1550      {{0.0, 7.0}, {1.0, 6.0}, {2.0, 5.0}, {3.0, 4.0},
1551       {4.0, 3.0}, {5.0, 2.0}, {6.0, 1.0}, {7.0, 0.0}},
1552      {{7.0, 0.0}, {6.0, 1.0}, {5.0, 2.0}, {4.0, 3.0},
1553       {3.0, 4.0}, {2.0, 5.0}, {1.0, 6.0}, {0.0, 7.0}}},
1554     {{{-4.0, 0.0}, {-3.0, 0.0}, {-2.0, 0.0}, {-1.0, 0.0},
1555       {1.0, 0.0}, {2.0, 0.0}, {3.0, 0.0}, {4.0, 0.0}},
1556      {{0.0, -4.0}, {0.0, -3.0}, {0.0, -2.0}, {0.0, -1.0},
1557       {0.0, 1.0}, {0.0, 2.0}, {0.0, 3.0}, {0.0, 4.0}},
1558      {{3.5, 3.5}, {-1.707107, -0.707107}, {-1.0, -0.0}, {-0.707107, 0.292893},
1559       {-0.5, 0.5}, {-0.292893, 0.707107}, {0.0, 1.0}, {0.707107, 1.707107}},
1560      {{3.5, 3.5}, {1.707107, 0.707107}, {1.0, 0.0}, {0.707107, -0.292893},
1561       {0.5, -0.5}, {0.292893, -0.707107}, {-0.0, -1.0}, {-0.707107, -1.707107}}}
1562   });
1563   fft_c64x2x4x8_1d_ = LiteralUtil::CreateR3<complex64>({
1564     {{{28.0, 0.0}, {-4.0, 9.656854}, {-4.0, 4.0}, {-4.0, 1.656854},
1565       {-4.0, 0.0}, {-4.0, -1.656854}, {-4.0, -4.0}, {-4.0, -9.656854}},
1566      {{0.0, 28.0}, {-9.656854, -4.0}, {-4.0, -4.0}, {-1.656854, -4.0},
1567       {0.0, -4.0}, {1.656854, -4.0}, {4.0, -4.0}, {9.656854, -4.0}},
1568      {{28.0, 28.0}, {5.656854, 13.656854}, {0.0, 8.0}, {-2.343146, 5.656854},
1569       {-4.0, 4.0}, {-5.656854, 2.343146}, {-8.0, -0.0}, {-13.656854, -5.656854}},  // NOLINT
1570      {{28.0, 28.0}, {-5.656854, -13.656854}, {-0.0, -8.0}, {2.343146, -5.656854},  // NOLINT
1571       {4.0, -4.0}, {5.656854, -2.343146}, {8.0, 0.0}, {13.656854, 5.656854}}},
1572     {{{0.0, 0.0}, {-5.0, 12.071068}, {-4.0, 4.0}, {-5.0, 2.071068},
1573       {-4.0, 0.0}, {-5.0, -2.071068}, {-4.0, -4.0}, {-5.0, -12.071068}},
1574      {{0.0, 0.0}, {-12.071068, -5.0}, {-4.0, -4.0}, {-2.071068, -5.0},
1575       {0.0, -4.0}, {2.071068, -5.0}, {4.0, -4.0}, {12.071068, -5.0}},
1576      {{0.0, 7.0}, {1.0, 6.0}, {2.0, 5.0}, {3.0, 4.0},
1577       {4.0, 3.0}, {5.0, 2.0}, {6.0, 1.0}, {7.0, 0.0}},
1578      {{7.0, 0.0}, {6.0, 1.0}, {5.0, 2.0}, {4.0, 3.0},
1579       {3.0, 4.0}, {2.0, 5.0}, {1.0, 6.0}, {0.0, 7.0}}}
1580   });
1581   fft_c64x2x4x8_2d_ = LiteralUtil::CreateR3<complex64>({
1582     {{{84.0, 84.0}, {-13.656854, 5.656854}, {-8.0, 0.0}, {-5.656854, -2.343146},
1583       {-4.0, -4.0}, {-2.343146, -5.656854}, {0.0, -8.0}, {5.656854, -13.656854}},  // NOLINT
1584      {{0.0, 0.0}, {0.0, -0.0}, {0.0, 0.0}, {0.0, 0.0},
1585       {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
1586      {{28.0, -28.0}, {16.970562, 40.970562}, {0.0, 24.0}, {-7.029438, 16.970562},      // NOLINT
1587       {-12.0, 12.0}, {-16.970562, 7.029438}, {-24.0, 0.0}, {-40.970562, -16.970562}},  // NOLINT
1588      {{0.0, -56.0}, {-19.313708, -8.0}, {-8.0, -8.0}, {-3.313708, -8.0},
1589       {0.0, -8.0}, {3.313708, -8.0}, {8.0, -8.0}, {19.313708, -8.0}}},
1590     {{{7.0, 7.0}, {-10.071068, 14.071068}, {-1.0, 7.0}, {-0.071068, 4.071068},
1591       {3.0, 3.0}, {4.071068, -0.071068}, {7.0, -1.0}, {14.071068, -10.071068}},
1592      {{0.0, 0.0}, {-12.0, 24.142136}, {-12.0, 8.0}, {-16.0, 4.142136},
1593       {-16.0, 0.0}, {-20.0, -4.142136}, {-20.0, -8.0}, {-24.0, -24.142136}},
1594      {{-7.0, 7.0}, {2.071068, 22.071068}, {-3.0, 11.0}, {-3.928932, 8.071068},
1595       {-3.0, 3.0}, {-4.071068, -0.071068}, {-3.0, -5.0}, {-10.071068, -14.071068}},  // NOLINT
1596      {{0.0, -14.0}, {0.0, -12.0}, {0.0, -10.0}, {0.0, -8.0},
1597       {0.0, -6.0}, {0.0, -4.0}, {0.0, -2.0}, {0.0, 0.0}}}
1598   });
1599   fft_c64x2x4x8_3d_ = LiteralUtil::CreateR3<complex64>({
1600     {{{91.0, 91.0}, {-23.727922, 19.727922}, {-9.0, 7.0}, {-5.727922, 1.727922},
1601       {-1.0, -1.0}, {1.727922, -5.727922}, {7.0, -9}, {19.727922, -23.727922}},
1602      {{0.0, 0.0}, {-12.0, 24.142136}, {-12.0, 8.0}, {-16.0, 4.142136},
1603       {-16.0, 0.0}, {-20.0, -4.142136}, {-20.0, -8.0}, {-24.0, -24.142136}},
1604      {{21.0, -21.0}, {19.041630, 63.041630}, {-3.0, 35.0}, {-10.958370, 25.041630},     // NOLINT
1605       {-15.0, 15.0}, {-21.041630, 6.958370}, {-27.0, -5.0}, {-51.041630, -31.041630}},  // NOLINT
1606      {{0.0, -70.0}, {-19.313708, -20.0}, {-8.0, -18.0}, {-3.313708, -16.0},
1607       {0.0, -14.0}, {3.313708, -12.0}, {8.0, -10.0}, {19.313708, -8.0}}},
1608     {{{77.0, 77.0}, {-3.585786, -8.414214}, {-7.0, -7.0}, {-5.585786, -6.414214},   // NOLINT
1609       {-7.0, -7.0}, {-6.414214, -5.585786}, {-7.0, -7.0}, {-8.414214, -3.585786}},  // NOLINT
1610      {{0.0, 0.0}, {12.0, -24.142136}, {12.0, -8.0}, {16.0, -4.142136},
1611       {16.0, 0.0}, {20.0, 4.142136}, {20.0, 8.0}, {24.0, 24.142136}},
1612      {{35.0, -35.0}, {14.899494, 18.899494}, {3.0, 13.0}, {-3.100506, 8.899494},
1613       {-9.0, 9.0}, {-12.899494, 7.100506}, {-21.0, 5.0}, {-30.899494, -2.899494}},  // NOLINT
1614      {{0.0, -42.0}, {-19.313708, 4.0}, {-8.0, 2.0}, {-3.313708, 0.0},
1615       {0.0, -2.0}, {3.313708, -4.0}, {8.0, -6.0}, {19.313708, -8.0}}}
1616   });
1617   // clang-format on
1618 }
1619 
1620 // Simple FFT tests:
1621 
1622 TEST_F(HloEvaluatorTest, 1D_FFT_4_on_c64x4) {
1623   const char* hlo_text = R"(
1624 HloModule Fft
1625 
1626 ENTRY main {
1627   operand = c64[4] parameter(0)
1628   ROOT fft = c64[4] fft(operand), fft_type=FFT, fft_length={4}
1629 }
1630 )";
1631   auto input = LiteralUtil::CreateR1<complex64>(
1632       {{1.0, 0.0}, {2.0, 0.0}, {3.0, 0.0}, {4.0, 0.0}});
1633   auto expected = LiteralUtil::CreateR1<complex64>(
1634       {{10.0, 0.0}, {-2.0, 2.0}, {-2.0, 0.0}, {-2.0, -2.0}});
1635   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1636   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
1637   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
1638   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
1639 }
1640 
1641 TEST_F(HloEvaluatorTest, 1D_IFFT_4_on_c64x4) {
1642   const char* hlo_text = R"(
1643 HloModule Fft
1644 
1645 ENTRY main {
1646   operand = c64[4] parameter(0)
1647   ROOT ifft = c64[4] fft(operand), fft_type=IFFT, fft_length={4}
1648 }
1649 )";
1650   auto input = LiteralUtil::CreateR1<complex64>(
1651       {{10.0, 0.0}, {-2.0, 2.0}, {-2.0, 0.0}, {-2.0, -2.0}});
1652   auto expected = LiteralUtil::CreateR1<complex64>(
1653       {{1.0, 0.0}, {2.0, 0.0}, {3.0, 0.0}, {4.0, 0.0}});
1654   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1655   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
1656   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
1657   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
1658 }
1659 
1660 TEST_F(HloEvaluatorTest, 1D_RFFT_4_on_f32x4) {
1661   const char* hlo_text = R"(
1662 HloModule Fft
1663 
1664 ENTRY main {
1665   operand = f32[4] parameter(0)
1666   ROOT rfft = c64[3] fft(operand), fft_type=RFFT, fft_length={4}
1667 }
1668 )";
1669   auto input = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, 4.0});
1670   auto expected =
1671       LiteralUtil::CreateR1<complex64>({{10.0, 0.0}, {-2.0, 2.0}, {-2.0, 0.0}});
1672   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1673   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
1674   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
1675   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
1676 }
1677 
1678 TEST_F(HloEvaluatorTest, 1D_IRFFT_4_on_c64x3) {
1679   const char* hlo_text = R"(
1680 HloModule Fft
1681 
1682 ENTRY main {
1683   operand = c64[3] parameter(0)
1684   ROOT irfft = f32[4] fft(operand), fft_type=IRFFT, fft_length={4}
1685 }
1686 )";
1687   auto input =
1688       LiteralUtil::CreateR1<complex64>({{10.0, 0.0}, {-2.0, 2.0}, {-2.0, 0.0}});
1689   auto expected = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, 4.0});
1690   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1691   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
1692   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
1693   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
1694 }
1695 
1696 // 1D FFT tests:
1697 
1698 TEST_F(HloEvaluatorTest, 1D_FFT_8_on_c64x2x4x8) {
1699   const char* hlo_text = R"(
1700 HloModule Fft
1701 
1702 ENTRY main {
1703   operand = c64[2, 4, 8] parameter(0)
1704   ROOT fft = c64[2, 4, 8] fft(operand), fft_type=FFT, fft_length={8}
1705 }
1706 )";
1707   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1708   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&fft_c64x2x4x8_}));
1709   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_1d_.shape()));
1710   EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_1d_, result, fft_error_));
1711 }
1712 
1713 TEST_F(HloEvaluatorTest, 1D_IFFT_8_on_c64x2x4x8) {
1714   const char* hlo_text = R"(
1715 HloModule Fft
1716 
1717 ENTRY main {
1718   operand = c64[2, 4, 8] parameter(0)
1719   ROOT ifft = c64[2, 4, 8] fft(operand), fft_type=IFFT, fft_length={8}
1720 }
1721 )";
1722   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1723   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&fft_c64x2x4x8_1d_}));
1724   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_.shape()));
1725   EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_, result, fft_error_));
1726 }
1727 
1728 TEST_F(HloEvaluatorTest, 1D_RFFT_8_on_f32x8) {
1729   const char* hlo_text = R"(
1730 HloModule Fft
1731 
1732 ENTRY main {
1733   operand = f32[8] parameter(0)
1734   ROOT rfft = c64[5] fft(operand), fft_type=RFFT, fft_length={8}
1735 }
1736 )";
1737   auto input =
1738       LiteralUtil::CreateR1<float>({1.8, 2.7, 3.6, 4.5, 5.4, 6.3, 7.2, 8.1});
1739   auto expected = LiteralUtil::CreateR1<complex64>({{39.6, 0.0},
1740                                                     {-3.6, 8.691169},
1741                                                     {-3.6, 3.6},
1742                                                     {-3.6, 1.491169},
1743                                                     {-3.6, 0.0}});
1744   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1745   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
1746   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
1747   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
1748 }
1749 
1750 TEST_F(HloEvaluatorTest, 1D_IRFFT_8_on_c64x5) {
1751   const char* hlo_text = R"(
1752 HloModule Fft
1753 
1754 ENTRY main {
1755   operand = c64[5] parameter(0)
1756   ROOT irfft = f32[8] fft(operand), fft_type=IRFFT, fft_length={8}
1757 }
1758 )";
1759   auto input = LiteralUtil::CreateR1<complex64>({{39.6, 0.0},
1760                                                  {-3.6, 8.691169},
1761                                                  {-3.6, 3.6},
1762                                                  {-3.6, 1.491169},
1763                                                  {-3.6, 0.0}});
1764   auto expected =
1765       LiteralUtil::CreateR1<float>({1.8, 2.7, 3.6, 4.5, 5.4, 6.3, 7.2, 8.1});
1766   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1767   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
1768   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
1769   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
1770 }
1771 
1772 TEST_F(HloEvaluatorTest, 1D_RFFT_9_on_f32x9) {
1773   const char* hlo_text = R"(
1774 HloModule Fft
1775 
1776 ENTRY main {
1777   operand = f32[9] parameter(0)
1778   ROOT rfft = c64[5] fft(operand), fft_type=RFFT, fft_length={9}
1779 }
1780 )";
1781   auto input = LiteralUtil::CreateR1<float>(
1782       {1.8, 2.7, 3.6, 4.5, 5.4, 6.3, 7.2, 8.1, 9.9});
1783   auto expected = LiteralUtil::CreateR1<complex64>({{49.5, 0.0},
1784                                                     {-3.360560, 11.705792},
1785                                                     {-3.893717, 5.712929},
1786                                                     {-4.5, 3.117691},
1787                                                     {-4.895723, 1.021942}});
1788   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1789   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
1790   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
1791   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
1792 }
1793 
1794 TEST_F(HloEvaluatorTest, 1D_IRFFT_9_on_c64x5) {
1795   const char* hlo_text = R"(
1796 HloModule Fft
1797 
1798 ENTRY main {
1799   operand = c64[5] parameter(0)
1800   ROOT irfft = f32[9] fft(operand), fft_type=IRFFT, fft_length={9}
1801 }
1802 )";
1803   auto input = LiteralUtil::CreateR1<complex64>({{49.5, 0.0},
1804                                                  {-3.360560, 11.705792},
1805                                                  {-3.893717, 5.712929},
1806                                                  {-4.5, 3.117691},
1807                                                  {-4.895723, 1.021942}});
1808   auto expected = LiteralUtil::CreateR1<float>(
1809       {1.8, 2.7, 3.6, 4.5, 5.4, 6.3, 7.2, 8.1, 9.9});
1810   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1811   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
1812   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
1813   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
1814 }
1815 
1816 // 2D FFT tests:
1817 
1818 TEST_F(HloEvaluatorTest, 2D_FFT_4x8_on_c64x2x4x8) {
1819   const char* hlo_text = R"(
1820 HloModule Fft
1821 
1822 ENTRY main {
1823   operand = c64[2, 4, 8] parameter(0)
1824   ROOT fft = c64[2, 4, 8] fft(operand), fft_type=FFT, fft_length={4, 8}
1825 }
1826 )";
1827   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1828   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&fft_c64x2x4x8_}));
1829   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_2d_.shape()));
1830   EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_2d_, result, fft_error_));
1831 }
1832 
1833 TEST_F(HloEvaluatorTest, 2D_IFFT_4x8_on_c64x2x4x8) {
1834   const char* hlo_text = R"(
1835 HloModule Fft
1836 
1837 ENTRY main {
1838   operand = c64[2, 4, 8] parameter(0)
1839   ROOT ifft = c64[2, 4, 8] fft(operand), fft_type=IFFT, fft_length={4, 8}
1840 }
1841 )";
1842   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1843   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&fft_c64x2x4x8_2d_}));
1844   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_.shape()));
1845   EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_, result, fft_error_));
1846 }
1847 
1848 TEST_F(HloEvaluatorTest, 2D_RFFT_3x8_on_f32x3x8) {
1849   const char* hlo_text = R"(
1850 HloModule Fft
1851 
1852 ENTRY main {
1853   operand = f32[3, 8] parameter(0)
1854   ROOT rfft = c64[3, 5] fft(operand), fft_type=RFFT, fft_length={3, 8}
1855 }
1856 )";
1857   auto input =
1858       LiteralUtil::CreateR2<float>({{1.8, 2.7, 3.6, 4.5, 5.4, 6.3, 7.2, 8.1},
1859                                     {8.1, 7.2, 6.3, 5.4, 4.5, 3.6, 2.7, 1.8},
1860                                     {1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8}});
1861   auto expected = LiteralUtil::CreateR2<complex64>({{{118.8, 0.0},
1862                                                      {-4.4, 10.622540},
1863                                                      {-4.4, 4.4},
1864                                                      {-4.4, 1.822540},
1865                                                      {-4.4, 0.0}},
1866                                                     {{0.0, 0.0},
1867                                                      {-19.926162, 0.797280},
1868                                                      {-10.128203, -3.728203},
1869                                                      {-6.069756, -5.602720},
1870                                                      {-3.2, -6.928203}},
1871                                                     {{0.0, 0.0},
1872                                                      {13.526162, 14.653687},
1873                                                      {3.728203, 10.128203},
1874                                                      {-0.330244, 8.253687},
1875                                                      {-3.2, 6.928203}}});
1876   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1877   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
1878   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
1879   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
1880 }
1881 
1882 TEST_F(HloEvaluatorTest, 2D_IRFFT_3x8_on_c64x3x5) {
1883   const char* hlo_text = R"(
1884 HloModule Fft
1885 
1886 ENTRY main {
1887   operand = c64[3, 5] parameter(0)
1888   ROOT irfft = f32[3, 8] fft(operand), fft_type=IRFFT, fft_length={3, 8}
1889 }
1890 )";
1891   auto input = LiteralUtil::CreateR2<complex64>({{{118.8, 0.0},
1892                                                   {-4.4, 10.622540},
1893                                                   {-4.4, 4.4},
1894                                                   {-4.4, 1.822540},
1895                                                   {-4.4, 0.0}},
1896                                                  {{0.0, 0.0},
1897                                                   {-19.926162, 0.797280},
1898                                                   {-10.128203, -3.728203},
1899                                                   {-6.069756, -5.602720},
1900                                                   {-3.2, -6.928203}},
1901                                                  {{0.0, 0.0},
1902                                                   {13.526162, 14.653687},
1903                                                   {3.728203, 10.128203},
1904                                                   {-0.330244, 8.253687},
1905                                                   {-3.2, 6.928203}}});
1906   auto expected =
1907       LiteralUtil::CreateR2<float>({{1.8, 2.7, 3.6, 4.5, 5.4, 6.3, 7.2, 8.1},
1908                                     {8.1, 7.2, 6.3, 5.4, 4.5, 3.6, 2.7, 1.8},
1909                                     {1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8}});
1910   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1911   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
1912   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
1913   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
1914 }
1915 
1916 TEST_F(HloEvaluatorTest, 2D_RFFT_3x9_on_f32x3x9) {
1917   const char* hlo_text = R"(
1918 HloModule Fft
1919 
1920 ENTRY main {
1921   operand = f32[3, 9] parameter(0)
1922   ROOT rfft = c64[3, 5] fft(operand), fft_type=RFFT, fft_length={3, 9}
1923 }
1924 )";
1925   auto input = LiteralUtil::CreateR2<float>(
1926       {{1.9, 2.8, 3.7, 4.6, 5.5, 6.4, 7.3, 8.2, 9.1},
1927        {9.1, 8.2, 7.3, 6.4, 5.5, 4.6, 3.7, 2.8, 1.9},
1928        {1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9}});
1929   auto expected = LiteralUtil::CreateR2<complex64>({{{148.5, 0.0},
1930                                                      {-4.95, 13.600013},
1931                                                      {-4.95, 5.899180},
1932                                                      {-4.95, 2.857884},
1933                                                      {-4.95, 0.872819}},
1934                                                     {{0.0, 0.0},
1935                                                      {-25.014467, 2.096690},
1936                                                      {-12.888800, -3.503916},
1937                                                      {-8.1, -5.715768},
1938                                                      {-4.974333, -7.159452}},
1939                                                     {{0.0, 0.0},
1940                                                      {17.814467, 17.685147},
1941                                                      {5.688800, 12.084542},
1942                                                      {0.9, 9.872690},
1943                                                      {-2.225667, 8.429006}}});
1944   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1945   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
1946   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
1947   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
1948 }
1949 
1950 TEST_F(HloEvaluatorTest, 2D_IRFFT_3x9_on_c64x3x5) {
1951   const char* hlo_text = R"(
1952 HloModule Fft
1953 
1954 ENTRY main {
1955   operand = c64[3, 5] parameter(0)
1956   ROOT irfft = f32[3, 9] fft(operand), fft_type=IRFFT, fft_length={3, 9}
1957 }
1958 )";
1959   auto input = LiteralUtil::CreateR2<complex64>({{{148.5, 0.0},
1960                                                   {-4.95, 13.600013},
1961                                                   {-4.95, 5.899180},
1962                                                   {-4.95, 2.857884},
1963                                                   {-4.95, 0.872819}},
1964                                                  {{0.0, 0.0},
1965                                                   {-25.014467, 2.096690},
1966                                                   {-12.888800, -3.503916},
1967                                                   {-8.1, -5.715768},
1968                                                   {-4.974333, -7.159452}},
1969                                                  {{0.0, 0.0},
1970                                                   {17.814467, 17.685147},
1971                                                   {5.688800, 12.084542},
1972                                                   {0.9, 9.872690},
1973                                                   {-2.225667, 8.429006}}});
1974   auto expected = LiteralUtil::CreateR2<float>(
1975       {{1.9, 2.8, 3.7, 4.6, 5.5, 6.4, 7.3, 8.2, 9.1},
1976        {9.1, 8.2, 7.3, 6.4, 5.5, 4.6, 3.7, 2.8, 1.9},
1977        {1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9}});
1978   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1979   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
1980   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
1981   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
1982 }
1983 
1984 // 3D FFT tests:
1985 
1986 TEST_F(HloEvaluatorTest, 3D_FFT_2x4x8_on_c64x2x4x8) {
1987   const char* hlo_text = R"(
1988 HloModule Fft
1989 
1990 ENTRY main {
1991   operand = c64[2, 4, 8] parameter(0)
1992   ROOT fft = c64[2, 4, 8] fft(operand), fft_type=FFT, fft_length={2, 4, 8}
1993 }
1994 )";
1995   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1996   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&fft_c64x2x4x8_}));
1997   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_3d_.shape()));
1998   EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_3d_, result, fft_error_));
1999 }
2000 
2001 TEST_F(HloEvaluatorTest, 3D_IFFT_2x4x8_on_c64x2x4x8) {
2002   const char* hlo_text = R"(
2003 HloModule Fft
2004 
2005 ENTRY main {
2006   operand = c64[2, 4, 8] parameter(0)
2007   ROOT ifft = c64[2, 4, 8] fft(operand), fft_type=IFFT, fft_length={2, 4, 8}
2008 }
2009 )";
2010   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2011   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&fft_c64x2x4x8_3d_}));
2012   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_.shape()));
2013   EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_, result, fft_error_));
2014 }
2015 
2016 TEST_F(HloEvaluatorTest, 3D_RFFT_3x3x4_on_f32x3x3x4) {
2017   const char* hlo_text = R"(
2018 HloModule Fft
2019 
2020 ENTRY main {
2021   operand = f32[3, 3, 4] parameter(0)
2022   ROOT rfft = c64[3, 3, 3] fft(operand), fft_type=RFFT, fft_length={3, 3, 4}
2023 }
2024 )";
2025   auto input = LiteralUtil::CreateR3<float>(
2026       {{{1.8, 2.7, 3.6, 4.5}, {8.1, 7.2, 6.3, 5.4}, {1.1, 2.2, 3.3, 4.4}},
2027        {{5.4, 6.3, 7.2, 8.1}, {4.5, 3.6, 2.7, 1.8}, {5.5, 6.6, 7.7, 8.8}},
2028        {{-1.8, -2.7, -3.6, -4.5},
2029         {-5.4, -6.3, -7.2, -8.1},
2030         {1.9, 2.9, 3.9, 4.9}}});
2031   auto expected = LiteralUtil::CreateR3<complex64>(
2032       {{{{92.8, 0.0}, {-2.8, 2.8}, {-2.8, 0.0}},
2033         {{-5.9, 35.160631}, {-11.519100, -8.919100}, {-1.3, -10.219100}},
2034         {{-5.9, -35.160631}, {8.919100, 11.519100}, {-1.3, 10.219100}}},
2035        {{{29.5, -81.579593}, {1.390897, 5.190897}, {-1.9, 3.290897}},
2036         {{-25.1, -49.017038}, {1.044486, 4.844486}, {-1.9, 2.944486}},
2037         {{11.8, 27.712813}, {1.517691, 4.717691}, {-1.6, 3.117691}}},
2038        {{{29.5, 81.579593}, {-5.190897, -1.390897}, {-1.9, -3.290897}},
2039         {{11.8, -27.712813}, {-4.717691, -1.517691}, {-1.6, -3.117691}},
2040         {{-25.1, 49.017038}, {-4.844486, -1.044486}, {-1.9, -2.944486}}}});
2041   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2042   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2043   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2044   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2045 }
2046 
2047 TEST_F(HloEvaluatorTest, 3D_IRFFT_3x3x4_on_c64x3x3x3) {
2048   const char* hlo_text = R"(
2049 HloModule Fft
2050 
2051 ENTRY main {
2052   operand = c64[3, 3, 3] parameter(0)
2053   ROOT irfft = f32[3, 3, 4] fft(operand), fft_type=IRFFT, fft_length={3, 3, 4}
2054 }
2055 )";
2056   auto input = LiteralUtil::CreateR3<complex64>(
2057       {{{{92.8, 0.0}, {-2.8, 2.8}, {-2.8, 0.0}},
2058         {{-5.9, 35.160631}, {-11.519100, -8.919100}, {-1.3, -10.219100}},
2059         {{-5.9, -35.160631}, {8.919100, 11.519100}, {-1.3, 10.219100}}},
2060        {{{29.5, -81.579593}, {1.390897, 5.190897}, {-1.9, 3.290897}},
2061         {{-25.1, -49.017038}, {1.044486, 4.844486}, {-1.9, 2.944486}},
2062         {{11.8, 27.712813}, {1.517691, 4.717691}, {-1.6, 3.117691}}},
2063        {{{29.5, 81.579593}, {-5.190897, -1.390897}, {-1.9, -3.290897}},
2064         {{11.8, -27.712813}, {-4.717691, -1.517691}, {-1.6, -3.117691}},
2065         {{-25.1, 49.017038}, {-4.844486, -1.044486}, {-1.9, -2.944486}}}});
2066   auto expected = LiteralUtil::CreateR3<float>(
2067       {{{1.8, 2.7, 3.6, 4.5}, {8.1, 7.2, 6.3, 5.4}, {1.1, 2.2, 3.3, 4.4}},
2068        {{5.4, 6.3, 7.2, 8.1}, {4.5, 3.6, 2.7, 1.8}, {5.5, 6.6, 7.7, 8.8}},
2069        {{-1.8, -2.7, -3.6, -4.5},
2070         {-5.4, -6.3, -7.2, -8.1},
2071         {1.9, 2.9, 3.9, 4.9}}});
2072   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2073   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2074   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2075   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2076 }
2077 
2078 TEST_F(HloEvaluatorTest, 3D_RFFT_3x3x5_on_f32x3x3x5) {
2079   const char* hlo_text = R"(
2080 HloModule Fft
2081 
2082 ENTRY main {
2083   operand = f32[3, 3, 5] parameter(0)
2084   ROOT rfft = c64[3, 3, 3] fft(operand), fft_type=RFFT, fft_length={3, 3, 5}
2085 }
2086 )";
2087   auto input = LiteralUtil::CreateR3<float>({{{1.8, 2.7, 3.6, 4.5, 5.4},
2088                                               {8.1, 7.2, 6.3, 5.4, 4.5},
2089                                               {1.1, 2.2, 3.3, 4.4, 5.5}},
2090                                              {{5.4, 6.3, 7.2, 8.1, 9.0},
2091                                               {4.5, 3.6, 2.7, 1.8, 0.9},
2092                                               {5.5, 6.6, 7.7, 8.8, 9.9}},
2093                                              {{-1.8, -2.7, -3.6, -4.5, -5.4},
2094                                               {-5.4, -6.3, -7.2, -8.1, -9.0},
2095                                               {1.9, 2.9, 3.9, 4.9, 5.9}}});
2096   auto expected = LiteralUtil::CreateR3<complex64>(
2097       {{{{119.5, 0.0}, {-3.5, 4.817337}, {-3.5, 1.137219}},
2098         {{-5.75, 56.724664}, {-19.206730, -10.537254}, {-5.775483, -12.245880}},
2099         {{-5.75, -56.724664}, {15.956730, 15.010495}, {2.525483, 13.301869}}},
2100        {{{39.25, -106.088112}, {3.286913, 7.382528}, {-1.038404, 4.885305}},
2101         {{-29.0, -64.951905}, {2.690922, 6.949515}, {-1.179098, 4.452292}},
2102         {{16.75, 30.743902}, {3.363918, 6.649878}, {-0.733751, 4.546954}}},
2103        {{{39.25, 106.088112}, {-8.036913, -0.844714}, {-3.711596, -3.341936}},
2104         {{16.75, -30.743902}, {-7.363918, -1.144350}, {-3.266249, -3.247275}},
2105         {{-29.0, 64.951905}, {-7.440922, -0.411701}, {-3.570902, -2.908924}}}});
2106   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2107   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2108   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2109   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2110 }
2111 
2112 TEST_F(HloEvaluatorTest, 3D_IRFFT_3x3x5_on_c64x3x3x3) {
2113   const char* hlo_text = R"(
2114 HloModule Fft
2115 
2116 ENTRY main {
2117   operand = c64[3, 3, 3] parameter(0)
2118   ROOT irfft = f32[3, 3, 5] fft(operand), fft_type=IRFFT, fft_length={3, 3, 5}
2119 }
2120 )";
2121   auto input = LiteralUtil::CreateR3<complex64>(
2122       {{{{119.5, 0.0}, {-3.5, 4.817337}, {-3.5, 1.137219}},
2123         {{-5.75, 56.724664}, {-19.206730, -10.537254}, {-5.775483, -12.245880}},
2124         {{-5.75, -56.724664}, {15.956730, 15.010495}, {2.525483, 13.301869}}},
2125        {{{39.25, -106.088112}, {3.286913, 7.382528}, {-1.038404, 4.885305}},
2126         {{-29.0, -64.951905}, {2.690922, 6.949515}, {-1.179098, 4.452292}},
2127         {{16.75, 30.743902}, {3.363918, 6.649878}, {-0.733751, 4.546954}}},
2128        {{{39.25, 106.088112}, {-8.036913, -0.844714}, {-3.711596, -3.341936}},
2129         {{16.75, -30.743902}, {-7.363918, -1.144350}, {-3.266249, -3.247275}},
2130         {{-29.0, 64.951905}, {-7.440922, -0.411701}, {-3.570902, -2.908924}}}});
2131   auto expected = LiteralUtil::CreateR3<float>({{{1.8, 2.7, 3.6, 4.5, 5.4},
2132                                                  {8.1, 7.2, 6.3, 5.4, 4.5},
2133                                                  {1.1, 2.2, 3.3, 4.4, 5.5}},
2134                                                 {{5.4, 6.3, 7.2, 8.1, 9.0},
2135                                                  {4.5, 3.6, 2.7, 1.8, 0.9},
2136                                                  {5.5, 6.6, 7.7, 8.8, 9.9}},
2137                                                 {{-1.8, -2.7, -3.6, -4.5, -5.4},
2138                                                  {-5.4, -6.3, -7.2, -8.1, -9.0},
2139                                                  {1.9, 2.9, 3.9, 4.9, 5.9}}});
2140   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2141   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2142   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2143   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2144 }
2145 
2146 // FFT tests with non-default data layout:
2147 
2148 TEST_F(HloEvaluatorTest, 1D_FFT_8_on_c64x2x4x8_with_layout) {
2149   const char* hlo_text = R"(
2150 HloModule Fft
2151 
2152 ENTRY main {
2153   operand = c64[2, 4, 8]{0, 2, 1} parameter(0)
2154   ROOT fft = c64[2, 4, 8]{1, 2, 0} fft(operand), fft_type=FFT, fft_length={8}
2155 }
2156 )";
2157   auto input = fft_c64x2x4x8_.Relayout(LayoutUtil::MakeLayout({0, 2, 1}));
2158   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2159   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2160   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_1d_.shape()));
2161   EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_1d_, result, fft_error_));
2162 }
2163 
2164 TEST_F(HloEvaluatorTest, 2D_FFT_4x8_on_c64x2x4x8_with_layout) {
2165   const char* hlo_text = R"(
2166 HloModule Fft
2167 
2168 ENTRY main {
2169   operand = c64[2, 4, 8]{2, 0, 1} parameter(0)
2170   ROOT fft = c64[2, 4, 8]{1, 0, 2} fft(operand), fft_type=FFT, fft_length={4, 8}
2171 }
2172 )";
2173   auto input = fft_c64x2x4x8_.Relayout(LayoutUtil::MakeLayout({2, 0, 1}));
2174   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2175   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2176   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_2d_.shape()));
2177   EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_2d_, result, fft_error_));
2178 }
2179 
2180 TEST_F(HloEvaluatorTest, 3D_FFT_2x4x8_on_c64x2x4x8_with_layout) {
2181   const char* hlo_text = R"(
2182 HloModule Fft
2183 
2184 ENTRY main {
2185   operand = c64[2, 4, 8]{1, 2, 0} parameter(0)
2186   ROOT fft =
2187     c64[2, 4, 8]{0, 2, 1} fft(operand), fft_type=FFT, fft_length={2, 4, 8}
2188 }
2189 )";
2190   auto input = fft_c64x2x4x8_.Relayout(LayoutUtil::MakeLayout({1, 2, 0}));
2191   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2192   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2193   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_3d_.shape()));
2194   EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_3d_, result, fft_error_));
2195 }
2196 
2197 // FFT tests with unusual parameters:
2198 
2199 // Zero-length transform.
2200 TEST_F(HloEvaluatorTest, 1D_FFT_0_on_c64x1x1x1x1) {
2201   const char* hlo_text = R"(
2202 HloModule Fft
2203 
2204 ENTRY main {
2205   operand = c64[1, 1, 1, 1] parameter(0)
2206   ROOT fft = c64[1, 1, 1, 1] fft(operand), fft_type=FFT, fft_length={0}
2207 }
2208 )";
2209   auto input = LiteralUtil::CreateR4<complex64>({{{{{42.24, 24.42}}}}});
2210   auto expected = LiteralUtil::CreateR4<complex64>({{{{{0.0, 0.0}}}}});
2211   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2212   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2213   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2214   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2215 }
2216 
2217 // Zero-length axis.
2218 TEST_F(HloEvaluatorTest, 1D_FFT_1_on_c64x1x1x1x0) {
2219   const char* hlo_text = R"(
2220 HloModule Fft
2221 
2222 ENTRY main {
2223   operand = c64[1, 1, 1, 0] parameter(0)
2224   ROOT fft = c64[1, 1, 1, 0] fft(operand), fft_type=FFT, fft_length={1}
2225 }
2226 )";
2227   TF_ASSERT_OK_AND_ASSIGN(
2228       auto input,
2229       LiteralUtil::CreateR4<complex64>({{{{}}}}).Reshape({1, 1, 1, 0}));
2230   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2231   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2232   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), input.shape()));
2233   EXPECT_TRUE(LiteralTestUtil::Near(input, result, fft_error_));
2234 }
2235 
2236 // Some/all dimensions have length 1.
2237 TEST_F(HloEvaluatorTest, 1D_FFT_1_on_c64x1x1x1x1) {
2238   const char* hlo_text = R"(
2239 HloModule Fft
2240 
2241 ENTRY main {
2242   operand = c64[1, 1, 1, 1] parameter(0)
2243   ROOT fft = c64[1, 1, 1, 1] fft(operand), fft_type=FFT, fft_length={1}
2244 }
2245 )";
2246   auto input = LiteralUtil::CreateR4<complex64>({{{{{42.24, 24.42}}}}});
2247   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2248   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2249   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), input.shape()));
2250   EXPECT_TRUE(LiteralTestUtil::Near(input, result, fft_error_));
2251 }
2252 
2253 // Zero-length transform.
2254 TEST_F(HloEvaluatorTest, 3D_FFT_1x0x1_on_c64x1x1x1x1) {
2255   const char* hlo_text = R"(
2256 HloModule Fft
2257 
2258 ENTRY main {
2259   operand = c64[1, 1, 1, 1] parameter(0)
2260   ROOT fft = c64[1, 1, 1, 1] fft(operand), fft_type=FFT, fft_length={1, 0, 1}
2261 }
2262 )";
2263   auto input = LiteralUtil::CreateR4<complex64>({{{{{42.24, 24.42}}}}});
2264   auto expected = LiteralUtil::CreateR4<complex64>({{{{{0.0, 0.0}}}}});
2265   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2266   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2267   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2268   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2269 }
2270 
2271 // Zero-length axis.
2272 TEST_F(HloEvaluatorTest, 3D_FFT_1x1x1_on_c64x0x1x0x1) {
2273   const char* hlo_text = R"(
2274 HloModule Fft
2275 
2276 ENTRY main {
2277   operand = c64[0, 1, 0, 1] parameter(0)
2278   ROOT fft = c64[0, 1, 0, 1] fft(operand), fft_type=FFT, fft_length={1, 1, 1}
2279 }
2280 )";
2281   TF_ASSERT_OK_AND_ASSIGN(
2282       auto input,
2283       LiteralUtil::CreateR4<complex64>({{{{}}}}).Reshape({0, 1, 0, 1}));
2284   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2285   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2286   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), input.shape()));
2287   EXPECT_TRUE(LiteralTestUtil::Near(input, result, fft_error_));
2288 }
2289 
2290 // Some/all dimensions have length 1.
2291 TEST_F(HloEvaluatorTest, 3D_FFT_1x1x1_on_c64x1x1x1x1) {
2292   const char* hlo_text = R"(
2293 HloModule Fft
2294 
2295 ENTRY main {
2296   operand = c64[1, 1, 1, 1] parameter(0)
2297   ROOT fft = c64[1, 1, 1, 1] fft(operand), fft_type=FFT, fft_length={1, 1, 1}
2298 }
2299 )";
2300   auto input = LiteralUtil::CreateR4<complex64>({{{{{42.24, 24.42}}}}});
2301   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2302   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2303   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), input.shape()));
2304   EXPECT_TRUE(LiteralTestUtil::Near(input, result, fft_error_));
2305 }
2306 
2307 // Some/all dimensions have length 1.
2308 TEST_F(HloEvaluatorTest, 3D_FFT_3x1x1_on_c64x1x3x1x1) {
2309   const char* hlo_text = R"(
2310 HloModule Fft
2311 
2312 ENTRY main {
2313   operand = c64[1, 3, 1, 1] parameter(0)
2314   ROOT fft = c64[1, 3, 1, 1] fft(operand), fft_type=FFT, fft_length={3, 1, 1}
2315 }
2316 )";
2317   auto input = LiteralUtil::CreateR4<complex64>(
2318       {{{{{42.24, 24.42}}}, {{{-42.24, 24.42}}}, {{{42.24, -24.42}}}}});
2319   auto expected =
2320       LiteralUtil::CreateR4<complex64>({{{{{42.24, 24.42}}},
2321                                          {{{84.5367, 97.5818}}},
2322                                          {{{-0.0566792, -48.7418}}}}});
2323   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2324   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2325   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2326   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2327 }
2328 
2329 // Some/all dimensions have length 1.
2330 TEST_F(HloEvaluatorTest, 3D_IFFT_3x1x1_on_c64x1x3x1x1) {
2331   const char* hlo_text = R"(
2332 HloModule Fft
2333 
2334 ENTRY main {
2335   operand = c64[1, 3, 1, 1] parameter(0)
2336   ROOT ifft = c64[1, 3, 1, 1] fft(operand), fft_type=IFFT, fft_length={3, 1, 1}
2337 }
2338 )";
2339   auto input = LiteralUtil::CreateR4<complex64>({{{{{42.24, 24.42}}},
2340                                                   {{{84.5367, 97.5818}}},
2341                                                   {{{-0.0566792, -48.7418}}}}});
2342   auto expected = LiteralUtil::CreateR4<complex64>(
2343       {{{{{42.24, 24.42}}}, {{{-42.24, 24.42}}}, {{{42.24, -24.42}}}}});
2344   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2345   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2346   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2347   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2348 }
2349 
2350 // Odd transform length.
2351 TEST_F(HloEvaluatorTest, 1D_FFT_5_on_c64x5) {
2352   const char* hlo_text = R"(
2353 HloModule Fft
2354 
2355 ENTRY main {
2356   operand = c64[5] parameter(0)
2357   ROOT fft = c64[5] fft(operand), fft_type=FFT, fft_length={5}
2358 }
2359 )";
2360   auto input = LiteralUtil::CreateR1<complex64>(
2361       {{1.0, 5.0}, {2.0, 4.0}, {3.0, 3.0}, {4.0, 2.0}, {5.0, 1.0}});
2362   auto expected = LiteralUtil::CreateR1<complex64>({{15.0, 15.0},
2363                                                     {0.940955, 5.94095},
2364                                                     {-1.6877, 3.3123},
2365                                                     {-3.3123, 1.6877},
2366                                                     {-5.94095, -0.940955}});
2367   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2368   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2369   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2370   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2371 }
2372 
2373 // Odd transform length.
2374 TEST_F(HloEvaluatorTest, 1D_IFFT_5_on_c64x5) {
2375   const char* hlo_text = R"(
2376 HloModule Fft
2377 
2378 ENTRY main {
2379   operand = c64[5] parameter(0)
2380   ROOT ifft = c64[5] fft(operand), fft_type=IFFT, fft_length={5}
2381 }
2382 )";
2383   auto input = LiteralUtil::CreateR1<complex64>({{15.0, 15.0},
2384                                                  {0.940955, 5.94095},
2385                                                  {-1.6877, 3.3123},
2386                                                  {-3.3123, 1.6877},
2387                                                  {-5.94095, -0.940955}});
2388   auto expected = LiteralUtil::CreateR1<complex64>(
2389       {{1.0, 5.0}, {2.0, 4.0}, {3.0, 3.0}, {4.0, 2.0}, {5.0, 1.0}});
2390   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2391   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2392   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2393   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2394 }
2395 
2396 // All input values are zero.
2397 TEST_F(HloEvaluatorTest, 1D_FFT_4_on_zero_c64x4) {
2398   const char* hlo_text = R"(
2399 HloModule Fft
2400 
2401 ENTRY main {
2402   operand = c64[4] parameter(0)
2403   ROOT fft = c64[4] fft(operand), fft_type=FFT, fft_length={4}
2404 }
2405 )";
2406   auto input = LiteralUtil::CreateR1<complex64>(
2407       {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}});
2408   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2409   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2410   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), input.shape()));
2411   EXPECT_TRUE(LiteralTestUtil::Near(input, result, fft_error_));
2412 }
2413 
2414 // All input values are zero.
2415 TEST_F(HloEvaluatorTest, 3D_FFT_3x3x4_on_zero_c64x3x3x4) {
2416   const char* hlo_text = R"(
2417 HloModule Fft
2418 
2419 ENTRY main {
2420   operand = c64[3, 3, 4] parameter(0)
2421   ROOT fft = c64[3, 3, 4] fft(operand), fft_type=FFT, fft_length={3, 3, 4}
2422 }
2423 )";
2424   auto input = LiteralUtil::CreateR3<complex64>(
2425       {{{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2426         {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2427         {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}},
2428        {{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2429         {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2430         {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}},
2431        {{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2432         {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2433         {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}}});
2434   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2435   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2436   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), input.shape()));
2437   EXPECT_TRUE(LiteralTestUtil::Near(input, result, fft_error_));
2438 }
2439 
2440 // All input values are zero.
2441 TEST_F(HloEvaluatorTest, 3D_IFFT_3x3x4_on_zero_c64x3x3x4) {
2442   const char* hlo_text = R"(
2443 HloModule Fft
2444 
2445 ENTRY main {
2446   operand = c64[3, 3, 4] parameter(0)
2447   ROOT ifft = c64[3, 3, 4] fft(operand), fft_type=IFFT, fft_length={3, 3, 4}
2448 }
2449 )";
2450   auto input = LiteralUtil::CreateR3<complex64>(
2451       {{{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2452         {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2453         {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}},
2454        {{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2455         {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2456         {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}},
2457        {{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2458         {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2459         {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}}});
2460   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2461   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2462   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), input.shape()));
2463   EXPECT_TRUE(LiteralTestUtil::Near(input, result, fft_error_));
2464 }
2465 
2466 // All input values are zero.
2467 TEST_F(HloEvaluatorTest, 3D_RFFT_3x3x4_on_zero_f32x3x3x4) {
2468   const char* hlo_text = R"(
2469 HloModule Fft
2470 
2471 ENTRY main {
2472   operand = f32[3, 3, 4] parameter(0)
2473   ROOT rfft = c64[3, 3, 3] fft(operand), fft_type=RFFT, fft_length={3, 3, 4}
2474 }
2475 )";
2476   auto input = LiteralUtil::CreateR3<float>(
2477       {{{0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}},
2478        {{0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}},
2479        {{0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}}});
2480   auto expected = LiteralUtil::CreateR3<complex64>(
2481       {{{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2482         {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2483         {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}},
2484        {{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2485         {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2486         {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}},
2487        {{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2488         {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2489         {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}}});
2490   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2491   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2492   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2493   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2494 }
2495 
2496 // All input values are zero.
2497 TEST_F(HloEvaluatorTest, 3D_IRFFT_3x3x4_on_zero_c64x3x3x3) {
2498   const char* hlo_text = R"(
2499 HloModule Fft
2500 
2501 ENTRY main {
2502   operand = c64[3, 3, 3] parameter(0)
2503   ROOT irfft = f32[3, 3, 4] fft(operand), fft_type=IRFFT, fft_length={3, 3, 4}
2504 }
2505 )";
2506   auto input = LiteralUtil::CreateR3<complex64>(
2507       {{{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2508         {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2509         {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}},
2510        {{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2511         {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2512         {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}},
2513        {{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2514         {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2515         {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}}});
2516   auto expected = LiteralUtil::CreateR3<float>(
2517       {{{0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}},
2518        {{0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}},
2519        {{0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}}});
2520   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2521   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2522   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2523   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2524 }
2525 
2526 // Input values, for which IRFFT discards non-zero imaginary parts.
2527 TEST_F(HloEvaluatorTest, 2D_IRFFT_3x4_on_c64x3x3) {
2528   const char* hlo_text = R"(
2529 HloModule Fft
2530 
2531 ENTRY main {
2532   operand = c64[3, 3] parameter(0)
2533   ROOT irfft = f32[3, 4] fft(operand), fft_type=IRFFT, fft_length={3, 4}
2534 }
2535 )";
2536   auto input =
2537       LiteralUtil::CreateR2<complex64>({{{0.0, 0.0}, {1.0, 0.0}, {2.0, 0.0}},
2538                                         {{3.0, 0.0}, {4.0, 0.0}, {5.0, 0.0}},
2539                                         {{6.0, 0.0}, {7.0, 0.0}, {8.0, 0.0}}});
2540   auto expected =
2541       LiteralUtil::CreateR2<float>({{4.0, -0.5, 0.0, -0.5},
2542                                     {-1.5, 0.433013, 0.0, -0.433013},
2543                                     {-1.5, -0.433013, 0.0, 0.433013}});
2544   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2545   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2546   EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2547   EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2548 }
2549 
2550 class HloEvaluatorPreciseReduceTest : public HloTestBase {};
2551 
2552 // Tests that Reduce doesn't lose precision when adding many numbers (because
2553 // it accumulates its result in a double).
TEST_F(HloEvaluatorPreciseReduceTest,AddReductionPrecisionTest)2554 TEST_F(HloEvaluatorPreciseReduceTest, AddReductionPrecisionTest) {
2555   auto m = CreateNewVerifiedModule();
2556   HloComputation::Builder b(TestName());
2557 
2558   constexpr int kNumElements = 1 << 25;  // float += 1 saturates at 1<<24
2559   std::vector<float> v(kNumElements, 1.0f);
2560   HloInstruction* arg_instruction = b.AddInstruction(
2561       HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(v)));
2562   HloInstruction* init_value = b.AddInstruction(
2563       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
2564 
2565   HloComputation::Builder add_computation("add");
2566   Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
2567   auto param_lhs = add_computation.AddInstruction(
2568       HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
2569   auto param_rhs = add_computation.AddInstruction(
2570       HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
2571   add_computation.AddInstruction(HloInstruction::CreateBinary(
2572       scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs));
2573   auto add_func = m->AddEmbeddedComputation(add_computation.Build());
2574 
2575   HloInstruction* reduce_instruction = b.AddInstruction(
2576       HloInstruction::CreateReduce(scalar_shape, arg_instruction, init_value,
2577                                    /*dimensions_to_reduce=*/{0}, add_func));
2578   m->AddEntryComputation(b.Build());
2579 
2580   HloEvaluator hlo_eval;
2581   Literal result = hlo_eval.Evaluate(reduce_instruction).value();
2582   LiteralTestUtil::ExpectR0Equal<float>(kNumElements, result);
2583 }
2584 
2585 // Reducing many numbers should be fast because it doesn't create
2586 // intermediate Literals; the microbenchmark should finish in < 1 msec.
BM_ReducePrecisely(::testing::benchmark::State & state)2587 void BM_ReducePrecisely(::testing::benchmark::State& state) {
2588   HloComputation::Builder b("BM_ReducePrecisely");
2589   HloModuleConfig config;
2590   config.set_debug_options(GetDebugOptionsFromFlags());
2591   HloModule module("BM_ReducePrecisely", config);
2592 
2593   constexpr int kNumElements = 1 << 25;  // float += 1 saturates at 1<<24
2594   std::vector<float> v(kNumElements, 1.0f);
2595   HloInstruction* arg_instruction = b.AddInstruction(
2596       HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(v)));
2597   auto init_value = b.AddInstruction(
2598       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
2599 
2600   HloComputation::Builder add_computation("add");
2601   Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
2602   auto param_lhs = add_computation.AddInstruction(
2603       HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
2604   auto param_rhs = add_computation.AddInstruction(
2605       HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
2606   add_computation.AddInstruction(HloInstruction::CreateBinary(
2607       scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs));
2608   auto add_func = module.AddEmbeddedComputation(add_computation.Build());
2609 
2610   HloInstruction* reduce_instruction = b.AddInstruction(
2611       HloInstruction::CreateReduce(scalar_shape, arg_instruction, init_value,
2612                                    /*dimensions_to_reduce=*/{0}, add_func));
2613   module.AddEntryComputation(b.Build());
2614 
2615   // Benchmark loop
2616   for (auto s : state) {
2617     HloEvaluator hlo_eval;
2618     hlo_eval.Evaluate(reduce_instruction).value();
2619   }
2620 }
2621 
2622 BENCHMARK(BM_ReducePrecisely);
2623 
TEST_P(HloEvaluatorBf16Test,ReduceAdd)2624 TEST_P(HloEvaluatorBf16Test, ReduceAdd) {
2625   HloComputation::Builder b(TestName());
2626 
2627   // arg:
2628   // f32[2,3] {
2629   //  { 1, 2, 3 },
2630   //  { 5, 6, 7 },
2631   // }
2632   auto arg_array = std::make_unique<Array2D<float>>(2, 3);
2633   arg_array->FillUnique(1.0f);
2634   auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array);
2635 
2636   HloInstruction* arg_instruction =
2637       b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
2638 
2639   auto init_value = b.AddInstruction(
2640       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
2641 
2642   HloComputation::Builder add_computation("add");
2643   Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
2644   auto param_lhs = add_computation.AddInstruction(
2645       HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
2646   auto param_rhs = add_computation.AddInstruction(
2647       HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
2648   add_computation.AddInstruction(HloInstruction::CreateBinary(
2649       scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs));
2650   auto add_func = m_->AddEmbeddedComputation(add_computation.Build());
2651 
2652   Shape shape = ShapeUtil::MakeShape(F32, {2});
2653   b.AddInstruction(
2654       HloInstruction::CreateReduce(shape, arg_instruction, init_value,
2655                                    /*dimensions_to_reduce=*/{1}, add_func));
2656 
2657   m_->AddEntryComputation(b.Build());
2658 
2659   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
2660 
2661   auto expected = LiteralUtil::CreateR1<float>({6, 18});
2662 
2663   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
2664 }
2665 
TEST_P(HloEvaluatorBf16Test,ReduceWindowMax)2666 TEST_P(HloEvaluatorBf16Test, ReduceWindowMax) {
2667   HloComputation::Builder b(TestName());
2668 
2669   // arg:
2670   // f32[2,3] {
2671   //  { 1, 2, 3 },
2672   //  { 5, 6, 7 },
2673   // }
2674   auto arg_array = std::make_unique<Array2D<float>>(2, 3);
2675   arg_array->FillUnique(1.0f);
2676   auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array);
2677 
2678   HloInstruction* arg_instruction =
2679       b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
2680 
2681   auto init_value = b.AddInstruction(
2682       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
2683   auto max_func = m_->AddEmbeddedComputation(MaxComputationScalarF32());
2684 
2685   Window window;
2686   WindowDimension dim;
2687   dim.set_size(2);
2688   dim.set_stride(1);
2689   dim.set_padding_low(0);
2690   dim.set_padding_high(0);
2691   dim.set_window_dilation(1);
2692   dim.set_base_dilation(1);
2693   *window.add_dimensions() = dim;
2694   *window.add_dimensions() = dim;
2695 
2696   Shape shape = ShapeUtil::MakeShape(F32, {1, 2});
2697   b.AddInstruction(HloInstruction::CreateReduceWindow(
2698       shape, arg_instruction, init_value, window, max_func));
2699 
2700   m_->AddEntryComputation(b.Build());
2701 
2702   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
2703 
2704   auto expected = LiteralUtil::CreateR2<float>({{6, 7}});
2705   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
2706 }
2707 
TEST_P(HloEvaluatorBf16Test,ReduceWindowMaxIotaWindowDilation)2708 TEST_P(HloEvaluatorBf16Test, ReduceWindowMaxIotaWindowDilation) {
2709   auto expected = LiteralUtil::CreateR2<float>({{10, 11}, {14, 15}});
2710   ReduceWindowMaxIotaTest(
2711       /*window_size=*/2,
2712       /*padding=*/0,
2713       /*stride=*/1,
2714       /*window_dilation=*/2,
2715       /*base_dilation=*/1,
2716       /*expected=*/expected);
2717 }
2718 
TEST_P(HloEvaluatorBf16Test,ReduceWindowMaxIotaStrideWindowDilation)2719 TEST_P(HloEvaluatorBf16Test, ReduceWindowMaxIotaStrideWindowDilation) {
2720   auto expected = LiteralUtil::CreateR2<float>({{10}});
2721   ReduceWindowMaxIotaTest(
2722       /*window_size=*/2,
2723       /*padding=*/0,
2724       /*stride=*/2,
2725       /*window_dilation=*/2,
2726       /*base_dilation=*/1,
2727       /*expected=*/expected);
2728 }
2729 
TEST_P(HloEvaluatorBf16Test,ReduceWindowMaxIotaBaseDilation)2730 TEST_P(HloEvaluatorBf16Test, ReduceWindowMaxIotaBaseDilation) {
2731   auto expected = LiteralUtil::CreateR2<float>({{0, 1, 1, 2, 2, 3},
2732                                                 {4, 5, 5, 6, 6, 7},
2733                                                 {4, 5, 5, 6, 6, 7},
2734                                                 {8, 9, 9, 10, 10, 11},
2735                                                 {8, 9, 9, 10, 10, 11},
2736                                                 {12, 13, 13, 14, 14, 15}});
2737   ReduceWindowMaxIotaTest(
2738       /*window_size=*/2,
2739       /*padding=*/0,
2740       /*stride=*/1,
2741       /*window_dilation=*/1,
2742       /*base_dilation=*/2,
2743       /*expected=*/expected);
2744 }
2745 
TEST_P(HloEvaluatorBf16Test,ReduceWindowMaxIotaStrideBaseDilation)2746 TEST_P(HloEvaluatorBf16Test, ReduceWindowMaxIotaStrideBaseDilation) {
2747   auto expected =
2748       LiteralUtil::CreateR2<float>({{0, 1, 2}, {4, 5, 6}, {8, 9, 10}});
2749   ReduceWindowMaxIotaTest(
2750       /*window_size=*/2,
2751       /*padding=*/0,
2752       /*stride=*/2,
2753       /*window_dilation=*/1,
2754       /*base_dilation=*/2,
2755       /*expected=*/expected);
2756 }
2757 
TEST_P(HloEvaluatorBf16Test,ReduceWindowMaxIotaStrideBothDilation)2758 TEST_P(HloEvaluatorBf16Test, ReduceWindowMaxIotaStrideBothDilation) {
2759   auto expected =
2760       LiteralUtil::CreateR2<float>({{5, 6, 7}, {9, 10, 11}, {13, 14, 15}});
2761   ReduceWindowMaxIotaTest(
2762       /*window_size=*/2,
2763       /*padding=*/0,
2764       /*stride=*/2,
2765       /*window_dilation=*/2,
2766       /*base_dilation=*/2,
2767       /*expected=*/expected);
2768 }
2769 
TEST_P(HloEvaluatorBf16Test,ReduceWindowMaxIotaPaddingStrideBaseDilation)2770 TEST_P(HloEvaluatorBf16Test, ReduceWindowMaxIotaPaddingStrideBaseDilation) {
2771   // The base is dilated first, and then padding is applied, hence this result.
2772   auto expected =
2773       LiteralUtil::CreateR2<float>({{0, 2, 3}, {8, 10, 11}, {12, 14, 15}});
2774   ReduceWindowMaxIotaTest(
2775       /*window_size=*/3,
2776       /*padding=*/1,
2777       /*stride=*/3,
2778       /*window_dilation=*/1,
2779       /*base_dilation=*/2,
2780       /*expected=*/expected);
2781 }
2782 
TEST_P(HloEvaluatorBf16Test,ReduceWindowAdd)2783 TEST_P(HloEvaluatorBf16Test, ReduceWindowAdd) {
2784   HloComputation::Builder b(TestName());
2785 
2786   // arg:
2787   // f32[2,3] {
2788   //  { 1, 2, 3 },
2789   //  { 5, 6, 7 },
2790   // }
2791   auto arg_array = std::make_unique<Array2D<float>>(2, 3);
2792   arg_array->FillUnique(1.0f);
2793   auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array);
2794 
2795   HloInstruction* arg_instruction =
2796       b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
2797 
2798   auto init_value = b.AddInstruction(
2799       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
2800 
2801   HloComputation::Builder add_computation("add");
2802   Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
2803   auto param_lhs = add_computation.AddInstruction(
2804       HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
2805   auto param_rhs = add_computation.AddInstruction(
2806       HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
2807   add_computation.AddInstruction(HloInstruction::CreateBinary(
2808       scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs));
2809   auto add_func = m_->AddEmbeddedComputation(add_computation.Build());
2810 
2811   Window window;
2812   WindowDimension dim;
2813   dim.set_size(1);
2814   dim.set_stride(1);
2815   dim.set_padding_low(0);
2816   dim.set_padding_high(0);
2817   dim.set_window_dilation(1);
2818   dim.set_base_dilation(1);
2819   *window.add_dimensions() = dim;
2820   dim.set_size(2);
2821   dim.set_stride(1);
2822   dim.set_padding_low(1);
2823   dim.set_padding_high(0);
2824   dim.set_window_dilation(1);
2825   dim.set_base_dilation(1);
2826   *window.add_dimensions() = dim;
2827 
2828   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
2829   b.AddInstruction(HloInstruction::CreateReduceWindow(
2830       shape, arg_instruction, init_value, window, add_func));
2831 
2832   m_->AddEntryComputation(b.Build());
2833 
2834   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
2835 
2836   auto expected = LiteralUtil::CreateR2<float>({{1, 3, 5}, {5, 11, 13}});
2837   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
2838 }
2839 
TEST_P(HloEvaluatorBf16Test,ReduceWindowAdd6D)2840 TEST_P(HloEvaluatorBf16Test, ReduceWindowAdd6D) {
2841   HloComputation::Builder b(TestName());
2842 
2843   // arg: f32[4,4,4,4,4,4] full of ones. Using small dims to limit run-time.
2844   std::vector<int64_t> input_dims(6, 4);
2845   Literal arg_literal =
2846       LiteralUtil::CreateFullWithDescendingLayout<float>(input_dims, 1.0f);
2847 
2848   HloInstruction* arg_instruction =
2849       b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
2850 
2851   auto init_value = b.AddInstruction(
2852       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
2853 
2854   HloComputation::Builder add_computation("add");
2855   Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
2856   auto param_lhs = add_computation.AddInstruction(
2857       HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
2858   auto param_rhs = add_computation.AddInstruction(
2859       HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
2860   add_computation.AddInstruction(HloInstruction::CreateBinary(
2861       scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs));
2862   auto add_func = m_->AddEmbeddedComputation(add_computation.Build());
2863 
2864   Window window;
2865 
2866   WindowDimension trivial_dim;
2867   trivial_dim.set_size(1);
2868   trivial_dim.set_stride(1);
2869   trivial_dim.set_padding_low(0);
2870   trivial_dim.set_padding_high(0);
2871   trivial_dim.set_window_dilation(1);
2872   trivial_dim.set_base_dilation(1);
2873 
2874   WindowDimension active_dim;
2875   active_dim.set_size(2);
2876   active_dim.set_stride(1);
2877   active_dim.set_padding_low(0);
2878   active_dim.set_padding_high(0);
2879   active_dim.set_window_dilation(1);
2880   active_dim.set_base_dilation(1);
2881 
2882   *window.add_dimensions() = trivial_dim;
2883   *window.add_dimensions() = active_dim;
2884   *window.add_dimensions() = active_dim;
2885   *window.add_dimensions() = active_dim;
2886   *window.add_dimensions() = trivial_dim;
2887   *window.add_dimensions() = trivial_dim;
2888 
2889   Shape shape = ShapeUtil::MakeShape(F32, {4, 3, 3, 3, 4, 4});
2890   b.AddInstruction(HloInstruction::CreateReduceWindow(
2891       shape, arg_instruction, init_value, window, add_func));
2892 
2893   m_->AddEntryComputation(b.Build());
2894 
2895   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
2896 
2897   std::vector<int64_t> output_dims = {4, 3, 3, 3, 4, 4};
2898   Literal result_literal =
2899       LiteralUtil::CreateFullWithDescendingLayout<float>(output_dims, 8.0f);
2900   EXPECT_TRUE(LiteralTestUtil::Equal(result_literal, result));
2901 }
2902 
TEST_P(HloEvaluatorBf16Test,Min3In5Stride2Tuple)2903 TEST_P(HloEvaluatorBf16Test, Min3In5Stride2Tuple) {
2904   HloComputation::Builder builder("main");
2905   auto input1 = builder.AddInstruction(HloInstruction::CreateConstant(
2906       LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1})));
2907   auto input2 = builder.AddInstruction(HloInstruction::CreateConstant(
2908       LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1})));
2909   HloComputation::Builder bcompute("ComputeFunction");
2910   auto shape1 = ShapeUtil::MakeShape(F32, {});
2911   auto shape2 = ShapeUtil::MakeShape(F32, {});
2912   auto p2 =
2913       bcompute.AddInstruction(HloInstruction::CreateParameter(0, shape1, "x0"));
2914   auto p3 =
2915       bcompute.AddInstruction(HloInstruction::CreateParameter(1, shape2, "x1"));
2916   auto p4 =
2917       bcompute.AddInstruction(HloInstruction::CreateParameter(2, shape1, "y0"));
2918   auto p5 =
2919       bcompute.AddInstruction(HloInstruction::CreateParameter(3, shape2, "y1"));
2920   std::vector<HloInstruction*> compute_vec = {
2921       bcompute.AddInstruction(
2922           HloInstruction::CreateBinary(shape1, HloOpcode::kMinimum, p2, p4)),
2923       bcompute.AddInstruction(
2924           HloInstruction::CreateBinary(shape2, HloOpcode::kMinimum, p3, p5))};
2925   bcompute.AddInstruction(HloInstruction::CreateTuple(compute_vec));
2926   auto compute_tuple = m_->AddEmbeddedComputation(bcompute.Build());
2927   std::vector<HloInstruction*> input_vec = {input1, input2};
2928   auto init1 = builder.AddInstruction(
2929       HloInstruction::CreateConstant(LiteralUtil::MaxValue(F32)));
2930   auto init2 = builder.AddInstruction(
2931       HloInstruction::CreateConstant(LiteralUtil::MaxValue(F32)));
2932   std::vector<HloInstruction*> init_vec = {init1, init2};
2933   auto padding = std::pair<int64_t, int64_t>(0, 0);
2934   TF_ASSERT_OK_AND_ASSIGN(auto window,
2935                           ShapeInference::InferWindowFromDimensions(
2936                               {3}, {2}, absl::MakeSpan(&padding, 1),
2937                               /*lhs_dilation=*/{},
2938                               /*rhs_dilation=*/{}));
2939   std::vector<const Shape*> input_shapes = {&input1->shape(), &input2->shape()};
2940   std::vector<const Shape*> init_shapes = {&init1->shape(), &init2->shape()};
2941   TF_ASSERT_OK_AND_ASSIGN(Shape shape,
2942                           ShapeInference::InferReduceWindowShape(
2943                               input_shapes, init_shapes, window,
2944                               compute_tuple->ComputeProgramShape()));
2945   builder.AddInstruction(HloInstruction::CreateReduceWindow(
2946       shape, input_vec, init_vec, window, compute_tuple));
2947   auto r1 = LiteralUtil::CreateR1<float>({100, 1});
2948   auto expected = LiteralUtil::MakeTuple({&r1, &r1});
2949   m_->AddEntryComputation(builder.Build());
2950   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
2951   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
2952 }
2953 
TEST_P(HloEvaluatorBf16Test,Min3In5Stride2TupleDiffInput)2954 TEST_P(HloEvaluatorBf16Test, Min3In5Stride2TupleDiffInput) {
2955   HloComputation::Builder builder("main");
2956   auto input1 = builder.AddInstruction(HloInstruction::CreateConstant(
2957       LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1})));
2958   auto input2 = builder.AddInstruction(HloInstruction::CreateConstant(
2959       LiteralUtil::CreateR1<int>({15, 28, 300, 107, 12})));
2960   HloComputation::Builder bcompute("ComputeFunction");
2961   auto shape1 = ShapeUtil::MakeShape(F32, {});
2962   auto shape2 = ShapeUtil::MakeShape(S32, {});
2963   auto p2 =
2964       bcompute.AddInstruction(HloInstruction::CreateParameter(0, shape1, "x0"));
2965   auto p3 =
2966       bcompute.AddInstruction(HloInstruction::CreateParameter(1, shape2, "x1"));
2967   auto p4 =
2968       bcompute.AddInstruction(HloInstruction::CreateParameter(2, shape1, "y0"));
2969   auto p5 =
2970       bcompute.AddInstruction(HloInstruction::CreateParameter(3, shape2, "y1"));
2971   std::vector<HloInstruction*> compute_vec = {
2972       bcompute.AddInstruction(
2973           HloInstruction::CreateBinary(shape1, HloOpcode::kMinimum, p2, p4)),
2974       bcompute.AddInstruction(
2975           HloInstruction::CreateBinary(shape2, HloOpcode::kMinimum, p3, p5))};
2976   bcompute.AddInstruction(HloInstruction::CreateTuple(compute_vec));
2977   auto compute_tuple = m_->AddEmbeddedComputation(bcompute.Build());
2978   std::vector<HloInstruction*> input_vec = {input1, input2};
2979   auto init1 = builder.AddInstruction(
2980       HloInstruction::CreateConstant(LiteralUtil::MaxValue(F32)));
2981   auto init2 = builder.AddInstruction(
2982       HloInstruction::CreateConstant(LiteralUtil::MaxValue(S32)));
2983   std::vector<HloInstruction*> init_vec = {init1, init2};
2984   auto padding = std::pair<int64_t, int64_t>(0, 0);
2985   TF_ASSERT_OK_AND_ASSIGN(auto window,
2986                           ShapeInference::InferWindowFromDimensions(
2987                               {3}, {2}, absl::MakeSpan(&padding, 1),
2988                               /*lhs_dilation=*/{},
2989                               /*rhs_dilation=*/{}));
2990   std::vector<const Shape*> input_shapes = {&input1->shape(), &input2->shape()};
2991   std::vector<const Shape*> init_shapes = {&init1->shape(), &init2->shape()};
2992   TF_ASSERT_OK_AND_ASSIGN(Shape shape,
2993                           ShapeInference::InferReduceWindowShape(
2994                               input_shapes, init_shapes, window,
2995                               compute_tuple->ComputeProgramShape()));
2996   builder.AddInstruction(HloInstruction::CreateReduceWindow(
2997       shape, input_vec, init_vec, window, compute_tuple));
2998   auto r1 = LiteralUtil::CreateR1<float>({100, 1});
2999   auto r2 = LiteralUtil::CreateR1<int>({15, 12});
3000   auto expected = LiteralUtil::MakeTuple({&r1, &r2});
3001   m_->AddEntryComputation(builder.Build());
3002   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
3003   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
3004 }
3005 
TEST_P(HloEvaluatorBf16Test,StridedSlice)3006 TEST_P(HloEvaluatorBf16Test, StridedSlice) {
3007   HloComputation::Builder b(TestName());
3008 
3009   // arg:
3010   // f32[3,5] {
3011   //  { 1, 2, 3, 4, 5 },
3012   //  { 9, 10, 11, 12, 13 },
3013   //  { 17, 18, 19, 20, 21 },
3014   // }
3015   auto operand_array = std::make_unique<Array2D<float>>(3, 5);
3016   operand_array->FillUnique(1.0f);
3017   auto operand_literal =
3018       LiteralUtil::CreateR2FromArray2D<float>(*operand_array);
3019 
3020   HloInstruction* operand = b.AddInstruction(
3021       HloInstruction::CreateConstant(std::move(operand_literal)));
3022 
3023   Shape shape = ShapeUtil::MakeShape(F32, {2, 1});
3024   b.AddInstruction(HloInstruction::CreateSlice(shape, operand,
3025                                                /*start_indices=*/{0, 2},
3026                                                /*limit_indices=*/{3, 5},
3027                                                /*strides=*/{2, 3}));
3028   m_->AddEntryComputation(b.Build());
3029 
3030   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
3031 
3032   auto expected = LiteralUtil::CreateR2<float>({
3033       {3},
3034       {19},
3035   });
3036 
3037   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
3038 }
3039 
TEST_P(HloEvaluatorBf16Test,DynamicSlice)3040 TEST_P(HloEvaluatorBf16Test, DynamicSlice) {
3041   HloComputation::Builder b(TestName());
3042 
3043   // arg:
3044   // f32[2,4] {
3045   //  { 1, 2, 3, 4 },
3046   //  { 5, 6, 7, 8 },
3047   // }
3048   auto operand_array = std::make_unique<Array2D<float>>(2, 4);
3049   operand_array->FillUnique(1.0f);
3050   auto operand_literal =
3051       LiteralUtil::CreateR2FromArray2D<float>(*operand_array);
3052 
3053   HloInstruction* operand = b.AddInstruction(
3054       HloInstruction::CreateConstant(std::move(operand_literal)));
3055 
3056   auto zero = b.AddInstruction(
3057       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(0)));
3058   auto one = b.AddInstruction(
3059       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(1)));
3060 
3061   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
3062   b.AddInstruction(
3063       HloInstruction::CreateDynamicSlice(shape, operand, {zero, one}, {2, 3}));
3064   m_->AddEntryComputation(b.Build());
3065 
3066   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
3067 
3068   auto expected = LiteralUtil::CreateR2<float>({
3069       {2, 3, 4},
3070       {6, 7, 8},
3071   });
3072 
3073   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
3074 }
3075 
3076 // Verifies that the HloEvaluator's implementation goes along with existing
3077 // backends' behavior, although this is not required by the spec.
TEST_P(HloEvaluatorBf16Test,DynamicSliceModSlice)3078 TEST_P(HloEvaluatorBf16Test, DynamicSliceModSlice) {
3079   HloComputation::Builder b(TestName());
3080 
3081   // arg:
3082   // f32[2,4] {
3083   //  { 1, 2, 3, 4 },
3084   //  { 5, 6, 7, 8 },
3085   // }
3086   auto operand_array = std::make_unique<Array2D<float>>(2, 4);
3087   operand_array->FillUnique(1.0f);
3088   auto operand_literal =
3089       LiteralUtil::CreateR2FromArray2D<float>(*operand_array);
3090 
3091   HloInstruction* operand = b.AddInstruction(
3092       HloInstruction::CreateConstant(std::move(operand_literal)));
3093 
3094   auto two = b.AddInstruction(
3095       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(2)));
3096   auto one = b.AddInstruction(
3097       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(1)));
3098 
3099   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
3100   b.AddInstruction(
3101       HloInstruction::CreateDynamicSlice(shape, operand, {two, one}, {2, 3}));
3102   m_->AddEntryComputation(b.Build());
3103 
3104   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
3105 
3106   auto expected = LiteralUtil::CreateR2<float>({
3107       {2, 3, 4},
3108       {6, 7, 8},
3109   });
3110 
3111   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
3112 }
3113 
TEST_P(HloEvaluatorBf16Test,DynamicSliceUpdate)3114 TEST_P(HloEvaluatorBf16Test, DynamicSliceUpdate) {
3115   HloComputation::Builder b(TestName());
3116 
3117   // arg:
3118   // f32[2,3] {
3119   //  { 1, 2, 3 },
3120   //  { 5, 6, 7 },
3121   // }
3122   auto operand_array = std::make_unique<Array2D<double>>(2, 3);
3123   operand_array->FillUnique(1.0);
3124   auto operand_literal =
3125       LiteralUtil::CreateR2FromArray2D<double>(*operand_array);
3126 
3127   HloInstruction* operand = b.AddInstruction(
3128       HloInstruction::CreateConstant(std::move(operand_literal)));
3129 
3130   auto zero = b.AddInstruction(
3131       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(0)));
3132   auto one = b.AddInstruction(
3133       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(1)));
3134 
3135   auto update = b.AddInstruction(HloInstruction::CreateConstant(
3136       LiteralUtil::CreateR2<double>({{-2.0, -3.0}, {-6.0, -7.0}})));
3137 
3138   Shape shape = ShapeUtil::MakeShape(F64, {2, 3});
3139   b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
3140       shape, operand, update, {zero, one}));
3141   m_->AddEntryComputation(b.Build());
3142 
3143   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
3144 
3145   auto expected = LiteralUtil::CreateR2<double>({
3146       {1, -2, -3},
3147       {5, -6, -7},
3148   });
3149 
3150   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
3151 }
3152 
TEST_P(HloEvaluatorBf16Test,SetAndGetTuples)3153 TEST_P(HloEvaluatorBf16Test, SetAndGetTuples) {
3154   HloComputation::Builder b(TestName());
3155 
3156   // arg:
3157   // f32[2,3] {
3158   //  { 1, 2, 3 },
3159   //  { 5, 6, 7 },
3160   // }
3161   auto operand_array = std::make_unique<Array2D<double>>(2, 3);
3162   operand_array->FillUnique(1.0);
3163   auto operand_literal2 =
3164       LiteralUtil::CreateR2FromArray2D<double>(*operand_array);
3165 
3166   HloInstruction* operand2 = b.AddInstruction(
3167       HloInstruction::CreateConstant(std::move(operand_literal2)));
3168   HloInstruction* operand1 = b.AddInstruction(
3169       HloInstruction::CreateConstant(LiteralUtil::CreateR1<int64_t>({0, 1})));
3170 
3171   auto tuple =
3172       b.AddInstruction(HloInstruction::CreateTuple({operand1, operand2}));
3173 
3174   Shape shape = ShapeUtil::MakeShape(F64, {2, 3});
3175   b.AddInstruction(HloInstruction::CreateGetTupleElement(shape, tuple, 1));
3176 
3177   m_->AddEntryComputation(b.Build());
3178 
3179   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
3180 
3181   auto expected = LiteralUtil::CreateR2<double>({
3182       {1, 2, 3},
3183       {5, 6, 7},
3184   });
3185 
3186   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
3187 }
3188 
TEST_P(HloEvaluatorBf16Test,SetAndGetNestedTuples)3189 TEST_P(HloEvaluatorBf16Test, SetAndGetNestedTuples) {
3190   HloComputation::Builder b(TestName());
3191 
3192   // arg:
3193   // f32[2,3] {
3194   //  { 1, 2, 3 },
3195   //  { 5, 6, 7 },
3196   // }
3197   auto operand_array = std::make_unique<Array2D<double>>(2, 3);
3198   operand_array->FillUnique(1.0);
3199 
3200   HloInstruction* operand2 = b.AddInstruction(HloInstruction::CreateConstant(
3201       LiteralUtil::CreateR2FromArray2D<double>(*operand_array)));
3202   HloInstruction* operand1 = b.AddInstruction(
3203       HloInstruction::CreateConstant(LiteralUtil::CreateR1<int64_t>({0, 1})));
3204 
3205   auto tuple1 =
3206       b.AddInstruction(HloInstruction::CreateTuple({operand1, operand2}));
3207   auto tuple2 =
3208       b.AddInstruction(HloInstruction::CreateTuple({operand2, operand2}));
3209 
3210   auto outer_tuple =
3211       b.AddInstruction(HloInstruction::CreateTuple({tuple1, tuple2}));
3212 
3213   b.AddInstruction(
3214       HloInstruction::CreateGetTupleElement(tuple2->shape(), outer_tuple, 1));
3215 
3216   m_->AddEntryComputation(b.Build());
3217 
3218   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
3219 
3220   auto result_inner_literal =
3221       LiteralUtil::CreateR2FromArray2D<double>(*operand_array);
3222   auto expected =
3223       LiteralUtil::MakeTuple({&result_inner_literal, &result_inner_literal});
3224 
3225   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
3226 }
3227 
TEST_P(HloEvaluatorBf16Test,Reverse)3228 TEST_P(HloEvaluatorBf16Test, Reverse) {
3229   HloComputation::Builder b(TestName());
3230 
3231   // Input shape is float[4x3x2x1].
3232   // clang-format off
3233   Array4D<float> input({
3234     {{{1.0f}, {2.0f}},
3235      {{3.0f}, {4.0f}},
3236      {{5.0f}, {6.0f}}},
3237     {{{7.0f}, {8.0f}},
3238      {{9.0f}, {10.0f}},
3239      {{11.0f}, {12.0f}}},
3240     {{{13.0f}, {14.0f}},
3241      {{15.0f}, {16.0f}},
3242      {{17.0f}, {18.0f}}},
3243     {{{19.0f}, {20.0f}},
3244      {{21.0f}, {22.0f}},
3245      {{23.0f}, {24.0f}}},
3246   });
3247   // clang-format on
3248   auto operand_literal = LiteralUtil::CreateR4FromArray4D<float>(input);
3249   HloInstruction* operand = b.AddInstruction(
3250       HloInstruction::CreateConstant(std::move(operand_literal)));
3251 
3252   const Shape shape = ShapeUtil::MakeShape(F32, {4, 3, 2, 1});
3253   b.AddInstruction(HloInstruction::CreateReverse(shape, operand, {0, 1}));
3254   m_->AddEntryComputation(b.Build());
3255 
3256   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
3257 
3258   // clang-format off
3259   auto expected = LiteralUtil::CreateR4FromArray4D<float>({
3260     {{{23.0f}, {24.0f}},
3261      {{21.0f}, {22.0f}},
3262      {{19.0f}, {20.0f}}},
3263 
3264     {{{17.0f}, {18.0f}},
3265      {{15.0f}, {16.0f}},
3266      {{13.0f}, {14.0f}}},
3267 
3268     {{{11.0f}, {12.0f}},
3269      {{9.0f}, {10.0f}},
3270      {{7.0f}, {8.0f}}},
3271 
3272     {{{5.0f}, {6.0f}},
3273      {{3.0f}, {4.0f}},
3274      {{1.0f}, {2.0f}}},
3275   });
3276   // clang-format on
3277 
3278   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
3279 }
3280 
TEST_P(HloEvaluatorBf16Test,EvaluateWithSubstitutions)3281 TEST_P(HloEvaluatorBf16Test, EvaluateWithSubstitutions) {
3282   HloComputation::Builder b(TestName());
3283   Shape shape = ShapeUtil::MakeShape(F32, {4});
3284 
3285   HloInstruction* param0 =
3286       b.AddInstruction(HloInstruction::CreateParameter(0, shape, "param0"));
3287   HloInstruction* square = b.AddInstruction(HloInstruction::CreateBinary(
3288       shape, HloOpcode::kMultiply, param0, param0));
3289   HloInstruction* add = b.AddInstruction(
3290       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, square));
3291 
3292   // Evaluate add with param0 = {1, 2, 3, 4}, square = {10, 20, 30, 40}.
3293   HloEvaluator evaluator;
3294   Literal param0_literal = LiteralUtil::CreateR1<float>({1, 2, 3, 4});
3295   Literal square_literal = LiteralUtil::CreateR1<float>({10, 20, 30, 40});
3296   TF_ASSERT_OK_AND_ASSIGN(
3297       Literal result,
3298       evaluator.EvaluateWithSubstitutions(
3299           add, {{param0, &param0_literal}, {square, &square_literal}}));
3300   EXPECT_TRUE(LiteralTestUtil::Equal(
3301       LiteralUtil::CreateR1<float>({11, 22, 33, 44}), result));
3302 }
3303 
3304 // Check that EvaluateWithSubstitutions works if one of the operands to the op
3305 // we're evaluating is a constant.
TEST_P(HloEvaluatorBf16Test,EvaluateWithSubstitutionsWithConstantOperand)3306 TEST_P(HloEvaluatorBf16Test, EvaluateWithSubstitutionsWithConstantOperand) {
3307   HloComputation::Builder b(TestName());
3308   Shape shape = ShapeUtil::MakeShape(F32, {4});
3309 
3310   HloInstruction* param0 =
3311       b.AddInstruction(HloInstruction::CreateParameter(0, shape, "param0"));
3312   HloInstruction* square = b.AddInstruction(HloInstruction::CreateBinary(
3313       shape, HloOpcode::kMultiply, param0, param0));
3314   HloInstruction* constant = b.AddInstruction(HloInstruction::CreateConstant(
3315       LiteralUtil::CreateR1<float>({1, 2, 3, 4})));
3316   HloInstruction* add = b.AddInstruction(
3317       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, constant, square));
3318 
3319   // Evaluate add with square = {10, 20, 30, 40}.
3320   HloEvaluator evaluator;
3321   Literal square_literal = LiteralUtil::CreateR1<float>({10, 20, 30, 40});
3322   TF_ASSERT_OK_AND_ASSIGN(
3323       Literal result,
3324       evaluator.EvaluateWithSubstitutions(add, {{square, &square_literal}}));
3325   EXPECT_TRUE(LiteralTestUtil::Equal(
3326       LiteralUtil::CreateR1<float>({11, 22, 33, 44}), result));
3327 }
3328 
TEST_F(HloEvaluatorTest,EvaluateGather_TensorFlowGatherV1)3329 TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV1) {
3330   const char* hlo_text = R"(
3331 HloModule TensorFlowGatherV1
3332 
3333 ENTRY main {
3334   operand = s32[3,3] parameter(0)
3335   indices = s32[2] parameter(1)
3336   ROOT gather = s32[2,3] gather(operand, indices),
3337       offset_dims={1},
3338       collapsed_slice_dims={0},
3339       start_index_map={0},
3340       index_vector_dim=1,
3341       slice_sizes={1, 3}
3342 }
3343 )";
3344   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3345   Literal operand =
3346       LiteralUtil::CreateR2<int32_t>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3347   Literal start_indices = LiteralUtil::CreateR1<int32_t>({0, 2});
3348   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
3349   EXPECT_TRUE(LiteralTestUtil::Equal(
3350       LiteralUtil::CreateR2<int32_t>({{1, 2, 3}, {7, 8, 9}}), result));
3351 }
3352 
TEST_F(HloEvaluatorTest,EvaluateGather_TensorFlowGatherV2)3353 TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) {
3354   const char* hlo_text = R"(
3355 HloModule TensorFlowGatherV2
3356 
3357 ENTRY main {
3358   operand = s32[3,3] parameter(0)
3359   indices = s32[2] parameter(1)
3360   ROOT gather = s32[3,2] gather(operand, indices),
3361       offset_dims={0},
3362       collapsed_slice_dims={1},
3363       start_index_map={1},
3364       index_vector_dim=1,
3365       slice_sizes={3, 1}
3366 }
3367 )";
3368   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3369   Literal operand =
3370       LiteralUtil::CreateR2<int32_t>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3371   Literal start_indices = LiteralUtil::CreateR1<int32_t>({0, 2});
3372   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
3373   EXPECT_TRUE(LiteralTestUtil::Equal(
3374       LiteralUtil::CreateR2<int32_t>({{1, 3}, {4, 6}, {7, 9}}), result));
3375 }
3376 
TEST_F(HloEvaluatorTest,EvaluateGather_TensorFlowGatherMultipleBatchDims)3377 TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) {
3378   const char* hlo_text = R"(
3379 HloModule TensorFlowGatherMultipleBatchDims
3380 
3381 ENTRY main {
3382   operand = s32[3,3] parameter(0)
3383   indices = s32[2,2] parameter(1)
3384   ROOT gather = s32[2,3,2] gather(operand, indices),
3385       offset_dims={1},
3386       collapsed_slice_dims={1},
3387       start_index_map={1},
3388       index_vector_dim=2,
3389       slice_sizes={3, 1}
3390 }
3391 )";
3392   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3393   Literal operand =
3394       LiteralUtil::CreateR2<int32_t>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3395   Literal start_indices = LiteralUtil::CreateR2<int32_t>({{0, 2}, {2, 1}});
3396   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
3397   EXPECT_TRUE(LiteralTestUtil::Equal(
3398       LiteralUtil::CreateR3<int32_t>(
3399           {{{1, 3}, {4, 6}, {7, 9}}, {{3, 2}, {6, 5}, {9, 8}}}),
3400       result));
3401 }
3402 
TEST_F(HloEvaluatorTest,EvaluateGather_TensorFlowGatherNd)3403 TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) {
3404   const char* hlo_text = R"(
3405 HloModule TensorFlowGatherNd
3406 
3407 ENTRY main {
3408   operand = s32[3,3,2] parameter(0)
3409   indices = s32[2,2] parameter(1)
3410   ROOT gather = s32[2,2] gather(operand, indices),
3411       offset_dims={1},
3412       collapsed_slice_dims={0,1},
3413       start_index_map={0,1},
3414       index_vector_dim=1,
3415       slice_sizes={1,1,2}
3416 }
3417 )";
3418   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3419   Literal operand =
3420       LiteralUtil::CreateR3<int32_t>({{{-1, 1}, {-2, 2}, {-3, 3}},  //
3421                                       {{-4, 4}, {-5, 5}, {-6, 6}},  //
3422                                       {{-7, 7}, {-8, 8}, {-9, 9}}});
3423   Literal start_indices = LiteralUtil::CreateR2<int32_t>({{0, 0}, {1, 0}});
3424   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
3425   EXPECT_TRUE(LiteralTestUtil::Equal(
3426       LiteralUtil::CreateR2<int32_t>({{-1, 1}, {-4, 4}}), result));
3427 }
3428 
TEST_F(HloEvaluatorTest,EvaluateGather_TensorFlowGatherNdNonDefaultIndexVectorDim)3429 TEST_F(HloEvaluatorTest,
3430        EvaluateGather_TensorFlowGatherNdNonDefaultIndexVectorDim) {
3431   const char* hlo_text = R"(
3432 HloModule TensorFlowGatherNd
3433 
3434 ENTRY main {
3435   operand = s32[3,3,2] parameter(0)
3436   indices = s32[2,2] parameter(1)
3437   ROOT gather = s32[2,2] gather(operand, indices),
3438       offset_dims={1},
3439       collapsed_slice_dims={0,1},
3440       start_index_map={0,1},
3441       index_vector_dim=0,
3442       slice_sizes={1,1,2}
3443 }
3444 )";
3445   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3446   Literal operand =
3447       LiteralUtil::CreateR3<int32_t>({{{-1, 1}, {-2, 2}, {-3, 3}},  //
3448                                       {{-4, 4}, {-5, 5}, {-6, 6}},  //
3449                                       {{-7, 7}, {-8, 8}, {-9, 9}}});
3450   Literal start_indices = LiteralUtil::CreateR2<int32_t>({{0, 0}, {1, 0}});
3451   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
3452   EXPECT_TRUE(LiteralTestUtil::Equal(
3453       LiteralUtil::CreateR2<int32_t>({{-2, 2}, {-1, 1}}), result));
3454 }
3455 
TEST_F(HloEvaluatorTest,EvaluateGather_DynamicSlice)3456 TEST_F(HloEvaluatorTest, EvaluateGather_DynamicSlice) {
3457   const char* hlo_text = R"(
3458 HloModule DynamicSlice
3459 
3460 ENTRY main {
3461   operand = s32[3,3] parameter(0)
3462   indices = s32[2] parameter(1)
3463   ROOT gather = s32[1,1] gather(operand, indices),
3464       offset_dims={0,1},
3465       collapsed_slice_dims={},
3466       start_index_map={0,1},
3467       index_vector_dim=0,
3468       slice_sizes={1,1}
3469 }
3470 )";
3471   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3472   Literal operand =
3473       LiteralUtil::CreateR2<int32_t>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3474   Literal start_indices = LiteralUtil::CreateR1<int32_t>({1, 1});
3475   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
3476   EXPECT_TRUE(
3477       LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32_t>({{5}}), result));
3478 }
3479 
TEST_F(HloEvaluatorTest,EvaluateGather_BatchDynamicSlice)3480 TEST_F(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) {
3481   const char* hlo_text = R"(
3482 HloModule BatchDynamicSlice
3483 
3484 ENTRY main {
3485   operand = s32[3,3] parameter(0)
3486   indices = s32[2,2] parameter(1)
3487   ROOT gather = s32[2,1,1] gather(operand, indices),
3488       offset_dims={1,2},
3489       collapsed_slice_dims={},
3490       start_index_map={0,1},
3491       index_vector_dim=0,
3492       slice_sizes={1,1}
3493 }
3494 )";
3495   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3496   Literal operand =
3497       LiteralUtil::CreateR2<int32_t>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3498   Literal start_indices = LiteralUtil::CreateR2<int32_t>({{2, 1}, {1, 1}});
3499   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
3500   EXPECT_TRUE(LiteralTestUtil::Equal(
3501       LiteralUtil::CreateR3<int32_t>({{{8}}, {{5}}}), result));
3502 }
3503 
TEST_F(HloEvaluatorTest,EvaluateGather_ZeroDimBounds)3504 TEST_F(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) {
3505   const char* hlo_text = R"(
3506 HloModule TensorFlowGatherV1
3507 
3508 ENTRY main {
3509   operand = s32[3,0] parameter(0)
3510   indices = s32[2] parameter(1)
3511   ROOT gather = s32[2,0] gather(operand, indices),
3512       offset_dims={1},
3513       collapsed_slice_dims={0},
3514       start_index_map={0},
3515       index_vector_dim=1,
3516       slice_sizes={1, 0}
3517 }
3518 )";
3519   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3520   Literal operand = LiteralUtil::CreateR2<int32_t>({{}, {}, {}});
3521   Literal start_indices = LiteralUtil::CreateR1<int32_t>({0, 2});
3522   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
3523   EXPECT_TRUE(
3524       LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32_t>({{}, {}}), result));
3525 }
3526 
TEST_F(HloEvaluatorTest,EvaluateGather_NoOutputWindowDims)3527 TEST_F(HloEvaluatorTest, EvaluateGather_NoOutputWindowDims) {
3528   const std::string hlo_text = R"(
3529 HloModule GatherXd
3530 
3531 ENTRY main {
3532   operand = s32[3] parameter(0)
3533   indices = s32[2,2,1] parameter(1)
3534   ROOT gather = s32[2,2] gather(operand, indices),
3535       offset_dims={},
3536       collapsed_slice_dims={0},
3537       start_index_map={0},
3538       index_vector_dim=2,
3539       slice_sizes={1}
3540 }
3541 )";
3542   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3543 
3544   Literal operand = LiteralUtil::CreateR1<int32_t>({0, 1, 2});
3545   Literal start_indices =
3546       LiteralUtil::CreateR3<int32_t>({{{0}, {1}}, {{2}, {1}}});
3547   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
3548   EXPECT_TRUE(LiteralTestUtil::Equal(
3549       LiteralUtil::CreateR2<int32_t>({{0, 1}, {2, 1}}), result));
3550 }
3551 
TEST_F(HloEvaluatorTest,EvaluateScatter_TensorFlowScatterV1_Update)3552 TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV1_Update) {
3553   const char* hlo_text = R"(
3554 HloModule TensorFlowScatterV1
3555 
3556 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3557   lhs = s32[] parameter(0)
3558   ROOT rhs = s32[] parameter(1)
3559 }
3560 
3561 ENTRY main {
3562   operand = s32[3,3] parameter(0)
3563   indices = s32[2] parameter(1)
3564   updates = s32[2,3] parameter(2)
3565   ROOT scatter = s32[3,3] scatter(operand, indices, updates),
3566       to_apply=update_s32,
3567       update_window_dims={1},
3568       inserted_window_dims={0},
3569       scatter_dims_to_operand_dims={0},
3570       index_vector_dim=1
3571 }
3572 )";
3573   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3574   Literal operand =
3575       LiteralUtil::CreateR2<int32_t>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3576   Literal scatter_indices = LiteralUtil::CreateR1<int32_t>({0, 2});
3577   Literal updates =
3578       LiteralUtil::CreateR2<int32_t>({{10, 20, 30}, {70, 80, 90}});
3579   TF_ASSERT_OK_AND_ASSIGN(Literal result,
3580                           Evaluate({&operand, &scatter_indices, &updates}));
3581   EXPECT_TRUE(LiteralTestUtil::Equal(
3582       LiteralUtil::CreateR2<int32_t>({{10, 20, 30}, {4, 5, 6}, {70, 80, 90}}),
3583       result));
3584 }
3585 
TEST_F(HloEvaluatorTest,EvaluateScatter_TensorFlowScatterV2_Update)3586 TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV2_Update) {
3587   const char* hlo_text = R"(
3588 HloModule TensorFlowScatterV2
3589 
3590 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3591   lhs = s32[] parameter(0)
3592   ROOT rhs = s32[] parameter(1)
3593 }
3594 
3595 ENTRY main {
3596   operand = s32[3,3] parameter(0)
3597   indices = s32[2] parameter(1)
3598   updates = s32[3,2] parameter(2)
3599   ROOT scatter = s32[3,3] scatter(operand, indices, updates),
3600       to_apply=update_s32,
3601       update_window_dims={0},
3602       inserted_window_dims={1},
3603       scatter_dims_to_operand_dims={1},
3604       index_vector_dim=1
3605 }
3606 )";
3607   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3608   Literal operand =
3609       LiteralUtil::CreateR2<int32_t>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3610   Literal scatter_indices = LiteralUtil::CreateR1<int32_t>({0, 2});
3611   Literal updates =
3612       LiteralUtil::CreateR2<int32_t>({{10, 30}, {40, 60}, {70, 90}});
3613   TF_ASSERT_OK_AND_ASSIGN(Literal result,
3614                           Evaluate({&operand, &scatter_indices, &updates}));
3615   EXPECT_TRUE(LiteralTestUtil::Equal(
3616       LiteralUtil::CreateR2<int32_t>({{10, 2, 30}, {40, 5, 60}, {70, 8, 90}}),
3617       result));
3618 }
3619 
TEST_F(HloEvaluatorTest,EvaluateScatter_TensorFlowScatter_Add)3620 TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Add) {
3621   const char* hlo_text = R"(
3622 HloModule TensorFlowScatter
3623 
3624 add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3625   lhs = s32[] parameter(0)
3626   rhs = s32[] parameter(1)
3627   ROOT add = s32[] add(s32[] lhs, s32[] rhs)
3628 }
3629 
3630 ENTRY main {
3631   operand = s32[3,3] parameter(0)
3632   indices = s32[2] parameter(1)
3633   updates = s32[2,3] parameter(2)
3634   ROOT scatter = s32[3,3] scatter(operand, indices, updates),
3635       to_apply=add_s32,
3636       update_window_dims={1},
3637       inserted_window_dims={0},
3638       scatter_dims_to_operand_dims={0},
3639       index_vector_dim=1
3640 }
3641 )";
3642   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3643   Literal operand =
3644       LiteralUtil::CreateR2<int32_t>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3645   Literal scatter_indices = LiteralUtil::CreateR1<int32_t>({0, 2});
3646   Literal updates =
3647       LiteralUtil::CreateR2<int32_t>({{10, 20, 30}, {70, 80, 90}});
3648   TF_ASSERT_OK_AND_ASSIGN(Literal result,
3649                           Evaluate({&operand, &scatter_indices, &updates}));
3650   EXPECT_TRUE(LiteralTestUtil::Equal(
3651       LiteralUtil::CreateR2<int32_t>({{11, 22, 33}, {4, 5, 6}, {77, 88, 99}}),
3652       result));
3653 }
3654 
TEST_F(HloEvaluatorTest,EvaluateScatter_TensorFlowScatter_Mul)3655 TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Mul) {
3656   const char* hlo_text = R"(
3657 HloModule TensorFlowScatter
3658 
3659 mul_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3660   lhs = s32[] parameter(0)
3661   rhs = s32[] parameter(1)
3662   ROOT mul = s32[] multiply(s32[] lhs, s32[] rhs)
3663 }
3664 
3665 ENTRY main {
3666   operand = s32[3,3] parameter(0)
3667   indices = s32[2] parameter(1)
3668   updates = s32[2,3] parameter(2)
3669   ROOT scatter = s32[3,3] scatter(operand, indices, updates),
3670       to_apply=mul_s32,
3671       update_window_dims={1},
3672       inserted_window_dims={0},
3673       scatter_dims_to_operand_dims={0},
3674       index_vector_dim=1
3675 }
3676 )";
3677   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3678   Literal operand =
3679       LiteralUtil::CreateR2<int32_t>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3680   Literal scatter_indices = LiteralUtil::CreateR1<int32_t>({0, 2});
3681   Literal updates =
3682       LiteralUtil::CreateR2<int32_t>({{10, 20, 30}, {70, 80, 90}});
3683   TF_ASSERT_OK_AND_ASSIGN(Literal result,
3684                           Evaluate({&operand, &scatter_indices, &updates}));
3685   EXPECT_TRUE(
3686       LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32_t>(
3687                                  {{10, 40, 90}, {4, 5, 6}, {490, 640, 810}}),
3688                              result));
3689 }
3690 
TEST_P(HloEvaluatorBf16Test,EvaluateScatter_TensorFlowScatter_F32)3691 TEST_P(HloEvaluatorBf16Test, EvaluateScatter_TensorFlowScatter_F32) {
3692   const char* hlo_text = R"(
3693 HloModule TensorFlowScatter
3694 
3695 add_f32 (lhs: f32[], rhs: f32[]) -> f32[] {
3696   lhs = f32[] parameter(0)
3697   rhs = f32[] parameter(1)
3698   ROOT add = f32[] add(f32[] lhs, f32[] rhs)
3699 }
3700 
3701 ENTRY main {
3702   operand = f32[3,3] parameter(0)
3703   indices = s32[2] parameter(1)
3704   updates = f32[2,3] parameter(2)
3705   ROOT scatter = f32[3,3] scatter(operand, indices, updates),
3706       to_apply=add_f32,
3707       update_window_dims={1},
3708       inserted_window_dims={0},
3709       scatter_dims_to_operand_dims={0},
3710       index_vector_dim=1
3711 }
3712 )";
3713   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3714   Literal operand = LiteralUtil::CreateR2<float>(
3715       {{1.1, 2.2, 3.3}, {4.4, 5.5, 6.6}, {7.7, 8.8, 9.9}});
3716   Literal scatter_indices = LiteralUtil::CreateR1<int32_t>({2, 1});
3717   Literal updates =
3718       LiteralUtil::CreateR2<float>({{0.4, 1.1, 0.7}, {2.3, 3.1, 1.6}});
3719   TF_ASSERT_OK_AND_ASSIGN(Literal result,
3720                           Evaluate({&operand, &scatter_indices, &updates}));
3721   EXPECT_TRUE(LiteralTestUtil::Near(
3722       LiteralUtil::CreateR2<float>(
3723           {{1.1, 2.2, 3.3}, {6.7, 8.6, 8.2}, {8.1, 9.9, 10.6}}),
3724       result, ErrorSpec{0.1, 0.01}));
3725 }
3726 
TEST_F(HloEvaluatorTest,EvaluateScatter_TensorFlowScatter_RepeatedIndices)3727 TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_RepeatedIndices) {
3728   const char* hlo_text = R"(
3729 HloModule TensorFlowScatter
3730 
3731 add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3732   lhs = s32[] parameter(0)
3733   rhs = s32[] parameter(1)
3734   ROOT add = s32[] add(s32[] lhs, s32[] rhs)
3735 }
3736 
3737 ENTRY main {
3738   operand = s32[3,3] parameter(0)
3739   indices = s32[2] parameter(1)
3740   updates = s32[2,3] parameter(2)
3741   ROOT scatter = s32[3,3] scatter(operand, indices, updates),
3742       to_apply=add_s32,
3743       update_window_dims={1},
3744       inserted_window_dims={0},
3745       scatter_dims_to_operand_dims={0},
3746       index_vector_dim=1
3747 }
3748 )";
3749   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3750   Literal operand =
3751       LiteralUtil::CreateR2<int32_t>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3752   Literal scatter_indices = LiteralUtil::CreateR1<int32_t>({1, 1});
3753   Literal updates =
3754       LiteralUtil::CreateR2<int32_t>({{10, 20, 30}, {70, 80, 90}});
3755   TF_ASSERT_OK_AND_ASSIGN(Literal result,
3756                           Evaluate({&operand, &scatter_indices, &updates}));
3757   EXPECT_TRUE(LiteralTestUtil::Equal(
3758       LiteralUtil::CreateR2<int32_t>({{1, 2, 3}, {84, 105, 126}, {7, 8, 9}}),
3759       result));
3760 }
3761 
TEST_F(HloEvaluatorTest,EvaluateScatter_TensorFlowScatter_MultipleBatchDims)3762 TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_MultipleBatchDims) {
3763   const char* hlo_text = R"(
3764 HloModule TensorFlowScatterMultipleBatchDims
3765 
3766 add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3767   lhs = s32[] parameter(0)
3768   rhs = s32[] parameter(1)
3769   ROOT add = s32[] add(s32[] lhs, s32[] rhs)
3770 }
3771 
3772 ENTRY main {
3773   operand = s32[3,3] parameter(0)
3774   indices = s32[2,2] parameter(1)
3775   updates = s32[2,3,2] parameter(2)
3776   ROOT scatter = s32[3,3] scatter(operand, indices, updates),
3777       to_apply=add_s32,
3778       update_window_dims={1},
3779       inserted_window_dims={1},
3780       scatter_dims_to_operand_dims={1},
3781       index_vector_dim=2
3782 }
3783 )";
3784   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3785   Literal operand =
3786       LiteralUtil::CreateR2<int32_t>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3787   Literal scatter_indices = LiteralUtil::CreateR2<int32_t>({{0, 2}, {2, 1}});
3788   Literal updates = LiteralUtil::CreateR3<int32_t>(
3789       {{{10, 30}, {40, 60}, {70, 90}}, {{5, 5}, {5, 5}, {5, 5}}});
3790   TF_ASSERT_OK_AND_ASSIGN(Literal result,
3791                           Evaluate({&operand, &scatter_indices, &updates}));
3792   EXPECT_TRUE(
3793       LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32_t>(
3794                                  {{11, 7, 38}, {44, 10, 71}, {77, 13, 104}}),
3795                              result));
3796 }
3797 
TEST_F(HloEvaluatorTest,EvaluateScatter_TensorFlowScatterNd)3798 TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterNd) {
3799   const char* hlo_text = R"(
3800 HloModule TensorFlowScatterNd
3801 
3802 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3803   lhs = s32[] parameter(0)
3804   ROOT rhs = s32[] parameter(1)
3805 }
3806 
3807 ENTRY main {
3808   operand = s32[3,3,2] parameter(0)
3809   indices = s32[2,2] parameter(1)
3810   updates = s32[2,2] parameter(2)
3811   ROOT scatter = s32[3,3,2] scatter(operand, indices, updates),
3812       to_apply=update_s32,
3813       update_window_dims={1},
3814       inserted_window_dims={0,1},
3815       scatter_dims_to_operand_dims={0,1},
3816       index_vector_dim=1
3817 }
3818 )";
3819   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3820   Literal operand =
3821       LiteralUtil::CreateR3<int32_t>({{{-1, 1}, {-2, 2}, {-3, 3}},  //
3822                                       {{-4, 4}, {-5, 5}, {-6, 6}},  //
3823                                       {{-7, 7}, {-8, 8}, {-9, 9}}});
3824   Literal scatter_indices = LiteralUtil::CreateR2<int32_t>({{0, 0}, {1, 0}});
3825   Literal updates = LiteralUtil::CreateR2<int32_t>({{-10, 10}, {-40, 40}});
3826   Literal expected =
3827       LiteralUtil::CreateR3<int32_t>({{{-10, 10}, {-2, 2}, {-3, 3}},  //
3828                                       {{-40, 40}, {-5, 5}, {-6, 6}},  //
3829                                       {{-7, 7}, {-8, 8}, {-9, 9}}});
3830   TF_ASSERT_OK_AND_ASSIGN(Literal result,
3831                           Evaluate({&operand, &scatter_indices, &updates}));
3832   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
3833 }
3834 
TEST_F(HloEvaluatorTest,EvaluateScatter_TensorFlowScatterNd_NonDefaultIndexVectorDim)3835 TEST_F(HloEvaluatorTest,
3836        EvaluateScatter_TensorFlowScatterNd_NonDefaultIndexVectorDim) {
3837   const char* hlo_text = R"(
3838 HloModule TensorFlowScatterNdNonDefaultIndexVectorDim
3839 
3840 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3841   lhs = s32[] parameter(0)
3842   ROOT rhs = s32[] parameter(1)
3843 }
3844 
3845 ENTRY main {
3846   operand = s32[3,3,2] parameter(0)
3847   indices = s32[2,2] parameter(1)
3848   updates = s32[2,2] parameter(2)
3849   ROOT scatter = s32[3,3,2] scatter(operand, indices, updates),
3850       to_apply=update_s32,
3851       update_window_dims={1},
3852       inserted_window_dims={0,1},
3853       scatter_dims_to_operand_dims={0,1},
3854       index_vector_dim=0
3855 }
3856 )";
3857   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3858   Literal operand =
3859       LiteralUtil::CreateR3<int32_t>({{{-1, 1}, {-2, 2}, {-3, 3}},  //
3860                                       {{-4, 4}, {-5, 5}, {-6, 6}},  //
3861                                       {{-7, 7}, {-8, 8}, {-9, 9}}});
3862   Literal scatter_indices = LiteralUtil::CreateR2<int32_t>({{0, 0}, {1, 0}});
3863   Literal updates = LiteralUtil::CreateR2<int32_t>({{-10, 10}, {-20, 20}});
3864   Literal expected =
3865       LiteralUtil::CreateR3<int32_t>({{{-20, 20}, {-10, 10}, {-3, 3}},  //
3866                                       {{-4, 4}, {-5, 5}, {-6, 6}},      //
3867                                       {{-7, 7}, {-8, 8}, {-9, 9}}});
3868   TF_ASSERT_OK_AND_ASSIGN(Literal result,
3869                           Evaluate({&operand, &scatter_indices, &updates}));
3870   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
3871 }
3872 
TEST_F(HloEvaluatorTest,EvaluateScatter_DynamicUpdateSlice)3873 TEST_F(HloEvaluatorTest, EvaluateScatter_DynamicUpdateSlice) {
3874   const char* hlo_text = R"(
3875 HloModule DynamicUpdateSlice
3876 
3877 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3878   lhs = s32[] parameter(0)
3879   ROOT rhs = s32[] parameter(1)
3880 }
3881 
3882 ENTRY main {
3883   operand = s32[3,3] parameter(0)
3884   indices = s32[2] parameter(1)
3885   updates = s32[1,1] parameter(2)
3886   ROOT scatter = s32[3,3] scatter(operand, indices, updates),
3887       to_apply=update_s32,
3888       update_window_dims={0,1},
3889       inserted_window_dims={},
3890       scatter_dims_to_operand_dims={0,1},
3891       index_vector_dim=0
3892 }
3893 )";
3894   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3895   Literal operand =
3896       LiteralUtil::CreateR2<int32_t>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3897   Literal scatter_indices = LiteralUtil::CreateR1<int32_t>({1, 1});
3898   Literal updates = LiteralUtil::CreateR2<int32_t>({{10}});
3899   Literal expected =
3900       LiteralUtil::CreateR2<int32_t>({{1, 2, 3}, {4, 10, 6}, {7, 8, 9}});
3901   TF_ASSERT_OK_AND_ASSIGN(Literal result,
3902                           Evaluate({&operand, &scatter_indices, &updates}));
3903   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
3904 }
3905 
TEST_F(HloEvaluatorTest,EvaluateScatter_BatchDynamicUpdateSlice)3906 TEST_F(HloEvaluatorTest, EvaluateScatter_BatchDynamicUpdateSlice) {
3907   const char* hlo_text = R"(
3908 HloModule BatchDynamicUpdateSlice
3909 
3910 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3911   lhs = s32[] parameter(0)
3912   ROOT rhs = s32[] parameter(1)
3913 }
3914 
3915 ENTRY main {
3916   operand = s32[3,3] parameter(0)
3917   indices = s32[2,2] parameter(1)
3918   updates = s32[2,1,1] parameter(2)
3919   ROOT scatter = s32[3,3] scatter(operand, indices, updates),
3920       to_apply=update_s32,
3921       update_window_dims={1,2},
3922       inserted_window_dims={},
3923       scatter_dims_to_operand_dims={0,1},
3924       index_vector_dim=0
3925 }
3926 )";
3927   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3928   Literal operand =
3929       LiteralUtil::CreateR2<int32_t>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3930   Literal scatter_indices = LiteralUtil::CreateR2<int32_t>({{2, 1}, {1, 1}});
3931   Literal updates = LiteralUtil::CreateR3<int32_t>({{{10}}, {{20}}});
3932   Literal expected =
3933       LiteralUtil::CreateR2<int32_t>({{1, 2, 3}, {4, 20, 6}, {7, 10, 9}});
3934   TF_ASSERT_OK_AND_ASSIGN(Literal result,
3935                           Evaluate({&operand, &scatter_indices, &updates}));
3936   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
3937 }
3938 
TEST_F(HloEvaluatorTest,EvaluateScatter_ZeroDimBounds)3939 TEST_F(HloEvaluatorTest, EvaluateScatter_ZeroDimBounds) {
3940   const char* hlo_text = R"(
3941 HloModule TensorFlowScatter_ZeroDimBounds
3942 
3943 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3944   lhs = s32[] parameter(0)
3945   ROOT rhs = s32[] parameter(1)
3946 }
3947 
3948 ENTRY main {
3949   operand = s32[3,0] parameter(0)
3950   indices = s32[2] parameter(1)
3951   updates = s32[2,0] parameter(2)
3952   ROOT scatter = s32[3,0] scatter(operand, indices, updates),
3953       to_apply=update_s32,
3954       update_window_dims={1},
3955       inserted_window_dims={0},
3956       scatter_dims_to_operand_dims={0},
3957       index_vector_dim=1
3958 }
3959 )";
3960   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3961   Literal operand = LiteralUtil::CreateR2<int32_t>({{}, {}, {}});
3962   Literal scatter_indices = LiteralUtil::CreateR1<int32_t>({0, 2});
3963   Literal updates = LiteralUtil::CreateR2<int32_t>({{}, {}});
3964   TF_ASSERT_OK_AND_ASSIGN(Literal result,
3965                           Evaluate({&operand, &scatter_indices, &updates}));
3966   EXPECT_TRUE(LiteralTestUtil::Equal(operand, result));
3967 }
3968 
TEST_F(HloEvaluatorTest,EvaluateScatter_NoUpdateWindowDims)3969 TEST_F(HloEvaluatorTest, EvaluateScatter_NoUpdateWindowDims) {
3970   const std::string hlo_text = R"(
3971 HloModule Scatter_NoUpdateWindowDims
3972 
3973 add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3974   lhs = s32[] parameter(0)
3975   rhs = s32[] parameter(1)
3976   ROOT add = s32[] add(s32[] lhs, s32[] rhs)
3977 }
3978 
3979 ENTRY main {
3980   operand = s32[3] parameter(0)
3981   indices = s32[2,2,1] parameter(1)
3982   updates = s32[2,2] parameter(2)
3983   ROOT scatter = s32[3] scatter(operand, indices, updates),
3984       to_apply=add_s32,
3985       update_window_dims={},
3986       inserted_window_dims={0},
3987       scatter_dims_to_operand_dims={0},
3988       index_vector_dim=2
3989 }
3990 )";
3991   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3992 
3993   Literal operand = LiteralUtil::CreateR1<int32_t>({0, 1, 2});
3994   Literal scatter_indices =
3995       LiteralUtil::CreateR3<int32_t>({{{0}, {1}}, {{2}, {1}}});
3996   Literal updates = LiteralUtil::CreateR2<int32_t>({{10, 20}, {30, 40}});
3997   Literal expected = LiteralUtil::CreateR1<int32_t>({10, 61, 32});
3998   TF_ASSERT_OK_AND_ASSIGN(Literal result,
3999                           Evaluate({&operand, &scatter_indices, &updates}));
4000   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
4001 }
4002 
TEST_F(HloEvaluatorTest,EvaluateScatter_NegativeIndices)4003 TEST_F(HloEvaluatorTest, EvaluateScatter_NegativeIndices) {
4004   const char* hlo_text = R"(
4005 HloModule TensorFlowScatter_NegativeIndices
4006 
4007 add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
4008   lhs = s32[] parameter(0)
4009   rhs = s32[] parameter(1)
4010   ROOT add = s32[] add(s32[] lhs, s32[] rhs)
4011 }
4012 
4013 ENTRY main {
4014   operand = s32[3,3] parameter(0)
4015   indices = s32[2] parameter(1)
4016   updates = s32[2,3] parameter(2)
4017   ROOT scatter = s32[3,3] scatter(operand, indices, updates),
4018       to_apply=add_s32,
4019       update_window_dims={1},
4020       inserted_window_dims={0},
4021       scatter_dims_to_operand_dims={0},
4022       index_vector_dim=1
4023 }
4024 )";
4025   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
4026                           ParseAndReturnVerifiedModule(hlo_text));
4027   Literal operand =
4028       LiteralUtil::CreateR2<int32_t>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
4029   // No updates should happen for the negative indices.
4030   Literal scatter_indices = LiteralUtil::CreateR1<int32_t>({-1, 2});
4031   Literal updates =
4032       LiteralUtil::CreateR2<int32_t>({{10, 20, 30}, {70, 80, 90}});
4033   EXPECT_TRUE(LiteralTestUtil::Equal(
4034       LiteralUtil::CreateR2<int32_t>({{1, 2, 3}, {4, 5, 6}, {77, 88, 99}}),
4035       EvaluateWithModule(module.get(),
4036                          {&operand, &scatter_indices, &updates})));
4037 }
4038 
TEST_F(HloEvaluatorTest,EvaluateScatter_OobIndices)4039 TEST_F(HloEvaluatorTest, EvaluateScatter_OobIndices) {
4040   const std::string hlo_text = R"(
4041 HloModule BatchDynamicUpdateSlice
4042 
4043 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
4044   lhs = s32[] parameter(0)
4045   ROOT rhs = s32[] parameter(1)
4046 }
4047 
4048 ENTRY main {
4049   operand = s32[3,3]{1,0} parameter(0)
4050   indices = s32[6,2]{1,0} parameter(1)
4051   updates = s32[6,1,1]{2,1,0} parameter(2)
4052   ROOT scatter = s32[3,3]{1,0} scatter(operand, indices, updates),
4053       to_apply=update_s32,
4054       update_window_dims={1,2},
4055       inserted_window_dims={},
4056       scatter_dims_to_operand_dims={0,1},
4057       index_vector_dim=1
4058 }
4059 )";
4060   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
4061                           ParseAndReturnVerifiedModule(hlo_text));
4062   Literal operand =
4063       LiteralUtil::CreateR2<int32_t>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
4064   // No updates should happen for the OOB indices.
4065   Literal scatter_indices = LiteralUtil::CreateR2<int32_t>(
4066       {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}});
4067   Literal updates = LiteralUtil::CreateR3<int32_t>(
4068       {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}});
4069   EXPECT_TRUE(LiteralTestUtil::Equal(
4070       LiteralUtil::CreateR2<int32_t>({{1, 2, 3}, {4, 30, 60}, {7, 20, 9}}),
4071       EvaluateWithModule(module.get(),
4072                          {&operand, &scatter_indices, &updates})));
4073 }
4074 
TEST_F(HloEvaluatorTest,EvaluateScatter_OobUpdateWindow)4075 TEST_F(HloEvaluatorTest, EvaluateScatter_OobUpdateWindow) {
4076   const char* hlo_text = R"(
4077 HloModule TensorFlowScatterNd_OobUpdateWindow
4078 
4079 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
4080   lhs = s32[] parameter(0)
4081   ROOT rhs = s32[] parameter(1)
4082 }
4083 
4084 ENTRY main {
4085   operand = s32[3,3,2] parameter(0)
4086   indices = s32[1,2] parameter(1)
4087   updates = s32[1,2,2] parameter(2)
4088   ROOT scatter = s32[3,3,2] scatter(operand, indices, updates),
4089       to_apply=update_s32,
4090       update_window_dims={1,2},
4091       inserted_window_dims={0},
4092       scatter_dims_to_operand_dims={0,1},
4093       index_vector_dim=1
4094 }
4095 )";
4096   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
4097                           ParseAndReturnVerifiedModule(hlo_text));
4098   Literal operand =
4099       LiteralUtil::CreateR3<int32_t>({{{-1, 1}, {-2, 2}, {-3, 3}},  //
4100                                       {{-4, 4}, {-5, 5}, {-6, 6}},  //
4101                                       {{-7, 7}, {-8, 8}, {-9, 9}}});
4102   Literal scatter_indices = LiteralUtil::CreateR2<int32_t>({{0, 2}});
4103   Literal updates = LiteralUtil::CreateR3<int32_t>({{{-10, 10}, {-40, 40}}});
4104   // Given the update window size of 2,2 and the index of 0,2, the update window
4105   // will be OOB. So, nothing should be updated.
4106   Literal expected = operand.Clone();
4107   EXPECT_TRUE(LiteralTestUtil::Equal(
4108       expected, EvaluateWithModule(module.get(),
4109                                    {&operand, &scatter_indices, &updates})));
4110 }
4111 
TEST_F(HloEvaluatorTest,EvaluateScatter_Multioutput)4112 TEST_F(HloEvaluatorTest, EvaluateScatter_Multioutput) {
4113   const char* hlo_text = R"(
4114 HloModule MultioutputScatter
4115 
4116 update {
4117   lhs0 = s32[] parameter(0)
4118   lhs1 = f32[] parameter(1)
4119   rhs0 = s32[] parameter(2)
4120   rhs1 = f32[] parameter(3)
4121   ROOT tuple = (s32[], f32[]) tuple(rhs0, rhs1)
4122 }
4123 
4124 ENTRY main {
4125   operand0 = s32[3,3,2] parameter(0)
4126   operand1 = f32[3,3,2] parameter(1)
4127   indices = s32[2,2] parameter(2)
4128   updates0 = s32[2,2] parameter(3)
4129   updates1 = f32[2,2] parameter(4)
4130   ROOT scatter = (s32[3,3,2], f32[3,3,2]) scatter(operand0, operand1, indices, updates0, updates1),
4131       to_apply=update,
4132       update_window_dims={1},
4133       inserted_window_dims={0,1},
4134       scatter_dims_to_operand_dims={0,1},
4135       index_vector_dim=1
4136 }
4137 )";
4138   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4139   Literal operand0 =
4140       LiteralUtil::CreateR3<int32_t>({{{-1, 1}, {-2, 2}, {-3, 3}},  //
4141                                       {{-4, 4}, {-5, 5}, {-6, 6}},  //
4142                                       {{-7, 7}, {-8, 8}, {-9, 9}}});
4143   Literal operand1 =
4144       LiteralUtil::CreateR3<float>({{{-2, 2}, {-3, 3}, {-4, 4}},  //
4145                                     {{-5, 5}, {-6, 6}, {-7, 7}},  //
4146                                     {{-8, 8}, {-9, 9}, {-10, 10}}});
4147   Literal scatter_indices = LiteralUtil::CreateR2<int32_t>({{0, 0}, {1, 0}});
4148   Literal updates0 = LiteralUtil::CreateR2<int32_t>({{-10, 10}, {-40, 40}});
4149   Literal updates1 = LiteralUtil::CreateR2<float>({{-11, 11}, {-41, 41}});
4150   Literal expected = LiteralUtil::MakeTupleOwned(
4151       LiteralUtil::CreateR3<int32_t>({{{-10, 10}, {-2, 2}, {-3, 3}},  //
4152                                       {{-40, 40}, {-5, 5}, {-6, 6}},  //
4153                                       {{-7, 7}, {-8, 8}, {-9, 9}}}),
4154       LiteralUtil::CreateR3<float>({{{-11, 11}, {-3, 3}, {-4, 4}},  //
4155                                     {{-41, 41}, {-6, 6}, {-7, 7}},  //
4156                                     {{-8, 8}, {-9, 9}, {-10, 10}}}));
4157   TF_ASSERT_OK_AND_ASSIGN(
4158       Literal result,
4159       Evaluate({&operand0, &operand1, &scatter_indices, &updates0, &updates1}));
4160   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
4161 }
4162 
4163 // Verifies that HloEvaluator evaluates a HLO instruction that performs
4164 // element-wise comparison with 2 bfloat16 operands.
TEST_F(HloEvaluatorTest,DoesCompareBF16)4165 TEST_F(HloEvaluatorTest, DoesCompareBF16) {
4166   // lhs >= rhs
4167   auto lhs = LiteralUtil::CreateR2<bfloat16>(
4168       {{bfloat16(0.25), bfloat16(0.35), bfloat16(0.125)},
4169        {bfloat16(-0.25), bfloat16(-0.35), bfloat16(-0.125)}});
4170   auto rhs = LiteralUtil::CreateR2<bfloat16>(
4171       {{bfloat16(0.5), bfloat16(0.125), bfloat16(0.125)},
4172        {bfloat16(0.25), bfloat16(-0.375), bfloat16(-0.127)}});
4173   auto expected =
4174       LiteralUtil::CreateR2<bool>({{false, true, true}, {false, true, true}});
4175 
4176   HloComputation::Builder b(TestName());
4177   auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs)));
4178   auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs)));
4179   b.AddInstruction(HloInstruction::CreateCompare(expected.shape(), c1, c2,
4180                                                  ComparisonDirection::kGe));
4181   m_->AddEntryComputation(b.Build());
4182 
4183   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
4184   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
4185 }
4186 
TEST_P(HloEvaluatorBf16Test,Bf16Reduction)4187 TEST_P(HloEvaluatorBf16Test, Bf16Reduction) {
4188   const std::string hlo_text = R"(
4189 HloModule Bf16Reduction
4190 
4191 add_bf16 (lhs: bf16[], rhs: bf16[]) -> bf16[] {
4192   lhs = bf16[] parameter(0)
4193   rhs = bf16[] parameter(1)
4194   ROOT add = bf16[] add(bf16[] lhs, bf16[] rhs)
4195 }
4196 
4197 ENTRY main {
4198   arg0 = bf16[4]{0} parameter(0)
4199   init = bf16[] constant(0)
4200   ROOT %reduce = bf16[] reduce(arg0, init), dimensions={0}, to_apply=add_bf16
4201 }
4202 )";
4203   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4204 
4205   Literal arg = LiteralUtil::CreateR1<bfloat16>(
4206       {bfloat16(1.0f), bfloat16(3.0f), bfloat16(-2.0f), bfloat16(42.0f)});
4207   Literal expected = LiteralUtil::CreateR0<bfloat16>(bfloat16(44.0f));
4208   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&arg}));
4209   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
4210 }
4211 
TEST_F(HloEvaluatorTest,MixedPrecisionReduction)4212 TEST_F(HloEvaluatorTest, MixedPrecisionReduction) {
4213   const std::string hlo_text = R"(
4214 HloModule MixedPrecisionReduction
4215 
4216 add_f32 {
4217   lhs = f32[] parameter(0)
4218   rhs = f32[] parameter(1)
4219   ROOT add = f32[] add(lhs, rhs)
4220 }
4221 
4222 ENTRY main {
4223   arg0 = f32[4]{0} parameter(0)
4224   init = f32[] constant(0)
4225   ROOT %reduce = bf16[] reduce(arg0, init), dimensions={0}, to_apply=add_f32
4226 }
4227 )";
4228   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4229 
4230   Literal arg = LiteralUtil::CreateR1<float>({1.0f, 3.0f, -2.0f, 42.0f});
4231   Literal expected = LiteralUtil::CreateR0<bfloat16>(bfloat16(44.0f));
4232   TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&arg}));
4233   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
4234 }
4235 
TEST_F(HloEvaluatorTest,DontFailOnCallUnimplementedOps)4236 TEST_F(HloEvaluatorTest, DontFailOnCallUnimplementedOps) {
4237   // Outfeed triggers unimplemented error within HandleCall, and we verify that
4238   // the Evaluator does fail in such case.
4239   const std::string hlo_text = R"(
4240 HloModule DontFailOnCall
4241 
4242 call {
4243   token0 = token[] after-all()
4244   constant = u32[3]{0} constant({1,2,3})
4245   ROOT  outfeed = token[] outfeed(constant, token0), outfeed_shape=u32[3]{0}
4246 }
4247 
4248 ENTRY main {
4249   ROOT result = token[] call(), to_apply=call
4250 }
4251 )";
4252   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4253   auto statusor = Evaluate();
4254   EXPECT_FALSE(statusor.status().ok());
4255 }
4256 
TEST_F(HloEvaluatorTest,DontFailOnFusionWithUnimplementedOps)4257 TEST_F(HloEvaluatorTest, DontFailOnFusionWithUnimplementedOps) {
4258   // Outfeed triggers unimplemented error within HandleFusion, and we verify
4259   // that the Evaluator does fail in such case.
4260   const std::string hlo_text = R"(
4261 HloModule DontFailOnFusion
4262 
4263 fused_computation {
4264   token0 = token[] after-all()
4265   constant = u32[3]{0} constant({1,2,3})
4266   ROOT  outfeed = token[] outfeed(constant, token0), outfeed_shape=u32[3]{0}
4267 }
4268 
4269 ENTRY main {
4270   ROOT result = token[] fusion(), kind=kLoop, calls=fused_computation
4271 }
4272 )";
4273   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4274   auto statusor = Evaluate();
4275   EXPECT_FALSE(statusor.status().ok());
4276 }
4277 
TEST_P(HloEvaluatorBf16Test,SliceWithDifferentLayout)4278 TEST_P(HloEvaluatorBf16Test, SliceWithDifferentLayout) {
4279   // Regression test for b/114735354.
4280   const std::string hlo_text = R"(
4281 HloModule SliceWithDifferentLayout
4282 
4283 ENTRY main {
4284   arg = f32[2,2,2]{0,1,2} parameter(0)
4285   ROOT %slice = f32[2,2,2]{1,0,2} slice(f32[2,2,2]{0,1,2} %arg), slice={[0:2], [0:2], [0:2]}
4286 }
4287 )";
4288   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4289 
4290   Literal arg = LiteralUtil::CreateR3WithLayout<float>(
4291       {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
4292       LayoutUtil::MakeLayout({0, 1, 2}));
4293   TF_ASSERT_OK_AND_ASSIGN(Literal actual, Evaluate({&arg}));
4294   EXPECT_TRUE(LiteralTestUtil::Equal(arg, actual));
4295 }
4296 
TEST_P(HloEvaluatorBf16Test,Bitcast)4297 TEST_P(HloEvaluatorBf16Test, Bitcast) {
4298   // Regression test for b/114735354.
4299   const absl::string_view hlo_text_base = R"(
4300 HloModule Bitcast
4301 
4302 ENTRY main {
4303   param = %s[32,121]{1,0} parameter(0)
4304   ROOT bitcast = %s[121,32,1]{0,1,2} bitcast(%s[32,121]{1,0} param)
4305 }
4306 )";
4307   std::string hlo_text;
4308   if (use_bfloat16_) {
4309     hlo_text = absl::StrFormat(hlo_text_base, "bf16", "bf16", "bf16");
4310   } else {
4311     hlo_text = absl::StrFormat(hlo_text_base, "f32", "f32", "f32");
4312   }
4313   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4314   auto args = MakeFakeArguments(m_.get()).value();
4315   TF_ASSERT_OK_AND_ASSIGN(Literal actual, Evaluate({&args[0]}));
4316   if (use_bfloat16_) {
4317     EXPECT_TRUE(
4318         absl::c_equal(args[0].data<bfloat16>(), actual.data<bfloat16>()));
4319   } else {
4320     EXPECT_TRUE(absl::c_equal(args[0].data<float>(), actual.data<float>()));
4321   }
4322 }
4323 
4324 // Check that s32 under/overflow doesn't trigger a ubsan failure.
TEST_F(HloEvaluatorTest,Int32Overflow)4325 TEST_F(HloEvaluatorTest, Int32Overflow) {
4326   const absl::string_view hlo_text = R"(
4327 HloModule Test
4328 
4329 ENTRY main {
4330   c1 = s32[] constant(1073741824)  // 2^30
4331   sum = s32[] add(c1, c1)  // 2^31, i.e. INT_MIN
4332 
4333   c2 = s32[] constant(-2147483648)  // -2^31
4334   sub = s32[] subtract(c2, c1)  // -2^31 - 2^30, underflows
4335 
4336   mul = s32[] multiply(c1, c1)
4337   ROOT tuple = (s32[], s32[], s32[]) tuple(sum, sub, mul)
4338 }
4339 )";
4340   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4341   TF_ASSERT_OK_AND_ASSIGN(auto literal, Evaluate({}));
4342   std::vector<Literal> actual = literal.DecomposeTuple();
4343   ASSERT_EQ(actual.size(), 3);
4344 
4345   uint32_t pow30 = uint32_t{1} << 30;
4346   uint32_t pow31 = uint32_t{1} << 31;
4347   EXPECT_EQ(actual[0].GetFirstElement<int32_t>(), static_cast<int32_t>(pow31));
4348   EXPECT_EQ(actual[1].GetFirstElement<int32_t>(),
4349             static_cast<int32_t>(-(pow31 + pow30)));
4350   EXPECT_EQ(actual[2].GetFirstElement<int32_t>(),
4351             static_cast<int32_t>(pow31 * pow31));
4352 }
4353 
TEST_F(HloEvaluatorTest,GetDimensionSize)4354 TEST_F(HloEvaluatorTest, GetDimensionSize) {
4355   const absl::string_view hlo_text = R"(
4356 HloModule Test
4357 
4358 ENTRY main {
4359   size = s32[] parameter(0)
4360 
4361   data = s32[4] parameter(1)
4362 
4363   sum = s32[4] add(data, data)
4364 
4365   ROOT dynamic_size = s32[] get-dimension-size(sum), dimensions={0}
4366 }
4367 )";
4368   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4369 
4370   // Set up dynamic parameter binding.
4371   TF_CHECK_OK(m_->dynamic_parameter_binding().Bind(
4372       DynamicParameterBinding::DynamicParameter{0, {}},
4373       DynamicParameterBinding::DynamicDimension{1, {}, 0}));
4374 
4375   TF_ASSERT_OK_AND_ASSIGN(DynamicDimensionInference dynamic_dimension_inference,
4376                           DynamicDimensionInference::Run(m_.get()));
4377 
4378   evaluator_.set_dynamic_dimension_inference(&dynamic_dimension_inference);
4379   Literal size_arg = LiteralUtil::CreateR0<int32_t>(3);
4380   Literal data_arg = LiteralUtil::CreateR1<int32_t>({1, 2, 3, 4});
4381 
4382   TF_ASSERT_OK_AND_ASSIGN(Literal actual, Evaluate({&size_arg, &data_arg}));
4383 
4384   EXPECT_EQ(actual.GetFirstElement<int32_t>(), static_cast<int32_t>(3));
4385 }
4386 
4387 // Check that we get a useful error if we pass inputs of the wrong shape.
TEST_F(HloEvaluatorTest,EvaluateWithWrongInputShapes)4388 TEST_F(HloEvaluatorTest, EvaluateWithWrongInputShapes) {
4389   const absl::string_view hlo_text = R"(
4390 HloModule Test
4391 
4392 ENTRY main {
4393   p0 = s32[1] parameter(0)
4394   ROOT sum = s32[1] add(p0, p0)
4395 }
4396 )";
4397   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4398   Literal input_wrong_shape = LiteralUtil::CreateR1<int32_t>({0, 1});
4399 
4400   EXPECT_EQ(HloEvaluator()
4401                 .Evaluate(*m_, {&input_wrong_shape})
4402                 .status()
4403                 .error_message(),
4404             "Shape mismatch at parameter 0. Computation expected s32[1]{0}, "
4405             "but arg was s32[2]{0}.");
4406   EXPECT_EQ(HloEvaluator()
4407                 .Evaluate(*m_->entry_computation(), {&input_wrong_shape})
4408                 .status()
4409                 .error_message(),
4410             "Shape mismatch at parameter 0. Computation expected s32[1]{0}, "
4411             "but arg was s32[2]{0}.");
4412 }
4413 
4414 // Check that we get a useful error if we pass too many or too few inputs.
TEST_F(HloEvaluatorTest,EvaluateWithWrongNumberOfInputs)4415 TEST_F(HloEvaluatorTest, EvaluateWithWrongNumberOfInputs) {
4416   const absl::string_view hlo_text = R"(
4417 HloModule Test
4418 
4419 ENTRY main {
4420   p0 = s32[1] parameter(0)
4421   ROOT sum = s32[1] add(p0, p0)
4422 }
4423 )";
4424   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4425   Literal input = LiteralUtil::CreateR1<int32_t>({0});
4426 
4427   EXPECT_EQ(
4428       HloEvaluator().Evaluate(*m_, {&input, &input}).status().error_message(),
4429       "Expected 1 argument, but got 2.");
4430   EXPECT_EQ(HloEvaluator()
4431                 .Evaluate(*m_->entry_computation(), {&input, &input})
4432                 .status()
4433                 .error_message(),
4434             "Expected 1 argument, but got 2.");
4435 }
4436 
TEST_F(HloEvaluatorTest,PreserveFusionInputLayout)4437 TEST_F(HloEvaluatorTest, PreserveFusionInputLayout) {
4438   const absl::string_view hlo_text = R"(
4439     HloModule FusionInputLayout
4440 
4441     fused_computation {
4442       param_0 = f32[20,20]{0,1} parameter(0)
4443       ROOT bitcast = f32[20,20]{1,0} bitcast(param_0)
4444     }
4445 
4446     ENTRY kernel_entry {
4447       parameter.0 = f32[20,20]{0,1} parameter(0)
4448       ROOT fusion = f32[20,20]{1,0} fusion(parameter.0),
4449         kind=kLoop, calls=fused_computation
4450     })";
4451 
4452   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4453   auto args = MakeFakeArguments(m_.get()).value();
4454 
4455   TF_ASSERT_OK_AND_ASSIGN(Literal actual, Evaluate({&args[0]}));
4456   EXPECT_TRUE(absl::c_equal(args[0].data<float>(), actual.data<float>()));
4457 }
4458 
TEST_F(HloEvaluatorTest,PreserveFusionOutputLayout)4459 TEST_F(HloEvaluatorTest, PreserveFusionOutputLayout) {
4460   const absl::string_view hlo_text = R"(
4461     HloModule FusionOutputLayout
4462 
4463     fused_computation {
4464       param_0 = f32[20,20]{1,0} parameter(0)
4465       ROOT bitcast = f32[20,20]{0,1} bitcast(param_0)
4466     }
4467 
4468     ENTRY kernel_entry {
4469       parameter.0 = f32[20,20]{1,0} parameter(0)
4470       ROOT fusion = f32[20,20]{0,1} fusion(parameter.0),
4471         kind=kLoop, calls=fused_computation
4472     })";
4473 
4474   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4475   auto args = MakeFakeArguments(m_.get()).value();
4476   TF_ASSERT_OK_AND_ASSIGN(Literal actual, Evaluate({&args[0]}));
4477   EXPECT_TRUE(absl::c_equal(args[0].data<float>(), actual.data<float>()));
4478 }
4479 
TEST_F(HloEvaluatorTest,PreserveMOFusionOutputLayout)4480 TEST_F(HloEvaluatorTest, PreserveMOFusionOutputLayout) {
4481   const absl::string_view hlo_text = R"(
4482     HloModule MOFusionOutputLayout
4483 
4484     fused_computation {
4485       param_0 = f32[20,20]{1,0} parameter(0)
4486       bitcast = f32[20,20]{0,1} bitcast(param_0)
4487       ROOT tuple = (f32[20,20]{0,1}) tuple(bitcast)
4488     }
4489 
4490     ENTRY kernel_entry {
4491       parameter.0 = f32[20,20]{1,0} parameter(0)
4492       ROOT fusion = (f32[20,20]{0,1}) fusion(parameter.0),
4493         kind=kLoop, calls=fused_computation
4494     })";
4495 
4496   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4497   auto args = MakeFakeArguments(m_.get()).value();
4498   TF_ASSERT_OK_AND_ASSIGN(Literal actual_tuple, Evaluate({&args[0]}));
4499   std::vector<Literal> actual_literals = actual_tuple.DecomposeTuple();
4500   EXPECT_TRUE(
4501       absl::c_equal(args[0].data<float>(), actual_literals[0].data<float>()));
4502 }
4503 
4504 // Tests that custom_calls fail to evaluate when no handler is specified.
TEST_F(HloEvaluatorTest,EvaluateCustomCall_NoHandler)4505 TEST_F(HloEvaluatorTest, EvaluateCustomCall_NoHandler) {
4506   const absl::string_view hlo_text = R"(
4507     HloModule EvaluateCustomCall_NoHandler
4508     ENTRY kernel_entry {
4509       parameter.0 = u32[2,2]{1,0} parameter(0)
4510       ROOT test_root = (u32[2,2]{1,0}) custom-call(parameter.0),
4511           custom_call_target="_my_custom_call"
4512     }
4513   )";
4514 
4515   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4516   auto args = MakeFakeArguments(m_.get()).value();
4517   EXPECT_EQ(HloEvaluator().Evaluate(*m_, {&args[0]}).status().code(),
4518             ::tensorflow::error::UNIMPLEMENTED);
4519 }
4520 
4521 // Tests when a custom_call handler returns an error.
TEST_F(HloEvaluatorTest,EvaluateCustomCall_HandlerError)4522 TEST_F(HloEvaluatorTest, EvaluateCustomCall_HandlerError) {
4523   const absl::string_view hlo_text = R"(
4524     HloModule EvaluateCustomCall_HandlerError
4525     ENTRY kernel_entry {
4526       parameter.0 = u32[2,2]{1,0} parameter(0)
4527       ROOT test_root = (u32[2,2]{1,0}) custom-call(parameter.0),
4528           custom_call_target="_my_custom_call"
4529     }
4530   )";
4531 
4532   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4533   auto args = MakeFakeArguments(m_.get()).value();
4534   HloEvaluator evaluator;
4535   evaluator.set_custom_call_handler(
4536       [](HloInstruction* custom_call, absl::Span<const Literal*> operands) {
4537         return InternalError("Test error");
4538       });
4539   EXPECT_EQ(evaluator.Evaluate(*m_, {&args[0]}).status().code(),
4540             ::tensorflow::error::INTERNAL);
4541 }
4542 
4543 // Tests the custom_call handler on calls with many inputs.
4544 // We sum the operands so that we can verify the operand and output literals
4545 // are properly mapped for access.
TEST_F(HloEvaluatorTest,EvaluateCustomCall_ManyInputs)4546 TEST_F(HloEvaluatorTest, EvaluateCustomCall_ManyInputs) {
4547   const absl::string_view hlo_text = R"(
4548     HloModule EvaluateCustomCall_ManyInputs
4549     ENTRY kernel_entry {
4550       parameter.0 = u32[1]{0} parameter(0)
4551       parameter.1 = u32[1]{0} parameter(1)
4552       ROOT test_root = u32[1]{0} custom-call(parameter.0, parameter.1),
4553           custom_call_target="_my_custom_call"
4554     }
4555   )";
4556 
4557   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4558   auto args = MakeFakeArguments(m_.get()).value();
4559   HloEvaluator evaluator;
4560   evaluator.set_custom_call_handler(
4561       [](HloInstruction* custom_call, absl::Span<const Literal*> operands) {
4562         EXPECT_EQ(HloOpcode::kCustomCall, custom_call->opcode());
4563         EXPECT_EQ("_my_custom_call", custom_call->custom_call_target());
4564         EXPECT_EQ(2, custom_call->operand_count());
4565         EXPECT_EQ(2, operands.size());
4566         auto output = Literal::CreateFromShape(custom_call->shape());
4567         auto operand0_data = operands[0]->data<uint32_t>();
4568         auto operand1_data = operands[1]->data<uint32_t>();
4569         auto output_data = output.data<uint32_t>();
4570         output_data[0] = operand0_data[0] + operand1_data[0];
4571         return output;
4572       });
4573   TF_ASSERT_OK_AND_ASSIGN(
4574       Literal actual_literal,
4575       evaluator.Evaluate(*m_->entry_computation(), {&args[0], &args[1]}));
4576   auto arg0_data = args[0].data<uint32_t>();
4577   auto arg1_data = args[1].data<uint32_t>();
4578   std::vector<uint32_t> expected_data = {arg0_data[0] + arg1_data[0]};
4579   EXPECT_TRUE(absl::c_equal(expected_data, actual_literal.data<uint32_t>()));
4580 }
4581 
TEST_F(HloEvaluatorTest,IsFiniteF16)4582 TEST_F(HloEvaluatorTest, IsFiniteF16) {
4583   const absl::string_view hlo_text = R"(
4584   HloModule test
4585 
4586   ENTRY IsFiniteTest {
4587     c = f16[6] constant({nan, 7, nan, -1, inf, -inf})
4588     ROOT is-finite = pred[6] is-finite(c)
4589   })";
4590 
4591   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4592   TF_ASSERT_OK_AND_ASSIGN(
4593       Literal actual_literal,
4594       HloEvaluator().Evaluate(*m_->entry_computation(), {}));
4595   EXPECT_THAT(actual_literal.data<bool>(),
4596               ::testing::ElementsAre(false, true, false, true, false, false));
4597 }
4598 
TEST_F(HloEvaluatorTest,IsFiniteBf16)4599 TEST_F(HloEvaluatorTest, IsFiniteBf16) {
4600   const absl::string_view hlo_text = R"(
4601   HloModule test
4602 
4603   ENTRY IsFiniteTest {
4604     c = bf16[6] constant({nan, 7, nan, -1, inf, -inf})
4605     ROOT is-finite = pred[6] is-finite(c)
4606   })";
4607 
4608   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4609   TF_ASSERT_OK_AND_ASSIGN(
4610       Literal actual_literal,
4611       HloEvaluator().Evaluate(*m_->entry_computation(), {}));
4612   EXPECT_THAT(actual_literal.data<bool>(),
4613               ::testing::ElementsAre(false, true, false, true, false, false));
4614 }
4615 
4616 // Check that evaluating `f32[<huge>, 0] iota` doesn't oom (it's an empty
4617 // array!).
TEST_F(HloEvaluatorTest,ZeroSizedIotaWithHugeDimension)4618 TEST_F(HloEvaluatorTest, ZeroSizedIotaWithHugeDimension) {
4619   const absl::string_view hlo_text = R"(
4620   HloModule test
4621   ENTRY t {
4622     ROOT i = f32[1000000000000, 0] iota(), iota_dimension=0
4623   })";
4624   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4625   TF_ASSERT_OK_AND_ASSIGN(
4626       Literal actual_literal,
4627       HloEvaluator().Evaluate(*m_->entry_computation(), {}));
4628   EXPECT_THAT(actual_literal.data<float>(), ::testing::IsEmpty());
4629 }
4630 
TEST_F(HloEvaluatorTest,CopyStartCopyDone)4631 TEST_F(HloEvaluatorTest, CopyStartCopyDone) {
4632   const absl::string_view hlo_text = R"(
4633   HloModule test
4634   ENTRY CopyStartCopyDone {
4635     init = f32[] constant(42.0)
4636     copy-start = (f32[]{:S(1)}, f32[], u32[]) copy-start(init)
4637     ROOT copy-done = f32[] copy-done(copy-start)
4638   }
4639   )";
4640   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4641   Literal expected = LiteralUtil::CreateR0<float>(42.0f);
4642   TF_ASSERT_OK_AND_ASSIGN(
4643       Literal result, HloEvaluator().Evaluate(*m_->entry_computation(), {}));
4644   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
4645 }
4646 
TEST_F(HloEvaluatorTest,AsyncOps)4647 TEST_F(HloEvaluatorTest, AsyncOps) {
4648   const absl::string_view hlo_text = R"(
4649   HloModule test
4650   ENTRY AsyncOps {
4651     init = f32[] constant(42.0)
4652     async-start = ((f32[]), f32[], u32[]) negate-start(init)
4653     async-update = ((f32[]), f32[], u32[]) negate-update(async-start)
4654     ROOT async-done = f32[] negate-done(async-update)
4655   }
4656   )";
4657   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4658   Literal expected = LiteralUtil::CreateR0<float>(-42.0f);
4659   TF_ASSERT_OK_AND_ASSIGN(
4660       Literal result, HloEvaluator().Evaluate(*m_->entry_computation(), {}));
4661   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
4662 }
4663 
TEST_F(HloEvaluatorTest,MapBF16)4664 TEST_F(HloEvaluatorTest, MapBF16) {
4665   const absl::string_view hlo_text = R"(
4666   HloModule test
4667 
4668   map_computation {
4669     p = bf16[] parameter(0)
4670     add = bf16[] add(p, p)
4671     ROOT conv = f32[] convert(add)
4672   }
4673 
4674   ENTRY CopyStartCopyDone {
4675     c = bf16[3] constant({1, 2, 3})
4676     ROOT map = f32[3] map(c), to_apply=map_computation
4677   }
4678   )";
4679   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4680   Literal expected = LiteralUtil::CreateR1<float>({2.f, 4.f, 6.f});
4681   TF_ASSERT_OK_AND_ASSIGN(
4682       Literal result, HloEvaluator().Evaluate(*m_->entry_computation(), {}));
4683   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
4684 }
4685 
TEST_F(HloEvaluatorTest,MapS16)4686 TEST_F(HloEvaluatorTest, MapS16) {
4687   const absl::string_view hlo_text = R"(
4688   HloModule test
4689 
4690   map_computation {
4691     p = s16[] parameter(0)
4692     add = s16[] add(p, p)
4693     ROOT conv = f32[] convert(add)
4694   }
4695 
4696   ENTRY CopyStartCopyDone {
4697     c = s16[3] constant({1, 2, 3})
4698     ROOT map = f32[3] map(c), to_apply=map_computation
4699   }
4700   )";
4701   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4702   Literal expected = LiteralUtil::CreateR1<float>({2.f, 4.f, 6.f});
4703   TF_ASSERT_OK_AND_ASSIGN(
4704       Literal result, HloEvaluator().Evaluate(*m_->entry_computation(), {}));
4705   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
4706 }
4707 
TEST_F(HloEvaluatorTest,MapU16)4708 TEST_F(HloEvaluatorTest, MapU16) {
4709   const absl::string_view hlo_text = R"(
4710   HloModule test
4711 
4712   map_computation {
4713     p = u16[] parameter(0)
4714     add = u16[] add(p, p)
4715     ROOT conv = f32[] convert(add)
4716   }
4717 
4718   ENTRY CopyStartCopyDone {
4719     c = u16[3] constant({1, 2, 3})
4720     ROOT map = f32[3] map(c), to_apply=map_computation
4721   }
4722   )";
4723   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4724   Literal expected = LiteralUtil::CreateR1<float>({2.f, 4.f, 6.f});
4725   TF_ASSERT_OK_AND_ASSIGN(
4726       Literal result, HloEvaluator().Evaluate(*m_->entry_computation(), {}));
4727   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
4728 }
4729 
TEST_F(HloEvaluatorTest,MapMixed)4730 TEST_F(HloEvaluatorTest, MapMixed) {
4731   const absl::string_view hlo_text = R"(
4732   HloModule test
4733 
4734   map_computation {
4735     p0 = u16[] parameter(0)
4736     p1 = f32[] parameter(1)
4737     c0 = f32[] convert(p0)
4738     ROOT add = f32[] add(c0, p1)
4739   }
4740 
4741   ENTRY CopyStartCopyDone {
4742     c0 = u16[3] constant({1, 2, 3})
4743     c1 = f32[3] constant({1.5, 2.5, 3.5})
4744     ROOT map = f32[3] map(c0, c1), to_apply=map_computation
4745   }
4746   )";
4747   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4748   Literal expected = LiteralUtil::CreateR1<float>({2.5f, 4.5f, 6.5f});
4749   TF_ASSERT_OK_AND_ASSIGN(
4750       Literal result, HloEvaluator().Evaluate(*m_->entry_computation(), {}));
4751   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
4752 }
4753 
TEST_F(HloEvaluatorTest,DotUpcast)4754 TEST_F(HloEvaluatorTest, DotUpcast) {
4755   const absl::string_view hlo_text = R"(
4756   HloModule test
4757   ENTRY DotUpcast {
4758     l = s16[4,3]{1,0} parameter(0)
4759     r = s8[3,2]{1,0} parameter(1)
4760     ROOT result = s32[4,2] dot(l, r), lhs_contracting_dims={1},
4761                                       rhs_contracting_dims={0}
4762   }
4763   )";
4764   // lhs:
4765   // s16[4,3] {
4766   //  { 1, 2, 3 },
4767   //  { 5, 6, 7 },
4768   //  { 9, 10, 11 },
4769   //  { 13, 14, 15 },
4770   // }
4771   auto lhs_array = std::make_unique<Array2D<int16_t>>(4, 3);
4772   lhs_array->FillUnique(1);
4773   auto lhs_literal = LiteralUtil::CreateR2FromArray2D<int16_t>(*lhs_array);
4774 
4775   // rhs:
4776   // s8[3,2] {
4777   //  { 1, 2 },
4778   //  { 3, 4 },
4779   //  { 5, 6 },
4780   // }
4781   auto rhs_array = std::make_unique<Array2D<int8_t>>(3, 2);
4782   rhs_array->FillUnique(1);
4783   auto rhs_literal = LiteralUtil::CreateR2FromArray2D<int8_t>(*rhs_array);
4784   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4785   TF_ASSERT_OK_AND_ASSIGN(Literal result,
4786                           Evaluate({&lhs_literal, &rhs_literal}));
4787 
4788   auto expected_array =
4789       Array2D<int32_t>({{22, 28}, {58, 76}, {94, 124}, {130, 172}});
4790   auto expected = LiteralUtil::CreateR2FromArray2D<int32_t>(expected_array);
4791 
4792   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
4793 }
4794 
TEST_F(HloEvaluatorTest,SortC64)4795 TEST_F(HloEvaluatorTest, SortC64) {
4796   const absl::string_view hlo_text = R"(
4797   HloModule m
4798 
4799   sort_lt_comparator {
4800     parameter.0 = c64[] parameter(0)
4801     real.0 = f32[] real(parameter.0)
4802     parameter.1 = c64[] parameter(1)
4803     real.1 = f32[] real(parameter.1)
4804     ROOT compare = pred[] compare(real.0, real.1), direction=LT
4805   }
4806 
4807   ENTRY main {
4808     c = c64[3] constant({(2, 0), (4, 0), (6, 0)})
4809     ROOT sort = c64[3]{0} sort(c), dimensions={0}, to_apply=sort_lt_comparator
4810   }
4811   )";
4812   TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4813   Literal expected =
4814       LiteralUtil::CreateR1<std::complex<float>>({2.f, 4.f, 6.f});
4815   TF_ASSERT_OK_AND_ASSIGN(
4816       Literal result, HloEvaluator().Evaluate(*m_->entry_computation(), {}));
4817   EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
4818 }
4819 
4820 // Tests that HloEvaluator can evaluate an instruction even when its operands
4821 // are not constant.
TEST_F(HloEvaluatorTest,RecursivelyEvaluateNonConstantOperands)4822 TEST_F(HloEvaluatorTest, RecursivelyEvaluateNonConstantOperands) {
4823   Literal c0_literal = LiteralUtil::CreateR2<float>({{0.f, 2.f}, {2.f, 4.f}});
4824   Literal c1_literal = LiteralUtil::CreateR2<float>({{0.f, 5.f}, {0.f, 4.f}});
4825   Literal c2_literal = LiteralUtil::CreateR2<float>({{2.f, 4.f}, {4.f, 4.f}});
4826 
4827   Shape shape = c0_literal.shape();
4828   HloComputation::Builder b(TestName());
4829   HloInstruction* c0 =
4830       b.AddInstruction(HloInstruction::CreateConstant(std::move(c0_literal)));
4831   HloInstruction* c1 =
4832       b.AddInstruction(HloInstruction::CreateConstant(std::move(c1_literal)));
4833   HloInstruction* c2 =
4834       b.AddInstruction(HloInstruction::CreateConstant(std::move(c2_literal)));
4835 
4836   HloInstruction* add0 = b.AddInstruction(
4837       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, c0, c1));
4838   HloInstruction* add1 = b.AddInstruction(
4839       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, c1, c2));
4840   HloInstruction* add2 = b.AddInstruction(
4841       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add0, add1));
4842 
4843   m_->AddEntryComputation(b.Build());
4844   Literal expected = LiteralUtil::CreateR2<float>({{2, 16}, {6, 16}});
4845   TestRecursivelyEvaluateInstruction(add2, expected);
4846 }
4847 
4848 // Tests that HloEvaluator can evaluate a GetTupleElement even when its operand
4849 // Tuple instruction cannot be fully evaluated. Note that this requires that the
4850 //  tuple element at the given tuple index can be evaluated.
TEST_F(HloEvaluatorTest,GetTupleElementOnPartiallyKnownTupleSucceeds)4851 TEST_F(HloEvaluatorTest, GetTupleElementOnPartiallyKnownTupleSucceeds) {
4852   Literal c0_literal = LiteralUtil::CreateR2<float>({{0.f, 2.f}, {2.f, 4.f}});
4853 
4854   Shape shape = c0_literal.shape();
4855   HloComputation::Builder b(TestName());
4856   HloInstruction* c0 =
4857       b.AddInstruction(HloInstruction::CreateConstant(std::move(c0_literal)));
4858   HloInstruction* p0 =
4859       b.AddInstruction(HloInstruction::CreateParameter(0, shape, "param.0"));
4860   HloInstruction* p1 =
4861       b.AddInstruction(HloInstruction::CreateParameter(1, shape, "param.1"));
4862 
4863   HloInstruction* tuple =
4864       b.AddInstruction(HloInstruction::CreateTuple({p0, p1, c0}));
4865   HloInstruction* gte =
4866       b.AddInstruction(HloInstruction::CreateGetTupleElement(tuple, 2));
4867 
4868   m_->AddEntryComputation(b.Build());
4869   Literal expected = LiteralUtil::CreateR2<float>({{0.f, 2.f}, {2.f, 4.f}});
4870   TestRecursivelyEvaluateInstruction(gte, expected);
4871 }
4872 
4873 // Tests that Infeed cannot be evaluated.
TEST_F(HloEvaluatorTest,InfeedFailure)4874 TEST_F(HloEvaluatorTest, InfeedFailure) {
4875   HloComputation::Builder b(TestName());
4876   HloInstruction* token = b.AddInstruction(HloInstruction::CreateToken());
4877   HloInstruction* infeed = b.AddInstruction(HloInstruction::CreateInfeed(
4878       ShapeUtil::MakeShape(F32, {4, 4}), token, ""));
4879 
4880   m_->AddEntryComputation(b.Build());
4881   TestRecursiveEvaluationFailure(infeed);
4882 }
4883 
4884 // Tests that GetTupleElement cannot be evaluated if the corresponding tuple
4885 // element cannot be evaluated.
TEST_F(HloEvaluatorTest,GetUnknownTupleElementFails)4886 TEST_F(HloEvaluatorTest, GetUnknownTupleElementFails) {
4887   Literal c0_literal = LiteralUtil::CreateR2<float>({{0.f, 2.f}, {2.f, 4.f}});
4888 
4889   Shape shape = c0_literal.shape();
4890   HloComputation::Builder b(TestName());
4891   HloInstruction* c0 =
4892       b.AddInstruction(HloInstruction::CreateConstant(std::move(c0_literal)));
4893   HloInstruction* p0 =
4894       b.AddInstruction(HloInstruction::CreateParameter(0, shape, "param.0"));
4895   HloInstruction* p1 =
4896       b.AddInstruction(HloInstruction::CreateParameter(1, shape, "param.1"));
4897 
4898   HloInstruction* tuple =
4899       b.AddInstruction(HloInstruction::CreateTuple({p0, p1, c0}));
4900   HloInstruction* gte =
4901       b.AddInstruction(HloInstruction::CreateGetTupleElement(tuple, 0));
4902 
4903   m_->AddEntryComputation(b.Build());
4904   TestRecursiveEvaluationFailure(gte);
4905 }
4906 
4907 // Tests that partial evaluation works for nested tuples.
TEST_F(HloEvaluatorTest,GetTupleElementFromNestedTupleSucceeds)4908 TEST_F(HloEvaluatorTest, GetTupleElementFromNestedTupleSucceeds) {
4909   Literal c0_literal = LiteralUtil::CreateR2<float>({{0.f, 2.f}, {2.f, 4.f}});
4910 
4911   Shape shape = c0_literal.shape();
4912   HloComputation::Builder b(TestName());
4913   HloInstruction* c0 =
4914       b.AddInstruction(HloInstruction::CreateConstant(std::move(c0_literal)));
4915   HloInstruction* p0 =
4916       b.AddInstruction(HloInstruction::CreateParameter(0, shape, "param.0"));
4917   HloInstruction* p1 =
4918       b.AddInstruction(HloInstruction::CreateParameter(1, shape, "param.1"));
4919 
4920   HloInstruction* tuple0 =
4921       b.AddInstruction(HloInstruction::CreateTuple({p0, c0}));
4922   HloInstruction* tuple1 =
4923       b.AddInstruction(HloInstruction::CreateTuple({tuple0, p1}));
4924   HloInstruction* gte0 =
4925       b.AddInstruction(HloInstruction::CreateGetTupleElement(tuple1, 0));
4926   HloInstruction* gte1 =
4927       b.AddInstruction(HloInstruction::CreateGetTupleElement(gte0, 1));
4928 
4929   m_->AddEntryComputation(b.Build());
4930   Literal expected = LiteralUtil::CreateR2<float>({{0.f, 2.f}, {2.f, 4.f}});
4931   TestRecursivelyEvaluateInstruction(gte1, expected);
4932 }
4933 
4934 // Tests that partial evaluation works when the GetTupleElement is interleaved
4935 // with other Tuple instructions.
TEST_F(HloEvaluatorTest,GetTupleElementInterleavedWithTupleSucceeds)4936 TEST_F(HloEvaluatorTest, GetTupleElementInterleavedWithTupleSucceeds) {
4937   Literal c0_literal = LiteralUtil::CreateR2<float>({{0.f, 2.f}, {2.f, 4.f}});
4938 
4939   Shape shape = c0_literal.shape();
4940   HloComputation::Builder b(TestName());
4941   HloInstruction* c0 =
4942       b.AddInstruction(HloInstruction::CreateConstant(std::move(c0_literal)));
4943   HloInstruction* p0 =
4944       b.AddInstruction(HloInstruction::CreateParameter(0, shape, "param.0"));
4945   HloInstruction* p1 =
4946       b.AddInstruction(HloInstruction::CreateParameter(1, shape, "param.1"));
4947   HloInstruction* p2 =
4948       b.AddInstruction(HloInstruction::CreateParameter(2, shape, "param.2"));
4949 
4950   HloInstruction* tuple0 =
4951       b.AddInstruction(HloInstruction::CreateTuple({p0, c0}));
4952   HloInstruction* tuple1 =
4953       b.AddInstruction(HloInstruction::CreateTuple({tuple0, p1}));
4954   HloInstruction* gte0 =
4955       b.AddInstruction(HloInstruction::CreateGetTupleElement(tuple1, 0));
4956   HloInstruction* tuple2 =
4957       b.AddInstruction(HloInstruction::CreateTuple({gte0, p2}));
4958   HloInstruction* gte1 =
4959       b.AddInstruction(HloInstruction::CreateGetTupleElement(tuple2, 0));
4960   HloInstruction* gte2 =
4961       b.AddInstruction(HloInstruction::CreateGetTupleElement(gte1, 1));
4962 
4963   m_->AddEntryComputation(b.Build());
4964   Literal expected = LiteralUtil::CreateR2<float>({{0.f, 2.f}, {2.f, 4.f}});
4965   TestRecursivelyEvaluateInstruction(gte2, expected);
4966 }
4967 
4968 class PatternMatchParseWhileLoopTest : public HloTestBase {};
4969 
TEST_F(PatternMatchParseWhileLoopTest,LoopBoundDefinedInsideOfCond)4970 TEST_F(PatternMatchParseWhileLoopTest, LoopBoundDefinedInsideOfCond) {
4971   constexpr absl::string_view kHloModule = R"(
4972     HloModule accumulated_all_reduce
4973 
4974     %while_condition {
4975       %param = (s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
4976       %gte.0 = s32[] get-tuple-element(%param), index=0
4977       %loop_bound = s32[] constant(5)
4978       ROOT result = pred[] compare(%gte.0, %loop_bound), direction=LT
4979     }
4980 
4981     %while_body {
4982       %param = (s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
4983       %gte.0 = s32[] get-tuple-element(%param), index=0
4984       %gte.1 = f32[1024, 1024] get-tuple-element(%param), index=1
4985       %gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2
4986       %accumulation = f32[1024, 1024] add(f32[1024, 1024] %gte.1, f32[1024, 1024] %gte.2)
4987       %constant = s32[] constant(1)
4988       %increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant)
4989       ROOT %loop_result = (s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %accumulation)
4990     }
4991 
4992     ENTRY accumulated_all_reduce {
4993       %param.1 = f32[1024, 1024] parameter(0)
4994       %constant.0 = s32[] constant(0)
4995       %accumulation_buffer_init = f32[] constant(0)
4996       %accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={}
4997       %while_init = (s32[], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer)
4998       %while = (s32[], f32[1024, 1024], f32[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body
4999       ROOT %result = f32[1024, 1024] get-tuple-element((s32[], f32[1024, 1024], f32[1024, 1024]) %while), index=2
5000     }
5001   )";
5002   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> hlo_module,
5003                           ParseAndReturnVerifiedModule(kHloModule));
5004   HloInstruction* while_op =
5005       hlo_module->entry_computation()->root_instruction()->mutable_operand(0);
5006   std::optional<ParsedWhileLoop> parsed_while_loop =
5007       PatternMatchParseWhileLoop(while_op);
5008   ASSERT_TRUE(parsed_while_loop.has_value());
5009   EXPECT_FALSE(parsed_while_loop->is_dynamic());
5010   EXPECT_EQ(parsed_while_loop->static_while_loop->trip_count, 5);
5011   EXPECT_EQ(parsed_while_loop->static_while_loop->induction_var_index, 0);
5012   EXPECT_EQ(parsed_while_loop->static_while_loop->induction_var_init_value, 0);
5013   EXPECT_EQ(parsed_while_loop->static_while_loop->step_size, 1);
5014   EXPECT_EQ(parsed_while_loop->static_while_loop->loop_bound, 5);
5015 }
5016 
TEST_F(PatternMatchParseWhileLoopTest,LoopBoundDefinedOutsideOfCond)5017 TEST_F(PatternMatchParseWhileLoopTest, LoopBoundDefinedOutsideOfCond) {
5018   constexpr absl::string_view kHloModule = R"(
5019     HloModule accumulated_all_reduce
5020 
5021     %while_condition {
5022       %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
5023       %gte.0 = s32[] get-tuple-element(%param), index=0
5024       %gte.1 = s32[] get-tuple-element(%param), index=1
5025       ROOT result = pred[] compare(%gte.0, %gte.1), direction=LT
5026     }
5027 
5028     %while_body {
5029       %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
5030       %gte.0 = s32[] get-tuple-element(%param), index=0
5031       %gte.1 = s32[] get-tuple-element(%param), index=1
5032       %gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2
5033       %gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3
5034       %accumulation = f32[1024, 1024] add(f32[1024, 1024] %gte.2, f32[1024, 1024] %gte.3)
5035       %constant = s32[] constant(1)
5036       %increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant)
5037       ROOT %loop_result = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation)
5038     }
5039 
5040     ENTRY accumulated_all_reduce {
5041       %param.1 = f32[1024, 1024] parameter(0)
5042       %constant.0 = s32[] constant(0)
5043       %constant.1 = s32[] constant(10)
5044       %accumulation_buffer_init = f32[] constant(0)
5045       %accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={}
5046       %while_init = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, s32[] %constant.1, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer)
5047       %while = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body
5048       ROOT %result = f32[1024, 1024] get-tuple-element((s32[], s32[], f32[1024, 1024], f32[1024, 1024]) %while), index=3
5049     }
5050   )";
5051   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> hlo_module,
5052                           ParseAndReturnVerifiedModule(kHloModule));
5053   HloInstruction* while_op =
5054       hlo_module->entry_computation()->root_instruction()->mutable_operand(0);
5055   std::optional<ParsedWhileLoop> parsed_while_loop =
5056       PatternMatchParseWhileLoop(while_op);
5057   ASSERT_TRUE(parsed_while_loop.has_value());
5058   EXPECT_FALSE(parsed_while_loop->is_dynamic());
5059   EXPECT_EQ(parsed_while_loop->static_while_loop->trip_count, 10);
5060   EXPECT_EQ(parsed_while_loop->static_while_loop->induction_var_index, 0);
5061   EXPECT_EQ(parsed_while_loop->static_while_loop->induction_var_init_value, 0);
5062   EXPECT_EQ(parsed_while_loop->static_while_loop->step_size, 1);
5063   EXPECT_EQ(parsed_while_loop->static_while_loop->loop_bound, 10);
5064 }
5065 
TEST_F(PatternMatchParseWhileLoopTest,LoopBoundComputedOutsideOfCond)5066 TEST_F(PatternMatchParseWhileLoopTest, LoopBoundComputedOutsideOfCond) {
5067   constexpr absl::string_view kHloModule = R"(
5068     HloModule accumulated_all_reduce
5069 
5070     %while_condition {
5071       %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
5072       %gte.0 = s32[] get-tuple-element(%param), index=0
5073       %gte.1 = s32[] get-tuple-element(%param), index=1
5074       ROOT result = pred[] compare(%gte.0, %gte.1), direction=LT
5075     }
5076 
5077     %while_body {
5078       %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
5079       %gte.0 = s32[] get-tuple-element(%param), index=0
5080       %gte.1 = s32[] get-tuple-element(%param), index=1
5081       %gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2
5082       %gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3
5083       %accumulation = f32[1024, 1024] add(f32[1024, 1024] %gte.2, f32[1024, 1024] %gte.3)
5084       %constant = s32[] constant(1)
5085       %increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant)
5086       ROOT %loop_result = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation)
5087     }
5088 
5089     ENTRY accumulated_all_reduce {
5090       %param.1 = f32[1024, 1024] parameter(0)
5091       %constant.0 = s32[] constant(0)
5092       %constant.1 = s32[] constant(10)
5093       %constant.2 = s32[] constant(4)
5094       %loop_bound = s32[] multiply(s32[] %constant.1, s32[] %constant.2)
5095       %accumulation_buffer_init = f32[] constant(0)
5096       %accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={}
5097       %while_init = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, s32[] %loop_bound, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer)
5098       %while = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body
5099       ROOT %result = f32[1024, 1024] get-tuple-element((s32[], s32[], f32[1024, 1024], f32[1024, 1024]) %while), index=3
5100     }
5101   )";
5102   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> hlo_module,
5103                           ParseAndReturnVerifiedModule(kHloModule));
5104   HloInstruction* while_op =
5105       hlo_module->entry_computation()->root_instruction()->mutable_operand(0);
5106   std::optional<ParsedWhileLoop> parsed_while_loop =
5107       PatternMatchParseWhileLoop(while_op);
5108   ASSERT_TRUE(parsed_while_loop.has_value());
5109   EXPECT_FALSE(parsed_while_loop->is_dynamic());
5110   EXPECT_EQ(parsed_while_loop->static_while_loop->trip_count, 40);
5111   EXPECT_EQ(parsed_while_loop->static_while_loop->induction_var_index, 0);
5112   EXPECT_EQ(parsed_while_loop->static_while_loop->induction_var_init_value, 0);
5113   EXPECT_EQ(parsed_while_loop->static_while_loop->step_size, 1);
5114   EXPECT_EQ(parsed_while_loop->static_while_loop->loop_bound, 40);
5115 }
5116 
TEST_F(PatternMatchParseWhileLoopTest,StepSizeNotOne)5117 TEST_F(PatternMatchParseWhileLoopTest, StepSizeNotOne) {
5118   constexpr absl::string_view kHloModule = R"(
5119     HloModule accumulated_all_reduce
5120 
5121     %while_condition {
5122       %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
5123       %gte.0 = s32[] get-tuple-element(%param), index=0
5124       %gte.1 = s32[] get-tuple-element(%param), index=1
5125       ROOT result = pred[] compare(%gte.0, %gte.1), direction=LT
5126     }
5127 
5128     %while_body {
5129       %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
5130       %gte.0 = s32[] get-tuple-element(%param), index=0
5131       %gte.1 = s32[] get-tuple-element(%param), index=1
5132       %gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2
5133       %gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3
5134       %accumulation = f32[1024, 1024] add(f32[1024, 1024] %gte.2, f32[1024, 1024] %gte.3)
5135       %constant = s32[] constant(4)
5136       %increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant)
5137       ROOT %loop_result = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation)
5138     }
5139 
5140     ENTRY accumulated_all_reduce {
5141       %param.1 = f32[1024, 1024] parameter(0)
5142       %constant.0 = s32[] constant(0)
5143       %constant.1 = s32[] constant(10)
5144       %constant.2 = s32[] constant(4)
5145       %loop_bound = s32[] multiply(s32[] %constant.1, s32[] %constant.2)
5146       %accumulation_buffer_init = f32[] constant(0)
5147       %accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={}
5148       %while_init = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, s32[] %loop_bound, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer)
5149       %while = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body
5150       ROOT %result = f32[1024, 1024] get-tuple-element((s32[], s32[], f32[1024, 1024], f32[1024, 1024]) %while), index=3
5151     }
5152   )";
5153   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> hlo_module,
5154                           ParseAndReturnVerifiedModule(kHloModule));
5155   HloInstruction* while_op =
5156       hlo_module->entry_computation()->root_instruction()->mutable_operand(0);
5157   std::optional<ParsedWhileLoop> parsed_while_loop =
5158       PatternMatchParseWhileLoop(while_op);
5159   ASSERT_TRUE(parsed_while_loop.has_value());
5160   EXPECT_FALSE(parsed_while_loop->is_dynamic());
5161   EXPECT_EQ(parsed_while_loop->static_while_loop->trip_count, 10);
5162   EXPECT_EQ(parsed_while_loop->static_while_loop->induction_var_index, 0);
5163   EXPECT_EQ(parsed_while_loop->static_while_loop->induction_var_init_value, 0);
5164   EXPECT_EQ(parsed_while_loop->static_while_loop->step_size, 4);
5165   EXPECT_EQ(parsed_while_loop->static_while_loop->loop_bound, 40);
5166 }
5167 
5168 // The loop condition comparison is computed by a call to another computation.
TEST_F(PatternMatchParseWhileLoopTest,RecursiveCond)5169 TEST_F(PatternMatchParseWhileLoopTest, RecursiveCond) {
5170   constexpr absl::string_view kHloModule = R"(
5171     HloModule accumulated_all_reduce
5172 
5173     %compute_pred {
5174       %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
5175       %gte.0 = s32[] get-tuple-element(%param), index=0
5176       %gte.1 = s32[] get-tuple-element(%param), index=1
5177       %compare = pred[] compare(gte.0, %gte.1), direction=LT
5178       ROOT %tuple = (pred[]) tuple(pred[] %compare)
5179     }
5180 
5181     %while_condition {
5182       %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
5183       %call = (pred[]) call((s32[], s32[], f32[1024, 1024], f32[1024, 1024]) %param), to_apply=%compute_pred
5184       ROOT %gte.4 = pred[] get-tuple-element((pred[]) %call), index=0
5185     }
5186 
5187     %while_body {
5188       %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
5189       %gte.0 = s32[] get-tuple-element(%param), index=0
5190       %gte.1 = s32[] get-tuple-element(%param), index=1
5191       %gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2
5192       %gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3
5193       %accumulation = f32[1024, 1024] add(f32[1024, 1024] %gte.2, f32[1024, 1024] %gte.3)
5194       %constant = s32[] constant(1)
5195       %increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant)
5196       ROOT %loop_result = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation)
5197     }
5198 
5199     ENTRY accumulated_all_reduce {
5200       %param.1 = f32[1024, 1024] parameter(0)
5201       %constant.0 = s32[] constant(0)
5202       %loop_bound = s32[] constant(10)
5203       %accumulation_buffer_init = f32[] constant(0)
5204       %accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={}
5205       %while_init = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, s32[] %loop_bound, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer)
5206       %while = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body
5207       ROOT %result = f32[1024, 1024] get-tuple-element((s32[], s32[], f32[1024, 1024], f32[1024, 1024]) %while), index=3
5208     }
5209   )";
5210   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> hlo_module,
5211                           ParseAndReturnVerifiedModule(kHloModule));
5212   HloInstruction* while_op =
5213       hlo_module->entry_computation()->root_instruction()->mutable_operand(0);
5214   std::optional<ParsedWhileLoop> parsed_while_loop =
5215       PatternMatchParseWhileLoop(while_op);
5216   ASSERT_TRUE(parsed_while_loop.has_value());
5217   EXPECT_FALSE(parsed_while_loop->is_dynamic());
5218   EXPECT_EQ(parsed_while_loop->static_while_loop->trip_count, 10);
5219   EXPECT_EQ(parsed_while_loop->static_while_loop->induction_var_index, 0);
5220   EXPECT_EQ(parsed_while_loop->static_while_loop->induction_var_init_value, 0);
5221   EXPECT_EQ(parsed_while_loop->static_while_loop->step_size, 1);
5222   EXPECT_EQ(parsed_while_loop->static_while_loop->loop_bound, 10);
5223 }
5224 
5225 // The loop condition comparison is computed by a call to another computation.
5226 // The called computation could be calling another computation and could use
5227 // get-tuple-element to extract the result.
TEST_F(PatternMatchParseWhileLoopTest,RecursiveCondGetTupleElement)5228 TEST_F(PatternMatchParseWhileLoopTest, RecursiveCondGetTupleElement) {
5229   constexpr absl::string_view kHloModule = R"(
5230     HloModule accumulated_all_reduce
5231 
5232     %compute_pred {
5233       %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
5234       %gte.0 = s32[] get-tuple-element(%param), index=0
5235       %gte.1 = s32[] get-tuple-element(%param), index=1
5236       %compare = pred[] compare(gte.0, %gte.1), direction=LT
5237       ROOT %tuple = (pred[]) tuple(pred[] %compare)
5238     }
5239 
5240     %get_tuple_element {
5241       %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
5242       %call = (pred[]) call((s32[], s32[], f32[1024, 1024], f32[1024, 1024]) %param), to_apply=%compute_pred
5243       %gte.4 = pred[] get-tuple-element((pred[]) %call), index=0
5244       ROOT %tuple.1 = (pred[]) tuple(pred[] %gte.4)
5245     }
5246     %while_condition {
5247       %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
5248       %call = (pred[]) call((s32[], s32[], f32[1024, 1024], f32[1024, 1024]) %param), to_apply=%get_tuple_element
5249       ROOT %gte.4 = pred[] get-tuple-element((pred[]) %call), index=0
5250     }
5251 
5252     %while_body {
5253       %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
5254       %gte.0 = s32[] get-tuple-element(%param), index=0
5255       %gte.1 = s32[] get-tuple-element(%param), index=1
5256       %gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2
5257       %gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3
5258       %accumulation = f32[1024, 1024] add(f32[1024, 1024] %gte.2, f32[1024, 1024] %gte.3)
5259       %constant = s32[] constant(1)
5260       %increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant)
5261       ROOT %loop_result = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation)
5262     }
5263 
5264     ENTRY accumulated_all_reduce {
5265       %param.1 = f32[1024, 1024] parameter(0)
5266       %constant.0 = s32[] constant(0)
5267       %loop_bound = s32[] constant(10)
5268       %accumulation_buffer_init = f32[] constant(0)
5269       %accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={}
5270       %while_init = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, s32[] %loop_bound, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer)
5271       %while = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body
5272       ROOT %result = f32[1024, 1024] get-tuple-element((s32[], s32[], f32[1024, 1024], f32[1024, 1024]) %while), index=3
5273     }
5274   )";
5275   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> hlo_module,
5276                           ParseAndReturnVerifiedModule(kHloModule));
5277   HloInstruction* while_op =
5278       hlo_module->entry_computation()->root_instruction()->mutable_operand(0);
5279   std::optional<ParsedWhileLoop> parsed_while_loop =
5280       PatternMatchParseWhileLoop(while_op);
5281   ASSERT_TRUE(parsed_while_loop.has_value());
5282   EXPECT_FALSE(parsed_while_loop->is_dynamic());
5283   EXPECT_EQ(parsed_while_loop->static_while_loop->trip_count, 10);
5284   EXPECT_EQ(parsed_while_loop->static_while_loop->induction_var_index, 0);
5285   EXPECT_EQ(parsed_while_loop->static_while_loop->induction_var_init_value, 0);
5286   EXPECT_EQ(parsed_while_loop->static_while_loop->step_size, 1);
5287   EXPECT_EQ(parsed_while_loop->static_while_loop->loop_bound, 10);
5288 }
5289 
TEST_F(PatternMatchParseWhileLoopTest,LoopBoundDependsOnAnotherLoop)5290 TEST_F(PatternMatchParseWhileLoopTest, LoopBoundDependsOnAnotherLoop) {
5291   constexpr absl::string_view kHloModule = R"(
5292     HloModule accumulated_all_reduce
5293 
5294     %compute_pred.0 {
5295       %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
5296       %gte.0 = s32[] get-tuple-element(%param), index=0
5297       %gte.1 = s32[] get-tuple-element(%param), index=1
5298       %compare = pred[] compare(gte.0, %gte.1), direction=LT
5299       ROOT %tuple = (pred[]) tuple(pred[] %compare)
5300     }
5301 
5302     %while_condition.0 {
5303       %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
5304       %call = (pred[]) call((s32[], s32[], f32[1024, 1024], f32[1024, 1024]) %param), to_apply=%compute_pred.0
5305       ROOT %gte.4 = pred[] get-tuple-element((pred[]) %call), index=0
5306     }
5307 
5308     %while_body.0 {
5309       %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
5310       %gte.0 = s32[] get-tuple-element(%param), index=0
5311       %gte.1 = s32[] get-tuple-element(%param), index=1
5312       %gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2
5313       %gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3
5314       %accumulation = f32[1024, 1024] add(f32[1024, 1024] %gte.2, f32[1024, 1024] %gte.3)
5315       %constant = s32[] constant(1)
5316       %increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant)
5317       ROOT %loop_result = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation)
5318     }
5319 
5320     %compute_pred.1 {
5321       %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
5322       %gte.0 = s32[] get-tuple-element(%param), index=0
5323       %gte.1 = s32[] get-tuple-element(%param), index=1
5324       %compare = pred[] compare(gte.0, %gte.1), direction=LT
5325       ROOT %tuple = (pred[]) tuple(pred[] %compare)
5326     }
5327 
5328     %while_condition.1 {
5329       %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
5330       %call = (pred[]) call((s32[], s32[], f32[1024, 1024], f32[1024, 1024]) %param), to_apply=%compute_pred.1
5331       ROOT %gte.4 = pred[] get-tuple-element((pred[]) %call), index=0
5332     }
5333 
5334     %while_body.1 {
5335       %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
5336       %gte.0 = s32[] get-tuple-element(%param), index=0
5337       %gte.1 = s32[] get-tuple-element(%param), index=1
5338       %gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2
5339       %gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3
5340       %accumulation = f32[1024, 1024] add(f32[1024, 1024] %gte.2, f32[1024, 1024] %gte.3)
5341       %constant = s32[] constant(1)
5342       %increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant)
5343       ROOT %loop_result = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation)
5344     }
5345 
5346     ENTRY accumulated_all_reduce {
5347       %param.1 = f32[1024, 1024] parameter(0)
5348       %param.2 = f32[1024, 1024] parameter(1)
5349       %constant.0 = s32[] constant(0)
5350       %loop_bound = s32[] constant(10)
5351       %accumulation_buffer_init = f32[] constant(0)
5352       %accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={}
5353       %while_init.0 = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, s32[] %loop_bound, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer)
5354       %while.0 = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) while(%while_init.0), condition=%while_condition.0, body=%while_body.0
5355       %result.0 = f32[1024, 1024] get-tuple-element((s32[], s32[], f32[1024, 1024], f32[1024, 1024]) %while.0), index=3
5356       %new_loop_bound = s32[] get-tuple-element((s32[], s32[], f32[1024, 1024], f32[1024, 1024]) %while.0), index=0
5357       %while_init.1 = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, s32[] %new_loop_bound, f32[1024, 1024] %param.2, f32[1024, 1024] %result.0)
5358       %while.1 = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) while(%while_init.1), condition=%while_condition.1, body=%while_body.1
5359       ROOT %result.1 = f32[1024, 1024] get-tuple-element((s32[], s32[], f32[1024, 1024], f32[1024, 1024]) %while.1), index=3
5360     }
5361   )";
5362   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> hlo_module,
5363                           ParseAndReturnVerifiedModule(kHloModule));
5364   HloInstruction* while_op =
5365       hlo_module->entry_computation()->root_instruction()->mutable_operand(0);
5366   std::optional<ParsedWhileLoop> parsed_while_loop =
5367       PatternMatchParseWhileLoop(while_op);
5368   ASSERT_TRUE(parsed_while_loop.has_value());
5369   EXPECT_FALSE(parsed_while_loop->is_dynamic());
5370   EXPECT_EQ(parsed_while_loop->static_while_loop->trip_count, 10);
5371   EXPECT_EQ(parsed_while_loop->static_while_loop->induction_var_index, 0);
5372   EXPECT_EQ(parsed_while_loop->static_while_loop->induction_var_init_value, 0);
5373   EXPECT_EQ(parsed_while_loop->static_while_loop->step_size, 1);
5374   EXPECT_EQ(parsed_while_loop->static_while_loop->loop_bound, 10);
5375 }
5376 
TEST_F(PatternMatchParseWhileLoopTest,DynamicLoop)5377 TEST_F(PatternMatchParseWhileLoopTest, DynamicLoop) {
5378   constexpr absl::string_view kHloModule = R"(
5379     HloModule accumulated_all_reduce
5380 
5381     %while_condition {
5382       %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
5383       %gte.0 = s32[] get-tuple-element(%param), index=0
5384       %gte.1 = s32[] get-tuple-element(%param), index=1
5385       ROOT result = pred[] compare(%gte.0, %gte.1), direction=LT
5386     }
5387 
5388     %while_body {
5389       %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
5390       %gte.0 = s32[] get-tuple-element(%param), index=0
5391       %gte.1 = s32[] get-tuple-element(%param), index=1
5392       %gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2
5393       %gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3
5394       %accumulation = f32[1024, 1024] add(f32[1024, 1024] %gte.2, f32[1024, 1024] %gte.3)
5395       %constant = s32[] constant(1)
5396       %increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant)
5397       ROOT %loop_result = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation)
5398     }
5399 
5400     ENTRY accumulated_all_reduce {
5401       %param.1 = f32[1024, 1024] parameter(0)
5402       %param.2 = s32[] parameter(1)
5403       %loop_bound = s32[] constant(10)
5404       %accumulation_buffer_init = f32[] constant(0)
5405       %accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={}
5406       %while_init = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %param.2, s32[] %loop_bound, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer)
5407       %while = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body
5408       ROOT %result = f32[1024, 1024] get-tuple-element((s32[], s32[], f32[1024, 1024], f32[1024, 1024]) %while), index=3
5409     }
5410   )";
5411   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> hlo_module,
5412                           ParseAndReturnVerifiedModule(kHloModule));
5413   HloInstruction* while_op =
5414       hlo_module->entry_computation()->root_instruction()->mutable_operand(0);
5415   std::optional<ParsedWhileLoop> parsed_while_loop =
5416       PatternMatchParseWhileLoop(while_op);
5417   ASSERT_TRUE(parsed_while_loop.has_value());
5418   EXPECT_TRUE(parsed_while_loop->is_dynamic());
5419 }
5420 
5421 // The loop condition comparison is computed by a call to another computation.
TEST_F(PatternMatchParseWhileLoopTest,BooleanCond)5422 TEST_F(PatternMatchParseWhileLoopTest, BooleanCond) {
5423   constexpr absl::string_view kHloModule = R"(
5424     HloModule accumulated_all_reduce
5425     %while_condition {
5426       %param = (pred[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
5427        ROOT %gte.0 = pred[] get-tuple-element(%param), index=0
5428     }
5429 
5430     %while_body {
5431       %param = (pred[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
5432       %gte.0 = pred[] get-tuple-element(%param), index=0
5433       %gte.1 = f32[1024, 1024] get-tuple-element(%param), index=1
5434       %gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2
5435       %accumulation = f32[1024, 1024] add(f32[1024, 1024] %gte.1, f32[1024, 1024] %gte.2)
5436       %new_loop_cond = pred[] constant(false)
5437       ROOT %loop_result = (pred[], f32[1024, 1024], f32[1024, 1024]) tuple(%new_loop_cond, %gte.1, %accumulation)
5438     }
5439 
5440     ENTRY accumulated_all_reduce {
5441       %param.1 = f32[1024, 1024] parameter(0)
5442       %constant.0 = pred[] constant(true)
5443       %accumulation_buffer_init = f32[] constant(0)
5444       %accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={}
5445       %while_init = (pred[], f32[1024, 1024], f32[1024, 1024]) tuple(pred[] %constant.0, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer)
5446       %while = (pred[], f32[1024, 1024], f32[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body
5447       ROOT %result = f32[1024, 1024] get-tuple-element((pred[], f32[1024, 1024], f32[1024, 1024]) %while), index=2
5448     }
5449   )";
5450   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> hlo_module,
5451                           ParseAndReturnVerifiedModule(kHloModule));
5452   HloInstruction* while_op =
5453       hlo_module->entry_computation()->root_instruction()->mutable_operand(0);
5454   std::optional<ParsedWhileLoop> parsed_while_loop =
5455       PatternMatchParseWhileLoop(while_op);
5456   ASSERT_TRUE(parsed_while_loop.has_value());
5457   EXPECT_FALSE(parsed_while_loop->is_dynamic());
5458   EXPECT_EQ(parsed_while_loop->static_while_loop->trip_count, 1);
5459   EXPECT_EQ(parsed_while_loop->static_while_loop->induction_var_index, 0);
5460   EXPECT_EQ(parsed_while_loop->static_while_loop->induction_var_init_value, 0);
5461   EXPECT_EQ(parsed_while_loop->static_while_loop->step_size, 1);
5462   EXPECT_EQ(parsed_while_loop->static_while_loop->loop_bound, 1);
5463 }
5464 
5465 }  // namespace
5466 }  // namespace xla
5467