1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A simple network to use in tests and examples.""" 16 17from tensorflow.python.data.ops import dataset_ops 18from tensorflow.python.distribute import step_fn 19from tensorflow.python.distribute import strategy_test_lib 20from tensorflow.python.framework import constant_op 21from tensorflow.python.framework import ops 22from tensorflow.python.layers import core 23from tensorflow.python.layers import normalization 24from tensorflow.python.ops import array_ops 25from tensorflow.python.ops import math_ops 26 27 28def single_loss_example(optimizer_fn, distribution, use_bias=False, 29 iterations_per_step=1): 30 """Build a very simple network to use in tests and examples.""" 31 32 def dataset_fn(): 33 return dataset_ops.Dataset.from_tensors([[1.]]).repeat() 34 35 optimizer = optimizer_fn() 36 layer = core.Dense(1, use_bias=use_bias) 37 38 def loss_fn(ctx, x): 39 del ctx 40 y = array_ops.reshape(layer(x), []) - constant_op.constant(1.) 41 return y * y 42 43 single_loss_step = step_fn.StandardSingleLossStep( 44 dataset_fn, loss_fn, optimizer, distribution, iterations_per_step) 45 46 # Layer is returned for inspecting the kernels in tests. 47 return single_loss_step, layer 48 49 50def minimize_loss_example(optimizer, use_bias=False, use_callable_loss=True): 51 """Example of non-distribution-aware legacy code.""" 52 53 def dataset_fn(): 54 dataset = dataset_ops.Dataset.from_tensors([[1.]]).repeat() 55 # TODO(isaprykin): batch with drop_remainder causes shapes to be 56 # fully defined for TPU. Remove this when XLA supports dynamic shapes. 57 return dataset.batch(1, drop_remainder=True) 58 59 layer = core.Dense(1, use_bias=use_bias) 60 61 def model_fn(x): 62 """A very simple model written by the user.""" 63 64 def loss_fn(): 65 y = array_ops.reshape(layer(x), []) - constant_op.constant(1.) 66 return y * y 67 68 if strategy_test_lib.is_optimizer_v2_instance(optimizer): 69 return optimizer.minimize(loss_fn, lambda: layer.trainable_variables) 70 elif use_callable_loss: 71 return optimizer.minimize(loss_fn) 72 else: 73 return optimizer.minimize(loss_fn()) 74 75 return model_fn, dataset_fn, layer 76 77 78def batchnorm_example(optimizer_fn, 79 batch_per_epoch=1, 80 momentum=0.9, 81 renorm=False, 82 update_ops_in_replica_mode=False): 83 """Example of non-distribution-aware legacy code with batch normalization.""" 84 85 def dataset_fn(): 86 # input shape is [16, 8], input values are increasing in both dimensions. 87 return dataset_ops.Dataset.from_tensor_slices( 88 [[[float(x * 8 + y + z * 100) 89 for y in range(8)] 90 for x in range(16)] 91 for z in range(batch_per_epoch)]).repeat() 92 93 optimizer = optimizer_fn() 94 batchnorm = normalization.BatchNormalization( 95 renorm=renorm, momentum=momentum, fused=False) 96 layer = core.Dense(1, use_bias=False) 97 98 def model_fn(x): 99 """A model that uses batchnorm.""" 100 101 def loss_fn(): 102 y = batchnorm(x, training=True) 103 with ops.control_dependencies( 104 ops.get_collection(ops.GraphKeys.UPDATE_OPS) 105 if update_ops_in_replica_mode else []): 106 loss = math_ops.reduce_mean( 107 math_ops.reduce_sum(layer(y)) - constant_op.constant(1.)) 108 # `x` and `y` will be fetched by the gradient computation, but not `loss`. 109 return loss 110 111 if strategy_test_lib.is_optimizer_v2_instance(optimizer): 112 return optimizer.minimize(loss_fn, lambda: layer.trainable_variables) 113 114 # Callable loss. 115 return optimizer.minimize(loss_fn) 116 117 return model_fn, dataset_fn, batchnorm 118