xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/models.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16"""Code for model cloning, plus model-related API entries."""
17
18from tensorflow.python.framework import ops
19from tensorflow.python.keras import backend
20from tensorflow.python.keras import metrics as metrics_module
21from tensorflow.python.keras import optimizer_v1
22from tensorflow.python.keras.engine import functional
23from tensorflow.python.keras.engine import sequential
24from tensorflow.python.keras.engine import training
25from tensorflow.python.keras.engine import training_v1
26from tensorflow.python.keras.engine.base_layer import AddMetric
27from tensorflow.python.keras.engine.base_layer import Layer
28from tensorflow.python.keras.engine.input_layer import Input
29from tensorflow.python.keras.engine.input_layer import InputLayer
30from tensorflow.python.keras.saving import model_config
31from tensorflow.python.keras.saving import save
32from tensorflow.python.keras.utils import generic_utils
33from tensorflow.python.keras.utils import version_utils
34from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
35from tensorflow.python.platform import tf_logging as logging
36from tensorflow.python.util import nest
37from tensorflow.python.util.tf_export import keras_export
38
39
40# API entries importable from `keras.models`:
41Model = training.Model  # pylint: disable=invalid-name
42Sequential = sequential.Sequential  # pylint: disable=invalid-name
43Functional = functional.Functional  # pylint: disable=invalid-name
44save_model = save.save_model
45load_model = save.load_model
46model_from_config = model_config.model_from_config
47model_from_yaml = model_config.model_from_yaml
48model_from_json = model_config.model_from_json
49
50
51# Callable used to clone a layer with weights preserved.
52def share_weights(layer):
53  return layer
54
55
56def _clone_layer(layer):
57  return layer.__class__.from_config(layer.get_config())
58
59
60def _insert_ancillary_layers(model, ancillary_layers, metrics_names, new_nodes):
61  """Inserts ancillary layers into the model with the proper order."""
62  # Sort `AddMetric` layers so they agree with metrics_names.
63  metric_layers = [
64      layer for layer in ancillary_layers if isinstance(layer, AddMetric)
65  ]
66  metric_layers.sort(key=lambda layer: metrics_names.index(layer.metric_name))
67  ancillary_layers = [
68      layer for layer in ancillary_layers if not isinstance(layer, AddMetric)
69  ] + metric_layers
70  model._insert_layers(ancillary_layers, relevant_nodes=list(new_nodes))
71
72
73def _make_new_nodes(nodes_by_depth, layer_fn, layer_map, tensor_map):
74  """Uses the layers in `layer_map` to make new nodes based on `nodes_by_depth`.
75
76  Args:
77    nodes_by_depth: Provides structure information to create new nodes.
78    layer_fn: Function to clone layers.
79    layer_map: Map from layers in `model` to new layers.
80    tensor_map: Map from tensors in `model` to newly compute tensors.
81
82  Returns:
83    A set of new nodes. `layer_map` and `tensor_map` are updated.
84  """
85  # Iterated over every node in the reference model, in depth order.
86  new_nodes = set()
87  depth_keys = list(nodes_by_depth.keys())
88  depth_keys.sort(reverse=True)
89  for depth in depth_keys:
90    nodes = nodes_by_depth[depth]
91    for node in nodes:
92      # Recover the corresponding layer.
93      layer = node.outbound_layer
94
95      # Get or create layer.
96      if layer not in layer_map:
97        new_layer = layer_fn(layer)
98        layer_map[layer] = new_layer
99        layer = new_layer
100      else:
101        # Reuse previously cloned layer.
102        layer = layer_map[layer]
103        # Don't call InputLayer multiple times.
104        if isinstance(layer, InputLayer):
105          continue
106
107      # If all previous input tensors are available in tensor_map,
108      # then call node.inbound_layer on them.
109      if all(
110          tensor in tensor_map for tensor in nest.flatten(node.input_tensors)):
111        # Call layer.
112        args = nest.map_structure(lambda t: tensor_map.get(t, t),
113                                  node.call_args)
114        kwargs = nest.map_structure(lambda t: tensor_map.get(t, t),
115                                    node.call_kwargs)
116        output_tensors = layer(*args, **kwargs)
117
118        # Thread-safe way to keep track of what node was created.
119        first_output_tensor = nest.flatten(output_tensors)[0]
120        new_nodes.add(
121            layer._inbound_nodes[first_output_tensor._keras_history.node_index])
122
123        for x, y in zip(
124            nest.flatten(node.output_tensors), nest.flatten(output_tensors)):
125          tensor_map[x] = y
126  return new_nodes
127
128
129def _clone_functional_model(model, input_tensors=None, layer_fn=_clone_layer):
130  """Clone a functional `Model` instance.
131
132  Model cloning is similar to calling a model on new inputs,
133  except that it creates new layers (and thus new weights) instead
134  of sharing the weights of the existing layers.
135
136  Input layers are always cloned.
137
138  Args:
139      model: Instance of `Model`.
140      input_tensors: optional list of input tensors
141          to build the model upon. If not provided,
142          placeholders will be created.
143      layer_fn: callable to be applied on non-input layers in the model. By
144          default it clones the layer. Another example is to preserve the layer
145          to share the weights. This is required when we create a per-replica
146          copy of the model with distribution strategy; we want the weights to
147          be shared but still feed inputs separately so we create new input
148          layers.
149
150  Returns:
151      An instance of `Model` reproducing the behavior
152      of the original model, on top of new inputs tensors,
153      using newly instantiated weights.
154
155  Raises:
156      ValueError: in case of invalid `model` argument value or `layer_fn`
157      argument value.
158  """
159  if not isinstance(model, Model):
160    raise ValueError('Expected `model` argument '
161                     'to be a `Model` instance, got ', model)
162  if isinstance(model, Sequential):
163    raise ValueError('Expected `model` argument '
164                     'to be a functional `Model` instance, '
165                     'got a `Sequential` instance instead:', model)
166  if not model._is_graph_network:
167    raise ValueError('Expected `model` argument '
168                     'to be a functional `Model` instance, '
169                     'but got a subclass model instead.')
170
171  new_input_layers = {}  # Cache for created layers.
172  if input_tensors is not None:
173    # Make sure that all input tensors come from a Keras layer.
174    input_tensors = nest.flatten(input_tensors)
175    for i, input_tensor in enumerate(input_tensors):
176      original_input_layer = model._input_layers[i]
177
178      # Cache input layer. Create a new layer if the tensor is originally not
179      # from a Keras layer.
180      if not backend.is_keras_tensor(input_tensor):
181        name = original_input_layer.name
182        input_tensor = Input(tensor=input_tensor,
183                             name='input_wrapper_for_' + name)
184        newly_created_input_layer = input_tensor._keras_history.layer
185        new_input_layers[original_input_layer] = newly_created_input_layer
186      else:
187        new_input_layers[original_input_layer] = original_input_layer
188
189  if not callable(layer_fn):
190    raise ValueError('Expected `layer_fn` argument to be a callable.')
191
192  model_configs, created_layers = _clone_layers_and_model_config(
193      model, new_input_layers, layer_fn)
194  # Reconstruct model from the config, using the cloned layers.
195  input_tensors, output_tensors, created_layers = (
196      functional.reconstruct_from_config(model_configs,
197                                         created_layers=created_layers))
198  metrics_names = model.metrics_names
199  model = Model(input_tensors, output_tensors, name=model.name)
200  # Layers not directly tied to outputs of the Model, such as loss layers
201  # created in `add_loss` and `add_metric`.
202  ancillary_layers = [
203      layer for layer in created_layers.values() if layer not in model.layers
204  ]
205  # TODO(b/162887610): This may need to adjust the inbound node index if the
206  # created layers had already been used to define other models.
207  if ancillary_layers:
208    new_nodes = nest.flatten([
209        layer.inbound_nodes[1:]
210        if functional._should_skip_first_node(layer)
211        else layer.inbound_nodes for layer in created_layers.values()
212    ])
213    _insert_ancillary_layers(model, ancillary_layers, metrics_names, new_nodes)
214  return model
215
216
217def _clone_layers_and_model_config(model, input_layers, layer_fn):
218  """Clones all layers, and returns the model config without serializing layers.
219
220  This function ensures that only the node graph is retrieved when getting the
221  model config. The `layer_fn` used to clone layers might not rely on
222  `layer.get_config()`, so some custom layers do not define `get_config`.
223  Trying to retrieve the config results in errors.
224
225  Args:
226    model: A Functional model.
227    input_layers: Dictionary mapping input layers in `model` to new input layers
228    layer_fn: Function used to clone all non-input layers.
229
230  Returns:
231    Model config object, and a dictionary of newly created layers.
232  """
233  created_layers = {}
234  def _copy_layer(layer):
235    # Whenever the network config attempts to get the layer serialization,
236    # return a dummy dictionary.
237    if layer in input_layers:
238      created_layers[layer.name] = input_layers[layer]
239    elif layer in model._input_layers:
240      created_layers[layer.name] = InputLayer(**layer.get_config())
241    else:
242      created_layers[layer.name] = layer_fn(layer)
243    return {}
244
245  config = functional.get_network_config(
246      model, serialize_layer_fn=_copy_layer)
247  return config, created_layers
248
249
250def _remove_ancillary_layers(model, layer_map, layers):
251  """Removes and returns any ancillary layers from `layers` based on `model`.
252
253  Ancillary layers are part of the model topology but not used to compute the
254  model outputs, e.g., layers from `add_loss` and `add_metric`.
255
256  Args:
257    model: A Keras Model.
258    layer_map: A map to from layers in the `model` to those in `layers`.
259    layers: A list of all layers.
260
261  Returns:
262    Two lists of layers: (1) `layers` with the ancillary layers removed, and (2)
263    the ancillary layers.
264  """
265  ancillary_layers = []  # Additional layers for computing losses and metrics.
266  if not model._is_graph_network:
267    return layers, ancillary_layers
268
269  # Ancillary layers are those with depth < 0.
270  depths = [depth for depth in model._nodes_by_depth.keys() if depth < 0]
271  depths.sort(reverse=True)  # Order topologically from inputs to outputs.
272  for depth in depths:
273    for node in model._nodes_by_depth[depth]:
274      ancillary_layers.append(layer_map[node.outbound_layer])
275
276  return [l for l in layers if l not in ancillary_layers], ancillary_layers
277
278
279def _clone_sequential_model(model, input_tensors=None, layer_fn=_clone_layer):
280  """Clone a `Sequential` model instance.
281
282  Model cloning is similar to calling a model on new inputs,
283  except that it creates new layers (and thus new weights) instead
284  of sharing the weights of the existing layers.
285
286  Args:
287      model: Instance of `Sequential`.
288      input_tensors: optional list of input tensors
289          to build the model upon. If not provided,
290          placeholders will be created.
291      layer_fn: callable to be applied on non-input layers in the model. By
292          default it clones the layer. Another example is to preserve the layer
293          to share the weights. This is required when we create a per-replica
294          copy of the model with distribution strategy; we want the weights to
295          be shared but still feed inputs separately so we create new input
296          layers.
297
298  Returns:
299      An instance of `Sequential` reproducing the behavior
300      of the original model, on top of new inputs tensors,
301      using newly instantiated weights.
302
303  Raises:
304      ValueError: in case of invalid `model` argument value or `layer_fn`
305      argument value.
306  """
307  if not isinstance(model, Sequential):
308    raise ValueError('Expected `model` argument '
309                     'to be a `Sequential` model instance, '
310                     'but got:', model)
311
312  if not callable(layer_fn):
313    raise ValueError('Expected `layer_fn` argument to be a callable.')
314
315  layers = []  # Layers needed to compute the model's outputs.
316  layer_map = {}
317  # Ensure that all layers are cloned. The model's layers
318  # property will exclude the initial InputLayer (if it exists) in the model,
319  # resulting in a different Sequential model structure.
320  for layer in model._flatten_layers(include_self=False, recursive=False):
321    if isinstance(layer, InputLayer) and input_tensors is not None:
322      # If input tensors are provided, the original model's InputLayer is
323      # overwritten with a different InputLayer.
324      continue
325    cloned_layer = (
326        _clone_layer(layer)
327        if isinstance(layer, InputLayer) else layer_fn(layer))
328    layers.append(cloned_layer)
329    layer_map[layer] = cloned_layer
330  layers, ancillary_layers = _remove_ancillary_layers(model, layer_map, layers)
331
332  if input_tensors is None:
333    cloned_model = Sequential(layers=layers, name=model.name)
334  elif len(generic_utils.to_list(input_tensors)) != 1:
335    raise ValueError('To clone a `Sequential` model, we expect '
336                     ' at most one tensor '
337                     'as part of `input_tensors`.')
338  else:
339    # Overwrite the original model's input layer.
340    if isinstance(input_tensors, tuple):
341      input_tensors = list(input_tensors)
342    x = generic_utils.to_list(input_tensors)[0]
343    if backend.is_keras_tensor(x):
344      origin_layer = x._keras_history.layer
345      if isinstance(origin_layer, InputLayer):
346        cloned_model = Sequential(
347            layers=[origin_layer] + layers, name=model.name)
348      else:
349        raise ValueError('Cannot clone a `Sequential` model on top '
350                         'of a tensor that comes from a Keras layer '
351                         'other than an `InputLayer`. '
352                         'Use the functional API instead.')
353    else:
354      input_tensor = Input(tensor=x, name='input_wrapper_for_' + str(x.name))
355      input_layer = input_tensor._keras_history.layer
356      cloned_model = Sequential(layers=[input_layer] + layers, name=model.name)
357
358  if not ancillary_layers:
359    return cloned_model
360
361  tensor_map = {}  # Maps tensors from `model` to those in `cloned_model`.
362  for depth, cloned_nodes in cloned_model._nodes_by_depth.items():
363    nodes = model._nodes_by_depth[depth]
364    # This should be safe in a Sequential model. In an arbitrary network, you
365    # need to sort using the outbound layer of the node as a key.
366    for cloned_node, node in zip(cloned_nodes, nodes):
367      if isinstance(cloned_node.output_tensors, list):
368        for j, output_tensor in enumerate(cloned_node.output_tensors):
369          tensor_map[node.output_tensors[j]] = output_tensor
370      else:
371        tensor_map[node.output_tensors] = cloned_node.output_tensors
372  # Ancillary nodes have negative depth.
373  new_nodes = _make_new_nodes(
374      {
375          depth: nodes
376          for depth, nodes in model._nodes_by_depth.items()
377          if depth < 0
378      }, layer_fn, layer_map, tensor_map)
379  _insert_ancillary_layers(cloned_model, ancillary_layers, model.metrics_names,
380                           new_nodes)
381  return cloned_model
382
383
384@keras_export('keras.models.clone_model')
385def clone_model(model, input_tensors=None, clone_function=None):
386  """Clone a Functional or Sequential `Model` instance.
387
388  Model cloning is similar to calling a model on new inputs,
389  except that it creates new layers (and thus new weights) instead
390  of sharing the weights of the existing layers.
391
392  Note that
393  `clone_model` will not preserve the uniqueness of shared objects within the
394  model (e.g. a single variable attached to two distinct layers will be
395  restored as two separate variables).
396
397  Args:
398      model: Instance of `Model`
399          (could be a Functional model or a Sequential model).
400      input_tensors: optional list of input tensors or InputLayer objects
401          to build the model upon. If not provided,
402          new `Input` objects will be created.
403      clone_function: Callable to be used to clone each layer in the target
404          model (except `InputLayer` instances). It takes as argument the layer
405          instance to be cloned, and returns the corresponding layer instance to
406          be used in the model copy. If unspecified, this callable defaults to
407          the following serialization/deserialization function:
408          `lambda layer: layer.__class__.from_config(layer.get_config())`.
409          By passing a custom callable, you can customize your copy of the
410          model, e.g. by wrapping certain layers of interest (you might want to
411          replace all `LSTM` instances with equivalent
412          `Bidirectional(LSTM(...))` instances, for example).
413
414  Returns:
415    An instance of `Model` reproducing the behavior
416    of the original model, on top of new inputs tensors,
417    using newly instantiated weights. The cloned model may behave
418    differently from the original model if a custom `clone_function`
419    modifies the layer.
420
421  Example:
422
423  ```python
424  # Create a test Sequential model.
425  model = keras.Sequential([
426      keras.Input(shape=(728,)),
427      keras.layers.Dense(32, activation='relu'),
428      keras.layers.Dense(1, activation='sigmoid'),
429  ])
430  # Create a copy of the test model (with freshly initialized weights).
431  new_model = clone_model(model)
432  ```
433
434  Note that subclassed models cannot be cloned, since their internal
435  layer structure is not known. To achieve equivalent functionality
436  as `clone_model` in the case of a subclassed model, simply make sure
437  that the model class implements `get_config()`
438  (and optionally `from_config()`), and call:
439
440  ```python
441  new_model = model.__class__.from_config(model.get_config())
442  ```
443  """
444  with generic_utils.DisableSharedObjectScope():
445    if clone_function is None:
446      clone_function = _clone_layer
447
448    if isinstance(model, Sequential):
449      return _clone_sequential_model(
450          model, input_tensors=input_tensors, layer_fn=clone_function)
451    else:
452      return _clone_functional_model(
453          model, input_tensors=input_tensors, layer_fn=clone_function)
454
455
456# "Clone" a subclassed model by reseting all of the attributes.
457def _in_place_subclassed_model_reset(model):
458  """Substitute for model cloning that works for subclassed models.
459
460  Subclassed models cannot be cloned because their topology is not serializable.
461  To "instantiate" an identical model in a new TF graph, we reuse the original
462  model object, but we clear its state.
463
464  After calling this function on a model instance, you can use the model
465  instance as if it were a model clone (in particular you can use it in a new
466  graph).
467
468  This method clears the state of the input model. It is thus destructive.
469  However the original state can be restored fully by calling
470  `_in_place_subclassed_model_state_restoration`.
471
472  Args:
473    model: Instance of a Keras model created via subclassing.
474
475  Raises:
476    ValueError: In case the model uses a subclassed model as inner layer.
477  """
478  assert not model._is_graph_network  # Only makes sense for subclassed networks
479  # Select correct base class for new Model.
480  version_utils.swap_class(model.__class__, training.Model, training_v1.Model,
481                           ops.executing_eagerly_outside_functions())
482  # Retrieve all layers tracked by the model as well as their attribute names
483  attributes_cache = {}
484  for name in dir(model):
485    # Skip attrs that track other trackables.
486    if name == 'submodules' or name == '_self_tracked_trackables':
487      continue
488
489    try:
490      value = getattr(model, name)
491    except (AttributeError, ValueError, TypeError):
492      continue
493    if isinstance(value, Layer):
494      attributes_cache[name] = value
495      assert value in model.layers
496      if hasattr(value, 'layers') and value.layers:
497        raise ValueError('We do not support the use of nested layers '
498                         'in `model_to_estimator` at this time. Found nested '
499                         'layer: %s' % value)
500    elif isinstance(
501        value, (list, tuple)) and name not in ('layers', '_layers', 'metrics',
502                                               '_compile_metric_functions',
503                                               '_output_loss_metrics'):
504      # Handle case: list/tuple of layers (also tracked by the Network API).
505      if value and all(isinstance(val, Layer) for val in value):
506        raise ValueError('We do not support the use of list-of-layers '
507                         'attributes in subclassed models used with '
508                         '`model_to_estimator` at this time. Found list '
509                         'model: %s' % name)
510
511  # Replace layers on the model with fresh layers
512  layers_to_names = {value: key for key, value in attributes_cache.items()}
513  original_layers = list(
514      model._flatten_layers(include_self=False, recursive=False))
515  setattr_tracking = model._setattr_tracking
516  model._setattr_tracking = False
517  model._self_tracked_trackables = []
518  for layer in original_layers:  # We preserve layer order.
519    config = layer.get_config()
520    # This will not work for nested subclassed models used as layers.
521    # This would be theoretically possible to support, but would add complexity.
522    # Only do it if users complain.
523    if isinstance(layer, training.Model) and not layer._is_graph_network:
524      raise ValueError('We do not support the use of nested subclassed models '
525                       'in `model_to_estimator` at this time. Found nested '
526                       'model: %s' % layer)
527    fresh_layer = layer.__class__.from_config(config)
528    name = layers_to_names[layer]
529    setattr(model, name, fresh_layer)
530    model._self_tracked_trackables.append(fresh_layer)
531
532  # Cache original model build attributes (in addition to layers)
533  if (not hasattr(model, '_original_attributes_cache') or
534      model._original_attributes_cache is None):
535    if model.built:
536      attributes_to_cache = [
537          'inputs',
538          'outputs',
539          'total_loss',
540          'optimizer',
541          'train_function',
542          'test_function',
543          'predict_function',
544          '_training_endpoints',
545          '_collected_trainable_weights',
546          '_feed_inputs',
547          '_feed_input_names',
548          '_feed_input_shapes',
549      ]
550      for name in attributes_to_cache:
551        attributes_cache[name] = getattr(model, name)
552  model._original_attributes_cache = attributes_cache
553  _reset_build_compile_trackers(model)
554  model._setattr_tracking = setattr_tracking
555
556
557def _reset_build_compile_trackers(model):
558  """Reset state trackers for model.
559
560  Note that we do not actually zero out attributes such as optimizer,
561  but instead rely on the expectation that all of the attrs will be
562  over-written on calling build/compile/etc. This is somewhat fragile,
563  insofar as we check elsewhere for the presence of these attributes as
564  evidence of having been built/compiled/etc. Pending a better way to do this,
565  we reset key attributes here to allow building and compiling.
566
567  Args:
568    model: the model that is being reset
569  """
570  # Reset build state
571  model.built = False
572  model.inputs = None
573  model.outputs = None
574  # Reset compile state
575  model._is_compiled = False  # pylint:disable=protected-access
576  if not ops.executing_eagerly_outside_functions():
577    model._v1_compile_was_called = False
578  model.optimizer = None
579
580
581@keras_export(
582    'keras.__internal__.models.in_place_subclassed_model_state_restoration',
583    v1=[])
584def in_place_subclassed_model_state_restoration(model):
585  """Restores the original state of a model after it was "reset".
586
587  This undoes this action of `_in_place_subclassed_model_reset`, which is called
588  in `clone_and_build_model` if `in_place_reset` is set to True.
589
590  Args:
591    model: Instance of a Keras model created via subclassing, on which
592      `_in_place_subclassed_model_reset` was previously called.
593  """
594  assert not model._is_graph_network
595  # Restore layers and build attributes
596  if (hasattr(model, '_original_attributes_cache') and
597      model._original_attributes_cache is not None):
598    # Models have sticky attribute assignment, so we want to be careful to add
599    # back the previous attributes and track Layers by their original names
600    # without adding dependencies on "utility" attributes which Models exempt
601    # when they're constructed.
602    setattr_tracking = model._setattr_tracking
603    model._setattr_tracking = False
604    model._self_tracked_trackables = []
605    for name, value in model._original_attributes_cache.items():
606      setattr(model, name, value)
607      if isinstance(value, Layer):
608        model._self_tracked_trackables.append(value)
609    model._original_attributes_cache = None
610    model._setattr_tracking = setattr_tracking
611  else:
612    # Restore to the state of a never-called model.
613    _reset_build_compile_trackers(model)
614
615
616@keras_export('keras.__internal__.models.clone_and_build_model', v1=[])
617def clone_and_build_model(
618    model, input_tensors=None, target_tensors=None, custom_objects=None,
619    compile_clone=True, in_place_reset=False, optimizer_iterations=None,
620    optimizer_config=None):
621  """Clone a `Model` and build/compile it with the same settings used before.
622
623  This function can be run in the same graph or in a separate graph from the
624  model. When using a separate graph, `in_place_reset` must be `False`.
625
626  Note that, currently, the clone produced from this function may not work with
627  TPU DistributionStrategy. Try at your own risk.
628
629  Args:
630    model: `tf.keras.Model` object. Can be Functional, Sequential, or
631      sub-classed.
632    input_tensors: Optional list or dictionary of input tensors to build the
633      model upon. If not provided, placeholders will be created.
634    target_tensors: Optional list of target tensors for compiling the model. If
635      not provided, placeholders will be created.
636    custom_objects: Optional dictionary mapping string names to custom classes
637      or functions.
638    compile_clone: Boolean, whether to compile model clone (default `True`).
639    in_place_reset: Boolean, whether to reset the model in place. Only used if
640      the model is a subclassed model. In the case of a subclassed model,
641      this argument must be set to `True` (default `False`). To restore the
642      original model, use the function
643      `in_place_subclassed_model_state_restoration(model)`.
644    optimizer_iterations: An iterations variable that will be incremented by the
645      optimizer if the clone is compiled. This argument is used when a Keras
646      model is cloned into an Estimator model function, because Estimators
647      create their own global step variable.
648    optimizer_config: Optimizer config dictionary or list of dictionary
649      returned from `get_config()`. This argument should be defined if
650      `clone_and_build_model` is called in a different graph or session from
651      the original model, and the optimizer is an instance of `OptimizerV2`.
652
653  Returns:
654    Clone of the model.
655
656  Raises:
657    ValueError: Cloning fails in the following cases
658      - cloning a subclassed model with `in_place_reset` set to False.
659      - compiling the clone when the original model has not been compiled.
660  """
661  # Grab optimizer now, as we reset-in-place for subclassed models, but
662  # want to maintain access to the original optimizer.
663  orig_optimizer = model.optimizer
664  if compile_clone and not orig_optimizer:
665    raise ValueError(
666        'Error when cloning model: compile_clone was set to True, but the '
667        'original model has not been compiled.')
668
669  if compile_clone:
670    compile_args = model._get_compile_args()  # pylint: disable=protected-access
671    # Allows this method to be robust to switching graph and eager classes.
672    model._get_compile_args = lambda: compile_args
673
674  with CustomObjectScope(custom_objects or {}):
675    if model._is_graph_network:
676      clone = clone_model(model, input_tensors=input_tensors)
677    elif isinstance(model, Sequential):
678      clone = clone_model(model, input_tensors=input_tensors)
679      if (not clone._is_graph_network and model._build_input_shape is not None):
680        if ops.executing_eagerly_outside_functions():
681          clone.build(model._build_input_shape)
682        else:
683          clone._set_inputs(
684              backend.placeholder(
685                  model._build_input_shape, dtype=model.inputs[0].dtype))
686    else:
687      try:
688        # Prefer cloning the model if serial/deserial logic is implemented for
689        # subclassed model.
690        clone = model.__class__.from_config(model.get_config())
691      except NotImplementedError:
692        logging.warning('This model is a subclassed model. Please implement '
693                        '`get_config` and `from_config` to better support '
694                        'cloning the model.')
695        if not in_place_reset:
696          raise ValueError(
697              'This model is a subclassed model. '
698              'Such a model cannot be cloned, but there is a workaround where '
699              'the model is reset in-place. To use this, please set the '
700              'argument `in_place_reset` to `True`. This will reset the '
701              'attributes in the original model. To restore the attributes, '
702              'call `in_place_subclassed_model_state_restoration(model)`.')
703        clone = model
704        _in_place_subclassed_model_reset(clone)
705      if input_tensors is not None:
706        if isinstance(input_tensors, (list, tuple)) and len(input_tensors) == 1:
707          input_tensors = input_tensors[0]
708        clone._set_inputs(input_tensors)
709
710  if compile_clone:
711    if isinstance(orig_optimizer, optimizer_v1.TFOptimizer):
712      optimizer = optimizer_v1.TFOptimizer(
713          orig_optimizer.optimizer, optimizer_iterations)
714      backend.track_tf_optimizer(optimizer)
715    else:
716      if not isinstance(orig_optimizer, (tuple, list)):
717        orig_optimizer = [orig_optimizer]
718      if optimizer_config is None:
719        optimizer = [
720            opt.__class__.from_config(opt.get_config())
721            for opt in orig_optimizer
722        ]
723      elif isinstance(optimizer_config, dict):
724        optimizer = [orig_optimizer[0].__class__.from_config(optimizer_config)]
725      else:
726        # optimizer config is list of dict, same order as orig_optimizer.
727        optimizer = [
728            opt.__class__.from_config(opt_config)
729            for (opt, opt_config) in zip(orig_optimizer, optimizer_config)
730        ]
731      if optimizer_iterations is not None:
732        for opt in optimizer:
733          opt.iterations = optimizer_iterations
734
735      if len(optimizer) == 1:
736        optimizer = optimizer[0]
737
738    compile_args['optimizer'] = optimizer
739    if target_tensors is not None:
740      compile_args['target_tensors'] = target_tensors
741    # Ensure Metric objects in new model are separate from existing model.
742    compile_args['metrics'] = metrics_module.clone_metrics(
743        compile_args['metrics'])
744    compile_args['weighted_metrics'] = metrics_module.clone_metrics(
745        compile_args['weighted_metrics'])
746    clone.compile(**compile_args)
747
748  return clone
749