1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_PY_EXECUTABLE_H_ 17 #define TENSORFLOW_COMPILER_XLA_PYTHON_PY_EXECUTABLE_H_ 18 19 #include <memory> 20 #include <string> 21 #include <utility> 22 #include <vector> 23 24 #include "absl/types/span.h" 25 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" 26 #include "tensorflow/compiler/xla/python/py_buffer.h" 27 #include "tensorflow/compiler/xla/python/py_client.h" 28 #include "tensorflow/compiler/xla/python/traceback.h" 29 #include "tensorflow/compiler/xla/statusor.h" 30 #include "tensorflow/compiler/xla/types.h" 31 32 namespace xla { 33 34 class PyToken { 35 public: 36 PyToken() = default; PyToken(PjRtFuture<Status> future)37 explicit PyToken(PjRtFuture<Status> future) : future_(std::move(future)) {} 38 ReadyPyToken()39 static PyToken ReadyPyToken() { 40 return PyToken(PjRtFuture<Status>(OkStatus())); 41 } 42 43 Status Await(); 44 45 private: 46 PjRtFuture<Status> future_; 47 }; 48 49 // PyShardedToken contains a PyToken for each device's execution. 50 class PyShardedToken { 51 public: 52 // Default construction creates a always-ready token. 53 PyShardedToken() = default; PyShardedToken(std::vector<PjRtFuture<Status>> futures)54 explicit PyShardedToken(std::vector<PjRtFuture<Status>> futures) 55 : futures_(std::move(futures)) {} 56 GetPyToken(int device_id)57 PyToken GetPyToken(int device_id) const { 58 if (futures_.empty()) return PyToken::ReadyPyToken(); 59 return PyToken(futures_.at(device_id)); 60 } 61 62 Status Await(); 63 64 private: 65 std::vector<PjRtFuture<Status>> futures_; 66 }; 67 68 // Python wrapper around PjRtExecutable. We use a wrapper class: 69 // a) to keep the PyClient alive via a std::shared_ptr<> 70 // b) to add Python-specific functionality. 71 class PyExecutable : public std::enable_shared_from_this<PyExecutable> { 72 public: 73 PyExecutable(std::shared_ptr<PyClient> client, 74 std::unique_ptr<PjRtLoadedExecutable> executable, 75 std::shared_ptr<Traceback> traceback, 76 std::optional<std::string> fingerprint, 77 std::vector<pybind11::capsule> host_callbacks); 78 ~PyExecutable(); 79 client()80 std::shared_ptr<PyClient> client() const { return client_; } executable()81 std::shared_ptr<PjRtLoadedExecutable> executable() const { 82 return executable_; 83 } 84 85 absl::Span<const PjRtLoadedExecutable::LogicalDeviceIds> addressable_device_logical_ids()86 addressable_device_logical_ids() const { 87 return executable_->addressable_device_logical_ids(); 88 } 89 90 std::vector<ClientAndPtr<PjRtDevice>> AddressableDevices() const; 91 SizeOfGeneratedCodeInBytes()92 int64_t SizeOfGeneratedCodeInBytes() const { 93 return executable_->SizeOfGeneratedCodeInBytes(); 94 } 95 GetCompiledMemoryStats()96 StatusOr<CompiledMemoryStats> GetCompiledMemoryStats() const { 97 return executable_->GetCompiledMemoryStats(); 98 } 99 Delete()100 void Delete() { return executable_->Delete(); } 101 is_deleted()102 bool is_deleted() { return executable_->IsDeleted(); } 103 104 StatusOr<std::vector<PyBuffer::object>> Execute( 105 absl::Span<PyBuffer::object const> args, PjRtDevice* device); 106 107 StatusOr<std::pair<std::vector<PyBuffer::object>, PyToken>> ExecuteWithToken( 108 absl::Span<PyBuffer::object const> args, PjRtDevice* device); 109 110 // Takes args indexed by argid then deviceid, transposes them, and passes to 111 // PjRtExecutable::Execute. The result is similarly transposed back into the 112 // argid,deviceid format. 113 // args is [num_args x num_devices]. 114 StatusOr<std::vector<std::vector<PyBuffer::object>>> 115 ExecuteShardedOnLocalDevices( 116 absl::Span<const std::vector<PyBuffer::object>> args); 117 118 StatusOr< 119 std::pair<std::vector<std::vector<PyBuffer::object>>, PyShardedToken>> 120 ExecuteShardedOnLocalDevicesWithTokens( 121 absl::Span<const std::vector<PyBuffer::object>> args); 122 123 StatusOr<std::vector<std::shared_ptr<HloModule>>> HloModules() const; 124 traceback()125 Traceback* traceback() { return traceback_.get(); } 126 pjrt_executable()127 const PjRtLoadedExecutable& pjrt_executable() const { return *executable_; } 128 mutable_pjrt_executable()129 PjRtLoadedExecutable* mutable_pjrt_executable() const { 130 return executable_.get(); 131 } options()132 const ExecuteOptions& options() const { return options_; } fingerprint()133 const std::optional<std::string>& fingerprint() const { return fingerprint_; } 134 135 // Keep `obj` alive as long as PyExecutable. 136 void KeepAlive(pybind11::object obj); 137 138 private: 139 StatusOr<std::pair<std::vector<PyBuffer::object>, PyToken>> ExecuteInternal( 140 absl::Span<PyBuffer::object const> args, PjRtDevice* device, 141 std::optional<std::vector<PjRtFuture<Status>>>& returned_futures); 142 StatusOr< 143 std::pair<std::vector<std::vector<PyBuffer::object>>, PyShardedToken>> 144 ExecuteShardedOnLocalDevicesInternal( 145 absl::Span<const std::vector<PyBuffer::object>> args, 146 std::optional<std::vector<PjRtFuture<Status>>>& returned_futures); 147 148 friend class PyClient; 149 150 std::shared_ptr<PyClient> client_; 151 std::shared_ptr<PjRtLoadedExecutable> executable_; 152 std::shared_ptr<Traceback> traceback_; 153 154 // Identical executables (i.e. representing the same program) will have the 155 // same fingerprint. nullopt on platforms or executables where fingerprints 156 // aren't implemented. 157 std::optional<std::string> fingerprint_; 158 159 // The python callbacks implemented using send/recv support. 160 std::vector<pybind11::capsule> host_callbacks_; 161 162 // The options to pass to `executable_.Execute`. 163 ExecuteOptions options_; 164 165 // Python objects to keep alive as requested by user. 166 std::vector<pybind11::object> keepalives_; 167 168 // Doubly-linked list of all executables known to the client. Protected by the 169 // GIL. 170 PyExecutable* next_; 171 PyExecutable* prev_; 172 }; 173 174 } // namespace xla 175 176 #endif // TENSORFLOW_COMPILER_XLA_PYTHON_PY_EXECUTABLE_H_ 177