1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Keras Input Tensor used to track functional API Topology.""" 16 17from tensorflow.python.framework import constant_op 18from tensorflow.python.framework import dtypes 19from tensorflow.python.framework import ops 20from tensorflow.python.framework import sparse_tensor 21from tensorflow.python.framework import tensor_shape 22from tensorflow.python.framework import tensor_spec 23from tensorflow.python.framework import type_spec as type_spec_module 24from tensorflow.python.keras.utils import object_identity 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops.ragged import ragged_operators # pylint: disable=unused-import 27from tensorflow.python.ops.ragged import ragged_tensor 28from tensorflow.python.util import nest 29 30# pylint: disable=g-classes-have-attributes 31 32 33# Tensorflow tensors have a maximum rank of 254 34# (See `MaxDimensions()` in //tensorflow/core/framework/tensor_shape.h ) 35# So we do not try to infer values for int32 tensors larger than this, 36# As they cannot represent shapes. 37_MAX_TENSOR_RANK = 254 38 39 40class KerasTensor(object): 41 """A representation of a Keras in/output during Functional API construction. 42 43 `KerasTensor`s are tensor-like objects that represent the symbolic inputs 44 and outputs of Keras layers during Functional model construction. They are 45 comprised of the `tf.TypeSpec` of the (Composite)Tensor that will be 46 consumed/produced in the corresponding location of the Functional model. 47 48 KerasTensors are intended as a private API, so users should never need to 49 directly instantiate `KerasTensor`s. 50 51 **Building Functional Models with KerasTensors** 52 `tf.keras.Input` produces `KerasTensor`s that represent the symbolic inputs 53 to your model. 54 55 Passing a `KerasTensor` to a `tf.keras.Layer` `__call__` lets the layer know 56 that you are building a Functional model. The layer __call__ will 57 infer the output signature and return `KerasTensor`s with `tf.TypeSpec`s 58 corresponding to the symbolic outputs of that layer call. These output 59 `KerasTensor`s will have all of the internal KerasHistory metadata attached 60 to them that Keras needs to construct a Functional Model. 61 62 Currently, layers infer the output signature by: 63 * creating a scratch `FuncGraph` 64 * making placeholders in the scratch graph that match the input typespecs 65 * Calling `layer.call` on these placeholders 66 * extracting the signatures of the outputs before clearing the scratch graph 67 68 (Note: names assigned to KerasTensors by this process are not guaranteed to 69 be unique, and are subject to implementation details). 70 71 `tf.nest` methods are used to insure all of the inputs/output data 72 structures get maintained, with elements swapped between KerasTensors and 73 placeholders. 74 75 In rare cases (such as when directly manipulating shapes using Keras layers), 76 the layer may be able to partially infer the value of the output in addition 77 to just inferring the signature. 78 When this happens, the returned KerasTensor will also contain the inferred 79 value information. Follow-on layers can use this information. 80 during their own output signature inference. 81 E.g. if one layer produces a symbolic `KerasTensor` that the next layer uses 82 as the shape of its outputs, partially knowing the value helps infer the 83 output shape. 84 85 **Automatically converting TF APIs to layers**: 86 If you passing a `KerasTensor` to a TF API that supports dispatching, 87 Keras will automatically turn that API call into a lambda 88 layer in the Functional model, and return KerasTensors representing the 89 symbolic outputs. 90 91 Most TF APIs that take only tensors as input and produce output tensors 92 will support dispatching. 93 94 Calling a `tf.function` does not support dispatching, so you cannot pass 95 `KerasTensor`s as inputs to a `tf.function`. 96 97 Higher-order APIs that take methods which produce tensors (e.g. `tf.while`, 98 `tf.map_fn`, `tf.cond`) also do not currently support dispatching. So, you 99 cannot directly pass KerasTensors as inputs to these APIs either. If you 100 want to use these APIs inside of a Functional model, you must put them inside 101 of a custom layer. 102 103 Args: 104 type_spec: The `tf.TypeSpec` for the symbolic input created by 105 `tf.keras.Input`, or symbolically inferred for the output 106 during a symbolic layer `__call__`. 107 inferred_value: (Optional) a non-symbolic static value, possibly partially 108 specified, that could be symbolically inferred for the outputs during 109 a symbolic layer `__call__`. This will generally only happen when 110 grabbing and manipulating `tf.int32` shapes directly as tensors. 111 Statically inferring values in this way and storing them in the 112 KerasTensor allows follow-on layers to infer output signatures 113 more effectively. (e.g. when using a symbolic shape tensor to later 114 construct a tensor with that shape). 115 name: (optional) string name for this KerasTensor. Names automatically 116 generated by symbolic layer `__call__`s are not guaranteed to be unique, 117 and are subject to implementation details. 118 """ 119 120 def __init__(self, type_spec, inferred_value=None, name=None): 121 """Constructs a KerasTensor.""" 122 if not isinstance(type_spec, type_spec_module.TypeSpec): 123 raise ValueError('KerasTensors must be constructed with a `tf.TypeSpec`.') 124 125 self._type_spec = type_spec 126 self._inferred_value = inferred_value 127 self._name = name 128 129 @property 130 def type_spec(self): 131 """Returns the `tf.TypeSpec` symbolically inferred for this Keras output.""" 132 return self._type_spec 133 134 @property 135 def shape(self): 136 """Returns the `TensorShape` symbolically inferred for this Keras output.""" 137 # TODO(kaftan): This is only valid for normal/sparse/ragged tensors. 138 # may need to raise an error when it's not valid for a type_spec, 139 # but some keras code (e.g. build-related stuff) will likely fail when 140 # it can't access shape or dtype 141 return self._type_spec._shape # pylint: disable=protected-access 142 143 @classmethod 144 def from_tensor(cls, tensor): 145 """Convert a traced (composite)tensor to a representative KerasTensor.""" 146 if isinstance(tensor, ops.Tensor): 147 name = getattr(tensor, 'name', None) 148 type_spec = type_spec_module.type_spec_from_value(tensor) 149 inferred_value = None 150 if (type_spec.dtype == dtypes.int32 and type_spec.shape.rank is not None 151 and type_spec.shape.rank < 2): 152 # If this tensor might be representing shape information, 153 # (dtype=int32, rank of 0 or 1, not too large to represent a shape) 154 # we attempt to capture any value information tensorflow's 155 # shape handling can extract from the current scratch graph. 156 # 157 # Even though keras layers each trace in their own scratch 158 # graph, this shape value info extraction allows us to capture 159 # a sizable and useful subset of the C++ shape value inference TF can do 160 # if all tf ops appear in the same graph when using shape ops. 161 # 162 # Examples of things this cannot infer concrete dimensions for 163 # that the full single-graph C++ shape inference sometimes can are: 164 # * cases where the shape tensor is cast out of int32 before being 165 # manipulated w/ floating point numbers then converted back 166 # * cases where int32 tensors w/ rank >= 2 are manipulated before being 167 # used as a shape tensor 168 # * cases where int32 tensors too large to represent shapes are 169 # manipulated to a smaller size before being used as a shape tensor 170 inferred_value = array_ops.ones(shape=tensor).shape 171 if inferred_value.dims: 172 inferred_value = inferred_value.as_list() 173 if len(inferred_value) > _MAX_TENSOR_RANK: 174 inferred_value = None 175 else: 176 inferred_value = None 177 178 return KerasTensor(type_spec, inferred_value=inferred_value, name=name) 179 else: 180 # Fallback to the generic arbitrary-typespec KerasTensor 181 name = getattr(tensor, 'name', None) 182 type_spec = type_spec_module.type_spec_from_value(tensor) 183 return cls(type_spec, name=name) 184 185 @classmethod 186 def from_type_spec(cls, type_spec, name=None): 187 return cls(type_spec=type_spec, name=name) 188 189 def _to_placeholder(self): 190 """Convert this KerasTensor to a placeholder in a graph.""" 191 # If there is an inferred value for this tensor, inject the inferred value 192 if self._inferred_value is not None: 193 # If we suspect this KerasTensor might be representing a shape tensor, 194 # and we were able to extract value information with TensorFlow's shape 195 # handling when making the KerasTensor, we construct the placeholder by 196 # re-injecting the inferred value information into the graph. We 197 # do this injection through the shape of a placeholder, because that 198 # allows us to specify partially-unspecified shape values. 199 # 200 # See the comment on value extraction inside `from_tensor` for more info. 201 inferred_value = array_ops.shape( 202 array_ops.placeholder( 203 shape=self._inferred_value, dtype=dtypes.int32)) 204 if self.type_spec.shape.rank == 0: 205 # `tf.shape` always returns a rank-1, we may need to turn it back to a 206 # scalar. 207 inferred_value = inferred_value[0] 208 return inferred_value 209 210 # Use the generic conversion from typespec to a placeholder. 211 def component_to_placeholder(component): 212 return array_ops.placeholder(component.dtype, component.shape) 213 214 return nest.map_structure( 215 component_to_placeholder, self.type_spec, expand_composites=True) 216 217 def get_shape(self): 218 return self.shape 219 220 def __len__(self): 221 raise TypeError('Keras symbolic inputs/outputs do not ' 222 'implement `__len__`. You may be ' 223 'trying to pass Keras symbolic inputs/outputs ' 224 'to a TF API that does not register dispatching, ' 225 'preventing Keras from automatically ' 226 'converting the API call to a lambda layer ' 227 'in the Functional Model. This error will also get raised ' 228 'if you try asserting a symbolic input/output directly.') 229 230 @property 231 def op(self): 232 raise TypeError('Keras symbolic inputs/outputs do not ' 233 'implement `op`. You may be ' 234 'trying to pass Keras symbolic inputs/outputs ' 235 'to a TF API that does not register dispatching, ' 236 'preventing Keras from automatically ' 237 'converting the API call to a lambda layer ' 238 'in the Functional Model.') 239 240 def __hash__(self): 241 raise TypeError('Tensors are unhashable. (%s)' 242 'Instead, use tensor.ref() as the key.' % self) 243 244 # Note: This enables the KerasTensor's overloaded "right" binary 245 # operators to run when the left operand is an ndarray, because it 246 # accords the Tensor class higher priority than an ndarray, or a 247 # numpy matrix. 248 # In the future explore chaning this to using numpy's __numpy_ufunc__ 249 # mechanism, which allows more control over how Tensors interact 250 # with ndarrays. 251 __array_priority__ = 100 252 253 def __array__(self): 254 raise TypeError( 255 'Cannot convert a symbolic Keras input/output to a numpy array. ' 256 'This error may indicate that you\'re trying to pass a symbolic value ' 257 'to a NumPy call, which is not supported. Or, ' 258 'you may be trying to pass Keras symbolic inputs/outputs ' 259 'to a TF API that does not register dispatching, ' 260 'preventing Keras from automatically ' 261 'converting the API call to a lambda layer ' 262 'in the Functional Model.') 263 264 @property 265 def is_tensor_like(self): 266 return True 267 268 def set_shape(self, shape): 269 """Updates the shape of this KerasTensor. Mimics `tf.Tensor.set_shape()`.""" 270 if not isinstance(shape, tensor_shape.TensorShape): 271 shape = tensor_shape.TensorShape(shape) 272 if shape.dims is not None: 273 dim_list = [dim.value for dim in shape.dims] 274 for dim in range(len(dim_list)): 275 if dim_list[dim] is None and self.shape.dims is not None: 276 dim_list[dim] = self.shape.dims[dim] 277 shape = tensor_shape.TensorShape(dim_list) 278 if not self.shape.is_compatible_with(shape): 279 raise ValueError( 280 "Keras symbolic input/output's shape %s is not" 281 "compatible with supplied shape %s" % 282 (self.shape, shape)) 283 else: 284 self._type_spec._shape = shape # pylint: disable=protected-access 285 286 def __str__(self): 287 symbolic_description = '' 288 inferred_value_string = '' 289 name_string = '' 290 291 if hasattr(self, '_keras_history'): 292 layer = self._keras_history.layer 293 symbolic_description = ( 294 ', description="created by layer \'%s\'"' % (layer.name,)) 295 if self._inferred_value is not None: 296 inferred_value_string = ( 297 ', inferred_value=%s' % self._inferred_value) 298 if self.name is not None: 299 name_string = ', name=\'%s\'' % self._name 300 return 'KerasTensor(type_spec=%s%s%s%s)' % ( 301 self.type_spec, inferred_value_string, 302 name_string, symbolic_description) 303 304 def __repr__(self): 305 symbolic_description = '' 306 inferred_value_string = '' 307 if isinstance(self.type_spec, tensor_spec.TensorSpec): 308 type_spec_string = 'shape=%s dtype=%s' % (self.shape, self.dtype.name) 309 else: 310 type_spec_string = 'type_spec=%s' % self.type_spec 311 312 if hasattr(self, '_keras_history'): 313 layer = self._keras_history.layer 314 symbolic_description = ' (created by layer \'%s\')' % (layer.name,) 315 if self._inferred_value is not None: 316 inferred_value_string = ( 317 ' inferred_value=%s' % self._inferred_value) 318 return '<KerasTensor: %s%s%s>' % ( 319 type_spec_string, inferred_value_string, symbolic_description) 320 321 @property 322 def dtype(self): 323 """Returns the `dtype` symbolically inferred for this Keras output.""" 324 # TODO(kaftan): This is only valid for normal/sparse/ragged tensors. 325 # may need to raise an error when it's not valid for a type_spec, 326 # but some keras code (e.g. build-related stuff) will likely fail when 327 # it can't access shape or dtype 328 return self._type_spec._dtype # pylint: disable=protected-access 329 330 def ref(self): 331 """Returns a hashable reference object to this KerasTensor. 332 333 The primary use case for this API is to put KerasTensors in a 334 set/dictionary. We can't put tensors in a set/dictionary as 335 `tensor.__hash__()` is not available and tensor equality (`==`) is supposed 336 to produce a tensor representing if the two inputs are equal. 337 338 See the documentation of `tf.Tensor.ref()` for more info. 339 """ 340 return object_identity.Reference(self) 341 342 def __iter__(self): 343 shape = None 344 if self.shape.ndims is not None: 345 shape = [dim.value for dim in self.shape.dims] 346 347 if shape is None: 348 raise TypeError('Cannot iterate over a Tensor with unknown shape.') 349 if not shape: 350 raise TypeError('Cannot iterate over a scalar.') 351 if shape[0] is None: 352 raise TypeError( 353 'Cannot iterate over a Tensor with unknown first dimension.') 354 return _KerasTensorIterator(self, shape[0]) 355 356 @property 357 def name(self): 358 """Returns the (non-unique, optional) name of this symbolic Keras value.""" 359 return self._name 360 361 @classmethod 362 def _overload_all_operators(cls, tensor_class): # pylint: disable=invalid-name 363 """Register overloads for all operators.""" 364 for operator in ops.Tensor.OVERLOADABLE_OPERATORS: 365 cls._overload_operator(tensor_class, operator) 366 367 # We include `experimental_ref` for versions of TensorFlow that 368 # still include the deprecated method in Tensors. 369 if hasattr(tensor_class, 'experimental_ref'): 370 cls._overload_operator(tensor_class, 'experimental_ref') 371 372 @classmethod 373 def _overload_operator(cls, tensor_class, operator): # pylint: disable=invalid-name 374 """Overload an operator with the same implementation as a base Tensor class. 375 376 We pull the operator out of the class dynamically to avoid ordering issues. 377 378 Args: 379 tensor_class: The (Composite)Tensor to get the method from. 380 operator: string. The operator name. 381 """ 382 tensor_oper = getattr(tensor_class, operator) 383 384 # Compatibility with Python 2: 385 # Python 2 unbound methods have type checks for the first arg, 386 # so we need to extract the underlying function 387 tensor_oper = getattr(tensor_oper, '__func__', tensor_oper) 388 389 setattr(cls, operator, tensor_oper) 390 391 392KerasTensor._overload_all_operators(ops.Tensor) # pylint: disable=protected-access 393 394 395class SparseKerasTensor(KerasTensor): 396 """A specialized KerasTensor representation for `tf.sparse.SparseTensor`s. 397 398 Specifically, it specializes the conversion to a placeholder in order 399 to maintain dense shape information. 400 """ 401 402 def _to_placeholder(self): 403 spec = self.type_spec 404 405 # nest.map_structure loses dense shape information for sparse tensors. 406 # So, we special-case sparse placeholder creation. 407 # This only preserves shape information for top-level sparse tensors; 408 # not for sparse tensors that are nested inside another composite 409 # tensor. 410 return array_ops.sparse_placeholder(dtype=spec.dtype, shape=spec.shape) 411 412 413class RaggedKerasTensor(KerasTensor): 414 """A specialized KerasTensor representation for `tf.RaggedTensor`s. 415 416 Specifically, it: 417 418 1. Specializes the conversion to a placeholder in order 419 to maintain shape information for non-ragged dimensions. 420 2. Overloads the KerasTensor's operators with the RaggedTensor versions 421 when they don't match the `tf.Tensor` versions 422 3. Exposes some of the instance method/attribute that are unique to 423 the RaggedTensor API (such as ragged_rank). 424 """ 425 426 def _to_placeholder(self): 427 ragged_spec = self.type_spec 428 if ragged_spec.ragged_rank == 0 or ragged_spec.shape.rank is None: 429 return super(RaggedKerasTensor, self)._to_placeholder() 430 431 flat_shape = ragged_spec.shape[ragged_spec.ragged_rank:] 432 result = array_ops.placeholder(ragged_spec.dtype, flat_shape) 433 434 known_num_splits = [] 435 prod = 1 436 for axis_size in ragged_spec.shape: 437 if prod is not None: 438 if axis_size is None or ( 439 getattr(axis_size, 'value', True) is None): 440 prod = None 441 else: 442 prod = prod * axis_size 443 known_num_splits.append(prod) 444 445 for axis in range(ragged_spec.ragged_rank, 0, -1): 446 axis_size = ragged_spec.shape[axis] 447 if axis_size is None or (getattr(axis_size, 'value', True) is None): 448 num_splits = known_num_splits[axis-1] 449 if num_splits is not None: 450 num_splits = num_splits + 1 451 splits = array_ops.placeholder( 452 ragged_spec.row_splits_dtype, [num_splits]) 453 result = ragged_tensor.RaggedTensor.from_row_splits( 454 result, splits, validate=False) 455 else: 456 rowlen = constant_op.constant(axis_size, ragged_spec.row_splits_dtype) 457 result = ragged_tensor.RaggedTensor.from_uniform_row_length( 458 result, rowlen, validate=False) 459 return result 460 461 @property 462 def ragged_rank(self): 463 return self.type_spec.ragged_rank 464 465# Overload slicing 466RaggedKerasTensor._overload_operator(ragged_tensor.RaggedTensor, '__getitem__') # pylint: disable=protected-access 467 468# Overload math ops 469RaggedKerasTensor._overload_operator(ragged_tensor.RaggedTensor, '__add__') # pylint: disable=protected-access 470RaggedKerasTensor._overload_operator(ragged_tensor.RaggedTensor, '__radd__') # pylint: disable=protected-access 471RaggedKerasTensor._overload_operator(ragged_tensor.RaggedTensor, '__mul__') # pylint: disable=protected-access 472RaggedKerasTensor._overload_operator(ragged_tensor.RaggedTensor, '__rmul__') # pylint: disable=protected-access 473 474 475# TODO(b/161487382): 476# Special-case user-registered symbolic objects (registered by the 477# private `register_symbolic_tensor_type` method) by passing them between 478# scratch graphs directly. 479# This is needed to not break Tensorflow probability 480# while they finish migrating to composite tensors. 481class UserRegisteredSpec(type_spec_module.TypeSpec): 482 """TypeSpec to represent user-registered symbolic objects.""" 483 484 def __init__(self, shape, dtype): 485 self.shape = shape 486 self._dtype = dtype 487 self.dtype = dtype 488 489 def _component_specs(self): 490 raise NotImplementedError 491 492 def _from_components(self, components): 493 raise NotImplementedError 494 495 def _serialize(self): 496 raise NotImplementedError 497 498 def _to_components(self, value): 499 raise NotImplementedError 500 501 def value_type(self): 502 raise NotImplementedError 503 504 505# TODO(b/161487382): 506# Special-case user-registered symbolic objects (registered by the 507# private `register_symbolic_tensor_type` method) by passing them between 508# scratch graphs directly. 509# This is needed to not break Tensorflow probability 510# while they finish migrating to composite tensors. 511class UserRegisteredTypeKerasTensor(KerasTensor): 512 """KerasTensor that represents legacy register_symbolic_tensor_type.""" 513 514 def __init__(self, user_registered_symbolic_object): 515 x = user_registered_symbolic_object 516 self._user_registered_symbolic_object = x 517 type_spec = UserRegisteredSpec(x.shape, x.dtype) 518 name = getattr(x, 'name', None) 519 520 super(UserRegisteredTypeKerasTensor, self).__init__(type_spec, name) 521 522 @classmethod 523 def from_tensor(cls, tensor): 524 return cls(tensor) 525 526 @classmethod 527 def from_type_spec(cls, type_spec, name=None): 528 raise NotImplementedError('You cannot instantiate a KerasTensor ' 529 'directly from TypeSpec: %s' % type_spec) 530 531 def _to_placeholder(self): 532 return self._user_registered_symbolic_object 533 534 535class _KerasTensorIterator(object): 536 """Iterates over the leading dim of a KerasTensor. Performs 0 error checks.""" 537 538 def __init__(self, tensor, dim0): 539 self._tensor = tensor 540 self._index = 0 541 self._limit = dim0 542 543 def __iter__(self): 544 return self 545 546 def __next__(self): 547 if self._index == self._limit: 548 raise StopIteration 549 result = self._tensor[self._index] 550 self._index += 1 551 return result 552 553 554# Specify the mappings of tensor class to KerasTensor class. 555# This is specifically a list instead of a dict for now because 556# 1. we do a check w/ isinstance because a key lookup based on class 557# would miss subclasses 558# 2. a list allows us to control lookup ordering 559# We include ops.Tensor -> KerasTensor in the first position as a fastpath, 560# *and* include object -> KerasTensor at the end as a catch-all. 561# We can re-visit these choices in the future as needed. 562keras_tensor_classes = [ 563 (ops.Tensor, KerasTensor), 564 (sparse_tensor.SparseTensor, SparseKerasTensor), 565 (ragged_tensor.RaggedTensor, RaggedKerasTensor), 566 (object, KerasTensor) 567] 568 569 570def register_keras_tensor_specialization(cls, keras_tensor_subclass): 571 """Register a specialized KerasTensor subclass for a Tensor type.""" 572 # We always leave (object, KerasTensor) at the end as a generic fallback 573 keras_tensor_classes.insert(-1, (cls, keras_tensor_subclass)) 574 575 576def keras_tensor_to_placeholder(x): 577 """Construct a graph placeholder to represent a KerasTensor when tracing.""" 578 if isinstance(x, KerasTensor): 579 return x._to_placeholder() # pylint: disable=protected-access 580 else: 581 return x 582 583 584def keras_tensor_from_tensor(tensor): 585 """Convert a traced (composite)tensor to a representative KerasTensor.""" 586 # Create a specialized KerasTensor that supports instance methods, 587 # operators, and additional value inference if possible 588 keras_tensor_cls = None 589 for tensor_type, cls in keras_tensor_classes: 590 if isinstance(tensor, tensor_type): 591 keras_tensor_cls = cls 592 break 593 594 out = keras_tensor_cls.from_tensor(tensor) 595 596 if hasattr(tensor, '_keras_mask'): 597 out._keras_mask = keras_tensor_from_tensor(tensor._keras_mask) # pylint: disable=protected-access 598 return out 599 600 601def keras_tensor_from_type_spec(type_spec, name=None): 602 """Convert a TypeSpec to a representative KerasTensor.""" 603 # Create a specialized KerasTensor that supports instance methods, 604 # operators, and additional value inference if possible 605 keras_tensor_cls = None 606 value_type = type_spec.value_type 607 for tensor_type, cls in keras_tensor_classes: 608 if issubclass(value_type, tensor_type): 609 keras_tensor_cls = cls 610 break 611 612 return keras_tensor_cls.from_type_spec(type_spec, name=name) 613