1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for object-based saving which use tf.train.* optimizers.""" 16import os 17 18from tensorflow.python.checkpoint import checkpoint as trackable_utils 19from tensorflow.python.client import session as session_lib 20from tensorflow.python.eager import context 21from tensorflow.python.eager import test 22from tensorflow.python.framework import ops 23from tensorflow.python.framework import test_util 24from tensorflow.python.ops import init_ops 25from tensorflow.python.ops import resource_variable_ops 26from tensorflow.python.ops import state_ops 27from tensorflow.python.ops import template 28from tensorflow.python.ops import variable_scope 29from tensorflow.python.trackable import autotrackable 30from tensorflow.python.training import adam 31 32 33class CheckpointingTests(test.TestCase): 34 35 @test_util.run_in_graph_and_eager_modes 36 def testDeferredSlotRestoration(self): 37 checkpoint_directory = self.get_temp_dir() 38 39 root = trackable_utils.Checkpoint() 40 root.var = trackable_utils.add_variable( 41 root, name="var", initializer=0.) 42 optimizer = adam.AdamOptimizer(0.1) 43 if context.executing_eagerly(): 44 optimizer.minimize(root.var.read_value) 45 else: 46 train_op = optimizer.minimize(root.var) 47 # Note that `optimizer` has not been added as a dependency of 48 # `root`. Create a one-off grouping so that slot variables for `root.var` 49 # get initialized too. 50 self.evaluate(trackable_utils.gather_initializers( 51 trackable_utils.Checkpoint(root=root, optimizer=optimizer))) 52 self.evaluate(train_op) 53 self.evaluate(state_ops.assign(root.var, 12.)) 54 no_slots_path = root.save(os.path.join(checkpoint_directory, "no_slots")) 55 root.optimizer = optimizer 56 self.evaluate(state_ops.assign(root.var, 13.)) 57 self.evaluate(state_ops.assign(optimizer.get_slot(name="m", var=root.var), 58 14.)) 59 slots_path = root.save(os.path.join(checkpoint_directory, "with_slots")) 60 new_root = trackable_utils.Checkpoint() 61 # Load the slot-containing checkpoint (deferred), then immediately overwrite 62 # the non-slot variable (also deferred). 63 slot_status = new_root.restore(slots_path) 64 no_slot_status = new_root.restore(no_slots_path) 65 with self.assertRaises(AssertionError): 66 no_slot_status.assert_consumed() 67 new_root.var = trackable_utils.add_variable( 68 new_root, name="var", shape=[]) 69 no_slot_status.assert_consumed() 70 no_slot_status.run_restore_ops() 71 self.assertEqual(12., self.evaluate(new_root.var)) 72 new_root.optimizer = adam.AdamOptimizer(0.1) 73 slot_status.assert_existing_objects_matched() 74 with self.assertRaisesRegex(AssertionError, "beta1_power"): 75 slot_status.assert_consumed() 76 self.assertEqual(12., self.evaluate(new_root.var)) 77 if context.executing_eagerly(): 78 # Slot variables are only created with restoring initializers when 79 # executing eagerly. 80 self.assertEqual(14., self.evaluate( 81 new_root.optimizer.get_slot(name="m", var=new_root.var))) 82 else: 83 self.assertIs(new_root.optimizer.get_slot(name="m", var=new_root.var), 84 None) 85 if context.executing_eagerly(): 86 new_root.optimizer.minimize(new_root.var.read_value) 87 else: 88 train_op = new_root.optimizer.minimize(new_root.var) 89 # The slot variable now exists; restore() didn't create it, but we should 90 # now have a restore op for it. 91 slot_status.run_restore_ops() 92 self.assertEqual(14., self.evaluate( 93 new_root.optimizer.get_slot(name="m", var=new_root.var))) 94 self.evaluate(train_op) 95 slot_status.assert_consumed() 96 97 def testManySavesGraph(self): 98 """Saves after the first should not modify the graph.""" 99 with context.graph_mode(): 100 graph = ops.Graph() 101 with graph.as_default(), self.session(graph): 102 checkpoint_directory = self.get_temp_dir() 103 checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") 104 obj = trackable_utils.Checkpoint() 105 obj.var = variable_scope.get_variable(name="v", initializer=0.) 106 obj.opt = adam.AdamOptimizer(0.1) 107 obj.opt.minimize(obj.var.read_value()) 108 self.evaluate(trackable_utils.gather_initializers(obj)) 109 obj.save(checkpoint_prefix) 110 before_ops = graph.get_operations() 111 obj.save(checkpoint_prefix) 112 self.assertEqual(before_ops, graph.get_operations()) 113 114 def testManyRestoresGraph(self): 115 """Restores after the first should not modify the graph.""" 116 with context.graph_mode(): 117 graph = ops.Graph() 118 with graph.as_default(), self.session(graph): 119 checkpoint_directory = self.get_temp_dir() 120 checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") 121 obj = trackable_utils.Checkpoint() 122 obj.var = variable_scope.get_variable(name="v", initializer=0.) 123 obj.opt = adam.AdamOptimizer(0.1) 124 obj.opt.minimize(obj.var.read_value()) 125 self.evaluate(trackable_utils.gather_initializers(obj)) 126 save_path = obj.save(checkpoint_prefix) 127 obj.restore(save_path) 128 before_ops = graph.get_operations() 129 obj.restore(save_path) 130 self.assertEqual(before_ops, graph.get_operations()) 131 132 def testMultipleGraphsNonSlotVariables(self): 133 with context.graph_mode(): 134 checkpoint_directory = self.get_temp_dir() 135 checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") 136 optimizer = adam.AdamOptimizer(0.001) 137 # Construct a model in one graph 138 first_graph = ops.Graph() 139 first_session = session_lib.Session(graph=first_graph) 140 with first_graph.as_default(), first_session.as_default(): 141 first_variable = resource_variable_ops.ResourceVariable([1.]) 142 first_root_trackable = trackable_utils.Checkpoint( 143 optimizer=optimizer, variable=first_variable) 144 train_op = optimizer.minimize(first_variable.read_value) 145 self.evaluate(trackable_utils.gather_initializers( 146 first_root_trackable)) 147 self.evaluate(train_op) 148 self.evaluate(first_variable.assign([1.])) 149 self.evaluate(optimizer.get_slot( 150 var=first_variable, name="m").assign([2.])) 151 beta1_power, _ = optimizer._get_beta_accumulators() 152 self.evaluate(beta1_power.assign(3.)) 153 154 # Save and load in a second graph 155 second_graph = ops.Graph() 156 with second_graph.as_default(), session_lib.Session(graph=second_graph): 157 second_variable = resource_variable_ops.ResourceVariable([1.]) 158 second_root_trackable = trackable_utils.Checkpoint( 159 optimizer=optimizer, variable=second_variable) 160 train_op = optimizer.minimize(second_variable.read_value) 161 second_root_trackable.restore(None).initialize_or_restore() 162 self.evaluate(train_op) 163 self.evaluate(second_variable.assign([4.])) 164 self.evaluate(optimizer.get_slot( 165 var=second_variable, name="m").assign([5.])) 166 beta1_power, _ = optimizer._get_beta_accumulators() 167 self.evaluate(beta1_power.assign(6.)) 168 save_path = second_root_trackable.save(checkpoint_prefix) 169 self.evaluate(second_variable.assign([7.])) 170 self.evaluate(optimizer.get_slot( 171 var=second_variable, name="m").assign([8.])) 172 beta1_power, _ = optimizer._get_beta_accumulators() 173 self.assertAllEqual(6., self.evaluate(beta1_power)) 174 status = second_root_trackable.restore(save_path) 175 status.assert_consumed().run_restore_ops() 176 self.assertAllEqual([4.], self.evaluate(second_variable)) 177 self.assertAllEqual([5.], self.evaluate(optimizer.get_slot( 178 var=second_variable, name="m"))) 179 beta1_power, _ = optimizer._get_beta_accumulators() 180 self.assertAllEqual(6., self.evaluate(beta1_power)) 181 182 # Check that the first graph is unmolested 183 with first_graph.as_default(), first_session.as_default(): 184 self.assertAllEqual([1.], self.evaluate(first_variable)) 185 self.assertAllEqual([2.], self.evaluate(optimizer.get_slot( 186 var=first_variable, name="m"))) 187 beta1_power, _ = optimizer._get_beta_accumulators() 188 self.assertAllEqual(3., self.evaluate(beta1_power)) 189 190 191class _ManualScope(autotrackable.AutoTrackable): 192 193 def __call__(self): 194 with variable_scope.variable_scope("ManualScope") as vs: 195 self.variable_scope = vs 196 with trackable_utils.capture_dependencies(template=self): 197 return self._build() 198 199 def _build(self): 200 return variable_scope.get_variable(name="in_manual_scope", shape=[]) 201 202 203class TemplateTests(test.TestCase): 204 205 @test_util.run_in_graph_and_eager_modes 206 def test_trackable_save_restore(self): 207 208 def _templated(): 209 v = variable_scope.get_variable( 210 "v", shape=[1], initializer=init_ops.zeros_initializer(), 211 use_resource=True) 212 v2 = variable_scope.get_variable( 213 "v2", shape=[1], initializer=init_ops.zeros_initializer(), 214 use_resource=True) 215 manual = _ManualScope() 216 return v, v + 1., v2, manual, manual() 217 218 save_template = template.make_template("s1", _templated) 219 v1_save, _, v2_save, manual_scope, manual_scope_v = save_template() 220 self.assertCountEqual([ 221 id(obj) for obj in 222 [v1_save, v2_save, manual_scope, manual_scope_v, save_template] 223 ], [id(obj) for obj in trackable_utils.list_objects(save_template)]) 224 self.assertDictEqual({"in_manual_scope": manual_scope_v}, 225 manual_scope._trackable_children()) 226 optimizer = adam.AdamOptimizer(0.0) 227 save_root = trackable_utils.Checkpoint( 228 my_template=save_template, optimizer=optimizer) 229 optimizer.minimize(v1_save.read_value) 230 self.evaluate([v.initializer for v in save_template.variables]) 231 self.evaluate([v.initializer for v in optimizer.variables()]) 232 self.evaluate(v1_save.assign([12.])) 233 self.evaluate(v2_save.assign([14.])) 234 checkpoint_directory = self.get_temp_dir() 235 checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") 236 save_path = save_root.save(checkpoint_prefix) 237 238 load_template = template.make_template("s2", _templated) 239 load_optimizer = adam.AdamOptimizer(0.0) 240 load_root = trackable_utils.Checkpoint( 241 my_template=load_template, optimizer=load_optimizer) 242 status = load_root.restore(save_path) 243 var, var_plus_one, var2, _, _ = load_template() 244 load_optimizer.minimize(var.read_value) 245 self.assertEqual(3, len(load_template._trackable_children())) 246 self.assertEqual(set(["v", "v2", "ManualScope"]), 247 load_template._trackable_children().keys()) 248 status.assert_consumed().run_restore_ops() 249 self.assertAllEqual([12.], self.evaluate(var)) 250 self.assertAllEqual([13.], self.evaluate(var_plus_one)) 251 self.assertAllEqual([14.], self.evaluate(var2)) 252 253 254if __name__ == "__main__": 255 test.main() 256