xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/factory_ops_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests that sparse tensors work with GPU, such as placement of int and string.
16
17Test using sparse tensors with distributed dataset. Since GPU does
18not support strings, sparse tensors containing string should always be placed
19on CPU.
20"""
21
22from absl.testing import parameterized
23from tensorflow.python.data.ops import dataset_ops
24from tensorflow.python.distribute import mirrored_strategy
25from tensorflow.python.eager import def_function
26from tensorflow.python.framework import constant_op
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import sparse_tensor
29from tensorflow.python.framework import test_util
30from tensorflow.python.ops import sparse_ops
31from tensorflow.python.platform import test
32
33
34def sparse_int64():
35  return sparse_tensor.SparseTensor(
36      indices=[[0, 0], [1, 1], [2, 2], [3, 3], [4, 0], [5, 1], [6, 2], [7, 3]],
37      values=constant_op.constant([1, 2, 3, 4, 5, 6, 7, 8], dtype=dtypes.int64),
38      dense_shape=[8, 4])
39
40
41def sparse_str():
42  return sparse_tensor.SparseTensor(
43      indices=[[0, 0], [1, 1], [2, 2], [3, 3], [4, 0], [5, 1], [6, 2], [7, 3]],
44      values=constant_op.constant(['1', '2', '3', '4', '5', '6', '7', '8']),
45      dense_shape=[8, 4])
46
47
48class FactoryOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
49
50  @parameterized.parameters(
51      (sparse_int64,),
52      (sparse_str,),
53  )
54  @test_util.run_gpu_only
55  def testSparseWithDistributedDataset(self, sparse_factory):
56
57    @def_function.function
58    def distributed_dataset_producer(t):
59      strategy = mirrored_strategy.MirroredStrategy(['GPU:0', 'GPU:1'])
60      sparse_ds = dataset_ops.Dataset.from_tensor_slices(t).batch(2)
61      dist_dataset = strategy.experimental_distribute_dataset(sparse_ds)
62      ds = iter(dist_dataset)
63      result = strategy.experimental_local_results(next(ds))[0]
64      # Reach the end of the iterator
65      for ignore in ds:  # pylint: disable=unused-variable
66        pass
67      return result
68
69    t = sparse_factory()
70
71    result = distributed_dataset_producer(t)
72    self.assertAllEqual(
73        self.evaluate(sparse_ops.sparse_tensor_to_dense(t)[0]),
74        self.evaluate(sparse_ops.sparse_tensor_to_dense(result)[0]))
75
76
77if __name__ == '__main__':
78  test.main()
79