1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14#============================================================================== 15"""Data Flow Operations.""" 16# pylint: disable=g-bad-name 17import functools 18import hashlib 19import threading 20 21from tensorflow.python.eager import context 22from tensorflow.python.framework import dtypes as _dtypes 23from tensorflow.python.framework import indexed_slices 24from tensorflow.python.framework import ops 25from tensorflow.python.framework import random_seed 26from tensorflow.python.framework import tensor_shape 27from tensorflow.python.framework import tensor_util 28from tensorflow.python.lib.io import python_io 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops import control_flow_ops 31from tensorflow.python.ops import gen_data_flow_ops 32from tensorflow.python.ops import math_ops 33from tensorflow.python.ops import resource_variable_ops 34# go/tf-wildcard-import 35# pylint: disable=wildcard-import 36from tensorflow.python.ops.gen_data_flow_ops import * 37from tensorflow.python.util import deprecation 38from tensorflow.python.util.compat import collections_abc 39from tensorflow.python.util.tf_export import tf_export 40 41# pylint: enable=wildcard-import 42 43 44def _as_type_list(dtypes): 45 """Convert dtypes to a list of types.""" 46 assert dtypes is not None 47 if not (isinstance(dtypes, list) or isinstance(dtypes, tuple)): 48 # We have a single type. 49 return [dtypes] 50 else: 51 # We have a list or tuple of types. 52 return list(dtypes) 53 54 55def _as_shape_list(shapes, 56 dtypes, 57 unknown_dim_allowed=False, 58 unknown_rank_allowed=False): 59 """Convert shapes to a list of tuples of int (or None).""" 60 del dtypes 61 if unknown_dim_allowed: 62 if (not isinstance(shapes, collections_abc.Sequence) or not shapes or 63 any(shape is None or isinstance(shape, int) for shape in shapes)): 64 raise ValueError( 65 "When providing partial shapes, a list of shapes must be provided.") 66 if shapes is None: 67 return None 68 if isinstance(shapes, tensor_shape.TensorShape): 69 shapes = [shapes] 70 if not isinstance(shapes, (tuple, list)): 71 raise TypeError( 72 "Shapes must be a TensorShape or a list or tuple of TensorShapes, " 73 f"got {type(shapes)} instead.") 74 if all(shape is None or isinstance(shape, int) for shape in shapes): 75 # We have a single shape. 76 shapes = [shapes] 77 shapes = [tensor_shape.as_shape(shape) for shape in shapes] 78 if not unknown_dim_allowed: 79 if any(not shape.is_fully_defined() for shape in shapes): 80 raise ValueError(f"All shapes must be fully defined: {shapes}") 81 if not unknown_rank_allowed: 82 if any(shape.dims is None for shape in shapes): 83 raise ValueError(f"All shapes must have a defined rank: {shapes}") 84 85 return shapes 86 87 88def _as_name_list(names, dtypes): 89 if names is None: 90 return None 91 if not isinstance(names, (list, tuple)): 92 names = [names] 93 if len(names) != len(dtypes): 94 raise ValueError("List of names must have the same length as the list " 95 f"of dtypes, received len(names)={len(names)}," 96 f"len(dtypes)={len(dtypes)}") 97 return list(names) 98 99 100def _shape_common(s1, s2): 101 """The greatest lower bound (ordered by specificity) TensorShape.""" 102 s1 = tensor_shape.TensorShape(s1) 103 s2 = tensor_shape.TensorShape(s2) 104 if s1.ndims is None or s2.ndims is None or s1.ndims != s2.ndims: 105 return tensor_shape.unknown_shape() 106 d = [ 107 d1 if d1 is not None and d1 == d2 else None 108 for (d1, d2) in zip(s1.as_list(), s2.as_list()) 109 ] 110 return tensor_shape.TensorShape(d) 111 112 113# pylint: disable=protected-access 114@tf_export("queue.QueueBase", 115 v1=["queue.QueueBase", "io.QueueBase", "QueueBase"]) 116@deprecation.deprecated_endpoints(["io.QueueBase", "QueueBase"]) 117class QueueBase: 118 """Base class for queue implementations. 119 120 A queue is a TensorFlow data structure that stores tensors across 121 multiple steps, and exposes operations that enqueue and dequeue 122 tensors. 123 124 Each queue element is a tuple of one or more tensors, where each 125 tuple component has a static dtype, and may have a static shape. The 126 queue implementations support versions of enqueue and dequeue that 127 handle single elements, versions that support enqueuing and 128 dequeuing a batch of elements at once. 129 130 See `tf.queue.FIFOQueue` and 131 `tf.queue.RandomShuffleQueue` for concrete 132 implementations of this class, and instructions on how to create 133 them. 134 """ 135 136 def __init__(self, dtypes, shapes, names, queue_ref): 137 """Constructs a queue object from a queue reference. 138 139 The two optional lists, `shapes` and `names`, must be of the same length 140 as `dtypes` if provided. The values at a given index `i` indicate the 141 shape and name to use for the corresponding queue component in `dtypes`. 142 143 Args: 144 dtypes: A list of types. The length of dtypes must equal the number 145 of tensors in each element. 146 shapes: Constraints on the shapes of tensors in an element: 147 A list of shape tuples or None. This list is the same length 148 as dtypes. If the shape of any tensors in the element are constrained, 149 all must be; shapes can be None if the shapes should not be constrained. 150 names: Optional list of names. If provided, the `enqueue()` and 151 `dequeue()` methods will use dictionaries with these names as keys. 152 Must be None or a list or tuple of the same length as `dtypes`. 153 queue_ref: The queue reference, i.e. the output of the queue op. 154 155 Raises: 156 ValueError: If one of the arguments is invalid. 157 """ 158 self._dtypes = dtypes 159 if shapes is not None: 160 if len(shapes) != len(dtypes): 161 raise ValueError("Queue shapes must have the same length as dtypes, " 162 f"received len(shapes)={len(shapes)}, " 163 f"len(dtypes)={len(dtypes)}") 164 self._shapes = [tensor_shape.TensorShape(s) for s in shapes] 165 else: 166 self._shapes = [tensor_shape.unknown_shape() for _ in self._dtypes] 167 if names is not None: 168 if len(names) != len(dtypes): 169 raise ValueError("Queue names must have the same length as dtypes," 170 f"received len(names)={len(names)}," 171 f"len {len(dtypes)}") 172 self._names = names 173 else: 174 self._names = None 175 self._queue_ref = queue_ref 176 if isinstance(queue_ref, ops.EagerTensor): 177 if context.context().scope_name: 178 self._name = context.context().scope_name 179 else: 180 self._name = "Empty" 181 self._resource_deleter = resource_variable_ops.EagerResourceDeleter( 182 queue_ref, None) 183 else: 184 self._name = self._queue_ref.op.name.split("/")[-1] 185 186 @staticmethod 187 def from_list(index, queues): 188 """Create a queue using the queue reference from `queues[index]`. 189 190 Args: 191 index: An integer scalar tensor that determines the input that gets 192 selected. 193 queues: A list of `QueueBase` objects. 194 195 Returns: 196 A `QueueBase` object. 197 198 Raises: 199 TypeError: When `queues` is not a list of `QueueBase` objects, 200 or when the data types of `queues` are not all the same. 201 """ 202 if ((not queues) or (not isinstance(queues, list)) or 203 (not all(isinstance(x, QueueBase) for x in queues))): 204 raise TypeError("A list of queues expected") 205 206 dtypes = queues[0].dtypes 207 if not all(dtypes == q.dtypes for q in queues[1:]): 208 raise TypeError("Queues do not have matching component dtypes.") 209 210 names = queues[0].names 211 if not all(names == q.names for q in queues[1:]): 212 raise TypeError("Queues do not have matching component names.") 213 214 queue_shapes = [q.shapes for q in queues] 215 reduced_shapes = [ 216 functools.reduce(_shape_common, s) for s in zip(*queue_shapes) 217 ] 218 219 queue_refs = array_ops.stack([x.queue_ref for x in queues]) 220 selected_queue = array_ops.gather(queue_refs, index) 221 return QueueBase( 222 dtypes=dtypes, 223 shapes=reduced_shapes, 224 names=names, 225 queue_ref=selected_queue) 226 227 @property 228 def queue_ref(self): 229 """The underlying queue reference.""" 230 return self._queue_ref 231 232 @property 233 def name(self): 234 """The name of the underlying queue.""" 235 if context.executing_eagerly(): 236 return self._name 237 return self._queue_ref.op.name 238 239 @property 240 def dtypes(self): 241 """The list of dtypes for each component of a queue element.""" 242 return self._dtypes 243 244 @property 245 def shapes(self): 246 """The list of shapes for each component of a queue element.""" 247 return self._shapes 248 249 @property 250 def names(self): 251 """The list of names for each component of a queue element.""" 252 return self._names 253 254 def _check_enqueue_dtypes(self, vals): 255 """Validate and convert `vals` to a list of `Tensor`s. 256 257 The `vals` argument can be a Tensor, a list or tuple of tensors, or a 258 dictionary with tensor values. 259 260 If it is a dictionary, the queue must have been constructed with a 261 `names` attribute and the dictionary keys must match the queue names. 262 If the queue was constructed with a `names` attribute, `vals` must 263 be a dictionary. 264 265 Args: 266 vals: A tensor, a list or tuple of tensors, or a dictionary.. 267 268 Returns: 269 A list of `Tensor` objects. 270 271 Raises: 272 ValueError: If `vals` is invalid. 273 """ 274 if isinstance(vals, dict): 275 if not self._names: 276 raise ValueError("Queue must have names to enqueue a dictionary") 277 if sorted(self._names, key=str) != sorted(vals.keys(), key=str): 278 raise ValueError("Keys in dictionary to enqueue do not match " 279 f"names of Queue. Dictionary: {sorted(vals.keys())}," 280 f"Queue: {sorted(self._names)}") 281 # The order of values in `self._names` indicates the order in which the 282 # tensors in the dictionary `vals` must be listed. 283 vals = [vals[k] for k in self._names] 284 else: 285 if self._names: 286 raise ValueError("You must enqueue a dictionary in a Queue with names") 287 if not isinstance(vals, (list, tuple)): 288 vals = [vals] 289 290 tensors = [] 291 for i, (val, dtype) in enumerate(zip(vals, self._dtypes)): 292 tensors.append( 293 ops.convert_to_tensor(val, dtype=dtype, name="component_%d" % i)) 294 295 return tensors 296 297 def _scope_vals(self, vals): 298 """Return a list of values to pass to `name_scope()`. 299 300 Args: 301 vals: A tensor, a list or tuple of tensors, or a dictionary. 302 303 Returns: 304 The values in vals as a list. 305 """ 306 if isinstance(vals, (list, tuple)): 307 return vals 308 elif isinstance(vals, dict): 309 return vals.values() 310 else: 311 return [vals] 312 313 def enqueue(self, vals, name=None): 314 """Enqueues one element to this queue. 315 316 If the queue is full when this operation executes, it will block 317 until the element has been enqueued. 318 319 At runtime, this operation may raise an error if the queue is 320 `tf.QueueBase.close` before or during its execution. If the 321 queue is closed before this operation runs, 322 `tf.errors.CancelledError` will be raised. If this operation is 323 blocked, and either (i) the queue is closed by a close operation 324 with `cancel_pending_enqueues=True`, or (ii) the session is 325 `tf.Session.close`, 326 `tf.errors.CancelledError` will be raised. 327 328 Args: 329 vals: A tensor, a list or tuple of tensors, or a dictionary containing 330 the values to enqueue. 331 name: A name for the operation (optional). 332 333 Returns: 334 The operation that enqueues a new tuple of tensors to the queue. 335 """ 336 with ops.name_scope(name, "%s_enqueue" % self._name, 337 self._scope_vals(vals)) as scope: 338 vals = self._check_enqueue_dtypes(vals) 339 340 # NOTE(mrry): Not using a shape function because we need access to 341 # the `QueueBase` object. 342 for val, shape in zip(vals, self._shapes): 343 val.get_shape().assert_is_compatible_with(shape) 344 345 if self._queue_ref.dtype == _dtypes.resource: 346 return gen_data_flow_ops.queue_enqueue_v2( 347 self._queue_ref, vals, name=scope) 348 else: 349 return gen_data_flow_ops.queue_enqueue( 350 self._queue_ref, vals, name=scope) 351 352 def enqueue_many(self, vals, name=None): 353 """Enqueues zero or more elements to this queue. 354 355 This operation slices each component tensor along the 0th dimension to 356 make multiple queue elements. All of the tensors in `vals` must have the 357 same size in the 0th dimension. 358 359 If the queue is full when this operation executes, it will block 360 until all of the elements have been enqueued. 361 362 At runtime, this operation may raise an error if the queue is 363 `tf.QueueBase.close` before or during its execution. If the 364 queue is closed before this operation runs, 365 `tf.errors.CancelledError` will be raised. If this operation is 366 blocked, and either (i) the queue is closed by a close operation 367 with `cancel_pending_enqueues=True`, or (ii) the session is 368 `tf.Session.close`, 369 `tf.errors.CancelledError` will be raised. 370 371 Args: 372 vals: A tensor, a list or tuple of tensors, or a dictionary 373 from which the queue elements are taken. 374 name: A name for the operation (optional). 375 376 Returns: 377 The operation that enqueues a batch of tuples of tensors to the queue. 378 """ 379 with ops.name_scope(name, "%s_EnqueueMany" % self._name, 380 self._scope_vals(vals)) as scope: 381 vals = self._check_enqueue_dtypes(vals) 382 383 # NOTE(mrry): Not using a shape function because we need access to 384 # the `QueueBase` object. 385 # NOTE(fchollet): the code that follow is verbose because it needs to be 386 # compatible with both TF v1 TensorShape behavior and TF v2 behavior. 387 batch_dim = tensor_shape.dimension_value( 388 vals[0].get_shape().with_rank_at_least(1)[0]) 389 batch_dim = tensor_shape.Dimension(batch_dim) 390 for val, shape in zip(vals, self._shapes): 391 val_batch_dim = tensor_shape.dimension_value( 392 val.get_shape().with_rank_at_least(1)[0]) 393 val_batch_dim = tensor_shape.Dimension(val_batch_dim) 394 batch_dim = batch_dim.merge_with(val_batch_dim) 395 val.get_shape()[1:].assert_is_compatible_with(shape) 396 397 return gen_data_flow_ops.queue_enqueue_many_v2( 398 self._queue_ref, vals, name=scope) 399 400 def _dequeue_return_value(self, tensors): 401 """Return the value to return from a dequeue op. 402 403 If the queue has names, return a dictionary with the 404 names as keys. Otherwise return either a single tensor 405 or a list of tensors depending on the length of `tensors`. 406 407 Args: 408 tensors: List of tensors from the dequeue op. 409 410 Returns: 411 A single tensor, a list of tensors, or a dictionary 412 of tensors. 413 """ 414 if self._names: 415 # The returned values in `tensors` are in the same order as 416 # the names in `self._names`. 417 return {n: tensors[i] for i, n in enumerate(self._names)} 418 elif len(tensors) == 1: 419 return tensors[0] 420 else: 421 return tensors 422 423 def dequeue(self, name=None): 424 """Dequeues one element from this queue. 425 426 If the queue is empty when this operation executes, it will block 427 until there is an element to dequeue. 428 429 At runtime, this operation may raise an error if the queue is 430 `tf.QueueBase.close` before or during its execution. If the 431 queue is closed, the queue is empty, and there are no pending 432 enqueue operations that can fulfill this request, 433 `tf.errors.OutOfRangeError` will be raised. If the session is 434 `tf.Session.close`, 435 `tf.errors.CancelledError` will be raised. 436 437 Args: 438 name: A name for the operation (optional). 439 440 Returns: 441 The tuple of tensors that was dequeued. 442 """ 443 if name is None: 444 name = "%s_Dequeue" % self._name 445 if self._queue_ref.dtype == _dtypes.resource: 446 ret = gen_data_flow_ops.queue_dequeue_v2( 447 self._queue_ref, self._dtypes, name=name) 448 else: 449 ret = gen_data_flow_ops.queue_dequeue( 450 self._queue_ref, self._dtypes, name=name) 451 452 # NOTE(mrry): Not using a shape function because we need access to 453 # the `QueueBase` object. 454 if not context.executing_eagerly(): 455 op = ret[0].op 456 for output, shape in zip(op.values(), self._shapes): 457 output.set_shape(shape) 458 459 return self._dequeue_return_value(ret) 460 461 def dequeue_many(self, n, name=None): 462 """Dequeues and concatenates `n` elements from this queue. 463 464 This operation concatenates queue-element component tensors along 465 the 0th dimension to make a single component tensor. All of the 466 components in the dequeued tuple will have size `n` in the 0th dimension. 467 468 If the queue is closed and there are less than `n` elements left, then an 469 `OutOfRange` exception is raised. 470 471 At runtime, this operation may raise an error if the queue is 472 `tf.QueueBase.close` before or during its execution. If the 473 queue is closed, the queue contains fewer than `n` elements, and 474 there are no pending enqueue operations that can fulfill this 475 request, `tf.errors.OutOfRangeError` will be raised. If the 476 session is `tf.Session.close`, 477 `tf.errors.CancelledError` will be raised. 478 479 Args: 480 n: A scalar `Tensor` containing the number of elements to dequeue. 481 name: A name for the operation (optional). 482 483 Returns: 484 The list of concatenated tensors that was dequeued. 485 """ 486 if name is None: 487 name = "%s_DequeueMany" % self._name 488 489 ret = gen_data_flow_ops.queue_dequeue_many_v2( 490 self._queue_ref, n=n, component_types=self._dtypes, name=name) 491 492 # NOTE(mrry): Not using a shape function because we need access to 493 # the Queue object. 494 if not context.executing_eagerly(): 495 op = ret[0].op 496 batch_dim = tensor_shape.Dimension( 497 tensor_util.constant_value(op.inputs[1])) 498 for output, shape in zip(op.values(), self._shapes): 499 output.set_shape( 500 tensor_shape.TensorShape([batch_dim]).concatenate(shape)) 501 502 return self._dequeue_return_value(ret) 503 504 def dequeue_up_to(self, n, name=None): 505 """Dequeues and concatenates `n` elements from this queue. 506 507 **Note** This operation is not supported by all queues. If a queue does not 508 support DequeueUpTo, then a `tf.errors.UnimplementedError` is raised. 509 510 This operation concatenates queue-element component tensors along 511 the 0th dimension to make a single component tensor. If the queue 512 has not been closed, all of the components in the dequeued tuple 513 will have size `n` in the 0th dimension. 514 515 If the queue is closed and there are more than `0` but fewer than 516 `n` elements remaining, then instead of raising a 517 `tf.errors.OutOfRangeError` like `tf.QueueBase.dequeue_many`, 518 less than `n` elements are returned immediately. If the queue is 519 closed and there are `0` elements left in the queue, then a 520 `tf.errors.OutOfRangeError` is raised just like in `dequeue_many`. 521 Otherwise the behavior is identical to `dequeue_many`. 522 523 Args: 524 n: A scalar `Tensor` containing the number of elements to dequeue. 525 name: A name for the operation (optional). 526 527 Returns: 528 The tuple of concatenated tensors that was dequeued. 529 """ 530 if name is None: 531 name = "%s_DequeueUpTo" % self._name 532 533 ret = gen_data_flow_ops.queue_dequeue_up_to_v2( 534 self._queue_ref, n=n, component_types=self._dtypes, name=name) 535 536 # NOTE(mrry): Not using a shape function because we need access to 537 # the Queue object. 538 if not context.executing_eagerly(): 539 op = ret[0].op 540 for output, shape in zip(op.values(), self._shapes): 541 output.set_shape(tensor_shape.TensorShape([None]).concatenate(shape)) 542 543 return self._dequeue_return_value(ret) 544 545 def close(self, cancel_pending_enqueues=False, name=None): 546 """Closes this queue. 547 548 This operation signals that no more elements will be enqueued in 549 the given queue. Subsequent `enqueue` and `enqueue_many` 550 operations will fail. Subsequent `dequeue` and `dequeue_many` 551 operations will continue to succeed if sufficient elements remain 552 in the queue. Subsequently dequeue and dequeue_many operations 553 that would otherwise block waiting for more elements (if close 554 hadn't been called) will now fail immediately. 555 556 If `cancel_pending_enqueues` is `True`, all pending requests will also 557 be canceled. 558 559 Args: 560 cancel_pending_enqueues: (Optional.) A boolean, defaulting to 561 `False` (described above). 562 name: A name for the operation (optional). 563 564 Returns: 565 The operation that closes the queue. 566 """ 567 if name is None: 568 name = "%s_Close" % self._name 569 if self._queue_ref.dtype == _dtypes.resource: 570 return gen_data_flow_ops.queue_close_v2( 571 self._queue_ref, 572 cancel_pending_enqueues=cancel_pending_enqueues, 573 name=name) 574 else: 575 return gen_data_flow_ops.queue_close( 576 self._queue_ref, 577 cancel_pending_enqueues=cancel_pending_enqueues, 578 name=name) 579 580 def is_closed(self, name=None): 581 """Returns true if queue is closed. 582 583 This operation returns true if the queue is closed and false if the queue 584 is open. 585 586 Args: 587 name: A name for the operation (optional). 588 589 Returns: 590 True if the queue is closed and false if the queue is open. 591 """ 592 if name is None: 593 name = "%s_Is_Closed" % self._name 594 if self._queue_ref.dtype == _dtypes.resource: 595 return gen_data_flow_ops.queue_is_closed_v2(self._queue_ref, name=name) 596 else: 597 return gen_data_flow_ops.queue_is_closed_(self._queue_ref, name=name) 598 599 def size(self, name=None): 600 """Compute the number of elements in this queue. 601 602 Args: 603 name: A name for the operation (optional). 604 605 Returns: 606 A scalar tensor containing the number of elements in this queue. 607 """ 608 if name is None: 609 name = "%s_Size" % self._name 610 if self._queue_ref.dtype == _dtypes.resource: 611 return gen_data_flow_ops.queue_size_v2(self._queue_ref, name=name) 612 else: 613 return gen_data_flow_ops.queue_size(self._queue_ref, name=name) 614 615def _shared_name(shared_name): 616 if context.executing_eagerly(): 617 return str(ops.uid()) 618 return shared_name 619 620 621@tf_export( 622 "queue.RandomShuffleQueue", 623 v1=["queue.RandomShuffleQueue", 624 "io.RandomShuffleQueue", "RandomShuffleQueue"]) 625@deprecation.deprecated_endpoints( 626 ["io.RandomShuffleQueue", "RandomShuffleQueue"]) 627class RandomShuffleQueue(QueueBase): 628 """A queue implementation that dequeues elements in a random order. 629 630 See `tf.queue.QueueBase` for a description of the methods on 631 this class. 632 """ 633 634 def __init__(self, 635 capacity, 636 min_after_dequeue, 637 dtypes, 638 shapes=None, 639 names=None, 640 seed=None, 641 shared_name=None, 642 name="random_shuffle_queue"): 643 """Create a queue that dequeues elements in a random order. 644 645 A `RandomShuffleQueue` has bounded capacity; supports multiple 646 concurrent producers and consumers; and provides exactly-once 647 delivery. 648 649 A `RandomShuffleQueue` holds a list of up to `capacity` 650 elements. Each element is a fixed-length tuple of tensors whose 651 dtypes are described by `dtypes`, and whose shapes are optionally 652 described by the `shapes` argument. 653 654 If the `shapes` argument is specified, each component of a queue 655 element must have the respective fixed shape. If it is 656 unspecified, different queue elements may have different shapes, 657 but the use of `dequeue_many` is disallowed. 658 659 The `min_after_dequeue` argument allows the caller to specify a 660 minimum number of elements that will remain in the queue after a 661 `dequeue` or `dequeue_many` operation completes, to ensure a 662 minimum level of mixing of elements. This invariant is maintained 663 by blocking those operations until sufficient elements have been 664 enqueued. The `min_after_dequeue` argument is ignored after the 665 queue has been closed. 666 667 Args: 668 capacity: An integer. The upper bound on the number of elements 669 that may be stored in this queue. 670 min_after_dequeue: An integer (described above). 671 dtypes: A list of `DType` objects. The length of `dtypes` must equal 672 the number of tensors in each queue element. 673 shapes: (Optional.) A list of fully-defined `TensorShape` objects 674 with the same length as `dtypes`, or `None`. 675 names: (Optional.) A list of string naming the components in the queue 676 with the same length as `dtypes`, or `None`. If specified the dequeue 677 methods return a dictionary with the names as keys. 678 seed: A Python integer. Used to create a random seed. See 679 `tf.compat.v1.set_random_seed` 680 for behavior. 681 shared_name: (Optional.) If non-empty, this queue will be shared under 682 the given name across multiple sessions. 683 name: Optional name for the queue operation. 684 """ 685 dtypes = _as_type_list(dtypes) 686 shapes = _as_shape_list(shapes, dtypes) 687 names = _as_name_list(names, dtypes) 688 seed1, seed2 = random_seed.get_seed(seed) 689 if seed1 is None and seed2 is None: 690 seed1, seed2 = 0, 0 691 elif seed is None and shared_name is not None: 692 # This means that graph seed is provided but op seed is not provided. 693 # If shared_name is also provided, make seed2 depend only on the graph 694 # seed and shared_name. (seed2 from get_seed() is generally dependent on 695 # the id of the last op created.) 696 string = (str(seed1) + shared_name).encode("utf-8") 697 seed2 = int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF 698 queue_ref = gen_data_flow_ops.random_shuffle_queue_v2( 699 component_types=dtypes, 700 shapes=shapes, 701 capacity=capacity, 702 min_after_dequeue=min_after_dequeue, 703 seed=seed1, 704 seed2=seed2, 705 shared_name=_shared_name(shared_name), 706 name=name) 707 708 super(RandomShuffleQueue, self).__init__(dtypes, shapes, names, queue_ref) 709 710 711@tf_export("queue.FIFOQueue", v1=["queue.FIFOQueue", "FIFOQueue"]) 712@deprecation.deprecated_endpoints("FIFOQueue") 713class FIFOQueue(QueueBase): 714 """A queue implementation that dequeues elements in first-in first-out order. 715 716 See `tf.queue.QueueBase` for a description of the methods on 717 this class. 718 """ 719 720 def __init__(self, 721 capacity, 722 dtypes, 723 shapes=None, 724 names=None, 725 shared_name=None, 726 name="fifo_queue"): 727 """Creates a queue that dequeues elements in a first-in first-out order. 728 729 A `FIFOQueue` has bounded capacity; supports multiple concurrent 730 producers and consumers; and provides exactly-once delivery. 731 732 A `FIFOQueue` holds a list of up to `capacity` elements. Each 733 element is a fixed-length tuple of tensors whose dtypes are 734 described by `dtypes`, and whose shapes are optionally described 735 by the `shapes` argument. 736 737 If the `shapes` argument is specified, each component of a queue 738 element must have the respective fixed shape. If it is 739 unspecified, different queue elements may have different shapes, 740 but the use of `dequeue_many` is disallowed. 741 742 Args: 743 capacity: An integer. The upper bound on the number of elements 744 that may be stored in this queue. 745 dtypes: A list of `DType` objects. The length of `dtypes` must equal 746 the number of tensors in each queue element. 747 shapes: (Optional.) A list of fully-defined `TensorShape` objects 748 with the same length as `dtypes`, or `None`. 749 names: (Optional.) A list of string naming the components in the queue 750 with the same length as `dtypes`, or `None`. If specified the dequeue 751 methods return a dictionary with the names as keys. 752 shared_name: (Optional.) If non-empty, this queue will be shared under 753 the given name across multiple sessions. 754 name: Optional name for the queue operation. 755 """ 756 dtypes = _as_type_list(dtypes) 757 shapes = _as_shape_list(shapes, dtypes) 758 names = _as_name_list(names, dtypes) 759 with ops.init_scope(), ops.device("CPU"): 760 queue_ref = gen_data_flow_ops.fifo_queue_v2( 761 component_types=dtypes, 762 shapes=shapes, 763 capacity=capacity, 764 shared_name=_shared_name(shared_name), 765 name=name) 766 767 super(FIFOQueue, self).__init__(dtypes, shapes, names, queue_ref) 768 769 770# TODO(allenl): If GPU-compatible queues turn out to be useful, we should 771# implement GPU kernels for EnqueueMany and DequeueMany so we can make the 772# public FIFOQueue GPU-compatible and remove this internal version. 773class GPUCompatibleFIFOQueue(QueueBase): 774 """A queue implementation that dequeues elements in first-in first-out order. 775 776 GPUCompatibleFIFOQueue is like FIFOQueue, but the queue resource may be placed 777 either on a CPU or on a GPU. It is not cross-device: enqueues and dequeues 778 will be colocated with the queue resource. GPUCompatibleFIFOQueue only 779 supports enqueue and dequeue at the moment, not enqueue_many or dequeue_many. 780 781 See `tf.queue.QueueBase` for a description of the methods on this class. 782 """ 783 784 def __init__(self, 785 capacity, 786 dtypes, 787 shapes=None, 788 names=None, 789 shared_name=None, 790 name="fifo_queue"): 791 """Creates a queue that dequeues elements in a first-in first-out order. 792 793 A `FIFOQueue` has bounded capacity; supports multiple concurrent 794 producers and consumers; and provides exactly-once delivery. 795 796 A `FIFOQueue` holds a list of up to `capacity` elements. Each 797 element is a fixed-length tuple of tensors whose dtypes are 798 described by `dtypes`, and whose shapes are optionally described 799 by the `shapes` argument. 800 801 If the `shapes` argument is specified, each component of a queue 802 element must have the respective fixed shape. If it is 803 unspecified, different queue elements may have different shapes, 804 but the use of `dequeue_many` is disallowed. 805 806 Args: 807 capacity: An integer. The upper bound on the number of elements 808 that may be stored in this queue. 809 dtypes: A list of `DType` objects. The length of `dtypes` must equal 810 the number of tensors in each queue element. 811 shapes: (Optional.) A list of fully-defined `TensorShape` objects 812 with the same length as `dtypes`, or `None`. 813 names: (Optional.) A list of string naming the components in the queue 814 with the same length as `dtypes`, or `None`. If specified the dequeue 815 methods return a dictionary with the names as keys. 816 shared_name: (Optional.) If non-empty, this queue will be shared under 817 the given name across multiple sessions. 818 name: Optional name for the queue operation. 819 """ 820 dtypes = _as_type_list(dtypes) 821 shapes = _as_shape_list(shapes, dtypes) 822 names = _as_name_list(names, dtypes) 823 with ops.init_scope(): 824 queue_ref = gen_data_flow_ops.fifo_queue_v2( 825 component_types=dtypes, 826 shapes=shapes, 827 capacity=capacity, 828 shared_name=_shared_name(shared_name), 829 name=name) 830 831 super(GPUCompatibleFIFOQueue, self).__init__( 832 dtypes, shapes, names, queue_ref) 833 834 def enqueue_many(self, vals, name=None): 835 """enqueue_many is not supported on GPUCompatibleFIFOQueue.""" 836 raise NotImplementedError( 837 "GPUCompatibleFIFOQueue does not support enqueue_many or dequeue_many, " 838 "only enqueue and dequeue.") 839 840 def dequeue_many(self, n, name=None): 841 """dequeue_many is not supported on GPUCompatibleFIFOQueue.""" 842 raise NotImplementedError( 843 "GPUCompatibleFIFOQueue does not support enqueue_many or dequeue_many, " 844 "only enqueue and dequeue.") 845 846 847@tf_export( 848 "queue.PaddingFIFOQueue", 849 v1=["queue.PaddingFIFOQueue", "io.PaddingFIFOQueue", "PaddingFIFOQueue"]) 850@deprecation.deprecated_endpoints(["io.PaddingFIFOQueue", "PaddingFIFOQueue"]) 851class PaddingFIFOQueue(QueueBase): 852 """A FIFOQueue that supports batching variable-sized tensors by padding. 853 854 A `PaddingFIFOQueue` may contain components with dynamic shape, while also 855 supporting `dequeue_many`. See the constructor for more details. 856 857 See `tf.queue.QueueBase` for a description of the methods on 858 this class. 859 """ 860 861 def __init__(self, 862 capacity, 863 dtypes, 864 shapes, 865 names=None, 866 shared_name=None, 867 name="padding_fifo_queue"): 868 """Creates a queue that dequeues elements in a first-in first-out order. 869 870 A `PaddingFIFOQueue` has bounded capacity; supports multiple concurrent 871 producers and consumers; and provides exactly-once delivery. 872 873 A `PaddingFIFOQueue` holds a list of up to `capacity` elements. Each 874 element is a fixed-length tuple of tensors whose dtypes are 875 described by `dtypes`, and whose shapes are described by the `shapes` 876 argument. 877 878 The `shapes` argument must be specified; each component of a queue 879 element must have the respective shape. Shapes of fixed 880 rank but variable size are allowed by setting any shape dimension to None. 881 In this case, the inputs' shape may vary along the given dimension, and 882 `dequeue_many` will pad the given dimension with zeros up to the maximum 883 shape of all elements in the given batch. 884 885 Args: 886 capacity: An integer. The upper bound on the number of elements 887 that may be stored in this queue. 888 dtypes: A list of `DType` objects. The length of `dtypes` must equal 889 the number of tensors in each queue element. 890 shapes: A list of `TensorShape` objects, with the same length as 891 `dtypes`. Any dimension in the `TensorShape` containing value 892 `None` is dynamic and allows values to be enqueued with 893 variable size in that dimension. 894 names: (Optional.) A list of string naming the components in the queue 895 with the same length as `dtypes`, or `None`. If specified the dequeue 896 methods return a dictionary with the names as keys. 897 shared_name: (Optional.) If non-empty, this queue will be shared under 898 the given name across multiple sessions. 899 name: Optional name for the queue operation. 900 901 Raises: 902 ValueError: If shapes is not a list of shapes, or the lengths of dtypes 903 and shapes do not match, or if names is specified and the lengths of 904 dtypes and names do not match. 905 """ 906 dtypes = _as_type_list(dtypes) 907 shapes = _as_shape_list(shapes, dtypes, unknown_dim_allowed=True) 908 names = _as_name_list(names, dtypes) 909 if len(dtypes) != len(shapes): 910 raise ValueError("Shapes must be provided for all components, " 911 f"but received {len(dtypes)} dtypes and " 912 f"{len(shapes)} shapes.") 913 queue_ref = gen_data_flow_ops.padding_fifo_queue_v2( 914 component_types=dtypes, 915 shapes=shapes, 916 capacity=capacity, 917 shared_name=_shared_name(shared_name), 918 name=name) 919 920 super(PaddingFIFOQueue, self).__init__(dtypes, shapes, names, queue_ref) 921 922 923@tf_export("queue.PriorityQueue", 924 v1=["queue.PriorityQueue", "io.PriorityQueue", "PriorityQueue"]) 925@deprecation.deprecated_endpoints(["io.PriorityQueue", "PriorityQueue"]) 926class PriorityQueue(QueueBase): 927 """A queue implementation that dequeues elements in prioritized order. 928 929 See `tf.queue.QueueBase` for a description of the methods on 930 this class. 931 """ 932 933 def __init__(self, 934 capacity, 935 types, 936 shapes=None, 937 names=None, 938 shared_name=None, 939 name="priority_queue"): 940 """Creates a queue that dequeues elements in a first-in first-out order. 941 942 A `PriorityQueue` has bounded capacity; supports multiple concurrent 943 producers and consumers; and provides exactly-once delivery. 944 945 A `PriorityQueue` holds a list of up to `capacity` elements. Each 946 element is a fixed-length tuple of tensors whose dtypes are 947 described by `types`, and whose shapes are optionally described 948 by the `shapes` argument. 949 950 If the `shapes` argument is specified, each component of a queue 951 element must have the respective fixed shape. If it is 952 unspecified, different queue elements may have different shapes, 953 but the use of `dequeue_many` is disallowed. 954 955 Enqueues and Dequeues to the `PriorityQueue` must include an additional 956 tuple entry at the beginning: the `priority`. The priority must be 957 an int64 scalar (for `enqueue`) or an int64 vector (for `enqueue_many`). 958 959 Args: 960 capacity: An integer. The upper bound on the number of elements 961 that may be stored in this queue. 962 types: A list of `DType` objects. The length of `types` must equal 963 the number of tensors in each queue element, except the first priority 964 element. The first tensor in each element is the priority, 965 which must be type int64. 966 shapes: (Optional.) A list of fully-defined `TensorShape` objects, 967 with the same length as `types`, or `None`. 968 names: (Optional.) A list of strings naming the components in the queue 969 with the same length as `dtypes`, or `None`. If specified, the dequeue 970 methods return a dictionary with the names as keys. 971 shared_name: (Optional.) If non-empty, this queue will be shared under 972 the given name across multiple sessions. 973 name: Optional name for the queue operation. 974 """ 975 types = _as_type_list(types) 976 shapes = _as_shape_list(shapes, types) 977 978 queue_ref = gen_data_flow_ops.priority_queue_v2( 979 component_types=types, 980 shapes=shapes, 981 capacity=capacity, 982 shared_name=_shared_name(shared_name), 983 name=name) 984 985 priority_dtypes = [_dtypes.int64] + types 986 priority_shapes = [()] + shapes if shapes else shapes 987 988 super(PriorityQueue, self).__init__(priority_dtypes, priority_shapes, names, 989 queue_ref) 990 991 992# TODO(josh11b): class BatchQueue(QueueBase): 993 994 995class Barrier: 996 """Represents a key-value map that persists across graph executions.""" 997 998 def __init__(self, types, shapes=None, shared_name=None, name="barrier"): 999 """Creates a barrier that persists across different graph executions. 1000 1001 A barrier represents a key-value map, where each key is a string, and 1002 each value is a tuple of tensors. 1003 1004 At runtime, the barrier contains 'complete' and 'incomplete' 1005 elements. A complete element has defined tensors for all 1006 components of its value tuple, and may be accessed using 1007 take_many. An incomplete element has some undefined components in 1008 its value tuple, and may be updated using insert_many. 1009 1010 The barrier call `take_many` outputs values in a particular order. 1011 First, it only outputs completed values. Second, the order in which 1012 completed values are returned matches the order in which their very 1013 first component was inserted into the barrier. So, for example, for this 1014 sequence of insertions and removals: 1015 1016 barrier = Barrier((tf.string, tf.int32), shapes=((), ())) 1017 barrier.insert_many(0, keys=["k1", "k2"], values=["a", "b"]).run() 1018 barrier.insert_many(1, keys=["k1"], values=[1]).run() 1019 barrier.insert_many(0, keys=["k3"], values=["c"]).run() 1020 barrier.insert_many(1, keys=["k3"], values=[3]).run() 1021 barrier.insert_many(1, keys=["k2"], values=[2]).run() 1022 1023 (indices, keys, values) = barrier.take_many(2) 1024 (indices_val, keys_val, values0_val, values1_val) = 1025 session.run([indices, keys, values[0], values[1]]) 1026 1027 The output will be (up to permutation of "k1" and "k2"): 1028 1029 indices_val == (-2**63, -2**63) 1030 keys_val == ("k1", "k2") 1031 values0_val == ("a", "b") 1032 values1_val == (1, 2) 1033 1034 Note the key "k2" was inserted into the barrier before "k3". Even though 1035 "k3" was completed first, both are complete by the time 1036 take_many is called. As a result, "k2" is prioritized and "k1" and "k2" 1037 are returned first. "k3" remains in the barrier until the next execution 1038 of `take_many`. Since "k1" and "k2" had their first insertions into 1039 the barrier together, their indices are the same (-2**63). The index 1040 of "k3" will be -2**63 + 1, because it was the next new inserted key. 1041 1042 Args: 1043 types: A single dtype or a tuple of dtypes, corresponding to the 1044 dtypes of the tensor elements that comprise a value in this barrier. 1045 shapes: Optional. Constraints on the shapes of tensors in the values: 1046 a single tensor shape tuple; a tuple of tensor shape tuples 1047 for each barrier-element tuple component; or None if the shape should 1048 not be constrained. 1049 shared_name: Optional. If non-empty, this barrier will be shared under 1050 the given name across multiple sessions. 1051 name: Optional name for the barrier op. 1052 1053 Raises: 1054 ValueError: If one of the `shapes` indicate no elements. 1055 """ 1056 self._types = _as_type_list(types) 1057 1058 if shapes is not None: 1059 shapes = _as_shape_list(shapes, self._types) 1060 self._shapes = [tensor_shape.TensorShape(s) for s in shapes] 1061 for i, shape in enumerate(self._shapes): 1062 if shape.num_elements() == 0: 1063 raise ValueError("Empty tensors are not supported, but received " 1064 f"shape '{shape}' at index {i}") 1065 else: 1066 self._shapes = [tensor_shape.unknown_shape() for _ in self._types] 1067 1068 self._barrier_ref = gen_data_flow_ops.barrier( 1069 component_types=self._types, 1070 shapes=self._shapes, 1071 shared_name=shared_name, 1072 name=name) 1073 if context.executing_eagerly(): 1074 self._name = context.context().scope_name 1075 else: 1076 self._name = self._barrier_ref.op.name.split("/")[-1] 1077 1078 @property 1079 def barrier_ref(self): 1080 """Get the underlying barrier reference.""" 1081 return self._barrier_ref 1082 1083 @property 1084 def name(self): 1085 """The name of the underlying barrier.""" 1086 if context.executing_eagerly(): 1087 return self._name 1088 return self._barrier_ref.op.name 1089 1090 def insert_many(self, component_index, keys, values, name=None): 1091 """For each key, assigns the respective value to the specified component. 1092 1093 This operation updates each element at component_index. 1094 1095 Args: 1096 component_index: The component of the value that is being assigned. 1097 keys: A vector of keys, with length n. 1098 values: An any-dimensional tensor of values, which are associated with the 1099 respective keys. The first dimension must have length n. 1100 name: Optional name for the op. 1101 1102 Returns: 1103 The operation that performs the insertion. 1104 Raises: 1105 InvalidArgumentsError: If inserting keys and values without elements. 1106 """ 1107 if name is None: 1108 name = "%s_BarrierInsertMany" % self._name 1109 return gen_data_flow_ops.barrier_insert_many( 1110 self._barrier_ref, keys, values, component_index, name=name) 1111 1112 def take_many(self, 1113 num_elements, 1114 allow_small_batch=False, 1115 timeout=None, 1116 name=None): 1117 """Takes the given number of completed elements from this barrier. 1118 1119 This operation concatenates completed-element component tensors along 1120 the 0th dimension to make a single component tensor. 1121 1122 If barrier has no completed elements, this operation will block 1123 until there are 'num_elements' elements to take. 1124 1125 TODO(b/25743580): the semantics of `allow_small_batch` are experimental 1126 and may be extended to other cases in the future. 1127 1128 TODO(ebrevdo): If a take_many(allow_small_batch=True) is blocking 1129 already when the barrier is closed, it will block for ever. Fix this 1130 by using asynchronous operations. 1131 1132 Args: 1133 num_elements: The number of elements to take. 1134 allow_small_batch: If the barrier is closed, don't block if there are less 1135 completed elements than requested, but instead return all available 1136 completed elements. 1137 timeout: This specifies the number of milliseconds to block 1138 before returning with DEADLINE_EXCEEDED. (This option is not 1139 supported yet.) 1140 name: A name for the operation (optional). 1141 1142 Returns: 1143 A tuple of (index, key, value_list). 1144 "index" is a int64 tensor of length num_elements containing the 1145 index of the insert_many call for which the very first component of 1146 the given element was inserted into the Barrier, starting with 1147 the value -2**63. Note, this value is different from the 1148 index of the insert_many call for which the element was completed. 1149 "key" is a string tensor of length num_elements containing the keys. 1150 "value_list" is a tuple of tensors, each one with size num_elements 1151 in the 0th dimension for each component in the barrier's values. 1152 1153 """ 1154 if name is None: 1155 name = "%s_BarrierTakeMany" % self._name 1156 ret = gen_data_flow_ops.barrier_take_many( 1157 self._barrier_ref, 1158 num_elements, 1159 self._types, 1160 allow_small_batch, 1161 timeout, 1162 name=name) 1163 1164 # NOTE(mrry): Not using a shape function because we need access to 1165 # the Barrier object. 1166 if not context.executing_eagerly(): 1167 op = ret[0].op 1168 if allow_small_batch: 1169 batch_dim = None 1170 else: 1171 batch_dim = tensor_shape.Dimension( 1172 tensor_util.constant_value(op.inputs[1])) 1173 op.outputs[0].set_shape(tensor_shape.TensorShape([batch_dim])) # indices 1174 op.outputs[1].set_shape(tensor_shape.TensorShape([batch_dim])) # keys 1175 for output, shape in zip(op.outputs[2:], self._shapes): # value_list 1176 output.set_shape( 1177 tensor_shape.TensorShape([batch_dim]).concatenate(shape)) 1178 1179 return ret 1180 1181 def close(self, cancel_pending_enqueues=False, name=None): 1182 """Closes this barrier. 1183 1184 This operation signals that no more new key values will be inserted in the 1185 given barrier. Subsequent InsertMany operations with new keys will fail. 1186 InsertMany operations that just complement already existing keys with other 1187 components, will continue to succeed. Subsequent TakeMany operations will 1188 continue to succeed if sufficient elements remain in the barrier. Subsequent 1189 TakeMany operations that would block will fail immediately. 1190 1191 If `cancel_pending_enqueues` is `True`, all pending requests to the 1192 underlying queue will also be canceled, and completing of already 1193 started values is also not acceptable anymore. 1194 1195 Args: 1196 cancel_pending_enqueues: (Optional.) A boolean, defaulting to 1197 `False` (described above). 1198 name: Optional name for the op. 1199 1200 Returns: 1201 The operation that closes the barrier. 1202 """ 1203 if name is None: 1204 name = "%s_BarrierClose" % self._name 1205 return gen_data_flow_ops.barrier_close( 1206 self._barrier_ref, 1207 cancel_pending_enqueues=cancel_pending_enqueues, 1208 name=name) 1209 1210 def ready_size(self, name=None): 1211 """Compute the number of complete elements in the given barrier. 1212 1213 Args: 1214 name: A name for the operation (optional). 1215 1216 Returns: 1217 A single-element tensor containing the number of complete elements in the 1218 given barrier. 1219 """ 1220 if name is None: 1221 name = "%s_BarrierReadySize" % self._name 1222 return gen_data_flow_ops.barrier_ready_size(self._barrier_ref, name=name) 1223 1224 def incomplete_size(self, name=None): 1225 """Compute the number of incomplete elements in the given barrier. 1226 1227 Args: 1228 name: A name for the operation (optional). 1229 1230 Returns: 1231 A single-element tensor containing the number of incomplete elements in 1232 the given barrier. 1233 """ 1234 if name is None: 1235 name = "%s_BarrierIncompleteSize" % self._name 1236 return gen_data_flow_ops.barrier_incomplete_size( 1237 self._barrier_ref, name=name) 1238 1239 1240@tf_export(v1=["ConditionalAccumulatorBase"]) 1241class ConditionalAccumulatorBase: 1242 """A conditional accumulator for aggregating gradients. 1243 1244 Up-to-date gradients (i.e., time step at which gradient was computed is 1245 equal to the accumulator's time step) are added to the accumulator. 1246 1247 Extraction of the average gradient is blocked until the required number of 1248 gradients has been accumulated. 1249 """ 1250 1251 def __init__(self, dtype, shape, accumulator_ref): 1252 """Creates a new ConditionalAccumulator. 1253 1254 Args: 1255 dtype: Datatype of the accumulated gradients. 1256 shape: Shape of the accumulated gradients. 1257 accumulator_ref: A handle to the conditional accumulator, created by sub- 1258 classes 1259 """ 1260 self._dtype = dtype 1261 if shape is not None: 1262 self._shape = tensor_shape.TensorShape(shape) 1263 else: 1264 self._shape = tensor_shape.unknown_shape() 1265 self._accumulator_ref = accumulator_ref 1266 if context.executing_eagerly(): 1267 self._name = context.context().scope_name 1268 else: 1269 self._name = self._accumulator_ref.op.name.split("/")[-1] 1270 1271 @property 1272 def accumulator_ref(self): 1273 """The underlying accumulator reference.""" 1274 return self._accumulator_ref 1275 1276 @property 1277 def name(self): 1278 """The name of the underlying accumulator.""" 1279 return self._name 1280 1281 @property 1282 def dtype(self): 1283 """The datatype of the gradients accumulated by this accumulator.""" 1284 return self._dtype 1285 1286 def num_accumulated(self, name=None): 1287 """Number of gradients that have currently been aggregated in accumulator. 1288 1289 Args: 1290 name: Optional name for the operation. 1291 1292 Returns: 1293 Number of accumulated gradients currently in accumulator. 1294 """ 1295 if name is None: 1296 name = "%s_NumAccumulated" % self._name 1297 1298 return gen_data_flow_ops.resource_accumulator_num_accumulated( 1299 self._accumulator_ref, name=name) 1300 1301 def set_global_step(self, new_global_step, name=None): 1302 """Sets the global time step of the accumulator. 1303 1304 The operation logs a warning if we attempt to set to a time step that is 1305 lower than the accumulator's own time step. 1306 1307 Args: 1308 new_global_step: Value of new time step. Can be a variable or a constant 1309 name: Optional name for the operation. 1310 1311 Returns: 1312 Operation that sets the accumulator's time step. 1313 """ 1314 return gen_data_flow_ops.resource_accumulator_set_global_step( 1315 self._accumulator_ref, 1316 math_ops.cast(ops.convert_to_tensor(new_global_step), _dtypes.int64), 1317 name=name) 1318 1319 1320@tf_export(v1=["ConditionalAccumulator"]) 1321class ConditionalAccumulator(ConditionalAccumulatorBase): 1322 """A conditional accumulator for aggregating gradients. 1323 1324 Up-to-date gradients (i.e., time step at which gradient was computed is 1325 equal to the accumulator's time step) are added to the accumulator. 1326 1327 Extraction of the average gradient is blocked until the required number of 1328 gradients has been accumulated. 1329 """ 1330 1331 def __init__(self, 1332 dtype, 1333 shape=None, 1334 shared_name=None, 1335 name="conditional_accumulator", 1336 reduction_type="MEAN"): 1337 """Creates a new ConditionalAccumulator. 1338 1339 Args: 1340 dtype: Datatype of the accumulated gradients. 1341 shape: Shape of the accumulated gradients. 1342 shared_name: Optional. If non-empty, this accumulator will be shared under 1343 the given name across multiple sessions. 1344 name: Optional name for the accumulator. 1345 reduction_type: Reduction type to use when taking the gradient. 1346 """ 1347 accumulator_ref = gen_data_flow_ops.resource_conditional_accumulator( 1348 dtype=dtype, 1349 shape=shape, 1350 shared_name=shared_name, 1351 name=name, 1352 reduction_type=reduction_type) 1353 if context.executing_eagerly(): 1354 self._resource_deleter = resource_variable_ops.EagerResourceDeleter( 1355 handle=accumulator_ref, handle_device=context.context().device_name) 1356 1357 super(ConditionalAccumulator, self).__init__(dtype, shape, accumulator_ref) 1358 1359 def apply_grad(self, grad, local_step=0, name=None): 1360 """Attempts to apply a gradient to the accumulator. 1361 1362 The attempt is silently dropped if the gradient is stale, i.e., local_step 1363 is less than the accumulator's global time step. 1364 1365 Args: 1366 grad: The gradient tensor to be applied. 1367 local_step: Time step at which the gradient was computed. 1368 name: Optional name for the operation. 1369 1370 Returns: 1371 The operation that (conditionally) applies a gradient to the accumulator. 1372 1373 Raises: 1374 ValueError: If grad is of the wrong shape 1375 """ 1376 grad = ops.convert_to_tensor(grad, self._dtype) 1377 grad.get_shape().assert_is_compatible_with(self._shape) 1378 local_step = math_ops.cast(ops.convert_to_tensor(local_step), _dtypes.int64) 1379 1380 return gen_data_flow_ops.resource_accumulator_apply_gradient( 1381 self._accumulator_ref, local_step=local_step, gradient=grad, name=name) 1382 1383 def take_grad(self, num_required, name=None): 1384 """Attempts to extract the average gradient from the accumulator. 1385 1386 The operation blocks until sufficient number of gradients have been 1387 successfully applied to the accumulator. 1388 1389 Once successful, the following actions are also triggered: 1390 1391 - Counter of accumulated gradients is reset to 0. 1392 - Aggregated gradient is reset to 0 tensor. 1393 - Accumulator's internal time step is incremented by 1. 1394 1395 Args: 1396 num_required: Number of gradients that needs to have been aggregated 1397 name: Optional name for the operation 1398 1399 Returns: 1400 A tensor holding the value of the average gradient. 1401 1402 Raises: 1403 InvalidArgumentError: If num_required < 1 1404 """ 1405 out = gen_data_flow_ops.resource_accumulator_take_gradient( 1406 self._accumulator_ref, num_required, dtype=self._dtype, name=name) 1407 out.set_shape(self._shape) 1408 return out 1409 1410 1411@tf_export( 1412 v1=["sparse.SparseConditionalAccumulator", "SparseConditionalAccumulator"]) 1413class SparseConditionalAccumulator(ConditionalAccumulatorBase): 1414 """A conditional accumulator for aggregating sparse gradients. 1415 1416 Sparse gradients are represented by `IndexedSlices`. 1417 1418 Up-to-date gradients (i.e., time step at which gradient was computed is 1419 equal to the accumulator's time step) are added to the accumulator. 1420 1421 Extraction of the average gradient is blocked until the required number of 1422 gradients has been accumulated. 1423 1424 Args: 1425 dtype: Datatype of the accumulated gradients. 1426 shape: Shape of the accumulated gradients. 1427 shared_name: Optional. If non-empty, this accumulator will be shared under 1428 the given name across multiple sessions. 1429 name: Optional name for the accumulator. 1430 reduction_type: Reduction type to use when taking the gradient. 1431 """ 1432 1433 def __init__(self, 1434 dtype, 1435 shape=None, 1436 shared_name=None, 1437 name="sparse_conditional_accumulator", 1438 reduction_type="MEAN"): 1439 accumulator_ref = gen_data_flow_ops.sparse_conditional_accumulator( 1440 dtype=dtype, 1441 shape=shape, 1442 shared_name=shared_name, 1443 name=name, 1444 reduction_type=reduction_type) 1445 super(SparseConditionalAccumulator, self).__init__(dtype, shape, 1446 accumulator_ref) 1447 1448 def apply_indexed_slices_grad(self, grad, local_step=0, name=None): 1449 """Attempts to apply a gradient to the accumulator. 1450 1451 The attempt is silently dropped if the gradient is stale, i.e., `local_step` 1452 is less than the accumulator's global time step. 1453 1454 Args: 1455 grad: The gradient `IndexedSlices` to be applied. 1456 local_step: Time step at which the gradient was computed. 1457 name: Optional name for the operation. 1458 1459 Returns: 1460 The operation that (conditionally) applies a gradient to the accumulator. 1461 1462 Raises: 1463 InvalidArgumentError: If grad is of the wrong shape 1464 """ 1465 return self.apply_grad( 1466 grad_indices=grad.indices, 1467 grad_values=grad.values, 1468 grad_shape=grad.dense_shape, 1469 local_step=local_step, 1470 name=name) 1471 1472 def apply_grad(self, 1473 grad_indices, 1474 grad_values, 1475 grad_shape=None, 1476 local_step=0, 1477 name=None): 1478 """Attempts to apply a sparse gradient to the accumulator. 1479 1480 The attempt is silently dropped if the gradient is stale, i.e., `local_step` 1481 is less than the accumulator's global time step. 1482 1483 A sparse gradient is represented by its indices, values and possibly empty 1484 or None shape. Indices must be a vector representing the locations of 1485 non-zero entries in the tensor. Values are the non-zero slices of the 1486 gradient, and must have the same first dimension as indices, i.e., the nnz 1487 represented by indices and values must be consistent. Shape, if not empty or 1488 None, must be consistent with the accumulator's shape (if also provided). 1489 1490 Example: 1491 A tensor [[0, 0], [0, 1], [2, 3]] can be represented 1492 indices: [1,2] 1493 values: [[0,1],[2,3]] 1494 shape: [3, 2] 1495 1496 Args: 1497 grad_indices: Indices of the sparse gradient to be applied. 1498 grad_values: Values of the sparse gradient to be applied. 1499 grad_shape: Shape of the sparse gradient to be applied. 1500 local_step: Time step at which the gradient was computed. 1501 name: Optional name for the operation. 1502 1503 Returns: 1504 The operation that (conditionally) applies a gradient to the accumulator. 1505 1506 Raises: 1507 InvalidArgumentError: If grad is of the wrong shape 1508 """ 1509 local_step = math_ops.cast(ops.convert_to_tensor(local_step), _dtypes.int64) 1510 return gen_data_flow_ops.sparse_accumulator_apply_gradient( 1511 self._accumulator_ref, 1512 local_step=local_step, 1513 gradient_indices=math_ops.cast(grad_indices, _dtypes.int64), 1514 gradient_values=grad_values, 1515 gradient_shape=math_ops.cast( 1516 [] if grad_shape is None else grad_shape, _dtypes.int64), 1517 has_known_shape=(grad_shape is not None), 1518 name=name) 1519 1520 def take_grad(self, num_required, name=None): 1521 """Attempts to extract the average gradient from the accumulator. 1522 1523 The operation blocks until sufficient number of gradients have been 1524 successfully applied to the accumulator. 1525 1526 Once successful, the following actions are also triggered: 1527 - Counter of accumulated gradients is reset to 0. 1528 - Aggregated gradient is reset to 0 tensor. 1529 - Accumulator's internal time step is incremented by 1. 1530 1531 Args: 1532 num_required: Number of gradients that needs to have been aggregated 1533 name: Optional name for the operation 1534 1535 Returns: 1536 A tuple of indices, values, and shape representing the average gradient. 1537 1538 Raises: 1539 InvalidArgumentError: If `num_required` < 1 1540 """ 1541 return gen_data_flow_ops.sparse_accumulator_take_gradient( 1542 self._accumulator_ref, num_required, dtype=self._dtype, name=name) 1543 1544 def take_indexed_slices_grad(self, num_required, name=None): 1545 """Attempts to extract the average gradient from the accumulator. 1546 1547 The operation blocks until sufficient number of gradients have been 1548 successfully applied to the accumulator. 1549 1550 Once successful, the following actions are also triggered: 1551 - Counter of accumulated gradients is reset to 0. 1552 - Aggregated gradient is reset to 0 tensor. 1553 - Accumulator's internal time step is incremented by 1. 1554 1555 Args: 1556 num_required: Number of gradients that needs to have been aggregated 1557 name: Optional name for the operation 1558 1559 Returns: 1560 An `IndexedSlices` holding the value of the average gradient. 1561 1562 Raises: 1563 InvalidArgumentError: If `num_required` < 1 1564 """ 1565 return_val = gen_data_flow_ops.sparse_accumulator_take_gradient( 1566 self._accumulator_ref, num_required, dtype=self._dtype, name=name) 1567 return indexed_slices.IndexedSlices( 1568 indices=return_val.indices, 1569 values=return_val.values, 1570 dense_shape=return_val.shape) 1571 1572 # SparseConditionalAccumulator is not switched to resource. Use old kernels. 1573 def num_accumulated(self, name=None): 1574 """Number of gradients that have currently been aggregated in accumulator. 1575 1576 Args: 1577 name: Optional name for the operation. 1578 1579 Returns: 1580 Number of accumulated gradients currently in accumulator. 1581 """ 1582 if name is None: 1583 name = "%s_NumAccumulated" % self._name 1584 1585 return gen_data_flow_ops.accumulator_num_accumulated( 1586 self._accumulator_ref, name=name) 1587 1588 def set_global_step(self, new_global_step, name=None): 1589 """Sets the global time step of the accumulator. 1590 1591 The operation logs a warning if we attempt to set to a time step that is 1592 lower than the accumulator's own time step. 1593 1594 Args: 1595 new_global_step: Value of new time step. Can be a variable or a constant 1596 name: Optional name for the operation. 1597 1598 Returns: 1599 Operation that sets the accumulator's time step. 1600 """ 1601 return gen_data_flow_ops.accumulator_set_global_step( 1602 self._accumulator_ref, 1603 math_ops.cast(ops.convert_to_tensor(new_global_step), _dtypes.int64), 1604 name=name) 1605 1606 1607class BaseStagingArea: 1608 """Base class for Staging Areas.""" 1609 _identifier = 0 1610 _lock = threading.Lock() 1611 1612 def __init__(self, 1613 dtypes, 1614 shapes=None, 1615 names=None, 1616 shared_name=None, 1617 capacity=0, 1618 memory_limit=0): 1619 if shared_name is None: 1620 self._name = ( 1621 ops.get_default_graph().unique_name(self.__class__.__name__)) 1622 elif isinstance(shared_name, str): 1623 self._name = shared_name 1624 else: 1625 raise ValueError(f"shared_name must be a string, got {shared_name}") 1626 1627 self._dtypes = dtypes 1628 1629 if shapes is not None: 1630 if len(shapes) != len(dtypes): 1631 raise ValueError("StagingArea shapes must be the same length as dtypes") 1632 self._shapes = [tensor_shape.TensorShape(s) for s in shapes] 1633 else: 1634 self._shapes = [tensor_shape.unknown_shape() for _ in self._dtypes] 1635 1636 if names is not None: 1637 if len(names) != len(dtypes): 1638 raise ValueError("StagingArea names must be the same length as dtypes") 1639 self._names = names 1640 else: 1641 self._names = None 1642 1643 self._capacity = capacity 1644 self._memory_limit = memory_limit 1645 1646 # all get and put ops must colocate with this op 1647 with ops.name_scope("%s_root" % self._name): 1648 self._coloc_op = control_flow_ops.no_op() 1649 1650 @property 1651 def name(self): 1652 """The name of the staging area.""" 1653 return self._name 1654 1655 @property 1656 def dtypes(self): 1657 """The list of dtypes for each component of a staging area element.""" 1658 return self._dtypes 1659 1660 @property 1661 def shapes(self): 1662 """The list of shapes for each component of a staging area element.""" 1663 return self._shapes 1664 1665 @property 1666 def names(self): 1667 """The list of names for each component of a staging area element.""" 1668 return self._names 1669 1670 @property 1671 def capacity(self): 1672 """The maximum number of elements of this staging area.""" 1673 return self._capacity 1674 1675 @property 1676 def memory_limit(self): 1677 """The maximum number of bytes of this staging area.""" 1678 return self._memory_limit 1679 1680 def _check_put_dtypes(self, vals, indices=None): 1681 """Validate and convert `vals` to a list of `Tensor`s. 1682 1683 The `vals` argument can be a Tensor, a list or tuple of tensors, or a 1684 dictionary with tensor values. 1685 1686 If `vals` is a list, then the appropriate indices associated with the 1687 values must be provided. 1688 1689 If it is a dictionary, the staging area must have been constructed with a 1690 `names` attribute and the dictionary keys must match the staging area names. 1691 `indices` will be inferred from the dictionary keys. 1692 If the staging area was constructed with a `names` attribute, `vals` must 1693 be a dictionary. 1694 1695 Checks that the dtype and shape of each value matches that 1696 of the staging area. 1697 1698 Args: 1699 vals: A tensor, a list or tuple of tensors, or a dictionary. 1700 1701 Returns: 1702 A (tensors, indices) tuple where `tensors` is a list of `Tensor` objects 1703 and `indices` is a list of indices associated with the tensors. 1704 1705 Raises: 1706 ValueError: If `vals` or `indices` is invalid. 1707 """ 1708 if isinstance(vals, dict): 1709 if not self._names: 1710 raise ValueError( 1711 "Staging areas must have names to enqueue a dictionary") 1712 if not set(vals.keys()).issubset(self._names): 1713 raise ValueError("Keys in dictionary to put do not match names " 1714 f"of staging area. Dictionary: {sorted(vals.keys())}" 1715 f"Queue: {sorted(self._names)}") 1716 # The order of values in `self._names` indicates the order in which the 1717 # tensors in the dictionary `vals` must be listed. 1718 vals, indices, _ = zip(*[(vals[k], i, k) 1719 for i, k in enumerate(self._names) 1720 if k in vals]) 1721 else: 1722 if self._names: 1723 raise ValueError("You must enqueue a dictionary in a staging area " 1724 "with names") 1725 1726 if indices is None: 1727 raise ValueError("Indices must be supplied when inserting a list " 1728 "of tensors") 1729 1730 if len(indices) != len(vals): 1731 raise ValueError(f"Number of indices {len(indices)} doesn't match " 1732 f"number of values {len(vals)}") 1733 1734 if not isinstance(vals, (list, tuple)): 1735 vals = [vals] 1736 indices = [0] 1737 1738 # Sanity check number of values 1739 if not len(vals) <= len(self._dtypes): 1740 raise ValueError(f"Unexpected number of inputs {len(vals)} vs " 1741 f"{len(self._dtypes)}") 1742 1743 tensors = [] 1744 1745 for val, i in zip(vals, indices): 1746 dtype, shape = self._dtypes[i], self._shapes[i] 1747 # Check dtype 1748 if val.dtype != dtype: 1749 raise ValueError(f"Datatypes do not match. " 1750 f"Received val.dtype {str(val.dtype)} and " 1751 f"dtype {str(dtype)}") 1752 # Check shape 1753 val.get_shape().assert_is_compatible_with(shape) 1754 1755 tensors.append( 1756 ops.convert_to_tensor(val, dtype=dtype, name="component_%d" % i)) 1757 1758 return tensors, indices 1759 1760 def _create_device_transfers(self, tensors): 1761 """Encode inter-device transfers if the current device 1762 is not the same as the Staging Area's device. 1763 """ 1764 1765 if not isinstance(tensors, (tuple, list)): 1766 tensors = [tensors] 1767 1768 curr_device_scope = control_flow_ops.no_op().device 1769 1770 if curr_device_scope != self._coloc_op.device: 1771 tensors = [array_ops.identity(t) for t in tensors] 1772 1773 return tensors 1774 1775 def _get_return_value(self, tensors, indices): 1776 """Return the value to return from a get op. 1777 1778 If the staging area has names, return a dictionary with the 1779 names as keys. Otherwise return either a single tensor 1780 or a list of tensors depending on the length of `tensors`. 1781 1782 Args: 1783 tensors: List of tensors from the get op. 1784 indices: Indices of associated names and shapes 1785 1786 Returns: 1787 A single tensor, a list of tensors, or a dictionary 1788 of tensors. 1789 """ 1790 1791 tensors = self._create_device_transfers(tensors) 1792 1793 # Sets shape 1794 for output, i in zip(tensors, indices): 1795 output.set_shape(self._shapes[i]) 1796 1797 if self._names: 1798 # The returned values in `tensors` are in the same order as 1799 # the names in `self._names`. 1800 return {self._names[i]: t for t, i in zip(tensors, indices)} 1801 return tensors 1802 1803 def _scope_vals(self, vals): 1804 """Return a list of values to pass to `name_scope()`. 1805 1806 Args: 1807 vals: A tensor, a list or tuple of tensors, or a dictionary. 1808 1809 Returns: 1810 The values in vals as a list. 1811 """ 1812 if isinstance(vals, (list, tuple)): 1813 return vals 1814 elif isinstance(vals, dict): 1815 return vals.values() 1816 else: 1817 return [vals] 1818 1819 1820class StagingArea(BaseStagingArea): 1821 """Class for staging inputs. No ordering guarantees. 1822 1823 A `StagingArea` is a TensorFlow data structure that stores tensors across 1824 multiple steps, and exposes operations that can put and get tensors. 1825 1826 Each `StagingArea` element is a tuple of one or more tensors, where each 1827 tuple component has a static dtype, and may have a static shape. 1828 1829 The capacity of a `StagingArea` may be bounded or unbounded. 1830 It supports multiple concurrent producers and consumers; and 1831 provides exactly-once delivery. 1832 1833 Each element of a `StagingArea` is a fixed-length tuple of tensors whose 1834 dtypes are described by `dtypes`, and whose shapes are optionally described 1835 by the `shapes` argument. 1836 1837 If the `shapes` argument is specified, each component of a staging area 1838 element must have the respective fixed shape. If it is 1839 unspecified, different elements may have different shapes, 1840 1841 It can be configured with a capacity in which case 1842 put(values) will block until space becomes available. 1843 1844 Similarly, it can be configured with a memory limit which 1845 will block put(values) until space is available. 1846 This is mostly useful for limiting the number of tensors on 1847 devices such as GPUs. 1848 1849 All get() and peek() commands block if the requested data 1850 is not present in the Staging Area. 1851 1852 """ 1853 1854 def __init__(self, 1855 dtypes, 1856 shapes=None, 1857 names=None, 1858 shared_name=None, 1859 capacity=0, 1860 memory_limit=0): 1861 """Constructs a staging area object. 1862 1863 The two optional lists, `shapes` and `names`, must be of the same length 1864 as `dtypes` if provided. The values at a given index `i` indicate the 1865 shape and name to use for the corresponding queue component in `dtypes`. 1866 1867 The device scope at the time of object creation determines where the 1868 storage for the `StagingArea` will reside. Calls to `put` will incur a copy 1869 to this memory space, if necessary. Tensors returned by `get` will be 1870 placed according to the device scope when `get` is called. 1871 1872 Args: 1873 dtypes: A list of types. The length of dtypes must equal the number 1874 of tensors in each element. 1875 shapes: (Optional.) Constraints on the shapes of tensors in an element. 1876 A list of shape tuples or None. This list is the same length 1877 as dtypes. If the shape of any tensors in the element are constrained, 1878 all must be; shapes can be None if the shapes should not be constrained. 1879 names: (Optional.) If provided, the `get()` and 1880 `put()` methods will use dictionaries with these names as keys. 1881 Must be None or a list or tuple of the same length as `dtypes`. 1882 shared_name: (Optional.) A name to be used for the shared object. By 1883 passing the same name to two different python objects they will share 1884 the underlying staging area. Must be a string. 1885 capacity: (Optional.) Maximum number of elements. 1886 An integer. If zero, the Staging Area is unbounded 1887 memory_limit: (Optional.) Maximum number of bytes of all tensors 1888 in the Staging Area. 1889 An integer. If zero, the Staging Area is unbounded 1890 1891 Raises: 1892 ValueError: If one of the arguments is invalid. 1893 """ 1894 1895 super(StagingArea, self).__init__(dtypes, shapes, names, shared_name, 1896 capacity, memory_limit) 1897 1898 def put(self, values, name=None): 1899 """Create an op that places a value into the staging area. 1900 1901 This operation will block if the `StagingArea` has reached 1902 its capacity. 1903 1904 Args: 1905 values: A single tensor, a list or tuple of tensors, or a dictionary with 1906 tensor values. The number of elements must match the length of the 1907 list provided to the dtypes argument when creating the StagingArea. 1908 name: A name for the operation (optional). 1909 1910 Returns: 1911 The created op. 1912 1913 Raises: 1914 ValueError: If the number or type of inputs don't match the staging area. 1915 """ 1916 with ops.name_scope(name, "%s_put" % self._name, 1917 self._scope_vals(values)) as scope: 1918 1919 if not isinstance(values, (list, tuple, dict)): 1920 values = [values] 1921 1922 # Hard-code indices for this staging area 1923 indices = list(range(len(values))) 1924 vals, _ = self._check_put_dtypes(values, indices) 1925 1926 with ops.colocate_with(self._coloc_op): 1927 op = gen_data_flow_ops.stage( 1928 values=vals, 1929 shared_name=self._name, 1930 name=scope, 1931 capacity=self._capacity, 1932 memory_limit=self._memory_limit) 1933 1934 return op 1935 1936 def __internal_get(self, get_fn, name): 1937 with ops.colocate_with(self._coloc_op): 1938 ret = get_fn() 1939 1940 indices = list(range(len(self._dtypes))) # Hard coded 1941 return self._get_return_value(ret, indices) 1942 1943 def get(self, name=None): 1944 """Gets one element from this staging area. 1945 1946 If the staging area is empty when this operation executes, it will block 1947 until there is an element to dequeue. 1948 1949 Note that unlike others ops that can block, like the queue Dequeue 1950 operations, this can stop other work from happening. To avoid this, the 1951 intended use is for this to be called only when there will be an element 1952 already available. One method for doing this in a training loop would be to 1953 run a `put()` call during a warmup session.run call, and then call both 1954 `get()` and `put()` in each subsequent step. 1955 1956 The placement of the returned tensor will be determined by the current 1957 device scope when this function is called. 1958 1959 Args: 1960 name: A name for the operation (optional). 1961 1962 Returns: 1963 The tuple of tensors that was gotten. 1964 """ 1965 if name is None: 1966 name = "%s_get" % self._name 1967 1968 # pylint: disable=bad-continuation 1969 fn = lambda: gen_data_flow_ops.unstage(dtypes=self._dtypes, 1970 shared_name=self._name, name=name, 1971 capacity=self._capacity, 1972 memory_limit=self._memory_limit) 1973 # pylint: enable=bad-continuation 1974 1975 return self.__internal_get(fn, name) 1976 1977 def peek(self, index, name=None): 1978 """Peeks at an element in the staging area. 1979 1980 If the staging area is too small to contain the element at 1981 the specified index, it will block until enough elements 1982 are inserted to complete the operation. 1983 1984 The placement of the returned tensor will be determined by 1985 the current device scope when this function is called. 1986 1987 Args: 1988 index: The index of the tensor within the staging area 1989 to look up. 1990 name: A name for the operation (optional). 1991 1992 Returns: 1993 The tuple of tensors that was gotten. 1994 """ 1995 if name is None: 1996 name = "%s_peek" % self._name 1997 1998 # pylint: disable=bad-continuation 1999 fn = lambda: gen_data_flow_ops.stage_peek(index, 2000 dtypes=self._dtypes, shared_name=self._name, 2001 name=name, capacity=self._capacity, 2002 memory_limit=self._memory_limit) 2003 # pylint: enable=bad-continuation 2004 2005 return self.__internal_get(fn, name) 2006 2007 def size(self, name=None): 2008 """Returns the number of elements in the staging area. 2009 2010 Args: 2011 name: A name for the operation (optional) 2012 2013 Returns: 2014 The created op 2015 """ 2016 if name is None: 2017 name = "%s_size" % self._name 2018 2019 return gen_data_flow_ops.stage_size( 2020 name=name, 2021 shared_name=self._name, 2022 dtypes=self._dtypes, 2023 capacity=self._capacity, 2024 memory_limit=self._memory_limit) 2025 2026 def clear(self, name=None): 2027 """Clears the staging area. 2028 2029 Args: 2030 name: A name for the operation (optional) 2031 2032 Returns: 2033 The created op 2034 """ 2035 if name is None: 2036 name = "%s_clear" % self._name 2037 2038 return gen_data_flow_ops.stage_clear( 2039 name=name, 2040 shared_name=self._name, 2041 dtypes=self._dtypes, 2042 capacity=self._capacity, 2043 memory_limit=self._memory_limit) 2044 2045 2046class MapStagingArea(BaseStagingArea): 2047 """A `MapStagingArea` is a TensorFlow data structure that stores tensors 2048 across multiple steps, and exposes operations that can put and get tensors. 2049 2050 Each `MapStagingArea` element is a (key, value) pair. 2051 Only int64 keys are supported, other types should be 2052 hashed to produce a key. 2053 Values are a tuple of one or more tensors. 2054 Each tuple component has a static dtype, 2055 and may have a static shape. 2056 2057 The capacity of a `MapStagingArea` may be bounded or unbounded. 2058 It supports multiple concurrent producers and consumers; and 2059 provides exactly-once delivery. 2060 2061 Each value tuple of a `MapStagingArea` is a fixed-length tuple of tensors 2062 whose 2063 dtypes are described by `dtypes`, and whose shapes are optionally described 2064 by the `shapes` argument. 2065 2066 If the `shapes` argument is specified, each component of a staging area 2067 element must have the respective fixed shape. If it is 2068 unspecified, different elements may have different shapes, 2069 2070 It behaves like an associative container with support for: 2071 2072 - put(key, values) 2073 - peek(key) like dict.get(key) 2074 - get(key) like dict.pop(key) 2075 - get(key=None) like dict.popitem() 2076 - size() 2077 - clear() 2078 2079 If ordered a tree structure ordered by key will be used and 2080 get(key=None) will remove (key, value) pairs in increasing key order. 2081 Otherwise a hashtable 2082 2083 It can be configured with a capacity in which case 2084 put(key, values) will block until space becomes available. 2085 2086 Similarly, it can be configured with a memory limit which 2087 will block put(key, values) until space is available. 2088 This is mostly useful for limiting the number of tensors on 2089 devices such as GPUs. 2090 2091 All get() and peek() commands block if the requested 2092 (key, value) pair is not present in the staging area. 2093 2094 Partial puts are supported and will be placed in an incomplete 2095 map until such time as all values associated with the key have 2096 been inserted. Once completed, this (key, value) pair will be 2097 inserted into the map. Data in the incomplete map 2098 counts towards the memory limit, but not towards capacity limit. 2099 2100 Partial gets from the map are also supported. 2101 This removes the partially requested tensors from the entry, 2102 but the entry is only removed from the map once all tensors 2103 associated with it are removed. 2104 """ 2105 2106 def __init__(self, 2107 dtypes, 2108 shapes=None, 2109 names=None, 2110 shared_name=None, 2111 ordered=False, 2112 capacity=0, 2113 memory_limit=0): 2114 """Args: 2115 2116 dtypes: A list of types. The length of dtypes must equal the number 2117 of tensors in each element. 2118 capacity: (Optional.) Maximum number of elements. 2119 An integer. If zero, the Staging Area is unbounded 2120 memory_limit: (Optional.) Maximum number of bytes of all tensors 2121 in the Staging Area (excluding keys). 2122 An integer. If zero, the Staging Area is unbounded 2123 ordered: (Optional.) If True the underlying data structure 2124 is a tree ordered on key. Otherwise assume a hashtable. 2125 shapes: (Optional.) Constraints on the shapes of tensors in an element. 2126 A list of shape tuples or None. This list is the same length 2127 as dtypes. If the shape of any tensors in the element are constrained, 2128 all must be; shapes can be None if the shapes should not be constrained. 2129 names: (Optional.) If provided, the `get()` and 2130 `put()` methods will use dictionaries with these names as keys. 2131 Must be None or a list or tuple of the same length as `dtypes`. 2132 shared_name: (Optional.) A name to be used for the shared object. By 2133 passing the same name to two different python objects they will share 2134 the underlying staging area. Must be a string. 2135 2136 Raises: 2137 ValueError: If one of the arguments is invalid. 2138 2139 """ 2140 2141 super(MapStagingArea, self).__init__(dtypes, shapes, names, shared_name, 2142 capacity, memory_limit) 2143 2144 # Defer to different methods depending if the map is ordered 2145 self._ordered = ordered 2146 2147 if ordered: 2148 self._put_fn = gen_data_flow_ops.ordered_map_stage 2149 self._pop_fn = gen_data_flow_ops.ordered_map_unstage 2150 self._popitem_fn = gen_data_flow_ops.ordered_map_unstage_no_key 2151 self._peek_fn = gen_data_flow_ops.ordered_map_peek 2152 self._size_fn = gen_data_flow_ops.ordered_map_size 2153 self._incomplete_size_fn = gen_data_flow_ops.ordered_map_incomplete_size 2154 self._clear_fn = gen_data_flow_ops.ordered_map_clear 2155 else: 2156 self._put_fn = gen_data_flow_ops.map_stage 2157 self._pop_fn = gen_data_flow_ops.map_unstage 2158 self._popitem_fn = gen_data_flow_ops.map_unstage_no_key 2159 self._peek_fn = gen_data_flow_ops.map_peek 2160 self._size_fn = gen_data_flow_ops.map_size 2161 self._incomplete_size_fn = gen_data_flow_ops.map_incomplete_size 2162 self._clear_fn = gen_data_flow_ops.map_clear 2163 2164 def put(self, key, vals, indices=None, name=None): 2165 """Create an op that stores the (key, vals) pair in the staging area. 2166 2167 Incomplete puts are possible, preferably using a dictionary for vals 2168 as the appropriate dtypes and shapes can be inferred from the value names 2169 dictionary key values. If vals is a list or tuple, indices must 2170 also be specified so that the op knows at which element position 2171 to perform the insert. 2172 2173 This operation will block if the capacity or memory limit of this 2174 container is reached. 2175 2176 Args: 2177 key: Key associated with the data 2178 vals: Tensor (or a dict/tuple of Tensors) to place 2179 into the staging area. 2180 indices: (Optional) if vals is a tuple/list, this is required. 2181 name: A name for the operation (optional) 2182 2183 Returns: 2184 The created op 2185 2186 Raises: 2187 ValueError: If the number or type of inputs don't match the staging 2188 area. 2189 """ 2190 2191 with ops.name_scope(name, "%s_put" % self._name, 2192 self._scope_vals(vals)) as scope: 2193 2194 vals, indices = self._check_put_dtypes(vals, indices) 2195 2196 with ops.colocate_with(self._coloc_op): 2197 op = self._put_fn( 2198 key, 2199 indices, 2200 vals, 2201 dtypes=self._dtypes, 2202 shared_name=self._name, 2203 name=scope, 2204 capacity=self._capacity, 2205 memory_limit=self._memory_limit) 2206 return op 2207 2208 def _get_indices_and_dtypes(self, indices=None): 2209 if indices is None: 2210 indices = list(range(len(self._dtypes))) 2211 2212 if not isinstance(indices, (tuple, list)): 2213 raise TypeError(f"Invalid indices type {type(indices)}") 2214 2215 if len(indices) == 0: 2216 raise ValueError("Empty indices") 2217 2218 if all(isinstance(i, str) for i in indices): 2219 if self._names is None: 2220 raise ValueError(f"String indices provided {indices}, but " 2221 "this Staging Area was not created with names.") 2222 2223 try: 2224 indices = [self._names.index(n) for n in indices] 2225 except ValueError: 2226 raise ValueError(f"Named index not in " 2227 f"Staging Area names {self._names}") 2228 elif all(isinstance(i, int) for i in indices): 2229 pass 2230 else: 2231 raise TypeError(f"Mixed types in indices {indices}. " 2232 "May only be str or int") 2233 2234 dtypes = [self._dtypes[i] for i in indices] 2235 2236 return indices, dtypes 2237 2238 def peek(self, key, indices=None, name=None): 2239 """Peeks at staging area data associated with the key. 2240 2241 If the key is not in the staging area, it will block 2242 until the associated (key, value) is inserted. 2243 2244 Args: 2245 key: Key associated with the required data 2246 indices: Partial list of tensors to retrieve (optional). 2247 A list of integer or string indices. 2248 String indices are only valid if the Staging Area 2249 has names associated with it. 2250 name: A name for the operation (optional) 2251 2252 Returns: 2253 The created op 2254 """ 2255 2256 if name is None: 2257 name = "%s_pop" % self._name 2258 2259 indices, dtypes = self._get_indices_and_dtypes(indices) 2260 2261 with ops.colocate_with(self._coloc_op): 2262 result = self._peek_fn( 2263 key, 2264 shared_name=self._name, 2265 indices=indices, 2266 dtypes=dtypes, 2267 name=name, 2268 capacity=self._capacity, 2269 memory_limit=self._memory_limit) 2270 2271 return self._get_return_value(result, indices) 2272 2273 def get(self, key=None, indices=None, name=None): 2274 """If the key is provided, the associated (key, value) is returned from the staging area. 2275 2276 If the key is not in the staging area, this method will block until 2277 the associated (key, value) is inserted. 2278 If no key is provided and the staging area is ordered, 2279 the (key, value) with the smallest key will be returned. 2280 Otherwise, a random (key, value) will be returned. 2281 2282 If the staging area is empty when this operation executes, 2283 it will block until there is an element to dequeue. 2284 2285 Args: 2286 key: Key associated with the required data (Optional) 2287 indices: Partial list of tensors to retrieve (optional). 2288 A list of integer or string indices. 2289 String indices are only valid if the Staging Area 2290 has names associated with it. 2291 name: A name for the operation (optional) 2292 2293 Returns: 2294 The created op 2295 """ 2296 if key is None: 2297 return self._popitem(indices=indices, name=name) 2298 else: 2299 return self._pop(key, indices=indices, name=name) 2300 2301 def _pop(self, key, indices=None, name=None): 2302 """Remove and return the associated (key, value) is returned from the staging area. 2303 2304 If the key is not in the staging area, this method will block until 2305 the associated (key, value) is inserted. 2306 Args: 2307 key: Key associated with the required data 2308 indices: Partial list of tensors to retrieve (optional). 2309 A list of integer or string indices. 2310 String indices are only valid if the Staging Area 2311 has names associated with it. 2312 name: A name for the operation (optional) 2313 2314 Returns: 2315 The created op 2316 """ 2317 if name is None: 2318 name = "%s_get" % self._name 2319 2320 indices, dtypes = self._get_indices_and_dtypes(indices) 2321 2322 with ops.colocate_with(self._coloc_op): 2323 result = self._pop_fn( 2324 key, 2325 shared_name=self._name, 2326 indices=indices, 2327 dtypes=dtypes, 2328 name=name, 2329 capacity=self._capacity, 2330 memory_limit=self._memory_limit) 2331 2332 return key, self._get_return_value(result, indices) 2333 2334 def _popitem(self, indices=None, name=None): 2335 """If the staging area is ordered, the (key, value) with the smallest key will be returned. 2336 2337 Otherwise, a random (key, value) will be returned. 2338 If the staging area is empty when this operation executes, 2339 it will block until there is an element to dequeue. 2340 2341 Args: 2342 key: Key associated with the required data 2343 indices: Partial list of tensors to retrieve (optional). 2344 A list of integer or string indices. 2345 String indices are only valid if the Staging Area 2346 has names associated with it. 2347 name: A name for the operation (optional) 2348 2349 Returns: 2350 The created op 2351 """ 2352 if name is None: 2353 name = "%s_get_nokey" % self._name 2354 2355 indices, dtypes = self._get_indices_and_dtypes(indices) 2356 2357 with ops.colocate_with(self._coloc_op): 2358 key, result = self._popitem_fn( 2359 shared_name=self._name, 2360 indices=indices, 2361 dtypes=dtypes, 2362 name=name, 2363 capacity=self._capacity, 2364 memory_limit=self._memory_limit) 2365 2366 # Separate keys and results out from 2367 # underlying namedtuple 2368 key = self._create_device_transfers(key)[0] 2369 result = self._get_return_value(result, indices) 2370 2371 return key, result 2372 2373 def size(self, name=None): 2374 """Returns the number of elements in the staging area. 2375 2376 Args: 2377 name: A name for the operation (optional) 2378 2379 Returns: 2380 The created op 2381 """ 2382 if name is None: 2383 name = "%s_size" % self._name 2384 2385 return self._size_fn( 2386 shared_name=self._name, 2387 name=name, 2388 dtypes=self._dtypes, 2389 capacity=self._capacity, 2390 memory_limit=self._memory_limit) 2391 2392 def incomplete_size(self, name=None): 2393 """Returns the number of incomplete elements in the staging area. 2394 2395 Args: 2396 name: A name for the operation (optional) 2397 2398 Returns: 2399 The created op 2400 """ 2401 if name is None: 2402 name = "%s_incomplete_size" % self._name 2403 2404 return self._incomplete_size_fn( 2405 shared_name=self._name, 2406 name=name, 2407 dtypes=self._dtypes, 2408 capacity=self._capacity, 2409 memory_limit=self._memory_limit) 2410 2411 def clear(self, name=None): 2412 """Clears the staging area. 2413 2414 Args: 2415 name: A name for the operation (optional) 2416 2417 Returns: 2418 The created op 2419 """ 2420 if name is None: 2421 name = "%s_clear" % self._name 2422 2423 return self._clear_fn( 2424 shared_name=self._name, 2425 name=name, 2426 dtypes=self._dtypes, 2427 capacity=self._capacity, 2428 memory_limit=self._memory_limit) 2429 2430 2431class RecordInput: 2432 """RecordInput asynchronously reads and randomly yields TFRecords. 2433 2434 A RecordInput Op will continuously read a batch of records asynchronously 2435 into a buffer of some fixed capacity. It can also asynchronously yield 2436 random records from this buffer. 2437 2438 It will not start yielding until at least `buffer_size / 2` elements have been 2439 placed into the buffer so that sufficient randomization can take place. 2440 2441 The order the files are read will be shifted each epoch by `shift_amount` so 2442 that the data is presented in a different order every epoch. 2443 """ 2444 2445 def __init__(self, 2446 file_pattern, 2447 batch_size=1, 2448 buffer_size=1, 2449 parallelism=1, 2450 shift_ratio=0, 2451 seed=0, 2452 name=None, 2453 batches=None, 2454 compression_type=None): 2455 """Constructs a RecordInput Op. 2456 2457 Args: 2458 file_pattern: File path to the dataset, possibly containing wildcards. 2459 All matching files will be iterated over each epoch. 2460 batch_size: How many records to return at a time. 2461 buffer_size: The maximum number of records the buffer will contain. 2462 parallelism: How many reader threads to use for reading from files. 2463 shift_ratio: What percentage of the total number files to move the start 2464 file forward by each epoch. 2465 seed: Specify the random number seed used by generator that randomizes 2466 records. 2467 name: Optional name for the operation. 2468 batches: None by default, creating a single batch op. Otherwise specifies 2469 how many batches to create, which are returned as a list when 2470 `get_yield_op()` is called. An example use case is to split processing 2471 between devices on one computer. 2472 compression_type: The type of compression for the file. Currently ZLIB and 2473 GZIP are supported. Defaults to none. 2474 2475 Raises: 2476 ValueError: If one of the arguments is invalid. 2477 """ 2478 self._batch_size = batch_size 2479 if batches is not None: 2480 self._batch_size *= batches 2481 self._batches = batches 2482 self._file_pattern = file_pattern 2483 self._buffer_size = buffer_size 2484 self._parallelism = parallelism 2485 self._shift_ratio = shift_ratio 2486 self._seed = seed 2487 self._name = name 2488 self._compression_type = python_io.TFRecordCompressionType.NONE 2489 if compression_type is not None: 2490 self._compression_type = compression_type 2491 2492 def get_yield_op(self): 2493 """Adds a node that yields a group of records every time it is executed. 2494 If RecordInput `batches` parameter is not None, it yields a list of 2495 record batches with the specified `batch_size`. 2496 """ 2497 compression_type = python_io.TFRecordOptions.get_compression_type_string( 2498 python_io.TFRecordOptions(self._compression_type)) 2499 records = gen_data_flow_ops.record_input( 2500 file_pattern=self._file_pattern, 2501 file_buffer_size=self._buffer_size, 2502 file_parallelism=self._parallelism, 2503 file_shuffle_shift_ratio=self._shift_ratio, 2504 batch_size=self._batch_size, 2505 file_random_seed=self._seed, 2506 compression_type=compression_type, 2507 name=self._name) 2508 if self._batches is None: 2509 return records 2510 else: 2511 with ops.name_scope(self._name): 2512 batch_list = [[] for _ in range(self._batches)] 2513 records = array_ops.split(records, self._batch_size, 0) 2514 for index, protobuf in enumerate(records): 2515 batch_index = index % self._batches 2516 batch_list[batch_index].append(array_ops.reshape(protobuf, [])) 2517 return batch_list 2518