1 // Copyright 2021 gRPC authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #ifndef GRPC_SRC_CORE_LIB_PROMISE_DETAIL_BASIC_JOIN_H 16 #define GRPC_SRC_CORE_LIB_PROMISE_DETAIL_BASIC_JOIN_H 17 18 #include <grpc/support/port_platform.h> 19 20 #include <assert.h> 21 #include <stddef.h> 22 23 #include <array> 24 #include <tuple> 25 #include <type_traits> 26 #include <utility> 27 28 #include "absl/utility/utility.h" 29 30 #include "src/core/lib/gprpp/bitset.h" 31 #include "src/core/lib/gprpp/construct_destruct.h" 32 #include "src/core/lib/promise/detail/promise_like.h" 33 #include "src/core/lib/promise/poll.h" 34 35 namespace grpc_core { 36 namespace promise_detail { 37 38 // This union can either be a functor, or the result of the functor (after 39 // mapping via a trait). Allows us to remember the result of one joined functor 40 // until the rest are ready. 41 template <typename Traits, typename F> 42 union Fused { Fused(F && f)43 explicit Fused(F&& f) : f(std::forward<F>(f)) {} Fused(PromiseLike<F> && f)44 explicit Fused(PromiseLike<F>&& f) : f(std::forward<PromiseLike<F>>(f)) {} ~Fused()45 ~Fused() {} 46 // Wrap the functor in a PromiseLike to handle immediately returning functors 47 // and the like. 48 using Promise = PromiseLike<F>; 49 GPR_NO_UNIQUE_ADDRESS Promise f; 50 // Compute the result type: We take the result of the promise, and pass it via 51 // our traits, so that, for example, TryJoin and take a StatusOr<T> and just 52 // store a T. 53 using Result = typename Traits::template ResultType<typename Promise::Result>; 54 GPR_NO_UNIQUE_ADDRESS Result result; 55 }; 56 57 // A join gets composed of joints... these are just wrappers around a Fused for 58 // their data, with some machinery as methods to get the system working. 59 template <typename Traits, size_t kRemaining, typename... Fs> 60 struct Joint : public Joint<Traits, kRemaining - 1, Fs...> { 61 // The index into Fs for this Joint 62 static constexpr size_t kIdx = sizeof...(Fs) - kRemaining; 63 // The next join (the one we derive from) 64 using NextJoint = Joint<Traits, kRemaining - 1, Fs...>; 65 // From Fs, extract the functor for this joint. 66 using F = typename std::tuple_element<kIdx, std::tuple<Fs...>>::type; 67 // Generate the Fused type for this functor. 68 using Fsd = Fused<Traits, F>; 69 GPR_NO_UNIQUE_ADDRESS Fsd fused; 70 // Figure out what kind of bitmask will be used by the outer join. 71 using Bits = BitSet<sizeof...(Fs)>; 72 // Initialize from a tuple of pointers to Fs JointJoint73 explicit Joint(std::tuple<Fs*...> fs) 74 : NextJoint(fs), fused(std::move(*std::get<kIdx>(fs))) {} 75 // Copy: assume that the Fuse is still in the promise state (since it's not 76 // legal to copy after the first poll!) JointJoint77 Joint(const Joint& j) : NextJoint(j), fused(j.fused.f) {} 78 // Move: assume that the Fuse is still in the promise state (since it's not 79 // legal to move after the first poll!) JointJoint80 Joint(Joint&& j) noexcept 81 : NextJoint(std::forward<NextJoint>(j)), fused(std::move(j.fused.f)) {} 82 // Destruct: check bits to see if we're in promise or result state, and call 83 // the appropriate destructor. Recursively, call up through the join. DestructAllJoint84 void DestructAll(const Bits& bits) { 85 if (!bits.is_set(kIdx)) { 86 Destruct(&fused.f); 87 } else { 88 Destruct(&fused.result); 89 } 90 NextJoint::DestructAll(bits); 91 } 92 // Poll all joints up, and then call finally. 93 template <typename F> 94 auto Run(Bits* bits, F finally) -> decltype(finally()) { 95 // If we're still in the promise state... 96 if (!bits->is_set(kIdx)) { 97 // Poll the promise 98 auto r = fused.f(); 99 if (auto* p = r.value_if_ready()) { 100 // If it's done, then ask the trait to unwrap it and store that result 101 // in the Fused, and continue the iteration. Note that OnResult could 102 // instead choose to return a value instead of recursing through the 103 // iteration, in that case we continue returning the same result up. 104 // Here is where TryJoin can escape out. 105 return Traits::OnResult( 106 std::move(*p), [this, bits, &finally](typename Fsd::Result result) { 107 bits->set(kIdx); 108 Destruct(&fused.f); 109 Construct(&fused.result, std::move(result)); 110 return NextJoint::Run(bits, std::move(finally)); 111 }); 112 } 113 } 114 // That joint is still pending... we'll still poll the result of the joints. 115 return NextJoint::Run(bits, std::move(finally)); 116 } 117 }; 118 119 // Terminating joint... for each of the recursions, do the thing we're supposed 120 // to do at the end. 121 template <typename Traits, typename... Fs> 122 struct Joint<Traits, 0, Fs...> { 123 explicit Joint(std::tuple<Fs*...>) {} 124 Joint(const Joint&) {} 125 Joint(Joint&&) noexcept {} 126 template <typename T> 127 void DestructAll(const T&) {} 128 template <typename F> 129 auto Run(BitSet<sizeof...(Fs)>*, F finally) -> decltype(finally()) { 130 return finally(); 131 } 132 }; 133 134 template <typename Traits, typename... Fs> 135 class BasicJoin { 136 private: 137 // How many things are we joining? 138 static constexpr size_t N = sizeof...(Fs); 139 // Bitset: if a bit is 0, that joint is still in promise state. If it's 1, 140 // then the joint has a result. 141 GPR_NO_UNIQUE_ADDRESS BitSet<N> state_; 142 // The actual joints, wrapped in an anonymous union to give us control of 143 // construction/destruction. 144 union { 145 GPR_NO_UNIQUE_ADDRESS Joint<Traits, sizeof...(Fs), Fs...> joints_; 146 }; 147 148 // Access joint index I 149 template <size_t I> 150 Joint<Traits, sizeof...(Fs) - I, Fs...>* GetJoint() { 151 return static_cast<Joint<Traits, sizeof...(Fs) - I, Fs...>*>(&joints_); 152 } 153 154 // The tuple of results of all our promises 155 using Tuple = std::tuple<typename Fused<Traits, Fs>::Result...>; 156 157 // Collect up all the results and construct a tuple. 158 template <size_t... I> 159 Tuple Finish(absl::index_sequence<I...>) { 160 return Tuple(std::move(GetJoint<I>()->fused.result)...); 161 } 162 163 public: 164 explicit BasicJoin(Fs&&... fs) : joints_(std::tuple<Fs*...>(&fs...)) {} 165 BasicJoin& operator=(const BasicJoin&) = delete; 166 // Copy a join - only available before polling. 167 BasicJoin(const BasicJoin& other) { 168 assert(other.state_.none()); 169 Construct(&joints_, other.joints_); 170 } 171 // Move a join - only available before polling. 172 BasicJoin(BasicJoin&& other) noexcept { 173 assert(other.state_.none()); 174 Construct(&joints_, std::move(other.joints_)); 175 } 176 ~BasicJoin() { joints_.DestructAll(state_); } 177 using Result = decltype(Traits::Wrap(std::declval<Tuple>())); 178 // Poll the join 179 Poll<Result> operator()() { 180 // Poll the joints... 181 return joints_.Run(&state_, [this]() -> Poll<Result> { 182 // If all of them are completed, collect the results, and then ask our 183 // traits to wrap them - allowing for example TryJoin to turn tuple<A,B,C> 184 // into StatusOr<tuple<A,B,C>>. 185 if (state_.all()) { 186 return Traits::Wrap(Finish(absl::make_index_sequence<N>())); 187 } else { 188 return Pending(); 189 } 190 }); 191 } 192 }; 193 194 } // namespace promise_detail 195 } // namespace grpc_core 196 197 #endif // GRPC_SRC_CORE_LIB_PROMISE_DETAIL_BASIC_JOIN_H 198