xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/tests/local_client_execute_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <initializer_list>
17 #include <memory>
18 #include <vector>
19 
20 #include "tensorflow/compiler/xla/client/client_library.h"
21 #include "tensorflow/compiler/xla/client/local_client.h"
22 #include "tensorflow/compiler/xla/client/sharding_builder.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/layout_util.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/service/local_service.h"
27 #include "tensorflow/compiler/xla/service/platform_util.h"
28 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
29 #include "tensorflow/compiler/xla/service/transfer_manager.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/statusor.h"
32 #include "tensorflow/compiler/xla/test.h"
33 #include "tensorflow/compiler/xla/test_helpers.h"
34 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
35 #include "tensorflow/compiler/xla/tests/local_client_test_base.h"
36 #include "tensorflow/compiler/xla/tests/test_macros.h"
37 #include "tensorflow/compiler/xla/tests/test_utils.h"
38 #include "tensorflow/compiler/xla/xla_data.pb.h"
39 #include "tensorflow/core/platform/env.h"
40 #include "tensorflow/core/platform/logging.h"
41 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
42 #include "tensorflow/core/platform/test.h"
43 #include "tensorflow/core/platform/test_benchmark.h"
44 #include "tensorflow/stream_executor/device_memory_allocator.h"
45 
46 namespace xla {
47 namespace {
48 
49 using ::testing::ContainsRegex;
50 
51 class LocalClientExecuteTest : public LocalClientTestBase {
52  protected:
53   ErrorSpec error_spec_{0.0001};
54 };
55 
XLA_TEST_F(LocalClientExecuteTest,Constant)56 XLA_TEST_F(LocalClientExecuteTest, Constant) {
57   XlaBuilder builder(TestName());
58   ConstantR0<float>(&builder, 123.0f);
59 
60   ScopedShapedBuffer result =
61       ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {});
62   LiteralTestUtil::ExpectR0Near<float>(123.f, ShapedBufferToLiteral(result),
63                                        error_spec_);
64 }
65 
XLA_TEST_F(LocalClientExecuteTest,AddScalars)66 XLA_TEST_F(LocalClientExecuteTest, AddScalars) {
67   XlaBuilder builder(TestName());
68   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
69   auto y = ConstantR0<float>(&builder, 123.0f);
70   Add(x, y);
71 
72   auto x_value = LiteralToShapedBuffer(LiteralUtil::CreateR0<float>(42.0f));
73   ScopedShapedBuffer result =
74       ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_value});
75   LiteralTestUtil::ExpectR0Near<float>(165.f, ShapedBufferToLiteral(result),
76                                        error_spec_);
77 }
78 
XLA_TEST_F(LocalClientExecuteTest,AddZeroElementVectors)79 XLA_TEST_F(LocalClientExecuteTest, AddZeroElementVectors) {
80   XlaBuilder builder(TestName());
81   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {0}), "x");
82   auto y = ConstantR1<float>(&builder, {});
83   Add(x, y);
84 
85   auto x_array = LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({}));
86   ScopedShapedBuffer result =
87       ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_array});
88   LiteralTestUtil::ExpectR1Near<float>({}, ShapedBufferToLiteral(result),
89                                        error_spec_);
90 }
91 
XLA_TEST_F(LocalClientExecuteTest,AddVectors)92 XLA_TEST_F(LocalClientExecuteTest, AddVectors) {
93   XlaBuilder builder(TestName());
94   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x");
95   auto y = ConstantR1<float>(&builder, {2.0f, 3.0f, 4.0f});
96   Add(x, y);
97 
98   auto x_array =
99       LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
100   ScopedShapedBuffer result =
101       ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_array});
102   LiteralTestUtil::ExpectR1Near<float>(
103       {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_);
104 }
105 
XLA_TEST_F(LocalClientExecuteTest,AddVectorsWithProfile)106 XLA_TEST_F(LocalClientExecuteTest, AddVectorsWithProfile) {
107   XlaBuilder builder(TestName());
108   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x");
109   auto y = ConstantR1<float>(&builder, {2.0f, 3.0f, 4.0f});
110   Add(x, y);
111 
112   auto x_array =
113       LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
114   ExecutionProfile profile;
115   ScopedShapedBuffer result = ExecuteLocallyOrDie(
116       builder.Build().ValueOrDie(), {&x_array}, DefaultExecutableBuildOptions(),
117       DefaultExecutableRunOptions().set_execution_profile(&profile));
118 
119   LiteralTestUtil::ExpectR1Near<float>(
120       {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_);
121   EXPECT_GT(profile.compute_and_transfer_time_ns(), 0);
122 }
123 
XLA_TEST_F(LocalClientExecuteTest,AddArraysWithDifferentInputLayouts)124 XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) {
125   XlaBuilder builder(TestName());
126   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
127   auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y");
128   Add(x, y);
129   auto computation = builder.Build().value();
130 
131   // Create x as a col-major array.
132   auto x_array = LiteralToShapedBuffer(LiteralUtil::CreateR2WithLayout(
133       {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1})));
134   EXPECT_TRUE(Layout::Equal().MinorToMajorOnly()(
135       x_array.on_device_shape().layout(), LayoutUtil::MakeLayout({0, 1})));
136 
137   // Create y as a row-major array.
138   auto y_array = LiteralToShapedBuffer(LiteralUtil::CreateR2WithLayout(
139       {{10.0f, 20.0f}, {30.0f, 40.0f}}, LayoutUtil::MakeLayout({1, 0})));
140   EXPECT_TRUE(Layout::Equal().MinorToMajorOnly()(
141       y_array.on_device_shape().layout(), LayoutUtil::MakeLayout({1, 0})));
142 
143   ScopedShapedBuffer result_colmaj =
144       ExecuteLocallyOrDie(computation, {&x_array, &y_array});
145   LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}},
146                                        ShapedBufferToLiteral(result_colmaj),
147                                        error_spec_);
148 
149   // Run with the parameter values in a different order.
150   ScopedShapedBuffer result_param_swap =
151       ExecuteLocallyOrDie(computation, {&y_array, &x_array});
152   LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}},
153                                        ShapedBufferToLiteral(result_param_swap),
154                                        error_spec_);
155 }
156 
XLA_TEST_F(LocalClientExecuteTest,AddArraysWithDifferentOutputLayouts)157 XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) {
158   XlaBuilder builder(TestName());
159   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
160   auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y");
161   Add(x, y);
162   auto computation = builder.Build().value();
163 
164   auto x_array = LiteralToShapedBuffer(
165       LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
166   auto y_array = LiteralToShapedBuffer(
167       LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
168 
169   // Run with col-major result layout.
170   ScopedShapedBuffer result_colmaj = ExecuteLocallyOrDie(
171       computation, {&x_array, &y_array},
172       DefaultExecutableBuildOptions().set_result_layout(
173           ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2}, {0, 1})),
174       DefaultExecutableRunOptions());
175   EXPECT_TRUE(Layout::Equal().MinorToMajorOnly()(
176       result_colmaj.on_device_shape().layout(),
177       LayoutUtil::MakeLayout({0, 1})));
178   LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}},
179                                        ShapedBufferToLiteral(result_colmaj),
180                                        error_spec_);
181 
182   // Run with row-major result layout.
183   ScopedShapedBuffer result_rowmaj = ExecuteLocallyOrDie(
184       computation, {&x_array, &y_array},
185       DefaultExecutableBuildOptions().set_result_layout(
186           ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2}, {1, 0})),
187       DefaultExecutableRunOptions());
188   EXPECT_TRUE(Layout::Equal().MinorToMajorOnly()(
189       result_rowmaj.on_device_shape().layout(),
190       LayoutUtil::MakeLayout({1, 0})));
191   LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}},
192                                        ShapedBufferToLiteral(result_rowmaj),
193                                        error_spec_);
194 }
195 
XLA_TEST_F(LocalClientExecuteTest,TupleResult)196 XLA_TEST_F(LocalClientExecuteTest, TupleResult) {
197   XlaBuilder builder(TestName());
198   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
199   auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y");
200   Tuple(&builder, {x, y, x});
201   auto computation = builder.Build().value();
202 
203   auto x_array = LiteralToShapedBuffer(
204       LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
205   auto y_array = LiteralToShapedBuffer(
206       LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
207 
208   ScopedShapedBuffer result =
209       ExecuteLocallyOrDie(computation, {&x_array, &y_array});
210 
211   EXPECT_TRUE(result.on_host_shape().IsTuple());
212   EXPECT_EQ(3, ShapeUtil::TupleElementCount(result.on_host_shape()));
213 
214   Literal result_literal = ShapedBufferToLiteral(result);
215   LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
216                                         LiteralSlice(result_literal, {0}));
217   LiteralTestUtil::ExpectR2Equal<float>({{10.0f, 20.0f}, {30.0f, 40.0f}},
218                                         LiteralSlice(result_literal, {1}));
219   LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
220                                         LiteralSlice(result_literal, {2}));
221 }
222 
XLA_TEST_F(LocalClientExecuteTest,NestedTupleResult)223 XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) {
224   XlaBuilder builder(TestName());
225   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
226   auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y");
227   auto inner_tuple = Tuple(&builder, {x, y, x});
228   Tuple(&builder, {inner_tuple, x});
229   auto computation = builder.Build().value();
230 
231   auto x_array = LiteralToShapedBuffer(
232       LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
233   auto y_array = LiteralToShapedBuffer(
234       LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
235 
236   ScopedShapedBuffer result =
237       ExecuteLocallyOrDie(computation, {&x_array, &y_array});
238 
239   EXPECT_TRUE(result.on_host_shape().IsTuple());
240   EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape()));
241 
242   Literal result_literal = ShapedBufferToLiteral(result);
243   LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
244                                         LiteralSlice(result_literal, {1}));
245   LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
246                                         LiteralSlice(result_literal, {0, 0}));
247   LiteralTestUtil::ExpectR2Equal<float>({{10.0f, 20.0f}, {30.0f, 40.0f}},
248                                         LiteralSlice(result_literal, {0, 1}));
249   LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
250                                         LiteralSlice(result_literal, {0, 2}));
251 }
252 
XLA_TEST_F(LocalClientExecuteTest,TupleResultWithLayout)253 XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) {
254   // Verify setting the result layout of a computation with a tuple output.
255   XlaBuilder builder(TestName());
256   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
257   auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y");
258   Tuple(&builder, {x, y});
259 
260   auto array = LiteralToShapedBuffer(
261       LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
262 
263   ExecutableBuildOptions options = DefaultExecutableBuildOptions();
264   Shape shape_with_layout = ShapeUtil::MakeTupleShape(
265       {ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2},
266                                       /*minor_to_major=*/{0, 1}),
267        ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2},
268                                       /*minor_to_major=*/{1, 0})});
269   options.set_result_layout(shape_with_layout);
270   ScopedShapedBuffer result =
271       ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&array, &array},
272                           options, DefaultExecutableRunOptions());
273 
274   Literal result_literal = ShapedBufferToLiteral(result);
275   LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
276                                         LiteralSlice(result_literal, {0}));
277   LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
278                                         LiteralSlice(result_literal, {1}));
279 }
280 
XLA_TEST_F(LocalClientExecuteTest,TupleArguments)281 XLA_TEST_F(LocalClientExecuteTest, TupleArguments) {
282   const Shape array_shape = ShapeUtil::MakeShape(F32, {2, 2});
283   const Shape vector_shape = ShapeUtil::MakeShape(F32, {3});
284 
285   const Shape tuple_shape0 =
286       ShapeUtil::MakeTupleShape({array_shape, vector_shape});
287   const Shape tuple_shape1 =
288       ShapeUtil::MakeTupleShape({vector_shape, array_shape});
289 
290   // Computation adds the respective array and vector elements from each tuple
291   // argument and returns the results as a tuple.
292   XlaBuilder builder(TestName());
293   auto x = Parameter(&builder, 0, tuple_shape0, "x");
294   auto y = Parameter(&builder, 1, tuple_shape1, "y");
295   auto x_0 = GetTupleElement(x, 0);
296   auto x_1 = GetTupleElement(x, 1);
297   auto y_0 = GetTupleElement(y, 0);
298   auto y_1 = GetTupleElement(y, 1);
299   auto array_sum = Add(x_0, y_1);
300   auto vector_diff = Sub(x_1, y_0);
301   Tuple(&builder, {array_sum, vector_diff});
302   auto computation = builder.Build().value();
303 
304   auto x_literal = LiteralUtil::MakeTupleFromSlices(
305       {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
306        LiteralUtil::CreateR1<float>({42.0, 75.0, 123.0})});
307   auto y_literal = LiteralUtil::MakeTupleFromSlices(
308       {LiteralUtil::CreateR1<float>({2.0, 4.0, 6.0}),
309        LiteralUtil::CreateR2<float>({{55.0, 44.0}, {33.0, 22.0}})});
310 
311   auto x_buffer = LiteralToShapedBuffer(x_literal);
312   auto y_buffer = LiteralToShapedBuffer(y_literal);
313 
314   ScopedShapedBuffer result =
315       ExecuteLocallyOrDie(computation, {&x_buffer, &y_buffer});
316 
317   EXPECT_TRUE(result.on_host_shape().IsTuple());
318   EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape()));
319 
320   Literal result_literal = ShapedBufferToLiteral(result);
321   LiteralTestUtil::ExpectR2Equal<float>({{56.0f, 46.0f}, {36.0f, 26.0f}},
322                                         LiteralSlice(result_literal, {0}));
323   LiteralTestUtil::ExpectR1Equal<float>({40.0f, 71.0f, 117.0f},
324                                         LiteralSlice(result_literal, {1}));
325 }
326 
XLA_TEST_F(LocalClientExecuteTest,NestedTupleArgument)327 XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) {
328   const Shape array_shape = ShapeUtil::MakeShape(F32, {2, 2});
329   const Shape vector_shape = ShapeUtil::MakeShape(F32, {3});
330 
331   const Shape inner_tuple_shape =
332       ShapeUtil::MakeTupleShape({array_shape, vector_shape});
333   const Shape nested_tuple_shape =
334       ShapeUtil::MakeTupleShape({inner_tuple_shape, vector_shape});
335 
336   // Computation negates the array element and sums the two vector elements in
337   // the nested tuple. The resulting array and vector are returned as a tuple.
338   XlaBuilder builder(TestName());
339   auto param = Parameter(&builder, 0, nested_tuple_shape, "param");
340   auto inner_tuple = GetTupleElement(param, 0);
341   auto inner_array = GetTupleElement(inner_tuple, 0);
342   auto inner_vector = GetTupleElement(inner_tuple, 1);
343   auto outer_vector = GetTupleElement(param, 1);
344 
345   auto negate_array = Neg(inner_array);
346   auto vector_sum = Add(inner_vector, outer_vector);
347   Tuple(&builder, {negate_array, vector_sum});
348   auto computation = builder.Build().value();
349 
350   auto arg_literal = LiteralUtil::MakeTupleFromSlices(
351       {LiteralUtil::MakeTupleFromSlices(
352            {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
353             LiteralUtil::CreateR1<float>({42.0, 75.0, 123.0})}),
354        LiteralUtil::CreateR1<float>({222.0, -2.0, 10.0})});
355   auto arg_buffer = LiteralToShapedBuffer(arg_literal);
356 
357   ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
358 
359   Literal result_literal = ShapedBufferToLiteral(result);
360   LiteralTestUtil::ExpectR2Equal<float>({{-1.0, -2.0}, {-3.0, -4}},
361                                         LiteralSlice(result_literal, {0}));
362   LiteralTestUtil::ExpectR1Equal<float>({264.0, 73.0, 133.0},
363                                         LiteralSlice(result_literal, {1}));
364 }
365 
XLA_TEST_F(LocalClientExecuteTest,PassingTupleResultBackIntoComputation)366 XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) {
367   // Construct a computation which takes and returns the same shape (a
368   // tuple). Feed the result of the computation back into the input. This
369   // provides additional verification that the returned tuple is properly
370   // constructed.
371   const Shape array_shape = ShapeUtil::MakeShape(F32, {2, 2});
372   const Shape tuple_shape =
373       ShapeUtil::MakeTupleShape({array_shape, array_shape});
374 
375   XlaBuilder builder(TestName());
376   auto param = Parameter(&builder, 0, tuple_shape, "param");
377   auto element_0 = GetTupleElement(param, 0);
378   auto element_1 = GetTupleElement(param, 1);
379   Tuple(&builder, {Neg(element_0), Add(element_1, element_1)});
380   auto computation = builder.Build().value();
381 
382   auto arg_literal = LiteralUtil::MakeTupleFromSlices(
383       {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
384        LiteralUtil::CreateR2<float>({{11.0, 3.0}, {4.0, 5.0}})});
385   auto arg_buffer = LiteralToShapedBuffer(arg_literal);
386 
387   ScopedShapedBuffer result_0 = ExecuteLocallyOrDie(computation, {&arg_buffer});
388   Literal result_0_literal = ShapedBufferToLiteral(result_0);
389   LiteralTestUtil::ExpectR2Equal<float>({{-1.0, -2.0}, {-3.0, -4.0}},
390                                         LiteralSlice(result_0_literal, {0}));
391   LiteralTestUtil::ExpectR2Equal<float>({{22.0, 6.0}, {8.0, 10}},
392                                         LiteralSlice(result_0_literal, {1}));
393 
394   ScopedShapedBuffer result_1 = ExecuteLocallyOrDie(computation, {&result_0});
395   Literal result_1_literal = ShapedBufferToLiteral(result_1);
396   LiteralTestUtil::ExpectR2Equal<float>({{1.0, 2.0}, {3.0, 4.0}},
397                                         LiteralSlice(result_1_literal, {0}));
398   LiteralTestUtil::ExpectR2Equal<float>({{44.0, 12.0}, {16.0, 20}},
399                                         LiteralSlice(result_1_literal, {1}));
400 }
401 
XLA_TEST_F(LocalClientExecuteTest,LargeTuple)402 XLA_TEST_F(LocalClientExecuteTest, LargeTuple) {
403   // Construct a computation which takes a tuple parameter with a very large
404   // number of elements.
405 
406   // A larger number of elements would make for a better, more strenuous test,
407   // but:
408   // TODO(b/66959878): On cpu a large number of elements results in long
409   //   compilation time.
410   // TODO(b/66954197): On gpu a large number of elements OOMs.
411   const int kElementCount = 100;
412 
413   // Each element is a 2-element vector.
414   const Shape element_shape = ShapeUtil::MakeShape(F32, {2});
415   std::vector<Shape> element_shapes(kElementCount, element_shape);
416   const Shape tuple_shape = ShapeUtil::MakeTupleShape(element_shapes);
417 
418   XlaBuilder builder(TestName());
419   auto param = Parameter(&builder, 0, tuple_shape, "param");
420 
421   // Add each element's tuple index value to every element.
422   std::vector<XlaOp> result_elements;
423   result_elements.reserve(kElementCount);
424   for (int i = 0; i < kElementCount; ++i) {
425     auto element = GetTupleElement(param, i);
426     result_elements.push_back(Add(element, ConstantR0<float>(&builder, i)));
427   }
428   Tuple(&builder, result_elements);
429   auto computation = builder.Build().value();
430 
431   // Feed in a tuple where each two-element vector element is {tuple_index,
432   // -tuple_index}.
433   std::vector<Literal> arg_elements;
434   arg_elements.reserve(kElementCount);
435   for (int i = 0; i < kElementCount; ++i) {
436     arg_elements.push_back(LiteralUtil::CreateR1<float>({1.0f * i, -1.0f * i}));
437   }
438   Literal arg_literal = LiteralUtil::MakeTupleOwned(std::move(arg_elements));
439   auto arg_buffer = LiteralToShapedBuffer(arg_literal);
440 
441   ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
442   Literal result_literal = ShapedBufferToLiteral(result);
443 
444   for (int i = 0; i < kElementCount; ++i) {
445     LiteralTestUtil::ExpectR1Near<float>(
446         {2.0f * i, 0.0f}, LiteralSlice(result_literal, {i}), error_spec_);
447   }
448 }
449 
XLA_TEST_F(LocalClientExecuteTest,LargeNestedTuple)450 XLA_TEST_F(LocalClientExecuteTest, LargeNestedTuple) {
451   // Construct and run a computation which takes a two-level nested tuple
452   // parameter with a large fanout.
453   const int kFanout = 40;
454 
455   // Tuple shape is full two-level tree with the given fanout.
456   const Shape element_shape = ShapeUtil::MakeShape(F32, {});
457   std::vector<Shape> element_shapes(kFanout, element_shape);
458   const Shape inner_tuple_shape = ShapeUtil::MakeTupleShape(element_shapes);
459   std::vector<Shape> inner_tuple_shapes(kFanout, inner_tuple_shape);
460   const Shape tuple_shape = ShapeUtil::MakeTupleShape(inner_tuple_shapes);
461 
462   XlaBuilder builder(TestName());
463   auto param = Parameter(&builder, 0, tuple_shape, "param");
464 
465   // The computation increments each leaf value by an amount equal to the leaf's
466   // ordinal position in a traversal of the tuple.
467   std::vector<XlaOp> result_elements;
468   result_elements.reserve(kFanout);
469   for (int i = 0; i < kFanout; ++i) {
470     auto outer_element = GetTupleElement(param, i);
471     std::vector<XlaOp> inner_result_elements;
472     inner_result_elements.reserve(kFanout);
473     for (int j = 0; j < kFanout; ++j) {
474       auto inner_element = GetTupleElement(outer_element, j);
475       inner_result_elements.push_back(
476           Add(inner_element, ConstantR0<float>(&builder, i * kFanout + j)));
477     }
478     result_elements.push_back(Tuple(&builder, inner_result_elements));
479   }
480   Tuple(&builder, result_elements);
481   auto computation = builder.Build().value();
482 
483   // Construct the argument to pass to the computation.
484   std::vector<Literal> outer_tuple_elements;
485   outer_tuple_elements.reserve(kFanout);
486   for (int i = 0; i < kFanout; ++i) {
487     std::vector<Literal> inner_tuple_elements;
488     inner_tuple_elements.reserve(kFanout);
489     for (int j = 0; j < kFanout; ++j) {
490       inner_tuple_elements.push_back(LiteralUtil::CreateR0<float>(i + j));
491     }
492     outer_tuple_elements.push_back(
493         LiteralUtil::MakeTupleOwned(std::move(inner_tuple_elements)));
494   }
495   auto arg_literal =
496       LiteralUtil::MakeTupleOwned(std::move(outer_tuple_elements));
497   auto arg_buffer = LiteralToShapedBuffer(arg_literal);
498 
499   ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
500   Literal result_literal = ShapedBufferToLiteral(result);
501 
502   for (int i = 0; i < kFanout; ++i) {
503     for (int j = 0; j < kFanout; ++j) {
504       LiteralTestUtil::ExpectR0Near<float>(i + j + i * kFanout + j,
505                                            LiteralSlice(result_literal, {i, j}),
506                                            error_spec_);
507     }
508   }
509 }
510 
XLA_TEST_F(LocalClientExecuteTest,DeepTuple)511 XLA_TEST_F(LocalClientExecuteTest, DeepTuple) {
512   // Construct and run a computation which takes a very deep tuple. The tuple
513   // has no fan out and a single scalar element at the bottom.
514   const int kTupleDepth = 100;
515 
516   // Tuple shape is full two-level tree with the given fanout.
517   Shape shape = ShapeUtil::MakeShape(F32, {});
518   for (int i = 0; i < kTupleDepth; ++i) {
519     shape = ShapeUtil::MakeTupleShape({shape});
520   }
521 
522   XlaBuilder builder(TestName());
523   auto element = Parameter(&builder, 0, shape, "param");
524   for (int i = 0; i < kTupleDepth; ++i) {
525     element = GetTupleElement(element, 0);
526   }
527 
528   auto output = Add(element, ConstantR0<float>(&builder, 42.0));
529   for (int i = 0; i < kTupleDepth; ++i) {
530     output = Tuple(&builder, {output});
531   }
532   auto computation = builder.Build().value();
533 
534   // Construct the argument to pass to the computation.
535   Literal arg_literal = LiteralUtil::CreateR0<float>(123.0);
536   for (int i = 0; i < kTupleDepth; ++i) {
537     std::vector<Literal> arg_vector;
538     arg_vector.push_back(std::move(arg_literal));
539     arg_literal = LiteralUtil::MakeTupleOwned(std::move(arg_vector));
540   }
541   auto arg_buffer = LiteralToShapedBuffer(arg_literal);
542 
543   ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
544   Literal result_literal = ShapedBufferToLiteral(result);
545 
546   ShapeIndex index;
547   for (int i = 0; i < kTupleDepth; ++i) {
548     index.push_back(0);
549   }
550   LiteralTestUtil::ExpectR0Equal<float>(165.0,
551                                         LiteralSlice(result_literal, index));
552 }
553 
XLA_TEST_F(LocalClientExecuteTest,InvalidNumberOfArguments)554 XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) {
555   // Test passing in an invalid number of arguments.
556   XlaBuilder builder(TestName());
557   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x");
558   auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {3}), "y");
559   Add(x, y);
560 
561   auto x_array =
562       LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f}));
563   auto execute_status =
564       ExecuteLocally(builder.Build().ValueOrDie(), {&x_array});
565 
566   EXPECT_FALSE(execute_status.ok());
567   EXPECT_THAT(execute_status.status().error_message(),
568               ContainsRegex("Invalid number of arguments"));
569 }
570 
XLA_TEST_F(LocalClientExecuteTest,IncorrectArgumentShape)571 XLA_TEST_F(LocalClientExecuteTest, IncorrectArgumentShape) {
572   // Test passing in an argument with the wrong shape.
573   XlaBuilder builder(TestName());
574   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x");
575   Neg(x);
576 
577   auto x_array = LiteralToShapedBuffer(
578       LiteralUtil::CreateR2<float>({{0.0f, 1.0f}, {2.0f, 3.0f}}));
579   auto execute_status =
580       ExecuteLocally(builder.Build().ValueOrDie(), {&x_array});
581 
582   EXPECT_FALSE(execute_status.ok());
583   EXPECT_THAT(execute_status.status().error_message(),
584               ContainsRegex("Invalid argument shape"))
585       << execute_status.status();
586 }
587 
XLA_TEST_F(LocalClientExecuteTest,InvalidResultLayout)588 XLA_TEST_F(LocalClientExecuteTest, InvalidResultLayout) {
589   // Test passing in an invalid result layout parameter.
590   XlaBuilder builder(TestName());
591   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
592   Neg(x);
593 
594   auto x_array = LiteralToShapedBuffer(
595       LiteralUtil::CreateR2<float>({{0.0f, 1.0f}, {2.0f, 3.0f}}));
596   auto execute_status = ExecuteLocally(
597       builder.Build().ValueOrDie(), {&x_array},
598       DefaultExecutableBuildOptions().set_result_layout(
599           ShapeUtil::MakeShapeWithLayout(F32,
600                                          /*dimensions=*/{1, 2, 3, 4},
601                                          /*minor_to_major=*/{0, 1, 2, 3})),
602       DefaultExecutableRunOptions());
603 
604   EXPECT_FALSE(execute_status.ok());
605   EXPECT_THAT(execute_status.status().error_message(),
606               ContainsRegex("not compatible with result shape"))
607       << execute_status.status();
608 }
609 
XLA_TEST_F(LocalClientExecuteTest,RunOnAllDeviceOrdinals)610 XLA_TEST_F(LocalClientExecuteTest, RunOnAllDeviceOrdinals) {
611   // Try to run a trivial computation on every device on the system. If a
612   // specific device is not supported, check that the right error is returned.
613   XlaBuilder builder(TestName());
614   ConstantR0<float>(&builder, 42.0f);
615   auto computation = builder.Build().value();
616   for (int d = 0; d < local_client_->device_count(); ++d) {
617     if (!local_client_->device_ordinal_supported(d)) {
618       auto execute_status =
619           ExecuteLocally(computation, {},
620                          DefaultExecutableBuildOptions().set_device_ordinal(d),
621                          DefaultExecutableRunOptions().set_device_ordinal(d));
622       EXPECT_FALSE(execute_status.ok());
623       EXPECT_THAT(execute_status.status().error_message(),
624                   ContainsRegex("device .* not supported"));
625     } else {
626       auto result = ExecuteLocallyOrDie(
627           computation, {},
628           DefaultExecutableBuildOptions().set_device_ordinal(d),
629           DefaultExecutableRunOptions().set_device_ordinal(d));
630       EXPECT_EQ(d, result.device_ordinal());
631       LiteralTestUtil::ExpectR0Equal<float>(42.0f,
632                                             ShapedBufferToLiteral(result));
633     }
634   }
635 }
636 
XLA_TEST_F(LocalClientExecuteTest,InvalidDeviceOrdinalValues)637 XLA_TEST_F(LocalClientExecuteTest, InvalidDeviceOrdinalValues) {
638   // Try running computations on devices with device ordinal values which do not
639   // exist.
640   XlaBuilder builder(TestName());
641   ConstantR0<float>(&builder, 42.0f);
642   auto computation = builder.Build().value();
643 
644   auto execute_status =
645       ExecuteLocally(computation, {},
646                      DefaultExecutableBuildOptions().set_device_ordinal(
647                          local_client_->device_count()),
648                      DefaultExecutableRunOptions().set_device_ordinal(
649                          local_client_->device_count()));
650   EXPECT_FALSE(execute_status.ok());
651   EXPECT_THAT(execute_status.status().error_message(),
652               ContainsRegex("Invalid device ordinal value"));
653 }
654 
XLA_TEST_F(LocalClientExecuteTest,RunOnStream)655 XLA_TEST_F(LocalClientExecuteTest, RunOnStream) {
656   // Run a computation on a specific stream on each device on the system.
657   XlaBuilder builder(TestName());
658   ConstantR0<float>(&builder, 42.0f);
659   auto computation = builder.Build().value();
660 
661   for (int d = 0; d < local_client_->device_count(); ++d) {
662     if (!local_client_->device_ordinal_supported(d)) {
663       continue;
664     }
665     se::StreamExecutor* executor =
666         local_client_->platform()->ExecutorForDevice(d).ValueOrDie();
667     se::Stream stream(executor);
668     stream.Init();
669 
670     auto result =
671         ExecuteLocallyOrDie(computation, {}, DefaultExecutableBuildOptions(),
672                             DefaultExecutableRunOptions().set_stream(&stream));
673     // As a check to verify that the computation ran of the device associated
674     // with the stream. This is a weak check, but stronger verification is hard.
675     EXPECT_EQ(d, result.device_ordinal());
676     LiteralTestUtil::ExpectR0Equal<float>(42.0f, ShapedBufferToLiteral(result));
677   }
678 }
679 
680 // Disable this test on CPU because we're using the CPU as the platform
681 // which does not match the service platform.
XLA_TEST_F(LocalClientExecuteTest,DISABLED_ON_CPU (RunOnStreamForWrongPlatform))682 XLA_TEST_F(LocalClientExecuteTest,
683            DISABLED_ON_CPU(RunOnStreamForWrongPlatform)) {
684   // Try to run a computation on a stream for a platform (CPU) which does not
685   // match the platform of the service (!= CPU).
686   se::Platform* wrong_platform =
687       se::MultiPlatformManager::PlatformWithId(se::host::kHostPlatformId)
688           .ValueOrDie();
689   se::Stream wrong_stream(wrong_platform->ExecutorForDevice(0).ValueOrDie());
690   wrong_stream.Init();
691 
692   XlaBuilder builder(TestName());
693   ConstantR0<float>(&builder, 42.0f);
694   auto execute_status = ExecuteLocally(
695       builder.Build().ValueOrDie(), {}, DefaultExecutableBuildOptions(),
696       DefaultExecutableRunOptions().set_stream(&wrong_stream));
697   EXPECT_FALSE(execute_status.ok());
698   EXPECT_THAT(execute_status.status().error_message(),
699               ContainsRegex("stream is for platform .*, but service targets"));
700 }
701 
XLA_TEST_F(LocalClientExecuteTest,DISABLED_ON_CPU (AllocatorDoesNotMatchPlatform))702 XLA_TEST_F(LocalClientExecuteTest,
703            DISABLED_ON_CPU(AllocatorDoesNotMatchPlatform)) {
704   se::Platform* wrong_platform =
705       se::MultiPlatformManager::PlatformWithId(se::host::kHostPlatformId)
706           .ValueOrDie();
707   TestAllocator allocator(wrong_platform);
708 
709   XlaBuilder builder(TestName());
710   ConstantR0<float>(&builder, 123.0f);
711 
712   auto execute_status = ExecuteLocally(
713       builder.Build().ValueOrDie(), {}, DefaultExecutableBuildOptions(),
714       DefaultExecutableRunOptions().set_allocator(&allocator));
715   EXPECT_FALSE(execute_status.ok());
716   EXPECT_THAT(execute_status.status().error_message(),
717               ContainsRegex("allocator platform .* does not match service"));
718 }
719 
XLA_TEST_F(LocalClientExecuteTest,RunOnUninitializedStream)720 XLA_TEST_F(LocalClientExecuteTest, RunOnUninitializedStream) {
721   // Try to run a computation on a stream that has not been initialized.
722   XlaBuilder builder(TestName());
723   ConstantR0<float>(&builder, 42.0f);
724 
725   LOG(INFO) << "default device = " << local_client_->default_device_ordinal();
726   se::StreamExecutor* executor =
727       local_client_->platform()
728           ->ExecutorForDevice(local_client_->default_device_ordinal())
729           .ValueOrDie();
730   se::Stream stream(executor);
731   // Don't call stream.Init().
732 
733   auto execute_status = ExecuteLocally(
734       builder.Build().ValueOrDie(), {}, DefaultExecutableBuildOptions(),
735       DefaultExecutableRunOptions().set_stream(&stream));
736   EXPECT_FALSE(execute_status.ok());
737   EXPECT_THAT(execute_status.status().error_message(),
738               ContainsRegex("stream is uninitialized or in an error state"));
739 }
740 
XLA_TEST_F(LocalClientExecuteTest,CompileExecutable)741 XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) {
742   XlaBuilder builder(TestName());
743   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x");
744   auto y = ConstantR1<float>(&builder, {2.0f, 3.0f, 4.0f});
745   Add(x, y);
746 
747   Shape argument_layout =
748       local_client_->backend().compiler()->DefaultDeviceShapeRepresentation(
749           ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{3}, {0}));
750   TF_ASSERT_OK_AND_ASSIGN(
751       auto executables,
752       local_client_->Compile(builder.Build().ValueOrDie(), {&argument_layout},
753                              ExecutableBuildOptions()));
754   EXPECT_EQ(1, executables.size());
755 
756   auto x_array =
757       LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
758   ScopedShapedBuffer result =
759       executables[0]->Run({&x_array}, DefaultExecutableRunOptions()).value();
760   ASSERT_IS_OK(local_client_->mutable_backend()
761                    ->BorrowStream(0)
762                    .ValueOrDie()
763                    ->BlockHostUntilDone());
764 
765   LiteralTestUtil::ExpectR1Near<float>(
766       {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_);
767 }
768 
XLA_TEST_F(LocalClientExecuteTest,CompilePartitionedExecutable)769 XLA_TEST_F(LocalClientExecuteTest, CompilePartitionedExecutable) {
770   if (local_client_->device_count() < 2) {
771     GTEST_SKIP_("requires two devices");
772   }
773 
774   XlaBuilder builder(TestName());
775   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x");
776   auto y = ConstantR1<float>(&builder, {2.0f, 3.0f, 4.0f});
777   auto z = ConstantR1<float>(&builder, {5.0f, 6.0f, 7.0f});
778   auto r = Add(x, y);
779   builder.SetSharding(sharding_builder::AssignDevice(1));
780   Add(r, z);
781   builder.ClearSharding();
782 
783   Shape argument_layout =
784       ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{3}, {0});
785   ExecutableBuildOptions build_options;
786   build_options.set_num_partitions(2);
787   TF_ASSERT_OK_AND_ASSIGN(
788       auto executables,
789       local_client_->Compile(builder.Build().ValueOrDie(), {&argument_layout},
790                              build_options));
791   EXPECT_EQ(2, executables.size());
792 }
793 
XLA_TEST_F(LocalClientExecuteTest,DISABLED_ON_INTERPRETER (SizeOfGeneratedCodeInBytes))794 XLA_TEST_F(LocalClientExecuteTest,
795            DISABLED_ON_INTERPRETER(SizeOfGeneratedCodeInBytes)) {
796   XlaBuilder builder(TestName());
797   auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
798   constexpr int size = 100000;
799   TF_ASSERT_OK_AND_ASSIGN(auto literal,
800                           LiteralUtil::CreateRandomLiteral<F32>(
801                               ShapeUtil::MakeShape(F32, {size}), 0.0, 1.0));
802   auto y = ConstantLiteral(&builder, literal);
803   Add(x, y);
804 
805   Shape argument_layout =
806       ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{}, {});
807   TF_ASSERT_OK_AND_ASSIGN(
808       auto executables,
809       local_client_->Compile(builder.Build().ValueOrDie(), {&argument_layout},
810                              ExecutableBuildOptions()));
811   EXPECT_EQ(1, executables.size());
812   // The executable should be at least as large as the constant it contains.
813   EXPECT_GT(executables.front()->executable()->SizeOfGeneratedCodeInBytes(),
814             int64_t{sizeof(float) * size});
815 }
816 
XLA_TEST_F(LocalClientExecuteTest,ShapeBufferToLiteralConversion)817 XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion) {
818   // Test copying Literals to the device as ShapedBuffers, then copying them
819   // back again to Literals.
820   auto test_to_device_and_back = [this](const Literal& literal) {
821     TF_ASSERT_OK_AND_ASSIGN(
822         auto shaped_buffer,
823         local_client_->LiteralToShapedBuffer(
824             literal, local_client_->default_device_ordinal(), allocator_));
825     TF_ASSERT_OK_AND_ASSIGN(
826         auto transferred_literal,
827         local_client_->ShapedBufferToLiteral(shaped_buffer));
828     EXPECT_EQ(literal, transferred_literal);
829   };
830 
831   // Array shapes.
832   test_to_device_and_back(LiteralUtil::CreateR0<float>(42.0));
833   test_to_device_and_back(LiteralUtil::CreateR0<bool>(true));
834   test_to_device_and_back(LiteralUtil::CreateR1<float>({1.0, 42.0, 744.4}));
835   test_to_device_and_back(
836       LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}}));
837   test_to_device_and_back(LiteralUtil::CreateR2<int32_t>({{2, 1}, {4444, 56}}));
838 
839   // Null shape (empty tuple).
840   test_to_device_and_back(LiteralUtil::MakeTuple({}));
841 
842   // Non-nested tuples.
843   test_to_device_and_back(LiteralUtil::MakeTupleFromSlices(
844       {LiteralUtil::CreateR0<float>(12223.0)}));
845   test_to_device_and_back(LiteralUtil::MakeTupleFromSlices(
846       {LiteralUtil::CreateR1<float>({1.0, -42.0}),
847        LiteralUtil::CreateR0<float>(123456.0)}));
848 
849   // Nested tuple.
850   test_to_device_and_back(LiteralUtil::MakeTupleFromSlices(
851       {LiteralUtil::MakeTupleFromSlices(
852            {LiteralUtil::CreateR1<float>({1.0, -42.0}),
853             LiteralUtil::CreateR0<float>(123456.0)}),
854        LiteralUtil::CreateR0<bool>(false)}));
855 }
856 
XLA_TEST_F(LocalClientExecuteTest,ShapeBufferToLiteralConversion64bit)857 XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) {
858   // Test copying Literals to the device as ShapedBuffers, then copying them
859   // back again to Literals for 64-bit values.
860   auto test_to_device_and_back = [this](const Literal& literal) {
861     TF_ASSERT_OK_AND_ASSIGN(
862         auto shaped_buffer,
863         local_client_->LiteralToShapedBuffer(
864             literal, local_client_->default_device_ordinal(), allocator_));
865     TF_ASSERT_OK_AND_ASSIGN(
866         auto transferred_literal,
867         local_client_->ShapedBufferToLiteral(shaped_buffer));
868     EXPECT_EQ(literal, transferred_literal);
869   };
870 
871   test_to_device_and_back(LiteralUtil::CreateR2<double>(
872       {{1.0, 2.0, 3.0}, {44.0, 0.099999999999999978, -3}}));
873   test_to_device_and_back(LiteralUtil::CreateR2<int64_t>({{2, 1}, {4444, 56}}));
874   test_to_device_and_back(
875       LiteralUtil::CreateR2<uint64_t>({{20000000000ULL, 1}, {4444, 56}}));
876   test_to_device_and_back(LiteralUtil::MakeTupleFromSlices(
877       {LiteralUtil::CreateR1<double>({1.0, -42.0}),
878        LiteralUtil::CreateR0<int64_t>(123456789000LL)}));
879 }
880 
881 // Disabled on interpreter backend since infeed HLO is unsupported.
XLA_TEST_F(LocalClientExecuteTest,DISABLED_ON_INTERPRETER (InfeedTest))882 XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_INTERPRETER(InfeedTest)) {
883   XlaBuilder builder(TestName());
884   const Shape shape = ShapeUtil::MakeShape(F32, {3});
885   auto in = Infeed(&builder, shape);
886   auto constant = ConstantR1<float>(&builder, {1.0f, 2.0f, 3.0f});
887   Add(in, constant);
888 
889   Literal result;
890   std::unique_ptr<tensorflow::Thread> thread(
891       tensorflow::Env::Default()->StartThread(
892           tensorflow::ThreadOptions(), "execute_thread", [&] {
893             result = ShapedBufferToLiteral(ExecuteLocallyOrDie(
894                 builder.Build().ValueOrDie(), /*arguments=*/{}));
895           }));
896 
897   ASSERT_IS_OK(local_client_->TransferToInfeedLocal(
898       LiteralUtil::CreateR1<float>({-5.0, 123.0, 42.0}),
899       local_client_->default_device_ordinal()));
900 
901   // Join the thread.
902   thread.reset();
903 
904   LiteralTestUtil::ExpectR1Equal<float>({-4.0, 125.0, 45.0}, result);
905 }
906 
907 // Disabled on interpreter backend since infeed/outfeed HLOs are unsupported.
XLA_TEST_F(LocalClientExecuteTest,DISABLED_ON_INTERPRETER (InfeedOutfeedTest))908 XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_INTERPRETER(InfeedOutfeedTest)) {
909   XlaBuilder builder(TestName());
910   const Shape shape = ShapeUtil::MakeShape(F32, {3});
911   auto in = Infeed(&builder, shape);
912   auto constant = ConstantR1<float>(&builder, {1.0f, 2.0f, 3.0f});
913   auto sum = Add(in, constant);
914   Outfeed(sum, shape, /*outfeed_config=*/"");
915 
916   std::unique_ptr<tensorflow::Thread> thread(
917       tensorflow::Env::Default()->StartThread(
918           tensorflow::ThreadOptions(), "execute_thread",
919           [&] { ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); }));
920 
921   ASSERT_IS_OK(local_client_->TransferToInfeedLocal(
922       LiteralUtil::CreateR1<float>({-5.0, 123.0, 42.0}),
923       local_client_->default_device_ordinal()));
924 
925   Literal result(shape);
926   ASSERT_IS_OK(local_client_->TransferFromOutfeedLocal(
927       local_client_->default_device_ordinal(), &result));
928 
929   LiteralTestUtil::ExpectR1Equal<float>({-4.0, 125.0, 45.0}, result);
930 }
931 
932 // Benchmark that measures the overhead of the LocalClient API when running a
933 // trivial computation
BM_LocalClientOverhead(::testing::benchmark::State & state)934 void BM_LocalClientOverhead(::testing::benchmark::State& state) {
935   se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie();
936   auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie();
937   se::StreamExecutorMemoryAllocator allocator(platform, executors);
938   LocalClient* client =
939       ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie();
940   auto* transfer_manager =
941       TransferManager::GetForPlatform(platform).ValueOrDie();
942   int device_ordinal = client->default_device_ordinal();
943 
944   // Use a tiny add operation as the computation.
945   XlaBuilder builder("Add");
946   auto shape = ShapeUtil::MakeShape(F32, {2, 3});
947   auto x = Parameter(&builder, 0, shape, "x");
948   Add(x, x);
949   auto computation = builder.Build().value();
950 
951   auto buffer =
952       transfer_manager
953           ->AllocateScopedShapedBuffer(shape, &allocator, /*device_ordinal=*/0)
954           .value();
955   auto literal = LiteralUtil::CreateR2<float>({{0, 0, 0}, {0, 0, 0}});
956   auto stream =
957       client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie();
958   ASSERT_IS_OK(
959       transfer_manager->TransferLiteralToDevice(stream.get(), literal, buffer));
960 
961   const int kWarmups = 2;
962 
963   TF_ASSERT_OK_AND_ASSIGN(
964       auto executables, client->Compile(computation, {&buffer.on_host_shape()},
965                                         ExecutableBuildOptions()));
966   std::unique_ptr<LocalExecutable> executable = std::move(executables[0]);
967 
968   ExecutableRunOptions run_options;
969   run_options.set_allocator(&allocator).set_stream(stream.get());
970 
971   for (int i = 0; i < kWarmups; ++i) {
972     auto result = executable->Run({&buffer}, run_options);
973     ASSERT_IS_OK(result);
974   }
975 
976   for (auto s : state) {
977     auto result = executable->Run({&buffer}, run_options);
978     ASSERT_IS_OK(result);
979   }
980 }
981 
982 BENCHMARK(BM_LocalClientOverhead);
983 
984 }  // namespace
985 }  // namespace xla
986