xref: /aosp_15_r20/external/tensorflow/tensorflow/python/tpu/tpu_infeed_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15
16"""Tests for TPU InfeedQueue methods."""
17
18
19from tensorflow.python.framework import constant_op
20from tensorflow.python.framework import dtypes
21from tensorflow.python.platform import test
22from tensorflow.python.tpu import tpu_feed
23
24
25class InfeedTest(test.TestCase):
26
27  def testConstructor(self):
28    """Tests that the constructor can be called with different arguments."""
29    i = tpu_feed.InfeedQueue(number_of_tuple_elements=2)
30    self.assertEqual(i.number_of_tuple_elements, 2)
31    self.assertEqual(i.tuple_types, None)
32    self.assertEqual(i.tuple_shapes, None)
33    self.assertEqual(i.number_of_shards, None)
34    i = tpu_feed.InfeedQueue(
35        tuple_types=[dtypes.float32, dtypes.int32, dtypes.int32])
36    self.assertEqual(i.number_of_tuple_elements, 3)
37    self.assertEqual(i.tuple_types,
38                     [dtypes.float32, dtypes.int32, dtypes.int32])
39    self.assertEqual(i.tuple_shapes, None)
40    self.assertEqual(i.number_of_shards, None)
41    i = tpu_feed.InfeedQueue(tuple_shapes=[[1], [2, 3]])
42    self.assertEqual(i.number_of_tuple_elements, 2)
43    self.assertEqual(i.tuple_types, None)
44    self.assertEqual(i.tuple_shapes, [[1], [2, 3]])
45    self.assertEqual(i.number_of_shards, None)
46    i = tpu_feed.InfeedQueue(shard_dimensions=[1, 0, 7])
47    self.assertEqual(i.number_of_tuple_elements, 3)
48    self.assertEqual(i.tuple_types, None)
49    self.assertEqual(i.tuple_shapes, None)
50    self.assertEqual([p.shard_dimension
51                      for p in i.sharding_policies], [1, 0, 7])
52    with self.assertRaises(ValueError):
53      i = tpu_feed.InfeedQueue()
54    with self.assertRaises(ValueError):
55      i = tpu_feed.InfeedQueue(
56          number_of_tuple_elements=2, tuple_types=[dtypes.float32])
57    with self.assertRaises(ValueError):
58      i = tpu_feed.InfeedQueue(number_of_tuple_elements=2, tuple_shapes=[[1]])
59    with self.assertRaises(ValueError):
60      i = tpu_feed.InfeedQueue(number_of_tuple_elements=2, shard_dimensions=[1])
61    with self.assertRaises(ValueError):
62      i = tpu_feed.InfeedQueue(tuple_shapes=[[1], [2, 3]], shard_dimensions=[1])
63
64  def testModification(self):
65    """Tests modification of the queue post-construction."""
66    i = tpu_feed.InfeedQueue(number_of_tuple_elements=2)
67    i.set_tuple_types([dtypes.float32, dtypes.int32])
68    self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.int32])
69    i.set_tuple_types([dtypes.float32, dtypes.float32])
70    self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.float32])
71    with self.assertRaises(ValueError):
72      i.set_tuple_types([dtypes.float32])
73    i.set_tuple_shapes([[1], [2, 3]])
74    self.assertEqual(i.tuple_shapes, [[1], [2, 3]])
75    i.set_tuple_shapes([[1, 2], [3, 4]])
76    self.assertEqual(i.tuple_shapes, [[1, 2], [3, 4]])
77    with self.assertRaises(ValueError):
78      i.set_tuple_shapes([[1, 2]])
79    i.set_number_of_shards(2)
80    self.assertEqual(i.number_of_shards, 2)
81    i.set_number_of_shards(3)
82    self.assertEqual(i.number_of_shards, 3)
83    t1 = constant_op.constant(1, dtypes.int32, shape=[6])
84    t2 = constant_op.constant(2.0, dtypes.float32, shape=[3, 18])
85    i.set_configuration_from_input_tensors([t1, t2])
86    self.assertEqual(i.tuple_shapes, [[6], [3, 18]])
87    self.assertEqual(i.tuple_types, [dtypes.int32, dtypes.float32])
88    i.set_configuration_from_sharded_input_tensors([[t2, t1], [t2, t1]])
89    self.assertEqual(i.number_of_shards, 2)
90    self.assertEqual(i.tuple_shapes, [[6, 18], [12]])
91    self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.int32])
92    i.set_shard_dimensions([1, 0])
93    i.set_number_of_shards(3)
94    with self.assertRaises(ValueError):
95      i.set_number_of_shards(4)
96
97  def testFreezing(self):
98    """Tests freezing the queue."""
99    i = tpu_feed.InfeedQueue(number_of_tuple_elements=2)
100    t1 = constant_op.constant(1, dtypes.int32, shape=[2])
101    t2 = constant_op.constant(2.0, dtypes.float32, shape=[2, 4])
102    i.set_configuration_from_sharded_input_tensors([[t2, t1], [t2, t1]])
103    self.assertEqual(i.number_of_shards, 2)
104    self.assertEqual(i.tuple_shapes, [[4, 4], [4]])
105    self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.int32])
106    self.assertEqual(i.shard_dimensions, [0, 0])
107    i.freeze()
108    i.set_number_of_shards(2)
109    i.set_tuple_shapes([[4, 4], [4]])
110    i.set_tuple_types([dtypes.float32, dtypes.int32])
111    i.set_shard_dimensions([0, 0])
112    with self.assertRaises(ValueError):
113      i.set_number_of_shards(1)
114    with self.assertRaises(ValueError):
115      i.set_tuple_shapes([[8, 8], [8]])
116    with self.assertRaises(ValueError):
117      i.set_tuple_types([dtypes.int32, dtypes.float32])
118    with self.assertRaises(ValueError):
119      i.set_shard_dimensions([1, 0])
120    self.assertEqual(i.number_of_shards, 2)
121    self.assertEqual(i.tuple_shapes, [[4, 4], [4]])
122    self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.int32])
123    self.assertEqual(i.shard_dimensions, [0, 0])
124
125if __name__ == '__main__':
126  test.main()
127