1# Copyright 2017, Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Bi-directional streaming RPC helpers."""
16
17import collections
18import datetime
19import logging
20import queue as queue_module
21import threading
22import time
23
24from google.api_core import exceptions
25
26_LOGGER = logging.getLogger(__name__)
27_BIDIRECTIONAL_CONSUMER_NAME = "Thread-ConsumeBidirectionalStream"
28
29
30class _RequestQueueGenerator(object):
31    """A helper for sending requests to a gRPC stream from a Queue.
32
33    This generator takes requests off a given queue and yields them to gRPC.
34
35    This helper is useful when you have an indeterminate, indefinite, or
36    otherwise open-ended set of requests to send through a request-streaming
37    (or bidirectional) RPC.
38
39    The reason this is necessary is because gRPC takes an iterator as the
40    request for request-streaming RPCs. gRPC consumes this iterator in another
41    thread to allow it to block while generating requests for the stream.
42    However, if the generator blocks indefinitely gRPC will not be able to
43    clean up the thread as it'll be blocked on `next(iterator)` and not be able
44    to check the channel status to stop iterating. This helper mitigates that
45    by waiting on the queue with a timeout and checking the RPC state before
46    yielding.
47
48    Finally, it allows for retrying without swapping queues because if it does
49    pull an item off the queue when the RPC is inactive, it'll immediately put
50    it back and then exit. This is necessary because yielding the item in this
51    case will cause gRPC to discard it. In practice, this means that the order
52    of messages is not guaranteed. If such a thing is necessary it would be
53    easy to use a priority queue.
54
55    Example::
56
57        requests = request_queue_generator(q)
58        call = stub.StreamingRequest(iter(requests))
59        requests.call = call
60
61        for response in call:
62            print(response)
63            q.put(...)
64
65    Note that it is possible to accomplish this behavior without "spinning"
66    (using a queue timeout). One possible way would be to use more threads to
67    multiplex the grpc end event with the queue, another possible way is to
68    use selectors and a custom event/queue object. Both of these approaches
69    are significant from an engineering perspective for small benefit - the
70    CPU consumed by spinning is pretty minuscule.
71
72    Args:
73        queue (queue_module.Queue): The request queue.
74        period (float): The number of seconds to wait for items from the queue
75            before checking if the RPC is cancelled. In practice, this
76            determines the maximum amount of time the request consumption
77            thread will live after the RPC is cancelled.
78        initial_request (Union[protobuf.Message,
79                Callable[None, protobuf.Message]]): The initial request to
80            yield. This is done independently of the request queue to allow fo
81            easily restarting streams that require some initial configuration
82            request.
83    """
84
85    def __init__(self, queue, period=1, initial_request=None):
86        self._queue = queue
87        self._period = period
88        self._initial_request = initial_request
89        self.call = None
90
91    def _is_active(self):
92        # Note: there is a possibility that this starts *before* the call
93        # property is set. So we have to check if self.call is set before
94        # seeing if it's active.
95        if self.call is not None and not self.call.is_active():
96            return False
97        else:
98            return True
99
100    def __iter__(self):
101        if self._initial_request is not None:
102            if callable(self._initial_request):
103                yield self._initial_request()
104            else:
105                yield self._initial_request
106
107        while True:
108            try:
109                item = self._queue.get(timeout=self._period)
110            except queue_module.Empty:
111                if not self._is_active():
112                    _LOGGER.debug(
113                        "Empty queue and inactive call, exiting request " "generator."
114                    )
115                    return
116                else:
117                    # call is still active, keep waiting for queue items.
118                    continue
119
120            # The consumer explicitly sent "None", indicating that the request
121            # should end.
122            if item is None:
123                _LOGGER.debug("Cleanly exiting request generator.")
124                return
125
126            if not self._is_active():
127                # We have an item, but the call is closed. We should put the
128                # item back on the queue so that the next call can consume it.
129                self._queue.put(item)
130                _LOGGER.debug(
131                    "Inactive call, replacing item on queue and exiting "
132                    "request generator."
133                )
134                return
135
136            yield item
137
138
139class _Throttle(object):
140    """A context manager limiting the total entries in a sliding time window.
141
142    If more than ``access_limit`` attempts are made to enter the context manager
143    instance in the last ``time window`` interval, the exceeding requests block
144    until enough time elapses.
145
146    The context manager instances are thread-safe and can be shared between
147    multiple threads. If multiple requests are blocked and waiting to enter,
148    the exact order in which they are allowed to proceed is not determined.
149
150    Example::
151
152        max_three_per_second = _Throttle(
153            access_limit=3, time_window=datetime.timedelta(seconds=1)
154        )
155
156        for i in range(5):
157            with max_three_per_second as time_waited:
158                print("{}: Waited {} seconds to enter".format(i, time_waited))
159
160    Args:
161        access_limit (int): the maximum number of entries allowed in the time window
162        time_window (datetime.timedelta): the width of the sliding time window
163    """
164
165    def __init__(self, access_limit, time_window):
166        if access_limit < 1:
167            raise ValueError("access_limit argument must be positive")
168
169        if time_window <= datetime.timedelta(0):
170            raise ValueError("time_window argument must be a positive timedelta")
171
172        self._time_window = time_window
173        self._access_limit = access_limit
174        self._past_entries = collections.deque(
175            maxlen=access_limit
176        )  # least recent first
177        self._entry_lock = threading.Lock()
178
179    def __enter__(self):
180        with self._entry_lock:
181            cutoff_time = datetime.datetime.now() - self._time_window
182
183            # drop the entries that are too old, as they are no longer relevant
184            while self._past_entries and self._past_entries[0] < cutoff_time:
185                self._past_entries.popleft()
186
187            if len(self._past_entries) < self._access_limit:
188                self._past_entries.append(datetime.datetime.now())
189                return 0.0  # no waiting was needed
190
191            to_wait = (self._past_entries[0] - cutoff_time).total_seconds()
192            time.sleep(to_wait)
193
194            self._past_entries.append(datetime.datetime.now())
195            return to_wait
196
197    def __exit__(self, *_):
198        pass
199
200    def __repr__(self):
201        return "{}(access_limit={}, time_window={})".format(
202            self.__class__.__name__, self._access_limit, repr(self._time_window)
203        )
204
205
206class BidiRpc(object):
207    """A helper for consuming a bi-directional streaming RPC.
208
209    This maps gRPC's built-in interface which uses a request iterator and a
210    response iterator into a socket-like :func:`send` and :func:`recv`. This
211    is a more useful pattern for long-running or asymmetric streams (streams
212    where there is not a direct correlation between the requests and
213    responses).
214
215    Example::
216
217        initial_request = example_pb2.StreamingRpcRequest(
218            setting='example')
219        rpc = BidiRpc(
220            stub.StreamingRpc,
221            initial_request=initial_request,
222            metadata=[('name', 'value')]
223        )
224
225        rpc.open()
226
227        while rpc.is_active():
228            print(rpc.recv())
229            rpc.send(example_pb2.StreamingRpcRequest(
230                data='example'))
231
232    This does *not* retry the stream on errors. See :class:`ResumableBidiRpc`.
233
234    Args:
235        start_rpc (grpc.StreamStreamMultiCallable): The gRPC method used to
236            start the RPC.
237        initial_request (Union[protobuf.Message,
238                Callable[None, protobuf.Message]]): The initial request to
239            yield. This is useful if an initial request is needed to start the
240            stream.
241        metadata (Sequence[Tuple(str, str)]): RPC metadata to include in
242            the request.
243    """
244
245    def __init__(self, start_rpc, initial_request=None, metadata=None):
246        self._start_rpc = start_rpc
247        self._initial_request = initial_request
248        self._rpc_metadata = metadata
249        self._request_queue = queue_module.Queue()
250        self._request_generator = None
251        self._is_active = False
252        self._callbacks = []
253        self.call = None
254
255    def add_done_callback(self, callback):
256        """Adds a callback that will be called when the RPC terminates.
257
258        This occurs when the RPC errors or is successfully terminated.
259
260        Args:
261            callback (Callable[[grpc.Future], None]): The callback to execute.
262                It will be provided with the same gRPC future as the underlying
263                stream which will also be a :class:`grpc.Call`.
264        """
265        self._callbacks.append(callback)
266
267    def _on_call_done(self, future):
268        for callback in self._callbacks:
269            callback(future)
270
271    def open(self):
272        """Opens the stream."""
273        if self.is_active:
274            raise ValueError("Can not open an already open stream.")
275
276        request_generator = _RequestQueueGenerator(
277            self._request_queue, initial_request=self._initial_request
278        )
279        call = self._start_rpc(iter(request_generator), metadata=self._rpc_metadata)
280
281        request_generator.call = call
282
283        # TODO: api_core should expose the future interface for wrapped
284        # callables as well.
285        if hasattr(call, "_wrapped"):  # pragma: NO COVER
286            call._wrapped.add_done_callback(self._on_call_done)
287        else:
288            call.add_done_callback(self._on_call_done)
289
290        self._request_generator = request_generator
291        self.call = call
292
293    def close(self):
294        """Closes the stream."""
295        if self.call is None:
296            return
297
298        self._request_queue.put(None)
299        self.call.cancel()
300        self._request_generator = None
301        # Don't set self.call to None. Keep it around so that send/recv can
302        # raise the error.
303
304    def send(self, request):
305        """Queue a message to be sent on the stream.
306
307        Send is non-blocking.
308
309        If the underlying RPC has been closed, this will raise.
310
311        Args:
312            request (protobuf.Message): The request to send.
313        """
314        if self.call is None:
315            raise ValueError("Can not send() on an RPC that has never been open()ed.")
316
317        # Don't use self.is_active(), as ResumableBidiRpc will overload it
318        # to mean something semantically different.
319        if self.call.is_active():
320            self._request_queue.put(request)
321        else:
322            # calling next should cause the call to raise.
323            next(self.call)
324
325    def recv(self):
326        """Wait for a message to be returned from the stream.
327
328        Recv is blocking.
329
330        If the underlying RPC has been closed, this will raise.
331
332        Returns:
333            protobuf.Message: The received message.
334        """
335        if self.call is None:
336            raise ValueError("Can not recv() on an RPC that has never been open()ed.")
337
338        return next(self.call)
339
340    @property
341    def is_active(self):
342        """bool: True if this stream is currently open and active."""
343        return self.call is not None and self.call.is_active()
344
345    @property
346    def pending_requests(self):
347        """int: Returns an estimate of the number of queued requests."""
348        return self._request_queue.qsize()
349
350
351def _never_terminate(future_or_error):
352    """By default, no errors cause BiDi termination."""
353    return False
354
355
356class ResumableBidiRpc(BidiRpc):
357    """A :class:`BidiRpc` that can automatically resume the stream on errors.
358
359    It uses the ``should_recover`` arg to determine if it should re-establish
360    the stream on error.
361
362    Example::
363
364        def should_recover(exc):
365            return (
366                isinstance(exc, grpc.RpcError) and
367                exc.code() == grpc.StatusCode.UNVAILABLE)
368
369        initial_request = example_pb2.StreamingRpcRequest(
370            setting='example')
371
372        metadata = [('header_name', 'value')]
373
374        rpc = ResumableBidiRpc(
375            stub.StreamingRpc,
376            should_recover=should_recover,
377            initial_request=initial_request,
378            metadata=metadata
379        )
380
381        rpc.open()
382
383        while rpc.is_active():
384            print(rpc.recv())
385            rpc.send(example_pb2.StreamingRpcRequest(
386                data='example'))
387
388    Args:
389        start_rpc (grpc.StreamStreamMultiCallable): The gRPC method used to
390            start the RPC.
391        initial_request (Union[protobuf.Message,
392                Callable[None, protobuf.Message]]): The initial request to
393            yield. This is useful if an initial request is needed to start the
394            stream.
395        should_recover (Callable[[Exception], bool]): A function that returns
396            True if the stream should be recovered. This will be called
397            whenever an error is encountered on the stream.
398        should_terminate (Callable[[Exception], bool]): A function that returns
399            True if the stream should be terminated. This will be called
400            whenever an error is encountered on the stream.
401        metadata Sequence[Tuple(str, str)]: RPC metadata to include in
402            the request.
403        throttle_reopen (bool): If ``True``, throttling will be applied to
404            stream reopen calls. Defaults to ``False``.
405    """
406
407    def __init__(
408        self,
409        start_rpc,
410        should_recover,
411        should_terminate=_never_terminate,
412        initial_request=None,
413        metadata=None,
414        throttle_reopen=False,
415    ):
416        super(ResumableBidiRpc, self).__init__(start_rpc, initial_request, metadata)
417        self._should_recover = should_recover
418        self._should_terminate = should_terminate
419        self._operational_lock = threading.RLock()
420        self._finalized = False
421        self._finalize_lock = threading.Lock()
422
423        if throttle_reopen:
424            self._reopen_throttle = _Throttle(
425                access_limit=5, time_window=datetime.timedelta(seconds=10)
426            )
427        else:
428            self._reopen_throttle = None
429
430    def _finalize(self, result):
431        with self._finalize_lock:
432            if self._finalized:
433                return
434
435            for callback in self._callbacks:
436                callback(result)
437
438            self._finalized = True
439
440    def _on_call_done(self, future):
441        # Unlike the base class, we only execute the callbacks on a terminal
442        # error, not for errors that we can recover from. Note that grpc's
443        # "future" here is also a grpc.RpcError.
444        with self._operational_lock:
445            if self._should_terminate(future):
446                self._finalize(future)
447            elif not self._should_recover(future):
448                self._finalize(future)
449            else:
450                _LOGGER.debug("Re-opening stream from gRPC callback.")
451                self._reopen()
452
453    def _reopen(self):
454        with self._operational_lock:
455            # Another thread already managed to re-open this stream.
456            if self.call is not None and self.call.is_active():
457                _LOGGER.debug("Stream was already re-established.")
458                return
459
460            self.call = None
461            # Request generator should exit cleanly since the RPC its bound to
462            # has exited.
463            self._request_generator = None
464
465            # Note: we do not currently do any sort of backoff here. The
466            # assumption is that re-establishing the stream under normal
467            # circumstances will happen in intervals greater than 60s.
468            # However, it is possible in a degenerative case that the server
469            # closes the stream rapidly which would lead to thrashing here,
470            # but hopefully in those cases the server would return a non-
471            # retryable error.
472
473            try:
474                if self._reopen_throttle:
475                    with self._reopen_throttle:
476                        self.open()
477                else:
478                    self.open()
479            # If re-opening or re-calling the method fails for any reason,
480            # consider it a terminal error and finalize the stream.
481            except Exception as exc:
482                _LOGGER.debug("Failed to re-open stream due to %s", exc)
483                self._finalize(exc)
484                raise
485
486            _LOGGER.info("Re-established stream")
487
488    def _recoverable(self, method, *args, **kwargs):
489        """Wraps a method to recover the stream and retry on error.
490
491        If a retryable error occurs while making the call, then the stream will
492        be re-opened and the method will be retried. This happens indefinitely
493        so long as the error is a retryable one. If an error occurs while
494        re-opening the stream, then this method will raise immediately and
495        trigger finalization of this object.
496
497        Args:
498            method (Callable[..., Any]): The method to call.
499            args: The args to pass to the method.
500            kwargs: The kwargs to pass to the method.
501        """
502        while True:
503            try:
504                return method(*args, **kwargs)
505
506            except Exception as exc:
507                with self._operational_lock:
508                    _LOGGER.debug("Call to retryable %r caused %s.", method, exc)
509
510                    if self._should_terminate(exc):
511                        self.close()
512                        _LOGGER.debug("Terminating %r due to %s.", method, exc)
513                        self._finalize(exc)
514                        break
515
516                    if not self._should_recover(exc):
517                        self.close()
518                        _LOGGER.debug("Not retrying %r due to %s.", method, exc)
519                        self._finalize(exc)
520                        raise exc
521
522                    _LOGGER.debug("Re-opening stream from retryable %r.", method)
523                    self._reopen()
524
525    def _send(self, request):
526        # Grab a reference to the RPC call. Because another thread (notably
527        # the gRPC error thread) can modify self.call (by invoking reopen),
528        # we should ensure our reference can not change underneath us.
529        # If self.call is modified (such as replaced with a new RPC call) then
530        # this will use the "old" RPC, which should result in the same
531        # exception passed into gRPC's error handler being raised here, which
532        # will be handled by the usual error handling in retryable.
533        with self._operational_lock:
534            call = self.call
535
536        if call is None:
537            raise ValueError("Can not send() on an RPC that has never been open()ed.")
538
539        # Don't use self.is_active(), as ResumableBidiRpc will overload it
540        # to mean something semantically different.
541        if call.is_active():
542            self._request_queue.put(request)
543            pass
544        else:
545            # calling next should cause the call to raise.
546            next(call)
547
548    def send(self, request):
549        return self._recoverable(self._send, request)
550
551    def _recv(self):
552        with self._operational_lock:
553            call = self.call
554
555        if call is None:
556            raise ValueError("Can not recv() on an RPC that has never been open()ed.")
557
558        return next(call)
559
560    def recv(self):
561        return self._recoverable(self._recv)
562
563    def close(self):
564        self._finalize(None)
565        super(ResumableBidiRpc, self).close()
566
567    @property
568    def is_active(self):
569        """bool: True if this stream is currently open and active."""
570        # Use the operational lock. It's entirely possible for something
571        # to check the active state *while* the RPC is being retried.
572        # Also, use finalized to track the actual terminal state here.
573        # This is because if the stream is re-established by the gRPC thread
574        # it's technically possible to check this between when gRPC marks the
575        # RPC as inactive and when gRPC executes our callback that re-opens
576        # the stream.
577        with self._operational_lock:
578            return self.call is not None and not self._finalized
579
580
581class BackgroundConsumer(object):
582    """A bi-directional stream consumer that runs in a separate thread.
583
584    This maps the consumption of a stream into a callback-based model. It also
585    provides :func:`pause` and :func:`resume` to allow for flow-control.
586
587    Example::
588
589        def should_recover(exc):
590            return (
591                isinstance(exc, grpc.RpcError) and
592                exc.code() == grpc.StatusCode.UNVAILABLE)
593
594        initial_request = example_pb2.StreamingRpcRequest(
595            setting='example')
596
597        rpc = ResumeableBidiRpc(
598            stub.StreamingRpc,
599            initial_request=initial_request,
600            should_recover=should_recover)
601
602        def on_response(response):
603            print(response)
604
605        consumer = BackgroundConsumer(rpc, on_response)
606        consumer.start()
607
608    Note that error handling *must* be done by using the provided
609    ``bidi_rpc``'s ``add_done_callback``. This helper will automatically exit
610    whenever the RPC itself exits and will not provide any error details.
611
612    Args:
613        bidi_rpc (BidiRpc): The RPC to consume. Should not have been
614            ``open()``ed yet.
615        on_response (Callable[[protobuf.Message], None]): The callback to
616            be called for every response on the stream.
617    """
618
619    def __init__(self, bidi_rpc, on_response):
620        self._bidi_rpc = bidi_rpc
621        self._on_response = on_response
622        self._paused = False
623        self._wake = threading.Condition()
624        self._thread = None
625        self._operational_lock = threading.Lock()
626
627    def _on_call_done(self, future):
628        # Resume the thread if it's paused, this prevents blocking forever
629        # when the RPC has terminated.
630        self.resume()
631
632    def _thread_main(self, ready):
633        try:
634            ready.set()
635            self._bidi_rpc.add_done_callback(self._on_call_done)
636            self._bidi_rpc.open()
637
638            while self._bidi_rpc.is_active:
639                # Do not allow the paused status to change at all during this
640                # section. There is a condition where we could be resumed
641                # between checking if we are paused and calling wake.wait(),
642                # which means that we will miss the notification to wake up
643                # (oops!) and wait for a notification that will never come.
644                # Keeping the lock throughout avoids that.
645                # In the future, we could use `Condition.wait_for` if we drop
646                # Python 2.7.
647                # See: https://github.com/googleapis/python-api-core/issues/211
648                with self._wake:
649                    while self._paused:
650                        _LOGGER.debug("paused, waiting for waking.")
651                        self._wake.wait()
652                        _LOGGER.debug("woken.")
653
654                _LOGGER.debug("waiting for recv.")
655                response = self._bidi_rpc.recv()
656                _LOGGER.debug("recved response.")
657                self._on_response(response)
658
659        except exceptions.GoogleAPICallError as exc:
660            _LOGGER.debug(
661                "%s caught error %s and will exit. Generally this is due to "
662                "the RPC itself being cancelled and the error will be "
663                "surfaced to the calling code.",
664                _BIDIRECTIONAL_CONSUMER_NAME,
665                exc,
666                exc_info=True,
667            )
668
669        except Exception as exc:
670            _LOGGER.exception(
671                "%s caught unexpected exception %s and will exit.",
672                _BIDIRECTIONAL_CONSUMER_NAME,
673                exc,
674            )
675
676        _LOGGER.info("%s exiting", _BIDIRECTIONAL_CONSUMER_NAME)
677
678    def start(self):
679        """Start the background thread and begin consuming the thread."""
680        with self._operational_lock:
681            ready = threading.Event()
682            thread = threading.Thread(
683                name=_BIDIRECTIONAL_CONSUMER_NAME,
684                target=self._thread_main,
685                args=(ready,),
686            )
687            thread.daemon = True
688            thread.start()
689            # Other parts of the code rely on `thread.is_alive` which
690            # isn't sufficient to know if a thread is active, just that it may
691            # soon be active. This can cause races. Further protect
692            # against races by using a ready event and wait on it to be set.
693            ready.wait()
694            self._thread = thread
695            _LOGGER.debug("Started helper thread %s", thread.name)
696
697    def stop(self):
698        """Stop consuming the stream and shutdown the background thread."""
699        with self._operational_lock:
700            self._bidi_rpc.close()
701
702            if self._thread is not None:
703                # Resume the thread to wake it up in case it is sleeping.
704                self.resume()
705                # The daemonized thread may itself block, so don't wait
706                # for it longer than a second.
707                self._thread.join(1.0)
708                if self._thread.is_alive():  # pragma: NO COVER
709                    _LOGGER.warning("Background thread did not exit.")
710
711            self._thread = None
712
713    @property
714    def is_active(self):
715        """bool: True if the background thread is active."""
716        return self._thread is not None and self._thread.is_alive()
717
718    def pause(self):
719        """Pauses the response stream.
720
721        This does *not* pause the request stream.
722        """
723        with self._wake:
724            self._paused = True
725
726    def resume(self):
727        """Resumes the response stream."""
728        with self._wake:
729            self._paused = False
730            self._wake.notify_all()
731
732    @property
733    def is_paused(self):
734        """bool: True if the response stream is paused."""
735        return self._paused
736