1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A type for representing values that may or may not exist.""" 16import abc 17 18from tensorflow.python.data.util import structure 19from tensorflow.python.framework import composite_tensor 20from tensorflow.python.framework import dtypes 21from tensorflow.python.framework import ops 22from tensorflow.python.framework import tensor_spec 23from tensorflow.python.framework import type_spec 24from tensorflow.python.ops import gen_dataset_ops 25from tensorflow.python.util import deprecation 26from tensorflow.python.util.tf_export import tf_export 27 28 29@tf_export("experimental.Optional", "data.experimental.Optional") 30@deprecation.deprecated_endpoints("data.experimental.Optional") 31class Optional(composite_tensor.CompositeTensor, metaclass=abc.ABCMeta): 32 """Represents a value that may or may not be present. 33 34 A `tf.experimental.Optional` can represent the result of an operation that may 35 fail as a value, rather than raising an exception and halting execution. For 36 example, `tf.data.Iterator.get_next_as_optional()` returns a 37 `tf.experimental.Optional` that either contains the next element of an 38 iterator if one exists, or an "empty" value that indicates the end of the 39 sequence has been reached. 40 41 `tf.experimental.Optional` can only be used with values that are convertible 42 to `tf.Tensor` or `tf.CompositeTensor`. 43 44 One can create a `tf.experimental.Optional` from a value using the 45 `from_value()` method: 46 47 >>> optional = tf.experimental.Optional.from_value(42) 48 >>> print(optional.has_value()) 49 tf.Tensor(True, shape=(), dtype=bool) 50 >>> print(optional.get_value()) 51 tf.Tensor(42, shape=(), dtype=int32) 52 53 or without a value using the `empty()` method: 54 55 >>> optional = tf.experimental.Optional.empty( 56 ... tf.TensorSpec(shape=(), dtype=tf.int32, name=None)) 57 >>> print(optional.has_value()) 58 tf.Tensor(False, shape=(), dtype=bool) 59 """ 60 61 @abc.abstractmethod 62 def has_value(self, name=None): 63 """Returns a tensor that evaluates to `True` if this optional has a value. 64 65 >>> optional = tf.experimental.Optional.from_value(42) 66 >>> print(optional.has_value()) 67 tf.Tensor(True, shape=(), dtype=bool) 68 69 Args: 70 name: (Optional.) A name for the created operation. 71 72 Returns: 73 A scalar `tf.Tensor` of type `tf.bool`. 74 """ 75 raise NotImplementedError("Optional.has_value()") 76 77 @abc.abstractmethod 78 def get_value(self, name=None): 79 """Returns the value wrapped by this optional. 80 81 If this optional does not have a value (i.e. `self.has_value()` evaluates to 82 `False`), this operation will raise `tf.errors.InvalidArgumentError` at 83 runtime. 84 85 >>> optional = tf.experimental.Optional.from_value(42) 86 >>> print(optional.get_value()) 87 tf.Tensor(42, shape=(), dtype=int32) 88 89 Args: 90 name: (Optional.) A name for the created operation. 91 92 Returns: 93 The wrapped value. 94 """ 95 raise NotImplementedError("Optional.get_value()") 96 97 @abc.abstractproperty 98 def element_spec(self): 99 """The type specification of an element of this optional. 100 101 >>> optional = tf.experimental.Optional.from_value(42) 102 >>> print(optional.element_spec) 103 tf.TensorSpec(shape=(), dtype=tf.int32, name=None) 104 105 Returns: 106 A (nested) structure of `tf.TypeSpec` objects matching the structure of an 107 element of this optional, specifying the type of individual components. 108 """ 109 raise NotImplementedError("Optional.element_spec") 110 111 @staticmethod 112 def empty(element_spec): 113 """Returns an `Optional` that has no value. 114 115 NOTE: This method takes an argument that defines the structure of the value 116 that would be contained in the returned `Optional` if it had a value. 117 118 >>> optional = tf.experimental.Optional.empty( 119 ... tf.TensorSpec(shape=(), dtype=tf.int32, name=None)) 120 >>> print(optional.has_value()) 121 tf.Tensor(False, shape=(), dtype=bool) 122 123 Args: 124 element_spec: A (nested) structure of `tf.TypeSpec` objects matching the 125 structure of an element of this optional. 126 127 Returns: 128 A `tf.experimental.Optional` with no value. 129 """ 130 return _OptionalImpl(gen_dataset_ops.optional_none(), element_spec) 131 132 @staticmethod 133 def from_value(value): 134 """Returns a `tf.experimental.Optional` that wraps the given value. 135 136 >>> optional = tf.experimental.Optional.from_value(42) 137 >>> print(optional.has_value()) 138 tf.Tensor(True, shape=(), dtype=bool) 139 >>> print(optional.get_value()) 140 tf.Tensor(42, shape=(), dtype=int32) 141 142 Args: 143 value: A value to wrap. The value must be convertible to `tf.Tensor` or 144 `tf.CompositeTensor`. 145 146 Returns: 147 A `tf.experimental.Optional` that wraps `value`. 148 """ 149 with ops.name_scope("optional") as scope: 150 with ops.name_scope("value"): 151 element_spec = structure.type_spec_from_value(value) 152 encoded_value = structure.to_tensor_list(element_spec, value) 153 154 return _OptionalImpl( 155 gen_dataset_ops.optional_from_value(encoded_value, name=scope), 156 element_spec) 157 158 159class _OptionalImpl(Optional): 160 """Concrete implementation of `tf.experimental.Optional`. 161 162 NOTE(mrry): This implementation is kept private, to avoid defining 163 `Optional.__init__()` in the public API. 164 """ 165 166 def __init__(self, variant_tensor, element_spec): 167 super().__init__() 168 self._variant_tensor = variant_tensor 169 self._element_spec = element_spec 170 171 def has_value(self, name=None): 172 with ops.colocate_with(self._variant_tensor): 173 return gen_dataset_ops.optional_has_value(self._variant_tensor, name=name) 174 175 def get_value(self, name=None): 176 # TODO(b/110122868): Consolidate the restructuring logic with similar logic 177 # in `Iterator.get_next()` and `StructuredFunctionWrapper`. 178 with ops.name_scope(name, "OptionalGetValue", 179 [self._variant_tensor]) as scope: 180 with ops.colocate_with(self._variant_tensor): 181 result = gen_dataset_ops.optional_get_value( 182 self._variant_tensor, 183 name=scope, 184 output_types=structure.get_flat_tensor_types(self._element_spec), 185 output_shapes=structure.get_flat_tensor_shapes(self._element_spec)) 186 # NOTE: We do not colocate the deserialization of composite tensors 187 # because not all ops are guaranteed to have non-GPU kernels. 188 return structure.from_tensor_list(self._element_spec, result) 189 190 @property 191 def element_spec(self): 192 return self._element_spec 193 194 @property 195 def _type_spec(self): 196 return OptionalSpec.from_value(self) 197 198 199@tf_export( 200 "OptionalSpec", v1=["OptionalSpec", "data.experimental.OptionalStructure"]) 201class OptionalSpec(type_spec.TypeSpec): 202 """Type specification for `tf.experimental.Optional`. 203 204 For instance, `tf.OptionalSpec` can be used to define a tf.function that takes 205 `tf.experimental.Optional` as an input argument: 206 207 >>> @tf.function(input_signature=[tf.OptionalSpec( 208 ... tf.TensorSpec(shape=(), dtype=tf.int32, name=None))]) 209 ... def maybe_square(optional): 210 ... if optional.has_value(): 211 ... x = optional.get_value() 212 ... return x * x 213 ... return -1 214 >>> optional = tf.experimental.Optional.from_value(5) 215 >>> print(maybe_square(optional)) 216 tf.Tensor(25, shape=(), dtype=int32) 217 218 Attributes: 219 element_spec: A (nested) structure of `TypeSpec` objects that represents the 220 type specification of the optional element. 221 """ 222 223 __slots__ = ["_element_spec"] 224 225 def __init__(self, element_spec): 226 super().__init__() 227 self._element_spec = element_spec 228 229 @property 230 def value_type(self): 231 return _OptionalImpl 232 233 def _serialize(self): 234 return (self._element_spec,) 235 236 @property 237 def _component_specs(self): 238 return [tensor_spec.TensorSpec((), dtypes.variant)] 239 240 def _to_components(self, value): 241 return [value._variant_tensor] # pylint: disable=protected-access 242 243 def _from_components(self, flat_value): 244 # pylint: disable=protected-access 245 return _OptionalImpl(flat_value[0], self._element_spec) 246 247 @staticmethod 248 def from_value(value): 249 return OptionalSpec(value.element_spec) 250 251 def _to_legacy_output_types(self): 252 return self 253 254 def _to_legacy_output_shapes(self): 255 return self 256 257 def _to_legacy_output_classes(self): 258 return self 259