1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15 16"""Tests for TPU InfeedQueue methods.""" 17 18 19from tensorflow.python.framework import constant_op 20from tensorflow.python.framework import dtypes 21from tensorflow.python.platform import test 22from tensorflow.python.tpu import tpu_feed 23 24 25class InfeedTest(test.TestCase): 26 27 def testConstructor(self): 28 """Tests that the constructor can be called with different arguments.""" 29 i = tpu_feed.InfeedQueue(number_of_tuple_elements=2) 30 self.assertEqual(i.number_of_tuple_elements, 2) 31 self.assertEqual(i.tuple_types, None) 32 self.assertEqual(i.tuple_shapes, None) 33 self.assertEqual(i.number_of_shards, None) 34 i = tpu_feed.InfeedQueue( 35 tuple_types=[dtypes.float32, dtypes.int32, dtypes.int32]) 36 self.assertEqual(i.number_of_tuple_elements, 3) 37 self.assertEqual(i.tuple_types, 38 [dtypes.float32, dtypes.int32, dtypes.int32]) 39 self.assertEqual(i.tuple_shapes, None) 40 self.assertEqual(i.number_of_shards, None) 41 i = tpu_feed.InfeedQueue(tuple_shapes=[[1], [2, 3]]) 42 self.assertEqual(i.number_of_tuple_elements, 2) 43 self.assertEqual(i.tuple_types, None) 44 self.assertEqual(i.tuple_shapes, [[1], [2, 3]]) 45 self.assertEqual(i.number_of_shards, None) 46 i = tpu_feed.InfeedQueue(shard_dimensions=[1, 0, 7]) 47 self.assertEqual(i.number_of_tuple_elements, 3) 48 self.assertEqual(i.tuple_types, None) 49 self.assertEqual(i.tuple_shapes, None) 50 self.assertEqual([p.shard_dimension 51 for p in i.sharding_policies], [1, 0, 7]) 52 with self.assertRaises(ValueError): 53 i = tpu_feed.InfeedQueue() 54 with self.assertRaises(ValueError): 55 i = tpu_feed.InfeedQueue( 56 number_of_tuple_elements=2, tuple_types=[dtypes.float32]) 57 with self.assertRaises(ValueError): 58 i = tpu_feed.InfeedQueue(number_of_tuple_elements=2, tuple_shapes=[[1]]) 59 with self.assertRaises(ValueError): 60 i = tpu_feed.InfeedQueue(number_of_tuple_elements=2, shard_dimensions=[1]) 61 with self.assertRaises(ValueError): 62 i = tpu_feed.InfeedQueue(tuple_shapes=[[1], [2, 3]], shard_dimensions=[1]) 63 64 def testModification(self): 65 """Tests modification of the queue post-construction.""" 66 i = tpu_feed.InfeedQueue(number_of_tuple_elements=2) 67 i.set_tuple_types([dtypes.float32, dtypes.int32]) 68 self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.int32]) 69 i.set_tuple_types([dtypes.float32, dtypes.float32]) 70 self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.float32]) 71 with self.assertRaises(ValueError): 72 i.set_tuple_types([dtypes.float32]) 73 i.set_tuple_shapes([[1], [2, 3]]) 74 self.assertEqual(i.tuple_shapes, [[1], [2, 3]]) 75 i.set_tuple_shapes([[1, 2], [3, 4]]) 76 self.assertEqual(i.tuple_shapes, [[1, 2], [3, 4]]) 77 with self.assertRaises(ValueError): 78 i.set_tuple_shapes([[1, 2]]) 79 i.set_number_of_shards(2) 80 self.assertEqual(i.number_of_shards, 2) 81 i.set_number_of_shards(3) 82 self.assertEqual(i.number_of_shards, 3) 83 t1 = constant_op.constant(1, dtypes.int32, shape=[6]) 84 t2 = constant_op.constant(2.0, dtypes.float32, shape=[3, 18]) 85 i.set_configuration_from_input_tensors([t1, t2]) 86 self.assertEqual(i.tuple_shapes, [[6], [3, 18]]) 87 self.assertEqual(i.tuple_types, [dtypes.int32, dtypes.float32]) 88 i.set_configuration_from_sharded_input_tensors([[t2, t1], [t2, t1]]) 89 self.assertEqual(i.number_of_shards, 2) 90 self.assertEqual(i.tuple_shapes, [[6, 18], [12]]) 91 self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.int32]) 92 i.set_shard_dimensions([1, 0]) 93 i.set_number_of_shards(3) 94 with self.assertRaises(ValueError): 95 i.set_number_of_shards(4) 96 97 def testFreezing(self): 98 """Tests freezing the queue.""" 99 i = tpu_feed.InfeedQueue(number_of_tuple_elements=2) 100 t1 = constant_op.constant(1, dtypes.int32, shape=[2]) 101 t2 = constant_op.constant(2.0, dtypes.float32, shape=[2, 4]) 102 i.set_configuration_from_sharded_input_tensors([[t2, t1], [t2, t1]]) 103 self.assertEqual(i.number_of_shards, 2) 104 self.assertEqual(i.tuple_shapes, [[4, 4], [4]]) 105 self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.int32]) 106 self.assertEqual(i.shard_dimensions, [0, 0]) 107 i.freeze() 108 i.set_number_of_shards(2) 109 i.set_tuple_shapes([[4, 4], [4]]) 110 i.set_tuple_types([dtypes.float32, dtypes.int32]) 111 i.set_shard_dimensions([0, 0]) 112 with self.assertRaises(ValueError): 113 i.set_number_of_shards(1) 114 with self.assertRaises(ValueError): 115 i.set_tuple_shapes([[8, 8], [8]]) 116 with self.assertRaises(ValueError): 117 i.set_tuple_types([dtypes.int32, dtypes.float32]) 118 with self.assertRaises(ValueError): 119 i.set_shard_dimensions([1, 0]) 120 self.assertEqual(i.number_of_shards, 2) 121 self.assertEqual(i.tuple_shapes, [[4, 4], [4]]) 122 self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.int32]) 123 self.assertEqual(i.shard_dimensions, [0, 0]) 124 125if __name__ == '__main__': 126 test.main() 127