1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tf.ragged.cross and tf.ragged.cross_hashed.""" 16 17from absl.testing import parameterized 18 19import numpy as np 20 21from tensorflow.python.eager import def_function 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import errors 24from tensorflow.python.framework import ops 25from tensorflow.python.framework import sparse_tensor 26from tensorflow.python.framework import tensor_spec 27from tensorflow.python.framework import test_util 28from tensorflow.python.ops import gen_ragged_array_ops 29from tensorflow.python.ops import sparse_ops 30from tensorflow.python.ops.ragged import ragged_array_ops 31from tensorflow.python.ops.ragged import ragged_factory_ops 32from tensorflow.python.ops.ragged import ragged_tensor 33from tensorflow.python.platform import googletest 34 35ragged_const = ragged_factory_ops.constant_value 36dense_const = np.array 37 38 39def sparse_const(matrix): 40 indices = [] 41 values = [] 42 for i, row in enumerate(matrix): 43 for j, val in enumerate(row): 44 indices.append([i, j]) 45 values.append(val) 46 shape = [len(matrix), max(len(row) for row in matrix)] if matrix else [0, 0] 47 if not values: 48 indices = np.zeros([0, 2], dtype=np.int64) 49 values = np.zeros([0], dtype=np.int64) 50 return sparse_tensor.SparseTensorValue(indices, values, shape) 51 52 53@test_util.run_all_in_graph_and_eager_modes 54class RaggedCrossOpTest(test_util.TensorFlowTestCase, parameterized.TestCase): 55 56 @parameterized.named_parameters([ 57 dict( 58 testcase_name='NoInputs', 59 inputs=[], 60 expected=ragged_const([], ragged_rank=1, dtype=dtypes.int32)), 61 dict( 62 testcase_name='OneInput_RaggedStr', 63 inputs=[ragged_const([['a', 'b'], [], ['c']])], 64 expected=ragged_const([[b'a', b'b'], [], [b'c']])), 65 dict( 66 testcase_name='OneInput_RaggedInt', 67 inputs=[ragged_const([[1, 2, 3], [4, 5]])], 68 expected=ragged_const([[b'1', b'2', b'3'], [b'4', b'5']])), 69 dict( 70 testcase_name='OneInput_DenseInt', 71 inputs=[dense_const([[1, 2, 3], [4, 5, 6]])], 72 expected=ragged_const([[b'1', b'2', b'3'], [b'4', b'5', b'6']])), 73 dict( 74 testcase_name='OneInput_SparseStr', 75 inputs=[sparse_const([['a', 'b'], [], ['c']])], 76 expected=ragged_const([[b'a', b'b'], [], [b'c']])), 77 dict( 78 testcase_name='TwoInputs_RaggedStr_RaggedStr', 79 inputs=[ 80 ragged_const([['a', 'b'], [], ['c']]), 81 ragged_const([['d', 'e'], ['f'], ['g']]) 82 ], 83 expected=ragged_const([[b'a_X_d', b'a_X_e', b'b_X_d', b'b_X_e'], [], 84 [b'c_X_g']])), 85 dict( 86 testcase_name='TwoInputs_RaggedInt_RaggedInt', 87 inputs=[ 88 ragged_const([[1, 2], [], [3]]), 89 ragged_const([[4, 5, 6], [], [7]]) 90 ], 91 expected=ragged_const( 92 [[b'1_X_4', b'1_X_5', b'1_X_6', b'2_X_4', b'2_X_5', b'2_X_6'], [], 93 [b'3_X_7']])), 94 dict( 95 testcase_name='TwoInputs_RaggedStr_RaggedInt', 96 inputs=[ 97 ragged_const([['a', 'b'], [], ['c']]), 98 ragged_const([['1', '2'], ['3'], ['4']]) 99 ], 100 expected=ragged_const([[b'a_X_1', b'a_X_2', b'b_X_1', b'b_X_2'], [], 101 [b'c_X_4']])), 102 dict( 103 testcase_name='TwoInputs_SparseStr_SparseStr', 104 inputs=[ 105 sparse_const([['a', 'b'], [], ['c']]), 106 sparse_const([['d', 'e'], ['f'], ['g']]) 107 ], 108 expected=ragged_const([[b'a_X_d', b'a_X_e', b'b_X_d', b'b_X_e'], [], 109 [b'c_X_g']])), 110 dict( 111 testcase_name='TwoInputs_DenseInt_DenseInt', 112 inputs=[dense_const([[1, 2], [3, 4]]), 113 dense_const([[5, 6], [7, 8]])], 114 expected=ragged_const([[b'1_X_5', b'1_X_6', b'2_X_5', b'2_X_6'], 115 [b'3_X_7', b'3_X_8', b'4_X_7', b'4_X_8']])), 116 dict( 117 testcase_name='TwoInputs_DenseInt_DenseStr', 118 inputs=[ 119 dense_const([[1, 2], [3, 4]]), 120 dense_const([[b'5', b'6'], [b'7', b'8']]) 121 ], 122 expected=ragged_const([[b'1_X_5', b'1_X_6', b'2_X_5', b'2_X_6'], 123 [b'3_X_7', b'3_X_8', b'4_X_7', b'4_X_8']])), 124 dict( 125 testcase_name='TwoInputs_RaggedInt_DenseInt', 126 inputs=[ 127 ragged_const([[], [], [1, 2], [3]]), 128 dense_const([[1, 2], [3, 4], [5, 6], [7, 8]]) 129 ], 130 expected=ragged_const([[], [], 131 [b'1_X_5', b'1_X_6', b'2_X_5', b'2_X_6'], 132 [b'3_X_7', b'3_X_8']])), 133 dict( 134 # This test exercises `input_order`. 135 testcase_name='TwoInputs_DenseInt_RaggedStr', 136 inputs=[ 137 dense_const([[1, 2], [3, 4], [5, 6]]), 138 ragged_const([['d', 'e'], ['f'], ['g']]) 139 ], 140 expected=ragged_const([[b'1_X_d', b'1_X_e', b'2_X_d', b'2_X_e'], 141 [b'3_X_f', b'4_X_f'], [b'5_X_g', b'6_X_g']]), 142 matches_sparse_cross=False # sparse doesn't preserve input order. 143 ), 144 dict( 145 # This test exercises `input_order`. 146 testcase_name='TwoInputs_SparseInt_RaggedStr', 147 inputs=[ 148 sparse_const([[1, 2], [3, 4], [5, 6]]), 149 ragged_const([['d', 'e'], ['f'], ['g']]) 150 ], 151 expected=ragged_const([[b'1_X_d', b'1_X_e', b'2_X_d', b'2_X_e'], 152 [b'3_X_f', b'4_X_f'], [b'5_X_g', b'6_X_g']]), 153 matches_sparse_cross=False # sparse doesn't preserve input order. 154 ), 155 dict( 156 testcase_name='ThreeInputs_RaggedInt_RaggedInt_RaggedInt', 157 inputs=[ 158 ragged_const([[11], [12, 13], [], [14, 15]]), 159 ragged_const([[21, 22], [23], [24, 25], [26, 27]]), 160 ragged_const([[31], [32, 33], [34, 35], [36, 37]]) 161 ], 162 expected=ragged_const([[b'11_X_21_X_31', b'11_X_22_X_31'], 163 [ 164 b'12_X_23_X_32', b'12_X_23_X_33', 165 b'13_X_23_X_32', b'13_X_23_X_33' 166 ], [], 167 [ 168 b'14_X_26_X_36', b'14_X_26_X_37', 169 b'14_X_27_X_36', b'14_X_27_X_37', 170 b'15_X_26_X_36', b'15_X_26_X_37', 171 b'15_X_27_X_36', b'15_X_27_X_37' 172 ]])), 173 dict( 174 testcase_name='ThreeInputs_RaggedInt_SparseInt_DenseInt', 175 inputs=[ 176 ragged_const([[11], [12, 13], [], [14, 15]]), 177 sparse_const([[21, 22], [23], [24, 25], [26, 27]]), 178 dense_const([[31], [32], [33], [34]]) 179 ], 180 expected=ragged_const([[b'11_X_21_X_31', b'11_X_22_X_31'], 181 [ 182 b'12_X_23_X_32', 183 b'13_X_23_X_32', 184 ], [], 185 [ 186 b'14_X_26_X_34', 187 b'14_X_27_X_34', 188 b'15_X_26_X_34', 189 b'15_X_27_X_34', 190 ]])), 191 dict( 192 testcase_name='FiveInputs', 193 inputs=[ 194 ragged_const([[1]]), 195 dense_const([[2]]), 196 ragged_const([[3]]), 197 sparse_const([[4]]), 198 ragged_const([[5]]) 199 ], 200 expected=ragged_const([[b'1_X_2_X_3_X_4_X_5']]), 201 matches_sparse_cross=False # sparse doesn't preserve input order. 202 ), 203 dict( 204 testcase_name='Permutation_3x3x3', 205 inputs=[[['11', '12', '13']], [['21', '22', '23']], 206 [['31', '32', '33']]], 207 expected=[[ 208 b'11_X_21_X_31', b'11_X_21_X_32', b'11_X_21_X_33', 209 b'11_X_22_X_31', b'11_X_22_X_32', b'11_X_22_X_33', 210 b'11_X_23_X_31', b'11_X_23_X_32', b'11_X_23_X_33', 211 b'12_X_21_X_31', b'12_X_21_X_32', b'12_X_21_X_33', 212 b'12_X_22_X_31', b'12_X_22_X_32', b'12_X_22_X_33', 213 b'12_X_23_X_31', b'12_X_23_X_32', b'12_X_23_X_33', 214 b'13_X_21_X_31', b'13_X_21_X_32', b'13_X_21_X_33', 215 b'13_X_22_X_31', b'13_X_22_X_32', b'13_X_22_X_33', 216 b'13_X_23_X_31', b'13_X_23_X_32', b'13_X_23_X_33' 217 ]]), 218 dict( 219 testcase_name='BatchSizeZero', 220 inputs=[ 221 ragged_const([], ragged_rank=1, dtype=dtypes.int32), 222 sparse_const([]), 223 np.zeros([0, 3], dtype=np.int32), 224 ], 225 expected=ragged_const([], ragged_rank=1, dtype=dtypes.int32)), 226 dict( 227 testcase_name='ThreeInputs_OneEmpty', 228 inputs=[ 229 ragged_const([[1, 2]]), 230 ragged_const([[]], dtype=dtypes.int32), 231 ragged_const([[3, 4]]) 232 ], 233 expected=ragged_const([[]], dtype=dtypes.string)), 234 dict( 235 testcase_name='ThreeInputs_AllEmpty', 236 inputs=[ 237 ragged_const([[]], dtype=dtypes.int64), 238 ragged_const([[]], dtype=dtypes.string), 239 ragged_const([[]], dtype=dtypes.int32) 240 ], 241 expected=ragged_const([[]], ragged_rank=1, dtype=dtypes.string)), 242 dict( 243 testcase_name='HashedZeroBucketsDefaultKey', 244 inputs=[ 245 ragged_const([['batch1-FC1-F1']]), 246 ragged_const([['batch1-FC2-F1']]), 247 ragged_const([['batch1-FC3-F1']]) 248 ], 249 expected_hashed=ragged_const([[1971693436396284976]])), 250 dict( 251 testcase_name='Hashed100BucketsDefaultKey', 252 inputs=[ 253 ragged_const([['batch1-FC1-F1']]), 254 ragged_const([['batch1-FC2-F1']]), 255 ragged_const([['batch1-FC3-F1']]) 256 ], 257 num_buckets=100, 258 expected_hashed=ragged_const([[83]])), 259 dict( 260 testcase_name='HashedZeroBucketsCustomKey', 261 inputs=[ 262 ragged_const([['batch1-FC1-F1']]), 263 ragged_const([['batch1-FC2-F1']]), 264 ragged_const([['batch1-FC3-F1']]) 265 ], 266 hash_key=ragged_array_ops._DEFAULT_CROSS_HASH_KEY + 1, 267 expected_hashed=ragged_const([[4847552627144134031]])), 268 dict( 269 testcase_name='Hashed100BucketsCustomKey', 270 inputs=[ 271 ragged_const([['batch1-FC1-F1']]), 272 ragged_const([['batch1-FC2-F1']]), 273 ragged_const([['batch1-FC3-F1']]) 274 ], 275 num_buckets=100, 276 hash_key=ragged_array_ops._DEFAULT_CROSS_HASH_KEY + 1, 277 expected_hashed=ragged_const([[31]])), 278 dict( 279 testcase_name='HashedZeroKey', 280 inputs=[ 281 ragged_const([['batch1-FC1-F1']]), 282 ragged_const([['batch1-FC2-F1']]), 283 ragged_const([['batch1-FC3-F1']]) 284 ], 285 hash_key=0, 286 expected_hashed=ragged_const([[9077905385164735582]]), 287 matches_sparse_cross=False # sparse treats hash_key=0 as None. 288 ), 289 dict( 290 testcase_name='UInt64', 291 inputs=[ragged_const([[2**64 - 1]], dtype=dtypes.uint64)], 292 expected=ragged_const([[b'-1']])), 293 ]) 294 def testRaggedCross(self, 295 inputs, 296 num_buckets=0, 297 hash_key=None, 298 expected=None, 299 expected_hashed=None, 300 matches_sparse_cross=True): 301 ragged_cross = ragged_array_ops.cross(inputs) 302 ragged_cross_hashed = ragged_array_ops.cross_hashed(inputs, num_buckets, 303 hash_key) 304 305 if expected is not None: 306 self.assertAllEqual(ragged_cross, expected) 307 if expected_hashed is not None: 308 self.assertAllEqual(ragged_cross_hashed, expected_hashed) 309 310 if matches_sparse_cross: 311 # Check that ragged.cross & sparse.cross match. 312 sparse_inputs = [self._ragged_to_sparse(t) for t in inputs] 313 sparse_cross = sparse_ops.sparse_cross(sparse_inputs) 314 self.assertAllEqual(ragged_cross, 315 ragged_tensor.RaggedTensor.from_sparse(sparse_cross)) 316 317 # Check that ragged.cross_hashed & sparse.cross_hashed match. 318 sparse_inputs = [self._ragged_to_sparse(t) for t in inputs] 319 sparse_cross_hashed = sparse_ops.sparse_cross_hashed( 320 sparse_inputs, num_buckets, hash_key) 321 self.assertAllEqual( 322 ragged_cross_hashed, 323 ragged_tensor.RaggedTensor.from_sparse(sparse_cross_hashed)) 324 325 def testRaggedCrossLargeBatch(self): 326 batch_size = 5000 327 inputs = [ 328 ragged_const([[1, 2, 3]] * batch_size), 329 ragged_const([[b'4']] * batch_size), 330 dense_const([[5]] * batch_size), 331 sparse_const([[6, 7]] * batch_size) 332 ] 333 334 expected = [[ 335 b'1_X_4_X_5_X_6', b'1_X_4_X_5_X_7', b'2_X_4_X_5_X_6', b'2_X_4_X_5_X_7', 336 b'3_X_4_X_5_X_6', b'3_X_4_X_5_X_7' 337 ]] * batch_size 338 339 ragged_cross = ragged_array_ops.cross(inputs) 340 341 # Note: we don't use assertAllEqual here because if they don't match, 342 # then the code in assertAllEqual that tries to build the error message 343 # is very slow, causing the test to timeout. 344 # pylint: disable=g-generic-assert 345 self.assertTrue(self.evaluate(ragged_cross).to_list() == expected) 346 347 @parameterized.named_parameters([ 348 dict( 349 testcase_name='BadDType', 350 inputs=[ragged_const([[1.1], [2.2, 3.3]])], 351 message=r'Unexpected dtype for inputs\[0\]'), 352 dict( 353 testcase_name='StaticBatchSizeMismatch1', 354 inputs=[ragged_const([[1]]), 355 ragged_const([[2], [3]])], 356 exception=(ValueError, errors.InvalidArgumentError), 357 message='inputs must all have the same batch dimension size'), 358 dict( 359 testcase_name='StaticBatchSizeMismatch2', 360 inputs=[ragged_const([[1]]), 361 dense_const([[2], [3]])], 362 exception=(ValueError, errors.InvalidArgumentError), 363 message='inputs must all have the same batch dimension size'), 364 dict( 365 testcase_name='3DDenseTensor', 366 inputs=[dense_const([[[1]]])], 367 exception=(ValueError, errors.InvalidArgumentError), 368 message='tf.ragged.cross only supports inputs with rank=2'), 369 dict( 370 testcase_name='0DDenseTensor', 371 inputs=[dense_const(1)], 372 exception=(ValueError, errors.InvalidArgumentError), 373 message='tf.ragged.cross only supports inputs with rank=2'), 374 ]) 375 def testStaticError(self, inputs, exception=ValueError, message=None): 376 with self.assertRaisesRegex(exception, message): 377 ragged_array_ops.cross(inputs) 378 379 @parameterized.named_parameters([ 380 dict( 381 testcase_name='3DRaggedTensor', 382 inputs=[ragged_const([[[1]]], ragged_rank=1)], 383 message='tf.ragged.cross only supports inputs with rank=2'), 384 dict( 385 testcase_name='0DDenseTensor', 386 inputs=[dense_const(1)], 387 signature=[[tensor_spec.TensorSpec(None, dtypes.int32)]], 388 exception=(ValueError, errors.InvalidArgumentError), 389 message='tf.ragged.cross only supports inputs with rank=2'), 390 dict( 391 testcase_name='1DDenseTensor', 392 inputs=[dense_const([1])], 393 signature=[[tensor_spec.TensorSpec(None, dtypes.int32)]], 394 exception=(ValueError, errors.InvalidArgumentError), 395 message='tf.ragged.cross only supports inputs with rank=2'), 396 dict( 397 testcase_name='3DDenseTensor', 398 inputs=[dense_const([[[1]]])], 399 signature=[[tensor_spec.TensorSpec(None, dtypes.int32)]], 400 exception=(ValueError, errors.InvalidArgumentError), 401 message='tf.ragged.cross only supports inputs with rank=2'), 402 ]) 403 def testRuntimeError(self, 404 inputs, 405 exception=errors.InvalidArgumentError, 406 message=None, 407 signature=None): 408 @def_function.function(input_signature=signature) 409 def fn(x): 410 return ragged_array_ops.cross(x) 411 412 with self.assertRaisesRegex(exception, message): 413 self.evaluate(fn(inputs)) 414 415 def _ragged_to_sparse(self, t): 416 if ragged_tensor.is_ragged(t): 417 return ragged_tensor.convert_to_tensor_or_ragged_tensor(t).to_sparse() 418 elif sparse_tensor.is_sparse(t): 419 return sparse_tensor.SparseTensor.from_value(t) 420 else: 421 return ops.convert_to_tensor(t) 422 423 def testSparseValuesAndIndicesMustMatch(self): 424 with self.assertRaisesRegex( 425 (ValueError, errors.InvalidArgumentError), 426 'sparse indices and values must have the same length'): 427 self.evaluate(gen_ragged_array_ops.RaggedCross( 428 ragged_values=[], 429 ragged_row_splits=[], 430 sparse_indices=[[5]], 431 sparse_values=[], 432 sparse_shape=[5], 433 dense_inputs=[['a']], 434 input_order='RD', 435 hashed_output=False, 436 num_buckets=5, 437 hash_key=2, 438 out_values_type=dtypes.string, 439 out_row_splits_type=dtypes.int64)) 440 441 def testRaggedValuesAndSplitsMustMatch(self): 442 with self.assertRaisesRegex( 443 (ValueError, errors.InvalidArgumentError), 444 'ragged values and splits must have the same length'): 445 self.evaluate(gen_ragged_array_ops.RaggedCross( 446 ragged_values=[['a']], 447 ragged_row_splits=[], 448 sparse_indices=[], 449 sparse_values=[], 450 sparse_shape=[], 451 dense_inputs=[['a']], 452 input_order='RD', 453 hashed_output=False, 454 num_buckets=5, 455 hash_key=2, 456 out_values_type=dtypes.string, 457 out_row_splits_type=dtypes.int64)) 458 459 460if __name__ == '__main__': 461 googletest.main() 462