1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""TPU embedding APIs.""" 16 17import collections 18import copy 19import math 20import re 21from typing import Optional 22 23from tensorflow.core.protobuf.tpu import optimization_parameters_pb2 24from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2 as elc 25from tensorflow.python.eager import context 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import ops 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import control_flow_ops 30from tensorflow.python.ops import init_ops 31from tensorflow.python.ops import math_ops 32from tensorflow.python.ops import partitioned_variables 33from tensorflow.python.ops import state_ops 34from tensorflow.python.ops import variable_scope 35from tensorflow.python.platform import tf_logging as logging 36from tensorflow.python.tpu import tpu_system_metadata as tpu_system_metadata_lib 37from tensorflow.python.tpu.ops import tpu_ops 38from tensorflow.python.util.tf_export import tf_export 39 40TRAINING = elc.TPUEmbeddingConfiguration.TRAINING 41INFERENCE = elc.TPUEmbeddingConfiguration.INFERENCE 42 43 44# TODO(shizhiw): a more future-proof way is to have optimization_parameter such 45# as AdagradParameters etc instead of learning_rate. 46class TableConfig( 47 collections.namedtuple('TableConfig', [ 48 'vocabulary_size', 49 'dimension', 50 'initializer', 51 'combiner', 52 'hot_id_replication', 53 'learning_rate', 54 'learning_rate_fn', 55 'optimization_parameters', 56 ])): 57 """Embedding table configuration.""" 58 59 def __new__(cls, 60 vocabulary_size, 61 dimension, 62 initializer=None, 63 combiner='mean', 64 hot_id_replication=False, 65 learning_rate=None, 66 learning_rate_fn=None, 67 optimization_parameters=None): 68 """Embedding table configuration. 69 70 Args: 71 vocabulary_size: Number of vocabulary (/rows) in the table. 72 dimension: The embedding dimension. 73 initializer: A variable initializer function to be used in embedding 74 variable initialization. If not specified, defaults to 75 `tf.compat.v1.truncated_normal_initializer` with mean `0.0` and standard 76 deviation `1/sqrt(dimension)`. 77 combiner: A string specifying how to reduce if there are multiple entries 78 in a single row. Currently 'mean', 'sqrtn', 'sum' and None are 79 supported, with 'mean' the default. 'sqrtn' often achieves good 80 accuracy, in particular with bag-of-words columns. For more information, 81 see `tf.nn.embedding_lookup_sparse`. None is only valid for dense rather 82 than sparse tensors. 83 hot_id_replication: If true, enables hot id replication, which can make 84 embedding lookups faster if there are some hot rows in the table. 85 learning_rate: float, static learning rate for this table. If 86 learning_rate and learning_rate_fn are both `None`, static learning rate 87 as specified in local `optimization_parameters` will be used. In case 88 local `optimization_parameters` is `None`, global 89 `optimization_parameters` in `TPUEmbedding` constructor will be used. 90 `learning_rate_fn` must be `None` if `learning_rate` is not `None. 91 learning_rate_fn: string, use dynamic learning rate given by the function. 92 This function will be passed the current global step. If learning_rate 93 and learning_rate_fn are both `None`, static learning rate as specified 94 in `optimization_parameters` is used. `learning_rate` must be `None` if 95 `learning_rate_fn` is not `None. 96 optimization_parameters: `AdagradParameters`, `AdamParameters`, 97 `Stochasticgradientdescentparameters`. Specifies table level optimizer. 98 If it's `None` global optimizer in `TPUEmbedding` constructor is used. 99 100 Returns: 101 `TableConfig`. 102 103 Raises: 104 ValueError: if `vocabulary_size` is not positive integer. 105 ValueError: if `dimension` is not positive integer. 106 ValueError: if `initializer` is specified and is not callable. 107 ValueError: if `combiner` is not supported. 108 ValueError: if `learning_rate` and `learning_rate_fn` are both not 109 `None`. 110 """ 111 if not isinstance(vocabulary_size, int) or vocabulary_size < 1: 112 raise ValueError(f'vocabulary_size must >= 1. ' 113 f'Received: {vocabulary_size}.') 114 115 if not isinstance(dimension, int) or dimension < 1: 116 raise ValueError( 117 f'dimension must be a positive int. Received: {dimension}.') 118 119 if (initializer is not None) and (not callable(initializer)): 120 raise ValueError(f'initializer must be callable if specified. ' 121 f'Received: {initializer}.') 122 if initializer is None: 123 initializer = init_ops.truncated_normal_initializer( 124 mean=0.0, stddev=1 / math.sqrt(dimension)) 125 126 if combiner not in ('mean', 'sum', 'sqrtn', None): 127 raise ValueError(f'combiner must be "mean", "sum", "sqrtn" or None. ' 128 f'Received: {combiner}.') 129 130 if learning_rate is not None and learning_rate_fn is not None: 131 raise ValueError('At most one of learning_rate and learning_rate_fn ' 132 'can be None. Received: {} and {}'.format( 133 learning_rate, learning_rate_fn)) 134 135 if optimization_parameters is not None: 136 if not isinstance(optimization_parameters, _OptimizationParameters): 137 raise ValueError(f'`optimization_parameters` must inherit from ' 138 f'`_OptimizationParameters`. ' 139 f'Received: `type(optimization_parameters)`=' 140 f'{type(optimization_parameters)}.') 141 142 return super().__new__(cls, vocabulary_size, dimension, initializer, 143 combiner, hot_id_replication, learning_rate, 144 learning_rate_fn, optimization_parameters) 145 146 147class FeatureConfig( 148 collections.namedtuple('FeatureConfig', 149 ['table_id', 'max_sequence_length', 'weight_key'])): 150 """Feature configuration.""" 151 152 def __new__(cls, table_id, max_sequence_length=0, weight_key=None): 153 """Feature configuration. 154 155 Args: 156 table_id: Which table the feature is uses for embedding lookups. 157 max_sequence_length: If positive, the feature is a sequence feature with 158 the corresponding maximum sequence length. If the sequence is longer 159 than this, it will be truncated. If 0, the feature is not a sequence 160 feature. 161 weight_key: If using weights for the combiner, this key specifies which 162 input feature contains the weights. 163 164 Returns: 165 `FeatureConfig`. 166 167 Raises: 168 ValueError: if `max_sequence_length` non-integer or negative. 169 """ 170 if not isinstance(max_sequence_length, int) or max_sequence_length < 0: 171 raise ValueError(f'max_sequence_length must be zero or a positive int, ' 172 f'got {max_sequence_length}.') 173 174 return super().__new__(cls, table_id, max_sequence_length, weight_key) 175 176 177class EnqueueData( 178 collections.namedtuple( 179 'EnqueueData', 180 ['embedding_indices', 'sample_indices', 'aggregation_weights'])): 181 """Data to be enqueued through generate_enqueue_ops().""" 182 183 def __new__(cls, 184 embedding_indices, 185 sample_indices=None, 186 aggregation_weights=None): 187 """Data to be enqueued through generate_enqueue_ops(). 188 189 Args: 190 embedding_indices: A rank 1 Tensor, indices into the embedding tables. It 191 corresponds to sp_ids.values in embedding_lookup_sparse(). Both int32 192 and int64 are allowed and will be converted to int32 internally. 193 sample_indices: A rank 2 Tensor specifying the training example to which 194 the corresponding embedding_indices and aggregation_weights values 195 belong. It corresponds to sp_ids.indices in embedding_lookup_sparse(). 196 If it is None, we assume each embedding_indices belongs to a different 197 sample. Both int32 and int64 are allowed and will be converted to int32 198 internally. 199 aggregation_weights: A rank 1 Tensor containing aggregation weights. It 200 corresponds to sp_weights.values in embedding_lookup_sparse(). If it is 201 None, we assume all weights are 1. Both float32 and float64 are allowed 202 and will be converted to float32 internally. 203 204 Returns: 205 An EnqueueData tuple. 206 207 """ 208 return super().__new__(cls, embedding_indices, sample_indices, 209 aggregation_weights) 210 211 @staticmethod 212 def from_sparse_tensor(sp_tensor, weights=None): 213 return EnqueueData( 214 sp_tensor.values, 215 sp_tensor.indices, 216 aggregation_weights=weights.values if weights is not None else None) 217 218 219class RaggedEnqueueData( 220 collections.namedtuple( 221 'RaggedEnqueueData', 222 ['embedding_indices', 'row_splits', 'aggregation_weights'])): 223 """RaggedTensor Data to be enqueued through generate_enqueue_ops().""" 224 225 def __new__(cls, 226 embedding_indices, 227 row_splits=None, 228 aggregation_weights=None): 229 """Data to be enqueued through generate_enqueue_ops(). 230 231 Args: 232 embedding_indices: A rank 1 Tensor, indices into the embedding tables. It 233 corresponds to ids.values in embedding_lookup(), when ids is a 234 RaggedTensor. Both int32 and int64 are allowed and will be converted to 235 int32 internally. 236 row_splits: A rank 1 Tensor specifying the length of the break points for 237 splitting embedding_indices and aggregation_weights. It corresponds to 238 ids.row_splits in embedding_lookup(), when ids is a RaggedTensor. Both 239 int32 and int64 are allowed and will be converted to int32 internally. 240 aggregation_weights: A rank 1 Tensor containing per training example 241 aggregation weights. It corresponds to the values field of a 242 RaggedTensor with the same row_splits as ids in embedding_lookup(), when 243 ids is a RaggedTensor. 244 245 Returns: 246 An RaggedEnqueueData tuple. 247 248 """ 249 return super().__new__(cls, embedding_indices, row_splits, 250 aggregation_weights) 251 252 @staticmethod 253 def from_ragged_tensor(rg_tensor, weights=None): 254 return RaggedEnqueueData( 255 rg_tensor.values, 256 rg_tensor.row_splits, 257 aggregation_weights=weights.values if weights is not None else None) 258 259 260def get_enqueue_datas_list_from_sparse_tensors_list(sp_tensors_list): 261 """Convenient function for generate_enqueue_ops(). 262 263 Args: 264 sp_tensors_list: a list of dictionary mapping from string of feature names 265 to SparseTensor. Each dictionary is for one TPU core. Dictionaries for the 266 same host should be contiguous on the list. 267 268 Returns: 269 enqueue_datas_list: a list of dictionary mapping from string 270 of feature names to EnqueueData. Each dictionary is for one 271 TPU core. Dictionaries for the same host should be contiguous 272 on the list. 273 274 """ 275 enqueue_datas_list = [] 276 for sp_tensors in sp_tensors_list: 277 enqueue_datas = collections.OrderedDict( 278 (k, EnqueueData.from_sparse_tensor(v)) for k, v in sp_tensors.items()) 279 enqueue_datas_list.append(enqueue_datas) 280 return enqueue_datas_list 281 282 283def get_enqueue_datas_list_from_ragged_tensors_list(rg_tensors_list): 284 """Convenient function for generate_enqueue_ops(). 285 286 Args: 287 rg_tensors_list: a list of dictionary mapping from string of feature names 288 to RaggedTensor. Each dictionary is for one TPU core. Dictionaries for the 289 same host should be contiguous on the list. 290 291 Returns: 292 enqueue_datas_list: a list of dictionary mapping from string 293 of feature names to RaggedEnqueueData. Each dictionary is for one 294 TPU core. Dictionaries for the same host should be contiguous 295 on the list. 296 297 """ 298 enqueue_datas_list = [] 299 for rg_tensors in rg_tensors_list: 300 enqueue_datas = collections.OrderedDict( 301 (k, RaggedEnqueueData.from_ragged_tensor(v)) 302 for k, v in rg_tensors.items()) 303 enqueue_datas_list.append(enqueue_datas) 304 return enqueue_datas_list 305 306 307AdamSlotVariableNames = collections.namedtuple('AdamSlotVariableNames', 308 ['m', 'v']) 309 310AdagradSlotVariableNames = collections.namedtuple('AdagradSlotVariableNames', 311 ['accumulator']) 312 313MomentumSlotVariableNames = collections.namedtuple('MomentumSlotVariableNames', 314 ['momenta']) 315 316AdagradMomentumSlotVariableNames = collections.namedtuple( 317 'AdagradMomentumSlotVariableNames', ['accumulator', 'momenta']) 318 319RMSPropSlotVariableNames = collections.namedtuple('RMSPropSlotVariableNames', 320 ['ms', 'mom']) 321 322ProximalAdagradSlotVariableNames = collections.namedtuple( 323 'ProximalAdagradSlotVariableNames', ['accumulator']) 324 325FtrlSlotVariableNames = collections.namedtuple('FtrlSlotVariableNames', 326 ['accumulator', 'linear']) 327 328ProximalYogiSlotVariableNames = collections.namedtuple( 329 'ProximalYogiSlotVariableNames', ['v', 'm']) 330 331FrequencyEstimatorSlotVariableNames = collections.namedtuple( 332 'FrequencyEstimatorSlotVariableNames', ['last_hit_step']) 333 334AdamSlotVariables = collections.namedtuple('AdamSlotVariables', ['m', 'v']) 335 336MomentumSlotVariables = collections.namedtuple('MomentumSlotVariables', 337 ['momenta']) 338 339AdagradMomentumSlotVariables = collections.namedtuple( 340 'AdagradMomentumSlotVariables', ['accumulator', 'momenta']) 341 342RMSPropSlotVariables = collections.namedtuple('RMSPropSlotVariables', 343 ['ms', 'mom']) 344 345AdagradSlotVariables = collections.namedtuple('AdagradSlotVariables', 346 ['accumulator']) 347 348ProximalAdagradSlotVariables = collections.namedtuple( 349 'ProximalAdagradSlotVariables', ['accumulator']) 350 351FtrlSlotVariable = collections.namedtuple('FtrlSlotVariable', 352 ['accumulator', 'linear']) 353 354ProximalYogiSlotVariables = collections.namedtuple('ProximalYogiSlotVariables', 355 ['v', 'm']) 356 357FrequencyEstimatorSlotVariables = collections.namedtuple( 358 'FrequencyEstimatorSlotVariables', ['last_hit_step']) 359 360VariablesAndOps = collections.namedtuple('VariablesAndOps', [ 361 'embedding_variables_by_table', 'slot_variables_by_table', 'load_ops', 362 'retrieve_ops' 363]) 364 365 366class _OptimizationParameters: 367 """Parameters common to all optimizations.""" 368 369 def __init__( 370 self, 371 learning_rate: float, 372 use_gradient_accumulation: bool, 373 clip_weight_min: Optional[float], 374 clip_weight_max: Optional[float], 375 weight_decay_factor: Optional[float], 376 multiply_weight_decay_factor_by_learning_rate: Optional[bool], 377 clip_gradient_min: Optional[float] = None, 378 clip_gradient_max: Optional[float] = None, 379 ): 380 self.learning_rate = learning_rate 381 self.use_gradient_accumulation = use_gradient_accumulation 382 self.clip_weight_min = clip_weight_min 383 self.clip_weight_max = clip_weight_max 384 self.weight_decay_factor = weight_decay_factor 385 self.multiply_weight_decay_factor_by_learning_rate = ( 386 multiply_weight_decay_factor_by_learning_rate) 387 self.clip_gradient_min = clip_gradient_min 388 self.clip_gradient_max = clip_gradient_max 389 390 if not use_gradient_accumulation and (clip_gradient_min is not None or 391 clip_gradient_max is not None): 392 raise ValueError('When using gradient clipping limits, gradient ' 393 'accumulation must be enabled.') 394 395 396@tf_export(v1=['tpu.experimental.AdagradParameters']) 397class AdagradParameters(_OptimizationParameters): 398 """Optimization parameters for Adagrad with TPU embeddings. 399 400 Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the 401 `optimization_parameters` argument to set the optimizer and its parameters. 402 See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec` 403 for more details. 404 405 ``` 406 estimator = tf.estimator.tpu.TPUEstimator( 407 ... 408 embedding_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec( 409 ... 410 optimization_parameters=tf.tpu.experimental.AdagradParameters(0.1), 411 ...)) 412 ``` 413 414 """ 415 416 def __init__( 417 self, 418 learning_rate: float, 419 initial_accumulator: float = 0.1, 420 use_gradient_accumulation: bool = True, 421 clip_weight_min: Optional[float] = None, 422 clip_weight_max: Optional[float] = None, 423 weight_decay_factor: Optional[float] = None, 424 multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, 425 clip_gradient_min: Optional[float] = None, 426 clip_gradient_max: Optional[float] = None, 427 ): 428 """Optimization parameters for Adagrad. 429 430 Args: 431 learning_rate: used for updating embedding table. 432 initial_accumulator: initial accumulator for Adagrad. 433 use_gradient_accumulation: setting this to `False` makes embedding 434 gradients calculation less accurate but faster. Please see 435 `optimization_parameters.proto` for details. 436 clip_weight_min: the minimum value to clip by; None means -infinity. 437 clip_weight_max: the maximum value to clip by; None means +infinity. 438 weight_decay_factor: amount of weight decay to apply; None means that the 439 weights are not decayed. 440 multiply_weight_decay_factor_by_learning_rate: if true, 441 `weight_decay_factor` is multiplied by the current learning rate. 442 clip_gradient_min: the minimum value to clip by; None means -infinity. 443 Gradient accumulation must be set to true if this is set. 444 clip_gradient_max: the maximum value to clip by; None means +infinity. 445 Gradient accumulation must be set to true if this is set. 446 """ 447 super().__init__( 448 learning_rate=learning_rate, 449 use_gradient_accumulation=use_gradient_accumulation, 450 clip_weight_min=clip_weight_min, 451 clip_weight_max=clip_weight_max, 452 weight_decay_factor=weight_decay_factor, 453 multiply_weight_decay_factor_by_learning_rate=( 454 multiply_weight_decay_factor_by_learning_rate), 455 clip_gradient_min=clip_gradient_min, 456 clip_gradient_max=clip_gradient_max, 457 ) 458 if initial_accumulator <= 0: 459 raise ValueError( 460 f'Adagrad initial_accumulator must be greater than zero. ' 461 f'Received: {initial_accumulator}.') 462 self.initial_accumulator = initial_accumulator 463 464 465class AdagradMomentumParameters(_OptimizationParameters): 466 """Optimization parameters for Adagrad + Momentum with TPU embeddings. 467 468 Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the 469 `optimization_parameters` argument to set the optimizer and its parameters. 470 See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec` 471 for more details. 472 473 ``` 474 estimator = tf.estimator.tpu.TPUEstimator( 475 ... 476 embedding_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec( 477 ... 478 optimization_parameters=tf.tpu.experimental.AdagradMomentumParameters(0.1), 479 ...)) 480 ``` 481 482 """ 483 484 def __init__( 485 self, 486 learning_rate: float, 487 momentum: float, 488 use_nesterov: bool = False, 489 exponent: float = 2, 490 beta2: float = 1, 491 epsilon: float = 1e-10, 492 use_gradient_accumulation: bool = True, 493 clip_weight_min: Optional[float] = None, 494 clip_weight_max: Optional[float] = None, 495 weight_decay_factor: Optional[float] = None, 496 multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, 497 clip_gradient_min: Optional[float] = None, 498 clip_gradient_max: Optional[float] = None, 499 ): 500 """Optimization parameters for Adagrad. 501 502 Args: 503 learning_rate: used for updating embedding table. 504 momentum: Moving average parameter for the momentum accumulator. 505 use_nesterov: Whether to use the Nesterov variant of momentum. See 506 Sutskever et al., 2013. 507 exponent: Exponent for the Adagrad accumulator. 508 beta2: Moving average parameter for the Adagrad accumulator. 509 epsilon: initial accumulator for Adagrad accumulator. 510 use_gradient_accumulation: setting this to `False` makes embedding 511 gradients calculation less accurate but faster. Please see 512 `optimization_parameters.proto` for details. 513 clip_weight_min: the minimum value to clip by; None means -infinity. 514 clip_weight_max: the maximum value to clip by; None means +infinity. 515 weight_decay_factor: amount of weight decay to apply; None means that the 516 weights are not decayed. 517 multiply_weight_decay_factor_by_learning_rate: if true, 518 `weight_decay_factor` is multiplied by the current learning rate. 519 clip_gradient_min: the minimum value to clip by; None means -infinity. 520 Gradient accumulation must be set to true if this is set. 521 clip_gradient_max: the maximum value to clip by; None means +infinity. 522 Gradient accumulation must be set to true if this is set. 523 """ 524 super().__init__( 525 learning_rate=learning_rate, 526 use_gradient_accumulation=use_gradient_accumulation, 527 clip_weight_min=clip_weight_min, 528 clip_weight_max=clip_weight_max, 529 weight_decay_factor=weight_decay_factor, 530 multiply_weight_decay_factor_by_learning_rate=( 531 multiply_weight_decay_factor_by_learning_rate), 532 clip_gradient_min=clip_gradient_min, 533 clip_gradient_max=clip_gradient_max, 534 ) 535 if epsilon <= 0: 536 raise ValueError('Adagrad momentum: epsilon must be positive') 537 if exponent <= 0: 538 raise ValueError('Adagrad momentum: Precondition exponent must >0') 539 self.momentum = momentum 540 self.use_nesterov = use_nesterov 541 self.exponent = exponent 542 self.beta2 = beta2 543 self.epsilon = epsilon 544 545 546class ProximalAdagradParameters(_OptimizationParameters): 547 """Optimization parameters for ProximalAdagrad with TPU embeddings. 548 549 Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the 550 `optimization_parameters` argument to set the optimizer and its parameters. 551 See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec` 552 for more details. 553 """ 554 555 def __init__( 556 self, 557 learning_rate: float, 558 initial_accumulator: float = 0.1, 559 l1_regularization_strength: float = 0.0, 560 l2_regularization_strength: float = 0.0, 561 use_gradient_accumulation: bool = True, 562 clip_weight_min: Optional[float] = None, 563 clip_weight_max: Optional[float] = None, 564 weight_decay_factor: Optional[float] = None, 565 multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, 566 clip_gradient_min: Optional[float] = None, 567 clip_gradient_max: Optional[float] = None, 568 ): 569 """Optimization parameters for Adagrad. 570 571 Args: 572 learning_rate: used for updating embedding table. 573 initial_accumulator: initial accumulator for Adagrad. 574 l1_regularization_strength: A float value, must be greater than or equal 575 to zero. 576 l2_regularization_strength: A float value, must be greater than or equal 577 to zero. 578 use_gradient_accumulation: setting this to `False` makes embedding 579 gradients calculation less accurate but faster. Please see 580 `optimization_parameters.proto` for details. for details. 581 clip_weight_min: the minimum value to clip by; None means -infinity. 582 clip_weight_max: the maximum value to clip by; None means +infinity. 583 weight_decay_factor: amount of weight decay to apply; None means that the 584 weights are not decayed. 585 multiply_weight_decay_factor_by_learning_rate: if true, 586 `weight_decay_factor` is multiplied by the current learning rate. 587 clip_gradient_min: the minimum value to clip by; None means -infinity. 588 Gradient accumulation must be set to true if this is set. 589 clip_gradient_max: the maximum value to clip by; None means +infinity. 590 Gradient accumulation must be set to true if this is set. 591 """ 592 super().__init__( 593 learning_rate=learning_rate, 594 use_gradient_accumulation=use_gradient_accumulation, 595 clip_weight_min=clip_weight_min, 596 clip_weight_max=clip_weight_max, 597 weight_decay_factor=weight_decay_factor, 598 multiply_weight_decay_factor_by_learning_rate=( 599 multiply_weight_decay_factor_by_learning_rate), 600 clip_gradient_min=clip_gradient_min, 601 clip_gradient_max=clip_gradient_max, 602 ) 603 if initial_accumulator <= 0: 604 raise ValueError(f'Adagrad initial_accumulator must be positive. ' 605 f'Received: {initial_accumulator}.') 606 if l1_regularization_strength < 0.: 607 raise ValueError('l1_regularization_strength must be greater than or ' 608 'equal to 0. got {}.'.format(l1_regularization_strength)) 609 610 if l2_regularization_strength < 0.: 611 raise ValueError('l2_regularization_strength must be greater than or ' 612 'equal to 0. got {}.'.format(l2_regularization_strength)) 613 614 self.initial_accumulator = initial_accumulator 615 self.l1_regularization_strength = l1_regularization_strength 616 self.l2_regularization_strength = l2_regularization_strength 617 618 619@tf_export(v1=['tpu.experimental.AdamParameters']) 620class AdamParameters(_OptimizationParameters): 621 """Optimization parameters for Adam with TPU embeddings. 622 623 Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the 624 `optimization_parameters` argument to set the optimizer and its parameters. 625 See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec` 626 for more details. 627 628 ``` 629 estimator = tf.estimator.tpu.TPUEstimator( 630 ... 631 embedding_config_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec( 632 ... 633 optimization_parameters=tf.tpu.experimental.AdamParameters(0.1), 634 ...)) 635 ``` 636 637 """ 638 639 def __init__( 640 self, 641 learning_rate: float, 642 beta1: float = 0.9, 643 beta2: float = 0.999, 644 epsilon: float = 1e-08, 645 lazy_adam: bool = True, 646 sum_inside_sqrt: bool = True, 647 use_gradient_accumulation: bool = True, 648 clip_weight_min: Optional[float] = None, 649 clip_weight_max: Optional[float] = None, 650 weight_decay_factor: Optional[float] = None, 651 multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, 652 clip_gradient_min: Optional[float] = None, 653 clip_gradient_max: Optional[float] = None, 654 ): 655 """Optimization parameters for Adam. 656 657 Args: 658 learning_rate: a floating point value. The learning rate. 659 beta1: A float value. The exponential decay rate for the 1st moment 660 estimates. 661 beta2: A float value. The exponential decay rate for the 2nd moment 662 estimates. 663 epsilon: A small constant for numerical stability. 664 lazy_adam: Use lazy Adam instead of Adam. Lazy Adam trains faster. See 665 `optimization_parameters.proto` for details. 666 sum_inside_sqrt: This improves training speed. Please see 667 `optimization_parameters.proto` for details. 668 use_gradient_accumulation: setting this to `False` makes embedding 669 gradients calculation less accurate but faster. Please see 670 `optimization_parameters.proto` for details. 671 clip_weight_min: the minimum value to clip by; None means -infinity. 672 clip_weight_max: the maximum value to clip by; None means +infinity. 673 weight_decay_factor: amount of weight decay to apply; None means that the 674 weights are not decayed. 675 multiply_weight_decay_factor_by_learning_rate: if true, 676 `weight_decay_factor` is multiplied by the current learning rate. 677 clip_gradient_min: the minimum value to clip by; None means -infinity. 678 Gradient accumulation must be set to true if this is set. 679 clip_gradient_max: the maximum value to clip by; None means +infinity. 680 Gradient accumulation must be set to true if this is set. 681 """ 682 super().__init__( 683 learning_rate=learning_rate, 684 use_gradient_accumulation=use_gradient_accumulation, 685 clip_weight_min=clip_weight_min, 686 clip_weight_max=clip_weight_max, 687 weight_decay_factor=weight_decay_factor, 688 multiply_weight_decay_factor_by_learning_rate=( 689 multiply_weight_decay_factor_by_learning_rate), 690 clip_gradient_min=clip_gradient_min, 691 clip_gradient_max=clip_gradient_max, 692 ) 693 if beta1 < 0. or beta1 >= 1.: 694 raise ValueError('beta1 must be between 0. and 1; got {}.'.format(beta1)) 695 if beta2 < 0. or beta2 >= 1.: 696 raise ValueError('beta2 must be between 0. and 1; got {}.'.format(beta2)) 697 if epsilon <= 0.: 698 raise ValueError('epsilon must be positive; got {}.'.format(epsilon)) 699 if not use_gradient_accumulation and not lazy_adam: 700 raise ValueError( 701 'When disabling Lazy Adam, gradient accumulation must be used.') 702 703 self.beta1 = beta1 704 self.beta2 = beta2 705 self.epsilon = epsilon 706 self.lazy_adam = lazy_adam 707 self.sum_inside_sqrt = sum_inside_sqrt 708 709 710@tf_export(v1=['tpu.experimental.FtrlParameters']) 711class FtrlParameters(_OptimizationParameters): 712 """Optimization parameters for Ftrl with TPU embeddings. 713 714 Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the 715 `optimization_parameters` argument to set the optimizer and its parameters. 716 See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec` 717 for more details. 718 719 ``` 720 estimator = tf.estimator.tpu.TPUEstimator( 721 ... 722 embedding_config_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec( 723 ... 724 optimization_parameters=tf.tpu.experimental.FtrlParameters(0.1), 725 ...)) 726 ``` 727 728 """ 729 730 def __init__( 731 self, 732 learning_rate: float, 733 learning_rate_power: float = -0.5, 734 initial_accumulator_value: float = 0.1, 735 l1_regularization_strength: float = 0.0, 736 l2_regularization_strength: float = 0.0, 737 use_gradient_accumulation: bool = True, 738 clip_weight_min: Optional[float] = None, 739 clip_weight_max: Optional[float] = None, 740 weight_decay_factor: Optional[float] = None, 741 multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, 742 multiply_linear_by_learning_rate: bool = False, 743 beta: float = 0, 744 allow_zero_accumulator: bool = False, 745 clip_gradient_min: Optional[float] = None, 746 clip_gradient_max: Optional[float] = None, 747 ): 748 """Optimization parameters for Ftrl. 749 750 Implements FTRL as described in the following [paper]( 751 https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/41159.pdf) 752 753 Args: 754 learning_rate: a floating point value. The learning rate. 755 learning_rate_power: A float value, must be less or equal to zero. 756 Controls how the learning rate decreases during training. Use zero for a 757 fixed learning rate. See section 3.1 in the 758 [paper](https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf). 759 initial_accumulator_value: The starting value for accumulators. Only zero 760 or positive values are allowed. 761 l1_regularization_strength: A float value, must be greater than or equal 762 to zero. 763 l2_regularization_strength: A float value, must be greater than or equal 764 to zero. 765 use_gradient_accumulation: setting this to `False` makes embedding 766 gradients calculation less accurate but faster. Please see 767 `optimization_parameters.proto` for details. for details. 768 clip_weight_min: the minimum value to clip by; None means -infinity. 769 clip_weight_max: the maximum value to clip by; None means +infinity. 770 weight_decay_factor: amount of weight decay to apply; None means that the 771 weights are not decayed. 772 multiply_weight_decay_factor_by_learning_rate: if true, 773 `weight_decay_factor` is multiplied by the current learning rate. 774 multiply_linear_by_learning_rate: When true, multiplies the usages of the 775 linear slot in the weight update by the learning rate. This is useful 776 when ramping up learning rate from 0 (which would normally produce 777 NaNs). 778 beta: The beta parameter for FTRL. 779 allow_zero_accumulator: Changes the implementation of the square root to 780 allow for the case of initial_accumulator_value being zero. This will 781 cause a slight performance drop. 782 clip_gradient_min: the minimum value to clip by; None means -infinity. 783 Gradient accumulation must be set to true if this is set. 784 clip_gradient_max: the maximum value to clip by; None means +infinity. 785 Gradient accumulation must be set to true if this is set. 786 """ 787 super().__init__( 788 learning_rate=learning_rate, 789 use_gradient_accumulation=use_gradient_accumulation, 790 clip_weight_min=clip_weight_min, 791 clip_weight_max=clip_weight_max, 792 weight_decay_factor=weight_decay_factor, 793 multiply_weight_decay_factor_by_learning_rate=( 794 multiply_weight_decay_factor_by_learning_rate), 795 clip_gradient_min=clip_gradient_min, 796 clip_gradient_max=clip_gradient_max, 797 ) 798 if learning_rate_power > 0.: 799 raise ValueError('learning_rate_power must be less than or equal to 0. ' 800 'got {}.'.format(learning_rate_power)) 801 802 if initial_accumulator_value < 0.: 803 raise ValueError('initial_accumulator_value must be greater than or equal' 804 ' to 0. got {}.'.format(initial_accumulator_value)) 805 806 if l1_regularization_strength < 0.: 807 raise ValueError('l1_regularization_strength must be greater than or ' 808 'equal to 0. got {}.'.format(l1_regularization_strength)) 809 810 if l2_regularization_strength < 0.: 811 raise ValueError('l2_regularization_strength must be greater than or ' 812 'equal to 0. got {}.'.format(l2_regularization_strength)) 813 814 self.learning_rate_power = learning_rate_power 815 self.initial_accumulator_value = initial_accumulator_value 816 self.initial_linear_value = 0.0 817 self.l1_regularization_strength = l1_regularization_strength 818 self.l2_regularization_strength = l2_regularization_strength 819 self.multiply_linear_by_learning_rate = multiply_linear_by_learning_rate 820 self.beta = beta 821 self.allow_zero_accumulator = allow_zero_accumulator 822 823 824class ProximalYogiParameters(_OptimizationParameters): 825 # pylint: disable=line-too-long 826 """Optimization parameters for Proximal Yogi with TPU embeddings. 827 828 Implements the Yogi optimizer as described in 829 [Adaptive Methods for Nonconvex 830 Optimization](https://papers.nips.cc/paper/8186-adaptive-methods-for-nonconvex-optimization). 831 832 Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the 833 `optimization_parameters` argument to set the optimizer and its parameters. 834 See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec` 835 for more details. 836 """ 837 838 # pylint: enable=line-too-long 839 840 def __init__( 841 self, 842 learning_rate: float = 0.01, 843 beta1: float = 0.9, 844 beta2: float = 0.999, 845 epsilon: float = 1e-3, 846 l1_regularization_strength: float = 0.0, 847 l2_regularization_strength: float = 0.0, 848 initial_accumulator_value: float = 1e-6, 849 use_gradient_accumulation: bool = True, 850 clip_weight_min: Optional[float] = None, 851 clip_weight_max: Optional[float] = None, 852 weight_decay_factor: Optional[float] = None, 853 multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, 854 clip_gradient_min: Optional[float] = None, 855 clip_gradient_max: Optional[float] = None, 856 ): 857 """Optimization parameters for Proximal Yogi. 858 859 Args: 860 learning_rate: a floating point value. The learning rate. 861 beta1: A float value. The exponential decay rate for the 1st moment 862 estimates. 863 beta2: A float value. The exponential decay rate for the 2nd moment 864 estimates. 865 epsilon: A small constant for numerical stability. 866 l1_regularization_strength: A float value, must be greater than or equal 867 to zero. 868 l2_regularization_strength: A float value, must be greater than or equal 869 to zero. 870 initial_accumulator_value: The starting value for accumulators. Only zero 871 or positive values are allowed. 872 use_gradient_accumulation: setting this to `False` makes embedding 873 gradients calculation less accurate but faster. Please see 874 `optimization_parameters.proto` for details. for details. 875 clip_weight_min: the minimum value to clip by; None means -infinity. 876 clip_weight_max: the maximum value to clip by; None means +infinity. 877 weight_decay_factor: amount of weight decay to apply; None means that the 878 weights are not decayed. 879 multiply_weight_decay_factor_by_learning_rate: if true, 880 `weight_decay_factor` is multiplied by the current learning rate. 881 clip_gradient_min: the minimum value to clip by; None means -infinity. 882 Gradient accumulation must be set to true if this is set. 883 clip_gradient_max: the maximum value to clip by; None means +infinity. 884 Gradient accumulation must be set to true if this is set. 885 """ 886 super().__init__( 887 learning_rate=learning_rate, 888 use_gradient_accumulation=use_gradient_accumulation, 889 clip_weight_min=clip_weight_min, 890 clip_weight_max=clip_weight_max, 891 weight_decay_factor=weight_decay_factor, 892 multiply_weight_decay_factor_by_learning_rate=( 893 multiply_weight_decay_factor_by_learning_rate), 894 clip_gradient_min=clip_gradient_min, 895 clip_gradient_max=clip_gradient_max, 896 ) 897 if beta1 < 0. or beta1 >= 1.: 898 raise ValueError('beta1 must be between 0. and 1; got {}.'.format(beta1)) 899 if beta2 < 0. or beta2 >= 1.: 900 raise ValueError('beta2 must be between 0. and 1; got {}.'.format(beta2)) 901 if epsilon <= 0.: 902 raise ValueError('epsilon must be positive; got {}.'.format(epsilon)) 903 if l1_regularization_strength < 0.: 904 raise ValueError('l1_regularization_strength must be greater than or ' 905 'equal to 0. got {}.'.format(l1_regularization_strength)) 906 if l2_regularization_strength < 0.: 907 raise ValueError('l2_regularization_strength must be greater than or ' 908 'equal to 0. got {}.'.format(l2_regularization_strength)) 909 910 self.beta1 = beta1 911 self.beta2 = beta2 912 self.epsilon = epsilon 913 self.l1_regularization_strength = l1_regularization_strength 914 self.l2_regularization_strength = l2_regularization_strength 915 self.initial_accumulator_value = initial_accumulator_value 916 917 918class MomentumParameters(_OptimizationParameters): 919 """Optimization parameters for Momentum with TPU embeddings. 920 921 Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the 922 `optimization_parameters` argument to set the optimizer and its parameters. 923 See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec` 924 for more details. 925 926 ``` 927 estimator = tf.estimator.tpu.TPUEstimator( 928 ... 929 embedding_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec( 930 ... 931 optimization_parameters=tf.tpu.experimental.MomentumParameters(0.1), 932 ...)) 933 ``` 934 935 """ 936 937 def __init__( 938 self, 939 learning_rate: float, 940 momentum: float, 941 use_nesterov: bool = False, 942 use_gradient_accumulation: bool = True, 943 clip_weight_min: Optional[float] = None, 944 clip_weight_max: Optional[float] = None, 945 weight_decay_factor: Optional[float] = None, 946 multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, 947 clip_gradient_min: Optional[float] = None, 948 clip_gradient_max: Optional[float] = None, 949 ): 950 """Optimization parameters for momentum. 951 952 Args: 953 learning_rate: a floating point value. The learning rate. 954 momentum: a floating point value. The momentum. 955 use_nesterov: If `True` use Nesterov Momentum. See (Sutskever et al., 956 2013). This implementation always computes gradients at the value of the 957 variable(s) passed to the optimizer. Using Nesterov Momentum makes the 958 variable(s) track the values called `theta_t + mu*v_t` in the paper. 959 This implementation is an approximation of the original formula, valid 960 for high values of momentum. It will compute the "adjusted gradient" in 961 NAG by assuming that the new gradient will be estimated by the current 962 average gradient plus the product of momentum and the change in the 963 average gradient. 964 use_gradient_accumulation: setting this to `False` makes embedding 965 gradients calculation less accurate but faster. Please see 966 `optimization_parameters.proto` for details. 967 clip_weight_min: the minimum value to clip by; None means -infinity. 968 clip_weight_max: the maximum value to clip by; None means +infinity. 969 weight_decay_factor: amount of weight decay to apply; None means that the 970 weights are not decayed. 971 multiply_weight_decay_factor_by_learning_rate: if true, 972 `weight_decay_factor` is multiplied by the current learning rate. 973 clip_gradient_min: the minimum value to clip by; None means -infinity. 974 Gradient accumulation must be set to true if this is set. 975 clip_gradient_max: the maximum value to clip by; None means +infinity. 976 Gradient accumulation must be set to true if this is set. 977 """ 978 super().__init__( 979 learning_rate=learning_rate, 980 use_gradient_accumulation=use_gradient_accumulation, 981 clip_weight_min=clip_weight_min, 982 clip_weight_max=clip_weight_max, 983 weight_decay_factor=weight_decay_factor, 984 multiply_weight_decay_factor_by_learning_rate=( 985 multiply_weight_decay_factor_by_learning_rate), 986 clip_gradient_min=clip_gradient_min, 987 clip_gradient_max=clip_gradient_max, 988 ) 989 self.momentum = momentum 990 self.use_nesterov = use_nesterov 991 992 993class RMSPropParameters(_OptimizationParameters): 994 """Optimization parameters for RMSProp with TPU embeddings. 995 996 Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the 997 `optimization_parameters` argument to set the optimizer and its parameters. 998 See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec` 999 for more details. 1000 1001 ``` 1002 estimator = tf.estimator.tpu.TPUEstimator( 1003 ... 1004 embedding_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec( 1005 ... 1006 optimization_parameters=tf.tpu.experimental.MomentumParameters(0.1), 1007 ...)) 1008 ``` 1009 1010 """ 1011 1012 def __init__( 1013 self, 1014 learning_rate: float, 1015 rho: float, 1016 momentum: float, 1017 epsilon: float, 1018 use_gradient_accumulation: bool = True, 1019 clip_weight_min: Optional[float] = None, 1020 clip_weight_max: Optional[float] = None, 1021 weight_decay_factor: Optional[float] = None, 1022 multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, 1023 clip_gradient_min: Optional[float] = None, 1024 clip_gradient_max: Optional[float] = None, 1025 ): 1026 """Optimization parameters for RMS prop. 1027 1028 Args: 1029 learning_rate: a floating point value. The learning rate. 1030 rho: Discounting factor for the history/coming gradient 1031 momentum: A scalar tensor. 1032 epsilon: Small value to avoid zero denominator. 1033 use_gradient_accumulation: setting this to `False` makes embedding 1034 gradients calculation less accurate but faster. Please see 1035 `optimization_parameters.proto` for details. for details. 1036 clip_weight_min: the minimum value to clip by; None means -infinity. 1037 clip_weight_max: the maximum value to clip by; None means +infinity. 1038 weight_decay_factor: amount of weight decay to apply; None means that the 1039 weights are not decayed. 1040 multiply_weight_decay_factor_by_learning_rate: if true, 1041 `weight_decay_factor` is multiplied by the current learning rate. 1042 clip_gradient_min: the minimum value to clip by; None means -infinity. 1043 Gradient accumulation must be set to true if this is set. 1044 clip_gradient_max: the maximum value to clip by; None means +infinity. 1045 Gradient accumulation must be set to true if this is set. 1046 """ 1047 super().__init__( 1048 learning_rate=learning_rate, 1049 use_gradient_accumulation=use_gradient_accumulation, 1050 clip_weight_min=clip_weight_min, 1051 clip_weight_max=clip_weight_max, 1052 weight_decay_factor=weight_decay_factor, 1053 multiply_weight_decay_factor_by_learning_rate=( 1054 multiply_weight_decay_factor_by_learning_rate), 1055 clip_gradient_min=clip_gradient_min, 1056 clip_gradient_max=clip_gradient_max, 1057 ) 1058 self.rho = rho 1059 self.momentum = momentum 1060 self.epsilon = epsilon 1061 1062 1063@tf_export(v1=['tpu.experimental.StochasticGradientDescentParameters']) 1064class StochasticGradientDescentParameters(_OptimizationParameters): 1065 """Optimization parameters for stochastic gradient descent for TPU embeddings. 1066 1067 Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the 1068 `optimization_parameters` argument to set the optimizer and its parameters. 1069 See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec` 1070 for more details. 1071 1072 ``` 1073 estimator = tf.estimator.tpu.TPUEstimator( 1074 ... 1075 embedding_config_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec( 1076 ... 1077 optimization_parameters=( 1078 tf.tpu.experimental.StochasticGradientDescentParameters(0.1)))) 1079 ``` 1080 1081 """ 1082 1083 def __init__( 1084 self, 1085 learning_rate: float, 1086 use_gradient_accumulation: bool = True, 1087 clip_weight_min: Optional[float] = None, 1088 clip_weight_max: Optional[float] = None, 1089 weight_decay_factor: Optional[float] = None, 1090 multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None, 1091 clip_gradient_min: Optional[float] = None, 1092 clip_gradient_max: Optional[float] = None, 1093 ): 1094 """Optimization parameters for stochastic gradient descent. 1095 1096 Args: 1097 learning_rate: a floating point value. The learning rate. 1098 use_gradient_accumulation: setting this to `False` makes embedding 1099 gradients calculation less accurate but faster. Please see 1100 `optimization_parameters.proto` for details. 1101 clip_weight_min: the minimum value to clip by; None means -infinity. 1102 clip_weight_max: the maximum value to clip by; None means +infinity. 1103 weight_decay_factor: amount of weight decay to apply; None means that the 1104 weights are not decayed. 1105 multiply_weight_decay_factor_by_learning_rate: if true, 1106 `weight_decay_factor` is multiplied by the current learning rate. 1107 clip_gradient_min: the minimum value to clip by; None means -infinity. 1108 clip_gradient_max: the maximum value to clip by; None means +infinity. 1109 """ 1110 super().__init__( 1111 learning_rate=learning_rate, 1112 use_gradient_accumulation=use_gradient_accumulation, 1113 clip_weight_min=clip_weight_min, 1114 clip_weight_max=clip_weight_max, 1115 weight_decay_factor=weight_decay_factor, 1116 multiply_weight_decay_factor_by_learning_rate=( 1117 multiply_weight_decay_factor_by_learning_rate), 1118 clip_gradient_min=clip_gradient_min, 1119 clip_gradient_max=clip_gradient_max, 1120 ) 1121 1122 1123class FrequencyEstimatorParameters(_OptimizationParameters): 1124 """Optimization parameters for Frequency Estimator TPU embeddings. 1125 1126 This is a non-standard optimizer, which returns the estimated frequency of 1127 lookup for the feature passed to it. It should only be used on a table of 1128 width 1. The gradient fed back to the TPU embedding should always be zero. 1129 This can be acomplished via using `tf.stop_gradients` on the feature before 1130 using it. 1131 1132 You must use the dynamic learning rate mechanism to set the 'learning rate' 1133 for this table to be the a float32 cast of the global training step counter. 1134 1135 See `tensorflow/core/protobuf/tpu/optimization_parameters.proto` for more 1136 details on this optimizer. 1137 1138 Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the 1139 `optimization_parameters` argument to set the optimizer and its parameters. 1140 See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec` 1141 for more details. 1142 1143 ``` 1144 estimator = tf.estimator.tpu.TPUEstimator( 1145 ... 1146 embedding_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec( 1147 ... 1148 optimization_parameters=FrequencyEstimatorParameters(0.1), 1149 ...)) 1150 ``` 1151 1152 """ 1153 1154 def __init__(self, tau: float, max_delta: float, outlier_threshold: float, 1155 weight_exponent: float): 1156 """Optimization parameters for frequency estimator. 1157 1158 Args: 1159 tau: Learning rate between (0, 1) that is used to update the array. 1160 max_delta: Maximum value of delta, the difference between the current 1161 global step and the last global step at which the row was sampled. 1162 outlier_threshold: Threshold used to determine whether the current update 1163 is an outlier. 1164 weight_exponent: The weight exponent used to transform the estimated delta 1165 into weights. 1166 """ 1167 super().__init__( 1168 learning_rate=1.0, 1169 use_gradient_accumulation=True, 1170 clip_weight_min=None, 1171 clip_weight_max=None, 1172 weight_decay_factor=None, 1173 multiply_weight_decay_factor_by_learning_rate=None, 1174 ) 1175 self.tau = tau 1176 self.max_delta = max_delta 1177 self.outlier_threshold = outlier_threshold 1178 self.weight_exponent = weight_exponent 1179 1180 1181DeviceConfig = collections.namedtuple('DeviceConfig', 1182 ['num_hosts', 'num_cores', 'job_name']) 1183 1184 1185class TPUEmbedding: 1186 """API for using TPU for embedding. 1187 1188 Example: 1189 ``` 1190 table_config_user = tpu_embedding.TableConfig( 1191 vocabulary_size=4, dimension=2, 1192 initializer=initializer, combiner='mean') 1193 table_to_config_dict = {'video': table_config_video, 1194 'user': table_config_user} 1195 feature_to_config_dict = {'watched': tpu_embedding.FeatureConfig('video'), 1196 'favorited': tpu_embedding.FeatureConfig('video'), 1197 'friends': tpu_embedding.FeatureConfig('user')} 1198 batch_size = 4 1199 num_hosts = 1 1200 optimization_parameters = tpu_embedding.AdagradParameters(1., 1.) 1201 mode = tpu_embedding.TRAINING 1202 embedding = tpu_embedding.TPUEmbedding( 1203 table_to_config_dict, feature_to_config_dict, 1204 batch_size, num_hosts, mode, optimization_parameters) 1205 1206 batch_size_per_core = embedding.batch_size_per_core 1207 sparse_features_list = [] 1208 for host in hosts: 1209 with ops.device(host): 1210 for _ in range(embedding.num_cores_per_host): 1211 sparse_features = {} 1212 sparse_features['watched'] = sparse_tensor.SparseTensor(...) 1213 sparse_features['favorited'] = sparse_tensor.SparseTensor(...) 1214 sparse_features['friends'] = sparse_tensor.SparseTensor(...) 1215 sparse_features_list.append(sparse_features) 1216 1217 enqueue_ops = embedding.generate_enqueue_ops(sparse_features_list) 1218 embedding_variables_and_ops = embedding.create_variables_and_ops() 1219 1220 def computation(): 1221 activations = embedding.get_activations() 1222 loss = compute_loss(activations) 1223 1224 base_optimizer = gradient_descent.GradientDescentOptimizer( 1225 learning_rate=1) 1226 cross_shard_optimizer = tpu_optimizer.CrossShardOptimizer( 1227 base_optimizer) 1228 1229 train_op = cross_shard_optimizer.minimize(loss) 1230 gradients = ( 1231 tpu_embedding_gradient.get_gradients_through_compute_gradients( 1232 cross_shard_optimizer, loss, activations) 1233 send_gradients_op = embedding.generate_send_gradients_op(gradients) 1234 with ops.control_dependencies([train_op, send_gradients_op]): 1235 loss = array_ops.identity(loss) 1236 1237 loss = tpu.shard(computation, 1238 num_shards=embedding.num_cores) 1239 1240 with self.test_session() as sess: 1241 sess.run(tpu.initialize_system(embedding_config= 1242 embedding.config_proto)) 1243 sess.run(variables.global_variables_initializer()) 1244 sess.run(embedding_variables_and_ops.load_ops()) 1245 sess.run(enqueue_ops) 1246 loss_val = sess.run(loss) 1247 ``` 1248 1249 Example with weight decay: 1250 1251 >>> def learning_rate_fn(global_step): 1252 ... return tf.compat.v1.train.polynomial_decay( 1253 ... learning_rate=5e-5, 1254 ... global_step=global_step, 1255 ... decay_steps=100000, 1256 ... end_learning_rate=0.0) 1257 >>> wordpiece_table_config = TableConfig( 1258 ... vocabulary_size=119547, 1259 ... dimension=256, 1260 ... learning_rate_fn=learning_rate_fn) 1261 >>> wordpiece_feature_config = FeatureConfig( 1262 ... table_id='bert/embeddings/word_embeddings', 1263 ... max_sequence_length=512) 1264 >>> optimization_parameters = AdamParameters( 1265 ... learning_rate=5e-5, 1266 ... epsilon=1e-6, 1267 ... weight_decay_factor=0.01, 1268 ... multiply_weight_decay_factor_by_learning_rate=True) 1269 >>> tpu_embedding = TPUEmbedding( 1270 ... table_to_config_dict={ 1271 ... 'bert/embeddings/word_embeddings': wordpiece_table_config, 1272 ... }, 1273 ... feature_to_config_dict={'input_ids': wordpiece_feature_config}, 1274 ... batch_size=128, 1275 ... mode=TRAINING, 1276 ... optimization_parameters=optimization_parameters, 1277 ... master='') 1278 >>> with tf.Graph().as_default(): 1279 ... init_tpu_op = tf.compat.v1.tpu.initialize_system( 1280 ... embedding_config=tpu_embedding.config_proto) 1281 ... tf.compat.v1.Session().run(init_tpu_op) 1282 """ 1283 1284 # TODO(shizhiw): Consider adding a field to FeatureConfig that indicates that 1285 # the feature should not be used to update embedding table (cr/204852758, 1286 # cr/204940540). Also, this can support different combiners for different 1287 # features within the same table. 1288 # TODO(shizhiw, b/118512626): Remove `batch_size` from `__init__` and move it 1289 # to `FeatureConfig`? 1290 1291 # TODO(shizhiw): will it be cleaner to make `table_to_config_dict` and 1292 # `feature_to_config_dict` lists of `TableSpec` and `FeatureSpec` 1293 # respectively? 1294 1295 # TODO(shizhiw): Consider adding `input_fn` as an option to remove boilerplate 1296 # for-loops around construction of inputs. 1297 1298 # `optimization_parameter` applies to all tables. If the need arises, 1299 # we can add `optimization_parameters` to `TableConfig` to override this 1300 # global setting. 1301 def __init__(self, 1302 table_to_config_dict, 1303 feature_to_config_dict, 1304 batch_size, 1305 mode, 1306 master=None, 1307 optimization_parameters=None, 1308 cluster_def=None, 1309 pipeline_execution_with_tensor_core=False, 1310 partition_strategy='div', 1311 profile_data_directory=None, 1312 device_config=None, 1313 master_job_name=None): 1314 """API for using TPU for embedding lookups. 1315 1316 Args: 1317 table_to_config_dict: A dictionary mapping from string of table name to 1318 `TableConfig`. Table refers to an embedding table, e.g. `params` 1319 argument to `tf.nn.embedding_lookup_sparse()`. 1320 feature_to_config_dict: A dictionary mapping from string of feature name 1321 to `FeatureConfig`. Feature refers to ids to lookup in embedding table, 1322 e.g. `sp_ids` argument to `tf.nn.embedding_lookup_sparse()`. 1323 batch_size: An `int` representing the global batch size. 1324 mode: `TRAINING` or `INFERENCE`. 1325 master: A `string` representing the TensorFlow master to use. 1326 optimization_parameters: `AdagradParameters`, `AdamParameters`, 1327 `Stochasticgradientdescentparameters`. Must be set in training unless 1328 all tables specify their own optimizers. And it must be `None` in 1329 inference. 1330 cluster_def: A ClusterDef object describing the TPU cluster. 1331 pipeline_execution_with_tensor_core: setting this to `True` makes training 1332 faster, but trained model will be different if step N and step N+1 1333 involve the same set of embedding IDs. Please see 1334 `tpu_embedding_configuration.proto` for details. 1335 partition_strategy: A string, either 'mod' or 'div', specifying how to map 1336 the lookup id to the embedding tensor. For more information see 1337 `tf.nn.embedding_lookup_sparse`. 1338 profile_data_directory: Directory where embedding lookup statistics are 1339 stored. These statistics summarize information about the inputs to the 1340 embedding lookup operation, in particular, the average number of 1341 embedding IDs per example and how well the embedding IDs are load 1342 balanced across the system. The lookup statistics are used during TPU 1343 initialization for embedding table partitioning. Collection of lookup 1344 statistics is done at runtime by profiling the embedding inputs, only a 1345 small fraction of input samples are profiled to minimize host CPU 1346 overhead. Once a suitable number of samples are profiled, the lookup 1347 statistics are saved to table-specific files in the profile data 1348 directory generally at the end of a TPU training loop. The filename 1349 corresponding to each table is obtained by hashing table specific 1350 parameters (e.g., table name and number of features) and global 1351 configuration parameters (e.g., sharding strategy and task count). The 1352 same profile data directory can be shared among several models to reuse 1353 embedding lookup statistics. 1354 device_config: A DeviceConfig instance, used when `master` and 1355 `cluster_def` are both `None`. 1356 master_job_name: if set, overrides the master job name used to schedule 1357 embedding ops. 1358 1359 Raises: 1360 ValueError: if any input is invalid. 1361 """ 1362 if partition_strategy not in ('div', 'mod'): 1363 raise ValueError(f'partition_strategy must be "div" or "mod". ' 1364 f'Received: {partition_strategy}.') 1365 self._partition_strategy = partition_strategy 1366 1367 self._profile_data_directory = profile_data_directory 1368 1369 _validate_table_to_config_dict(table_to_config_dict) 1370 # Avoid nondeterminism from `Dict` iteration order by using `OrderedDict`. 1371 self._table_to_config_dict = _create_ordered_dict(table_to_config_dict) 1372 1373 _validate_feature_to_config_dict(table_to_config_dict, 1374 feature_to_config_dict) 1375 self._feature_to_config_dict = _create_ordered_dict(feature_to_config_dict) 1376 self._table_to_features_dict = ( 1377 _create_table_to_features_dict(self._feature_to_config_dict)) 1378 self._combiners = _create_combiners(self._table_to_config_dict, 1379 self._table_to_features_dict) 1380 1381 self._batch_size = batch_size 1382 1383 if master is None and cluster_def is None: 1384 if device_config is None: 1385 raise ValueError('When master and cluster_def are both None,' 1386 'device_config must be set but is not.') 1387 if device_config.num_cores % device_config.num_hosts: 1388 raise ValueError('num_hosts ({}) should divide num_cores ({}) ' 1389 'but does not.'.format(device_config.num_cores, 1390 device_config.num_hosts)) 1391 self._num_hosts = device_config.num_hosts 1392 self._num_cores = device_config.num_cores 1393 self._num_cores_per_host = self._num_cores // self._num_hosts 1394 self._hosts = [ 1395 '{}/replica:0/task:{}/device:CPU:0'.format(device_config.job_name, i) 1396 for i in range(self._num_hosts) 1397 ] 1398 else: 1399 tpu_system_metadata = ( 1400 tpu_system_metadata_lib._query_tpu_system_metadata( # pylint: disable=protected-access 1401 master, 1402 cluster_def=cluster_def)) 1403 if tpu_system_metadata.num_cores == 0: 1404 raise ValueError('TPUEmbedding needs TPUs, but master {} does not have ' 1405 'TPUs.'.format(master)) 1406 self._num_hosts = tpu_system_metadata.num_hosts 1407 if master_job_name is None: 1408 try: 1409 master_job_name = tpu_system_metadata_lib.master_job( 1410 master, cluster_def) 1411 except ValueError as e: 1412 raise ValueError(str(e) + ' Please specify a master_job_name.') 1413 self._hosts = [] 1414 for device in tpu_system_metadata.devices: 1415 if 'device:CPU:' in device.name and (master_job_name is None or 1416 master_job_name in device.name): 1417 self._hosts.append(device.name) 1418 self._num_cores_per_host = tpu_system_metadata.num_of_cores_per_host 1419 self._num_cores = tpu_system_metadata.num_cores 1420 1421 _validate_batch_size(self._batch_size, self._num_cores) 1422 self._batch_size_per_core = self._batch_size // self._num_cores 1423 1424 # TODO(shizhiw): remove `mode`? 1425 if mode == TRAINING: 1426 _validate_optimization_parameters(optimization_parameters, 1427 self._table_to_config_dict) 1428 self._optimization_parameters = optimization_parameters 1429 elif mode == INFERENCE: 1430 if optimization_parameters is not None: 1431 raise ValueError(f'`optimization_parameters` should be `None` ' 1432 f'for inference mode. ' 1433 f'Received: {optimization_parameters}.') 1434 self._optimization_parameters = (StochasticGradientDescentParameters(1.)) 1435 else: 1436 raise ValueError('`mode` only supports {} and {}; got {}.'.format( 1437 TRAINING, INFERENCE, mode)) 1438 self._mode = mode 1439 1440 # TODO(shizhiw): move `optimization_parameters` into `_optimizer_handler` 1441 # and create special handler for inference that inherits from 1442 # StochasticGradientDescentHandler with more user-friendly error message 1443 # on get_slot(). 1444 self._optimizer_handler_dict = self._get_optimizer_handler_by_table() 1445 1446 self._pipeline_execution_with_tensor_core = ( 1447 pipeline_execution_with_tensor_core) 1448 self._learning_rate_fn = list( 1449 set(c.learning_rate_fn 1450 for c in self._table_to_config_dict.values() 1451 if c.learning_rate_fn is not None)) 1452 self._learning_rate_fn_to_tag = { 1453 fn: id for id, fn in enumerate(self._learning_rate_fn) 1454 } 1455 1456 self._config_proto = self._create_config_proto() 1457 1458 @property 1459 def hosts(self): 1460 """A list of device names for CPU hosts. 1461 1462 Returns: 1463 A list of device names for CPU hosts. 1464 """ 1465 return copy.copy(self._hosts) 1466 1467 # TODO(shizhiw): change to num_tensor_cores_per_host to be more explicit and 1468 # to be consistent with `tpu_embedding_configuration.proto`. 1469 @property 1470 def num_cores_per_host(self): 1471 """Number of TPU cores on a CPU host. 1472 1473 Returns: 1474 Number of TPU cores on a CPU host. 1475 """ 1476 return self._num_cores_per_host 1477 1478 @property 1479 def num_cores(self): 1480 """Total number of TPU cores on all hosts. 1481 1482 Returns: 1483 Total number of TPU cores on all hosts. 1484 """ 1485 return self._num_cores 1486 1487 @property 1488 def batch_size_per_core(self): 1489 """Batch size for each TPU core. 1490 1491 The sparse tensors in `sparse_features_list` to `generate_enqueue_ops` 1492 must have batch dimension equal to this. 1493 1494 Returns: 1495 Batch size for each TPU core. 1496 """ 1497 return self._batch_size_per_core 1498 1499 @property 1500 def config_proto(self): 1501 """Create embedding config proto for `tpu.initialize_system()`. 1502 1503 Returns: 1504 an `TPUEmbeddingConfiguration` proto describing the desired 1505 configuration of the hardware embedding lookup tables, which 1506 is passed to `tpu.initialize_system()`. 1507 """ 1508 return self._config_proto 1509 1510 @property 1511 def table_to_config_dict(self): 1512 return copy.copy(self._table_to_config_dict) 1513 1514 @property 1515 def feature_to_config_dict(self): 1516 return copy.copy(self._feature_to_config_dict) 1517 1518 @property 1519 def table_to_features_dict(self): 1520 return copy.copy(self._table_to_features_dict) 1521 1522 @property 1523 def optimization_parameters(self): 1524 return self._optimization_parameters 1525 1526 def _create_config_proto(self): 1527 """Create `TPUEmbeddingConfiguration`.""" 1528 config_proto = elc.TPUEmbeddingConfiguration() 1529 for table in self._table_to_config_dict: 1530 table_descriptor = config_proto.table_descriptor.add() 1531 table_descriptor.name = table 1532 1533 table_config = self._table_to_config_dict[table] 1534 # For small tables, we pad to the number of hosts so that at least one 1535 # id will be assigned to each host. 1536 table_descriptor.vocabulary_size = max(table_config.vocabulary_size, 1537 len(self.hosts)) 1538 table_descriptor.dimension = table_config.dimension 1539 1540 optimization_parameters = ( 1541 self._optimizer_handler_dict[table].get_optimization_parameters()) 1542 1543 parameters = table_descriptor.optimization_parameters 1544 if table_config.learning_rate: 1545 parameters.learning_rate.constant = table_config.learning_rate 1546 elif table_config.learning_rate_fn: 1547 parameters.learning_rate.dynamic.tag = ( 1548 self._learning_rate_fn_to_tag[table_config.learning_rate_fn]) 1549 else: 1550 parameters.learning_rate.constant = ( 1551 optimization_parameters.learning_rate) 1552 parameters.gradient_accumulation_status = ( 1553 optimization_parameters_pb2.GradientAccumulationStatus.ENABLED 1554 if optimization_parameters.use_gradient_accumulation else 1555 optimization_parameters_pb2.GradientAccumulationStatus.DISABLED) 1556 1557 if optimization_parameters.clip_gradient_min is not None: 1558 parameters.gradient_clipping_limits.lower.value = ( 1559 optimization_parameters.clip_gradient_min) 1560 if optimization_parameters.clip_gradient_max is not None: 1561 parameters.gradient_clipping_limits.upper.value = ( 1562 optimization_parameters.clip_gradient_max) 1563 1564 if optimization_parameters.clip_weight_min is not None: 1565 parameters.clipping_limits.lower.value = ( 1566 optimization_parameters.clip_weight_min) 1567 if optimization_parameters.clip_weight_max is not None: 1568 parameters.clipping_limits.upper.value = ( 1569 optimization_parameters.clip_weight_max) 1570 if optimization_parameters.weight_decay_factor: 1571 parameters.weight_decay_factor = ( 1572 optimization_parameters.weight_decay_factor) 1573 if (optimization_parameters 1574 .multiply_weight_decay_factor_by_learning_rate): 1575 parameters.multiply_weight_decay_factor_by_learning_rate = True 1576 if table_config.hot_id_replication: 1577 parameters.hot_id_replication_configuration.status = ( 1578 optimization_parameters_pb2.HotIdReplicationConfiguration.ENABLED) 1579 optimizer_handler = self._optimizer_handler_dict[table] 1580 optimizer_handler.set_optimization_parameters(table_descriptor) 1581 1582 table_to_id = { 1583 table: i for i, table in enumerate(self._table_to_config_dict) 1584 } 1585 1586 # Set feature descriptor field in the config proto. 1587 for table in self._table_to_features_dict: 1588 features = self._table_to_features_dict[table] 1589 for feature in features: 1590 feature_descriptor = config_proto.feature_descriptor.add() 1591 1592 feature_descriptor.table_id = table_to_id[ 1593 self._feature_to_config_dict[feature].table_id] 1594 if self._feature_to_config_dict[feature].max_sequence_length > 0: 1595 feature_descriptor.input_shape.extend([ 1596 self._batch_size_per_core, 1597 self._feature_to_config_dict[feature].max_sequence_length 1598 ]) 1599 else: 1600 feature_descriptor.input_shape.extend([self._batch_size_per_core]) 1601 1602 config_proto.mode = self._mode 1603 config_proto.num_hosts = self._num_hosts 1604 config_proto.num_tensor_cores = self._num_cores 1605 config_proto.sharding_strategy = ( 1606 elc.TPUEmbeddingConfiguration.DIV_DEFAULT if self._partition_strategy 1607 == 'div' else elc.TPUEmbeddingConfiguration.MOD) 1608 config_proto.pipeline_execution_with_tensor_core = ( 1609 self._pipeline_execution_with_tensor_core) 1610 if self._profile_data_directory: 1611 config_proto.profile_data_directory = self._profile_data_directory 1612 1613 return config_proto 1614 1615 def create_variables_and_ops(self, 1616 embedding_variable_name_by_table=None, 1617 slot_variable_names_by_table=None): 1618 """Create embedding and slot variables, with ops to load and retrieve them. 1619 1620 N.B.: the retrieve embedding variables (including slot variables) ops are 1621 returned as lambda fn, as the call side might want to impose control 1622 dependencies between the TPU computation and retrieving actions. For 1623 example, the following code snippet ensures the TPU computation finishes 1624 first, and then we pull the variables back from TPU to CPU. 1625 1626 ``` 1627 updates_ops = [] 1628 with ops.control_dependencies([loss]): 1629 for op_fn in retrieve_parameters_op_fns: 1630 update_ops.append(op_fn()) 1631 ``` 1632 1633 Args: 1634 embedding_variable_name_by_table: A dictionary mapping from string of 1635 table name to string of embedding variable name. If `None`, defaults 1636 from `get_default_slot_variable_names()` will be used. 1637 slot_variable_names_by_table: A dictionary mapping from string of table 1638 name to `AdamSlotVariableNames`, `AdagradSlotVariableNames` etc. If 1639 `None`, defaults from `get_default_slot_variable_names()` will be used. 1640 1641 Returns: 1642 `tpu_embedding.VariablesAndOps` with: 1643 A dictionary mapping from string of table name to embedding variables, 1644 A dictionary mapping from string of table name to AdagradSlotVariables, 1645 AdamSlotVariables etc with slot variables, 1646 A function which returns a list of ops to load embedding and slot 1647 variables from CPU to TPU. 1648 A function which returns a list of ops to retrieve embedding and slot 1649 variables from TPU to CPU. 1650 """ 1651 embedding_variables_by_table = {} 1652 slot_variables_by_table = {} 1653 load_op_fns = [] 1654 retrieve_op_fns = [] 1655 1656 for i, table in enumerate(self._table_to_config_dict): 1657 if embedding_variable_name_by_table: 1658 embedding_variable_name = embedding_variable_name_by_table[table] 1659 else: 1660 embedding_variable_name = table 1661 if slot_variable_names_by_table: 1662 slot_variable_names = slot_variable_names_by_table[table] 1663 else: 1664 optimizer_handler = self._optimizer_handler_dict[table] 1665 slot_variable_names = ( 1666 optimizer_handler.get_default_slot_variable_names(table)) 1667 1668 # TODO(b/139144091): Multi-host support for mid-level API in 1669 # eager context (TF 2.0) 1670 # Workaround below allows single-host use case in TF 2.0 1671 if context.executing_eagerly(): 1672 device = '' 1673 else: 1674 device = _create_device_fn(self._hosts) 1675 1676 with ops.device(device): 1677 table_variables = _create_partitioned_variables( 1678 name=embedding_variable_name, 1679 num_hosts=self._num_hosts, 1680 vocabulary_size=self._table_to_config_dict[table].vocabulary_size, 1681 embedding_dimension=self._table_to_config_dict[table].dimension, 1682 initializer=self._table_to_config_dict[table].initializer, 1683 collections=[ops.GraphKeys.GLOBAL_VARIABLES]) 1684 embedding_variables_by_table[table] = table_variables 1685 1686 # Only loads embedding config to load/retrieve nodes for the first table 1687 # on the first host, other nodes would use config from the first node. 1688 config = None if i else self.config_proto.SerializeToString() 1689 slot_variables_for_table, load_ops_fn, retrieve_ops_fn = ( 1690 self._optimizer_handler_dict[table].create_variables_and_ops( 1691 table, slot_variable_names, self._num_hosts, 1692 self._table_to_config_dict[table], table_variables, config)) 1693 slot_variables_by_table[table] = slot_variables_for_table 1694 load_op_fns.append(load_ops_fn) 1695 retrieve_op_fns.append(retrieve_ops_fn) 1696 1697 def load_ops(): 1698 """Calls and returns the load ops for each embedding table. 1699 1700 Returns: 1701 A list of ops to load embedding and slot variables from CPU to TPU. 1702 """ 1703 load_ops_list = [] 1704 for load_op_fn in load_op_fns: 1705 load_ops_list.extend(load_op_fn()) 1706 return load_ops_list 1707 1708 def retrieve_ops(): 1709 """Calls and returns the retrieve ops for each embedding table. 1710 1711 Returns: 1712 A list of ops to retrieve embedding and slot variables from TPU to CPU. 1713 """ 1714 retrieve_ops_list = [] 1715 for retrieve_op_fn in retrieve_op_fns: 1716 retrieve_ops_list.extend(retrieve_op_fn()) 1717 return retrieve_ops_list 1718 1719 return VariablesAndOps(embedding_variables_by_table, 1720 slot_variables_by_table, load_ops, retrieve_ops) 1721 1722 def generate_enqueue_ops( 1723 self, 1724 enqueue_datas_list, 1725 mode_override=None, 1726 ragged=False, 1727 ): 1728 """Generate enqueue ops. 1729 1730 Args: 1731 enqueue_datas_list: a list of dictionary mapping from string of feature 1732 names to EnqueueData. Each dictionary is for one TPU core. Dictionaries 1733 for the same host should be contiguous in the list. 1734 mode_override: A string input that overrides the mode specified in the 1735 TPUEmbeddingConfiguration. Supported values are {'unspecified', 1736 'inference', 'training', 'backward_pass_only'}. When set to 1737 'unspecified', the mode set in TPUEmbeddingConfiguration is used, 1738 otherwise mode_override is used (optional). 1739 ragged: If True, creates RaggedTensor enqueue ops rather than 1740 SparseTensor. 1741 1742 Returns: 1743 Ops to enqueue to TPU for embedding. 1744 """ 1745 self._validate_generate_enqueue_ops_enqueue_datas_list(enqueue_datas_list) 1746 return [ 1747 self._generate_enqueue_op( # pylint: disable=g-complex-comprehension 1748 enqueue_datas, 1749 device_ordinal=i % self._num_cores_per_host, 1750 mode_override=mode_override, 1751 ragged=ragged, 1752 ) for i, enqueue_datas in enumerate(enqueue_datas_list) 1753 ] 1754 1755 def _validate_generate_enqueue_ops_enqueue_datas_list(self, 1756 enqueue_datas_list): 1757 """Validate `enqueue_datas_list`.""" 1758 1759 def _check_agreement(data, name, feature, enqueue_data): 1760 """Helper function to check device agreement.""" 1761 if (data is not None and 1762 data.device != enqueue_data.embedding_indices.device): 1763 raise ValueError('Device of {0} does not agree with that of' 1764 'embedding_indices for feature {1}.'.format( 1765 name, feature)) 1766 1767 feature_set = set(self._feature_to_config_dict.keys()) 1768 contiguous_device = None 1769 for i, enqueue_datas in enumerate(enqueue_datas_list): 1770 used_feature_set = set(enqueue_datas.keys()) 1771 1772 # Check features are valid. 1773 missing_feature_set = feature_set - used_feature_set 1774 if missing_feature_set: 1775 raise ValueError('`enqueue_datas_list[{}]` misses a feature that is ' 1776 'in `feature_to_config_dict`: {}.'.format( 1777 i, missing_feature_set)) 1778 1779 extra_feature_set = used_feature_set - feature_set 1780 if extra_feature_set: 1781 raise ValueError('`enqueue_datas_list[{}]` has a feature that is not ' 1782 'in `feature_to_config_dict`: {}.'.format( 1783 i, extra_feature_set)) 1784 1785 device = None 1786 device_feature = None 1787 for feature, enqueue_data in enqueue_datas.items(): 1788 combiner = self._table_to_config_dict[ 1789 self._feature_to_config_dict[feature].table_id].combiner 1790 1791 if isinstance(enqueue_data, EnqueueData): 1792 if enqueue_data.sample_indices is None and combiner: 1793 logging.warn( 1794 'No sample indices set for features %f table %f but ' 1795 'combiner is set to %s.', feature, 1796 self._feature_to_config_dict[feature].table_id, combiner) 1797 _check_agreement(enqueue_data.sample_indices, 'sample_indices', 1798 feature, enqueue_data) 1799 _check_agreement(enqueue_data.aggregation_weights, 1800 'aggregation_weights', feature, enqueue_data) 1801 1802 elif isinstance(enqueue_data, RaggedEnqueueData): 1803 if enqueue_data.row_splits is None and combiner: 1804 logging.warn( 1805 'No row splits set for features %f table %f but ' 1806 'combiner is set to %s.', feature, 1807 self._feature_to_config_dict[feature].table_id, combiner) 1808 _check_agreement(enqueue_data.row_splits, 'row_splits', feature, 1809 enqueue_data) 1810 _check_agreement(enqueue_data.aggregation_weights, 1811 'aggregation_weights', feature, enqueue_data) 1812 else: 1813 raise ValueError( 1814 '`enqueue_datas_list[{}]` has a feature that is not mapped to ' 1815 '`EnqueueData` or `RaggedEnqueueData`. `feature`: {}'.format( 1816 i, feature)) 1817 # Check all features are on the same device. 1818 if device is None: 1819 device = enqueue_data.embedding_indices.device 1820 device_feature = feature 1821 else: 1822 if device != enqueue_data.embedding_indices.device: 1823 raise ValueError('Devices are different between features in ' 1824 '`enqueue_datas_list[{}]`; ' 1825 'devices: {}, {}; features: {}, {}.'.format( 1826 i, device, 1827 enqueue_data.embedding_indices.device, feature, 1828 device_feature)) 1829 1830 if i % self._num_cores_per_host: 1831 if device != contiguous_device: 1832 raise ValueError('We expect the `enqueue_datas` which are on the ' 1833 'same host to be contiguous in ' 1834 '`enqueue_datas_list`, ' 1835 '`enqueue_datas_list[{}]` is on device {}, ' 1836 'but is expected to be on device {}.'.format( 1837 i, device, contiguous_device)) 1838 else: 1839 contiguous_device = device 1840 1841 def _generate_enqueue_op(self, 1842 enqueue_datas, 1843 device_ordinal, 1844 mode_override=None, 1845 ragged=False): 1846 """Creates op for enqueuing batch to TPU.""" 1847 enqueue_data0 = list(enqueue_datas.values())[0] 1848 with ops.colocate_with(enqueue_data0.embedding_indices): 1849 return tpu_ops.enqueue_tpu_embedding_arbitrary_tensor_batch( 1850 device_ordinal=device_ordinal, 1851 combiners=self._combiners, 1852 mode_override=mode_override, 1853 **self._format_for_tpu_embedding_arbitrary_tensor_batch( 1854 enqueue_datas, ragged)) 1855 1856 def _format_for_tpu_embedding_arbitrary_tensor_batch(self, enqueue_datas, 1857 ragged): 1858 """Format features for `enqueue_tpu_embedding_arbitrary_tensor_batch()`. 1859 1860 Args: 1861 enqueue_datas: a `Dict` of `RaggedEnqueueData` objects for embedding. 1862 ragged: If True, extract row splits from the data rather than sample 1863 indices. 1864 1865 Returns: 1866 Dict of arguments for `enqueue_tpu_embedding_arbitrary_tensor_batch()`. 1867 """ 1868 1869 kwargs = { 1870 'sample_indices_or_row_splits': [], 1871 'embedding_indices': [], 1872 'aggregation_weights': [], 1873 } 1874 int_zeros = array_ops.zeros((0,), dtype=dtypes.int64) 1875 float_zeros = array_ops.zeros((0,), dtype=dtypes.float32) 1876 for table in self._table_to_features_dict: 1877 features = self._table_to_features_dict[table] 1878 for feature in features: 1879 enqueue_data = enqueue_datas[feature] 1880 if ragged: 1881 kwargs['sample_indices_or_row_splits'].append( 1882 enqueue_data.row_splits if enqueue_data 1883 .row_splits is not None else int_zeros) 1884 else: 1885 if (self._feature_to_config_dict[feature].max_sequence_length > 0 and 1886 enqueue_data.sample_indices is not None and 1887 enqueue_data.sample_indices.shape[1] == 2): 1888 # Pad the sample indices as if the enqueued sparse tensor is rank 2. 1889 sample_indices = array_ops.pad( 1890 enqueue_data.sample_indices, paddings=[[0, 0], [0, 1]]) 1891 kwargs['sample_indices_or_row_splits'].append(sample_indices) 1892 else: 1893 # If the sample_indices is rank 1 or not present, treat it as dense 1894 # tensor. 1895 if (enqueue_data.sample_indices is None or 1896 enqueue_data.sample_indices.shape[1] == 1): 1897 kwargs['sample_indices_or_row_splits'].append(int_zeros) 1898 else: 1899 kwargs['sample_indices_or_row_splits'].append( 1900 enqueue_data.sample_indices) 1901 1902 kwargs['aggregation_weights'].append( 1903 enqueue_data.aggregation_weights if enqueue_data 1904 .aggregation_weights is not None else float_zeros) 1905 1906 kwargs['embedding_indices'].append(enqueue_data.embedding_indices) 1907 return kwargs 1908 1909 def get_activations(self): 1910 """Get activations for features. 1911 1912 This should be called within `computation` that is passed to 1913 `tpu.replicate` and friends. 1914 1915 Returns: 1916 A dictionary mapping from `String` of feature name to `Tensor` 1917 of activation. 1918 """ 1919 recv_activations = tpu_ops.recv_tpu_embedding_activations( 1920 num_outputs=len(self._feature_to_config_dict), 1921 config=self._config_proto.SerializeToString()) 1922 1923 activations = collections.OrderedDict() 1924 index = 0 1925 for table in self._table_to_features_dict: 1926 for feature in self._table_to_features_dict[table]: 1927 activations[feature] = recv_activations[index] 1928 index += 1 1929 return activations 1930 1931 def generate_send_gradients_op(self, feature_to_gradient_dict, step=None): 1932 """Send gradient to TPU embedding. 1933 1934 Args: 1935 feature_to_gradient_dict: dict mapping feature names to gradient wrt 1936 activations. 1937 step: the current global step, used for dynamic learning rate. 1938 1939 Returns: 1940 SendTPUEmbeddingGradients Op. 1941 1942 Raises: 1943 RuntimeError: If `mode` is not `TRAINING`. 1944 """ 1945 if self._mode != TRAINING: 1946 raise RuntimeError('Only in training mode gradients need to ' 1947 'be sent to TPU embedding; got mode {}.'.format( 1948 self._mode)) 1949 if step is None and self._learning_rate_fn: 1950 raise ValueError('There are dynamic learning rates but step is None.') 1951 1952 gradients = [] 1953 for table in self._table_to_features_dict: 1954 for feature in self._table_to_features_dict[table]: 1955 gradients.append(feature_to_gradient_dict[feature]) 1956 1957 return tpu_ops.send_tpu_embedding_gradients( 1958 inputs=gradients, 1959 learning_rates=[ 1960 math_ops.cast(fn(step), dtype=dtypes.float32) 1961 for fn in self._learning_rate_fn 1962 ], 1963 config=self.config_proto.SerializeToString()) 1964 1965 def _get_optimizer_handler_by_table(self): 1966 optimizer_handlers = {} 1967 for table, table_config in self.table_to_config_dict.items(): 1968 if table_config.optimization_parameters is not None: 1969 optimizer = table_config.optimization_parameters 1970 else: 1971 optimizer = self._optimization_parameters 1972 optimizer_handlers[table] = _get_optimization_handler(optimizer) 1973 1974 return optimizer_handlers 1975 1976 1977def _validate_table_to_config_dict(table_to_config_dict): 1978 """Validate `table_to_config_dict`.""" 1979 for k, v in table_to_config_dict.items(): 1980 if not isinstance(v, TableConfig): 1981 raise ValueError('Value of `table_to_config_dict` must be of type ' 1982 '`TableConfig`, got {} for {}.'.format(type(v), k)) 1983 1984 1985def _validate_feature_to_config_dict(table_to_config_dict, 1986 feature_to_config_dict): 1987 """Validate `feature_to_config_dict`.""" 1988 used_table_set = set( 1989 [feature.table_id for feature in feature_to_config_dict.values()]) 1990 table_set = set(table_to_config_dict.keys()) 1991 1992 unused_table_set = table_set - used_table_set 1993 if unused_table_set: 1994 raise ValueError( 1995 '`table_to_config_dict` specifies table that is not ' 1996 'used in `feature_to_config_dict`: {}.'.format(unused_table_set)) 1997 1998 extra_table_set = used_table_set - table_set 1999 if extra_table_set: 2000 raise ValueError( 2001 '`feature_to_config_dict` refers to a table that is not ' 2002 'specified in `table_to_config_dict`: {}.'.format(extra_table_set)) 2003 2004 2005def _validate_batch_size(batch_size, num_cores): 2006 if batch_size % num_cores: 2007 raise ValueError('`batch_size` is not a multiple of number of ' 2008 'cores. `batch_size`={}, `_num_cores`={}.'.format( 2009 batch_size, num_cores)) 2010 2011 2012def _validate_optimization_parameters(optimization_parameters, 2013 table_to_config_dict): 2014 """Validate global optimization_parameters and per table optimizers. 2015 2016 If global optimizer is `None`, all table optimizers should be non `None`. 2017 2018 Args: 2019 optimization_parameters: global optimizer provided in `TPUEmbedding` 2020 constructor. 2021 table_to_config_dict: A dictionary mapping from string of table name to 2022 `TableConfig`. 2023 """ 2024 tbl_optimizer_missing = False 2025 for _, table_config in table_to_config_dict.items(): 2026 if table_config.optimization_parameters is None: 2027 tbl_optimizer_missing = True 2028 break 2029 2030 if optimization_parameters: 2031 if not isinstance(optimization_parameters, _OptimizationParameters): 2032 raise ValueError('`optimization_parameters` must inherit from ' 2033 '`_OptimizationParameters`. ' 2034 '`type(optimization_parameters)`={}'.format( 2035 type(optimization_parameters))) 2036 else: 2037 # Missing global optimization_parameters. 2038 if tbl_optimizer_missing: 2039 raise ValueError('`optimization_parameters` is missing.') 2040 2041 2042class _OptimizerHandler: 2043 """Interface class for handling optimizer specific logic.""" 2044 2045 def __init__(self, optimization_parameters): 2046 self._optimization_parameters = optimization_parameters 2047 2048 def get_optimization_parameters(self): 2049 return self._optimization_parameters 2050 2051 def set_optimization_parameters(self, table_descriptor): 2052 raise NotImplementedError() 2053 2054 def get_default_slot_variable_names(self, table): 2055 raise NotImplementedError() 2056 2057 def create_variables_and_ops(self, table, slot_variable_names, num_hosts, 2058 table_config, table_variables, config_proto): 2059 raise NotImplementedError() 2060 2061 2062class _AdagradHandler(_OptimizerHandler): 2063 """Handles Adagrad specific logic.""" 2064 2065 def set_optimization_parameters(self, table_descriptor): 2066 table_descriptor.optimization_parameters.adagrad.SetInParent() 2067 2068 def get_default_slot_variable_names(self, table): 2069 return AdagradSlotVariableNames('{}/{}'.format(table, 'Adagrad')) 2070 2071 def create_variables_and_ops(self, table, slot_variable_names, num_hosts, 2072 table_config, table_variables, config_proto): 2073 accumulator_initializer = init_ops.constant_initializer( 2074 self._optimization_parameters.initial_accumulator) 2075 accumulator_variables = _create_partitioned_variables( 2076 name=slot_variable_names.accumulator, 2077 num_hosts=num_hosts, 2078 vocabulary_size=table_config.vocabulary_size, 2079 embedding_dimension=table_config.dimension, 2080 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 2081 initializer=accumulator_initializer) 2082 slot_variables = AdagradSlotVariables(accumulator_variables) 2083 2084 def load_ops_fn(): 2085 """Returns the retrieve ops for AdaGrad embedding tables. 2086 2087 Returns: 2088 A list of ops to load embedding and slot variables from CPU to TPU. 2089 """ 2090 config = config_proto 2091 load_op_list = [] 2092 for host_id, table_variable, accumulator_variable in zip( 2093 range(num_hosts), table_variables, accumulator_variables): 2094 with ops.colocate_with(table_variable): 2095 load_parameters_op = ( 2096 tpu_ops.load_tpu_embedding_adagrad_parameters( 2097 parameters=table_variable, 2098 accumulators=accumulator_variable, 2099 table_name=table, 2100 num_shards=num_hosts, 2101 shard_id=host_id, 2102 config=config)) 2103 config = None 2104 load_op_list.append(load_parameters_op) 2105 return load_op_list 2106 2107 def retrieve_ops_fn(): 2108 """Returns the retrieve ops for AdaGrad embedding tables. 2109 2110 Returns: 2111 A list of ops to retrieve embedding and slot variables from TPU to CPU. 2112 """ 2113 config = config_proto 2114 retrieve_op_list = [] 2115 for host_id, table_variable, accumulator_variable in (zip( 2116 range(num_hosts), table_variables, accumulator_variables)): 2117 with ops.colocate_with(table_variable): 2118 retrieved_table, retrieved_accumulator = ( 2119 tpu_ops.retrieve_tpu_embedding_adagrad_parameters( 2120 table_name=table, 2121 num_shards=num_hosts, 2122 shard_id=host_id, 2123 config=config)) 2124 retrieve_parameters_op = control_flow_ops.group( 2125 state_ops.assign(table_variable, retrieved_table), 2126 state_ops.assign(accumulator_variable, retrieved_accumulator)) 2127 config = None 2128 retrieve_op_list.append(retrieve_parameters_op) 2129 return retrieve_op_list 2130 2131 return slot_variables, load_ops_fn, retrieve_ops_fn 2132 2133 2134class _AdagradMomentumHandler(_OptimizerHandler): 2135 """Handles Adagrad with Momentum specific logic. 2136 2137 Creates slot variables and defines their initializers. Defines load/retrieve 2138 operations to be used for loading variables into TPU memory (from host memory) 2139 and retrieving variables from TPU memory (into host memory). 2140 """ 2141 2142 def set_optimization_parameters(self, table_descriptor): 2143 table_descriptor.optimization_parameters.adagrad_momentum.SetInParent() 2144 table_descriptor.optimization_parameters.adagrad_momentum.momentum = ( 2145 self._optimization_parameters.momentum) 2146 table_descriptor.optimization_parameters.adagrad_momentum.use_nesterov = ( 2147 self._optimization_parameters.use_nesterov) 2148 table_descriptor.optimization_parameters.adagrad_momentum.exponent = ( 2149 self._optimization_parameters.exponent) 2150 table_descriptor.optimization_parameters.adagrad_momentum.beta2 = ( 2151 self._optimization_parameters.beta2) 2152 table_descriptor.optimization_parameters.adagrad_momentum.epsilon = ( 2153 self._optimization_parameters.epsilon) 2154 2155 def get_default_slot_variable_names(self, table): 2156 return AdagradMomentumSlotVariableNames( 2157 '{}/{}/Accumulator'.format(table, 'AdagradMomentum'), 2158 '{}/{}/Momentum'.format(table, 'AdagradMomentum')) 2159 2160 def create_variables_and_ops(self, table, slot_variable_names, num_hosts, 2161 table_config, table_variables, config_proto): 2162 accumulator_initializer = init_ops.zeros_initializer() 2163 accumulator_variables = _create_partitioned_variables( 2164 name=slot_variable_names.accumulator, 2165 num_hosts=num_hosts, 2166 vocabulary_size=table_config.vocabulary_size, 2167 embedding_dimension=table_config.dimension, 2168 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 2169 initializer=accumulator_initializer) 2170 momenta_initializer = init_ops.zeros_initializer() 2171 momenta_variables = _create_partitioned_variables( 2172 name=slot_variable_names.momenta, 2173 num_hosts=num_hosts, 2174 vocabulary_size=table_config.vocabulary_size, 2175 embedding_dimension=table_config.dimension, 2176 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 2177 initializer=momenta_initializer) 2178 slot_variables = AdagradMomentumSlotVariables(accumulator_variables, 2179 momenta_variables) 2180 2181 def load_ops_fn(): 2182 """Returns the load ops for AdaGrad with momentum embedding tables. 2183 2184 Returns: 2185 A list of ops to load embedding and slot variables from CPU to TPU. 2186 """ 2187 config = config_proto 2188 load_op_list = [] 2189 for host_id, table_variable, accumulator_variable, momenta_variable in zip( 2190 range(num_hosts), table_variables, accumulator_variables, 2191 momenta_variables): 2192 with ops.colocate_with(table_variable): 2193 load_parameters_op = ( 2194 tpu_ops.load_tpu_embedding_adagrad_momentum_parameters( 2195 parameters=table_variable, 2196 accumulators=accumulator_variable, 2197 momenta=momenta_variable, 2198 table_name=table, 2199 num_shards=num_hosts, 2200 shard_id=host_id, 2201 config=config)) 2202 config = None 2203 load_op_list.append(load_parameters_op) 2204 return load_op_list 2205 2206 def retrieve_ops_fn(): 2207 """Returns the retrieve ops for AdaGrad with momentum embedding tables. 2208 2209 Returns: 2210 A list of ops to retrieve embedding and slot variables from TPU to CPU. 2211 """ 2212 config = config_proto 2213 retrieve_op_list = [] 2214 for host_id, table_variable, accumulator_variable, momenta_variable in ( 2215 zip( 2216 range(num_hosts), table_variables, accumulator_variables, 2217 momenta_variables)): 2218 with ops.colocate_with(table_variable): 2219 retrieved_table, retrieved_accumulator, retrieved_momenta = ( 2220 tpu_ops.retrieve_tpu_embedding_adagrad_momentum_parameters( 2221 table_name=table, 2222 num_shards=num_hosts, 2223 shard_id=host_id, 2224 config=config)) 2225 retrieve_parameters_op = control_flow_ops.group( 2226 state_ops.assign(table_variable, retrieved_table), 2227 state_ops.assign(accumulator_variable, retrieved_accumulator), 2228 state_ops.assign(momenta_variable, retrieved_momenta)) 2229 config = None 2230 retrieve_op_list.append(retrieve_parameters_op) 2231 return retrieve_op_list 2232 2233 return slot_variables, load_ops_fn, retrieve_ops_fn 2234 2235 2236class _ProximalAdagradHandler(_OptimizerHandler): 2237 """Handles ProximalAdagrad specific logic.""" 2238 2239 def set_optimization_parameters(self, table_descriptor): 2240 table_descriptor.optimization_parameters.proximal_adagrad.SetInParent() 2241 table_descriptor.optimization_parameters.proximal_adagrad.l1 = ( 2242 self._optimization_parameters.l1_regularization_strength) 2243 table_descriptor.optimization_parameters.proximal_adagrad.l2 = ( 2244 self._optimization_parameters.l2_regularization_strength) 2245 2246 def get_default_slot_variable_names(self, table): 2247 return ProximalAdagradSlotVariableNames('{}/{}'.format( 2248 table, 'ProximalAdagrad')) 2249 2250 def create_variables_and_ops(self, table, slot_variable_names, num_hosts, 2251 table_config, table_variables, config_proto): 2252 accumulator_initializer = init_ops.constant_initializer( 2253 self._optimization_parameters.initial_accumulator) 2254 accumulator_variables = _create_partitioned_variables( 2255 name=slot_variable_names.accumulator, 2256 num_hosts=num_hosts, 2257 vocabulary_size=table_config.vocabulary_size, 2258 embedding_dimension=table_config.dimension, 2259 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 2260 initializer=accumulator_initializer) 2261 slot_variables = ProximalAdagradSlotVariables(accumulator_variables) 2262 2263 def load_ops_fn(): 2264 """Returns the retrieve ops for Proximal AdaGrad embedding tables. 2265 2266 Returns: 2267 A list of ops to load embedding and slot variables from CPU to TPU. 2268 """ 2269 config = config_proto 2270 load_op_list = [] 2271 for host_id, table_variable, accumulator_variable in zip( 2272 range(num_hosts), table_variables, accumulator_variables): 2273 with ops.colocate_with(table_variable): 2274 load_parameters_op = ( 2275 tpu_ops.load_tpu_embedding_proximal_adagrad_parameters( 2276 parameters=table_variable, 2277 accumulators=accumulator_variable, 2278 table_name=table, 2279 num_shards=num_hosts, 2280 shard_id=host_id, 2281 config=config)) 2282 config = None 2283 load_op_list.append(load_parameters_op) 2284 return load_op_list 2285 2286 def retrieve_ops_fn(): 2287 """Returns the retrieve ops for Proximal AdaGrad embedding tables. 2288 2289 Returns: 2290 A list of ops to retrieve embedding and slot variables from TPU to CPU. 2291 """ 2292 config = config_proto 2293 retrieve_op_list = [] 2294 for host_id, table_variable, accumulator_variable in (zip( 2295 range(num_hosts), table_variables, accumulator_variables)): 2296 with ops.colocate_with(table_variable): 2297 retrieved_table, retrieved_accumulator = ( 2298 tpu_ops.retrieve_tpu_embedding_proximal_adagrad_parameters( 2299 table_name=table, 2300 num_shards=num_hosts, 2301 shard_id=host_id, 2302 config=config)) 2303 retrieve_parameters_op = control_flow_ops.group( 2304 state_ops.assign(table_variable, retrieved_table), 2305 state_ops.assign(accumulator_variable, retrieved_accumulator)) 2306 config = None 2307 retrieve_op_list.append(retrieve_parameters_op) 2308 return retrieve_op_list 2309 2310 return slot_variables, load_ops_fn, retrieve_ops_fn 2311 2312 2313class _AdamHandler(_OptimizerHandler): 2314 """Handles Adam specific logic.""" 2315 2316 def set_optimization_parameters(self, table_descriptor): 2317 table_descriptor.optimization_parameters.adam.beta1 = ( 2318 self._optimization_parameters.beta1) 2319 table_descriptor.optimization_parameters.adam.beta2 = ( 2320 self._optimization_parameters.beta2) 2321 table_descriptor.optimization_parameters.adam.epsilon = ( 2322 self._optimization_parameters.epsilon) 2323 table_descriptor.optimization_parameters.adam.use_non_lazy_adam = ( 2324 not self._optimization_parameters.lazy_adam) 2325 table_descriptor.optimization_parameters.adam.use_sum_inside_sqrt = ( 2326 self._optimization_parameters.sum_inside_sqrt) 2327 2328 def get_default_slot_variable_names(self, table): 2329 return AdamSlotVariableNames('{}/{}/m'.format(table, 'Adam'), 2330 '{}/{}/v'.format(table, 'Adam')) 2331 2332 def create_variables_and_ops(self, table, slot_variable_names, num_hosts, 2333 table_config, table_variables, config_proto): 2334 m_initializer = init_ops.zeros_initializer() 2335 m_variables = _create_partitioned_variables( 2336 name=slot_variable_names.m, 2337 num_hosts=num_hosts, 2338 vocabulary_size=table_config.vocabulary_size, 2339 embedding_dimension=table_config.dimension, 2340 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 2341 initializer=m_initializer) 2342 v_initializer = init_ops.zeros_initializer() 2343 v_variables = _create_partitioned_variables( 2344 name=slot_variable_names.v, 2345 num_hosts=num_hosts, 2346 vocabulary_size=table_config.vocabulary_size, 2347 embedding_dimension=table_config.dimension, 2348 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 2349 initializer=v_initializer) 2350 slot_variables = AdamSlotVariables(m_variables, v_variables) 2351 2352 def load_ops_fn(): 2353 """Returns the retrieve ops for AdaGrad embedding tables. 2354 2355 Returns: 2356 A list of ops to load embedding and slot variables from CPU to TPU. 2357 """ 2358 load_op_list = [] 2359 config = config_proto 2360 for host_id, table_variable, m_variable, v_variable in (zip( 2361 range(num_hosts), table_variables, m_variables, v_variables)): 2362 with ops.colocate_with(table_variable): 2363 load_parameters_op = ( 2364 tpu_ops.load_tpu_embedding_adam_parameters( 2365 parameters=table_variable, 2366 momenta=m_variable, 2367 velocities=v_variable, 2368 table_name=table, 2369 num_shards=num_hosts, 2370 shard_id=host_id, 2371 config=config)) 2372 # Set config to None to enforce that config is only loaded to the first 2373 # table. 2374 config = None 2375 load_op_list.append(load_parameters_op) 2376 return load_op_list 2377 2378 def retrieve_ops_fn(): 2379 """Returns the retrieve ops for Adam embedding tables. 2380 2381 Returns: 2382 A list of ops to retrieve embedding and slot variables from TPU to CPU. 2383 """ 2384 retrieve_op_list = [] 2385 config = config_proto 2386 for host_id, table_variable, m_variable, v_variable in (zip( 2387 range(num_hosts), table_variables, m_variables, v_variables)): 2388 with ops.colocate_with(table_variable): 2389 retrieved_table, retrieved_m, retrieved_v = ( 2390 tpu_ops.retrieve_tpu_embedding_adam_parameters( 2391 table_name=table, 2392 num_shards=num_hosts, 2393 shard_id=host_id, 2394 config=config)) 2395 retrieve_parameters_op = control_flow_ops.group( 2396 state_ops.assign(table_variable, retrieved_table), 2397 state_ops.assign(m_variable, retrieved_m), 2398 state_ops.assign(v_variable, retrieved_v)) 2399 config = None 2400 retrieve_op_list.append(retrieve_parameters_op) 2401 return retrieve_op_list 2402 2403 return slot_variables, load_ops_fn, retrieve_ops_fn 2404 2405 2406class _FtrlHandler(_OptimizerHandler): 2407 """Handles Ftrl specific logic.""" 2408 2409 def set_optimization_parameters(self, table_descriptor): 2410 table_descriptor.optimization_parameters.ftrl.lr_power = ( 2411 self._optimization_parameters.learning_rate_power) 2412 table_descriptor.optimization_parameters.ftrl.l1 = ( 2413 self._optimization_parameters.l1_regularization_strength) 2414 table_descriptor.optimization_parameters.ftrl.l2 = ( 2415 self._optimization_parameters.l2_regularization_strength) 2416 table_descriptor.optimization_parameters.ftrl.multiply_linear_by_lr = ( 2417 self._optimization_parameters.multiply_linear_by_learning_rate) 2418 table_descriptor.optimization_parameters.ftrl.beta = ( 2419 self._optimization_parameters.beta) 2420 table_descriptor.optimization_parameters.ftrl.allow_zero_accumulator = ( 2421 self._optimization_parameters.allow_zero_accumulator) 2422 2423 def get_default_slot_variable_names(self, table): 2424 # These match the default slot variable names created by 2425 # tf.train.FtrlOptimizer. 2426 return FtrlSlotVariableNames( 2427 '{}/{}'.format(table, 'Ftrl'), # accumulator 2428 '{}/{}'.format(table, 'Ftrl_1')) # linear 2429 2430 def create_variables_and_ops(self, table, slot_variable_names, num_hosts, 2431 table_config, table_variables, config_proto): 2432 accumulator_initializer = init_ops.constant_initializer( 2433 self._optimization_parameters.initial_accumulator_value) 2434 accumulator_variables = _create_partitioned_variables( 2435 name=slot_variable_names.accumulator, 2436 num_hosts=num_hosts, 2437 vocabulary_size=table_config.vocabulary_size, 2438 embedding_dimension=table_config.dimension, 2439 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 2440 initializer=accumulator_initializer) 2441 linear_initializer = init_ops.constant_initializer( 2442 self._optimization_parameters.initial_linear_value) 2443 linear_variables = _create_partitioned_variables( 2444 name=slot_variable_names.linear, 2445 num_hosts=num_hosts, 2446 vocabulary_size=table_config.vocabulary_size, 2447 embedding_dimension=table_config.dimension, 2448 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 2449 initializer=linear_initializer) 2450 slot_variables = FtrlSlotVariable(accumulator_variables, linear_variables) 2451 2452 def load_ops_fn(): 2453 """Returns the retrieve ops for Ftrl embedding tables. 2454 2455 Returns: 2456 A list of ops to load embedding and slot variables from CPU to TPU. 2457 """ 2458 config = config_proto 2459 load_op_list = [] 2460 for host_id, table_variable, accumulator_variable, linear_variable in zip( 2461 range(num_hosts), table_variables, accumulator_variables, 2462 linear_variables): 2463 with ops.colocate_with(table_variable): 2464 load_parameters_op = ( 2465 tpu_ops.load_tpu_embedding_ftrl_parameters( 2466 parameters=table_variable, 2467 accumulators=accumulator_variable, 2468 linears=linear_variable, 2469 table_name=table, 2470 num_shards=num_hosts, 2471 shard_id=host_id, 2472 config=config)) 2473 config = None 2474 load_op_list.append(load_parameters_op) 2475 return load_op_list 2476 2477 def retrieve_ops_fn(): 2478 """Returns the retrieve ops for Ftrl embedding tables. 2479 2480 Returns: 2481 A list of ops to retrieve embedding and slot variables from TPU to CPU. 2482 """ 2483 config = config_proto 2484 retrieve_op_list = [] 2485 for host_id, table_variable, accumulator_variable, linear_variable in zip( 2486 range(num_hosts), table_variables, accumulator_variables, 2487 linear_variables): 2488 with ops.colocate_with(table_variable): 2489 retrieved_table, retrieved_accumulator, retrieved_linear = ( 2490 tpu_ops.retrieve_tpu_embedding_ftrl_parameters( 2491 table_name=table, 2492 num_shards=num_hosts, 2493 shard_id=host_id, 2494 config=config)) 2495 retrieve_parameters_op = control_flow_ops.group( 2496 state_ops.assign(table_variable, retrieved_table), 2497 state_ops.assign(accumulator_variable, retrieved_accumulator), 2498 state_ops.assign(linear_variable, retrieved_linear)) 2499 config = None 2500 retrieve_op_list.append(retrieve_parameters_op) 2501 return retrieve_op_list 2502 2503 return slot_variables, load_ops_fn, retrieve_ops_fn 2504 2505 2506class _ProximalYogiHandler(_OptimizerHandler): 2507 """Handles Proximal Yogi specific logic.""" 2508 2509 def set_optimization_parameters(self, table_descriptor): 2510 table_descriptor.optimization_parameters.proximal_yogi.SetInParent() 2511 table_descriptor.optimization_parameters.proximal_yogi.beta1 = ( 2512 self._optimization_parameters.beta1) 2513 table_descriptor.optimization_parameters.proximal_yogi.beta2 = ( 2514 self._optimization_parameters.beta2) 2515 table_descriptor.optimization_parameters.proximal_yogi.epsilon = ( 2516 self._optimization_parameters.epsilon) 2517 table_descriptor.optimization_parameters.proximal_yogi.l1 = ( 2518 self._optimization_parameters.l1_regularization_strength) 2519 table_descriptor.optimization_parameters.proximal_yogi.l2 = ( 2520 self._optimization_parameters.l2_regularization_strength) 2521 2522 def get_default_slot_variable_names(self, table): 2523 return ProximalYogiSlotVariableNames( 2524 '{}/{}'.format(table, 'ProximalYogi'), # v 2525 '{}/{}_1'.format(table, 'ProximalYogi')) # m 2526 2527 def create_variables_and_ops(self, table, slot_variable_names, num_hosts, 2528 table_config, table_variables, config_proto): 2529 v_initializer = init_ops.constant_initializer( 2530 self._optimization_parameters.initial_accumulator_value) 2531 v_variables = _create_partitioned_variables( 2532 name=slot_variable_names.v, 2533 num_hosts=num_hosts, 2534 vocabulary_size=table_config.vocabulary_size, 2535 embedding_dimension=table_config.dimension, 2536 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 2537 initializer=v_initializer) 2538 m_initializer = init_ops.zeros_initializer() 2539 m_variables = _create_partitioned_variables( 2540 name=slot_variable_names.m, 2541 num_hosts=num_hosts, 2542 vocabulary_size=table_config.vocabulary_size, 2543 embedding_dimension=table_config.dimension, 2544 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 2545 initializer=m_initializer) 2546 slot_variables = ProximalYogiSlotVariables(v_variables, m_variables) 2547 2548 def load_ops_fn(): 2549 """Returns the load ops for Proximal Yogi embedding tables. 2550 2551 Returns: 2552 A list of ops to load embedding and slot variables from CPU to TPU. 2553 """ 2554 load_op_list = [] 2555 config = config_proto 2556 for host_id, table_variable, v_variable, m_variable in (zip( 2557 range(num_hosts), table_variables, v_variables, m_variables)): 2558 with ops.colocate_with(table_variable): 2559 load_parameters_op = ( 2560 tpu_ops.load_tpu_embedding_proximal_yogi_parameters( 2561 parameters=table_variable, 2562 v=v_variable, 2563 m=m_variable, 2564 table_name=table, 2565 num_shards=num_hosts, 2566 shard_id=host_id, 2567 config=config)) 2568 # Set config to None to enforce that config is only loaded to the first 2569 # table. 2570 config = None 2571 load_op_list.append(load_parameters_op) 2572 return load_op_list 2573 2574 def retrieve_ops_fn(): 2575 """Returns the retrieve ops for Proximal Yogi embedding tables. 2576 2577 Returns: 2578 A list of ops to retrieve embedding and slot variables from TPU to CPU. 2579 """ 2580 retrieve_op_list = [] 2581 config = config_proto 2582 for host_id, table_variable, v_variable, m_variable in (zip( 2583 range(num_hosts), table_variables, v_variables, m_variables)): 2584 with ops.colocate_with(table_variable): 2585 retrieved_table, retrieved_v, retrieved_m = ( 2586 tpu_ops.retrieve_tpu_embedding_proximal_yogi_parameters( 2587 table_name=table, 2588 num_shards=num_hosts, 2589 shard_id=host_id, 2590 config=config)) 2591 retrieve_parameters_op = control_flow_ops.group( 2592 state_ops.assign(table_variable, retrieved_table), 2593 state_ops.assign(v_variable, retrieved_v), 2594 state_ops.assign(m_variable, retrieved_m)) 2595 config = None 2596 retrieve_op_list.append(retrieve_parameters_op) 2597 return retrieve_op_list 2598 2599 return slot_variables, load_ops_fn, retrieve_ops_fn 2600 2601 2602class _MomentumHandler(_OptimizerHandler): 2603 """Handles Momentum specific logic.""" 2604 2605 def set_optimization_parameters(self, table_descriptor): 2606 (table_descriptor.optimization_parameters.momentum.SetInParent()) 2607 table_descriptor.optimization_parameters.momentum.momentum = ( 2608 self._optimization_parameters.momentum) 2609 table_descriptor.optimization_parameters.momentum.use_nesterov = ( 2610 self._optimization_parameters.use_nesterov) 2611 2612 def get_default_slot_variable_names(self, table): 2613 return MomentumSlotVariableNames('{}/{}'.format(table, 'Momentum')) 2614 2615 def create_variables_and_ops(self, table, slot_variable_names, num_hosts, 2616 table_config, table_variables, config_proto): 2617 2618 momenta_initializer = init_ops.zeros_initializer() 2619 momenta_variables = _create_partitioned_variables( 2620 name=slot_variable_names.momenta, 2621 num_hosts=num_hosts, 2622 vocabulary_size=table_config.vocabulary_size, 2623 embedding_dimension=table_config.dimension, 2624 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 2625 initializer=momenta_initializer) 2626 slot_variables = MomentumSlotVariables(momenta_variables) 2627 2628 def load_ops_fn(): 2629 """Returns the retrieve ops for Momentum embedding tables. 2630 2631 Returns: 2632 A list of ops to load embedding and slot variables from CPU to TPU. 2633 """ 2634 load_op_list = [] 2635 config = config_proto 2636 for host_id, table_variable, momenta_variable in (zip( 2637 range(num_hosts), table_variables, momenta_variables)): 2638 with ops.colocate_with(table_variable): 2639 load_parameters_op = tpu_ops.load_tpu_embedding_momentum_parameters( 2640 parameters=table_variable, 2641 momenta=momenta_variable, 2642 table_name=table, 2643 num_shards=num_hosts, 2644 shard_id=host_id, 2645 config=config, 2646 ) 2647 config = None 2648 load_op_list.append(load_parameters_op) 2649 return load_op_list 2650 2651 def retrieve_ops_fn(): 2652 """Returns the retrieve ops for Momentum embedding tables. 2653 2654 Returns: 2655 A list of ops to retrieve embedding and slot variables from TPU to CPU. 2656 """ 2657 retrieve_op_list = [] 2658 config = config_proto 2659 for host_id, table_variable, momenta_variable in (zip( 2660 range(num_hosts), table_variables, momenta_variables)): 2661 with ops.colocate_with(table_variable): 2662 retrieved_table, retrieved_momenta = ( 2663 tpu_ops.retrieve_tpu_embedding_momentum_parameters( 2664 table_name=table, 2665 num_shards=num_hosts, 2666 shard_id=host_id, 2667 config=config, 2668 )) 2669 retrieve_parameters_op = control_flow_ops.group( 2670 state_ops.assign(table_variable, retrieved_table), 2671 state_ops.assign(momenta_variable, retrieved_momenta)) 2672 config = None 2673 retrieve_op_list.append(retrieve_parameters_op) 2674 return retrieve_op_list 2675 2676 return slot_variables, load_ops_fn, retrieve_ops_fn 2677 2678 2679class _RMSPropHandler(_OptimizerHandler): 2680 """Handles RMS prop specific logic.""" 2681 2682 def set_optimization_parameters(self, table_descriptor): 2683 (table_descriptor.optimization_parameters.rms_prop.SetInParent()) 2684 table_descriptor.optimization_parameters.rms_prop.rho = ( 2685 self._optimization_parameters.rho) 2686 table_descriptor.optimization_parameters.rms_prop.epsilon = ( 2687 self._optimization_parameters.epsilon) 2688 table_descriptor.optimization_parameters.rms_prop.momentum = ( 2689 self._optimization_parameters.momentum) 2690 2691 def get_default_slot_variable_names(self, table): 2692 return RMSPropSlotVariableNames('{}/{}/ms'.format(table, 'RMSProp'), 2693 '{}/{}/mom'.format(table, 'RMSProp')) 2694 2695 def create_variables_and_ops(self, table, slot_variable_names, num_hosts, 2696 table_config, table_variables, config_proto): 2697 2698 ms_variables = _create_partitioned_variables( 2699 name=slot_variable_names.ms, 2700 num_hosts=num_hosts, 2701 vocabulary_size=table_config.vocabulary_size, 2702 embedding_dimension=table_config.dimension, 2703 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 2704 initializer=init_ops.zeros_initializer(), 2705 ) 2706 mom_variables = _create_partitioned_variables( 2707 name=slot_variable_names.mom, 2708 num_hosts=num_hosts, 2709 vocabulary_size=table_config.vocabulary_size, 2710 embedding_dimension=table_config.dimension, 2711 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 2712 initializer=init_ops.zeros_initializer(), 2713 ) 2714 slot_variables = RMSPropSlotVariables(ms_variables, mom_variables) 2715 2716 def load_ops_fn(): 2717 """Returns the retrieve ops for RMS Prop embedding tables. 2718 2719 Returns: 2720 A list of ops to load embedding and slot variables from CPU to TPU. 2721 """ 2722 load_op_list = [] 2723 config = config_proto 2724 for host_id, table_variable, ms_variable, mom_variable in (zip( 2725 range(num_hosts), table_variables, ms_variables, mom_variables)): 2726 with ops.colocate_with(table_variable): 2727 load_parameters_op = tpu_ops.load_tpu_embedding_rms_prop_parameters( 2728 parameters=table_variable, 2729 ms=ms_variable, 2730 mom=mom_variable, 2731 table_name=table, 2732 num_shards=num_hosts, 2733 shard_id=host_id, 2734 config=config, 2735 ) 2736 config = None 2737 load_op_list.append(load_parameters_op) 2738 return load_op_list 2739 2740 def retrieve_ops_fn(): 2741 """Returns the retrieve ops for RMS Prop embedding tables. 2742 2743 Returns: 2744 A list of ops to retrieve embedding and slot variables from TPU to CPU. 2745 """ 2746 retrieve_op_list = [] 2747 config = config_proto 2748 for host_id, table_variable, ms_variable, mom_variable in (zip( 2749 range(num_hosts), table_variables, ms_variables, mom_variables)): 2750 with ops.colocate_with(table_variable): 2751 retrieved_table, retrieved_ms, retrieved_mom = ( 2752 tpu_ops.retrieve_tpu_embedding_rms_prop_parameters( 2753 table_name=table, 2754 num_shards=num_hosts, 2755 shard_id=host_id, 2756 config=config, 2757 )) 2758 retrieve_parameters_op = control_flow_ops.group( 2759 state_ops.assign(table_variable, retrieved_table), 2760 state_ops.assign(ms_variable, retrieved_ms), 2761 state_ops.assign(mom_variable, retrieved_mom)) 2762 config = None 2763 retrieve_op_list.append(retrieve_parameters_op) 2764 return retrieve_op_list 2765 2766 return slot_variables, load_ops_fn, retrieve_ops_fn 2767 2768 2769class _FrequencyEstimatorHandler(_OptimizerHandler): 2770 """Handles frequency estimator specific logic.""" 2771 2772 def set_optimization_parameters(self, table_descriptor): 2773 table_descriptor.optimization_parameters.frequency_estimator.SetInParent() 2774 freq = table_descriptor.optimization_parameters.frequency_estimator 2775 freq.tau = self._optimization_parameters.tau 2776 freq.max_delta = self._optimization_parameters.max_delta 2777 freq.outlier_threshold = self._optimization_parameters.outlier_threshold 2778 freq.weight_exponent = self._optimization_parameters.weight_exponent 2779 2780 def get_default_slot_variable_names(self, table): 2781 return FrequencyEstimatorSlotVariableNames( 2782 '{}/FrequencyEstimator'.format(table)) 2783 2784 def create_variables_and_ops(self, table, slot_variable_names, num_hosts, 2785 table_config, table_variables, config_proto): 2786 if table_config.dimension != 1: 2787 raise ValueError('FrequencyEstimator tables should only have a dimension ' 2788 'of 1. Received dimension {}'.format( 2789 table_config.dimension)) 2790 2791 last_hit_step_variables = _create_partitioned_variables( 2792 name=slot_variable_names.last_hit_step, 2793 num_hosts=num_hosts, 2794 vocabulary_size=table_config.vocabulary_size, 2795 embedding_dimension=table_config.dimension, 2796 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 2797 initializer=init_ops.zeros_initializer(), 2798 ) 2799 slot_variables = FrequencyEstimatorSlotVariables(last_hit_step_variables) 2800 2801 def load_ops_fn(): 2802 """Returns the retrieve ops for Frequency Estimator embedding tables. 2803 2804 Returns: 2805 A list of ops to load embedding and slot variables from CPU to TPU. 2806 """ 2807 load_op_list = [] 2808 config = config_proto 2809 for host_id, table_variable, last_hit_step_variable in (zip( 2810 range(num_hosts), table_variables, last_hit_step_variables)): 2811 with ops.colocate_with(table_variable): 2812 load_parameters_op = ( 2813 tpu_ops.load_tpu_embedding_frequency_estimator_parameters( 2814 parameters=table_variable, 2815 last_hit_step=last_hit_step_variable, 2816 table_name=table, 2817 num_shards=num_hosts, 2818 shard_id=host_id, 2819 config=config)) 2820 config = None 2821 load_op_list.append(load_parameters_op) 2822 return load_op_list 2823 2824 def retrieve_ops_fn(): 2825 """Returns the retrieve ops for Frequency Estimator embedding tables. 2826 2827 Returns: 2828 A list of ops to retrieve embedding and slot variables from TPU to CPU. 2829 """ 2830 retrieve_op_list = [] 2831 config = config_proto 2832 for host_id, table_variable, last_hit_step_variable in (zip( 2833 range(num_hosts), table_variables, last_hit_step_variables)): 2834 with ops.colocate_with(table_variable): 2835 retrieved_table, retrieved_last_hit_step = ( 2836 tpu_ops.retrieve_tpu_embedding_frequency_estimator_parameters( 2837 table_name=table, 2838 num_shards=num_hosts, 2839 shard_id=host_id, 2840 config=config, 2841 )) 2842 retrieve_parameters_op = control_flow_ops.group( 2843 state_ops.assign(table_variable, retrieved_table), 2844 state_ops.assign(last_hit_step_variable, retrieved_last_hit_step)) 2845 config = None 2846 retrieve_op_list.append(retrieve_parameters_op) 2847 return retrieve_op_list 2848 2849 return slot_variables, load_ops_fn, retrieve_ops_fn 2850 2851 2852class _StochasticGradientDescentHandler(_OptimizerHandler): 2853 """Handles stochastic gradient descent specific logic.""" 2854 2855 def set_optimization_parameters(self, table_descriptor): 2856 (table_descriptor.optimization_parameters.stochastic_gradient_descent 2857 .SetInParent()) 2858 2859 def get_default_slot_variable_names(self, table): 2860 return None 2861 2862 def create_variables_and_ops(self, table, slot_variable_names, num_hosts, 2863 table_config, table_variables, config_proto): 2864 del table_config 2865 2866 def load_ops_fn(): 2867 """Returns the retrieve ops for AdaGrad embedding tables. 2868 2869 Returns: 2870 A list of ops to load embedding and slot variables from CPU to TPU. 2871 """ 2872 load_op_list = [] 2873 config = config_proto 2874 for host_id, table_variable in enumerate(table_variables): 2875 with ops.colocate_with(table_variable): 2876 load_parameters_op = ( 2877 tpu_ops.load_tpu_embedding_stochastic_gradient_descent_parameters( 2878 parameters=table_variable, 2879 table_name=table, 2880 num_shards=num_hosts, 2881 shard_id=host_id, 2882 config=config)) 2883 config = None 2884 load_op_list.append(load_parameters_op) 2885 return load_op_list 2886 2887 def retrieve_ops_fn(): 2888 """Returns the retrieve ops for SGD embedding tables. 2889 2890 Returns: 2891 A list of ops to retrieve embedding and slot variables from TPU to CPU. 2892 """ 2893 retrieve_op_list = [] 2894 config = config_proto 2895 for host_id, table_variable in enumerate(table_variables): 2896 with ops.colocate_with(table_variable): 2897 retrieved_table = ( 2898 tpu_ops 2899 .retrieve_tpu_embedding_stochastic_gradient_descent_parameters( 2900 table_name=table, 2901 num_shards=num_hosts, 2902 shard_id=host_id, 2903 config=config)) 2904 retrieve_parameters_op = control_flow_ops.group( 2905 state_ops.assign(table_variable, retrieved_table)) 2906 config = None 2907 retrieve_op_list.append(retrieve_parameters_op) 2908 return retrieve_op_list 2909 2910 return None, load_ops_fn, retrieve_ops_fn 2911 2912 2913def _get_optimization_handler(optimization_parameters): 2914 """Gets the optimization handler given the parameter type.""" 2915 if isinstance(optimization_parameters, AdagradParameters): 2916 return _AdagradHandler(optimization_parameters) 2917 elif isinstance(optimization_parameters, AdagradMomentumParameters): 2918 return _AdagradMomentumHandler(optimization_parameters) 2919 elif isinstance(optimization_parameters, ProximalAdagradParameters): 2920 return _ProximalAdagradHandler(optimization_parameters) 2921 elif isinstance(optimization_parameters, AdamParameters): 2922 return _AdamHandler(optimization_parameters) 2923 elif isinstance(optimization_parameters, FtrlParameters): 2924 return _FtrlHandler(optimization_parameters) 2925 elif isinstance(optimization_parameters, ProximalYogiParameters): 2926 return _ProximalYogiHandler(optimization_parameters) 2927 elif isinstance(optimization_parameters, StochasticGradientDescentParameters): 2928 return _StochasticGradientDescentHandler(optimization_parameters) 2929 elif isinstance(optimization_parameters, MomentumParameters): 2930 return _MomentumHandler(optimization_parameters) 2931 elif isinstance(optimization_parameters, RMSPropParameters): 2932 return _RMSPropHandler(optimization_parameters) 2933 elif isinstance(optimization_parameters, FrequencyEstimatorParameters): 2934 return _FrequencyEstimatorHandler(optimization_parameters) 2935 return NotImplementedError() 2936 2937 2938def _create_ordered_dict(d): 2939 """Create an OrderedDict from Dict.""" 2940 return collections.OrderedDict((k, d[k]) for k in sorted(d)) 2941 2942 2943def _create_combiners(table_to_config_dict, table_to_features_dict): 2944 """Create a per feature list of combiners, ordered by table.""" 2945 combiners = [] 2946 for table in table_to_config_dict: 2947 combiner = table_to_config_dict[table].combiner or 'sum' 2948 combiners.extend([combiner] * len(table_to_features_dict[table])) 2949 return combiners 2950 2951 2952def _create_table_to_features_dict(feature_to_config_dict): 2953 """Create mapping from table to a list of its features.""" 2954 table_to_features_dict_tmp = {} 2955 for feature, feature_config in feature_to_config_dict.items(): 2956 if feature_config.table_id in table_to_features_dict_tmp: 2957 table_to_features_dict_tmp[feature_config.table_id].append(feature) 2958 else: 2959 table_to_features_dict_tmp[feature_config.table_id] = [feature] 2960 2961 table_to_features_dict = collections.OrderedDict() 2962 for table in sorted(table_to_features_dict_tmp): 2963 table_to_features_dict[table] = sorted(table_to_features_dict_tmp[table]) 2964 return table_to_features_dict 2965 2966 2967def _create_device_fn(hosts): 2968 """Create device_fn() to use with _create_partitioned_variables().""" 2969 2970 def device_fn(op): 2971 """Returns the `device` for `op`.""" 2972 part_match = re.match(r'.*/part_(\d+)(/|$)', op.name) 2973 dummy_match = re.match(r'.*dummy_(\d+).*', op.name) 2974 if not part_match and not dummy_match: 2975 raise RuntimeError( 2976 'Internal Error: Expected {} to contain /part_* or dummy_*'.format( 2977 op.name)) 2978 2979 if part_match: 2980 idx = int(part_match.group(1)) 2981 else: 2982 idx = int(dummy_match.group(1)) # pytype: disable=attribute-error 2983 2984 device = hosts[idx] 2985 logging.debug('assigning {} to {}.', op, device) 2986 return device 2987 2988 return device_fn 2989 2990 2991def _create_partitioned_variables(name, 2992 num_hosts, 2993 vocabulary_size, 2994 embedding_dimension, 2995 initializer, 2996 collections=None): # pylint: disable=redefined-outer-name 2997 """Creates PartitionedVariables based on `num_hosts` for `table`.""" 2998 2999 num_slices = min(vocabulary_size, num_hosts) 3000 3001 var_list = list( 3002 variable_scope.get_variable( 3003 name, 3004 shape=(vocabulary_size, embedding_dimension), 3005 partitioner=partitioned_variables.fixed_size_partitioner(num_slices), 3006 dtype=dtypes.float32, 3007 initializer=initializer, 3008 collections=collections, 3009 trainable=False)) 3010 3011 if vocabulary_size >= num_hosts: 3012 return var_list 3013 3014 # For padded part, define the dummy variable to be loaded into TPU system. 3015 for idx in range(num_hosts - vocabulary_size): 3016 var_list.append( 3017 variable_scope.get_variable( 3018 'dummy_{}_{}'.format(vocabulary_size + idx, name), 3019 shape=(1, embedding_dimension), 3020 dtype=dtypes.float32, 3021 initializer=initializer, 3022 collections=[ops.GraphKeys.LOCAL_VARIABLES], 3023 trainable=False)) 3024 3025 return var_list 3026