xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/service.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SERVICE_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_SERVICE_H_
18 
19 #include <functional>
20 #include <memory>
21 #include <set>
22 #include <string>
23 #include <vector>
24 
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/debug_options_flags.h"
27 #include "tensorflow/compiler/xla/executable_run_options.h"
28 #include "tensorflow/compiler/xla/service/allocation_tracker.h"
29 #include "tensorflow/compiler/xla/service/backend.h"
30 #include "tensorflow/compiler/xla/service/channel_tracker.h"
31 #include "tensorflow/compiler/xla/service/compilation_cache.h"
32 #include "tensorflow/compiler/xla/service/executable.h"
33 #include "tensorflow/compiler/xla/service/execution_tracker.h"
34 #include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
35 #include "tensorflow/compiler/xla/service/hlo_module.h"
36 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
37 #include "tensorflow/compiler/xla/service_interface.h"
38 #include "tensorflow/compiler/xla/statusor.h"
39 #include "tensorflow/compiler/xla/types.h"
40 #include "tensorflow/compiler/xla/xla.pb.h"
41 #include "tensorflow/compiler/xla/xla_data.pb.h"
42 #include "tensorflow/core/platform/logging.h"
43 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
44 #include "tensorflow/stream_executor/device_memory_allocator.h"
45 
46 namespace xla {
47 
48 // Options to configure the service when it is created.
49 class ServiceOptions {
50  public:
51   // Set the platform backing the service, or nullptr for the default platform.
52   ServiceOptions& set_platform(se::Platform* platform);
53   se::Platform* platform() const;
54 
55   // Set the default number of replicas to use when compiling replicated
56   // programs.
57   ServiceOptions& set_number_of_replicas(int number_of_replicas);
58   int number_of_replicas() const;
59 
60   // Sets the thread pool size for parallel execution of an individual operator.
61   ServiceOptions& set_intra_op_parallelism_threads(int num_threads);
62   int intra_op_parallelism_threads() const;
63 
64   // Sets the allowed_devices set for selectively constructing stream executors
65   // on the platform.
66   ServiceOptions& set_allowed_devices(
67       const std::optional<std::set<int>>& allowed_devices);
68   const std::optional<std::set<int>>& allowed_devices() const;
69 
70  private:
71   se::Platform* platform_ = nullptr;
72   int number_of_replicas_ = 1;
73   int intra_op_parallelism_threads_ = -1;
74   std::optional<std::set<int>> allowed_devices_;
75 };
76 
77 // The XLA service object, which is the same across all platforms. It maintains
78 // the service state of computations and allocations, and delegates
79 // target-specific requests to the target-specific infrastructure
80 // (target-specific compiler, StreamExecutor).
81 class Service : public ServiceInterface {
82  public:
83   // Factory method for creating a new Service.
84   static StatusOr<std::unique_ptr<Service>> NewService(
85       se::Platform* platform = nullptr);
86   static StatusOr<std::unique_ptr<Service>> NewService(
87       const ServiceOptions& options);
88 
89   // Unregisters a previously-allocated global handle.
90   //
91   // If the handle given is not currently allocated, a NOT_FOUND status is
92   // returned.
93   Status Unregister(const UnregisterRequest* arg,
94                     UnregisterResponse* result) override;
95 
96   // Deconstructs a tuple. Returns a newly created GlobalDataHandle for each
97   // element in the tuple.
98   Status DeconstructTuple(const DeconstructTupleRequest* arg,
99                           DeconstructTupleResponse* result) override;
100 
101   // Compiles a computation into an executable. The request contains the whole
102   // computation graph. Returns the handle to the executable.
103   Status Compile(const CompileRequest* arg, CompileResponse* result) override;
104 
105   // Executes an executable with the provided global data passes as immutable
106   // arguments. The request contains the handle to the executable. Returns
107   // global data output and execution timing.
108   Status Execute(const ExecuteRequest* arg, ExecuteResponse* result) override;
109 
110   // Executes one or more computations in parallel with the provided global data
111   // passed as immutable arguments. Returns global data output for each
112   // computation.
113   Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg,
114                               ExecuteParallelResponse* result) override;
115 
116   // Requests one or more device handles from the target.
117   //
118   // When N device handles are requested and the number of replicas is R, at
119   // least N * R devices must be available. The devices are assigned based on
120   // the device ordinals such that the first R available devices are assigned to
121   // the first set of replicas, and the next R devices to the second set of
122   // replicas, etc. Each returned device handle represents the device with the
123   // replica id 0.
124   Status GetDeviceHandles(const GetDeviceHandlesRequest* arg,
125                           GetDeviceHandlesResponse* result) override;
126 
127   // Waits until the specified execution is complete and returns the result.
128   // Calling this API multiple times with the same execution handle returns the
129   // method with an error since the execution handle is destroyed after the
130   // first call.
131   Status WaitForExecution(const WaitForExecutionRequest* arg,
132                           WaitForExecutionResponse* result) override;
133 
134   // Requests that global data be transferred to the client in literal form.
135   Status TransferToClient(const TransferToClientRequest* arg,
136                           TransferToClientResponse* result) override;
137 
138   // Transfers data from a literal provided by the client, into device memory.
139   Status TransferToServer(const TransferToServerRequest* arg,
140                           TransferToServerResponse* result) override;
141 
142   // Transfers data from a literal provided by the client, into the Infeed
143   // buffer of the device.
144   Status TransferToInfeed(const TransferToInfeedRequest* arg,
145                           TransferToInfeedResponse* result) override;
146 
147   // Transfers data from the Outfeed othe device to the literal provided by the
148   // client.
149   Status TransferFromOutfeed(const TransferFromOutfeedRequest* arg,
150                              TransferFromOutfeedResponse* result) override;
151 
152   // Resets devices, clearing all existing state on all the devices associated
153   // with this service (including memory allocated on the devices).
154   //
155   // ResetDevice may only be called where no previous Execution state on the
156   // device is used by the next Execution.
157   //
158   // ResetDevice should be called before an Execution that expect the device to
159   // be in the reset state. For example, if the prior Execution modifies device
160   // state (e.g., architectural state) that the next Execution depends on.
161   Status ResetDevice(const ResetDeviceRequest* arg,
162                      ResetDeviceResponse* result) override;
163 
164   Status ComputeConstantGraph(const ComputeConstantGraphRequest* arg,
165                               ComputeConstantResponse* result) override;
166 
167   // Returns the shape (with layout) of an array associated with a given data
168   // handle.
169   Status GetShape(const GetShapeRequest* arg,
170                   GetShapeResponse* result) override;
171 
172   // Retrieves the statistics of a computation.
173   Status GetComputationGraphStats(const ComputationGraphStatsRequest* arg,
174                                   ComputationStatsResponse* result) override;
175 
176   // Creates a unique channel handle that can be used for Send/Recv
177   // instructions.
178   Status CreateChannelHandle(const CreateChannelHandleRequest* arg,
179                              CreateChannelHandleResponse* result) override;
180 
181   // Returns the backend used to execute computations.
backend()182   const Backend& backend() const { return *execute_backend_; }
mutable_backend()183   Backend* mutable_backend() { return execute_backend_.get(); }
184 
185   // Create a Hlo module config for the given program shape and arguments.
186   // aot_options is optional; if not given a default is used.
187   StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
188       const ProgramShape& program_shape,
189       absl::Span<const Shape* const> argument_shapes,
190       const ExecutionOptions* execution_options,
191       const AotCompilationOptions* aot_options = nullptr);
192 
193  private:
194   // A private overload for Service itself, used by other methods within this
195   // class.
196   StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
197       const ProgramShape& program_shape,
198       absl::Span<const ShapedBuffer* const> arguments,
199       const ExecutionOptions& execution_options,
200       const AotCompilationOptions* aot_options = nullptr);
201 
202   // Prepare the executors for executing parallel.
203   StatusOr<std::vector<se::StreamExecutor*>> GetExecutors(
204       const ExecutionOptions& execution_options, int64_t requests_size,
205       int64_t request_index) const;
206 
207   // Prepare the arguments for executing parallel.
208   StatusOr<std::vector<std::vector<const ShapedBuffer*>>> GetArguments(
209       const ExecutionOptions& execution_options,
210       absl::Span<const GlobalDataHandle* const> arguments) const;
211 
212  protected:
213   friend class LocalExecutable;
214 
215   // The constructor is private. Use the NewService factory to create new
216   // service objects.
217   Service(const ServiceOptions& options,
218           std::unique_ptr<Backend> execute_backend);
219 
220   // Resolves the given argument handles in the allocation tracker and returns
221   // the corresponding allocations for every replica. The function also verifies
222   // that each allocation matches the execution platform and device ordinal of
223   // the corresponding replica.
224   StatusOr<std::vector<std::vector<const ShapedBuffer*>>>
225   ResolveAndValidateArguments(
226       absl::Span<const GlobalDataHandle* const> arguments,
227       absl::Span<se::StreamExecutor* const> stream_executors) const;
228 
229   // Builds an Executable for the given parameters.
230   //
231   // If device_allocator is not null, the compiler may use it to allocate temp
232   // buffers, which the compiler is responsible for freeing.  The allocator
233   // given here need not match the allocator used when running the executable.
234   StatusOr<std::unique_ptr<Executable>> BuildExecutable(
235       const HloModuleProto& module_proto,
236       std::unique_ptr<HloModuleConfig> module_config, Backend* backend,
237       se::StreamExecutor* executor, const Compiler::CompileOptions& options,
238       bool run_backend_only = false);
239 
240   // Same as BuildExecutable() above, but builds a list of Executables for the
241   // given computations that may interact with each other.
242   StatusOr<std::vector<std::unique_ptr<Executable>>> BuildExecutables(
243       const std::vector<const HloModuleProto*>& module_protos,
244       std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
245       Backend* backend, std::vector<std::vector<se::StreamExecutor*>> executors,
246       const Compiler::CompileOptions& options, bool run_backend_only = false);
247 
248   // Same as BuildExecutable() above, but builds a list of
249   // AotCompilationResult(s), which can be persisted to later load Executable
250   // objects.
251   StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>> BuildAotResults(
252       const std::vector<const HloModuleProto*>& module_protos,
253       std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
254       Backend* backend, std::vector<std::vector<se::StreamExecutor*>> executors,
255       const Compiler::CompileOptions& options, bool run_backend_only = false);
256 
257   // Runs the given executable with the given arguments and register the result
258   // in the allocation tracker. The handle of the result from the tracker is
259   // returned. If the parameter "profile" is not null, it points to an
260   // ExecutionProfile object which will be filled in with profile data.
261   StatusOr<GlobalDataHandle> ExecuteAndRegisterResult(
262       Executable* executable,
263       absl::Span<const std::vector<const ShapedBuffer*>> arguments,
264       Backend* backend, const DeviceHandle& device_handle,
265       const std::string& result_tag, ExecutionProfile* profile);
266 
267   // Runs the given executables with the given arguments and register the result
268   // from each executable in the allocation tracker. The handles of the result
269   // from the tracker are returned.
270   StatusOr<std::vector<GlobalDataHandle>> ExecuteParallelAndRegisterResult(
271       absl::Span<Executable* const> executables,
272       absl::Span<const std::vector<std::vector<const ShapedBuffer*>>> arguments,
273       Backend* backend, absl::Span<const DeviceHandle> device_handles,
274       absl::Span<const std::string> result_tags, ExecutionProfile* profile);
275 
276   // Convenience function which checks whether the given client_shape
277   // (presumably passed by the client to set the result layout) is valid for the
278   // given computation result shape.
279   Status ValidateResultShape(const Shape& client_shape,
280                              const Shape& result_shape) const;
281 
282   // Returns the stream executors assigned to the replicas represented by the
283   // given device handle. Each device_handle is a virtual replicated device that
284   // represents a set of physical devices for the replicas.
285   StatusOr<std::vector<se::StreamExecutor*>> Replicas(
286       const Backend& backend, const DeviceHandle& device_handle) const;
287 
288   // Returns the device handle that represents the replicated device for a
289   // single computation that is not model-parallelized.
290   DeviceHandle SingleComputationDeviceHandle() const;
291 
292   ServiceOptions options_;
293 
294   // Cache containing previously built Executables.
295   CompilationCache compilation_cache_;
296 
297   // Tracks channels created via the API.
298   ChannelTracker channel_tracker_;
299 
300   // Tracks allocations made via the API and computation execution.
301   AllocationTracker allocation_tracker_;
302 
303   // Tracks asynchronously launched executions via the API.
304   ExecutionTracker execution_tracker_;
305 
306   // Backend to compile and execute computations on.
307   std::unique_ptr<Backend> execute_backend_;
308 
309   Service(const Service&) = delete;
310   Service& operator=(const Service&) = delete;
311 };
312 
313 }  // namespace xla
314 
315 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_SERVICE_H_
316