1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Tests for the functional saver.""" 16 17import os 18 19from tensorflow.python.checkpoint import checkpoint_options 20from tensorflow.python.checkpoint import functional_saver 21from tensorflow.python.eager import context 22from tensorflow.python.eager import remote 23from tensorflow.python.eager import test 24from tensorflow.python.eager import wrap_function 25from tensorflow.python.framework import config 26from tensorflow.python.framework import constant_op 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import test_util 29from tensorflow.python.ops import resource_variable_ops 30from tensorflow.python.platform import gfile 31from tensorflow.python.training import server_lib 32from tensorflow.python.training.saving import saveable_object_util 33 34 35LOCALHOST = "/job:localhost/replica:0/task:0/device:CPU:0" 36 37 38class SaverTest(test.TestCase): 39 40 def setUp(self): 41 super(SaverTest, self).setUp() 42 cpus = config.list_physical_devices("CPU") 43 # Set 3 virtual CPUs 44 config.set_logical_device_configuration(cpus[0], [ 45 context.LogicalDeviceConfiguration(), 46 context.LogicalDeviceConfiguration(), 47 context.LogicalDeviceConfiguration() 48 ]) 49 self.local_options = checkpoint_options.CheckpointOptions( 50 experimental_io_device=LOCALHOST) 51 52 @test_util.run_in_graph_and_eager_modes 53 def test_resource_variable(self): 54 v1 = resource_variable_ops.ResourceVariable(2.) 55 self.evaluate(v1.initializer) 56 saver = functional_saver.MultiDeviceSaver( 57 saveable_object_util.saveable_objects_for_op(v1, "x")) 58 prefix = os.path.join(self.get_temp_dir(), "ckpt") 59 self.evaluate(saver.save(constant_op.constant(prefix))) 60 self.assertEqual(2, len(gfile.Glob(prefix + "*"))) 61 self.evaluate(v1.assign(1.)) 62 self.evaluate(saver.restore(prefix)) 63 self.assertEqual(2., self.evaluate(v1)) 64 65 v2 = resource_variable_ops.ResourceVariable(3.) 66 self.evaluate(v2.initializer) 67 second_saver = functional_saver.MultiDeviceSaver( 68 saveable_object_util.saveable_objects_for_op(v2, "x")) 69 self.evaluate(second_saver.restore(prefix)) 70 self.assertEqual(2., self.evaluate(v2)) 71 72 @test_util.run_in_graph_and_eager_modes 73 def test_resource_variable_use_localhost(self): 74 v1 = resource_variable_ops.ResourceVariable(2.) 75 self.evaluate(v1.initializer) 76 saver = functional_saver.MultiDeviceSaver( 77 saveable_object_util.saveable_objects_for_op(v1, "x")) 78 prefix = os.path.join(self.get_temp_dir(), "ckpt") 79 self.evaluate(saver.save(constant_op.constant(prefix), self.local_options)) 80 self.assertEqual(2, len(gfile.Glob(prefix + "*"))) 81 self.evaluate(v1.assign(1.)) 82 self.evaluate(saver.restore(prefix, self.local_options)) 83 self.assertEqual(2., self.evaluate(v1)) 84 85 v2 = resource_variable_ops.ResourceVariable(3.) 86 self.evaluate(v2.initializer) 87 second_saver = functional_saver.MultiDeviceSaver( 88 saveable_object_util.saveable_objects_for_op(v2, "x")) 89 self.evaluate(second_saver.restore(prefix, self.local_options)) 90 self.assertEqual(2., self.evaluate(v2)) 91 92 # In graph mode, verify that the save and restore ops were set to run on 93 # localhost. 94 if not context.executing_eagerly(): 95 for op in ops.get_default_graph().get_operations(): 96 if op.type in ("SaveV2", "RestoreV2"): 97 self.assertEqual(LOCALHOST, op.device) 98 99 def test_to_proto(self): 100 v1 = resource_variable_ops.ResourceVariable(2.) 101 saver = functional_saver.MultiDeviceSaver( 102 saveable_object_util.saveable_objects_for_op(v1, "x")) 103 prefix = os.path.join(self.get_temp_dir(), "ckpt") 104 105 proto_accumulator = [] 106 wrapped = wrap_function.wrap_function( 107 lambda: proto_accumulator.append(saver.to_proto()), signature=()) 108 self.assertEqual(1, len(proto_accumulator)) 109 proto = proto_accumulator[0] 110 save = wrapped.prune( 111 feeds=wrapped.graph.get_tensor_by_name(proto.filename_tensor_name), 112 fetches=wrapped.graph.get_tensor_by_name(proto.save_tensor_name)) 113 restore = wrapped.prune( 114 feeds=wrapped.graph.get_tensor_by_name(proto.filename_tensor_name), 115 fetches=wrapped.graph.get_operation_by_name(proto.restore_op_name)) 116 save_path = save(constant_op.constant(prefix)) 117 v1.assign(1.) 118 restore(constant_op.constant(save_path)) 119 self.assertEqual(2., self.evaluate(v1)) 120 121 v2 = resource_variable_ops.ResourceVariable(3.) 122 second_saver = functional_saver.MultiDeviceSaver( 123 saveable_object_util.saveable_objects_for_op(v2, "x")) 124 second_saver.restore(save_path) 125 self.assertEqual(2., self.evaluate(v2)) 126 127 @test_util.disable_tfrt("b/171765113: server is not supported in TFRT yet.") 128 def test_checkpoint_is_sharded_by_task(self): 129 servers = [server_lib.Server.create_local_server() for _ in range(3)] 130 cluster_spec = server_lib.ClusterSpec({ 131 "worker": [s.target[len("grpc://"):] for s in servers]}) 132 remote.connect_to_cluster(cluster_spec) 133 with ops.device("/job:worker/task:0/cpu:0"): 134 v0 = resource_variable_ops.ResourceVariable(0.) 135 with ops.device("/job:worker/task:1/cpu:0"): 136 v1 = resource_variable_ops.ResourceVariable(1.) 137 with ops.device("/job:worker/task:2/cpu:0"): 138 v2 = resource_variable_ops.ResourceVariable(2.) 139 140 self.evaluate([v0.initializer, v1.initializer, v2.initializer]) 141 saver = functional_saver.MultiDeviceSaver( 142 list(saveable_object_util.saveable_objects_for_op(v0, "v0")) + 143 list(saveable_object_util.saveable_objects_for_op(v1, "v1")) + 144 list(saveable_object_util.saveable_objects_for_op(v2, "v2"))) 145 prefix = os.path.join(self.get_temp_dir(), "ckpt") 146 self.evaluate(saver.save(constant_op.constant(prefix))) 147 self.assertEqual(4, len(gfile.Glob(prefix + "*"))) 148 self.evaluate(v0.assign(-1.)) 149 self.evaluate(v1.assign(-1.)) 150 self.evaluate(v2.assign(-1.)) 151 self.evaluate(saver.restore(constant_op.constant(prefix))) 152 self.assertEqual(0., self.evaluate(v0)) 153 self.assertEqual(1., self.evaluate(v1)) 154 self.assertEqual(2., self.evaluate(v2)) 155 156 @test_util.run_in_graph_and_eager_modes 157 def test_checkpoint_multi_device_using_localhost(self): 158 with ops.device("cpu:0"): 159 v0 = resource_variable_ops.ResourceVariable(0.) 160 with ops.device("cpu:1"): 161 v1 = resource_variable_ops.ResourceVariable(1.) 162 with ops.device("cpu:2"): 163 v2 = resource_variable_ops.ResourceVariable(2.) 164 165 self.evaluate([v0.initializer, v1.initializer, v2.initializer]) 166 saver = functional_saver.MultiDeviceSaver( 167 list(saveable_object_util.saveable_objects_for_op(v0, "v0")) + 168 list(saveable_object_util.saveable_objects_for_op(v1, "v1")) + 169 list(saveable_object_util.saveable_objects_for_op(v2, "v2"))) 170 prefix = os.path.join(self.get_temp_dir(), "ckpt") 171 self.evaluate(saver.save(constant_op.constant(prefix), self.local_options)) 172 self.assertEqual(2, len(gfile.Glob(prefix + "*"))) 173 self.evaluate(v0.assign(-1.)) 174 self.evaluate(v1.assign(-1.)) 175 self.evaluate(v2.assign(-1.)) 176 self.evaluate( 177 saver.restore(constant_op.constant(prefix), self.local_options)) 178 self.assertEqual(0., self.evaluate(v0)) 179 self.assertEqual(1., self.evaluate(v1)) 180 self.assertEqual(2., self.evaluate(v2)) 181 182 # In graph mode, verify that the save and restore ops were set to run on 183 # localhost. 184 if not context.executing_eagerly(): 185 for op in ops.get_default_graph().get_operations(): 186 if op.type in ("SaveV2", "RestoreV2", "MergeV2Checkpoints"): 187 self.assertEqual(LOCALHOST, op.device) 188 189 190if __name__ == "__main__": 191 ops.enable_eager_execution() 192 test.main() 193