xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/client/local_client.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LOCAL_CLIENT_H_
17 #define TENSORFLOW_COMPILER_XLA_CLIENT_LOCAL_CLIENT_H_
18 
19 #include <memory>
20 #include <string>
21 #include <vector>
22 
23 #include "absl/types/span.h"
24 #include "tensorflow/compiler/xla/client/client.h"
25 #include "tensorflow/compiler/xla/client/executable_build_options.h"
26 #include "tensorflow/compiler/xla/client/xla_computation.h"
27 #include "tensorflow/compiler/xla/executable_run_options.h"
28 #include "tensorflow/compiler/xla/service/compiler.h"
29 #include "tensorflow/compiler/xla/service/executable.h"
30 #include "tensorflow/compiler/xla/service/hlo.pb.h"
31 #include "tensorflow/compiler/xla/service/local_service.h"
32 #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
33 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
34 #include "tensorflow/compiler/xla/shape_tree.h"
35 #include "tensorflow/compiler/xla/statusor.h"
36 #include "tensorflow/compiler/xla/xla_data.pb.h"
37 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
38 #include "tensorflow/stream_executor/device_memory_allocator.h"
39 
40 namespace xla {
41 
42 class LocalExecutable {
43  public:
44   // Low-level constructor; LocalClient::Compile() is the usual way to create
45   // executables.
46   LocalExecutable(std::unique_ptr<Executable> executable, Backend* backend,
47                   ExecutableBuildOptions build_options);
48 
49   // Run the compiled computation with the given arguments and options and
50   // return the result.
51   StatusOr<ScopedShapedBuffer> Run(
52       const absl::Span<const ShapedBuffer* const> arguments,
53       ExecutableRunOptions run_options);
54 
55   // Similar to Run(), but allows for donating argument buffers to the
56   // executable.
57   StatusOr<ExecutionOutput> Run(std::vector<ExecutionInput> arguments,
58                                 ExecutableRunOptions run_options);
59 
60   // Similar to Run(), but need not block the host waiting for the computation
61   // to complete before returning.
62   StatusOr<ScopedShapedBuffer> RunAsync(
63       const absl::Span<const ShapedBuffer* const> arguments,
64       ExecutableRunOptions run_options);
65 
66   // Similar to RunAsync(), but allows for donating argument buffers to the
67   // executable.
68   StatusOr<ExecutionOutput> RunAsync(std::vector<ExecutionInput> arguments,
69                                      ExecutableRunOptions run_options);
70 
71   // Return the options used to build the executable.
build_options()72   const ExecutableBuildOptions& build_options() const { return build_options_; }
73 
74   // Return the built executable.
executable()75   Executable* executable() const { return executable_.get(); }
76 
77  private:
78   StatusOr<ExecutionOutput> RunAsync(
79       absl::Span<Shape const* const> argument_host_shapes,
80       std::vector<ExecutionInput> arguments, ExecutableRunOptions run_options);
81 
82   // Validates that the given arguments and options satisfy various constraints
83   // of the computation.
84   //
85   // The given ExecutableRunOptions override any values from TF_XLA_FLAGS
86   // environment variable.
87   Status ValidateExecutionOptions(const ExecutableRunOptions& run_options,
88                                   const Backend& backend);
89 
90   // Returns a literal containing the contents of the given ShapedBuffer.
91   StatusOr<Literal> LiteralFromShapedBuffer(const ShapedBuffer& shaped_buffer);
92 
93   StatusOr<std::pair<ServiceExecutableRunOptions, StreamPool::Ptr>> RunHelper(
94       const absl::Span<const Shape* const> argument_shapes,
95       ExecutableRunOptions run_options);
96 
97   // The ordinal of the device which this executable was compiled for. The
98   // executable can run on all equivalent devices (as determined by
99   // Backend::devices_equivalent).
build_device_ordinal()100   int build_device_ordinal() const { return build_options_.device_ordinal(); }
101 
102   template <typename T>
AsyncCallAndBlockHostUntilDone(absl::Span<Shape const * const> argument_shapes,const ExecutableRunOptions & run_options,std::function<StatusOr<T> (const ExecutableRunOptions &)> async_callback)103   StatusOr<T> AsyncCallAndBlockHostUntilDone(
104       absl::Span<Shape const* const> argument_shapes,
105       const ExecutableRunOptions& run_options,
106       std::function<StatusOr<T>(const ExecutableRunOptions&)> async_callback) {
107     TF_ASSIGN_OR_RETURN(auto options_and_stream,
108                         RunHelper(argument_shapes, run_options));
109     ExecutableRunOptions options = options_and_stream.first.run_options();
110     options.set_device_ordinal(-1);
111     StatusOr<T> result = async_callback(options);
112     Status block_status = options.stream()->BlockHostUntilDone();
113     TF_RETURN_IF_ERROR(result.status());
114     TF_RETURN_IF_ERROR(block_status);
115     return result;
116   }
117 
118   // Compiled computation.
119   std::unique_ptr<Executable> executable_;
120 
121   // Execution backend.
122   Backend* backend_ = nullptr;
123 
124   // Options used to build the executable.
125   const ExecutableBuildOptions build_options_;
126 };
127 
128 // An XLA Client specialization for use when the client and service run in
129 // the same process.
130 class LocalClient : public Client {
131  public:
LocalClient(LocalService * service)132   explicit LocalClient(LocalService* service)
133       : Client(service), local_service_(service) {}
134 
135   LocalClient(const LocalClient&) = delete;
136   void operator=(const LocalClient&) = delete;
137 
138   // Build and return LocalExecutable objects (one per partition, as specified
139   // by the build options). The executable is compiled using the given
140   // XlaComputation, argument layouts and options.
141   //
142   // The given ExecutableBuildOptions overrides any values from XLA_FLAGS
143   // environment variable.
144   StatusOr<std::vector<std::unique_ptr<LocalExecutable>>> Compile(
145       const XlaComputation& computation,
146       const absl::Span<const Shape* const> argument_layouts,
147       const ExecutableBuildOptions& options);
148 
149   // Same as Compile() above, but return AotCompilationResult objects (instead
150   // of LocalExecutable objects), which can be persisted to later load
151   // LocalExecutable(s) using the Load() method below.
152   StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
153   CompileAheadOfTime(const XlaComputation& computation,
154                      const absl::Span<const Shape* const> argument_layouts,
155                      const ExecutableBuildOptions& options);
156 
157   // Return a LocalExecutable object loaded from a serialized
158   // AotCompilationResult.
159   StatusOr<std::unique_ptr<LocalExecutable>> Load(
160       const std::string& serialized_aot_result,
161       const ExecutableBuildOptions& options);
162 
163   // Copy the literal data to the device with the given ordinal and return as a
164   // ScopedShapedBuffer. If non-null the given memory allocator is used for
165   // device memory allocation. If null, the default memory allocator for the
166   // device is used.
167   StatusOr<ScopedShapedBuffer> LiteralToShapedBuffer(
168       const LiteralSlice& literal, int device_ordinal,
169       se::DeviceMemoryAllocator* allocator = nullptr);
170 
171   // Transfer the BorrowingLiteral to the device with the given ordinal.
172   StatusOr<TransferToServerResponse> TransferToLocalServer(
173       const ::xla::BorrowingLiteral& literal, int device_ordinal);
174 
175   // Copy the data from the device contained in the given ShapedBuffer and
176   // return as a Literal.
177   StatusOr<Literal> ShapedBufferToLiteral(const ShapedBuffer& shaped_buffer);
178 
179   // Converts a GlobalDataHandle into a pointer to a ShapedBuffer that's valid
180   // as long as the handle is valid.
181   StatusOr<const ShapedBuffer*> GlobalDataToShapedBuffer(
182       const GlobalDataHandle& data, int replica_number);
183 
184   // Transfer the given literal to the infeed queue of the given device.
185   // TODO(b/69670845): Remove the 'Local' from the name when LocalClient does
186   // not inherit from Client and there is no possibility of confusion with
187   // Client::TransferToInfeed.
188   Status TransferToInfeedLocal(const LiteralSlice& literal, int device_ordinal);
189 
190   // Transfer and return a value from the outfeed of the given device. The
191   // shape of the object to transfer is determined by `literal`'s shape.
192   // TODO(b/69670845): Remove the 'Local' from the name when LocalClient does
193   // not inherit from Client and there is no possibility of confusion with
194   // Client::TransferFromOutfeed.
195   Status TransferFromOutfeedLocal(int device_ordinal,
196                                   MutableBorrowingLiteral literal);
197 
198   // Returns the device ordinal that corresponds to the given replica number.
199   //
200   // This returns an error if there is not a one-to-one correspondence of
201   // replicas to device ordinals, but is useful as a short term mechanism for
202   // the "easy" case where a single replica is a single device.
203   StatusOr<int> ReplicaNumberToDeviceOrdinal(int replica_number);
204 
205   // Returns the platform that the underlying service targets.
206   se::Platform* platform() const;
207 
208   // Returns the number of devices on the system of the service platform
209   // type. Not all devices may be supported by the service (see
210   // device_ordinal_supported method).
211   int device_count() const;
212 
213   // Returns the default device ordinal that the service will run computations
214   // on if no device ordinal is specified in execute options.
215   int default_device_ordinal() const;
216 
217   // Returns whether the device with the given ordinal can be used by the
218   // service to execute computations. Not all devices of a particular platform
219   // may be usable by the service (eg, a GPU with insufficient CUDA compute
220   // capability).
221   bool device_ordinal_supported(int device_ordinal) const;
222 
223   // Returns the backend used to execute computations.
224   const Backend& backend() const;
225   Backend* mutable_backend();
226 
227  private:
228   LocalService* local_service_;
229 };
230 
231 }  // namespace xla
232 
233 #endif  // TENSORFLOW_COMPILER_XLA_CLIENT_LOCAL_CLIENT_H_
234