1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tf.framework.extension_type.""" 16 17import contextlib 18import copy 19import pickle 20import tempfile 21import typing 22 23from absl.testing import parameterized 24import typing_extensions 25 26from tensorflow.core.framework import full_type_pb2 27from tensorflow.python.data.ops import dataset_ops 28from tensorflow.python.distribute import mirrored_strategy 29from tensorflow.python.eager import context 30from tensorflow.python.eager import def_function 31from tensorflow.python.framework import constant_op 32from tensorflow.python.framework import dtypes 33from tensorflow.python.framework import extension_type 34from tensorflow.python.framework import extension_type_field 35from tensorflow.python.framework import immutable_dict 36from tensorflow.python.framework import ops 37from tensorflow.python.framework import tensor_shape 38from tensorflow.python.framework import tensor_spec 39from tensorflow.python.framework import test_util 40from tensorflow.python.framework import type_spec 41from tensorflow.python.framework.type_utils import fulltypes_for_flat_tensors 42from tensorflow.python.keras.engine import input_layer 43from tensorflow.python.keras.engine import training 44from tensorflow.python.keras.saving import save as keras_save 45from tensorflow.python.module import module 46from tensorflow.python.ops import array_ops 47from tensorflow.python.ops import control_flow_ops 48from tensorflow.python.ops import math_ops 49from tensorflow.python.ops.ragged import ragged_factory_ops 50from tensorflow.python.ops.ragged import ragged_tensor 51from tensorflow.python.platform import googletest 52from tensorflow.python.platform import test 53from tensorflow.python.saved_model import load 54from tensorflow.python.saved_model import save 55from tensorflow.python.util import dispatch 56from tensorflow.python.util import nest 57from tensorflow.python.util import tf_inspect 58 59 60POSITIONAL_OR_KEYWORD = tf_inspect.Parameter.POSITIONAL_OR_KEYWORD 61KEYWORD_ONLY = tf_inspect.Parameter.KEYWORD_ONLY 62 63 64class MaskedTensorV1(extension_type.ExtensionType): 65 """Example subclass of ExtensionType, used for testing.""" 66 values: ops.Tensor 67 mask: ops.Tensor 68 69 70class MaskedTensorV2(extension_type.ExtensionType): 71 """Example subclass of ExtensionType, used for testing. 72 73 This version adds methods, classmethod, staticmethod, and properties, and 74 customizes `__repr__` and `__validate__`. It also adds a `__name__` field, 75 which enables serialization. 76 """ 77 __name__ = 'tf.test.MaskedTensorV2' 78 79 values: ops.Tensor 80 mask: ops.Tensor 81 82 def __repr__(self): 83 if hasattr(self.values, 'numpy') and hasattr(self.mask, 'numpy'): 84 return '<MaskedTensorV2 %s>' % _masked_array_repr(self.values.numpy(), 85 self.mask.numpy()) 86 else: 87 return super(MaskedTensorV2, self).__repr__() 88 89 @property 90 def shape(self): 91 return self.values.shape 92 93 @property 94 def dtype(self): 95 return self.values.dtype 96 97 @classmethod 98 def from_full_tensor(cls, values): 99 return cls(values, array_ops.ones_like(values, dtype=dtypes.bool)) 100 101 # A dummy example to test support of staticmethod 102 @staticmethod 103 def doc_link(): 104 return 'http://example.com/masked_tensor' 105 106 def __validate__(self): 107 self.values.shape.assert_is_compatible_with(self.mask.shape) 108 109 def with_default(self, default): 110 return array_ops.where_v2(self.mask, self.values, default) 111 112 __add__ = math_ops.add 113 __sub__ = math_ops.subtract 114 115 116def _masked_array_repr(values, mask): 117 """Returns a string representation for a masked numpy array.""" 118 assert len(values) == len(mask) 119 if len(values.shape) == 1: 120 items = [repr(v) if m else '_' for (v, m) in zip(values, mask)] 121 else: 122 items = [_masked_array_repr(v, m) for (v, m) in zip(values, mask)] 123 return '[%s]' % ', '.join(items) 124 125 126class MaskedTensorV3(extension_type.BatchableExtensionType): 127 """Example subclass of ExtensionType, used for testing. 128 129 This version adds Keras required properties to MaskedTensor and its Spec 130 class, to test Keras integration. 131 """ 132 __name__ = 'tf.test.MaskedTensorV3.Spec' 133 134 values: typing.Union[ops.Tensor, ragged_tensor.RaggedTensor] 135 mask: typing.Union[ops.Tensor, ragged_tensor.RaggedTensor] 136 137 def __init__(self, values, mask): 138 if isinstance(values, ragged_tensor.RaggedTensor): 139 assert isinstance(mask, ragged_tensor.RaggedTensor) 140 assert mask.dtype == dtypes.bool 141 else: 142 values = ops.convert_to_tensor(values) 143 mask = ops.convert_to_tensor(mask, dtypes.bool) 144 self.values = values 145 self.mask = mask 146 147 # Required by assert_input_compatibility in keras/engine/input_spec.py 148 @property 149 def shape(self): 150 return self.values.shape 151 152 @property 153 def dtype(self): 154 return self.values.dtype 155 156 class Spec: 157 158 # Required by KerasTensor.shape in keras/engine/keras_tensor.py 159 @property 160 def _shape(self): 161 return self.values._shape 162 163 164class ForwardRefA(extension_type.ExtensionType): 165 x: typing.Tuple[typing.Union['ForwardRefA', 'ForwardRefB'], ...] 166 y: 'ForwardRefB' 167 168 169class ForwardRefB(extension_type.ExtensionType): 170 z: 'ForwardRefB' 171 n: ops.Tensor 172 173 174class ExtensionTypeWithTensorDefault(extension_type.ExtensionType): 175 x: ops.Tensor = 5 176 y: ops.Tensor = ['a', 'b', 'c'] 177 178 179@test_util.run_all_in_graph_and_eager_modes 180class ExtensionTypeTest(test_util.TensorFlowTestCase, parameterized.TestCase): 181 182 def testAttributeAccessors(self): 183 mt1 = MaskedTensorV2([1, 2, 3, 4], [True, True, False, True]) 184 mt2 = extension_type.pack(mt1) 185 186 for mt in [mt1, mt2]: 187 self.assertIsInstance(mt.values, ops.Tensor) 188 self.assertAllEqual(mt.values, [1, 2, 3, 4]) 189 self.assertIsInstance(mt.mask, ops.Tensor) 190 self.assertAllEqual(mt.mask, [True, True, False, True]) 191 192 def testAttributesAreImmutable(self): 193 mt1 = MaskedTensorV2([1, 2, 3, 4], [True, True, False, True]) 194 mt2 = extension_type.pack(mt1) 195 196 for mt in [mt1, mt2]: 197 with self.assertRaisesRegex( 198 AttributeError, 199 'Cannot mutate attribute `score` outside the custom constructor of ExtensionType' 200 ): 201 mt.score = 12 202 with self.assertRaisesRegex( 203 AttributeError, 204 'Cannot mutate attribute `values` outside the custom constructor of ExtensionType' 205 ): 206 mt.values = constant_op.constant([4, 3, 2, 1]) 207 with self.assertRaisesRegex( 208 AttributeError, 209 'Cannot mutate attribute `values` outside the custom constructor of ExtensionType' 210 ): 211 del mt.values 212 213 def testClassAndStaticMethod(self): 214 mt = MaskedTensorV2.from_full_tensor([1, 2, 3, 4]) 215 self.assertAllEqual(mt.mask, [True, True, True, True]) 216 self.assertEqual(mt.doc_link(), 'http://example.com/masked_tensor') 217 218 def testRepr(self): 219 values = constant_op.constant([1, 2, 3, 4]) 220 mask = constant_op.constant([True, True, False, True]) 221 mt = MaskedTensorV1(values, mask) 222 expected = f'MaskedTensorV1(values={values!r}, mask={mask!r})' 223 self.assertEqual(expected, repr(mt)) 224 225 def testEagerRepr(self): 226 values = constant_op.constant([1, 2, 3, 4]) 227 mask = constant_op.constant([True, True, False, True]) 228 mt = MaskedTensorV2(values, mask) 229 if context.executing_eagerly(): 230 expected = '<MaskedTensorV2 [1, 2, _, 4]>' 231 else: 232 expected = f'MaskedTensorV2(values={values!r}, mask={mask!r})' 233 234 self.assertEqual(expected, repr(mt)) 235 self.assertEqual(expected, repr(mt)) 236 237 def testConstructorSignature(self): 238 239 class MyType(extension_type.ExtensionType): 240 x: ops.Tensor 241 y: ops.Tensor 242 z: typing.Tuple[typing.Union[int, str], ...] = [1, 'two', 3] 243 244 expected_parameters = [ 245 tf_inspect.Parameter('self', POSITIONAL_OR_KEYWORD), 246 tf_inspect.Parameter('x', POSITIONAL_OR_KEYWORD, annotation=ops.Tensor), 247 tf_inspect.Parameter('y', POSITIONAL_OR_KEYWORD, annotation=ops.Tensor), 248 tf_inspect.Parameter( 249 'z', 250 POSITIONAL_OR_KEYWORD, 251 annotation=typing.Tuple[typing.Union[int, str], ...], 252 default=(1, 'two', 3)), 253 ] 254 expected_sig = tf_inspect.Signature( 255 expected_parameters, return_annotation=MyType) 256 self.assertEqual(expected_sig, tf_inspect.signature(MyType.__init__)) 257 258 def testConstructorSignatureWithKeywordOnlyArgs(self): 259 260 class MyType(extension_type.ExtensionType): 261 a: int 262 b: str = 'Hello world' 263 c: ops.Tensor 264 265 expected_parameters = [ 266 tf_inspect.Parameter('self', POSITIONAL_OR_KEYWORD), 267 tf_inspect.Parameter('a', POSITIONAL_OR_KEYWORD, annotation=int), 268 tf_inspect.Parameter( 269 'b', POSITIONAL_OR_KEYWORD, annotation=str, default='Hello world'), 270 tf_inspect.Parameter('c', KEYWORD_ONLY, annotation=ops.Tensor), 271 ] 272 expected_sig = tf_inspect.Signature( 273 expected_parameters, return_annotation=MyType) 274 self.assertEqual(expected_sig, tf_inspect.signature(MyType.__init__)) 275 276 def testConstructorSignatureWithDefaultForTensorField(self): 277 a = ExtensionTypeWithTensorDefault() 278 279 # Check that the default values were *not* converted to Tensors: 280 sig = tf_inspect.signature(ExtensionTypeWithTensorDefault.__init__) 281 self.assertIsInstance(sig.parameters['x'].default, int) 282 self.assertIsInstance(sig.parameters['y'].default, list) 283 284 # The following would fail with "RuntimeError: Attempting to capture an 285 # EagerTensor without building a function" if we converted the default 286 # value to a Tensor when we built the type. 287 self.assertAllEqual(a.x + constant_op.constant(3), 8) 288 289 def testConstructorSignatureWithAnnotatedTensorField(self): 290 291 class MyType(extension_type.ExtensionType): 292 a: typing_extensions.Annotated[ops.Tensor, 'metadata'] 293 b: typing_extensions.Annotated[str, 'metadata'] = 'Hello world' 294 c: typing.Optional[typing_extensions.Annotated[int, 'metadata']] = None 295 296 expected_parameters = [ 297 tf_inspect.Parameter('self', POSITIONAL_OR_KEYWORD), 298 tf_inspect.Parameter('a', POSITIONAL_OR_KEYWORD, annotation=ops.Tensor), 299 tf_inspect.Parameter( 300 'b', POSITIONAL_OR_KEYWORD, annotation=str, default='Hello world'), 301 tf_inspect.Parameter( 302 'c', 303 POSITIONAL_OR_KEYWORD, 304 annotation=typing.Optional[int], 305 default=None), 306 ] 307 expected_sig = tf_inspect.Signature( 308 expected_parameters, return_annotation=MyType) 309 self.assertEqual(expected_sig, tf_inspect.signature(MyType.__init__)) 310 311 def testEmptyType(self): 312 313 class EmptyType(extension_type.ExtensionType): 314 pass 315 316 self.assertEmpty(EmptyType._tf_extension_type_fields()) 317 x = EmptyType() 318 self.assertEqual( 319 repr(x), 'ExtensionTypeTest.testEmptyType.<locals>.EmptyType()') 320 321 def testCustomConstrutor(self): 322 323 class SummarizedTensor(extension_type.ExtensionType): 324 values: ops.Tensor 325 mean: ops.Tensor 326 max: ops.Tensor 327 328 def __init__(self, values): 329 self.values = ops.convert_to_tensor(values) 330 self.mean = math_ops.reduce_mean(values) 331 self.max = math_ops.reduce_max(values) 332 333 x = SummarizedTensor([[1.0, 2, 3], [4, 5, 6]]) 334 self.assertAllEqual(x.values, [[1.0, 2, 3], [4, 5, 6]]) 335 self.assertAllEqual(x.mean, 3.5) 336 self.assertAllEqual(x.max, 6) 337 338 class Node(extension_type.ExtensionType): 339 x: ops.Tensor 340 y: typing.Optional[str] = None 341 children: typing.Tuple['ExtensionTypeTest.Node', ...] = () 342 343 def testConstructorWithDefaultValues(self): 344 a = ExtensionTypeTest.Node(5) 345 self.assertAllEqual(a.x, 5) 346 self.assertIsNone(a.y) 347 self.assertEqual(a.children, ()) 348 349 b = ExtensionTypeTest.Node(6, 'blue') 350 self.assertAllEqual(b.x, 6) 351 self.assertEqual(b.y, 'blue') 352 self.assertEqual(b.children, ()) 353 354 c = ExtensionTypeTest.Node(7, children=(a, b)) 355 self.assertAllEqual(c.x, 7) 356 self.assertIsNone(c.y) 357 self.assertEqual(c.children, (a, b)) 358 359 def testCustomConstrutorCantMutateNestedValues(self): 360 361 class Foo(extension_type.ExtensionType): 362 x: int 363 364 class Bar(extension_type.ExtensionType): 365 foo: Foo 366 367 def __init__(self, foo): 368 foo.x = 33 # This raises an exception 369 370 with self.assertRaisesRegex( 371 AttributeError, 372 'Cannot mutate attribute `x` outside the custom constructor of ExtensionType' 373 ): 374 Bar(Foo(12)) 375 376 def testCustomValidate(self): 377 378 class AlignedTensors(extension_type.ExtensionType): 379 x: ops.Tensor 380 y: ops.Tensor 381 382 def __validate__(self): 383 self.x.shape.assert_is_compatible_with(self.y.shape) 384 385 aligned = AlignedTensors([1, 2, 3], ['a', 'b', 'c']) 386 self.assertAllEqual(aligned.x, [1, 2, 3]) 387 self.assertAllEqual(aligned.y, [b'a', b'b', b'c']) 388 389 with self.assertRaises(ValueError): 390 AlignedTensors([1, 2, 3], ['a', 'b', 'c', 'd']) 391 392 def testEquals(self): 393 394 class MyType(extension_type.ExtensionType): 395 values: ops.Tensor 396 score: ops.Tensor 397 flavor: str 398 399 x1 = MyType([1, 2], 8, 'blue') 400 x2 = MyType([1, 2], 8, 'blue') 401 y = MyType([1, 2], 8, 'red') 402 z = MyType([1, 2], 7, 'blue') 403 self.assertAllEqual(x1 == x2, True) 404 self.assertAllEqual(x1 != x2, False) 405 self.assertAllEqual(x1 == y, False) 406 self.assertAllEqual(x1 != y, True) 407 self.assertAllEqual(x1 == z, False) 408 self.assertAllEqual(y == z, False) 409 410 # These are not equal, even though their values are broadcast-compatible 411 # and elements are all equal when we broadcast. Shapes must match. 412 a = MyType([1, 1, 1, 1], 0, 'x') 413 b = MyType([[1, 1, 1, 1]], 0, 'x') 414 c = MyType([[1, 1], [1, 1]], 0, 'x') 415 self.assertAllEqual(a == b, False) 416 self.assertAllEqual(a == c, False) 417 self.assertAllEqual(b == c, False) 418 419 # Test with unknown shapes (executes a different codepath). 420 a_ph = replace_tensors_with_placeholders(a) 421 b_ph = replace_tensors_with_placeholders(b) 422 c_ph = replace_tensors_with_placeholders(c) 423 self.assertAllEqual(a_ph == b_ph, False) 424 self.assertAllEqual(a_ph == c_ph, False) 425 self.assertAllEqual(b_ph == c_ph, False) 426 427 def testPassIntoTfFunction(self): 428 429 @def_function.function 430 def fn(x): 431 return x.with_default(99) 432 433 mt = MaskedTensorV2([1, 2, 3, 4], [True, True, False, True]) 434 self.assertAllEqual([1, 2, 99, 4], fn(mt)) 435 self.assertAllEqual([1, 2, 99, 4], fn(extension_type.pack(mt))) 436 437 def testReturnFromTfFunction(self): 438 439 @def_function.function 440 def mask_neg_values(x): 441 return MaskedTensorV2(x, x > 0) 442 443 @def_function.function 444 def mask_neg_values_packed(x): 445 return extension_type.pack(MaskedTensorV2(x, x > 0)) 446 447 expected = MaskedTensorV2([5, 8, -3, 9], [True, True, False, True]) 448 449 actual1 = mask_neg_values(constant_op.constant([5, 8, -3, 9])) 450 self.assertIsInstance(actual1, MaskedTensorV2) 451 self.assertAllEqual(expected.values, actual1.values) 452 self.assertAllEqual(expected.mask, actual1.mask) 453 454 actual2 = mask_neg_values_packed(constant_op.constant([5, 8, -3, 9])) 455 self.assertIsInstance(actual2, MaskedTensorV2) 456 self.assertTrue(extension_type.is_packed(actual2)) 457 self.assertAllEqual(expected.values, actual2.values) 458 self.assertAllEqual(expected.mask, actual2.mask) 459 460 def testCaptureByTfFunction(self): 461 x = MaskedTensorV2( 462 values=[[1, 2, 3], [4, 5, 6]], 463 mask=[[True, True, True], [True, False, True]]) 464 465 @def_function.function 466 def add_to_x(y): 467 return MaskedTensorV2(x.values + y.values, x.mask & y.mask) 468 469 actual = add_to_x(MaskedTensorV2([10, 20, 30], [False, True, True])) 470 expected = MaskedTensorV2( 471 values=[[11, 22, 33], [14, 25, 36]], 472 mask=[[False, True, True], [False, False, True]]) 473 self.assertIsInstance(actual, MaskedTensorV2) 474 self.assertAllEqual(expected.values, actual.values) 475 self.assertAllEqual(expected.mask, actual.mask) 476 477 def testTfFunctionArgMutationError(self): 478 479 @def_function.function 480 def fn_with_side_effect(mts): 481 mts.append(MaskedTensorV1(mts[0].values * 2, mts[0].mask)) 482 483 with self.assertRaisesRegex(ValueError, 'should not modify'): 484 fn_with_side_effect([MaskedTensorV1([10, 20, 30], [False, True, True])]) 485 486 def testNestPackUnpack(self): 487 488 class CandyStore(extension_type.ExtensionType): 489 name: ops.Tensor 490 prices: typing.Mapping[str, ops.Tensor] 491 492 store = CandyStore('Yum', {'gum': [0.42, 0.48], 'chocolate': [0.83, 1.02]}) 493 components = nest.flatten(store, expand_composites=True) 494 repacked_1 = nest.pack_sequence_as( 495 store, components, expand_composites=True) 496 repacked_2 = nest.pack_sequence_as( 497 store._type_spec, components, expand_composites=True) 498 499 # Note: dicts get sorted by key. 500 self.assertLen(components, 3) 501 self.assertAllEqual(components[0], b'Yum') 502 self.assertAllClose(components[1], [0.83, 1.02]) 503 self.assertAllClose(components[2], [0.42, 0.48]) 504 505 for repacked in [repacked_1, repacked_2]: 506 self.assertAllEqual(repacked.name, b'Yum') 507 self.assertAllClose(repacked.prices['gum'], [0.42, 0.48]) 508 self.assertAllClose(repacked.prices['chocolate'], [0.83, 1.02]) 509 510 def testSimpleCond(self): 511 x = MaskedTensorV1([1, 2, 3, 4], [True, False, True, False]) 512 y = MaskedTensorV1([5, 6, 7, 8], [False, True, True, False]) 513 514 x_2 = control_flow_ops.cond( 515 constant_op.constant(True), lambda: x, lambda: y) 516 y_2 = control_flow_ops.cond( 517 constant_op.constant(False), lambda: x, lambda: y) 518 519 self.assertAllEqual(x.values, x_2.values) 520 self.assertAllEqual(x.mask, x_2.mask) 521 self.assertAllEqual(y.values, y_2.values) 522 self.assertAllEqual(y.mask, y_2.mask) 523 524 def testComplexCond(self): 525 mt = MaskedTensorV1([1, 2, 3, 4], [True, False, True, False]) 526 527 def true_fn(): 528 return MaskedTensorV1( 529 array_ops.where_v2(mt.mask, mt.values, -1), mt.values > 3) 530 531 def false_fn(): 532 return MaskedTensorV1( 533 array_ops.where_v2(mt.mask, 100, mt.values * 2), 534 math_ops.logical_not(mt.mask)) 535 536 x = control_flow_ops.cond(constant_op.constant(True), true_fn, false_fn) 537 y = control_flow_ops.cond(constant_op.constant(False), true_fn, false_fn) 538 539 self.assertAllEqual(x.values, [1, -1, 3, -1]) 540 self.assertAllEqual(x.mask, [False, False, False, True]) 541 self.assertAllEqual(y.values, [100, 4, 100, 8]) 542 self.assertAllEqual(y.mask, [False, True, False, True]) 543 544 def testCondAutograph(self): 545 546 @def_function.function 547 def fn(mt): 548 if mt.values[3] > 3: 549 return MaskedTensorV1( 550 array_ops.where_v2(mt.mask, mt.values, -1), mt.values > 3) 551 else: 552 return MaskedTensorV1( 553 array_ops.where_v2(mt.mask, 100, mt.values * 2), not mt.mask) 554 555 x = fn(MaskedTensorV1([1, 2, 3, 4], [True, False, True, False])) 556 self.assertAllEqual(x.values, [1, -1, 3, -1]) 557 self.assertAllEqual(x.mask, [False, False, False, True]) 558 559 def testCondTypeMismatch(self): 560 if context.executing_eagerly: 561 # In eager mode, tf.cond eagerly runs either true_fn or false_fn, and 562 # ignores the other one; so it doesn't detect any type mismatches 563 # between the two outcomes. (See _eager_cond_implementation in 564 # control_flow_ops.py.) 565 return 566 567 a = lambda: MaskedTensorV1([1, 2, 3], [True, True, False]) 568 b = lambda: MaskedTensorV1(['a', 'b', 'c'], [False, True, True]) 569 c = lambda: MaskedTensorV2([4, 5, 6], [True, True, False]) 570 d = lambda: constant_op.constant([7, 8, 9]) 571 572 with self.assertRaisesRegex( 573 ValueError, 574 'Incompatible return values of true_fn and false_fn: The two ' 575 "structures don't have the same nested structure"): 576 control_flow_ops.cond(constant_op.constant(True), a, b) 577 with self.assertRaisesRegex( 578 TypeError, 'Incompatible return types of true_fn and false_fn: The two ' 579 "structures don't have the same nested structure"): 580 control_flow_ops.cond(constant_op.constant(True), a, c) 581 with self.assertRaisesRegex( 582 ValueError, 583 'Incompatible return values of true_fn and false_fn: The two ' 584 "structures don't have the same nested structure"): 585 control_flow_ops.cond(constant_op.constant(True), a, d) 586 587 def testCondPacked(self): 588 x = MaskedTensorV2([1, 2, 3, 4], [True, False, True, False]) 589 y = MaskedTensorV2([5, 6, 7, 8], [False, True, True, False]) 590 x = extension_type.pack(x) 591 y = extension_type.pack(y) 592 593 x_2 = control_flow_ops.cond( 594 constant_op.constant(True), lambda: x, lambda: y) 595 y_2 = control_flow_ops.cond( 596 constant_op.constant(False), lambda: x, lambda: y) 597 598 self.assertAllEqual(x.values, x_2.values) 599 self.assertAllEqual(x.mask, x_2.mask) 600 self.assertAllEqual(y.values, y_2.values) 601 self.assertAllEqual(y.mask, y_2.mask) 602 603 a = MaskedTensorV2([1, 2, 3, 4], [True, False, True, False]) 604 b = extension_type.pack(a) 605 b = control_flow_ops.cond( 606 constant_op.constant(True), lambda: array_ops.size(a.mask), 607 lambda: array_ops.size(a.values)) 608 self.assertAllEqual(b, 4) 609 610 # Note: the following example would fail (with `Retval[0] does not have a 611 # value`) if `ExtensionType.__getattr__` cached the results of unpacking 612 # the value. See the comment in `ExtensionType.__getattr__` for details. 613 c = MaskedTensorV2([1, 2, 3, 4], [True, False, True, False]) 614 c = extension_type.pack(c) 615 d = control_flow_ops.cond( 616 constant_op.constant(False), lambda: array_ops.size(c.mask), 617 lambda: array_ops.size(c.values)) 618 self.assertAllEqual(d, 4) 619 620 def testWhileLoop(self): 621 x = MaskedTensorV1([1, 2, 3, 4], [True, False, True, False]) 622 623 cond = lambda i, x: i < 10 624 body = lambda i, x: (i + 1, MaskedTensorV1(x.values * 2, x.mask)) 625 _, y = control_flow_ops.while_loop_v2(cond, body, [0, x]) 626 627 self.assertIsInstance(y, MaskedTensorV1) 628 self.assertAllEqual(y.values, [1024, 2048, 3072, 4096]) 629 self.assertAllEqual(y.mask, [True, False, True, False]) 630 631 def testWhileLoopAutograph(self): 632 633 @def_function.function 634 def fn(x, n): 635 for _ in math_ops.range(n): 636 x = MaskedTensorV1(x.values * 2, x.mask) 637 return x 638 639 y = fn(MaskedTensorV1([1, 2, 3, 4], [True, False, True, False]), 10) 640 self.assertIsInstance(y, MaskedTensorV1) 641 self.assertAllEqual(y.values, [1024, 2048, 3072, 4096]) 642 self.assertAllEqual(y.mask, [True, False, True, False]) 643 644 def testWhileLoopTypeMismatch(self): 645 x = MaskedTensorV1([1, 2, 3, 4], [True, False, True, False]) 646 647 cond = lambda i, x: i < 10 648 649 def body(i, x): 650 if isinstance(x, MaskedTensorV1): 651 return x.values * 2 652 else: 653 return MaskedTensorV1(x, x > i) 654 655 with self.assertRaisesRegex( 656 ValueError, "The two structures don't have the same nested structure"): 657 control_flow_ops.while_loop_v2(cond, body, [0, x]) 658 659 def testWhileLoopPacked(self): 660 x = MaskedTensorV2([1, 2, 3, 4], [True, False, True, False]) 661 x = extension_type.pack(x) 662 cond = lambda i, x: i < 10 663 664 def body(i, x): 665 return i + 1, extension_type.pack(MaskedTensorV2(x.values * 2, x.mask)) 666 667 _, y = control_flow_ops.while_loop_v2(cond, body, [0, x]) 668 self.assertIsInstance(y, MaskedTensorV2) 669 self.assertAllEqual(y.values, [1024, 2048, 3072, 4096]) 670 self.assertAllEqual(y.mask, [True, False, True, False]) 671 672 def testNestedFields(self): 673 PossiblyRaggedTensor = typing.Union[ops.Tensor, ragged_tensor.RaggedTensor] 674 ToyFeatures = typing.Mapping[str, PossiblyRaggedTensor] 675 676 class ToyInfo(extension_type.ExtensionType): 677 version: str 678 toys: typing.Tuple[typing.Tuple[str, ops.Tensor, ToyFeatures], ...] 679 boxes: typing.Mapping[str, ops.Tensor] 680 681 authors = [[b'A', b'Aardvark'], [b'Z', b'Zhook']] 682 toys = [('car', 1.0, { 683 'size': [8, 3, 2], 684 'color': [0.3, 0.2, 0.8] 685 }), ('book', 3.7, { 686 'authors': ragged_factory_ops.constant(authors) 687 })] 688 boxes = {'green': ['car'], 'blue': ['car', 'book', 'book']} 689 toy_info = ToyInfo(version='1.0 alpha', toys=toys, boxes=boxes) 690 691 self.assertEqual(toy_info.version, '1.0 alpha') 692 self.assertEqual(toy_info.toys[0][0], 'car') 693 self.assertIsInstance(toy_info.toys[0][1], ops.Tensor) 694 self.assertAllEqual(toy_info.toys[0][1], 1.0) 695 self.assertEqual(set(toy_info.toys[0][2].keys()), {'size', 'color'}) 696 self.assertIsInstance(toy_info.toys[0][2]['size'], ops.Tensor) 697 self.assertAllEqual(toy_info.toys[0][2]['size'], [8, 3, 2]) 698 self.assertIsInstance(toy_info.toys[1][2]['authors'], 699 ragged_tensor.RaggedTensor) 700 self.assertAllEqual(toy_info.toys[1][2]['authors'], authors) 701 self.assertAllEqual(toy_info.boxes['green'], [b'car']) 702 self.assertAllEqual(toy_info.boxes['blue'], ['car', 'book', 'book']) 703 704 expected_repr = ( 705 r"ToyInfo\(version='1.0 alpha', toys=\(" 706 r"\('car', <tf.Tensor[^>]*>, ImmutableDict\(" 707 r"{'size': <tf.Tensor[^>]*>, 'color': <tf.Tensor[^>]*>}\)\), " 708 r"\('book', <tf.Tensor[^>]*>, ImmutableDict\(" 709 r"{'authors': (<tf.RaggedTensor[^>]*>|tf.RaggedTensor\(.*\))}\)\)\), " 710 r'boxes=ImmutableDict\(' 711 r"{'green': <tf.Tensor[^>]*>, 'blue': <tf.Tensor[^>]*>}\)\)") 712 713 self.assertRegex(repr(toy_info), expected_repr) 714 715 def testNestedExtensionTypes(self): 716 PossiblyMaskedTensor = typing.Union[ops.Tensor, MaskedTensorV1] 717 718 class Toy(extension_type.ExtensionType): 719 name: str 720 price: ops.Tensor 721 features: typing.Mapping[str, PossiblyMaskedTensor] 722 723 class Box(extension_type.ExtensionType): 724 contents: ops.Tensor 725 726 class ToyInfo(extension_type.ExtensionType): 727 version: str 728 toys: typing.Tuple[Toy, ...] 729 boxes: typing.Mapping[str, Box] 730 731 authors = MaskedTensorV1( 732 values=[[b'A', b'Quincy', b'Aardvark'], [b'Z', b'Zhook', b'']], 733 mask=[[True, True, True], [True, True, False]]) 734 toys = [ 735 Toy('car', 1.0, { 736 'size': [8, 3, 2], 737 'color': [0.3, 0.2, 0.8] 738 }), 739 Toy(name='book', price=3.7, features={'authors': authors}) 740 ] 741 boxes = { 742 'green': Box(['car']), 743 'blue': Box(contents=['car', 'book', 'book']) 744 } 745 toy_info = ToyInfo(version='1.0 alpha', toys=toys, boxes=boxes) 746 747 @def_function.function 748 def fn(info): 749 prices = [toy.price for toy in info.toys] 750 return math_ops.reduce_sum(array_ops.stack(prices)) 751 752 self.assertAllClose(fn(toy_info), 4.7) 753 754 def testNestedCustomConstructor(self): 755 756 class Toy(extension_type.ExtensionType): 757 name: str 758 price: ops.Tensor 759 760 def __init__(self, name, price, discount=0): 761 if discount: 762 name += ' (discounted)' 763 price *= (1 - discount) 764 self.name = name 765 self.price = price 766 767 class ToyBox(extension_type.ExtensionType): 768 toys: typing.Tuple[Toy, ...] 769 770 def __init__(self, name_to_price, name_to_discount): 771 self.toys = [ 772 Toy(name, price, name_to_discount.get(name, 0)) 773 for (name, price) in name_to_price.items() 774 ] 775 776 toy_box = ToyBox({ 777 'car': 8.3, 778 'truck': 5.9, 779 'puzzle': 5.3, 780 'jacks': 2.8 781 }, { 782 'puzzle': .2, 783 'truck': .3 784 }) 785 self.assertLen(toy_box.toys, 4) 786 self.assertEqual( 787 set(toy.name for toy in toy_box.toys), 788 {'car', 'truck (discounted)', 'puzzle (discounted)', 'jacks'}) 789 790 def testExtensionTypeWithMathOperators(self): 791 792 def masked_add(x, y, name=None): 793 del name 794 if not isinstance(x, MaskedTensorV2) and isinstance(y, MaskedTensorV2): 795 return dispatch.OpDispatcher.NOT_SUPPORTED 796 return MaskedTensorV2(x.values + y.values, x.mask & y.mask) 797 798 with temporarily_add_dispatch(math_ops.add, MaskedTensorV2, masked_add): 799 x = MaskedTensorV2([[1, 2], [3, 4]], [[True, False], [True, True]]) 800 y = MaskedTensorV2([[3, 4], [5, 6]], [[True, True], [False, True]]) 801 z = x + y 802 self.assertAllEqual(z.values, [[4, 6], [8, 10]]) 803 self.assertAllEqual(z.mask, [[True, False], [False, True]]) 804 805 def testGetExtensionTypeFields(self): 806 807 # Can be called on a type or an instance: 808 fields_1 = MaskedTensorV1._tf_extension_type_fields() 809 fields_2 = MaskedTensorV1([0], [True])._tf_extension_type_fields() 810 811 for fields in [fields_1, fields_2]: 812 self.assertLen(fields, 2) 813 self.assertEqual(fields[0].name, 'values') 814 self.assertEqual(fields[0].value_type, ops.Tensor) 815 self.assertEqual(fields[0].default, fields[0].NO_DEFAULT) 816 self.assertEqual(fields[1].name, 'mask') 817 self.assertEqual(fields[1].value_type, ops.Tensor) 818 self.assertEqual(fields[1].default, fields[0].NO_DEFAULT) 819 820 def testHasExtensionTypeField(self): 821 822 self.assertTrue(MaskedTensorV1._tf_extension_type_has_field('values')) 823 self.assertTrue(MaskedTensorV1._tf_extension_type_has_field('mask')) 824 self.assertFalse(MaskedTensorV1._tf_extension_type_has_field('labels')) 825 826 mt = MaskedTensorV1([0], [True]) 827 self.assertTrue(mt._tf_extension_type_has_field('values')) 828 self.assertTrue(mt._tf_extension_type_has_field('mask')) 829 self.assertFalse(mt._tf_extension_type_has_field('labels')) 830 831 def testForwardReferences(self): 832 A, B = ForwardRefA, ForwardRefB 833 834 self.assertEqual(A._tf_extension_type_fields(), 835 (extension_type_field.ExtensionTypeField( 836 'x', typing.Tuple[typing.Union[A, B], ...]), 837 extension_type_field.ExtensionTypeField('y', B))) 838 self.assertEqual(B._tf_extension_type_fields(), 839 (extension_type_field.ExtensionTypeField('z', B), 840 extension_type_field.ExtensionTypeField('n', ops.Tensor))) 841 842 # Check the signature. 843 expected_parameters = [ 844 tf_inspect.Parameter('self', POSITIONAL_OR_KEYWORD), 845 tf_inspect.Parameter( 846 'x', 847 POSITIONAL_OR_KEYWORD, 848 annotation=typing.Tuple[typing.Union['ForwardRefA', 'ForwardRefB'], 849 ...]), 850 tf_inspect.Parameter( 851 'y', POSITIONAL_OR_KEYWORD, annotation='ForwardRefB'), 852 ] 853 expected_sig = tf_inspect.Signature( 854 expected_parameters, return_annotation=A) 855 self.assertEqual(tf_inspect.signature(A.__init__), expected_sig) 856 857 def testUnresolvedForwardReference(self): 858 859 class Broken(extension_type.ExtensionType): 860 x: 'Cra' # note: intentional typo for Car. 861 862 class Car(extension_type.ExtensionType): 863 speed: float 864 865 with self.assertRaises(TypeError): 866 Broken(x=Car(3.8)) 867 868 def testUnsupportedAnnotations(self): 869 with self.assertRaisesRegex( 870 TypeError, "In field 'values': Unsupported type annotation"): 871 872 class MyType1(extension_type.ExtensionType): # pylint: disable=unused-variable 873 values: typing.List[ops.Tensor] 874 875 with self.assertRaisesRegex(TypeError, 876 "In field 'xyz': Unsupported type annotation"): 877 878 class MyType2(extension_type.ExtensionType): # pylint: disable=unused-variable 879 xyz: typing.Union[typing.Tuple[complex, ...], int] 880 881 def testCantUseReservedName(self): 882 with self.assertRaisesRegex( 883 ValueError, 'The field annotations for MyType1 are invalid. ' 884 "Field '_to_components' is reserved"): 885 886 class MyType1(extension_type.ExtensionType): # pylint: disable=unused-variable 887 _to_components: int 888 889 with self.assertRaisesRegex( 890 ValueError, 'The field annotations for MyType2 are invalid. ' 891 "Field '_tf_extension_type_foo' is reserved"): 892 893 class MyType2(extension_type.ExtensionType): # pylint: disable=unused-variable 894 _tf_extension_type_foo: int 895 896 with self.assertRaisesRegex( 897 ValueError, 'The field annotations for MyType3 are invalid. ' 898 "Field 'is_compatible_with' is reserved"): 899 900 class MyType3(extension_type.ExtensionType): # pylint: disable=unused-variable 901 902 def is_compatible_with(self, other): 903 return False 904 905 def testExtensionTypeBaseClassHasNoSpec(self): 906 self.assertFalse(hasattr(extension_type.ExtensionType, 'Spec')) 907 908 def testExtensionTypeBaseConstructorRaisesException(self): 909 with self.assertRaisesRegex(AssertionError, 910 'ExtensionType is an abstract base class.'): 911 extension_type.ExtensionType() 912 913 class ExtensionTypeWithName(extension_type.ExtensionType): 914 __name__ = 'tf.__test__.ExtensionTypeWithName' # For SavedModel 915 x: typing.Tuple[ops.Tensor, int] 916 y: ops.Tensor 917 918 def testSavedModelSupport(self): 919 920 class TestModule(module.Module): 921 922 @def_function.function 923 def f(self, s): 924 return s.x[0] + s.x[1] + s.y 925 926 s1 = self.ExtensionTypeWithName((1, 2), 3) 927 s2 = self.ExtensionTypeWithName((1.0, 2), [3.0, 4.0]) 928 929 m = TestModule() 930 m.f.get_concrete_function(s1) 931 m.f.get_concrete_function(s2) 932 933 path = tempfile.mkdtemp(prefix=test.get_temp_dir()) 934 save.save(m, path) 935 loaded = load.load(path) 936 937 self.assertAllEqual(loaded.f(s1), 6) 938 self.assertAllEqual(loaded.f(s2), [6.0, 7.0]) 939 940 def testPackedEncoding(self): 941 mt1 = MaskedTensorV2([1, 2, 3, 4], [True, True, False, True]) 942 self.assertLen(nest.flatten(mt1, expand_composites=True), 2) 943 944 mt2 = extension_type.pack(mt1) 945 self.assertLen(nest.flatten(mt2, expand_composites=True), 1) 946 self.assertIsInstance(mt2.values, ops.Tensor) 947 self.assertAllEqual(mt2.values, [1, 2, 3, 4]) 948 self.assertIsInstance(mt2.mask, ops.Tensor) 949 self.assertAllEqual(mt2.mask, [True, True, False, True]) 950 951 mt3 = extension_type.unpack(mt2) 952 self.assertLen(nest.flatten(mt3, expand_composites=True), 2) 953 self.assertIsInstance(mt3.values, ops.Tensor) 954 self.assertAllEqual(mt3.values, [1, 2, 3, 4]) 955 self.assertIsInstance(mt3.mask, ops.Tensor) 956 self.assertAllEqual(mt3.mask, [True, True, False, True]) 957 958 nest.assert_same_structure(mt1, mt3, expand_composites=True) 959 with self.assertRaisesRegex(ValueError, "don't have the same"): # pylint: disable=g-error-prone-assert-raises 960 nest.assert_same_structure(mt1, mt2, expand_composites=True) 961 962 mt4 = MaskedTensorV1([1, 2, 3, 4], [True, True, False, True]) 963 with self.assertRaisesRegex( 964 ValueError, 965 'ExtensionTypes must have a __name__ field in order to be packed.'): 966 extension_type.pack(mt4) 967 968 def testSubclassing(self): 969 970 class Instrument(extension_type.ExtensionType): 971 name: ops.Tensor 972 weight: ops.Tensor 973 needs_case: bool 974 975 class StringInstrument(Instrument): 976 num_strings: int # Add a new field 977 needs_case: bool = True # Override default value. 978 979 class Violin(StringInstrument): 980 maker: ops.Tensor 981 num_strings: int = 4 # Override default value. 982 name: str = 'violin' # Override field type and default value. 983 984 self.assertEqual( 985 list( 986 tf_inspect.signature( 987 StringInstrument.__init__).parameters.values()), [ 988 tf_inspect.Parameter('self', POSITIONAL_OR_KEYWORD), 989 tf_inspect.Parameter( 990 'name', POSITIONAL_OR_KEYWORD, annotation=ops.Tensor), 991 tf_inspect.Parameter( 992 'weight', POSITIONAL_OR_KEYWORD, annotation=ops.Tensor), 993 tf_inspect.Parameter( 994 'needs_case', 995 POSITIONAL_OR_KEYWORD, 996 annotation=bool, 997 default=True), 998 tf_inspect.Parameter( 999 'num_strings', KEYWORD_ONLY, annotation=int), 1000 ]) 1001 self.assertEqual( 1002 list(tf_inspect.signature(Violin.__init__).parameters.values()), [ 1003 tf_inspect.Parameter('self', POSITIONAL_OR_KEYWORD), 1004 tf_inspect.Parameter( 1005 'name', POSITIONAL_OR_KEYWORD, annotation=str, 1006 default='violin'), 1007 tf_inspect.Parameter('weight', KEYWORD_ONLY, annotation=ops.Tensor), 1008 tf_inspect.Parameter( 1009 'needs_case', KEYWORD_ONLY, annotation=bool, default=True), 1010 tf_inspect.Parameter( 1011 'num_strings', KEYWORD_ONLY, annotation=int, default=4), 1012 tf_inspect.Parameter('maker', KEYWORD_ONLY, annotation=ops.Tensor), 1013 ]) 1014 1015 violin = Violin(weight=28, maker='Amati') 1016 self.assertAllEqual(violin.name, 'violin') 1017 self.assertAllEqual(violin.weight, 28) 1018 self.assertAllEqual(violin.needs_case, True) 1019 self.assertAllEqual(violin.num_strings, 4) 1020 self.assertAllEqual(violin.maker, 'Amati') 1021 1022 1023# integration test to test compatibility with high level api like Dataset 1024# and Keras 1025class ExtensionTypeIntegrationTest(test_util.TensorFlowTestCase): 1026 1027 @test_util.run_v2_only 1028 def testDataset(self): 1029 mt = MaskedTensorV3([[1], [2], [3]], [[True], [False], [True]]) 1030 ds = dataset_ops.DatasetV2.from_tensors(mt) 1031 self.assertEqual(next(iter(ds)), mt) 1032 1033 @test_util.run_v2_only 1034 def testDatasetBatch(self): 1035 xs = MaskedTensorV3([[1], [2], [3]], [[True], [False], [True]]) 1036 x0 = MaskedTensorV3(xs.values[0], xs.mask[0]) 1037 1038 ds = dataset_ops.DatasetV2.from_tensors(xs) 1039 self.assertEqual(next(iter(ds)), xs) 1040 ds = ds.unbatch() 1041 self.assertEqual(next(iter(ds)), x0) 1042 1043 ds = dataset_ops.DatasetV2.from_tensor_slices(xs) 1044 self.assertEqual(next(iter(ds)), x0) 1045 ds = ds.batch(3, drop_remainder=True) 1046 self.assertEqual(next(iter(ds)), xs) 1047 1048 @test_util.run_v2_only 1049 def testDatasetBatchRagged(self): 1050 xs = MaskedTensorV3( 1051 ragged_factory_ops.constant([[1], [2, 3], [4]]), 1052 ragged_factory_ops.constant([[True], [False], [True]])) 1053 x0 = MaskedTensorV3(xs.values[0], xs.mask[0]) 1054 1055 ds = dataset_ops.DatasetV2.from_tensors(xs) 1056 self.assertEqual(next(iter(ds)), xs) 1057 ds = ds.unbatch() 1058 self.assertEqual(next(iter(ds)), x0) 1059 1060 ds = dataset_ops.DatasetV2.from_tensor_slices(xs) 1061 self.assertEqual(next(iter(ds)), x0) 1062 ds = ds.batch(3, drop_remainder=True) 1063 self.assertEqual(next(iter(ds)), xs) 1064 1065 @test_util.run_v2_only 1066 def testDistributedDataset(self): 1067 strategy = mirrored_strategy.MirroredStrategy(['GPU:0', 'GPU:1']) 1068 mt = MaskedTensorV3([[1], [2], [3], [4]], [[True], [False], [True], [True]]) 1069 ds = dataset_ops.DatasetV2.from_tensor_slices(mt).batch(2) 1070 dist_dataset = strategy.experimental_distribute_dataset(ds) 1071 expect = MaskedTensorV3([[1]], [[True]]) 1072 per_replica_result = next(iter(dist_dataset)) 1073 self.assertEqual(per_replica_result.values[0].values, expect.values[0]) 1074 self.assertEqual(per_replica_result.values[0].mask, expect.mask[0]) 1075 1076 # TODO(edloper): Move this test to Keras. 1077 @test_util.run_v2_only 1078 def testKerasModel(self): 1079 mt_spec = MaskedTensorV3.Spec( 1080 tensor_spec.TensorSpec(shape=[None, 1], dtype=dtypes.int32), 1081 tensor_spec.TensorSpec(shape=[None, 1], dtype=dtypes.bool), 1082 ) 1083 model_input = input_layer.Input(type_spec=mt_spec) 1084 model_output = array_ops.identity(model_input, name='output') 1085 model = training.Model(inputs=model_input, outputs=model_output) 1086 mt = MaskedTensorV3([[1], [2], [3]], [[True], [False], [True]]) 1087 self.assertEqual(model(mt), mt) 1088 ds = dataset_ops.DatasetV2.from_tensors(mt) 1089 self.assertEqual(model.predict(ds), mt) 1090 1091 with self.subTest('keras save'): 1092 path = self.create_tempdir().full_path 1093 model.save(path) 1094 loaded_model = keras_save.load_model(path) 1095 self.assertEqual(loaded_model.input.type_spec, mt_spec) 1096 self.assertEqual(loaded_model(mt), mt) 1097 1098 loaded_fn = load.load(path) 1099 self.assertEqual(loaded_fn(mt), mt) 1100 with self.assertRaisesRegex( 1101 ValueError, 1102 'Could not find matching concrete function to call ' 1103 'loaded from the SavedModel', 1104 ): 1105 loaded_fn(MaskedTensorV3([1, 2, 3], [True, False, True])) 1106 1107 # The serving_fn use flatten signature 1108 serving_fn = loaded_fn.signatures['serving_default'] 1109 self.assertEqual( 1110 serving_fn(args_0=mt.values, args_0_1=mt.mask)['tf.identity'], mt) 1111 1112 1113@test_util.run_all_in_graph_and_eager_modes 1114class ExtensionTypeSpecTest(test_util.TensorFlowTestCase, 1115 parameterized.TestCase): 1116 1117 def testSpecConstructor(self): 1118 values_spec = tensor_spec.TensorSpec([4], dtypes.float32) 1119 mask_spec = tensor_spec.TensorSpec([4], dtypes.bool) 1120 mt_spec = MaskedTensorV1.Spec(values_spec, mask_spec) 1121 self.assertEqual(mt_spec.values, values_spec) 1122 self.assertEqual(mt_spec.mask, mask_spec) 1123 1124 mt = MaskedTensorV1([1.0, 2.0, 3.0, 4.0], [True, True, False, True]) 1125 self.assertEqual(mt._type_spec, mt_spec) 1126 1127 def testSpecConstructorSignature(self): 1128 1129 class MyType(extension_type.ExtensionType): 1130 x: ops.Tensor 1131 y: ops.Tensor 1132 z: typing.Tuple[typing.Union[int, str], ...] = [1, 'two', 3] 1133 1134 expected_parameters = [ 1135 tf_inspect.Parameter('self', POSITIONAL_OR_KEYWORD), 1136 tf_inspect.Parameter('x', POSITIONAL_OR_KEYWORD), 1137 tf_inspect.Parameter('y', POSITIONAL_OR_KEYWORD), 1138 tf_inspect.Parameter('z', POSITIONAL_OR_KEYWORD), 1139 ] 1140 expected_sig = tf_inspect.Signature( 1141 expected_parameters, return_annotation=MyType.Spec) 1142 self.assertEqual(expected_sig, tf_inspect.signature(MyType.Spec.__init__)) 1143 1144 def testSpecAttributesAreImmutable(self): 1145 mt = MaskedTensorV1([1, 2, 3, 4], [True, True, False, True]) 1146 mt_spec = MaskedTensorV1.Spec.from_value(mt) 1147 with self.assertRaisesRegex( 1148 AttributeError, 'Cannot mutate attribute `score` ' 1149 'outside the custom constructor of ExtensionTypeSpec'): 1150 mt_spec.score = 12 1151 with self.assertRaisesRegex( 1152 AttributeError, 'Cannot mutate attribute `values` ' 1153 'outside the custom constructor of ExtensionTypeSpec'): 1154 mt_spec.values = constant_op.constant([4, 3, 2, 1]) 1155 with self.assertRaisesRegex( 1156 AttributeError, 'Cannot mutate attribute `values` ' 1157 'outside the custom constructor of ExtensionTypeSpec'): 1158 del mt_spec.values 1159 1160 def testSpecFromValue(self): 1161 mt = MaskedTensorV1([1.0, 2.0, 3.0, 4.0], [True, True, False, True]) 1162 mt_spec = MaskedTensorV1.Spec.from_value(mt) 1163 1164 expected_values_spec = tensor_spec.TensorSpec([4], dtypes.float32) 1165 expected_mask_spec = tensor_spec.TensorSpec([4], dtypes.bool) 1166 self.assertEqual(mt_spec.values, expected_values_spec) 1167 self.assertEqual(mt_spec.mask, expected_mask_spec) 1168 1169 def testSpecSerialize(self): 1170 1171 class Zoo(extension_type.ExtensionType): 1172 zookeepers: typing.Tuple[str, ...] 1173 animals: typing.Mapping[str, typing.Mapping[str, ops.Tensor]] 1174 1175 featurespec = { 1176 'size': tensor_spec.TensorSpec([3]), 1177 'weight': tensor_spec.TensorSpec([]) 1178 } 1179 zoo_spec = Zoo.Spec( 1180 zookeepers=['Zoey', 'Zack'], 1181 animals={ 1182 'tiger': featurespec, 1183 'elephant': featurespec 1184 }) 1185 1186 serialized = zoo_spec._serialize() 1187 self.assertEqual(serialized, 1188 (('zookeepers', ('Zoey', 'Zack')), ('animals', { 1189 'tiger': featurespec, 1190 'elephant': featurespec 1191 }))) 1192 restored = Zoo.Spec._deserialize(serialized) 1193 self.assertEqual(zoo_spec, restored) 1194 1195 # ImmutableDict is used for the field, but dict for the serialization: 1196 self.assertIsInstance(zoo_spec.animals, immutable_dict.ImmutableDict) 1197 serialized_field_name, serialized_field_value = serialized[1] 1198 self.assertEqual(serialized_field_name, 'animals') 1199 self.assertIsInstance(serialized_field_value, dict) 1200 1201 def testSpecComponents(self): 1202 1203 class Zoo(extension_type.ExtensionType): 1204 zookeepers: typing.Tuple[str, ...] 1205 animals: typing.Mapping[str, typing.Mapping[str, ops.Tensor]] 1206 1207 zoo = Zoo( 1208 ['Zoey', 'Zack'], { 1209 'elephant': { 1210 'size': [25, 30, 20], 1211 'weight': 2000.0 1212 }, 1213 'tiger': { 1214 'hunger': 3.2, 1215 'size': [3, 8, 2], 1216 'weight': 87.3 1217 } 1218 }) 1219 zoo_spec = Zoo.Spec.from_value(zoo) 1220 1221 components = zoo_spec._to_components(zoo) 1222 self.assertLen(components, 5) 1223 self.assertAllClose(components[0], [25, 30, 20]) 1224 self.assertAllClose(components[1], 2000.0) 1225 self.assertAllClose(components[2], 3.2) 1226 self.assertAllClose(components[3], [3, 8, 2]) 1227 self.assertAllClose(components[4], 87.3) 1228 1229 restored = zoo_spec._from_components(components) 1230 self.assertAllEqual(zoo == restored, True) 1231 1232 self.assertEqual(zoo_spec._component_specs, 1233 (tensor_spec.TensorSpec([3], dtypes.int32), 1234 tensor_spec.TensorSpec([], dtypes.float32), 1235 tensor_spec.TensorSpec([], dtypes.float32), 1236 tensor_spec.TensorSpec([3], dtypes.int32), 1237 tensor_spec.TensorSpec([], dtypes.float32))) 1238 1239 def testCopyAndPickle(self): 1240 values_spec = tensor_spec.TensorSpec([4], dtypes.float32) 1241 mask_spec = tensor_spec.TensorSpec([4], dtypes.bool) 1242 mt_spec = MaskedTensorV1.Spec(values_spec, mask_spec) 1243 self.assertEqual(copy.copy(mt_spec), mt_spec) 1244 self.assertEqual(copy.deepcopy(mt_spec), mt_spec) 1245 self.assertEqual(pickle.loads(pickle.dumps(mt_spec)), mt_spec) 1246 1247 def testCustomizeSpecTest(self): 1248 1249 class WeightedTensor(extension_type.ExtensionType): 1250 """ExtensionType with a customized TypeSpec. 1251 1252 * Custom constructor. 1253 * Custom __validate__. 1254 * Add properties (shape, dtype, weight_dtype). 1255 * Add method (with_shape). 1256 """ 1257 values: ops.Tensor 1258 weight: ops.Tensor # scalar 1259 1260 shape = property(lambda self: self.shape) 1261 dtype = property(lambda self: self.dtype) 1262 weight_dtype = property(lambda self: self.weight.dtype) 1263 1264 def __validate__(self): 1265 self.weight.shape.assert_has_rank(0) 1266 1267 class Spec: 1268 1269 def __init__(self, shape, dtype, weight_dtype=dtypes.float32): 1270 self.values = tensor_spec.TensorSpec(shape, dtype) 1271 self.weight = tensor_spec.TensorSpec([], weight_dtype) 1272 1273 def __validate__(self): 1274 self.weight.shape.assert_has_rank(0) 1275 1276 shape = property(lambda self: self.values.shape) 1277 dtype = property(lambda self: self.values.dtype) 1278 weight_dtype = property(lambda self: self.weight.dtype) 1279 1280 def with_shape(self, shape): 1281 return WeightedTensor.Spec(shape, self.dtype, self.weight_dtype) 1282 1283 wt = WeightedTensor([1, 2], 0.3) 1284 wt_spec = WeightedTensor.Spec.from_value(wt) 1285 self.assertEqual(wt_spec.shape, tensor_shape.TensorShape([2])) 1286 self.assertEqual(wt_spec.dtype, dtypes.int32) 1287 1288 self.assertEqual(wt_spec, WeightedTensor.Spec([2], dtypes.int32)) 1289 1290 wt2 = WeightedTensor([[1, 2], [3, 4]], 0.5) 1291 wt2_spec = WeightedTensor.Spec.from_value(wt2) 1292 self.assertEqual(wt_spec.with_shape([2, 2]), wt2_spec) 1293 1294 def testNestedSpecMustBeAClass(self): 1295 with self.assertRaisesRegex( 1296 ValueError, 1297 r'BrokenExtensionType\.Spec must be a nested class; got 12.'): 1298 1299 class BrokenExtensionType(extension_type.ExtensionType): 1300 1301 Spec = 12 # pylint: disable=invalid-name 1302 1303 del BrokenExtensionType 1304 1305 def testNestedSpecMayNotHaveBaseClasses(self): 1306 with self.assertRaisesRegex( 1307 ValueError, r'BrokenExtensionType\.Spec must be directly subclassed ' 1308 'from tf.TypeSpec.'): 1309 1310 class BrokenExtensionType(extension_type.ExtensionType): 1311 1312 class Spec(type_spec.BatchableTypeSpec): 1313 pass 1314 1315 del BrokenExtensionType 1316 1317 1318@test_util.run_all_in_graph_and_eager_modes 1319class AnonymousExtensionTypeTest(test_util.TensorFlowTestCase, 1320 parameterized.TestCase): 1321 1322 @parameterized.parameters([ 1323 [dict(i=5, f=3.2, b=True, n=None)], 1324 [dict(x=(1, 2), y={ 1325 3: 4, 1326 5: 6 1327 })], 1328 [lambda: dict(t=constant_op.constant(123))], 1329 [lambda: dict(r=ragged_factory_ops.constant([[1, 2], [3]]))], 1330 ]) 1331 def testConstruction(self, fields): 1332 if callable(fields): 1333 fields = fields() 1334 extension_type.AnonymousExtensionType(**fields) 1335 1336 @parameterized.parameters([ 1337 [dict(x=[1, 2, 3]), 'unsupported `value` argument'], 1338 [dict(x=set([1, 2])), 'unsupported `value` argument'], 1339 [dict(x=(1, dict([(2, [])]))), 'unsupported `value` argument'], 1340 [ 1341 dict(_tf_extension_type_xyz=5), 1342 'Reserved field name .*_tf_extension_type_xyz.*' 1343 ], 1344 ]) 1345 def testConstructionErrors(self, fields, error): 1346 with self.assertRaisesRegex(ValueError, error): 1347 extension_type.AnonymousExtensionType(**fields) 1348 1349 @parameterized.parameters([ 1350 [dict(i=5, f=3.2, b=True, n=None)], 1351 [dict(x=(1, 2), y={ 1352 3: 4, 1353 5: 6 1354 })], 1355 [lambda: dict(t=constant_op.constant(123))], 1356 [lambda: dict(r=ragged_factory_ops.constant([[1, 2], [3]]))], 1357 ]) 1358 def testAttributeAccessors(self, fields): 1359 if callable(fields): 1360 fields = fields() 1361 s = extension_type.AnonymousExtensionType(**fields) 1362 for (name, value) in fields.items(): 1363 actual = getattr(s, name) 1364 if isinstance(actual, (ops.Tensor, ragged_tensor.RaggedTensor)): 1365 self.assertAllEqual(actual, value) 1366 else: 1367 self.assertEqual(actual, value) 1368 1369 def testAttributeAccessorsAreImmutable(self): 1370 s = extension_type.AnonymousExtensionType(x=12, y={'x': 55}) 1371 with self.assertRaisesRegex(AttributeError, 'Cannot set attribute `x`'): 1372 s.x = 22 1373 with self.assertRaisesRegex(AttributeError, 'Cannot delete attribute `y`'): 1374 del s.y 1375 with self.assertRaisesRegex(TypeError, 'does not support item assignment'): 1376 s.y['x'] = 66 1377 1378 def testReinterpret(self): 1379 x = MaskedTensorV2([4, 5], [True, False]) 1380 anon_x = extension_type.reinterpret(x, 1381 extension_type.AnonymousExtensionType) 1382 self.assertAllEqual(anon_x.values, [4, 5]) 1383 self.assertAllEqual(anon_x.mask, [True, False]) 1384 1385 round_trip_x = extension_type.reinterpret(anon_x, MaskedTensorV2) 1386 self.assertAllEqual(round_trip_x.values, [4, 5]) 1387 self.assertAllEqual(round_trip_x.mask, [True, False]) 1388 1389 converted_x = extension_type.reinterpret(anon_x, MaskedTensorV1) 1390 self.assertAllEqual(converted_x.values, [4, 5]) 1391 self.assertAllEqual(converted_x.mask, [True, False]) 1392 1393 # pylint: disable=g-long-lambda 1394 @parameterized.parameters([ 1395 [ 1396 lambda: extension_type.AnonymousExtensionType( 1397 values=constant_op.constant([1, 2, 3])), MaskedTensorV2, 1398 "Missing required fields: {'mask'}" 1399 ], 1400 [ 1401 lambda: extension_type.AnonymousExtensionType( 1402 values=(1, 2, 3), mask=None), MaskedTensorV2, 1403 'mask: expected a Tensor, got None' 1404 ], 1405 [ 1406 lambda: extension_type.AnonymousExtensionType( 1407 values=constant_op.constant([1, 2, 3]), 1408 mask=constant_op.constant([True, False])), MaskedTensorV2, 1409 'Shapes .* are incompatible' 1410 ], 1411 [ 1412 lambda: extension_type.AnonymousExtensionType( 1413 values=constant_op.constant([1, 2, 3])), ops.Tensor, 1414 'reinterpret expects `new_type` to be a subclass of ' 1415 'tf.ExtensionType; ' 1416 'got .*.Tensor.*' 1417 ], 1418 [ 1419 lambda: constant_op.constant([1, 2, 3]), 1420 extension_type.AnonymousExtensionType, 1421 'reinterpret expects `value` to be a tf.ExtensionType instance; ' 1422 'got.*.Tensor.*' 1423 ], 1424 ]) 1425 def testReinterpretErrors(self, value, new_type, error): 1426 if callable(value): 1427 value = value() 1428 with self.assertRaisesRegex((TypeError, ValueError), error): 1429 extension_type.reinterpret(value, new_type) 1430 1431 def testLoadSavedModelWithUnregisteredExtensionType(self): 1432 1433 def f(x, y): 1434 x_values = x.values if isinstance(x, MaskedTensorV1) else x 1435 y_values = y.values if isinstance(y, MaskedTensorV1) else y 1436 x_mask = x.mask if isinstance(x, MaskedTensorV1) else True 1437 y_mask = y.mask if isinstance(y, MaskedTensorV1) else True 1438 return MaskedTensorV1(x_values + y_values, x_mask & y_mask) 1439 1440 t_spec = tensor_spec.TensorSpec(None, dtypes.int32) 1441 b_spec = tensor_spec.TensorSpec(None, dtypes.bool) 1442 mt_spec = MaskedTensorV1.Spec(values=t_spec, mask=b_spec) 1443 model = module.Module() 1444 model.f = def_function.function(f) 1445 model.f.get_concrete_function(t_spec, t_spec) 1446 model.f.get_concrete_function(t_spec, mt_spec) 1447 model.f.get_concrete_function(mt_spec, t_spec) 1448 model.f.get_concrete_function(mt_spec, mt_spec) 1449 1450 path = tempfile.mkdtemp(prefix=test.get_temp_dir()) 1451 with temporarily_register_type_spec('tf.test.MaskedTensorV1.Spec', 1452 MaskedTensorV1.Spec): 1453 save.save(model, path) 1454 loaded_model = load.load(path) 1455 1456 with self.assertRaises(ValueError): 1457 type_spec.lookup('tf.test.MaskedTensorV1') 1458 1459 t = constant_op.constant([10, 20, 30]) 1460 v1 = loaded_model.f(t, t) 1461 self.assertIsInstance(v1, extension_type.AnonymousExtensionType) 1462 self.assertAllEqual(v1.values, [20, 40, 60]) 1463 self.assertAllEqual(v1.mask, True) 1464 1465 v2 = loaded_model.f(v1, v1) 1466 self.assertIsInstance(v2, extension_type.AnonymousExtensionType) 1467 self.assertAllEqual(v2.values, [40, 80, 120]) 1468 self.assertAllEqual(v2.mask, True) 1469 1470 mt = MaskedTensorV1([1, 2, 3], [True, True, False]) 1471 v3 = loaded_model.f( 1472 t, extension_type.reinterpret(mt, 1473 extension_type.AnonymousExtensionType)) 1474 self.assertIsInstance(v3, extension_type.AnonymousExtensionType) 1475 self.assertAllEqual(v3.values, [11, 22, 33]) 1476 self.assertAllEqual(v3.mask, [True, True, False]) 1477 1478 v4 = extension_type.reinterpret(v3, MaskedTensorV1) 1479 self.assertIsInstance(v4, MaskedTensorV1) 1480 self.assertAllEqual(v4.values, [11, 22, 33]) 1481 self.assertAllEqual(v4.mask, [True, True, False]) 1482 1483 def testFlatTensorSpecs(self): 1484 x = MaskedTensorV2([4, 5], [True, False]) 1485 spec = type_spec.type_spec_from_value(x) 1486 flat_specs = spec._flat_tensor_specs 1487 self.assertEqual(flat_specs, [ 1488 tensor_spec.TensorSpec(shape=(2,), dtype=dtypes.int32, name=None), 1489 tensor_spec.TensorSpec(shape=(2,), dtype=dtypes.bool, name=None) 1490 ]) 1491 1492 def testFullTypesForFlatTensors(self): 1493 x = MaskedTensorV2([4, 5], [True, False]) 1494 spec = type_spec.type_spec_from_value(x) 1495 full_type_list = fulltypes_for_flat_tensors(spec) 1496 expect = [ 1497 full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_UNSET), 1498 full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_UNSET) 1499 ] 1500 self.assertEqual(len(spec._flat_tensor_specs), len(full_type_list)) 1501 self.assertEqual(expect, full_type_list) 1502 1503 1504def replace_tensors_with_placeholders(value): 1505 1506 def repl(x): 1507 if isinstance(x, ops.Tensor): 1508 return array_ops.placeholder_with_default(x, shape=None) 1509 else: 1510 return x 1511 1512 return nest.map_structure(repl, value, expand_composites=True) 1513 1514 1515@contextlib.contextmanager 1516def temporarily_add_dispatch(op, typ, fn): 1517 n = len(op._tf_fallback_dispatchers) 1518 dispatch.dispatch_for_types(op, typ)(fn) 1519 yield 1520 assert len(op._tf_fallback_dispatchers) == n + 1 1521 del op._tf_fallback_dispatchers[-1] 1522 1523 1524@contextlib.contextmanager 1525def temporarily_register_type_spec(name, cls): 1526 """Context manager for making temporary changes to the TypeSpec registry.""" 1527 type_spec.register(name)(cls) 1528 yield 1529 assert type_spec._TYPE_SPEC_TO_NAME.pop(cls) == name 1530 assert type_spec._NAME_TO_TYPE_SPEC.pop(name) is cls 1531 1532 1533if __name__ == '__main__': 1534 googletest.main() 1535