xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/testing_utils.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for unit-testing Keras."""
16
17import collections
18import contextlib
19import functools
20import itertools
21import threading
22
23import numpy as np
24
25from tensorflow.python import tf2
26from tensorflow.python.eager import context
27from tensorflow.python.framework import config
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import tensor_shape
31from tensorflow.python.framework import tensor_spec
32from tensorflow.python.framework import test_util
33from tensorflow.python.keras import backend
34from tensorflow.python.keras import layers
35from tensorflow.python.keras import models
36from tensorflow.python.keras.engine import base_layer_utils
37from tensorflow.python.keras.optimizer_v2 import adadelta as adadelta_v2
38from tensorflow.python.keras.optimizer_v2 import adagrad as adagrad_v2
39from tensorflow.python.keras.optimizer_v2 import adam as adam_v2
40from tensorflow.python.keras.optimizer_v2 import adamax as adamax_v2
41from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2
42from tensorflow.python.keras.optimizer_v2 import nadam as nadam_v2
43from tensorflow.python.keras.optimizer_v2 import rmsprop as rmsprop_v2
44from tensorflow.python.keras.utils import tf_contextlib
45from tensorflow.python.keras.utils import tf_inspect
46from tensorflow.python.util import tf_decorator
47
48
49def string_test(actual, expected):
50  np.testing.assert_array_equal(actual, expected)
51
52
53def numeric_test(actual, expected):
54  np.testing.assert_allclose(actual, expected, rtol=1e-3, atol=1e-6)
55
56
57def get_test_data(train_samples,
58                  test_samples,
59                  input_shape,
60                  num_classes,
61                  random_seed=None):
62  """Generates test data to train a model on.
63
64  Args:
65    train_samples: Integer, how many training samples to generate.
66    test_samples: Integer, how many test samples to generate.
67    input_shape: Tuple of integers, shape of the inputs.
68    num_classes: Integer, number of classes for the data and targets.
69    random_seed: Integer, random seed used by numpy to generate data.
70
71  Returns:
72    A tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
73  """
74  if random_seed is not None:
75    np.random.seed(random_seed)
76  num_sample = train_samples + test_samples
77  templates = 2 * num_classes * np.random.random((num_classes,) + input_shape)
78  y = np.random.randint(0, num_classes, size=(num_sample,))
79  x = np.zeros((num_sample,) + input_shape, dtype=np.float32)
80  for i in range(num_sample):
81    x[i] = templates[y[i]] + np.random.normal(loc=0, scale=1., size=input_shape)
82  return ((x[:train_samples], y[:train_samples]),
83          (x[train_samples:], y[train_samples:]))
84
85
86@test_util.disable_cudnn_autotune
87def layer_test(layer_cls,
88               kwargs=None,
89               input_shape=None,
90               input_dtype=None,
91               input_data=None,
92               expected_output=None,
93               expected_output_dtype=None,
94               expected_output_shape=None,
95               validate_training=True,
96               adapt_data=None,
97               custom_objects=None,
98               test_harness=None,
99               supports_masking=None):
100  """Test routine for a layer with a single input and single output.
101
102  Args:
103    layer_cls: Layer class object.
104    kwargs: Optional dictionary of keyword arguments for instantiating the
105      layer.
106    input_shape: Input shape tuple.
107    input_dtype: Data type of the input data.
108    input_data: Numpy array of input data.
109    expected_output: Numpy array of the expected output.
110    expected_output_dtype: Data type expected for the output.
111    expected_output_shape: Shape tuple for the expected shape of the output.
112    validate_training: Whether to attempt to validate training on this layer.
113      This might be set to False for non-differentiable layers that output
114      string or integer values.
115    adapt_data: Optional data for an 'adapt' call. If None, adapt() will not
116      be tested for this layer. This is only relevant for PreprocessingLayers.
117    custom_objects: Optional dictionary mapping name strings to custom objects
118      in the layer class. This is helpful for testing custom layers.
119    test_harness: The Tensorflow test, if any, that this function is being
120      called in.
121    supports_masking: Optional boolean to check the `supports_masking` property
122      of the layer. If None, the check will not be performed.
123
124  Returns:
125    The output data (Numpy array) returned by the layer, for additional
126    checks to be done by the calling code.
127
128  Raises:
129    ValueError: if `input_shape is None`.
130  """
131  if input_data is None:
132    if input_shape is None:
133      raise ValueError('input_shape is None')
134    if not input_dtype:
135      input_dtype = 'float32'
136    input_data_shape = list(input_shape)
137    for i, e in enumerate(input_data_shape):
138      if e is None:
139        input_data_shape[i] = np.random.randint(1, 4)
140    input_data = 10 * np.random.random(input_data_shape)
141    if input_dtype[:5] == 'float':
142      input_data -= 0.5
143    input_data = input_data.astype(input_dtype)
144  elif input_shape is None:
145    input_shape = input_data.shape
146  if input_dtype is None:
147    input_dtype = input_data.dtype
148  if expected_output_dtype is None:
149    expected_output_dtype = input_dtype
150
151  if dtypes.as_dtype(expected_output_dtype) == dtypes.string:
152    if test_harness:
153      assert_equal = test_harness.assertAllEqual
154    else:
155      assert_equal = string_test
156  else:
157    if test_harness:
158      assert_equal = test_harness.assertAllClose
159    else:
160      assert_equal = numeric_test
161
162  # instantiation
163  kwargs = kwargs or {}
164  layer = layer_cls(**kwargs)
165
166  if (supports_masking is not None
167      and layer.supports_masking != supports_masking):
168    raise AssertionError(
169        'When testing layer %s, the `supports_masking` property is %r'
170        'but expected to be %r.\nFull kwargs: %s' %
171        (layer_cls.__name__, layer.supports_masking, supports_masking, kwargs))
172
173  # Test adapt, if data was passed.
174  if adapt_data is not None:
175    layer.adapt(adapt_data)
176
177  # test get_weights , set_weights at layer level
178  weights = layer.get_weights()
179  layer.set_weights(weights)
180
181  # test and instantiation from weights
182  if 'weights' in tf_inspect.getargspec(layer_cls.__init__):
183    kwargs['weights'] = weights
184    layer = layer_cls(**kwargs)
185
186  # test in functional API
187  x = layers.Input(shape=input_shape[1:], dtype=input_dtype)
188  y = layer(x)
189  if backend.dtype(y) != expected_output_dtype:
190    raise AssertionError('When testing layer %s, for input %s, found output '
191                         'dtype=%s but expected to find %s.\nFull kwargs: %s' %
192                         (layer_cls.__name__, x, backend.dtype(y),
193                          expected_output_dtype, kwargs))
194
195  def assert_shapes_equal(expected, actual):
196    """Asserts that the output shape from the layer matches the actual shape."""
197    if len(expected) != len(actual):
198      raise AssertionError(
199          'When testing layer %s, for input %s, found output_shape='
200          '%s but expected to find %s.\nFull kwargs: %s' %
201          (layer_cls.__name__, x, actual, expected, kwargs))
202
203    for expected_dim, actual_dim in zip(expected, actual):
204      if isinstance(expected_dim, tensor_shape.Dimension):
205        expected_dim = expected_dim.value
206      if isinstance(actual_dim, tensor_shape.Dimension):
207        actual_dim = actual_dim.value
208      if expected_dim is not None and expected_dim != actual_dim:
209        raise AssertionError(
210            'When testing layer %s, for input %s, found output_shape='
211            '%s but expected to find %s.\nFull kwargs: %s' %
212            (layer_cls.__name__, x, actual, expected, kwargs))
213
214  if expected_output_shape is not None:
215    assert_shapes_equal(tensor_shape.TensorShape(expected_output_shape),
216                        y.shape)
217
218  # check shape inference
219  model = models.Model(x, y)
220  computed_output_shape = tuple(
221      layer.compute_output_shape(
222          tensor_shape.TensorShape(input_shape)).as_list())
223  computed_output_signature = layer.compute_output_signature(
224      tensor_spec.TensorSpec(shape=input_shape, dtype=input_dtype))
225  actual_output = model.predict(input_data)
226  actual_output_shape = actual_output.shape
227  assert_shapes_equal(computed_output_shape, actual_output_shape)
228  assert_shapes_equal(computed_output_signature.shape, actual_output_shape)
229  if computed_output_signature.dtype != actual_output.dtype:
230    raise AssertionError(
231        'When testing layer %s, for input %s, found output_dtype='
232        '%s but expected to find %s.\nFull kwargs: %s' %
233        (layer_cls.__name__, x, actual_output.dtype,
234         computed_output_signature.dtype, kwargs))
235  if expected_output is not None:
236    assert_equal(actual_output, expected_output)
237
238  # test serialization, weight setting at model level
239  model_config = model.get_config()
240  recovered_model = models.Model.from_config(model_config, custom_objects)
241  if model.weights:
242    weights = model.get_weights()
243    recovered_model.set_weights(weights)
244    output = recovered_model.predict(input_data)
245    assert_equal(output, actual_output)
246
247  # test training mode (e.g. useful for dropout tests)
248  # Rebuild the model to avoid the graph being reused between predict() and
249  # See b/120160788 for more details. This should be mitigated after 2.0.
250  layer_weights = layer.get_weights()  # Get the layer weights BEFORE training.
251  if validate_training:
252    model = models.Model(x, layer(x))
253    if _thread_local_data.run_eagerly is not None:
254      model.compile(
255          'rmsprop',
256          'mse',
257          weighted_metrics=['acc'],
258          run_eagerly=should_run_eagerly())
259    else:
260      model.compile('rmsprop', 'mse', weighted_metrics=['acc'])
261    model.train_on_batch(input_data, actual_output)
262
263  # test as first layer in Sequential API
264  layer_config = layer.get_config()
265  layer_config['batch_input_shape'] = input_shape
266  layer = layer.__class__.from_config(layer_config)
267
268  # Test adapt, if data was passed.
269  if adapt_data is not None:
270    layer.adapt(adapt_data)
271
272  model = models.Sequential()
273  model.add(layers.Input(shape=input_shape[1:], dtype=input_dtype))
274  model.add(layer)
275
276  layer.set_weights(layer_weights)
277  actual_output = model.predict(input_data)
278  actual_output_shape = actual_output.shape
279  for expected_dim, actual_dim in zip(computed_output_shape,
280                                      actual_output_shape):
281    if expected_dim is not None:
282      if expected_dim != actual_dim:
283        raise AssertionError(
284            'When testing layer %s **after deserialization**, '
285            'for input %s, found output_shape='
286            '%s but expected to find inferred shape %s.\nFull kwargs: %s' %
287            (layer_cls.__name__,
288             x,
289             actual_output_shape,
290             computed_output_shape,
291             kwargs))
292  if expected_output is not None:
293    assert_equal(actual_output, expected_output)
294
295  # test serialization, weight setting at model level
296  model_config = model.get_config()
297  recovered_model = models.Sequential.from_config(model_config, custom_objects)
298  if model.weights:
299    weights = model.get_weights()
300    recovered_model.set_weights(weights)
301    output = recovered_model.predict(input_data)
302    assert_equal(output, actual_output)
303
304  # for further checks in the caller function
305  return actual_output
306
307
308_thread_local_data = threading.local()
309_thread_local_data.model_type = None
310_thread_local_data.run_eagerly = None
311_thread_local_data.saved_model_format = None
312_thread_local_data.save_kwargs = None
313
314
315@tf_contextlib.contextmanager
316def model_type_scope(value):
317  """Provides a scope within which the model type to test is equal to `value`.
318
319  The model type gets restored to its original value upon exiting the scope.
320
321  Args:
322     value: model type value
323
324  Yields:
325    The provided value.
326  """
327  previous_value = _thread_local_data.model_type
328  try:
329    _thread_local_data.model_type = value
330    yield value
331  finally:
332    # Restore model type to initial value.
333    _thread_local_data.model_type = previous_value
334
335
336@tf_contextlib.contextmanager
337def run_eagerly_scope(value):
338  """Provides a scope within which we compile models to run eagerly or not.
339
340  The boolean gets restored to its original value upon exiting the scope.
341
342  Args:
343     value: Bool specifying if we should run models eagerly in the active test.
344     Should be True or False.
345
346  Yields:
347    The provided value.
348  """
349  previous_value = _thread_local_data.run_eagerly
350  try:
351    _thread_local_data.run_eagerly = value
352    yield value
353  finally:
354    # Restore model type to initial value.
355    _thread_local_data.run_eagerly = previous_value
356
357
358def should_run_eagerly():
359  """Returns whether the models we are testing should be run eagerly."""
360  if _thread_local_data.run_eagerly is None:
361    raise ValueError('Cannot call `should_run_eagerly()` outside of a '
362                     '`run_eagerly_scope()` or `run_all_keras_modes` '
363                     'decorator.')
364
365  return _thread_local_data.run_eagerly and context.executing_eagerly()
366
367
368@tf_contextlib.contextmanager
369def saved_model_format_scope(value, **kwargs):
370  """Provides a scope within which the savde model format to test is `value`.
371
372  The saved model format gets restored to its original value upon exiting the
373  scope.
374
375  Args:
376     value: saved model format value
377     **kwargs: optional kwargs to pass to the save function.
378
379  Yields:
380    The provided value.
381  """
382  previous_format = _thread_local_data.saved_model_format
383  previous_kwargs = _thread_local_data.save_kwargs
384  try:
385    _thread_local_data.saved_model_format = value
386    _thread_local_data.save_kwargs = kwargs
387    yield
388  finally:
389    # Restore saved model format to initial value.
390    _thread_local_data.saved_model_format = previous_format
391    _thread_local_data.save_kwargs = previous_kwargs
392
393
394def get_save_format():
395  if _thread_local_data.saved_model_format is None:
396    raise ValueError(
397        'Cannot call `get_save_format()` outside of a '
398        '`saved_model_format_scope()` or `run_with_all_saved_model_formats` '
399        'decorator.')
400  return _thread_local_data.saved_model_format
401
402
403def get_save_kwargs():
404  if _thread_local_data.save_kwargs is None:
405    raise ValueError(
406        'Cannot call `get_save_kwargs()` outside of a '
407        '`saved_model_format_scope()` or `run_with_all_saved_model_formats` '
408        'decorator.')
409  return _thread_local_data.save_kwargs or {}
410
411
412def get_model_type():
413  """Gets the model type that should be tested."""
414  if _thread_local_data.model_type is None:
415    raise ValueError('Cannot call `get_model_type()` outside of a '
416                     '`model_type_scope()` or `run_with_all_model_types` '
417                     'decorator.')
418
419  return _thread_local_data.model_type
420
421
422def get_small_sequential_mlp(num_hidden, num_classes, input_dim=None):
423  model = models.Sequential()
424  if input_dim:
425    model.add(layers.Dense(num_hidden, activation='relu', input_dim=input_dim))
426  else:
427    model.add(layers.Dense(num_hidden, activation='relu'))
428  activation = 'sigmoid' if num_classes == 1 else 'softmax'
429  model.add(layers.Dense(num_classes, activation=activation))
430  return model
431
432
433def get_small_functional_mlp(num_hidden, num_classes, input_dim):
434  inputs = layers.Input(shape=(input_dim,))
435  outputs = layers.Dense(num_hidden, activation='relu')(inputs)
436  activation = 'sigmoid' if num_classes == 1 else 'softmax'
437  outputs = layers.Dense(num_classes, activation=activation)(outputs)
438  return models.Model(inputs, outputs)
439
440
441class SmallSubclassMLP(models.Model):
442  """A subclass model based small MLP."""
443
444  def __init__(self,
445               num_hidden,
446               num_classes,
447               use_bn=False,
448               use_dp=False,
449               **kwargs):
450    super(SmallSubclassMLP, self).__init__(name='test_model', **kwargs)
451    self.use_bn = use_bn
452    self.use_dp = use_dp
453
454    self.layer_a = layers.Dense(num_hidden, activation='relu')
455    activation = 'sigmoid' if num_classes == 1 else 'softmax'
456    self.layer_b = layers.Dense(num_classes, activation=activation)
457    if self.use_dp:
458      self.dp = layers.Dropout(0.5)
459    if self.use_bn:
460      self.bn = layers.BatchNormalization(axis=-1)
461
462  def call(self, inputs, **kwargs):
463    x = self.layer_a(inputs)
464    if self.use_dp:
465      x = self.dp(x)
466    if self.use_bn:
467      x = self.bn(x)
468    return self.layer_b(x)
469
470
471class _SmallSubclassMLPCustomBuild(models.Model):
472  """A subclass model small MLP that uses a custom build method."""
473
474  def __init__(self, num_hidden, num_classes):
475    super(_SmallSubclassMLPCustomBuild, self).__init__()
476    self.layer_a = None
477    self.layer_b = None
478    self.num_hidden = num_hidden
479    self.num_classes = num_classes
480
481  def build(self, input_shape):
482    self.layer_a = layers.Dense(self.num_hidden, activation='relu')
483    activation = 'sigmoid' if self.num_classes == 1 else 'softmax'
484    self.layer_b = layers.Dense(self.num_classes, activation=activation)
485
486  def call(self, inputs, **kwargs):
487    x = self.layer_a(inputs)
488    return self.layer_b(x)
489
490
491def get_small_subclass_mlp(num_hidden, num_classes):
492  return SmallSubclassMLP(num_hidden, num_classes)
493
494
495def get_small_subclass_mlp_with_custom_build(num_hidden, num_classes):
496  return _SmallSubclassMLPCustomBuild(num_hidden, num_classes)
497
498
499def get_small_mlp(num_hidden, num_classes, input_dim):
500  """Get a small mlp of the model type specified by `get_model_type`."""
501  model_type = get_model_type()
502  if model_type == 'subclass':
503    return get_small_subclass_mlp(num_hidden, num_classes)
504  if model_type == 'subclass_custom_build':
505    return get_small_subclass_mlp_with_custom_build(num_hidden, num_classes)
506  if model_type == 'sequential':
507    return get_small_sequential_mlp(num_hidden, num_classes, input_dim)
508  if model_type == 'functional':
509    return get_small_functional_mlp(num_hidden, num_classes, input_dim)
510  raise ValueError('Unknown model type {}'.format(model_type))
511
512
513class _SubclassModel(models.Model):
514  """A Keras subclass model."""
515
516  def __init__(self, model_layers, *args, **kwargs):
517    """Instantiate a model.
518
519    Args:
520      model_layers: a list of layers to be added to the model.
521      *args: Model's args
522      **kwargs: Model's keyword args, at most one of input_tensor -> the input
523        tensor required for ragged/sparse input.
524    """
525
526    inputs = kwargs.pop('input_tensor', None)
527    super(_SubclassModel, self).__init__(*args, **kwargs)
528    # Note that clone and build doesn't support lists of layers in subclassed
529    # models. Adding each layer directly here.
530    for i, layer in enumerate(model_layers):
531      setattr(self, self._layer_name_for_i(i), layer)
532
533    self.num_layers = len(model_layers)
534
535    if inputs is not None:
536      self._set_inputs(inputs)
537
538  def _layer_name_for_i(self, i):
539    return 'layer{}'.format(i)
540
541  def call(self, inputs, **kwargs):
542    x = inputs
543    for i in range(self.num_layers):
544      layer = getattr(self, self._layer_name_for_i(i))
545      x = layer(x)
546    return x
547
548
549class _SubclassModelCustomBuild(models.Model):
550  """A Keras subclass model that uses a custom build method."""
551
552  def __init__(self, layer_generating_func, *args, **kwargs):
553    super(_SubclassModelCustomBuild, self).__init__(*args, **kwargs)
554    self.all_layers = None
555    self._layer_generating_func = layer_generating_func
556
557  def build(self, input_shape):
558    model_layers = []
559    for layer in self._layer_generating_func():
560      model_layers.append(layer)
561    self.all_layers = model_layers
562
563  def call(self, inputs, **kwargs):
564    x = inputs
565    for layer in self.all_layers:
566      x = layer(x)
567    return x
568
569
570def get_model_from_layers(model_layers,
571                          input_shape=None,
572                          input_dtype=None,
573                          name=None,
574                          input_ragged=None,
575                          input_sparse=None,
576                          model_type=None):
577  """Builds a model from a sequence of layers.
578
579  Args:
580    model_layers: The layers used to build the network.
581    input_shape: Shape tuple of the input or 'TensorShape' instance.
582    input_dtype: Datatype of the input.
583    name: Name for the model.
584    input_ragged: Boolean, whether the input data is a ragged tensor.
585    input_sparse: Boolean, whether the input data is a sparse tensor.
586    model_type: One of "subclass", "subclass_custom_build", "sequential", or
587      "functional". When None, defaults to `get_model_type`.
588
589  Returns:
590    A Keras model.
591  """
592  if model_type is None:
593    model_type = get_model_type()
594  if model_type == 'subclass':
595    inputs = None
596    if input_ragged or input_sparse:
597      inputs = layers.Input(
598          shape=input_shape,
599          dtype=input_dtype,
600          ragged=input_ragged,
601          sparse=input_sparse)
602    return _SubclassModel(model_layers, name=name, input_tensor=inputs)
603
604  if model_type == 'subclass_custom_build':
605    layer_generating_func = lambda: model_layers
606    return _SubclassModelCustomBuild(layer_generating_func, name=name)
607
608  if model_type == 'sequential':
609    model = models.Sequential(name=name)
610    if input_shape:
611      model.add(
612          layers.InputLayer(
613              input_shape=input_shape,
614              dtype=input_dtype,
615              ragged=input_ragged,
616              sparse=input_sparse))
617    for layer in model_layers:
618      model.add(layer)
619    return model
620
621  if model_type == 'functional':
622    if not input_shape:
623      raise ValueError('Cannot create a functional model from layers with no '
624                       'input shape.')
625    inputs = layers.Input(
626        shape=input_shape,
627        dtype=input_dtype,
628        ragged=input_ragged,
629        sparse=input_sparse)
630    outputs = inputs
631    for layer in model_layers:
632      outputs = layer(outputs)
633    return models.Model(inputs, outputs, name=name)
634
635  raise ValueError('Unknown model type {}'.format(model_type))
636
637
638class Bias(layers.Layer):
639
640  def build(self, input_shape):
641    self.bias = self.add_variable('bias', (1,), initializer='zeros')
642
643  def call(self, inputs):
644    return inputs + self.bias
645
646
647class _MultiIOSubclassModel(models.Model):
648  """Multi IO Keras subclass model."""
649
650  def __init__(self, branch_a, branch_b, shared_input_branch=None,
651               shared_output_branch=None, name=None):
652    super(_MultiIOSubclassModel, self).__init__(name=name)
653    self._shared_input_branch = shared_input_branch
654    self._branch_a = branch_a
655    self._branch_b = branch_b
656    self._shared_output_branch = shared_output_branch
657
658  def call(self, inputs, **kwargs):
659    if self._shared_input_branch:
660      for layer in self._shared_input_branch:
661        inputs = layer(inputs)
662      a = inputs
663      b = inputs
664    elif isinstance(inputs, dict):
665      a = inputs['input_1']
666      b = inputs['input_2']
667    else:
668      a, b = inputs
669
670    for layer in self._branch_a:
671      a = layer(a)
672    for layer in self._branch_b:
673      b = layer(b)
674    outs = [a, b]
675
676    if self._shared_output_branch:
677      for layer in self._shared_output_branch:
678        outs = layer(outs)
679
680    return outs
681
682
683class _MultiIOSubclassModelCustomBuild(models.Model):
684  """Multi IO Keras subclass model that uses a custom build method."""
685
686  def __init__(self, branch_a_func, branch_b_func,
687               shared_input_branch_func=None,
688               shared_output_branch_func=None):
689    super(_MultiIOSubclassModelCustomBuild, self).__init__()
690    self._shared_input_branch_func = shared_input_branch_func
691    self._branch_a_func = branch_a_func
692    self._branch_b_func = branch_b_func
693    self._shared_output_branch_func = shared_output_branch_func
694
695    self._shared_input_branch = None
696    self._branch_a = None
697    self._branch_b = None
698    self._shared_output_branch = None
699
700  def build(self, input_shape):
701    if self._shared_input_branch_func():
702      self._shared_input_branch = self._shared_input_branch_func()
703    self._branch_a = self._branch_a_func()
704    self._branch_b = self._branch_b_func()
705
706    if self._shared_output_branch_func():
707      self._shared_output_branch = self._shared_output_branch_func()
708
709  def call(self, inputs, **kwargs):
710    if self._shared_input_branch:
711      for layer in self._shared_input_branch:
712        inputs = layer(inputs)
713      a = inputs
714      b = inputs
715    else:
716      a, b = inputs
717
718    for layer in self._branch_a:
719      a = layer(a)
720    for layer in self._branch_b:
721      b = layer(b)
722    outs = a, b
723
724    if self._shared_output_branch:
725      for layer in self._shared_output_branch:
726        outs = layer(outs)
727
728    return outs
729
730
731def get_multi_io_model(
732    branch_a,
733    branch_b,
734    shared_input_branch=None,
735    shared_output_branch=None):
736  """Builds a multi-io model that contains two branches.
737
738  The produced model will be of the type specified by `get_model_type`.
739
740  To build a two-input, two-output model:
741    Specify a list of layers for branch a and branch b, but do not specify any
742    shared input branch or shared output branch. The resulting model will apply
743    each branch to a different input, to produce two outputs.
744
745    The first value in branch_a must be the Keras 'Input' layer for branch a,
746    and the first value in branch_b must be the Keras 'Input' layer for
747    branch b.
748
749    example usage:
750    ```
751    branch_a = [Input(shape=(2,), name='a'), Dense(), Dense()]
752    branch_b = [Input(shape=(3,), name='b'), Dense(), Dense()]
753
754    model = get_multi_io_model(branch_a, branch_b)
755    ```
756
757  To build a two-input, one-output model:
758    Specify a list of layers for branch a and branch b, and specify a
759    shared output branch. The resulting model will apply
760    each branch to a different input. It will then apply the shared output
761    branch to a tuple containing the intermediate outputs of each branch,
762    to produce a single output. The first layer in the shared_output_branch
763    must be able to merge a tuple of two tensors.
764
765    The first value in branch_a must be the Keras 'Input' layer for branch a,
766    and the first value in branch_b must be the Keras 'Input' layer for
767    branch b.
768
769    example usage:
770    ```
771    input_branch_a = [Input(shape=(2,), name='a'), Dense(), Dense()]
772    input_branch_b = [Input(shape=(3,), name='b'), Dense(), Dense()]
773    shared_output_branch = [Concatenate(), Dense(), Dense()]
774
775    model = get_multi_io_model(input_branch_a, input_branch_b,
776                               shared_output_branch=shared_output_branch)
777    ```
778  To build a one-input, two-output model:
779    Specify a list of layers for branch a and branch b, and specify a
780    shared input branch. The resulting model will take one input, and apply
781    the shared input branch to it. It will then respectively apply each branch
782    to that intermediate result in parallel, to produce two outputs.
783
784    The first value in the shared_input_branch must be the Keras 'Input' layer
785    for the whole model. Branch a and branch b should not contain any Input
786    layers.
787
788    example usage:
789    ```
790    shared_input_branch = [Input(shape=(2,), name='in'), Dense(), Dense()]
791    output_branch_a = [Dense(), Dense()]
792    output_branch_b = [Dense(), Dense()]
793
794
795    model = get_multi_io_model(output__branch_a, output_branch_b,
796                               shared_input_branch=shared_input_branch)
797    ```
798
799  Args:
800    branch_a: A sequence of layers for branch a of the model.
801    branch_b: A sequence of layers for branch b of the model.
802    shared_input_branch: An optional sequence of layers to apply to a single
803      input, before applying both branches to that intermediate result. If set,
804      the model will take only one input instead of two. Defaults to None.
805    shared_output_branch: An optional sequence of layers to merge the
806      intermediate results produced by branch a and branch b. If set,
807      the model will produce only one output instead of two. Defaults to None.
808
809  Returns:
810    A multi-io model of the type specified by `get_model_type`, specified
811    by the different branches.
812  """
813  # Extract the functional inputs from the layer lists
814  if shared_input_branch:
815    inputs = shared_input_branch[0]
816    shared_input_branch = shared_input_branch[1:]
817  else:
818    inputs = branch_a[0], branch_b[0]
819    branch_a = branch_a[1:]
820    branch_b = branch_b[1:]
821
822  model_type = get_model_type()
823  if model_type == 'subclass':
824    return _MultiIOSubclassModel(branch_a, branch_b, shared_input_branch,
825                                 shared_output_branch)
826
827  if model_type == 'subclass_custom_build':
828    return _MultiIOSubclassModelCustomBuild((lambda: branch_a),
829                                            (lambda: branch_b),
830                                            (lambda: shared_input_branch),
831                                            (lambda: shared_output_branch))
832
833  if model_type == 'sequential':
834    raise ValueError('Cannot use `get_multi_io_model` to construct '
835                     'sequential models')
836
837  if model_type == 'functional':
838    if shared_input_branch:
839      a_and_b = inputs
840      for layer in shared_input_branch:
841        a_and_b = layer(a_and_b)
842      a = a_and_b
843      b = a_and_b
844    else:
845      a, b = inputs
846
847    for layer in branch_a:
848      a = layer(a)
849    for layer in branch_b:
850      b = layer(b)
851    outputs = a, b
852
853    if shared_output_branch:
854      for layer in shared_output_branch:
855        outputs = layer(outputs)
856
857    return models.Model(inputs, outputs)
858
859  raise ValueError('Unknown model type {}'.format(model_type))
860
861
862_V2_OPTIMIZER_MAP = {
863    'adadelta': adadelta_v2.Adadelta,
864    'adagrad': adagrad_v2.Adagrad,
865    'adam': adam_v2.Adam,
866    'adamax': adamax_v2.Adamax,
867    'nadam': nadam_v2.Nadam,
868    'rmsprop': rmsprop_v2.RMSprop,
869    'sgd': gradient_descent_v2.SGD
870}
871
872
873def get_v2_optimizer(name, **kwargs):
874  """Get the v2 optimizer requested.
875
876  This is only necessary until v2 are the default, as we are testing in Eager,
877  and Eager + v1 optimizers fail tests. When we are in v2, the strings alone
878  should be sufficient, and this mapping can theoretically be removed.
879
880  Args:
881    name: string name of Keras v2 optimizer.
882    **kwargs: any kwargs to pass to the optimizer constructor.
883
884  Returns:
885    Initialized Keras v2 optimizer.
886
887  Raises:
888    ValueError: if an unknown name was passed.
889  """
890  try:
891    return _V2_OPTIMIZER_MAP[name](**kwargs)
892  except KeyError:
893    raise ValueError(
894        'Could not find requested v2 optimizer: {}\nValid choices: {}'.format(
895            name, list(_V2_OPTIMIZER_MAP.keys())))
896
897
898def get_expected_metric_variable_names(var_names, name_suffix=''):
899  """Returns expected metric variable names given names and prefix/suffix."""
900  if tf2.enabled() or context.executing_eagerly():
901    # In V1 eager mode and V2 variable names are not made unique.
902    return [n + ':0' for n in var_names]
903  # In V1 graph mode variable names are made unique using a suffix.
904  return [n + name_suffix + ':0' for n in var_names]
905
906
907def enable_v2_dtype_behavior(fn):
908  """Decorator for enabling the layer V2 dtype behavior on a test."""
909  return _set_v2_dtype_behavior(fn, True)
910
911
912def disable_v2_dtype_behavior(fn):
913  """Decorator for disabling the layer V2 dtype behavior on a test."""
914  return _set_v2_dtype_behavior(fn, False)
915
916
917def _set_v2_dtype_behavior(fn, enabled):
918  """Returns version of 'fn' that runs with v2 dtype behavior on or off."""
919  @functools.wraps(fn)
920  def wrapper(*args, **kwargs):
921    v2_dtype_behavior = base_layer_utils.V2_DTYPE_BEHAVIOR
922    base_layer_utils.V2_DTYPE_BEHAVIOR = enabled
923    try:
924      return fn(*args, **kwargs)
925    finally:
926      base_layer_utils.V2_DTYPE_BEHAVIOR = v2_dtype_behavior
927
928  return tf_decorator.make_decorator(fn, wrapper)
929
930
931@contextlib.contextmanager
932def device(should_use_gpu):
933  """Uses gpu when requested and available."""
934  if should_use_gpu and test_util.is_gpu_available():
935    dev = '/device:GPU:0'
936  else:
937    dev = '/device:CPU:0'
938  with ops.device(dev):
939    yield
940
941
942@contextlib.contextmanager
943def use_gpu():
944  """Uses gpu when requested and available."""
945  with device(should_use_gpu=True):
946    yield
947
948
949def for_all_test_methods(decorator, *args, **kwargs):
950  """Generate class-level decorator from given method-level decorator.
951
952  It is expected for the given decorator to take some arguments and return
953  a method that is then called on the test method to produce a decorated
954  method.
955
956  Args:
957    decorator: The decorator to apply.
958    *args: Positional arguments
959    **kwargs: Keyword arguments
960  Returns: Function that will decorate a given classes test methods with the
961    decorator.
962  """
963
964  def all_test_methods_impl(cls):
965    """Apply decorator to all test methods in class."""
966    for name in dir(cls):
967      value = getattr(cls, name)
968      if callable(value) and name.startswith('test') and (name !=
969                                                          'test_session'):
970        setattr(cls, name, decorator(*args, **kwargs)(value))
971    return cls
972
973  return all_test_methods_impl
974
975
976# The description is just for documentation purposes.
977def run_without_tensor_float_32(description):  # pylint: disable=unused-argument
978  """Execute test with TensorFloat-32 disabled.
979
980  While almost every real-world deep learning model runs fine with
981  TensorFloat-32, many tests use assertAllClose or similar methods.
982  TensorFloat-32 matmuls typically will cause such methods to fail with the
983  default tolerances.
984
985  Args:
986    description: A description used for documentation purposes, describing why
987      the test requires TensorFloat-32 to be disabled.
988
989  Returns:
990    Decorator which runs a test with TensorFloat-32 disabled.
991  """
992
993  def decorator(f):
994
995    @functools.wraps(f)
996    def decorated(self, *args, **kwargs):
997      allowed = config.tensor_float_32_execution_enabled()
998      try:
999        config.enable_tensor_float_32_execution(False)
1000        f(self, *args, **kwargs)
1001      finally:
1002        config.enable_tensor_float_32_execution(allowed)
1003
1004    return decorated
1005
1006  return decorator
1007
1008
1009# The description is just for documentation purposes.
1010def run_all_without_tensor_float_32(description):  # pylint: disable=unused-argument
1011  """Execute all tests in a class with TensorFloat-32 disabled."""
1012  return for_all_test_methods(run_without_tensor_float_32, description)
1013
1014
1015def run_v2_only(func=None):
1016  """Execute the decorated test only if running in v2 mode.
1017
1018  This function is intended to be applied to tests that exercise v2 only
1019  functionality. If the test is run in v1 mode it will simply be skipped.
1020
1021  See go/tf-test-decorator-cheatsheet for the decorators to use in different
1022  v1/v2/eager/graph combinations.
1023
1024  Args:
1025    func: function to be annotated. If `func` is None, this method returns a
1026      decorator the can be applied to a function. If `func` is not None this
1027      returns the decorator applied to `func`.
1028
1029  Returns:
1030    Returns a decorator that will conditionally skip the decorated test method.
1031  """
1032
1033  def decorator(f):
1034    if tf_inspect.isclass(f):
1035      raise ValueError('`run_v2_only` only supports test methods.')
1036
1037    def decorated(self, *args, **kwargs):
1038      if not tf2.enabled():
1039        self.skipTest('Test is only compatible with v2')
1040
1041      return f(self, *args, **kwargs)
1042
1043    return decorated
1044
1045  if func is not None:
1046    return decorator(func)
1047
1048  return decorator
1049
1050
1051def generate_combinations_with_testcase_name(**kwargs):
1052  """Generate combinations based on its keyword arguments using combine().
1053
1054  This function calls combine() and appends a testcase name to the list of
1055  dictionaries returned. The 'testcase_name' key is a required for named
1056  parameterized tests.
1057
1058  Args:
1059    **kwargs: keyword arguments of form `option=[possibilities, ...]` or
1060      `option=the_only_possibility`.
1061
1062  Returns:
1063    a list of dictionaries for each combination. Keys in the dictionaries are
1064    the keyword argument names.  Each key has one value - one of the
1065    corresponding keyword argument values.
1066  """
1067  sort_by_key = lambda k: k[0]
1068  combinations = []
1069  for key, values in sorted(kwargs.items(), key=sort_by_key):
1070    if not isinstance(values, list):
1071      values = [values]
1072    combinations.append([(key, value) for value in values])
1073
1074  combinations = [collections.OrderedDict(result)
1075                  for result in itertools.product(*combinations)]
1076  named_combinations = []
1077  for combination in combinations:
1078    assert isinstance(combination, collections.OrderedDict)
1079    name = ''.join([
1080        '_{}_{}'.format(''.join(filter(str.isalnum, key)),
1081                        ''.join(filter(str.isalnum, str(value))))
1082        for key, value in combination.items()
1083    ])
1084    named_combinations.append(
1085        collections.OrderedDict(
1086            list(combination.items()) +
1087            [('testcase_name', '_test{}'.format(name))]))
1088
1089  return named_combinations
1090