1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for unit-testing Keras.""" 16 17import collections 18import contextlib 19import functools 20import itertools 21import threading 22 23import numpy as np 24 25from tensorflow.python import tf2 26from tensorflow.python.eager import context 27from tensorflow.python.framework import config 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import ops 30from tensorflow.python.framework import tensor_shape 31from tensorflow.python.framework import tensor_spec 32from tensorflow.python.framework import test_util 33from tensorflow.python.keras import backend 34from tensorflow.python.keras import layers 35from tensorflow.python.keras import models 36from tensorflow.python.keras.engine import base_layer_utils 37from tensorflow.python.keras.optimizer_v2 import adadelta as adadelta_v2 38from tensorflow.python.keras.optimizer_v2 import adagrad as adagrad_v2 39from tensorflow.python.keras.optimizer_v2 import adam as adam_v2 40from tensorflow.python.keras.optimizer_v2 import adamax as adamax_v2 41from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2 42from tensorflow.python.keras.optimizer_v2 import nadam as nadam_v2 43from tensorflow.python.keras.optimizer_v2 import rmsprop as rmsprop_v2 44from tensorflow.python.keras.utils import tf_contextlib 45from tensorflow.python.keras.utils import tf_inspect 46from tensorflow.python.util import tf_decorator 47 48 49def string_test(actual, expected): 50 np.testing.assert_array_equal(actual, expected) 51 52 53def numeric_test(actual, expected): 54 np.testing.assert_allclose(actual, expected, rtol=1e-3, atol=1e-6) 55 56 57def get_test_data(train_samples, 58 test_samples, 59 input_shape, 60 num_classes, 61 random_seed=None): 62 """Generates test data to train a model on. 63 64 Args: 65 train_samples: Integer, how many training samples to generate. 66 test_samples: Integer, how many test samples to generate. 67 input_shape: Tuple of integers, shape of the inputs. 68 num_classes: Integer, number of classes for the data and targets. 69 random_seed: Integer, random seed used by numpy to generate data. 70 71 Returns: 72 A tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`. 73 """ 74 if random_seed is not None: 75 np.random.seed(random_seed) 76 num_sample = train_samples + test_samples 77 templates = 2 * num_classes * np.random.random((num_classes,) + input_shape) 78 y = np.random.randint(0, num_classes, size=(num_sample,)) 79 x = np.zeros((num_sample,) + input_shape, dtype=np.float32) 80 for i in range(num_sample): 81 x[i] = templates[y[i]] + np.random.normal(loc=0, scale=1., size=input_shape) 82 return ((x[:train_samples], y[:train_samples]), 83 (x[train_samples:], y[train_samples:])) 84 85 86@test_util.disable_cudnn_autotune 87def layer_test(layer_cls, 88 kwargs=None, 89 input_shape=None, 90 input_dtype=None, 91 input_data=None, 92 expected_output=None, 93 expected_output_dtype=None, 94 expected_output_shape=None, 95 validate_training=True, 96 adapt_data=None, 97 custom_objects=None, 98 test_harness=None, 99 supports_masking=None): 100 """Test routine for a layer with a single input and single output. 101 102 Args: 103 layer_cls: Layer class object. 104 kwargs: Optional dictionary of keyword arguments for instantiating the 105 layer. 106 input_shape: Input shape tuple. 107 input_dtype: Data type of the input data. 108 input_data: Numpy array of input data. 109 expected_output: Numpy array of the expected output. 110 expected_output_dtype: Data type expected for the output. 111 expected_output_shape: Shape tuple for the expected shape of the output. 112 validate_training: Whether to attempt to validate training on this layer. 113 This might be set to False for non-differentiable layers that output 114 string or integer values. 115 adapt_data: Optional data for an 'adapt' call. If None, adapt() will not 116 be tested for this layer. This is only relevant for PreprocessingLayers. 117 custom_objects: Optional dictionary mapping name strings to custom objects 118 in the layer class. This is helpful for testing custom layers. 119 test_harness: The Tensorflow test, if any, that this function is being 120 called in. 121 supports_masking: Optional boolean to check the `supports_masking` property 122 of the layer. If None, the check will not be performed. 123 124 Returns: 125 The output data (Numpy array) returned by the layer, for additional 126 checks to be done by the calling code. 127 128 Raises: 129 ValueError: if `input_shape is None`. 130 """ 131 if input_data is None: 132 if input_shape is None: 133 raise ValueError('input_shape is None') 134 if not input_dtype: 135 input_dtype = 'float32' 136 input_data_shape = list(input_shape) 137 for i, e in enumerate(input_data_shape): 138 if e is None: 139 input_data_shape[i] = np.random.randint(1, 4) 140 input_data = 10 * np.random.random(input_data_shape) 141 if input_dtype[:5] == 'float': 142 input_data -= 0.5 143 input_data = input_data.astype(input_dtype) 144 elif input_shape is None: 145 input_shape = input_data.shape 146 if input_dtype is None: 147 input_dtype = input_data.dtype 148 if expected_output_dtype is None: 149 expected_output_dtype = input_dtype 150 151 if dtypes.as_dtype(expected_output_dtype) == dtypes.string: 152 if test_harness: 153 assert_equal = test_harness.assertAllEqual 154 else: 155 assert_equal = string_test 156 else: 157 if test_harness: 158 assert_equal = test_harness.assertAllClose 159 else: 160 assert_equal = numeric_test 161 162 # instantiation 163 kwargs = kwargs or {} 164 layer = layer_cls(**kwargs) 165 166 if (supports_masking is not None 167 and layer.supports_masking != supports_masking): 168 raise AssertionError( 169 'When testing layer %s, the `supports_masking` property is %r' 170 'but expected to be %r.\nFull kwargs: %s' % 171 (layer_cls.__name__, layer.supports_masking, supports_masking, kwargs)) 172 173 # Test adapt, if data was passed. 174 if adapt_data is not None: 175 layer.adapt(adapt_data) 176 177 # test get_weights , set_weights at layer level 178 weights = layer.get_weights() 179 layer.set_weights(weights) 180 181 # test and instantiation from weights 182 if 'weights' in tf_inspect.getargspec(layer_cls.__init__): 183 kwargs['weights'] = weights 184 layer = layer_cls(**kwargs) 185 186 # test in functional API 187 x = layers.Input(shape=input_shape[1:], dtype=input_dtype) 188 y = layer(x) 189 if backend.dtype(y) != expected_output_dtype: 190 raise AssertionError('When testing layer %s, for input %s, found output ' 191 'dtype=%s but expected to find %s.\nFull kwargs: %s' % 192 (layer_cls.__name__, x, backend.dtype(y), 193 expected_output_dtype, kwargs)) 194 195 def assert_shapes_equal(expected, actual): 196 """Asserts that the output shape from the layer matches the actual shape.""" 197 if len(expected) != len(actual): 198 raise AssertionError( 199 'When testing layer %s, for input %s, found output_shape=' 200 '%s but expected to find %s.\nFull kwargs: %s' % 201 (layer_cls.__name__, x, actual, expected, kwargs)) 202 203 for expected_dim, actual_dim in zip(expected, actual): 204 if isinstance(expected_dim, tensor_shape.Dimension): 205 expected_dim = expected_dim.value 206 if isinstance(actual_dim, tensor_shape.Dimension): 207 actual_dim = actual_dim.value 208 if expected_dim is not None and expected_dim != actual_dim: 209 raise AssertionError( 210 'When testing layer %s, for input %s, found output_shape=' 211 '%s but expected to find %s.\nFull kwargs: %s' % 212 (layer_cls.__name__, x, actual, expected, kwargs)) 213 214 if expected_output_shape is not None: 215 assert_shapes_equal(tensor_shape.TensorShape(expected_output_shape), 216 y.shape) 217 218 # check shape inference 219 model = models.Model(x, y) 220 computed_output_shape = tuple( 221 layer.compute_output_shape( 222 tensor_shape.TensorShape(input_shape)).as_list()) 223 computed_output_signature = layer.compute_output_signature( 224 tensor_spec.TensorSpec(shape=input_shape, dtype=input_dtype)) 225 actual_output = model.predict(input_data) 226 actual_output_shape = actual_output.shape 227 assert_shapes_equal(computed_output_shape, actual_output_shape) 228 assert_shapes_equal(computed_output_signature.shape, actual_output_shape) 229 if computed_output_signature.dtype != actual_output.dtype: 230 raise AssertionError( 231 'When testing layer %s, for input %s, found output_dtype=' 232 '%s but expected to find %s.\nFull kwargs: %s' % 233 (layer_cls.__name__, x, actual_output.dtype, 234 computed_output_signature.dtype, kwargs)) 235 if expected_output is not None: 236 assert_equal(actual_output, expected_output) 237 238 # test serialization, weight setting at model level 239 model_config = model.get_config() 240 recovered_model = models.Model.from_config(model_config, custom_objects) 241 if model.weights: 242 weights = model.get_weights() 243 recovered_model.set_weights(weights) 244 output = recovered_model.predict(input_data) 245 assert_equal(output, actual_output) 246 247 # test training mode (e.g. useful for dropout tests) 248 # Rebuild the model to avoid the graph being reused between predict() and 249 # See b/120160788 for more details. This should be mitigated after 2.0. 250 layer_weights = layer.get_weights() # Get the layer weights BEFORE training. 251 if validate_training: 252 model = models.Model(x, layer(x)) 253 if _thread_local_data.run_eagerly is not None: 254 model.compile( 255 'rmsprop', 256 'mse', 257 weighted_metrics=['acc'], 258 run_eagerly=should_run_eagerly()) 259 else: 260 model.compile('rmsprop', 'mse', weighted_metrics=['acc']) 261 model.train_on_batch(input_data, actual_output) 262 263 # test as first layer in Sequential API 264 layer_config = layer.get_config() 265 layer_config['batch_input_shape'] = input_shape 266 layer = layer.__class__.from_config(layer_config) 267 268 # Test adapt, if data was passed. 269 if adapt_data is not None: 270 layer.adapt(adapt_data) 271 272 model = models.Sequential() 273 model.add(layers.Input(shape=input_shape[1:], dtype=input_dtype)) 274 model.add(layer) 275 276 layer.set_weights(layer_weights) 277 actual_output = model.predict(input_data) 278 actual_output_shape = actual_output.shape 279 for expected_dim, actual_dim in zip(computed_output_shape, 280 actual_output_shape): 281 if expected_dim is not None: 282 if expected_dim != actual_dim: 283 raise AssertionError( 284 'When testing layer %s **after deserialization**, ' 285 'for input %s, found output_shape=' 286 '%s but expected to find inferred shape %s.\nFull kwargs: %s' % 287 (layer_cls.__name__, 288 x, 289 actual_output_shape, 290 computed_output_shape, 291 kwargs)) 292 if expected_output is not None: 293 assert_equal(actual_output, expected_output) 294 295 # test serialization, weight setting at model level 296 model_config = model.get_config() 297 recovered_model = models.Sequential.from_config(model_config, custom_objects) 298 if model.weights: 299 weights = model.get_weights() 300 recovered_model.set_weights(weights) 301 output = recovered_model.predict(input_data) 302 assert_equal(output, actual_output) 303 304 # for further checks in the caller function 305 return actual_output 306 307 308_thread_local_data = threading.local() 309_thread_local_data.model_type = None 310_thread_local_data.run_eagerly = None 311_thread_local_data.saved_model_format = None 312_thread_local_data.save_kwargs = None 313 314 315@tf_contextlib.contextmanager 316def model_type_scope(value): 317 """Provides a scope within which the model type to test is equal to `value`. 318 319 The model type gets restored to its original value upon exiting the scope. 320 321 Args: 322 value: model type value 323 324 Yields: 325 The provided value. 326 """ 327 previous_value = _thread_local_data.model_type 328 try: 329 _thread_local_data.model_type = value 330 yield value 331 finally: 332 # Restore model type to initial value. 333 _thread_local_data.model_type = previous_value 334 335 336@tf_contextlib.contextmanager 337def run_eagerly_scope(value): 338 """Provides a scope within which we compile models to run eagerly or not. 339 340 The boolean gets restored to its original value upon exiting the scope. 341 342 Args: 343 value: Bool specifying if we should run models eagerly in the active test. 344 Should be True or False. 345 346 Yields: 347 The provided value. 348 """ 349 previous_value = _thread_local_data.run_eagerly 350 try: 351 _thread_local_data.run_eagerly = value 352 yield value 353 finally: 354 # Restore model type to initial value. 355 _thread_local_data.run_eagerly = previous_value 356 357 358def should_run_eagerly(): 359 """Returns whether the models we are testing should be run eagerly.""" 360 if _thread_local_data.run_eagerly is None: 361 raise ValueError('Cannot call `should_run_eagerly()` outside of a ' 362 '`run_eagerly_scope()` or `run_all_keras_modes` ' 363 'decorator.') 364 365 return _thread_local_data.run_eagerly and context.executing_eagerly() 366 367 368@tf_contextlib.contextmanager 369def saved_model_format_scope(value, **kwargs): 370 """Provides a scope within which the savde model format to test is `value`. 371 372 The saved model format gets restored to its original value upon exiting the 373 scope. 374 375 Args: 376 value: saved model format value 377 **kwargs: optional kwargs to pass to the save function. 378 379 Yields: 380 The provided value. 381 """ 382 previous_format = _thread_local_data.saved_model_format 383 previous_kwargs = _thread_local_data.save_kwargs 384 try: 385 _thread_local_data.saved_model_format = value 386 _thread_local_data.save_kwargs = kwargs 387 yield 388 finally: 389 # Restore saved model format to initial value. 390 _thread_local_data.saved_model_format = previous_format 391 _thread_local_data.save_kwargs = previous_kwargs 392 393 394def get_save_format(): 395 if _thread_local_data.saved_model_format is None: 396 raise ValueError( 397 'Cannot call `get_save_format()` outside of a ' 398 '`saved_model_format_scope()` or `run_with_all_saved_model_formats` ' 399 'decorator.') 400 return _thread_local_data.saved_model_format 401 402 403def get_save_kwargs(): 404 if _thread_local_data.save_kwargs is None: 405 raise ValueError( 406 'Cannot call `get_save_kwargs()` outside of a ' 407 '`saved_model_format_scope()` or `run_with_all_saved_model_formats` ' 408 'decorator.') 409 return _thread_local_data.save_kwargs or {} 410 411 412def get_model_type(): 413 """Gets the model type that should be tested.""" 414 if _thread_local_data.model_type is None: 415 raise ValueError('Cannot call `get_model_type()` outside of a ' 416 '`model_type_scope()` or `run_with_all_model_types` ' 417 'decorator.') 418 419 return _thread_local_data.model_type 420 421 422def get_small_sequential_mlp(num_hidden, num_classes, input_dim=None): 423 model = models.Sequential() 424 if input_dim: 425 model.add(layers.Dense(num_hidden, activation='relu', input_dim=input_dim)) 426 else: 427 model.add(layers.Dense(num_hidden, activation='relu')) 428 activation = 'sigmoid' if num_classes == 1 else 'softmax' 429 model.add(layers.Dense(num_classes, activation=activation)) 430 return model 431 432 433def get_small_functional_mlp(num_hidden, num_classes, input_dim): 434 inputs = layers.Input(shape=(input_dim,)) 435 outputs = layers.Dense(num_hidden, activation='relu')(inputs) 436 activation = 'sigmoid' if num_classes == 1 else 'softmax' 437 outputs = layers.Dense(num_classes, activation=activation)(outputs) 438 return models.Model(inputs, outputs) 439 440 441class SmallSubclassMLP(models.Model): 442 """A subclass model based small MLP.""" 443 444 def __init__(self, 445 num_hidden, 446 num_classes, 447 use_bn=False, 448 use_dp=False, 449 **kwargs): 450 super(SmallSubclassMLP, self).__init__(name='test_model', **kwargs) 451 self.use_bn = use_bn 452 self.use_dp = use_dp 453 454 self.layer_a = layers.Dense(num_hidden, activation='relu') 455 activation = 'sigmoid' if num_classes == 1 else 'softmax' 456 self.layer_b = layers.Dense(num_classes, activation=activation) 457 if self.use_dp: 458 self.dp = layers.Dropout(0.5) 459 if self.use_bn: 460 self.bn = layers.BatchNormalization(axis=-1) 461 462 def call(self, inputs, **kwargs): 463 x = self.layer_a(inputs) 464 if self.use_dp: 465 x = self.dp(x) 466 if self.use_bn: 467 x = self.bn(x) 468 return self.layer_b(x) 469 470 471class _SmallSubclassMLPCustomBuild(models.Model): 472 """A subclass model small MLP that uses a custom build method.""" 473 474 def __init__(self, num_hidden, num_classes): 475 super(_SmallSubclassMLPCustomBuild, self).__init__() 476 self.layer_a = None 477 self.layer_b = None 478 self.num_hidden = num_hidden 479 self.num_classes = num_classes 480 481 def build(self, input_shape): 482 self.layer_a = layers.Dense(self.num_hidden, activation='relu') 483 activation = 'sigmoid' if self.num_classes == 1 else 'softmax' 484 self.layer_b = layers.Dense(self.num_classes, activation=activation) 485 486 def call(self, inputs, **kwargs): 487 x = self.layer_a(inputs) 488 return self.layer_b(x) 489 490 491def get_small_subclass_mlp(num_hidden, num_classes): 492 return SmallSubclassMLP(num_hidden, num_classes) 493 494 495def get_small_subclass_mlp_with_custom_build(num_hidden, num_classes): 496 return _SmallSubclassMLPCustomBuild(num_hidden, num_classes) 497 498 499def get_small_mlp(num_hidden, num_classes, input_dim): 500 """Get a small mlp of the model type specified by `get_model_type`.""" 501 model_type = get_model_type() 502 if model_type == 'subclass': 503 return get_small_subclass_mlp(num_hidden, num_classes) 504 if model_type == 'subclass_custom_build': 505 return get_small_subclass_mlp_with_custom_build(num_hidden, num_classes) 506 if model_type == 'sequential': 507 return get_small_sequential_mlp(num_hidden, num_classes, input_dim) 508 if model_type == 'functional': 509 return get_small_functional_mlp(num_hidden, num_classes, input_dim) 510 raise ValueError('Unknown model type {}'.format(model_type)) 511 512 513class _SubclassModel(models.Model): 514 """A Keras subclass model.""" 515 516 def __init__(self, model_layers, *args, **kwargs): 517 """Instantiate a model. 518 519 Args: 520 model_layers: a list of layers to be added to the model. 521 *args: Model's args 522 **kwargs: Model's keyword args, at most one of input_tensor -> the input 523 tensor required for ragged/sparse input. 524 """ 525 526 inputs = kwargs.pop('input_tensor', None) 527 super(_SubclassModel, self).__init__(*args, **kwargs) 528 # Note that clone and build doesn't support lists of layers in subclassed 529 # models. Adding each layer directly here. 530 for i, layer in enumerate(model_layers): 531 setattr(self, self._layer_name_for_i(i), layer) 532 533 self.num_layers = len(model_layers) 534 535 if inputs is not None: 536 self._set_inputs(inputs) 537 538 def _layer_name_for_i(self, i): 539 return 'layer{}'.format(i) 540 541 def call(self, inputs, **kwargs): 542 x = inputs 543 for i in range(self.num_layers): 544 layer = getattr(self, self._layer_name_for_i(i)) 545 x = layer(x) 546 return x 547 548 549class _SubclassModelCustomBuild(models.Model): 550 """A Keras subclass model that uses a custom build method.""" 551 552 def __init__(self, layer_generating_func, *args, **kwargs): 553 super(_SubclassModelCustomBuild, self).__init__(*args, **kwargs) 554 self.all_layers = None 555 self._layer_generating_func = layer_generating_func 556 557 def build(self, input_shape): 558 model_layers = [] 559 for layer in self._layer_generating_func(): 560 model_layers.append(layer) 561 self.all_layers = model_layers 562 563 def call(self, inputs, **kwargs): 564 x = inputs 565 for layer in self.all_layers: 566 x = layer(x) 567 return x 568 569 570def get_model_from_layers(model_layers, 571 input_shape=None, 572 input_dtype=None, 573 name=None, 574 input_ragged=None, 575 input_sparse=None, 576 model_type=None): 577 """Builds a model from a sequence of layers. 578 579 Args: 580 model_layers: The layers used to build the network. 581 input_shape: Shape tuple of the input or 'TensorShape' instance. 582 input_dtype: Datatype of the input. 583 name: Name for the model. 584 input_ragged: Boolean, whether the input data is a ragged tensor. 585 input_sparse: Boolean, whether the input data is a sparse tensor. 586 model_type: One of "subclass", "subclass_custom_build", "sequential", or 587 "functional". When None, defaults to `get_model_type`. 588 589 Returns: 590 A Keras model. 591 """ 592 if model_type is None: 593 model_type = get_model_type() 594 if model_type == 'subclass': 595 inputs = None 596 if input_ragged or input_sparse: 597 inputs = layers.Input( 598 shape=input_shape, 599 dtype=input_dtype, 600 ragged=input_ragged, 601 sparse=input_sparse) 602 return _SubclassModel(model_layers, name=name, input_tensor=inputs) 603 604 if model_type == 'subclass_custom_build': 605 layer_generating_func = lambda: model_layers 606 return _SubclassModelCustomBuild(layer_generating_func, name=name) 607 608 if model_type == 'sequential': 609 model = models.Sequential(name=name) 610 if input_shape: 611 model.add( 612 layers.InputLayer( 613 input_shape=input_shape, 614 dtype=input_dtype, 615 ragged=input_ragged, 616 sparse=input_sparse)) 617 for layer in model_layers: 618 model.add(layer) 619 return model 620 621 if model_type == 'functional': 622 if not input_shape: 623 raise ValueError('Cannot create a functional model from layers with no ' 624 'input shape.') 625 inputs = layers.Input( 626 shape=input_shape, 627 dtype=input_dtype, 628 ragged=input_ragged, 629 sparse=input_sparse) 630 outputs = inputs 631 for layer in model_layers: 632 outputs = layer(outputs) 633 return models.Model(inputs, outputs, name=name) 634 635 raise ValueError('Unknown model type {}'.format(model_type)) 636 637 638class Bias(layers.Layer): 639 640 def build(self, input_shape): 641 self.bias = self.add_variable('bias', (1,), initializer='zeros') 642 643 def call(self, inputs): 644 return inputs + self.bias 645 646 647class _MultiIOSubclassModel(models.Model): 648 """Multi IO Keras subclass model.""" 649 650 def __init__(self, branch_a, branch_b, shared_input_branch=None, 651 shared_output_branch=None, name=None): 652 super(_MultiIOSubclassModel, self).__init__(name=name) 653 self._shared_input_branch = shared_input_branch 654 self._branch_a = branch_a 655 self._branch_b = branch_b 656 self._shared_output_branch = shared_output_branch 657 658 def call(self, inputs, **kwargs): 659 if self._shared_input_branch: 660 for layer in self._shared_input_branch: 661 inputs = layer(inputs) 662 a = inputs 663 b = inputs 664 elif isinstance(inputs, dict): 665 a = inputs['input_1'] 666 b = inputs['input_2'] 667 else: 668 a, b = inputs 669 670 for layer in self._branch_a: 671 a = layer(a) 672 for layer in self._branch_b: 673 b = layer(b) 674 outs = [a, b] 675 676 if self._shared_output_branch: 677 for layer in self._shared_output_branch: 678 outs = layer(outs) 679 680 return outs 681 682 683class _MultiIOSubclassModelCustomBuild(models.Model): 684 """Multi IO Keras subclass model that uses a custom build method.""" 685 686 def __init__(self, branch_a_func, branch_b_func, 687 shared_input_branch_func=None, 688 shared_output_branch_func=None): 689 super(_MultiIOSubclassModelCustomBuild, self).__init__() 690 self._shared_input_branch_func = shared_input_branch_func 691 self._branch_a_func = branch_a_func 692 self._branch_b_func = branch_b_func 693 self._shared_output_branch_func = shared_output_branch_func 694 695 self._shared_input_branch = None 696 self._branch_a = None 697 self._branch_b = None 698 self._shared_output_branch = None 699 700 def build(self, input_shape): 701 if self._shared_input_branch_func(): 702 self._shared_input_branch = self._shared_input_branch_func() 703 self._branch_a = self._branch_a_func() 704 self._branch_b = self._branch_b_func() 705 706 if self._shared_output_branch_func(): 707 self._shared_output_branch = self._shared_output_branch_func() 708 709 def call(self, inputs, **kwargs): 710 if self._shared_input_branch: 711 for layer in self._shared_input_branch: 712 inputs = layer(inputs) 713 a = inputs 714 b = inputs 715 else: 716 a, b = inputs 717 718 for layer in self._branch_a: 719 a = layer(a) 720 for layer in self._branch_b: 721 b = layer(b) 722 outs = a, b 723 724 if self._shared_output_branch: 725 for layer in self._shared_output_branch: 726 outs = layer(outs) 727 728 return outs 729 730 731def get_multi_io_model( 732 branch_a, 733 branch_b, 734 shared_input_branch=None, 735 shared_output_branch=None): 736 """Builds a multi-io model that contains two branches. 737 738 The produced model will be of the type specified by `get_model_type`. 739 740 To build a two-input, two-output model: 741 Specify a list of layers for branch a and branch b, but do not specify any 742 shared input branch or shared output branch. The resulting model will apply 743 each branch to a different input, to produce two outputs. 744 745 The first value in branch_a must be the Keras 'Input' layer for branch a, 746 and the first value in branch_b must be the Keras 'Input' layer for 747 branch b. 748 749 example usage: 750 ``` 751 branch_a = [Input(shape=(2,), name='a'), Dense(), Dense()] 752 branch_b = [Input(shape=(3,), name='b'), Dense(), Dense()] 753 754 model = get_multi_io_model(branch_a, branch_b) 755 ``` 756 757 To build a two-input, one-output model: 758 Specify a list of layers for branch a and branch b, and specify a 759 shared output branch. The resulting model will apply 760 each branch to a different input. It will then apply the shared output 761 branch to a tuple containing the intermediate outputs of each branch, 762 to produce a single output. The first layer in the shared_output_branch 763 must be able to merge a tuple of two tensors. 764 765 The first value in branch_a must be the Keras 'Input' layer for branch a, 766 and the first value in branch_b must be the Keras 'Input' layer for 767 branch b. 768 769 example usage: 770 ``` 771 input_branch_a = [Input(shape=(2,), name='a'), Dense(), Dense()] 772 input_branch_b = [Input(shape=(3,), name='b'), Dense(), Dense()] 773 shared_output_branch = [Concatenate(), Dense(), Dense()] 774 775 model = get_multi_io_model(input_branch_a, input_branch_b, 776 shared_output_branch=shared_output_branch) 777 ``` 778 To build a one-input, two-output model: 779 Specify a list of layers for branch a and branch b, and specify a 780 shared input branch. The resulting model will take one input, and apply 781 the shared input branch to it. It will then respectively apply each branch 782 to that intermediate result in parallel, to produce two outputs. 783 784 The first value in the shared_input_branch must be the Keras 'Input' layer 785 for the whole model. Branch a and branch b should not contain any Input 786 layers. 787 788 example usage: 789 ``` 790 shared_input_branch = [Input(shape=(2,), name='in'), Dense(), Dense()] 791 output_branch_a = [Dense(), Dense()] 792 output_branch_b = [Dense(), Dense()] 793 794 795 model = get_multi_io_model(output__branch_a, output_branch_b, 796 shared_input_branch=shared_input_branch) 797 ``` 798 799 Args: 800 branch_a: A sequence of layers for branch a of the model. 801 branch_b: A sequence of layers for branch b of the model. 802 shared_input_branch: An optional sequence of layers to apply to a single 803 input, before applying both branches to that intermediate result. If set, 804 the model will take only one input instead of two. Defaults to None. 805 shared_output_branch: An optional sequence of layers to merge the 806 intermediate results produced by branch a and branch b. If set, 807 the model will produce only one output instead of two. Defaults to None. 808 809 Returns: 810 A multi-io model of the type specified by `get_model_type`, specified 811 by the different branches. 812 """ 813 # Extract the functional inputs from the layer lists 814 if shared_input_branch: 815 inputs = shared_input_branch[0] 816 shared_input_branch = shared_input_branch[1:] 817 else: 818 inputs = branch_a[0], branch_b[0] 819 branch_a = branch_a[1:] 820 branch_b = branch_b[1:] 821 822 model_type = get_model_type() 823 if model_type == 'subclass': 824 return _MultiIOSubclassModel(branch_a, branch_b, shared_input_branch, 825 shared_output_branch) 826 827 if model_type == 'subclass_custom_build': 828 return _MultiIOSubclassModelCustomBuild((lambda: branch_a), 829 (lambda: branch_b), 830 (lambda: shared_input_branch), 831 (lambda: shared_output_branch)) 832 833 if model_type == 'sequential': 834 raise ValueError('Cannot use `get_multi_io_model` to construct ' 835 'sequential models') 836 837 if model_type == 'functional': 838 if shared_input_branch: 839 a_and_b = inputs 840 for layer in shared_input_branch: 841 a_and_b = layer(a_and_b) 842 a = a_and_b 843 b = a_and_b 844 else: 845 a, b = inputs 846 847 for layer in branch_a: 848 a = layer(a) 849 for layer in branch_b: 850 b = layer(b) 851 outputs = a, b 852 853 if shared_output_branch: 854 for layer in shared_output_branch: 855 outputs = layer(outputs) 856 857 return models.Model(inputs, outputs) 858 859 raise ValueError('Unknown model type {}'.format(model_type)) 860 861 862_V2_OPTIMIZER_MAP = { 863 'adadelta': adadelta_v2.Adadelta, 864 'adagrad': adagrad_v2.Adagrad, 865 'adam': adam_v2.Adam, 866 'adamax': adamax_v2.Adamax, 867 'nadam': nadam_v2.Nadam, 868 'rmsprop': rmsprop_v2.RMSprop, 869 'sgd': gradient_descent_v2.SGD 870} 871 872 873def get_v2_optimizer(name, **kwargs): 874 """Get the v2 optimizer requested. 875 876 This is only necessary until v2 are the default, as we are testing in Eager, 877 and Eager + v1 optimizers fail tests. When we are in v2, the strings alone 878 should be sufficient, and this mapping can theoretically be removed. 879 880 Args: 881 name: string name of Keras v2 optimizer. 882 **kwargs: any kwargs to pass to the optimizer constructor. 883 884 Returns: 885 Initialized Keras v2 optimizer. 886 887 Raises: 888 ValueError: if an unknown name was passed. 889 """ 890 try: 891 return _V2_OPTIMIZER_MAP[name](**kwargs) 892 except KeyError: 893 raise ValueError( 894 'Could not find requested v2 optimizer: {}\nValid choices: {}'.format( 895 name, list(_V2_OPTIMIZER_MAP.keys()))) 896 897 898def get_expected_metric_variable_names(var_names, name_suffix=''): 899 """Returns expected metric variable names given names and prefix/suffix.""" 900 if tf2.enabled() or context.executing_eagerly(): 901 # In V1 eager mode and V2 variable names are not made unique. 902 return [n + ':0' for n in var_names] 903 # In V1 graph mode variable names are made unique using a suffix. 904 return [n + name_suffix + ':0' for n in var_names] 905 906 907def enable_v2_dtype_behavior(fn): 908 """Decorator for enabling the layer V2 dtype behavior on a test.""" 909 return _set_v2_dtype_behavior(fn, True) 910 911 912def disable_v2_dtype_behavior(fn): 913 """Decorator for disabling the layer V2 dtype behavior on a test.""" 914 return _set_v2_dtype_behavior(fn, False) 915 916 917def _set_v2_dtype_behavior(fn, enabled): 918 """Returns version of 'fn' that runs with v2 dtype behavior on or off.""" 919 @functools.wraps(fn) 920 def wrapper(*args, **kwargs): 921 v2_dtype_behavior = base_layer_utils.V2_DTYPE_BEHAVIOR 922 base_layer_utils.V2_DTYPE_BEHAVIOR = enabled 923 try: 924 return fn(*args, **kwargs) 925 finally: 926 base_layer_utils.V2_DTYPE_BEHAVIOR = v2_dtype_behavior 927 928 return tf_decorator.make_decorator(fn, wrapper) 929 930 931@contextlib.contextmanager 932def device(should_use_gpu): 933 """Uses gpu when requested and available.""" 934 if should_use_gpu and test_util.is_gpu_available(): 935 dev = '/device:GPU:0' 936 else: 937 dev = '/device:CPU:0' 938 with ops.device(dev): 939 yield 940 941 942@contextlib.contextmanager 943def use_gpu(): 944 """Uses gpu when requested and available.""" 945 with device(should_use_gpu=True): 946 yield 947 948 949def for_all_test_methods(decorator, *args, **kwargs): 950 """Generate class-level decorator from given method-level decorator. 951 952 It is expected for the given decorator to take some arguments and return 953 a method that is then called on the test method to produce a decorated 954 method. 955 956 Args: 957 decorator: The decorator to apply. 958 *args: Positional arguments 959 **kwargs: Keyword arguments 960 Returns: Function that will decorate a given classes test methods with the 961 decorator. 962 """ 963 964 def all_test_methods_impl(cls): 965 """Apply decorator to all test methods in class.""" 966 for name in dir(cls): 967 value = getattr(cls, name) 968 if callable(value) and name.startswith('test') and (name != 969 'test_session'): 970 setattr(cls, name, decorator(*args, **kwargs)(value)) 971 return cls 972 973 return all_test_methods_impl 974 975 976# The description is just for documentation purposes. 977def run_without_tensor_float_32(description): # pylint: disable=unused-argument 978 """Execute test with TensorFloat-32 disabled. 979 980 While almost every real-world deep learning model runs fine with 981 TensorFloat-32, many tests use assertAllClose or similar methods. 982 TensorFloat-32 matmuls typically will cause such methods to fail with the 983 default tolerances. 984 985 Args: 986 description: A description used for documentation purposes, describing why 987 the test requires TensorFloat-32 to be disabled. 988 989 Returns: 990 Decorator which runs a test with TensorFloat-32 disabled. 991 """ 992 993 def decorator(f): 994 995 @functools.wraps(f) 996 def decorated(self, *args, **kwargs): 997 allowed = config.tensor_float_32_execution_enabled() 998 try: 999 config.enable_tensor_float_32_execution(False) 1000 f(self, *args, **kwargs) 1001 finally: 1002 config.enable_tensor_float_32_execution(allowed) 1003 1004 return decorated 1005 1006 return decorator 1007 1008 1009# The description is just for documentation purposes. 1010def run_all_without_tensor_float_32(description): # pylint: disable=unused-argument 1011 """Execute all tests in a class with TensorFloat-32 disabled.""" 1012 return for_all_test_methods(run_without_tensor_float_32, description) 1013 1014 1015def run_v2_only(func=None): 1016 """Execute the decorated test only if running in v2 mode. 1017 1018 This function is intended to be applied to tests that exercise v2 only 1019 functionality. If the test is run in v1 mode it will simply be skipped. 1020 1021 See go/tf-test-decorator-cheatsheet for the decorators to use in different 1022 v1/v2/eager/graph combinations. 1023 1024 Args: 1025 func: function to be annotated. If `func` is None, this method returns a 1026 decorator the can be applied to a function. If `func` is not None this 1027 returns the decorator applied to `func`. 1028 1029 Returns: 1030 Returns a decorator that will conditionally skip the decorated test method. 1031 """ 1032 1033 def decorator(f): 1034 if tf_inspect.isclass(f): 1035 raise ValueError('`run_v2_only` only supports test methods.') 1036 1037 def decorated(self, *args, **kwargs): 1038 if not tf2.enabled(): 1039 self.skipTest('Test is only compatible with v2') 1040 1041 return f(self, *args, **kwargs) 1042 1043 return decorated 1044 1045 if func is not None: 1046 return decorator(func) 1047 1048 return decorator 1049 1050 1051def generate_combinations_with_testcase_name(**kwargs): 1052 """Generate combinations based on its keyword arguments using combine(). 1053 1054 This function calls combine() and appends a testcase name to the list of 1055 dictionaries returned. The 'testcase_name' key is a required for named 1056 parameterized tests. 1057 1058 Args: 1059 **kwargs: keyword arguments of form `option=[possibilities, ...]` or 1060 `option=the_only_possibility`. 1061 1062 Returns: 1063 a list of dictionaries for each combination. Keys in the dictionaries are 1064 the keyword argument names. Each key has one value - one of the 1065 corresponding keyword argument values. 1066 """ 1067 sort_by_key = lambda k: k[0] 1068 combinations = [] 1069 for key, values in sorted(kwargs.items(), key=sort_by_key): 1070 if not isinstance(values, list): 1071 values = [values] 1072 combinations.append([(key, value) for value in values]) 1073 1074 combinations = [collections.OrderedDict(result) 1075 for result in itertools.product(*combinations)] 1076 named_combinations = [] 1077 for combination in combinations: 1078 assert isinstance(combination, collections.OrderedDict) 1079 name = ''.join([ 1080 '_{}_{}'.format(''.join(filter(str.isalnum, key)), 1081 ''.join(filter(str.isalnum, str(value)))) 1082 for key, value in combination.items() 1083 ]) 1084 named_combinations.append( 1085 collections.OrderedDict( 1086 list(combination.items()) + 1087 [('testcase_name', '_test{}'.format(name))])) 1088 1089 return named_combinations 1090