xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/saving/hdf5_format.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16"""Functions for saving and loading a Keras Model from HDF5 format."""
17
18import json
19import os
20
21import numpy as np
22
23from tensorflow.python.keras import backend
24from tensorflow.python.keras import optimizer_v1
25from tensorflow.python.keras.saving import model_config as model_config_lib
26from tensorflow.python.keras.saving import saving_utils
27from tensorflow.python.keras.saving.saved_model import json_utils
28from tensorflow.python.keras.utils.generic_utils import LazyLoader
29from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
30from tensorflow.python.ops import variables as variables_module
31from tensorflow.python.platform import gfile
32from tensorflow.python.platform import tf_logging as logging
33
34
35# pylint: disable=g-import-not-at-top
36try:
37  import h5py
38  HDF5_OBJECT_HEADER_LIMIT = 64512
39except ImportError:
40  h5py = None
41# pylint: enable=g-import-not-at-top
42
43# TODO(b/134426265): Switch back to single-quotes to match the rest of the file
44# once the issue with copybara is fixed.
45# pylint:disable=g-inconsistent-quotes
46sequential_lib = LazyLoader(
47    "sequential_lib", globals(),
48    "tensorflow.python.keras.engine.sequential")
49# pylint:enable=g-inconsistent-quotes
50
51
52def save_model_to_hdf5(model, filepath, overwrite=True, include_optimizer=True):
53  """Saves a model to a HDF5 file.
54
55  The saved model contains:
56      - the model's configuration (topology)
57      - the model's weights
58      - the model's optimizer's state (if any)
59
60  Thus the saved model can be reinstantiated in
61  the exact same state, without any of the code
62  used for model definition or training.
63
64  Args:
65      model: Keras model instance to be saved.
66      filepath: One of the following:
67          - String, path where to save the model
68          - `h5py.File` object where to save the model
69      overwrite: Whether we should overwrite any existing
70          model at the target location, or instead
71          ask the user with a manual prompt.
72      include_optimizer: If True, save optimizer's state together.
73
74  Raises:
75      ImportError: if h5py is not available.
76  """
77
78  if h5py is None:
79    raise ImportError('`save_model` requires h5py.')
80
81  # TODO(psv) Add warning when we save models that contain non-serializable
82  # entities like metrics added using `add_metric` and losses added using
83  # `add_loss.`
84  if len(model.weights) != len(model._undeduplicated_weights):
85    logging.warning('Found duplicated `Variable`s in Model\'s `weights`. '
86                    'This is usually caused by `Variable`s being shared by '
87                    'Layers in the Model. These `Variable`s will be treated '
88                    'as separate `Variable`s when the Model is restored. To '
89                    'avoid this, please save with `save_format="tf"`.')
90
91  if not isinstance(filepath, h5py.File):
92    # If file exists and should not be overwritten.
93    if not overwrite and os.path.isfile(filepath):
94      proceed = ask_to_proceed_with_overwrite(filepath)
95      if not proceed:
96        return
97
98    # Try creating dir if not exist
99    dirpath = os.path.dirname(filepath)
100    if not os.path.exists(dirpath):
101      gfile.MakeDirs(dirpath)
102
103    f = h5py.File(filepath, mode='w')
104    opened_new_file = True
105  else:
106    f = filepath
107    opened_new_file = False
108
109  try:
110    model_metadata = saving_utils.model_metadata(model, include_optimizer)
111    for k, v in model_metadata.items():
112      if isinstance(v, (dict, list, tuple)):
113        f.attrs[k] = json.dumps(
114            v, default=json_utils.get_json_type).encode('utf8')
115      else:
116        f.attrs[k] = v
117
118    model_weights_group = f.create_group('model_weights')
119    model_layers = model.layers
120    save_weights_to_hdf5_group(model_weights_group, model_layers)
121
122    # TODO(b/128683857): Add integration tests between tf.keras and external
123    # Keras, to avoid breaking TF.js users.
124    if (include_optimizer and model.optimizer and
125        not isinstance(model.optimizer, optimizer_v1.TFOptimizer)):
126      save_optimizer_weights_to_hdf5_group(f, model.optimizer)
127
128    f.flush()
129  finally:
130    if opened_new_file:
131      f.close()
132
133
134def load_model_from_hdf5(filepath, custom_objects=None, compile=True):  # pylint: disable=redefined-builtin
135  """Loads a model saved via `save_model_to_hdf5`.
136
137  Args:
138      filepath: One of the following:
139          - String, path to the saved model
140          - `h5py.File` object from which to load the model
141      custom_objects: Optional dictionary mapping names
142          (strings) to custom classes or functions to be
143          considered during deserialization.
144      compile: Boolean, whether to compile the model
145          after loading.
146
147  Returns:
148      A Keras model instance. If an optimizer was found
149      as part of the saved model, the model is already
150      compiled. Otherwise, the model is uncompiled and
151      a warning will be displayed. When `compile` is set
152      to False, the compilation is omitted without any
153      warning.
154
155  Raises:
156      ImportError: if h5py is not available.
157      ValueError: In case of an invalid savefile.
158  """
159  if h5py is None:
160    raise ImportError('`load_model` requires h5py.')
161
162  if not custom_objects:
163    custom_objects = {}
164
165  opened_new_file = not isinstance(filepath, h5py.File)
166  if opened_new_file:
167    f = h5py.File(filepath, mode='r')
168  else:
169    f = filepath
170
171  model = None
172  try:
173    # instantiate model
174    model_config = f.attrs.get('model_config')
175    if model_config is None:
176      raise ValueError('No model found in config file.')
177    if hasattr(model_config, 'decode'):
178      model_config = model_config.decode('utf-8')
179    model_config = json_utils.decode(model_config)
180    model = model_config_lib.model_from_config(model_config,
181                                               custom_objects=custom_objects)
182
183    # set weights
184    load_weights_from_hdf5_group(f['model_weights'], model.layers)
185
186    if compile:
187      # instantiate optimizer
188      training_config = f.attrs.get('training_config')
189      if hasattr(training_config, 'decode'):
190        training_config = training_config.decode('utf-8')
191      if training_config is None:
192        logging.warning('No training configuration found in the save file, so '
193                        'the model was *not* compiled. Compile it manually.')
194        return model
195      training_config = json_utils.decode(training_config)
196
197      # Compile model.
198      model.compile(**saving_utils.compile_args_from_training_config(
199          training_config, custom_objects), from_serialized=True)
200      saving_utils.try_build_compiled_arguments(model)
201
202      # Set optimizer weights.
203      if 'optimizer_weights' in f:
204        try:
205          model.optimizer._create_all_weights(model.trainable_variables)
206        except (NotImplementedError, AttributeError):
207          logging.warning(
208              'Error when creating the weights of optimizer {}, making it '
209              'impossible to restore the saved optimizer state. As a result, '
210              'your model is starting with a freshly initialized optimizer.')
211
212        optimizer_weight_values = load_optimizer_weights_from_hdf5_group(f)
213        try:
214          model.optimizer.set_weights(optimizer_weight_values)
215        except ValueError:
216          logging.warning('Error in loading the saved optimizer '
217                          'state. As a result, your model is '
218                          'starting with a freshly initialized '
219                          'optimizer.')
220  finally:
221    if opened_new_file:
222      f.close()
223  return model
224
225
226def preprocess_weights_for_loading(layer,
227                                   weights,
228                                   original_keras_version=None,
229                                   original_backend=None):
230  """Preprocess layer weights between different Keras formats.
231
232  Converts layers weights from Keras 1 format to Keras 2 and also weights of
233  CuDNN layers in Keras 2.
234
235  Args:
236      layer: Layer instance.
237      weights: List of weights values (Numpy arrays).
238      original_keras_version: Keras version for the weights, as a string.
239      original_backend: Keras backend the weights were trained with,
240          as a string.
241
242  Returns:
243      A list of weights values (Numpy arrays).
244  """
245  def convert_nested_bidirectional(weights):
246    """Converts layers nested in `Bidirectional` wrapper.
247
248    This function uses `preprocess_weights_for_loading()` for converting
249    layers.
250
251    Args:
252        weights: List of weights values (Numpy arrays).
253
254    Returns:
255        A list of weights values (Numpy arrays).
256    """
257    num_weights_per_layer = len(weights) // 2
258    forward_weights = preprocess_weights_for_loading(
259        layer.forward_layer, weights[:num_weights_per_layer],
260        original_keras_version, original_backend)
261    backward_weights = preprocess_weights_for_loading(
262        layer.backward_layer, weights[num_weights_per_layer:],
263        original_keras_version, original_backend)
264    return forward_weights + backward_weights
265
266  def convert_nested_time_distributed(weights):
267    """Converts layers nested in `TimeDistributed` wrapper.
268
269    This function uses `preprocess_weights_for_loading()` for converting nested
270    layers.
271
272    Args:
273        weights: List of weights values (Numpy arrays).
274
275    Returns:
276        A list of weights values (Numpy arrays).
277    """
278    return preprocess_weights_for_loading(
279        layer.layer, weights, original_keras_version, original_backend)
280
281  def convert_nested_model(weights):
282    """Converts layers nested in `Model` or `Sequential`.
283
284    This function uses `preprocess_weights_for_loading()` for converting nested
285    layers.
286
287    Args:
288        weights: List of weights values (Numpy arrays).
289
290    Returns:
291        A list of weights values (Numpy arrays).
292    """
293    trainable_weights = weights[:len(layer.trainable_weights)]
294    non_trainable_weights = weights[len(layer.trainable_weights):]
295
296    new_trainable_weights = []
297    new_non_trainable_weights = []
298
299    for sublayer in layer.layers:
300      num_trainable_weights = len(sublayer.trainable_weights)
301      num_non_trainable_weights = len(sublayer.non_trainable_weights)
302      if sublayer.weights:
303        preprocessed = preprocess_weights_for_loading(
304            layer=sublayer,
305            weights=(trainable_weights[:num_trainable_weights] +
306                     non_trainable_weights[:num_non_trainable_weights]),
307            original_keras_version=original_keras_version,
308            original_backend=original_backend)
309        new_trainable_weights.extend(preprocessed[:num_trainable_weights])
310        new_non_trainable_weights.extend(preprocessed[num_trainable_weights:])
311
312        trainable_weights = trainable_weights[num_trainable_weights:]
313        non_trainable_weights = non_trainable_weights[
314            num_non_trainable_weights:]
315
316    return new_trainable_weights + new_non_trainable_weights
317
318  # Convert layers nested in Bidirectional/Model/Sequential.
319  # Both transformation should be ran for both Keras 1->2 conversion
320  # and for conversion of CuDNN layers.
321  if layer.__class__.__name__ == 'Bidirectional':
322    weights = convert_nested_bidirectional(weights)
323  if layer.__class__.__name__ == 'TimeDistributed':
324    weights = convert_nested_time_distributed(weights)
325  elif layer.__class__.__name__ in ['Model', 'Sequential', 'Functional']:
326    weights = convert_nested_model(weights)
327
328  if original_keras_version == '1':
329    if layer.__class__.__name__ == 'TimeDistributed':
330      weights = preprocess_weights_for_loading(
331          layer.layer, weights, original_keras_version, original_backend)
332
333    if layer.__class__.__name__ == 'Conv1D':
334      shape = weights[0].shape
335      # Handle Keras 1.1 format
336      if shape[:2] != (layer.kernel_size[0], 1) or shape[3] != layer.filters:
337        # Legacy shape:
338        # (filters, input_dim, filter_length, 1)
339        assert shape[0] == layer.filters and shape[2:] == (layer.kernel_size[0],
340                                                           1)
341        weights[0] = np.transpose(weights[0], (2, 3, 1, 0))
342      weights[0] = weights[0][:, 0, :, :]
343
344    if layer.__class__.__name__ == 'Conv2D':
345      if layer.data_format == 'channels_first':
346        # old: (filters, stack_size, kernel_rows, kernel_cols)
347        # new: (kernel_rows, kernel_cols, stack_size, filters)
348        weights[0] = np.transpose(weights[0], (2, 3, 1, 0))
349
350    if layer.__class__.__name__ == 'Conv2DTranspose':
351      if layer.data_format == 'channels_last':
352        # old: (kernel_rows, kernel_cols, stack_size, filters)
353        # new: (kernel_rows, kernel_cols, filters, stack_size)
354        weights[0] = np.transpose(weights[0], (0, 1, 3, 2))
355      if layer.data_format == 'channels_first':
356        # old: (filters, stack_size, kernel_rows, kernel_cols)
357        # new: (kernel_rows, kernel_cols, filters, stack_size)
358        weights[0] = np.transpose(weights[0], (2, 3, 0, 1))
359
360    if layer.__class__.__name__ == 'Conv3D':
361      if layer.data_format == 'channels_first':
362        # old: (filters, stack_size, ...)
363        # new: (..., stack_size, filters)
364        weights[0] = np.transpose(weights[0], (2, 3, 4, 1, 0))
365
366    if layer.__class__.__name__ == 'GRU':
367      if len(weights) == 9:
368        kernel = np.concatenate([weights[0], weights[3], weights[6]], axis=-1)
369        recurrent_kernel = np.concatenate(
370            [weights[1], weights[4], weights[7]], axis=-1)
371        bias = np.concatenate([weights[2], weights[5], weights[8]], axis=-1)
372        weights = [kernel, recurrent_kernel, bias]
373
374    if layer.__class__.__name__ == 'LSTM':
375      if len(weights) == 12:
376        # old: i, c, f, o
377        # new: i, f, c, o
378        kernel = np.concatenate(
379            [weights[0], weights[6], weights[3], weights[9]], axis=-1)
380        recurrent_kernel = np.concatenate(
381            [weights[1], weights[7], weights[4], weights[10]], axis=-1)
382        bias = np.concatenate(
383            [weights[2], weights[8], weights[5], weights[11]], axis=-1)
384        weights = [kernel, recurrent_kernel, bias]
385
386    if layer.__class__.__name__ == 'ConvLSTM2D':
387      if len(weights) == 12:
388        kernel = np.concatenate(
389            [weights[0], weights[6], weights[3], weights[9]], axis=-1)
390        recurrent_kernel = np.concatenate(
391            [weights[1], weights[7], weights[4], weights[10]], axis=-1)
392        bias = np.concatenate(
393            [weights[2], weights[8], weights[5], weights[11]], axis=-1)
394        if layer.data_format == 'channels_first':
395          # old: (filters, stack_size, kernel_rows, kernel_cols)
396          # new: (kernel_rows, kernel_cols, stack_size, filters)
397          kernel = np.transpose(kernel, (2, 3, 1, 0))
398          recurrent_kernel = np.transpose(recurrent_kernel, (2, 3, 1, 0))
399        weights = [kernel, recurrent_kernel, bias]
400
401  conv_layers = ['Conv1D', 'Conv2D', 'Conv3D', 'Conv2DTranspose', 'ConvLSTM2D']
402  if layer.__class__.__name__ in conv_layers:
403    if backend.int_shape(layer.weights[0]) != weights[0].shape:
404      weights[0] = np.transpose(weights[0], (3, 2, 0, 1))
405      if layer.__class__.__name__ == 'ConvLSTM2D':
406        weights[1] = np.transpose(weights[1], (3, 2, 0, 1))
407
408  # convert CuDNN layers
409  return _convert_rnn_weights(layer, weights)
410
411
412def _convert_rnn_weights(layer, weights):
413  """Converts weights for RNN layers between native and CuDNN format.
414
415  Input kernels for each gate are transposed and converted between Fortran
416  and C layout, recurrent kernels are transposed. For LSTM biases are summed/
417  split in half, for GRU biases are reshaped.
418
419  Weights can be converted in both directions between `LSTM` and`CuDNNSLTM`
420  and between `CuDNNGRU` and `GRU(reset_after=True)`. Default `GRU` is not
421  compatible with `CuDNNGRU`.
422
423  For missing biases in `LSTM`/`GRU` (`use_bias=False`) no conversion is made.
424
425  Args:
426      layer: Target layer instance.
427      weights: List of source weights values (input kernels, recurrent
428          kernels, [biases]) (Numpy arrays).
429
430  Returns:
431      A list of converted weights values (Numpy arrays).
432
433  Raises:
434      ValueError: for incompatible GRU layer/weights or incompatible biases
435  """
436
437  def transform_kernels(kernels, func, n_gates):
438    """Transforms kernel for each gate separately using given function.
439
440    Args:
441        kernels: Stacked array of kernels for individual gates.
442        func: Function applied to kernel of each gate.
443        n_gates: Number of gates (4 for LSTM, 3 for GRU).
444
445    Returns:
446        Stacked array of transformed kernels.
447    """
448    return np.hstack([func(k) for k in np.hsplit(kernels, n_gates)])
449
450  def transpose_input(from_cudnn):
451    """Makes a function that transforms input kernels from/to CuDNN format.
452
453    It keeps the shape, but changes between the layout (Fortran/C). Eg.:
454
455    ```
456    Keras                 CuDNN
457    [[0, 1, 2],  <--->  [[0, 2, 4],
458     [3, 4, 5]]          [1, 3, 5]]
459    ```
460
461    It can be passed to `transform_kernels()`.
462
463    Args:
464        from_cudnn: `True` if source weights are in CuDNN format, `False`
465            if they're in plain Keras format.
466
467    Returns:
468        Function that converts input kernel to the other format.
469    """
470    order = 'F' if from_cudnn else 'C'
471
472    def transform(kernel):
473      return kernel.T.reshape(kernel.shape, order=order)
474
475    return transform
476
477  target_class = layer.__class__.__name__
478
479  # convert the weights between CuDNNLSTM and LSTM
480  if target_class in ['LSTM', 'CuDNNLSTM'] and len(weights) == 3:
481    # determine if we're loading a CuDNNLSTM layer
482    # from the number of bias weights:
483    # CuDNNLSTM has (units * 8) weights; while LSTM has (units * 4)
484    # if there's no bias weight in the file, skip this conversion
485    units = weights[1].shape[0]
486    bias_shape = weights[2].shape
487    n_gates = 4
488
489    if bias_shape == (2 * units * n_gates,):
490      source = 'CuDNNLSTM'
491    elif bias_shape == (units * n_gates,):
492      source = 'LSTM'
493    else:
494      raise ValueError('Invalid bias shape: ' + str(bias_shape))
495
496    def convert_lstm_weights(weights, from_cudnn=True):
497      """Converts the weights between CuDNNLSTM and LSTM.
498
499      Args:
500        weights: Original weights.
501        from_cudnn: Indicates whether original weights are from CuDNN layer.
502
503      Returns:
504        Updated weights compatible with LSTM.
505      """
506
507      # Transpose (and reshape) input and recurrent kernels
508      kernels = transform_kernels(weights[0], transpose_input(from_cudnn),
509                                  n_gates)
510      recurrent_kernels = transform_kernels(weights[1], lambda k: k.T, n_gates)
511      if from_cudnn:
512        # merge input and recurrent biases into a single set
513        biases = np.sum(np.split(weights[2], 2, axis=0), axis=0)
514      else:
515        # Split single set of biases evenly to two sets. The way of
516        # splitting doesn't matter as long as the two sets sum is kept.
517        biases = np.tile(0.5 * weights[2], 2)
518      return [kernels, recurrent_kernels, biases]
519
520    if source != target_class:
521      weights = convert_lstm_weights(weights, from_cudnn=source == 'CuDNNLSTM')
522
523  # convert the weights between CuDNNGRU and GRU(reset_after=True)
524  if target_class in ['GRU', 'CuDNNGRU'] and len(weights) == 3:
525    # We can determine the source of the weights from the shape of the bias.
526    # If there is no bias we skip the conversion since
527    # CuDNNGRU always has biases.
528
529    units = weights[1].shape[0]
530    bias_shape = weights[2].shape
531    n_gates = 3
532
533    def convert_gru_weights(weights, from_cudnn=True):
534      """Converts the weights between CuDNNGRU and GRU.
535
536      Args:
537        weights: Original weights.
538        from_cudnn: Indicates whether original weights are from CuDNN layer.
539
540      Returns:
541        Updated weights compatible with GRU.
542      """
543
544      kernels = transform_kernels(weights[0], transpose_input(from_cudnn),
545                                  n_gates)
546      recurrent_kernels = transform_kernels(weights[1], lambda k: k.T, n_gates)
547      biases = np.array(weights[2]).reshape((2, -1) if from_cudnn else -1)
548      return [kernels, recurrent_kernels, biases]
549
550    if bias_shape == (2 * units * n_gates,):
551      source = 'CuDNNGRU'
552    elif bias_shape == (2, units * n_gates):
553      source = 'GRU(reset_after=True)'
554    elif bias_shape == (units * n_gates,):
555      source = 'GRU(reset_after=False)'
556    else:
557      raise ValueError('Invalid bias shape: ' + str(bias_shape))
558
559    if target_class == 'CuDNNGRU':
560      target = 'CuDNNGRU'
561    elif layer.reset_after:
562      target = 'GRU(reset_after=True)'
563    else:
564      target = 'GRU(reset_after=False)'
565
566    # only convert between different types
567    if source != target:
568      types = (source, target)
569      if 'GRU(reset_after=False)' in types:
570        raise ValueError('%s is not compatible with %s' % types)
571      if source == 'CuDNNGRU':
572        weights = convert_gru_weights(weights, from_cudnn=True)
573      elif source == 'GRU(reset_after=True)':
574        weights = convert_gru_weights(weights, from_cudnn=False)
575
576  return weights
577
578
579def save_optimizer_weights_to_hdf5_group(hdf5_group, optimizer):
580  """Saves optimizer weights of a optimizer to a HDF5 group.
581
582  Args:
583      hdf5_group: HDF5 group.
584      optimizer: optimizer instance.
585  """
586
587  symbolic_weights = getattr(optimizer, 'weights')
588  if symbolic_weights:
589    weights_group = hdf5_group.create_group('optimizer_weights')
590    weight_names = [str(w.name).encode('utf8') for w in symbolic_weights]
591    save_attributes_to_hdf5_group(weights_group, 'weight_names', weight_names)
592    weight_values = backend.batch_get_value(symbolic_weights)
593    for name, val in zip(weight_names, weight_values):
594      param_dset = weights_group.create_dataset(
595          name, val.shape, dtype=val.dtype)
596      if not val.shape:
597        # scalar
598        param_dset[()] = val
599      else:
600        param_dset[:] = val
601
602
603def load_optimizer_weights_from_hdf5_group(hdf5_group):
604  """Load optimizer weights from a HDF5 group.
605
606  Args:
607      hdf5_group: A pointer to a HDF5 group.
608
609  Returns:
610      data: List of optimizer weight names.
611  """
612  weights_group = hdf5_group['optimizer_weights']
613  optimizer_weight_names = load_attributes_from_hdf5_group(
614      weights_group, 'weight_names')
615  return [weights_group[weight_name] for weight_name in optimizer_weight_names]
616
617
618def save_weights_to_hdf5_group(f, layers):
619  """Saves the weights of a list of layers to a HDF5 group.
620
621  Args:
622      f: HDF5 group.
623      layers: List of layer instances.
624  """
625  from tensorflow.python.keras import __version__ as keras_version  # pylint: disable=g-import-not-at-top
626
627  save_attributes_to_hdf5_group(
628      f, 'layer_names', [layer.name.encode('utf8') for layer in layers])
629  f.attrs['backend'] = backend.backend().encode('utf8')
630  f.attrs['keras_version'] = str(keras_version).encode('utf8')
631
632  # Sort model layers by layer name to ensure that group names are strictly
633  # growing to avoid prefix issues.
634  for layer in sorted(layers, key=lambda x: x.name):
635    g = f.create_group(layer.name)
636    weights = _legacy_weights(layer)
637    weight_values = backend.batch_get_value(weights)
638    weight_names = [w.name.encode('utf8') for w in weights]
639    save_attributes_to_hdf5_group(g, 'weight_names', weight_names)
640    for name, val in zip(weight_names, weight_values):
641      param_dset = g.create_dataset(name, val.shape, dtype=val.dtype)
642      if not val.shape:
643        # scalar
644        param_dset[()] = val
645      else:
646        param_dset[:] = val
647
648
649def load_weights_from_hdf5_group(f, layers):
650  """Implements topological (order-based) weight loading.
651
652  Args:
653      f: A pointer to a HDF5 group.
654      layers: a list of target layers.
655
656  Raises:
657      ValueError: in case of mismatch between provided layers
658          and weights file.
659  """
660  if 'keras_version' in f.attrs:
661    original_keras_version = f.attrs['keras_version']
662    if hasattr(original_keras_version, 'decode'):
663      original_keras_version = original_keras_version.decode('utf8')
664  else:
665    original_keras_version = '1'
666  if 'backend' in f.attrs:
667    original_backend = f.attrs['backend']
668    if hasattr(original_backend, 'decode'):
669      original_backend = original_backend.decode('utf8')
670  else:
671    original_backend = None
672
673  filtered_layers = []
674  for layer in layers:
675    weights = _legacy_weights(layer)
676    if weights:
677      filtered_layers.append(layer)
678
679  layer_names = load_attributes_from_hdf5_group(f, 'layer_names')
680  filtered_layer_names = []
681  for name in layer_names:
682    g = f[name]
683    weight_names = load_attributes_from_hdf5_group(g, 'weight_names')
684    if weight_names:
685      filtered_layer_names.append(name)
686  layer_names = filtered_layer_names
687  if len(layer_names) != len(filtered_layers):
688    raise ValueError('You are trying to load a weight file '
689                     'containing ' + str(len(layer_names)) +
690                     ' layers into a model with ' + str(len(filtered_layers)) +
691                     ' layers.')
692
693  # We batch weight value assignments in a single backend call
694  # which provides a speedup in TensorFlow.
695  weight_value_tuples = []
696  for k, name in enumerate(layer_names):
697    g = f[name]
698    weight_names = load_attributes_from_hdf5_group(g, 'weight_names')
699    weight_values = [np.asarray(g[weight_name]) for weight_name in weight_names]
700    layer = filtered_layers[k]
701    symbolic_weights = _legacy_weights(layer)
702    weight_values = preprocess_weights_for_loading(
703        layer, weight_values, original_keras_version, original_backend)
704    if len(weight_values) != len(symbolic_weights):
705      raise ValueError('Layer #' + str(k) + ' (named "' + layer.name +
706                       '" in the current model) was found to '
707                       'correspond to layer ' + name + ' in the save file. '
708                       'However the new layer ' + layer.name + ' expects ' +
709                       str(len(symbolic_weights)) +
710                       ' weights, but the saved weights have ' +
711                       str(len(weight_values)) + ' elements.')
712    weight_value_tuples += zip(symbolic_weights, weight_values)
713  backend.batch_set_value(weight_value_tuples)
714
715
716def load_weights_from_hdf5_group_by_name(
717    f, layers, skip_mismatch=False):
718  """Implements name-based weight loading.
719
720  (instead of topological weight loading).
721
722  Layers that have no matching name are skipped.
723
724  Args:
725      f: A pointer to a HDF5 group.
726      layers: a list of target layers.
727      skip_mismatch: Boolean, whether to skip loading of layers
728          where there is a mismatch in the number of weights,
729          or a mismatch in the shape of the weights.
730
731  Raises:
732      ValueError: in case of mismatch between provided layers
733          and weights file and skip_match=False.
734  """
735  if 'keras_version' in f.attrs:
736    original_keras_version = f.attrs['keras_version']
737    if hasattr(original_keras_version, 'decode'):
738      original_keras_version = original_keras_version.decode('utf8')
739  else:
740    original_keras_version = '1'
741  if 'backend' in f.attrs:
742    original_backend = f.attrs['backend']
743    if hasattr(original_backend, 'decode'):
744      original_backend = original_backend.decode('utf8')
745  else:
746    original_backend = None
747
748  # New file format.
749  layer_names = load_attributes_from_hdf5_group(f, 'layer_names')
750
751  # Reverse index of layer name to list of layers with name.
752  index = {}
753  for layer in layers:
754    if layer.name:
755      index.setdefault(layer.name, []).append(layer)
756
757  # We batch weight value assignments in a single backend call
758  # which provides a speedup in TensorFlow.
759  weight_value_tuples = []
760  for k, name in enumerate(layer_names):
761    g = f[name]
762    weight_names = load_attributes_from_hdf5_group(g, 'weight_names')
763    weight_values = [np.asarray(g[weight_name]) for weight_name in weight_names]
764
765    for layer in index.get(name, []):
766      symbolic_weights = _legacy_weights(layer)
767      weight_values = preprocess_weights_for_loading(
768          layer, weight_values, original_keras_version, original_backend)
769      if len(weight_values) != len(symbolic_weights):
770        if skip_mismatch:
771          logging.warning('Skipping loading of weights for '
772                          'layer {}'.format(layer.name) + ' due to mismatch '
773                          'in number of weights ({} vs {}).'.format(
774                              len(symbolic_weights), len(weight_values)))
775          continue
776        raise ValueError('Layer #' + str(k) + ' (named "' + layer.name +
777                         '") expects ' + str(len(symbolic_weights)) +
778                         ' weight(s), but the saved weights' + ' have ' +
779                         str(len(weight_values)) + ' element(s).')
780      # Set values.
781      for i in range(len(weight_values)):
782        if backend.int_shape(symbolic_weights[i]) != weight_values[i].shape:
783          if skip_mismatch:
784            logging.warning('Skipping loading of weights for '
785                            'layer {}'.format(layer.name) + ' due to '
786                            'mismatch in shape ({} vs {}).'.format(
787                                symbolic_weights[i].shape,
788                                weight_values[i].shape))
789            continue
790          raise ValueError('Layer #' + str(k) +' (named "' + layer.name +
791                           '"), weight ' + str(symbolic_weights[i]) +
792                           ' has shape {}'.format(backend.int_shape(
793                               symbolic_weights[i])) +
794                           ', but the saved weight has shape ' +
795                           str(weight_values[i].shape) + '.')
796
797        else:
798          weight_value_tuples.append((symbolic_weights[i], weight_values[i]))
799  backend.batch_set_value(weight_value_tuples)
800
801
802def save_attributes_to_hdf5_group(group, name, data):
803  """Saves attributes (data) of the specified name into the HDF5 group.
804
805  This method deals with an inherent problem of HDF5 file which is not
806  able to store data larger than HDF5_OBJECT_HEADER_LIMIT bytes.
807
808  Args:
809      group: A pointer to a HDF5 group.
810      name: A name of the attributes to save.
811      data: Attributes data to store.
812
813  Raises:
814    RuntimeError: If any single attribute is too large to be saved.
815  """
816  # Check that no item in `data` is larger than `HDF5_OBJECT_HEADER_LIMIT`
817  # because in that case even chunking the array would not make the saving
818  # possible.
819  bad_attributes = [x for x in data if len(x) > HDF5_OBJECT_HEADER_LIMIT]
820
821  # Expecting this to never be true.
822  if bad_attributes:
823    raise RuntimeError('The following attributes cannot be saved to HDF5 '
824                       'file because they are larger than %d bytes: %s' %
825                       (HDF5_OBJECT_HEADER_LIMIT, ', '.join(bad_attributes)))
826
827  data_npy = np.asarray(data)
828
829  num_chunks = 1
830  chunked_data = np.array_split(data_npy, num_chunks)
831
832  # This will never loop forever thanks to the test above.
833  while any(x.nbytes > HDF5_OBJECT_HEADER_LIMIT for x in chunked_data):
834    num_chunks += 1
835    chunked_data = np.array_split(data_npy, num_chunks)
836
837  if num_chunks > 1:
838    for chunk_id, chunk_data in enumerate(chunked_data):
839      group.attrs['%s%d' % (name, chunk_id)] = chunk_data
840  else:
841    group.attrs[name] = data
842
843
844def load_attributes_from_hdf5_group(group, name):
845  """Loads attributes of the specified name from the HDF5 group.
846
847  This method deals with an inherent problem
848  of HDF5 file which is not able to store
849  data larger than HDF5_OBJECT_HEADER_LIMIT bytes.
850
851  Args:
852      group: A pointer to a HDF5 group.
853      name: A name of the attributes to load.
854
855  Returns:
856      data: Attributes data.
857  """
858  if name in group.attrs:
859    data = [
860        n.decode('utf8') if hasattr(n, 'decode') else n
861        for n in group.attrs[name]
862    ]
863  else:
864    data = []
865    chunk_id = 0
866    while '%s%d' % (name, chunk_id) in group.attrs:
867      data.extend([
868          n.decode('utf8') if hasattr(n, 'decode') else n
869          for n in group.attrs['%s%d' % (name, chunk_id)]
870      ])
871      chunk_id += 1
872  return data
873
874
875def _legacy_weights(layer):
876  """DO NOT USE.
877
878  For legacy reason, the layer.weights was in the order of
879  [self.trainable_weights + self.non_trainable_weights], and this order was
880  used for preserving the weights in h5 format. The new order of layer.weights
881  are the same as layer.get_weights() which is more intuitive for user. To
882  keep supporting the existing saved h5 file, this method should be used to
883  save/load weights. In future version, we will delete this method and
884  introduce a breaking change for h5 and stay with the new order for weights.
885
886  Args:
887    layer: a `tf.keras.Model` or `tf.keras.layers.Layer` instance.
888
889  Returns:
890    A list of variables with the order of trainable_weights, followed by
891      non_trainable_weights.
892  """
893  weights = layer.trainable_weights + layer.non_trainable_weights
894  if any(not isinstance(w, variables_module.Variable) for w in weights):
895    raise NotImplementedError(
896        'Save or restore weights that is not an instance of `tf.Variable` is '
897        'not supported in h5, use `save_format=\'tf\'` instead. Got a model '
898        'or layer {} with weights {}'.format(layer.__class__.__name__, weights))
899  return weights
900