1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_PJRT_FUTURE_H_ 17 #define TENSORFLOW_COMPILER_XLA_PJRT_PJRT_FUTURE_H_ 18 19 #include <functional> 20 #include <utility> 21 22 #include "absl/functional/any_invocable.h" 23 #include "absl/types/span.h" 24 #include "tfrt/host_context/async_dispatch.h" // from @tf_runtime 25 #include "tfrt/host_context/async_value.h" // from @tf_runtime 26 #include "tfrt/host_context/async_value_ref.h" // from @tf_runtime 27 #include "tfrt/support/ref_count.h" // from @tf_runtime 28 29 namespace xla { 30 31 template <class T> 32 class PjRtFuture; 33 34 // An RAII event that a caller can use to tell the PjRtClient about asynchronous 35 // actions outside PjRt. 36 // 37 // A ScopedAsyncTrackingEvent can be generated by the caller by calling a method 38 // on PjRtDevice, and the creation of a ScopedAsyncTrackingEvent tells the 39 // PjRtClient that the client is creating some outstanding asynchronous work 40 // that depends on activities happening on the PjRtDevice. 41 // 42 // The caller can indicate that a ScopedAsyncTrackingEvent event cannot complete 43 // until after some PjRtFuture becomes ready, by calling 44 // future.AssertHappensBefore(event). 45 // 46 // The caller indicates that the work tracked by the ScopedAsyncTrackingEvent 47 // has completed by letting the event go out of scope. 48 // 49 // ScopedAsyncTrackingEvents are used by some PjRtClient implementations to 50 // monitor system-wide dependencies. 51 class ScopedAsyncTrackingEvent { 52 public: 53 virtual ~ScopedAsyncTrackingEvent() = default; 54 55 private: 56 template <class T> 57 friend class PjRtFuture; 58 59 // Indicates that the ScopedAsyncTrackingEvent won't complete until dependency 60 // becomes available. Called only by PjRtFuture. 61 virtual void AddDependency( 62 tfrt::RCReference<tfrt::AsyncValue> dependency) = 0; 63 }; 64 65 // Helpers for using PjRtFutures. 66 struct PjRtFutureHelpers { 67 public: 68 // Keys that are returned by an implementation-specific handler when a client 69 // starts to block on a promise. 70 // 71 // For now, contains a single UID that can be used to identify a TraceMe, but 72 // made extensible to allow support for other profilers such as endoscope. 73 struct ProfilingKeys { 74 uint64_t traceme_context_id = -1; 75 }; 76 // Signature of handler called by the PjRtFuture class before it starts to 77 // block a thread. 78 using OnBlockStartFn = std::function<ProfilingKeys()>; 79 // Signature of handler called by the PjRtFuture class after it finishes 80 // blocking a thread. 81 using OnBlockEndFn = std::function<void(ProfilingKeys)>; 82 }; 83 84 // PjRtFuture<T> is a simple future that is returned by PjRt APIs that 85 // enqueue asynchronous work, reporting a value of type T (frequently T=Status) 86 // when the work is complete. 87 // 88 // PjRtFuture can be used by the client to wait for work to complete, either via 89 // a blocking call or a callback. 90 // 91 // The implementation wraps a TFRT AsyncValueRef<T>, but we prefer to 92 // encapsulate the AVR rather than returning it directly for two reasons. 93 // 94 // First, we want to retain portability in case a future implementation moves 95 // away from AsyncValueRef ---- we don't want clients to call arbitrary 96 // AsyncValueRef APIs. 97 // 98 // Second, we want to export different semantics, for example we support 99 // integration between blocking and profiling (e.g., TraceMe). 100 // 101 // There are two ways to construct a PjRtFuture, one used by clients that 102 // natively use TFRT, which already have import APIs for constructing 103 // AsyncValueRefs; and another that avoids exposing TFRT APIs and can be used by 104 // non-TFRT clients. 105 template <class T> 106 class PjRtFuture { 107 public: 108 // Wrapper for AsyncValueRef<T> that can be used by clients that don't 109 // natively use TFRT. 110 struct Promise { 111 public: 112 // Creates an empty promise with !this == true. 113 explicit Promise() = default; 114 Promise(Promise&& other) = default; PromisePromise115 Promise(const Promise& other) : avr(other.avr.CopyRef()) {} 116 Promise& operator=(const Promise& other) { 117 avr = other.avr.CopyRef(); 118 return *this; 119 } 120 bool operator!() { return !avr; } 121 122 // Sets the value of the promise. Must be called at most once. 123 // 124 // After Set is called, value will be delivered to waiters on the parent 125 // PjRtFuture, via blocking or callbacks. SetPromise126 void Set(T value) { avr.emplace(std::move(value)); } 127 128 private: 129 friend class PjRtFuture<T>; PromisePromise130 explicit Promise(tfrt::AsyncValueRef<T> ref) : avr(std::move(ref)) {} 131 // The underlying TFRT value that can be waited on. 132 tfrt::AsyncValueRef<T> avr; 133 }; 134 135 // Returns a Promise that can be used to construct a PjRtFuture, and then Set 136 // later. 137 // 138 // Used by clients that do not use TFRT natively. CreatePromise()139 static Promise CreatePromise() { 140 return Promise(tfrt::MakeUnconstructedAsyncValueRef<T>()); 141 } 142 143 PjRtFuture() = default; 144 IsValid()145 bool IsValid() const { return promise_ref_ != nullptr; } 146 147 // Constructor for an already-available PjRtFuture. 148 // 149 // Typically used to eagerly return error values when async work will not 150 // be enqueued, e.g., due to invalid arguments. PjRtFuture(T t)151 explicit PjRtFuture(T t) 152 : promise_ref_(tfrt::MakeAvailableAsyncValueRef<T>(t)), 153 on_block_start_([]() { return PjRtFutureHelpers::ProfilingKeys(); }), 154 on_block_end_([](PjRtFutureHelpers::ProfilingKeys) {}) {} 155 156 // Constructor used by clients that natively use TFRT and already have a 157 // host_ctx that should be used for awaiting promises. 158 // 159 // on_block_start is called before Await starts to block. 160 // on_block_end is called after Await finishes blocking. 161 explicit PjRtFuture( 162 tfrt::AsyncValueRef<T> async_value, 163 PjRtFutureHelpers::OnBlockStartFn on_block_start = 164 []() { return PjRtFutureHelpers::ProfilingKeys(); }, 165 PjRtFutureHelpers::OnBlockEndFn on_block_end = 166 [](PjRtFutureHelpers::ProfilingKeys) {}) promise_ref_(std::move (async_value))167 : promise_ref_(std::move(async_value)), 168 on_block_start_(std::move(on_block_start)), 169 on_block_end_(std::move(on_block_end)) {} 170 171 // Constructor used by clients that don't natively use TFRT and want to use 172 // the wrapped PjRtFuture<T>::Promise class. 173 // 174 // on_block_start is called before Await starts to block. 175 // on_block_end is called after Await finishes blocking. 176 explicit PjRtFuture( 177 Promise promise, 178 PjRtFutureHelpers::OnBlockStartFn on_block_start = 179 []() { return PjRtFutureHelpers::ProfilingKeys(); }, 180 PjRtFutureHelpers::OnBlockEndFn on_block_end = 181 [](PjRtFutureHelpers::ProfilingKeys) {}) 182 : promise_ref_(std::move(promise.avr)), 183 on_block_start_(std::move(on_block_start)), 184 on_block_end_(std::move(on_block_end)) {} 185 186 // Two functions exist to know whether the future is ready, to accomodate 187 // the fact some backends (e.g. disributed ones) could take a non-trivial time 188 // to check the state of a future. 189 // 190 // `IsReady()` is guaranteed to return true if the future became ready before 191 // `IsReady()` was called. `IsReady()` will return immediately if a call to 192 // `Await()` has already returned, or any callback passed to `OnReady` has 193 // already been triggered. Otherwise IsReady() may block for the duration of a 194 // network message on some backends. IsReady()195 bool IsReady() { 196 CHECK(IsValid()); 197 return promise_ref_.IsAvailable(); 198 } 199 // `IsKnownReady()` is guaranteed to return immediately. `IsKnownReady()` will 200 // always return true if a call to `Await()` has already returned, or any 201 // callback passed to `OnReady` has already been triggered. Otherwise, 202 // `IsKnownReady()` may return false in some cases in which the future was 203 // ready before `IsKnownReady()` was called. IsKnownReady()204 bool IsKnownReady() { 205 CHECK(IsValid()); 206 return promise_ref_.IsAvailable(); 207 } 208 209 // Blocks the calling thread until the promise is ready, then returns the 210 // final value. Await()211 T Await() { 212 CHECK(IsValid()); 213 if (!promise_ref_.IsAvailable()) { 214 const auto keys = on_block_start_(); 215 tfrt::Await({promise_ref_.GetAsyncValue()}); 216 on_block_end_(keys); 217 } 218 DCHECK(promise_ref_.IsConcrete()); 219 return *promise_ref_; 220 } 221 222 // Registers callback to be called once the promise is ready, with the final 223 // value. 224 // 225 // callback may be called on an internal system thread or the calling thread. 226 // The client should avoid any potentially re-entrant API calls within the 227 // callback, for example by using the callback to enqueue work on a 228 // client-owned threadpool. OnReady(absl::AnyInvocable<void (T)&&> callback)229 void OnReady(absl::AnyInvocable<void(T) &&> callback) { 230 CHECK(IsValid()); 231 promise_ref_.AndThen([promise = promise_ref_.AsPtr(), 232 callback = std::move(callback)]() mutable { 233 DCHECK(promise.IsConcrete()); 234 std::move(callback)(*promise); 235 }); 236 } 237 238 // Indicates that event will not complete until after this becomes ready. 239 // 240 // May safely be called with event==nullptr in which case AssertHappensBefore 241 // has no effect. AssertHappensBefore(ScopedAsyncTrackingEvent * event)242 void AssertHappensBefore(ScopedAsyncTrackingEvent* event) { 243 CHECK(IsValid()); 244 if (event) { 245 event->AddDependency(promise_ref_.CopyRCRef()); 246 } 247 } 248 249 private: 250 // Wrapped object to wait on. 251 tfrt::AsyncValueRef<T> promise_ref_; 252 // Function that is called before a thread starts blocking on the promise. 253 PjRtFutureHelpers::OnBlockStartFn on_block_start_; 254 // Function that is called after a thread finishes blocking on the promise. 255 PjRtFutureHelpers::OnBlockEndFn on_block_end_; 256 }; 257 258 } // namespace xla 259 260 #endif // TENSORFLOW_COMPILER_XLA_PJRT_PJRT_FUTURE_H_ 261