1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Type specifications for TensorFlow APIs.""" 16 17import abc 18import functools 19import re 20from typing import Any, List, Optional, Sequence, Type 21import warnings 22 23import numpy as np 24 25from tensorflow.core.function import trace_type 26from tensorflow.core.protobuf import struct_pb2 27from tensorflow.python.framework import composite_tensor 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import tensor_shape 30from tensorflow.python.platform import tf_logging as logging 31from tensorflow.python.types import trace 32from tensorflow.python.util import _pywrap_utils 33from tensorflow.python.util import compat 34from tensorflow.python.util import deprecation 35from tensorflow.python.util import nest 36from tensorflow.python.util import tf_decorator 37from tensorflow.python.util.lazy_loader import LazyLoader 38from tensorflow.python.util.tf_export import tf_export 39 40# Use LazyLoader to avoid circular dependencies. 41tensor_spec = LazyLoader( 42 "tensor_spec", globals(), 43 "tensorflow.python.framework.tensor_spec") 44ops = LazyLoader("ops", globals(), 45 "tensorflow.python.framework.ops") 46# TODO(b/238903802): Remove this dependency. 47nested_structure_coder = LazyLoader( 48 "nested_structure_coder", globals(), 49 "tensorflow.python.saved_model.nested_structure_coder") 50 51 52@tf_export("TypeSpec", v1=["TypeSpec", "data.experimental.Structure"]) 53class TypeSpec( 54 trace.TraceType, trace_type.Serializable, metaclass=abc.ABCMeta): 55 """Specifies a TensorFlow value type. 56 57 A `tf.TypeSpec` provides metadata describing an object accepted or returned 58 by TensorFlow APIs. Concrete subclasses, such as `tf.TensorSpec` and 59 `tf.RaggedTensorSpec`, are used to describe different value types. 60 61 For example, `tf.function`'s `input_signature` argument accepts a list 62 (or nested structure) of `TypeSpec`s. 63 64 Creating new subclasses of `TypeSpec` (outside of TensorFlow core) is not 65 currently supported. In particular, we may make breaking changes to the 66 private methods and properties defined by this base class. 67 68 Example: 69 70 >>> spec = tf.RaggedTensorSpec(shape=[None, None], dtype=tf.int32) 71 >>> @tf.function(input_signature=[spec]) 72 ... def double(x): 73 ... return x * 2 74 >>> print(double(tf.ragged.constant([[1, 2], [3]]))) 75 <tf.RaggedTensor [[2, 4], [6]]> 76 """ 77 # === Subclassing === 78 # 79 # Each `TypeSpec` subclass must define: 80 # 81 # * A "component encoding" for values. 82 # * A "serialization" for types. 83 # 84 # The component encoding for a value is a nested structure of `tf.Tensor` 85 # or `CompositeTensor` that can be used by the `TypeSpec` to reconstruct 86 # the value. Each individual `TypeSpec` must use the same nested structure 87 # for all values -- this structure is defined by the `component_specs` 88 # attribute. Decomposing values into components, and reconstructing them 89 # from those components, should be inexpensive. In particular, it should 90 # *not* require any TensorFlow ops. 91 # 92 # The serialization for a `TypeSpec` is a nested tuple of values that can 93 # be used to reconstruct the `TypeSpec`. See the documentation for 94 # `_serialize()` for more information. 95 96 __slots__ = [] 97 98 @abc.abstractproperty 99 def value_type(self): 100 """The Python type for values that are compatible with this TypeSpec. 101 102 In particular, all values that are compatible with this TypeSpec must be an 103 instance of this type. 104 """ 105 raise NotImplementedError("%s.value_type" % type(self).__name__) 106 107 def is_subtype_of(self, other: trace.TraceType) -> bool: 108 """Returns True if `self` is a subtype of `other`. 109 110 Implements the tf.types.experimental.func.TraceType interface. 111 112 If not overridden by a subclass, the default behavior is to assume the 113 TypeSpec is covariant upon attributes that implement TraceType and 114 invariant upon rest of the attributes as well as the structure and type 115 of the TypeSpec. 116 117 Args: 118 other: A TraceType object. 119 """ 120 if type(self) is not type(other): 121 return False 122 123 is_subtype = True 124 125 def check_attribute(attribute_self, attribute_other): 126 nonlocal is_subtype 127 if not is_subtype: 128 return 129 130 if isinstance(attribute_self, trace.TraceType): 131 if not attribute_self.is_subtype_of(attribute_other): 132 is_subtype = False 133 return 134 else: 135 if attribute_self != attribute_other: 136 is_subtype = False 137 138 try: 139 # TODO(b/217959193): Replace _serialize with parameter decomposition. 140 nest.map_structure(check_attribute, self._serialize(), other._serialize()) # pylint: disable=protected-access 141 except (ValueError, TypeError): 142 return False 143 144 return is_subtype 145 146 def most_specific_common_supertype( 147 self, others: Sequence[trace.TraceType]) -> Optional["TypeSpec"]: 148 """Returns the most specific supertype TypeSpec of `self` and `others`. 149 150 Implements the tf.types.experimental.func.TraceType interface. 151 152 If not overridden by a subclass, the default behavior is to assume the 153 TypeSpec is covariant upon attributes that implement TraceType and 154 invariant upon rest of the attributes as well as the structure and type 155 of the TypeSpec. 156 157 Args: 158 others: A sequence of TraceTypes. 159 """ 160 if any(type(self) is not type(other) for other in others): 161 return None 162 163 has_supertype = True 164 165 def make_supertype_attribute(attribute_self, *attribute_others): 166 nonlocal has_supertype 167 if not has_supertype: 168 return 169 170 if isinstance(attribute_self, trace.TraceType): 171 attribute_supertype = attribute_self.most_specific_common_supertype( 172 attribute_others) 173 if attribute_supertype is None: 174 has_supertype = False 175 return 176 return attribute_supertype 177 else: 178 if not all(attribute_self == attribute_other 179 for attribute_other in attribute_others): 180 has_supertype = False 181 return 182 return attribute_self 183 184 try: 185 # TODO(b/217959193): Replace _serialize with parameter decomposition. 186 serialized_supertype = nest.map_structure( 187 make_supertype_attribute, self._serialize(), 188 *(o._serialize() for o in others)) # pylint: disable=protected-access 189 except (ValueError, TypeError): 190 return None 191 192 return self._deserialize(serialized_supertype) if has_supertype else None 193 194 @classmethod 195 def experimental_type_proto(cls) -> Type[struct_pb2.TypeSpecProto]: 196 """Returns the type of proto associated with TypeSpec serialization. 197 198 Do NOT override for custom non-TF types. 199 """ 200 return struct_pb2.TypeSpecProto 201 202 @classmethod 203 def experimental_from_proto(cls, 204 proto: struct_pb2.TypeSpecProto) -> "TypeSpec": 205 """Returns a TypeSpec instance based on the serialized proto. 206 207 Do NOT override for custom non-TF types. 208 209 Args: 210 proto: Proto generated using 'experimental_as_proto'. 211 """ 212 return nested_structure_coder.decode_proto( 213 struct_pb2.StructuredValue(type_spec_value=proto)) 214 215 def experimental_as_proto(self) -> struct_pb2.TypeSpecProto: 216 """Returns a proto representation of the TypeSpec instance. 217 218 Do NOT override for custom non-TF types. 219 """ 220 return nested_structure_coder.encode_structure(self).type_spec_value 221 222 # TODO(b/223659753): Return the actual Tensor-based value instead of spec. 223 def _placeholder_value(self) -> "TypeSpec": 224 """Value used for tracing a function signature with this TraceType.""" 225 return self 226 227 # TODO(b/225058047): Reconsider semantics. 228 def is_compatible_with(self, spec_or_value): 229 """Returns true if `spec_or_value` is compatible with this TypeSpec. 230 231 Prefer using "is_subtype_of" and "most_specific_common_supertype" wherever 232 possible. 233 234 Args: 235 spec_or_value: A TypeSpec or TypeSpec associated value to compare against. 236 """ 237 # === Subclassing === 238 # If not overridden by subclasses, the default behavior is to convert 239 # `spec_or_value` to a `TypeSpec` (if it isn't already); and then to 240 # consider two `TypeSpec`s compatible if they have the same type, and 241 # the values returned by `_serialize` are compatible (where 242 # `tf.TensorShape`, `tf.TensorSpec`, and `tf.DType` are checked for 243 # compatibility using their `is_compatible_with` method; and all other 244 # types are considered compatible if they are equal). 245 if not isinstance(spec_or_value, TypeSpec): 246 spec_or_value = type_spec_from_value(spec_or_value) 247 if type(self) is not type(spec_or_value): 248 return False 249 return self.__is_compatible(self._serialize(), spec_or_value._serialize()) # pylint: disable=protected-access 250 251 @deprecation.deprecated(None, "Use most_specific_common_supertype instead.") 252 def most_specific_compatible_type(self, other: "TypeSpec") -> "TypeSpec": 253 """Returns the most specific TypeSpec compatible with `self` and `other`. 254 255 Deprecated. Please use `most_specific_common_supertype` instead. 256 Do not override this function. 257 258 Args: 259 other: A `TypeSpec`. 260 261 Raises: 262 ValueError: If there is no TypeSpec that is compatible with both `self` 263 and `other`. 264 """ 265 result = self.most_specific_common_supertype([other]) 266 if result is None: 267 raise ValueError("No TypeSpec is compatible with both %s and %s" % 268 (self, other)) 269 return result 270 271 # TODO(b/226395276): Delete after removing usages. 272 def _with_tensor_ranks_only(self) -> "TypeSpec": 273 """Returns a TypeSpec compatible with `self`, with tensor shapes relaxed. 274 275 Returns: 276 A `TypeSpec` that is compatible with `self`, where any `TensorShape` 277 information has been relaxed to include only tensor rank (and not 278 the dimension sizes for individual axes). 279 """ 280 281 # === Subclassing === 282 # If not overridden by a subclass, the default behavior is to serialize 283 # this TypeSpec, relax any TensorSpec or TensorShape values, and 284 # deserialize the result. 285 286 def relax(value): 287 if isinstance(value, TypeSpec): 288 return value._with_tensor_ranks_only() # pylint: disable=protected-access 289 elif (isinstance(value, tensor_shape.TensorShape) and 290 value.rank is not None): 291 return tensor_shape.TensorShape([None] * value.rank) 292 else: 293 return value 294 295 return self._deserialize(nest.map_structure(relax, self._serialize())) 296 297 # TODO(b/206014848): Helper function to support logic that does not consider 298 # Tensor name. Will be removed once load-bearing usages of Tensor name are 299 # fixed. 300 def _without_tensor_names(self) -> "TypeSpec": 301 """Returns a TypeSpec compatible with `self`, with tensor names removed. 302 303 Returns: 304 A `TypeSpec` that is compatible with `self`, where the name of any 305 `TensorSpec` is set to `None`. 306 """ 307 308 # === Subclassing === 309 # If not overridden by a subclass, the default behavior is to serialize 310 # this TypeSpec, set the TensorSpecs' names to None, and deserialize the 311 # result. 312 313 def rename(value): 314 if isinstance(value, TypeSpec): 315 return value._without_tensor_names() # pylint: disable=protected-access 316 return value 317 318 return self._deserialize(nest.map_structure(rename, self._serialize())) 319 320 # === Component encoding for values === 321 322 @abc.abstractmethod 323 def _to_components(self, value): 324 """Encodes `value` as a nested structure of `Tensor` or `CompositeTensor`. 325 326 Args: 327 value: A value compatible with this `TypeSpec`. (Caller is responsible 328 for ensuring compatibility.) 329 330 Returns: 331 A nested structure of `tf.Tensor` or `tf.CompositeTensor` compatible with 332 `self._component_specs`, which can be used to reconstruct `value`. 333 """ 334 # === Subclassing === 335 # This method must be inexpensive (do not call TF ops). 336 raise NotImplementedError("%s._to_components()" % type(self).__name__) 337 338 @abc.abstractmethod 339 def _from_components(self, components): 340 """Reconstructs a value from a nested structure of Tensor/CompositeTensor. 341 342 Args: 343 components: A nested structure of `tf.Tensor` or `tf.CompositeTensor`, 344 compatible with `self._component_specs`. (Caller is responsible for 345 ensuring compatibility.) 346 347 Returns: 348 A value that is compatible with this `TypeSpec`. 349 """ 350 # === Subclassing === 351 # This method must be inexpensive (do not call TF ops). 352 raise NotImplementedError("%s._from_components()" % type(self).__name__) 353 354 @abc.abstractproperty 355 def _component_specs(self): 356 """A nested structure of TypeSpecs for this type's components. 357 358 Returns: 359 A nested structure describing the component encodings that are returned 360 by this TypeSpec's `_to_components` method. In particular, for a 361 TypeSpec `spec` and a compatible value `value`: 362 363 ``` 364 nest.map_structure(lambda t, c: assert t.is_compatible_with(c), 365 spec._component_specs, spec._to_components(value)) 366 ``` 367 """ 368 raise NotImplementedError("%s._component_specs()" % type(self).__name__) 369 370 # === Tensor list encoding for values === 371 372 def _to_tensor_list(self, value) -> List["ops.Tensor"]: 373 """Encodes `value` as a flat list of `tf.Tensor`. 374 375 By default, this just flattens `self._to_components(value)` using 376 `nest.flatten`. However, subclasses may override this to return a 377 different tensor encoding for values. In particular, some subclasses 378 of `BatchableTypeSpec` override this method to return a "boxed" encoding 379 for values, which then can be batched or unbatched. See 380 `BatchableTypeSpec` for more details. 381 382 Args: 383 value: A value with compatible this `TypeSpec`. (Caller is responsible 384 for ensuring compatibility.) 385 386 Returns: 387 A list of `tf.Tensor`, compatible with `self._flat_tensor_specs`, which 388 can be used to reconstruct `value`. 389 """ 390 return nest.flatten(self._to_components(value), expand_composites=True) 391 392 def _from_tensor_list(self, tensor_list: List["ops.Tensor"]) -> Any: 393 """Reconstructs a value from a flat list of `tf.Tensor`. 394 395 Args: 396 tensor_list: A flat list of `tf.Tensor`, compatible with 397 `self._flat_tensor_specs`. 398 399 Returns: 400 A value that is compatible with this `TypeSpec`. 401 402 Raises: 403 ValueError: If `tensor_list` is not compatible with 404 `self._flat_tensor_specs`. 405 """ 406 self.__check_tensor_list(tensor_list) 407 return self._from_compatible_tensor_list(tensor_list) 408 409 def _from_compatible_tensor_list(self, 410 tensor_list: List["ops.Tensor"]) -> Any: 411 """Reconstructs a value from a compatible flat list of `tf.Tensor`. 412 413 Args: 414 tensor_list: A flat list of `tf.Tensor`, compatible with 415 `self._flat_tensor_specs`. (Caller is responsible for ensuring 416 compatibility.) 417 418 Returns: 419 A value that is compatible with this `TypeSpec`. 420 """ 421 return self._from_components( 422 nest.pack_sequence_as( 423 self._component_specs, tensor_list, expand_composites=True)) 424 425 @property 426 def _flat_tensor_specs(self): 427 """A list of TensorSpecs compatible with self._to_tensor_list(v).""" 428 return nest.flatten(self._component_specs, expand_composites=True) 429 430 # === Serialization for types === 431 432 @abc.abstractmethod 433 def _serialize(self): 434 """Returns a nested tuple containing the state of this TypeSpec. 435 436 The serialization may contain the following value types: boolean, 437 integer, string, float, None, `TensorSpec`, `tf.TensorShape`, `tf.DType`, 438 `np.ndarray`, `TypeSpec`, and nested tuples, namedtuples, dicts, and 439 OrderedDicts of any of the above. 440 441 This method is used to provide default definitions for: equality 442 testing (__eq__, __ne__), hashing (__hash__), pickling (__reduce__), 443 string representation (__repr__), `self.is_compatible_with()`, 444 `self.most_specific_compatible_type()`, and protobuf serialization 445 (e.g. TensorInfo and StructuredValue). 446 """ 447 raise NotImplementedError("%s._serialize()" % type(self).__name__) 448 449 @classmethod 450 def _deserialize(cls, serialization): 451 """Reconstructs a TypeSpec from a value returned by `serialize`. 452 453 Args: 454 serialization: A value returned by _serialize. In some contexts, 455 `namedtuple`s in `serialization` may not have the identical type that 456 was returned by `_serialize` (but its type will still be a `namedtuple` 457 type with the same type name and field names). For example, the code 458 that loads a SavedModel does not have access to the original 459 `namedtuple` type, so it dynamically creates a new `namedtuple` type 460 with the same type name and field names as the original one. If 461 necessary, you can check `serialization` for these duck-typed 462 `nametuple` types, and restore them to the original type. (E.g., this 463 would be necessary if you rely on type checks such as `isinstance` for 464 this `TypeSpec`'s member variables). 465 466 Returns: 467 A `TypeSpec` of type `cls`. 468 """ 469 return cls(*serialization) # pytype: disable=not-instantiable # trace-all-classes 470 471 # === Operators === 472 473 def __eq__(self, other) -> bool: 474 # pylint: disable=protected-access 475 return (type(other) is type(self) and 476 self.__get_cmp_key() == other.__get_cmp_key()) 477 478 def __ne__(self, other) -> bool: 479 return not self == other 480 481 def __hash__(self) -> int: 482 return hash(self.__get_cmp_key()) 483 484 def __reduce__(self): 485 return type(self), self._serialize() 486 487 def __repr__(self) -> str: 488 return "%s%r" % (type(self).__name__, self._serialize()) 489 490 # === Legacy Output === 491 # TODO(b/133606651) Document and/or deprecate the legacy_output methods. 492 # (These are used by tf.data.) 493 494 def _to_legacy_output_types(self): 495 raise NotImplementedError("%s._to_legacy_output_types()" % 496 type(self).__name__) 497 498 def _to_legacy_output_shapes(self): 499 raise NotImplementedError("%s._to_legacy_output_shapes()" % 500 type(self).__name__) 501 502 def _to_legacy_output_classes(self): 503 return self.value_type 504 505 # === Private Helper Methods === 506 507 # TODO(b/154541175): Currently this usage is used to represent a Tensor 508 # argument not a TensorSpec argument as it should be. 509 def __tf_tracing_type__(self, 510 context: trace.TracingContext) -> trace.TraceType: 511 return self 512 513 def __check_tensor_list(self, tensor_list): 514 """Raises an exception if tensor_list incompatible w/ flat_tensor_specs.""" 515 expected = self._flat_tensor_specs 516 specs = [type_spec_from_value(t) for t in tensor_list] 517 if len(specs) != len(expected): 518 raise ValueError(f"Cannot create a {self.value_type.__name__} from the " 519 f"tensor list because the TypeSpec expects " 520 f"{len(expected)} items, but the provided tensor list " 521 f"has {len(specs)} items.") 522 for i, (s1, s2) in enumerate(zip(specs, expected)): 523 if not s1.is_compatible_with(s2): 524 raise ValueError(f"Cannot create a {self.value_type.__name__} from the " 525 f"tensor list because item {i} ({tensor_list[i]!r}) " 526 f"is incompatible with the expected TypeSpec {s2}.") 527 528 def __get_cmp_key(self): 529 """Returns a hashable eq-comparable key for `self`.""" 530 # TODO(b/133606651): Decide whether to cache this value. 531 return (type(self), self.__make_cmp_key(self._serialize())) 532 533 def __make_cmp_key(self, value): 534 """Converts `value` to a hashable key.""" 535 if isinstance(value, (int, float, bool, np.generic, dtypes.DType, TypeSpec, 536 tensor_shape.TensorShape)): 537 return value 538 if isinstance(value, compat.bytes_or_text_types): 539 return value 540 if value is None: 541 return value 542 if isinstance(value, dict): 543 return tuple([ 544 tuple([self.__make_cmp_key(key), 545 self.__make_cmp_key(value[key])]) 546 for key in sorted(value.keys()) 547 ]) 548 if isinstance(value, tuple): 549 return tuple([self.__make_cmp_key(v) for v in value]) 550 if isinstance(value, list): 551 return (list, tuple([self.__make_cmp_key(v) for v in value])) 552 if isinstance(value, np.ndarray): 553 return (np.ndarray, value.shape, 554 TypeSpec.__nested_list_to_tuple(value.tolist())) 555 raise ValueError(f"Cannot generate a hashable key for {self} because " 556 f"the _serialize() method " 557 f"returned an unsupproted value of type {type(value)}") 558 559 @staticmethod 560 def __nested_list_to_tuple(value): 561 """Converts a nested list to a corresponding nested tuple.""" 562 if isinstance(value, list): 563 return tuple(TypeSpec.__nested_list_to_tuple(v) for v in value) 564 return value 565 566 @staticmethod 567 def __same_types(a, b): 568 """Returns whether a and b have the same type, up to namedtuple equivalence. 569 570 Consistent with tf.nest.assert_same_structure(), two namedtuple types 571 are considered the same iff they agree in their class name (without 572 qualification by module name) and in their sequence of field names. 573 This makes namedtuples recreated by nested_structure_coder compatible with 574 their original Python definition. 575 576 Args: 577 a: a Python object. 578 b: a Python object. 579 580 Returns: 581 A boolean that is true iff type(a) and type(b) are the same object 582 or equivalent namedtuple types. 583 """ 584 if nest.is_namedtuple(a) and nest.is_namedtuple(b): 585 return nest.same_namedtuples(a, b) 586 else: 587 return type(a) is type(b) 588 589 @staticmethod 590 def __is_compatible(a, b): 591 """Returns true if the given type serializations compatible.""" 592 if isinstance(a, TypeSpec): 593 return a.is_compatible_with(b) 594 if not TypeSpec.__same_types(a, b): 595 return False 596 if isinstance(a, (list, tuple)): 597 return (len(a) == len(b) and 598 all(TypeSpec.__is_compatible(x, y) for (x, y) in zip(a, b))) 599 if isinstance(a, dict): 600 return (len(a) == len(b) and sorted(a.keys()) == sorted(b.keys()) and 601 all(TypeSpec.__is_compatible(a[k], b[k]) for k in a.keys())) 602 if isinstance(a, (tensor_shape.TensorShape, dtypes.DType)): 603 return a.is_compatible_with(b) 604 return a == b 605 606trace_type.register_serializable(TypeSpec) 607 608 609class TypeSpecBatchEncoder(object, metaclass=abc.ABCMeta): 610 """Class used to encode and decode composite tensor values for batching. 611 612 In order to be batched and unbatched by APIs such as `tf.data.Dataset` and 613 `tf.map_fn`, composite tensors must be encoded using flat tensors that can 614 themselves be batched or unbatched. `TypeSpecBatchEncoder`s are 615 responsible for implementing this encoding. 616 617 If a composite tensor's shape is a prefix of the shape of all of its 618 component tensors, then this encoding can usually be performed by just 619 returning those component tensors as a list. But if the composite tensor 620 has components whose shape has a more complex relationship to the shape 621 of the composite tensor, then a custom `TypeSpecBatchEncoder` may 622 need to be implemented. 623 """ 624 625 @abc.abstractmethod 626 def batch(self, spec, batch_size): 627 """Returns the TypeSpec representing a batch of values described by `spec`. 628 629 Args: 630 spec: The `TypeSpec` for an individual value. 631 batch_size: An `int` indicating the number of values that are batched 632 together, or `None` if the batch size is not known. 633 634 Returns: 635 A `TypeSpec` for a batch of values. 636 """ 637 raise NotImplementedError(f"{type(self).__name__}.batch") 638 639 @abc.abstractmethod 640 def unbatch(self, spec): 641 """Returns the TypeSpec for a single unbatched element in `spec`. 642 643 Args: 644 spec: The `TypeSpec` for a batch of values. 645 646 Returns: 647 A `TypeSpec` for an individual value. 648 """ 649 raise NotImplementedError(f"{type(self).__name__}.unbatch") 650 651 @abc.abstractmethod 652 def encode(self, spec, value, minimum_rank=0): 653 """Encodes `value` as a nest of batchable `Tensor` or `CompositeTensor`. 654 655 Args: 656 spec: The TypeSpec of the value to encode. 657 value: A value compatible with `spec`. 658 minimum_rank: The minimum rank for the returned Tensors, CompositeTensors, 659 and ExtensionType values. This can be used to ensure that the encoded 660 values can be unbatched this number of times. If `minimum_rank>0`, 661 then `t.shape[:minimum_rank]` must be compatible for all values `t` 662 returned by `encode`. 663 664 Returns: 665 A nest (as defined by `tf.nest`) of `tf.Tensor`s, batchable 666 `tf.CompositeTensor`s, or `tf.ExtensionType`s. Stacking, unstacking, or 667 concatenating these encoded values and then decoding the result must be 668 equivalent to stacking, unstacking, or concatenating the original values. 669 """ 670 raise NotImplementedError(f"{type(self).__name__}.encode") 671 672 @abc.abstractmethod 673 def decode(self, spec, encoded_value): 674 """Decodes `value` from a batchable tensor encoding. 675 676 Args: 677 spec: The TypeSpec for the result value. If encoded values with spec `s` 678 were batched, then `spec` should be `s.batch(batch_size)`; or if encoded 679 values with spec `s` were unbatched, then `spec` should be 680 `s.unbatch()`. 681 encoded_value: A nest of values returned by `encode`; or a nest of values 682 that was formed by stacking, unstacking, or concatenating the 683 corresponding elements of values returned by `encode`. 684 685 Returns: 686 A value compatible with `type_spec`. 687 """ 688 raise NotImplementedError(f"{type(self).__name__}.decode") 689 690 @abc.abstractmethod 691 def encoding_specs(self, spec): 692 """Returns a nest of `TypeSpec`(s) describing the encoding for `spec`. 693 694 Args: 695 spec: The TypeSpec whose encoding should be described. 696 697 Returns: 698 A nest (as defined by `tf.nest) of `tf.TypeSpec`, describing the values 699 that are returned by `self.encode(spec, ...)`. All TypeSpecs in this 700 nest must be batchable. 701 """ 702 raise NotImplementedError(f"{type(self).__name__}.encoding_specs") 703 704 705class LegacyTypeSpecBatchEncoder(TypeSpecBatchEncoder): 706 """TypeSpecBatchEncoder for legacy composite tensor classes. 707 708 TODO(edloper): Update existing composite tensors to use non-legacy 709 CompositTensorBatchEncoders. 710 """ 711 712 def batch(self, type_spec, batch_size): 713 return type_spec._batch(batch_size) # pylint: disable=protected-access 714 715 def unbatch(self, type_spec): 716 return type_spec._unbatch() # pylint: disable=protected-access 717 718 def encode(self, type_spec, value, minimum_rank=0): 719 if minimum_rank == 0: 720 return type_spec._to_tensor_list(value) # pylint: disable=protected-access 721 elif minimum_rank == 1: 722 if not isinstance(type_spec, BatchableTypeSpec): 723 raise ValueError(f"{type_spec.__name__}.encode does not support " 724 "minimum_rank>0.") 725 return type_spec._to_batched_tensor_list(value) # pylint: disable=protected-access 726 else: 727 raise ValueError(f"{type_spec.__name__}.encode does not support " 728 "minimum_rank>1.") 729 730 def decode(self, type_spec, encoded_value): 731 return type_spec._from_tensor_list(encoded_value) # pylint: disable=protected-access 732 733 def encoding_specs(self, spec): 734 return spec._flat_tensor_specs # pylint: disable=protected-access 735 736 737class BatchableTypeSpec(TypeSpec, metaclass=abc.ABCMeta): 738 """TypeSpec with a batchable tensor encoding. 739 740 The batchable tensor encoding is a list of `tf.Tensor`s that supports 741 batching and unbatching. In particular, stacking (or unstacking) 742 values with the same `TypeSpec` must be equivalent to stacking (or 743 unstacking) each of their tensor lists. Unlike the component encoding 744 (returned by `self._to_components)`, the batchable tensor encoding 745 may require using encoding/decoding ops. 746 747 If a subclass's batchable tensor encoding is not simply a flattened version 748 of the component encoding, then the subclass must override `_to_tensor_list`, 749 `_from_tensor_list`, and _flat_tensor_specs`. 750 """ 751 752 __slots__ = [] 753 754 __batch_encoder__ = LegacyTypeSpecBatchEncoder() 755 756 @abc.abstractmethod 757 def _batch(self, batch_size) -> TypeSpec: 758 """Returns a TypeSpec representing a batch of objects with this TypeSpec. 759 760 Args: 761 batch_size: An `int` representing the number of elements in a batch, or 762 `None` if the batch size may vary. 763 764 Returns: 765 A `TypeSpec` representing a batch of objects with this TypeSpec. 766 """ 767 raise NotImplementedError(f"{type(self).__name__}._batch") 768 769 @abc.abstractmethod 770 def _unbatch(self) -> TypeSpec: 771 """Returns a TypeSpec representing a single element this TypeSpec. 772 773 Returns: 774 A `TypeSpec` representing a single element of objects with this TypeSpec. 775 """ 776 raise NotImplementedError(f"{type(self).__name__}._unbatch") 777 778# LINT.IfChange 779 @property 780 def _flat_tensor_specs(self) -> List[TypeSpec]: 781 """A list of TensorSpecs compatible with self._to_tensor_list(v).""" 782 component_flat_tensor_specs = nest.map_structure( 783 functools.partial(get_batchable_flat_tensor_specs, context_spec=self), 784 self._component_specs) 785 return nest.flatten(component_flat_tensor_specs) 786# LINT.ThenChange(//tensorflow/python/framework/type_utils.py:_specs_for_flat_tensors) 787# Note that _specs_for_flat_tensors in type_utils.py must correspond 788# _flat_tensor_specs in this class and any derived classes. 789 790 def _to_tensor_list( 791 self, value: composite_tensor.CompositeTensor) -> List["ops.Tensor"]: 792 """Encodes `value` as a flat list of `ops.Tensor`.""" 793 component_tensor_lists = nest.map_structure(batchable_to_tensor_list, 794 self._component_specs, 795 self._to_components(value)) 796 return nest.flatten(component_tensor_lists) 797 798 def _to_batched_tensor_list( 799 self, value: composite_tensor.CompositeTensor) -> List["ops.Tensor"]: 800 """Encodes `value` as a flat list of `ops.Tensor` each with rank>0.""" 801 get_spec_tensor_list = lambda spec, v: ( # pylint: disable=g-long-lambda 802 batchable_to_tensor_list(spec, v, minimum_rank=1) 803 if isinstance(spec, BatchableTypeSpec) else spec._to_tensor_list(v)) # pylint: disable=protected-access 804 component_batched_tensor_lists = nest.map_structure( 805 get_spec_tensor_list, self._component_specs, self._to_components(value)) 806 tensor_list = nest.flatten(component_batched_tensor_lists) 807 if any(t.shape.ndims == 0 for t in tensor_list): 808 raise ValueError( 809 f"While converting {value} to a list of tensors for batching, " 810 f"found a scalar item which cannot be batched.") 811 return tensor_list 812 813 def _from_compatible_tensor_list( 814 self, 815 tensor_list: List["ops.Tensor"]) -> composite_tensor.CompositeTensor: 816 """Reconstructs a value from a compatible flat list of `ops.Tensor`.""" 817 flat_specs = nest.map_structure( 818 functools.partial(get_batchable_flat_tensor_specs, context_spec=self), 819 self._component_specs) 820 nested_tensor_list = nest.pack_sequence_as(flat_specs, tensor_list) 821 components = nest.map_structure_up_to(self._component_specs, 822 batchable_from_tensor_list, 823 self._component_specs, 824 nested_tensor_list) 825 return self._from_components(components) 826 827 828def get_batchable_flat_tensor_specs(spec, context_spec=None): 829 """Returns the flat tensor specs for `spec`.""" 830 if isinstance(spec, tensor_spec.TensorSpec): 831 return [spec] 832 elif hasattr(spec, "__batch_encoder__"): 833 encoding_specs = nest.map_structure( 834 functools.partial( 835 get_batchable_flat_tensor_specs, context_spec=context_spec), 836 spec.__batch_encoder__.encoding_specs(spec)) 837 return nest.flatten(encoding_specs) 838 else: 839 # TODO(edloper) Fix existing CompositeTensors that permit this, and 840 # then turn this warning into an error. 841 warnings.warn(f"Batchable type {context_spec} contains non-batchable " 842 f"field or component with type {spec}.") 843 return spec._flat_tensor_specs # pylint: disable=protected-access 844 845 846def batchable_to_tensor_list(spec, value, minimum_rank=0): 847 """Returns a list of tensors encoding `value`, whose type is `spec`.""" 848 if isinstance(spec, tensor_spec.TensorSpec): 849 return [value] 850 elif hasattr(spec, "__batch_encoder__"): 851 encoded_value = spec.__batch_encoder__.encode(spec, value, minimum_rank) 852 encoded_specs = spec.__batch_encoder__.encoding_specs(spec) 853 encoded_flats = nest.map_structure( 854 functools.partial(batchable_to_tensor_list, minimum_rank=minimum_rank), 855 encoded_specs, encoded_value) 856 return nest.flatten(encoded_flats) 857 else: 858 return spec._to_tensor_list(value) # pylint: disable=protected-access 859 860 861def batchable_from_tensor_list(spec, tensor_list): 862 """Returns a value with type `spec` decoded from `tensor_list`.""" 863 if isinstance(spec, tensor_spec.TensorSpec): 864 assert len(tensor_list) == 1 865 return tensor_list[0] 866 elif hasattr(spec, "__batch_encoder__"): 867 encoded_specs = spec.__batch_encoder__.encoding_specs(spec) 868 flat_specs = nest.map_structure(get_batchable_flat_tensor_specs, 869 encoded_specs) 870 encoded_flats = nest.pack_sequence_as(flat_specs, tensor_list) 871 encoded_value = nest.map_structure_up_to(encoded_specs, 872 batchable_from_tensor_list, 873 encoded_specs, encoded_flats) 874 return spec.__batch_encoder__.decode(spec, encoded_value) 875 else: 876 return spec._from_compatible_tensor_list(tensor_list) # pylint: disable=protected-access 877 878 879@tf_export("type_spec_from_value") 880def type_spec_from_value(value) -> TypeSpec: 881 """Returns a `tf.TypeSpec` that represents the given `value`. 882 883 Examples: 884 885 >>> tf.type_spec_from_value(tf.constant([1, 2, 3])) 886 TensorSpec(shape=(3,), dtype=tf.int32, name=None) 887 >>> tf.type_spec_from_value(np.array([4.0, 5.0], np.float64)) 888 TensorSpec(shape=(2,), dtype=tf.float64, name=None) 889 >>> tf.type_spec_from_value(tf.ragged.constant([[1, 2], [3, 4, 5]])) 890 RaggedTensorSpec(TensorShape([2, None]), tf.int32, 1, tf.int64) 891 892 >>> example_input = tf.ragged.constant([[1, 2], [3]]) 893 >>> @tf.function(input_signature=[tf.type_spec_from_value(example_input)]) 894 ... def f(x): 895 ... return tf.reduce_sum(x, axis=1) 896 897 Args: 898 value: A value that can be accepted or returned by TensorFlow APIs. Accepted 899 types for `value` include `tf.Tensor`, any value that can be converted to 900 `tf.Tensor` using `tf.convert_to_tensor`, and any subclass of 901 `CompositeTensor` (such as `tf.RaggedTensor`). 902 903 Returns: 904 A `TypeSpec` that is compatible with `value`. 905 906 Raises: 907 TypeError: If a TypeSpec cannot be built for `value`, because its type 908 is not supported. 909 """ 910 spec = _type_spec_from_value(value) 911 if spec is not None: 912 return spec 913 914 # Fallback: try converting value to a tensor. 915 try: 916 tensor = ops.convert_to_tensor(value) 917 spec = _type_spec_from_value(tensor) 918 if spec is not None: 919 return spec 920 except (ValueError, TypeError) as e: 921 logging.vlog( 922 3, "Failed to convert %r to tensor: %s" % (type(value).__name__, e)) 923 924 raise TypeError(f"Could not build a TypeSpec for {value} of " 925 f"unsupported type {type(value)}.") 926 927 928def _type_spec_from_value(value) -> TypeSpec: 929 """Returns a `TypeSpec` that represents the given `value`.""" 930 if isinstance(value, ops.Tensor): 931 # Note: we do not include Tensor names when constructing TypeSpecs. 932 return tensor_spec.TensorSpec(value.shape, value.dtype) 933 934 if isinstance(value, composite_tensor.CompositeTensor): 935 return value._type_spec # pylint: disable=protected-access 936 937 # If `value` is a list and all of its elements can be represented by the same 938 # batchable type spec, then we can represent the entire list using a single 939 # type spec that captures the type accurately (unlike the `convert_to_tensor` 940 # fallback). 941 if isinstance(value, list) and value: 942 subspecs = [_type_spec_from_value(v) for v in value] 943 if isinstance(subspecs[0], BatchableTypeSpec): 944 merged_subspec = subspecs[0].most_specific_common_supertype(subspecs[1:]) 945 if merged_subspec is not None: 946 return merged_subspec._batch(len(subspecs)) # pylint: disable=protected-access 947 948 for entry in reversed(_TYPE_CONVERSION_FUNCTION_REGISTRY): 949 type_object, converter_fn, allow_subclass = entry 950 if ((type(value) is type_object) or # pylint: disable=unidiomatic-typecheck 951 (allow_subclass and isinstance(value, type_object))): 952 return converter_fn(value) 953 954 return None 955 956 957_TYPE_CONVERSION_FUNCTION_REGISTRY = [] 958 959 960def register_type_spec_from_value_converter(type_object, 961 converter_fn, 962 allow_subclass=False): 963 """Registers a function for converting values with a given type to TypeSpecs. 964 965 If multiple registered `type_object`s match a value, then the most recent 966 registration takes precedence. Custom converters should not be defined for 967 `CompositeTensor`s; use `CompositeTensor._type_spec` instead. 968 969 Args: 970 type_object: A Python `type` object representing the type of values accepted 971 by `converter_fn`. 972 converter_fn: A function that takes one argument (an instance of the type 973 represented by `type_object`) and returns a `TypeSpec`. 974 allow_subclass: If true, then use `isinstance(value, type_object)` to check 975 for matches. If false, then use `type(value) is type_object`. 976 """ 977 _, type_object = tf_decorator.unwrap(type_object) 978 _TYPE_CONVERSION_FUNCTION_REGISTRY.append( 979 (type_object, converter_fn, allow_subclass)) 980 981 982_pywrap_utils.RegisterType("TypeSpec", TypeSpec) 983 984_TYPE_SPEC_TO_NAME = {} 985_NAME_TO_TYPE_SPEC = {} 986 987# Regular expression for valid TypeSpec names. 988_REGISTERED_NAME_RE = re.compile(r"^(\w+\.)+\w+$") 989 990 991# TODO(b/173744905) tf_export this as "tf.register_type_spec". (And add a 992# usage example to the docstring, once the API is public.) 993# 994# TODO(b/173744905) Update this decorator to apply to ExtensionType rather than 995# TypeSpec (once we do refactoring to move to_components/from_components from 996# TypeSpec to ExtensionType). 997def register(name): 998 """Decorator used to register a globally unique name for a TypeSpec subclass. 999 1000 Args: 1001 name: The name of the type spec. Must be globally unique. Must have the 1002 form `"{project_name}.{type_name}"`. E.g. `"my_project.MyTypeSpec"`. 1003 1004 Returns: 1005 A class decorator that registers the decorated class with the given name. 1006 """ 1007 if not isinstance(name, str): 1008 raise TypeError("Expected `name` to be a string; got %r" % (name,)) 1009 if not _REGISTERED_NAME_RE.match(name): 1010 raise ValueError( 1011 "Registered name must have the form '{project_name}.{type_name}' " 1012 "(e.g. 'my_project.MyTypeSpec'); got %r." % name) 1013 1014 def decorator_fn(cls): 1015 if not (isinstance(cls, type) and issubclass(cls, TypeSpec)): 1016 raise TypeError("Expected `cls` to be a TypeSpec; got %r" % (cls,)) 1017 if cls in _TYPE_SPEC_TO_NAME: 1018 raise ValueError("Class %s.%s has already been registered with name %s." % 1019 (cls.__module__, cls.__name__, _TYPE_SPEC_TO_NAME[cls])) 1020 if name in _NAME_TO_TYPE_SPEC: 1021 raise ValueError("Name %s has already been registered for class %s.%s." % 1022 (name, _NAME_TO_TYPE_SPEC[name].__module__, 1023 _NAME_TO_TYPE_SPEC[name].__name__)) 1024 _TYPE_SPEC_TO_NAME[cls] = name 1025 _NAME_TO_TYPE_SPEC[name] = cls 1026 return cls 1027 1028 return decorator_fn 1029 1030 1031# TODO(edloper) tf_export this as "tf.get_type_spec_name" (or some similar name) 1032def get_name(cls): 1033 """Returns the registered name for TypeSpec `cls`.""" 1034 if not (isinstance(cls, type) and issubclass(cls, TypeSpec)): 1035 raise TypeError("Expected `cls` to be a TypeSpec; got %r" % (cls,)) 1036 if cls not in _TYPE_SPEC_TO_NAME: 1037 raise ValueError("TypeSpec %s.%s has not been registered." % 1038 (cls.__module__, cls.__name__)) 1039 return _TYPE_SPEC_TO_NAME[cls] 1040 1041 1042# TODO(edloper) tf_export this as "tf.lookup_type_spec" (or some similar name) 1043def lookup(name): 1044 """Returns the TypeSpec that has been registered with name `name`.""" 1045 if not isinstance(name, str): 1046 raise TypeError("Expected `name` to be a string; got %r" % (name,)) 1047 if name not in _NAME_TO_TYPE_SPEC: 1048 raise ValueError("No TypeSpec has been registered with name %r" % (name,)) 1049 return _NAME_TO_TYPE_SPEC[name] 1050