1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Create threads to run multiple enqueue ops.""" 17import threading 18import weakref 19 20from tensorflow.core.protobuf import queue_runner_pb2 21from tensorflow.python.client import session 22from tensorflow.python.eager import context 23from tensorflow.python.framework import errors 24from tensorflow.python.framework import ops 25from tensorflow.python.platform import tf_logging as logging 26from tensorflow.python.util import deprecation 27from tensorflow.python.util.tf_export import tf_export 28 29_DEPRECATION_INSTRUCTION = ( 30 "To construct input pipelines, use the `tf.data` module.") 31 32 33@tf_export(v1=["train.queue_runner.QueueRunner", "train.QueueRunner"]) 34class QueueRunner: 35 """Holds a list of enqueue operations for a queue, each to be run in a thread. 36 37 Queues are a convenient TensorFlow mechanism to compute tensors 38 asynchronously using multiple threads. For example in the canonical 'Input 39 Reader' setup one set of threads generates filenames in a queue; a second set 40 of threads read records from the files, processes them, and enqueues tensors 41 on a second queue; a third set of threads dequeues these input records to 42 construct batches and runs them through training operations. 43 44 There are several delicate issues when running multiple threads that way: 45 closing the queues in sequence as the input is exhausted, correctly catching 46 and reporting exceptions, etc. 47 48 The `QueueRunner`, combined with the `Coordinator`, helps handle these issues. 49 50 @compatibility(TF2) 51 QueueRunners are not compatible with eager execution. Instead, please 52 use [tf.data](https://www.tensorflow.org/guide/data) to get data into your 53 model. 54 @end_compatibility 55 """ 56 57 @deprecation.deprecated(None, _DEPRECATION_INSTRUCTION) 58 def __init__(self, queue=None, enqueue_ops=None, close_op=None, 59 cancel_op=None, queue_closed_exception_types=None, 60 queue_runner_def=None, import_scope=None): 61 """Create a QueueRunner. 62 63 On construction the `QueueRunner` adds an op to close the queue. That op 64 will be run if the enqueue ops raise exceptions. 65 66 When you later call the `create_threads()` method, the `QueueRunner` will 67 create one thread for each op in `enqueue_ops`. Each thread will run its 68 enqueue op in parallel with the other threads. The enqueue ops do not have 69 to all be the same op, but it is expected that they all enqueue tensors in 70 `queue`. 71 72 Args: 73 queue: A `Queue`. 74 enqueue_ops: List of enqueue ops to run in threads later. 75 close_op: Op to close the queue. Pending enqueue ops are preserved. 76 cancel_op: Op to close the queue and cancel pending enqueue ops. 77 queue_closed_exception_types: Optional tuple of Exception types that 78 indicate that the queue has been closed when raised during an enqueue 79 operation. Defaults to `(tf.errors.OutOfRangeError,)`. Another common 80 case includes `(tf.errors.OutOfRangeError, tf.errors.CancelledError)`, 81 when some of the enqueue ops may dequeue from other Queues. 82 queue_runner_def: Optional `QueueRunnerDef` protocol buffer. If specified, 83 recreates the QueueRunner from its contents. `queue_runner_def` and the 84 other arguments are mutually exclusive. 85 import_scope: Optional `string`. Name scope to add. Only used when 86 initializing from protocol buffer. 87 88 Raises: 89 ValueError: If both `queue_runner_def` and `queue` are both specified. 90 ValueError: If `queue` or `enqueue_ops` are not provided when not 91 restoring from `queue_runner_def`. 92 RuntimeError: If eager execution is enabled. 93 """ 94 if context.executing_eagerly(): 95 raise RuntimeError( 96 "QueueRunners are not supported when eager execution is enabled. " 97 "Instead, please use tf.data to get data into your model.") 98 99 if queue_runner_def: 100 if queue or enqueue_ops: 101 raise ValueError("queue_runner_def and queue are mutually exclusive.") 102 self._init_from_proto(queue_runner_def, 103 import_scope=import_scope) 104 else: 105 self._init_from_args( 106 queue=queue, enqueue_ops=enqueue_ops, 107 close_op=close_op, cancel_op=cancel_op, 108 queue_closed_exception_types=queue_closed_exception_types) 109 # Protect the count of runs to wait for. 110 self._lock = threading.Lock() 111 # A map from a session object to the number of outstanding queue runner 112 # threads for that session. 113 self._runs_per_session = weakref.WeakKeyDictionary() 114 # List of exceptions raised by the running threads. 115 self._exceptions_raised = [] 116 117 def _init_from_args(self, queue=None, enqueue_ops=None, close_op=None, 118 cancel_op=None, queue_closed_exception_types=None): 119 """Create a QueueRunner from arguments. 120 121 Args: 122 queue: A `Queue`. 123 enqueue_ops: List of enqueue ops to run in threads later. 124 close_op: Op to close the queue. Pending enqueue ops are preserved. 125 cancel_op: Op to close the queue and cancel pending enqueue ops. 126 queue_closed_exception_types: Tuple of exception types, which indicate 127 the queue has been safely closed. 128 129 Raises: 130 ValueError: If `queue` or `enqueue_ops` are not provided when not 131 restoring from `queue_runner_def`. 132 TypeError: If `queue_closed_exception_types` is provided, but is not 133 a non-empty tuple of error types (subclasses of `tf.errors.OpError`). 134 """ 135 if not queue or not enqueue_ops: 136 raise ValueError("Must provide queue and enqueue_ops.") 137 self._queue = queue 138 self._enqueue_ops = enqueue_ops 139 self._close_op = close_op 140 self._cancel_op = cancel_op 141 if queue_closed_exception_types is not None: 142 if (not isinstance(queue_closed_exception_types, tuple) 143 or not queue_closed_exception_types 144 or not all(issubclass(t, errors.OpError) 145 for t in queue_closed_exception_types)): 146 raise TypeError( 147 "queue_closed_exception_types, when provided, " 148 "must be a tuple of tf.error types, but saw: %s" 149 % queue_closed_exception_types) 150 self._queue_closed_exception_types = queue_closed_exception_types 151 # Close when no more will be produced, but pending enqueues should be 152 # preserved. 153 if self._close_op is None: 154 self._close_op = self._queue.close() 155 # Close and cancel pending enqueues since there was an error and we want 156 # to unblock everything so we can cleanly exit. 157 if self._cancel_op is None: 158 self._cancel_op = self._queue.close(cancel_pending_enqueues=True) 159 if not self._queue_closed_exception_types: 160 self._queue_closed_exception_types = (errors.OutOfRangeError,) 161 else: 162 self._queue_closed_exception_types = tuple( 163 self._queue_closed_exception_types) 164 165 def _init_from_proto(self, queue_runner_def, import_scope=None): 166 """Create a QueueRunner from `QueueRunnerDef`. 167 168 Args: 169 queue_runner_def: Optional `QueueRunnerDef` protocol buffer. 170 import_scope: Optional `string`. Name scope to add. 171 """ 172 assert isinstance(queue_runner_def, queue_runner_pb2.QueueRunnerDef) 173 g = ops.get_default_graph() 174 self._queue = g.as_graph_element( 175 ops.prepend_name_scope(queue_runner_def.queue_name, import_scope)) 176 self._enqueue_ops = [g.as_graph_element( 177 ops.prepend_name_scope(op, import_scope)) 178 for op in queue_runner_def.enqueue_op_name] 179 self._close_op = g.as_graph_element(ops.prepend_name_scope( 180 queue_runner_def.close_op_name, import_scope)) 181 self._cancel_op = g.as_graph_element(ops.prepend_name_scope( 182 queue_runner_def.cancel_op_name, import_scope)) 183 self._queue_closed_exception_types = tuple( 184 errors.exception_type_from_error_code(code) 185 for code in queue_runner_def.queue_closed_exception_types) 186 # Legacy support for old QueueRunnerDefs created before this field 187 # was added. 188 if not self._queue_closed_exception_types: 189 self._queue_closed_exception_types = (errors.OutOfRangeError,) 190 191 @property 192 def queue(self): 193 return self._queue 194 195 @property 196 def enqueue_ops(self): 197 return self._enqueue_ops 198 199 @property 200 def close_op(self): 201 return self._close_op 202 203 @property 204 def cancel_op(self): 205 return self._cancel_op 206 207 @property 208 def queue_closed_exception_types(self): 209 return self._queue_closed_exception_types 210 211 @property 212 def exceptions_raised(self): 213 """Exceptions raised but not handled by the `QueueRunner` threads. 214 215 Exceptions raised in queue runner threads are handled in one of two ways 216 depending on whether or not a `Coordinator` was passed to 217 `create_threads()`: 218 219 * With a `Coordinator`, exceptions are reported to the coordinator and 220 forgotten by the `QueueRunner`. 221 * Without a `Coordinator`, exceptions are captured by the `QueueRunner` and 222 made available in this `exceptions_raised` property. 223 224 Returns: 225 A list of Python `Exception` objects. The list is empty if no exception 226 was captured. (No exceptions are captured when using a Coordinator.) 227 """ 228 return self._exceptions_raised 229 230 @property 231 def name(self): 232 """The string name of the underlying Queue.""" 233 return self._queue.name 234 235 # pylint: disable=broad-except 236 def _run(self, sess, enqueue_op, coord=None): 237 """Execute the enqueue op in a loop, close the queue in case of error. 238 239 Args: 240 sess: A Session. 241 enqueue_op: The Operation to run. 242 coord: Optional Coordinator object for reporting errors and checking 243 for stop conditions. 244 """ 245 decremented = False 246 try: 247 # Make a cached callable from the `enqueue_op` to decrease the 248 # Python overhead in the queue-runner loop. 249 enqueue_callable = sess.make_callable(enqueue_op) 250 while True: 251 if coord and coord.should_stop(): 252 break 253 try: 254 enqueue_callable() 255 except self._queue_closed_exception_types: # pylint: disable=catching-non-exception 256 # This exception indicates that a queue was closed. 257 with self._lock: 258 self._runs_per_session[sess] -= 1 259 decremented = True 260 if self._runs_per_session[sess] == 0: 261 try: 262 sess.run(self._close_op) 263 except Exception as e: 264 # Intentionally ignore errors from close_op. 265 logging.vlog(1, "Ignored exception: %s", str(e)) 266 return 267 except Exception as e: 268 # This catches all other exceptions. 269 if coord: 270 coord.request_stop(e) 271 else: 272 logging.error("Exception in QueueRunner: %s", str(e)) 273 with self._lock: 274 self._exceptions_raised.append(e) 275 raise 276 finally: 277 # Make sure we account for all terminations: normal or errors. 278 if not decremented: 279 with self._lock: 280 self._runs_per_session[sess] -= 1 281 282 def _close_on_stop(self, sess, cancel_op, coord): 283 """Close the queue when the Coordinator requests stop. 284 285 Args: 286 sess: A Session. 287 cancel_op: The Operation to run. 288 coord: Coordinator. 289 """ 290 coord.wait_for_stop() 291 try: 292 sess.run(cancel_op) 293 except Exception as e: 294 # Intentionally ignore errors from cancel_op. 295 logging.vlog(1, "Ignored exception: %s", str(e)) 296 # pylint: enable=broad-except 297 298 def create_threads(self, sess, coord=None, daemon=False, start=False): 299 """Create threads to run the enqueue ops for the given session. 300 301 This method requires a session in which the graph was launched. It creates 302 a list of threads, optionally starting them. There is one thread for each 303 op passed in `enqueue_ops`. 304 305 The `coord` argument is an optional coordinator that the threads will use 306 to terminate together and report exceptions. If a coordinator is given, 307 this method starts an additional thread to close the queue when the 308 coordinator requests a stop. 309 310 If previously created threads for the given session are still running, no 311 new threads will be created. 312 313 Args: 314 sess: A `Session`. 315 coord: Optional `Coordinator` object for reporting errors and checking 316 stop conditions. 317 daemon: Boolean. If `True` make the threads daemon threads. 318 start: Boolean. If `True` starts the threads. If `False` the 319 caller must call the `start()` method of the returned threads. 320 321 Returns: 322 A list of threads. 323 """ 324 with self._lock: 325 try: 326 if self._runs_per_session[sess] > 0: 327 # Already started: no new threads to return. 328 return [] 329 except KeyError: 330 # We haven't seen this session yet. 331 pass 332 self._runs_per_session[sess] = len(self._enqueue_ops) 333 self._exceptions_raised = [] 334 335 ret_threads = [] 336 for op in self._enqueue_ops: 337 name = "QueueRunnerThread-{}-{}".format(self.name, op.name) 338 ret_threads.append(threading.Thread(target=self._run, 339 args=(sess, op, coord), 340 name=name)) 341 if coord: 342 name = "QueueRunnerThread-{}-close_on_stop".format(self.name) 343 ret_threads.append(threading.Thread(target=self._close_on_stop, 344 args=(sess, self._cancel_op, coord), 345 name=name)) 346 for t in ret_threads: 347 if coord: 348 coord.register_thread(t) 349 if daemon: 350 t.daemon = True 351 if start: 352 t.start() 353 return ret_threads 354 355 def to_proto(self, export_scope=None): 356 """Converts this `QueueRunner` to a `QueueRunnerDef` protocol buffer. 357 358 Args: 359 export_scope: Optional `string`. Name scope to remove. 360 361 Returns: 362 A `QueueRunnerDef` protocol buffer, or `None` if the `Variable` is not in 363 the specified name scope. 364 """ 365 if (export_scope is None or 366 self.queue.name.startswith(export_scope)): 367 queue_runner_def = queue_runner_pb2.QueueRunnerDef() 368 queue_runner_def.queue_name = ops.strip_name_scope( 369 self.queue.name, export_scope) 370 for enqueue_op in self.enqueue_ops: 371 queue_runner_def.enqueue_op_name.append( 372 ops.strip_name_scope(enqueue_op.name, export_scope)) 373 queue_runner_def.close_op_name = ops.strip_name_scope( 374 self.close_op.name, export_scope) 375 queue_runner_def.cancel_op_name = ops.strip_name_scope( 376 self.cancel_op.name, export_scope) 377 queue_runner_def.queue_closed_exception_types.extend([ 378 errors.error_code_from_exception_type(cls) 379 for cls in self._queue_closed_exception_types]) 380 return queue_runner_def 381 else: 382 return None 383 384 @staticmethod 385 def from_proto(queue_runner_def, import_scope=None): 386 """Returns a `QueueRunner` object created from `queue_runner_def`.""" 387 return QueueRunner(queue_runner_def=queue_runner_def, 388 import_scope=import_scope) 389 390 391@tf_export(v1=["train.queue_runner.add_queue_runner", "train.add_queue_runner"]) 392@deprecation.deprecated(None, _DEPRECATION_INSTRUCTION) 393def add_queue_runner(qr, collection=ops.GraphKeys.QUEUE_RUNNERS): 394 """Adds a `QueueRunner` to a collection in the graph. 395 396 When building a complex model that uses many queues it is often difficult to 397 gather all the queue runners that need to be run. This convenience function 398 allows you to add a queue runner to a well known collection in the graph. 399 400 The companion method `start_queue_runners()` can be used to start threads for 401 all the collected queue runners. 402 403 @compatibility(TF2) 404 QueueRunners are not compatible with eager execution. Instead, please 405 use [tf.data](https://www.tensorflow.org/guide/data) to get data into your 406 model. 407 @end_compatibility 408 409 Args: 410 qr: A `QueueRunner`. 411 collection: A `GraphKey` specifying the graph collection to add 412 the queue runner to. Defaults to `GraphKeys.QUEUE_RUNNERS`. 413 """ 414 ops.add_to_collection(collection, qr) 415 416 417@tf_export(v1=["train.queue_runner.start_queue_runners", 418 "train.start_queue_runners"]) 419@deprecation.deprecated(None, _DEPRECATION_INSTRUCTION) 420def start_queue_runners(sess=None, coord=None, daemon=True, start=True, 421 collection=ops.GraphKeys.QUEUE_RUNNERS): 422 """Starts all queue runners collected in the graph. 423 424 This is a companion method to `add_queue_runner()`. It just starts 425 threads for all queue runners collected in the graph. It returns 426 the list of all threads. 427 428 @compatibility(TF2) 429 QueueRunners are not compatible with eager execution. Instead, please 430 use [tf.data](https://www.tensorflow.org/guide/data) to get data into your 431 model. 432 @end_compatibility 433 434 Args: 435 sess: `Session` used to run the queue ops. Defaults to the 436 default session. 437 coord: Optional `Coordinator` for coordinating the started threads. 438 daemon: Whether the threads should be marked as `daemons`, meaning 439 they don't block program exit. 440 start: Set to `False` to only create the threads, not start them. 441 collection: A `GraphKey` specifying the graph collection to 442 get the queue runners from. Defaults to `GraphKeys.QUEUE_RUNNERS`. 443 444 Raises: 445 ValueError: if `sess` is None and there isn't any default session. 446 TypeError: if `sess` is not a `tf.compat.v1.Session` object. 447 448 Returns: 449 A list of threads. 450 451 Raises: 452 RuntimeError: If called with eager execution enabled. 453 ValueError: If called without a default `tf.compat.v1.Session` registered. 454 """ 455 if context.executing_eagerly(): 456 raise RuntimeError("Queues are not compatible with eager execution.") 457 if sess is None: 458 sess = ops.get_default_session() 459 if not sess: 460 raise ValueError("Cannot start queue runners: No default session is " 461 "registered. Use `with sess.as_default()` or pass an " 462 "explicit session to tf.start_queue_runners(sess=sess)") 463 464 if not isinstance(sess, session.SessionInterface): 465 # Following check is due to backward compatibility. (b/62061352) 466 if sess.__class__.__name__ in [ 467 "MonitoredSession", "SingularMonitoredSession"]: 468 return [] 469 raise TypeError("sess must be a `tf.Session` object. " 470 "Given class: {}".format(sess.__class__)) 471 472 queue_runners = ops.get_collection(collection) 473 if not queue_runners: 474 logging.warning( 475 "`tf.train.start_queue_runners()` was called when no queue runners " 476 "were defined. You can safely remove the call to this deprecated " 477 "function.") 478 479 with sess.graph.as_default(): 480 threads = [] 481 for qr in ops.get_collection(collection): 482 threads.extend(qr.create_threads(sess, coord=coord, daemon=daemon, 483 start=start)) 484 return threads 485 486 487ops.register_proto_function(ops.GraphKeys.QUEUE_RUNNERS, 488 proto_type=queue_runner_pb2.QueueRunnerDef, 489 to_proto=QueueRunner.to_proto, 490 from_proto=QueueRunner.from_proto) 491