xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/utils/generic_utils.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python utilities required by Keras."""
16
17import binascii
18import codecs
19import importlib
20import marshal
21import os
22import re
23import sys
24import threading
25import time
26import types as python_types
27import warnings
28import weakref
29
30import numpy as np
31
32from tensorflow.python.keras.utils import tf_contextlib
33from tensorflow.python.keras.utils import tf_inspect
34from tensorflow.python.util import nest
35from tensorflow.python.util import tf_decorator
36from tensorflow.python.util.tf_export import keras_export
37
38_GLOBAL_CUSTOM_OBJECTS = {}
39_GLOBAL_CUSTOM_NAMES = {}
40
41# Flag that determines whether to skip the NotImplementedError when calling
42# get_config in custom models and layers. This is only enabled when saving to
43# SavedModel, when the config isn't required.
44_SKIP_FAILED_SERIALIZATION = False
45# If a layer does not have a defined config, then the returned config will be a
46# dictionary with the below key.
47_LAYER_UNDEFINED_CONFIG_KEY = 'layer was saved without config'
48
49
50@keras_export('keras.utils.custom_object_scope',  # pylint: disable=g-classes-have-attributes
51              'keras.utils.CustomObjectScope')
52class CustomObjectScope(object):
53  """Exposes custom classes/functions to Keras deserialization internals.
54
55  Under a scope `with custom_object_scope(objects_dict)`, Keras methods such
56  as `tf.keras.models.load_model` or `tf.keras.models.model_from_config`
57  will be able to deserialize any custom object referenced by a
58  saved config (e.g. a custom layer or metric).
59
60  Example:
61
62  Consider a custom regularizer `my_regularizer`:
63
64  ```python
65  layer = Dense(3, kernel_regularizer=my_regularizer)
66  config = layer.get_config()  # Config contains a reference to `my_regularizer`
67  ...
68  # Later:
69  with custom_object_scope({'my_regularizer': my_regularizer}):
70    layer = Dense.from_config(config)
71  ```
72
73  Args:
74      *args: Dictionary or dictionaries of `{name: object}` pairs.
75  """
76
77  def __init__(self, *args):
78    self.custom_objects = args
79    self.backup = None
80
81  def __enter__(self):
82    self.backup = _GLOBAL_CUSTOM_OBJECTS.copy()
83    for objects in self.custom_objects:
84      _GLOBAL_CUSTOM_OBJECTS.update(objects)
85    return self
86
87  def __exit__(self, *args, **kwargs):
88    _GLOBAL_CUSTOM_OBJECTS.clear()
89    _GLOBAL_CUSTOM_OBJECTS.update(self.backup)
90
91
92@keras_export('keras.utils.get_custom_objects')
93def get_custom_objects():
94  """Retrieves a live reference to the global dictionary of custom objects.
95
96  Updating and clearing custom objects using `custom_object_scope`
97  is preferred, but `get_custom_objects` can
98  be used to directly access the current collection of custom objects.
99
100  Example:
101
102  ```python
103  get_custom_objects().clear()
104  get_custom_objects()['MyObject'] = MyObject
105  ```
106
107  Returns:
108      Global dictionary of names to classes (`_GLOBAL_CUSTOM_OBJECTS`).
109  """
110  return _GLOBAL_CUSTOM_OBJECTS
111
112
113# Store a unique, per-object ID for shared objects.
114#
115# We store a unique ID for each object so that we may, at loading time,
116# re-create the network properly.  Without this ID, we would have no way of
117# determining whether a config is a description of a new object that
118# should be created or is merely a reference to an already-created object.
119SHARED_OBJECT_KEY = 'shared_object_id'
120
121
122SHARED_OBJECT_DISABLED = threading.local()
123SHARED_OBJECT_LOADING = threading.local()
124SHARED_OBJECT_SAVING = threading.local()
125
126
127# Attributes on the threadlocal variable must be set per-thread, thus we
128# cannot initialize these globally. Instead, we have accessor functions with
129# default values.
130def _shared_object_disabled():
131  """Get whether shared object handling is disabled in a threadsafe manner."""
132  return getattr(SHARED_OBJECT_DISABLED, 'disabled', False)
133
134
135def _shared_object_loading_scope():
136  """Get the current shared object saving scope in a threadsafe manner."""
137  return getattr(SHARED_OBJECT_LOADING, 'scope', NoopLoadingScope())
138
139
140def _shared_object_saving_scope():
141  """Get the current shared object saving scope in a threadsafe manner."""
142  return getattr(SHARED_OBJECT_SAVING, 'scope', None)
143
144
145class DisableSharedObjectScope(object):
146  """A context manager for disabling handling of shared objects.
147
148  Disables shared object handling for both saving and loading.
149
150  Created primarily for use with `clone_model`, which does extra surgery that
151  is incompatible with shared objects.
152  """
153
154  def __enter__(self):
155    SHARED_OBJECT_DISABLED.disabled = True
156    self._orig_loading_scope = _shared_object_loading_scope()
157    self._orig_saving_scope = _shared_object_saving_scope()
158
159  def __exit__(self, *args, **kwargs):
160    SHARED_OBJECT_DISABLED.disabled = False
161    SHARED_OBJECT_LOADING.scope = self._orig_loading_scope
162    SHARED_OBJECT_SAVING.scope = self._orig_saving_scope
163
164
165class NoopLoadingScope(object):
166  """The default shared object loading scope. It does nothing.
167
168  Created to simplify serialization code that doesn't care about shared objects
169  (e.g. when serializing a single object).
170  """
171
172  def get(self, unused_object_id):
173    return None
174
175  def set(self, object_id, obj):
176    pass
177
178
179class SharedObjectLoadingScope(object):
180  """A context manager for keeping track of loaded objects.
181
182  During the deserialization process, we may come across objects that are
183  shared across multiple layers. In order to accurately restore the network
184  structure to its original state, `SharedObjectLoadingScope` allows us to
185  re-use shared objects rather than cloning them.
186  """
187
188  def __enter__(self):
189    if _shared_object_disabled():
190      return NoopLoadingScope()
191
192    global SHARED_OBJECT_LOADING
193    SHARED_OBJECT_LOADING.scope = self
194    self._obj_ids_to_obj = {}
195    return self
196
197  def get(self, object_id):
198    """Given a shared object ID, returns a previously instantiated object.
199
200    Args:
201      object_id: shared object ID to use when attempting to find already-loaded
202        object.
203
204    Returns:
205      The object, if we've seen this ID before. Else, `None`.
206    """
207    # Explicitly check for `None` internally to make external calling code a
208    # bit cleaner.
209    if object_id is None:
210      return
211    return self._obj_ids_to_obj.get(object_id)
212
213  def set(self, object_id, obj):
214    """Stores an instantiated object for future lookup and sharing."""
215    if object_id is None:
216      return
217    self._obj_ids_to_obj[object_id] = obj
218
219  def __exit__(self, *args, **kwargs):
220    global SHARED_OBJECT_LOADING
221    SHARED_OBJECT_LOADING.scope = NoopLoadingScope()
222
223
224class SharedObjectConfig(dict):
225  """A configuration container that keeps track of references.
226
227  `SharedObjectConfig` will automatically attach a shared object ID to any
228  configs which are referenced more than once, allowing for proper shared
229  object reconstruction at load time.
230
231  In most cases, it would be more proper to subclass something like
232  `collections.UserDict` or `collections.Mapping` rather than `dict` directly.
233  Unfortunately, python's json encoder does not support `Mapping`s. This is
234  important functionality to retain, since we are dealing with serialization.
235
236  We should be safe to subclass `dict` here, since we aren't actually
237  overriding any core methods, only augmenting with a new one for reference
238  counting.
239  """
240
241  def __init__(self, base_config, object_id, **kwargs):
242    self.ref_count = 1
243    self.object_id = object_id
244    super(SharedObjectConfig, self).__init__(base_config, **kwargs)
245
246  def increment_ref_count(self):
247    # As soon as we've seen the object more than once, we want to attach the
248    # shared object ID. This allows us to only attach the shared object ID when
249    # it's strictly necessary, making backwards compatibility breakage less
250    # likely.
251    if self.ref_count == 1:
252      self[SHARED_OBJECT_KEY] = self.object_id
253    self.ref_count += 1
254
255
256class SharedObjectSavingScope(object):
257  """Keeps track of shared object configs when serializing."""
258
259  def __enter__(self):
260    if _shared_object_disabled():
261      return None
262
263    global SHARED_OBJECT_SAVING
264
265    # Serialization can happen at a number of layers for a number of reasons.
266    # We may end up with a case where we're opening a saving scope within
267    # another saving scope. In that case, we'd like to use the outermost scope
268    # available and ignore inner scopes, since there is not (yet) a reasonable
269    # use case for having these nested and distinct.
270    if _shared_object_saving_scope() is not None:
271      self._passthrough = True
272      return _shared_object_saving_scope()
273    else:
274      self._passthrough = False
275
276    SHARED_OBJECT_SAVING.scope = self
277    self._shared_objects_config = weakref.WeakKeyDictionary()
278    self._next_id = 0
279    return self
280
281  def get_config(self, obj):
282    """Gets a `SharedObjectConfig` if one has already been seen for `obj`.
283
284    Args:
285      obj: The object for which to retrieve the `SharedObjectConfig`.
286
287    Returns:
288      The SharedObjectConfig for a given object, if already seen. Else,
289        `None`.
290    """
291    try:
292      shared_object_config = self._shared_objects_config[obj]
293    except (TypeError, KeyError):
294      # If the object is unhashable (e.g. a subclass of `AbstractBaseClass`
295      # that has not overridden `__hash__`), a `TypeError` will be thrown.
296      # We'll just continue on without shared object support.
297      return None
298    shared_object_config.increment_ref_count()
299    return shared_object_config
300
301  def create_config(self, base_config, obj):
302    """Create a new SharedObjectConfig for a given object."""
303    shared_object_config = SharedObjectConfig(base_config, self._next_id)
304    self._next_id += 1
305    try:
306      self._shared_objects_config[obj] = shared_object_config
307    except TypeError:
308      # If the object is unhashable (e.g. a subclass of `AbstractBaseClass`
309      # that has not overridden `__hash__`), a `TypeError` will be thrown.
310      # We'll just continue on without shared object support.
311      pass
312    return shared_object_config
313
314  def __exit__(self, *args, **kwargs):
315    if not getattr(self, '_passthrough', False):
316      global SHARED_OBJECT_SAVING
317      SHARED_OBJECT_SAVING.scope = None
318
319
320def serialize_keras_class_and_config(
321    cls_name, cls_config, obj=None, shared_object_id=None):
322  """Returns the serialization of the class with the given config."""
323  base_config = {'class_name': cls_name, 'config': cls_config}
324
325  # We call `serialize_keras_class_and_config` for some branches of the load
326  # path. In that case, we may already have a shared object ID we'd like to
327  # retain.
328  if shared_object_id is not None:
329    base_config[SHARED_OBJECT_KEY] = shared_object_id
330
331  # If we have an active `SharedObjectSavingScope`, check whether we've already
332  # serialized this config. If so, just use that config. This will store an
333  # extra ID field in the config, allowing us to re-create the shared object
334  # relationship at load time.
335  if _shared_object_saving_scope() is not None and obj is not None:
336    shared_object_config = _shared_object_saving_scope().get_config(obj)
337    if shared_object_config is None:
338      return _shared_object_saving_scope().create_config(base_config, obj)
339    return shared_object_config
340
341  return base_config
342
343
344@keras_export('keras.utils.register_keras_serializable')
345def register_keras_serializable(package='Custom', name=None):
346  """Registers an object with the Keras serialization framework.
347
348  This decorator injects the decorated class or function into the Keras custom
349  object dictionary, so that it can be serialized and deserialized without
350  needing an entry in the user-provided custom object dict. It also injects a
351  function that Keras will call to get the object's serializable string key.
352
353  Note that to be serialized and deserialized, classes must implement the
354  `get_config()` method. Functions do not have this requirement.
355
356  The object will be registered under the key 'package>name' where `name`,
357  defaults to the object name if not passed.
358
359  Args:
360    package: The package that this class belongs to.
361    name: The name to serialize this class under in this package. If None, the
362      class' name will be used.
363
364  Returns:
365    A decorator that registers the decorated class with the passed names.
366  """
367
368  def decorator(arg):
369    """Registers a class with the Keras serialization framework."""
370    class_name = name if name is not None else arg.__name__
371    registered_name = package + '>' + class_name
372
373    if tf_inspect.isclass(arg) and not hasattr(arg, 'get_config'):
374      raise ValueError(
375          'Cannot register a class that does not have a get_config() method.')
376
377    if registered_name in _GLOBAL_CUSTOM_OBJECTS:
378      raise ValueError(
379          '%s has already been registered to %s' %
380          (registered_name, _GLOBAL_CUSTOM_OBJECTS[registered_name]))
381
382    if arg in _GLOBAL_CUSTOM_NAMES:
383      raise ValueError('%s has already been registered to %s' %
384                       (arg, _GLOBAL_CUSTOM_NAMES[arg]))
385    _GLOBAL_CUSTOM_OBJECTS[registered_name] = arg
386    _GLOBAL_CUSTOM_NAMES[arg] = registered_name
387
388    return arg
389
390  return decorator
391
392
393@keras_export('keras.utils.get_registered_name')
394def get_registered_name(obj):
395  """Returns the name registered to an object within the Keras framework.
396
397  This function is part of the Keras serialization and deserialization
398  framework. It maps objects to the string names associated with those objects
399  for serialization/deserialization.
400
401  Args:
402    obj: The object to look up.
403
404  Returns:
405    The name associated with the object, or the default Python name if the
406      object is not registered.
407  """
408  if obj in _GLOBAL_CUSTOM_NAMES:
409    return _GLOBAL_CUSTOM_NAMES[obj]
410  else:
411    return obj.__name__
412
413
414@tf_contextlib.contextmanager
415def skip_failed_serialization():
416  global _SKIP_FAILED_SERIALIZATION
417  prev = _SKIP_FAILED_SERIALIZATION
418  try:
419    _SKIP_FAILED_SERIALIZATION = True
420    yield
421  finally:
422    _SKIP_FAILED_SERIALIZATION = prev
423
424
425@keras_export('keras.utils.get_registered_object')
426def get_registered_object(name, custom_objects=None, module_objects=None):
427  """Returns the class associated with `name` if it is registered with Keras.
428
429  This function is part of the Keras serialization and deserialization
430  framework. It maps strings to the objects associated with them for
431  serialization/deserialization.
432
433  Example:
434  ```
435  def from_config(cls, config, custom_objects=None):
436    if 'my_custom_object_name' in config:
437      config['hidden_cls'] = tf.keras.utils.get_registered_object(
438          config['my_custom_object_name'], custom_objects=custom_objects)
439  ```
440
441  Args:
442    name: The name to look up.
443    custom_objects: A dictionary of custom objects to look the name up in.
444      Generally, custom_objects is provided by the user.
445    module_objects: A dictionary of custom objects to look the name up in.
446      Generally, module_objects is provided by midlevel library implementers.
447
448  Returns:
449    An instantiable class associated with 'name', or None if no such class
450      exists.
451  """
452  if name in _GLOBAL_CUSTOM_OBJECTS:
453    return _GLOBAL_CUSTOM_OBJECTS[name]
454  elif custom_objects and name in custom_objects:
455    return custom_objects[name]
456  elif module_objects and name in module_objects:
457    return module_objects[name]
458  return None
459
460
461# pylint: disable=g-bad-exception-name
462class CustomMaskWarning(Warning):
463  pass
464# pylint: enable=g-bad-exception-name
465
466
467@keras_export('keras.utils.serialize_keras_object')
468def serialize_keras_object(instance):
469  """Serialize a Keras object into a JSON-compatible representation.
470
471  Calls to `serialize_keras_object` while underneath the
472  `SharedObjectSavingScope` context manager will cause any objects re-used
473  across multiple layers to be saved with a special shared object ID. This
474  allows the network to be re-created properly during deserialization.
475
476  Args:
477    instance: The object to serialize.
478
479  Returns:
480    A dict-like, JSON-compatible representation of the object's config.
481  """
482  _, instance = tf_decorator.unwrap(instance)
483  if instance is None:
484    return None
485
486  # pylint: disable=protected-access
487  #
488  # For v1 layers, checking supports_masking is not enough. We have to also
489  # check whether compute_mask has been overridden.
490  supports_masking = (getattr(instance, 'supports_masking', False)
491                      or (hasattr(instance, 'compute_mask')
492                          and not is_default(instance.compute_mask)))
493  if supports_masking and is_default(instance.get_config):
494    warnings.warn('Custom mask layers require a config and must override '
495                  'get_config. When loading, the custom mask layer must be '
496                  'passed to the custom_objects argument.',
497                  category=CustomMaskWarning)
498  # pylint: enable=protected-access
499
500  if hasattr(instance, 'get_config'):
501    name = get_registered_name(instance.__class__)
502    try:
503      config = instance.get_config()
504    except NotImplementedError as e:
505      if _SKIP_FAILED_SERIALIZATION:
506        return serialize_keras_class_and_config(
507            name, {_LAYER_UNDEFINED_CONFIG_KEY: True})
508      raise e
509    serialization_config = {}
510    for key, item in config.items():
511      if isinstance(item, str):
512        serialization_config[key] = item
513        continue
514
515      # Any object of a different type needs to be converted to string or dict
516      # for serialization (e.g. custom functions, custom classes)
517      try:
518        serialized_item = serialize_keras_object(item)
519        if isinstance(serialized_item, dict) and not isinstance(item, dict):
520          serialized_item['__passive_serialization__'] = True
521        serialization_config[key] = serialized_item
522      except ValueError:
523        serialization_config[key] = item
524
525    name = get_registered_name(instance.__class__)
526    return serialize_keras_class_and_config(
527        name, serialization_config, instance)
528  if hasattr(instance, '__name__'):
529    return get_registered_name(instance)
530  raise ValueError('Cannot serialize', instance)
531
532
533def get_custom_objects_by_name(item, custom_objects=None):
534  """Returns the item if it is in either local or global custom objects."""
535  if item in _GLOBAL_CUSTOM_OBJECTS:
536    return _GLOBAL_CUSTOM_OBJECTS[item]
537  elif custom_objects and item in custom_objects:
538    return custom_objects[item]
539  return None
540
541
542def class_and_config_for_serialized_keras_object(
543    config,
544    module_objects=None,
545    custom_objects=None,
546    printable_module_name='object'):
547  """Returns the class name and config for a serialized keras object."""
548  if (not isinstance(config, dict)
549      or 'class_name' not in config
550      or 'config' not in config):
551    raise ValueError('Improper config format: ' + str(config))
552
553  class_name = config['class_name']
554  cls = get_registered_object(class_name, custom_objects, module_objects)
555  if cls is None:
556    raise ValueError(
557        'Unknown {}: {}. Please ensure this object is '
558        'passed to the `custom_objects` argument. See '
559        'https://www.tensorflow.org/guide/keras/save_and_serialize'
560        '#registering_the_custom_object for details.'
561        .format(printable_module_name, class_name))
562
563  cls_config = config['config']
564  # Check if `cls_config` is a list. If it is a list, return the class and the
565  # associated class configs for recursively deserialization. This case will
566  # happen on the old version of sequential model (e.g. `keras_version` ==
567  # "2.0.6"), which is serialized in a different structure, for example
568  # "{'class_name': 'Sequential',
569  #   'config': [{'class_name': 'Embedding', 'config': ...}, {}, ...]}".
570  if isinstance(cls_config, list):
571    return (cls, cls_config)
572
573  deserialized_objects = {}
574  for key, item in cls_config.items():
575    if key == 'name':
576      # Assume that the value of 'name' is a string that should not be
577      # deserialized as a function. This avoids the corner case where
578      # cls_config['name'] has an identical name to a custom function and
579      # gets converted into that function.
580      deserialized_objects[key] = item
581    elif isinstance(item, dict) and '__passive_serialization__' in item:
582      deserialized_objects[key] = deserialize_keras_object(
583          item,
584          module_objects=module_objects,
585          custom_objects=custom_objects,
586          printable_module_name='config_item')
587    # TODO(momernick): Should this also have 'module_objects'?
588    elif (isinstance(item, str) and
589          tf_inspect.isfunction(get_registered_object(item, custom_objects))):
590      # Handle custom functions here. When saving functions, we only save the
591      # function's name as a string. If we find a matching string in the custom
592      # objects during deserialization, we convert the string back to the
593      # original function.
594      # Note that a potential issue is that a string field could have a naming
595      # conflict with a custom function name, but this should be a rare case.
596      # This issue does not occur if a string field has a naming conflict with
597      # a custom object, since the config of an object will always be a dict.
598      deserialized_objects[key] = get_registered_object(item, custom_objects)
599  for key, item in deserialized_objects.items():
600    cls_config[key] = deserialized_objects[key]
601
602  return (cls, cls_config)
603
604
605@keras_export('keras.utils.deserialize_keras_object')
606def deserialize_keras_object(identifier,
607                             module_objects=None,
608                             custom_objects=None,
609                             printable_module_name='object'):
610  """Turns the serialized form of a Keras object back into an actual object.
611
612  This function is for mid-level library implementers rather than end users.
613
614  Importantly, this utility requires you to provide the dict of `module_objects`
615  to use for looking up the object config; this is not populated by default.
616  If you need a deserialization utility that has preexisting knowledge of
617  built-in Keras objects, use e.g. `keras.layers.deserialize(config)`,
618  `keras.metrics.deserialize(config)`, etc.
619
620  Calling `deserialize_keras_object` while underneath the
621  `SharedObjectLoadingScope` context manager will cause any already-seen shared
622  objects to be returned as-is rather than creating a new object.
623
624  Args:
625    identifier: the serialized form of the object.
626    module_objects: A dictionary of built-in objects to look the name up in.
627      Generally, `module_objects` is provided by midlevel library implementers.
628    custom_objects: A dictionary of custom objects to look the name up in.
629      Generally, `custom_objects` is provided by the end user.
630    printable_module_name: A human-readable string representing the type of the
631      object. Printed in case of exception.
632
633  Returns:
634    The deserialized object.
635
636  Example:
637
638  A mid-level library implementer might want to implement a utility for
639  retrieving an object from its config, as such:
640
641  ```python
642  def deserialize(config, custom_objects=None):
643     return deserialize_keras_object(
644       identifier,
645       module_objects=globals(),
646       custom_objects=custom_objects,
647       name="MyObjectType",
648     )
649  ```
650
651  This is how e.g. `keras.layers.deserialize()` is implemented.
652  """
653  if identifier is None:
654    return None
655
656  if isinstance(identifier, dict):
657    # In this case we are dealing with a Keras config dictionary.
658    config = identifier
659    (cls, cls_config) = class_and_config_for_serialized_keras_object(
660        config, module_objects, custom_objects, printable_module_name)
661
662    # If this object has already been loaded (i.e. it's shared between multiple
663    # objects), return the already-loaded object.
664    shared_object_id = config.get(SHARED_OBJECT_KEY)
665    shared_object = _shared_object_loading_scope().get(shared_object_id)  # pylint: disable=assignment-from-none
666    if shared_object is not None:
667      return shared_object
668
669    if hasattr(cls, 'from_config'):
670      arg_spec = tf_inspect.getfullargspec(cls.from_config)
671      custom_objects = custom_objects or {}
672
673      if 'custom_objects' in arg_spec.args:
674        deserialized_obj = cls.from_config(
675            cls_config,
676            custom_objects=dict(
677                list(_GLOBAL_CUSTOM_OBJECTS.items()) +
678                list(custom_objects.items())))
679      else:
680        with CustomObjectScope(custom_objects):
681          deserialized_obj = cls.from_config(cls_config)
682    else:
683      # Then `cls` may be a function returning a class.
684      # in this case by convention `config` holds
685      # the kwargs of the function.
686      custom_objects = custom_objects or {}
687      with CustomObjectScope(custom_objects):
688        deserialized_obj = cls(**cls_config)
689
690    # Add object to shared objects, in case we find it referenced again.
691    _shared_object_loading_scope().set(shared_object_id, deserialized_obj)
692
693    return deserialized_obj
694
695  elif isinstance(identifier, str):
696    object_name = identifier
697    if custom_objects and object_name in custom_objects:
698      obj = custom_objects.get(object_name)
699    elif object_name in _GLOBAL_CUSTOM_OBJECTS:
700      obj = _GLOBAL_CUSTOM_OBJECTS[object_name]
701    else:
702      obj = module_objects.get(object_name)
703      if obj is None:
704        raise ValueError(
705            'Unknown {}: {}. Please ensure this object is '
706            'passed to the `custom_objects` argument. See '
707            'https://www.tensorflow.org/guide/keras/save_and_serialize'
708            '#registering_the_custom_object for details.'
709            .format(printable_module_name, object_name))
710
711    # Classes passed by name are instantiated with no args, functions are
712    # returned as-is.
713    if tf_inspect.isclass(obj):
714      return obj()
715    return obj
716  elif tf_inspect.isfunction(identifier):
717    # If a function has already been deserialized, return as is.
718    return identifier
719  else:
720    raise ValueError('Could not interpret serialized %s: %s' %
721                     (printable_module_name, identifier))
722
723
724def func_dump(func):
725  """Serializes a user defined function.
726
727  Args:
728      func: the function to serialize.
729
730  Returns:
731      A tuple `(code, defaults, closure)`.
732  """
733  if os.name == 'nt':
734    raw_code = marshal.dumps(func.__code__).replace(b'\\', b'/')
735    code = codecs.encode(raw_code, 'base64').decode('ascii')
736  else:
737    raw_code = marshal.dumps(func.__code__)
738    code = codecs.encode(raw_code, 'base64').decode('ascii')
739  defaults = func.__defaults__
740  if func.__closure__:
741    closure = tuple(c.cell_contents for c in func.__closure__)
742  else:
743    closure = None
744  return code, defaults, closure
745
746
747def func_load(code, defaults=None, closure=None, globs=None):
748  """Deserializes a user defined function.
749
750  Args:
751      code: bytecode of the function.
752      defaults: defaults of the function.
753      closure: closure of the function.
754      globs: dictionary of global objects.
755
756  Returns:
757      A function object.
758  """
759  if isinstance(code, (tuple, list)):  # unpack previous dump
760    code, defaults, closure = code
761    if isinstance(defaults, list):
762      defaults = tuple(defaults)
763
764  def ensure_value_to_cell(value):
765    """Ensures that a value is converted to a python cell object.
766
767    Args:
768        value: Any value that needs to be casted to the cell type
769
770    Returns:
771        A value wrapped as a cell object (see function "func_load")
772    """
773
774    def dummy_fn():
775      # pylint: disable=pointless-statement
776      value  # just access it so it gets captured in .__closure__
777
778    cell_value = dummy_fn.__closure__[0]
779    if not isinstance(value, type(cell_value)):
780      return cell_value
781    return value
782
783  if closure is not None:
784    closure = tuple(ensure_value_to_cell(_) for _ in closure)
785  try:
786    raw_code = codecs.decode(code.encode('ascii'), 'base64')
787  except (UnicodeEncodeError, binascii.Error):
788    raw_code = code.encode('raw_unicode_escape')
789  code = marshal.loads(raw_code)
790  if globs is None:
791    globs = globals()
792  return python_types.FunctionType(
793      code, globs, name=code.co_name, argdefs=defaults, closure=closure)
794
795
796def has_arg(fn, name, accept_all=False):
797  """Checks if a callable accepts a given keyword argument.
798
799  Args:
800      fn: Callable to inspect.
801      name: Check if `fn` can be called with `name` as a keyword argument.
802      accept_all: What to return if there is no parameter called `name` but the
803        function accepts a `**kwargs` argument.
804
805  Returns:
806      bool, whether `fn` accepts a `name` keyword argument.
807  """
808  arg_spec = tf_inspect.getfullargspec(fn)
809  if accept_all and arg_spec.varkw is not None:
810    return True
811  return name in arg_spec.args or name in arg_spec.kwonlyargs
812
813
814@keras_export('keras.utils.Progbar')
815class Progbar(object):
816  """Displays a progress bar.
817
818  Args:
819      target: Total number of steps expected, None if unknown.
820      width: Progress bar width on screen.
821      verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose)
822      stateful_metrics: Iterable of string names of metrics that should *not* be
823        averaged over time. Metrics in this list will be displayed as-is. All
824        others will be averaged by the progbar before display.
825      interval: Minimum visual progress update interval (in seconds).
826      unit_name: Display name for step counts (usually "step" or "sample").
827  """
828
829  def __init__(self,
830               target,
831               width=30,
832               verbose=1,
833               interval=0.05,
834               stateful_metrics=None,
835               unit_name='step'):
836    self.target = target
837    self.width = width
838    self.verbose = verbose
839    self.interval = interval
840    self.unit_name = unit_name
841    if stateful_metrics:
842      self.stateful_metrics = set(stateful_metrics)
843    else:
844      self.stateful_metrics = set()
845
846    self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and
847                              sys.stdout.isatty()) or
848                             'ipykernel' in sys.modules or
849                             'posix' in sys.modules or
850                             'PYCHARM_HOSTED' in os.environ)
851    self._total_width = 0
852    self._seen_so_far = 0
853    # We use a dict + list to avoid garbage collection
854    # issues found in OrderedDict
855    self._values = {}
856    self._values_order = []
857    self._start = time.time()
858    self._last_update = 0
859
860    self._time_after_first_step = None
861
862  def update(self, current, values=None, finalize=None):
863    """Updates the progress bar.
864
865    Args:
866        current: Index of current step.
867        values: List of tuples: `(name, value_for_last_step)`. If `name` is in
868          `stateful_metrics`, `value_for_last_step` will be displayed as-is.
869          Else, an average of the metric over time will be displayed.
870        finalize: Whether this is the last update for the progress bar. If
871          `None`, defaults to `current >= self.target`.
872    """
873    if finalize is None:
874      if self.target is None:
875        finalize = False
876      else:
877        finalize = current >= self.target
878
879    values = values or []
880    for k, v in values:
881      if k not in self._values_order:
882        self._values_order.append(k)
883      if k not in self.stateful_metrics:
884        # In the case that progress bar doesn't have a target value in the first
885        # epoch, both on_batch_end and on_epoch_end will be called, which will
886        # cause 'current' and 'self._seen_so_far' to have the same value. Force
887        # the minimal value to 1 here, otherwise stateful_metric will be 0s.
888        value_base = max(current - self._seen_so_far, 1)
889        if k not in self._values:
890          self._values[k] = [v * value_base, value_base]
891        else:
892          self._values[k][0] += v * value_base
893          self._values[k][1] += value_base
894      else:
895        # Stateful metrics output a numeric value. This representation
896        # means "take an average from a single value" but keeps the
897        # numeric formatting.
898        self._values[k] = [v, 1]
899    self._seen_so_far = current
900
901    now = time.time()
902    info = ' - %.0fs' % (now - self._start)
903    if self.verbose == 1:
904      if now - self._last_update < self.interval and not finalize:
905        return
906
907      prev_total_width = self._total_width
908      if self._dynamic_display:
909        sys.stdout.write('\b' * prev_total_width)
910        sys.stdout.write('\r')
911      else:
912        sys.stdout.write('\n')
913
914      if self.target is not None:
915        numdigits = int(np.log10(self.target)) + 1
916        bar = ('%' + str(numdigits) + 'd/%d [') % (current, self.target)
917        prog = float(current) / self.target
918        prog_width = int(self.width * prog)
919        if prog_width > 0:
920          bar += ('=' * (prog_width - 1))
921          if current < self.target:
922            bar += '>'
923          else:
924            bar += '='
925        bar += ('.' * (self.width - prog_width))
926        bar += ']'
927      else:
928        bar = '%7d/Unknown' % current
929
930      self._total_width = len(bar)
931      sys.stdout.write(bar)
932
933      time_per_unit = self._estimate_step_duration(current, now)
934
935      if self.target is None or finalize:
936        if time_per_unit >= 1 or time_per_unit == 0:
937          info += ' %.0fs/%s' % (time_per_unit, self.unit_name)
938        elif time_per_unit >= 1e-3:
939          info += ' %.0fms/%s' % (time_per_unit * 1e3, self.unit_name)
940        else:
941          info += ' %.0fus/%s' % (time_per_unit * 1e6, self.unit_name)
942      else:
943        eta = time_per_unit * (self.target - current)
944        if eta > 3600:
945          eta_format = '%d:%02d:%02d' % (eta // 3600,
946                                         (eta % 3600) // 60, eta % 60)
947        elif eta > 60:
948          eta_format = '%d:%02d' % (eta // 60, eta % 60)
949        else:
950          eta_format = '%ds' % eta
951
952        info = ' - ETA: %s' % eta_format
953
954      for k in self._values_order:
955        info += ' - %s:' % k
956        if isinstance(self._values[k], list):
957          avg = np.mean(self._values[k][0] / max(1, self._values[k][1]))
958          if abs(avg) > 1e-3:
959            info += ' %.4f' % avg
960          else:
961            info += ' %.4e' % avg
962        else:
963          info += ' %s' % self._values[k]
964
965      self._total_width += len(info)
966      if prev_total_width > self._total_width:
967        info += (' ' * (prev_total_width - self._total_width))
968
969      if finalize:
970        info += '\n'
971
972      sys.stdout.write(info)
973      sys.stdout.flush()
974
975    elif self.verbose == 2:
976      if finalize:
977        numdigits = int(np.log10(self.target)) + 1
978        count = ('%' + str(numdigits) + 'd/%d') % (current, self.target)
979        info = count + info
980        for k in self._values_order:
981          info += ' - %s:' % k
982          avg = np.mean(self._values[k][0] / max(1, self._values[k][1]))
983          if avg > 1e-3:
984            info += ' %.4f' % avg
985          else:
986            info += ' %.4e' % avg
987        info += '\n'
988
989        sys.stdout.write(info)
990        sys.stdout.flush()
991
992    self._last_update = now
993
994  def add(self, n, values=None):
995    self.update(self._seen_so_far + n, values)
996
997  def _estimate_step_duration(self, current, now):
998    """Estimate the duration of a single step.
999
1000    Given the step number `current` and the corresponding time `now`
1001    this function returns an estimate for how long a single step
1002    takes. If this is called before one step has been completed
1003    (i.e. `current == 0`) then zero is given as an estimate. The duration
1004    estimate ignores the duration of the (assumed to be non-representative)
1005    first step for estimates when more steps are available (i.e. `current>1`).
1006    Args:
1007      current: Index of current step.
1008      now: The current time.
1009    Returns: Estimate of the duration of a single step.
1010    """
1011    if current:
1012      # there are a few special scenarios here:
1013      # 1) somebody is calling the progress bar without ever supplying step 1
1014      # 2) somebody is calling the progress bar and supplies step one mulitple
1015      #    times, e.g. as part of a finalizing call
1016      # in these cases, we just fall back to the simple calculation
1017      if self._time_after_first_step is not None and current > 1:
1018        time_per_unit = (now - self._time_after_first_step) / (current - 1)
1019      else:
1020        time_per_unit = (now - self._start) / current
1021
1022      if current == 1:
1023        self._time_after_first_step = now
1024      return time_per_unit
1025    else:
1026      return 0
1027
1028  def _update_stateful_metrics(self, stateful_metrics):
1029    self.stateful_metrics = self.stateful_metrics.union(stateful_metrics)
1030
1031
1032def make_batches(size, batch_size):
1033  """Returns a list of batch indices (tuples of indices).
1034
1035  Args:
1036      size: Integer, total size of the data to slice into batches.
1037      batch_size: Integer, batch size.
1038
1039  Returns:
1040      A list of tuples of array indices.
1041  """
1042  num_batches = int(np.ceil(size / float(batch_size)))
1043  return [(i * batch_size, min(size, (i + 1) * batch_size))
1044          for i in range(0, num_batches)]
1045
1046
1047def slice_arrays(arrays, start=None, stop=None):
1048  """Slice an array or list of arrays.
1049
1050  This takes an array-like, or a list of
1051  array-likes, and outputs:
1052      - arrays[start:stop] if `arrays` is an array-like
1053      - [x[start:stop] for x in arrays] if `arrays` is a list
1054
1055  Can also work on list/array of indices: `slice_arrays(x, indices)`
1056
1057  Args:
1058      arrays: Single array or list of arrays.
1059      start: can be an integer index (start index) or a list/array of indices
1060      stop: integer (stop index); should be None if `start` was a list.
1061
1062  Returns:
1063      A slice of the array(s).
1064
1065  Raises:
1066      ValueError: If the value of start is a list and stop is not None.
1067  """
1068  if arrays is None:
1069    return [None]
1070  if isinstance(start, list) and stop is not None:
1071    raise ValueError('The stop argument has to be None if the value of start '
1072                     'is a list.')
1073  elif isinstance(arrays, list):
1074    if hasattr(start, '__len__'):
1075      # hdf5 datasets only support list objects as indices
1076      if hasattr(start, 'shape'):
1077        start = start.tolist()
1078      return [None if x is None else x[start] for x in arrays]
1079    return [
1080        None if x is None else
1081        None if not hasattr(x, '__getitem__') else x[start:stop] for x in arrays
1082    ]
1083  else:
1084    if hasattr(start, '__len__'):
1085      if hasattr(start, 'shape'):
1086        start = start.tolist()
1087      return arrays[start]
1088    if hasattr(start, '__getitem__'):
1089      return arrays[start:stop]
1090    return [None]
1091
1092
1093def to_list(x):
1094  """Normalizes a list/tensor into a list.
1095
1096  If a tensor is passed, we return
1097  a list of size 1 containing the tensor.
1098
1099  Args:
1100      x: target object to be normalized.
1101
1102  Returns:
1103      A list.
1104  """
1105  if isinstance(x, list):
1106    return x
1107  return [x]
1108
1109
1110def to_snake_case(name):
1111  intermediate = re.sub('(.)([A-Z][a-z0-9]+)', r'\1_\2', name)
1112  insecure = re.sub('([a-z])([A-Z])', r'\1_\2', intermediate).lower()
1113  # If the class is private the name starts with "_" which is not secure
1114  # for creating scopes. We prefix the name with "private" in this case.
1115  if insecure[0] != '_':
1116    return insecure
1117  return 'private' + insecure
1118
1119
1120def is_all_none(structure):
1121  iterable = nest.flatten(structure)
1122  # We cannot use Python's `any` because the iterable may return Tensors.
1123  for element in iterable:
1124    if element is not None:
1125      return False
1126  return True
1127
1128
1129def check_for_unexpected_keys(name, input_dict, expected_values):
1130  unknown = set(input_dict.keys()).difference(expected_values)
1131  if unknown:
1132    raise ValueError('Unknown entries in {} dictionary: {}. Only expected '
1133                     'following keys: {}'.format(name, list(unknown),
1134                                                 expected_values))
1135
1136
1137def validate_kwargs(kwargs,
1138                    allowed_kwargs,
1139                    error_message='Keyword argument not understood:'):
1140  """Checks that all keyword arguments are in the set of allowed keys."""
1141  for kwarg in kwargs:
1142    if kwarg not in allowed_kwargs:
1143      raise TypeError(error_message, kwarg)
1144
1145
1146def validate_config(config):
1147  """Determines whether config appears to be a valid layer config."""
1148  return isinstance(config, dict) and _LAYER_UNDEFINED_CONFIG_KEY not in config
1149
1150
1151def default(method):
1152  """Decorates a method to detect overrides in subclasses."""
1153  method._is_default = True  # pylint: disable=protected-access
1154  return method
1155
1156
1157def is_default(method):
1158  """Check if a method is decorated with the `default` wrapper."""
1159  return getattr(method, '_is_default', False)
1160
1161
1162def populate_dict_with_module_objects(target_dict, modules, obj_filter):
1163  for module in modules:
1164    for name in dir(module):
1165      obj = getattr(module, name)
1166      if obj_filter(obj):
1167        target_dict[name] = obj
1168
1169
1170class LazyLoader(python_types.ModuleType):
1171  """Lazily import a module, mainly to avoid pulling in large dependencies."""
1172
1173  def __init__(self, local_name, parent_module_globals, name):
1174    self._local_name = local_name
1175    self._parent_module_globals = parent_module_globals
1176    super(LazyLoader, self).__init__(name)
1177
1178  def _load(self):
1179    """Load the module and insert it into the parent's globals."""
1180    # Import the target module and insert it into the parent's namespace
1181    module = importlib.import_module(self.__name__)
1182    self._parent_module_globals[self._local_name] = module
1183    # Update this object's dict so that if someone keeps a reference to the
1184    #   LazyLoader, lookups are efficient (__getattr__ is only called on lookups
1185    #   that fail).
1186    self.__dict__.update(module.__dict__)
1187    return module
1188
1189  def __getattr__(self, item):
1190    module = self._load()
1191    return getattr(module, item)
1192
1193
1194# Aliases
1195
1196custom_object_scope = CustomObjectScope  # pylint: disable=invalid-name
1197