1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Python utilities required by Keras.""" 16 17import binascii 18import codecs 19import importlib 20import marshal 21import os 22import re 23import sys 24import threading 25import time 26import types as python_types 27import warnings 28import weakref 29 30import numpy as np 31 32from tensorflow.python.keras.utils import tf_contextlib 33from tensorflow.python.keras.utils import tf_inspect 34from tensorflow.python.util import nest 35from tensorflow.python.util import tf_decorator 36from tensorflow.python.util.tf_export import keras_export 37 38_GLOBAL_CUSTOM_OBJECTS = {} 39_GLOBAL_CUSTOM_NAMES = {} 40 41# Flag that determines whether to skip the NotImplementedError when calling 42# get_config in custom models and layers. This is only enabled when saving to 43# SavedModel, when the config isn't required. 44_SKIP_FAILED_SERIALIZATION = False 45# If a layer does not have a defined config, then the returned config will be a 46# dictionary with the below key. 47_LAYER_UNDEFINED_CONFIG_KEY = 'layer was saved without config' 48 49 50@keras_export('keras.utils.custom_object_scope', # pylint: disable=g-classes-have-attributes 51 'keras.utils.CustomObjectScope') 52class CustomObjectScope(object): 53 """Exposes custom classes/functions to Keras deserialization internals. 54 55 Under a scope `with custom_object_scope(objects_dict)`, Keras methods such 56 as `tf.keras.models.load_model` or `tf.keras.models.model_from_config` 57 will be able to deserialize any custom object referenced by a 58 saved config (e.g. a custom layer or metric). 59 60 Example: 61 62 Consider a custom regularizer `my_regularizer`: 63 64 ```python 65 layer = Dense(3, kernel_regularizer=my_regularizer) 66 config = layer.get_config() # Config contains a reference to `my_regularizer` 67 ... 68 # Later: 69 with custom_object_scope({'my_regularizer': my_regularizer}): 70 layer = Dense.from_config(config) 71 ``` 72 73 Args: 74 *args: Dictionary or dictionaries of `{name: object}` pairs. 75 """ 76 77 def __init__(self, *args): 78 self.custom_objects = args 79 self.backup = None 80 81 def __enter__(self): 82 self.backup = _GLOBAL_CUSTOM_OBJECTS.copy() 83 for objects in self.custom_objects: 84 _GLOBAL_CUSTOM_OBJECTS.update(objects) 85 return self 86 87 def __exit__(self, *args, **kwargs): 88 _GLOBAL_CUSTOM_OBJECTS.clear() 89 _GLOBAL_CUSTOM_OBJECTS.update(self.backup) 90 91 92@keras_export('keras.utils.get_custom_objects') 93def get_custom_objects(): 94 """Retrieves a live reference to the global dictionary of custom objects. 95 96 Updating and clearing custom objects using `custom_object_scope` 97 is preferred, but `get_custom_objects` can 98 be used to directly access the current collection of custom objects. 99 100 Example: 101 102 ```python 103 get_custom_objects().clear() 104 get_custom_objects()['MyObject'] = MyObject 105 ``` 106 107 Returns: 108 Global dictionary of names to classes (`_GLOBAL_CUSTOM_OBJECTS`). 109 """ 110 return _GLOBAL_CUSTOM_OBJECTS 111 112 113# Store a unique, per-object ID for shared objects. 114# 115# We store a unique ID for each object so that we may, at loading time, 116# re-create the network properly. Without this ID, we would have no way of 117# determining whether a config is a description of a new object that 118# should be created or is merely a reference to an already-created object. 119SHARED_OBJECT_KEY = 'shared_object_id' 120 121 122SHARED_OBJECT_DISABLED = threading.local() 123SHARED_OBJECT_LOADING = threading.local() 124SHARED_OBJECT_SAVING = threading.local() 125 126 127# Attributes on the threadlocal variable must be set per-thread, thus we 128# cannot initialize these globally. Instead, we have accessor functions with 129# default values. 130def _shared_object_disabled(): 131 """Get whether shared object handling is disabled in a threadsafe manner.""" 132 return getattr(SHARED_OBJECT_DISABLED, 'disabled', False) 133 134 135def _shared_object_loading_scope(): 136 """Get the current shared object saving scope in a threadsafe manner.""" 137 return getattr(SHARED_OBJECT_LOADING, 'scope', NoopLoadingScope()) 138 139 140def _shared_object_saving_scope(): 141 """Get the current shared object saving scope in a threadsafe manner.""" 142 return getattr(SHARED_OBJECT_SAVING, 'scope', None) 143 144 145class DisableSharedObjectScope(object): 146 """A context manager for disabling handling of shared objects. 147 148 Disables shared object handling for both saving and loading. 149 150 Created primarily for use with `clone_model`, which does extra surgery that 151 is incompatible with shared objects. 152 """ 153 154 def __enter__(self): 155 SHARED_OBJECT_DISABLED.disabled = True 156 self._orig_loading_scope = _shared_object_loading_scope() 157 self._orig_saving_scope = _shared_object_saving_scope() 158 159 def __exit__(self, *args, **kwargs): 160 SHARED_OBJECT_DISABLED.disabled = False 161 SHARED_OBJECT_LOADING.scope = self._orig_loading_scope 162 SHARED_OBJECT_SAVING.scope = self._orig_saving_scope 163 164 165class NoopLoadingScope(object): 166 """The default shared object loading scope. It does nothing. 167 168 Created to simplify serialization code that doesn't care about shared objects 169 (e.g. when serializing a single object). 170 """ 171 172 def get(self, unused_object_id): 173 return None 174 175 def set(self, object_id, obj): 176 pass 177 178 179class SharedObjectLoadingScope(object): 180 """A context manager for keeping track of loaded objects. 181 182 During the deserialization process, we may come across objects that are 183 shared across multiple layers. In order to accurately restore the network 184 structure to its original state, `SharedObjectLoadingScope` allows us to 185 re-use shared objects rather than cloning them. 186 """ 187 188 def __enter__(self): 189 if _shared_object_disabled(): 190 return NoopLoadingScope() 191 192 global SHARED_OBJECT_LOADING 193 SHARED_OBJECT_LOADING.scope = self 194 self._obj_ids_to_obj = {} 195 return self 196 197 def get(self, object_id): 198 """Given a shared object ID, returns a previously instantiated object. 199 200 Args: 201 object_id: shared object ID to use when attempting to find already-loaded 202 object. 203 204 Returns: 205 The object, if we've seen this ID before. Else, `None`. 206 """ 207 # Explicitly check for `None` internally to make external calling code a 208 # bit cleaner. 209 if object_id is None: 210 return 211 return self._obj_ids_to_obj.get(object_id) 212 213 def set(self, object_id, obj): 214 """Stores an instantiated object for future lookup and sharing.""" 215 if object_id is None: 216 return 217 self._obj_ids_to_obj[object_id] = obj 218 219 def __exit__(self, *args, **kwargs): 220 global SHARED_OBJECT_LOADING 221 SHARED_OBJECT_LOADING.scope = NoopLoadingScope() 222 223 224class SharedObjectConfig(dict): 225 """A configuration container that keeps track of references. 226 227 `SharedObjectConfig` will automatically attach a shared object ID to any 228 configs which are referenced more than once, allowing for proper shared 229 object reconstruction at load time. 230 231 In most cases, it would be more proper to subclass something like 232 `collections.UserDict` or `collections.Mapping` rather than `dict` directly. 233 Unfortunately, python's json encoder does not support `Mapping`s. This is 234 important functionality to retain, since we are dealing with serialization. 235 236 We should be safe to subclass `dict` here, since we aren't actually 237 overriding any core methods, only augmenting with a new one for reference 238 counting. 239 """ 240 241 def __init__(self, base_config, object_id, **kwargs): 242 self.ref_count = 1 243 self.object_id = object_id 244 super(SharedObjectConfig, self).__init__(base_config, **kwargs) 245 246 def increment_ref_count(self): 247 # As soon as we've seen the object more than once, we want to attach the 248 # shared object ID. This allows us to only attach the shared object ID when 249 # it's strictly necessary, making backwards compatibility breakage less 250 # likely. 251 if self.ref_count == 1: 252 self[SHARED_OBJECT_KEY] = self.object_id 253 self.ref_count += 1 254 255 256class SharedObjectSavingScope(object): 257 """Keeps track of shared object configs when serializing.""" 258 259 def __enter__(self): 260 if _shared_object_disabled(): 261 return None 262 263 global SHARED_OBJECT_SAVING 264 265 # Serialization can happen at a number of layers for a number of reasons. 266 # We may end up with a case where we're opening a saving scope within 267 # another saving scope. In that case, we'd like to use the outermost scope 268 # available and ignore inner scopes, since there is not (yet) a reasonable 269 # use case for having these nested and distinct. 270 if _shared_object_saving_scope() is not None: 271 self._passthrough = True 272 return _shared_object_saving_scope() 273 else: 274 self._passthrough = False 275 276 SHARED_OBJECT_SAVING.scope = self 277 self._shared_objects_config = weakref.WeakKeyDictionary() 278 self._next_id = 0 279 return self 280 281 def get_config(self, obj): 282 """Gets a `SharedObjectConfig` if one has already been seen for `obj`. 283 284 Args: 285 obj: The object for which to retrieve the `SharedObjectConfig`. 286 287 Returns: 288 The SharedObjectConfig for a given object, if already seen. Else, 289 `None`. 290 """ 291 try: 292 shared_object_config = self._shared_objects_config[obj] 293 except (TypeError, KeyError): 294 # If the object is unhashable (e.g. a subclass of `AbstractBaseClass` 295 # that has not overridden `__hash__`), a `TypeError` will be thrown. 296 # We'll just continue on without shared object support. 297 return None 298 shared_object_config.increment_ref_count() 299 return shared_object_config 300 301 def create_config(self, base_config, obj): 302 """Create a new SharedObjectConfig for a given object.""" 303 shared_object_config = SharedObjectConfig(base_config, self._next_id) 304 self._next_id += 1 305 try: 306 self._shared_objects_config[obj] = shared_object_config 307 except TypeError: 308 # If the object is unhashable (e.g. a subclass of `AbstractBaseClass` 309 # that has not overridden `__hash__`), a `TypeError` will be thrown. 310 # We'll just continue on without shared object support. 311 pass 312 return shared_object_config 313 314 def __exit__(self, *args, **kwargs): 315 if not getattr(self, '_passthrough', False): 316 global SHARED_OBJECT_SAVING 317 SHARED_OBJECT_SAVING.scope = None 318 319 320def serialize_keras_class_and_config( 321 cls_name, cls_config, obj=None, shared_object_id=None): 322 """Returns the serialization of the class with the given config.""" 323 base_config = {'class_name': cls_name, 'config': cls_config} 324 325 # We call `serialize_keras_class_and_config` for some branches of the load 326 # path. In that case, we may already have a shared object ID we'd like to 327 # retain. 328 if shared_object_id is not None: 329 base_config[SHARED_OBJECT_KEY] = shared_object_id 330 331 # If we have an active `SharedObjectSavingScope`, check whether we've already 332 # serialized this config. If so, just use that config. This will store an 333 # extra ID field in the config, allowing us to re-create the shared object 334 # relationship at load time. 335 if _shared_object_saving_scope() is not None and obj is not None: 336 shared_object_config = _shared_object_saving_scope().get_config(obj) 337 if shared_object_config is None: 338 return _shared_object_saving_scope().create_config(base_config, obj) 339 return shared_object_config 340 341 return base_config 342 343 344@keras_export('keras.utils.register_keras_serializable') 345def register_keras_serializable(package='Custom', name=None): 346 """Registers an object with the Keras serialization framework. 347 348 This decorator injects the decorated class or function into the Keras custom 349 object dictionary, so that it can be serialized and deserialized without 350 needing an entry in the user-provided custom object dict. It also injects a 351 function that Keras will call to get the object's serializable string key. 352 353 Note that to be serialized and deserialized, classes must implement the 354 `get_config()` method. Functions do not have this requirement. 355 356 The object will be registered under the key 'package>name' where `name`, 357 defaults to the object name if not passed. 358 359 Args: 360 package: The package that this class belongs to. 361 name: The name to serialize this class under in this package. If None, the 362 class' name will be used. 363 364 Returns: 365 A decorator that registers the decorated class with the passed names. 366 """ 367 368 def decorator(arg): 369 """Registers a class with the Keras serialization framework.""" 370 class_name = name if name is not None else arg.__name__ 371 registered_name = package + '>' + class_name 372 373 if tf_inspect.isclass(arg) and not hasattr(arg, 'get_config'): 374 raise ValueError( 375 'Cannot register a class that does not have a get_config() method.') 376 377 if registered_name in _GLOBAL_CUSTOM_OBJECTS: 378 raise ValueError( 379 '%s has already been registered to %s' % 380 (registered_name, _GLOBAL_CUSTOM_OBJECTS[registered_name])) 381 382 if arg in _GLOBAL_CUSTOM_NAMES: 383 raise ValueError('%s has already been registered to %s' % 384 (arg, _GLOBAL_CUSTOM_NAMES[arg])) 385 _GLOBAL_CUSTOM_OBJECTS[registered_name] = arg 386 _GLOBAL_CUSTOM_NAMES[arg] = registered_name 387 388 return arg 389 390 return decorator 391 392 393@keras_export('keras.utils.get_registered_name') 394def get_registered_name(obj): 395 """Returns the name registered to an object within the Keras framework. 396 397 This function is part of the Keras serialization and deserialization 398 framework. It maps objects to the string names associated with those objects 399 for serialization/deserialization. 400 401 Args: 402 obj: The object to look up. 403 404 Returns: 405 The name associated with the object, or the default Python name if the 406 object is not registered. 407 """ 408 if obj in _GLOBAL_CUSTOM_NAMES: 409 return _GLOBAL_CUSTOM_NAMES[obj] 410 else: 411 return obj.__name__ 412 413 414@tf_contextlib.contextmanager 415def skip_failed_serialization(): 416 global _SKIP_FAILED_SERIALIZATION 417 prev = _SKIP_FAILED_SERIALIZATION 418 try: 419 _SKIP_FAILED_SERIALIZATION = True 420 yield 421 finally: 422 _SKIP_FAILED_SERIALIZATION = prev 423 424 425@keras_export('keras.utils.get_registered_object') 426def get_registered_object(name, custom_objects=None, module_objects=None): 427 """Returns the class associated with `name` if it is registered with Keras. 428 429 This function is part of the Keras serialization and deserialization 430 framework. It maps strings to the objects associated with them for 431 serialization/deserialization. 432 433 Example: 434 ``` 435 def from_config(cls, config, custom_objects=None): 436 if 'my_custom_object_name' in config: 437 config['hidden_cls'] = tf.keras.utils.get_registered_object( 438 config['my_custom_object_name'], custom_objects=custom_objects) 439 ``` 440 441 Args: 442 name: The name to look up. 443 custom_objects: A dictionary of custom objects to look the name up in. 444 Generally, custom_objects is provided by the user. 445 module_objects: A dictionary of custom objects to look the name up in. 446 Generally, module_objects is provided by midlevel library implementers. 447 448 Returns: 449 An instantiable class associated with 'name', or None if no such class 450 exists. 451 """ 452 if name in _GLOBAL_CUSTOM_OBJECTS: 453 return _GLOBAL_CUSTOM_OBJECTS[name] 454 elif custom_objects and name in custom_objects: 455 return custom_objects[name] 456 elif module_objects and name in module_objects: 457 return module_objects[name] 458 return None 459 460 461# pylint: disable=g-bad-exception-name 462class CustomMaskWarning(Warning): 463 pass 464# pylint: enable=g-bad-exception-name 465 466 467@keras_export('keras.utils.serialize_keras_object') 468def serialize_keras_object(instance): 469 """Serialize a Keras object into a JSON-compatible representation. 470 471 Calls to `serialize_keras_object` while underneath the 472 `SharedObjectSavingScope` context manager will cause any objects re-used 473 across multiple layers to be saved with a special shared object ID. This 474 allows the network to be re-created properly during deserialization. 475 476 Args: 477 instance: The object to serialize. 478 479 Returns: 480 A dict-like, JSON-compatible representation of the object's config. 481 """ 482 _, instance = tf_decorator.unwrap(instance) 483 if instance is None: 484 return None 485 486 # pylint: disable=protected-access 487 # 488 # For v1 layers, checking supports_masking is not enough. We have to also 489 # check whether compute_mask has been overridden. 490 supports_masking = (getattr(instance, 'supports_masking', False) 491 or (hasattr(instance, 'compute_mask') 492 and not is_default(instance.compute_mask))) 493 if supports_masking and is_default(instance.get_config): 494 warnings.warn('Custom mask layers require a config and must override ' 495 'get_config. When loading, the custom mask layer must be ' 496 'passed to the custom_objects argument.', 497 category=CustomMaskWarning) 498 # pylint: enable=protected-access 499 500 if hasattr(instance, 'get_config'): 501 name = get_registered_name(instance.__class__) 502 try: 503 config = instance.get_config() 504 except NotImplementedError as e: 505 if _SKIP_FAILED_SERIALIZATION: 506 return serialize_keras_class_and_config( 507 name, {_LAYER_UNDEFINED_CONFIG_KEY: True}) 508 raise e 509 serialization_config = {} 510 for key, item in config.items(): 511 if isinstance(item, str): 512 serialization_config[key] = item 513 continue 514 515 # Any object of a different type needs to be converted to string or dict 516 # for serialization (e.g. custom functions, custom classes) 517 try: 518 serialized_item = serialize_keras_object(item) 519 if isinstance(serialized_item, dict) and not isinstance(item, dict): 520 serialized_item['__passive_serialization__'] = True 521 serialization_config[key] = serialized_item 522 except ValueError: 523 serialization_config[key] = item 524 525 name = get_registered_name(instance.__class__) 526 return serialize_keras_class_and_config( 527 name, serialization_config, instance) 528 if hasattr(instance, '__name__'): 529 return get_registered_name(instance) 530 raise ValueError('Cannot serialize', instance) 531 532 533def get_custom_objects_by_name(item, custom_objects=None): 534 """Returns the item if it is in either local or global custom objects.""" 535 if item in _GLOBAL_CUSTOM_OBJECTS: 536 return _GLOBAL_CUSTOM_OBJECTS[item] 537 elif custom_objects and item in custom_objects: 538 return custom_objects[item] 539 return None 540 541 542def class_and_config_for_serialized_keras_object( 543 config, 544 module_objects=None, 545 custom_objects=None, 546 printable_module_name='object'): 547 """Returns the class name and config for a serialized keras object.""" 548 if (not isinstance(config, dict) 549 or 'class_name' not in config 550 or 'config' not in config): 551 raise ValueError('Improper config format: ' + str(config)) 552 553 class_name = config['class_name'] 554 cls = get_registered_object(class_name, custom_objects, module_objects) 555 if cls is None: 556 raise ValueError( 557 'Unknown {}: {}. Please ensure this object is ' 558 'passed to the `custom_objects` argument. See ' 559 'https://www.tensorflow.org/guide/keras/save_and_serialize' 560 '#registering_the_custom_object for details.' 561 .format(printable_module_name, class_name)) 562 563 cls_config = config['config'] 564 # Check if `cls_config` is a list. If it is a list, return the class and the 565 # associated class configs for recursively deserialization. This case will 566 # happen on the old version of sequential model (e.g. `keras_version` == 567 # "2.0.6"), which is serialized in a different structure, for example 568 # "{'class_name': 'Sequential', 569 # 'config': [{'class_name': 'Embedding', 'config': ...}, {}, ...]}". 570 if isinstance(cls_config, list): 571 return (cls, cls_config) 572 573 deserialized_objects = {} 574 for key, item in cls_config.items(): 575 if key == 'name': 576 # Assume that the value of 'name' is a string that should not be 577 # deserialized as a function. This avoids the corner case where 578 # cls_config['name'] has an identical name to a custom function and 579 # gets converted into that function. 580 deserialized_objects[key] = item 581 elif isinstance(item, dict) and '__passive_serialization__' in item: 582 deserialized_objects[key] = deserialize_keras_object( 583 item, 584 module_objects=module_objects, 585 custom_objects=custom_objects, 586 printable_module_name='config_item') 587 # TODO(momernick): Should this also have 'module_objects'? 588 elif (isinstance(item, str) and 589 tf_inspect.isfunction(get_registered_object(item, custom_objects))): 590 # Handle custom functions here. When saving functions, we only save the 591 # function's name as a string. If we find a matching string in the custom 592 # objects during deserialization, we convert the string back to the 593 # original function. 594 # Note that a potential issue is that a string field could have a naming 595 # conflict with a custom function name, but this should be a rare case. 596 # This issue does not occur if a string field has a naming conflict with 597 # a custom object, since the config of an object will always be a dict. 598 deserialized_objects[key] = get_registered_object(item, custom_objects) 599 for key, item in deserialized_objects.items(): 600 cls_config[key] = deserialized_objects[key] 601 602 return (cls, cls_config) 603 604 605@keras_export('keras.utils.deserialize_keras_object') 606def deserialize_keras_object(identifier, 607 module_objects=None, 608 custom_objects=None, 609 printable_module_name='object'): 610 """Turns the serialized form of a Keras object back into an actual object. 611 612 This function is for mid-level library implementers rather than end users. 613 614 Importantly, this utility requires you to provide the dict of `module_objects` 615 to use for looking up the object config; this is not populated by default. 616 If you need a deserialization utility that has preexisting knowledge of 617 built-in Keras objects, use e.g. `keras.layers.deserialize(config)`, 618 `keras.metrics.deserialize(config)`, etc. 619 620 Calling `deserialize_keras_object` while underneath the 621 `SharedObjectLoadingScope` context manager will cause any already-seen shared 622 objects to be returned as-is rather than creating a new object. 623 624 Args: 625 identifier: the serialized form of the object. 626 module_objects: A dictionary of built-in objects to look the name up in. 627 Generally, `module_objects` is provided by midlevel library implementers. 628 custom_objects: A dictionary of custom objects to look the name up in. 629 Generally, `custom_objects` is provided by the end user. 630 printable_module_name: A human-readable string representing the type of the 631 object. Printed in case of exception. 632 633 Returns: 634 The deserialized object. 635 636 Example: 637 638 A mid-level library implementer might want to implement a utility for 639 retrieving an object from its config, as such: 640 641 ```python 642 def deserialize(config, custom_objects=None): 643 return deserialize_keras_object( 644 identifier, 645 module_objects=globals(), 646 custom_objects=custom_objects, 647 name="MyObjectType", 648 ) 649 ``` 650 651 This is how e.g. `keras.layers.deserialize()` is implemented. 652 """ 653 if identifier is None: 654 return None 655 656 if isinstance(identifier, dict): 657 # In this case we are dealing with a Keras config dictionary. 658 config = identifier 659 (cls, cls_config) = class_and_config_for_serialized_keras_object( 660 config, module_objects, custom_objects, printable_module_name) 661 662 # If this object has already been loaded (i.e. it's shared between multiple 663 # objects), return the already-loaded object. 664 shared_object_id = config.get(SHARED_OBJECT_KEY) 665 shared_object = _shared_object_loading_scope().get(shared_object_id) # pylint: disable=assignment-from-none 666 if shared_object is not None: 667 return shared_object 668 669 if hasattr(cls, 'from_config'): 670 arg_spec = tf_inspect.getfullargspec(cls.from_config) 671 custom_objects = custom_objects or {} 672 673 if 'custom_objects' in arg_spec.args: 674 deserialized_obj = cls.from_config( 675 cls_config, 676 custom_objects=dict( 677 list(_GLOBAL_CUSTOM_OBJECTS.items()) + 678 list(custom_objects.items()))) 679 else: 680 with CustomObjectScope(custom_objects): 681 deserialized_obj = cls.from_config(cls_config) 682 else: 683 # Then `cls` may be a function returning a class. 684 # in this case by convention `config` holds 685 # the kwargs of the function. 686 custom_objects = custom_objects or {} 687 with CustomObjectScope(custom_objects): 688 deserialized_obj = cls(**cls_config) 689 690 # Add object to shared objects, in case we find it referenced again. 691 _shared_object_loading_scope().set(shared_object_id, deserialized_obj) 692 693 return deserialized_obj 694 695 elif isinstance(identifier, str): 696 object_name = identifier 697 if custom_objects and object_name in custom_objects: 698 obj = custom_objects.get(object_name) 699 elif object_name in _GLOBAL_CUSTOM_OBJECTS: 700 obj = _GLOBAL_CUSTOM_OBJECTS[object_name] 701 else: 702 obj = module_objects.get(object_name) 703 if obj is None: 704 raise ValueError( 705 'Unknown {}: {}. Please ensure this object is ' 706 'passed to the `custom_objects` argument. See ' 707 'https://www.tensorflow.org/guide/keras/save_and_serialize' 708 '#registering_the_custom_object for details.' 709 .format(printable_module_name, object_name)) 710 711 # Classes passed by name are instantiated with no args, functions are 712 # returned as-is. 713 if tf_inspect.isclass(obj): 714 return obj() 715 return obj 716 elif tf_inspect.isfunction(identifier): 717 # If a function has already been deserialized, return as is. 718 return identifier 719 else: 720 raise ValueError('Could not interpret serialized %s: %s' % 721 (printable_module_name, identifier)) 722 723 724def func_dump(func): 725 """Serializes a user defined function. 726 727 Args: 728 func: the function to serialize. 729 730 Returns: 731 A tuple `(code, defaults, closure)`. 732 """ 733 if os.name == 'nt': 734 raw_code = marshal.dumps(func.__code__).replace(b'\\', b'/') 735 code = codecs.encode(raw_code, 'base64').decode('ascii') 736 else: 737 raw_code = marshal.dumps(func.__code__) 738 code = codecs.encode(raw_code, 'base64').decode('ascii') 739 defaults = func.__defaults__ 740 if func.__closure__: 741 closure = tuple(c.cell_contents for c in func.__closure__) 742 else: 743 closure = None 744 return code, defaults, closure 745 746 747def func_load(code, defaults=None, closure=None, globs=None): 748 """Deserializes a user defined function. 749 750 Args: 751 code: bytecode of the function. 752 defaults: defaults of the function. 753 closure: closure of the function. 754 globs: dictionary of global objects. 755 756 Returns: 757 A function object. 758 """ 759 if isinstance(code, (tuple, list)): # unpack previous dump 760 code, defaults, closure = code 761 if isinstance(defaults, list): 762 defaults = tuple(defaults) 763 764 def ensure_value_to_cell(value): 765 """Ensures that a value is converted to a python cell object. 766 767 Args: 768 value: Any value that needs to be casted to the cell type 769 770 Returns: 771 A value wrapped as a cell object (see function "func_load") 772 """ 773 774 def dummy_fn(): 775 # pylint: disable=pointless-statement 776 value # just access it so it gets captured in .__closure__ 777 778 cell_value = dummy_fn.__closure__[0] 779 if not isinstance(value, type(cell_value)): 780 return cell_value 781 return value 782 783 if closure is not None: 784 closure = tuple(ensure_value_to_cell(_) for _ in closure) 785 try: 786 raw_code = codecs.decode(code.encode('ascii'), 'base64') 787 except (UnicodeEncodeError, binascii.Error): 788 raw_code = code.encode('raw_unicode_escape') 789 code = marshal.loads(raw_code) 790 if globs is None: 791 globs = globals() 792 return python_types.FunctionType( 793 code, globs, name=code.co_name, argdefs=defaults, closure=closure) 794 795 796def has_arg(fn, name, accept_all=False): 797 """Checks if a callable accepts a given keyword argument. 798 799 Args: 800 fn: Callable to inspect. 801 name: Check if `fn` can be called with `name` as a keyword argument. 802 accept_all: What to return if there is no parameter called `name` but the 803 function accepts a `**kwargs` argument. 804 805 Returns: 806 bool, whether `fn` accepts a `name` keyword argument. 807 """ 808 arg_spec = tf_inspect.getfullargspec(fn) 809 if accept_all and arg_spec.varkw is not None: 810 return True 811 return name in arg_spec.args or name in arg_spec.kwonlyargs 812 813 814@keras_export('keras.utils.Progbar') 815class Progbar(object): 816 """Displays a progress bar. 817 818 Args: 819 target: Total number of steps expected, None if unknown. 820 width: Progress bar width on screen. 821 verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose) 822 stateful_metrics: Iterable of string names of metrics that should *not* be 823 averaged over time. Metrics in this list will be displayed as-is. All 824 others will be averaged by the progbar before display. 825 interval: Minimum visual progress update interval (in seconds). 826 unit_name: Display name for step counts (usually "step" or "sample"). 827 """ 828 829 def __init__(self, 830 target, 831 width=30, 832 verbose=1, 833 interval=0.05, 834 stateful_metrics=None, 835 unit_name='step'): 836 self.target = target 837 self.width = width 838 self.verbose = verbose 839 self.interval = interval 840 self.unit_name = unit_name 841 if stateful_metrics: 842 self.stateful_metrics = set(stateful_metrics) 843 else: 844 self.stateful_metrics = set() 845 846 self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and 847 sys.stdout.isatty()) or 848 'ipykernel' in sys.modules or 849 'posix' in sys.modules or 850 'PYCHARM_HOSTED' in os.environ) 851 self._total_width = 0 852 self._seen_so_far = 0 853 # We use a dict + list to avoid garbage collection 854 # issues found in OrderedDict 855 self._values = {} 856 self._values_order = [] 857 self._start = time.time() 858 self._last_update = 0 859 860 self._time_after_first_step = None 861 862 def update(self, current, values=None, finalize=None): 863 """Updates the progress bar. 864 865 Args: 866 current: Index of current step. 867 values: List of tuples: `(name, value_for_last_step)`. If `name` is in 868 `stateful_metrics`, `value_for_last_step` will be displayed as-is. 869 Else, an average of the metric over time will be displayed. 870 finalize: Whether this is the last update for the progress bar. If 871 `None`, defaults to `current >= self.target`. 872 """ 873 if finalize is None: 874 if self.target is None: 875 finalize = False 876 else: 877 finalize = current >= self.target 878 879 values = values or [] 880 for k, v in values: 881 if k not in self._values_order: 882 self._values_order.append(k) 883 if k not in self.stateful_metrics: 884 # In the case that progress bar doesn't have a target value in the first 885 # epoch, both on_batch_end and on_epoch_end will be called, which will 886 # cause 'current' and 'self._seen_so_far' to have the same value. Force 887 # the minimal value to 1 here, otherwise stateful_metric will be 0s. 888 value_base = max(current - self._seen_so_far, 1) 889 if k not in self._values: 890 self._values[k] = [v * value_base, value_base] 891 else: 892 self._values[k][0] += v * value_base 893 self._values[k][1] += value_base 894 else: 895 # Stateful metrics output a numeric value. This representation 896 # means "take an average from a single value" but keeps the 897 # numeric formatting. 898 self._values[k] = [v, 1] 899 self._seen_so_far = current 900 901 now = time.time() 902 info = ' - %.0fs' % (now - self._start) 903 if self.verbose == 1: 904 if now - self._last_update < self.interval and not finalize: 905 return 906 907 prev_total_width = self._total_width 908 if self._dynamic_display: 909 sys.stdout.write('\b' * prev_total_width) 910 sys.stdout.write('\r') 911 else: 912 sys.stdout.write('\n') 913 914 if self.target is not None: 915 numdigits = int(np.log10(self.target)) + 1 916 bar = ('%' + str(numdigits) + 'd/%d [') % (current, self.target) 917 prog = float(current) / self.target 918 prog_width = int(self.width * prog) 919 if prog_width > 0: 920 bar += ('=' * (prog_width - 1)) 921 if current < self.target: 922 bar += '>' 923 else: 924 bar += '=' 925 bar += ('.' * (self.width - prog_width)) 926 bar += ']' 927 else: 928 bar = '%7d/Unknown' % current 929 930 self._total_width = len(bar) 931 sys.stdout.write(bar) 932 933 time_per_unit = self._estimate_step_duration(current, now) 934 935 if self.target is None or finalize: 936 if time_per_unit >= 1 or time_per_unit == 0: 937 info += ' %.0fs/%s' % (time_per_unit, self.unit_name) 938 elif time_per_unit >= 1e-3: 939 info += ' %.0fms/%s' % (time_per_unit * 1e3, self.unit_name) 940 else: 941 info += ' %.0fus/%s' % (time_per_unit * 1e6, self.unit_name) 942 else: 943 eta = time_per_unit * (self.target - current) 944 if eta > 3600: 945 eta_format = '%d:%02d:%02d' % (eta // 3600, 946 (eta % 3600) // 60, eta % 60) 947 elif eta > 60: 948 eta_format = '%d:%02d' % (eta // 60, eta % 60) 949 else: 950 eta_format = '%ds' % eta 951 952 info = ' - ETA: %s' % eta_format 953 954 for k in self._values_order: 955 info += ' - %s:' % k 956 if isinstance(self._values[k], list): 957 avg = np.mean(self._values[k][0] / max(1, self._values[k][1])) 958 if abs(avg) > 1e-3: 959 info += ' %.4f' % avg 960 else: 961 info += ' %.4e' % avg 962 else: 963 info += ' %s' % self._values[k] 964 965 self._total_width += len(info) 966 if prev_total_width > self._total_width: 967 info += (' ' * (prev_total_width - self._total_width)) 968 969 if finalize: 970 info += '\n' 971 972 sys.stdout.write(info) 973 sys.stdout.flush() 974 975 elif self.verbose == 2: 976 if finalize: 977 numdigits = int(np.log10(self.target)) + 1 978 count = ('%' + str(numdigits) + 'd/%d') % (current, self.target) 979 info = count + info 980 for k in self._values_order: 981 info += ' - %s:' % k 982 avg = np.mean(self._values[k][0] / max(1, self._values[k][1])) 983 if avg > 1e-3: 984 info += ' %.4f' % avg 985 else: 986 info += ' %.4e' % avg 987 info += '\n' 988 989 sys.stdout.write(info) 990 sys.stdout.flush() 991 992 self._last_update = now 993 994 def add(self, n, values=None): 995 self.update(self._seen_so_far + n, values) 996 997 def _estimate_step_duration(self, current, now): 998 """Estimate the duration of a single step. 999 1000 Given the step number `current` and the corresponding time `now` 1001 this function returns an estimate for how long a single step 1002 takes. If this is called before one step has been completed 1003 (i.e. `current == 0`) then zero is given as an estimate. The duration 1004 estimate ignores the duration of the (assumed to be non-representative) 1005 first step for estimates when more steps are available (i.e. `current>1`). 1006 Args: 1007 current: Index of current step. 1008 now: The current time. 1009 Returns: Estimate of the duration of a single step. 1010 """ 1011 if current: 1012 # there are a few special scenarios here: 1013 # 1) somebody is calling the progress bar without ever supplying step 1 1014 # 2) somebody is calling the progress bar and supplies step one mulitple 1015 # times, e.g. as part of a finalizing call 1016 # in these cases, we just fall back to the simple calculation 1017 if self._time_after_first_step is not None and current > 1: 1018 time_per_unit = (now - self._time_after_first_step) / (current - 1) 1019 else: 1020 time_per_unit = (now - self._start) / current 1021 1022 if current == 1: 1023 self._time_after_first_step = now 1024 return time_per_unit 1025 else: 1026 return 0 1027 1028 def _update_stateful_metrics(self, stateful_metrics): 1029 self.stateful_metrics = self.stateful_metrics.union(stateful_metrics) 1030 1031 1032def make_batches(size, batch_size): 1033 """Returns a list of batch indices (tuples of indices). 1034 1035 Args: 1036 size: Integer, total size of the data to slice into batches. 1037 batch_size: Integer, batch size. 1038 1039 Returns: 1040 A list of tuples of array indices. 1041 """ 1042 num_batches = int(np.ceil(size / float(batch_size))) 1043 return [(i * batch_size, min(size, (i + 1) * batch_size)) 1044 for i in range(0, num_batches)] 1045 1046 1047def slice_arrays(arrays, start=None, stop=None): 1048 """Slice an array or list of arrays. 1049 1050 This takes an array-like, or a list of 1051 array-likes, and outputs: 1052 - arrays[start:stop] if `arrays` is an array-like 1053 - [x[start:stop] for x in arrays] if `arrays` is a list 1054 1055 Can also work on list/array of indices: `slice_arrays(x, indices)` 1056 1057 Args: 1058 arrays: Single array or list of arrays. 1059 start: can be an integer index (start index) or a list/array of indices 1060 stop: integer (stop index); should be None if `start` was a list. 1061 1062 Returns: 1063 A slice of the array(s). 1064 1065 Raises: 1066 ValueError: If the value of start is a list and stop is not None. 1067 """ 1068 if arrays is None: 1069 return [None] 1070 if isinstance(start, list) and stop is not None: 1071 raise ValueError('The stop argument has to be None if the value of start ' 1072 'is a list.') 1073 elif isinstance(arrays, list): 1074 if hasattr(start, '__len__'): 1075 # hdf5 datasets only support list objects as indices 1076 if hasattr(start, 'shape'): 1077 start = start.tolist() 1078 return [None if x is None else x[start] for x in arrays] 1079 return [ 1080 None if x is None else 1081 None if not hasattr(x, '__getitem__') else x[start:stop] for x in arrays 1082 ] 1083 else: 1084 if hasattr(start, '__len__'): 1085 if hasattr(start, 'shape'): 1086 start = start.tolist() 1087 return arrays[start] 1088 if hasattr(start, '__getitem__'): 1089 return arrays[start:stop] 1090 return [None] 1091 1092 1093def to_list(x): 1094 """Normalizes a list/tensor into a list. 1095 1096 If a tensor is passed, we return 1097 a list of size 1 containing the tensor. 1098 1099 Args: 1100 x: target object to be normalized. 1101 1102 Returns: 1103 A list. 1104 """ 1105 if isinstance(x, list): 1106 return x 1107 return [x] 1108 1109 1110def to_snake_case(name): 1111 intermediate = re.sub('(.)([A-Z][a-z0-9]+)', r'\1_\2', name) 1112 insecure = re.sub('([a-z])([A-Z])', r'\1_\2', intermediate).lower() 1113 # If the class is private the name starts with "_" which is not secure 1114 # for creating scopes. We prefix the name with "private" in this case. 1115 if insecure[0] != '_': 1116 return insecure 1117 return 'private' + insecure 1118 1119 1120def is_all_none(structure): 1121 iterable = nest.flatten(structure) 1122 # We cannot use Python's `any` because the iterable may return Tensors. 1123 for element in iterable: 1124 if element is not None: 1125 return False 1126 return True 1127 1128 1129def check_for_unexpected_keys(name, input_dict, expected_values): 1130 unknown = set(input_dict.keys()).difference(expected_values) 1131 if unknown: 1132 raise ValueError('Unknown entries in {} dictionary: {}. Only expected ' 1133 'following keys: {}'.format(name, list(unknown), 1134 expected_values)) 1135 1136 1137def validate_kwargs(kwargs, 1138 allowed_kwargs, 1139 error_message='Keyword argument not understood:'): 1140 """Checks that all keyword arguments are in the set of allowed keys.""" 1141 for kwarg in kwargs: 1142 if kwarg not in allowed_kwargs: 1143 raise TypeError(error_message, kwarg) 1144 1145 1146def validate_config(config): 1147 """Determines whether config appears to be a valid layer config.""" 1148 return isinstance(config, dict) and _LAYER_UNDEFINED_CONFIG_KEY not in config 1149 1150 1151def default(method): 1152 """Decorates a method to detect overrides in subclasses.""" 1153 method._is_default = True # pylint: disable=protected-access 1154 return method 1155 1156 1157def is_default(method): 1158 """Check if a method is decorated with the `default` wrapper.""" 1159 return getattr(method, '_is_default', False) 1160 1161 1162def populate_dict_with_module_objects(target_dict, modules, obj_filter): 1163 for module in modules: 1164 for name in dir(module): 1165 obj = getattr(module, name) 1166 if obj_filter(obj): 1167 target_dict[name] = obj 1168 1169 1170class LazyLoader(python_types.ModuleType): 1171 """Lazily import a module, mainly to avoid pulling in large dependencies.""" 1172 1173 def __init__(self, local_name, parent_module_globals, name): 1174 self._local_name = local_name 1175 self._parent_module_globals = parent_module_globals 1176 super(LazyLoader, self).__init__(name) 1177 1178 def _load(self): 1179 """Load the module and insert it into the parent's globals.""" 1180 # Import the target module and insert it into the parent's namespace 1181 module = importlib.import_module(self.__name__) 1182 self._parent_module_globals[self._local_name] = module 1183 # Update this object's dict so that if someone keeps a reference to the 1184 # LazyLoader, lookups are efficient (__getattr__ is only called on lookups 1185 # that fail). 1186 self.__dict__.update(module.__dict__) 1187 return module 1188 1189 def __getattr__(self, item): 1190 module = self._load() 1191 return getattr(module, item) 1192 1193 1194# Aliases 1195 1196custom_object_scope = CustomObjectScope # pylint: disable=invalid-name 1197