xref: /aosp_15_r20/external/tensorflow/tensorflow/python/tpu/tests/tpu_embedding_base_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Base Class for TPU Embedding tests."""
16
17import os
18
19from absl import flags
20from absl.testing import parameterized
21import numpy as np
22
23from tensorflow.python.data.ops import dataset_ops
24from tensorflow.python.distribute import tpu_strategy
25from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
26from tensorflow.python.eager import remote
27from tensorflow.python.framework import constant_op
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import sparse_tensor
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import gen_math_ops
32from tensorflow.python.ops import init_ops_v2
33from tensorflow.python.ops import math_ops
34from tensorflow.python.ops.ragged import ragged_tensor
35from tensorflow.python.platform import test
36from tensorflow.python.tpu import tpu_embedding_v2
37from tensorflow.python.tpu import tpu_embedding_v2_utils
38from tensorflow.python.tpu import tpu_strategy_util
39from tensorflow.python.util import nest
40
41FLAGS = flags.FLAGS
42flags.DEFINE_string('tpu', '', 'Name of TPU to connect to.')
43flags.DEFINE_string('project', None, 'Name of GCP project with TPU.')
44flags.DEFINE_string('zone', None, 'Name of GCP zone with TPU.')
45flags.DEFINE_string('model_dir', os.environ.get('TEST_TMPDIR'),
46                    'A temporary directory.')
47
48
49class TPUEmbeddingBaseTest(parameterized.TestCase, test.TestCase):
50
51  def skip_if_oss(self):
52    if FLAGS.project is not None or FLAGS.zone is not None:
53      self.skipTest(
54          'Skipping tests for oss as it is slow to run every test in cloud tpu.'
55      )
56
57  def setUp(self):
58    super(TPUEmbeddingBaseTest, self).setUp()
59    self.embedding_values = np.array(list(range(32)), dtype=np.float64)
60    self.initializer = init_ops_v2.Constant(self.embedding_values)
61    # Embedding for video initialized to
62    # 0 1 2 3
63    # 4 5 6 7
64    # ...
65    self.table_video = tpu_embedding_v2_utils.TableConfig(
66        vocabulary_size=8,
67        dim=4,
68        initializer=self.initializer,
69        combiner='sum',
70        name='video')
71    # Embedding for user initialized to
72    # 0 1
73    # 2 3
74    # 4 5
75    # 6 7
76    # ...
77    self.table_user = tpu_embedding_v2_utils.TableConfig(
78        vocabulary_size=16,
79        dim=2,
80        initializer=self.initializer,
81        combiner='mean',
82        name='user')
83    self.feature_config = (tpu_embedding_v2_utils.FeatureConfig(
84        table=self.table_video, name='watched'),
85                           tpu_embedding_v2_utils.FeatureConfig(
86                               table=self.table_video, name='favorited'),
87                           tpu_embedding_v2_utils.FeatureConfig(
88                               table=self.table_user, name='friends'))
89
90    self.batch_size = 2
91    self.data_batch_size = 4
92
93    # One (global) batch of inputs
94    # sparse tensor for watched:
95    # row 0: 0
96    # row 1: 0, 1
97    # row 2: 0, 1
98    # row 3: 1
99    self.feature_watched_indices = [[0, 0], [1, 0], [1, 1], [2, 0], [2, 1],
100                                    [3, 0]]
101    self.feature_watched_values = [0, 0, 1, 0, 1, 1]
102    self.feature_watched_row_lengths = [1, 2, 2, 1]
103    # sparse tensor for favorited:
104    # row 0: 0, 1
105    # row 1: 1
106    # row 2: 0
107    # row 3: 0, 1
108    self.feature_favorited_indices = [[0, 0], [0, 1], [1, 0], [2, 0], [3, 0],
109                                      [3, 1]]
110    self.feature_favorited_values = [0, 1, 1, 0, 0, 1]
111    self.feature_favorited_row_lengths = [2, 1, 1, 2]
112    # sparse tensor for friends:
113    # row 0: 3
114    # row 1: 0, 1, 2
115    # row 2: 3
116    # row 3: 0, 1, 2
117    self.feature_friends_indices = [[0, 0], [1, 0], [1, 1], [1, 2], [2, 0],
118                                    [3, 0], [3, 1], [3, 2]]
119    self.feature_friends_values = [3, 0, 1, 2, 3, 0, 1, 2]
120    self.feature_friends_row_lengths = [1, 3, 1, 3]
121    self.resolver = None
122
123    # Basically we are expand the dims of the old feature by 1 and repeat
124    # batch size times for the first dimension.
125    def create_hight_dimensional_indices(indices):
126      indices = np.array(indices, dtype=np.int32)
127      batch_size_index = np.repeat(
128          np.arange(self.data_batch_size), len(indices)).reshape(-1, 1)
129      repeated_indices = np.tile(indices, (self.data_batch_size, 1))
130      return np.concatenate([batch_size_index, repeated_indices], axis=1)
131
132    # Create high dimensional features with shape(4, 4, 2)
133    self.feature_watched_indices_high_dimensional = create_hight_dimensional_indices(
134        self.feature_watched_indices)
135    self.feature_watched_values_high_dimensional = self.feature_watched_values * self.data_batch_size
136    self.feature_watched_row_lengths_high_dimensional = self.feature_watched_row_lengths * self.data_batch_size
137
138    # Create high dimensional features with shape(4, 4, 2)
139    self.feature_favorited_indices_high_dimensional = create_hight_dimensional_indices(
140        self.feature_favorited_indices)
141    self.feature_favorited_values_high_dimensional = self.feature_favorited_values * self.data_batch_size
142    self.feature_favorited_row_lengths_high_dimensional = self.feature_favorited_row_lengths * self.data_batch_size
143
144    # Create high dimensional features with shape(4, 4, 3)
145    self.feature_friends_indices_high_dimensional = create_hight_dimensional_indices(
146        self.feature_friends_indices)
147    self.feature_friends_values_high_dimensional = self.feature_friends_values * self.data_batch_size
148    self.feature_friends_row_lengths_high_dimensional = self.feature_friends_row_lengths * self.data_batch_size
149
150  def _get_strategy(self):
151    self.resolver = tpu_cluster_resolver.TPUClusterResolver(
152        tpu=FLAGS.tpu, zone=FLAGS.zone, project=FLAGS.project)
153    if hasattr(self.resolver, '_cloud_tpu_client'):
154      self.resolver._cloud_tpu_client.configure_tpu_version(
155          version='nightly', restart_type='always')
156    remote.connect_to_cluster(self.resolver)
157    tpu_strategy_util.initialize_tpu_system(self.resolver)
158    return tpu_strategy.TPUStrategy(self.resolver)
159
160  def _create_mid_level(self, optimizer=None):
161    # Create `TPUEmbedding` object.
162    if optimizer is None:
163      optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1)
164
165    return tpu_embedding_v2.TPUEmbedding(
166        feature_config=self.feature_config, optimizer=optimizer)
167
168  def _create_strategy_and_mid_level(self, optimizer_name):
169    strategy = self._get_strategy()
170
171    with strategy.scope():
172      if optimizer_name == 'sgd':
173        optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1)
174      elif optimizer_name == 'adagrad':
175        optimizer = tpu_embedding_v2_utils.Adagrad(learning_rate=0.1)
176      elif optimizer_name == 'adam':
177        optimizer = tpu_embedding_v2_utils.Adam(learning_rate=0.1)
178      elif optimizer_name == 'ftrl':
179        optimizer = tpu_embedding_v2_utils.FTRL(learning_rate=0.1)
180      elif optimizer_name == 'adagrad_momentum':
181        optimizer = tpu_embedding_v2_utils.AdagradMomentum(
182            learning_rate=0.1,
183            momentum=0.9,
184            use_nesterov=True,
185            exponent=3.0,
186            epsilon=0.1,
187            beta2=0.9)
188      else:
189        raise ValueError('optimizer is not recognized: ', optimizer_name)
190      mid_level_api = self._create_mid_level(optimizer=optimizer)
191
192    return strategy, mid_level_api, optimizer
193
194  def _create_sparse_data(self, include_weights, weight=0.5):
195    sparse_features = (sparse_tensor.SparseTensor(
196        indices=self.feature_watched_indices,
197        values=self.feature_watched_values,
198        dense_shape=[self.data_batch_size, 2]),
199                       sparse_tensor.SparseTensor(
200                           indices=self.feature_favorited_indices,
201                           values=self.feature_favorited_values,
202                           dense_shape=[self.data_batch_size, 2]),
203                       sparse_tensor.SparseTensor(
204                           indices=self.feature_friends_indices,
205                           values=self.feature_friends_values,
206                           dense_shape=[self.data_batch_size, 3]))
207    if include_weights:
208      weights = []
209      for sparse in sparse_features:
210        values = (
211            array_ops.ones_like(sparse.values, dtype=dtypes.float32) * weight)
212        weights.append(
213            sparse_tensor.SparseTensor(
214                indices=sparse.indices,
215                values=values,
216                dense_shape=sparse.dense_shape))
217      sparse_features = (sparse_features, tuple(weights))
218    return sparse_features
219
220  def _create_sparse_dataset(self, strategy, include_weights=False, weight=0.5):
221    # Create dataset for enqueue operation
222    sparse_features = self._create_sparse_data(include_weights, weight)
223
224    dataset = dataset_ops.DatasetV2.from_tensors(sparse_features)
225
226    # Data is batched to self.data_batch_size, rebatch to global batch size.
227    return dataset.unbatch().repeat().batch(
228        self.batch_size * strategy.num_replicas_in_sync, drop_remainder=True)
229
230  def _create_high_dimensional_sparse_dataset(self,
231                                              strategy,
232                                              include_weights=False,
233                                              weight=0.5):
234    sparse_features = (
235        sparse_tensor.SparseTensor(
236            indices=self.feature_watched_indices_high_dimensional,
237            values=self.feature_watched_values_high_dimensional,
238            dense_shape=[self.data_batch_size, self.data_batch_size, 2]),
239        sparse_tensor.SparseTensor(
240            indices=self.feature_favorited_indices_high_dimensional,
241            values=self.feature_favorited_values_high_dimensional,
242            dense_shape=[self.data_batch_size, self.data_batch_size, 2]),
243        sparse_tensor.SparseTensor(
244            indices=self.feature_friends_indices_high_dimensional,
245            values=self.feature_friends_values_high_dimensional,
246            dense_shape=[self.data_batch_size, self.data_batch_size, 3]))
247    if include_weights:
248      weights = []
249      for sparse in sparse_features:
250        values = (
251            array_ops.ones_like(sparse.values, dtype=dtypes.float32) * weight)
252        weights.append(
253            sparse_tensor.SparseTensor(
254                indices=sparse.indices,
255                values=values,
256                dense_shape=sparse.dense_shape))
257      sparse_features = (sparse_features, tuple(weights))
258
259    dataset = dataset_ops.DatasetV2.from_tensors(sparse_features)
260    # Data is batched to self.data_batch_size, rebatch to global batch size.
261    return dataset.unbatch().repeat().batch(
262        self.batch_size * strategy.num_replicas_in_sync, drop_remainder=True)
263
264  def _create_high_dimensional_ragged_dataset(self,
265                                              strategy,
266                                              include_weights=False,
267                                              weight=0.5):
268    ragged_features = (
269        ragged_tensor.RaggedTensor.from_row_lengths(
270            row_lengths=self.feature_watched_row_lengths_high_dimensional,
271            values=self.feature_watched_values_high_dimensional),
272        ragged_tensor.RaggedTensor.from_row_lengths(
273            row_lengths=self.feature_favorited_row_lengths_high_dimensional,
274            values=self.feature_favorited_values_high_dimensional),
275        ragged_tensor.RaggedTensor.from_row_lengths(
276            row_lengths=self.feature_friends_row_lengths_high_dimensional,
277            values=self.feature_friends_values_high_dimensional))
278    if include_weights:
279      weights = []
280      for ragged in ragged_features:
281        values = (
282            array_ops.ones_like(ragged.values, dtype=dtypes.float32) * weight)
283        weights.append(
284            ragged_tensor.RaggedTensor(
285                row_lengths=ragged.row_lengths(), values=values))
286      ragged_features = (ragged_features, tuple(weights))
287
288    dataset = dataset_ops.DatasetV2.from_tensors(ragged_features)
289    # Data is batched to self.data_batch_size, rebatch to global batch size.
290    return dataset.unbatch().repeat().batch(
291        self.batch_size * strategy.num_replicas_in_sync, drop_remainder=True)
292
293  def _create_ragged_dataset(self, strategy, include_weights=False, weight=0.5):
294    # Create dataset for enqueue operation
295    sparse_features = self._create_sparse_data(include_weights, weight)
296    ragged_features = nest.map_structure(ragged_tensor.RaggedTensor.from_sparse,
297                                         sparse_features)
298
299    dataset = dataset_ops.DatasetV2.from_tensors(ragged_features)
300
301    # Data is batched to self.data_batch_size, rebatch to global batch size.
302    return dataset.unbatch().repeat().batch(
303        self.batch_size * strategy.num_replicas_in_sync, drop_remainder=True)
304
305  def _create_dense_dataset(self, strategy, include_weights=False, weight=0.5):
306
307    features = (constant_op.constant(
308        self.feature_watched_values[:self.data_batch_size], dtype=dtypes.int32),
309                constant_op.constant(
310                    self.feature_favorited_values[:self.data_batch_size],
311                    dtype=dtypes.int32),
312                constant_op.constant(
313                    self.feature_friends_values[:self.data_batch_size],
314                    dtype=dtypes.int32))
315    if include_weights:
316      weights = [
317          array_ops.ones_like(t, dtype=dtypes.float32) * weight
318          for t in features
319      ]
320      features = (features, tuple(weights))
321
322    dataset = dataset_ops.DatasetV2.from_tensors(features)
323    return dataset.unbatch().repeat().batch(
324        self.batch_size * strategy.num_replicas_in_sync, drop_remainder=True)
325
326  def _create_high_dimensional_dense_dataset(self,
327                                             strategy,
328                                             include_weights=False,
329                                             weight=0.5):
330
331    dense_size = self.data_batch_size * self.data_batch_size
332    features = (constant_op.constant(
333        self.feature_watched_values_high_dimensional[:dense_size],
334        shape=(self.data_batch_size, self.data_batch_size, 1),
335        dtype=dtypes.int32),
336                constant_op.constant(
337                    self.feature_favorited_values_high_dimensional[:dense_size],
338                    shape=(self.data_batch_size, self.data_batch_size, 1),
339                    dtype=dtypes.int32),
340                constant_op.constant(
341                    self.feature_friends_values_high_dimensional[:dense_size],
342                    shape=(self.data_batch_size, self.data_batch_size, 1),
343                    dtype=dtypes.int32))
344    if include_weights:
345      weights = [
346          array_ops.ones_like(t, dtype=dtypes.float32) * weight
347          for t in features
348      ]
349      features = (features, tuple(weights))
350    dataset = dataset_ops.DatasetV2.from_tensors(features)
351    return dataset.unbatch().repeat().batch(
352        self.batch_size * strategy.num_replicas_in_sync, drop_remainder=True)
353
354  def _check_results(self, strategy, shard_out_val, training, input_data,
355                     table_to_variable, optimizer, is_high_dimensional):
356    num_replicas = strategy.num_replicas_in_sync
357
358    # Unpack the values `strategy.run()` returns.
359    loss = self._unpack(strategy, shard_out_val[0])
360    activation_watched = self._unpack(strategy, shard_out_val[1])
361    activation_favorited = self._unpack(strategy, shard_out_val[2])
362    activation_friends = self._unpack(strategy, shard_out_val[3])
363
364    # Core 0:
365    # Calculate the values of embedding activations.
366    activation_watched_gold0 = np.array([[0, 1, 2, 3], [4, 6, 8, 10]])
367    activation_favorited_gold0 = np.array([[4, 6, 8, 10], [4, 5, 6, 7]])
368    # Second row of `activation_friends_gold0` is the mean of the following.
369    # row 0: 0 1
370    # row 1: 2 3
371    # row 2: 4 5
372    activation_friends_gold0 = np.array([[6, 7], [2, 3]])
373
374    loss_gold0 = self._compute_loss(activation_watched_gold0,
375                                    activation_favorited_gold0,
376                                    activation_friends_gold0)
377
378    # Add on values from other cores:
379    # Activations for watched are an alternating sequence of
380    # activation_watched_gold0 and activation_favorited_gold0.
381    # For favorited it is the same but in the opposite order.
382    activation_watched_gold = np.concatenate(
383        (activation_watched_gold0, activation_favorited_gold0))
384    activation_favorited_gold = np.concatenate(
385        (activation_favorited_gold0, activation_watched_gold0))
386    activation_friends_gold = np.concatenate(
387        (activation_friends_gold0, activation_friends_gold0))
388
389    if is_high_dimensional:
390      activation_watched_gold = np.stack([activation_watched_gold] *
391                                         self.batch_size * num_replicas)
392
393      activation_favorited_gold = np.stack([activation_favorited_gold] *
394                                           self.batch_size * num_replicas)
395
396      activation_friends_gold = np.stack([activation_friends_gold] *
397                                         self.batch_size * num_replicas)
398    else:
399      if num_replicas == 1:
400        activation_watched_gold = activation_watched_gold0
401        activation_favorited_gold = activation_favorited_gold0
402        activation_friends_gold = activation_friends_gold0
403      else:
404        activation_watched_gold = np.concatenate(
405            [activation_watched_gold] * (num_replicas // self.batch_size))
406        activation_favorited_gold = np.concatenate(
407            [activation_favorited_gold] * (num_replicas // self.batch_size))
408        activation_friends_gold = np.concatenate(
409            [activation_friends_gold] * (num_replicas // self.batch_size))
410
411    loss_gold = [loss_gold0] * num_replicas
412
413    # Test values.
414    self.assertAllClose(activation_watched_gold, activation_watched)
415    self.assertAllClose(activation_favorited_gold, activation_favorited)
416    self.assertAllClose(activation_friends_gold, activation_friends)
417
418    self.assertAllClose(loss_gold, loss)
419
420    embedding_table_video_before = np.copy(
421        np.reshape(self.embedding_values, [8, 4]))
422    embedding_table_user_before = np.copy(
423        np.reshape(self.embedding_values, [16, 2]))
424    if is_high_dimensional:
425      global_batch_size = self.batch_size * self.data_batch_size * num_replicas
426    else:
427      global_batch_size = self.batch_size * num_replicas
428    if training:
429      gradient_wrt_watched_gold = (2 * activation_watched_gold /
430                                   global_batch_size)
431      gradient_wrt_favorited_gold = (2 * activation_favorited_gold /
432                                     global_batch_size)
433      gradient_wrt_friends_gold = (2 * activation_friends_gold /
434                                   global_batch_size)
435
436      # Calculate gradients wrt embedding tables.
437      gradients_wrt_user = (
438          self._compute_gradients_wrt_embedding_table(
439              gradient_wrt_friends_gold, embedding_table_user_before,
440              input_data[2].indices.numpy(), input_data[2].values.numpy(),
441              self.table_user.combiner))
442      gradients_wrt_video = (
443          self._compute_gradients_wrt_embedding_table(
444              gradient_wrt_favorited_gold, embedding_table_video_before,
445              input_data[1].indices.numpy(), input_data[1].values.numpy(),
446              self.table_video.combiner) +
447          self._compute_gradients_wrt_embedding_table(
448              gradient_wrt_watched_gold, embedding_table_video_before,
449              input_data[0].indices.numpy(), input_data[0].values.numpy(),
450              self.table_video.combiner))
451
452      self._check_embedding_and_slot_variables(embedding_table_user_before,
453                                               gradients_wrt_user,
454                                               embedding_table_video_before,
455                                               gradients_wrt_video, optimizer,
456                                               table_to_variable)
457
458  def _check_embedding_and_slot_variables(self, embedding_table_user_before,
459                                          gradients_wrt_user,
460                                          embedding_table_video_before,
461                                          gradients_wrt_video, optimizer,
462                                          table_to_variable):
463    if isinstance(optimizer, tpu_embedding_v2_utils.SGD):
464      check_fn = self._check_embedding_and_slot_variables_for_sgd
465    elif isinstance(optimizer, tpu_embedding_v2_utils.Adagrad):
466      check_fn = self._check_embedding_and_slot_variables_for_adagrad
467    elif isinstance(optimizer, tpu_embedding_v2_utils.AdagradMomentum):
468      check_fn = self._check_embedding_and_slot_variables_for_adagrad_momentum
469    elif isinstance(optimizer, tpu_embedding_v2_utils.Adam):
470      check_fn = self._check_embedding_and_slot_variables_for_adam
471    elif isinstance(optimizer, tpu_embedding_v2_utils.FTRL):
472      check_fn = self._check_embedding_and_slot_variables_for_ftrl
473    else:
474      raise ValueError('optimizer is not recognized: ', type(optimizer))
475    check_fn(embedding_table_user_before, gradients_wrt_user, optimizer,
476             table_to_variable[self.table_user.name])
477    check_fn(embedding_table_video_before, gradients_wrt_video, optimizer,
478             table_to_variable[self.table_video.name])
479
480  def _check_embedding_and_slot_variables_for_sgd(self, embedding_table_before,
481                                                  gradients, optimizer,
482                                                  variables):
483    embedding_table = np.copy(embedding_table_before)
484    embedding_table -= optimizer.learning_rate * np.sum(gradients, axis=0)
485    self.assertAllClose(
486        self._get_variable(variables['parameters']).numpy(), embedding_table)
487
488  def _check_embedding_and_slot_variables_for_adagrad(self,
489                                                      embedding_table_before,
490                                                      gradients, optimizer,
491                                                      variable):
492    embedding_table = np.copy(embedding_table_before)
493    accumulator = (
494        optimizer.initial_accumulator_value + np.sum(gradients, axis=0)**2)
495    embedding_table -= (
496        optimizer.learning_rate * np.sum(gradients, axis=0) /
497        np.sqrt(accumulator))
498    self.assertAllClose(
499        self._get_variable(variable['parameters']).numpy(), embedding_table)
500    self.assertAllClose(
501        self._get_variable(variable['accumulators']).numpy(), accumulator)
502
503  def _check_embedding_and_slot_variables_for_adagrad_momentum(
504      self, embedding_table_before, gradients, optimizer, variable):
505    embedding_table = np.copy(embedding_table_before)
506    accumulator = np.zeros(self._get_variable(variable['accumulators']).shape)
507    momenta = np.zeros(self._get_variable(variable['momenta']).shape)
508    gradients = np.sum(gradients, axis=0)
509    if optimizer.beta2 == 1.0:
510      accumulator += gradients**2
511    else:
512      accumulator = optimizer.beta2 * accumulator + (
513          1 - optimizer.beta2) * gradients**2
514    accumulator_power = np.power(accumulator + optimizer.epsilon,
515                                 -1.0 / optimizer.exponent)
516    momenta = optimizer.momentum * momenta + gradients * accumulator_power
517    if optimizer.use_nesterov:
518      update = optimizer.momentum * momenta + gradients * accumulator_power
519    else:
520      update = momenta
521    embedding_table -= optimizer.learning_rate * update
522    self.assertAllClose(
523        self._get_variable(variable['parameters']).numpy(),
524        embedding_table,
525        rtol=1e-3)
526    self.assertAllClose(
527        self._get_variable(variable['accumulators']).numpy(),
528        accumulator,
529        rtol=1e-3)
530    self.assertAllClose(
531        self._get_variable(variable['momenta']).numpy(), momenta, rtol=1e-3)
532
533  def _check_embedding_and_slot_variables_for_adam(self, embedding_table_before,
534                                                   gradients, optimizer,
535                                                   variable):
536    embedding_table = np.copy(embedding_table_before)
537    g = np.sum(gradients, axis=0)
538    v = g**2 * (1 - optimizer.beta_2)
539    m = g * (1 - optimizer.beta_1)
540    epsilon = optimizer.epsilon
541    # TPU Embeddings don't have the LR decay factor for Adam.
542    lr_modifier = 1
543    embedding_table -= (
544        m * optimizer.learning_rate * lr_modifier / (np.sqrt(v) + epsilon))
545    self.assertAllClose(
546        self._get_variable(variable['parameters']).numpy(),
547        embedding_table,
548        rtol=1e-4)
549    self.assertAllClose(
550        self._get_variable(variable['momenta']).numpy(), m, rtol=1e-4)
551    self.assertAllClose(
552        self._get_variable(variable['velocities']).numpy(), v, rtol=1e-4)
553
554  def _check_embedding_and_slot_variables_for_ftrl(self, embedding_table_before,
555                                                   gradients, optimizer,
556                                                   variable):
557    embedding_table = np.copy(embedding_table_before)
558    neg_lr_p = -optimizer.learning_rate_power
559    accumulator = (
560        optimizer.initial_accumulator_value + np.sum(gradients, axis=0)**2)
561    sigma = (accumulator**neg_lr_p - optimizer.initial_accumulator_value**
562             neg_lr_p) / optimizer.learning_rate
563    linear = np.sum(gradients, axis=0) - sigma * embedding_table
564    quadratic = accumulator**neg_lr_p / optimizer.learning_rate
565    embedding_table = -linear / quadratic
566    actual_parameters = self._get_variable(variable['parameters']).numpy()
567    # For entries where `linear` == 0, it is not worth comparing since the
568    # initial values have not been touched yet and they will not agree with what
569    # the actual values should be.
570    actual_parameters *= (linear != 0.0)
571    # FTRL has a bit more precision diff on parameters.
572    self.assertAllClose(actual_parameters, embedding_table, rtol=5e-5)
573    self.assertAllClose(
574        self._get_variable(variable['linears']).numpy(), linear, rtol=5e-4)
575    self.assertAllClose(
576        self._get_variable(variable['accumulators']).numpy(), accumulator)
577
578  def _get_replica_numpy(self, structured, strategy, replica_id):
579
580    def select_replica(x):
581      x = strategy.experimental_local_results(x)
582      if len(x) == 1:
583        return x.numpy()
584      return x[replica_id].numpy()
585
586    return nest.map_structure(select_replica, structured)
587
588  def _compute_gradients_wrt_embedding_table(self, gradient_wrt_activation,
589                                             embedding_table, feature_indices,
590                                             feature_values, combiner):
591    """Compute gradients wrt embedding_table.
592
593    Args:
594      gradient_wrt_activation: `np.array` with shape `batch_size` by embedding
595        `dimension`.
596      embedding_table: `np.array` with shape `vocabulary_size` by embedding
597        `dimension`.
598      feature_indices: `indices` as used to construct `SparseTensor`.
599      feature_values: `values` as used to construct `SparseTensor`.
600      combiner: `String`, 'mean' or 'sum'.
601
602    Returns:
603      Gradients wrt `embedding_table`, an `np.array`s with shape
604        `batch_size` by `vocabulary_size` by
605        embedding `dimension`.
606
607    Raises:
608      ValueError: if `combiner` is not one of 'mean' or 'sum'.
609    """
610    if combiner not in ('mean', 'sum'):
611      raise ValueError(
612          '`combiner` must be mean or sum; got {}.'.format(combiner))
613    grads_shape = gradient_wrt_activation.shape[:-1] + embedding_table.shape
614    grads = np.zeros(shape=grads_shape)
615    count = np.zeros(shape=grads_shape)
616    for feature_indice, vocabulary_id in zip(feature_indices, feature_values):
617      batch_index = tuple(feature_indice[:-1])
618      grads[batch_index][vocabulary_id] += gradient_wrt_activation[batch_index]
619      count[batch_index] += 1
620    count[count == 0] = 1
621    if combiner == 'mean':
622      grads = grads / count
623    return np.reshape(grads, (-1, *embedding_table.shape))
624
625  def _unpack(self, strategy, per_replica_output):
626    per_replica_output = strategy.experimental_local_results(per_replica_output)
627    per_replica_output = array_ops.concat(per_replica_output, axis=0).numpy()
628    return per_replica_output
629
630  def _get_total_loss_tensor(self, activations):
631    losses = []
632    for activation in activations:
633      losses.append(
634          math_ops.reduce_mean(
635              math_ops.reduce_sum(
636                  gen_math_ops.squared_difference(activation, 0), axis=-1)))
637    total_loss = array_ops.expand_dims_v2(sum(losses), 0)
638    return total_loss
639
640  def _compute_loss(self, activation_watched, activation_favorited,
641                    activation_friends):
642    watched_loss = np.mean(np.sum(activation_watched**2, axis=-1))
643    favorited_loss = np.mean(np.sum(activation_favorited**2, axis=-1))
644    friends_loss = np.mean(np.sum(activation_friends**2, axis=-1))
645    loss = watched_loss + favorited_loss + friends_loss
646    return loss
647
648  def _get_variable(self, variable):
649    if isinstance(variable, tpu_embedding_v2.TPUEmbeddingVariable):
650      return variable.variables[0]
651    return variable
652
653  def _get_tmpdir(self, name, subdir=''):
654    segments = [FLAGS.model_dir, name] + ([subdir] if subdir else [])
655    return os.path.join(*segments)
656