xref: /aosp_15_r20/external/tensorflow/tensorflow/python/training/queue_runner_impl.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Create threads to run multiple enqueue ops."""
17import threading
18import weakref
19
20from tensorflow.core.protobuf import queue_runner_pb2
21from tensorflow.python.client import session
22from tensorflow.python.eager import context
23from tensorflow.python.framework import errors
24from tensorflow.python.framework import ops
25from tensorflow.python.platform import tf_logging as logging
26from tensorflow.python.util import deprecation
27from tensorflow.python.util.tf_export import tf_export
28
29_DEPRECATION_INSTRUCTION = (
30    "To construct input pipelines, use the `tf.data` module.")
31
32
33@tf_export(v1=["train.queue_runner.QueueRunner", "train.QueueRunner"])
34class QueueRunner:
35  """Holds a list of enqueue operations for a queue, each to be run in a thread.
36
37  Queues are a convenient TensorFlow mechanism to compute tensors
38  asynchronously using multiple threads. For example in the canonical 'Input
39  Reader' setup one set of threads generates filenames in a queue; a second set
40  of threads read records from the files, processes them, and enqueues tensors
41  on a second queue; a third set of threads dequeues these input records to
42  construct batches and runs them through training operations.
43
44  There are several delicate issues when running multiple threads that way:
45  closing the queues in sequence as the input is exhausted, correctly catching
46  and reporting exceptions, etc.
47
48  The `QueueRunner`, combined with the `Coordinator`, helps handle these issues.
49
50  @compatibility(TF2)
51  QueueRunners are not compatible with eager execution. Instead, please
52  use [tf.data](https://www.tensorflow.org/guide/data) to get data into your
53  model.
54  @end_compatibility
55  """
56
57  @deprecation.deprecated(None, _DEPRECATION_INSTRUCTION)
58  def __init__(self, queue=None, enqueue_ops=None, close_op=None,
59               cancel_op=None, queue_closed_exception_types=None,
60               queue_runner_def=None, import_scope=None):
61    """Create a QueueRunner.
62
63    On construction the `QueueRunner` adds an op to close the queue.  That op
64    will be run if the enqueue ops raise exceptions.
65
66    When you later call the `create_threads()` method, the `QueueRunner` will
67    create one thread for each op in `enqueue_ops`.  Each thread will run its
68    enqueue op in parallel with the other threads.  The enqueue ops do not have
69    to all be the same op, but it is expected that they all enqueue tensors in
70    `queue`.
71
72    Args:
73      queue: A `Queue`.
74      enqueue_ops: List of enqueue ops to run in threads later.
75      close_op: Op to close the queue. Pending enqueue ops are preserved.
76      cancel_op: Op to close the queue and cancel pending enqueue ops.
77      queue_closed_exception_types: Optional tuple of Exception types that
78        indicate that the queue has been closed when raised during an enqueue
79        operation.  Defaults to `(tf.errors.OutOfRangeError,)`.  Another common
80        case includes `(tf.errors.OutOfRangeError, tf.errors.CancelledError)`,
81        when some of the enqueue ops may dequeue from other Queues.
82      queue_runner_def: Optional `QueueRunnerDef` protocol buffer. If specified,
83        recreates the QueueRunner from its contents. `queue_runner_def` and the
84        other arguments are mutually exclusive.
85      import_scope: Optional `string`. Name scope to add. Only used when
86        initializing from protocol buffer.
87
88    Raises:
89      ValueError: If both `queue_runner_def` and `queue` are both specified.
90      ValueError: If `queue` or `enqueue_ops` are not provided when not
91        restoring from `queue_runner_def`.
92      RuntimeError: If eager execution is enabled.
93    """
94    if context.executing_eagerly():
95      raise RuntimeError(
96          "QueueRunners are not supported when eager execution is enabled. "
97          "Instead, please use tf.data to get data into your model.")
98
99    if queue_runner_def:
100      if queue or enqueue_ops:
101        raise ValueError("queue_runner_def and queue are mutually exclusive.")
102      self._init_from_proto(queue_runner_def,
103                            import_scope=import_scope)
104    else:
105      self._init_from_args(
106          queue=queue, enqueue_ops=enqueue_ops,
107          close_op=close_op, cancel_op=cancel_op,
108          queue_closed_exception_types=queue_closed_exception_types)
109    # Protect the count of runs to wait for.
110    self._lock = threading.Lock()
111    # A map from a session object to the number of outstanding queue runner
112    # threads for that session.
113    self._runs_per_session = weakref.WeakKeyDictionary()
114    # List of exceptions raised by the running threads.
115    self._exceptions_raised = []
116
117  def _init_from_args(self, queue=None, enqueue_ops=None, close_op=None,
118                      cancel_op=None, queue_closed_exception_types=None):
119    """Create a QueueRunner from arguments.
120
121    Args:
122      queue: A `Queue`.
123      enqueue_ops: List of enqueue ops to run in threads later.
124      close_op: Op to close the queue. Pending enqueue ops are preserved.
125      cancel_op: Op to close the queue and cancel pending enqueue ops.
126      queue_closed_exception_types: Tuple of exception types, which indicate
127        the queue has been safely closed.
128
129    Raises:
130      ValueError: If `queue` or `enqueue_ops` are not provided when not
131        restoring from `queue_runner_def`.
132      TypeError: If `queue_closed_exception_types` is provided, but is not
133        a non-empty tuple of error types (subclasses of `tf.errors.OpError`).
134    """
135    if not queue or not enqueue_ops:
136      raise ValueError("Must provide queue and enqueue_ops.")
137    self._queue = queue
138    self._enqueue_ops = enqueue_ops
139    self._close_op = close_op
140    self._cancel_op = cancel_op
141    if queue_closed_exception_types is not None:
142      if (not isinstance(queue_closed_exception_types, tuple)
143          or not queue_closed_exception_types
144          or not all(issubclass(t, errors.OpError)
145                     for t in queue_closed_exception_types)):
146        raise TypeError(
147            "queue_closed_exception_types, when provided, "
148            "must be a tuple of tf.error types, but saw: %s"
149            % queue_closed_exception_types)
150    self._queue_closed_exception_types = queue_closed_exception_types
151    # Close when no more will be produced, but pending enqueues should be
152    # preserved.
153    if self._close_op is None:
154      self._close_op = self._queue.close()
155    # Close and cancel pending enqueues since there was an error and we want
156    # to unblock everything so we can cleanly exit.
157    if self._cancel_op is None:
158      self._cancel_op = self._queue.close(cancel_pending_enqueues=True)
159    if not self._queue_closed_exception_types:
160      self._queue_closed_exception_types = (errors.OutOfRangeError,)
161    else:
162      self._queue_closed_exception_types = tuple(
163          self._queue_closed_exception_types)
164
165  def _init_from_proto(self, queue_runner_def, import_scope=None):
166    """Create a QueueRunner from `QueueRunnerDef`.
167
168    Args:
169      queue_runner_def: Optional `QueueRunnerDef` protocol buffer.
170      import_scope: Optional `string`. Name scope to add.
171    """
172    assert isinstance(queue_runner_def, queue_runner_pb2.QueueRunnerDef)
173    g = ops.get_default_graph()
174    self._queue = g.as_graph_element(
175        ops.prepend_name_scope(queue_runner_def.queue_name, import_scope))
176    self._enqueue_ops = [g.as_graph_element(
177        ops.prepend_name_scope(op, import_scope))
178                         for op in queue_runner_def.enqueue_op_name]
179    self._close_op = g.as_graph_element(ops.prepend_name_scope(
180        queue_runner_def.close_op_name, import_scope))
181    self._cancel_op = g.as_graph_element(ops.prepend_name_scope(
182        queue_runner_def.cancel_op_name, import_scope))
183    self._queue_closed_exception_types = tuple(
184        errors.exception_type_from_error_code(code)
185        for code in queue_runner_def.queue_closed_exception_types)
186    # Legacy support for old QueueRunnerDefs created before this field
187    # was added.
188    if not self._queue_closed_exception_types:
189      self._queue_closed_exception_types = (errors.OutOfRangeError,)
190
191  @property
192  def queue(self):
193    return self._queue
194
195  @property
196  def enqueue_ops(self):
197    return self._enqueue_ops
198
199  @property
200  def close_op(self):
201    return self._close_op
202
203  @property
204  def cancel_op(self):
205    return self._cancel_op
206
207  @property
208  def queue_closed_exception_types(self):
209    return self._queue_closed_exception_types
210
211  @property
212  def exceptions_raised(self):
213    """Exceptions raised but not handled by the `QueueRunner` threads.
214
215    Exceptions raised in queue runner threads are handled in one of two ways
216    depending on whether or not a `Coordinator` was passed to
217    `create_threads()`:
218
219    * With a `Coordinator`, exceptions are reported to the coordinator and
220      forgotten by the `QueueRunner`.
221    * Without a `Coordinator`, exceptions are captured by the `QueueRunner` and
222      made available in this `exceptions_raised` property.
223
224    Returns:
225      A list of Python `Exception` objects.  The list is empty if no exception
226      was captured.  (No exceptions are captured when using a Coordinator.)
227    """
228    return self._exceptions_raised
229
230  @property
231  def name(self):
232    """The string name of the underlying Queue."""
233    return self._queue.name
234
235  # pylint: disable=broad-except
236  def _run(self, sess, enqueue_op, coord=None):
237    """Execute the enqueue op in a loop, close the queue in case of error.
238
239    Args:
240      sess: A Session.
241      enqueue_op: The Operation to run.
242      coord: Optional Coordinator object for reporting errors and checking
243        for stop conditions.
244    """
245    decremented = False
246    try:
247      # Make a cached callable from the `enqueue_op` to decrease the
248      # Python overhead in the queue-runner loop.
249      enqueue_callable = sess.make_callable(enqueue_op)
250      while True:
251        if coord and coord.should_stop():
252          break
253        try:
254          enqueue_callable()
255        except self._queue_closed_exception_types:  # pylint: disable=catching-non-exception
256          # This exception indicates that a queue was closed.
257          with self._lock:
258            self._runs_per_session[sess] -= 1
259            decremented = True
260            if self._runs_per_session[sess] == 0:
261              try:
262                sess.run(self._close_op)
263              except Exception as e:
264                # Intentionally ignore errors from close_op.
265                logging.vlog(1, "Ignored exception: %s", str(e))
266            return
267    except Exception as e:
268      # This catches all other exceptions.
269      if coord:
270        coord.request_stop(e)
271      else:
272        logging.error("Exception in QueueRunner: %s", str(e))
273        with self._lock:
274          self._exceptions_raised.append(e)
275        raise
276    finally:
277      # Make sure we account for all terminations: normal or errors.
278      if not decremented:
279        with self._lock:
280          self._runs_per_session[sess] -= 1
281
282  def _close_on_stop(self, sess, cancel_op, coord):
283    """Close the queue when the Coordinator requests stop.
284
285    Args:
286      sess: A Session.
287      cancel_op: The Operation to run.
288      coord: Coordinator.
289    """
290    coord.wait_for_stop()
291    try:
292      sess.run(cancel_op)
293    except Exception as e:
294      # Intentionally ignore errors from cancel_op.
295      logging.vlog(1, "Ignored exception: %s", str(e))
296  # pylint: enable=broad-except
297
298  def create_threads(self, sess, coord=None, daemon=False, start=False):
299    """Create threads to run the enqueue ops for the given session.
300
301    This method requires a session in which the graph was launched.  It creates
302    a list of threads, optionally starting them.  There is one thread for each
303    op passed in `enqueue_ops`.
304
305    The `coord` argument is an optional coordinator that the threads will use
306    to terminate together and report exceptions.  If a coordinator is given,
307    this method starts an additional thread to close the queue when the
308    coordinator requests a stop.
309
310    If previously created threads for the given session are still running, no
311    new threads will be created.
312
313    Args:
314      sess: A `Session`.
315      coord: Optional `Coordinator` object for reporting errors and checking
316        stop conditions.
317      daemon: Boolean.  If `True` make the threads daemon threads.
318      start: Boolean.  If `True` starts the threads.  If `False` the
319        caller must call the `start()` method of the returned threads.
320
321    Returns:
322      A list of threads.
323    """
324    with self._lock:
325      try:
326        if self._runs_per_session[sess] > 0:
327          # Already started: no new threads to return.
328          return []
329      except KeyError:
330        # We haven't seen this session yet.
331        pass
332      self._runs_per_session[sess] = len(self._enqueue_ops)
333      self._exceptions_raised = []
334
335    ret_threads = []
336    for op in self._enqueue_ops:
337      name = "QueueRunnerThread-{}-{}".format(self.name, op.name)
338      ret_threads.append(threading.Thread(target=self._run,
339                                          args=(sess, op, coord),
340                                          name=name))
341    if coord:
342      name = "QueueRunnerThread-{}-close_on_stop".format(self.name)
343      ret_threads.append(threading.Thread(target=self._close_on_stop,
344                                          args=(sess, self._cancel_op, coord),
345                                          name=name))
346    for t in ret_threads:
347      if coord:
348        coord.register_thread(t)
349      if daemon:
350        t.daemon = True
351      if start:
352        t.start()
353    return ret_threads
354
355  def to_proto(self, export_scope=None):
356    """Converts this `QueueRunner` to a `QueueRunnerDef` protocol buffer.
357
358    Args:
359      export_scope: Optional `string`. Name scope to remove.
360
361    Returns:
362      A `QueueRunnerDef` protocol buffer, or `None` if the `Variable` is not in
363      the specified name scope.
364    """
365    if (export_scope is None or
366        self.queue.name.startswith(export_scope)):
367      queue_runner_def = queue_runner_pb2.QueueRunnerDef()
368      queue_runner_def.queue_name = ops.strip_name_scope(
369          self.queue.name, export_scope)
370      for enqueue_op in self.enqueue_ops:
371        queue_runner_def.enqueue_op_name.append(
372            ops.strip_name_scope(enqueue_op.name, export_scope))
373      queue_runner_def.close_op_name = ops.strip_name_scope(
374          self.close_op.name, export_scope)
375      queue_runner_def.cancel_op_name = ops.strip_name_scope(
376          self.cancel_op.name, export_scope)
377      queue_runner_def.queue_closed_exception_types.extend([
378          errors.error_code_from_exception_type(cls)
379          for cls in self._queue_closed_exception_types])
380      return queue_runner_def
381    else:
382      return None
383
384  @staticmethod
385  def from_proto(queue_runner_def, import_scope=None):
386    """Returns a `QueueRunner` object created from `queue_runner_def`."""
387    return QueueRunner(queue_runner_def=queue_runner_def,
388                       import_scope=import_scope)
389
390
391@tf_export(v1=["train.queue_runner.add_queue_runner", "train.add_queue_runner"])
392@deprecation.deprecated(None, _DEPRECATION_INSTRUCTION)
393def add_queue_runner(qr, collection=ops.GraphKeys.QUEUE_RUNNERS):
394  """Adds a `QueueRunner` to a collection in the graph.
395
396  When building a complex model that uses many queues it is often difficult to
397  gather all the queue runners that need to be run.  This convenience function
398  allows you to add a queue runner to a well known collection in the graph.
399
400  The companion method `start_queue_runners()` can be used to start threads for
401  all the collected queue runners.
402
403  @compatibility(TF2)
404  QueueRunners are not compatible with eager execution. Instead, please
405  use [tf.data](https://www.tensorflow.org/guide/data) to get data into your
406  model.
407  @end_compatibility
408
409  Args:
410    qr: A `QueueRunner`.
411    collection: A `GraphKey` specifying the graph collection to add
412      the queue runner to.  Defaults to `GraphKeys.QUEUE_RUNNERS`.
413  """
414  ops.add_to_collection(collection, qr)
415
416
417@tf_export(v1=["train.queue_runner.start_queue_runners",
418               "train.start_queue_runners"])
419@deprecation.deprecated(None, _DEPRECATION_INSTRUCTION)
420def start_queue_runners(sess=None, coord=None, daemon=True, start=True,
421                        collection=ops.GraphKeys.QUEUE_RUNNERS):
422  """Starts all queue runners collected in the graph.
423
424  This is a companion method to `add_queue_runner()`.  It just starts
425  threads for all queue runners collected in the graph.  It returns
426  the list of all threads.
427
428  @compatibility(TF2)
429  QueueRunners are not compatible with eager execution. Instead, please
430  use [tf.data](https://www.tensorflow.org/guide/data) to get data into your
431  model.
432  @end_compatibility
433
434  Args:
435    sess: `Session` used to run the queue ops.  Defaults to the
436      default session.
437    coord: Optional `Coordinator` for coordinating the started threads.
438    daemon: Whether the threads should be marked as `daemons`, meaning
439      they don't block program exit.
440    start: Set to `False` to only create the threads, not start them.
441    collection: A `GraphKey` specifying the graph collection to
442      get the queue runners from.  Defaults to `GraphKeys.QUEUE_RUNNERS`.
443
444  Raises:
445    ValueError: if `sess` is None and there isn't any default session.
446    TypeError: if `sess` is not a `tf.compat.v1.Session` object.
447
448  Returns:
449    A list of threads.
450
451  Raises:
452    RuntimeError: If called with eager execution enabled.
453    ValueError: If called without a default `tf.compat.v1.Session` registered.
454  """
455  if context.executing_eagerly():
456    raise RuntimeError("Queues are not compatible with eager execution.")
457  if sess is None:
458    sess = ops.get_default_session()
459    if not sess:
460      raise ValueError("Cannot start queue runners: No default session is "
461                       "registered. Use `with sess.as_default()` or pass an "
462                       "explicit session to tf.start_queue_runners(sess=sess)")
463
464  if not isinstance(sess, session.SessionInterface):
465    # Following check is due to backward compatibility. (b/62061352)
466    if sess.__class__.__name__ in [
467        "MonitoredSession", "SingularMonitoredSession"]:
468      return []
469    raise TypeError("sess must be a `tf.Session` object. "
470                    "Given class: {}".format(sess.__class__))
471
472  queue_runners = ops.get_collection(collection)
473  if not queue_runners:
474    logging.warning(
475        "`tf.train.start_queue_runners()` was called when no queue runners "
476        "were defined. You can safely remove the call to this deprecated "
477        "function.")
478
479  with sess.graph.as_default():
480    threads = []
481    for qr in ops.get_collection(collection):
482      threads.extend(qr.create_threads(sess, coord=coord, daemon=daemon,
483                                       start=start))
484  return threads
485
486
487ops.register_proto_function(ops.GraphKeys.QUEUE_RUNNERS,
488                            proto_type=queue_runner_pb2.QueueRunnerDef,
489                            to_proto=QueueRunner.to_proto,
490                            from_proto=QueueRunner.from_proto)
491