1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# =================================================================== 15"""Optional helper for gradient handling.""" 16 17import collections 18 19from tensorflow.python.framework import dtypes 20from tensorflow.python.framework import ops 21from tensorflow.python.ops import array_ops 22from tensorflow.python.ops import variable_scope 23from tensorflow.python.ops import variables 24from tensorflow.python.platform import tf_logging as logging 25from tensorflow.python.tpu.ops import tpu_ops 26 27 28def get_gradients_through_compute_gradients(optimizer, loss, activations): 29 """Compute gradients to send to TPU embedding. 30 31 Args: 32 optimizer: a subclass of optimizer.Optimizer, usually CrossShardOptimizer. 33 Used to call compute_gradients(). 34 loss: a Tensor to call optimizer.compute_gradients() on. 35 activations: an OrderedDict mapping feature_name to Tensors of activations. 36 37 Returns: 38 An OrderedDict mapping from feature name Strings to Tensors of gradients of 39 the loss wrt the activations of the features. 40 """ 41 activation_list = activations.values() 42 grads_and_vars = optimizer.compute_gradients(loss, activation_list) 43 grads = [grad for grad, _ in grads_and_vars] 44 feature_to_gradient_dict = collections.OrderedDict( 45 zip(activations.keys(), grads)) 46 return feature_to_gradient_dict 47 48 49def create_dummy_table_variables(tpu_embedding): 50 """Create dummy embedding table variables. 51 52 The sole purpose of these dummy variables are to trigger gradient 53 calculation wrt them so that the gradients wrt activation can be captured 54 and later sent to TPU embedding. 55 56 Args: 57 tpu_embedding: TPUEmbedding, dummy table variables will be created for use 58 with tpu_embedding. 59 60 Returns: 61 A tuple of dummy variables and their initializer. 62 63 Raises: 64 RuntimeError: if collection to store gradients already exists and is not 65 empty. 66 """ 67 dummy_table_variables = collections.OrderedDict() 68 for table_id, table in enumerate(tpu_embedding.table_to_features_dict): 69 dummy_table_variables[table] = ( 70 # Explicitly specifying collections prevents this variable from 71 # being added to the GLOBAL_VARIABLES collection, so that Saver() 72 # ignores it. 73 # But Tensorflow optimizer creates slot variable for these dummy 74 # variable, e.g. tpu_embedding_dummy_table_variable_mlp_user/Adam{_1}, 75 # which will be in GLOBAL_VARIABLES collection, 76 variable_scope.get_variable( 77 'tpu_embedding_dummy_table_variable_{}'.format(table), 78 dtype=dtypes.float32, 79 shape=[1], 80 use_resource=True, 81 trainable=True, 82 collections=['tpu_embedding_dummy_table_variables'])) 83 84 g = ops.get_default_graph() 85 table_gradients = g.get_collection_ref( 86 'tpu_embedding_gradients_table_{}'.format(table_id)) 87 if table_gradients: 88 raise RuntimeError( 89 'tpu_embedding_gradients_table_{} is not empty.'.format(table_id)) 90 num_features = len(tpu_embedding.table_to_features_dict[table]) 91 table_gradients.extend([None for _ in range(num_features)]) 92 93 return (dummy_table_variables, 94 variables.variables_initializer( 95 dummy_table_variables.values(), 96 name='tpu_embedding_dummy_table_variables_init')) 97 98 99def hook_dummy_table_variables_to_activations(tpu_embedding, activations, 100 dummy_table_variables): 101 """Have activations depend on dummy table variables for gradient intercept. 102 103 Args: 104 tpu_embedding: TPUEmbedding, activations and dummy_table_variables are from 105 tpu_embedding. 106 activations: An OrderedDict of feature name String to activation tensors. 107 dummy_table_variables: An OrderedDict of table name String to dummy table 108 variables. 109 110 Returns: 111 An OrderedDict of feature name String to activation tensors, which can be 112 used just as the activations input. 113 """ 114 new_activations = collections.OrderedDict() 115 for feature in activations: 116 table = tpu_embedding.feature_to_config_dict[feature].table_id 117 new_activations[feature] = tpu_ops.tpu_embedding_activations( 118 dummy_table_variables[table], 119 activations[feature], 120 table_id=list(tpu_embedding.table_to_config_dict).index(table), 121 lookup_id=tpu_embedding.table_to_features_dict[table].index(feature)) 122 return new_activations 123 124 125def get_gradients_through_dummy_table_variables(tpu_embedding): 126 """Get gradients wrt the activations of each feature. 127 128 Args: 129 tpu_embedding: TPUEmbedding, create dummy table variable to be used with 130 tpu_embedding. 131 132 Returns: 133 An OrderedDict mapping feature name to gradient. 134 135 Raises: 136 ValueError: if some gradients are not defined. 137 """ 138 g = ops.get_default_graph() 139 gradients_found = False 140 for table_id, table in enumerate(tpu_embedding.table_to_config_dict): 141 table_gradients = g.get_collection( 142 'tpu_embedding_gradients_table_{}'.format(table_id)) 143 if any(gradient is None for gradient in table_gradients): 144 # TODO(bfontain): create a white-list for optimizers which are compatible 145 # with `tf.stop_gradient`. 146 logging.warn( 147 'Table {} with id {} has undefined gradients: this is probably ' 148 'because the model asked TPUEmbedding to compute activations that ' 149 'were not used, or tf.stop_gradient() is applied. Gradients of zeros ' 150 'are sent back to TPUEmbedding instead. Gradients of zeros and no ' 151 'gradients are equivalent for SGD, AdaGrad, FTRL, etc, but ' 152 'might differ for other optimizers due to implementation of TPU ' 153 'embedding optimizers.'.format(table, table_id)) 154 gradients_found = gradients_found or any( 155 gradient is not None for gradient in table_gradients) 156 157 if not gradients_found: 158 logging.warn( 159 'All tables have undefined gradients: this is probably because the ' 160 'model asked TPUEmbedding to compute activations that were not used. ' 161 'If all TPUEmbedding features have stop_gradients, consider using the ' 162 'INFERENCE mode instead.') 163 164 feature_to_gradient_dict = collections.OrderedDict() 165 for table_id, table in enumerate(tpu_embedding.table_to_config_dict): 166 table_gradients = g.get_collection( 167 'tpu_embedding_gradients_table_{}'.format(table_id)) 168 for feature, gradient in zip(tpu_embedding.table_to_features_dict[table], 169 table_gradients): 170 if gradient is not None: 171 feature_to_gradient_dict[feature] = gradient 172 else: 173 dimension = tpu_embedding.table_to_config_dict[table].dimension 174 batch_size = tpu_embedding.batch_size_per_core 175 max_sequence_length = ( 176 tpu_embedding.feature_to_config_dict[feature].max_sequence_length) 177 if max_sequence_length: 178 feature_to_gradient_dict[feature] = array_ops.zeros( 179 [batch_size, max_sequence_length, dimension]) 180 else: 181 feature_to_gradient_dict[feature] = array_ops.zeros( 182 [batch_size, dimension]) 183 184 return feature_to_gradient_dict 185