xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/aot/tests/tfcompile_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 #define EIGEN_USE_CUSTOM_THREAD_POOL
18 
19 #include "absl/strings/str_split.h"
20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21 #include "tensorflow/compiler/xla/service/hlo_profile_printer.h"
22 #include "tensorflow/compiler/xla/shape_util.h"
23 #include "tensorflow/compiler/xla/test.h"
24 #include "tensorflow/core/platform/regexp.h"
25 #include "tensorflow/core/platform/test.h"
26 
27 // The header files for the tests using mlir_bridge have the _mlir_bridge suffix
28 // inherited from the tf_library target names.
29 #if defined(ENABLE_MLIR_BRIDGE_TEST)
30 #include "tensorflow/compiler/aot/tests/test_graph_tfadd_mlir_bridge.h"
31 #include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_mlir_bridge.h"
32 #include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver_mlir_bridge.h"
33 #include "tensorflow/compiler/aot/tests/test_graph_tfassert_eq_mlir_bridge.h"
34 #include "tensorflow/compiler/aot/tests/test_graph_tfcond_mlir_bridge.h"
35 #include "tensorflow/compiler/aot/tests/test_graph_tffunction_mlir_bridge.h"
36 #include "tensorflow/compiler/aot/tests/test_graph_tfgather_mlir_bridge.h"
37 #include "tensorflow/compiler/aot/tests/test_graph_tfmatmul_mlir_bridge.h"
38 #include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_mlir_bridge.h"
39 #include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_with_profiling_mlir_bridge.h"
40 #include "tensorflow/compiler/aot/tests/test_graph_tfsplits_mlir_bridge.h"
41 #include "tensorflow/compiler/aot/tests/test_graph_tftop_k_mlir_bridge.h"
42 #include "tensorflow/compiler/aot/tests/test_graph_tfvariable_mlir_bridge.h"
43 #include "tensorflow/compiler/aot/tests/test_graph_tfvariable_readonly_mlir_bridge.h"
44 #include "tensorflow/compiler/aot/tests/test_graph_tfvariable_sequential_updates_mlir_bridge.h"
45 // Similarly, there are files for testing the MLIR based lowering of HLO to
46 // object code for XLA:CPU
47 #elif defined(MHLO_LOWERING_TEST)
48 #include "tensorflow/compiler/aot/tests/test_graph_tfadd_mhlo_lowering.h"
49 #include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_mhlo_lowering.h"
50 #include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver_mhlo_lowering.h"
51 #include "tensorflow/compiler/aot/tests/test_graph_tfassert_eq_mhlo_lowering.h"
52 #include "tensorflow/compiler/aot/tests/test_graph_tfcond_mhlo_lowering.h"
53 #include "tensorflow/compiler/aot/tests/test_graph_tffunction_mhlo_lowering.h"
54 #include "tensorflow/compiler/aot/tests/test_graph_tfgather_mhlo_lowering.h"
55 #include "tensorflow/compiler/aot/tests/test_graph_tfmatmul_mhlo_lowering.h"
56 #include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_mhlo_lowering.h"
57 #include "tensorflow/compiler/aot/tests/test_graph_tfsplits_mhlo_lowering.h"
58 #include "tensorflow/compiler/aot/tests/test_graph_tftop_k_mhlo_lowering.h"
59 #include "tensorflow/compiler/aot/tests/test_graph_tfvariable_mhlo_lowering.h"
60 #include "tensorflow/compiler/aot/tests/test_graph_tfvariable_readonly_mhlo_lowering.h"
61 #include "tensorflow/compiler/aot/tests/test_graph_tfvariable_sequential_updates_mhlo_lowering.h"
62 #else
63 #include "tensorflow/compiler/aot/tests/test_graph_tfadd.h"
64 #include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h"
65 #include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver.h"
66 #include "tensorflow/compiler/aot/tests/test_graph_tfassert_eq.h"
67 #include "tensorflow/compiler/aot/tests/test_graph_tfcond.h"
68 #include "tensorflow/compiler/aot/tests/test_graph_tffunction.h"
69 #include "tensorflow/compiler/aot/tests/test_graph_tfgather.h"
70 #include "tensorflow/compiler/aot/tests/test_graph_tfmatmul.h"
71 #include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.h"
72 #include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_with_profiling.h"
73 #include "tensorflow/compiler/aot/tests/test_graph_tfsplits.h"
74 #include "tensorflow/compiler/aot/tests/test_graph_tftop_k.h"
75 #include "tensorflow/compiler/aot/tests/test_graph_tfvariable.h"
76 #include "tensorflow/compiler/aot/tests/test_graph_tfvariable_readonly.h"
77 #include "tensorflow/compiler/aot/tests/test_graph_tfvariable_sequential_updates.h"
78 #endif
79 
80 namespace tensorflow {
81 namespace tfcompile {
82 namespace {
83 
84 using ::testing::HasSubstr;
85 using ::testing::IsSupersetOf;
86 
TEST(TFCompileTest,Add)87 TEST(TFCompileTest, Add) {
88   AddComp add;
89   EXPECT_EQ(add.arg0_data(), add.arg_data(0));
90   EXPECT_EQ(add.arg1_data(), add.arg_data(1));
91 
92   add.arg0() = 1;
93   add.arg1() = 2;
94   EXPECT_TRUE(add.Run());
95   EXPECT_EQ(add.error_msg(), "");
96   EXPECT_EQ(add.result0(), 3);
97   EXPECT_EQ(add.result0_data()[0], 3);
98   EXPECT_EQ(add.result0_data(), add.results()[0]);
99 
100   add.arg0_data()[0] = 123;
101   add.arg1_data()[0] = 456;
102   EXPECT_TRUE(add.Run());
103   EXPECT_EQ(add.error_msg(), "");
104   EXPECT_EQ(add.result0(), 579);
105   EXPECT_EQ(add.result0_data()[0], 579);
106   EXPECT_EQ(add.result0_data(), add.results()[0]);
107 
108   const AddComp& add_const = add;
109   EXPECT_EQ(add_const.error_msg(), "");
110   EXPECT_EQ(add_const.arg0(), 123);
111   EXPECT_EQ(add_const.arg0_data()[0], 123);
112   EXPECT_EQ(add_const.arg0_data(), add.arg_data(0));
113   EXPECT_EQ(add_const.arg1(), 456);
114   EXPECT_EQ(add_const.arg1_data()[0], 456);
115   EXPECT_EQ(add_const.arg1_data(), add.arg_data(1));
116   EXPECT_EQ(add_const.result0(), 579);
117   EXPECT_EQ(add_const.result0_data()[0], 579);
118   EXPECT_EQ(add_const.result0_data(), add_const.results()[0]);
119 }
120 
121 // Run tests that use set_argN_data separately, to avoid accidentally re-using
122 // non-existent buffers.
TEST(TFCompileTest,Add_SetArg)123 TEST(TFCompileTest, Add_SetArg) {
124   AddComp add(
125       XlaCompiledCpuFunction::AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY);
126 
127   int32 arg_x = 10;
128   int32 arg_y = 32;
129   add.set_arg0_data(&arg_x);
130   add.set_arg1_data(&arg_y);
131   EXPECT_EQ(add.arg0_data(), add.arg_data(0));
132   EXPECT_EQ(add.arg1_data(), add.arg_data(1));
133 
134   EXPECT_TRUE(add.Run());
135   EXPECT_EQ(add.error_msg(), "");
136   EXPECT_EQ(add.result0(), 42);
137   EXPECT_EQ(add.result0_data()[0], 42);
138   EXPECT_EQ(add.result0_data(), add.results()[0]);
139 }
140 
TEST(TFCompileTest,AddWithCkpt)141 TEST(TFCompileTest, AddWithCkpt) {
142   AddWithCkptComp add;
143   EXPECT_EQ(add.arg0_data(), add.arg_data(0));
144 
145   add.arg0() = 1;
146   EXPECT_TRUE(add.Run());
147   EXPECT_EQ(add.error_msg(), "");
148   EXPECT_EQ(add.result0(), 43);
149   EXPECT_EQ(add.result0_data()[0], 43);
150   EXPECT_EQ(add.result0_data(), add.results()[0]);
151 
152   add.arg0_data()[0] = 111;
153   EXPECT_TRUE(add.Run());
154   EXPECT_EQ(add.error_msg(), "");
155   EXPECT_EQ(add.result0(), 153);
156   EXPECT_EQ(add.result0_data()[0], 153);
157   EXPECT_EQ(add.result0_data(), add.results()[0]);
158 
159   const AddWithCkptComp& add_const = add;
160   EXPECT_EQ(add_const.error_msg(), "");
161   EXPECT_EQ(add_const.arg0(), 111);
162   EXPECT_EQ(add_const.arg0_data()[0], 111);
163   EXPECT_EQ(add_const.arg0_data(), add_const.arg_data(0));
164   EXPECT_EQ(add_const.result0(), 153);
165   EXPECT_EQ(add_const.result0_data()[0], 153);
166   EXPECT_EQ(add_const.result0_data(), add_const.results()[0]);
167 }
168 
TEST(TFCompileTest,AddWithCkptSaver)169 TEST(TFCompileTest, AddWithCkptSaver) {
170   AddWithCkptSaverComp add;
171   EXPECT_EQ(add.arg0_data(), add.arg_data(0));
172 
173   add.arg0() = 1;
174   EXPECT_TRUE(add.Run());
175   EXPECT_EQ(add.error_msg(), "");
176   EXPECT_EQ(add.result0(), 43);
177   EXPECT_EQ(add.result0_data()[0], 43);
178   EXPECT_EQ(add.result0_data(), add.results()[0]);
179 
180   add.arg0_data()[0] = 111;
181   EXPECT_TRUE(add.Run());
182   EXPECT_EQ(add.error_msg(), "");
183   EXPECT_EQ(add.result0(), 153);
184   EXPECT_EQ(add.result0_data()[0], 153);
185   EXPECT_EQ(add.result0_data(), add.results()[0]);
186 
187   const AddWithCkptSaverComp& add_const = add;
188   EXPECT_EQ(add_const.error_msg(), "");
189   EXPECT_EQ(add_const.arg0(), 111);
190   EXPECT_EQ(add_const.arg0_data()[0], 111);
191   EXPECT_EQ(add_const.arg0_data(), add_const.arg_data(0));
192   EXPECT_EQ(add_const.result0(), 153);
193   EXPECT_EQ(add_const.result0_data()[0], 153);
194   EXPECT_EQ(add_const.result0_data(), add_const.results()[0]);
195 }
196 
TEST(TFCompileTest,Cond)197 TEST(TFCompileTest, Cond) {
198   CondComp cond;
199   EXPECT_EQ(cond.arg0_data(), cond.arg_data(0));
200   EXPECT_EQ(cond.arg1_data(), cond.arg_data(1));
201   EXPECT_EQ(cond.arg2_data(), cond.arg_data(2));
202   cond.arg1() = 10;
203   cond.arg2() = 20;
204   {
205     cond.arg0() = true;
206     const int32 expected_result = cond.arg1();
207     EXPECT_TRUE(cond.Run());
208     EXPECT_EQ(cond.result0(), expected_result);
209     EXPECT_EQ(cond.result0_data()[0], expected_result);
210     EXPECT_EQ(cond.result0_data(), cond.results()[0]);
211   }
212   {
213     cond.arg0() = false;
214     const int32 expected_result = cond.arg2();
215     EXPECT_TRUE(cond.Run());
216     EXPECT_EQ(cond.result0(), expected_result);
217     EXPECT_EQ(cond.result0_data()[0], expected_result);
218     EXPECT_EQ(cond.result0_data(), cond.results()[0]);
219   }
220 }
221 
TEST(TFCompileTest,Gather)222 TEST(TFCompileTest, Gather) {
223   GatherComp gather;
224   EXPECT_EQ(gather.arg0_data(), gather.arg_data(0));
225   EXPECT_EQ(gather.arg1_data(), gather.arg_data(1));
226 
227   // Successful gather.
228   {
229     const float params[4] = {1, 2, 3, 4};
230     std::copy(params + 0, params + 4, gather.arg0_data());
231     const int32 indices[2] = {1, 3};
232     std::copy(indices + 0, indices + 2, gather.arg1_data());
233     EXPECT_TRUE(gather.Run());
234     EXPECT_EQ(gather.error_msg(), "");
235     const float results[2] = {2, 4};
236     for (int i = 0; i < 2; ++i) {
237       EXPECT_EQ(gather.result0(i), results[i]);
238       EXPECT_EQ(gather.result0_data()[i], results[i]);
239     }
240     EXPECT_EQ(gather.result0_data(), gather.results()[0]);
241 
242     const GatherComp& gather_const = gather;
243     EXPECT_EQ(gather_const.error_msg(), "");
244     for (int i = 0; i < 4; ++i) {
245       EXPECT_EQ(gather_const.arg0(i), params[i]);
246       EXPECT_EQ(gather_const.arg0_data()[i], params[i]);
247     }
248     EXPECT_EQ(gather_const.arg0_data(), gather_const.arg_data(0));
249     for (int i = 0; i < 2; ++i) {
250       EXPECT_EQ(gather_const.arg1(i), indices[i]);
251       EXPECT_EQ(gather_const.arg1_data()[i], indices[i]);
252     }
253     EXPECT_EQ(gather_const.arg1_data(), gather_const.arg_data(1));
254     for (int i = 0; i < 2; ++i) {
255       EXPECT_EQ(gather_const.result0(i), results[i]);
256       EXPECT_EQ(gather_const.result0_data()[i], results[i]);
257     }
258     EXPECT_EQ(gather_const.result0_data(), gather.results()[0]);
259   }
260 }
261 
TEST(TFCompileTest,MatMul2)262 TEST(TFCompileTest, MatMul2) {
263   Eigen::ThreadPool tp(2);
264   Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
265 
266   foo::bar::MatMulComp matmul;
267   matmul.set_thread_pool(&device);
268   EXPECT_EQ(matmul.arg0_data(), matmul.arg_data(0));
269   EXPECT_EQ(matmul.arg1_data(), matmul.arg_data(1));
270 
271   // Test using the argN() methods.
272   {
273     matmul.arg0(0, 0) = 1;
274     matmul.arg0(0, 1) = 2;
275     matmul.arg0(0, 2) = 3;
276     matmul.arg0(1, 0) = 4;
277     matmul.arg0(1, 1) = 5;
278     matmul.arg0(1, 2) = 6;
279 
280     matmul.arg1(0, 0) = 7;
281     matmul.arg1(0, 1) = 8;
282     matmul.arg1(1, 0) = 9;
283     matmul.arg1(1, 1) = 10;
284     matmul.arg1(2, 0) = 11;
285     matmul.arg1(2, 1) = 12;
286 
287     EXPECT_TRUE(matmul.Run());
288     EXPECT_EQ(matmul.error_msg(), "");
289     const float results[4] = {58, 64, 139, 154};
290     for (int i = 0; i < 4; ++i) {
291       EXPECT_EQ(matmul.result0(i / 2, i % 2), results[i]);
292       EXPECT_EQ(matmul.result0_data()[i], results[i]);
293     }
294     EXPECT_EQ(matmul.result0_data(), matmul.results()[0]);
295   }
296 
297   // Test using the argN_data() methods.
298   {
299     const float args[12] = {10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120};
300     std::copy(args + 0, args + 6, matmul.arg0_data());
301     std::copy(args + 6, args + 12, matmul.arg1_data());
302     EXPECT_TRUE(matmul.Run());
303     EXPECT_EQ(matmul.error_msg(), "");
304     const float results[4] = {5800, 6400, 13900, 15400};
305     for (int i = 0; i < 4; ++i) {
306       EXPECT_EQ(matmul.result0(i / 2, i % 2), results[i]);
307       EXPECT_EQ(matmul.result0_data()[i], results[i]);
308     }
309     EXPECT_EQ(matmul.result0_data(), matmul.results()[0]);
310 
311     const foo::bar::MatMulComp& matmul_const = matmul;
312     EXPECT_EQ(matmul_const.error_msg(), "");
313     for (int i = 0; i < 6; ++i) {
314       EXPECT_EQ(matmul_const.arg0(i / 3, i % 3), args[i]);
315       EXPECT_EQ(matmul_const.arg0_data()[i], args[i]);
316     }
317     EXPECT_EQ(matmul_const.arg0_data(), matmul.arg_data(0));
318     for (int i = 0; i < 6; ++i) {
319       EXPECT_EQ(matmul_const.arg1(i / 2, i % 2), args[i + 6]);
320       EXPECT_EQ(matmul_const.arg1_data()[i], args[i + 6]);
321     }
322     EXPECT_EQ(matmul_const.arg1_data(), matmul.arg_data(1));
323     for (int i = 0; i < 4; ++i) {
324       EXPECT_EQ(matmul_const.result0(i / 2, i % 2), results[i]);
325       EXPECT_EQ(matmul_const.result0_data()[i], results[i]);
326     }
327     EXPECT_EQ(matmul_const.result0_data(), matmul.results()[0]);
328   }
329 }
330 
331 // Run tests that use set_argN_data separately, to avoid accidentally re-using
332 // non-existent buffers.
TEST(TFCompileTest,MatMul2_SetArg)333 TEST(TFCompileTest, MatMul2_SetArg) {
334   Eigen::ThreadPool tp(2);
335   Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
336 
337   foo::bar::MatMulComp matmul(
338       XlaCompiledCpuFunction::AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY);
339   matmul.set_thread_pool(&device);
340 
341   // Test using the set_argN_data() methods.
342   float arg0[2][3] = {{1, 2, 3}, {4, 5, 6}};
343   float arg1[3][2] = {{7, 8}, {9, 10}, {11, 12}};
344   matmul.set_arg0_data(&arg0);
345   matmul.set_arg1_data(&arg1);
346   EXPECT_EQ(matmul.arg0_data(), matmul.arg_data(0));
347   EXPECT_EQ(matmul.arg1_data(), matmul.arg_data(1));
348 
349   EXPECT_TRUE(matmul.Run());
350   EXPECT_EQ(matmul.error_msg(), "");
351   const float results[4] = {58, 64, 139, 154};
352   for (int i = 0; i < 4; ++i) {
353     EXPECT_EQ(matmul.result0(i / 2, i % 2), results[i]);
354     EXPECT_EQ(matmul.result0_data()[i], results[i]);
355   }
356   EXPECT_EQ(matmul.result0_data(), matmul.results()[0]);
357 }
358 
TEST(TFCompileTest,MatMulAndAdd1)359 TEST(TFCompileTest, MatMulAndAdd1) {
360   Eigen::ThreadPool tp(1);
361   Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
362 
363   ::foo::bar::MatMulAndAddComp muladd;
364   muladd.set_thread_pool(&device);
365   EXPECT_EQ(muladd.arg0_data(), muladd.arg_data(0));
366   EXPECT_EQ(muladd.arg1_data(), muladd.arg_data(1));
367 
368   // Test methods with positional args and results.
369   {
370     const float args[8] = {1, 2, 3, 4, 5, 6, 7, 8};
371     std::copy(args + 0, args + 4, muladd.arg0_data());
372     std::copy(args + 4, args + 8, muladd.arg1_data());
373     EXPECT_TRUE(muladd.Run());
374     EXPECT_EQ(muladd.error_msg(), "");
375     const float results0[4] = {19, 22, 43, 50};
376     const float results1[4] = {6, 8, 10, 12};
377     for (int i = 0; i < 4; ++i) {
378       EXPECT_EQ(muladd.result0(i / 2, i % 2), results0[i]);
379       EXPECT_EQ(muladd.result0_data()[i], results0[i]);
380       EXPECT_EQ(muladd.result1(i / 2, i % 2), results1[i]);
381       EXPECT_EQ(muladd.result1_data()[i], results1[i]);
382     }
383     EXPECT_EQ(muladd.result0_data(), muladd.results()[0]);
384     EXPECT_EQ(muladd.result1_data(), muladd.results()[1]);
385 
386     const ::foo::bar::MatMulAndAddComp& muladd_const = muladd;
387     EXPECT_EQ(muladd_const.error_msg(), "");
388     for (int i = 0; i < 4; ++i) {
389       EXPECT_EQ(muladd_const.arg0(i / 2, i % 2), args[i]);
390       EXPECT_EQ(muladd_const.arg0_data()[i], args[i]);
391     }
392     EXPECT_EQ(muladd_const.arg0_data(), muladd.arg_data(0));
393     for (int i = 0; i < 4; ++i) {
394       EXPECT_EQ(muladd_const.arg1(i / 2, i % 2), args[i + 4]);
395       EXPECT_EQ(muladd_const.arg1_data()[i], args[i + 4]);
396     }
397     EXPECT_EQ(muladd_const.arg1_data(), muladd.arg_data(1));
398     for (int i = 0; i < 4; ++i) {
399       EXPECT_EQ(muladd_const.result0(i / 2, i % 2), results0[i]);
400       EXPECT_EQ(muladd_const.result0_data()[i], results0[i]);
401       EXPECT_EQ(muladd_const.result1(i / 2, i % 2), results1[i]);
402       EXPECT_EQ(muladd_const.result1_data()[i], results1[i]);
403     }
404     EXPECT_EQ(muladd_const.result0_data(), muladd.results()[0]);
405     EXPECT_EQ(muladd_const.result1_data(), muladd.results()[1]);
406   }
407 
408   // Test methods with named args and results.
409   {
410     const float args[8] = {10, 20, 30, 40, 50, 60, 70, 80};
411     std::copy(args + 0, args + 4, muladd.arg_x_data());
412     std::copy(args + 4, args + 8, muladd.arg_y_data());
413     EXPECT_TRUE(muladd.Run());
414     EXPECT_EQ(muladd.error_msg(), "");
415     const float results0[4] = {1900, 2200, 4300, 5000};
416     const float results1[4] = {60, 80, 100, 120};
417     for (int i = 0; i < 4; ++i) {
418       EXPECT_EQ(muladd.result_x_y_prod(i / 2, i % 2), results0[i]);
419       EXPECT_EQ(muladd.result_x_y_prod_data()[i], results0[i]);
420       EXPECT_EQ(muladd.result_x_y_sum(i / 2, i % 2), results1[i]);
421       EXPECT_EQ(muladd.result_x_y_sum_data()[i], results1[i]);
422     }
423     EXPECT_EQ(muladd.result_x_y_prod_data(), muladd.results()[0]);
424     EXPECT_EQ(muladd.result_x_y_sum_data(), muladd.results()[1]);
425 
426     // Test const methods.
427     const ::foo::bar::MatMulAndAddComp& muladd_const = muladd;
428     EXPECT_EQ(muladd_const.error_msg(), "");
429     for (int i = 0; i < 4; ++i) {
430       EXPECT_EQ(muladd_const.arg_x(i / 2, i % 2), args[i]);
431       EXPECT_EQ(muladd_const.arg_x_data()[i], args[i]);
432     }
433     EXPECT_EQ(muladd_const.arg_x_data(), muladd.arg_data(0));
434     for (int i = 0; i < 4; ++i) {
435       EXPECT_EQ(muladd_const.arg_y(i / 2, i % 2), args[i + 4]);
436       EXPECT_EQ(muladd_const.arg_y_data()[i], args[i + 4]);
437     }
438     EXPECT_EQ(muladd_const.arg_y_data(), muladd.arg_data(1));
439     for (int i = 0; i < 4; ++i) {
440       EXPECT_EQ(muladd_const.result_x_y_prod(i / 2, i % 2), results0[i]);
441       EXPECT_EQ(muladd_const.result_x_y_prod_data()[i], results0[i]);
442       EXPECT_EQ(muladd_const.result_x_y_sum(i / 2, i % 2), results1[i]);
443       EXPECT_EQ(muladd_const.result_x_y_sum_data()[i], results1[i]);
444     }
445     EXPECT_EQ(muladd_const.result_x_y_prod_data(), muladd.results()[0]);
446     EXPECT_EQ(muladd_const.result_x_y_sum_data(), muladd.results()[1]);
447   }
448 }
449 
TEST(TFCompileTest,Function)450 TEST(TFCompileTest, Function) {
451   // The function is equivalent to an addition
452   FunctionComp add_fn;
453   EXPECT_EQ(add_fn.arg0_data(), add_fn.arg_data(0));
454   EXPECT_EQ(add_fn.arg1_data(), add_fn.arg_data(1));
455 
456   add_fn.arg0() = 1;
457   add_fn.arg1() = 2;
458   EXPECT_TRUE(add_fn.Run());
459   EXPECT_EQ(add_fn.error_msg(), "");
460   EXPECT_EQ(add_fn.result0(), 3);
461   EXPECT_EQ(add_fn.result0_data()[0], 3);
462   EXPECT_EQ(add_fn.result0_data(), add_fn.results()[0]);
463 }
464 
TEST(TFCompileTest,Splits)465 TEST(TFCompileTest, Splits) {
466   Eigen::ThreadPool tp(1);
467   Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
468 
469   SplitsComp fn;
470 
471   fn.set_thread_pool(&device);
472   // x = [[1, 2], [3, 4]]
473   fn.arg0(0, 0) = 1;
474   fn.arg0(0, 1) = 2;
475   fn.arg0(1, 0) = 3;
476   fn.arg0(1, 1) = 4;
477 
478   // y = [[10, 20], [30, 40]]
479   fn.arg1(0, 0) = 10;
480   fn.arg1(0, 1) = 20;
481   fn.arg1(1, 0) = 30;
482   fn.arg1(1, 1) = 40;
483   EXPECT_TRUE(fn.Run());
484   EXPECT_EQ(fn.error_msg(), "");
485   const float expected[] = {7.86375557e+10, 1.34274679e+11, 1.92741717e+12,
486                             3.29964742e+12};
487   EXPECT_NEAR(expected[0], fn.result0(0, 0), 1e4);
488   EXPECT_NEAR(expected[1], fn.result0(0, 1), 1e4);
489   EXPECT_NEAR(expected[2], fn.result0(1, 0), 1e4);
490   EXPECT_NEAR(expected[3], fn.result0(1, 1), 1e4);
491 }
492 
TEST(TFCompileTest,TopK)493 TEST(TFCompileTest, TopK) {
494   Eigen::ThreadPool tp(1);
495   Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
496 
497   TopKComp fn;
498 
499   fn.set_thread_pool(&device);
500   // x = [4, 1, 4, 4, 3]
501   fn.arg0(0) = 4;
502   fn.arg0(1) = 1;
503   fn.arg0(2) = 4;
504   fn.arg0(3) = 4;
505   fn.arg0(4) = 3;
506 
507   EXPECT_TRUE(fn.Run());
508   EXPECT_EQ(fn.error_msg(), "");
509   const int32 expected_values[] = {4, 4};
510   const int32 expected_indices[] = {0, 2};
511   EXPECT_EQ(expected_values[0], fn.result0(0));
512   EXPECT_EQ(expected_values[1], fn.result0(1));
513   EXPECT_EQ(expected_indices[0], fn.result1(0));
514   EXPECT_EQ(expected_indices[1], fn.result1(1));
515 }
516 
TEST(TFCompileTest,VariableReadonly)517 TEST(TFCompileTest, VariableReadonly) {
518   Eigen::ThreadPool tp(1);
519   Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
520 
521   VariableReadonlyComp fn;
522   float x = 23;
523   fn.set_var_x_data(&x);
524 
525   fn.set_thread_pool(&device);
526   fn.Run();
527   EXPECT_EQ(fn.result0(), 65);
528   EXPECT_EQ(fn.var_x(), 23);
529 }
530 
TEST(TFCompileTest,Variable)531 TEST(TFCompileTest, Variable) {
532   Eigen::ThreadPool tp(1);
533   Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
534 
535   VariableComp fn;
536   float x = 23;
537   fn.set_var_x_data(&x);
538 
539   fn.set_thread_pool(&device);
540   fn.Run();
541   EXPECT_EQ(fn.result0(0, 0), 23);
542   EXPECT_EQ(fn.result0(1, 0), 65);
543   EXPECT_EQ(fn.var_x(), 65);
544 
545   EXPECT_EQ(fn.var_x_data(), &x);
546   EXPECT_EQ(x, 65);
547   fn.Run();
548   EXPECT_EQ(fn.result0(0, 0), 65);
549   EXPECT_EQ(fn.result0(1, 0), 107);
550   EXPECT_EQ(fn.var_x(), 107);
551 }
552 
TEST(TFCompileTest,VariableSequentialUpdates)553 TEST(TFCompileTest, VariableSequentialUpdates) {
554   Eigen::ThreadPool tp(1);
555   Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
556 
557   // This implements the recursion:
558   // x[0] = 2.0
559   // x[n+1] = x[n] - 0.1*(x[n-1] + y)
560   VariableSequentialUpdatesComp fn;
561   fn.var_x() = 2;
562   *const_cast<float*>(fn.var_y_data()) = 1;
563 
564   fn.set_thread_pool(&device);
565   // First calculate x[3]
566   fn.Run();
567   EXPECT_NEAR(fn.var_x(), 1.187f, 1e-6);
568 
569   const float y = 1;
570   fn.set_var_y_data(&y);
571 
572   // Now const_cast<float*>(fn.var_y_data()) is not longer legal since we've set
573   // the buffer to point to a constant location.
574 
575   // Then calculate x[6]
576   fn.Run();
577   EXPECT_NEAR(fn.var_x(), 0.594322f, 1e-6);
578 }
579 
TEST(TFCompileTest,VariableSequentialUpdatesNoAlloc)580 TEST(TFCompileTest, VariableSequentialUpdatesNoAlloc) {
581   Eigen::ThreadPool tp(1);
582   Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
583 
584   // This implements the recursion:
585   // x[0] = 2.0
586   // x[n+1] = x[n] - 0.1*(x[n-1] + 1.0)
587   VariableSequentialUpdatesComp fn(
588       XlaCompiledCpuFunction::AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY);
589   float x = 2;
590   float y = 1;
591   fn.set_var_x_data(&x);
592   fn.set_var_y_data(&y);
593 
594   fn.set_thread_pool(&device);
595   // First calculate x[3]
596   fn.Run();
597   EXPECT_NEAR(x, 1.187f, 1e-6);
598 
599   // Then calculate x[6]
600   fn.Run();
601   EXPECT_NEAR(x, 0.594322f, 1e-6);
602 }
603 
TEST(TFCompileTest,AssertEqAndReturnDiff)604 TEST(TFCompileTest, AssertEqAndReturnDiff) {
605   // Assert is converted into a no-op in XLA, so there is no failure even if the
606   // two args are different.
607   AssertComp assert;
608   EXPECT_EQ(assert.arg0_data(), assert.arg_data(0));
609   EXPECT_EQ(assert.arg1_data(), assert.arg_data(1));
610 
611   assert.arg0() = 2;
612   assert.arg1() = 1;
613   const int32 expected_result = assert.arg0() - assert.arg1();
614   EXPECT_TRUE(assert.Run());
615   EXPECT_EQ(assert.error_msg(), "");
616   EXPECT_EQ(assert.result0(), expected_result);
617   EXPECT_EQ(assert.result0_data()[0], expected_result);
618   EXPECT_EQ(assert.result0_data(), assert.results()[0]);
619 }
620 
TEST(TFCompileTest,LookupNameIndex)621 TEST(TFCompileTest, LookupNameIndex) {
622   // add doesn't have any names defined in its config.
623   AddComp add;
624   EXPECT_FALSE(add.HasNameIndices());
625 
626   // muladd has names defined for all feeds and fetches.
627   ::foo::bar::MatMulAndAddComp muladd;
628   EXPECT_TRUE(muladd.HasNameIndices());
629 
630   EXPECT_EQ(muladd.LookupArgIndex("x"), 0);
631   EXPECT_EQ(muladd.LookupArgIndex("y"), 1);
632   EXPECT_EQ(muladd.LookupArgIndex(""), -1);
633   EXPECT_EQ(muladd.LookupArgIndex("x_hold"), -1);
634   EXPECT_EQ(muladd.LookupArgIndex("y_hold"), -1);
635   EXPECT_EQ(muladd.LookupArgIndex("x_y_prod"), -1);
636   EXPECT_EQ(muladd.LookupArgIndex("x_y_sum"), -1);
637 
638   EXPECT_EQ(muladd.LookupResultIndex("x_y_prod"), 0);
639   EXPECT_EQ(muladd.LookupResultIndex("x_y_sum"), 1);
640   EXPECT_EQ(muladd.LookupResultIndex(""), -1);
641   EXPECT_EQ(muladd.LookupResultIndex("x"), -1);
642   EXPECT_EQ(muladd.LookupResultIndex("y"), -1);
643   EXPECT_EQ(muladd.LookupResultIndex("x_hold"), -1);
644   EXPECT_EQ(muladd.LookupResultIndex("y_hold"), -1);
645 }
646 
TEST(TFCompileTest,ProgramShape)647 TEST(TFCompileTest, ProgramShape) {
648   using xla::ShapeUtil;
649   const xla::Shape f32_2x2 = ShapeUtil::MakeShape(xla::F32, {2, 2});
650 
651   // add doesn't have the program shape defined.
652   AddComp add;
653   ASSERT_TRUE(add.ProgramShape() == nullptr);
654 
655   // muladd has the program shape defined.
656   ::foo::bar::MatMulAndAddComp muladd;
657   const xla::ProgramShapeProto* muladd_shape = muladd.ProgramShape();
658   ASSERT_TRUE(muladd_shape != nullptr);
659   ASSERT_EQ(muladd_shape->parameters_size(), 2);
660   EXPECT_TRUE(
661       ShapeUtil::Compatible(xla::Shape(muladd_shape->parameters(0)), f32_2x2));
662   EXPECT_TRUE(
663       ShapeUtil::Compatible(xla::Shape(muladd_shape->parameters(1)), f32_2x2));
664 
665   const xla::Shape muladd_result(muladd_shape->result());
666   ASSERT_EQ(muladd_result.element_type(), xla::TUPLE);
667   ASSERT_EQ(ShapeUtil::TupleElementCount(muladd_result), 2);
668   const xla::Shape& muladd_result0 =
669       ShapeUtil::GetTupleElementShape(muladd_result, 0);
670   EXPECT_TRUE(ShapeUtil::Compatible(muladd_result0, f32_2x2));
671   const xla::Shape& muladd_result1 =
672       ShapeUtil::GetTupleElementShape(muladd_result, 1);
673   EXPECT_TRUE(ShapeUtil::Compatible(muladd_result1, f32_2x2));
674 }
675 
676 // tf_compile with mlir_component=HloLowering does not currently support
677 // profiling, so we disable the test case here rather than creating a new test
678 // target that could allow more divergence.
679 #if !defined(MHLO_LOWERING_TEST)
TEST(TFCompileTest,HloProfiling)680 TEST(TFCompileTest, HloProfiling) {
681   Eigen::ThreadPool tp(1);
682   Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
683 
684   MatMulAndAddCompWithProfiling fn;
685   ASSERT_TRUE(fn.hlo_profiling_enabled());
686 
687   fn.set_thread_pool(&device);
688 
689   // x = [[1, 2], [3, 4]]
690   fn.arg0(0, 0) = 1;
691   fn.arg0(0, 1) = 2;
692   fn.arg0(1, 0) = 3;
693   fn.arg0(1, 1) = 4;
694 
695   // y = [[10, 20], [30, 40]]
696   fn.arg1(0, 0) = 10;
697   fn.arg1(0, 1) = 20;
698   fn.arg1(1, 0) = 30;
699   fn.arg1(1, 1) = 40;
700 
701   EXPECT_TRUE(fn.Run());
702 
703   string hlo_profile_as_string =
704       xla::PrintHloProfile(fn.hlo_profile_printer_data(), fn.profile_counters(),
705                            /*clock_rate_ghz=*/1.0);
706   VLOG(1) << "Original HLO profile string:\n" << hlo_profile_as_string;
707 
708   // Replace Arg_n with argn when the MLIR bridge is used.
709 #if defined(ENABLE_MLIR_BRIDGE_TEST)
710   RE2::GlobalReplace(&hlo_profile_as_string, "(Arg_)([0-9].)", "arg\\2");
711 #endif
712 
713   // Strip away identifier details from the profile string to avoid this test
714   // being a change detector for xla internals. Identifiers such as '%dot.0.7'
715   // just become '%dot'.
716   RE2::GlobalReplace(&hlo_profile_as_string, "(%[a-zA-Z0-9]*)[.0-9]*", "\\1");
717   VLOG(1) << "Stripped HLO profile string:\n" << hlo_profile_as_string;
718 
719   std::vector<string> hlo_profile_lines =
720       absl::StrSplit(hlo_profile_as_string, '\n');
721 
722   auto header = HasSubstr("Execution profile for");
723   auto total_cycles_profile_line = HasSubstr("[total]");
724   auto dot_profile_line = HasSubstr(
725       "%dot = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0, f32[2,2]{1,0} %arg1)");
726   auto add_profile_line = HasSubstr(
727       "%add = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0, f32[2,2]{1,0} %arg1)");
728   auto tuple_profile_line = HasSubstr(
729       "%tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} %dot, "
730       "f32[2,2]{1,0} %add)");
731   auto arg0_profile_line = HasSubstr("%arg0 = f32[2,2]{1,0} parameter(0)");
732   auto arg1_profile_line = HasSubstr("%arg1 = f32[2,2]{1,0} parameter(1)");
733 
734   EXPECT_THAT(hlo_profile_lines,
735               IsSupersetOf({header, total_cycles_profile_line, dot_profile_line,
736                             add_profile_line, tuple_profile_line}));
737 }
738 #endif
739 
740 }  // namespace
741 }  // namespace tfcompile
742 }  // namespace tensorflow
743