1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=protected-access 16"""A `Network` is way to compose layers: the topological form of a `Model`.""" 17 18import collections 19import copy 20import itertools 21import warnings 22 23from tensorflow.python.eager import context 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import ops 26from tensorflow.python.keras import backend 27from tensorflow.python.keras.engine import base_layer 28from tensorflow.python.keras.engine import base_layer_utils 29from tensorflow.python.keras.engine import input_layer as input_layer_module 30from tensorflow.python.keras.engine import input_spec 31from tensorflow.python.keras.engine import node as node_module 32from tensorflow.python.keras.engine import training as training_lib 33from tensorflow.python.keras.engine import training_utils 34from tensorflow.python.keras.saving.saved_model import network_serialization 35from tensorflow.python.keras.utils import generic_utils 36from tensorflow.python.keras.utils import tf_inspect 37from tensorflow.python.keras.utils import tf_utils 38from tensorflow.python.ops import array_ops 39from tensorflow.python.ops import math_ops 40from tensorflow.python.platform import tf_logging as logging 41from tensorflow.python.trackable import base as trackable 42from tensorflow.python.util import nest 43from tensorflow.tools.docs import doc_controls 44 45 46# pylint: disable=g-classes-have-attributes 47class Functional(training_lib.Model): 48 """A `Functional` model is a `Model` defined as a directed graph of layers. 49 50 Three types of `Model` exist: subclassed `Model`, `Functional` model, 51 and `Sequential` (a special case of `Functional`). 52 In general, more Keras features are supported with `Functional` 53 than with subclassed `Model`s, specifically: 54 55 - Model cloning (`keras.models.clone`) 56 - Serialization (`model.get_config()/from_config`, `model.to_json()` 57 - Whole-model saving (`model.save()`) 58 59 A `Functional` model can be instantiated by passing two arguments to 60 `__init__`. The first argument is the `keras.Input` Tensors that represent 61 the inputs to the model. The second argument specifies the output 62 tensors that represent the outputs of this model. Both arguments can be a 63 nested structure of tensors. 64 65 Example: 66 67 ``` 68 inputs = {'x1': keras.Input(shape=(10,)), 'x2': keras.Input(shape=(1,))} 69 t = keras.layers.Dense(1, activation='relu')(inputs['x1']) 70 outputs = keras.layers.Add()([t, inputs['x2']) 71 model = keras.Model(inputs, outputs) 72 ``` 73 74 A `Functional` model constructed using the Functional API can also include raw 75 TensorFlow functions, with the exception of functions that create Variables 76 or assign ops. 77 78 Example: 79 80 ``` 81 inputs = keras.Input(shape=(10,)) 82 x = keras.layers.Dense(1)(inputs) 83 outputs = tf.nn.relu(x) 84 model = keras.Model(inputs, outputs) 85 ``` 86 87 Args: 88 inputs: List of input tensors (must be created via `tf.keras.Input()`). 89 outputs: List of output tensors. 90 name: String, optional. Name of the model. 91 trainable: Boolean, optional. If the model's variables should be trainable. 92 """ 93 94 # See tf.Module for the usage of this property. 95 # The key of _layer_call_argspecs is a layer. tf.Module._flatten will fail to 96 # flatten the key since it is trying to convert Trackable/Layer to a string. 97 _TF_MODULE_IGNORED_PROPERTIES = frozenset(itertools.chain( 98 ('_layer_call_argspecs', '_compiled_trainable_state', 99 '_output_mask_cache', '_output_tensor_cache', '_output_shape_cache'), 100 training_lib.Model._TF_MODULE_IGNORED_PROPERTIES 101 )) 102 103 @trackable.no_automatic_dependency_tracking 104 def __init__(self, inputs, outputs, name=None, trainable=True, 105 **kwargs): 106 # This is used by the Model class, since we have some logic to swap the 107 # class in the __new__ method, which will lead to __init__ get invoked 108 # twice. Using the skip_init to skip one of the invocation of __init__ to 109 # avoid any side effects 110 skip_init = kwargs.pop('skip_init', False) 111 if skip_init: 112 return 113 generic_utils.validate_kwargs(kwargs, {}) 114 super(Functional, self).__init__(name=name, trainable=trainable) 115 self._init_graph_network(inputs, outputs) 116 117 @trackable.no_automatic_dependency_tracking 118 def _init_graph_network(self, inputs, outputs): 119 # This method is needed for Sequential to reinitialize graph network when 120 # layer is added or removed. 121 self._is_graph_network = True 122 123 # Normalize and set self.inputs, self.outputs. 124 if isinstance(inputs, list) and len(nest.flatten(inputs)) == 1: 125 inputs = inputs[0] 126 if isinstance(outputs, list) and len(nest.flatten(outputs)) == 1: 127 outputs = outputs[0] 128 self._nested_inputs = inputs 129 self._nested_outputs = outputs 130 self.inputs = nest.flatten(inputs) 131 self.outputs = nest.flatten(outputs) 132 133 # Models constructed with a single Tensor or list of Tensors can 134 # be called with a dict, where the keys of the dict are the names 135 # of the `Input` objects. Extra keys are ignored with warning. 136 if not nest.is_nested(self._nested_inputs): 137 self._enable_dict_to_input_mapping = True 138 elif (isinstance(self._nested_inputs, (list, tuple)) and 139 not any(nest.is_nested(t) for t in self._nested_inputs)): 140 self._enable_dict_to_input_mapping = True 141 elif (isinstance(self._nested_inputs, dict) and 142 not any(nest.is_nested(t) for t in self._nested_inputs.values())): 143 self._enable_dict_to_input_mapping = True 144 else: 145 self._enable_dict_to_input_mapping = False 146 147 if not ops.executing_eagerly_outside_functions(): 148 if any(not hasattr(tensor, '_keras_history') for tensor in self.outputs): 149 base_layer_utils.create_keras_history(self._nested_outputs) 150 151 self._validate_graph_inputs_and_outputs() 152 153 # A Network does not create weights of its own, thus it is already 154 # built. 155 self.built = True 156 self._build_input_shape = nest.map_structure(lambda x: x.shape, inputs) 157 self._compute_output_and_mask_jointly = True 158 # `_expects_training_arg` is True since the `training` argument is always 159 # present in the signature of the `call` method of a graph network. 160 self._expects_training_arg = True 161 self._expects_mask_arg = True 162 # A graph network does not autocast inputs, as its layers will cast them 163 # instead. 164 self._autocast = False 165 166 self._input_layers = [] 167 self._output_layers = [] 168 self._input_coordinates = [] 169 self._output_coordinates = [] 170 171 # This is for performance optimization when calling the Network on new 172 # inputs. Every time the Network is called on a set on input tensors, 173 # we compute the output tensors, output masks and output shapes in one pass, 174 # then cache them here. When any of these outputs is queried later, we 175 # retrieve it from there instead of recomputing it. 176 self._output_mask_cache = {} 177 self._output_tensor_cache = {} 178 self._output_shape_cache = {} 179 180 # Build self._output_layers: 181 for x in self.outputs: 182 layer, node_index, tensor_index = x._keras_history # pylint: disable=protected-access 183 self._output_layers.append(layer) 184 self._output_coordinates.append((layer, node_index, tensor_index)) 185 186 # Build self._input_layers: 187 for x in self.inputs: 188 layer, node_index, tensor_index = x._keras_history # pylint: disable=protected-access 189 # It's supposed to be an input layer, so only one node 190 # and one tensor output. 191 assert node_index == 0 192 assert tensor_index == 0 193 self._input_layers.append(layer) 194 self._input_coordinates.append((layer, node_index, tensor_index)) 195 196 # Keep track of the network's nodes and layers. 197 nodes, nodes_by_depth, layers, _ = _map_graph_network( 198 self.inputs, self.outputs) 199 self._network_nodes = nodes 200 self._nodes_by_depth = nodes_by_depth 201 self._self_tracked_trackables = layers 202 self._layer_call_argspecs = {} 203 for layer in self._self_tracked_trackables: 204 self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call) 205 206 # Build self.input_names and self.output_names. 207 self._set_output_names() 208 self.input_names = [] 209 self._feed_input_names = [] 210 self._feed_inputs = [] 211 self._feed_input_shapes = [] 212 for layer in self._input_layers: 213 self.input_names.append(layer.name) 214 if layer.is_placeholder: 215 self._feed_input_names.append(layer.name) 216 # Use batch_input_shape here because non-eager composite tensors may not 217 # have a shape attribute that's meaningful (sparse, for instance, has 218 # a tensor that's non-constant and needs to be fed). This means that 219 # input layers that create placeholders will need to have the 220 # batch_input_shape attr to allow for input shape validation. 221 self._feed_input_shapes.append(layer._batch_input_shape) 222 self._feed_inputs.append(layer.input) 223 224 self._compute_tensor_usage_count() 225 self._set_save_spec(self._nested_inputs) 226 tf_utils.assert_no_legacy_layers(self.layers) 227 228 @property 229 def input(self): 230 """Retrieves the input tensor(s) of a layer. 231 232 Only applicable if the layer has exactly one input, 233 i.e. if it is connected to one incoming layer. 234 235 Returns: 236 Input tensor or list of input tensors. 237 238 Raises: 239 RuntimeError: If called in Eager mode. 240 AttributeError: If no inbound nodes are found. 241 """ 242 return self._nested_inputs 243 244 @property 245 def input_shape(self): 246 """Retrieves the input shape(s) of a layer. 247 248 Only applicable if the layer has exactly one input, 249 i.e. if it is connected to one incoming layer, or if all inputs 250 have the same shape. 251 252 Returns: 253 Input shape, as an integer shape tuple 254 (or list of shape tuples, one tuple per input tensor). 255 256 Raises: 257 AttributeError: if the layer has no defined input_shape. 258 RuntimeError: if called in Eager mode. 259 """ 260 return nest.map_structure(backend.int_shape, self.input) 261 262 @property 263 def input_spec(self): 264 if hasattr(self, '_manual_input_spec'): 265 return self._manual_input_spec 266 if (isinstance(self._nested_inputs, (dict, list, tuple)) and 267 len(self._nested_inputs) != len(self.inputs)): 268 # Case where we have a nested structure. 269 # In such a case we can't safely run any checks. 270 return None 271 if isinstance(self._nested_inputs, dict): 272 # Case where `_nested_inputs` is a plain dict of Inputs. 273 names = sorted(self._nested_inputs.keys()) 274 return [input_spec.InputSpec( 275 shape=shape_with_no_batch_size(self._nested_inputs[name]), 276 allow_last_axis_squeeze=True, name=name) for name in names] 277 else: 278 # Single input, or list / tuple of inputs. 279 # The data may be passed as a dict keyed by input name. 280 return [input_spec.InputSpec( 281 shape=shape_with_no_batch_size(x), allow_last_axis_squeeze=True, 282 name=x._keras_history.layer.name) for x in self.inputs] 283 284 @input_spec.setter 285 def input_spec(self, value): 286 self._manual_input_spec = value 287 288 @property 289 def output(self): 290 """Retrieves the output tensor(s) of a layer. 291 292 Only applicable if the layer has exactly one output, 293 i.e. if it is connected to one incoming layer. 294 295 Returns: 296 Output tensor or list of output tensors. 297 298 Raises: 299 AttributeError: if the layer is connected to more than one incoming 300 layers. 301 RuntimeError: if called in Eager mode. 302 """ 303 return self._nested_outputs 304 305 @property 306 def output_shape(self): 307 """Retrieves the output shape(s) of a layer. 308 309 Only applicable if the layer has one output, 310 or if all outputs have the same shape. 311 312 Returns: 313 Output shape, as an integer shape tuple 314 (or list of shape tuples, one tuple per output tensor). 315 316 Raises: 317 AttributeError: if the layer has no defined output shape. 318 RuntimeError: if called in Eager mode. 319 """ 320 return nest.map_structure(backend.int_shape, self.output) 321 322 def _set_output_names(self): 323 """Assigns unique names to the Network's outputs. 324 325 Output layers with multiple output tensors would otherwise lead to duplicate 326 names in self.output_names. 327 """ 328 uniquified = [] 329 output_names = set() 330 prefix_count = {} 331 for layer in self._output_layers: 332 proposal = layer.name 333 while proposal in output_names: 334 existing_count = prefix_count.get(layer.name, 1) 335 proposal = '{}_{}'.format(layer.name, existing_count) 336 prefix_count[layer.name] = existing_count + 1 337 output_names.add(proposal) 338 uniquified.append(proposal) 339 self.output_names = uniquified 340 341 @property 342 def _layer_checkpoint_dependencies(self): 343 """Dictionary of layer dependencies to be included in the checkpoint.""" 344 weight_layer_index = 0 345 346 dependencies = collections.OrderedDict() 347 for layer_index, layer in enumerate(self.layers): 348 try: 349 if layer.weights: 350 # Keep a separate index for layers which have weights. This allows 351 # users to insert Layers without weights anywhere in the network 352 # without breaking checkpoints. 353 dependencies['layer_with_weights-%d' % weight_layer_index] = layer 354 weight_layer_index += 1 355 except ValueError: 356 # The layer might have weights, but may not be built yet. We just treat 357 # it as layer without weight. 358 pass 359 360 # Even if it doesn't have weights, we should still track everything in 361 # case it has/will have Trackable dependencies. 362 dependencies['layer-%d' % layer_index] = layer 363 return dependencies 364 365 def _trackable_children(self, 366 save_type=trackable.SaveType.CHECKPOINT, 367 **kwargs): 368 dependencies = self._layer_checkpoint_dependencies 369 dependencies.update( 370 super(Functional, self)._trackable_children(save_type, **kwargs)) 371 return dependencies 372 373 def _lookup_dependency(self, name): 374 layer_dependencies = self._layer_checkpoint_dependencies 375 if name in layer_dependencies: 376 return layer_dependencies[name] 377 return super(Functional, self)._lookup_dependency(name) 378 379 def _handle_deferred_layer_dependencies(self, layers): 380 """Handles layer checkpoint dependencies that are added after init.""" 381 layer_checkpoint_dependencies = self._layer_checkpoint_dependencies 382 layer_to_name = {v: k for k, v in layer_checkpoint_dependencies.items()} 383 for layer in layers: 384 if layer in layer_to_name: 385 self._handle_deferred_dependencies(name=layer_to_name[layer], 386 trackable=layer) 387 388 @property 389 def _should_compute_mask(self): 390 return True 391 392 def compute_mask(self, inputs, mask): 393 # TODO(omalleyt): b/123540974 This function is not really safe to call 394 # by itself because it will duplicate any updates and losses in graph 395 # mode by `call`ing the Layers again. 396 output_tensors = self._run_internal_graph(inputs, mask=mask) 397 return nest.map_structure(lambda t: getattr(t, '_keras_mask', None), 398 output_tensors) 399 400 @doc_controls.do_not_doc_inheritable 401 def call(self, inputs, training=None, mask=None): 402 """Calls the model on new inputs. 403 404 In this case `call` just reapplies 405 all ops in the graph to the new inputs 406 (e.g. build a new computational graph from the provided inputs). 407 408 Args: 409 inputs: A tensor or list of tensors. 410 training: Boolean or boolean scalar tensor, indicating whether to run 411 the `Network` in training mode or inference mode. 412 mask: A mask or list of masks. A mask can be 413 either a tensor or None (no mask). 414 415 Returns: 416 A tensor if there is a single output, or 417 a list of tensors if there are more than one outputs. 418 """ 419 return self._run_internal_graph( 420 inputs, training=training, mask=mask) 421 422 def compute_output_shape(self, input_shape): 423 # Convert any shapes in tuple format to TensorShapes. 424 input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False) 425 426 if len(nest.flatten(input_shape)) != len(nest.flatten(self._input_layers)): 427 raise ValueError('Invalid input_shape argument ' + str(input_shape) + 428 ': model has ' + str(len(self._input_layers)) + 429 ' tensor inputs.') 430 431 # Use the tuple of TensorShape as the cache key, since tuple is hashable 432 # and can be used as hash key. 433 try: 434 cache_key = tuple(tf_utils.convert_shapes(input_shape, to_tuples=True)) 435 if cache_key in self._output_shape_cache: 436 # Cache hit. Return shapes as TensorShapes. 437 return self._output_shape_cache[cache_key] 438 except ValueError: 439 # In case there are unknown TensorShape, eg for sparse tensor input, 440 # We skip the caching since the shape is unknown. 441 pass 442 443 layers_to_output_shapes = {} 444 for layer, shape in zip(self._input_layers, nest.flatten(input_shape)): 445 # It's an input layer: then `compute_output_shape` is identity, 446 # and there is only one node and one tensor.. 447 shape_key = layer.name + '_0_0' 448 layers_to_output_shapes[shape_key] = shape 449 450 depth_keys = list(self._nodes_by_depth.keys()) 451 depth_keys.sort(reverse=True) 452 # Iterate over nodes, by depth level. 453 if len(depth_keys) > 1: 454 for depth in depth_keys: 455 nodes = self._nodes_by_depth[depth] 456 for node in nodes: 457 layer = node.layer 458 if layer in self._input_layers: 459 # We've already covered the input layers 460 # a few lines above. 461 continue 462 # Get the input shapes for the first argument of the node 463 layer_input_shapes = [] 464 layer_inputs = node.call_args[0] 465 for layer_input in nest.flatten(layer_inputs): 466 kh = layer_input._keras_history 467 input_layer_key = kh.layer.name + '_%s_%s' % (kh.node_index, 468 kh.tensor_index) 469 layer_input_shapes.append(layers_to_output_shapes[input_layer_key]) 470 layer_input_shapes = nest.pack_sequence_as(layer_inputs, 471 layer_input_shapes) 472 # Layers expect shapes to be tuples for `compute_output_shape`. 473 layer_input_shapes = tf_utils.convert_shapes( 474 layer_input_shapes, to_tuples=True) 475 layer_output_shapes = layer.compute_output_shape(layer_input_shapes) 476 # Convert back to TensorShapes. 477 layer_output_shapes = tf_utils.convert_shapes( 478 layer_output_shapes, to_tuples=False) 479 480 node_index = layer._inbound_nodes.index(node) # pylint: disable=protected-access 481 for j, shape in enumerate(nest.flatten(layer_output_shapes)): 482 shape_key = layer.name + '_%s_%s' % (node_index, j) 483 layers_to_output_shapes[shape_key] = shape 484 485 # Read final output shapes from layers_to_output_shapes. 486 output_shapes = [] 487 for i in range(len(self._output_layers)): 488 layer, node_index, tensor_index = self._output_coordinates[i] 489 shape_key = layer.name + '_%s_%s' % (node_index, tensor_index) 490 output_shapes.append(layers_to_output_shapes[shape_key]) 491 output_shapes = nest.pack_sequence_as(self._nested_outputs, output_shapes) 492 # Store in cache. 493 self._output_shape_cache[cache_key] = output_shapes 494 495 # Return shapes as TensorShapes. 496 return output_shapes 497 498 def _init_set_name(self, name, zero_based=True): 499 if not name: 500 cls_name = self.__class__.__name__ 501 if self.__class__ == Functional: 502 # Hide the functional class name from user, since its not a public 503 # visible class. Use "Model" instead, 504 cls_name = 'Model' 505 self._name = backend.unique_object_name( 506 generic_utils.to_snake_case(cls_name), 507 zero_based=zero_based) 508 else: 509 self._name = name 510 511 def _run_internal_graph(self, inputs, training=None, mask=None): 512 """Computes output tensors for new inputs. 513 514 # Note: 515 - Can be run on non-Keras tensors. 516 517 Args: 518 inputs: Tensor or nested structure of Tensors. 519 training: Boolean learning phase. 520 mask: (Optional) Tensor or nested structure of Tensors. 521 522 Returns: 523 output_tensors 524 """ 525 inputs = self._flatten_to_reference_inputs(inputs) 526 if mask is None: 527 masks = [None] * len(inputs) 528 else: 529 masks = self._flatten_to_reference_inputs(mask) 530 for input_t, mask in zip(inputs, masks): 531 input_t._keras_mask = mask 532 533 # Dictionary mapping reference tensors to computed tensors. 534 tensor_dict = {} 535 tensor_usage_count = self._tensor_usage_count 536 for x, y in zip(self.inputs, inputs): 537 y = self._conform_to_reference_input(y, ref_input=x) 538 x_id = str(id(x)) 539 tensor_dict[x_id] = [y] * tensor_usage_count[x_id] 540 541 nodes_by_depth = self._nodes_by_depth 542 depth_keys = list(nodes_by_depth.keys()) 543 depth_keys.sort(reverse=True) 544 545 for depth in depth_keys: 546 nodes = nodes_by_depth[depth] 547 for node in nodes: 548 if node.is_input: 549 continue # Input tensors already exist. 550 551 if any(t_id not in tensor_dict for t_id in node.flat_input_ids): 552 continue # Node is not computable, try skipping. 553 554 args, kwargs = node.map_arguments(tensor_dict) 555 outputs = node.layer(*args, **kwargs) 556 557 # Update tensor_dict. 558 for x_id, y in zip(node.flat_output_ids, nest.flatten(outputs)): 559 tensor_dict[x_id] = [y] * tensor_usage_count[x_id] 560 561 output_tensors = [] 562 for x in self.outputs: 563 x_id = str(id(x)) 564 assert x_id in tensor_dict, 'Could not compute output ' + str(x) 565 output_tensors.append(tensor_dict[x_id].pop()) 566 567 return nest.pack_sequence_as(self._nested_outputs, output_tensors) 568 569 def _flatten_to_reference_inputs(self, tensors): 570 """Maps `tensors` to their respective `keras.Input`.""" 571 if self._enable_dict_to_input_mapping and isinstance(tensors, dict): 572 ref_inputs = self._nested_inputs 573 if not nest.is_nested(ref_inputs): 574 ref_inputs = [self._nested_inputs] 575 if isinstance(ref_inputs, dict): 576 # In the case that the graph is constructed with dict input tensors, 577 # We will use the original dict key to map with the keys in the input 578 # data. Note that the model.inputs is using nest.flatten to process the 579 # input tensors, which means the dict input tensors are ordered by their 580 # keys. 581 ref_input_names = sorted(ref_inputs.keys()) 582 else: 583 ref_input_names = [inp._keras_history.layer.name for inp in ref_inputs] 584 585 # Raise an warning if there are more input data comparing to input tensor 586 if len(tensors) > len(ref_input_names): 587 warnings.warn( 588 'Input dict contained keys {} which did not match any model input. ' 589 'They will be ignored by the model.'.format( 590 [n for n in tensors.keys() if n not in ref_input_names]) 591 ) 592 593 try: 594 # Flatten in the order `Input`s were passed during Model construction. 595 return [tensors[n] for n in ref_input_names] 596 except KeyError: 597 # TODO(b/151582614) 598 return nest.flatten(tensors) 599 600 # Otherwise both self.inputs and tensors will already be in same order. 601 return nest.flatten(tensors) 602 603 def _conform_to_reference_input(self, tensor, ref_input): 604 """Set shape and dtype based on `keras.Input`s.""" 605 if isinstance(tensor, ops.Tensor): 606 # Allow (None,) and (None, 1) Tensors to be passed interchangeably. Use 607 # the shape specified by the `keras.Input`. 608 t_shape = tensor.shape 609 t_rank = t_shape.rank 610 ref_shape = ref_input.shape 611 ref_rank = ref_shape.rank 612 keras_history = getattr(tensor, '_keras_history', None) 613 if t_rank is not None and ref_rank is not None: 614 # Should squeeze last dimension. 615 # True if tensor is (BATCH, ..., 1) and reference is (BATCH, ...). 616 if (t_rank == ref_rank + 1 and t_shape[-1] == 1): 617 tensor = array_ops.squeeze_v2(tensor, axis=-1) 618 # Should expand last_dimension. 619 # True if tensor is (BATCH, ...) and reference is (BATCH, ..., 1). 620 elif (t_rank == ref_rank - 1 and ref_shape[-1] == 1): 621 tensor = array_ops.expand_dims_v2(tensor, axis=-1) 622 if keras_history is not None: # Restore keras history. 623 tensor._keras_history = keras_history 624 625 # Add shape hints to Tensors that may have None shape dims but have shapes 626 # defined by the `keras.Input` (not applicable in eager mode). 627 if not context.executing_eagerly(): 628 try: 629 tensor.set_shape(tensor.shape.merge_with(ref_input.shape)) 630 except ValueError: 631 logging.warning( 632 'Model was constructed with shape {} for input {}, but it was ' 633 'called on an input with incompatible shape {}.'.format( 634 ref_input.shape, ref_input, tensor.shape)) 635 636 # Dtype casting. 637 tensor = math_ops.cast(tensor, dtype=ref_input.dtype) 638 elif tf_utils.is_extension_type(tensor): 639 # Dtype casting (If the extension type has a non-variant dtype and 640 # supports being cast) 641 ref_input_dtype = getattr(ref_input, 'dtype', None) 642 if ref_input_dtype is not None and ref_input_dtype != dtypes.variant: 643 tensor = math_ops.cast(tensor, dtype=ref_input_dtype) 644 645 return tensor 646 647 def get_config(self): 648 return copy.deepcopy(get_network_config(self)) 649 650 @classmethod 651 def from_config(cls, config, custom_objects=None): 652 """Instantiates a Model from its config (output of `get_config()`). 653 654 Args: 655 config: Model config dictionary. 656 custom_objects: Optional dictionary mapping names 657 (strings) to custom classes or functions to be 658 considered during deserialization. 659 660 Returns: 661 A model instance. 662 663 Raises: 664 ValueError: In case of improperly formatted config dict. 665 """ 666 with generic_utils.SharedObjectLoadingScope(): 667 input_tensors, output_tensors, created_layers = reconstruct_from_config( 668 config, custom_objects) 669 model = cls(inputs=input_tensors, outputs=output_tensors, 670 name=config.get('name')) 671 connect_ancillary_layers(model, created_layers) 672 return model 673 674 def _validate_graph_inputs_and_outputs(self): 675 """Validates the inputs and outputs of a Graph Network.""" 676 # Check for redundancy in inputs. 677 if len({id(i) for i in self.inputs}) != len(self.inputs): 678 raise ValueError('The list of inputs passed to the model ' 679 'is redundant. ' 680 'All inputs should only appear once.' 681 ' Found: ' + str(self.inputs)) 682 683 for x in self.inputs: 684 # Check that x has appropriate `_keras_history` metadata. 685 if not hasattr(x, '_keras_history'): 686 cls_name = self.__class__.__name__ 687 raise ValueError('Input tensors to a ' + cls_name + ' ' + 688 'must come from `tf.keras.Input`. ' 689 'Received: ' + str(x) + 690 ' (missing previous layer metadata).') 691 # Check that x is an input tensor. 692 # pylint: disable=protected-access 693 layer = x._keras_history.layer 694 if len(layer._inbound_nodes) > 1 or ( 695 layer._inbound_nodes and not layer._inbound_nodes[0].is_input): 696 cls_name = self.__class__.__name__ 697 logging.warning(cls_name + ' model inputs must come from ' 698 '`tf.keras.Input` (thus holding past layer metadata), ' 699 'they cannot be the output of ' 700 'a previous non-Input layer. ' 701 'Here, a tensor specified as ' 702 'input to "' + self.name + '" was not an Input tensor, ' 703 'it was generated by layer ' + layer.name + '.\n' 704 'Note that input tensors are ' 705 'instantiated via `tensor = tf.keras.Input(shape)`.\n' 706 'The tensor that caused the issue was: ' + str(x.name)) 707 708 # Check compatibility of batch sizes of Input Layers. 709 input_batch_sizes = [ 710 training_utils.get_static_batch_size(x._keras_history.layer) 711 for x in self.inputs 712 ] 713 consistent_batch_size = None 714 for batch_size in input_batch_sizes: 715 if batch_size is not None: 716 if (consistent_batch_size is not None and 717 batch_size != consistent_batch_size): 718 raise ValueError('The specified batch sizes of the Input Layers' 719 ' are incompatible. Found batch sizes: {}'.format( 720 input_batch_sizes)) 721 consistent_batch_size = batch_size 722 723 for x in self.outputs: 724 if not hasattr(x, '_keras_history'): 725 cls_name = self.__class__.__name__ 726 raise ValueError('Output tensors of a ' + cls_name + ' model must be ' 727 'the output of a TensorFlow `Layer` ' 728 '(thus holding past layer metadata). Found: ' + str(x)) 729 730 def _insert_layers(self, layers, relevant_nodes=None): 731 """Inserts Layers into the Network after Network creation. 732 733 This is only valid for Keras Graph Networks. Layers added via this function 734 will be included in the `call` computation and `get_config` of this Network. 735 They will not be added to the Network's outputs. 736 737 738 Args: 739 layers: Arbitrary nested structure of Layers. Layers must be reachable 740 from one or more of the `keras.Input` Tensors that correspond to this 741 Network's inputs. 742 relevant_nodes: Nodes from the Layers that should be considered part of 743 this Network. If `None`, all Nodes will be considered part of this 744 Network. 745 746 Raises: 747 ValueError: If the layers depend on `Input`s not found in this Model. 748 """ 749 layers = nest.flatten(layers) 750 tf_utils.assert_no_legacy_layers(layers) 751 node_to_depth = {} 752 for depth, nodes in self._nodes_by_depth.items(): 753 node_to_depth.update({node: depth for node in nodes}) 754 # The nodes of these Layers that are relevant to this Network. If not 755 # provided, assume all Nodes are relevant 756 if not relevant_nodes: 757 relevant_nodes = nest.flatten([layer._inbound_nodes for layer in layers]) 758 network_nodes = set(relevant_nodes + list(node_to_depth.keys())) 759 760 def _get_min_depth(node): 761 """Gets the minimum depth at which node can be computed.""" 762 min_depth = 0 763 for layer, node_id, _, _ in node.iterate_inbound(): 764 inbound_node = layer._inbound_nodes[node_id] 765 if inbound_node in node_to_depth: 766 min_depth = min(min_depth, node_to_depth[inbound_node]) 767 elif inbound_node not in network_nodes: 768 continue 769 else: 770 # Previous relevant nodes haven't been processed yet. 771 return None 772 # New node is one shallower than its shallowest input. 773 return min_depth - 1 774 775 # Insert nodes into `_nodes_by_depth` and other node attrs. 776 unprocessed_nodes = copy.copy(relevant_nodes) 777 i = 0 778 while unprocessed_nodes: 779 i += 1 780 # Do a sanity check. This can occur if `Input`s from outside this Model 781 # are being relied on. 782 if i > 10000: 783 raise ValueError('Layers could not be added due to missing ' 784 'dependencies.') 785 786 node = unprocessed_nodes.pop(0) 787 depth = _get_min_depth(node) 788 if depth is None: # Defer until inbound nodes are processed. 789 unprocessed_nodes.append(node) 790 continue 791 node_key = _make_node_key(node.layer.name, 792 node.layer._inbound_nodes.index(node)) 793 if node_key not in self._network_nodes: 794 node_to_depth[node] = depth 795 self._network_nodes.add(node_key) 796 self._nodes_by_depth[depth].append(node) 797 798 # Insert layers and update other layer attrs. 799 layer_set = set(self._self_tracked_trackables) 800 deferred_layers = [] 801 for layer in layers: 802 if layer not in layer_set: 803 self._self_tracked_trackables.append(layer) 804 deferred_layers.append(layer) 805 self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call) 806 layer_set.add(layer) 807 self._handle_deferred_layer_dependencies(deferred_layers) 808 809 self._compute_tensor_usage_count() 810 811 def _compute_tensor_usage_count(self): 812 """Compute the #. of tensor usages for all the output tensors of layers. 813 814 The computed tensor usage count is saved as `self._tensor_usage_count`. This 815 is later used for saving memory in eager computation by releasing 816 no-longer-needed tensors as early as possible. 817 """ 818 tensor_usage_count = collections.Counter() 819 available_tensors = set(str(id(tensor)) for tensor in self.inputs) 820 821 depth_keys = list(self._nodes_by_depth.keys()) 822 depth_keys.sort(reverse=True) 823 depth_keys = depth_keys[1:] 824 825 for depth in depth_keys: 826 for node in self._nodes_by_depth[depth]: 827 input_tensors = { 828 str(id(tensor)) for tensor in nest.flatten(node.keras_inputs) 829 } 830 if input_tensors.issubset(available_tensors): 831 for tensor in nest.flatten(node.keras_inputs): 832 tensor_usage_count[str(id(tensor))] += 1 833 834 for output_tensor in nest.flatten(node.outputs): 835 available_tensors.add(str(id(output_tensor))) 836 837 for tensor in self.outputs: 838 tensor_usage_count[str(id(tensor))] += 1 839 840 self._tensor_usage_count = tensor_usage_count 841 842 def _assert_weights_created(self): 843 # Override the implementation in Model. 844 # The Functional model should always have weight created already. 845 return 846 847 def _graph_network_add_loss(self, symbolic_loss): 848 new_nodes, new_layers = _map_subgraph_network(self.inputs, [symbolic_loss]) 849 # Losses must be keyed on inputs no matter what in order to be supported in 850 # DistributionStrategy. 851 add_loss_layer = base_layer.AddLoss( 852 unconditional=False, dtype=symbolic_loss.dtype) 853 add_loss_layer(symbolic_loss) 854 new_nodes.extend(add_loss_layer.inbound_nodes) 855 new_layers.append(add_loss_layer) 856 self._insert_layers(new_layers, new_nodes) 857 858 def _graph_network_add_metric(self, value, aggregation, name): 859 new_nodes, new_layers = _map_subgraph_network(self.inputs, [value]) 860 add_metric_layer = base_layer.AddMetric( 861 aggregation, name, dtype=value.dtype) 862 add_metric_layer(value) 863 new_nodes.extend(add_metric_layer.inbound_nodes) 864 new_layers.append(add_metric_layer) 865 self._insert_layers(new_layers, new_nodes) 866 867 @property 868 def _trackable_saved_model_saver(self): 869 return network_serialization.NetworkSavedModelSaver(self) 870 871 def _get_save_spec(self, dynamic_batch=True): 872 if getattr(self, '_has_explicit_input_shape', True): 873 # Functional models and Sequential models that have an explicit input 874 # shape should use the batch size set by the input layer. 875 dynamic_batch = False 876 return super(Functional, self)._get_save_spec(dynamic_batch) 877 878 879def _make_node_key(layer_name, node_index): 880 return layer_name + '_ib-' + str(node_index) 881 882 883def _map_graph_network(inputs, outputs): 884 """Validates a network's topology and gather its layers and nodes. 885 886 Args: 887 inputs: List of input tensors. 888 outputs: List of outputs tensors. 889 890 Returns: 891 A tuple `(nodes, nodes_by_depth, layers, layers_by_depth)`. 892 - nodes: list of Node instances. 893 - nodes_by_depth: dict mapping ints (depth) to lists of node instances. 894 - layers: list of Layer instances. 895 - layers_by_depth: dict mapping ints (depth) to lists of layer instances. 896 897 Raises: 898 ValueError: In case the network is not valid (e.g. disconnected graph). 899 """ 900 # "depth" is number of layers between output Node and the Node. 901 # Nodes are ordered from inputs -> outputs. 902 nodes_in_decreasing_depth, layer_indices = _build_map(outputs) 903 network_nodes = { 904 _make_node_key(node.layer.name, node.layer._inbound_nodes.index(node)) 905 for node in nodes_in_decreasing_depth 906 } 907 908 nodes_depths = {} # dict {node: depth value} 909 layers_depths = {} # dict {layer: depth value} 910 911 for node in reversed(nodes_in_decreasing_depth): 912 # If the depth is not set, the node has no outbound nodes (depth 0). 913 depth = nodes_depths.setdefault(node, 0) 914 915 # Update the depth of the corresponding layer 916 previous_depth = layers_depths.get(node.layer, 0) 917 # If we've seen this layer before at a higher depth, 918 # we should use that depth instead of the node depth. 919 # This is necessary for shared layers that have inputs at different 920 # depth levels in the graph. 921 depth = max(depth, previous_depth) 922 layers_depths[node.layer] = depth 923 nodes_depths[node] = depth 924 925 # Update the depth of inbound nodes. 926 # The "depth" of a node is the max of the depths 927 # of all nodes it is connected to + 1. 928 for node_dep in node.parent_nodes: 929 previous_depth = nodes_depths.get(node_dep, 0) 930 nodes_depths[node_dep] = max(depth + 1, previous_depth) 931 932 # Handle inputs that are not connected to outputs. 933 # We do not error out here because the inputs may be used to compute losses 934 # and metrics. 935 for input_t in inputs: 936 input_layer = input_t._keras_history[0] 937 if input_layer not in layers_depths: 938 layers_depths[input_layer] = 0 939 layer_indices[input_layer] = -1 940 nodes_depths[input_layer._inbound_nodes[0]] = 0 941 network_nodes.add(_make_node_key(input_layer.name, 0)) 942 943 # Build a dict {depth: list of nodes with this depth} 944 nodes_by_depth = collections.defaultdict(list) 945 for node, depth in nodes_depths.items(): 946 nodes_by_depth[depth].append(node) 947 948 # Build a dict {depth: list of layers with this depth} 949 layers_by_depth = collections.defaultdict(list) 950 for layer, depth in layers_depths.items(): 951 layers_by_depth[depth].append(layer) 952 953 # Get sorted list of layer depths. 954 depth_keys = list(layers_by_depth.keys()) 955 depth_keys.sort(reverse=True) 956 957 # Set self.layers ordered by depth. 958 layers = [] 959 for depth in depth_keys: 960 layers_for_depth = layers_by_depth[depth] 961 # Network.layers needs to have a deterministic order: 962 # here we order them by traversal order. 963 layers_for_depth.sort(key=lambda x: layer_indices[x]) 964 layers.extend(layers_for_depth) 965 966 # Get sorted list of node depths. 967 depth_keys = list(nodes_by_depth.keys()) 968 depth_keys.sort(reverse=True) 969 970 # Check that all tensors required are computable. 971 # computable_tensors: all tensors in the graph 972 # that can be computed from the inputs provided. 973 computable_tensors = set() 974 for x in inputs: 975 computable_tensors.add(id(x)) 976 977 layers_with_complete_input = [] # To provide a better error msg. 978 for depth in depth_keys: 979 for node in nodes_by_depth[depth]: 980 layer = node.layer 981 if layer and not node.is_input: 982 for x in nest.flatten(node.keras_inputs): 983 if id(x) not in computable_tensors: 984 raise ValueError('Graph disconnected: ' 985 'cannot obtain value for tensor ' + str(x) + 986 ' at layer "' + layer.name + '". ' 987 'The following previous layers ' 988 'were accessed without issue: ' + 989 str(layers_with_complete_input)) 990 for x in nest.flatten(node.outputs): 991 computable_tensors.add(id(x)) 992 layers_with_complete_input.append(layer.name) 993 994 # Ensure name unicity, which will be crucial for serialization 995 # (since serialized nodes refer to layers by their name). 996 all_names = [layer.name for layer in layers] 997 for name in all_names: 998 if all_names.count(name) != 1: 999 raise ValueError('The name "' + name + '" is used ' + 1000 str(all_names.count(name)) + ' times in the model. ' 1001 'All layer names should be unique.') 1002 return network_nodes, nodes_by_depth, layers, layers_by_depth 1003 1004 1005def _build_map(outputs): 1006 """This method topologically sorts nodes in order from inputs to outputs. 1007 1008 It uses a depth-first search to topologically sort nodes that appear in the 1009 _keras_history connectivity metadata of `outputs`. 1010 1011 Args: 1012 outputs: the output tensors whose _keras_history metadata should be walked. 1013 This may be an arbitrary nested structure. 1014 1015 Returns: 1016 A tuple like (ordered_nodes, layer_to_first_traversal_index) 1017 ordered_nodes: list of nodes appearing in the keras history, topologically 1018 sorted from original inputs to the `outputs`. 1019 (If outputs have different sets of ancestors, the inputs to one output 1020 may appear after a different output). 1021 layer_to_first_traversal_index: 1022 A dict mapping layer to the traversal index in the DFS where it is 1023 seen. Note: if a layer is shared by several nodes, the dict will only 1024 store the index corresponding to the *first* time the layer seen. 1025 """ 1026 finished_nodes = set() 1027 nodes_in_progress = set() 1028 nodes_in_decreasing_depth = [] # nodes from inputs -> outputs. 1029 layer_indices = {} # layer -> in traversal order. 1030 for output in nest.flatten(outputs): 1031 _build_map_helper(output, finished_nodes, nodes_in_progress, 1032 nodes_in_decreasing_depth, layer_indices) 1033 return nodes_in_decreasing_depth, layer_indices 1034 1035 1036def _build_map_helper(tensor, finished_nodes, nodes_in_progress, 1037 nodes_in_decreasing_depth, layer_indices): 1038 """Recursive helper for `_build_map`.""" 1039 layer, node_index, _ = tensor._keras_history # pylint: disable=protected-access 1040 node = layer._inbound_nodes[node_index] # pylint: disable=protected-access 1041 1042 # Don't repeat work for shared subgraphs 1043 if node in finished_nodes: 1044 return 1045 1046 # Prevent cycles. 1047 if node in nodes_in_progress: 1048 raise ValueError('The tensor ' + str(tensor) + ' at layer "' + layer.name + 1049 '" is part of a cycle.') 1050 1051 # Store the traversal order for layer sorting. 1052 if layer not in layer_indices: 1053 layer_indices[layer] = len(layer_indices) 1054 1055 # Propagate to all previous tensors connected to this node. 1056 nodes_in_progress.add(node) 1057 if not node.is_input: 1058 for tensor in node.keras_inputs: 1059 _build_map_helper(tensor, finished_nodes, nodes_in_progress, 1060 nodes_in_decreasing_depth, layer_indices) 1061 1062 finished_nodes.add(node) 1063 nodes_in_progress.remove(node) 1064 nodes_in_decreasing_depth.append(node) 1065 1066 1067def _map_subgraph_network(inputs, outputs): 1068 """Returns the nodes and layers in the topology from `inputs` to `outputs`. 1069 1070 Args: 1071 inputs: List of input tensors. 1072 outputs: List of output tensors. 1073 1074 Returns: 1075 A tuple of List{Node] and List[Layer]. 1076 """ 1077 if not ops.executing_eagerly_outside_functions(): 1078 base_layer_utils.create_keras_history(outputs) 1079 # Keep only nodes and layers in the topology between inputs and outputs. 1080 _, nodes_by_depth, layers, _ = _map_graph_network(inputs, outputs) 1081 return nest.flatten([nodes for nodes in nodes_by_depth.values()]), layers 1082 1083 1084def _should_skip_first_node(layer): 1085 """Returns True if the first layer node should not be saved or loaded.""" 1086 # Networks that are constructed with an Input layer/shape start with a 1087 # pre-existing node linking their input to output. This node is excluded from 1088 # the network config. 1089 if layer._self_tracked_trackables: 1090 return (isinstance(layer, Functional) and 1091 # Filter out Sequential models without an input shape. 1092 isinstance(layer._self_tracked_trackables[0], 1093 input_layer_module.InputLayer)) 1094 else: 1095 return isinstance(layer, Functional) 1096 1097 1098def connect_ancillary_layers(model, created_layers): 1099 """Adds layers that are not connected to the outputs to the model.""" 1100 # Layers not connected to outputs, such as those added in `add_loss`. 1101 ancillary_layers = [ 1102 layer for layer in created_layers.values() if layer not in model.layers 1103 ] 1104 if ancillary_layers: 1105 relevant_nodes = nest.flatten([ 1106 layer.inbound_nodes[1:] 1107 if _should_skip_first_node(layer) else layer.inbound_nodes 1108 for layer in created_layers.values() 1109 ]) 1110 model._insert_layers(ancillary_layers, relevant_nodes) 1111 return model 1112 1113 1114def reconstruct_from_config(config, custom_objects=None, created_layers=None): 1115 """Reconstructs graph from config object. 1116 1117 Args: 1118 config: Dictionary returned from Network.get_config() 1119 custom_objects: Optional dictionary mapping names (strings) to custom 1120 classes or functions to be considered during deserialization. 1121 created_layers: Optional dictionary mapping names to Layer objects. Any 1122 layer not in this dictionary will be created and added to the dict. 1123 This function will add new nodes to all layers (excluding InputLayers), 1124 instead of re-using pre-existing nodes in the layers. 1125 1126 Returns: 1127 Tuple of (input tensors, output tensors, dictionary of created layers) 1128 """ 1129 # Layer instances created during the graph reconstruction process. 1130 created_layers = created_layers or collections.OrderedDict() 1131 1132 # Maps input data (tuple of inbound layer name, node index) from the config 1133 # to node indices in the newly generated model. The node indices may be 1134 # different if the layers have already been called previously. 1135 node_index_map = {} 1136 node_count_by_layer = {} 1137 1138 # Dictionary mapping layer instances to 1139 # node data that specifies a layer call. 1140 # It acts as a queue that maintains any unprocessed 1141 # layer call until it becomes possible to process it 1142 # (i.e. until the input tensors to the call all exist). 1143 unprocessed_nodes = {} 1144 1145 def add_unprocessed_node(layer, node_data): 1146 if layer not in unprocessed_nodes: 1147 unprocessed_nodes[layer] = [node_data] 1148 else: 1149 unprocessed_nodes[layer].append(node_data) 1150 1151 def get_node_index(layer, config_node_index): 1152 """Returns node index in layer (might differ from config_node_index).""" 1153 if isinstance(layer, input_layer_module.InputLayer): 1154 return 0 1155 return node_index_map.get((layer.name, config_node_index), None) 1156 1157 def _deserialize_keras_tensors(kwargs, layer_map): 1158 """Deserializes Keras Tensors passed to `call`..""" 1159 1160 def _deserialize_keras_tensor(t): 1161 """Deserializes a single Keras Tensor passed to `call`.""" 1162 if isinstance(t, tf_utils.ListWrapper): 1163 t = t.as_list() 1164 layer_name = t[0] 1165 node_index = t[1] 1166 tensor_index = t[2] 1167 1168 layer = layer_map[layer_name] 1169 new_node_index = get_node_index(layer, node_index) 1170 if new_node_index is None: 1171 # The inbound node may not have been processed yet, 1172 # (This can happen e.g. if it depends on a different set 1173 # of inputs than those that have been processed already). 1174 # raise an IndexError so that the current node puts itself 1175 # back on the unprocessed queue. 1176 # Caution: This may lead to infinite loops for malformed 1177 # network configurations! (or when there is a bug in 1178 # the network config loading code). 1179 raise IndexError 1180 node = layer._inbound_nodes[new_node_index] 1181 return nest.flatten(node.outputs)[tensor_index] 1182 return t 1183 1184 kwargs = tf_utils.convert_inner_node_data(kwargs, wrap=True) 1185 return nest.map_structure(_deserialize_keras_tensor, kwargs) 1186 1187 def process_node(layer, node_data): 1188 """Deserialize a node. 1189 1190 Args: 1191 layer: layer instance. 1192 node_data: Nested structure of `ListWrapper`. 1193 1194 Raises: 1195 ValueError: In case of improperly formatted `node_data`. 1196 """ 1197 input_tensors = [] 1198 for input_data in nest.flatten(node_data): 1199 input_data = input_data.as_list() 1200 inbound_layer_name = input_data[0] 1201 inbound_node_index = input_data[1] 1202 inbound_tensor_index = input_data[2] 1203 if len(input_data) == 3: 1204 kwargs = {} 1205 elif len(input_data) == 4: 1206 kwargs = input_data[3] 1207 try: 1208 kwargs = _deserialize_keras_tensors(kwargs, created_layers) 1209 except IndexError: 1210 # Happens if keras tensors in kwargs are still unprocessed 1211 add_unprocessed_node(layer, node_data) 1212 return 1213 else: 1214 raise ValueError('Improperly formatted model config.') 1215 1216 if inbound_layer_name != node_module._CONSTANT_VALUE: 1217 inbound_layer = created_layers[inbound_layer_name] 1218 inbound_node_index = get_node_index(inbound_layer, inbound_node_index) 1219 1220 if inbound_node_index is None: 1221 add_unprocessed_node(layer, node_data) 1222 return 1223 inbound_node = inbound_layer._inbound_nodes[inbound_node_index] 1224 input_tensors.append( 1225 nest.flatten(inbound_node.outputs)[inbound_tensor_index]) 1226 else: 1227 # We received a constant w/ no Keras history attached 1228 input_tensors.append(inbound_tensor_index) 1229 input_tensors = nest.pack_sequence_as(node_data, input_tensors) 1230 # Call layer on its inputs, thus creating the node 1231 # and building the layer if needed. 1232 if input_tensors is not None: 1233 if not layer._preserve_input_structure_in_config: 1234 input_tensors = ( 1235 base_layer_utils.unnest_if_single_tensor(input_tensors)) 1236 output_tensors = layer(input_tensors, **kwargs) 1237 1238 # Update node index map. 1239 output_index = nest.flatten(output_tensors)[0]._keras_history.node_index 1240 node_index_map[(layer.name, node_count_by_layer[layer])] = output_index 1241 node_count_by_layer[layer] += 1 1242 1243 def process_layer(layer_data): 1244 """Deserializes a layer, then call it on appropriate inputs. 1245 1246 Args: 1247 layer_data: layer config dict. 1248 1249 Raises: 1250 ValueError: In case of improperly formatted `layer_data` dict. 1251 """ 1252 layer_name = layer_data['name'] 1253 1254 if layer_name in created_layers: 1255 layer = created_layers[layer_name] 1256 else: 1257 # Instantiate layer. 1258 from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top 1259 1260 layer = deserialize_layer(layer_data, custom_objects=custom_objects) 1261 created_layers[layer_name] = layer 1262 1263 node_count_by_layer[layer] = int(_should_skip_first_node(layer)) 1264 1265 # Gather layer inputs and convert to `ListWrapper` objects. 1266 inbound_nodes_data = layer_data['inbound_nodes'] 1267 inbound_nodes_data = tf_utils.convert_inner_node_data( 1268 inbound_nodes_data, wrap=True) 1269 for node_data in inbound_nodes_data: 1270 # We don't process nodes (i.e. make layer calls) 1271 # on the fly because the inbound node may not yet exist, 1272 # in case of layer shared at different topological depths 1273 # (e.g. a model such as A(B(A(B(x))))) 1274 add_unprocessed_node(layer, node_data) 1275 1276 # First, we create all layers and enqueue nodes to be processed 1277 for layer_data in config['layers']: 1278 process_layer(layer_data) 1279 # Then we process nodes in order of layer depth. 1280 # Nodes that cannot yet be processed (if the inbound node 1281 # does not yet exist) are re-enqueued, and the process 1282 # is repeated until all nodes are processed. 1283 while unprocessed_nodes: 1284 for layer_data in config['layers']: 1285 layer = created_layers[layer_data['name']] 1286 if layer in unprocessed_nodes: 1287 for node_data in unprocessed_nodes.pop(layer): 1288 process_node(layer, node_data) 1289 1290 input_tensors = [] 1291 output_tensors = [] 1292 1293 input_layers = tf_utils.convert_inner_node_data( 1294 config['input_layers'], wrap=True) 1295 for layer_data in nest.flatten(input_layers): 1296 layer_name, node_index, tensor_index = layer_data.as_list() 1297 assert layer_name in created_layers 1298 layer = created_layers[layer_name] 1299 node_index = get_node_index(layer, node_index) 1300 layer_output_tensors = layer._inbound_nodes[node_index].output_tensors 1301 input_tensors.append(nest.flatten(layer_output_tensors)[tensor_index]) 1302 1303 output_layers = tf_utils.convert_inner_node_data( 1304 config['output_layers'], wrap=True) 1305 for layer_data in nest.flatten(output_layers): 1306 layer_name, node_index, tensor_index = layer_data.as_list() 1307 assert layer_name in created_layers 1308 layer = created_layers[layer_name] 1309 node_index = get_node_index(layer, node_index) 1310 layer_output_tensors = layer._inbound_nodes[node_index].output_tensors 1311 output_tensors.append(nest.flatten(layer_output_tensors)[tensor_index]) 1312 1313 input_tensors = nest.pack_sequence_as(input_layers, input_tensors) 1314 output_tensors = nest.pack_sequence_as(output_layers, output_tensors) 1315 return input_tensors, output_tensors, created_layers 1316 1317 1318def get_network_config(network, serialize_layer_fn=None): 1319 """Builds the config, which consists of the node graph and serialized layers. 1320 1321 Args: 1322 network: A Network object. 1323 serialize_layer_fn: Function used to serialize layers. 1324 1325 Returns: 1326 Config dictionary. 1327 """ 1328 serialize_layer_fn = ( 1329 serialize_layer_fn or generic_utils.serialize_keras_object) 1330 config = { 1331 'name': network.name, 1332 } 1333 node_conversion_map = {} 1334 for layer in network.layers: 1335 kept_nodes = 1 if _should_skip_first_node(layer) else 0 1336 for original_node_index, node in enumerate(layer._inbound_nodes): 1337 node_key = _make_node_key(layer.name, original_node_index) 1338 if node_key in network._network_nodes: 1339 node_conversion_map[node_key] = kept_nodes 1340 kept_nodes += 1 1341 layer_configs = [] 1342 1343 with generic_utils.SharedObjectSavingScope(): 1344 for layer in network.layers: # From the earliest layers on. 1345 filtered_inbound_nodes = [] 1346 for original_node_index, node in enumerate(layer._inbound_nodes): 1347 node_key = _make_node_key(layer.name, original_node_index) 1348 if node_key in network._network_nodes and not node.is_input: 1349 # The node is relevant to the model: 1350 # add to filtered_inbound_nodes. 1351 node_data = node.serialize(_make_node_key, node_conversion_map) 1352 filtered_inbound_nodes.append(node_data) 1353 1354 layer_config = serialize_layer_fn(layer) 1355 layer_config['name'] = layer.name 1356 layer_config['inbound_nodes'] = filtered_inbound_nodes 1357 layer_configs.append(layer_config) 1358 config['layers'] = layer_configs 1359 1360 # Gather info about inputs and outputs. 1361 model_inputs = [] 1362 for i in range(len(network._input_layers)): 1363 layer, node_index, tensor_index = network._input_coordinates[i] 1364 node_key = _make_node_key(layer.name, node_index) 1365 if node_key not in network._network_nodes: 1366 continue 1367 new_node_index = node_conversion_map[node_key] 1368 model_inputs.append( 1369 tf_utils.ListWrapper([layer.name, new_node_index, tensor_index])) 1370 model_inputs = nest.pack_sequence_as(network._nested_inputs, model_inputs) 1371 # Preserve external Keras compat for Models with single input. 1372 if not nest.is_nested(model_inputs): 1373 model_inputs = [model_inputs] 1374 model_inputs = tf_utils.convert_inner_node_data(model_inputs) 1375 config['input_layers'] = model_inputs 1376 1377 model_outputs = [] 1378 for i in range(len(network._output_layers)): 1379 layer, node_index, tensor_index = network._output_coordinates[i] 1380 node_key = _make_node_key(layer.name, node_index) 1381 if node_key not in network._network_nodes: 1382 continue 1383 new_node_index = node_conversion_map[node_key] 1384 model_outputs.append( 1385 tf_utils.ListWrapper([layer.name, new_node_index, tensor_index])) 1386 model_outputs = nest.pack_sequence_as(network._nested_outputs, model_outputs) 1387 # Preserve external Keras compat for Models with single output. 1388 if not nest.is_nested(model_outputs): 1389 model_outputs = [model_outputs] 1390 model_outputs = tf_utils.convert_inner_node_data(model_outputs) 1391 config['output_layers'] = model_outputs 1392 return config 1393 1394 1395def shape_with_no_batch_size(x): 1396 if x.shape.rank is None: 1397 return None 1398 shape = x.shape.as_list() 1399 if shape: 1400 shape[0] = None 1401 return shape 1402 1403 1404class ModuleWrapper(base_layer.Layer): 1405 """Wrapper for `tf.Module`s to support the Functional and Sequential API.""" 1406 1407 def __init__(self, module, method_name=None, **kwargs): 1408 """Initializes the wrapper Layer for this module. 1409 1410 Args: 1411 module: The `tf.Module` instance to be wrapped. 1412 method_name: (Optional) str. The name of the method to use as the forward 1413 pass of the module. If not set, defaults to '__call__' if defined, or 1414 'call'. 1415 **kwargs: Additional keywrod arguments. See `tf.keras.layers.Layer`. 1416 1417 Raises: 1418 ValueError: If `method` is not defined on `module`. 1419 """ 1420 super(ModuleWrapper, self).__init__(**kwargs) 1421 if method_name is None: 1422 if hasattr(module, '__call__'): 1423 method_name = '__call__' 1424 elif hasattr(module, 'call'): 1425 method_name = 'call' 1426 if method_name is None or not hasattr(module, method_name): 1427 raise ValueError('{} is not defined on object {}'.format( 1428 method_name, module)) 1429 1430 self._module = module 1431 self._method_name = method_name 1432 1433 # Check if module.__call__ has a `training` arg or accepts `**kwargs`. 1434 method = getattr(module, method_name) 1435 method_arg_spec = tf_inspect.getfullargspec(method) 1436 self._expects_training_arg = ('training' in method_arg_spec.args or 1437 method_arg_spec.varkw is not None) 1438 self._expects_mask_arg = ('mask' in method_arg_spec.args or 1439 method_arg_spec.varkw is not None) 1440 1441 def call(self, *args, **kwargs): 1442 if 'training' in kwargs and not self._expects_training_arg: 1443 kwargs.pop('training') 1444 if 'mask' in kwargs and not self._expects_mask_arg: 1445 kwargs.pop('mask') 1446 return getattr(self._module, self._method_name)(*args, **kwargs) 1447