xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/client/client.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/client.h"
17 
18 #include <memory>
19 #include <optional>
20 #include <string>
21 #include <utility>
22 
23 #include "absl/strings/str_cat.h"
24 #include "tensorflow/compiler/xla/client/xla_computation.h"
25 #include "tensorflow/compiler/xla/debug_options_flags.h"
26 #include "tensorflow/compiler/xla/execution_options_util.h"
27 #include "tensorflow/compiler/xla/literal.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/compiler/xla/types.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/core/platform/protobuf.h"
33 
34 namespace xla {
35 
Client(ServiceInterface * stub)36 Client::Client(ServiceInterface* stub) : stub_(stub) {}
37 
38 Client::~Client() = default;
39 
Transfer(const GlobalData & data,const Shape * shape_with_layout)40 StatusOr<Literal> Client::Transfer(const GlobalData& data,
41                                    const Shape* shape_with_layout) {
42   TransferToClientRequest request;
43   *request.mutable_data() = data.handle();
44   if (shape_with_layout != nullptr) {
45     *request.mutable_shape_with_layout() = shape_with_layout->ToProto();
46   }
47   TransferToClientResponse response;
48 
49   VLOG(1) << "making transfer request";
50   VLOG(3) << "TransferToClientRequest: {" << request.DebugString() << "}";
51   Status s = stub_->TransferToClient(&request, &response);
52   VLOG(1) << "done with request";
53 
54   if (!s.ok()) {
55     return s;
56   }
57   VLOG(3) << "TransferToClientResponse: {" << response.DebugString() << "}";
58 
59   if (!response.has_literal()) {
60     return FailedPrecondition(
61         "server provided response without a literal in "
62         "TransferToClient request");
63   }
64   return Literal::CreateFromProto(*response.mutable_literal());
65 }
66 
TransferToServer(const LiteralSlice & literal,const DeviceHandle * device_handle)67 StatusOr<std::unique_ptr<GlobalData>> Client::TransferToServer(
68     const LiteralSlice& literal, const DeviceHandle* device_handle) {
69   TransferToServerRequest request;
70   *request.mutable_literal() = literal.ToProto();
71   if (device_handle) {
72     *request.mutable_device_handle() = *device_handle;
73   }
74   TransferToServerResponse response;
75 
76   VLOG(1) << "making transfer to server request";
77   VLOG(3) << "TransferToServerRequest: {" << request.DebugString() << "}";
78   Status s = stub_->TransferToServer(&request, &response);
79   VLOG(1) << "done with request";
80 
81   if (!s.ok()) {
82     return s;
83   }
84   VLOG(3) << "TransferToServerResponse: {" << response.DebugString() << "}";
85 
86   if (!response.has_data()) {
87     return FailedPrecondition(
88         "server provided response without a data handle in "
89         "TransferToServer request");
90   }
91 
92   return std::make_unique<GlobalData>(stub_, response.data());
93 }
94 
TransferToInfeed(const LiteralSlice & literal,int64_t replica_id,const DeviceHandle * device_handle)95 Status Client::TransferToInfeed(const LiteralSlice& literal, int64_t replica_id,
96                                 const DeviceHandle* device_handle) {
97   TransferToInfeedRequest request;
98   *request.mutable_literal() = literal.ToProto();
99   if (device_handle) {
100     *request.mutable_device_handle() = *device_handle;
101   }
102   request.set_replica_id(replica_id);
103   TransferToInfeedResponse response;
104 
105   VLOG(1) << "making transfer to infeed request";
106   VLOG(3) << "TransferToInfeedRequest: {" << request.DebugString() << "}";
107   Status s = stub_->TransferToInfeed(&request, &response);
108   VLOG(1) << "done with request";
109 
110   if (!s.ok()) {
111     return s;
112   }
113   VLOG(3) << "TransferToInfeedResponse: {" << response.DebugString() << "}";
114   return OkStatus();
115 }
116 
TransferFromOutfeed(const Shape * shape_with_layout,int64_t replica_id,const DeviceHandle * device_handle)117 StatusOr<Literal> Client::TransferFromOutfeed(
118     const Shape* shape_with_layout, int64_t replica_id,
119     const DeviceHandle* device_handle) {
120   TransferFromOutfeedRequest request;
121   if (device_handle) {
122     *request.mutable_device_handle() = *device_handle;
123   }
124   request.set_replica_id(replica_id);
125   if (shape_with_layout != nullptr) {
126     *request.mutable_shape_with_layout() = shape_with_layout->ToProto();
127   }
128   TransferFromOutfeedResponse response;
129 
130   VLOG(1) << "making transfer from outfeed request";
131   VLOG(3) << "TransferFromOutfeedRequest: {" << request.DebugString() << "}";
132   Status s = stub_->TransferFromOutfeed(&request, &response);
133   VLOG(1) << "done with request";
134 
135   if (!s.ok()) {
136     return s;
137   }
138   VLOG(3) << "TransferFromOutfeedResponse: {" << response.DebugString() << "}";
139 
140   if (!response.has_literal()) {
141     return FailedPrecondition(
142         "server provided response without a literal in "
143         "TransferToClient request");
144   }
145 
146   return Literal::CreateFromProto(response.literal());
147 }
148 
ResetDevice()149 Status Client::ResetDevice() {
150   ResetDeviceRequest request;
151   ResetDeviceResponse response;
152 
153   VLOG(1) << "making reset device request";
154   VLOG(3) << "ResetDeviceRequest: {" << request.DebugString() << "}";
155   Status s = stub_->ResetDevice(&request, &response);
156   VLOG(1) << "done with request";
157 
158   if (!s.ok()) {
159     return s;
160   }
161   VLOG(3) << "ResetDeviceResponse: {" << response.DebugString() << "}";
162   return OkStatus();
163 }
164 
ExecuteAndTransfer(const XlaComputation & computation,absl::Span<GlobalData * const> arguments,const ExecutionOptions * execution_options,ExecutionProfile * execution_profile)165 StatusOr<Literal> Client::ExecuteAndTransfer(
166     const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
167     const ExecutionOptions* execution_options,
168     ExecutionProfile* execution_profile) {
169   TF_ASSIGN_OR_RETURN(
170       std::unique_ptr<GlobalData> data,
171       Execute(computation, arguments, execution_options, execution_profile));
172 
173   std::optional<Shape> shape_with_output_layout;
174   if (execution_options && execution_options->has_shape_with_output_layout()) {
175     shape_with_output_layout =
176         Shape(execution_options->shape_with_output_layout());
177   }
178   return Transfer(*data, shape_with_output_layout.has_value()
179                              ? &(*shape_with_output_layout)
180                              : nullptr);
181 }
182 
ComputeConstant(const XlaComputation & computation,const Layout * output_layout) const183 StatusOr<Literal> Client::ComputeConstant(const XlaComputation& computation,
184                                           const Layout* output_layout) const {
185   ComputeConstantGraphRequest request;
186   *request.mutable_computation() = computation.proto();
187   if (output_layout != nullptr) {
188     *request.mutable_output_layout() = output_layout->ToProto();
189   }
190 
191   ComputeConstantResponse response;
192 
193   VLOG(2) << "making compute-constant-graph request";
194   Status s = stub_->ComputeConstantGraph(&request, &response);
195   VLOG(2) << "done with request";
196 
197   if (!s.ok()) {
198     return s;
199   }
200 
201   VLOG(3) << "ComputeConstant: {" << response.DebugString() << "}";
202 
203   if (!response.has_literal()) {
204     return InternalError(
205         "no computed literal in the provided response in ComputeConstantGraph "
206         "request");
207   }
208   return Literal::CreateFromProto(response.literal());
209 }
210 
LoadSnapshot(const HloSnapshot & module)211 StatusOr<XlaComputation> Client::LoadSnapshot(const HloSnapshot& module) {
212   TF_RET_CHECK(module.has_hlo() && module.hlo().has_hlo_module());
213   return XlaComputation(module.hlo().hlo_module());
214 }
215 
Compile(const XlaComputation & computation,absl::Span<const Shape> argument_shapes,const ExecutionOptions * execution_options)216 StatusOr<ExecutionHandle> Client::Compile(
217     const XlaComputation& computation, absl::Span<const Shape> argument_shapes,
218     const ExecutionOptions* execution_options) {
219   CompileRequest request;
220   *request.mutable_computation() = computation.proto();
221 
222   if (execution_options == nullptr) {
223     *request.mutable_execution_options() = CreateDefaultExecutionOptions();
224   } else {
225     *request.mutable_execution_options() = *execution_options;
226   }
227   if (request.execution_options().device_handles_size() > 1) {
228     return InvalidArgument(
229         "Compiling with multiple device handles is not supported. Use "
230         "'Execute' instead.");
231   }
232 
233   // The argument shapes affect how the computation is compiled.
234   for (const auto& arg_shape : argument_shapes) {
235     *request.add_input_shape_with_layout() = arg_shape.ToProto();
236   }
237 
238   CompileResponse response;
239   VLOG(1) << "making compile request: " << request.ShortDebugString();
240   Status s = stub_->Compile(&request, &response);
241   VLOG(1) << "done with request";
242 
243   if (!s.ok()) {
244     return s;
245   }
246   TF_RET_CHECK(response.has_handle());
247   return response.handle();
248 }
249 
Execute(const ExecutionHandle & handle,absl::Span<GlobalData * const> arguments,ExecutionProfile * execution_profile)250 StatusOr<std::unique_ptr<GlobalData>> Client::Execute(
251     const ExecutionHandle& handle, absl::Span<GlobalData* const> arguments,
252     ExecutionProfile* execution_profile) {
253   ExecuteRequest request;
254   *request.mutable_handle() = handle;
255   for (GlobalData* argument : arguments) {
256     CHECK(argument != nullptr) << "Argument pointers must not be null.";
257     *request.add_arguments() = argument->handle();
258   }
259 
260   ExecuteResponse response;
261   VLOG(1) << "making execute request: " << request.ShortDebugString();
262   Status s = stub_->Execute(&request, &response);
263   VLOG(1) << "done with request";
264 
265   if (!s.ok()) {
266     return s;
267   }
268 
269   if (execution_profile != nullptr) {
270     *execution_profile = response.profile();
271   }
272 
273   return std::make_unique<GlobalData>(stub_, response.output());
274 }
275 
Execute(const XlaComputation & computation,absl::Span<GlobalData * const> arguments,const ExecutionOptions * execution_options,ExecutionProfile * execution_profile)276 StatusOr<std::unique_ptr<GlobalData>> Client::Execute(
277     const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
278     const ExecutionOptions* execution_options,
279     ExecutionProfile* execution_profile) {
280   // Create an ExecutionOptions if necessary, or set its DeviceHandles.
281   std::optional<ExecutionOptions> options_storage;
282   if (!execution_options || execution_options->device_handles().empty()) {
283     if (execution_options) {
284       options_storage.emplace(*execution_options);
285     } else {
286       options_storage.emplace(CreateDefaultExecutionOptions());
287     }
288     execution_options = &*options_storage;
289 
290     TF_ASSIGN_OR_RETURN(auto device_handles,
291                         GetDeviceHandles(/*device_count=*/1));
292     TF_RET_CHECK(!device_handles.empty());
293     *options_storage->add_device_handles() = std::move(device_handles[0]);
294   }
295 
296   std::vector<XlaComputationInstance> computation_instances = {
297       XlaComputationInstance{
298           computation,
299           std::vector<GlobalData*>(arguments.begin(), arguments.end()),
300           *execution_options, execution_profile}};
301 
302   // Instead of invoking Compile() and Execute(), invoke
303   // Service::ExecuteParallel() to execute our one computation.  Compile()
304   // caches the executable forever, which isn't what we want.
305   VLOG(1) << "Making ExecuteParallel request: "
306           << execution_options->DebugString();
307   TF_ASSIGN_OR_RETURN(auto results, ExecuteParallel(computation_instances));
308   VLOG(1) << "ExecuteParallel request done.";
309 
310   // The result selection is a bit hacky, but better than assuming it is
311   // device 0.
312   //
313   // TODO(b/118493728): Allow Execute to return one result per computation.
314   for (int64_t i = 0, end = results.size(); i < end; i++) {
315     TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(*results[i]));
316     if (!ShapeUtil::IsEmptyTuple(shape)) {
317       VLOG(3) << "Fetching result from device " << i << ": "
318               << ShapeUtil::HumanString(shape);
319       return std::move(results[i]);
320     }
321   }
322   TF_RET_CHECK(!results.empty());
323   VLOG(1) << "Defaulting to device 0 result";
324   return std::move(results[0]);
325 }
326 
ExecuteParallel(absl::Span<const XlaComputationInstance> computations)327 StatusOr<std::vector<std::unique_ptr<GlobalData>>> Client::ExecuteParallel(
328     absl::Span<const XlaComputationInstance> computations) {
329   ExecuteGraphParallelRequest request;
330 
331   for (const XlaComputationInstance& computation : computations) {
332     ExecuteGraphRequest single_request;
333     *single_request.mutable_computation() = computation.computation.proto();
334     for (GlobalData* argument : computation.arguments) {
335       *single_request.add_arguments() = argument->handle();
336     }
337     *single_request.mutable_execution_options() = computation.execution_options;
338     *request.add_requests() = single_request;
339   }
340 
341   ExecuteParallelResponse response;
342   VLOG(1) << "making execute-graph-parallel request: "
343           << request.ShortDebugString();
344   Status s = stub_->ExecuteGraphParallel(&request, &response);
345   VLOG(1) << "done with request";
346 
347   if (!s.ok()) {
348     return s;
349   }
350 
351   std::vector<std::unique_ptr<GlobalData>> outputs;
352   for (size_t i = 0, end = response.responses_size(); i < end; ++i) {
353     outputs.push_back(
354         std::make_unique<GlobalData>(stub_, response.responses(i).output()));
355     if (i < computations.size() &&
356         computations[i].execution_profile != nullptr) {
357       *computations[i].execution_profile = response.responses(i).profile();
358     }
359   }
360 
361   return std::move(outputs);
362 }
363 
GetDeviceHandles(int64_t device_count)364 StatusOr<std::vector<DeviceHandle>> Client::GetDeviceHandles(
365     int64_t device_count) {
366   if (device_count < 1) {
367     return InvalidArgument("device_count must be greater than 0");
368   }
369   GetDeviceHandlesRequest request;
370   request.set_device_count(device_count);
371 
372   GetDeviceHandlesResponse response;
373   VLOG(1) << "making get device request: " << request.ShortDebugString();
374   Status s = stub_->GetDeviceHandles(&request, &response);
375   VLOG(1) << "done with request";
376 
377   if (!s.ok()) {
378     return s;
379   }
380 
381   std::vector<DeviceHandle> device_handles;
382   const auto& response_device_handles = response.device_handles();
383   device_handles.reserve(response_device_handles.size());
384   for (const DeviceHandle& device_handle : response_device_handles) {
385     device_handles.push_back(device_handle);
386   }
387 
388   return device_handles;
389 }
390 
Unregister(const GlobalData & data)391 Status Client::Unregister(const GlobalData& data) {
392   UnregisterRequest request;
393   *request.add_data() = data.handle();
394   UnregisterResponse response;
395 
396   VLOG(1) << "making unregister request";
397   Status s = stub_->Unregister(&request, &response);
398   VLOG(1) << "done with request";
399 
400   return s;
401 }
402 
DeconstructTuple(const GlobalData & data)403 StatusOr<std::vector<std::unique_ptr<GlobalData>>> Client::DeconstructTuple(
404     const GlobalData& data) {
405   DeconstructTupleRequest request;
406   *request.mutable_tuple_handle() = data.handle();
407   DeconstructTupleResponse response;
408 
409   VLOG(1) << "making DestructTuple request";
410   Status s = stub_->DeconstructTuple(&request, &response);
411   VLOG(1) << "done with request";
412 
413   if (!s.ok()) {
414     return s;
415   }
416 
417   std::vector<std::unique_ptr<GlobalData>> handles;
418   for (auto& handle : response.element_handles()) {
419     handles.push_back(std::make_unique<GlobalData>(stub_, handle));
420   }
421   return std::move(handles);
422 }
423 
GetComputationStats(const XlaComputation & computation,const DebugOptions & debug_options) const424 StatusOr<ComputationStats> Client::GetComputationStats(
425     const XlaComputation& computation,
426     const DebugOptions& debug_options) const {
427   ComputationGraphStatsRequest request;
428 
429   // TODO(b/74197823): Find a way to avoid the copy of the hlo proto.
430   *request.mutable_computation() = computation.proto();
431   *request.mutable_debug_options() = debug_options;
432   ComputationStatsResponse response;
433 
434   VLOG(1) << "making computation graph stats request";
435   Status s = stub_->GetComputationGraphStats(&request, &response);
436   VLOG(1) << "done with request";
437 
438   if (!s.ok()) {
439     return s;
440   }
441   CHECK(response.has_stats());
442   return response.stats();
443 }
444 
GetComputationShape(const XlaComputation & computation)445 StatusOr<std::unique_ptr<ProgramShape>> Client::GetComputationShape(
446     const XlaComputation& computation) {
447   TF_ASSIGN_OR_RETURN(const auto& result, computation.GetProgramShape());
448   return std::make_unique<ProgramShape>(result);
449 }
450 
GetShape(const GlobalData & data)451 StatusOr<Shape> Client::GetShape(const GlobalData& data) {
452   GetShapeRequest request;
453   *request.mutable_data() = data.handle();
454   GetShapeResponse response;
455 
456   VLOG(1) << "making get shape request";
457   Status s = stub_->GetShape(&request, &response);
458   VLOG(1) << "done with request";
459 
460   if (!s.ok()) {
461     return s;
462   }
463 
464   return Shape(response.shape());
465 }
466 
ExecutionStatsAsString(const XlaComputation & computation,const ExecutionProfile & profile)467 StatusOr<std::string> Client::ExecutionStatsAsString(
468     const XlaComputation& computation, const ExecutionProfile& profile) {
469   TF_ASSIGN_OR_RETURN(
470       auto computation_stats,
471       GetComputationStats(computation, GetDebugOptionsFromFlags()));
472   int64_t total_flops =
473       computation_stats.flop_count() + computation_stats.transcendental_count();
474   if (profile.compute_time_ns() > 0) {
475     int64_t nanoseconds = profile.compute_time_ns();
476     int64_t cycle_count = profile.compute_cycle_count();
477     double gflops = total_flops / nanoseconds;
478     return absl::StrCat(
479         "[Execution Statistics] flop count: ", computation_stats.flop_count(),
480         ", transcendental count: ", computation_stats.transcendental_count(),
481         ", compute execution time: ", nanoseconds, " nsec",
482         ", compute cycles: ", cycle_count, ", performance: ", gflops,
483         "gflop/s");
484   }
485   return std::string("[Execution Statistics] not available.");
486 }
487 
CreateChannelHandleByType(ChannelHandle::ChannelType type)488 StatusOr<ChannelHandle> Client::CreateChannelHandleByType(
489     ChannelHandle::ChannelType type) {
490   CreateChannelHandleRequest request;
491   request.set_channel_type(type);
492   CreateChannelHandleResponse response;
493 
494   VLOG(1) << "making create channel handle request";
495   Status s = stub_->CreateChannelHandle(&request, &response);
496   VLOG(1) << "done with request";
497 
498   if (!s.ok()) {
499     return s;
500   }
501 
502   return response.channel();
503 }
504 
CreateChannelHandle()505 StatusOr<ChannelHandle> Client::CreateChannelHandle() {
506   return CreateChannelHandleByType(ChannelHandle::DEVICE_TO_DEVICE);
507 }
508 
CreateHostToDeviceChannelHandle()509 StatusOr<ChannelHandle> Client::CreateHostToDeviceChannelHandle() {
510   return CreateChannelHandleByType(ChannelHandle::HOST_TO_DEVICE);
511 }
512 
CreateDeviceToHostChannelHandle()513 StatusOr<ChannelHandle> Client::CreateDeviceToHostChannelHandle() {
514   return CreateChannelHandleByType(ChannelHandle::DEVICE_TO_HOST);
515 }
516 
517 }  // namespace xla
518