1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for third_party.tensorflow.python.ops.ragged_tensor.""" 16 17import functools 18from absl.testing import parameterized 19import numpy as np 20 21from tensorflow.core.framework import full_type_pb2 22from tensorflow.python.data.ops import dataset_ops 23from tensorflow.python.eager import backprop 24from tensorflow.python.eager import context 25from tensorflow.python.eager import def_function 26from tensorflow.python.framework import constant_op 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import errors 29from tensorflow.python.framework import ops 30from tensorflow.python.framework import tensor_shape 31from tensorflow.python.framework import tensor_spec 32from tensorflow.python.framework import test_util 33from tensorflow.python.framework import type_spec 34from tensorflow.python.framework.type_utils import fulltypes_for_flat_tensors 35from tensorflow.python.ops import array_grad # pylint: disable=unused-import 36from tensorflow.python.ops import array_ops 37from tensorflow.python.ops import control_flow_ops 38from tensorflow.python.ops import gen_ragged_conversion_ops 39from tensorflow.python.ops import gradients_impl 40from tensorflow.python.ops import map_fn 41from tensorflow.python.ops import math_grad # pylint: disable=unused-import 42from tensorflow.python.ops import math_ops 43from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import 44from tensorflow.python.ops.ragged import ragged_concat_ops 45from tensorflow.python.ops.ragged import ragged_factory_ops 46from tensorflow.python.ops.ragged import ragged_gather_ops 47from tensorflow.python.ops.ragged import ragged_math_ops 48from tensorflow.python.ops.ragged import ragged_tensor 49from tensorflow.python.ops.ragged import ragged_tensor_value 50from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensor 51from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensorSpec 52from tensorflow.python.ops.ragged.row_partition import RowPartition 53 54from tensorflow.python.platform import googletest 55from tensorflow.python.util import nest 56 57 58def int32array(values): 59 return np.array(values, dtype=np.int32) 60 61 62@test_util.run_all_in_graph_and_eager_modes 63class RaggedTensorTest(test_util.TensorFlowTestCase, parameterized.TestCase): 64 longMessage = True # Property in unittest.Testcase. pylint: disable=invalid-name 65 66 #============================================================================= 67 # RaggedTensor class docstring examples 68 #============================================================================= 69 70 def testClassDocStringExamples(self): 71 # From section: "Component Tensors" 72 rt = RaggedTensor.from_row_splits( 73 values=[3, 1, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8]) 74 self.assertAllEqual(rt, [[3, 1, 4, 1], [], [5, 9, 2], [6], []]) 75 del rt 76 77 # From section: "Alternative Row-Partitioning Schemes" 78 values = [3, 1, 4, 1, 5, 9, 2, 6] 79 rt1 = RaggedTensor.from_row_splits(values, row_splits=[0, 4, 4, 7, 8, 8]) 80 rt2 = RaggedTensor.from_row_lengths(values, row_lengths=[4, 0, 3, 1, 0]) 81 rt3 = RaggedTensor.from_value_rowids( 82 values, value_rowids=[0, 0, 0, 0, 2, 2, 2, 3], nrows=5) 83 rt4 = RaggedTensor.from_row_starts(values, row_starts=[0, 4, 4, 7, 8]) 84 rt5 = RaggedTensor.from_row_limits(values, row_limits=[4, 4, 7, 8, 8]) 85 for rt in (rt1, rt2, rt3, rt4, rt5): 86 self.assertAllEqual(rt, [[3, 1, 4, 1], [], [5, 9, 2], [6], []]) 87 del rt1, rt2, rt3, rt4, rt5 88 89 # From section: "Multiple Ragged Dimensions" 90 inner_rt = RaggedTensor.from_row_splits( 91 values=[3, 1, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8]) 92 outer_rt = RaggedTensor.from_row_splits( 93 values=inner_rt, row_splits=[0, 3, 3, 5]) 94 self.assertEqual(outer_rt.ragged_rank, 2) 95 self.assertAllEqual(outer_rt, 96 [[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]]) 97 del inner_rt, outer_rt 98 99 # From section: "Multiple Ragged Dimensions" 100 rt = RaggedTensor.from_nested_row_splits( 101 flat_values=[3, 1, 4, 1, 5, 9, 2, 6], 102 nested_row_splits=([0, 3, 3, 5], [0, 4, 4, 7, 8, 8])) 103 self.assertAllEqual(rt, [[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]]) 104 del rt 105 106 # From section: "Uniform Inner Dimensions" 107 rt = RaggedTensor.from_row_splits( 108 values=array_ops.ones([5, 3]), row_splits=[0, 2, 5]) 109 self.assertAllEqual( 110 rt, [[[1, 1, 1], [1, 1, 1]], [[1, 1, 1], [1, 1, 1], [1, 1, 1]]]) 111 self.assertEqual(rt.shape.as_list(), [2, None, 3]) 112 del rt 113 114 #============================================================================= 115 # RaggedTensorValue Constructor 116 #============================================================================= 117 118 def testRaggedTensorValueConstruction(self): 119 values = np.array(b'a b c d e f g'.split()) 120 splits = np.array([0, 2, 5, 6, 6, 7], dtype=np.int64) 121 splits2 = np.array([0, 3, 5], dtype=np.int64) 122 123 # Test construction of a RaggedTensorValue with ragged_rank=1. 124 rt_value = ragged_tensor_value.RaggedTensorValue(values, splits) 125 self.assertEqual(rt_value.row_splits.dtype, np.int64) 126 self.assertEqual(rt_value.shape, (5, None)) 127 self.assertLen(rt_value.nested_row_splits, 1) 128 self.assertAllEqual(splits, rt_value.row_splits) 129 self.assertAllEqual(values, rt_value.values) 130 self.assertAllEqual(splits, rt_value.nested_row_splits[0]) 131 self.assertAllEqual(values, rt_value.flat_values) 132 133 # Test construction of a RaggedTensorValue with ragged_rank=2. 134 rt_value = ragged_tensor_value.RaggedTensorValue( 135 values=ragged_tensor_value.RaggedTensorValue(values, splits), 136 row_splits=splits2) 137 self.assertEqual(rt_value.row_splits.dtype, np.int64) 138 self.assertEqual(rt_value.shape, (2, None, None)) 139 self.assertLen(rt_value.nested_row_splits, 2) 140 self.assertAllEqual(splits2, rt_value.row_splits) 141 self.assertAllEqual(splits, rt_value.values.row_splits) 142 self.assertAllEqual(splits2, rt_value.nested_row_splits[0]) 143 self.assertAllEqual(splits, rt_value.nested_row_splits[1]) 144 self.assertAllEqual(values, rt_value.values.values) 145 self.assertAllEqual(values, rt_value.flat_values) 146 147 #============================================================================= 148 # RaggedTensor Constructor (private) 149 #============================================================================= 150 151 def testRaggedTensorConstruction(self): 152 values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) 153 row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64) 154 rp = RowPartition.from_row_splits(row_splits) 155 rt = RaggedTensor(values=values, row_partition=rp, internal=True) 156 157 self.assertAllEqual(rt, 158 [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']]) 159 160 def testRaggedTensorConstructionErrors(self): 161 values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) 162 row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64) 163 rp = RowPartition.from_row_splits(row_splits) 164 165 with self.assertRaisesRegex(ValueError, 166 'RaggedTensor constructor is private'): 167 RaggedTensor(values=values, row_partition=rp) 168 169 with self.assertRaisesRegex( 170 TypeError, r'type\(values\) must be one of: Tensor, RaggedTensor'): 171 RaggedTensor(values=range(7), row_partition=rp, internal=True) 172 173 with self.assertRaisesRegex( 174 TypeError, 'Argument `row_partition` must be a RowPartition'): 175 RaggedTensor( 176 values=values, row_partition=[0, 2, 2, 5, 6, 7], internal=True) 177 178 #============================================================================= 179 # RaggedTensor Factory Ops 180 #============================================================================= 181 182 def testFromValueRowIdsWithDerivedNRows(self): 183 # nrows is known at graph creation time. 184 values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) 185 value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64) 186 187 rt = RaggedTensor.from_value_rowids(values, value_rowids, validate=False) 188 self.assertEqual(rt.dtype, dtypes.string) 189 self.assertEqual(rt.shape.as_list(), [5, None]) 190 self.assertEqual(rt.ragged_rank, 1) 191 192 rt_values = rt.values 193 rt_value_rowids = rt.value_rowids() 194 rt_nrows = rt.nrows() 195 196 self.assertIs(rt_values, values) 197 self.assertIs(rt_value_rowids, value_rowids) # cached_value_rowids 198 self.assertAllEqual(rt_value_rowids, value_rowids) 199 self.assertAllEqual(rt_nrows, 5) 200 self.assertAllEqual(rt, 201 [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']]) 202 203 def testFromValueRowIdsWithDerivedNRowsDynamic(self): 204 # nrows is not known at graph creation time. 205 values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) 206 value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64) 207 value_rowids = array_ops.placeholder_with_default(value_rowids, shape=None) 208 209 rt = RaggedTensor.from_value_rowids(values, value_rowids, validate=False) 210 self.assertEqual(rt.dtype, dtypes.string) 211 if context.executing_eagerly(): 212 self.assertEqual(rt.shape.as_list(), [5, None]) 213 else: 214 self.assertEqual(rt.shape.as_list(), [None, None]) 215 self.assertEqual(rt.ragged_rank, 1) 216 217 rt_values = rt.values 218 rt_value_rowids = rt.value_rowids() 219 rt_nrows = rt.nrows() 220 221 self.assertIs(rt_values, values) 222 self.assertIs(rt_value_rowids, value_rowids) # cached_value_rowids 223 self.assertAllEqual(rt_value_rowids, value_rowids) 224 self.assertAllEqual(rt_nrows, 5) 225 self.assertAllEqual(rt, 226 [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']]) 227 228 def testFromValueRowIdsWithExplicitNRows(self): 229 values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) 230 value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64) 231 nrows = constant_op.constant(7, dtypes.int64) 232 233 rt = RaggedTensor.from_value_rowids( 234 values, value_rowids, nrows, validate=False) 235 self.assertEqual(rt.dtype, dtypes.string) 236 self.assertEqual(rt.shape.as_list(), [7, None]) 237 self.assertEqual(rt.ragged_rank, 1) 238 239 rt_values = rt.values 240 rt_value_rowids = rt.value_rowids() 241 rt_nrows = rt.nrows() 242 243 self.assertIs(rt_values, values) 244 self.assertIs(rt_value_rowids, value_rowids) # cached_value_rowids 245 self.assertIs(rt_nrows, nrows) # cached_nrows 246 self.assertAllEqual( 247 rt, [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g'], [], []]) 248 249 def testFromValueRowIdsWithExplicitNRowsEqualToDefault(self): 250 values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) 251 value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64) 252 nrows = constant_op.constant(5, dtypes.int64) 253 254 rt = RaggedTensor.from_value_rowids( 255 values, value_rowids, nrows, validate=False) 256 self.assertEqual(rt.dtype, dtypes.string) 257 self.assertEqual(rt.shape.as_list(), [5, None]) 258 self.assertEqual(rt.ragged_rank, 1) 259 260 rt_values = rt.values 261 rt_value_rowids = rt.value_rowids() 262 rt_nrows = rt.nrows() 263 264 self.assertIs(rt_values, values) 265 self.assertIs(rt_value_rowids, value_rowids) # cached_value_rowids 266 self.assertIs(rt_nrows, nrows) # cached_nrows 267 self.assertAllEqual(rt_value_rowids, value_rowids) 268 self.assertAllEqual(rt_nrows, nrows) 269 self.assertAllEqual(rt, 270 [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']]) 271 272 def testFromValueRowIdsWithEmptyValues(self): 273 rt = RaggedTensor.from_value_rowids([], []) 274 rt_nrows = rt.nrows() 275 self.assertEqual(rt.dtype, dtypes.float32) 276 self.assertEqual(rt.shape.as_list(), [0, None]) 277 self.assertEqual(rt.ragged_rank, 1) 278 self.assertEqual(rt.values.shape.as_list(), [0]) 279 self.assertEqual(rt.value_rowids().shape.as_list(), [0]) 280 self.assertAllEqual(rt_nrows, 0) 281 self.assertAllEqual(rt, []) 282 283 def testFromRowSplits(self): 284 values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) 285 row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64) 286 287 rt = RaggedTensor.from_row_splits(values, row_splits, validate=False) 288 self.assertEqual(rt.dtype, dtypes.string) 289 self.assertEqual(rt.shape.as_list(), [5, None]) 290 self.assertEqual(rt.ragged_rank, 1) 291 292 rt_values = rt.values 293 rt_row_splits = rt.row_splits 294 rt_nrows = rt.nrows() 295 296 self.assertIs(rt_values, values) 297 self.assertIs(rt_row_splits, row_splits) 298 self.assertAllEqual(rt_nrows, 5) 299 self.assertAllEqual(rt, 300 [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']]) 301 302 def testFromRowSplitsWithDifferentSplitTypes(self): 303 values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) 304 splits1 = [0, 2, 2, 5, 6, 7] 305 splits2 = np.array([0, 2, 2, 5, 6, 7], np.int64) 306 splits3 = np.array([0, 2, 2, 5, 6, 7], np.int32) 307 splits4 = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64) 308 splits5 = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int32) 309 rt1 = RaggedTensor.from_row_splits(values, splits1) 310 rt2 = RaggedTensor.from_row_splits(values, splits2) 311 rt3 = RaggedTensor.from_row_splits(values, splits3) 312 rt4 = RaggedTensor.from_row_splits(values, splits4) 313 rt5 = RaggedTensor.from_row_splits(values, splits5) 314 self.assertEqual(rt1.row_splits.dtype, dtypes.int64) 315 self.assertEqual(rt2.row_splits.dtype, dtypes.int64) 316 self.assertEqual(rt3.row_splits.dtype, dtypes.int32) 317 self.assertEqual(rt4.row_splits.dtype, dtypes.int64) 318 self.assertEqual(rt5.row_splits.dtype, dtypes.int32) 319 320 def testFromRowSplitsWithEmptySplits(self): 321 err_msg = 'row_splits tensor may not be empty' 322 with self.assertRaisesRegex(ValueError, err_msg): 323 RaggedTensor.from_row_splits([], []) 324 325 def testFromRowStarts(self): 326 values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) 327 row_starts = constant_op.constant([0, 2, 2, 5, 6], dtypes.int64) 328 329 rt = RaggedTensor.from_row_starts(values, row_starts, validate=False) 330 self.assertEqual(rt.dtype, dtypes.string) 331 self.assertEqual(rt.shape.as_list(), [5, None]) 332 self.assertEqual(rt.ragged_rank, 1) 333 334 rt_values = rt.values 335 rt_row_starts = rt.row_starts() 336 rt_nrows = rt.nrows() 337 338 self.assertIs(rt_values, values) 339 self.assertAllEqual(rt_nrows, 5) 340 self.assertAllEqual(rt_row_starts, row_starts) 341 self.assertAllEqual(rt, 342 [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']]) 343 344 def testFromRowLimits(self): 345 values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) 346 row_limits = constant_op.constant([2, 2, 5, 6, 7], dtypes.int64) 347 348 rt = RaggedTensor.from_row_limits(values, row_limits, validate=False) 349 self.assertEqual(rt.dtype, dtypes.string) 350 self.assertEqual(rt.shape.as_list(), [5, None]) 351 self.assertEqual(rt.ragged_rank, 1) 352 353 rt_values = rt.values 354 rt_row_limits = rt.row_limits() 355 rt_nrows = rt.nrows() 356 357 self.assertIs(rt_values, values) 358 self.assertAllEqual(rt_nrows, 5) 359 self.assertAllEqual(rt_row_limits, row_limits) 360 self.assertAllEqual(rt, 361 [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']]) 362 363 def testFromRowLengths(self): 364 values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) 365 row_lengths = constant_op.constant([2, 0, 3, 1, 1], dtypes.int64) 366 367 rt = RaggedTensor.from_row_lengths(values, row_lengths, validate=False) 368 self.assertEqual(rt.dtype, dtypes.string) 369 self.assertEqual(rt.shape.as_list(), [5, None]) 370 self.assertEqual(rt.ragged_rank, 1) 371 372 rt_values = rt.values 373 rt_row_lengths = rt.row_lengths() 374 rt_nrows = rt.nrows() 375 376 self.assertIs(rt_values, values) 377 self.assertIs(rt_row_lengths, row_lengths) # cached_nrows 378 self.assertAllEqual(rt_nrows, 5) 379 self.assertAllEqual(rt_row_lengths, row_lengths) 380 self.assertAllEqual(rt, 381 [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']]) 382 383 def testFromRowLengthsInt32(self): 384 rt = RaggedTensor.from_row_lengths([1, 2, 3, 4], 385 constant_op.constant([1, 0, 3], 386 dtype=dtypes.int32)) 387 rt2 = RaggedTensor.from_row_lengths(rt, [2, 1, 0]) 388 self.assertAllEqual([2, 1, 0], rt2.row_lengths()) 389 390 def testFromUniformRowLength(self): 391 values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16] 392 393 a1 = RaggedTensor.from_uniform_row_length(values, 2) 394 a2 = RaggedTensor.from_uniform_row_length(values, 2, 8) 395 self.assertAllEqual( 396 a1, 397 [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]]) 398 self.assertAllEqual(a1, a2) 399 self.assertEqual(a1.shape.as_list(), [8, 2]) 400 self.assertEqual(a2.shape.as_list(), [8, 2]) 401 402 b1 = RaggedTensor.from_uniform_row_length(a1, 2) 403 b2 = RaggedTensor.from_uniform_row_length(a1, 2, 4) 404 self.assertAllEqual(b1, [[[1, 2], [3, 4]], [[5, 6], [7, 8]], 405 [[9, 10], [11, 12]], [[13, 14], [15, 16]]]) 406 self.assertAllEqual(b1, b2) 407 self.assertEqual(b1.shape.as_list(), [4, 2, 2]) 408 self.assertEqual(b2.shape.as_list(), [4, 2, 2]) 409 410 c1 = RaggedTensor.from_uniform_row_length(b1, 2) 411 c2 = RaggedTensor.from_uniform_row_length(b1, 2, 2) 412 self.assertAllEqual(c1, [[[[1, 2], [3, 4]], [[5, 6], [7, 8]]], 413 [[[9, 10], [11, 12]], [[13, 14], [15, 16]]]]) 414 self.assertAllEqual(c1, c2) 415 self.assertEqual(c1.shape.as_list(), [2, 2, 2, 2]) 416 self.assertEqual(c2.shape.as_list(), [2, 2, 2, 2]) 417 418 def testFromUniformRowLengthWithEmptyValues(self): 419 empty_values = [] 420 a = RaggedTensor.from_uniform_row_length(empty_values, 0, nrows=10) 421 self.assertEqual(a.shape.as_list(), [10, 0]) 422 423 b = RaggedTensor.from_uniform_row_length(a, 2) 424 self.assertEqual(b.shape.as_list(), [5, 2, 0]) 425 426 # Make sure we avoid divide-by-zero when finding nrows for nvals=rowlen=0. 427 c = RaggedTensor.from_uniform_row_length(empty_values, 0) 428 self.assertEqual(c.shape.as_list(), [0, 0]) 429 d = RaggedTensor.from_uniform_row_length(empty_values, 0, nrows=0) 430 self.assertEqual(d.shape.as_list(), [0, 0]) 431 432 def testFromUniformRowLengthWithPlaceholders(self): 433 ph_values = array_ops.placeholder_with_default([1, 2, 3, 4, 5, 6], [None]) 434 ph_rowlen = array_ops.placeholder_with_default(3, None) 435 rt1 = RaggedTensor.from_uniform_row_length(ph_values, 3) 436 rt2 = RaggedTensor.from_uniform_row_length(ph_values, ph_rowlen) 437 rt3 = RaggedTensor.from_uniform_row_length([1, 2, 3, 4, 5, 6], ph_rowlen) 438 self.assertAllEqual(rt1, [[1, 2, 3], [4, 5, 6]]) 439 self.assertAllEqual(rt2, [[1, 2, 3], [4, 5, 6]]) 440 self.assertAllEqual(rt3, [[1, 2, 3], [4, 5, 6]]) 441 if context.executing_eagerly(): 442 self.assertEqual(rt1.shape.as_list(), [2, 3]) 443 self.assertEqual(rt2.shape.as_list(), [2, 3]) 444 self.assertEqual(rt3.shape.as_list(), [2, 3]) 445 else: 446 self.assertEqual(rt1.shape.as_list(), [None, 3]) 447 self.assertEqual(rt2.shape.as_list(), [None, None]) 448 self.assertEqual(rt3.shape.as_list(), [None, None]) 449 450 b = RaggedTensor.from_uniform_row_length(rt1, 2) 451 self.assertAllEqual(b, [[[1, 2, 3], [4, 5, 6]]]) 452 453 # Make sure we avoid divide-by-zero when finding nrows for nvals=rowlen=0. 454 ph_empty_values = array_ops.placeholder_with_default( 455 array_ops.zeros([0], dtypes.int64), [None]) 456 ph_zero = array_ops.placeholder_with_default(0, []) 457 c = RaggedTensor.from_uniform_row_length(ph_empty_values, ph_zero) 458 if context.executing_eagerly(): 459 self.assertEqual(c.shape.as_list(), [0, 0]) 460 else: 461 self.assertEqual(c.shape.as_list(), [None, None]) 462 463 def testFromNestedValueRowIdsWithDerivedNRows(self): 464 values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) 465 nested_value_rowids = [ 466 constant_op.constant([0, 0, 1, 3, 3], dtypes.int64), 467 constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64) 468 ] 469 470 rt = RaggedTensor.from_nested_value_rowids(values, nested_value_rowids) 471 self.assertEqual(rt.dtype, dtypes.string) 472 self.assertEqual(rt.shape.as_list(), [4, None, None]) 473 self.assertEqual(rt.ragged_rank, 2) 474 475 rt_values = rt.values 476 rt_value_rowids = rt.value_rowids() 477 rt_values_values = rt_values.values 478 rt_values_value_rowids = rt_values.value_rowids() 479 480 self.assertIs(rt_values_values, values) 481 self.assertAllEqual(rt_value_rowids, nested_value_rowids[0]) 482 self.assertAllEqual(rt_values_value_rowids, nested_value_rowids[1]) 483 self.assertAllEqual( 484 rt, [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]]) 485 486 def testFromNestedRowPartitions(self): 487 flat_values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) 488 nested_row_splits = [[0, 2, 3, 3, 5], [0, 2, 2, 5, 6, 7]] 489 nested_row_partition = [ 490 RowPartition.from_row_splits(constant_op.constant(x, dtypes.int64)) 491 for x in nested_row_splits 492 ] 493 494 rt = RaggedTensor._from_nested_row_partitions( 495 flat_values, nested_row_partition, validate=False) 496 self.assertEqual(rt.dtype, dtypes.string) 497 self.assertEqual(rt.shape.as_list(), [4, None, None]) 498 self.assertEqual(rt.ragged_rank, 2) 499 self.assertAllEqual( 500 rt, [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]]) 501 502 def testFromNestedValueRowIdsWithExplicitNRows(self): 503 values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) 504 nested_value_rowids = [ 505 constant_op.constant([0, 0, 1, 3, 3, 3], dtypes.int64), 506 constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64) 507 ] 508 nrows = [ 509 constant_op.constant(6, dtypes.int64), 510 constant_op.constant(6, dtypes.int64) 511 ] 512 513 rt = RaggedTensor.from_nested_value_rowids(values, nested_value_rowids, 514 nrows) 515 self.assertEqual(rt.dtype, dtypes.string) 516 self.assertEqual(rt.shape.as_list(), [6, None, None]) 517 self.assertEqual(rt.ragged_rank, 2) 518 519 rt_values = rt.values 520 rt_value_rowids = rt.value_rowids() 521 rt_nrows = rt.nrows() 522 rt_values_values = rt_values.values 523 rt_values_value_rowids = rt_values.value_rowids() 524 rt_values_nrows = rt_values.nrows() 525 526 self.assertIs(rt_values_values, values) 527 self.assertAllEqual(rt_value_rowids, nested_value_rowids[0]) 528 self.assertAllEqual(rt_values_value_rowids, nested_value_rowids[1]) 529 self.assertAllEqual(rt_nrows, nrows[0]) 530 self.assertAllEqual(rt_values_nrows, nrows[1]) 531 self.assertAllEqual(rt, [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], 532 [[b'f'], [b'g'], []], [], []]) 533 534 def testFromNestedValueRowIdsWithExplicitNRowsMismatch(self): 535 values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) 536 nested_value_rowids = [ 537 constant_op.constant([0, 0, 1, 3, 3, 3], dtypes.int64), 538 constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64) 539 ] 540 nrows = [constant_op.constant(6, dtypes.int64)] 541 with self.assertRaisesRegex( 542 ValueError, 'Argument `nested_nrows` must have the same length as ' 543 'argument `nested_value_rowids`'): 544 RaggedTensor.from_nested_value_rowids(values, nested_value_rowids, nrows) 545 546 def testFromNestedValueRowIdsWithNonListInput(self): 547 with self.assertRaisesRegex( 548 TypeError, 'Argument `nested_value_rowids` must be a list of Tensors'): 549 RaggedTensor.from_nested_value_rowids( 550 [1, 2, 3], constant_op.constant([[0, 1, 2], [0, 1, 2]], dtypes.int64)) 551 with self.assertRaisesRegex( 552 TypeError, 'Argument `nested_nrows` must be a list of Tensors'): 553 RaggedTensor.from_nested_value_rowids([1, 2, 3], [[0, 1, 2], [0, 1, 2]], 554 constant_op.constant([3, 3])) 555 556 def testFromNestedRowSplits(self): 557 flat_values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) 558 nested_row_splits = [ 559 constant_op.constant([0, 2, 3, 3, 5], dtypes.int64), 560 constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64) 561 ] 562 563 rt = RaggedTensor.from_nested_row_splits( 564 flat_values, nested_row_splits, validate=False) 565 self.assertEqual(rt.dtype, dtypes.string) 566 self.assertEqual(rt.shape.as_list(), [4, None, None]) 567 self.assertEqual(rt.ragged_rank, 2) 568 569 rt_values = rt.values 570 rt_row_splits = rt.row_splits 571 rt_values_values = rt_values.values 572 rt_values_row_splits = rt_values.row_splits 573 574 self.assertIs(rt_values_values, flat_values) 575 self.assertIs(rt_row_splits, nested_row_splits[0]) 576 self.assertIs(rt_values_row_splits, nested_row_splits[1]) 577 self.assertAllEqual( 578 rt, [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]]) 579 580 def testWithRowSplits(self): 581 flat_values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) 582 nested_row_splits = [ 583 constant_op.constant([0, 2, 3, 3, 5], dtypes.int64), 584 constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64) 585 ] 586 587 rt = RaggedTensor.from_nested_row_splits( 588 flat_values, nested_row_splits, validate=False) 589 590 rt = rt.with_row_splits_dtype(dtypes.int32) 591 592 self.assertEqual(rt.dtype, dtypes.string) 593 self.assertEqual(rt.shape.as_list(), [4, None, None]) 594 self.assertEqual(rt.ragged_rank, 2) 595 596 rt_values = rt.values 597 rt_row_splits = rt.row_splits 598 rt_values_values = rt_values.values 599 rt_values_row_splits = rt_values.row_splits 600 601 self.assertAllEqual(rt_values_values, flat_values) 602 self.assertAllEqual(rt_row_splits, nested_row_splits[0]) 603 self.assertAllEqual(rt_values_row_splits, nested_row_splits[1]) 604 self.assertAllEqual( 605 rt, [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]]) 606 607 def testFromNestedRowSplitsWithNonListInput(self): 608 with self.assertRaisesRegex( 609 TypeError, '`nested_row_splits` must be a list of Tensors'): 610 RaggedTensor.from_nested_row_splits( 611 [1, 2], constant_op.constant([[0, 1, 2], [0, 1, 2]], dtypes.int64)) 612 613 def testFromValueRowIdsWithBadNRows(self): 614 values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) 615 value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64) 616 nrows = constant_op.constant(5, dtypes.int64) 617 618 with self.assertRaisesRegex(ValueError, r'Expected nrows >= 0; got -2'): 619 RaggedTensor.from_value_rowids( 620 values=values, 621 value_rowids=array_ops.placeholder_with_default(value_rowids, None), 622 nrows=-2) 623 624 with self.assertRaisesRegex( 625 ValueError, r'Expected nrows >= value_rowids\[-1\] \+ 1; got nrows=2, ' 626 r'value_rowids\[-1\]=4'): 627 RaggedTensor.from_value_rowids( 628 values=values, value_rowids=value_rowids, nrows=2) 629 630 with self.assertRaisesRegex( 631 ValueError, r'Expected nrows >= value_rowids\[-1\] \+ 1; got nrows=4, ' 632 r'value_rowids\[-1\]=4'): 633 RaggedTensor.from_value_rowids( 634 values=values, value_rowids=value_rowids, nrows=4) 635 636 with self.assertRaisesRegex(ValueError, r'Shape \(7, 1\) must have rank 1'): 637 RaggedTensor.from_value_rowids( 638 values=values, 639 value_rowids=array_ops.expand_dims(value_rowids, 1), 640 nrows=nrows) 641 642 with self.assertRaisesRegex(ValueError, r'Shape \(1,\) must have rank 0'): 643 RaggedTensor.from_value_rowids( 644 values=values, 645 value_rowids=value_rowids, 646 nrows=array_ops.expand_dims(nrows, 0)) 647 648 def testCondWithTensorsFromValueIds(self): 649 # b/141166460 650 rt = RaggedTensor.from_value_rowids([1, 2, 3], [0, 0, 2]) 651 c = array_ops.placeholder_with_default(True, None) 652 result = control_flow_ops.cond(c, lambda: rt, lambda: rt) 653 self.assertAllEqual(rt, result) 654 655 def testGraphMismatch(self): 656 if not context.executing_eagerly(): 657 with ops.Graph().as_default(): 658 values = constant_op.constant([1, 2, 3], dtypes.int64) 659 with ops.Graph().as_default(): 660 splits = constant_op.constant([0, 2, 3], dtypes.int64) 661 with self.assertRaisesRegex(ValueError, 662 '.* must be from the same graph as .*'): 663 RaggedTensor.from_row_splits(values, splits) 664 665 @parameterized.named_parameters([ 666 dict( 667 testcase_name='Rank0', 668 tensor='a'), 669 dict( 670 testcase_name='Rank1', 671 tensor=['a', 'b']), 672 ]) 673 def testFromTensorRankError(self, tensor): 674 with self.assertRaisesRegex(ValueError, 'must be greater than 1'): 675 RaggedTensor.from_tensor(tensor) 676 677 #============================================================================= 678 # Ragged Value & Row-Partitioning Tensor Accessors 679 #============================================================================= 680 681 def testRaggedTensorAccessors_2d(self): 682 values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) 683 row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64) 684 value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64) 685 rt1 = RaggedTensor.from_row_splits(values, row_splits) 686 rt2 = RaggedTensor.from_value_rowids(values, value_rowids) 687 688 for rt in [rt1, rt2]: 689 self.assertAllEqual( 690 rt, [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']]) 691 self.assertAllEqual(rt.values, [b'a', b'b', b'c', b'd', b'e', b'f', b'g']) 692 self.assertEqual(rt.values.shape.dims[0].value, 7) 693 self.assertAllEqual(rt.value_rowids(), [0, 0, 2, 2, 2, 3, 4]) 694 self.assertAllEqual(rt.nrows(), 5) 695 self.assertAllEqual(rt.row_splits, [0, 2, 2, 5, 6, 7]) 696 self.assertAllEqual(rt.row_starts(), [0, 2, 2, 5, 6]) 697 self.assertAllEqual(rt.row_limits(), [2, 2, 5, 6, 7]) 698 self.assertAllEqual(rt.row_lengths(), [2, 0, 3, 1, 1]) 699 self.assertAllEqual(rt.flat_values, 700 [b'a', b'b', b'c', b'd', b'e', b'f', b'g']) 701 self.assertLen(rt.nested_row_splits, 1) 702 self.assertAllEqual(rt.nested_row_splits[0], [0, 2, 2, 5, 6, 7]) 703 704 def testRaggedTensorAccessors_3d_with_ragged_rank_1(self): 705 values = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13]] 706 row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64) 707 value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64) 708 row_lengths = constant_op.constant([2, 0, 3, 1, 1]) 709 rt1 = RaggedTensor.from_row_splits(values, row_splits) 710 rt2 = RaggedTensor.from_value_rowids(values, value_rowids) 711 rt3 = RaggedTensor.from_row_lengths(values, row_lengths) 712 713 for rt in [rt1, rt2, rt3]: 714 self.assertAllEqual(rt, [[[0, 1], [2, 3]], [], [[4, 5], [6, 7], [8, 9]], 715 [[10, 11]], [[12, 13]]]) 716 self.assertAllEqual( 717 rt.values, 718 [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13]]) 719 self.assertEqual(rt.values.shape.dims[0].value, 7) 720 self.assertAllEqual(rt.value_rowids(), [0, 0, 2, 2, 2, 3, 4]) 721 self.assertAllEqual(rt.nrows(), 5) 722 self.assertAllEqual(rt.row_splits, [0, 2, 2, 5, 6, 7]) 723 self.assertAllEqual(rt.row_starts(), [0, 2, 2, 5, 6]) 724 self.assertAllEqual(rt.row_limits(), [2, 2, 5, 6, 7]) 725 self.assertAllEqual(rt.row_lengths(), [2, 0, 3, 1, 1]) 726 self.assertAllEqual( 727 rt.row_lengths(axis=2), [[2, 2], [], [2, 2, 2], [2], [2]]) 728 self.assertAllEqual( 729 rt.flat_values, 730 [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13]]) 731 self.assertLen(rt.nested_row_splits, 1) 732 self.assertAllEqual(rt.nested_row_splits[0], [0, 2, 2, 5, 6, 7]) 733 self.assertLen(rt.nested_value_rowids(), 1) 734 735 self.assertAllEqual(rt.nested_value_rowids()[0], [0, 0, 2, 2, 2, 3, 4]) 736 737 def testRaggedTensorAccessors_3d_with_ragged_rank_2(self): 738 values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) 739 nested_row_splits = [ 740 constant_op.constant([0, 2, 3, 3, 5], dtypes.int64), 741 constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64) 742 ] 743 nested_value_rowids = [ 744 constant_op.constant([0, 0, 1, 3, 3], dtypes.int64), 745 constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64) 746 ] 747 rt1 = RaggedTensor.from_nested_row_splits(values, nested_row_splits) 748 rt2 = RaggedTensor.from_nested_value_rowids(values, nested_value_rowids) 749 750 for rt in [rt1, rt2]: 751 self.assertAllEqual( 752 rt, [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]]) 753 self.assertAllEqual( 754 rt.values, [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']]) 755 self.assertEqual(rt.values.shape.dims[0].value, 5) 756 self.assertAllEqual(rt.value_rowids(), [0, 0, 1, 3, 3]) 757 self.assertAllEqual(rt.nrows(), 4) 758 self.assertAllEqual(rt.row_splits, [0, 2, 3, 3, 5]) 759 self.assertAllEqual(rt.row_starts(), [0, 2, 3, 3]) 760 self.assertAllEqual(rt.row_limits(), [2, 3, 3, 5]) 761 self.assertAllEqual(rt.row_lengths(), [2, 1, 0, 2]) 762 self.assertAllEqual(rt.flat_values, 763 [b'a', b'b', b'c', b'd', b'e', b'f', b'g']) 764 self.assertLen(rt.nested_row_splits, 2) 765 self.assertAllEqual(rt.nested_row_splits[0], [0, 2, 3, 3, 5]) 766 self.assertAllEqual(rt.nested_row_splits[1], [0, 2, 2, 5, 6, 7]) 767 self.assertLen(rt.nested_value_rowids(), 2) 768 self.assertAllEqual(rt.nested_value_rowids()[0], [0, 0, 1, 3, 3]) 769 self.assertAllEqual(rt.nested_value_rowids()[1], [0, 0, 2, 2, 2, 3, 4]) 770 771 #============================================================================= 772 # RaggedTensor.shape 773 #============================================================================= 774 775 def testShape(self): 776 """Tests for RaggedTensor.shape.""" 777 rt1 = RaggedTensor.from_row_splits(b'a b c d e f g'.split(), 778 [0, 2, 5, 6, 6, 7]) 779 self.assertEqual(rt1.shape.as_list(), [5, None]) 780 781 rt2 = RaggedTensor.from_row_splits( 782 [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14]], 783 [0, 2, 5, 6, 6, 7]) 784 self.assertEqual(rt2.shape.as_list(), [5, None, 2]) 785 786 rt3 = RaggedTensor.from_row_splits( 787 [[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]], [0, 2, 2, 3]) 788 self.assertEqual(rt3.shape.as_list(), [3, None, 2, 2]) 789 790 rt4 = RaggedTensor.from_row_splits(rt3, [0, 1, 3, 3]) 791 self.assertEqual(rt4.shape.as_list(), [3, None, None, 2, 2]) 792 793 if not context.executing_eagerly(): 794 rt5 = RaggedTensor.from_row_splits( 795 array_ops.placeholder(dtype=dtypes.string), [0, 2, 3, 5]) 796 self.assertIsNone(rt5.shape.ndims) 797 798 rt6 = RaggedTensor.from_row_splits( 799 [1, 2, 3], array_ops.placeholder(dtype=dtypes.int64)) 800 self.assertEqual(rt6.shape.as_list(), [None, None]) 801 802 def testGetShape(self): 803 rt = RaggedTensor.from_row_splits(b'a b c d e f g'.split(), 804 [0, 2, 5, 6, 6, 7]) 805 self.assertEqual(rt.shape.as_list(), rt.get_shape().as_list()) 806 807 #============================================================================= 808 # RaggedTensor.__str__ 809 #============================================================================= 810 def testRaggedTensorStr(self): 811 values = [b'a', b'b', b'c', b'd', b'e', b'f', b'g'] 812 row_splits = [0, 2, 5, 6, 6, 7] 813 rt = RaggedTensor.from_row_splits(values, row_splits, validate=False) 814 splits_type = 'int64' 815 if context.executing_eagerly(): 816 expected_repr = '<tf.RaggedTensor {}>'.format([[b'a', b'b'], 817 [b'c', b'd', b'e'], [b'f'], 818 [], [b'g']]) 819 else: 820 expected_repr = ( 821 'tf.RaggedTensor(values=Tensor("RaggedFromRowSplits/values:0", ' 822 'shape=(7,), dtype=string), ' 823 'row_splits=Tensor(' 824 '"RaggedFromRowSplits/RowPartitionFromRowSplits/row_splits:0",' 825 ' shape=(6,), dtype={}))').format(splits_type) 826 self.assertEqual(repr(rt), expected_repr) 827 self.assertEqual(str(rt), expected_repr) 828 829 def testRaggedTensorValueStr(self): 830 values = [b'a', b'b', b'c', b'd', b'e', b'f', b'g'] 831 row_splits = [0, 2, 5, 6, 6, 7] 832 rt = ragged_tensor_value.RaggedTensorValue( 833 np.array(values), np.array(row_splits, dtype=np.int64)) 834 expected_str = '<tf.RaggedTensorValue {}>'.format([[b'a', b'b'], 835 [b'c', b'd', b'e'], 836 [b'f'], [], [b'g']]) 837 expected_repr = ("tf.RaggedTensorValue(values=array({}, dtype='|S1'), " 838 'row_splits=array({}))'.format(values, row_splits)) 839 self.assertEqual(' '.join(str(rt).split()), expected_str) 840 self.assertEqual(' '.join(repr(rt).split()), expected_repr) 841 842 def testRaggedTensorStrWithZeroSizeInnerShape(self): 843 # Tests that b/226112826 is fixed. 844 if context.executing_eagerly(): 845 rt = RaggedTensor.from_row_lengths(array_ops.zeros([9, 0]), [4, 3, 2]) 846 expected_repr = ( 847 '<tf.RaggedTensor [[[], [], [], []], [[], [], []], [[], []]]>') 848 self.assertEqual(' '.join(repr(rt).split()), expected_repr) 849 850 #============================================================================= 851 # RaggedTensor.with_values() and RaggedTensor.with_flat_values(). 852 #============================================================================= 853 854 def testWithValues(self): 855 rt1 = ragged_factory_ops.constant([[1, 2], [3, 4, 5], [6], [], [7]]) 856 rt2 = ragged_factory_ops.constant([[[1, 2], [3, 4, 5]], [[6]], [], [[], 857 [7]]]) 858 859 rt1_plus_10 = rt1.with_values(rt1.values + 10) 860 rt2_times_10 = rt2.with_flat_values(rt2.flat_values * 10) 861 rt1_expanded = rt1.with_values(array_ops.expand_dims(rt1.values, axis=1)) 862 863 self.assertAllEqual(rt1_plus_10, [[11, 12], [13, 14, 15], [16], [], [17]]) 864 self.assertAllEqual(rt2_times_10, 865 [[[10, 20], [30, 40, 50]], [[60]], [], [[], [70]]]) 866 self.assertAllEqual(rt1_expanded, 867 [[[1], [2]], [[3], [4], [5]], [[6]], [], [[7]]]) 868 869 #============================================================================= 870 # Session.run 871 #============================================================================= 872 def testSessionRun(self): 873 if context.executing_eagerly(): 874 return 875 876 rt1 = ragged_factory_ops.constant([[1, 2, 3], [4]]) 877 rt2 = ragged_factory_ops.constant([[[], [1, 2]], [[3]]]) 878 with self.test_session() as session: 879 result = session.run({'rt1': rt1, 'rt2': rt2}) 880 self.assertCountEqual(result.keys(), ['rt1', 'rt2']) 881 self.assertEqual(result['rt1'].to_list(), [[1, 2, 3], [4]]) 882 self.assertEqual(result['rt2'].to_list(), [[[], [1, 2]], [[3]]]) 883 884 def testSessionRunFeed(self): 885 if context.executing_eagerly(): 886 return 887 888 rt1 = RaggedTensor.from_row_splits( 889 array_ops.placeholder(dtypes.int32), 890 array_ops.placeholder(dtypes.int64)) 891 rt2 = RaggedTensor.from_nested_row_splits( 892 array_ops.placeholder(dtypes.int32), [ 893 array_ops.placeholder(dtypes.int64), 894 array_ops.placeholder(dtypes.int64) 895 ]) 896 897 rt1_feed_val = ragged_factory_ops.constant_value([[1, 2, 3], [4]]) 898 rt2_feed_val = ragged_factory_ops.constant_value([[[], [1, 2]], [[3]]]) 899 900 with self.test_session() as session: 901 fetches = {'rt1': rt1, 'rt2': rt2} 902 feeds = {rt1: rt1_feed_val, rt2: rt2_feed_val} 903 result = session.run(fetches, feed_dict=feeds) 904 self.assertCountEqual(result.keys(), ['rt1', 'rt2']) 905 self.assertEqual(result['rt1'].to_list(), [[1, 2, 3], [4]]) 906 self.assertEqual(result['rt2'].to_list(), [[[], [1, 2]], [[3]]]) 907 908 def testSessionPartialRunFeed(self): 909 if context.executing_eagerly(): 910 return 911 912 # Placeholder inputs. 913 a = RaggedTensor.from_row_splits( 914 array_ops.placeholder(dtypes.int32, shape=[None], name='a.values'), 915 array_ops.placeholder(dtypes.int64, name='a.row_splits')) 916 b = RaggedTensor.from_row_splits( 917 array_ops.placeholder(dtypes.int32, shape=[None], name='b.values'), 918 array_ops.placeholder(dtypes.int64, name='b.row_splits')) 919 c = array_ops.placeholder(dtypes.int32, shape=[], name='c') 920 921 # Feed values for placeholder inputs. 922 a_val = ragged_factory_ops.constant_value([[1, 2, 3], [4]]) 923 b_val = ragged_factory_ops.constant_value([[5, 4, 3], [2]]) 924 c_val = 3 925 926 # Compute some values. 927 r1 = ragged_math_ops.reduce_sum(a * b, axis=1) 928 r2 = ragged_math_ops.reduce_sum(a + c, axis=1) 929 930 with self.test_session() as session: 931 handle = session.partial_run_setup([r1, r2], [a, b, c]) 932 933 res1 = session.partial_run(handle, r1, feed_dict={a: a_val, b: b_val}) 934 self.assertAllEqual(res1, [22, 8]) 935 936 res2 = session.partial_run(handle, r2, feed_dict={c: c_val}) 937 self.assertAllEqual(res2, [15, 7]) 938 939 # Test case for GitHub issue 24679. 940 def testEagerForLoop(self): 941 if not context.executing_eagerly(): 942 return 943 944 values = [[1., 2.], [3., 4., 5.], [6.]] 945 r = ragged_factory_ops.constant(values) 946 i = 0 947 for elem in r: 948 self.assertAllEqual(elem, values[i]) 949 i += 1 950 951 def testConsumers(self): 952 if context.executing_eagerly(): 953 return 954 955 a = RaggedTensor.from_row_splits( 956 array_ops.placeholder(dtypes.int32, shape=[None], name='a.values'), 957 array_ops.placeholder(dtypes.int64, name='a.row_splits'), 958 validate=False) 959 ragged_math_ops.reduce_sum(a) 960 self.assertLen(a.consumers(), 1) 961 962 @parameterized.parameters([ 963 { 964 'descr': 'from_value_rowids', 965 'factory': RaggedTensor.from_value_rowids, 966 'test': RaggedTensor.value_rowids, 967 'values': { 968 'values': [1, 2, 3, 4, 5, 6], 969 'value_rowids': [0, 0, 1, 1, 2, 2], 970 }, 971 'tensor_field': 'value_rowids', 972 'value_rowids': [0, 1, 2], 973 'nrows': 10 974 }, 975 { 976 'descr': 'from_row_splits', 977 'factory': RaggedTensor.from_row_splits, 978 # row_splits is a property, not a function. 979 'test': (lambda rt: rt.row_splits), 980 'values': { 981 'values': [1, 2, 3, 4, 5, 6], 982 'row_splits': [0, 2, 4, 6], 983 }, 984 'tensor_field': 'row_splits', 985 'row_splits': [0, 1, 2, 3] 986 }, 987 { 988 'descr': 'from_row_lengths', 989 'factory': RaggedTensor.from_row_lengths, 990 'test': RaggedTensor.row_lengths, 991 'values': { 992 'values': [1, 2, 3, 4, 5, 6], 993 'row_lengths': [2, 2, 2], 994 }, 995 'tensor_field': 'row_lengths', 996 'row_lengths': [1, 1, 1], 997 }, 998 # from_row_starts 999 { 1000 'descr': 'from_row_starts', 1001 'factory': RaggedTensor.from_row_starts, 1002 'test': RaggedTensor.row_starts, 1003 'values': { 1004 'values': [1, 2, 3, 4, 5, 6], 1005 'row_starts': [0, 2, 4] 1006 }, 1007 'tensor_field': 'row_starts', 1008 'row_starts': [0, 1, 2] 1009 }, 1010 # from_row_limits 1011 { 1012 'descr': 'from_row_limits', 1013 'factory': RaggedTensor.from_row_limits, 1014 'test': RaggedTensor.row_limits, 1015 'values': { 1016 'values': [1, 2, 3, 4, 5, 6], 1017 'row_limits': [2, 4, 6] 1018 }, 1019 'tensor_field': 'row_limits', 1020 'row_limits': [3] 1021 }, 1022 # from_uniform_row_length 1023 { 1024 'descr': 'from_uniform_row_length', 1025 'factory': RaggedTensor.from_uniform_row_length, 1026 # One cannot extract uniform_row_length or nvals, so we return 1027 # nvals//nrows = uniform_row_length, where nvals = 3 1028 'test': (lambda rt: 3 // (rt.shape[0])), 1029 'values': { 1030 'values': [1, 2, 3, 4, 5, 6], 1031 'uniform_row_length': 2 1032 }, 1033 'tensor_field': 'uniform_row_length', 1034 'uniform_row_length': 3 1035 }, 1036 ]) 1037 def testFactoryTypePreference(self, descr, test, factory, values, 1038 tensor_field, **kwargs): 1039 # When input tensors have shape information, some of these errors will be 1040 # detected statically. 1041 def op_cast(k, v): 1042 if k == tensor_field: 1043 return constant_op.constant(v, dtype=dtypes.int32) 1044 else: 1045 return v 1046 1047 value_copy = {k: op_cast(k, v) for k, v in values.items()} 1048 rt = factory(**value_copy) 1049 1050 kw_copy = {k: v for k, v in kwargs.items()} 1051 kw_copy['values'] = rt 1052 rt2 = factory(**kw_copy) 1053 self.assertAllEqual(kwargs[tensor_field], test(rt2)) 1054 1055 @parameterized.parameters([ 1056 # from_value_rowids 1057 { 1058 'descr': 'bad rank for value_rowids', 1059 'factory': RaggedTensor.from_value_rowids, 1060 'values': [[1, 2], [3, 4]], 1061 'value_rowids': [[1, 2], [3, 4]], 1062 'nrows': 10 1063 }, 1064 { 1065 'descr': 'bad rank for nrows', 1066 'factory': RaggedTensor.from_value_rowids, 1067 'values': [1, 2, 3, 4], 1068 'value_rowids': [1, 2, 3, 4], 1069 'nrows': [10] 1070 }, 1071 { 1072 'descr': 'len(values) != len(value_rowids)', 1073 'factory': RaggedTensor.from_value_rowids, 1074 'values': [1, 2, 3, 4], 1075 'value_rowids': [1, 2, 3, 4, 5], 1076 'nrows': 10 1077 }, 1078 { 1079 'descr': 'negative value_rowid', 1080 'factory': RaggedTensor.from_value_rowids, 1081 'values': [1, 2, 3, 4], 1082 'value_rowids': [-5, 2, 3, 4], 1083 'nrows': 10 1084 }, 1085 { 1086 'descr': 'non-monotonic-increasing value_rowid', 1087 'factory': RaggedTensor.from_value_rowids, 1088 'values': [1, 2, 3, 4], 1089 'value_rowids': [4, 3, 2, 1], 1090 'nrows': 10 1091 }, 1092 { 1093 'descr': 'value_rowid > nrows', 1094 'factory': RaggedTensor.from_value_rowids, 1095 'values': [1, 2, 3, 4], 1096 'value_rowids': [1, 2, 3, 4], 1097 'nrows': 2 1098 }, 1099 { 1100 'descr': 'bad rank for values', 1101 'factory': RaggedTensor.from_value_rowids, 1102 'values': 10, 1103 'value_rowids': [1, 2, 3, 4], 1104 'nrows': 10 1105 }, 1106 1107 # from_row_splits 1108 { 1109 'descr': 'bad rank for row_splits', 1110 'factory': RaggedTensor.from_row_splits, 1111 'values': [[1, 2], [3, 4]], 1112 'row_splits': [[1, 2], [3, 4]] 1113 }, 1114 { 1115 'descr': 'row_splits[0] != 0', 1116 'factory': RaggedTensor.from_row_splits, 1117 'values': [1, 2, 3, 4], 1118 'row_splits': [2, 3, 4] 1119 }, 1120 { 1121 'descr': 'non-monotonic-increasing row_splits', 1122 'factory': RaggedTensor.from_row_splits, 1123 'values': [1, 2, 3, 4], 1124 'row_splits': [0, 3, 2, 4] 1125 }, 1126 { 1127 'descr': 'row_splits[0] != nvals', 1128 'factory': RaggedTensor.from_row_splits, 1129 'values': [1, 2, 3, 4], 1130 'row_splits': [0, 2, 3, 5] 1131 }, 1132 { 1133 'descr': 'bad rank for values', 1134 'factory': RaggedTensor.from_row_splits, 1135 'values': 10, 1136 'row_splits': [0, 1] 1137 }, 1138 1139 # from_row_lengths 1140 { 1141 'descr': 'bad rank for row_lengths', 1142 'factory': RaggedTensor.from_row_lengths, 1143 'values': [1, 2, 3, 4], 1144 'row_lengths': [[1, 2], [1, 0]] 1145 }, 1146 { 1147 'descr': 'negatve row_lengths', 1148 'factory': RaggedTensor.from_row_lengths, 1149 'values': [1, 2, 3, 4], 1150 'row_lengths': [3, -1, 2] 1151 }, 1152 { 1153 'descr': 'sum(row_lengths) != nvals', 1154 'factory': RaggedTensor.from_row_lengths, 1155 'values': [1, 2, 3, 4], 1156 'row_lengths': [2, 4, 2, 8] 1157 }, 1158 { 1159 'descr': 'bad rank for values', 1160 'factory': RaggedTensor.from_row_lengths, 1161 'values': 10, 1162 'row_lengths': [0, 1] 1163 }, 1164 1165 # from_row_starts 1166 { 1167 'descr': 'bad rank for row_starts', 1168 'factory': RaggedTensor.from_row_starts, 1169 'values': [[1, 2], [3, 4]], 1170 'row_starts': [[1, 2], [3, 4]] 1171 }, 1172 { 1173 'descr': 'row_starts[0] != 0', 1174 'factory': RaggedTensor.from_row_starts, 1175 'values': [1, 2, 3, 4], 1176 'row_starts': [2, 3, 4] 1177 }, 1178 { 1179 'descr': 'non-monotonic-increasing row_starts', 1180 'factory': RaggedTensor.from_row_starts, 1181 'values': [1, 2, 3, 4], 1182 'row_starts': [0, 3, 2, 4] 1183 }, 1184 { 1185 'descr': 'row_starts[0] > nvals', 1186 'factory': RaggedTensor.from_row_starts, 1187 'values': [1, 2, 3, 4], 1188 'row_starts': [0, 2, 3, 5] 1189 }, 1190 { 1191 'descr': 'bad rank for values', 1192 'factory': RaggedTensor.from_row_starts, 1193 'values': 10, 1194 'row_starts': [0, 1] 1195 }, 1196 1197 # from_row_limits 1198 { 1199 'descr': 'bad rank for row_limits', 1200 'factory': RaggedTensor.from_row_limits, 1201 'values': [[1, 2], [3, 4]], 1202 'row_limits': [[1, 2], [3, 4]] 1203 }, 1204 { 1205 'descr': 'row_limits[0] < 0', 1206 'factory': RaggedTensor.from_row_limits, 1207 'values': [1, 2, 3, 4], 1208 'row_limits': [-1, 3, 4] 1209 }, 1210 { 1211 'descr': 'non-monotonic-increasing row_limits', 1212 'factory': RaggedTensor.from_row_limits, 1213 'values': [1, 2, 3, 4], 1214 'row_limits': [0, 3, 2, 4] 1215 }, 1216 { 1217 'descr': 'row_limits[0] != nvals', 1218 'factory': RaggedTensor.from_row_limits, 1219 'values': [1, 2, 3, 4], 1220 'row_limits': [0, 2, 3, 5] 1221 }, 1222 { 1223 'descr': 'bad rank for values', 1224 'factory': RaggedTensor.from_row_limits, 1225 'values': 10, 1226 'row_limits': [0, 1] 1227 }, 1228 1229 # from_uniform_row_length 1230 { 1231 'descr': 'rowlen * nrows != nvals (1)', 1232 'factory': RaggedTensor.from_uniform_row_length, 1233 'values': [1, 2, 3, 4, 5], 1234 'uniform_row_length': 3 1235 }, 1236 { 1237 'descr': 'rowlen * nrows != nvals (2)', 1238 'factory': RaggedTensor.from_uniform_row_length, 1239 'values': [1, 2, 3, 4, 5], 1240 'uniform_row_length': 6 1241 }, 1242 { 1243 'descr': 'rowlen * nrows != nvals (3)', 1244 'factory': RaggedTensor.from_uniform_row_length, 1245 'values': [1, 2, 3, 4, 5, 6], 1246 'uniform_row_length': 3, 1247 'nrows': 3 1248 }, 1249 { 1250 'descr': 'rowlen must be a scalar', 1251 'factory': RaggedTensor.from_uniform_row_length, 1252 'values': [1, 2, 3, 4], 1253 'uniform_row_length': [2] 1254 }, 1255 { 1256 'descr': 'rowlen must be nonnegative', 1257 'factory': RaggedTensor.from_uniform_row_length, 1258 'values': [1, 2, 3, 4], 1259 'uniform_row_length': -1 1260 }, 1261 ]) 1262 def testFactoryValidation(self, descr, factory, **kwargs): 1263 # When input tensors have shape information, some of these errors will be 1264 # detected statically. 1265 with self.assertRaises((errors.InvalidArgumentError, ValueError)): 1266 self.evaluate(factory(**kwargs)) 1267 1268 # Remove shape information (by wrapping tensors in placeholders), and check 1269 # that we detect the errors when the graph is run. 1270 if not context.executing_eagerly(): 1271 1272 def wrap_arg(v): 1273 return array_ops.placeholder_with_default( 1274 constant_op.constant(v, dtype=dtypes.int64), 1275 tensor_shape.TensorShape(None)) 1276 1277 kwargs = dict((k, wrap_arg(v)) for (k, v) in kwargs.items()) 1278 1279 with self.assertRaises(errors.InvalidArgumentError): 1280 self.evaluate(factory(**kwargs)) 1281 1282 #============================================================================= 1283 # RaggedTensor Variant conversion 1284 #============================================================================= 1285 1286 @parameterized.named_parameters( 1287 { 1288 'testcase_name': 'Shape_5_none', 1289 'ragged_constant': [[1, 2], [3, 4, 5], [6], [], [7]], 1290 'ragged_rank': 1 1291 }, { 1292 'testcase_name': 'Shape_4_none_2', 1293 'ragged_constant': [[[1, 2]], [], [[3, 4]], []], 1294 'ragged_rank': 1 1295 }, { 1296 'testcase_name': 'Shape_1_none_none', 1297 'ragged_constant': [[[1], [2, 3, 4, 5, 6, 7]], [[]]], 1298 'ragged_rank': 2 1299 }) 1300 def testRaggedToVariant(self, ragged_constant, ragged_rank): 1301 rt = ragged_factory_ops.constant(ragged_constant, ragged_rank=ragged_rank) 1302 et = rt._to_variant() 1303 self.assertEqual(et.shape.as_list(), []) 1304 self.assertEqual(et.dtype, dtypes.variant) 1305 1306 @parameterized.parameters( 1307 { 1308 'ragged_constant': [[1, 2], [3, 4, 5], [6], [], [7]], 1309 'ragged_rank': 1, 1310 'num_batched_elems': 5 1311 }, { 1312 'ragged_constant': [[[1, 2]], [], [[3, 4]], []], 1313 'ragged_rank': 1, 1314 'num_batched_elems': 4 1315 }, { 1316 'ragged_constant': [[[1], [2, 3, 4, 5, 6, 7]], [[]]], 1317 'ragged_rank': 2, 1318 'num_batched_elems': 2 1319 }) 1320 def testRaggedToBatchedVariant(self, ragged_constant, ragged_rank, 1321 num_batched_elems): 1322 rt = ragged_factory_ops.constant(ragged_constant, ragged_rank=ragged_rank) 1323 et = rt._to_variant(batched_input=True) 1324 self.assertEqual(et.shape.as_list(), [num_batched_elems]) 1325 self.assertEqual(et.dtype, dtypes.variant) 1326 1327 @parameterized.parameters( 1328 # 2D test cases. 1329 { 1330 'ragged_constant': [[]], 1331 'ragged_rank': 1, 1332 }, 1333 { 1334 'ragged_constant': [[1]], 1335 'ragged_rank': 1, 1336 }, 1337 { 1338 'ragged_constant': [[1, 2]], 1339 'ragged_rank': 1, 1340 }, 1341 { 1342 'ragged_constant': [[1], [2], [3]], 1343 'ragged_rank': 1, 1344 }, 1345 { 1346 'ragged_constant': [[1, 2, 3], [4, 5, 6], [7, 8, 9]], 1347 'ragged_rank': 1, 1348 }, 1349 { 1350 'ragged_constant': [[1, 2], [3, 4, 5], [6], [], [7]], 1351 'ragged_rank': 1, 1352 }, 1353 # 3D test cases. 1354 { 1355 'ragged_constant': [[[]]], 1356 'ragged_rank': 2, 1357 }, 1358 { 1359 'ragged_constant': [[[1]]], 1360 'ragged_rank': 2, 1361 }, 1362 { 1363 'ragged_constant': [[[1, 2]]], 1364 'ragged_rank': 2, 1365 }, 1366 { 1367 'ragged_constant': [[[1, 2], [3, 4]]], 1368 'ragged_rank': 2, 1369 }, 1370 { 1371 'ragged_constant': [[[1, 2]], [[3, 4]], [[5, 6]], [[7, 8]]], 1372 'ragged_rank': 2, 1373 }, 1374 { 1375 'ragged_constant': [[[1], [2]], [[3], [4]], [[5], [6]], [[7], [8]]], 1376 'ragged_rank': 2, 1377 }, 1378 { 1379 'ragged_constant': [[[1, 2]], [], [[3, 4]], []], 1380 'ragged_rank': 2, 1381 }, 1382 # 4D test cases. 1383 { 1384 'ragged_constant': [[[[1, 2], [3, 4]]], 1385 [[[0, 0], [0, 0]], [[5, 6], [7, 8]]], []], 1386 'ragged_rank': 3, 1387 }, 1388 # dtype `string`. 1389 { 1390 'ragged_constant': [['a'], ['b'], ['c']], 1391 'ragged_rank': 1, 1392 'dtype': dtypes.string, 1393 }, 1394 { 1395 'ragged_constant': [[['a', 'b'], ['c', 'd']]], 1396 'ragged_rank': 2, 1397 'dtype': dtypes.string, 1398 }, 1399 { 1400 'ragged_constant': [[[['a', 'b'], ['c', 'd']]], 1401 [[['e', 'f'], ['g', 'h']], [['i', 'j'], 1402 ['k', 'l']]], []], 1403 'ragged_rank': 3, 1404 'dtype': dtypes.string, 1405 }) 1406 def testVariantRoundTrip(self, 1407 ragged_constant, 1408 ragged_rank, 1409 dtype=dtypes.int32): 1410 rt = ragged_factory_ops.constant( 1411 ragged_constant, ragged_rank=ragged_rank, dtype=dtype) 1412 et = rt._to_variant() 1413 round_trip_rt = RaggedTensor._from_variant( 1414 et, dtype, output_ragged_rank=ragged_rank) 1415 self.assertAllEqual(rt, round_trip_rt) 1416 1417 def testBatchedVariantRoundTripInputRaggedRankInferred(self): 1418 ragged_rank = 1 1419 rt = ragged_factory_ops.constant( 1420 [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]], 1421 ragged_rank=ragged_rank) 1422 batched_variant = rt._to_variant(batched_input=True) 1423 nested_batched_variant = array_ops.reshape(batched_variant, [5, 2]) 1424 decoded_rt = RaggedTensor._from_variant( 1425 nested_batched_variant, 1426 dtype=dtypes.int32, 1427 output_ragged_rank=ragged_rank + 1) 1428 expected_rt = ragged_factory_ops.constant([[[0], [1]], [[2], [3]], [[4], 1429 [5]], 1430 [[6], [7]], [[8], [9]]]) 1431 self.assertAllEqual(decoded_rt, expected_rt) 1432 1433 def testBatchedVariantRoundTripWithInputRaggedRank(self): 1434 ragged_rank = 1 1435 rt = ragged_factory_ops.constant( 1436 [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]], 1437 ragged_rank=ragged_rank) 1438 batched_variant = rt._to_variant(batched_input=True) 1439 nested_batched_variant = array_ops.reshape(batched_variant, [5, 2]) 1440 decoded_rt = RaggedTensor._from_variant( 1441 nested_batched_variant, 1442 dtype=dtypes.int32, 1443 output_ragged_rank=ragged_rank + 1, 1444 input_ragged_rank=ragged_rank - 1) 1445 expected_rt = ragged_factory_ops.constant([[[0], [1]], [[2], [3]], [[4], 1446 [5]], 1447 [[6], [7]], [[8], [9]]]) 1448 self.assertAllEqual(decoded_rt, expected_rt) 1449 1450 def testUnbatchVariant(self): # b/141789000 1451 rt = ragged_factory_ops.constant([[1, 2, 3], [4, 5], [], [6, 7, 8, 9]]) 1452 batched = rt._to_variant(batched_input=True) 1453 for i in range(4): 1454 row = RaggedTensor._from_variant( 1455 batched[i], dtype=dtypes.int32, output_ragged_rank=0) 1456 self.assertAllEqual(rt[i], row) 1457 1458 def testUnbatchVariantInDataset(self): 1459 rt = ragged_factory_ops.constant([[1, 2, 3], [4, 5], [], [6, 7, 8, 9]]) 1460 ds = dataset_ops.Dataset.from_tensor_slices(rt) 1461 if context.executing_eagerly(): 1462 for i, value in enumerate(ds): 1463 self.assertAllEqual(rt[i], value) 1464 else: 1465 it = dataset_ops.make_one_shot_iterator(ds) 1466 out = it.get_next() 1467 with self.cached_session() as sess: 1468 for i in range(3): 1469 self.assertAllEqual(sess.run(rt[i]), out) 1470 1471 def testToVariantInvalidParams(self): 1472 self.assertRaisesRegex((ValueError, errors.InvalidArgumentError), 1473 r'be rank 1 but is rank 0', 1474 gen_ragged_conversion_ops.ragged_tensor_to_variant, 1475 rt_nested_splits=[0, 1, 2], 1476 rt_dense_values=[0, 1, 2], 1477 batched_input=True) 1478 1479 self.assertRaisesRegex((ValueError, errors.InvalidArgumentError), 1480 r'be rank 1 but is rank 2', 1481 gen_ragged_conversion_ops.ragged_tensor_to_variant, 1482 rt_nested_splits=[[[0]], [[1]], [[2]]], 1483 rt_dense_values=[0, 1, 2], 1484 batched_input=True) 1485 1486 def testFromVariantInvalidParams(self): 1487 rt = ragged_factory_ops.constant([[0], [1], [2], [3]]) 1488 batched_variant = rt._to_variant(batched_input=True) 1489 nested_batched_variant = array_ops.reshape(batched_variant, [2, 2]) 1490 with self.assertRaisesRegex(ValueError, 1491 r'`output_ragged_rank` \(1\) must be equal to'): 1492 RaggedTensor._from_variant( 1493 nested_batched_variant, 1494 dtype=dtypes.int32, 1495 output_ragged_rank=1, 1496 input_ragged_rank=1) 1497 1498 def testUnbatchToTensor(self): 1499 batched = ragged_factory_ops.constant([[0], [1], [2], [3]]) 1500 unbatched = [constant_op.constant(x) for x in [[0], [1], [2], [3]]] 1501 batched_spec = type_spec.type_spec_from_value(batched) 1502 1503 # Note that the unbatched_spec is derived from the batched spec, so it can 1504 # add back a ragged instead of a dense tensor. 1505 unbatched_spec = batched_spec._unbatch() 1506 batched_tensor_list = batched_spec._to_batched_tensor_list(batched) 1507 unbatched_tensor_lists = zip( 1508 *[array_ops.unstack(tensor) for tensor in batched_tensor_list]) 1509 actual_unbatched = [ 1510 batched_spec._unbatch()._from_tensor_list(tensor_list) 1511 for tensor_list in unbatched_tensor_lists] 1512 self.assertLen(actual_unbatched, len(unbatched)) 1513 for x in actual_unbatched: 1514 self.assertTrue(unbatched_spec.is_compatible_with(x)) 1515 1516 for (actual, expected) in zip(actual_unbatched, unbatched): 1517 self.assertAllEqual(actual, expected) 1518 1519 def testDatasetUnbatchTwice(self): 1520 batched = ragged_factory_ops.constant([[[0], [1], [5]], [[2], [3]]]) 1521 ds = dataset_ops.Dataset.from_tensors(batched) 1522 ds2 = ds.unbatch() 1523 ds3 = ds2.unbatch() 1524 if context.executing_eagerly(): 1525 value = next(iter(ds3)) 1526 self.assertAllEqual([0], value) 1527 1528 def testDatasetUnbatchToScalar(self): 1529 batched = ragged_factory_ops.constant([[0], [1], [2], [3]]) 1530 ds = dataset_ops.Dataset.from_tensors(batched) 1531 ds2 = ds.unbatch() 1532 ds3 = ds2.unbatch() 1533 if context.executing_eagerly(): 1534 value = next(iter(ds3)) 1535 self.assertAllEqual(0, value) 1536 1537 def testBatchToTensor(self): 1538 batched = ragged_factory_ops.constant([[0], [1], [2], [3]]) 1539 unbatched = [constant_op.constant(x) for x in [[0], [1], [2], [3]]] 1540 batched_spec = type_spec.type_spec_from_value(batched) 1541 1542 # Note that the unbatched_spec is derived from the batched spec, so it can 1543 # add back a ragged instead of a dense tensor. 1544 unbatched_spec = batched_spec._unbatch() 1545 unbatched_tensor_lists = [unbatched_spec._to_tensor_list(x) 1546 for x in unbatched] 1547 batched_tensor_list = [array_ops.stack(tensors) 1548 for tensors in zip(*unbatched_tensor_lists)] 1549 actual_batched = unbatched_spec._batch(4)._from_tensor_list( 1550 batched_tensor_list) 1551 self.assertAllEqual(actual_batched, batched) 1552 1553 def _testGradient(self, func, x, expected_grad, grad_y=None): 1554 x = ragged_factory_ops.constant(x) 1555 if grad_y is not None: 1556 grad_y = ragged_factory_ops.constant(grad_y) 1557 if context.executing_eagerly(): 1558 with backprop.GradientTape() as t: 1559 t.watch(x) 1560 y = func(x) 1561 g = t.gradient(y, x, grad_y) 1562 else: 1563 y = func(x) 1564 g = gradients_impl.gradients(ys=y, xs=x, grad_ys=grad_y)[0] 1565 if expected_grad is None: 1566 self.assertIsNone(g) 1567 else: 1568 g = ragged_tensor.convert_to_tensor_or_ragged_tensor(g) 1569 self.assertAllClose(g, expected_grad) 1570 1571 @parameterized.named_parameters([ 1572 dict( 1573 testcase_name='RaggedInput', 1574 func=lambda x: math_ops.reduce_prod(x, axis=1), 1575 x=[[1., 2.], [3.]], 1576 expected=[[2., 1.], [1.]]), 1577 dict( 1578 testcase_name='RaggedOutput', 1579 func=lambda x: ragged_concat_ops.stack([x, x[:1]]), 1580 x=[3., 2.], 1581 expected=[2., 1.]), 1582 dict( 1583 testcase_name='RaggedInputAndOutput', 1584 func=lambda x: array_ops.stack([x, x * x]), 1585 x=[[1., 2.], [3.]], 1586 expected=[[3., 5.], [7.]]), 1587 dict( 1588 testcase_name='RaggedOutputWithGradYs', 1589 func=lambda x: ragged_concat_ops.stack([x, x[:1]]), 1590 x=[3., 2.], 1591 grad_ys=[[1., 1.], [1.]], 1592 expected=[2., 1.]), 1593 dict( 1594 testcase_name='RaggedInputAndOutputWithGradYs', 1595 func=lambda x: array_ops.stack([x, x * x]), 1596 x=[[1., 2.], [3.]], 1597 grad_ys=[[[1., 1.], [1.]], [[1., 1.], [1.]]], 1598 expected=[[3., 5.], [7.]]), 1599 dict( 1600 testcase_name='RaggedRank3', 1601 func=lambda x: ragged_concat_ops.stack([x, (x * x)[:, 1:]]), 1602 x=[[[1., 2.], [3., 4., 5.]], [[6.]]], 1603 expected=[[[1.0, 1.0], [7.0, 9.0, 11.0]], [[1.0]]]), 1604 dict( 1605 testcase_name='RaggedIndexedSlices', 1606 func=lambda x: ragged_gather_ops.gather(x, [0, 2]), 1607 x=[[1., 2.], [3.], [4., 5., 6.]], 1608 expected=[[1., 1.], [0.], [1., 1., 1.]]), 1609 ]) 1610 def testGradient(self, func, x, expected, grad_ys=None): 1611 self._testGradient(func, x, expected, grad_ys) 1612 1613 def testHigherOrderGradient(self): 1614 x = ragged_factory_ops.constant([[1.0, 2.0], [3.0]]) 1615 1616 with backprop.GradientTape() as t2: 1617 t2.watch(x) 1618 with backprop.GradientTape() as t1: 1619 t1.watch(x) 1620 y = x * x * x 1621 dy_dx = t1.gradient(y, x) 1622 d2y_dx2 = t2.gradient(dy_dx, x) 1623 1624 self.assertAllEqual(dy_dx, [[3.0, 12.0], [27.0]]) 1625 self.assertAllEqual(d2y_dx2, [[6.0, 12.0], [18.0]]) 1626 1627 def testUnconnectedGradient(self): 1628 x = ragged_factory_ops.constant([[1.0, 2.0], [3.0]]) 1629 1630 with backprop.GradientTape() as t: 1631 t.watch(x) 1632 y = ragged_factory_ops.constant([[2.0, 4.0], [6.0]]) 1633 self.assertIsNone(t.gradient(y, x)) 1634 1635 def testStopGradient(self): 1636 1637 def func(x): 1638 y = x * constant_op.constant([[1.], [3.]]) 1639 y = y.with_values(array_ops.stop_gradient(y.values)) 1640 z = x * y 1641 return math_ops.reduce_sum(z) 1642 1643 self._testGradient(func, [[1., 2.], [3., 4., 5.]], 1644 [[1., 2.], [9., 12., 15.]]) 1645 1646 def testStopGradientNoneComponent(self): 1647 1648 def func(x): 1649 y = x * constant_op.constant([[1.], [3.]]) 1650 y = y.with_values(array_ops.stop_gradient(y.values)) 1651 return y 1652 1653 self._testGradient(func, [[1., 2], [3, 4, 5]], None) 1654 1655 def testRaggedVariantGradients(self): 1656 1657 def func(x): 1658 rt1 = RaggedTensor.from_row_splits(values=x, row_splits=[0, 4, 7, 8]) 1659 rt2 = rt1 * [[10], [100], [1000]] 1660 v = rt2._to_variant(batched_input=False) 1661 rt3 = RaggedTensor._from_variant(v, dtype=rt2.dtype, output_ragged_rank=1) 1662 return rt3.flat_values 1663 1664 self._testGradient(func, [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0], 1665 [10., 10., 10., 10., 100., 100., 100., 1000.]) 1666 1667 def testRaggedVariantGradientsEmptyRows(self): 1668 1669 def func(x): 1670 rt1 = RaggedTensor.from_row_splits( 1671 values=x, row_splits=[0, 2, 2, 4, 7, 7, 8]) 1672 rt2 = rt1 * [[10], [20], [30], [40], [50], [60]] 1673 v = rt2._to_variant(batched_input=False) 1674 rt3 = RaggedTensor._from_variant(v, dtype=rt2.dtype, output_ragged_rank=1) 1675 return rt3.flat_values 1676 1677 self._testGradient(func, [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0], 1678 [10., 10., 30., 30., 40., 40., 40., 60.]) 1679 1680 def testRaggedVariantSteps(self): 1681 x = [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0] 1682 rt1 = RaggedTensor.from_row_splits(values=x, row_splits=[0, 4, 7, 8]) 1683 rt2 = rt1 * [[10], [100], [1000]] 1684 v = rt2._to_variant(batched_input=False) 1685 rt3 = RaggedTensor._from_variant(v, dtype=rt2.dtype, output_ragged_rank=1) 1686 self.assertAllClose([30., 10., 40., 10., 100., 0., 200., 1000.], 1687 rt3.flat_values) 1688 1689 def testRaggedVariantGradientsBatched(self): 1690 1691 def func(x): 1692 rt1 = RaggedTensor.from_row_splits(values=x, row_splits=[0, 4, 7, 8]) 1693 rt2 = rt1 * [[10], [100], [1000]] 1694 v = rt2._to_variant(batched_input=True) 1695 rt3 = RaggedTensor._from_variant(v, dtype=rt2.dtype, output_ragged_rank=1) 1696 return rt3.flat_values 1697 1698 self._testGradient(func, [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0], 1699 [10., 10., 10., 10., 100., 100., 100., 1000.]) 1700 1701 def testRaggedVariantGradientsEmptyRowsBatched(self): 1702 1703 def func(x): 1704 rt1 = RaggedTensor.from_row_splits( 1705 values=x, row_splits=[0, 2, 2, 4, 7, 7, 8]) 1706 rt2 = rt1 * [[10], [20], [30], [40], [50], [60]] 1707 v = rt2._to_variant(batched_input=True) 1708 rt3 = RaggedTensor._from_variant(v, dtype=rt2.dtype, output_ragged_rank=1) 1709 return rt3.flat_values 1710 1711 self._testGradient(func, [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0], 1712 [10., 10., 30., 30., 40., 40., 40., 60.]) 1713 1714 def testRaggedVariantGradientsEmptyOutputBatched(self): 1715 1716 def func(x): 1717 rt1 = RaggedTensor.from_row_splits( 1718 values=x, row_splits=[0, 0, 0, 0, 0, 0, 0]) 1719 rt2 = rt1 * [[10], [20], [30], [40], [50], [60]] 1720 v = rt2._to_variant(batched_input=True) 1721 rt3 = RaggedTensor._from_variant(v, dtype=rt2.dtype, output_ragged_rank=1) 1722 return rt3.flat_values 1723 1724 self._testGradient(func, [], []) 1725 1726 def testRaggedVariantGradientsBatchedAndSliced(self): 1727 1728 def func(x, i): 1729 rt1 = RaggedTensor.from_row_splits(values=x, row_splits=[0, 4, 7, 8]) 1730 rt2 = rt1 * [[10], [100], [1000]] 1731 v_slice = rt2._to_variant(batched_input=True)[i] 1732 return RaggedTensor._from_variant( 1733 v_slice, dtype=rt2.dtype, output_ragged_rank=0) 1734 1735 self._testGradient( 1736 functools.partial(func, i=0), [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0], 1737 [10., 10., 10., 10., 0., 0., 0., 0.]) 1738 self._testGradient( 1739 functools.partial(func, i=1), [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0], 1740 [0., 0., 0., 0., 100., 100., 100., 0.]) 1741 self._testGradient( 1742 functools.partial(func, i=2), [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0], 1743 [0., 0., 0., 0., 0., 0., 0., 1000.]) 1744 1745 def testRaggedVariantGradientsEmptyRowsBatchedAndSliced(self): 1746 1747 def func(x, i): 1748 rt1 = RaggedTensor.from_row_splits( 1749 values=x, row_splits=[0, 2, 2, 4, 7, 7, 8]) 1750 rt2 = rt1 * [[10], [20], [30], [40], [50], [60]] 1751 v_slice = rt2._to_variant(batched_input=True)[i] 1752 return RaggedTensor._from_variant( 1753 v_slice, dtype=rt2.dtype, output_ragged_rank=0) 1754 1755 self._testGradient( 1756 functools.partial(func, i=0), [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0], 1757 [10., 10., 0., 0., 0., 0., 0., 0.]) 1758 self._testGradient( 1759 functools.partial(func, i=1), [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0], 1760 [0., 0., 0., 0., 0., 0., 0., 0.]) 1761 self._testGradient( 1762 functools.partial(func, i=2), [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0], 1763 [0., 0., 30., 30., 0., 0., 0., 0.]) 1764 self._testGradient( 1765 functools.partial(func, i=3), [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0], 1766 [0., 0., 0., 0., 40., 40., 40., 0.]) 1767 self._testGradient( 1768 functools.partial(func, i=4), [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0], 1769 [0., 0., 0., 0., 0., 0., 0., 0.]) 1770 self._testGradient( 1771 functools.partial(func, i=5), [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0], 1772 [0., 0., 0., 0., 0., 0., 0., 60.]) 1773 1774 def testRaggedVariantGradientsRaggedRank0(self): 1775 1776 def func(x): 1777 x2 = x * 2 1778 v = gen_ragged_conversion_ops.ragged_tensor_to_variant( 1779 [], x2, batched_input=False) 1780 return RaggedTensor._from_variant(v, dtype=x2.dtype, output_ragged_rank=0) 1781 1782 self._testGradient(func, [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0], 1783 [2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]) 1784 1785 def testRaggedVariantGradientsRaggedRank3(self): 1786 1787 def func(x): 1788 x2 = x * 2 1789 rt1 = RaggedTensor.from_nested_row_splits( 1790 x2, ([0, 0, 3], [0, 2, 2, 3], [0, 4, 7, 8])) 1791 v = rt1._to_variant(batched_input=False) 1792 rt3 = RaggedTensor._from_variant(v, dtype=x2.dtype, output_ragged_rank=3) 1793 return rt3.flat_values 1794 1795 self._testGradient(func, [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0], 1796 [2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]) 1797 1798 def testRaggedVariantGradientsViaMapFn(self): 1799 rt = RaggedTensor.from_row_splits( 1800 values=[3, 1.0, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 7, 8]) 1801 1802 def func(x): 1803 1804 def transform_row(row): 1805 return math_ops.sqrt( 1806 math_ops.reduce_mean(math_ops.square(row * x), keepdims=True)) 1807 1808 return math_ops.reduce_sum(map_fn.map_fn(transform_row, rt)) 1809 1810 self._testGradient(func, 3.0, 14.653377) 1811 1812 def testRaggedVariantGradientsEmptyRowsViaMapFn(self): 1813 rt = RaggedTensor.from_row_splits( 1814 values=[3, 1.0, 4, 1, 5, 9, 2, 6], row_splits=[0, 2, 2, 4, 7, 7, 8]) 1815 1816 def func(x): 1817 1818 def transform_row(row): 1819 return math_ops.sqrt( 1820 math_ops.reduce_mean(math_ops.square(row * x), keepdims=True)) 1821 1822 return math_ops.reduce_sum(map_fn.map_fn(transform_row, rt)) 1823 1824 self._testGradient(func, 3.0, 17.206844) 1825 1826 def testRaggedVariantGradientsEmptyOutputViaMapFn(self): 1827 rt = RaggedTensor.from_row_splits( 1828 values=[], row_splits=[0, 0, 0, 0]) 1829 1830 def func(x): 1831 1832 def transform_row(row): 1833 return math_ops.sqrt( 1834 math_ops.reduce_mean(math_ops.square(row * x), keepdims=True)) 1835 1836 return math_ops.reduce_sum(map_fn.map_fn(transform_row, rt)) 1837 1838 self._testGradient(func, 3.0, 0.0) 1839 1840 def testRaggedVariantGradientsViaMapFnReduce(self): 1841 1842 def func(x): 1843 rt1 = RaggedTensor.from_row_splits(values=x, row_splits=[0, 4, 7, 8]) 1844 return map_fn.map_fn( 1845 math_ops.reduce_max, 1846 rt1, 1847 fn_output_signature=tensor_spec.TensorSpec((), x.dtype)) 1848 1849 self._testGradient(func, [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0], 1850 [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0]) 1851 1852 def testRaggedVariantGradientsEmptyRowsViaMapFnReduce(self): 1853 1854 def func(x): 1855 rt1 = RaggedTensor.from_row_splits( 1856 values=x, row_splits=[0, 2, 2, 4, 7, 7, 8]) 1857 return map_fn.map_fn( 1858 math_ops.reduce_max, 1859 rt1, 1860 fn_output_signature=tensor_spec.TensorSpec((), x.dtype)) 1861 1862 self._testGradient(func, [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0], 1863 [1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0]) 1864 1865 def testRaggedVariantGradientsEmptyOutputViaMapFnReduce(self): 1866 1867 def func(x): 1868 rt1 = RaggedTensor.from_row_splits( 1869 values=x, row_splits=[0, 0, 0, 0]) 1870 return map_fn.map_fn( 1871 math_ops.reduce_max, 1872 rt1, 1873 fn_output_signature=tensor_spec.TensorSpec((), x.dtype)) 1874 1875 self._testGradient(func, [], []) 1876 1877 def testRaggedVariantGradientsErrors(self): 1878 if context.executing_eagerly(): 1879 return 1880 1881 rt = RaggedTensor.from_row_splits([1.0, 2.0], row_splits=[0, 2, 2]) 1882 v1 = rt._to_variant() 1883 v2 = array_ops.stack([array_ops.stack([v1])]) 1884 y = RaggedTensor._from_variant(v2, rt.dtype, output_ragged_rank=3) 1885 1886 with self.assertRaisesRegex( 1887 ValueError, 'Unable to compute gradient: RaggedTensorToVariant ' 1888 'can currently only generate 0D or 1D output.'): 1889 gradients_impl.gradients(ys=y.flat_values, xs=rt.flat_values) 1890 1891 def assertNumpyObjectTensorsRecursivelyEqual(self, a, b, msg): 1892 """Check that two numpy arrays are equal. 1893 1894 For arrays with dtype=object, check values recursively to see if a and b 1895 are equal. (c.f. `np.array_equal`, which checks dtype=object values using 1896 object identity.) 1897 1898 Args: 1899 a: A numpy array. 1900 b: A numpy array. 1901 msg: Message to display if a != b. 1902 """ 1903 if isinstance(a, np.ndarray) and a.dtype == object: 1904 self.assertEqual(a.dtype, b.dtype, msg) 1905 self.assertEqual(a.shape, b.shape, msg) 1906 self.assertLen(a, len(b), msg) 1907 for a_val, b_val in zip(a, b): 1908 self.assertNumpyObjectTensorsRecursivelyEqual(a_val, b_val, msg) 1909 else: 1910 self.assertAllEqual(a, b, msg) 1911 1912 @parameterized.named_parameters([ 1913 ('Shape_2_R', 1914 [[1, 2], [3, 4, 5]], 1915 np.array([int32array([1, 2]), int32array([3, 4, 5])])), 1916 ('Shape_2_2', 1917 [[1, 2], [3, 4]], 1918 np.array([[1, 2], [3, 4]])), 1919 ('Shape_2_R_2', 1920 [[[1, 2], [3, 4]], [[5, 6]]], 1921 np.array([int32array([[1, 2], [3, 4]]), int32array([[5, 6]])])), 1922 ('Shape_3_2_R', 1923 [[[1], []], [[2, 3], [4]], [[], [5, 6, 7]]], 1924 np.array([[int32array([1]), int32array([])], 1925 [int32array([2, 3]), int32array([4])], 1926 [int32array([]), int32array([5, 6, 7])]])), 1927 ('Shape_0_R', 1928 ragged_factory_ops.constant_value([], ragged_rank=1, dtype=np.int32), 1929 np.zeros([0, 0], dtype=np.int32)), 1930 ('Shape_0_R_2', 1931 ragged_factory_ops.constant_value([], ragged_rank=1, 1932 inner_shape=(2,), dtype=np.int32), 1933 np.zeros([0, 0, 2], dtype=np.int32)), 1934 ]) # pyformat: disable 1935 def testRaggedTensorNumpy(self, rt, expected): 1936 if isinstance(rt, list): 1937 rt = ragged_factory_ops.constant(rt, dtype=dtypes.int32) 1938 else: 1939 rt = ragged_tensor.convert_to_tensor_or_ragged_tensor(rt) 1940 if context.executing_eagerly(): 1941 actual = rt.numpy() 1942 self.assertNumpyObjectTensorsRecursivelyEqual( 1943 expected, actual, 'Expected %r, got %r' % (expected, actual)) 1944 else: 1945 with self.assertRaisesRegex(ValueError, 'only supported in eager mode'): 1946 rt.numpy() 1947 1948 @parameterized.parameters([ 1949 ([[[1, 2], [3, 4, 5]], [[6]]], 2, None), 1950 ([[[1, 2], [3, 4, 5]], [[6]]], 2, [None, None, None]), 1951 ([[[1, 2], [3, 4, 5]], [[6]]], 2, [2, None, None]), 1952 ([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, None), 1953 ([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, [None, None, None]), 1954 ([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, [2, None, None]), 1955 ([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, [2, None, 3]), 1956 ([[[1, 2, 3]]], 1, [1, 1, None]), 1957 ([[[1, 2, 3]]], 1, [1, 1, 3]), 1958 ]) 1959 def testRaggedTensorSetShape(self, rt, rt_ragged_rank, shape): 1960 rt1 = ragged_factory_ops.constant(rt, ragged_rank=rt_ragged_rank) 1961 rt1._set_shape(shape) 1962 rt1.shape.assert_is_compatible_with(shape) 1963 if shape is not None: 1964 self.assertIsNot(rt1.shape.rank, None) 1965 for a, b in zip(rt1.shape, shape): 1966 if b is not None: 1967 self.assertEqual(a, b) 1968 1969 @parameterized.parameters([ 1970 ([[[1, 2], [3, 4, 5]], [[6]]], 2, None), 1971 ([[[1, 2], [3, 4, 5]], [[6]]], 2, [None, None, None]), 1972 ([[[1, 2], [3, 4, 5]], [[6]]], 2, [2, None, None]), 1973 ([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, None), 1974 ([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, [None, None, None]), 1975 ([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, [2, None, None]), 1976 ([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, [2, None, 3]), 1977 ([[[1, 2, 3]]], 1, [1, 1, None]), 1978 ([[[1, 2, 3]]], 1, [1, 1, 3]), 1979 ]) 1980 def testRaggedTensorSetShapeWithPlaceholders(self, rt, rt_ragged_rank, shape): 1981 rt2 = nest.map_structure( 1982 lambda x: array_ops.placeholder_with_default(x, None), 1983 ragged_factory_ops.constant(rt, ragged_rank=rt_ragged_rank), 1984 expand_composites=True) 1985 rt2._set_shape(shape) 1986 rt2.shape.assert_is_compatible_with(shape) 1987 if shape is not None: 1988 self.assertIsNot(rt2.shape.rank, None) 1989 for a, b in zip(rt2.shape, shape): 1990 if b is not None: 1991 self.assertEqual(a, b) 1992 1993 def testRaggedTensorSetShapeUniformRowLength(self): 1994 rt = [[[1], [2], [3]], [[4], [5], [6]]] 1995 1996 rt1 = RaggedTensor.from_tensor(rt, ragged_rank=1) 1997 rt1._set_shape([2, 3, 1]) 1998 1999 rt2 = nest.map_structure( 2000 lambda x: array_ops.placeholder_with_default(x, None), 2001 rt1, 2002 expand_composites=True) 2003 rt2._set_shape([2, 3, 1]) 2004 2005 def testRaggedTensorSetShapeInconsistentShapeError(self): 2006 rt = RaggedTensor.from_tensor([[[1], [2], [3]], [[4], [5], [6]]], 2007 ragged_rank=1) 2008 self.assertEqual(rt.shape.as_list(), [2, 3, 1]) 2009 with self.assertRaises(ValueError): 2010 rt._set_shape([None, None, 5]) 2011 with self.assertRaisesRegex(ValueError, 'Inconsistent size'): 2012 rt._set_shape([None, 5, None]) 2013 with self.assertRaises(ValueError): 2014 rt._set_shape([5, None, None]) 2015 2016 2017@test_util.run_all_in_graph_and_eager_modes 2018class RaggedTensorSpecTest(test_util.TensorFlowTestCase, 2019 parameterized.TestCase): 2020 2021 def assertAllTensorsEqual(self, list1, list2): 2022 self.assertLen(list1, len(list2)) 2023 for (t1, t2) in zip(list1, list2): 2024 self.assertAllEqual(t1, t2) 2025 2026 def testConstruction(self): 2027 spec1 = RaggedTensorSpec(ragged_rank=1) 2028 self.assertIsNone(spec1._shape.rank) 2029 self.assertEqual(spec1._dtype, dtypes.float32) 2030 self.assertEqual(spec1._row_splits_dtype, dtypes.int64) 2031 self.assertEqual(spec1._ragged_rank, 1) 2032 2033 self.assertIsNone(spec1.shape.rank) 2034 self.assertEqual(spec1.dtype, dtypes.float32) 2035 self.assertEqual(spec1.row_splits_dtype, dtypes.int64) 2036 self.assertEqual(spec1.ragged_rank, 1) 2037 2038 spec2 = RaggedTensorSpec(shape=[None, None, None]) 2039 self.assertEqual(spec2._shape.as_list(), [None, None, None]) 2040 self.assertEqual(spec2._dtype, dtypes.float32) 2041 self.assertEqual(spec2._row_splits_dtype, dtypes.int64) 2042 self.assertEqual(spec2._ragged_rank, 2) 2043 2044 with self.assertRaisesRegex(ValueError, 'Must specify ragged_rank'): 2045 RaggedTensorSpec() 2046 with self.assertRaisesRegex(TypeError, '`ragged_rank` must be an int'): 2047 RaggedTensorSpec(ragged_rank=constant_op.constant(1)) 2048 with self.assertRaisesRegex( 2049 ValueError, 2050 r'Argument `ragged_rank` \(2\) must be less than rank \(2\).'): 2051 RaggedTensorSpec(ragged_rank=2, shape=[None, None]) 2052 2053 def testValueType(self): 2054 spec1 = RaggedTensorSpec(ragged_rank=1) 2055 self.assertEqual(spec1.value_type, RaggedTensor) 2056 spec2 = RaggedTensorSpec(ragged_rank=0) 2057 self.assertEqual(spec2.value_type, ops.Tensor) 2058 2059 @parameterized.parameters([ 2060 (RaggedTensorSpec(ragged_rank=1), 2061 (tensor_shape.TensorShape(None), dtypes.float32, 1, dtypes.int64)), 2062 (RaggedTensorSpec(shape=[5, None, None]), 2063 (tensor_shape.TensorShape([5, None, None]), dtypes.float32, 2064 2, dtypes.int64)), 2065 (RaggedTensorSpec(shape=[5, None, None], dtype=dtypes.int32), 2066 (tensor_shape.TensorShape([5, None, None]), dtypes.int32, 2, 2067 dtypes.int64)), 2068 (RaggedTensorSpec(ragged_rank=1, row_splits_dtype=dtypes.int32), 2069 (tensor_shape.TensorShape(None), dtypes.float32, 1, dtypes.int32)), 2070 ]) # pyformat: disable 2071 def testSerialize(self, rt_spec, expected): 2072 serialization = rt_spec._serialize() 2073 # TensorShape has an unconventional definition of equality, so we can't use 2074 # assertEqual directly here. But repr() is deterministic and lossless for 2075 # the expected values, so we can use that instead. 2076 self.assertEqual(repr(serialization), repr(expected)) 2077 2078 @parameterized.parameters([ 2079 (RaggedTensorSpec(ragged_rank=0, shape=[5, 3]), [ 2080 tensor_spec.TensorSpec([5, 3], dtypes.float32), 2081 ]), 2082 (RaggedTensorSpec(ragged_rank=1), [ 2083 tensor_spec.TensorSpec(None, dtypes.float32), 2084 tensor_spec.TensorSpec([None], dtypes.int64) 2085 ]), 2086 (RaggedTensorSpec(ragged_rank=1, row_splits_dtype=dtypes.int32), [ 2087 tensor_spec.TensorSpec(None, dtypes.float32), 2088 tensor_spec.TensorSpec([None], dtypes.int32), 2089 ]), 2090 (RaggedTensorSpec(ragged_rank=2), [ 2091 tensor_spec.TensorSpec(None, dtypes.float32), 2092 tensor_spec.TensorSpec([None], dtypes.int64), 2093 tensor_spec.TensorSpec([None], dtypes.int64), 2094 ]), 2095 (RaggedTensorSpec(shape=[5, None, None], dtype=dtypes.string), [ 2096 tensor_spec.TensorSpec([None], dtypes.string), 2097 tensor_spec.TensorSpec([6], dtypes.int64), 2098 tensor_spec.TensorSpec([None], dtypes.int64), 2099 ]), 2100 ]) 2101 def testComponentSpecs(self, rt_spec, expected): 2102 self.assertEqual(rt_spec._component_specs, expected) 2103 2104 @parameterized.parameters([ 2105 { 2106 'rt_spec': RaggedTensorSpec(ragged_rank=0), 2107 'rt': [1.0, 2.0, 3.0], 2108 'components': [[1.0, 2.0, 3.0]] 2109 }, 2110 { 2111 'rt_spec': RaggedTensorSpec(ragged_rank=1), 2112 'rt': [[1.0, 2.0], [3.0]], 2113 'components': [[1.0, 2.0, 3.0], [0, 2, 3]] 2114 }, 2115 { 2116 'rt_spec': RaggedTensorSpec(shape=[2, None, None]), 2117 'rt': [[[1.0, 2.0], [3.0]], [[], [4.0]]], 2118 'components': [[1.0, 2.0, 3.0, 4.0], [0, 2, 4], [0, 2, 3, 3, 4]] 2119 }, 2120 ]) 2121 def testToFromComponents(self, rt_spec, rt, components): 2122 rt = ragged_factory_ops.constant(rt) 2123 actual_components = rt_spec._to_components(rt) 2124 self.assertAllTensorsEqual(actual_components, components) 2125 rt_reconstructed = rt_spec._from_components(actual_components) 2126 self.assertAllEqual(rt, rt_reconstructed) 2127 2128 @parameterized.parameters([ 2129 { 2130 'flat_value_spec': tensor_spec.TensorSpec(None, dtypes.float32), 2131 'row_splits_spec': tensor_spec.TensorSpec(None, dtypes.int64), 2132 }, 2133 { 2134 'flat_value_spec': tensor_spec.TensorSpec([None,], dtypes.float32), 2135 'row_splits_spec': tensor_spec.TensorSpec(None, dtypes.int64), 2136 }, 2137 { 2138 'flat_value_spec': tensor_spec.TensorSpec(None, dtypes.float32), 2139 'row_splits_spec': tensor_spec.TensorSpec([None,], dtypes.int64), 2140 }, 2141 { 2142 'flat_value_spec': tensor_spec.TensorSpec([None,], dtypes.float32), 2143 'row_splits_spec': tensor_spec.TensorSpec([None,], dtypes.int64), 2144 }, 2145 { 2146 'flat_value_spec': tensor_spec.TensorSpec([4,], dtypes.float32), 2147 'row_splits_spec': tensor_spec.TensorSpec(None, dtypes.int64), 2148 }, 2149 { 2150 'flat_value_spec': tensor_spec.TensorSpec(None, dtypes.float32), 2151 'row_splits_spec': tensor_spec.TensorSpec([3,], dtypes.int64), 2152 }, 2153 ]) 2154 def testToFromComponentsStaticUnknownShape(self, flat_value_spec, 2155 row_splits_spec): 2156 rt_spec = RaggedTensorSpec(shape=[2, None], ragged_rank=1) 2157 tester = self 2158 2159 @def_function.function(input_signature=[flat_value_spec, row_splits_spec]) 2160 def test_fn(flat_value, row_splits): 2161 # Apply static shape information saved in rt_spec to rt. 2162 rt = rt_spec._from_components([flat_value, row_splits]) 2163 tester.assertEqual(rt.shape.as_list(), [2, None]) 2164 return rt + ragged_factory_ops.constant([[1.0, 1.0, 1.0], [1.0]]) 2165 2166 result = test_fn([1.0, 2.0, 3.0, 4.0], [0, 3, 4]) 2167 expected_result = ragged_factory_ops.constant([[2.0, 3.0, 4.0], [5.0]]) 2168 self.assertAllEqual(result, expected_result) 2169 2170 @test_util.run_v1_only('RaggedTensorValue is deprecated in v2') 2171 def testFromNumpyComponents(self): 2172 spec1 = RaggedTensorSpec(ragged_rank=1, dtype=dtypes.int32) 2173 rt1 = spec1._from_components([np.array([1, 2, 3]), np.array([0, 2, 3])]) 2174 self.assertIsInstance(rt1, ragged_tensor_value.RaggedTensorValue) 2175 self.assertAllEqual(rt1, [[1, 2], [3]]) 2176 2177 spec2 = RaggedTensorSpec(ragged_rank=2, dtype=dtypes.int32) 2178 rt2 = spec2._from_components( 2179 [np.array([1, 2, 3]), 2180 np.array([0, 2, 3]), 2181 np.array([0, 0, 2, 3])]) 2182 self.assertIsInstance(rt2, ragged_tensor_value.RaggedTensorValue) 2183 self.assertAllEqual(rt2, [[[], [1, 2]], [[3]]]) 2184 2185 spec3 = RaggedTensorSpec(ragged_rank=0, dtype=dtypes.int32) 2186 rt3 = spec3._from_components([np.array([1, 2, 3])]) 2187 self.assertIsInstance(rt3, np.ndarray) 2188 self.assertAllEqual(rt3, [1, 2, 3]) 2189 2190 @parameterized.parameters([ 2191 RaggedTensorSpec(ragged_rank=0, shape=[5, 3]), 2192 RaggedTensorSpec(ragged_rank=1), 2193 RaggedTensorSpec(ragged_rank=1, row_splits_dtype=dtypes.int32), 2194 RaggedTensorSpec(ragged_rank=2, dtype=dtypes.string), 2195 RaggedTensorSpec(shape=[5, None, None]), 2196 ]) 2197 def testFlatTensorSpecs(self, rt_spec): 2198 self.assertEqual(rt_spec._flat_tensor_specs, 2199 [tensor_spec.TensorSpec(None, dtypes.variant)]) 2200 2201 @parameterized.parameters([ 2202 (dtypes.float32, full_type_pb2.TFT_FLOAT), 2203 (dtypes.string, full_type_pb2.TFT_STRING), 2204 ]) 2205 def testFullTypesForFlatTensors(self, dt, ft): 2206 rt_spec = RaggedTensorSpec(ragged_rank=2, dtype=dt) 2207 full_type_list = fulltypes_for_flat_tensors(rt_spec) 2208 expect = [ 2209 full_type_pb2.FullTypeDef( 2210 type_id=full_type_pb2.TFT_RAGGED, 2211 args=[full_type_pb2.FullTypeDef(type_id=ft)]) 2212 ] 2213 self.assertEqual(len(rt_spec._flat_tensor_specs), len(full_type_list)) 2214 self.assertEqual(expect, full_type_list) 2215 2216 @parameterized.named_parameters([ 2217 { 2218 'testcase_name': 'RaggedRank0', 2219 'rt_spec': RaggedTensorSpec(ragged_rank=0), 2220 'rt': [1.0, 2.0, 3.0], 2221 }, 2222 { 2223 'testcase_name': 'RaggedRank1', 2224 'rt_spec': RaggedTensorSpec(ragged_rank=1), 2225 'rt': [[1.0, 2.0], [3.0]] 2226 }, 2227 { 2228 'testcase_name': 'RaggedRank2', 2229 'rt_spec': RaggedTensorSpec(shape=[2, None, None]), 2230 'rt': [[[1.0, 2.0], [3.0]], [[], [4.0]]] 2231 }, 2232 ]) 2233 def testToFromTensorList(self, rt_spec, rt): 2234 rt = ragged_factory_ops.constant(rt) 2235 tensor_list = rt_spec._to_tensor_list(rt) 2236 rt_reconstructed = rt_spec._from_tensor_list(tensor_list) 2237 self.assertAllEqual(rt, rt_reconstructed) 2238 2239 @parameterized.named_parameters([ 2240 # TODO(b/141789000) Test ragged_rank=0 when support is added. 2241 { 2242 'testcase_name': 'RaggedRank1', 2243 'rt_spec': RaggedTensorSpec(ragged_rank=1), 2244 'rt': [[1.0, 2.0], [3.0]] 2245 }, 2246 { 2247 'testcase_name': 'RaggedRank2', 2248 'rt_spec': RaggedTensorSpec(shape=[2, None, None]), 2249 'rt': [[[1.0, 2.0], [3.0]], [[], [4.0]]] 2250 }, 2251 ]) 2252 def testToFromBatchedTensorList(self, rt_spec, rt): 2253 rt = ragged_factory_ops.constant(rt) 2254 tensor_list = rt_spec._to_batched_tensor_list(rt) 2255 rt_reconstructed = rt_spec._from_tensor_list(tensor_list) 2256 self.assertAllEqual(rt, rt_reconstructed) 2257 first_row = rt_spec._unbatch()._from_tensor_list( 2258 [t[0] for t in tensor_list]) 2259 self.assertAllEqual(rt[0], first_row) 2260 2261 def testToFromBatchedTensorListPreservesUniformRowLengths(self): 2262 rt = RaggedTensor.from_tensor(array_ops.zeros([3, 4, 5]), ragged_rank=2) 2263 rt_spec = rt._type_spec 2264 tensor_list = rt_spec._to_batched_tensor_list(rt) 2265 rt_reconstructed = rt_spec._from_tensor_list(tensor_list) 2266 self.assertAllEqual(rt, rt_reconstructed) 2267 self.assertTrue(rt.shape.is_fully_defined()) 2268 self.assertTrue(rt_reconstructed.shape.is_fully_defined()) 2269 self.assertEqual(rt.shape.as_list(), rt_reconstructed.shape.as_list()) 2270 2271 @parameterized.parameters([ 2272 (RaggedTensorSpec([2, None], dtypes.float32, 1), 32, 2273 RaggedTensorSpec([32, 2, None], dtypes.float32, 2)), 2274 (RaggedTensorSpec([4, None], dtypes.float32, 1), None, 2275 RaggedTensorSpec([None, 4, None], dtypes.float32, 2)), 2276 (RaggedTensorSpec([2], dtypes.float32, 2277 -1), 32, RaggedTensorSpec([32, 2], dtypes.float32, 0)), 2278 ]) 2279 def testBatch(self, spec, batch_size, expected): 2280 self.assertEqual(spec._batch(batch_size), expected) 2281 2282 @parameterized.parameters([ 2283 (RaggedTensorSpec([32, None, None], dtypes.float32, 2), 2284 RaggedTensorSpec([None, None], dtypes.float32, 1)), 2285 (RaggedTensorSpec([None, None, None], dtypes.float32, 2), 2286 RaggedTensorSpec([None, None], dtypes.float32, 1)), 2287 (RaggedTensorSpec([32, 2], dtypes.float32, 0), 2288 RaggedTensorSpec([2], dtypes.float32, -1)), 2289 (RaggedTensorSpec([32, None, 4], dtypes.float32, 1, dtypes.int32), 2290 RaggedTensorSpec([None, 4], dtypes.float32, 0, dtypes.int32)), 2291 ]) # pyformat: disable 2292 def testUnbatch(self, spec, expected): 2293 self.assertEqual(spec._unbatch(), expected) 2294 2295 def testIsCompatibleWith(self): 2296 spec1 = RaggedTensorSpec([32, None, None], dtypes.float32, 2) 2297 spec2 = RaggedTensorSpec(None, dtypes.float32, 2) 2298 spec3 = RaggedTensorSpec(None, dtypes.int32, 1) 2299 spec4 = RaggedTensorSpec([None], dtypes.int32, 0) 2300 2301 self.assertTrue(spec1.is_compatible_with(spec2)) 2302 self.assertFalse(spec1.is_compatible_with(spec3)) 2303 self.assertFalse(spec1.is_compatible_with(spec4)) 2304 self.assertFalse(spec2.is_compatible_with(spec3)) 2305 self.assertFalse(spec2.is_compatible_with(spec4)) 2306 self.assertFalse(spec3.is_compatible_with(spec4)) 2307 self.assertTrue(spec4.is_compatible_with(constant_op.constant([1, 2, 3]))) 2308 2309 2310if __name__ == '__main__': 2311 googletest.main() 2312