1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""FeatureColumn serialization, deserialization logic.""" 16 17import six 18 19from tensorflow.python.feature_column import feature_column_v2 as fc_lib 20from tensorflow.python.feature_column import sequence_feature_column as sfc_lib 21from tensorflow.python.ops import init_ops 22from tensorflow.python.util import tf_decorator 23from tensorflow.python.util import tf_inspect 24from tensorflow.python.util.tf_export import tf_export 25 26 27_FEATURE_COLUMNS = [ 28 fc_lib.BucketizedColumn, fc_lib.CrossedColumn, fc_lib.EmbeddingColumn, 29 fc_lib.HashedCategoricalColumn, fc_lib.IdentityCategoricalColumn, 30 fc_lib.IndicatorColumn, fc_lib.NumericColumn, 31 fc_lib.SequenceCategoricalColumn, fc_lib.SequenceDenseColumn, 32 fc_lib.SharedEmbeddingColumn, fc_lib.VocabularyFileCategoricalColumn, 33 fc_lib.VocabularyListCategoricalColumn, fc_lib.WeightedCategoricalColumn, 34 init_ops.TruncatedNormal, sfc_lib.SequenceNumericColumn 35] 36 37 38@tf_export('__internal__.feature_column.serialize_feature_column', v1=[]) 39def serialize_feature_column(fc): 40 """Serializes a FeatureColumn or a raw string key. 41 42 This method should only be used to serialize parent FeatureColumns when 43 implementing FeatureColumn.get_config(), else serialize_feature_columns() 44 is preferable. 45 46 This serialization also keeps information of the FeatureColumn class, so 47 deserialization is possible without knowing the class type. For example: 48 49 a = numeric_column('x') 50 a.get_config() gives: 51 { 52 'key': 'price', 53 'shape': (1,), 54 'default_value': None, 55 'dtype': 'float32', 56 'normalizer_fn': None 57 } 58 While serialize_feature_column(a) gives: 59 { 60 'class_name': 'NumericColumn', 61 'config': { 62 'key': 'price', 63 'shape': (1,), 64 'default_value': None, 65 'dtype': 'float32', 66 'normalizer_fn': None 67 } 68 } 69 70 Args: 71 fc: A FeatureColumn or raw feature key string. 72 73 Returns: 74 Keras serialization for FeatureColumns, leaves string keys unaffected. 75 76 Raises: 77 ValueError if called with input that is not string or FeatureColumn. 78 """ 79 if isinstance(fc, six.string_types): 80 return fc 81 elif isinstance(fc, fc_lib.FeatureColumn): 82 return {'class_name': fc.__class__.__name__, 'config': fc.get_config()} 83 else: 84 raise ValueError('Instance: {} is not a FeatureColumn'.format(fc)) 85 86 87@tf_export('__internal__.feature_column.deserialize_feature_column', v1=[]) 88def deserialize_feature_column(config, 89 custom_objects=None, 90 columns_by_name=None): 91 """Deserializes a `config` generated with `serialize_feature_column`. 92 93 This method should only be used to deserialize parent FeatureColumns when 94 implementing FeatureColumn.from_config(), else deserialize_feature_columns() 95 is preferable. Returns a FeatureColumn for this config. 96 97 Args: 98 config: A Dict with the serialization of feature columns acquired by 99 `serialize_feature_column`, or a string representing a raw column. 100 custom_objects: A Dict from custom_object name to the associated keras 101 serializable objects (FeatureColumns, classes or functions). 102 columns_by_name: A Dict[String, FeatureColumn] of existing columns in order 103 to avoid duplication. 104 105 Raises: 106 ValueError if `config` has invalid format (e.g: expected keys missing, 107 or refers to unknown classes). 108 109 Returns: 110 A FeatureColumn corresponding to the input `config`. 111 """ 112 # TODO(b/118939620): Simplify code if Keras utils support object deduping. 113 if isinstance(config, six.string_types): 114 return config 115 # A dict from class_name to class for all FeatureColumns in this module. 116 # FeatureColumns not part of the module can be passed as custom_objects. 117 module_feature_column_classes = { 118 cls.__name__: cls for cls in _FEATURE_COLUMNS} 119 if columns_by_name is None: 120 columns_by_name = {} 121 122 (cls, 123 cls_config) = _class_and_config_for_serialized_keras_object( 124 config, 125 module_objects=module_feature_column_classes, 126 custom_objects=custom_objects, 127 printable_module_name='feature_column_v2') 128 129 if not issubclass(cls, fc_lib.FeatureColumn): 130 raise ValueError( 131 'Expected FeatureColumn class, instead found: {}'.format(cls)) 132 133 # Always deserialize the FeatureColumn, in order to get the name. 134 new_instance = cls.from_config( # pylint: disable=protected-access 135 cls_config, 136 custom_objects=custom_objects, 137 columns_by_name=columns_by_name) 138 139 # If the name already exists, re-use the column from columns_by_name, 140 # (new_instance remains unused). 141 return columns_by_name.setdefault( 142 _column_name_with_class_name(new_instance), new_instance) 143 144 145def serialize_feature_columns(feature_columns): 146 """Serializes a list of FeatureColumns. 147 148 Returns a list of Keras-style config dicts that represent the input 149 FeatureColumns and can be used with `deserialize_feature_columns` for 150 reconstructing the original columns. 151 152 Args: 153 feature_columns: A list of FeatureColumns. 154 155 Returns: 156 Keras serialization for the list of FeatureColumns. 157 158 Raises: 159 ValueError if called with input that is not a list of FeatureColumns. 160 """ 161 return [serialize_feature_column(fc) for fc in feature_columns] 162 163 164def deserialize_feature_columns(configs, custom_objects=None): 165 """Deserializes a list of FeatureColumns configs. 166 167 Returns a list of FeatureColumns given a list of config dicts acquired by 168 `serialize_feature_columns`. 169 170 Args: 171 configs: A list of Dicts with the serialization of feature columns acquired 172 by `serialize_feature_columns`. 173 custom_objects: A Dict from custom_object name to the associated keras 174 serializable objects (FeatureColumns, classes or functions). 175 176 Returns: 177 FeatureColumn objects corresponding to the input configs. 178 179 Raises: 180 ValueError if called with input that is not a list of FeatureColumns. 181 """ 182 columns_by_name = {} 183 return [ 184 deserialize_feature_column(c, custom_objects, columns_by_name) 185 for c in configs 186 ] 187 188 189def _column_name_with_class_name(fc): 190 """Returns a unique name for the feature column used during deduping. 191 192 Without this two FeatureColumns that have the same name and where 193 one wraps the other, such as an IndicatorColumn wrapping a 194 SequenceCategoricalColumn, will fail to deserialize because they will have the 195 same name in columns_by_name, causing the wrong column to be returned. 196 197 Args: 198 fc: A FeatureColumn. 199 200 Returns: 201 A unique name as a string. 202 """ 203 return fc.__class__.__name__ + ':' + fc.name 204 205 206def _serialize_keras_object(instance): 207 """Serialize a Keras object into a JSON-compatible representation.""" 208 _, instance = tf_decorator.unwrap(instance) 209 if instance is None: 210 return None 211 212 if hasattr(instance, 'get_config'): 213 name = instance.__class__.__name__ 214 config = instance.get_config() 215 serialization_config = {} 216 for key, item in config.items(): 217 if isinstance(item, six.string_types): 218 serialization_config[key] = item 219 continue 220 221 # Any object of a different type needs to be converted to string or dict 222 # for serialization (e.g. custom functions, custom classes) 223 try: 224 serialized_item = _serialize_keras_object(item) 225 if isinstance(serialized_item, dict) and not isinstance(item, dict): 226 serialized_item['__passive_serialization__'] = True 227 serialization_config[key] = serialized_item 228 except ValueError: 229 serialization_config[key] = item 230 231 return {'class_name': name, 'config': serialization_config} 232 if hasattr(instance, '__name__'): 233 return instance.__name__ 234 raise ValueError('Cannot serialize', instance) 235 236 237def _deserialize_keras_object(identifier, 238 module_objects=None, 239 custom_objects=None, 240 printable_module_name='object'): 241 """Turns the serialized form of a Keras object back into an actual object.""" 242 if identifier is None: 243 return None 244 245 if isinstance(identifier, dict): 246 # In this case we are dealing with a Keras config dictionary. 247 config = identifier 248 (cls, cls_config) = _class_and_config_for_serialized_keras_object( 249 config, module_objects, custom_objects, printable_module_name) 250 251 if hasattr(cls, 'from_config'): 252 arg_spec = tf_inspect.getfullargspec(cls.from_config) 253 custom_objects = custom_objects or {} 254 255 if 'custom_objects' in arg_spec.args: 256 return cls.from_config( 257 cls_config, 258 custom_objects=dict( 259 list(custom_objects.items()))) 260 return cls.from_config(cls_config) 261 else: 262 # Then `cls` may be a function returning a class. 263 # in this case by convention `config` holds 264 # the kwargs of the function. 265 custom_objects = custom_objects or {} 266 return cls(**cls_config) 267 elif isinstance(identifier, six.string_types): 268 object_name = identifier 269 if custom_objects and object_name in custom_objects: 270 obj = custom_objects.get(object_name) 271 else: 272 obj = module_objects.get(object_name) 273 if obj is None: 274 raise ValueError( 275 'Unknown ' + printable_module_name + ': ' + object_name) 276 # Classes passed by name are instantiated with no args, functions are 277 # returned as-is. 278 if tf_inspect.isclass(obj): 279 return obj() 280 return obj 281 elif tf_inspect.isfunction(identifier): 282 # If a function has already been deserialized, return as is. 283 return identifier 284 else: 285 raise ValueError('Could not interpret serialized %s: %s' % 286 (printable_module_name, identifier)) 287 288 289def _class_and_config_for_serialized_keras_object( 290 config, 291 module_objects=None, 292 custom_objects=None, 293 printable_module_name='object'): 294 """Returns the class name and config for a serialized keras object.""" 295 if (not isinstance(config, dict) or 'class_name' not in config or 296 'config' not in config): 297 raise ValueError('Improper config format: ' + str(config)) 298 299 class_name = config['class_name'] 300 cls = _get_registered_object(class_name, custom_objects=custom_objects, 301 module_objects=module_objects) 302 if cls is None: 303 raise ValueError('Unknown ' + printable_module_name + ': ' + class_name) 304 305 cls_config = config['config'] 306 307 deserialized_objects = {} 308 for key, item in cls_config.items(): 309 if isinstance(item, dict) and '__passive_serialization__' in item: 310 deserialized_objects[key] = _deserialize_keras_object( 311 item, 312 module_objects=module_objects, 313 custom_objects=custom_objects, 314 printable_module_name='config_item') 315 elif (isinstance(item, six.string_types) and 316 tf_inspect.isfunction(_get_registered_object(item, custom_objects))): 317 # Handle custom functions here. When saving functions, we only save the 318 # function's name as a string. If we find a matching string in the custom 319 # objects during deserialization, we convert the string back to the 320 # original function. 321 # Note that a potential issue is that a string field could have a naming 322 # conflict with a custom function name, but this should be a rare case. 323 # This issue does not occur if a string field has a naming conflict with 324 # a custom object, since the config of an object will always be a dict. 325 deserialized_objects[key] = _get_registered_object(item, custom_objects) 326 for key, item in deserialized_objects.items(): 327 cls_config[key] = deserialized_objects[key] 328 329 return (cls, cls_config) 330 331 332def _get_registered_object(name, custom_objects=None, module_objects=None): 333 if custom_objects and name in custom_objects: 334 return custom_objects[name] 335 elif module_objects and name in module_objects: 336 return module_objects[name] 337 return None 338 339