1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/service.h"
17
18 #include <algorithm>
19 #include <functional>
20 #include <memory>
21 #include <numeric>
22 #include <string>
23 #include <utility>
24 #include <vector>
25
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/str_format.h"
28 #include "tensorflow/compiler/xla/debug_options_flags.h"
29 #include "tensorflow/compiler/xla/execution_options_util.h"
30 #include "tensorflow/compiler/xla/layout_util.h"
31 #include "tensorflow/compiler/xla/service/backend.h"
32 #include "tensorflow/compiler/xla/service/compiler.h"
33 #include "tensorflow/compiler/xla/service/computation_layout.h"
34 #include "tensorflow/compiler/xla/service/computation_placer.h"
35 #include "tensorflow/compiler/xla/service/dump.h"
36 #include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h"
37 #include "tensorflow/compiler/xla/service/dynamic_padder.h"
38 #include "tensorflow/compiler/xla/service/executable.h"
39 #include "tensorflow/compiler/xla/service/hlo_computation.h"
40 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
41 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
42 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
43 #include "tensorflow/compiler/xla/service/hlo_module.h"
44 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
45 #include "tensorflow/compiler/xla/service/hlo_module_group.h"
46 #include "tensorflow/compiler/xla/service/hlo_module_util.h"
47 #include "tensorflow/compiler/xla/service/hlo_proto_util.h"
48 #include "tensorflow/compiler/xla/service/platform_util.h"
49 #include "tensorflow/compiler/xla/service/source_map_util.h"
50 #include "tensorflow/compiler/xla/service/stream_pool.h"
51 #include "tensorflow/compiler/xla/service/transfer_manager.h"
52 #include "tensorflow/compiler/xla/shape.h"
53 #include "tensorflow/compiler/xla/shape_layout.h"
54 #include "tensorflow/compiler/xla/shape_util.h"
55 #include "tensorflow/compiler/xla/status_macros.h"
56 #include "tensorflow/compiler/xla/types.h"
57 #include "tensorflow/compiler/xla/util.h"
58 #include "tensorflow/compiler/xla/xla_data.pb.h"
59 #include "tensorflow/core/platform/env.h"
60 #include "tensorflow/core/platform/errors.h"
61 #include "tensorflow/core/platform/logging.h"
62 #include "tensorflow/core/platform/protobuf.h"
63 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
64 #include "tensorflow/core/util/ptr_util.h"
65 #include "tensorflow/stream_executor/device_memory_allocator.h"
66
67 namespace xla {
68 namespace {
69
70 using absl::StrCat;
71 using absl::StrFormat;
72
73 // Argument used when calling DumpHloModuleIfEnabled before optimizations are
74 // performed on an HloModule.
75 constexpr char kBeforeOptimizationsDumpName[] = "before_optimizations";
76
77 // Records the arguments used to invoke a computation in an HloSnapshot proto.
RecordArguments(const absl::Span<const ShapedBuffer * const> arguments,se::Stream * stream,TransferManager * transfer_manager,HloSnapshot * module)78 Status RecordArguments(const absl::Span<const ShapedBuffer* const> arguments,
79 se::Stream* stream, TransferManager* transfer_manager,
80 HloSnapshot* module) {
81 module->clear_arguments();
82 for (const ShapedBuffer* argument : arguments) {
83 TF_ASSIGN_OR_RETURN(
84 Literal literal,
85 transfer_manager->TransferLiteralFromDevice(stream, *argument));
86 *module->add_arguments() = literal.ToProto();
87 }
88 return OkStatus();
89 }
90
91 // Records the result of a computation in a HloSnapshot proto.
RecordResult(const ShapedBuffer & result,se::Stream * stream,TransferManager * transfer_manager,HloSnapshot * module)92 Status RecordResult(const ShapedBuffer& result, se::Stream* stream,
93 TransferManager* transfer_manager, HloSnapshot* module) {
94 module->clear_result();
95 TF_ASSIGN_OR_RETURN(
96 Literal literal,
97 transfer_manager->TransferLiteralFromDevice(stream, result));
98 *module->mutable_result() = literal.ToProto();
99 return OkStatus();
100 }
101
102 } // namespace
103
set_platform(se::Platform * platform)104 ServiceOptions& ServiceOptions::set_platform(se::Platform* platform) {
105 platform_ = platform;
106 return *this;
107 }
108
platform() const109 se::Platform* ServiceOptions::platform() const { return platform_; }
110
set_number_of_replicas(int number_of_replicas)111 ServiceOptions& ServiceOptions::set_number_of_replicas(int number_of_replicas) {
112 number_of_replicas_ = number_of_replicas;
113 return *this;
114 }
115
number_of_replicas() const116 int ServiceOptions::number_of_replicas() const { return number_of_replicas_; }
117
set_intra_op_parallelism_threads(int num_threads)118 ServiceOptions& ServiceOptions::set_intra_op_parallelism_threads(
119 int num_threads) {
120 intra_op_parallelism_threads_ = num_threads;
121 return *this;
122 }
123
intra_op_parallelism_threads() const124 int ServiceOptions::intra_op_parallelism_threads() const {
125 return intra_op_parallelism_threads_;
126 }
127
set_allowed_devices(const std::optional<std::set<int>> & allowed_devices)128 ServiceOptions& ServiceOptions::set_allowed_devices(
129 const std::optional<std::set<int>>& allowed_devices) {
130 allowed_devices_ = allowed_devices;
131 return *this;
132 }
133
allowed_devices() const134 const std::optional<std::set<int>>& ServiceOptions::allowed_devices() const {
135 return allowed_devices_;
136 }
137
NewService(se::Platform * platform)138 /* static */ StatusOr<std::unique_ptr<Service>> Service::NewService(
139 se::Platform* platform) {
140 ServiceOptions default_options;
141 default_options.set_platform(platform);
142 return NewService(default_options);
143 }
144
NewService(const ServiceOptions & options)145 /* static */ StatusOr<std::unique_ptr<Service>> Service::NewService(
146 const ServiceOptions& options) {
147 se::Platform* platform = options.platform();
148 std::unique_ptr<Backend> execute_backend;
149 if (platform == nullptr) {
150 TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform());
151 }
152 BackendOptions backend_options;
153 backend_options.set_platform(platform);
154 backend_options.set_allowed_devices(options.allowed_devices());
155 TF_ASSIGN_OR_RETURN(execute_backend, Backend::CreateBackend(backend_options));
156
157 std::unique_ptr<Service> service(
158 new Service(options, std::move(execute_backend)));
159 return std::move(service);
160 }
161
Service(const ServiceOptions & options,std::unique_ptr<Backend> execute_backend)162 Service::Service(const ServiceOptions& options,
163 std::unique_ptr<Backend> execute_backend)
164 : options_(options),
165 allocation_tracker_(execute_backend.get()),
166 execute_backend_(std::move(execute_backend)) {
167 CHECK_GT(options_.number_of_replicas(), 0);
168 if (execute_backend_) {
169 if (execute_backend_->device_count() > 0) {
170 CHECK_GE(execute_backend_->device_count(), options_.number_of_replicas())
171 << "Requested more replicas than there are devices.";
172 }
173 LOG(INFO) << StrFormat(
174 "XLA service %p initialized for platform %s (this does not guarantee "
175 "that XLA will be used). Devices:",
176 this, execute_backend_->platform()->Name());
177 auto stream_executors = execute_backend_->stream_executors();
178 for (int i = 0; i < execute_backend_->device_count(); ++i) {
179 se::StreamExecutor* executor = stream_executors.at(i);
180 const auto& description = executor->GetDeviceDescription();
181 LOG(INFO) << StrFormat(" StreamExecutor device (%d): %s, %s", i,
182 description.name(),
183 description.platform_version());
184 }
185 } else {
186 VLOG(1) << "XLA compile-only service constructed";
187 }
188 }
189
CreateChannelHandle(const CreateChannelHandleRequest * arg,CreateChannelHandleResponse * result)190 Status Service::CreateChannelHandle(const CreateChannelHandleRequest* arg,
191 CreateChannelHandleResponse* result) {
192 TF_ASSIGN_OR_RETURN(*result->mutable_channel(),
193 channel_tracker_.NewChannel(arg->channel_type()));
194 return OkStatus();
195 }
196
Unregister(const UnregisterRequest * arg,UnregisterResponse * result)197 Status Service::Unregister(const UnregisterRequest* arg,
198 UnregisterResponse* result) {
199 Status status;
200 for (auto& data : arg->data()) {
201 Status unregister_status = allocation_tracker_.Unregister(data);
202 if (!unregister_status.ok() && status.ok()) {
203 status = unregister_status;
204 }
205 }
206 return status;
207 }
208
209 // Deconstructs a previously-allocated global handle.
DeconstructTuple(const DeconstructTupleRequest * arg,DeconstructTupleResponse * result)210 Status Service::DeconstructTuple(const DeconstructTupleRequest* arg,
211 DeconstructTupleResponse* result) {
212 TF_ASSIGN_OR_RETURN(
213 std::vector<GlobalDataHandle> elements,
214 allocation_tracker_.DeconstructTuple(arg->tuple_handle()));
215
216 for (auto& element : elements) {
217 *result->add_element_handles() = element;
218 }
219 return OkStatus();
220 }
221
ValidateResultShape(const Shape & client_shape,const Shape & result_shape) const222 Status Service::ValidateResultShape(const Shape& client_shape,
223 const Shape& result_shape) const {
224 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(client_shape));
225 if (!ShapeUtil::Compatible(client_shape, result_shape)) {
226 return InvalidArgument(
227 "Shape used to set computation result layout %s is not compatible "
228 "with result shape %s",
229 ShapeUtil::HumanStringWithLayout(client_shape),
230 ShapeUtil::HumanString(result_shape));
231 }
232 return OkStatus();
233 }
234
235 StatusOr<std::vector<std::vector<const ShapedBuffer*>>>
ResolveAndValidateArguments(absl::Span<const GlobalDataHandle * const> arguments,absl::Span<se::StreamExecutor * const> stream_executors) const236 Service::ResolveAndValidateArguments(
237 absl::Span<const GlobalDataHandle* const> arguments,
238 absl::Span<se::StreamExecutor* const> stream_executors) const {
239 CHECK_EQ(options_.number_of_replicas(), stream_executors.size());
240 std::vector<std::vector<const ShapedBuffer*>> replicated_arguments;
241 replicated_arguments.resize(options_.number_of_replicas());
242 for (size_t i = 0; i < arguments.size(); ++i) {
243 auto buffer_status = allocation_tracker_.Resolve(*arguments[i]);
244 if (!buffer_status.ok()) {
245 return tensorflow::errors::CreateWithUpdatedMessage(
246 buffer_status.status(),
247 StrCat(buffer_status.status().error_message(), ", ",
248 "failed to resolve allocation for parameter ", i));
249 }
250 auto replicated_buffers = buffer_status.ValueOrDie();
251 CHECK_EQ(options_.number_of_replicas(), replicated_buffers.size());
252 for (int replica = 0; replica < options_.number_of_replicas(); ++replica) {
253 const ShapedBuffer* shaped_buffer = replicated_buffers[replica];
254 replicated_arguments[replica].push_back(shaped_buffer);
255 }
256 }
257 return replicated_arguments;
258 }
259
CreateModuleConfig(const ProgramShape & program_shape,absl::Span<const Shape * const> argument_shapes,const ExecutionOptions * execution_options,const AotCompilationOptions * aot_options)260 StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
261 const ProgramShape& program_shape,
262 absl::Span<const Shape* const> argument_shapes,
263 const ExecutionOptions* execution_options,
264 const AotCompilationOptions* aot_options) {
265 int default_num_replicas = options_.number_of_replicas();
266 std::optional<int> num_threads;
267 if (execute_backend_ != nullptr &&
268 execute_backend_->eigen_intra_op_thread_pool() != nullptr) {
269 num_threads = execute_backend_->eigen_intra_op_thread_pool()->NumThreads();
270 }
271
272 return xla::CreateModuleConfig(program_shape, argument_shapes,
273 execution_options, default_num_replicas,
274 num_threads, aot_options);
275 }
276
CreateModuleConfig(const ProgramShape & program_shape,absl::Span<const ShapedBuffer * const> arguments,const ExecutionOptions & execution_options,const AotCompilationOptions * aot_options)277 StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
278 const ProgramShape& program_shape,
279 absl::Span<const ShapedBuffer* const> arguments,
280 const ExecutionOptions& execution_options,
281 const AotCompilationOptions* aot_options) {
282 std::vector<const Shape*> argument_shapes;
283 for (const auto* arg : arguments) {
284 argument_shapes.push_back(&arg->on_device_shape());
285 }
286 return CreateModuleConfig(program_shape, argument_shapes, &execution_options,
287 aot_options);
288 }
289
BuildExecutables(const std::vector<const HloModuleProto * > & module_protos,std::vector<std::unique_ptr<HloModuleConfig>> module_configs,Backend * backend,std::vector<std::vector<se::StreamExecutor * >> executors,const Compiler::CompileOptions & options,bool run_backend_only)290 StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
291 const std::vector<const HloModuleProto*>& module_protos,
292 std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
293 Backend* backend, std::vector<std::vector<se::StreamExecutor*>> executors,
294 const Compiler::CompileOptions& options, bool run_backend_only) {
295 VLOG(1) << StrFormat("BuildExecutable on service %p", this);
296
297 VLOG(1) << "Computations:";
298 for (const HloModuleProto* proto : module_protos) {
299 VLOG(1) << proto->name();
300 }
301
302 CHECK_EQ(module_protos.size(), module_configs.size());
303 auto module_group =
304 std::make_unique<HloModuleGroup>(module_protos[0]->name());
305 for (int64_t i = 0, end = module_protos.size(); i < end; ++i) {
306 const HloModuleProto* proto = module_protos[i];
307 const HloModuleConfig& config = *module_configs[i];
308 TF_ASSIGN_OR_RETURN(
309 auto module, CreateModuleFromProto(*proto, config, run_backend_only));
310 UpdateEntryComputationLayout(
311 module.get(), std::bind(&Compiler::DefaultDeviceShapeRepresentation,
312 backend->compiler(), std::placeholders::_1));
313 DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName);
314 module_group->push_back(std::move(module));
315 }
316
317 std::vector<std::unique_ptr<Executable>> executables;
318 if (!run_backend_only) {
319 TF_ASSIGN_OR_RETURN(executables, backend->compiler()->Compile(
320 std::move(module_group),
321 std::move(executors), options));
322 } else {
323 auto modules = module_group->ConsumeModules();
324 for (std::unique_ptr<HloModule>& module : modules) {
325 TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
326 backend->compiler()->RunBackend(
327 std::move(module), executors[0][0], options));
328 executables.push_back(std::move(executable));
329 }
330 }
331
332 return std::move(executables);
333 }
334
335 StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
BuildAotResults(const std::vector<const HloModuleProto * > & module_protos,std::vector<std::unique_ptr<HloModuleConfig>> module_configs,Backend * backend,std::vector<std::vector<se::StreamExecutor * >> executors,const Compiler::CompileOptions & options,bool run_backend_only)336 Service::BuildAotResults(
337 const std::vector<const HloModuleProto*>& module_protos,
338 std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
339 Backend* backend, std::vector<std::vector<se::StreamExecutor*>> executors,
340 const Compiler::CompileOptions& options, bool run_backend_only) {
341 VLOG(1) << StrFormat("BuildAotResults on service %p", this);
342
343 VLOG(1) << "Computations:";
344 for (const HloModuleProto* proto : module_protos) {
345 VLOG(1) << proto->name();
346 }
347
348 CHECK_EQ(module_protos.size(), module_configs.size());
349 auto module_group =
350 std::make_unique<HloModuleGroup>(module_protos[0]->name());
351 for (int64_t i = 0, end = module_protos.size(); i < end; ++i) {
352 const HloModuleProto* proto = module_protos[i];
353 const HloModuleConfig& config = *module_configs[i];
354 TF_ASSIGN_OR_RETURN(
355 auto module, CreateModuleFromProto(*proto, config, run_backend_only));
356 DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName);
357 module_group->push_back(std::move(module));
358 }
359
360 AotCompilationOptions aot_options(backend->compiler()->PlatformId());
361 aot_options.set_executor(executors[0][0]);
362 aot_options.set_device_allocator(options.device_allocator);
363 aot_options.set_run_backend_only(run_backend_only);
364
365 TF_ASSIGN_OR_RETURN(
366 std::vector<std::unique_ptr<AotCompilationResult>> aot_results,
367 backend->compiler()->CompileAheadOfTime(std::move(module_group),
368 aot_options));
369 return std::move(aot_results);
370 }
371
372 StatusOr<std::vector<GlobalDataHandle>>
ExecuteParallelAndRegisterResult(absl::Span<Executable * const> executables,absl::Span<const std::vector<std::vector<const ShapedBuffer * >>> arguments,Backend * backend,absl::Span<const DeviceHandle> device_handles,absl::Span<const std::string> result_tags,ExecutionProfile * profile)373 Service::ExecuteParallelAndRegisterResult(
374 absl::Span<Executable* const> executables,
375 absl::Span<const std::vector<std::vector<const ShapedBuffer*>>> arguments,
376 Backend* backend, absl::Span<const DeviceHandle> device_handles,
377 absl::Span<const std::string> result_tags, ExecutionProfile* profile) {
378 // Streams where the computation are launched, so we can wait on the streams
379 // to complete.
380 std::vector<StreamPool::Ptr> streams;
381 std::vector<std::unique_ptr<se::Timer>> timers;
382
383 // Global data handles for the computation results, one for each computation.
384 std::vector<GlobalDataHandle> result_handles;
385
386 // Device ID to stream executor, populated only with devices that are being
387 // profiled.
388 std::map<int64_t, se::Stream*> index_to_profiled_streams;
389
390 // Build DeviceAssignment for all cores based on the provided device handles.
391 DeviceAssignment device_assignment(options_.number_of_replicas(),
392 executables.size());
393 for (int64_t i = 0; i < executables.size(); i++) {
394 TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handles[i]));
395 CHECK_EQ(replicas.size(), arguments[i].size());
396 for (int64_t replica = 0, end = replicas.size(); replica < end; ++replica) {
397 device_assignment(replica, i) = replicas[replica]->device_ordinal();
398 }
399 }
400
401 for (int64_t i = 0, end = executables.size(); i < end; i++) {
402 // Stream executors for the replicas of the current computation.
403 TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handles[i]));
404 CHECK_EQ(replicas.size(), arguments[i].size());
405 std::vector<ScopedShapedBuffer> result_buffers;
406 const int64_t n = replicas.size();
407 result_buffers.reserve(n);
408 for (int64_t replica = 0; replica < n; ++replica) {
409 TF_ASSIGN_OR_RETURN(StreamPool::Ptr stream,
410 backend->BorrowStream(replicas[replica]));
411 streams.push_back(std::move(stream));
412
413 if (replica == 0 && profile != nullptr) {
414 timers.push_back(std::make_unique<se::Timer>(streams.back()->parent()));
415 streams.back()
416 ->InitTimer(timers.back().get())
417 .ThenStartTimer(timers.back().get());
418 CHECK(timers.front() != nullptr);
419 }
420
421 if (replica == 0 &&
422 executables[i]->module_config().debug_options().xla_hlo_profile() &&
423 executables[i]->hlo_profiling_enabled()) {
424 index_to_profiled_streams[i] = streams.back().get();
425 }
426
427 // Set up run options.
428 ExecutableRunOptions options;
429 options.set_stream(streams.back().get());
430 options.set_allocator(backend->memory_allocator());
431 options.set_intra_op_thread_pool(
432 backend->eigen_intra_op_thread_pool_device());
433 options.set_device_assignment(&device_assignment);
434 // Use run-time profile information from execution_profile on the 0th
435 // device.
436 if (i == 0) {
437 options.set_execution_profile(profile);
438 }
439 ServiceExecutableRunOptions run_options(options,
440 backend->StreamBorrower());
441
442 // Asynchronously launch the computation.
443 TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
444 executables[i]->ExecuteAsyncOnStream(
445 &run_options, arguments[i][replica],
446 /*hlo_execution_profile=*/nullptr));
447
448 if (replica == 0 && profile != nullptr) {
449 streams.back()->ThenStopTimer(timers.back().get());
450 }
451
452 result_buffers.push_back(std::move(result));
453 }
454 TF_ASSIGN_OR_RETURN(GlobalDataHandle handle,
455 allocation_tracker_.RegisterReplicatedBuffers(
456 std::move(result_buffers), result_tags[i]));
457 result_handles.push_back(handle);
458 }
459
460 // Wait for all executions to complete.
461 for (int64_t i = 0, end = streams.size(); i < end; ++i) {
462 Status block_status = streams[i]->BlockHostUntilDone();
463 if (!block_status.ok()) {
464 return InternalError("failed to complete execution for stream %d: %s", i,
465 block_status.error_message());
466 }
467 }
468
469 if (profile != nullptr) {
470 CHECK(!timers.empty());
471 std::vector<uint64_t> timer_nanoseconds;
472 timer_nanoseconds.reserve(timers.size());
473 for (auto& timer : timers) {
474 timer_nanoseconds.push_back(timer->Nanoseconds());
475 }
476 uint64_t nanoseconds =
477 *std::max_element(timer_nanoseconds.begin(), timer_nanoseconds.end());
478
479 // Overall execution time (in nanoseconds) from the executor timer.
480 profile->set_compute_and_transfer_time_ns(nanoseconds);
481
482 // TODO(b/28123297): On GPU we end up including transfer time in
483 // the compute time this way. Instead, we should get the correct
484 // value by measuring it. Setting the field here at least lets
485 // benchmarks provide *some* value for GPU computations.
486 //
487 // TODO(b/28447609): The value in compute_and_transfer_time_ns is actually
488 // the compute time without the transfer time, so this way we get the
489 // correct compute time. We should instead have the correct value for
490 // compute_and_transfer_time and set compute_time to the compute time.
491 if (profile->compute_time_ns() == 0) {
492 profile->set_compute_time_ns(profile->compute_and_transfer_time_ns());
493 }
494 }
495
496 return result_handles;
497 }
498
ExecuteAndRegisterResult(Executable * executable,absl::Span<const std::vector<const ShapedBuffer * >> arguments,Backend * backend,const DeviceHandle & device_handle,const std::string & result_tag,ExecutionProfile * profile)499 StatusOr<GlobalDataHandle> Service::ExecuteAndRegisterResult(
500 Executable* executable,
501 absl::Span<const std::vector<const ShapedBuffer*>> arguments,
502 Backend* backend, const DeviceHandle& device_handle,
503 const std::string& result_tag, ExecutionProfile* profile) {
504 // Set up streams.
505 std::vector<StreamPool::Ptr> streams;
506
507 TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handle));
508 TF_RET_CHECK(!replicas.empty());
509 for (se::StreamExecutor* executor : replicas) {
510 TF_ASSIGN_OR_RETURN(StreamPool::Ptr stream,
511 backend->BorrowStream(executor));
512 streams.push_back(std::move(stream));
513 }
514
515 DeviceAssignment device_assignment(options_.number_of_replicas(),
516 /*computation_count=*/1);
517 for (int64_t replica = 0; replica < replicas.size(); ++replica) {
518 device_assignment(replica, 0) = replicas[replica]->device_ordinal();
519 }
520
521 // Set up run options.
522 std::vector<ServiceExecutableRunOptions> run_options;
523 run_options.reserve(streams.size());
524 for (const StreamPool::Ptr& stream : streams) {
525 ExecutableRunOptions options;
526 options.set_stream(stream.get());
527 options.set_device_ordinal(stream->parent()->device_ordinal());
528 options.set_allocator(backend->memory_allocator());
529 options.set_intra_op_thread_pool(
530 backend->eigen_intra_op_thread_pool_device());
531 options.set_device_assignment(&device_assignment);
532 options.set_execution_profile(profile);
533 run_options.emplace_back(options, backend->StreamBorrower());
534 }
535
536 if (options_.number_of_replicas() == 1) {
537 TF_ASSIGN_OR_RETURN(auto result, executable->ExecuteOnStreamWrapper(
538 &run_options[0], arguments[0]));
539 return allocation_tracker_.Register(std::move(result), result_tag);
540 }
541
542 // TODO(b/69985541): Support profiling also on this path.
543
544 std::vector<absl::Span<const ShapedBuffer* const>> replicated_arguments;
545 for (const auto& arg : arguments) {
546 replicated_arguments.push_back(arg);
547 }
548
549 TF_ASSIGN_OR_RETURN(auto results, executable->ExecuteOnStreams(
550 run_options, replicated_arguments));
551 TF_RET_CHECK(!results.empty());
552 return allocation_tracker_.RegisterReplicatedBuffers(std::move(results),
553 result_tag);
554 }
555
GetExecutors(const ExecutionOptions & execution_options,int64_t requests_size,int64_t request_index) const556 StatusOr<std::vector<se::StreamExecutor*>> Service::GetExecutors(
557 const ExecutionOptions& execution_options, int64_t requests_size,
558 int64_t request_index) const {
559 if (execution_options.device_handles().empty()) {
560 return FailedPrecondition(
561 "device handles must be given to execute parallel computations");
562 }
563 if (requests_size > 1 && execution_options.device_handles_size() > 1) {
564 return InvalidArgument(
565 "Parallel requests with multiple device handles is not supported. "
566 "Found %d parallel requests, with request %d containing %d device "
567 "handles.",
568 requests_size, request_index, execution_options.device_handles_size());
569 }
570 std::vector<se::StreamExecutor*> executors;
571 for (const auto& device_handle : execution_options.device_handles()) {
572 TF_ASSIGN_OR_RETURN(auto replicas,
573 Replicas(*execute_backend_, device_handle));
574 se::StreamExecutor* executor = replicas[0];
575 CHECK(executor != nullptr);
576 executors.push_back(executor);
577 }
578 return executors;
579 }
580
GetArguments(const ExecutionOptions & execution_options,absl::Span<const GlobalDataHandle * const> arguments) const581 StatusOr<std::vector<std::vector<const ShapedBuffer*>>> Service::GetArguments(
582 const ExecutionOptions& execution_options,
583 absl::Span<const GlobalDataHandle* const> arguments) const {
584 // Resolve the allocations for the arguments of the computation, and create
585 // a vector of device memory offsets for the arguments from the allocations.
586 // In the case of partitioned computations, assume all arguments go on the
587 // zeroth core.
588 TF_ASSIGN_OR_RETURN(
589 auto replicas,
590 Replicas(*execute_backend_, execution_options.device_handles(0)));
591 TF_ASSIGN_OR_RETURN(
592 std::vector<std::vector<const ShapedBuffer*>> replicated_arguments,
593 ResolveAndValidateArguments(arguments, replicas));
594 return replicated_arguments;
595 }
596
ExecuteGraphParallel(const ExecuteGraphParallelRequest * arg,ExecuteParallelResponse * result)597 Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg,
598 ExecuteParallelResponse* result) {
599 VLOG(1) << "running execute-graph-parallel request";
600
601 std::vector<std::vector<std::vector<const ShapedBuffer*>>> all_arguments;
602 std::vector<std::vector<se::StreamExecutor*>> all_executors;
603 std::vector<const HloModuleProto*> module_protos;
604 std::vector<std::unique_ptr<HloModuleConfig>> module_configs;
605 std::vector<std::string> computation_names;
606 std::vector<DeviceHandle> device_handles;
607
608 int num_requested_devices =
609 std::accumulate(arg->requests().begin(), arg->requests().end(), 0,
610 [](int a, const ExecuteGraphRequest& r) -> int {
611 return a + r.execution_options().device_handles_size();
612 });
613 if (num_requested_devices * options_.number_of_replicas() >
614 execute_backend_->device_count()) {
615 return FailedPrecondition(
616 "there are not enough stream executors to execute %d computations",
617 num_requested_devices);
618 }
619
620 for (int64_t i = 0; i < arg->requests_size(); ++i) {
621 // Get the stream executor for the i'th computation. This stream executor
622 // is one of the executors to run the replicated computation.
623 const ExecutionOptions& execution_options =
624 arg->requests(i).execution_options();
625 const ExecuteGraphRequest& request = arg->requests(i);
626 TF_RET_CHECK(request.has_computation()) << "computations may not be empty";
627 TF_RET_CHECK(request.computation().has_host_program_shape())
628 << "program shape may not be empty";
629
630 // Get the executors.
631 TF_ASSIGN_OR_RETURN(auto executors, GetExecutors(execution_options,
632 arg->requests_size(), i));
633
634 // Get the replicated arguments.
635 TF_ASSIGN_OR_RETURN(auto replicated_arguments,
636 GetArguments(execution_options, request.arguments()));
637
638 for (auto& args : replicated_arguments) {
639 for (auto& arg : args) {
640 auto update_shape_with_empty_tiles = [this](
641 Shape* subshape,
642 const xla::ShapeIndex& index) {
643 if (subshape->IsArray() && subshape->layout().tiles().empty()) {
644 *subshape =
645 execute_backend_->transfer_manager()->HostShapeToDeviceShape(
646 *subshape);
647 }
648 };
649 ShapeUtil::ForEachMutableSubshape(
650 const_cast<Shape*>(&arg->on_device_shape()),
651 update_shape_with_empty_tiles);
652 }
653 }
654
655 // Create an HloModuleConfig object for the computation, given the shape of
656 // the program and the argument allocations. Here, we care only about the
657 // shapes of the arguments, so, it is sufficient to use the arguments of
658 // replica 0.
659 TF_ASSIGN_OR_RETURN(
660 std::unique_ptr<HloModuleConfig> module_config,
661 CreateModuleConfig(
662 ProgramShape{request.computation().host_program_shape()},
663 replicated_arguments.front(), request.execution_options()));
664 VLOG(3)
665 << "ExecuteGraphParallel created HloModuleConfig computation layout: "
666 << module_config->entry_computation_layout().ToString();
667
668 // Adds to the vectors to build and execute the computations after the loop.
669 all_arguments.push_back(replicated_arguments);
670 all_arguments.insert(all_arguments.end(), executors.size() - 1, {{}});
671 module_protos.push_back(&request.computation());
672 module_configs.push_back(std::move(module_config));
673 computation_names.insert(computation_names.end(), executors.size(),
674 request.computation().name());
675 all_executors.push_back(executors);
676 device_handles.insert(device_handles.end(),
677 execution_options.device_handles().begin(),
678 execution_options.device_handles().end());
679 }
680
681 // Build the HloModules and compile to generate the executables.
682 //
683 // TODO(jlebar): There's currently no way to pass a device allocator to
684 // ExecuteGraphParallel, so we have to pass a null device_allocator below.
685 TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<Executable>> executables,
686 BuildExecutables(module_protos, std::move(module_configs),
687 execute_backend_.get(), all_executors,
688 {/*device_allocator=*/nullptr}));
689 std::vector<Executable*> executable_ptrs;
690 executable_ptrs.reserve(executables.size());
691 for (const auto& executable : executables) {
692 executable_ptrs.push_back(executable.get());
693 }
694
695 std::vector<HloSnapshot> snapshots;
696 snapshots.resize(executable_ptrs.size());
697 for (int i = 0, end = executable_ptrs.size(); i < end; i++) {
698 if (executable_ptrs[i]->dumping_snapshot()) {
699 *snapshots[i].mutable_hlo() = *executable_ptrs[i]->hlo_proto();
700 TF_ASSIGN_OR_RETURN(auto stream,
701 execute_backend_->BorrowStream(
702 all_executors[i][0]->device_ordinal()));
703 TF_RETURN_IF_ERROR(RecordArguments(all_arguments[i].front(), stream.get(),
704 execute_backend_->transfer_manager(),
705 &snapshots[i]));
706 }
707 }
708
709 // If we have multiple executables to run, execute them all in parallel. But
710 // if we only have one executable, execute it using the vanilla, non-parallel
711 // call.
712 //
713 // We do this because the Client API uses ExecuteGraphParallel when it wants
714 // to compile and run one computation without caching the executable, but not
715 // all backends support the async StreamExecutor API required by
716 // ExecuteParallelAndRegisterResult.
717 //
718 // TODO(b/122731460): Consolidate Execute{,Parallel}AndRegisterResult; they do
719 // basically the same thing.
720 ExecutionProfile profile;
721 std::vector<GlobalDataHandle> outputs;
722 Status execution_status = OkStatus();
723
724 if (executable_ptrs.size() == 1) {
725 StatusOr<GlobalDataHandle> output_or_status = ExecuteAndRegisterResult(
726 executable_ptrs[0], all_arguments[0], execute_backend_.get(),
727 device_handles[0], computation_names[0], &profile);
728 if (output_or_status.ok()) {
729 outputs.push_back(std::move(output_or_status).ValueOrDie());
730 } else {
731 execution_status = output_or_status.status();
732 }
733 } else {
734 StatusOr<std::vector<GlobalDataHandle>> outputs_or_status =
735 ExecuteParallelAndRegisterResult(executable_ptrs, all_arguments,
736 execute_backend_.get(), device_handles,
737 computation_names, &profile);
738 if (outputs_or_status.ok()) {
739 outputs = std::move(outputs_or_status).ValueOrDie();
740 } else {
741 execution_status = outputs_or_status.status();
742 }
743 }
744
745 if (!execution_status.ok()) {
746 // Execution failed so we don't have the results. Dump the HLO snapshot
747 // with just the program arguments.
748 for (int i = 0, end = executable_ptrs.size(); i < end; i++) {
749 DumpHloSnapshotIfEnabled(executable_ptrs[i]->module(), snapshots[i]);
750 }
751 }
752
753 TF_RETURN_IF_ERROR(execution_status);
754
755 for (const GlobalDataHandle& output : outputs) {
756 ExecuteResponse response;
757 *response.mutable_output() = output;
758 *response.mutable_profile() = profile;
759 *result->add_responses() = response;
760 }
761
762 for (int i = 0, end = executable_ptrs.size(); i < end; i++) {
763 Executable* executable = executable_ptrs[i];
764 if (executable->dumping_snapshot()) {
765 TF_ASSIGN_OR_RETURN(const ShapedBuffer* result_buffer,
766 allocation_tracker_.ResolveForReplica(outputs[i], 0));
767 TF_ASSIGN_OR_RETURN(auto stream,
768 execute_backend_->BorrowStream(all_executors[i][0]));
769 TF_RETURN_IF_ERROR(RecordResult(*result_buffer, stream.get(),
770 execute_backend_->transfer_manager(),
771 &snapshots[i]));
772 DumpHloSnapshotIfEnabled(executable->module(), snapshots[i]);
773 }
774 }
775
776 VLOG(1) << "successfully completed 'execute-graph-parallel' request";
777 return OkStatus();
778 }
779
GetDeviceHandles(const GetDeviceHandlesRequest * arg,GetDeviceHandlesResponse * result)780 Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg,
781 GetDeviceHandlesResponse* result) {
782 const int64_t available_device_count = execute_backend_->device_count();
783 const int64_t replica_count = options_.number_of_replicas();
784 if (replica_count <= 0) {
785 return FailedPrecondition("Replica count must be a positive integer");
786 }
787 if (available_device_count < arg->device_count() * replica_count) {
788 return ResourceExhausted(
789 "Requested logical device count (%d) with replica count (%d) exceeds "
790 "the number of available physical devices on the target (%d)",
791 arg->device_count(), replica_count, available_device_count);
792 }
793
794 for (int64_t i = 0; i < arg->device_count(); ++i) {
795 DeviceHandle device_handle;
796 device_handle.set_handle(i);
797 device_handle.set_device_count(arg->device_count());
798 *result->add_device_handles() = device_handle;
799 }
800
801 return OkStatus();
802 }
803
BuildExecutable(const HloModuleProto & module_proto,std::unique_ptr<HloModuleConfig> module_config,Backend * backend,se::StreamExecutor * executor,const Compiler::CompileOptions & options,bool run_backend_only)804 StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
805 const HloModuleProto& module_proto,
806 std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
807 se::StreamExecutor* executor, const Compiler::CompileOptions& options,
808 bool run_backend_only) {
809 VLOG(1) << StrFormat(
810 "BuildExecutable on service %p with serialized module proto: %s", this,
811 module_proto.name());
812
813 TF_ASSIGN_OR_RETURN(
814 std::unique_ptr<HloModule> module,
815 CreateModuleFromProto(module_proto, *module_config, run_backend_only));
816 UpdateEntryComputationLayout(
817 module.get(), std::bind(&Compiler::DefaultDeviceShapeRepresentation,
818 backend->compiler(), std::placeholders::_1));
819 DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName);
820
821 std::unique_ptr<HloProto> hlo_proto_before_opt;
822 if (!run_backend_only) {
823 // Save proto state before optimizations if we want a snapshot.
824 // When run_backend_only is enabled the post-optimization HLO will be the
825 // same as the pre-optimization HLO.
826 if (DumpingEnabledForHloModule(*module)) {
827 hlo_proto_before_opt = std::make_unique<HloProto>(MakeHloProto(*module));
828 }
829 TF_ASSIGN_OR_RETURN(module, backend->compiler()->RunHloPasses(
830 std::move(module), executor, options));
831 }
832
833 TF_ASSIGN_OR_RETURN(
834 std::unique_ptr<Executable> executable,
835 backend->compiler()->RunBackend(std::move(module), executor, options));
836
837 const HloProto* hlo_proto_after_opt = executable->hlo_proto();
838
839 // If dumping is enabled RunBackend(...) will emit a hlo_proto in the
840 // executable. This contains the buffer_assignment that is only available
841 // after RunBackend(). If hlo_proto_before_opt is not null, then we replace
842 // its buffer_assignment with the one from after_opt and then store it into
843 // the executable.
844 if (hlo_proto_before_opt != nullptr && hlo_proto_after_opt != nullptr) {
845 CHECK(DumpingEnabledForHloModule(executable->module()));
846 *hlo_proto_before_opt->mutable_buffer_assignment() =
847 hlo_proto_after_opt->buffer_assignment();
848 executable->set_hlo_proto(std::move(hlo_proto_before_opt));
849 }
850 return std::move(executable);
851 }
852
Compile(const CompileRequest * arg,CompileResponse * result)853 Status Service::Compile(const CompileRequest* arg, CompileResponse* result) {
854 VLOG(1) << "running compile request";
855 if (!arg->has_computation()) {
856 return InvalidArgument("computations may not be empty");
857 }
858 if (!arg->computation().has_host_program_shape()) {
859 return InvalidArgument("program shape may not be empty");
860 }
861
862 if (arg->execution_options().device_handles_size() > 1) {
863 return InvalidArgument(
864 "The compile request does not support multiple device handles.");
865 }
866
867 std::vector<Shape> argument_shapes;
868 argument_shapes.reserve(arg->input_shape_with_layout_size());
869 std::vector<const Shape*> argument_shape_ptrs;
870 for (const ShapeProto& shape_proto : arg->input_shape_with_layout()) {
871 argument_shapes.push_back(Shape(shape_proto));
872 argument_shape_ptrs.push_back(&argument_shapes.back());
873 }
874 TF_ASSIGN_OR_RETURN(
875 std::unique_ptr<HloModuleConfig> module_config,
876 CreateModuleConfig(ProgramShape{arg->computation().host_program_shape()},
877 argument_shape_ptrs, &arg->execution_options()));
878 VLOG(3) << "Compile created HloModuleConfig computation layout: "
879 << module_config->entry_computation_layout().ToString();
880
881 TF_ASSIGN_OR_RETURN(
882 std::unique_ptr<Executable> executable,
883 BuildExecutable(arg->computation(), std::move(module_config),
884 execute_backend_.get(),
885 execute_backend_->default_stream_executor(),
886 {/*device_allocator=*/nullptr}));
887
888 *result->mutable_handle() = compilation_cache_.Insert(std::move(executable));
889
890 VLOG(1) << "successfully completed 'compile' request";
891 return OkStatus();
892 }
893
Execute(const ExecuteRequest * arg,ExecuteResponse * result)894 Status Service::Execute(const ExecuteRequest* arg, ExecuteResponse* result) {
895 VLOG(1) << "running execute request";
896 if (!arg->has_handle()) {
897 return InvalidArgument("execution handle should not be empty");
898 }
899 TF_ASSIGN_OR_RETURN(auto executable,
900 compilation_cache_.LookUp(arg->handle()));
901
902 TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_,
903 SingleComputationDeviceHandle()));
904 TF_ASSIGN_OR_RETURN(
905 std::vector<std::vector<const ShapedBuffer*>> replicated_arguments,
906 ResolveAndValidateArguments(arg->arguments(), replicas));
907
908 // Check that the replicated_arguments has the same shape and layout as the
909 // module config used when creating the executable.
910 const int64_t num_module_args =
911 executable->module_config().entry_computation_layout().parameter_count();
912 if (num_module_args != arg->arguments_size()) {
913 return InvalidArgument(
914 "The executable expects %lld arguments, but sees %lld.",
915 num_module_args, arg->arguments_size());
916 }
917 for (int64_t i = 0; i < num_module_args; i++) {
918 const Shape& shape_module =
919 executable->module_config().entry_computation_layout().parameter_shape(
920 i);
921 const Shape& shape_arg = replicated_arguments.front()[i]->on_device_shape();
922 if (!ShapeUtil::Equal(shape_module, shape_arg)) {
923 return InvalidArgumentStrCat(
924 "The executable expects the ", i, "th argument in shape ",
925 ShapeUtil::HumanStringWithLayout(shape_module), " but sees ",
926 ShapeUtil::HumanStringWithLayout(shape_arg));
927 }
928 }
929
930 TF_ASSIGN_OR_RETURN(auto stream,
931 execute_backend_->BorrowStream(
932 execute_backend_->default_stream_executor()));
933 HloSnapshot snapshot;
934 if (executable->dumping_snapshot()) {
935 *snapshot.mutable_hlo() = *executable->hlo_proto();
936 snapshot.set_execution_platform(execute_backend_->platform()->Name());
937 TF_RETURN_IF_ERROR(
938 RecordArguments(replicated_arguments.front(), stream.get(),
939 execute_backend_->transfer_manager(), &snapshot));
940 }
941
942 TF_ASSIGN_OR_RETURN(
943 *result->mutable_output(),
944 ExecuteAndRegisterResult(executable.get(), replicated_arguments,
945 execute_backend_.get(),
946 SingleComputationDeviceHandle(),
947 "result of " + executable->module().name(),
948 result->mutable_profile()));
949
950 if (executable->dumping_snapshot()) {
951 TF_ASSIGN_OR_RETURN(
952 const ShapedBuffer* result_buffer,
953 allocation_tracker_.ResolveForReplica(result->output(), 0));
954 TF_RETURN_IF_ERROR(RecordResult(*result_buffer, stream.get(),
955 execute_backend_->transfer_manager(),
956 &snapshot));
957 DumpHloSnapshotIfEnabled(executable->module(), snapshot);
958 }
959
960 VLOG(1) << "successfully completed 'execute' request";
961 return OkStatus();
962 }
963
WaitForExecution(const WaitForExecutionRequest * arg,WaitForExecutionResponse * result)964 Status Service::WaitForExecution(const WaitForExecutionRequest* arg,
965 WaitForExecutionResponse* result) {
966 TF_ASSIGN_OR_RETURN(const auto execution,
967 execution_tracker_.Resolve(arg->execution()));
968
969 TF_RETURN_IF_ERROR(execution->BlockUntilDone());
970
971 *result->mutable_output() = execution->result();
972 *result->mutable_profile() = execution->profile();
973
974 TF_RETURN_IF_ERROR(execution_tracker_.Unregister(arg->execution()));
975 VLOG(1) << "successfully completed 'wait-for-execution' request";
976 return OkStatus();
977 }
978
TransferToClient(const TransferToClientRequest * arg,TransferToClientResponse * result)979 Status Service::TransferToClient(const TransferToClientRequest* arg,
980 TransferToClientResponse* result) {
981 TF_ASSIGN_OR_RETURN(const ShapedBuffer* shaped_buffer,
982 allocation_tracker_.ResolveForReplica(arg->data(), 0));
983
984 Shape return_shape;
985 if (arg->has_shape_with_layout()) {
986 return_shape = Shape(arg->shape_with_layout());
987 if (!LayoutUtil::HasLayout(return_shape)) {
988 return InvalidArgument("shape_with_layout must have layout if present.");
989 }
990 } else {
991 return_shape = Shape(shaped_buffer->on_device_shape());
992 }
993
994 TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream(
995 shaped_buffer->device_ordinal()));
996
997 TF_ASSIGN_OR_RETURN(
998 Literal result_literal,
999 execute_backend_->transfer_manager()->TransferLiteralFromDevice(
1000 stream.get(), *shaped_buffer));
1001
1002 if (LayoutUtil::LayoutsInShapesEqual(return_shape, result_literal.shape())) {
1003 *result->mutable_literal() = result_literal.ToProto();
1004 } else {
1005 *result->mutable_literal() =
1006 result_literal.Relayout(return_shape).ToProto();
1007 }
1008 return OkStatus();
1009 }
1010
TransferToServer(const TransferToServerRequest * arg,TransferToServerResponse * result)1011 Status Service::TransferToServer(const TransferToServerRequest* arg,
1012 TransferToServerResponse* result) {
1013 TF_ASSIGN_OR_RETURN(Literal literal,
1014 Literal::CreateFromProto(arg->literal()));
1015 const Shape& shape = literal.shape();
1016
1017 std::vector<se::StreamExecutor*> replicas;
1018 if (arg->has_device_handle()) {
1019 TF_ASSIGN_OR_RETURN(replicas,
1020 Replicas(*execute_backend_, arg->device_handle()));
1021 } else {
1022 TF_ASSIGN_OR_RETURN(
1023 replicas, Replicas(*execute_backend_, SingleComputationDeviceHandle()));
1024 }
1025
1026 // Allocate memory in each replica and transfer the data to all replicas.
1027 std::vector<ScopedShapedBuffer> replicated_buffers;
1028 replicated_buffers.reserve(replicas.size());
1029 for (se::StreamExecutor* executor : replicas) {
1030 auto device_shape_representation_fn = [this](const Shape& shape) {
1031 return execute_backend_->compiler()->DefaultDeviceShapeRepresentation(
1032 shape);
1033 };
1034 TF_ASSIGN_OR_RETURN(
1035 ScopedShapedBuffer shaped_buffer,
1036 execute_backend_->transfer_manager()->AllocateScopedShapedBuffer(
1037 shape, execute_backend_->memory_allocator(),
1038 executor->device_ordinal(), device_shape_representation_fn));
1039 TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream(executor));
1040 TF_RETURN_IF_ERROR(
1041 execute_backend_->transfer_manager()->TransferLiteralToDevice(
1042 stream.get(), literal, shaped_buffer));
1043 replicated_buffers.emplace_back(std::move(shaped_buffer));
1044 }
1045 TF_ASSIGN_OR_RETURN(*result->mutable_data(),
1046 allocation_tracker_.RegisterReplicatedBuffers(
1047 std::move(replicated_buffers),
1048 StrCat("TransferToServer literal of shape ",
1049 ShapeUtil::HumanString(shape))));
1050
1051 return OkStatus();
1052 }
1053
TransferToInfeed(const TransferToInfeedRequest * arg,TransferToInfeedResponse * result)1054 Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
1055 TransferToInfeedResponse* result) {
1056 const int64_t replica_count = options_.number_of_replicas();
1057 if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) {
1058 return FailedPrecondition(
1059 "%s",
1060 StrCat("The replica_id=", arg->replica_id(),
1061 " on TransferToInfeedRequest not in range [0, replica_count=",
1062 replica_count, ")."));
1063 }
1064
1065 se::StreamExecutor* executor;
1066 if (arg->has_device_handle()) {
1067 TF_ASSIGN_OR_RETURN(auto replicas,
1068 Replicas(*execute_backend_, arg->device_handle()));
1069 executor = replicas[arg->replica_id()];
1070 } else {
1071 TF_ASSIGN_OR_RETURN(
1072 auto replicas,
1073 Replicas(*execute_backend_, SingleComputationDeviceHandle()));
1074 executor = replicas[arg->replica_id()];
1075 }
1076
1077 TF_ASSIGN_OR_RETURN(Literal literal,
1078 Literal::CreateFromProto(arg->literal()));
1079 return execute_backend_->transfer_manager()->TransferLiteralToInfeed(executor,
1080 literal);
1081 }
1082
TransferFromOutfeed(const TransferFromOutfeedRequest * arg,TransferFromOutfeedResponse * result)1083 Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg,
1084 TransferFromOutfeedResponse* result) {
1085 const int64_t replica_count = options_.number_of_replicas();
1086 if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) {
1087 return FailedPrecondition(
1088 "The replica_id=%d on TransferFromOutfeedRequest not in range [0, %d)",
1089 arg->replica_id(), replica_count);
1090 }
1091
1092 se::StreamExecutor* executor;
1093 if (arg->has_device_handle()) {
1094 TF_ASSIGN_OR_RETURN(auto replicas,
1095 Replicas(*execute_backend_, arg->device_handle()));
1096 executor = replicas[arg->replica_id()];
1097 } else {
1098 TF_ASSIGN_OR_RETURN(
1099 auto replicas,
1100 Replicas(*execute_backend_, SingleComputationDeviceHandle()));
1101 executor = replicas[arg->replica_id()];
1102 }
1103
1104 auto literal = Literal::CreateFromShape(Shape(arg->shape_with_layout()));
1105
1106 TF_RETURN_IF_ERROR(
1107 execute_backend_->transfer_manager()->TransferLiteralFromOutfeed(
1108 executor, &literal));
1109 *result->mutable_literal() = literal.ToProto();
1110 return OkStatus();
1111 }
1112
ResetDevice(const ResetDeviceRequest * arg,ResetDeviceResponse * result)1113 Status Service::ResetDevice(const ResetDeviceRequest* arg,
1114 ResetDeviceResponse* result) {
1115 return execute_backend_->ResetDevices();
1116 }
1117
ComputeConstantGraph(const ComputeConstantGraphRequest * arg,ComputeConstantResponse * result)1118 Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg,
1119 ComputeConstantResponse* result) {
1120 if (!arg->has_computation()) {
1121 return InvalidArgument("computations may not be empty");
1122 }
1123 if (!arg->computation().has_host_program_shape()) {
1124 return InvalidArgument("program shape may not be empty");
1125 }
1126 if (arg->computation().host_program_shape().parameters_size() != 0) {
1127 return InvalidArgument(
1128 "constant computation may not depend on any parameters.");
1129 }
1130
1131 ProgramShape program_shape(arg->computation().host_program_shape());
1132 TF_DCHECK_OK(ShapeUtil::ValidateShape(program_shape.result()));
1133 std::optional<Layout> output_layout;
1134 if (arg->has_output_layout()) {
1135 output_layout = Layout::CreateFromProto(arg->output_layout());
1136 TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape(
1137 *output_layout, program_shape.result()));
1138 }
1139
1140 HloModuleConfig config(program_shape);
1141
1142 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
1143 CreateModuleFromProto(arg->computation(), config));
1144 DynamicPadder dynamic_padder;
1145 TF_RETURN_IF_ERROR(dynamic_padder.Run(module.get()).status());
1146
1147 TF_ASSIGN_OR_RETURN(DynamicDimensionInference dynamic_dimension_inference,
1148 DynamicDimensionInference::Run(module.get()));
1149
1150 HloEvaluator evaluator;
1151 evaluator.set_dynamic_dimension_inference(&dynamic_dimension_inference);
1152 evaluator.set_custom_call_handler(
1153 [](HloInstruction* custom_call,
1154 absl::Span<const Literal*> operands) -> StatusOr<Literal> {
1155 if (custom_call->custom_call_target() == "SliceToDynamic") {
1156 auto result = operands[0]->Clone();
1157 for (int64_t i = 0; i < result.shape().rank(); ++i) {
1158 result.SetDynamicSize(i, operands[1 + i]->Get<int32_t>({}));
1159 }
1160 return result.ToStatic();
1161 }
1162 return Unimplemented("Custom call %s is not supported: %s",
1163 custom_call->custom_call_target(),
1164 custom_call->ToString());
1165 });
1166 TF_ASSIGN_OR_RETURN(auto result_literal, evaluator.Evaluate(*module, {}));
1167
1168 // Since the result layout is non-effective to the Evaluator results, explicit
1169 // relayout here.
1170 //
1171 // TODO(b/77824332): Make HloEvaluator take care of the re-layout.
1172 if (output_layout.has_value()) {
1173 result_literal = result_literal.Relayout(*output_layout);
1174 }
1175 *result->mutable_literal() = result_literal.ToProto();
1176
1177 return OkStatus();
1178 }
1179
GetShape(const GetShapeRequest * arg,GetShapeResponse * result)1180 Status Service::GetShape(const GetShapeRequest* arg, GetShapeResponse* result) {
1181 TF_ASSIGN_OR_RETURN(const ShapedBuffer* buffer,
1182 allocation_tracker_.ResolveForReplica(arg->data(), 0));
1183 *result->mutable_shape() = buffer->on_device_shape().ToProto();
1184 return OkStatus();
1185 }
1186
GetComputationGraphStats(const ComputationGraphStatsRequest * arg,ComputationStatsResponse * result)1187 Status Service::GetComputationGraphStats(
1188 const ComputationGraphStatsRequest* arg, ComputationStatsResponse* result) {
1189 if (!arg->has_computation()) {
1190 return InvalidArgument("Computations may not be empty.");
1191 }
1192 if (!arg->computation().has_host_program_shape()) {
1193 return InvalidArgument("Program shape may not be empty.");
1194 }
1195
1196 HloModuleConfig config(ProgramShape{arg->computation().host_program_shape()});
1197 config.set_debug_options(arg->debug_options());
1198 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
1199 CreateModuleFromProto(arg->computation(), config));
1200 UpdateEntryComputationLayout(
1201 module.get(),
1202 std::bind(&Compiler::DefaultDeviceShapeRepresentation,
1203 execute_backend_->compiler(), std::placeholders::_1));
1204 DumpHloModuleIfEnabled(*module, kBeforeOptimizationsDumpName);
1205
1206 // Run HLO analysis to get the computation statistics.
1207 HloCostAnalysis analysis(
1208 execute_backend_->compiler()->ShapeSizeBytesFunction());
1209
1210 TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&analysis));
1211
1212 ComputationStats stats;
1213 stats.set_flop_count(analysis.flop_count());
1214 stats.set_transcendental_count(analysis.transcendental_count());
1215 *result->mutable_stats() = stats;
1216 return OkStatus();
1217 }
1218
SingleComputationDeviceHandle() const1219 DeviceHandle Service::SingleComputationDeviceHandle() const {
1220 DeviceHandle device_handle;
1221 device_handle.set_handle(0);
1222 device_handle.set_device_count(1);
1223 return device_handle;
1224 }
1225
Replicas(const Backend & backend,const DeviceHandle & device_handle) const1226 StatusOr<std::vector<se::StreamExecutor*>> Service::Replicas(
1227 const Backend& backend, const DeviceHandle& device_handle) const {
1228 std::vector<se::StreamExecutor*> replicas;
1229 for (int replica = 0; replica < options_.number_of_replicas(); ++replica) {
1230 // From the computation placer, find out the device ids of the replicas for
1231 // the given device handle.
1232 TF_ASSIGN_OR_RETURN(
1233 int device_ordinal,
1234 backend.computation_placer()->DeviceId(replica, device_handle.handle(),
1235 options_.number_of_replicas(),
1236 device_handle.device_count()));
1237 TF_ASSIGN_OR_RETURN(auto executor, backend.stream_executor(device_ordinal));
1238 replicas.push_back(executor);
1239 }
1240 return replicas;
1241 }
1242
1243 } // namespace xla
1244