1 //
2 //
3 // Copyright 2016 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18 
19 #ifndef GRPC_SRC_CPP_COMMON_CHANNEL_FILTER_H
20 #define GRPC_SRC_CPP_COMMON_CHANNEL_FILTER_H
21 
22 #include <stddef.h>
23 
24 #include <functional>
25 #include <new>
26 #include <string>
27 #include <utility>
28 
29 #include "absl/status/status.h"
30 #include "absl/types/optional.h"
31 
32 #include <grpc/grpc.h>
33 #include <grpcpp/support/config.h>
34 
35 #include "src/core/lib/channel/channel_args.h"
36 #include "src/core/lib/channel/channel_fwd.h"
37 #include "src/core/lib/channel/channel_stack.h"
38 #include "src/core/lib/channel/context.h"
39 #include "src/core/lib/iomgr/closure.h"
40 #include "src/core/lib/iomgr/error.h"
41 #include "src/core/lib/iomgr/polling_entity.h"
42 #include "src/core/lib/slice/slice_buffer.h"
43 #include "src/core/lib/surface/channel_stack_type.h"
44 #include "src/core/lib/transport/metadata_batch.h"
45 #include "src/core/lib/transport/transport.h"
46 
47 /// An interface to define filters.
48 ///
49 /// To define a filter, implement a subclass of each of \c CallData and
50 /// \c ChannelData. Then register the filter using something like this:
51 /// \code{.cpp}
52 ///   RegisterChannelFilter<MyChannelDataSubclass, MyCallDataSubclass>(
53 ///       "name-of-filter", GRPC_SERVER_CHANNEL, INT_MAX, nullptr);
54 /// \endcode
55 
56 namespace grpc {
57 
58 /// A C++ wrapper for the \c grpc_metadata_batch struct.
59 class MetadataBatch {
60  public:
61   /// Borrows a pointer to \a batch, but does NOT take ownership.
62   /// The caller must ensure that \a batch continues to exist for as
63   /// long as the MetadataBatch object does.
MetadataBatch(grpc_metadata_batch * batch)64   explicit MetadataBatch(grpc_metadata_batch* batch) : batch_(batch) {}
65 
batch()66   grpc_metadata_batch* batch() const { return batch_; }
67 
68   /// Adds metadata.
69   void AddMetadata(const string& key, const string& value);
70 
71  private:
72   grpc_metadata_batch* batch_;  // Not owned.
73 };
74 
75 /// A C++ wrapper for the \c grpc_transport_op struct.
76 class TransportOp {
77  public:
78   /// Borrows a pointer to \a op, but does NOT take ownership.
79   /// The caller must ensure that \a op continues to exist for as
80   /// long as the TransportOp object does.
TransportOp(grpc_transport_op * op)81   explicit TransportOp(grpc_transport_op* op) : op_(op) {}
82 
op()83   grpc_transport_op* op() const { return op_; }
84 
85   // TODO(roth): Add a C++ wrapper for grpc_error?
disconnect_with_error()86   grpc_error_handle disconnect_with_error() const {
87     return op_->disconnect_with_error;
88   }
send_goaway()89   bool send_goaway() const { return !op_->goaway_error.ok(); }
90 
91   // TODO(roth): Add methods for additional fields as needed.
92 
93  private:
94   grpc_transport_op* op_;  // Not owned.
95 };
96 
97 /// A C++ wrapper for the \c grpc_transport_stream_op_batch struct.
98 class TransportStreamOpBatch {
99  public:
100   /// Borrows a pointer to \a op, but does NOT take ownership.
101   /// The caller must ensure that \a op continues to exist for as
102   /// long as the TransportStreamOpBatch object does.
TransportStreamOpBatch(grpc_transport_stream_op_batch * op)103   explicit TransportStreamOpBatch(grpc_transport_stream_op_batch* op)
104       : op_(op),
105         send_initial_metadata_(
106             op->send_initial_metadata
107                 ? op->payload->send_initial_metadata.send_initial_metadata
108                 : nullptr),
109         send_trailing_metadata_(
110             op->send_trailing_metadata
111                 ? op->payload->send_trailing_metadata.send_trailing_metadata
112                 : nullptr),
113         recv_initial_metadata_(
114             op->recv_initial_metadata
115                 ? op->payload->recv_initial_metadata.recv_initial_metadata
116                 : nullptr),
117         recv_trailing_metadata_(
118             op->recv_trailing_metadata
119                 ? op->payload->recv_trailing_metadata.recv_trailing_metadata
120                 : nullptr) {}
121 
op()122   grpc_transport_stream_op_batch* op() const { return op_; }
123 
on_complete()124   grpc_closure* on_complete() const { return op_->on_complete; }
set_on_complete(grpc_closure * closure)125   void set_on_complete(grpc_closure* closure) { op_->on_complete = closure; }
126 
send_initial_metadata()127   MetadataBatch* send_initial_metadata() {
128     return op_->send_initial_metadata ? &send_initial_metadata_ : nullptr;
129   }
send_trailing_metadata()130   MetadataBatch* send_trailing_metadata() {
131     return op_->send_trailing_metadata ? &send_trailing_metadata_ : nullptr;
132   }
recv_initial_metadata()133   MetadataBatch* recv_initial_metadata() {
134     return op_->recv_initial_metadata ? &recv_initial_metadata_ : nullptr;
135   }
recv_trailing_metadata()136   MetadataBatch* recv_trailing_metadata() {
137     return op_->recv_trailing_metadata ? &recv_trailing_metadata_ : nullptr;
138   }
139 
recv_initial_metadata_ready()140   grpc_closure* recv_initial_metadata_ready() const {
141     return op_->recv_initial_metadata
142                ? op_->payload->recv_initial_metadata.recv_initial_metadata_ready
143                : nullptr;
144   }
set_recv_initial_metadata_ready(grpc_closure * closure)145   void set_recv_initial_metadata_ready(grpc_closure* closure) {
146     op_->payload->recv_initial_metadata.recv_initial_metadata_ready = closure;
147   }
148 
send_message()149   grpc_core::SliceBuffer* send_message() const {
150     return op_->send_message ? op_->payload->send_message.send_message
151                              : nullptr;
152   }
153 
set_send_message(grpc_core::SliceBuffer * send_message)154   void set_send_message(grpc_core::SliceBuffer* send_message) {
155     op_->send_message = true;
156     op_->payload->send_message.send_message = send_message;
157   }
158 
recv_message()159   absl::optional<grpc_core::SliceBuffer>* recv_message() const {
160     return op_->recv_message ? op_->payload->recv_message.recv_message
161                              : nullptr;
162   }
set_recv_message(absl::optional<grpc_core::SliceBuffer> * recv_message)163   void set_recv_message(absl::optional<grpc_core::SliceBuffer>* recv_message) {
164     op_->recv_message = true;
165     op_->payload->recv_message.recv_message = recv_message;
166   }
167 
get_census_context()168   census_context* get_census_context() const {
169     return static_cast<census_context*>(
170         op_->payload->context[GRPC_CONTEXT_TRACING].value);
171   }
172 
173  private:
174   grpc_transport_stream_op_batch* op_;  // Not owned.
175   MetadataBatch send_initial_metadata_;
176   MetadataBatch send_trailing_metadata_;
177   MetadataBatch recv_initial_metadata_;
178   MetadataBatch recv_trailing_metadata_;
179 };
180 
181 /// Represents channel data.
182 class ChannelData {
183  public:
ChannelData()184   ChannelData() {}
~ChannelData()185   virtual ~ChannelData() {}
186 
187   // TODO(roth): Come up with a more C++-like API for the channel element.
188 
189   /// Initializes the channel data.
Init(grpc_channel_element *,grpc_channel_element_args *)190   virtual grpc_error_handle Init(grpc_channel_element* /*elem*/,
191                                  grpc_channel_element_args* /*args*/) {
192     return absl::OkStatus();
193   }
194 
195   // Called before destruction.
Destroy(grpc_channel_element *)196   virtual void Destroy(grpc_channel_element* /*elem*/) {}
197 
198   virtual void StartTransportOp(grpc_channel_element* elem, TransportOp* op);
199 
200   virtual void GetInfo(grpc_channel_element* elem,
201                        const grpc_channel_info* channel_info);
202 };
203 
204 /// Represents call data.
205 class CallData {
206  public:
CallData()207   CallData() {}
~CallData()208   virtual ~CallData() {}
209 
210   // TODO(roth): Come up with a more C++-like API for the call element.
211 
212   /// Initializes the call data.
Init(grpc_call_element *,const grpc_call_element_args *)213   virtual grpc_error_handle Init(grpc_call_element* /*elem*/,
214                                  const grpc_call_element_args* /*args*/) {
215     return absl::OkStatus();
216   }
217 
218   // Called before destruction.
Destroy(grpc_call_element *,const grpc_call_final_info *,grpc_closure *)219   virtual void Destroy(grpc_call_element* /*elem*/,
220                        const grpc_call_final_info* /*final_info*/,
221                        grpc_closure* /*then_call_closure*/) {}
222 
223   /// Starts a new stream operation.
224   virtual void StartTransportStreamOpBatch(grpc_call_element* elem,
225                                            TransportStreamOpBatch* op);
226 
227   /// Sets a pollset or pollset set.
228   virtual void SetPollsetOrPollsetSet(grpc_call_element* elem,
229                                       grpc_polling_entity* pollent);
230 };
231 
232 namespace internal {
233 
234 // Defines static members for passing to C core.
235 // Members of this class correspond to the members of the C
236 // grpc_channel_filter struct.
237 template <typename ChannelDataType, typename CallDataType>
238 class ChannelFilter final {
239  public:
240   static const size_t channel_data_size = sizeof(ChannelDataType);
241 
InitChannelElement(grpc_channel_element * elem,grpc_channel_element_args * args)242   static grpc_error_handle InitChannelElement(grpc_channel_element* elem,
243                                               grpc_channel_element_args* args) {
244     // Construct the object in the already-allocated memory.
245     ChannelDataType* channel_data = new (elem->channel_data) ChannelDataType();
246     return channel_data->Init(elem, args);
247   }
248 
DestroyChannelElement(grpc_channel_element * elem)249   static void DestroyChannelElement(grpc_channel_element* elem) {
250     ChannelDataType* channel_data =
251         static_cast<ChannelDataType*>(elem->channel_data);
252     channel_data->Destroy(elem);
253     channel_data->~ChannelDataType();
254   }
255 
StartTransportOp(grpc_channel_element * elem,grpc_transport_op * op)256   static void StartTransportOp(grpc_channel_element* elem,
257                                grpc_transport_op* op) {
258     ChannelDataType* channel_data =
259         static_cast<ChannelDataType*>(elem->channel_data);
260     TransportOp op_wrapper(op);
261     channel_data->StartTransportOp(elem, &op_wrapper);
262   }
263 
GetChannelInfo(grpc_channel_element * elem,const grpc_channel_info * channel_info)264   static void GetChannelInfo(grpc_channel_element* elem,
265                              const grpc_channel_info* channel_info) {
266     ChannelDataType* channel_data =
267         static_cast<ChannelDataType*>(elem->channel_data);
268     channel_data->GetInfo(elem, channel_info);
269   }
270 
271   static const size_t call_data_size = sizeof(CallDataType);
272 
InitCallElement(grpc_call_element * elem,const grpc_call_element_args * args)273   static grpc_error_handle InitCallElement(grpc_call_element* elem,
274                                            const grpc_call_element_args* args) {
275     // Construct the object in the already-allocated memory.
276     CallDataType* call_data = new (elem->call_data) CallDataType();
277     return call_data->Init(elem, args);
278   }
279 
DestroyCallElement(grpc_call_element * elem,const grpc_call_final_info * final_info,grpc_closure * then_call_closure)280   static void DestroyCallElement(grpc_call_element* elem,
281                                  const grpc_call_final_info* final_info,
282                                  grpc_closure* then_call_closure) {
283     CallDataType* call_data = static_cast<CallDataType*>(elem->call_data);
284     call_data->Destroy(elem, final_info, then_call_closure);
285     call_data->~CallDataType();
286   }
287 
StartTransportStreamOpBatch(grpc_call_element * elem,grpc_transport_stream_op_batch * op)288   static void StartTransportStreamOpBatch(grpc_call_element* elem,
289                                           grpc_transport_stream_op_batch* op) {
290     CallDataType* call_data = static_cast<CallDataType*>(elem->call_data);
291     TransportStreamOpBatch op_wrapper(op);
292     call_data->StartTransportStreamOpBatch(elem, &op_wrapper);
293   }
294 
SetPollsetOrPollsetSet(grpc_call_element * elem,grpc_polling_entity * pollent)295   static void SetPollsetOrPollsetSet(grpc_call_element* elem,
296                                      grpc_polling_entity* pollent) {
297     CallDataType* call_data = static_cast<CallDataType*>(elem->call_data);
298     call_data->SetPollsetOrPollsetSet(elem, pollent);
299   }
300 };
301 
302 void RegisterChannelFilter(
303     grpc_channel_stack_type stack_type, int priority,
304     std::function<bool(const grpc_core::ChannelArgs&)> include_filter,
305     const grpc_channel_filter* filter);
306 
307 }  // namespace internal
308 
309 /// Registers a new filter.
310 /// Must be called by only one thread at a time.
311 /// The \a include_filter argument specifies a function that will be called
312 /// to determine at run-time whether or not to add the filter. If the
313 /// value is nullptr, the filter will be added unconditionally.
314 /// If the channel stack type is GRPC_CLIENT_SUBCHANNEL, the caller should
315 /// ensure that subchannels with different filter lists will always have
316 /// different channel args. This requires setting a channel arg in case the
317 /// registration function relies on some condition other than channel args to
318 /// decide whether to add a filter or not.
319 template <typename ChannelDataType, typename CallDataType>
RegisterChannelFilter(const char * name,grpc_channel_stack_type stack_type,int priority,std::function<bool (const grpc_core::ChannelArgs &)> include_filter)320 void RegisterChannelFilter(
321     const char* name, grpc_channel_stack_type stack_type, int priority,
322     std::function<bool(const grpc_core::ChannelArgs&)> include_filter) {
323   using FilterType = internal::ChannelFilter<ChannelDataType, CallDataType>;
324   static const grpc_channel_filter filter = {
325       FilterType::StartTransportStreamOpBatch,
326       nullptr,
327       FilterType::StartTransportOp,
328       FilterType::call_data_size,
329       FilterType::InitCallElement,
330       FilterType::SetPollsetOrPollsetSet,
331       FilterType::DestroyCallElement,
332       FilterType::channel_data_size,
333       FilterType::InitChannelElement,
334       grpc_channel_stack_no_post_init,
335       FilterType::DestroyChannelElement,
336       FilterType::GetChannelInfo,
337       name};
338   grpc::internal::RegisterChannelFilter(stack_type, priority,
339                                         std::move(include_filter), &filter);
340 }
341 
342 }  // namespace grpc
343 
344 #endif  // GRPC_SRC_CPP_COMMON_CHANNEL_FILTER_H
345