xref: /aosp_15_r20/external/tensorflow/tensorflow/python/checkpoint/functional_saver_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Tests for the functional saver."""
16
17import os
18
19from tensorflow.python.checkpoint import checkpoint_options
20from tensorflow.python.checkpoint import functional_saver
21from tensorflow.python.eager import context
22from tensorflow.python.eager import remote
23from tensorflow.python.eager import test
24from tensorflow.python.eager import wrap_function
25from tensorflow.python.framework import config
26from tensorflow.python.framework import constant_op
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import test_util
29from tensorflow.python.ops import resource_variable_ops
30from tensorflow.python.platform import gfile
31from tensorflow.python.training import server_lib
32from tensorflow.python.training.saving import saveable_object_util
33
34
35LOCALHOST = "/job:localhost/replica:0/task:0/device:CPU:0"
36
37
38class SaverTest(test.TestCase):
39
40  def setUp(self):
41    super(SaverTest, self).setUp()
42    cpus = config.list_physical_devices("CPU")
43    # Set 3 virtual CPUs
44    config.set_logical_device_configuration(cpus[0], [
45        context.LogicalDeviceConfiguration(),
46        context.LogicalDeviceConfiguration(),
47        context.LogicalDeviceConfiguration()
48    ])
49    self.local_options = checkpoint_options.CheckpointOptions(
50        experimental_io_device=LOCALHOST)
51
52  @test_util.run_in_graph_and_eager_modes
53  def test_resource_variable(self):
54    v1 = resource_variable_ops.ResourceVariable(2.)
55    self.evaluate(v1.initializer)
56    saver = functional_saver.MultiDeviceSaver(
57        saveable_object_util.saveable_objects_for_op(v1, "x"))
58    prefix = os.path.join(self.get_temp_dir(), "ckpt")
59    self.evaluate(saver.save(constant_op.constant(prefix)))
60    self.assertEqual(2, len(gfile.Glob(prefix + "*")))
61    self.evaluate(v1.assign(1.))
62    self.evaluate(saver.restore(prefix))
63    self.assertEqual(2., self.evaluate(v1))
64
65    v2 = resource_variable_ops.ResourceVariable(3.)
66    self.evaluate(v2.initializer)
67    second_saver = functional_saver.MultiDeviceSaver(
68        saveable_object_util.saveable_objects_for_op(v2, "x"))
69    self.evaluate(second_saver.restore(prefix))
70    self.assertEqual(2., self.evaluate(v2))
71
72  @test_util.run_in_graph_and_eager_modes
73  def test_resource_variable_use_localhost(self):
74    v1 = resource_variable_ops.ResourceVariable(2.)
75    self.evaluate(v1.initializer)
76    saver = functional_saver.MultiDeviceSaver(
77        saveable_object_util.saveable_objects_for_op(v1, "x"))
78    prefix = os.path.join(self.get_temp_dir(), "ckpt")
79    self.evaluate(saver.save(constant_op.constant(prefix), self.local_options))
80    self.assertEqual(2, len(gfile.Glob(prefix + "*")))
81    self.evaluate(v1.assign(1.))
82    self.evaluate(saver.restore(prefix, self.local_options))
83    self.assertEqual(2., self.evaluate(v1))
84
85    v2 = resource_variable_ops.ResourceVariable(3.)
86    self.evaluate(v2.initializer)
87    second_saver = functional_saver.MultiDeviceSaver(
88        saveable_object_util.saveable_objects_for_op(v2, "x"))
89    self.evaluate(second_saver.restore(prefix, self.local_options))
90    self.assertEqual(2., self.evaluate(v2))
91
92    # In graph mode, verify that the save and restore ops were set to run on
93    # localhost.
94    if not context.executing_eagerly():
95      for op in ops.get_default_graph().get_operations():
96        if op.type in ("SaveV2", "RestoreV2"):
97          self.assertEqual(LOCALHOST, op.device)
98
99  def test_to_proto(self):
100    v1 = resource_variable_ops.ResourceVariable(2.)
101    saver = functional_saver.MultiDeviceSaver(
102        saveable_object_util.saveable_objects_for_op(v1, "x"))
103    prefix = os.path.join(self.get_temp_dir(), "ckpt")
104
105    proto_accumulator = []
106    wrapped = wrap_function.wrap_function(
107        lambda: proto_accumulator.append(saver.to_proto()), signature=())
108    self.assertEqual(1, len(proto_accumulator))
109    proto = proto_accumulator[0]
110    save = wrapped.prune(
111        feeds=wrapped.graph.get_tensor_by_name(proto.filename_tensor_name),
112        fetches=wrapped.graph.get_tensor_by_name(proto.save_tensor_name))
113    restore = wrapped.prune(
114        feeds=wrapped.graph.get_tensor_by_name(proto.filename_tensor_name),
115        fetches=wrapped.graph.get_operation_by_name(proto.restore_op_name))
116    save_path = save(constant_op.constant(prefix))
117    v1.assign(1.)
118    restore(constant_op.constant(save_path))
119    self.assertEqual(2., self.evaluate(v1))
120
121    v2 = resource_variable_ops.ResourceVariable(3.)
122    second_saver = functional_saver.MultiDeviceSaver(
123        saveable_object_util.saveable_objects_for_op(v2, "x"))
124    second_saver.restore(save_path)
125    self.assertEqual(2., self.evaluate(v2))
126
127  @test_util.disable_tfrt("b/171765113: server is not supported in TFRT yet.")
128  def test_checkpoint_is_sharded_by_task(self):
129    servers = [server_lib.Server.create_local_server() for _ in range(3)]
130    cluster_spec = server_lib.ClusterSpec({
131        "worker": [s.target[len("grpc://"):] for s in servers]})
132    remote.connect_to_cluster(cluster_spec)
133    with ops.device("/job:worker/task:0/cpu:0"):
134      v0 = resource_variable_ops.ResourceVariable(0.)
135    with ops.device("/job:worker/task:1/cpu:0"):
136      v1 = resource_variable_ops.ResourceVariable(1.)
137    with ops.device("/job:worker/task:2/cpu:0"):
138      v2 = resource_variable_ops.ResourceVariable(2.)
139
140    self.evaluate([v0.initializer, v1.initializer, v2.initializer])
141    saver = functional_saver.MultiDeviceSaver(
142        list(saveable_object_util.saveable_objects_for_op(v0, "v0")) +
143        list(saveable_object_util.saveable_objects_for_op(v1, "v1")) +
144        list(saveable_object_util.saveable_objects_for_op(v2, "v2")))
145    prefix = os.path.join(self.get_temp_dir(), "ckpt")
146    self.evaluate(saver.save(constant_op.constant(prefix)))
147    self.assertEqual(4, len(gfile.Glob(prefix + "*")))
148    self.evaluate(v0.assign(-1.))
149    self.evaluate(v1.assign(-1.))
150    self.evaluate(v2.assign(-1.))
151    self.evaluate(saver.restore(constant_op.constant(prefix)))
152    self.assertEqual(0., self.evaluate(v0))
153    self.assertEqual(1., self.evaluate(v1))
154    self.assertEqual(2., self.evaluate(v2))
155
156  @test_util.run_in_graph_and_eager_modes
157  def test_checkpoint_multi_device_using_localhost(self):
158    with ops.device("cpu:0"):
159      v0 = resource_variable_ops.ResourceVariable(0.)
160    with ops.device("cpu:1"):
161      v1 = resource_variable_ops.ResourceVariable(1.)
162    with ops.device("cpu:2"):
163      v2 = resource_variable_ops.ResourceVariable(2.)
164
165    self.evaluate([v0.initializer, v1.initializer, v2.initializer])
166    saver = functional_saver.MultiDeviceSaver(
167        list(saveable_object_util.saveable_objects_for_op(v0, "v0")) +
168        list(saveable_object_util.saveable_objects_for_op(v1, "v1")) +
169        list(saveable_object_util.saveable_objects_for_op(v2, "v2")))
170    prefix = os.path.join(self.get_temp_dir(), "ckpt")
171    self.evaluate(saver.save(constant_op.constant(prefix), self.local_options))
172    self.assertEqual(2, len(gfile.Glob(prefix + "*")))
173    self.evaluate(v0.assign(-1.))
174    self.evaluate(v1.assign(-1.))
175    self.evaluate(v2.assign(-1.))
176    self.evaluate(
177        saver.restore(constant_op.constant(prefix), self.local_options))
178    self.assertEqual(0., self.evaluate(v0))
179    self.assertEqual(1., self.evaluate(v1))
180    self.assertEqual(2., self.evaluate(v2))
181
182    # In graph mode, verify that the save and restore ops were set to run on
183    # localhost.
184    if not context.executing_eagerly():
185      for op in ops.get_default_graph().get_operations():
186        if op.type in ("SaveV2", "RestoreV2", "MergeV2Checkpoints"):
187          self.assertEqual(LOCALHOST, op.device)
188
189
190if __name__ == "__main__":
191  ops.enable_eager_execution()
192  test.main()
193