xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/layers/embeddings.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Embedding layer."""
16# pylint: disable=g-classes-have-attributes
17
18from tensorflow.python.keras import backend
19from tensorflow.python.keras import constraints
20from tensorflow.python.keras import initializers
21from tensorflow.python.keras import regularizers
22from tensorflow.python.keras.engine import base_layer_utils
23from tensorflow.python.keras.engine.base_layer import Layer
24from tensorflow.python.keras.utils import tf_utils
25from tensorflow.python.ops import embedding_ops
26from tensorflow.python.ops import math_ops
27from tensorflow.python.util.tf_export import keras_export
28
29
30@keras_export('keras.layers.Embedding')
31class Embedding(Layer):
32  """Turns positive integers (indexes) into dense vectors of fixed size.
33
34  e.g. `[[4], [20]] -> [[0.25, 0.1], [0.6, -0.2]]`
35
36  This layer can only be used as the first layer in a model.
37
38  Example:
39
40  >>> model = tf.keras.Sequential()
41  >>> model.add(tf.keras.layers.Embedding(1000, 64, input_length=10))
42  >>> # The model will take as input an integer matrix of size (batch,
43  >>> # input_length), and the largest integer (i.e. word index) in the input
44  >>> # should be no larger than 999 (vocabulary size).
45  >>> # Now model.output_shape is (None, 10, 64), where `None` is the batch
46  >>> # dimension.
47  >>> input_array = np.random.randint(1000, size=(32, 10))
48  >>> model.compile('rmsprop', 'mse')
49  >>> output_array = model.predict(input_array)
50  >>> print(output_array.shape)
51  (32, 10, 64)
52
53  Args:
54    input_dim: Integer. Size of the vocabulary,
55      i.e. maximum integer index + 1.
56    output_dim: Integer. Dimension of the dense embedding.
57    embeddings_initializer: Initializer for the `embeddings`
58      matrix (see `keras.initializers`).
59    embeddings_regularizer: Regularizer function applied to
60      the `embeddings` matrix (see `keras.regularizers`).
61    embeddings_constraint: Constraint function applied to
62      the `embeddings` matrix (see `keras.constraints`).
63    mask_zero: Boolean, whether or not the input value 0 is a special "padding"
64      value that should be masked out.
65      This is useful when using recurrent layers
66      which may take variable length input.
67      If this is `True`, then all subsequent layers
68      in the model need to support masking or an exception will be raised.
69      If mask_zero is set to True, as a consequence, index 0 cannot be
70      used in the vocabulary (input_dim should equal size of
71      vocabulary + 1).
72    input_length: Length of input sequences, when it is constant.
73      This argument is required if you are going to connect
74      `Flatten` then `Dense` layers upstream
75      (without it, the shape of the dense outputs cannot be computed).
76
77  Input shape:
78    2D tensor with shape: `(batch_size, input_length)`.
79
80  Output shape:
81    3D tensor with shape: `(batch_size, input_length, output_dim)`.
82
83  **Note on variable placement:**
84  By default, if a GPU is available, the embedding matrix will be placed on
85  the GPU. This achieves the best performance, but it might cause issues:
86
87  - You may be using an optimizer that does not support sparse GPU kernels.
88  In this case you will see an error upon training your model.
89  - Your embedding matrix may be too large to fit on your GPU. In this case
90  you will see an Out Of Memory (OOM) error.
91
92  In such cases, you should place the embedding matrix on the CPU memory.
93  You can do so with a device scope, as such:
94
95  ```python
96  with tf.device('cpu:0'):
97    embedding_layer = Embedding(...)
98    embedding_layer.build()
99  ```
100
101  The pre-built `embedding_layer` instance can then be added to a `Sequential`
102  model (e.g. `model.add(embedding_layer)`), called in a Functional model
103  (e.g. `x = embedding_layer(x)`), or used in a subclassed model.
104  """
105
106  def __init__(self,
107               input_dim,
108               output_dim,
109               embeddings_initializer='uniform',
110               embeddings_regularizer=None,
111               activity_regularizer=None,
112               embeddings_constraint=None,
113               mask_zero=False,
114               input_length=None,
115               **kwargs):
116    if 'input_shape' not in kwargs:
117      if input_length:
118        kwargs['input_shape'] = (input_length,)
119      else:
120        kwargs['input_shape'] = (None,)
121    if input_dim <= 0 or output_dim <= 0:
122      raise ValueError('Both `input_dim` and `output_dim` should be positive, '
123                       'found input_dim {} and output_dim {}'.format(
124                           input_dim, output_dim))
125    if (not base_layer_utils.v2_dtype_behavior_enabled() and
126        'dtype' not in kwargs):
127      # In TF1, the dtype defaults to the input dtype which is typically int32,
128      # so explicitly set it to floatx
129      kwargs['dtype'] = backend.floatx()
130    # We set autocast to False, as we do not want to cast floating- point inputs
131    # to self.dtype. In call(), we cast to int32, and casting to self.dtype
132    # before casting to int32 might cause the int32 values to be different due
133    # to a loss of precision.
134    kwargs['autocast'] = False
135    super(Embedding, self).__init__(**kwargs)
136
137    self.input_dim = input_dim
138    self.output_dim = output_dim
139    self.embeddings_initializer = initializers.get(embeddings_initializer)
140    self.embeddings_regularizer = regularizers.get(embeddings_regularizer)
141    self.activity_regularizer = regularizers.get(activity_regularizer)
142    self.embeddings_constraint = constraints.get(embeddings_constraint)
143    self.mask_zero = mask_zero
144    self.supports_masking = mask_zero
145    self.input_length = input_length
146
147  @tf_utils.shape_type_conversion
148  def build(self, input_shape=None):
149    self.embeddings = self.add_weight(
150        shape=(self.input_dim, self.output_dim),
151        initializer=self.embeddings_initializer,
152        name='embeddings',
153        regularizer=self.embeddings_regularizer,
154        constraint=self.embeddings_constraint,
155        experimental_autocast=False)
156    self.built = True
157
158  def compute_mask(self, inputs, mask=None):
159    if not self.mask_zero:
160      return None
161    return math_ops.not_equal(inputs, 0)
162
163  @tf_utils.shape_type_conversion
164  def compute_output_shape(self, input_shape):
165    if self.input_length is None:
166      return input_shape + (self.output_dim,)
167    else:
168      # input_length can be tuple if input is 3D or higher
169      if isinstance(self.input_length, (list, tuple)):
170        in_lens = list(self.input_length)
171      else:
172        in_lens = [self.input_length]
173      if len(in_lens) != len(input_shape) - 1:
174        raise ValueError('"input_length" is %s, '
175                         'but received input has shape %s' % (str(
176                             self.input_length), str(input_shape)))
177      else:
178        for i, (s1, s2) in enumerate(zip(in_lens, input_shape[1:])):
179          if s1 is not None and s2 is not None and s1 != s2:
180            raise ValueError('"input_length" is %s, '
181                             'but received input has shape %s' % (str(
182                                 self.input_length), str(input_shape)))
183          elif s1 is None:
184            in_lens[i] = s2
185      return (input_shape[0],) + tuple(in_lens) + (self.output_dim,)
186
187  def call(self, inputs):
188    dtype = backend.dtype(inputs)
189    if dtype != 'int32' and dtype != 'int64':
190      inputs = math_ops.cast(inputs, 'int32')
191    out = embedding_ops.embedding_lookup_v2(self.embeddings, inputs)
192    if self._dtype_policy.compute_dtype != self._dtype_policy.variable_dtype:
193      # Instead of casting the variable as in most layers, cast the output, as
194      # this is mathematically equivalent but is faster.
195      out = math_ops.cast(out, self._dtype_policy.compute_dtype)
196    return out
197
198  def get_config(self):
199    config = {
200        'input_dim': self.input_dim,
201        'output_dim': self.output_dim,
202        'embeddings_initializer':
203            initializers.serialize(self.embeddings_initializer),
204        'embeddings_regularizer':
205            regularizers.serialize(self.embeddings_regularizer),
206        'activity_regularizer':
207            regularizers.serialize(self.activity_regularizer),
208        'embeddings_constraint':
209            constraints.serialize(self.embeddings_constraint),
210        'mask_zero': self.mask_zero,
211        'input_length': self.input_length
212    }
213    base_config = super(Embedding, self).get_config()
214    return dict(list(base_config.items()) + list(config.items()))
215