1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for StructuredTensor.""" 16 17import textwrap 18 19from absl.testing import parameterized 20import numpy as np 21 22from tensorflow.python.eager import context 23from tensorflow.python.framework import constant_op 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import extension_type 26from tensorflow.python.framework import ops 27from tensorflow.python.framework import sparse_tensor 28from tensorflow.python.framework import tensor_shape 29from tensorflow.python.framework import tensor_spec 30from tensorflow.python.framework import test_util 31from tensorflow.python.ops import array_ops 32from tensorflow.python.ops.ragged import ragged_factory_ops 33from tensorflow.python.ops.ragged import ragged_tensor 34from tensorflow.python.ops.ragged import row_partition 35from tensorflow.python.ops.ragged.dynamic_ragged_shape import DynamicRaggedShape 36 37# TODO(b/173144447): remove when structured_array_ops is included in init. 38from tensorflow.python.ops.structured import structured_array_ops # pylint: disable=unused-import 39 40from tensorflow.python.ops.structured import structured_tensor 41from tensorflow.python.ops.structured import structured_tensor_dynamic 42from tensorflow.python.ops.structured.structured_tensor import StructuredTensor 43from tensorflow.python.platform import googletest 44from tensorflow.python.util import dispatch 45 46 47class _PrivateSpecialType(extension_type.ExtensionType): 48 ragged: ragged_tensor.RaggedTensor 49 50 51@dispatch.dispatch_for_types(array_ops.shape_v2, _PrivateSpecialType) 52def shape_v2_special(input: _PrivateSpecialType, out_type=dtypes.int32, # pylint: disable=redefined-builtin 53 name=None): 54 """Returns a DynamicRaggedShape containing the shape of the input.""" 55 del name 56 return array_ops.shape_v2(input.ragged, out_type) # pylint: disable=protected-access 57 58 59class _PrivateBrokenType(extension_type.ExtensionType): 60 ragged: ragged_tensor.RaggedTensor 61 62 63@dispatch.dispatch_for_types(array_ops.shape_v2, _PrivateBrokenType) 64def shape_v2_broken(input: _PrivateBrokenType, out_type=dtypes.int32, # pylint: disable=redefined-builtin 65 name=None): 66 """Returns a DynamicRaggedShape containing the shape of the input.""" 67 del name 68 del input 69 del out_type 70 return { 71 "foo": "This is not a shape", 72 "bar": "But if I put a string here, it becomes a vector" 73 } 74 75 76# pylint: disable=g-long-lambda 77@test_util.run_all_in_graph_and_eager_modes 78class StructuredTensorTest(test_util.TensorFlowTestCase, 79 parameterized.TestCase): 80 81 def assertAllEqual(self, a, b, msg=None): 82 if not (isinstance(a, structured_tensor.StructuredTensor) or 83 isinstance(b, structured_tensor.StructuredTensor)): 84 return super(StructuredTensorTest, self).assertAllEqual(a, b, msg) 85 if not isinstance(a, structured_tensor.StructuredTensor): 86 a = structured_tensor.StructuredTensor.from_pyval(a) 87 self._assertStructuredEqual(a, b, msg, False) 88 elif not isinstance(b, structured_tensor.StructuredTensor): 89 b = structured_tensor.StructuredTensor.from_pyval(b) 90 self._assertStructuredEqual(a, b, msg, False) 91 else: 92 self._assertStructuredEqual(a, b, msg, True) 93 94 def _assertStructuredEqual(self, a, b, msg, check_shape): 95 if check_shape: 96 self.assertEqual(repr(a.shape), repr(b.shape)) 97 self.assertEqual(set(a.field_names()), set(b.field_names())) 98 for field in a.field_names(): 99 a_value = a.field_value(field) 100 b_value = b.field_value(field) 101 self.assertIs(type(a_value), type(b_value)) 102 if isinstance(a_value, structured_tensor.StructuredTensor): 103 self._assertStructuredEqual(a_value, b_value, msg, check_shape) 104 else: 105 self.assertAllEqual(a_value, b_value, msg) 106 107 @parameterized.named_parameters([ 108 # Scalar (rank=0) StructuredTensors. 109 { 110 "testcase_name": "Rank0_WithTensorFields", 111 "rank": 0, 112 "fields": {"Foo": 5, "Bar": [1, 2, 3]}, 113 "expected_shape": [] 114 }, 115 { 116 "testcase_name": "Rank0_WithRaggedFields", 117 "fields": { 118 # note: fields have varying rank & ragged_rank. 119 "p": ragged_factory_ops.constant_value([[1, 2], [3]]), 120 "q": ragged_factory_ops.constant_value([[[4]], [], [[5, 6]]]), 121 "r": ragged_factory_ops.constant_value([[[4]], [], [[5]]], 122 ragged_rank=1), 123 "s": ragged_factory_ops.constant_value([[[4]], [], [[5]]], 124 ragged_rank=2), 125 }, 126 "rank": 0, 127 "expected_shape": [], 128 }, 129 { 130 "testcase_name": "Rank0_WithStructuredFields", 131 "fields": lambda: { 132 "foo": StructuredTensor.from_pyval({"a": 1, "b": [1, 2, 3]}), 133 "bar": StructuredTensor.from_pyval( 134 [[{"x": 12}], [{"x": 13}, {"x": 14}]]), 135 }, 136 "rank": 0, 137 "expected_shape": [], 138 }, 139 { 140 "testcase_name": "Rank0_WithMixedFields", 141 "fields": lambda: { 142 # TODO(martinz): should handle this, but can't. 143 "f1": 5, 144 "f2": [1, 2, 3], 145 "f3": ragged_factory_ops.constant_value([[1, 2], [3]]), 146 "f4": StructuredTensor.from_pyval({"a": 1, "b": [1, 2, 3]}), 147 }, 148 "rank": 0, 149 "expected_shape": [], 150 }, 151 # Vector (rank=1) StructuredTensors. 152 { 153 "testcase_name": "Rank1_WithExplicitNrows", 154 "fields": {"x": [1, 2], "y": [[1, 2], [3, 4]]}, 155 "rank": 1, 156 "expected_shape": [2], 157 }, 158 { 159 "testcase_name": "Rank1_WithTensorFields", 160 "fields": {"x": [1, 2], "y": [[1, 2], [3, 4]]}, 161 "rank": 1, 162 "expected_shape": [2], 163 164 }, 165 { 166 "testcase_name": "Rank1_WithRaggedFields", 167 "fields": { 168 # note: fields have varying rank & ragged_rank. 169 "p": ragged_factory_ops.constant_value([[1, 2], [3]]), 170 "q": ragged_factory_ops.constant_value([[[4]], [[5, 6], [7]]]), 171 "r": ragged_factory_ops.constant_value([[], [[[12]], [[13]]]]), 172 "s": ragged_factory_ops.constant_value([[], [[[12]], [[13]]]], 173 ragged_rank=1), 174 "t": ragged_factory_ops.constant_value([[], [[[12]], [[13]]]], 175 ragged_rank=2), 176 }, 177 "rank": 1, 178 "expected_shape": [2], 179 }, 180 { 181 "testcase_name": "Rank1_WithStructuredFields", 182 "fields": lambda: { 183 "foo": StructuredTensor.from_pyval( 184 [{"a": 1, "b": [1, 2, 3]}, {"a": 2, "b": []}]), 185 "bar": StructuredTensor.from_pyval( 186 [[{"x": 12}], [{"x": 13}, {"x": 14}]]), 187 }, 188 "rank": 1, 189 "expected_shape": [2], 190 }, 191 { 192 "testcase_name": "Rank1_WithMixedFields", 193 "fields": lambda: { 194 "x": [1, 2], 195 "y": [[1, 2], [3, 4]], 196 "r": ragged_factory_ops.constant_value([[1, 2], [3]]), 197 "s": StructuredTensor.from_pyval( 198 [[{"x": 12}], [{"x": 13}, {"x": 14}]]), 199 }, 200 "rank": 1, 201 "expected_shape": [2], 202 }, 203 { 204 "testcase_name": "Rank1_WithNoElements", 205 "fields": lambda: { 206 "x": [], 207 "y": np.zeros([0, 8]), 208 "r": ragged_factory_ops.constant([], ragged_rank=1), 209 "s": StructuredTensor.from_pyval([]), 210 }, 211 "rank": 1, 212 "expected_shape": [0], # Note: could also be [None] (?) 213 }, 214 { 215 "testcase_name": "Rank1_InferDimSize", 216 "fields": lambda: { 217 "x": [1, 2], 218 "y": [[1, 2], [3, 4]], 219 "r": ragged_factory_ops.constant_value([[1, 2], [3]]), 220 "p": ragged_factory_ops.constant_value([[4], [5, 6, 7]]), 221 "foo": StructuredTensor.from_pyval( 222 [{"a": 1, "b": [1, 2, 3]}, {"a": 2, "b": []}]), 223 "bar": StructuredTensor.from_pyval( 224 [[{"x": 12}], [{"x": 13}, {"x": 14}]]), 225 }, 226 "rank": 1, 227 "expected_shape": [2], # inferred from field values. 228 }, 229 # Matrix (rank=2) StructuredTensors. 230 { 231 "testcase_name": "Rank2_WithTensorFields", 232 "fields": { 233 "x": [[1, 2, 3], [4, 5, 6]], 234 "y": np.ones([2, 3, 8]) 235 }, 236 "rank": 2, 237 "expected_shape": [2, 3], # inferred from field values. 238 }, 239 { 240 "testcase_name": "Rank2_WithRaggedFields", 241 "fields": { 242 # Note: fields must have identical row_splits. 243 "a": ragged_factory_ops.constant_value([[1, 2], [3]]), 244 "b": ragged_factory_ops.constant_value([[4, 5], [6]]), 245 "c": ragged_factory_ops.constant_value([[[1, 2], [3]], [[4, 5]]]), 246 "d": ragged_factory_ops.constant_value( 247 [[[[1, 2], [3]], [[4], [], [5]]], [[[6, 7, 8], []]]]), 248 }, 249 "rank": 2, 250 "expected_shape": [2, None], 251 }, 252 { 253 "testcase_name": "Rank2_WithStructuredFields", 254 "fields": lambda: { 255 # Note: fields must have identical row_splits. 256 "a": StructuredTensor.from_pyval( 257 [[{"x": 1}], [{"x": 2}, {"x": 3}]]), 258 "b": StructuredTensor.from_pyval( 259 [[[{"y": 1}]], [[], [{"y": 2}, {"y": 3}]]]), 260 }, 261 "rank": 2, 262 "expected_shape": [2, None], # ragged shape = [[*], [*, *]] 263 }, 264 { 265 "testcase_name": "Rank2_WithMixedFields", 266 "fields": lambda: { 267 "a": [[1, 2], [3, 4]], 268 "b": ragged_factory_ops.constant_value([[1, 2], [3, 4]]), 269 "c": StructuredTensor.from_pyval( 270 [[[{"y": 1}], []], [[], [{"y": 2}, {"y": 3}]]]), 271 "d": ragged_factory_ops.constant_value( 272 [[[1, 2], []], [[3], [4]]]), 273 }, 274 "rank": 2, 275 "expected_shape": [2, 2], 276 }, 277 # Rank=4 StructuredTensors. 278 { 279 "testcase_name": "Rank4_WithMixedFields", 280 "fields": lambda: { 281 "a": np.ones([1, 2, 3, 1]), 282 "b": np.ones([1, 2, 3, 1, 5]), 283 "c": ragged_factory_ops.constant(np.zeros([1, 2, 3, 1])), 284 "d": ragged_factory_ops.constant( 285 np.zeros([1, 2, 3, 1, 3]).tolist(), ragged_rank=1), 286 "e": ragged_factory_ops.constant( 287 np.zeros([1, 2, 3, 1, 2, 2]).tolist(), ragged_rank=2), 288 "f": ragged_factory_ops.constant(np.zeros([1, 2, 3, 1, 3])), 289 "g": StructuredTensor.from_pyval( 290 [[[[{"x": j, "y": k}] for k in range(3)] 291 for j in range(2)]]), 292 "h": StructuredTensor.from_pyval( 293 [[[[[{"x": j, "y": k, "z": z} for z in range(j)]] 294 for k in range(3)] 295 for j in range(2)]]), 296 }, 297 "rank": 4, 298 "expected_shape": [1, 2, 3, 1], # inferred from field values. 299 }, 300 ]) # pyformat: disable 301 def testFromFieldsAndRank(self, fields, rank, expected_shape): 302 if callable(fields): 303 fields = fields() # deferred construction: fields may include tensors. 304 305 struct = StructuredTensor.from_fields_and_rank(fields, rank) 306 self.assertEqual(struct.shape.as_list(), expected_shape) 307 308 @parameterized.named_parameters([ 309 { 310 "testcase_name": "NoFields", 311 "rank": 1, 312 "fields": {}, 313 "msg": "Must provide at least one field" 314 }, 315 { 316 "testcase_name": "IntegerRank", 317 "rank": 0.5, 318 "fields": { 319 "foo": [1] 320 }, 321 "msg": "rank must be an integer" 322 }, 323 { 324 "testcase_name": "NonNegativeRank", 325 "rank": -1, 326 "fields": { 327 "bar": [1, 2, 3] 328 }, 329 "msg": "rank must be nonnegative" 330 }, 331 ]) 332 def testFromFieldsAndRankError(self, fields, rank, msg): 333 if callable(fields): 334 fields = fields() # deferred construction: fields may include tensors. 335 with self.assertRaisesRegex(ValueError, msg): 336 StructuredTensor.from_fields_and_rank(fields, rank) 337 338 @parameterized.named_parameters([ 339 # Scalar (rank=0) StructuredTensors. 340 { 341 "testcase_name": "Rank0_WithNoFields", 342 "shape": [], 343 "fields": {}, 344 }, 345 { 346 "testcase_name": "Rank0_WithTensorFields", 347 "shape": [], 348 "fields": {"Foo": 5, "Bar": [1, 2, 3]}, 349 }, 350 { 351 "testcase_name": "Rank0_WithRaggedFields", 352 "shape": [], 353 "fields": { 354 # note: fields have varying rank & ragged_rank. 355 "p": ragged_factory_ops.constant_value([[1, 2], [3]]), 356 "q": ragged_factory_ops.constant_value([[[4]], [], [[5, 6]]]), 357 "r": ragged_factory_ops.constant_value([[[4]], [], [[5]]], 358 ragged_rank=1), 359 "s": ragged_factory_ops.constant_value([[[4]], [], [[5]]], 360 ragged_rank=2), 361 }, 362 }, 363 { 364 "testcase_name": "Rank0_WithStructuredFields", 365 "shape": [], 366 "fields": lambda: { 367 "foo": StructuredTensor.from_pyval({"a": 1, "b": [1, 2, 3]}), 368 "bar": StructuredTensor.from_pyval( 369 [[{"x": 12}], [{"x": 13}, {"x": 14}]]), 370 }, 371 }, 372 { 373 "testcase_name": "Rank0_WithMixedFields", 374 "shape": [], 375 "fields": lambda: { 376 "f1": 5, 377 "f2": [1, 2, 3], 378 "f3": ragged_factory_ops.constant_value([[1, 2], [3]]), 379 "f4": StructuredTensor.from_pyval({"a": 1, "b": [1, 2, 3]}), 380 }, 381 }, 382 # Vector (rank=1) StructuredTensors. 383 { 384 "testcase_name": "Rank1_WithNoFields", 385 "shape": [2], 386 "fields": {}, 387 }, 388 { 389 "testcase_name": "Rank1_WithExplicitNrows", 390 "shape": [None], 391 "nrows": 2, 392 "fields": {"x": [1, 2], "y": [[1, 2], [3, 4]]}, 393 "expected_shape": [2], 394 }, 395 { 396 "testcase_name": "Rank1_WithTensorFields", 397 "shape": [2], 398 "fields": {"x": [1, 2], "y": [[1, 2], [3, 4]]}, 399 }, 400 { 401 "testcase_name": "Rank1_WithRaggedFields", 402 "shape": [2], 403 "fields": { 404 # note: fields have varying rank & ragged_rank. 405 "p": ragged_factory_ops.constant_value([[1, 2], [3]]), 406 "q": ragged_factory_ops.constant_value([[[4]], [[5, 6], [7]]]), 407 "r": ragged_factory_ops.constant_value([[], [[[12]], [[13]]]]), 408 "s": ragged_factory_ops.constant_value([[], [[[12]], [[13]]]], 409 ragged_rank=1), 410 "t": ragged_factory_ops.constant_value([[], [[[12]], [[13]]]], 411 ragged_rank=2), 412 }, 413 }, 414 { 415 "testcase_name": "Rank1_WithStructuredFields", 416 "shape": [2], 417 "fields": lambda: { 418 "foo": StructuredTensor.from_pyval( 419 [{"a": 1, "b": [1, 2, 3]}, {"a": 2, "b": []}]), 420 "bar": StructuredTensor.from_pyval( 421 [[{"x": 12}], [{"x": 13}, {"x": 14}]]), 422 }, 423 }, 424 { 425 "testcase_name": "Rank1_WithMixedFields", 426 "shape": [2], 427 "fields": lambda: { 428 "x": [1, 2], 429 "y": [[1, 2], [3, 4]], 430 "r": ragged_factory_ops.constant_value([[1, 2], [3]]), 431 "s": StructuredTensor.from_pyval( 432 [[{"x": 12}], [{"x": 13}, {"x": 14}]]), 433 }, 434 }, 435 { 436 "testcase_name": "Rank1_WithNoElements", 437 "shape": [0], 438 "fields": lambda: { 439 "x": [], 440 "y": np.zeros([0, 8]), 441 "r": ragged_factory_ops.constant([], ragged_rank=1), 442 "s": StructuredTensor.from_pyval([]), 443 }, 444 }, 445 { 446 "testcase_name": "Rank1_InferDimSize", 447 "shape": [None], 448 "fields": lambda: { 449 "x": [1, 2], 450 "y": [[1, 2], [3, 4]], 451 "r": ragged_factory_ops.constant_value([[1, 2], [3]]), 452 "p": ragged_factory_ops.constant_value([[4], [5, 6, 7]]), 453 "foo": StructuredTensor.from_pyval( 454 [{"a": 1, "b": [1, 2, 3]}, {"a": 2, "b": []}]), 455 "bar": StructuredTensor.from_pyval( 456 [[{"x": 12}], [{"x": 13}, {"x": 14}]]), 457 }, 458 "expected_shape": [2], # inferred from field values. 459 }, 460 # Matrix (rank=2) StructuredTensors. 461 { 462 "testcase_name": "Rank2_WithNoFields", 463 "shape": [2, 8], 464 "fields": {}, 465 }, 466 { 467 "testcase_name": "Rank2_WithNoFieldsAndExplicitRowPartitions", 468 "shape": [2, None], 469 "row_partitions": 470 lambda: [row_partition.RowPartition.from_row_lengths([3, 7])], 471 "fields": {}, 472 }, 473 { 474 "testcase_name": "Rank2_WithTensorFields", 475 "shape": [None, None], 476 "fields": { 477 "x": [[1, 2, 3], [4, 5, 6]], 478 "y": np.ones([2, 3, 8]) 479 }, 480 "expected_shape": [2, 3], # inferred from field values. 481 }, 482 { 483 "testcase_name": "Rank2_WithRaggedFields", 484 "shape": [2, None], # ragged shape = [[*, *], [*]] 485 "fields": { 486 # Note: fields must have identical row_splits. 487 "a": ragged_factory_ops.constant_value([[1, 2], [3]]), 488 "b": ragged_factory_ops.constant_value([[4, 5], [6]]), 489 "c": ragged_factory_ops.constant_value([[[1, 2], [3]], [[4, 5]]]), 490 "d": ragged_factory_ops.constant_value( 491 [[[[1, 2], [3]], [[4], [], [5]]], [[[6, 7, 8], []]]]), 492 }, 493 }, 494 { 495 "testcase_name": "Rank2_WithStructuredFields", 496 "shape": [2, None], # ragged shape = [[*], [*, *]] 497 "fields": lambda: { 498 # Note: fields must have identical row_splits. 499 "a": StructuredTensor.from_pyval( 500 [[{"x": 1}], [{"x": 2}, {"x": 3}]]), 501 "b": StructuredTensor.from_pyval( 502 [[[{"y": 1}]], [[], [{"y": 2}, {"y": 3}]]]), 503 }, 504 }, 505 { 506 "testcase_name": "Rank2_WithMixedFields", 507 "shape": [2, None], 508 "fields": lambda: { 509 "a": [[1, 2], [3, 4]], 510 "b": ragged_factory_ops.constant_value([[1, 2], [3, 4]]), 511 "c": StructuredTensor.from_pyval( 512 [[[{"y": 1}], []], [[], [{"y": 2}, {"y": 3}]]]), 513 "d": ragged_factory_ops.constant_value( 514 [[[1, 2], []], [[3], [4]]]), 515 }, 516 "expected_shape": [2, 2], 517 }, 518 # Rank=4 StructuredTensors. 519 { 520 "testcase_name": "Rank4_WithNoFields", 521 "shape": [1, None, None, 3], 522 "fields": {}, 523 "row_partitions": lambda: [ 524 row_partition.RowPartition.from_row_lengths([3]), 525 row_partition.RowPartition.from_row_lengths([2, 0, 1]), 526 row_partition.RowPartition.from_uniform_row_length(3, nvals=9) 527 ] 528 }, 529 { 530 "testcase_name": "Rank4_WithMixedFields", 531 "shape": [1, None, None, 1], 532 "fields": lambda: { 533 "a": np.ones([1, 2, 3, 1]), 534 "b": np.ones([1, 2, 3, 1, 5]), 535 "c": ragged_factory_ops.constant(np.zeros([1, 2, 3, 1])), 536 "d": ragged_factory_ops.constant( 537 np.zeros([1, 2, 3, 1, 3]).tolist(), ragged_rank=1), 538 "e": ragged_factory_ops.constant( 539 np.zeros([1, 2, 3, 1, 2, 2]).tolist(), ragged_rank=2), 540 "f": ragged_factory_ops.constant(np.zeros([1, 2, 3, 1, 3])), 541 "g": StructuredTensor.from_pyval( 542 [[[[{"x": j, "y": k}] for k in range(3)] 543 for j in range(2)]]), 544 "h": StructuredTensor.from_pyval( 545 [[[[[{"x": j, "y": k, "z": z} for z in range(j)]] 546 for k in range(3)] 547 for j in range(2)]]), 548 }, 549 "expected_shape": [1, 2, 3, 1], # inferred from field values. 550 }, 551 ]) # pyformat: disable 552 def testFromFields(self, 553 shape, 554 fields, 555 expected_shape=None, 556 nrows=None, 557 row_partitions=None): 558 if callable(fields): 559 fields = fields() # deferred construction: fields may include tensors. 560 if callable(nrows): 561 nrows = nrows() # deferred construction. 562 if callable(row_partitions): 563 row_partitions = row_partitions() # deferred construction. 564 for validate in (True, False): 565 struct = StructuredTensor.from_fields( 566 fields, 567 shape, 568 nrows=nrows, 569 row_partitions=row_partitions, 570 validate=validate) 571 if expected_shape is None: 572 expected_shape = shape 573 self.assertEqual(struct.shape.as_list(), expected_shape) 574 self.assertLen(expected_shape, struct.rank) 575 self.assertCountEqual(struct.field_names(), tuple(fields.keys())) 576 for field, value in fields.items(): 577 self.assertIsInstance( 578 struct.field_value(field), 579 (ops.Tensor, structured_tensor.StructuredTensor, 580 ragged_tensor.RaggedTensor)) 581 self.assertAllEqual(struct.field_value(field), value) 582 583 @parameterized.parameters([ 584 dict(fields={}, shape=object(), err=TypeError), 585 dict( 586 fields=object(), 587 shape=[], 588 err=TypeError, 589 msg="fields must be a dictionary"), 590 dict( 591 fields={1: 2}, shape=[], err=TypeError, 592 msg="Unexpected type for key"), 593 dict( 594 fields={"x": object()}, 595 shape=[], 596 err=(TypeError, ValueError), 597 msg="Error with shape of x|Unexpected type for value"), 598 dict( 599 fields={}, 600 shape=None, 601 err=ValueError, 602 msg="StructuredTensor's shape must have known rank"), 603 dict( 604 fields={"f": 5}, 605 shape=[5], 606 err=ValueError, 607 msg=r"Field f has shape \(\), which is incompatible with the shape " 608 r"that was specified or inferred from other fields: \(5,\)|Shapes"), 609 dict( 610 fields=dict(x=[1], y=[]), 611 shape=[None], 612 err=ValueError, 613 msg=r"Error in shape of y"), 614 dict( 615 fields={"": 5}, 616 shape=[], 617 err=ValueError, 618 msg="Field name '' is not currently allowed."), 619 dict( 620 fields={"_": 5}, 621 shape=[], 622 err=ValueError, 623 msg="Field name '_' is not currently allowed."), 624 dict( 625 fields={ 626 "r1": ragged_factory_ops.constant_value([[1, 2], [3]]), 627 "r2": ragged_factory_ops.constant_value([[1, 2, 3], [4]]) 628 }, 629 shape=[2, None], 630 validate=True, 631 err=ValueError, 632 msg=r"Error in shape of r2", 633 ), 634 dict( 635 fields={}, 636 shape=(), 637 nrows=5, 638 err=ValueError, 639 msg="nrows must be None if shape.rank==0"), 640 dict( 641 fields={}, 642 shape=(), 643 row_partitions=[0], 644 err=ValueError, 645 msg=r"row_partitions must be None or \[\] if shape.rank<2"), 646 dict( 647 fields={}, 648 shape=(None, None, None), 649 row_partitions=[], 650 err=ValueError, 651 msg=r"len\(row_partitions\) must be shape.rank-1"), 652 dict( 653 fields={}, 654 shape=[None], 655 err=ValueError, 656 msg="Must specify `nrows`, a fully specified `shape`, " 657 "or have `fields` if `rank=1`"), 658 dict( 659 fields={}, 660 shape=[None, None], 661 err=ValueError, 662 msg="Must specify row_partitions, a fully specified shape, " 663 "or have fields if rank > 1"), 664 dict( 665 fields={}, 666 shape=[None, None], 667 nrows=lambda: constant_op.constant(2, dtypes.int32), 668 row_partitions=lambda: 669 [row_partition.RowPartition.from_row_lengths([3, 4])], 670 err=ValueError, 671 msg="row_partition dtypes are inconsistent"), 672 dict( 673 fields=lambda: { 674 "a": 675 ragged_factory_ops.constant([[1]], 676 row_splits_dtype=dtypes.int32), 677 "b": 678 ragged_factory_ops.constant([[1]], 679 row_splits_dtype=dtypes.int64) 680 }, 681 shape=[None, None], 682 err=ValueError, 683 msg="field values have incompatible row_partition dtypes"), 684 ]) 685 def testFromFieldsErrors(self, 686 fields, 687 shape, 688 nrows=None, 689 row_partitions=None, 690 validate=False, 691 err=ValueError, 692 msg=None, 693 test_in_eager=True): 694 if not test_in_eager and context.executing_eagerly(): 695 return 696 if callable(fields): 697 fields = fields() # deferred construction. 698 if callable(nrows): 699 nrows = nrows() # deferred construction. 700 if callable(row_partitions): 701 row_partitions = row_partitions() # deferred construction. 702 with self.assertRaisesRegex(err, msg): 703 struct = StructuredTensor.from_fields( 704 fields=fields, 705 shape=shape, 706 nrows=nrows, 707 row_partitions=row_partitions, 708 validate=validate) 709 for field_name in struct.field_names(): 710 self.evaluate(struct.field_value(field_name)) 711 self.evaluate(struct.nrows()) 712 713 def testMergeNrowsErrors(self): 714 nrows = constant_op.constant(5) 715 static_nrows = tensor_shape.Dimension(5) 716 value = constant_op.constant([1, 2, 3]) 717 with self.assertRaisesRegex(ValueError, "fields have incompatible nrows"): 718 structured_tensor._merge_nrows( 719 nrows, static_nrows, value, dtypes.int32, validate=False) 720 721 def testNestedStructConstruction(self): 722 rt = ragged_factory_ops.constant([[1, 2], [3]]) 723 struct1 = StructuredTensor.from_fields(shape=[], fields={"x": [1, 2]}) 724 struct2 = StructuredTensor.from_fields(shape=[2], fields={"x": [1, 2]}) 725 struct3 = StructuredTensor.from_fields( 726 shape=[], fields={ 727 "r": rt, 728 "s": struct1 729 }) 730 struct4 = StructuredTensor.from_fields( 731 shape=[2], fields={ 732 "r": rt, 733 "s": struct2 734 }) 735 736 self.assertEqual(struct3.shape.as_list(), []) 737 self.assertEqual(struct3.rank, 0) 738 self.assertEqual(set(struct3.field_names()), set(["r", "s"])) 739 self.assertAllEqual(struct3.field_value("r"), rt) 740 self.assertAllEqual(struct3.field_value("s"), struct1) 741 742 self.assertEqual(struct4.shape.as_list(), [2]) 743 self.assertEqual(struct4.rank, 1) 744 self.assertEqual(set(struct4.field_names()), set(["r", "s"])) 745 self.assertAllEqual(struct4.field_value("r"), rt) 746 self.assertAllEqual(struct4.field_value("s"), struct2) 747 748 def testPartitionOuterDims(self): 749 a = dict(x=1, y=[1, 2]) 750 b = dict(x=2, y=[3, 4]) 751 c = dict(x=3, y=[5, 6]) 752 d = dict(x=4, y=[7, 8]) 753 st1 = StructuredTensor.from_pyval([a, b, c, d]) 754 755 st2 = st1.partition_outer_dimension( 756 row_partition.RowPartition.from_row_splits([0, 2, 2, 3, 4])) 757 self.assertAllEqual(st2, [[a, b], [], [c], [d]]) 758 759 st3 = st2.partition_outer_dimension( 760 row_partition.RowPartition.from_row_lengths([1, 0, 3, 0])) 761 self.assertAllEqual(st3, [[[a, b]], [], [[], [c], [d]], []]) 762 763 # If we partition with uniform_row_lengths, then `x` is partitioned into 764 # a Tensor (not a RaggedTensor). 765 st4 = st1.partition_outer_dimension( 766 row_partition.RowPartition.from_uniform_row_length( 767 uniform_row_length=2, nvals=4, nrows=2)) 768 self.assertAllEqual( 769 st4, 770 structured_tensor.StructuredTensor.from_pyval( 771 [[a, b], [c, d]], 772 structured_tensor.StructuredTensor.Spec( 773 _ragged_shape=DynamicRaggedShape.Spec( 774 row_partitions=[], 775 static_inner_shape=[2, 2], 776 dtype=dtypes.int64), 777 _fields={ 778 "x": 779 tensor_spec.TensorSpec([2, 2], dtypes.int32), 780 "y": 781 ragged_tensor.RaggedTensorSpec([2, 2, None], 782 dtypes.int32) 783 }))) 784 785 def testPartitionOuterDimension3(self): 786 rt = ragged_tensor.RaggedTensor.from_value_rowids( 787 array_ops.constant([[1, 2], [3, 4], [5, 6]]), [0, 0, 1]) 788 struct = structured_tensor.StructuredTensor.from_fields({"r": rt}, [2]) 789 struct_2 = struct.partition_outer_dimension( 790 row_partition.RowPartition.from_row_splits([0, 1, 2])) 791 struct_3 = struct_2.partition_outer_dimension( 792 row_partition.RowPartition.from_row_splits([0, 1, 2])) 793 self.assertEqual(3, struct_3.rank) 794 795 def testWithPrivateSpecialType(self): 796 rt = ragged_tensor.RaggedTensor.from_value_rowids( 797 array_ops.constant([[1, 2], [3, 4], [5, 6]]), [0, 0, 1]) 798 pst = _PrivateSpecialType(rt) 799 pst_shape = array_ops.shape_v2(pst) 800 st = structured_tensor.StructuredTensor.from_fields_and_rank({"r": pst}, 1) 801 st_shape = st._ragged_shape 802 self.assertEqual(1, st.rank) 803 self.assertAllEqual(pst_shape[0], st_shape[0]) 804 805 def testWithPrivateBrokenType(self): 806 rt = ragged_tensor.RaggedTensor.from_value_rowids( 807 array_ops.constant([[1, 2], [3, 4], [5, 6]]), [0, 0, 1]) 808 pbt = _PrivateBrokenType(rt) 809 810 with self.assertRaisesRegex(ValueError, "Error in shape of r"): 811 structured_tensor.StructuredTensor.from_fields_and_rank({"r": pbt}, 1) 812 813 def testPartitionOuterDimsErrors(self): 814 st = StructuredTensor.from_fields({}) 815 partition = row_partition.RowPartition.from_row_splits([0]) 816 with self.assertRaisesRegex(ValueError, 817 r"Shape \(\) must have rank at least 1"): 818 st.partition_outer_dimension(partition) 819 820 with self.assertRaisesRegex(TypeError, 821 "row_partition must be a RowPartition"): 822 st.partition_outer_dimension(10) 823 824 @parameterized.named_parameters([ 825 { 826 "testcase_name": "ScalarEmpty", 827 "pyval": {}, 828 "expected": lambda: StructuredTensor.from_fields(shape=[], fields={}) 829 }, 830 { 831 "testcase_name": "ScalarSimple", 832 "pyval": {"a": 12, "b": [1, 2, 3], "c": [[1, 2], [3]]}, 833 "expected": lambda: StructuredTensor.from_fields(shape=[], fields={ 834 "a": 12, 835 "b": [1, 2, 3], 836 "c": ragged_factory_ops.constant([[1, 2], [3]])}) 837 }, 838 { 839 "testcase_name": "ScalarSimpleWithTypeSpec", 840 "pyval": {"a": 12, "b": [1, 2, 3], "c": [[1, 2], [3]]}, 841 "type_spec": StructuredTensor.Spec._from_fields_and_rank( 842 fields={ 843 "a": tensor_spec.TensorSpec([], dtypes.int32), 844 "b": tensor_spec.TensorSpec([None], dtypes.int32), 845 "c": ragged_tensor.RaggedTensorSpec([None, None], 846 dtypes.int32)}, 847 rank=0), 848 "expected": lambda: StructuredTensor.from_fields(shape=[], fields={ 849 "a": 12, 850 "b": [1, 2, 3], 851 "c": ragged_factory_ops.constant([[1, 2], [3]])}) 852 }, 853 { 854 "testcase_name": "ScalarWithNestedStruct", 855 "pyval": {"a": 12, "b": [1, 2, 3], "c": {"x": b"Z", "y": [10, 20]}}, 856 "expected": lambda: StructuredTensor.from_fields(shape=[], fields={ 857 "a": 12, 858 "b": [1, 2, 3], 859 "c": StructuredTensor.from_fields(shape=[], fields={ 860 "x": "Z", 861 "y": [10, 20]})}) 862 }, 863 { 864 "testcase_name": "EmptyList", 865 "pyval": [], 866 "expected": lambda: [], 867 }, 868 { 869 "testcase_name": "ListOfEmptyList", 870 "pyval": [[], []], 871 "expected": lambda: [[], []], 872 }, 873 { 874 "testcase_name": "EmptyListWithTypeSpecAndFields", 875 "pyval": [], 876 "type_spec": structured_tensor.StructuredTensor.Spec._from_fields_and_rank( 877 fields={"a": tensor_spec.TensorSpec([0], dtypes.int32)}, 878 rank=1), 879 "expected": lambda: StructuredTensor.from_fields(shape=[0], fields={ 880 "a": []}) 881 }, 882 { 883 "testcase_name": "EmptyListWithTypeSpecNoFieldsShape0_5", 884 "pyval": [], 885 "type_spec": StructuredTensor.Spec._from_shape(DynamicRaggedShape.Spec( 886 row_partitions=[], 887 static_inner_shape=[0, 5], 888 dtype=dtypes.int64)), 889 "expected": lambda: StructuredTensor.from_fields(shape=[0, 5], 890 fields={}) 891 }, 892 { 893 "testcase_name": "EmptyListWithTypeSpecNoFieldsShape1_0", 894 "pyval": [[]], 895 "type_spec": StructuredTensor.Spec._from_shape( 896 DynamicRaggedShape.Spec( 897 row_partitions=[], 898 static_inner_shape=[1, 0], 899 dtype=dtypes.int64)), 900 "expected": lambda: StructuredTensor.from_shape( 901 DynamicRaggedShape.from_lengths([1, 0])) 902 }, 903 { 904 "testcase_name": "VectorOfDict", 905 "pyval": [{"a": 1}, {"a": 2}], 906 "expected": lambda: StructuredTensor.from_fields(shape=[2], fields={ 907 "a": [1, 2]}) 908 }, 909 { 910 "testcase_name": "VectorOfDictWithNestedStructScalar", 911 "pyval": [{"a": 1, "b": {"x": [1, 2]}}, 912 {"a": 2, "b": {"x": [3]}}], 913 "expected": lambda: StructuredTensor.from_fields(shape=[2], fields={ 914 "a": [1, 2], 915 "b": StructuredTensor.from_fields(shape=[2], fields={ 916 "x": ragged_factory_ops.constant([[1, 2], [3]])})}), 917 }, 918 { 919 "testcase_name": "VectorOfDictWithNestedStructVector", 920 "pyval": [{"a": 1, "b": [{"x": [1, 2]}, {"x": [5]}]}, 921 {"a": 2, "b": [{"x": [3]}]}], 922 "expected": lambda: StructuredTensor.from_fields(shape=[2], fields={ 923 "a": [1, 2], 924 "b": StructuredTensor.from_fields(shape=[2, None], fields={ 925 "x": ragged_factory_ops.constant([[[1, 2], [5]], [[3]]])})}), 926 }, 927 { 928 "testcase_name": "Ragged2DOfDict", 929 "pyval": [[{"a": 1}, {"a": 2}, {"a": 3},], 930 [{"a": 4}, {"a": 5}]], 931 "expected": lambda: StructuredTensor.from_fields( 932 shape=[2, None], 933 fields={ 934 "a": ragged_factory_ops.constant([[1, 2, 3], [4, 5]])}) 935 }, 936 { 937 # With no type-spec, all tensors>1D are encoded as ragged: 938 "testcase_name": "MatrixOfDictWithoutTypeSpec", 939 "pyval": [[{"a": 1}, {"a": 2}, {"a": 3},], 940 [{"a": 4}, {"a": 5}, {"a": 6}]], 941 "expected": lambda: StructuredTensor.from_fields( 942 shape=[2, None], fields={ 943 "a": ragged_factory_ops.constant([[1, 2, 3], [4, 5, 6]])}) 944 }, 945 { 946 # TypeSpec can be used to specify StructuredTensor shape. 947 "testcase_name": "MatrixOfDictWithTypeSpec", 948 "pyval": [[{"a": 1}, {"a": 2}, {"a": 3},], 949 [{"a": 4}, {"a": 5}, {"a": 6}]], 950 "type_spec": structured_tensor.StructuredTensorSpec([2, 3], { 951 "a": tensor_spec.TensorSpec(None, dtypes.int32)}), 952 "expected": lambda: StructuredTensor.from_fields( 953 shape=[2, 3], fields={"a": [[1, 2, 3], [4, 5, 6]]}) 954 }, 955 ]) # pyformat: disable 956 def testPyvalConversion(self, pyval, expected, type_spec=None): 957 expected = expected() # Deferred init because it creates tensors. 958 actual = structured_tensor.StructuredTensor.from_pyval(pyval, type_spec) 959 self.assertAllEqual(actual, expected) 960 if isinstance(actual, structured_tensor.StructuredTensor): 961 if context.executing_eagerly(): # to_pyval only available in eager. 962 self.assertEqual(actual.to_pyval(), pyval) 963 964 def testStructuredTensorSpecFactory(self): 965 spec = StructuredTensor.Spec._from_fields_and_rank( 966 fields={ 967 "a": tensor_spec.TensorSpec([], dtypes.int32), 968 "b": tensor_spec.TensorSpec([None], dtypes.int32), 969 "c": ragged_tensor.RaggedTensorSpec([None, None], dtypes.int32)}, 970 rank=0) 971 self.assertEqual(spec.rank, 0) 972 973 @parameterized.named_parameters([ 974 dict( 975 testcase_name="NoFieldsRaggedRank0", 976 st=lambda: StructuredTensor.from_fields({}, (3,)), 977 expected=[{}, {}, {}]), 978 dict( 979 testcase_name="NoFieldsRaggedRank1", 980 st=lambda: StructuredTensor.from_fields( 981 {}, (2, None), 982 row_partitions=[ 983 row_partition.RowPartition.from_row_lengths([3, 2])]), 984 expected=[[{}, {}, {}], [{}, {}]]), 985 dict( 986 testcase_name="NoFieldsRaggedRank2", 987 st=lambda: StructuredTensor.from_fields( 988 {}, (2, None, None), 989 row_partitions=[ 990 row_partition.RowPartition.from_row_lengths([2, 1]), 991 row_partition.RowPartition.from_row_lengths([2, 3, 1])]), 992 expected=[[[{}, {}], [{}, {}, {}]], [[{}]]]), 993 dict( 994 testcase_name="NoFieldsRaggedRank2NoDicts", 995 st=lambda: StructuredTensor.from_fields( 996 {}, (1, None, None), 997 row_partitions=[ 998 row_partition.RowPartition.from_row_lengths([2]), 999 row_partition.RowPartition.from_row_lengths([0, 0])]), 1000 expected=[[[], []]]), 1001 dict( 1002 testcase_name="NestedStructTensorWithNoFields", 1003 st=lambda: StructuredTensor.from_fields( 1004 { 1005 "foo": ragged_factory_ops.constant([[[], []]]), 1006 "bar": StructuredTensor.from_fields( 1007 {}, (1, None, None, None), row_partitions=[ 1008 row_partition.RowPartition.from_row_lengths([2]), 1009 row_partition.RowPartition.from_row_lengths([0, 0]), 1010 row_partition.RowPartition.from_row_lengths([]), 1011 ]) 1012 1013 }, (1, None, None),), 1014 expected=[[[], []]]), 1015 ]) # pyformat: disable 1016 def testToPyval(self, st, expected): 1017 if context.executing_eagerly(): # to_pyval only available in eager. 1018 st = st() # Deferred init because it creates tensors. 1019 self.assertEqual(st.to_pyval(), expected) 1020 1021 @parameterized.named_parameters([ 1022 dict(testcase_name="MissingKeys", 1023 pyval=[{"a": [1, 2]}, {"b": [3, 4]}], 1024 err=KeyError, 1025 msg="'b'"), 1026 dict(testcase_name="TypeSpecMismatch_DictKey", 1027 pyval={"a": 1}, 1028 type_spec=StructuredTensor.Spec._from_fields_and_rank( 1029 fields={"b": tensor_spec.TensorSpec([1], dtypes.int32)}, 1030 rank=1), 1031 msg=r"Value at \(\) does not match typespec"), 1032 dict(testcase_name="TypeSpecMismatch_ListDictKey", 1033 pyval=[{"a": 1}], 1034 type_spec=StructuredTensor.Spec._from_fields_and_rank( 1035 fields={"b": tensor_spec.TensorSpec([1], dtypes.int32)}, 1036 rank=1), 1037 msg=r"Value at \(\) does not match typespec"), 1038 dict(testcase_name="TypeSpecMismatch_RankMismatch", 1039 pyval=[{"a": 1}], 1040 type_spec=StructuredTensor.Spec._from_fields_and_rank( 1041 fields={"a": tensor_spec.TensorSpec([], dtypes.int32)}, 1042 rank=0), 1043 msg=r"Value at \(\) does not match typespec \(rank mismatch\)"), 1044 dict(testcase_name="TypeSpecMismatch_Scalar", 1045 pyval=0, 1046 type_spec=StructuredTensor.Spec._from_shape( 1047 DynamicRaggedShape.Spec( 1048 row_partitions=[], 1049 static_inner_shape=[], 1050 dtype=dtypes.int64)), 1051 msg=r"Value at \(\) does not match typespec"), 1052 dict(testcase_name="TypeSpecMismatch_ListTensor", 1053 pyval={"a": [[1]]}, 1054 type_spec=StructuredTensor.Spec._from_fields_and_rank( 1055 fields={"a": tensor_spec.TensorSpec([], dtypes.int32)}, 1056 rank=0), 1057 msg=r"Value at \('a',\) does not match typespec"), 1058 dict(testcase_name="TypeSpecMismatch_ListTensorDeep", 1059 pyval={"a": {"b": [[1]]}}, 1060 type_spec=StructuredTensor.Spec._from_fields_and_rank( 1061 fields={"a": StructuredTensor.Spec._from_fields_and_rank( 1062 fields={"b": tensor_spec.TensorSpec([], dtypes.int32)}, 1063 rank=0 1064 )}, 1065 rank=0), 1066 msg=r"Value at \('a', 'b'\) does not match typespec"), 1067 dict(testcase_name="TypeSpecMismatch_ListTensorDeep_infer", 1068 pyval={"a": [{"b": [[1]]}, {"b": [["c"]]}]}, 1069 type_spec=None, 1070 msg=r"Error parsing path \('a', 'b'\)"), 1071 dict(testcase_name="TypeSpecMismatch_ListTensorDeep_infer2", 1072 pyval=[{"a": 1}, {"a": "c"}], 1073 type_spec=None, 1074 msg=r"Error parsing path \('a',\)"), 1075 dict(testcase_name="TypeSpecMismatch_ListSparse", 1076 pyval=[1, 2], 1077 type_spec=sparse_tensor.SparseTensorSpec([None], dtypes.int32), 1078 msg=r"Value at \(\) does not match typespec"), 1079 dict(testcase_name="TypeSpecMismatch_ListStruct", 1080 pyval=[[1]], 1081 type_spec=StructuredTensor.Spec._from_fields_and_rank( 1082 fields={"a": tensor_spec.TensorSpec([1, 1], dtypes.int32)}, 1083 rank=2), 1084 msg=r"Value at \(\) does not match typespec"), 1085 dict(testcase_name="InconsistentDictionaryDepth", 1086 pyval=[{}, [{}]], 1087 msg="Inconsistent depth of dictionaries"), 1088 dict(testcase_name="FOO", 1089 pyval=[[{}], 5], 1090 msg="Expected dict or nested list/tuple of dict"), 1091 1092 ]) # pyformat: disable 1093 def testFromPyvalError(self, pyval, err=ValueError, type_spec=None, msg=None): 1094 with self.assertRaisesRegex(err, msg): 1095 structured_tensor.StructuredTensor.from_pyval(pyval, type_spec) 1096 1097 def testToPyvalRequiresEagerMode(self): 1098 st = structured_tensor.StructuredTensor.from_pyval({"a": 5}) 1099 if not context.executing_eagerly(): 1100 with self.assertRaisesRegex(ValueError, "only supported in eager mode."): 1101 st.to_pyval() 1102 1103 @parameterized.named_parameters([ 1104 ( 1105 "Rank0", 1106 [], 1107 ), 1108 ( 1109 "Rank1", 1110 [5, 3], 1111 ), 1112 ( 1113 "Rank2", 1114 [5, 8, 3], 1115 ), 1116 ( 1117 "Rank5", 1118 [1, 2, 3, 4, 5], 1119 ), 1120 ]) 1121 def testRowPartitionsFromUniformShape(self, shape): 1122 for rank in range(len(shape)): 1123 partitions = structured_tensor._row_partitions_for_uniform_shape( 1124 ops.convert_to_tensor(shape), rank) 1125 self.assertLen(partitions, max(0, rank - 1)) 1126 if partitions: 1127 self.assertAllEqual(shape[0], partitions[0].nrows()) 1128 for (dim, partition) in enumerate(partitions): 1129 self.assertAllEqual(shape[dim + 1], partition.uniform_row_length()) 1130 1131 @parameterized.named_parameters([ 1132 # For shapes: U = uniform dimension; R = ragged dimension. 1133 dict( 1134 testcase_name="Shape_UR_Rank2", 1135 rt=[[1, 2], [], [3]], 1136 rt_ragged_rank=1, 1137 rank=2, 1138 expected_row_lengths=[[2, 0, 1]]), 1139 dict( 1140 testcase_name="Shape_URR_Rank2", 1141 rt=[[[1, 2], []], [[3]]], 1142 rt_ragged_rank=2, 1143 rank=2, 1144 expected_row_lengths=[[2, 1]]), 1145 dict( 1146 testcase_name="Shape_URU_Rank2", 1147 rt=[[[1], [2]], [[3]]], 1148 rt_ragged_rank=1, 1149 rank=2, 1150 expected_row_lengths=[[2, 1]]), 1151 dict( 1152 testcase_name="Shape_URR_Rank3", 1153 rt=[[[1, 2], []], [[3]]], 1154 rt_ragged_rank=2, 1155 rank=3, 1156 expected_row_lengths=[[2, 1], [2, 0, 1]]), 1157 dict( 1158 testcase_name="Shape_URU_Rank3", 1159 rt=[[[1], [2]], [[3]]], 1160 rt_ragged_rank=1, 1161 rank=3, 1162 expected_row_lengths=[[2, 1], [1, 1, 1]]), 1163 dict( 1164 testcase_name="Shape_URRUU_Rank2", 1165 rt=[[[[[1, 2]]]]], 1166 rt_ragged_rank=2, 1167 rank=2, 1168 expected_row_lengths=[[1]]), 1169 dict( 1170 testcase_name="Shape_URRUU_Rank3", 1171 rt=[[[[[1, 2]]]]], 1172 rt_ragged_rank=2, 1173 rank=3, 1174 expected_row_lengths=[[1], [1]]), 1175 dict( 1176 testcase_name="Shape_URRUU_Rank4", 1177 rt=[[[[[1, 2]]]]], 1178 rt_ragged_rank=2, 1179 rank=4, 1180 expected_row_lengths=[[1], [1], [1]]), 1181 dict( 1182 testcase_name="Shape_URRUU_Rank5", 1183 rt=[[[[[1, 2]]]]], 1184 rt_ragged_rank=2, 1185 rank=5, 1186 expected_row_lengths=[[1], [1], [1], [2]]), 1187 ]) 1188 def testRowPartitionsForRaggedTensor(self, rt, rt_ragged_rank, rank, 1189 expected_row_lengths): 1190 rt = ragged_factory_ops.constant(rt, rt_ragged_rank) 1191 partitions = structured_tensor._row_partitions_for_ragged_tensor( 1192 rt, rank, dtypes.int64) 1193 self.assertLen(partitions, rank - 1) 1194 self.assertLen(partitions, len(expected_row_lengths)) 1195 for partition, expected in zip(partitions, expected_row_lengths): 1196 self.assertAllEqual(partition.row_lengths(), expected) 1197 1198 @parameterized.named_parameters([ 1199 dict( 1200 testcase_name="2D_0_1", 1201 st=[[{"x": 1}, {"x": 2}], [{"x": 3}]], 1202 outer_axis=0, inner_axis=1, 1203 expected=[{"x": 1}, {"x": 2}, {"x": 3}]), 1204 dict( 1205 testcase_name="3D_0_1", 1206 st=[[[{"x": 1}, {"x": 2}], [{"x": 3}]], [[{"x": 4}]]], 1207 outer_axis=0, inner_axis=1, 1208 expected=[[{"x": 1}, {"x": 2}], [{"x": 3}], [{"x": 4}]]), 1209 dict( 1210 testcase_name="3D_1_2", 1211 st=[[[{"x": 1}, {"x": 2}], [{"x": 3}]], [[{"x": 4}]]], 1212 outer_axis=1, inner_axis=2, 1213 expected=[[{"x": 1}, {"x": 2}, {"x": 3}], [{"x": 4}]]), 1214 dict( 1215 testcase_name="3D_0_2", 1216 st=[[[{"x": 1}, {"x": 2}], [{"x": 3}]], [[{"x": 4}]]], 1217 outer_axis=0, inner_axis=2, 1218 expected=[{"x": 1}, {"x": 2}, {"x": 3}, {"x": 4}]), 1219 dict( 1220 testcase_name="4D_0_1", 1221 st=[[[[{"x": 1}, {"x": 2}], [{"x": 3}]], [[{"x": 4}]]], 1222 [[[{"x": 5}]], [[{"x": 6}], [{"x": 7}]]]], 1223 outer_axis=0, inner_axis=1, 1224 expected=[[[{"x": 1}, {"x": 2}], [{"x": 3}]], [[{"x": 4}]], 1225 [[{"x": 5}]], [[{"x": 6}], [{"x": 7}]]]), 1226 dict( 1227 testcase_name="4D_0_2", 1228 st=[[[[{"x": 1}, {"x": 2}], [{"x": 3}]], [[{"x": 4}]]], 1229 [[[{"x": 5}]], [[{"x": 6}], [{"x": 7}]]]], 1230 outer_axis=0, inner_axis=2, 1231 expected=[[{"x": 1}, {"x": 2}], [{"x": 3}], [{"x": 4}], 1232 [{"x": 5}], [{"x": 6}], [{"x": 7}]]), 1233 dict( 1234 testcase_name="4D_0_3", 1235 st=[[[[{"x": 1}, {"x": 2}], [{"x": 3}]], [[{"x": 4}]]], 1236 [[[{"x": 5}]], [[{"x": 6}], [{"x": 7}]]]], 1237 outer_axis=0, inner_axis=3, 1238 expected=[{"x": 1}, {"x": 2}, {"x": 3}, {"x": 4}, 1239 {"x": 5}, {"x": 6}, {"x": 7}]), 1240 dict( 1241 testcase_name="4D_1_2", 1242 st=[[[[{"x": 1}, {"x": 2}], [{"x": 3}]], [[{"x": 4}]]], 1243 [[[{"x": 5}]], [[{"x": 6}], [{"x": 7}]]]], 1244 outer_axis=1, inner_axis=2, 1245 expected=[[[{"x": 1}, {"x": 2}], [{"x": 3}], [{"x": 4}]], 1246 [[{"x": 5}], [{"x": 6}], [{"x": 7}]]]), 1247 dict( 1248 testcase_name="4D_1_3", 1249 st=[[[[{"x": 1}, {"x": 2}], [{"x": 3}]], [[{"x": 4}]]], 1250 [[[{"x": 5}]], [[{"x": 6}], [{"x": 7}]]]], 1251 outer_axis=1, inner_axis=3, 1252 expected=[[{"x": 1}, {"x": 2}, {"x": 3}, {"x": 4}], 1253 [{"x": 5}, {"x": 6}, {"x": 7}]]), 1254 dict( 1255 testcase_name="4D_2_3", 1256 st=[[[[{"x": 1}, {"x": 2}], [{"x": 3}]], [[{"x": 4}]]], 1257 [[[{"x": 5}]], [[{"x": 6}], [{"x": 7}]]]], 1258 outer_axis=2, inner_axis=3, 1259 expected=[[[{"x": 1}, {"x": 2}, {"x": 3}], [{"x": 4}]], 1260 [[{"x": 5}], [{"x": 6}, {"x": 7}]]]), 1261 ]) # pyformat: disable 1262 def testMergeDims(self, st, outer_axis, inner_axis, expected): 1263 st = StructuredTensor.from_pyval(st) 1264 result = st.merge_dims(outer_axis, inner_axis) 1265 self.assertAllEqual(result, expected) 1266 1267 def testMergeDimsDetail_3D_0_1(self): 1268 st = StructuredTensor.from_pyval( 1269 [[[{"x": 1}, {"x": 2}], [{"x": 3}]], [[{"x": 4}]]]) 1270 result = st.merge_dims(0, 1) 1271 expected_shape = tensor_shape.TensorShape([3, None]) 1272 self.assertTrue(expected_shape.is_compatible_with(result.shape)) 1273 1274 def testMergeDims_0_1(self): 1275 rt = ragged_tensor.RaggedTensor.from_value_rowids( 1276 array_ops.constant([[1, 2], [3, 4], [5, 6]]), [0, 0, 1]) 1277 struct = StructuredTensor.from_fields({"r": rt}, [2]) 1278 struct_2 = struct.partition_outer_dimension( 1279 row_partition.RowPartition.from_row_splits([0, 1, 2])) 1280 struct_3 = struct_2.partition_outer_dimension( 1281 row_partition.RowPartition.from_row_splits([0, 1, 2])) 1282 self.assertLen(struct_3.row_partitions, 2) 1283 merged = struct_3.merge_dims(0, 1) 1284 self.assertLen(merged.row_partitions, 1) 1285 1286 def testMergeDimsError(self): 1287 st = StructuredTensor.from_pyval([[[{"a": 5}]]]) 1288 with self.assertRaisesRegex( 1289 ValueError, r"Expected outer_axis \(2\) to be less than " 1290 r"or equal to inner_axis \(1\)"): 1291 st.merge_dims(2, 1) 1292 1293 def testTupleFieldValue(self): 1294 st = StructuredTensor.from_pyval({"a": 5, "b": {"c": [1, 2, 3]}}) 1295 self.assertAllEqual(st.field_value(("a",)), 5) 1296 self.assertAllEqual(st.field_value(("b", "c")), [1, 2, 3]) 1297 expected = r"Field path \(.*a.*,.*b.*\) not found in .*" 1298 with self.assertRaisesRegex(KeyError, expected): 1299 st.field_value(("a", "b")) 1300 1301 @parameterized.named_parameters([ 1302 dict( 1303 testcase_name="scalar_scalar_scalar", 1304 st={"b": {"a": 5}}, 1305 source_path=("b", "a"), 1306 new_field_name="new_field", 1307 expected={"b": {"a": 5}, "new_field": 5},), 1308 dict( 1309 testcase_name="scalar_scalar_repeated", 1310 st={"b": {"a": [5, 3]}}, 1311 source_path=("b", "a"), 1312 new_field_name="new_field", 1313 expected={"b": {"a": [5, 3]}, "new_field": [5, 3]}), 1314 dict( 1315 testcase_name="scalar_scalar_repeated2", 1316 st={"b": {"a": [[7], [5, 3]]}}, 1317 source_path=("b", "a"), 1318 new_field_name="new_field", 1319 expected={"b": {"a": [[7], [5, 3]]}, "new_field": [[7], [5, 3]]}), 1320 dict( 1321 testcase_name="repeated_scalar_repeated", 1322 st=[{"b": {"a": [7]}}, 1323 {"b": {"a": [5, 3]}}], 1324 source_path=("b", "a"), 1325 new_field_name="new_field", 1326 expected=[{"b": {"a": [7]}, "new_field": [7]}, 1327 {"b": {"a": [5, 3]}, "new_field": [5, 3]}]), 1328 dict( 1329 testcase_name="repeated_scalar_repeated2", 1330 st=[{"b": {"a": [[5, 7], []]}}, 1331 {"b": {"a": [[5, 1], [3]]}}], 1332 source_path=("b", "a"), 1333 new_field_name="new_field", 1334 expected=[{"b": {"a": [[5, 7], []]}, 1335 "new_field": [[5, 7], []]}, 1336 {"b": {"a": [[5, 1], [3]]}, 1337 "new_field": [[5, 1], [3]]}]), 1338 dict( 1339 testcase_name="scalar_scalar_scalar_scalar", 1340 st={"a": {"b": {"c": 7}}}, 1341 source_path=("a", "b", "c"), 1342 new_field_name="new_field", 1343 expected={"a": {"b": {"c": 7}, "new_field": 7}}), 1344 dict( 1345 testcase_name="repeated_scalar_scalar_scalar", 1346 st=[{"a": {"b": {"c": 7}}}, 1347 {"a": {"b": {"c": 5}}}], 1348 source_path=("a", "b", "c"), 1349 new_field_name="new_field", 1350 expected=[{"a": {"b": {"c": 7}, "new_field": 7}}, 1351 {"a": {"b": {"c": 5}, "new_field": 5}}],), 1352 dict( 1353 testcase_name="repeated_repeated_scalar_scalar", 1354 st=[{"a": [{"b": {"c": 7}}, {"b": {"c": 3}}]}, 1355 {"a": [{"b": {"c": 5}}]}], 1356 source_path=("a", "b", "c"), 1357 new_field_name="new_field", 1358 expected=[{"a": [{"b": {"c": 7}, "new_field": 7}, 1359 {"b": {"c": 3}, "new_field": 3}]}, 1360 {"a": [{"b": {"c": 5}, "new_field": 5}]}]), 1361 dict( 1362 testcase_name="docs_tokens", 1363 st=[{"docs": [{"tokens": [7, 17]}, {"tokens": [3, 13]}]}, 1364 {"docs": [{"tokens": [5, 15]}]}], 1365 source_path=("docs", "tokens"), 1366 new_field_name="docs_tokens", 1367 expected=[{"docs": [{"tokens": [7, 17]}, {"tokens": [3, 13]}], 1368 "docs_tokens": [7, 17, 3, 13]}, 1369 {"docs": [{"tokens": [5, 15]}], 1370 "docs_tokens": [5, 15]}], 1371 ), 1372 dict( 1373 testcase_name="repeated_repeated_scalar_repeated", 1374 st=[{"a": [{"b": {"c": [7, 17]}}, {"b": {"c": [3, 13]}}]}, 1375 {"a": [{"b": {"c": [5, 15]}}]}], 1376 source_path=("a", "b", "c"), 1377 new_field_name="new_field", 1378 expected=[{"a": [{"b": {"c": [7, 17]}, "new_field": [7, 17]}, 1379 {"b": {"c": [3, 13]}, "new_field": [3, 13]}]}, 1380 {"a": [{"b": {"c": [5, 15]}, "new_field": [5, 15]}]}]), 1381 dict( 1382 testcase_name="scalar_scalar_scalar_repeated", 1383 st={"a": {"b": {"c": [7, 3, 5]}}}, 1384 source_path=("a", "b", "c"), 1385 new_field_name="new_field", 1386 expected={"a": {"b": {"c": [7, 3, 5]}, "new_field": [7, 3, 5]}}), 1387 dict( 1388 testcase_name="repeated_repeated_scalar_repeated2", 1389 st=[{"a": [{"b": {"c": [[7, 3], [17]]}}, {"b": {"c": [[3, 13]]}}]}, 1390 {"a": [{"b": {"c": [[5, 15]]}}]}], 1391 source_path=("a", "b", "c"), 1392 new_field_name="new_field", 1393 expected=[{"a": [{"b": {"c": [[7, 3], [17]]}, 1394 "new_field": [[7, 3], [17]]}, 1395 {"b": {"c": [[3, 13]]}, 1396 "new_field": [[3, 13]]}]}, 1397 {"a": [{"b": {"c": [[5, 15]]}, 1398 "new_field": [[5, 15]]}]}]), 1399 dict(testcase_name="example_4_promote_of_labeled_vector", 1400 st=[{"user_info": [{"gaia_id": {"vec": [0, 1, 2]}}]}, 1401 {"user_info": [{"gaia_id": {"vec": [3, 4, 5]}}]}], 1402 source_path=("user_info", "gaia_id"), 1403 new_field_name="user_info_gaia_id", 1404 expected=[{"user_info": [{"gaia_id": {"vec": [0, 1, 2]}}], 1405 "user_info_gaia_id": [{"vec": [0, 1, 2]}]}, 1406 {"user_info": [{"gaia_id": {"vec": [3, 4, 5]}}], 1407 "user_info_gaia_id": [{"vec": [3, 4, 5]}]}]), 1408 dict( 1409 testcase_name="promote_structure", 1410 st=[{"a": [{"aa": [{"b": {"c": 1}}, {"b": {"c": 8}}]}],}, 1411 {"a": [{"aa": [{"b": {"c": 12}}]}],}], 1412 source_path=("a", "aa", "b"), 1413 new_field_name="new_field", 1414 expected=[{"a": [{"aa": [{"b": {"c": 1}}, {"b": {"c": 8}}], 1415 "new_field": [{"c": 1}, {"c": 8}]}]}, 1416 {"a": [{"aa": [{"b": {"c": 12}}], 1417 "new_field": [{"c": 12}]}]}])]) # pyformat: disable 1418 def testPromote(self, st, source_path, new_field_name, expected): 1419 st2 = StructuredTensor.from_pyval(st) 1420 expected2 = StructuredTensor.from_pyval(expected) 1421 result = st2.promote(source_path, new_field_name) 1422 self.assertAllEqual(result, expected2) 1423 1424 def testPromoteDense(self): 1425 st = StructuredTensor.from_fields( 1426 { 1427 "a": 1428 StructuredTensor.from_fields( 1429 {"b": [[[1, 11], [2, 12]], [[3, 13], [4, 14]]]}, 1430 shape=[2, 2, 2]) 1431 }, 1432 shape=[2]) 1433 result = st.promote(("a", "b"), "new_field") 1434 self.assertEqual(st.rank, 1) 1435 self.assertEqual(st.field_value("a").rank, 3) 1436 self.assertAllEqual( 1437 result.field_value("new_field"), [[1, 11, 2, 12], [3, 13, 4, 14]]) 1438 1439 def testMergeDimsGeneric(self): 1440 """This is an example of a dense tensor being merged, when outer=rank. 1441 1442 Note that outer=rank is equivalent to outer=rank - 1. And yet, from the 1443 perspective of promote, it is nice to be able to have this functionality 1444 directly available, because sometimes the rank of the parent equals the 1445 rank of the child. 1446 1447 Finally, note that merge_dims for Ragged and StructuredTensor would not 1448 accept this as a valid argument. 1449 1450 Note: _merge_dims_generic is private, but these unit tests help to 1451 discuss the proper API definition. 1452 """ 1453 t = array_ops.constant([[[1, 11], [2, 12]], [[3, 13], [4, 14]]]) 1454 t2 = structured_tensor._merge_dims_generic(t, 1, 3) 1455 self.assertAllEqual(t2, [[1, 11, 2, 12], [3, 13, 4, 14]]) 1456 1457 def testMergeDimsGenericNoop(self): 1458 """This is an example of a dense tensor being merged, when outer=inner. 1459 1460 Sometimes, when promoting, the parent and grandparent ranks are equal. 1461 Finally, note that merge_dims for Ragged and StructuredTensor would not 1462 accept this as a valid argument. This should be aligned. 1463 """ 1464 t = array_ops.constant([[[1, 11], [2, 12]], [[3, 13], [4, 14]]]) 1465 t2 = structured_tensor._merge_dims_generic(t, 2, 2) 1466 self.assertAllEqual(t2, [[[1, 11], [2, 12]], [[3, 13], [4, 14]]]) 1467 1468 def testRepr(self): 1469 st = StructuredTensor.from_pyval({"a": 5, "b": {"c": [1, 2, 3]}}) 1470 if context.executing_eagerly(): 1471 expected = textwrap.dedent(""" 1472 <StructuredTensor( 1473 fields={ 1474 "a": tf.Tensor(5, shape=(), dtype=int32), 1475 "b": <StructuredTensor( 1476 fields={ 1477 "c": tf.Tensor([1 2 3], shape=(3,), dtype=int32)}, 1478 shape=())>}, 1479 shape=())>""")[1:] 1480 else: 1481 expected = textwrap.dedent(""" 1482 <StructuredTensor( 1483 fields={ 1484 "a": Tensor("Const:0", shape=(), dtype=int32), 1485 "b": <StructuredTensor( 1486 fields={ 1487 "c": Tensor("RaggedConstant/Const:0", shape=(3,), dtype=int32)}, 1488 shape=())>}, 1489 shape=())>""")[1:] 1490 self.assertEqual(repr(st), expected) 1491 1492 def testPartitionOuterDimension2DDenseField(self): 1493 struct = structured_tensor.StructuredTensor.from_fields( 1494 fields={"r": array_ops.constant([[1, 2], [3, 4]])}, shape=[2]) 1495 1496 result = struct.partition_outer_dimension( 1497 row_partition.RowPartition.from_uniform_row_length(2, 2)) 1498 r = result.field_value("r") 1499 self.assertAllEqual(r, [[[1, 2], [3, 4]]]) 1500 1501 @parameterized.parameters([ 1502 # Simple example. 1503 ( 1504 {"a": 12, "b": 23}, 1505 {"a": 7}, 1506 ), 1507 # New field. 1508 ( 1509 {"a": 12}, 1510 {("b",): 13}, 1511 ), 1512 # Nested example. 1513 ( 1514 {"a": 12, "b": {"c": 23}}, 1515 {("b", "c"): 7}, 1516 ), 1517 # Multipe updates. 1518 ( 1519 {"a": 12, "b": {"c": 23}}, 1520 {"a": 3, ("b", "c"): 7}, 1521 ), 1522 # Deep updates. 1523 ( 1524 {"a": 12, "b": {"c": 23, "d": {"e": 11}}}, 1525 {("b", "c"): 7, ("b", "d", "e"): 13}, 1526 ), 1527 # Multiple updates to the same substructure. 1528 ( 1529 {"a": 12, "b": {"c": 23, "d": {"e": 11}}}, 1530 {("b", "c"): 7, ("b", "f"): 13}, 1531 ), 1532 # Scalar to non-scalar elements. Shape remains unchanged. 1533 ( 1534 {"a": 5}, 1535 {"a": ragged_factory_ops.constant_value([[51, 52], [61, 62, 63]])}, 1536 ), 1537 # Non-scalar element to scalar. 1538 ( 1539 {"c": {"a": [5, 3], "b": 2}}, 1540 {("c", "a"): 5}, 1541 ), 1542 # Rank-1 StructuredTensor: shape is preserved and an item is added. 1543 ( 1544 [{"a": 5}, {"a": 6}], 1545 {"a": [15, 16], "b": np.array([0.9, 1.1])}, 1546 ), 1547 # Non-scalar ragged elements, within a rank-2 StructuredTensor: elements 1548 # rows (inner dimensions) are changed, but StructuredTensor shape 1549 # (outer dimensions) are preserved. 1550 ( 1551 [[{"a": [5]}], [{"a": [3, 4]}, {"a": [8]}]], 1552 {"a": ragged_factory_ops.constant_value([[[50, 60]], [[30], []]])}, 1553 ), 1554 ]) # pyformat: disable 1555 def testWithUpdatesValues(self, pyval, updates): 1556 st = StructuredTensor.from_pyval(pyval) 1557 updated_st = st.with_updates(updates, validate=False) 1558 for key, value in updates.items(): 1559 got = updated_st.field_value(key) 1560 self.assertAllEqual( 1561 value, got, 1562 "Update failed: key={}, value={}, got={}".format(key, value, got)) 1563 1564 def testWithUpdatesFunctions(self): 1565 pyval = {"a": 12, "b": {"c": 23, "d": {"e": 11}}} 1566 st = StructuredTensor.from_pyval(pyval) 1567 st_updated = st.with_updates( 1568 { 1569 "a": lambda x: x + 1, 1570 ("b", "d", "e"): lambda x: x + 7 1571 }, validate=True) 1572 # Updated values. 1573 self.assertAllEqual(st_updated.field_value("a"), 13) 1574 self.assertAllEqual(st_updated.field_value(("b", "d", "e")), 18) 1575 # Unchanged value. 1576 self.assertAllEqual(st_updated.field_value(("b", "c")), 23) 1577 1578 def test_from_pyval_list_of_empty(self): 1579 """See b/183245576.""" 1580 st = structured_tensor.StructuredTensor.from_pyval([{}]) 1581 self.assertAllEqual([1], st.shape.as_list()) 1582 1583 def test_from_pyval_list_of_empty_three(self): 1584 """See b/183245576.""" 1585 st = structured_tensor.StructuredTensor.from_pyval([{}, {}, {}]) 1586 self.assertAllEqual([3], st.shape.as_list()) 1587 self.assertEmpty(st.field_names()) 1588 1589 def test_from_pyval_deep_list_of_empty(self): 1590 """See b/183245576.""" 1591 st = structured_tensor.StructuredTensor.from_pyval([[{ 1592 "a": {}, 1593 "b": [3, 4] 1594 }, { 1595 "a": {}, 1596 "b": [5] 1597 }], [{ 1598 "a": {}, 1599 "b": [7, 8, 9] 1600 }]]) 1601 self.assertAllEqual(2, st.rank) 1602 self.assertEqual(2, st.shape[0]) 1603 self.assertEmpty(st.field_value("a").field_names()) 1604 1605 def testWithUpdatesChecks(self): 1606 pyval = {"a": 12, "b": {"c": 23, "d": {"e": 11}}} 1607 st = StructuredTensor.from_pyval(pyval) 1608 1609 # Try to set non-existant sub-structure. 1610 with self.assertRaisesRegex( 1611 ValueError, r"cannot create new sub-field.*\('b', 'x'\).*is not set"): 1612 st.with_updates({("b", "x", "e"): 5}) 1613 1614 # Try to set with path to a non-sub-structure. 1615 with self.assertRaisesRegex( 1616 ValueError, r"cannot create new sub-field.*\('b', 'c'\).*is not a " 1617 r"`StructuredTensor`"): 1618 st.with_updates({("b", "c", "e"): 5}) 1619 1620 # Try to apply function to non-existing value. 1621 with self.assertRaisesRegex( 1622 ValueError, r"cannot update.*\('b', 'd', 'x'\).*does not already " 1623 r"exist"): 1624 st.with_updates({("b", "d", "x"): lambda x: x + 1}) 1625 1626 # Empty names not allowed. 1627 with self.assertRaisesRegex(ValueError, r"does not allow empty names"): 1628 st.with_updates({(): lambda x: x + 1}) 1629 with self.assertRaisesRegex(ValueError, r"does not allow empty names"): 1630 st.with_updates({("b", ""): lambda x: x + 1}) 1631 1632 # Parent and child nodes cannot be updated simultaneously. 1633 with self.assertRaisesRegex( 1634 ValueError, r"does not allow both parent and child nodes.*" 1635 r"parent=\('b'.*child=\('b', 'd'"): 1636 st.with_updates({("b", "d"): lambda x: x + 1, "a": 3, "b": 10}) 1637 1638 # Invalid shape change. 1639 with self.assertRaisesRegex( 1640 ValueError, 1641 r"`StructuredTensor.with_updates` failed for field \('c',\)"): 1642 st_with_shape = StructuredTensor.from_pyval([[{ 1643 "c": { 1644 "a": 5, 1645 "b": 2 1646 } 1647 }], [{ 1648 "c": { 1649 "a": 3, 1650 "b": 1 1651 } 1652 }, { 1653 "c": { 1654 "a": 8, 1655 "b": 18 1656 } 1657 }]]) 1658 st_with_shape.with_updates({("c", "a"): 3}) 1659 1660 def testWithUpdatesDelete(self): 1661 pyval = {"a": 12, "b": {"c": 23, "d": {"e": 11}}} 1662 st = StructuredTensor.from_pyval(pyval) 1663 updated_st = st.with_updates({("b", "c"): None}, validate=True) 1664 self.assertNotIn("c", updated_st.field_value("b").field_names()) 1665 with self.assertRaisesRegex(ValueError, 1666 r"cannot delete.*\('b', 'x'\).*not present"): 1667 st.with_updates({("b", "x"): None}, validate=True) 1668 with self.assertRaisesRegex(ValueError, 1669 r"cannot delete.*\'x'.*not present"): 1670 st.with_updates({"x": None}, validate=False) 1671 1672 # Test that nrows() and rowpartitions() is preserved after removal. 1673 pyval = [[{"a": 1}, {"a": 2}], [{"a": 3}]] 1674 st = StructuredTensor.from_pyval(pyval) 1675 self.assertLen(st.row_partitions, 1) 1676 self.assertAllEqual(st.nrows(), 2) 1677 self.assertAllEqual(st.row_partitions[0].row_lengths(), [2, 1]) 1678 updated_st = st.with_updates({("a",): None}, validate=True) 1679 self.assertLen(updated_st.row_partitions, 1) 1680 self.assertAllEqual(updated_st.nrows(), 2) 1681 self.assertAllEqual(updated_st.row_partitions[0].row_lengths(), [2, 1]) 1682 1683 # Test that it works also for rank-1 and rank-0 empty results. 1684 pyval = [{"a": 1}, {"a": 2}] 1685 st = StructuredTensor.from_pyval(pyval) 1686 self.assertEqual(st.rank, 1) 1687 updated_st = st.with_updates({("a",): None}, validate=True) 1688 self.assertEqual(updated_st.rank, 1) 1689 1690 # assertEqual won't work because nrows() returns a tensor, and 1691 # assertEqual doesn't do the magic to convert them to numbers in a 1692 # way that works in eager/non-eager mode. 1693 self.assertAllEqual(updated_st.nrows(), 2) 1694 pyval = {"a": [0, 1]} 1695 st = StructuredTensor.from_pyval(pyval) 1696 self.assertEqual(st.rank, 0) 1697 updated_st = st.with_updates({("a",): None}, validate=True) 1698 self.assertEqual(updated_st.rank, 0) 1699 self.assertFalse(updated_st.row_partitions) 1700 self.assertIsNone(updated_st.nrows()) 1701 1702 def test_from_pyval_deep_row_partitions(self): 1703 """See b/179195750.""" 1704 st = structured_tensor.StructuredTensor.from_pyval([{ 1705 "foo": [{ 1706 "bar": [{ 1707 "baz": [b"FW"] 1708 }] 1709 }] 1710 }]) 1711 st2 = st.field_value(("foo", "bar")) 1712 self.assertLen(st2.row_partitions, st2.rank - 1) 1713 1714 def test_from_fields_deep_row_partitions(self): 1715 """Test a field with its own row_partition. See b/179195750.""" 1716 st = structured_tensor.StructuredTensor.from_pyval([[[{"baz": [b"FW"]}]]]) 1717 self.assertLen(st.row_partitions, st.rank - 1) 1718 st2 = structured_tensor.StructuredTensor.from_fields( 1719 fields={"bar": st}, shape=(None, None), validate=False) 1720 st3 = st2.field_value("bar") 1721 self.assertLen(st3.row_partitions, st3.rank - 1) 1722 1723 def test_structured_tensor_spec_shape_property(self): 1724 spec = StructuredTensor.Spec._from_shape(DynamicRaggedShape.Spec( 1725 row_partitions=[], 1726 static_inner_shape=[1, 2], 1727 dtype=dtypes.int64)) 1728 self.assertEqual(spec.shape.as_list(), [1, 2]) 1729 spec = StructuredTensor.Spec._from_shape(DynamicRaggedShape.Spec( 1730 row_partitions=[], 1731 static_inner_shape=[None], 1732 dtype=dtypes.int64)) 1733 self.assertEqual(spec.shape.as_list(), [None]) 1734 1735 def test_dynamic_ragged_shape_init_vector(self): 1736 x = constant_op.constant([1, 2, 3, 4]) 1737 y = constant_op.constant([[1, 2], [3, 4], [5, 6], [7, 8]]) 1738 fields = {"x": x, "y": y} 1739 nrows = constant_op.constant(4) 1740 shape = tensor_shape.TensorShape((4,)) 1741 row_partitions = () 1742 rs = structured_tensor_dynamic._dynamic_ragged_shape_init( 1743 fields, shape, nrows, row_partitions) 1744 self.assertEqual( 1745 repr(rs._to_tensor_shape()), repr(tensor_shape.TensorShape((4,)))) 1746 1747 def test_dynamic_ragged_shape_init_scalar(self): 1748 x = constant_op.constant([1, 2, 3, 4]) 1749 y = constant_op.constant([[1, 2], [3, 4], [5, 6], [7, 8]]) 1750 fields = {"x": x, "y": y} 1751 nrows = None 1752 shape = tensor_shape.TensorShape(()) 1753 row_partitions = () 1754 1755 rs = structured_tensor_dynamic._dynamic_ragged_shape_init( 1756 fields, shape, nrows, row_partitions) 1757 self.assertEqual( 1758 repr(rs._to_tensor_shape()), repr(tensor_shape.TensorShape(()))) 1759 1760 def test_dynamic_ragged_shape_init_ragged(self): 1761 x = ragged_factory_ops.constant_value([[1, 2, 3], [4]]) 1762 fields = {"x": x} 1763 nrows = constant_op.constant(2, dtype=dtypes.int64) 1764 shape = tensor_shape.TensorShape([2, None]) 1765 row_partitions = tuple(x._nested_row_partitions) 1766 rs = structured_tensor_dynamic._dynamic_ragged_shape_init( 1767 fields, shape, nrows, row_partitions) 1768 self.assertEqual( 1769 repr(rs._to_tensor_shape()), repr(tensor_shape.TensorShape((2, None)))) 1770 1771 1772if __name__ == "__main__": 1773 googletest.main() 1774