1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_RENDEZVOUS_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_RENDEZVOUS_H_ 18 19 #include <functional> 20 #include <memory> 21 #include <utility> 22 #include <vector> 23 24 #include "absl/container/flat_hash_map.h" 25 #include "absl/synchronization/mutex.h" 26 #include "absl/time/time.h" 27 #include "absl/types/span.h" 28 29 namespace xla { 30 31 template <typename K, typename V> 32 class ThreadSafeMap { 33 public: 34 V& operator[](const K& key) { 35 absl::MutexLock lock(&mutex_); 36 std::unique_ptr<V>& value = map_[key]; 37 if (value == nullptr) value = std::make_unique<V>(); 38 return *value; 39 } 40 ForEachValue(const std::function<void (V &)> & fn)41 void ForEachValue(const std::function<void(V&)>& fn) { 42 absl::MutexLock lock(&mutex_); 43 for (const auto& [_, value] : map_) fn(*value); 44 } 45 46 private: 47 absl::Mutex mutex_; 48 absl::flat_hash_map<K, std::unique_ptr<V>> map_ ABSL_GUARDED_BY(mutex_); 49 }; 50 51 void AwaitAndLogIfStuck(absl::Mutex& mutex, const absl::Condition& condition, 52 absl::Duration warn_stuck_timeout, 53 absl::Duration terminate_timeout); 54 55 // A rendezvous for a group of threads. 56 // 57 // The group of threads identifies itself with a key that must be unique to the 58 // the group. When all threads have arrived at the rendezvous, one thread 59 // executes the given function with the values supplied by each thread, and all 60 // threads receive the result. 61 // TODO(cjfj): Replace XLA rendezvous code with this simpler implementation. 62 template <typename R, typename K, typename V> 63 std::shared_ptr<R> RendezvousSingle( 64 const K& key, const V& value, size_t num_threads, 65 const std::function<R(absl::Span<const V* const>)>& fn, 66 absl::Duration warn_stuck_timeout = absl::InfiniteDuration(), 67 absl::Duration terminate_timeout = absl::InfiniteDuration()) { 68 // Fast-path (DO NOT REMOVE: the logic below doesn't work for single thread). 69 if (num_threads == 1) return std::make_shared<R>(fn({&value})); 70 71 struct State { 72 absl::Mutex mutex; 73 std::vector<const V*> values ABSL_GUARDED_BY(mutex); 74 std::shared_ptr<R> result ABSL_GUARDED_BY(mutex); 75 }; 76 77 static auto& states = *new ThreadSafeMap<K, State>; 78 State& state = states[key]; 79 80 absl::MutexLock lock(&state.mutex); 81 state.values.push_back(&value); 82 83 std::shared_ptr<R> result; 84 if (state.values.size() == num_threads) { 85 // Last thread to arrive executes the function. 86 CHECK(state.result == nullptr); 87 result = std::make_shared<R>(fn(state.values)); 88 state.result = result; 89 state.values.clear(); 90 } else { 91 absl::Condition result_ready( 92 +[](std::shared_ptr<R>* ptr) { return ptr->get() != nullptr; }, 93 &state.result); 94 AwaitAndLogIfStuck(state.mutex, result_ready, warn_stuck_timeout, 95 terminate_timeout); 96 97 // There is one use of the result in the shared state, plus one use for each 98 // thread that has already retrieved the result. 99 if (state.result.use_count() < num_threads) { 100 result = state.result; 101 } else { 102 // Last thread to retrieve the result takes the result from the state, 103 // allowing the other threads to exit the function. 104 return std::move(state.result); 105 } 106 } 107 108 // Wait for all threads to have retrieved the result. Without this, a thread 109 // could duplicate or delete its copy of the result, invalidating the use 110 // count logic above. 111 absl::Condition result_taken( 112 +[](std::shared_ptr<R>* ptr) { return ptr->get() == nullptr; }, 113 &state.result); 114 AwaitAndLogIfStuck(state.mutex, result_taken, warn_stuck_timeout, 115 terminate_timeout); 116 return result; 117 } 118 119 // A rendezvous for a group of threads. 120 // 121 // The group of threads identifies itself with a key that must be unique to the 122 // the group. When all threads have arrived at the rendezvous, one thread 123 // executes the given function and all threads receive the result. 124 // TODO(cjfj): Replace XLA rendezvous code with this simpler implementation. 125 template <typename R, typename K> 126 std::shared_ptr<R> RendezvousSingle( 127 const K& key, size_t num_threads, const std::function<R()>& fn, 128 absl::Duration warn_stuck_timeout = absl::InfiniteDuration(), 129 absl::Duration terminate_timeout = absl::InfiniteDuration()) { 130 // Pass an arbitrary value that is ignored. 131 return RendezvousSingle<R, K, int>( 132 key, 0, num_threads, [fn](absl::Span<const int* const>) { return fn(); }, 133 warn_stuck_timeout, terminate_timeout); 134 } 135 136 } // namespace xla 137 138 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_RENDEZVOUS_H_ 139