xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/service.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/service.h"
17 
18 #include <algorithm>
19 #include <functional>
20 #include <memory>
21 #include <numeric>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/str_format.h"
28 #include "tensorflow/compiler/xla/debug_options_flags.h"
29 #include "tensorflow/compiler/xla/execution_options_util.h"
30 #include "tensorflow/compiler/xla/layout_util.h"
31 #include "tensorflow/compiler/xla/service/backend.h"
32 #include "tensorflow/compiler/xla/service/compiler.h"
33 #include "tensorflow/compiler/xla/service/computation_layout.h"
34 #include "tensorflow/compiler/xla/service/computation_placer.h"
35 #include "tensorflow/compiler/xla/service/dump.h"
36 #include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h"
37 #include "tensorflow/compiler/xla/service/dynamic_padder.h"
38 #include "tensorflow/compiler/xla/service/executable.h"
39 #include "tensorflow/compiler/xla/service/hlo_computation.h"
40 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
41 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
42 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
43 #include "tensorflow/compiler/xla/service/hlo_module.h"
44 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
45 #include "tensorflow/compiler/xla/service/hlo_module_group.h"
46 #include "tensorflow/compiler/xla/service/hlo_module_util.h"
47 #include "tensorflow/compiler/xla/service/hlo_proto_util.h"
48 #include "tensorflow/compiler/xla/service/platform_util.h"
49 #include "tensorflow/compiler/xla/service/source_map_util.h"
50 #include "tensorflow/compiler/xla/service/stream_pool.h"
51 #include "tensorflow/compiler/xla/service/transfer_manager.h"
52 #include "tensorflow/compiler/xla/shape.h"
53 #include "tensorflow/compiler/xla/shape_layout.h"
54 #include "tensorflow/compiler/xla/shape_util.h"
55 #include "tensorflow/compiler/xla/status_macros.h"
56 #include "tensorflow/compiler/xla/types.h"
57 #include "tensorflow/compiler/xla/util.h"
58 #include "tensorflow/compiler/xla/xla_data.pb.h"
59 #include "tensorflow/core/platform/env.h"
60 #include "tensorflow/core/platform/errors.h"
61 #include "tensorflow/core/platform/logging.h"
62 #include "tensorflow/core/platform/protobuf.h"
63 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
64 #include "tensorflow/core/util/ptr_util.h"
65 #include "tensorflow/stream_executor/device_memory_allocator.h"
66 
67 namespace xla {
68 namespace {
69 
70 using absl::StrCat;
71 using absl::StrFormat;
72 
73 // Argument used when calling DumpHloModuleIfEnabled before optimizations are
74 // performed on an HloModule.
75 constexpr char kBeforeOptimizationsDumpName[] = "before_optimizations";
76 
77 // Records the arguments used to invoke a computation in an HloSnapshot proto.
RecordArguments(const absl::Span<const ShapedBuffer * const> arguments,se::Stream * stream,TransferManager * transfer_manager,HloSnapshot * module)78 Status RecordArguments(const absl::Span<const ShapedBuffer* const> arguments,
79                        se::Stream* stream, TransferManager* transfer_manager,
80                        HloSnapshot* module) {
81   module->clear_arguments();
82   for (const ShapedBuffer* argument : arguments) {
83     TF_ASSIGN_OR_RETURN(
84         Literal literal,
85         transfer_manager->TransferLiteralFromDevice(stream, *argument));
86     *module->add_arguments() = literal.ToProto();
87   }
88   return OkStatus();
89 }
90 
91 // Records the result of a computation in a HloSnapshot proto.
RecordResult(const ShapedBuffer & result,se::Stream * stream,TransferManager * transfer_manager,HloSnapshot * module)92 Status RecordResult(const ShapedBuffer& result, se::Stream* stream,
93                     TransferManager* transfer_manager, HloSnapshot* module) {
94   module->clear_result();
95   TF_ASSIGN_OR_RETURN(
96       Literal literal,
97       transfer_manager->TransferLiteralFromDevice(stream, result));
98   *module->mutable_result() = literal.ToProto();
99   return OkStatus();
100 }
101 
102 }  // namespace
103 
set_platform(se::Platform * platform)104 ServiceOptions& ServiceOptions::set_platform(se::Platform* platform) {
105   platform_ = platform;
106   return *this;
107 }
108 
platform() const109 se::Platform* ServiceOptions::platform() const { return platform_; }
110 
set_number_of_replicas(int number_of_replicas)111 ServiceOptions& ServiceOptions::set_number_of_replicas(int number_of_replicas) {
112   number_of_replicas_ = number_of_replicas;
113   return *this;
114 }
115 
number_of_replicas() const116 int ServiceOptions::number_of_replicas() const { return number_of_replicas_; }
117 
set_intra_op_parallelism_threads(int num_threads)118 ServiceOptions& ServiceOptions::set_intra_op_parallelism_threads(
119     int num_threads) {
120   intra_op_parallelism_threads_ = num_threads;
121   return *this;
122 }
123 
intra_op_parallelism_threads() const124 int ServiceOptions::intra_op_parallelism_threads() const {
125   return intra_op_parallelism_threads_;
126 }
127 
set_allowed_devices(const std::optional<std::set<int>> & allowed_devices)128 ServiceOptions& ServiceOptions::set_allowed_devices(
129     const std::optional<std::set<int>>& allowed_devices) {
130   allowed_devices_ = allowed_devices;
131   return *this;
132 }
133 
allowed_devices() const134 const std::optional<std::set<int>>& ServiceOptions::allowed_devices() const {
135   return allowed_devices_;
136 }
137 
NewService(se::Platform * platform)138 /* static */ StatusOr<std::unique_ptr<Service>> Service::NewService(
139     se::Platform* platform) {
140   ServiceOptions default_options;
141   default_options.set_platform(platform);
142   return NewService(default_options);
143 }
144 
NewService(const ServiceOptions & options)145 /* static */ StatusOr<std::unique_ptr<Service>> Service::NewService(
146     const ServiceOptions& options) {
147   se::Platform* platform = options.platform();
148   std::unique_ptr<Backend> execute_backend;
149   if (platform == nullptr) {
150     TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform());
151   }
152   BackendOptions backend_options;
153   backend_options.set_platform(platform);
154   backend_options.set_allowed_devices(options.allowed_devices());
155   TF_ASSIGN_OR_RETURN(execute_backend, Backend::CreateBackend(backend_options));
156 
157   std::unique_ptr<Service> service(
158       new Service(options, std::move(execute_backend)));
159   return std::move(service);
160 }
161 
Service(const ServiceOptions & options,std::unique_ptr<Backend> execute_backend)162 Service::Service(const ServiceOptions& options,
163                  std::unique_ptr<Backend> execute_backend)
164     : options_(options),
165       allocation_tracker_(execute_backend.get()),
166       execute_backend_(std::move(execute_backend)) {
167   CHECK_GT(options_.number_of_replicas(), 0);
168   if (execute_backend_) {
169     if (execute_backend_->device_count() > 0) {
170       CHECK_GE(execute_backend_->device_count(), options_.number_of_replicas())
171           << "Requested more replicas than there are devices.";
172     }
173     LOG(INFO) << StrFormat(
174         "XLA service %p initialized for platform %s (this does not guarantee "
175         "that XLA will be used). Devices:",
176         this, execute_backend_->platform()->Name());
177     auto stream_executors = execute_backend_->stream_executors();
178     for (int i = 0; i < execute_backend_->device_count(); ++i) {
179       se::StreamExecutor* executor = stream_executors.at(i);
180       const auto& description = executor->GetDeviceDescription();
181       LOG(INFO) << StrFormat("  StreamExecutor device (%d): %s, %s", i,
182                              description.name(),
183                              description.platform_version());
184     }
185   } else {
186     VLOG(1) << "XLA compile-only service constructed";
187   }
188 }
189 
CreateChannelHandle(const CreateChannelHandleRequest * arg,CreateChannelHandleResponse * result)190 Status Service::CreateChannelHandle(const CreateChannelHandleRequest* arg,
191                                     CreateChannelHandleResponse* result) {
192   TF_ASSIGN_OR_RETURN(*result->mutable_channel(),
193                       channel_tracker_.NewChannel(arg->channel_type()));
194   return OkStatus();
195 }
196 
Unregister(const UnregisterRequest * arg,UnregisterResponse * result)197 Status Service::Unregister(const UnregisterRequest* arg,
198                            UnregisterResponse* result) {
199   Status status;
200   for (auto& data : arg->data()) {
201     Status unregister_status = allocation_tracker_.Unregister(data);
202     if (!unregister_status.ok() && status.ok()) {
203       status = unregister_status;
204     }
205   }
206   return status;
207 }
208 
209 // Deconstructs a previously-allocated global handle.
DeconstructTuple(const DeconstructTupleRequest * arg,DeconstructTupleResponse * result)210 Status Service::DeconstructTuple(const DeconstructTupleRequest* arg,
211                                  DeconstructTupleResponse* result) {
212   TF_ASSIGN_OR_RETURN(
213       std::vector<GlobalDataHandle> elements,
214       allocation_tracker_.DeconstructTuple(arg->tuple_handle()));
215 
216   for (auto& element : elements) {
217     *result->add_element_handles() = element;
218   }
219   return OkStatus();
220 }
221 
ValidateResultShape(const Shape & client_shape,const Shape & result_shape) const222 Status Service::ValidateResultShape(const Shape& client_shape,
223                                     const Shape& result_shape) const {
224   TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(client_shape));
225   if (!ShapeUtil::Compatible(client_shape, result_shape)) {
226     return InvalidArgument(
227         "Shape used to set computation result layout %s is not compatible "
228         "with result shape %s",
229         ShapeUtil::HumanStringWithLayout(client_shape),
230         ShapeUtil::HumanString(result_shape));
231   }
232   return OkStatus();
233 }
234 
235 StatusOr<std::vector<std::vector<const ShapedBuffer*>>>
ResolveAndValidateArguments(absl::Span<const GlobalDataHandle * const> arguments,absl::Span<se::StreamExecutor * const> stream_executors) const236 Service::ResolveAndValidateArguments(
237     absl::Span<const GlobalDataHandle* const> arguments,
238     absl::Span<se::StreamExecutor* const> stream_executors) const {
239   CHECK_EQ(options_.number_of_replicas(), stream_executors.size());
240   std::vector<std::vector<const ShapedBuffer*>> replicated_arguments;
241   replicated_arguments.resize(options_.number_of_replicas());
242   for (size_t i = 0; i < arguments.size(); ++i) {
243     auto buffer_status = allocation_tracker_.Resolve(*arguments[i]);
244     if (!buffer_status.ok()) {
245       return tensorflow::errors::CreateWithUpdatedMessage(
246           buffer_status.status(),
247           StrCat(buffer_status.status().error_message(), ", ",
248                  "failed to resolve allocation for parameter ", i));
249     }
250     auto replicated_buffers = buffer_status.ValueOrDie();
251     CHECK_EQ(options_.number_of_replicas(), replicated_buffers.size());
252     for (int replica = 0; replica < options_.number_of_replicas(); ++replica) {
253       const ShapedBuffer* shaped_buffer = replicated_buffers[replica];
254       replicated_arguments[replica].push_back(shaped_buffer);
255     }
256   }
257   return replicated_arguments;
258 }
259 
CreateModuleConfig(const ProgramShape & program_shape,absl::Span<const Shape * const> argument_shapes,const ExecutionOptions * execution_options,const AotCompilationOptions * aot_options)260 StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
261     const ProgramShape& program_shape,
262     absl::Span<const Shape* const> argument_shapes,
263     const ExecutionOptions* execution_options,
264     const AotCompilationOptions* aot_options) {
265   int default_num_replicas = options_.number_of_replicas();
266   std::optional<int> num_threads;
267   if (execute_backend_ != nullptr &&
268       execute_backend_->eigen_intra_op_thread_pool() != nullptr) {
269     num_threads = execute_backend_->eigen_intra_op_thread_pool()->NumThreads();
270   }
271 
272   return xla::CreateModuleConfig(program_shape, argument_shapes,
273                                  execution_options, default_num_replicas,
274                                  num_threads, aot_options);
275 }
276 
CreateModuleConfig(const ProgramShape & program_shape,absl::Span<const ShapedBuffer * const> arguments,const ExecutionOptions & execution_options,const AotCompilationOptions * aot_options)277 StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
278     const ProgramShape& program_shape,
279     absl::Span<const ShapedBuffer* const> arguments,
280     const ExecutionOptions& execution_options,
281     const AotCompilationOptions* aot_options) {
282   std::vector<const Shape*> argument_shapes;
283   for (const auto* arg : arguments) {
284     argument_shapes.push_back(&arg->on_device_shape());
285   }
286   return CreateModuleConfig(program_shape, argument_shapes, &execution_options,
287                             aot_options);
288 }
289 
BuildExecutables(const std::vector<const HloModuleProto * > & module_protos,std::vector<std::unique_ptr<HloModuleConfig>> module_configs,Backend * backend,std::vector<std::vector<se::StreamExecutor * >> executors,const Compiler::CompileOptions & options,bool run_backend_only)290 StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
291     const std::vector<const HloModuleProto*>& module_protos,
292     std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
293     Backend* backend, std::vector<std::vector<se::StreamExecutor*>> executors,
294     const Compiler::CompileOptions& options, bool run_backend_only) {
295   VLOG(1) << StrFormat("BuildExecutable on service %p", this);
296 
297   VLOG(1) << "Computations:";
298   for (const HloModuleProto* proto : module_protos) {
299     VLOG(1) << proto->name();
300   }
301 
302   CHECK_EQ(module_protos.size(), module_configs.size());
303   auto module_group =
304       std::make_unique<HloModuleGroup>(module_protos[0]->name());
305   for (int64_t i = 0, end = module_protos.size(); i < end; ++i) {
306     const HloModuleProto* proto = module_protos[i];
307     const HloModuleConfig& config = *module_configs[i];
308     TF_ASSIGN_OR_RETURN(
309         auto module, CreateModuleFromProto(*proto, config, run_backend_only));
310     UpdateEntryComputationLayout(
311         module.get(), std::bind(&Compiler::DefaultDeviceShapeRepresentation,
312                                 backend->compiler(), std::placeholders::_1));
313     DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName);
314     module_group->push_back(std::move(module));
315   }
316 
317   std::vector<std::unique_ptr<Executable>> executables;
318   if (!run_backend_only) {
319     TF_ASSIGN_OR_RETURN(executables, backend->compiler()->Compile(
320                                          std::move(module_group),
321                                          std::move(executors), options));
322   } else {
323     auto modules = module_group->ConsumeModules();
324     for (std::unique_ptr<HloModule>& module : modules) {
325       TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
326                           backend->compiler()->RunBackend(
327                               std::move(module), executors[0][0], options));
328       executables.push_back(std::move(executable));
329     }
330   }
331 
332   return std::move(executables);
333 }
334 
335 StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
BuildAotResults(const std::vector<const HloModuleProto * > & module_protos,std::vector<std::unique_ptr<HloModuleConfig>> module_configs,Backend * backend,std::vector<std::vector<se::StreamExecutor * >> executors,const Compiler::CompileOptions & options,bool run_backend_only)336 Service::BuildAotResults(
337     const std::vector<const HloModuleProto*>& module_protos,
338     std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
339     Backend* backend, std::vector<std::vector<se::StreamExecutor*>> executors,
340     const Compiler::CompileOptions& options, bool run_backend_only) {
341   VLOG(1) << StrFormat("BuildAotResults on service %p", this);
342 
343   VLOG(1) << "Computations:";
344   for (const HloModuleProto* proto : module_protos) {
345     VLOG(1) << proto->name();
346   }
347 
348   CHECK_EQ(module_protos.size(), module_configs.size());
349   auto module_group =
350       std::make_unique<HloModuleGroup>(module_protos[0]->name());
351   for (int64_t i = 0, end = module_protos.size(); i < end; ++i) {
352     const HloModuleProto* proto = module_protos[i];
353     const HloModuleConfig& config = *module_configs[i];
354     TF_ASSIGN_OR_RETURN(
355         auto module, CreateModuleFromProto(*proto, config, run_backend_only));
356     DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName);
357     module_group->push_back(std::move(module));
358   }
359 
360   AotCompilationOptions aot_options(backend->compiler()->PlatformId());
361   aot_options.set_executor(executors[0][0]);
362   aot_options.set_device_allocator(options.device_allocator);
363   aot_options.set_run_backend_only(run_backend_only);
364 
365   TF_ASSIGN_OR_RETURN(
366       std::vector<std::unique_ptr<AotCompilationResult>> aot_results,
367       backend->compiler()->CompileAheadOfTime(std::move(module_group),
368                                               aot_options));
369   return std::move(aot_results);
370 }
371 
372 StatusOr<std::vector<GlobalDataHandle>>
ExecuteParallelAndRegisterResult(absl::Span<Executable * const> executables,absl::Span<const std::vector<std::vector<const ShapedBuffer * >>> arguments,Backend * backend,absl::Span<const DeviceHandle> device_handles,absl::Span<const std::string> result_tags,ExecutionProfile * profile)373 Service::ExecuteParallelAndRegisterResult(
374     absl::Span<Executable* const> executables,
375     absl::Span<const std::vector<std::vector<const ShapedBuffer*>>> arguments,
376     Backend* backend, absl::Span<const DeviceHandle> device_handles,
377     absl::Span<const std::string> result_tags, ExecutionProfile* profile) {
378   // Streams where the computation are launched, so we can wait on the streams
379   // to complete.
380   std::vector<StreamPool::Ptr> streams;
381   std::vector<std::unique_ptr<se::Timer>> timers;
382 
383   // Global data handles for the computation results, one for each computation.
384   std::vector<GlobalDataHandle> result_handles;
385 
386   // Device ID to stream executor, populated only with devices that are being
387   // profiled.
388   std::map<int64_t, se::Stream*> index_to_profiled_streams;
389 
390   // Build DeviceAssignment for all cores based on the provided device handles.
391   DeviceAssignment device_assignment(options_.number_of_replicas(),
392                                      executables.size());
393   for (int64_t i = 0; i < executables.size(); i++) {
394     TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handles[i]));
395     CHECK_EQ(replicas.size(), arguments[i].size());
396     for (int64_t replica = 0, end = replicas.size(); replica < end; ++replica) {
397       device_assignment(replica, i) = replicas[replica]->device_ordinal();
398     }
399   }
400 
401   for (int64_t i = 0, end = executables.size(); i < end; i++) {
402     // Stream executors for the replicas of the current computation.
403     TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handles[i]));
404     CHECK_EQ(replicas.size(), arguments[i].size());
405     std::vector<ScopedShapedBuffer> result_buffers;
406     const int64_t n = replicas.size();
407     result_buffers.reserve(n);
408     for (int64_t replica = 0; replica < n; ++replica) {
409       TF_ASSIGN_OR_RETURN(StreamPool::Ptr stream,
410                           backend->BorrowStream(replicas[replica]));
411       streams.push_back(std::move(stream));
412 
413       if (replica == 0 && profile != nullptr) {
414         timers.push_back(std::make_unique<se::Timer>(streams.back()->parent()));
415         streams.back()
416             ->InitTimer(timers.back().get())
417             .ThenStartTimer(timers.back().get());
418         CHECK(timers.front() != nullptr);
419       }
420 
421       if (replica == 0 &&
422           executables[i]->module_config().debug_options().xla_hlo_profile() &&
423           executables[i]->hlo_profiling_enabled()) {
424         index_to_profiled_streams[i] = streams.back().get();
425       }
426 
427       // Set up run options.
428       ExecutableRunOptions options;
429       options.set_stream(streams.back().get());
430       options.set_allocator(backend->memory_allocator());
431       options.set_intra_op_thread_pool(
432           backend->eigen_intra_op_thread_pool_device());
433       options.set_device_assignment(&device_assignment);
434       // Use run-time profile information from execution_profile on the 0th
435       // device.
436       if (i == 0) {
437         options.set_execution_profile(profile);
438       }
439       ServiceExecutableRunOptions run_options(options,
440                                               backend->StreamBorrower());
441 
442       // Asynchronously launch the computation.
443       TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
444                           executables[i]->ExecuteAsyncOnStream(
445                               &run_options, arguments[i][replica],
446                               /*hlo_execution_profile=*/nullptr));
447 
448       if (replica == 0 && profile != nullptr) {
449         streams.back()->ThenStopTimer(timers.back().get());
450       }
451 
452       result_buffers.push_back(std::move(result));
453     }
454     TF_ASSIGN_OR_RETURN(GlobalDataHandle handle,
455                         allocation_tracker_.RegisterReplicatedBuffers(
456                             std::move(result_buffers), result_tags[i]));
457     result_handles.push_back(handle);
458   }
459 
460   // Wait for all executions to complete.
461   for (int64_t i = 0, end = streams.size(); i < end; ++i) {
462     Status block_status = streams[i]->BlockHostUntilDone();
463     if (!block_status.ok()) {
464       return InternalError("failed to complete execution for stream %d: %s", i,
465                            block_status.error_message());
466     }
467   }
468 
469   if (profile != nullptr) {
470     CHECK(!timers.empty());
471     std::vector<uint64_t> timer_nanoseconds;
472     timer_nanoseconds.reserve(timers.size());
473     for (auto& timer : timers) {
474       timer_nanoseconds.push_back(timer->Nanoseconds());
475     }
476     uint64_t nanoseconds =
477         *std::max_element(timer_nanoseconds.begin(), timer_nanoseconds.end());
478 
479     // Overall execution time (in nanoseconds) from the executor timer.
480     profile->set_compute_and_transfer_time_ns(nanoseconds);
481 
482     // TODO(b/28123297): On GPU we end up including transfer time in
483     // the compute time this way. Instead, we should get the correct
484     // value by measuring it. Setting the field here at least lets
485     // benchmarks provide *some* value for GPU computations.
486     //
487     // TODO(b/28447609): The value in compute_and_transfer_time_ns is actually
488     // the compute time without the transfer time, so this way we get the
489     // correct compute time. We should instead have the correct value for
490     // compute_and_transfer_time and set compute_time to the compute time.
491     if (profile->compute_time_ns() == 0) {
492       profile->set_compute_time_ns(profile->compute_and_transfer_time_ns());
493     }
494   }
495 
496   return result_handles;
497 }
498 
ExecuteAndRegisterResult(Executable * executable,absl::Span<const std::vector<const ShapedBuffer * >> arguments,Backend * backend,const DeviceHandle & device_handle,const std::string & result_tag,ExecutionProfile * profile)499 StatusOr<GlobalDataHandle> Service::ExecuteAndRegisterResult(
500     Executable* executable,
501     absl::Span<const std::vector<const ShapedBuffer*>> arguments,
502     Backend* backend, const DeviceHandle& device_handle,
503     const std::string& result_tag, ExecutionProfile* profile) {
504   // Set up streams.
505   std::vector<StreamPool::Ptr> streams;
506 
507   TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handle));
508   TF_RET_CHECK(!replicas.empty());
509   for (se::StreamExecutor* executor : replicas) {
510     TF_ASSIGN_OR_RETURN(StreamPool::Ptr stream,
511                         backend->BorrowStream(executor));
512     streams.push_back(std::move(stream));
513   }
514 
515   DeviceAssignment device_assignment(options_.number_of_replicas(),
516                                      /*computation_count=*/1);
517   for (int64_t replica = 0; replica < replicas.size(); ++replica) {
518     device_assignment(replica, 0) = replicas[replica]->device_ordinal();
519   }
520 
521   // Set up run options.
522   std::vector<ServiceExecutableRunOptions> run_options;
523   run_options.reserve(streams.size());
524   for (const StreamPool::Ptr& stream : streams) {
525     ExecutableRunOptions options;
526     options.set_stream(stream.get());
527     options.set_device_ordinal(stream->parent()->device_ordinal());
528     options.set_allocator(backend->memory_allocator());
529     options.set_intra_op_thread_pool(
530         backend->eigen_intra_op_thread_pool_device());
531     options.set_device_assignment(&device_assignment);
532     options.set_execution_profile(profile);
533     run_options.emplace_back(options, backend->StreamBorrower());
534   }
535 
536   if (options_.number_of_replicas() == 1) {
537     TF_ASSIGN_OR_RETURN(auto result, executable->ExecuteOnStreamWrapper(
538                                          &run_options[0], arguments[0]));
539     return allocation_tracker_.Register(std::move(result), result_tag);
540   }
541 
542   // TODO(b/69985541): Support profiling also on this path.
543 
544   std::vector<absl::Span<const ShapedBuffer* const>> replicated_arguments;
545   for (const auto& arg : arguments) {
546     replicated_arguments.push_back(arg);
547   }
548 
549   TF_ASSIGN_OR_RETURN(auto results, executable->ExecuteOnStreams(
550                                         run_options, replicated_arguments));
551   TF_RET_CHECK(!results.empty());
552   return allocation_tracker_.RegisterReplicatedBuffers(std::move(results),
553                                                        result_tag);
554 }
555 
GetExecutors(const ExecutionOptions & execution_options,int64_t requests_size,int64_t request_index) const556 StatusOr<std::vector<se::StreamExecutor*>> Service::GetExecutors(
557     const ExecutionOptions& execution_options, int64_t requests_size,
558     int64_t request_index) const {
559   if (execution_options.device_handles().empty()) {
560     return FailedPrecondition(
561         "device handles must be given to execute parallel computations");
562   }
563   if (requests_size > 1 && execution_options.device_handles_size() > 1) {
564     return InvalidArgument(
565         "Parallel requests with multiple device handles is not supported. "
566         "Found %d parallel requests, with request %d containing %d device "
567         "handles.",
568         requests_size, request_index, execution_options.device_handles_size());
569   }
570   std::vector<se::StreamExecutor*> executors;
571   for (const auto& device_handle : execution_options.device_handles()) {
572     TF_ASSIGN_OR_RETURN(auto replicas,
573                         Replicas(*execute_backend_, device_handle));
574     se::StreamExecutor* executor = replicas[0];
575     CHECK(executor != nullptr);
576     executors.push_back(executor);
577   }
578   return executors;
579 }
580 
GetArguments(const ExecutionOptions & execution_options,absl::Span<const GlobalDataHandle * const> arguments) const581 StatusOr<std::vector<std::vector<const ShapedBuffer*>>> Service::GetArguments(
582     const ExecutionOptions& execution_options,
583     absl::Span<const GlobalDataHandle* const> arguments) const {
584   // Resolve the allocations for the arguments of the computation, and create
585   // a vector of device memory offsets for the arguments from the allocations.
586   // In the case of partitioned computations, assume all arguments go on the
587   // zeroth core.
588   TF_ASSIGN_OR_RETURN(
589       auto replicas,
590       Replicas(*execute_backend_, execution_options.device_handles(0)));
591   TF_ASSIGN_OR_RETURN(
592       std::vector<std::vector<const ShapedBuffer*>> replicated_arguments,
593       ResolveAndValidateArguments(arguments, replicas));
594   return replicated_arguments;
595 }
596 
ExecuteGraphParallel(const ExecuteGraphParallelRequest * arg,ExecuteParallelResponse * result)597 Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg,
598                                      ExecuteParallelResponse* result) {
599   VLOG(1) << "running execute-graph-parallel request";
600 
601   std::vector<std::vector<std::vector<const ShapedBuffer*>>> all_arguments;
602   std::vector<std::vector<se::StreamExecutor*>> all_executors;
603   std::vector<const HloModuleProto*> module_protos;
604   std::vector<std::unique_ptr<HloModuleConfig>> module_configs;
605   std::vector<std::string> computation_names;
606   std::vector<DeviceHandle> device_handles;
607 
608   int num_requested_devices =
609       std::accumulate(arg->requests().begin(), arg->requests().end(), 0,
610                       [](int a, const ExecuteGraphRequest& r) -> int {
611                         return a + r.execution_options().device_handles_size();
612                       });
613   if (num_requested_devices * options_.number_of_replicas() >
614       execute_backend_->device_count()) {
615     return FailedPrecondition(
616         "there are not enough stream executors to execute %d computations",
617         num_requested_devices);
618   }
619 
620   for (int64_t i = 0; i < arg->requests_size(); ++i) {
621     // Get the stream executor for the i'th computation. This stream executor
622     // is one of the executors to run the replicated computation.
623     const ExecutionOptions& execution_options =
624         arg->requests(i).execution_options();
625     const ExecuteGraphRequest& request = arg->requests(i);
626     TF_RET_CHECK(request.has_computation()) << "computations may not be empty";
627     TF_RET_CHECK(request.computation().has_host_program_shape())
628         << "program shape may not be empty";
629 
630     // Get the executors.
631     TF_ASSIGN_OR_RETURN(auto executors, GetExecutors(execution_options,
632                                                      arg->requests_size(), i));
633 
634     // Get the replicated arguments.
635     TF_ASSIGN_OR_RETURN(auto replicated_arguments,
636                         GetArguments(execution_options, request.arguments()));
637 
638     for (auto& args : replicated_arguments) {
639       for (auto& arg : args) {
640         auto update_shape_with_empty_tiles = [this](
641                                                  Shape* subshape,
642                                                  const xla::ShapeIndex& index) {
643           if (subshape->IsArray() && subshape->layout().tiles().empty()) {
644             *subshape =
645                 execute_backend_->transfer_manager()->HostShapeToDeviceShape(
646                     *subshape);
647           }
648         };
649         ShapeUtil::ForEachMutableSubshape(
650             const_cast<Shape*>(&arg->on_device_shape()),
651             update_shape_with_empty_tiles);
652       }
653     }
654 
655     // Create an HloModuleConfig object for the computation, given the shape of
656     // the program and the argument allocations. Here, we care only about the
657     // shapes of the arguments, so, it is sufficient to use the arguments of
658     // replica 0.
659     TF_ASSIGN_OR_RETURN(
660         std::unique_ptr<HloModuleConfig> module_config,
661         CreateModuleConfig(
662             ProgramShape{request.computation().host_program_shape()},
663             replicated_arguments.front(), request.execution_options()));
664     VLOG(3)
665         << "ExecuteGraphParallel created HloModuleConfig computation layout: "
666         << module_config->entry_computation_layout().ToString();
667 
668     // Adds to the vectors to build and execute the computations after the loop.
669     all_arguments.push_back(replicated_arguments);
670     all_arguments.insert(all_arguments.end(), executors.size() - 1, {{}});
671     module_protos.push_back(&request.computation());
672     module_configs.push_back(std::move(module_config));
673     computation_names.insert(computation_names.end(), executors.size(),
674                              request.computation().name());
675     all_executors.push_back(executors);
676     device_handles.insert(device_handles.end(),
677                           execution_options.device_handles().begin(),
678                           execution_options.device_handles().end());
679   }
680 
681   // Build the HloModules and compile to generate the executables.
682   //
683   // TODO(jlebar): There's currently no way to pass a device allocator to
684   // ExecuteGraphParallel, so we have to pass a null device_allocator below.
685   TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<Executable>> executables,
686                       BuildExecutables(module_protos, std::move(module_configs),
687                                        execute_backend_.get(), all_executors,
688                                        {/*device_allocator=*/nullptr}));
689   std::vector<Executable*> executable_ptrs;
690   executable_ptrs.reserve(executables.size());
691   for (const auto& executable : executables) {
692     executable_ptrs.push_back(executable.get());
693   }
694 
695   std::vector<HloSnapshot> snapshots;
696   snapshots.resize(executable_ptrs.size());
697   for (int i = 0, end = executable_ptrs.size(); i < end; i++) {
698     if (executable_ptrs[i]->dumping_snapshot()) {
699       *snapshots[i].mutable_hlo() = *executable_ptrs[i]->hlo_proto();
700       TF_ASSIGN_OR_RETURN(auto stream,
701                           execute_backend_->BorrowStream(
702                               all_executors[i][0]->device_ordinal()));
703       TF_RETURN_IF_ERROR(RecordArguments(all_arguments[i].front(), stream.get(),
704                                          execute_backend_->transfer_manager(),
705                                          &snapshots[i]));
706     }
707   }
708 
709   // If we have multiple executables to run, execute them all in parallel.  But
710   // if we only have one executable, execute it using the vanilla, non-parallel
711   // call.
712   //
713   // We do this because the Client API uses ExecuteGraphParallel when it wants
714   // to compile and run one computation without caching the executable, but not
715   // all backends support the async StreamExecutor API required by
716   // ExecuteParallelAndRegisterResult.
717   //
718   // TODO(b/122731460): Consolidate Execute{,Parallel}AndRegisterResult; they do
719   // basically the same thing.
720   ExecutionProfile profile;
721   std::vector<GlobalDataHandle> outputs;
722   Status execution_status = OkStatus();
723 
724   if (executable_ptrs.size() == 1) {
725     StatusOr<GlobalDataHandle> output_or_status = ExecuteAndRegisterResult(
726         executable_ptrs[0], all_arguments[0], execute_backend_.get(),
727         device_handles[0], computation_names[0], &profile);
728     if (output_or_status.ok()) {
729       outputs.push_back(std::move(output_or_status).ValueOrDie());
730     } else {
731       execution_status = output_or_status.status();
732     }
733   } else {
734     StatusOr<std::vector<GlobalDataHandle>> outputs_or_status =
735         ExecuteParallelAndRegisterResult(executable_ptrs, all_arguments,
736                                          execute_backend_.get(), device_handles,
737                                          computation_names, &profile);
738     if (outputs_or_status.ok()) {
739       outputs = std::move(outputs_or_status).ValueOrDie();
740     } else {
741       execution_status = outputs_or_status.status();
742     }
743   }
744 
745   if (!execution_status.ok()) {
746     // Execution failed so we don't have the results.  Dump the HLO snapshot
747     // with just the program arguments.
748     for (int i = 0, end = executable_ptrs.size(); i < end; i++) {
749       DumpHloSnapshotIfEnabled(executable_ptrs[i]->module(), snapshots[i]);
750     }
751   }
752 
753   TF_RETURN_IF_ERROR(execution_status);
754 
755   for (const GlobalDataHandle& output : outputs) {
756     ExecuteResponse response;
757     *response.mutable_output() = output;
758     *response.mutable_profile() = profile;
759     *result->add_responses() = response;
760   }
761 
762   for (int i = 0, end = executable_ptrs.size(); i < end; i++) {
763     Executable* executable = executable_ptrs[i];
764     if (executable->dumping_snapshot()) {
765       TF_ASSIGN_OR_RETURN(const ShapedBuffer* result_buffer,
766                           allocation_tracker_.ResolveForReplica(outputs[i], 0));
767       TF_ASSIGN_OR_RETURN(auto stream,
768                           execute_backend_->BorrowStream(all_executors[i][0]));
769       TF_RETURN_IF_ERROR(RecordResult(*result_buffer, stream.get(),
770                                       execute_backend_->transfer_manager(),
771                                       &snapshots[i]));
772       DumpHloSnapshotIfEnabled(executable->module(), snapshots[i]);
773     }
774   }
775 
776   VLOG(1) << "successfully completed 'execute-graph-parallel' request";
777   return OkStatus();
778 }
779 
GetDeviceHandles(const GetDeviceHandlesRequest * arg,GetDeviceHandlesResponse * result)780 Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg,
781                                  GetDeviceHandlesResponse* result) {
782   const int64_t available_device_count = execute_backend_->device_count();
783   const int64_t replica_count = options_.number_of_replicas();
784   if (replica_count <= 0) {
785     return FailedPrecondition("Replica count must be a positive integer");
786   }
787   if (available_device_count < arg->device_count() * replica_count) {
788     return ResourceExhausted(
789         "Requested logical device count (%d) with replica count (%d) exceeds "
790         "the number of available physical devices on the target (%d)",
791         arg->device_count(), replica_count, available_device_count);
792   }
793 
794   for (int64_t i = 0; i < arg->device_count(); ++i) {
795     DeviceHandle device_handle;
796     device_handle.set_handle(i);
797     device_handle.set_device_count(arg->device_count());
798     *result->add_device_handles() = device_handle;
799   }
800 
801   return OkStatus();
802 }
803 
BuildExecutable(const HloModuleProto & module_proto,std::unique_ptr<HloModuleConfig> module_config,Backend * backend,se::StreamExecutor * executor,const Compiler::CompileOptions & options,bool run_backend_only)804 StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
805     const HloModuleProto& module_proto,
806     std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
807     se::StreamExecutor* executor, const Compiler::CompileOptions& options,
808     bool run_backend_only) {
809   VLOG(1) << StrFormat(
810       "BuildExecutable on service %p with serialized module proto: %s", this,
811       module_proto.name());
812 
813   TF_ASSIGN_OR_RETURN(
814       std::unique_ptr<HloModule> module,
815       CreateModuleFromProto(module_proto, *module_config, run_backend_only));
816   UpdateEntryComputationLayout(
817       module.get(), std::bind(&Compiler::DefaultDeviceShapeRepresentation,
818                               backend->compiler(), std::placeholders::_1));
819   DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName);
820 
821   std::unique_ptr<HloProto> hlo_proto_before_opt;
822   if (!run_backend_only) {
823     // Save proto state before optimizations if we want a snapshot.
824     // When run_backend_only is enabled the post-optimization HLO will be the
825     // same as the pre-optimization HLO.
826     if (DumpingEnabledForHloModule(*module)) {
827       hlo_proto_before_opt = std::make_unique<HloProto>(MakeHloProto(*module));
828     }
829     TF_ASSIGN_OR_RETURN(module, backend->compiler()->RunHloPasses(
830                                     std::move(module), executor, options));
831   }
832 
833   TF_ASSIGN_OR_RETURN(
834       std::unique_ptr<Executable> executable,
835       backend->compiler()->RunBackend(std::move(module), executor, options));
836 
837   const HloProto* hlo_proto_after_opt = executable->hlo_proto();
838 
839   // If dumping is enabled RunBackend(...) will emit a hlo_proto in the
840   // executable. This contains the buffer_assignment that is only available
841   // after RunBackend(). If hlo_proto_before_opt is not null, then we replace
842   // its buffer_assignment with the one from after_opt and then store it into
843   // the executable.
844   if (hlo_proto_before_opt != nullptr && hlo_proto_after_opt != nullptr) {
845     CHECK(DumpingEnabledForHloModule(executable->module()));
846     *hlo_proto_before_opt->mutable_buffer_assignment() =
847         hlo_proto_after_opt->buffer_assignment();
848     executable->set_hlo_proto(std::move(hlo_proto_before_opt));
849   }
850   return std::move(executable);
851 }
852 
Compile(const CompileRequest * arg,CompileResponse * result)853 Status Service::Compile(const CompileRequest* arg, CompileResponse* result) {
854   VLOG(1) << "running compile request";
855   if (!arg->has_computation()) {
856     return InvalidArgument("computations may not be empty");
857   }
858   if (!arg->computation().has_host_program_shape()) {
859     return InvalidArgument("program shape may not be empty");
860   }
861 
862   if (arg->execution_options().device_handles_size() > 1) {
863     return InvalidArgument(
864         "The compile request does not support multiple device handles.");
865   }
866 
867   std::vector<Shape> argument_shapes;
868   argument_shapes.reserve(arg->input_shape_with_layout_size());
869   std::vector<const Shape*> argument_shape_ptrs;
870   for (const ShapeProto& shape_proto : arg->input_shape_with_layout()) {
871     argument_shapes.push_back(Shape(shape_proto));
872     argument_shape_ptrs.push_back(&argument_shapes.back());
873   }
874   TF_ASSIGN_OR_RETURN(
875       std::unique_ptr<HloModuleConfig> module_config,
876       CreateModuleConfig(ProgramShape{arg->computation().host_program_shape()},
877                          argument_shape_ptrs, &arg->execution_options()));
878   VLOG(3) << "Compile created HloModuleConfig computation layout: "
879           << module_config->entry_computation_layout().ToString();
880 
881   TF_ASSIGN_OR_RETURN(
882       std::unique_ptr<Executable> executable,
883       BuildExecutable(arg->computation(), std::move(module_config),
884                       execute_backend_.get(),
885                       execute_backend_->default_stream_executor(),
886                       {/*device_allocator=*/nullptr}));
887 
888   *result->mutable_handle() = compilation_cache_.Insert(std::move(executable));
889 
890   VLOG(1) << "successfully completed 'compile' request";
891   return OkStatus();
892 }
893 
Execute(const ExecuteRequest * arg,ExecuteResponse * result)894 Status Service::Execute(const ExecuteRequest* arg, ExecuteResponse* result) {
895   VLOG(1) << "running execute request";
896   if (!arg->has_handle()) {
897     return InvalidArgument("execution handle should not be empty");
898   }
899   TF_ASSIGN_OR_RETURN(auto executable,
900                       compilation_cache_.LookUp(arg->handle()));
901 
902   TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_,
903                                               SingleComputationDeviceHandle()));
904   TF_ASSIGN_OR_RETURN(
905       std::vector<std::vector<const ShapedBuffer*>> replicated_arguments,
906       ResolveAndValidateArguments(arg->arguments(), replicas));
907 
908   // Check that the replicated_arguments has the same shape and layout as the
909   // module config used when creating the executable.
910   const int64_t num_module_args =
911       executable->module_config().entry_computation_layout().parameter_count();
912   if (num_module_args != arg->arguments_size()) {
913     return InvalidArgument(
914         "The executable expects %lld arguments, but sees %lld.",
915         num_module_args, arg->arguments_size());
916   }
917   for (int64_t i = 0; i < num_module_args; i++) {
918     const Shape& shape_module =
919         executable->module_config().entry_computation_layout().parameter_shape(
920             i);
921     const Shape& shape_arg = replicated_arguments.front()[i]->on_device_shape();
922     if (!ShapeUtil::Equal(shape_module, shape_arg)) {
923       return InvalidArgumentStrCat(
924           "The executable expects the ", i, "th argument in shape ",
925           ShapeUtil::HumanStringWithLayout(shape_module), " but sees ",
926           ShapeUtil::HumanStringWithLayout(shape_arg));
927     }
928   }
929 
930   TF_ASSIGN_OR_RETURN(auto stream,
931                       execute_backend_->BorrowStream(
932                           execute_backend_->default_stream_executor()));
933   HloSnapshot snapshot;
934   if (executable->dumping_snapshot()) {
935     *snapshot.mutable_hlo() = *executable->hlo_proto();
936     snapshot.set_execution_platform(execute_backend_->platform()->Name());
937     TF_RETURN_IF_ERROR(
938         RecordArguments(replicated_arguments.front(), stream.get(),
939                         execute_backend_->transfer_manager(), &snapshot));
940   }
941 
942   TF_ASSIGN_OR_RETURN(
943       *result->mutable_output(),
944       ExecuteAndRegisterResult(executable.get(), replicated_arguments,
945                                execute_backend_.get(),
946                                SingleComputationDeviceHandle(),
947                                "result of " + executable->module().name(),
948                                result->mutable_profile()));
949 
950   if (executable->dumping_snapshot()) {
951     TF_ASSIGN_OR_RETURN(
952         const ShapedBuffer* result_buffer,
953         allocation_tracker_.ResolveForReplica(result->output(), 0));
954     TF_RETURN_IF_ERROR(RecordResult(*result_buffer, stream.get(),
955                                     execute_backend_->transfer_manager(),
956                                     &snapshot));
957     DumpHloSnapshotIfEnabled(executable->module(), snapshot);
958   }
959 
960   VLOG(1) << "successfully completed 'execute' request";
961   return OkStatus();
962 }
963 
WaitForExecution(const WaitForExecutionRequest * arg,WaitForExecutionResponse * result)964 Status Service::WaitForExecution(const WaitForExecutionRequest* arg,
965                                  WaitForExecutionResponse* result) {
966   TF_ASSIGN_OR_RETURN(const auto execution,
967                       execution_tracker_.Resolve(arg->execution()));
968 
969   TF_RETURN_IF_ERROR(execution->BlockUntilDone());
970 
971   *result->mutable_output() = execution->result();
972   *result->mutable_profile() = execution->profile();
973 
974   TF_RETURN_IF_ERROR(execution_tracker_.Unregister(arg->execution()));
975   VLOG(1) << "successfully completed 'wait-for-execution' request";
976   return OkStatus();
977 }
978 
TransferToClient(const TransferToClientRequest * arg,TransferToClientResponse * result)979 Status Service::TransferToClient(const TransferToClientRequest* arg,
980                                  TransferToClientResponse* result) {
981   TF_ASSIGN_OR_RETURN(const ShapedBuffer* shaped_buffer,
982                       allocation_tracker_.ResolveForReplica(arg->data(), 0));
983 
984   Shape return_shape;
985   if (arg->has_shape_with_layout()) {
986     return_shape = Shape(arg->shape_with_layout());
987     if (!LayoutUtil::HasLayout(return_shape)) {
988       return InvalidArgument("shape_with_layout must have layout if present.");
989     }
990   } else {
991     return_shape = Shape(shaped_buffer->on_device_shape());
992   }
993 
994   TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream(
995                                        shaped_buffer->device_ordinal()));
996 
997   TF_ASSIGN_OR_RETURN(
998       Literal result_literal,
999       execute_backend_->transfer_manager()->TransferLiteralFromDevice(
1000           stream.get(), *shaped_buffer));
1001 
1002   if (LayoutUtil::LayoutsInShapesEqual(return_shape, result_literal.shape())) {
1003     *result->mutable_literal() = result_literal.ToProto();
1004   } else {
1005     *result->mutable_literal() =
1006         result_literal.Relayout(return_shape).ToProto();
1007   }
1008   return OkStatus();
1009 }
1010 
TransferToServer(const TransferToServerRequest * arg,TransferToServerResponse * result)1011 Status Service::TransferToServer(const TransferToServerRequest* arg,
1012                                  TransferToServerResponse* result) {
1013   TF_ASSIGN_OR_RETURN(Literal literal,
1014                       Literal::CreateFromProto(arg->literal()));
1015   const Shape& shape = literal.shape();
1016 
1017   std::vector<se::StreamExecutor*> replicas;
1018   if (arg->has_device_handle()) {
1019     TF_ASSIGN_OR_RETURN(replicas,
1020                         Replicas(*execute_backend_, arg->device_handle()));
1021   } else {
1022     TF_ASSIGN_OR_RETURN(
1023         replicas, Replicas(*execute_backend_, SingleComputationDeviceHandle()));
1024   }
1025 
1026   // Allocate memory in each replica and transfer the data to all replicas.
1027   std::vector<ScopedShapedBuffer> replicated_buffers;
1028   replicated_buffers.reserve(replicas.size());
1029   for (se::StreamExecutor* executor : replicas) {
1030     auto device_shape_representation_fn = [this](const Shape& shape) {
1031       return execute_backend_->compiler()->DefaultDeviceShapeRepresentation(
1032           shape);
1033     };
1034     TF_ASSIGN_OR_RETURN(
1035         ScopedShapedBuffer shaped_buffer,
1036         execute_backend_->transfer_manager()->AllocateScopedShapedBuffer(
1037             shape, execute_backend_->memory_allocator(),
1038             executor->device_ordinal(), device_shape_representation_fn));
1039     TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream(executor));
1040     TF_RETURN_IF_ERROR(
1041         execute_backend_->transfer_manager()->TransferLiteralToDevice(
1042             stream.get(), literal, shaped_buffer));
1043     replicated_buffers.emplace_back(std::move(shaped_buffer));
1044   }
1045   TF_ASSIGN_OR_RETURN(*result->mutable_data(),
1046                       allocation_tracker_.RegisterReplicatedBuffers(
1047                           std::move(replicated_buffers),
1048                           StrCat("TransferToServer literal of shape ",
1049                                  ShapeUtil::HumanString(shape))));
1050 
1051   return OkStatus();
1052 }
1053 
TransferToInfeed(const TransferToInfeedRequest * arg,TransferToInfeedResponse * result)1054 Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
1055                                  TransferToInfeedResponse* result) {
1056   const int64_t replica_count = options_.number_of_replicas();
1057   if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) {
1058     return FailedPrecondition(
1059         "%s",
1060         StrCat("The replica_id=", arg->replica_id(),
1061                " on TransferToInfeedRequest not in range [0, replica_count=",
1062                replica_count, ")."));
1063   }
1064 
1065   se::StreamExecutor* executor;
1066   if (arg->has_device_handle()) {
1067     TF_ASSIGN_OR_RETURN(auto replicas,
1068                         Replicas(*execute_backend_, arg->device_handle()));
1069     executor = replicas[arg->replica_id()];
1070   } else {
1071     TF_ASSIGN_OR_RETURN(
1072         auto replicas,
1073         Replicas(*execute_backend_, SingleComputationDeviceHandle()));
1074     executor = replicas[arg->replica_id()];
1075   }
1076 
1077   TF_ASSIGN_OR_RETURN(Literal literal,
1078                       Literal::CreateFromProto(arg->literal()));
1079   return execute_backend_->transfer_manager()->TransferLiteralToInfeed(executor,
1080                                                                        literal);
1081 }
1082 
TransferFromOutfeed(const TransferFromOutfeedRequest * arg,TransferFromOutfeedResponse * result)1083 Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg,
1084                                     TransferFromOutfeedResponse* result) {
1085   const int64_t replica_count = options_.number_of_replicas();
1086   if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) {
1087     return FailedPrecondition(
1088         "The replica_id=%d on TransferFromOutfeedRequest not in range [0, %d)",
1089         arg->replica_id(), replica_count);
1090   }
1091 
1092   se::StreamExecutor* executor;
1093   if (arg->has_device_handle()) {
1094     TF_ASSIGN_OR_RETURN(auto replicas,
1095                         Replicas(*execute_backend_, arg->device_handle()));
1096     executor = replicas[arg->replica_id()];
1097   } else {
1098     TF_ASSIGN_OR_RETURN(
1099         auto replicas,
1100         Replicas(*execute_backend_, SingleComputationDeviceHandle()));
1101     executor = replicas[arg->replica_id()];
1102   }
1103 
1104   auto literal = Literal::CreateFromShape(Shape(arg->shape_with_layout()));
1105 
1106   TF_RETURN_IF_ERROR(
1107       execute_backend_->transfer_manager()->TransferLiteralFromOutfeed(
1108           executor, &literal));
1109   *result->mutable_literal() = literal.ToProto();
1110   return OkStatus();
1111 }
1112 
ResetDevice(const ResetDeviceRequest * arg,ResetDeviceResponse * result)1113 Status Service::ResetDevice(const ResetDeviceRequest* arg,
1114                             ResetDeviceResponse* result) {
1115   return execute_backend_->ResetDevices();
1116 }
1117 
ComputeConstantGraph(const ComputeConstantGraphRequest * arg,ComputeConstantResponse * result)1118 Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg,
1119                                      ComputeConstantResponse* result) {
1120   if (!arg->has_computation()) {
1121     return InvalidArgument("computations may not be empty");
1122   }
1123   if (!arg->computation().has_host_program_shape()) {
1124     return InvalidArgument("program shape may not be empty");
1125   }
1126   if (arg->computation().host_program_shape().parameters_size() != 0) {
1127     return InvalidArgument(
1128         "constant computation may not depend on any parameters.");
1129   }
1130 
1131   ProgramShape program_shape(arg->computation().host_program_shape());
1132   TF_DCHECK_OK(ShapeUtil::ValidateShape(program_shape.result()));
1133   std::optional<Layout> output_layout;
1134   if (arg->has_output_layout()) {
1135     output_layout = Layout::CreateFromProto(arg->output_layout());
1136     TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape(
1137         *output_layout, program_shape.result()));
1138   }
1139 
1140   HloModuleConfig config(program_shape);
1141 
1142   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
1143                       CreateModuleFromProto(arg->computation(), config));
1144   DynamicPadder dynamic_padder;
1145   TF_RETURN_IF_ERROR(dynamic_padder.Run(module.get()).status());
1146 
1147   TF_ASSIGN_OR_RETURN(DynamicDimensionInference dynamic_dimension_inference,
1148                       DynamicDimensionInference::Run(module.get()));
1149 
1150   HloEvaluator evaluator;
1151   evaluator.set_dynamic_dimension_inference(&dynamic_dimension_inference);
1152   evaluator.set_custom_call_handler(
1153       [](HloInstruction* custom_call,
1154          absl::Span<const Literal*> operands) -> StatusOr<Literal> {
1155         if (custom_call->custom_call_target() == "SliceToDynamic") {
1156           auto result = operands[0]->Clone();
1157           for (int64_t i = 0; i < result.shape().rank(); ++i) {
1158             result.SetDynamicSize(i, operands[1 + i]->Get<int32_t>({}));
1159           }
1160           return result.ToStatic();
1161         }
1162         return Unimplemented("Custom call %s is not supported: %s",
1163                              custom_call->custom_call_target(),
1164                              custom_call->ToString());
1165       });
1166   TF_ASSIGN_OR_RETURN(auto result_literal, evaluator.Evaluate(*module, {}));
1167 
1168   // Since the result layout is non-effective to the Evaluator results, explicit
1169   // relayout here.
1170   //
1171   // TODO(b/77824332): Make HloEvaluator take care of the re-layout.
1172   if (output_layout.has_value()) {
1173     result_literal = result_literal.Relayout(*output_layout);
1174   }
1175   *result->mutable_literal() = result_literal.ToProto();
1176 
1177   return OkStatus();
1178 }
1179 
GetShape(const GetShapeRequest * arg,GetShapeResponse * result)1180 Status Service::GetShape(const GetShapeRequest* arg, GetShapeResponse* result) {
1181   TF_ASSIGN_OR_RETURN(const ShapedBuffer* buffer,
1182                       allocation_tracker_.ResolveForReplica(arg->data(), 0));
1183   *result->mutable_shape() = buffer->on_device_shape().ToProto();
1184   return OkStatus();
1185 }
1186 
GetComputationGraphStats(const ComputationGraphStatsRequest * arg,ComputationStatsResponse * result)1187 Status Service::GetComputationGraphStats(
1188     const ComputationGraphStatsRequest* arg, ComputationStatsResponse* result) {
1189   if (!arg->has_computation()) {
1190     return InvalidArgument("Computations may not be empty.");
1191   }
1192   if (!arg->computation().has_host_program_shape()) {
1193     return InvalidArgument("Program shape may not be empty.");
1194   }
1195 
1196   HloModuleConfig config(ProgramShape{arg->computation().host_program_shape()});
1197   config.set_debug_options(arg->debug_options());
1198   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
1199                       CreateModuleFromProto(arg->computation(), config));
1200   UpdateEntryComputationLayout(
1201       module.get(),
1202       std::bind(&Compiler::DefaultDeviceShapeRepresentation,
1203                 execute_backend_->compiler(), std::placeholders::_1));
1204   DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName);
1205 
1206   // Run HLO analysis to get the computation statistics.
1207   HloCostAnalysis analysis(
1208       execute_backend_->compiler()->ShapeSizeBytesFunction());
1209 
1210   TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&analysis));
1211 
1212   ComputationStats stats;
1213   stats.set_flop_count(analysis.flop_count());
1214   stats.set_transcendental_count(analysis.transcendental_count());
1215   *result->mutable_stats() = stats;
1216   return OkStatus();
1217 }
1218 
SingleComputationDeviceHandle() const1219 DeviceHandle Service::SingleComputationDeviceHandle() const {
1220   DeviceHandle device_handle;
1221   device_handle.set_handle(0);
1222   device_handle.set_device_count(1);
1223   return device_handle;
1224 }
1225 
Replicas(const Backend & backend,const DeviceHandle & device_handle) const1226 StatusOr<std::vector<se::StreamExecutor*>> Service::Replicas(
1227     const Backend& backend, const DeviceHandle& device_handle) const {
1228   std::vector<se::StreamExecutor*> replicas;
1229   for (int replica = 0; replica < options_.number_of_replicas(); ++replica) {
1230     // From the computation placer, find out the device ids of the replicas for
1231     // the given device handle.
1232     TF_ASSIGN_OR_RETURN(
1233         int device_ordinal,
1234         backend.computation_placer()->DeviceId(replica, device_handle.handle(),
1235                                                options_.number_of_replicas(),
1236                                                device_handle.device_count()));
1237     TF_ASSIGN_OR_RETURN(auto executor, backend.stream_executor(device_ordinal));
1238     replicas.push_back(executor);
1239   }
1240   return replicas;
1241 }
1242 
1243 }  // namespace xla
1244