xref: /aosp_15_r20/external/tensorflow/tensorflow/python/feature_column/serialization.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""FeatureColumn serialization, deserialization logic."""
16
17import six
18
19from tensorflow.python.feature_column import feature_column_v2 as fc_lib
20from tensorflow.python.feature_column import sequence_feature_column as sfc_lib
21from tensorflow.python.ops import init_ops
22from tensorflow.python.util import tf_decorator
23from tensorflow.python.util import tf_inspect
24from tensorflow.python.util.tf_export import tf_export
25
26
27_FEATURE_COLUMNS = [
28    fc_lib.BucketizedColumn, fc_lib.CrossedColumn, fc_lib.EmbeddingColumn,
29    fc_lib.HashedCategoricalColumn, fc_lib.IdentityCategoricalColumn,
30    fc_lib.IndicatorColumn, fc_lib.NumericColumn,
31    fc_lib.SequenceCategoricalColumn, fc_lib.SequenceDenseColumn,
32    fc_lib.SharedEmbeddingColumn, fc_lib.VocabularyFileCategoricalColumn,
33    fc_lib.VocabularyListCategoricalColumn, fc_lib.WeightedCategoricalColumn,
34    init_ops.TruncatedNormal, sfc_lib.SequenceNumericColumn
35]
36
37
38@tf_export('__internal__.feature_column.serialize_feature_column', v1=[])
39def serialize_feature_column(fc):
40  """Serializes a FeatureColumn or a raw string key.
41
42  This method should only be used to serialize parent FeatureColumns when
43  implementing FeatureColumn.get_config(), else serialize_feature_columns()
44  is preferable.
45
46  This serialization also keeps information of the FeatureColumn class, so
47  deserialization is possible without knowing the class type. For example:
48
49  a = numeric_column('x')
50  a.get_config() gives:
51  {
52      'key': 'price',
53      'shape': (1,),
54      'default_value': None,
55      'dtype': 'float32',
56      'normalizer_fn': None
57  }
58  While serialize_feature_column(a) gives:
59  {
60      'class_name': 'NumericColumn',
61      'config': {
62          'key': 'price',
63          'shape': (1,),
64          'default_value': None,
65          'dtype': 'float32',
66          'normalizer_fn': None
67      }
68  }
69
70  Args:
71    fc: A FeatureColumn or raw feature key string.
72
73  Returns:
74    Keras serialization for FeatureColumns, leaves string keys unaffected.
75
76  Raises:
77    ValueError if called with input that is not string or FeatureColumn.
78  """
79  if isinstance(fc, six.string_types):
80    return fc
81  elif isinstance(fc, fc_lib.FeatureColumn):
82    return {'class_name': fc.__class__.__name__, 'config': fc.get_config()}
83  else:
84    raise ValueError('Instance: {} is not a FeatureColumn'.format(fc))
85
86
87@tf_export('__internal__.feature_column.deserialize_feature_column', v1=[])
88def deserialize_feature_column(config,
89                               custom_objects=None,
90                               columns_by_name=None):
91  """Deserializes a `config` generated with `serialize_feature_column`.
92
93  This method should only be used to deserialize parent FeatureColumns when
94  implementing FeatureColumn.from_config(), else deserialize_feature_columns()
95  is preferable. Returns a FeatureColumn for this config.
96
97  Args:
98    config: A Dict with the serialization of feature columns acquired by
99      `serialize_feature_column`, or a string representing a raw column.
100    custom_objects: A Dict from custom_object name to the associated keras
101      serializable objects (FeatureColumns, classes or functions).
102    columns_by_name: A Dict[String, FeatureColumn] of existing columns in order
103      to avoid duplication.
104
105  Raises:
106    ValueError if `config` has invalid format (e.g: expected keys missing,
107    or refers to unknown classes).
108
109  Returns:
110    A FeatureColumn corresponding to the input `config`.
111  """
112  # TODO(b/118939620): Simplify code if Keras utils support object deduping.
113  if isinstance(config, six.string_types):
114    return config
115  # A dict from class_name to class for all FeatureColumns in this module.
116  # FeatureColumns not part of the module can be passed as custom_objects.
117  module_feature_column_classes = {
118      cls.__name__: cls for cls in _FEATURE_COLUMNS}
119  if columns_by_name is None:
120    columns_by_name = {}
121
122  (cls,
123   cls_config) = _class_and_config_for_serialized_keras_object(
124       config,
125       module_objects=module_feature_column_classes,
126       custom_objects=custom_objects,
127       printable_module_name='feature_column_v2')
128
129  if not issubclass(cls, fc_lib.FeatureColumn):
130    raise ValueError(
131        'Expected FeatureColumn class, instead found: {}'.format(cls))
132
133  # Always deserialize the FeatureColumn, in order to get the name.
134  new_instance = cls.from_config(  # pylint: disable=protected-access
135      cls_config,
136      custom_objects=custom_objects,
137      columns_by_name=columns_by_name)
138
139  # If the name already exists, re-use the column from columns_by_name,
140  # (new_instance remains unused).
141  return columns_by_name.setdefault(
142      _column_name_with_class_name(new_instance), new_instance)
143
144
145def serialize_feature_columns(feature_columns):
146  """Serializes a list of FeatureColumns.
147
148  Returns a list of Keras-style config dicts that represent the input
149  FeatureColumns and can be used with `deserialize_feature_columns` for
150  reconstructing the original columns.
151
152  Args:
153    feature_columns: A list of FeatureColumns.
154
155  Returns:
156    Keras serialization for the list of FeatureColumns.
157
158  Raises:
159    ValueError if called with input that is not a list of FeatureColumns.
160  """
161  return [serialize_feature_column(fc) for fc in feature_columns]
162
163
164def deserialize_feature_columns(configs, custom_objects=None):
165  """Deserializes a list of FeatureColumns configs.
166
167  Returns a list of FeatureColumns given a list of config dicts acquired by
168  `serialize_feature_columns`.
169
170  Args:
171    configs: A list of Dicts with the serialization of feature columns acquired
172      by `serialize_feature_columns`.
173    custom_objects: A Dict from custom_object name to the associated keras
174      serializable objects (FeatureColumns, classes or functions).
175
176  Returns:
177    FeatureColumn objects corresponding to the input configs.
178
179  Raises:
180    ValueError if called with input that is not a list of FeatureColumns.
181  """
182  columns_by_name = {}
183  return [
184      deserialize_feature_column(c, custom_objects, columns_by_name)
185      for c in configs
186  ]
187
188
189def _column_name_with_class_name(fc):
190  """Returns a unique name for the feature column used during deduping.
191
192  Without this two FeatureColumns that have the same name and where
193  one wraps the other, such as an IndicatorColumn wrapping a
194  SequenceCategoricalColumn, will fail to deserialize because they will have the
195  same name in columns_by_name, causing the wrong column to be returned.
196
197  Args:
198    fc: A FeatureColumn.
199
200  Returns:
201    A unique name as a string.
202  """
203  return fc.__class__.__name__ + ':' + fc.name
204
205
206def _serialize_keras_object(instance):
207  """Serialize a Keras object into a JSON-compatible representation."""
208  _, instance = tf_decorator.unwrap(instance)
209  if instance is None:
210    return None
211
212  if hasattr(instance, 'get_config'):
213    name = instance.__class__.__name__
214    config = instance.get_config()
215    serialization_config = {}
216    for key, item in config.items():
217      if isinstance(item, six.string_types):
218        serialization_config[key] = item
219        continue
220
221      # Any object of a different type needs to be converted to string or dict
222      # for serialization (e.g. custom functions, custom classes)
223      try:
224        serialized_item = _serialize_keras_object(item)
225        if isinstance(serialized_item, dict) and not isinstance(item, dict):
226          serialized_item['__passive_serialization__'] = True
227        serialization_config[key] = serialized_item
228      except ValueError:
229        serialization_config[key] = item
230
231    return {'class_name': name, 'config': serialization_config}
232  if hasattr(instance, '__name__'):
233    return instance.__name__
234  raise ValueError('Cannot serialize', instance)
235
236
237def _deserialize_keras_object(identifier,
238                              module_objects=None,
239                              custom_objects=None,
240                              printable_module_name='object'):
241  """Turns the serialized form of a Keras object back into an actual object."""
242  if identifier is None:
243    return None
244
245  if isinstance(identifier, dict):
246    # In this case we are dealing with a Keras config dictionary.
247    config = identifier
248    (cls, cls_config) = _class_and_config_for_serialized_keras_object(
249        config, module_objects, custom_objects, printable_module_name)
250
251    if hasattr(cls, 'from_config'):
252      arg_spec = tf_inspect.getfullargspec(cls.from_config)
253      custom_objects = custom_objects or {}
254
255      if 'custom_objects' in arg_spec.args:
256        return cls.from_config(
257            cls_config,
258            custom_objects=dict(
259                list(custom_objects.items())))
260      return cls.from_config(cls_config)
261    else:
262      # Then `cls` may be a function returning a class.
263      # in this case by convention `config` holds
264      # the kwargs of the function.
265      custom_objects = custom_objects or {}
266      return cls(**cls_config)
267  elif isinstance(identifier, six.string_types):
268    object_name = identifier
269    if custom_objects and object_name in custom_objects:
270      obj = custom_objects.get(object_name)
271    else:
272      obj = module_objects.get(object_name)
273      if obj is None:
274        raise ValueError(
275            'Unknown ' + printable_module_name + ': ' + object_name)
276    # Classes passed by name are instantiated with no args, functions are
277    # returned as-is.
278    if tf_inspect.isclass(obj):
279      return obj()
280    return obj
281  elif tf_inspect.isfunction(identifier):
282    # If a function has already been deserialized, return as is.
283    return identifier
284  else:
285    raise ValueError('Could not interpret serialized %s: %s' %
286                     (printable_module_name, identifier))
287
288
289def _class_and_config_for_serialized_keras_object(
290    config,
291    module_objects=None,
292    custom_objects=None,
293    printable_module_name='object'):
294  """Returns the class name and config for a serialized keras object."""
295  if (not isinstance(config, dict) or 'class_name' not in config or
296      'config' not in config):
297    raise ValueError('Improper config format: ' + str(config))
298
299  class_name = config['class_name']
300  cls = _get_registered_object(class_name, custom_objects=custom_objects,
301                               module_objects=module_objects)
302  if cls is None:
303    raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
304
305  cls_config = config['config']
306
307  deserialized_objects = {}
308  for key, item in cls_config.items():
309    if isinstance(item, dict) and '__passive_serialization__' in item:
310      deserialized_objects[key] = _deserialize_keras_object(
311          item,
312          module_objects=module_objects,
313          custom_objects=custom_objects,
314          printable_module_name='config_item')
315    elif (isinstance(item, six.string_types) and
316          tf_inspect.isfunction(_get_registered_object(item, custom_objects))):
317      # Handle custom functions here. When saving functions, we only save the
318      # function's name as a string. If we find a matching string in the custom
319      # objects during deserialization, we convert the string back to the
320      # original function.
321      # Note that a potential issue is that a string field could have a naming
322      # conflict with a custom function name, but this should be a rare case.
323      # This issue does not occur if a string field has a naming conflict with
324      # a custom object, since the config of an object will always be a dict.
325      deserialized_objects[key] = _get_registered_object(item, custom_objects)
326  for key, item in deserialized_objects.items():
327    cls_config[key] = deserialized_objects[key]
328
329  return (cls, cls_config)
330
331
332def _get_registered_object(name, custom_objects=None, module_objects=None):
333  if custom_objects and name in custom_objects:
334    return custom_objects[name]
335  elif module_objects and name in module_objects:
336    return module_objects[name]
337  return None
338
339