xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/engine/functional.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16"""A `Network` is way to compose layers: the topological form of a `Model`."""
17
18import collections
19import copy
20import itertools
21import warnings
22
23from tensorflow.python.eager import context
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.keras import backend
27from tensorflow.python.keras.engine import base_layer
28from tensorflow.python.keras.engine import base_layer_utils
29from tensorflow.python.keras.engine import input_layer as input_layer_module
30from tensorflow.python.keras.engine import input_spec
31from tensorflow.python.keras.engine import node as node_module
32from tensorflow.python.keras.engine import training as training_lib
33from tensorflow.python.keras.engine import training_utils
34from tensorflow.python.keras.saving.saved_model import network_serialization
35from tensorflow.python.keras.utils import generic_utils
36from tensorflow.python.keras.utils import tf_inspect
37from tensorflow.python.keras.utils import tf_utils
38from tensorflow.python.ops import array_ops
39from tensorflow.python.ops import math_ops
40from tensorflow.python.platform import tf_logging as logging
41from tensorflow.python.trackable import base as trackable
42from tensorflow.python.util import nest
43from tensorflow.tools.docs import doc_controls
44
45
46# pylint: disable=g-classes-have-attributes
47class Functional(training_lib.Model):
48  """A `Functional` model is a `Model` defined as a directed graph of layers.
49
50  Three types of `Model` exist: subclassed `Model`, `Functional` model,
51  and `Sequential` (a special case of `Functional`).
52  In general, more Keras features are supported with `Functional`
53  than with subclassed `Model`s, specifically:
54
55  - Model cloning (`keras.models.clone`)
56  - Serialization (`model.get_config()/from_config`, `model.to_json()`
57  - Whole-model saving (`model.save()`)
58
59  A `Functional` model can be instantiated by passing two arguments to
60  `__init__`. The first argument is the `keras.Input` Tensors that represent
61  the inputs to the model. The second argument specifies the output
62  tensors that represent the outputs of this model. Both arguments can be a
63  nested structure of tensors.
64
65  Example:
66
67  ```
68  inputs = {'x1': keras.Input(shape=(10,)), 'x2': keras.Input(shape=(1,))}
69  t = keras.layers.Dense(1, activation='relu')(inputs['x1'])
70  outputs = keras.layers.Add()([t, inputs['x2'])
71  model = keras.Model(inputs, outputs)
72  ```
73
74  A `Functional` model constructed using the Functional API can also include raw
75  TensorFlow functions, with the exception of functions that create Variables
76  or assign ops.
77
78  Example:
79
80  ```
81  inputs = keras.Input(shape=(10,))
82  x = keras.layers.Dense(1)(inputs)
83  outputs = tf.nn.relu(x)
84  model = keras.Model(inputs, outputs)
85  ```
86
87  Args:
88    inputs: List of input tensors (must be created via `tf.keras.Input()`).
89    outputs: List of output tensors.
90    name: String, optional. Name of the model.
91    trainable: Boolean, optional. If the model's variables should be trainable.
92  """
93
94  # See tf.Module for the usage of this property.
95  # The key of _layer_call_argspecs is a layer. tf.Module._flatten will fail to
96  # flatten the key since it is trying to convert Trackable/Layer to a string.
97  _TF_MODULE_IGNORED_PROPERTIES = frozenset(itertools.chain(
98      ('_layer_call_argspecs', '_compiled_trainable_state',
99       '_output_mask_cache', '_output_tensor_cache', '_output_shape_cache'),
100      training_lib.Model._TF_MODULE_IGNORED_PROPERTIES
101  ))
102
103  @trackable.no_automatic_dependency_tracking
104  def __init__(self, inputs, outputs, name=None, trainable=True,
105               **kwargs):
106    # This is used by the Model class, since we have some logic to swap the
107    # class in the __new__ method, which will lead to __init__ get invoked
108    # twice. Using the skip_init to skip one of the invocation of __init__ to
109    # avoid any side effects
110    skip_init = kwargs.pop('skip_init', False)
111    if skip_init:
112      return
113    generic_utils.validate_kwargs(kwargs, {})
114    super(Functional, self).__init__(name=name, trainable=trainable)
115    self._init_graph_network(inputs, outputs)
116
117  @trackable.no_automatic_dependency_tracking
118  def _init_graph_network(self, inputs, outputs):
119    # This method is needed for Sequential to reinitialize graph network when
120    # layer is added or removed.
121    self._is_graph_network = True
122
123    # Normalize and set self.inputs, self.outputs.
124    if isinstance(inputs, list) and len(nest.flatten(inputs)) == 1:
125      inputs = inputs[0]
126    if isinstance(outputs, list) and len(nest.flatten(outputs)) == 1:
127      outputs = outputs[0]
128    self._nested_inputs = inputs
129    self._nested_outputs = outputs
130    self.inputs = nest.flatten(inputs)
131    self.outputs = nest.flatten(outputs)
132
133    # Models constructed with a single Tensor or list of Tensors can
134    # be called with a dict, where the keys of the dict are the names
135    # of the `Input` objects. Extra keys are ignored with warning.
136    if not nest.is_nested(self._nested_inputs):
137      self._enable_dict_to_input_mapping = True
138    elif (isinstance(self._nested_inputs, (list, tuple)) and
139          not any(nest.is_nested(t) for t in self._nested_inputs)):
140      self._enable_dict_to_input_mapping = True
141    elif (isinstance(self._nested_inputs, dict) and
142          not any(nest.is_nested(t) for t in self._nested_inputs.values())):
143      self._enable_dict_to_input_mapping = True
144    else:
145      self._enable_dict_to_input_mapping = False
146
147    if not ops.executing_eagerly_outside_functions():
148      if any(not hasattr(tensor, '_keras_history') for tensor in self.outputs):
149        base_layer_utils.create_keras_history(self._nested_outputs)
150
151    self._validate_graph_inputs_and_outputs()
152
153    # A Network does not create weights of its own, thus it is already
154    # built.
155    self.built = True
156    self._build_input_shape = nest.map_structure(lambda x: x.shape, inputs)
157    self._compute_output_and_mask_jointly = True
158    # `_expects_training_arg` is True since the `training` argument is always
159    # present in the signature of the `call` method of a graph network.
160    self._expects_training_arg = True
161    self._expects_mask_arg = True
162    # A graph network does not autocast inputs, as its layers will cast them
163    # instead.
164    self._autocast = False
165
166    self._input_layers = []
167    self._output_layers = []
168    self._input_coordinates = []
169    self._output_coordinates = []
170
171    # This is for performance optimization when calling the Network on new
172    # inputs. Every time the Network is called on a set on input tensors,
173    # we compute the output tensors, output masks and output shapes in one pass,
174    # then cache them here. When any of these outputs is queried later, we
175    # retrieve it from there instead of recomputing it.
176    self._output_mask_cache = {}
177    self._output_tensor_cache = {}
178    self._output_shape_cache = {}
179
180    # Build self._output_layers:
181    for x in self.outputs:
182      layer, node_index, tensor_index = x._keras_history  # pylint: disable=protected-access
183      self._output_layers.append(layer)
184      self._output_coordinates.append((layer, node_index, tensor_index))
185
186    # Build self._input_layers:
187    for x in self.inputs:
188      layer, node_index, tensor_index = x._keras_history  # pylint: disable=protected-access
189      # It's supposed to be an input layer, so only one node
190      # and one tensor output.
191      assert node_index == 0
192      assert tensor_index == 0
193      self._input_layers.append(layer)
194      self._input_coordinates.append((layer, node_index, tensor_index))
195
196    # Keep track of the network's nodes and layers.
197    nodes, nodes_by_depth, layers, _ = _map_graph_network(
198        self.inputs, self.outputs)
199    self._network_nodes = nodes
200    self._nodes_by_depth = nodes_by_depth
201    self._self_tracked_trackables = layers
202    self._layer_call_argspecs = {}
203    for layer in self._self_tracked_trackables:
204      self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call)
205
206    # Build self.input_names and self.output_names.
207    self._set_output_names()
208    self.input_names = []
209    self._feed_input_names = []
210    self._feed_inputs = []
211    self._feed_input_shapes = []
212    for layer in self._input_layers:
213      self.input_names.append(layer.name)
214      if layer.is_placeholder:
215        self._feed_input_names.append(layer.name)
216        # Use batch_input_shape here because non-eager composite tensors may not
217        # have a shape attribute that's meaningful (sparse, for instance, has
218        # a tensor that's non-constant and needs to be fed). This means that
219        # input layers that create placeholders will need to have the
220        # batch_input_shape attr to allow for input shape validation.
221        self._feed_input_shapes.append(layer._batch_input_shape)
222        self._feed_inputs.append(layer.input)
223
224    self._compute_tensor_usage_count()
225    self._set_save_spec(self._nested_inputs)
226    tf_utils.assert_no_legacy_layers(self.layers)
227
228  @property
229  def input(self):
230    """Retrieves the input tensor(s) of a layer.
231
232    Only applicable if the layer has exactly one input,
233    i.e. if it is connected to one incoming layer.
234
235    Returns:
236        Input tensor or list of input tensors.
237
238    Raises:
239      RuntimeError: If called in Eager mode.
240      AttributeError: If no inbound nodes are found.
241    """
242    return self._nested_inputs
243
244  @property
245  def input_shape(self):
246    """Retrieves the input shape(s) of a layer.
247
248    Only applicable if the layer has exactly one input,
249    i.e. if it is connected to one incoming layer, or if all inputs
250    have the same shape.
251
252    Returns:
253        Input shape, as an integer shape tuple
254        (or list of shape tuples, one tuple per input tensor).
255
256    Raises:
257        AttributeError: if the layer has no defined input_shape.
258        RuntimeError: if called in Eager mode.
259    """
260    return nest.map_structure(backend.int_shape, self.input)
261
262  @property
263  def input_spec(self):
264    if hasattr(self, '_manual_input_spec'):
265      return self._manual_input_spec
266    if (isinstance(self._nested_inputs, (dict, list, tuple)) and
267        len(self._nested_inputs) != len(self.inputs)):
268      # Case where we have a nested structure.
269      # In such a case we can't safely run any checks.
270      return None
271    if isinstance(self._nested_inputs, dict):
272      # Case where `_nested_inputs` is a plain dict of Inputs.
273      names = sorted(self._nested_inputs.keys())
274      return [input_spec.InputSpec(
275          shape=shape_with_no_batch_size(self._nested_inputs[name]),
276          allow_last_axis_squeeze=True, name=name) for name in names]
277    else:
278      # Single input, or list / tuple of inputs.
279      # The data may be passed as a dict keyed by input name.
280      return [input_spec.InputSpec(
281          shape=shape_with_no_batch_size(x), allow_last_axis_squeeze=True,
282          name=x._keras_history.layer.name) for x in self.inputs]
283
284  @input_spec.setter
285  def input_spec(self, value):
286    self._manual_input_spec = value
287
288  @property
289  def output(self):
290    """Retrieves the output tensor(s) of a layer.
291
292    Only applicable if the layer has exactly one output,
293    i.e. if it is connected to one incoming layer.
294
295    Returns:
296      Output tensor or list of output tensors.
297
298    Raises:
299      AttributeError: if the layer is connected to more than one incoming
300        layers.
301      RuntimeError: if called in Eager mode.
302    """
303    return self._nested_outputs
304
305  @property
306  def output_shape(self):
307    """Retrieves the output shape(s) of a layer.
308
309    Only applicable if the layer has one output,
310    or if all outputs have the same shape.
311
312    Returns:
313        Output shape, as an integer shape tuple
314        (or list of shape tuples, one tuple per output tensor).
315
316    Raises:
317        AttributeError: if the layer has no defined output shape.
318        RuntimeError: if called in Eager mode.
319    """
320    return nest.map_structure(backend.int_shape, self.output)
321
322  def _set_output_names(self):
323    """Assigns unique names to the Network's outputs.
324
325    Output layers with multiple output tensors would otherwise lead to duplicate
326    names in self.output_names.
327    """
328    uniquified = []
329    output_names = set()
330    prefix_count = {}
331    for layer in self._output_layers:
332      proposal = layer.name
333      while proposal in output_names:
334        existing_count = prefix_count.get(layer.name, 1)
335        proposal = '{}_{}'.format(layer.name, existing_count)
336        prefix_count[layer.name] = existing_count + 1
337      output_names.add(proposal)
338      uniquified.append(proposal)
339    self.output_names = uniquified
340
341  @property
342  def _layer_checkpoint_dependencies(self):
343    """Dictionary of layer dependencies to be included in the checkpoint."""
344    weight_layer_index = 0
345
346    dependencies = collections.OrderedDict()
347    for layer_index, layer in enumerate(self.layers):
348      try:
349        if layer.weights:
350          # Keep a separate index for layers which have weights. This allows
351          # users to insert Layers without weights anywhere in the network
352          # without breaking checkpoints.
353          dependencies['layer_with_weights-%d' % weight_layer_index] = layer
354          weight_layer_index += 1
355      except ValueError:
356        # The layer might have weights, but may not be built yet. We just treat
357        # it as layer without weight.
358        pass
359
360      # Even if it doesn't have weights, we should still track everything in
361      # case it has/will have Trackable dependencies.
362      dependencies['layer-%d' % layer_index] = layer
363    return dependencies
364
365  def _trackable_children(self,
366                          save_type=trackable.SaveType.CHECKPOINT,
367                          **kwargs):
368    dependencies = self._layer_checkpoint_dependencies
369    dependencies.update(
370        super(Functional, self)._trackable_children(save_type, **kwargs))
371    return dependencies
372
373  def _lookup_dependency(self, name):
374    layer_dependencies = self._layer_checkpoint_dependencies
375    if name in layer_dependencies:
376      return layer_dependencies[name]
377    return super(Functional, self)._lookup_dependency(name)
378
379  def _handle_deferred_layer_dependencies(self, layers):
380    """Handles layer checkpoint dependencies that are added after init."""
381    layer_checkpoint_dependencies = self._layer_checkpoint_dependencies
382    layer_to_name = {v: k for k, v in layer_checkpoint_dependencies.items()}
383    for layer in layers:
384      if layer in layer_to_name:
385        self._handle_deferred_dependencies(name=layer_to_name[layer],
386                                           trackable=layer)
387
388  @property
389  def _should_compute_mask(self):
390    return True
391
392  def compute_mask(self, inputs, mask):
393    # TODO(omalleyt): b/123540974 This function is not really safe to call
394    # by itself because it will duplicate any updates and losses in graph
395    # mode by `call`ing the Layers again.
396    output_tensors = self._run_internal_graph(inputs, mask=mask)
397    return nest.map_structure(lambda t: getattr(t, '_keras_mask', None),
398                              output_tensors)
399
400  @doc_controls.do_not_doc_inheritable
401  def call(self, inputs, training=None, mask=None):
402    """Calls the model on new inputs.
403
404    In this case `call` just reapplies
405    all ops in the graph to the new inputs
406    (e.g. build a new computational graph from the provided inputs).
407
408    Args:
409        inputs: A tensor or list of tensors.
410        training: Boolean or boolean scalar tensor, indicating whether to run
411          the `Network` in training mode or inference mode.
412        mask: A mask or list of masks. A mask can be
413            either a tensor or None (no mask).
414
415    Returns:
416        A tensor if there is a single output, or
417        a list of tensors if there are more than one outputs.
418    """
419    return self._run_internal_graph(
420        inputs, training=training, mask=mask)
421
422  def compute_output_shape(self, input_shape):
423    # Convert any shapes in tuple format to TensorShapes.
424    input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False)
425
426    if len(nest.flatten(input_shape)) != len(nest.flatten(self._input_layers)):
427      raise ValueError('Invalid input_shape argument ' + str(input_shape) +
428                       ': model has ' + str(len(self._input_layers)) +
429                       ' tensor inputs.')
430
431    # Use the tuple of TensorShape as the cache key, since tuple is hashable
432    # and can be used as hash key.
433    try:
434      cache_key = tuple(tf_utils.convert_shapes(input_shape, to_tuples=True))
435      if cache_key in self._output_shape_cache:
436        # Cache hit. Return shapes as TensorShapes.
437        return self._output_shape_cache[cache_key]
438    except ValueError:
439      # In case there are unknown TensorShape, eg for sparse tensor input,
440      # We skip the caching since the shape is unknown.
441      pass
442
443    layers_to_output_shapes = {}
444    for layer, shape in zip(self._input_layers, nest.flatten(input_shape)):
445      # It's an input layer: then `compute_output_shape` is identity,
446      # and there is only one node and one tensor..
447      shape_key = layer.name + '_0_0'
448      layers_to_output_shapes[shape_key] = shape
449
450    depth_keys = list(self._nodes_by_depth.keys())
451    depth_keys.sort(reverse=True)
452    # Iterate over nodes, by depth level.
453    if len(depth_keys) > 1:
454      for depth in depth_keys:
455        nodes = self._nodes_by_depth[depth]
456        for node in nodes:
457          layer = node.layer
458          if layer in self._input_layers:
459            # We've already covered the input layers
460            # a few lines above.
461            continue
462          # Get the input shapes for the first argument of the node
463          layer_input_shapes = []
464          layer_inputs = node.call_args[0]
465          for layer_input in nest.flatten(layer_inputs):
466            kh = layer_input._keras_history
467            input_layer_key = kh.layer.name + '_%s_%s' % (kh.node_index,
468                                                          kh.tensor_index)
469            layer_input_shapes.append(layers_to_output_shapes[input_layer_key])
470          layer_input_shapes = nest.pack_sequence_as(layer_inputs,
471                                                     layer_input_shapes)
472          # Layers expect shapes to be tuples for `compute_output_shape`.
473          layer_input_shapes = tf_utils.convert_shapes(
474              layer_input_shapes, to_tuples=True)
475          layer_output_shapes = layer.compute_output_shape(layer_input_shapes)
476          # Convert back to TensorShapes.
477          layer_output_shapes = tf_utils.convert_shapes(
478              layer_output_shapes, to_tuples=False)
479
480          node_index = layer._inbound_nodes.index(node)  # pylint: disable=protected-access
481          for j, shape in enumerate(nest.flatten(layer_output_shapes)):
482            shape_key = layer.name + '_%s_%s' % (node_index, j)
483            layers_to_output_shapes[shape_key] = shape
484
485      # Read final output shapes from layers_to_output_shapes.
486      output_shapes = []
487      for i in range(len(self._output_layers)):
488        layer, node_index, tensor_index = self._output_coordinates[i]
489        shape_key = layer.name + '_%s_%s' % (node_index, tensor_index)
490        output_shapes.append(layers_to_output_shapes[shape_key])
491      output_shapes = nest.pack_sequence_as(self._nested_outputs, output_shapes)
492      # Store in cache.
493      self._output_shape_cache[cache_key] = output_shapes
494
495    # Return shapes as TensorShapes.
496    return output_shapes
497
498  def _init_set_name(self, name, zero_based=True):
499    if not name:
500      cls_name = self.__class__.__name__
501      if self.__class__ == Functional:
502        # Hide the functional class name from user, since its not a public
503        # visible class. Use "Model" instead,
504        cls_name = 'Model'
505      self._name = backend.unique_object_name(
506          generic_utils.to_snake_case(cls_name),
507          zero_based=zero_based)
508    else:
509      self._name = name
510
511  def _run_internal_graph(self, inputs, training=None, mask=None):
512    """Computes output tensors for new inputs.
513
514    # Note:
515        - Can be run on non-Keras tensors.
516
517    Args:
518        inputs: Tensor or nested structure of Tensors.
519        training: Boolean learning phase.
520        mask: (Optional) Tensor or nested structure of Tensors.
521
522    Returns:
523        output_tensors
524    """
525    inputs = self._flatten_to_reference_inputs(inputs)
526    if mask is None:
527      masks = [None] * len(inputs)
528    else:
529      masks = self._flatten_to_reference_inputs(mask)
530    for input_t, mask in zip(inputs, masks):
531      input_t._keras_mask = mask
532
533    # Dictionary mapping reference tensors to computed tensors.
534    tensor_dict = {}
535    tensor_usage_count = self._tensor_usage_count
536    for x, y in zip(self.inputs, inputs):
537      y = self._conform_to_reference_input(y, ref_input=x)
538      x_id = str(id(x))
539      tensor_dict[x_id] = [y] * tensor_usage_count[x_id]
540
541    nodes_by_depth = self._nodes_by_depth
542    depth_keys = list(nodes_by_depth.keys())
543    depth_keys.sort(reverse=True)
544
545    for depth in depth_keys:
546      nodes = nodes_by_depth[depth]
547      for node in nodes:
548        if node.is_input:
549          continue  # Input tensors already exist.
550
551        if any(t_id not in tensor_dict for t_id in node.flat_input_ids):
552          continue  # Node is not computable, try skipping.
553
554        args, kwargs = node.map_arguments(tensor_dict)
555        outputs = node.layer(*args, **kwargs)
556
557        # Update tensor_dict.
558        for x_id, y in zip(node.flat_output_ids, nest.flatten(outputs)):
559          tensor_dict[x_id] = [y] * tensor_usage_count[x_id]
560
561    output_tensors = []
562    for x in self.outputs:
563      x_id = str(id(x))
564      assert x_id in tensor_dict, 'Could not compute output ' + str(x)
565      output_tensors.append(tensor_dict[x_id].pop())
566
567    return nest.pack_sequence_as(self._nested_outputs, output_tensors)
568
569  def _flatten_to_reference_inputs(self, tensors):
570    """Maps `tensors` to their respective `keras.Input`."""
571    if self._enable_dict_to_input_mapping and isinstance(tensors, dict):
572      ref_inputs = self._nested_inputs
573      if not nest.is_nested(ref_inputs):
574        ref_inputs = [self._nested_inputs]
575      if isinstance(ref_inputs, dict):
576        # In the case that the graph is constructed with dict input tensors,
577        # We will use the original dict key to map with the keys in the input
578        # data. Note that the model.inputs is using nest.flatten to process the
579        # input tensors, which means the dict input tensors are ordered by their
580        # keys.
581        ref_input_names = sorted(ref_inputs.keys())
582      else:
583        ref_input_names = [inp._keras_history.layer.name for inp in ref_inputs]
584
585      # Raise an warning if there are more input data comparing to input tensor
586      if len(tensors) > len(ref_input_names):
587        warnings.warn(
588            'Input dict contained keys {} which did not match any model input. '
589            'They will be ignored by the model.'.format(
590                [n for n in tensors.keys() if n not in ref_input_names])
591            )
592
593      try:
594        # Flatten in the order `Input`s were passed during Model construction.
595        return [tensors[n] for n in ref_input_names]
596      except KeyError:
597        # TODO(b/151582614)
598        return nest.flatten(tensors)
599
600    # Otherwise both self.inputs and tensors will already be in same order.
601    return nest.flatten(tensors)
602
603  def _conform_to_reference_input(self, tensor, ref_input):
604    """Set shape and dtype based on `keras.Input`s."""
605    if isinstance(tensor, ops.Tensor):
606      # Allow (None,) and (None, 1) Tensors to be passed interchangeably. Use
607      # the shape specified by the `keras.Input`.
608      t_shape = tensor.shape
609      t_rank = t_shape.rank
610      ref_shape = ref_input.shape
611      ref_rank = ref_shape.rank
612      keras_history = getattr(tensor, '_keras_history', None)
613      if t_rank is not None and ref_rank is not None:
614        # Should squeeze last dimension.
615        # True if tensor is (BATCH, ..., 1) and reference is (BATCH, ...).
616        if (t_rank == ref_rank + 1 and t_shape[-1] == 1):
617          tensor = array_ops.squeeze_v2(tensor, axis=-1)
618        # Should expand last_dimension.
619        # True if tensor is (BATCH, ...) and reference is (BATCH, ..., 1).
620        elif (t_rank == ref_rank - 1 and ref_shape[-1] == 1):
621          tensor = array_ops.expand_dims_v2(tensor, axis=-1)
622      if keras_history is not None:  # Restore keras history.
623        tensor._keras_history = keras_history
624
625      # Add shape hints to Tensors that may have None shape dims but have shapes
626      # defined by the `keras.Input` (not applicable in eager mode).
627      if not context.executing_eagerly():
628        try:
629          tensor.set_shape(tensor.shape.merge_with(ref_input.shape))
630        except ValueError:
631          logging.warning(
632              'Model was constructed with shape {} for input {}, but it was '
633              'called on an input with incompatible shape {}.'.format(
634                  ref_input.shape, ref_input, tensor.shape))
635
636      # Dtype casting.
637      tensor = math_ops.cast(tensor, dtype=ref_input.dtype)
638    elif tf_utils.is_extension_type(tensor):
639      # Dtype casting (If the extension type has a non-variant dtype and
640      # supports being cast)
641      ref_input_dtype = getattr(ref_input, 'dtype', None)
642      if ref_input_dtype is not None and ref_input_dtype != dtypes.variant:
643        tensor = math_ops.cast(tensor, dtype=ref_input_dtype)
644
645    return tensor
646
647  def get_config(self):
648    return copy.deepcopy(get_network_config(self))
649
650  @classmethod
651  def from_config(cls, config, custom_objects=None):
652    """Instantiates a Model from its config (output of `get_config()`).
653
654    Args:
655        config: Model config dictionary.
656        custom_objects: Optional dictionary mapping names
657            (strings) to custom classes or functions to be
658            considered during deserialization.
659
660    Returns:
661        A model instance.
662
663    Raises:
664        ValueError: In case of improperly formatted config dict.
665    """
666    with generic_utils.SharedObjectLoadingScope():
667      input_tensors, output_tensors, created_layers = reconstruct_from_config(
668          config, custom_objects)
669      model = cls(inputs=input_tensors, outputs=output_tensors,
670                  name=config.get('name'))
671      connect_ancillary_layers(model, created_layers)
672      return model
673
674  def _validate_graph_inputs_and_outputs(self):
675    """Validates the inputs and outputs of a Graph Network."""
676    # Check for redundancy in inputs.
677    if len({id(i) for i in self.inputs}) != len(self.inputs):
678      raise ValueError('The list of inputs passed to the model '
679                       'is redundant. '
680                       'All inputs should only appear once.'
681                       ' Found: ' + str(self.inputs))
682
683    for x in self.inputs:
684      # Check that x has appropriate `_keras_history` metadata.
685      if not hasattr(x, '_keras_history'):
686        cls_name = self.__class__.__name__
687        raise ValueError('Input tensors to a ' + cls_name + ' ' +
688                         'must come from `tf.keras.Input`. '
689                         'Received: ' + str(x) +
690                         ' (missing previous layer metadata).')
691      # Check that x is an input tensor.
692      # pylint: disable=protected-access
693      layer = x._keras_history.layer
694      if len(layer._inbound_nodes) > 1 or (
695          layer._inbound_nodes and not layer._inbound_nodes[0].is_input):
696        cls_name = self.__class__.__name__
697        logging.warning(cls_name + ' model inputs must come from '
698                        '`tf.keras.Input` (thus holding past layer metadata), '
699                        'they cannot be the output of '
700                        'a previous non-Input layer. '
701                        'Here, a tensor specified as '
702                        'input to "' + self.name + '" was not an Input tensor, '
703                        'it was generated by layer ' + layer.name + '.\n'
704                        'Note that input tensors are '
705                        'instantiated via `tensor = tf.keras.Input(shape)`.\n'
706                        'The tensor that caused the issue was: ' + str(x.name))
707
708    # Check compatibility of batch sizes of Input Layers.
709    input_batch_sizes = [
710        training_utils.get_static_batch_size(x._keras_history.layer)
711        for x in self.inputs
712    ]
713    consistent_batch_size = None
714    for batch_size in input_batch_sizes:
715      if batch_size is not None:
716        if (consistent_batch_size is not None and
717            batch_size != consistent_batch_size):
718          raise ValueError('The specified batch sizes of the Input Layers'
719                           ' are incompatible. Found batch sizes: {}'.format(
720                               input_batch_sizes))
721        consistent_batch_size = batch_size
722
723    for x in self.outputs:
724      if not hasattr(x, '_keras_history'):
725        cls_name = self.__class__.__name__
726        raise ValueError('Output tensors of a ' + cls_name + ' model must be '
727                         'the output of a TensorFlow `Layer` '
728                         '(thus holding past layer metadata). Found: ' + str(x))
729
730  def _insert_layers(self, layers, relevant_nodes=None):
731    """Inserts Layers into the Network after Network creation.
732
733    This is only valid for Keras Graph Networks.  Layers added via this function
734    will be included in the `call` computation and `get_config` of this Network.
735    They will not be added to the Network's outputs.
736
737
738    Args:
739      layers: Arbitrary nested structure of Layers. Layers must be reachable
740        from one or more of the `keras.Input` Tensors that correspond to this
741        Network's inputs.
742      relevant_nodes: Nodes from the Layers that should be considered part of
743        this Network. If `None`, all Nodes will be considered part of this
744        Network.
745
746    Raises:
747      ValueError: If the layers depend on `Input`s not found in this Model.
748    """
749    layers = nest.flatten(layers)
750    tf_utils.assert_no_legacy_layers(layers)
751    node_to_depth = {}
752    for depth, nodes in self._nodes_by_depth.items():
753      node_to_depth.update({node: depth for node in nodes})
754    # The nodes of these Layers that are relevant to this Network. If not
755    # provided, assume all Nodes are relevant
756    if not relevant_nodes:
757      relevant_nodes = nest.flatten([layer._inbound_nodes for layer in layers])
758    network_nodes = set(relevant_nodes + list(node_to_depth.keys()))
759
760    def _get_min_depth(node):
761      """Gets the minimum depth at which node can be computed."""
762      min_depth = 0
763      for layer, node_id, _, _ in node.iterate_inbound():
764        inbound_node = layer._inbound_nodes[node_id]
765        if inbound_node in node_to_depth:
766          min_depth = min(min_depth, node_to_depth[inbound_node])
767        elif inbound_node not in network_nodes:
768          continue
769        else:
770          # Previous relevant nodes haven't been processed yet.
771          return None
772      # New node is one shallower than its shallowest input.
773      return min_depth - 1
774
775    # Insert nodes into `_nodes_by_depth` and other node attrs.
776    unprocessed_nodes = copy.copy(relevant_nodes)
777    i = 0
778    while unprocessed_nodes:
779      i += 1
780      # Do a sanity check. This can occur if `Input`s from outside this Model
781      # are being relied on.
782      if i > 10000:
783        raise ValueError('Layers could not be added due to missing '
784                         'dependencies.')
785
786      node = unprocessed_nodes.pop(0)
787      depth = _get_min_depth(node)
788      if depth is None:  # Defer until inbound nodes are processed.
789        unprocessed_nodes.append(node)
790        continue
791      node_key = _make_node_key(node.layer.name,
792                                node.layer._inbound_nodes.index(node))
793      if node_key not in self._network_nodes:
794        node_to_depth[node] = depth
795        self._network_nodes.add(node_key)
796        self._nodes_by_depth[depth].append(node)
797
798    # Insert layers and update other layer attrs.
799    layer_set = set(self._self_tracked_trackables)
800    deferred_layers = []
801    for layer in layers:
802      if layer not in layer_set:
803        self._self_tracked_trackables.append(layer)
804        deferred_layers.append(layer)
805        self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call)
806        layer_set.add(layer)
807    self._handle_deferred_layer_dependencies(deferred_layers)
808
809    self._compute_tensor_usage_count()
810
811  def _compute_tensor_usage_count(self):
812    """Compute the #. of tensor usages for all the output tensors of layers.
813
814    The computed tensor usage count is saved as `self._tensor_usage_count`. This
815    is later used for saving memory in eager computation by releasing
816    no-longer-needed tensors as early as possible.
817    """
818    tensor_usage_count = collections.Counter()
819    available_tensors = set(str(id(tensor)) for tensor in self.inputs)
820
821    depth_keys = list(self._nodes_by_depth.keys())
822    depth_keys.sort(reverse=True)
823    depth_keys = depth_keys[1:]
824
825    for depth in depth_keys:
826      for node in self._nodes_by_depth[depth]:
827        input_tensors = {
828            str(id(tensor)) for tensor in nest.flatten(node.keras_inputs)
829        }
830        if input_tensors.issubset(available_tensors):
831          for tensor in nest.flatten(node.keras_inputs):
832            tensor_usage_count[str(id(tensor))] += 1
833
834          for output_tensor in nest.flatten(node.outputs):
835            available_tensors.add(str(id(output_tensor)))
836
837    for tensor in self.outputs:
838      tensor_usage_count[str(id(tensor))] += 1
839
840    self._tensor_usage_count = tensor_usage_count
841
842  def _assert_weights_created(self):
843    # Override the implementation in Model.
844    # The Functional model should always have weight created already.
845    return
846
847  def _graph_network_add_loss(self, symbolic_loss):
848    new_nodes, new_layers = _map_subgraph_network(self.inputs, [symbolic_loss])
849    # Losses must be keyed on inputs no matter what in order to be supported in
850    # DistributionStrategy.
851    add_loss_layer = base_layer.AddLoss(
852        unconditional=False, dtype=symbolic_loss.dtype)
853    add_loss_layer(symbolic_loss)
854    new_nodes.extend(add_loss_layer.inbound_nodes)
855    new_layers.append(add_loss_layer)
856    self._insert_layers(new_layers, new_nodes)
857
858  def _graph_network_add_metric(self, value, aggregation, name):
859    new_nodes, new_layers = _map_subgraph_network(self.inputs, [value])
860    add_metric_layer = base_layer.AddMetric(
861        aggregation, name, dtype=value.dtype)
862    add_metric_layer(value)
863    new_nodes.extend(add_metric_layer.inbound_nodes)
864    new_layers.append(add_metric_layer)
865    self._insert_layers(new_layers, new_nodes)
866
867  @property
868  def _trackable_saved_model_saver(self):
869    return network_serialization.NetworkSavedModelSaver(self)
870
871  def _get_save_spec(self, dynamic_batch=True):
872    if getattr(self, '_has_explicit_input_shape', True):
873      # Functional models and Sequential models that have an explicit input
874      # shape should use the batch size set by the input layer.
875      dynamic_batch = False
876    return super(Functional, self)._get_save_spec(dynamic_batch)
877
878
879def _make_node_key(layer_name, node_index):
880  return layer_name + '_ib-' + str(node_index)
881
882
883def _map_graph_network(inputs, outputs):
884  """Validates a network's topology and gather its layers and nodes.
885
886  Args:
887    inputs: List of input tensors.
888    outputs: List of outputs tensors.
889
890  Returns:
891    A tuple `(nodes, nodes_by_depth, layers, layers_by_depth)`.
892    - nodes: list of Node instances.
893    - nodes_by_depth: dict mapping ints (depth) to lists of node instances.
894    - layers: list of Layer instances.
895    - layers_by_depth: dict mapping ints (depth) to lists of layer instances.
896
897  Raises:
898    ValueError: In case the network is not valid (e.g. disconnected graph).
899  """
900  # "depth" is number of layers between output Node and the Node.
901  # Nodes are ordered from inputs -> outputs.
902  nodes_in_decreasing_depth, layer_indices = _build_map(outputs)
903  network_nodes = {
904      _make_node_key(node.layer.name, node.layer._inbound_nodes.index(node))
905      for node in nodes_in_decreasing_depth
906  }
907
908  nodes_depths = {}  # dict {node: depth value}
909  layers_depths = {}  # dict {layer: depth value}
910
911  for node in reversed(nodes_in_decreasing_depth):
912    # If the depth is not set, the node has no outbound nodes (depth 0).
913    depth = nodes_depths.setdefault(node, 0)
914
915    # Update the depth of the corresponding layer
916    previous_depth = layers_depths.get(node.layer, 0)
917    # If we've seen this layer before at a higher depth,
918    # we should use that depth instead of the node depth.
919    # This is necessary for shared layers that have inputs at different
920    # depth levels in the graph.
921    depth = max(depth, previous_depth)
922    layers_depths[node.layer] = depth
923    nodes_depths[node] = depth
924
925    # Update the depth of inbound nodes.
926    # The "depth" of a node is the max of the depths
927    # of all nodes it is connected to + 1.
928    for node_dep in node.parent_nodes:
929      previous_depth = nodes_depths.get(node_dep, 0)
930      nodes_depths[node_dep] = max(depth + 1, previous_depth)
931
932  # Handle inputs that are not connected to outputs.
933  # We do not error out here because the inputs may be used to compute losses
934  # and metrics.
935  for input_t in inputs:
936    input_layer = input_t._keras_history[0]
937    if input_layer not in layers_depths:
938      layers_depths[input_layer] = 0
939      layer_indices[input_layer] = -1
940      nodes_depths[input_layer._inbound_nodes[0]] = 0
941      network_nodes.add(_make_node_key(input_layer.name, 0))
942
943  # Build a dict {depth: list of nodes with this depth}
944  nodes_by_depth = collections.defaultdict(list)
945  for node, depth in nodes_depths.items():
946    nodes_by_depth[depth].append(node)
947
948  # Build a dict {depth: list of layers with this depth}
949  layers_by_depth = collections.defaultdict(list)
950  for layer, depth in layers_depths.items():
951    layers_by_depth[depth].append(layer)
952
953  # Get sorted list of layer depths.
954  depth_keys = list(layers_by_depth.keys())
955  depth_keys.sort(reverse=True)
956
957  # Set self.layers ordered by depth.
958  layers = []
959  for depth in depth_keys:
960    layers_for_depth = layers_by_depth[depth]
961    # Network.layers needs to have a deterministic order:
962    # here we order them by traversal order.
963    layers_for_depth.sort(key=lambda x: layer_indices[x])
964    layers.extend(layers_for_depth)
965
966  # Get sorted list of node depths.
967  depth_keys = list(nodes_by_depth.keys())
968  depth_keys.sort(reverse=True)
969
970  # Check that all tensors required are computable.
971  # computable_tensors: all tensors in the graph
972  # that can be computed from the inputs provided.
973  computable_tensors = set()
974  for x in inputs:
975    computable_tensors.add(id(x))
976
977  layers_with_complete_input = []  # To provide a better error msg.
978  for depth in depth_keys:
979    for node in nodes_by_depth[depth]:
980      layer = node.layer
981      if layer and not node.is_input:
982        for x in nest.flatten(node.keras_inputs):
983          if id(x) not in computable_tensors:
984            raise ValueError('Graph disconnected: '
985                             'cannot obtain value for tensor ' + str(x) +
986                             ' at layer "' + layer.name + '". '
987                             'The following previous layers '
988                             'were accessed without issue: ' +
989                             str(layers_with_complete_input))
990        for x in nest.flatten(node.outputs):
991          computable_tensors.add(id(x))
992        layers_with_complete_input.append(layer.name)
993
994  # Ensure name unicity, which will be crucial for serialization
995  # (since serialized nodes refer to layers by their name).
996  all_names = [layer.name for layer in layers]
997  for name in all_names:
998    if all_names.count(name) != 1:
999      raise ValueError('The name "' + name + '" is used ' +
1000                       str(all_names.count(name)) + ' times in the model. '
1001                       'All layer names should be unique.')
1002  return network_nodes, nodes_by_depth, layers, layers_by_depth
1003
1004
1005def _build_map(outputs):
1006  """This method topologically sorts nodes in order from inputs to outputs.
1007
1008  It uses a depth-first search to topologically sort nodes that appear in the
1009  _keras_history connectivity metadata of `outputs`.
1010
1011  Args:
1012    outputs: the output tensors whose _keras_history metadata should be walked.
1013    This may be an arbitrary nested structure.
1014
1015  Returns:
1016    A tuple like (ordered_nodes, layer_to_first_traversal_index)
1017    ordered_nodes: list of nodes appearing in the keras history, topologically
1018      sorted from original inputs to the `outputs`.
1019      (If outputs have different sets of ancestors, the inputs to one output
1020      may appear after a different output).
1021    layer_to_first_traversal_index:
1022      A dict mapping layer to the traversal index in the DFS where it is
1023      seen. Note: if a layer is shared by several nodes, the dict will only
1024      store the index corresponding to the *first* time the layer seen.
1025  """
1026  finished_nodes = set()
1027  nodes_in_progress = set()
1028  nodes_in_decreasing_depth = []  # nodes from inputs -> outputs.
1029  layer_indices = {}  # layer -> in traversal order.
1030  for output in nest.flatten(outputs):
1031    _build_map_helper(output, finished_nodes, nodes_in_progress,
1032                      nodes_in_decreasing_depth, layer_indices)
1033  return nodes_in_decreasing_depth, layer_indices
1034
1035
1036def _build_map_helper(tensor, finished_nodes, nodes_in_progress,
1037                      nodes_in_decreasing_depth, layer_indices):
1038  """Recursive helper for `_build_map`."""
1039  layer, node_index, _ = tensor._keras_history  # pylint: disable=protected-access
1040  node = layer._inbound_nodes[node_index]  # pylint: disable=protected-access
1041
1042  # Don't repeat work for shared subgraphs
1043  if node in finished_nodes:
1044    return
1045
1046  # Prevent cycles.
1047  if node in nodes_in_progress:
1048    raise ValueError('The tensor ' + str(tensor) + ' at layer "' + layer.name +
1049                     '" is part of a cycle.')
1050
1051  # Store the traversal order for layer sorting.
1052  if layer not in layer_indices:
1053    layer_indices[layer] = len(layer_indices)
1054
1055  # Propagate to all previous tensors connected to this node.
1056  nodes_in_progress.add(node)
1057  if not node.is_input:
1058    for tensor in node.keras_inputs:
1059      _build_map_helper(tensor, finished_nodes, nodes_in_progress,
1060                        nodes_in_decreasing_depth, layer_indices)
1061
1062  finished_nodes.add(node)
1063  nodes_in_progress.remove(node)
1064  nodes_in_decreasing_depth.append(node)
1065
1066
1067def _map_subgraph_network(inputs, outputs):
1068  """Returns the nodes and layers in the topology from `inputs` to `outputs`.
1069
1070  Args:
1071    inputs: List of input tensors.
1072    outputs: List of output tensors.
1073
1074  Returns:
1075    A tuple of List{Node] and List[Layer].
1076  """
1077  if not ops.executing_eagerly_outside_functions():
1078    base_layer_utils.create_keras_history(outputs)
1079  # Keep only nodes and layers in the topology between inputs and outputs.
1080  _, nodes_by_depth, layers, _ = _map_graph_network(inputs, outputs)
1081  return nest.flatten([nodes for nodes in nodes_by_depth.values()]), layers
1082
1083
1084def _should_skip_first_node(layer):
1085  """Returns True if the first layer node should not be saved or loaded."""
1086  # Networks that are constructed with an Input layer/shape start with a
1087  # pre-existing node linking their input to output. This node is excluded from
1088  # the network config.
1089  if layer._self_tracked_trackables:
1090    return (isinstance(layer, Functional) and
1091            # Filter out Sequential models without an input shape.
1092            isinstance(layer._self_tracked_trackables[0],
1093                       input_layer_module.InputLayer))
1094  else:
1095    return isinstance(layer, Functional)
1096
1097
1098def connect_ancillary_layers(model, created_layers):
1099  """Adds layers that are not connected to the outputs to the model."""
1100  # Layers not connected to outputs, such as those added in `add_loss`.
1101  ancillary_layers = [
1102      layer for layer in created_layers.values() if layer not in model.layers
1103  ]
1104  if ancillary_layers:
1105    relevant_nodes = nest.flatten([
1106        layer.inbound_nodes[1:]
1107        if _should_skip_first_node(layer) else layer.inbound_nodes
1108        for layer in created_layers.values()
1109    ])
1110    model._insert_layers(ancillary_layers, relevant_nodes)
1111  return model
1112
1113
1114def reconstruct_from_config(config, custom_objects=None, created_layers=None):
1115  """Reconstructs graph from config object.
1116
1117  Args:
1118    config: Dictionary returned from Network.get_config()
1119    custom_objects: Optional dictionary mapping names (strings) to custom
1120      classes or functions to be considered during deserialization.
1121    created_layers: Optional dictionary mapping names to Layer objects. Any
1122      layer not in this dictionary will be created and added to the dict.
1123      This function will add new nodes to all layers (excluding InputLayers),
1124      instead of re-using pre-existing nodes in the layers.
1125
1126  Returns:
1127    Tuple of (input tensors, output tensors, dictionary of created layers)
1128  """
1129  # Layer instances created during the graph reconstruction process.
1130  created_layers = created_layers or collections.OrderedDict()
1131
1132  # Maps input data (tuple of inbound layer name, node index) from the config
1133  # to node indices in the newly generated model. The node indices may be
1134  # different if the layers have already been called previously.
1135  node_index_map = {}
1136  node_count_by_layer = {}
1137
1138  # Dictionary mapping layer instances to
1139  # node data that specifies a layer call.
1140  # It acts as a queue that maintains any unprocessed
1141  # layer call until it becomes possible to process it
1142  # (i.e. until the input tensors to the call all exist).
1143  unprocessed_nodes = {}
1144
1145  def add_unprocessed_node(layer, node_data):
1146    if layer not in unprocessed_nodes:
1147      unprocessed_nodes[layer] = [node_data]
1148    else:
1149      unprocessed_nodes[layer].append(node_data)
1150
1151  def get_node_index(layer, config_node_index):
1152    """Returns node index in layer (might differ from config_node_index)."""
1153    if isinstance(layer, input_layer_module.InputLayer):
1154      return 0
1155    return node_index_map.get((layer.name, config_node_index), None)
1156
1157  def _deserialize_keras_tensors(kwargs, layer_map):
1158    """Deserializes Keras Tensors passed to `call`.."""
1159
1160    def _deserialize_keras_tensor(t):
1161      """Deserializes a single Keras Tensor passed to `call`."""
1162      if isinstance(t, tf_utils.ListWrapper):
1163        t = t.as_list()
1164        layer_name = t[0]
1165        node_index = t[1]
1166        tensor_index = t[2]
1167
1168        layer = layer_map[layer_name]
1169        new_node_index = get_node_index(layer, node_index)
1170        if new_node_index is None:
1171          # The inbound node may not have been processed yet,
1172          # (This can happen e.g. if it depends on a different set
1173          # of inputs than those that have been processed already).
1174          # raise an IndexError so that the current node puts itself
1175          # back on the unprocessed queue.
1176          # Caution: This may lead to infinite loops for malformed
1177          # network configurations! (or when there is a bug in
1178          # the network config loading code).
1179          raise IndexError
1180        node = layer._inbound_nodes[new_node_index]
1181        return nest.flatten(node.outputs)[tensor_index]
1182      return t
1183
1184    kwargs = tf_utils.convert_inner_node_data(kwargs, wrap=True)
1185    return nest.map_structure(_deserialize_keras_tensor, kwargs)
1186
1187  def process_node(layer, node_data):
1188    """Deserialize a node.
1189
1190    Args:
1191        layer: layer instance.
1192        node_data: Nested structure of `ListWrapper`.
1193
1194    Raises:
1195        ValueError: In case of improperly formatted `node_data`.
1196    """
1197    input_tensors = []
1198    for input_data in nest.flatten(node_data):
1199      input_data = input_data.as_list()
1200      inbound_layer_name = input_data[0]
1201      inbound_node_index = input_data[1]
1202      inbound_tensor_index = input_data[2]
1203      if len(input_data) == 3:
1204        kwargs = {}
1205      elif len(input_data) == 4:
1206        kwargs = input_data[3]
1207        try:
1208          kwargs = _deserialize_keras_tensors(kwargs, created_layers)
1209        except IndexError:
1210          # Happens if keras tensors in kwargs are still unprocessed
1211          add_unprocessed_node(layer, node_data)
1212          return
1213      else:
1214        raise ValueError('Improperly formatted model config.')
1215
1216      if inbound_layer_name != node_module._CONSTANT_VALUE:
1217        inbound_layer = created_layers[inbound_layer_name]
1218        inbound_node_index = get_node_index(inbound_layer, inbound_node_index)
1219
1220        if inbound_node_index is None:
1221          add_unprocessed_node(layer, node_data)
1222          return
1223        inbound_node = inbound_layer._inbound_nodes[inbound_node_index]
1224        input_tensors.append(
1225            nest.flatten(inbound_node.outputs)[inbound_tensor_index])
1226      else:
1227        # We received a constant w/ no Keras history attached
1228        input_tensors.append(inbound_tensor_index)
1229    input_tensors = nest.pack_sequence_as(node_data, input_tensors)
1230    # Call layer on its inputs, thus creating the node
1231    # and building the layer if needed.
1232    if input_tensors is not None:
1233      if not layer._preserve_input_structure_in_config:
1234        input_tensors = (
1235            base_layer_utils.unnest_if_single_tensor(input_tensors))
1236      output_tensors = layer(input_tensors, **kwargs)
1237
1238      # Update node index map.
1239      output_index = nest.flatten(output_tensors)[0]._keras_history.node_index
1240      node_index_map[(layer.name, node_count_by_layer[layer])] = output_index
1241      node_count_by_layer[layer] += 1
1242
1243  def process_layer(layer_data):
1244    """Deserializes a layer, then call it on appropriate inputs.
1245
1246    Args:
1247        layer_data: layer config dict.
1248
1249    Raises:
1250        ValueError: In case of improperly formatted `layer_data` dict.
1251    """
1252    layer_name = layer_data['name']
1253
1254    if layer_name in created_layers:
1255      layer = created_layers[layer_name]
1256    else:
1257      # Instantiate layer.
1258      from tensorflow.python.keras.layers import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
1259
1260      layer = deserialize_layer(layer_data, custom_objects=custom_objects)
1261      created_layers[layer_name] = layer
1262
1263    node_count_by_layer[layer] = int(_should_skip_first_node(layer))
1264
1265    # Gather layer inputs and convert to `ListWrapper` objects.
1266    inbound_nodes_data = layer_data['inbound_nodes']
1267    inbound_nodes_data = tf_utils.convert_inner_node_data(
1268        inbound_nodes_data, wrap=True)
1269    for node_data in inbound_nodes_data:
1270      # We don't process nodes (i.e. make layer calls)
1271      # on the fly because the inbound node may not yet exist,
1272      # in case of layer shared at different topological depths
1273      # (e.g. a model such as A(B(A(B(x)))))
1274      add_unprocessed_node(layer, node_data)
1275
1276  # First, we create all layers and enqueue nodes to be processed
1277  for layer_data in config['layers']:
1278    process_layer(layer_data)
1279  # Then we process nodes in order of layer depth.
1280  # Nodes that cannot yet be processed (if the inbound node
1281  # does not yet exist) are re-enqueued, and the process
1282  # is repeated until all nodes are processed.
1283  while unprocessed_nodes:
1284    for layer_data in config['layers']:
1285      layer = created_layers[layer_data['name']]
1286      if layer in unprocessed_nodes:
1287        for node_data in unprocessed_nodes.pop(layer):
1288          process_node(layer, node_data)
1289
1290  input_tensors = []
1291  output_tensors = []
1292
1293  input_layers = tf_utils.convert_inner_node_data(
1294      config['input_layers'], wrap=True)
1295  for layer_data in nest.flatten(input_layers):
1296    layer_name, node_index, tensor_index = layer_data.as_list()
1297    assert layer_name in created_layers
1298    layer = created_layers[layer_name]
1299    node_index = get_node_index(layer, node_index)
1300    layer_output_tensors = layer._inbound_nodes[node_index].output_tensors
1301    input_tensors.append(nest.flatten(layer_output_tensors)[tensor_index])
1302
1303  output_layers = tf_utils.convert_inner_node_data(
1304      config['output_layers'], wrap=True)
1305  for layer_data in nest.flatten(output_layers):
1306    layer_name, node_index, tensor_index = layer_data.as_list()
1307    assert layer_name in created_layers
1308    layer = created_layers[layer_name]
1309    node_index = get_node_index(layer, node_index)
1310    layer_output_tensors = layer._inbound_nodes[node_index].output_tensors
1311    output_tensors.append(nest.flatten(layer_output_tensors)[tensor_index])
1312
1313  input_tensors = nest.pack_sequence_as(input_layers, input_tensors)
1314  output_tensors = nest.pack_sequence_as(output_layers, output_tensors)
1315  return input_tensors, output_tensors, created_layers
1316
1317
1318def get_network_config(network, serialize_layer_fn=None):
1319  """Builds the config, which consists of the node graph and serialized layers.
1320
1321  Args:
1322    network: A Network object.
1323    serialize_layer_fn: Function used to serialize layers.
1324
1325  Returns:
1326    Config dictionary.
1327  """
1328  serialize_layer_fn = (
1329      serialize_layer_fn or generic_utils.serialize_keras_object)
1330  config = {
1331      'name': network.name,
1332  }
1333  node_conversion_map = {}
1334  for layer in network.layers:
1335    kept_nodes = 1 if _should_skip_first_node(layer) else 0
1336    for original_node_index, node in enumerate(layer._inbound_nodes):
1337      node_key = _make_node_key(layer.name, original_node_index)
1338      if node_key in network._network_nodes:
1339        node_conversion_map[node_key] = kept_nodes
1340        kept_nodes += 1
1341  layer_configs = []
1342
1343  with generic_utils.SharedObjectSavingScope():
1344    for layer in network.layers:  # From the earliest layers on.
1345      filtered_inbound_nodes = []
1346      for original_node_index, node in enumerate(layer._inbound_nodes):
1347        node_key = _make_node_key(layer.name, original_node_index)
1348        if node_key in network._network_nodes and not node.is_input:
1349          # The node is relevant to the model:
1350          # add to filtered_inbound_nodes.
1351          node_data = node.serialize(_make_node_key, node_conversion_map)
1352          filtered_inbound_nodes.append(node_data)
1353
1354      layer_config = serialize_layer_fn(layer)
1355      layer_config['name'] = layer.name
1356      layer_config['inbound_nodes'] = filtered_inbound_nodes
1357      layer_configs.append(layer_config)
1358    config['layers'] = layer_configs
1359
1360  # Gather info about inputs and outputs.
1361  model_inputs = []
1362  for i in range(len(network._input_layers)):
1363    layer, node_index, tensor_index = network._input_coordinates[i]
1364    node_key = _make_node_key(layer.name, node_index)
1365    if node_key not in network._network_nodes:
1366      continue
1367    new_node_index = node_conversion_map[node_key]
1368    model_inputs.append(
1369        tf_utils.ListWrapper([layer.name, new_node_index, tensor_index]))
1370  model_inputs = nest.pack_sequence_as(network._nested_inputs, model_inputs)
1371  # Preserve external Keras compat for Models with single input.
1372  if not nest.is_nested(model_inputs):
1373    model_inputs = [model_inputs]
1374  model_inputs = tf_utils.convert_inner_node_data(model_inputs)
1375  config['input_layers'] = model_inputs
1376
1377  model_outputs = []
1378  for i in range(len(network._output_layers)):
1379    layer, node_index, tensor_index = network._output_coordinates[i]
1380    node_key = _make_node_key(layer.name, node_index)
1381    if node_key not in network._network_nodes:
1382      continue
1383    new_node_index = node_conversion_map[node_key]
1384    model_outputs.append(
1385        tf_utils.ListWrapper([layer.name, new_node_index, tensor_index]))
1386  model_outputs = nest.pack_sequence_as(network._nested_outputs, model_outputs)
1387  # Preserve external Keras compat for Models with single output.
1388  if not nest.is_nested(model_outputs):
1389    model_outputs = [model_outputs]
1390  model_outputs = tf_utils.convert_inner_node_data(model_outputs)
1391  config['output_layers'] = model_outputs
1392  return config
1393
1394
1395def shape_with_no_batch_size(x):
1396  if x.shape.rank is None:
1397    return None
1398  shape = x.shape.as_list()
1399  if shape:
1400    shape[0] = None
1401  return shape
1402
1403
1404class ModuleWrapper(base_layer.Layer):
1405  """Wrapper for `tf.Module`s to support the Functional and Sequential API."""
1406
1407  def __init__(self, module, method_name=None, **kwargs):
1408    """Initializes the wrapper Layer for this module.
1409
1410    Args:
1411      module: The `tf.Module` instance to be wrapped.
1412      method_name: (Optional) str. The name of the method to use as the forward
1413        pass of the module. If not set, defaults to '__call__' if defined, or
1414        'call'.
1415      **kwargs: Additional keywrod arguments. See `tf.keras.layers.Layer`.
1416
1417    Raises:
1418      ValueError: If `method` is not defined on `module`.
1419    """
1420    super(ModuleWrapper, self).__init__(**kwargs)
1421    if method_name is None:
1422      if hasattr(module, '__call__'):
1423        method_name = '__call__'
1424      elif hasattr(module, 'call'):
1425        method_name = 'call'
1426    if method_name is None or not hasattr(module, method_name):
1427      raise ValueError('{} is not defined on object {}'.format(
1428          method_name, module))
1429
1430    self._module = module
1431    self._method_name = method_name
1432
1433    # Check if module.__call__ has a `training` arg or accepts `**kwargs`.
1434    method = getattr(module, method_name)
1435    method_arg_spec = tf_inspect.getfullargspec(method)
1436    self._expects_training_arg = ('training' in method_arg_spec.args or
1437                                  method_arg_spec.varkw is not None)
1438    self._expects_mask_arg = ('mask' in method_arg_spec.args or
1439                              method_arg_spec.varkw is not None)
1440
1441  def call(self, *args, **kwargs):
1442    if 'training' in kwargs and not self._expects_training_arg:
1443      kwargs.pop('training')
1444    if 'mask' in kwargs and not self._expects_mask_arg:
1445      kwargs.pop('mask')
1446    return getattr(self._module, self._method_name)(*args, **kwargs)
1447