1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Base Class for TPU Embedding tests.""" 16 17import os 18 19from absl import flags 20from absl.testing import parameterized 21import numpy as np 22 23from tensorflow.python.data.ops import dataset_ops 24from tensorflow.python.distribute import tpu_strategy 25from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver 26from tensorflow.python.eager import remote 27from tensorflow.python.framework import constant_op 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import sparse_tensor 30from tensorflow.python.ops import array_ops 31from tensorflow.python.ops import gen_math_ops 32from tensorflow.python.ops import init_ops_v2 33from tensorflow.python.ops import math_ops 34from tensorflow.python.ops.ragged import ragged_tensor 35from tensorflow.python.platform import test 36from tensorflow.python.tpu import tpu_embedding_v2 37from tensorflow.python.tpu import tpu_embedding_v2_utils 38from tensorflow.python.tpu import tpu_strategy_util 39from tensorflow.python.util import nest 40 41FLAGS = flags.FLAGS 42flags.DEFINE_string('tpu', '', 'Name of TPU to connect to.') 43flags.DEFINE_string('project', None, 'Name of GCP project with TPU.') 44flags.DEFINE_string('zone', None, 'Name of GCP zone with TPU.') 45flags.DEFINE_string('model_dir', os.environ.get('TEST_TMPDIR'), 46 'A temporary directory.') 47 48 49class TPUEmbeddingBaseTest(parameterized.TestCase, test.TestCase): 50 51 def skip_if_oss(self): 52 if FLAGS.project is not None or FLAGS.zone is not None: 53 self.skipTest( 54 'Skipping tests for oss as it is slow to run every test in cloud tpu.' 55 ) 56 57 def setUp(self): 58 super(TPUEmbeddingBaseTest, self).setUp() 59 self.embedding_values = np.array(list(range(32)), dtype=np.float64) 60 self.initializer = init_ops_v2.Constant(self.embedding_values) 61 # Embedding for video initialized to 62 # 0 1 2 3 63 # 4 5 6 7 64 # ... 65 self.table_video = tpu_embedding_v2_utils.TableConfig( 66 vocabulary_size=8, 67 dim=4, 68 initializer=self.initializer, 69 combiner='sum', 70 name='video') 71 # Embedding for user initialized to 72 # 0 1 73 # 2 3 74 # 4 5 75 # 6 7 76 # ... 77 self.table_user = tpu_embedding_v2_utils.TableConfig( 78 vocabulary_size=16, 79 dim=2, 80 initializer=self.initializer, 81 combiner='mean', 82 name='user') 83 self.feature_config = (tpu_embedding_v2_utils.FeatureConfig( 84 table=self.table_video, name='watched'), 85 tpu_embedding_v2_utils.FeatureConfig( 86 table=self.table_video, name='favorited'), 87 tpu_embedding_v2_utils.FeatureConfig( 88 table=self.table_user, name='friends')) 89 90 self.batch_size = 2 91 self.data_batch_size = 4 92 93 # One (global) batch of inputs 94 # sparse tensor for watched: 95 # row 0: 0 96 # row 1: 0, 1 97 # row 2: 0, 1 98 # row 3: 1 99 self.feature_watched_indices = [[0, 0], [1, 0], [1, 1], [2, 0], [2, 1], 100 [3, 0]] 101 self.feature_watched_values = [0, 0, 1, 0, 1, 1] 102 self.feature_watched_row_lengths = [1, 2, 2, 1] 103 # sparse tensor for favorited: 104 # row 0: 0, 1 105 # row 1: 1 106 # row 2: 0 107 # row 3: 0, 1 108 self.feature_favorited_indices = [[0, 0], [0, 1], [1, 0], [2, 0], [3, 0], 109 [3, 1]] 110 self.feature_favorited_values = [0, 1, 1, 0, 0, 1] 111 self.feature_favorited_row_lengths = [2, 1, 1, 2] 112 # sparse tensor for friends: 113 # row 0: 3 114 # row 1: 0, 1, 2 115 # row 2: 3 116 # row 3: 0, 1, 2 117 self.feature_friends_indices = [[0, 0], [1, 0], [1, 1], [1, 2], [2, 0], 118 [3, 0], [3, 1], [3, 2]] 119 self.feature_friends_values = [3, 0, 1, 2, 3, 0, 1, 2] 120 self.feature_friends_row_lengths = [1, 3, 1, 3] 121 self.resolver = None 122 123 # Basically we are expand the dims of the old feature by 1 and repeat 124 # batch size times for the first dimension. 125 def create_hight_dimensional_indices(indices): 126 indices = np.array(indices, dtype=np.int32) 127 batch_size_index = np.repeat( 128 np.arange(self.data_batch_size), len(indices)).reshape(-1, 1) 129 repeated_indices = np.tile(indices, (self.data_batch_size, 1)) 130 return np.concatenate([batch_size_index, repeated_indices], axis=1) 131 132 # Create high dimensional features with shape(4, 4, 2) 133 self.feature_watched_indices_high_dimensional = create_hight_dimensional_indices( 134 self.feature_watched_indices) 135 self.feature_watched_values_high_dimensional = self.feature_watched_values * self.data_batch_size 136 self.feature_watched_row_lengths_high_dimensional = self.feature_watched_row_lengths * self.data_batch_size 137 138 # Create high dimensional features with shape(4, 4, 2) 139 self.feature_favorited_indices_high_dimensional = create_hight_dimensional_indices( 140 self.feature_favorited_indices) 141 self.feature_favorited_values_high_dimensional = self.feature_favorited_values * self.data_batch_size 142 self.feature_favorited_row_lengths_high_dimensional = self.feature_favorited_row_lengths * self.data_batch_size 143 144 # Create high dimensional features with shape(4, 4, 3) 145 self.feature_friends_indices_high_dimensional = create_hight_dimensional_indices( 146 self.feature_friends_indices) 147 self.feature_friends_values_high_dimensional = self.feature_friends_values * self.data_batch_size 148 self.feature_friends_row_lengths_high_dimensional = self.feature_friends_row_lengths * self.data_batch_size 149 150 def _get_strategy(self): 151 self.resolver = tpu_cluster_resolver.TPUClusterResolver( 152 tpu=FLAGS.tpu, zone=FLAGS.zone, project=FLAGS.project) 153 if hasattr(self.resolver, '_cloud_tpu_client'): 154 self.resolver._cloud_tpu_client.configure_tpu_version( 155 version='nightly', restart_type='always') 156 remote.connect_to_cluster(self.resolver) 157 tpu_strategy_util.initialize_tpu_system(self.resolver) 158 return tpu_strategy.TPUStrategy(self.resolver) 159 160 def _create_mid_level(self, optimizer=None): 161 # Create `TPUEmbedding` object. 162 if optimizer is None: 163 optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1) 164 165 return tpu_embedding_v2.TPUEmbedding( 166 feature_config=self.feature_config, optimizer=optimizer) 167 168 def _create_strategy_and_mid_level(self, optimizer_name): 169 strategy = self._get_strategy() 170 171 with strategy.scope(): 172 if optimizer_name == 'sgd': 173 optimizer = tpu_embedding_v2_utils.SGD(learning_rate=0.1) 174 elif optimizer_name == 'adagrad': 175 optimizer = tpu_embedding_v2_utils.Adagrad(learning_rate=0.1) 176 elif optimizer_name == 'adam': 177 optimizer = tpu_embedding_v2_utils.Adam(learning_rate=0.1) 178 elif optimizer_name == 'ftrl': 179 optimizer = tpu_embedding_v2_utils.FTRL(learning_rate=0.1) 180 elif optimizer_name == 'adagrad_momentum': 181 optimizer = tpu_embedding_v2_utils.AdagradMomentum( 182 learning_rate=0.1, 183 momentum=0.9, 184 use_nesterov=True, 185 exponent=3.0, 186 epsilon=0.1, 187 beta2=0.9) 188 else: 189 raise ValueError('optimizer is not recognized: ', optimizer_name) 190 mid_level_api = self._create_mid_level(optimizer=optimizer) 191 192 return strategy, mid_level_api, optimizer 193 194 def _create_sparse_data(self, include_weights, weight=0.5): 195 sparse_features = (sparse_tensor.SparseTensor( 196 indices=self.feature_watched_indices, 197 values=self.feature_watched_values, 198 dense_shape=[self.data_batch_size, 2]), 199 sparse_tensor.SparseTensor( 200 indices=self.feature_favorited_indices, 201 values=self.feature_favorited_values, 202 dense_shape=[self.data_batch_size, 2]), 203 sparse_tensor.SparseTensor( 204 indices=self.feature_friends_indices, 205 values=self.feature_friends_values, 206 dense_shape=[self.data_batch_size, 3])) 207 if include_weights: 208 weights = [] 209 for sparse in sparse_features: 210 values = ( 211 array_ops.ones_like(sparse.values, dtype=dtypes.float32) * weight) 212 weights.append( 213 sparse_tensor.SparseTensor( 214 indices=sparse.indices, 215 values=values, 216 dense_shape=sparse.dense_shape)) 217 sparse_features = (sparse_features, tuple(weights)) 218 return sparse_features 219 220 def _create_sparse_dataset(self, strategy, include_weights=False, weight=0.5): 221 # Create dataset for enqueue operation 222 sparse_features = self._create_sparse_data(include_weights, weight) 223 224 dataset = dataset_ops.DatasetV2.from_tensors(sparse_features) 225 226 # Data is batched to self.data_batch_size, rebatch to global batch size. 227 return dataset.unbatch().repeat().batch( 228 self.batch_size * strategy.num_replicas_in_sync, drop_remainder=True) 229 230 def _create_high_dimensional_sparse_dataset(self, 231 strategy, 232 include_weights=False, 233 weight=0.5): 234 sparse_features = ( 235 sparse_tensor.SparseTensor( 236 indices=self.feature_watched_indices_high_dimensional, 237 values=self.feature_watched_values_high_dimensional, 238 dense_shape=[self.data_batch_size, self.data_batch_size, 2]), 239 sparse_tensor.SparseTensor( 240 indices=self.feature_favorited_indices_high_dimensional, 241 values=self.feature_favorited_values_high_dimensional, 242 dense_shape=[self.data_batch_size, self.data_batch_size, 2]), 243 sparse_tensor.SparseTensor( 244 indices=self.feature_friends_indices_high_dimensional, 245 values=self.feature_friends_values_high_dimensional, 246 dense_shape=[self.data_batch_size, self.data_batch_size, 3])) 247 if include_weights: 248 weights = [] 249 for sparse in sparse_features: 250 values = ( 251 array_ops.ones_like(sparse.values, dtype=dtypes.float32) * weight) 252 weights.append( 253 sparse_tensor.SparseTensor( 254 indices=sparse.indices, 255 values=values, 256 dense_shape=sparse.dense_shape)) 257 sparse_features = (sparse_features, tuple(weights)) 258 259 dataset = dataset_ops.DatasetV2.from_tensors(sparse_features) 260 # Data is batched to self.data_batch_size, rebatch to global batch size. 261 return dataset.unbatch().repeat().batch( 262 self.batch_size * strategy.num_replicas_in_sync, drop_remainder=True) 263 264 def _create_high_dimensional_ragged_dataset(self, 265 strategy, 266 include_weights=False, 267 weight=0.5): 268 ragged_features = ( 269 ragged_tensor.RaggedTensor.from_row_lengths( 270 row_lengths=self.feature_watched_row_lengths_high_dimensional, 271 values=self.feature_watched_values_high_dimensional), 272 ragged_tensor.RaggedTensor.from_row_lengths( 273 row_lengths=self.feature_favorited_row_lengths_high_dimensional, 274 values=self.feature_favorited_values_high_dimensional), 275 ragged_tensor.RaggedTensor.from_row_lengths( 276 row_lengths=self.feature_friends_row_lengths_high_dimensional, 277 values=self.feature_friends_values_high_dimensional)) 278 if include_weights: 279 weights = [] 280 for ragged in ragged_features: 281 values = ( 282 array_ops.ones_like(ragged.values, dtype=dtypes.float32) * weight) 283 weights.append( 284 ragged_tensor.RaggedTensor( 285 row_lengths=ragged.row_lengths(), values=values)) 286 ragged_features = (ragged_features, tuple(weights)) 287 288 dataset = dataset_ops.DatasetV2.from_tensors(ragged_features) 289 # Data is batched to self.data_batch_size, rebatch to global batch size. 290 return dataset.unbatch().repeat().batch( 291 self.batch_size * strategy.num_replicas_in_sync, drop_remainder=True) 292 293 def _create_ragged_dataset(self, strategy, include_weights=False, weight=0.5): 294 # Create dataset for enqueue operation 295 sparse_features = self._create_sparse_data(include_weights, weight) 296 ragged_features = nest.map_structure(ragged_tensor.RaggedTensor.from_sparse, 297 sparse_features) 298 299 dataset = dataset_ops.DatasetV2.from_tensors(ragged_features) 300 301 # Data is batched to self.data_batch_size, rebatch to global batch size. 302 return dataset.unbatch().repeat().batch( 303 self.batch_size * strategy.num_replicas_in_sync, drop_remainder=True) 304 305 def _create_dense_dataset(self, strategy, include_weights=False, weight=0.5): 306 307 features = (constant_op.constant( 308 self.feature_watched_values[:self.data_batch_size], dtype=dtypes.int32), 309 constant_op.constant( 310 self.feature_favorited_values[:self.data_batch_size], 311 dtype=dtypes.int32), 312 constant_op.constant( 313 self.feature_friends_values[:self.data_batch_size], 314 dtype=dtypes.int32)) 315 if include_weights: 316 weights = [ 317 array_ops.ones_like(t, dtype=dtypes.float32) * weight 318 for t in features 319 ] 320 features = (features, tuple(weights)) 321 322 dataset = dataset_ops.DatasetV2.from_tensors(features) 323 return dataset.unbatch().repeat().batch( 324 self.batch_size * strategy.num_replicas_in_sync, drop_remainder=True) 325 326 def _create_high_dimensional_dense_dataset(self, 327 strategy, 328 include_weights=False, 329 weight=0.5): 330 331 dense_size = self.data_batch_size * self.data_batch_size 332 features = (constant_op.constant( 333 self.feature_watched_values_high_dimensional[:dense_size], 334 shape=(self.data_batch_size, self.data_batch_size, 1), 335 dtype=dtypes.int32), 336 constant_op.constant( 337 self.feature_favorited_values_high_dimensional[:dense_size], 338 shape=(self.data_batch_size, self.data_batch_size, 1), 339 dtype=dtypes.int32), 340 constant_op.constant( 341 self.feature_friends_values_high_dimensional[:dense_size], 342 shape=(self.data_batch_size, self.data_batch_size, 1), 343 dtype=dtypes.int32)) 344 if include_weights: 345 weights = [ 346 array_ops.ones_like(t, dtype=dtypes.float32) * weight 347 for t in features 348 ] 349 features = (features, tuple(weights)) 350 dataset = dataset_ops.DatasetV2.from_tensors(features) 351 return dataset.unbatch().repeat().batch( 352 self.batch_size * strategy.num_replicas_in_sync, drop_remainder=True) 353 354 def _check_results(self, strategy, shard_out_val, training, input_data, 355 table_to_variable, optimizer, is_high_dimensional): 356 num_replicas = strategy.num_replicas_in_sync 357 358 # Unpack the values `strategy.run()` returns. 359 loss = self._unpack(strategy, shard_out_val[0]) 360 activation_watched = self._unpack(strategy, shard_out_val[1]) 361 activation_favorited = self._unpack(strategy, shard_out_val[2]) 362 activation_friends = self._unpack(strategy, shard_out_val[3]) 363 364 # Core 0: 365 # Calculate the values of embedding activations. 366 activation_watched_gold0 = np.array([[0, 1, 2, 3], [4, 6, 8, 10]]) 367 activation_favorited_gold0 = np.array([[4, 6, 8, 10], [4, 5, 6, 7]]) 368 # Second row of `activation_friends_gold0` is the mean of the following. 369 # row 0: 0 1 370 # row 1: 2 3 371 # row 2: 4 5 372 activation_friends_gold0 = np.array([[6, 7], [2, 3]]) 373 374 loss_gold0 = self._compute_loss(activation_watched_gold0, 375 activation_favorited_gold0, 376 activation_friends_gold0) 377 378 # Add on values from other cores: 379 # Activations for watched are an alternating sequence of 380 # activation_watched_gold0 and activation_favorited_gold0. 381 # For favorited it is the same but in the opposite order. 382 activation_watched_gold = np.concatenate( 383 (activation_watched_gold0, activation_favorited_gold0)) 384 activation_favorited_gold = np.concatenate( 385 (activation_favorited_gold0, activation_watched_gold0)) 386 activation_friends_gold = np.concatenate( 387 (activation_friends_gold0, activation_friends_gold0)) 388 389 if is_high_dimensional: 390 activation_watched_gold = np.stack([activation_watched_gold] * 391 self.batch_size * num_replicas) 392 393 activation_favorited_gold = np.stack([activation_favorited_gold] * 394 self.batch_size * num_replicas) 395 396 activation_friends_gold = np.stack([activation_friends_gold] * 397 self.batch_size * num_replicas) 398 else: 399 if num_replicas == 1: 400 activation_watched_gold = activation_watched_gold0 401 activation_favorited_gold = activation_favorited_gold0 402 activation_friends_gold = activation_friends_gold0 403 else: 404 activation_watched_gold = np.concatenate( 405 [activation_watched_gold] * (num_replicas // self.batch_size)) 406 activation_favorited_gold = np.concatenate( 407 [activation_favorited_gold] * (num_replicas // self.batch_size)) 408 activation_friends_gold = np.concatenate( 409 [activation_friends_gold] * (num_replicas // self.batch_size)) 410 411 loss_gold = [loss_gold0] * num_replicas 412 413 # Test values. 414 self.assertAllClose(activation_watched_gold, activation_watched) 415 self.assertAllClose(activation_favorited_gold, activation_favorited) 416 self.assertAllClose(activation_friends_gold, activation_friends) 417 418 self.assertAllClose(loss_gold, loss) 419 420 embedding_table_video_before = np.copy( 421 np.reshape(self.embedding_values, [8, 4])) 422 embedding_table_user_before = np.copy( 423 np.reshape(self.embedding_values, [16, 2])) 424 if is_high_dimensional: 425 global_batch_size = self.batch_size * self.data_batch_size * num_replicas 426 else: 427 global_batch_size = self.batch_size * num_replicas 428 if training: 429 gradient_wrt_watched_gold = (2 * activation_watched_gold / 430 global_batch_size) 431 gradient_wrt_favorited_gold = (2 * activation_favorited_gold / 432 global_batch_size) 433 gradient_wrt_friends_gold = (2 * activation_friends_gold / 434 global_batch_size) 435 436 # Calculate gradients wrt embedding tables. 437 gradients_wrt_user = ( 438 self._compute_gradients_wrt_embedding_table( 439 gradient_wrt_friends_gold, embedding_table_user_before, 440 input_data[2].indices.numpy(), input_data[2].values.numpy(), 441 self.table_user.combiner)) 442 gradients_wrt_video = ( 443 self._compute_gradients_wrt_embedding_table( 444 gradient_wrt_favorited_gold, embedding_table_video_before, 445 input_data[1].indices.numpy(), input_data[1].values.numpy(), 446 self.table_video.combiner) + 447 self._compute_gradients_wrt_embedding_table( 448 gradient_wrt_watched_gold, embedding_table_video_before, 449 input_data[0].indices.numpy(), input_data[0].values.numpy(), 450 self.table_video.combiner)) 451 452 self._check_embedding_and_slot_variables(embedding_table_user_before, 453 gradients_wrt_user, 454 embedding_table_video_before, 455 gradients_wrt_video, optimizer, 456 table_to_variable) 457 458 def _check_embedding_and_slot_variables(self, embedding_table_user_before, 459 gradients_wrt_user, 460 embedding_table_video_before, 461 gradients_wrt_video, optimizer, 462 table_to_variable): 463 if isinstance(optimizer, tpu_embedding_v2_utils.SGD): 464 check_fn = self._check_embedding_and_slot_variables_for_sgd 465 elif isinstance(optimizer, tpu_embedding_v2_utils.Adagrad): 466 check_fn = self._check_embedding_and_slot_variables_for_adagrad 467 elif isinstance(optimizer, tpu_embedding_v2_utils.AdagradMomentum): 468 check_fn = self._check_embedding_and_slot_variables_for_adagrad_momentum 469 elif isinstance(optimizer, tpu_embedding_v2_utils.Adam): 470 check_fn = self._check_embedding_and_slot_variables_for_adam 471 elif isinstance(optimizer, tpu_embedding_v2_utils.FTRL): 472 check_fn = self._check_embedding_and_slot_variables_for_ftrl 473 else: 474 raise ValueError('optimizer is not recognized: ', type(optimizer)) 475 check_fn(embedding_table_user_before, gradients_wrt_user, optimizer, 476 table_to_variable[self.table_user.name]) 477 check_fn(embedding_table_video_before, gradients_wrt_video, optimizer, 478 table_to_variable[self.table_video.name]) 479 480 def _check_embedding_and_slot_variables_for_sgd(self, embedding_table_before, 481 gradients, optimizer, 482 variables): 483 embedding_table = np.copy(embedding_table_before) 484 embedding_table -= optimizer.learning_rate * np.sum(gradients, axis=0) 485 self.assertAllClose( 486 self._get_variable(variables['parameters']).numpy(), embedding_table) 487 488 def _check_embedding_and_slot_variables_for_adagrad(self, 489 embedding_table_before, 490 gradients, optimizer, 491 variable): 492 embedding_table = np.copy(embedding_table_before) 493 accumulator = ( 494 optimizer.initial_accumulator_value + np.sum(gradients, axis=0)**2) 495 embedding_table -= ( 496 optimizer.learning_rate * np.sum(gradients, axis=0) / 497 np.sqrt(accumulator)) 498 self.assertAllClose( 499 self._get_variable(variable['parameters']).numpy(), embedding_table) 500 self.assertAllClose( 501 self._get_variable(variable['accumulators']).numpy(), accumulator) 502 503 def _check_embedding_and_slot_variables_for_adagrad_momentum( 504 self, embedding_table_before, gradients, optimizer, variable): 505 embedding_table = np.copy(embedding_table_before) 506 accumulator = np.zeros(self._get_variable(variable['accumulators']).shape) 507 momenta = np.zeros(self._get_variable(variable['momenta']).shape) 508 gradients = np.sum(gradients, axis=0) 509 if optimizer.beta2 == 1.0: 510 accumulator += gradients**2 511 else: 512 accumulator = optimizer.beta2 * accumulator + ( 513 1 - optimizer.beta2) * gradients**2 514 accumulator_power = np.power(accumulator + optimizer.epsilon, 515 -1.0 / optimizer.exponent) 516 momenta = optimizer.momentum * momenta + gradients * accumulator_power 517 if optimizer.use_nesterov: 518 update = optimizer.momentum * momenta + gradients * accumulator_power 519 else: 520 update = momenta 521 embedding_table -= optimizer.learning_rate * update 522 self.assertAllClose( 523 self._get_variable(variable['parameters']).numpy(), 524 embedding_table, 525 rtol=1e-3) 526 self.assertAllClose( 527 self._get_variable(variable['accumulators']).numpy(), 528 accumulator, 529 rtol=1e-3) 530 self.assertAllClose( 531 self._get_variable(variable['momenta']).numpy(), momenta, rtol=1e-3) 532 533 def _check_embedding_and_slot_variables_for_adam(self, embedding_table_before, 534 gradients, optimizer, 535 variable): 536 embedding_table = np.copy(embedding_table_before) 537 g = np.sum(gradients, axis=0) 538 v = g**2 * (1 - optimizer.beta_2) 539 m = g * (1 - optimizer.beta_1) 540 epsilon = optimizer.epsilon 541 # TPU Embeddings don't have the LR decay factor for Adam. 542 lr_modifier = 1 543 embedding_table -= ( 544 m * optimizer.learning_rate * lr_modifier / (np.sqrt(v) + epsilon)) 545 self.assertAllClose( 546 self._get_variable(variable['parameters']).numpy(), 547 embedding_table, 548 rtol=1e-4) 549 self.assertAllClose( 550 self._get_variable(variable['momenta']).numpy(), m, rtol=1e-4) 551 self.assertAllClose( 552 self._get_variable(variable['velocities']).numpy(), v, rtol=1e-4) 553 554 def _check_embedding_and_slot_variables_for_ftrl(self, embedding_table_before, 555 gradients, optimizer, 556 variable): 557 embedding_table = np.copy(embedding_table_before) 558 neg_lr_p = -optimizer.learning_rate_power 559 accumulator = ( 560 optimizer.initial_accumulator_value + np.sum(gradients, axis=0)**2) 561 sigma = (accumulator**neg_lr_p - optimizer.initial_accumulator_value** 562 neg_lr_p) / optimizer.learning_rate 563 linear = np.sum(gradients, axis=0) - sigma * embedding_table 564 quadratic = accumulator**neg_lr_p / optimizer.learning_rate 565 embedding_table = -linear / quadratic 566 actual_parameters = self._get_variable(variable['parameters']).numpy() 567 # For entries where `linear` == 0, it is not worth comparing since the 568 # initial values have not been touched yet and they will not agree with what 569 # the actual values should be. 570 actual_parameters *= (linear != 0.0) 571 # FTRL has a bit more precision diff on parameters. 572 self.assertAllClose(actual_parameters, embedding_table, rtol=5e-5) 573 self.assertAllClose( 574 self._get_variable(variable['linears']).numpy(), linear, rtol=5e-4) 575 self.assertAllClose( 576 self._get_variable(variable['accumulators']).numpy(), accumulator) 577 578 def _get_replica_numpy(self, structured, strategy, replica_id): 579 580 def select_replica(x): 581 x = strategy.experimental_local_results(x) 582 if len(x) == 1: 583 return x.numpy() 584 return x[replica_id].numpy() 585 586 return nest.map_structure(select_replica, structured) 587 588 def _compute_gradients_wrt_embedding_table(self, gradient_wrt_activation, 589 embedding_table, feature_indices, 590 feature_values, combiner): 591 """Compute gradients wrt embedding_table. 592 593 Args: 594 gradient_wrt_activation: `np.array` with shape `batch_size` by embedding 595 `dimension`. 596 embedding_table: `np.array` with shape `vocabulary_size` by embedding 597 `dimension`. 598 feature_indices: `indices` as used to construct `SparseTensor`. 599 feature_values: `values` as used to construct `SparseTensor`. 600 combiner: `String`, 'mean' or 'sum'. 601 602 Returns: 603 Gradients wrt `embedding_table`, an `np.array`s with shape 604 `batch_size` by `vocabulary_size` by 605 embedding `dimension`. 606 607 Raises: 608 ValueError: if `combiner` is not one of 'mean' or 'sum'. 609 """ 610 if combiner not in ('mean', 'sum'): 611 raise ValueError( 612 '`combiner` must be mean or sum; got {}.'.format(combiner)) 613 grads_shape = gradient_wrt_activation.shape[:-1] + embedding_table.shape 614 grads = np.zeros(shape=grads_shape) 615 count = np.zeros(shape=grads_shape) 616 for feature_indice, vocabulary_id in zip(feature_indices, feature_values): 617 batch_index = tuple(feature_indice[:-1]) 618 grads[batch_index][vocabulary_id] += gradient_wrt_activation[batch_index] 619 count[batch_index] += 1 620 count[count == 0] = 1 621 if combiner == 'mean': 622 grads = grads / count 623 return np.reshape(grads, (-1, *embedding_table.shape)) 624 625 def _unpack(self, strategy, per_replica_output): 626 per_replica_output = strategy.experimental_local_results(per_replica_output) 627 per_replica_output = array_ops.concat(per_replica_output, axis=0).numpy() 628 return per_replica_output 629 630 def _get_total_loss_tensor(self, activations): 631 losses = [] 632 for activation in activations: 633 losses.append( 634 math_ops.reduce_mean( 635 math_ops.reduce_sum( 636 gen_math_ops.squared_difference(activation, 0), axis=-1))) 637 total_loss = array_ops.expand_dims_v2(sum(losses), 0) 638 return total_loss 639 640 def _compute_loss(self, activation_watched, activation_favorited, 641 activation_friends): 642 watched_loss = np.mean(np.sum(activation_watched**2, axis=-1)) 643 favorited_loss = np.mean(np.sum(activation_favorited**2, axis=-1)) 644 friends_loss = np.mean(np.sum(activation_friends**2, axis=-1)) 645 loss = watched_loss + favorited_loss + friends_loss 646 return loss 647 648 def _get_variable(self, variable): 649 if isinstance(variable, tpu_embedding_v2.TPUEmbeddingVariable): 650 return variable.variables[0] 651 return variable 652 653 def _get_tmpdir(self, name, subdir=''): 654 segments = [FLAGS.model_dir, name] + ([subdir] if subdir else []) 655 return os.path.join(*segments) 656