1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests that sparse tensors work with GPU, such as placement of int and string. 16 17Test using sparse tensors with distributed dataset. Since GPU does 18not support strings, sparse tensors containing string should always be placed 19on CPU. 20""" 21 22from absl.testing import parameterized 23from tensorflow.python.data.ops import dataset_ops 24from tensorflow.python.distribute import mirrored_strategy 25from tensorflow.python.eager import def_function 26from tensorflow.python.framework import constant_op 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import sparse_tensor 29from tensorflow.python.framework import test_util 30from tensorflow.python.ops import sparse_ops 31from tensorflow.python.platform import test 32 33 34def sparse_int64(): 35 return sparse_tensor.SparseTensor( 36 indices=[[0, 0], [1, 1], [2, 2], [3, 3], [4, 0], [5, 1], [6, 2], [7, 3]], 37 values=constant_op.constant([1, 2, 3, 4, 5, 6, 7, 8], dtype=dtypes.int64), 38 dense_shape=[8, 4]) 39 40 41def sparse_str(): 42 return sparse_tensor.SparseTensor( 43 indices=[[0, 0], [1, 1], [2, 2], [3, 3], [4, 0], [5, 1], [6, 2], [7, 3]], 44 values=constant_op.constant(['1', '2', '3', '4', '5', '6', '7', '8']), 45 dense_shape=[8, 4]) 46 47 48class FactoryOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase): 49 50 @parameterized.parameters( 51 (sparse_int64,), 52 (sparse_str,), 53 ) 54 @test_util.run_gpu_only 55 def testSparseWithDistributedDataset(self, sparse_factory): 56 57 @def_function.function 58 def distributed_dataset_producer(t): 59 strategy = mirrored_strategy.MirroredStrategy(['GPU:0', 'GPU:1']) 60 sparse_ds = dataset_ops.Dataset.from_tensor_slices(t).batch(2) 61 dist_dataset = strategy.experimental_distribute_dataset(sparse_ds) 62 ds = iter(dist_dataset) 63 result = strategy.experimental_local_results(next(ds))[0] 64 # Reach the end of the iterator 65 for ignore in ds: # pylint: disable=unused-variable 66 pass 67 return result 68 69 t = sparse_factory() 70 71 result = distributed_dataset_producer(t) 72 self.assertAllEqual( 73 self.evaluate(sparse_ops.sparse_tensor_to_dense(t)[0]), 74 self.evaluate(sparse_ops.sparse_tensor_to_dense(result)[0])) 75 76 77if __name__ == '__main__': 78 test.main() 79