xref: /aosp_15_r20/external/tensorflow/tensorflow/python/feature_column/sequence_feature_column.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""This API defines FeatureColumn for sequential input.
16
17NOTE: This API is a work in progress and will likely be changing frequently.
18"""
19
20
21import collections
22
23
24from tensorflow.python.feature_column import feature_column_v2 as fc
25from tensorflow.python.feature_column import utils as fc_utils
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import tensor_shape
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import check_ops
31from tensorflow.python.ops import parsing_ops
32from tensorflow.python.ops import sparse_ops
33from tensorflow.python.util.tf_export import tf_export
34
35
36# pylint: disable=protected-access
37def concatenate_context_input(context_input, sequence_input):
38  """Replicates `context_input` across all timesteps of `sequence_input`.
39
40  Expands dimension 1 of `context_input` then tiles it `sequence_length` times.
41  This value is appended to `sequence_input` on dimension 2 and the result is
42  returned.
43
44  Args:
45    context_input: A `Tensor` of dtype `float32` and shape `[batch_size, d1]`.
46    sequence_input: A `Tensor` of dtype `float32` and shape `[batch_size,
47      padded_length, d0]`.
48
49  Returns:
50    A `Tensor` of dtype `float32` and shape `[batch_size, padded_length,
51    d0 + d1]`.
52
53  Raises:
54    ValueError: If `sequence_input` does not have rank 3 or `context_input` does
55      not have rank 2.
56  """
57  seq_rank_check = check_ops.assert_rank(
58      sequence_input,
59      3,
60      message='sequence_input must have rank 3',
61      data=[array_ops.shape(sequence_input)])
62  seq_type_check = check_ops.assert_type(
63      sequence_input,
64      dtypes.float32,
65      message='sequence_input must have dtype float32; got {}.'.format(
66          sequence_input.dtype))
67  ctx_rank_check = check_ops.assert_rank(
68      context_input,
69      2,
70      message='context_input must have rank 2',
71      data=[array_ops.shape(context_input)])
72  ctx_type_check = check_ops.assert_type(
73      context_input,
74      dtypes.float32,
75      message='context_input must have dtype float32; got {}.'.format(
76          context_input.dtype))
77  with ops.control_dependencies(
78      [seq_rank_check, seq_type_check, ctx_rank_check, ctx_type_check]):
79    padded_length = array_ops.shape(sequence_input)[1]
80    tiled_context_input = array_ops.tile(
81        array_ops.expand_dims(context_input, 1),
82        array_ops.concat([[1], [padded_length], [1]], 0))
83  return array_ops.concat([sequence_input, tiled_context_input], 2)
84
85
86@tf_export('feature_column.sequence_categorical_column_with_identity')
87def sequence_categorical_column_with_identity(
88    key, num_buckets, default_value=None):
89  """Returns a feature column that represents sequences of integers.
90
91  Pass this to `embedding_column` or `indicator_column` to convert sequence
92  categorical data into dense representation for input to sequence NN, such as
93  RNN.
94
95  Example:
96
97  ```python
98  watches = sequence_categorical_column_with_identity(
99      'watches', num_buckets=1000)
100  watches_embedding = embedding_column(watches, dimension=10)
101  columns = [watches_embedding]
102
103  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
104  sequence_feature_layer = SequenceFeatures(columns)
105  sequence_input, sequence_length = sequence_feature_layer(features)
106  sequence_length_mask = tf.sequence_mask(sequence_length)
107
108  rnn_cell = tf.keras.layers.SimpleRNNCell(hidden_size)
109  rnn_layer = tf.keras.layers.RNN(rnn_cell)
110  outputs, state = rnn_layer(sequence_input, mask=sequence_length_mask)
111  ```
112
113  Args:
114    key: A unique string identifying the input feature.
115    num_buckets: Range of inputs. Namely, inputs are expected to be in the
116      range `[0, num_buckets)`.
117    default_value: If `None`, this column's graph operations will fail for
118      out-of-range inputs. Otherwise, this value must be in the range
119      `[0, num_buckets)`, and will replace out-of-range inputs.
120
121  Returns:
122    A `SequenceCategoricalColumn`.
123
124  Raises:
125    ValueError: if `num_buckets` is less than one.
126    ValueError: if `default_value` is not in range `[0, num_buckets)`.
127  """
128  return fc.SequenceCategoricalColumn(
129      fc.categorical_column_with_identity(
130          key=key,
131          num_buckets=num_buckets,
132          default_value=default_value))
133
134
135@tf_export('feature_column.sequence_categorical_column_with_hash_bucket')
136def sequence_categorical_column_with_hash_bucket(
137    key, hash_bucket_size, dtype=dtypes.string):
138  """A sequence of categorical terms where ids are set by hashing.
139
140  Pass this to `embedding_column` or `indicator_column` to convert sequence
141  categorical data into dense representation for input to sequence NN, such as
142  RNN.
143
144  Example:
145
146  ```python
147  tokens = sequence_categorical_column_with_hash_bucket(
148      'tokens', hash_bucket_size=1000)
149  tokens_embedding = embedding_column(tokens, dimension=10)
150  columns = [tokens_embedding]
151
152  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
153  sequence_feature_layer = SequenceFeatures(columns)
154  sequence_input, sequence_length = sequence_feature_layer(features)
155  sequence_length_mask = tf.sequence_mask(sequence_length)
156
157  rnn_cell = tf.keras.layers.SimpleRNNCell(hidden_size)
158  rnn_layer = tf.keras.layers.RNN(rnn_cell)
159  outputs, state = rnn_layer(sequence_input, mask=sequence_length_mask)
160  ```
161
162  Args:
163    key: A unique string identifying the input feature.
164    hash_bucket_size: An int > 1. The number of buckets.
165    dtype: The type of features. Only string and integer types are supported.
166
167  Returns:
168    A `SequenceCategoricalColumn`.
169
170  Raises:
171    ValueError: `hash_bucket_size` is not greater than 1.
172    ValueError: `dtype` is neither string nor integer.
173  """
174  return fc.SequenceCategoricalColumn(
175      fc.categorical_column_with_hash_bucket(
176          key=key,
177          hash_bucket_size=hash_bucket_size,
178          dtype=dtype))
179
180
181@tf_export('feature_column.sequence_categorical_column_with_vocabulary_file')
182def sequence_categorical_column_with_vocabulary_file(
183    key, vocabulary_file, vocabulary_size=None, num_oov_buckets=0,
184    default_value=None, dtype=dtypes.string):
185  """A sequence of categorical terms where ids use a vocabulary file.
186
187  Pass this to `embedding_column` or `indicator_column` to convert sequence
188  categorical data into dense representation for input to sequence NN, such as
189  RNN.
190
191  Example:
192
193  ```python
194  states = sequence_categorical_column_with_vocabulary_file(
195      key='states', vocabulary_file='/us/states.txt', vocabulary_size=50,
196      num_oov_buckets=5)
197  states_embedding = embedding_column(states, dimension=10)
198  columns = [states_embedding]
199
200  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
201  sequence_feature_layer = SequenceFeatures(columns)
202  sequence_input, sequence_length = sequence_feature_layer(features)
203  sequence_length_mask = tf.sequence_mask(sequence_length)
204
205  rnn_cell = tf.keras.layers.SimpleRNNCell(hidden_size)
206  rnn_layer = tf.keras.layers.RNN(rnn_cell)
207  outputs, state = rnn_layer(sequence_input, mask=sequence_length_mask)
208  ```
209
210  Args:
211    key: A unique string identifying the input feature.
212    vocabulary_file: The vocabulary file name.
213    vocabulary_size: Number of the elements in the vocabulary. This must be no
214      greater than length of `vocabulary_file`, if less than length, later
215      values are ignored. If None, it is set to the length of `vocabulary_file`.
216    num_oov_buckets: Non-negative integer, the number of out-of-vocabulary
217      buckets. All out-of-vocabulary inputs will be assigned IDs in the range
218      `[vocabulary_size, vocabulary_size+num_oov_buckets)` based on a hash of
219      the input value. A positive `num_oov_buckets` can not be specified with
220      `default_value`.
221    default_value: The integer ID value to return for out-of-vocabulary feature
222      values, defaults to `-1`. This can not be specified with a positive
223      `num_oov_buckets`.
224    dtype: The type of features. Only string and integer types are supported.
225
226  Returns:
227    A `SequenceCategoricalColumn`.
228
229  Raises:
230    ValueError: `vocabulary_file` is missing or cannot be opened.
231    ValueError: `vocabulary_size` is missing or < 1.
232    ValueError: `num_oov_buckets` is a negative integer.
233    ValueError: `num_oov_buckets` and `default_value` are both specified.
234    ValueError: `dtype` is neither string nor integer.
235  """
236  return fc.SequenceCategoricalColumn(
237      fc.categorical_column_with_vocabulary_file(
238          key=key,
239          vocabulary_file=vocabulary_file,
240          vocabulary_size=vocabulary_size,
241          num_oov_buckets=num_oov_buckets,
242          default_value=default_value,
243          dtype=dtype))
244
245
246@tf_export('feature_column.sequence_categorical_column_with_vocabulary_list')
247def sequence_categorical_column_with_vocabulary_list(
248    key, vocabulary_list, dtype=None, default_value=-1, num_oov_buckets=0):
249  """A sequence of categorical terms where ids use an in-memory list.
250
251  Pass this to `embedding_column` or `indicator_column` to convert sequence
252  categorical data into dense representation for input to sequence NN, such as
253  RNN.
254
255  Example:
256
257  ```python
258  colors = sequence_categorical_column_with_vocabulary_list(
259      key='colors', vocabulary_list=('R', 'G', 'B', 'Y'),
260      num_oov_buckets=2)
261  colors_embedding = embedding_column(colors, dimension=3)
262  columns = [colors_embedding]
263
264  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
265  sequence_feature_layer = SequenceFeatures(columns)
266  sequence_input, sequence_length = sequence_feature_layer(features)
267  sequence_length_mask = tf.sequence_mask(sequence_length)
268
269  rnn_cell = tf.keras.layers.SimpleRNNCell(hidden_size)
270  rnn_layer = tf.keras.layers.RNN(rnn_cell)
271  outputs, state = rnn_layer(sequence_input, mask=sequence_length_mask)
272  ```
273
274  Args:
275    key: A unique string identifying the input feature.
276    vocabulary_list: An ordered iterable defining the vocabulary. Each feature
277      is mapped to the index of its value (if present) in `vocabulary_list`.
278      Must be castable to `dtype`.
279    dtype: The type of features. Only string and integer types are supported.
280      If `None`, it will be inferred from `vocabulary_list`.
281    default_value: The integer ID value to return for out-of-vocabulary feature
282      values, defaults to `-1`. This can not be specified with a positive
283      `num_oov_buckets`.
284    num_oov_buckets: Non-negative integer, the number of out-of-vocabulary
285      buckets. All out-of-vocabulary inputs will be assigned IDs in the range
286      `[len(vocabulary_list), len(vocabulary_list)+num_oov_buckets)` based on a
287      hash of the input value. A positive `num_oov_buckets` can not be specified
288      with `default_value`.
289
290  Returns:
291    A `SequenceCategoricalColumn`.
292
293  Raises:
294    ValueError: if `vocabulary_list` is empty, or contains duplicate keys.
295    ValueError: `num_oov_buckets` is a negative integer.
296    ValueError: `num_oov_buckets` and `default_value` are both specified.
297    ValueError: if `dtype` is not integer or string.
298  """
299  return fc.SequenceCategoricalColumn(
300      fc.categorical_column_with_vocabulary_list(
301          key=key,
302          vocabulary_list=vocabulary_list,
303          dtype=dtype,
304          default_value=default_value,
305          num_oov_buckets=num_oov_buckets))
306
307
308@tf_export('feature_column.sequence_numeric_column')
309def sequence_numeric_column(
310    key,
311    shape=(1,),
312    default_value=0.,
313    dtype=dtypes.float32,
314    normalizer_fn=None):
315  """Returns a feature column that represents sequences of numeric data.
316
317  Example:
318
319  ```python
320  temperature = sequence_numeric_column('temperature')
321  columns = [temperature]
322
323  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
324  sequence_feature_layer = SequenceFeatures(columns)
325  sequence_input, sequence_length = sequence_feature_layer(features)
326  sequence_length_mask = tf.sequence_mask(sequence_length)
327
328  rnn_cell = tf.keras.layers.SimpleRNNCell(hidden_size)
329  rnn_layer = tf.keras.layers.RNN(rnn_cell)
330  outputs, state = rnn_layer(sequence_input, mask=sequence_length_mask)
331  ```
332
333  Args:
334    key: A unique string identifying the input features.
335    shape: The shape of the input data per sequence id. E.g. if `shape=(2,)`,
336      each example must contain `2 * sequence_length` values.
337    default_value: A single value compatible with `dtype` that is used for
338      padding the sparse data into a dense `Tensor`.
339    dtype: The type of values.
340    normalizer_fn: If not `None`, a function that can be used to normalize the
341      value of the tensor after `default_value` is applied for parsing.
342      Normalizer function takes the input `Tensor` as its argument, and returns
343      the output `Tensor`. (e.g. lambda x: (x - 3.0) / 4.2). Please note that
344      even though the most common use case of this function is normalization, it
345      can be used for any kind of Tensorflow transformations.
346
347  Returns:
348    A `SequenceNumericColumn`.
349
350  Raises:
351    TypeError: if any dimension in shape is not an int.
352    ValueError: if any dimension in shape is not a positive integer.
353    ValueError: if `dtype` is not convertible to `tf.float32`.
354  """
355  shape = fc._check_shape(shape=shape, key=key)
356  if not (dtype.is_integer or dtype.is_floating):
357    raise ValueError('dtype must be convertible to float. '
358                     'dtype: {}, key: {}'.format(dtype, key))
359  if normalizer_fn is not None and not callable(normalizer_fn):
360    raise TypeError(
361        'normalizer_fn must be a callable. Given: {}'.format(normalizer_fn))
362
363  return SequenceNumericColumn(
364      key,
365      shape=shape,
366      default_value=default_value,
367      dtype=dtype,
368      normalizer_fn=normalizer_fn)
369
370
371def _assert_all_equal_and_return(tensors, name=None):
372  """Asserts that all tensors are equal and returns the first one."""
373  with ops.name_scope(name, 'assert_all_equal', values=tensors):
374    if len(tensors) == 1:
375      return tensors[0]
376    assert_equal_ops = []
377    for t in tensors[1:]:
378      assert_equal_ops.append(check_ops.assert_equal(tensors[0], t))
379    with ops.control_dependencies(assert_equal_ops):
380      return array_ops.identity(tensors[0])
381
382
383class SequenceNumericColumn(
384    fc.SequenceDenseColumn,
385    collections.namedtuple(
386        'SequenceNumericColumn',
387        ('key', 'shape', 'default_value', 'dtype', 'normalizer_fn'))):
388  """Represents sequences of numeric data."""
389
390  @property
391  def _is_v2_column(self):
392    return True
393
394  @property
395  def name(self):
396    """See `FeatureColumn` base class."""
397    return self.key
398
399  @property
400  def parse_example_spec(self):
401    """See `FeatureColumn` base class."""
402    return {self.key: parsing_ops.VarLenFeature(self.dtype)}
403
404  def transform_feature(self, transformation_cache, state_manager):
405    """See `FeatureColumn` base class.
406
407    In this case, we apply the `normalizer_fn` to the input tensor.
408
409    Args:
410      transformation_cache: A `FeatureTransformationCache` object to access
411        features.
412      state_manager: A `StateManager` to create / access resources such as
413        lookup tables.
414
415    Returns:
416      Normalized input tensor.
417    """
418    input_tensor = transformation_cache.get(self.key, state_manager)
419    if self.normalizer_fn is not None:
420      input_tensor = self.normalizer_fn(input_tensor)
421    return input_tensor
422
423  @property
424  def variable_shape(self):
425    """Returns a `TensorShape` representing the shape of sequence input."""
426    return tensor_shape.TensorShape(self.shape)
427
428  def get_sequence_dense_tensor(self, transformation_cache, state_manager):
429    """Returns a `TensorSequenceLengthPair`.
430
431    Args:
432      transformation_cache: A `FeatureTransformationCache` object to access
433        features.
434      state_manager: A `StateManager` to create / access resources such as
435        lookup tables.
436    """
437    sp_tensor = transformation_cache.get(self, state_manager)
438    dense_tensor = sparse_ops.sparse_tensor_to_dense(
439        sp_tensor, default_value=self.default_value)
440    # Reshape into [batch_size, T, variable_shape].
441    dense_shape = array_ops.concat(
442        [array_ops.shape(dense_tensor)[:1], [-1], self.variable_shape],
443        axis=0)
444    dense_tensor = array_ops.reshape(dense_tensor, shape=dense_shape)
445
446    # Get the number of timesteps per example
447    # For the 2D case, the raw values are grouped according to num_elements;
448    # for the 3D case, the grouping happens in the third dimension, and
449    # sequence length is not affected.
450    if sp_tensor.shape.ndims == 2:
451      num_elements = self.variable_shape.num_elements()
452    else:
453      num_elements = 1
454    seq_length = fc_utils.sequence_length_from_sparse_tensor(
455        sp_tensor, num_elements=num_elements)
456
457    return fc.SequenceDenseColumn.TensorSequenceLengthPair(
458        dense_tensor=dense_tensor, sequence_length=seq_length)
459
460  @property
461  def parents(self):
462    """See 'FeatureColumn` base class."""
463    return [self.key]
464
465  def get_config(self):
466    """See 'FeatureColumn` base class."""
467    config = dict(zip(self._fields, self))
468    config['dtype'] = self.dtype.name
469    return config
470
471  @classmethod
472  def from_config(cls, config, custom_objects=None, columns_by_name=None):
473    """See 'FeatureColumn` base class."""
474    fc._check_config_keys(config, cls._fields)
475    kwargs = fc._standardize_and_copy_config(config)
476    kwargs['dtype'] = dtypes.as_dtype(config['dtype'])
477    return cls(**kwargs)
478
479
480# pylint: enable=protected-access
481