1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for SparseConcat.""" 16 17import numpy as np 18 19from tensorflow.python.framework import constant_op 20from tensorflow.python.framework import dtypes 21from tensorflow.python.framework import sparse_tensor 22from tensorflow.python.framework import test_util 23from tensorflow.python.ops import array_ops 24from tensorflow.python.ops import sparse_ops 25from tensorflow.python.platform import test 26 27 28class SparseConcatTest(test.TestCase): 29 30 def _SparseTensor_UnknownShape(self, 31 ind_shape=None, 32 val_shape=None, 33 shape_shape=None): 34 return sparse_tensor.SparseTensor( 35 array_ops.placeholder( 36 dtypes.int64, shape=ind_shape), 37 array_ops.placeholder( 38 dtypes.float32, shape=val_shape), 39 array_ops.placeholder( 40 dtypes.int64, shape=shape_shape)) 41 42 def _SparseTensorValue_3x3(self): 43 # [ 1] 44 # [2 ] 45 # [3 4] 46 ind = np.array([[0, 2], [1, 0], [2, 0], [2, 2]]) 47 val = np.array([1, 2, 3, 4]) 48 shape = np.array([3, 3]) 49 return sparse_tensor.SparseTensorValue( 50 np.array(ind, np.int64), 51 np.array(val, np.float32), np.array(shape, np.int64)) 52 53 def _SparseTensor_3x3(self): 54 return sparse_tensor.SparseTensor.from_value(self._SparseTensorValue_3x3()) 55 56 def _SparseTensorValue_3x5(self): 57 # [ ] 58 # [ 1 ] 59 # [2 1 0] 60 ind = np.array([[1, 1], [2, 0], [2, 3], [2, 4]]) 61 val = np.array([1, 2, 1, 0]) 62 shape = np.array([3, 5]) 63 return sparse_tensor.SparseTensorValue( 64 np.array(ind, np.int64), 65 np.array(val, np.float32), np.array(shape, np.int64)) 66 67 def _SparseTensor_3x5(self): 68 return sparse_tensor.SparseTensor.from_value(self._SparseTensorValue_3x5()) 69 70 def _SparseTensor_3x2(self): 71 # [ ] 72 # [1 ] 73 # [2 ] 74 ind = np.array([[1, 0], [2, 0]]) 75 val = np.array([1, 2]) 76 shape = np.array([3, 2]) 77 return sparse_tensor.SparseTensor( 78 constant_op.constant(ind, dtypes.int64), 79 constant_op.constant(val, dtypes.float32), 80 constant_op.constant(shape, dtypes.int64)) 81 82 def _SparseTensor_2x3(self): 83 # [ 1 ] 84 # [1 2] 85 ind = np.array([[0, 1], [1, 0], [1, 2]]) 86 val = np.array([1, 1, 2]) 87 shape = np.array([2, 3]) 88 return sparse_tensor.SparseTensor( 89 constant_op.constant(ind, dtypes.int64), 90 constant_op.constant(val, dtypes.float32), 91 constant_op.constant(shape, dtypes.int64)) 92 93 def _SparseTensor_2x3x4(self): 94 ind = np.array([ 95 [0, 0, 1], 96 [0, 1, 0], [0, 1, 2], 97 [1, 0, 3], 98 [1, 1, 1], [1, 1, 3], 99 [1, 2, 2]]) 100 val = np.array([1, 10, 12, 103, 111, 113, 122]) 101 shape = np.array([2, 3, 4]) 102 return sparse_tensor.SparseTensor( 103 constant_op.constant(ind, dtypes.int64), 104 constant_op.constant(val, dtypes.float32), 105 constant_op.constant(shape, dtypes.int64)) 106 107 def _SparseTensor_NoNonZeros(self, dense_shape): 108 ind = np.empty(shape=(0, len(dense_shape))) 109 val = np.array([]) 110 shape = np.array(dense_shape) 111 return sparse_tensor.SparseTensor( 112 constant_op.constant(ind, dtypes.int64), 113 constant_op.constant(val, dtypes.float32), 114 constant_op.constant(shape, dtypes.int64)) 115 116 def _SparseTensor_String3x3(self): 117 # [ a] 118 # [b ] 119 # [c d] 120 ind = np.array([[0, 2], [1, 0], [2, 0], [2, 2]]) 121 val = np.array(["a", "b", "c", "d"]) 122 shape = np.array([3, 3]) 123 return sparse_tensor.SparseTensor( 124 constant_op.constant(ind, dtypes.int64), 125 constant_op.constant(val, dtypes.string), 126 constant_op.constant(shape, dtypes.int64)) 127 128 def _SparseTensor_String3x5(self): 129 # [ ] 130 # [ e ] 131 # [f g h] 132 ind = np.array([[1, 1], [2, 0], [2, 3], [2, 4]]) 133 val = np.array(["e", "f", "g", "h"]) 134 shape = np.array([3, 5]) 135 return sparse_tensor.SparseTensor( 136 constant_op.constant(ind, dtypes.int64), 137 constant_op.constant(val, dtypes.string), 138 constant_op.constant(shape, dtypes.int64)) 139 140 def testConcat1(self): 141 with self.session() as sess: 142 # concat(A): 143 # [ 1] 144 # [2 ] 145 # [3 4] 146 for sp_a in (self._SparseTensorValue_3x3(), self._SparseTensor_3x3()): 147 # Note that we ignore concat_dim in this case since we short-circuit the 148 # single-input case in python. 149 for concat_dim in (-2000, 1, 2000): 150 sp_concat = sparse_ops.sparse_concat(concat_dim, [sp_a]) 151 152 self.assertEqual(sp_concat.indices.get_shape(), [4, 2]) 153 self.assertEqual(sp_concat.values.get_shape(), [4]) 154 self.assertEqual(sp_concat.dense_shape.get_shape(), [2]) 155 156 concat_out = self.evaluate(sp_concat) 157 158 self.assertAllEqual(concat_out.indices, 159 [[0, 2], [1, 0], [2, 0], [2, 2]]) 160 self.assertAllEqual(concat_out.values, [1, 2, 3, 4]) 161 self.assertAllEqual(concat_out.dense_shape, [3, 3]) 162 163 def testConcat2(self): 164 with self.session() as sess: 165 # concat(A, B): 166 # [ 1 ] 167 # [2 1 ] 168 # [3 4 2 1 0] 169 for sp_a in (self._SparseTensorValue_3x3(), self._SparseTensor_3x3()): 170 for sp_b in (self._SparseTensorValue_3x5(), self._SparseTensor_3x5()): 171 for concat_dim in (-1, 1): 172 sp_concat = sparse_ops.sparse_concat(concat_dim, [sp_a, sp_b]) 173 174 self.assertEqual(sp_concat.indices.get_shape(), [8, 2]) 175 self.assertEqual(sp_concat.values.get_shape(), [8]) 176 self.assertEqual(sp_concat.dense_shape.get_shape(), [2]) 177 178 concat_out = self.evaluate(sp_concat) 179 180 self.assertAllEqual(concat_out.indices, [[0, 2], [1, 0], [1, 4], 181 [2, 0], [2, 2], [2, 3], 182 [2, 6], [2, 7]]) 183 self.assertAllEqual(concat_out.values, [1, 2, 1, 3, 4, 2, 1, 0]) 184 self.assertAllEqual(concat_out.dense_shape, [3, 8]) 185 186 def testConcatDim0(self): 187 with self.session() as sess: 188 # concat(A, D): 189 # [ 1] 190 # [2 ] 191 # [3 4] 192 # [ 1 ] 193 # [1 2] 194 sp_a = self._SparseTensor_3x3() 195 sp_d = self._SparseTensor_2x3() 196 197 for concat_dim in (-2, 0): 198 sp_concat = sparse_ops.sparse_concat(concat_dim, [sp_a, sp_d]) 199 200 self.assertEqual(sp_concat.indices.get_shape(), [7, 2]) 201 self.assertEqual(sp_concat.values.get_shape(), [7]) 202 self.assertEqual(sp_concat.dense_shape.get_shape(), [2]) 203 204 concat_out = self.evaluate(sp_concat) 205 206 self.assertAllEqual( 207 concat_out.indices, 208 [[0, 2], [1, 0], [2, 0], [2, 2], [3, 1], [4, 0], [4, 2]]) 209 self.assertAllEqual(concat_out.values, np.array([1, 2, 3, 4, 1, 1, 2])) 210 self.assertAllEqual(concat_out.dense_shape, np.array([5, 3])) 211 212 def testConcat3(self): 213 with self.session() as sess: 214 # concat(A, B, C): 215 # [ 1 ] 216 # [2 1 1 ] 217 # [3 4 2 1 0 2 ] 218 sp_a = self._SparseTensor_3x3() 219 sp_b = self._SparseTensor_3x5() 220 sp_c = self._SparseTensor_3x2() 221 222 for concat_dim in (-1, 1): 223 sp_concat = sparse_ops.sparse_concat(concat_dim, [sp_a, sp_b, sp_c]) 224 225 self.assertEqual(sp_concat.indices.get_shape(), [10, 2]) 226 self.assertEqual(sp_concat.values.get_shape(), [10]) 227 self.assertEqual(sp_concat.dense_shape.get_shape(), [2]) 228 229 concat_out = self.evaluate(sp_concat) 230 231 self.assertAllEqual(concat_out.indices, [[0, 2], [1, 0], [1, 4], [1, 8], 232 [2, 0], [2, 2], [2, 3], [2, 6], 233 [2, 7], [2, 8]]) 234 self.assertAllEqual(concat_out.values, [1, 2, 1, 1, 3, 4, 2, 1, 0, 2]) 235 self.assertAllEqual(concat_out.dense_shape, [3, 10]) 236 237 def testConcatNoNonZeros(self): 238 sp_a = self._SparseTensor_NoNonZeros((2, 3, 4)) 239 sp_b = self._SparseTensor_NoNonZeros((2, 7, 4)) 240 sp_c = self._SparseTensor_NoNonZeros((2, 5, 4)) 241 242 with self.session() as sess: 243 concat_dim = 1 244 sp_concat = sparse_ops.sparse_concat(concat_dim, [sp_a, sp_b, sp_c]) 245 246 self.assertEqual(sp_concat.indices.get_shape(), [0, 3]) 247 self.assertEqual(sp_concat.values.get_shape(), [0]) 248 self.assertEqual(sp_concat.dense_shape.get_shape(), [3]) 249 250 concat_out = self.evaluate(sp_concat) 251 252 self.assertEqual(concat_out.indices.shape, (0, 3)) 253 self.assertEqual(concat_out.values.shape, (0,)) 254 self.assertAllEqual(concat_out.dense_shape, [2, 15, 4]) 255 256 def testConcatSomeNoNonZeros(self): 257 sp_a = self._SparseTensor_NoNonZeros((2, 7, 4)) 258 sp_b = self._SparseTensor_2x3x4() 259 sp_c = self._SparseTensor_NoNonZeros((2, 5, 4)) 260 output_nnz = sp_b.indices.get_shape()[0] 261 262 with self.session() as sess: 263 concat_dim = 1 264 sp_concat = sparse_ops.sparse_concat(concat_dim, [sp_a, sp_b, sp_c]) 265 266 self.assertEqual(sp_concat.indices.get_shape(), [output_nnz, 3]) 267 self.assertEqual(sp_concat.values.get_shape(), [output_nnz]) 268 self.assertEqual(sp_concat.dense_shape.get_shape(), [3]) 269 270 concat_out = self.evaluate(sp_concat) 271 272 self.assertAllEqual(concat_out.indices, 273 sp_b.indices + [0, sp_a.dense_shape[1], 0]) 274 self.assertAllEqual(concat_out.values, sp_b.values) 275 self.assertAllEqual(concat_out.dense_shape, [2, 15, 4]) 276 277 def testConcatNonNumeric(self): 278 with self.session(use_gpu=False) as sess: 279 # concat(A, B): 280 # [ a ] 281 # [b e ] 282 # [c d f g h] 283 sp_a = self._SparseTensor_String3x3() 284 sp_b = self._SparseTensor_String3x5() 285 286 for concat_dim in (-1, 1): 287 sp_concat = sparse_ops.sparse_concat(concat_dim, [sp_a, sp_b]) 288 289 self.assertEqual(sp_concat.indices.get_shape(), [8, 2]) 290 self.assertEqual(sp_concat.values.get_shape(), [8]) 291 self.assertEqual(sp_concat.dense_shape.get_shape(), [2]) 292 293 concat_out = self.evaluate(sp_concat) 294 295 self.assertAllEqual( 296 concat_out.indices, 297 [[0, 2], [1, 0], [1, 4], [2, 0], [2, 2], [2, 3], [2, 6], [2, 7]]) 298 self.assertAllEqual(concat_out.values, 299 [b"a", b"b", b"e", b"c", b"d", b"f", b"g", b"h"]) 300 self.assertAllEqual(concat_out.dense_shape, [3, 8]) 301 302 @test_util.run_deprecated_v1 303 def testMismatchedRank(self): 304 with self.session(): 305 sp_a = self._SparseTensor_3x3() 306 sp_e = self._SparseTensor_2x3x4() 307 308 # Rank mismatches can be caught at shape-inference time 309 for concat_dim in (-1, 1): 310 with self.assertRaises(ValueError): 311 sparse_ops.sparse_concat(concat_dim, [sp_a, sp_e]) 312 313 @test_util.run_deprecated_v1 314 def testMismatchedRankExpandNonconcatDim(self): 315 with self.session(): 316 sp_a = self._SparseTensor_3x3() 317 sp_e = self._SparseTensor_2x3x4() 318 319 # Rank mismatches should be caught at shape-inference time, even for 320 # expand_nonconcat_dim=True. 321 for concat_dim in (-1, 1): 322 with self.assertRaises(ValueError): 323 sparse_ops.sparse_concat( 324 concat_dim, [sp_a, sp_e], expand_nonconcat_dim=True) 325 326 @test_util.run_deprecated_v1 327 def testMismatchedShapes(self): 328 with self.session() as sess: 329 sp_a = self._SparseTensor_3x3() 330 sp_b = self._SparseTensor_3x5() 331 sp_c = self._SparseTensor_3x2() 332 sp_d = self._SparseTensor_2x3() 333 for concat_dim in (-1, 1): 334 sp_concat = sparse_ops.sparse_concat(concat_dim, 335 [sp_a, sp_b, sp_c, sp_d]) 336 337 # Shape mismatches can only be caught when the op is run 338 with self.assertRaisesOpError("Input shapes must match"): 339 self.evaluate(sp_concat) 340 341 def testMismatchedShapesExpandNonconcatDim(self): 342 with self.session() as sess: 343 sp_a = self._SparseTensor_3x3() 344 sp_b = self._SparseTensor_3x5() 345 sp_c = self._SparseTensor_3x2() 346 sp_d = self._SparseTensor_2x3() 347 for concat_dim0 in (-2, 0): 348 for concat_dim1 in (-1, 1): 349 sp_concat_dim0 = sparse_ops.sparse_concat( 350 concat_dim0, [sp_a, sp_b, sp_c, sp_d], expand_nonconcat_dim=True) 351 sp_concat_dim1 = sparse_ops.sparse_concat( 352 concat_dim1, [sp_a, sp_b, sp_c, sp_d], expand_nonconcat_dim=True) 353 354 sp_concat_dim0_out = self.evaluate(sp_concat_dim0) 355 sp_concat_dim1_out = self.evaluate(sp_concat_dim1) 356 357 self.assertAllEqual(sp_concat_dim0_out.indices, 358 [[0, 2], [1, 0], [2, 0], [2, 2], [4, 1], [5, 0], 359 [5, 3], [5, 4], [7, 0], [8, 0], [9, 1], [10, 0], 360 [10, 2]]) 361 self.assertAllEqual(sp_concat_dim0_out.values, 362 [1, 2, 3, 4, 1, 2, 1, 0, 1, 2, 1, 1, 2]) 363 self.assertAllEqual(sp_concat_dim0_out.dense_shape, [11, 5]) 364 365 self.assertAllEqual(sp_concat_dim1_out.indices, 366 [[0, 2], [0, 11], [1, 0], [1, 4], [1, 8], [1, 10], 367 [1, 12], [2, 0], [2, 2], [2, 3], [2, 6], [2, 7], 368 [2, 8]]) 369 self.assertAllEqual(sp_concat_dim1_out.values, 370 [1, 1, 2, 1, 1, 1, 2, 3, 4, 2, 1, 0, 2]) 371 self.assertAllEqual(sp_concat_dim1_out.dense_shape, [3, 13]) 372 373 @test_util.run_deprecated_v1 374 def testShapeInferenceUnknownShapes(self): 375 with self.session(): 376 sp_inputs = [ 377 self._SparseTensor_UnknownShape(), 378 self._SparseTensor_UnknownShape(val_shape=[3]), 379 self._SparseTensor_UnknownShape(ind_shape=[1, 3]), 380 self._SparseTensor_UnknownShape(shape_shape=[3]) 381 ] 382 383 for concat_dim in (-2, 0): 384 sp_concat = sparse_ops.sparse_concat(concat_dim, sp_inputs) 385 386 self.assertEqual(sp_concat.indices.get_shape().as_list(), [None, 3]) 387 self.assertEqual(sp_concat.values.get_shape().as_list(), [None]) 388 self.assertEqual(sp_concat.dense_shape.get_shape(), [3]) 389 390 def testConcatShape(self): 391 # Test case for GitHub 21964. 392 x = sparse_tensor.SparseTensor( 393 indices=[[0, 0], [1, 1]], values=[1, 2], dense_shape=[2, 2]) 394 y = sparse_tensor.SparseTensor( 395 indices=[[0, 0], [1, 1]], values=[1, 2], dense_shape=[2, 2]) 396 z = sparse_ops.sparse_concat(-1, [x, y]) 397 self.assertEqual(z.get_shape().as_list(), [2, 4]) 398 399 400if __name__ == "__main__": 401 test.main() 402