1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=protected-access 16"""Functions for saving and loading a Keras Model from HDF5 format.""" 17 18import json 19import os 20 21import numpy as np 22 23from tensorflow.python.keras import backend 24from tensorflow.python.keras import optimizer_v1 25from tensorflow.python.keras.saving import model_config as model_config_lib 26from tensorflow.python.keras.saving import saving_utils 27from tensorflow.python.keras.saving.saved_model import json_utils 28from tensorflow.python.keras.utils.generic_utils import LazyLoader 29from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite 30from tensorflow.python.ops import variables as variables_module 31from tensorflow.python.platform import gfile 32from tensorflow.python.platform import tf_logging as logging 33 34 35# pylint: disable=g-import-not-at-top 36try: 37 import h5py 38 HDF5_OBJECT_HEADER_LIMIT = 64512 39except ImportError: 40 h5py = None 41# pylint: enable=g-import-not-at-top 42 43# TODO(b/134426265): Switch back to single-quotes to match the rest of the file 44# once the issue with copybara is fixed. 45# pylint:disable=g-inconsistent-quotes 46sequential_lib = LazyLoader( 47 "sequential_lib", globals(), 48 "tensorflow.python.keras.engine.sequential") 49# pylint:enable=g-inconsistent-quotes 50 51 52def save_model_to_hdf5(model, filepath, overwrite=True, include_optimizer=True): 53 """Saves a model to a HDF5 file. 54 55 The saved model contains: 56 - the model's configuration (topology) 57 - the model's weights 58 - the model's optimizer's state (if any) 59 60 Thus the saved model can be reinstantiated in 61 the exact same state, without any of the code 62 used for model definition or training. 63 64 Args: 65 model: Keras model instance to be saved. 66 filepath: One of the following: 67 - String, path where to save the model 68 - `h5py.File` object where to save the model 69 overwrite: Whether we should overwrite any existing 70 model at the target location, or instead 71 ask the user with a manual prompt. 72 include_optimizer: If True, save optimizer's state together. 73 74 Raises: 75 ImportError: if h5py is not available. 76 """ 77 78 if h5py is None: 79 raise ImportError('`save_model` requires h5py.') 80 81 # TODO(psv) Add warning when we save models that contain non-serializable 82 # entities like metrics added using `add_metric` and losses added using 83 # `add_loss.` 84 if len(model.weights) != len(model._undeduplicated_weights): 85 logging.warning('Found duplicated `Variable`s in Model\'s `weights`. ' 86 'This is usually caused by `Variable`s being shared by ' 87 'Layers in the Model. These `Variable`s will be treated ' 88 'as separate `Variable`s when the Model is restored. To ' 89 'avoid this, please save with `save_format="tf"`.') 90 91 if not isinstance(filepath, h5py.File): 92 # If file exists and should not be overwritten. 93 if not overwrite and os.path.isfile(filepath): 94 proceed = ask_to_proceed_with_overwrite(filepath) 95 if not proceed: 96 return 97 98 # Try creating dir if not exist 99 dirpath = os.path.dirname(filepath) 100 if not os.path.exists(dirpath): 101 gfile.MakeDirs(dirpath) 102 103 f = h5py.File(filepath, mode='w') 104 opened_new_file = True 105 else: 106 f = filepath 107 opened_new_file = False 108 109 try: 110 model_metadata = saving_utils.model_metadata(model, include_optimizer) 111 for k, v in model_metadata.items(): 112 if isinstance(v, (dict, list, tuple)): 113 f.attrs[k] = json.dumps( 114 v, default=json_utils.get_json_type).encode('utf8') 115 else: 116 f.attrs[k] = v 117 118 model_weights_group = f.create_group('model_weights') 119 model_layers = model.layers 120 save_weights_to_hdf5_group(model_weights_group, model_layers) 121 122 # TODO(b/128683857): Add integration tests between tf.keras and external 123 # Keras, to avoid breaking TF.js users. 124 if (include_optimizer and model.optimizer and 125 not isinstance(model.optimizer, optimizer_v1.TFOptimizer)): 126 save_optimizer_weights_to_hdf5_group(f, model.optimizer) 127 128 f.flush() 129 finally: 130 if opened_new_file: 131 f.close() 132 133 134def load_model_from_hdf5(filepath, custom_objects=None, compile=True): # pylint: disable=redefined-builtin 135 """Loads a model saved via `save_model_to_hdf5`. 136 137 Args: 138 filepath: One of the following: 139 - String, path to the saved model 140 - `h5py.File` object from which to load the model 141 custom_objects: Optional dictionary mapping names 142 (strings) to custom classes or functions to be 143 considered during deserialization. 144 compile: Boolean, whether to compile the model 145 after loading. 146 147 Returns: 148 A Keras model instance. If an optimizer was found 149 as part of the saved model, the model is already 150 compiled. Otherwise, the model is uncompiled and 151 a warning will be displayed. When `compile` is set 152 to False, the compilation is omitted without any 153 warning. 154 155 Raises: 156 ImportError: if h5py is not available. 157 ValueError: In case of an invalid savefile. 158 """ 159 if h5py is None: 160 raise ImportError('`load_model` requires h5py.') 161 162 if not custom_objects: 163 custom_objects = {} 164 165 opened_new_file = not isinstance(filepath, h5py.File) 166 if opened_new_file: 167 f = h5py.File(filepath, mode='r') 168 else: 169 f = filepath 170 171 model = None 172 try: 173 # instantiate model 174 model_config = f.attrs.get('model_config') 175 if model_config is None: 176 raise ValueError('No model found in config file.') 177 if hasattr(model_config, 'decode'): 178 model_config = model_config.decode('utf-8') 179 model_config = json_utils.decode(model_config) 180 model = model_config_lib.model_from_config(model_config, 181 custom_objects=custom_objects) 182 183 # set weights 184 load_weights_from_hdf5_group(f['model_weights'], model.layers) 185 186 if compile: 187 # instantiate optimizer 188 training_config = f.attrs.get('training_config') 189 if hasattr(training_config, 'decode'): 190 training_config = training_config.decode('utf-8') 191 if training_config is None: 192 logging.warning('No training configuration found in the save file, so ' 193 'the model was *not* compiled. Compile it manually.') 194 return model 195 training_config = json_utils.decode(training_config) 196 197 # Compile model. 198 model.compile(**saving_utils.compile_args_from_training_config( 199 training_config, custom_objects), from_serialized=True) 200 saving_utils.try_build_compiled_arguments(model) 201 202 # Set optimizer weights. 203 if 'optimizer_weights' in f: 204 try: 205 model.optimizer._create_all_weights(model.trainable_variables) 206 except (NotImplementedError, AttributeError): 207 logging.warning( 208 'Error when creating the weights of optimizer {}, making it ' 209 'impossible to restore the saved optimizer state. As a result, ' 210 'your model is starting with a freshly initialized optimizer.') 211 212 optimizer_weight_values = load_optimizer_weights_from_hdf5_group(f) 213 try: 214 model.optimizer.set_weights(optimizer_weight_values) 215 except ValueError: 216 logging.warning('Error in loading the saved optimizer ' 217 'state. As a result, your model is ' 218 'starting with a freshly initialized ' 219 'optimizer.') 220 finally: 221 if opened_new_file: 222 f.close() 223 return model 224 225 226def preprocess_weights_for_loading(layer, 227 weights, 228 original_keras_version=None, 229 original_backend=None): 230 """Preprocess layer weights between different Keras formats. 231 232 Converts layers weights from Keras 1 format to Keras 2 and also weights of 233 CuDNN layers in Keras 2. 234 235 Args: 236 layer: Layer instance. 237 weights: List of weights values (Numpy arrays). 238 original_keras_version: Keras version for the weights, as a string. 239 original_backend: Keras backend the weights were trained with, 240 as a string. 241 242 Returns: 243 A list of weights values (Numpy arrays). 244 """ 245 def convert_nested_bidirectional(weights): 246 """Converts layers nested in `Bidirectional` wrapper. 247 248 This function uses `preprocess_weights_for_loading()` for converting 249 layers. 250 251 Args: 252 weights: List of weights values (Numpy arrays). 253 254 Returns: 255 A list of weights values (Numpy arrays). 256 """ 257 num_weights_per_layer = len(weights) // 2 258 forward_weights = preprocess_weights_for_loading( 259 layer.forward_layer, weights[:num_weights_per_layer], 260 original_keras_version, original_backend) 261 backward_weights = preprocess_weights_for_loading( 262 layer.backward_layer, weights[num_weights_per_layer:], 263 original_keras_version, original_backend) 264 return forward_weights + backward_weights 265 266 def convert_nested_time_distributed(weights): 267 """Converts layers nested in `TimeDistributed` wrapper. 268 269 This function uses `preprocess_weights_for_loading()` for converting nested 270 layers. 271 272 Args: 273 weights: List of weights values (Numpy arrays). 274 275 Returns: 276 A list of weights values (Numpy arrays). 277 """ 278 return preprocess_weights_for_loading( 279 layer.layer, weights, original_keras_version, original_backend) 280 281 def convert_nested_model(weights): 282 """Converts layers nested in `Model` or `Sequential`. 283 284 This function uses `preprocess_weights_for_loading()` for converting nested 285 layers. 286 287 Args: 288 weights: List of weights values (Numpy arrays). 289 290 Returns: 291 A list of weights values (Numpy arrays). 292 """ 293 trainable_weights = weights[:len(layer.trainable_weights)] 294 non_trainable_weights = weights[len(layer.trainable_weights):] 295 296 new_trainable_weights = [] 297 new_non_trainable_weights = [] 298 299 for sublayer in layer.layers: 300 num_trainable_weights = len(sublayer.trainable_weights) 301 num_non_trainable_weights = len(sublayer.non_trainable_weights) 302 if sublayer.weights: 303 preprocessed = preprocess_weights_for_loading( 304 layer=sublayer, 305 weights=(trainable_weights[:num_trainable_weights] + 306 non_trainable_weights[:num_non_trainable_weights]), 307 original_keras_version=original_keras_version, 308 original_backend=original_backend) 309 new_trainable_weights.extend(preprocessed[:num_trainable_weights]) 310 new_non_trainable_weights.extend(preprocessed[num_trainable_weights:]) 311 312 trainable_weights = trainable_weights[num_trainable_weights:] 313 non_trainable_weights = non_trainable_weights[ 314 num_non_trainable_weights:] 315 316 return new_trainable_weights + new_non_trainable_weights 317 318 # Convert layers nested in Bidirectional/Model/Sequential. 319 # Both transformation should be ran for both Keras 1->2 conversion 320 # and for conversion of CuDNN layers. 321 if layer.__class__.__name__ == 'Bidirectional': 322 weights = convert_nested_bidirectional(weights) 323 if layer.__class__.__name__ == 'TimeDistributed': 324 weights = convert_nested_time_distributed(weights) 325 elif layer.__class__.__name__ in ['Model', 'Sequential', 'Functional']: 326 weights = convert_nested_model(weights) 327 328 if original_keras_version == '1': 329 if layer.__class__.__name__ == 'TimeDistributed': 330 weights = preprocess_weights_for_loading( 331 layer.layer, weights, original_keras_version, original_backend) 332 333 if layer.__class__.__name__ == 'Conv1D': 334 shape = weights[0].shape 335 # Handle Keras 1.1 format 336 if shape[:2] != (layer.kernel_size[0], 1) or shape[3] != layer.filters: 337 # Legacy shape: 338 # (filters, input_dim, filter_length, 1) 339 assert shape[0] == layer.filters and shape[2:] == (layer.kernel_size[0], 340 1) 341 weights[0] = np.transpose(weights[0], (2, 3, 1, 0)) 342 weights[0] = weights[0][:, 0, :, :] 343 344 if layer.__class__.__name__ == 'Conv2D': 345 if layer.data_format == 'channels_first': 346 # old: (filters, stack_size, kernel_rows, kernel_cols) 347 # new: (kernel_rows, kernel_cols, stack_size, filters) 348 weights[0] = np.transpose(weights[0], (2, 3, 1, 0)) 349 350 if layer.__class__.__name__ == 'Conv2DTranspose': 351 if layer.data_format == 'channels_last': 352 # old: (kernel_rows, kernel_cols, stack_size, filters) 353 # new: (kernel_rows, kernel_cols, filters, stack_size) 354 weights[0] = np.transpose(weights[0], (0, 1, 3, 2)) 355 if layer.data_format == 'channels_first': 356 # old: (filters, stack_size, kernel_rows, kernel_cols) 357 # new: (kernel_rows, kernel_cols, filters, stack_size) 358 weights[0] = np.transpose(weights[0], (2, 3, 0, 1)) 359 360 if layer.__class__.__name__ == 'Conv3D': 361 if layer.data_format == 'channels_first': 362 # old: (filters, stack_size, ...) 363 # new: (..., stack_size, filters) 364 weights[0] = np.transpose(weights[0], (2, 3, 4, 1, 0)) 365 366 if layer.__class__.__name__ == 'GRU': 367 if len(weights) == 9: 368 kernel = np.concatenate([weights[0], weights[3], weights[6]], axis=-1) 369 recurrent_kernel = np.concatenate( 370 [weights[1], weights[4], weights[7]], axis=-1) 371 bias = np.concatenate([weights[2], weights[5], weights[8]], axis=-1) 372 weights = [kernel, recurrent_kernel, bias] 373 374 if layer.__class__.__name__ == 'LSTM': 375 if len(weights) == 12: 376 # old: i, c, f, o 377 # new: i, f, c, o 378 kernel = np.concatenate( 379 [weights[0], weights[6], weights[3], weights[9]], axis=-1) 380 recurrent_kernel = np.concatenate( 381 [weights[1], weights[7], weights[4], weights[10]], axis=-1) 382 bias = np.concatenate( 383 [weights[2], weights[8], weights[5], weights[11]], axis=-1) 384 weights = [kernel, recurrent_kernel, bias] 385 386 if layer.__class__.__name__ == 'ConvLSTM2D': 387 if len(weights) == 12: 388 kernel = np.concatenate( 389 [weights[0], weights[6], weights[3], weights[9]], axis=-1) 390 recurrent_kernel = np.concatenate( 391 [weights[1], weights[7], weights[4], weights[10]], axis=-1) 392 bias = np.concatenate( 393 [weights[2], weights[8], weights[5], weights[11]], axis=-1) 394 if layer.data_format == 'channels_first': 395 # old: (filters, stack_size, kernel_rows, kernel_cols) 396 # new: (kernel_rows, kernel_cols, stack_size, filters) 397 kernel = np.transpose(kernel, (2, 3, 1, 0)) 398 recurrent_kernel = np.transpose(recurrent_kernel, (2, 3, 1, 0)) 399 weights = [kernel, recurrent_kernel, bias] 400 401 conv_layers = ['Conv1D', 'Conv2D', 'Conv3D', 'Conv2DTranspose', 'ConvLSTM2D'] 402 if layer.__class__.__name__ in conv_layers: 403 if backend.int_shape(layer.weights[0]) != weights[0].shape: 404 weights[0] = np.transpose(weights[0], (3, 2, 0, 1)) 405 if layer.__class__.__name__ == 'ConvLSTM2D': 406 weights[1] = np.transpose(weights[1], (3, 2, 0, 1)) 407 408 # convert CuDNN layers 409 return _convert_rnn_weights(layer, weights) 410 411 412def _convert_rnn_weights(layer, weights): 413 """Converts weights for RNN layers between native and CuDNN format. 414 415 Input kernels for each gate are transposed and converted between Fortran 416 and C layout, recurrent kernels are transposed. For LSTM biases are summed/ 417 split in half, for GRU biases are reshaped. 418 419 Weights can be converted in both directions between `LSTM` and`CuDNNSLTM` 420 and between `CuDNNGRU` and `GRU(reset_after=True)`. Default `GRU` is not 421 compatible with `CuDNNGRU`. 422 423 For missing biases in `LSTM`/`GRU` (`use_bias=False`) no conversion is made. 424 425 Args: 426 layer: Target layer instance. 427 weights: List of source weights values (input kernels, recurrent 428 kernels, [biases]) (Numpy arrays). 429 430 Returns: 431 A list of converted weights values (Numpy arrays). 432 433 Raises: 434 ValueError: for incompatible GRU layer/weights or incompatible biases 435 """ 436 437 def transform_kernels(kernels, func, n_gates): 438 """Transforms kernel for each gate separately using given function. 439 440 Args: 441 kernels: Stacked array of kernels for individual gates. 442 func: Function applied to kernel of each gate. 443 n_gates: Number of gates (4 for LSTM, 3 for GRU). 444 445 Returns: 446 Stacked array of transformed kernels. 447 """ 448 return np.hstack([func(k) for k in np.hsplit(kernels, n_gates)]) 449 450 def transpose_input(from_cudnn): 451 """Makes a function that transforms input kernels from/to CuDNN format. 452 453 It keeps the shape, but changes between the layout (Fortran/C). Eg.: 454 455 ``` 456 Keras CuDNN 457 [[0, 1, 2], <---> [[0, 2, 4], 458 [3, 4, 5]] [1, 3, 5]] 459 ``` 460 461 It can be passed to `transform_kernels()`. 462 463 Args: 464 from_cudnn: `True` if source weights are in CuDNN format, `False` 465 if they're in plain Keras format. 466 467 Returns: 468 Function that converts input kernel to the other format. 469 """ 470 order = 'F' if from_cudnn else 'C' 471 472 def transform(kernel): 473 return kernel.T.reshape(kernel.shape, order=order) 474 475 return transform 476 477 target_class = layer.__class__.__name__ 478 479 # convert the weights between CuDNNLSTM and LSTM 480 if target_class in ['LSTM', 'CuDNNLSTM'] and len(weights) == 3: 481 # determine if we're loading a CuDNNLSTM layer 482 # from the number of bias weights: 483 # CuDNNLSTM has (units * 8) weights; while LSTM has (units * 4) 484 # if there's no bias weight in the file, skip this conversion 485 units = weights[1].shape[0] 486 bias_shape = weights[2].shape 487 n_gates = 4 488 489 if bias_shape == (2 * units * n_gates,): 490 source = 'CuDNNLSTM' 491 elif bias_shape == (units * n_gates,): 492 source = 'LSTM' 493 else: 494 raise ValueError('Invalid bias shape: ' + str(bias_shape)) 495 496 def convert_lstm_weights(weights, from_cudnn=True): 497 """Converts the weights between CuDNNLSTM and LSTM. 498 499 Args: 500 weights: Original weights. 501 from_cudnn: Indicates whether original weights are from CuDNN layer. 502 503 Returns: 504 Updated weights compatible with LSTM. 505 """ 506 507 # Transpose (and reshape) input and recurrent kernels 508 kernels = transform_kernels(weights[0], transpose_input(from_cudnn), 509 n_gates) 510 recurrent_kernels = transform_kernels(weights[1], lambda k: k.T, n_gates) 511 if from_cudnn: 512 # merge input and recurrent biases into a single set 513 biases = np.sum(np.split(weights[2], 2, axis=0), axis=0) 514 else: 515 # Split single set of biases evenly to two sets. The way of 516 # splitting doesn't matter as long as the two sets sum is kept. 517 biases = np.tile(0.5 * weights[2], 2) 518 return [kernels, recurrent_kernels, biases] 519 520 if source != target_class: 521 weights = convert_lstm_weights(weights, from_cudnn=source == 'CuDNNLSTM') 522 523 # convert the weights between CuDNNGRU and GRU(reset_after=True) 524 if target_class in ['GRU', 'CuDNNGRU'] and len(weights) == 3: 525 # We can determine the source of the weights from the shape of the bias. 526 # If there is no bias we skip the conversion since 527 # CuDNNGRU always has biases. 528 529 units = weights[1].shape[0] 530 bias_shape = weights[2].shape 531 n_gates = 3 532 533 def convert_gru_weights(weights, from_cudnn=True): 534 """Converts the weights between CuDNNGRU and GRU. 535 536 Args: 537 weights: Original weights. 538 from_cudnn: Indicates whether original weights are from CuDNN layer. 539 540 Returns: 541 Updated weights compatible with GRU. 542 """ 543 544 kernels = transform_kernels(weights[0], transpose_input(from_cudnn), 545 n_gates) 546 recurrent_kernels = transform_kernels(weights[1], lambda k: k.T, n_gates) 547 biases = np.array(weights[2]).reshape((2, -1) if from_cudnn else -1) 548 return [kernels, recurrent_kernels, biases] 549 550 if bias_shape == (2 * units * n_gates,): 551 source = 'CuDNNGRU' 552 elif bias_shape == (2, units * n_gates): 553 source = 'GRU(reset_after=True)' 554 elif bias_shape == (units * n_gates,): 555 source = 'GRU(reset_after=False)' 556 else: 557 raise ValueError('Invalid bias shape: ' + str(bias_shape)) 558 559 if target_class == 'CuDNNGRU': 560 target = 'CuDNNGRU' 561 elif layer.reset_after: 562 target = 'GRU(reset_after=True)' 563 else: 564 target = 'GRU(reset_after=False)' 565 566 # only convert between different types 567 if source != target: 568 types = (source, target) 569 if 'GRU(reset_after=False)' in types: 570 raise ValueError('%s is not compatible with %s' % types) 571 if source == 'CuDNNGRU': 572 weights = convert_gru_weights(weights, from_cudnn=True) 573 elif source == 'GRU(reset_after=True)': 574 weights = convert_gru_weights(weights, from_cudnn=False) 575 576 return weights 577 578 579def save_optimizer_weights_to_hdf5_group(hdf5_group, optimizer): 580 """Saves optimizer weights of a optimizer to a HDF5 group. 581 582 Args: 583 hdf5_group: HDF5 group. 584 optimizer: optimizer instance. 585 """ 586 587 symbolic_weights = getattr(optimizer, 'weights') 588 if symbolic_weights: 589 weights_group = hdf5_group.create_group('optimizer_weights') 590 weight_names = [str(w.name).encode('utf8') for w in symbolic_weights] 591 save_attributes_to_hdf5_group(weights_group, 'weight_names', weight_names) 592 weight_values = backend.batch_get_value(symbolic_weights) 593 for name, val in zip(weight_names, weight_values): 594 param_dset = weights_group.create_dataset( 595 name, val.shape, dtype=val.dtype) 596 if not val.shape: 597 # scalar 598 param_dset[()] = val 599 else: 600 param_dset[:] = val 601 602 603def load_optimizer_weights_from_hdf5_group(hdf5_group): 604 """Load optimizer weights from a HDF5 group. 605 606 Args: 607 hdf5_group: A pointer to a HDF5 group. 608 609 Returns: 610 data: List of optimizer weight names. 611 """ 612 weights_group = hdf5_group['optimizer_weights'] 613 optimizer_weight_names = load_attributes_from_hdf5_group( 614 weights_group, 'weight_names') 615 return [weights_group[weight_name] for weight_name in optimizer_weight_names] 616 617 618def save_weights_to_hdf5_group(f, layers): 619 """Saves the weights of a list of layers to a HDF5 group. 620 621 Args: 622 f: HDF5 group. 623 layers: List of layer instances. 624 """ 625 from tensorflow.python.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top 626 627 save_attributes_to_hdf5_group( 628 f, 'layer_names', [layer.name.encode('utf8') for layer in layers]) 629 f.attrs['backend'] = backend.backend().encode('utf8') 630 f.attrs['keras_version'] = str(keras_version).encode('utf8') 631 632 # Sort model layers by layer name to ensure that group names are strictly 633 # growing to avoid prefix issues. 634 for layer in sorted(layers, key=lambda x: x.name): 635 g = f.create_group(layer.name) 636 weights = _legacy_weights(layer) 637 weight_values = backend.batch_get_value(weights) 638 weight_names = [w.name.encode('utf8') for w in weights] 639 save_attributes_to_hdf5_group(g, 'weight_names', weight_names) 640 for name, val in zip(weight_names, weight_values): 641 param_dset = g.create_dataset(name, val.shape, dtype=val.dtype) 642 if not val.shape: 643 # scalar 644 param_dset[()] = val 645 else: 646 param_dset[:] = val 647 648 649def load_weights_from_hdf5_group(f, layers): 650 """Implements topological (order-based) weight loading. 651 652 Args: 653 f: A pointer to a HDF5 group. 654 layers: a list of target layers. 655 656 Raises: 657 ValueError: in case of mismatch between provided layers 658 and weights file. 659 """ 660 if 'keras_version' in f.attrs: 661 original_keras_version = f.attrs['keras_version'] 662 if hasattr(original_keras_version, 'decode'): 663 original_keras_version = original_keras_version.decode('utf8') 664 else: 665 original_keras_version = '1' 666 if 'backend' in f.attrs: 667 original_backend = f.attrs['backend'] 668 if hasattr(original_backend, 'decode'): 669 original_backend = original_backend.decode('utf8') 670 else: 671 original_backend = None 672 673 filtered_layers = [] 674 for layer in layers: 675 weights = _legacy_weights(layer) 676 if weights: 677 filtered_layers.append(layer) 678 679 layer_names = load_attributes_from_hdf5_group(f, 'layer_names') 680 filtered_layer_names = [] 681 for name in layer_names: 682 g = f[name] 683 weight_names = load_attributes_from_hdf5_group(g, 'weight_names') 684 if weight_names: 685 filtered_layer_names.append(name) 686 layer_names = filtered_layer_names 687 if len(layer_names) != len(filtered_layers): 688 raise ValueError('You are trying to load a weight file ' 689 'containing ' + str(len(layer_names)) + 690 ' layers into a model with ' + str(len(filtered_layers)) + 691 ' layers.') 692 693 # We batch weight value assignments in a single backend call 694 # which provides a speedup in TensorFlow. 695 weight_value_tuples = [] 696 for k, name in enumerate(layer_names): 697 g = f[name] 698 weight_names = load_attributes_from_hdf5_group(g, 'weight_names') 699 weight_values = [np.asarray(g[weight_name]) for weight_name in weight_names] 700 layer = filtered_layers[k] 701 symbolic_weights = _legacy_weights(layer) 702 weight_values = preprocess_weights_for_loading( 703 layer, weight_values, original_keras_version, original_backend) 704 if len(weight_values) != len(symbolic_weights): 705 raise ValueError('Layer #' + str(k) + ' (named "' + layer.name + 706 '" in the current model) was found to ' 707 'correspond to layer ' + name + ' in the save file. ' 708 'However the new layer ' + layer.name + ' expects ' + 709 str(len(symbolic_weights)) + 710 ' weights, but the saved weights have ' + 711 str(len(weight_values)) + ' elements.') 712 weight_value_tuples += zip(symbolic_weights, weight_values) 713 backend.batch_set_value(weight_value_tuples) 714 715 716def load_weights_from_hdf5_group_by_name( 717 f, layers, skip_mismatch=False): 718 """Implements name-based weight loading. 719 720 (instead of topological weight loading). 721 722 Layers that have no matching name are skipped. 723 724 Args: 725 f: A pointer to a HDF5 group. 726 layers: a list of target layers. 727 skip_mismatch: Boolean, whether to skip loading of layers 728 where there is a mismatch in the number of weights, 729 or a mismatch in the shape of the weights. 730 731 Raises: 732 ValueError: in case of mismatch between provided layers 733 and weights file and skip_match=False. 734 """ 735 if 'keras_version' in f.attrs: 736 original_keras_version = f.attrs['keras_version'] 737 if hasattr(original_keras_version, 'decode'): 738 original_keras_version = original_keras_version.decode('utf8') 739 else: 740 original_keras_version = '1' 741 if 'backend' in f.attrs: 742 original_backend = f.attrs['backend'] 743 if hasattr(original_backend, 'decode'): 744 original_backend = original_backend.decode('utf8') 745 else: 746 original_backend = None 747 748 # New file format. 749 layer_names = load_attributes_from_hdf5_group(f, 'layer_names') 750 751 # Reverse index of layer name to list of layers with name. 752 index = {} 753 for layer in layers: 754 if layer.name: 755 index.setdefault(layer.name, []).append(layer) 756 757 # We batch weight value assignments in a single backend call 758 # which provides a speedup in TensorFlow. 759 weight_value_tuples = [] 760 for k, name in enumerate(layer_names): 761 g = f[name] 762 weight_names = load_attributes_from_hdf5_group(g, 'weight_names') 763 weight_values = [np.asarray(g[weight_name]) for weight_name in weight_names] 764 765 for layer in index.get(name, []): 766 symbolic_weights = _legacy_weights(layer) 767 weight_values = preprocess_weights_for_loading( 768 layer, weight_values, original_keras_version, original_backend) 769 if len(weight_values) != len(symbolic_weights): 770 if skip_mismatch: 771 logging.warning('Skipping loading of weights for ' 772 'layer {}'.format(layer.name) + ' due to mismatch ' 773 'in number of weights ({} vs {}).'.format( 774 len(symbolic_weights), len(weight_values))) 775 continue 776 raise ValueError('Layer #' + str(k) + ' (named "' + layer.name + 777 '") expects ' + str(len(symbolic_weights)) + 778 ' weight(s), but the saved weights' + ' have ' + 779 str(len(weight_values)) + ' element(s).') 780 # Set values. 781 for i in range(len(weight_values)): 782 if backend.int_shape(symbolic_weights[i]) != weight_values[i].shape: 783 if skip_mismatch: 784 logging.warning('Skipping loading of weights for ' 785 'layer {}'.format(layer.name) + ' due to ' 786 'mismatch in shape ({} vs {}).'.format( 787 symbolic_weights[i].shape, 788 weight_values[i].shape)) 789 continue 790 raise ValueError('Layer #' + str(k) +' (named "' + layer.name + 791 '"), weight ' + str(symbolic_weights[i]) + 792 ' has shape {}'.format(backend.int_shape( 793 symbolic_weights[i])) + 794 ', but the saved weight has shape ' + 795 str(weight_values[i].shape) + '.') 796 797 else: 798 weight_value_tuples.append((symbolic_weights[i], weight_values[i])) 799 backend.batch_set_value(weight_value_tuples) 800 801 802def save_attributes_to_hdf5_group(group, name, data): 803 """Saves attributes (data) of the specified name into the HDF5 group. 804 805 This method deals with an inherent problem of HDF5 file which is not 806 able to store data larger than HDF5_OBJECT_HEADER_LIMIT bytes. 807 808 Args: 809 group: A pointer to a HDF5 group. 810 name: A name of the attributes to save. 811 data: Attributes data to store. 812 813 Raises: 814 RuntimeError: If any single attribute is too large to be saved. 815 """ 816 # Check that no item in `data` is larger than `HDF5_OBJECT_HEADER_LIMIT` 817 # because in that case even chunking the array would not make the saving 818 # possible. 819 bad_attributes = [x for x in data if len(x) > HDF5_OBJECT_HEADER_LIMIT] 820 821 # Expecting this to never be true. 822 if bad_attributes: 823 raise RuntimeError('The following attributes cannot be saved to HDF5 ' 824 'file because they are larger than %d bytes: %s' % 825 (HDF5_OBJECT_HEADER_LIMIT, ', '.join(bad_attributes))) 826 827 data_npy = np.asarray(data) 828 829 num_chunks = 1 830 chunked_data = np.array_split(data_npy, num_chunks) 831 832 # This will never loop forever thanks to the test above. 833 while any(x.nbytes > HDF5_OBJECT_HEADER_LIMIT for x in chunked_data): 834 num_chunks += 1 835 chunked_data = np.array_split(data_npy, num_chunks) 836 837 if num_chunks > 1: 838 for chunk_id, chunk_data in enumerate(chunked_data): 839 group.attrs['%s%d' % (name, chunk_id)] = chunk_data 840 else: 841 group.attrs[name] = data 842 843 844def load_attributes_from_hdf5_group(group, name): 845 """Loads attributes of the specified name from the HDF5 group. 846 847 This method deals with an inherent problem 848 of HDF5 file which is not able to store 849 data larger than HDF5_OBJECT_HEADER_LIMIT bytes. 850 851 Args: 852 group: A pointer to a HDF5 group. 853 name: A name of the attributes to load. 854 855 Returns: 856 data: Attributes data. 857 """ 858 if name in group.attrs: 859 data = [ 860 n.decode('utf8') if hasattr(n, 'decode') else n 861 for n in group.attrs[name] 862 ] 863 else: 864 data = [] 865 chunk_id = 0 866 while '%s%d' % (name, chunk_id) in group.attrs: 867 data.extend([ 868 n.decode('utf8') if hasattr(n, 'decode') else n 869 for n in group.attrs['%s%d' % (name, chunk_id)] 870 ]) 871 chunk_id += 1 872 return data 873 874 875def _legacy_weights(layer): 876 """DO NOT USE. 877 878 For legacy reason, the layer.weights was in the order of 879 [self.trainable_weights + self.non_trainable_weights], and this order was 880 used for preserving the weights in h5 format. The new order of layer.weights 881 are the same as layer.get_weights() which is more intuitive for user. To 882 keep supporting the existing saved h5 file, this method should be used to 883 save/load weights. In future version, we will delete this method and 884 introduce a breaking change for h5 and stay with the new order for weights. 885 886 Args: 887 layer: a `tf.keras.Model` or `tf.keras.layers.Layer` instance. 888 889 Returns: 890 A list of variables with the order of trainable_weights, followed by 891 non_trainable_weights. 892 """ 893 weights = layer.trainable_weights + layer.non_trainable_weights 894 if any(not isinstance(w, variables_module.Variable) for w in weights): 895 raise NotImplementedError( 896 'Save or restore weights that is not an instance of `tf.Variable` is ' 897 'not supported in h5, use `save_format=\'tf\'` instead. Got a model ' 898 'or layer {} with weights {}'.format(layer.__class__.__name__, weights)) 899 return weights 900