1"""Manages a Trackable object graph.""" 2# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16import collections 17import weakref 18 19from tensorflow.python.trackable import base 20from tensorflow.python.trackable import converter 21from tensorflow.python.util import object_identity 22from tensorflow.python.util.tf_export import tf_export 23 24 25@tf_export("train.TrackableView", v1=[]) 26class TrackableView(object): 27 """Gathers and serializes a trackable view. 28 29 Example usage: 30 31 >>> class SimpleModule(tf.Module): 32 ... def __init__(self, name=None): 33 ... super().__init__(name=name) 34 ... self.a_var = tf.Variable(5.0) 35 ... self.b_var = tf.Variable(4.0) 36 ... self.vars = [tf.Variable(1.0), tf.Variable(2.0)] 37 38 >>> root = SimpleModule(name="root") 39 >>> root.leaf = SimpleModule(name="leaf") 40 >>> trackable_view = tf.train.TrackableView(root) 41 42 Pass root to tf.train.TrackableView.children() to get the dictionary of all 43 children directly linked to root by name. 44 >>> trackable_view_children = trackable_view.children(root) 45 >>> for item in trackable_view_children.items(): 46 ... print(item) 47 ('a_var', <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=5.0>) 48 ('b_var', <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=4.0>) 49 ('vars', ListWrapper([<tf.Variable 'Variable:0' shape=() dtype=float32, 50 numpy=1.0>, <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>])) 51 ('leaf', ...) 52 53 """ 54 55 def __init__(self, root): 56 """Configure the trackable view. 57 58 Args: 59 root: A `Trackable` object whose variables (including the variables of 60 dependencies, recursively) should be saved. May be a weak reference. 61 """ 62 # TrackableView should never contain a strong reference to root, since it 63 # may result in a cycle: 64 # root -> deferred dependencies -> CheckpointPosition 65 # -> CheckpointRestoreCoordinator -> TrackableView -> root 66 self._root_ref = (root if isinstance(root, weakref.ref) 67 else weakref.ref(root)) 68 69 @classmethod 70 def children(cls, obj, save_type=base.SaveType.CHECKPOINT, **kwargs): 71 """Returns all child trackables attached to obj. 72 73 Args: 74 obj: A `Trackable` object. 75 save_type: A string, can be 'savedmodel' or 'checkpoint'. 76 **kwargs: kwargs to use when retrieving the object's children. 77 78 Returns: 79 Dictionary of all children attached to the object with name to trackable. 80 """ 81 # pylint: disable=protected-access 82 obj._maybe_initialize_trackable() 83 children = {} 84 for name, ref in obj._trackable_children(save_type, **kwargs).items(): 85 ref = converter.convert_to_trackable(ref, parent=obj) 86 children[name] = ref 87 return children 88 89 @property 90 def root(self): 91 if isinstance(self._root_ref, weakref.ref): 92 derefed = self._root_ref() 93 assert derefed is not None 94 return derefed 95 else: 96 return self._root_ref 97 98 def descendants(self): 99 """Returns a list of all nodes from self.root using a breadth first traversal.""" 100 return self._descendants_with_paths()[0] 101 102 def _descendants_with_paths(self): 103 """Returns a list of all nodes and its paths from self.root using a breadth first traversal.""" 104 bfs_sorted = [] 105 to_visit = collections.deque([self.root]) 106 node_paths = object_identity.ObjectIdentityDictionary() 107 node_paths[self.root] = () 108 while to_visit: 109 current_trackable = to_visit.popleft() 110 bfs_sorted.append(current_trackable) 111 for name, dependency in self.children(current_trackable).items(): 112 if dependency not in node_paths: 113 node_paths[dependency] = ( 114 node_paths[current_trackable] + 115 (base.TrackableReference(name, dependency),)) 116 to_visit.append(dependency) 117 return bfs_sorted, node_paths 118