1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_TPU_DRIVER_CLIENT_TPU_CLIENT_H_
17 #define TENSORFLOW_COMPILER_XLA_PYTHON_TPU_DRIVER_CLIENT_TPU_CLIENT_H_
18
19 #include <functional>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/strings/string_view.h"
26 #include "absl/synchronization/mutex.h"
27 #include "absl/synchronization/notification.h"
28 #include "absl/types/span.h"
29 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
30 #include "tensorflow/compiler/xla/client/executable_build_options.h"
31 #include "tensorflow/compiler/xla/executable_run_options.h"
32 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
33 #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.h"
34 #include "tensorflow/compiler/xla/python/tpu_driver/tpu_driver.pb.h"
35 #include "tensorflow/compiler/xla/shape.h"
36 #include "tensorflow/compiler/xla/status.h"
37 #include "tensorflow/compiler/xla/statusor.h"
38 #include "tensorflow/compiler/xla/util.h"
39 #include "tensorflow/core/platform/threadpool.h"
40
41 namespace xla {
42
TpuPlatform()43 inline const char* TpuPlatform() {
44 static constexpr char kTpuPlatform[] = "tpu";
45 return kTpuPlatform;
46 }
47
48 class PyTpuClient;
49
50 class TpuDevice : public PjRtDevice {
51 public:
52 TpuDevice(int id, int process_index, const std::array<int, 3>& coords,
53 int core_on_chip);
54
coords()55 const std::array<int, 3>& coords() const { return coords_; }
core_on_chip()56 int core_on_chip() const { return core_on_chip_; }
57
58 absl::string_view DebugString() const override;
59
60 absl::string_view ToString() const override;
61
62 static xla::StatusOr<std::vector<std::shared_ptr<xla::PjRtDevice>>>
63 GetTpuDevices(const tpu_driver::SystemInfo& system_info);
64
client()65 PjRtClient* client() const override { return nullptr; }
tpu_client()66 PyTpuClient* tpu_client() const { return tpu_client_; }
set_tpu_client(PyTpuClient * tpu_client)67 void set_tpu_client(PyTpuClient* tpu_client) { tpu_client_ = tpu_client; }
68
IsAddressable()69 bool IsAddressable() const override { return false; }
70
id()71 int id() const override { return id_; }
72
process_index()73 int process_index() const override { return process_index_; }
74
local_hardware_id()75 int local_hardware_id() const override { return -1; }
76
device_kind()77 absl::string_view device_kind() const override { return device_kind_; }
78
TransferToInfeed(const LiteralSlice & literal)79 Status TransferToInfeed(const LiteralSlice& literal) override {
80 return Unimplemented("Infeed not yet implemented via this API");
81 }
82
TransferFromOutfeed(MutableBorrowingLiteral literal)83 Status TransferFromOutfeed(MutableBorrowingLiteral literal) override {
84 return Unimplemented("Outfeed not yet implemented via this API");
85 }
86
CreateAsyncTrackingEvent(absl::string_view description)87 std::unique_ptr<ScopedAsyncTrackingEvent> CreateAsyncTrackingEvent(
88 absl::string_view description) const override {
89 return nullptr;
90 }
91
Attributes()92 const absl::flat_hash_map<std::string, PjRtDeviceAttribute>& Attributes()
93 const override {
94 return attributes_;
95 }
96
97 private:
98 const int id_;
99 const int process_index_;
100 const std::array<int, 3> coords_;
101 const std::string device_kind_ = "Cloud TPU";
102 std::string debug_string_;
103 std::string to_string_;
104 const absl::flat_hash_map<std::string, PjRtDeviceAttribute> attributes_ = {};
105 // Index of the core of the same chip.
106 int core_on_chip_;
107 PyTpuClient* tpu_client_;
108 };
109
110 // Encapsulates the state of Python session with XLA.
111 class PyTpuClient : public std::enable_shared_from_this<PyTpuClient> {
112 public:
113 // Initializes a local XLA client for `platform_name`. Returns an error if no
114 // such platform exists, or if the platform has no visible devices.
115 static StatusOr<std::shared_ptr<PyTpuClient>> Get(const std::string& worker);
116
117 explicit PyTpuClient(std::string platform_name,
118 std::unique_ptr<tpu_driver::TpuDriver> driver,
119 std::vector<std::shared_ptr<PjRtDevice>> devices,
120 int process_index);
121 virtual ~PyTpuClient() = default;
122
123 PyTpuClient(const PyTpuClient&) = delete;
124 PyTpuClient(PyTpuClient&&) = delete;
125 PyTpuClient& operator=(const PyTpuClient&) = delete;
126 PyTpuClient& operator=(PyTpuClient&&) = delete;
127
128 Status TransferToInfeed(const LiteralSlice& literal, int device_id);
129 StatusOr<Literal> TransferFromOutfeed(const Shape& shape, int device_id);
130
131 virtual StatusOr<DeviceAssignment> GetDefaultDeviceAssignment(
132 int num_replicas, int num_partitions) const;
133
device_count()134 int device_count() const { return devices_.size(); }
local_device_count()135 int local_device_count() const { return local_devices_.size(); }
devices()136 const std::vector<std::shared_ptr<PjRtDevice>>& devices() { return devices_; }
local_devices()137 const std::vector<std::shared_ptr<PjRtDevice>>& local_devices() {
138 return local_devices_;
139 }
id_to_device()140 const std::map<int, std::shared_ptr<PjRtDevice>>& id_to_device() const {
141 return id_to_device_;
142 }
process_index()143 int process_index() const { return process_index_; }
platform_name()144 const absl::string_view platform_name() const { return platform_name_; }
platform_version()145 const absl::string_view platform_version() const { return platform_version_; }
146
ChooseCompactLayoutForShape(Shape subshape)147 StatusOr<Shape> ChooseCompactLayoutForShape(Shape subshape) {
148 return Unimplemented("ChooseCompactLayoutForShape not implemented.");
149 }
150
151 // Returns a bad status containing `caller_name` if `device_id` doesn't
152 // correspond to a valid device at the POD-slice boundary.
153 Status CheckDeviceId(int device_id, absl::string_view caller_name);
154
driver()155 tpu_driver::TpuDriver* driver() { return driver_.get(); }
156
GetThreadPool()157 tensorflow::thread::ThreadPool* GetThreadPool() { return pool_.get(); }
158
159 protected:
160 std::string platform_name_;
161 std::string platform_version_;
162 std::unique_ptr<tpu_driver::TpuDriver> driver_;
163
164 // Includes all devices, including non-local devices on multi-host platforms.
165 std::vector<std::shared_ptr<PjRtDevice>> devices_;
166 // Maps Device::id() to the corresponding Device. Includes all devices.
167 std::map<int, std::shared_ptr<PjRtDevice>> id_to_device_;
168 // Local devices indexed by local device ordinal.
169 std::vector<std::shared_ptr<PjRtDevice>> local_devices_;
170 int process_index_;
171
172 // A thread pool for scheduling core executions in parallel.
173 std::unique_ptr<tensorflow::thread::ThreadPool> pool_;
174 };
175
176 // Manages a buffer shared amongst multiple users. Buffers are asynchronously
177 // deallocated after the last use.
178 struct TpuSharedBuffer final {
179 public:
TpuSharedBufferfinal180 TpuSharedBuffer(tpu_driver::TpuDriver* driver,
181 std::unique_ptr<tpu_driver::BufferHandle> handle,
182 std::vector<std::shared_ptr<tpu_driver::Event>> wait_for_use,
183 std::shared_ptr<PjRtDevice> src_device)
184 : driver(driver),
185 device(std::move(src_device)),
186 handle(std::move(handle)),
187 wait_for_use(std::move(wait_for_use)) {}
188
~TpuSharedBufferfinal189 ~TpuSharedBuffer() {
190 std::vector<tpu_driver::Event*> events;
191 for (const auto& e : wait_for_use) {
192 events.push_back(e.get());
193 }
194 driver->Deallocate(std::move(handle), events);
195 }
196
197 tpu_driver::TpuDriver* const driver;
198 const std::shared_ptr<PjRtDevice> device;
199
200 std::unique_ptr<tpu_driver::BufferHandle> handle;
201 std::vector<std::shared_ptr<tpu_driver::Event>> wait_for_use;
202 };
203
204 // Holds a reference from Python to one or more device buffers.
205 // A PyTpuBuffer can be either valid or invalid. An invalid buffer is one that
206 // has never been initialized, or a buffer that has been deleted (e.g., by
207 // calling Delete). We allow PyTpuBuffer objects to outlive the underlying
208 // device buffers so we can decouple buffer lifetimes from the corresponding
209 // Python references if needed.
210 // Thread-safe.
211 class PyTpuBuffer {
212 public:
213 // `tuple_shape` can be at most a one-level tuple combining non-tuple leaves.
214 static StatusOr<std::unique_ptr<PyTpuBuffer>> FromLiterals(
215 std::vector<BorrowingLiteral> leaves_literals, const Shape& tuple_shape,
216 std::shared_ptr<void> leaves_reference,
217 std::shared_ptr<PyTpuClient> client, std::shared_ptr<PjRtDevice> device);
218
219 // Supports nested tuple creation.
220 static StatusOr<std::unique_ptr<PyTpuBuffer>> MakeTuple(
221 absl::Span<PyTpuBuffer* const> buffers,
222 std::shared_ptr<PyTpuClient> client, std::shared_ptr<PjRtDevice> device);
223
224 PyTpuBuffer() = delete;
225 PyTpuBuffer(Shape on_host_shape,
226 std::shared_ptr<TpuSharedBuffer> device_buffer,
227 std::vector<std::shared_ptr<TpuSharedBuffer>> child_buffers,
228 std::shared_ptr<PyTpuClient> client);
229
230 PyTpuBuffer(const PyTpuBuffer&) = delete;
231 PyTpuBuffer(PyTpuBuffer&&) = delete;
232 PyTpuBuffer& operator=(const PyTpuBuffer&) = delete;
233 PyTpuBuffer& operator=(PyTpuBuffer&&) = delete;
234
on_host_shape()235 const Shape& on_host_shape() const { return on_host_shape_; }
device()236 std::shared_ptr<PjRtDevice> device() const { return device_; }
platform_name()237 const absl::string_view platform_name() const {
238 return client_->platform_name();
239 }
client()240 std::shared_ptr<PyTpuClient> client() const { return client_; }
241
242 // Returns the buffer's value as a tuple DAG of Python arrays. If the value
243 // has previously been prefetched to the host, then returns the prefetched
244 // version, otherwise copies the buffer to the host. Blocks until the
245 // value is ready.
246 StatusOr<std::shared_ptr<Literal>> ToLiteral();
247
248 // Initiates a copy of the buffer to the host. Does not block waiting for
249 // the transfer to complete. The value can be retrieved by a later call to
250 // ToLiteral().
251 Status CopyToHostAsync();
252
253 // Returns the associated device buffer. Returns a nullptr if the buffer is
254 // invalid.
255 std::shared_ptr<TpuSharedBuffer> DeviceBuffer() const;
256
257 // Deletes the device memory associated with this buffer, leaving it in an
258 // invalid state.
259 void Delete();
260
261 // Destructures a tuple-valued PyTpuBuffer into its constituent elements.
262 StatusOr<std::vector<std::unique_ptr<PyTpuBuffer>>> DestructureTuple();
263
264 // Copies the buffer to target device `dst_device` and returns a PyTpuBuffer
265 // object holding the context to the target device buffer.
266 StatusOr<std::unique_ptr<PyTpuBuffer>> CopyToDevice(
267 std::shared_ptr<PjRtDevice> dst_device);
268
269 // Blocks the host until the buffer's value has been computed and is ready for
270 // immediate use on the device. Useful in particular for timing benchmarks.
271 Status BlockHostUntilReady();
272
273 // Allocates uninitialized buffers on device `device_id`. If `shape` is a
274 // tuple, the returned buffer corresponds to the root tuple buffer.
275 static StatusOr<std::unique_ptr<PyTpuBuffer>> AllocateBuffer(
276 const Shape& shape, std::shared_ptr<PyTpuClient> client,
277 std::shared_ptr<PjRtDevice> device);
278
279 private:
280 // Initializes a just allocated device buffer. The returned event will be
281 // placed into the buffer's `wait_for_use` list.
282 using BufferInitializer = std::function<std::shared_ptr<tpu_driver::Event>(
283 tpu_driver::BufferHandle*)>;
284 // Allocates and optionally initializes a non-tuple buffer on the device.
285 static StatusOr<std::unique_ptr<PyTpuBuffer>> CreateBuffer(
286 const Shape& non_tuple_shape,
287 std::optional<BufferInitializer> initializer,
288 std::shared_ptr<PyTpuClient> client, std::shared_ptr<PjRtDevice> device);
289
290 const std::shared_ptr<PyTpuClient> client_;
291 const Shape on_host_shape_;
292 const std::shared_ptr<PjRtDevice> device_;
293
294 // If this is a tuple, `device_buffer_` stores the tuple buffer and
295 // `child_buffers_` stores the child buffers; else, `device_buffer_` stores
296 // the data content and `child_buffers_` is empty.
297 mutable absl::Mutex mu_;
298 std::shared_ptr<TpuSharedBuffer> device_buffer_ ABSL_GUARDED_BY(mu_);
299 std::vector<std::shared_ptr<TpuSharedBuffer>> child_buffers_
300 ABSL_GUARDED_BY(mu_);
301 // The cached value of the buffer on the host, produced either from a call to
302 // CopyToHost or from a call to ToLiteral. Once a value has been fetched to
303 // the host, it persists Delete() is called or the PyTpuBuffer is destroyed.
304 struct HostValue {
305 absl::Mutex mutex;
306 absl::Notification ready;
307 int pending_ops;
308 // status and value are valid for reading only after `ready` has been
309 // notified.
310 Status status;
311 std::shared_ptr<Literal> value;
312 };
313 std::shared_ptr<HostValue> host_value_ ABSL_GUARDED_BY(mu_);
314 };
315
316 // A dummy token that is always ready. PyTpuExecutable::Execute() is blocking
317 // until the computation finishes.
318 class PyTpuToken {
319 public:
PyTpuToken()320 PyTpuToken() {}
Await()321 Status Await() { return Status::OK(); }
322 };
323
324 class PyShardedTpuToken {
325 public:
PyShardedTpuToken()326 PyShardedTpuToken() {}
Await()327 Status Await() { return Status::OK(); }
GetPyToken(int i)328 PyTpuToken GetPyToken(int i) { return PyTpuToken(); }
329 };
330
331 // Represents a compiled computation that can be executed given handles to
332 // device-allocated literals. Wraps an XLA LocalExecutable.
333 class PyTpuExecutable {
334 public:
335 static StatusOr<std::unique_ptr<PyTpuExecutable>> Compile(
336 const XlaComputation& computation,
337 std::optional<std::vector<Shape>> argument_layouts,
338 const ExecutableBuildOptions* build_options,
339 std::shared_ptr<PyTpuClient> client, bool tuple_arguments);
340
341 static StatusOr<std::unique_ptr<PyTpuExecutable>> CompileMlir(
342 mlir::ModuleOp module, std::optional<std::vector<Shape>> argument_layouts,
343 const ExecutableBuildOptions* build_options,
344 std::shared_ptr<PyTpuClient> client, bool tuple_arguments);
345
346 PyTpuExecutable(
347 std::unique_ptr<tpu_driver::CompiledProgramHandle> compiled_program,
348 DeviceAssignment device_assignment, std::shared_ptr<PyTpuClient> client,
349 xla::Shape result_shape, bool tuple_arguments);
~PyTpuExecutable()350 virtual ~PyTpuExecutable() {
351 for (auto it = executables_.begin(); it != executables_.end(); ++it) {
352 client_->driver()->UnloadProgram(std::move(it->second), {});
353 }
354 }
355
356 PyTpuExecutable(const PyTpuExecutable&) = delete;
357 PyTpuExecutable(PyTpuExecutable&&) = delete;
358 PyTpuExecutable& operator=(const PyTpuExecutable&) = delete;
359 PyTpuExecutable& operator=(PyTpuExecutable&&) = delete;
360
client()361 std::shared_ptr<PyTpuClient> client() const { return client_; }
362
num_replicas()363 int num_replicas() const { return device_assignment_.replica_count(); }
num_partitions()364 int num_partitions() const { return device_assignment_.computation_count(); }
365
SizeOfGeneratedCodeInBytes()366 int64_t SizeOfGeneratedCodeInBytes() const {
367 CHECK_GE(executables_.size(), 1);
368 return executables_.begin()->second->size_in_bytes();
369 }
370
device_assignment()371 const DeviceAssignment& device_assignment() const {
372 return device_assignment_;
373 }
374
local_logical_device_ids()375 const std::vector<std::pair<int, int>>& local_logical_device_ids() const {
376 return local_logical_device_ids_;
377 }
378
local_devices()379 const std::vector<std::shared_ptr<PjRtDevice>>& local_devices() const {
380 return local_devices_;
381 }
382
383 // TODO(power): Both Execute and ExecutePerOnLocalDevices block and wait
384 // inside for computation to finish. Coordinate with JAX code change to see if
385 // we can make both Execute and ExecutePerReplica non-blocking.
386 StatusOr<std::vector<std::unique_ptr<PyTpuBuffer>>> Execute(
387 absl::Span<PyTpuBuffer* const> argument_handles);
388
389 StatusOr<std::pair<std::vector<std::unique_ptr<PyTpuBuffer>>, PyTpuToken>>
ExecuteWithToken(absl::Span<PyTpuBuffer * const> argument_handles)390 ExecuteWithToken(absl::Span<PyTpuBuffer* const> argument_handles) {
391 TF_ASSIGN_OR_RETURN(auto results, Execute(argument_handles));
392 return std::pair<std::vector<std::unique_ptr<PyTpuBuffer>>, PyTpuToken>(
393 std::move(results), PyTpuToken());
394 }
395
396 // Execute on local devices. Takes a sequence of argument lists (one argument
397 // list per local device) and returns a tuple of results (one result per local
398 // device). The number of argument lists must be equal to the local device
399 // count.
400 StatusOr<std::vector<std::vector<std::unique_ptr<PyTpuBuffer>>>>
401 ExecuteOnLocalDevices(
402 absl::Span<const std::vector<PyTpuBuffer*>> argument_handles);
403
404 StatusOr<std::vector<std::vector<std::unique_ptr<PyTpuBuffer>>>>
405 ExecuteShardedOnLocalDevices(
406 absl::Span<const std::vector<PyTpuBuffer*>> args);
407
408 StatusOr<std::pair<std::vector<std::vector<std::unique_ptr<PyTpuBuffer>>>,
409 PyShardedTpuToken>>
ExecuteShardedOnLocalDevicesWithTokens(absl::Span<const std::vector<PyTpuBuffer * >> args)410 ExecuteShardedOnLocalDevicesWithTokens(
411 absl::Span<const std::vector<PyTpuBuffer*>> args) {
412 TF_ASSIGN_OR_RETURN(auto results, ExecuteShardedOnLocalDevices(args));
413
414 TF_RET_CHECK(!args.empty());
415 return std::pair<std::vector<std::vector<std::unique_ptr<PyTpuBuffer>>>,
416 PyShardedTpuToken>(std::move(results),
417 PyShardedTpuToken());
418 }
419
Delete()420 void Delete() { executables_.clear(); }
421
422 private:
423 struct ExecuteResult {
424 std::unique_ptr<PyTpuBuffer> buffer;
425 std::shared_ptr<tpu_driver::Event> on_execute_finished;
426 };
427
428 ExecuteResult ExecuteHelper(
429 absl::Span<const std::vector<PyTpuBuffer*>> all_core_arguments,
430 absl::Span<PyTpuBuffer* const> this_core_arguments, int replica,
431 int partition, const RunId& run_id);
432
433 std::shared_ptr<PyTpuClient> const client_;
434 std::map<int, std::unique_ptr<tpu_driver::LoadedProgramHandle>> executables_;
435 const DeviceAssignment device_assignment_;
436 const bool tuple_arguments_;
437
438 // The replica and partition indices of device_assignment_ to be run by this
439 // client. On single-host platforms without partitioning, this is all replicas
440 // (i.e. local_logical_device_ids_[i] = (i, 0)), but this may not be the case
441 // on multi-host platforms.
442 // If there are 4 replicas and 2 partitions on a single host platform, size of
443 // local_logical_device_ids_ is 4*2 = 8.
444 std::vector<std::pair<int, int>> local_logical_device_ids_;
445
446 // local_devices_[i] is the Device to which local_logical_device_ids_[i] is
447 // assigned.
448 // shared_ptrs instead of unique_ptrs to play well with the Python bindings
449 // (see xla.cc).
450 std::vector<std::shared_ptr<PjRtDevice>> local_devices_;
451
452 xla::Shape result_shape_;
453 };
454
455 } // namespace xla
456
457 #endif // TENSORFLOW_COMPILER_XLA_PYTHON_TPU_DRIVER_CLIENT_TPU_CLIENT_H_
458