xref: /aosp_15_r20/external/tensorflow/tensorflow/python/data/ops/optional_ops.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A type for representing values that may or may not exist."""
16import abc
17
18from tensorflow.python.data.util import structure
19from tensorflow.python.framework import composite_tensor
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import tensor_spec
23from tensorflow.python.framework import type_spec
24from tensorflow.python.ops import gen_dataset_ops
25from tensorflow.python.util import deprecation
26from tensorflow.python.util.tf_export import tf_export
27
28
29@tf_export("experimental.Optional", "data.experimental.Optional")
30@deprecation.deprecated_endpoints("data.experimental.Optional")
31class Optional(composite_tensor.CompositeTensor, metaclass=abc.ABCMeta):
32  """Represents a value that may or may not be present.
33
34  A `tf.experimental.Optional` can represent the result of an operation that may
35  fail as a value, rather than raising an exception and halting execution. For
36  example, `tf.data.Iterator.get_next_as_optional()` returns a
37  `tf.experimental.Optional` that either contains the next element of an
38  iterator if one exists, or an "empty" value that indicates the end of the
39  sequence has been reached.
40
41  `tf.experimental.Optional` can only be used with values that are convertible
42  to `tf.Tensor` or `tf.CompositeTensor`.
43
44  One can create a `tf.experimental.Optional` from a value using the
45  `from_value()` method:
46
47  >>> optional = tf.experimental.Optional.from_value(42)
48  >>> print(optional.has_value())
49  tf.Tensor(True, shape=(), dtype=bool)
50  >>> print(optional.get_value())
51  tf.Tensor(42, shape=(), dtype=int32)
52
53  or without a value using the `empty()` method:
54
55  >>> optional = tf.experimental.Optional.empty(
56  ...   tf.TensorSpec(shape=(), dtype=tf.int32, name=None))
57  >>> print(optional.has_value())
58  tf.Tensor(False, shape=(), dtype=bool)
59  """
60
61  @abc.abstractmethod
62  def has_value(self, name=None):
63    """Returns a tensor that evaluates to `True` if this optional has a value.
64
65    >>> optional = tf.experimental.Optional.from_value(42)
66    >>> print(optional.has_value())
67    tf.Tensor(True, shape=(), dtype=bool)
68
69    Args:
70      name: (Optional.) A name for the created operation.
71
72    Returns:
73      A scalar `tf.Tensor` of type `tf.bool`.
74    """
75    raise NotImplementedError("Optional.has_value()")
76
77  @abc.abstractmethod
78  def get_value(self, name=None):
79    """Returns the value wrapped by this optional.
80
81    If this optional does not have a value (i.e. `self.has_value()` evaluates to
82    `False`), this operation will raise `tf.errors.InvalidArgumentError` at
83    runtime.
84
85    >>> optional = tf.experimental.Optional.from_value(42)
86    >>> print(optional.get_value())
87    tf.Tensor(42, shape=(), dtype=int32)
88
89    Args:
90      name: (Optional.) A name for the created operation.
91
92    Returns:
93      The wrapped value.
94    """
95    raise NotImplementedError("Optional.get_value()")
96
97  @abc.abstractproperty
98  def element_spec(self):
99    """The type specification of an element of this optional.
100
101    >>> optional = tf.experimental.Optional.from_value(42)
102    >>> print(optional.element_spec)
103    tf.TensorSpec(shape=(), dtype=tf.int32, name=None)
104
105    Returns:
106      A (nested) structure of `tf.TypeSpec` objects matching the structure of an
107      element of this optional, specifying the type of individual components.
108    """
109    raise NotImplementedError("Optional.element_spec")
110
111  @staticmethod
112  def empty(element_spec):
113    """Returns an `Optional` that has no value.
114
115    NOTE: This method takes an argument that defines the structure of the value
116    that would be contained in the returned `Optional` if it had a value.
117
118    >>> optional = tf.experimental.Optional.empty(
119    ...   tf.TensorSpec(shape=(), dtype=tf.int32, name=None))
120    >>> print(optional.has_value())
121    tf.Tensor(False, shape=(), dtype=bool)
122
123    Args:
124      element_spec: A (nested) structure of `tf.TypeSpec` objects matching the
125        structure of an element of this optional.
126
127    Returns:
128      A `tf.experimental.Optional` with no value.
129    """
130    return _OptionalImpl(gen_dataset_ops.optional_none(), element_spec)
131
132  @staticmethod
133  def from_value(value):
134    """Returns a `tf.experimental.Optional` that wraps the given value.
135
136    >>> optional = tf.experimental.Optional.from_value(42)
137    >>> print(optional.has_value())
138    tf.Tensor(True, shape=(), dtype=bool)
139    >>> print(optional.get_value())
140    tf.Tensor(42, shape=(), dtype=int32)
141
142    Args:
143      value: A value to wrap. The value must be convertible to `tf.Tensor` or
144        `tf.CompositeTensor`.
145
146    Returns:
147      A `tf.experimental.Optional` that wraps `value`.
148    """
149    with ops.name_scope("optional") as scope:
150      with ops.name_scope("value"):
151        element_spec = structure.type_spec_from_value(value)
152        encoded_value = structure.to_tensor_list(element_spec, value)
153
154    return _OptionalImpl(
155        gen_dataset_ops.optional_from_value(encoded_value, name=scope),
156        element_spec)
157
158
159class _OptionalImpl(Optional):
160  """Concrete implementation of `tf.experimental.Optional`.
161
162  NOTE(mrry): This implementation is kept private, to avoid defining
163  `Optional.__init__()` in the public API.
164  """
165
166  def __init__(self, variant_tensor, element_spec):
167    super().__init__()
168    self._variant_tensor = variant_tensor
169    self._element_spec = element_spec
170
171  def has_value(self, name=None):
172    with ops.colocate_with(self._variant_tensor):
173      return gen_dataset_ops.optional_has_value(self._variant_tensor, name=name)
174
175  def get_value(self, name=None):
176    # TODO(b/110122868): Consolidate the restructuring logic with similar logic
177    # in `Iterator.get_next()` and `StructuredFunctionWrapper`.
178    with ops.name_scope(name, "OptionalGetValue",
179                        [self._variant_tensor]) as scope:
180      with ops.colocate_with(self._variant_tensor):
181        result = gen_dataset_ops.optional_get_value(
182            self._variant_tensor,
183            name=scope,
184            output_types=structure.get_flat_tensor_types(self._element_spec),
185            output_shapes=structure.get_flat_tensor_shapes(self._element_spec))
186      # NOTE: We do not colocate the deserialization of composite tensors
187      # because not all ops are guaranteed to have non-GPU kernels.
188      return structure.from_tensor_list(self._element_spec, result)
189
190  @property
191  def element_spec(self):
192    return self._element_spec
193
194  @property
195  def _type_spec(self):
196    return OptionalSpec.from_value(self)
197
198
199@tf_export(
200    "OptionalSpec", v1=["OptionalSpec", "data.experimental.OptionalStructure"])
201class OptionalSpec(type_spec.TypeSpec):
202  """Type specification for `tf.experimental.Optional`.
203
204  For instance, `tf.OptionalSpec` can be used to define a tf.function that takes
205  `tf.experimental.Optional` as an input argument:
206
207  >>> @tf.function(input_signature=[tf.OptionalSpec(
208  ...   tf.TensorSpec(shape=(), dtype=tf.int32, name=None))])
209  ... def maybe_square(optional):
210  ...   if optional.has_value():
211  ...     x = optional.get_value()
212  ...     return x * x
213  ...   return -1
214  >>> optional = tf.experimental.Optional.from_value(5)
215  >>> print(maybe_square(optional))
216  tf.Tensor(25, shape=(), dtype=int32)
217
218  Attributes:
219    element_spec: A (nested) structure of `TypeSpec` objects that represents the
220      type specification of the optional element.
221  """
222
223  __slots__ = ["_element_spec"]
224
225  def __init__(self, element_spec):
226    super().__init__()
227    self._element_spec = element_spec
228
229  @property
230  def value_type(self):
231    return _OptionalImpl
232
233  def _serialize(self):
234    return (self._element_spec,)
235
236  @property
237  def _component_specs(self):
238    return [tensor_spec.TensorSpec((), dtypes.variant)]
239
240  def _to_components(self, value):
241    return [value._variant_tensor]  # pylint: disable=protected-access
242
243  def _from_components(self, flat_value):
244    # pylint: disable=protected-access
245    return _OptionalImpl(flat_value[0], self._element_spec)
246
247  @staticmethod
248  def from_value(value):
249    return OptionalSpec(value.element_spec)
250
251  def _to_legacy_output_types(self):
252    return self
253
254  def _to_legacy_output_shapes(self):
255    return self
256
257  def _to_legacy_output_classes(self):
258    return self
259