xref: /aosp_15_r20/external/tensorflow/tensorflow/python/tpu/tpu_embedding_v2_utils.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Companion classes for mid level API for TPU Embeddings in TF2."""
16
17import abc
18import math
19import typing
20from typing import Any, Dict, Callable, List, Optional, Text, Tuple, TypeVar, Union
21
22from absl import logging
23
24from tensorflow.core.protobuf.tpu import optimization_parameters_pb2
25from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2
26from tensorflow.python.distribute import sharded_variable
27from tensorflow.python.framework import ops
28from tensorflow.python.framework.tensor_shape import TensorShape
29from tensorflow.python.ops import init_ops_v2
30from tensorflow.python.ops import variables as tf_variables
31from tensorflow.python.tpu.ops import tpu_ops
32from tensorflow.python.types import core
33from tensorflow.python.util.tf_export import tf_export
34
35
36TableVariable = TypeVar("TableVariable", sharded_variable.ShardedVariable,
37                        tf_variables.Variable)
38SlotVarCreationFnType = Callable[
39    [TableVariable, List[Text], List[init_ops_v2.Initializer]],
40    Dict[Text, TableVariable]]
41ClipValueType = Union[Tuple[float, float], float]
42
43
44class _Optimizer(metaclass=abc.ABCMeta):
45  """Base class for all optimizers, with common parameters."""
46
47  def __init__(
48      self,
49      learning_rate: Union[float, Callable[[], float]],
50      use_gradient_accumulation: bool,
51      clip_weight_min: Optional[float],
52      clip_weight_max: Optional[float],
53      weight_decay_factor: Optional[float],
54      multiply_weight_decay_factor_by_learning_rate: bool,
55      clipvalue: Optional[ClipValueType] = None,
56      slot_variable_creation_fn: Optional[SlotVarCreationFnType] = None):
57    self.learning_rate = learning_rate
58    self.use_gradient_accumulation = use_gradient_accumulation
59    self.clip_weight_min = clip_weight_min
60    self.clip_weight_max = clip_weight_max
61    if not use_gradient_accumulation and clipvalue is not None:
62      raise ValueError(
63          f"When `use_gradient_accumulation` is False, gradient clipping "
64          f"cannot be used and `clipvalue` should be left as None. "
65          f"Received value {clipvalue} for argument `clipvalue`.")
66    if clipvalue is None:
67      clipvalue = (None, None)
68    elif not isinstance(clipvalue, tuple):
69      clipvalue = (-1. * clipvalue, clipvalue)
70    self.clip_gradient_min, self.clip_gradient_max = clipvalue
71
72    self.weight_decay_factor = weight_decay_factor
73    self.multiply_weight_decay_factor_by_learning_rate = (
74        multiply_weight_decay_factor_by_learning_rate)
75
76    if (slot_variable_creation_fn is not None and
77        not callable(slot_variable_creation_fn)):
78      raise ValueError(
79          f"Argument `slot_variable_creation_fn` must be either None or a "
80          f"callable. Received: {slot_variable_creation_fn}")
81    self.slot_variable_creation_fn = slot_variable_creation_fn
82
83  @abc.abstractmethod
84  def _slot_names(self) -> List[Text]:
85    """Returns the name of all the slot variables.
86
87    This does not include the 'parameters' variable and these names must match
88    the names of the slots variables as used in the corresponding
89    `tpu_ops.load_tpu_embedding_*` ops.
90    """
91    raise NotImplementedError
92
93  @abc.abstractmethod
94  def _slot_initializers(self) -> List[init_ops_v2.Initializer]:
95    """Returns initializers for slot variables.
96
97    This returns a parallel list to self._slot_names().
98    """
99    raise NotImplementedError
100
101  def _set_optimization_parameters(
102      self, parameters: optimization_parameters_pb2.OptimizationParameters):
103    """Sets the optimizer fields in the OptimizationParameters."""
104    if self.use_gradient_accumulation:
105      parameters.gradient_accumulation_status = (
106          optimization_parameters_pb2.GradientAccumulationStatus.ENABLED)
107    else:
108      parameters.gradient_accumulation_status = (
109          optimization_parameters_pb2.GradientAccumulationStatus.DISABLED)
110
111    if self.clip_weight_min is not None:
112      parameters.clipping_limits.lower.value = self.clip_weight_min
113
114    if self.clip_weight_max is not None:
115      parameters.clipping_limits.upper.value = self.clip_weight_max
116
117    if self.clip_gradient_min is not None:
118      parameters.gradient_clipping_limits.lower.value = self.clip_gradient_min
119
120    if self.clip_gradient_max is not None:
121      parameters.gradient_clipping_limits.upper.value = self.clip_gradient_max
122
123    if self.weight_decay_factor:
124      parameters.weight_decay_factor = self.weight_decay_factor
125      if self.multiply_weight_decay_factor_by_learning_rate:
126        parameters.multiply_weight_decay_factor_by_learning_rate = True
127
128  @abc.abstractmethod
129  def _load(self) -> Callable[..., ops.Operation]:
130    """Returns the load function for the optimizer."""
131    raise NotImplementedError
132
133  @abc.abstractmethod
134  def _retrieve(self) -> Callable[..., core.Tensor]:
135    """Returns the retrieve function for the optimizer."""
136    raise NotImplementedError
137
138  def _create_slots(
139      self, table: "TableConfig",
140      variable_creator: Callable[[Text, init_ops_v2.Initializer],
141                                 tf_variables.Variable]
142  ) -> Dict[Text, tf_variables.Variable]:
143    """Creates slot variables for table.
144
145    Args:
146      table: The table variable to create slots for.
147      variable_creator: A function which creates variables. Takes parameters
148        'name', 'initializer'.
149
150    Returns:
151      A dict of variables, keyed by self._slot_names().
152    """
153    if self.slot_variable_creation_fn is not None:
154      return self.slot_variable_creation_fn(table, self._slot_names(),
155                                            self._slot_initializers())
156    else:
157      slots = {}
158      for slot, initializer in zip(self._slot_names(),
159                                   self._slot_initializers()):
160        slots[slot] = variable_creator(slot, initializer)
161      return slots
162
163  def __eq__(self, other: Any) -> Union[Any, bool]:
164    if isinstance(other, self.__class__):
165      return all([
166          attr1 == attr2
167          for attr1, attr2 in zip(self.__dict__.items(), other.__dict__.items())
168      ])
169    else:
170      return False
171
172  def __hash__(self) -> int:
173    return hash(tuple(self.__dict__.items()))
174
175
176@tf_export("tpu.experimental.embedding.SGD")
177class SGD(_Optimizer):
178  """Optimization parameters for stochastic gradient descent for TPU embeddings.
179
180  Pass this to `tf.tpu.experimental.embedding.TPUEmbedding` via the `optimizer`
181  argument to set the global optimizer and its parameters:
182
183  ```
184  embedding = tf.tpu.experimental.embedding.TPUEmbedding(
185      ...
186      optimizer=tf.tpu.experimental.embedding.SGD(0.1))
187  ```
188
189  This can also be used in a `tf.tpu.experimental.embedding.TableConfig` as the
190  optimizer parameter to set a table specific optimizer. This will override the
191  optimizer and parameters for global embedding optimizer defined above:
192
193  ```
194  table_one = tf.tpu.experimental.embedding.TableConfig(
195      vocabulary_size=...,
196      dim=...,
197      optimizer=tf.tpu.experimental.embedding.SGD(0.2))
198  table_two = tf.tpu.experimental.embedding.TableConfig(
199      vocabulary_size=...,
200      dim=...)
201
202  feature_config = (
203      tf.tpu.experimental.embedding.FeatureConfig(
204          table=table_one),
205      tf.tpu.experimental.embedding.FeatureConfig(
206          table=table_two))
207
208  embedding = tf.tpu.experimental.embedding.TPUEmbedding(
209      feature_config=feature_config,
210      batch_size=...
211      optimizer=tf.tpu.experimental.embedding.SGD(0.1))
212  ```
213
214  In the above example, the first feature will be looked up in a table that has
215  a learning rate of 0.2 while the second feature will be looked up in a table
216  that has a learning rate of 0.1.
217
218  See 'tensorflow/core/protobuf/tpu/optimization_parameters.proto' for a
219  complete description of these parameters and their impacts on the optimizer
220  algorithm.
221  """
222
223  def __init__(self,
224               learning_rate: Union[float, Callable[[], float]] = 0.01,
225               use_gradient_accumulation: bool = True,
226               clip_weight_min: Optional[float] = None,
227               clip_weight_max: Optional[float] = None,
228               weight_decay_factor: Optional[float] = None,
229               multiply_weight_decay_factor_by_learning_rate: bool = None,
230               clipvalue: Optional[ClipValueType] = None):
231    """Optimization parameters for stochastic gradient descent.
232
233    Args:
234      learning_rate: The learning rate. It should be a floating point value or a
235        callable taking no arguments for a dynamic learning rate.
236      use_gradient_accumulation: setting this to `False` makes embedding
237        gradients calculation less accurate but faster.
238      clip_weight_min: the minimum value to clip by; None means -infinity.
239      clip_weight_max: the maximum value to clip by; None means +infinity.
240      weight_decay_factor: amount of weight decay to apply; None means that the
241        weights are not decayed. Weights are decayed by multiplying the weight
242        by this factor each step.
243      multiply_weight_decay_factor_by_learning_rate: if true,
244        `weight_decay_factor` is multiplied by the current learning rate.
245      clipvalue: Controls clipping of the gradient. Set to either a single
246        positive scalar value to get clipping or a tiple of scalar values (min,
247        max) to set a separate maximum or minimum. If one of the two entries is
248        None, then there will be no clipping that direction. Note if this is
249        set, you may see a decrease in performance as  gradient accumulation
250        will be enabled (it is normally off for SGD as it has no affect on
251        accuracy). See
252        'tensorflow/core/protobuf/tpu/optimization_parameters.proto' for more
253        information on gradient accumulation and its impact on tpu embeddings.
254    """
255    super().__init__(learning_rate, use_gradient_accumulation, clip_weight_min,
256                     clip_weight_max, weight_decay_factor,
257                     multiply_weight_decay_factor_by_learning_rate, clipvalue)
258
259  def _slot_names(self) -> List[Text]:
260    return []
261
262  def _slot_initializers(self) -> List[init_ops_v2.Initializer]:
263    return []
264
265  def _set_optimization_parameters(
266      self, parameters: optimization_parameters_pb2.OptimizationParameters):
267    super()._set_optimization_parameters(parameters)
268    parameters.stochastic_gradient_descent.SetInParent()
269
270  def _load(self) -> Callable[..., ops.Operation]:
271    return tpu_ops.load_tpu_embedding_stochastic_gradient_descent_parameters
272
273  def _retrieve(self) -> Callable[..., core.Tensor]:
274    return tpu_ops.retrieve_tpu_embedding_stochastic_gradient_descent_parameters
275
276
277@tf_export("tpu.experimental.embedding.Adagrad")
278class Adagrad(_Optimizer):
279  """Optimization parameters for Adagrad with TPU embeddings.
280
281  Pass this to `tf.tpu.experimental.embedding.TPUEmbedding` via the `optimizer`
282  argument to set the global optimizer and its parameters:
283
284  ```python
285  embedding = tf.tpu.experimental.embedding.TPUEmbedding(
286      ...
287      optimizer=tf.tpu.experimental.embedding.Adagrad(0.1))
288  ```
289
290  This can also be used in a `tf.tpu.experimental.embedding.TableConfig` as the
291  optimizer parameter to set a table specific optimizer. This will override the
292  optimizer and parameters for global embedding optimizer defined above:
293
294  ```python
295  table_one = tf.tpu.experimental.embedding.TableConfig(
296      vocabulary_size=...,
297      dim=...,
298      optimizer=tf.tpu.experimental.embedding.Adagrad(0.2))
299  table_two = tf.tpu.experimental.embedding.TableConfig(
300      vocabulary_size=...,
301      dim=...)
302
303  feature_config = (
304      tf.tpu.experimental.embedding.FeatureConfig(
305          table=table_one),
306      tf.tpu.experimental.embedding.FeatureConfig(
307          table=table_two))
308
309  embedding = tf.tpu.experimental.embedding.TPUEmbedding(
310      feature_config=feature_config,
311      batch_size=...
312      optimizer=tf.tpu.experimental.embedding.Adagrad(0.1))
313  ```
314
315  In the above example, the first feature will be looked up in a table that has
316  a learning rate of 0.2 while the second feature will be looked up in a table
317  that has a learning rate of 0.1.
318
319  See 'tensorflow/core/protobuf/tpu/optimization_parameters.proto' for a
320  complete description of these parameters and their impacts on the optimizer
321  algorithm.
322  """
323
324  def __init__(
325      self,
326      learning_rate: Union[float, Callable[[], float]] = 0.001,
327      initial_accumulator_value: float = 0.1,
328      use_gradient_accumulation: bool = True,
329      clip_weight_min: Optional[float] = None,
330      clip_weight_max: Optional[float] = None,
331      weight_decay_factor: Optional[float] = None,
332      multiply_weight_decay_factor_by_learning_rate: bool = None,
333      slot_variable_creation_fn: Optional[SlotVarCreationFnType] = None,
334      clipvalue: Optional[ClipValueType] = None):
335    """Optimization parameters for Adagrad.
336
337    Args:
338      learning_rate: The learning rate. It should be a floating point value or a
339        callable taking no arguments for a dynamic learning rate.
340      initial_accumulator_value: initial accumulator for Adagrad.
341      use_gradient_accumulation: setting this to `False` makes embedding
342        gradients calculation less accurate but faster.
343      clip_weight_min: the minimum value to clip by; None means -infinity.
344      clip_weight_max: the maximum value to clip by; None means +infinity.
345      weight_decay_factor: amount of weight decay to apply; None means that the
346        weights are not decayed.
347      multiply_weight_decay_factor_by_learning_rate: if true,
348        `weight_decay_factor` is multiplied by the current learning rate.
349      slot_variable_creation_fn: If you wish do directly control the creation of
350        the slot variables, set this to a callable taking three parameters: a
351          table variable, a list of slot names to create for it, and a list of
352          initializers. This function should return a dict with the slot names
353          as keys and the created variables as values with types matching the
354          table variable. When set to None (the default), uses the built-in
355          variable creation.
356      clipvalue: Controls clipping of the gradient. Set to either a single
357        positive scalar value to get clipping or a tuple of scalar values (min,
358        max) to set a separate maximum or minimum. If one of the two entries is
359        None, then there will be no clipping that direction.
360    """
361    super().__init__(learning_rate, use_gradient_accumulation, clip_weight_min,
362                     clip_weight_max, weight_decay_factor,
363                     multiply_weight_decay_factor_by_learning_rate, clipvalue,
364                     slot_variable_creation_fn)
365    if initial_accumulator_value <= 0:
366      raise ValueError(
367          f"Argument `initial_accumulator_value` must be a positive float. "
368          f"Received: {initial_accumulator_value}")
369    self.initial_accumulator_value = initial_accumulator_value
370
371  def _slot_names(self) -> List[Text]:
372    return ["accumulators"]
373
374  def _slot_initializers(self) -> List[init_ops_v2.Initializer]:
375    return [init_ops_v2.Constant(self.initial_accumulator_value)]
376
377  def _set_optimization_parameters(
378      self, parameters: optimization_parameters_pb2.OptimizationParameters):
379    super()._set_optimization_parameters(parameters)
380    parameters.adagrad.SetInParent()
381
382  def _load(self) -> Callable[..., ops.Operation]:
383    return tpu_ops.load_tpu_embedding_adagrad_parameters
384
385  def _retrieve(self) -> Callable[..., core.Tensor]:
386    return tpu_ops.retrieve_tpu_embedding_adagrad_parameters
387
388
389@tf_export("tpu.experimental.embedding.AdagradMomentum")
390class AdagradMomentum(_Optimizer):
391  """Optimization parameters for Adagrad + Momentum with TPU embeddings.
392
393  Pass this to `tf.tpu.experimental.embedding.TPUEmbedding` via the `optimizer`
394  argument to set the global optimizer and its parameters:
395
396  ```python
397  embedding = tf.tpu.experimental.embedding.TPUEmbedding(
398      ...
399      optimizer=tf.tpu.experimental.embedding.AdagradMomentum(0.1))
400  ```
401
402  This can also be used in a `tf.tpu.experimental.embedding.TableConfig` as the
403  optimizer parameter to set a table specific optimizer. This will override the
404  optimizer and parameters for global embedding optimizer defined above:
405
406  ```python
407  table_one = tf.tpu.experimental.embedding.TableConfig(
408      vocabulary_size=...,
409      dim=...,
410      optimizer=tf.tpu.experimental.embedding.AdagradMomentum(0.2))
411  table_two = tf.tpu.experimental.embedding.TableConfig(
412      vocabulary_size=...,
413      dim=...)
414
415  feature_config = (
416      tf.tpu.experimental.embedding.FeatureConfig(
417          table=table_one),
418      tf.tpu.experimental.embedding.FeatureConfig(
419          table=table_two))
420
421  embedding = tf.tpu.experimental.embedding.TPUEmbedding(
422      feature_config=feature_config,
423      batch_size=...
424      optimizer=tf.tpu.experimental.embedding.AdagradMomentum(0.1))
425  ```
426
427  In the above example, the first feature will be looked up in a table that has
428  a learning rate of 0.2 while the second feature will be looked up in a table
429  that has a learning rate of 0.1.
430
431  See 'tensorflow/core/protobuf/tpu/optimization_parameters.proto' for a
432  complete description of these parameters and their impacts on the optimizer
433  algorithm.
434  """
435
436  def __init__(
437      self,
438      learning_rate: Union[float, Callable[[], float]] = 0.001,
439      momentum: float = 0.0,
440      use_nesterov: bool = False,
441      exponent: float = 2,
442      beta2: float = 1,
443      epsilon: float = 1e-10,
444      use_gradient_accumulation: bool = True,
445      clip_weight_min: Optional[float] = None,
446      clip_weight_max: Optional[float] = None,
447      weight_decay_factor: Optional[float] = None,
448      multiply_weight_decay_factor_by_learning_rate: bool = None,
449      slot_variable_creation_fn: Optional[SlotVarCreationFnType] = None,
450      clipvalue: Optional[ClipValueType] = None):
451    """Optimization parameters for Adagrad + Momentum.
452
453    Args:
454      learning_rate: The learning rate. It should be a floating point value or a
455        callable taking no arguments for a dynamic learning rate.
456      momentum: Moving average parameter for the momentum accumulator.
457      use_nesterov: Whether to use the Nesterov variant of momentum. See
458        Sutskever et al., 2013.
459      exponent: Exponent for the Adagrad accumulator.
460      beta2: Moving average parameter for the Adagrad accumulator.
461      epsilon: initial accumulator for Adagrad accumulator.
462      use_gradient_accumulation: setting this to `False` makes embedding
463        gradients calculation less accurate but faster.
464      clip_weight_min: the minimum value to clip by; None means -infinity.
465      clip_weight_max: the maximum value to clip by; None means +infinity.
466      weight_decay_factor: amount of weight decay to apply; None means that the
467        weights are not decayed.
468      multiply_weight_decay_factor_by_learning_rate: if true,
469        `weight_decay_factor` is multiplied by the current learning rate.
470      slot_variable_creation_fn: If you wish do directly control the creation of
471        the slot variables, set this to a callable taking three parameters: a
472          table variable, a list of slot names to create for it, and a list of
473          initializers. This function should return a dict with the slot names
474          as keys and the created variables as values with types matching the
475          table variable. When set to None (the default), uses the built-in
476          variable creation.
477      clipvalue: Controls clipping of the gradient. Set to either a single
478        positive scalar value to get clipping or a tuple of scalar values (min,
479        max) to set a separate maximum or minimum. If one of the two entries is
480        None, then there will be no clipping that direction.
481    """
482    super().__init__(learning_rate, use_gradient_accumulation, clip_weight_min,
483                     clip_weight_max, weight_decay_factor,
484                     multiply_weight_decay_factor_by_learning_rate, clipvalue,
485                     slot_variable_creation_fn)
486    if epsilon <= 0:
487      raise ValueError("Adagrad momentum: epsilon must be positive")
488    if exponent <= 0:
489      raise ValueError("Adagrad momentum: Precondition exponent must >0")
490    self.momentum = momentum
491    self.use_nesterov = use_nesterov
492    self.exponent = exponent
493    self.beta2 = beta2
494    self.epsilon = epsilon
495
496  def _slot_names(self) -> List[Text]:
497    return ["accumulators", "momenta"]
498
499  def _slot_initializers(self) -> List[init_ops_v2.Initializer]:
500    return [init_ops_v2.Constant(), init_ops_v2.Constant()]
501
502  def _set_optimization_parameters(
503      self, parameters: optimization_parameters_pb2.OptimizationParameters):
504    super()._set_optimization_parameters(parameters)
505    parameters.adagrad_momentum.SetInParent()
506    parameters.adagrad_momentum.momentum = self.momentum
507    parameters.adagrad_momentum.use_nesterov = self.use_nesterov
508    parameters.adagrad_momentum.exponent = self.exponent
509    parameters.adagrad_momentum.beta2 = self.beta2
510    parameters.adagrad_momentum.epsilon = self.epsilon
511
512  def _load(self) -> Callable[..., ops.Operation]:
513    return tpu_ops.load_tpu_embedding_adagrad_momentum_parameters
514
515  def _retrieve(self) -> Callable[..., core.Tensor]:
516    return tpu_ops.retrieve_tpu_embedding_adagrad_momentum_parameters
517
518
519@tf_export("tpu.experimental.embedding.FTRL")
520class FTRL(_Optimizer):
521  """Optimization parameters for FTRL with TPU embeddings.
522
523  See Algorithm 1 of this
524  [paper](https://research.google.com/pubs/archive/41159.pdf).
525
526  Pass this to `tf.tpu.experimental.embedding.TPUEmbedding` via the `optimizer`
527  argument to set the global optimizer and its parameters:
528
529  ```python
530  embedding = tf.tpu.experimental.embedding.TPUEmbedding(
531      ...
532      optimizer=tf.tpu.experimental.embedding.FTRL(0.1))
533  ```
534
535  This can also be used in a `tf.tpu.experimental.embedding.TableConfig` as the
536  optimizer parameter to set a table specific optimizer. This will override the
537  optimizer and parameters for global embedding optimizer defined above:
538
539  ```python
540  table_one = tf.tpu.experimental.embedding.TableConfig(
541      vocabulary_size=...,
542      dim=...,
543      optimizer=tf.tpu.experimental.embedding.FTRL(0.2))
544  table_two = tf.tpu.experimental.embedding.TableConfig(
545      vocabulary_size=...,
546      dim=...)
547
548  feature_config = (
549      tf.tpu.experimental.embedding.FeatureConfig(
550          table=table_one),
551      tf.tpu.experimental.embedding.FeatureConfig(
552          table=table_two))
553
554  embedding = tf.tpu.experimental.embedding.TPUEmbedding(
555      feature_config=feature_config,
556      batch_size=...
557      optimizer=tf.tpu.experimental.embedding.FTRL(0.1))
558  ```
559
560  In the above example, the first feature will be looked up in a table that has
561  a learning rate of 0.2 while the second feature will be looked up in a table
562  that has a learning rate of 0.1.
563
564  See 'tensorflow/core/protobuf/tpu/optimization_parameters.proto' for a
565  complete description of these parameters and their impacts on the optimizer
566  algorithm.
567  """
568
569  def __init__(
570      self,
571      learning_rate: Union[float, Callable[[], float]] = 0.001,
572      learning_rate_power: float = -0.5,
573      l1_regularization_strength: float = 0.0,
574      l2_regularization_strength: float = 0.0,
575      beta: float = 0.0,
576      initial_accumulator_value: float = 0.1,
577      use_gradient_accumulation: bool = True,
578      clip_weight_min: Optional[float] = None,
579      clip_weight_max: Optional[float] = None,
580      weight_decay_factor: Optional[float] = None,
581      multiply_weight_decay_factor_by_learning_rate: bool = None,
582      slot_variable_creation_fn: Optional[SlotVarCreationFnType] = None,
583      clipvalue: Optional[ClipValueType] = None,
584      multiply_linear_by_learning_rate: bool = False,
585      allow_zero_accumulator: bool = False):
586    """Optimization parameters for Adagrad.
587
588    Args:
589      learning_rate: The learning rate. It should be a floating point value or a
590        callable taking no arguments for a dynamic learning rate.
591      learning_rate_power: A float value, must be less or equal to zero.
592        Controls how the learning rate decreases during training. Use zero for a
593        fixed learning rate.
594      l1_regularization_strength: A float value, must be greater than or equal
595        to zero.
596      l2_regularization_strength: A float value, must be greater than or equal
597        to zero.
598      beta: A float value, representing the beta value from the paper.
599      initial_accumulator_value: The starting value for accumulators. Only zero
600        or positive values are allowed.
601      use_gradient_accumulation: setting this to `False` makes embedding
602        gradients calculation less accurate but faster.
603      clip_weight_min: the minimum value to clip by; None means -infinity.
604      clip_weight_max: the maximum value to clip by; None means +infinity.
605      weight_decay_factor: amount of weight decay to apply; None means that the
606        weights are not decayed.
607      multiply_weight_decay_factor_by_learning_rate: if true,
608        `weight_decay_factor` is multiplied by the current learning rate.
609      slot_variable_creation_fn: If you wish do directly control the creation of
610        the slot variables, set this to a callable taking three parameters: a
611          table variable, a list of slot names to create for it, and a list of
612          initializers. This function should return a dict with the slot names
613          as keys and the created variables as values with types matching the
614          table variable. When set to None (the default), uses the built-in
615          variable creation.
616      clipvalue: Controls clipping of the gradient. Set to either a single
617        positive scalar value to get clipping or a tuple of scalar values (min,
618        max) to set a separate maximum or minimum. If one of the two entries is
619        None, then there will be no clipping that direction.
620      multiply_linear_by_learning_rate: If set to True, a modified formula is
621        used for FTRL that treats the "linear" accumulator as being
622        pre-multiplied by the learning rate (i.e., the accumulator named
623        "linear" actually stores "linear * learning_rate"). Other than
624        checkpoint compatibility, this is mathematically equivalent for a static
625        learning rate; for a dynamic learning rate, it is nearly the same as
626        long as the learning rate does not change quickly. The benefit of this
627        is that the modified formula handles zero and near-zero learning rates
628        without producing NaNs, improving flexibility for learning rate ramp-up.
629      allow_zero_accumulator: If set to True, changes some internal formulas to
630        allow zero and near-zero accumulator values at the cost of some
631        performance; this only needs to be set if you are using an initial
632        accumulator value of zero, which is uncommon.
633    """
634    super().__init__(learning_rate, use_gradient_accumulation, clip_weight_min,
635                     clip_weight_max, weight_decay_factor,
636                     multiply_weight_decay_factor_by_learning_rate, clipvalue,
637                     slot_variable_creation_fn)
638    if initial_accumulator_value <= 0:
639      raise ValueError(
640          f"Argument `initial_accumulator_value` must be a positive float. "
641          f"Received: {initial_accumulator_value}")
642    self.initial_accumulator_value = initial_accumulator_value
643    self.learning_rate_power = learning_rate_power
644    self.l1_regularization_strength = l1_regularization_strength
645    self.l2_regularization_strength = l2_regularization_strength
646    self.beta = beta
647    self.multiply_linear_by_learning_rate = multiply_linear_by_learning_rate
648    self.allow_zero_accumulator = allow_zero_accumulator
649
650  def _slot_names(self) -> List[Text]:
651    return ["accumulators", "linears"]
652
653  def _slot_initializers(self) -> List[init_ops_v2.Initializer]:
654    return [
655        init_ops_v2.Constant(self.initial_accumulator_value),
656        init_ops_v2.Constant()
657    ]
658
659  def _set_optimization_parameters(
660      self, parameters: optimization_parameters_pb2.OptimizationParameters):
661    super()._set_optimization_parameters(parameters)
662    ftrl = parameters.ftrl
663    ftrl.l1 = self.l1_regularization_strength
664    ftrl.l2 = self.l2_regularization_strength
665    ftrl.lr_power = self.learning_rate_power
666    ftrl.beta = self.beta
667    ftrl.multiply_linear_by_lr = self.multiply_linear_by_learning_rate
668    ftrl.allow_zero_accumulator = self.allow_zero_accumulator
669
670  def _load(self) -> Callable[..., ops.Operation]:
671    return tpu_ops.load_tpu_embedding_ftrl_parameters
672
673  def _retrieve(self) -> Callable[..., core.Tensor]:
674    return tpu_ops.retrieve_tpu_embedding_ftrl_parameters
675
676
677@tf_export("tpu.experimental.embedding.Adam")
678class Adam(_Optimizer):
679  """Optimization parameters for Adam with TPU embeddings.
680
681  Pass this to `tf.tpu.experimental.embedding.TPUEmbedding` via the `optimizer`
682  argument to set the global optimizer and its parameters:
683
684  NOTE: By default this optimizer is lazy, i.e. it will not apply the gradient
685  update of zero to rows that were not looked up. You can change this behavior
686  by setting `lazy_adam` to `False`.
687
688  ```python
689  embedding = tf.tpu.experimental.embedding.TPUEmbedding(
690      ...
691      optimizer=tf.tpu.experimental.embedding.Adam(0.1))
692  ```
693
694  This can also be used in a `tf.tpu.experimental.embedding.TableConfig` as the
695  optimizer parameter to set a table specific optimizer. This will override the
696  optimizer and parameters for global embedding optimizer defined above:
697
698  ```python
699  table_one = tf.tpu.experimental.embedding.TableConfig(
700      vocabulary_size=...,
701      dim=...,
702      optimizer=tf.tpu.experimental.embedding.Adam(0.2))
703  table_two = tf.tpu.experimental.embedding.TableConfig(
704      vocabulary_size=...,
705      dim=...)
706
707  feature_config = (
708      tf.tpu.experimental.embedding.FeatureConfig(
709          table=table_one),
710      tf.tpu.experimental.embedding.FeatureConfig(
711          table=table_two))
712
713  embedding = tf.tpu.experimental.embedding.TPUEmbedding(
714      feature_config=feature_config,
715      batch_size=...
716      optimizer=tf.tpu.experimental.embedding.Adam(0.1))
717  ```
718
719  In the above example, the first feature will be looked up in a table that has
720  a learning rate of 0.2 while the second feature will be looked up in a table
721  that has a learning rate of 0.1.
722
723  See 'tensorflow/core/protobuf/tpu/optimization_parameters.proto' for a
724  complete description of these parameters and their impacts on the optimizer
725  algorithm.
726  """
727
728  def __init__(
729      self,
730      learning_rate: Union[float, Callable[[], float]] = 0.001,
731      beta_1: float = 0.9,
732      beta_2: float = 0.999,
733      epsilon: float = 1e-07,
734      lazy_adam: bool = True,
735      sum_inside_sqrt: bool = True,
736      use_gradient_accumulation: bool = True,
737      clip_weight_min: Optional[float] = None,
738      clip_weight_max: Optional[float] = None,
739      weight_decay_factor: Optional[float] = None,
740      multiply_weight_decay_factor_by_learning_rate: bool = None,
741      slot_variable_creation_fn: Optional[SlotVarCreationFnType] = None,
742      clipvalue: Optional[ClipValueType] = None):
743    """Optimization parameters for Adam.
744
745    See 'tensorflow/core/protobuf/tpu/optimization_parameters.proto' for a
746    complete description of these parameters and their impacts on the optimizer
747    algorithm.
748
749    Args:
750      learning_rate: The learning rate. It should be a floating point value or a
751        callable taking no arguments for a dynamic learning rate.
752      beta_1: A float value. The exponential decay rate for the 1st moment
753        estimates.
754      beta_2: A float value. The exponential decay rate for the 2nd moment
755        estimates.
756      epsilon: A small constant for numerical stability.
757      lazy_adam: Use lazy Adam instead of Adam. Lazy Adam trains faster.
758      sum_inside_sqrt: When this is true, the Adam update formula is changed
759        from `m / (sqrt(v) + epsilon)` to `m / sqrt(v + epsilon**2)`. This
760        option improves the performance of TPU training and is not expected to
761        harm model quality.
762      use_gradient_accumulation: Setting this to `False` makes embedding
763        gradients calculation less accurate but faster.
764      clip_weight_min: the minimum value to clip by; None means -infinity.
765      clip_weight_max: the maximum value to clip by; None means +infinity.
766      weight_decay_factor: amount of weight decay to apply; None means that the
767        weights are not decayed.
768      multiply_weight_decay_factor_by_learning_rate: if true,
769        `weight_decay_factor` is multiplied by the current learning rate.
770      slot_variable_creation_fn: If you wish do directly control the creation of
771        the slot variables, set this to a callable taking three parameters: a
772          table variable, a list of slot names to create for it, and a list of
773          initializers. This function should return a dict with the slot names
774          as keys and the created variables as values with types matching the
775          table variable. When set to None (the default), uses the built-in
776          variable creation.
777      clipvalue: Controls clipping of the gradient. Set to either a single
778        positive scalar value to get clipping or a tiple of scalar values (min,
779        max) to set a separate maximum or minimum. If one of the two entries is
780        None, then there will be no clipping that direction.
781    """
782    super(Adam, self).__init__(
783        learning_rate, use_gradient_accumulation, clip_weight_min,
784        clip_weight_max, weight_decay_factor,
785        multiply_weight_decay_factor_by_learning_rate, clipvalue,
786        slot_variable_creation_fn)
787    if beta_1 < 0. or beta_1 >= 1.:
788      raise ValueError(
789          f"Argument `beta_1` must be >= 0 and < 1. Received: {beta_1}.")
790    if beta_2 < 0. or beta_2 >= 1.:
791      raise ValueError(
792          f"Argument `beta_2` must be >= 0 and < 1. Received: {beta_1}.")
793    if epsilon <= 0.:
794      raise ValueError("epsilon must be positive; got {}.".format(epsilon))
795    if not use_gradient_accumulation and not lazy_adam:
796      raise ValueError(
797          "When disabling lazy Adam (`lazy_adam=False`), "
798          "gradient accumulation must be used. "
799          "Set `use_gradient_accumulation` to False.")
800
801    self.beta_1 = beta_1
802    self.beta_2 = beta_2
803    self.epsilon = epsilon
804    self.lazy_adam = lazy_adam
805    self.sum_inside_sqrt = sum_inside_sqrt
806
807  def _slot_names(self) -> List[Text]:
808    return ["momenta", "velocities"]
809
810  def _slot_initializers(self) -> List[init_ops_v2.Initializer]:
811    return [init_ops_v2.Constant(), init_ops_v2.Constant()]
812
813  def _set_optimization_parameters(
814      self, parameters: optimization_parameters_pb2.OptimizationParameters):
815    super(Adam, self)._set_optimization_parameters(parameters)
816    parameters.adam.beta1 = self.beta_1
817    parameters.adam.beta2 = self.beta_2
818    parameters.adam.epsilon = self.epsilon
819    parameters.adam.use_non_lazy_adam = not self.lazy_adam
820    parameters.adam.use_sum_inside_sqrt = self.sum_inside_sqrt
821
822  def _load(self) -> Callable[..., ops.Operation]:
823    return tpu_ops.load_tpu_embedding_adam_parameters
824
825  def _retrieve(self) -> Callable[..., core.Tensor]:
826    return tpu_ops.retrieve_tpu_embedding_adam_parameters
827
828
829@tf_export("tpu.experimental.embedding.TableConfig")
830class TableConfig:
831  """Configuration data for one embedding table.
832
833  This class holds the configuration data for a single embedding table. It is
834  used as the `table` parameter of a
835  `tf.tpu.experimental.embedding.FeatureConfig`. Multiple
836  `tf.tpu.experimental.embedding.FeatureConfig` objects can use the same
837  `tf.tpu.experimental.embedding.TableConfig` object. In this case a shared
838  table will be created for those feature lookups.
839
840  ```python
841  table_config_one = tf.tpu.experimental.embedding.TableConfig(
842      vocabulary_size=...,
843      dim=...)
844  table_config_two = tf.tpu.experimental.embedding.TableConfig(
845      vocabulary_size=...,
846      dim=...)
847  feature_config = {
848      'feature_one': tf.tpu.experimental.embedding.FeatureConfig(
849          table=table_config_one),
850      'feature_two': tf.tpu.experimental.embedding.FeatureConfig(
851          table=table_config_one),
852      'feature_three': tf.tpu.experimental.embedding.FeatureConfig(
853          table=table_config_two)}
854  embedding = tf.tpu.experimental.embedding.TPUEmbedding(
855      feature_config=feature_config,
856      batch_size=...
857      optimizer=tf.tpu.experimental.embedding.Adam(0.1))
858  ```
859
860  The above configuration has 2 tables, and three features. The first two
861  features will be looked up in the first table and the third feature will be
862  looked up in the second table.
863
864  """
865
866  def __init__(self,
867               vocabulary_size: int,
868               dim: int,
869               initializer: Optional[Callable[[Any], None]] = None,
870               optimizer: Optional[_Optimizer] = None,
871               combiner: Text = "mean",
872               name: Optional[Text] = None):
873    """Embedding table configuration.
874
875    Args:
876      vocabulary_size: Size of the table's vocabulary (number of rows).
877      dim: The embedding dimension (width) of the table.
878      initializer: A callable initializer taking one parameter, the shape of the
879        variable that will be initialized. Will be called once per task, to
880        initialize that task's shard of the embedding table. If not specified,
881        defaults to `truncated_normal_initializer` with mean `0.0` and standard
882        deviation `1/sqrt(dim)`.
883      optimizer: An optional instance of an optimizer parameters class, instance
884        of one of `tf.tpu.experimental.embedding.SGD`,
885        `tf.tpu.experimental.embedding.Adagrad` or
886        `tf.tpu.experimental.embedding.Adam`. It set will override the global
887        optimizer passed to `tf.tpu.experimental.embedding.TPUEmbedding`.
888      combiner: A string specifying how to reduce if there are multiple entries
889        in a single row. Currently 'mean', 'sqrtn', 'sum' are supported, with
890        'mean' the default. 'sqrtn' often achieves good accuracy, in particular
891        with bag-of-words columns. For more information, see
892        `tf.nn.embedding_lookup_sparse`.
893      name: An optional string used to name the table. Useful for debugging.
894
895    Returns:
896      `TableConfig`.
897
898    Raises:
899      ValueError: if `vocabulary_size` is not a positive integer.
900      ValueError: if `dim` is not a positive integer.
901      ValueError: if `initializer` is specified and is not callable.
902      ValueError: if `combiner` is not supported.
903    """
904    if not isinstance(vocabulary_size, int) or vocabulary_size < 1:
905      raise ValueError(
906          f"Argument `vocabulary_size` must be an int and must be >= 1. "
907          f"Received: {vocabulary_size}")
908
909    if not isinstance(dim, int) or dim < 1:
910      raise ValueError(
911          f"Argument `dim` (embedding dimension) "
912          f"must be an int and must be >= 1. Received: {dim}")
913
914    if (initializer is not None) and (not callable(initializer)):
915      raise ValueError(
916          f"Argument `initializer` must be a callable (or None). "
917          f"Received: {initializer}")
918    if initializer is None:
919      initializer = init_ops_v2.TruncatedNormal(mean=0.0,
920                                                stddev=1/math.sqrt(dim))
921    accepted_combiners = ("mean", "sum", "sqrtn")
922    if combiner not in accepted_combiners:
923      raise ValueError(
924          f"Argument `combiner` must be one of {accepted_combiners}. "
925          f"Received: {combiner}")
926
927    self.vocabulary_size = vocabulary_size
928    self.dim = dim
929    self.initializer = initializer
930    self.optimizer = optimizer
931    self.combiner = combiner
932    self.name = name
933
934  def __repr__(self):
935    # If using the default initializer, just print "None" for clarity.
936    initializer = self.initializer
937
938    if isinstance(initializer, init_ops_v2.TruncatedNormal):
939      # PY2 type checking can't infer type of initializer even after if.
940      initializer = typing.cast(init_ops_v2.TruncatedNormal, initializer)
941      if (initializer.mean == 0.0
942          and math.isclose(initializer.stddev, 1/math.sqrt(self.dim))):  # pytype: disable=module-attr (math.isclose not in PY2)
943        initializer = None
944
945    return (
946        "TableConfig(vocabulary_size={vocabulary_size!r}, dim={dim!r}, "
947        "initializer={initializer!r}, optimizer={optimizer!r}, "
948        "combiner={combiner!r}, name={name!r})".format(
949            vocabulary_size=self.vocabulary_size,
950            dim=self.dim,
951            initializer=initializer,
952            optimizer=self.optimizer,
953            combiner=self.combiner,
954            name=self.name,)
955    )
956
957
958@tf_export("tpu.experimental.embedding.FeatureConfig")
959class FeatureConfig:
960  """Configuration data for one embedding feature.
961
962  This class holds the configuration data for a single embedding feature. The
963  main use is to assign features to `tf.tpu.experimental.embedding.TableConfig`s
964  via the table parameter:
965
966  ```python
967  table_config_one = tf.tpu.experimental.embedding.TableConfig(
968      vocabulary_size=...,
969      dim=...)
970  table_config_two = tf.tpu.experimental.embedding.TableConfig(
971      vocabulary_size=...,
972      dim=...)
973  feature_config = {
974      'feature_one': tf.tpu.experimental.embedding.FeatureConfig(
975          table=table_config_one),
976      'feature_two': tf.tpu.experimental.embedding.FeatureConfig(
977          table=table_config_one),
978      'feature_three': tf.tpu.experimental.embedding.FeatureConfig(
979          table=table_config_two)}
980  embedding = tf.tpu.experimental.embedding.TPUEmbedding(
981      feature_config=feature_config,
982      batch_size=...
983      optimizer=tf.tpu.experimental.embedding.Adam(0.1))
984  ```
985
986  The above configuration has 2 tables, and three features. The first two
987  features will be looked up in the first table and the third feature will be
988  looked up in the second table.
989
990  You can also specify the output shape for each feature. The output shape
991  should be the expected activation shape excluding the table dimension. For
992  dense and sparse tensor, the output shape should be the same as the input
993  shape excluding the last dimension. For ragged tensor, the output shape can
994  mismatch the input shape.
995
996  NOTE: The `max_sequence_length` will be only used when the input tensor has
997  rank 2 and the `output_shape` is not set in the feature config.
998
999  When feeding features into `embedding.enqueue` they can be `tf.Tensor`s,
1000  `tf.SparseTensor`s or `tf.RaggedTensor`s. When the argument
1001  `max_sequence_length` is 0, the default, you should expect a output of
1002  `embedding.dequeue` for this feature of shape `(batch_size, dim)`. If
1003  `max_sequence_length` is greater than 0, the feature is embedded as a sequence
1004  and padded up to the given length. The shape of the output for this feature
1005  will be `(batch_size, max_sequence_length, dim)`.
1006  """
1007
1008  def __init__(self,
1009               table: TableConfig,
1010               max_sequence_length: int = 0,
1011               validate_weights_and_indices: bool = True,
1012               output_shape: Optional[Union[List[int], TensorShape]] = None,
1013               name: Optional[Text] = None):
1014    """Feature configuration.
1015
1016    Args:
1017      table: An instance of `tf.tpu.experimental.embedding.TableConfig`,
1018        describing the table in which this feature should be looked up.
1019      max_sequence_length: If positive, the feature is a sequence feature with
1020        the corresponding maximum sequence length. If the sequence is longer
1021        than this, it will be truncated. If 0, the feature is not a sequence
1022        feature.
1023      validate_weights_and_indices: If true, uses safe_embedding_lookup during
1024        serving which ensures there are no empty rows and all weights and ids
1025        are positive at the expense of extra compute cost.
1026      output_shape: Optional argument to config the output shape of the feature
1027        activation. If provided, the feature feeding to the `embedding.enqueue`
1028        has to match the shape (for ragged tensor, the input shape and output
1029        shape can mismatch). If not provided, the shape can be either provided
1030        to the `embedding.build` or auto detected at the runtime.
1031      name: An optional name for the feature, useful for debugging.
1032
1033    Returns:
1034      `FeatureConfig`.
1035
1036    Raises:
1037      ValueError: if `table` is not an instance of
1038        `tf.tpu.experimental.embedding.TableConfig`.
1039      ValueError: if `max_sequence_length` not an integer or is negative.
1040    """
1041    if not isinstance(table, TableConfig):
1042      raise ValueError(f"Argument `table` has invalid type {type(table)}. "
1043                       "Expected `tf.tpu.experimental.embedding.TableConfig`.")
1044
1045    if not isinstance(max_sequence_length, int) or max_sequence_length < 0:
1046      raise ValueError(
1047          f"Argument `max_sequence_length` must be an int and must be >= 0. "
1048          f"Received: {max_sequence_length}")
1049
1050    self.table = table
1051    self.max_sequence_length = max_sequence_length
1052    self.name = name
1053    self.output_shape = TensorShape(output_shape)
1054
1055    if not isinstance(
1056        validate_weights_and_indices, bool):
1057      raise ValueError(
1058          f"Argument `validate_weights_and_indices` must be a boolean. "
1059          f"Received: {validate_weights_and_indices}")
1060
1061    self.validate_weights_and_indices = validate_weights_and_indices
1062
1063  def __repr__(self):
1064    return ("FeatureConfig(table={table!r}, "
1065            "max_sequence_length={max_sequence_length!r}, "
1066            "validate_weights_and_indices={"
1067            "validate_weights_and_indices!r}, name={name!r})".format(
1068                table=self.table,
1069                max_sequence_length=self.max_sequence_length,
1070                validate_weights_and_indices=self.validate_weights_and_indices,
1071                name=self.name))
1072
1073
1074def log_tpu_embedding_configuration(
1075    config: tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration) -> None:
1076  """Logs a TPUEmbeddingConfiguration proto across multiple statements.
1077
1078  Args:
1079    config: TPUEmbeddingConfiguration proto to log.  Necessary because
1080      logging.info has a maximum length to each log statement, which
1081      particularly large configs can exceed.
1082  """
1083  logging.info("Beginning log of TPUEmbeddingConfiguration.")
1084  for line in str(config).splitlines():
1085    logging.info(line)
1086  logging.info("Done with log of TPUEmbeddingConfiguration.")
1087