1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
16
17 #include <initializer_list>
18 #include <memory>
19 #include <string>
20 #include <tuple>
21 #include <utility>
22 #include <vector>
23
24 #include "absl/strings/str_format.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26 #include "tensorflow/compiler/xla/literal.h"
27 #include "tensorflow/compiler/xla/permutation_util.h"
28 #include "tensorflow/compiler/xla/reference_util.h"
29 #include "tensorflow/compiler/xla/service/hlo_computation.h"
30 #include "tensorflow/compiler/xla/service/hlo_element_type_converter.h"
31 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/status.h"
34 #include "tensorflow/compiler/xla/status_macros.h"
35 #include "tensorflow/compiler/xla/statusor.h"
36 #include "tensorflow/compiler/xla/test.h"
37 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
38 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
39 #include "tensorflow/compiler/xla/tests/test_utils.h"
40 #include "tensorflow/compiler/xla/types.h"
41 #include "tensorflow/compiler/xla/util.h"
42 #include "tensorflow/compiler/xla/xla_data.pb.h"
43 #include "tensorflow/core/lib/core/status.h"
44 #include "tensorflow/core/lib/core/status_test_util.h"
45 #include "tensorflow/core/platform/test.h"
46 #include "tensorflow/core/platform/test_benchmark.h"
47
48 namespace xla {
49 namespace {
50
51 static std::array<bool, 2> use_bf16_params{true, false};
52
53 // Test fixture for the HloEvaluator.
54 //
55 // In bf16 mode, all f32 shapes are converted to bf16 before running.
56 class HloEvaluatorTest : public HloTestBase {
57 public:
HloEvaluatorTest()58 HloEvaluatorTest() : use_bfloat16_(false) { InitializeFftData(); }
59
Evaluate(absl::Span<const Literal * const> arg_literals={})60 StatusOr<Literal> Evaluate(
61 absl::Span<const Literal* const> arg_literals = {}) {
62 if (use_bfloat16_) {
63 HloElementTypeConverter(F32, BF16).Run(m_.get()).ValueOrDie();
64 }
65 return evaluator_.Evaluate(*m_->entry_computation(), arg_literals);
66 }
67
68 // Evaluate function that takes in a local module instead of using m_
69 // that is in HloTestBase. Once m_ in HloTestBase is
70 // removed, this should be the default Evaluate function.
EvaluateWithModule(HloModule * module,absl::Span<const Literal * const> arg_literals={})71 Literal EvaluateWithModule(
72 HloModule* module, absl::Span<const Literal* const> arg_literals = {}) {
73 if (use_bfloat16_) {
74 HloElementTypeConverter(F32, BF16).Run(m_.get()).ValueOrDie();
75 }
76 return evaluator_.Evaluate(*module->entry_computation(), arg_literals)
77 .value();
78 }
79
TestUnaryOp(HloOpcode opcode,Literal expected,Literal input,float aabs=0)80 void TestUnaryOp(HloOpcode opcode, Literal expected, Literal input,
81 float aabs = 0) {
82 HloComputation::Builder b(TestName());
83 auto c1 =
84 b.AddInstruction(HloInstruction::CreateConstant(std::move(input)));
85 b.AddInstruction(HloInstruction::CreateUnary(expected.shape(), opcode, c1));
86 m_->AddEntryComputation(b.Build());
87
88 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
89
90 auto element_type = expected.shape().element_type();
91 if (element_type == F32 || element_type == F64) {
92 ErrorSpec error(aabs);
93 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, error));
94 } else {
95 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
96 }
97 }
98
TestBinaryOp(HloOpcode opcode,Literal expected,Literal lhs,Literal rhs)99 void TestBinaryOp(HloOpcode opcode, Literal expected, Literal lhs,
100 Literal rhs) {
101 HloComputation::Builder b(TestName());
102 auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs)));
103 auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs)));
104 b.AddInstruction(
105 HloInstruction::CreateBinary(expected.shape(), opcode, c1, c2));
106 m_->AddEntryComputation(b.Build());
107
108 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
109
110 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
111 }
112
TestTernaryOp(HloOpcode opcode,Literal expected,Literal src0,Literal src1,Literal src2)113 void TestTernaryOp(HloOpcode opcode, Literal expected, Literal src0,
114 Literal src1, Literal src2) {
115 HloComputation::Builder b(TestName());
116 auto operand0 =
117 b.AddInstruction(HloInstruction::CreateConstant(std::move(src0)));
118 auto operand1 =
119 b.AddInstruction(HloInstruction::CreateConstant(std::move(src1)));
120 auto operand2 =
121 b.AddInstruction(HloInstruction::CreateConstant(std::move(src2)));
122 b.AddInstruction(HloInstruction::CreateTernary(
123 expected.shape(), opcode, operand0, operand1, operand2));
124 m_->AddEntryComputation(b.Build());
125
126 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
127
128 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
129 }
130
TestEvaluateInstruction(HloInstruction * instruction,const Literal & expected)131 void TestEvaluateInstruction(HloInstruction* instruction,
132 const Literal& expected) {
133 TF_ASSERT_OK_AND_ASSIGN(Literal result, evaluator_.Evaluate(instruction));
134 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
135 }
136
TestEvaluationFailure(HloInstruction * instruction)137 void TestEvaluationFailure(HloInstruction* instruction) {
138 StatusOr<Literal> result = evaluator_.Evaluate(instruction);
139 EXPECT_TRUE(!result.ok());
140 }
141
TestRecursivelyEvaluateInstruction(HloInstruction * instruction,const Literal & expected)142 void TestRecursivelyEvaluateInstruction(HloInstruction* instruction,
143 const Literal& expected) {
144 TF_ASSERT_OK_AND_ASSIGN(
145 Literal result,
146 evaluator_.Evaluate(
147 instruction,
148 /*recursively_evaluate_nonconstant_operands=*/true));
149 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
150 }
151
TestRecursiveEvaluationFailure(HloInstruction * instruction)152 void TestRecursiveEvaluationFailure(HloInstruction* instruction) {
153 StatusOr<Literal> result = evaluator_.Evaluate(
154 instruction, /*recursively_evaluate_nonconstant_operands=*/true);
155 EXPECT_TRUE(!result.ok());
156 }
157
MaxComputationScalarF32()158 std::unique_ptr<HloComputation> MaxComputationScalarF32() {
159 HloComputation::Builder max_computation("max");
160 Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
161 auto param_lhs = max_computation.AddInstruction(
162 HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
163 auto param_rhs = max_computation.AddInstruction(
164 HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
165 max_computation.AddInstruction(HloInstruction::CreateBinary(
166 scalar_shape, HloOpcode::kMaximum, param_lhs, param_rhs));
167 return max_computation.Build();
168 }
169
ReduceWindowMaxIotaTest(int window_size,int padding,int stride,int window_dilation,int base_dilation,const Literal & expected)170 void ReduceWindowMaxIotaTest(int window_size, int padding, int stride,
171 int window_dilation, int base_dilation,
172 const Literal& expected) {
173 HloComputation::Builder b(TestName());
174
175 // arg:
176 // f32[4,4] {
177 // { 0, 1, 2, 3 },
178 // { 4, 5, 6, 7 },
179 // { 8, 9, 10, 11 },
180 // { 12, 13, 14, 15 }
181 // }
182 auto arg_array = std::make_unique<Array2D<float>>(4, 4);
183 arg_array->FillIota(0);
184 auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array);
185
186 HloInstruction* arg_instruction = b.AddInstruction(
187 HloInstruction::CreateConstant(std::move(arg_literal)));
188 auto init_value = b.AddInstruction(
189 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
190 auto max_func = m_->AddEmbeddedComputation(MaxComputationScalarF32());
191
192 Window window;
193 WindowDimension dim;
194 dim.set_size(window_size);
195 dim.set_stride(stride);
196 dim.set_padding_low(padding);
197 dim.set_padding_high(padding);
198 dim.set_window_dilation(window_dilation);
199 dim.set_base_dilation(base_dilation);
200 *window.add_dimensions() = dim;
201 *window.add_dimensions() = dim;
202
203 int dim0 = expected.shape().dimensions(0);
204 int dim1 = expected.shape().dimensions(1);
205 Shape shape = ShapeUtil::MakeShape(F32, {dim0, dim1});
206 b.AddInstruction(HloInstruction::CreateReduceWindow(
207 shape, arg_instruction, init_value, window, max_func));
208
209 m_->AddEntryComputation(b.Build());
210 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
211 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
212 }
213
214 protected:
HloEvaluatorTest(bool use_bfloat16)215 explicit HloEvaluatorTest(bool use_bfloat16) : use_bfloat16_(use_bfloat16) {
216 InitializeFftData();
217 }
218
219 // Initializes data sets used in FFT tests below.
220 void InitializeFftData();
221
222 HloEvaluator evaluator_;
223
224 const bool use_bfloat16_;
225 std::unique_ptr<HloModule> m_ = CreateNewVerifiedModule();
226
227 // Data sets used in FFT tests below.
228 ErrorSpec fft_error_ = ErrorSpec(1e-4, 1e-5);
229 Literal fft_c64x2x4x8_;
230 Literal fft_c64x2x4x8_1d_;
231 Literal fft_c64x2x4x8_2d_;
232 Literal fft_c64x2x4x8_3d_;
233 };
234
235 // Lets you write TEST_Ps that run twice, once with and once without bf16.
236 class HloEvaluatorBf16Test : public ::testing::WithParamInterface<bool>,
237 public HloEvaluatorTest {
238 protected:
HloEvaluatorBf16Test()239 HloEvaluatorBf16Test() : HloEvaluatorTest(/*use_bfloat16=*/GetParam()) {}
240 };
241
242 INSTANTIATE_TEST_SUITE_P(HloEvaluatorTest_Instantiation, HloEvaluatorBf16Test,
243 ::testing::ValuesIn(use_bf16_params));
244
245 // Verifies that HloEvaluator evaluates a HLO instruction that performs clamp
246 // with 3 operands.
TEST_P(HloEvaluatorBf16Test,DoesClamp)247 TEST_P(HloEvaluatorBf16Test, DoesClamp) {
248 auto low = LiteralUtil::CreateR2<float>({{0.f, 2.f}, {2.f, 4.f}});
249 auto value = LiteralUtil::CreateR2<float>({{0.f, 5.f}, {0.f, 4.f}});
250 auto high = LiteralUtil::CreateR2<float>({{2.f, 4.f}, {4.f, 4.f}});
251
252 Shape shape = low.shape();
253 HloComputation::Builder b(TestName());
254 auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low)));
255 auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
256 auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high)));
257 b.AddInstruction(
258 HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3));
259 m_->AddEntryComputation(b.Build());
260
261 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
262
263 auto expected = LiteralUtil::CreateR2<float>({{0, 4}, {2, 4}});
264
265 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
266 }
267
268 // Verifies that clamping of int64_t does not cause loss of precision
TEST_P(HloEvaluatorBf16Test,DoesClampInt64)269 TEST_P(HloEvaluatorBf16Test, DoesClampInt64) {
270 auto ones = [](int bits) { return (int64_t{1} << bits) - 1; };
271
272 auto low =
273 LiteralUtil::CreateR2<int64_t>({{0, ones(54)}, {ones(54), ones(58)}});
274 auto value = LiteralUtil::CreateR2<int64_t>({{0, ones(56)}, {0, ones(58)}});
275 auto high = LiteralUtil::CreateR2<int64_t>(
276 {{ones(54), ones(55)}, {ones(56), ones(58)}});
277
278 Shape shape = low.shape();
279 HloComputation::Builder b(TestName());
280 auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low)));
281 auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
282 auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high)));
283 b.AddInstruction(
284 HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3));
285 m_->AddEntryComputation(b.Build());
286
287 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
288
289 auto expected =
290 LiteralUtil::CreateR2<int64_t>({{0, ones(55)}, {ones(54), ones(58)}});
291
292 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
293 }
294
TEST_P(HloEvaluatorBf16Test,DISABLED_DoesClampSpecialBroadcast)295 TEST_P(HloEvaluatorBf16Test, DISABLED_DoesClampSpecialBroadcast) {
296 auto low = LiteralUtil::CreateR0<float>(0.f);
297 auto value = LiteralUtil::CreateR2<float>({{-1.f, 0.f}, {1.f, 2.f}});
298 auto high = LiteralUtil::CreateR0<float>(1.f);
299
300 Shape shape = value.shape();
301 HloComputation::Builder b(TestName());
302 auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low)));
303 auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value)));
304 auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high)));
305 b.AddInstruction(
306 HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3));
307 m_->AddEntryComputation(b.Build());
308
309 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
310
311 auto expected = LiteralUtil::CreateR2<float>({{0, 0}, {1, 1}});
312
313 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
314 }
315
316 // Verifies that HloEvaluator evaluates a HLO instruction that performs select
317 // with 3 operands.
TEST_P(HloEvaluatorBf16Test,DoesSelect)318 TEST_P(HloEvaluatorBf16Test, DoesSelect) {
319 auto pred = LiteralUtil::CreateR2<bool>({{true, false}, {false, true}});
320 auto on_true = LiteralUtil::CreateR2<float>({{2.f, 4.f}, {4.f, 4.f}});
321 auto on_false = LiteralUtil::CreateR2<float>({{0.f, 5.f}, {0.f, 4.f}});
322
323 Shape shape = on_true.shape();
324 HloComputation::Builder b(TestName());
325 auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(pred)));
326 auto c2 =
327 b.AddInstruction(HloInstruction::CreateConstant(std::move(on_true)));
328 auto c3 =
329 b.AddInstruction(HloInstruction::CreateConstant(std::move(on_false)));
330 b.AddInstruction(
331 HloInstruction::CreateTernary(shape, HloOpcode::kSelect, c1, c2, c3));
332 m_->AddEntryComputation(b.Build());
333
334 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({}));
335
336 auto expected = LiteralUtil::CreateR2<float>({{2, 5}, {0, 4}});
337
338 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
339 }
340
341 // Verifies that HloEvaluator evaluates a HLO instruction that performs
342 // element-wise addition with 2 operands.
TEST_F(HloEvaluatorTest,DoesAdd)343 TEST_F(HloEvaluatorTest, DoesAdd) {
344 auto lhs = LiteralUtil::CreateR2<int64_t>({{1, 0}, {-100, 4}});
345 auto rhs = LiteralUtil::CreateR2<int64_t>({{2, 4}, {4, 4}});
346 auto expected = LiteralUtil::CreateR2<int64_t>({{3, 4}, {-96, 8}});
347 TestBinaryOp(HloOpcode::kAdd, std::move(expected), std::move(lhs),
348 std::move(rhs));
349 }
350 // Verifies that HloEvaluator evaluates a HLO instruction that performs
351 // element-wise and with 2 operands.
TEST_P(HloEvaluatorBf16Test,DoesAnd)352 TEST_P(HloEvaluatorBf16Test, DoesAnd) {
353 auto lhs = LiteralUtil::CreateR2<int64_t>({{1, 0}, {-100, 4}});
354 auto rhs = LiteralUtil::CreateR2<int64_t>({{2, 4}, {4, 4}});
355 auto expected = LiteralUtil::CreateR2<int64_t>({{0, 0}, {4, 4}});
356 TestBinaryOp(HloOpcode::kAnd, std::move(expected), std::move(lhs),
357 std::move(rhs));
358 }
359 // Verifies that HloEvaluator evaluates a HLO instruction that performs
360 // element-wise or with 2 operands.
TEST_F(HloEvaluatorTest,DoesOr)361 TEST_F(HloEvaluatorTest, DoesOr) {
362 auto lhs = LiteralUtil::CreateR2<int64_t>({{1, 0}, {-100, 4}});
363 auto rhs = LiteralUtil::CreateR2<int64_t>({{2, 4}, {4, 4}});
364 auto expected = LiteralUtil::CreateR2<int64_t>({{3, 4}, {-100, 4}});
365 TestBinaryOp(HloOpcode::kOr, std::move(expected), std::move(lhs),
366 std::move(rhs));
367 }
368 // Verifies that HloEvaluator evaluates a HLO instruction that performs
369 // element-wise or with 2 operands.
TEST_F(HloEvaluatorTest,DoesXor)370 TEST_F(HloEvaluatorTest, DoesXor) {
371 auto lhs = LiteralUtil::CreateR2<int64_t>({{1, 0}, {-100, 4}});
372 auto rhs = LiteralUtil::CreateR2<int64_t>({{2, 4}, {4, 4}});
373 auto expected = LiteralUtil::CreateR2<int64_t>({{3, 4}, {-104, 0}});
374 TestBinaryOp(HloOpcode::kXor, std::move(expected), std::move(lhs),
375 std::move(rhs));
376 }
377 // Verifies that HloEvaluator evaluates a HLO instruction that performs
378 // element-wise multiply with 2 operands.
TEST_F(HloEvaluatorTest,DoesMultiply)379 TEST_F(HloEvaluatorTest, DoesMultiply) {
380 auto lhs = LiteralUtil::CreateR2<int32_t>({{-1, 0}, {-100, 4}});
381 auto rhs = LiteralUtil::CreateR2<int32_t>(
382 {{std::numeric_limits<int32_t>::min(), 4}, {4, 4}});
383 auto expected = LiteralUtil::CreateR2<int32_t>(
384 {{std::numeric_limits<int32_t>::min(), 0}, {-400, 16}});
385 TestBinaryOp(HloOpcode::kMultiply, std::move(expected), std::move(lhs),
386 std::move(rhs));
387 }
388 // Verifies that HloEvaluator evaluates a HLO instruction that performs
389 // element-wise divide with 2 operands.
TEST_F(HloEvaluatorTest,DoesDivideInt64)390 TEST_F(HloEvaluatorTest, DoesDivideInt64) {
391 auto lhs = LiteralUtil::CreateR2<int64_t>({{1, 0}, {-100, 4}});
392 auto rhs = LiteralUtil::CreateR2<int64_t>({{2, 4}, {4, 4}});
393 auto expected = LiteralUtil::CreateR2<int64_t>({{0, 0}, {-25, 1}});
394 TestBinaryOp(HloOpcode::kDivide, std::move(expected), std::move(lhs),
395 std::move(rhs));
396 }
397
TEST_F(HloEvaluatorTest,DoesClampS64)398 TEST_F(HloEvaluatorTest, DoesClampS64) {
399 auto low = LiteralUtil::CreateR1<int64_t>(
400 {-8616761059752331528LL, 6780561065411491190LL, -8616761059752331528LL});
401 auto value = LiteralUtil::CreateR1<int64_t>(
402 {-6780561065411491190LL, 6780561065411491180LL, 4241131823772864090LL});
403 auto high = LiteralUtil::CreateR1<int64_t>(
404 {-6780561065411491180LL, 8616761059752331528LL, 3832151243857508051LL});
405 auto expected = LiteralUtil::CreateR1<int64_t>(
406 {-6780561065411491190LL, 6780561065411491190LL, 3832151243857508051LL});
407 TestTernaryOp(HloOpcode::kClamp, std::move(expected), std::move(low),
408 std::move(value), std::move(high));
409 }
410
TEST_P(HloEvaluatorBf16Test,DoesDivideDouble)411 TEST_P(HloEvaluatorBf16Test, DoesDivideDouble) {
412 auto lhs = LiteralUtil::CreateR2<double>({{1.0, 0.0}, {-100.0, 4.0}});
413 auto rhs = LiteralUtil::CreateR2<double>({{2.2, 4.0}, {4.0, 4.0}});
414 auto expected =
415 LiteralUtil::CreateR2<double>({{0.45454545454545453, 0}, {-25, 1}});
416 TestBinaryOp(HloOpcode::kDivide, std::move(expected), std::move(lhs),
417 std::move(rhs));
418 }
419
420 // Verifies that HloEvaluator evaluates a HLO instruction that performs
421 // element-wise abs op with 1 operand.
TEST_F(HloEvaluatorTest,DoesAbsR2)422 TEST_F(HloEvaluatorTest, DoesAbsR2) {
423 auto operand = LiteralUtil::CreateR2<int64_t>({{1, -20}, {-100, 4}});
424 auto expected = LiteralUtil::CreateR2<int64_t>({{1, 20}, {100, 4}});
425 TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand));
426 }
TEST_P(HloEvaluatorBf16Test,DoesAbsR0)427 TEST_P(HloEvaluatorBf16Test, DoesAbsR0) {
428 auto operand = LiteralUtil::CreateR0<float>(-1.0f);
429 auto expected = LiteralUtil::CreateR0<float>(1.0f);
430 TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand));
431 }
TEST_P(HloEvaluatorBf16Test,DoesAbsR1WithZeroSize)432 TEST_P(HloEvaluatorBf16Test, DoesAbsR1WithZeroSize) {
433 auto operand = LiteralUtil::CreateR1<float>({});
434 auto expected = LiteralUtil::CreateR1<float>({});
435 TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand));
436 }
437
TEST_F(HloEvaluatorTest,DoesAbsC128)438 TEST_F(HloEvaluatorTest, DoesAbsC128) {
439 auto x = LiteralUtil::CreateR0<complex128>({1, 2});
440 auto expected_real = LiteralUtil::CreateR0<double>(2.23607);
441 TestUnaryOp(HloOpcode::kAbs, std::move(expected_real), std::move(x), 3e-06);
442 }
443
TEST_F(HloEvaluatorTest,DoesNegateR2)444 TEST_F(HloEvaluatorTest, DoesNegateR2) {
445 auto operand = LiteralUtil::CreateR2<int32_t>(
446 {{0, std::numeric_limits<int32_t>::min()}, {-1, 4}});
447 auto expected = LiteralUtil::CreateR2<int32_t>(
448 {{0, std::numeric_limits<int>::min()}, {1, -4}});
449 TestUnaryOp(HloOpcode::kNegate, std::move(expected), std::move(operand));
450 }
TEST_P(HloEvaluatorBf16Test,DoesCosR2)451 TEST_P(HloEvaluatorBf16Test, DoesCosR2) {
452 auto operand = LiteralUtil::CreateR2<float>({{0, M_PI}, {-M_PI, 2 * M_PI}});
453 auto expected = LiteralUtil::CreateR2<float>({{1, -1}, {-1, 1}});
454 TestUnaryOp(HloOpcode::kCos, std::move(expected), std::move(operand),
455 use_bfloat16_ ? 0.031250 : 9.5367431640625E-7);
456 }
TEST_P(HloEvaluatorBf16Test,DoesSinR2)457 TEST_P(HloEvaluatorBf16Test, DoesSinR2) {
458 auto operand = LiteralUtil::CreateR2<float>({{0, M_PI}, {-M_PI, 2 * M_PI}});
459 auto expected = LiteralUtil::CreateR2<float>({{0, 0}, {0, 0}});
460 TestUnaryOp(HloOpcode::kSin, std::move(expected), std::move(operand),
461 use_bfloat16_ ? 0.031250 : 9.5367431640625E-7);
462 }
TEST_F(HloEvaluatorTest,DoesNotR2)463 TEST_F(HloEvaluatorTest, DoesNotR2) {
464 auto operand =
465 LiteralUtil::CreateR2<int32_t>({{0, std::numeric_limits<int>::min()},
466 {-1, std::numeric_limits<int>::max()}});
467 auto expected =
468 LiteralUtil::CreateR2<int32_t>({{-1, std::numeric_limits<int>::max()},
469 {0, std::numeric_limits<int>::min()}});
470 TestUnaryOp(HloOpcode::kNot, std::move(expected), std::move(operand));
471 }
472
TEST_F(HloEvaluatorTest,DoesRealC128)473 TEST_F(HloEvaluatorTest, DoesRealC128) {
474 auto x = LiteralUtil::CreateR1<complex128>({{1, 0}, {-100, 4}});
475 auto expected_real = LiteralUtil::CreateR1<double>({1, -100});
476 TestUnaryOp(HloOpcode::kReal, std::move(expected_real), std::move(x));
477 }
478
TEST_F(HloEvaluatorTest,DoesImagC128)479 TEST_F(HloEvaluatorTest, DoesImagC128) {
480 auto x = LiteralUtil::CreateR1<complex128>({{1, 0}, {-100, 4}});
481 auto expected_imag = LiteralUtil::CreateR1<double>({0, 4});
482 TestUnaryOp(HloOpcode::kImag, std::move(expected_imag), std::move(x));
483 }
484
TEST_P(HloEvaluatorBf16Test,DoesImagF32AndBf16)485 TEST_P(HloEvaluatorBf16Test, DoesImagF32AndBf16) {
486 auto x = LiteralUtil::CreateR1<float>({1, -100});
487 auto expected_imag = LiteralUtil::CreateR1<float>({0, 0});
488 TestUnaryOp(HloOpcode::kImag, std::move(expected_imag), std::move(x));
489 }
490
TEST_F(HloEvaluatorTest,DoesImagF64)491 TEST_F(HloEvaluatorTest, DoesImagF64) {
492 auto x = LiteralUtil::CreateR1<double>({1, -100});
493 auto expected_imag = LiteralUtil::CreateR1<double>({0, 0});
494 TestUnaryOp(HloOpcode::kImag, std::move(expected_imag), std::move(x));
495 }
496
497 // Verifies that HloEvaluator evaluates a HLO Computation with non-parameter nor
498 // constant operands.
TEST_F(HloEvaluatorTest,DoesTraverseInstructions)499 TEST_F(HloEvaluatorTest, DoesTraverseInstructions) {
500 auto lhs = LiteralUtil::CreateR2<int64_t>({{1, 0}, {-100, 4}});
501 auto rhs = LiteralUtil::CreateR2<int64_t>({{2, 4}, {4, 4}});
502 auto rhs2 = LiteralUtil::CreateR2<int64_t>({{1, -20}, {-100, 4}});
503 std::vector<const Literal*> args = {&lhs, &rhs, &rhs2};
504
505 Shape shape = ShapeUtil::MakeShape(S64, {2, 2});
506
507 HloComputation::Builder b(TestName());
508 auto param_lhs =
509 b.AddInstruction(HloInstruction::CreateParameter(0, shape, "lhs"));
510 auto param_rhs =
511 b.AddInstruction(HloInstruction::CreateParameter(1, shape, "rhs"));
512 auto lhs_instruction = b.AddInstruction(HloInstruction::CreateBinary(
513 shape, HloOpcode::kAdd, param_lhs, param_rhs));
514
515 auto param_rhs2 =
516 b.AddInstruction(HloInstruction::CreateParameter(2, shape, "rhs2"));
517 b.AddInstruction(HloInstruction::CreateBinary(shape, HloOpcode::kAdd,
518 lhs_instruction, param_rhs2));
519 m_->AddEntryComputation(b.Build());
520
521 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate(args));
522
523 auto expected = LiteralUtil::CreateR2<int64_t>({{4, -16}, {-196, 12}});
524
525 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
526 }
527
528 // Verifies Reshape operation is correctly evaluated.
TEST_F(HloEvaluatorTest,DoesReshape)529 TEST_F(HloEvaluatorTest, DoesReshape) {
530 HloComputation::Builder b(TestName());
531 const int64_t dimensions[] = {11, 8, 7, 5, 9};
532 TF_ASSERT_OK_AND_ASSIGN(auto literal,
533 LiteralUtil::CreateRandomLiteral<F32>(
534 ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
535 auto literal_clone = literal.Clone();
536 HloInstruction* literal_instruction =
537 b.AddInstruction(HloInstruction::CreateConstant(std::move(literal)));
538
539 Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5});
540 const int64_t permutation[] = {1, 2, 0, 4, 3};
541 b.AddInstruction(
542 HloInstruction::CreateTranspose(shape, literal_instruction, permutation));
543 m_->AddEntryComputation(b.Build());
544
545 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({}));
546
547 using NativeT = typename primitive_util::PrimitiveTypeToNative<F32>::type;
548 result.EachCell<NativeT>(
549 [&](absl::Span<const int64_t> indices, NativeT value) {
550 std::vector<int64_t> rindexes = PermuteInverse(indices, permutation);
551 EXPECT_NEAR(value, literal_clone.Get<NativeT>(rindexes), 0.031250);
552 });
553 }
554
555 // Verifies Broadcast operation is correctly evaluated.
TEST_F(HloEvaluatorTest,DoesBroadcast)556 TEST_F(HloEvaluatorTest, DoesBroadcast) {
557 HloComputation::Builder b(TestName());
558 auto input_literal = LiteralUtil::CreateR2<int32_t>({{1, 2}, {3, 4}, {5, 6}});
559 auto output_literal = LiteralUtil::CreateR3<int32_t>(
560 {{{1, 2}, {3, 4}, {5, 6}}, {{1, 2}, {3, 4}, {5, 6}}});
561 HloInstruction* literal_instruction = b.AddInstruction(
562 HloInstruction::CreateConstant(std::move(input_literal)));
563 b.AddInstruction(HloInstruction::CreateBroadcast(
564 output_literal.shape(), literal_instruction, {1, 2}));
565 m_->AddEntryComputation(b.Build());
566
567 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({}));
568
569 EXPECT_TRUE(LiteralTestUtil::Equal(result, output_literal));
570 }
571
TEST_F(HloEvaluatorTest,DoesBroadcastScalar)572 TEST_F(HloEvaluatorTest, DoesBroadcastScalar) {
573 HloComputation::Builder b(TestName());
574 auto input_literal = LiteralUtil::CreateR0<int32_t>(111);
575 auto output_literal = LiteralUtil::CreateR2<int32_t>(
576 {{111, 111}, {111, 111}, {111, 111}, {111, 111}, {111, 111}, {111, 111}});
577
578 HloInstruction* literal_instruction = b.AddInstruction(
579 HloInstruction::CreateConstant(std::move(input_literal)));
580 // Broadcast dimension should be empty in the case of scalars.
581 b.AddInstruction(HloInstruction::CreateBroadcast(
582 output_literal.shape(), literal_instruction,
583 /*broadcast_dimensions=*/{}));
584 m_->AddEntryComputation(b.Build());
585
586 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({}));
587
588 EXPECT_TRUE(LiteralTestUtil::Equal(result, output_literal));
589 }
590
TEST_F(HloEvaluatorTest,DoesConcatenateSimple)591 TEST_F(HloEvaluatorTest, DoesConcatenateSimple) {
592 HloComputation::Builder b(TestName());
593
594 HloInstruction* operand1 = b.AddInstruction(HloInstruction::CreateConstant(
595 LiteralUtil::CreateR2<int64_t>({{-1, -2}, {100, 200}})));
596 HloInstruction* operand2 = b.AddInstruction(HloInstruction::CreateConstant(
597 LiteralUtil::CreateR2<int64_t>({{-2, -3}, {-100, -200}})));
598
599 std::vector<HloInstruction*> operands = {operand1, operand2};
600
601 Shape shape = ShapeUtil::MakeShape(S64, {4, 2});
602 b.AddInstruction(HloInstruction::CreateConcatenate(shape, operands, 0));
603
604 m_->AddEntryComputation(b.Build());
605
606 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
607
608 auto expected = LiteralUtil::CreateR2<int64_t>(
609 {{-1, -2}, {100, 200}, {-2, -3}, {-100, -200}});
610 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
611 }
612
TEST_F(HloEvaluatorTest,ConcatenateHandlesShapeWithZeroElement)613 TEST_F(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) {
614 HloComputation::Builder b(TestName());
615
616 HloInstruction* operand1 = b.AddInstruction(HloInstruction::CreateConstant(
617 LiteralUtil::CreateR1<int64_t>({100, 200})));
618 HloInstruction* operand2 = b.AddInstruction(
619 HloInstruction::CreateConstant(LiteralUtil::CreateR1<int64_t>({})));
620
621 std::vector<HloInstruction*> operands = {operand1, operand2};
622
623 Shape shape = ShapeUtil::MakeShape(S64, {2});
624 b.AddInstruction(HloInstruction::CreateConcatenate(shape, operands, 0));
625
626 m_->AddEntryComputation(b.Build());
627
628 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
629
630 auto expected = LiteralUtil::CreateR1<int64_t>({100, 200});
631 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
632 }
633
TEST_P(HloEvaluatorBf16Test,ConvertWithSameLayout)634 TEST_P(HloEvaluatorBf16Test, ConvertWithSameLayout) {
635 HloComputation::Builder b(TestName());
636
637 auto input_literal = LiteralUtil::CreateR2<int32_t>({{1, 2}, {3, 4}, {5, 6}});
638 auto expected =
639 LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}});
640 ASSERT_TRUE(LayoutUtil::LayoutsInShapesEqual(input_literal.shape(),
641 expected.shape()));
642
643 HloInstruction* constant = b.AddInstruction(
644 HloInstruction::CreateConstant(std::move(input_literal)));
645 b.AddInstruction(HloInstruction::CreateConvert(expected.shape(), constant));
646 m_->AddEntryComputation(b.Build());
647
648 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
649
650 EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
651 }
652
TEST_P(HloEvaluatorBf16Test,ConvertWithDifferentLayout)653 TEST_P(HloEvaluatorBf16Test, ConvertWithDifferentLayout) {
654 HloComputation::Builder b(TestName());
655
656 auto input_literal = LiteralUtil::CreateR2WithLayout<int32_t>(
657 {{1, 2}, {3, 4}, {5, 6}}, LayoutUtil::MakeLayout({0, 1}));
658 auto expected = LiteralUtil::CreateR2WithLayout<float>(
659 {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}, LayoutUtil::MakeLayout({1, 0}));
660 ASSERT_FALSE(LayoutUtil::LayoutsInShapesEqual(input_literal.shape(),
661 expected.shape()));
662
663 HloInstruction* constant = b.AddInstruction(
664 HloInstruction::CreateConstant(std::move(input_literal)));
665 b.AddInstruction(HloInstruction::CreateConvert(expected.shape(), constant));
666 m_->AddEntryComputation(b.Build());
667
668 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
669
670 EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
671 }
672
CreatePaddingConfig(std::initializer_list<std::array<int64_t,3>> padding_dimensions)673 PaddingConfig CreatePaddingConfig(
674 std::initializer_list<std::array<int64_t, 3>> padding_dimensions) {
675 PaddingConfig padding_config;
676
677 for (auto& paddings_per_dim : padding_dimensions) {
678 auto dimension = padding_config.add_dimensions();
679 dimension->set_edge_padding_low(paddings_per_dim[0]);
680 dimension->set_edge_padding_high(paddings_per_dim[1]);
681 dimension->set_interior_padding(paddings_per_dim[2]);
682 }
683 return padding_config;
684 }
685
TEST_F(HloEvaluatorTest,Pad2DIntegerArrayWithZeroDimension)686 TEST_F(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) {
687 auto operand = LiteralUtil::CreateR2<int32_t>({{}, {}});
688 HloComputation::Builder b(TestName());
689 auto operand_instruction =
690 b.AddInstruction(HloInstruction::CreateConstant(std::move(operand)));
691
692 constexpr int32_t kPadValue = 10;
693 auto pad_value = LiteralUtil::CreateR0<int32_t>(kPadValue);
694 auto padding_value_instruction =
695 b.AddInstruction(HloInstruction::CreateConstant(std::move(pad_value)));
696
697 auto padding_config = CreatePaddingConfig({{{1, 0, 2}}, {{0, 2, 1}}});
698 Shape shape = ShapeUtil::MakeShape(S32, {5, 2});
699 b.AddInstruction(HloInstruction::CreatePad(
700 shape, operand_instruction, padding_value_instruction, padding_config));
701 m_->AddEntryComputation(b.Build());
702
703 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
704
705 auto expected = LiteralUtil::CreateR2<int32_t>(
706 {{10, 10}, {10, 10}, {10, 10}, {10, 10}, {10, 10}});
707
708 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
709 }
710
TEST_P(HloEvaluatorBf16Test,Pad4DFloatArrayWithInteriorPadding)711 TEST_P(HloEvaluatorBf16Test, Pad4DFloatArrayWithInteriorPadding) {
712 HloComputation::Builder b(TestName());
713
714 Array4D<float> input_array(3, 2, 1, 1, {1, 2, 3, 4, 5, 6});
715 auto input = LiteralUtil::CreateR4FromArray4D<float>(input_array);
716 HloInstruction* input_instruction =
717 b.AddInstruction(HloInstruction::CreateConstant(std::move(input)));
718 constexpr float kPadValue = 1.5;
719 auto pad_value = LiteralUtil::CreateR0<float>(kPadValue);
720 HloInstruction* pad_instruction =
721 b.AddInstruction(HloInstruction::CreateConstant(std::move(pad_value)));
722
723 Shape shape = ShapeUtil::MakeShape(F32, {8, 5, 1, 1});
724 auto r4_padding_on_dim0_dim1 =
725 CreatePaddingConfig({{{1, 0, 2}}, {{0, 2, 1}}, {{0, 0, 0}}, {{0, 0, 0}}});
726 b.AddInstruction(HloInstruction::CreatePad(
727 shape, input_instruction, pad_instruction, r4_padding_on_dim0_dim1));
728 m_->AddEntryComputation(b.Build());
729
730 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
731
732 auto expected_array = std::make_unique<Array4D<float>>(8, 5, 1, 1);
733 expected_array->Fill(kPadValue);
734 (*expected_array)(1, 0, 0, 0) = 1.0f;
735 (*expected_array)(1, 2, 0, 0) = 2.0f;
736 (*expected_array)(4, 0, 0, 0) = 3.0f;
737 (*expected_array)(4, 2, 0, 0) = 4.0f;
738 (*expected_array)(7, 0, 0, 0) = 5.0f;
739 (*expected_array)(7, 2, 0, 0) = 6.0f;
740
741 auto expected = LiteralUtil::CreateR4FromArray4D<float>(*expected_array);
742
743 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
744 }
745
TEST_P(HloEvaluatorBf16Test,NegativePadding2D)746 TEST_P(HloEvaluatorBf16Test, NegativePadding2D) {
747 HloComputation::Builder b(TestName());
748
749 // input_array:
750 // f32[4,3] {
751 // { 1, 2, 3 },
752 // { 5, 6, 7 },
753 // { 9, 10, 11 },
754 // { 13, 14, 15 },
755 // }
756 auto input_array = std::make_unique<Array2D<float>>(4, 3);
757 input_array->FillUnique(1.0f);
758 auto input = LiteralUtil::CreateR2FromArray2D<float>(*input_array);
759 HloInstruction* input_instruction =
760 b.AddInstruction(HloInstruction::CreateConstant(std::move(input)));
761
762 auto pad_value_instruction = b.AddInstruction(
763 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.718f)));
764
765 auto r2_padding_on_dim0_dim1 =
766 CreatePaddingConfig({{{-1, -2, 0}}, {{-2, 4, 0}}});
767 Shape shape = ShapeUtil::MakeShape(F32, {1, 5});
768 b.AddInstruction(HloInstruction::CreatePad(shape, input_instruction,
769 pad_value_instruction,
770 r2_padding_on_dim0_dim1));
771
772 m_->AddEntryComputation(b.Build());
773
774 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
775
776 // f32[1,5] { 7.0, 2.718, 2.718, 2.718, 2.718 }
777 auto expected_array = std::make_unique<Array2D<float>>(1, 5);
778 (*expected_array)(0, 0) = 7.0f;
779 (*expected_array)(0, 1) = 2.718f;
780 (*expected_array)(0, 2) = 2.718f;
781 (*expected_array)(0, 3) = 2.718f;
782 (*expected_array)(0, 4) = 2.718f;
783 auto expected = LiteralUtil::CreateR2FromArray2D<float>(*expected_array);
784
785 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, ErrorSpec(0.031250)));
786 }
787
TEST_P(HloEvaluatorBf16Test,NegativeAndInteriorPadding2D)788 TEST_P(HloEvaluatorBf16Test, NegativeAndInteriorPadding2D) {
789 HloComputation::Builder b(TestName());
790
791 // f32[4,3] {
792 // { 1, 2, 3 },
793 // { 5, 6, 7 },
794 // { 9, 10, 11 },
795 // { 13, 14, 15 },
796 // }
797 auto input_array = std::make_unique<Array2D<float>>(4, 3);
798 input_array->FillUnique(1.0f);
799 auto input = LiteralUtil::CreateR2FromArray2D<float>(*input_array);
800 HloInstruction* input_instruction =
801 b.AddInstruction(HloInstruction::CreateConstant(std::move(input)));
802
803 auto pad_value_instruction = b.AddInstruction(
804 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.718f)));
805
806 PaddingConfig padding_config = MakeNoPaddingConfig(2);
807
808 // Negative padding that results in zero dimensions.
809 auto r2_padding_on_dim0_dim1 =
810 CreatePaddingConfig({{{-2, -5, 1}}, {{-2, 4, 2}}});
811
812 Shape shape = ShapeUtil::MakeShape(F32, {0, 9});
813 b.AddInstruction(HloInstruction::CreatePad(shape, input_instruction,
814 pad_value_instruction,
815 r2_padding_on_dim0_dim1));
816
817 m_->AddEntryComputation(b.Build());
818
819 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
820
821 auto expected_array = std::make_unique<Array2D<float>>(0, 9);
822 auto expected = LiteralUtil::CreateR2FromArray2D<float>(*expected_array);
823
824 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
825 }
826
TEST_P(HloEvaluatorBf16Test,DotRank2AndRank1)827 TEST_P(HloEvaluatorBf16Test, DotRank2AndRank1) {
828 HloComputation::Builder b(TestName());
829
830 // lhs:
831 // f32[4,1] {
832 // { 1 },
833 // { 2 },
834 // { 3 },
835 // { 4 },
836 // }
837 auto lhs_array = std::make_unique<Array2D<float>>(4, 1);
838 lhs_array->FillUnique(1.0f);
839 auto lhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*lhs_array);
840 HloInstruction* lhs_instruction =
841 b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
842
843 // rhs:
844 // f32[2] { 1, 2 },
845 auto rhs_literal = LiteralUtil::CreateR2<float>({{1, 2}});
846 HloInstruction* rhs_instruction =
847 b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
848
849 Shape shape = ShapeUtil::MakeShape(F32, {4, 2});
850 DotDimensionNumbers dot_dnums;
851 dot_dnums.add_lhs_contracting_dimensions(1);
852 dot_dnums.add_rhs_contracting_dimensions(0);
853 b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction,
854 rhs_instruction, dot_dnums,
855 DefaultPrecisionConfig(2)));
856 m_->AddEntryComputation(b.Build());
857
858 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
859
860 // clang-format off
861 auto expected_array = Array2D<float>({
862 {1.f, 2.f},
863 {2.f, 4.f},
864 {3.f, 6.f},
865 {4.f, 8.f},
866 });
867 // clang-format on
868 auto expected = LiteralUtil::CreateR2FromArray2D<float>(expected_array);
869
870 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
871 }
872
TEST_P(HloEvaluatorBf16Test,DotRank1AndRank2)873 TEST_P(HloEvaluatorBf16Test, DotRank1AndRank2) {
874 HloComputation::Builder b(TestName());
875
876 // lhs:
877 // f32[3]
878 // { 1, 2, 3 },
879 auto lhs_literal = LiteralUtil::CreateR1<float>({1, 2, 3});
880 HloInstruction* lhs_instruction =
881 b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
882
883 // rhs:
884 // f32[3,2] {
885 // { 1, 2 },
886 // { 3, 4 },
887 // { 5, 6 },
888 // }
889 auto rhs_array = std::make_unique<Array2D<float>>(3, 2);
890 rhs_array->FillUnique(1.0f);
891 auto rhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*rhs_array);
892 HloInstruction* rhs_instruction =
893 b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
894
895 Shape shape = ShapeUtil::MakeShape(F32, {2});
896 DotDimensionNumbers dot_dnums;
897 dot_dnums.add_lhs_contracting_dimensions(0);
898 dot_dnums.add_rhs_contracting_dimensions(0);
899 b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction,
900 rhs_instruction, dot_dnums,
901 DefaultPrecisionConfig(2)));
902 m_->AddEntryComputation(b.Build());
903
904 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
905
906 auto expected = LiteralUtil::CreateR1<float>({22.f, 28.f});
907
908 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
909 }
910
TEST_P(HloEvaluatorBf16Test,DotRank2AndRank2)911 TEST_P(HloEvaluatorBf16Test, DotRank2AndRank2) {
912 HloComputation::Builder b(TestName());
913
914 // lhs:
915 // f32[4,3] {
916 // { 1, 2, 3 },
917 // { 5, 6, 7 },
918 // { 9, 10, 11 },
919 // { 13, 14, 15 },
920 // }
921 auto lhs_array = std::make_unique<Array2D<float>>(4, 3);
922 lhs_array->FillUnique(1.0f);
923 auto lhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*lhs_array);
924 HloInstruction* lhs_instruction =
925 b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
926
927 // rhs:
928 // f32[3,2] {
929 // { 1, 2 },
930 // { 3, 4 },
931 // { 5, 6 },
932 // }
933 auto rhs_array = std::make_unique<Array2D<float>>(3, 2);
934 rhs_array->FillUnique(1.0f);
935 auto rhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*rhs_array);
936 HloInstruction* rhs_instruction =
937 b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
938
939 Shape shape = ShapeUtil::MakeShape(F32, {4, 2});
940 DotDimensionNumbers dot_dnums;
941 dot_dnums.add_lhs_contracting_dimensions(1);
942 dot_dnums.add_rhs_contracting_dimensions(0);
943 b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction,
944 rhs_instruction, dot_dnums,
945 DefaultPrecisionConfig(2)));
946 m_->AddEntryComputation(b.Build());
947
948 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
949
950 auto expected_array = Array2D<float>({
951 {22.f, 28.f},
952 {58.f, 76.f},
953 {94.f, 124.f},
954 {130.f, 172.f},
955 });
956 auto expected = LiteralUtil::CreateR2FromArray2D<float>(expected_array);
957
958 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
959 }
960
TEST_P(HloEvaluatorBf16Test,DotRank4AndRank4)961 TEST_P(HloEvaluatorBf16Test, DotRank4AndRank4) {
962 HloComputation::Builder b(TestName());
963
964 auto lhs_array = std::make_unique<Array4D<float>>(2, 2, 3, 1);
965 lhs_array->FillIota(1.0f);
966 auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(*lhs_array);
967 HloInstruction* lhs_instruction =
968 b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
969
970 auto rhs_array = std::make_unique<Array4D<float>>(2, 2, 3, 1);
971 rhs_array->FillIota(2.0f);
972 auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(*rhs_array);
973 HloInstruction* rhs_instruction =
974 b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
975
976 Shape shape = ShapeUtil::MakeShape(F32, {2, 1, 1});
977 DotDimensionNumbers dot_dnums;
978
979 dot_dnums.add_lhs_batch_dimensions(0);
980 dot_dnums.add_rhs_batch_dimensions(0);
981 dot_dnums.add_lhs_contracting_dimensions(1);
982 dot_dnums.add_lhs_contracting_dimensions(2);
983 dot_dnums.add_rhs_contracting_dimensions(1);
984 dot_dnums.add_rhs_contracting_dimensions(2);
985 b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction,
986 rhs_instruction, dot_dnums,
987 DefaultPrecisionConfig(2)));
988 m_->AddEntryComputation(b.Build());
989
990 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
991
992 float expected_1 = 0;
993 for (float i = 1.0f; i < 7.0f; ++i) {
994 expected_1 += i * i + i;
995 }
996 float expected_2 = 0;
997 for (float i = 7.0f; i < 13.0f; ++i) {
998 expected_2 += i * i + i;
999 }
1000 auto expected_array = Array3D<float>({{{expected_1}}, {{expected_2}}});
1001 auto expected = LiteralUtil::CreateR3FromArray3D<float>(expected_array);
1002
1003 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1004 }
1005
TEST_P(HloEvaluatorBf16Test,SimpleConv1D)1006 TEST_P(HloEvaluatorBf16Test, SimpleConv1D) {
1007 HloComputation::Builder b(TestName());
1008
1009 Array3D<float> lhs_array = {{{1, 2, 3}}};
1010 auto lhs_literal = LiteralUtil::CreateR3FromArray3D<float>(lhs_array);
1011 HloInstruction* lhs_instruction =
1012 b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
1013
1014 Array3D<float> rhs_array = {{{3.f, 4.f}}};
1015 auto rhs_literal = LiteralUtil::CreateR3FromArray3D<float>(rhs_array);
1016 HloInstruction* rhs_instruction =
1017 b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
1018
1019 Window window;
1020 WindowDimension dim;
1021 dim.set_size(2);
1022 dim.set_stride(1);
1023 dim.set_padding_low(0);
1024 dim.set_padding_high(1);
1025 dim.set_window_dilation(1);
1026 dim.set_base_dilation(1);
1027 *window.add_dimensions() = dim;
1028
1029 ConvolutionDimensionNumbers dnums;
1030 dnums.set_input_batch_dimension(0);
1031 dnums.set_output_batch_dimension(0);
1032 dnums.set_input_feature_dimension(1);
1033 dnums.set_output_feature_dimension(1);
1034 dnums.add_input_spatial_dimensions(2);
1035 dnums.add_output_spatial_dimensions(2);
1036
1037 dnums.set_kernel_output_feature_dimension(0);
1038 dnums.set_kernel_input_feature_dimension(1);
1039 dnums.add_kernel_spatial_dimensions(2);
1040
1041 Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 3});
1042 b.AddInstruction(HloInstruction::CreateConvolve(
1043 shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
1044 /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
1045 m_->AddEntryComputation(b.Build());
1046
1047 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1048
1049 Array3D<float> expected_array = {{{11.f, 18.f, 9.f}}};
1050 auto expected = LiteralUtil::CreateR3FromArray3D<float>(expected_array);
1051
1052 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1053 }
1054
TEST_P(HloEvaluatorBf16Test,Simple4x4Conv2DWith2x2Kernel)1055 TEST_P(HloEvaluatorBf16Test, Simple4x4Conv2DWith2x2Kernel) {
1056 HloComputation::Builder b(TestName());
1057
1058 Array4D<float> lhs_array(1, 1, 4, 4);
1059 // clang-format off
1060 lhs_array.FillWithYX(Array2D<float>({
1061 {1, 2, 3, 4 },
1062 {5, 6, 7, 8 },
1063 {9, 10, 11, 12},
1064 {13, 14, 15, 16},
1065 }));
1066 // clang-format on
1067 auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(lhs_array);
1068 HloInstruction* lhs_instruction =
1069 b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
1070
1071 Array4D<float> rhs_array(1, 1, 2, 2);
1072 // clang-format off
1073 rhs_array.FillWithYX(Array2D<float>({
1074 {5, 6},
1075 {7, 8},
1076 }));
1077 // clang-format on
1078 auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(rhs_array);
1079 HloInstruction* rhs_instruction =
1080 b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
1081
1082 Window window;
1083 WindowDimension dim;
1084 dim.set_size(2);
1085 dim.set_stride(1);
1086 dim.set_padding_low(0);
1087 dim.set_padding_high(1);
1088 dim.set_window_dilation(1);
1089 dim.set_base_dilation(1);
1090 *window.add_dimensions() = dim;
1091 *window.add_dimensions() = dim;
1092
1093 ConvolutionDimensionNumbers dnums =
1094 XlaBuilder::CreateDefaultConvDimensionNumbers(2);
1095
1096 Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
1097 b.AddInstruction(HloInstruction::CreateConvolve(
1098 shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
1099 /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
1100 m_->AddEntryComputation(b.Build());
1101
1102 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1103
1104 Array4D<float> expected_array(1, 1, 4, 4);
1105 // clang-format off
1106 expected_array.FillWithYX(Array2D<float>({
1107 {100, 126, 152, 76},
1108 {204, 230, 256, 124},
1109 {308, 334, 360, 172},
1110 {149, 160, 171, 80},
1111 }));
1112 // clang-format on
1113 auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
1114
1115 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1116 }
1117
TEST_P(HloEvaluatorBf16Test,Conv2DGeneralDimensionsReversed)1118 TEST_P(HloEvaluatorBf16Test, Conv2DGeneralDimensionsReversed) {
1119 HloComputation::Builder b(TestName());
1120
1121 // clang-format off
1122 // Input dimensions: [feature=2, height=3, batch=1, width=4]
1123 Array4D<float> input({
1124 {{{1, 2, 3, 4}},
1125 {{5, 6, 7, 8}},
1126 {{9, 10, 11, 12}}},
1127 {{{13, 14, 15, 16}},
1128 {{17, 18, 19, 20}},
1129 {{21, 22, 23, 24}}}
1130 });
1131 // Weight dimensions:
1132 // [kernel_output_feature=1, width=3, kernel_input_feature=2, height=3]
1133 Array4D<float> weight({{
1134 {{1, 7, 13},
1135 {4, 10, 16}},
1136 {{2, 8, 14},
1137 {5, 11, 17}},
1138 {{3, 9, 15},
1139 {6, 12, 18}}
1140 }});
1141 // clang-format on
1142
1143 auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(input);
1144 HloInstruction* lhs_instruction =
1145 b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
1146
1147 auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(weight);
1148 HloInstruction* rhs_instruction =
1149 b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
1150 rhs_instruction = b.AddInstruction(HloInstruction::CreateReverse(
1151 rhs_instruction->shape(), rhs_instruction, {3, 1}));
1152
1153 Window window;
1154 WindowDimension dim;
1155 dim.set_size(3);
1156 dim.set_stride(1);
1157 dim.set_padding_low(0);
1158 dim.set_padding_high(0);
1159 dim.set_window_dilation(1);
1160 dim.set_base_dilation(1);
1161 dim.set_window_reversal(true);
1162 *window.add_dimensions() = dim;
1163 *window.add_dimensions() = dim;
1164
1165 ConvolutionDimensionNumbers dnums;
1166 dnums.set_input_batch_dimension(2);
1167 dnums.set_output_batch_dimension(2);
1168 dnums.set_input_feature_dimension(0);
1169 dnums.set_output_feature_dimension(0);
1170 dnums.add_input_spatial_dimensions(1);
1171 dnums.add_output_spatial_dimensions(1);
1172 dnums.add_input_spatial_dimensions(3);
1173 dnums.add_output_spatial_dimensions(3);
1174
1175 dnums.set_kernel_output_feature_dimension(0);
1176 dnums.set_kernel_input_feature_dimension(2);
1177 dnums.add_kernel_spatial_dimensions(3);
1178 dnums.add_kernel_spatial_dimensions(1);
1179
1180 Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
1181 b.AddInstruction(HloInstruction::CreateConvolve(
1182 shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
1183 /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
1184 m_->AddEntryComputation(b.Build());
1185
1186 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1187
1188 // clang-format off
1189 // Result dimensions: [feature=1, height=1, batch=1, width=2]
1190 Array4D<float> expected_array({{{{2514, 2685}}}});
1191 Array4D<float> expected_array_bf16({{{{2512, 2688}}}});
1192 // clang-format on
1193 auto expected = LiteralUtil::CreateR4FromArray4D<float>(
1194 use_bfloat16_ ? expected_array_bf16 : expected_array);
1195
1196 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1197 }
1198
TEST_P(HloEvaluatorBf16Test,Conv2DGeneralDimensions)1199 TEST_P(HloEvaluatorBf16Test, Conv2DGeneralDimensions) {
1200 HloComputation::Builder b(TestName());
1201
1202 // clang-format off
1203 // Input dimensions: [feature=2, height=3, batch=1, width=4]
1204 Array4D<float> input({
1205 {{{1, 2, 3, 4}},
1206 {{5, 6, 7, 8}},
1207 {{9, 10, 11, 12}}},
1208 {{{13, 14, 15, 16}},
1209 {{17, 18, 19, 20}},
1210 {{21, 22, 23, 24}}}
1211 });
1212 // Weight dimensions:
1213 // [kernel_output_feature=1, width=3, kernel_input_feature=2, height=3]
1214 Array4D<float> weight({{
1215 {{1, 7, 13},
1216 {4, 10, 16}},
1217 {{2, 8, 14},
1218 {5, 11, 17}},
1219 {{3, 9, 15},
1220 {6, 12, 18}}
1221 }});
1222 // clang-format on
1223
1224 auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(input);
1225 HloInstruction* lhs_instruction =
1226 b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
1227
1228 auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(weight);
1229 HloInstruction* rhs_instruction =
1230 b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
1231
1232 Window window;
1233 WindowDimension dim;
1234 dim.set_size(3);
1235 dim.set_stride(1);
1236 dim.set_padding_low(0);
1237 dim.set_padding_high(0);
1238 dim.set_window_dilation(1);
1239 dim.set_base_dilation(1);
1240 *window.add_dimensions() = dim;
1241 *window.add_dimensions() = dim;
1242
1243 ConvolutionDimensionNumbers dnums;
1244 dnums.set_input_batch_dimension(2);
1245 dnums.set_output_batch_dimension(2);
1246 dnums.set_input_feature_dimension(0);
1247 dnums.set_output_feature_dimension(0);
1248 dnums.add_input_spatial_dimensions(1);
1249 dnums.add_output_spatial_dimensions(1);
1250 dnums.add_input_spatial_dimensions(3);
1251 dnums.add_output_spatial_dimensions(3);
1252
1253 dnums.set_kernel_output_feature_dimension(0);
1254 dnums.set_kernel_input_feature_dimension(2);
1255 dnums.add_kernel_spatial_dimensions(3);
1256 dnums.add_kernel_spatial_dimensions(1);
1257
1258 Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
1259 b.AddInstruction(HloInstruction::CreateConvolve(
1260 shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
1261 /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
1262 m_->AddEntryComputation(b.Build());
1263
1264 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1265
1266 // clang-format off
1267 // Result dimensions: [feature=1, height=1, batch=1, width=2]
1268 Array4D<float> expected_array({{{{2514, 2685}}}});
1269 Array4D<float> expected_array_bf16({{{{2512, 2688}}}});
1270 // clang-format on
1271 auto expected = LiteralUtil::CreateR4FromArray4D<float>(
1272 use_bfloat16_ ? expected_array_bf16 : expected_array);
1273
1274 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1275 }
1276
TEST_P(HloEvaluatorBf16Test,DilatedBaseConv2DWithHighPadding)1277 TEST_P(HloEvaluatorBf16Test, DilatedBaseConv2DWithHighPadding) {
1278 HloComputation::Builder b(TestName());
1279
1280 Array4D<float> lhs_array(1, 1, 4, 4);
1281 // clang-format off
1282 lhs_array.FillWithYX(Array2D<float>({
1283 {1, 2, 3, 4 },
1284 {5, 6, 7, 8 },
1285 {9, 10, 11, 12},
1286 {13, 14, 15, 16},
1287 }));
1288 // clang-format on
1289 auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(lhs_array);
1290 HloInstruction* lhs_instruction =
1291 b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
1292
1293 Array4D<float> rhs_array(1, 1, 2, 2);
1294 // clang-format off
1295 rhs_array.FillWithYX(Array2D<float>({
1296 {5, 6},
1297 {7, 8},
1298 }));
1299 // clang-format on
1300 auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(rhs_array);
1301 HloInstruction* rhs_instruction =
1302 b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
1303
1304 Window window;
1305 WindowDimension dim;
1306 dim.set_size(2);
1307 dim.set_stride(1);
1308 dim.set_padding_low(0);
1309 dim.set_padding_high(1);
1310 dim.set_window_dilation(1);
1311 dim.set_base_dilation(2);
1312 *window.add_dimensions() = dim;
1313 *window.add_dimensions() = dim;
1314
1315 ConvolutionDimensionNumbers dnums =
1316 XlaBuilder::CreateDefaultConvDimensionNumbers(2);
1317
1318 Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 7, 7});
1319 b.AddInstruction(HloInstruction::CreateConvolve(
1320 shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
1321 /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
1322 m_->AddEntryComputation(b.Build());
1323
1324 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1325
1326 Array4D<float> expected_array(1, 1, 7, 7);
1327 expected_array.FillWithYX(Array2D<float>({
1328 {5, 12, 10, 18, 15, 24, 20},
1329 {35, 48, 42, 56, 49, 64, 56},
1330 {25, 36, 30, 42, 35, 48, 40},
1331 {63, 80, 70, 88, 77, 96, 84},
1332 {45, 60, 50, 66, 55, 72, 60},
1333 {91, 112, 98, 120, 105, 128, 112},
1334 {65, 84, 70, 90, 75, 96, 80},
1335 }));
1336 auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
1337
1338 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1339 }
1340
TEST_P(HloEvaluatorBf16Test,DilatedBaseConv2DWithLowAndHighPadding)1341 TEST_P(HloEvaluatorBf16Test, DilatedBaseConv2DWithLowAndHighPadding) {
1342 HloComputation::Builder b(TestName());
1343
1344 Array4D<float> lhs_array(1, 1, 4, 4);
1345 // clang-format off
1346 lhs_array.FillWithYX(Array2D<float>({
1347 {1, 2, 3, 4 },
1348 {5, 6, 7, 8 },
1349 {9, 10, 11, 12},
1350 {13, 14, 15, 16},
1351 }));
1352 // clang-format on
1353 auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(lhs_array);
1354 HloInstruction* lhs_instruction =
1355 b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
1356
1357 Array4D<float> rhs_array(1, 1, 2, 2);
1358 // clang-format off
1359 rhs_array.FillWithYX(Array2D<float>({
1360 {5, 6},
1361 {7, 8},
1362 }));
1363 // clang-format on
1364 auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(rhs_array);
1365 HloInstruction* rhs_instruction =
1366 b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
1367
1368 Window window;
1369 WindowDimension dim;
1370 dim.set_size(2);
1371 dim.set_stride(1);
1372 dim.set_padding_low(1);
1373 dim.set_padding_high(1);
1374 dim.set_window_dilation(1);
1375 dim.set_base_dilation(2);
1376 *window.add_dimensions() = dim;
1377 *window.add_dimensions() = dim;
1378
1379 ConvolutionDimensionNumbers dnums =
1380 XlaBuilder::CreateDefaultConvDimensionNumbers(2);
1381
1382 Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 8, 8});
1383 b.AddInstruction(HloInstruction::CreateConvolve(
1384 shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
1385 /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
1386 m_->AddEntryComputation(b.Build());
1387
1388 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1389
1390 Array4D<float> expected_array(1, 1, 8, 8);
1391 expected_array.FillWithYX(Array2D<float>({
1392 {8, 7, 16, 14, 24, 21, 32, 28},
1393 {6, 5, 12, 10, 18, 15, 24, 20},
1394 {40, 35, 48, 42, 56, 49, 64, 56},
1395 {30, 25, 36, 30, 42, 35, 48, 40},
1396 {72, 63, 80, 70, 88, 77, 96, 84},
1397 {54, 45, 60, 50, 66, 55, 72, 60},
1398 {104, 91, 112, 98, 120, 105, 128, 112},
1399 {78, 65, 84, 70, 90, 75, 96, 80},
1400 }));
1401 auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
1402
1403 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1404 }
1405
TEST_P(HloEvaluatorBf16Test,DilatedWindowAndBaseConv2DWithDifferentLowAndHighPaddingAndStrides)1406 TEST_P(HloEvaluatorBf16Test,
1407 DilatedWindowAndBaseConv2DWithDifferentLowAndHighPaddingAndStrides) {
1408 HloComputation::Builder b(TestName());
1409
1410 Array4D<float> lhs_array(1, 1, 4, 4);
1411 // clang-format off
1412 lhs_array.FillWithYX(Array2D<float>({
1413 {1, 2, 3, 4 },
1414 {5, 6, 7, 8 },
1415 {9, 10, 11, 12},
1416 {13, 14, 15, 16},
1417 }));
1418 // clang-format on
1419 auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(lhs_array);
1420 HloInstruction* lhs_instruction =
1421 b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
1422
1423 Array4D<float> rhs_array(1, 1, 2, 3);
1424 // clang-format off
1425 rhs_array.FillWithYX(Array2D<float>({
1426 {5, 6, 7},
1427 {8, 9, 10},
1428 }));
1429 // clang-format on
1430 auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(rhs_array);
1431 HloInstruction* rhs_instruction =
1432 b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
1433
1434 Window window;
1435 WindowDimension dim;
1436 dim.set_size(2);
1437 dim.set_stride(1);
1438 dim.set_padding_low(2);
1439 dim.set_padding_high(2);
1440 dim.set_window_dilation(2);
1441 dim.set_base_dilation(2);
1442 *window.add_dimensions() = dim;
1443 dim.set_size(3);
1444 dim.set_stride(3);
1445 dim.set_padding_low(2);
1446 dim.set_padding_high(-1);
1447 dim.set_window_dilation(1);
1448 dim.set_base_dilation(3);
1449 *window.add_dimensions() = dim;
1450
1451 ConvolutionDimensionNumbers dnums =
1452 XlaBuilder::CreateDefaultConvDimensionNumbers(2);
1453
1454 Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 9, 3});
1455 b.AddInstruction(HloInstruction::CreateConvolve(
1456 shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
1457 /*batch_group_count=*/1, window, dnums, DefaultPrecisionConfig(2)));
1458 m_->AddEntryComputation(b.Build());
1459
1460 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1461
1462 Array4D<float> expected_array(1, 1, 9, 3);
1463 expected_array.FillWithYX(Array2D<float>({
1464 {10, 20, 30},
1465 {0, 0, 0},
1466 {57, 74, 91},
1467 {0, 0, 0},
1468 {125, 142, 159},
1469 {0, 0, 0},
1470 {193, 210, 227},
1471 {0, 0, 0},
1472 {91, 98, 105},
1473 }));
1474 auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
1475
1476 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1477 }
1478
TEST_P(HloEvaluatorBf16Test,Conv2DGroupedConvolution)1479 TEST_P(HloEvaluatorBf16Test, Conv2DGroupedConvolution) {
1480 HloComputation::Builder b(TestName());
1481 std::vector<int64_t> input_dims = {1, 2, 2, 4};
1482 std::vector<int64_t> filter_dims = {2, 2, 2, 8};
1483 Shape input_shape = ShapeUtil::MakeShapeWithType<float>(input_dims);
1484 Shape filter_shape = ShapeUtil::MakeShapeWithType<float>(filter_dims);
1485 // Tensorflow dimension numbers for 2D convolution.
1486 ConvolutionDimensionNumbers dnums;
1487 dnums.set_input_batch_dimension(0);
1488 dnums.set_output_batch_dimension(0);
1489 dnums.add_input_spatial_dimensions(1);
1490 dnums.add_output_spatial_dimensions(1);
1491 dnums.add_input_spatial_dimensions(2);
1492 dnums.add_output_spatial_dimensions(2);
1493 dnums.set_input_feature_dimension(3);
1494 dnums.set_output_feature_dimension(3);
1495 dnums.add_kernel_spatial_dimensions(0);
1496 dnums.add_kernel_spatial_dimensions(1);
1497 dnums.set_kernel_input_feature_dimension(2);
1498 dnums.set_kernel_output_feature_dimension(3);
1499
1500 Window window;
1501 WindowDimension dim;
1502 dim.set_size(2);
1503 dim.set_stride(1);
1504 dim.set_padding_low(0);
1505 dim.set_padding_high(0);
1506 dim.set_window_dilation(1);
1507 dim.set_base_dilation(1);
1508 *window.add_dimensions() = dim;
1509 *window.add_dimensions() = dim;
1510
1511 std::vector<float> input_elems(ShapeUtil::ElementsIn(input_shape));
1512 std::iota(input_elems.begin(), input_elems.end(), -7);
1513 auto input_r1 = LiteralUtil::CreateR1<float>(input_elems);
1514 auto input_r4 = input_r1.Reshape(input_dims).value();
1515 HloInstruction* lhs_instruction =
1516 b.AddInstruction(HloInstruction::CreateConstant(std::move(input_r4)));
1517
1518 std::vector<float> filter_elems(ShapeUtil::ElementsIn(filter_shape));
1519 std::iota(filter_elems.begin(), filter_elems.end(), -31);
1520 auto filter_r1 = LiteralUtil::CreateR1<float>(filter_elems);
1521 auto filter_r4 = filter_r1.Reshape(filter_dims).value();
1522 HloInstruction* rhs_instruction =
1523 b.AddInstruction(HloInstruction::CreateConstant(std::move(filter_r4)));
1524
1525 Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 8});
1526 b.AddInstruction(HloInstruction::CreateConvolve(
1527 shape, lhs_instruction, rhs_instruction,
1528 /*feature_group_count=*/2, /*batch_group_count=*/1, window, dnums,
1529 DefaultPrecisionConfig(2)));
1530 m_->AddEntryComputation(b.Build());
1531
1532 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
1533
1534 Array4D<float> expected_array(1, 1, 1, 8);
1535 expected_array.FillWithYX(
1536 Array2D<float>({{668, 664, 660, 656, 668, 680, 692, 704}}));
1537 auto expected = LiteralUtil::CreateR4FromArray4D<float>(expected_array);
1538 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
1539 }
1540
1541 // Initialization of data sets for FFT tests:
1542
InitializeFftData()1543 void HloEvaluatorTest::InitializeFftData() {
1544 // clang-format off
1545 fft_c64x2x4x8_ = LiteralUtil::CreateR3<complex64>({
1546 {{{0.0, 0.0}, {1.0, 0.0}, {2.0, 0.0}, {3.0, 0.0},
1547 {4.0, 0.0}, {5.0, 0.0}, {6.0, 0.0}, {7.0, 0.0}},
1548 {{0.0, 0.0}, {0.0, 1.0}, {0.0, 2.0}, {0.0, 3.0},
1549 {0.0, 4.0}, {0.0, 5.0}, {0.0, 6.0}, {0.0, 7.0}},
1550 {{0.0, 7.0}, {1.0, 6.0}, {2.0, 5.0}, {3.0, 4.0},
1551 {4.0, 3.0}, {5.0, 2.0}, {6.0, 1.0}, {7.0, 0.0}},
1552 {{7.0, 0.0}, {6.0, 1.0}, {5.0, 2.0}, {4.0, 3.0},
1553 {3.0, 4.0}, {2.0, 5.0}, {1.0, 6.0}, {0.0, 7.0}}},
1554 {{{-4.0, 0.0}, {-3.0, 0.0}, {-2.0, 0.0}, {-1.0, 0.0},
1555 {1.0, 0.0}, {2.0, 0.0}, {3.0, 0.0}, {4.0, 0.0}},
1556 {{0.0, -4.0}, {0.0, -3.0}, {0.0, -2.0}, {0.0, -1.0},
1557 {0.0, 1.0}, {0.0, 2.0}, {0.0, 3.0}, {0.0, 4.0}},
1558 {{3.5, 3.5}, {-1.707107, -0.707107}, {-1.0, -0.0}, {-0.707107, 0.292893},
1559 {-0.5, 0.5}, {-0.292893, 0.707107}, {0.0, 1.0}, {0.707107, 1.707107}},
1560 {{3.5, 3.5}, {1.707107, 0.707107}, {1.0, 0.0}, {0.707107, -0.292893},
1561 {0.5, -0.5}, {0.292893, -0.707107}, {-0.0, -1.0}, {-0.707107, -1.707107}}}
1562 });
1563 fft_c64x2x4x8_1d_ = LiteralUtil::CreateR3<complex64>({
1564 {{{28.0, 0.0}, {-4.0, 9.656854}, {-4.0, 4.0}, {-4.0, 1.656854},
1565 {-4.0, 0.0}, {-4.0, -1.656854}, {-4.0, -4.0}, {-4.0, -9.656854}},
1566 {{0.0, 28.0}, {-9.656854, -4.0}, {-4.0, -4.0}, {-1.656854, -4.0},
1567 {0.0, -4.0}, {1.656854, -4.0}, {4.0, -4.0}, {9.656854, -4.0}},
1568 {{28.0, 28.0}, {5.656854, 13.656854}, {0.0, 8.0}, {-2.343146, 5.656854},
1569 {-4.0, 4.0}, {-5.656854, 2.343146}, {-8.0, -0.0}, {-13.656854, -5.656854}}, // NOLINT
1570 {{28.0, 28.0}, {-5.656854, -13.656854}, {-0.0, -8.0}, {2.343146, -5.656854}, // NOLINT
1571 {4.0, -4.0}, {5.656854, -2.343146}, {8.0, 0.0}, {13.656854, 5.656854}}},
1572 {{{0.0, 0.0}, {-5.0, 12.071068}, {-4.0, 4.0}, {-5.0, 2.071068},
1573 {-4.0, 0.0}, {-5.0, -2.071068}, {-4.0, -4.0}, {-5.0, -12.071068}},
1574 {{0.0, 0.0}, {-12.071068, -5.0}, {-4.0, -4.0}, {-2.071068, -5.0},
1575 {0.0, -4.0}, {2.071068, -5.0}, {4.0, -4.0}, {12.071068, -5.0}},
1576 {{0.0, 7.0}, {1.0, 6.0}, {2.0, 5.0}, {3.0, 4.0},
1577 {4.0, 3.0}, {5.0, 2.0}, {6.0, 1.0}, {7.0, 0.0}},
1578 {{7.0, 0.0}, {6.0, 1.0}, {5.0, 2.0}, {4.0, 3.0},
1579 {3.0, 4.0}, {2.0, 5.0}, {1.0, 6.0}, {0.0, 7.0}}}
1580 });
1581 fft_c64x2x4x8_2d_ = LiteralUtil::CreateR3<complex64>({
1582 {{{84.0, 84.0}, {-13.656854, 5.656854}, {-8.0, 0.0}, {-5.656854, -2.343146},
1583 {-4.0, -4.0}, {-2.343146, -5.656854}, {0.0, -8.0}, {5.656854, -13.656854}}, // NOLINT
1584 {{0.0, 0.0}, {0.0, -0.0}, {0.0, 0.0}, {0.0, 0.0},
1585 {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
1586 {{28.0, -28.0}, {16.970562, 40.970562}, {0.0, 24.0}, {-7.029438, 16.970562}, // NOLINT
1587 {-12.0, 12.0}, {-16.970562, 7.029438}, {-24.0, 0.0}, {-40.970562, -16.970562}}, // NOLINT
1588 {{0.0, -56.0}, {-19.313708, -8.0}, {-8.0, -8.0}, {-3.313708, -8.0},
1589 {0.0, -8.0}, {3.313708, -8.0}, {8.0, -8.0}, {19.313708, -8.0}}},
1590 {{{7.0, 7.0}, {-10.071068, 14.071068}, {-1.0, 7.0}, {-0.071068, 4.071068},
1591 {3.0, 3.0}, {4.071068, -0.071068}, {7.0, -1.0}, {14.071068, -10.071068}},
1592 {{0.0, 0.0}, {-12.0, 24.142136}, {-12.0, 8.0}, {-16.0, 4.142136},
1593 {-16.0, 0.0}, {-20.0, -4.142136}, {-20.0, -8.0}, {-24.0, -24.142136}},
1594 {{-7.0, 7.0}, {2.071068, 22.071068}, {-3.0, 11.0}, {-3.928932, 8.071068},
1595 {-3.0, 3.0}, {-4.071068, -0.071068}, {-3.0, -5.0}, {-10.071068, -14.071068}}, // NOLINT
1596 {{0.0, -14.0}, {0.0, -12.0}, {0.0, -10.0}, {0.0, -8.0},
1597 {0.0, -6.0}, {0.0, -4.0}, {0.0, -2.0}, {0.0, 0.0}}}
1598 });
1599 fft_c64x2x4x8_3d_ = LiteralUtil::CreateR3<complex64>({
1600 {{{91.0, 91.0}, {-23.727922, 19.727922}, {-9.0, 7.0}, {-5.727922, 1.727922},
1601 {-1.0, -1.0}, {1.727922, -5.727922}, {7.0, -9}, {19.727922, -23.727922}},
1602 {{0.0, 0.0}, {-12.0, 24.142136}, {-12.0, 8.0}, {-16.0, 4.142136},
1603 {-16.0, 0.0}, {-20.0, -4.142136}, {-20.0, -8.0}, {-24.0, -24.142136}},
1604 {{21.0, -21.0}, {19.041630, 63.041630}, {-3.0, 35.0}, {-10.958370, 25.041630}, // NOLINT
1605 {-15.0, 15.0}, {-21.041630, 6.958370}, {-27.0, -5.0}, {-51.041630, -31.041630}}, // NOLINT
1606 {{0.0, -70.0}, {-19.313708, -20.0}, {-8.0, -18.0}, {-3.313708, -16.0},
1607 {0.0, -14.0}, {3.313708, -12.0}, {8.0, -10.0}, {19.313708, -8.0}}},
1608 {{{77.0, 77.0}, {-3.585786, -8.414214}, {-7.0, -7.0}, {-5.585786, -6.414214}, // NOLINT
1609 {-7.0, -7.0}, {-6.414214, -5.585786}, {-7.0, -7.0}, {-8.414214, -3.585786}}, // NOLINT
1610 {{0.0, 0.0}, {12.0, -24.142136}, {12.0, -8.0}, {16.0, -4.142136},
1611 {16.0, 0.0}, {20.0, 4.142136}, {20.0, 8.0}, {24.0, 24.142136}},
1612 {{35.0, -35.0}, {14.899494, 18.899494}, {3.0, 13.0}, {-3.100506, 8.899494},
1613 {-9.0, 9.0}, {-12.899494, 7.100506}, {-21.0, 5.0}, {-30.899494, -2.899494}}, // NOLINT
1614 {{0.0, -42.0}, {-19.313708, 4.0}, {-8.0, 2.0}, {-3.313708, 0.0},
1615 {0.0, -2.0}, {3.313708, -4.0}, {8.0, -6.0}, {19.313708, -8.0}}}
1616 });
1617 // clang-format on
1618 }
1619
1620 // Simple FFT tests:
1621
1622 TEST_F(HloEvaluatorTest, 1D_FFT_4_on_c64x4) {
1623 const char* hlo_text = R"(
1624 HloModule Fft
1625
1626 ENTRY main {
1627 operand = c64[4] parameter(0)
1628 ROOT fft = c64[4] fft(operand), fft_type=FFT, fft_length={4}
1629 }
1630 )";
1631 auto input = LiteralUtil::CreateR1<complex64>(
1632 {{1.0, 0.0}, {2.0, 0.0}, {3.0, 0.0}, {4.0, 0.0}});
1633 auto expected = LiteralUtil::CreateR1<complex64>(
1634 {{10.0, 0.0}, {-2.0, 2.0}, {-2.0, 0.0}, {-2.0, -2.0}});
1635 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1636 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
1637 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
1638 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
1639 }
1640
1641 TEST_F(HloEvaluatorTest, 1D_IFFT_4_on_c64x4) {
1642 const char* hlo_text = R"(
1643 HloModule Fft
1644
1645 ENTRY main {
1646 operand = c64[4] parameter(0)
1647 ROOT ifft = c64[4] fft(operand), fft_type=IFFT, fft_length={4}
1648 }
1649 )";
1650 auto input = LiteralUtil::CreateR1<complex64>(
1651 {{10.0, 0.0}, {-2.0, 2.0}, {-2.0, 0.0}, {-2.0, -2.0}});
1652 auto expected = LiteralUtil::CreateR1<complex64>(
1653 {{1.0, 0.0}, {2.0, 0.0}, {3.0, 0.0}, {4.0, 0.0}});
1654 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1655 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
1656 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
1657 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
1658 }
1659
1660 TEST_F(HloEvaluatorTest, 1D_RFFT_4_on_f32x4) {
1661 const char* hlo_text = R"(
1662 HloModule Fft
1663
1664 ENTRY main {
1665 operand = f32[4] parameter(0)
1666 ROOT rfft = c64[3] fft(operand), fft_type=RFFT, fft_length={4}
1667 }
1668 )";
1669 auto input = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, 4.0});
1670 auto expected =
1671 LiteralUtil::CreateR1<complex64>({{10.0, 0.0}, {-2.0, 2.0}, {-2.0, 0.0}});
1672 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1673 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
1674 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
1675 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
1676 }
1677
1678 TEST_F(HloEvaluatorTest, 1D_IRFFT_4_on_c64x3) {
1679 const char* hlo_text = R"(
1680 HloModule Fft
1681
1682 ENTRY main {
1683 operand = c64[3] parameter(0)
1684 ROOT irfft = f32[4] fft(operand), fft_type=IRFFT, fft_length={4}
1685 }
1686 )";
1687 auto input =
1688 LiteralUtil::CreateR1<complex64>({{10.0, 0.0}, {-2.0, 2.0}, {-2.0, 0.0}});
1689 auto expected = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, 4.0});
1690 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1691 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
1692 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
1693 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
1694 }
1695
1696 // 1D FFT tests:
1697
1698 TEST_F(HloEvaluatorTest, 1D_FFT_8_on_c64x2x4x8) {
1699 const char* hlo_text = R"(
1700 HloModule Fft
1701
1702 ENTRY main {
1703 operand = c64[2, 4, 8] parameter(0)
1704 ROOT fft = c64[2, 4, 8] fft(operand), fft_type=FFT, fft_length={8}
1705 }
1706 )";
1707 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1708 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&fft_c64x2x4x8_}));
1709 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_1d_.shape()));
1710 EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_1d_, result, fft_error_));
1711 }
1712
1713 TEST_F(HloEvaluatorTest, 1D_IFFT_8_on_c64x2x4x8) {
1714 const char* hlo_text = R"(
1715 HloModule Fft
1716
1717 ENTRY main {
1718 operand = c64[2, 4, 8] parameter(0)
1719 ROOT ifft = c64[2, 4, 8] fft(operand), fft_type=IFFT, fft_length={8}
1720 }
1721 )";
1722 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1723 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&fft_c64x2x4x8_1d_}));
1724 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_.shape()));
1725 EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_, result, fft_error_));
1726 }
1727
1728 TEST_F(HloEvaluatorTest, 1D_RFFT_8_on_f32x8) {
1729 const char* hlo_text = R"(
1730 HloModule Fft
1731
1732 ENTRY main {
1733 operand = f32[8] parameter(0)
1734 ROOT rfft = c64[5] fft(operand), fft_type=RFFT, fft_length={8}
1735 }
1736 )";
1737 auto input =
1738 LiteralUtil::CreateR1<float>({1.8, 2.7, 3.6, 4.5, 5.4, 6.3, 7.2, 8.1});
1739 auto expected = LiteralUtil::CreateR1<complex64>({{39.6, 0.0},
1740 {-3.6, 8.691169},
1741 {-3.6, 3.6},
1742 {-3.6, 1.491169},
1743 {-3.6, 0.0}});
1744 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1745 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
1746 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
1747 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
1748 }
1749
1750 TEST_F(HloEvaluatorTest, 1D_IRFFT_8_on_c64x5) {
1751 const char* hlo_text = R"(
1752 HloModule Fft
1753
1754 ENTRY main {
1755 operand = c64[5] parameter(0)
1756 ROOT irfft = f32[8] fft(operand), fft_type=IRFFT, fft_length={8}
1757 }
1758 )";
1759 auto input = LiteralUtil::CreateR1<complex64>({{39.6, 0.0},
1760 {-3.6, 8.691169},
1761 {-3.6, 3.6},
1762 {-3.6, 1.491169},
1763 {-3.6, 0.0}});
1764 auto expected =
1765 LiteralUtil::CreateR1<float>({1.8, 2.7, 3.6, 4.5, 5.4, 6.3, 7.2, 8.1});
1766 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1767 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
1768 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
1769 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
1770 }
1771
1772 TEST_F(HloEvaluatorTest, 1D_RFFT_9_on_f32x9) {
1773 const char* hlo_text = R"(
1774 HloModule Fft
1775
1776 ENTRY main {
1777 operand = f32[9] parameter(0)
1778 ROOT rfft = c64[5] fft(operand), fft_type=RFFT, fft_length={9}
1779 }
1780 )";
1781 auto input = LiteralUtil::CreateR1<float>(
1782 {1.8, 2.7, 3.6, 4.5, 5.4, 6.3, 7.2, 8.1, 9.9});
1783 auto expected = LiteralUtil::CreateR1<complex64>({{49.5, 0.0},
1784 {-3.360560, 11.705792},
1785 {-3.893717, 5.712929},
1786 {-4.5, 3.117691},
1787 {-4.895723, 1.021942}});
1788 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1789 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
1790 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
1791 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
1792 }
1793
1794 TEST_F(HloEvaluatorTest, 1D_IRFFT_9_on_c64x5) {
1795 const char* hlo_text = R"(
1796 HloModule Fft
1797
1798 ENTRY main {
1799 operand = c64[5] parameter(0)
1800 ROOT irfft = f32[9] fft(operand), fft_type=IRFFT, fft_length={9}
1801 }
1802 )";
1803 auto input = LiteralUtil::CreateR1<complex64>({{49.5, 0.0},
1804 {-3.360560, 11.705792},
1805 {-3.893717, 5.712929},
1806 {-4.5, 3.117691},
1807 {-4.895723, 1.021942}});
1808 auto expected = LiteralUtil::CreateR1<float>(
1809 {1.8, 2.7, 3.6, 4.5, 5.4, 6.3, 7.2, 8.1, 9.9});
1810 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1811 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
1812 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
1813 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
1814 }
1815
1816 // 2D FFT tests:
1817
1818 TEST_F(HloEvaluatorTest, 2D_FFT_4x8_on_c64x2x4x8) {
1819 const char* hlo_text = R"(
1820 HloModule Fft
1821
1822 ENTRY main {
1823 operand = c64[2, 4, 8] parameter(0)
1824 ROOT fft = c64[2, 4, 8] fft(operand), fft_type=FFT, fft_length={4, 8}
1825 }
1826 )";
1827 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1828 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&fft_c64x2x4x8_}));
1829 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_2d_.shape()));
1830 EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_2d_, result, fft_error_));
1831 }
1832
1833 TEST_F(HloEvaluatorTest, 2D_IFFT_4x8_on_c64x2x4x8) {
1834 const char* hlo_text = R"(
1835 HloModule Fft
1836
1837 ENTRY main {
1838 operand = c64[2, 4, 8] parameter(0)
1839 ROOT ifft = c64[2, 4, 8] fft(operand), fft_type=IFFT, fft_length={4, 8}
1840 }
1841 )";
1842 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1843 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&fft_c64x2x4x8_2d_}));
1844 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_.shape()));
1845 EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_, result, fft_error_));
1846 }
1847
1848 TEST_F(HloEvaluatorTest, 2D_RFFT_3x8_on_f32x3x8) {
1849 const char* hlo_text = R"(
1850 HloModule Fft
1851
1852 ENTRY main {
1853 operand = f32[3, 8] parameter(0)
1854 ROOT rfft = c64[3, 5] fft(operand), fft_type=RFFT, fft_length={3, 8}
1855 }
1856 )";
1857 auto input =
1858 LiteralUtil::CreateR2<float>({{1.8, 2.7, 3.6, 4.5, 5.4, 6.3, 7.2, 8.1},
1859 {8.1, 7.2, 6.3, 5.4, 4.5, 3.6, 2.7, 1.8},
1860 {1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8}});
1861 auto expected = LiteralUtil::CreateR2<complex64>({{{118.8, 0.0},
1862 {-4.4, 10.622540},
1863 {-4.4, 4.4},
1864 {-4.4, 1.822540},
1865 {-4.4, 0.0}},
1866 {{0.0, 0.0},
1867 {-19.926162, 0.797280},
1868 {-10.128203, -3.728203},
1869 {-6.069756, -5.602720},
1870 {-3.2, -6.928203}},
1871 {{0.0, 0.0},
1872 {13.526162, 14.653687},
1873 {3.728203, 10.128203},
1874 {-0.330244, 8.253687},
1875 {-3.2, 6.928203}}});
1876 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1877 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
1878 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
1879 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
1880 }
1881
1882 TEST_F(HloEvaluatorTest, 2D_IRFFT_3x8_on_c64x3x5) {
1883 const char* hlo_text = R"(
1884 HloModule Fft
1885
1886 ENTRY main {
1887 operand = c64[3, 5] parameter(0)
1888 ROOT irfft = f32[3, 8] fft(operand), fft_type=IRFFT, fft_length={3, 8}
1889 }
1890 )";
1891 auto input = LiteralUtil::CreateR2<complex64>({{{118.8, 0.0},
1892 {-4.4, 10.622540},
1893 {-4.4, 4.4},
1894 {-4.4, 1.822540},
1895 {-4.4, 0.0}},
1896 {{0.0, 0.0},
1897 {-19.926162, 0.797280},
1898 {-10.128203, -3.728203},
1899 {-6.069756, -5.602720},
1900 {-3.2, -6.928203}},
1901 {{0.0, 0.0},
1902 {13.526162, 14.653687},
1903 {3.728203, 10.128203},
1904 {-0.330244, 8.253687},
1905 {-3.2, 6.928203}}});
1906 auto expected =
1907 LiteralUtil::CreateR2<float>({{1.8, 2.7, 3.6, 4.5, 5.4, 6.3, 7.2, 8.1},
1908 {8.1, 7.2, 6.3, 5.4, 4.5, 3.6, 2.7, 1.8},
1909 {1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8}});
1910 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1911 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
1912 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
1913 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
1914 }
1915
1916 TEST_F(HloEvaluatorTest, 2D_RFFT_3x9_on_f32x3x9) {
1917 const char* hlo_text = R"(
1918 HloModule Fft
1919
1920 ENTRY main {
1921 operand = f32[3, 9] parameter(0)
1922 ROOT rfft = c64[3, 5] fft(operand), fft_type=RFFT, fft_length={3, 9}
1923 }
1924 )";
1925 auto input = LiteralUtil::CreateR2<float>(
1926 {{1.9, 2.8, 3.7, 4.6, 5.5, 6.4, 7.3, 8.2, 9.1},
1927 {9.1, 8.2, 7.3, 6.4, 5.5, 4.6, 3.7, 2.8, 1.9},
1928 {1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9}});
1929 auto expected = LiteralUtil::CreateR2<complex64>({{{148.5, 0.0},
1930 {-4.95, 13.600013},
1931 {-4.95, 5.899180},
1932 {-4.95, 2.857884},
1933 {-4.95, 0.872819}},
1934 {{0.0, 0.0},
1935 {-25.014467, 2.096690},
1936 {-12.888800, -3.503916},
1937 {-8.1, -5.715768},
1938 {-4.974333, -7.159452}},
1939 {{0.0, 0.0},
1940 {17.814467, 17.685147},
1941 {5.688800, 12.084542},
1942 {0.9, 9.872690},
1943 {-2.225667, 8.429006}}});
1944 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1945 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
1946 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
1947 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
1948 }
1949
1950 TEST_F(HloEvaluatorTest, 2D_IRFFT_3x9_on_c64x3x5) {
1951 const char* hlo_text = R"(
1952 HloModule Fft
1953
1954 ENTRY main {
1955 operand = c64[3, 5] parameter(0)
1956 ROOT irfft = f32[3, 9] fft(operand), fft_type=IRFFT, fft_length={3, 9}
1957 }
1958 )";
1959 auto input = LiteralUtil::CreateR2<complex64>({{{148.5, 0.0},
1960 {-4.95, 13.600013},
1961 {-4.95, 5.899180},
1962 {-4.95, 2.857884},
1963 {-4.95, 0.872819}},
1964 {{0.0, 0.0},
1965 {-25.014467, 2.096690},
1966 {-12.888800, -3.503916},
1967 {-8.1, -5.715768},
1968 {-4.974333, -7.159452}},
1969 {{0.0, 0.0},
1970 {17.814467, 17.685147},
1971 {5.688800, 12.084542},
1972 {0.9, 9.872690},
1973 {-2.225667, 8.429006}}});
1974 auto expected = LiteralUtil::CreateR2<float>(
1975 {{1.9, 2.8, 3.7, 4.6, 5.5, 6.4, 7.3, 8.2, 9.1},
1976 {9.1, 8.2, 7.3, 6.4, 5.5, 4.6, 3.7, 2.8, 1.9},
1977 {1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9}});
1978 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1979 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
1980 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
1981 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
1982 }
1983
1984 // 3D FFT tests:
1985
1986 TEST_F(HloEvaluatorTest, 3D_FFT_2x4x8_on_c64x2x4x8) {
1987 const char* hlo_text = R"(
1988 HloModule Fft
1989
1990 ENTRY main {
1991 operand = c64[2, 4, 8] parameter(0)
1992 ROOT fft = c64[2, 4, 8] fft(operand), fft_type=FFT, fft_length={2, 4, 8}
1993 }
1994 )";
1995 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
1996 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&fft_c64x2x4x8_}));
1997 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_3d_.shape()));
1998 EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_3d_, result, fft_error_));
1999 }
2000
2001 TEST_F(HloEvaluatorTest, 3D_IFFT_2x4x8_on_c64x2x4x8) {
2002 const char* hlo_text = R"(
2003 HloModule Fft
2004
2005 ENTRY main {
2006 operand = c64[2, 4, 8] parameter(0)
2007 ROOT ifft = c64[2, 4, 8] fft(operand), fft_type=IFFT, fft_length={2, 4, 8}
2008 }
2009 )";
2010 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2011 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&fft_c64x2x4x8_3d_}));
2012 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_.shape()));
2013 EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_, result, fft_error_));
2014 }
2015
2016 TEST_F(HloEvaluatorTest, 3D_RFFT_3x3x4_on_f32x3x3x4) {
2017 const char* hlo_text = R"(
2018 HloModule Fft
2019
2020 ENTRY main {
2021 operand = f32[3, 3, 4] parameter(0)
2022 ROOT rfft = c64[3, 3, 3] fft(operand), fft_type=RFFT, fft_length={3, 3, 4}
2023 }
2024 )";
2025 auto input = LiteralUtil::CreateR3<float>(
2026 {{{1.8, 2.7, 3.6, 4.5}, {8.1, 7.2, 6.3, 5.4}, {1.1, 2.2, 3.3, 4.4}},
2027 {{5.4, 6.3, 7.2, 8.1}, {4.5, 3.6, 2.7, 1.8}, {5.5, 6.6, 7.7, 8.8}},
2028 {{-1.8, -2.7, -3.6, -4.5},
2029 {-5.4, -6.3, -7.2, -8.1},
2030 {1.9, 2.9, 3.9, 4.9}}});
2031 auto expected = LiteralUtil::CreateR3<complex64>(
2032 {{{{92.8, 0.0}, {-2.8, 2.8}, {-2.8, 0.0}},
2033 {{-5.9, 35.160631}, {-11.519100, -8.919100}, {-1.3, -10.219100}},
2034 {{-5.9, -35.160631}, {8.919100, 11.519100}, {-1.3, 10.219100}}},
2035 {{{29.5, -81.579593}, {1.390897, 5.190897}, {-1.9, 3.290897}},
2036 {{-25.1, -49.017038}, {1.044486, 4.844486}, {-1.9, 2.944486}},
2037 {{11.8, 27.712813}, {1.517691, 4.717691}, {-1.6, 3.117691}}},
2038 {{{29.5, 81.579593}, {-5.190897, -1.390897}, {-1.9, -3.290897}},
2039 {{11.8, -27.712813}, {-4.717691, -1.517691}, {-1.6, -3.117691}},
2040 {{-25.1, 49.017038}, {-4.844486, -1.044486}, {-1.9, -2.944486}}}});
2041 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2042 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2043 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2044 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2045 }
2046
2047 TEST_F(HloEvaluatorTest, 3D_IRFFT_3x3x4_on_c64x3x3x3) {
2048 const char* hlo_text = R"(
2049 HloModule Fft
2050
2051 ENTRY main {
2052 operand = c64[3, 3, 3] parameter(0)
2053 ROOT irfft = f32[3, 3, 4] fft(operand), fft_type=IRFFT, fft_length={3, 3, 4}
2054 }
2055 )";
2056 auto input = LiteralUtil::CreateR3<complex64>(
2057 {{{{92.8, 0.0}, {-2.8, 2.8}, {-2.8, 0.0}},
2058 {{-5.9, 35.160631}, {-11.519100, -8.919100}, {-1.3, -10.219100}},
2059 {{-5.9, -35.160631}, {8.919100, 11.519100}, {-1.3, 10.219100}}},
2060 {{{29.5, -81.579593}, {1.390897, 5.190897}, {-1.9, 3.290897}},
2061 {{-25.1, -49.017038}, {1.044486, 4.844486}, {-1.9, 2.944486}},
2062 {{11.8, 27.712813}, {1.517691, 4.717691}, {-1.6, 3.117691}}},
2063 {{{29.5, 81.579593}, {-5.190897, -1.390897}, {-1.9, -3.290897}},
2064 {{11.8, -27.712813}, {-4.717691, -1.517691}, {-1.6, -3.117691}},
2065 {{-25.1, 49.017038}, {-4.844486, -1.044486}, {-1.9, -2.944486}}}});
2066 auto expected = LiteralUtil::CreateR3<float>(
2067 {{{1.8, 2.7, 3.6, 4.5}, {8.1, 7.2, 6.3, 5.4}, {1.1, 2.2, 3.3, 4.4}},
2068 {{5.4, 6.3, 7.2, 8.1}, {4.5, 3.6, 2.7, 1.8}, {5.5, 6.6, 7.7, 8.8}},
2069 {{-1.8, -2.7, -3.6, -4.5},
2070 {-5.4, -6.3, -7.2, -8.1},
2071 {1.9, 2.9, 3.9, 4.9}}});
2072 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2073 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2074 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2075 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2076 }
2077
2078 TEST_F(HloEvaluatorTest, 3D_RFFT_3x3x5_on_f32x3x3x5) {
2079 const char* hlo_text = R"(
2080 HloModule Fft
2081
2082 ENTRY main {
2083 operand = f32[3, 3, 5] parameter(0)
2084 ROOT rfft = c64[3, 3, 3] fft(operand), fft_type=RFFT, fft_length={3, 3, 5}
2085 }
2086 )";
2087 auto input = LiteralUtil::CreateR3<float>({{{1.8, 2.7, 3.6, 4.5, 5.4},
2088 {8.1, 7.2, 6.3, 5.4, 4.5},
2089 {1.1, 2.2, 3.3, 4.4, 5.5}},
2090 {{5.4, 6.3, 7.2, 8.1, 9.0},
2091 {4.5, 3.6, 2.7, 1.8, 0.9},
2092 {5.5, 6.6, 7.7, 8.8, 9.9}},
2093 {{-1.8, -2.7, -3.6, -4.5, -5.4},
2094 {-5.4, -6.3, -7.2, -8.1, -9.0},
2095 {1.9, 2.9, 3.9, 4.9, 5.9}}});
2096 auto expected = LiteralUtil::CreateR3<complex64>(
2097 {{{{119.5, 0.0}, {-3.5, 4.817337}, {-3.5, 1.137219}},
2098 {{-5.75, 56.724664}, {-19.206730, -10.537254}, {-5.775483, -12.245880}},
2099 {{-5.75, -56.724664}, {15.956730, 15.010495}, {2.525483, 13.301869}}},
2100 {{{39.25, -106.088112}, {3.286913, 7.382528}, {-1.038404, 4.885305}},
2101 {{-29.0, -64.951905}, {2.690922, 6.949515}, {-1.179098, 4.452292}},
2102 {{16.75, 30.743902}, {3.363918, 6.649878}, {-0.733751, 4.546954}}},
2103 {{{39.25, 106.088112}, {-8.036913, -0.844714}, {-3.711596, -3.341936}},
2104 {{16.75, -30.743902}, {-7.363918, -1.144350}, {-3.266249, -3.247275}},
2105 {{-29.0, 64.951905}, {-7.440922, -0.411701}, {-3.570902, -2.908924}}}});
2106 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2107 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2108 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2109 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2110 }
2111
2112 TEST_F(HloEvaluatorTest, 3D_IRFFT_3x3x5_on_c64x3x3x3) {
2113 const char* hlo_text = R"(
2114 HloModule Fft
2115
2116 ENTRY main {
2117 operand = c64[3, 3, 3] parameter(0)
2118 ROOT irfft = f32[3, 3, 5] fft(operand), fft_type=IRFFT, fft_length={3, 3, 5}
2119 }
2120 )";
2121 auto input = LiteralUtil::CreateR3<complex64>(
2122 {{{{119.5, 0.0}, {-3.5, 4.817337}, {-3.5, 1.137219}},
2123 {{-5.75, 56.724664}, {-19.206730, -10.537254}, {-5.775483, -12.245880}},
2124 {{-5.75, -56.724664}, {15.956730, 15.010495}, {2.525483, 13.301869}}},
2125 {{{39.25, -106.088112}, {3.286913, 7.382528}, {-1.038404, 4.885305}},
2126 {{-29.0, -64.951905}, {2.690922, 6.949515}, {-1.179098, 4.452292}},
2127 {{16.75, 30.743902}, {3.363918, 6.649878}, {-0.733751, 4.546954}}},
2128 {{{39.25, 106.088112}, {-8.036913, -0.844714}, {-3.711596, -3.341936}},
2129 {{16.75, -30.743902}, {-7.363918, -1.144350}, {-3.266249, -3.247275}},
2130 {{-29.0, 64.951905}, {-7.440922, -0.411701}, {-3.570902, -2.908924}}}});
2131 auto expected = LiteralUtil::CreateR3<float>({{{1.8, 2.7, 3.6, 4.5, 5.4},
2132 {8.1, 7.2, 6.3, 5.4, 4.5},
2133 {1.1, 2.2, 3.3, 4.4, 5.5}},
2134 {{5.4, 6.3, 7.2, 8.1, 9.0},
2135 {4.5, 3.6, 2.7, 1.8, 0.9},
2136 {5.5, 6.6, 7.7, 8.8, 9.9}},
2137 {{-1.8, -2.7, -3.6, -4.5, -5.4},
2138 {-5.4, -6.3, -7.2, -8.1, -9.0},
2139 {1.9, 2.9, 3.9, 4.9, 5.9}}});
2140 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2141 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2142 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2143 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2144 }
2145
2146 // FFT tests with non-default data layout:
2147
2148 TEST_F(HloEvaluatorTest, 1D_FFT_8_on_c64x2x4x8_with_layout) {
2149 const char* hlo_text = R"(
2150 HloModule Fft
2151
2152 ENTRY main {
2153 operand = c64[2, 4, 8]{0, 2, 1} parameter(0)
2154 ROOT fft = c64[2, 4, 8]{1, 2, 0} fft(operand), fft_type=FFT, fft_length={8}
2155 }
2156 )";
2157 auto input = fft_c64x2x4x8_.Relayout(LayoutUtil::MakeLayout({0, 2, 1}));
2158 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2159 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2160 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_1d_.shape()));
2161 EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_1d_, result, fft_error_));
2162 }
2163
2164 TEST_F(HloEvaluatorTest, 2D_FFT_4x8_on_c64x2x4x8_with_layout) {
2165 const char* hlo_text = R"(
2166 HloModule Fft
2167
2168 ENTRY main {
2169 operand = c64[2, 4, 8]{2, 0, 1} parameter(0)
2170 ROOT fft = c64[2, 4, 8]{1, 0, 2} fft(operand), fft_type=FFT, fft_length={4, 8}
2171 }
2172 )";
2173 auto input = fft_c64x2x4x8_.Relayout(LayoutUtil::MakeLayout({2, 0, 1}));
2174 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2175 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2176 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_2d_.shape()));
2177 EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_2d_, result, fft_error_));
2178 }
2179
2180 TEST_F(HloEvaluatorTest, 3D_FFT_2x4x8_on_c64x2x4x8_with_layout) {
2181 const char* hlo_text = R"(
2182 HloModule Fft
2183
2184 ENTRY main {
2185 operand = c64[2, 4, 8]{1, 2, 0} parameter(0)
2186 ROOT fft =
2187 c64[2, 4, 8]{0, 2, 1} fft(operand), fft_type=FFT, fft_length={2, 4, 8}
2188 }
2189 )";
2190 auto input = fft_c64x2x4x8_.Relayout(LayoutUtil::MakeLayout({1, 2, 0}));
2191 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2192 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2193 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), fft_c64x2x4x8_3d_.shape()));
2194 EXPECT_TRUE(LiteralTestUtil::Near(fft_c64x2x4x8_3d_, result, fft_error_));
2195 }
2196
2197 // FFT tests with unusual parameters:
2198
2199 // Zero-length transform.
2200 TEST_F(HloEvaluatorTest, 1D_FFT_0_on_c64x1x1x1x1) {
2201 const char* hlo_text = R"(
2202 HloModule Fft
2203
2204 ENTRY main {
2205 operand = c64[1, 1, 1, 1] parameter(0)
2206 ROOT fft = c64[1, 1, 1, 1] fft(operand), fft_type=FFT, fft_length={0}
2207 }
2208 )";
2209 auto input = LiteralUtil::CreateR4<complex64>({{{{{42.24, 24.42}}}}});
2210 auto expected = LiteralUtil::CreateR4<complex64>({{{{{0.0, 0.0}}}}});
2211 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2212 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2213 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2214 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2215 }
2216
2217 // Zero-length axis.
2218 TEST_F(HloEvaluatorTest, 1D_FFT_1_on_c64x1x1x1x0) {
2219 const char* hlo_text = R"(
2220 HloModule Fft
2221
2222 ENTRY main {
2223 operand = c64[1, 1, 1, 0] parameter(0)
2224 ROOT fft = c64[1, 1, 1, 0] fft(operand), fft_type=FFT, fft_length={1}
2225 }
2226 )";
2227 TF_ASSERT_OK_AND_ASSIGN(
2228 auto input,
2229 LiteralUtil::CreateR4<complex64>({{{{}}}}).Reshape({1, 1, 1, 0}));
2230 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2231 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2232 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), input.shape()));
2233 EXPECT_TRUE(LiteralTestUtil::Near(input, result, fft_error_));
2234 }
2235
2236 // Some/all dimensions have length 1.
2237 TEST_F(HloEvaluatorTest, 1D_FFT_1_on_c64x1x1x1x1) {
2238 const char* hlo_text = R"(
2239 HloModule Fft
2240
2241 ENTRY main {
2242 operand = c64[1, 1, 1, 1] parameter(0)
2243 ROOT fft = c64[1, 1, 1, 1] fft(operand), fft_type=FFT, fft_length={1}
2244 }
2245 )";
2246 auto input = LiteralUtil::CreateR4<complex64>({{{{{42.24, 24.42}}}}});
2247 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2248 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2249 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), input.shape()));
2250 EXPECT_TRUE(LiteralTestUtil::Near(input, result, fft_error_));
2251 }
2252
2253 // Zero-length transform.
2254 TEST_F(HloEvaluatorTest, 3D_FFT_1x0x1_on_c64x1x1x1x1) {
2255 const char* hlo_text = R"(
2256 HloModule Fft
2257
2258 ENTRY main {
2259 operand = c64[1, 1, 1, 1] parameter(0)
2260 ROOT fft = c64[1, 1, 1, 1] fft(operand), fft_type=FFT, fft_length={1, 0, 1}
2261 }
2262 )";
2263 auto input = LiteralUtil::CreateR4<complex64>({{{{{42.24, 24.42}}}}});
2264 auto expected = LiteralUtil::CreateR4<complex64>({{{{{0.0, 0.0}}}}});
2265 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2266 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2267 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2268 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2269 }
2270
2271 // Zero-length axis.
2272 TEST_F(HloEvaluatorTest, 3D_FFT_1x1x1_on_c64x0x1x0x1) {
2273 const char* hlo_text = R"(
2274 HloModule Fft
2275
2276 ENTRY main {
2277 operand = c64[0, 1, 0, 1] parameter(0)
2278 ROOT fft = c64[0, 1, 0, 1] fft(operand), fft_type=FFT, fft_length={1, 1, 1}
2279 }
2280 )";
2281 TF_ASSERT_OK_AND_ASSIGN(
2282 auto input,
2283 LiteralUtil::CreateR4<complex64>({{{{}}}}).Reshape({0, 1, 0, 1}));
2284 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2285 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2286 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), input.shape()));
2287 EXPECT_TRUE(LiteralTestUtil::Near(input, result, fft_error_));
2288 }
2289
2290 // Some/all dimensions have length 1.
2291 TEST_F(HloEvaluatorTest, 3D_FFT_1x1x1_on_c64x1x1x1x1) {
2292 const char* hlo_text = R"(
2293 HloModule Fft
2294
2295 ENTRY main {
2296 operand = c64[1, 1, 1, 1] parameter(0)
2297 ROOT fft = c64[1, 1, 1, 1] fft(operand), fft_type=FFT, fft_length={1, 1, 1}
2298 }
2299 )";
2300 auto input = LiteralUtil::CreateR4<complex64>({{{{{42.24, 24.42}}}}});
2301 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2302 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2303 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), input.shape()));
2304 EXPECT_TRUE(LiteralTestUtil::Near(input, result, fft_error_));
2305 }
2306
2307 // Some/all dimensions have length 1.
2308 TEST_F(HloEvaluatorTest, 3D_FFT_3x1x1_on_c64x1x3x1x1) {
2309 const char* hlo_text = R"(
2310 HloModule Fft
2311
2312 ENTRY main {
2313 operand = c64[1, 3, 1, 1] parameter(0)
2314 ROOT fft = c64[1, 3, 1, 1] fft(operand), fft_type=FFT, fft_length={3, 1, 1}
2315 }
2316 )";
2317 auto input = LiteralUtil::CreateR4<complex64>(
2318 {{{{{42.24, 24.42}}}, {{{-42.24, 24.42}}}, {{{42.24, -24.42}}}}});
2319 auto expected =
2320 LiteralUtil::CreateR4<complex64>({{{{{42.24, 24.42}}},
2321 {{{84.5367, 97.5818}}},
2322 {{{-0.0566792, -48.7418}}}}});
2323 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2324 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2325 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2326 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2327 }
2328
2329 // Some/all dimensions have length 1.
2330 TEST_F(HloEvaluatorTest, 3D_IFFT_3x1x1_on_c64x1x3x1x1) {
2331 const char* hlo_text = R"(
2332 HloModule Fft
2333
2334 ENTRY main {
2335 operand = c64[1, 3, 1, 1] parameter(0)
2336 ROOT ifft = c64[1, 3, 1, 1] fft(operand), fft_type=IFFT, fft_length={3, 1, 1}
2337 }
2338 )";
2339 auto input = LiteralUtil::CreateR4<complex64>({{{{{42.24, 24.42}}},
2340 {{{84.5367, 97.5818}}},
2341 {{{-0.0566792, -48.7418}}}}});
2342 auto expected = LiteralUtil::CreateR4<complex64>(
2343 {{{{{42.24, 24.42}}}, {{{-42.24, 24.42}}}, {{{42.24, -24.42}}}}});
2344 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2345 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2346 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2347 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2348 }
2349
2350 // Odd transform length.
2351 TEST_F(HloEvaluatorTest, 1D_FFT_5_on_c64x5) {
2352 const char* hlo_text = R"(
2353 HloModule Fft
2354
2355 ENTRY main {
2356 operand = c64[5] parameter(0)
2357 ROOT fft = c64[5] fft(operand), fft_type=FFT, fft_length={5}
2358 }
2359 )";
2360 auto input = LiteralUtil::CreateR1<complex64>(
2361 {{1.0, 5.0}, {2.0, 4.0}, {3.0, 3.0}, {4.0, 2.0}, {5.0, 1.0}});
2362 auto expected = LiteralUtil::CreateR1<complex64>({{15.0, 15.0},
2363 {0.940955, 5.94095},
2364 {-1.6877, 3.3123},
2365 {-3.3123, 1.6877},
2366 {-5.94095, -0.940955}});
2367 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2368 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2369 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2370 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2371 }
2372
2373 // Odd transform length.
2374 TEST_F(HloEvaluatorTest, 1D_IFFT_5_on_c64x5) {
2375 const char* hlo_text = R"(
2376 HloModule Fft
2377
2378 ENTRY main {
2379 operand = c64[5] parameter(0)
2380 ROOT ifft = c64[5] fft(operand), fft_type=IFFT, fft_length={5}
2381 }
2382 )";
2383 auto input = LiteralUtil::CreateR1<complex64>({{15.0, 15.0},
2384 {0.940955, 5.94095},
2385 {-1.6877, 3.3123},
2386 {-3.3123, 1.6877},
2387 {-5.94095, -0.940955}});
2388 auto expected = LiteralUtil::CreateR1<complex64>(
2389 {{1.0, 5.0}, {2.0, 4.0}, {3.0, 3.0}, {4.0, 2.0}, {5.0, 1.0}});
2390 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2391 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2392 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2393 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2394 }
2395
2396 // All input values are zero.
2397 TEST_F(HloEvaluatorTest, 1D_FFT_4_on_zero_c64x4) {
2398 const char* hlo_text = R"(
2399 HloModule Fft
2400
2401 ENTRY main {
2402 operand = c64[4] parameter(0)
2403 ROOT fft = c64[4] fft(operand), fft_type=FFT, fft_length={4}
2404 }
2405 )";
2406 auto input = LiteralUtil::CreateR1<complex64>(
2407 {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}});
2408 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2409 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2410 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), input.shape()));
2411 EXPECT_TRUE(LiteralTestUtil::Near(input, result, fft_error_));
2412 }
2413
2414 // All input values are zero.
2415 TEST_F(HloEvaluatorTest, 3D_FFT_3x3x4_on_zero_c64x3x3x4) {
2416 const char* hlo_text = R"(
2417 HloModule Fft
2418
2419 ENTRY main {
2420 operand = c64[3, 3, 4] parameter(0)
2421 ROOT fft = c64[3, 3, 4] fft(operand), fft_type=FFT, fft_length={3, 3, 4}
2422 }
2423 )";
2424 auto input = LiteralUtil::CreateR3<complex64>(
2425 {{{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2426 {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2427 {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}},
2428 {{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2429 {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2430 {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}},
2431 {{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2432 {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2433 {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}}});
2434 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2435 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2436 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), input.shape()));
2437 EXPECT_TRUE(LiteralTestUtil::Near(input, result, fft_error_));
2438 }
2439
2440 // All input values are zero.
2441 TEST_F(HloEvaluatorTest, 3D_IFFT_3x3x4_on_zero_c64x3x3x4) {
2442 const char* hlo_text = R"(
2443 HloModule Fft
2444
2445 ENTRY main {
2446 operand = c64[3, 3, 4] parameter(0)
2447 ROOT ifft = c64[3, 3, 4] fft(operand), fft_type=IFFT, fft_length={3, 3, 4}
2448 }
2449 )";
2450 auto input = LiteralUtil::CreateR3<complex64>(
2451 {{{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2452 {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2453 {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}},
2454 {{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2455 {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2456 {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}},
2457 {{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2458 {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2459 {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}}});
2460 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2461 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2462 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), input.shape()));
2463 EXPECT_TRUE(LiteralTestUtil::Near(input, result, fft_error_));
2464 }
2465
2466 // All input values are zero.
2467 TEST_F(HloEvaluatorTest, 3D_RFFT_3x3x4_on_zero_f32x3x3x4) {
2468 const char* hlo_text = R"(
2469 HloModule Fft
2470
2471 ENTRY main {
2472 operand = f32[3, 3, 4] parameter(0)
2473 ROOT rfft = c64[3, 3, 3] fft(operand), fft_type=RFFT, fft_length={3, 3, 4}
2474 }
2475 )";
2476 auto input = LiteralUtil::CreateR3<float>(
2477 {{{0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}},
2478 {{0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}},
2479 {{0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}}});
2480 auto expected = LiteralUtil::CreateR3<complex64>(
2481 {{{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2482 {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2483 {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}},
2484 {{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2485 {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2486 {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}},
2487 {{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2488 {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2489 {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}}});
2490 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2491 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2492 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2493 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2494 }
2495
2496 // All input values are zero.
2497 TEST_F(HloEvaluatorTest, 3D_IRFFT_3x3x4_on_zero_c64x3x3x3) {
2498 const char* hlo_text = R"(
2499 HloModule Fft
2500
2501 ENTRY main {
2502 operand = c64[3, 3, 3] parameter(0)
2503 ROOT irfft = f32[3, 3, 4] fft(operand), fft_type=IRFFT, fft_length={3, 3, 4}
2504 }
2505 )";
2506 auto input = LiteralUtil::CreateR3<complex64>(
2507 {{{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2508 {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2509 {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}},
2510 {{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2511 {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2512 {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}},
2513 {{{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2514 {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}},
2515 {{0.0, 0.0}, {0.0, 0.0}, {0.0, 0.0}}}});
2516 auto expected = LiteralUtil::CreateR3<float>(
2517 {{{0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}},
2518 {{0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}},
2519 {{0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0}}});
2520 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2521 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2522 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2523 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2524 }
2525
2526 // Input values, for which IRFFT discards non-zero imaginary parts.
2527 TEST_F(HloEvaluatorTest, 2D_IRFFT_3x4_on_c64x3x3) {
2528 const char* hlo_text = R"(
2529 HloModule Fft
2530
2531 ENTRY main {
2532 operand = c64[3, 3] parameter(0)
2533 ROOT irfft = f32[3, 4] fft(operand), fft_type=IRFFT, fft_length={3, 4}
2534 }
2535 )";
2536 auto input =
2537 LiteralUtil::CreateR2<complex64>({{{0.0, 0.0}, {1.0, 0.0}, {2.0, 0.0}},
2538 {{3.0, 0.0}, {4.0, 0.0}, {5.0, 0.0}},
2539 {{6.0, 0.0}, {7.0, 0.0}, {8.0, 0.0}}});
2540 auto expected =
2541 LiteralUtil::CreateR2<float>({{4.0, -0.5, 0.0, -0.5},
2542 {-1.5, 0.433013, 0.0, -0.433013},
2543 {-1.5, -0.433013, 0.0, 0.433013}});
2544 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
2545 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&input}));
2546 EXPECT_TRUE(ShapeUtil::Compatible(result.shape(), expected.shape()));
2547 EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_));
2548 }
2549
2550 class HloEvaluatorPreciseReduceTest : public HloTestBase {};
2551
2552 // Tests that Reduce doesn't lose precision when adding many numbers (because
2553 // it accumulates its result in a double).
TEST_F(HloEvaluatorPreciseReduceTest,AddReductionPrecisionTest)2554 TEST_F(HloEvaluatorPreciseReduceTest, AddReductionPrecisionTest) {
2555 auto m = CreateNewVerifiedModule();
2556 HloComputation::Builder b(TestName());
2557
2558 constexpr int kNumElements = 1 << 25; // float += 1 saturates at 1<<24
2559 std::vector<float> v(kNumElements, 1.0f);
2560 HloInstruction* arg_instruction = b.AddInstruction(
2561 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(v)));
2562 HloInstruction* init_value = b.AddInstruction(
2563 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
2564
2565 HloComputation::Builder add_computation("add");
2566 Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
2567 auto param_lhs = add_computation.AddInstruction(
2568 HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
2569 auto param_rhs = add_computation.AddInstruction(
2570 HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
2571 add_computation.AddInstruction(HloInstruction::CreateBinary(
2572 scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs));
2573 auto add_func = m->AddEmbeddedComputation(add_computation.Build());
2574
2575 HloInstruction* reduce_instruction = b.AddInstruction(
2576 HloInstruction::CreateReduce(scalar_shape, arg_instruction, init_value,
2577 /*dimensions_to_reduce=*/{0}, add_func));
2578 m->AddEntryComputation(b.Build());
2579
2580 HloEvaluator hlo_eval;
2581 Literal result = hlo_eval.Evaluate(reduce_instruction).value();
2582 LiteralTestUtil::ExpectR0Equal<float>(kNumElements, result);
2583 }
2584
2585 // Reducing many numbers should be fast because it doesn't create
2586 // intermediate Literals; the microbenchmark should finish in < 1 msec.
BM_ReducePrecisely(::testing::benchmark::State & state)2587 void BM_ReducePrecisely(::testing::benchmark::State& state) {
2588 HloComputation::Builder b("BM_ReducePrecisely");
2589 HloModuleConfig config;
2590 config.set_debug_options(GetDebugOptionsFromFlags());
2591 HloModule module("BM_ReducePrecisely", config);
2592
2593 constexpr int kNumElements = 1 << 25; // float += 1 saturates at 1<<24
2594 std::vector<float> v(kNumElements, 1.0f);
2595 HloInstruction* arg_instruction = b.AddInstruction(
2596 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(v)));
2597 auto init_value = b.AddInstruction(
2598 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
2599
2600 HloComputation::Builder add_computation("add");
2601 Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
2602 auto param_lhs = add_computation.AddInstruction(
2603 HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
2604 auto param_rhs = add_computation.AddInstruction(
2605 HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
2606 add_computation.AddInstruction(HloInstruction::CreateBinary(
2607 scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs));
2608 auto add_func = module.AddEmbeddedComputation(add_computation.Build());
2609
2610 HloInstruction* reduce_instruction = b.AddInstruction(
2611 HloInstruction::CreateReduce(scalar_shape, arg_instruction, init_value,
2612 /*dimensions_to_reduce=*/{0}, add_func));
2613 module.AddEntryComputation(b.Build());
2614
2615 // Benchmark loop
2616 for (auto s : state) {
2617 HloEvaluator hlo_eval;
2618 hlo_eval.Evaluate(reduce_instruction).value();
2619 }
2620 }
2621
2622 BENCHMARK(BM_ReducePrecisely);
2623
TEST_P(HloEvaluatorBf16Test,ReduceAdd)2624 TEST_P(HloEvaluatorBf16Test, ReduceAdd) {
2625 HloComputation::Builder b(TestName());
2626
2627 // arg:
2628 // f32[2,3] {
2629 // { 1, 2, 3 },
2630 // { 5, 6, 7 },
2631 // }
2632 auto arg_array = std::make_unique<Array2D<float>>(2, 3);
2633 arg_array->FillUnique(1.0f);
2634 auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array);
2635
2636 HloInstruction* arg_instruction =
2637 b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
2638
2639 auto init_value = b.AddInstruction(
2640 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
2641
2642 HloComputation::Builder add_computation("add");
2643 Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
2644 auto param_lhs = add_computation.AddInstruction(
2645 HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
2646 auto param_rhs = add_computation.AddInstruction(
2647 HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
2648 add_computation.AddInstruction(HloInstruction::CreateBinary(
2649 scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs));
2650 auto add_func = m_->AddEmbeddedComputation(add_computation.Build());
2651
2652 Shape shape = ShapeUtil::MakeShape(F32, {2});
2653 b.AddInstruction(
2654 HloInstruction::CreateReduce(shape, arg_instruction, init_value,
2655 /*dimensions_to_reduce=*/{1}, add_func));
2656
2657 m_->AddEntryComputation(b.Build());
2658
2659 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
2660
2661 auto expected = LiteralUtil::CreateR1<float>({6, 18});
2662
2663 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
2664 }
2665
TEST_P(HloEvaluatorBf16Test,ReduceWindowMax)2666 TEST_P(HloEvaluatorBf16Test, ReduceWindowMax) {
2667 HloComputation::Builder b(TestName());
2668
2669 // arg:
2670 // f32[2,3] {
2671 // { 1, 2, 3 },
2672 // { 5, 6, 7 },
2673 // }
2674 auto arg_array = std::make_unique<Array2D<float>>(2, 3);
2675 arg_array->FillUnique(1.0f);
2676 auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array);
2677
2678 HloInstruction* arg_instruction =
2679 b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
2680
2681 auto init_value = b.AddInstruction(
2682 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
2683 auto max_func = m_->AddEmbeddedComputation(MaxComputationScalarF32());
2684
2685 Window window;
2686 WindowDimension dim;
2687 dim.set_size(2);
2688 dim.set_stride(1);
2689 dim.set_padding_low(0);
2690 dim.set_padding_high(0);
2691 dim.set_window_dilation(1);
2692 dim.set_base_dilation(1);
2693 *window.add_dimensions() = dim;
2694 *window.add_dimensions() = dim;
2695
2696 Shape shape = ShapeUtil::MakeShape(F32, {1, 2});
2697 b.AddInstruction(HloInstruction::CreateReduceWindow(
2698 shape, arg_instruction, init_value, window, max_func));
2699
2700 m_->AddEntryComputation(b.Build());
2701
2702 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
2703
2704 auto expected = LiteralUtil::CreateR2<float>({{6, 7}});
2705 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
2706 }
2707
TEST_P(HloEvaluatorBf16Test,ReduceWindowMaxIotaWindowDilation)2708 TEST_P(HloEvaluatorBf16Test, ReduceWindowMaxIotaWindowDilation) {
2709 auto expected = LiteralUtil::CreateR2<float>({{10, 11}, {14, 15}});
2710 ReduceWindowMaxIotaTest(
2711 /*window_size=*/2,
2712 /*padding=*/0,
2713 /*stride=*/1,
2714 /*window_dilation=*/2,
2715 /*base_dilation=*/1,
2716 /*expected=*/expected);
2717 }
2718
TEST_P(HloEvaluatorBf16Test,ReduceWindowMaxIotaStrideWindowDilation)2719 TEST_P(HloEvaluatorBf16Test, ReduceWindowMaxIotaStrideWindowDilation) {
2720 auto expected = LiteralUtil::CreateR2<float>({{10}});
2721 ReduceWindowMaxIotaTest(
2722 /*window_size=*/2,
2723 /*padding=*/0,
2724 /*stride=*/2,
2725 /*window_dilation=*/2,
2726 /*base_dilation=*/1,
2727 /*expected=*/expected);
2728 }
2729
TEST_P(HloEvaluatorBf16Test,ReduceWindowMaxIotaBaseDilation)2730 TEST_P(HloEvaluatorBf16Test, ReduceWindowMaxIotaBaseDilation) {
2731 auto expected = LiteralUtil::CreateR2<float>({{0, 1, 1, 2, 2, 3},
2732 {4, 5, 5, 6, 6, 7},
2733 {4, 5, 5, 6, 6, 7},
2734 {8, 9, 9, 10, 10, 11},
2735 {8, 9, 9, 10, 10, 11},
2736 {12, 13, 13, 14, 14, 15}});
2737 ReduceWindowMaxIotaTest(
2738 /*window_size=*/2,
2739 /*padding=*/0,
2740 /*stride=*/1,
2741 /*window_dilation=*/1,
2742 /*base_dilation=*/2,
2743 /*expected=*/expected);
2744 }
2745
TEST_P(HloEvaluatorBf16Test,ReduceWindowMaxIotaStrideBaseDilation)2746 TEST_P(HloEvaluatorBf16Test, ReduceWindowMaxIotaStrideBaseDilation) {
2747 auto expected =
2748 LiteralUtil::CreateR2<float>({{0, 1, 2}, {4, 5, 6}, {8, 9, 10}});
2749 ReduceWindowMaxIotaTest(
2750 /*window_size=*/2,
2751 /*padding=*/0,
2752 /*stride=*/2,
2753 /*window_dilation=*/1,
2754 /*base_dilation=*/2,
2755 /*expected=*/expected);
2756 }
2757
TEST_P(HloEvaluatorBf16Test,ReduceWindowMaxIotaStrideBothDilation)2758 TEST_P(HloEvaluatorBf16Test, ReduceWindowMaxIotaStrideBothDilation) {
2759 auto expected =
2760 LiteralUtil::CreateR2<float>({{5, 6, 7}, {9, 10, 11}, {13, 14, 15}});
2761 ReduceWindowMaxIotaTest(
2762 /*window_size=*/2,
2763 /*padding=*/0,
2764 /*stride=*/2,
2765 /*window_dilation=*/2,
2766 /*base_dilation=*/2,
2767 /*expected=*/expected);
2768 }
2769
TEST_P(HloEvaluatorBf16Test,ReduceWindowMaxIotaPaddingStrideBaseDilation)2770 TEST_P(HloEvaluatorBf16Test, ReduceWindowMaxIotaPaddingStrideBaseDilation) {
2771 // The base is dilated first, and then padding is applied, hence this result.
2772 auto expected =
2773 LiteralUtil::CreateR2<float>({{0, 2, 3}, {8, 10, 11}, {12, 14, 15}});
2774 ReduceWindowMaxIotaTest(
2775 /*window_size=*/3,
2776 /*padding=*/1,
2777 /*stride=*/3,
2778 /*window_dilation=*/1,
2779 /*base_dilation=*/2,
2780 /*expected=*/expected);
2781 }
2782
TEST_P(HloEvaluatorBf16Test,ReduceWindowAdd)2783 TEST_P(HloEvaluatorBf16Test, ReduceWindowAdd) {
2784 HloComputation::Builder b(TestName());
2785
2786 // arg:
2787 // f32[2,3] {
2788 // { 1, 2, 3 },
2789 // { 5, 6, 7 },
2790 // }
2791 auto arg_array = std::make_unique<Array2D<float>>(2, 3);
2792 arg_array->FillUnique(1.0f);
2793 auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array);
2794
2795 HloInstruction* arg_instruction =
2796 b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
2797
2798 auto init_value = b.AddInstruction(
2799 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
2800
2801 HloComputation::Builder add_computation("add");
2802 Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
2803 auto param_lhs = add_computation.AddInstruction(
2804 HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
2805 auto param_rhs = add_computation.AddInstruction(
2806 HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
2807 add_computation.AddInstruction(HloInstruction::CreateBinary(
2808 scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs));
2809 auto add_func = m_->AddEmbeddedComputation(add_computation.Build());
2810
2811 Window window;
2812 WindowDimension dim;
2813 dim.set_size(1);
2814 dim.set_stride(1);
2815 dim.set_padding_low(0);
2816 dim.set_padding_high(0);
2817 dim.set_window_dilation(1);
2818 dim.set_base_dilation(1);
2819 *window.add_dimensions() = dim;
2820 dim.set_size(2);
2821 dim.set_stride(1);
2822 dim.set_padding_low(1);
2823 dim.set_padding_high(0);
2824 dim.set_window_dilation(1);
2825 dim.set_base_dilation(1);
2826 *window.add_dimensions() = dim;
2827
2828 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
2829 b.AddInstruction(HloInstruction::CreateReduceWindow(
2830 shape, arg_instruction, init_value, window, add_func));
2831
2832 m_->AddEntryComputation(b.Build());
2833
2834 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
2835
2836 auto expected = LiteralUtil::CreateR2<float>({{1, 3, 5}, {5, 11, 13}});
2837 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
2838 }
2839
TEST_P(HloEvaluatorBf16Test,ReduceWindowAdd6D)2840 TEST_P(HloEvaluatorBf16Test, ReduceWindowAdd6D) {
2841 HloComputation::Builder b(TestName());
2842
2843 // arg: f32[4,4,4,4,4,4] full of ones. Using small dims to limit run-time.
2844 std::vector<int64_t> input_dims(6, 4);
2845 Literal arg_literal =
2846 LiteralUtil::CreateFullWithDescendingLayout<float>(input_dims, 1.0f);
2847
2848 HloInstruction* arg_instruction =
2849 b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
2850
2851 auto init_value = b.AddInstruction(
2852 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
2853
2854 HloComputation::Builder add_computation("add");
2855 Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
2856 auto param_lhs = add_computation.AddInstruction(
2857 HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
2858 auto param_rhs = add_computation.AddInstruction(
2859 HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
2860 add_computation.AddInstruction(HloInstruction::CreateBinary(
2861 scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs));
2862 auto add_func = m_->AddEmbeddedComputation(add_computation.Build());
2863
2864 Window window;
2865
2866 WindowDimension trivial_dim;
2867 trivial_dim.set_size(1);
2868 trivial_dim.set_stride(1);
2869 trivial_dim.set_padding_low(0);
2870 trivial_dim.set_padding_high(0);
2871 trivial_dim.set_window_dilation(1);
2872 trivial_dim.set_base_dilation(1);
2873
2874 WindowDimension active_dim;
2875 active_dim.set_size(2);
2876 active_dim.set_stride(1);
2877 active_dim.set_padding_low(0);
2878 active_dim.set_padding_high(0);
2879 active_dim.set_window_dilation(1);
2880 active_dim.set_base_dilation(1);
2881
2882 *window.add_dimensions() = trivial_dim;
2883 *window.add_dimensions() = active_dim;
2884 *window.add_dimensions() = active_dim;
2885 *window.add_dimensions() = active_dim;
2886 *window.add_dimensions() = trivial_dim;
2887 *window.add_dimensions() = trivial_dim;
2888
2889 Shape shape = ShapeUtil::MakeShape(F32, {4, 3, 3, 3, 4, 4});
2890 b.AddInstruction(HloInstruction::CreateReduceWindow(
2891 shape, arg_instruction, init_value, window, add_func));
2892
2893 m_->AddEntryComputation(b.Build());
2894
2895 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
2896
2897 std::vector<int64_t> output_dims = {4, 3, 3, 3, 4, 4};
2898 Literal result_literal =
2899 LiteralUtil::CreateFullWithDescendingLayout<float>(output_dims, 8.0f);
2900 EXPECT_TRUE(LiteralTestUtil::Equal(result_literal, result));
2901 }
2902
TEST_P(HloEvaluatorBf16Test,Min3In5Stride2Tuple)2903 TEST_P(HloEvaluatorBf16Test, Min3In5Stride2Tuple) {
2904 HloComputation::Builder builder("main");
2905 auto input1 = builder.AddInstruction(HloInstruction::CreateConstant(
2906 LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1})));
2907 auto input2 = builder.AddInstruction(HloInstruction::CreateConstant(
2908 LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1})));
2909 HloComputation::Builder bcompute("ComputeFunction");
2910 auto shape1 = ShapeUtil::MakeShape(F32, {});
2911 auto shape2 = ShapeUtil::MakeShape(F32, {});
2912 auto p2 =
2913 bcompute.AddInstruction(HloInstruction::CreateParameter(0, shape1, "x0"));
2914 auto p3 =
2915 bcompute.AddInstruction(HloInstruction::CreateParameter(1, shape2, "x1"));
2916 auto p4 =
2917 bcompute.AddInstruction(HloInstruction::CreateParameter(2, shape1, "y0"));
2918 auto p5 =
2919 bcompute.AddInstruction(HloInstruction::CreateParameter(3, shape2, "y1"));
2920 std::vector<HloInstruction*> compute_vec = {
2921 bcompute.AddInstruction(
2922 HloInstruction::CreateBinary(shape1, HloOpcode::kMinimum, p2, p4)),
2923 bcompute.AddInstruction(
2924 HloInstruction::CreateBinary(shape2, HloOpcode::kMinimum, p3, p5))};
2925 bcompute.AddInstruction(HloInstruction::CreateTuple(compute_vec));
2926 auto compute_tuple = m_->AddEmbeddedComputation(bcompute.Build());
2927 std::vector<HloInstruction*> input_vec = {input1, input2};
2928 auto init1 = builder.AddInstruction(
2929 HloInstruction::CreateConstant(LiteralUtil::MaxValue(F32)));
2930 auto init2 = builder.AddInstruction(
2931 HloInstruction::CreateConstant(LiteralUtil::MaxValue(F32)));
2932 std::vector<HloInstruction*> init_vec = {init1, init2};
2933 auto padding = std::pair<int64_t, int64_t>(0, 0);
2934 TF_ASSERT_OK_AND_ASSIGN(auto window,
2935 ShapeInference::InferWindowFromDimensions(
2936 {3}, {2}, absl::MakeSpan(&padding, 1),
2937 /*lhs_dilation=*/{},
2938 /*rhs_dilation=*/{}));
2939 std::vector<const Shape*> input_shapes = {&input1->shape(), &input2->shape()};
2940 std::vector<const Shape*> init_shapes = {&init1->shape(), &init2->shape()};
2941 TF_ASSERT_OK_AND_ASSIGN(Shape shape,
2942 ShapeInference::InferReduceWindowShape(
2943 input_shapes, init_shapes, window,
2944 compute_tuple->ComputeProgramShape()));
2945 builder.AddInstruction(HloInstruction::CreateReduceWindow(
2946 shape, input_vec, init_vec, window, compute_tuple));
2947 auto r1 = LiteralUtil::CreateR1<float>({100, 1});
2948 auto expected = LiteralUtil::MakeTuple({&r1, &r1});
2949 m_->AddEntryComputation(builder.Build());
2950 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
2951 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
2952 }
2953
TEST_P(HloEvaluatorBf16Test,Min3In5Stride2TupleDiffInput)2954 TEST_P(HloEvaluatorBf16Test, Min3In5Stride2TupleDiffInput) {
2955 HloComputation::Builder builder("main");
2956 auto input1 = builder.AddInstruction(HloInstruction::CreateConstant(
2957 LiteralUtil::CreateR1<float>({10000, 1000, 100, 10, 1})));
2958 auto input2 = builder.AddInstruction(HloInstruction::CreateConstant(
2959 LiteralUtil::CreateR1<int>({15, 28, 300, 107, 12})));
2960 HloComputation::Builder bcompute("ComputeFunction");
2961 auto shape1 = ShapeUtil::MakeShape(F32, {});
2962 auto shape2 = ShapeUtil::MakeShape(S32, {});
2963 auto p2 =
2964 bcompute.AddInstruction(HloInstruction::CreateParameter(0, shape1, "x0"));
2965 auto p3 =
2966 bcompute.AddInstruction(HloInstruction::CreateParameter(1, shape2, "x1"));
2967 auto p4 =
2968 bcompute.AddInstruction(HloInstruction::CreateParameter(2, shape1, "y0"));
2969 auto p5 =
2970 bcompute.AddInstruction(HloInstruction::CreateParameter(3, shape2, "y1"));
2971 std::vector<HloInstruction*> compute_vec = {
2972 bcompute.AddInstruction(
2973 HloInstruction::CreateBinary(shape1, HloOpcode::kMinimum, p2, p4)),
2974 bcompute.AddInstruction(
2975 HloInstruction::CreateBinary(shape2, HloOpcode::kMinimum, p3, p5))};
2976 bcompute.AddInstruction(HloInstruction::CreateTuple(compute_vec));
2977 auto compute_tuple = m_->AddEmbeddedComputation(bcompute.Build());
2978 std::vector<HloInstruction*> input_vec = {input1, input2};
2979 auto init1 = builder.AddInstruction(
2980 HloInstruction::CreateConstant(LiteralUtil::MaxValue(F32)));
2981 auto init2 = builder.AddInstruction(
2982 HloInstruction::CreateConstant(LiteralUtil::MaxValue(S32)));
2983 std::vector<HloInstruction*> init_vec = {init1, init2};
2984 auto padding = std::pair<int64_t, int64_t>(0, 0);
2985 TF_ASSERT_OK_AND_ASSIGN(auto window,
2986 ShapeInference::InferWindowFromDimensions(
2987 {3}, {2}, absl::MakeSpan(&padding, 1),
2988 /*lhs_dilation=*/{},
2989 /*rhs_dilation=*/{}));
2990 std::vector<const Shape*> input_shapes = {&input1->shape(), &input2->shape()};
2991 std::vector<const Shape*> init_shapes = {&init1->shape(), &init2->shape()};
2992 TF_ASSERT_OK_AND_ASSIGN(Shape shape,
2993 ShapeInference::InferReduceWindowShape(
2994 input_shapes, init_shapes, window,
2995 compute_tuple->ComputeProgramShape()));
2996 builder.AddInstruction(HloInstruction::CreateReduceWindow(
2997 shape, input_vec, init_vec, window, compute_tuple));
2998 auto r1 = LiteralUtil::CreateR1<float>({100, 1});
2999 auto r2 = LiteralUtil::CreateR1<int>({15, 12});
3000 auto expected = LiteralUtil::MakeTuple({&r1, &r2});
3001 m_->AddEntryComputation(builder.Build());
3002 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
3003 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
3004 }
3005
TEST_P(HloEvaluatorBf16Test,StridedSlice)3006 TEST_P(HloEvaluatorBf16Test, StridedSlice) {
3007 HloComputation::Builder b(TestName());
3008
3009 // arg:
3010 // f32[3,5] {
3011 // { 1, 2, 3, 4, 5 },
3012 // { 9, 10, 11, 12, 13 },
3013 // { 17, 18, 19, 20, 21 },
3014 // }
3015 auto operand_array = std::make_unique<Array2D<float>>(3, 5);
3016 operand_array->FillUnique(1.0f);
3017 auto operand_literal =
3018 LiteralUtil::CreateR2FromArray2D<float>(*operand_array);
3019
3020 HloInstruction* operand = b.AddInstruction(
3021 HloInstruction::CreateConstant(std::move(operand_literal)));
3022
3023 Shape shape = ShapeUtil::MakeShape(F32, {2, 1});
3024 b.AddInstruction(HloInstruction::CreateSlice(shape, operand,
3025 /*start_indices=*/{0, 2},
3026 /*limit_indices=*/{3, 5},
3027 /*strides=*/{2, 3}));
3028 m_->AddEntryComputation(b.Build());
3029
3030 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
3031
3032 auto expected = LiteralUtil::CreateR2<float>({
3033 {3},
3034 {19},
3035 });
3036
3037 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
3038 }
3039
TEST_P(HloEvaluatorBf16Test,DynamicSlice)3040 TEST_P(HloEvaluatorBf16Test, DynamicSlice) {
3041 HloComputation::Builder b(TestName());
3042
3043 // arg:
3044 // f32[2,4] {
3045 // { 1, 2, 3, 4 },
3046 // { 5, 6, 7, 8 },
3047 // }
3048 auto operand_array = std::make_unique<Array2D<float>>(2, 4);
3049 operand_array->FillUnique(1.0f);
3050 auto operand_literal =
3051 LiteralUtil::CreateR2FromArray2D<float>(*operand_array);
3052
3053 HloInstruction* operand = b.AddInstruction(
3054 HloInstruction::CreateConstant(std::move(operand_literal)));
3055
3056 auto zero = b.AddInstruction(
3057 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(0)));
3058 auto one = b.AddInstruction(
3059 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(1)));
3060
3061 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
3062 b.AddInstruction(
3063 HloInstruction::CreateDynamicSlice(shape, operand, {zero, one}, {2, 3}));
3064 m_->AddEntryComputation(b.Build());
3065
3066 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
3067
3068 auto expected = LiteralUtil::CreateR2<float>({
3069 {2, 3, 4},
3070 {6, 7, 8},
3071 });
3072
3073 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
3074 }
3075
3076 // Verifies that the HloEvaluator's implementation goes along with existing
3077 // backends' behavior, although this is not required by the spec.
TEST_P(HloEvaluatorBf16Test,DynamicSliceModSlice)3078 TEST_P(HloEvaluatorBf16Test, DynamicSliceModSlice) {
3079 HloComputation::Builder b(TestName());
3080
3081 // arg:
3082 // f32[2,4] {
3083 // { 1, 2, 3, 4 },
3084 // { 5, 6, 7, 8 },
3085 // }
3086 auto operand_array = std::make_unique<Array2D<float>>(2, 4);
3087 operand_array->FillUnique(1.0f);
3088 auto operand_literal =
3089 LiteralUtil::CreateR2FromArray2D<float>(*operand_array);
3090
3091 HloInstruction* operand = b.AddInstruction(
3092 HloInstruction::CreateConstant(std::move(operand_literal)));
3093
3094 auto two = b.AddInstruction(
3095 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(2)));
3096 auto one = b.AddInstruction(
3097 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(1)));
3098
3099 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
3100 b.AddInstruction(
3101 HloInstruction::CreateDynamicSlice(shape, operand, {two, one}, {2, 3}));
3102 m_->AddEntryComputation(b.Build());
3103
3104 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
3105
3106 auto expected = LiteralUtil::CreateR2<float>({
3107 {2, 3, 4},
3108 {6, 7, 8},
3109 });
3110
3111 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
3112 }
3113
TEST_P(HloEvaluatorBf16Test,DynamicSliceUpdate)3114 TEST_P(HloEvaluatorBf16Test, DynamicSliceUpdate) {
3115 HloComputation::Builder b(TestName());
3116
3117 // arg:
3118 // f32[2,3] {
3119 // { 1, 2, 3 },
3120 // { 5, 6, 7 },
3121 // }
3122 auto operand_array = std::make_unique<Array2D<double>>(2, 3);
3123 operand_array->FillUnique(1.0);
3124 auto operand_literal =
3125 LiteralUtil::CreateR2FromArray2D<double>(*operand_array);
3126
3127 HloInstruction* operand = b.AddInstruction(
3128 HloInstruction::CreateConstant(std::move(operand_literal)));
3129
3130 auto zero = b.AddInstruction(
3131 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(0)));
3132 auto one = b.AddInstruction(
3133 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(1)));
3134
3135 auto update = b.AddInstruction(HloInstruction::CreateConstant(
3136 LiteralUtil::CreateR2<double>({{-2.0, -3.0}, {-6.0, -7.0}})));
3137
3138 Shape shape = ShapeUtil::MakeShape(F64, {2, 3});
3139 b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
3140 shape, operand, update, {zero, one}));
3141 m_->AddEntryComputation(b.Build());
3142
3143 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
3144
3145 auto expected = LiteralUtil::CreateR2<double>({
3146 {1, -2, -3},
3147 {5, -6, -7},
3148 });
3149
3150 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
3151 }
3152
TEST_P(HloEvaluatorBf16Test,SetAndGetTuples)3153 TEST_P(HloEvaluatorBf16Test, SetAndGetTuples) {
3154 HloComputation::Builder b(TestName());
3155
3156 // arg:
3157 // f32[2,3] {
3158 // { 1, 2, 3 },
3159 // { 5, 6, 7 },
3160 // }
3161 auto operand_array = std::make_unique<Array2D<double>>(2, 3);
3162 operand_array->FillUnique(1.0);
3163 auto operand_literal2 =
3164 LiteralUtil::CreateR2FromArray2D<double>(*operand_array);
3165
3166 HloInstruction* operand2 = b.AddInstruction(
3167 HloInstruction::CreateConstant(std::move(operand_literal2)));
3168 HloInstruction* operand1 = b.AddInstruction(
3169 HloInstruction::CreateConstant(LiteralUtil::CreateR1<int64_t>({0, 1})));
3170
3171 auto tuple =
3172 b.AddInstruction(HloInstruction::CreateTuple({operand1, operand2}));
3173
3174 Shape shape = ShapeUtil::MakeShape(F64, {2, 3});
3175 b.AddInstruction(HloInstruction::CreateGetTupleElement(shape, tuple, 1));
3176
3177 m_->AddEntryComputation(b.Build());
3178
3179 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
3180
3181 auto expected = LiteralUtil::CreateR2<double>({
3182 {1, 2, 3},
3183 {5, 6, 7},
3184 });
3185
3186 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
3187 }
3188
TEST_P(HloEvaluatorBf16Test,SetAndGetNestedTuples)3189 TEST_P(HloEvaluatorBf16Test, SetAndGetNestedTuples) {
3190 HloComputation::Builder b(TestName());
3191
3192 // arg:
3193 // f32[2,3] {
3194 // { 1, 2, 3 },
3195 // { 5, 6, 7 },
3196 // }
3197 auto operand_array = std::make_unique<Array2D<double>>(2, 3);
3198 operand_array->FillUnique(1.0);
3199
3200 HloInstruction* operand2 = b.AddInstruction(HloInstruction::CreateConstant(
3201 LiteralUtil::CreateR2FromArray2D<double>(*operand_array)));
3202 HloInstruction* operand1 = b.AddInstruction(
3203 HloInstruction::CreateConstant(LiteralUtil::CreateR1<int64_t>({0, 1})));
3204
3205 auto tuple1 =
3206 b.AddInstruction(HloInstruction::CreateTuple({operand1, operand2}));
3207 auto tuple2 =
3208 b.AddInstruction(HloInstruction::CreateTuple({operand2, operand2}));
3209
3210 auto outer_tuple =
3211 b.AddInstruction(HloInstruction::CreateTuple({tuple1, tuple2}));
3212
3213 b.AddInstruction(
3214 HloInstruction::CreateGetTupleElement(tuple2->shape(), outer_tuple, 1));
3215
3216 m_->AddEntryComputation(b.Build());
3217
3218 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
3219
3220 auto result_inner_literal =
3221 LiteralUtil::CreateR2FromArray2D<double>(*operand_array);
3222 auto expected =
3223 LiteralUtil::MakeTuple({&result_inner_literal, &result_inner_literal});
3224
3225 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
3226 }
3227
TEST_P(HloEvaluatorBf16Test,Reverse)3228 TEST_P(HloEvaluatorBf16Test, Reverse) {
3229 HloComputation::Builder b(TestName());
3230
3231 // Input shape is float[4x3x2x1].
3232 // clang-format off
3233 Array4D<float> input({
3234 {{{1.0f}, {2.0f}},
3235 {{3.0f}, {4.0f}},
3236 {{5.0f}, {6.0f}}},
3237 {{{7.0f}, {8.0f}},
3238 {{9.0f}, {10.0f}},
3239 {{11.0f}, {12.0f}}},
3240 {{{13.0f}, {14.0f}},
3241 {{15.0f}, {16.0f}},
3242 {{17.0f}, {18.0f}}},
3243 {{{19.0f}, {20.0f}},
3244 {{21.0f}, {22.0f}},
3245 {{23.0f}, {24.0f}}},
3246 });
3247 // clang-format on
3248 auto operand_literal = LiteralUtil::CreateR4FromArray4D<float>(input);
3249 HloInstruction* operand = b.AddInstruction(
3250 HloInstruction::CreateConstant(std::move(operand_literal)));
3251
3252 const Shape shape = ShapeUtil::MakeShape(F32, {4, 3, 2, 1});
3253 b.AddInstruction(HloInstruction::CreateReverse(shape, operand, {0, 1}));
3254 m_->AddEntryComputation(b.Build());
3255
3256 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
3257
3258 // clang-format off
3259 auto expected = LiteralUtil::CreateR4FromArray4D<float>({
3260 {{{23.0f}, {24.0f}},
3261 {{21.0f}, {22.0f}},
3262 {{19.0f}, {20.0f}}},
3263
3264 {{{17.0f}, {18.0f}},
3265 {{15.0f}, {16.0f}},
3266 {{13.0f}, {14.0f}}},
3267
3268 {{{11.0f}, {12.0f}},
3269 {{9.0f}, {10.0f}},
3270 {{7.0f}, {8.0f}}},
3271
3272 {{{5.0f}, {6.0f}},
3273 {{3.0f}, {4.0f}},
3274 {{1.0f}, {2.0f}}},
3275 });
3276 // clang-format on
3277
3278 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
3279 }
3280
TEST_P(HloEvaluatorBf16Test,EvaluateWithSubstitutions)3281 TEST_P(HloEvaluatorBf16Test, EvaluateWithSubstitutions) {
3282 HloComputation::Builder b(TestName());
3283 Shape shape = ShapeUtil::MakeShape(F32, {4});
3284
3285 HloInstruction* param0 =
3286 b.AddInstruction(HloInstruction::CreateParameter(0, shape, "param0"));
3287 HloInstruction* square = b.AddInstruction(HloInstruction::CreateBinary(
3288 shape, HloOpcode::kMultiply, param0, param0));
3289 HloInstruction* add = b.AddInstruction(
3290 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, square));
3291
3292 // Evaluate add with param0 = {1, 2, 3, 4}, square = {10, 20, 30, 40}.
3293 HloEvaluator evaluator;
3294 Literal param0_literal = LiteralUtil::CreateR1<float>({1, 2, 3, 4});
3295 Literal square_literal = LiteralUtil::CreateR1<float>({10, 20, 30, 40});
3296 TF_ASSERT_OK_AND_ASSIGN(
3297 Literal result,
3298 evaluator.EvaluateWithSubstitutions(
3299 add, {{param0, ¶m0_literal}, {square, &square_literal}}));
3300 EXPECT_TRUE(LiteralTestUtil::Equal(
3301 LiteralUtil::CreateR1<float>({11, 22, 33, 44}), result));
3302 }
3303
3304 // Check that EvaluateWithSubstitutions works if one of the operands to the op
3305 // we're evaluating is a constant.
TEST_P(HloEvaluatorBf16Test,EvaluateWithSubstitutionsWithConstantOperand)3306 TEST_P(HloEvaluatorBf16Test, EvaluateWithSubstitutionsWithConstantOperand) {
3307 HloComputation::Builder b(TestName());
3308 Shape shape = ShapeUtil::MakeShape(F32, {4});
3309
3310 HloInstruction* param0 =
3311 b.AddInstruction(HloInstruction::CreateParameter(0, shape, "param0"));
3312 HloInstruction* square = b.AddInstruction(HloInstruction::CreateBinary(
3313 shape, HloOpcode::kMultiply, param0, param0));
3314 HloInstruction* constant = b.AddInstruction(HloInstruction::CreateConstant(
3315 LiteralUtil::CreateR1<float>({1, 2, 3, 4})));
3316 HloInstruction* add = b.AddInstruction(
3317 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, constant, square));
3318
3319 // Evaluate add with square = {10, 20, 30, 40}.
3320 HloEvaluator evaluator;
3321 Literal square_literal = LiteralUtil::CreateR1<float>({10, 20, 30, 40});
3322 TF_ASSERT_OK_AND_ASSIGN(
3323 Literal result,
3324 evaluator.EvaluateWithSubstitutions(add, {{square, &square_literal}}));
3325 EXPECT_TRUE(LiteralTestUtil::Equal(
3326 LiteralUtil::CreateR1<float>({11, 22, 33, 44}), result));
3327 }
3328
TEST_F(HloEvaluatorTest,EvaluateGather_TensorFlowGatherV1)3329 TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV1) {
3330 const char* hlo_text = R"(
3331 HloModule TensorFlowGatherV1
3332
3333 ENTRY main {
3334 operand = s32[3,3] parameter(0)
3335 indices = s32[2] parameter(1)
3336 ROOT gather = s32[2,3] gather(operand, indices),
3337 offset_dims={1},
3338 collapsed_slice_dims={0},
3339 start_index_map={0},
3340 index_vector_dim=1,
3341 slice_sizes={1, 3}
3342 }
3343 )";
3344 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3345 Literal operand =
3346 LiteralUtil::CreateR2<int32_t>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3347 Literal start_indices = LiteralUtil::CreateR1<int32_t>({0, 2});
3348 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
3349 EXPECT_TRUE(LiteralTestUtil::Equal(
3350 LiteralUtil::CreateR2<int32_t>({{1, 2, 3}, {7, 8, 9}}), result));
3351 }
3352
TEST_F(HloEvaluatorTest,EvaluateGather_TensorFlowGatherV2)3353 TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) {
3354 const char* hlo_text = R"(
3355 HloModule TensorFlowGatherV2
3356
3357 ENTRY main {
3358 operand = s32[3,3] parameter(0)
3359 indices = s32[2] parameter(1)
3360 ROOT gather = s32[3,2] gather(operand, indices),
3361 offset_dims={0},
3362 collapsed_slice_dims={1},
3363 start_index_map={1},
3364 index_vector_dim=1,
3365 slice_sizes={3, 1}
3366 }
3367 )";
3368 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3369 Literal operand =
3370 LiteralUtil::CreateR2<int32_t>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3371 Literal start_indices = LiteralUtil::CreateR1<int32_t>({0, 2});
3372 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
3373 EXPECT_TRUE(LiteralTestUtil::Equal(
3374 LiteralUtil::CreateR2<int32_t>({{1, 3}, {4, 6}, {7, 9}}), result));
3375 }
3376
TEST_F(HloEvaluatorTest,EvaluateGather_TensorFlowGatherMultipleBatchDims)3377 TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) {
3378 const char* hlo_text = R"(
3379 HloModule TensorFlowGatherMultipleBatchDims
3380
3381 ENTRY main {
3382 operand = s32[3,3] parameter(0)
3383 indices = s32[2,2] parameter(1)
3384 ROOT gather = s32[2,3,2] gather(operand, indices),
3385 offset_dims={1},
3386 collapsed_slice_dims={1},
3387 start_index_map={1},
3388 index_vector_dim=2,
3389 slice_sizes={3, 1}
3390 }
3391 )";
3392 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3393 Literal operand =
3394 LiteralUtil::CreateR2<int32_t>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3395 Literal start_indices = LiteralUtil::CreateR2<int32_t>({{0, 2}, {2, 1}});
3396 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
3397 EXPECT_TRUE(LiteralTestUtil::Equal(
3398 LiteralUtil::CreateR3<int32_t>(
3399 {{{1, 3}, {4, 6}, {7, 9}}, {{3, 2}, {6, 5}, {9, 8}}}),
3400 result));
3401 }
3402
TEST_F(HloEvaluatorTest,EvaluateGather_TensorFlowGatherNd)3403 TEST_F(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) {
3404 const char* hlo_text = R"(
3405 HloModule TensorFlowGatherNd
3406
3407 ENTRY main {
3408 operand = s32[3,3,2] parameter(0)
3409 indices = s32[2,2] parameter(1)
3410 ROOT gather = s32[2,2] gather(operand, indices),
3411 offset_dims={1},
3412 collapsed_slice_dims={0,1},
3413 start_index_map={0,1},
3414 index_vector_dim=1,
3415 slice_sizes={1,1,2}
3416 }
3417 )";
3418 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3419 Literal operand =
3420 LiteralUtil::CreateR3<int32_t>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
3421 {{-4, 4}, {-5, 5}, {-6, 6}}, //
3422 {{-7, 7}, {-8, 8}, {-9, 9}}});
3423 Literal start_indices = LiteralUtil::CreateR2<int32_t>({{0, 0}, {1, 0}});
3424 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
3425 EXPECT_TRUE(LiteralTestUtil::Equal(
3426 LiteralUtil::CreateR2<int32_t>({{-1, 1}, {-4, 4}}), result));
3427 }
3428
TEST_F(HloEvaluatorTest,EvaluateGather_TensorFlowGatherNdNonDefaultIndexVectorDim)3429 TEST_F(HloEvaluatorTest,
3430 EvaluateGather_TensorFlowGatherNdNonDefaultIndexVectorDim) {
3431 const char* hlo_text = R"(
3432 HloModule TensorFlowGatherNd
3433
3434 ENTRY main {
3435 operand = s32[3,3,2] parameter(0)
3436 indices = s32[2,2] parameter(1)
3437 ROOT gather = s32[2,2] gather(operand, indices),
3438 offset_dims={1},
3439 collapsed_slice_dims={0,1},
3440 start_index_map={0,1},
3441 index_vector_dim=0,
3442 slice_sizes={1,1,2}
3443 }
3444 )";
3445 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3446 Literal operand =
3447 LiteralUtil::CreateR3<int32_t>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
3448 {{-4, 4}, {-5, 5}, {-6, 6}}, //
3449 {{-7, 7}, {-8, 8}, {-9, 9}}});
3450 Literal start_indices = LiteralUtil::CreateR2<int32_t>({{0, 0}, {1, 0}});
3451 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
3452 EXPECT_TRUE(LiteralTestUtil::Equal(
3453 LiteralUtil::CreateR2<int32_t>({{-2, 2}, {-1, 1}}), result));
3454 }
3455
TEST_F(HloEvaluatorTest,EvaluateGather_DynamicSlice)3456 TEST_F(HloEvaluatorTest, EvaluateGather_DynamicSlice) {
3457 const char* hlo_text = R"(
3458 HloModule DynamicSlice
3459
3460 ENTRY main {
3461 operand = s32[3,3] parameter(0)
3462 indices = s32[2] parameter(1)
3463 ROOT gather = s32[1,1] gather(operand, indices),
3464 offset_dims={0,1},
3465 collapsed_slice_dims={},
3466 start_index_map={0,1},
3467 index_vector_dim=0,
3468 slice_sizes={1,1}
3469 }
3470 )";
3471 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3472 Literal operand =
3473 LiteralUtil::CreateR2<int32_t>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3474 Literal start_indices = LiteralUtil::CreateR1<int32_t>({1, 1});
3475 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
3476 EXPECT_TRUE(
3477 LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32_t>({{5}}), result));
3478 }
3479
TEST_F(HloEvaluatorTest,EvaluateGather_BatchDynamicSlice)3480 TEST_F(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) {
3481 const char* hlo_text = R"(
3482 HloModule BatchDynamicSlice
3483
3484 ENTRY main {
3485 operand = s32[3,3] parameter(0)
3486 indices = s32[2,2] parameter(1)
3487 ROOT gather = s32[2,1,1] gather(operand, indices),
3488 offset_dims={1,2},
3489 collapsed_slice_dims={},
3490 start_index_map={0,1},
3491 index_vector_dim=0,
3492 slice_sizes={1,1}
3493 }
3494 )";
3495 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3496 Literal operand =
3497 LiteralUtil::CreateR2<int32_t>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3498 Literal start_indices = LiteralUtil::CreateR2<int32_t>({{2, 1}, {1, 1}});
3499 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
3500 EXPECT_TRUE(LiteralTestUtil::Equal(
3501 LiteralUtil::CreateR3<int32_t>({{{8}}, {{5}}}), result));
3502 }
3503
TEST_F(HloEvaluatorTest,EvaluateGather_ZeroDimBounds)3504 TEST_F(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) {
3505 const char* hlo_text = R"(
3506 HloModule TensorFlowGatherV1
3507
3508 ENTRY main {
3509 operand = s32[3,0] parameter(0)
3510 indices = s32[2] parameter(1)
3511 ROOT gather = s32[2,0] gather(operand, indices),
3512 offset_dims={1},
3513 collapsed_slice_dims={0},
3514 start_index_map={0},
3515 index_vector_dim=1,
3516 slice_sizes={1, 0}
3517 }
3518 )";
3519 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3520 Literal operand = LiteralUtil::CreateR2<int32_t>({{}, {}, {}});
3521 Literal start_indices = LiteralUtil::CreateR1<int32_t>({0, 2});
3522 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
3523 EXPECT_TRUE(
3524 LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32_t>({{}, {}}), result));
3525 }
3526
TEST_F(HloEvaluatorTest,EvaluateGather_NoOutputWindowDims)3527 TEST_F(HloEvaluatorTest, EvaluateGather_NoOutputWindowDims) {
3528 const std::string hlo_text = R"(
3529 HloModule GatherXd
3530
3531 ENTRY main {
3532 operand = s32[3] parameter(0)
3533 indices = s32[2,2,1] parameter(1)
3534 ROOT gather = s32[2,2] gather(operand, indices),
3535 offset_dims={},
3536 collapsed_slice_dims={0},
3537 start_index_map={0},
3538 index_vector_dim=2,
3539 slice_sizes={1}
3540 }
3541 )";
3542 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3543
3544 Literal operand = LiteralUtil::CreateR1<int32_t>({0, 1, 2});
3545 Literal start_indices =
3546 LiteralUtil::CreateR3<int32_t>({{{0}, {1}}, {{2}, {1}}});
3547 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
3548 EXPECT_TRUE(LiteralTestUtil::Equal(
3549 LiteralUtil::CreateR2<int32_t>({{0, 1}, {2, 1}}), result));
3550 }
3551
TEST_F(HloEvaluatorTest,EvaluateScatter_TensorFlowScatterV1_Update)3552 TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV1_Update) {
3553 const char* hlo_text = R"(
3554 HloModule TensorFlowScatterV1
3555
3556 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3557 lhs = s32[] parameter(0)
3558 ROOT rhs = s32[] parameter(1)
3559 }
3560
3561 ENTRY main {
3562 operand = s32[3,3] parameter(0)
3563 indices = s32[2] parameter(1)
3564 updates = s32[2,3] parameter(2)
3565 ROOT scatter = s32[3,3] scatter(operand, indices, updates),
3566 to_apply=update_s32,
3567 update_window_dims={1},
3568 inserted_window_dims={0},
3569 scatter_dims_to_operand_dims={0},
3570 index_vector_dim=1
3571 }
3572 )";
3573 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3574 Literal operand =
3575 LiteralUtil::CreateR2<int32_t>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3576 Literal scatter_indices = LiteralUtil::CreateR1<int32_t>({0, 2});
3577 Literal updates =
3578 LiteralUtil::CreateR2<int32_t>({{10, 20, 30}, {70, 80, 90}});
3579 TF_ASSERT_OK_AND_ASSIGN(Literal result,
3580 Evaluate({&operand, &scatter_indices, &updates}));
3581 EXPECT_TRUE(LiteralTestUtil::Equal(
3582 LiteralUtil::CreateR2<int32_t>({{10, 20, 30}, {4, 5, 6}, {70, 80, 90}}),
3583 result));
3584 }
3585
TEST_F(HloEvaluatorTest,EvaluateScatter_TensorFlowScatterV2_Update)3586 TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV2_Update) {
3587 const char* hlo_text = R"(
3588 HloModule TensorFlowScatterV2
3589
3590 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3591 lhs = s32[] parameter(0)
3592 ROOT rhs = s32[] parameter(1)
3593 }
3594
3595 ENTRY main {
3596 operand = s32[3,3] parameter(0)
3597 indices = s32[2] parameter(1)
3598 updates = s32[3,2] parameter(2)
3599 ROOT scatter = s32[3,3] scatter(operand, indices, updates),
3600 to_apply=update_s32,
3601 update_window_dims={0},
3602 inserted_window_dims={1},
3603 scatter_dims_to_operand_dims={1},
3604 index_vector_dim=1
3605 }
3606 )";
3607 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3608 Literal operand =
3609 LiteralUtil::CreateR2<int32_t>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3610 Literal scatter_indices = LiteralUtil::CreateR1<int32_t>({0, 2});
3611 Literal updates =
3612 LiteralUtil::CreateR2<int32_t>({{10, 30}, {40, 60}, {70, 90}});
3613 TF_ASSERT_OK_AND_ASSIGN(Literal result,
3614 Evaluate({&operand, &scatter_indices, &updates}));
3615 EXPECT_TRUE(LiteralTestUtil::Equal(
3616 LiteralUtil::CreateR2<int32_t>({{10, 2, 30}, {40, 5, 60}, {70, 8, 90}}),
3617 result));
3618 }
3619
TEST_F(HloEvaluatorTest,EvaluateScatter_TensorFlowScatter_Add)3620 TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Add) {
3621 const char* hlo_text = R"(
3622 HloModule TensorFlowScatter
3623
3624 add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3625 lhs = s32[] parameter(0)
3626 rhs = s32[] parameter(1)
3627 ROOT add = s32[] add(s32[] lhs, s32[] rhs)
3628 }
3629
3630 ENTRY main {
3631 operand = s32[3,3] parameter(0)
3632 indices = s32[2] parameter(1)
3633 updates = s32[2,3] parameter(2)
3634 ROOT scatter = s32[3,3] scatter(operand, indices, updates),
3635 to_apply=add_s32,
3636 update_window_dims={1},
3637 inserted_window_dims={0},
3638 scatter_dims_to_operand_dims={0},
3639 index_vector_dim=1
3640 }
3641 )";
3642 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3643 Literal operand =
3644 LiteralUtil::CreateR2<int32_t>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3645 Literal scatter_indices = LiteralUtil::CreateR1<int32_t>({0, 2});
3646 Literal updates =
3647 LiteralUtil::CreateR2<int32_t>({{10, 20, 30}, {70, 80, 90}});
3648 TF_ASSERT_OK_AND_ASSIGN(Literal result,
3649 Evaluate({&operand, &scatter_indices, &updates}));
3650 EXPECT_TRUE(LiteralTestUtil::Equal(
3651 LiteralUtil::CreateR2<int32_t>({{11, 22, 33}, {4, 5, 6}, {77, 88, 99}}),
3652 result));
3653 }
3654
TEST_F(HloEvaluatorTest,EvaluateScatter_TensorFlowScatter_Mul)3655 TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_Mul) {
3656 const char* hlo_text = R"(
3657 HloModule TensorFlowScatter
3658
3659 mul_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3660 lhs = s32[] parameter(0)
3661 rhs = s32[] parameter(1)
3662 ROOT mul = s32[] multiply(s32[] lhs, s32[] rhs)
3663 }
3664
3665 ENTRY main {
3666 operand = s32[3,3] parameter(0)
3667 indices = s32[2] parameter(1)
3668 updates = s32[2,3] parameter(2)
3669 ROOT scatter = s32[3,3] scatter(operand, indices, updates),
3670 to_apply=mul_s32,
3671 update_window_dims={1},
3672 inserted_window_dims={0},
3673 scatter_dims_to_operand_dims={0},
3674 index_vector_dim=1
3675 }
3676 )";
3677 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3678 Literal operand =
3679 LiteralUtil::CreateR2<int32_t>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3680 Literal scatter_indices = LiteralUtil::CreateR1<int32_t>({0, 2});
3681 Literal updates =
3682 LiteralUtil::CreateR2<int32_t>({{10, 20, 30}, {70, 80, 90}});
3683 TF_ASSERT_OK_AND_ASSIGN(Literal result,
3684 Evaluate({&operand, &scatter_indices, &updates}));
3685 EXPECT_TRUE(
3686 LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32_t>(
3687 {{10, 40, 90}, {4, 5, 6}, {490, 640, 810}}),
3688 result));
3689 }
3690
TEST_P(HloEvaluatorBf16Test,EvaluateScatter_TensorFlowScatter_F32)3691 TEST_P(HloEvaluatorBf16Test, EvaluateScatter_TensorFlowScatter_F32) {
3692 const char* hlo_text = R"(
3693 HloModule TensorFlowScatter
3694
3695 add_f32 (lhs: f32[], rhs: f32[]) -> f32[] {
3696 lhs = f32[] parameter(0)
3697 rhs = f32[] parameter(1)
3698 ROOT add = f32[] add(f32[] lhs, f32[] rhs)
3699 }
3700
3701 ENTRY main {
3702 operand = f32[3,3] parameter(0)
3703 indices = s32[2] parameter(1)
3704 updates = f32[2,3] parameter(2)
3705 ROOT scatter = f32[3,3] scatter(operand, indices, updates),
3706 to_apply=add_f32,
3707 update_window_dims={1},
3708 inserted_window_dims={0},
3709 scatter_dims_to_operand_dims={0},
3710 index_vector_dim=1
3711 }
3712 )";
3713 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3714 Literal operand = LiteralUtil::CreateR2<float>(
3715 {{1.1, 2.2, 3.3}, {4.4, 5.5, 6.6}, {7.7, 8.8, 9.9}});
3716 Literal scatter_indices = LiteralUtil::CreateR1<int32_t>({2, 1});
3717 Literal updates =
3718 LiteralUtil::CreateR2<float>({{0.4, 1.1, 0.7}, {2.3, 3.1, 1.6}});
3719 TF_ASSERT_OK_AND_ASSIGN(Literal result,
3720 Evaluate({&operand, &scatter_indices, &updates}));
3721 EXPECT_TRUE(LiteralTestUtil::Near(
3722 LiteralUtil::CreateR2<float>(
3723 {{1.1, 2.2, 3.3}, {6.7, 8.6, 8.2}, {8.1, 9.9, 10.6}}),
3724 result, ErrorSpec{0.1, 0.01}));
3725 }
3726
TEST_F(HloEvaluatorTest,EvaluateScatter_TensorFlowScatter_RepeatedIndices)3727 TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_RepeatedIndices) {
3728 const char* hlo_text = R"(
3729 HloModule TensorFlowScatter
3730
3731 add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3732 lhs = s32[] parameter(0)
3733 rhs = s32[] parameter(1)
3734 ROOT add = s32[] add(s32[] lhs, s32[] rhs)
3735 }
3736
3737 ENTRY main {
3738 operand = s32[3,3] parameter(0)
3739 indices = s32[2] parameter(1)
3740 updates = s32[2,3] parameter(2)
3741 ROOT scatter = s32[3,3] scatter(operand, indices, updates),
3742 to_apply=add_s32,
3743 update_window_dims={1},
3744 inserted_window_dims={0},
3745 scatter_dims_to_operand_dims={0},
3746 index_vector_dim=1
3747 }
3748 )";
3749 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3750 Literal operand =
3751 LiteralUtil::CreateR2<int32_t>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3752 Literal scatter_indices = LiteralUtil::CreateR1<int32_t>({1, 1});
3753 Literal updates =
3754 LiteralUtil::CreateR2<int32_t>({{10, 20, 30}, {70, 80, 90}});
3755 TF_ASSERT_OK_AND_ASSIGN(Literal result,
3756 Evaluate({&operand, &scatter_indices, &updates}));
3757 EXPECT_TRUE(LiteralTestUtil::Equal(
3758 LiteralUtil::CreateR2<int32_t>({{1, 2, 3}, {84, 105, 126}, {7, 8, 9}}),
3759 result));
3760 }
3761
TEST_F(HloEvaluatorTest,EvaluateScatter_TensorFlowScatter_MultipleBatchDims)3762 TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatter_MultipleBatchDims) {
3763 const char* hlo_text = R"(
3764 HloModule TensorFlowScatterMultipleBatchDims
3765
3766 add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3767 lhs = s32[] parameter(0)
3768 rhs = s32[] parameter(1)
3769 ROOT add = s32[] add(s32[] lhs, s32[] rhs)
3770 }
3771
3772 ENTRY main {
3773 operand = s32[3,3] parameter(0)
3774 indices = s32[2,2] parameter(1)
3775 updates = s32[2,3,2] parameter(2)
3776 ROOT scatter = s32[3,3] scatter(operand, indices, updates),
3777 to_apply=add_s32,
3778 update_window_dims={1},
3779 inserted_window_dims={1},
3780 scatter_dims_to_operand_dims={1},
3781 index_vector_dim=2
3782 }
3783 )";
3784 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3785 Literal operand =
3786 LiteralUtil::CreateR2<int32_t>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3787 Literal scatter_indices = LiteralUtil::CreateR2<int32_t>({{0, 2}, {2, 1}});
3788 Literal updates = LiteralUtil::CreateR3<int32_t>(
3789 {{{10, 30}, {40, 60}, {70, 90}}, {{5, 5}, {5, 5}, {5, 5}}});
3790 TF_ASSERT_OK_AND_ASSIGN(Literal result,
3791 Evaluate({&operand, &scatter_indices, &updates}));
3792 EXPECT_TRUE(
3793 LiteralTestUtil::Equal(LiteralUtil::CreateR2<int32_t>(
3794 {{11, 7, 38}, {44, 10, 71}, {77, 13, 104}}),
3795 result));
3796 }
3797
TEST_F(HloEvaluatorTest,EvaluateScatter_TensorFlowScatterNd)3798 TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterNd) {
3799 const char* hlo_text = R"(
3800 HloModule TensorFlowScatterNd
3801
3802 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3803 lhs = s32[] parameter(0)
3804 ROOT rhs = s32[] parameter(1)
3805 }
3806
3807 ENTRY main {
3808 operand = s32[3,3,2] parameter(0)
3809 indices = s32[2,2] parameter(1)
3810 updates = s32[2,2] parameter(2)
3811 ROOT scatter = s32[3,3,2] scatter(operand, indices, updates),
3812 to_apply=update_s32,
3813 update_window_dims={1},
3814 inserted_window_dims={0,1},
3815 scatter_dims_to_operand_dims={0,1},
3816 index_vector_dim=1
3817 }
3818 )";
3819 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3820 Literal operand =
3821 LiteralUtil::CreateR3<int32_t>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
3822 {{-4, 4}, {-5, 5}, {-6, 6}}, //
3823 {{-7, 7}, {-8, 8}, {-9, 9}}});
3824 Literal scatter_indices = LiteralUtil::CreateR2<int32_t>({{0, 0}, {1, 0}});
3825 Literal updates = LiteralUtil::CreateR2<int32_t>({{-10, 10}, {-40, 40}});
3826 Literal expected =
3827 LiteralUtil::CreateR3<int32_t>({{{-10, 10}, {-2, 2}, {-3, 3}}, //
3828 {{-40, 40}, {-5, 5}, {-6, 6}}, //
3829 {{-7, 7}, {-8, 8}, {-9, 9}}});
3830 TF_ASSERT_OK_AND_ASSIGN(Literal result,
3831 Evaluate({&operand, &scatter_indices, &updates}));
3832 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
3833 }
3834
TEST_F(HloEvaluatorTest,EvaluateScatter_TensorFlowScatterNd_NonDefaultIndexVectorDim)3835 TEST_F(HloEvaluatorTest,
3836 EvaluateScatter_TensorFlowScatterNd_NonDefaultIndexVectorDim) {
3837 const char* hlo_text = R"(
3838 HloModule TensorFlowScatterNdNonDefaultIndexVectorDim
3839
3840 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3841 lhs = s32[] parameter(0)
3842 ROOT rhs = s32[] parameter(1)
3843 }
3844
3845 ENTRY main {
3846 operand = s32[3,3,2] parameter(0)
3847 indices = s32[2,2] parameter(1)
3848 updates = s32[2,2] parameter(2)
3849 ROOT scatter = s32[3,3,2] scatter(operand, indices, updates),
3850 to_apply=update_s32,
3851 update_window_dims={1},
3852 inserted_window_dims={0,1},
3853 scatter_dims_to_operand_dims={0,1},
3854 index_vector_dim=0
3855 }
3856 )";
3857 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3858 Literal operand =
3859 LiteralUtil::CreateR3<int32_t>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
3860 {{-4, 4}, {-5, 5}, {-6, 6}}, //
3861 {{-7, 7}, {-8, 8}, {-9, 9}}});
3862 Literal scatter_indices = LiteralUtil::CreateR2<int32_t>({{0, 0}, {1, 0}});
3863 Literal updates = LiteralUtil::CreateR2<int32_t>({{-10, 10}, {-20, 20}});
3864 Literal expected =
3865 LiteralUtil::CreateR3<int32_t>({{{-20, 20}, {-10, 10}, {-3, 3}}, //
3866 {{-4, 4}, {-5, 5}, {-6, 6}}, //
3867 {{-7, 7}, {-8, 8}, {-9, 9}}});
3868 TF_ASSERT_OK_AND_ASSIGN(Literal result,
3869 Evaluate({&operand, &scatter_indices, &updates}));
3870 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
3871 }
3872
TEST_F(HloEvaluatorTest,EvaluateScatter_DynamicUpdateSlice)3873 TEST_F(HloEvaluatorTest, EvaluateScatter_DynamicUpdateSlice) {
3874 const char* hlo_text = R"(
3875 HloModule DynamicUpdateSlice
3876
3877 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3878 lhs = s32[] parameter(0)
3879 ROOT rhs = s32[] parameter(1)
3880 }
3881
3882 ENTRY main {
3883 operand = s32[3,3] parameter(0)
3884 indices = s32[2] parameter(1)
3885 updates = s32[1,1] parameter(2)
3886 ROOT scatter = s32[3,3] scatter(operand, indices, updates),
3887 to_apply=update_s32,
3888 update_window_dims={0,1},
3889 inserted_window_dims={},
3890 scatter_dims_to_operand_dims={0,1},
3891 index_vector_dim=0
3892 }
3893 )";
3894 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3895 Literal operand =
3896 LiteralUtil::CreateR2<int32_t>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3897 Literal scatter_indices = LiteralUtil::CreateR1<int32_t>({1, 1});
3898 Literal updates = LiteralUtil::CreateR2<int32_t>({{10}});
3899 Literal expected =
3900 LiteralUtil::CreateR2<int32_t>({{1, 2, 3}, {4, 10, 6}, {7, 8, 9}});
3901 TF_ASSERT_OK_AND_ASSIGN(Literal result,
3902 Evaluate({&operand, &scatter_indices, &updates}));
3903 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
3904 }
3905
TEST_F(HloEvaluatorTest,EvaluateScatter_BatchDynamicUpdateSlice)3906 TEST_F(HloEvaluatorTest, EvaluateScatter_BatchDynamicUpdateSlice) {
3907 const char* hlo_text = R"(
3908 HloModule BatchDynamicUpdateSlice
3909
3910 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3911 lhs = s32[] parameter(0)
3912 ROOT rhs = s32[] parameter(1)
3913 }
3914
3915 ENTRY main {
3916 operand = s32[3,3] parameter(0)
3917 indices = s32[2,2] parameter(1)
3918 updates = s32[2,1,1] parameter(2)
3919 ROOT scatter = s32[3,3] scatter(operand, indices, updates),
3920 to_apply=update_s32,
3921 update_window_dims={1,2},
3922 inserted_window_dims={},
3923 scatter_dims_to_operand_dims={0,1},
3924 index_vector_dim=0
3925 }
3926 )";
3927 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3928 Literal operand =
3929 LiteralUtil::CreateR2<int32_t>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
3930 Literal scatter_indices = LiteralUtil::CreateR2<int32_t>({{2, 1}, {1, 1}});
3931 Literal updates = LiteralUtil::CreateR3<int32_t>({{{10}}, {{20}}});
3932 Literal expected =
3933 LiteralUtil::CreateR2<int32_t>({{1, 2, 3}, {4, 20, 6}, {7, 10, 9}});
3934 TF_ASSERT_OK_AND_ASSIGN(Literal result,
3935 Evaluate({&operand, &scatter_indices, &updates}));
3936 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
3937 }
3938
TEST_F(HloEvaluatorTest,EvaluateScatter_ZeroDimBounds)3939 TEST_F(HloEvaluatorTest, EvaluateScatter_ZeroDimBounds) {
3940 const char* hlo_text = R"(
3941 HloModule TensorFlowScatter_ZeroDimBounds
3942
3943 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3944 lhs = s32[] parameter(0)
3945 ROOT rhs = s32[] parameter(1)
3946 }
3947
3948 ENTRY main {
3949 operand = s32[3,0] parameter(0)
3950 indices = s32[2] parameter(1)
3951 updates = s32[2,0] parameter(2)
3952 ROOT scatter = s32[3,0] scatter(operand, indices, updates),
3953 to_apply=update_s32,
3954 update_window_dims={1},
3955 inserted_window_dims={0},
3956 scatter_dims_to_operand_dims={0},
3957 index_vector_dim=1
3958 }
3959 )";
3960 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3961 Literal operand = LiteralUtil::CreateR2<int32_t>({{}, {}, {}});
3962 Literal scatter_indices = LiteralUtil::CreateR1<int32_t>({0, 2});
3963 Literal updates = LiteralUtil::CreateR2<int32_t>({{}, {}});
3964 TF_ASSERT_OK_AND_ASSIGN(Literal result,
3965 Evaluate({&operand, &scatter_indices, &updates}));
3966 EXPECT_TRUE(LiteralTestUtil::Equal(operand, result));
3967 }
3968
TEST_F(HloEvaluatorTest,EvaluateScatter_NoUpdateWindowDims)3969 TEST_F(HloEvaluatorTest, EvaluateScatter_NoUpdateWindowDims) {
3970 const std::string hlo_text = R"(
3971 HloModule Scatter_NoUpdateWindowDims
3972
3973 add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
3974 lhs = s32[] parameter(0)
3975 rhs = s32[] parameter(1)
3976 ROOT add = s32[] add(s32[] lhs, s32[] rhs)
3977 }
3978
3979 ENTRY main {
3980 operand = s32[3] parameter(0)
3981 indices = s32[2,2,1] parameter(1)
3982 updates = s32[2,2] parameter(2)
3983 ROOT scatter = s32[3] scatter(operand, indices, updates),
3984 to_apply=add_s32,
3985 update_window_dims={},
3986 inserted_window_dims={0},
3987 scatter_dims_to_operand_dims={0},
3988 index_vector_dim=2
3989 }
3990 )";
3991 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
3992
3993 Literal operand = LiteralUtil::CreateR1<int32_t>({0, 1, 2});
3994 Literal scatter_indices =
3995 LiteralUtil::CreateR3<int32_t>({{{0}, {1}}, {{2}, {1}}});
3996 Literal updates = LiteralUtil::CreateR2<int32_t>({{10, 20}, {30, 40}});
3997 Literal expected = LiteralUtil::CreateR1<int32_t>({10, 61, 32});
3998 TF_ASSERT_OK_AND_ASSIGN(Literal result,
3999 Evaluate({&operand, &scatter_indices, &updates}));
4000 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
4001 }
4002
TEST_F(HloEvaluatorTest,EvaluateScatter_NegativeIndices)4003 TEST_F(HloEvaluatorTest, EvaluateScatter_NegativeIndices) {
4004 const char* hlo_text = R"(
4005 HloModule TensorFlowScatter_NegativeIndices
4006
4007 add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
4008 lhs = s32[] parameter(0)
4009 rhs = s32[] parameter(1)
4010 ROOT add = s32[] add(s32[] lhs, s32[] rhs)
4011 }
4012
4013 ENTRY main {
4014 operand = s32[3,3] parameter(0)
4015 indices = s32[2] parameter(1)
4016 updates = s32[2,3] parameter(2)
4017 ROOT scatter = s32[3,3] scatter(operand, indices, updates),
4018 to_apply=add_s32,
4019 update_window_dims={1},
4020 inserted_window_dims={0},
4021 scatter_dims_to_operand_dims={0},
4022 index_vector_dim=1
4023 }
4024 )";
4025 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
4026 ParseAndReturnVerifiedModule(hlo_text));
4027 Literal operand =
4028 LiteralUtil::CreateR2<int32_t>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
4029 // No updates should happen for the negative indices.
4030 Literal scatter_indices = LiteralUtil::CreateR1<int32_t>({-1, 2});
4031 Literal updates =
4032 LiteralUtil::CreateR2<int32_t>({{10, 20, 30}, {70, 80, 90}});
4033 EXPECT_TRUE(LiteralTestUtil::Equal(
4034 LiteralUtil::CreateR2<int32_t>({{1, 2, 3}, {4, 5, 6}, {77, 88, 99}}),
4035 EvaluateWithModule(module.get(),
4036 {&operand, &scatter_indices, &updates})));
4037 }
4038
TEST_F(HloEvaluatorTest,EvaluateScatter_OobIndices)4039 TEST_F(HloEvaluatorTest, EvaluateScatter_OobIndices) {
4040 const std::string hlo_text = R"(
4041 HloModule BatchDynamicUpdateSlice
4042
4043 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
4044 lhs = s32[] parameter(0)
4045 ROOT rhs = s32[] parameter(1)
4046 }
4047
4048 ENTRY main {
4049 operand = s32[3,3]{1,0} parameter(0)
4050 indices = s32[6,2]{1,0} parameter(1)
4051 updates = s32[6,1,1]{2,1,0} parameter(2)
4052 ROOT scatter = s32[3,3]{1,0} scatter(operand, indices, updates),
4053 to_apply=update_s32,
4054 update_window_dims={1,2},
4055 inserted_window_dims={},
4056 scatter_dims_to_operand_dims={0,1},
4057 index_vector_dim=1
4058 }
4059 )";
4060 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
4061 ParseAndReturnVerifiedModule(hlo_text));
4062 Literal operand =
4063 LiteralUtil::CreateR2<int32_t>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
4064 // No updates should happen for the OOB indices.
4065 Literal scatter_indices = LiteralUtil::CreateR2<int32_t>(
4066 {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}});
4067 Literal updates = LiteralUtil::CreateR3<int32_t>(
4068 {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}});
4069 EXPECT_TRUE(LiteralTestUtil::Equal(
4070 LiteralUtil::CreateR2<int32_t>({{1, 2, 3}, {4, 30, 60}, {7, 20, 9}}),
4071 EvaluateWithModule(module.get(),
4072 {&operand, &scatter_indices, &updates})));
4073 }
4074
TEST_F(HloEvaluatorTest,EvaluateScatter_OobUpdateWindow)4075 TEST_F(HloEvaluatorTest, EvaluateScatter_OobUpdateWindow) {
4076 const char* hlo_text = R"(
4077 HloModule TensorFlowScatterNd_OobUpdateWindow
4078
4079 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
4080 lhs = s32[] parameter(0)
4081 ROOT rhs = s32[] parameter(1)
4082 }
4083
4084 ENTRY main {
4085 operand = s32[3,3,2] parameter(0)
4086 indices = s32[1,2] parameter(1)
4087 updates = s32[1,2,2] parameter(2)
4088 ROOT scatter = s32[3,3,2] scatter(operand, indices, updates),
4089 to_apply=update_s32,
4090 update_window_dims={1,2},
4091 inserted_window_dims={0},
4092 scatter_dims_to_operand_dims={0,1},
4093 index_vector_dim=1
4094 }
4095 )";
4096 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
4097 ParseAndReturnVerifiedModule(hlo_text));
4098 Literal operand =
4099 LiteralUtil::CreateR3<int32_t>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
4100 {{-4, 4}, {-5, 5}, {-6, 6}}, //
4101 {{-7, 7}, {-8, 8}, {-9, 9}}});
4102 Literal scatter_indices = LiteralUtil::CreateR2<int32_t>({{0, 2}});
4103 Literal updates = LiteralUtil::CreateR3<int32_t>({{{-10, 10}, {-40, 40}}});
4104 // Given the update window size of 2,2 and the index of 0,2, the update window
4105 // will be OOB. So, nothing should be updated.
4106 Literal expected = operand.Clone();
4107 EXPECT_TRUE(LiteralTestUtil::Equal(
4108 expected, EvaluateWithModule(module.get(),
4109 {&operand, &scatter_indices, &updates})));
4110 }
4111
TEST_F(HloEvaluatorTest,EvaluateScatter_Multioutput)4112 TEST_F(HloEvaluatorTest, EvaluateScatter_Multioutput) {
4113 const char* hlo_text = R"(
4114 HloModule MultioutputScatter
4115
4116 update {
4117 lhs0 = s32[] parameter(0)
4118 lhs1 = f32[] parameter(1)
4119 rhs0 = s32[] parameter(2)
4120 rhs1 = f32[] parameter(3)
4121 ROOT tuple = (s32[], f32[]) tuple(rhs0, rhs1)
4122 }
4123
4124 ENTRY main {
4125 operand0 = s32[3,3,2] parameter(0)
4126 operand1 = f32[3,3,2] parameter(1)
4127 indices = s32[2,2] parameter(2)
4128 updates0 = s32[2,2] parameter(3)
4129 updates1 = f32[2,2] parameter(4)
4130 ROOT scatter = (s32[3,3,2], f32[3,3,2]) scatter(operand0, operand1, indices, updates0, updates1),
4131 to_apply=update,
4132 update_window_dims={1},
4133 inserted_window_dims={0,1},
4134 scatter_dims_to_operand_dims={0,1},
4135 index_vector_dim=1
4136 }
4137 )";
4138 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4139 Literal operand0 =
4140 LiteralUtil::CreateR3<int32_t>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
4141 {{-4, 4}, {-5, 5}, {-6, 6}}, //
4142 {{-7, 7}, {-8, 8}, {-9, 9}}});
4143 Literal operand1 =
4144 LiteralUtil::CreateR3<float>({{{-2, 2}, {-3, 3}, {-4, 4}}, //
4145 {{-5, 5}, {-6, 6}, {-7, 7}}, //
4146 {{-8, 8}, {-9, 9}, {-10, 10}}});
4147 Literal scatter_indices = LiteralUtil::CreateR2<int32_t>({{0, 0}, {1, 0}});
4148 Literal updates0 = LiteralUtil::CreateR2<int32_t>({{-10, 10}, {-40, 40}});
4149 Literal updates1 = LiteralUtil::CreateR2<float>({{-11, 11}, {-41, 41}});
4150 Literal expected = LiteralUtil::MakeTupleOwned(
4151 LiteralUtil::CreateR3<int32_t>({{{-10, 10}, {-2, 2}, {-3, 3}}, //
4152 {{-40, 40}, {-5, 5}, {-6, 6}}, //
4153 {{-7, 7}, {-8, 8}, {-9, 9}}}),
4154 LiteralUtil::CreateR3<float>({{{-11, 11}, {-3, 3}, {-4, 4}}, //
4155 {{-41, 41}, {-6, 6}, {-7, 7}}, //
4156 {{-8, 8}, {-9, 9}, {-10, 10}}}));
4157 TF_ASSERT_OK_AND_ASSIGN(
4158 Literal result,
4159 Evaluate({&operand0, &operand1, &scatter_indices, &updates0, &updates1}));
4160 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
4161 }
4162
4163 // Verifies that HloEvaluator evaluates a HLO instruction that performs
4164 // element-wise comparison with 2 bfloat16 operands.
TEST_F(HloEvaluatorTest,DoesCompareBF16)4165 TEST_F(HloEvaluatorTest, DoesCompareBF16) {
4166 // lhs >= rhs
4167 auto lhs = LiteralUtil::CreateR2<bfloat16>(
4168 {{bfloat16(0.25), bfloat16(0.35), bfloat16(0.125)},
4169 {bfloat16(-0.25), bfloat16(-0.35), bfloat16(-0.125)}});
4170 auto rhs = LiteralUtil::CreateR2<bfloat16>(
4171 {{bfloat16(0.5), bfloat16(0.125), bfloat16(0.125)},
4172 {bfloat16(0.25), bfloat16(-0.375), bfloat16(-0.127)}});
4173 auto expected =
4174 LiteralUtil::CreateR2<bool>({{false, true, true}, {false, true, true}});
4175
4176 HloComputation::Builder b(TestName());
4177 auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs)));
4178 auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs)));
4179 b.AddInstruction(HloInstruction::CreateCompare(expected.shape(), c1, c2,
4180 ComparisonDirection::kGe));
4181 m_->AddEntryComputation(b.Build());
4182
4183 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate());
4184 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
4185 }
4186
TEST_P(HloEvaluatorBf16Test,Bf16Reduction)4187 TEST_P(HloEvaluatorBf16Test, Bf16Reduction) {
4188 const std::string hlo_text = R"(
4189 HloModule Bf16Reduction
4190
4191 add_bf16 (lhs: bf16[], rhs: bf16[]) -> bf16[] {
4192 lhs = bf16[] parameter(0)
4193 rhs = bf16[] parameter(1)
4194 ROOT add = bf16[] add(bf16[] lhs, bf16[] rhs)
4195 }
4196
4197 ENTRY main {
4198 arg0 = bf16[4]{0} parameter(0)
4199 init = bf16[] constant(0)
4200 ROOT %reduce = bf16[] reduce(arg0, init), dimensions={0}, to_apply=add_bf16
4201 }
4202 )";
4203 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4204
4205 Literal arg = LiteralUtil::CreateR1<bfloat16>(
4206 {bfloat16(1.0f), bfloat16(3.0f), bfloat16(-2.0f), bfloat16(42.0f)});
4207 Literal expected = LiteralUtil::CreateR0<bfloat16>(bfloat16(44.0f));
4208 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&arg}));
4209 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
4210 }
4211
TEST_F(HloEvaluatorTest,MixedPrecisionReduction)4212 TEST_F(HloEvaluatorTest, MixedPrecisionReduction) {
4213 const std::string hlo_text = R"(
4214 HloModule MixedPrecisionReduction
4215
4216 add_f32 {
4217 lhs = f32[] parameter(0)
4218 rhs = f32[] parameter(1)
4219 ROOT add = f32[] add(lhs, rhs)
4220 }
4221
4222 ENTRY main {
4223 arg0 = f32[4]{0} parameter(0)
4224 init = f32[] constant(0)
4225 ROOT %reduce = bf16[] reduce(arg0, init), dimensions={0}, to_apply=add_f32
4226 }
4227 )";
4228 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4229
4230 Literal arg = LiteralUtil::CreateR1<float>({1.0f, 3.0f, -2.0f, 42.0f});
4231 Literal expected = LiteralUtil::CreateR0<bfloat16>(bfloat16(44.0f));
4232 TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&arg}));
4233 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
4234 }
4235
TEST_F(HloEvaluatorTest,DontFailOnCallUnimplementedOps)4236 TEST_F(HloEvaluatorTest, DontFailOnCallUnimplementedOps) {
4237 // Outfeed triggers unimplemented error within HandleCall, and we verify that
4238 // the Evaluator does fail in such case.
4239 const std::string hlo_text = R"(
4240 HloModule DontFailOnCall
4241
4242 call {
4243 token0 = token[] after-all()
4244 constant = u32[3]{0} constant({1,2,3})
4245 ROOT outfeed = token[] outfeed(constant, token0), outfeed_shape=u32[3]{0}
4246 }
4247
4248 ENTRY main {
4249 ROOT result = token[] call(), to_apply=call
4250 }
4251 )";
4252 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4253 auto statusor = Evaluate();
4254 EXPECT_FALSE(statusor.status().ok());
4255 }
4256
TEST_F(HloEvaluatorTest,DontFailOnFusionWithUnimplementedOps)4257 TEST_F(HloEvaluatorTest, DontFailOnFusionWithUnimplementedOps) {
4258 // Outfeed triggers unimplemented error within HandleFusion, and we verify
4259 // that the Evaluator does fail in such case.
4260 const std::string hlo_text = R"(
4261 HloModule DontFailOnFusion
4262
4263 fused_computation {
4264 token0 = token[] after-all()
4265 constant = u32[3]{0} constant({1,2,3})
4266 ROOT outfeed = token[] outfeed(constant, token0), outfeed_shape=u32[3]{0}
4267 }
4268
4269 ENTRY main {
4270 ROOT result = token[] fusion(), kind=kLoop, calls=fused_computation
4271 }
4272 )";
4273 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4274 auto statusor = Evaluate();
4275 EXPECT_FALSE(statusor.status().ok());
4276 }
4277
TEST_P(HloEvaluatorBf16Test,SliceWithDifferentLayout)4278 TEST_P(HloEvaluatorBf16Test, SliceWithDifferentLayout) {
4279 // Regression test for b/114735354.
4280 const std::string hlo_text = R"(
4281 HloModule SliceWithDifferentLayout
4282
4283 ENTRY main {
4284 arg = f32[2,2,2]{0,1,2} parameter(0)
4285 ROOT %slice = f32[2,2,2]{1,0,2} slice(f32[2,2,2]{0,1,2} %arg), slice={[0:2], [0:2], [0:2]}
4286 }
4287 )";
4288 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4289
4290 Literal arg = LiteralUtil::CreateR3WithLayout<float>(
4291 {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
4292 LayoutUtil::MakeLayout({0, 1, 2}));
4293 TF_ASSERT_OK_AND_ASSIGN(Literal actual, Evaluate({&arg}));
4294 EXPECT_TRUE(LiteralTestUtil::Equal(arg, actual));
4295 }
4296
TEST_P(HloEvaluatorBf16Test,Bitcast)4297 TEST_P(HloEvaluatorBf16Test, Bitcast) {
4298 // Regression test for b/114735354.
4299 const absl::string_view hlo_text_base = R"(
4300 HloModule Bitcast
4301
4302 ENTRY main {
4303 param = %s[32,121]{1,0} parameter(0)
4304 ROOT bitcast = %s[121,32,1]{0,1,2} bitcast(%s[32,121]{1,0} param)
4305 }
4306 )";
4307 std::string hlo_text;
4308 if (use_bfloat16_) {
4309 hlo_text = absl::StrFormat(hlo_text_base, "bf16", "bf16", "bf16");
4310 } else {
4311 hlo_text = absl::StrFormat(hlo_text_base, "f32", "f32", "f32");
4312 }
4313 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4314 auto args = MakeFakeArguments(m_.get()).value();
4315 TF_ASSERT_OK_AND_ASSIGN(Literal actual, Evaluate({&args[0]}));
4316 if (use_bfloat16_) {
4317 EXPECT_TRUE(
4318 absl::c_equal(args[0].data<bfloat16>(), actual.data<bfloat16>()));
4319 } else {
4320 EXPECT_TRUE(absl::c_equal(args[0].data<float>(), actual.data<float>()));
4321 }
4322 }
4323
4324 // Check that s32 under/overflow doesn't trigger a ubsan failure.
TEST_F(HloEvaluatorTest,Int32Overflow)4325 TEST_F(HloEvaluatorTest, Int32Overflow) {
4326 const absl::string_view hlo_text = R"(
4327 HloModule Test
4328
4329 ENTRY main {
4330 c1 = s32[] constant(1073741824) // 2^30
4331 sum = s32[] add(c1, c1) // 2^31, i.e. INT_MIN
4332
4333 c2 = s32[] constant(-2147483648) // -2^31
4334 sub = s32[] subtract(c2, c1) // -2^31 - 2^30, underflows
4335
4336 mul = s32[] multiply(c1, c1)
4337 ROOT tuple = (s32[], s32[], s32[]) tuple(sum, sub, mul)
4338 }
4339 )";
4340 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4341 TF_ASSERT_OK_AND_ASSIGN(auto literal, Evaluate({}));
4342 std::vector<Literal> actual = literal.DecomposeTuple();
4343 ASSERT_EQ(actual.size(), 3);
4344
4345 uint32_t pow30 = uint32_t{1} << 30;
4346 uint32_t pow31 = uint32_t{1} << 31;
4347 EXPECT_EQ(actual[0].GetFirstElement<int32_t>(), static_cast<int32_t>(pow31));
4348 EXPECT_EQ(actual[1].GetFirstElement<int32_t>(),
4349 static_cast<int32_t>(-(pow31 + pow30)));
4350 EXPECT_EQ(actual[2].GetFirstElement<int32_t>(),
4351 static_cast<int32_t>(pow31 * pow31));
4352 }
4353
TEST_F(HloEvaluatorTest,GetDimensionSize)4354 TEST_F(HloEvaluatorTest, GetDimensionSize) {
4355 const absl::string_view hlo_text = R"(
4356 HloModule Test
4357
4358 ENTRY main {
4359 size = s32[] parameter(0)
4360
4361 data = s32[4] parameter(1)
4362
4363 sum = s32[4] add(data, data)
4364
4365 ROOT dynamic_size = s32[] get-dimension-size(sum), dimensions={0}
4366 }
4367 )";
4368 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4369
4370 // Set up dynamic parameter binding.
4371 TF_CHECK_OK(m_->dynamic_parameter_binding().Bind(
4372 DynamicParameterBinding::DynamicParameter{0, {}},
4373 DynamicParameterBinding::DynamicDimension{1, {}, 0}));
4374
4375 TF_ASSERT_OK_AND_ASSIGN(DynamicDimensionInference dynamic_dimension_inference,
4376 DynamicDimensionInference::Run(m_.get()));
4377
4378 evaluator_.set_dynamic_dimension_inference(&dynamic_dimension_inference);
4379 Literal size_arg = LiteralUtil::CreateR0<int32_t>(3);
4380 Literal data_arg = LiteralUtil::CreateR1<int32_t>({1, 2, 3, 4});
4381
4382 TF_ASSERT_OK_AND_ASSIGN(Literal actual, Evaluate({&size_arg, &data_arg}));
4383
4384 EXPECT_EQ(actual.GetFirstElement<int32_t>(), static_cast<int32_t>(3));
4385 }
4386
4387 // Check that we get a useful error if we pass inputs of the wrong shape.
TEST_F(HloEvaluatorTest,EvaluateWithWrongInputShapes)4388 TEST_F(HloEvaluatorTest, EvaluateWithWrongInputShapes) {
4389 const absl::string_view hlo_text = R"(
4390 HloModule Test
4391
4392 ENTRY main {
4393 p0 = s32[1] parameter(0)
4394 ROOT sum = s32[1] add(p0, p0)
4395 }
4396 )";
4397 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4398 Literal input_wrong_shape = LiteralUtil::CreateR1<int32_t>({0, 1});
4399
4400 EXPECT_EQ(HloEvaluator()
4401 .Evaluate(*m_, {&input_wrong_shape})
4402 .status()
4403 .error_message(),
4404 "Shape mismatch at parameter 0. Computation expected s32[1]{0}, "
4405 "but arg was s32[2]{0}.");
4406 EXPECT_EQ(HloEvaluator()
4407 .Evaluate(*m_->entry_computation(), {&input_wrong_shape})
4408 .status()
4409 .error_message(),
4410 "Shape mismatch at parameter 0. Computation expected s32[1]{0}, "
4411 "but arg was s32[2]{0}.");
4412 }
4413
4414 // Check that we get a useful error if we pass too many or too few inputs.
TEST_F(HloEvaluatorTest,EvaluateWithWrongNumberOfInputs)4415 TEST_F(HloEvaluatorTest, EvaluateWithWrongNumberOfInputs) {
4416 const absl::string_view hlo_text = R"(
4417 HloModule Test
4418
4419 ENTRY main {
4420 p0 = s32[1] parameter(0)
4421 ROOT sum = s32[1] add(p0, p0)
4422 }
4423 )";
4424 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4425 Literal input = LiteralUtil::CreateR1<int32_t>({0});
4426
4427 EXPECT_EQ(
4428 HloEvaluator().Evaluate(*m_, {&input, &input}).status().error_message(),
4429 "Expected 1 argument, but got 2.");
4430 EXPECT_EQ(HloEvaluator()
4431 .Evaluate(*m_->entry_computation(), {&input, &input})
4432 .status()
4433 .error_message(),
4434 "Expected 1 argument, but got 2.");
4435 }
4436
TEST_F(HloEvaluatorTest,PreserveFusionInputLayout)4437 TEST_F(HloEvaluatorTest, PreserveFusionInputLayout) {
4438 const absl::string_view hlo_text = R"(
4439 HloModule FusionInputLayout
4440
4441 fused_computation {
4442 param_0 = f32[20,20]{0,1} parameter(0)
4443 ROOT bitcast = f32[20,20]{1,0} bitcast(param_0)
4444 }
4445
4446 ENTRY kernel_entry {
4447 parameter.0 = f32[20,20]{0,1} parameter(0)
4448 ROOT fusion = f32[20,20]{1,0} fusion(parameter.0),
4449 kind=kLoop, calls=fused_computation
4450 })";
4451
4452 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4453 auto args = MakeFakeArguments(m_.get()).value();
4454
4455 TF_ASSERT_OK_AND_ASSIGN(Literal actual, Evaluate({&args[0]}));
4456 EXPECT_TRUE(absl::c_equal(args[0].data<float>(), actual.data<float>()));
4457 }
4458
TEST_F(HloEvaluatorTest,PreserveFusionOutputLayout)4459 TEST_F(HloEvaluatorTest, PreserveFusionOutputLayout) {
4460 const absl::string_view hlo_text = R"(
4461 HloModule FusionOutputLayout
4462
4463 fused_computation {
4464 param_0 = f32[20,20]{1,0} parameter(0)
4465 ROOT bitcast = f32[20,20]{0,1} bitcast(param_0)
4466 }
4467
4468 ENTRY kernel_entry {
4469 parameter.0 = f32[20,20]{1,0} parameter(0)
4470 ROOT fusion = f32[20,20]{0,1} fusion(parameter.0),
4471 kind=kLoop, calls=fused_computation
4472 })";
4473
4474 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4475 auto args = MakeFakeArguments(m_.get()).value();
4476 TF_ASSERT_OK_AND_ASSIGN(Literal actual, Evaluate({&args[0]}));
4477 EXPECT_TRUE(absl::c_equal(args[0].data<float>(), actual.data<float>()));
4478 }
4479
TEST_F(HloEvaluatorTest,PreserveMOFusionOutputLayout)4480 TEST_F(HloEvaluatorTest, PreserveMOFusionOutputLayout) {
4481 const absl::string_view hlo_text = R"(
4482 HloModule MOFusionOutputLayout
4483
4484 fused_computation {
4485 param_0 = f32[20,20]{1,0} parameter(0)
4486 bitcast = f32[20,20]{0,1} bitcast(param_0)
4487 ROOT tuple = (f32[20,20]{0,1}) tuple(bitcast)
4488 }
4489
4490 ENTRY kernel_entry {
4491 parameter.0 = f32[20,20]{1,0} parameter(0)
4492 ROOT fusion = (f32[20,20]{0,1}) fusion(parameter.0),
4493 kind=kLoop, calls=fused_computation
4494 })";
4495
4496 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4497 auto args = MakeFakeArguments(m_.get()).value();
4498 TF_ASSERT_OK_AND_ASSIGN(Literal actual_tuple, Evaluate({&args[0]}));
4499 std::vector<Literal> actual_literals = actual_tuple.DecomposeTuple();
4500 EXPECT_TRUE(
4501 absl::c_equal(args[0].data<float>(), actual_literals[0].data<float>()));
4502 }
4503
4504 // Tests that custom_calls fail to evaluate when no handler is specified.
TEST_F(HloEvaluatorTest,EvaluateCustomCall_NoHandler)4505 TEST_F(HloEvaluatorTest, EvaluateCustomCall_NoHandler) {
4506 const absl::string_view hlo_text = R"(
4507 HloModule EvaluateCustomCall_NoHandler
4508 ENTRY kernel_entry {
4509 parameter.0 = u32[2,2]{1,0} parameter(0)
4510 ROOT test_root = (u32[2,2]{1,0}) custom-call(parameter.0),
4511 custom_call_target="_my_custom_call"
4512 }
4513 )";
4514
4515 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4516 auto args = MakeFakeArguments(m_.get()).value();
4517 EXPECT_EQ(HloEvaluator().Evaluate(*m_, {&args[0]}).status().code(),
4518 ::tensorflow::error::UNIMPLEMENTED);
4519 }
4520
4521 // Tests when a custom_call handler returns an error.
TEST_F(HloEvaluatorTest,EvaluateCustomCall_HandlerError)4522 TEST_F(HloEvaluatorTest, EvaluateCustomCall_HandlerError) {
4523 const absl::string_view hlo_text = R"(
4524 HloModule EvaluateCustomCall_HandlerError
4525 ENTRY kernel_entry {
4526 parameter.0 = u32[2,2]{1,0} parameter(0)
4527 ROOT test_root = (u32[2,2]{1,0}) custom-call(parameter.0),
4528 custom_call_target="_my_custom_call"
4529 }
4530 )";
4531
4532 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4533 auto args = MakeFakeArguments(m_.get()).value();
4534 HloEvaluator evaluator;
4535 evaluator.set_custom_call_handler(
4536 [](HloInstruction* custom_call, absl::Span<const Literal*> operands) {
4537 return InternalError("Test error");
4538 });
4539 EXPECT_EQ(evaluator.Evaluate(*m_, {&args[0]}).status().code(),
4540 ::tensorflow::error::INTERNAL);
4541 }
4542
4543 // Tests the custom_call handler on calls with many inputs.
4544 // We sum the operands so that we can verify the operand and output literals
4545 // are properly mapped for access.
TEST_F(HloEvaluatorTest,EvaluateCustomCall_ManyInputs)4546 TEST_F(HloEvaluatorTest, EvaluateCustomCall_ManyInputs) {
4547 const absl::string_view hlo_text = R"(
4548 HloModule EvaluateCustomCall_ManyInputs
4549 ENTRY kernel_entry {
4550 parameter.0 = u32[1]{0} parameter(0)
4551 parameter.1 = u32[1]{0} parameter(1)
4552 ROOT test_root = u32[1]{0} custom-call(parameter.0, parameter.1),
4553 custom_call_target="_my_custom_call"
4554 }
4555 )";
4556
4557 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4558 auto args = MakeFakeArguments(m_.get()).value();
4559 HloEvaluator evaluator;
4560 evaluator.set_custom_call_handler(
4561 [](HloInstruction* custom_call, absl::Span<const Literal*> operands) {
4562 EXPECT_EQ(HloOpcode::kCustomCall, custom_call->opcode());
4563 EXPECT_EQ("_my_custom_call", custom_call->custom_call_target());
4564 EXPECT_EQ(2, custom_call->operand_count());
4565 EXPECT_EQ(2, operands.size());
4566 auto output = Literal::CreateFromShape(custom_call->shape());
4567 auto operand0_data = operands[0]->data<uint32_t>();
4568 auto operand1_data = operands[1]->data<uint32_t>();
4569 auto output_data = output.data<uint32_t>();
4570 output_data[0] = operand0_data[0] + operand1_data[0];
4571 return output;
4572 });
4573 TF_ASSERT_OK_AND_ASSIGN(
4574 Literal actual_literal,
4575 evaluator.Evaluate(*m_->entry_computation(), {&args[0], &args[1]}));
4576 auto arg0_data = args[0].data<uint32_t>();
4577 auto arg1_data = args[1].data<uint32_t>();
4578 std::vector<uint32_t> expected_data = {arg0_data[0] + arg1_data[0]};
4579 EXPECT_TRUE(absl::c_equal(expected_data, actual_literal.data<uint32_t>()));
4580 }
4581
TEST_F(HloEvaluatorTest,IsFiniteF16)4582 TEST_F(HloEvaluatorTest, IsFiniteF16) {
4583 const absl::string_view hlo_text = R"(
4584 HloModule test
4585
4586 ENTRY IsFiniteTest {
4587 c = f16[6] constant({nan, 7, nan, -1, inf, -inf})
4588 ROOT is-finite = pred[6] is-finite(c)
4589 })";
4590
4591 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4592 TF_ASSERT_OK_AND_ASSIGN(
4593 Literal actual_literal,
4594 HloEvaluator().Evaluate(*m_->entry_computation(), {}));
4595 EXPECT_THAT(actual_literal.data<bool>(),
4596 ::testing::ElementsAre(false, true, false, true, false, false));
4597 }
4598
TEST_F(HloEvaluatorTest,IsFiniteBf16)4599 TEST_F(HloEvaluatorTest, IsFiniteBf16) {
4600 const absl::string_view hlo_text = R"(
4601 HloModule test
4602
4603 ENTRY IsFiniteTest {
4604 c = bf16[6] constant({nan, 7, nan, -1, inf, -inf})
4605 ROOT is-finite = pred[6] is-finite(c)
4606 })";
4607
4608 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4609 TF_ASSERT_OK_AND_ASSIGN(
4610 Literal actual_literal,
4611 HloEvaluator().Evaluate(*m_->entry_computation(), {}));
4612 EXPECT_THAT(actual_literal.data<bool>(),
4613 ::testing::ElementsAre(false, true, false, true, false, false));
4614 }
4615
4616 // Check that evaluating `f32[<huge>, 0] iota` doesn't oom (it's an empty
4617 // array!).
TEST_F(HloEvaluatorTest,ZeroSizedIotaWithHugeDimension)4618 TEST_F(HloEvaluatorTest, ZeroSizedIotaWithHugeDimension) {
4619 const absl::string_view hlo_text = R"(
4620 HloModule test
4621 ENTRY t {
4622 ROOT i = f32[1000000000000, 0] iota(), iota_dimension=0
4623 })";
4624 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4625 TF_ASSERT_OK_AND_ASSIGN(
4626 Literal actual_literal,
4627 HloEvaluator().Evaluate(*m_->entry_computation(), {}));
4628 EXPECT_THAT(actual_literal.data<float>(), ::testing::IsEmpty());
4629 }
4630
TEST_F(HloEvaluatorTest,CopyStartCopyDone)4631 TEST_F(HloEvaluatorTest, CopyStartCopyDone) {
4632 const absl::string_view hlo_text = R"(
4633 HloModule test
4634 ENTRY CopyStartCopyDone {
4635 init = f32[] constant(42.0)
4636 copy-start = (f32[]{:S(1)}, f32[], u32[]) copy-start(init)
4637 ROOT copy-done = f32[] copy-done(copy-start)
4638 }
4639 )";
4640 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4641 Literal expected = LiteralUtil::CreateR0<float>(42.0f);
4642 TF_ASSERT_OK_AND_ASSIGN(
4643 Literal result, HloEvaluator().Evaluate(*m_->entry_computation(), {}));
4644 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
4645 }
4646
TEST_F(HloEvaluatorTest,AsyncOps)4647 TEST_F(HloEvaluatorTest, AsyncOps) {
4648 const absl::string_view hlo_text = R"(
4649 HloModule test
4650 ENTRY AsyncOps {
4651 init = f32[] constant(42.0)
4652 async-start = ((f32[]), f32[], u32[]) negate-start(init)
4653 async-update = ((f32[]), f32[], u32[]) negate-update(async-start)
4654 ROOT async-done = f32[] negate-done(async-update)
4655 }
4656 )";
4657 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4658 Literal expected = LiteralUtil::CreateR0<float>(-42.0f);
4659 TF_ASSERT_OK_AND_ASSIGN(
4660 Literal result, HloEvaluator().Evaluate(*m_->entry_computation(), {}));
4661 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
4662 }
4663
TEST_F(HloEvaluatorTest,MapBF16)4664 TEST_F(HloEvaluatorTest, MapBF16) {
4665 const absl::string_view hlo_text = R"(
4666 HloModule test
4667
4668 map_computation {
4669 p = bf16[] parameter(0)
4670 add = bf16[] add(p, p)
4671 ROOT conv = f32[] convert(add)
4672 }
4673
4674 ENTRY CopyStartCopyDone {
4675 c = bf16[3] constant({1, 2, 3})
4676 ROOT map = f32[3] map(c), to_apply=map_computation
4677 }
4678 )";
4679 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4680 Literal expected = LiteralUtil::CreateR1<float>({2.f, 4.f, 6.f});
4681 TF_ASSERT_OK_AND_ASSIGN(
4682 Literal result, HloEvaluator().Evaluate(*m_->entry_computation(), {}));
4683 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
4684 }
4685
TEST_F(HloEvaluatorTest,MapS16)4686 TEST_F(HloEvaluatorTest, MapS16) {
4687 const absl::string_view hlo_text = R"(
4688 HloModule test
4689
4690 map_computation {
4691 p = s16[] parameter(0)
4692 add = s16[] add(p, p)
4693 ROOT conv = f32[] convert(add)
4694 }
4695
4696 ENTRY CopyStartCopyDone {
4697 c = s16[3] constant({1, 2, 3})
4698 ROOT map = f32[3] map(c), to_apply=map_computation
4699 }
4700 )";
4701 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4702 Literal expected = LiteralUtil::CreateR1<float>({2.f, 4.f, 6.f});
4703 TF_ASSERT_OK_AND_ASSIGN(
4704 Literal result, HloEvaluator().Evaluate(*m_->entry_computation(), {}));
4705 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
4706 }
4707
TEST_F(HloEvaluatorTest,MapU16)4708 TEST_F(HloEvaluatorTest, MapU16) {
4709 const absl::string_view hlo_text = R"(
4710 HloModule test
4711
4712 map_computation {
4713 p = u16[] parameter(0)
4714 add = u16[] add(p, p)
4715 ROOT conv = f32[] convert(add)
4716 }
4717
4718 ENTRY CopyStartCopyDone {
4719 c = u16[3] constant({1, 2, 3})
4720 ROOT map = f32[3] map(c), to_apply=map_computation
4721 }
4722 )";
4723 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4724 Literal expected = LiteralUtil::CreateR1<float>({2.f, 4.f, 6.f});
4725 TF_ASSERT_OK_AND_ASSIGN(
4726 Literal result, HloEvaluator().Evaluate(*m_->entry_computation(), {}));
4727 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
4728 }
4729
TEST_F(HloEvaluatorTest,MapMixed)4730 TEST_F(HloEvaluatorTest, MapMixed) {
4731 const absl::string_view hlo_text = R"(
4732 HloModule test
4733
4734 map_computation {
4735 p0 = u16[] parameter(0)
4736 p1 = f32[] parameter(1)
4737 c0 = f32[] convert(p0)
4738 ROOT add = f32[] add(c0, p1)
4739 }
4740
4741 ENTRY CopyStartCopyDone {
4742 c0 = u16[3] constant({1, 2, 3})
4743 c1 = f32[3] constant({1.5, 2.5, 3.5})
4744 ROOT map = f32[3] map(c0, c1), to_apply=map_computation
4745 }
4746 )";
4747 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4748 Literal expected = LiteralUtil::CreateR1<float>({2.5f, 4.5f, 6.5f});
4749 TF_ASSERT_OK_AND_ASSIGN(
4750 Literal result, HloEvaluator().Evaluate(*m_->entry_computation(), {}));
4751 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
4752 }
4753
TEST_F(HloEvaluatorTest,DotUpcast)4754 TEST_F(HloEvaluatorTest, DotUpcast) {
4755 const absl::string_view hlo_text = R"(
4756 HloModule test
4757 ENTRY DotUpcast {
4758 l = s16[4,3]{1,0} parameter(0)
4759 r = s8[3,2]{1,0} parameter(1)
4760 ROOT result = s32[4,2] dot(l, r), lhs_contracting_dims={1},
4761 rhs_contracting_dims={0}
4762 }
4763 )";
4764 // lhs:
4765 // s16[4,3] {
4766 // { 1, 2, 3 },
4767 // { 5, 6, 7 },
4768 // { 9, 10, 11 },
4769 // { 13, 14, 15 },
4770 // }
4771 auto lhs_array = std::make_unique<Array2D<int16_t>>(4, 3);
4772 lhs_array->FillUnique(1);
4773 auto lhs_literal = LiteralUtil::CreateR2FromArray2D<int16_t>(*lhs_array);
4774
4775 // rhs:
4776 // s8[3,2] {
4777 // { 1, 2 },
4778 // { 3, 4 },
4779 // { 5, 6 },
4780 // }
4781 auto rhs_array = std::make_unique<Array2D<int8_t>>(3, 2);
4782 rhs_array->FillUnique(1);
4783 auto rhs_literal = LiteralUtil::CreateR2FromArray2D<int8_t>(*rhs_array);
4784 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4785 TF_ASSERT_OK_AND_ASSIGN(Literal result,
4786 Evaluate({&lhs_literal, &rhs_literal}));
4787
4788 auto expected_array =
4789 Array2D<int32_t>({{22, 28}, {58, 76}, {94, 124}, {130, 172}});
4790 auto expected = LiteralUtil::CreateR2FromArray2D<int32_t>(expected_array);
4791
4792 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
4793 }
4794
TEST_F(HloEvaluatorTest,SortC64)4795 TEST_F(HloEvaluatorTest, SortC64) {
4796 const absl::string_view hlo_text = R"(
4797 HloModule m
4798
4799 sort_lt_comparator {
4800 parameter.0 = c64[] parameter(0)
4801 real.0 = f32[] real(parameter.0)
4802 parameter.1 = c64[] parameter(1)
4803 real.1 = f32[] real(parameter.1)
4804 ROOT compare = pred[] compare(real.0, real.1), direction=LT
4805 }
4806
4807 ENTRY main {
4808 c = c64[3] constant({(2, 0), (4, 0), (6, 0)})
4809 ROOT sort = c64[3]{0} sort(c), dimensions={0}, to_apply=sort_lt_comparator
4810 }
4811 )";
4812 TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
4813 Literal expected =
4814 LiteralUtil::CreateR1<std::complex<float>>({2.f, 4.f, 6.f});
4815 TF_ASSERT_OK_AND_ASSIGN(
4816 Literal result, HloEvaluator().Evaluate(*m_->entry_computation(), {}));
4817 EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
4818 }
4819
4820 // Tests that HloEvaluator can evaluate an instruction even when its operands
4821 // are not constant.
TEST_F(HloEvaluatorTest,RecursivelyEvaluateNonConstantOperands)4822 TEST_F(HloEvaluatorTest, RecursivelyEvaluateNonConstantOperands) {
4823 Literal c0_literal = LiteralUtil::CreateR2<float>({{0.f, 2.f}, {2.f, 4.f}});
4824 Literal c1_literal = LiteralUtil::CreateR2<float>({{0.f, 5.f}, {0.f, 4.f}});
4825 Literal c2_literal = LiteralUtil::CreateR2<float>({{2.f, 4.f}, {4.f, 4.f}});
4826
4827 Shape shape = c0_literal.shape();
4828 HloComputation::Builder b(TestName());
4829 HloInstruction* c0 =
4830 b.AddInstruction(HloInstruction::CreateConstant(std::move(c0_literal)));
4831 HloInstruction* c1 =
4832 b.AddInstruction(HloInstruction::CreateConstant(std::move(c1_literal)));
4833 HloInstruction* c2 =
4834 b.AddInstruction(HloInstruction::CreateConstant(std::move(c2_literal)));
4835
4836 HloInstruction* add0 = b.AddInstruction(
4837 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, c0, c1));
4838 HloInstruction* add1 = b.AddInstruction(
4839 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, c1, c2));
4840 HloInstruction* add2 = b.AddInstruction(
4841 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add0, add1));
4842
4843 m_->AddEntryComputation(b.Build());
4844 Literal expected = LiteralUtil::CreateR2<float>({{2, 16}, {6, 16}});
4845 TestRecursivelyEvaluateInstruction(add2, expected);
4846 }
4847
4848 // Tests that HloEvaluator can evaluate a GetTupleElement even when its operand
4849 // Tuple instruction cannot be fully evaluated. Note that this requires that the
4850 // tuple element at the given tuple index can be evaluated.
TEST_F(HloEvaluatorTest,GetTupleElementOnPartiallyKnownTupleSucceeds)4851 TEST_F(HloEvaluatorTest, GetTupleElementOnPartiallyKnownTupleSucceeds) {
4852 Literal c0_literal = LiteralUtil::CreateR2<float>({{0.f, 2.f}, {2.f, 4.f}});
4853
4854 Shape shape = c0_literal.shape();
4855 HloComputation::Builder b(TestName());
4856 HloInstruction* c0 =
4857 b.AddInstruction(HloInstruction::CreateConstant(std::move(c0_literal)));
4858 HloInstruction* p0 =
4859 b.AddInstruction(HloInstruction::CreateParameter(0, shape, "param.0"));
4860 HloInstruction* p1 =
4861 b.AddInstruction(HloInstruction::CreateParameter(1, shape, "param.1"));
4862
4863 HloInstruction* tuple =
4864 b.AddInstruction(HloInstruction::CreateTuple({p0, p1, c0}));
4865 HloInstruction* gte =
4866 b.AddInstruction(HloInstruction::CreateGetTupleElement(tuple, 2));
4867
4868 m_->AddEntryComputation(b.Build());
4869 Literal expected = LiteralUtil::CreateR2<float>({{0.f, 2.f}, {2.f, 4.f}});
4870 TestRecursivelyEvaluateInstruction(gte, expected);
4871 }
4872
4873 // Tests that Infeed cannot be evaluated.
TEST_F(HloEvaluatorTest,InfeedFailure)4874 TEST_F(HloEvaluatorTest, InfeedFailure) {
4875 HloComputation::Builder b(TestName());
4876 HloInstruction* token = b.AddInstruction(HloInstruction::CreateToken());
4877 HloInstruction* infeed = b.AddInstruction(HloInstruction::CreateInfeed(
4878 ShapeUtil::MakeShape(F32, {4, 4}), token, ""));
4879
4880 m_->AddEntryComputation(b.Build());
4881 TestRecursiveEvaluationFailure(infeed);
4882 }
4883
4884 // Tests that GetTupleElement cannot be evaluated if the corresponding tuple
4885 // element cannot be evaluated.
TEST_F(HloEvaluatorTest,GetUnknownTupleElementFails)4886 TEST_F(HloEvaluatorTest, GetUnknownTupleElementFails) {
4887 Literal c0_literal = LiteralUtil::CreateR2<float>({{0.f, 2.f}, {2.f, 4.f}});
4888
4889 Shape shape = c0_literal.shape();
4890 HloComputation::Builder b(TestName());
4891 HloInstruction* c0 =
4892 b.AddInstruction(HloInstruction::CreateConstant(std::move(c0_literal)));
4893 HloInstruction* p0 =
4894 b.AddInstruction(HloInstruction::CreateParameter(0, shape, "param.0"));
4895 HloInstruction* p1 =
4896 b.AddInstruction(HloInstruction::CreateParameter(1, shape, "param.1"));
4897
4898 HloInstruction* tuple =
4899 b.AddInstruction(HloInstruction::CreateTuple({p0, p1, c0}));
4900 HloInstruction* gte =
4901 b.AddInstruction(HloInstruction::CreateGetTupleElement(tuple, 0));
4902
4903 m_->AddEntryComputation(b.Build());
4904 TestRecursiveEvaluationFailure(gte);
4905 }
4906
4907 // Tests that partial evaluation works for nested tuples.
TEST_F(HloEvaluatorTest,GetTupleElementFromNestedTupleSucceeds)4908 TEST_F(HloEvaluatorTest, GetTupleElementFromNestedTupleSucceeds) {
4909 Literal c0_literal = LiteralUtil::CreateR2<float>({{0.f, 2.f}, {2.f, 4.f}});
4910
4911 Shape shape = c0_literal.shape();
4912 HloComputation::Builder b(TestName());
4913 HloInstruction* c0 =
4914 b.AddInstruction(HloInstruction::CreateConstant(std::move(c0_literal)));
4915 HloInstruction* p0 =
4916 b.AddInstruction(HloInstruction::CreateParameter(0, shape, "param.0"));
4917 HloInstruction* p1 =
4918 b.AddInstruction(HloInstruction::CreateParameter(1, shape, "param.1"));
4919
4920 HloInstruction* tuple0 =
4921 b.AddInstruction(HloInstruction::CreateTuple({p0, c0}));
4922 HloInstruction* tuple1 =
4923 b.AddInstruction(HloInstruction::CreateTuple({tuple0, p1}));
4924 HloInstruction* gte0 =
4925 b.AddInstruction(HloInstruction::CreateGetTupleElement(tuple1, 0));
4926 HloInstruction* gte1 =
4927 b.AddInstruction(HloInstruction::CreateGetTupleElement(gte0, 1));
4928
4929 m_->AddEntryComputation(b.Build());
4930 Literal expected = LiteralUtil::CreateR2<float>({{0.f, 2.f}, {2.f, 4.f}});
4931 TestRecursivelyEvaluateInstruction(gte1, expected);
4932 }
4933
4934 // Tests that partial evaluation works when the GetTupleElement is interleaved
4935 // with other Tuple instructions.
TEST_F(HloEvaluatorTest,GetTupleElementInterleavedWithTupleSucceeds)4936 TEST_F(HloEvaluatorTest, GetTupleElementInterleavedWithTupleSucceeds) {
4937 Literal c0_literal = LiteralUtil::CreateR2<float>({{0.f, 2.f}, {2.f, 4.f}});
4938
4939 Shape shape = c0_literal.shape();
4940 HloComputation::Builder b(TestName());
4941 HloInstruction* c0 =
4942 b.AddInstruction(HloInstruction::CreateConstant(std::move(c0_literal)));
4943 HloInstruction* p0 =
4944 b.AddInstruction(HloInstruction::CreateParameter(0, shape, "param.0"));
4945 HloInstruction* p1 =
4946 b.AddInstruction(HloInstruction::CreateParameter(1, shape, "param.1"));
4947 HloInstruction* p2 =
4948 b.AddInstruction(HloInstruction::CreateParameter(2, shape, "param.2"));
4949
4950 HloInstruction* tuple0 =
4951 b.AddInstruction(HloInstruction::CreateTuple({p0, c0}));
4952 HloInstruction* tuple1 =
4953 b.AddInstruction(HloInstruction::CreateTuple({tuple0, p1}));
4954 HloInstruction* gte0 =
4955 b.AddInstruction(HloInstruction::CreateGetTupleElement(tuple1, 0));
4956 HloInstruction* tuple2 =
4957 b.AddInstruction(HloInstruction::CreateTuple({gte0, p2}));
4958 HloInstruction* gte1 =
4959 b.AddInstruction(HloInstruction::CreateGetTupleElement(tuple2, 0));
4960 HloInstruction* gte2 =
4961 b.AddInstruction(HloInstruction::CreateGetTupleElement(gte1, 1));
4962
4963 m_->AddEntryComputation(b.Build());
4964 Literal expected = LiteralUtil::CreateR2<float>({{0.f, 2.f}, {2.f, 4.f}});
4965 TestRecursivelyEvaluateInstruction(gte2, expected);
4966 }
4967
4968 class PatternMatchParseWhileLoopTest : public HloTestBase {};
4969
TEST_F(PatternMatchParseWhileLoopTest,LoopBoundDefinedInsideOfCond)4970 TEST_F(PatternMatchParseWhileLoopTest, LoopBoundDefinedInsideOfCond) {
4971 constexpr absl::string_view kHloModule = R"(
4972 HloModule accumulated_all_reduce
4973
4974 %while_condition {
4975 %param = (s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
4976 %gte.0 = s32[] get-tuple-element(%param), index=0
4977 %loop_bound = s32[] constant(5)
4978 ROOT result = pred[] compare(%gte.0, %loop_bound), direction=LT
4979 }
4980
4981 %while_body {
4982 %param = (s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
4983 %gte.0 = s32[] get-tuple-element(%param), index=0
4984 %gte.1 = f32[1024, 1024] get-tuple-element(%param), index=1
4985 %gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2
4986 %accumulation = f32[1024, 1024] add(f32[1024, 1024] %gte.1, f32[1024, 1024] %gte.2)
4987 %constant = s32[] constant(1)
4988 %increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant)
4989 ROOT %loop_result = (s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %accumulation)
4990 }
4991
4992 ENTRY accumulated_all_reduce {
4993 %param.1 = f32[1024, 1024] parameter(0)
4994 %constant.0 = s32[] constant(0)
4995 %accumulation_buffer_init = f32[] constant(0)
4996 %accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={}
4997 %while_init = (s32[], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer)
4998 %while = (s32[], f32[1024, 1024], f32[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body
4999 ROOT %result = f32[1024, 1024] get-tuple-element((s32[], f32[1024, 1024], f32[1024, 1024]) %while), index=2
5000 }
5001 )";
5002 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> hlo_module,
5003 ParseAndReturnVerifiedModule(kHloModule));
5004 HloInstruction* while_op =
5005 hlo_module->entry_computation()->root_instruction()->mutable_operand(0);
5006 std::optional<ParsedWhileLoop> parsed_while_loop =
5007 PatternMatchParseWhileLoop(while_op);
5008 ASSERT_TRUE(parsed_while_loop.has_value());
5009 EXPECT_FALSE(parsed_while_loop->is_dynamic());
5010 EXPECT_EQ(parsed_while_loop->static_while_loop->trip_count, 5);
5011 EXPECT_EQ(parsed_while_loop->static_while_loop->induction_var_index, 0);
5012 EXPECT_EQ(parsed_while_loop->static_while_loop->induction_var_init_value, 0);
5013 EXPECT_EQ(parsed_while_loop->static_while_loop->step_size, 1);
5014 EXPECT_EQ(parsed_while_loop->static_while_loop->loop_bound, 5);
5015 }
5016
TEST_F(PatternMatchParseWhileLoopTest,LoopBoundDefinedOutsideOfCond)5017 TEST_F(PatternMatchParseWhileLoopTest, LoopBoundDefinedOutsideOfCond) {
5018 constexpr absl::string_view kHloModule = R"(
5019 HloModule accumulated_all_reduce
5020
5021 %while_condition {
5022 %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
5023 %gte.0 = s32[] get-tuple-element(%param), index=0
5024 %gte.1 = s32[] get-tuple-element(%param), index=1
5025 ROOT result = pred[] compare(%gte.0, %gte.1), direction=LT
5026 }
5027
5028 %while_body {
5029 %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
5030 %gte.0 = s32[] get-tuple-element(%param), index=0
5031 %gte.1 = s32[] get-tuple-element(%param), index=1
5032 %gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2
5033 %gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3
5034 %accumulation = f32[1024, 1024] add(f32[1024, 1024] %gte.2, f32[1024, 1024] %gte.3)
5035 %constant = s32[] constant(1)
5036 %increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant)
5037 ROOT %loop_result = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation)
5038 }
5039
5040 ENTRY accumulated_all_reduce {
5041 %param.1 = f32[1024, 1024] parameter(0)
5042 %constant.0 = s32[] constant(0)
5043 %constant.1 = s32[] constant(10)
5044 %accumulation_buffer_init = f32[] constant(0)
5045 %accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={}
5046 %while_init = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, s32[] %constant.1, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer)
5047 %while = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body
5048 ROOT %result = f32[1024, 1024] get-tuple-element((s32[], s32[], f32[1024, 1024], f32[1024, 1024]) %while), index=3
5049 }
5050 )";
5051 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> hlo_module,
5052 ParseAndReturnVerifiedModule(kHloModule));
5053 HloInstruction* while_op =
5054 hlo_module->entry_computation()->root_instruction()->mutable_operand(0);
5055 std::optional<ParsedWhileLoop> parsed_while_loop =
5056 PatternMatchParseWhileLoop(while_op);
5057 ASSERT_TRUE(parsed_while_loop.has_value());
5058 EXPECT_FALSE(parsed_while_loop->is_dynamic());
5059 EXPECT_EQ(parsed_while_loop->static_while_loop->trip_count, 10);
5060 EXPECT_EQ(parsed_while_loop->static_while_loop->induction_var_index, 0);
5061 EXPECT_EQ(parsed_while_loop->static_while_loop->induction_var_init_value, 0);
5062 EXPECT_EQ(parsed_while_loop->static_while_loop->step_size, 1);
5063 EXPECT_EQ(parsed_while_loop->static_while_loop->loop_bound, 10);
5064 }
5065
TEST_F(PatternMatchParseWhileLoopTest,LoopBoundComputedOutsideOfCond)5066 TEST_F(PatternMatchParseWhileLoopTest, LoopBoundComputedOutsideOfCond) {
5067 constexpr absl::string_view kHloModule = R"(
5068 HloModule accumulated_all_reduce
5069
5070 %while_condition {
5071 %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
5072 %gte.0 = s32[] get-tuple-element(%param), index=0
5073 %gte.1 = s32[] get-tuple-element(%param), index=1
5074 ROOT result = pred[] compare(%gte.0, %gte.1), direction=LT
5075 }
5076
5077 %while_body {
5078 %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
5079 %gte.0 = s32[] get-tuple-element(%param), index=0
5080 %gte.1 = s32[] get-tuple-element(%param), index=1
5081 %gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2
5082 %gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3
5083 %accumulation = f32[1024, 1024] add(f32[1024, 1024] %gte.2, f32[1024, 1024] %gte.3)
5084 %constant = s32[] constant(1)
5085 %increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant)
5086 ROOT %loop_result = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation)
5087 }
5088
5089 ENTRY accumulated_all_reduce {
5090 %param.1 = f32[1024, 1024] parameter(0)
5091 %constant.0 = s32[] constant(0)
5092 %constant.1 = s32[] constant(10)
5093 %constant.2 = s32[] constant(4)
5094 %loop_bound = s32[] multiply(s32[] %constant.1, s32[] %constant.2)
5095 %accumulation_buffer_init = f32[] constant(0)
5096 %accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={}
5097 %while_init = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, s32[] %loop_bound, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer)
5098 %while = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body
5099 ROOT %result = f32[1024, 1024] get-tuple-element((s32[], s32[], f32[1024, 1024], f32[1024, 1024]) %while), index=3
5100 }
5101 )";
5102 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> hlo_module,
5103 ParseAndReturnVerifiedModule(kHloModule));
5104 HloInstruction* while_op =
5105 hlo_module->entry_computation()->root_instruction()->mutable_operand(0);
5106 std::optional<ParsedWhileLoop> parsed_while_loop =
5107 PatternMatchParseWhileLoop(while_op);
5108 ASSERT_TRUE(parsed_while_loop.has_value());
5109 EXPECT_FALSE(parsed_while_loop->is_dynamic());
5110 EXPECT_EQ(parsed_while_loop->static_while_loop->trip_count, 40);
5111 EXPECT_EQ(parsed_while_loop->static_while_loop->induction_var_index, 0);
5112 EXPECT_EQ(parsed_while_loop->static_while_loop->induction_var_init_value, 0);
5113 EXPECT_EQ(parsed_while_loop->static_while_loop->step_size, 1);
5114 EXPECT_EQ(parsed_while_loop->static_while_loop->loop_bound, 40);
5115 }
5116
TEST_F(PatternMatchParseWhileLoopTest,StepSizeNotOne)5117 TEST_F(PatternMatchParseWhileLoopTest, StepSizeNotOne) {
5118 constexpr absl::string_view kHloModule = R"(
5119 HloModule accumulated_all_reduce
5120
5121 %while_condition {
5122 %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
5123 %gte.0 = s32[] get-tuple-element(%param), index=0
5124 %gte.1 = s32[] get-tuple-element(%param), index=1
5125 ROOT result = pred[] compare(%gte.0, %gte.1), direction=LT
5126 }
5127
5128 %while_body {
5129 %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
5130 %gte.0 = s32[] get-tuple-element(%param), index=0
5131 %gte.1 = s32[] get-tuple-element(%param), index=1
5132 %gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2
5133 %gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3
5134 %accumulation = f32[1024, 1024] add(f32[1024, 1024] %gte.2, f32[1024, 1024] %gte.3)
5135 %constant = s32[] constant(4)
5136 %increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant)
5137 ROOT %loop_result = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation)
5138 }
5139
5140 ENTRY accumulated_all_reduce {
5141 %param.1 = f32[1024, 1024] parameter(0)
5142 %constant.0 = s32[] constant(0)
5143 %constant.1 = s32[] constant(10)
5144 %constant.2 = s32[] constant(4)
5145 %loop_bound = s32[] multiply(s32[] %constant.1, s32[] %constant.2)
5146 %accumulation_buffer_init = f32[] constant(0)
5147 %accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={}
5148 %while_init = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, s32[] %loop_bound, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer)
5149 %while = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body
5150 ROOT %result = f32[1024, 1024] get-tuple-element((s32[], s32[], f32[1024, 1024], f32[1024, 1024]) %while), index=3
5151 }
5152 )";
5153 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> hlo_module,
5154 ParseAndReturnVerifiedModule(kHloModule));
5155 HloInstruction* while_op =
5156 hlo_module->entry_computation()->root_instruction()->mutable_operand(0);
5157 std::optional<ParsedWhileLoop> parsed_while_loop =
5158 PatternMatchParseWhileLoop(while_op);
5159 ASSERT_TRUE(parsed_while_loop.has_value());
5160 EXPECT_FALSE(parsed_while_loop->is_dynamic());
5161 EXPECT_EQ(parsed_while_loop->static_while_loop->trip_count, 10);
5162 EXPECT_EQ(parsed_while_loop->static_while_loop->induction_var_index, 0);
5163 EXPECT_EQ(parsed_while_loop->static_while_loop->induction_var_init_value, 0);
5164 EXPECT_EQ(parsed_while_loop->static_while_loop->step_size, 4);
5165 EXPECT_EQ(parsed_while_loop->static_while_loop->loop_bound, 40);
5166 }
5167
5168 // The loop condition comparison is computed by a call to another computation.
TEST_F(PatternMatchParseWhileLoopTest,RecursiveCond)5169 TEST_F(PatternMatchParseWhileLoopTest, RecursiveCond) {
5170 constexpr absl::string_view kHloModule = R"(
5171 HloModule accumulated_all_reduce
5172
5173 %compute_pred {
5174 %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
5175 %gte.0 = s32[] get-tuple-element(%param), index=0
5176 %gte.1 = s32[] get-tuple-element(%param), index=1
5177 %compare = pred[] compare(gte.0, %gte.1), direction=LT
5178 ROOT %tuple = (pred[]) tuple(pred[] %compare)
5179 }
5180
5181 %while_condition {
5182 %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
5183 %call = (pred[]) call((s32[], s32[], f32[1024, 1024], f32[1024, 1024]) %param), to_apply=%compute_pred
5184 ROOT %gte.4 = pred[] get-tuple-element((pred[]) %call), index=0
5185 }
5186
5187 %while_body {
5188 %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
5189 %gte.0 = s32[] get-tuple-element(%param), index=0
5190 %gte.1 = s32[] get-tuple-element(%param), index=1
5191 %gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2
5192 %gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3
5193 %accumulation = f32[1024, 1024] add(f32[1024, 1024] %gte.2, f32[1024, 1024] %gte.3)
5194 %constant = s32[] constant(1)
5195 %increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant)
5196 ROOT %loop_result = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation)
5197 }
5198
5199 ENTRY accumulated_all_reduce {
5200 %param.1 = f32[1024, 1024] parameter(0)
5201 %constant.0 = s32[] constant(0)
5202 %loop_bound = s32[] constant(10)
5203 %accumulation_buffer_init = f32[] constant(0)
5204 %accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={}
5205 %while_init = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, s32[] %loop_bound, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer)
5206 %while = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body
5207 ROOT %result = f32[1024, 1024] get-tuple-element((s32[], s32[], f32[1024, 1024], f32[1024, 1024]) %while), index=3
5208 }
5209 )";
5210 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> hlo_module,
5211 ParseAndReturnVerifiedModule(kHloModule));
5212 HloInstruction* while_op =
5213 hlo_module->entry_computation()->root_instruction()->mutable_operand(0);
5214 std::optional<ParsedWhileLoop> parsed_while_loop =
5215 PatternMatchParseWhileLoop(while_op);
5216 ASSERT_TRUE(parsed_while_loop.has_value());
5217 EXPECT_FALSE(parsed_while_loop->is_dynamic());
5218 EXPECT_EQ(parsed_while_loop->static_while_loop->trip_count, 10);
5219 EXPECT_EQ(parsed_while_loop->static_while_loop->induction_var_index, 0);
5220 EXPECT_EQ(parsed_while_loop->static_while_loop->induction_var_init_value, 0);
5221 EXPECT_EQ(parsed_while_loop->static_while_loop->step_size, 1);
5222 EXPECT_EQ(parsed_while_loop->static_while_loop->loop_bound, 10);
5223 }
5224
5225 // The loop condition comparison is computed by a call to another computation.
5226 // The called computation could be calling another computation and could use
5227 // get-tuple-element to extract the result.
TEST_F(PatternMatchParseWhileLoopTest,RecursiveCondGetTupleElement)5228 TEST_F(PatternMatchParseWhileLoopTest, RecursiveCondGetTupleElement) {
5229 constexpr absl::string_view kHloModule = R"(
5230 HloModule accumulated_all_reduce
5231
5232 %compute_pred {
5233 %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
5234 %gte.0 = s32[] get-tuple-element(%param), index=0
5235 %gte.1 = s32[] get-tuple-element(%param), index=1
5236 %compare = pred[] compare(gte.0, %gte.1), direction=LT
5237 ROOT %tuple = (pred[]) tuple(pred[] %compare)
5238 }
5239
5240 %get_tuple_element {
5241 %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
5242 %call = (pred[]) call((s32[], s32[], f32[1024, 1024], f32[1024, 1024]) %param), to_apply=%compute_pred
5243 %gte.4 = pred[] get-tuple-element((pred[]) %call), index=0
5244 ROOT %tuple.1 = (pred[]) tuple(pred[] %gte.4)
5245 }
5246 %while_condition {
5247 %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
5248 %call = (pred[]) call((s32[], s32[], f32[1024, 1024], f32[1024, 1024]) %param), to_apply=%get_tuple_element
5249 ROOT %gte.4 = pred[] get-tuple-element((pred[]) %call), index=0
5250 }
5251
5252 %while_body {
5253 %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
5254 %gte.0 = s32[] get-tuple-element(%param), index=0
5255 %gte.1 = s32[] get-tuple-element(%param), index=1
5256 %gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2
5257 %gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3
5258 %accumulation = f32[1024, 1024] add(f32[1024, 1024] %gte.2, f32[1024, 1024] %gte.3)
5259 %constant = s32[] constant(1)
5260 %increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant)
5261 ROOT %loop_result = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation)
5262 }
5263
5264 ENTRY accumulated_all_reduce {
5265 %param.1 = f32[1024, 1024] parameter(0)
5266 %constant.0 = s32[] constant(0)
5267 %loop_bound = s32[] constant(10)
5268 %accumulation_buffer_init = f32[] constant(0)
5269 %accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={}
5270 %while_init = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, s32[] %loop_bound, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer)
5271 %while = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body
5272 ROOT %result = f32[1024, 1024] get-tuple-element((s32[], s32[], f32[1024, 1024], f32[1024, 1024]) %while), index=3
5273 }
5274 )";
5275 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> hlo_module,
5276 ParseAndReturnVerifiedModule(kHloModule));
5277 HloInstruction* while_op =
5278 hlo_module->entry_computation()->root_instruction()->mutable_operand(0);
5279 std::optional<ParsedWhileLoop> parsed_while_loop =
5280 PatternMatchParseWhileLoop(while_op);
5281 ASSERT_TRUE(parsed_while_loop.has_value());
5282 EXPECT_FALSE(parsed_while_loop->is_dynamic());
5283 EXPECT_EQ(parsed_while_loop->static_while_loop->trip_count, 10);
5284 EXPECT_EQ(parsed_while_loop->static_while_loop->induction_var_index, 0);
5285 EXPECT_EQ(parsed_while_loop->static_while_loop->induction_var_init_value, 0);
5286 EXPECT_EQ(parsed_while_loop->static_while_loop->step_size, 1);
5287 EXPECT_EQ(parsed_while_loop->static_while_loop->loop_bound, 10);
5288 }
5289
TEST_F(PatternMatchParseWhileLoopTest,LoopBoundDependsOnAnotherLoop)5290 TEST_F(PatternMatchParseWhileLoopTest, LoopBoundDependsOnAnotherLoop) {
5291 constexpr absl::string_view kHloModule = R"(
5292 HloModule accumulated_all_reduce
5293
5294 %compute_pred.0 {
5295 %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
5296 %gte.0 = s32[] get-tuple-element(%param), index=0
5297 %gte.1 = s32[] get-tuple-element(%param), index=1
5298 %compare = pred[] compare(gte.0, %gte.1), direction=LT
5299 ROOT %tuple = (pred[]) tuple(pred[] %compare)
5300 }
5301
5302 %while_condition.0 {
5303 %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
5304 %call = (pred[]) call((s32[], s32[], f32[1024, 1024], f32[1024, 1024]) %param), to_apply=%compute_pred.0
5305 ROOT %gte.4 = pred[] get-tuple-element((pred[]) %call), index=0
5306 }
5307
5308 %while_body.0 {
5309 %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
5310 %gte.0 = s32[] get-tuple-element(%param), index=0
5311 %gte.1 = s32[] get-tuple-element(%param), index=1
5312 %gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2
5313 %gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3
5314 %accumulation = f32[1024, 1024] add(f32[1024, 1024] %gte.2, f32[1024, 1024] %gte.3)
5315 %constant = s32[] constant(1)
5316 %increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant)
5317 ROOT %loop_result = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation)
5318 }
5319
5320 %compute_pred.1 {
5321 %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
5322 %gte.0 = s32[] get-tuple-element(%param), index=0
5323 %gte.1 = s32[] get-tuple-element(%param), index=1
5324 %compare = pred[] compare(gte.0, %gte.1), direction=LT
5325 ROOT %tuple = (pred[]) tuple(pred[] %compare)
5326 }
5327
5328 %while_condition.1 {
5329 %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
5330 %call = (pred[]) call((s32[], s32[], f32[1024, 1024], f32[1024, 1024]) %param), to_apply=%compute_pred.1
5331 ROOT %gte.4 = pred[] get-tuple-element((pred[]) %call), index=0
5332 }
5333
5334 %while_body.1 {
5335 %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
5336 %gte.0 = s32[] get-tuple-element(%param), index=0
5337 %gte.1 = s32[] get-tuple-element(%param), index=1
5338 %gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2
5339 %gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3
5340 %accumulation = f32[1024, 1024] add(f32[1024, 1024] %gte.2, f32[1024, 1024] %gte.3)
5341 %constant = s32[] constant(1)
5342 %increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant)
5343 ROOT %loop_result = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation)
5344 }
5345
5346 ENTRY accumulated_all_reduce {
5347 %param.1 = f32[1024, 1024] parameter(0)
5348 %param.2 = f32[1024, 1024] parameter(1)
5349 %constant.0 = s32[] constant(0)
5350 %loop_bound = s32[] constant(10)
5351 %accumulation_buffer_init = f32[] constant(0)
5352 %accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={}
5353 %while_init.0 = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, s32[] %loop_bound, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer)
5354 %while.0 = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) while(%while_init.0), condition=%while_condition.0, body=%while_body.0
5355 %result.0 = f32[1024, 1024] get-tuple-element((s32[], s32[], f32[1024, 1024], f32[1024, 1024]) %while.0), index=3
5356 %new_loop_bound = s32[] get-tuple-element((s32[], s32[], f32[1024, 1024], f32[1024, 1024]) %while.0), index=0
5357 %while_init.1 = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, s32[] %new_loop_bound, f32[1024, 1024] %param.2, f32[1024, 1024] %result.0)
5358 %while.1 = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) while(%while_init.1), condition=%while_condition.1, body=%while_body.1
5359 ROOT %result.1 = f32[1024, 1024] get-tuple-element((s32[], s32[], f32[1024, 1024], f32[1024, 1024]) %while.1), index=3
5360 }
5361 )";
5362 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> hlo_module,
5363 ParseAndReturnVerifiedModule(kHloModule));
5364 HloInstruction* while_op =
5365 hlo_module->entry_computation()->root_instruction()->mutable_operand(0);
5366 std::optional<ParsedWhileLoop> parsed_while_loop =
5367 PatternMatchParseWhileLoop(while_op);
5368 ASSERT_TRUE(parsed_while_loop.has_value());
5369 EXPECT_FALSE(parsed_while_loop->is_dynamic());
5370 EXPECT_EQ(parsed_while_loop->static_while_loop->trip_count, 10);
5371 EXPECT_EQ(parsed_while_loop->static_while_loop->induction_var_index, 0);
5372 EXPECT_EQ(parsed_while_loop->static_while_loop->induction_var_init_value, 0);
5373 EXPECT_EQ(parsed_while_loop->static_while_loop->step_size, 1);
5374 EXPECT_EQ(parsed_while_loop->static_while_loop->loop_bound, 10);
5375 }
5376
TEST_F(PatternMatchParseWhileLoopTest,DynamicLoop)5377 TEST_F(PatternMatchParseWhileLoopTest, DynamicLoop) {
5378 constexpr absl::string_view kHloModule = R"(
5379 HloModule accumulated_all_reduce
5380
5381 %while_condition {
5382 %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
5383 %gte.0 = s32[] get-tuple-element(%param), index=0
5384 %gte.1 = s32[] get-tuple-element(%param), index=1
5385 ROOT result = pred[] compare(%gte.0, %gte.1), direction=LT
5386 }
5387
5388 %while_body {
5389 %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
5390 %gte.0 = s32[] get-tuple-element(%param), index=0
5391 %gte.1 = s32[] get-tuple-element(%param), index=1
5392 %gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2
5393 %gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3
5394 %accumulation = f32[1024, 1024] add(f32[1024, 1024] %gte.2, f32[1024, 1024] %gte.3)
5395 %constant = s32[] constant(1)
5396 %increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant)
5397 ROOT %loop_result = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation)
5398 }
5399
5400 ENTRY accumulated_all_reduce {
5401 %param.1 = f32[1024, 1024] parameter(0)
5402 %param.2 = s32[] parameter(1)
5403 %loop_bound = s32[] constant(10)
5404 %accumulation_buffer_init = f32[] constant(0)
5405 %accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={}
5406 %while_init = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %param.2, s32[] %loop_bound, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer)
5407 %while = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body
5408 ROOT %result = f32[1024, 1024] get-tuple-element((s32[], s32[], f32[1024, 1024], f32[1024, 1024]) %while), index=3
5409 }
5410 )";
5411 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> hlo_module,
5412 ParseAndReturnVerifiedModule(kHloModule));
5413 HloInstruction* while_op =
5414 hlo_module->entry_computation()->root_instruction()->mutable_operand(0);
5415 std::optional<ParsedWhileLoop> parsed_while_loop =
5416 PatternMatchParseWhileLoop(while_op);
5417 ASSERT_TRUE(parsed_while_loop.has_value());
5418 EXPECT_TRUE(parsed_while_loop->is_dynamic());
5419 }
5420
5421 // The loop condition comparison is computed by a call to another computation.
TEST_F(PatternMatchParseWhileLoopTest,BooleanCond)5422 TEST_F(PatternMatchParseWhileLoopTest, BooleanCond) {
5423 constexpr absl::string_view kHloModule = R"(
5424 HloModule accumulated_all_reduce
5425 %while_condition {
5426 %param = (pred[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
5427 ROOT %gte.0 = pred[] get-tuple-element(%param), index=0
5428 }
5429
5430 %while_body {
5431 %param = (pred[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
5432 %gte.0 = pred[] get-tuple-element(%param), index=0
5433 %gte.1 = f32[1024, 1024] get-tuple-element(%param), index=1
5434 %gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2
5435 %accumulation = f32[1024, 1024] add(f32[1024, 1024] %gte.1, f32[1024, 1024] %gte.2)
5436 %new_loop_cond = pred[] constant(false)
5437 ROOT %loop_result = (pred[], f32[1024, 1024], f32[1024, 1024]) tuple(%new_loop_cond, %gte.1, %accumulation)
5438 }
5439
5440 ENTRY accumulated_all_reduce {
5441 %param.1 = f32[1024, 1024] parameter(0)
5442 %constant.0 = pred[] constant(true)
5443 %accumulation_buffer_init = f32[] constant(0)
5444 %accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={}
5445 %while_init = (pred[], f32[1024, 1024], f32[1024, 1024]) tuple(pred[] %constant.0, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer)
5446 %while = (pred[], f32[1024, 1024], f32[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body
5447 ROOT %result = f32[1024, 1024] get-tuple-element((pred[], f32[1024, 1024], f32[1024, 1024]) %while), index=2
5448 }
5449 )";
5450 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> hlo_module,
5451 ParseAndReturnVerifiedModule(kHloModule));
5452 HloInstruction* while_op =
5453 hlo_module->entry_computation()->root_instruction()->mutable_operand(0);
5454 std::optional<ParsedWhileLoop> parsed_while_loop =
5455 PatternMatchParseWhileLoop(while_op);
5456 ASSERT_TRUE(parsed_while_loop.has_value());
5457 EXPECT_FALSE(parsed_while_loop->is_dynamic());
5458 EXPECT_EQ(parsed_while_loop->static_while_loop->trip_count, 1);
5459 EXPECT_EQ(parsed_while_loop->static_while_loop->induction_var_index, 0);
5460 EXPECT_EQ(parsed_while_loop->static_while_loop->induction_var_init_value, 0);
5461 EXPECT_EQ(parsed_while_loop->static_while_loop->step_size, 1);
5462 EXPECT_EQ(parsed_while_loop->static_while_loop->loop_bound, 1);
5463 }
5464
5465 } // namespace
5466 } // namespace xla
5467