1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/client.h"
17
18 #include <memory>
19 #include <optional>
20 #include <string>
21 #include <utility>
22
23 #include "absl/strings/str_cat.h"
24 #include "tensorflow/compiler/xla/client/xla_computation.h"
25 #include "tensorflow/compiler/xla/debug_options_flags.h"
26 #include "tensorflow/compiler/xla/execution_options_util.h"
27 #include "tensorflow/compiler/xla/literal.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/compiler/xla/types.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/core/platform/protobuf.h"
33
34 namespace xla {
35
Client(ServiceInterface * stub)36 Client::Client(ServiceInterface* stub) : stub_(stub) {}
37
38 Client::~Client() = default;
39
Transfer(const GlobalData & data,const Shape * shape_with_layout)40 StatusOr<Literal> Client::Transfer(const GlobalData& data,
41 const Shape* shape_with_layout) {
42 TransferToClientRequest request;
43 *request.mutable_data() = data.handle();
44 if (shape_with_layout != nullptr) {
45 *request.mutable_shape_with_layout() = shape_with_layout->ToProto();
46 }
47 TransferToClientResponse response;
48
49 VLOG(1) << "making transfer request";
50 VLOG(3) << "TransferToClientRequest: {" << request.DebugString() << "}";
51 Status s = stub_->TransferToClient(&request, &response);
52 VLOG(1) << "done with request";
53
54 if (!s.ok()) {
55 return s;
56 }
57 VLOG(3) << "TransferToClientResponse: {" << response.DebugString() << "}";
58
59 if (!response.has_literal()) {
60 return FailedPrecondition(
61 "server provided response without a literal in "
62 "TransferToClient request");
63 }
64 return Literal::CreateFromProto(*response.mutable_literal());
65 }
66
TransferToServer(const LiteralSlice & literal,const DeviceHandle * device_handle)67 StatusOr<std::unique_ptr<GlobalData>> Client::TransferToServer(
68 const LiteralSlice& literal, const DeviceHandle* device_handle) {
69 TransferToServerRequest request;
70 *request.mutable_literal() = literal.ToProto();
71 if (device_handle) {
72 *request.mutable_device_handle() = *device_handle;
73 }
74 TransferToServerResponse response;
75
76 VLOG(1) << "making transfer to server request";
77 VLOG(3) << "TransferToServerRequest: {" << request.DebugString() << "}";
78 Status s = stub_->TransferToServer(&request, &response);
79 VLOG(1) << "done with request";
80
81 if (!s.ok()) {
82 return s;
83 }
84 VLOG(3) << "TransferToServerResponse: {" << response.DebugString() << "}";
85
86 if (!response.has_data()) {
87 return FailedPrecondition(
88 "server provided response without a data handle in "
89 "TransferToServer request");
90 }
91
92 return std::make_unique<GlobalData>(stub_, response.data());
93 }
94
TransferToInfeed(const LiteralSlice & literal,int64_t replica_id,const DeviceHandle * device_handle)95 Status Client::TransferToInfeed(const LiteralSlice& literal, int64_t replica_id,
96 const DeviceHandle* device_handle) {
97 TransferToInfeedRequest request;
98 *request.mutable_literal() = literal.ToProto();
99 if (device_handle) {
100 *request.mutable_device_handle() = *device_handle;
101 }
102 request.set_replica_id(replica_id);
103 TransferToInfeedResponse response;
104
105 VLOG(1) << "making transfer to infeed request";
106 VLOG(3) << "TransferToInfeedRequest: {" << request.DebugString() << "}";
107 Status s = stub_->TransferToInfeed(&request, &response);
108 VLOG(1) << "done with request";
109
110 if (!s.ok()) {
111 return s;
112 }
113 VLOG(3) << "TransferToInfeedResponse: {" << response.DebugString() << "}";
114 return OkStatus();
115 }
116
TransferFromOutfeed(const Shape * shape_with_layout,int64_t replica_id,const DeviceHandle * device_handle)117 StatusOr<Literal> Client::TransferFromOutfeed(
118 const Shape* shape_with_layout, int64_t replica_id,
119 const DeviceHandle* device_handle) {
120 TransferFromOutfeedRequest request;
121 if (device_handle) {
122 *request.mutable_device_handle() = *device_handle;
123 }
124 request.set_replica_id(replica_id);
125 if (shape_with_layout != nullptr) {
126 *request.mutable_shape_with_layout() = shape_with_layout->ToProto();
127 }
128 TransferFromOutfeedResponse response;
129
130 VLOG(1) << "making transfer from outfeed request";
131 VLOG(3) << "TransferFromOutfeedRequest: {" << request.DebugString() << "}";
132 Status s = stub_->TransferFromOutfeed(&request, &response);
133 VLOG(1) << "done with request";
134
135 if (!s.ok()) {
136 return s;
137 }
138 VLOG(3) << "TransferFromOutfeedResponse: {" << response.DebugString() << "}";
139
140 if (!response.has_literal()) {
141 return FailedPrecondition(
142 "server provided response without a literal in "
143 "TransferToClient request");
144 }
145
146 return Literal::CreateFromProto(response.literal());
147 }
148
ResetDevice()149 Status Client::ResetDevice() {
150 ResetDeviceRequest request;
151 ResetDeviceResponse response;
152
153 VLOG(1) << "making reset device request";
154 VLOG(3) << "ResetDeviceRequest: {" << request.DebugString() << "}";
155 Status s = stub_->ResetDevice(&request, &response);
156 VLOG(1) << "done with request";
157
158 if (!s.ok()) {
159 return s;
160 }
161 VLOG(3) << "ResetDeviceResponse: {" << response.DebugString() << "}";
162 return OkStatus();
163 }
164
ExecuteAndTransfer(const XlaComputation & computation,absl::Span<GlobalData * const> arguments,const ExecutionOptions * execution_options,ExecutionProfile * execution_profile)165 StatusOr<Literal> Client::ExecuteAndTransfer(
166 const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
167 const ExecutionOptions* execution_options,
168 ExecutionProfile* execution_profile) {
169 TF_ASSIGN_OR_RETURN(
170 std::unique_ptr<GlobalData> data,
171 Execute(computation, arguments, execution_options, execution_profile));
172
173 std::optional<Shape> shape_with_output_layout;
174 if (execution_options && execution_options->has_shape_with_output_layout()) {
175 shape_with_output_layout =
176 Shape(execution_options->shape_with_output_layout());
177 }
178 return Transfer(*data, shape_with_output_layout.has_value()
179 ? &(*shape_with_output_layout)
180 : nullptr);
181 }
182
ComputeConstant(const XlaComputation & computation,const Layout * output_layout) const183 StatusOr<Literal> Client::ComputeConstant(const XlaComputation& computation,
184 const Layout* output_layout) const {
185 ComputeConstantGraphRequest request;
186 *request.mutable_computation() = computation.proto();
187 if (output_layout != nullptr) {
188 *request.mutable_output_layout() = output_layout->ToProto();
189 }
190
191 ComputeConstantResponse response;
192
193 VLOG(2) << "making compute-constant-graph request";
194 Status s = stub_->ComputeConstantGraph(&request, &response);
195 VLOG(2) << "done with request";
196
197 if (!s.ok()) {
198 return s;
199 }
200
201 VLOG(3) << "ComputeConstant: {" << response.DebugString() << "}";
202
203 if (!response.has_literal()) {
204 return InternalError(
205 "no computed literal in the provided response in ComputeConstantGraph "
206 "request");
207 }
208 return Literal::CreateFromProto(response.literal());
209 }
210
LoadSnapshot(const HloSnapshot & module)211 StatusOr<XlaComputation> Client::LoadSnapshot(const HloSnapshot& module) {
212 TF_RET_CHECK(module.has_hlo() && module.hlo().has_hlo_module());
213 return XlaComputation(module.hlo().hlo_module());
214 }
215
Compile(const XlaComputation & computation,absl::Span<const Shape> argument_shapes,const ExecutionOptions * execution_options)216 StatusOr<ExecutionHandle> Client::Compile(
217 const XlaComputation& computation, absl::Span<const Shape> argument_shapes,
218 const ExecutionOptions* execution_options) {
219 CompileRequest request;
220 *request.mutable_computation() = computation.proto();
221
222 if (execution_options == nullptr) {
223 *request.mutable_execution_options() = CreateDefaultExecutionOptions();
224 } else {
225 *request.mutable_execution_options() = *execution_options;
226 }
227 if (request.execution_options().device_handles_size() > 1) {
228 return InvalidArgument(
229 "Compiling with multiple device handles is not supported. Use "
230 "'Execute' instead.");
231 }
232
233 // The argument shapes affect how the computation is compiled.
234 for (const auto& arg_shape : argument_shapes) {
235 *request.add_input_shape_with_layout() = arg_shape.ToProto();
236 }
237
238 CompileResponse response;
239 VLOG(1) << "making compile request: " << request.ShortDebugString();
240 Status s = stub_->Compile(&request, &response);
241 VLOG(1) << "done with request";
242
243 if (!s.ok()) {
244 return s;
245 }
246 TF_RET_CHECK(response.has_handle());
247 return response.handle();
248 }
249
Execute(const ExecutionHandle & handle,absl::Span<GlobalData * const> arguments,ExecutionProfile * execution_profile)250 StatusOr<std::unique_ptr<GlobalData>> Client::Execute(
251 const ExecutionHandle& handle, absl::Span<GlobalData* const> arguments,
252 ExecutionProfile* execution_profile) {
253 ExecuteRequest request;
254 *request.mutable_handle() = handle;
255 for (GlobalData* argument : arguments) {
256 CHECK(argument != nullptr) << "Argument pointers must not be null.";
257 *request.add_arguments() = argument->handle();
258 }
259
260 ExecuteResponse response;
261 VLOG(1) << "making execute request: " << request.ShortDebugString();
262 Status s = stub_->Execute(&request, &response);
263 VLOG(1) << "done with request";
264
265 if (!s.ok()) {
266 return s;
267 }
268
269 if (execution_profile != nullptr) {
270 *execution_profile = response.profile();
271 }
272
273 return std::make_unique<GlobalData>(stub_, response.output());
274 }
275
Execute(const XlaComputation & computation,absl::Span<GlobalData * const> arguments,const ExecutionOptions * execution_options,ExecutionProfile * execution_profile)276 StatusOr<std::unique_ptr<GlobalData>> Client::Execute(
277 const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
278 const ExecutionOptions* execution_options,
279 ExecutionProfile* execution_profile) {
280 // Create an ExecutionOptions if necessary, or set its DeviceHandles.
281 std::optional<ExecutionOptions> options_storage;
282 if (!execution_options || execution_options->device_handles().empty()) {
283 if (execution_options) {
284 options_storage.emplace(*execution_options);
285 } else {
286 options_storage.emplace(CreateDefaultExecutionOptions());
287 }
288 execution_options = &*options_storage;
289
290 TF_ASSIGN_OR_RETURN(auto device_handles,
291 GetDeviceHandles(/*device_count=*/1));
292 TF_RET_CHECK(!device_handles.empty());
293 *options_storage->add_device_handles() = std::move(device_handles[0]);
294 }
295
296 std::vector<XlaComputationInstance> computation_instances = {
297 XlaComputationInstance{
298 computation,
299 std::vector<GlobalData*>(arguments.begin(), arguments.end()),
300 *execution_options, execution_profile}};
301
302 // Instead of invoking Compile() and Execute(), invoke
303 // Service::ExecuteParallel() to execute our one computation. Compile()
304 // caches the executable forever, which isn't what we want.
305 VLOG(1) << "Making ExecuteParallel request: "
306 << execution_options->DebugString();
307 TF_ASSIGN_OR_RETURN(auto results, ExecuteParallel(computation_instances));
308 VLOG(1) << "ExecuteParallel request done.";
309
310 // The result selection is a bit hacky, but better than assuming it is
311 // device 0.
312 //
313 // TODO(b/118493728): Allow Execute to return one result per computation.
314 for (int64_t i = 0, end = results.size(); i < end; i++) {
315 TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(*results[i]));
316 if (!ShapeUtil::IsEmptyTuple(shape)) {
317 VLOG(3) << "Fetching result from device " << i << ": "
318 << ShapeUtil::HumanString(shape);
319 return std::move(results[i]);
320 }
321 }
322 TF_RET_CHECK(!results.empty());
323 VLOG(1) << "Defaulting to device 0 result";
324 return std::move(results[0]);
325 }
326
ExecuteParallel(absl::Span<const XlaComputationInstance> computations)327 StatusOr<std::vector<std::unique_ptr<GlobalData>>> Client::ExecuteParallel(
328 absl::Span<const XlaComputationInstance> computations) {
329 ExecuteGraphParallelRequest request;
330
331 for (const XlaComputationInstance& computation : computations) {
332 ExecuteGraphRequest single_request;
333 *single_request.mutable_computation() = computation.computation.proto();
334 for (GlobalData* argument : computation.arguments) {
335 *single_request.add_arguments() = argument->handle();
336 }
337 *single_request.mutable_execution_options() = computation.execution_options;
338 *request.add_requests() = single_request;
339 }
340
341 ExecuteParallelResponse response;
342 VLOG(1) << "making execute-graph-parallel request: "
343 << request.ShortDebugString();
344 Status s = stub_->ExecuteGraphParallel(&request, &response);
345 VLOG(1) << "done with request";
346
347 if (!s.ok()) {
348 return s;
349 }
350
351 std::vector<std::unique_ptr<GlobalData>> outputs;
352 for (size_t i = 0, end = response.responses_size(); i < end; ++i) {
353 outputs.push_back(
354 std::make_unique<GlobalData>(stub_, response.responses(i).output()));
355 if (i < computations.size() &&
356 computations[i].execution_profile != nullptr) {
357 *computations[i].execution_profile = response.responses(i).profile();
358 }
359 }
360
361 return std::move(outputs);
362 }
363
GetDeviceHandles(int64_t device_count)364 StatusOr<std::vector<DeviceHandle>> Client::GetDeviceHandles(
365 int64_t device_count) {
366 if (device_count < 1) {
367 return InvalidArgument("device_count must be greater than 0");
368 }
369 GetDeviceHandlesRequest request;
370 request.set_device_count(device_count);
371
372 GetDeviceHandlesResponse response;
373 VLOG(1) << "making get device request: " << request.ShortDebugString();
374 Status s = stub_->GetDeviceHandles(&request, &response);
375 VLOG(1) << "done with request";
376
377 if (!s.ok()) {
378 return s;
379 }
380
381 std::vector<DeviceHandle> device_handles;
382 const auto& response_device_handles = response.device_handles();
383 device_handles.reserve(response_device_handles.size());
384 for (const DeviceHandle& device_handle : response_device_handles) {
385 device_handles.push_back(device_handle);
386 }
387
388 return device_handles;
389 }
390
Unregister(const GlobalData & data)391 Status Client::Unregister(const GlobalData& data) {
392 UnregisterRequest request;
393 *request.add_data() = data.handle();
394 UnregisterResponse response;
395
396 VLOG(1) << "making unregister request";
397 Status s = stub_->Unregister(&request, &response);
398 VLOG(1) << "done with request";
399
400 return s;
401 }
402
DeconstructTuple(const GlobalData & data)403 StatusOr<std::vector<std::unique_ptr<GlobalData>>> Client::DeconstructTuple(
404 const GlobalData& data) {
405 DeconstructTupleRequest request;
406 *request.mutable_tuple_handle() = data.handle();
407 DeconstructTupleResponse response;
408
409 VLOG(1) << "making DestructTuple request";
410 Status s = stub_->DeconstructTuple(&request, &response);
411 VLOG(1) << "done with request";
412
413 if (!s.ok()) {
414 return s;
415 }
416
417 std::vector<std::unique_ptr<GlobalData>> handles;
418 for (auto& handle : response.element_handles()) {
419 handles.push_back(std::make_unique<GlobalData>(stub_, handle));
420 }
421 return std::move(handles);
422 }
423
GetComputationStats(const XlaComputation & computation,const DebugOptions & debug_options) const424 StatusOr<ComputationStats> Client::GetComputationStats(
425 const XlaComputation& computation,
426 const DebugOptions& debug_options) const {
427 ComputationGraphStatsRequest request;
428
429 // TODO(b/74197823): Find a way to avoid the copy of the hlo proto.
430 *request.mutable_computation() = computation.proto();
431 *request.mutable_debug_options() = debug_options;
432 ComputationStatsResponse response;
433
434 VLOG(1) << "making computation graph stats request";
435 Status s = stub_->GetComputationGraphStats(&request, &response);
436 VLOG(1) << "done with request";
437
438 if (!s.ok()) {
439 return s;
440 }
441 CHECK(response.has_stats());
442 return response.stats();
443 }
444
GetComputationShape(const XlaComputation & computation)445 StatusOr<std::unique_ptr<ProgramShape>> Client::GetComputationShape(
446 const XlaComputation& computation) {
447 TF_ASSIGN_OR_RETURN(const auto& result, computation.GetProgramShape());
448 return std::make_unique<ProgramShape>(result);
449 }
450
GetShape(const GlobalData & data)451 StatusOr<Shape> Client::GetShape(const GlobalData& data) {
452 GetShapeRequest request;
453 *request.mutable_data() = data.handle();
454 GetShapeResponse response;
455
456 VLOG(1) << "making get shape request";
457 Status s = stub_->GetShape(&request, &response);
458 VLOG(1) << "done with request";
459
460 if (!s.ok()) {
461 return s;
462 }
463
464 return Shape(response.shape());
465 }
466
ExecutionStatsAsString(const XlaComputation & computation,const ExecutionProfile & profile)467 StatusOr<std::string> Client::ExecutionStatsAsString(
468 const XlaComputation& computation, const ExecutionProfile& profile) {
469 TF_ASSIGN_OR_RETURN(
470 auto computation_stats,
471 GetComputationStats(computation, GetDebugOptionsFromFlags()));
472 int64_t total_flops =
473 computation_stats.flop_count() + computation_stats.transcendental_count();
474 if (profile.compute_time_ns() > 0) {
475 int64_t nanoseconds = profile.compute_time_ns();
476 int64_t cycle_count = profile.compute_cycle_count();
477 double gflops = total_flops / nanoseconds;
478 return absl::StrCat(
479 "[Execution Statistics] flop count: ", computation_stats.flop_count(),
480 ", transcendental count: ", computation_stats.transcendental_count(),
481 ", compute execution time: ", nanoseconds, " nsec",
482 ", compute cycles: ", cycle_count, ", performance: ", gflops,
483 "gflop/s");
484 }
485 return std::string("[Execution Statistics] not available.");
486 }
487
CreateChannelHandleByType(ChannelHandle::ChannelType type)488 StatusOr<ChannelHandle> Client::CreateChannelHandleByType(
489 ChannelHandle::ChannelType type) {
490 CreateChannelHandleRequest request;
491 request.set_channel_type(type);
492 CreateChannelHandleResponse response;
493
494 VLOG(1) << "making create channel handle request";
495 Status s = stub_->CreateChannelHandle(&request, &response);
496 VLOG(1) << "done with request";
497
498 if (!s.ok()) {
499 return s;
500 }
501
502 return response.channel();
503 }
504
CreateChannelHandle()505 StatusOr<ChannelHandle> Client::CreateChannelHandle() {
506 return CreateChannelHandleByType(ChannelHandle::DEVICE_TO_DEVICE);
507 }
508
CreateHostToDeviceChannelHandle()509 StatusOr<ChannelHandle> Client::CreateHostToDeviceChannelHandle() {
510 return CreateChannelHandleByType(ChannelHandle::HOST_TO_DEVICE);
511 }
512
CreateDeviceToHostChannelHandle()513 StatusOr<ChannelHandle> Client::CreateDeviceToHostChannelHandle() {
514 return CreateChannelHandleByType(ChannelHandle::DEVICE_TO_HOST);
515 }
516
517 } // namespace xla
518