xref: /aosp_15_r20/external/tensorflow/tensorflow/python/tpu/tpu_embedding.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TPU embedding APIs."""
16
17import collections
18import copy
19import math
20import re
21from typing import Optional
22
23from tensorflow.core.protobuf.tpu import optimization_parameters_pb2
24from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2 as elc
25from tensorflow.python.eager import context
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import control_flow_ops
30from tensorflow.python.ops import init_ops
31from tensorflow.python.ops import math_ops
32from tensorflow.python.ops import partitioned_variables
33from tensorflow.python.ops import state_ops
34from tensorflow.python.ops import variable_scope
35from tensorflow.python.platform import tf_logging as logging
36from tensorflow.python.tpu import tpu_system_metadata as tpu_system_metadata_lib
37from tensorflow.python.tpu.ops import tpu_ops
38from tensorflow.python.util.tf_export import tf_export
39
40TRAINING = elc.TPUEmbeddingConfiguration.TRAINING
41INFERENCE = elc.TPUEmbeddingConfiguration.INFERENCE
42
43
44# TODO(shizhiw): a more future-proof way is to have optimization_parameter such
45#  as AdagradParameters etc instead of learning_rate.
46class TableConfig(
47    collections.namedtuple('TableConfig', [
48        'vocabulary_size',
49        'dimension',
50        'initializer',
51        'combiner',
52        'hot_id_replication',
53        'learning_rate',
54        'learning_rate_fn',
55        'optimization_parameters',
56    ])):
57  """Embedding table configuration."""
58
59  def __new__(cls,
60              vocabulary_size,
61              dimension,
62              initializer=None,
63              combiner='mean',
64              hot_id_replication=False,
65              learning_rate=None,
66              learning_rate_fn=None,
67              optimization_parameters=None):
68    """Embedding table configuration.
69
70    Args:
71      vocabulary_size: Number of vocabulary (/rows) in the table.
72      dimension: The embedding dimension.
73      initializer: A variable initializer function to be used in embedding
74        variable initialization. If not specified, defaults to
75        `tf.compat.v1.truncated_normal_initializer` with mean `0.0` and standard
76        deviation `1/sqrt(dimension)`.
77      combiner: A string specifying how to reduce if there are multiple entries
78        in a single row. Currently 'mean', 'sqrtn', 'sum' and None are
79        supported, with 'mean' the default. 'sqrtn' often achieves good
80        accuracy, in particular with bag-of-words columns. For more information,
81        see `tf.nn.embedding_lookup_sparse`. None is only valid for dense rather
82        than sparse tensors.
83      hot_id_replication: If true, enables hot id replication, which can make
84        embedding lookups faster if there are some hot rows in the table.
85      learning_rate: float, static learning rate for this table. If
86        learning_rate and learning_rate_fn are both `None`, static learning rate
87        as specified in local `optimization_parameters` will be used. In case
88        local `optimization_parameters` is `None`, global
89        `optimization_parameters` in `TPUEmbedding` constructor will be used.
90        `learning_rate_fn` must be `None` if `learning_rate` is not `None.
91      learning_rate_fn: string, use dynamic learning rate given by the function.
92        This function will be passed the current global step. If learning_rate
93        and learning_rate_fn are both `None`, static learning rate as specified
94        in `optimization_parameters` is used. `learning_rate` must be `None` if
95        `learning_rate_fn` is not `None.
96      optimization_parameters: `AdagradParameters`, `AdamParameters`,
97        `Stochasticgradientdescentparameters`. Specifies table level optimizer.
98        If it's `None` global optimizer in `TPUEmbedding` constructor is used.
99
100    Returns:
101      `TableConfig`.
102
103    Raises:
104      ValueError: if `vocabulary_size` is not positive integer.
105      ValueError: if `dimension` is not positive integer.
106      ValueError: if `initializer` is specified and is not callable.
107      ValueError: if `combiner` is not supported.
108      ValueError: if `learning_rate` and `learning_rate_fn` are both not
109        `None`.
110    """
111    if not isinstance(vocabulary_size, int) or vocabulary_size < 1:
112      raise ValueError(f'vocabulary_size must >= 1. '
113                       f'Received: {vocabulary_size}.')
114
115    if not isinstance(dimension, int) or dimension < 1:
116      raise ValueError(
117          f'dimension must be a positive int. Received: {dimension}.')
118
119    if (initializer is not None) and (not callable(initializer)):
120      raise ValueError(f'initializer must be callable if specified. '
121                       f'Received: {initializer}.')
122    if initializer is None:
123      initializer = init_ops.truncated_normal_initializer(
124          mean=0.0, stddev=1 / math.sqrt(dimension))
125
126    if combiner not in ('mean', 'sum', 'sqrtn', None):
127      raise ValueError(f'combiner must be "mean", "sum", "sqrtn" or None. '
128                       f'Received: {combiner}.')
129
130    if learning_rate is not None and learning_rate_fn is not None:
131      raise ValueError('At most one of learning_rate and learning_rate_fn '
132                       'can be None. Received: {} and {}'.format(
133                           learning_rate, learning_rate_fn))
134
135    if optimization_parameters is not None:
136      if not isinstance(optimization_parameters, _OptimizationParameters):
137        raise ValueError(f'`optimization_parameters` must inherit from '
138                         f'`_OptimizationParameters`. '
139                         f'Received: `type(optimization_parameters)`='
140                         f'{type(optimization_parameters)}.')
141
142    return super().__new__(cls, vocabulary_size, dimension, initializer,
143                           combiner, hot_id_replication, learning_rate,
144                           learning_rate_fn, optimization_parameters)
145
146
147class FeatureConfig(
148    collections.namedtuple('FeatureConfig',
149                           ['table_id', 'max_sequence_length', 'weight_key'])):
150  """Feature configuration."""
151
152  def __new__(cls, table_id, max_sequence_length=0, weight_key=None):
153    """Feature configuration.
154
155    Args:
156      table_id: Which table the feature is uses for embedding lookups.
157      max_sequence_length: If positive, the feature is a sequence feature with
158        the corresponding maximum sequence length. If the sequence is longer
159        than this, it will be truncated. If 0, the feature is not a sequence
160        feature.
161      weight_key: If using weights for the combiner, this key specifies which
162        input feature contains the weights.
163
164    Returns:
165      `FeatureConfig`.
166
167    Raises:
168      ValueError: if `max_sequence_length` non-integer or negative.
169    """
170    if not isinstance(max_sequence_length, int) or max_sequence_length < 0:
171      raise ValueError(f'max_sequence_length must be zero or a positive int, '
172                       f'got {max_sequence_length}.')
173
174    return super().__new__(cls, table_id, max_sequence_length, weight_key)
175
176
177class EnqueueData(
178    collections.namedtuple(
179        'EnqueueData',
180        ['embedding_indices', 'sample_indices', 'aggregation_weights'])):
181  """Data to be enqueued through generate_enqueue_ops()."""
182
183  def __new__(cls,
184              embedding_indices,
185              sample_indices=None,
186              aggregation_weights=None):
187    """Data to be enqueued through generate_enqueue_ops().
188
189    Args:
190      embedding_indices: A rank 1 Tensor, indices into the embedding tables. It
191        corresponds to sp_ids.values in embedding_lookup_sparse(). Both int32
192        and int64 are allowed and will be converted to int32 internally.
193      sample_indices: A rank 2 Tensor specifying the training example to which
194        the corresponding embedding_indices and aggregation_weights values
195        belong. It corresponds to sp_ids.indices in embedding_lookup_sparse().
196        If it is None, we assume each embedding_indices belongs to a different
197        sample. Both int32 and int64 are allowed and will be converted to int32
198        internally.
199      aggregation_weights: A rank 1 Tensor containing aggregation weights. It
200        corresponds to sp_weights.values in embedding_lookup_sparse(). If it is
201        None, we assume all weights are 1. Both float32 and float64 are allowed
202        and will be converted to float32 internally.
203
204    Returns:
205      An EnqueueData tuple.
206
207    """
208    return super().__new__(cls, embedding_indices, sample_indices,
209                           aggregation_weights)
210
211  @staticmethod
212  def from_sparse_tensor(sp_tensor, weights=None):
213    return EnqueueData(
214        sp_tensor.values,
215        sp_tensor.indices,
216        aggregation_weights=weights.values if weights is not None else None)
217
218
219class RaggedEnqueueData(
220    collections.namedtuple(
221        'RaggedEnqueueData',
222        ['embedding_indices', 'row_splits', 'aggregation_weights'])):
223  """RaggedTensor Data to be enqueued through generate_enqueue_ops()."""
224
225  def __new__(cls,
226              embedding_indices,
227              row_splits=None,
228              aggregation_weights=None):
229    """Data to be enqueued through generate_enqueue_ops().
230
231    Args:
232      embedding_indices: A rank 1 Tensor, indices into the embedding tables. It
233        corresponds to ids.values in embedding_lookup(), when ids is a
234        RaggedTensor. Both int32 and int64 are allowed and will be converted to
235        int32 internally.
236      row_splits: A rank 1 Tensor specifying the length of  the break points for
237        splitting embedding_indices and aggregation_weights. It corresponds to
238        ids.row_splits in embedding_lookup(), when ids is a RaggedTensor. Both
239        int32 and int64 are allowed and will be converted to int32 internally.
240      aggregation_weights: A rank 1 Tensor containing per training example
241        aggregation weights. It corresponds to the values field of a
242        RaggedTensor with the same row_splits as ids in embedding_lookup(), when
243        ids is a RaggedTensor.
244
245    Returns:
246      An RaggedEnqueueData tuple.
247
248    """
249    return super().__new__(cls, embedding_indices, row_splits,
250                           aggregation_weights)
251
252  @staticmethod
253  def from_ragged_tensor(rg_tensor, weights=None):
254    return RaggedEnqueueData(
255        rg_tensor.values,
256        rg_tensor.row_splits,
257        aggregation_weights=weights.values if weights is not None else None)
258
259
260def get_enqueue_datas_list_from_sparse_tensors_list(sp_tensors_list):
261  """Convenient function for generate_enqueue_ops().
262
263  Args:
264    sp_tensors_list: a list of dictionary mapping from string of feature names
265      to SparseTensor. Each dictionary is for one TPU core. Dictionaries for the
266      same host should be contiguous on the list.
267
268  Returns:
269    enqueue_datas_list: a list of dictionary mapping from string
270      of feature names to EnqueueData. Each dictionary is for one
271      TPU core. Dictionaries for the same host should be contiguous
272      on the list.
273
274  """
275  enqueue_datas_list = []
276  for sp_tensors in sp_tensors_list:
277    enqueue_datas = collections.OrderedDict(
278        (k, EnqueueData.from_sparse_tensor(v)) for k, v in sp_tensors.items())
279    enqueue_datas_list.append(enqueue_datas)
280  return enqueue_datas_list
281
282
283def get_enqueue_datas_list_from_ragged_tensors_list(rg_tensors_list):
284  """Convenient function for generate_enqueue_ops().
285
286  Args:
287    rg_tensors_list: a list of dictionary mapping from string of feature names
288      to RaggedTensor. Each dictionary is for one TPU core. Dictionaries for the
289      same host should be contiguous on the list.
290
291  Returns:
292    enqueue_datas_list: a list of dictionary mapping from string
293      of feature names to RaggedEnqueueData. Each dictionary is for one
294      TPU core. Dictionaries for the same host should be contiguous
295      on the list.
296
297  """
298  enqueue_datas_list = []
299  for rg_tensors in rg_tensors_list:
300    enqueue_datas = collections.OrderedDict(
301        (k, RaggedEnqueueData.from_ragged_tensor(v))
302        for k, v in rg_tensors.items())
303    enqueue_datas_list.append(enqueue_datas)
304  return enqueue_datas_list
305
306
307AdamSlotVariableNames = collections.namedtuple('AdamSlotVariableNames',
308                                               ['m', 'v'])
309
310AdagradSlotVariableNames = collections.namedtuple('AdagradSlotVariableNames',
311                                                  ['accumulator'])
312
313MomentumSlotVariableNames = collections.namedtuple('MomentumSlotVariableNames',
314                                                   ['momenta'])
315
316AdagradMomentumSlotVariableNames = collections.namedtuple(
317    'AdagradMomentumSlotVariableNames', ['accumulator', 'momenta'])
318
319RMSPropSlotVariableNames = collections.namedtuple('RMSPropSlotVariableNames',
320                                                  ['ms', 'mom'])
321
322ProximalAdagradSlotVariableNames = collections.namedtuple(
323    'ProximalAdagradSlotVariableNames', ['accumulator'])
324
325FtrlSlotVariableNames = collections.namedtuple('FtrlSlotVariableNames',
326                                               ['accumulator', 'linear'])
327
328ProximalYogiSlotVariableNames = collections.namedtuple(
329    'ProximalYogiSlotVariableNames', ['v', 'm'])
330
331FrequencyEstimatorSlotVariableNames = collections.namedtuple(
332    'FrequencyEstimatorSlotVariableNames', ['last_hit_step'])
333
334AdamSlotVariables = collections.namedtuple('AdamSlotVariables', ['m', 'v'])
335
336MomentumSlotVariables = collections.namedtuple('MomentumSlotVariables',
337                                               ['momenta'])
338
339AdagradMomentumSlotVariables = collections.namedtuple(
340    'AdagradMomentumSlotVariables', ['accumulator', 'momenta'])
341
342RMSPropSlotVariables = collections.namedtuple('RMSPropSlotVariables',
343                                              ['ms', 'mom'])
344
345AdagradSlotVariables = collections.namedtuple('AdagradSlotVariables',
346                                              ['accumulator'])
347
348ProximalAdagradSlotVariables = collections.namedtuple(
349    'ProximalAdagradSlotVariables', ['accumulator'])
350
351FtrlSlotVariable = collections.namedtuple('FtrlSlotVariable',
352                                          ['accumulator', 'linear'])
353
354ProximalYogiSlotVariables = collections.namedtuple('ProximalYogiSlotVariables',
355                                                   ['v', 'm'])
356
357FrequencyEstimatorSlotVariables = collections.namedtuple(
358    'FrequencyEstimatorSlotVariables', ['last_hit_step'])
359
360VariablesAndOps = collections.namedtuple('VariablesAndOps', [
361    'embedding_variables_by_table', 'slot_variables_by_table', 'load_ops',
362    'retrieve_ops'
363])
364
365
366class _OptimizationParameters:
367  """Parameters common to all optimizations."""
368
369  def __init__(
370      self,
371      learning_rate: float,
372      use_gradient_accumulation: bool,
373      clip_weight_min: Optional[float],
374      clip_weight_max: Optional[float],
375      weight_decay_factor: Optional[float],
376      multiply_weight_decay_factor_by_learning_rate: Optional[bool],
377      clip_gradient_min: Optional[float] = None,
378      clip_gradient_max: Optional[float] = None,
379  ):
380    self.learning_rate = learning_rate
381    self.use_gradient_accumulation = use_gradient_accumulation
382    self.clip_weight_min = clip_weight_min
383    self.clip_weight_max = clip_weight_max
384    self.weight_decay_factor = weight_decay_factor
385    self.multiply_weight_decay_factor_by_learning_rate = (
386        multiply_weight_decay_factor_by_learning_rate)
387    self.clip_gradient_min = clip_gradient_min
388    self.clip_gradient_max = clip_gradient_max
389
390    if not use_gradient_accumulation and (clip_gradient_min is not None or
391                                          clip_gradient_max is not None):
392      raise ValueError('When using gradient clipping limits, gradient  '
393                       'accumulation must be enabled.')
394
395
396@tf_export(v1=['tpu.experimental.AdagradParameters'])
397class AdagradParameters(_OptimizationParameters):
398  """Optimization parameters for Adagrad with TPU embeddings.
399
400  Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
401  `optimization_parameters` argument to set the optimizer and its parameters.
402  See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
403  for more details.
404
405  ```
406  estimator = tf.estimator.tpu.TPUEstimator(
407      ...
408      embedding_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec(
409          ...
410          optimization_parameters=tf.tpu.experimental.AdagradParameters(0.1),
411          ...))
412  ```
413
414  """
415
416  def __init__(
417      self,
418      learning_rate: float,
419      initial_accumulator: float = 0.1,
420      use_gradient_accumulation: bool = True,
421      clip_weight_min: Optional[float] = None,
422      clip_weight_max: Optional[float] = None,
423      weight_decay_factor: Optional[float] = None,
424      multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None,
425      clip_gradient_min: Optional[float] = None,
426      clip_gradient_max: Optional[float] = None,
427  ):
428    """Optimization parameters for Adagrad.
429
430    Args:
431      learning_rate: used for updating embedding table.
432      initial_accumulator: initial accumulator for Adagrad.
433      use_gradient_accumulation: setting this to `False` makes embedding
434        gradients calculation less accurate but faster. Please see
435        `optimization_parameters.proto` for details.
436      clip_weight_min: the minimum value to clip by; None means -infinity.
437      clip_weight_max: the maximum value to clip by; None means +infinity.
438      weight_decay_factor: amount of weight decay to apply; None means that the
439        weights are not decayed.
440      multiply_weight_decay_factor_by_learning_rate: if true,
441        `weight_decay_factor` is multiplied by the current learning rate.
442      clip_gradient_min: the minimum value to clip by; None means -infinity.
443        Gradient accumulation must be set to true if this is set.
444      clip_gradient_max: the maximum value to clip by; None means +infinity.
445        Gradient accumulation must be set to true if this is set.
446    """
447    super().__init__(
448        learning_rate=learning_rate,
449        use_gradient_accumulation=use_gradient_accumulation,
450        clip_weight_min=clip_weight_min,
451        clip_weight_max=clip_weight_max,
452        weight_decay_factor=weight_decay_factor,
453        multiply_weight_decay_factor_by_learning_rate=(
454            multiply_weight_decay_factor_by_learning_rate),
455        clip_gradient_min=clip_gradient_min,
456        clip_gradient_max=clip_gradient_max,
457    )
458    if initial_accumulator <= 0:
459      raise ValueError(
460          f'Adagrad initial_accumulator must be greater than zero. '
461          f'Received: {initial_accumulator}.')
462    self.initial_accumulator = initial_accumulator
463
464
465class AdagradMomentumParameters(_OptimizationParameters):
466  """Optimization parameters for Adagrad + Momentum with TPU embeddings.
467
468  Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
469  `optimization_parameters` argument to set the optimizer and its parameters.
470  See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
471  for more details.
472
473  ```
474  estimator = tf.estimator.tpu.TPUEstimator(
475      ...
476      embedding_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec(
477          ...
478          optimization_parameters=tf.tpu.experimental.AdagradMomentumParameters(0.1),
479          ...))
480  ```
481
482  """
483
484  def __init__(
485      self,
486      learning_rate: float,
487      momentum: float,
488      use_nesterov: bool = False,
489      exponent: float = 2,
490      beta2: float = 1,
491      epsilon: float = 1e-10,
492      use_gradient_accumulation: bool = True,
493      clip_weight_min: Optional[float] = None,
494      clip_weight_max: Optional[float] = None,
495      weight_decay_factor: Optional[float] = None,
496      multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None,
497      clip_gradient_min: Optional[float] = None,
498      clip_gradient_max: Optional[float] = None,
499  ):
500    """Optimization parameters for Adagrad.
501
502    Args:
503      learning_rate: used for updating embedding table.
504      momentum: Moving average parameter for the momentum accumulator.
505      use_nesterov: Whether to use the Nesterov variant of momentum. See
506        Sutskever et al., 2013.
507      exponent: Exponent for the Adagrad accumulator.
508      beta2: Moving average parameter for the Adagrad accumulator.
509      epsilon: initial accumulator for Adagrad accumulator.
510      use_gradient_accumulation: setting this to `False` makes embedding
511        gradients calculation less accurate but faster. Please see
512        `optimization_parameters.proto` for details.
513      clip_weight_min: the minimum value to clip by; None means -infinity.
514      clip_weight_max: the maximum value to clip by; None means +infinity.
515      weight_decay_factor: amount of weight decay to apply; None means that the
516        weights are not decayed.
517      multiply_weight_decay_factor_by_learning_rate: if true,
518        `weight_decay_factor` is multiplied by the current learning rate.
519      clip_gradient_min: the minimum value to clip by; None means -infinity.
520        Gradient accumulation must be set to true if this is set.
521      clip_gradient_max: the maximum value to clip by; None means +infinity.
522        Gradient accumulation must be set to true if this is set.
523    """
524    super().__init__(
525        learning_rate=learning_rate,
526        use_gradient_accumulation=use_gradient_accumulation,
527        clip_weight_min=clip_weight_min,
528        clip_weight_max=clip_weight_max,
529        weight_decay_factor=weight_decay_factor,
530        multiply_weight_decay_factor_by_learning_rate=(
531            multiply_weight_decay_factor_by_learning_rate),
532        clip_gradient_min=clip_gradient_min,
533        clip_gradient_max=clip_gradient_max,
534    )
535    if epsilon <= 0:
536      raise ValueError('Adagrad momentum: epsilon must be positive')
537    if exponent <= 0:
538      raise ValueError('Adagrad momentum: Precondition exponent must >0')
539    self.momentum = momentum
540    self.use_nesterov = use_nesterov
541    self.exponent = exponent
542    self.beta2 = beta2
543    self.epsilon = epsilon
544
545
546class ProximalAdagradParameters(_OptimizationParameters):
547  """Optimization parameters for ProximalAdagrad with TPU embeddings.
548
549  Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
550  `optimization_parameters` argument to set the optimizer and its parameters.
551  See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
552  for more details.
553  """
554
555  def __init__(
556      self,
557      learning_rate: float,
558      initial_accumulator: float = 0.1,
559      l1_regularization_strength: float = 0.0,
560      l2_regularization_strength: float = 0.0,
561      use_gradient_accumulation: bool = True,
562      clip_weight_min: Optional[float] = None,
563      clip_weight_max: Optional[float] = None,
564      weight_decay_factor: Optional[float] = None,
565      multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None,
566      clip_gradient_min: Optional[float] = None,
567      clip_gradient_max: Optional[float] = None,
568  ):
569    """Optimization parameters for Adagrad.
570
571    Args:
572      learning_rate: used for updating embedding table.
573      initial_accumulator: initial accumulator for Adagrad.
574      l1_regularization_strength: A float value, must be greater than or equal
575        to zero.
576      l2_regularization_strength: A float value, must be greater than or equal
577        to zero.
578      use_gradient_accumulation: setting this to `False` makes embedding
579        gradients calculation less accurate but faster. Please see
580        `optimization_parameters.proto` for details. for details.
581      clip_weight_min: the minimum value to clip by; None means -infinity.
582      clip_weight_max: the maximum value to clip by; None means +infinity.
583      weight_decay_factor: amount of weight decay to apply; None means that the
584        weights are not decayed.
585      multiply_weight_decay_factor_by_learning_rate: if true,
586        `weight_decay_factor` is multiplied by the current learning rate.
587      clip_gradient_min: the minimum value to clip by; None means -infinity.
588        Gradient accumulation must be set to true if this is set.
589      clip_gradient_max: the maximum value to clip by; None means +infinity.
590        Gradient accumulation must be set to true if this is set.
591    """
592    super().__init__(
593        learning_rate=learning_rate,
594        use_gradient_accumulation=use_gradient_accumulation,
595        clip_weight_min=clip_weight_min,
596        clip_weight_max=clip_weight_max,
597        weight_decay_factor=weight_decay_factor,
598        multiply_weight_decay_factor_by_learning_rate=(
599            multiply_weight_decay_factor_by_learning_rate),
600        clip_gradient_min=clip_gradient_min,
601        clip_gradient_max=clip_gradient_max,
602    )
603    if initial_accumulator <= 0:
604      raise ValueError(f'Adagrad initial_accumulator must be positive. '
605                       f'Received: {initial_accumulator}.')
606    if l1_regularization_strength < 0.:
607      raise ValueError('l1_regularization_strength must be greater than or '
608                       'equal to 0. got {}.'.format(l1_regularization_strength))
609
610    if l2_regularization_strength < 0.:
611      raise ValueError('l2_regularization_strength must be greater than or '
612                       'equal to 0. got {}.'.format(l2_regularization_strength))
613
614    self.initial_accumulator = initial_accumulator
615    self.l1_regularization_strength = l1_regularization_strength
616    self.l2_regularization_strength = l2_regularization_strength
617
618
619@tf_export(v1=['tpu.experimental.AdamParameters'])
620class AdamParameters(_OptimizationParameters):
621  """Optimization parameters for Adam with TPU embeddings.
622
623  Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
624  `optimization_parameters` argument to set the optimizer and its parameters.
625  See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
626  for more details.
627
628  ```
629  estimator = tf.estimator.tpu.TPUEstimator(
630      ...
631      embedding_config_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec(
632          ...
633          optimization_parameters=tf.tpu.experimental.AdamParameters(0.1),
634          ...))
635  ```
636
637  """
638
639  def __init__(
640      self,
641      learning_rate: float,
642      beta1: float = 0.9,
643      beta2: float = 0.999,
644      epsilon: float = 1e-08,
645      lazy_adam: bool = True,
646      sum_inside_sqrt: bool = True,
647      use_gradient_accumulation: bool = True,
648      clip_weight_min: Optional[float] = None,
649      clip_weight_max: Optional[float] = None,
650      weight_decay_factor: Optional[float] = None,
651      multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None,
652      clip_gradient_min: Optional[float] = None,
653      clip_gradient_max: Optional[float] = None,
654  ):
655    """Optimization parameters for Adam.
656
657    Args:
658      learning_rate: a floating point value. The learning rate.
659      beta1: A float value. The exponential decay rate for the 1st moment
660        estimates.
661      beta2: A float value. The exponential decay rate for the 2nd moment
662        estimates.
663      epsilon: A small constant for numerical stability.
664      lazy_adam: Use lazy Adam instead of Adam. Lazy Adam trains faster. See
665        `optimization_parameters.proto` for details.
666      sum_inside_sqrt: This improves training speed. Please see
667        `optimization_parameters.proto` for details.
668      use_gradient_accumulation: setting this to `False` makes embedding
669        gradients calculation less accurate but faster. Please see
670        `optimization_parameters.proto` for details.
671      clip_weight_min: the minimum value to clip by; None means -infinity.
672      clip_weight_max: the maximum value to clip by; None means +infinity.
673      weight_decay_factor: amount of weight decay to apply; None means that the
674        weights are not decayed.
675      multiply_weight_decay_factor_by_learning_rate: if true,
676        `weight_decay_factor` is multiplied by the current learning rate.
677      clip_gradient_min: the minimum value to clip by; None means -infinity.
678        Gradient accumulation must be set to true if this is set.
679      clip_gradient_max: the maximum value to clip by; None means +infinity.
680        Gradient accumulation must be set to true if this is set.
681    """
682    super().__init__(
683        learning_rate=learning_rate,
684        use_gradient_accumulation=use_gradient_accumulation,
685        clip_weight_min=clip_weight_min,
686        clip_weight_max=clip_weight_max,
687        weight_decay_factor=weight_decay_factor,
688        multiply_weight_decay_factor_by_learning_rate=(
689            multiply_weight_decay_factor_by_learning_rate),
690        clip_gradient_min=clip_gradient_min,
691        clip_gradient_max=clip_gradient_max,
692    )
693    if beta1 < 0. or beta1 >= 1.:
694      raise ValueError('beta1 must be between 0. and 1; got {}.'.format(beta1))
695    if beta2 < 0. or beta2 >= 1.:
696      raise ValueError('beta2 must be between 0. and 1; got {}.'.format(beta2))
697    if epsilon <= 0.:
698      raise ValueError('epsilon must be positive; got {}.'.format(epsilon))
699    if not use_gradient_accumulation and not lazy_adam:
700      raise ValueError(
701          'When disabling Lazy Adam, gradient accumulation must be used.')
702
703    self.beta1 = beta1
704    self.beta2 = beta2
705    self.epsilon = epsilon
706    self.lazy_adam = lazy_adam
707    self.sum_inside_sqrt = sum_inside_sqrt
708
709
710@tf_export(v1=['tpu.experimental.FtrlParameters'])
711class FtrlParameters(_OptimizationParameters):
712  """Optimization parameters for Ftrl with TPU embeddings.
713
714  Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
715  `optimization_parameters` argument to set the optimizer and its parameters.
716  See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
717  for more details.
718
719  ```
720  estimator = tf.estimator.tpu.TPUEstimator(
721      ...
722      embedding_config_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec(
723          ...
724          optimization_parameters=tf.tpu.experimental.FtrlParameters(0.1),
725          ...))
726  ```
727
728  """
729
730  def __init__(
731      self,
732      learning_rate: float,
733      learning_rate_power: float = -0.5,
734      initial_accumulator_value: float = 0.1,
735      l1_regularization_strength: float = 0.0,
736      l2_regularization_strength: float = 0.0,
737      use_gradient_accumulation: bool = True,
738      clip_weight_min: Optional[float] = None,
739      clip_weight_max: Optional[float] = None,
740      weight_decay_factor: Optional[float] = None,
741      multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None,
742      multiply_linear_by_learning_rate: bool = False,
743      beta: float = 0,
744      allow_zero_accumulator: bool = False,
745      clip_gradient_min: Optional[float] = None,
746      clip_gradient_max: Optional[float] = None,
747  ):
748    """Optimization parameters for Ftrl.
749
750    Implements FTRL as described in the following [paper](
751    https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/41159.pdf)
752
753    Args:
754      learning_rate: a floating point value. The learning rate.
755      learning_rate_power: A float value, must be less or equal to zero.
756        Controls how the learning rate decreases during training. Use zero for a
757        fixed learning rate. See section 3.1 in the
758        [paper](https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf).
759      initial_accumulator_value: The starting value for accumulators. Only zero
760        or positive values are allowed.
761      l1_regularization_strength: A float value, must be greater than or equal
762        to zero.
763      l2_regularization_strength: A float value, must be greater than or equal
764        to zero.
765      use_gradient_accumulation: setting this to `False` makes embedding
766        gradients calculation less accurate but faster. Please see
767        `optimization_parameters.proto` for details. for details.
768      clip_weight_min: the minimum value to clip by; None means -infinity.
769      clip_weight_max: the maximum value to clip by; None means +infinity.
770      weight_decay_factor: amount of weight decay to apply; None means that the
771        weights are not decayed.
772      multiply_weight_decay_factor_by_learning_rate: if true,
773        `weight_decay_factor` is multiplied by the current learning rate.
774      multiply_linear_by_learning_rate: When true, multiplies the usages of the
775        linear slot in the weight update by the learning rate. This is useful
776        when ramping up learning rate from 0 (which would normally produce
777        NaNs).
778      beta: The beta parameter for FTRL.
779      allow_zero_accumulator: Changes the implementation of the square root to
780        allow for the case of initial_accumulator_value being zero. This will
781        cause a slight performance drop.
782      clip_gradient_min: the minimum value to clip by; None means -infinity.
783        Gradient accumulation must be set to true if this is set.
784      clip_gradient_max: the maximum value to clip by; None means +infinity.
785        Gradient accumulation must be set to true if this is set.
786    """
787    super().__init__(
788        learning_rate=learning_rate,
789        use_gradient_accumulation=use_gradient_accumulation,
790        clip_weight_min=clip_weight_min,
791        clip_weight_max=clip_weight_max,
792        weight_decay_factor=weight_decay_factor,
793        multiply_weight_decay_factor_by_learning_rate=(
794            multiply_weight_decay_factor_by_learning_rate),
795        clip_gradient_min=clip_gradient_min,
796        clip_gradient_max=clip_gradient_max,
797    )
798    if learning_rate_power > 0.:
799      raise ValueError('learning_rate_power must be less than or equal to 0. '
800                       'got {}.'.format(learning_rate_power))
801
802    if initial_accumulator_value < 0.:
803      raise ValueError('initial_accumulator_value must be greater than or equal'
804                       ' to 0. got {}.'.format(initial_accumulator_value))
805
806    if l1_regularization_strength < 0.:
807      raise ValueError('l1_regularization_strength must be greater than or '
808                       'equal to 0. got {}.'.format(l1_regularization_strength))
809
810    if l2_regularization_strength < 0.:
811      raise ValueError('l2_regularization_strength must be greater than or '
812                       'equal to 0. got {}.'.format(l2_regularization_strength))
813
814    self.learning_rate_power = learning_rate_power
815    self.initial_accumulator_value = initial_accumulator_value
816    self.initial_linear_value = 0.0
817    self.l1_regularization_strength = l1_regularization_strength
818    self.l2_regularization_strength = l2_regularization_strength
819    self.multiply_linear_by_learning_rate = multiply_linear_by_learning_rate
820    self.beta = beta
821    self.allow_zero_accumulator = allow_zero_accumulator
822
823
824class ProximalYogiParameters(_OptimizationParameters):
825  # pylint: disable=line-too-long
826  """Optimization parameters for Proximal Yogi with TPU embeddings.
827
828  Implements the Yogi optimizer as described in
829  [Adaptive Methods for Nonconvex
830  Optimization](https://papers.nips.cc/paper/8186-adaptive-methods-for-nonconvex-optimization).
831
832  Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
833  `optimization_parameters` argument to set the optimizer and its parameters.
834  See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
835  for more details.
836  """
837
838  # pylint: enable=line-too-long
839
840  def __init__(
841      self,
842      learning_rate: float = 0.01,
843      beta1: float = 0.9,
844      beta2: float = 0.999,
845      epsilon: float = 1e-3,
846      l1_regularization_strength: float = 0.0,
847      l2_regularization_strength: float = 0.0,
848      initial_accumulator_value: float = 1e-6,
849      use_gradient_accumulation: bool = True,
850      clip_weight_min: Optional[float] = None,
851      clip_weight_max: Optional[float] = None,
852      weight_decay_factor: Optional[float] = None,
853      multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None,
854      clip_gradient_min: Optional[float] = None,
855      clip_gradient_max: Optional[float] = None,
856  ):
857    """Optimization parameters for Proximal Yogi.
858
859    Args:
860      learning_rate: a floating point value. The learning rate.
861      beta1: A float value. The exponential decay rate for the 1st moment
862        estimates.
863      beta2: A float value. The exponential decay rate for the 2nd moment
864        estimates.
865      epsilon: A small constant for numerical stability.
866      l1_regularization_strength: A float value, must be greater than or equal
867        to zero.
868      l2_regularization_strength: A float value, must be greater than or equal
869        to zero.
870      initial_accumulator_value: The starting value for accumulators. Only zero
871        or positive values are allowed.
872      use_gradient_accumulation: setting this to `False` makes embedding
873        gradients calculation less accurate but faster. Please see
874        `optimization_parameters.proto` for details. for details.
875      clip_weight_min: the minimum value to clip by; None means -infinity.
876      clip_weight_max: the maximum value to clip by; None means +infinity.
877      weight_decay_factor: amount of weight decay to apply; None means that the
878        weights are not decayed.
879      multiply_weight_decay_factor_by_learning_rate: if true,
880        `weight_decay_factor` is multiplied by the current learning rate.
881      clip_gradient_min: the minimum value to clip by; None means -infinity.
882        Gradient accumulation must be set to true if this is set.
883      clip_gradient_max: the maximum value to clip by; None means +infinity.
884        Gradient accumulation must be set to true if this is set.
885    """
886    super().__init__(
887        learning_rate=learning_rate,
888        use_gradient_accumulation=use_gradient_accumulation,
889        clip_weight_min=clip_weight_min,
890        clip_weight_max=clip_weight_max,
891        weight_decay_factor=weight_decay_factor,
892        multiply_weight_decay_factor_by_learning_rate=(
893            multiply_weight_decay_factor_by_learning_rate),
894        clip_gradient_min=clip_gradient_min,
895        clip_gradient_max=clip_gradient_max,
896    )
897    if beta1 < 0. or beta1 >= 1.:
898      raise ValueError('beta1 must be between 0. and 1; got {}.'.format(beta1))
899    if beta2 < 0. or beta2 >= 1.:
900      raise ValueError('beta2 must be between 0. and 1; got {}.'.format(beta2))
901    if epsilon <= 0.:
902      raise ValueError('epsilon must be positive; got {}.'.format(epsilon))
903    if l1_regularization_strength < 0.:
904      raise ValueError('l1_regularization_strength must be greater than or '
905                       'equal to 0. got {}.'.format(l1_regularization_strength))
906    if l2_regularization_strength < 0.:
907      raise ValueError('l2_regularization_strength must be greater than or '
908                       'equal to 0. got {}.'.format(l2_regularization_strength))
909
910    self.beta1 = beta1
911    self.beta2 = beta2
912    self.epsilon = epsilon
913    self.l1_regularization_strength = l1_regularization_strength
914    self.l2_regularization_strength = l2_regularization_strength
915    self.initial_accumulator_value = initial_accumulator_value
916
917
918class MomentumParameters(_OptimizationParameters):
919  """Optimization parameters for Momentum with TPU embeddings.
920
921  Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
922  `optimization_parameters` argument to set the optimizer and its parameters.
923  See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
924  for more details.
925
926  ```
927  estimator = tf.estimator.tpu.TPUEstimator(
928      ...
929      embedding_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec(
930          ...
931          optimization_parameters=tf.tpu.experimental.MomentumParameters(0.1),
932          ...))
933  ```
934
935  """
936
937  def __init__(
938      self,
939      learning_rate: float,
940      momentum: float,
941      use_nesterov: bool = False,
942      use_gradient_accumulation: bool = True,
943      clip_weight_min: Optional[float] = None,
944      clip_weight_max: Optional[float] = None,
945      weight_decay_factor: Optional[float] = None,
946      multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None,
947      clip_gradient_min: Optional[float] = None,
948      clip_gradient_max: Optional[float] = None,
949  ):
950    """Optimization parameters for momentum.
951
952    Args:
953      learning_rate: a floating point value. The learning rate.
954      momentum: a floating point value.  The momentum.
955      use_nesterov: If `True` use Nesterov Momentum. See (Sutskever et al.,
956        2013). This implementation always computes gradients at the value of the
957        variable(s) passed to the optimizer. Using Nesterov Momentum makes the
958        variable(s) track the values called `theta_t + mu*v_t` in the paper.
959        This implementation is an approximation of the original formula, valid
960        for high values of momentum. It will compute the "adjusted gradient" in
961        NAG by assuming that the new gradient will be estimated by the current
962        average gradient plus the product of momentum and the change in the
963        average gradient.
964      use_gradient_accumulation: setting this to `False` makes embedding
965        gradients calculation less accurate but faster. Please see
966        `optimization_parameters.proto` for details.
967      clip_weight_min: the minimum value to clip by; None means -infinity.
968      clip_weight_max: the maximum value to clip by; None means +infinity.
969      weight_decay_factor: amount of weight decay to apply; None means that the
970        weights are not decayed.
971      multiply_weight_decay_factor_by_learning_rate: if true,
972        `weight_decay_factor` is multiplied by the current learning rate.
973      clip_gradient_min: the minimum value to clip by; None means -infinity.
974        Gradient accumulation must be set to true if this is set.
975      clip_gradient_max: the maximum value to clip by; None means +infinity.
976        Gradient accumulation must be set to true if this is set.
977    """
978    super().__init__(
979        learning_rate=learning_rate,
980        use_gradient_accumulation=use_gradient_accumulation,
981        clip_weight_min=clip_weight_min,
982        clip_weight_max=clip_weight_max,
983        weight_decay_factor=weight_decay_factor,
984        multiply_weight_decay_factor_by_learning_rate=(
985            multiply_weight_decay_factor_by_learning_rate),
986        clip_gradient_min=clip_gradient_min,
987        clip_gradient_max=clip_gradient_max,
988    )
989    self.momentum = momentum
990    self.use_nesterov = use_nesterov
991
992
993class RMSPropParameters(_OptimizationParameters):
994  """Optimization parameters for RMSProp with TPU embeddings.
995
996  Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
997  `optimization_parameters` argument to set the optimizer and its parameters.
998  See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
999  for more details.
1000
1001  ```
1002  estimator = tf.estimator.tpu.TPUEstimator(
1003      ...
1004      embedding_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec(
1005          ...
1006          optimization_parameters=tf.tpu.experimental.MomentumParameters(0.1),
1007          ...))
1008  ```
1009
1010  """
1011
1012  def __init__(
1013      self,
1014      learning_rate: float,
1015      rho: float,
1016      momentum: float,
1017      epsilon: float,
1018      use_gradient_accumulation: bool = True,
1019      clip_weight_min: Optional[float] = None,
1020      clip_weight_max: Optional[float] = None,
1021      weight_decay_factor: Optional[float] = None,
1022      multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None,
1023      clip_gradient_min: Optional[float] = None,
1024      clip_gradient_max: Optional[float] = None,
1025  ):
1026    """Optimization parameters for RMS prop.
1027
1028    Args:
1029      learning_rate: a floating point value. The learning rate.
1030      rho: Discounting factor for the history/coming gradient
1031      momentum: A scalar tensor.
1032      epsilon: Small value to avoid zero denominator.
1033      use_gradient_accumulation: setting this to `False` makes embedding
1034        gradients calculation less accurate but faster. Please see
1035        `optimization_parameters.proto` for details. for details.
1036      clip_weight_min: the minimum value to clip by; None means -infinity.
1037      clip_weight_max: the maximum value to clip by; None means +infinity.
1038      weight_decay_factor: amount of weight decay to apply; None means that the
1039        weights are not decayed.
1040      multiply_weight_decay_factor_by_learning_rate: if true,
1041        `weight_decay_factor` is multiplied by the current learning rate.
1042      clip_gradient_min: the minimum value to clip by; None means -infinity.
1043        Gradient accumulation must be set to true if this is set.
1044      clip_gradient_max: the maximum value to clip by; None means +infinity.
1045        Gradient accumulation must be set to true if this is set.
1046    """
1047    super().__init__(
1048        learning_rate=learning_rate,
1049        use_gradient_accumulation=use_gradient_accumulation,
1050        clip_weight_min=clip_weight_min,
1051        clip_weight_max=clip_weight_max,
1052        weight_decay_factor=weight_decay_factor,
1053        multiply_weight_decay_factor_by_learning_rate=(
1054            multiply_weight_decay_factor_by_learning_rate),
1055        clip_gradient_min=clip_gradient_min,
1056        clip_gradient_max=clip_gradient_max,
1057    )
1058    self.rho = rho
1059    self.momentum = momentum
1060    self.epsilon = epsilon
1061
1062
1063@tf_export(v1=['tpu.experimental.StochasticGradientDescentParameters'])
1064class StochasticGradientDescentParameters(_OptimizationParameters):
1065  """Optimization parameters for stochastic gradient descent for TPU embeddings.
1066
1067  Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
1068  `optimization_parameters` argument to set the optimizer and its parameters.
1069  See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
1070  for more details.
1071
1072  ```
1073  estimator = tf.estimator.tpu.TPUEstimator(
1074      ...
1075      embedding_config_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec(
1076          ...
1077          optimization_parameters=(
1078              tf.tpu.experimental.StochasticGradientDescentParameters(0.1))))
1079  ```
1080
1081  """
1082
1083  def __init__(
1084      self,
1085      learning_rate: float,
1086      use_gradient_accumulation: bool = True,
1087      clip_weight_min: Optional[float] = None,
1088      clip_weight_max: Optional[float] = None,
1089      weight_decay_factor: Optional[float] = None,
1090      multiply_weight_decay_factor_by_learning_rate: Optional[bool] = None,
1091      clip_gradient_min: Optional[float] = None,
1092      clip_gradient_max: Optional[float] = None,
1093  ):
1094    """Optimization parameters for stochastic gradient descent.
1095
1096    Args:
1097      learning_rate: a floating point value. The learning rate.
1098      use_gradient_accumulation: setting this to `False` makes embedding
1099        gradients calculation less accurate but faster. Please see
1100        `optimization_parameters.proto` for details.
1101      clip_weight_min: the minimum value to clip by; None means -infinity.
1102      clip_weight_max: the maximum value to clip by; None means +infinity.
1103      weight_decay_factor: amount of weight decay to apply; None means that the
1104        weights are not decayed.
1105      multiply_weight_decay_factor_by_learning_rate: if true,
1106        `weight_decay_factor` is multiplied by the current learning rate.
1107      clip_gradient_min: the minimum value to clip by; None means -infinity.
1108      clip_gradient_max: the maximum value to clip by; None means +infinity.
1109    """
1110    super().__init__(
1111        learning_rate=learning_rate,
1112        use_gradient_accumulation=use_gradient_accumulation,
1113        clip_weight_min=clip_weight_min,
1114        clip_weight_max=clip_weight_max,
1115        weight_decay_factor=weight_decay_factor,
1116        multiply_weight_decay_factor_by_learning_rate=(
1117            multiply_weight_decay_factor_by_learning_rate),
1118        clip_gradient_min=clip_gradient_min,
1119        clip_gradient_max=clip_gradient_max,
1120    )
1121
1122
1123class FrequencyEstimatorParameters(_OptimizationParameters):
1124  """Optimization parameters for Frequency Estimator TPU embeddings.
1125
1126  This is a non-standard optimizer, which returns the estimated frequency of
1127  lookup for the feature passed to it. It should only be used on a table of
1128  width 1. The gradient fed back to the TPU embedding should always be zero.
1129  This can be acomplished via using `tf.stop_gradients` on the feature before
1130  using it.
1131
1132  You must use the dynamic learning rate mechanism to set the 'learning rate'
1133  for this table to be the a float32 cast of the global training step counter.
1134
1135  See `tensorflow/core/protobuf/tpu/optimization_parameters.proto` for more
1136  details on this optimizer.
1137
1138  Pass this to `tf.estimator.tpu.experimental.EmbeddingConfigSpec` via the
1139  `optimization_parameters` argument to set the optimizer and its parameters.
1140  See the documentation for `tf.estimator.tpu.experimental.EmbeddingConfigSpec`
1141  for more details.
1142
1143  ```
1144  estimator = tf.estimator.tpu.TPUEstimator(
1145      ...
1146      embedding_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec(
1147          ...
1148          optimization_parameters=FrequencyEstimatorParameters(0.1),
1149          ...))
1150  ```
1151
1152  """
1153
1154  def __init__(self, tau: float, max_delta: float, outlier_threshold: float,
1155               weight_exponent: float):
1156    """Optimization parameters for frequency estimator.
1157
1158    Args:
1159      tau: Learning rate between (0, 1) that is used to update the array.
1160      max_delta: Maximum value of delta, the difference between the current
1161        global step and the last global step at which the row was sampled.
1162      outlier_threshold: Threshold used to determine whether the current update
1163        is an outlier.
1164      weight_exponent: The weight exponent used to transform the estimated delta
1165        into weights.
1166    """
1167    super().__init__(
1168        learning_rate=1.0,
1169        use_gradient_accumulation=True,
1170        clip_weight_min=None,
1171        clip_weight_max=None,
1172        weight_decay_factor=None,
1173        multiply_weight_decay_factor_by_learning_rate=None,
1174    )
1175    self.tau = tau
1176    self.max_delta = max_delta
1177    self.outlier_threshold = outlier_threshold
1178    self.weight_exponent = weight_exponent
1179
1180
1181DeviceConfig = collections.namedtuple('DeviceConfig',
1182                                      ['num_hosts', 'num_cores', 'job_name'])
1183
1184
1185class TPUEmbedding:
1186  """API for using TPU for embedding.
1187
1188    Example:
1189    ```
1190    table_config_user = tpu_embedding.TableConfig(
1191        vocabulary_size=4, dimension=2,
1192        initializer=initializer, combiner='mean')
1193    table_to_config_dict = {'video': table_config_video,
1194                          'user': table_config_user}
1195    feature_to_config_dict = {'watched': tpu_embedding.FeatureConfig('video'),
1196                              'favorited': tpu_embedding.FeatureConfig('video'),
1197                              'friends': tpu_embedding.FeatureConfig('user')}
1198    batch_size = 4
1199    num_hosts = 1
1200    optimization_parameters = tpu_embedding.AdagradParameters(1., 1.)
1201    mode = tpu_embedding.TRAINING
1202    embedding = tpu_embedding.TPUEmbedding(
1203        table_to_config_dict, feature_to_config_dict,
1204        batch_size, num_hosts, mode, optimization_parameters)
1205
1206    batch_size_per_core = embedding.batch_size_per_core
1207    sparse_features_list = []
1208    for host in hosts:
1209      with ops.device(host):
1210        for _ in range(embedding.num_cores_per_host):
1211          sparse_features = {}
1212          sparse_features['watched'] = sparse_tensor.SparseTensor(...)
1213          sparse_features['favorited'] = sparse_tensor.SparseTensor(...)
1214          sparse_features['friends'] = sparse_tensor.SparseTensor(...)
1215          sparse_features_list.append(sparse_features)
1216
1217    enqueue_ops = embedding.generate_enqueue_ops(sparse_features_list)
1218    embedding_variables_and_ops = embedding.create_variables_and_ops()
1219
1220    def computation():
1221      activations = embedding.get_activations()
1222      loss = compute_loss(activations)
1223
1224      base_optimizer = gradient_descent.GradientDescentOptimizer(
1225          learning_rate=1)
1226      cross_shard_optimizer = tpu_optimizer.CrossShardOptimizer(
1227          base_optimizer)
1228
1229      train_op = cross_shard_optimizer.minimize(loss)
1230      gradients = (
1231          tpu_embedding_gradient.get_gradients_through_compute_gradients(
1232              cross_shard_optimizer, loss, activations)
1233      send_gradients_op = embedding.generate_send_gradients_op(gradients)
1234      with ops.control_dependencies([train_op, send_gradients_op]):
1235        loss = array_ops.identity(loss)
1236
1237    loss = tpu.shard(computation,
1238                     num_shards=embedding.num_cores)
1239
1240    with self.test_session() as sess:
1241      sess.run(tpu.initialize_system(embedding_config=
1242                                     embedding.config_proto))
1243      sess.run(variables.global_variables_initializer())
1244      sess.run(embedding_variables_and_ops.load_ops())
1245      sess.run(enqueue_ops)
1246      loss_val = sess.run(loss)
1247    ```
1248
1249  Example with weight decay:
1250
1251  >>> def learning_rate_fn(global_step):
1252  ...   return tf.compat.v1.train.polynomial_decay(
1253  ...     learning_rate=5e-5,
1254  ...     global_step=global_step,
1255  ...     decay_steps=100000,
1256  ...     end_learning_rate=0.0)
1257  >>> wordpiece_table_config = TableConfig(
1258  ...   vocabulary_size=119547,
1259  ...   dimension=256,
1260  ...   learning_rate_fn=learning_rate_fn)
1261  >>> wordpiece_feature_config = FeatureConfig(
1262  ...   table_id='bert/embeddings/word_embeddings',
1263  ...   max_sequence_length=512)
1264  >>> optimization_parameters = AdamParameters(
1265  ...   learning_rate=5e-5,
1266  ...   epsilon=1e-6,
1267  ...   weight_decay_factor=0.01,
1268  ...   multiply_weight_decay_factor_by_learning_rate=True)
1269  >>> tpu_embedding = TPUEmbedding(
1270  ...  table_to_config_dict={
1271  ...    'bert/embeddings/word_embeddings': wordpiece_table_config,
1272  ...  },
1273  ...  feature_to_config_dict={'input_ids': wordpiece_feature_config},
1274  ...  batch_size=128,
1275  ...  mode=TRAINING,
1276  ...  optimization_parameters=optimization_parameters,
1277  ...  master='')
1278  >>> with tf.Graph().as_default():
1279  ...   init_tpu_op = tf.compat.v1.tpu.initialize_system(
1280  ...     embedding_config=tpu_embedding.config_proto)
1281  ...   tf.compat.v1.Session().run(init_tpu_op)
1282  """
1283
1284  # TODO(shizhiw): Consider adding a field to FeatureConfig that indicates that
1285  # the feature should not be used to update embedding table (cr/204852758,
1286  # cr/204940540). Also, this can support different combiners for different
1287  # features within the same table.
1288  # TODO(shizhiw, b/118512626): Remove `batch_size` from `__init__` and move it
1289  # to `FeatureConfig`?
1290
1291  # TODO(shizhiw): will it be cleaner to make `table_to_config_dict` and
1292  # `feature_to_config_dict` lists of `TableSpec` and `FeatureSpec`
1293  # respectively?
1294
1295  # TODO(shizhiw): Consider adding `input_fn` as an option to remove boilerplate
1296  # for-loops around construction of inputs.
1297
1298  # `optimization_parameter` applies to all tables. If the need arises,
1299  # we can add `optimization_parameters` to `TableConfig` to override this
1300  # global setting.
1301  def __init__(self,
1302               table_to_config_dict,
1303               feature_to_config_dict,
1304               batch_size,
1305               mode,
1306               master=None,
1307               optimization_parameters=None,
1308               cluster_def=None,
1309               pipeline_execution_with_tensor_core=False,
1310               partition_strategy='div',
1311               profile_data_directory=None,
1312               device_config=None,
1313               master_job_name=None):
1314    """API for using TPU for embedding lookups.
1315
1316    Args:
1317      table_to_config_dict: A dictionary mapping from string of table name to
1318        `TableConfig`. Table refers to an embedding table, e.g. `params`
1319        argument to `tf.nn.embedding_lookup_sparse()`.
1320      feature_to_config_dict: A dictionary mapping from string of feature name
1321        to `FeatureConfig`. Feature refers to ids to lookup in embedding table,
1322        e.g. `sp_ids` argument to `tf.nn.embedding_lookup_sparse()`.
1323      batch_size: An `int` representing the global batch size.
1324      mode: `TRAINING` or `INFERENCE`.
1325      master: A `string` representing the TensorFlow master to use.
1326      optimization_parameters: `AdagradParameters`, `AdamParameters`,
1327        `Stochasticgradientdescentparameters`. Must be set in training unless
1328        all tables specify their own optimizers. And it must be `None` in
1329        inference.
1330      cluster_def: A ClusterDef object describing the TPU cluster.
1331      pipeline_execution_with_tensor_core: setting this to `True` makes training
1332        faster, but trained model will be different if step N and step N+1
1333        involve the same set of embedding IDs. Please see
1334        `tpu_embedding_configuration.proto` for details.
1335      partition_strategy: A string, either 'mod' or 'div', specifying how to map
1336        the lookup id to the embedding tensor. For more information see
1337        `tf.nn.embedding_lookup_sparse`.
1338      profile_data_directory: Directory where embedding lookup statistics are
1339        stored. These statistics summarize information about the inputs to the
1340        embedding lookup operation, in particular, the average number of
1341        embedding IDs per example and how well the embedding IDs are load
1342        balanced across the system. The lookup statistics are used during TPU
1343        initialization for embedding table partitioning. Collection of lookup
1344        statistics is done at runtime by  profiling the embedding inputs, only a
1345        small fraction of input samples are profiled to minimize host CPU
1346        overhead. Once a suitable number of samples are profiled, the lookup
1347        statistics are saved to table-specific files in the profile data
1348        directory generally at the end of a TPU training loop. The filename
1349        corresponding to each table is obtained by hashing table specific
1350        parameters (e.g., table name and number of features) and global
1351        configuration parameters (e.g., sharding strategy and task count). The
1352        same profile data directory can be shared among several models to reuse
1353        embedding lookup statistics.
1354      device_config: A DeviceConfig instance, used when `master` and
1355        `cluster_def` are both `None`.
1356      master_job_name: if set, overrides the master job name used to schedule
1357        embedding ops.
1358
1359    Raises:
1360      ValueError: if any input is invalid.
1361    """
1362    if partition_strategy not in ('div', 'mod'):
1363      raise ValueError(f'partition_strategy must be "div" or "mod". '
1364                       f'Received: {partition_strategy}.')
1365    self._partition_strategy = partition_strategy
1366
1367    self._profile_data_directory = profile_data_directory
1368
1369    _validate_table_to_config_dict(table_to_config_dict)
1370    # Avoid nondeterminism from `Dict` iteration order by using `OrderedDict`.
1371    self._table_to_config_dict = _create_ordered_dict(table_to_config_dict)
1372
1373    _validate_feature_to_config_dict(table_to_config_dict,
1374                                     feature_to_config_dict)
1375    self._feature_to_config_dict = _create_ordered_dict(feature_to_config_dict)
1376    self._table_to_features_dict = (
1377        _create_table_to_features_dict(self._feature_to_config_dict))
1378    self._combiners = _create_combiners(self._table_to_config_dict,
1379                                        self._table_to_features_dict)
1380
1381    self._batch_size = batch_size
1382
1383    if master is None and cluster_def is None:
1384      if device_config is None:
1385        raise ValueError('When master and cluster_def are both None,'
1386                         'device_config must be set but is not.')
1387      if device_config.num_cores % device_config.num_hosts:
1388        raise ValueError('num_hosts ({}) should divide num_cores ({}) '
1389                         'but does not.'.format(device_config.num_cores,
1390                                                device_config.num_hosts))
1391      self._num_hosts = device_config.num_hosts
1392      self._num_cores = device_config.num_cores
1393      self._num_cores_per_host = self._num_cores // self._num_hosts
1394      self._hosts = [
1395          '{}/replica:0/task:{}/device:CPU:0'.format(device_config.job_name, i)
1396          for i in range(self._num_hosts)
1397      ]
1398    else:
1399      tpu_system_metadata = (
1400          tpu_system_metadata_lib._query_tpu_system_metadata(  # pylint: disable=protected-access
1401              master,
1402              cluster_def=cluster_def))
1403      if tpu_system_metadata.num_cores == 0:
1404        raise ValueError('TPUEmbedding needs TPUs, but master {} does not have '
1405                         'TPUs.'.format(master))
1406      self._num_hosts = tpu_system_metadata.num_hosts
1407      if master_job_name is None:
1408        try:
1409          master_job_name = tpu_system_metadata_lib.master_job(
1410              master, cluster_def)
1411        except ValueError as e:
1412          raise ValueError(str(e) + ' Please specify a master_job_name.')
1413      self._hosts = []
1414      for device in tpu_system_metadata.devices:
1415        if 'device:CPU:' in device.name and (master_job_name is None or
1416                                             master_job_name in device.name):
1417          self._hosts.append(device.name)
1418      self._num_cores_per_host = tpu_system_metadata.num_of_cores_per_host
1419      self._num_cores = tpu_system_metadata.num_cores
1420
1421    _validate_batch_size(self._batch_size, self._num_cores)
1422    self._batch_size_per_core = self._batch_size // self._num_cores
1423
1424    # TODO(shizhiw): remove `mode`?
1425    if mode == TRAINING:
1426      _validate_optimization_parameters(optimization_parameters,
1427                                        self._table_to_config_dict)
1428      self._optimization_parameters = optimization_parameters
1429    elif mode == INFERENCE:
1430      if optimization_parameters is not None:
1431        raise ValueError(f'`optimization_parameters` should be `None` '
1432                         f'for inference mode. '
1433                         f'Received: {optimization_parameters}.')
1434      self._optimization_parameters = (StochasticGradientDescentParameters(1.))
1435    else:
1436      raise ValueError('`mode` only supports {} and {}; got {}.'.format(
1437          TRAINING, INFERENCE, mode))
1438    self._mode = mode
1439
1440    # TODO(shizhiw): move `optimization_parameters` into `_optimizer_handler`
1441    # and create special handler for inference that inherits from
1442    # StochasticGradientDescentHandler with more user-friendly error message
1443    # on get_slot().
1444    self._optimizer_handler_dict = self._get_optimizer_handler_by_table()
1445
1446    self._pipeline_execution_with_tensor_core = (
1447        pipeline_execution_with_tensor_core)
1448    self._learning_rate_fn = list(
1449        set(c.learning_rate_fn
1450            for c in self._table_to_config_dict.values()
1451            if c.learning_rate_fn is not None))
1452    self._learning_rate_fn_to_tag = {
1453        fn: id for id, fn in enumerate(self._learning_rate_fn)
1454    }
1455
1456    self._config_proto = self._create_config_proto()
1457
1458  @property
1459  def hosts(self):
1460    """A list of device names for CPU hosts.
1461
1462    Returns:
1463      A list of device names for CPU hosts.
1464    """
1465    return copy.copy(self._hosts)
1466
1467  # TODO(shizhiw): change to num_tensor_cores_per_host to be more explicit and
1468  # to be consistent with `tpu_embedding_configuration.proto`.
1469  @property
1470  def num_cores_per_host(self):
1471    """Number of TPU cores on a CPU host.
1472
1473    Returns:
1474      Number of TPU cores on a CPU host.
1475    """
1476    return self._num_cores_per_host
1477
1478  @property
1479  def num_cores(self):
1480    """Total number of TPU cores on all hosts.
1481
1482    Returns:
1483      Total number of TPU cores on all hosts.
1484    """
1485    return self._num_cores
1486
1487  @property
1488  def batch_size_per_core(self):
1489    """Batch size for each TPU core.
1490
1491    The sparse tensors in `sparse_features_list` to `generate_enqueue_ops`
1492       must have batch dimension equal to this.
1493
1494    Returns:
1495      Batch size for each TPU core.
1496    """
1497    return self._batch_size_per_core
1498
1499  @property
1500  def config_proto(self):
1501    """Create embedding config proto for `tpu.initialize_system()`.
1502
1503    Returns:
1504      an `TPUEmbeddingConfiguration` proto describing the desired
1505         configuration of the hardware embedding lookup tables, which
1506         is passed to `tpu.initialize_system()`.
1507    """
1508    return self._config_proto
1509
1510  @property
1511  def table_to_config_dict(self):
1512    return copy.copy(self._table_to_config_dict)
1513
1514  @property
1515  def feature_to_config_dict(self):
1516    return copy.copy(self._feature_to_config_dict)
1517
1518  @property
1519  def table_to_features_dict(self):
1520    return copy.copy(self._table_to_features_dict)
1521
1522  @property
1523  def optimization_parameters(self):
1524    return self._optimization_parameters
1525
1526  def _create_config_proto(self):
1527    """Create `TPUEmbeddingConfiguration`."""
1528    config_proto = elc.TPUEmbeddingConfiguration()
1529    for table in self._table_to_config_dict:
1530      table_descriptor = config_proto.table_descriptor.add()
1531      table_descriptor.name = table
1532
1533      table_config = self._table_to_config_dict[table]
1534      # For small tables, we pad to the number of hosts so that at least one
1535      # id will be assigned to each host.
1536      table_descriptor.vocabulary_size = max(table_config.vocabulary_size,
1537                                             len(self.hosts))
1538      table_descriptor.dimension = table_config.dimension
1539
1540      optimization_parameters = (
1541          self._optimizer_handler_dict[table].get_optimization_parameters())
1542
1543      parameters = table_descriptor.optimization_parameters
1544      if table_config.learning_rate:
1545        parameters.learning_rate.constant = table_config.learning_rate
1546      elif table_config.learning_rate_fn:
1547        parameters.learning_rate.dynamic.tag = (
1548            self._learning_rate_fn_to_tag[table_config.learning_rate_fn])
1549      else:
1550        parameters.learning_rate.constant = (
1551            optimization_parameters.learning_rate)
1552      parameters.gradient_accumulation_status = (
1553          optimization_parameters_pb2.GradientAccumulationStatus.ENABLED
1554          if optimization_parameters.use_gradient_accumulation else
1555          optimization_parameters_pb2.GradientAccumulationStatus.DISABLED)
1556
1557      if optimization_parameters.clip_gradient_min is not None:
1558        parameters.gradient_clipping_limits.lower.value = (
1559            optimization_parameters.clip_gradient_min)
1560      if optimization_parameters.clip_gradient_max is not None:
1561        parameters.gradient_clipping_limits.upper.value = (
1562            optimization_parameters.clip_gradient_max)
1563
1564      if optimization_parameters.clip_weight_min is not None:
1565        parameters.clipping_limits.lower.value = (
1566            optimization_parameters.clip_weight_min)
1567      if optimization_parameters.clip_weight_max is not None:
1568        parameters.clipping_limits.upper.value = (
1569            optimization_parameters.clip_weight_max)
1570      if optimization_parameters.weight_decay_factor:
1571        parameters.weight_decay_factor = (
1572            optimization_parameters.weight_decay_factor)
1573        if (optimization_parameters
1574            .multiply_weight_decay_factor_by_learning_rate):
1575          parameters.multiply_weight_decay_factor_by_learning_rate = True
1576      if table_config.hot_id_replication:
1577        parameters.hot_id_replication_configuration.status = (
1578            optimization_parameters_pb2.HotIdReplicationConfiguration.ENABLED)
1579      optimizer_handler = self._optimizer_handler_dict[table]
1580      optimizer_handler.set_optimization_parameters(table_descriptor)
1581
1582    table_to_id = {
1583        table: i for i, table in enumerate(self._table_to_config_dict)
1584    }
1585
1586    # Set feature descriptor field in the config proto.
1587    for table in self._table_to_features_dict:
1588      features = self._table_to_features_dict[table]
1589      for feature in features:
1590        feature_descriptor = config_proto.feature_descriptor.add()
1591
1592        feature_descriptor.table_id = table_to_id[
1593            self._feature_to_config_dict[feature].table_id]
1594        if self._feature_to_config_dict[feature].max_sequence_length > 0:
1595          feature_descriptor.input_shape.extend([
1596              self._batch_size_per_core,
1597              self._feature_to_config_dict[feature].max_sequence_length
1598          ])
1599        else:
1600          feature_descriptor.input_shape.extend([self._batch_size_per_core])
1601
1602    config_proto.mode = self._mode
1603    config_proto.num_hosts = self._num_hosts
1604    config_proto.num_tensor_cores = self._num_cores
1605    config_proto.sharding_strategy = (
1606        elc.TPUEmbeddingConfiguration.DIV_DEFAULT if self._partition_strategy
1607        == 'div' else elc.TPUEmbeddingConfiguration.MOD)
1608    config_proto.pipeline_execution_with_tensor_core = (
1609        self._pipeline_execution_with_tensor_core)
1610    if self._profile_data_directory:
1611      config_proto.profile_data_directory = self._profile_data_directory
1612
1613    return config_proto
1614
1615  def create_variables_and_ops(self,
1616                               embedding_variable_name_by_table=None,
1617                               slot_variable_names_by_table=None):
1618    """Create embedding and slot variables, with ops to load and retrieve them.
1619
1620    N.B.: the retrieve embedding variables (including slot variables) ops are
1621    returned as lambda fn, as the call side might want to impose control
1622    dependencies between the TPU computation and retrieving actions. For
1623    example, the following code snippet ensures the TPU computation finishes
1624    first, and then we pull the variables back from TPU to CPU.
1625
1626    ```
1627    updates_ops = []
1628    with ops.control_dependencies([loss]):
1629      for op_fn in retrieve_parameters_op_fns:
1630        update_ops.append(op_fn())
1631    ```
1632
1633    Args:
1634      embedding_variable_name_by_table: A dictionary mapping from string of
1635        table name to string of embedding variable name. If `None`, defaults
1636        from `get_default_slot_variable_names()` will be used.
1637      slot_variable_names_by_table: A dictionary mapping from string of table
1638        name to `AdamSlotVariableNames`, `AdagradSlotVariableNames` etc. If
1639        `None`, defaults from `get_default_slot_variable_names()` will be used.
1640
1641    Returns:
1642      `tpu_embedding.VariablesAndOps` with:
1643        A dictionary mapping from string of table name to embedding variables,
1644        A dictionary mapping from string of table name to AdagradSlotVariables,
1645         AdamSlotVariables etc with slot variables,
1646        A function which returns a list of ops to load embedding and slot
1647         variables from CPU to TPU.
1648        A function which returns a list of ops to retrieve embedding and slot
1649         variables from TPU to CPU.
1650    """
1651    embedding_variables_by_table = {}
1652    slot_variables_by_table = {}
1653    load_op_fns = []
1654    retrieve_op_fns = []
1655
1656    for i, table in enumerate(self._table_to_config_dict):
1657      if embedding_variable_name_by_table:
1658        embedding_variable_name = embedding_variable_name_by_table[table]
1659      else:
1660        embedding_variable_name = table
1661      if slot_variable_names_by_table:
1662        slot_variable_names = slot_variable_names_by_table[table]
1663      else:
1664        optimizer_handler = self._optimizer_handler_dict[table]
1665        slot_variable_names = (
1666            optimizer_handler.get_default_slot_variable_names(table))
1667
1668      # TODO(b/139144091): Multi-host support for mid-level API in
1669      #  eager context (TF 2.0)
1670      # Workaround below allows single-host use case in TF 2.0
1671      if context.executing_eagerly():
1672        device = ''
1673      else:
1674        device = _create_device_fn(self._hosts)
1675
1676      with ops.device(device):
1677        table_variables = _create_partitioned_variables(
1678            name=embedding_variable_name,
1679            num_hosts=self._num_hosts,
1680            vocabulary_size=self._table_to_config_dict[table].vocabulary_size,
1681            embedding_dimension=self._table_to_config_dict[table].dimension,
1682            initializer=self._table_to_config_dict[table].initializer,
1683            collections=[ops.GraphKeys.GLOBAL_VARIABLES])
1684        embedding_variables_by_table[table] = table_variables
1685
1686        # Only loads embedding config to load/retrieve nodes for the first table
1687        # on the first host, other nodes would use config from the first node.
1688        config = None if i else self.config_proto.SerializeToString()
1689        slot_variables_for_table, load_ops_fn, retrieve_ops_fn = (
1690            self._optimizer_handler_dict[table].create_variables_and_ops(
1691                table, slot_variable_names, self._num_hosts,
1692                self._table_to_config_dict[table], table_variables, config))
1693        slot_variables_by_table[table] = slot_variables_for_table
1694        load_op_fns.append(load_ops_fn)
1695        retrieve_op_fns.append(retrieve_ops_fn)
1696
1697    def load_ops():
1698      """Calls and returns the load ops for each embedding table.
1699
1700      Returns:
1701        A list of ops to load embedding and slot variables from CPU to TPU.
1702      """
1703      load_ops_list = []
1704      for load_op_fn in load_op_fns:
1705        load_ops_list.extend(load_op_fn())
1706      return load_ops_list
1707
1708    def retrieve_ops():
1709      """Calls and returns the retrieve ops for each embedding table.
1710
1711      Returns:
1712        A list of ops to retrieve embedding and slot variables from TPU to CPU.
1713      """
1714      retrieve_ops_list = []
1715      for retrieve_op_fn in retrieve_op_fns:
1716        retrieve_ops_list.extend(retrieve_op_fn())
1717      return retrieve_ops_list
1718
1719    return VariablesAndOps(embedding_variables_by_table,
1720                           slot_variables_by_table, load_ops, retrieve_ops)
1721
1722  def generate_enqueue_ops(
1723      self,
1724      enqueue_datas_list,
1725      mode_override=None,
1726      ragged=False,
1727  ):
1728    """Generate enqueue ops.
1729
1730    Args:
1731      enqueue_datas_list: a list of dictionary mapping from string of feature
1732        names to EnqueueData. Each dictionary is for one TPU core. Dictionaries
1733        for the same host should be contiguous in the list.
1734      mode_override: A string input that overrides the mode specified in the
1735        TPUEmbeddingConfiguration. Supported values are {'unspecified',
1736        'inference', 'training', 'backward_pass_only'}. When set to
1737        'unspecified', the mode set in TPUEmbeddingConfiguration is used,
1738        otherwise mode_override is used (optional).
1739      ragged: If True, creates RaggedTensor enqueue ops rather than
1740        SparseTensor.
1741
1742    Returns:
1743      Ops to enqueue to TPU for embedding.
1744    """
1745    self._validate_generate_enqueue_ops_enqueue_datas_list(enqueue_datas_list)
1746    return [
1747        self._generate_enqueue_op(  # pylint: disable=g-complex-comprehension
1748            enqueue_datas,
1749            device_ordinal=i % self._num_cores_per_host,
1750            mode_override=mode_override,
1751            ragged=ragged,
1752        ) for i, enqueue_datas in enumerate(enqueue_datas_list)
1753    ]
1754
1755  def _validate_generate_enqueue_ops_enqueue_datas_list(self,
1756                                                        enqueue_datas_list):
1757    """Validate `enqueue_datas_list`."""
1758
1759    def _check_agreement(data, name, feature, enqueue_data):
1760      """Helper function to check device agreement."""
1761      if (data is not None and
1762          data.device != enqueue_data.embedding_indices.device):
1763        raise ValueError('Device of {0} does not agree with that of'
1764                         'embedding_indices for feature {1}.'.format(
1765                             name, feature))
1766
1767    feature_set = set(self._feature_to_config_dict.keys())
1768    contiguous_device = None
1769    for i, enqueue_datas in enumerate(enqueue_datas_list):
1770      used_feature_set = set(enqueue_datas.keys())
1771
1772      # Check features are valid.
1773      missing_feature_set = feature_set - used_feature_set
1774      if missing_feature_set:
1775        raise ValueError('`enqueue_datas_list[{}]` misses a feature that is '
1776                         'in `feature_to_config_dict`: {}.'.format(
1777                             i, missing_feature_set))
1778
1779      extra_feature_set = used_feature_set - feature_set
1780      if extra_feature_set:
1781        raise ValueError('`enqueue_datas_list[{}]` has a feature that is not '
1782                         'in `feature_to_config_dict`: {}.'.format(
1783                             i, extra_feature_set))
1784
1785      device = None
1786      device_feature = None
1787      for feature, enqueue_data in enqueue_datas.items():
1788        combiner = self._table_to_config_dict[
1789            self._feature_to_config_dict[feature].table_id].combiner
1790
1791        if isinstance(enqueue_data, EnqueueData):
1792          if enqueue_data.sample_indices is None and combiner:
1793            logging.warn(
1794                'No sample indices set for features %f table %f but '
1795                'combiner is set to %s.', feature,
1796                self._feature_to_config_dict[feature].table_id, combiner)
1797          _check_agreement(enqueue_data.sample_indices, 'sample_indices',
1798                           feature, enqueue_data)
1799          _check_agreement(enqueue_data.aggregation_weights,
1800                           'aggregation_weights', feature, enqueue_data)
1801
1802        elif isinstance(enqueue_data, RaggedEnqueueData):
1803          if enqueue_data.row_splits is None and combiner:
1804            logging.warn(
1805                'No row splits set for features %f table %f but '
1806                'combiner is set to %s.', feature,
1807                self._feature_to_config_dict[feature].table_id, combiner)
1808          _check_agreement(enqueue_data.row_splits, 'row_splits', feature,
1809                           enqueue_data)
1810          _check_agreement(enqueue_data.aggregation_weights,
1811                           'aggregation_weights', feature, enqueue_data)
1812        else:
1813          raise ValueError(
1814              '`enqueue_datas_list[{}]` has a feature that is not mapped to '
1815              '`EnqueueData` or `RaggedEnqueueData`. `feature`: {}'.format(
1816                  i, feature))
1817        # Check all features are on the same device.
1818        if device is None:
1819          device = enqueue_data.embedding_indices.device
1820          device_feature = feature
1821        else:
1822          if device != enqueue_data.embedding_indices.device:
1823            raise ValueError('Devices are different between features in '
1824                             '`enqueue_datas_list[{}]`; '
1825                             'devices: {}, {}; features: {}, {}.'.format(
1826                                 i, device,
1827                                 enqueue_data.embedding_indices.device, feature,
1828                                 device_feature))
1829
1830      if i % self._num_cores_per_host:
1831        if device != contiguous_device:
1832          raise ValueError('We expect the `enqueue_datas` which are on the '
1833                           'same host to be contiguous in '
1834                           '`enqueue_datas_list`, '
1835                           '`enqueue_datas_list[{}]` is on device {}, '
1836                           'but is expected to be on device {}.'.format(
1837                               i, device, contiguous_device))
1838      else:
1839        contiguous_device = device
1840
1841  def _generate_enqueue_op(self,
1842                           enqueue_datas,
1843                           device_ordinal,
1844                           mode_override=None,
1845                           ragged=False):
1846    """Creates op for enqueuing batch to TPU."""
1847    enqueue_data0 = list(enqueue_datas.values())[0]
1848    with ops.colocate_with(enqueue_data0.embedding_indices):
1849      return tpu_ops.enqueue_tpu_embedding_arbitrary_tensor_batch(
1850          device_ordinal=device_ordinal,
1851          combiners=self._combiners,
1852          mode_override=mode_override,
1853          **self._format_for_tpu_embedding_arbitrary_tensor_batch(
1854              enqueue_datas, ragged))
1855
1856  def _format_for_tpu_embedding_arbitrary_tensor_batch(self, enqueue_datas,
1857                                                       ragged):
1858    """Format features for `enqueue_tpu_embedding_arbitrary_tensor_batch()`.
1859
1860    Args:
1861      enqueue_datas: a `Dict` of `RaggedEnqueueData` objects for embedding.
1862      ragged: If True, extract row splits from the data rather than sample
1863        indices.
1864
1865    Returns:
1866      Dict of arguments for `enqueue_tpu_embedding_arbitrary_tensor_batch()`.
1867    """
1868
1869    kwargs = {
1870        'sample_indices_or_row_splits': [],
1871        'embedding_indices': [],
1872        'aggregation_weights': [],
1873    }
1874    int_zeros = array_ops.zeros((0,), dtype=dtypes.int64)
1875    float_zeros = array_ops.zeros((0,), dtype=dtypes.float32)
1876    for table in self._table_to_features_dict:
1877      features = self._table_to_features_dict[table]
1878      for feature in features:
1879        enqueue_data = enqueue_datas[feature]
1880        if ragged:
1881          kwargs['sample_indices_or_row_splits'].append(
1882              enqueue_data.row_splits if enqueue_data
1883              .row_splits is not None else int_zeros)
1884        else:
1885          if (self._feature_to_config_dict[feature].max_sequence_length > 0 and
1886              enqueue_data.sample_indices is not None and
1887              enqueue_data.sample_indices.shape[1] == 2):
1888            # Pad the sample indices as if the enqueued sparse tensor is rank 2.
1889            sample_indices = array_ops.pad(
1890                enqueue_data.sample_indices, paddings=[[0, 0], [0, 1]])
1891            kwargs['sample_indices_or_row_splits'].append(sample_indices)
1892          else:
1893            # If the sample_indices is rank 1 or not present, treat it as dense
1894            # tensor.
1895            if (enqueue_data.sample_indices is None or
1896                enqueue_data.sample_indices.shape[1] == 1):
1897              kwargs['sample_indices_or_row_splits'].append(int_zeros)
1898            else:
1899              kwargs['sample_indices_or_row_splits'].append(
1900                  enqueue_data.sample_indices)
1901
1902        kwargs['aggregation_weights'].append(
1903            enqueue_data.aggregation_weights if enqueue_data
1904            .aggregation_weights is not None else float_zeros)
1905
1906        kwargs['embedding_indices'].append(enqueue_data.embedding_indices)
1907    return kwargs
1908
1909  def get_activations(self):
1910    """Get activations for features.
1911
1912    This should be called within `computation` that is passed to
1913      `tpu.replicate` and friends.
1914
1915    Returns:
1916      A dictionary mapping from `String` of feature name to `Tensor`
1917        of activation.
1918    """
1919    recv_activations = tpu_ops.recv_tpu_embedding_activations(
1920        num_outputs=len(self._feature_to_config_dict),
1921        config=self._config_proto.SerializeToString())
1922
1923    activations = collections.OrderedDict()
1924    index = 0
1925    for table in self._table_to_features_dict:
1926      for feature in self._table_to_features_dict[table]:
1927        activations[feature] = recv_activations[index]
1928        index += 1
1929    return activations
1930
1931  def generate_send_gradients_op(self, feature_to_gradient_dict, step=None):
1932    """Send gradient to TPU embedding.
1933
1934    Args:
1935      feature_to_gradient_dict: dict mapping feature names to gradient wrt
1936        activations.
1937      step: the current global step, used for dynamic learning rate.
1938
1939    Returns:
1940      SendTPUEmbeddingGradients Op.
1941
1942    Raises:
1943      RuntimeError: If `mode` is not `TRAINING`.
1944    """
1945    if self._mode != TRAINING:
1946      raise RuntimeError('Only in training mode gradients need to '
1947                         'be sent to TPU embedding; got mode {}.'.format(
1948                             self._mode))
1949    if step is None and self._learning_rate_fn:
1950      raise ValueError('There are dynamic learning rates but step is None.')
1951
1952    gradients = []
1953    for table in self._table_to_features_dict:
1954      for feature in self._table_to_features_dict[table]:
1955        gradients.append(feature_to_gradient_dict[feature])
1956
1957    return tpu_ops.send_tpu_embedding_gradients(
1958        inputs=gradients,
1959        learning_rates=[
1960            math_ops.cast(fn(step), dtype=dtypes.float32)
1961            for fn in self._learning_rate_fn
1962        ],
1963        config=self.config_proto.SerializeToString())
1964
1965  def _get_optimizer_handler_by_table(self):
1966    optimizer_handlers = {}
1967    for table, table_config in self.table_to_config_dict.items():
1968      if table_config.optimization_parameters is not None:
1969        optimizer = table_config.optimization_parameters
1970      else:
1971        optimizer = self._optimization_parameters
1972      optimizer_handlers[table] = _get_optimization_handler(optimizer)
1973
1974    return optimizer_handlers
1975
1976
1977def _validate_table_to_config_dict(table_to_config_dict):
1978  """Validate `table_to_config_dict`."""
1979  for k, v in table_to_config_dict.items():
1980    if not isinstance(v, TableConfig):
1981      raise ValueError('Value of `table_to_config_dict` must be of type '
1982                       '`TableConfig`, got {} for {}.'.format(type(v), k))
1983
1984
1985def _validate_feature_to_config_dict(table_to_config_dict,
1986                                     feature_to_config_dict):
1987  """Validate `feature_to_config_dict`."""
1988  used_table_set = set(
1989      [feature.table_id for feature in feature_to_config_dict.values()])
1990  table_set = set(table_to_config_dict.keys())
1991
1992  unused_table_set = table_set - used_table_set
1993  if unused_table_set:
1994    raise ValueError(
1995        '`table_to_config_dict` specifies table that is not '
1996        'used in `feature_to_config_dict`: {}.'.format(unused_table_set))
1997
1998  extra_table_set = used_table_set - table_set
1999  if extra_table_set:
2000    raise ValueError(
2001        '`feature_to_config_dict` refers to a table that is not '
2002        'specified in `table_to_config_dict`: {}.'.format(extra_table_set))
2003
2004
2005def _validate_batch_size(batch_size, num_cores):
2006  if batch_size % num_cores:
2007    raise ValueError('`batch_size` is not a multiple of number of '
2008                     'cores. `batch_size`={}, `_num_cores`={}.'.format(
2009                         batch_size, num_cores))
2010
2011
2012def _validate_optimization_parameters(optimization_parameters,
2013                                      table_to_config_dict):
2014  """Validate global optimization_parameters and per table optimizers.
2015
2016  If global optimizer is `None`, all table optimizers should be non `None`.
2017
2018  Args:
2019      optimization_parameters: global optimizer provided in `TPUEmbedding`
2020        constructor.
2021      table_to_config_dict: A dictionary mapping from string of table name to
2022        `TableConfig`.
2023  """
2024  tbl_optimizer_missing = False
2025  for _, table_config in table_to_config_dict.items():
2026    if table_config.optimization_parameters is None:
2027      tbl_optimizer_missing = True
2028      break
2029
2030  if optimization_parameters:
2031    if not isinstance(optimization_parameters, _OptimizationParameters):
2032      raise ValueError('`optimization_parameters` must inherit from '
2033                       '`_OptimizationParameters`. '
2034                       '`type(optimization_parameters)`={}'.format(
2035                           type(optimization_parameters)))
2036  else:
2037    # Missing global optimization_parameters.
2038    if tbl_optimizer_missing:
2039      raise ValueError('`optimization_parameters` is missing.')
2040
2041
2042class _OptimizerHandler:
2043  """Interface class for handling optimizer specific logic."""
2044
2045  def __init__(self, optimization_parameters):
2046    self._optimization_parameters = optimization_parameters
2047
2048  def get_optimization_parameters(self):
2049    return self._optimization_parameters
2050
2051  def set_optimization_parameters(self, table_descriptor):
2052    raise NotImplementedError()
2053
2054  def get_default_slot_variable_names(self, table):
2055    raise NotImplementedError()
2056
2057  def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
2058                               table_config, table_variables, config_proto):
2059    raise NotImplementedError()
2060
2061
2062class _AdagradHandler(_OptimizerHandler):
2063  """Handles Adagrad specific logic."""
2064
2065  def set_optimization_parameters(self, table_descriptor):
2066    table_descriptor.optimization_parameters.adagrad.SetInParent()
2067
2068  def get_default_slot_variable_names(self, table):
2069    return AdagradSlotVariableNames('{}/{}'.format(table, 'Adagrad'))
2070
2071  def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
2072                               table_config, table_variables, config_proto):
2073    accumulator_initializer = init_ops.constant_initializer(
2074        self._optimization_parameters.initial_accumulator)
2075    accumulator_variables = _create_partitioned_variables(
2076        name=slot_variable_names.accumulator,
2077        num_hosts=num_hosts,
2078        vocabulary_size=table_config.vocabulary_size,
2079        embedding_dimension=table_config.dimension,
2080        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
2081        initializer=accumulator_initializer)
2082    slot_variables = AdagradSlotVariables(accumulator_variables)
2083
2084    def load_ops_fn():
2085      """Returns the retrieve ops for AdaGrad embedding tables.
2086
2087      Returns:
2088        A list of ops to load embedding and slot variables from CPU to TPU.
2089      """
2090      config = config_proto
2091      load_op_list = []
2092      for host_id, table_variable, accumulator_variable in zip(
2093          range(num_hosts), table_variables, accumulator_variables):
2094        with ops.colocate_with(table_variable):
2095          load_parameters_op = (
2096              tpu_ops.load_tpu_embedding_adagrad_parameters(
2097                  parameters=table_variable,
2098                  accumulators=accumulator_variable,
2099                  table_name=table,
2100                  num_shards=num_hosts,
2101                  shard_id=host_id,
2102                  config=config))
2103        config = None
2104        load_op_list.append(load_parameters_op)
2105      return load_op_list
2106
2107    def retrieve_ops_fn():
2108      """Returns the retrieve ops for AdaGrad embedding tables.
2109
2110      Returns:
2111        A list of ops to retrieve embedding and slot variables from TPU to CPU.
2112      """
2113      config = config_proto
2114      retrieve_op_list = []
2115      for host_id, table_variable, accumulator_variable in (zip(
2116          range(num_hosts), table_variables, accumulator_variables)):
2117        with ops.colocate_with(table_variable):
2118          retrieved_table, retrieved_accumulator = (
2119              tpu_ops.retrieve_tpu_embedding_adagrad_parameters(
2120                  table_name=table,
2121                  num_shards=num_hosts,
2122                  shard_id=host_id,
2123                  config=config))
2124          retrieve_parameters_op = control_flow_ops.group(
2125              state_ops.assign(table_variable, retrieved_table),
2126              state_ops.assign(accumulator_variable, retrieved_accumulator))
2127        config = None
2128        retrieve_op_list.append(retrieve_parameters_op)
2129      return retrieve_op_list
2130
2131    return slot_variables, load_ops_fn, retrieve_ops_fn
2132
2133
2134class _AdagradMomentumHandler(_OptimizerHandler):
2135  """Handles Adagrad with Momentum specific logic.
2136
2137  Creates slot variables and defines their initializers. Defines load/retrieve
2138  operations to be used for loading variables into TPU memory (from host memory)
2139  and retrieving variables from TPU memory (into host memory).
2140  """
2141
2142  def set_optimization_parameters(self, table_descriptor):
2143    table_descriptor.optimization_parameters.adagrad_momentum.SetInParent()
2144    table_descriptor.optimization_parameters.adagrad_momentum.momentum = (
2145        self._optimization_parameters.momentum)
2146    table_descriptor.optimization_parameters.adagrad_momentum.use_nesterov = (
2147        self._optimization_parameters.use_nesterov)
2148    table_descriptor.optimization_parameters.adagrad_momentum.exponent = (
2149        self._optimization_parameters.exponent)
2150    table_descriptor.optimization_parameters.adagrad_momentum.beta2 = (
2151        self._optimization_parameters.beta2)
2152    table_descriptor.optimization_parameters.adagrad_momentum.epsilon = (
2153        self._optimization_parameters.epsilon)
2154
2155  def get_default_slot_variable_names(self, table):
2156    return AdagradMomentumSlotVariableNames(
2157        '{}/{}/Accumulator'.format(table, 'AdagradMomentum'),
2158        '{}/{}/Momentum'.format(table, 'AdagradMomentum'))
2159
2160  def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
2161                               table_config, table_variables, config_proto):
2162    accumulator_initializer = init_ops.zeros_initializer()
2163    accumulator_variables = _create_partitioned_variables(
2164        name=slot_variable_names.accumulator,
2165        num_hosts=num_hosts,
2166        vocabulary_size=table_config.vocabulary_size,
2167        embedding_dimension=table_config.dimension,
2168        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
2169        initializer=accumulator_initializer)
2170    momenta_initializer = init_ops.zeros_initializer()
2171    momenta_variables = _create_partitioned_variables(
2172        name=slot_variable_names.momenta,
2173        num_hosts=num_hosts,
2174        vocabulary_size=table_config.vocabulary_size,
2175        embedding_dimension=table_config.dimension,
2176        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
2177        initializer=momenta_initializer)
2178    slot_variables = AdagradMomentumSlotVariables(accumulator_variables,
2179                                                  momenta_variables)
2180
2181    def load_ops_fn():
2182      """Returns the load ops for AdaGrad with momentum embedding tables.
2183
2184      Returns:
2185        A list of ops to load embedding and slot variables from CPU to TPU.
2186      """
2187      config = config_proto
2188      load_op_list = []
2189      for host_id, table_variable, accumulator_variable, momenta_variable in zip(
2190          range(num_hosts), table_variables, accumulator_variables,
2191          momenta_variables):
2192        with ops.colocate_with(table_variable):
2193          load_parameters_op = (
2194              tpu_ops.load_tpu_embedding_adagrad_momentum_parameters(
2195                  parameters=table_variable,
2196                  accumulators=accumulator_variable,
2197                  momenta=momenta_variable,
2198                  table_name=table,
2199                  num_shards=num_hosts,
2200                  shard_id=host_id,
2201                  config=config))
2202        config = None
2203        load_op_list.append(load_parameters_op)
2204      return load_op_list
2205
2206    def retrieve_ops_fn():
2207      """Returns the retrieve ops for AdaGrad with momentum embedding tables.
2208
2209      Returns:
2210        A list of ops to retrieve embedding and slot variables from TPU to CPU.
2211      """
2212      config = config_proto
2213      retrieve_op_list = []
2214      for host_id, table_variable, accumulator_variable, momenta_variable in (
2215          zip(
2216              range(num_hosts), table_variables, accumulator_variables,
2217              momenta_variables)):
2218        with ops.colocate_with(table_variable):
2219          retrieved_table, retrieved_accumulator, retrieved_momenta = (
2220              tpu_ops.retrieve_tpu_embedding_adagrad_momentum_parameters(
2221                  table_name=table,
2222                  num_shards=num_hosts,
2223                  shard_id=host_id,
2224                  config=config))
2225          retrieve_parameters_op = control_flow_ops.group(
2226              state_ops.assign(table_variable, retrieved_table),
2227              state_ops.assign(accumulator_variable, retrieved_accumulator),
2228              state_ops.assign(momenta_variable, retrieved_momenta))
2229        config = None
2230        retrieve_op_list.append(retrieve_parameters_op)
2231      return retrieve_op_list
2232
2233    return slot_variables, load_ops_fn, retrieve_ops_fn
2234
2235
2236class _ProximalAdagradHandler(_OptimizerHandler):
2237  """Handles ProximalAdagrad specific logic."""
2238
2239  def set_optimization_parameters(self, table_descriptor):
2240    table_descriptor.optimization_parameters.proximal_adagrad.SetInParent()
2241    table_descriptor.optimization_parameters.proximal_adagrad.l1 = (
2242        self._optimization_parameters.l1_regularization_strength)
2243    table_descriptor.optimization_parameters.proximal_adagrad.l2 = (
2244        self._optimization_parameters.l2_regularization_strength)
2245
2246  def get_default_slot_variable_names(self, table):
2247    return ProximalAdagradSlotVariableNames('{}/{}'.format(
2248        table, 'ProximalAdagrad'))
2249
2250  def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
2251                               table_config, table_variables, config_proto):
2252    accumulator_initializer = init_ops.constant_initializer(
2253        self._optimization_parameters.initial_accumulator)
2254    accumulator_variables = _create_partitioned_variables(
2255        name=slot_variable_names.accumulator,
2256        num_hosts=num_hosts,
2257        vocabulary_size=table_config.vocabulary_size,
2258        embedding_dimension=table_config.dimension,
2259        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
2260        initializer=accumulator_initializer)
2261    slot_variables = ProximalAdagradSlotVariables(accumulator_variables)
2262
2263    def load_ops_fn():
2264      """Returns the retrieve ops for Proximal AdaGrad embedding tables.
2265
2266      Returns:
2267        A list of ops to load embedding and slot variables from CPU to TPU.
2268      """
2269      config = config_proto
2270      load_op_list = []
2271      for host_id, table_variable, accumulator_variable in zip(
2272          range(num_hosts), table_variables, accumulator_variables):
2273        with ops.colocate_with(table_variable):
2274          load_parameters_op = (
2275              tpu_ops.load_tpu_embedding_proximal_adagrad_parameters(
2276                  parameters=table_variable,
2277                  accumulators=accumulator_variable,
2278                  table_name=table,
2279                  num_shards=num_hosts,
2280                  shard_id=host_id,
2281                  config=config))
2282        config = None
2283        load_op_list.append(load_parameters_op)
2284      return load_op_list
2285
2286    def retrieve_ops_fn():
2287      """Returns the retrieve ops for Proximal AdaGrad embedding tables.
2288
2289      Returns:
2290        A list of ops to retrieve embedding and slot variables from TPU to CPU.
2291      """
2292      config = config_proto
2293      retrieve_op_list = []
2294      for host_id, table_variable, accumulator_variable in (zip(
2295          range(num_hosts), table_variables, accumulator_variables)):
2296        with ops.colocate_with(table_variable):
2297          retrieved_table, retrieved_accumulator = (
2298              tpu_ops.retrieve_tpu_embedding_proximal_adagrad_parameters(
2299                  table_name=table,
2300                  num_shards=num_hosts,
2301                  shard_id=host_id,
2302                  config=config))
2303          retrieve_parameters_op = control_flow_ops.group(
2304              state_ops.assign(table_variable, retrieved_table),
2305              state_ops.assign(accumulator_variable, retrieved_accumulator))
2306        config = None
2307        retrieve_op_list.append(retrieve_parameters_op)
2308      return retrieve_op_list
2309
2310    return slot_variables, load_ops_fn, retrieve_ops_fn
2311
2312
2313class _AdamHandler(_OptimizerHandler):
2314  """Handles Adam specific logic."""
2315
2316  def set_optimization_parameters(self, table_descriptor):
2317    table_descriptor.optimization_parameters.adam.beta1 = (
2318        self._optimization_parameters.beta1)
2319    table_descriptor.optimization_parameters.adam.beta2 = (
2320        self._optimization_parameters.beta2)
2321    table_descriptor.optimization_parameters.adam.epsilon = (
2322        self._optimization_parameters.epsilon)
2323    table_descriptor.optimization_parameters.adam.use_non_lazy_adam = (
2324        not self._optimization_parameters.lazy_adam)
2325    table_descriptor.optimization_parameters.adam.use_sum_inside_sqrt = (
2326        self._optimization_parameters.sum_inside_sqrt)
2327
2328  def get_default_slot_variable_names(self, table):
2329    return AdamSlotVariableNames('{}/{}/m'.format(table, 'Adam'),
2330                                 '{}/{}/v'.format(table, 'Adam'))
2331
2332  def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
2333                               table_config, table_variables, config_proto):
2334    m_initializer = init_ops.zeros_initializer()
2335    m_variables = _create_partitioned_variables(
2336        name=slot_variable_names.m,
2337        num_hosts=num_hosts,
2338        vocabulary_size=table_config.vocabulary_size,
2339        embedding_dimension=table_config.dimension,
2340        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
2341        initializer=m_initializer)
2342    v_initializer = init_ops.zeros_initializer()
2343    v_variables = _create_partitioned_variables(
2344        name=slot_variable_names.v,
2345        num_hosts=num_hosts,
2346        vocabulary_size=table_config.vocabulary_size,
2347        embedding_dimension=table_config.dimension,
2348        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
2349        initializer=v_initializer)
2350    slot_variables = AdamSlotVariables(m_variables, v_variables)
2351
2352    def load_ops_fn():
2353      """Returns the retrieve ops for AdaGrad embedding tables.
2354
2355      Returns:
2356        A list of ops to load embedding and slot variables from CPU to TPU.
2357      """
2358      load_op_list = []
2359      config = config_proto
2360      for host_id, table_variable, m_variable, v_variable in (zip(
2361          range(num_hosts), table_variables, m_variables, v_variables)):
2362        with ops.colocate_with(table_variable):
2363          load_parameters_op = (
2364              tpu_ops.load_tpu_embedding_adam_parameters(
2365                  parameters=table_variable,
2366                  momenta=m_variable,
2367                  velocities=v_variable,
2368                  table_name=table,
2369                  num_shards=num_hosts,
2370                  shard_id=host_id,
2371                  config=config))
2372        # Set config to None to enforce that config is only loaded to the first
2373        # table.
2374        config = None
2375        load_op_list.append(load_parameters_op)
2376      return load_op_list
2377
2378    def retrieve_ops_fn():
2379      """Returns the retrieve ops for Adam embedding tables.
2380
2381      Returns:
2382        A list of ops to retrieve embedding and slot variables from TPU to CPU.
2383      """
2384      retrieve_op_list = []
2385      config = config_proto
2386      for host_id, table_variable, m_variable, v_variable in (zip(
2387          range(num_hosts), table_variables, m_variables, v_variables)):
2388        with ops.colocate_with(table_variable):
2389          retrieved_table, retrieved_m, retrieved_v = (
2390              tpu_ops.retrieve_tpu_embedding_adam_parameters(
2391                  table_name=table,
2392                  num_shards=num_hosts,
2393                  shard_id=host_id,
2394                  config=config))
2395          retrieve_parameters_op = control_flow_ops.group(
2396              state_ops.assign(table_variable, retrieved_table),
2397              state_ops.assign(m_variable, retrieved_m),
2398              state_ops.assign(v_variable, retrieved_v))
2399        config = None
2400        retrieve_op_list.append(retrieve_parameters_op)
2401      return retrieve_op_list
2402
2403    return slot_variables, load_ops_fn, retrieve_ops_fn
2404
2405
2406class _FtrlHandler(_OptimizerHandler):
2407  """Handles Ftrl specific logic."""
2408
2409  def set_optimization_parameters(self, table_descriptor):
2410    table_descriptor.optimization_parameters.ftrl.lr_power = (
2411        self._optimization_parameters.learning_rate_power)
2412    table_descriptor.optimization_parameters.ftrl.l1 = (
2413        self._optimization_parameters.l1_regularization_strength)
2414    table_descriptor.optimization_parameters.ftrl.l2 = (
2415        self._optimization_parameters.l2_regularization_strength)
2416    table_descriptor.optimization_parameters.ftrl.multiply_linear_by_lr = (
2417        self._optimization_parameters.multiply_linear_by_learning_rate)
2418    table_descriptor.optimization_parameters.ftrl.beta = (
2419        self._optimization_parameters.beta)
2420    table_descriptor.optimization_parameters.ftrl.allow_zero_accumulator = (
2421        self._optimization_parameters.allow_zero_accumulator)
2422
2423  def get_default_slot_variable_names(self, table):
2424    # These match the default slot variable names created by
2425    # tf.train.FtrlOptimizer.
2426    return FtrlSlotVariableNames(
2427        '{}/{}'.format(table, 'Ftrl'),  # accumulator
2428        '{}/{}'.format(table, 'Ftrl_1'))  # linear
2429
2430  def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
2431                               table_config, table_variables, config_proto):
2432    accumulator_initializer = init_ops.constant_initializer(
2433        self._optimization_parameters.initial_accumulator_value)
2434    accumulator_variables = _create_partitioned_variables(
2435        name=slot_variable_names.accumulator,
2436        num_hosts=num_hosts,
2437        vocabulary_size=table_config.vocabulary_size,
2438        embedding_dimension=table_config.dimension,
2439        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
2440        initializer=accumulator_initializer)
2441    linear_initializer = init_ops.constant_initializer(
2442        self._optimization_parameters.initial_linear_value)
2443    linear_variables = _create_partitioned_variables(
2444        name=slot_variable_names.linear,
2445        num_hosts=num_hosts,
2446        vocabulary_size=table_config.vocabulary_size,
2447        embedding_dimension=table_config.dimension,
2448        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
2449        initializer=linear_initializer)
2450    slot_variables = FtrlSlotVariable(accumulator_variables, linear_variables)
2451
2452    def load_ops_fn():
2453      """Returns the retrieve ops for Ftrl embedding tables.
2454
2455      Returns:
2456        A list of ops to load embedding and slot variables from CPU to TPU.
2457      """
2458      config = config_proto
2459      load_op_list = []
2460      for host_id, table_variable, accumulator_variable, linear_variable in zip(
2461          range(num_hosts), table_variables, accumulator_variables,
2462          linear_variables):
2463        with ops.colocate_with(table_variable):
2464          load_parameters_op = (
2465              tpu_ops.load_tpu_embedding_ftrl_parameters(
2466                  parameters=table_variable,
2467                  accumulators=accumulator_variable,
2468                  linears=linear_variable,
2469                  table_name=table,
2470                  num_shards=num_hosts,
2471                  shard_id=host_id,
2472                  config=config))
2473        config = None
2474        load_op_list.append(load_parameters_op)
2475      return load_op_list
2476
2477    def retrieve_ops_fn():
2478      """Returns the retrieve ops for Ftrl embedding tables.
2479
2480      Returns:
2481        A list of ops to retrieve embedding and slot variables from TPU to CPU.
2482      """
2483      config = config_proto
2484      retrieve_op_list = []
2485      for host_id, table_variable, accumulator_variable, linear_variable in zip(
2486          range(num_hosts), table_variables, accumulator_variables,
2487          linear_variables):
2488        with ops.colocate_with(table_variable):
2489          retrieved_table, retrieved_accumulator, retrieved_linear = (
2490              tpu_ops.retrieve_tpu_embedding_ftrl_parameters(
2491                  table_name=table,
2492                  num_shards=num_hosts,
2493                  shard_id=host_id,
2494                  config=config))
2495          retrieve_parameters_op = control_flow_ops.group(
2496              state_ops.assign(table_variable, retrieved_table),
2497              state_ops.assign(accumulator_variable, retrieved_accumulator),
2498              state_ops.assign(linear_variable, retrieved_linear))
2499        config = None
2500        retrieve_op_list.append(retrieve_parameters_op)
2501      return retrieve_op_list
2502
2503    return slot_variables, load_ops_fn, retrieve_ops_fn
2504
2505
2506class _ProximalYogiHandler(_OptimizerHandler):
2507  """Handles Proximal Yogi specific logic."""
2508
2509  def set_optimization_parameters(self, table_descriptor):
2510    table_descriptor.optimization_parameters.proximal_yogi.SetInParent()
2511    table_descriptor.optimization_parameters.proximal_yogi.beta1 = (
2512        self._optimization_parameters.beta1)
2513    table_descriptor.optimization_parameters.proximal_yogi.beta2 = (
2514        self._optimization_parameters.beta2)
2515    table_descriptor.optimization_parameters.proximal_yogi.epsilon = (
2516        self._optimization_parameters.epsilon)
2517    table_descriptor.optimization_parameters.proximal_yogi.l1 = (
2518        self._optimization_parameters.l1_regularization_strength)
2519    table_descriptor.optimization_parameters.proximal_yogi.l2 = (
2520        self._optimization_parameters.l2_regularization_strength)
2521
2522  def get_default_slot_variable_names(self, table):
2523    return ProximalYogiSlotVariableNames(
2524        '{}/{}'.format(table, 'ProximalYogi'),  # v
2525        '{}/{}_1'.format(table, 'ProximalYogi'))  # m
2526
2527  def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
2528                               table_config, table_variables, config_proto):
2529    v_initializer = init_ops.constant_initializer(
2530        self._optimization_parameters.initial_accumulator_value)
2531    v_variables = _create_partitioned_variables(
2532        name=slot_variable_names.v,
2533        num_hosts=num_hosts,
2534        vocabulary_size=table_config.vocabulary_size,
2535        embedding_dimension=table_config.dimension,
2536        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
2537        initializer=v_initializer)
2538    m_initializer = init_ops.zeros_initializer()
2539    m_variables = _create_partitioned_variables(
2540        name=slot_variable_names.m,
2541        num_hosts=num_hosts,
2542        vocabulary_size=table_config.vocabulary_size,
2543        embedding_dimension=table_config.dimension,
2544        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
2545        initializer=m_initializer)
2546    slot_variables = ProximalYogiSlotVariables(v_variables, m_variables)
2547
2548    def load_ops_fn():
2549      """Returns the load ops for Proximal Yogi embedding tables.
2550
2551      Returns:
2552        A list of ops to load embedding and slot variables from CPU to TPU.
2553      """
2554      load_op_list = []
2555      config = config_proto
2556      for host_id, table_variable, v_variable, m_variable in (zip(
2557          range(num_hosts), table_variables, v_variables, m_variables)):
2558        with ops.colocate_with(table_variable):
2559          load_parameters_op = (
2560              tpu_ops.load_tpu_embedding_proximal_yogi_parameters(
2561                  parameters=table_variable,
2562                  v=v_variable,
2563                  m=m_variable,
2564                  table_name=table,
2565                  num_shards=num_hosts,
2566                  shard_id=host_id,
2567                  config=config))
2568        # Set config to None to enforce that config is only loaded to the first
2569        # table.
2570        config = None
2571        load_op_list.append(load_parameters_op)
2572      return load_op_list
2573
2574    def retrieve_ops_fn():
2575      """Returns the retrieve ops for Proximal Yogi embedding tables.
2576
2577      Returns:
2578        A list of ops to retrieve embedding and slot variables from TPU to CPU.
2579      """
2580      retrieve_op_list = []
2581      config = config_proto
2582      for host_id, table_variable, v_variable, m_variable in (zip(
2583          range(num_hosts), table_variables, v_variables, m_variables)):
2584        with ops.colocate_with(table_variable):
2585          retrieved_table, retrieved_v, retrieved_m = (
2586              tpu_ops.retrieve_tpu_embedding_proximal_yogi_parameters(
2587                  table_name=table,
2588                  num_shards=num_hosts,
2589                  shard_id=host_id,
2590                  config=config))
2591          retrieve_parameters_op = control_flow_ops.group(
2592              state_ops.assign(table_variable, retrieved_table),
2593              state_ops.assign(v_variable, retrieved_v),
2594              state_ops.assign(m_variable, retrieved_m))
2595        config = None
2596        retrieve_op_list.append(retrieve_parameters_op)
2597      return retrieve_op_list
2598
2599    return slot_variables, load_ops_fn, retrieve_ops_fn
2600
2601
2602class _MomentumHandler(_OptimizerHandler):
2603  """Handles Momentum specific logic."""
2604
2605  def set_optimization_parameters(self, table_descriptor):
2606    (table_descriptor.optimization_parameters.momentum.SetInParent())
2607    table_descriptor.optimization_parameters.momentum.momentum = (
2608        self._optimization_parameters.momentum)
2609    table_descriptor.optimization_parameters.momentum.use_nesterov = (
2610        self._optimization_parameters.use_nesterov)
2611
2612  def get_default_slot_variable_names(self, table):
2613    return MomentumSlotVariableNames('{}/{}'.format(table, 'Momentum'))
2614
2615  def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
2616                               table_config, table_variables, config_proto):
2617
2618    momenta_initializer = init_ops.zeros_initializer()
2619    momenta_variables = _create_partitioned_variables(
2620        name=slot_variable_names.momenta,
2621        num_hosts=num_hosts,
2622        vocabulary_size=table_config.vocabulary_size,
2623        embedding_dimension=table_config.dimension,
2624        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
2625        initializer=momenta_initializer)
2626    slot_variables = MomentumSlotVariables(momenta_variables)
2627
2628    def load_ops_fn():
2629      """Returns the retrieve ops for Momentum embedding tables.
2630
2631      Returns:
2632        A list of ops to load embedding and slot variables from CPU to TPU.
2633      """
2634      load_op_list = []
2635      config = config_proto
2636      for host_id, table_variable, momenta_variable in (zip(
2637          range(num_hosts), table_variables, momenta_variables)):
2638        with ops.colocate_with(table_variable):
2639          load_parameters_op = tpu_ops.load_tpu_embedding_momentum_parameters(
2640              parameters=table_variable,
2641              momenta=momenta_variable,
2642              table_name=table,
2643              num_shards=num_hosts,
2644              shard_id=host_id,
2645              config=config,
2646          )
2647        config = None
2648        load_op_list.append(load_parameters_op)
2649      return load_op_list
2650
2651    def retrieve_ops_fn():
2652      """Returns the retrieve ops for Momentum embedding tables.
2653
2654      Returns:
2655        A list of ops to retrieve embedding and slot variables from TPU to CPU.
2656      """
2657      retrieve_op_list = []
2658      config = config_proto
2659      for host_id, table_variable, momenta_variable in (zip(
2660          range(num_hosts), table_variables, momenta_variables)):
2661        with ops.colocate_with(table_variable):
2662          retrieved_table, retrieved_momenta = (
2663              tpu_ops.retrieve_tpu_embedding_momentum_parameters(
2664                  table_name=table,
2665                  num_shards=num_hosts,
2666                  shard_id=host_id,
2667                  config=config,
2668              ))
2669          retrieve_parameters_op = control_flow_ops.group(
2670              state_ops.assign(table_variable, retrieved_table),
2671              state_ops.assign(momenta_variable, retrieved_momenta))
2672        config = None
2673        retrieve_op_list.append(retrieve_parameters_op)
2674      return retrieve_op_list
2675
2676    return slot_variables, load_ops_fn, retrieve_ops_fn
2677
2678
2679class _RMSPropHandler(_OptimizerHandler):
2680  """Handles RMS prop specific logic."""
2681
2682  def set_optimization_parameters(self, table_descriptor):
2683    (table_descriptor.optimization_parameters.rms_prop.SetInParent())
2684    table_descriptor.optimization_parameters.rms_prop.rho = (
2685        self._optimization_parameters.rho)
2686    table_descriptor.optimization_parameters.rms_prop.epsilon = (
2687        self._optimization_parameters.epsilon)
2688    table_descriptor.optimization_parameters.rms_prop.momentum = (
2689        self._optimization_parameters.momentum)
2690
2691  def get_default_slot_variable_names(self, table):
2692    return RMSPropSlotVariableNames('{}/{}/ms'.format(table, 'RMSProp'),
2693                                    '{}/{}/mom'.format(table, 'RMSProp'))
2694
2695  def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
2696                               table_config, table_variables, config_proto):
2697
2698    ms_variables = _create_partitioned_variables(
2699        name=slot_variable_names.ms,
2700        num_hosts=num_hosts,
2701        vocabulary_size=table_config.vocabulary_size,
2702        embedding_dimension=table_config.dimension,
2703        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
2704        initializer=init_ops.zeros_initializer(),
2705    )
2706    mom_variables = _create_partitioned_variables(
2707        name=slot_variable_names.mom,
2708        num_hosts=num_hosts,
2709        vocabulary_size=table_config.vocabulary_size,
2710        embedding_dimension=table_config.dimension,
2711        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
2712        initializer=init_ops.zeros_initializer(),
2713    )
2714    slot_variables = RMSPropSlotVariables(ms_variables, mom_variables)
2715
2716    def load_ops_fn():
2717      """Returns the retrieve ops for RMS Prop embedding tables.
2718
2719      Returns:
2720        A list of ops to load embedding and slot variables from CPU to TPU.
2721      """
2722      load_op_list = []
2723      config = config_proto
2724      for host_id, table_variable, ms_variable, mom_variable in (zip(
2725          range(num_hosts), table_variables, ms_variables, mom_variables)):
2726        with ops.colocate_with(table_variable):
2727          load_parameters_op = tpu_ops.load_tpu_embedding_rms_prop_parameters(
2728              parameters=table_variable,
2729              ms=ms_variable,
2730              mom=mom_variable,
2731              table_name=table,
2732              num_shards=num_hosts,
2733              shard_id=host_id,
2734              config=config,
2735          )
2736        config = None
2737        load_op_list.append(load_parameters_op)
2738      return load_op_list
2739
2740    def retrieve_ops_fn():
2741      """Returns the retrieve ops for RMS Prop embedding tables.
2742
2743      Returns:
2744        A list of ops to retrieve embedding and slot variables from TPU to CPU.
2745      """
2746      retrieve_op_list = []
2747      config = config_proto
2748      for host_id, table_variable, ms_variable, mom_variable in (zip(
2749          range(num_hosts), table_variables, ms_variables, mom_variables)):
2750        with ops.colocate_with(table_variable):
2751          retrieved_table, retrieved_ms, retrieved_mom = (
2752              tpu_ops.retrieve_tpu_embedding_rms_prop_parameters(
2753                  table_name=table,
2754                  num_shards=num_hosts,
2755                  shard_id=host_id,
2756                  config=config,
2757              ))
2758          retrieve_parameters_op = control_flow_ops.group(
2759              state_ops.assign(table_variable, retrieved_table),
2760              state_ops.assign(ms_variable, retrieved_ms),
2761              state_ops.assign(mom_variable, retrieved_mom))
2762        config = None
2763        retrieve_op_list.append(retrieve_parameters_op)
2764      return retrieve_op_list
2765
2766    return slot_variables, load_ops_fn, retrieve_ops_fn
2767
2768
2769class _FrequencyEstimatorHandler(_OptimizerHandler):
2770  """Handles frequency estimator specific logic."""
2771
2772  def set_optimization_parameters(self, table_descriptor):
2773    table_descriptor.optimization_parameters.frequency_estimator.SetInParent()
2774    freq = table_descriptor.optimization_parameters.frequency_estimator
2775    freq.tau = self._optimization_parameters.tau
2776    freq.max_delta = self._optimization_parameters.max_delta
2777    freq.outlier_threshold = self._optimization_parameters.outlier_threshold
2778    freq.weight_exponent = self._optimization_parameters.weight_exponent
2779
2780  def get_default_slot_variable_names(self, table):
2781    return FrequencyEstimatorSlotVariableNames(
2782        '{}/FrequencyEstimator'.format(table))
2783
2784  def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
2785                               table_config, table_variables, config_proto):
2786    if table_config.dimension != 1:
2787      raise ValueError('FrequencyEstimator tables should only have a dimension '
2788                       'of 1. Received dimension {}'.format(
2789                           table_config.dimension))
2790
2791    last_hit_step_variables = _create_partitioned_variables(
2792        name=slot_variable_names.last_hit_step,
2793        num_hosts=num_hosts,
2794        vocabulary_size=table_config.vocabulary_size,
2795        embedding_dimension=table_config.dimension,
2796        collections=[ops.GraphKeys.GLOBAL_VARIABLES],
2797        initializer=init_ops.zeros_initializer(),
2798    )
2799    slot_variables = FrequencyEstimatorSlotVariables(last_hit_step_variables)
2800
2801    def load_ops_fn():
2802      """Returns the retrieve ops for Frequency Estimator embedding tables.
2803
2804      Returns:
2805        A list of ops to load embedding and slot variables from CPU to TPU.
2806      """
2807      load_op_list = []
2808      config = config_proto
2809      for host_id, table_variable, last_hit_step_variable in (zip(
2810          range(num_hosts), table_variables, last_hit_step_variables)):
2811        with ops.colocate_with(table_variable):
2812          load_parameters_op = (
2813              tpu_ops.load_tpu_embedding_frequency_estimator_parameters(
2814                  parameters=table_variable,
2815                  last_hit_step=last_hit_step_variable,
2816                  table_name=table,
2817                  num_shards=num_hosts,
2818                  shard_id=host_id,
2819                  config=config))
2820        config = None
2821        load_op_list.append(load_parameters_op)
2822      return load_op_list
2823
2824    def retrieve_ops_fn():
2825      """Returns the retrieve ops for Frequency Estimator embedding tables.
2826
2827      Returns:
2828        A list of ops to retrieve embedding and slot variables from TPU to CPU.
2829      """
2830      retrieve_op_list = []
2831      config = config_proto
2832      for host_id, table_variable, last_hit_step_variable in (zip(
2833          range(num_hosts), table_variables, last_hit_step_variables)):
2834        with ops.colocate_with(table_variable):
2835          retrieved_table, retrieved_last_hit_step = (
2836              tpu_ops.retrieve_tpu_embedding_frequency_estimator_parameters(
2837                  table_name=table,
2838                  num_shards=num_hosts,
2839                  shard_id=host_id,
2840                  config=config,
2841              ))
2842          retrieve_parameters_op = control_flow_ops.group(
2843              state_ops.assign(table_variable, retrieved_table),
2844              state_ops.assign(last_hit_step_variable, retrieved_last_hit_step))
2845        config = None
2846        retrieve_op_list.append(retrieve_parameters_op)
2847      return retrieve_op_list
2848
2849    return slot_variables, load_ops_fn, retrieve_ops_fn
2850
2851
2852class _StochasticGradientDescentHandler(_OptimizerHandler):
2853  """Handles stochastic gradient descent specific logic."""
2854
2855  def set_optimization_parameters(self, table_descriptor):
2856    (table_descriptor.optimization_parameters.stochastic_gradient_descent
2857     .SetInParent())
2858
2859  def get_default_slot_variable_names(self, table):
2860    return None
2861
2862  def create_variables_and_ops(self, table, slot_variable_names, num_hosts,
2863                               table_config, table_variables, config_proto):
2864    del table_config
2865
2866    def load_ops_fn():
2867      """Returns the retrieve ops for AdaGrad embedding tables.
2868
2869      Returns:
2870        A list of ops to load embedding and slot variables from CPU to TPU.
2871      """
2872      load_op_list = []
2873      config = config_proto
2874      for host_id, table_variable in enumerate(table_variables):
2875        with ops.colocate_with(table_variable):
2876          load_parameters_op = (
2877              tpu_ops.load_tpu_embedding_stochastic_gradient_descent_parameters(
2878                  parameters=table_variable,
2879                  table_name=table,
2880                  num_shards=num_hosts,
2881                  shard_id=host_id,
2882                  config=config))
2883        config = None
2884        load_op_list.append(load_parameters_op)
2885      return load_op_list
2886
2887    def retrieve_ops_fn():
2888      """Returns the retrieve ops for SGD embedding tables.
2889
2890      Returns:
2891        A list of ops to retrieve embedding and slot variables from TPU to CPU.
2892      """
2893      retrieve_op_list = []
2894      config = config_proto
2895      for host_id, table_variable in enumerate(table_variables):
2896        with ops.colocate_with(table_variable):
2897          retrieved_table = (
2898              tpu_ops
2899              .retrieve_tpu_embedding_stochastic_gradient_descent_parameters(
2900                  table_name=table,
2901                  num_shards=num_hosts,
2902                  shard_id=host_id,
2903                  config=config))
2904          retrieve_parameters_op = control_flow_ops.group(
2905              state_ops.assign(table_variable, retrieved_table))
2906        config = None
2907        retrieve_op_list.append(retrieve_parameters_op)
2908      return retrieve_op_list
2909
2910    return None, load_ops_fn, retrieve_ops_fn
2911
2912
2913def _get_optimization_handler(optimization_parameters):
2914  """Gets the optimization handler given the parameter type."""
2915  if isinstance(optimization_parameters, AdagradParameters):
2916    return _AdagradHandler(optimization_parameters)
2917  elif isinstance(optimization_parameters, AdagradMomentumParameters):
2918    return _AdagradMomentumHandler(optimization_parameters)
2919  elif isinstance(optimization_parameters, ProximalAdagradParameters):
2920    return _ProximalAdagradHandler(optimization_parameters)
2921  elif isinstance(optimization_parameters, AdamParameters):
2922    return _AdamHandler(optimization_parameters)
2923  elif isinstance(optimization_parameters, FtrlParameters):
2924    return _FtrlHandler(optimization_parameters)
2925  elif isinstance(optimization_parameters, ProximalYogiParameters):
2926    return _ProximalYogiHandler(optimization_parameters)
2927  elif isinstance(optimization_parameters, StochasticGradientDescentParameters):
2928    return _StochasticGradientDescentHandler(optimization_parameters)
2929  elif isinstance(optimization_parameters, MomentumParameters):
2930    return _MomentumHandler(optimization_parameters)
2931  elif isinstance(optimization_parameters, RMSPropParameters):
2932    return _RMSPropHandler(optimization_parameters)
2933  elif isinstance(optimization_parameters, FrequencyEstimatorParameters):
2934    return _FrequencyEstimatorHandler(optimization_parameters)
2935  return NotImplementedError()
2936
2937
2938def _create_ordered_dict(d):
2939  """Create an OrderedDict from Dict."""
2940  return collections.OrderedDict((k, d[k]) for k in sorted(d))
2941
2942
2943def _create_combiners(table_to_config_dict, table_to_features_dict):
2944  """Create a per feature list of combiners, ordered by table."""
2945  combiners = []
2946  for table in table_to_config_dict:
2947    combiner = table_to_config_dict[table].combiner or 'sum'
2948    combiners.extend([combiner] * len(table_to_features_dict[table]))
2949  return combiners
2950
2951
2952def _create_table_to_features_dict(feature_to_config_dict):
2953  """Create mapping from table to a list of its features."""
2954  table_to_features_dict_tmp = {}
2955  for feature, feature_config in feature_to_config_dict.items():
2956    if feature_config.table_id in table_to_features_dict_tmp:
2957      table_to_features_dict_tmp[feature_config.table_id].append(feature)
2958    else:
2959      table_to_features_dict_tmp[feature_config.table_id] = [feature]
2960
2961  table_to_features_dict = collections.OrderedDict()
2962  for table in sorted(table_to_features_dict_tmp):
2963    table_to_features_dict[table] = sorted(table_to_features_dict_tmp[table])
2964  return table_to_features_dict
2965
2966
2967def _create_device_fn(hosts):
2968  """Create device_fn() to use with _create_partitioned_variables()."""
2969
2970  def device_fn(op):
2971    """Returns the `device` for `op`."""
2972    part_match = re.match(r'.*/part_(\d+)(/|$)', op.name)
2973    dummy_match = re.match(r'.*dummy_(\d+).*', op.name)
2974    if not part_match and not dummy_match:
2975      raise RuntimeError(
2976          'Internal Error: Expected {} to contain /part_* or dummy_*'.format(
2977              op.name))
2978
2979    if part_match:
2980      idx = int(part_match.group(1))
2981    else:
2982      idx = int(dummy_match.group(1))  # pytype: disable=attribute-error
2983
2984    device = hosts[idx]
2985    logging.debug('assigning {} to {}.', op, device)
2986    return device
2987
2988  return device_fn
2989
2990
2991def _create_partitioned_variables(name,
2992                                  num_hosts,
2993                                  vocabulary_size,
2994                                  embedding_dimension,
2995                                  initializer,
2996                                  collections=None):  # pylint: disable=redefined-outer-name
2997  """Creates PartitionedVariables based on `num_hosts` for `table`."""
2998
2999  num_slices = min(vocabulary_size, num_hosts)
3000
3001  var_list = list(
3002      variable_scope.get_variable(
3003          name,
3004          shape=(vocabulary_size, embedding_dimension),
3005          partitioner=partitioned_variables.fixed_size_partitioner(num_slices),
3006          dtype=dtypes.float32,
3007          initializer=initializer,
3008          collections=collections,
3009          trainable=False))
3010
3011  if vocabulary_size >= num_hosts:
3012    return var_list
3013
3014  # For padded part, define the dummy variable to be loaded into TPU system.
3015  for idx in range(num_hosts - vocabulary_size):
3016    var_list.append(
3017        variable_scope.get_variable(
3018            'dummy_{}_{}'.format(vocabulary_size + idx, name),
3019            shape=(1, embedding_dimension),
3020            dtype=dtypes.float32,
3021            initializer=initializer,
3022            collections=[ops.GraphKeys.LOCAL_VARIABLES],
3023            trainable=False))
3024
3025  return var_list
3026