xref: /aosp_15_r20/external/tensorflow/tensorflow/python/framework/extension_type_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tf.framework.extension_type."""
16
17import contextlib
18import copy
19import pickle
20import tempfile
21import typing
22
23from absl.testing import parameterized
24import typing_extensions
25
26from tensorflow.core.framework import full_type_pb2
27from tensorflow.python.data.ops import dataset_ops
28from tensorflow.python.distribute import mirrored_strategy
29from tensorflow.python.eager import context
30from tensorflow.python.eager import def_function
31from tensorflow.python.framework import constant_op
32from tensorflow.python.framework import dtypes
33from tensorflow.python.framework import extension_type
34from tensorflow.python.framework import extension_type_field
35from tensorflow.python.framework import immutable_dict
36from tensorflow.python.framework import ops
37from tensorflow.python.framework import tensor_shape
38from tensorflow.python.framework import tensor_spec
39from tensorflow.python.framework import test_util
40from tensorflow.python.framework import type_spec
41from tensorflow.python.framework.type_utils import fulltypes_for_flat_tensors
42from tensorflow.python.keras.engine import input_layer
43from tensorflow.python.keras.engine import training
44from tensorflow.python.keras.saving import save as keras_save
45from tensorflow.python.module import module
46from tensorflow.python.ops import array_ops
47from tensorflow.python.ops import control_flow_ops
48from tensorflow.python.ops import math_ops
49from tensorflow.python.ops.ragged import ragged_factory_ops
50from tensorflow.python.ops.ragged import ragged_tensor
51from tensorflow.python.platform import googletest
52from tensorflow.python.platform import test
53from tensorflow.python.saved_model import load
54from tensorflow.python.saved_model import save
55from tensorflow.python.util import dispatch
56from tensorflow.python.util import nest
57from tensorflow.python.util import tf_inspect
58
59
60POSITIONAL_OR_KEYWORD = tf_inspect.Parameter.POSITIONAL_OR_KEYWORD
61KEYWORD_ONLY = tf_inspect.Parameter.KEYWORD_ONLY
62
63
64class MaskedTensorV1(extension_type.ExtensionType):
65  """Example subclass of ExtensionType, used for testing."""
66  values: ops.Tensor
67  mask: ops.Tensor
68
69
70class MaskedTensorV2(extension_type.ExtensionType):
71  """Example subclass of ExtensionType, used for testing.
72
73  This version adds methods, classmethod, staticmethod, and properties, and
74  customizes `__repr__` and `__validate__`.  It also adds a `__name__` field,
75  which enables serialization.
76  """
77  __name__ = 'tf.test.MaskedTensorV2'
78
79  values: ops.Tensor
80  mask: ops.Tensor
81
82  def __repr__(self):
83    if hasattr(self.values, 'numpy') and hasattr(self.mask, 'numpy'):
84      return '<MaskedTensorV2 %s>' % _masked_array_repr(self.values.numpy(),
85                                                        self.mask.numpy())
86    else:
87      return super(MaskedTensorV2, self).__repr__()
88
89  @property
90  def shape(self):
91    return self.values.shape
92
93  @property
94  def dtype(self):
95    return self.values.dtype
96
97  @classmethod
98  def from_full_tensor(cls, values):
99    return cls(values, array_ops.ones_like(values, dtype=dtypes.bool))
100
101  # A dummy example to test support of staticmethod
102  @staticmethod
103  def doc_link():
104    return 'http://example.com/masked_tensor'
105
106  def __validate__(self):
107    self.values.shape.assert_is_compatible_with(self.mask.shape)
108
109  def with_default(self, default):
110    return array_ops.where_v2(self.mask, self.values, default)
111
112  __add__ = math_ops.add
113  __sub__ = math_ops.subtract
114
115
116def _masked_array_repr(values, mask):
117  """Returns a string representation for a masked numpy array."""
118  assert len(values) == len(mask)
119  if len(values.shape) == 1:
120    items = [repr(v) if m else '_' for (v, m) in zip(values, mask)]
121  else:
122    items = [_masked_array_repr(v, m) for (v, m) in zip(values, mask)]
123  return '[%s]' % ', '.join(items)
124
125
126class MaskedTensorV3(extension_type.BatchableExtensionType):
127  """Example subclass of ExtensionType, used for testing.
128
129  This version adds Keras required properties to MaskedTensor and its Spec
130  class, to test Keras integration.
131  """
132  __name__ = 'tf.test.MaskedTensorV3.Spec'
133
134  values: typing.Union[ops.Tensor, ragged_tensor.RaggedTensor]
135  mask: typing.Union[ops.Tensor, ragged_tensor.RaggedTensor]
136
137  def __init__(self, values, mask):
138    if isinstance(values, ragged_tensor.RaggedTensor):
139      assert isinstance(mask, ragged_tensor.RaggedTensor)
140      assert mask.dtype == dtypes.bool
141    else:
142      values = ops.convert_to_tensor(values)
143      mask = ops.convert_to_tensor(mask, dtypes.bool)
144    self.values = values
145    self.mask = mask
146
147  # Required by assert_input_compatibility in keras/engine/input_spec.py
148  @property
149  def shape(self):
150    return self.values.shape
151
152  @property
153  def dtype(self):
154    return self.values.dtype
155
156  class Spec:
157
158    # Required by KerasTensor.shape in keras/engine/keras_tensor.py
159    @property
160    def _shape(self):
161      return self.values._shape
162
163
164class ForwardRefA(extension_type.ExtensionType):
165  x: typing.Tuple[typing.Union['ForwardRefA', 'ForwardRefB'], ...]
166  y: 'ForwardRefB'
167
168
169class ForwardRefB(extension_type.ExtensionType):
170  z: 'ForwardRefB'
171  n: ops.Tensor
172
173
174class ExtensionTypeWithTensorDefault(extension_type.ExtensionType):
175  x: ops.Tensor = 5
176  y: ops.Tensor = ['a', 'b', 'c']
177
178
179@test_util.run_all_in_graph_and_eager_modes
180class ExtensionTypeTest(test_util.TensorFlowTestCase, parameterized.TestCase):
181
182  def testAttributeAccessors(self):
183    mt1 = MaskedTensorV2([1, 2, 3, 4], [True, True, False, True])
184    mt2 = extension_type.pack(mt1)
185
186    for mt in [mt1, mt2]:
187      self.assertIsInstance(mt.values, ops.Tensor)
188      self.assertAllEqual(mt.values, [1, 2, 3, 4])
189      self.assertIsInstance(mt.mask, ops.Tensor)
190      self.assertAllEqual(mt.mask, [True, True, False, True])
191
192  def testAttributesAreImmutable(self):
193    mt1 = MaskedTensorV2([1, 2, 3, 4], [True, True, False, True])
194    mt2 = extension_type.pack(mt1)
195
196    for mt in [mt1, mt2]:
197      with self.assertRaisesRegex(
198          AttributeError,
199          'Cannot mutate attribute `score` outside the custom constructor of ExtensionType'
200      ):
201        mt.score = 12
202      with self.assertRaisesRegex(
203          AttributeError,
204          'Cannot mutate attribute `values` outside the custom constructor of ExtensionType'
205      ):
206        mt.values = constant_op.constant([4, 3, 2, 1])
207      with self.assertRaisesRegex(
208          AttributeError,
209          'Cannot mutate attribute `values` outside the custom constructor of ExtensionType'
210      ):
211        del mt.values
212
213  def testClassAndStaticMethod(self):
214    mt = MaskedTensorV2.from_full_tensor([1, 2, 3, 4])
215    self.assertAllEqual(mt.mask, [True, True, True, True])
216    self.assertEqual(mt.doc_link(), 'http://example.com/masked_tensor')
217
218  def testRepr(self):
219    values = constant_op.constant([1, 2, 3, 4])
220    mask = constant_op.constant([True, True, False, True])
221    mt = MaskedTensorV1(values, mask)
222    expected = f'MaskedTensorV1(values={values!r}, mask={mask!r})'
223    self.assertEqual(expected, repr(mt))
224
225  def testEagerRepr(self):
226    values = constant_op.constant([1, 2, 3, 4])
227    mask = constant_op.constant([True, True, False, True])
228    mt = MaskedTensorV2(values, mask)
229    if context.executing_eagerly():
230      expected = '<MaskedTensorV2 [1, 2, _, 4]>'
231    else:
232      expected = f'MaskedTensorV2(values={values!r}, mask={mask!r})'
233
234    self.assertEqual(expected, repr(mt))
235    self.assertEqual(expected, repr(mt))
236
237  def testConstructorSignature(self):
238
239    class MyType(extension_type.ExtensionType):
240      x: ops.Tensor
241      y: ops.Tensor
242      z: typing.Tuple[typing.Union[int, str], ...] = [1, 'two', 3]
243
244    expected_parameters = [
245        tf_inspect.Parameter('self', POSITIONAL_OR_KEYWORD),
246        tf_inspect.Parameter('x', POSITIONAL_OR_KEYWORD, annotation=ops.Tensor),
247        tf_inspect.Parameter('y', POSITIONAL_OR_KEYWORD, annotation=ops.Tensor),
248        tf_inspect.Parameter(
249            'z',
250            POSITIONAL_OR_KEYWORD,
251            annotation=typing.Tuple[typing.Union[int, str], ...],
252            default=(1, 'two', 3)),
253    ]
254    expected_sig = tf_inspect.Signature(
255        expected_parameters, return_annotation=MyType)
256    self.assertEqual(expected_sig, tf_inspect.signature(MyType.__init__))
257
258  def testConstructorSignatureWithKeywordOnlyArgs(self):
259
260    class MyType(extension_type.ExtensionType):
261      a: int
262      b: str = 'Hello world'
263      c: ops.Tensor
264
265    expected_parameters = [
266        tf_inspect.Parameter('self', POSITIONAL_OR_KEYWORD),
267        tf_inspect.Parameter('a', POSITIONAL_OR_KEYWORD, annotation=int),
268        tf_inspect.Parameter(
269            'b', POSITIONAL_OR_KEYWORD, annotation=str, default='Hello world'),
270        tf_inspect.Parameter('c', KEYWORD_ONLY, annotation=ops.Tensor),
271    ]
272    expected_sig = tf_inspect.Signature(
273        expected_parameters, return_annotation=MyType)
274    self.assertEqual(expected_sig, tf_inspect.signature(MyType.__init__))
275
276  def testConstructorSignatureWithDefaultForTensorField(self):
277    a = ExtensionTypeWithTensorDefault()
278
279    # Check that the default values were *not* converted to Tensors:
280    sig = tf_inspect.signature(ExtensionTypeWithTensorDefault.__init__)
281    self.assertIsInstance(sig.parameters['x'].default, int)
282    self.assertIsInstance(sig.parameters['y'].default, list)
283
284    # The following would fail with "RuntimeError: Attempting to capture an
285    # EagerTensor without building a function" if we converted the default
286    # value to a Tensor when we built the type.
287    self.assertAllEqual(a.x + constant_op.constant(3), 8)
288
289  def testConstructorSignatureWithAnnotatedTensorField(self):
290
291    class MyType(extension_type.ExtensionType):
292      a: typing_extensions.Annotated[ops.Tensor, 'metadata']
293      b: typing_extensions.Annotated[str, 'metadata'] = 'Hello world'
294      c: typing.Optional[typing_extensions.Annotated[int, 'metadata']] = None
295
296    expected_parameters = [
297        tf_inspect.Parameter('self', POSITIONAL_OR_KEYWORD),
298        tf_inspect.Parameter('a', POSITIONAL_OR_KEYWORD, annotation=ops.Tensor),
299        tf_inspect.Parameter(
300            'b', POSITIONAL_OR_KEYWORD, annotation=str, default='Hello world'),
301        tf_inspect.Parameter(
302            'c',
303            POSITIONAL_OR_KEYWORD,
304            annotation=typing.Optional[int],
305            default=None),
306    ]
307    expected_sig = tf_inspect.Signature(
308        expected_parameters, return_annotation=MyType)
309    self.assertEqual(expected_sig, tf_inspect.signature(MyType.__init__))
310
311  def testEmptyType(self):
312
313    class EmptyType(extension_type.ExtensionType):
314      pass
315
316    self.assertEmpty(EmptyType._tf_extension_type_fields())
317    x = EmptyType()
318    self.assertEqual(
319        repr(x), 'ExtensionTypeTest.testEmptyType.<locals>.EmptyType()')
320
321  def testCustomConstrutor(self):
322
323    class SummarizedTensor(extension_type.ExtensionType):
324      values: ops.Tensor
325      mean: ops.Tensor
326      max: ops.Tensor
327
328      def __init__(self, values):
329        self.values = ops.convert_to_tensor(values)
330        self.mean = math_ops.reduce_mean(values)
331        self.max = math_ops.reduce_max(values)
332
333    x = SummarizedTensor([[1.0, 2, 3], [4, 5, 6]])
334    self.assertAllEqual(x.values, [[1.0, 2, 3], [4, 5, 6]])
335    self.assertAllEqual(x.mean, 3.5)
336    self.assertAllEqual(x.max, 6)
337
338  class Node(extension_type.ExtensionType):
339    x: ops.Tensor
340    y: typing.Optional[str] = None
341    children: typing.Tuple['ExtensionTypeTest.Node', ...] = ()
342
343  def testConstructorWithDefaultValues(self):
344    a = ExtensionTypeTest.Node(5)
345    self.assertAllEqual(a.x, 5)
346    self.assertIsNone(a.y)
347    self.assertEqual(a.children, ())
348
349    b = ExtensionTypeTest.Node(6, 'blue')
350    self.assertAllEqual(b.x, 6)
351    self.assertEqual(b.y, 'blue')
352    self.assertEqual(b.children, ())
353
354    c = ExtensionTypeTest.Node(7, children=(a, b))
355    self.assertAllEqual(c.x, 7)
356    self.assertIsNone(c.y)
357    self.assertEqual(c.children, (a, b))
358
359  def testCustomConstrutorCantMutateNestedValues(self):
360
361    class Foo(extension_type.ExtensionType):
362      x: int
363
364    class Bar(extension_type.ExtensionType):
365      foo: Foo
366
367      def __init__(self, foo):
368        foo.x = 33  # This raises an exception
369
370    with self.assertRaisesRegex(
371        AttributeError,
372        'Cannot mutate attribute `x` outside the custom constructor of ExtensionType'
373    ):
374      Bar(Foo(12))
375
376  def testCustomValidate(self):
377
378    class AlignedTensors(extension_type.ExtensionType):
379      x: ops.Tensor
380      y: ops.Tensor
381
382      def __validate__(self):
383        self.x.shape.assert_is_compatible_with(self.y.shape)
384
385    aligned = AlignedTensors([1, 2, 3], ['a', 'b', 'c'])
386    self.assertAllEqual(aligned.x, [1, 2, 3])
387    self.assertAllEqual(aligned.y, [b'a', b'b', b'c'])
388
389    with self.assertRaises(ValueError):
390      AlignedTensors([1, 2, 3], ['a', 'b', 'c', 'd'])
391
392  def testEquals(self):
393
394    class MyType(extension_type.ExtensionType):
395      values: ops.Tensor
396      score: ops.Tensor
397      flavor: str
398
399    x1 = MyType([1, 2], 8, 'blue')
400    x2 = MyType([1, 2], 8, 'blue')
401    y = MyType([1, 2], 8, 'red')
402    z = MyType([1, 2], 7, 'blue')
403    self.assertAllEqual(x1 == x2, True)
404    self.assertAllEqual(x1 != x2, False)
405    self.assertAllEqual(x1 == y, False)
406    self.assertAllEqual(x1 != y, True)
407    self.assertAllEqual(x1 == z, False)
408    self.assertAllEqual(y == z, False)
409
410    # These are not equal, even though their values are broadcast-compatible
411    # and elements are all equal when we broadcast.  Shapes must match.
412    a = MyType([1, 1, 1, 1], 0, 'x')
413    b = MyType([[1, 1, 1, 1]], 0, 'x')
414    c = MyType([[1, 1], [1, 1]], 0, 'x')
415    self.assertAllEqual(a == b, False)
416    self.assertAllEqual(a == c, False)
417    self.assertAllEqual(b == c, False)
418
419    # Test with unknown shapes (executes a different codepath).
420    a_ph = replace_tensors_with_placeholders(a)
421    b_ph = replace_tensors_with_placeholders(b)
422    c_ph = replace_tensors_with_placeholders(c)
423    self.assertAllEqual(a_ph == b_ph, False)
424    self.assertAllEqual(a_ph == c_ph, False)
425    self.assertAllEqual(b_ph == c_ph, False)
426
427  def testPassIntoTfFunction(self):
428
429    @def_function.function
430    def fn(x):
431      return x.with_default(99)
432
433    mt = MaskedTensorV2([1, 2, 3, 4], [True, True, False, True])
434    self.assertAllEqual([1, 2, 99, 4], fn(mt))
435    self.assertAllEqual([1, 2, 99, 4], fn(extension_type.pack(mt)))
436
437  def testReturnFromTfFunction(self):
438
439    @def_function.function
440    def mask_neg_values(x):
441      return MaskedTensorV2(x, x > 0)
442
443    @def_function.function
444    def mask_neg_values_packed(x):
445      return extension_type.pack(MaskedTensorV2(x, x > 0))
446
447    expected = MaskedTensorV2([5, 8, -3, 9], [True, True, False, True])
448
449    actual1 = mask_neg_values(constant_op.constant([5, 8, -3, 9]))
450    self.assertIsInstance(actual1, MaskedTensorV2)
451    self.assertAllEqual(expected.values, actual1.values)
452    self.assertAllEqual(expected.mask, actual1.mask)
453
454    actual2 = mask_neg_values_packed(constant_op.constant([5, 8, -3, 9]))
455    self.assertIsInstance(actual2, MaskedTensorV2)
456    self.assertTrue(extension_type.is_packed(actual2))
457    self.assertAllEqual(expected.values, actual2.values)
458    self.assertAllEqual(expected.mask, actual2.mask)
459
460  def testCaptureByTfFunction(self):
461    x = MaskedTensorV2(
462        values=[[1, 2, 3], [4, 5, 6]],
463        mask=[[True, True, True], [True, False, True]])
464
465    @def_function.function
466    def add_to_x(y):
467      return MaskedTensorV2(x.values + y.values, x.mask & y.mask)
468
469    actual = add_to_x(MaskedTensorV2([10, 20, 30], [False, True, True]))
470    expected = MaskedTensorV2(
471        values=[[11, 22, 33], [14, 25, 36]],
472        mask=[[False, True, True], [False, False, True]])
473    self.assertIsInstance(actual, MaskedTensorV2)
474    self.assertAllEqual(expected.values, actual.values)
475    self.assertAllEqual(expected.mask, actual.mask)
476
477  def testTfFunctionArgMutationError(self):
478
479    @def_function.function
480    def fn_with_side_effect(mts):
481      mts.append(MaskedTensorV1(mts[0].values * 2, mts[0].mask))
482
483    with self.assertRaisesRegex(ValueError, 'should not modify'):
484      fn_with_side_effect([MaskedTensorV1([10, 20, 30], [False, True, True])])
485
486  def testNestPackUnpack(self):
487
488    class CandyStore(extension_type.ExtensionType):
489      name: ops.Tensor
490      prices: typing.Mapping[str, ops.Tensor]
491
492    store = CandyStore('Yum', {'gum': [0.42, 0.48], 'chocolate': [0.83, 1.02]})
493    components = nest.flatten(store, expand_composites=True)
494    repacked_1 = nest.pack_sequence_as(
495        store, components, expand_composites=True)
496    repacked_2 = nest.pack_sequence_as(
497        store._type_spec, components, expand_composites=True)
498
499    # Note: dicts get sorted by key.
500    self.assertLen(components, 3)
501    self.assertAllEqual(components[0], b'Yum')
502    self.assertAllClose(components[1], [0.83, 1.02])
503    self.assertAllClose(components[2], [0.42, 0.48])
504
505    for repacked in [repacked_1, repacked_2]:
506      self.assertAllEqual(repacked.name, b'Yum')
507      self.assertAllClose(repacked.prices['gum'], [0.42, 0.48])
508      self.assertAllClose(repacked.prices['chocolate'], [0.83, 1.02])
509
510  def testSimpleCond(self):
511    x = MaskedTensorV1([1, 2, 3, 4], [True, False, True, False])
512    y = MaskedTensorV1([5, 6, 7, 8], [False, True, True, False])
513
514    x_2 = control_flow_ops.cond(
515        constant_op.constant(True), lambda: x, lambda: y)
516    y_2 = control_flow_ops.cond(
517        constant_op.constant(False), lambda: x, lambda: y)
518
519    self.assertAllEqual(x.values, x_2.values)
520    self.assertAllEqual(x.mask, x_2.mask)
521    self.assertAllEqual(y.values, y_2.values)
522    self.assertAllEqual(y.mask, y_2.mask)
523
524  def testComplexCond(self):
525    mt = MaskedTensorV1([1, 2, 3, 4], [True, False, True, False])
526
527    def true_fn():
528      return MaskedTensorV1(
529          array_ops.where_v2(mt.mask, mt.values, -1), mt.values > 3)
530
531    def false_fn():
532      return MaskedTensorV1(
533          array_ops.where_v2(mt.mask, 100, mt.values * 2),
534          math_ops.logical_not(mt.mask))
535
536    x = control_flow_ops.cond(constant_op.constant(True), true_fn, false_fn)
537    y = control_flow_ops.cond(constant_op.constant(False), true_fn, false_fn)
538
539    self.assertAllEqual(x.values, [1, -1, 3, -1])
540    self.assertAllEqual(x.mask, [False, False, False, True])
541    self.assertAllEqual(y.values, [100, 4, 100, 8])
542    self.assertAllEqual(y.mask, [False, True, False, True])
543
544  def testCondAutograph(self):
545
546    @def_function.function
547    def fn(mt):
548      if mt.values[3] > 3:
549        return MaskedTensorV1(
550            array_ops.where_v2(mt.mask, mt.values, -1), mt.values > 3)
551      else:
552        return MaskedTensorV1(
553            array_ops.where_v2(mt.mask, 100, mt.values * 2), not mt.mask)
554
555    x = fn(MaskedTensorV1([1, 2, 3, 4], [True, False, True, False]))
556    self.assertAllEqual(x.values, [1, -1, 3, -1])
557    self.assertAllEqual(x.mask, [False, False, False, True])
558
559  def testCondTypeMismatch(self):
560    if context.executing_eagerly:
561      # In eager mode, tf.cond eagerly runs either true_fn or false_fn, and
562      # ignores the other one; so it doesn't detect any type mismatches
563      # between the two outcomes.  (See _eager_cond_implementation in
564      # control_flow_ops.py.)
565      return
566
567    a = lambda: MaskedTensorV1([1, 2, 3], [True, True, False])
568    b = lambda: MaskedTensorV1(['a', 'b', 'c'], [False, True, True])
569    c = lambda: MaskedTensorV2([4, 5, 6], [True, True, False])
570    d = lambda: constant_op.constant([7, 8, 9])
571
572    with self.assertRaisesRegex(
573        ValueError,
574        'Incompatible return values of true_fn and false_fn: The two '
575        "structures don't have the same nested structure"):
576      control_flow_ops.cond(constant_op.constant(True), a, b)
577    with self.assertRaisesRegex(
578        TypeError, 'Incompatible return types of true_fn and false_fn: The two '
579        "structures don't have the same nested structure"):
580      control_flow_ops.cond(constant_op.constant(True), a, c)
581    with self.assertRaisesRegex(
582        ValueError,
583        'Incompatible return values of true_fn and false_fn: The two '
584        "structures don't have the same nested structure"):
585      control_flow_ops.cond(constant_op.constant(True), a, d)
586
587  def testCondPacked(self):
588    x = MaskedTensorV2([1, 2, 3, 4], [True, False, True, False])
589    y = MaskedTensorV2([5, 6, 7, 8], [False, True, True, False])
590    x = extension_type.pack(x)
591    y = extension_type.pack(y)
592
593    x_2 = control_flow_ops.cond(
594        constant_op.constant(True), lambda: x, lambda: y)
595    y_2 = control_flow_ops.cond(
596        constant_op.constant(False), lambda: x, lambda: y)
597
598    self.assertAllEqual(x.values, x_2.values)
599    self.assertAllEqual(x.mask, x_2.mask)
600    self.assertAllEqual(y.values, y_2.values)
601    self.assertAllEqual(y.mask, y_2.mask)
602
603    a = MaskedTensorV2([1, 2, 3, 4], [True, False, True, False])
604    b = extension_type.pack(a)
605    b = control_flow_ops.cond(
606        constant_op.constant(True), lambda: array_ops.size(a.mask),
607        lambda: array_ops.size(a.values))
608    self.assertAllEqual(b, 4)
609
610    # Note: the following example would fail (with `Retval[0] does not have a
611    # value`) if `ExtensionType.__getattr__` cached the results of unpacking
612    # the value.  See the comment in `ExtensionType.__getattr__` for details.
613    c = MaskedTensorV2([1, 2, 3, 4], [True, False, True, False])
614    c = extension_type.pack(c)
615    d = control_flow_ops.cond(
616        constant_op.constant(False), lambda: array_ops.size(c.mask),
617        lambda: array_ops.size(c.values))
618    self.assertAllEqual(d, 4)
619
620  def testWhileLoop(self):
621    x = MaskedTensorV1([1, 2, 3, 4], [True, False, True, False])
622
623    cond = lambda i, x: i < 10
624    body = lambda i, x: (i + 1, MaskedTensorV1(x.values * 2, x.mask))
625    _, y = control_flow_ops.while_loop_v2(cond, body, [0, x])
626
627    self.assertIsInstance(y, MaskedTensorV1)
628    self.assertAllEqual(y.values, [1024, 2048, 3072, 4096])
629    self.assertAllEqual(y.mask, [True, False, True, False])
630
631  def testWhileLoopAutograph(self):
632
633    @def_function.function
634    def fn(x, n):
635      for _ in math_ops.range(n):
636        x = MaskedTensorV1(x.values * 2, x.mask)
637      return x
638
639    y = fn(MaskedTensorV1([1, 2, 3, 4], [True, False, True, False]), 10)
640    self.assertIsInstance(y, MaskedTensorV1)
641    self.assertAllEqual(y.values, [1024, 2048, 3072, 4096])
642    self.assertAllEqual(y.mask, [True, False, True, False])
643
644  def testWhileLoopTypeMismatch(self):
645    x = MaskedTensorV1([1, 2, 3, 4], [True, False, True, False])
646
647    cond = lambda i, x: i < 10
648
649    def body(i, x):
650      if isinstance(x, MaskedTensorV1):
651        return x.values * 2
652      else:
653        return MaskedTensorV1(x, x > i)
654
655    with self.assertRaisesRegex(
656        ValueError, "The two structures don't have the same nested structure"):
657      control_flow_ops.while_loop_v2(cond, body, [0, x])
658
659  def testWhileLoopPacked(self):
660    x = MaskedTensorV2([1, 2, 3, 4], [True, False, True, False])
661    x = extension_type.pack(x)
662    cond = lambda i, x: i < 10
663
664    def body(i, x):
665      return i + 1, extension_type.pack(MaskedTensorV2(x.values * 2, x.mask))
666
667    _, y = control_flow_ops.while_loop_v2(cond, body, [0, x])
668    self.assertIsInstance(y, MaskedTensorV2)
669    self.assertAllEqual(y.values, [1024, 2048, 3072, 4096])
670    self.assertAllEqual(y.mask, [True, False, True, False])
671
672  def testNestedFields(self):
673    PossiblyRaggedTensor = typing.Union[ops.Tensor, ragged_tensor.RaggedTensor]
674    ToyFeatures = typing.Mapping[str, PossiblyRaggedTensor]
675
676    class ToyInfo(extension_type.ExtensionType):
677      version: str
678      toys: typing.Tuple[typing.Tuple[str, ops.Tensor, ToyFeatures], ...]
679      boxes: typing.Mapping[str, ops.Tensor]
680
681    authors = [[b'A', b'Aardvark'], [b'Z', b'Zhook']]
682    toys = [('car', 1.0, {
683        'size': [8, 3, 2],
684        'color': [0.3, 0.2, 0.8]
685    }), ('book', 3.7, {
686        'authors': ragged_factory_ops.constant(authors)
687    })]
688    boxes = {'green': ['car'], 'blue': ['car', 'book', 'book']}
689    toy_info = ToyInfo(version='1.0 alpha', toys=toys, boxes=boxes)
690
691    self.assertEqual(toy_info.version, '1.0 alpha')
692    self.assertEqual(toy_info.toys[0][0], 'car')
693    self.assertIsInstance(toy_info.toys[0][1], ops.Tensor)
694    self.assertAllEqual(toy_info.toys[0][1], 1.0)
695    self.assertEqual(set(toy_info.toys[0][2].keys()), {'size', 'color'})
696    self.assertIsInstance(toy_info.toys[0][2]['size'], ops.Tensor)
697    self.assertAllEqual(toy_info.toys[0][2]['size'], [8, 3, 2])
698    self.assertIsInstance(toy_info.toys[1][2]['authors'],
699                          ragged_tensor.RaggedTensor)
700    self.assertAllEqual(toy_info.toys[1][2]['authors'], authors)
701    self.assertAllEqual(toy_info.boxes['green'], [b'car'])
702    self.assertAllEqual(toy_info.boxes['blue'], ['car', 'book', 'book'])
703
704    expected_repr = (
705        r"ToyInfo\(version='1.0 alpha', toys=\("
706        r"\('car', <tf.Tensor[^>]*>, ImmutableDict\("
707        r"{'size': <tf.Tensor[^>]*>, 'color': <tf.Tensor[^>]*>}\)\), "
708        r"\('book', <tf.Tensor[^>]*>, ImmutableDict\("
709        r"{'authors': (<tf.RaggedTensor[^>]*>|tf.RaggedTensor\(.*\))}\)\)\), "
710        r'boxes=ImmutableDict\('
711        r"{'green': <tf.Tensor[^>]*>, 'blue': <tf.Tensor[^>]*>}\)\)")
712
713    self.assertRegex(repr(toy_info), expected_repr)
714
715  def testNestedExtensionTypes(self):
716    PossiblyMaskedTensor = typing.Union[ops.Tensor, MaskedTensorV1]
717
718    class Toy(extension_type.ExtensionType):
719      name: str
720      price: ops.Tensor
721      features: typing.Mapping[str, PossiblyMaskedTensor]
722
723    class Box(extension_type.ExtensionType):
724      contents: ops.Tensor
725
726    class ToyInfo(extension_type.ExtensionType):
727      version: str
728      toys: typing.Tuple[Toy, ...]
729      boxes: typing.Mapping[str, Box]
730
731    authors = MaskedTensorV1(
732        values=[[b'A', b'Quincy', b'Aardvark'], [b'Z', b'Zhook', b'']],
733        mask=[[True, True, True], [True, True, False]])
734    toys = [
735        Toy('car', 1.0, {
736            'size': [8, 3, 2],
737            'color': [0.3, 0.2, 0.8]
738        }),
739        Toy(name='book', price=3.7, features={'authors': authors})
740    ]
741    boxes = {
742        'green': Box(['car']),
743        'blue': Box(contents=['car', 'book', 'book'])
744    }
745    toy_info = ToyInfo(version='1.0 alpha', toys=toys, boxes=boxes)
746
747    @def_function.function
748    def fn(info):
749      prices = [toy.price for toy in info.toys]
750      return math_ops.reduce_sum(array_ops.stack(prices))
751
752    self.assertAllClose(fn(toy_info), 4.7)
753
754  def testNestedCustomConstructor(self):
755
756    class Toy(extension_type.ExtensionType):
757      name: str
758      price: ops.Tensor
759
760      def __init__(self, name, price, discount=0):
761        if discount:
762          name += ' (discounted)'
763          price *= (1 - discount)
764        self.name = name
765        self.price = price
766
767    class ToyBox(extension_type.ExtensionType):
768      toys: typing.Tuple[Toy, ...]
769
770      def __init__(self, name_to_price, name_to_discount):
771        self.toys = [
772            Toy(name, price, name_to_discount.get(name, 0))
773            for (name, price) in name_to_price.items()
774        ]
775
776    toy_box = ToyBox({
777        'car': 8.3,
778        'truck': 5.9,
779        'puzzle': 5.3,
780        'jacks': 2.8
781    }, {
782        'puzzle': .2,
783        'truck': .3
784    })
785    self.assertLen(toy_box.toys, 4)
786    self.assertEqual(
787        set(toy.name for toy in toy_box.toys),
788        {'car', 'truck (discounted)', 'puzzle (discounted)', 'jacks'})
789
790  def testExtensionTypeWithMathOperators(self):
791
792    def masked_add(x, y, name=None):
793      del name
794      if not isinstance(x, MaskedTensorV2) and isinstance(y, MaskedTensorV2):
795        return dispatch.OpDispatcher.NOT_SUPPORTED
796      return MaskedTensorV2(x.values + y.values, x.mask & y.mask)
797
798    with temporarily_add_dispatch(math_ops.add, MaskedTensorV2, masked_add):
799      x = MaskedTensorV2([[1, 2], [3, 4]], [[True, False], [True, True]])
800      y = MaskedTensorV2([[3, 4], [5, 6]], [[True, True], [False, True]])
801      z = x + y
802      self.assertAllEqual(z.values, [[4, 6], [8, 10]])
803      self.assertAllEqual(z.mask, [[True, False], [False, True]])
804
805  def testGetExtensionTypeFields(self):
806
807    # Can be called on a type or an instance:
808    fields_1 = MaskedTensorV1._tf_extension_type_fields()
809    fields_2 = MaskedTensorV1([0], [True])._tf_extension_type_fields()
810
811    for fields in [fields_1, fields_2]:
812      self.assertLen(fields, 2)
813      self.assertEqual(fields[0].name, 'values')
814      self.assertEqual(fields[0].value_type, ops.Tensor)
815      self.assertEqual(fields[0].default, fields[0].NO_DEFAULT)
816      self.assertEqual(fields[1].name, 'mask')
817      self.assertEqual(fields[1].value_type, ops.Tensor)
818      self.assertEqual(fields[1].default, fields[0].NO_DEFAULT)
819
820  def testHasExtensionTypeField(self):
821
822    self.assertTrue(MaskedTensorV1._tf_extension_type_has_field('values'))
823    self.assertTrue(MaskedTensorV1._tf_extension_type_has_field('mask'))
824    self.assertFalse(MaskedTensorV1._tf_extension_type_has_field('labels'))
825
826    mt = MaskedTensorV1([0], [True])
827    self.assertTrue(mt._tf_extension_type_has_field('values'))
828    self.assertTrue(mt._tf_extension_type_has_field('mask'))
829    self.assertFalse(mt._tf_extension_type_has_field('labels'))
830
831  def testForwardReferences(self):
832    A, B = ForwardRefA, ForwardRefB
833
834    self.assertEqual(A._tf_extension_type_fields(),
835                     (extension_type_field.ExtensionTypeField(
836                         'x', typing.Tuple[typing.Union[A, B], ...]),
837                      extension_type_field.ExtensionTypeField('y', B)))
838    self.assertEqual(B._tf_extension_type_fields(),
839                     (extension_type_field.ExtensionTypeField('z', B),
840                      extension_type_field.ExtensionTypeField('n', ops.Tensor)))
841
842    # Check the signature.
843    expected_parameters = [
844        tf_inspect.Parameter('self', POSITIONAL_OR_KEYWORD),
845        tf_inspect.Parameter(
846            'x',
847            POSITIONAL_OR_KEYWORD,
848            annotation=typing.Tuple[typing.Union['ForwardRefA', 'ForwardRefB'],
849                                    ...]),
850        tf_inspect.Parameter(
851            'y', POSITIONAL_OR_KEYWORD, annotation='ForwardRefB'),
852    ]
853    expected_sig = tf_inspect.Signature(
854        expected_parameters, return_annotation=A)
855    self.assertEqual(tf_inspect.signature(A.__init__), expected_sig)
856
857  def testUnresolvedForwardReference(self):
858
859    class Broken(extension_type.ExtensionType):
860      x: 'Cra'  # note: intentional typo for Car.
861
862    class Car(extension_type.ExtensionType):
863      speed: float
864
865    with self.assertRaises(TypeError):
866      Broken(x=Car(3.8))
867
868  def testUnsupportedAnnotations(self):
869    with self.assertRaisesRegex(
870        TypeError, "In field 'values': Unsupported type annotation"):
871
872      class MyType1(extension_type.ExtensionType):  # pylint: disable=unused-variable
873        values: typing.List[ops.Tensor]
874
875    with self.assertRaisesRegex(TypeError,
876                                "In field 'xyz': Unsupported type annotation"):
877
878      class MyType2(extension_type.ExtensionType):  # pylint: disable=unused-variable
879        xyz: typing.Union[typing.Tuple[complex, ...], int]
880
881  def testCantUseReservedName(self):
882    with self.assertRaisesRegex(
883        ValueError, 'The field annotations for MyType1 are invalid. '
884        "Field '_to_components' is reserved"):
885
886      class MyType1(extension_type.ExtensionType):  # pylint: disable=unused-variable
887        _to_components: int
888
889    with self.assertRaisesRegex(
890        ValueError, 'The field annotations for MyType2 are invalid. '
891        "Field '_tf_extension_type_foo' is reserved"):
892
893      class MyType2(extension_type.ExtensionType):  # pylint: disable=unused-variable
894        _tf_extension_type_foo: int
895
896    with self.assertRaisesRegex(
897        ValueError, 'The field annotations for MyType3 are invalid. '
898        "Field 'is_compatible_with' is reserved"):
899
900      class MyType3(extension_type.ExtensionType):  # pylint: disable=unused-variable
901
902        def is_compatible_with(self, other):
903          return False
904
905  def testExtensionTypeBaseClassHasNoSpec(self):
906    self.assertFalse(hasattr(extension_type.ExtensionType, 'Spec'))
907
908  def testExtensionTypeBaseConstructorRaisesException(self):
909    with self.assertRaisesRegex(AssertionError,
910                                'ExtensionType is an abstract base class.'):
911      extension_type.ExtensionType()
912
913  class ExtensionTypeWithName(extension_type.ExtensionType):
914    __name__ = 'tf.__test__.ExtensionTypeWithName'  # For SavedModel
915    x: typing.Tuple[ops.Tensor, int]
916    y: ops.Tensor
917
918  def testSavedModelSupport(self):
919
920    class TestModule(module.Module):
921
922      @def_function.function
923      def f(self, s):
924        return s.x[0] + s.x[1] + s.y
925
926    s1 = self.ExtensionTypeWithName((1, 2), 3)
927    s2 = self.ExtensionTypeWithName((1.0, 2), [3.0, 4.0])
928
929    m = TestModule()
930    m.f.get_concrete_function(s1)
931    m.f.get_concrete_function(s2)
932
933    path = tempfile.mkdtemp(prefix=test.get_temp_dir())
934    save.save(m, path)
935    loaded = load.load(path)
936
937    self.assertAllEqual(loaded.f(s1), 6)
938    self.assertAllEqual(loaded.f(s2), [6.0, 7.0])
939
940  def testPackedEncoding(self):
941    mt1 = MaskedTensorV2([1, 2, 3, 4], [True, True, False, True])
942    self.assertLen(nest.flatten(mt1, expand_composites=True), 2)
943
944    mt2 = extension_type.pack(mt1)
945    self.assertLen(nest.flatten(mt2, expand_composites=True), 1)
946    self.assertIsInstance(mt2.values, ops.Tensor)
947    self.assertAllEqual(mt2.values, [1, 2, 3, 4])
948    self.assertIsInstance(mt2.mask, ops.Tensor)
949    self.assertAllEqual(mt2.mask, [True, True, False, True])
950
951    mt3 = extension_type.unpack(mt2)
952    self.assertLen(nest.flatten(mt3, expand_composites=True), 2)
953    self.assertIsInstance(mt3.values, ops.Tensor)
954    self.assertAllEqual(mt3.values, [1, 2, 3, 4])
955    self.assertIsInstance(mt3.mask, ops.Tensor)
956    self.assertAllEqual(mt3.mask, [True, True, False, True])
957
958    nest.assert_same_structure(mt1, mt3, expand_composites=True)
959    with self.assertRaisesRegex(ValueError, "don't have the same"):  # pylint: disable=g-error-prone-assert-raises
960      nest.assert_same_structure(mt1, mt2, expand_composites=True)
961
962    mt4 = MaskedTensorV1([1, 2, 3, 4], [True, True, False, True])
963    with self.assertRaisesRegex(
964        ValueError,
965        'ExtensionTypes must have a __name__ field in order to be packed.'):
966      extension_type.pack(mt4)
967
968  def testSubclassing(self):
969
970    class Instrument(extension_type.ExtensionType):
971      name: ops.Tensor
972      weight: ops.Tensor
973      needs_case: bool
974
975    class StringInstrument(Instrument):
976      num_strings: int  # Add a new field
977      needs_case: bool = True  # Override default value.
978
979    class Violin(StringInstrument):
980      maker: ops.Tensor
981      num_strings: int = 4  # Override default value.
982      name: str = 'violin'  # Override field type and default value.
983
984    self.assertEqual(
985        list(
986            tf_inspect.signature(
987                StringInstrument.__init__).parameters.values()), [
988                    tf_inspect.Parameter('self', POSITIONAL_OR_KEYWORD),
989                    tf_inspect.Parameter(
990                        'name', POSITIONAL_OR_KEYWORD, annotation=ops.Tensor),
991                    tf_inspect.Parameter(
992                        'weight', POSITIONAL_OR_KEYWORD, annotation=ops.Tensor),
993                    tf_inspect.Parameter(
994                        'needs_case',
995                        POSITIONAL_OR_KEYWORD,
996                        annotation=bool,
997                        default=True),
998                    tf_inspect.Parameter(
999                        'num_strings', KEYWORD_ONLY, annotation=int),
1000                ])
1001    self.assertEqual(
1002        list(tf_inspect.signature(Violin.__init__).parameters.values()), [
1003            tf_inspect.Parameter('self', POSITIONAL_OR_KEYWORD),
1004            tf_inspect.Parameter(
1005                'name', POSITIONAL_OR_KEYWORD, annotation=str,
1006                default='violin'),
1007            tf_inspect.Parameter('weight', KEYWORD_ONLY, annotation=ops.Tensor),
1008            tf_inspect.Parameter(
1009                'needs_case', KEYWORD_ONLY, annotation=bool, default=True),
1010            tf_inspect.Parameter(
1011                'num_strings', KEYWORD_ONLY, annotation=int, default=4),
1012            tf_inspect.Parameter('maker', KEYWORD_ONLY, annotation=ops.Tensor),
1013        ])
1014
1015    violin = Violin(weight=28, maker='Amati')
1016    self.assertAllEqual(violin.name, 'violin')
1017    self.assertAllEqual(violin.weight, 28)
1018    self.assertAllEqual(violin.needs_case, True)
1019    self.assertAllEqual(violin.num_strings, 4)
1020    self.assertAllEqual(violin.maker, 'Amati')
1021
1022
1023# integration test to test compatibility with high level api like Dataset
1024# and Keras
1025class ExtensionTypeIntegrationTest(test_util.TensorFlowTestCase):
1026
1027  @test_util.run_v2_only
1028  def testDataset(self):
1029    mt = MaskedTensorV3([[1], [2], [3]], [[True], [False], [True]])
1030    ds = dataset_ops.DatasetV2.from_tensors(mt)
1031    self.assertEqual(next(iter(ds)), mt)
1032
1033  @test_util.run_v2_only
1034  def testDatasetBatch(self):
1035    xs = MaskedTensorV3([[1], [2], [3]], [[True], [False], [True]])
1036    x0 = MaskedTensorV3(xs.values[0], xs.mask[0])
1037
1038    ds = dataset_ops.DatasetV2.from_tensors(xs)
1039    self.assertEqual(next(iter(ds)), xs)
1040    ds = ds.unbatch()
1041    self.assertEqual(next(iter(ds)), x0)
1042
1043    ds = dataset_ops.DatasetV2.from_tensor_slices(xs)
1044    self.assertEqual(next(iter(ds)), x0)
1045    ds = ds.batch(3, drop_remainder=True)
1046    self.assertEqual(next(iter(ds)), xs)
1047
1048  @test_util.run_v2_only
1049  def testDatasetBatchRagged(self):
1050    xs = MaskedTensorV3(
1051        ragged_factory_ops.constant([[1], [2, 3], [4]]),
1052        ragged_factory_ops.constant([[True], [False], [True]]))
1053    x0 = MaskedTensorV3(xs.values[0], xs.mask[0])
1054
1055    ds = dataset_ops.DatasetV2.from_tensors(xs)
1056    self.assertEqual(next(iter(ds)), xs)
1057    ds = ds.unbatch()
1058    self.assertEqual(next(iter(ds)), x0)
1059
1060    ds = dataset_ops.DatasetV2.from_tensor_slices(xs)
1061    self.assertEqual(next(iter(ds)), x0)
1062    ds = ds.batch(3, drop_remainder=True)
1063    self.assertEqual(next(iter(ds)), xs)
1064
1065  @test_util.run_v2_only
1066  def testDistributedDataset(self):
1067    strategy = mirrored_strategy.MirroredStrategy(['GPU:0', 'GPU:1'])
1068    mt = MaskedTensorV3([[1], [2], [3], [4]], [[True], [False], [True], [True]])
1069    ds = dataset_ops.DatasetV2.from_tensor_slices(mt).batch(2)
1070    dist_dataset = strategy.experimental_distribute_dataset(ds)
1071    expect = MaskedTensorV3([[1]], [[True]])
1072    per_replica_result = next(iter(dist_dataset))
1073    self.assertEqual(per_replica_result.values[0].values, expect.values[0])
1074    self.assertEqual(per_replica_result.values[0].mask, expect.mask[0])
1075
1076  # TODO(edloper): Move this test to Keras.
1077  @test_util.run_v2_only
1078  def testKerasModel(self):
1079    mt_spec = MaskedTensorV3.Spec(
1080        tensor_spec.TensorSpec(shape=[None, 1], dtype=dtypes.int32),
1081        tensor_spec.TensorSpec(shape=[None, 1], dtype=dtypes.bool),
1082    )
1083    model_input = input_layer.Input(type_spec=mt_spec)
1084    model_output = array_ops.identity(model_input, name='output')
1085    model = training.Model(inputs=model_input, outputs=model_output)
1086    mt = MaskedTensorV3([[1], [2], [3]], [[True], [False], [True]])
1087    self.assertEqual(model(mt), mt)
1088    ds = dataset_ops.DatasetV2.from_tensors(mt)
1089    self.assertEqual(model.predict(ds), mt)
1090
1091    with self.subTest('keras save'):
1092      path = self.create_tempdir().full_path
1093      model.save(path)
1094      loaded_model = keras_save.load_model(path)
1095      self.assertEqual(loaded_model.input.type_spec, mt_spec)
1096      self.assertEqual(loaded_model(mt), mt)
1097
1098      loaded_fn = load.load(path)
1099      self.assertEqual(loaded_fn(mt), mt)
1100      with self.assertRaisesRegex(
1101          ValueError,
1102          'Could not find matching concrete function to call '
1103          'loaded from the SavedModel',
1104      ):
1105        loaded_fn(MaskedTensorV3([1, 2, 3], [True, False, True]))
1106
1107      # The serving_fn use flatten signature
1108      serving_fn = loaded_fn.signatures['serving_default']
1109      self.assertEqual(
1110          serving_fn(args_0=mt.values, args_0_1=mt.mask)['tf.identity'], mt)
1111
1112
1113@test_util.run_all_in_graph_and_eager_modes
1114class ExtensionTypeSpecTest(test_util.TensorFlowTestCase,
1115                            parameterized.TestCase):
1116
1117  def testSpecConstructor(self):
1118    values_spec = tensor_spec.TensorSpec([4], dtypes.float32)
1119    mask_spec = tensor_spec.TensorSpec([4], dtypes.bool)
1120    mt_spec = MaskedTensorV1.Spec(values_spec, mask_spec)
1121    self.assertEqual(mt_spec.values, values_spec)
1122    self.assertEqual(mt_spec.mask, mask_spec)
1123
1124    mt = MaskedTensorV1([1.0, 2.0, 3.0, 4.0], [True, True, False, True])
1125    self.assertEqual(mt._type_spec, mt_spec)
1126
1127  def testSpecConstructorSignature(self):
1128
1129    class MyType(extension_type.ExtensionType):
1130      x: ops.Tensor
1131      y: ops.Tensor
1132      z: typing.Tuple[typing.Union[int, str], ...] = [1, 'two', 3]
1133
1134    expected_parameters = [
1135        tf_inspect.Parameter('self', POSITIONAL_OR_KEYWORD),
1136        tf_inspect.Parameter('x', POSITIONAL_OR_KEYWORD),
1137        tf_inspect.Parameter('y', POSITIONAL_OR_KEYWORD),
1138        tf_inspect.Parameter('z', POSITIONAL_OR_KEYWORD),
1139    ]
1140    expected_sig = tf_inspect.Signature(
1141        expected_parameters, return_annotation=MyType.Spec)
1142    self.assertEqual(expected_sig, tf_inspect.signature(MyType.Spec.__init__))
1143
1144  def testSpecAttributesAreImmutable(self):
1145    mt = MaskedTensorV1([1, 2, 3, 4], [True, True, False, True])
1146    mt_spec = MaskedTensorV1.Spec.from_value(mt)
1147    with self.assertRaisesRegex(
1148        AttributeError, 'Cannot mutate attribute `score` '
1149        'outside the custom constructor of ExtensionTypeSpec'):
1150      mt_spec.score = 12
1151    with self.assertRaisesRegex(
1152        AttributeError, 'Cannot mutate attribute `values` '
1153        'outside the custom constructor of ExtensionTypeSpec'):
1154      mt_spec.values = constant_op.constant([4, 3, 2, 1])
1155    with self.assertRaisesRegex(
1156        AttributeError, 'Cannot mutate attribute `values` '
1157        'outside the custom constructor of ExtensionTypeSpec'):
1158      del mt_spec.values
1159
1160  def testSpecFromValue(self):
1161    mt = MaskedTensorV1([1.0, 2.0, 3.0, 4.0], [True, True, False, True])
1162    mt_spec = MaskedTensorV1.Spec.from_value(mt)
1163
1164    expected_values_spec = tensor_spec.TensorSpec([4], dtypes.float32)
1165    expected_mask_spec = tensor_spec.TensorSpec([4], dtypes.bool)
1166    self.assertEqual(mt_spec.values, expected_values_spec)
1167    self.assertEqual(mt_spec.mask, expected_mask_spec)
1168
1169  def testSpecSerialize(self):
1170
1171    class Zoo(extension_type.ExtensionType):
1172      zookeepers: typing.Tuple[str, ...]
1173      animals: typing.Mapping[str, typing.Mapping[str, ops.Tensor]]
1174
1175    featurespec = {
1176        'size': tensor_spec.TensorSpec([3]),
1177        'weight': tensor_spec.TensorSpec([])
1178    }
1179    zoo_spec = Zoo.Spec(
1180        zookeepers=['Zoey', 'Zack'],
1181        animals={
1182            'tiger': featurespec,
1183            'elephant': featurespec
1184        })
1185
1186    serialized = zoo_spec._serialize()
1187    self.assertEqual(serialized,
1188                     (('zookeepers', ('Zoey', 'Zack')), ('animals', {
1189                         'tiger': featurespec,
1190                         'elephant': featurespec
1191                     })))
1192    restored = Zoo.Spec._deserialize(serialized)
1193    self.assertEqual(zoo_spec, restored)
1194
1195    # ImmutableDict is used for the field, but dict for the serialization:
1196    self.assertIsInstance(zoo_spec.animals, immutable_dict.ImmutableDict)
1197    serialized_field_name, serialized_field_value = serialized[1]
1198    self.assertEqual(serialized_field_name, 'animals')
1199    self.assertIsInstance(serialized_field_value, dict)
1200
1201  def testSpecComponents(self):
1202
1203    class Zoo(extension_type.ExtensionType):
1204      zookeepers: typing.Tuple[str, ...]
1205      animals: typing.Mapping[str, typing.Mapping[str, ops.Tensor]]
1206
1207    zoo = Zoo(
1208        ['Zoey', 'Zack'], {
1209            'elephant': {
1210                'size': [25, 30, 20],
1211                'weight': 2000.0
1212            },
1213            'tiger': {
1214                'hunger': 3.2,
1215                'size': [3, 8, 2],
1216                'weight': 87.3
1217            }
1218        })
1219    zoo_spec = Zoo.Spec.from_value(zoo)
1220
1221    components = zoo_spec._to_components(zoo)
1222    self.assertLen(components, 5)
1223    self.assertAllClose(components[0], [25, 30, 20])
1224    self.assertAllClose(components[1], 2000.0)
1225    self.assertAllClose(components[2], 3.2)
1226    self.assertAllClose(components[3], [3, 8, 2])
1227    self.assertAllClose(components[4], 87.3)
1228
1229    restored = zoo_spec._from_components(components)
1230    self.assertAllEqual(zoo == restored, True)
1231
1232    self.assertEqual(zoo_spec._component_specs,
1233                     (tensor_spec.TensorSpec([3], dtypes.int32),
1234                      tensor_spec.TensorSpec([], dtypes.float32),
1235                      tensor_spec.TensorSpec([], dtypes.float32),
1236                      tensor_spec.TensorSpec([3], dtypes.int32),
1237                      tensor_spec.TensorSpec([], dtypes.float32)))
1238
1239  def testCopyAndPickle(self):
1240    values_spec = tensor_spec.TensorSpec([4], dtypes.float32)
1241    mask_spec = tensor_spec.TensorSpec([4], dtypes.bool)
1242    mt_spec = MaskedTensorV1.Spec(values_spec, mask_spec)
1243    self.assertEqual(copy.copy(mt_spec), mt_spec)
1244    self.assertEqual(copy.deepcopy(mt_spec), mt_spec)
1245    self.assertEqual(pickle.loads(pickle.dumps(mt_spec)), mt_spec)
1246
1247  def testCustomizeSpecTest(self):
1248
1249    class WeightedTensor(extension_type.ExtensionType):
1250      """ExtensionType with a customized TypeSpec.
1251
1252      * Custom constructor.
1253      * Custom __validate__.
1254      * Add properties (shape, dtype, weight_dtype).
1255      * Add method (with_shape).
1256      """
1257      values: ops.Tensor
1258      weight: ops.Tensor  # scalar
1259
1260      shape = property(lambda self: self.shape)
1261      dtype = property(lambda self: self.dtype)
1262      weight_dtype = property(lambda self: self.weight.dtype)
1263
1264      def __validate__(self):
1265        self.weight.shape.assert_has_rank(0)
1266
1267      class Spec:
1268
1269        def __init__(self, shape, dtype, weight_dtype=dtypes.float32):
1270          self.values = tensor_spec.TensorSpec(shape, dtype)
1271          self.weight = tensor_spec.TensorSpec([], weight_dtype)
1272
1273        def __validate__(self):
1274          self.weight.shape.assert_has_rank(0)
1275
1276        shape = property(lambda self: self.values.shape)
1277        dtype = property(lambda self: self.values.dtype)
1278        weight_dtype = property(lambda self: self.weight.dtype)
1279
1280        def with_shape(self, shape):
1281          return WeightedTensor.Spec(shape, self.dtype, self.weight_dtype)
1282
1283    wt = WeightedTensor([1, 2], 0.3)
1284    wt_spec = WeightedTensor.Spec.from_value(wt)
1285    self.assertEqual(wt_spec.shape, tensor_shape.TensorShape([2]))
1286    self.assertEqual(wt_spec.dtype, dtypes.int32)
1287
1288    self.assertEqual(wt_spec, WeightedTensor.Spec([2], dtypes.int32))
1289
1290    wt2 = WeightedTensor([[1, 2], [3, 4]], 0.5)
1291    wt2_spec = WeightedTensor.Spec.from_value(wt2)
1292    self.assertEqual(wt_spec.with_shape([2, 2]), wt2_spec)
1293
1294  def testNestedSpecMustBeAClass(self):
1295    with self.assertRaisesRegex(
1296        ValueError,
1297        r'BrokenExtensionType\.Spec must be a nested class; got 12.'):
1298
1299      class BrokenExtensionType(extension_type.ExtensionType):
1300
1301        Spec = 12  # pylint: disable=invalid-name
1302
1303      del BrokenExtensionType
1304
1305  def testNestedSpecMayNotHaveBaseClasses(self):
1306    with self.assertRaisesRegex(
1307        ValueError, r'BrokenExtensionType\.Spec must be directly subclassed '
1308        'from tf.TypeSpec.'):
1309
1310      class BrokenExtensionType(extension_type.ExtensionType):
1311
1312        class Spec(type_spec.BatchableTypeSpec):
1313          pass
1314
1315      del BrokenExtensionType
1316
1317
1318@test_util.run_all_in_graph_and_eager_modes
1319class AnonymousExtensionTypeTest(test_util.TensorFlowTestCase,
1320                                 parameterized.TestCase):
1321
1322  @parameterized.parameters([
1323      [dict(i=5, f=3.2, b=True, n=None)],
1324      [dict(x=(1, 2), y={
1325          3: 4,
1326          5: 6
1327      })],
1328      [lambda: dict(t=constant_op.constant(123))],
1329      [lambda: dict(r=ragged_factory_ops.constant([[1, 2], [3]]))],
1330  ])
1331  def testConstruction(self, fields):
1332    if callable(fields):
1333      fields = fields()
1334    extension_type.AnonymousExtensionType(**fields)
1335
1336  @parameterized.parameters([
1337      [dict(x=[1, 2, 3]), 'unsupported `value` argument'],
1338      [dict(x=set([1, 2])), 'unsupported `value` argument'],
1339      [dict(x=(1, dict([(2, [])]))), 'unsupported `value` argument'],
1340      [
1341          dict(_tf_extension_type_xyz=5),
1342          'Reserved field name .*_tf_extension_type_xyz.*'
1343      ],
1344  ])
1345  def testConstructionErrors(self, fields, error):
1346    with self.assertRaisesRegex(ValueError, error):
1347      extension_type.AnonymousExtensionType(**fields)
1348
1349  @parameterized.parameters([
1350      [dict(i=5, f=3.2, b=True, n=None)],
1351      [dict(x=(1, 2), y={
1352          3: 4,
1353          5: 6
1354      })],
1355      [lambda: dict(t=constant_op.constant(123))],
1356      [lambda: dict(r=ragged_factory_ops.constant([[1, 2], [3]]))],
1357  ])
1358  def testAttributeAccessors(self, fields):
1359    if callable(fields):
1360      fields = fields()
1361    s = extension_type.AnonymousExtensionType(**fields)
1362    for (name, value) in fields.items():
1363      actual = getattr(s, name)
1364      if isinstance(actual, (ops.Tensor, ragged_tensor.RaggedTensor)):
1365        self.assertAllEqual(actual, value)
1366      else:
1367        self.assertEqual(actual, value)
1368
1369  def testAttributeAccessorsAreImmutable(self):
1370    s = extension_type.AnonymousExtensionType(x=12, y={'x': 55})
1371    with self.assertRaisesRegex(AttributeError, 'Cannot set attribute `x`'):
1372      s.x = 22
1373    with self.assertRaisesRegex(AttributeError, 'Cannot delete attribute `y`'):
1374      del s.y
1375    with self.assertRaisesRegex(TypeError, 'does not support item assignment'):
1376      s.y['x'] = 66
1377
1378  def testReinterpret(self):
1379    x = MaskedTensorV2([4, 5], [True, False])
1380    anon_x = extension_type.reinterpret(x,
1381                                        extension_type.AnonymousExtensionType)
1382    self.assertAllEqual(anon_x.values, [4, 5])
1383    self.assertAllEqual(anon_x.mask, [True, False])
1384
1385    round_trip_x = extension_type.reinterpret(anon_x, MaskedTensorV2)
1386    self.assertAllEqual(round_trip_x.values, [4, 5])
1387    self.assertAllEqual(round_trip_x.mask, [True, False])
1388
1389    converted_x = extension_type.reinterpret(anon_x, MaskedTensorV1)
1390    self.assertAllEqual(converted_x.values, [4, 5])
1391    self.assertAllEqual(converted_x.mask, [True, False])
1392
1393  # pylint: disable=g-long-lambda
1394  @parameterized.parameters([
1395      [
1396          lambda: extension_type.AnonymousExtensionType(
1397              values=constant_op.constant([1, 2, 3])), MaskedTensorV2,
1398          "Missing required fields: {'mask'}"
1399      ],
1400      [
1401          lambda: extension_type.AnonymousExtensionType(
1402              values=(1, 2, 3), mask=None), MaskedTensorV2,
1403          'mask: expected a Tensor, got None'
1404      ],
1405      [
1406          lambda: extension_type.AnonymousExtensionType(
1407              values=constant_op.constant([1, 2, 3]),
1408              mask=constant_op.constant([True, False])), MaskedTensorV2,
1409          'Shapes .* are incompatible'
1410      ],
1411      [
1412          lambda: extension_type.AnonymousExtensionType(
1413              values=constant_op.constant([1, 2, 3])), ops.Tensor,
1414          'reinterpret expects `new_type` to be a subclass of '
1415          'tf.ExtensionType; '
1416          'got .*.Tensor.*'
1417      ],
1418      [
1419          lambda: constant_op.constant([1, 2, 3]),
1420          extension_type.AnonymousExtensionType,
1421          'reinterpret expects `value` to be a tf.ExtensionType instance; '
1422          'got.*.Tensor.*'
1423      ],
1424  ])
1425  def testReinterpretErrors(self, value, new_type, error):
1426    if callable(value):
1427      value = value()
1428    with self.assertRaisesRegex((TypeError, ValueError), error):
1429      extension_type.reinterpret(value, new_type)
1430
1431  def testLoadSavedModelWithUnregisteredExtensionType(self):
1432
1433    def f(x, y):
1434      x_values = x.values if isinstance(x, MaskedTensorV1) else x
1435      y_values = y.values if isinstance(y, MaskedTensorV1) else y
1436      x_mask = x.mask if isinstance(x, MaskedTensorV1) else True
1437      y_mask = y.mask if isinstance(y, MaskedTensorV1) else True
1438      return MaskedTensorV1(x_values + y_values, x_mask & y_mask)
1439
1440    t_spec = tensor_spec.TensorSpec(None, dtypes.int32)
1441    b_spec = tensor_spec.TensorSpec(None, dtypes.bool)
1442    mt_spec = MaskedTensorV1.Spec(values=t_spec, mask=b_spec)
1443    model = module.Module()
1444    model.f = def_function.function(f)
1445    model.f.get_concrete_function(t_spec, t_spec)
1446    model.f.get_concrete_function(t_spec, mt_spec)
1447    model.f.get_concrete_function(mt_spec, t_spec)
1448    model.f.get_concrete_function(mt_spec, mt_spec)
1449
1450    path = tempfile.mkdtemp(prefix=test.get_temp_dir())
1451    with temporarily_register_type_spec('tf.test.MaskedTensorV1.Spec',
1452                                        MaskedTensorV1.Spec):
1453      save.save(model, path)
1454    loaded_model = load.load(path)
1455
1456    with self.assertRaises(ValueError):
1457      type_spec.lookup('tf.test.MaskedTensorV1')
1458
1459    t = constant_op.constant([10, 20, 30])
1460    v1 = loaded_model.f(t, t)
1461    self.assertIsInstance(v1, extension_type.AnonymousExtensionType)
1462    self.assertAllEqual(v1.values, [20, 40, 60])
1463    self.assertAllEqual(v1.mask, True)
1464
1465    v2 = loaded_model.f(v1, v1)
1466    self.assertIsInstance(v2, extension_type.AnonymousExtensionType)
1467    self.assertAllEqual(v2.values, [40, 80, 120])
1468    self.assertAllEqual(v2.mask, True)
1469
1470    mt = MaskedTensorV1([1, 2, 3], [True, True, False])
1471    v3 = loaded_model.f(
1472        t, extension_type.reinterpret(mt,
1473                                      extension_type.AnonymousExtensionType))
1474    self.assertIsInstance(v3, extension_type.AnonymousExtensionType)
1475    self.assertAllEqual(v3.values, [11, 22, 33])
1476    self.assertAllEqual(v3.mask, [True, True, False])
1477
1478    v4 = extension_type.reinterpret(v3, MaskedTensorV1)
1479    self.assertIsInstance(v4, MaskedTensorV1)
1480    self.assertAllEqual(v4.values, [11, 22, 33])
1481    self.assertAllEqual(v4.mask, [True, True, False])
1482
1483  def testFlatTensorSpecs(self):
1484    x = MaskedTensorV2([4, 5], [True, False])
1485    spec = type_spec.type_spec_from_value(x)
1486    flat_specs = spec._flat_tensor_specs
1487    self.assertEqual(flat_specs, [
1488        tensor_spec.TensorSpec(shape=(2,), dtype=dtypes.int32, name=None),
1489        tensor_spec.TensorSpec(shape=(2,), dtype=dtypes.bool, name=None)
1490    ])
1491
1492  def testFullTypesForFlatTensors(self):
1493    x = MaskedTensorV2([4, 5], [True, False])
1494    spec = type_spec.type_spec_from_value(x)
1495    full_type_list = fulltypes_for_flat_tensors(spec)
1496    expect = [
1497        full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_UNSET),
1498        full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_UNSET)
1499    ]
1500    self.assertEqual(len(spec._flat_tensor_specs), len(full_type_list))
1501    self.assertEqual(expect, full_type_list)
1502
1503
1504def replace_tensors_with_placeholders(value):
1505
1506  def repl(x):
1507    if isinstance(x, ops.Tensor):
1508      return array_ops.placeholder_with_default(x, shape=None)
1509    else:
1510      return x
1511
1512  return nest.map_structure(repl, value, expand_composites=True)
1513
1514
1515@contextlib.contextmanager
1516def temporarily_add_dispatch(op, typ, fn):
1517  n = len(op._tf_fallback_dispatchers)
1518  dispatch.dispatch_for_types(op, typ)(fn)
1519  yield
1520  assert len(op._tf_fallback_dispatchers) == n + 1
1521  del op._tf_fallback_dispatchers[-1]
1522
1523
1524@contextlib.contextmanager
1525def temporarily_register_type_spec(name, cls):
1526  """Context manager for making temporary changes to the TypeSpec registry."""
1527  type_spec.register(name)(cls)
1528  yield
1529  assert type_spec._TYPE_SPEC_TO_NAME.pop(cls) == name
1530  assert type_spec._NAME_TO_TYPE_SPEC.pop(name) is cls
1531
1532
1533if __name__ == '__main__':
1534  googletest.main()
1535