xref: /aosp_15_r20/external/tensorflow/tensorflow/python/tpu/tpu_embedding_gradient.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ===================================================================
15"""Optional helper for gradient handling."""
16
17import collections
18
19from tensorflow.python.framework import dtypes
20from tensorflow.python.framework import ops
21from tensorflow.python.ops import array_ops
22from tensorflow.python.ops import variable_scope
23from tensorflow.python.ops import variables
24from tensorflow.python.platform import tf_logging as logging
25from tensorflow.python.tpu.ops import tpu_ops
26
27
28def get_gradients_through_compute_gradients(optimizer, loss, activations):
29  """Compute gradients to send to TPU embedding.
30
31  Args:
32    optimizer: a subclass of optimizer.Optimizer, usually CrossShardOptimizer.
33      Used to call compute_gradients().
34    loss: a Tensor to call optimizer.compute_gradients() on.
35    activations: an OrderedDict mapping feature_name to Tensors of activations.
36
37  Returns:
38    An OrderedDict mapping from feature name Strings to Tensors of gradients of
39      the loss wrt the activations of the features.
40  """
41  activation_list = activations.values()
42  grads_and_vars = optimizer.compute_gradients(loss, activation_list)
43  grads = [grad for grad, _ in grads_and_vars]
44  feature_to_gradient_dict = collections.OrderedDict(
45      zip(activations.keys(), grads))
46  return feature_to_gradient_dict
47
48
49def create_dummy_table_variables(tpu_embedding):
50  """Create dummy embedding table variables.
51
52  The sole purpose of these dummy variables are to trigger gradient
53  calculation wrt them so that the gradients wrt activation can be captured
54  and later sent to TPU embedding.
55
56  Args:
57    tpu_embedding: TPUEmbedding, dummy table variables will be created for use
58      with tpu_embedding.
59
60  Returns:
61    A tuple of dummy variables and their initializer.
62
63  Raises:
64    RuntimeError: if collection to store gradients already exists and is not
65    empty.
66  """
67  dummy_table_variables = collections.OrderedDict()
68  for table_id, table in enumerate(tpu_embedding.table_to_features_dict):
69    dummy_table_variables[table] = (
70        # Explicitly specifying collections prevents this variable from
71        # being added to the GLOBAL_VARIABLES collection, so that Saver()
72        # ignores it.
73        # But Tensorflow optimizer creates slot variable for these dummy
74        # variable, e.g. tpu_embedding_dummy_table_variable_mlp_user/Adam{_1},
75        # which will be in GLOBAL_VARIABLES collection,
76        variable_scope.get_variable(
77            'tpu_embedding_dummy_table_variable_{}'.format(table),
78            dtype=dtypes.float32,
79            shape=[1],
80            use_resource=True,
81            trainable=True,
82            collections=['tpu_embedding_dummy_table_variables']))
83
84    g = ops.get_default_graph()
85    table_gradients = g.get_collection_ref(
86        'tpu_embedding_gradients_table_{}'.format(table_id))
87    if table_gradients:
88      raise RuntimeError(
89          'tpu_embedding_gradients_table_{} is not empty.'.format(table_id))
90    num_features = len(tpu_embedding.table_to_features_dict[table])
91    table_gradients.extend([None for _ in range(num_features)])
92
93  return (dummy_table_variables,
94          variables.variables_initializer(
95              dummy_table_variables.values(),
96              name='tpu_embedding_dummy_table_variables_init'))
97
98
99def hook_dummy_table_variables_to_activations(tpu_embedding, activations,
100                                              dummy_table_variables):
101  """Have activations depend on dummy table variables for gradient intercept.
102
103  Args:
104    tpu_embedding: TPUEmbedding, activations and dummy_table_variables are from
105      tpu_embedding.
106    activations: An OrderedDict of feature name String to activation tensors.
107    dummy_table_variables: An OrderedDict of table name String to dummy table
108      variables.
109
110  Returns:
111    An OrderedDict of feature name String to activation tensors, which can be
112      used just as the activations input.
113  """
114  new_activations = collections.OrderedDict()
115  for feature in activations:
116    table = tpu_embedding.feature_to_config_dict[feature].table_id
117    new_activations[feature] = tpu_ops.tpu_embedding_activations(
118        dummy_table_variables[table],
119        activations[feature],
120        table_id=list(tpu_embedding.table_to_config_dict).index(table),
121        lookup_id=tpu_embedding.table_to_features_dict[table].index(feature))
122  return new_activations
123
124
125def get_gradients_through_dummy_table_variables(tpu_embedding):
126  """Get gradients wrt the activations of each feature.
127
128  Args:
129    tpu_embedding: TPUEmbedding, create dummy table variable to be used with
130      tpu_embedding.
131
132  Returns:
133    An OrderedDict mapping feature name to gradient.
134
135  Raises:
136    ValueError: if some gradients are not defined.
137  """
138  g = ops.get_default_graph()
139  gradients_found = False
140  for table_id, table in enumerate(tpu_embedding.table_to_config_dict):
141    table_gradients = g.get_collection(
142        'tpu_embedding_gradients_table_{}'.format(table_id))
143    if any(gradient is None for gradient in table_gradients):
144      # TODO(bfontain): create a white-list for optimizers which are compatible
145      # with `tf.stop_gradient`.
146      logging.warn(
147          'Table {} with id {} has undefined gradients: this is probably '
148          'because the model asked TPUEmbedding to compute activations that '
149          'were not used, or tf.stop_gradient() is applied. Gradients of zeros '
150          'are sent back to TPUEmbedding instead. Gradients of zeros and no '
151          'gradients are equivalent for SGD, AdaGrad, FTRL, etc, but '
152          'might differ for other optimizers due to implementation of TPU '
153          'embedding optimizers.'.format(table, table_id))
154    gradients_found = gradients_found or any(
155        gradient is not None for gradient in table_gradients)
156
157  if not gradients_found:
158    logging.warn(
159        'All tables have undefined gradients: this is probably because the '
160        'model asked TPUEmbedding to compute activations that were not used. '
161        'If all TPUEmbedding features have stop_gradients, consider using the '
162        'INFERENCE mode instead.')
163
164  feature_to_gradient_dict = collections.OrderedDict()
165  for table_id, table in enumerate(tpu_embedding.table_to_config_dict):
166    table_gradients = g.get_collection(
167        'tpu_embedding_gradients_table_{}'.format(table_id))
168    for feature, gradient in zip(tpu_embedding.table_to_features_dict[table],
169                                 table_gradients):
170      if gradient is not None:
171        feature_to_gradient_dict[feature] = gradient
172      else:
173        dimension = tpu_embedding.table_to_config_dict[table].dimension
174        batch_size = tpu_embedding.batch_size_per_core
175        max_sequence_length = (
176            tpu_embedding.feature_to_config_dict[feature].max_sequence_length)
177        if max_sequence_length:
178          feature_to_gradient_dict[feature] = array_ops.zeros(
179              [batch_size, max_sequence_length, dimension])
180        else:
181          feature_to_gradient_dict[feature] = array_ops.zeros(
182              [batch_size, dimension])
183
184  return feature_to_gradient_dict
185