1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Contains private utilities used mainly by the base Layer class.""" 16 17import functools 18import threading 19 20from tensorflow.python import tf2 21from tensorflow.python.distribute import distribution_strategy_context 22from tensorflow.python.eager import context 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import ops 25from tensorflow.python.framework import sparse_tensor 26from tensorflow.python.framework import tensor_shape 27from tensorflow.python.framework import tensor_util 28from tensorflow.python.keras import backend 29from tensorflow.python.keras.utils import control_flow_util 30from tensorflow.python.keras.utils import tf_inspect 31from tensorflow.python.keras.utils import tf_utils 32from tensorflow.python.ops import array_ops 33from tensorflow.python.ops import variables as tf_variables 34from tensorflow.python.ops.ragged import ragged_tensor 35from tensorflow.python.trackable import base as tracking 36from tensorflow.python.util import keras_deps 37from tensorflow.python.util import nest 38from tensorflow.python.util.tf_export import keras_export 39 40_call_context = threading.local() 41 42 43def create_mean_metric(value, name=None): 44 # import keras will import base_layer and then this module, and metric relies 45 # on base_layer, which result into a cyclic dependency. 46 from tensorflow.python.keras import metrics as metrics_module # pylint: disable=g-import-not-at-top 47 metric_obj = metrics_module.Mean(name=name, dtype=value.dtype) 48 return metric_obj, metric_obj(value) 49 50 51def make_variable(name, 52 shape=None, 53 dtype=dtypes.float32, 54 initializer=None, 55 trainable=None, 56 caching_device=None, 57 validate_shape=True, 58 constraint=None, 59 use_resource=None, 60 collections=None, 61 synchronization=tf_variables.VariableSynchronization.AUTO, 62 aggregation=tf_variables.VariableAggregation.NONE, 63 partitioner=None): # pylint: disable=unused-argument 64 """Temporary util to create a variable (relies on `variable_scope.variable`). 65 66 Some reuse-related technicalities prevent us from using 67 `variable_scope.get_variable()` directly, so we use a subcomponent 68 that has fewer constraints (`variable_scope.variable()`). 69 70 In the longer term, it seems like a similar "default variable creator" method 71 should exist in `Trackable` instead. When this happens, we can get 72 rid of this temporary solution. 73 74 TODO(fchollet): remove this method when no longer needed. 75 76 Args: 77 name: Variable name. 78 shape: Variable shape. 79 dtype: The type of the variable. Defaults to `self.dtype` or `float32`. 80 initializer: Initializer instance (callable). 81 trainable: Whether the variable should be part of the layer's 82 "trainable_variables" (e.g. variables, biases) 83 or "non_trainable_variables" (e.g. BatchNorm mean, stddev). 84 Note, if the current variable scope is marked as non-trainable 85 then this parameter is ignored and any added variables are also 86 marked as non-trainable. `trainable` defaults to `True` unless 87 `synchronization` is set to `ON_READ`. 88 caching_device: Passed to `tf.Variable`. 89 validate_shape: Passed to `tf.Variable`. 90 constraint: Constraint instance (callable). 91 use_resource: Whether to use a `ResourceVariable`. 92 collections: List of graph collections keys. The new variable is added to 93 these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 94 synchronization: Indicates when a distributed a variable will be 95 aggregated. Accepted values are constants defined in the class 96 `tf.VariableSynchronization`. By default the synchronization is set to 97 `AUTO` and the current `DistributionStrategy` chooses 98 when to synchronize. If `synchronization` is set to `ON_READ`, 99 `trainable` must not be set to `True`. 100 aggregation: Indicates how a distributed variable will be aggregated. 101 Accepted values are constants defined in the class 102 `tf.VariableAggregation`. 103 partitioner: Not handled at this time. 104 105 Returns: 106 Variable instance. 107 """ 108 initializing_from_value = False 109 if initializer is not None and not callable(initializer): 110 initializing_from_value = True 111 112 if initializing_from_value: 113 init_val = initializer 114 variable_dtype = None 115 else: 116 # Instantiate initializer if provided initializer is a type object. 117 if tf_inspect.isclass(initializer): 118 initializer = initializer() 119 init_val = functools.partial(initializer, shape, dtype=dtype) 120 variable_dtype = dtype.base_dtype 121 if use_resource is None: 122 use_resource = True 123 124 # TODO(apassos,rohanj) figure out how to remove collections from here so we 125 # can remove the V1. 126 variable_shape = tensor_shape.TensorShape(shape) 127 return tf_variables.VariableV1( 128 initial_value=init_val, 129 name=name, 130 trainable=trainable, 131 caching_device=caching_device, 132 dtype=variable_dtype, 133 validate_shape=validate_shape, 134 constraint=constraint, 135 use_resource=use_resource, 136 collections=collections, 137 synchronization=synchronization, 138 aggregation=aggregation, 139 shape=variable_shape if variable_shape else None) 140 141 142def collect_previous_mask(input_tensors): 143 """Retrieves the output mask(s) of the previous node. 144 145 Args: 146 input_tensors: An arbitrary structure of Tensors. 147 148 Returns: 149 A mask tensor or list of mask tensors. 150 """ 151 152 def _collect_previous_mask(x): 153 return getattr(x, '_keras_mask', None) 154 155 return nest.map_structure(_collect_previous_mask, input_tensors) 156 157 158def have_all_keras_metadata(tensors): 159 return all(hasattr(x, '_keras_history') for x in nest.flatten(tensors)) 160 161 162def generate_placeholders_from_shape(shape): 163 return array_ops.placeholder(shape=shape, dtype=backend.floatx()) 164 165 166def create_keras_history(tensors): 167 """Wraps TensorFlow Operations for compatibility with the Functional API. 168 169 This method checks to see if a Tensor in `tensors` is missing Keras metadata 170 and has its origin in a Keras `Input` Layer. If so, this method will replace 171 the raw TensorFlow Operations that created this tensor with 172 `TensorFlowOpLayer` instances that create identical operations. 173 174 Any Tensors not originating from a Keras `Input` Layer will be treated as 175 constants when constructing `TensorFlowOpLayer` instances. 176 177 Args: 178 tensors: A structure of Tensors, some of which come from raw TensorFlow 179 operations and need to have Keras metadata assigned to them. 180 181 Returns: 182 created_layers: List. The `TensorFlowOpLayer` instances created to wrap 183 the raw Tensorflow operations. 184 """ 185 _, created_layers = _create_keras_history_helper(tensors, set(), []) 186 return created_layers 187 188 189# Unsafe Internal attribute. 190# If True, Keras will not evaluate the constant-foldable inputs to tf op 191# layers in TF1 graphs. This *might* speed up model construction time in 192# certain settings, but it means 193# the models will not be serializable/deserializable via get_config 194# (Only via Savedmodels). It may also change the semantics of whether 195# generated random numbers are generated once and re-used, or recomputed 196# each time. 197# Note: This path triggers for TPUEstimators / xla compiled graphs regardless 198# of this setting. 199_UNSAFE_GRAPH_OP_LAYER_CREATION = False 200 201 202def _create_keras_history_helper(tensors, processed_ops, created_layers): 203 """Helper method for `create_keras_history`. 204 205 Args: 206 tensors: A structure of Tensors for which to create Keras metadata. 207 processed_ops: Set. TensorFlow operations that have already been wrapped in 208 `TensorFlowOpLayer` instances. 209 created_layers: List. The `TensorFlowOpLayer` instances created. 210 211 Returns: 212 Tuple. First element is the updated set of TensorFlow Operations that 213 have been wrapped in `TensorFlowOpLayer` instances. Second element is 214 a list of the `TensorFlowOpLayer` instances created. 215 """ 216 if ops.executing_eagerly_outside_functions(): 217 raise ValueError( 218 '`create_keras_history` should only be called if eager is disabled!') 219 # Import of `base_layer` needed in order to create `TensorFlowOpLayer`. 220 # Cannot be imported at top because of circular dependencies. 221 # TODO(omalleyt): Resolve circular dependency. 222 from tensorflow.python.keras.engine import base_layer # pylint: disable=g-import-not-at-top 223 tensor_list = nest.flatten(tensors) 224 sparse_ops = [] 225 ragged_tensors = [] 226 for tensor in tensor_list: 227 if getattr(tensor, '_keras_history', None) is not None: 228 continue 229 if isinstance( 230 tensor, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)): 231 sparse_ops.append(tensor.op) 232 continue 233 if tf_utils.is_ragged(tensor): 234 # Ragged tensors don't have an op property 235 ragged_tensors.append(tensor) 236 continue 237 op = tensor.op # The Op that created this Tensor. 238 if op not in processed_ops: 239 # Recursively set `_keras_history`. 240 op_inputs = list(op.inputs) 241 constants = {} 242 layer_inputs = [] 243 for i, op_input in enumerate(op_inputs): 244 if uses_keras_history(op_input): 245 layer_inputs.append(op_input) 246 else: 247 # Treat any value not originating from a `keras.Input` as 248 # a constant. Variables cannot be supported. 249 ds_with_session = ( 250 distribution_strategy_context.in_cross_replica_context() and 251 not ops.executing_eagerly_outside_functions()) 252 using_xla = control_flow_util.GraphOrParentsInXlaContext( 253 ops.get_default_graph()) 254 if ds_with_session or using_xla or _UNSAFE_GRAPH_OP_LAYER_CREATION: 255 # In Legacy Graph mode, evaluating here makes Session be 256 # configured improperly. The downside of this is that saving 257 # via `get_config` breaks, but SavedModel still works. 258 constants[i] = op_input 259 else: 260 with ops.init_scope(): 261 constants[i] = backend.function([], op_input)([]) 262 layer_inputs = unnest_if_single_tensor(layer_inputs) 263 processed_ops, created_layers = _create_keras_history_helper( 264 layer_inputs, processed_ops, created_layers) 265 name = op.name 266 node_def = op.node_def.SerializeToString() 267 op_layer = base_layer.TensorFlowOpLayer( 268 node_def, constants=constants, name=name) 269 created_layers.append(op_layer) 270 op_layer._set_connectivity_metadata( # pylint: disable=protected-access 271 args=(layer_inputs,), 272 kwargs={}, 273 outputs=op.outputs) 274 processed_ops.update([op]) 275 if sparse_ops or ragged_tensors: 276 lambda_example = """ 277 weights_mult = lambda x: tf.sparse.sparse_dense_matmul(x, weights) 278 output = tf.keras.layers.Lambda(weights_mult)(input) 279 """ 280 raise ValueError( 281 'Tensorflow ops that generate ragged or sparse tensor ' 282 'outputs are currently not supported by Keras automatic ' 283 'op wrapping. Please wrap these ops in a Lambda layer: ' 284 '\n\n```\n{example}\n```\n' 285 'Sparse ops encountered: {sparse_ops}\n' 286 'Ragged tensors encountered: {ragged_tensors}\n'.format( 287 example=lambda_example, 288 sparse_ops=str(sparse_ops), 289 ragged_tensors=str(ragged_tensors))) 290 return processed_ops, created_layers 291 292 293def unnest_if_single_tensor(input_tensors): 294 # Preserve compatibility with older configs 295 flat_input_tensors = nest.flatten(input_tensors) 296 # If this is a single element but not a dict, unwrap. If this is a dict, 297 # assume the first layer expects a dict (as is the case with a 298 # DenseFeatures layer); pass through. 299 if not isinstance(input_tensors, dict) and len(flat_input_tensors) == 1: 300 input_tensors = flat_input_tensors[0] 301 return input_tensors 302 303 304def needs_keras_history(tensors, ignore_call_context=False): 305 """Check if any Tensors need to be wrapped in TensorFlowOpLayers. 306 307 This will never return True inside a sublayer, because sublayers 308 do not need to create Keras History. Otherwise, this returns True 309 if one or more of `tensors` originates from a `keras.Input` and 310 does not have `_keras_history` set. 311 312 Args: 313 tensors: An arbitrary nested structure of Tensors. 314 ignore_call_context: Whether to ignore the check of if currently 315 outside of a `call` context. This is `True` when creating 316 KerasHistory inside `Node`, where we always know that Tensors 317 are being used with the Functional API. 318 319 Returns: 320 Bool, whether at least one Tensor needs to be wrapped. 321 """ 322 input_tensors = nest.flatten(tensors) 323 if call_context().in_call and not ignore_call_context: 324 return False 325 if all( 326 getattr(tensor, '_keras_history', None) is not None 327 for tensor in input_tensors): 328 # KerasHistory already set. 329 return False 330 return uses_keras_history(tensors) 331 332 333def is_in_keras_graph(): 334 """Returns if currently executing inside of a Keras graph.""" 335 return call_context().in_keras_graph 336 337 338def is_in_eager_or_tf_function(): 339 """Returns if in eager mode or inside of a tf.function.""" 340 return context.executing_eagerly() or is_in_tf_function() 341 342 343def is_in_tf_function(): 344 """Returns if inside of a tf.function.""" 345 # Check if running in V1 graph mode. 346 if not ops.executing_eagerly_outside_functions(): 347 return False 348 if not ops.inside_function(): 349 return False 350 # Check if inside Keras FuncGraph. 351 if is_in_keras_graph(): 352 return False 353 # Check for a v1 `wrap_function` FuncGraph. 354 graph = ops.get_default_graph() 355 if (getattr(graph, 'name', False) and 356 graph.name.startswith('wrapped_function')): 357 return False 358 return True 359 360 361def uses_keras_history(tensors): 362 """Check if at least one Tensor originates from a `keras.Input`. 363 364 This is `True` if at least one Tensor has its origin in a `keras.Input`. 365 Any Tensor that originates from a `keras.Input` will have a dependency 366 Tensor with a `_keras_history` attribute attached. Tensors that have 367 already been checked to not originate from a `keras.Input` 368 are marked as `_keras_history_checked`. 369 370 Args: 371 tensors: An arbitrary nested structure of Tensors. 372 373 Returns: 374 Bool, whether at least one Tensor originates from a `keras.Input`. 375 """ 376 checked_tensors = set() 377 tensors_to_check = nest.flatten(tensors) 378 379 while tensors_to_check: 380 new_tensors_to_check = [] 381 for tensor in tensors_to_check: 382 if id(tensor) in checked_tensors: 383 continue 384 385 checked_tensors.add(id(tensor)) 386 387 if getattr(tensor, '_keras_history_checked', None) is not None: 388 continue 389 if getattr(tensor, '_keras_history', None) is not None: 390 return True 391 392 try: 393 new_tensors_to_check.extend(tensor.op.inputs) 394 except AttributeError: 395 # In case `tensor` is a Variable created in an Eager context. 396 pass 397 398 tensors_to_check = new_tensors_to_check 399 400 # Mark that these Tensors have been checked once for `_keras_history`, 401 # and should not be checked again for performance reasons. 402 mark_checked(tensors) 403 return False 404 405 406def mark_checked(tensors): 407 """Marks that these Tensors should not be tracked. 408 409 This prevents Layers from attempting to create TensorFlowOpLayers 410 for these Tensors. 411 412 Args: 413 tensors: An arbitrary structure of Tensors. 414 """ 415 416 def _mark_checked(tensor): 417 tensor._keras_history_checked = True # pylint: disable=protected-access 418 419 nest.map_structure(_mark_checked, tensors) 420 421 422def call_context(): 423 """Returns currently active `CallContext`.""" 424 call_ctx = getattr(_call_context, 'call_context', None) 425 if call_ctx is None: 426 call_ctx = CallContext() 427 _call_context.call_context = call_ctx 428 return call_ctx 429 430 431# Inject the call_context function to keras_deps to remove the dependency 432# from TFLite to Keras. 433keras_deps.register_call_context_function(call_context) 434 435 436class CallContext(object): 437 """Keeps track of properties currently inside a Layer/Model's `call`. 438 439 Attributes: 440 in_call: Whether currently inside the `call` of a Layer. 441 layer: The `Layer` whose `call` is currently active. 442 inputs: The inputs to the currently active `Layer`. 443 build_graph: Whether currently inside a Graph or FuncGraph. 444 training: Whether currently executing in training or inference mode. 445 saving: Whether currently saving to SavedModel. 446 frozen: Whether currently executing inside a `Layer` with `trainable` set to 447 `False`. 448 in_keras_graph: Whether executing inside the Keras Graph. 449 """ 450 451 def __init__(self): 452 # Handle `in_call` separately as it is the most-read attr and reading it is 453 # on the hot path. 454 self.in_call = False 455 self._state = { 456 'layer': None, 457 'inputs': None, 458 'build_graph': False, 459 'training': None, 460 'saving': None 461 } 462 # TODO(b/150169018): This logic can be replaced after the Functional API 463 # refactor. 464 self._in_keras_graph = False 465 466 def enter(self, layer, inputs, build_graph, training, saving=None): 467 """Push a Layer and its inputs and state onto the current call context. 468 469 Args: 470 layer: The `Layer` whose `call` is currently active. 471 inputs: The inputs to the currently active `Layer`. 472 build_graph: Whether currently inside a Graph or FuncGraph. 473 training: Whether currently executing in training or inference mode. 474 saving: Whether currently saving to SavedModel. 475 476 Returns: 477 Context manager. 478 """ 479 state = { 480 'layer': layer, 481 'inputs': inputs, 482 'build_graph': build_graph, 483 'training': training, 484 'saving': saving 485 } 486 return CallContextManager(self, state) 487 488 @property 489 def layer(self): 490 return self._state['layer'] 491 492 @property 493 def inputs(self): 494 return self._state['inputs'] 495 496 @property 497 def build_graph(self): 498 return self._state['build_graph'] 499 500 @property 501 def training(self): 502 return self._state['training'] 503 504 @property 505 def saving(self): 506 return self._state['saving'] 507 508 @property 509 def frozen(self): 510 layer = self._state['layer'] 511 if not layer: 512 return False 513 return not layer.trainable 514 515 @property 516 def in_keras_graph(self): 517 # Returns True even if in a subgraph of the Keras graph, such as those 518 # created by control flow ops. 519 if context.executing_eagerly(): 520 return False 521 return (self._in_keras_graph or 522 getattr(backend.get_graph(), 'name', None) == 'keras_graph') 523 524 525class CallContextManager(object): 526 """Context manager for `CallContext`.""" 527 528 def __init__(self, call_ctx, state): 529 self._call_ctx = call_ctx 530 self._state = state 531 self._build_graph = state['build_graph'] 532 533 def __enter__(self): 534 call_ctx = self._call_ctx 535 self._prev_in_call = call_ctx.in_call 536 self._prev_state = call_ctx._state 537 538 call_ctx.in_call = True 539 call_ctx._state = self._state 540 541 # TODO(b/150169018): This logic can be removed after the Functional API 542 # refactor. 543 if self._build_graph: 544 self._prev_in_keras_graph = call_ctx._in_keras_graph 545 call_ctx._in_keras_graph = ( 546 call_ctx._in_keras_graph or 547 getattr(backend.get_graph(), 'name', None) == 'keras_graph') 548 549 def __exit__(self, *exc_info): 550 call_ctx = self._call_ctx 551 call_ctx.in_call = self._prev_in_call 552 call_ctx._state = self._prev_state 553 554 if self._build_graph: 555 call_ctx._in_keras_graph = self._prev_in_keras_graph 556 557 558def training_arg_passed_to_call(argspec, args, kwargs): 559 """Returns whether a user passed the `training` argument in `__call__`.""" 560 # `argspec.args` starts with ['self', 'inputs'] 561 full_args = dict(zip(argspec.args[2:], args)) 562 full_args.update(kwargs) 563 return 'training' in full_args and full_args['training'] is not None 564 565 566def is_subclassed(layer): 567 """Returns True if the object is a subclassed layer or subclassed model.""" 568 return (layer.__module__.find('keras.engine') == -1 and 569 layer.__module__.find('keras.layers') == -1) 570 571 572def from_saved_model(layer): 573 """Returns whether the layer is loaded from a SavedModel.""" 574 return layer.__module__.find('keras.saving.saved_model') != -1 575 576 577def check_graph_consistency(tensor=None, method='add_loss', force_raise=False): 578 """Checks that tensors passed to `add_*` method match the Keras graph. 579 580 When one of the `add_*` method is called inside a V2 conditional branch, 581 the underlying tensor gets created in a FuncGraph managed by control_flow_v2. 582 We need to raise clear error messages in such cases. 583 584 Args: 585 tensor: Tensor to check, or `False` if it is known that an error 586 should be raised. 587 method: Caller method, one of {'add_metric', 'add_loss', 'add_update'}. 588 force_raise: If an error should be raised regardless of `tensor`. 589 590 Raises: 591 RuntimeError: In case of an out-of-graph tensor. 592 """ 593 if (force_raise or 594 (ops.executing_eagerly_outside_functions() and 595 hasattr(tensor, 'graph') and tensor.graph.is_control_flow_graph)): 596 if method == 'activity_regularizer': 597 bad_example = """ 598 class TestModel(tf.keras.Model): 599 600 def __init__(self): 601 super(TestModel, self).__init__(name='test_model') 602 self.dense = tf.keras.layers.Dense(2, activity_regularizer='l2') 603 604 def call(self, x, training=None): 605 if training: 606 return self.dense(x) 607 else: 608 return self.dense(x) 609 """ 610 correct_example = """ 611 class TestModel(tf.keras.Model): 612 613 def __init__(self): 614 super(TestModel, self).__init__(name='test_model') 615 self.dense = tf.keras.layers.Dense(2, activity_regularizer='l2') 616 617 def call(self, x, training=None): 618 return self.dense(x) 619 """ 620 raise RuntimeError( 621 'You are using a layer with `activity_regularizer` in a control flow ' 622 'branch, e.g.:\n{bad_example}\nThis is currently not supported. ' 623 'Please move your call to the layer with `activity_regularizer` out ' 624 'of the control flow branch, e.g.:\n{correct_example}\n' 625 'You can also resolve this by marking your outer model/layer dynamic' 626 ' (eager-only) by passing `dynamic=True` to the layer constructor. ' 627 'Any kind of control flow is supported with dynamic layers. ' 628 'Note that using `dynamic=True` requires you to implement static ' 629 'shape inference in the `compute_output_shape(input_shape)` ' 630 'method.'.format( 631 bad_example=bad_example, correct_example=correct_example)) 632 633 if method == 'add_metric': 634 bad_example = """ 635 def call(self, inputs, training=None): 636 if training: 637 metric = compute_metric(inputs) 638 self.add_metric(metric, name='my_metric', aggregation='mean') 639 return inputs 640 """ 641 correct_example = """ 642 def call(self, inputs, training=None): 643 if training: 644 metric = compute_metric(inputs) 645 else: 646 metric = 0. 647 self.add_metric(metric, name='my_metric', aggregation='mean') 648 return inputs 649 """ 650 elif method == 'add_loss': 651 bad_example = """ 652 def call(self, inputs, training=None): 653 if training: 654 loss = compute_loss(inputs) 655 self.add_loss(loss) 656 return inputs 657 """ 658 correct_example = """ 659 def call(self, inputs, training=None): 660 if training: 661 loss = compute_loss(inputs) 662 else: 663 loss = 0. 664 self.add_loss(loss) 665 return inputs 666 """ 667 else: 668 bad_example = """ 669 def call(self, inputs, training=None): 670 if training: 671 self.add_update(self.w.assign_add(1)) 672 return inputs 673 """ 674 correct_example = """ 675 def call(self, inputs, training=None): 676 if training: 677 increment = 1 678 else: 679 increment = 0 680 self.add_update(self.w.assign_add(increment)) 681 return inputs 682 """ 683 raise RuntimeError( 684 'You are using the method `{method}` in a control flow branch ' 685 'in your layer, e.g.:\n{bad_example}\n' 686 'This is not currently supported. ' 687 'Please move your call to {method} out of the control flow branch, ' 688 'e.g.:\n{correct_example}\n' 689 'You can also resolve this by marking your layer ' 690 'as dynamic (eager-only) by passing ' 691 '`dynamic=True` to the layer constructor. ' 692 'Any kind of control flow is supported with dynamic layers. ' 693 'Note that using `dynamic=True` requires you ' 694 'to implement static shape inference ' 695 'in the `compute_output_shape(input_shape)` method.'.format( 696 method=method, 697 bad_example=bad_example, 698 correct_example=correct_example)) 699 700 701def mark_as_return(outputs, acd): 702 """Marks `outputs` as the return values for automatic control deps.""" 703 704 def _mark_as_return(tensor): 705 """Marks `tensor` as the return value for automatic control deps.""" 706 if not tensor_util.is_tf_type(tensor): 707 return tensor 708 709 # pylint: disable=protected-access 710 return_tensor = acd.mark_as_return(tensor) 711 if getattr(tensor, '_keras_mask', None) is not None: 712 return_tensor._keras_mask = acd.mark_as_return(tensor._keras_mask) 713 else: 714 return_tensor._keras_mask = None 715 716 # Handle TensorFlow Probability attached metadata. 717 # TODO(b/132076537): Remove this once TFP uses `CompositeTensor`. 718 if getattr(tensor, '_tfp_distribution', None) is not None: 719 return_tensor._tfp_distribution = tensor._tfp_distribution 720 721 return return_tensor 722 # pylint: enable=protected-access 723 724 return nest.map_structure(_mark_as_return, outputs) 725 726 727V2_DTYPE_BEHAVIOR = None 728 729 730@keras_export(v1=['keras.layers.enable_v2_dtype_behavior']) 731def enable_v2_dtype_behavior(): 732 """Enable the V2 dtype behavior for Keras layers. 733 734 By default, the V2 dtype behavior is enabled in TensorFlow 2, so this function 735 is only useful if `tf.compat.v1.disable_v2_behavior` has been called. Since 736 mixed precision requires V2 dtype behavior to be enabled, this function allows 737 you to use mixed precision in Keras layers if `disable_v2_behavior` has been 738 called. 739 740 When enabled, the dtype of Keras layers defaults to floatx (which is typically 741 float32) instead of None. In addition, layers will automatically cast 742 floating-point inputs to the layer's dtype. 743 744 >>> x = tf.ones((4, 4, 4, 4), dtype='float64') 745 >>> layer = tf.keras.layers.Conv2D(filters=4, kernel_size=2) 746 >>> print(layer.dtype) # float32 since V2 dtype behavior is enabled 747 float32 748 >>> y = layer(x) # Layer casts inputs since V2 dtype behavior is enabled 749 >>> print(y.dtype.name) 750 float32 751 752 A layer author can opt-out their layer from the automatic input casting by 753 passing `autocast=False` to the base Layer's constructor. This disables the 754 autocasting part of the V2 behavior for that layer, but not the defaulting to 755 floatx part of the V2 behavior. 756 757 When a global `tf.keras.mixed_precision.Policy` is set, a Keras layer's dtype 758 will default to the global policy instead of floatx. Layers will automatically 759 cast inputs to the policy's compute_dtype. 760 """ 761 global V2_DTYPE_BEHAVIOR 762 V2_DTYPE_BEHAVIOR = True 763 764 765@keras_export(v1=['keras.layers.disable_v2_dtype_behavior']) 766def disable_v2_dtype_behavior(): 767 """Disables the V2 dtype behavior for Keras layers. 768 769 See `tf.compat.v1.keras.layers.enable_v2_dtype_behavior`. 770 """ 771 global V2_DTYPE_BEHAVIOR 772 V2_DTYPE_BEHAVIOR = False 773 774 775def v2_dtype_behavior_enabled(): 776 """Returns True if the V2 dtype behavior is enabled.""" 777 if V2_DTYPE_BEHAVIOR is None: 778 return tf2.enabled() 779 return V2_DTYPE_BEHAVIOR 780 781 782class TrackableWeightHandler(object): 783 """Keras wrapper for handling tracking.Trackable object saving and restoring. 784 785 This class handles Trackables in both V1 and V2 modes, ensuring that they can 786 be saved and restored with the correct data and without adding additional ops 787 on every save. 788 789 Attributes: 790 trackable: The trackable to wrap. 791 num_tensors: The number of tensors that this trackable requires for saving. 792 """ 793 794 def __init__(self, trackable): 795 if not isinstance(trackable, tracking.Trackable): 796 raise ValueError('%s is not a Trackable object.' % (trackable,)) 797 self._trackable = trackable 798 self._distribute_strategy = distribution_strategy_context.get_strategy() 799 800 # TODO(b/141682913): Figure out why this is private and fix it. 801 saveables = trackable._gather_saveables_for_checkpoint().values() # pylint: disable=protected-access 802 # 'Saveables' won't exist when we're passed a legacy TF1 table like 803 # a StaticHashTable. 804 if not saveables: 805 self._num_tensors = 0 806 self._setter = lambda weights: None 807 self._getter = lambda: [] 808 809 elif len(saveables) == 1: 810 saveable = list(saveables)[0] 811 812 if ops.executing_eagerly_outside_functions(): 813 # If we're in eager mode, we need to defer calling the Trackable's 814 # saveable() callable until data export time. 815 # However, it is safe to call the saveable as many times as we want, so 816 # we will call it now to figure out how many tensors this Trackable will 817 # produce. 818 self._saveable = saveable 819 self._num_tensors = len(self._saveable().specs) 820 self._setter = lambda weights: self._saveable().restore(weights, None) 821 self._getter = lambda: [spec.tensor for spec in self._saveable().specs] 822 else: 823 # If we're in Graph mode, we need to evaluate the Saveable only once and 824 # cache the resulting restore graph. Failing to do this will result in 825 # new assignment ops being added to the graph each time set_weights() is 826 # called. 827 self._placeholder_tensors = [] 828 self._saveable = saveable() 829 self._num_tensors = len(self._saveable.specs) 830 for spec in self._saveable.specs: 831 tensor = spec.tensor 832 self._placeholder_tensors.append( 833 array_ops.placeholder(tensor.dtype, tensor.shape)) 834 self._assign_op = self._saveable.restore(self._placeholder_tensors, 835 None) 836 self._setter = self._set_weights_v1 837 self._getter = lambda: [spec.tensor for spec in self._saveable.specs] 838 else: 839 raise ValueError('Only Trackables with one Saveable are supported. ' 840 'The Trackable %s has %d Saveables.' % 841 (trackable, len(saveables))) 842 843 @property 844 def num_tensors(self): 845 return self._num_tensors 846 847 def set_weights(self, weights): 848 if len(weights) != self._num_tensors: 849 raise ValueError( 850 ('Weight handler for trackable %s received the wrong number of ' + 851 'weights: expected %s, got %s.') % 852 (self._trackable, self._num_tensors, len(weights))) 853 self._setter(weights) 854 855 def get_tensors(self): 856 return self._getter() 857 858 def _set_weights_v1(self, weights): 859 feed_dict = {} 860 for idx, tensor in enumerate(weights): 861 feed_dict[self._placeholder_tensors[idx]] = tensor 862 backend.get_session().run(self._assign_op, feed_dict) 863 864 865class StaticTableHandler(TrackableWeightHandler): 866 """Wrapper for handling weight collection for static hash tables.""" 867 868 def __init__(self, getter_lambda): # pylint: disable=super-init-not-called 869 self._num_tensors = 2 870 self._getter = getter_lambda 871 self._distribute_strategy = distribution_strategy_context.get_strategy() 872 873 def raise_error(_): 874 raise RuntimeError('This layer contains a static lookup table, which ' 875 'cannot be changed via set_weights().') 876 877 self._setter = raise_error 878 879 880def no_ragged_support(inputs, layer_name): 881 input_list = nest.flatten(inputs) 882 if any(isinstance(x, ragged_tensor.RaggedTensor) for x in input_list): 883 raise ValueError('Layer %s does not support RaggedTensors as input. ' 884 'Inputs received: %s. You can try converting your ' 885 'input to an uniform tensor.' % (layer_name, inputs)) 886 887 888def is_split_variable(v): 889 """Returns True if `v` is either a PartionedVariable or a ShardedVariable.""" 890 return hasattr(v, '_variable_list') or hasattr(v, '_variables') 891 892 893def has_weights(obj): 894 obj_type = type(obj) 895 return (hasattr(obj_type, 'trainable_weights') and 896 hasattr(obj_type, 'non_trainable_weights') and 897 not isinstance(obj, type)) 898 899 900# TODO(kathywu): This is a temporary hack. When a network of layers is revived 901# from SavedModel, only the top-level layer will have losses. This causes issues 902# in eager mode because the child layers may have graph losses 903# (thus model.losses returns a mix of Eager and graph tensors). To fix this, 904# whenever eager losses are added to one layer, add eager losses to all 905# child layers. This causes `.losses` to only return eager losses. 906REVIVED_LOSS_PLACEHOLDER = ( 907 'This layer\'s losses have been added to the parent layer.') 908