1"""Trackable data structures.""" 2# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16import collections 17import copy 18import operator 19import sys 20 21try: 22 import wrapt 23except ImportError: 24 # Fall back to the build-time dependency if the system package is not available. 25 from .....third_party import wrapt # pylint: disable=relative-beyond-top-level 26 27from tensorflow.python.eager import def_function 28from tensorflow.python.eager import function as defun 29from tensorflow.python.ops import variables 30from tensorflow.python.saved_model import revived_types 31from tensorflow.python.trackable import base 32from tensorflow.python.trackable import layer_utils 33from tensorflow.python.util import lazy_loader 34from tensorflow.python.util.compat import collections_abc 35from tensorflow.python.util.tf_export import tf_export 36 37 38module = lazy_loader.LazyLoader( 39 "module", globals(), "tensorflow.python.module.module") 40 41 42class NoDependency: 43 """Allows attribute assignment to `Trackable` objects with no dependency. 44 45 Example usage: 46 ```python 47 obj = Trackable() 48 obj.has_dependency = tf.Variable(0., name="dep") 49 obj.no_dependency = NoDependency(tf.Variable(1., name="nodep")) 50 assert obj.no_dependency.name == "nodep:0" 51 ``` 52 53 `obj` in this example has a dependency on the variable "dep", and both 54 attributes contain un-wrapped `Variable` objects. 55 56 `NoDependency` also works with `tf.keras.Model`, but only for checkpoint 57 dependencies: wrapping a `Layer` in `NoDependency` will assign the (unwrapped) 58 `Layer` to the attribute without a checkpoint dependency, but the `Model` will 59 still track the `Layer` (so it will appear in `Model.layers`, and its 60 variables will appear in `Model.variables`). 61 """ 62 63 __slots__ = ["value"] 64 65 def __init__(self, value): 66 self.value = value 67 68 69def _should_wrap_tuple(t): 70 """Determine if a tuple has any trackable components.""" 71 # pylint: disable=unidiomatic-typecheck 72 # Exact type checking to avoid mucking up custom logic in list/dict 73 # subclasses, e.g. collections.Counter. 74 for element in t: 75 if isinstance(element, NoDependency): 76 return True # We should remove the NoDependency object from the tuple. 77 if isinstance(element, base.Trackable): 78 return True 79 if type(element) == dict: 80 return True 81 if type(element) == collections.OrderedDict: 82 return True 83 if type(element) == list: 84 return True 85 if isinstance(element, tuple) and _should_wrap_tuple(element): 86 return True 87 # There are no trackable elements or data structures. Tuples are immutable, so 88 # mutation isn't a concern. Don't wrap. 89 return False 90 # pylint: enable=unidiomatic-typecheck 91 92 93@tf_export("__internal__.tracking.wrap", v1=[]) 94def wrap_or_unwrap(value): 95 """Wraps input value into trackable data structures. 96 97 This is mostly useful for containers like list, dict, etc, which could contain 98 trackable objects in it. Wrapped data structure will be tracked when 99 associated with a `tf.Module`, so that save model/checkpoint can properly 100 track the dependency. 101 102 It will also unwrap NoDependency objects. 103 104 Args: 105 value: the input object to be wrapped. 106 107 Returns: 108 Wrapped trackable data structure. 109 """ 110 # pylint: disable=unidiomatic-typecheck 111 # Exact type checking to avoid mucking up custom logic in list/dict 112 # subclasses, e.g. collections.Counter. 113 if isinstance(value, NoDependency): 114 return value.value 115 if isinstance(value, base.Trackable): 116 return value # Skip conversion for already trackable objects. 117 elif type(value) == dict: 118 return _DictWrapper(value) 119 elif type(value) == collections.OrderedDict: 120 return _DictWrapper(value) 121 elif type(value) == list: 122 return ListWrapper(value) 123 elif isinstance(value, tuple) and _should_wrap_tuple(value): 124 # There are trackable elements or data structures. Wrap the tuple. 125 return _TupleWrapper(value) 126 else: 127 return value 128 # pylint: enable=unidiomatic-typecheck 129 130 131@tf_export("__internal__.tracking.sticky_attribute_assignment", v1=[]) 132def sticky_attribute_assignment(trackable, name, value): 133 """Adds dependencies, generally called from __setattr__. 134 135 This behavior is shared between Trackable and Model. 136 137 Respects NoDependency indicators, but otherwise makes trackable objects 138 out of common data structures and tracks objects by their attribute names. 139 140 Args: 141 trackable: The object to add dependencies to (generally the one having 142 an attribute assigned). 143 name: The attribute name being assigned. 144 value: The value being assigned. Not necessarily a trackable object. 145 146 Returns: 147 The value which should be stored in the attribute (unwrapped from a 148 NoDependency object if necessary). 149 """ 150 if isinstance(value, NoDependency): 151 add_dependency = False 152 else: 153 add_dependency = True 154 value = wrap_or_unwrap(value) 155 if not add_dependency: 156 return value 157 if isinstance(value, base.Trackable): 158 trackable._track_trackable( # pylint: disable=protected-access 159 value, name=name, 160 # Allow the user to switch the Trackable which is tracked by this 161 # name, since assigning a new variable to an attribute has 162 # historically been fine (e.g. Adam did this). 163 overwrite=True) 164 return value 165 166 167class _UntrackableError(ValueError): 168 169 def __init__(self, value): # pylint: disable=super-init-not-called 170 self._value = value 171 172 def __str__(self): 173 return ("Only trackable objects (such as Layers or Optimizers) may be " 174 f"stored in a List object. Got {self._value}, which does not " 175 "inherit from Trackable.") 176 177 178@tf_export("__internal__.tracking.TrackableDataStructure", v1=[]) 179class TrackableDataStructure(base.Trackable): 180 """Base class for data structures which contain trackable objects.""" 181 182 def __init__(self): 183 # Attributes prefixed with "_self_" for compatibility with 184 # wrapt.ObjectProxy. All additional attrs MUST conform to this pattern, as 185 # extending `__slots__` on a subclass of ObjectProxy breaks in a variety of 186 # ways. 187 self._self_trainable = True 188 self._self_extra_variables = [] 189 self._self_attribute_sentinel = layer_utils.AttributeSentinel(True) 190 191 @property 192 def _attribute_sentinel(self): 193 return self._self_attribute_sentinel 194 195 @property 196 def trainable(self): 197 return self._self_trainable 198 199 @trainable.setter 200 def trainable(self, value): 201 self._self_trainable = value 202 203 def _track_value(self, value, name): 204 """Add a dependency on `value`.""" 205 value = sticky_attribute_assignment( 206 trackable=self, value=value, name=name) 207 if isinstance(value, variables.Variable): 208 self._self_extra_variables.append(value) 209 if not isinstance(value, base.Trackable): 210 raise _UntrackableError(value) 211 if hasattr(value, "_use_resource_variables"): 212 # In subclassed models, legacy layers (tf.layers) must always use 213 # resource variables. 214 value._use_resource_variables = True # pylint: disable=protected-access 215 value_attribute_sentinel = getattr(value, "_attribute_sentinel", None) 216 if value_attribute_sentinel: 217 value_attribute_sentinel.add_parent(self._attribute_sentinel) 218 return value 219 220 @property 221 def _values(self): 222 """An iterable/sequence which may contain trackable objects.""" 223 raise NotImplementedError("Abstract method") 224 225 @property 226 def _layers(self): 227 """All Layers and Layer containers, including empty containers.""" 228 # Filter objects on demand so that wrapper objects use values from the thing 229 # they're wrapping if out of sync. 230 collected = [] 231 for obj in self._values: 232 if (isinstance(obj, TrackableDataStructure) 233 or layer_utils.is_layer(obj) 234 or layer_utils.has_weights(obj)): 235 collected.append(obj) 236 return collected 237 238 @property 239 def layers(self): 240 return list(layer_utils.filter_empty_layer_containers(self._layers)) 241 242 @property 243 def trainable_weights(self): 244 if not self._self_trainable: 245 return [] 246 trainable_variables = [] 247 for obj in self._values: 248 if isinstance(obj, (TrackableDataStructure, module.Module)): 249 trainable_variables += obj.trainable_variables 250 trainable_extra_variables = [ 251 v for v in self._self_extra_variables if v.trainable 252 ] 253 return trainable_variables + trainable_extra_variables 254 255 @property 256 def non_trainable_weights(self): 257 trainable_extra_variables = [ 258 v for v in self._self_extra_variables if v.trainable 259 ] 260 non_trainable_extra_variables = [ 261 v for v in self._self_extra_variables if not v.trainable 262 ] 263 non_trainable_variables = [] 264 for obj in self._values: 265 if isinstance(obj, (TrackableDataStructure, module.Module)): 266 non_trainable_variables += obj.non_trainable_variables 267 268 if not self._self_trainable: 269 # Return order is all trainable vars, then all non-trainable vars. 270 trainable_variables = [] 271 for obj in self._values: 272 if isinstance(obj, (TrackableDataStructure, module.Module)): 273 trainable_variables += obj.trainable_variables 274 275 non_trainable_variables = ( 276 trainable_variables + trainable_extra_variables + 277 non_trainable_variables + non_trainable_extra_variables) 278 else: 279 non_trainable_variables = ( 280 non_trainable_variables + non_trainable_extra_variables) 281 282 return non_trainable_variables 283 284 @property 285 def weights(self): 286 return self.trainable_weights + self.non_trainable_weights 287 288 @property 289 def trainable_variables(self): 290 return self.trainable_weights 291 292 @property 293 def non_trainable_variables(self): 294 return self.non_trainable_weights 295 296 @property 297 def variables(self): 298 return self.weights 299 300 @property 301 def updates(self): 302 """Aggregate updates from any `Layer` instances.""" 303 # Updates and conditional losses are forwarded as-is rather than being 304 # filtered based on inputs, since this is just a container and won't ever 305 # have any inputs. 306 aggregated = [] 307 for layer in self.layers: 308 if hasattr(layer, "updates"): 309 aggregated += layer.updates 310 return aggregated 311 312 @property 313 def losses(self): 314 """Aggregate losses from any `Layer` instances.""" 315 aggregated = [] 316 for layer in self.layers: 317 if hasattr(layer, "losses"): 318 aggregated += layer.losses 319 return aggregated 320 321 def __hash__(self): 322 # Support object-identity hashing, so these structures can be used as keys 323 # in sets/dicts. 324 return id(self) 325 326 def __eq__(self, other): 327 # Similar to Tensors, trackable data structures use object-identity 328 # equality to support set/dict membership. 329 return self is other 330 331 332class List(TrackableDataStructure, collections_abc.Sequence): 333 """An append-only sequence type which is trackable. 334 335 Maintains checkpoint dependencies on its contents (which must also be 336 trackable), and forwards any `Layer` metadata such as updates and losses. 337 338 Note that `List` is purely a container. It lets a `tf.keras.Model` or 339 other trackable object know about its contents, but does not call any 340 `Layer` instances which are added to it. To indicate a sequence of `Layer` 341 instances which should be called sequentially, use `tf.keras.Sequential`. 342 343 Example usage: 344 ```python 345 class HasList(tf.keras.Model): 346 347 def __init__(self): 348 super().__init__() 349 self.layer_list = List([layers.Dense(3)]) 350 self.layer_list.append(layers.Dense(4)) 351 352 def call(self, x): 353 aggregation = 0. 354 for l in self.layer_list: 355 x = l(x) 356 aggregation += tf.reduce_sum(x) 357 return aggregation 358 ``` 359 360 This kind of wrapping is necessary because `Trackable` objects do not 361 (yet) deeply inspect regular Python data structures, so for example assigning 362 a regular list (`self.layer_list = [layers.Dense(3)]`) does not create a 363 checkpoint dependency and does not add the `Layer` instance's weights to its 364 parent `Model`. 365 """ 366 367 def __init__(self, *args, **kwargs): 368 """Construct a new sequence. Arguments are passed to `list()`.""" 369 super().__init__() 370 self._storage = self._make_storage(*args, **kwargs) 371 for index, element in enumerate(self._storage): 372 self._storage[index] = self._track_value( 373 element, name=self._name_element(index)) 374 375 def copy(self): 376 return type(self)(copy.copy(self._storage)) 377 378 def __copy__(self): 379 return self.copy() 380 381 def __deepcopy__(self, memo): 382 return type(self)(copy.deepcopy(self._storage, memo)) 383 384 def _make_storage(self, *args, **kwargs): 385 """Determines the backing storage (overridden in subclasses).""" 386 return list(*args, **kwargs) 387 388 def _name_element(self, index): 389 return "%d" % (index,) 390 391 @property 392 def _values(self): 393 """Collect values for TrackableDataStructure.""" 394 return self 395 396 def append(self, value): 397 """Add a new trackable value.""" 398 value = self._track_value(value, self._name_element(len(self._storage))) 399 self._storage.append(value) 400 401 def extend(self, values): 402 """Add a sequence of trackable values.""" 403 for value in values: 404 self.append(value) 405 406 def __iadd__(self, values): 407 self.extend(values) 408 return self 409 410 def __add__(self, other): 411 return self._storage + getattr(other, "_storage", other) 412 413 def __imul__(self, y): 414 if y <= 0: 415 raise ValueError( 416 f"List only supports append, multiplying in place by {y} removes " 417 "elements.") 418 419 n = len(self._storage) 420 for _ in range(y - 1): 421 for i in range(n): 422 self.append(self._storage[i]) 423 424 return self 425 426 def __mul__(self, n): 427 return self._storage * n 428 429 def __rmul__(self, n): 430 return self * n 431 432 def __radd__(self, other): 433 return other + self._storage 434 435 def __getitem__(self, key): 436 return self._storage[key] 437 438 def __getslice__(self, i, j): 439 return self._storage[slice(i, j)] 440 441 def __len__(self): 442 return len(self._storage) 443 444 def __repr__(self): 445 return "List(%s)" % (repr(self._storage),) 446 447 def __sizeof__(self): 448 return super().__sizeof__() + sys.getsizeof(self._storage) 449 450 451# TODO(tomhennigan) Update to collections.UserList? 452# TODO(allenl): Try switching this to wrapt.ObjectProxy again when we drop 453# Python 3.4 support (may still be tricky). 454class ListWrapper( 455 List, 456 collections_abc.MutableSequence, 457 # Shadowed, but there for isinstance checks. 458 list): 459 """Wraps the built-in `list` to support restore-on-create for variables. 460 461 Unlike `List`, this sequence type is mutable in the same ways built-in lists 462 are. Instead of throwing an error immediately like `List`, it records 463 problematic mutations (e.g. assigning a new element to a position already 464 occupied, meaning both elements get the same names at different times) and 465 refuses to save. 466 467 On assignment to an attribute of a Model or Trackable object, Python 468 lists are replaced with ListWrapper. Wrapping a list in a 469 `NoDependency` object prevents this. 470 """ 471 472 def __init__(self, wrapped_list): 473 """Construct a new list wrapper. 474 475 Args: 476 wrapped_list: The initial value of the data structure. A shallow copy may 477 be maintained for error checking. `wrapped_list` itself should not be 478 modified directly after constructing the `ListWrapper`, and if changes 479 are detected the `ListWrapper` will throw an exception on save. 480 """ 481 # Monotonic flags which indicate this object would not be restored properly, 482 # and therefore should throw an error on save to avoid giving the impression 483 # that restoring it will work. 484 self._non_append_mutation_value = False 485 self._external_modification_value = False 486 super().__init__(wrapped_list) 487 self._last_wrapped_list_snapshot = list(self._storage) 488 489 @property 490 def _non_append_mutation(self): 491 return self._non_append_mutation_value 492 493 @_non_append_mutation.setter 494 def _non_append_mutation(self, value): 495 # Trackable only cares that a mutation occurred at some point; when 496 # attempting to save it checks whether a mutation occurred and the object is 497 # in a "dirty" state but otherwise the specifics of how it got to that state 498 # are ignored. By contrast, the attribute cache needs to signal the mutation 499 # immediately since a caller could query the value of an attribute (And 500 # should not hit the cached value since the mutation may have affected the 501 # result.) 502 self._attribute_sentinel.invalidate_all() 503 self._non_append_mutation_value = value 504 505 @property 506 def _external_modification(self): 507 return self._external_modification_value 508 509 @_external_modification.setter 510 def _external_modification(self, value): 511 # Invalidate for the same reason as `_non_append_mutation` 512 self._attribute_sentinel.invalidate_all() 513 self._external_modification_value = value 514 515 # pylint: disable=protected-access 516 def __copy__(self): 517 copied = super().__copy__() 518 copied._non_append_mutation = self._non_append_mutation 519 copied._external_modification = self._external_modification 520 return copied 521 522 def __deepcopy__(self, memo): 523 copied = super().__deepcopy__(memo) 524 copied._non_append_mutation = self._non_append_mutation 525 copied._external_modification = self._external_modification 526 return copied 527 # pylint: enable=protected-access 528 529 def __reduce_ex__(self, protocol): 530 return (self.__class__, 531 (self._storage,)) 532 533 def _make_storage(self, wrapped_list): 534 """Use the user's original list for storage.""" 535 return wrapped_list 536 537 def _check_external_modification(self): 538 """Checks for any changes to the wrapped list not through the wrapper.""" 539 if self._external_modification or self._non_append_mutation: 540 return 541 if self._storage != self._last_wrapped_list_snapshot: 542 self._external_modification = True 543 self._last_wrapped_list_snapshot = None 544 545 def _update_snapshot(self): 546 """Acknowledges tracked changes to the wrapped list.""" 547 548 # Mutation tracking for attributes reuses the same infrastructure as 549 # Trackable mutation tracking. 550 self._attribute_sentinel.invalidate_all() 551 if self._external_modification or self._non_append_mutation: 552 return 553 self._last_wrapped_list_snapshot = list(self._storage) 554 555 def _trackable_children(self, save_type=base.SaveType.CHECKPOINT, **kwargs): 556 self._check_external_modification() 557 if self._non_append_mutation: 558 raise ValueError( 559 f"Unable to save the object {self} (a list wrapper constructed to " 560 "track trackable TensorFlow objects). A list element was replaced " 561 "(__setitem__, __setslice__), deleted (__delitem__, __delslice__), " 562 "or moved (sort). In order to support restoration on object " 563 "creation, tracking is exclusively for append-only data structures." 564 "\n\nIf you don't need this list checkpointed, wrap it in a " 565 "non-trackable object; it will be subsequently ignored.") 566 if self._external_modification: 567 raise ValueError( 568 f"Unable to save the object {self} (a list wrapper constructed to " 569 "track trackable TensorFlow objects). The wrapped list was modified " 570 f"outside the wrapper (its final value was {self._storage}, its value" 571 " when a checkpoint dependency was added was " 572 f"{self._last_wrapped_list_snapshot}), which breaks " 573 "restoration on object creation.\n\nIf you don't need this list " 574 "checkpointed, wrap it in a NoDependency object; it will be " 575 "subsequently ignored.") 576 children = super()._trackable_children(save_type, **kwargs) 577 578 if save_type == base.SaveType.SAVEDMODEL: 579 # Add functions to be serialized. 580 children.update({ 581 str(key): value 582 for key, value in enumerate(self) 583 if _is_function(value) 584 }) 585 586 return children 587 588 def _has_mutation_or_trackable(self): 589 """Short-circuits a check for trackables if there's already a mutation.""" 590 if self._non_append_mutation: 591 return True 592 return any(isinstance(element, base.Trackable) for element in self._storage) 593 594 def __delitem__(self, key): 595 self._check_external_modification() 596 if self._has_mutation_or_trackable(): 597 self._non_append_mutation = True 598 del self._storage[key] 599 self._update_snapshot() 600 601 def __setitem__(self, key, value): 602 self._check_external_modification() 603 604 if isinstance(key, slice): 605 # Note: this is quite inefficient, but the list API supports a broad range 606 # of slice setters (e.g. truncate, extend, replace) and imitating this 607 # for a range of Python versions is non-trivial. 608 storage_copy = list(self._storage) 609 self._storage[key] = value 610 611 len_before = len(storage_copy) 612 len_now = len(self._storage) 613 for i in range(max(len_before, len_now)): 614 value_now = self._storage[i] if i < len_now else None 615 value_before = storage_copy[i] if i < len_before else None 616 617 if isinstance(value_before, base.Trackable): 618 self._non_append_mutation = True 619 620 if value_now is not None and value_now != value_before: 621 self._storage[i] = self._track_value(self._storage[i], 622 self._name_element(i)) 623 624 else: 625 if isinstance(self._storage[key], base.Trackable): 626 self._non_append_mutation = True 627 self._storage[key] = self._track_value(value, self._name_element(key)) 628 629 self._update_snapshot() 630 631 def append(self, value): 632 """Add a new trackable value.""" 633 self._check_external_modification() 634 super().append(value) 635 self._update_snapshot() 636 637 def extend(self, values): 638 """Add a sequence of trackable values.""" 639 self._check_external_modification() 640 super().extend(values) 641 self._update_snapshot() 642 643 def __imul__(self, y): 644 if y <= 0: 645 self._check_external_modification() 646 if self._has_mutation_or_trackable(): 647 self._non_append_mutation = True 648 self._storage *= y 649 self._update_snapshot() 650 return self 651 652 # Relies on super() calling append, which updates the snapshot. 653 return super().__imul__(y) 654 655 def __eq__(self, other): 656 return self._storage == getattr(other, "_storage", other) 657 658 def __ne__(self, other): 659 return self._storage != getattr(other, "_storage", other) 660 661 def __lt__(self, other): 662 return self._storage < getattr(other, "_storage", other) 663 664 def __le__(self, other): 665 return self._storage <= getattr(other, "_storage", other) 666 667 def __gt__(self, other): 668 return self._storage > getattr(other, "_storage", other) 669 670 def __ge__(self, other): 671 return self._storage >= getattr(other, "_storage", other) 672 673 def __hash__(self): 674 # List wrappers need to compare like regular lists, and so like regular 675 # lists they don't belong in hash tables. 676 raise TypeError("unhashable type: 'ListWrapper'") 677 678 def insert(self, index, obj): 679 self._check_external_modification() 680 if (self._has_mutation_or_trackable() or isinstance(obj, base.Trackable)): 681 self._non_append_mutation = True 682 self._storage.insert(index, obj) 683 self._update_snapshot() 684 685 def sort(self): 686 self._check_external_modification() 687 if self._has_mutation_or_trackable(): 688 self._non_append_mutation = True 689 self._storage.sort() 690 self._update_snapshot() 691 692 def __setslice__(self, i, j, y): 693 self.__setitem__(slice(i, j), y) 694 695 def __delslice__(self, i, j): 696 self._check_external_modification() 697 if self._has_mutation_or_trackable(): 698 self._non_append_mutation = True 699 del self._storage[slice(i, j)] 700 self._update_snapshot() 701 702 def _track_value(self, value, name): 703 """Allows storage of non-trackable objects.""" 704 try: 705 value = super()._track_value(value=value, name=name) 706 except ValueError: 707 # Even if this value isn't trackable, we need to make sure 708 # NoDependency objects get unwrapped. 709 value = sticky_attribute_assignment( 710 trackable=self, value=value, name=name) 711 return value 712 713 def __repr__(self): 714 return "ListWrapper(%s)" % (repr(self._storage),) 715 716 717class Mapping(TrackableDataStructure, collections_abc.Mapping): 718 """An append-only trackable mapping data structure with string keys. 719 720 Maintains checkpoint dependencies on its contents (which must also be 721 trackable), named based on its keys. 722 723 Note that once a key has been added, it may not be deleted or replaced. 724 """ 725 726 def __init__(self, *args, **kwargs): 727 """Construct a new sequence. Arguments are passed to `dict()`.""" 728 super().__init__() 729 self._storage = self._make_storage(*args, **kwargs) 730 self._storage.update( 731 {key: self._track_value( 732 value, name=self._name_element(key)) 733 for key, value in self._storage.items()}) 734 735 def __copy__(self): 736 return type(self)(copy.copy(self._storage)) 737 738 def __deepcopy__(self, memo): 739 return type(self)(copy.deepcopy(self._storage, memo)) 740 741 def _make_storage(self, *args, **kwargs): 742 return dict(*args, **kwargs) 743 744 @property 745 def _values(self): 746 """Collect values for TrackableDataStructure.""" 747 # Sort items deterministically by key 748 ordered = list(zip(*sorted(self.items(), key=lambda it: it[0]))) 749 if ordered: 750 return ordered[1] 751 return [] 752 753 def _name_element(self, key): 754 if not isinstance(key, str): 755 raise TypeError( 756 f"Mapping accepts only string keys, but got a key {repr(key)}.") 757 return str(key) 758 759 def __setitem__(self, key, value): 760 name = self._name_element(key) 761 value = self._track_value(value, name=name) 762 current_value = self._storage.setdefault(key, value) 763 if current_value is not value: 764 raise ValueError( 765 "Mappings are an append-only data structure. Tried to overwrite the " 766 f"key '{key}' with value {value}, but it already contains " 767 f"{current_value}") 768 769 def update(self, *args, **kwargs): 770 for key, value in dict(*args, **kwargs).items(): 771 self[key] = value 772 773 def __getitem__(self, key): 774 return self._storage[key] 775 776 def __len__(self): 777 return len(self._storage) 778 779 def __repr__(self): 780 return "Mapping(%s)" % (repr(self._storage),) 781 782 def __iter__(self): 783 return iter(self._storage) 784 785 786class _DictWrapper(TrackableDataStructure, wrapt.ObjectProxy): 787 """Wraps built-in dicts to support restore-on-create for variables. 788 789 _DictWrapper is to Mapping as ListWrapper is to List. Unlike Mapping, 790 _DictWrapper allows non-string keys and values and arbitrary mutations (delete 791 keys, reassign values). Like ListWrapper, these mutations mean that 792 _DictWrapper will raise an exception on save. 793 """ 794 795 def __init__(self, wrapped_dict=None): 796 if wrapped_dict is None: 797 # Allow zero-argument construction, e.g. from session.run's re-wrapping. 798 wrapped_dict = {} 799 if not isinstance(wrapped_dict, collections_abc.Mapping): 800 # Allow construction from a sequence, e.g. from nest.pack_sequence_as. 801 wrapped_dict = dict(wrapped_dict) 802 wrapt.ObjectProxy.__init__(self, wrapped_dict) 803 TrackableDataStructure.__init__(self) 804 self._self_non_string_key = False 805 self._self_external_modification = False 806 self.__wrapped__.update( 807 {key: self._track_value( 808 value, name=self._name_element(key)) 809 for key, value in self.__wrapped__.items()}) 810 self._update_snapshot() 811 812 def __reduce_ex__(self, protocol): 813 return (self.__class__, 814 (self.__wrapped__,)) 815 816 def __getattribute__(self, name): 817 if (hasattr(type(self), name) 818 and isinstance(getattr(type(self), name), property)): 819 # Bypass ObjectProxy for properties. Whether this workaround is necessary 820 # appears to depend on the Python version but not the wrapt version: 3.4 821 # in particular seems to look up properties on the wrapped object instead 822 # of the wrapper without this logic. 823 return object.__getattribute__(self, name) 824 else: 825 return super().__getattribute__(name) 826 827 def copy(self): 828 return copy.copy(self) 829 830 # pylint: disable=protected-access 831 def __copy__(self): 832 copied = _DictWrapper(copy.copy(self.__wrapped__)) 833 copied._self_external_modification = self._self_external_modification 834 copied._self_non_string_key = self._self_non_string_key 835 return copied 836 837 def __deepcopy__(self, memo): 838 copied = _DictWrapper(copy.deepcopy(self.__wrapped__, memo)) 839 copied._self_external_modification = self._self_external_modification 840 copied._self_non_string_key = self._self_non_string_key 841 return copied 842 # pylint: enable=protected-access 843 844 @property 845 def _values(self): 846 """Collect values for TrackableDataStructure.""" 847 # Sort items deterministically by key 848 ordered = list(zip(*sorted(self.items(), key=lambda it: it[0]))) 849 if ordered: 850 return ordered[1] 851 return [] 852 853 def _trackable_children(self, save_type=base.SaveType.CHECKPOINT, **kwargs): 854 """Check that the object is saveable before listing its dependencies.""" 855 self._check_self_external_modification() 856 if self._self_non_string_key: 857 raise ValueError( 858 f"Unable to save the object {self} (a dictionary wrapper constructed " 859 "automatically on attribute assignment). The wrapped dictionary " 860 "contains a non-string key which maps to a trackable object or " 861 "mutable data structure.\n\nIf you don't need this dictionary " 862 "checkpointed, wrap it in a non-trackable " 863 "object; it will be subsequently ignored.") 864 if self._self_external_modification: 865 raise ValueError( 866 f"Unable to save the object {self} (a dictionary wrapper constructed " 867 "automatically on attribute assignment). The wrapped dictionary was " 868 f"modified outside the wrapper (its final value was {self}, its value" 869 " when a checkpoint dependency was added was " 870 f"{self._self_last_wrapped_dict_snapshot}), which breaks " 871 "restoration on object creation.\n\nIf you don't need this " 872 "dictionary checkpointed, wrap it in a " 873 "non-trackable object; it will be subsequently ignored.") 874 assert not self._dirty # Any reason for dirtiness should have an exception. 875 children = super()._trackable_children(save_type, **kwargs) 876 877 if save_type == base.SaveType.SAVEDMODEL: 878 # Add functions to be serialized. 879 children.update( 880 {key: value for key, value in self.items() if _is_function(value)}) 881 882 return children 883 884 @property 885 def _dirty(self): 886 """Check if there has already been a mutation which prevents saving.""" 887 return (self._self_external_modification 888 or self._self_non_string_key) 889 890 def _check_self_external_modification(self): 891 """Checks for any changes to the wrapped dict not through the wrapper.""" 892 if self._dirty: 893 return 894 if self != self._self_last_wrapped_dict_snapshot: 895 self._self_external_modification = True 896 self._self_last_wrapped_dict_snapshot = None 897 898 def _update_snapshot(self): 899 """Acknowledges tracked changes to the wrapped dict.""" 900 self._attribute_sentinel.invalidate_all() 901 if self._dirty: 902 return 903 self._self_last_wrapped_dict_snapshot = dict(self) 904 905 def _track_value(self, value, name): 906 """Allows storage of non-trackable objects.""" 907 if isinstance(name, str): 908 string_key = True 909 else: 910 name = "-non_string_key" 911 string_key = False 912 try: 913 no_dependency = isinstance(value, NoDependency) 914 value = super()._track_value(value=value, name=name) 915 if not (string_key or no_dependency): 916 # A non-string key maps to a trackable value. This data structure 917 # is not saveable. 918 self._self_non_string_key = True 919 return value 920 except ValueError: 921 # Even if this value isn't trackable, we need to make sure 922 # NoDependency objects get unwrapped. 923 return sticky_attribute_assignment( 924 trackable=self, value=value, name=name) 925 926 def _name_element(self, key): 927 """Tells TrackableDataStructure to use keys as names as-is.""" 928 return key 929 930 def __setitem__(self, key, value): 931 """Allow any modifications, but possibly mark the wrapper as unsaveable.""" 932 self._check_self_external_modification() 933 self._maybe_initialize_trackable() 934 no_dep = isinstance(value, NoDependency) 935 if isinstance(key, str): 936 value = self._track_value(value, name=key) 937 else: 938 value = wrap_or_unwrap(value) 939 if not no_dep and isinstance(value, base.Trackable): 940 # Non-string keys are OK as long as we have no reason to add a 941 # dependency on the value (either because the value is not 942 # trackable, or because it was wrapped in a NoDependency object). 943 self._self_non_string_key = True 944 self.__wrapped__[key] = value 945 946 self._update_snapshot() 947 948 def __delitem__(self, key): 949 self._check_self_external_modification() 950 del self.__wrapped__[key] 951 self._update_snapshot() 952 953 def __repr__(self): 954 return "DictWrapper(%s)" % (repr(self.__wrapped__),) 955 956 def __hash__(self): 957 raise TypeError("unhashable type: 'DictWrapper'") 958 959 def __eq__(self, other): 960 # Override the TrackableDataStructure "== -> is" forwarding and go back to 961 # the wrapt implementation. 962 return self.__wrapped__ == other 963 964 def update(self, *args, **kwargs): 965 for key, value in dict(*args, **kwargs).items(): 966 self[key] = value 967 968 969class _TupleWrapper(TrackableDataStructure, wrapt.ObjectProxy): 970 """Trackable wrapper for tuples and namedtuples.""" 971 972 def __init__(self, original_wrapped_tuple=()): 973 add_dependency = [] 974 substituted_wrapped_tuple = [] 975 for element in original_wrapped_tuple: 976 if isinstance(element, NoDependency): 977 add_dependency.append(False) 978 else: 979 add_dependency.append(True) 980 substituted_wrapped_tuple.append(wrap_or_unwrap(element)) 981 try: 982 fields = original_wrapped_tuple._fields 983 except AttributeError: 984 # Not a namedtuple 985 is_namedtuple = False 986 else: 987 is_namedtuple = True 988 original_type = type(original_wrapped_tuple) 989 # Flag to poison saving if we can't re-construct a namedtupled because its 990 # __new__ takes different keyword arguments than its _fields. 991 self._self_tuple_is_constructable = True 992 if is_namedtuple: 993 try: 994 # NamedTuples take N arguments, unlike tuple which takes a sequence. 995 substituted_wrapped_tuple = original_type( 996 **dict(zip(fields, substituted_wrapped_tuple))) 997 except TypeError: 998 wrapt.ObjectProxy.__init__(self, original_wrapped_tuple) 999 TrackableDataStructure.__init__(self) 1000 self._self_tuple_is_constructable = False 1001 return 1002 else: 1003 substituted_wrapped_tuple = original_type(substituted_wrapped_tuple) 1004 wrapt.ObjectProxy.__init__(self, substituted_wrapped_tuple) 1005 TrackableDataStructure.__init__(self) 1006 1007 if is_namedtuple: 1008 # For namedtuples, also track by names for compatibility with 1009 # dictionaries. 1010 for name, should_depend, element in zip( 1011 fields, add_dependency, substituted_wrapped_tuple): 1012 if should_depend: 1013 self._track_value(element, name=name) 1014 1015 # Track by index as well, for compatibility with lists. 1016 for index, (should_depend, element) in enumerate( 1017 zip(add_dependency, substituted_wrapped_tuple)): 1018 if should_depend: 1019 self._track_value(element, name="%d" % (index,)) 1020 1021 @property 1022 def _values(self): 1023 """Collect values for TrackableDataStructure.""" 1024 return self 1025 1026 def _track_value(self, value, name): 1027 """Allows storage of non-trackable objects.""" 1028 try: 1029 value = super()._track_value(value=value, name=name) 1030 except ValueError: 1031 # Even if this value isn't trackable, we need to make sure 1032 # NoDependency objects get unwrapped. 1033 value = sticky_attribute_assignment( 1034 trackable=self, value=value, name=name) 1035 return value 1036 1037 def __repr__(self): 1038 return "_TupleWrapper(%s)" % (repr(self.__wrapped__),) 1039 1040 def __hash__(self): 1041 # Override the TrackableDataStructure hash forwarding and go back to 1042 # the wrapt implementation. 1043 return hash(self.__wrapped__) 1044 1045 def __eq__(self, other): 1046 # Override the TrackableDataStructure "== -> is" forwarding and go back to 1047 # the wrapt implementation. 1048 return self.__wrapped__ == other 1049 1050 def __copy__(self): 1051 return _TupleWrapper(copy.copy(self.__wrapped__)) 1052 1053 def __deepcopy__(self, memo): 1054 return _TupleWrapper(copy.deepcopy(self.__wrapped__, memo)) 1055 1056 def __reduce_ex__(self, protocol): 1057 return (self.__class__, 1058 (self.__wrapped__,)) 1059 1060 # imul and iadd are the only tuple-relevant in-place operators. They need to 1061 # be special-cased to avoid mutating the original proxy object. 1062 def __imul__(self, y): 1063 """Avoid running self.__wrapped__ *= y, which mutates `self`.""" 1064 return self.__wrapped__ * y 1065 1066 def __iadd__(self, y): 1067 """Avoid running self.__wrapped__ += y, which mutates `self`.""" 1068 return self.__wrapped__ + y 1069 1070 def _trackable_children(self, save_type=base.SaveType.CHECKPOINT, **kwargs): 1071 if not self._self_tuple_is_constructable: 1072 raise ValueError( 1073 f"Unable to save because the namedtuple {self.__wrapped__} is not " 1074 "constructable from its _fields (i.e. __new__ is overridden). " 1075 f"Expected keyword arguments {self.__wrapped__._fields}. If you do " 1076 "not need to save this object, consider wrapping it in a custom " 1077 "object that does not inherit from tuple.") 1078 return super()._trackable_children(save_type, **kwargs) 1079 1080 def __getattribute__(self, name): 1081 if name != "__wrapped__" and hasattr(self.__wrapped__, name): 1082 # Prefer attributes on the wrapped object when they conflict with 1083 # attributes on the wrapper object. 1084 return getattr(self.__wrapped__, name) 1085 1086 if (hasattr(type(self), name) 1087 and isinstance(getattr(type(self), name), property)): 1088 # Bypass ObjectProxy for properties. Whether this workaround is necessary 1089 # appears to depend on the Python version but not the wrapt version: 3.4 1090 # in particular seems to look up properties on the wrapped object instead 1091 # of the wrapper without this logic. 1092 return object.__getattribute__(self, name) 1093 else: 1094 return super().__getattribute__(name) 1095 1096 1097def _is_function(x): 1098 return isinstance(x, (def_function.Function, defun.ConcreteFunction)) 1099 1100 1101revived_types.register_revived_type( 1102 "trackable_dict_wrapper", 1103 lambda obj: isinstance(obj, _DictWrapper), 1104 versions=[revived_types.VersionedTypeRegistration( 1105 # Standard dependencies are enough to reconstruct the trackable 1106 # items in dictionaries, so we don't need to save any extra information. 1107 object_factory=lambda proto: _DictWrapper({}), 1108 version=1, 1109 min_producer_version=1, 1110 min_consumer_version=1, 1111 setter=operator.setitem)]) 1112 1113 1114def _set_list_item(list_object, index_string, value): 1115 item_index = int(index_string) 1116 if len(list_object) <= item_index: 1117 list_object.extend([None] * (1 + item_index - len(list_object))) 1118 list_object[item_index] = value 1119 1120 1121revived_types.register_revived_type( 1122 "trackable_list_wrapper", 1123 lambda obj: isinstance(obj, ListWrapper), 1124 versions=[revived_types.VersionedTypeRegistration( 1125 object_factory=lambda proto: ListWrapper([]), 1126 version=1, 1127 min_producer_version=1, 1128 min_consumer_version=1, 1129 setter=_set_list_item)]) 1130 1131 1132def _set_tuple_item(list_object, index_string, value): 1133 try: 1134 item_index = int(index_string) 1135 except ValueError: 1136 # Ignore namedtuple fields. 1137 return 1138 if len(list_object) <= item_index: 1139 list_object.extend([None] * (1 + item_index - len(list_object))) 1140 list_object[item_index] = value 1141 1142 1143# Revive tuples as lists so we can append any dependencies during loading. 1144revived_types.register_revived_type( 1145 "trackable_tuple_wrapper", 1146 lambda obj: isinstance(obj, _TupleWrapper), 1147 versions=[revived_types.VersionedTypeRegistration( 1148 object_factory=lambda proto: ListWrapper([]), 1149 version=1, 1150 min_producer_version=1, 1151 min_consumer_version=1, 1152 setter=_set_tuple_item)]) 1153