1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for object-based saving which use tf.train.* optimizers."""
16import os
17
18from tensorflow.python.checkpoint import checkpoint as trackable_utils
19from tensorflow.python.client import session as session_lib
20from tensorflow.python.eager import context
21from tensorflow.python.eager import test
22from tensorflow.python.framework import ops
23from tensorflow.python.framework import test_util
24from tensorflow.python.ops import init_ops
25from tensorflow.python.ops import resource_variable_ops
26from tensorflow.python.ops import state_ops
27from tensorflow.python.ops import template
28from tensorflow.python.ops import variable_scope
29from tensorflow.python.trackable import autotrackable
30from tensorflow.python.training import adam
31
32
33class CheckpointingTests(test.TestCase):
34
35  @test_util.run_in_graph_and_eager_modes
36  def testDeferredSlotRestoration(self):
37    checkpoint_directory = self.get_temp_dir()
38
39    root = trackable_utils.Checkpoint()
40    root.var = trackable_utils.add_variable(
41        root, name="var", initializer=0.)
42    optimizer = adam.AdamOptimizer(0.1)
43    if context.executing_eagerly():
44      optimizer.minimize(root.var.read_value)
45    else:
46      train_op = optimizer.minimize(root.var)
47      # Note that `optimizer` has not been added as a dependency of
48      # `root`. Create a one-off grouping so that slot variables for `root.var`
49      # get initialized too.
50      self.evaluate(trackable_utils.gather_initializers(
51          trackable_utils.Checkpoint(root=root, optimizer=optimizer)))
52      self.evaluate(train_op)
53    self.evaluate(state_ops.assign(root.var, 12.))
54    no_slots_path = root.save(os.path.join(checkpoint_directory, "no_slots"))
55    root.optimizer = optimizer
56    self.evaluate(state_ops.assign(root.var, 13.))
57    self.evaluate(state_ops.assign(optimizer.get_slot(name="m", var=root.var),
58                                   14.))
59    slots_path = root.save(os.path.join(checkpoint_directory, "with_slots"))
60    new_root = trackable_utils.Checkpoint()
61    # Load the slot-containing checkpoint (deferred), then immediately overwrite
62    # the non-slot variable (also deferred).
63    slot_status = new_root.restore(slots_path)
64    no_slot_status = new_root.restore(no_slots_path)
65    with self.assertRaises(AssertionError):
66      no_slot_status.assert_consumed()
67    new_root.var = trackable_utils.add_variable(
68        new_root, name="var", shape=[])
69    no_slot_status.assert_consumed()
70    no_slot_status.run_restore_ops()
71    self.assertEqual(12., self.evaluate(new_root.var))
72    new_root.optimizer = adam.AdamOptimizer(0.1)
73    slot_status.assert_existing_objects_matched()
74    with self.assertRaisesRegex(AssertionError, "beta1_power"):
75      slot_status.assert_consumed()
76    self.assertEqual(12., self.evaluate(new_root.var))
77    if context.executing_eagerly():
78      # Slot variables are only created with restoring initializers when
79      # executing eagerly.
80      self.assertEqual(14., self.evaluate(
81          new_root.optimizer.get_slot(name="m", var=new_root.var)))
82    else:
83      self.assertIs(new_root.optimizer.get_slot(name="m", var=new_root.var),
84                    None)
85    if context.executing_eagerly():
86      new_root.optimizer.minimize(new_root.var.read_value)
87    else:
88      train_op = new_root.optimizer.minimize(new_root.var)
89      # The slot variable now exists; restore() didn't create it, but we should
90      # now have a restore op for it.
91      slot_status.run_restore_ops()
92      self.assertEqual(14., self.evaluate(
93          new_root.optimizer.get_slot(name="m", var=new_root.var)))
94      self.evaluate(train_op)
95    slot_status.assert_consumed()
96
97  def testManySavesGraph(self):
98    """Saves after the first should not modify the graph."""
99    with context.graph_mode():
100      graph = ops.Graph()
101      with graph.as_default(), self.session(graph):
102        checkpoint_directory = self.get_temp_dir()
103        checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
104        obj = trackable_utils.Checkpoint()
105        obj.var = variable_scope.get_variable(name="v", initializer=0.)
106        obj.opt = adam.AdamOptimizer(0.1)
107        obj.opt.minimize(obj.var.read_value())
108        self.evaluate(trackable_utils.gather_initializers(obj))
109        obj.save(checkpoint_prefix)
110        before_ops = graph.get_operations()
111        obj.save(checkpoint_prefix)
112        self.assertEqual(before_ops, graph.get_operations())
113
114  def testManyRestoresGraph(self):
115    """Restores after the first should not modify the graph."""
116    with context.graph_mode():
117      graph = ops.Graph()
118      with graph.as_default(), self.session(graph):
119        checkpoint_directory = self.get_temp_dir()
120        checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
121        obj = trackable_utils.Checkpoint()
122        obj.var = variable_scope.get_variable(name="v", initializer=0.)
123        obj.opt = adam.AdamOptimizer(0.1)
124        obj.opt.minimize(obj.var.read_value())
125        self.evaluate(trackable_utils.gather_initializers(obj))
126        save_path = obj.save(checkpoint_prefix)
127        obj.restore(save_path)
128        before_ops = graph.get_operations()
129        obj.restore(save_path)
130        self.assertEqual(before_ops, graph.get_operations())
131
132  def testMultipleGraphsNonSlotVariables(self):
133    with context.graph_mode():
134      checkpoint_directory = self.get_temp_dir()
135      checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
136      optimizer = adam.AdamOptimizer(0.001)
137      # Construct a model in one graph
138      first_graph = ops.Graph()
139      first_session = session_lib.Session(graph=first_graph)
140      with first_graph.as_default(), first_session.as_default():
141        first_variable = resource_variable_ops.ResourceVariable([1.])
142        first_root_trackable = trackable_utils.Checkpoint(
143            optimizer=optimizer, variable=first_variable)
144        train_op = optimizer.minimize(first_variable.read_value)
145        self.evaluate(trackable_utils.gather_initializers(
146            first_root_trackable))
147        self.evaluate(train_op)
148        self.evaluate(first_variable.assign([1.]))
149        self.evaluate(optimizer.get_slot(
150            var=first_variable, name="m").assign([2.]))
151        beta1_power, _ = optimizer._get_beta_accumulators()
152        self.evaluate(beta1_power.assign(3.))
153
154      # Save and load in a second graph
155      second_graph = ops.Graph()
156      with second_graph.as_default(), session_lib.Session(graph=second_graph):
157        second_variable = resource_variable_ops.ResourceVariable([1.])
158        second_root_trackable = trackable_utils.Checkpoint(
159            optimizer=optimizer, variable=second_variable)
160        train_op = optimizer.minimize(second_variable.read_value)
161        second_root_trackable.restore(None).initialize_or_restore()
162        self.evaluate(train_op)
163        self.evaluate(second_variable.assign([4.]))
164        self.evaluate(optimizer.get_slot(
165            var=second_variable, name="m").assign([5.]))
166        beta1_power, _ = optimizer._get_beta_accumulators()
167        self.evaluate(beta1_power.assign(6.))
168        save_path = second_root_trackable.save(checkpoint_prefix)
169        self.evaluate(second_variable.assign([7.]))
170        self.evaluate(optimizer.get_slot(
171            var=second_variable, name="m").assign([8.]))
172        beta1_power, _ = optimizer._get_beta_accumulators()
173        self.assertAllEqual(6., self.evaluate(beta1_power))
174        status = second_root_trackable.restore(save_path)
175        status.assert_consumed().run_restore_ops()
176        self.assertAllEqual([4.], self.evaluate(second_variable))
177        self.assertAllEqual([5.], self.evaluate(optimizer.get_slot(
178            var=second_variable, name="m")))
179        beta1_power, _ = optimizer._get_beta_accumulators()
180        self.assertAllEqual(6., self.evaluate(beta1_power))
181
182      # Check that the first graph is unmolested
183      with first_graph.as_default(), first_session.as_default():
184        self.assertAllEqual([1.], self.evaluate(first_variable))
185        self.assertAllEqual([2.], self.evaluate(optimizer.get_slot(
186            var=first_variable, name="m")))
187        beta1_power, _ = optimizer._get_beta_accumulators()
188        self.assertAllEqual(3., self.evaluate(beta1_power))
189
190
191class _ManualScope(autotrackable.AutoTrackable):
192
193  def __call__(self):
194    with variable_scope.variable_scope("ManualScope") as vs:
195      self.variable_scope = vs
196      with trackable_utils.capture_dependencies(template=self):
197        return self._build()
198
199  def _build(self):
200    return variable_scope.get_variable(name="in_manual_scope", shape=[])
201
202
203class TemplateTests(test.TestCase):
204
205  @test_util.run_in_graph_and_eager_modes
206  def test_trackable_save_restore(self):
207
208    def _templated():
209      v = variable_scope.get_variable(
210          "v", shape=[1], initializer=init_ops.zeros_initializer(),
211          use_resource=True)
212      v2 = variable_scope.get_variable(
213          "v2", shape=[1], initializer=init_ops.zeros_initializer(),
214          use_resource=True)
215      manual = _ManualScope()
216      return v, v + 1., v2, manual, manual()
217
218    save_template = template.make_template("s1", _templated)
219    v1_save, _, v2_save, manual_scope, manual_scope_v = save_template()
220    self.assertCountEqual([
221        id(obj) for obj in
222        [v1_save, v2_save, manual_scope, manual_scope_v, save_template]
223    ], [id(obj) for obj in trackable_utils.list_objects(save_template)])
224    self.assertDictEqual({"in_manual_scope": manual_scope_v},
225                         manual_scope._trackable_children())
226    optimizer = adam.AdamOptimizer(0.0)
227    save_root = trackable_utils.Checkpoint(
228        my_template=save_template, optimizer=optimizer)
229    optimizer.minimize(v1_save.read_value)
230    self.evaluate([v.initializer for v in save_template.variables])
231    self.evaluate([v.initializer for v in optimizer.variables()])
232    self.evaluate(v1_save.assign([12.]))
233    self.evaluate(v2_save.assign([14.]))
234    checkpoint_directory = self.get_temp_dir()
235    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
236    save_path = save_root.save(checkpoint_prefix)
237
238    load_template = template.make_template("s2", _templated)
239    load_optimizer = adam.AdamOptimizer(0.0)
240    load_root = trackable_utils.Checkpoint(
241        my_template=load_template, optimizer=load_optimizer)
242    status = load_root.restore(save_path)
243    var, var_plus_one, var2, _, _ = load_template()
244    load_optimizer.minimize(var.read_value)
245    self.assertEqual(3, len(load_template._trackable_children()))
246    self.assertEqual(set(["v", "v2", "ManualScope"]),
247                     load_template._trackable_children().keys())
248    status.assert_consumed().run_restore_ops()
249    self.assertAllEqual([12.], self.evaluate(var))
250    self.assertAllEqual([13.], self.evaluate(var_plus_one))
251    self.assertAllEqual([14.], self.evaluate(var2))
252
253
254if __name__ == "__main__":
255  test.main()
256