xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/sharded_variable_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for ShardedVariable."""
16
17import os
18
19from absl.testing import parameterized
20import numpy as np
21from tensorflow.python.checkpoint import checkpoint as util
22from tensorflow.python.client import session as session_lib
23from tensorflow.python.compat import v2_compat
24from tensorflow.python.distribute import combinations
25from tensorflow.python.distribute import distribution_strategy_context as ds_context
26from tensorflow.python.distribute import parameter_server_strategy_v2
27from tensorflow.python.distribute import sharded_variable
28from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
29from tensorflow.python.distribute.test_util import get_cluster_def
30from tensorflow.python.distribute.test_util import TestClusterParams
31from tensorflow.python.eager import context
32from tensorflow.python.eager import def_function
33from tensorflow.python.framework import constant_op
34from tensorflow.python.framework import dtypes
35from tensorflow.python.framework import indexed_slices
36from tensorflow.python.framework import ops
37from tensorflow.python.framework import sparse_tensor
38from tensorflow.python.framework import tensor_shape
39from tensorflow.python.framework import tensor_spec
40from tensorflow.python.module import module
41from tensorflow.python.ops import array_ops
42from tensorflow.python.ops import control_flow_ops
43from tensorflow.python.ops import embedding_ops
44from tensorflow.python.ops import math_ops
45from tensorflow.python.ops import random_ops
46from tensorflow.python.ops import variables as variables_lib
47from tensorflow.python.platform import test
48from tensorflow.python.saved_model import load
49from tensorflow.python.saved_model import loader
50from tensorflow.python.saved_model import save
51from tensorflow.python.saved_model import signature_constants
52from tensorflow.python.saved_model import tag_constants
53from tensorflow.python.trackable import autotrackable
54from tensorflow.python.training.server_lib import ClusterSpec
55from tensorflow.python.util import nest
56
57# We create one cluster to share between tests. The cluster should be large
58# enough to accommodate all the tests. Adjust the following constants as needed
59# but be aware of resource limitations in OSS tests.
60test_cluster_params = TestClusterParams(None, 2, 3)
61
62
63def _load_and_run(
64    model_dir,
65    inputs,
66    signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY):
67  """Load a SavedModel into a TF 1.x-style graph and run `signature_key`."""
68  graph = ops.Graph()
69  with graph.as_default(), session_lib.Session() as session:
70    meta_graph_def = loader.load(session, [tag_constants.SERVING], model_dir)
71    signature = meta_graph_def.signature_def[signature_key]
72    feed_dict = {}
73    for arg_name in inputs.keys():
74      input_tensor = session.graph.get_tensor_by_name(
75          signature.inputs[arg_name].name)
76      feed_dict[input_tensor] = inputs[arg_name]
77    output_dict = {}
78    for output_name, output_tensor_info in signature.outputs.items():
79      output_dict[output_name] = session.graph.get_tensor_by_name(
80          output_tensor_info.name)
81    return session.run(output_dict, feed_dict=feed_dict)
82
83
84class PartitionerTest(test.TestCase):
85
86  def test_fixed_shards_partitioner(self):
87    partitioner = sharded_variable.FixedShardsPartitioner(num_shards=2)
88    got = partitioner(tensor_shape.TensorShape([10, 3]), dtypes.float32)
89    self.assertAllEqual(got, [2, 1])
90
91  def test_min_size_partitioner(self):
92    partitioner = sharded_variable.MinSizePartitioner(
93        min_shard_bytes=4, max_shards=2)
94    got = partitioner(tensor_shape.TensorShape([6, 1]), dtypes.float32)
95    self.assertAllEqual(got, [2, 1])
96
97    partitioner = sharded_variable.MinSizePartitioner(
98        min_shard_bytes=4, max_shards=10)
99    got = partitioner(tensor_shape.TensorShape([6, 1]), dtypes.float32)
100    self.assertAllEqual(got, [6, 1])
101
102  def test_max_size_partitioner(self):
103    partitioner = sharded_variable.MaxSizePartitioner(max_shard_bytes=4)
104    got = partitioner(tensor_shape.TensorShape([6, 1]), dtypes.float32)
105    self.assertAllEqual(got, [6, 1])
106
107    partitioner = sharded_variable.MaxSizePartitioner(
108        max_shard_bytes=4, max_shards=2)
109    got = partitioner(tensor_shape.TensorShape([6, 1]), dtypes.float32)
110    self.assertAllEqual(got, [2, 1])
111
112    partitioner = sharded_variable.MaxSizePartitioner(max_shard_bytes=1024)
113    got = partitioner(tensor_shape.TensorShape([6, 1]), dtypes.float32)
114    self.assertAllEqual(got, [1, 1])
115
116
117class ShardedVariableTest(test.TestCase, parameterized.TestCase):
118
119  def test_sharded_variable_simple(self):
120    v0 = variables_lib.Variable([0])
121    v1 = variables_lib.Variable([1])
122    s = sharded_variable.ShardedVariable([v0, v1], name='s')
123    self.assertEqual(s.variables[0], v0)
124    self.assertEqual(s.variables[1], v1)
125    self.assertEqual(s.shape.as_list(), [2])
126    self.assertEqual(s.dtype, v0.dtype)
127    self.assertEqual(s.name, 's')
128
129  def test_assign(self):
130    v0 = variables_lib.Variable([[0, 0]])
131    v1 = variables_lib.Variable([[1, 1], [2, 2]])
132    v2 = variables_lib.Variable([[3, 3]])
133    s = sharded_variable.ShardedVariable([v0, v1, v2])
134    ret = s.assign([[4, 4], [5, 5], [6, 6], [7, 7]])
135    self.assertAllEqual(self.evaluate(s.variables[0]), [[4, 4]])
136    self.assertAllEqual(self.evaluate(s.variables[1]), [[5, 5], [6, 6]])
137    self.assertAllEqual(self.evaluate(s.variables[2]), [[7, 7]])
138    self.assertIs(ret, s)
139
140  def test_assign_add(self):
141    v0 = variables_lib.Variable([[0, 0]])
142    v1 = variables_lib.Variable([[1, 1], [2, 2]])
143    v2 = variables_lib.Variable([[3, 3]])
144    s = sharded_variable.ShardedVariable([v0, v1, v2])
145    ret = s.assign_add([[1, 1], [1, 1], [2, 2], [2, 2]])
146    self.assertAllEqual(self.evaluate(s.variables[0]), [[1, 1]])
147    self.assertAllEqual(self.evaluate(s.variables[1]), [[2, 2], [4, 4]])
148    self.assertAllEqual(self.evaluate(s.variables[2]), [[5, 5]])
149    self.assertIs(ret, s)
150
151  def test_assign_sub(self):
152    v0 = variables_lib.Variable([[0, 0]])
153    v1 = variables_lib.Variable([[1, 1], [2, 2]])
154    v2 = variables_lib.Variable([[3, 3]])
155    s = sharded_variable.ShardedVariable([v0, v1, v2])
156    ret = s.assign_sub([[0, 0], [1, 1], [1, 1], [3, 3]])
157    self.assertAllEqual(self.evaluate(s.variables[0]), [[0, 0]])
158    self.assertAllEqual(self.evaluate(s.variables[1]), [[0, 0], [1, 1]])
159    self.assertAllEqual(self.evaluate(s.variables[2]), [[0, 0]])
160    self.assertIs(ret, s)
161
162  def test_scatter_add_uneven_partition(self):
163    v = variables_lib.Variable(array_ops.zeros((32, 1)))
164    sparse_delta = indexed_slices.IndexedSlices(
165        values=constant_op.constant([[0.], [1.], [2.], [3.], [4.], [5.]]),
166        indices=constant_op.constant([0, 10, 11, 12, 30, 31]))
167
168    v0 = variables_lib.Variable(array_ops.zeros((11, 1)))
169    v1 = variables_lib.Variable(array_ops.zeros((11, 1)))
170    v2 = variables_lib.Variable(array_ops.zeros((10, 1)))
171    sv = sharded_variable.ShardedVariable([v0, v1, v2])
172
173    v.scatter_add(sparse_delta)
174    sv.scatter_add(sparse_delta)
175    self.assertAllEqual(v, ops.convert_to_tensor(sv))
176
177    @def_function.function
178    def func():
179      v.scatter_add(sparse_delta)
180      sv.scatter_add(sparse_delta)
181
182    func()
183    self.assertAllEqual(v, ops.convert_to_tensor(sv))
184
185  @parameterized.parameters('scatter_add', 'scatter_div', 'scatter_max',
186                            'scatter_min', 'scatter_mul', 'scatter_sub',
187                            'scatter_update')
188  def test_scatter_ops_even_partition(self, op):
189    v = variables_lib.Variable(array_ops.zeros((30, 1)))
190    # Make sure values does not contain 0 due to testing `scatter_div`!
191    sparse_delta = indexed_slices.IndexedSlices(
192        values=constant_op.constant([[1.], [2.], [3.], [4.], [5.]]),
193        indices=constant_op.constant([0, 10, 12, 21, 22]))
194
195    v0 = variables_lib.Variable(array_ops.zeros((10, 1)))
196    v1 = variables_lib.Variable(array_ops.zeros((10, 1)))
197    v2 = variables_lib.Variable(array_ops.zeros((10, 1)))
198    sv = sharded_variable.ShardedVariable([v0, v1, v2])
199
200    getattr(v, op)(sparse_delta, name='scatter_v')
201    getattr(sv, op)(sparse_delta, name='scatter_sv')
202    self.assertAllEqual(v, ops.convert_to_tensor(sv))
203
204    @def_function.function
205    def func():
206      getattr(v, op)(sparse_delta, name='scatter_v')
207      getattr(sv, op)(sparse_delta, name='scatter_sv')
208
209    func()
210    self.assertAllEqual(v, ops.convert_to_tensor(sv))
211
212  def test_batch_scatter_update(self):
213    v = variables_lib.Variable(array_ops.zeros((32, 1)))
214    sparse_delta = indexed_slices.IndexedSlices(
215        values=constant_op.constant([[0.], [1.], [2.], [3.], [4.], [5.]]),
216        indices=constant_op.constant([10, 11, 12, 13, 14, 15]))
217
218    v0 = variables_lib.Variable(array_ops.zeros((11, 1)))
219    v1 = variables_lib.Variable(array_ops.zeros((11, 1)))
220    v2 = variables_lib.Variable(array_ops.zeros((10, 1)))
221    sv = sharded_variable.ShardedVariable([v0, v1, v2])
222
223    v.batch_scatter_update(sparse_delta)
224    sv.batch_scatter_update(sparse_delta)
225    self.assertAllEqual(v, ops.convert_to_tensor(sv))
226
227    @def_function.function
228    def func():
229      v.batch_scatter_update(sparse_delta)
230      sv.batch_scatter_update(sparse_delta)
231
232    func()
233    self.assertAllEqual(v, ops.convert_to_tensor(sv))
234
235  def test_sparse_read(self):
236    v = variables_lib.Variable(array_ops.zeros((30, 1)))
237    indices = constant_op.constant([0, 10, 12, 21, 22])
238
239    v0 = variables_lib.Variable(array_ops.zeros((10, 1)))
240    v1 = variables_lib.Variable(array_ops.zeros((10, 1)))
241    v2 = variables_lib.Variable(array_ops.zeros((10, 1)))
242    sv = sharded_variable.ShardedVariable([v0, v1, v2])
243
244    self.assertAllEqual(v.sparse_read(indices), sv.sparse_read(indices))
245
246    @def_function.function
247    def func():
248      return v.sparse_read(indices), sv.sparse_read(indices)
249
250    got, expect = func()
251    self.assertAllEqual(got, expect)
252
253  def test_control_dep_on_assign(self):
254    v0 = variables_lib.Variable([[0, 0]])
255    v1 = variables_lib.Variable([[1, 1], [2, 2]])
256    v2 = variables_lib.Variable([[3, 3]])
257    s = sharded_variable.ShardedVariable([v0, v1, v2])
258
259    @def_function.function
260    def func():
261      ret = s.assign([[4, 4], [5, 5], [6, 6], [7, 7]])
262      with ops.control_dependencies([ret]):
263        a = array_ops.ones((1, 1))
264      with ops.control_dependencies([control_flow_ops.group(ret)]):
265        b = array_ops.ones((1, 1))
266      return a, b
267
268    func()
269
270  def test_convert_to_tensor(self):
271    v0 = variables_lib.Variable([[0, 0]])
272    v1 = variables_lib.Variable([[1, 1], [2, 2]])
273    v2 = variables_lib.Variable([[3, 3]])
274    s = sharded_variable.ShardedVariable([v0, v1, v2])
275    t = ops.convert_to_tensor(s)
276    self.assertAllEqual(t, [[0, 0], [1, 1], [2, 2], [3, 3]])
277
278  def test_save_restore(self):
279    fname = os.path.join(self.get_temp_dir(), 'checkpoint')
280    variables = [
281        variables_lib.Variable([0]),
282        variables_lib.Variable([1]),
283        variables_lib.Variable([2]),
284        variables_lib.Variable([3])
285    ]
286    s = sharded_variable.ShardedVariable(variables, name='s')
287
288    cp = util.Checkpoint(s=s)
289    self.assertEqual(self.evaluate(cp.s.variables[0]), [0])
290    cp.write(fname)
291
292    self.evaluate(cp.s.variables[0].assign([4]))
293    self.assertEqual(self.evaluate(cp.s.variables[0]), [4])
294
295    cp.restore(fname)
296    # Tests that the original weights are restored.
297    self.assertEqual(self.evaluate(cp.s.variables[0]), [0])
298
299  def test_save_restore_different_partitions(self):
300    fname = os.path.join(self.get_temp_dir(), 'checkpoint')
301    variables = [
302        variables_lib.Variable([0]),
303        variables_lib.Variable([1]),
304        variables_lib.Variable([2]),
305        variables_lib.Variable([3])
306    ]
307    s = sharded_variable.ShardedVariable(variables, name='s')
308
309    cp = util.Checkpoint(s=s)
310    cp.write(fname)
311
312    variables2 = [variables_lib.Variable([0, 0, 0, 0])]
313    s2 = sharded_variable.ShardedVariable(variables2, name='s')
314
315    # Restore from 4 partitions into 1.
316    cp2 = util.Checkpoint(s=s2)
317    cp2.restore(fname)
318    self.assertAllEqual(self.evaluate(cp2.s.variables[0]), [0, 1, 2, 3])
319
320    self.evaluate(cp2.s.variables[0].assign([5, 10, 15, 20]))
321    cp2.write(fname)
322
323    # Restore 1 partition into 4.
324    cp.restore(fname)
325    self.assertEqual(self.evaluate(cp.s.variables[0]), [5])
326    self.assertEqual(self.evaluate(cp.s.variables[1]), [10])
327    self.assertEqual(self.evaluate(cp.s.variables[2]), [15])
328    self.assertEqual(self.evaluate(cp.s.variables[3]), [20])
329
330  def test_save_restore_4_to_2_partitions(self):
331    fname = os.path.join(self.get_temp_dir(), 'checkpoint')
332    variables = [
333        variables_lib.Variable([0]),
334        variables_lib.Variable([1]),
335        variables_lib.Variable([2]),
336        variables_lib.Variable([3])
337    ]
338    s = sharded_variable.ShardedVariable(variables, name='s')
339    cp = util.Checkpoint(s=s)
340    cp.write(fname)
341
342    variables2 = [
343        variables_lib.Variable([0, 0]),
344        variables_lib.Variable([0, 0])
345    ]
346    s2 = sharded_variable.ShardedVariable(variables2, name='s')
347    cp2 = util.Checkpoint(s=s2)
348    cp2.restore(fname)
349    # Assert that weights from the 4 partitions were loaded here.
350    self.assertLen(cp2.s.variables, 2)
351    self.assertAllEqual(self.evaluate(cp2.s.variables[0]), [0, 1])
352    self.assertAllEqual(self.evaluate(cp2.s.variables[1]), [2, 3])
353
354  def test_delayed_restore(self):
355    fname = os.path.join(self.get_temp_dir(), 'checkpoint')
356    model = autotrackable.AutoTrackable()
357    variables = [
358        variables_lib.Variable([0]),
359        variables_lib.Variable([1]),
360        variables_lib.Variable([2]),
361        variables_lib.Variable([3])
362    ]
363    model.s = sharded_variable.ShardedVariable(variables)
364    cp = util.Checkpoint(model=model)
365    cp.write(fname)
366
367    model2 = autotrackable.AutoTrackable()
368    cp2 = util.Checkpoint(model=model2)
369    cp2.restore(fname)
370    variables2 = [
371        variables_lib.Variable([0]),
372        variables_lib.Variable([0]),
373        variables_lib.Variable([0]),
374        variables_lib.Variable([0])
375    ]
376    model2.s = sharded_variable.ShardedVariable(variables2)
377    self.assertAllEqual(self.evaluate(model2.s.variables[0]), [0])
378    self.assertAllEqual(self.evaluate(model2.s.variables[1]), [1])
379    self.assertAllEqual(self.evaluate(model2.s.variables[2]), [2])
380    self.assertAllEqual(self.evaluate(model2.s.variables[3]), [3])
381
382  def test_delayed_restore_4_to_2_partitions(self):
383    fname = os.path.join(self.get_temp_dir(), 'checkpoint')
384    model = autotrackable.AutoTrackable()
385    variables = [
386        variables_lib.Variable([0]),
387        variables_lib.Variable([1]),
388        variables_lib.Variable([2]),
389        variables_lib.Variable([3])
390    ]
391    model.s = sharded_variable.ShardedVariable(variables)
392    cp = util.Checkpoint(model=model)
393    cp.write(fname)
394
395    model2 = autotrackable.AutoTrackable()
396    cp2 = util.Checkpoint(model=model2)
397    cp2.restore(fname)
398    variables2 = [
399        variables_lib.Variable([0, 0]),
400        variables_lib.Variable([0, 0])
401    ]
402    model2.s = sharded_variable.ShardedVariable(variables2)
403    self.assertAllEqual(self.evaluate(model2.s.variables[0]), [0, 1])
404    self.assertAllEqual(self.evaluate(model2.s.variables[1]), [2, 3])
405
406  def test_save_graph_def(self):
407    root = autotrackable.AutoTrackable()
408    v1 = variables_lib.Variable([3.])
409    v2 = variables_lib.Variable([2.])
410    root.v = sharded_variable.ShardedVariable([v1, v2])
411    root.train = def_function.function(
412        lambda x: embedding_ops.embedding_lookup_v2(root.v.variables, x))
413    # TODO(b/144057383): Remove the necessity of root.serve once saving context
414    # is made to tf.function cache.
415    root.serve = def_function.function(
416        lambda x: embedding_ops.embedding_lookup_v2(root.v.variables[0], x),
417        input_signature=[tensor_spec.TensorSpec([2], dtypes.int32, name='x')])
418
419    # Trace and use root.train
420    self.assertAllEqual([3., 2.], root.train([0, 1]).numpy())
421
422    save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
423    save.save(root, save_dir, root.serve)
424    self.assertAllEqual([3., 2.],
425                        _load_and_run(save_dir, {'x': [0, 1]})['output_0'])
426
427    # Continue using root.train for training
428    self.assertAllEqual([3., 2.], root.train([0, 1]).numpy())
429
430  def test_validation_errors(self):
431    with self.assertRaisesRegex(TypeError, 'should be a non-empty list of'):
432      sharded_variable.ShardedVariable(None)
433
434    with self.assertRaisesRegex(TypeError, 'should be a non-empty list of'):
435      sharded_variable.ShardedVariable(
436          [variables_lib.Variable([0]), 'not-a-variable'])
437
438    with self.assertRaisesRegex(TypeError, 'should be a non-empty list of'):
439      sharded_variable.ShardedVariable([])
440
441    with self.assertRaisesRegex(ValueError, 'must have the same dtype'):
442      sharded_variable.ShardedVariable([
443          variables_lib.Variable([0], dtype='int64'),
444          variables_lib.Variable([1], dtype='int32')
445      ])
446
447    with self.assertRaisesRegex(ValueError, 'the same shapes except'):
448      sharded_variable.ShardedVariable([
449          variables_lib.Variable(array_ops.ones((5, 10))),
450          variables_lib.Variable(array_ops.ones((5, 20)))
451      ])
452
453    with self.assertRaisesRegex(ValueError, '`SaveSliceInfo` should not'):
454      v = variables_lib.Variable([0])
455      v._set_save_slice_info(
456          variables_lib.Variable.SaveSliceInfo(
457              full_name='s', full_shape=[2], var_offset=[0], var_shape=[1]))
458      sharded_variable.ShardedVariable([v])
459
460  def test_as_function_input(self):
461    variables1 = [
462        variables_lib.Variable([1]),
463        variables_lib.Variable([1]),
464    ]
465    s = sharded_variable.ShardedVariable(variables1)
466    variables2 = [
467        variables_lib.Variable([2]),
468        variables_lib.Variable([2]),
469    ]
470    s2 = sharded_variable.ShardedVariable(variables2)
471
472    trace_count = [0]
473
474    @def_function.function
475    def func(sharded_var):
476      trace_count[0] = trace_count[0] + 1
477      sharded_var.assign([0, 0])
478
479    func(s)
480    self.assertAllEqual(ops.convert_to_tensor(s), [0, 0])
481    self.assertEqual(trace_count[0], 1)
482    func(s2)
483    self.assertAllEqual(ops.convert_to_tensor(s2), [0, 0])
484    self.assertEqual(trace_count[0], 1)
485
486  def test_flatten(self):
487    variables = [
488        variables_lib.Variable([0]),
489        variables_lib.Variable([1]),
490    ]
491    s = sharded_variable.ShardedVariable(variables)
492
493    got = nest.flatten(s)
494    self.assertIs(s, got[0])
495
496    got = nest.flatten(s, expand_composites=True)
497    expected = nest.flatten(variables, expand_composites=True)
498    self.assertEqual(got, expected)
499
500  def test_tf_module(self):
501
502    class Model(module.Module):
503
504      def __init__(self):
505        super().__init__()
506        variables = [
507            variables_lib.Variable([0]),
508            variables_lib.Variable([1]),
509        ]
510        self.w = sharded_variable.ShardedVariable(variables)
511
512    model = Model()
513
514    self.assertLen(model.variables, 2)
515    self.assertEqual(model.variables[0], [0])
516    self.assertEqual(model.variables[1], [1])
517    self.assertAllEqual(model.variables, model.trainable_variables)
518
519    self.assertLen(model._trackable_children(), 1)
520    self.assertIs(model._trackable_children().popitem()[1], model.w)
521
522  def test_embedding_lookup(self):
523    v = [
524        variables_lib.Variable([[1., 2.], [3., 4.]]),
525        variables_lib.Variable([[5., 6.], [7., 8.]]),
526        variables_lib.Variable([[9., 10.]])
527    ]
528    sv = sharded_variable.ShardedVariable(v)
529
530    @def_function.function
531    def lookup():
532      ids = constant_op.constant([0, 3, 4])
533      return embedding_ops.embedding_lookup_v2(sv, ids)
534
535    @def_function.function
536    def sparse_lookup():
537      sp_ids = sparse_tensor.SparseTensor(
538          indices=[[0, 0], [0, 1], [1, 0], [2, 2]],
539          values=[0, 3, 4, 1],
540          dense_shape=[3, 3])
541      return embedding_ops.embedding_lookup_sparse_v2(sv, sp_ids, None)
542
543    @def_function.function
544    def safe_sparse_lookup():
545      sp_ids = sparse_tensor.SparseTensor(
546          indices=[[0, 0], [0, 1], [1, 0], [2, 2]],
547          values=[0, -1, 4, 1],
548          dense_shape=[3, 3])
549      sp_weights = sparse_tensor.SparseTensor(
550          indices=[[0, 0], [0, 1], [1, 0], [2, 2]],
551          values=[1., 1., -1., 1.],
552          dense_shape=[3, 3])
553      return embedding_ops.safe_embedding_lookup_sparse_v2(
554          sv, sp_ids, sp_weights)
555
556    # TODO(chenkai): Add safe_sparse_lookup to the list. Currently
557    # ShardedVariable is converted to a tensor in safe_sparse_lookup.
558    for func in [lookup, sparse_lookup]:
559      num_gather_ops = 0
560      for op in func.get_concrete_function().graph.get_operations():
561        if op.type == 'ResourceGather':
562          num_gather_ops += 1
563      self.assertEqual(
564          num_gather_ops, len(v), 'Number of ResourceGather op does not match'
565          ' expected, possibly due to ShardedVariable accidentally being'
566          ' converted to tensor in embedding_lookup ops.')
567
568    self.assertAllEqual(lookup(), [[1., 2.], [7., 8.], [9., 10.]])
569    self.assertAllClose(sparse_lookup(), [[4., 5.], [9., 10.], [3., 4.]])
570    self.assertAllClose(safe_sparse_lookup(), [[1., 2.], [0., 0.], [3., 4.]])
571
572  def test_slicing(self):
573    v = [
574        variables_lib.Variable([[1, 2], [3, 4], [5, 6]]),
575        variables_lib.Variable([[7, 8], [9, 10], [11, 12]]),
576        variables_lib.Variable([[13, 14], [15, 16]])
577    ]
578    sv = sharded_variable.ShardedVariable(v)
579    empty = v[0][0:0]
580
581    # Test cases: positive step
582    self.assertAllEqual(sv[:], array_ops.concat(v, axis=0))
583    self.assertAllEqual(sv[:2], [[1, 2], [3, 4]])
584    self.assertAllEqual(sv[-8:2], [[1, 2], [3, 4]])
585    self.assertAllEqual(sv[-10:2], [[1, 2], [3, 4]])
586    self.assertAllEqual(sv[5:], [[11, 12], [13, 14], [15, 16]])
587    self.assertAllEqual(sv[5:-1], [[11, 12], [13, 14]])
588    self.assertAllEqual(sv[::3], [[1, 2], [7, 8], [13, 14]])
589    self.assertAllEqual(sv[::5], [[1, 2], [11, 12]])
590    self.assertAllEqual(sv[1::6], [[3, 4], [15, 16]])
591    self.assertAllEqual(sv[1:5:6], [[3, 4]])
592    self.assertAllEqual(sv[1::7], [[3, 4]])
593    self.assertAllEqual(sv[2:7], [[5, 6], [7, 8], [9, 10], [11, 12], [13, 14]])
594    self.assertAllEqual(sv[2:7:2], [[5, 6], [9, 10], [13, 14]])
595    self.assertAllEqual(sv[2:7:3], [[5, 6], [11, 12]])
596
597    # Test cases: negative step
598    self.assertAllEqual(
599        sv[::-1], array_ops.reverse(array_ops.concat(v, axis=0), axis=[0]))
600    self.assertAllEqual(sv[2::-1], [[5, 6], [3, 4], [1, 2]])
601    self.assertAllEqual(sv[2:-8:-1], [[5, 6], [3, 4]])
602    self.assertAllEqual(sv[2:-10:-1], [[5, 6], [3, 4], [1, 2]])
603    self.assertAllEqual(sv[4::-1], [[9, 10], [7, 8], [5, 6], [3, 4], [1, 2]])
604    self.assertAllEqual(sv[-1:-3:-1], [[15, 16], [13, 14]])
605    self.assertAllEqual(sv[::-5], [[15, 16], [5, 6]])
606    self.assertAllEqual(sv[6::-6], [[13, 14], [1, 2]])
607    self.assertAllEqual(sv[6:5:-6], [[13, 14]])
608    self.assertAllEqual(sv[6::-7], [[13, 14]])
609    self.assertAllEqual(sv[7:1:-1],
610                        [[15, 16], [13, 14], [11, 12], [9, 10], [7, 8], [5, 6]])
611    self.assertAllEqual(sv[7:1:-2], [[15, 16], [11, 12], [7, 8]])
612    self.assertAllEqual(sv[7:1:-4], [[15, 16], [7, 8]])
613
614    # Test cases: empty slice
615    self.assertAllEqual(sv[0:0], empty)
616    self.assertAllEqual(sv[5:3], empty)
617    self.assertAllEqual(sv[3:5:-1], empty)
618    self.assertAllEqual(sv[-1:0], empty)
619    self.assertAllEqual(sv[2:-1:-1], empty)
620
621    # Test cases: slicing other dimensions
622    self.assertAllEqual(sv[:, 0], [1, 3, 5, 7, 9, 11, 13, 15])
623    self.assertAllEqual(sv[:, 0:1], [[1], [3], [5], [7], [9], [11], [13], [15]])
624
625    # Test cases: normal indexing
626    self.assertAllEqual(sv[2], [5, 6])
627    self.assertAllEqual(sv[6], [13, 14])
628    self.assertAllEqual(sv[2, 1], 6)
629    self.assertAllEqual(sv[-2], [13, 14])
630    with self.assertRaisesRegex(IndexError, 'out of bounds'):
631      _ = sv[100]
632    with self.assertRaisesRegex(IndexError, 'out of bounds'):
633      _ = sv[-100]
634
635    # Test cases: Ellipsis
636    self.assertAllEqual(sv[...], array_ops.concat(v, axis=0))
637    self.assertAllEqual(sv[..., 0], [1, 3, 5, 7, 9, 11, 13, 15])
638    self.assertAllEqual(sv[0:1, ...], [[1, 2]])
639
640    # Test cases: newaxis
641    self.assertAllEqual(
642        sv[array_ops.newaxis, ...],
643        array_ops.expand_dims_v2(array_ops.concat(v, axis=0), axis=0))
644
645    # Test cases: boolean masks
646    self.assertAllEqual(sv[ops.convert_to_tensor(sv) > 10],
647                        [11, 12, 13, 14, 15, 16])
648
649    # Test cases: tensor input
650    with self.assertRaisesRegex(TypeError, 'not allowed'):
651      _ = sv[constant_op.constant(1)::]
652    with self.assertRaisesRegex(TypeError, 'not allowed'):
653      _ = sv[:constant_op.constant(1):]
654    with self.assertRaisesRegex(TypeError, 'not allowed'):
655      _ = sv[constant_op.constant(1)]
656
657    # Test cases: inside tf.function
658    @def_function.function
659    def func():
660      a = sv[:, 0]
661      return a
662
663    self.assertAllEqual(func(), [1, 3, 5, 7, 9, 11, 13, 15])
664
665  def test_operator_overload(self):
666    v1 = [
667        variables_lib.Variable([1.]),
668        variables_lib.Variable([2.]),
669    ]
670    sv1 = sharded_variable.ShardedVariable(v1)
671
672    v2 = [
673        variables_lib.Variable([1.]),
674        variables_lib.Variable([2.]),
675    ]
676    sv2 = sharded_variable.ShardedVariable(v2)
677
678    equal = sv1 == sv2
679    self.assertAllEqual(equal, [True, True])
680    self.assertAllEqual(sv1 + sv2, [2.0, 4.0])
681
682  def test_shards_have_container_set(self):
683    v1 = [
684        variables_lib.Variable([1.]),
685        variables_lib.Variable([2.]),
686    ]
687    sv1 = sharded_variable.ShardedVariable(v1)
688    for v in sv1.variables:
689      self.assertTrue(hasattr(v, '_sharded_container'))
690      self.assertIs(v._sharded_container(), sv1)
691
692  def test_numpy(self):
693    v1 = [
694        variables_lib.Variable([1.]),
695        variables_lib.Variable([2.]),
696    ]
697    sv1 = sharded_variable.ShardedVariable(v1)
698    sv1_np = sv1.numpy()
699    self.assertIsInstance(sv1_np, np.ndarray)
700    self.assertAllEqual(sv1_np, np.array([1., 2.]))
701
702
703class ShardedVariableSaveLoadTest(test.TestCase, parameterized.TestCase):
704
705  def setUp(self):
706    super().setUp()
707    cluster_def = get_cluster_def(test_cluster_params, num_workers=2, num_ps=3)
708    self.cluster_resolver = SimpleClusterResolver(ClusterSpec(cluster_def))
709
710  def tearDown(self):
711    super().tearDown()
712    # Reset context to disconnect from the cluster.
713    context._reset_context()
714
715  def _create_strategy(self, num_shards):
716    if num_shards > 1:
717      strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
718          self.cluster_resolver,
719          variable_partitioner=sharded_variable.FixedShardsPartitioner(
720              num_shards))
721    else:
722      strategy = ds_context._get_default_strategy()
723    return strategy
724
725  @combinations.generate(
726      combinations.combine(
727          shard_config=[[2, 2], [2, 3], [3, 2], [2, 1], [1, 1]],
728      ))
729  def testSaveAndLoadSingleVariable(self, shard_config):
730    """Test saving and loading ShardedVariable with different numbers of shards.
731
732    Loading tf.Variables into multiple Shards is not yet supported
733
734    Args:
735      shard_config: The number of shards to use before and after loading. For
736        example, [2, 1] means to create and save the variable with 2 shards and
737        load it into 1 shard (i.e., a regular tf.Variable).
738    """
739    strategy = self._create_strategy(shard_config[0])
740
741    with strategy.scope():
742      var = variables_lib.Variable([1., 2., 3., 4., 5., 6.])
743
744    # Save variable
745    model_dir = self.get_temp_dir()
746    save.save(var, model_dir)
747
748    strategy2 = self._create_strategy(shard_config[1])
749    with strategy2.scope():
750      # Load variable
751      loaded = load.load(model_dir)
752
753    # Assert all values loaded, values are same
754    if shard_config[1] > 1:
755      loaded = array_ops.concat(loaded.variables, axis=0)
756    self.assertLen(loaded.numpy(), 6)
757
758    if shard_config[0] > 1:
759      var = array_ops.concat(var.variables, axis=0)
760    self.assertAllClose(var.numpy(), loaded.numpy())
761
762  def testSaveAndLoadModuleUnderStrategy(self):
763
764    class Dense(module.Module):
765
766      def __init__(self):
767        self.kernel = variables_lib.Variable(
768            random_ops.random_uniform((6, 6)), name='kernel')
769        self.bias = variables_lib.Variable(
770            random_ops.random_uniform((6,)), name='bias')
771
772      @def_function.function
773      def __call__(self, x):
774        out = math_ops.matmul(self.kernel, x)
775        out = out + self.bias
776        return out
777
778    x = constant_op.constant(
779        math_ops.range(6, dtype=dtypes.float32), shape=[6, 1])
780
781    strategy = self._create_strategy(2)
782    with strategy.scope():
783      layer = Dense()
784      expect = layer(x)
785
786    model_dir = self.get_temp_dir()
787    save.save(layer, model_dir)
788
789    strategy2 = self._create_strategy(3)
790    with strategy2.scope():
791      loaded_layer = load.load(model_dir)
792      # Should fail with informative error
793      with self.assertRaisesRegex(ValueError, 'run a loaded non-Keras'):
794        got = loaded_layer(x)
795
796    # Loading without a strategy should work, because the tf.function is traced
797    # with a single variable as input
798    loaded_layer = load.load(model_dir)
799    got = loaded_layer(x)
800    self.assertAllClose(got, expect)
801
802
803if __name__ == '__main__':
804  v2_compat.enable_v2_behavior()
805  test.main()
806