1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <initializer_list>
17 #include <memory>
18 #include <vector>
19
20 #include "tensorflow/compiler/xla/client/client_library.h"
21 #include "tensorflow/compiler/xla/client/local_client.h"
22 #include "tensorflow/compiler/xla/client/sharding_builder.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/layout_util.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/service/local_service.h"
27 #include "tensorflow/compiler/xla/service/platform_util.h"
28 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
29 #include "tensorflow/compiler/xla/service/transfer_manager.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/statusor.h"
32 #include "tensorflow/compiler/xla/test.h"
33 #include "tensorflow/compiler/xla/test_helpers.h"
34 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
35 #include "tensorflow/compiler/xla/tests/local_client_test_base.h"
36 #include "tensorflow/compiler/xla/tests/test_macros.h"
37 #include "tensorflow/compiler/xla/tests/test_utils.h"
38 #include "tensorflow/compiler/xla/xla_data.pb.h"
39 #include "tensorflow/core/platform/env.h"
40 #include "tensorflow/core/platform/logging.h"
41 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
42 #include "tensorflow/core/platform/test.h"
43 #include "tensorflow/core/platform/test_benchmark.h"
44 #include "tensorflow/stream_executor/device_memory_allocator.h"
45
46 namespace xla {
47 namespace {
48
49 using ::testing::ContainsRegex;
50
51 class LocalClientExecuteTest : public LocalClientTestBase {
52 protected:
53 ErrorSpec error_spec_{0.0001};
54 };
55
XLA_TEST_F(LocalClientExecuteTest,Constant)56 XLA_TEST_F(LocalClientExecuteTest, Constant) {
57 XlaBuilder builder(TestName());
58 ConstantR0<float>(&builder, 123.0f);
59
60 ScopedShapedBuffer result =
61 ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {});
62 LiteralTestUtil::ExpectR0Near<float>(123.f, ShapedBufferToLiteral(result),
63 error_spec_);
64 }
65
XLA_TEST_F(LocalClientExecuteTest,AddScalars)66 XLA_TEST_F(LocalClientExecuteTest, AddScalars) {
67 XlaBuilder builder(TestName());
68 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
69 auto y = ConstantR0<float>(&builder, 123.0f);
70 Add(x, y);
71
72 auto x_value = LiteralToShapedBuffer(LiteralUtil::CreateR0<float>(42.0f));
73 ScopedShapedBuffer result =
74 ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_value});
75 LiteralTestUtil::ExpectR0Near<float>(165.f, ShapedBufferToLiteral(result),
76 error_spec_);
77 }
78
XLA_TEST_F(LocalClientExecuteTest,AddZeroElementVectors)79 XLA_TEST_F(LocalClientExecuteTest, AddZeroElementVectors) {
80 XlaBuilder builder(TestName());
81 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {0}), "x");
82 auto y = ConstantR1<float>(&builder, {});
83 Add(x, y);
84
85 auto x_array = LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({}));
86 ScopedShapedBuffer result =
87 ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_array});
88 LiteralTestUtil::ExpectR1Near<float>({}, ShapedBufferToLiteral(result),
89 error_spec_);
90 }
91
XLA_TEST_F(LocalClientExecuteTest,AddVectors)92 XLA_TEST_F(LocalClientExecuteTest, AddVectors) {
93 XlaBuilder builder(TestName());
94 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x");
95 auto y = ConstantR1<float>(&builder, {2.0f, 3.0f, 4.0f});
96 Add(x, y);
97
98 auto x_array =
99 LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
100 ScopedShapedBuffer result =
101 ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_array});
102 LiteralTestUtil::ExpectR1Near<float>(
103 {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_);
104 }
105
XLA_TEST_F(LocalClientExecuteTest,AddVectorsWithProfile)106 XLA_TEST_F(LocalClientExecuteTest, AddVectorsWithProfile) {
107 XlaBuilder builder(TestName());
108 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x");
109 auto y = ConstantR1<float>(&builder, {2.0f, 3.0f, 4.0f});
110 Add(x, y);
111
112 auto x_array =
113 LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
114 ExecutionProfile profile;
115 ScopedShapedBuffer result = ExecuteLocallyOrDie(
116 builder.Build().ValueOrDie(), {&x_array}, DefaultExecutableBuildOptions(),
117 DefaultExecutableRunOptions().set_execution_profile(&profile));
118
119 LiteralTestUtil::ExpectR1Near<float>(
120 {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_);
121 EXPECT_GT(profile.compute_and_transfer_time_ns(), 0);
122 }
123
XLA_TEST_F(LocalClientExecuteTest,AddArraysWithDifferentInputLayouts)124 XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) {
125 XlaBuilder builder(TestName());
126 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
127 auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y");
128 Add(x, y);
129 auto computation = builder.Build().value();
130
131 // Create x as a col-major array.
132 auto x_array = LiteralToShapedBuffer(LiteralUtil::CreateR2WithLayout(
133 {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1})));
134 EXPECT_TRUE(Layout::Equal().MinorToMajorOnly()(
135 x_array.on_device_shape().layout(), LayoutUtil::MakeLayout({0, 1})));
136
137 // Create y as a row-major array.
138 auto y_array = LiteralToShapedBuffer(LiteralUtil::CreateR2WithLayout(
139 {{10.0f, 20.0f}, {30.0f, 40.0f}}, LayoutUtil::MakeLayout({1, 0})));
140 EXPECT_TRUE(Layout::Equal().MinorToMajorOnly()(
141 y_array.on_device_shape().layout(), LayoutUtil::MakeLayout({1, 0})));
142
143 ScopedShapedBuffer result_colmaj =
144 ExecuteLocallyOrDie(computation, {&x_array, &y_array});
145 LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}},
146 ShapedBufferToLiteral(result_colmaj),
147 error_spec_);
148
149 // Run with the parameter values in a different order.
150 ScopedShapedBuffer result_param_swap =
151 ExecuteLocallyOrDie(computation, {&y_array, &x_array});
152 LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}},
153 ShapedBufferToLiteral(result_param_swap),
154 error_spec_);
155 }
156
XLA_TEST_F(LocalClientExecuteTest,AddArraysWithDifferentOutputLayouts)157 XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) {
158 XlaBuilder builder(TestName());
159 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
160 auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y");
161 Add(x, y);
162 auto computation = builder.Build().value();
163
164 auto x_array = LiteralToShapedBuffer(
165 LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
166 auto y_array = LiteralToShapedBuffer(
167 LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
168
169 // Run with col-major result layout.
170 ScopedShapedBuffer result_colmaj = ExecuteLocallyOrDie(
171 computation, {&x_array, &y_array},
172 DefaultExecutableBuildOptions().set_result_layout(
173 ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2}, {0, 1})),
174 DefaultExecutableRunOptions());
175 EXPECT_TRUE(Layout::Equal().MinorToMajorOnly()(
176 result_colmaj.on_device_shape().layout(),
177 LayoutUtil::MakeLayout({0, 1})));
178 LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}},
179 ShapedBufferToLiteral(result_colmaj),
180 error_spec_);
181
182 // Run with row-major result layout.
183 ScopedShapedBuffer result_rowmaj = ExecuteLocallyOrDie(
184 computation, {&x_array, &y_array},
185 DefaultExecutableBuildOptions().set_result_layout(
186 ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2}, {1, 0})),
187 DefaultExecutableRunOptions());
188 EXPECT_TRUE(Layout::Equal().MinorToMajorOnly()(
189 result_rowmaj.on_device_shape().layout(),
190 LayoutUtil::MakeLayout({1, 0})));
191 LiteralTestUtil::ExpectR2Near<float>({{11.0f, 22.0f}, {33.0f, 44.0f}},
192 ShapedBufferToLiteral(result_rowmaj),
193 error_spec_);
194 }
195
XLA_TEST_F(LocalClientExecuteTest,TupleResult)196 XLA_TEST_F(LocalClientExecuteTest, TupleResult) {
197 XlaBuilder builder(TestName());
198 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
199 auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y");
200 Tuple(&builder, {x, y, x});
201 auto computation = builder.Build().value();
202
203 auto x_array = LiteralToShapedBuffer(
204 LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
205 auto y_array = LiteralToShapedBuffer(
206 LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
207
208 ScopedShapedBuffer result =
209 ExecuteLocallyOrDie(computation, {&x_array, &y_array});
210
211 EXPECT_TRUE(result.on_host_shape().IsTuple());
212 EXPECT_EQ(3, ShapeUtil::TupleElementCount(result.on_host_shape()));
213
214 Literal result_literal = ShapedBufferToLiteral(result);
215 LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
216 LiteralSlice(result_literal, {0}));
217 LiteralTestUtil::ExpectR2Equal<float>({{10.0f, 20.0f}, {30.0f, 40.0f}},
218 LiteralSlice(result_literal, {1}));
219 LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
220 LiteralSlice(result_literal, {2}));
221 }
222
XLA_TEST_F(LocalClientExecuteTest,NestedTupleResult)223 XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) {
224 XlaBuilder builder(TestName());
225 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
226 auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y");
227 auto inner_tuple = Tuple(&builder, {x, y, x});
228 Tuple(&builder, {inner_tuple, x});
229 auto computation = builder.Build().value();
230
231 auto x_array = LiteralToShapedBuffer(
232 LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
233 auto y_array = LiteralToShapedBuffer(
234 LiteralUtil::CreateR2<float>({{10.0f, 20.0f}, {30.0f, 40.0f}}));
235
236 ScopedShapedBuffer result =
237 ExecuteLocallyOrDie(computation, {&x_array, &y_array});
238
239 EXPECT_TRUE(result.on_host_shape().IsTuple());
240 EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape()));
241
242 Literal result_literal = ShapedBufferToLiteral(result);
243 LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
244 LiteralSlice(result_literal, {1}));
245 LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
246 LiteralSlice(result_literal, {0, 0}));
247 LiteralTestUtil::ExpectR2Equal<float>({{10.0f, 20.0f}, {30.0f, 40.0f}},
248 LiteralSlice(result_literal, {0, 1}));
249 LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
250 LiteralSlice(result_literal, {0, 2}));
251 }
252
XLA_TEST_F(LocalClientExecuteTest,TupleResultWithLayout)253 XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) {
254 // Verify setting the result layout of a computation with a tuple output.
255 XlaBuilder builder(TestName());
256 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
257 auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {2, 2}), "y");
258 Tuple(&builder, {x, y});
259
260 auto array = LiteralToShapedBuffer(
261 LiteralUtil::CreateR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}}));
262
263 ExecutableBuildOptions options = DefaultExecutableBuildOptions();
264 Shape shape_with_layout = ShapeUtil::MakeTupleShape(
265 {ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2},
266 /*minor_to_major=*/{0, 1}),
267 ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2},
268 /*minor_to_major=*/{1, 0})});
269 options.set_result_layout(shape_with_layout);
270 ScopedShapedBuffer result =
271 ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&array, &array},
272 options, DefaultExecutableRunOptions());
273
274 Literal result_literal = ShapedBufferToLiteral(result);
275 LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
276 LiteralSlice(result_literal, {0}));
277 LiteralTestUtil::ExpectR2Equal<float>({{1.0f, 2.0f}, {3.0f, 4.0f}},
278 LiteralSlice(result_literal, {1}));
279 }
280
XLA_TEST_F(LocalClientExecuteTest,TupleArguments)281 XLA_TEST_F(LocalClientExecuteTest, TupleArguments) {
282 const Shape array_shape = ShapeUtil::MakeShape(F32, {2, 2});
283 const Shape vector_shape = ShapeUtil::MakeShape(F32, {3});
284
285 const Shape tuple_shape0 =
286 ShapeUtil::MakeTupleShape({array_shape, vector_shape});
287 const Shape tuple_shape1 =
288 ShapeUtil::MakeTupleShape({vector_shape, array_shape});
289
290 // Computation adds the respective array and vector elements from each tuple
291 // argument and returns the results as a tuple.
292 XlaBuilder builder(TestName());
293 auto x = Parameter(&builder, 0, tuple_shape0, "x");
294 auto y = Parameter(&builder, 1, tuple_shape1, "y");
295 auto x_0 = GetTupleElement(x, 0);
296 auto x_1 = GetTupleElement(x, 1);
297 auto y_0 = GetTupleElement(y, 0);
298 auto y_1 = GetTupleElement(y, 1);
299 auto array_sum = Add(x_0, y_1);
300 auto vector_diff = Sub(x_1, y_0);
301 Tuple(&builder, {array_sum, vector_diff});
302 auto computation = builder.Build().value();
303
304 auto x_literal = LiteralUtil::MakeTupleFromSlices(
305 {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
306 LiteralUtil::CreateR1<float>({42.0, 75.0, 123.0})});
307 auto y_literal = LiteralUtil::MakeTupleFromSlices(
308 {LiteralUtil::CreateR1<float>({2.0, 4.0, 6.0}),
309 LiteralUtil::CreateR2<float>({{55.0, 44.0}, {33.0, 22.0}})});
310
311 auto x_buffer = LiteralToShapedBuffer(x_literal);
312 auto y_buffer = LiteralToShapedBuffer(y_literal);
313
314 ScopedShapedBuffer result =
315 ExecuteLocallyOrDie(computation, {&x_buffer, &y_buffer});
316
317 EXPECT_TRUE(result.on_host_shape().IsTuple());
318 EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape()));
319
320 Literal result_literal = ShapedBufferToLiteral(result);
321 LiteralTestUtil::ExpectR2Equal<float>({{56.0f, 46.0f}, {36.0f, 26.0f}},
322 LiteralSlice(result_literal, {0}));
323 LiteralTestUtil::ExpectR1Equal<float>({40.0f, 71.0f, 117.0f},
324 LiteralSlice(result_literal, {1}));
325 }
326
XLA_TEST_F(LocalClientExecuteTest,NestedTupleArgument)327 XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) {
328 const Shape array_shape = ShapeUtil::MakeShape(F32, {2, 2});
329 const Shape vector_shape = ShapeUtil::MakeShape(F32, {3});
330
331 const Shape inner_tuple_shape =
332 ShapeUtil::MakeTupleShape({array_shape, vector_shape});
333 const Shape nested_tuple_shape =
334 ShapeUtil::MakeTupleShape({inner_tuple_shape, vector_shape});
335
336 // Computation negates the array element and sums the two vector elements in
337 // the nested tuple. The resulting array and vector are returned as a tuple.
338 XlaBuilder builder(TestName());
339 auto param = Parameter(&builder, 0, nested_tuple_shape, "param");
340 auto inner_tuple = GetTupleElement(param, 0);
341 auto inner_array = GetTupleElement(inner_tuple, 0);
342 auto inner_vector = GetTupleElement(inner_tuple, 1);
343 auto outer_vector = GetTupleElement(param, 1);
344
345 auto negate_array = Neg(inner_array);
346 auto vector_sum = Add(inner_vector, outer_vector);
347 Tuple(&builder, {negate_array, vector_sum});
348 auto computation = builder.Build().value();
349
350 auto arg_literal = LiteralUtil::MakeTupleFromSlices(
351 {LiteralUtil::MakeTupleFromSlices(
352 {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
353 LiteralUtil::CreateR1<float>({42.0, 75.0, 123.0})}),
354 LiteralUtil::CreateR1<float>({222.0, -2.0, 10.0})});
355 auto arg_buffer = LiteralToShapedBuffer(arg_literal);
356
357 ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
358
359 Literal result_literal = ShapedBufferToLiteral(result);
360 LiteralTestUtil::ExpectR2Equal<float>({{-1.0, -2.0}, {-3.0, -4}},
361 LiteralSlice(result_literal, {0}));
362 LiteralTestUtil::ExpectR1Equal<float>({264.0, 73.0, 133.0},
363 LiteralSlice(result_literal, {1}));
364 }
365
XLA_TEST_F(LocalClientExecuteTest,PassingTupleResultBackIntoComputation)366 XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) {
367 // Construct a computation which takes and returns the same shape (a
368 // tuple). Feed the result of the computation back into the input. This
369 // provides additional verification that the returned tuple is properly
370 // constructed.
371 const Shape array_shape = ShapeUtil::MakeShape(F32, {2, 2});
372 const Shape tuple_shape =
373 ShapeUtil::MakeTupleShape({array_shape, array_shape});
374
375 XlaBuilder builder(TestName());
376 auto param = Parameter(&builder, 0, tuple_shape, "param");
377 auto element_0 = GetTupleElement(param, 0);
378 auto element_1 = GetTupleElement(param, 1);
379 Tuple(&builder, {Neg(element_0), Add(element_1, element_1)});
380 auto computation = builder.Build().value();
381
382 auto arg_literal = LiteralUtil::MakeTupleFromSlices(
383 {LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
384 LiteralUtil::CreateR2<float>({{11.0, 3.0}, {4.0, 5.0}})});
385 auto arg_buffer = LiteralToShapedBuffer(arg_literal);
386
387 ScopedShapedBuffer result_0 = ExecuteLocallyOrDie(computation, {&arg_buffer});
388 Literal result_0_literal = ShapedBufferToLiteral(result_0);
389 LiteralTestUtil::ExpectR2Equal<float>({{-1.0, -2.0}, {-3.0, -4.0}},
390 LiteralSlice(result_0_literal, {0}));
391 LiteralTestUtil::ExpectR2Equal<float>({{22.0, 6.0}, {8.0, 10}},
392 LiteralSlice(result_0_literal, {1}));
393
394 ScopedShapedBuffer result_1 = ExecuteLocallyOrDie(computation, {&result_0});
395 Literal result_1_literal = ShapedBufferToLiteral(result_1);
396 LiteralTestUtil::ExpectR2Equal<float>({{1.0, 2.0}, {3.0, 4.0}},
397 LiteralSlice(result_1_literal, {0}));
398 LiteralTestUtil::ExpectR2Equal<float>({{44.0, 12.0}, {16.0, 20}},
399 LiteralSlice(result_1_literal, {1}));
400 }
401
XLA_TEST_F(LocalClientExecuteTest,LargeTuple)402 XLA_TEST_F(LocalClientExecuteTest, LargeTuple) {
403 // Construct a computation which takes a tuple parameter with a very large
404 // number of elements.
405
406 // A larger number of elements would make for a better, more strenuous test,
407 // but:
408 // TODO(b/66959878): On cpu a large number of elements results in long
409 // compilation time.
410 // TODO(b/66954197): On gpu a large number of elements OOMs.
411 const int kElementCount = 100;
412
413 // Each element is a 2-element vector.
414 const Shape element_shape = ShapeUtil::MakeShape(F32, {2});
415 std::vector<Shape> element_shapes(kElementCount, element_shape);
416 const Shape tuple_shape = ShapeUtil::MakeTupleShape(element_shapes);
417
418 XlaBuilder builder(TestName());
419 auto param = Parameter(&builder, 0, tuple_shape, "param");
420
421 // Add each element's tuple index value to every element.
422 std::vector<XlaOp> result_elements;
423 result_elements.reserve(kElementCount);
424 for (int i = 0; i < kElementCount; ++i) {
425 auto element = GetTupleElement(param, i);
426 result_elements.push_back(Add(element, ConstantR0<float>(&builder, i)));
427 }
428 Tuple(&builder, result_elements);
429 auto computation = builder.Build().value();
430
431 // Feed in a tuple where each two-element vector element is {tuple_index,
432 // -tuple_index}.
433 std::vector<Literal> arg_elements;
434 arg_elements.reserve(kElementCount);
435 for (int i = 0; i < kElementCount; ++i) {
436 arg_elements.push_back(LiteralUtil::CreateR1<float>({1.0f * i, -1.0f * i}));
437 }
438 Literal arg_literal = LiteralUtil::MakeTupleOwned(std::move(arg_elements));
439 auto arg_buffer = LiteralToShapedBuffer(arg_literal);
440
441 ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
442 Literal result_literal = ShapedBufferToLiteral(result);
443
444 for (int i = 0; i < kElementCount; ++i) {
445 LiteralTestUtil::ExpectR1Near<float>(
446 {2.0f * i, 0.0f}, LiteralSlice(result_literal, {i}), error_spec_);
447 }
448 }
449
XLA_TEST_F(LocalClientExecuteTest,LargeNestedTuple)450 XLA_TEST_F(LocalClientExecuteTest, LargeNestedTuple) {
451 // Construct and run a computation which takes a two-level nested tuple
452 // parameter with a large fanout.
453 const int kFanout = 40;
454
455 // Tuple shape is full two-level tree with the given fanout.
456 const Shape element_shape = ShapeUtil::MakeShape(F32, {});
457 std::vector<Shape> element_shapes(kFanout, element_shape);
458 const Shape inner_tuple_shape = ShapeUtil::MakeTupleShape(element_shapes);
459 std::vector<Shape> inner_tuple_shapes(kFanout, inner_tuple_shape);
460 const Shape tuple_shape = ShapeUtil::MakeTupleShape(inner_tuple_shapes);
461
462 XlaBuilder builder(TestName());
463 auto param = Parameter(&builder, 0, tuple_shape, "param");
464
465 // The computation increments each leaf value by an amount equal to the leaf's
466 // ordinal position in a traversal of the tuple.
467 std::vector<XlaOp> result_elements;
468 result_elements.reserve(kFanout);
469 for (int i = 0; i < kFanout; ++i) {
470 auto outer_element = GetTupleElement(param, i);
471 std::vector<XlaOp> inner_result_elements;
472 inner_result_elements.reserve(kFanout);
473 for (int j = 0; j < kFanout; ++j) {
474 auto inner_element = GetTupleElement(outer_element, j);
475 inner_result_elements.push_back(
476 Add(inner_element, ConstantR0<float>(&builder, i * kFanout + j)));
477 }
478 result_elements.push_back(Tuple(&builder, inner_result_elements));
479 }
480 Tuple(&builder, result_elements);
481 auto computation = builder.Build().value();
482
483 // Construct the argument to pass to the computation.
484 std::vector<Literal> outer_tuple_elements;
485 outer_tuple_elements.reserve(kFanout);
486 for (int i = 0; i < kFanout; ++i) {
487 std::vector<Literal> inner_tuple_elements;
488 inner_tuple_elements.reserve(kFanout);
489 for (int j = 0; j < kFanout; ++j) {
490 inner_tuple_elements.push_back(LiteralUtil::CreateR0<float>(i + j));
491 }
492 outer_tuple_elements.push_back(
493 LiteralUtil::MakeTupleOwned(std::move(inner_tuple_elements)));
494 }
495 auto arg_literal =
496 LiteralUtil::MakeTupleOwned(std::move(outer_tuple_elements));
497 auto arg_buffer = LiteralToShapedBuffer(arg_literal);
498
499 ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
500 Literal result_literal = ShapedBufferToLiteral(result);
501
502 for (int i = 0; i < kFanout; ++i) {
503 for (int j = 0; j < kFanout; ++j) {
504 LiteralTestUtil::ExpectR0Near<float>(i + j + i * kFanout + j,
505 LiteralSlice(result_literal, {i, j}),
506 error_spec_);
507 }
508 }
509 }
510
XLA_TEST_F(LocalClientExecuteTest,DeepTuple)511 XLA_TEST_F(LocalClientExecuteTest, DeepTuple) {
512 // Construct and run a computation which takes a very deep tuple. The tuple
513 // has no fan out and a single scalar element at the bottom.
514 const int kTupleDepth = 100;
515
516 // Tuple shape is full two-level tree with the given fanout.
517 Shape shape = ShapeUtil::MakeShape(F32, {});
518 for (int i = 0; i < kTupleDepth; ++i) {
519 shape = ShapeUtil::MakeTupleShape({shape});
520 }
521
522 XlaBuilder builder(TestName());
523 auto element = Parameter(&builder, 0, shape, "param");
524 for (int i = 0; i < kTupleDepth; ++i) {
525 element = GetTupleElement(element, 0);
526 }
527
528 auto output = Add(element, ConstantR0<float>(&builder, 42.0));
529 for (int i = 0; i < kTupleDepth; ++i) {
530 output = Tuple(&builder, {output});
531 }
532 auto computation = builder.Build().value();
533
534 // Construct the argument to pass to the computation.
535 Literal arg_literal = LiteralUtil::CreateR0<float>(123.0);
536 for (int i = 0; i < kTupleDepth; ++i) {
537 std::vector<Literal> arg_vector;
538 arg_vector.push_back(std::move(arg_literal));
539 arg_literal = LiteralUtil::MakeTupleOwned(std::move(arg_vector));
540 }
541 auto arg_buffer = LiteralToShapedBuffer(arg_literal);
542
543 ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer});
544 Literal result_literal = ShapedBufferToLiteral(result);
545
546 ShapeIndex index;
547 for (int i = 0; i < kTupleDepth; ++i) {
548 index.push_back(0);
549 }
550 LiteralTestUtil::ExpectR0Equal<float>(165.0,
551 LiteralSlice(result_literal, index));
552 }
553
XLA_TEST_F(LocalClientExecuteTest,InvalidNumberOfArguments)554 XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) {
555 // Test passing in an invalid number of arguments.
556 XlaBuilder builder(TestName());
557 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x");
558 auto y = Parameter(&builder, 1, ShapeUtil::MakeShape(F32, {3}), "y");
559 Add(x, y);
560
561 auto x_array =
562 LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f}));
563 auto execute_status =
564 ExecuteLocally(builder.Build().ValueOrDie(), {&x_array});
565
566 EXPECT_FALSE(execute_status.ok());
567 EXPECT_THAT(execute_status.status().error_message(),
568 ContainsRegex("Invalid number of arguments"));
569 }
570
XLA_TEST_F(LocalClientExecuteTest,IncorrectArgumentShape)571 XLA_TEST_F(LocalClientExecuteTest, IncorrectArgumentShape) {
572 // Test passing in an argument with the wrong shape.
573 XlaBuilder builder(TestName());
574 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x");
575 Neg(x);
576
577 auto x_array = LiteralToShapedBuffer(
578 LiteralUtil::CreateR2<float>({{0.0f, 1.0f}, {2.0f, 3.0f}}));
579 auto execute_status =
580 ExecuteLocally(builder.Build().ValueOrDie(), {&x_array});
581
582 EXPECT_FALSE(execute_status.ok());
583 EXPECT_THAT(execute_status.status().error_message(),
584 ContainsRegex("Invalid argument shape"))
585 << execute_status.status();
586 }
587
XLA_TEST_F(LocalClientExecuteTest,InvalidResultLayout)588 XLA_TEST_F(LocalClientExecuteTest, InvalidResultLayout) {
589 // Test passing in an invalid result layout parameter.
590 XlaBuilder builder(TestName());
591 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
592 Neg(x);
593
594 auto x_array = LiteralToShapedBuffer(
595 LiteralUtil::CreateR2<float>({{0.0f, 1.0f}, {2.0f, 3.0f}}));
596 auto execute_status = ExecuteLocally(
597 builder.Build().ValueOrDie(), {&x_array},
598 DefaultExecutableBuildOptions().set_result_layout(
599 ShapeUtil::MakeShapeWithLayout(F32,
600 /*dimensions=*/{1, 2, 3, 4},
601 /*minor_to_major=*/{0, 1, 2, 3})),
602 DefaultExecutableRunOptions());
603
604 EXPECT_FALSE(execute_status.ok());
605 EXPECT_THAT(execute_status.status().error_message(),
606 ContainsRegex("not compatible with result shape"))
607 << execute_status.status();
608 }
609
XLA_TEST_F(LocalClientExecuteTest,RunOnAllDeviceOrdinals)610 XLA_TEST_F(LocalClientExecuteTest, RunOnAllDeviceOrdinals) {
611 // Try to run a trivial computation on every device on the system. If a
612 // specific device is not supported, check that the right error is returned.
613 XlaBuilder builder(TestName());
614 ConstantR0<float>(&builder, 42.0f);
615 auto computation = builder.Build().value();
616 for (int d = 0; d < local_client_->device_count(); ++d) {
617 if (!local_client_->device_ordinal_supported(d)) {
618 auto execute_status =
619 ExecuteLocally(computation, {},
620 DefaultExecutableBuildOptions().set_device_ordinal(d),
621 DefaultExecutableRunOptions().set_device_ordinal(d));
622 EXPECT_FALSE(execute_status.ok());
623 EXPECT_THAT(execute_status.status().error_message(),
624 ContainsRegex("device .* not supported"));
625 } else {
626 auto result = ExecuteLocallyOrDie(
627 computation, {},
628 DefaultExecutableBuildOptions().set_device_ordinal(d),
629 DefaultExecutableRunOptions().set_device_ordinal(d));
630 EXPECT_EQ(d, result.device_ordinal());
631 LiteralTestUtil::ExpectR0Equal<float>(42.0f,
632 ShapedBufferToLiteral(result));
633 }
634 }
635 }
636
XLA_TEST_F(LocalClientExecuteTest,InvalidDeviceOrdinalValues)637 XLA_TEST_F(LocalClientExecuteTest, InvalidDeviceOrdinalValues) {
638 // Try running computations on devices with device ordinal values which do not
639 // exist.
640 XlaBuilder builder(TestName());
641 ConstantR0<float>(&builder, 42.0f);
642 auto computation = builder.Build().value();
643
644 auto execute_status =
645 ExecuteLocally(computation, {},
646 DefaultExecutableBuildOptions().set_device_ordinal(
647 local_client_->device_count()),
648 DefaultExecutableRunOptions().set_device_ordinal(
649 local_client_->device_count()));
650 EXPECT_FALSE(execute_status.ok());
651 EXPECT_THAT(execute_status.status().error_message(),
652 ContainsRegex("Invalid device ordinal value"));
653 }
654
XLA_TEST_F(LocalClientExecuteTest,RunOnStream)655 XLA_TEST_F(LocalClientExecuteTest, RunOnStream) {
656 // Run a computation on a specific stream on each device on the system.
657 XlaBuilder builder(TestName());
658 ConstantR0<float>(&builder, 42.0f);
659 auto computation = builder.Build().value();
660
661 for (int d = 0; d < local_client_->device_count(); ++d) {
662 if (!local_client_->device_ordinal_supported(d)) {
663 continue;
664 }
665 se::StreamExecutor* executor =
666 local_client_->platform()->ExecutorForDevice(d).ValueOrDie();
667 se::Stream stream(executor);
668 stream.Init();
669
670 auto result =
671 ExecuteLocallyOrDie(computation, {}, DefaultExecutableBuildOptions(),
672 DefaultExecutableRunOptions().set_stream(&stream));
673 // As a check to verify that the computation ran of the device associated
674 // with the stream. This is a weak check, but stronger verification is hard.
675 EXPECT_EQ(d, result.device_ordinal());
676 LiteralTestUtil::ExpectR0Equal<float>(42.0f, ShapedBufferToLiteral(result));
677 }
678 }
679
680 // Disable this test on CPU because we're using the CPU as the platform
681 // which does not match the service platform.
XLA_TEST_F(LocalClientExecuteTest,DISABLED_ON_CPU (RunOnStreamForWrongPlatform))682 XLA_TEST_F(LocalClientExecuteTest,
683 DISABLED_ON_CPU(RunOnStreamForWrongPlatform)) {
684 // Try to run a computation on a stream for a platform (CPU) which does not
685 // match the platform of the service (!= CPU).
686 se::Platform* wrong_platform =
687 se::MultiPlatformManager::PlatformWithId(se::host::kHostPlatformId)
688 .ValueOrDie();
689 se::Stream wrong_stream(wrong_platform->ExecutorForDevice(0).ValueOrDie());
690 wrong_stream.Init();
691
692 XlaBuilder builder(TestName());
693 ConstantR0<float>(&builder, 42.0f);
694 auto execute_status = ExecuteLocally(
695 builder.Build().ValueOrDie(), {}, DefaultExecutableBuildOptions(),
696 DefaultExecutableRunOptions().set_stream(&wrong_stream));
697 EXPECT_FALSE(execute_status.ok());
698 EXPECT_THAT(execute_status.status().error_message(),
699 ContainsRegex("stream is for platform .*, but service targets"));
700 }
701
XLA_TEST_F(LocalClientExecuteTest,DISABLED_ON_CPU (AllocatorDoesNotMatchPlatform))702 XLA_TEST_F(LocalClientExecuteTest,
703 DISABLED_ON_CPU(AllocatorDoesNotMatchPlatform)) {
704 se::Platform* wrong_platform =
705 se::MultiPlatformManager::PlatformWithId(se::host::kHostPlatformId)
706 .ValueOrDie();
707 TestAllocator allocator(wrong_platform);
708
709 XlaBuilder builder(TestName());
710 ConstantR0<float>(&builder, 123.0f);
711
712 auto execute_status = ExecuteLocally(
713 builder.Build().ValueOrDie(), {}, DefaultExecutableBuildOptions(),
714 DefaultExecutableRunOptions().set_allocator(&allocator));
715 EXPECT_FALSE(execute_status.ok());
716 EXPECT_THAT(execute_status.status().error_message(),
717 ContainsRegex("allocator platform .* does not match service"));
718 }
719
XLA_TEST_F(LocalClientExecuteTest,RunOnUninitializedStream)720 XLA_TEST_F(LocalClientExecuteTest, RunOnUninitializedStream) {
721 // Try to run a computation on a stream that has not been initialized.
722 XlaBuilder builder(TestName());
723 ConstantR0<float>(&builder, 42.0f);
724
725 LOG(INFO) << "default device = " << local_client_->default_device_ordinal();
726 se::StreamExecutor* executor =
727 local_client_->platform()
728 ->ExecutorForDevice(local_client_->default_device_ordinal())
729 .ValueOrDie();
730 se::Stream stream(executor);
731 // Don't call stream.Init().
732
733 auto execute_status = ExecuteLocally(
734 builder.Build().ValueOrDie(), {}, DefaultExecutableBuildOptions(),
735 DefaultExecutableRunOptions().set_stream(&stream));
736 EXPECT_FALSE(execute_status.ok());
737 EXPECT_THAT(execute_status.status().error_message(),
738 ContainsRegex("stream is uninitialized or in an error state"));
739 }
740
XLA_TEST_F(LocalClientExecuteTest,CompileExecutable)741 XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) {
742 XlaBuilder builder(TestName());
743 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x");
744 auto y = ConstantR1<float>(&builder, {2.0f, 3.0f, 4.0f});
745 Add(x, y);
746
747 Shape argument_layout =
748 local_client_->backend().compiler()->DefaultDeviceShapeRepresentation(
749 ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{3}, {0}));
750 TF_ASSERT_OK_AND_ASSIGN(
751 auto executables,
752 local_client_->Compile(builder.Build().ValueOrDie(), {&argument_layout},
753 ExecutableBuildOptions()));
754 EXPECT_EQ(1, executables.size());
755
756 auto x_array =
757 LiteralToShapedBuffer(LiteralUtil::CreateR1<float>({0.0f, 1.0f, 2.0f}));
758 ScopedShapedBuffer result =
759 executables[0]->Run({&x_array}, DefaultExecutableRunOptions()).value();
760 ASSERT_IS_OK(local_client_->mutable_backend()
761 ->BorrowStream(0)
762 .ValueOrDie()
763 ->BlockHostUntilDone());
764
765 LiteralTestUtil::ExpectR1Near<float>(
766 {2.0f, 4.0f, 6.0f}, ShapedBufferToLiteral(result), error_spec_);
767 }
768
XLA_TEST_F(LocalClientExecuteTest,CompilePartitionedExecutable)769 XLA_TEST_F(LocalClientExecuteTest, CompilePartitionedExecutable) {
770 if (local_client_->device_count() < 2) {
771 GTEST_SKIP_("requires two devices");
772 }
773
774 XlaBuilder builder(TestName());
775 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x");
776 auto y = ConstantR1<float>(&builder, {2.0f, 3.0f, 4.0f});
777 auto z = ConstantR1<float>(&builder, {5.0f, 6.0f, 7.0f});
778 auto r = Add(x, y);
779 builder.SetSharding(sharding_builder::AssignDevice(1));
780 Add(r, z);
781 builder.ClearSharding();
782
783 Shape argument_layout =
784 ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{3}, {0});
785 ExecutableBuildOptions build_options;
786 build_options.set_num_partitions(2);
787 TF_ASSERT_OK_AND_ASSIGN(
788 auto executables,
789 local_client_->Compile(builder.Build().ValueOrDie(), {&argument_layout},
790 build_options));
791 EXPECT_EQ(2, executables.size());
792 }
793
XLA_TEST_F(LocalClientExecuteTest,DISABLED_ON_INTERPRETER (SizeOfGeneratedCodeInBytes))794 XLA_TEST_F(LocalClientExecuteTest,
795 DISABLED_ON_INTERPRETER(SizeOfGeneratedCodeInBytes)) {
796 XlaBuilder builder(TestName());
797 auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "x");
798 constexpr int size = 100000;
799 TF_ASSERT_OK_AND_ASSIGN(auto literal,
800 LiteralUtil::CreateRandomLiteral<F32>(
801 ShapeUtil::MakeShape(F32, {size}), 0.0, 1.0));
802 auto y = ConstantLiteral(&builder, literal);
803 Add(x, y);
804
805 Shape argument_layout =
806 ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{}, {});
807 TF_ASSERT_OK_AND_ASSIGN(
808 auto executables,
809 local_client_->Compile(builder.Build().ValueOrDie(), {&argument_layout},
810 ExecutableBuildOptions()));
811 EXPECT_EQ(1, executables.size());
812 // The executable should be at least as large as the constant it contains.
813 EXPECT_GT(executables.front()->executable()->SizeOfGeneratedCodeInBytes(),
814 int64_t{sizeof(float) * size});
815 }
816
XLA_TEST_F(LocalClientExecuteTest,ShapeBufferToLiteralConversion)817 XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion) {
818 // Test copying Literals to the device as ShapedBuffers, then copying them
819 // back again to Literals.
820 auto test_to_device_and_back = [this](const Literal& literal) {
821 TF_ASSERT_OK_AND_ASSIGN(
822 auto shaped_buffer,
823 local_client_->LiteralToShapedBuffer(
824 literal, local_client_->default_device_ordinal(), allocator_));
825 TF_ASSERT_OK_AND_ASSIGN(
826 auto transferred_literal,
827 local_client_->ShapedBufferToLiteral(shaped_buffer));
828 EXPECT_EQ(literal, transferred_literal);
829 };
830
831 // Array shapes.
832 test_to_device_and_back(LiteralUtil::CreateR0<float>(42.0));
833 test_to_device_and_back(LiteralUtil::CreateR0<bool>(true));
834 test_to_device_and_back(LiteralUtil::CreateR1<float>({1.0, 42.0, 744.4}));
835 test_to_device_and_back(
836 LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}}));
837 test_to_device_and_back(LiteralUtil::CreateR2<int32_t>({{2, 1}, {4444, 56}}));
838
839 // Null shape (empty tuple).
840 test_to_device_and_back(LiteralUtil::MakeTuple({}));
841
842 // Non-nested tuples.
843 test_to_device_and_back(LiteralUtil::MakeTupleFromSlices(
844 {LiteralUtil::CreateR0<float>(12223.0)}));
845 test_to_device_and_back(LiteralUtil::MakeTupleFromSlices(
846 {LiteralUtil::CreateR1<float>({1.0, -42.0}),
847 LiteralUtil::CreateR0<float>(123456.0)}));
848
849 // Nested tuple.
850 test_to_device_and_back(LiteralUtil::MakeTupleFromSlices(
851 {LiteralUtil::MakeTupleFromSlices(
852 {LiteralUtil::CreateR1<float>({1.0, -42.0}),
853 LiteralUtil::CreateR0<float>(123456.0)}),
854 LiteralUtil::CreateR0<bool>(false)}));
855 }
856
XLA_TEST_F(LocalClientExecuteTest,ShapeBufferToLiteralConversion64bit)857 XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) {
858 // Test copying Literals to the device as ShapedBuffers, then copying them
859 // back again to Literals for 64-bit values.
860 auto test_to_device_and_back = [this](const Literal& literal) {
861 TF_ASSERT_OK_AND_ASSIGN(
862 auto shaped_buffer,
863 local_client_->LiteralToShapedBuffer(
864 literal, local_client_->default_device_ordinal(), allocator_));
865 TF_ASSERT_OK_AND_ASSIGN(
866 auto transferred_literal,
867 local_client_->ShapedBufferToLiteral(shaped_buffer));
868 EXPECT_EQ(literal, transferred_literal);
869 };
870
871 test_to_device_and_back(LiteralUtil::CreateR2<double>(
872 {{1.0, 2.0, 3.0}, {44.0, 0.099999999999999978, -3}}));
873 test_to_device_and_back(LiteralUtil::CreateR2<int64_t>({{2, 1}, {4444, 56}}));
874 test_to_device_and_back(
875 LiteralUtil::CreateR2<uint64_t>({{20000000000ULL, 1}, {4444, 56}}));
876 test_to_device_and_back(LiteralUtil::MakeTupleFromSlices(
877 {LiteralUtil::CreateR1<double>({1.0, -42.0}),
878 LiteralUtil::CreateR0<int64_t>(123456789000LL)}));
879 }
880
881 // Disabled on interpreter backend since infeed HLO is unsupported.
XLA_TEST_F(LocalClientExecuteTest,DISABLED_ON_INTERPRETER (InfeedTest))882 XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_INTERPRETER(InfeedTest)) {
883 XlaBuilder builder(TestName());
884 const Shape shape = ShapeUtil::MakeShape(F32, {3});
885 auto in = Infeed(&builder, shape);
886 auto constant = ConstantR1<float>(&builder, {1.0f, 2.0f, 3.0f});
887 Add(in, constant);
888
889 Literal result;
890 std::unique_ptr<tensorflow::Thread> thread(
891 tensorflow::Env::Default()->StartThread(
892 tensorflow::ThreadOptions(), "execute_thread", [&] {
893 result = ShapedBufferToLiteral(ExecuteLocallyOrDie(
894 builder.Build().ValueOrDie(), /*arguments=*/{}));
895 }));
896
897 ASSERT_IS_OK(local_client_->TransferToInfeedLocal(
898 LiteralUtil::CreateR1<float>({-5.0, 123.0, 42.0}),
899 local_client_->default_device_ordinal()));
900
901 // Join the thread.
902 thread.reset();
903
904 LiteralTestUtil::ExpectR1Equal<float>({-4.0, 125.0, 45.0}, result);
905 }
906
907 // Disabled on interpreter backend since infeed/outfeed HLOs are unsupported.
XLA_TEST_F(LocalClientExecuteTest,DISABLED_ON_INTERPRETER (InfeedOutfeedTest))908 XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_INTERPRETER(InfeedOutfeedTest)) {
909 XlaBuilder builder(TestName());
910 const Shape shape = ShapeUtil::MakeShape(F32, {3});
911 auto in = Infeed(&builder, shape);
912 auto constant = ConstantR1<float>(&builder, {1.0f, 2.0f, 3.0f});
913 auto sum = Add(in, constant);
914 Outfeed(sum, shape, /*outfeed_config=*/"");
915
916 std::unique_ptr<tensorflow::Thread> thread(
917 tensorflow::Env::Default()->StartThread(
918 tensorflow::ThreadOptions(), "execute_thread",
919 [&] { ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); }));
920
921 ASSERT_IS_OK(local_client_->TransferToInfeedLocal(
922 LiteralUtil::CreateR1<float>({-5.0, 123.0, 42.0}),
923 local_client_->default_device_ordinal()));
924
925 Literal result(shape);
926 ASSERT_IS_OK(local_client_->TransferFromOutfeedLocal(
927 local_client_->default_device_ordinal(), &result));
928
929 LiteralTestUtil::ExpectR1Equal<float>({-4.0, 125.0, 45.0}, result);
930 }
931
932 // Benchmark that measures the overhead of the LocalClient API when running a
933 // trivial computation
BM_LocalClientOverhead(::testing::benchmark::State & state)934 void BM_LocalClientOverhead(::testing::benchmark::State& state) {
935 se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie();
936 auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie();
937 se::StreamExecutorMemoryAllocator allocator(platform, executors);
938 LocalClient* client =
939 ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie();
940 auto* transfer_manager =
941 TransferManager::GetForPlatform(platform).ValueOrDie();
942 int device_ordinal = client->default_device_ordinal();
943
944 // Use a tiny add operation as the computation.
945 XlaBuilder builder("Add");
946 auto shape = ShapeUtil::MakeShape(F32, {2, 3});
947 auto x = Parameter(&builder, 0, shape, "x");
948 Add(x, x);
949 auto computation = builder.Build().value();
950
951 auto buffer =
952 transfer_manager
953 ->AllocateScopedShapedBuffer(shape, &allocator, /*device_ordinal=*/0)
954 .value();
955 auto literal = LiteralUtil::CreateR2<float>({{0, 0, 0}, {0, 0, 0}});
956 auto stream =
957 client->mutable_backend()->BorrowStream(device_ordinal).ValueOrDie();
958 ASSERT_IS_OK(
959 transfer_manager->TransferLiteralToDevice(stream.get(), literal, buffer));
960
961 const int kWarmups = 2;
962
963 TF_ASSERT_OK_AND_ASSIGN(
964 auto executables, client->Compile(computation, {&buffer.on_host_shape()},
965 ExecutableBuildOptions()));
966 std::unique_ptr<LocalExecutable> executable = std::move(executables[0]);
967
968 ExecutableRunOptions run_options;
969 run_options.set_allocator(&allocator).set_stream(stream.get());
970
971 for (int i = 0; i < kWarmups; ++i) {
972 auto result = executable->Run({&buffer}, run_options);
973 ASSERT_IS_OK(result);
974 }
975
976 for (auto s : state) {
977 auto result = executable->Run({&buffer}, run_options);
978 ASSERT_IS_OK(result);
979 }
980 }
981
982 BENCHMARK(BM_LocalClientOverhead);
983
984 } // namespace
985 } // namespace xla
986