xref: /aosp_15_r20/external/tensorflow/tensorflow/python/kernel_tests/sparse_ops/sparse_concat_op_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for SparseConcat."""
16
17import numpy as np
18
19from tensorflow.python.framework import constant_op
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import sparse_tensor
22from tensorflow.python.framework import test_util
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import sparse_ops
25from tensorflow.python.platform import test
26
27
28class SparseConcatTest(test.TestCase):
29
30  def _SparseTensor_UnknownShape(self,
31                                 ind_shape=None,
32                                 val_shape=None,
33                                 shape_shape=None):
34    return sparse_tensor.SparseTensor(
35        array_ops.placeholder(
36            dtypes.int64, shape=ind_shape),
37        array_ops.placeholder(
38            dtypes.float32, shape=val_shape),
39        array_ops.placeholder(
40            dtypes.int64, shape=shape_shape))
41
42  def _SparseTensorValue_3x3(self):
43    # [    1]
44    # [2    ]
45    # [3   4]
46    ind = np.array([[0, 2], [1, 0], [2, 0], [2, 2]])
47    val = np.array([1, 2, 3, 4])
48    shape = np.array([3, 3])
49    return sparse_tensor.SparseTensorValue(
50        np.array(ind, np.int64),
51        np.array(val, np.float32), np.array(shape, np.int64))
52
53  def _SparseTensor_3x3(self):
54    return sparse_tensor.SparseTensor.from_value(self._SparseTensorValue_3x3())
55
56  def _SparseTensorValue_3x5(self):
57    # [         ]
58    # [  1      ]
59    # [2     1 0]
60    ind = np.array([[1, 1], [2, 0], [2, 3], [2, 4]])
61    val = np.array([1, 2, 1, 0])
62    shape = np.array([3, 5])
63    return sparse_tensor.SparseTensorValue(
64        np.array(ind, np.int64),
65        np.array(val, np.float32), np.array(shape, np.int64))
66
67  def _SparseTensor_3x5(self):
68    return sparse_tensor.SparseTensor.from_value(self._SparseTensorValue_3x5())
69
70  def _SparseTensor_3x2(self):
71    # [   ]
72    # [1  ]
73    # [2  ]
74    ind = np.array([[1, 0], [2, 0]])
75    val = np.array([1, 2])
76    shape = np.array([3, 2])
77    return sparse_tensor.SparseTensor(
78        constant_op.constant(ind, dtypes.int64),
79        constant_op.constant(val, dtypes.float32),
80        constant_op.constant(shape, dtypes.int64))
81
82  def _SparseTensor_2x3(self):
83    # [  1  ]
84    # [1   2]
85    ind = np.array([[0, 1], [1, 0], [1, 2]])
86    val = np.array([1, 1, 2])
87    shape = np.array([2, 3])
88    return sparse_tensor.SparseTensor(
89        constant_op.constant(ind, dtypes.int64),
90        constant_op.constant(val, dtypes.float32),
91        constant_op.constant(shape, dtypes.int64))
92
93  def _SparseTensor_2x3x4(self):
94    ind = np.array([
95        [0, 0, 1],
96        [0, 1, 0], [0, 1, 2],
97        [1, 0, 3],
98        [1, 1, 1], [1, 1, 3],
99        [1, 2, 2]])
100    val = np.array([1, 10, 12, 103, 111, 113, 122])
101    shape = np.array([2, 3, 4])
102    return sparse_tensor.SparseTensor(
103        constant_op.constant(ind, dtypes.int64),
104        constant_op.constant(val, dtypes.float32),
105        constant_op.constant(shape, dtypes.int64))
106
107  def _SparseTensor_NoNonZeros(self, dense_shape):
108    ind = np.empty(shape=(0, len(dense_shape)))
109    val = np.array([])
110    shape = np.array(dense_shape)
111    return sparse_tensor.SparseTensor(
112        constant_op.constant(ind, dtypes.int64),
113        constant_op.constant(val, dtypes.float32),
114        constant_op.constant(shape, dtypes.int64))
115
116  def _SparseTensor_String3x3(self):
117    # [    a]
118    # [b    ]
119    # [c   d]
120    ind = np.array([[0, 2], [1, 0], [2, 0], [2, 2]])
121    val = np.array(["a", "b", "c", "d"])
122    shape = np.array([3, 3])
123    return sparse_tensor.SparseTensor(
124        constant_op.constant(ind, dtypes.int64),
125        constant_op.constant(val, dtypes.string),
126        constant_op.constant(shape, dtypes.int64))
127
128  def _SparseTensor_String3x5(self):
129    # [         ]
130    # [  e      ]
131    # [f     g h]
132    ind = np.array([[1, 1], [2, 0], [2, 3], [2, 4]])
133    val = np.array(["e", "f", "g", "h"])
134    shape = np.array([3, 5])
135    return sparse_tensor.SparseTensor(
136        constant_op.constant(ind, dtypes.int64),
137        constant_op.constant(val, dtypes.string),
138        constant_op.constant(shape, dtypes.int64))
139
140  def testConcat1(self):
141    with self.session() as sess:
142      # concat(A):
143      # [    1]
144      # [2    ]
145      # [3   4]
146      for sp_a in (self._SparseTensorValue_3x3(), self._SparseTensor_3x3()):
147        # Note that we ignore concat_dim in this case since we short-circuit the
148        # single-input case in python.
149        for concat_dim in (-2000, 1, 2000):
150          sp_concat = sparse_ops.sparse_concat(concat_dim, [sp_a])
151
152          self.assertEqual(sp_concat.indices.get_shape(), [4, 2])
153          self.assertEqual(sp_concat.values.get_shape(), [4])
154          self.assertEqual(sp_concat.dense_shape.get_shape(), [2])
155
156          concat_out = self.evaluate(sp_concat)
157
158          self.assertAllEqual(concat_out.indices,
159                              [[0, 2], [1, 0], [2, 0], [2, 2]])
160          self.assertAllEqual(concat_out.values, [1, 2, 3, 4])
161          self.assertAllEqual(concat_out.dense_shape, [3, 3])
162
163  def testConcat2(self):
164    with self.session() as sess:
165      # concat(A, B):
166      # [    1          ]
167      # [2       1      ]
168      # [3   4 2     1 0]
169      for sp_a in (self._SparseTensorValue_3x3(), self._SparseTensor_3x3()):
170        for sp_b in (self._SparseTensorValue_3x5(), self._SparseTensor_3x5()):
171          for concat_dim in (-1, 1):
172            sp_concat = sparse_ops.sparse_concat(concat_dim, [sp_a, sp_b])
173
174            self.assertEqual(sp_concat.indices.get_shape(), [8, 2])
175            self.assertEqual(sp_concat.values.get_shape(), [8])
176            self.assertEqual(sp_concat.dense_shape.get_shape(), [2])
177
178            concat_out = self.evaluate(sp_concat)
179
180            self.assertAllEqual(concat_out.indices, [[0, 2], [1, 0], [1, 4],
181                                                     [2, 0], [2, 2], [2, 3],
182                                                     [2, 6], [2, 7]])
183            self.assertAllEqual(concat_out.values, [1, 2, 1, 3, 4, 2, 1, 0])
184            self.assertAllEqual(concat_out.dense_shape, [3, 8])
185
186  def testConcatDim0(self):
187    with self.session() as sess:
188      # concat(A, D):
189      # [    1]
190      # [2    ]
191      # [3   4]
192      # [  1  ]
193      # [1   2]
194      sp_a = self._SparseTensor_3x3()
195      sp_d = self._SparseTensor_2x3()
196
197      for concat_dim in (-2, 0):
198        sp_concat = sparse_ops.sparse_concat(concat_dim, [sp_a, sp_d])
199
200        self.assertEqual(sp_concat.indices.get_shape(), [7, 2])
201        self.assertEqual(sp_concat.values.get_shape(), [7])
202        self.assertEqual(sp_concat.dense_shape.get_shape(), [2])
203
204        concat_out = self.evaluate(sp_concat)
205
206        self.assertAllEqual(
207            concat_out.indices,
208            [[0, 2], [1, 0], [2, 0], [2, 2], [3, 1], [4, 0], [4, 2]])
209        self.assertAllEqual(concat_out.values, np.array([1, 2, 3, 4, 1, 1, 2]))
210        self.assertAllEqual(concat_out.dense_shape, np.array([5, 3]))
211
212  def testConcat3(self):
213    with self.session() as sess:
214      # concat(A, B, C):
215      # [    1              ]
216      # [2       1       1  ]
217      # [3   4 2     1 0 2  ]
218      sp_a = self._SparseTensor_3x3()
219      sp_b = self._SparseTensor_3x5()
220      sp_c = self._SparseTensor_3x2()
221
222      for concat_dim in (-1, 1):
223        sp_concat = sparse_ops.sparse_concat(concat_dim, [sp_a, sp_b, sp_c])
224
225        self.assertEqual(sp_concat.indices.get_shape(), [10, 2])
226        self.assertEqual(sp_concat.values.get_shape(), [10])
227        self.assertEqual(sp_concat.dense_shape.get_shape(), [2])
228
229        concat_out = self.evaluate(sp_concat)
230
231        self.assertAllEqual(concat_out.indices, [[0, 2], [1, 0], [1, 4], [1, 8],
232                                                 [2, 0], [2, 2], [2, 3], [2, 6],
233                                                 [2, 7], [2, 8]])
234        self.assertAllEqual(concat_out.values, [1, 2, 1, 1, 3, 4, 2, 1, 0, 2])
235        self.assertAllEqual(concat_out.dense_shape, [3, 10])
236
237  def testConcatNoNonZeros(self):
238    sp_a = self._SparseTensor_NoNonZeros((2, 3, 4))
239    sp_b = self._SparseTensor_NoNonZeros((2, 7, 4))
240    sp_c = self._SparseTensor_NoNonZeros((2, 5, 4))
241
242    with self.session() as sess:
243      concat_dim = 1
244      sp_concat = sparse_ops.sparse_concat(concat_dim, [sp_a, sp_b, sp_c])
245
246      self.assertEqual(sp_concat.indices.get_shape(), [0, 3])
247      self.assertEqual(sp_concat.values.get_shape(), [0])
248      self.assertEqual(sp_concat.dense_shape.get_shape(), [3])
249
250      concat_out = self.evaluate(sp_concat)
251
252      self.assertEqual(concat_out.indices.shape, (0, 3))
253      self.assertEqual(concat_out.values.shape, (0,))
254      self.assertAllEqual(concat_out.dense_shape, [2, 15, 4])
255
256  def testConcatSomeNoNonZeros(self):
257    sp_a = self._SparseTensor_NoNonZeros((2, 7, 4))
258    sp_b = self._SparseTensor_2x3x4()
259    sp_c = self._SparseTensor_NoNonZeros((2, 5, 4))
260    output_nnz = sp_b.indices.get_shape()[0]
261
262    with self.session() as sess:
263      concat_dim = 1
264      sp_concat = sparse_ops.sparse_concat(concat_dim, [sp_a, sp_b, sp_c])
265
266      self.assertEqual(sp_concat.indices.get_shape(), [output_nnz, 3])
267      self.assertEqual(sp_concat.values.get_shape(), [output_nnz])
268      self.assertEqual(sp_concat.dense_shape.get_shape(), [3])
269
270      concat_out = self.evaluate(sp_concat)
271
272      self.assertAllEqual(concat_out.indices,
273                          sp_b.indices + [0, sp_a.dense_shape[1], 0])
274      self.assertAllEqual(concat_out.values, sp_b.values)
275      self.assertAllEqual(concat_out.dense_shape, [2, 15, 4])
276
277  def testConcatNonNumeric(self):
278    with self.session(use_gpu=False) as sess:
279      # concat(A, B):
280      # [    a          ]
281      # [b       e      ]
282      # [c   d f     g h]
283      sp_a = self._SparseTensor_String3x3()
284      sp_b = self._SparseTensor_String3x5()
285
286      for concat_dim in (-1, 1):
287        sp_concat = sparse_ops.sparse_concat(concat_dim, [sp_a, sp_b])
288
289        self.assertEqual(sp_concat.indices.get_shape(), [8, 2])
290        self.assertEqual(sp_concat.values.get_shape(), [8])
291        self.assertEqual(sp_concat.dense_shape.get_shape(), [2])
292
293        concat_out = self.evaluate(sp_concat)
294
295        self.assertAllEqual(
296            concat_out.indices,
297            [[0, 2], [1, 0], [1, 4], [2, 0], [2, 2], [2, 3], [2, 6], [2, 7]])
298        self.assertAllEqual(concat_out.values,
299                            [b"a", b"b", b"e", b"c", b"d", b"f", b"g", b"h"])
300        self.assertAllEqual(concat_out.dense_shape, [3, 8])
301
302  @test_util.run_deprecated_v1
303  def testMismatchedRank(self):
304    with self.session():
305      sp_a = self._SparseTensor_3x3()
306      sp_e = self._SparseTensor_2x3x4()
307
308      # Rank mismatches can be caught at shape-inference time
309      for concat_dim in (-1, 1):
310        with self.assertRaises(ValueError):
311          sparse_ops.sparse_concat(concat_dim, [sp_a, sp_e])
312
313  @test_util.run_deprecated_v1
314  def testMismatchedRankExpandNonconcatDim(self):
315    with self.session():
316      sp_a = self._SparseTensor_3x3()
317      sp_e = self._SparseTensor_2x3x4()
318
319      # Rank mismatches should be caught at shape-inference time, even for
320      # expand_nonconcat_dim=True.
321      for concat_dim in (-1, 1):
322        with self.assertRaises(ValueError):
323          sparse_ops.sparse_concat(
324              concat_dim, [sp_a, sp_e], expand_nonconcat_dim=True)
325
326  @test_util.run_deprecated_v1
327  def testMismatchedShapes(self):
328    with self.session() as sess:
329      sp_a = self._SparseTensor_3x3()
330      sp_b = self._SparseTensor_3x5()
331      sp_c = self._SparseTensor_3x2()
332      sp_d = self._SparseTensor_2x3()
333      for concat_dim in (-1, 1):
334        sp_concat = sparse_ops.sparse_concat(concat_dim,
335                                             [sp_a, sp_b, sp_c, sp_d])
336
337        # Shape mismatches can only be caught when the op is run
338        with self.assertRaisesOpError("Input shapes must match"):
339          self.evaluate(sp_concat)
340
341  def testMismatchedShapesExpandNonconcatDim(self):
342    with self.session() as sess:
343      sp_a = self._SparseTensor_3x3()
344      sp_b = self._SparseTensor_3x5()
345      sp_c = self._SparseTensor_3x2()
346      sp_d = self._SparseTensor_2x3()
347      for concat_dim0 in (-2, 0):
348        for concat_dim1 in (-1, 1):
349          sp_concat_dim0 = sparse_ops.sparse_concat(
350              concat_dim0, [sp_a, sp_b, sp_c, sp_d], expand_nonconcat_dim=True)
351          sp_concat_dim1 = sparse_ops.sparse_concat(
352              concat_dim1, [sp_a, sp_b, sp_c, sp_d], expand_nonconcat_dim=True)
353
354          sp_concat_dim0_out = self.evaluate(sp_concat_dim0)
355          sp_concat_dim1_out = self.evaluate(sp_concat_dim1)
356
357          self.assertAllEqual(sp_concat_dim0_out.indices,
358                              [[0, 2], [1, 0], [2, 0], [2, 2], [4, 1], [5, 0],
359                               [5, 3], [5, 4], [7, 0], [8, 0], [9, 1], [10, 0],
360                               [10, 2]])
361          self.assertAllEqual(sp_concat_dim0_out.values,
362                              [1, 2, 3, 4, 1, 2, 1, 0, 1, 2, 1, 1, 2])
363          self.assertAllEqual(sp_concat_dim0_out.dense_shape, [11, 5])
364
365          self.assertAllEqual(sp_concat_dim1_out.indices,
366                              [[0, 2], [0, 11], [1, 0], [1, 4], [1, 8], [1, 10],
367                               [1, 12], [2, 0], [2, 2], [2, 3], [2, 6], [2, 7],
368                               [2, 8]])
369          self.assertAllEqual(sp_concat_dim1_out.values,
370                              [1, 1, 2, 1, 1, 1, 2, 3, 4, 2, 1, 0, 2])
371          self.assertAllEqual(sp_concat_dim1_out.dense_shape, [3, 13])
372
373  @test_util.run_deprecated_v1
374  def testShapeInferenceUnknownShapes(self):
375    with self.session():
376      sp_inputs = [
377          self._SparseTensor_UnknownShape(),
378          self._SparseTensor_UnknownShape(val_shape=[3]),
379          self._SparseTensor_UnknownShape(ind_shape=[1, 3]),
380          self._SparseTensor_UnknownShape(shape_shape=[3])
381      ]
382
383      for concat_dim in (-2, 0):
384        sp_concat = sparse_ops.sparse_concat(concat_dim, sp_inputs)
385
386        self.assertEqual(sp_concat.indices.get_shape().as_list(), [None, 3])
387        self.assertEqual(sp_concat.values.get_shape().as_list(), [None])
388        self.assertEqual(sp_concat.dense_shape.get_shape(), [3])
389
390  def testConcatShape(self):
391    # Test case for GitHub 21964.
392    x = sparse_tensor.SparseTensor(
393        indices=[[0, 0], [1, 1]], values=[1, 2], dense_shape=[2, 2])
394    y = sparse_tensor.SparseTensor(
395        indices=[[0, 0], [1, 1]], values=[1, 2], dense_shape=[2, 2])
396    z = sparse_ops.sparse_concat(-1, [x, y])
397    self.assertEqual(z.get_shape().as_list(), [2, 4])
398
399
400if __name__ == "__main__":
401  test.main()
402