1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""TensorFlow-related utilities.""" 16 17import collections 18import copy 19import numpy as np 20 21from tensorflow.python.data.experimental.ops import cardinality 22from tensorflow.python.distribute.coordinator import cluster_coordinator as coordinator_lib 23from tensorflow.python.eager import context 24from tensorflow.python.framework import composite_tensor 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import sparse_tensor 27from tensorflow.python.framework import tensor_shape 28from tensorflow.python.framework import tensor_spec 29from tensorflow.python.framework import tensor_util 30from tensorflow.python.framework import type_spec 31from tensorflow.python.keras import backend as K 32from tensorflow.python.keras.engine import keras_tensor 33from tensorflow.python.keras.utils import object_identity 34from tensorflow.python.keras.utils import tf_contextlib 35from tensorflow.python.ops import math_ops 36from tensorflow.python.ops import variables 37from tensorflow.python.ops.ragged import ragged_tensor 38from tensorflow.python.ops.ragged import ragged_tensor_value 39from tensorflow.python.util import nest 40from tensorflow.python.util.tf_export import keras_export 41 42 43def is_tensor_or_tensor_list(v): 44 v = nest.flatten(v) 45 if v and isinstance(v[0], ops.Tensor): 46 return True 47 else: 48 return False 49 50 51def get_reachable_from_inputs(inputs, targets=None): 52 """Returns the set of tensors/ops reachable from `inputs`. 53 54 Stops if all targets have been found (target is optional). 55 56 Only valid in Symbolic mode, not Eager mode. 57 58 Args: 59 inputs: List of tensors. 60 targets: List of tensors. 61 62 Returns: 63 A set of tensors reachable from the inputs (includes the inputs themselves). 64 """ 65 inputs = nest.flatten(inputs, expand_composites=True) 66 reachable = object_identity.ObjectIdentitySet(inputs) 67 if targets: 68 remaining_targets = object_identity.ObjectIdentitySet(nest.flatten(targets)) 69 queue = collections.deque(inputs) 70 71 while queue: 72 x = queue.pop() 73 if isinstance(x, tuple(_user_convertible_tensor_types)): 74 # Can't find consumers of user-specific types. 75 continue 76 77 if isinstance(x, ops.Operation): 78 outputs = x.outputs[:] or [] 79 outputs += x._control_outputs # pylint: disable=protected-access 80 elif isinstance(x, variables.Variable): 81 try: 82 outputs = [x.op] 83 except AttributeError: 84 # Variables can be created in an Eager context. 85 outputs = [] 86 elif tensor_util.is_tf_type(x): 87 outputs = x.consumers() 88 else: 89 raise TypeError('Expected Operation, Variable, or Tensor, got ' + str(x)) 90 91 for y in outputs: 92 if y not in reachable: 93 reachable.add(y) 94 if targets: 95 remaining_targets.discard(y) 96 queue.appendleft(y) 97 98 if targets and not remaining_targets: 99 return reachable 100 101 return reachable 102 103 104# This function needs access to private functions of `nest`. 105# pylint: disable=protected-access 106def map_structure_with_atomic(is_atomic_fn, map_fn, nested): 107 """Maps the atomic elements of a nested structure. 108 109 Args: 110 is_atomic_fn: A function that determines if an element of `nested` is 111 atomic. 112 map_fn: The function to apply to atomic elements of `nested`. 113 nested: A nested structure. 114 115 Returns: 116 The nested structure, with atomic elements mapped according to `map_fn`. 117 118 Raises: 119 ValueError: If an element that is neither atomic nor a sequence is 120 encountered. 121 """ 122 if is_atomic_fn(nested): 123 return map_fn(nested) 124 125 # Recursively convert. 126 if not nest.is_nested(nested): 127 raise ValueError( 128 'Received non-atomic and non-sequence element: {}'.format(nested)) 129 if nest.is_mapping(nested): 130 values = [nested[k] for k in sorted(nested.keys())] 131 elif nest.is_attrs(nested): 132 values = _astuple(nested) 133 else: 134 values = nested 135 mapped_values = [ 136 map_structure_with_atomic(is_atomic_fn, map_fn, ele) for ele in values 137 ] 138 return nest._sequence_like(nested, mapped_values) 139 140 141def get_shapes(tensors): 142 """Gets shapes from tensors.""" 143 return nest.map_structure(lambda x: x.shape, tensors) 144 145 146# pylint: enable=protected-access 147 148 149def convert_shapes(input_shape, to_tuples=True): 150 """Converts nested shape representations to desired format. 151 152 Performs: 153 154 TensorShapes -> tuples if `to_tuples=True`. 155 tuples of int or None -> TensorShapes if `to_tuples=False`. 156 157 Valid objects to be converted are: 158 - TensorShapes 159 - tuples with elements of type int or None. 160 - ints 161 - None 162 163 Args: 164 input_shape: A nested structure of objects to be converted to TensorShapes. 165 to_tuples: If `True`, converts all TensorShape to tuples. Otherwise converts 166 all tuples representing shapes to TensorShapes. 167 168 Returns: 169 Nested structure of shapes in desired format. 170 171 Raises: 172 ValueError: when the input tensor shape can't be converted to tuples, eg 173 unknown tensor shape. 174 """ 175 176 def _is_shape_component(value): 177 return value is None or isinstance(value, (int, tensor_shape.Dimension)) 178 179 def _is_atomic_shape(input_shape): 180 # Ex: TensorShape or (None, 10, 32) or 5 or `None` 181 if _is_shape_component(input_shape): 182 return True 183 if isinstance(input_shape, tensor_shape.TensorShape): 184 return True 185 if (isinstance(input_shape, (tuple, list)) and 186 all(_is_shape_component(ele) for ele in input_shape)): 187 return True 188 return False 189 190 def _convert_shape(input_shape): 191 input_shape = tensor_shape.TensorShape(input_shape) 192 if to_tuples: 193 input_shape = tuple(input_shape.as_list()) 194 return input_shape 195 196 return map_structure_with_atomic(_is_atomic_shape, _convert_shape, 197 input_shape) 198 199 200class ListWrapper(object): 201 """A wrapper for lists to be treated as elements for `nest`.""" 202 203 def __init__(self, list_to_wrap): 204 self._list = list_to_wrap 205 206 def as_list(self): 207 return self._list 208 209 210def convert_inner_node_data(nested, wrap=False): 211 """Either wraps or unwraps innermost node data lists in `ListWrapper` objects. 212 213 Args: 214 nested: A nested data structure. 215 wrap: If `True`, wrap innermost lists in `ListWrapper` objects. If `False`, 216 unwraps `ListWrapper` objects into lists. 217 218 Returns: 219 Structure of same type as nested, with lists wrapped/unwrapped. 220 """ 221 222 def _is_serialized_node_data(nested): 223 # Node data can be of form `[layer_name, node_id, tensor_id]` or 224 # `[layer_name, node_id, tensor_id, kwargs]`. 225 if (isinstance(nested, list) and (len(nested) in [3, 4]) and 226 isinstance(nested[0], str)): 227 return True 228 return False 229 230 def _is_atomic_nested(nested): 231 """Returns `True` if `nested` is a list representing node data.""" 232 if isinstance(nested, ListWrapper): 233 return True 234 if _is_serialized_node_data(nested): 235 return True 236 return not nest.is_nested(nested) 237 238 def _convert_object_or_list(nested): 239 """Convert b/t `ListWrapper` object and list representations.""" 240 if wrap: 241 if isinstance(nested, ListWrapper): 242 return nested 243 if _is_serialized_node_data(nested): 244 return ListWrapper(nested) 245 return nested 246 else: 247 if isinstance(nested, ListWrapper): 248 return nested.as_list() 249 return nested 250 251 return map_structure_with_atomic(_is_atomic_nested, _convert_object_or_list, 252 nested) 253 254 255def shape_type_conversion(fn): 256 """Decorator that handles tuple/TensorShape conversion. 257 258 Used in `compute_output_shape` and `build`. 259 260 Args: 261 fn: function to wrap. 262 263 Returns: 264 Wrapped function. 265 """ 266 267 def wrapper(instance, input_shape): 268 # Pass shapes as tuples to `fn` 269 # This preserves compatibility with external Keras. 270 if input_shape is not None: 271 input_shape = convert_shapes(input_shape, to_tuples=True) 272 output_shape = fn(instance, input_shape) 273 # Return shapes from `fn` as TensorShapes. 274 if output_shape is not None: 275 output_shape = convert_shapes(output_shape, to_tuples=False) 276 return output_shape 277 278 return wrapper 279 280 281def are_all_symbolic_tensors(tensors): 282 return all(map(is_symbolic_tensor, tensors)) 283 284 285_user_convertible_tensor_types = set() 286 287 288def is_extension_type(tensor): 289 """Returns whether a tensor is of an ExtensionType. 290 291 github.com/tensorflow/community/pull/269 292 Currently it works by checking if `tensor` is a `CompositeTensor` instance, 293 but this will be changed to use an appropriate extensiontype protocol 294 check once ExtensionType is made public. 295 296 Args: 297 tensor: An object to test 298 299 Returns: 300 True if the tensor is an extension type object, false if not. 301 """ 302 return isinstance(tensor, composite_tensor.CompositeTensor) 303 304 305def is_symbolic_tensor(tensor): 306 """Returns whether a tensor is symbolic (from a TF graph) or an eager tensor. 307 308 A Variable can be seen as either: it is considered symbolic 309 when we are in a graph scope, and eager when we are in an eager scope. 310 311 Args: 312 tensor: A tensor instance to test. 313 314 Returns: 315 True for symbolic tensors, False for eager tensors. 316 """ 317 if isinstance(tensor, ops.Tensor): 318 return hasattr(tensor, 'graph') 319 elif is_extension_type(tensor): 320 component_tensors = nest.flatten(tensor, expand_composites=True) 321 return any(hasattr(t, 'graph') for t in component_tensors) 322 elif isinstance(tensor, variables.Variable): 323 # Variables that are output of a Keras Layer in Functional API mode 324 # should be considered symbolic. 325 # TODO(omalleyt): We need a better way to check this in order to 326 # enable `run_eagerly=True` for Models containing Layers that 327 # return Variables as outputs. 328 return (getattr(tensor, '_keras_history', False) or 329 not context.executing_eagerly()) 330 elif isinstance(tensor, tuple(_user_convertible_tensor_types)): 331 tensor = ops.convert_to_tensor_or_composite(tensor) 332 return is_symbolic_tensor(tensor) 333 else: 334 return False 335 336 337@keras_export('keras.__internal__.utils.register_symbolic_tensor_type', v1=[]) 338def register_symbolic_tensor_type(cls): 339 """Allows users to specify types regarded as symbolic `Tensor`s. 340 341 Used in conjunction with `tf.register_tensor_conversion_function`, calling 342 `tf.keras.__internal__.utils.register_symbolic_tensor_type(cls)` 343 allows non-`Tensor` objects to be plumbed through Keras layers. 344 345 Example: 346 347 ```python 348 # One-time setup. 349 class Foo(object): 350 def __init__(self, input_): 351 self._input = input_ 352 def value(self): 353 return tf.constant(42.) 354 355 tf.register_tensor_conversion_function( 356 Foo, lambda x, *args, **kwargs: x.value()) 357 358 tf.keras.__internal__.utils.register_symbolic_tensor_type(Foo) 359 360 # User-land. 361 layer = tf.keras.layers.Lambda(lambda input_: Foo(input_)) 362 ``` 363 364 Args: 365 cls: A `class` type which shall be regarded as a symbolic `Tensor`. 366 """ 367 global _user_convertible_tensor_types 368 if cls not in _user_convertible_tensor_types: 369 keras_tensor.register_keras_tensor_specialization( 370 cls, keras_tensor.UserRegisteredTypeKerasTensor) 371 _user_convertible_tensor_types.add(cls) 372 373 374def type_spec_from_value(value): 375 """Grab type_spec without converting array-likes to tensors.""" 376 if is_extension_type(value): 377 return value._type_spec # pylint: disable=protected-access 378 # Get a TensorSpec for array-like data without 379 # converting the data to a Tensor 380 if hasattr(value, 'shape') and hasattr(value, 'dtype'): 381 return tensor_spec.TensorSpec(value.shape, value.dtype) 382 else: 383 return type_spec.type_spec_from_value(value) 384 385 386def is_ragged(tensor): 387 """Returns true if `tensor` is a ragged tensor or ragged tensor value.""" 388 return isinstance( 389 tensor, 390 (ragged_tensor.RaggedTensor, ragged_tensor_value.RaggedTensorValue)) 391 392 393def is_sparse(tensor): 394 """Returns true if `tensor` is a sparse tensor or sparse tensor value.""" 395 return isinstance( 396 tensor, 397 (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)) 398 399 400def is_tensor_or_variable(x): 401 return tensor_util.is_tf_type(x) or isinstance(x, variables.Variable) 402 403 404def assert_no_legacy_layers(layers): 405 """Prevent tf.layers.Layers from being used with Keras. 406 407 Certain legacy layers inherit from their keras analogs; however they are 408 not supported with keras and can lead to subtle and hard to diagnose bugs. 409 410 Args: 411 layers: A list of layers to check 412 413 Raises: 414 TypeError: If any elements of layers are tf.layers.Layers 415 """ 416 417 # isinstance check for tf.layers.Layer introduces a circular dependency. 418 legacy_layers = [l for l in layers if getattr(l, '_is_legacy_layer', None)] 419 if legacy_layers: 420 layer_str = '\n'.join(' ' + str(l) for l in legacy_layers) 421 raise TypeError( 422 'The following are legacy tf.layers.Layers:\n{}\nTo use keras as a ' 423 'framework (for instance using the Network, Model, or Sequential ' 424 'classes), please use the tf.keras.layers implementation instead. ' 425 '(Or, if writing custom layers, subclass from tf.keras.layers rather ' 426 'than tf.layers)'.format(layer_str)) 427 428 429@tf_contextlib.contextmanager 430def maybe_init_scope(layer): 431 """Open an `init_scope` if in V2 mode and using the keras graph. 432 433 Args: 434 layer: The Layer/Model that is currently active. 435 436 Yields: 437 None 438 """ 439 # Don't open an init_scope in V1 mode or when using legacy tf.layers. 440 if (ops.executing_eagerly_outside_functions() and 441 getattr(layer, '_keras_style', True)): 442 with ops.init_scope(): 443 yield 444 else: 445 yield 446 447 448@tf_contextlib.contextmanager 449def graph_context_for_symbolic_tensors(*args, **kwargs): 450 """Returns graph context manager if any of the inputs is a symbolic tensor.""" 451 if any(is_symbolic_tensor(v) for v in list(args) + list(kwargs.values())): 452 with K.get_graph().as_default(): 453 yield 454 else: 455 yield 456 457 458def dataset_is_infinite(dataset): 459 """True if the passed dataset is infinite.""" 460 if ops.executing_eagerly_outside_functions(): 461 return math_ops.equal( 462 cardinality.cardinality(dataset), cardinality.INFINITE) 463 else: 464 dataset_size = K.get_session().run(cardinality.cardinality(dataset)) 465 return dataset_size == cardinality.INFINITE 466 467 468def get_tensor_spec(t, dynamic_batch=False, name=None): 469 """Returns a `TensorSpec` given a single `Tensor` or `TensorSpec`.""" 470 # pylint: disable=protected-access 471 if isinstance(t, type_spec.TypeSpec): 472 spec = t 473 elif is_extension_type(t): 474 # TODO(b/148821952): Should these specs have a name attr? 475 spec = t._type_spec 476 elif (hasattr(t, '_keras_history') and 477 hasattr(t._keras_history[0], '_type_spec')): 478 return t._keras_history[0]._type_spec 479 elif hasattr(t, 'shape') and hasattr(t, 'dtype'): 480 spec = tensor_spec.TensorSpec(shape=t.shape, dtype=t.dtype, name=name) 481 else: 482 return None # Allow non-Tensors to pass through. 483 484 if not dynamic_batch: 485 return spec 486 487 dynamic_batch_spec = copy.deepcopy(spec) 488 # RaggedTensorSpec only has a private _shape. 489 shape = dynamic_batch_spec._shape 490 if shape.rank is not None and shape.rank > 0: 491 shape_list = shape.as_list() 492 shape_list[0] = None 493 dynamic_batch_spec._shape = tensor_shape.TensorShape(shape_list) 494 return dynamic_batch_spec 495 # pylint: enable=protected-access 496 497 498def sync_to_numpy_or_python_type(tensors): 499 """Syncs and converts a structure of `Tensor`s to `NumPy` arrays or Python scalar types. 500 501 For each tensor, it calls `tensor.numpy()`. If the result is a scalar value, 502 it converts it to a Python type, such as a float or int, by calling 503 `result.item()`. 504 505 Numpy scalars are converted, as Python types are often more convenient to deal 506 with. This is especially useful for bfloat16 Numpy scalars, which don't 507 support as many operations as other Numpy values. 508 509 Async strategies (such as `TPUStrategy` and `ParameterServerStrategy`) are 510 forced to 511 sync during this process. 512 513 Args: 514 tensors: A structure of tensors. 515 516 Returns: 517 `tensors`, but scalar tensors are converted to Python types and non-scalar 518 tensors are converted to Numpy arrays. 519 """ 520 if isinstance(tensors, coordinator_lib.RemoteValue): 521 return tensors.fetch() 522 523 def _to_single_numpy_or_python_type(t): 524 if isinstance(t, ops.Tensor): 525 x = t.numpy() 526 return x.item() if np.ndim(x) == 0 else x 527 return t # Don't turn ragged or sparse tensors to NumPy. 528 529 return nest.map_structure(_to_single_numpy_or_python_type, tensors) 530 531 532def _astuple(attrs): 533 """Converts the given attrs to tuple non-recursively.""" 534 cls = type(attrs) 535 fields = getattr(cls, '__attrs_attrs__', None) 536 if fields is None: 537 raise ValueError('%r is not an attrs-decorated class.' % cls) 538 values = [] 539 for field in fields: 540 values.append(getattr(attrs, field.name)) 541 return tuple(values) 542