xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/rendezvous.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_RENDEZVOUS_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_RENDEZVOUS_H_
18 
19 #include <functional>
20 #include <memory>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/synchronization/mutex.h"
26 #include "absl/time/time.h"
27 #include "absl/types/span.h"
28 
29 namespace xla {
30 
31 template <typename K, typename V>
32 class ThreadSafeMap {
33  public:
34   V& operator[](const K& key) {
35     absl::MutexLock lock(&mutex_);
36     std::unique_ptr<V>& value = map_[key];
37     if (value == nullptr) value = std::make_unique<V>();
38     return *value;
39   }
40 
ForEachValue(const std::function<void (V &)> & fn)41   void ForEachValue(const std::function<void(V&)>& fn) {
42     absl::MutexLock lock(&mutex_);
43     for (const auto& [_, value] : map_) fn(*value);
44   }
45 
46  private:
47   absl::Mutex mutex_;
48   absl::flat_hash_map<K, std::unique_ptr<V>> map_ ABSL_GUARDED_BY(mutex_);
49 };
50 
51 void AwaitAndLogIfStuck(absl::Mutex& mutex, const absl::Condition& condition,
52                         absl::Duration warn_stuck_timeout,
53                         absl::Duration terminate_timeout);
54 
55 // A rendezvous for a group of threads.
56 //
57 // The group of threads identifies itself with a key that must be unique to the
58 // the group. When all threads have arrived at the rendezvous, one thread
59 // executes the given function with the values supplied by each thread, and all
60 // threads receive the result.
61 // TODO(cjfj): Replace XLA rendezvous code with this simpler implementation.
62 template <typename R, typename K, typename V>
63 std::shared_ptr<R> RendezvousSingle(
64     const K& key, const V& value, size_t num_threads,
65     const std::function<R(absl::Span<const V* const>)>& fn,
66     absl::Duration warn_stuck_timeout = absl::InfiniteDuration(),
67     absl::Duration terminate_timeout = absl::InfiniteDuration()) {
68   // Fast-path (DO NOT REMOVE: the logic below doesn't work for single thread).
69   if (num_threads == 1) return std::make_shared<R>(fn({&value}));
70 
71   struct State {
72     absl::Mutex mutex;
73     std::vector<const V*> values ABSL_GUARDED_BY(mutex);
74     std::shared_ptr<R> result ABSL_GUARDED_BY(mutex);
75   };
76 
77   static auto& states = *new ThreadSafeMap<K, State>;
78   State& state = states[key];
79 
80   absl::MutexLock lock(&state.mutex);
81   state.values.push_back(&value);
82 
83   std::shared_ptr<R> result;
84   if (state.values.size() == num_threads) {
85     // Last thread to arrive executes the function.
86     CHECK(state.result == nullptr);
87     result = std::make_shared<R>(fn(state.values));
88     state.result = result;
89     state.values.clear();
90   } else {
91     absl::Condition result_ready(
92         +[](std::shared_ptr<R>* ptr) { return ptr->get() != nullptr; },
93         &state.result);
94     AwaitAndLogIfStuck(state.mutex, result_ready, warn_stuck_timeout,
95                        terminate_timeout);
96 
97     // There is one use of the result in the shared state, plus one use for each
98     // thread that has already retrieved the result.
99     if (state.result.use_count() < num_threads) {
100       result = state.result;
101     } else {
102       // Last thread to retrieve the result takes the result from the state,
103       // allowing the other threads to exit the function.
104       return std::move(state.result);
105     }
106   }
107 
108   // Wait for all threads to have retrieved the result. Without this, a thread
109   // could duplicate or delete its copy of the result, invalidating the use
110   // count logic above.
111   absl::Condition result_taken(
112       +[](std::shared_ptr<R>* ptr) { return ptr->get() == nullptr; },
113       &state.result);
114   AwaitAndLogIfStuck(state.mutex, result_taken, warn_stuck_timeout,
115                      terminate_timeout);
116   return result;
117 }
118 
119 // A rendezvous for a group of threads.
120 //
121 // The group of threads identifies itself with a key that must be unique to the
122 // the group. When all threads have arrived at the rendezvous, one thread
123 // executes the given function and all threads receive the result.
124 // TODO(cjfj): Replace XLA rendezvous code with this simpler implementation.
125 template <typename R, typename K>
126 std::shared_ptr<R> RendezvousSingle(
127     const K& key, size_t num_threads, const std::function<R()>& fn,
128     absl::Duration warn_stuck_timeout = absl::InfiniteDuration(),
129     absl::Duration terminate_timeout = absl::InfiniteDuration()) {
130   // Pass an arbitrary value that is ignored.
131   return RendezvousSingle<R, K, int>(
132       key, 0, num_threads, [fn](absl::Span<const int* const>) { return fn(); },
133       warn_stuck_timeout, terminate_timeout);
134 }
135 
136 }  // namespace xla
137 
138 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_RENDEZVOUS_H_
139