xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_TPU_DRIVER_CLIENT_TPU_CLIENT_H_
17 #define TENSORFLOW_COMPILER_XLA_PYTHON_TPU_DRIVER_CLIENT_TPU_CLIENT_H_
18 
19 #include <functional>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/strings/string_view.h"
26 #include "absl/synchronization/mutex.h"
27 #include "absl/synchronization/notification.h"
28 #include "absl/types/span.h"
29 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
30 #include "tensorflow/compiler/xla/client/executable_build_options.h"
31 #include "tensorflow/compiler/xla/executable_run_options.h"
32 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
33 #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h"
34 #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.pb.h"
35 #include "tensorflow/compiler/xla/shape.h"
36 #include "tensorflow/compiler/xla/status.h"
37 #include "tensorflow/compiler/xla/statusor.h"
38 #include "tensorflow/compiler/xla/util.h"
39 #include "tensorflow/core/platform/threadpool.h"
40 
41 namespace xla {
42 
TpuPlatform()43 inline const char* TpuPlatform() {
44   static constexpr char kTpuPlatform[] = "tpu";
45   return kTpuPlatform;
46 }
47 
48 class PyTpuClient;
49 
50 class TpuDevice : public PjRtDevice {
51  public:
52   TpuDevice(int id, int process_index, const std::array<int, 3>& coords,
53             int core_on_chip);
54 
coords()55   const std::array<int, 3>& coords() const { return coords_; }
core_on_chip()56   int core_on_chip() const { return core_on_chip_; }
57 
58   absl::string_view DebugString() const override;
59 
60   absl::string_view ToString() const override;
61 
62   static xla::StatusOr<std::vector<std::shared_ptr<xla::PjRtDevice>>>
63   GetTpuDevices(const tpu_driver::SystemInfo& system_info);
64 
client()65   PjRtClient* client() const override { return nullptr; }
tpu_client()66   PyTpuClient* tpu_client() const { return tpu_client_; }
set_tpu_client(PyTpuClient * tpu_client)67   void set_tpu_client(PyTpuClient* tpu_client) { tpu_client_ = tpu_client; }
68 
IsAddressable()69   bool IsAddressable() const override { return false; }
70 
id()71   int id() const override { return id_; }
72 
process_index()73   int process_index() const override { return process_index_; }
74 
local_hardware_id()75   int local_hardware_id() const override { return -1; }
76 
device_kind()77   absl::string_view device_kind() const override { return device_kind_; }
78 
TransferToInfeed(const LiteralSlice & literal)79   Status TransferToInfeed(const LiteralSlice& literal) override {
80     return Unimplemented("Infeed not yet implemented via this API");
81   }
82 
TransferFromOutfeed(MutableBorrowingLiteral literal)83   Status TransferFromOutfeed(MutableBorrowingLiteral literal) override {
84     return Unimplemented("Outfeed not yet implemented via this API");
85   }
86 
CreateAsyncTrackingEvent(absl::string_view description)87   std::unique_ptr<ScopedAsyncTrackingEvent> CreateAsyncTrackingEvent(
88       absl::string_view description) const override {
89     return nullptr;
90   }
91 
Attributes()92   const absl::flat_hash_map<std::string, PjRtDeviceAttribute>& Attributes()
93       const override {
94     return attributes_;
95   }
96 
97  private:
98   const int id_;
99   const int process_index_;
100   const std::array<int, 3> coords_;
101   const std::string device_kind_ = "Cloud TPU";
102   std::string debug_string_;
103   std::string to_string_;
104   const absl::flat_hash_map<std::string, PjRtDeviceAttribute> attributes_ = {};
105   // Index of the core of the same chip.
106   int core_on_chip_;
107   PyTpuClient* tpu_client_;
108 };
109 
110 // Encapsulates the state of Python session with XLA.
111 class PyTpuClient : public std::enable_shared_from_this<PyTpuClient> {
112  public:
113   // Initializes a local XLA client for `platform_name`. Returns an error if no
114   // such platform exists, or if the platform has no visible devices.
115   static StatusOr<std::shared_ptr<PyTpuClient>> Get(const std::string& worker);
116 
117   explicit PyTpuClient(std::string platform_name,
118                        std::unique_ptr<tpu_driver::TpuDriver> driver,
119                        std::vector<std::shared_ptr<PjRtDevice>> devices,
120                        int process_index);
121   virtual ~PyTpuClient() = default;
122 
123   PyTpuClient(const PyTpuClient&) = delete;
124   PyTpuClient(PyTpuClient&&) = delete;
125   PyTpuClient& operator=(const PyTpuClient&) = delete;
126   PyTpuClient& operator=(PyTpuClient&&) = delete;
127 
128   Status TransferToInfeed(const LiteralSlice& literal, int device_id);
129   StatusOr<Literal> TransferFromOutfeed(const Shape& shape, int device_id);
130 
131   virtual StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
132       int num_replicas, int num_partitions) const;
133 
device_count()134   int device_count() const { return devices_.size(); }
local_device_count()135   int local_device_count() const { return local_devices_.size(); }
devices()136   const std::vector<std::shared_ptr<PjRtDevice>>& devices() { return devices_; }
local_devices()137   const std::vector<std::shared_ptr<PjRtDevice>>& local_devices() {
138     return local_devices_;
139   }
id_to_device()140   const std::map<int, std::shared_ptr<PjRtDevice>>& id_to_device() const {
141     return id_to_device_;
142   }
process_index()143   int process_index() const { return process_index_; }
platform_name()144   const absl::string_view platform_name() const { return platform_name_; }
platform_version()145   const absl::string_view platform_version() const { return platform_version_; }
146 
ChooseCompactLayoutForShape(Shape subshape)147   StatusOr<Shape> ChooseCompactLayoutForShape(Shape subshape) {
148     return Unimplemented("ChooseCompactLayoutForShape not implemented.");
149   }
150 
151   // Returns a bad status containing `caller_name` if `device_id` doesn't
152   // correspond to a valid device at the POD-slice boundary.
153   Status CheckDeviceId(int device_id, absl::string_view caller_name);
154 
driver()155   tpu_driver::TpuDriver* driver() { return driver_.get(); }
156 
GetThreadPool()157   tensorflow::thread::ThreadPool* GetThreadPool() { return pool_.get(); }
158 
159  protected:
160   std::string platform_name_;
161   std::string platform_version_;
162   std::unique_ptr<tpu_driver::TpuDriver> driver_;
163 
164   // Includes all devices, including non-local devices on multi-host platforms.
165   std::vector<std::shared_ptr<PjRtDevice>> devices_;
166   // Maps Device::id() to the corresponding Device. Includes all devices.
167   std::map<int, std::shared_ptr<PjRtDevice>> id_to_device_;
168   // Local devices indexed by local device ordinal.
169   std::vector<std::shared_ptr<PjRtDevice>> local_devices_;
170   int process_index_;
171 
172   // A thread pool for scheduling core executions in parallel.
173   std::unique_ptr<tensorflow::thread::ThreadPool> pool_;
174 };
175 
176 // Manages a buffer shared amongst multiple users. Buffers are asynchronously
177 // deallocated after the last use.
178 struct TpuSharedBuffer final {
179  public:
TpuSharedBufferfinal180   TpuSharedBuffer(tpu_driver::TpuDriver* driver,
181                   std::unique_ptr<tpu_driver::BufferHandle> handle,
182                   std::vector<std::shared_ptr<tpu_driver::Event>> wait_for_use,
183                   std::shared_ptr<PjRtDevice> src_device)
184       : driver(driver),
185         device(std::move(src_device)),
186         handle(std::move(handle)),
187         wait_for_use(std::move(wait_for_use)) {}
188 
~TpuSharedBufferfinal189   ~TpuSharedBuffer() {
190     std::vector<tpu_driver::Event*> events;
191     for (const auto& e : wait_for_use) {
192       events.push_back(e.get());
193     }
194     driver->Deallocate(std::move(handle), events);
195   }
196 
197   tpu_driver::TpuDriver* const driver;
198   const std::shared_ptr<PjRtDevice> device;
199 
200   std::unique_ptr<tpu_driver::BufferHandle> handle;
201   std::vector<std::shared_ptr<tpu_driver::Event>> wait_for_use;
202 };
203 
204 // Holds a reference from Python to one or more device buffers.
205 // A PyTpuBuffer can be either valid or invalid. An invalid buffer is one that
206 // has never been initialized, or a buffer that has been deleted (e.g., by
207 // calling Delete). We allow PyTpuBuffer objects to outlive the underlying
208 // device buffers so we can decouple buffer lifetimes from the corresponding
209 // Python references if needed.
210 // Thread-safe.
211 class PyTpuBuffer {
212  public:
213   // `tuple_shape` can be at most a one-level tuple combining non-tuple leaves.
214   static StatusOr<std::unique_ptr<PyTpuBuffer>> FromLiterals(
215       std::vector<BorrowingLiteral> leaves_literals, const Shape& tuple_shape,
216       std::shared_ptr<void> leaves_reference,
217       std::shared_ptr<PyTpuClient> client, std::shared_ptr<PjRtDevice> device);
218 
219   // Supports nested tuple creation.
220   static StatusOr<std::unique_ptr<PyTpuBuffer>> MakeTuple(
221       absl::Span<PyTpuBuffer* const> buffers,
222       std::shared_ptr<PyTpuClient> client, std::shared_ptr<PjRtDevice> device);
223 
224   PyTpuBuffer() = delete;
225   PyTpuBuffer(Shape on_host_shape,
226               std::shared_ptr<TpuSharedBuffer> device_buffer,
227               std::vector<std::shared_ptr<TpuSharedBuffer>> child_buffers,
228               std::shared_ptr<PyTpuClient> client);
229 
230   PyTpuBuffer(const PyTpuBuffer&) = delete;
231   PyTpuBuffer(PyTpuBuffer&&) = delete;
232   PyTpuBuffer& operator=(const PyTpuBuffer&) = delete;
233   PyTpuBuffer& operator=(PyTpuBuffer&&) = delete;
234 
on_host_shape()235   const Shape& on_host_shape() const { return on_host_shape_; }
device()236   std::shared_ptr<PjRtDevice> device() const { return device_; }
platform_name()237   const absl::string_view platform_name() const {
238     return client_->platform_name();
239   }
client()240   std::shared_ptr<PyTpuClient> client() const { return client_; }
241 
242   // Returns the buffer's value as a tuple DAG of Python arrays. If the value
243   // has previously been prefetched to the host, then returns the prefetched
244   // version, otherwise copies the buffer to the host. Blocks until the
245   // value is ready.
246   StatusOr<std::shared_ptr<Literal>> ToLiteral();
247 
248   // Initiates a copy of the buffer to the host. Does not block waiting for
249   // the transfer to complete. The value can be retrieved by a later call to
250   // ToLiteral().
251   Status CopyToHostAsync();
252 
253   // Returns the associated device buffer. Returns a nullptr if the buffer is
254   // invalid.
255   std::shared_ptr<TpuSharedBuffer> DeviceBuffer() const;
256 
257   // Deletes the device memory associated with this buffer, leaving it in an
258   // invalid state.
259   void Delete();
260 
261   // Destructures a tuple-valued PyTpuBuffer into its constituent elements.
262   StatusOr<std::vector<std::unique_ptr<PyTpuBuffer>>> DestructureTuple();
263 
264   // Copies the buffer to target device `dst_device` and returns a PyTpuBuffer
265   // object holding the context to the target device buffer.
266   StatusOr<std::unique_ptr<PyTpuBuffer>> CopyToDevice(
267       std::shared_ptr<PjRtDevice> dst_device);
268 
269   // Blocks the host until the buffer's value has been computed and is ready for
270   // immediate use on the device. Useful in particular for timing benchmarks.
271   Status BlockHostUntilReady();
272 
273   // Allocates uninitialized buffers on device `device_id`. If `shape` is a
274   // tuple, the returned buffer corresponds to the root tuple buffer.
275   static StatusOr<std::unique_ptr<PyTpuBuffer>> AllocateBuffer(
276       const Shape& shape, std::shared_ptr<PyTpuClient> client,
277       std::shared_ptr<PjRtDevice> device);
278 
279  private:
280   // Initializes a just allocated device buffer. The returned event will be
281   // placed into the buffer's `wait_for_use` list.
282   using BufferInitializer = std::function<std::shared_ptr<tpu_driver::Event>(
283       tpu_driver::BufferHandle*)>;
284   // Allocates and optionally initializes a non-tuple buffer on the device.
285   static StatusOr<std::unique_ptr<PyTpuBuffer>> CreateBuffer(
286       const Shape& non_tuple_shape,
287       std::optional<BufferInitializer> initializer,
288       std::shared_ptr<PyTpuClient> client, std::shared_ptr<PjRtDevice> device);
289 
290   const std::shared_ptr<PyTpuClient> client_;
291   const Shape on_host_shape_;
292   const std::shared_ptr<PjRtDevice> device_;
293 
294   // If this is a tuple, `device_buffer_` stores the tuple buffer and
295   // `child_buffers_` stores the child buffers; else, `device_buffer_` stores
296   // the data content and `child_buffers_` is empty.
297   mutable absl::Mutex mu_;
298   std::shared_ptr<TpuSharedBuffer> device_buffer_ ABSL_GUARDED_BY(mu_);
299   std::vector<std::shared_ptr<TpuSharedBuffer>> child_buffers_
300       ABSL_GUARDED_BY(mu_);
301   // The cached value of the buffer on the host, produced either from a call to
302   // CopyToHost or from a call to ToLiteral. Once a value has been fetched to
303   // the host, it persists Delete() is called or the PyTpuBuffer is destroyed.
304   struct HostValue {
305     absl::Mutex mutex;
306     absl::Notification ready;
307     int pending_ops;
308     // status and value are valid for reading only after `ready` has been
309     // notified.
310     Status status;
311     std::shared_ptr<Literal> value;
312   };
313   std::shared_ptr<HostValue> host_value_ ABSL_GUARDED_BY(mu_);
314 };
315 
316 // A dummy token that is always ready. PyTpuExecutable::Execute() is blocking
317 // until the computation finishes.
318 class PyTpuToken {
319  public:
PyTpuToken()320   PyTpuToken() {}
Await()321   Status Await() { return Status::OK(); }
322 };
323 
324 class PyShardedTpuToken {
325  public:
PyShardedTpuToken()326   PyShardedTpuToken() {}
Await()327   Status Await() { return Status::OK(); }
GetPyToken(int i)328   PyTpuToken GetPyToken(int i) { return PyTpuToken(); }
329 };
330 
331 // Represents a compiled computation that can be executed given handles to
332 // device-allocated literals. Wraps an XLA LocalExecutable.
333 class PyTpuExecutable {
334  public:
335   static StatusOr<std::unique_ptr<PyTpuExecutable>> Compile(
336       const XlaComputation& computation,
337       std::optional<std::vector<Shape>> argument_layouts,
338       const ExecutableBuildOptions* build_options,
339       std::shared_ptr<PyTpuClient> client, bool tuple_arguments);
340 
341   static StatusOr<std::unique_ptr<PyTpuExecutable>> CompileMlir(
342       mlir::ModuleOp module, std::optional<std::vector<Shape>> argument_layouts,
343       const ExecutableBuildOptions* build_options,
344       std::shared_ptr<PyTpuClient> client, bool tuple_arguments);
345 
346   PyTpuExecutable(
347       std::unique_ptr<tpu_driver::CompiledProgramHandle> compiled_program,
348       DeviceAssignment device_assignment, std::shared_ptr<PyTpuClient> client,
349       xla::Shape result_shape, bool tuple_arguments);
~PyTpuExecutable()350   virtual ~PyTpuExecutable() {
351     for (auto it = executables_.begin(); it != executables_.end(); ++it) {
352       client_->driver()->UnloadProgram(std::move(it->second), {});
353     }
354   }
355 
356   PyTpuExecutable(const PyTpuExecutable&) = delete;
357   PyTpuExecutable(PyTpuExecutable&&) = delete;
358   PyTpuExecutable& operator=(const PyTpuExecutable&) = delete;
359   PyTpuExecutable& operator=(PyTpuExecutable&&) = delete;
360 
client()361   std::shared_ptr<PyTpuClient> client() const { return client_; }
362 
num_replicas()363   int num_replicas() const { return device_assignment_.replica_count(); }
num_partitions()364   int num_partitions() const { return device_assignment_.computation_count(); }
365 
SizeOfGeneratedCodeInBytes()366   int64_t SizeOfGeneratedCodeInBytes() const {
367     CHECK_GE(executables_.size(), 1);
368     return executables_.begin()->second->size_in_bytes();
369   }
370 
device_assignment()371   const DeviceAssignment& device_assignment() const {
372     return device_assignment_;
373   }
374 
local_logical_device_ids()375   const std::vector<std::pair<int, int>>& local_logical_device_ids() const {
376     return local_logical_device_ids_;
377   }
378 
local_devices()379   const std::vector<std::shared_ptr<PjRtDevice>>& local_devices() const {
380     return local_devices_;
381   }
382 
383   // TODO(power): Both Execute and ExecutePerOnLocalDevices block and wait
384   // inside for computation to finish. Coordinate with JAX code change to see if
385   // we can make both Execute and ExecutePerReplica non-blocking.
386   StatusOr<std::vector<std::unique_ptr<PyTpuBuffer>>> Execute(
387       absl::Span<PyTpuBuffer* const> argument_handles);
388 
389   StatusOr<std::pair<std::vector<std::unique_ptr<PyTpuBuffer>>, PyTpuToken>>
ExecuteWithToken(absl::Span<PyTpuBuffer * const> argument_handles)390   ExecuteWithToken(absl::Span<PyTpuBuffer* const> argument_handles) {
391     TF_ASSIGN_OR_RETURN(auto results, Execute(argument_handles));
392     return std::pair<std::vector<std::unique_ptr<PyTpuBuffer>>, PyTpuToken>(
393         std::move(results), PyTpuToken());
394   }
395 
396   // Execute on local devices. Takes a sequence of argument lists (one argument
397   // list per local device) and returns a tuple of results (one result per local
398   // device). The number of argument lists must be equal to the local device
399   // count.
400   StatusOr<std::vector<std::vector<std::unique_ptr<PyTpuBuffer>>>>
401   ExecuteOnLocalDevices(
402       absl::Span<const std::vector<PyTpuBuffer*>> argument_handles);
403 
404   StatusOr<std::vector<std::vector<std::unique_ptr<PyTpuBuffer>>>>
405   ExecuteShardedOnLocalDevices(
406       absl::Span<const std::vector<PyTpuBuffer*>> args);
407 
408   StatusOr<std::pair<std::vector<std::vector<std::unique_ptr<PyTpuBuffer>>>,
409                      PyShardedTpuToken>>
ExecuteShardedOnLocalDevicesWithTokens(absl::Span<const std::vector<PyTpuBuffer * >> args)410   ExecuteShardedOnLocalDevicesWithTokens(
411       absl::Span<const std::vector<PyTpuBuffer*>> args) {
412     TF_ASSIGN_OR_RETURN(auto results, ExecuteShardedOnLocalDevices(args));
413 
414     TF_RET_CHECK(!args.empty());
415     return std::pair<std::vector<std::vector<std::unique_ptr<PyTpuBuffer>>>,
416                      PyShardedTpuToken>(std::move(results),
417                                         PyShardedTpuToken());
418   }
419 
Delete()420   void Delete() { executables_.clear(); }
421 
422  private:
423   struct ExecuteResult {
424     std::unique_ptr<PyTpuBuffer> buffer;
425     std::shared_ptr<tpu_driver::Event> on_execute_finished;
426   };
427 
428   ExecuteResult ExecuteHelper(
429       absl::Span<const std::vector<PyTpuBuffer*>> all_core_arguments,
430       absl::Span<PyTpuBuffer* const> this_core_arguments, int replica,
431       int partition, const RunId& run_id);
432 
433   std::shared_ptr<PyTpuClient> const client_;
434   std::map<int, std::unique_ptr<tpu_driver::LoadedProgramHandle>> executables_;
435   const DeviceAssignment device_assignment_;
436   const bool tuple_arguments_;
437 
438   // The replica and partition indices of device_assignment_ to be run by this
439   // client. On single-host platforms without partitioning, this is all replicas
440   // (i.e. local_logical_device_ids_[i] = (i, 0)), but this may not be the case
441   // on multi-host platforms.
442   // If there are 4 replicas and 2 partitions on a single host platform, size of
443   // local_logical_device_ids_ is 4*2 = 8.
444   std::vector<std::pair<int, int>> local_logical_device_ids_;
445 
446   // local_devices_[i] is the Device to which local_logical_device_ids_[i] is
447   // assigned.
448   // shared_ptrs instead of unique_ptrs to play well with the Python bindings
449   // (see xla.cc).
450   std::vector<std::shared_ptr<PjRtDevice>> local_devices_;
451 
452   xla::Shape result_shape_;
453 };
454 
455 }  // namespace xla
456 
457 #endif  // TENSORFLOW_COMPILER_XLA_PYTHON_TPU_DRIVER_CLIENT_TPU_CLIENT_H_
458