1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #define EIGEN_USE_THREADS
17 #define EIGEN_USE_CUSTOM_THREAD_POOL
18
19 #include "absl/strings/str_split.h"
20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21 #include "tensorflow/compiler/xla/service/hlo_profile_printer.h"
22 #include "tensorflow/compiler/xla/shape_util.h"
23 #include "tensorflow/compiler/xla/test.h"
24 #include "tensorflow/core/platform/regexp.h"
25 #include "tensorflow/core/platform/test.h"
26
27 // The header files for the tests using mlir_bridge have the _mlir_bridge suffix
28 // inherited from the tf_library target names.
29 #if defined(ENABLE_MLIR_BRIDGE_TEST)
30 #include "tensorflow/compiler/aot/tests/test_graph_tfadd_mlir_bridge.h"
31 #include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_mlir_bridge.h"
32 #include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver_mlir_bridge.h"
33 #include "tensorflow/compiler/aot/tests/test_graph_tfassert_eq_mlir_bridge.h"
34 #include "tensorflow/compiler/aot/tests/test_graph_tfcond_mlir_bridge.h"
35 #include "tensorflow/compiler/aot/tests/test_graph_tffunction_mlir_bridge.h"
36 #include "tensorflow/compiler/aot/tests/test_graph_tfgather_mlir_bridge.h"
37 #include "tensorflow/compiler/aot/tests/test_graph_tfmatmul_mlir_bridge.h"
38 #include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_mlir_bridge.h"
39 #include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_with_profiling_mlir_bridge.h"
40 #include "tensorflow/compiler/aot/tests/test_graph_tfsplits_mlir_bridge.h"
41 #include "tensorflow/compiler/aot/tests/test_graph_tftop_k_mlir_bridge.h"
42 #include "tensorflow/compiler/aot/tests/test_graph_tfvariable_mlir_bridge.h"
43 #include "tensorflow/compiler/aot/tests/test_graph_tfvariable_readonly_mlir_bridge.h"
44 #include "tensorflow/compiler/aot/tests/test_graph_tfvariable_sequential_updates_mlir_bridge.h"
45 // Similarly, there are files for testing the MLIR based lowering of HLO to
46 // object code for XLA:CPU
47 #elif defined(MHLO_LOWERING_TEST)
48 #include "tensorflow/compiler/aot/tests/test_graph_tfadd_mhlo_lowering.h"
49 #include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_mhlo_lowering.h"
50 #include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver_mhlo_lowering.h"
51 #include "tensorflow/compiler/aot/tests/test_graph_tfassert_eq_mhlo_lowering.h"
52 #include "tensorflow/compiler/aot/tests/test_graph_tfcond_mhlo_lowering.h"
53 #include "tensorflow/compiler/aot/tests/test_graph_tffunction_mhlo_lowering.h"
54 #include "tensorflow/compiler/aot/tests/test_graph_tfgather_mhlo_lowering.h"
55 #include "tensorflow/compiler/aot/tests/test_graph_tfmatmul_mhlo_lowering.h"
56 #include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_mhlo_lowering.h"
57 #include "tensorflow/compiler/aot/tests/test_graph_tfsplits_mhlo_lowering.h"
58 #include "tensorflow/compiler/aot/tests/test_graph_tftop_k_mhlo_lowering.h"
59 #include "tensorflow/compiler/aot/tests/test_graph_tfvariable_mhlo_lowering.h"
60 #include "tensorflow/compiler/aot/tests/test_graph_tfvariable_readonly_mhlo_lowering.h"
61 #include "tensorflow/compiler/aot/tests/test_graph_tfvariable_sequential_updates_mhlo_lowering.h"
62 #else
63 #include "tensorflow/compiler/aot/tests/test_graph_tfadd.h"
64 #include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h"
65 #include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver.h"
66 #include "tensorflow/compiler/aot/tests/test_graph_tfassert_eq.h"
67 #include "tensorflow/compiler/aot/tests/test_graph_tfcond.h"
68 #include "tensorflow/compiler/aot/tests/test_graph_tffunction.h"
69 #include "tensorflow/compiler/aot/tests/test_graph_tfgather.h"
70 #include "tensorflow/compiler/aot/tests/test_graph_tfmatmul.h"
71 #include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.h"
72 #include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_with_profiling.h"
73 #include "tensorflow/compiler/aot/tests/test_graph_tfsplits.h"
74 #include "tensorflow/compiler/aot/tests/test_graph_tftop_k.h"
75 #include "tensorflow/compiler/aot/tests/test_graph_tfvariable.h"
76 #include "tensorflow/compiler/aot/tests/test_graph_tfvariable_readonly.h"
77 #include "tensorflow/compiler/aot/tests/test_graph_tfvariable_sequential_updates.h"
78 #endif
79
80 namespace tensorflow {
81 namespace tfcompile {
82 namespace {
83
84 using ::testing::HasSubstr;
85 using ::testing::IsSupersetOf;
86
TEST(TFCompileTest,Add)87 TEST(TFCompileTest, Add) {
88 AddComp add;
89 EXPECT_EQ(add.arg0_data(), add.arg_data(0));
90 EXPECT_EQ(add.arg1_data(), add.arg_data(1));
91
92 add.arg0() = 1;
93 add.arg1() = 2;
94 EXPECT_TRUE(add.Run());
95 EXPECT_EQ(add.error_msg(), "");
96 EXPECT_EQ(add.result0(), 3);
97 EXPECT_EQ(add.result0_data()[0], 3);
98 EXPECT_EQ(add.result0_data(), add.results()[0]);
99
100 add.arg0_data()[0] = 123;
101 add.arg1_data()[0] = 456;
102 EXPECT_TRUE(add.Run());
103 EXPECT_EQ(add.error_msg(), "");
104 EXPECT_EQ(add.result0(), 579);
105 EXPECT_EQ(add.result0_data()[0], 579);
106 EXPECT_EQ(add.result0_data(), add.results()[0]);
107
108 const AddComp& add_const = add;
109 EXPECT_EQ(add_const.error_msg(), "");
110 EXPECT_EQ(add_const.arg0(), 123);
111 EXPECT_EQ(add_const.arg0_data()[0], 123);
112 EXPECT_EQ(add_const.arg0_data(), add.arg_data(0));
113 EXPECT_EQ(add_const.arg1(), 456);
114 EXPECT_EQ(add_const.arg1_data()[0], 456);
115 EXPECT_EQ(add_const.arg1_data(), add.arg_data(1));
116 EXPECT_EQ(add_const.result0(), 579);
117 EXPECT_EQ(add_const.result0_data()[0], 579);
118 EXPECT_EQ(add_const.result0_data(), add_const.results()[0]);
119 }
120
121 // Run tests that use set_argN_data separately, to avoid accidentally re-using
122 // non-existent buffers.
TEST(TFCompileTest,Add_SetArg)123 TEST(TFCompileTest, Add_SetArg) {
124 AddComp add(
125 XlaCompiledCpuFunction::AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY);
126
127 int32 arg_x = 10;
128 int32 arg_y = 32;
129 add.set_arg0_data(&arg_x);
130 add.set_arg1_data(&arg_y);
131 EXPECT_EQ(add.arg0_data(), add.arg_data(0));
132 EXPECT_EQ(add.arg1_data(), add.arg_data(1));
133
134 EXPECT_TRUE(add.Run());
135 EXPECT_EQ(add.error_msg(), "");
136 EXPECT_EQ(add.result0(), 42);
137 EXPECT_EQ(add.result0_data()[0], 42);
138 EXPECT_EQ(add.result0_data(), add.results()[0]);
139 }
140
TEST(TFCompileTest,AddWithCkpt)141 TEST(TFCompileTest, AddWithCkpt) {
142 AddWithCkptComp add;
143 EXPECT_EQ(add.arg0_data(), add.arg_data(0));
144
145 add.arg0() = 1;
146 EXPECT_TRUE(add.Run());
147 EXPECT_EQ(add.error_msg(), "");
148 EXPECT_EQ(add.result0(), 43);
149 EXPECT_EQ(add.result0_data()[0], 43);
150 EXPECT_EQ(add.result0_data(), add.results()[0]);
151
152 add.arg0_data()[0] = 111;
153 EXPECT_TRUE(add.Run());
154 EXPECT_EQ(add.error_msg(), "");
155 EXPECT_EQ(add.result0(), 153);
156 EXPECT_EQ(add.result0_data()[0], 153);
157 EXPECT_EQ(add.result0_data(), add.results()[0]);
158
159 const AddWithCkptComp& add_const = add;
160 EXPECT_EQ(add_const.error_msg(), "");
161 EXPECT_EQ(add_const.arg0(), 111);
162 EXPECT_EQ(add_const.arg0_data()[0], 111);
163 EXPECT_EQ(add_const.arg0_data(), add_const.arg_data(0));
164 EXPECT_EQ(add_const.result0(), 153);
165 EXPECT_EQ(add_const.result0_data()[0], 153);
166 EXPECT_EQ(add_const.result0_data(), add_const.results()[0]);
167 }
168
TEST(TFCompileTest,AddWithCkptSaver)169 TEST(TFCompileTest, AddWithCkptSaver) {
170 AddWithCkptSaverComp add;
171 EXPECT_EQ(add.arg0_data(), add.arg_data(0));
172
173 add.arg0() = 1;
174 EXPECT_TRUE(add.Run());
175 EXPECT_EQ(add.error_msg(), "");
176 EXPECT_EQ(add.result0(), 43);
177 EXPECT_EQ(add.result0_data()[0], 43);
178 EXPECT_EQ(add.result0_data(), add.results()[0]);
179
180 add.arg0_data()[0] = 111;
181 EXPECT_TRUE(add.Run());
182 EXPECT_EQ(add.error_msg(), "");
183 EXPECT_EQ(add.result0(), 153);
184 EXPECT_EQ(add.result0_data()[0], 153);
185 EXPECT_EQ(add.result0_data(), add.results()[0]);
186
187 const AddWithCkptSaverComp& add_const = add;
188 EXPECT_EQ(add_const.error_msg(), "");
189 EXPECT_EQ(add_const.arg0(), 111);
190 EXPECT_EQ(add_const.arg0_data()[0], 111);
191 EXPECT_EQ(add_const.arg0_data(), add_const.arg_data(0));
192 EXPECT_EQ(add_const.result0(), 153);
193 EXPECT_EQ(add_const.result0_data()[0], 153);
194 EXPECT_EQ(add_const.result0_data(), add_const.results()[0]);
195 }
196
TEST(TFCompileTest,Cond)197 TEST(TFCompileTest, Cond) {
198 CondComp cond;
199 EXPECT_EQ(cond.arg0_data(), cond.arg_data(0));
200 EXPECT_EQ(cond.arg1_data(), cond.arg_data(1));
201 EXPECT_EQ(cond.arg2_data(), cond.arg_data(2));
202 cond.arg1() = 10;
203 cond.arg2() = 20;
204 {
205 cond.arg0() = true;
206 const int32 expected_result = cond.arg1();
207 EXPECT_TRUE(cond.Run());
208 EXPECT_EQ(cond.result0(), expected_result);
209 EXPECT_EQ(cond.result0_data()[0], expected_result);
210 EXPECT_EQ(cond.result0_data(), cond.results()[0]);
211 }
212 {
213 cond.arg0() = false;
214 const int32 expected_result = cond.arg2();
215 EXPECT_TRUE(cond.Run());
216 EXPECT_EQ(cond.result0(), expected_result);
217 EXPECT_EQ(cond.result0_data()[0], expected_result);
218 EXPECT_EQ(cond.result0_data(), cond.results()[0]);
219 }
220 }
221
TEST(TFCompileTest,Gather)222 TEST(TFCompileTest, Gather) {
223 GatherComp gather;
224 EXPECT_EQ(gather.arg0_data(), gather.arg_data(0));
225 EXPECT_EQ(gather.arg1_data(), gather.arg_data(1));
226
227 // Successful gather.
228 {
229 const float params[4] = {1, 2, 3, 4};
230 std::copy(params + 0, params + 4, gather.arg0_data());
231 const int32 indices[2] = {1, 3};
232 std::copy(indices + 0, indices + 2, gather.arg1_data());
233 EXPECT_TRUE(gather.Run());
234 EXPECT_EQ(gather.error_msg(), "");
235 const float results[2] = {2, 4};
236 for (int i = 0; i < 2; ++i) {
237 EXPECT_EQ(gather.result0(i), results[i]);
238 EXPECT_EQ(gather.result0_data()[i], results[i]);
239 }
240 EXPECT_EQ(gather.result0_data(), gather.results()[0]);
241
242 const GatherComp& gather_const = gather;
243 EXPECT_EQ(gather_const.error_msg(), "");
244 for (int i = 0; i < 4; ++i) {
245 EXPECT_EQ(gather_const.arg0(i), params[i]);
246 EXPECT_EQ(gather_const.arg0_data()[i], params[i]);
247 }
248 EXPECT_EQ(gather_const.arg0_data(), gather_const.arg_data(0));
249 for (int i = 0; i < 2; ++i) {
250 EXPECT_EQ(gather_const.arg1(i), indices[i]);
251 EXPECT_EQ(gather_const.arg1_data()[i], indices[i]);
252 }
253 EXPECT_EQ(gather_const.arg1_data(), gather_const.arg_data(1));
254 for (int i = 0; i < 2; ++i) {
255 EXPECT_EQ(gather_const.result0(i), results[i]);
256 EXPECT_EQ(gather_const.result0_data()[i], results[i]);
257 }
258 EXPECT_EQ(gather_const.result0_data(), gather.results()[0]);
259 }
260 }
261
TEST(TFCompileTest,MatMul2)262 TEST(TFCompileTest, MatMul2) {
263 Eigen::ThreadPool tp(2);
264 Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
265
266 foo::bar::MatMulComp matmul;
267 matmul.set_thread_pool(&device);
268 EXPECT_EQ(matmul.arg0_data(), matmul.arg_data(0));
269 EXPECT_EQ(matmul.arg1_data(), matmul.arg_data(1));
270
271 // Test using the argN() methods.
272 {
273 matmul.arg0(0, 0) = 1;
274 matmul.arg0(0, 1) = 2;
275 matmul.arg0(0, 2) = 3;
276 matmul.arg0(1, 0) = 4;
277 matmul.arg0(1, 1) = 5;
278 matmul.arg0(1, 2) = 6;
279
280 matmul.arg1(0, 0) = 7;
281 matmul.arg1(0, 1) = 8;
282 matmul.arg1(1, 0) = 9;
283 matmul.arg1(1, 1) = 10;
284 matmul.arg1(2, 0) = 11;
285 matmul.arg1(2, 1) = 12;
286
287 EXPECT_TRUE(matmul.Run());
288 EXPECT_EQ(matmul.error_msg(), "");
289 const float results[4] = {58, 64, 139, 154};
290 for (int i = 0; i < 4; ++i) {
291 EXPECT_EQ(matmul.result0(i / 2, i % 2), results[i]);
292 EXPECT_EQ(matmul.result0_data()[i], results[i]);
293 }
294 EXPECT_EQ(matmul.result0_data(), matmul.results()[0]);
295 }
296
297 // Test using the argN_data() methods.
298 {
299 const float args[12] = {10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120};
300 std::copy(args + 0, args + 6, matmul.arg0_data());
301 std::copy(args + 6, args + 12, matmul.arg1_data());
302 EXPECT_TRUE(matmul.Run());
303 EXPECT_EQ(matmul.error_msg(), "");
304 const float results[4] = {5800, 6400, 13900, 15400};
305 for (int i = 0; i < 4; ++i) {
306 EXPECT_EQ(matmul.result0(i / 2, i % 2), results[i]);
307 EXPECT_EQ(matmul.result0_data()[i], results[i]);
308 }
309 EXPECT_EQ(matmul.result0_data(), matmul.results()[0]);
310
311 const foo::bar::MatMulComp& matmul_const = matmul;
312 EXPECT_EQ(matmul_const.error_msg(), "");
313 for (int i = 0; i < 6; ++i) {
314 EXPECT_EQ(matmul_const.arg0(i / 3, i % 3), args[i]);
315 EXPECT_EQ(matmul_const.arg0_data()[i], args[i]);
316 }
317 EXPECT_EQ(matmul_const.arg0_data(), matmul.arg_data(0));
318 for (int i = 0; i < 6; ++i) {
319 EXPECT_EQ(matmul_const.arg1(i / 2, i % 2), args[i + 6]);
320 EXPECT_EQ(matmul_const.arg1_data()[i], args[i + 6]);
321 }
322 EXPECT_EQ(matmul_const.arg1_data(), matmul.arg_data(1));
323 for (int i = 0; i < 4; ++i) {
324 EXPECT_EQ(matmul_const.result0(i / 2, i % 2), results[i]);
325 EXPECT_EQ(matmul_const.result0_data()[i], results[i]);
326 }
327 EXPECT_EQ(matmul_const.result0_data(), matmul.results()[0]);
328 }
329 }
330
331 // Run tests that use set_argN_data separately, to avoid accidentally re-using
332 // non-existent buffers.
TEST(TFCompileTest,MatMul2_SetArg)333 TEST(TFCompileTest, MatMul2_SetArg) {
334 Eigen::ThreadPool tp(2);
335 Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
336
337 foo::bar::MatMulComp matmul(
338 XlaCompiledCpuFunction::AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY);
339 matmul.set_thread_pool(&device);
340
341 // Test using the set_argN_data() methods.
342 float arg0[2][3] = {{1, 2, 3}, {4, 5, 6}};
343 float arg1[3][2] = {{7, 8}, {9, 10}, {11, 12}};
344 matmul.set_arg0_data(&arg0);
345 matmul.set_arg1_data(&arg1);
346 EXPECT_EQ(matmul.arg0_data(), matmul.arg_data(0));
347 EXPECT_EQ(matmul.arg1_data(), matmul.arg_data(1));
348
349 EXPECT_TRUE(matmul.Run());
350 EXPECT_EQ(matmul.error_msg(), "");
351 const float results[4] = {58, 64, 139, 154};
352 for (int i = 0; i < 4; ++i) {
353 EXPECT_EQ(matmul.result0(i / 2, i % 2), results[i]);
354 EXPECT_EQ(matmul.result0_data()[i], results[i]);
355 }
356 EXPECT_EQ(matmul.result0_data(), matmul.results()[0]);
357 }
358
TEST(TFCompileTest,MatMulAndAdd1)359 TEST(TFCompileTest, MatMulAndAdd1) {
360 Eigen::ThreadPool tp(1);
361 Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
362
363 ::foo::bar::MatMulAndAddComp muladd;
364 muladd.set_thread_pool(&device);
365 EXPECT_EQ(muladd.arg0_data(), muladd.arg_data(0));
366 EXPECT_EQ(muladd.arg1_data(), muladd.arg_data(1));
367
368 // Test methods with positional args and results.
369 {
370 const float args[8] = {1, 2, 3, 4, 5, 6, 7, 8};
371 std::copy(args + 0, args + 4, muladd.arg0_data());
372 std::copy(args + 4, args + 8, muladd.arg1_data());
373 EXPECT_TRUE(muladd.Run());
374 EXPECT_EQ(muladd.error_msg(), "");
375 const float results0[4] = {19, 22, 43, 50};
376 const float results1[4] = {6, 8, 10, 12};
377 for (int i = 0; i < 4; ++i) {
378 EXPECT_EQ(muladd.result0(i / 2, i % 2), results0[i]);
379 EXPECT_EQ(muladd.result0_data()[i], results0[i]);
380 EXPECT_EQ(muladd.result1(i / 2, i % 2), results1[i]);
381 EXPECT_EQ(muladd.result1_data()[i], results1[i]);
382 }
383 EXPECT_EQ(muladd.result0_data(), muladd.results()[0]);
384 EXPECT_EQ(muladd.result1_data(), muladd.results()[1]);
385
386 const ::foo::bar::MatMulAndAddComp& muladd_const = muladd;
387 EXPECT_EQ(muladd_const.error_msg(), "");
388 for (int i = 0; i < 4; ++i) {
389 EXPECT_EQ(muladd_const.arg0(i / 2, i % 2), args[i]);
390 EXPECT_EQ(muladd_const.arg0_data()[i], args[i]);
391 }
392 EXPECT_EQ(muladd_const.arg0_data(), muladd.arg_data(0));
393 for (int i = 0; i < 4; ++i) {
394 EXPECT_EQ(muladd_const.arg1(i / 2, i % 2), args[i + 4]);
395 EXPECT_EQ(muladd_const.arg1_data()[i], args[i + 4]);
396 }
397 EXPECT_EQ(muladd_const.arg1_data(), muladd.arg_data(1));
398 for (int i = 0; i < 4; ++i) {
399 EXPECT_EQ(muladd_const.result0(i / 2, i % 2), results0[i]);
400 EXPECT_EQ(muladd_const.result0_data()[i], results0[i]);
401 EXPECT_EQ(muladd_const.result1(i / 2, i % 2), results1[i]);
402 EXPECT_EQ(muladd_const.result1_data()[i], results1[i]);
403 }
404 EXPECT_EQ(muladd_const.result0_data(), muladd.results()[0]);
405 EXPECT_EQ(muladd_const.result1_data(), muladd.results()[1]);
406 }
407
408 // Test methods with named args and results.
409 {
410 const float args[8] = {10, 20, 30, 40, 50, 60, 70, 80};
411 std::copy(args + 0, args + 4, muladd.arg_x_data());
412 std::copy(args + 4, args + 8, muladd.arg_y_data());
413 EXPECT_TRUE(muladd.Run());
414 EXPECT_EQ(muladd.error_msg(), "");
415 const float results0[4] = {1900, 2200, 4300, 5000};
416 const float results1[4] = {60, 80, 100, 120};
417 for (int i = 0; i < 4; ++i) {
418 EXPECT_EQ(muladd.result_x_y_prod(i / 2, i % 2), results0[i]);
419 EXPECT_EQ(muladd.result_x_y_prod_data()[i], results0[i]);
420 EXPECT_EQ(muladd.result_x_y_sum(i / 2, i % 2), results1[i]);
421 EXPECT_EQ(muladd.result_x_y_sum_data()[i], results1[i]);
422 }
423 EXPECT_EQ(muladd.result_x_y_prod_data(), muladd.results()[0]);
424 EXPECT_EQ(muladd.result_x_y_sum_data(), muladd.results()[1]);
425
426 // Test const methods.
427 const ::foo::bar::MatMulAndAddComp& muladd_const = muladd;
428 EXPECT_EQ(muladd_const.error_msg(), "");
429 for (int i = 0; i < 4; ++i) {
430 EXPECT_EQ(muladd_const.arg_x(i / 2, i % 2), args[i]);
431 EXPECT_EQ(muladd_const.arg_x_data()[i], args[i]);
432 }
433 EXPECT_EQ(muladd_const.arg_x_data(), muladd.arg_data(0));
434 for (int i = 0; i < 4; ++i) {
435 EXPECT_EQ(muladd_const.arg_y(i / 2, i % 2), args[i + 4]);
436 EXPECT_EQ(muladd_const.arg_y_data()[i], args[i + 4]);
437 }
438 EXPECT_EQ(muladd_const.arg_y_data(), muladd.arg_data(1));
439 for (int i = 0; i < 4; ++i) {
440 EXPECT_EQ(muladd_const.result_x_y_prod(i / 2, i % 2), results0[i]);
441 EXPECT_EQ(muladd_const.result_x_y_prod_data()[i], results0[i]);
442 EXPECT_EQ(muladd_const.result_x_y_sum(i / 2, i % 2), results1[i]);
443 EXPECT_EQ(muladd_const.result_x_y_sum_data()[i], results1[i]);
444 }
445 EXPECT_EQ(muladd_const.result_x_y_prod_data(), muladd.results()[0]);
446 EXPECT_EQ(muladd_const.result_x_y_sum_data(), muladd.results()[1]);
447 }
448 }
449
TEST(TFCompileTest,Function)450 TEST(TFCompileTest, Function) {
451 // The function is equivalent to an addition
452 FunctionComp add_fn;
453 EXPECT_EQ(add_fn.arg0_data(), add_fn.arg_data(0));
454 EXPECT_EQ(add_fn.arg1_data(), add_fn.arg_data(1));
455
456 add_fn.arg0() = 1;
457 add_fn.arg1() = 2;
458 EXPECT_TRUE(add_fn.Run());
459 EXPECT_EQ(add_fn.error_msg(), "");
460 EXPECT_EQ(add_fn.result0(), 3);
461 EXPECT_EQ(add_fn.result0_data()[0], 3);
462 EXPECT_EQ(add_fn.result0_data(), add_fn.results()[0]);
463 }
464
TEST(TFCompileTest,Splits)465 TEST(TFCompileTest, Splits) {
466 Eigen::ThreadPool tp(1);
467 Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
468
469 SplitsComp fn;
470
471 fn.set_thread_pool(&device);
472 // x = [[1, 2], [3, 4]]
473 fn.arg0(0, 0) = 1;
474 fn.arg0(0, 1) = 2;
475 fn.arg0(1, 0) = 3;
476 fn.arg0(1, 1) = 4;
477
478 // y = [[10, 20], [30, 40]]
479 fn.arg1(0, 0) = 10;
480 fn.arg1(0, 1) = 20;
481 fn.arg1(1, 0) = 30;
482 fn.arg1(1, 1) = 40;
483 EXPECT_TRUE(fn.Run());
484 EXPECT_EQ(fn.error_msg(), "");
485 const float expected[] = {7.86375557e+10, 1.34274679e+11, 1.92741717e+12,
486 3.29964742e+12};
487 EXPECT_NEAR(expected[0], fn.result0(0, 0), 1e4);
488 EXPECT_NEAR(expected[1], fn.result0(0, 1), 1e4);
489 EXPECT_NEAR(expected[2], fn.result0(1, 0), 1e4);
490 EXPECT_NEAR(expected[3], fn.result0(1, 1), 1e4);
491 }
492
TEST(TFCompileTest,TopK)493 TEST(TFCompileTest, TopK) {
494 Eigen::ThreadPool tp(1);
495 Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
496
497 TopKComp fn;
498
499 fn.set_thread_pool(&device);
500 // x = [4, 1, 4, 4, 3]
501 fn.arg0(0) = 4;
502 fn.arg0(1) = 1;
503 fn.arg0(2) = 4;
504 fn.arg0(3) = 4;
505 fn.arg0(4) = 3;
506
507 EXPECT_TRUE(fn.Run());
508 EXPECT_EQ(fn.error_msg(), "");
509 const int32 expected_values[] = {4, 4};
510 const int32 expected_indices[] = {0, 2};
511 EXPECT_EQ(expected_values[0], fn.result0(0));
512 EXPECT_EQ(expected_values[1], fn.result0(1));
513 EXPECT_EQ(expected_indices[0], fn.result1(0));
514 EXPECT_EQ(expected_indices[1], fn.result1(1));
515 }
516
TEST(TFCompileTest,VariableReadonly)517 TEST(TFCompileTest, VariableReadonly) {
518 Eigen::ThreadPool tp(1);
519 Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
520
521 VariableReadonlyComp fn;
522 float x = 23;
523 fn.set_var_x_data(&x);
524
525 fn.set_thread_pool(&device);
526 fn.Run();
527 EXPECT_EQ(fn.result0(), 65);
528 EXPECT_EQ(fn.var_x(), 23);
529 }
530
TEST(TFCompileTest,Variable)531 TEST(TFCompileTest, Variable) {
532 Eigen::ThreadPool tp(1);
533 Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
534
535 VariableComp fn;
536 float x = 23;
537 fn.set_var_x_data(&x);
538
539 fn.set_thread_pool(&device);
540 fn.Run();
541 EXPECT_EQ(fn.result0(0, 0), 23);
542 EXPECT_EQ(fn.result0(1, 0), 65);
543 EXPECT_EQ(fn.var_x(), 65);
544
545 EXPECT_EQ(fn.var_x_data(), &x);
546 EXPECT_EQ(x, 65);
547 fn.Run();
548 EXPECT_EQ(fn.result0(0, 0), 65);
549 EXPECT_EQ(fn.result0(1, 0), 107);
550 EXPECT_EQ(fn.var_x(), 107);
551 }
552
TEST(TFCompileTest,VariableSequentialUpdates)553 TEST(TFCompileTest, VariableSequentialUpdates) {
554 Eigen::ThreadPool tp(1);
555 Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
556
557 // This implements the recursion:
558 // x[0] = 2.0
559 // x[n+1] = x[n] - 0.1*(x[n-1] + y)
560 VariableSequentialUpdatesComp fn;
561 fn.var_x() = 2;
562 *const_cast<float*>(fn.var_y_data()) = 1;
563
564 fn.set_thread_pool(&device);
565 // First calculate x[3]
566 fn.Run();
567 EXPECT_NEAR(fn.var_x(), 1.187f, 1e-6);
568
569 const float y = 1;
570 fn.set_var_y_data(&y);
571
572 // Now const_cast<float*>(fn.var_y_data()) is not longer legal since we've set
573 // the buffer to point to a constant location.
574
575 // Then calculate x[6]
576 fn.Run();
577 EXPECT_NEAR(fn.var_x(), 0.594322f, 1e-6);
578 }
579
TEST(TFCompileTest,VariableSequentialUpdatesNoAlloc)580 TEST(TFCompileTest, VariableSequentialUpdatesNoAlloc) {
581 Eigen::ThreadPool tp(1);
582 Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
583
584 // This implements the recursion:
585 // x[0] = 2.0
586 // x[n+1] = x[n] - 0.1*(x[n-1] + 1.0)
587 VariableSequentialUpdatesComp fn(
588 XlaCompiledCpuFunction::AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY);
589 float x = 2;
590 float y = 1;
591 fn.set_var_x_data(&x);
592 fn.set_var_y_data(&y);
593
594 fn.set_thread_pool(&device);
595 // First calculate x[3]
596 fn.Run();
597 EXPECT_NEAR(x, 1.187f, 1e-6);
598
599 // Then calculate x[6]
600 fn.Run();
601 EXPECT_NEAR(x, 0.594322f, 1e-6);
602 }
603
TEST(TFCompileTest,AssertEqAndReturnDiff)604 TEST(TFCompileTest, AssertEqAndReturnDiff) {
605 // Assert is converted into a no-op in XLA, so there is no failure even if the
606 // two args are different.
607 AssertComp assert;
608 EXPECT_EQ(assert.arg0_data(), assert.arg_data(0));
609 EXPECT_EQ(assert.arg1_data(), assert.arg_data(1));
610
611 assert.arg0() = 2;
612 assert.arg1() = 1;
613 const int32 expected_result = assert.arg0() - assert.arg1();
614 EXPECT_TRUE(assert.Run());
615 EXPECT_EQ(assert.error_msg(), "");
616 EXPECT_EQ(assert.result0(), expected_result);
617 EXPECT_EQ(assert.result0_data()[0], expected_result);
618 EXPECT_EQ(assert.result0_data(), assert.results()[0]);
619 }
620
TEST(TFCompileTest,LookupNameIndex)621 TEST(TFCompileTest, LookupNameIndex) {
622 // add doesn't have any names defined in its config.
623 AddComp add;
624 EXPECT_FALSE(add.HasNameIndices());
625
626 // muladd has names defined for all feeds and fetches.
627 ::foo::bar::MatMulAndAddComp muladd;
628 EXPECT_TRUE(muladd.HasNameIndices());
629
630 EXPECT_EQ(muladd.LookupArgIndex("x"), 0);
631 EXPECT_EQ(muladd.LookupArgIndex("y"), 1);
632 EXPECT_EQ(muladd.LookupArgIndex(""), -1);
633 EXPECT_EQ(muladd.LookupArgIndex("x_hold"), -1);
634 EXPECT_EQ(muladd.LookupArgIndex("y_hold"), -1);
635 EXPECT_EQ(muladd.LookupArgIndex("x_y_prod"), -1);
636 EXPECT_EQ(muladd.LookupArgIndex("x_y_sum"), -1);
637
638 EXPECT_EQ(muladd.LookupResultIndex("x_y_prod"), 0);
639 EXPECT_EQ(muladd.LookupResultIndex("x_y_sum"), 1);
640 EXPECT_EQ(muladd.LookupResultIndex(""), -1);
641 EXPECT_EQ(muladd.LookupResultIndex("x"), -1);
642 EXPECT_EQ(muladd.LookupResultIndex("y"), -1);
643 EXPECT_EQ(muladd.LookupResultIndex("x_hold"), -1);
644 EXPECT_EQ(muladd.LookupResultIndex("y_hold"), -1);
645 }
646
TEST(TFCompileTest,ProgramShape)647 TEST(TFCompileTest, ProgramShape) {
648 using xla::ShapeUtil;
649 const xla::Shape f32_2x2 = ShapeUtil::MakeShape(xla::F32, {2, 2});
650
651 // add doesn't have the program shape defined.
652 AddComp add;
653 ASSERT_TRUE(add.ProgramShape() == nullptr);
654
655 // muladd has the program shape defined.
656 ::foo::bar::MatMulAndAddComp muladd;
657 const xla::ProgramShapeProto* muladd_shape = muladd.ProgramShape();
658 ASSERT_TRUE(muladd_shape != nullptr);
659 ASSERT_EQ(muladd_shape->parameters_size(), 2);
660 EXPECT_TRUE(
661 ShapeUtil::Compatible(xla::Shape(muladd_shape->parameters(0)), f32_2x2));
662 EXPECT_TRUE(
663 ShapeUtil::Compatible(xla::Shape(muladd_shape->parameters(1)), f32_2x2));
664
665 const xla::Shape muladd_result(muladd_shape->result());
666 ASSERT_EQ(muladd_result.element_type(), xla::TUPLE);
667 ASSERT_EQ(ShapeUtil::TupleElementCount(muladd_result), 2);
668 const xla::Shape& muladd_result0 =
669 ShapeUtil::GetTupleElementShape(muladd_result, 0);
670 EXPECT_TRUE(ShapeUtil::Compatible(muladd_result0, f32_2x2));
671 const xla::Shape& muladd_result1 =
672 ShapeUtil::GetTupleElementShape(muladd_result, 1);
673 EXPECT_TRUE(ShapeUtil::Compatible(muladd_result1, f32_2x2));
674 }
675
676 // tf_compile with mlir_component=HloLowering does not currently support
677 // profiling, so we disable the test case here rather than creating a new test
678 // target that could allow more divergence.
679 #if !defined(MHLO_LOWERING_TEST)
TEST(TFCompileTest,HloProfiling)680 TEST(TFCompileTest, HloProfiling) {
681 Eigen::ThreadPool tp(1);
682 Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
683
684 MatMulAndAddCompWithProfiling fn;
685 ASSERT_TRUE(fn.hlo_profiling_enabled());
686
687 fn.set_thread_pool(&device);
688
689 // x = [[1, 2], [3, 4]]
690 fn.arg0(0, 0) = 1;
691 fn.arg0(0, 1) = 2;
692 fn.arg0(1, 0) = 3;
693 fn.arg0(1, 1) = 4;
694
695 // y = [[10, 20], [30, 40]]
696 fn.arg1(0, 0) = 10;
697 fn.arg1(0, 1) = 20;
698 fn.arg1(1, 0) = 30;
699 fn.arg1(1, 1) = 40;
700
701 EXPECT_TRUE(fn.Run());
702
703 string hlo_profile_as_string =
704 xla::PrintHloProfile(fn.hlo_profile_printer_data(), fn.profile_counters(),
705 /*clock_rate_ghz=*/1.0);
706 VLOG(1) << "Original HLO profile string:\n" << hlo_profile_as_string;
707
708 // Replace Arg_n with argn when the MLIR bridge is used.
709 #if defined(ENABLE_MLIR_BRIDGE_TEST)
710 RE2::GlobalReplace(&hlo_profile_as_string, "(Arg_)([0-9].)", "arg\\2");
711 #endif
712
713 // Strip away identifier details from the profile string to avoid this test
714 // being a change detector for xla internals. Identifiers such as '%dot.0.7'
715 // just become '%dot'.
716 RE2::GlobalReplace(&hlo_profile_as_string, "(%[a-zA-Z0-9]*)[.0-9]*", "\\1");
717 VLOG(1) << "Stripped HLO profile string:\n" << hlo_profile_as_string;
718
719 std::vector<string> hlo_profile_lines =
720 absl::StrSplit(hlo_profile_as_string, '\n');
721
722 auto header = HasSubstr("Execution profile for");
723 auto total_cycles_profile_line = HasSubstr("[total]");
724 auto dot_profile_line = HasSubstr(
725 "%dot = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0, f32[2,2]{1,0} %arg1)");
726 auto add_profile_line = HasSubstr(
727 "%add = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0, f32[2,2]{1,0} %arg1)");
728 auto tuple_profile_line = HasSubstr(
729 "%tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} %dot, "
730 "f32[2,2]{1,0} %add)");
731 auto arg0_profile_line = HasSubstr("%arg0 = f32[2,2]{1,0} parameter(0)");
732 auto arg1_profile_line = HasSubstr("%arg1 = f32[2,2]{1,0} parameter(1)");
733
734 EXPECT_THAT(hlo_profile_lines,
735 IsSupersetOf({header, total_cycles_profile_line, dot_profile_line,
736 add_profile_line, tuple_profile_line}));
737 }
738 #endif
739
740 } // namespace
741 } // namespace tfcompile
742 } // namespace tensorflow
743