xref: /aosp_15_r20/external/tensorflow/tensorflow/python/framework/type_spec.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Type specifications for TensorFlow APIs."""
16
17import abc
18import functools
19import re
20from typing import Any, List, Optional, Sequence, Type
21import warnings
22
23import numpy as np
24
25from tensorflow.core.function import trace_type
26from tensorflow.core.protobuf import struct_pb2
27from tensorflow.python.framework import composite_tensor
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import tensor_shape
30from tensorflow.python.platform import tf_logging as logging
31from tensorflow.python.types import trace
32from tensorflow.python.util import _pywrap_utils
33from tensorflow.python.util import compat
34from tensorflow.python.util import deprecation
35from tensorflow.python.util import nest
36from tensorflow.python.util import tf_decorator
37from tensorflow.python.util.lazy_loader import LazyLoader
38from tensorflow.python.util.tf_export import tf_export
39
40# Use LazyLoader to avoid circular dependencies.
41tensor_spec = LazyLoader(
42    "tensor_spec", globals(),
43    "tensorflow.python.framework.tensor_spec")
44ops = LazyLoader("ops", globals(),
45                 "tensorflow.python.framework.ops")
46# TODO(b/238903802): Remove this dependency.
47nested_structure_coder = LazyLoader(
48    "nested_structure_coder", globals(),
49    "tensorflow.python.saved_model.nested_structure_coder")
50
51
52@tf_export("TypeSpec", v1=["TypeSpec", "data.experimental.Structure"])
53class TypeSpec(
54    trace.TraceType, trace_type.Serializable, metaclass=abc.ABCMeta):
55  """Specifies a TensorFlow value type.
56
57  A `tf.TypeSpec` provides metadata describing an object accepted or returned
58  by TensorFlow APIs.  Concrete subclasses, such as `tf.TensorSpec` and
59  `tf.RaggedTensorSpec`, are used to describe different value types.
60
61  For example, `tf.function`'s `input_signature` argument accepts a list
62  (or nested structure) of `TypeSpec`s.
63
64  Creating new subclasses of `TypeSpec` (outside of TensorFlow core) is not
65  currently supported.  In particular, we may make breaking changes to the
66  private methods and properties defined by this base class.
67
68  Example:
69
70  >>> spec = tf.RaggedTensorSpec(shape=[None, None], dtype=tf.int32)
71  >>> @tf.function(input_signature=[spec])
72  ... def double(x):
73  ...   return x * 2
74  >>> print(double(tf.ragged.constant([[1, 2], [3]])))
75  <tf.RaggedTensor [[2, 4], [6]]>
76  """
77  # === Subclassing ===
78  #
79  # Each `TypeSpec` subclass must define:
80  #
81  #   * A "component encoding" for values.
82  #   * A "serialization" for types.
83  #
84  # The component encoding for a value is a nested structure of `tf.Tensor`
85  # or `CompositeTensor` that can be used by the `TypeSpec` to reconstruct
86  # the value.  Each individual `TypeSpec` must use the same nested structure
87  # for all values -- this structure is defined by the `component_specs`
88  # attribute.  Decomposing values into components, and reconstructing them
89  # from those components, should be inexpensive.  In particular, it should
90  # *not* require any TensorFlow ops.
91  #
92  # The serialization for a `TypeSpec` is a nested tuple of values that can
93  # be used to reconstruct the `TypeSpec`.  See the documentation for
94  # `_serialize()` for more information.
95
96  __slots__ = []
97
98  @abc.abstractproperty
99  def value_type(self):
100    """The Python type for values that are compatible with this TypeSpec.
101
102    In particular, all values that are compatible with this TypeSpec must be an
103    instance of this type.
104    """
105    raise NotImplementedError("%s.value_type" % type(self).__name__)
106
107  def is_subtype_of(self, other: trace.TraceType) -> bool:
108    """Returns True if `self` is a subtype of `other`.
109
110    Implements the tf.types.experimental.func.TraceType interface.
111
112    If not overridden by a subclass, the default behavior is to assume the
113    TypeSpec is covariant upon attributes that implement TraceType and
114    invariant upon rest of the attributes as well as the structure and type
115    of the TypeSpec.
116
117    Args:
118      other: A TraceType object.
119    """
120    if type(self) is not type(other):
121      return False
122
123    is_subtype = True
124
125    def check_attribute(attribute_self, attribute_other):
126      nonlocal is_subtype
127      if not is_subtype:
128        return
129
130      if isinstance(attribute_self, trace.TraceType):
131        if not attribute_self.is_subtype_of(attribute_other):
132          is_subtype = False
133          return
134      else:
135        if attribute_self != attribute_other:
136          is_subtype = False
137
138    try:
139      # TODO(b/217959193): Replace _serialize with parameter decomposition.
140      nest.map_structure(check_attribute, self._serialize(), other._serialize())  # pylint: disable=protected-access
141    except (ValueError, TypeError):
142      return False
143
144    return is_subtype
145
146  def most_specific_common_supertype(
147      self, others: Sequence[trace.TraceType]) -> Optional["TypeSpec"]:
148    """Returns the most specific supertype TypeSpec  of `self` and `others`.
149
150    Implements the tf.types.experimental.func.TraceType interface.
151
152    If not overridden by a subclass, the default behavior is to assume the
153    TypeSpec is covariant upon attributes that implement TraceType and
154    invariant upon rest of the attributes as well as the structure and type
155    of the TypeSpec.
156
157    Args:
158      others: A sequence of TraceTypes.
159    """
160    if any(type(self) is not type(other) for other in others):
161      return None
162
163    has_supertype = True
164
165    def make_supertype_attribute(attribute_self, *attribute_others):
166      nonlocal has_supertype
167      if not has_supertype:
168        return
169
170      if isinstance(attribute_self, trace.TraceType):
171        attribute_supertype = attribute_self.most_specific_common_supertype(
172            attribute_others)
173        if attribute_supertype is None:
174          has_supertype = False
175          return
176        return attribute_supertype
177      else:
178        if not all(attribute_self == attribute_other
179                   for attribute_other in attribute_others):
180          has_supertype = False
181          return
182        return attribute_self
183
184    try:
185      # TODO(b/217959193): Replace _serialize with parameter decomposition.
186      serialized_supertype = nest.map_structure(
187          make_supertype_attribute, self._serialize(),
188          *(o._serialize() for o in others))  # pylint: disable=protected-access
189    except (ValueError, TypeError):
190      return None
191
192    return self._deserialize(serialized_supertype) if has_supertype else None
193
194  @classmethod
195  def experimental_type_proto(cls) -> Type[struct_pb2.TypeSpecProto]:
196    """Returns the type of proto associated with TypeSpec serialization.
197
198    Do NOT override for custom non-TF types.
199    """
200    return struct_pb2.TypeSpecProto
201
202  @classmethod
203  def experimental_from_proto(cls,
204                              proto: struct_pb2.TypeSpecProto) -> "TypeSpec":
205    """Returns a TypeSpec instance based on the serialized proto.
206
207    Do NOT override for custom non-TF types.
208
209    Args:
210      proto: Proto generated using 'experimental_as_proto'.
211    """
212    return nested_structure_coder.decode_proto(
213        struct_pb2.StructuredValue(type_spec_value=proto))
214
215  def experimental_as_proto(self) -> struct_pb2.TypeSpecProto:
216    """Returns a proto representation of the TypeSpec instance.
217
218    Do NOT override for custom non-TF types.
219    """
220    return nested_structure_coder.encode_structure(self).type_spec_value
221
222  # TODO(b/223659753): Return the actual Tensor-based value instead of spec.
223  def _placeholder_value(self) -> "TypeSpec":
224    """Value used for tracing a function signature with this TraceType."""
225    return self
226
227  # TODO(b/225058047): Reconsider semantics.
228  def is_compatible_with(self, spec_or_value):
229    """Returns true if `spec_or_value` is compatible with this TypeSpec.
230
231    Prefer using "is_subtype_of" and "most_specific_common_supertype" wherever
232    possible.
233
234    Args:
235      spec_or_value: A TypeSpec or TypeSpec associated value to compare against.
236    """
237    # === Subclassing ===
238    # If not overridden by subclasses, the default behavior is to convert
239    # `spec_or_value` to a `TypeSpec` (if it isn't already); and then to
240    # consider two `TypeSpec`s compatible if they have the same type, and
241    # the values returned by `_serialize` are compatible (where
242    # `tf.TensorShape`, `tf.TensorSpec`, and `tf.DType` are checked for
243    # compatibility using their `is_compatible_with` method; and all other
244    # types are considered compatible if they are equal).
245    if not isinstance(spec_or_value, TypeSpec):
246      spec_or_value = type_spec_from_value(spec_or_value)
247    if type(self) is not type(spec_or_value):
248      return False
249    return self.__is_compatible(self._serialize(), spec_or_value._serialize())  # pylint: disable=protected-access
250
251  @deprecation.deprecated(None, "Use most_specific_common_supertype instead.")
252  def most_specific_compatible_type(self, other: "TypeSpec") -> "TypeSpec":
253    """Returns the most specific TypeSpec compatible with `self` and `other`.
254
255    Deprecated. Please use `most_specific_common_supertype` instead.
256    Do not override this function.
257
258    Args:
259      other: A `TypeSpec`.
260
261    Raises:
262      ValueError: If there is no TypeSpec that is compatible with both `self`
263        and `other`.
264    """
265    result = self.most_specific_common_supertype([other])
266    if result is None:
267      raise ValueError("No TypeSpec is compatible with both %s and %s" %
268                       (self, other))
269    return result
270
271  # TODO(b/226395276): Delete after removing usages.
272  def _with_tensor_ranks_only(self) -> "TypeSpec":
273    """Returns a TypeSpec compatible with `self`, with tensor shapes relaxed.
274
275    Returns:
276      A `TypeSpec` that is compatible with `self`, where any `TensorShape`
277      information has been relaxed to include only tensor rank (and not
278      the dimension sizes for individual axes).
279    """
280
281    # === Subclassing ===
282    # If not overridden by a subclass, the default behavior is to serialize
283    # this TypeSpec, relax any TensorSpec or TensorShape values, and
284    # deserialize the result.
285
286    def relax(value):
287      if isinstance(value, TypeSpec):
288        return value._with_tensor_ranks_only()  # pylint: disable=protected-access
289      elif (isinstance(value, tensor_shape.TensorShape) and
290            value.rank is not None):
291        return tensor_shape.TensorShape([None] * value.rank)
292      else:
293        return value
294
295    return self._deserialize(nest.map_structure(relax, self._serialize()))
296
297  # TODO(b/206014848): Helper function to support logic that does not consider
298  # Tensor name. Will be removed once load-bearing usages of Tensor name are
299  # fixed.
300  def _without_tensor_names(self) -> "TypeSpec":
301    """Returns a TypeSpec compatible with `self`, with tensor names removed.
302
303    Returns:
304      A `TypeSpec` that is compatible with `self`, where the name of any
305      `TensorSpec` is set to `None`.
306    """
307
308    # === Subclassing ===
309    # If not overridden by a subclass, the default behavior is to serialize
310    # this TypeSpec, set the TensorSpecs' names to None, and deserialize the
311    # result.
312
313    def rename(value):
314      if isinstance(value, TypeSpec):
315        return value._without_tensor_names()  # pylint: disable=protected-access
316      return value
317
318    return self._deserialize(nest.map_structure(rename, self._serialize()))
319
320  # === Component encoding for values ===
321
322  @abc.abstractmethod
323  def _to_components(self, value):
324    """Encodes `value` as a nested structure of `Tensor` or `CompositeTensor`.
325
326    Args:
327      value: A value compatible with this `TypeSpec`.  (Caller is responsible
328        for ensuring compatibility.)
329
330    Returns:
331      A nested structure of `tf.Tensor` or `tf.CompositeTensor` compatible with
332      `self._component_specs`, which can be used to reconstruct `value`.
333    """
334    # === Subclassing ===
335    # This method must be inexpensive (do not call TF ops).
336    raise NotImplementedError("%s._to_components()" % type(self).__name__)
337
338  @abc.abstractmethod
339  def _from_components(self, components):
340    """Reconstructs a value from a nested structure of Tensor/CompositeTensor.
341
342    Args:
343      components: A nested structure of `tf.Tensor` or `tf.CompositeTensor`,
344        compatible with `self._component_specs`.  (Caller is responsible for
345        ensuring compatibility.)
346
347    Returns:
348      A value that is compatible with this `TypeSpec`.
349    """
350    # === Subclassing ===
351    # This method must be inexpensive (do not call TF ops).
352    raise NotImplementedError("%s._from_components()" % type(self).__name__)
353
354  @abc.abstractproperty
355  def _component_specs(self):
356    """A nested structure of TypeSpecs for this type's components.
357
358    Returns:
359      A nested structure describing the component encodings that are returned
360      by this TypeSpec's `_to_components` method.  In particular, for a
361      TypeSpec `spec` and a compatible value `value`:
362
363      ```
364      nest.map_structure(lambda t, c: assert t.is_compatible_with(c),
365                         spec._component_specs, spec._to_components(value))
366      ```
367    """
368    raise NotImplementedError("%s._component_specs()" % type(self).__name__)
369
370  # === Tensor list encoding for values ===
371
372  def _to_tensor_list(self, value) -> List["ops.Tensor"]:
373    """Encodes `value` as a flat list of `tf.Tensor`.
374
375    By default, this just flattens `self._to_components(value)` using
376    `nest.flatten`.  However, subclasses may override this to return a
377    different tensor encoding for values.  In particular, some subclasses
378    of `BatchableTypeSpec` override this method to return a "boxed" encoding
379    for values, which then can be batched or unbatched.  See
380    `BatchableTypeSpec` for more details.
381
382    Args:
383      value: A value with compatible this `TypeSpec`.  (Caller is responsible
384        for ensuring compatibility.)
385
386    Returns:
387      A list of `tf.Tensor`, compatible with `self._flat_tensor_specs`, which
388      can be used to reconstruct `value`.
389    """
390    return nest.flatten(self._to_components(value), expand_composites=True)
391
392  def _from_tensor_list(self, tensor_list: List["ops.Tensor"]) -> Any:
393    """Reconstructs a value from a flat list of `tf.Tensor`.
394
395    Args:
396      tensor_list: A flat list of `tf.Tensor`, compatible with
397        `self._flat_tensor_specs`.
398
399    Returns:
400      A value that is compatible with this `TypeSpec`.
401
402    Raises:
403      ValueError: If `tensor_list` is not compatible with
404      `self._flat_tensor_specs`.
405    """
406    self.__check_tensor_list(tensor_list)
407    return self._from_compatible_tensor_list(tensor_list)
408
409  def _from_compatible_tensor_list(self,
410                                   tensor_list: List["ops.Tensor"]) -> Any:
411    """Reconstructs a value from a compatible flat list of `tf.Tensor`.
412
413    Args:
414      tensor_list: A flat list of `tf.Tensor`, compatible with
415        `self._flat_tensor_specs`.  (Caller is responsible for ensuring
416        compatibility.)
417
418    Returns:
419      A value that is compatible with this `TypeSpec`.
420    """
421    return self._from_components(
422        nest.pack_sequence_as(
423            self._component_specs, tensor_list, expand_composites=True))
424
425  @property
426  def _flat_tensor_specs(self):
427    """A list of TensorSpecs compatible with self._to_tensor_list(v)."""
428    return nest.flatten(self._component_specs, expand_composites=True)
429
430  # === Serialization for types ===
431
432  @abc.abstractmethod
433  def _serialize(self):
434    """Returns a nested tuple containing the state of this TypeSpec.
435
436    The serialization may contain the following value types: boolean,
437    integer, string, float, None, `TensorSpec`, `tf.TensorShape`, `tf.DType`,
438    `np.ndarray`, `TypeSpec`, and nested tuples, namedtuples, dicts, and
439    OrderedDicts of any of the above.
440
441    This method is used to provide default definitions for: equality
442    testing (__eq__, __ne__), hashing (__hash__), pickling (__reduce__),
443    string representation (__repr__), `self.is_compatible_with()`,
444    `self.most_specific_compatible_type()`, and protobuf serialization
445    (e.g. TensorInfo and StructuredValue).
446    """
447    raise NotImplementedError("%s._serialize()" % type(self).__name__)
448
449  @classmethod
450  def _deserialize(cls, serialization):
451    """Reconstructs a TypeSpec from a value returned by `serialize`.
452
453    Args:
454      serialization: A value returned by _serialize.  In some contexts,
455        `namedtuple`s in `serialization` may not have the identical type that
456        was returned by `_serialize` (but its type will still be a `namedtuple`
457        type with the same type name and field names).  For example, the code
458        that loads a SavedModel does not have access to the original
459        `namedtuple` type, so it dynamically creates a new `namedtuple` type
460        with the same type name and field names as the original one.  If
461        necessary, you can check `serialization` for these duck-typed
462        `nametuple` types, and restore them to the original type. (E.g., this
463        would be necessary if you rely on type checks such as `isinstance` for
464        this `TypeSpec`'s member variables).
465
466    Returns:
467      A `TypeSpec` of type `cls`.
468    """
469    return cls(*serialization)  # pytype: disable=not-instantiable  # trace-all-classes
470
471  # === Operators ===
472
473  def __eq__(self, other) -> bool:
474    # pylint: disable=protected-access
475    return (type(other) is type(self) and
476            self.__get_cmp_key() == other.__get_cmp_key())
477
478  def __ne__(self, other) -> bool:
479    return not self == other
480
481  def __hash__(self) -> int:
482    return hash(self.__get_cmp_key())
483
484  def __reduce__(self):
485    return type(self), self._serialize()
486
487  def __repr__(self) -> str:
488    return "%s%r" % (type(self).__name__, self._serialize())
489
490  # === Legacy Output ===
491  # TODO(b/133606651) Document and/or deprecate the legacy_output methods.
492  # (These are used by tf.data.)
493
494  def _to_legacy_output_types(self):
495    raise NotImplementedError("%s._to_legacy_output_types()" %
496                              type(self).__name__)
497
498  def _to_legacy_output_shapes(self):
499    raise NotImplementedError("%s._to_legacy_output_shapes()" %
500                              type(self).__name__)
501
502  def _to_legacy_output_classes(self):
503    return self.value_type
504
505  # === Private Helper Methods ===
506
507  # TODO(b/154541175): Currently this usage is used to represent a Tensor
508  # argument not a TensorSpec argument as it should be.
509  def __tf_tracing_type__(self,
510                          context: trace.TracingContext) -> trace.TraceType:
511    return self
512
513  def __check_tensor_list(self, tensor_list):
514    """Raises an exception if tensor_list incompatible w/ flat_tensor_specs."""
515    expected = self._flat_tensor_specs
516    specs = [type_spec_from_value(t) for t in tensor_list]
517    if len(specs) != len(expected):
518      raise ValueError(f"Cannot create a {self.value_type.__name__} from the "
519                       f"tensor list because the TypeSpec expects "
520                       f"{len(expected)} items, but the provided tensor list "
521                       f"has {len(specs)} items.")
522    for i, (s1, s2) in enumerate(zip(specs, expected)):
523      if not s1.is_compatible_with(s2):
524        raise ValueError(f"Cannot create a {self.value_type.__name__} from the "
525                         f"tensor list because item {i} ({tensor_list[i]!r}) "
526                         f"is incompatible with the expected TypeSpec {s2}.")
527
528  def __get_cmp_key(self):
529    """Returns a hashable eq-comparable key for `self`."""
530    # TODO(b/133606651): Decide whether to cache this value.
531    return (type(self), self.__make_cmp_key(self._serialize()))
532
533  def __make_cmp_key(self, value):
534    """Converts `value` to a hashable key."""
535    if isinstance(value, (int, float, bool, np.generic, dtypes.DType, TypeSpec,
536                          tensor_shape.TensorShape)):
537      return value
538    if isinstance(value, compat.bytes_or_text_types):
539      return value
540    if value is None:
541      return value
542    if isinstance(value, dict):
543      return tuple([
544          tuple([self.__make_cmp_key(key),
545                 self.__make_cmp_key(value[key])])
546          for key in sorted(value.keys())
547      ])
548    if isinstance(value, tuple):
549      return tuple([self.__make_cmp_key(v) for v in value])
550    if isinstance(value, list):
551      return (list, tuple([self.__make_cmp_key(v) for v in value]))
552    if isinstance(value, np.ndarray):
553      return (np.ndarray, value.shape,
554              TypeSpec.__nested_list_to_tuple(value.tolist()))
555    raise ValueError(f"Cannot generate a hashable key for {self} because "
556                     f"the _serialize() method "
557                     f"returned an unsupproted value of type {type(value)}")
558
559  @staticmethod
560  def __nested_list_to_tuple(value):
561    """Converts a nested list to a corresponding nested tuple."""
562    if isinstance(value, list):
563      return tuple(TypeSpec.__nested_list_to_tuple(v) for v in value)
564    return value
565
566  @staticmethod
567  def __same_types(a, b):
568    """Returns whether a and b have the same type, up to namedtuple equivalence.
569
570    Consistent with tf.nest.assert_same_structure(), two namedtuple types
571    are considered the same iff they agree in their class name (without
572    qualification by module name) and in their sequence of field names.
573    This makes namedtuples recreated by nested_structure_coder compatible with
574    their original Python definition.
575
576    Args:
577      a: a Python object.
578      b: a Python object.
579
580    Returns:
581      A boolean that is true iff type(a) and type(b) are the same object
582      or equivalent namedtuple types.
583    """
584    if nest.is_namedtuple(a) and nest.is_namedtuple(b):
585      return nest.same_namedtuples(a, b)
586    else:
587      return type(a) is type(b)
588
589  @staticmethod
590  def __is_compatible(a, b):
591    """Returns true if the given type serializations compatible."""
592    if isinstance(a, TypeSpec):
593      return a.is_compatible_with(b)
594    if not TypeSpec.__same_types(a, b):
595      return False
596    if isinstance(a, (list, tuple)):
597      return (len(a) == len(b) and
598              all(TypeSpec.__is_compatible(x, y) for (x, y) in zip(a, b)))
599    if isinstance(a, dict):
600      return (len(a) == len(b) and sorted(a.keys()) == sorted(b.keys()) and
601              all(TypeSpec.__is_compatible(a[k], b[k]) for k in a.keys()))
602    if isinstance(a, (tensor_shape.TensorShape, dtypes.DType)):
603      return a.is_compatible_with(b)
604    return a == b
605
606trace_type.register_serializable(TypeSpec)
607
608
609class TypeSpecBatchEncoder(object, metaclass=abc.ABCMeta):
610  """Class used to encode and decode composite tensor values for batching.
611
612  In order to be batched and unbatched by APIs such as `tf.data.Dataset` and
613  `tf.map_fn`, composite tensors must be encoded using flat tensors that can
614  themselves be batched or unbatched.  `TypeSpecBatchEncoder`s are
615  responsible for implementing this encoding.
616
617  If a composite tensor's shape is a prefix of the shape of all of its
618  component tensors, then this encoding can usually be performed by just
619  returning those component tensors as a list.  But if the composite tensor
620  has components whose shape has a more complex relationship to the shape
621  of the composite tensor, then a custom `TypeSpecBatchEncoder` may
622  need to be implemented.
623  """
624
625  @abc.abstractmethod
626  def batch(self, spec, batch_size):
627    """Returns the TypeSpec representing a batch of values described by `spec`.
628
629    Args:
630      spec: The `TypeSpec` for an individual value.
631      batch_size: An `int` indicating the number of values that are batched
632        together, or `None` if the batch size is not known.
633
634    Returns:
635      A `TypeSpec` for a batch of values.
636    """
637    raise NotImplementedError(f"{type(self).__name__}.batch")
638
639  @abc.abstractmethod
640  def unbatch(self, spec):
641    """Returns the TypeSpec for a single unbatched element in `spec`.
642
643    Args:
644      spec: The `TypeSpec` for a batch of values.
645
646    Returns:
647      A `TypeSpec` for an individual value.
648    """
649    raise NotImplementedError(f"{type(self).__name__}.unbatch")
650
651  @abc.abstractmethod
652  def encode(self, spec, value, minimum_rank=0):
653    """Encodes `value` as a nest of batchable `Tensor` or `CompositeTensor`.
654
655    Args:
656      spec: The TypeSpec of the value to encode.
657      value: A value compatible with `spec`.
658      minimum_rank: The minimum rank for the returned Tensors, CompositeTensors,
659        and ExtensionType values.  This can be used to ensure that the encoded
660        values can be unbatched this number of times.   If `minimum_rank>0`,
661        then `t.shape[:minimum_rank]` must be compatible for all values `t`
662        returned by `encode`.
663
664    Returns:
665      A nest (as defined by `tf.nest`) of `tf.Tensor`s, batchable
666      `tf.CompositeTensor`s, or `tf.ExtensionType`s.  Stacking, unstacking, or
667      concatenating these encoded values and then decoding the result must be
668      equivalent to stacking, unstacking, or concatenating the original values.
669    """
670    raise NotImplementedError(f"{type(self).__name__}.encode")
671
672  @abc.abstractmethod
673  def decode(self, spec, encoded_value):
674    """Decodes `value` from a batchable tensor encoding.
675
676    Args:
677      spec: The TypeSpec for the result value.  If encoded values with spec `s`
678        were batched, then `spec` should be `s.batch(batch_size)`; or if encoded
679        values with spec `s` were unbatched, then `spec` should be
680        `s.unbatch()`.
681      encoded_value: A nest of values returned by `encode`; or a nest of values
682        that was formed by stacking, unstacking, or concatenating the
683        corresponding elements of values returned by `encode`.
684
685    Returns:
686      A value compatible with `type_spec`.
687    """
688    raise NotImplementedError(f"{type(self).__name__}.decode")
689
690  @abc.abstractmethod
691  def encoding_specs(self, spec):
692    """Returns a nest of `TypeSpec`(s) describing the encoding for `spec`.
693
694    Args:
695      spec: The TypeSpec whose encoding should be described.
696
697    Returns:
698      A nest (as defined by `tf.nest) of `tf.TypeSpec`, describing the values
699      that are returned by `self.encode(spec, ...)`.  All TypeSpecs in this
700      nest must be batchable.
701    """
702    raise NotImplementedError(f"{type(self).__name__}.encoding_specs")
703
704
705class LegacyTypeSpecBatchEncoder(TypeSpecBatchEncoder):
706  """TypeSpecBatchEncoder for legacy composite tensor classes.
707
708  TODO(edloper): Update existing composite tensors to use non-legacy
709    CompositTensorBatchEncoders.
710  """
711
712  def batch(self, type_spec, batch_size):
713    return type_spec._batch(batch_size)  # pylint: disable=protected-access
714
715  def unbatch(self, type_spec):
716    return type_spec._unbatch()  # pylint: disable=protected-access
717
718  def encode(self, type_spec, value, minimum_rank=0):
719    if minimum_rank == 0:
720      return type_spec._to_tensor_list(value)  # pylint: disable=protected-access
721    elif minimum_rank == 1:
722      if not isinstance(type_spec, BatchableTypeSpec):
723        raise ValueError(f"{type_spec.__name__}.encode does not support "
724                         "minimum_rank>0.")
725      return type_spec._to_batched_tensor_list(value)  # pylint: disable=protected-access
726    else:
727      raise ValueError(f"{type_spec.__name__}.encode does not support "
728                       "minimum_rank>1.")
729
730  def decode(self, type_spec, encoded_value):
731    return type_spec._from_tensor_list(encoded_value)  # pylint: disable=protected-access
732
733  def encoding_specs(self, spec):
734    return spec._flat_tensor_specs  # pylint: disable=protected-access
735
736
737class BatchableTypeSpec(TypeSpec, metaclass=abc.ABCMeta):
738  """TypeSpec with a batchable tensor encoding.
739
740  The batchable tensor encoding is a list of `tf.Tensor`s that supports
741  batching and unbatching.  In particular, stacking (or unstacking)
742  values with the same `TypeSpec` must be equivalent to stacking (or
743  unstacking) each of their tensor lists.  Unlike the component encoding
744  (returned by `self._to_components)`, the batchable tensor encoding
745  may require using encoding/decoding ops.
746
747  If a subclass's batchable tensor encoding is not simply a flattened version
748  of the component encoding, then the subclass must override `_to_tensor_list`,
749  `_from_tensor_list`, and _flat_tensor_specs`.
750  """
751
752  __slots__ = []
753
754  __batch_encoder__ = LegacyTypeSpecBatchEncoder()
755
756  @abc.abstractmethod
757  def _batch(self, batch_size) -> TypeSpec:
758    """Returns a TypeSpec representing a batch of objects with this TypeSpec.
759
760    Args:
761      batch_size: An `int` representing the number of elements in a batch, or
762        `None` if the batch size may vary.
763
764    Returns:
765      A `TypeSpec` representing a batch of objects with this TypeSpec.
766    """
767    raise NotImplementedError(f"{type(self).__name__}._batch")
768
769  @abc.abstractmethod
770  def _unbatch(self) -> TypeSpec:
771    """Returns a TypeSpec representing a single element this TypeSpec.
772
773    Returns:
774      A `TypeSpec` representing a single element of objects with this TypeSpec.
775    """
776    raise NotImplementedError(f"{type(self).__name__}._unbatch")
777
778# LINT.IfChange
779  @property
780  def _flat_tensor_specs(self) -> List[TypeSpec]:
781    """A list of TensorSpecs compatible with self._to_tensor_list(v)."""
782    component_flat_tensor_specs = nest.map_structure(
783        functools.partial(get_batchable_flat_tensor_specs, context_spec=self),
784        self._component_specs)
785    return nest.flatten(component_flat_tensor_specs)
786# LINT.ThenChange(//tensorflow/python/framework/type_utils.py:_specs_for_flat_tensors)
787# Note that _specs_for_flat_tensors in type_utils.py must correspond
788# _flat_tensor_specs in this class and any derived classes.
789
790  def _to_tensor_list(
791      self, value: composite_tensor.CompositeTensor) -> List["ops.Tensor"]:
792    """Encodes `value` as a flat list of `ops.Tensor`."""
793    component_tensor_lists = nest.map_structure(batchable_to_tensor_list,
794                                                self._component_specs,
795                                                self._to_components(value))
796    return nest.flatten(component_tensor_lists)
797
798  def _to_batched_tensor_list(
799      self, value: composite_tensor.CompositeTensor) -> List["ops.Tensor"]:
800    """Encodes `value` as a flat list of `ops.Tensor` each with rank>0."""
801    get_spec_tensor_list = lambda spec, v: (  # pylint: disable=g-long-lambda
802        batchable_to_tensor_list(spec, v, minimum_rank=1)
803        if isinstance(spec, BatchableTypeSpec) else spec._to_tensor_list(v))  # pylint: disable=protected-access
804    component_batched_tensor_lists = nest.map_structure(
805        get_spec_tensor_list, self._component_specs, self._to_components(value))
806    tensor_list = nest.flatten(component_batched_tensor_lists)
807    if any(t.shape.ndims == 0 for t in tensor_list):
808      raise ValueError(
809          f"While converting {value} to a list of tensors for batching, "
810          f"found a scalar item which cannot be batched.")
811    return tensor_list
812
813  def _from_compatible_tensor_list(
814      self,
815      tensor_list: List["ops.Tensor"]) -> composite_tensor.CompositeTensor:
816    """Reconstructs a value from a compatible flat list of `ops.Tensor`."""
817    flat_specs = nest.map_structure(
818        functools.partial(get_batchable_flat_tensor_specs, context_spec=self),
819        self._component_specs)
820    nested_tensor_list = nest.pack_sequence_as(flat_specs, tensor_list)
821    components = nest.map_structure_up_to(self._component_specs,
822                                          batchable_from_tensor_list,
823                                          self._component_specs,
824                                          nested_tensor_list)
825    return self._from_components(components)
826
827
828def get_batchable_flat_tensor_specs(spec, context_spec=None):
829  """Returns the flat tensor specs for `spec`."""
830  if isinstance(spec, tensor_spec.TensorSpec):
831    return [spec]
832  elif hasattr(spec, "__batch_encoder__"):
833    encoding_specs = nest.map_structure(
834        functools.partial(
835            get_batchable_flat_tensor_specs, context_spec=context_spec),
836        spec.__batch_encoder__.encoding_specs(spec))
837    return nest.flatten(encoding_specs)
838  else:
839    # TODO(edloper) Fix existing CompositeTensors that permit this, and
840    # then turn this warning into an error.
841    warnings.warn(f"Batchable type {context_spec} contains non-batchable "
842                  f"field or component with type {spec}.")
843    return spec._flat_tensor_specs  # pylint: disable=protected-access
844
845
846def batchable_to_tensor_list(spec, value, minimum_rank=0):
847  """Returns a list of tensors encoding `value`, whose type is `spec`."""
848  if isinstance(spec, tensor_spec.TensorSpec):
849    return [value]
850  elif hasattr(spec, "__batch_encoder__"):
851    encoded_value = spec.__batch_encoder__.encode(spec, value, minimum_rank)
852    encoded_specs = spec.__batch_encoder__.encoding_specs(spec)
853    encoded_flats = nest.map_structure(
854        functools.partial(batchable_to_tensor_list, minimum_rank=minimum_rank),
855        encoded_specs, encoded_value)
856    return nest.flatten(encoded_flats)
857  else:
858    return spec._to_tensor_list(value)  # pylint: disable=protected-access
859
860
861def batchable_from_tensor_list(spec, tensor_list):
862  """Returns a value with type `spec` decoded from `tensor_list`."""
863  if isinstance(spec, tensor_spec.TensorSpec):
864    assert len(tensor_list) == 1
865    return tensor_list[0]
866  elif hasattr(spec, "__batch_encoder__"):
867    encoded_specs = spec.__batch_encoder__.encoding_specs(spec)
868    flat_specs = nest.map_structure(get_batchable_flat_tensor_specs,
869                                    encoded_specs)
870    encoded_flats = nest.pack_sequence_as(flat_specs, tensor_list)
871    encoded_value = nest.map_structure_up_to(encoded_specs,
872                                             batchable_from_tensor_list,
873                                             encoded_specs, encoded_flats)
874    return spec.__batch_encoder__.decode(spec, encoded_value)
875  else:
876    return spec._from_compatible_tensor_list(tensor_list)  # pylint: disable=protected-access
877
878
879@tf_export("type_spec_from_value")
880def type_spec_from_value(value) -> TypeSpec:
881  """Returns a `tf.TypeSpec` that represents the given `value`.
882
883  Examples:
884
885    >>> tf.type_spec_from_value(tf.constant([1, 2, 3]))
886    TensorSpec(shape=(3,), dtype=tf.int32, name=None)
887    >>> tf.type_spec_from_value(np.array([4.0, 5.0], np.float64))
888    TensorSpec(shape=(2,), dtype=tf.float64, name=None)
889    >>> tf.type_spec_from_value(tf.ragged.constant([[1, 2], [3, 4, 5]]))
890    RaggedTensorSpec(TensorShape([2, None]), tf.int32, 1, tf.int64)
891
892    >>> example_input = tf.ragged.constant([[1, 2], [3]])
893    >>> @tf.function(input_signature=[tf.type_spec_from_value(example_input)])
894    ... def f(x):
895    ...   return tf.reduce_sum(x, axis=1)
896
897  Args:
898    value: A value that can be accepted or returned by TensorFlow APIs. Accepted
899      types for `value` include `tf.Tensor`, any value that can be converted to
900      `tf.Tensor` using `tf.convert_to_tensor`, and any subclass of
901      `CompositeTensor` (such as `tf.RaggedTensor`).
902
903  Returns:
904    A `TypeSpec` that is compatible with `value`.
905
906  Raises:
907    TypeError: If a TypeSpec cannot be built for `value`, because its type
908      is not supported.
909  """
910  spec = _type_spec_from_value(value)
911  if spec is not None:
912    return spec
913
914  # Fallback: try converting value to a tensor.
915  try:
916    tensor = ops.convert_to_tensor(value)
917    spec = _type_spec_from_value(tensor)
918    if spec is not None:
919      return spec
920  except (ValueError, TypeError) as e:
921    logging.vlog(
922        3, "Failed to convert %r to tensor: %s" % (type(value).__name__, e))
923
924  raise TypeError(f"Could not build a TypeSpec for {value} of "
925                  f"unsupported type {type(value)}.")
926
927
928def _type_spec_from_value(value) -> TypeSpec:
929  """Returns a `TypeSpec` that represents the given `value`."""
930  if isinstance(value, ops.Tensor):
931    # Note: we do not include Tensor names when constructing TypeSpecs.
932    return tensor_spec.TensorSpec(value.shape, value.dtype)
933
934  if isinstance(value, composite_tensor.CompositeTensor):
935    return value._type_spec  # pylint: disable=protected-access
936
937  # If `value` is a list and all of its elements can be represented by the same
938  # batchable type spec, then we can represent the entire list using a single
939  # type spec that captures the type accurately (unlike the `convert_to_tensor`
940  # fallback).
941  if isinstance(value, list) and value:
942    subspecs = [_type_spec_from_value(v) for v in value]
943    if isinstance(subspecs[0], BatchableTypeSpec):
944      merged_subspec = subspecs[0].most_specific_common_supertype(subspecs[1:])
945      if merged_subspec is not None:
946        return merged_subspec._batch(len(subspecs))  # pylint: disable=protected-access
947
948  for entry in reversed(_TYPE_CONVERSION_FUNCTION_REGISTRY):
949    type_object, converter_fn, allow_subclass = entry
950    if ((type(value) is type_object) or  # pylint: disable=unidiomatic-typecheck
951        (allow_subclass and isinstance(value, type_object))):
952      return converter_fn(value)
953
954  return None
955
956
957_TYPE_CONVERSION_FUNCTION_REGISTRY = []
958
959
960def register_type_spec_from_value_converter(type_object,
961                                            converter_fn,
962                                            allow_subclass=False):
963  """Registers a function for converting values with a given type to TypeSpecs.
964
965  If multiple registered `type_object`s match a value, then the most recent
966  registration takes precedence.  Custom converters should not be defined for
967  `CompositeTensor`s; use `CompositeTensor._type_spec` instead.
968
969  Args:
970    type_object: A Python `type` object representing the type of values accepted
971      by `converter_fn`.
972    converter_fn: A function that takes one argument (an instance of the type
973      represented by `type_object`) and returns a `TypeSpec`.
974    allow_subclass: If true, then use `isinstance(value, type_object)` to check
975      for matches.  If false, then use `type(value) is type_object`.
976  """
977  _, type_object = tf_decorator.unwrap(type_object)
978  _TYPE_CONVERSION_FUNCTION_REGISTRY.append(
979      (type_object, converter_fn, allow_subclass))
980
981
982_pywrap_utils.RegisterType("TypeSpec", TypeSpec)
983
984_TYPE_SPEC_TO_NAME = {}
985_NAME_TO_TYPE_SPEC = {}
986
987# Regular expression for valid TypeSpec names.
988_REGISTERED_NAME_RE = re.compile(r"^(\w+\.)+\w+$")
989
990
991# TODO(b/173744905) tf_export this as "tf.register_type_spec".  (And add a
992# usage example to the docstring, once the API is public.)
993#
994# TODO(b/173744905) Update this decorator to apply to ExtensionType rather than
995# TypeSpec (once we do refactoring to move to_components/from_components from
996# TypeSpec to ExtensionType).
997def register(name):
998  """Decorator used to register a globally unique name for a TypeSpec subclass.
999
1000  Args:
1001    name: The name of the type spec.  Must be globally unique.  Must have the
1002      form `"{project_name}.{type_name}"`.  E.g. `"my_project.MyTypeSpec"`.
1003
1004  Returns:
1005    A class decorator that registers the decorated class with the given name.
1006  """
1007  if not isinstance(name, str):
1008    raise TypeError("Expected `name` to be a string; got %r" % (name,))
1009  if not _REGISTERED_NAME_RE.match(name):
1010    raise ValueError(
1011        "Registered name must have the form '{project_name}.{type_name}' "
1012        "(e.g. 'my_project.MyTypeSpec'); got %r." % name)
1013
1014  def decorator_fn(cls):
1015    if not (isinstance(cls, type) and issubclass(cls, TypeSpec)):
1016      raise TypeError("Expected `cls` to be a TypeSpec; got %r" % (cls,))
1017    if cls in _TYPE_SPEC_TO_NAME:
1018      raise ValueError("Class %s.%s has already been registered with name %s." %
1019                       (cls.__module__, cls.__name__, _TYPE_SPEC_TO_NAME[cls]))
1020    if name in _NAME_TO_TYPE_SPEC:
1021      raise ValueError("Name %s has already been registered for class %s.%s." %
1022                       (name, _NAME_TO_TYPE_SPEC[name].__module__,
1023                        _NAME_TO_TYPE_SPEC[name].__name__))
1024    _TYPE_SPEC_TO_NAME[cls] = name
1025    _NAME_TO_TYPE_SPEC[name] = cls
1026    return cls
1027
1028  return decorator_fn
1029
1030
1031# TODO(edloper) tf_export this as "tf.get_type_spec_name" (or some similar name)
1032def get_name(cls):
1033  """Returns the registered name for TypeSpec `cls`."""
1034  if not (isinstance(cls, type) and issubclass(cls, TypeSpec)):
1035    raise TypeError("Expected `cls` to be a TypeSpec; got %r" % (cls,))
1036  if cls not in _TYPE_SPEC_TO_NAME:
1037    raise ValueError("TypeSpec %s.%s has not been registered." %
1038                     (cls.__module__, cls.__name__))
1039  return _TYPE_SPEC_TO_NAME[cls]
1040
1041
1042# TODO(edloper) tf_export this as "tf.lookup_type_spec" (or some similar name)
1043def lookup(name):
1044  """Returns the TypeSpec that has been registered with name `name`."""
1045  if not isinstance(name, str):
1046    raise TypeError("Expected `name` to be a string; got %r" % (name,))
1047  if name not in _NAME_TO_TYPE_SPEC:
1048    raise ValueError("No TypeSpec has been registered with name %r" % (name,))
1049  return _NAME_TO_TYPE_SPEC[name]
1050