1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for ShardedVariable.""" 16 17import os 18 19from absl.testing import parameterized 20import numpy as np 21from tensorflow.python.checkpoint import checkpoint as util 22from tensorflow.python.client import session as session_lib 23from tensorflow.python.compat import v2_compat 24from tensorflow.python.distribute import combinations 25from tensorflow.python.distribute import distribution_strategy_context as ds_context 26from tensorflow.python.distribute import parameter_server_strategy_v2 27from tensorflow.python.distribute import sharded_variable 28from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver 29from tensorflow.python.distribute.test_util import get_cluster_def 30from tensorflow.python.distribute.test_util import TestClusterParams 31from tensorflow.python.eager import context 32from tensorflow.python.eager import def_function 33from tensorflow.python.framework import constant_op 34from tensorflow.python.framework import dtypes 35from tensorflow.python.framework import indexed_slices 36from tensorflow.python.framework import ops 37from tensorflow.python.framework import sparse_tensor 38from tensorflow.python.framework import tensor_shape 39from tensorflow.python.framework import tensor_spec 40from tensorflow.python.module import module 41from tensorflow.python.ops import array_ops 42from tensorflow.python.ops import control_flow_ops 43from tensorflow.python.ops import embedding_ops 44from tensorflow.python.ops import math_ops 45from tensorflow.python.ops import random_ops 46from tensorflow.python.ops import variables as variables_lib 47from tensorflow.python.platform import test 48from tensorflow.python.saved_model import load 49from tensorflow.python.saved_model import loader 50from tensorflow.python.saved_model import save 51from tensorflow.python.saved_model import signature_constants 52from tensorflow.python.saved_model import tag_constants 53from tensorflow.python.trackable import autotrackable 54from tensorflow.python.training.server_lib import ClusterSpec 55from tensorflow.python.util import nest 56 57# We create one cluster to share between tests. The cluster should be large 58# enough to accommodate all the tests. Adjust the following constants as needed 59# but be aware of resource limitations in OSS tests. 60test_cluster_params = TestClusterParams(None, 2, 3) 61 62 63def _load_and_run( 64 model_dir, 65 inputs, 66 signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY): 67 """Load a SavedModel into a TF 1.x-style graph and run `signature_key`.""" 68 graph = ops.Graph() 69 with graph.as_default(), session_lib.Session() as session: 70 meta_graph_def = loader.load(session, [tag_constants.SERVING], model_dir) 71 signature = meta_graph_def.signature_def[signature_key] 72 feed_dict = {} 73 for arg_name in inputs.keys(): 74 input_tensor = session.graph.get_tensor_by_name( 75 signature.inputs[arg_name].name) 76 feed_dict[input_tensor] = inputs[arg_name] 77 output_dict = {} 78 for output_name, output_tensor_info in signature.outputs.items(): 79 output_dict[output_name] = session.graph.get_tensor_by_name( 80 output_tensor_info.name) 81 return session.run(output_dict, feed_dict=feed_dict) 82 83 84class PartitionerTest(test.TestCase): 85 86 def test_fixed_shards_partitioner(self): 87 partitioner = sharded_variable.FixedShardsPartitioner(num_shards=2) 88 got = partitioner(tensor_shape.TensorShape([10, 3]), dtypes.float32) 89 self.assertAllEqual(got, [2, 1]) 90 91 def test_min_size_partitioner(self): 92 partitioner = sharded_variable.MinSizePartitioner( 93 min_shard_bytes=4, max_shards=2) 94 got = partitioner(tensor_shape.TensorShape([6, 1]), dtypes.float32) 95 self.assertAllEqual(got, [2, 1]) 96 97 partitioner = sharded_variable.MinSizePartitioner( 98 min_shard_bytes=4, max_shards=10) 99 got = partitioner(tensor_shape.TensorShape([6, 1]), dtypes.float32) 100 self.assertAllEqual(got, [6, 1]) 101 102 def test_max_size_partitioner(self): 103 partitioner = sharded_variable.MaxSizePartitioner(max_shard_bytes=4) 104 got = partitioner(tensor_shape.TensorShape([6, 1]), dtypes.float32) 105 self.assertAllEqual(got, [6, 1]) 106 107 partitioner = sharded_variable.MaxSizePartitioner( 108 max_shard_bytes=4, max_shards=2) 109 got = partitioner(tensor_shape.TensorShape([6, 1]), dtypes.float32) 110 self.assertAllEqual(got, [2, 1]) 111 112 partitioner = sharded_variable.MaxSizePartitioner(max_shard_bytes=1024) 113 got = partitioner(tensor_shape.TensorShape([6, 1]), dtypes.float32) 114 self.assertAllEqual(got, [1, 1]) 115 116 117class ShardedVariableTest(test.TestCase, parameterized.TestCase): 118 119 def test_sharded_variable_simple(self): 120 v0 = variables_lib.Variable([0]) 121 v1 = variables_lib.Variable([1]) 122 s = sharded_variable.ShardedVariable([v0, v1], name='s') 123 self.assertEqual(s.variables[0], v0) 124 self.assertEqual(s.variables[1], v1) 125 self.assertEqual(s.shape.as_list(), [2]) 126 self.assertEqual(s.dtype, v0.dtype) 127 self.assertEqual(s.name, 's') 128 129 def test_assign(self): 130 v0 = variables_lib.Variable([[0, 0]]) 131 v1 = variables_lib.Variable([[1, 1], [2, 2]]) 132 v2 = variables_lib.Variable([[3, 3]]) 133 s = sharded_variable.ShardedVariable([v0, v1, v2]) 134 ret = s.assign([[4, 4], [5, 5], [6, 6], [7, 7]]) 135 self.assertAllEqual(self.evaluate(s.variables[0]), [[4, 4]]) 136 self.assertAllEqual(self.evaluate(s.variables[1]), [[5, 5], [6, 6]]) 137 self.assertAllEqual(self.evaluate(s.variables[2]), [[7, 7]]) 138 self.assertIs(ret, s) 139 140 def test_assign_add(self): 141 v0 = variables_lib.Variable([[0, 0]]) 142 v1 = variables_lib.Variable([[1, 1], [2, 2]]) 143 v2 = variables_lib.Variable([[3, 3]]) 144 s = sharded_variable.ShardedVariable([v0, v1, v2]) 145 ret = s.assign_add([[1, 1], [1, 1], [2, 2], [2, 2]]) 146 self.assertAllEqual(self.evaluate(s.variables[0]), [[1, 1]]) 147 self.assertAllEqual(self.evaluate(s.variables[1]), [[2, 2], [4, 4]]) 148 self.assertAllEqual(self.evaluate(s.variables[2]), [[5, 5]]) 149 self.assertIs(ret, s) 150 151 def test_assign_sub(self): 152 v0 = variables_lib.Variable([[0, 0]]) 153 v1 = variables_lib.Variable([[1, 1], [2, 2]]) 154 v2 = variables_lib.Variable([[3, 3]]) 155 s = sharded_variable.ShardedVariable([v0, v1, v2]) 156 ret = s.assign_sub([[0, 0], [1, 1], [1, 1], [3, 3]]) 157 self.assertAllEqual(self.evaluate(s.variables[0]), [[0, 0]]) 158 self.assertAllEqual(self.evaluate(s.variables[1]), [[0, 0], [1, 1]]) 159 self.assertAllEqual(self.evaluate(s.variables[2]), [[0, 0]]) 160 self.assertIs(ret, s) 161 162 def test_scatter_add_uneven_partition(self): 163 v = variables_lib.Variable(array_ops.zeros((32, 1))) 164 sparse_delta = indexed_slices.IndexedSlices( 165 values=constant_op.constant([[0.], [1.], [2.], [3.], [4.], [5.]]), 166 indices=constant_op.constant([0, 10, 11, 12, 30, 31])) 167 168 v0 = variables_lib.Variable(array_ops.zeros((11, 1))) 169 v1 = variables_lib.Variable(array_ops.zeros((11, 1))) 170 v2 = variables_lib.Variable(array_ops.zeros((10, 1))) 171 sv = sharded_variable.ShardedVariable([v0, v1, v2]) 172 173 v.scatter_add(sparse_delta) 174 sv.scatter_add(sparse_delta) 175 self.assertAllEqual(v, ops.convert_to_tensor(sv)) 176 177 @def_function.function 178 def func(): 179 v.scatter_add(sparse_delta) 180 sv.scatter_add(sparse_delta) 181 182 func() 183 self.assertAllEqual(v, ops.convert_to_tensor(sv)) 184 185 @parameterized.parameters('scatter_add', 'scatter_div', 'scatter_max', 186 'scatter_min', 'scatter_mul', 'scatter_sub', 187 'scatter_update') 188 def test_scatter_ops_even_partition(self, op): 189 v = variables_lib.Variable(array_ops.zeros((30, 1))) 190 # Make sure values does not contain 0 due to testing `scatter_div`! 191 sparse_delta = indexed_slices.IndexedSlices( 192 values=constant_op.constant([[1.], [2.], [3.], [4.], [5.]]), 193 indices=constant_op.constant([0, 10, 12, 21, 22])) 194 195 v0 = variables_lib.Variable(array_ops.zeros((10, 1))) 196 v1 = variables_lib.Variable(array_ops.zeros((10, 1))) 197 v2 = variables_lib.Variable(array_ops.zeros((10, 1))) 198 sv = sharded_variable.ShardedVariable([v0, v1, v2]) 199 200 getattr(v, op)(sparse_delta, name='scatter_v') 201 getattr(sv, op)(sparse_delta, name='scatter_sv') 202 self.assertAllEqual(v, ops.convert_to_tensor(sv)) 203 204 @def_function.function 205 def func(): 206 getattr(v, op)(sparse_delta, name='scatter_v') 207 getattr(sv, op)(sparse_delta, name='scatter_sv') 208 209 func() 210 self.assertAllEqual(v, ops.convert_to_tensor(sv)) 211 212 def test_batch_scatter_update(self): 213 v = variables_lib.Variable(array_ops.zeros((32, 1))) 214 sparse_delta = indexed_slices.IndexedSlices( 215 values=constant_op.constant([[0.], [1.], [2.], [3.], [4.], [5.]]), 216 indices=constant_op.constant([10, 11, 12, 13, 14, 15])) 217 218 v0 = variables_lib.Variable(array_ops.zeros((11, 1))) 219 v1 = variables_lib.Variable(array_ops.zeros((11, 1))) 220 v2 = variables_lib.Variable(array_ops.zeros((10, 1))) 221 sv = sharded_variable.ShardedVariable([v0, v1, v2]) 222 223 v.batch_scatter_update(sparse_delta) 224 sv.batch_scatter_update(sparse_delta) 225 self.assertAllEqual(v, ops.convert_to_tensor(sv)) 226 227 @def_function.function 228 def func(): 229 v.batch_scatter_update(sparse_delta) 230 sv.batch_scatter_update(sparse_delta) 231 232 func() 233 self.assertAllEqual(v, ops.convert_to_tensor(sv)) 234 235 def test_sparse_read(self): 236 v = variables_lib.Variable(array_ops.zeros((30, 1))) 237 indices = constant_op.constant([0, 10, 12, 21, 22]) 238 239 v0 = variables_lib.Variable(array_ops.zeros((10, 1))) 240 v1 = variables_lib.Variable(array_ops.zeros((10, 1))) 241 v2 = variables_lib.Variable(array_ops.zeros((10, 1))) 242 sv = sharded_variable.ShardedVariable([v0, v1, v2]) 243 244 self.assertAllEqual(v.sparse_read(indices), sv.sparse_read(indices)) 245 246 @def_function.function 247 def func(): 248 return v.sparse_read(indices), sv.sparse_read(indices) 249 250 got, expect = func() 251 self.assertAllEqual(got, expect) 252 253 def test_control_dep_on_assign(self): 254 v0 = variables_lib.Variable([[0, 0]]) 255 v1 = variables_lib.Variable([[1, 1], [2, 2]]) 256 v2 = variables_lib.Variable([[3, 3]]) 257 s = sharded_variable.ShardedVariable([v0, v1, v2]) 258 259 @def_function.function 260 def func(): 261 ret = s.assign([[4, 4], [5, 5], [6, 6], [7, 7]]) 262 with ops.control_dependencies([ret]): 263 a = array_ops.ones((1, 1)) 264 with ops.control_dependencies([control_flow_ops.group(ret)]): 265 b = array_ops.ones((1, 1)) 266 return a, b 267 268 func() 269 270 def test_convert_to_tensor(self): 271 v0 = variables_lib.Variable([[0, 0]]) 272 v1 = variables_lib.Variable([[1, 1], [2, 2]]) 273 v2 = variables_lib.Variable([[3, 3]]) 274 s = sharded_variable.ShardedVariable([v0, v1, v2]) 275 t = ops.convert_to_tensor(s) 276 self.assertAllEqual(t, [[0, 0], [1, 1], [2, 2], [3, 3]]) 277 278 def test_save_restore(self): 279 fname = os.path.join(self.get_temp_dir(), 'checkpoint') 280 variables = [ 281 variables_lib.Variable([0]), 282 variables_lib.Variable([1]), 283 variables_lib.Variable([2]), 284 variables_lib.Variable([3]) 285 ] 286 s = sharded_variable.ShardedVariable(variables, name='s') 287 288 cp = util.Checkpoint(s=s) 289 self.assertEqual(self.evaluate(cp.s.variables[0]), [0]) 290 cp.write(fname) 291 292 self.evaluate(cp.s.variables[0].assign([4])) 293 self.assertEqual(self.evaluate(cp.s.variables[0]), [4]) 294 295 cp.restore(fname) 296 # Tests that the original weights are restored. 297 self.assertEqual(self.evaluate(cp.s.variables[0]), [0]) 298 299 def test_save_restore_different_partitions(self): 300 fname = os.path.join(self.get_temp_dir(), 'checkpoint') 301 variables = [ 302 variables_lib.Variable([0]), 303 variables_lib.Variable([1]), 304 variables_lib.Variable([2]), 305 variables_lib.Variable([3]) 306 ] 307 s = sharded_variable.ShardedVariable(variables, name='s') 308 309 cp = util.Checkpoint(s=s) 310 cp.write(fname) 311 312 variables2 = [variables_lib.Variable([0, 0, 0, 0])] 313 s2 = sharded_variable.ShardedVariable(variables2, name='s') 314 315 # Restore from 4 partitions into 1. 316 cp2 = util.Checkpoint(s=s2) 317 cp2.restore(fname) 318 self.assertAllEqual(self.evaluate(cp2.s.variables[0]), [0, 1, 2, 3]) 319 320 self.evaluate(cp2.s.variables[0].assign([5, 10, 15, 20])) 321 cp2.write(fname) 322 323 # Restore 1 partition into 4. 324 cp.restore(fname) 325 self.assertEqual(self.evaluate(cp.s.variables[0]), [5]) 326 self.assertEqual(self.evaluate(cp.s.variables[1]), [10]) 327 self.assertEqual(self.evaluate(cp.s.variables[2]), [15]) 328 self.assertEqual(self.evaluate(cp.s.variables[3]), [20]) 329 330 def test_save_restore_4_to_2_partitions(self): 331 fname = os.path.join(self.get_temp_dir(), 'checkpoint') 332 variables = [ 333 variables_lib.Variable([0]), 334 variables_lib.Variable([1]), 335 variables_lib.Variable([2]), 336 variables_lib.Variable([3]) 337 ] 338 s = sharded_variable.ShardedVariable(variables, name='s') 339 cp = util.Checkpoint(s=s) 340 cp.write(fname) 341 342 variables2 = [ 343 variables_lib.Variable([0, 0]), 344 variables_lib.Variable([0, 0]) 345 ] 346 s2 = sharded_variable.ShardedVariable(variables2, name='s') 347 cp2 = util.Checkpoint(s=s2) 348 cp2.restore(fname) 349 # Assert that weights from the 4 partitions were loaded here. 350 self.assertLen(cp2.s.variables, 2) 351 self.assertAllEqual(self.evaluate(cp2.s.variables[0]), [0, 1]) 352 self.assertAllEqual(self.evaluate(cp2.s.variables[1]), [2, 3]) 353 354 def test_delayed_restore(self): 355 fname = os.path.join(self.get_temp_dir(), 'checkpoint') 356 model = autotrackable.AutoTrackable() 357 variables = [ 358 variables_lib.Variable([0]), 359 variables_lib.Variable([1]), 360 variables_lib.Variable([2]), 361 variables_lib.Variable([3]) 362 ] 363 model.s = sharded_variable.ShardedVariable(variables) 364 cp = util.Checkpoint(model=model) 365 cp.write(fname) 366 367 model2 = autotrackable.AutoTrackable() 368 cp2 = util.Checkpoint(model=model2) 369 cp2.restore(fname) 370 variables2 = [ 371 variables_lib.Variable([0]), 372 variables_lib.Variable([0]), 373 variables_lib.Variable([0]), 374 variables_lib.Variable([0]) 375 ] 376 model2.s = sharded_variable.ShardedVariable(variables2) 377 self.assertAllEqual(self.evaluate(model2.s.variables[0]), [0]) 378 self.assertAllEqual(self.evaluate(model2.s.variables[1]), [1]) 379 self.assertAllEqual(self.evaluate(model2.s.variables[2]), [2]) 380 self.assertAllEqual(self.evaluate(model2.s.variables[3]), [3]) 381 382 def test_delayed_restore_4_to_2_partitions(self): 383 fname = os.path.join(self.get_temp_dir(), 'checkpoint') 384 model = autotrackable.AutoTrackable() 385 variables = [ 386 variables_lib.Variable([0]), 387 variables_lib.Variable([1]), 388 variables_lib.Variable([2]), 389 variables_lib.Variable([3]) 390 ] 391 model.s = sharded_variable.ShardedVariable(variables) 392 cp = util.Checkpoint(model=model) 393 cp.write(fname) 394 395 model2 = autotrackable.AutoTrackable() 396 cp2 = util.Checkpoint(model=model2) 397 cp2.restore(fname) 398 variables2 = [ 399 variables_lib.Variable([0, 0]), 400 variables_lib.Variable([0, 0]) 401 ] 402 model2.s = sharded_variable.ShardedVariable(variables2) 403 self.assertAllEqual(self.evaluate(model2.s.variables[0]), [0, 1]) 404 self.assertAllEqual(self.evaluate(model2.s.variables[1]), [2, 3]) 405 406 def test_save_graph_def(self): 407 root = autotrackable.AutoTrackable() 408 v1 = variables_lib.Variable([3.]) 409 v2 = variables_lib.Variable([2.]) 410 root.v = sharded_variable.ShardedVariable([v1, v2]) 411 root.train = def_function.function( 412 lambda x: embedding_ops.embedding_lookup_v2(root.v.variables, x)) 413 # TODO(b/144057383): Remove the necessity of root.serve once saving context 414 # is made to tf.function cache. 415 root.serve = def_function.function( 416 lambda x: embedding_ops.embedding_lookup_v2(root.v.variables[0], x), 417 input_signature=[tensor_spec.TensorSpec([2], dtypes.int32, name='x')]) 418 419 # Trace and use root.train 420 self.assertAllEqual([3., 2.], root.train([0, 1]).numpy()) 421 422 save_dir = os.path.join(self.get_temp_dir(), 'saved_model') 423 save.save(root, save_dir, root.serve) 424 self.assertAllEqual([3., 2.], 425 _load_and_run(save_dir, {'x': [0, 1]})['output_0']) 426 427 # Continue using root.train for training 428 self.assertAllEqual([3., 2.], root.train([0, 1]).numpy()) 429 430 def test_validation_errors(self): 431 with self.assertRaisesRegex(TypeError, 'should be a non-empty list of'): 432 sharded_variable.ShardedVariable(None) 433 434 with self.assertRaisesRegex(TypeError, 'should be a non-empty list of'): 435 sharded_variable.ShardedVariable( 436 [variables_lib.Variable([0]), 'not-a-variable']) 437 438 with self.assertRaisesRegex(TypeError, 'should be a non-empty list of'): 439 sharded_variable.ShardedVariable([]) 440 441 with self.assertRaisesRegex(ValueError, 'must have the same dtype'): 442 sharded_variable.ShardedVariable([ 443 variables_lib.Variable([0], dtype='int64'), 444 variables_lib.Variable([1], dtype='int32') 445 ]) 446 447 with self.assertRaisesRegex(ValueError, 'the same shapes except'): 448 sharded_variable.ShardedVariable([ 449 variables_lib.Variable(array_ops.ones((5, 10))), 450 variables_lib.Variable(array_ops.ones((5, 20))) 451 ]) 452 453 with self.assertRaisesRegex(ValueError, '`SaveSliceInfo` should not'): 454 v = variables_lib.Variable([0]) 455 v._set_save_slice_info( 456 variables_lib.Variable.SaveSliceInfo( 457 full_name='s', full_shape=[2], var_offset=[0], var_shape=[1])) 458 sharded_variable.ShardedVariable([v]) 459 460 def test_as_function_input(self): 461 variables1 = [ 462 variables_lib.Variable([1]), 463 variables_lib.Variable([1]), 464 ] 465 s = sharded_variable.ShardedVariable(variables1) 466 variables2 = [ 467 variables_lib.Variable([2]), 468 variables_lib.Variable([2]), 469 ] 470 s2 = sharded_variable.ShardedVariable(variables2) 471 472 trace_count = [0] 473 474 @def_function.function 475 def func(sharded_var): 476 trace_count[0] = trace_count[0] + 1 477 sharded_var.assign([0, 0]) 478 479 func(s) 480 self.assertAllEqual(ops.convert_to_tensor(s), [0, 0]) 481 self.assertEqual(trace_count[0], 1) 482 func(s2) 483 self.assertAllEqual(ops.convert_to_tensor(s2), [0, 0]) 484 self.assertEqual(trace_count[0], 1) 485 486 def test_flatten(self): 487 variables = [ 488 variables_lib.Variable([0]), 489 variables_lib.Variable([1]), 490 ] 491 s = sharded_variable.ShardedVariable(variables) 492 493 got = nest.flatten(s) 494 self.assertIs(s, got[0]) 495 496 got = nest.flatten(s, expand_composites=True) 497 expected = nest.flatten(variables, expand_composites=True) 498 self.assertEqual(got, expected) 499 500 def test_tf_module(self): 501 502 class Model(module.Module): 503 504 def __init__(self): 505 super().__init__() 506 variables = [ 507 variables_lib.Variable([0]), 508 variables_lib.Variable([1]), 509 ] 510 self.w = sharded_variable.ShardedVariable(variables) 511 512 model = Model() 513 514 self.assertLen(model.variables, 2) 515 self.assertEqual(model.variables[0], [0]) 516 self.assertEqual(model.variables[1], [1]) 517 self.assertAllEqual(model.variables, model.trainable_variables) 518 519 self.assertLen(model._trackable_children(), 1) 520 self.assertIs(model._trackable_children().popitem()[1], model.w) 521 522 def test_embedding_lookup(self): 523 v = [ 524 variables_lib.Variable([[1., 2.], [3., 4.]]), 525 variables_lib.Variable([[5., 6.], [7., 8.]]), 526 variables_lib.Variable([[9., 10.]]) 527 ] 528 sv = sharded_variable.ShardedVariable(v) 529 530 @def_function.function 531 def lookup(): 532 ids = constant_op.constant([0, 3, 4]) 533 return embedding_ops.embedding_lookup_v2(sv, ids) 534 535 @def_function.function 536 def sparse_lookup(): 537 sp_ids = sparse_tensor.SparseTensor( 538 indices=[[0, 0], [0, 1], [1, 0], [2, 2]], 539 values=[0, 3, 4, 1], 540 dense_shape=[3, 3]) 541 return embedding_ops.embedding_lookup_sparse_v2(sv, sp_ids, None) 542 543 @def_function.function 544 def safe_sparse_lookup(): 545 sp_ids = sparse_tensor.SparseTensor( 546 indices=[[0, 0], [0, 1], [1, 0], [2, 2]], 547 values=[0, -1, 4, 1], 548 dense_shape=[3, 3]) 549 sp_weights = sparse_tensor.SparseTensor( 550 indices=[[0, 0], [0, 1], [1, 0], [2, 2]], 551 values=[1., 1., -1., 1.], 552 dense_shape=[3, 3]) 553 return embedding_ops.safe_embedding_lookup_sparse_v2( 554 sv, sp_ids, sp_weights) 555 556 # TODO(chenkai): Add safe_sparse_lookup to the list. Currently 557 # ShardedVariable is converted to a tensor in safe_sparse_lookup. 558 for func in [lookup, sparse_lookup]: 559 num_gather_ops = 0 560 for op in func.get_concrete_function().graph.get_operations(): 561 if op.type == 'ResourceGather': 562 num_gather_ops += 1 563 self.assertEqual( 564 num_gather_ops, len(v), 'Number of ResourceGather op does not match' 565 ' expected, possibly due to ShardedVariable accidentally being' 566 ' converted to tensor in embedding_lookup ops.') 567 568 self.assertAllEqual(lookup(), [[1., 2.], [7., 8.], [9., 10.]]) 569 self.assertAllClose(sparse_lookup(), [[4., 5.], [9., 10.], [3., 4.]]) 570 self.assertAllClose(safe_sparse_lookup(), [[1., 2.], [0., 0.], [3., 4.]]) 571 572 def test_slicing(self): 573 v = [ 574 variables_lib.Variable([[1, 2], [3, 4], [5, 6]]), 575 variables_lib.Variable([[7, 8], [9, 10], [11, 12]]), 576 variables_lib.Variable([[13, 14], [15, 16]]) 577 ] 578 sv = sharded_variable.ShardedVariable(v) 579 empty = v[0][0:0] 580 581 # Test cases: positive step 582 self.assertAllEqual(sv[:], array_ops.concat(v, axis=0)) 583 self.assertAllEqual(sv[:2], [[1, 2], [3, 4]]) 584 self.assertAllEqual(sv[-8:2], [[1, 2], [3, 4]]) 585 self.assertAllEqual(sv[-10:2], [[1, 2], [3, 4]]) 586 self.assertAllEqual(sv[5:], [[11, 12], [13, 14], [15, 16]]) 587 self.assertAllEqual(sv[5:-1], [[11, 12], [13, 14]]) 588 self.assertAllEqual(sv[::3], [[1, 2], [7, 8], [13, 14]]) 589 self.assertAllEqual(sv[::5], [[1, 2], [11, 12]]) 590 self.assertAllEqual(sv[1::6], [[3, 4], [15, 16]]) 591 self.assertAllEqual(sv[1:5:6], [[3, 4]]) 592 self.assertAllEqual(sv[1::7], [[3, 4]]) 593 self.assertAllEqual(sv[2:7], [[5, 6], [7, 8], [9, 10], [11, 12], [13, 14]]) 594 self.assertAllEqual(sv[2:7:2], [[5, 6], [9, 10], [13, 14]]) 595 self.assertAllEqual(sv[2:7:3], [[5, 6], [11, 12]]) 596 597 # Test cases: negative step 598 self.assertAllEqual( 599 sv[::-1], array_ops.reverse(array_ops.concat(v, axis=0), axis=[0])) 600 self.assertAllEqual(sv[2::-1], [[5, 6], [3, 4], [1, 2]]) 601 self.assertAllEqual(sv[2:-8:-1], [[5, 6], [3, 4]]) 602 self.assertAllEqual(sv[2:-10:-1], [[5, 6], [3, 4], [1, 2]]) 603 self.assertAllEqual(sv[4::-1], [[9, 10], [7, 8], [5, 6], [3, 4], [1, 2]]) 604 self.assertAllEqual(sv[-1:-3:-1], [[15, 16], [13, 14]]) 605 self.assertAllEqual(sv[::-5], [[15, 16], [5, 6]]) 606 self.assertAllEqual(sv[6::-6], [[13, 14], [1, 2]]) 607 self.assertAllEqual(sv[6:5:-6], [[13, 14]]) 608 self.assertAllEqual(sv[6::-7], [[13, 14]]) 609 self.assertAllEqual(sv[7:1:-1], 610 [[15, 16], [13, 14], [11, 12], [9, 10], [7, 8], [5, 6]]) 611 self.assertAllEqual(sv[7:1:-2], [[15, 16], [11, 12], [7, 8]]) 612 self.assertAllEqual(sv[7:1:-4], [[15, 16], [7, 8]]) 613 614 # Test cases: empty slice 615 self.assertAllEqual(sv[0:0], empty) 616 self.assertAllEqual(sv[5:3], empty) 617 self.assertAllEqual(sv[3:5:-1], empty) 618 self.assertAllEqual(sv[-1:0], empty) 619 self.assertAllEqual(sv[2:-1:-1], empty) 620 621 # Test cases: slicing other dimensions 622 self.assertAllEqual(sv[:, 0], [1, 3, 5, 7, 9, 11, 13, 15]) 623 self.assertAllEqual(sv[:, 0:1], [[1], [3], [5], [7], [9], [11], [13], [15]]) 624 625 # Test cases: normal indexing 626 self.assertAllEqual(sv[2], [5, 6]) 627 self.assertAllEqual(sv[6], [13, 14]) 628 self.assertAllEqual(sv[2, 1], 6) 629 self.assertAllEqual(sv[-2], [13, 14]) 630 with self.assertRaisesRegex(IndexError, 'out of bounds'): 631 _ = sv[100] 632 with self.assertRaisesRegex(IndexError, 'out of bounds'): 633 _ = sv[-100] 634 635 # Test cases: Ellipsis 636 self.assertAllEqual(sv[...], array_ops.concat(v, axis=0)) 637 self.assertAllEqual(sv[..., 0], [1, 3, 5, 7, 9, 11, 13, 15]) 638 self.assertAllEqual(sv[0:1, ...], [[1, 2]]) 639 640 # Test cases: newaxis 641 self.assertAllEqual( 642 sv[array_ops.newaxis, ...], 643 array_ops.expand_dims_v2(array_ops.concat(v, axis=0), axis=0)) 644 645 # Test cases: boolean masks 646 self.assertAllEqual(sv[ops.convert_to_tensor(sv) > 10], 647 [11, 12, 13, 14, 15, 16]) 648 649 # Test cases: tensor input 650 with self.assertRaisesRegex(TypeError, 'not allowed'): 651 _ = sv[constant_op.constant(1)::] 652 with self.assertRaisesRegex(TypeError, 'not allowed'): 653 _ = sv[:constant_op.constant(1):] 654 with self.assertRaisesRegex(TypeError, 'not allowed'): 655 _ = sv[constant_op.constant(1)] 656 657 # Test cases: inside tf.function 658 @def_function.function 659 def func(): 660 a = sv[:, 0] 661 return a 662 663 self.assertAllEqual(func(), [1, 3, 5, 7, 9, 11, 13, 15]) 664 665 def test_operator_overload(self): 666 v1 = [ 667 variables_lib.Variable([1.]), 668 variables_lib.Variable([2.]), 669 ] 670 sv1 = sharded_variable.ShardedVariable(v1) 671 672 v2 = [ 673 variables_lib.Variable([1.]), 674 variables_lib.Variable([2.]), 675 ] 676 sv2 = sharded_variable.ShardedVariable(v2) 677 678 equal = sv1 == sv2 679 self.assertAllEqual(equal, [True, True]) 680 self.assertAllEqual(sv1 + sv2, [2.0, 4.0]) 681 682 def test_shards_have_container_set(self): 683 v1 = [ 684 variables_lib.Variable([1.]), 685 variables_lib.Variable([2.]), 686 ] 687 sv1 = sharded_variable.ShardedVariable(v1) 688 for v in sv1.variables: 689 self.assertTrue(hasattr(v, '_sharded_container')) 690 self.assertIs(v._sharded_container(), sv1) 691 692 def test_numpy(self): 693 v1 = [ 694 variables_lib.Variable([1.]), 695 variables_lib.Variable([2.]), 696 ] 697 sv1 = sharded_variable.ShardedVariable(v1) 698 sv1_np = sv1.numpy() 699 self.assertIsInstance(sv1_np, np.ndarray) 700 self.assertAllEqual(sv1_np, np.array([1., 2.])) 701 702 703class ShardedVariableSaveLoadTest(test.TestCase, parameterized.TestCase): 704 705 def setUp(self): 706 super().setUp() 707 cluster_def = get_cluster_def(test_cluster_params, num_workers=2, num_ps=3) 708 self.cluster_resolver = SimpleClusterResolver(ClusterSpec(cluster_def)) 709 710 def tearDown(self): 711 super().tearDown() 712 # Reset context to disconnect from the cluster. 713 context._reset_context() 714 715 def _create_strategy(self, num_shards): 716 if num_shards > 1: 717 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 718 self.cluster_resolver, 719 variable_partitioner=sharded_variable.FixedShardsPartitioner( 720 num_shards)) 721 else: 722 strategy = ds_context._get_default_strategy() 723 return strategy 724 725 @combinations.generate( 726 combinations.combine( 727 shard_config=[[2, 2], [2, 3], [3, 2], [2, 1], [1, 1]], 728 )) 729 def testSaveAndLoadSingleVariable(self, shard_config): 730 """Test saving and loading ShardedVariable with different numbers of shards. 731 732 Loading tf.Variables into multiple Shards is not yet supported 733 734 Args: 735 shard_config: The number of shards to use before and after loading. For 736 example, [2, 1] means to create and save the variable with 2 shards and 737 load it into 1 shard (i.e., a regular tf.Variable). 738 """ 739 strategy = self._create_strategy(shard_config[0]) 740 741 with strategy.scope(): 742 var = variables_lib.Variable([1., 2., 3., 4., 5., 6.]) 743 744 # Save variable 745 model_dir = self.get_temp_dir() 746 save.save(var, model_dir) 747 748 strategy2 = self._create_strategy(shard_config[1]) 749 with strategy2.scope(): 750 # Load variable 751 loaded = load.load(model_dir) 752 753 # Assert all values loaded, values are same 754 if shard_config[1] > 1: 755 loaded = array_ops.concat(loaded.variables, axis=0) 756 self.assertLen(loaded.numpy(), 6) 757 758 if shard_config[0] > 1: 759 var = array_ops.concat(var.variables, axis=0) 760 self.assertAllClose(var.numpy(), loaded.numpy()) 761 762 def testSaveAndLoadModuleUnderStrategy(self): 763 764 class Dense(module.Module): 765 766 def __init__(self): 767 self.kernel = variables_lib.Variable( 768 random_ops.random_uniform((6, 6)), name='kernel') 769 self.bias = variables_lib.Variable( 770 random_ops.random_uniform((6,)), name='bias') 771 772 @def_function.function 773 def __call__(self, x): 774 out = math_ops.matmul(self.kernel, x) 775 out = out + self.bias 776 return out 777 778 x = constant_op.constant( 779 math_ops.range(6, dtype=dtypes.float32), shape=[6, 1]) 780 781 strategy = self._create_strategy(2) 782 with strategy.scope(): 783 layer = Dense() 784 expect = layer(x) 785 786 model_dir = self.get_temp_dir() 787 save.save(layer, model_dir) 788 789 strategy2 = self._create_strategy(3) 790 with strategy2.scope(): 791 loaded_layer = load.load(model_dir) 792 # Should fail with informative error 793 with self.assertRaisesRegex(ValueError, 'run a loaded non-Keras'): 794 got = loaded_layer(x) 795 796 # Loading without a strategy should work, because the tf.function is traced 797 # with a single variable as input 798 loaded_layer = load.load(model_dir) 799 got = loaded_layer(x) 800 self.assertAllClose(got, expect) 801 802 803if __name__ == '__main__': 804 v2_compat.enable_v2_behavior() 805 test.main() 806