1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Classes and functions implementing Layer SavedModel serialization.""" 16 17from tensorflow.python.keras.mixed_precision import policy 18from tensorflow.python.keras.saving.saved_model import base_serialization 19from tensorflow.python.keras.saving.saved_model import constants 20from tensorflow.python.keras.saving.saved_model import save_impl 21from tensorflow.python.keras.saving.saved_model import serialized_attributes 22from tensorflow.python.keras.utils import generic_utils 23from tensorflow.python.trackable import data_structures 24from tensorflow.python.util import nest 25 26 27class LayerSavedModelSaver(base_serialization.SavedModelSaver): 28 """Implements Layer SavedModel serialization.""" 29 30 @property 31 def object_identifier(self): 32 return constants.LAYER_IDENTIFIER 33 34 @property 35 def python_properties(self): 36 # TODO(kathywu): Add python property validator 37 return self._python_properties_internal() 38 39 def _python_properties_internal(self): 40 """Returns dictionary of all python properties.""" 41 # TODO(kathywu): Add support for metrics serialization. 42 # TODO(kathywu): Synchronize with the keras spec (go/keras-json-spec) once 43 # the python config serialization has caught up. 44 metadata = dict( 45 name=self.obj.name, 46 trainable=self.obj.trainable, 47 expects_training_arg=self.obj._expects_training_arg, # pylint: disable=protected-access 48 dtype=policy.serialize(self.obj._dtype_policy), # pylint: disable=protected-access 49 batch_input_shape=getattr(self.obj, '_batch_input_shape', None), 50 stateful=self.obj.stateful, 51 must_restore_from_config=self.obj._must_restore_from_config, # pylint: disable=protected-access 52 ) 53 54 metadata.update(get_serialized(self.obj)) 55 if self.obj.input_spec is not None: 56 # Layer's input_spec has already been type-checked in the property setter. 57 metadata['input_spec'] = nest.map_structure( 58 lambda x: generic_utils.serialize_keras_object(x) if x else None, 59 self.obj.input_spec) 60 if (self.obj.activity_regularizer is not None and 61 hasattr(self.obj.activity_regularizer, 'get_config')): 62 metadata['activity_regularizer'] = generic_utils.serialize_keras_object( 63 self.obj.activity_regularizer) 64 if self.obj._build_input_shape is not None: # pylint: disable=protected-access 65 metadata['build_input_shape'] = self.obj._build_input_shape # pylint: disable=protected-access 66 return metadata 67 68 def objects_to_serialize(self, serialization_cache): 69 return (self._get_serialized_attributes( 70 serialization_cache).objects_to_serialize) 71 72 def functions_to_serialize(self, serialization_cache): 73 return (self._get_serialized_attributes( 74 serialization_cache).functions_to_serialize) 75 76 def _get_serialized_attributes(self, serialization_cache): 77 """Generates or retrieves serialized attributes from cache.""" 78 keras_cache = serialization_cache.setdefault(constants.KERAS_CACHE_KEY, {}) 79 if self.obj in keras_cache: 80 return keras_cache[self.obj] 81 82 serialized_attr = keras_cache[self.obj] = ( 83 serialized_attributes.SerializedAttributes.new(self.obj)) 84 85 if (save_impl.should_skip_serialization(self.obj) or 86 self.obj._must_restore_from_config): # pylint: disable=protected-access 87 return serialized_attr 88 89 object_dict, function_dict = self._get_serialized_attributes_internal( 90 serialization_cache) 91 92 serialized_attr.set_and_validate_objects(object_dict) 93 serialized_attr.set_and_validate_functions(function_dict) 94 return serialized_attr 95 96 def _get_serialized_attributes_internal(self, serialization_cache): 97 """Returns dictionary of serialized attributes.""" 98 objects = save_impl.wrap_layer_objects(self.obj, serialization_cache) 99 functions = save_impl.wrap_layer_functions(self.obj, serialization_cache) 100 # Attribute validator requires that the default save signature is added to 101 # function dict, even if the value is None. 102 functions['_default_save_signature'] = None 103 return objects, functions 104 105 106# TODO(kathywu): Move serialization utils (and related utils from 107# generic_utils.py) to a separate file. 108def get_serialized(obj): 109 with generic_utils.skip_failed_serialization(): 110 # Store the config dictionary, which may be used when reviving the object. 111 # When loading, the program will attempt to revive the object from config, 112 # and if that fails, the object will be revived from the SavedModel. 113 return generic_utils.serialize_keras_object(obj) 114 115 116class InputLayerSavedModelSaver(base_serialization.SavedModelSaver): 117 """InputLayer serialization.""" 118 119 @property 120 def object_identifier(self): 121 return constants.INPUT_LAYER_IDENTIFIER 122 123 @property 124 def python_properties(self): 125 126 return dict( 127 class_name=type(self.obj).__name__, 128 name=self.obj.name, 129 dtype=self.obj.dtype, 130 sparse=self.obj.sparse, 131 ragged=self.obj.ragged, 132 batch_input_shape=self.obj._batch_input_shape, # pylint: disable=protected-access 133 config=self.obj.get_config()) 134 135 def objects_to_serialize(self, serialization_cache): 136 return {} 137 138 def functions_to_serialize(self, serialization_cache): 139 return {} 140 141 142class RNNSavedModelSaver(LayerSavedModelSaver): 143 """RNN layer serialization.""" 144 145 @property 146 def object_identifier(self): 147 return constants.RNN_LAYER_IDENTIFIER 148 149 def _get_serialized_attributes_internal(self, serialization_cache): 150 objects, functions = ( 151 super(RNNSavedModelSaver, self)._get_serialized_attributes_internal( 152 serialization_cache)) 153 states = data_structures.wrap_or_unwrap(self.obj.states) 154 # SaveModel require all the objects to be Trackable when saving. 155 # If the states is still a tuple after wrap_or_unwrap, it means it doesn't 156 # contain any trackable item within it, eg empty tuple or (None, None) for 157 # stateless ConvLSTM2D. We convert them to list so that wrap_or_unwrap can 158 # make it a Trackable again for saving. When loaded, ConvLSTM2D is 159 # able to handle the tuple/list conversion. 160 if isinstance(states, tuple): 161 states = data_structures.wrap_or_unwrap(list(states)) 162 objects['states'] = states 163 return objects, functions 164 165 166class IndexLookupLayerSavedModelSaver(LayerSavedModelSaver): 167 """Index lookup layer serialization.""" 168 169 @property 170 def python_properties(self): 171 # TODO(kathywu): Add python property validator 172 metadata = self._python_properties_internal() 173 if metadata['config'].get('has_static_table', False): 174 metadata['config']['vocabulary'] = None 175 return metadata 176