xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/ragged/ragged_cross_op_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tf.ragged.cross and tf.ragged.cross_hashed."""
16
17from absl.testing import parameterized
18
19import numpy as np
20
21from tensorflow.python.eager import def_function
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import errors
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import sparse_tensor
26from tensorflow.python.framework import tensor_spec
27from tensorflow.python.framework import test_util
28from tensorflow.python.ops import gen_ragged_array_ops
29from tensorflow.python.ops import sparse_ops
30from tensorflow.python.ops.ragged import ragged_array_ops
31from tensorflow.python.ops.ragged import ragged_factory_ops
32from tensorflow.python.ops.ragged import ragged_tensor
33from tensorflow.python.platform import googletest
34
35ragged_const = ragged_factory_ops.constant_value
36dense_const = np.array
37
38
39def sparse_const(matrix):
40  indices = []
41  values = []
42  for i, row in enumerate(matrix):
43    for j, val in enumerate(row):
44      indices.append([i, j])
45      values.append(val)
46  shape = [len(matrix), max(len(row) for row in matrix)] if matrix else [0, 0]
47  if not values:
48    indices = np.zeros([0, 2], dtype=np.int64)
49    values = np.zeros([0], dtype=np.int64)
50  return sparse_tensor.SparseTensorValue(indices, values, shape)
51
52
53@test_util.run_all_in_graph_and_eager_modes
54class RaggedCrossOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
55
56  @parameterized.named_parameters([
57      dict(
58          testcase_name='NoInputs',
59          inputs=[],
60          expected=ragged_const([], ragged_rank=1, dtype=dtypes.int32)),
61      dict(
62          testcase_name='OneInput_RaggedStr',
63          inputs=[ragged_const([['a', 'b'], [], ['c']])],
64          expected=ragged_const([[b'a', b'b'], [], [b'c']])),
65      dict(
66          testcase_name='OneInput_RaggedInt',
67          inputs=[ragged_const([[1, 2, 3], [4, 5]])],
68          expected=ragged_const([[b'1', b'2', b'3'], [b'4', b'5']])),
69      dict(
70          testcase_name='OneInput_DenseInt',
71          inputs=[dense_const([[1, 2, 3], [4, 5, 6]])],
72          expected=ragged_const([[b'1', b'2', b'3'], [b'4', b'5', b'6']])),
73      dict(
74          testcase_name='OneInput_SparseStr',
75          inputs=[sparse_const([['a', 'b'], [], ['c']])],
76          expected=ragged_const([[b'a', b'b'], [], [b'c']])),
77      dict(
78          testcase_name='TwoInputs_RaggedStr_RaggedStr',
79          inputs=[
80              ragged_const([['a', 'b'], [], ['c']]),
81              ragged_const([['d', 'e'], ['f'], ['g']])
82          ],
83          expected=ragged_const([[b'a_X_d', b'a_X_e', b'b_X_d', b'b_X_e'], [],
84                                 [b'c_X_g']])),
85      dict(
86          testcase_name='TwoInputs_RaggedInt_RaggedInt',
87          inputs=[
88              ragged_const([[1, 2], [], [3]]),
89              ragged_const([[4, 5, 6], [], [7]])
90          ],
91          expected=ragged_const(
92              [[b'1_X_4', b'1_X_5', b'1_X_6', b'2_X_4', b'2_X_5', b'2_X_6'], [],
93               [b'3_X_7']])),
94      dict(
95          testcase_name='TwoInputs_RaggedStr_RaggedInt',
96          inputs=[
97              ragged_const([['a', 'b'], [], ['c']]),
98              ragged_const([['1', '2'], ['3'], ['4']])
99          ],
100          expected=ragged_const([[b'a_X_1', b'a_X_2', b'b_X_1', b'b_X_2'], [],
101                                 [b'c_X_4']])),
102      dict(
103          testcase_name='TwoInputs_SparseStr_SparseStr',
104          inputs=[
105              sparse_const([['a', 'b'], [], ['c']]),
106              sparse_const([['d', 'e'], ['f'], ['g']])
107          ],
108          expected=ragged_const([[b'a_X_d', b'a_X_e', b'b_X_d', b'b_X_e'], [],
109                                 [b'c_X_g']])),
110      dict(
111          testcase_name='TwoInputs_DenseInt_DenseInt',
112          inputs=[dense_const([[1, 2], [3, 4]]),
113                  dense_const([[5, 6], [7, 8]])],
114          expected=ragged_const([[b'1_X_5', b'1_X_6', b'2_X_5', b'2_X_6'],
115                                 [b'3_X_7', b'3_X_8', b'4_X_7', b'4_X_8']])),
116      dict(
117          testcase_name='TwoInputs_DenseInt_DenseStr',
118          inputs=[
119              dense_const([[1, 2], [3, 4]]),
120              dense_const([[b'5', b'6'], [b'7', b'8']])
121          ],
122          expected=ragged_const([[b'1_X_5', b'1_X_6', b'2_X_5', b'2_X_6'],
123                                 [b'3_X_7', b'3_X_8', b'4_X_7', b'4_X_8']])),
124      dict(
125          testcase_name='TwoInputs_RaggedInt_DenseInt',
126          inputs=[
127              ragged_const([[], [], [1, 2], [3]]),
128              dense_const([[1, 2], [3, 4], [5, 6], [7, 8]])
129          ],
130          expected=ragged_const([[], [],
131                                 [b'1_X_5', b'1_X_6', b'2_X_5', b'2_X_6'],
132                                 [b'3_X_7', b'3_X_8']])),
133      dict(
134          # This test exercises `input_order`.
135          testcase_name='TwoInputs_DenseInt_RaggedStr',
136          inputs=[
137              dense_const([[1, 2], [3, 4], [5, 6]]),
138              ragged_const([['d', 'e'], ['f'], ['g']])
139          ],
140          expected=ragged_const([[b'1_X_d', b'1_X_e', b'2_X_d', b'2_X_e'],
141                                 [b'3_X_f', b'4_X_f'], [b'5_X_g', b'6_X_g']]),
142          matches_sparse_cross=False  # sparse doesn't preserve input order.
143      ),
144      dict(
145          # This test exercises `input_order`.
146          testcase_name='TwoInputs_SparseInt_RaggedStr',
147          inputs=[
148              sparse_const([[1, 2], [3, 4], [5, 6]]),
149              ragged_const([['d', 'e'], ['f'], ['g']])
150          ],
151          expected=ragged_const([[b'1_X_d', b'1_X_e', b'2_X_d', b'2_X_e'],
152                                 [b'3_X_f', b'4_X_f'], [b'5_X_g', b'6_X_g']]),
153          matches_sparse_cross=False  # sparse doesn't preserve input order.
154      ),
155      dict(
156          testcase_name='ThreeInputs_RaggedInt_RaggedInt_RaggedInt',
157          inputs=[
158              ragged_const([[11], [12, 13], [], [14, 15]]),
159              ragged_const([[21, 22], [23], [24, 25], [26, 27]]),
160              ragged_const([[31], [32, 33], [34, 35], [36, 37]])
161          ],
162          expected=ragged_const([[b'11_X_21_X_31', b'11_X_22_X_31'],
163                                 [
164                                     b'12_X_23_X_32', b'12_X_23_X_33',
165                                     b'13_X_23_X_32', b'13_X_23_X_33'
166                                 ], [],
167                                 [
168                                     b'14_X_26_X_36', b'14_X_26_X_37',
169                                     b'14_X_27_X_36', b'14_X_27_X_37',
170                                     b'15_X_26_X_36', b'15_X_26_X_37',
171                                     b'15_X_27_X_36', b'15_X_27_X_37'
172                                 ]])),
173      dict(
174          testcase_name='ThreeInputs_RaggedInt_SparseInt_DenseInt',
175          inputs=[
176              ragged_const([[11], [12, 13], [], [14, 15]]),
177              sparse_const([[21, 22], [23], [24, 25], [26, 27]]),
178              dense_const([[31], [32], [33], [34]])
179          ],
180          expected=ragged_const([[b'11_X_21_X_31', b'11_X_22_X_31'],
181                                 [
182                                     b'12_X_23_X_32',
183                                     b'13_X_23_X_32',
184                                 ], [],
185                                 [
186                                     b'14_X_26_X_34',
187                                     b'14_X_27_X_34',
188                                     b'15_X_26_X_34',
189                                     b'15_X_27_X_34',
190                                 ]])),
191      dict(
192          testcase_name='FiveInputs',
193          inputs=[
194              ragged_const([[1]]),
195              dense_const([[2]]),
196              ragged_const([[3]]),
197              sparse_const([[4]]),
198              ragged_const([[5]])
199          ],
200          expected=ragged_const([[b'1_X_2_X_3_X_4_X_5']]),
201          matches_sparse_cross=False  # sparse doesn't preserve input order.
202      ),
203      dict(
204          testcase_name='Permutation_3x3x3',
205          inputs=[[['11', '12', '13']], [['21', '22', '23']],
206                  [['31', '32', '33']]],
207          expected=[[
208              b'11_X_21_X_31', b'11_X_21_X_32', b'11_X_21_X_33',
209              b'11_X_22_X_31', b'11_X_22_X_32', b'11_X_22_X_33',
210              b'11_X_23_X_31', b'11_X_23_X_32', b'11_X_23_X_33',
211              b'12_X_21_X_31', b'12_X_21_X_32', b'12_X_21_X_33',
212              b'12_X_22_X_31', b'12_X_22_X_32', b'12_X_22_X_33',
213              b'12_X_23_X_31', b'12_X_23_X_32', b'12_X_23_X_33',
214              b'13_X_21_X_31', b'13_X_21_X_32', b'13_X_21_X_33',
215              b'13_X_22_X_31', b'13_X_22_X_32', b'13_X_22_X_33',
216              b'13_X_23_X_31', b'13_X_23_X_32', b'13_X_23_X_33'
217          ]]),
218      dict(
219          testcase_name='BatchSizeZero',
220          inputs=[
221              ragged_const([], ragged_rank=1, dtype=dtypes.int32),
222              sparse_const([]),
223              np.zeros([0, 3], dtype=np.int32),
224          ],
225          expected=ragged_const([], ragged_rank=1, dtype=dtypes.int32)),
226      dict(
227          testcase_name='ThreeInputs_OneEmpty',
228          inputs=[
229              ragged_const([[1, 2]]),
230              ragged_const([[]], dtype=dtypes.int32),
231              ragged_const([[3, 4]])
232          ],
233          expected=ragged_const([[]], dtype=dtypes.string)),
234      dict(
235          testcase_name='ThreeInputs_AllEmpty',
236          inputs=[
237              ragged_const([[]], dtype=dtypes.int64),
238              ragged_const([[]], dtype=dtypes.string),
239              ragged_const([[]], dtype=dtypes.int32)
240          ],
241          expected=ragged_const([[]], ragged_rank=1, dtype=dtypes.string)),
242      dict(
243          testcase_name='HashedZeroBucketsDefaultKey',
244          inputs=[
245              ragged_const([['batch1-FC1-F1']]),
246              ragged_const([['batch1-FC2-F1']]),
247              ragged_const([['batch1-FC3-F1']])
248          ],
249          expected_hashed=ragged_const([[1971693436396284976]])),
250      dict(
251          testcase_name='Hashed100BucketsDefaultKey',
252          inputs=[
253              ragged_const([['batch1-FC1-F1']]),
254              ragged_const([['batch1-FC2-F1']]),
255              ragged_const([['batch1-FC3-F1']])
256          ],
257          num_buckets=100,
258          expected_hashed=ragged_const([[83]])),
259      dict(
260          testcase_name='HashedZeroBucketsCustomKey',
261          inputs=[
262              ragged_const([['batch1-FC1-F1']]),
263              ragged_const([['batch1-FC2-F1']]),
264              ragged_const([['batch1-FC3-F1']])
265          ],
266          hash_key=ragged_array_ops._DEFAULT_CROSS_HASH_KEY + 1,
267          expected_hashed=ragged_const([[4847552627144134031]])),
268      dict(
269          testcase_name='Hashed100BucketsCustomKey',
270          inputs=[
271              ragged_const([['batch1-FC1-F1']]),
272              ragged_const([['batch1-FC2-F1']]),
273              ragged_const([['batch1-FC3-F1']])
274          ],
275          num_buckets=100,
276          hash_key=ragged_array_ops._DEFAULT_CROSS_HASH_KEY + 1,
277          expected_hashed=ragged_const([[31]])),
278      dict(
279          testcase_name='HashedZeroKey',
280          inputs=[
281              ragged_const([['batch1-FC1-F1']]),
282              ragged_const([['batch1-FC2-F1']]),
283              ragged_const([['batch1-FC3-F1']])
284          ],
285          hash_key=0,
286          expected_hashed=ragged_const([[9077905385164735582]]),
287          matches_sparse_cross=False  # sparse treats hash_key=0 as None.
288      ),
289      dict(
290          testcase_name='UInt64',
291          inputs=[ragged_const([[2**64 - 1]], dtype=dtypes.uint64)],
292          expected=ragged_const([[b'-1']])),
293  ])
294  def testRaggedCross(self,
295                      inputs,
296                      num_buckets=0,
297                      hash_key=None,
298                      expected=None,
299                      expected_hashed=None,
300                      matches_sparse_cross=True):
301    ragged_cross = ragged_array_ops.cross(inputs)
302    ragged_cross_hashed = ragged_array_ops.cross_hashed(inputs, num_buckets,
303                                                        hash_key)
304
305    if expected is not None:
306      self.assertAllEqual(ragged_cross, expected)
307    if expected_hashed is not None:
308      self.assertAllEqual(ragged_cross_hashed, expected_hashed)
309
310    if matches_sparse_cross:
311      # Check that ragged.cross & sparse.cross match.
312      sparse_inputs = [self._ragged_to_sparse(t) for t in inputs]
313      sparse_cross = sparse_ops.sparse_cross(sparse_inputs)
314      self.assertAllEqual(ragged_cross,
315                          ragged_tensor.RaggedTensor.from_sparse(sparse_cross))
316
317      # Check that ragged.cross_hashed & sparse.cross_hashed match.
318      sparse_inputs = [self._ragged_to_sparse(t) for t in inputs]
319      sparse_cross_hashed = sparse_ops.sparse_cross_hashed(
320          sparse_inputs, num_buckets, hash_key)
321      self.assertAllEqual(
322          ragged_cross_hashed,
323          ragged_tensor.RaggedTensor.from_sparse(sparse_cross_hashed))
324
325  def testRaggedCrossLargeBatch(self):
326    batch_size = 5000
327    inputs = [
328        ragged_const([[1, 2, 3]] * batch_size),
329        ragged_const([[b'4']] * batch_size),
330        dense_const([[5]] * batch_size),
331        sparse_const([[6, 7]] * batch_size)
332    ]
333
334    expected = [[
335        b'1_X_4_X_5_X_6', b'1_X_4_X_5_X_7', b'2_X_4_X_5_X_6', b'2_X_4_X_5_X_7',
336        b'3_X_4_X_5_X_6', b'3_X_4_X_5_X_7'
337    ]] * batch_size
338
339    ragged_cross = ragged_array_ops.cross(inputs)
340
341    # Note: we don't use assertAllEqual here because if they don't match,
342    # then the code in assertAllEqual that tries to build the error message
343    # is very slow, causing the test to timeout.
344    # pylint: disable=g-generic-assert
345    self.assertTrue(self.evaluate(ragged_cross).to_list() == expected)
346
347  @parameterized.named_parameters([
348      dict(
349          testcase_name='BadDType',
350          inputs=[ragged_const([[1.1], [2.2, 3.3]])],
351          message=r'Unexpected dtype for inputs\[0\]'),
352      dict(
353          testcase_name='StaticBatchSizeMismatch1',
354          inputs=[ragged_const([[1]]),
355                  ragged_const([[2], [3]])],
356          exception=(ValueError, errors.InvalidArgumentError),
357          message='inputs must all have the same batch dimension size'),
358      dict(
359          testcase_name='StaticBatchSizeMismatch2',
360          inputs=[ragged_const([[1]]),
361                  dense_const([[2], [3]])],
362          exception=(ValueError, errors.InvalidArgumentError),
363          message='inputs must all have the same batch dimension size'),
364      dict(
365          testcase_name='3DDenseTensor',
366          inputs=[dense_const([[[1]]])],
367          exception=(ValueError, errors.InvalidArgumentError),
368          message='tf.ragged.cross only supports inputs with rank=2'),
369      dict(
370          testcase_name='0DDenseTensor',
371          inputs=[dense_const(1)],
372          exception=(ValueError, errors.InvalidArgumentError),
373          message='tf.ragged.cross only supports inputs with rank=2'),
374  ])
375  def testStaticError(self, inputs, exception=ValueError, message=None):
376    with self.assertRaisesRegex(exception, message):
377      ragged_array_ops.cross(inputs)
378
379  @parameterized.named_parameters([
380      dict(
381          testcase_name='3DRaggedTensor',
382          inputs=[ragged_const([[[1]]], ragged_rank=1)],
383          message='tf.ragged.cross only supports inputs with rank=2'),
384      dict(
385          testcase_name='0DDenseTensor',
386          inputs=[dense_const(1)],
387          signature=[[tensor_spec.TensorSpec(None, dtypes.int32)]],
388          exception=(ValueError, errors.InvalidArgumentError),
389          message='tf.ragged.cross only supports inputs with rank=2'),
390      dict(
391          testcase_name='1DDenseTensor',
392          inputs=[dense_const([1])],
393          signature=[[tensor_spec.TensorSpec(None, dtypes.int32)]],
394          exception=(ValueError, errors.InvalidArgumentError),
395          message='tf.ragged.cross only supports inputs with rank=2'),
396      dict(
397          testcase_name='3DDenseTensor',
398          inputs=[dense_const([[[1]]])],
399          signature=[[tensor_spec.TensorSpec(None, dtypes.int32)]],
400          exception=(ValueError, errors.InvalidArgumentError),
401          message='tf.ragged.cross only supports inputs with rank=2'),
402  ])
403  def testRuntimeError(self,
404                       inputs,
405                       exception=errors.InvalidArgumentError,
406                       message=None,
407                       signature=None):
408    @def_function.function(input_signature=signature)
409    def fn(x):
410      return ragged_array_ops.cross(x)
411
412    with self.assertRaisesRegex(exception, message):
413      self.evaluate(fn(inputs))
414
415  def _ragged_to_sparse(self, t):
416    if ragged_tensor.is_ragged(t):
417      return ragged_tensor.convert_to_tensor_or_ragged_tensor(t).to_sparse()
418    elif sparse_tensor.is_sparse(t):
419      return sparse_tensor.SparseTensor.from_value(t)
420    else:
421      return ops.convert_to_tensor(t)
422
423  def testSparseValuesAndIndicesMustMatch(self):
424    with self.assertRaisesRegex(
425        (ValueError, errors.InvalidArgumentError),
426        'sparse indices and values must have the same length'):
427      self.evaluate(gen_ragged_array_ops.RaggedCross(
428          ragged_values=[],
429          ragged_row_splits=[],
430          sparse_indices=[[5]],
431          sparse_values=[],
432          sparse_shape=[5],
433          dense_inputs=[['a']],
434          input_order='RD',
435          hashed_output=False,
436          num_buckets=5,
437          hash_key=2,
438          out_values_type=dtypes.string,
439          out_row_splits_type=dtypes.int64))
440
441  def testRaggedValuesAndSplitsMustMatch(self):
442    with self.assertRaisesRegex(
443        (ValueError, errors.InvalidArgumentError),
444        'ragged values and splits must have the same length'):
445      self.evaluate(gen_ragged_array_ops.RaggedCross(
446          ragged_values=[['a']],
447          ragged_row_splits=[],
448          sparse_indices=[],
449          sparse_values=[],
450          sparse_shape=[],
451          dense_inputs=[['a']],
452          input_order='RD',
453          hashed_output=False,
454          num_buckets=5,
455          hash_key=2,
456          out_values_type=dtypes.string,
457          out_row_splits_type=dtypes.int64))
458
459
460if __name__ == '__main__':
461  googletest.main()
462