xref: /aosp_15_r20/external/tensorflow/tensorflow/python/tpu/tpu_sharding_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15
16"""Tests for tpu_function helpers."""
17
18
19from tensorflow.python.framework import tensor_shape
20from tensorflow.python.platform import test
21from tensorflow.python.tpu import tpu_sharding
22
23
24class ShardingTest(test.TestCase):
25
26  def testFreeze(self):
27    """Tests that freezing a policy applies default values."""
28    p1 = tpu_sharding.ShardingPolicy()
29    p1.freeze()
30    self.assertEqual(p1.number_of_shards,
31                     tpu_sharding._DEFAULT_NUMBER_OF_SHARDS)
32    self.assertEqual(p1.shard_dimension, tpu_sharding._DEFAULT_SHARD_DIMENSION)
33    p2 = tpu_sharding.ShardingPolicy()
34    p2.set_number_of_shards(17)
35    p2.set_shard_dimension(23)
36    p2.freeze()
37    self.assertEqual(p2.number_of_shards, 17)
38    self.assertEqual(p2.shard_dimension, 23)
39
40  def testFrozen(self):
41    """Tests that frozen policies can't be changed."""
42    p1 = tpu_sharding.ShardingPolicy()
43    p1.freeze()
44    with self.assertRaises(ValueError):
45      p1.set_number_of_shards(17)
46    with self.assertRaises(ValueError):
47      p1.set_shard_dimension(22)
48
49  def testStr(self):
50    """Tests the string representation."""
51    p1 = tpu_sharding.ShardingPolicy()
52    self.assertEqual(str(p1), "ShardingPolicy(unset)")
53    p1.set_number_of_shards(17)
54    self.assertEqual(str(p1), "ShardingPolicy(unset)")
55    p1.set_shard_dimension(8)
56    self.assertEqual(str(p1), "ShardingPolicy(17 shards dimension 8)")
57
58  def testMerge(self):
59    """Tests that merging works."""
60    p1 = tpu_sharding.ShardingPolicy()
61    p1.set_number_of_shards(17)
62    p1.set_shard_dimension(23)
63    p2 = tpu_sharding.ShardingPolicy()
64    p2.merge(p1)
65    self.assertEqual(p2.number_of_shards, 17)
66    self.assertEqual(p2.shard_dimension, 23)
67    p1 = tpu_sharding.ShardingPolicy()
68    p1.set_shard_dimension(12)
69    p2.merge(p1)
70    self.assertEqual(p2.number_of_shards, 17)
71    self.assertEqual(p2.shard_dimension, 12)
72    p2.freeze()
73    p2.merge(p1)
74    self.assertEqual(p2.number_of_shards, 17)
75    self.assertEqual(p2.shard_dimension, 12)
76    p1.set_number_of_shards(1)
77    with self.assertRaises(ValueError):
78      p2.merge(p1)
79    p1 = tpu_sharding.ShardingPolicy()
80    p1.set_number_of_shards(17)
81    p2.merge(p1)
82    p1.set_shard_dimension(2)
83    with self.assertRaises(ValueError):
84      p2.merge(p1)
85
86  def testGetShardedShape(self):
87    """Tests getting a sharded shape."""
88    p = tpu_sharding.ShardingPolicy()
89    p.set_number_of_shards(3)
90    p.set_shard_dimension(1)
91    self.assertEqual(p.get_sharded_shape([4, 9]), [4, 3])
92    p.freeze()
93    with self.assertRaises(ValueError):
94      p.set_shard_dimension(0)
95    with self.assertRaises(ValueError):
96      _ = p.get_sharded_shape([4, 9], shard_index=4)
97    with self.assertRaises(ValueError):
98      _ = p.get_sharded_shape([4, 9], shard_index=-1)
99    with self.assertRaises(TypeError):
100      _ = p.get_sharded_shape("not_a_shape")
101    with self.assertRaises(ValueError):
102      _ = p.get_sharded_shape(tensor_shape.TensorShape(None))
103    with self.assertRaises(ValueError):
104      _ = p.get_sharded_shape([4, 10], shard_index=-1)
105
106  def testGetUnpartitionedShape(self):
107    """Tests getting a sharded shape."""
108    p = tpu_sharding.ShardingPolicy()
109    p.set_number_of_shards(3)
110    p.set_shard_dimension(1)
111    p.set_number_of_partitions(4)
112    self.assertEqual(p.get_unpartitioned_shape([3, 5]), [3, 20])
113    p.freeze()
114    with self.assertRaises(ValueError):
115      _ = p.get_unpartitioned_shape([3, None])
116
117  def testGetUnshardedShape(self):
118    """Tests getting an unsharded shape."""
119    p = tpu_sharding.ShardingPolicy()
120    p.set_number_of_shards(2)
121    p.set_shard_dimension(1)
122    self.assertEqual(p.get_unsharded_shape([[4, 3], [4, 3]]), [4, 6])
123    with self.assertRaises(ValueError):
124      _ = p.get_unsharded_shape([[4, 3]])
125    with self.assertRaises(ValueError):
126      _ = p.get_unsharded_shape([[4, 3], [4, 3], [4, 3]])
127    with self.assertRaises(ValueError):
128      _ = p.get_unsharded_shape([[4, 3], [4, 2]])
129    with self.assertRaises(TypeError):
130      _ = p.get_unsharded_shape([[4, 3], "not_a_shape"])
131    with self.assertRaises(ValueError):
132      _ = p.get_unsharded_shape([None, [4, 3]])
133    with self.assertRaises(ValueError):
134      _ = p.get_unsharded_shape([[2], [4, 3]])
135
136  def testScalar(self):
137    """Tests sharding and unsharding scalars."""
138    p = tpu_sharding.ShardingPolicy()
139    p.freeze()
140    self.assertEqual(p.get_sharded_shape([]), [])
141    self.assertEqual(p.get_unsharded_shape([[]]), [])
142
143
144if __name__ == "__main__":
145  test.main()
146