1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utils related to keras model saving.""" 16 17import collections 18import copy 19import os 20 21from tensorflow.python.eager import def_function 22from tensorflow.python.keras import backend as K 23from tensorflow.python.keras import losses 24from tensorflow.python.keras import optimizer_v1 25from tensorflow.python.keras import optimizers 26from tensorflow.python.keras.engine import base_layer_utils 27from tensorflow.python.keras.utils import generic_utils 28from tensorflow.python.keras.utils import version_utils 29from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite 30from tensorflow.python.platform import tf_logging as logging 31from tensorflow.python.util import nest 32 33 34def extract_model_metrics(model): 35 """Convert metrics from a Keras model `compile` API to dictionary. 36 37 This is used for converting Keras models to Estimators and SavedModels. 38 39 Args: 40 model: A `tf.keras.Model` object. 41 42 Returns: 43 Dictionary mapping metric names to metric instances. May return `None` if 44 the model does not contain any metrics. 45 """ 46 if getattr(model, '_compile_metrics', None): 47 # TODO(psv/kathywu): use this implementation in model to estimator flow. 48 # We are not using model.metrics here because we want to exclude the metrics 49 # added using `add_metric` API. 50 return {m.name: m for m in model._compile_metric_functions} # pylint: disable=protected-access 51 return None 52 53 54def model_input_signature(model, keep_original_batch_size=False): 55 """Inspect model to get its input signature. 56 57 The model's input signature is a list with a single (possibly-nested) object. 58 This is due to the Keras-enforced restriction that tensor inputs must be 59 passed in as the first argument. 60 61 For example, a model with input {'feature1': <Tensor>, 'feature2': <Tensor>} 62 will have input signature: [{'feature1': TensorSpec, 'feature2': TensorSpec}] 63 64 Args: 65 model: Keras Model object. 66 keep_original_batch_size: A boolean indicating whether we want to keep using 67 the original batch size or set it to None. Default is `False`, which means 68 that the batch dim of the returned input signature will always be set to 69 `None`. 70 71 Returns: 72 A list containing either a single TensorSpec or an object with nested 73 TensorSpecs. This list does not contain the `training` argument. 74 """ 75 input_specs = model._get_save_spec(dynamic_batch=not keep_original_batch_size) # pylint: disable=protected-access 76 if input_specs is None: 77 return None 78 input_specs = _enforce_names_consistency(input_specs) 79 # Return a list with a single element as the model's input signature. 80 if isinstance(input_specs, 81 collections.abc.Sequence) and len(input_specs) == 1: 82 # Note that the isinstance check filters out single-element dictionaries, 83 # which should also be wrapped as a single-element list. 84 return input_specs 85 else: 86 return [input_specs] 87 88 89def raise_model_input_error(model): 90 raise ValueError( 91 'Model {} cannot be saved because the input shapes have not been ' 92 'set. Usually, input shapes are automatically determined from calling' 93 ' `.fit()` or `.predict()`. To manually set the shapes, call ' 94 '`model.build(input_shape)`.'.format(model)) 95 96 97def trace_model_call(model, input_signature=None): 98 """Trace the model call to create a tf.function for exporting a Keras model. 99 100 Args: 101 model: A Keras model. 102 input_signature: optional, a list of tf.TensorSpec objects specifying the 103 inputs to the model. 104 105 Returns: 106 A tf.function wrapping the model's call function with input signatures set. 107 108 Raises: 109 ValueError: if input signature cannot be inferred from the model. 110 """ 111 if input_signature is None: 112 if isinstance(model.call, def_function.Function): 113 input_signature = model.call.input_signature 114 115 if input_signature is None: 116 input_signature = model_input_signature(model) 117 118 if input_signature is None: 119 raise_model_input_error(model) 120 121 @def_function.function(input_signature=input_signature) 122 def _wrapped_model(*args): 123 """A concrete tf.function that wraps the model's call function.""" 124 # When given a single input, Keras models will call the model on the tensor 125 # rather than a list consisting of the single tensor. 126 inputs = args[0] if len(input_signature) == 1 else list(args) 127 128 with base_layer_utils.call_context().enter( 129 model, inputs=inputs, build_graph=False, training=False, saving=True): 130 outputs = model(inputs, training=False) 131 132 # Outputs always has to be a flat dict. 133 output_names = model.output_names # Functional Model. 134 if output_names is None: # Subclassed Model. 135 from tensorflow.python.keras.engine import compile_utils # pylint: disable=g-import-not-at-top 136 output_names = compile_utils.create_pseudo_output_names(outputs) 137 outputs = nest.flatten(outputs) 138 return {name: output for name, output in zip(output_names, outputs)} 139 140 return _wrapped_model 141 142 143def model_metadata(model, include_optimizer=True, require_config=True): 144 """Returns a dictionary containing the model metadata.""" 145 from tensorflow.python.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top 146 from tensorflow.python.keras.optimizer_v2 import optimizer_v2 # pylint: disable=g-import-not-at-top 147 148 model_config = {'class_name': model.__class__.__name__} 149 try: 150 model_config['config'] = model.get_config() 151 except NotImplementedError as e: 152 if require_config: 153 raise e 154 155 metadata = dict( 156 keras_version=str(keras_version), 157 backend=K.backend(), 158 model_config=model_config) 159 if model.optimizer and include_optimizer: 160 if isinstance(model.optimizer, optimizer_v1.TFOptimizer): 161 logging.warning( 162 'TensorFlow optimizers do not ' 163 'make it possible to access ' 164 'optimizer attributes or optimizer state ' 165 'after instantiation. ' 166 'As a result, we cannot save the optimizer ' 167 'as part of the model save file. ' 168 'You will have to compile your model again after loading it. ' 169 'Prefer using a Keras optimizer instead ' 170 '(see keras.io/optimizers).') 171 elif model._compile_was_called: # pylint: disable=protected-access 172 training_config = model._get_compile_args(user_metrics=False) # pylint: disable=protected-access 173 training_config.pop('optimizer', None) # Handled separately. 174 metadata['training_config'] = _serialize_nested_config(training_config) 175 if isinstance(model.optimizer, optimizer_v2.RestoredOptimizer): 176 raise NotImplementedError( 177 'As of now, Optimizers loaded from SavedModel cannot be saved. ' 178 'If you\'re calling `model.save` or `tf.keras.models.save_model`,' 179 ' please set the `include_optimizer` option to `False`. For ' 180 '`tf.saved_model.save`, delete the optimizer from the model.') 181 else: 182 optimizer_config = { 183 'class_name': 184 generic_utils.get_registered_name(model.optimizer.__class__), 185 'config': 186 model.optimizer.get_config() 187 } 188 metadata['training_config']['optimizer_config'] = optimizer_config 189 return metadata 190 191 192def should_overwrite(filepath, overwrite): 193 """Returns whether the filepath should be overwritten.""" 194 # If file exists and should not be overwritten. 195 if not overwrite and os.path.isfile(filepath): 196 return ask_to_proceed_with_overwrite(filepath) 197 return True 198 199 200def compile_args_from_training_config(training_config, custom_objects=None): 201 """Return model.compile arguments from training config.""" 202 if custom_objects is None: 203 custom_objects = {} 204 205 with generic_utils.CustomObjectScope(custom_objects): 206 optimizer_config = training_config['optimizer_config'] 207 optimizer = optimizers.deserialize(optimizer_config) 208 209 # Recover losses. 210 loss = None 211 loss_config = training_config.get('loss', None) 212 if loss_config is not None: 213 loss = _deserialize_nested_config(losses.deserialize, loss_config) 214 215 # Recover metrics. 216 metrics = None 217 metrics_config = training_config.get('metrics', None) 218 if metrics_config is not None: 219 metrics = _deserialize_nested_config(_deserialize_metric, metrics_config) 220 221 # Recover weighted metrics. 222 weighted_metrics = None 223 weighted_metrics_config = training_config.get('weighted_metrics', None) 224 if weighted_metrics_config is not None: 225 weighted_metrics = _deserialize_nested_config(_deserialize_metric, 226 weighted_metrics_config) 227 228 sample_weight_mode = training_config['sample_weight_mode'] if hasattr( 229 training_config, 'sample_weight_mode') else None 230 loss_weights = training_config['loss_weights'] 231 232 return dict( 233 optimizer=optimizer, 234 loss=loss, 235 metrics=metrics, 236 weighted_metrics=weighted_metrics, 237 loss_weights=loss_weights, 238 sample_weight_mode=sample_weight_mode) 239 240 241def _deserialize_nested_config(deserialize_fn, config): 242 """Deserializes arbitrary Keras `config` using `deserialize_fn`.""" 243 244 def _is_single_object(obj): 245 if isinstance(obj, dict) and 'class_name' in obj: 246 return True # Serialized Keras object. 247 if isinstance(obj, str): 248 return True # Serialized function or string. 249 return False 250 251 if config is None: 252 return None 253 if _is_single_object(config): 254 return deserialize_fn(config) 255 elif isinstance(config, dict): 256 return { 257 k: _deserialize_nested_config(deserialize_fn, v) 258 for k, v in config.items() 259 } 260 elif isinstance(config, (tuple, list)): 261 return [_deserialize_nested_config(deserialize_fn, obj) for obj in config] 262 263 raise ValueError('Saved configuration not understood.') 264 265 266def _serialize_nested_config(config): 267 """Serialized a nested structure of Keras objects.""" 268 269 def _serialize_fn(obj): 270 if callable(obj): 271 return generic_utils.serialize_keras_object(obj) 272 return obj 273 274 return nest.map_structure(_serialize_fn, config) 275 276 277def _deserialize_metric(metric_config): 278 """Deserialize metrics, leaving special strings untouched.""" 279 from tensorflow.python.keras import metrics as metrics_module # pylint:disable=g-import-not-at-top 280 if metric_config in ['accuracy', 'acc', 'crossentropy', 'ce']: 281 # Do not deserialize accuracy and cross-entropy strings as we have special 282 # case handling for these in compile, based on model output shape. 283 return metric_config 284 return metrics_module.deserialize(metric_config) 285 286 287def _enforce_names_consistency(specs): 288 """Enforces that either all specs have names or none do.""" 289 290 def _has_name(spec): 291 return hasattr(spec, 'name') and spec.name is not None 292 293 def _clear_name(spec): 294 spec = copy.deepcopy(spec) 295 if hasattr(spec, 'name'): 296 spec._name = None # pylint:disable=protected-access 297 return spec 298 299 flat_specs = nest.flatten(specs) 300 name_inconsistency = ( 301 any(_has_name(s) for s in flat_specs) and 302 not all(_has_name(s) for s in flat_specs)) 303 304 if name_inconsistency: 305 specs = nest.map_structure(_clear_name, specs) 306 return specs 307 308 309def try_build_compiled_arguments(model): 310 if (not version_utils.is_v1_layer_or_model(model) and 311 model.outputs is not None): 312 try: 313 if not model.compiled_loss.built: 314 model.compiled_loss.build(model.outputs) 315 if not model.compiled_metrics.built: 316 model.compiled_metrics.build(model.outputs, model.outputs) 317 except: # pylint: disable=bare-except 318 logging.warning( 319 'Compiled the loaded model, but the compiled metrics have yet to ' 320 'be built. `model.compile_metrics` will be empty until you train ' 321 'or evaluate the model.') 322 323 324def is_hdf5_filepath(filepath): 325 return (filepath.endswith('.h5') or filepath.endswith('.keras') or 326 filepath.endswith('.hdf5')) 327