1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Embedding layer.""" 16# pylint: disable=g-classes-have-attributes 17 18from tensorflow.python.keras import backend 19from tensorflow.python.keras import constraints 20from tensorflow.python.keras import initializers 21from tensorflow.python.keras import regularizers 22from tensorflow.python.keras.engine import base_layer_utils 23from tensorflow.python.keras.engine.base_layer import Layer 24from tensorflow.python.keras.utils import tf_utils 25from tensorflow.python.ops import embedding_ops 26from tensorflow.python.ops import math_ops 27from tensorflow.python.util.tf_export import keras_export 28 29 30@keras_export('keras.layers.Embedding') 31class Embedding(Layer): 32 """Turns positive integers (indexes) into dense vectors of fixed size. 33 34 e.g. `[[4], [20]] -> [[0.25, 0.1], [0.6, -0.2]]` 35 36 This layer can only be used as the first layer in a model. 37 38 Example: 39 40 >>> model = tf.keras.Sequential() 41 >>> model.add(tf.keras.layers.Embedding(1000, 64, input_length=10)) 42 >>> # The model will take as input an integer matrix of size (batch, 43 >>> # input_length), and the largest integer (i.e. word index) in the input 44 >>> # should be no larger than 999 (vocabulary size). 45 >>> # Now model.output_shape is (None, 10, 64), where `None` is the batch 46 >>> # dimension. 47 >>> input_array = np.random.randint(1000, size=(32, 10)) 48 >>> model.compile('rmsprop', 'mse') 49 >>> output_array = model.predict(input_array) 50 >>> print(output_array.shape) 51 (32, 10, 64) 52 53 Args: 54 input_dim: Integer. Size of the vocabulary, 55 i.e. maximum integer index + 1. 56 output_dim: Integer. Dimension of the dense embedding. 57 embeddings_initializer: Initializer for the `embeddings` 58 matrix (see `keras.initializers`). 59 embeddings_regularizer: Regularizer function applied to 60 the `embeddings` matrix (see `keras.regularizers`). 61 embeddings_constraint: Constraint function applied to 62 the `embeddings` matrix (see `keras.constraints`). 63 mask_zero: Boolean, whether or not the input value 0 is a special "padding" 64 value that should be masked out. 65 This is useful when using recurrent layers 66 which may take variable length input. 67 If this is `True`, then all subsequent layers 68 in the model need to support masking or an exception will be raised. 69 If mask_zero is set to True, as a consequence, index 0 cannot be 70 used in the vocabulary (input_dim should equal size of 71 vocabulary + 1). 72 input_length: Length of input sequences, when it is constant. 73 This argument is required if you are going to connect 74 `Flatten` then `Dense` layers upstream 75 (without it, the shape of the dense outputs cannot be computed). 76 77 Input shape: 78 2D tensor with shape: `(batch_size, input_length)`. 79 80 Output shape: 81 3D tensor with shape: `(batch_size, input_length, output_dim)`. 82 83 **Note on variable placement:** 84 By default, if a GPU is available, the embedding matrix will be placed on 85 the GPU. This achieves the best performance, but it might cause issues: 86 87 - You may be using an optimizer that does not support sparse GPU kernels. 88 In this case you will see an error upon training your model. 89 - Your embedding matrix may be too large to fit on your GPU. In this case 90 you will see an Out Of Memory (OOM) error. 91 92 In such cases, you should place the embedding matrix on the CPU memory. 93 You can do so with a device scope, as such: 94 95 ```python 96 with tf.device('cpu:0'): 97 embedding_layer = Embedding(...) 98 embedding_layer.build() 99 ``` 100 101 The pre-built `embedding_layer` instance can then be added to a `Sequential` 102 model (e.g. `model.add(embedding_layer)`), called in a Functional model 103 (e.g. `x = embedding_layer(x)`), or used in a subclassed model. 104 """ 105 106 def __init__(self, 107 input_dim, 108 output_dim, 109 embeddings_initializer='uniform', 110 embeddings_regularizer=None, 111 activity_regularizer=None, 112 embeddings_constraint=None, 113 mask_zero=False, 114 input_length=None, 115 **kwargs): 116 if 'input_shape' not in kwargs: 117 if input_length: 118 kwargs['input_shape'] = (input_length,) 119 else: 120 kwargs['input_shape'] = (None,) 121 if input_dim <= 0 or output_dim <= 0: 122 raise ValueError('Both `input_dim` and `output_dim` should be positive, ' 123 'found input_dim {} and output_dim {}'.format( 124 input_dim, output_dim)) 125 if (not base_layer_utils.v2_dtype_behavior_enabled() and 126 'dtype' not in kwargs): 127 # In TF1, the dtype defaults to the input dtype which is typically int32, 128 # so explicitly set it to floatx 129 kwargs['dtype'] = backend.floatx() 130 # We set autocast to False, as we do not want to cast floating- point inputs 131 # to self.dtype. In call(), we cast to int32, and casting to self.dtype 132 # before casting to int32 might cause the int32 values to be different due 133 # to a loss of precision. 134 kwargs['autocast'] = False 135 super(Embedding, self).__init__(**kwargs) 136 137 self.input_dim = input_dim 138 self.output_dim = output_dim 139 self.embeddings_initializer = initializers.get(embeddings_initializer) 140 self.embeddings_regularizer = regularizers.get(embeddings_regularizer) 141 self.activity_regularizer = regularizers.get(activity_regularizer) 142 self.embeddings_constraint = constraints.get(embeddings_constraint) 143 self.mask_zero = mask_zero 144 self.supports_masking = mask_zero 145 self.input_length = input_length 146 147 @tf_utils.shape_type_conversion 148 def build(self, input_shape=None): 149 self.embeddings = self.add_weight( 150 shape=(self.input_dim, self.output_dim), 151 initializer=self.embeddings_initializer, 152 name='embeddings', 153 regularizer=self.embeddings_regularizer, 154 constraint=self.embeddings_constraint, 155 experimental_autocast=False) 156 self.built = True 157 158 def compute_mask(self, inputs, mask=None): 159 if not self.mask_zero: 160 return None 161 return math_ops.not_equal(inputs, 0) 162 163 @tf_utils.shape_type_conversion 164 def compute_output_shape(self, input_shape): 165 if self.input_length is None: 166 return input_shape + (self.output_dim,) 167 else: 168 # input_length can be tuple if input is 3D or higher 169 if isinstance(self.input_length, (list, tuple)): 170 in_lens = list(self.input_length) 171 else: 172 in_lens = [self.input_length] 173 if len(in_lens) != len(input_shape) - 1: 174 raise ValueError('"input_length" is %s, ' 175 'but received input has shape %s' % (str( 176 self.input_length), str(input_shape))) 177 else: 178 for i, (s1, s2) in enumerate(zip(in_lens, input_shape[1:])): 179 if s1 is not None and s2 is not None and s1 != s2: 180 raise ValueError('"input_length" is %s, ' 181 'but received input has shape %s' % (str( 182 self.input_length), str(input_shape))) 183 elif s1 is None: 184 in_lens[i] = s2 185 return (input_shape[0],) + tuple(in_lens) + (self.output_dim,) 186 187 def call(self, inputs): 188 dtype = backend.dtype(inputs) 189 if dtype != 'int32' and dtype != 'int64': 190 inputs = math_ops.cast(inputs, 'int32') 191 out = embedding_ops.embedding_lookup_v2(self.embeddings, inputs) 192 if self._dtype_policy.compute_dtype != self._dtype_policy.variable_dtype: 193 # Instead of casting the variable as in most layers, cast the output, as 194 # this is mathematically equivalent but is faster. 195 out = math_ops.cast(out, self._dtype_policy.compute_dtype) 196 return out 197 198 def get_config(self): 199 config = { 200 'input_dim': self.input_dim, 201 'output_dim': self.output_dim, 202 'embeddings_initializer': 203 initializers.serialize(self.embeddings_initializer), 204 'embeddings_regularizer': 205 regularizers.serialize(self.embeddings_regularizer), 206 'activity_regularizer': 207 regularizers.serialize(self.activity_regularizer), 208 'embeddings_constraint': 209 constraints.serialize(self.embeddings_constraint), 210 'mask_zero': self.mask_zero, 211 'input_length': self.input_length 212 } 213 base_config = super(Embedding, self).get_config() 214 return dict(list(base_config.items()) + list(config.items())) 215