xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/pjrt/pjrt_future.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_PJRT_FUTURE_H_
17 #define TENSORFLOW_COMPILER_XLA_PJRT_PJRT_FUTURE_H_
18 
19 #include <functional>
20 #include <utility>
21 
22 #include "absl/functional/any_invocable.h"
23 #include "absl/types/span.h"
24 #include "tfrt/host_context/async_dispatch.h"  // from @tf_runtime
25 #include "tfrt/host_context/async_value.h"  // from @tf_runtime
26 #include "tfrt/host_context/async_value_ref.h"  // from @tf_runtime
27 #include "tfrt/support/ref_count.h"  // from @tf_runtime
28 
29 namespace xla {
30 
31 template <class T>
32 class PjRtFuture;
33 
34 // An RAII event that a caller can use to tell the PjRtClient about asynchronous
35 // actions outside PjRt.
36 //
37 // A ScopedAsyncTrackingEvent can be generated by the caller by calling a method
38 // on PjRtDevice, and the creation of a ScopedAsyncTrackingEvent tells the
39 // PjRtClient that the client is creating some outstanding asynchronous work
40 // that depends on activities happening on the PjRtDevice.
41 //
42 // The caller can indicate that a ScopedAsyncTrackingEvent event cannot complete
43 // until after some PjRtFuture becomes ready, by calling
44 // future.AssertHappensBefore(event).
45 //
46 // The caller indicates that the work tracked by the ScopedAsyncTrackingEvent
47 // has completed by letting the event go out of scope.
48 //
49 // ScopedAsyncTrackingEvents are used by some PjRtClient implementations to
50 // monitor system-wide dependencies.
51 class ScopedAsyncTrackingEvent {
52  public:
53   virtual ~ScopedAsyncTrackingEvent() = default;
54 
55  private:
56   template <class T>
57   friend class PjRtFuture;
58 
59   // Indicates that the ScopedAsyncTrackingEvent won't complete until dependency
60   // becomes available. Called only by PjRtFuture.
61   virtual void AddDependency(
62       tfrt::RCReference<tfrt::AsyncValue> dependency) = 0;
63 };
64 
65 // Helpers for using PjRtFutures.
66 struct PjRtFutureHelpers {
67  public:
68   // Keys that are returned by an implementation-specific handler when a client
69   // starts to block on a promise.
70   //
71   // For now, contains a single UID that can be used to identify a TraceMe, but
72   // made extensible to allow support for other profilers such as endoscope.
73   struct ProfilingKeys {
74     uint64_t traceme_context_id = -1;
75   };
76   // Signature of handler called by the PjRtFuture class before it starts to
77   // block a thread.
78   using OnBlockStartFn = std::function<ProfilingKeys()>;
79   // Signature of handler called by the PjRtFuture class after it finishes
80   // blocking a thread.
81   using OnBlockEndFn = std::function<void(ProfilingKeys)>;
82 };
83 
84 // PjRtFuture<T> is a simple future that is returned by PjRt APIs that
85 // enqueue asynchronous work, reporting a value of type T (frequently T=Status)
86 // when the work is complete.
87 //
88 // PjRtFuture can be used by the client to wait for work to complete, either via
89 // a blocking call or a callback.
90 //
91 // The implementation wraps a TFRT AsyncValueRef<T>, but we prefer to
92 // encapsulate the AVR rather than returning it directly for two reasons.
93 //
94 // First, we want to retain portability in case a future implementation moves
95 // away from AsyncValueRef ---- we don't want clients to call arbitrary
96 // AsyncValueRef APIs.
97 //
98 // Second, we want to export different semantics, for example we support
99 // integration between blocking and profiling (e.g., TraceMe).
100 //
101 // There are two ways to construct a PjRtFuture, one used by clients that
102 // natively use TFRT, which already have import APIs for constructing
103 // AsyncValueRefs; and another that avoids exposing TFRT APIs and can be used by
104 // non-TFRT clients.
105 template <class T>
106 class PjRtFuture {
107  public:
108   // Wrapper for AsyncValueRef<T> that can be used by clients that don't
109   // natively use TFRT.
110   struct Promise {
111    public:
112     // Creates an empty promise with !this == true.
113     explicit Promise() = default;
114     Promise(Promise&& other) = default;
PromisePromise115     Promise(const Promise& other) : avr(other.avr.CopyRef()) {}
116     Promise& operator=(const Promise& other) {
117       avr = other.avr.CopyRef();
118       return *this;
119     }
120     bool operator!() { return !avr; }
121 
122     // Sets the value of the promise. Must be called at most once.
123     //
124     // After Set is called, value will be delivered to waiters on the parent
125     // PjRtFuture, via blocking or callbacks.
SetPromise126     void Set(T value) { avr.emplace(std::move(value)); }
127 
128    private:
129     friend class PjRtFuture<T>;
PromisePromise130     explicit Promise(tfrt::AsyncValueRef<T> ref) : avr(std::move(ref)) {}
131     // The underlying TFRT value that can be waited on.
132     tfrt::AsyncValueRef<T> avr;
133   };
134 
135   // Returns a Promise that can be used to construct a PjRtFuture, and then Set
136   // later.
137   //
138   // Used by clients that do not use TFRT natively.
CreatePromise()139   static Promise CreatePromise() {
140     return Promise(tfrt::MakeUnconstructedAsyncValueRef<T>());
141   }
142 
143   PjRtFuture() = default;
144 
IsValid()145   bool IsValid() const { return promise_ref_ != nullptr; }
146 
147   // Constructor for an already-available PjRtFuture.
148   //
149   // Typically used to eagerly return error values when async work will not
150   // be enqueued, e.g., due to invalid arguments.
PjRtFuture(T t)151   explicit PjRtFuture(T t)
152       : promise_ref_(tfrt::MakeAvailableAsyncValueRef<T>(t)),
153         on_block_start_([]() { return PjRtFutureHelpers::ProfilingKeys(); }),
154         on_block_end_([](PjRtFutureHelpers::ProfilingKeys) {}) {}
155 
156   // Constructor used by clients that natively use TFRT and already have a
157   // host_ctx that should be used for awaiting promises.
158   //
159   // on_block_start is called before Await starts to block.
160   // on_block_end is called after Await finishes blocking.
161   explicit PjRtFuture(
162       tfrt::AsyncValueRef<T> async_value,
163       PjRtFutureHelpers::OnBlockStartFn on_block_start =
164           []() { return PjRtFutureHelpers::ProfilingKeys(); },
165       PjRtFutureHelpers::OnBlockEndFn on_block_end =
166           [](PjRtFutureHelpers::ProfilingKeys) {})
promise_ref_(std::move (async_value))167       : promise_ref_(std::move(async_value)),
168         on_block_start_(std::move(on_block_start)),
169         on_block_end_(std::move(on_block_end)) {}
170 
171   // Constructor used by clients that don't natively use TFRT and want to use
172   // the wrapped PjRtFuture<T>::Promise class.
173   //
174   // on_block_start is called before Await starts to block.
175   // on_block_end is called after Await finishes blocking.
176   explicit PjRtFuture(
177       Promise promise,
178       PjRtFutureHelpers::OnBlockStartFn on_block_start =
179           []() { return PjRtFutureHelpers::ProfilingKeys(); },
180       PjRtFutureHelpers::OnBlockEndFn on_block_end =
181           [](PjRtFutureHelpers::ProfilingKeys) {})
182       : promise_ref_(std::move(promise.avr)),
183         on_block_start_(std::move(on_block_start)),
184         on_block_end_(std::move(on_block_end)) {}
185 
186   // Two functions exist to know whether the future is ready, to accomodate
187   // the fact some backends (e.g. disributed ones) could take a non-trivial time
188   // to check the state of a future.
189   //
190   // `IsReady()` is guaranteed to return true if the future became ready before
191   // `IsReady()` was called. `IsReady()` will return immediately if a call to
192   // `Await()` has already returned, or any callback passed to `OnReady` has
193   // already been triggered. Otherwise IsReady() may block for the duration of a
194   // network message on some backends.
IsReady()195   bool IsReady() {
196     CHECK(IsValid());
197     return promise_ref_.IsAvailable();
198   }
199   // `IsKnownReady()` is guaranteed to return immediately. `IsKnownReady()` will
200   // always return true if a call to `Await()` has already returned, or any
201   // callback passed to `OnReady` has already been triggered. Otherwise,
202   // `IsKnownReady()` may return false in some cases in which the future was
203   // ready before `IsKnownReady()` was called.
IsKnownReady()204   bool IsKnownReady() {
205     CHECK(IsValid());
206     return promise_ref_.IsAvailable();
207   }
208 
209   // Blocks the calling thread until the promise is ready, then returns the
210   // final value.
Await()211   T Await() {
212     CHECK(IsValid());
213     if (!promise_ref_.IsAvailable()) {
214       const auto keys = on_block_start_();
215       tfrt::Await({promise_ref_.GetAsyncValue()});
216       on_block_end_(keys);
217     }
218     DCHECK(promise_ref_.IsConcrete());
219     return *promise_ref_;
220   }
221 
222   // Registers callback to be called once the promise is ready, with the final
223   // value.
224   //
225   // callback may be called on an internal system thread or the calling thread.
226   // The client should avoid any potentially re-entrant API calls within the
227   // callback, for example by using the callback to enqueue work on a
228   // client-owned threadpool.
OnReady(absl::AnyInvocable<void (T)&&> callback)229   void OnReady(absl::AnyInvocable<void(T) &&> callback) {
230     CHECK(IsValid());
231     promise_ref_.AndThen([promise = promise_ref_.AsPtr(),
232                           callback = std::move(callback)]() mutable {
233       DCHECK(promise.IsConcrete());
234       std::move(callback)(*promise);
235     });
236   }
237 
238   // Indicates that event will not complete until after this becomes ready.
239   //
240   // May safely be called with event==nullptr in which case AssertHappensBefore
241   // has no effect.
AssertHappensBefore(ScopedAsyncTrackingEvent * event)242   void AssertHappensBefore(ScopedAsyncTrackingEvent* event) {
243     CHECK(IsValid());
244     if (event) {
245       event->AddDependency(promise_ref_.CopyRCRef());
246     }
247   }
248 
249  private:
250   // Wrapped object to wait on.
251   tfrt::AsyncValueRef<T> promise_ref_;
252   // Function that is called before a thread starts blocking on the promise.
253   PjRtFutureHelpers::OnBlockStartFn on_block_start_;
254   // Function that is called after a thread finishes blocking on the promise.
255   PjRtFutureHelpers::OnBlockEndFn on_block_end_;
256 };
257 
258 }  // namespace xla
259 
260 #endif  // TENSORFLOW_COMPILER_XLA_PJRT_PJRT_FUTURE_H_
261