1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15 16"""Tests for tpu_function helpers.""" 17 18 19from tensorflow.python.framework import tensor_shape 20from tensorflow.python.platform import test 21from tensorflow.python.tpu import tpu_sharding 22 23 24class ShardingTest(test.TestCase): 25 26 def testFreeze(self): 27 """Tests that freezing a policy applies default values.""" 28 p1 = tpu_sharding.ShardingPolicy() 29 p1.freeze() 30 self.assertEqual(p1.number_of_shards, 31 tpu_sharding._DEFAULT_NUMBER_OF_SHARDS) 32 self.assertEqual(p1.shard_dimension, tpu_sharding._DEFAULT_SHARD_DIMENSION) 33 p2 = tpu_sharding.ShardingPolicy() 34 p2.set_number_of_shards(17) 35 p2.set_shard_dimension(23) 36 p2.freeze() 37 self.assertEqual(p2.number_of_shards, 17) 38 self.assertEqual(p2.shard_dimension, 23) 39 40 def testFrozen(self): 41 """Tests that frozen policies can't be changed.""" 42 p1 = tpu_sharding.ShardingPolicy() 43 p1.freeze() 44 with self.assertRaises(ValueError): 45 p1.set_number_of_shards(17) 46 with self.assertRaises(ValueError): 47 p1.set_shard_dimension(22) 48 49 def testStr(self): 50 """Tests the string representation.""" 51 p1 = tpu_sharding.ShardingPolicy() 52 self.assertEqual(str(p1), "ShardingPolicy(unset)") 53 p1.set_number_of_shards(17) 54 self.assertEqual(str(p1), "ShardingPolicy(unset)") 55 p1.set_shard_dimension(8) 56 self.assertEqual(str(p1), "ShardingPolicy(17 shards dimension 8)") 57 58 def testMerge(self): 59 """Tests that merging works.""" 60 p1 = tpu_sharding.ShardingPolicy() 61 p1.set_number_of_shards(17) 62 p1.set_shard_dimension(23) 63 p2 = tpu_sharding.ShardingPolicy() 64 p2.merge(p1) 65 self.assertEqual(p2.number_of_shards, 17) 66 self.assertEqual(p2.shard_dimension, 23) 67 p1 = tpu_sharding.ShardingPolicy() 68 p1.set_shard_dimension(12) 69 p2.merge(p1) 70 self.assertEqual(p2.number_of_shards, 17) 71 self.assertEqual(p2.shard_dimension, 12) 72 p2.freeze() 73 p2.merge(p1) 74 self.assertEqual(p2.number_of_shards, 17) 75 self.assertEqual(p2.shard_dimension, 12) 76 p1.set_number_of_shards(1) 77 with self.assertRaises(ValueError): 78 p2.merge(p1) 79 p1 = tpu_sharding.ShardingPolicy() 80 p1.set_number_of_shards(17) 81 p2.merge(p1) 82 p1.set_shard_dimension(2) 83 with self.assertRaises(ValueError): 84 p2.merge(p1) 85 86 def testGetShardedShape(self): 87 """Tests getting a sharded shape.""" 88 p = tpu_sharding.ShardingPolicy() 89 p.set_number_of_shards(3) 90 p.set_shard_dimension(1) 91 self.assertEqual(p.get_sharded_shape([4, 9]), [4, 3]) 92 p.freeze() 93 with self.assertRaises(ValueError): 94 p.set_shard_dimension(0) 95 with self.assertRaises(ValueError): 96 _ = p.get_sharded_shape([4, 9], shard_index=4) 97 with self.assertRaises(ValueError): 98 _ = p.get_sharded_shape([4, 9], shard_index=-1) 99 with self.assertRaises(TypeError): 100 _ = p.get_sharded_shape("not_a_shape") 101 with self.assertRaises(ValueError): 102 _ = p.get_sharded_shape(tensor_shape.TensorShape(None)) 103 with self.assertRaises(ValueError): 104 _ = p.get_sharded_shape([4, 10], shard_index=-1) 105 106 def testGetUnpartitionedShape(self): 107 """Tests getting a sharded shape.""" 108 p = tpu_sharding.ShardingPolicy() 109 p.set_number_of_shards(3) 110 p.set_shard_dimension(1) 111 p.set_number_of_partitions(4) 112 self.assertEqual(p.get_unpartitioned_shape([3, 5]), [3, 20]) 113 p.freeze() 114 with self.assertRaises(ValueError): 115 _ = p.get_unpartitioned_shape([3, None]) 116 117 def testGetUnshardedShape(self): 118 """Tests getting an unsharded shape.""" 119 p = tpu_sharding.ShardingPolicy() 120 p.set_number_of_shards(2) 121 p.set_shard_dimension(1) 122 self.assertEqual(p.get_unsharded_shape([[4, 3], [4, 3]]), [4, 6]) 123 with self.assertRaises(ValueError): 124 _ = p.get_unsharded_shape([[4, 3]]) 125 with self.assertRaises(ValueError): 126 _ = p.get_unsharded_shape([[4, 3], [4, 3], [4, 3]]) 127 with self.assertRaises(ValueError): 128 _ = p.get_unsharded_shape([[4, 3], [4, 2]]) 129 with self.assertRaises(TypeError): 130 _ = p.get_unsharded_shape([[4, 3], "not_a_shape"]) 131 with self.assertRaises(ValueError): 132 _ = p.get_unsharded_shape([None, [4, 3]]) 133 with self.assertRaises(ValueError): 134 _ = p.get_unsharded_shape([[2], [4, 3]]) 135 136 def testScalar(self): 137 """Tests sharding and unsharding scalars.""" 138 p = tpu_sharding.ShardingPolicy() 139 p.freeze() 140 self.assertEqual(p.get_sharded_shape([]), []) 141 self.assertEqual(p.get_unsharded_shape([[]]), []) 142 143 144if __name__ == "__main__": 145 test.main() 146