1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""TensorArray: a dynamically sized array of Tensors.""" 16# Mixture of pep8 and non-pep8 names, so disable pylint bad-name 17# pylint: disable=g-bad-name 18import contextlib 19 20import traceback 21import weakref 22 23import numpy as np 24 25from tensorflow.python.eager import context 26from tensorflow.python.framework import constant_op 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import errors_impl 29from tensorflow.python.framework import ops 30from tensorflow.python.framework import tensor_shape 31from tensorflow.python.framework import tensor_spec 32from tensorflow.python.framework import tensor_util 33from tensorflow.python.framework import type_spec 34from tensorflow.python.ops import array_ops 35from tensorflow.python.ops import control_flow_util 36from tensorflow.python.ops import gen_control_flow_ops 37from tensorflow.python.ops import gen_data_flow_ops 38from tensorflow.python.ops import list_ops 39from tensorflow.python.ops import math_ops 40from tensorflow.python.platform import tf_logging as logging 41from tensorflow.python.util import tf_should_use 42from tensorflow.python.util.tf_export import tf_export 43 44 45# _GraphTensorArray accesses many of the hidden generated ops, but is in 46# fact built to wrap these methods. 47# pylint: disable=protected-access 48class _GraphTensorArray: 49 """Graph-mode implementation of TensorArray.""" 50 51 def __init__(self, 52 dtype, 53 size=None, 54 dynamic_size=None, 55 clear_after_read=None, 56 tensor_array_name=None, 57 handle=None, 58 flow=None, 59 infer_shape=True, 60 element_shape=None, 61 colocate_with_first_write_call=True, 62 name=None): 63 """Constructs a graph mode TensorArray. 64 65 Args: 66 dtype: (required) data type of the TensorArray. 67 size: (optional) int32 scalar `Tensor`: the size of the TensorArray. 68 Required if handle is not provided. 69 dynamic_size: (optional) Python bool: If true, writes to the TensorArray 70 can grow the TensorArray past its initial size. Default: False. 71 clear_after_read: Boolean (optional, default: True). If True, clear 72 TensorArray values after reading them. This disables read-many 73 semantics, but allows early release of memory. 74 tensor_array_name: (optional) Python string: the name of the TensorArray. 75 This is used when creating the TensorArray handle. If this value is 76 set, handle should be None. 77 handle: (optional) A `Tensor` handle to an existing TensorArray. If this 78 is set, tensor_array_name should be None. Only supported in graph mode. 79 flow: (optional) A float `Tensor` scalar coming from an existing 80 `TensorArray.flow`. Only supported in graph mode. 81 infer_shape: (optional, default: True) If True, shape inference is 82 enabled. In this case, all elements must have the same shape. 83 element_shape: (optional, default: None) A `TensorShape` object specifying 84 the shape constraints of each of the elements of the TensorArray. Need 85 not be fully defined. 86 colocate_with_first_write_call: If `True`, the TensorArray will be 87 colocated on the same device as the Tensor used on its first write 88 (write operations include `write`, `unstack`, and `split`). If `False`, 89 the TensorArray will be placed on the device determined by the device 90 context available during its initialization. 91 name: A name for the operation (optional). 92 93 Raises: 94 ValueError: if both handle and tensor_array_name are provided. 95 TypeError: if handle is provided but is not a Tensor. 96 """ 97 if handle is not None and tensor_array_name: 98 raise ValueError( 99 "Cannot provide both `handle` and `tensor_array_name` arguments at " 100 "the same time.") 101 if handle is not None and not isinstance(handle, ops.Tensor): 102 raise TypeError( 103 f"Expected `handle` to be a Tensor, but got `{handle}` of type " 104 f"`{type(handle)}` instead.") 105 if handle is None and size is None: 106 raise ValueError( 107 "Argument `size` must be provided if handle is not provided.") 108 if handle is not None and size is not None: 109 raise ValueError("Cannot provide both a `handle` and `size` arguments " 110 "at the same time.") 111 if handle is not None and element_shape is not None: 112 raise ValueError( 113 "Cannot provide both `handle` and `element_shape` arguments " 114 "at the same time.") 115 if handle is not None and dynamic_size is not None: 116 raise ValueError( 117 "Cannot provide both `handle` and `dynamic_size` arguments " 118 "at the same time.") 119 if handle is not None and clear_after_read is not None: 120 raise ValueError( 121 "Cannot provide both `handle` and `clear_after_read` arguments " 122 "at the same time.") 123 124 if clear_after_read is None: 125 clear_after_read = True 126 self._dynamic_size = dynamic_size or False 127 self._dtype = dtypes.as_dtype(dtype).base_dtype 128 129 # Used to keep track of what tensors the TensorArray should be 130 # colocated with. We choose to colocate the TensorArray with the 131 # first tensor written to it. 132 self._colocate_with_first_write_call = colocate_with_first_write_call 133 if colocate_with_first_write_call: 134 self._colocate_with = [] 135 else: 136 self._colocate_with = None 137 138 # Record the current static shape for the array elements. The element 139 # shape is defined either by `element_shape` or the shape of the tensor 140 # of the first write. If `infer_shape` is true, all writes checks for 141 # shape equality. 142 self._element_shape = [tensor_shape.as_shape(element_shape)] 143 self._infer_shape = infer_shape 144 self._size = size 145 with ops.name_scope(name, "TensorArray", [handle, size, flow]) as scope: 146 if handle is not None: 147 self._handle = handle 148 if flow is None: 149 raise ValueError("flow must not be None if handle is not None.") 150 self._flow = flow 151 else: 152 # Construct the TensorArray with an empty device. The first 153 # write into the TensorArray from a Tensor with a set device 154 # will retroactively set the device value of this op. 155 def create(): 156 """Create the TensorArray op.""" 157 return gen_data_flow_ops.tensor_array_v3( 158 dtype=dtype, 159 size=size, 160 element_shape=element_shape, 161 identical_element_shapes=infer_shape, 162 dynamic_size=self._dynamic_size, 163 clear_after_read=clear_after_read, 164 tensor_array_name=tensor_array_name, 165 name=scope) 166 167 if colocate_with_first_write_call: 168 with ops.device(None), ops.colocate_with(None, ignore_existing=True): 169 self._handle, self._flow = create() 170 else: 171 self._handle, self._flow = create() 172 173 @property 174 def flow(self): 175 return self._flow 176 177 @property 178 def dtype(self): 179 return self._dtype 180 181 @property 182 def handle(self): 183 return self._handle 184 185 @property 186 def element_shape(self): 187 return self._element_shape[0] 188 189 def _check_element_shape(self, shape): 190 """Changes the element shape of the array given a shape to merge with. 191 192 Args: 193 shape: A `TensorShape` object to merge with. 194 195 Raises: 196 ValueError: if the provided shape is incompatible with the current 197 element shape of the `TensorArray`. 198 """ 199 if not shape.is_compatible_with(self.element_shape): 200 raise ValueError("Inconsistent shapes: saw %s but expected %s " % 201 (shape, self.element_shape)) 202 if self._infer_shape: 203 self._element_shape[0] = self.element_shape.merge_with(shape) 204 205 @contextlib.contextmanager 206 def _maybe_colocate_with(self, value): 207 """Colocate operations with an internal colocation group or `value`. 208 209 Args: 210 value: `Tensor`, the tensor to try to colocate with. 211 212 Yields: 213 Does not yield anything, but the new context is a colocation context. 214 215 If no internal colocation group is set, colocate with `value` and set 216 the internal colocation group to be value. 217 """ 218 if not self._colocate_with_first_write_call: 219 yield 220 else: 221 if not self._colocate_with: 222 self._colocate_with.append(value) 223 with ops.colocate_with(self._colocate_with[0]): 224 yield 225 226 def identity(self): 227 """See TensorArray.""" 228 flow = array_ops.identity(self._flow) 229 return build_ta_with_new_flow(self, flow) 230 231 def grad(self, source, flow=None, name=None): 232 """See TensorArray.""" 233 # tensor_array_grad requires a flow input when forward 234 # TensorArrays are dynamically sized. This forces the creation 235 # of the grad TensorArray only once the final forward array's size 236 # is fixed. 237 if flow is None: 238 flow = self.flow 239 with ops.name_scope(name, "TensorArrayGrad", [self._handle]): 240 with ops.colocate_with(self._handle): 241 g_handle, unused_flow = gen_data_flow_ops.tensor_array_grad_v3( 242 handle=self._handle, source=source, flow_in=flow, name=name) 243 with ops.control_dependencies([g_handle]): 244 flow = array_ops.identity(flow, name="gradient_flow") 245 g = TensorArray( 246 dtype=self._dtype, 247 handle=g_handle, 248 flow=flow, 249 infer_shape=self._infer_shape, 250 colocate_with_first_write_call=False) 251 # pylint: disable=protected-access 252 g._implementation._element_shape = self._element_shape 253 # pylint: enable=protected-access 254 return g 255 256 def read(self, index, name=None): 257 """See TensorArray.""" 258 value = gen_data_flow_ops.tensor_array_read_v3( 259 handle=self._handle, 260 index=index, 261 flow_in=self._flow, 262 dtype=self._dtype, 263 name=name) 264 if self._element_shape: 265 value.set_shape(self._element_shape[0].dims) 266 return value 267 268 def write(self, index, value, name=None): 269 """See TensorArray.""" 270 with ops.name_scope(name, "TensorArrayWrite", [self._handle, index, value]): 271 # TODO(b/129870929): Fix after all callers provide proper init dtype. 272 value = ops.convert_to_tensor( 273 value, preferred_dtype=self._dtype, name="value") 274 _check_dtypes(value, self._dtype) 275 self._check_element_shape(value.shape) 276 with self._maybe_colocate_with(value): 277 flow_out = gen_data_flow_ops.tensor_array_write_v3( 278 handle=self._handle, 279 index=index, 280 value=value, 281 flow_in=self._flow, 282 name=name) 283 return build_ta_with_new_flow(self, flow_out) 284 285 def stack(self, name=None): 286 """See TensorArray.""" 287 with ops.colocate_with(self._handle): 288 with ops.name_scope(name, "TensorArrayStack", [self._handle]): 289 value = self.gather(math_ops.range(0, self.size()), name=name) 290 if (self.element_shape and not self._dynamic_size and 291 self._size is not None): 292 value.set_shape([tensor_util.constant_value(self._size)] + 293 self.element_shape.dims) 294 return value 295 296 def gather(self, indices, name=None): 297 """See TensorArray.""" 298 if self._element_shape: 299 element_shape = self._element_shape[0] 300 else: 301 element_shape = tensor_shape.unknown_shape(None) 302 value = gen_data_flow_ops.tensor_array_gather_v3( 303 handle=self._handle, 304 indices=indices, 305 flow_in=self._flow, 306 dtype=self._dtype, 307 name=name, 308 element_shape=element_shape) 309 if self.element_shape: 310 value.set_shape([None] + self.element_shape.dims) 311 return value 312 313 def concat(self, name=None): 314 """See TensorArray.""" 315 value, _ = gen_data_flow_ops.tensor_array_concat_v3( 316 handle=self._handle, 317 flow_in=self._flow, 318 dtype=self._dtype, 319 name=name, 320 element_shape_except0=self.element_shape[1:]) 321 if self.element_shape: 322 value.set_shape([None] + self.element_shape.dims[1:]) 323 return value 324 325 @tf_should_use.should_use_result 326 def unstack(self, value, name=None): 327 """See TensorArray.""" 328 with ops.name_scope(name, "TensorArrayUnstack", [self._handle, value]): 329 num_elements = array_ops.shape(value)[0] 330 return self.scatter( 331 indices=math_ops.range(0, num_elements), value=value, name=name) 332 333 @tf_should_use.should_use_result 334 def scatter(self, indices, value, name=None): 335 """See TensorArray.""" 336 with ops.name_scope(name, "TensorArrayScatter", 337 [self._handle, value, indices]): 338 # TODO(b/129870929): Fix after all callers provide proper init dtype. 339 value = ops.convert_to_tensor( 340 value, preferred_dtype=self._dtype, name="value") 341 _check_dtypes(value, self._dtype) 342 if not context.executing_eagerly(): 343 self._check_element_shape(value.shape[1:]) 344 with self._maybe_colocate_with(value): 345 flow_out = gen_data_flow_ops.tensor_array_scatter_v3( 346 handle=self._handle, 347 indices=indices, 348 value=value, 349 flow_in=self._flow, 350 name=name) 351 return build_ta_with_new_flow(self, flow_out) 352 353 @tf_should_use.should_use_result 354 def split(self, value, lengths, name=None): 355 """See TensorArray.""" 356 with ops.name_scope(name, "TensorArraySplit", 357 [self._handle, value, lengths]): 358 value = ops.convert_to_tensor(value, dtype=self._dtype, name="value") 359 with self._maybe_colocate_with(value): 360 lengths_64 = math_ops.cast(lengths, dtypes.int64) 361 if not context.executing_eagerly(): 362 clengths = tensor_util.constant_value(lengths_64) 363 if value.shape.dims is not None and clengths is not None: 364 if clengths.shape and clengths.max() == clengths.min(): 365 self._check_element_shape( 366 tensor_shape.TensorShape([clengths[0] 367 ]).concatenate(value.shape[1:])) 368 flow_out = gen_data_flow_ops.tensor_array_split_v3( 369 handle=self._handle, 370 value=value, 371 lengths=lengths_64, 372 flow_in=self._flow, 373 name=name) 374 return build_ta_with_new_flow(self, flow_out) 375 376 def size(self, name=None): 377 """See TensorArray.""" 378 if not self._dynamic_size and self._size is not None: 379 return ops.convert_to_tensor(self._size, dtype=dtypes.int32) 380 else: 381 return gen_data_flow_ops.tensor_array_size_v3( 382 handle=self._handle, flow_in=self.flow, name=name) 383 384 @tf_should_use.should_use_result 385 def close(self, name=None): 386 """See TensorArray.""" 387 return gen_data_flow_ops.tensor_array_close_v3( 388 handle=self._handle, name=name) 389 390 391class _GraphTensorArrayV2: 392 """Graph-mode implementation of TensorArray backed by TensorLists. 393 394 The backing tensor of this TensorArray is a TensorList variant tensor which is 395 stored in the `flow`. The `handle` is always none here. The reason we use the 396 `flow` field and not the `handle` field is to ensure backwards compatibility 397 with legacy control flow. 398 """ 399 400 def __init__(self, 401 dtype, 402 size=None, 403 dynamic_size=None, 404 clear_after_read=None, 405 tensor_array_name=None, 406 handle=None, 407 flow=None, 408 infer_shape=True, 409 element_shape=None, 410 colocate_with_first_write_call=True, 411 name=None): 412 """Constructs a graph mode TensorArray. 413 414 Args: 415 dtype: (required) data type of the TensorArray. 416 size: (optional) int32 scalar `Tensor`: the size of the TensorArray. 417 Required if flow is not provided. 418 dynamic_size: (optional) Python bool: If true, writes to the TensorArray 419 can grow the TensorArray past its initial size. Default: False. 420 clear_after_read: (optional) unused. Not supported in TensorLists. 421 tensor_array_name: (optional) unused. 422 handle: (optional) Must always be None. 423 flow: (optional) A variant `Tensor` scalar for a TensorList. 424 infer_shape: (optional, default: True) If True, shape inference is 425 enabled. In this case, all elements must have the same shape. 426 element_shape: (optional, default: None) A `TensorShape` object specifying 427 the shape constraints of each of the elements of the TensorArray. Need 428 not be fully defined. 429 colocate_with_first_write_call: (optional). unused. 430 name: (optional) A name for the operation. 431 432 Raises: 433 ValueError: if both handle and tensor_array_name are provided. 434 TypeError: if handle is provided but is not a Tensor. 435 """ 436 assert handle is None 437 del handle 438 del clear_after_read 439 del tensor_array_name 440 del colocate_with_first_write_call 441 442 self._dynamic_size = dynamic_size 443 self._size = size 444 445 if (flow is not None and 446 (not isinstance(flow, ops.Tensor) or flow.dtype != dtypes.variant)): 447 raise TypeError( 448 f"Expected `flow` to be a variant tensor, but received `{flow.dtype}` " 449 f"instead.") 450 if flow is None and size is None: 451 raise ValueError("Argument `size` must be provided if argument `flow` " 452 "is not provided.") 453 if flow is not None and size is not None: 454 raise ValueError("Cannot provide both `flow` and `size` arguments " 455 "at the same time.") 456 if flow is not None and element_shape is not None: 457 raise ValueError( 458 "Cannot provide both `flow` and `element_shape` arguments" 459 "at the same time.") 460 461 self._dtype = dtypes.as_dtype(dtype).base_dtype 462 463 # Record the current static shape for the array elements. The element 464 # shape is defined either by `element_shape` or the shape of the tensor 465 # of the first write. If `infer_shape` is true, all writes checks for 466 # shape equality. 467 self._element_shape = [tensor_shape.as_shape(element_shape)] 468 self._infer_shape = infer_shape 469 with ops.name_scope(name, "TensorArrayV2", [size, flow]) as scope: 470 if flow is None: 471 self._flow = list_ops.tensor_list_reserve( 472 element_shape=element_shape, 473 num_elements=size, 474 element_dtype=dtype, 475 name=scope) 476 else: 477 self._flow = flow 478 479 # For backwards compatibility. 480 self._colocate_with_first_write_call = None 481 self._colocate_with = None 482 483 @property 484 def flow(self): 485 return self._flow 486 487 @property 488 def dtype(self): 489 return self._dtype 490 491 @property 492 def element_shape(self): 493 return self._element_shape[0] 494 495 @property 496 def handle(self): 497 # We intentionally do not raise an error so that legacy while_loop does not 498 # complain. 499 return None 500 501 def _check_element_shape(self, shape): 502 """Changes the element shape of the array given a shape to merge with. 503 504 Args: 505 shape: A `TensorShape` object to merge with. 506 507 Raises: 508 ValueError: if the provided shape is incompatible with the current 509 element shape of the `TensorArray`. 510 """ 511 if not shape.is_compatible_with(self.element_shape): 512 raise ValueError("Inconsistent shapes: saw %s but expected %s " % 513 (shape, self.element_shape)) 514 if self._infer_shape: 515 self._element_shape[0] = self.element_shape.merge_with(shape) 516 517 def identity(self): 518 """See TensorArray.""" 519 flow = array_ops.identity(self._flow) 520 return build_ta_with_new_flow(self, flow) 521 522 def grad(self, source, flow=None, name=None): 523 """Not supported.""" 524 raise NotImplementedError() 525 526 def read(self, index, name=None): 527 """See TensorArray.""" 528 with ops.name_scope(name, "TensorArrayV2Read", [self._flow, index]): 529 value = list_ops.tensor_list_get_item( 530 input_handle=self._flow, 531 index=index, 532 element_dtype=self._dtype, 533 element_shape=self.element_shape, 534 name=name) 535 return value 536 537 def write(self, index, value, name=None): 538 """See TensorArray.""" 539 with ops.name_scope(name, "TensorArrayV2Write", [self._flow, index, value]): 540 # TODO(b/129870929): Fix after all callers provide proper init dtype. 541 value = ops.convert_to_tensor( 542 value, preferred_dtype=self._dtype, name="value") 543 _check_dtypes(value, self._dtype) 544 self._check_element_shape(value.shape) 545 flow_out = list_ops.tensor_list_set_item( 546 input_handle=self._flow, 547 index=index, 548 item=value, 549 resize_if_index_out_of_bounds=self._dynamic_size, 550 name=name) 551 return build_ta_with_new_flow(self, flow_out) 552 553 def stack(self, name=None): 554 """See TensorArray.""" 555 with ops.name_scope(name, "TensorArrayV2Stack", [self._flow]): 556 # TODO(b/139941163): remove constant_value after changing num_elements to regular input 557 if not self._dynamic_size and self._size is not None: 558 ta_size = tensor_util.constant_value(self._size) 559 else: 560 ta_size = -1 561 value = list_ops.tensor_list_stack( 562 input_handle=self._flow, 563 element_dtype=self._dtype, 564 num_elements=ta_size, 565 element_shape=self.element_shape) 566 return value 567 568 def gather(self, indices, name=None): 569 """See TensorArray.""" 570 value = list_ops.tensor_list_gather( 571 input_handle=self._flow, 572 indices=indices, 573 element_dtype=self._dtype, 574 element_shape=self.element_shape, 575 name=name) 576 return value 577 578 def concat(self, name=None): 579 """See TensorArray.""" 580 if self.element_shape: 581 element_shape = [None] + self.element_shape.dims[1:] 582 else: 583 element_shape = None 584 585 value = list_ops.tensor_list_concat( 586 input_handle=self._flow, 587 element_dtype=self._dtype, 588 element_shape=element_shape, 589 name=name) 590 return value 591 592 @tf_should_use.should_use_result 593 def unstack(self, value, name=None): 594 """See TensorArray.""" 595 with ops.name_scope(name, "TensorArrayUnstack", [self._flow, value]): 596 # TODO(b/129870929): Fix after all callers provide proper init dtype. 597 value = ops.convert_to_tensor( 598 value, preferred_dtype=self._dtype, name="value") 599 _check_dtypes(value, self._dtype) 600 self._check_element_shape(value.shape[1:]) 601 flow_out = list_ops.tensor_list_from_tensor( 602 tensor=value, element_shape=value.shape[1:]) 603 return build_ta_with_new_flow(self, flow_out) 604 605 @tf_should_use.should_use_result 606 def scatter(self, indices, value, name=None): 607 """See TensorArray.""" 608 with ops.name_scope(name, "TensorArrayScatter", 609 [self._flow, value, indices]): 610 # TODO(b/129870929): Fix after all callers provide proper init dtype. 611 value = ops.convert_to_tensor( 612 value, preferred_dtype=self._dtype, name="value") 613 _check_dtypes(value, self._dtype) 614 self._check_element_shape(value.shape[1:]) 615 flow_out = list_ops.tensor_list_scatter( 616 tensor=value, 617 indices=indices, 618 element_shape=self.element_shape, 619 input_handle=self._flow) 620 return build_ta_with_new_flow(self, flow_out) 621 622 @tf_should_use.should_use_result 623 def split(self, value, lengths, name=None): 624 """See TensorArray.""" 625 with ops.name_scope(name, "TensorArraySplit", [self._flow, value, lengths]): 626 # TODO(b/129870929): Fix after all callers provide proper init dtype. 627 value = ops.convert_to_tensor( 628 value, preferred_dtype=self._dtype, name="value") 629 _check_dtypes(value, self._dtype) 630 lengths_64 = math_ops.cast(lengths, dtypes.int64) 631 if not context.executing_eagerly(): 632 clengths = tensor_util.constant_value(lengths_64) 633 if value.shape.dims is not None and clengths is not None: 634 if clengths.shape and clengths.max() == clengths.min(): 635 self._check_element_shape( 636 tensor_shape.TensorShape([clengths[0] 637 ]).concatenate(value.shape[1:])) 638 flow_out = list_ops.tensor_list_split( 639 tensor=value, 640 lengths=lengths_64, 641 element_shape=self.element_shape, 642 name=name) 643 return build_ta_with_new_flow(self, flow_out) 644 645 def size(self, name=None): 646 """See TensorArray.""" 647 if not self._dynamic_size and self._size is not None: 648 return ops.convert_to_tensor(self._size, dtype=dtypes.int32) 649 else: 650 return list_ops.tensor_list_length(input_handle=self._flow, name=name) 651 652 def close(self, name=None): 653 """See TensorArray.""" 654 return gen_control_flow_ops.no_op(name=name) 655 656 657# pylint: enable=protected-access 658 659 660class _EagerTensorArray: 661 """Eager-compatible implementation of TensorArray.""" 662 663 def __init__(self, 664 dtype, 665 size=None, 666 dynamic_size=None, 667 clear_after_read=None, 668 tensor_array_name=None, 669 handle=None, 670 flow=None, 671 infer_shape=True, 672 element_shape=None, 673 colocate_with_first_write_call=True, 674 name=None): 675 """Constructs a TensorArray compatible with eager execution. 676 677 Args: 678 dtype: (required) data type of the TensorArray. 679 size: (optional) int32 scalar `Tensor`: the size of the TensorArray. 680 Required if handle is not provided. 681 dynamic_size: (optional) Python bool: If true, writes to the TensorArray 682 can grow the TensorArray past its initial size. Default: False. 683 clear_after_read: Boolean (optional, default: True). If True, clear 684 TensorArray values after reading them. This disables read-many 685 semantics, but allows early release of memory. 686 tensor_array_name: unused. 687 handle: unsupported. 688 flow: unsupported. 689 infer_shape: used for error checking, same semantics as TensorArray. 690 element_shape: used for error checking, same semantics as TensorArray. 691 colocate_with_first_write_call: unsupported. 692 name: unsupported. 693 694 Raises: 695 ValueError: handle or flow are supplied, or if size is not supplied. 696 """ 697 698 del (flow, tensor_array_name, name) # Unused. 699 700 if handle is not None: 701 raise ValueError("TensorArray handles are not supported when eager " 702 "execution is enabled.") 703 if size is None: 704 raise ValueError("Size must be declared for TensorArrays when eager " 705 "execution is enabled.") 706 707 # These attributes are not meaningful when eager is enabled, but some 708 # library functions (e.g., those in control_flow_ops.py) access them to 709 # create new tensor arrays; as such, we define them for the sake of 710 # compatibility. 711 self._handle = None 712 # we assign a dummy value to _flow in case other code assumes it to be 713 # a Tensor 714 self._flow = constant_op.constant(0, dtype=dtypes.int32) 715 self._infer_shape = infer_shape 716 self._element_shape = tensor_shape.as_shape(element_shape) 717 self._colocate_with_first_write_call = colocate_with_first_write_call 718 719 self._dtype = dtypes.as_dtype(dtype).base_dtype 720 self._dynamic_size = dynamic_size or False 721 self._clear_after_read = (True 722 if clear_after_read is None else clear_after_read) 723 self._previously_read_indices = [] 724 725 if isinstance(size, ops.EagerTensor): 726 size = size.numpy() 727 self._tensor_array = [None for _ in range(size)] 728 729 @property 730 def flow(self): 731 """For compatibility; flows are not meaningful when eager is enabled.""" 732 return self._flow 733 734 @property 735 def dtype(self): 736 return self._dtype 737 738 @property 739 def handle(self): 740 """For compatibility; handles are not meaningful when eager is enabled.""" 741 return self._handle 742 743 @property 744 def element_shape(self): 745 return self._element_shape 746 747 def identity(self): 748 """See TensorArray.""" 749 return self.parent() 750 751 def grad(self, source, flow=None, name=None): 752 raise NotImplementedError( 753 "TensorArray.grad is not supported when executing eagerly; eager's " 754 "gradient implementation does not use/need this function to compute " 755 "gradients of operations that use TensorArrays.") 756 757 def read(self, index, name=None): 758 """See TensorArray.""" 759 del name # not meaningful when executing eagerly. 760 761 if isinstance(index, ops.EagerTensor): 762 index = index.numpy() 763 764 if index < 0: 765 raise errors_impl.OutOfRangeError( 766 None, None, 767 "Reading from negative indices (index %d) is not allowed." % index) 768 769 if index >= len(self._tensor_array): 770 raise errors_impl.OutOfRangeError( 771 None, None, "Tried to read from index %d but array size is: %d " % 772 (index, len(self._tensor_array))) 773 774 tensor = self._tensor_array[index] 775 if tensor is None: 776 if index in self._previously_read_indices: 777 raise errors_impl.InvalidArgumentError( 778 None, None, 779 "Could not read index %d twice because it was cleared after " 780 "a previous read (perhaps try setting clear_after_read = false?)" % 781 index) 782 else: 783 tensor = self._maybe_zero(index) 784 785 if self._clear_after_read: 786 self._tensor_array[index] = None 787 self._previously_read_indices.append(index) 788 return tensor 789 790 def _write(self, index, value): 791 """Writes `value` into index named by `index`. 792 793 Args: 794 index: 0-D. int32 scalar with the index to write to. 795 value: N-D. Tensor of type `dtype`. The `Tensor` to write to `index`. 796 797 Raises: 798 errors_impl.InvalidArgumentError: `value` dtype does not match dtype. 799 errors_impl.OutOfRangeError: `index` is out of bounds. 800 ValueError: shape of `value` is not consistent with inferred shape. 801 """ 802 803 if isinstance(index, ops.EagerTensor): 804 index = index.numpy() 805 806 if index < 0: 807 raise errors_impl.OutOfRangeError( 808 None, None, 809 "Writing to negative indices (index %d) is not allowed." % index) 810 811 size = len(self._tensor_array) 812 if index >= size: 813 if not self._dynamic_size: 814 raise errors_impl.OutOfRangeError( 815 None, None, 816 "Tried to write to index %d but array is not resizeable and size " 817 "is: %d " % (index, size)) 818 self._tensor_array.extend(None for _ in range(index - size + 1)) 819 820 if not isinstance(value, ops.EagerTensor): 821 # TODO(b/129870929): Fix after all callers provide proper init dtype. 822 value = ops.convert_to_tensor( 823 value, preferred_dtype=self._dtype, name="value") 824 825 if self._dtype != value.dtype: 826 raise errors_impl.InvalidArgumentError( 827 None, None, 828 "TensorArray dtype is %s but Op is trying to write dtype %s " % 829 (self._dtype.name, value.dtype.name)) 830 831 if not self._element_shape.is_compatible_with(value.shape): 832 raise ValueError("Incompatible shape for value (%s), expected (%s)" % 833 (value.shape, self._element_shape)) 834 835 if self._infer_shape: 836 self._element_shape = self._element_shape.merge_with(value.shape) 837 838 self._tensor_array[index] = value 839 840 def write(self, index, value, name=None): 841 """See TensorArray.""" 842 del name # not meaningful when executing eagerly. 843 self._write(index, value) 844 return self.parent() 845 846 def _maybe_zero(self, ix): 847 val = self._tensor_array[ix] 848 if val is None: 849 val = self._tensor_array[ix] = array_ops.zeros( 850 shape=self._element_shape, dtype=self._dtype) 851 return val 852 853 def stack(self, name=None): 854 """See TensorArray.""" 855 if self._tensor_array: 856 for ix in range(len(self._tensor_array)): 857 self._maybe_zero(ix) 858 if not self._tensor_array and self._element_shape.is_fully_defined(): 859 return ops.convert_to_tensor( 860 np.ndarray([0] + self._element_shape), name=name, dtype=self._dtype) 861 else: 862 return ops.convert_to_tensor( 863 self._tensor_array, name=name, dtype=self._dtype) 864 865 def gather(self, indices, name=None): 866 """See TensorArray.""" 867 del name # not meaningful when executing eagerly. 868 if isinstance(indices, ops.EagerTensor): 869 indices = indices.numpy() 870 return array_ops.stack([self._maybe_zero(i) for i in indices]) 871 872 def concat(self, name=None): 873 """See TensorArray.""" 874 try: 875 return array_ops.concat( 876 [self._maybe_zero(ix) for ix in range(len(self._tensor_array))], 877 0, 878 name=name) 879 except errors_impl.OpError: 880 # Reproduce a subset of the error-handling for graph-mode TensorArrays. 881 shapes = [t.shape for t in self._tensor_array] 882 ndims = [s.ndims for s in shapes] 883 if 0 in ndims: 884 idx = ndims.index(0) 885 raise errors_impl.InvalidArgumentError( 886 None, None, "Concat saw a scalar shape at index %d but requires " 887 "at least vectors." % idx) 888 else: 889 raise 890 891 def unstack(self, value, name=None): 892 """See TensorArray.""" 893 tensors = array_ops.unstack(value, name=name) 894 if len(tensors) > len(self._tensor_array) and not self._dynamic_size: 895 raise ValueError( 896 "Cannot unstack %d tensors into a TensorArray of static size %d " % 897 (len(tensors), len(self._tensor_array))) 898 self._tensor_array = tensors 899 return self.parent() 900 901 def scatter(self, indices, value, name=None): 902 """See TensorArray.""" 903 del name # not meaningful when executing eagerly. 904 if isinstance(indices, ops.EagerTensor): 905 indices = indices.numpy() 906 for index, val in zip(indices, array_ops.unstack(value)): 907 self._write(index, val) # pylint: disable=protected-access 908 return self.parent() 909 910 def split(self, value, lengths, name=None): 911 """See TensorArray.""" 912 # TODO(b/129870929): Fix after all callers provide proper init dtype. 913 value = ops.convert_to_tensor( 914 value, preferred_dtype=self._dtype, name="value") 915 _check_dtypes(value, self._dtype) 916 lengths = ops.convert_to_tensor(lengths) 917 sum_lengths = math_ops.reduce_sum(lengths) 918 if lengths.shape.ndims != 1: 919 raise errors_impl.InvalidArgumentError( 920 None, None, "Expected lengths to be a vector, received shape: %s " % 921 lengths.shape.as_list()) 922 elif value.shape.ndims == 0: 923 raise errors_impl.InvalidArgumentError( 924 None, None, "Expected value to be at least a vector, " 925 "but received shape: %s " % value.shape.as_list()) 926 elif sum_lengths.numpy() != value.shape.as_list()[0]: 927 raise errors_impl.InvalidArgumentError( 928 None, None, "Expected sum of lengths to be equal to " 929 "values.shape[0], but sum of lengths is %d and " 930 "value's shape is: %s " % (sum_lengths.numpy(), 931 value.shape.as_list())) 932 elif not self._dynamic_size and lengths.shape[0] != len(self._tensor_array): 933 raise errors_impl.InvalidArgumentError( 934 None, None, "TensorArray's size is not equal to the size of " 935 "lengths (%d vs. %d), and the TensorArray is not marked as " 936 "dynamically resizeable." % 937 (len(self._tensor_array), lengths.shape[0])) 938 else: 939 self._tensor_array = array_ops.split(value, lengths, name=name) 940 return self.parent() 941 942 def size(self, name=None): 943 """See TensorArray.""" 944 del name # not meaningful when executing eagerly. 945 return constant_op.constant(len(self._tensor_array)) 946 947 def close(self, name=None): 948 del name # not meaningful when executing eagerly. 949 del self._tensor_array[:] 950 951 952# TensorArray is designed to hide an underlying implementation object 953# and as such accesses many of that object's hidden fields. 954# pylint: disable=protected-access 955# pylint:disable=line-too-long 956@tf_export("TensorArray") 957class TensorArray: 958 """Class wrapping dynamic-sized, per-time-step, Tensor arrays. 959 960 This class is meant to be used with dynamic iteration primitives such as 961 `while_loop` and `map_fn`. It supports gradient back-propagation via special 962 "flow" control flow dependencies. 963 964 Note that although the array can be read multiple times and positions can be 965 overwritten, behavior may be undefined when storing multiple references to 966 the same array and clear_after_read is False. In particular, avoid using 967 methods like concat() to convert an intermediate TensorArray to a Tensor, 968 then further modifying the TensorArray, particularly if you need to backprop 969 through it later. 970 971 Example 1: Plain reading and writing. 972 973 >>> ta = tf.TensorArray(tf.float32, size=0, dynamic_size=True, clear_after_read=False) 974 >>> ta = ta.write(0, 10) 975 >>> ta = ta.write(1, 20) 976 >>> ta = ta.write(2, 30) 977 >>> 978 >>> ta.read(0) 979 <tf.Tensor: shape=(), dtype=float32, numpy=10.0> 980 >>> ta.read(1) 981 <tf.Tensor: shape=(), dtype=float32, numpy=20.0> 982 >>> ta.read(2) 983 <tf.Tensor: shape=(), dtype=float32, numpy=30.0> 984 >>> ta.stack() 985 <tf.Tensor: shape=(3,), dtype=float32, numpy=array([10., 20., 30.], 986 dtype=float32)> 987 988 Example 2: Fibonacci sequence algorithm that writes in a loop then returns. 989 990 >>> @tf.function 991 ... def fibonacci(n): 992 ... ta = tf.TensorArray(tf.float32, size=0, dynamic_size=True) 993 ... ta = ta.unstack([0., 1.]) 994 ... 995 ... for i in range(2, n): 996 ... ta = ta.write(i, ta.read(i - 1) + ta.read(i - 2)) 997 ... 998 ... return ta.stack() 999 >>> 1000 >>> fibonacci(7) 1001 <tf.Tensor: shape=(7,), dtype=float32, 1002 numpy=array([0., 1., 1., 2., 3., 5., 8.], dtype=float32)> 1003 1004 Example 3: A simple loop interacting with a `tf.Variable`. 1005 1006 >>> v = tf.Variable(1) 1007 >>> @tf.function 1008 ... def f(x): 1009 ... ta = tf.TensorArray(tf.int32, size=0, dynamic_size=True) 1010 ... for i in tf.range(x): 1011 ... v.assign_add(i) 1012 ... ta = ta.write(i, v) 1013 ... return ta.stack() 1014 >>> f(5) 1015 <tf.Tensor: shape=(5,), dtype=int32, numpy=array([ 1, 2, 4, 7, 11], 1016 dtype=int32)> 1017 """ 1018 1019 def __init__(self, 1020 dtype, 1021 size=None, 1022 dynamic_size=None, 1023 clear_after_read=None, 1024 tensor_array_name=None, 1025 handle=None, 1026 flow=None, 1027 infer_shape=True, 1028 element_shape=None, 1029 colocate_with_first_write_call=True, 1030 name=None): 1031 """Construct a new TensorArray or wrap an existing TensorArray handle. 1032 1033 A note about the parameter `name`: 1034 1035 The name of the `TensorArray` (even if passed in) is uniquified: each time 1036 a new `TensorArray` is created at runtime it is assigned its own name for 1037 the duration of the run. This avoids name collisions if a `TensorArray` 1038 is created within a `while_loop`. 1039 1040 Args: 1041 dtype: (required) data type of the TensorArray. 1042 size: (optional) int32 scalar `Tensor`: the size of the TensorArray. 1043 Required if handle is not provided. 1044 dynamic_size: (optional) Python bool: If true, writes to the TensorArray 1045 can grow the TensorArray past its initial size. Default: False. 1046 clear_after_read: Boolean (optional, default: True). If True, clear 1047 TensorArray values after reading them. This disables read-many 1048 semantics, but allows early release of memory. 1049 tensor_array_name: (optional) Python string: the name of the TensorArray. 1050 This is used when creating the TensorArray handle. If this value is 1051 set, handle should be None. 1052 handle: (optional) A `Tensor` handle to an existing TensorArray. If this 1053 is set, tensor_array_name should be None. Only supported in graph mode. 1054 flow: (optional) A float `Tensor` scalar coming from an existing 1055 `TensorArray.flow`. Only supported in graph mode. 1056 infer_shape: (optional, default: True) If True, shape inference is 1057 enabled. In this case, all elements must have the same shape. 1058 element_shape: (optional, default: None) A `TensorShape` object specifying 1059 the shape constraints of each of the elements of the TensorArray. Need 1060 not be fully defined. 1061 colocate_with_first_write_call: If `True`, the TensorArray will be 1062 colocated on the same device as the Tensor used on its first write 1063 (write operations include `write`, `unstack`, and `split`). If `False`, 1064 the TensorArray will be placed on the device determined by the device 1065 context available during its initialization. 1066 name: A name for the operation (optional). 1067 1068 Raises: 1069 ValueError: if both handle and tensor_array_name are provided. 1070 TypeError: if handle is provided but is not a Tensor. 1071 """ 1072 if (context.executing_eagerly() and 1073 (flow is None or flow.dtype != dtypes.variant)): 1074 # It is possible to create a Variant-style TensorArray even in eager mode, 1075 # and this is fine but can have performance implications in eager. 1076 # An example of when this happens is if a tf.function returns a 1077 # TensorArray in its output; its flow variant object is returned to Eager. 1078 # This can be wrapped back up in a Variant-style TensorArray. 1079 implementation = _EagerTensorArray 1080 elif (flow is not None and flow.dtype == dtypes.variant or 1081 control_flow_util.EnableControlFlowV2(ops.get_default_graph())): 1082 implementation = _GraphTensorArrayV2 1083 else: 1084 implementation = _GraphTensorArray 1085 self._implementation = implementation( 1086 dtype, 1087 size=size, 1088 dynamic_size=dynamic_size, 1089 clear_after_read=clear_after_read, 1090 tensor_array_name=tensor_array_name, 1091 handle=handle, 1092 flow=flow, 1093 infer_shape=infer_shape, 1094 element_shape=element_shape, 1095 colocate_with_first_write_call=colocate_with_first_write_call, 1096 name=name) 1097 1098 self._implementation.parent = weakref.ref(self) 1099 1100 @property 1101 def flow(self): 1102 """The flow `Tensor` forcing ops leading to this TensorArray state.""" 1103 return self._implementation._flow 1104 1105 @property 1106 def dtype(self): 1107 """The data type of this TensorArray.""" 1108 return self._implementation._dtype 1109 1110 @property 1111 def handle(self): 1112 """The reference to the TensorArray.""" 1113 return self._implementation.handle 1114 1115 @property 1116 def element_shape(self): 1117 """The `tf.TensorShape` of elements in this TensorArray.""" 1118 return self._implementation.element_shape 1119 1120 @property 1121 def dynamic_size(self): 1122 """Python bool; if `True` the TensorArray can grow dynamically.""" 1123 return self._implementation._dynamic_size 1124 1125 @property 1126 def _infer_shape(self): 1127 # TODO(slebedev): consider making public or changing TensorArrayStructure 1128 # to access _implementation directly. Note that dynamic_size is also 1129 # only used by TensorArrayStructure. 1130 return self._implementation._infer_shape 1131 1132 def identity(self): 1133 """Returns a TensorArray with the same content and properties. 1134 1135 Returns: 1136 A new TensorArray object with flow that ensures the control dependencies 1137 from the contexts will become control dependencies for writes, reads, etc. 1138 Use this object for all subsequent operations. 1139 """ 1140 return self._implementation.identity() 1141 1142 def grad(self, source, flow=None, name=None): 1143 return self._implementation.grad(source, flow=flow, name=name) 1144 1145 def read(self, index, name=None): 1146 """Read the value at location `index` in the TensorArray. 1147 1148 Args: 1149 index: 0-D. int32 tensor with the index to read from. 1150 name: A name for the operation (optional). 1151 1152 Returns: 1153 The tensor at index `index`. 1154 """ 1155 return self._implementation.read(index, name=name) 1156 1157 @tf_should_use.should_use_result(warn_in_eager=True) 1158 def write(self, index, value, name=None): 1159 """Write `value` into index `index` of the TensorArray. 1160 1161 Args: 1162 index: 0-D. int32 scalar with the index to write to. 1163 value: N-D. Tensor of type `dtype`. The Tensor to write to this index. 1164 name: A name for the operation (optional). 1165 1166 Returns: 1167 A new TensorArray object with flow that ensures the write occurs. 1168 Use this object for all subsequent operations. 1169 1170 Raises: 1171 ValueError: if there are more writers than specified. 1172 """ 1173 return self._implementation.write(index, value, name=name) 1174 1175 def stack(self, name=None): 1176 """Return the values in the TensorArray as a stacked `Tensor`. 1177 1178 All of the values must have been written and their shapes must all match. 1179 If input shapes have rank-`R`, then output shape will have rank-`(R+1)`. 1180 1181 For example: 1182 1183 1184 >>> ta = tf.TensorArray(tf.int32, size=3) 1185 >>> ta.write(0, tf.constant([1, 2])) 1186 >>> ta.write(1, tf.constant([3, 4])) 1187 >>> ta.write(2, tf.constant([5, 6])) 1188 >>> ta.stack() 1189 <tf.Tensor: shape=(3, 2), dtype=int32, numpy= 1190 array([[1, 2], 1191 [3, 4], 1192 [5, 6]], dtype=int32)> 1193 1194 1195 Args: 1196 name: A name for the operation (optional). 1197 1198 Returns: 1199 All the tensors in the TensorArray stacked into one tensor. 1200 """ 1201 return self._implementation.stack(name=name) 1202 1203 def gather(self, indices, name=None): 1204 """Return selected values in the TensorArray as a packed `Tensor`. 1205 1206 All of selected values must have been written and their shapes 1207 must all match. 1208 1209 Args: 1210 indices: A `1-D` `Tensor` taking values in `[0, max_value)`. If the 1211 `TensorArray` is not dynamic, `max_value=size()`. 1212 name: A name for the operation (optional). 1213 1214 Returns: 1215 The tensors in the `TensorArray` selected by `indices`, packed into one 1216 tensor. 1217 """ 1218 return self._implementation.gather(indices, name=name) 1219 1220 def concat(self, name=None): 1221 """Return the values in the TensorArray as a concatenated `Tensor`. 1222 1223 All of the values must have been written, their ranks must match, and 1224 and their shapes must all match for all dimensions except the first. 1225 1226 Args: 1227 name: A name for the operation (optional). 1228 1229 Returns: 1230 All the tensors in the TensorArray concatenated into one tensor. 1231 """ 1232 return self._implementation.concat(name=name) 1233 1234 @tf_should_use.should_use_result 1235 def unstack(self, value, name=None): 1236 """Unstack the values of a `Tensor` in the TensorArray. 1237 1238 If input value shapes have rank-`R`, then the output TensorArray will 1239 contain elements whose shapes are rank-`(R-1)`. 1240 1241 Args: 1242 value: (N+1)-D. Tensor of type `dtype`. The Tensor to unstack. 1243 name: A name for the operation (optional). 1244 1245 Returns: 1246 A new TensorArray object with flow that ensures the unstack occurs. 1247 Use this object for all subsequent operations. 1248 1249 Raises: 1250 ValueError: if the shape inference fails. 1251 """ 1252 return self._implementation.unstack(value, name=name) 1253 1254 @tf_should_use.should_use_result 1255 def scatter(self, indices, value, name=None): 1256 """Scatter the values of a `Tensor` in specific indices of a `TensorArray`. 1257 1258 Args: 1259 indices: A `1-D` `Tensor` taking values in `[0, max_value)`. If the 1260 `TensorArray` is not dynamic, `max_value=size()`. 1261 value: (N+1)-D. Tensor of type `dtype`. The Tensor to unpack. 1262 name: A name for the operation (optional). 1263 1264 Returns: 1265 A new TensorArray object with flow that ensures the scatter occurs. 1266 Use this object for all subsequent operations. 1267 1268 Raises: 1269 ValueError: if the shape inference fails. 1270 """ 1271 return self._implementation.scatter(indices, value, name=name) 1272 1273 @tf_should_use.should_use_result 1274 def split(self, value, lengths, name=None): 1275 """Split the values of a `Tensor` into the TensorArray. 1276 1277 Args: 1278 value: (N+1)-D. Tensor of type `dtype`. The Tensor to split. 1279 lengths: 1-D. int32 vector with the lengths to use when splitting `value` 1280 along its first dimension. 1281 name: A name for the operation (optional). 1282 1283 Returns: 1284 A new TensorArray object with flow that ensures the split occurs. 1285 Use this object for all subsequent operations. 1286 1287 Raises: 1288 ValueError: if the shape inference fails. 1289 """ 1290 return self._implementation.split(value, lengths, name=name) 1291 1292 def size(self, name=None): 1293 """Return the size of the TensorArray.""" 1294 return self._implementation.size(name=name) 1295 1296 @tf_should_use.should_use_result 1297 def close(self, name=None): 1298 """Close the current TensorArray.""" 1299 return self._implementation.close(name=name) 1300 1301 1302def build_ta_with_new_flow(old_ta, flow): 1303 """Builds a TensorArray with a new `flow` tensor.""" 1304 # Sometimes we get old_ta as the implementation, sometimes it's the 1305 # TensorArray wrapper object. 1306 impl = (old_ta._implementation if isinstance(old_ta, TensorArray) else old_ta) 1307 1308 if not context.executing_eagerly(): 1309 if (not isinstance(impl, _GraphTensorArrayV2) and 1310 control_flow_util.EnableControlFlowV2(ops.get_default_graph())): 1311 raise NotImplementedError("Attempting to build a graph-mode TF2-style " 1312 "TensorArray from either an eager-mode " 1313 "TensorArray or a TF1-style TensorArray. " 1314 "This is not currently supported. You may be " 1315 "attempting to capture a TensorArray " 1316 "inside a tf.function or tf.data map function. " 1317 "Instead, construct a new TensorArray inside " 1318 "the function.") 1319 new_ta = TensorArray( 1320 dtype=impl.dtype, 1321 handle=impl.handle, 1322 flow=flow, 1323 infer_shape=impl._infer_shape, 1324 colocate_with_first_write_call=impl._colocate_with_first_write_call) 1325 new_impl = new_ta._implementation 1326 new_impl._dynamic_size = impl._dynamic_size 1327 new_impl._size = impl._size 1328 new_impl._colocate_with = impl._colocate_with 1329 new_impl._element_shape = impl._element_shape # Share _element_shape. 1330 return new_ta 1331 1332 1333# pylint: enable=protected-access 1334 1335 1336def _check_dtypes(value, dtype): 1337 if value.dtype != dtype: 1338 logging.error("Error: Input value {} has dtype {}, but expected dtype {}. " 1339 "This leads to undefined behavior and will be an error " 1340 "in future versions of TensorFlow. Traceback:\n{}".format( 1341 value, str(value.dtype), str(dtype), 1342 "".join(traceback.format_stack()))) 1343 1344 1345@tf_export("TensorArraySpec") 1346@type_spec.register("tf.TensorArraySpec") 1347class TensorArraySpec(type_spec.TypeSpec): 1348 """Type specification for a `tf.TensorArray`.""" 1349 1350 __slots__ = ["_element_shape", "_dtype", "_dynamic_size", "_infer_shape"] 1351 1352 value_type = property(lambda self: TensorArray) 1353 1354 def __init__(self, 1355 element_shape=None, 1356 dtype=dtypes.float32, 1357 dynamic_size=False, 1358 infer_shape=True): 1359 """Constructs a type specification for a `tf.TensorArray`. 1360 1361 Args: 1362 element_shape: The shape of each element in the `TensorArray`. 1363 dtype: Data type of the `TensorArray`. 1364 dynamic_size: Whether the `TensorArray` can grow past its initial size. 1365 infer_shape: Whether shape inference is enabled. 1366 """ 1367 self._element_shape = tensor_shape.as_shape(element_shape) 1368 self._dtype = dtypes.as_dtype(dtype) 1369 self._dynamic_size = dynamic_size 1370 self._infer_shape = infer_shape 1371 1372 def is_subtype_of(self, other): 1373 # pylint: disable=protected-access 1374 return (isinstance(other, TensorArraySpec) and 1375 self._dtype == other._dtype and 1376 self._dynamic_size == other._dynamic_size) 1377 1378 def most_specific_common_supertype(self, others): 1379 """Returns the most specific supertype of `self` and `others`. 1380 1381 Args: 1382 others: A Sequence of `TypeSpec`. 1383 1384 Returns `None` if a supertype does not exist. 1385 """ 1386 # pylint: disable=protected-access 1387 if not all(isinstance(other, TensorArraySpec) for other in others): 1388 return False 1389 1390 common_shape = self._element_shape.most_specific_common_supertype( 1391 other._element_shape for other in others) 1392 if common_shape is None: 1393 return None 1394 1395 if not all(self._dtype == other._dtype for other in others): 1396 return None 1397 1398 if not all(self._dynamic_size == other._dynamic_size for other in others): 1399 return None 1400 1401 infer_shape = self._infer_shape and all( 1402 other._infer_shape for other in others) 1403 1404 return TensorArraySpec(common_shape, self._dtype, self._dynamic_size, 1405 infer_shape) 1406 1407 def is_compatible_with(self, other): 1408 # pylint: disable=protected-access 1409 if not isinstance(other, type_spec.TypeSpec): 1410 other = type_spec.type_spec_from_value(other) 1411 1412 # Note: we intentionally exclude infer_shape in this check. 1413 return (isinstance(other, TensorArraySpec) and 1414 self._dtype.is_compatible_with(other._dtype) and 1415 self._element_shape.is_compatible_with(other._element_shape) and 1416 self._dynamic_size == other._dynamic_size) 1417 1418 def _serialize(self): 1419 return (self._element_shape, self._dtype, self._dynamic_size, 1420 self._infer_shape) 1421 1422 @property 1423 def _component_specs(self): 1424 return [tensor_spec.TensorSpec([], dtypes.variant)] 1425 1426 def _to_components(self, value): 1427 if not isinstance(value, TensorArray): 1428 raise TypeError("Expected value to be a TensorArray, but got: `{}`".format( 1429 type(value))) 1430 if value.flow is not None and value.flow.dtype == dtypes.variant: 1431 return [value.flow] 1432 else: 1433 # Convert to a TF2-style TensorArray. 1434 # TODO(ebrevdo): Add an "_as_variant" method to TensorArray class, or 1435 # "implementation / as_variant" arg to TensorArray constructor. 1436 with ops.name_scope("convert_tensor_array"): 1437 flow = list_ops.tensor_list_from_tensor( 1438 tensor=value.stack(), element_shape=value.element_shape) 1439 return [flow] 1440 1441 def _from_components(self, tensor_list): 1442 # This will return a TF2 Graph-style TensorArray because tensor_list[0] is 1443 # a variant object. size == -1 implies unknown size. 1444 ret = TensorArray( 1445 dtype=self._dtype, 1446 flow=tensor_list[0], 1447 dynamic_size=self._dynamic_size, 1448 infer_shape=self._infer_shape) 1449 ret._implementation._element_shape = [self._element_shape] # pylint: disable=protected-access 1450 return ret 1451 1452 @staticmethod 1453 def from_value(value): 1454 if not isinstance(value, TensorArray): 1455 raise TypeError("Expected value to be a TensorArray, but got: `{}`".format( 1456 type(value))) 1457 1458 return TensorArraySpec( 1459 dtype=value.dtype, 1460 element_shape=value.element_shape, 1461 dynamic_size=value.dynamic_size, 1462 infer_shape=value._infer_shape) # pylint: disable=protected-access 1463 1464 def _to_legacy_output_types(self): 1465 return self._dtype 1466 1467 def _to_legacy_output_shapes(self): 1468 # Sneak the dynamic_size and infer_shape values into the legacy shape. 1469 return (tensor_shape.TensorShape([self._dynamic_size, self._infer_shape 1470 ]).concatenate(self._element_shape)) 1471 1472 def _to_legacy_output_classes(self): 1473 return TensorArray 1474 1475 1476# Register the TypeSpec for TensorArray. If TensorArray is updated to be a 1477# CompositeTensor, then this registration can be deleted. 1478type_spec.register_type_spec_from_value_converter( 1479 TensorArray, TensorArraySpec.from_value, allow_subclass=True) 1480