xref: /aosp_15_r20/external/tensorflow/tensorflow/python/feature_column/feature_column_v2.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""This API defines FeatureColumn abstraction.
16
17FeatureColumns provide a high level abstraction for ingesting and representing
18features. FeatureColumns are also the primary way of encoding features for
19canned `tf.estimator.Estimator`s.
20
21When using FeatureColumns with `Estimators`, the type of feature column you
22should choose depends on (1) the feature type and (2) the model type.
23
241. Feature type:
25
26  * Continuous features can be represented by `numeric_column`.
27  * Categorical features can be represented by any `categorical_column_with_*`
28  column:
29    - `categorical_column_with_vocabulary_list`
30    - `categorical_column_with_vocabulary_file`
31    - `categorical_column_with_hash_bucket`
32    - `categorical_column_with_identity`
33    - `weighted_categorical_column`
34
352. Model type:
36
37  * Deep neural network models (`DNNClassifier`, `DNNRegressor`).
38
39    Continuous features can be directly fed into deep neural network models.
40
41      age_column = numeric_column("age")
42
43    To feed sparse features into DNN models, wrap the column with
44    `embedding_column` or `indicator_column`. `indicator_column` is recommended
45    for features with only a few possible values. For features with many
46    possible values, to reduce the size of your model, `embedding_column` is
47    recommended.
48
49      embedded_dept_column = embedding_column(
50          categorical_column_with_vocabulary_list(
51              "department", ["math", "philosophy", ...]), dimension=10)
52
53  * Wide (aka linear) models (`LinearClassifier`, `LinearRegressor`).
54
55    Sparse features can be fed directly into linear models. They behave like an
56    indicator column but with an efficient implementation.
57
58      dept_column = categorical_column_with_vocabulary_list("department",
59          ["math", "philosophy", "english"])
60
61    It is recommended that continuous features be bucketized before being
62    fed into linear models.
63
64      bucketized_age_column = bucketized_column(
65          source_column=age_column,
66          boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65])
67
68    Sparse features can be crossed (also known as conjuncted or combined) in
69    order to form non-linearities, and then fed into linear models.
70
71      cross_dept_age_column = crossed_column(
72          columns=["department", bucketized_age_column],
73          hash_bucket_size=1000)
74
75Example of building canned `Estimator`s using FeatureColumns:
76
77  ```python
78  # Define features and transformations
79  deep_feature_columns = [age_column, embedded_dept_column]
80  wide_feature_columns = [dept_column, bucketized_age_column,
81      cross_dept_age_column]
82
83  # Build deep model
84  estimator = DNNClassifier(
85      feature_columns=deep_feature_columns,
86      hidden_units=[500, 250, 50])
87  estimator.train(...)
88
89  # Or build a wide model
90  estimator = LinearClassifier(
91      feature_columns=wide_feature_columns)
92  estimator.train(...)
93
94  # Or build a wide and deep model!
95  estimator = DNNLinearCombinedClassifier(
96      linear_feature_columns=wide_feature_columns,
97      dnn_feature_columns=deep_feature_columns,
98      dnn_hidden_units=[500, 250, 50])
99  estimator.train(...)
100  ```
101
102
103FeatureColumns can also be transformed into a generic input layer for
104custom models using `input_layer`.
105
106Example of building model using FeatureColumns, this can be used in a
107`model_fn` which is given to the {tf.estimator.Estimator}:
108
109  ```python
110  # Building model via layers
111
112  deep_feature_columns = [age_column, embedded_dept_column]
113  columns_to_tensor = parse_feature_columns_from_examples(
114      serialized=my_data,
115      feature_columns=deep_feature_columns)
116  first_layer = input_layer(
117      features=columns_to_tensor,
118      feature_columns=deep_feature_columns)
119  second_layer = fully_connected(first_layer, ...)
120  ```
121
122NOTE: Functions prefixed with "_" indicate experimental or private parts of
123the API subject to change, and should not be relied upon!
124"""
125
126import abc
127import collections
128import math
129import re
130
131import numpy as np
132import six
133
134from tensorflow.python.data.experimental.ops import lookup_ops as data_lookup_ops
135from tensorflow.python.data.ops import readers
136from tensorflow.python.eager import context
137from tensorflow.python.feature_column import feature_column as fc_old
138from tensorflow.python.feature_column import utils as fc_utils
139from tensorflow.python.framework import dtypes
140from tensorflow.python.framework import ops
141from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
142from tensorflow.python.framework import tensor_shape
143from tensorflow.python.ops import array_ops
144from tensorflow.python.ops import check_ops
145from tensorflow.python.ops import control_flow_ops
146from tensorflow.python.ops import embedding_ops
147from tensorflow.python.ops import init_ops
148from tensorflow.python.ops import lookup_ops
149from tensorflow.python.ops import math_ops
150from tensorflow.python.ops import parsing_ops
151from tensorflow.python.ops import sparse_ops
152from tensorflow.python.ops import string_ops
153from tensorflow.python.ops import variable_scope
154from tensorflow.python.ops import variables
155from tensorflow.python.platform import gfile
156from tensorflow.python.platform import tf_logging as logging
157from tensorflow.python.trackable import autotrackable
158from tensorflow.python.trackable import base as trackable
159from tensorflow.python.trackable import data_structures
160from tensorflow.python.training import checkpoint_utils
161from tensorflow.python.util import deprecation
162from tensorflow.python.util import nest
163from tensorflow.python.util import tf_inspect
164from tensorflow.python.util.compat import collections_abc
165from tensorflow.python.util.tf_export import tf_export
166
167
168_FEATURE_COLUMN_DEPRECATION_DATE = None
169_FEATURE_COLUMN_DEPRECATION = ('The old _FeatureColumn APIs are being '
170                               'deprecated. Please use the new FeatureColumn '
171                               'APIs instead.')
172
173
174class StateManager(object):
175  """Manages the state associated with FeatureColumns.
176
177  Some `FeatureColumn`s create variables or resources to assist their
178  computation. The `StateManager` is responsible for creating and storing these
179  objects since `FeatureColumn`s are supposed to be stateless configuration
180  only.
181  """
182
183  def create_variable(self,
184                      feature_column,
185                      name,
186                      shape,
187                      dtype=None,
188                      trainable=True,
189                      use_resource=True,
190                      initializer=None):
191    """Creates a new variable.
192
193    Args:
194      feature_column: A `FeatureColumn` object this variable corresponds to.
195      name: variable name.
196      shape: variable shape.
197      dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
198      trainable: Whether this variable is trainable or not.
199      use_resource: If true, we use resource variables. Otherwise we use
200        RefVariable.
201      initializer: initializer instance (callable).
202
203    Returns:
204      The created variable.
205    """
206    del feature_column, name, shape, dtype, trainable, use_resource, initializer
207    raise NotImplementedError('StateManager.create_variable')
208
209  def add_variable(self, feature_column, var):
210    """Adds an existing variable to the state.
211
212    Args:
213      feature_column: A `FeatureColumn` object to associate this variable with.
214      var: The variable.
215    """
216    del feature_column, var
217    raise NotImplementedError('StateManager.add_variable')
218
219  def get_variable(self, feature_column, name):
220    """Returns an existing variable.
221
222    Args:
223      feature_column: A `FeatureColumn` object this variable corresponds to.
224      name: variable name.
225    """
226    del feature_column, name
227    raise NotImplementedError('StateManager.get_var')
228
229  def add_resource(self, feature_column, name, resource):
230    """Creates a new resource.
231
232    Resources can be things such as tables, variables, trackables, etc.
233
234    Args:
235      feature_column: A `FeatureColumn` object this resource corresponds to.
236      name: Name of the resource.
237      resource: The resource.
238
239    Returns:
240      The created resource.
241    """
242    del feature_column, name, resource
243    raise NotImplementedError('StateManager.add_resource')
244
245  def has_resource(self, feature_column, name):
246    """Returns true iff a resource with same name exists.
247
248    Resources can be things such as tables, variables, trackables, etc.
249
250    Args:
251      feature_column: A `FeatureColumn` object this variable corresponds to.
252      name: Name of the resource.
253    """
254    del feature_column, name
255    raise NotImplementedError('StateManager.has_resource')
256
257  def get_resource(self, feature_column, name):
258    """Returns an already created resource.
259
260    Resources can be things such as tables, variables, trackables, etc.
261
262    Args:
263      feature_column: A `FeatureColumn` object this variable corresponds to.
264      name: Name of the resource.
265    """
266    del feature_column, name
267    raise NotImplementedError('StateManager.get_resource')
268
269
270@tf_export('__internal__.feature_column.StateManager', v1=[])
271class _StateManagerImpl(StateManager):
272  """Manages the state of DenseFeatures and LinearLayer.
273
274  Some `FeatureColumn`s create variables or resources to assist their
275  computation. The `StateManager` is responsible for creating and storing these
276  objects since `FeatureColumn`s are supposed to be stateless configuration
277  only.
278  """
279
280  def __init__(self, layer, trainable):
281    """Creates an _StateManagerImpl object.
282
283    Args:
284      layer: The input layer this state manager is associated with.
285      trainable: Whether by default, variables created are trainable or not.
286    """
287    self._trainable = trainable
288    self._layer = layer
289    if self._layer is not None and not hasattr(self._layer, '_resources'):
290      self._layer._resources = data_structures.Mapping()  # pylint: disable=protected-access
291    self._cols_to_vars_map = collections.defaultdict(lambda: {})
292    self._cols_to_resources_map = collections.defaultdict(lambda: {})
293
294  def create_variable(self,
295                      feature_column,
296                      name,
297                      shape,
298                      dtype=None,
299                      trainable=True,
300                      use_resource=True,
301                      initializer=None):
302    """Creates a new variable.
303
304    Args:
305      feature_column: A `FeatureColumn` object this variable corresponds to.
306      name: variable name.
307      shape: variable shape.
308      dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
309      trainable: Whether this variable is trainable or not.
310      use_resource: If true, we use resource variables. Otherwise we use
311        RefVariable.
312      initializer: initializer instance (callable).
313
314    Returns:
315      The created variable.
316    """
317    if name in self._cols_to_vars_map[feature_column]:
318      raise ValueError('Variable already exists.')
319
320    # We explicitly track these variables since `name` is not guaranteed to be
321    # unique and disable manual tracking that the add_weight call does.
322    with trackable.no_manual_dependency_tracking_scope(self._layer):
323      var = self._layer.add_weight(
324          name=name,
325          shape=shape,
326          dtype=dtype,
327          initializer=initializer,
328          trainable=self._trainable and trainable,
329          use_resource=use_resource,
330          # TODO(rohanj): Get rid of this hack once we have a mechanism for
331          # specifying a default partitioner for an entire layer. In that case,
332          # the default getter for Layers should work.
333          getter=variable_scope.get_variable)
334    if isinstance(var, variables.PartitionedVariable):
335      for v in var:
336        part_name = name + '/' + str(v._get_save_slice_info().var_offset[0])  # pylint: disable=protected-access
337        self._layer._track_trackable(v, feature_column.name + '/' + part_name)  # pylint: disable=protected-access
338    else:
339      if isinstance(var, trackable.Trackable):
340        self._layer._track_trackable(var, feature_column.name + '/' + name)  # pylint: disable=protected-access
341
342    self._cols_to_vars_map[feature_column][name] = var
343    return var
344
345  def get_variable(self, feature_column, name):
346    """Returns an existing variable.
347
348    Args:
349      feature_column: A `FeatureColumn` object this variable corresponds to.
350      name: variable name.
351    """
352    if name in self._cols_to_vars_map[feature_column]:
353      return self._cols_to_vars_map[feature_column][name]
354    raise ValueError('Variable does not exist.')
355
356  def add_resource(self, feature_column, resource_name, resource):
357    """Creates a new resource.
358
359    Resources can be things such as tables, variables, trackables, etc.
360
361    Args:
362      feature_column: A `FeatureColumn` object this resource corresponds to.
363      resource_name: Name of the resource.
364      resource: The resource.
365
366    Returns:
367      The created resource.
368    """
369    self._cols_to_resources_map[feature_column][resource_name] = resource
370    # pylint: disable=protected-access
371    if self._layer is not None and isinstance(resource, trackable.Trackable):
372      # Add trackable resources to the layer for serialization.
373      if feature_column.name not in self._layer._resources:
374        self._layer._resources[feature_column.name] = data_structures.Mapping()
375      if resource_name not in self._layer._resources[feature_column.name]:
376        self._layer._resources[feature_column.name][resource_name] = resource
377    # pylint: enable=protected-access
378
379  def has_resource(self, feature_column, resource_name):
380    """Returns true iff a resource with same name exists.
381
382    Resources can be things such as tables, variables, trackables, etc.
383
384    Args:
385      feature_column: A `FeatureColumn` object this variable corresponds to.
386      resource_name: Name of the resource.
387    """
388    return resource_name in self._cols_to_resources_map[feature_column]
389
390  def get_resource(self, feature_column, resource_name):
391    """Returns an already created resource.
392
393    Resources can be things such as tables, variables, trackables, etc.
394
395    Args:
396      feature_column: A `FeatureColumn` object this variable corresponds to.
397      resource_name: Name of the resource.
398    """
399    if (feature_column not in self._cols_to_resources_map or
400        resource_name not in self._cols_to_resources_map[feature_column]):
401      raise ValueError('Resource does not exist.')
402    return self._cols_to_resources_map[feature_column][resource_name]
403
404
405def _transform_features_v2(features, feature_columns, state_manager):
406  """Returns transformed features based on features columns passed in.
407
408  Please note that most probably you would not need to use this function. Please
409  check `input_layer` and `linear_model` to see whether they will
410  satisfy your use case or not.
411
412  Example:
413
414  ```python
415  # Define features and transformations
416  crosses_a_x_b = crossed_column(
417      columns=["sparse_feature_a", "sparse_feature_b"], hash_bucket_size=10000)
418  price_buckets = bucketized_column(
419      source_column=numeric_column("price"), boundaries=[...])
420
421  columns = [crosses_a_x_b, price_buckets]
422  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
423  transformed = transform_features(features=features, feature_columns=columns)
424
425  assertCountEqual(columns, transformed.keys())
426  ```
427
428  Args:
429    features: A mapping from key to tensors. `FeatureColumn`s look up via these
430      keys. For example `numeric_column('price')` will look at 'price' key in
431      this dict. Values can be a `SparseTensor` or a `Tensor` depends on
432      corresponding `FeatureColumn`.
433    feature_columns: An iterable containing all the `FeatureColumn`s.
434    state_manager: A StateManager object that holds the FeatureColumn state.
435
436  Returns:
437    A `dict` mapping `FeatureColumn` to `Tensor` and `SparseTensor` values.
438  """
439  feature_columns = _normalize_feature_columns(feature_columns)
440  outputs = {}
441  with ops.name_scope(
442      None, default_name='transform_features', values=features.values()):
443    transformation_cache = FeatureTransformationCache(features)
444    for column in feature_columns:
445      with ops.name_scope(
446          None,
447          default_name=_sanitize_column_name_for_variable_scope(column.name)):
448        outputs[column] = transformation_cache.get(column, state_manager)
449  return outputs
450
451
452@tf_export('feature_column.make_parse_example_spec', v1=[])
453def make_parse_example_spec_v2(feature_columns):
454  """Creates parsing spec dictionary from input feature_columns.
455
456  The returned dictionary can be used as arg 'features' in
457  `tf.io.parse_example`.
458
459  Typical usage example:
460
461  ```python
462  # Define features and transformations
463  feature_a = tf.feature_column.categorical_column_with_vocabulary_file(...)
464  feature_b = tf.feature_column.numeric_column(...)
465  feature_c_bucketized = tf.feature_column.bucketized_column(
466      tf.feature_column.numeric_column("feature_c"), ...)
467  feature_a_x_feature_c = tf.feature_column.crossed_column(
468      columns=["feature_a", feature_c_bucketized], ...)
469
470  feature_columns = set(
471      [feature_b, feature_c_bucketized, feature_a_x_feature_c])
472  features = tf.io.parse_example(
473      serialized=serialized_examples,
474      features=tf.feature_column.make_parse_example_spec(feature_columns))
475  ```
476
477  For the above example, make_parse_example_spec would return the dict:
478
479  ```python
480  {
481      "feature_a": parsing_ops.VarLenFeature(tf.string),
482      "feature_b": parsing_ops.FixedLenFeature([1], dtype=tf.float32),
483      "feature_c": parsing_ops.FixedLenFeature([1], dtype=tf.float32)
484  }
485  ```
486
487  Args:
488    feature_columns: An iterable containing all feature columns. All items
489      should be instances of classes derived from `FeatureColumn`.
490
491  Returns:
492    A dict mapping each feature key to a `FixedLenFeature` or `VarLenFeature`
493    value.
494
495  Raises:
496    ValueError: If any of the given `feature_columns` is not a `FeatureColumn`
497      instance.
498  """
499  result = {}
500  for column in feature_columns:
501    if not isinstance(column, FeatureColumn):
502      raise ValueError('All feature_columns must be FeatureColumn instances. '
503                       'Given: {}'.format(column))
504    config = column.parse_example_spec
505    for key, value in six.iteritems(config):
506      if key in result and value != result[key]:
507        raise ValueError(
508            'feature_columns contain different parse_spec for key '
509            '{}. Given {} and {}'.format(key, value, result[key]))
510    result.update(config)
511  return result
512
513
514@tf_export('feature_column.embedding_column')
515def embedding_column(categorical_column,
516                     dimension,
517                     combiner='mean',
518                     initializer=None,
519                     ckpt_to_load_from=None,
520                     tensor_name_in_ckpt=None,
521                     max_norm=None,
522                     trainable=True,
523                     use_safe_embedding_lookup=True):
524  """`DenseColumn` that converts from sparse, categorical input.
525
526  Use this when your inputs are sparse, but you want to convert them to a dense
527  representation (e.g., to feed to a DNN).
528
529  Inputs must be a `CategoricalColumn` created by any of the
530  `categorical_column_*` function. Here is an example of using
531  `embedding_column` with `DNNClassifier`:
532
533  ```python
534  video_id = categorical_column_with_identity(
535      key='video_id', num_buckets=1000000, default_value=0)
536  columns = [embedding_column(video_id, 9),...]
537
538  estimator = tf.estimator.DNNClassifier(feature_columns=columns, ...)
539
540  label_column = ...
541  def input_fn():
542    features = tf.io.parse_example(
543        ..., features=make_parse_example_spec(columns + [label_column]))
544    labels = features.pop(label_column.name)
545    return features, labels
546
547  estimator.train(input_fn=input_fn, steps=100)
548  ```
549
550  Here is an example using `embedding_column` with model_fn:
551
552  ```python
553  def model_fn(features, ...):
554    video_id = categorical_column_with_identity(
555        key='video_id', num_buckets=1000000, default_value=0)
556    columns = [embedding_column(video_id, 9),...]
557    dense_tensor = input_layer(features, columns)
558    # Form DNN layers, calculate loss, and return EstimatorSpec.
559    ...
560  ```
561
562  Args:
563    categorical_column: A `CategoricalColumn` created by a
564      `categorical_column_with_*` function. This column produces the sparse IDs
565      that are inputs to the embedding lookup.
566    dimension: An integer specifying dimension of the embedding, must be > 0.
567    combiner: A string specifying how to reduce if there are multiple entries in
568      a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with
569      'mean' the default. 'sqrtn' often achieves good accuracy, in particular
570      with bag-of-words columns. Each of this can be thought as example level
571      normalizations on the column. For more information, see
572      `tf.embedding_lookup_sparse`.
573    initializer: A variable initializer function to be used in embedding
574      variable initialization. If not specified, defaults to
575      `truncated_normal_initializer` with mean `0.0` and
576      standard deviation `1/sqrt(dimension)`.
577    ckpt_to_load_from: String representing checkpoint name/pattern from which to
578      restore column weights. Required if `tensor_name_in_ckpt` is not `None`.
579    tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from which
580      to restore the column weights. Required if `ckpt_to_load_from` is not
581      `None`.
582    max_norm: If not `None`, embedding values are l2-normalized to this value.
583    trainable: Whether or not the embedding is trainable. Default is True.
584    use_safe_embedding_lookup: If true, uses safe_embedding_lookup_sparse
585      instead of embedding_lookup_sparse. safe_embedding_lookup_sparse ensures
586      there are no empty rows and all weights and ids are positive at the
587      expense of extra compute cost. This only applies to rank 2 (NxM) shaped
588      input tensors. Defaults to true, consider turning off if the above checks
589      are not needed. Note that having empty rows will not trigger any error
590      though the output result might be 0 or omitted.
591
592  Returns:
593    `DenseColumn` that converts from sparse input.
594
595  Raises:
596    ValueError: if `dimension` not > 0.
597    ValueError: if exactly one of `ckpt_to_load_from` and `tensor_name_in_ckpt`
598      is specified.
599    ValueError: if `initializer` is specified and is not callable.
600    RuntimeError: If eager execution is enabled.
601  """
602  if (dimension is None) or (dimension < 1):
603    raise ValueError('Invalid dimension {}.'.format(dimension))
604  if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None):
605    raise ValueError('Must specify both `ckpt_to_load_from` and '
606                     '`tensor_name_in_ckpt` or none of them.')
607
608  if (initializer is not None) and (not callable(initializer)):
609    raise ValueError('initializer must be callable if specified. '
610                     'Embedding of column_name: {}'.format(
611                         categorical_column.name))
612  if initializer is None:
613    initializer = init_ops.truncated_normal_initializer(
614        mean=0.0, stddev=1 / math.sqrt(dimension))
615
616  return EmbeddingColumn(
617      categorical_column=categorical_column,
618      dimension=dimension,
619      combiner=combiner,
620      initializer=initializer,
621      ckpt_to_load_from=ckpt_to_load_from,
622      tensor_name_in_ckpt=tensor_name_in_ckpt,
623      max_norm=max_norm,
624      trainable=trainable,
625      use_safe_embedding_lookup=use_safe_embedding_lookup)
626
627
628@tf_export(v1=['feature_column.shared_embedding_columns'])
629def shared_embedding_columns(categorical_columns,
630                             dimension,
631                             combiner='mean',
632                             initializer=None,
633                             shared_embedding_collection_name=None,
634                             ckpt_to_load_from=None,
635                             tensor_name_in_ckpt=None,
636                             max_norm=None,
637                             trainable=True,
638                             use_safe_embedding_lookup=True):
639  """List of dense columns that convert from sparse, categorical input.
640
641  This is similar to `embedding_column`, except that it produces a list of
642  embedding columns that share the same embedding weights.
643
644  Use this when your inputs are sparse and of the same type (e.g. watched and
645  impression video IDs that share the same vocabulary), and you want to convert
646  them to a dense representation (e.g., to feed to a DNN).
647
648  Inputs must be a list of categorical columns created by any of the
649  `categorical_column_*` function. They must all be of the same type and have
650  the same arguments except `key`. E.g. they can be
651  categorical_column_with_vocabulary_file with the same vocabulary_file. Some or
652  all columns could also be weighted_categorical_column.
653
654  Here is an example embedding of two features for a DNNClassifier model:
655
656  ```python
657  watched_video_id = categorical_column_with_vocabulary_file(
658      'watched_video_id', video_vocabulary_file, video_vocabulary_size)
659  impression_video_id = categorical_column_with_vocabulary_file(
660      'impression_video_id', video_vocabulary_file, video_vocabulary_size)
661  columns = shared_embedding_columns(
662      [watched_video_id, impression_video_id], dimension=10)
663
664  estimator = tf.estimator.DNNClassifier(feature_columns=columns, ...)
665
666  label_column = ...
667  def input_fn():
668    features = tf.io.parse_example(
669        ..., features=make_parse_example_spec(columns + [label_column]))
670    labels = features.pop(label_column.name)
671    return features, labels
672
673  estimator.train(input_fn=input_fn, steps=100)
674  ```
675
676  Here is an example using `shared_embedding_columns` with model_fn:
677
678  ```python
679  def model_fn(features, ...):
680    watched_video_id = categorical_column_with_vocabulary_file(
681        'watched_video_id', video_vocabulary_file, video_vocabulary_size)
682    impression_video_id = categorical_column_with_vocabulary_file(
683        'impression_video_id', video_vocabulary_file, video_vocabulary_size)
684    columns = shared_embedding_columns(
685        [watched_video_id, impression_video_id], dimension=10)
686    dense_tensor = input_layer(features, columns)
687    # Form DNN layers, calculate loss, and return EstimatorSpec.
688    ...
689  ```
690
691  Args:
692    categorical_columns: List of categorical columns created by a
693      `categorical_column_with_*` function. These columns produce the sparse IDs
694      that are inputs to the embedding lookup. All columns must be of the same
695      type and have the same arguments except `key`. E.g. they can be
696      categorical_column_with_vocabulary_file with the same vocabulary_file.
697      Some or all columns could also be weighted_categorical_column.
698    dimension: An integer specifying dimension of the embedding, must be > 0.
699    combiner: A string specifying how to reduce if there are multiple entries in
700      a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with
701      'mean' the default. 'sqrtn' often achieves good accuracy, in particular
702      with bag-of-words columns. Each of this can be thought as example level
703      normalizations on the column. For more information, see
704      `tf.embedding_lookup_sparse`.
705    initializer: A variable initializer function to be used in embedding
706      variable initialization. If not specified, defaults to
707      `truncated_normal_initializer` with mean `0.0` and
708      standard deviation `1/sqrt(dimension)`.
709    shared_embedding_collection_name: Optional name of the collection where
710      shared embedding weights are added. If not given, a reasonable name will
711      be chosen based on the names of `categorical_columns`. This is also used
712      in `variable_scope` when creating shared embedding weights.
713    ckpt_to_load_from: String representing checkpoint name/pattern from which to
714      restore column weights. Required if `tensor_name_in_ckpt` is not `None`.
715    tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from which
716      to restore the column weights. Required if `ckpt_to_load_from` is not
717      `None`.
718    max_norm: If not `None`, each embedding is clipped if its l2-norm is larger
719      than this value, before combining.
720    trainable: Whether or not the embedding is trainable. Default is True.
721    use_safe_embedding_lookup: If true, uses safe_embedding_lookup_sparse
722      instead of embedding_lookup_sparse. safe_embedding_lookup_sparse ensures
723      there are no empty rows and all weights and ids are positive at the
724      expense of extra compute cost. This only applies to rank 2 (NxM) shaped
725      input tensors. Defaults to true, consider turning off if the above checks
726      are not needed. Note that having empty rows will not trigger any error
727      though the output result might be 0 or omitted.
728
729  Returns:
730    A list of dense columns that converts from sparse input. The order of
731    results follows the ordering of `categorical_columns`.
732
733  Raises:
734    ValueError: if `dimension` not > 0.
735    ValueError: if any of the given `categorical_columns` is of different type
736      or has different arguments than the others.
737    ValueError: if exactly one of `ckpt_to_load_from` and `tensor_name_in_ckpt`
738      is specified.
739    ValueError: if `initializer` is specified and is not callable.
740    RuntimeError: if eager execution is enabled.
741  """
742  if context.executing_eagerly():
743    raise RuntimeError('shared_embedding_columns are not supported when eager '
744                       'execution is enabled.')
745
746  if (dimension is None) or (dimension < 1):
747    raise ValueError('Invalid dimension {}.'.format(dimension))
748  if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None):
749    raise ValueError('Must specify both `ckpt_to_load_from` and '
750                     '`tensor_name_in_ckpt` or none of them.')
751
752  if (initializer is not None) and (not callable(initializer)):
753    raise ValueError('initializer must be callable if specified.')
754  if initializer is None:
755    initializer = init_ops.truncated_normal_initializer(
756        mean=0.0, stddev=1. / math.sqrt(dimension))
757
758  # Sort the columns so the default collection name is deterministic even if the
759  # user passes columns from an unsorted collection, such as dict.values().
760  sorted_columns = sorted(categorical_columns, key=lambda x: x.name)
761
762  c0 = sorted_columns[0]
763  num_buckets = c0._num_buckets  # pylint: disable=protected-access
764  if not isinstance(c0, fc_old._CategoricalColumn):  # pylint: disable=protected-access
765    raise ValueError(
766        'All categorical_columns must be subclasses of _CategoricalColumn. '
767        'Given: {}, of type: {}'.format(c0, type(c0)))
768  while isinstance(
769      c0, (fc_old._WeightedCategoricalColumn, WeightedCategoricalColumn,  # pylint: disable=protected-access
770           fc_old._SequenceCategoricalColumn, SequenceCategoricalColumn)):  # pylint: disable=protected-access
771    c0 = c0.categorical_column
772  for c in sorted_columns[1:]:
773    while isinstance(
774        c, (fc_old._WeightedCategoricalColumn, WeightedCategoricalColumn,  # pylint: disable=protected-access
775            fc_old._SequenceCategoricalColumn, SequenceCategoricalColumn)):  # pylint: disable=protected-access
776      c = c.categorical_column
777    if not isinstance(c, type(c0)):
778      raise ValueError(
779          'To use shared_embedding_column, all categorical_columns must have '
780          'the same type, or be weighted_categorical_column or sequence column '
781          'of the same type. Given column: {} of type: {} does not match given '
782          'column: {} of type: {}'.format(c0, type(c0), c, type(c)))
783    if num_buckets != c._num_buckets:  # pylint: disable=protected-access
784      raise ValueError(
785          'To use shared_embedding_column, all categorical_columns must have '
786          'the same number of buckets. ven column: {} with buckets: {} does  '
787          'not match column: {} with buckets: {}'.format(
788              c0, num_buckets, c, c._num_buckets))  # pylint: disable=protected-access
789
790  if not shared_embedding_collection_name:
791    shared_embedding_collection_name = '_'.join(c.name for c in sorted_columns)
792    shared_embedding_collection_name += '_shared_embedding'
793
794  result = []
795  for column in categorical_columns:
796    result.append(
797        fc_old._SharedEmbeddingColumn(  # pylint: disable=protected-access
798            categorical_column=column,
799            initializer=initializer,
800            dimension=dimension,
801            combiner=combiner,
802            shared_embedding_collection_name=shared_embedding_collection_name,
803            ckpt_to_load_from=ckpt_to_load_from,
804            tensor_name_in_ckpt=tensor_name_in_ckpt,
805            max_norm=max_norm,
806            trainable=trainable,
807            use_safe_embedding_lookup=use_safe_embedding_lookup))
808
809  return result
810
811
812@tf_export('feature_column.shared_embeddings', v1=[])
813def shared_embedding_columns_v2(categorical_columns,
814                                dimension,
815                                combiner='mean',
816                                initializer=None,
817                                shared_embedding_collection_name=None,
818                                ckpt_to_load_from=None,
819                                tensor_name_in_ckpt=None,
820                                max_norm=None,
821                                trainable=True,
822                                use_safe_embedding_lookup=True):
823  """List of dense columns that convert from sparse, categorical input.
824
825  This is similar to `embedding_column`, except that it produces a list of
826  embedding columns that share the same embedding weights.
827
828  Use this when your inputs are sparse and of the same type (e.g. watched and
829  impression video IDs that share the same vocabulary), and you want to convert
830  them to a dense representation (e.g., to feed to a DNN).
831
832  Inputs must be a list of categorical columns created by any of the
833  `categorical_column_*` function. They must all be of the same type and have
834  the same arguments except `key`. E.g. they can be
835  categorical_column_with_vocabulary_file with the same vocabulary_file. Some or
836  all columns could also be weighted_categorical_column.
837
838  Here is an example embedding of two features for a DNNClassifier model:
839
840  ```python
841  watched_video_id = categorical_column_with_vocabulary_file(
842      'watched_video_id', video_vocabulary_file, video_vocabulary_size)
843  impression_video_id = categorical_column_with_vocabulary_file(
844      'impression_video_id', video_vocabulary_file, video_vocabulary_size)
845  columns = shared_embedding_columns(
846      [watched_video_id, impression_video_id], dimension=10)
847
848  estimator = tf.estimator.DNNClassifier(feature_columns=columns, ...)
849
850  label_column = ...
851  def input_fn():
852    features = tf.io.parse_example(
853        ..., features=make_parse_example_spec(columns + [label_column]))
854    labels = features.pop(label_column.name)
855    return features, labels
856
857  estimator.train(input_fn=input_fn, steps=100)
858  ```
859
860  Here is an example using `shared_embedding_columns` with model_fn:
861
862  ```python
863  def model_fn(features, ...):
864    watched_video_id = categorical_column_with_vocabulary_file(
865        'watched_video_id', video_vocabulary_file, video_vocabulary_size)
866    impression_video_id = categorical_column_with_vocabulary_file(
867        'impression_video_id', video_vocabulary_file, video_vocabulary_size)
868    columns = shared_embedding_columns(
869        [watched_video_id, impression_video_id], dimension=10)
870    dense_tensor = input_layer(features, columns)
871    # Form DNN layers, calculate loss, and return EstimatorSpec.
872    ...
873  ```
874
875  Args:
876    categorical_columns: List of categorical columns created by a
877      `categorical_column_with_*` function. These columns produce the sparse IDs
878      that are inputs to the embedding lookup. All columns must be of the same
879      type and have the same arguments except `key`. E.g. they can be
880      categorical_column_with_vocabulary_file with the same vocabulary_file.
881      Some or all columns could also be weighted_categorical_column.
882    dimension: An integer specifying dimension of the embedding, must be > 0.
883    combiner: A string specifying how to reduce if there are multiple entries
884      in a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with
885      'mean' the default. 'sqrtn' often achieves good accuracy, in particular
886      with bag-of-words columns. Each of this can be thought as example level
887      normalizations on the column. For more information, see
888      `tf.embedding_lookup_sparse`.
889    initializer: A variable initializer function to be used in embedding
890      variable initialization. If not specified, defaults to
891      `truncated_normal_initializer` with mean `0.0` and standard
892      deviation `1/sqrt(dimension)`.
893    shared_embedding_collection_name: Optional collective name of these columns.
894      If not given, a reasonable name will be chosen based on the names of
895      `categorical_columns`.
896    ckpt_to_load_from: String representing checkpoint name/pattern from which to
897      restore column weights. Required if `tensor_name_in_ckpt` is not `None`.
898    tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from
899      which to restore the column weights. Required if `ckpt_to_load_from` is
900      not `None`.
901    max_norm: If not `None`, each embedding is clipped if its l2-norm is
902      larger than this value, before combining.
903    trainable: Whether or not the embedding is trainable. Default is True.
904    use_safe_embedding_lookup: If true, uses safe_embedding_lookup_sparse
905      instead of embedding_lookup_sparse. safe_embedding_lookup_sparse ensures
906      there are no empty rows and all weights and ids are positive at the
907      expense of extra compute cost. This only applies to rank 2 (NxM) shaped
908      input tensors. Defaults to true, consider turning off if the above checks
909      are not needed. Note that having empty rows will not trigger any error
910      though the output result might be 0 or omitted.
911
912  Returns:
913    A list of dense columns that converts from sparse input. The order of
914    results follows the ordering of `categorical_columns`.
915
916  Raises:
917    ValueError: if `dimension` not > 0.
918    ValueError: if any of the given `categorical_columns` is of different type
919      or has different arguments than the others.
920    ValueError: if exactly one of `ckpt_to_load_from` and `tensor_name_in_ckpt`
921      is specified.
922    ValueError: if `initializer` is specified and is not callable.
923    RuntimeError: if eager execution is enabled.
924  """
925  if context.executing_eagerly():
926    raise RuntimeError('shared_embedding_columns are not supported when eager '
927                       'execution is enabled.')
928
929  if (dimension is None) or (dimension < 1):
930    raise ValueError('Invalid dimension {}.'.format(dimension))
931  if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None):
932    raise ValueError('Must specify both `ckpt_to_load_from` and '
933                     '`tensor_name_in_ckpt` or none of them.')
934
935  if (initializer is not None) and (not callable(initializer)):
936    raise ValueError('initializer must be callable if specified.')
937  if initializer is None:
938    initializer = init_ops.truncated_normal_initializer(
939        mean=0.0, stddev=1. / math.sqrt(dimension))
940
941  # Sort the columns so the default collection name is deterministic even if the
942  # user passes columns from an unsorted collection, such as dict.values().
943  sorted_columns = sorted(categorical_columns, key=lambda x: x.name)
944
945  c0 = sorted_columns[0]
946  num_buckets = c0.num_buckets
947  if not isinstance(c0, CategoricalColumn):
948    raise ValueError(
949        'All categorical_columns must be subclasses of CategoricalColumn. '
950        'Given: {}, of type: {}'.format(c0, type(c0)))
951  while isinstance(c0, (WeightedCategoricalColumn, SequenceCategoricalColumn)):
952    c0 = c0.categorical_column
953  for c in sorted_columns[1:]:
954    while isinstance(c, (WeightedCategoricalColumn, SequenceCategoricalColumn)):
955      c = c.categorical_column
956    if not isinstance(c, type(c0)):
957      raise ValueError(
958          'To use shared_embedding_column, all categorical_columns must have '
959          'the same type, or be weighted_categorical_column or sequence column '
960          'of the same type. Given column: {} of type: {} does not match given '
961          'column: {} of type: {}'.format(c0, type(c0), c, type(c)))
962    if num_buckets != c.num_buckets:
963      raise ValueError(
964          'To use shared_embedding_column, all categorical_columns must have '
965          'the same number of buckets. Given column: {} with buckets: {} does  '
966          'not match column: {} with buckets: {}'.format(
967              c0, num_buckets, c, c.num_buckets))
968
969  if not shared_embedding_collection_name:
970    shared_embedding_collection_name = '_'.join(c.name for c in sorted_columns)
971    shared_embedding_collection_name += '_shared_embedding'
972
973  column_creator = SharedEmbeddingColumnCreator(
974      dimension, initializer, ckpt_to_load_from, tensor_name_in_ckpt,
975      num_buckets, trainable, shared_embedding_collection_name,
976      use_safe_embedding_lookup)
977
978  result = []
979  for column in categorical_columns:
980    result.append(
981        column_creator(
982            categorical_column=column, combiner=combiner, max_norm=max_norm))
983
984  return result
985
986
987@tf_export('feature_column.numeric_column')
988def numeric_column(key,
989                   shape=(1,),
990                   default_value=None,
991                   dtype=dtypes.float32,
992                   normalizer_fn=None):
993  """Represents real valued or numerical features.
994
995  Example:
996
997  Assume we have data with two features `a` and `b`.
998
999  >>> data = {'a': [15, 9, 17, 19, 21, 18, 25, 30],
1000  ...    'b': [5.0, 6.4, 10.5, 13.6, 15.7, 19.9, 20.3 , 0.0]}
1001
1002  Let us represent the features `a` and `b` as numerical features.
1003
1004  >>> a = tf.feature_column.numeric_column('a')
1005  >>> b = tf.feature_column.numeric_column('b')
1006
1007  Feature column describe a set of transformations to the inputs.
1008
1009  For example, to "bucketize" feature `a`, wrap the `a` column in a
1010  `feature_column.bucketized_column`.
1011  Providing `5` bucket boundaries, the bucketized_column api
1012  will bucket this feature in total of `6` buckets.
1013
1014  >>> a_buckets = tf.feature_column.bucketized_column(a,
1015  ...    boundaries=[10, 15, 20, 25, 30])
1016
1017  Create a `DenseFeatures` layer which will apply the transformations
1018  described by the set of `tf.feature_column` objects:
1019
1020  >>> feature_layer = tf.keras.layers.DenseFeatures([a_buckets, b])
1021  >>> print(feature_layer(data))
1022  tf.Tensor(
1023  [[ 0.   0.   1.   0.   0.   0.   5. ]
1024   [ 1.   0.   0.   0.   0.   0.   6.4]
1025   [ 0.   0.   1.   0.   0.   0.  10.5]
1026   [ 0.   0.   1.   0.   0.   0.  13.6]
1027   [ 0.   0.   0.   1.   0.   0.  15.7]
1028   [ 0.   0.   1.   0.   0.   0.  19.9]
1029   [ 0.   0.   0.   0.   1.   0.  20.3]
1030   [ 0.   0.   0.   0.   0.   1.   0. ]], shape=(8, 7), dtype=float32)
1031
1032  Args:
1033    key: A unique string identifying the input feature. It is used as the
1034      column name and the dictionary key for feature parsing configs, feature
1035      `Tensor` objects, and feature columns.
1036    shape: An iterable of integers specifies the shape of the `Tensor`. An
1037      integer can be given which means a single dimension `Tensor` with given
1038      width. The `Tensor` representing the column will have the shape of
1039      [batch_size] + `shape`.
1040    default_value: A single value compatible with `dtype` or an iterable of
1041      values compatible with `dtype` which the column takes on during
1042      `tf.Example` parsing if data is missing. A default value of `None` will
1043      cause `tf.io.parse_example` to fail if an example does not contain this
1044      column. If a single value is provided, the same value will be applied as
1045      the default value for every item. If an iterable of values is provided,
1046      the shape of the `default_value` should be equal to the given `shape`.
1047    dtype: defines the type of values. Default value is `tf.float32`. Must be a
1048      non-quantized, real integer or floating point type.
1049    normalizer_fn: If not `None`, a function that can be used to normalize the
1050      value of the tensor after `default_value` is applied for parsing.
1051      Normalizer function takes the input `Tensor` as its argument, and returns
1052      the output `Tensor`. (e.g. lambda x: (x - 3.0) / 4.2). Please note that
1053      even though the most common use case of this function is normalization, it
1054      can be used for any kind of Tensorflow transformations.
1055
1056  Returns:
1057    A `NumericColumn`.
1058
1059  Raises:
1060    TypeError: if any dimension in shape is not an int
1061    ValueError: if any dimension in shape is not a positive integer
1062    TypeError: if `default_value` is an iterable but not compatible with `shape`
1063    TypeError: if `default_value` is not compatible with `dtype`.
1064    ValueError: if `dtype` is not convertible to `tf.float32`.
1065  """
1066  shape = _check_shape(shape, key)
1067  if not (dtype.is_integer or dtype.is_floating):
1068    raise ValueError('dtype must be convertible to float. '
1069                     'dtype: {}, key: {}'.format(dtype, key))
1070  default_value = fc_utils.check_default_value(
1071      shape, default_value, dtype, key)
1072
1073  if normalizer_fn is not None and not callable(normalizer_fn):
1074    raise TypeError(
1075        'normalizer_fn must be a callable. Given: {}'.format(normalizer_fn))
1076
1077  fc_utils.assert_key_is_string(key)
1078  return NumericColumn(
1079      key,
1080      shape=shape,
1081      default_value=default_value,
1082      dtype=dtype,
1083      normalizer_fn=normalizer_fn)
1084
1085
1086@tf_export('feature_column.bucketized_column')
1087def bucketized_column(source_column, boundaries):
1088  """Represents discretized dense input bucketed by `boundaries`.
1089
1090  Buckets include the left boundary, and exclude the right boundary. Namely,
1091  `boundaries=[0., 1., 2.]` generates buckets `(-inf, 0.)`, `[0., 1.)`,
1092  `[1., 2.)`, and `[2., +inf)`.
1093
1094  For example, if the inputs are
1095
1096  ```python
1097  boundaries = [0, 10, 100]
1098  input tensor = [[-5, 10000]
1099                  [150,   10]
1100                  [5,    100]]
1101  ```
1102
1103  then the output will be
1104
1105  ```python
1106  output = [[0, 3]
1107            [3, 2]
1108            [1, 3]]
1109  ```
1110
1111  Example:
1112
1113  ```python
1114  price = tf.feature_column.numeric_column('price')
1115  bucketized_price = tf.feature_column.bucketized_column(
1116      price, boundaries=[...])
1117  columns = [bucketized_price, ...]
1118  features = tf.io.parse_example(
1119      ..., features=tf.feature_column.make_parse_example_spec(columns))
1120  dense_tensor = tf.keras.layers.DenseFeatures(columns)(features)
1121  ```
1122
1123  A `bucketized_column` can also be crossed with another categorical column
1124  using `crossed_column`:
1125
1126  ```python
1127  price = tf.feature_column.numeric_column('price')
1128  # bucketized_column converts numerical feature to a categorical one.
1129  bucketized_price = tf.feature_column.bucketized_column(
1130      price, boundaries=[...])
1131  # 'keywords' is a string feature.
1132  price_x_keywords = tf.feature_column.crossed_column(
1133      [bucketized_price, 'keywords'], 50K)
1134  columns = [price_x_keywords, ...]
1135  features = tf.io.parse_example(
1136      ..., features=tf.feature_column.make_parse_example_spec(columns))
1137  dense_tensor = tf.keras.layers.DenseFeatures(columns)(features)
1138  linear_model = tf.keras.experimental.LinearModel(units=...)(dense_tensor)
1139  ```
1140
1141  Args:
1142    source_column: A one-dimensional dense column which is generated with
1143      `numeric_column`.
1144    boundaries: A sorted list or tuple of floats specifying the boundaries.
1145
1146  Returns:
1147    A `BucketizedColumn`.
1148
1149  Raises:
1150    ValueError: If `source_column` is not a numeric column, or if it is not
1151      one-dimensional.
1152    ValueError: If `boundaries` is not a sorted list or tuple.
1153  """
1154  if not isinstance(source_column, (NumericColumn, fc_old._NumericColumn)):  # pylint: disable=protected-access
1155    raise ValueError(
1156        'source_column must be a column generated with numeric_column(). '
1157        'Given: {}'.format(source_column))
1158  if len(source_column.shape) > 1:
1159    raise ValueError(
1160        'source_column must be one-dimensional column. '
1161        'Given: {}'.format(source_column))
1162  if not boundaries:
1163    raise ValueError('boundaries must not be empty.')
1164  if not (isinstance(boundaries, list) or isinstance(boundaries, tuple)):
1165    raise ValueError('boundaries must be a sorted list.')
1166  for i in range(len(boundaries) - 1):
1167    if boundaries[i] >= boundaries[i + 1]:
1168      raise ValueError('boundaries must be a sorted list.')
1169  return BucketizedColumn(source_column, tuple(boundaries))
1170
1171
1172@tf_export('feature_column.categorical_column_with_hash_bucket')
1173def categorical_column_with_hash_bucket(key,
1174                                        hash_bucket_size,
1175                                        dtype=dtypes.string):
1176  """Represents sparse feature where ids are set by hashing.
1177
1178  Use this when your sparse features are in string or integer format, and you
1179  want to distribute your inputs into a finite number of buckets by hashing.
1180  output_id = Hash(input_feature_string) % bucket_size for string type input.
1181  For int type input, the value is converted to its string representation first
1182  and then hashed by the same formula.
1183
1184  For input dictionary `features`, `features[key]` is either `Tensor` or
1185  `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
1186  and `''` for string, which will be dropped by this feature column.
1187
1188  Example:
1189
1190  ```python
1191  import tensorflow as tf
1192  keywords = tf.feature_column.categorical_column_with_hash_bucket("keywords",
1193  10000)
1194  columns = [keywords]
1195  features = {'keywords': tf.constant([['Tensorflow', 'Keras', 'RNN', 'LSTM',
1196  'CNN'], ['LSTM', 'CNN', 'Tensorflow', 'Keras', 'RNN'], ['CNN', 'Tensorflow',
1197  'LSTM', 'Keras', 'RNN']])}
1198  linear_prediction, _, _ = tf.compat.v1.feature_column.linear_model(features,
1199  columns)
1200
1201  # or
1202  import tensorflow as tf
1203  keywords = tf.feature_column.categorical_column_with_hash_bucket("keywords",
1204  10000)
1205  keywords_embedded = tf.feature_column.embedding_column(keywords, 16)
1206  columns = [keywords_embedded]
1207  features = {'keywords': tf.constant([['Tensorflow', 'Keras', 'RNN', 'LSTM',
1208  'CNN'], ['LSTM', 'CNN', 'Tensorflow', 'Keras', 'RNN'], ['CNN', 'Tensorflow',
1209  'LSTM', 'Keras', 'RNN']])}
1210  input_layer = tf.keras.layers.DenseFeatures(columns)
1211  dense_tensor = input_layer(features)
1212  ```
1213
1214  Args:
1215    key: A unique string identifying the input feature. It is used as the
1216      column name and the dictionary key for feature parsing configs, feature
1217      `Tensor` objects, and feature columns.
1218    hash_bucket_size: An int > 1. The number of buckets.
1219    dtype: The type of features. Only string and integer types are supported.
1220
1221  Returns:
1222    A `HashedCategoricalColumn`.
1223
1224  Raises:
1225    ValueError: `hash_bucket_size` is not greater than 1.
1226    ValueError: `dtype` is neither string nor integer.
1227  """
1228  if hash_bucket_size is None:
1229    raise ValueError('hash_bucket_size must be set. ' 'key: {}'.format(key))
1230
1231  if hash_bucket_size < 1:
1232    raise ValueError('hash_bucket_size must be at least 1. '
1233                     'hash_bucket_size: {}, key: {}'.format(
1234                         hash_bucket_size, key))
1235
1236  fc_utils.assert_key_is_string(key)
1237  fc_utils.assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
1238
1239  return HashedCategoricalColumn(key, hash_bucket_size, dtype)
1240
1241
1242@tf_export(v1=['feature_column.categorical_column_with_vocabulary_file'])
1243def categorical_column_with_vocabulary_file(key,
1244                                            vocabulary_file,
1245                                            vocabulary_size=None,
1246                                            num_oov_buckets=0,
1247                                            default_value=None,
1248                                            dtype=dtypes.string):
1249  """A `CategoricalColumn` with a vocabulary file.
1250
1251  Use this when your inputs are in string or integer format, and you have a
1252  vocabulary file that maps each value to an integer ID. By default,
1253  out-of-vocabulary values are ignored. Use either (but not both) of
1254  `num_oov_buckets` and `default_value` to specify how to include
1255  out-of-vocabulary values.
1256
1257  For input dictionary `features`, `features[key]` is either `Tensor` or
1258  `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
1259  and `''` for string, which will be dropped by this feature column.
1260
1261  Example with `num_oov_buckets`:
1262  File '/us/states.txt' contains 50 lines, each with a 2-character U.S. state
1263  abbreviation. All inputs with values in that file are assigned an ID 0-49,
1264  corresponding to its line number. All other values are hashed and assigned an
1265  ID 50-54.
1266
1267  ```python
1268  import tensorflow as tf
1269  states = tf.feature_column.categorical_column_with_vocabulary_file(
1270    key='states', vocabulary_file='states.txt', vocabulary_size=5,
1271    num_oov_buckets=1)
1272  columns = [states]
1273  features = {'states':tf.constant([['california', 'georgia', 'michigan',
1274  'texas', 'new york'], ['new york', 'georgia', 'california', 'michigan',
1275  'texas']])}
1276  linear_prediction = tf.compat.v1.feature_column.linear_model(features,
1277  columns)
1278  ```
1279
1280  Example with `default_value`:
1281  File '/us/states.txt' contains 51 lines - the first line is 'XX', and the
1282  other 50 each have a 2-character U.S. state abbreviation. Both a literal 'XX'
1283  in input, and other values missing from the file, will be assigned ID 0. All
1284  others are assigned the corresponding line number 1-50.
1285
1286  ```python
1287  import tensorflow as tf
1288  states = tf.feature_column.categorical_column_with_vocabulary_file(
1289    key='states', vocabulary_file='states.txt', vocabulary_size=6,
1290    default_value=0)
1291  columns = [states]
1292  features = {'states':tf.constant([['california', 'georgia', 'michigan',
1293  'texas', 'new york'], ['new york', 'georgia', 'california', 'michigan',
1294  'texas']])}
1295  linear_prediction = tf.compat.v1.feature_column.linear_model(features,
1296  columns)
1297  ```
1298
1299  And to make an embedding with either:
1300
1301  ```python
1302  import tensorflow as tf
1303  states = tf.feature_column.categorical_column_with_vocabulary_file(
1304    key='states', vocabulary_file='states.txt', vocabulary_size=5,
1305    num_oov_buckets=1)
1306  columns = [tf.feature_column.embedding_column(states, 3)]
1307  features = {'states':tf.constant([['california', 'georgia', 'michigan',
1308  'texas', 'new york'], ['new york', 'georgia', 'california', 'michigan',
1309  'texas']])}
1310  input_layer = tf.keras.layers.DenseFeatures(columns)
1311  dense_tensor = input_layer(features)
1312  ```
1313
1314  Args:
1315    key: A unique string identifying the input feature. It is used as the
1316      column name and the dictionary key for feature parsing configs, feature
1317      `Tensor` objects, and feature columns.
1318    vocabulary_file: The vocabulary file name.
1319    vocabulary_size: Number of the elements in the vocabulary. This must be no
1320      greater than length of `vocabulary_file`, if less than length, later
1321      values are ignored. If None, it is set to the length of `vocabulary_file`.
1322    num_oov_buckets: Non-negative integer, the number of out-of-vocabulary
1323      buckets. All out-of-vocabulary inputs will be assigned IDs in the range
1324      `[vocabulary_size, vocabulary_size+num_oov_buckets)` based on a hash of
1325      the input value. A positive `num_oov_buckets` can not be specified with
1326      `default_value`.
1327    default_value: The integer ID value to return for out-of-vocabulary feature
1328      values, defaults to `-1`. This can not be specified with a positive
1329      `num_oov_buckets`.
1330    dtype: The type of features. Only string and integer types are supported.
1331
1332  Returns:
1333    A `CategoricalColumn` with a vocabulary file.
1334
1335  Raises:
1336    ValueError: `vocabulary_file` is missing or cannot be opened.
1337    ValueError: `vocabulary_size` is missing or < 1.
1338    ValueError: `num_oov_buckets` is a negative integer.
1339    ValueError: `num_oov_buckets` and `default_value` are both specified.
1340    ValueError: `dtype` is neither string nor integer.
1341  """
1342  return categorical_column_with_vocabulary_file_v2(
1343      key, vocabulary_file, vocabulary_size,
1344      dtype, default_value,
1345      num_oov_buckets)
1346
1347
1348@tf_export('feature_column.categorical_column_with_vocabulary_file', v1=[])
1349def categorical_column_with_vocabulary_file_v2(key,
1350                                               vocabulary_file,
1351                                               vocabulary_size=None,
1352                                               dtype=dtypes.string,
1353                                               default_value=None,
1354                                               num_oov_buckets=0,
1355                                               file_format=None):
1356  """A `CategoricalColumn` with a vocabulary file.
1357
1358  Use this when your inputs are in string or integer format, and you have a
1359  vocabulary file that maps each value to an integer ID. By default,
1360  out-of-vocabulary values are ignored. Use either (but not both) of
1361  `num_oov_buckets` and `default_value` to specify how to include
1362  out-of-vocabulary values.
1363
1364  For input dictionary `features`, `features[key]` is either `Tensor` or
1365  `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
1366  and `''` for string, which will be dropped by this feature column.
1367
1368  Example with `num_oov_buckets`:
1369  File `'/us/states.txt'` contains 50 lines, each with a 2-character U.S. state
1370  abbreviation. All inputs with values in that file are assigned an ID 0-49,
1371  corresponding to its line number. All other values are hashed and assigned an
1372  ID 50-54.
1373
1374  ```python
1375  states = categorical_column_with_vocabulary_file(
1376      key='states', vocabulary_file='/us/states.txt', vocabulary_size=50,
1377      num_oov_buckets=5)
1378  columns = [states, ...]
1379  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1380  linear_prediction = linear_model(features, columns)
1381  ```
1382
1383  Example with `default_value`:
1384  File `'/us/states.txt'` contains 51 lines - the first line is `'XX'`, and the
1385  other 50 each have a 2-character U.S. state abbreviation. Both a literal
1386  `'XX'` in input, and other values missing from the file, will be assigned
1387  ID 0. All others are assigned the corresponding line number 1-50.
1388
1389  ```python
1390  states = categorical_column_with_vocabulary_file(
1391      key='states', vocabulary_file='/us/states.txt', vocabulary_size=51,
1392      default_value=0)
1393  columns = [states, ...]
1394  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1395  linear_prediction, _, _ = linear_model(features, columns)
1396  ```
1397
1398  And to make an embedding with either:
1399
1400  ```python
1401  columns = [embedding_column(states, 3),...]
1402  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1403  dense_tensor = input_layer(features, columns)
1404  ```
1405
1406  Args:
1407    key: A unique string identifying the input feature. It is used as the
1408      column name and the dictionary key for feature parsing configs, feature
1409      `Tensor` objects, and feature columns.
1410    vocabulary_file: The vocabulary file name.
1411    vocabulary_size: Number of the elements in the vocabulary. This must be no
1412      greater than length of `vocabulary_file`, if less than length, later
1413      values are ignored. If None, it is set to the length of `vocabulary_file`.
1414    dtype: The type of features. Only string and integer types are supported.
1415    default_value: The integer ID value to return for out-of-vocabulary feature
1416      values, defaults to `-1`. This can not be specified with a positive
1417      `num_oov_buckets`.
1418    num_oov_buckets: Non-negative integer, the number of out-of-vocabulary
1419      buckets. All out-of-vocabulary inputs will be assigned IDs in the range
1420      `[vocabulary_size, vocabulary_size+num_oov_buckets)` based on a hash of
1421      the input value. A positive `num_oov_buckets` can not be specified with
1422      `default_value`.
1423    file_format: The format of the vocabulary file. The format is 'text' by
1424      default unless `vocabulary_file` is a string which ends in 'tfrecord.gz'.
1425      Accepted alternative value for `file_format` is 'tfrecord_gzip'.
1426
1427  Returns:
1428    A `CategoricalColumn` with a vocabulary file.
1429
1430  Raises:
1431    ValueError: `vocabulary_file` is missing or cannot be opened.
1432    ValueError: `vocabulary_size` is missing or < 1.
1433    ValueError: `num_oov_buckets` is a negative integer.
1434    ValueError: `num_oov_buckets` and `default_value` are both specified.
1435    ValueError: `dtype` is neither string nor integer.
1436  """
1437  if not vocabulary_file:
1438    raise ValueError('Missing vocabulary_file in {}.'.format(key))
1439
1440  if file_format is None and vocabulary_file.endswith('tfrecord.gz'):
1441    file_format = 'tfrecord_gzip'
1442
1443  if vocabulary_size is None:
1444    if not gfile.Exists(vocabulary_file):
1445      raise ValueError('vocabulary_file in {} does not exist.'.format(key))
1446
1447    if file_format == 'tfrecord_gzip':
1448      ds = readers.TFRecordDataset(vocabulary_file, 'GZIP')
1449      vocabulary_size = ds.reduce(0, lambda x, _: x + 1)
1450      if context.executing_eagerly():
1451        vocabulary_size = vocabulary_size.numpy()
1452    else:
1453      with gfile.GFile(vocabulary_file, mode='rb') as f:
1454        vocabulary_size = sum(1 for _ in f)
1455    logging.info(
1456        'vocabulary_size = %d in %s is inferred from the number of elements '
1457        'in the vocabulary_file %s.', vocabulary_size, key, vocabulary_file)
1458
1459  # `vocabulary_size` isn't required for lookup, but it is for `_num_buckets`.
1460  if not isinstance(vocabulary_size, ops.Tensor) and vocabulary_size < 1:
1461    raise ValueError('Invalid vocabulary_size in {}.'.format(key))
1462  if num_oov_buckets:
1463    if default_value is not None:
1464      raise ValueError(
1465          'Can\'t specify both num_oov_buckets and default_value in {}.'.format(
1466              key))
1467    if num_oov_buckets < 0:
1468      raise ValueError('Invalid num_oov_buckets {} in {}.'.format(
1469          num_oov_buckets, key))
1470  fc_utils.assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
1471  fc_utils.assert_key_is_string(key)
1472  return VocabularyFileCategoricalColumn(
1473      key=key,
1474      vocabulary_file=vocabulary_file,
1475      vocabulary_size=vocabulary_size,
1476      num_oov_buckets=0 if num_oov_buckets is None else num_oov_buckets,
1477      default_value=-1 if default_value is None else default_value,
1478      dtype=dtype,
1479      file_format=file_format)
1480
1481
1482@tf_export('feature_column.categorical_column_with_vocabulary_list')
1483def categorical_column_with_vocabulary_list(key,
1484                                            vocabulary_list,
1485                                            dtype=None,
1486                                            default_value=-1,
1487                                            num_oov_buckets=0):
1488  """A `CategoricalColumn` with in-memory vocabulary.
1489
1490  Use this when your inputs are in string or integer format, and you have an
1491  in-memory vocabulary mapping each value to an integer ID. By default,
1492  out-of-vocabulary values are ignored. Use either (but not both) of
1493  `num_oov_buckets` and `default_value` to specify how to include
1494  out-of-vocabulary values.
1495
1496  For input dictionary `features`, `features[key]` is either `Tensor` or
1497  `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
1498  and `''` for string, which will be dropped by this feature column.
1499
1500  Example with `num_oov_buckets`:
1501  In the following example, each input in `vocabulary_list` is assigned an ID
1502  0-3 corresponding to its index (e.g., input 'B' produces output 2). All other
1503  inputs are hashed and assigned an ID 4-5.
1504
1505  ```python
1506  colors = categorical_column_with_vocabulary_list(
1507      key='colors', vocabulary_list=('R', 'G', 'B', 'Y'),
1508      num_oov_buckets=2)
1509  columns = [colors, ...]
1510  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1511  linear_prediction, _, _ = linear_model(features, columns)
1512  ```
1513
1514  Example with `default_value`:
1515  In the following example, each input in `vocabulary_list` is assigned an ID
1516  0-4 corresponding to its index (e.g., input 'B' produces output 3). All other
1517  inputs are assigned `default_value` 0.
1518
1519
1520  ```python
1521  colors = categorical_column_with_vocabulary_list(
1522      key='colors', vocabulary_list=('X', 'R', 'G', 'B', 'Y'), default_value=0)
1523  columns = [colors, ...]
1524  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1525  linear_prediction, _, _ = linear_model(features, columns)
1526  ```
1527
1528  And to make an embedding with either:
1529
1530  ```python
1531  columns = [embedding_column(colors, 3),...]
1532  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1533  dense_tensor = input_layer(features, columns)
1534  ```
1535
1536  Args:
1537    key: A unique string identifying the input feature. It is used as the column
1538      name and the dictionary key for feature parsing configs, feature `Tensor`
1539      objects, and feature columns.
1540    vocabulary_list: An ordered iterable defining the vocabulary. Each feature
1541      is mapped to the index of its value (if present) in `vocabulary_list`.
1542      Must be castable to `dtype`.
1543    dtype: The type of features. Only string and integer types are supported. If
1544      `None`, it will be inferred from `vocabulary_list`.
1545    default_value: The integer ID value to return for out-of-vocabulary feature
1546      values, defaults to `-1`. This can not be specified with a positive
1547      `num_oov_buckets`.
1548    num_oov_buckets: Non-negative integer, the number of out-of-vocabulary
1549      buckets. All out-of-vocabulary inputs will be assigned IDs in the range
1550      `[len(vocabulary_list), len(vocabulary_list)+num_oov_buckets)` based on a
1551      hash of the input value. A positive `num_oov_buckets` can not be specified
1552      with `default_value`.
1553
1554  Returns:
1555    A `CategoricalColumn` with in-memory vocabulary.
1556
1557  Raises:
1558    ValueError: if `vocabulary_list` is empty, or contains duplicate keys.
1559    ValueError: `num_oov_buckets` is a negative integer.
1560    ValueError: `num_oov_buckets` and `default_value` are both specified.
1561    ValueError: if `dtype` is not integer or string.
1562  """
1563  if (vocabulary_list is None) or (len(vocabulary_list) < 1):
1564    raise ValueError(
1565        'vocabulary_list {} must be non-empty, column_name: {}'.format(
1566            vocabulary_list, key))
1567  if len(set(vocabulary_list)) != len(vocabulary_list):
1568    raise ValueError(
1569        'Duplicate keys in vocabulary_list {}, column_name: {}'.format(
1570            vocabulary_list, key))
1571  vocabulary_dtype = dtypes.as_dtype(np.array(vocabulary_list).dtype)
1572  if num_oov_buckets:
1573    if default_value != -1:
1574      raise ValueError(
1575          'Can\'t specify both num_oov_buckets and default_value in {}.'.format(
1576              key))
1577    if num_oov_buckets < 0:
1578      raise ValueError('Invalid num_oov_buckets {} in {}.'.format(
1579          num_oov_buckets, key))
1580  fc_utils.assert_string_or_int(
1581      vocabulary_dtype, prefix='column_name: {} vocabulary'.format(key))
1582  if dtype is None:
1583    dtype = vocabulary_dtype
1584  elif dtype.is_integer != vocabulary_dtype.is_integer:
1585    raise ValueError(
1586        'dtype {} and vocabulary dtype {} do not match, column_name: {}'.format(
1587            dtype, vocabulary_dtype, key))
1588  fc_utils.assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
1589  fc_utils.assert_key_is_string(key)
1590
1591  return VocabularyListCategoricalColumn(
1592      key=key,
1593      vocabulary_list=tuple(vocabulary_list),
1594      dtype=dtype,
1595      default_value=default_value,
1596      num_oov_buckets=num_oov_buckets)
1597
1598
1599@tf_export('feature_column.categorical_column_with_identity')
1600def categorical_column_with_identity(key, num_buckets, default_value=None):
1601  """A `CategoricalColumn` that returns identity values.
1602
1603  Use this when your inputs are integers in the range `[0, num_buckets)`, and
1604  you want to use the input value itself as the categorical ID. Values outside
1605  this range will result in `default_value` if specified, otherwise it will
1606  fail.
1607
1608  Typically, this is used for contiguous ranges of integer indexes, but
1609  it doesn't have to be. This might be inefficient, however, if many of IDs
1610  are unused. Consider `categorical_column_with_hash_bucket` in that case.
1611
1612  For input dictionary `features`, `features[key]` is either `Tensor` or
1613  `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
1614  and `''` for string, which will be dropped by this feature column.
1615
1616  In the following examples, each input in the range `[0, 1000000)` is assigned
1617  the same value. All other inputs are assigned `default_value` 0. Note that a
1618  literal 0 in inputs will result in the same default ID.
1619
1620  Linear model:
1621
1622  ```python
1623  import tensorflow as tf
1624  video_id = tf.feature_column.categorical_column_with_identity(
1625      key='video_id', num_buckets=1000000, default_value=0)
1626  columns = [video_id]
1627  features = {'video_id': tf.sparse.from_dense([[2, 85, 0, 0, 0],
1628  [33,78, 2, 73, 1]])}
1629  linear_prediction = tf.compat.v1.feature_column.linear_model(features,
1630  columns)
1631  ```
1632
1633  Embedding for a DNN model:
1634
1635  ```python
1636  import tensorflow as tf
1637  video_id = tf.feature_column.categorical_column_with_identity(
1638      key='video_id', num_buckets=1000000, default_value=0)
1639  columns = [tf.feature_column.embedding_column(video_id, 9)]
1640  features = {'video_id': tf.sparse.from_dense([[2, 85, 0, 0, 0],
1641  [33,78, 2, 73, 1]])}
1642  input_layer = tf.keras.layers.DenseFeatures(columns)
1643  dense_tensor = input_layer(features)
1644  ```
1645
1646  Args:
1647    key: A unique string identifying the input feature. It is used as the
1648      column name and the dictionary key for feature parsing configs, feature
1649      `Tensor` objects, and feature columns.
1650    num_buckets: Range of inputs and outputs is `[0, num_buckets)`.
1651    default_value: If set, values outside of range `[0, num_buckets)` will
1652      be replaced with this value. If not set, values >= num_buckets will
1653      cause a failure while values < 0 will be dropped.
1654
1655  Returns:
1656    A `CategoricalColumn` that returns identity values.
1657
1658  Raises:
1659    ValueError: if `num_buckets` is less than one.
1660    ValueError: if `default_value` is not in range `[0, num_buckets)`.
1661  """
1662  if num_buckets < 1:
1663    raise ValueError(
1664        'num_buckets {} < 1, column_name {}'.format(num_buckets, key))
1665  if (default_value is not None) and (
1666      (default_value < 0) or (default_value >= num_buckets)):
1667    raise ValueError(
1668        'default_value {} not in range [0, {}), column_name {}'.format(
1669            default_value, num_buckets, key))
1670  fc_utils.assert_key_is_string(key)
1671  return IdentityCategoricalColumn(
1672      key=key, number_buckets=num_buckets, default_value=default_value)
1673
1674
1675@tf_export('feature_column.indicator_column')
1676def indicator_column(categorical_column):
1677  """Represents multi-hot representation of given categorical column.
1678
1679  - For DNN model, `indicator_column` can be used to wrap any
1680    `categorical_column_*` (e.g., to feed to DNN). Consider to Use
1681    `embedding_column` if the number of buckets/unique(values) are large.
1682
1683  - For Wide (aka linear) model, `indicator_column` is the internal
1684    representation for categorical column when passing categorical column
1685    directly (as any element in feature_columns) to `linear_model`. See
1686    `linear_model` for details.
1687
1688  ```python
1689  name = indicator_column(categorical_column_with_vocabulary_list(
1690      'name', ['bob', 'george', 'wanda']))
1691  columns = [name, ...]
1692  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1693  dense_tensor = input_layer(features, columns)
1694
1695  dense_tensor == [[1, 0, 0]]  # If "name" bytes_list is ["bob"]
1696  dense_tensor == [[1, 0, 1]]  # If "name" bytes_list is ["bob", "wanda"]
1697  dense_tensor == [[2, 0, 0]]  # If "name" bytes_list is ["bob", "bob"]
1698  ```
1699
1700  Args:
1701    categorical_column: A `CategoricalColumn` which is created by
1702      `categorical_column_with_*` or `crossed_column` functions.
1703
1704  Returns:
1705    An `IndicatorColumn`.
1706
1707  Raises:
1708    ValueError: If `categorical_column` is not CategoricalColumn type.
1709  """
1710  if not isinstance(categorical_column,
1711                    (CategoricalColumn, fc_old._CategoricalColumn)):  # pylint: disable=protected-access
1712    raise ValueError(
1713        'Unsupported input type. Input must be a CategoricalColumn. '
1714        'Given: {}'.format(categorical_column))
1715  return IndicatorColumn(categorical_column)
1716
1717
1718@tf_export('feature_column.weighted_categorical_column')
1719def weighted_categorical_column(categorical_column,
1720                                weight_feature_key,
1721                                dtype=dtypes.float32):
1722  """Applies weight values to a `CategoricalColumn`.
1723
1724  Use this when each of your sparse inputs has both an ID and a value. For
1725  example, if you're representing text documents as a collection of word
1726  frequencies, you can provide 2 parallel sparse input features ('terms' and
1727  'frequencies' below).
1728
1729  Example:
1730
1731  Input `tf.Example` objects:
1732
1733  ```proto
1734  [
1735    features {
1736      feature {
1737        key: "terms"
1738        value {bytes_list {value: "very" value: "model"}}
1739      }
1740      feature {
1741        key: "frequencies"
1742        value {float_list {value: 0.3 value: 0.1}}
1743      }
1744    },
1745    features {
1746      feature {
1747        key: "terms"
1748        value {bytes_list {value: "when" value: "course" value: "human"}}
1749      }
1750      feature {
1751        key: "frequencies"
1752        value {float_list {value: 0.4 value: 0.1 value: 0.2}}
1753      }
1754    }
1755  ]
1756  ```
1757
1758  ```python
1759  categorical_column = categorical_column_with_hash_bucket(
1760      column_name='terms', hash_bucket_size=1000)
1761  weighted_column = weighted_categorical_column(
1762      categorical_column=categorical_column, weight_feature_key='frequencies')
1763  columns = [weighted_column, ...]
1764  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1765  linear_prediction, _, _ = linear_model(features, columns)
1766  ```
1767
1768  This assumes the input dictionary contains a `SparseTensor` for key
1769  'terms', and a `SparseTensor` for key 'frequencies'. These 2 tensors must have
1770  the same indices and dense shape.
1771
1772  Args:
1773    categorical_column: A `CategoricalColumn` created by
1774      `categorical_column_with_*` functions.
1775    weight_feature_key: String key for weight values.
1776    dtype: Type of weights, such as `tf.float32`. Only float and integer weights
1777      are supported.
1778
1779  Returns:
1780    A `CategoricalColumn` composed of two sparse features: one represents id,
1781    the other represents weight (value) of the id feature in that example.
1782
1783  Raises:
1784    ValueError: if `dtype` is not convertible to float.
1785  """
1786  if (dtype is None) or not (dtype.is_integer or dtype.is_floating):
1787    raise ValueError('dtype {} is not convertible to float.'.format(dtype))
1788  return WeightedCategoricalColumn(
1789      categorical_column=categorical_column,
1790      weight_feature_key=weight_feature_key,
1791      dtype=dtype)
1792
1793
1794@tf_export('feature_column.crossed_column')
1795def crossed_column(keys, hash_bucket_size, hash_key=None):
1796  """Returns a column for performing crosses of categorical features.
1797
1798  Crossed features will be hashed according to `hash_bucket_size`. Conceptually,
1799  the transformation can be thought of as:
1800    Hash(cartesian product of features) % `hash_bucket_size`
1801
1802  For example, if the input features are:
1803
1804  * SparseTensor referred by first key:
1805
1806    ```python
1807    shape = [2, 2]
1808    {
1809        [0, 0]: "a"
1810        [1, 0]: "b"
1811        [1, 1]: "c"
1812    }
1813    ```
1814
1815  * SparseTensor referred by second key:
1816
1817    ```python
1818    shape = [2, 1]
1819    {
1820        [0, 0]: "d"
1821        [1, 0]: "e"
1822    }
1823    ```
1824
1825  then crossed feature will look like:
1826
1827  ```python
1828   shape = [2, 2]
1829  {
1830      [0, 0]: Hash64("d", Hash64("a")) % hash_bucket_size
1831      [1, 0]: Hash64("e", Hash64("b")) % hash_bucket_size
1832      [1, 1]: Hash64("e", Hash64("c")) % hash_bucket_size
1833  }
1834  ```
1835
1836  Here is an example to create a linear model with crosses of string features:
1837
1838  ```python
1839  keywords_x_doc_terms = crossed_column(['keywords', 'doc_terms'], 50K)
1840  columns = [keywords_x_doc_terms, ...]
1841  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1842  linear_prediction = linear_model(features, columns)
1843  ```
1844
1845  You could also use vocabulary lookup before crossing:
1846
1847  ```python
1848  keywords = categorical_column_with_vocabulary_file(
1849      'keywords', '/path/to/vocabulary/file', vocabulary_size=1K)
1850  keywords_x_doc_terms = crossed_column([keywords, 'doc_terms'], 50K)
1851  columns = [keywords_x_doc_terms, ...]
1852  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1853  linear_prediction = linear_model(features, columns)
1854  ```
1855
1856  If an input feature is of numeric type, you can use
1857  `categorical_column_with_identity`, or `bucketized_column`, as in the example:
1858
1859  ```python
1860  # vertical_id is an integer categorical feature.
1861  vertical_id = categorical_column_with_identity('vertical_id', 10K)
1862  price = numeric_column('price')
1863  # bucketized_column converts numerical feature to a categorical one.
1864  bucketized_price = bucketized_column(price, boundaries=[...])
1865  vertical_id_x_price = crossed_column([vertical_id, bucketized_price], 50K)
1866  columns = [vertical_id_x_price, ...]
1867  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))
1868  linear_prediction = linear_model(features, columns)
1869  ```
1870
1871  To use crossed column in DNN model, you need to add it in an embedding column
1872  as in this example:
1873
1874  ```python
1875  vertical_id_x_price = crossed_column([vertical_id, bucketized_price], 50K)
1876  vertical_id_x_price_embedded = embedding_column(vertical_id_x_price, 10)
1877  dense_tensor = input_layer(features, [vertical_id_x_price_embedded, ...])
1878  ```
1879
1880  Args:
1881    keys: An iterable identifying the features to be crossed. Each element can
1882      be either:
1883      * string: Will use the corresponding feature which must be of string type.
1884      * `CategoricalColumn`: Will use the transformed tensor produced by this
1885        column. Does not support hashed categorical column.
1886    hash_bucket_size: An int > 1. The number of buckets.
1887    hash_key: Specify the hash_key that will be used by the `FingerprintCat64`
1888      function to combine the crosses fingerprints on SparseCrossOp (optional).
1889
1890  Returns:
1891    A `CrossedColumn`.
1892
1893  Raises:
1894    ValueError: If `len(keys) < 2`.
1895    ValueError: If any of the keys is neither a string nor `CategoricalColumn`.
1896    ValueError: If any of the keys is `HashedCategoricalColumn`.
1897    ValueError: If `hash_bucket_size < 1`.
1898  """
1899  if not hash_bucket_size or hash_bucket_size < 1:
1900    raise ValueError('hash_bucket_size must be > 1. '
1901                     'hash_bucket_size: {}'.format(hash_bucket_size))
1902  if not keys or len(keys) < 2:
1903    raise ValueError(
1904        'keys must be a list with length > 1. Given: {}'.format(keys))
1905  for key in keys:
1906    if (not isinstance(key, six.string_types) and
1907        not isinstance(key, (CategoricalColumn, fc_old._CategoricalColumn))):  # pylint: disable=protected-access
1908      raise ValueError(
1909          'Unsupported key type. All keys must be either string, or '
1910          'categorical column except HashedCategoricalColumn. '
1911          'Given: {}'.format(key))
1912    if isinstance(key,
1913                  (HashedCategoricalColumn, fc_old._HashedCategoricalColumn)):  # pylint: disable=protected-access
1914      raise ValueError(
1915          'categorical_column_with_hash_bucket is not supported for crossing. '
1916          'Hashing before crossing will increase probability of collision. '
1917          'Instead, use the feature name as a string. Given: {}'.format(key))
1918  return CrossedColumn(
1919      keys=tuple(keys), hash_bucket_size=hash_bucket_size, hash_key=hash_key)
1920
1921
1922# TODO(b/181853833): Add a tf.type for instance type checking.
1923@tf_export('__internal__.feature_column.FeatureColumn', v1=[])
1924@six.add_metaclass(abc.ABCMeta)
1925class FeatureColumn(object):
1926  """Represents a feature column abstraction.
1927
1928  WARNING: Do not subclass this layer unless you know what you are doing:
1929  the API is subject to future changes.
1930
1931  To distinguish between the concept of a feature family and a specific binary
1932  feature within a family, we refer to a feature family like "country" as a
1933  feature column. For example, we can have a feature in a `tf.Example` format:
1934    {key: "country",  value: [ "US" ]}
1935  In this example the value of feature is "US" and "country" refers to the
1936  column of the feature.
1937
1938  This class is an abstract class. Users should not create instances of this.
1939  """
1940
1941  @abc.abstractproperty
1942  def name(self):
1943    """Returns string. Used for naming."""
1944    pass
1945
1946  def __lt__(self, other):
1947    """Allows feature columns to be sorted in Python 3 as they are in Python 2.
1948
1949    Feature columns need to occasionally be sortable, for example when used as
1950    keys in a features dictionary passed to a layer.
1951
1952    In CPython, `__lt__` must be defined for all objects in the
1953    sequence being sorted.
1954
1955    If any objects in the sequence being sorted do not have an `__lt__` method
1956    compatible with feature column objects (such as strings), then CPython will
1957    fall back to using the `__gt__` method below.
1958    https://docs.python.org/3/library/stdtypes.html#list.sort
1959
1960    Args:
1961      other: The other object to compare to.
1962
1963    Returns:
1964      True if the string representation of this object is lexicographically less
1965      than the string representation of `other`. For FeatureColumn objects,
1966      this looks like "<__main__.FeatureColumn object at 0xa>".
1967    """
1968    return str(self) < str(other)
1969
1970  def __gt__(self, other):
1971    """Allows feature columns to be sorted in Python 3 as they are in Python 2.
1972
1973    Feature columns need to occasionally be sortable, for example when used as
1974    keys in a features dictionary passed to a layer.
1975
1976    `__gt__` is called when the "other" object being compared during the sort
1977    does not have `__lt__` defined.
1978    Example:
1979    ```
1980    # __lt__ only class
1981    class A():
1982      def __lt__(self, other): return str(self) < str(other)
1983
1984    a = A()
1985    a < "b" # True
1986    "0" < a # Error
1987
1988    # __lt__ and __gt__ class
1989    class B():
1990      def __lt__(self, other): return str(self) < str(other)
1991      def __gt__(self, other): return str(self) > str(other)
1992
1993    b = B()
1994    b < "c" # True
1995    "0" < b # True
1996    ```
1997
1998    Args:
1999      other: The other object to compare to.
2000
2001    Returns:
2002      True if the string representation of this object is lexicographically
2003      greater than the string representation of `other`. For FeatureColumn
2004      objects, this looks like "<__main__.FeatureColumn object at 0xa>".
2005    """
2006    return str(self) > str(other)
2007
2008  @abc.abstractmethod
2009  def transform_feature(self, transformation_cache, state_manager):
2010    """Returns intermediate representation (usually a `Tensor`).
2011
2012    Uses `transformation_cache` to create an intermediate representation
2013    (usually a `Tensor`) that other feature columns can use.
2014
2015    Example usage of `transformation_cache`:
2016    Let's say a Feature column depends on raw feature ('raw') and another
2017    `FeatureColumn` (input_fc). To access corresponding `Tensor`s,
2018    transformation_cache will be used as follows:
2019
2020    ```python
2021    raw_tensor = transformation_cache.get('raw', state_manager)
2022    fc_tensor = transformation_cache.get(input_fc, state_manager)
2023    ```
2024
2025    Args:
2026      transformation_cache: A `FeatureTransformationCache` object to access
2027        features.
2028      state_manager: A `StateManager` to create / access resources such as
2029        lookup tables.
2030
2031    Returns:
2032      Transformed feature `Tensor`.
2033    """
2034    pass
2035
2036  @abc.abstractproperty
2037  def parse_example_spec(self):
2038    """Returns a `tf.Example` parsing spec as dict.
2039
2040    It is used for get_parsing_spec for `tf.io.parse_example`. Returned spec is
2041    a dict from keys ('string') to `VarLenFeature`, `FixedLenFeature`, and other
2042    supported objects. Please check documentation of `tf.io.parse_example` for
2043    all supported spec objects.
2044
2045    Let's say a Feature column depends on raw feature ('raw') and another
2046    `FeatureColumn` (input_fc). One possible implementation of
2047    parse_example_spec is as follows:
2048
2049    ```python
2050    spec = {'raw': tf.io.FixedLenFeature(...)}
2051    spec.update(input_fc.parse_example_spec)
2052    return spec
2053    ```
2054    """
2055    pass
2056
2057  def create_state(self, state_manager):
2058    """Uses the `state_manager` to create state for the FeatureColumn.
2059
2060    Args:
2061      state_manager: A `StateManager` to create / access resources such as
2062        lookup tables and variables.
2063    """
2064    pass
2065
2066  @abc.abstractproperty
2067  def _is_v2_column(self):
2068    """Returns whether this FeatureColumn is fully conformant to the new API.
2069
2070    This is needed for composition type cases where an EmbeddingColumn etc.
2071    might take in old categorical columns as input and then we want to use the
2072    old API.
2073    """
2074    pass
2075
2076  @abc.abstractproperty
2077  def parents(self):
2078    """Returns a list of immediate raw feature and FeatureColumn dependencies.
2079
2080    For example:
2081    # For the following feature columns
2082    a = numeric_column('f1')
2083    c = crossed_column(a, 'f2')
2084    # The expected parents are:
2085    a.parents = ['f1']
2086    c.parents = [a, 'f2']
2087    """
2088    pass
2089
2090  def get_config(self):
2091    """Returns the config of the feature column.
2092
2093    A FeatureColumn config is a Python dictionary (serializable) containing the
2094    configuration of a FeatureColumn. The same FeatureColumn can be
2095    reinstantiated later from this configuration.
2096
2097    The config of a feature column does not include information about feature
2098    columns depending on it nor the FeatureColumn class name.
2099
2100    Example with (de)serialization practices followed in this file:
2101    ```python
2102    class SerializationExampleFeatureColumn(
2103        FeatureColumn, collections.namedtuple(
2104            'SerializationExampleFeatureColumn',
2105            ('dimension', 'parent', 'dtype', 'normalizer_fn'))):
2106
2107      def get_config(self):
2108        # Create a dict from the namedtuple.
2109        # Python attribute literals can be directly copied from / to the config.
2110        # For example 'dimension', assuming it is an integer literal.
2111        config = dict(zip(self._fields, self))
2112
2113        # (De)serialization of parent FeatureColumns should use the provided
2114        # (de)serialize_feature_column() methods that take care of de-duping.
2115        config['parent'] = serialize_feature_column(self.parent)
2116
2117        # Many objects provide custom (de)serialization e.g: for tf.DType
2118        # tf.DType.name, tf.as_dtype() can be used.
2119        config['dtype'] = self.dtype.name
2120
2121        # Non-trivial dependencies should be Keras-(de)serializable.
2122        config['normalizer_fn'] = generic_utils.serialize_keras_object(
2123            self.normalizer_fn)
2124
2125        return config
2126
2127      @classmethod
2128      def from_config(cls, config, custom_objects=None, columns_by_name=None):
2129        # This should do the inverse transform from `get_config` and construct
2130        # the namedtuple.
2131        kwargs = config.copy()
2132        kwargs['parent'] = deserialize_feature_column(
2133            config['parent'], custom_objects, columns_by_name)
2134        kwargs['dtype'] = dtypes.as_dtype(config['dtype'])
2135        kwargs['normalizer_fn'] = generic_utils.deserialize_keras_object(
2136          config['normalizer_fn'], custom_objects=custom_objects)
2137        return cls(**kwargs)
2138
2139    ```
2140    Returns:
2141      A serializable Dict that can be used to deserialize the object with
2142      from_config.
2143    """
2144    return self._get_config()
2145
2146  def _get_config(self):
2147    raise NotImplementedError('Must be implemented in subclasses.')
2148
2149  @classmethod
2150  def from_config(cls, config, custom_objects=None, columns_by_name=None):
2151    """Creates a FeatureColumn from its config.
2152
2153    This method should be the reverse of `get_config`, capable of instantiating
2154    the same FeatureColumn from the config dictionary. See `get_config` for an
2155    example of common (de)serialization practices followed in this file.
2156
2157    TODO(b/118939620): This is a private method until consensus is reached on
2158    supporting object deserialization deduping within Keras.
2159
2160    Args:
2161      config: A Dict config acquired with `get_config`.
2162      custom_objects: Optional dictionary mapping names (strings) to custom
2163        classes or functions to be considered during deserialization.
2164      columns_by_name: A Dict[String, FeatureColumn] of existing columns in
2165        order to avoid duplication. Should be passed to any calls to
2166        deserialize_feature_column().
2167
2168    Returns:
2169      A FeatureColumn for the input config.
2170    """
2171    return cls._from_config(config, custom_objects, columns_by_name)
2172
2173  @classmethod
2174  def _from_config(cls, config, custom_objects=None, columns_by_name=None):
2175    raise NotImplementedError('Must be implemented in subclasses.')
2176
2177
2178# TODO(b/181853833): Add a tf.type for instance type checking.
2179@tf_export('__internal__.feature_column.DenseColumn', v1=[])
2180class DenseColumn(FeatureColumn):
2181  """Represents a column which can be represented as `Tensor`.
2182
2183  Some examples of this type are: numeric_column, embedding_column,
2184  indicator_column.
2185  """
2186
2187  @abc.abstractproperty
2188  def variable_shape(self):
2189    """`TensorShape` of `get_dense_tensor`, without batch dimension."""
2190    pass
2191
2192  @abc.abstractmethod
2193  def get_dense_tensor(self, transformation_cache, state_manager):
2194    """Returns a `Tensor`.
2195
2196    The output of this function will be used by model-builder-functions. For
2197    example the pseudo code of `input_layer` will be like:
2198
2199    ```python
2200    def input_layer(features, feature_columns, ...):
2201      outputs = [fc.get_dense_tensor(...) for fc in feature_columns]
2202      return tf.concat(outputs)
2203    ```
2204
2205    Args:
2206      transformation_cache: A `FeatureTransformationCache` object to access
2207        features.
2208      state_manager: A `StateManager` to create / access resources such as
2209        lookup tables.
2210
2211    Returns:
2212      `Tensor` of shape [batch_size] + `variable_shape`.
2213    """
2214    pass
2215
2216
2217def is_feature_column_v2(feature_columns):
2218  """Returns True if all feature columns are V2."""
2219  for feature_column in feature_columns:
2220    if not isinstance(feature_column, FeatureColumn):
2221      return False
2222    if not feature_column._is_v2_column:  # pylint: disable=protected-access
2223      return False
2224  return True
2225
2226
2227def _create_weighted_sum(column, transformation_cache, state_manager,
2228                         sparse_combiner, weight_var):
2229  """Creates a weighted sum for a dense/categorical column for linear_model."""
2230  if isinstance(column, CategoricalColumn):
2231    return _create_categorical_column_weighted_sum(
2232        column=column,
2233        transformation_cache=transformation_cache,
2234        state_manager=state_manager,
2235        sparse_combiner=sparse_combiner,
2236        weight_var=weight_var)
2237  else:
2238    return _create_dense_column_weighted_sum(
2239        column=column,
2240        transformation_cache=transformation_cache,
2241        state_manager=state_manager,
2242        weight_var=weight_var)
2243
2244
2245def _create_dense_column_weighted_sum(column, transformation_cache,
2246                                      state_manager, weight_var):
2247  """Create a weighted sum of a dense column for linear_model."""
2248  tensor = column.get_dense_tensor(transformation_cache, state_manager)
2249  num_elements = column.variable_shape.num_elements()
2250  batch_size = array_ops.shape(tensor)[0]
2251  tensor = array_ops.reshape(tensor, shape=(batch_size, num_elements))
2252  return math_ops.matmul(tensor, weight_var, name='weighted_sum')
2253
2254
2255class CategoricalColumn(FeatureColumn):
2256  """Represents a categorical feature.
2257
2258  A categorical feature typically handled with a `tf.sparse.SparseTensor` of
2259  IDs.
2260  """
2261
2262  IdWeightPair = collections.namedtuple(  # pylint: disable=invalid-name
2263      'IdWeightPair', ('id_tensor', 'weight_tensor'))
2264
2265  @abc.abstractproperty
2266  def num_buckets(self):
2267    """Returns number of buckets in this sparse feature."""
2268    pass
2269
2270  @abc.abstractmethod
2271  def get_sparse_tensors(self, transformation_cache, state_manager):
2272    """Returns an IdWeightPair.
2273
2274    `IdWeightPair` is a pair of `SparseTensor`s which represents ids and
2275    weights.
2276
2277    `IdWeightPair.id_tensor` is typically a `batch_size` x `num_buckets`
2278    `SparseTensor` of `int64`. `IdWeightPair.weight_tensor` is either a
2279    `SparseTensor` of `float` or `None` to indicate all weights should be
2280    taken to be 1. If specified, `weight_tensor` must have exactly the same
2281    shape and indices as `sp_ids`. Expected `SparseTensor` is same as parsing
2282    output of a `VarLenFeature` which is a ragged matrix.
2283
2284    Args:
2285      transformation_cache: A `FeatureTransformationCache` object to access
2286        features.
2287      state_manager: A `StateManager` to create / access resources such as
2288        lookup tables.
2289    """
2290    pass
2291
2292
2293def _create_categorical_column_weighted_sum(
2294    column, transformation_cache, state_manager, sparse_combiner, weight_var):
2295  # pylint: disable=g-doc-return-or-yield,g-doc-args
2296  """Create a weighted sum of a categorical column for linear_model.
2297
2298  Note to maintainer: As implementation details, the weighted sum is
2299  implemented via embedding_lookup_sparse toward efficiency. Mathematically,
2300  they are the same.
2301
2302  To be specific, conceptually, categorical column can be treated as multi-hot
2303  vector. Say:
2304
2305  ```python
2306    x = [0 0 1]  # categorical column input
2307    w = [a b c]  # weights
2308  ```
2309  The weighted sum is `c` in this case, which is same as `w[2]`.
2310
2311  Another example is
2312
2313  ```python
2314    x = [0 1 1]  # categorical column input
2315    w = [a b c]  # weights
2316  ```
2317  The weighted sum is `b + c` in this case, which is same as `w[2] + w[3]`.
2318
2319  For both cases, we can implement weighted sum via embedding_lookup with
2320  sparse_combiner = "sum".
2321  """
2322
2323  sparse_tensors = column.get_sparse_tensors(transformation_cache,
2324                                             state_manager)
2325  id_tensor = sparse_ops.sparse_reshape(sparse_tensors.id_tensor, [
2326      array_ops.shape(sparse_tensors.id_tensor)[0], -1
2327  ])
2328  weight_tensor = sparse_tensors.weight_tensor
2329  if weight_tensor is not None:
2330    weight_tensor = sparse_ops.sparse_reshape(
2331        weight_tensor, [array_ops.shape(weight_tensor)[0], -1])
2332
2333  return embedding_ops.safe_embedding_lookup_sparse(
2334      weight_var,
2335      id_tensor,
2336      sparse_weights=weight_tensor,
2337      combiner=sparse_combiner,
2338      name='weighted_sum')
2339
2340
2341# TODO(b/181853833): Add a tf.type for instance type checking.
2342@tf_export('__internal__.feature_column.SequenceDenseColumn', v1=[])
2343class SequenceDenseColumn(FeatureColumn):
2344  """Represents dense sequence data."""
2345
2346  TensorSequenceLengthPair = collections.namedtuple(  # pylint: disable=invalid-name
2347      'TensorSequenceLengthPair', ('dense_tensor', 'sequence_length'))
2348
2349  @abc.abstractmethod
2350  def get_sequence_dense_tensor(self, transformation_cache, state_manager):
2351    """Returns a `TensorSequenceLengthPair`.
2352
2353    Args:
2354      transformation_cache: A `FeatureTransformationCache` object to access
2355        features.
2356      state_manager: A `StateManager` to create / access resources such as
2357        lookup tables.
2358    """
2359    pass
2360
2361
2362@tf_export('__internal__.feature_column.FeatureTransformationCache', v1=[])
2363class FeatureTransformationCache(object):
2364  """Handles caching of transformations while building the model.
2365
2366  `FeatureColumn` specifies how to digest an input column to the network. Some
2367  feature columns require data transformations. This class caches those
2368  transformations.
2369
2370  Some features may be used in more than one place. For example, one can use a
2371  bucketized feature by itself and a cross with it. In that case we
2372  should create only one bucketization op instead of creating ops for each
2373  feature column separately. To handle re-use of transformed columns,
2374  `FeatureTransformationCache` caches all previously transformed columns.
2375
2376  Example:
2377  We're trying to use the following `FeatureColumn`s:
2378
2379  ```python
2380  bucketized_age = fc.bucketized_column(fc.numeric_column("age"), ...)
2381  keywords = fc.categorical_column_with_hash_buckets("keywords", ...)
2382  age_X_keywords = fc.crossed_column([bucketized_age, "keywords"])
2383  ... = linear_model(features,
2384                          [bucketized_age, keywords, age_X_keywords]
2385  ```
2386
2387  If we transform each column independently, then we'll get duplication of
2388  bucketization (one for cross, one for bucketization itself).
2389  The `FeatureTransformationCache` eliminates this duplication.
2390  """
2391
2392  def __init__(self, features):
2393    """Creates a `FeatureTransformationCache`.
2394
2395    Args:
2396      features: A mapping from feature column to objects that are `Tensor` or
2397        `SparseTensor`, or can be converted to same via
2398        `sparse_tensor.convert_to_tensor_or_sparse_tensor`. A `string` key
2399        signifies a base feature (not-transformed). A `FeatureColumn` key
2400        means that this `Tensor` is the output of an existing `FeatureColumn`
2401        which can be reused.
2402    """
2403    self._features = features.copy()
2404    self._feature_tensors = {}
2405
2406  def get(self, key, state_manager, training=None):
2407    """Returns a `Tensor` for the given key.
2408
2409    A `str` key is used to access a base feature (not-transformed). When a
2410    `FeatureColumn` is passed, the transformed feature is returned if it
2411    already exists, otherwise the given `FeatureColumn` is asked to provide its
2412    transformed output, which is then cached.
2413
2414    Args:
2415      key: a `str` or a `FeatureColumn`.
2416      state_manager: A StateManager object that holds the FeatureColumn state.
2417      training: Boolean indicating whether to the column is being used in
2418        training mode. This argument is passed to the transform_feature method
2419        of any `FeatureColumn` that takes a `training` argument. For example, if
2420        a `FeatureColumn` performed dropout, it could expose a `training`
2421        argument to control whether the dropout should be applied.
2422
2423    Returns:
2424      The transformed `Tensor` corresponding to the `key`.
2425
2426    Raises:
2427      ValueError: if key is not found or a transformed `Tensor` cannot be
2428        computed.
2429    """
2430    if key in self._feature_tensors:
2431      # FeatureColumn is already transformed or converted.
2432      return self._feature_tensors[key]
2433
2434    if key in self._features:
2435      feature_tensor = self._get_raw_feature_as_tensor(key)
2436      self._feature_tensors[key] = feature_tensor
2437      return feature_tensor
2438
2439    if isinstance(key, six.string_types):
2440      raise ValueError('Feature {} is not in features dictionary.'.format(key))
2441
2442    if not isinstance(key, FeatureColumn):
2443      raise TypeError('"key" must be either a "str" or "FeatureColumn". '
2444                      'Provided: {}'.format(key))
2445
2446    column = key
2447    logging.debug('Transforming feature_column %s.', column)
2448
2449    # Some columns may need information about whether the transformation is
2450    # happening in training or prediction mode, but not all columns expose this
2451    # argument.
2452    try:
2453      transformed = column.transform_feature(
2454          self, state_manager, training=training)
2455    except TypeError:
2456      transformed = column.transform_feature(self, state_manager)
2457    if transformed is None:
2458      raise ValueError('Column {} is not supported.'.format(column.name))
2459    self._feature_tensors[column] = transformed
2460    return transformed
2461
2462  def _get_raw_feature_as_tensor(self, key):
2463    """Gets the raw_feature (keyed by `key`) as `tensor`.
2464
2465    The raw feature is converted to (sparse) tensor and maybe expand dim.
2466
2467    For both `Tensor` and `SparseTensor`, the rank will be expanded (to 2) if
2468    the rank is 1. This supports dynamic rank also. For rank 0 raw feature, will
2469    error out as it is not supported.
2470
2471    Args:
2472      key: A `str` key to access the raw feature.
2473
2474    Returns:
2475      A `Tensor` or `SparseTensor`.
2476
2477    Raises:
2478      ValueError: if the raw feature has rank 0.
2479    """
2480    raw_feature = self._features[key]
2481    feature_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(
2482        raw_feature)
2483
2484    def expand_dims(input_tensor):
2485      # Input_tensor must have rank 1.
2486      if isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
2487        return sparse_ops.sparse_reshape(
2488            input_tensor, [array_ops.shape(input_tensor)[0], 1])
2489      else:
2490        return array_ops.expand_dims(input_tensor, -1)
2491
2492    rank = feature_tensor.get_shape().ndims
2493    if rank is not None:
2494      if rank == 0:
2495        raise ValueError(
2496            'Feature (key: {}) cannot have rank 0. Given: {}'.format(
2497                key, feature_tensor))
2498      return feature_tensor if rank != 1 else expand_dims(feature_tensor)
2499
2500    # Handle dynamic rank.
2501    with ops.control_dependencies([
2502        check_ops.assert_positive(
2503            array_ops.rank(feature_tensor),
2504            message='Feature (key: {}) cannot have rank 0. Given: {}'.format(
2505                key, feature_tensor))]):
2506      return control_flow_ops.cond(
2507          math_ops.equal(1, array_ops.rank(feature_tensor)),
2508          lambda: expand_dims(feature_tensor),
2509          lambda: feature_tensor)
2510
2511
2512# TODO(ptucker): Move to third_party/tensorflow/python/ops/sparse_ops.py
2513def _to_sparse_input_and_drop_ignore_values(input_tensor, ignore_value=None):
2514  """Converts a `Tensor` to a `SparseTensor`, dropping ignore_value cells.
2515
2516  If `input_tensor` is already a `SparseTensor`, just return it.
2517
2518  Args:
2519    input_tensor: A string or integer `Tensor`.
2520    ignore_value: Entries in `dense_tensor` equal to this value will be
2521      absent from the resulting `SparseTensor`. If `None`, default value of
2522      `dense_tensor`'s dtype will be used ('' for `str`, -1 for `int`).
2523
2524  Returns:
2525    A `SparseTensor` with the same shape as `input_tensor`.
2526
2527  Raises:
2528    ValueError: when `input_tensor`'s rank is `None`.
2529  """
2530  input_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(
2531      input_tensor)
2532  if isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
2533    return input_tensor
2534  with ops.name_scope(None, 'to_sparse_input', (input_tensor, ignore_value,)):
2535    if ignore_value is None:
2536      if input_tensor.dtype == dtypes.string:
2537        # Exception due to TF strings are converted to numpy objects by default.
2538        ignore_value = ''
2539      elif input_tensor.dtype.is_integer:
2540        ignore_value = -1  # -1 has a special meaning of missing feature
2541      else:
2542        # NOTE: `as_numpy_dtype` is a property, so with the parentheses this is
2543        # constructing a new numpy object of the given type, which yields the
2544        # default value for that type.
2545        ignore_value = input_tensor.dtype.as_numpy_dtype()
2546    ignore_value = math_ops.cast(
2547        ignore_value, input_tensor.dtype, name='ignore_value')
2548    indices = array_ops.where_v2(
2549        math_ops.not_equal(input_tensor, ignore_value), name='indices')
2550    return sparse_tensor_lib.SparseTensor(
2551        indices=indices,
2552        values=array_ops.gather_nd(input_tensor, indices, name='values'),
2553        dense_shape=array_ops.shape(
2554            input_tensor, out_type=dtypes.int64, name='dense_shape'))
2555
2556
2557def _normalize_feature_columns(feature_columns):
2558  """Normalizes the `feature_columns` input.
2559
2560  This method converts the `feature_columns` to list type as best as it can. In
2561  addition, verifies the type and other parts of feature_columns, required by
2562  downstream library.
2563
2564  Args:
2565    feature_columns: The raw feature columns, usually passed by users.
2566
2567  Returns:
2568    The normalized feature column list.
2569
2570  Raises:
2571    ValueError: for any invalid inputs, such as empty, duplicated names, etc.
2572  """
2573  if isinstance(feature_columns, FeatureColumn):
2574    feature_columns = [feature_columns]
2575
2576  if isinstance(feature_columns, collections_abc.Iterator):
2577    feature_columns = list(feature_columns)
2578
2579  if isinstance(feature_columns, dict):
2580    raise ValueError('Expected feature_columns to be iterable, found dict.')
2581
2582  for column in feature_columns:
2583    if not isinstance(column, FeatureColumn):
2584      raise ValueError('Items of feature_columns must be a FeatureColumn. '
2585                       'Given (type {}): {}.'.format(type(column), column))
2586  if not feature_columns:
2587    raise ValueError('feature_columns must not be empty.')
2588  name_to_column = {}
2589  for column in feature_columns:
2590    if column.name in name_to_column:
2591      raise ValueError('Duplicate feature column name found for columns: {} '
2592                       'and {}. This usually means that these columns refer to '
2593                       'same base feature. Either one must be discarded or a '
2594                       'duplicated but renamed item must be inserted in '
2595                       'features dict.'.format(column,
2596                                               name_to_column[column.name]))
2597    name_to_column[column.name] = column
2598
2599  return sorted(feature_columns, key=lambda x: x.name)
2600
2601
2602class NumericColumn(
2603    DenseColumn,
2604    fc_old._DenseColumn,  # pylint: disable=protected-access
2605    collections.namedtuple(
2606        'NumericColumn',
2607        ('key', 'shape', 'default_value', 'dtype', 'normalizer_fn'))):
2608  """see `numeric_column`."""
2609
2610  @property
2611  def _is_v2_column(self):
2612    return True
2613
2614  @property
2615  def name(self):
2616    """See `FeatureColumn` base class."""
2617    return self.key
2618
2619  @property
2620  def parse_example_spec(self):
2621    """See `FeatureColumn` base class."""
2622    return {
2623        self.key:
2624            parsing_ops.FixedLenFeature(self.shape, self.dtype,
2625                                        self.default_value)
2626    }
2627
2628  @property
2629  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2630                          _FEATURE_COLUMN_DEPRECATION)
2631  def _parse_example_spec(self):
2632    return self.parse_example_spec
2633
2634  def _transform_input_tensor(self, input_tensor):
2635    if isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
2636      raise ValueError(
2637          'The corresponding Tensor of numerical column must be a Tensor. '
2638          'SparseTensor is not supported. key: {}'.format(self.key))
2639    if self.normalizer_fn is not None:
2640      input_tensor = self.normalizer_fn(input_tensor)
2641    return math_ops.cast(input_tensor, dtypes.float32)
2642
2643  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2644                          _FEATURE_COLUMN_DEPRECATION)
2645  def _transform_feature(self, inputs):
2646    input_tensor = inputs.get(self.key)
2647    return self._transform_input_tensor(input_tensor)
2648
2649  def transform_feature(self, transformation_cache, state_manager):
2650    """See `FeatureColumn` base class.
2651
2652    In this case, we apply the `normalizer_fn` to the input tensor.
2653
2654    Args:
2655      transformation_cache: A `FeatureTransformationCache` object to access
2656        features.
2657      state_manager: A `StateManager` to create / access resources such as
2658        lookup tables.
2659
2660    Returns:
2661      Normalized input tensor.
2662    Raises:
2663      ValueError: If a SparseTensor is passed in.
2664    """
2665    input_tensor = transformation_cache.get(self.key, state_manager)
2666    return self._transform_input_tensor(input_tensor)
2667
2668  @property
2669  def variable_shape(self):
2670    """See `DenseColumn` base class."""
2671    return tensor_shape.TensorShape(self.shape)
2672
2673  @property
2674  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2675                          _FEATURE_COLUMN_DEPRECATION)
2676  def _variable_shape(self):
2677    return self.variable_shape
2678
2679  def get_dense_tensor(self, transformation_cache, state_manager):
2680    """Returns dense `Tensor` representing numeric feature.
2681
2682    Args:
2683      transformation_cache: A `FeatureTransformationCache` object to access
2684        features.
2685      state_manager: A `StateManager` to create / access resources such as
2686        lookup tables.
2687
2688    Returns:
2689      Dense `Tensor` created within `transform_feature`.
2690    """
2691    # Feature has been already transformed. Return the intermediate
2692    # representation created by _transform_feature.
2693    return transformation_cache.get(self, state_manager)
2694
2695  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2696                          _FEATURE_COLUMN_DEPRECATION)
2697  def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
2698    del weight_collections
2699    del trainable
2700    return inputs.get(self)
2701
2702  @property
2703  def parents(self):
2704    """See 'FeatureColumn` base class."""
2705    return [self.key]
2706
2707  def get_config(self):
2708    """See 'FeatureColumn` base class."""
2709    config = dict(zip(self._fields, self))
2710    from tensorflow.python.feature_column import serialization  # pylint: disable=g-import-not-at-top
2711    config['normalizer_fn'] = serialization._serialize_keras_object(  # pylint: disable=protected-access
2712        self.normalizer_fn)
2713    config['dtype'] = self.dtype.name
2714    return config
2715
2716  @classmethod
2717  def from_config(cls, config, custom_objects=None, columns_by_name=None):
2718    """See 'FeatureColumn` base class."""
2719    _check_config_keys(config, cls._fields)
2720    from tensorflow.python.feature_column import serialization  # pylint: disable=g-import-not-at-top
2721    kwargs = _standardize_and_copy_config(config)
2722    kwargs['normalizer_fn'] = serialization._deserialize_keras_object(  # pylint: disable=protected-access
2723        config['normalizer_fn'], custom_objects=custom_objects)
2724    kwargs['dtype'] = dtypes.as_dtype(config['dtype'])
2725
2726    return cls(**kwargs)
2727
2728
2729class BucketizedColumn(
2730    DenseColumn,
2731    CategoricalColumn,
2732    fc_old._DenseColumn,  # pylint: disable=protected-access
2733    fc_old._CategoricalColumn,  # pylint: disable=protected-access
2734    collections.namedtuple('BucketizedColumn',
2735                           ('source_column', 'boundaries'))):
2736  """See `bucketized_column`."""
2737
2738  @property
2739  def _is_v2_column(self):
2740    return (isinstance(self.source_column, FeatureColumn) and
2741            self.source_column._is_v2_column)  # pylint: disable=protected-access
2742
2743  @property
2744  def name(self):
2745    """See `FeatureColumn` base class."""
2746    return '{}_bucketized'.format(self.source_column.name)
2747
2748  @property
2749  def parse_example_spec(self):
2750    """See `FeatureColumn` base class."""
2751    return self.source_column.parse_example_spec
2752
2753  @property
2754  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2755                          _FEATURE_COLUMN_DEPRECATION)
2756  def _parse_example_spec(self):
2757    return self.source_column._parse_example_spec  # pylint: disable=protected-access
2758
2759  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2760                          _FEATURE_COLUMN_DEPRECATION)
2761  def _transform_feature(self, inputs):
2762    """Returns bucketized categorical `source_column` tensor."""
2763    source_tensor = inputs.get(self.source_column)
2764    return math_ops._bucketize(  # pylint: disable=protected-access
2765        source_tensor,
2766        boundaries=self.boundaries)
2767
2768  def transform_feature(self, transformation_cache, state_manager):
2769    """Returns bucketized categorical `source_column` tensor."""
2770    source_tensor = transformation_cache.get(self.source_column, state_manager)
2771    return math_ops._bucketize(  # pylint: disable=protected-access
2772        source_tensor,
2773        boundaries=self.boundaries)
2774
2775  @property
2776  def variable_shape(self):
2777    """See `DenseColumn` base class."""
2778    return tensor_shape.TensorShape(
2779        tuple(self.source_column.shape) + (len(self.boundaries) + 1,))
2780
2781  @property
2782  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2783                          _FEATURE_COLUMN_DEPRECATION)
2784  def _variable_shape(self):
2785    return self.variable_shape
2786
2787  def _get_dense_tensor_for_input_tensor(self, input_tensor):
2788    return array_ops.one_hot(
2789        indices=math_ops.cast(input_tensor, dtypes.int64),
2790        depth=len(self.boundaries) + 1,
2791        on_value=1.,
2792        off_value=0.)
2793
2794  def get_dense_tensor(self, transformation_cache, state_manager):
2795    """Returns one hot encoded dense `Tensor`."""
2796    input_tensor = transformation_cache.get(self, state_manager)
2797    return self._get_dense_tensor_for_input_tensor(input_tensor)
2798
2799  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2800                          _FEATURE_COLUMN_DEPRECATION)
2801  def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
2802    del weight_collections
2803    del trainable
2804    input_tensor = inputs.get(self)
2805    return self._get_dense_tensor_for_input_tensor(input_tensor)
2806
2807  @property
2808  def num_buckets(self):
2809    """See `CategoricalColumn` base class."""
2810    # By construction, source_column is always one-dimensional.
2811    return (len(self.boundaries) + 1) * self.source_column.shape[0]
2812
2813  @property
2814  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2815                          _FEATURE_COLUMN_DEPRECATION)
2816  def _num_buckets(self):
2817    return self.num_buckets
2818
2819  def _get_sparse_tensors_for_input_tensor(self, input_tensor):
2820    batch_size = array_ops.shape(input_tensor)[0]
2821    # By construction, source_column is always one-dimensional.
2822    source_dimension = self.source_column.shape[0]
2823
2824    i1 = array_ops.reshape(
2825        array_ops.tile(
2826            array_ops.expand_dims(math_ops.range(0, batch_size), 1),
2827            [1, source_dimension]),
2828        (-1,))
2829    i2 = array_ops.tile(math_ops.range(0, source_dimension), [batch_size])
2830    # Flatten the bucket indices and unique them across dimensions
2831    # E.g. 2nd dimension indices will range from k to 2*k-1 with k buckets
2832    bucket_indices = (
2833        array_ops.reshape(input_tensor, (-1,)) +
2834        (len(self.boundaries) + 1) * i2)
2835
2836    indices = math_ops.cast(
2837        array_ops.transpose(array_ops.stack((i1, i2))), dtypes.int64)
2838    dense_shape = math_ops.cast(
2839        array_ops.stack([batch_size, source_dimension]), dtypes.int64)
2840    sparse_tensor = sparse_tensor_lib.SparseTensor(
2841        indices=indices,
2842        values=bucket_indices,
2843        dense_shape=dense_shape)
2844    return CategoricalColumn.IdWeightPair(sparse_tensor, None)
2845
2846  def get_sparse_tensors(self, transformation_cache, state_manager):
2847    """Converts dense inputs to SparseTensor so downstream code can use it."""
2848    input_tensor = transformation_cache.get(self, state_manager)
2849    return self._get_sparse_tensors_for_input_tensor(input_tensor)
2850
2851  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2852                          _FEATURE_COLUMN_DEPRECATION)
2853  def _get_sparse_tensors(self, inputs, weight_collections=None,
2854                          trainable=None):
2855    """Converts dense inputs to SparseTensor so downstream code can use it."""
2856    del weight_collections
2857    del trainable
2858    input_tensor = inputs.get(self)
2859    return self._get_sparse_tensors_for_input_tensor(input_tensor)
2860
2861  @property
2862  def parents(self):
2863    """See 'FeatureColumn` base class."""
2864    return [self.source_column]
2865
2866  def get_config(self):
2867    """See 'FeatureColumn` base class."""
2868    from tensorflow.python.feature_column.serialization import serialize_feature_column  # pylint: disable=g-import-not-at-top
2869    config = dict(zip(self._fields, self))
2870    config['source_column'] = serialize_feature_column(self.source_column)
2871    return config
2872
2873  @classmethod
2874  def from_config(cls, config, custom_objects=None, columns_by_name=None):
2875    """See 'FeatureColumn` base class."""
2876    from tensorflow.python.feature_column.serialization import deserialize_feature_column  # pylint: disable=g-import-not-at-top
2877    _check_config_keys(config, cls._fields)
2878    kwargs = _standardize_and_copy_config(config)
2879    kwargs['source_column'] = deserialize_feature_column(
2880        config['source_column'], custom_objects, columns_by_name)
2881    return cls(**kwargs)
2882
2883
2884class EmbeddingColumn(
2885    DenseColumn,
2886    SequenceDenseColumn,
2887    fc_old._DenseColumn,  # pylint: disable=protected-access
2888    fc_old._SequenceDenseColumn,  # pylint: disable=protected-access
2889    collections.namedtuple(
2890        'EmbeddingColumn',
2891        ('categorical_column', 'dimension', 'combiner', 'initializer',
2892         'ckpt_to_load_from', 'tensor_name_in_ckpt', 'max_norm', 'trainable',
2893         'use_safe_embedding_lookup'))):
2894  """See `embedding_column`."""
2895
2896  def __new__(cls,
2897              categorical_column,
2898              dimension,
2899              combiner,
2900              initializer,
2901              ckpt_to_load_from,
2902              tensor_name_in_ckpt,
2903              max_norm,
2904              trainable,
2905              use_safe_embedding_lookup=True):
2906    return super(EmbeddingColumn, cls).__new__(
2907        cls,
2908        categorical_column=categorical_column,
2909        dimension=dimension,
2910        combiner=combiner,
2911        initializer=initializer,
2912        ckpt_to_load_from=ckpt_to_load_from,
2913        tensor_name_in_ckpt=tensor_name_in_ckpt,
2914        max_norm=max_norm,
2915        trainable=trainable,
2916        use_safe_embedding_lookup=use_safe_embedding_lookup)
2917
2918  @property
2919  def _is_v2_column(self):
2920    return (isinstance(self.categorical_column, FeatureColumn) and
2921            self.categorical_column._is_v2_column)  # pylint: disable=protected-access
2922
2923  @property
2924  def name(self):
2925    """See `FeatureColumn` base class."""
2926    return '{}_embedding'.format(self.categorical_column.name)
2927
2928  @property
2929  def parse_example_spec(self):
2930    """See `FeatureColumn` base class."""
2931    return self.categorical_column.parse_example_spec
2932
2933  @property
2934  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2935                          _FEATURE_COLUMN_DEPRECATION)
2936  def _parse_example_spec(self):
2937    return self.categorical_column._parse_example_spec  # pylint: disable=protected-access
2938
2939  def transform_feature(self, transformation_cache, state_manager):
2940    """Transforms underlying `categorical_column`."""
2941    return transformation_cache.get(self.categorical_column, state_manager)
2942
2943  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2944                          _FEATURE_COLUMN_DEPRECATION)
2945  def _transform_feature(self, inputs):
2946    return inputs.get(self.categorical_column)
2947
2948  @property
2949  def variable_shape(self):
2950    """See `DenseColumn` base class."""
2951    return tensor_shape.TensorShape([self.dimension])
2952
2953  @property
2954  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
2955                          _FEATURE_COLUMN_DEPRECATION)
2956  def _variable_shape(self):
2957    return self.variable_shape
2958
2959  def create_state(self, state_manager):
2960    """Creates the embedding lookup variable."""
2961    default_num_buckets = (self.categorical_column.num_buckets
2962                           if self._is_v2_column
2963                           else self.categorical_column._num_buckets)   # pylint: disable=protected-access
2964    num_buckets = getattr(self.categorical_column, 'num_buckets',
2965                          default_num_buckets)
2966    embedding_shape = (num_buckets, self.dimension)
2967    state_manager.create_variable(
2968        self,
2969        name='embedding_weights',
2970        shape=embedding_shape,
2971        dtype=dtypes.float32,
2972        trainable=self.trainable,
2973        use_resource=True,
2974        initializer=self.initializer)
2975
2976  def _get_dense_tensor_internal_helper(self, sparse_tensors,
2977                                        embedding_weights):
2978    sparse_ids = sparse_tensors.id_tensor
2979    sparse_weights = sparse_tensors.weight_tensor
2980
2981    if self.ckpt_to_load_from is not None:
2982      to_restore = embedding_weights
2983      if isinstance(to_restore, variables.PartitionedVariable):
2984        to_restore = to_restore._get_variable_list()  # pylint: disable=protected-access
2985      checkpoint_utils.init_from_checkpoint(self.ckpt_to_load_from, {
2986          self.tensor_name_in_ckpt: to_restore
2987      })
2988
2989    sparse_id_rank = tensor_shape.dimension_value(
2990        sparse_ids.dense_shape.get_shape()[0])
2991    embedding_lookup_sparse = embedding_ops.safe_embedding_lookup_sparse
2992    if (not self.use_safe_embedding_lookup and sparse_id_rank is not None and
2993        sparse_id_rank <= 2):
2994      embedding_lookup_sparse = embedding_ops.embedding_lookup_sparse_v2
2995    # Return embedding lookup result.
2996    return embedding_lookup_sparse(
2997        embedding_weights,
2998        sparse_ids,
2999        sparse_weights,
3000        combiner=self.combiner,
3001        name='%s_weights' % self.name,
3002        max_norm=self.max_norm)
3003
3004  def _get_dense_tensor_internal(self, sparse_tensors, state_manager):
3005    """Private method that follows the signature of get_dense_tensor."""
3006    embedding_weights = state_manager.get_variable(
3007        self, name='embedding_weights')
3008    return self._get_dense_tensor_internal_helper(sparse_tensors,
3009                                                  embedding_weights)
3010
3011  def _old_get_dense_tensor_internal(self, sparse_tensors, weight_collections,
3012                                     trainable):
3013    """Private method that follows the signature of _get_dense_tensor."""
3014    embedding_shape = (self.categorical_column._num_buckets, self.dimension)  # pylint: disable=protected-access
3015    if (weight_collections and
3016        ops.GraphKeys.GLOBAL_VARIABLES not in weight_collections):
3017      weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES)
3018    embedding_weights = variable_scope.get_variable(
3019        name='embedding_weights',
3020        shape=embedding_shape,
3021        dtype=dtypes.float32,
3022        initializer=self.initializer,
3023        trainable=self.trainable and trainable,
3024        collections=weight_collections)
3025    return self._get_dense_tensor_internal_helper(sparse_tensors,
3026                                                  embedding_weights)
3027
3028  def get_dense_tensor(self, transformation_cache, state_manager):
3029    """Returns tensor after doing the embedding lookup.
3030
3031    Args:
3032      transformation_cache: A `FeatureTransformationCache` object to access
3033        features.
3034      state_manager: A `StateManager` to create / access resources such as
3035        lookup tables.
3036
3037    Returns:
3038      Embedding lookup tensor.
3039
3040    Raises:
3041      ValueError: `categorical_column` is SequenceCategoricalColumn.
3042    """
3043    if isinstance(self.categorical_column, SequenceCategoricalColumn):
3044      raise ValueError(
3045          'In embedding_column: {}. '
3046          'categorical_column must not be of type SequenceCategoricalColumn. '
3047          'Suggested fix A: If you wish to use DenseFeatures, use a '
3048          'non-sequence categorical_column_with_*. '
3049          'Suggested fix B: If you wish to create sequence input, use '
3050          'SequenceFeatures instead of DenseFeatures. '
3051          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
3052                                       self.categorical_column))
3053    # Get sparse IDs and weights.
3054    sparse_tensors = self.categorical_column.get_sparse_tensors(
3055        transformation_cache, state_manager)
3056    return self._get_dense_tensor_internal(sparse_tensors, state_manager)
3057
3058  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3059                          _FEATURE_COLUMN_DEPRECATION)
3060  def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
3061    if isinstance(
3062        self.categorical_column,
3063        (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)):  # pylint: disable=protected-access
3064      raise ValueError(
3065          'In embedding_column: {}. '
3066          'categorical_column must not be of type _SequenceCategoricalColumn. '
3067          'Suggested fix A: If you wish to use DenseFeatures, use a '
3068          'non-sequence categorical_column_with_*. '
3069          'Suggested fix B: If you wish to create sequence input, use '
3070          'SequenceFeatures instead of DenseFeatures. '
3071          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
3072                                       self.categorical_column))
3073    sparse_tensors = self.categorical_column._get_sparse_tensors(  # pylint: disable=protected-access
3074        inputs, weight_collections, trainable)
3075    return self._old_get_dense_tensor_internal(sparse_tensors,
3076                                               weight_collections, trainable)
3077
3078  def get_sequence_dense_tensor(self, transformation_cache, state_manager):
3079    """See `SequenceDenseColumn` base class."""
3080    if not isinstance(self.categorical_column, SequenceCategoricalColumn):
3081      raise ValueError(
3082          'In embedding_column: {}. '
3083          'categorical_column must be of type SequenceCategoricalColumn '
3084          'to use SequenceFeatures. '
3085          'Suggested fix: Use one of sequence_categorical_column_with_*. '
3086          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
3087                                       self.categorical_column))
3088    sparse_tensors = self.categorical_column.get_sparse_tensors(
3089        transformation_cache, state_manager)
3090    dense_tensor = self._get_dense_tensor_internal(sparse_tensors,
3091                                                   state_manager)
3092    sequence_length = fc_utils.sequence_length_from_sparse_tensor(
3093        sparse_tensors.id_tensor)
3094    return SequenceDenseColumn.TensorSequenceLengthPair(
3095        dense_tensor=dense_tensor, sequence_length=sequence_length)
3096
3097  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3098                          _FEATURE_COLUMN_DEPRECATION)
3099  def _get_sequence_dense_tensor(self,
3100                                 inputs,
3101                                 weight_collections=None,
3102                                 trainable=None):
3103    if not isinstance(
3104        self.categorical_column,
3105        (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)):  # pylint: disable=protected-access
3106      raise ValueError(
3107          'In embedding_column: {}. '
3108          'categorical_column must be of type SequenceCategoricalColumn '
3109          'to use SequenceFeatures. '
3110          'Suggested fix: Use one of sequence_categorical_column_with_*. '
3111          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
3112                                       self.categorical_column))
3113    sparse_tensors = self.categorical_column._get_sparse_tensors(inputs)  # pylint: disable=protected-access
3114    dense_tensor = self._old_get_dense_tensor_internal(
3115        sparse_tensors,
3116        weight_collections=weight_collections,
3117        trainable=trainable)
3118    sequence_length = fc_utils.sequence_length_from_sparse_tensor(
3119        sparse_tensors.id_tensor)
3120    return SequenceDenseColumn.TensorSequenceLengthPair(
3121        dense_tensor=dense_tensor, sequence_length=sequence_length)
3122
3123  @property
3124  def parents(self):
3125    """See 'FeatureColumn` base class."""
3126    return [self.categorical_column]
3127
3128  def get_config(self):
3129    """See 'FeatureColumn` base class."""
3130    from tensorflow.python.feature_column import serialization  # pylint: disable=g-import-not-at-top
3131    config = dict(zip(self._fields, self))
3132    config['categorical_column'] = serialization.serialize_feature_column(
3133        self.categorical_column)
3134    config['initializer'] = serialization._serialize_keras_object(  # pylint: disable=protected-access
3135        self.initializer)
3136    return config
3137
3138  @classmethod
3139  def from_config(cls, config, custom_objects=None, columns_by_name=None):
3140    """See 'FeatureColumn` base class."""
3141    if 'use_safe_embedding_lookup' not in config:
3142      config['use_safe_embedding_lookup'] = True
3143    from tensorflow.python.feature_column import serialization  # pylint: disable=g-import-not-at-top
3144    _check_config_keys(config, cls._fields)
3145    kwargs = _standardize_and_copy_config(config)
3146    kwargs['categorical_column'] = serialization.deserialize_feature_column(
3147        config['categorical_column'], custom_objects, columns_by_name)
3148    all_initializers = dict(tf_inspect.getmembers(init_ops, tf_inspect.isclass))
3149    kwargs['initializer'] = serialization._deserialize_keras_object(  # pylint: disable=protected-access
3150        config['initializer'],
3151        module_objects=all_initializers,
3152        custom_objects=custom_objects)
3153    return cls(**kwargs)
3154
3155
3156def _raise_shared_embedding_column_error():
3157  raise ValueError('SharedEmbeddingColumns are not supported in '
3158                   '`linear_model` or `input_layer`. Please use '
3159                   '`DenseFeatures` or `LinearModel` instead.')
3160
3161
3162class SharedEmbeddingColumnCreator(autotrackable.AutoTrackable):
3163  """Class that creates a `SharedEmbeddingColumn`."""
3164
3165  def __init__(self,
3166               dimension,
3167               initializer,
3168               ckpt_to_load_from,
3169               tensor_name_in_ckpt,
3170               num_buckets,
3171               trainable,
3172               name='shared_embedding_column_creator',
3173               use_safe_embedding_lookup=True):
3174    self._dimension = dimension
3175    self._initializer = initializer
3176    self._ckpt_to_load_from = ckpt_to_load_from
3177    self._tensor_name_in_ckpt = tensor_name_in_ckpt
3178    self._num_buckets = num_buckets
3179    self._trainable = trainable
3180    self._name = name
3181    self._use_safe_embedding_lookup = use_safe_embedding_lookup
3182    # Map from graph keys to embedding_weight variables.
3183    self._embedding_weights = {}
3184
3185  def __call__(self, categorical_column, combiner, max_norm):
3186    return SharedEmbeddingColumn(categorical_column, self, combiner, max_norm,
3187                                 self._use_safe_embedding_lookup)
3188
3189  @property
3190  def embedding_weights(self):
3191    key = ops.get_default_graph()._graph_key  # pylint: disable=protected-access
3192    if key not in self._embedding_weights:
3193      embedding_shape = (self._num_buckets, self._dimension)
3194      var = variable_scope.get_variable(
3195          name=self._name,
3196          shape=embedding_shape,
3197          dtype=dtypes.float32,
3198          initializer=self._initializer,
3199          trainable=self._trainable)
3200
3201      if self._ckpt_to_load_from is not None:
3202        to_restore = var
3203        if isinstance(to_restore, variables.PartitionedVariable):
3204          to_restore = to_restore._get_variable_list()  # pylint: disable=protected-access
3205        checkpoint_utils.init_from_checkpoint(
3206            self._ckpt_to_load_from, {self._tensor_name_in_ckpt: to_restore})
3207      self._embedding_weights[key] = var
3208    return self._embedding_weights[key]
3209
3210  @property
3211  def dimension(self):
3212    return self._dimension
3213
3214
3215class SharedEmbeddingColumn(
3216    DenseColumn,
3217    SequenceDenseColumn,
3218    fc_old._DenseColumn,  # pylint: disable=protected-access
3219    fc_old._SequenceDenseColumn,  # pylint: disable=protected-access
3220    collections.namedtuple(
3221        'SharedEmbeddingColumn',
3222        ('categorical_column', 'shared_embedding_column_creator', 'combiner',
3223         'max_norm', 'use_safe_embedding_lookup'))):
3224  """See `embedding_column`."""
3225
3226  def __new__(cls,
3227              categorical_column,
3228              shared_embedding_column_creator,
3229              combiner,
3230              max_norm,
3231              use_safe_embedding_lookup=True):
3232    return super(SharedEmbeddingColumn, cls).__new__(
3233        cls,
3234        categorical_column=categorical_column,
3235        shared_embedding_column_creator=shared_embedding_column_creator,
3236        combiner=combiner,
3237        max_norm=max_norm,
3238        use_safe_embedding_lookup=use_safe_embedding_lookup)
3239
3240  @property
3241  def _is_v2_column(self):
3242    return True
3243
3244  @property
3245  def name(self):
3246    """See `FeatureColumn` base class."""
3247    return '{}_shared_embedding'.format(self.categorical_column.name)
3248
3249  @property
3250  def parse_example_spec(self):
3251    """See `FeatureColumn` base class."""
3252    return self.categorical_column.parse_example_spec
3253
3254  @property
3255  def _parse_example_spec(self):
3256    return _raise_shared_embedding_column_error()
3257
3258  def transform_feature(self, transformation_cache, state_manager):
3259    """See `FeatureColumn` base class."""
3260    return transformation_cache.get(self.categorical_column, state_manager)
3261
3262  def _transform_feature(self, inputs):
3263    return _raise_shared_embedding_column_error()
3264
3265  @property
3266  def variable_shape(self):
3267    """See `DenseColumn` base class."""
3268    return tensor_shape.TensorShape(
3269        [self.shared_embedding_column_creator.dimension])
3270
3271  @property
3272  def _variable_shape(self):
3273    return _raise_shared_embedding_column_error()
3274
3275  def _get_dense_tensor_internal(self, transformation_cache, state_manager):
3276    """Private method that follows the signature of _get_dense_tensor."""
3277    # This method is called from a variable_scope with name _var_scope_name,
3278    # which is shared among all shared embeddings. Open a name_scope here, so
3279    # that the ops for different columns have distinct names.
3280    with ops.name_scope(None, default_name=self.name):
3281      # Get sparse IDs and weights.
3282      sparse_tensors = self.categorical_column.get_sparse_tensors(
3283          transformation_cache, state_manager)
3284      sparse_ids = sparse_tensors.id_tensor
3285      sparse_weights = sparse_tensors.weight_tensor
3286
3287      embedding_weights = self.shared_embedding_column_creator.embedding_weights
3288
3289      sparse_id_rank = tensor_shape.dimension_value(
3290          sparse_ids.dense_shape.get_shape()[0])
3291      embedding_lookup_sparse = embedding_ops.safe_embedding_lookup_sparse
3292      if (not self.use_safe_embedding_lookup and sparse_id_rank is not None and
3293          sparse_id_rank <= 2):
3294        embedding_lookup_sparse = embedding_ops.embedding_lookup_sparse_v2
3295      # Return embedding lookup result.
3296      return embedding_lookup_sparse(
3297          embedding_weights,
3298          sparse_ids,
3299          sparse_weights,
3300          combiner=self.combiner,
3301          name='%s_weights' % self.name,
3302          max_norm=self.max_norm)
3303
3304  def get_dense_tensor(self, transformation_cache, state_manager):
3305    """Returns the embedding lookup result."""
3306    if isinstance(self.categorical_column, SequenceCategoricalColumn):
3307      raise ValueError(
3308          'In embedding_column: {}. '
3309          'categorical_column must not be of type SequenceCategoricalColumn. '
3310          'Suggested fix A: If you wish to use DenseFeatures, use a '
3311          'non-sequence categorical_column_with_*. '
3312          'Suggested fix B: If you wish to create sequence input, use '
3313          'SequenceFeatures instead of DenseFeatures. '
3314          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
3315                                       self.categorical_column))
3316    return self._get_dense_tensor_internal(transformation_cache, state_manager)
3317
3318  def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
3319    return _raise_shared_embedding_column_error()
3320
3321  def get_sequence_dense_tensor(self, transformation_cache, state_manager):
3322    """See `SequenceDenseColumn` base class."""
3323    if not isinstance(self.categorical_column, SequenceCategoricalColumn):
3324      raise ValueError(
3325          'In embedding_column: {}. '
3326          'categorical_column must be of type SequenceCategoricalColumn '
3327          'to use SequenceFeatures. '
3328          'Suggested fix: Use one of sequence_categorical_column_with_*. '
3329          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
3330                                       self.categorical_column))
3331    dense_tensor = self._get_dense_tensor_internal(transformation_cache,
3332                                                   state_manager)
3333    sparse_tensors = self.categorical_column.get_sparse_tensors(
3334        transformation_cache, state_manager)
3335    sequence_length = fc_utils.sequence_length_from_sparse_tensor(
3336        sparse_tensors.id_tensor)
3337    return SequenceDenseColumn.TensorSequenceLengthPair(
3338        dense_tensor=dense_tensor, sequence_length=sequence_length)
3339
3340  def _get_sequence_dense_tensor(self,
3341                                 inputs,
3342                                 weight_collections=None,
3343                                 trainable=None):
3344    return _raise_shared_embedding_column_error()
3345
3346  @property
3347  def parents(self):
3348    """See 'FeatureColumn` base class."""
3349    return [self.categorical_column]
3350
3351
3352def _check_shape(shape, key):
3353  """Returns shape if it's valid, raises error otherwise."""
3354  assert shape is not None
3355  if not nest.is_nested(shape):
3356    shape = [shape]
3357  shape = tuple(shape)
3358  for dimension in shape:
3359    if not isinstance(dimension, int):
3360      raise TypeError('shape dimensions must be integer. '
3361                      'shape: {}, key: {}'.format(shape, key))
3362    if dimension < 1:
3363      raise ValueError('shape dimensions must be greater than 0. '
3364                       'shape: {}, key: {}'.format(shape, key))
3365  return shape
3366
3367
3368class HashedCategoricalColumn(
3369    CategoricalColumn,
3370    fc_old._CategoricalColumn,  # pylint: disable=protected-access
3371    collections.namedtuple('HashedCategoricalColumn',
3372                           ('key', 'hash_bucket_size', 'dtype'))):
3373  """see `categorical_column_with_hash_bucket`."""
3374
3375  @property
3376  def _is_v2_column(self):
3377    return True
3378
3379  @property
3380  def name(self):
3381    """See `FeatureColumn` base class."""
3382    return self.key
3383
3384  @property
3385  def parse_example_spec(self):
3386    """See `FeatureColumn` base class."""
3387    return {self.key: parsing_ops.VarLenFeature(self.dtype)}
3388
3389  @property
3390  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3391                          _FEATURE_COLUMN_DEPRECATION)
3392  def _parse_example_spec(self):
3393    return self.parse_example_spec
3394
3395  def _transform_input_tensor(self, input_tensor):
3396    """Hashes the values in the feature_column."""
3397    if not isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
3398      raise ValueError('SparseColumn input must be a SparseTensor.')
3399
3400    fc_utils.assert_string_or_int(
3401        input_tensor.dtype,
3402        prefix='column_name: {} input_tensor'.format(self.key))
3403
3404    if self.dtype.is_integer != input_tensor.dtype.is_integer:
3405      raise ValueError(
3406          'Column dtype and SparseTensors dtype must be compatible. '
3407          'key: {}, column dtype: {}, tensor dtype: {}'.format(
3408              self.key, self.dtype, input_tensor.dtype))
3409
3410    if self.dtype == dtypes.string:
3411      sparse_values = input_tensor.values
3412    else:
3413      sparse_values = string_ops.as_string(input_tensor.values)
3414
3415    sparse_id_values = string_ops.string_to_hash_bucket_fast(
3416        sparse_values, self.hash_bucket_size, name='lookup')
3417    return sparse_tensor_lib.SparseTensor(
3418        input_tensor.indices, sparse_id_values, input_tensor.dense_shape)
3419
3420  def transform_feature(self, transformation_cache, state_manager):
3421    """Hashes the values in the feature_column."""
3422    input_tensor = _to_sparse_input_and_drop_ignore_values(
3423        transformation_cache.get(self.key, state_manager))
3424    return self._transform_input_tensor(input_tensor)
3425
3426  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3427                          _FEATURE_COLUMN_DEPRECATION)
3428  def _transform_feature(self, inputs):
3429    input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
3430    return self._transform_input_tensor(input_tensor)
3431
3432  @property
3433  def num_buckets(self):
3434    """Returns number of buckets in this sparse feature."""
3435    return self.hash_bucket_size
3436
3437  @property
3438  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3439                          _FEATURE_COLUMN_DEPRECATION)
3440  def _num_buckets(self):
3441    return self.num_buckets
3442
3443  def get_sparse_tensors(self, transformation_cache, state_manager):
3444    """See `CategoricalColumn` base class."""
3445    return CategoricalColumn.IdWeightPair(
3446        transformation_cache.get(self, state_manager), None)
3447
3448  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3449                          _FEATURE_COLUMN_DEPRECATION)
3450  def _get_sparse_tensors(self, inputs, weight_collections=None,
3451                          trainable=None):
3452    del weight_collections
3453    del trainable
3454    return CategoricalColumn.IdWeightPair(inputs.get(self), None)
3455
3456  @property
3457  def parents(self):
3458    """See 'FeatureColumn` base class."""
3459    return [self.key]
3460
3461  def get_config(self):
3462    """See 'FeatureColumn` base class."""
3463    config = dict(zip(self._fields, self))
3464    config['dtype'] = self.dtype.name
3465    return config
3466
3467  @classmethod
3468  def from_config(cls, config, custom_objects=None, columns_by_name=None):
3469    """See 'FeatureColumn` base class."""
3470    _check_config_keys(config, cls._fields)
3471    kwargs = _standardize_and_copy_config(config)
3472    kwargs['dtype'] = dtypes.as_dtype(config['dtype'])
3473    return cls(**kwargs)
3474
3475
3476class VocabularyFileCategoricalColumn(
3477    CategoricalColumn,
3478    fc_old._CategoricalColumn,  # pylint: disable=protected-access
3479    collections.namedtuple(
3480        'VocabularyFileCategoricalColumn',
3481        ('key', 'vocabulary_file', 'vocabulary_size', 'num_oov_buckets',
3482         'dtype', 'default_value', 'file_format'))):
3483  """See `categorical_column_with_vocabulary_file`."""
3484
3485  def __new__(cls,
3486              key,
3487              vocabulary_file,
3488              vocabulary_size,
3489              num_oov_buckets,
3490              dtype,
3491              default_value,
3492              file_format=None):
3493    return super(VocabularyFileCategoricalColumn, cls).__new__(
3494        cls,
3495        key=key,
3496        vocabulary_file=vocabulary_file,
3497        vocabulary_size=vocabulary_size,
3498        num_oov_buckets=num_oov_buckets,
3499        dtype=dtype,
3500        default_value=default_value,
3501        file_format=file_format)
3502
3503  @property
3504  def _is_v2_column(self):
3505    return True
3506
3507  @property
3508  def name(self):
3509    """See `FeatureColumn` base class."""
3510    return self.key
3511
3512  @property
3513  def parse_example_spec(self):
3514    """See `FeatureColumn` base class."""
3515    return {self.key: parsing_ops.VarLenFeature(self.dtype)}
3516
3517  @property
3518  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3519                          _FEATURE_COLUMN_DEPRECATION)
3520  def _parse_example_spec(self):
3521    return self.parse_example_spec
3522
3523  def _make_table_from_tfrecord_gzip_file(self, key_dtype, name):
3524    dataset = readers.TFRecordDataset(
3525        self.vocabulary_file, compression_type='GZIP')
3526
3527    def key_dtype_fn(key):
3528      return key if key_dtype is dtypes.string else string_ops.string_to_number(
3529          key, out_type=key_dtype)
3530
3531    return data_lookup_ops.index_table_from_dataset(
3532        dataset.map(key_dtype_fn),
3533        num_oov_buckets=self.num_oov_buckets,
3534        vocab_size=self.vocabulary_size,
3535        default_value=self.default_value,
3536        key_dtype=key_dtype,
3537        name=name)
3538
3539  def _make_table(self, key_dtype, state_manager):
3540    name = '{}_lookup'.format(self.key)
3541    if state_manager is None or not state_manager.has_resource(self, name):
3542      with ops.init_scope():
3543        if self.file_format == 'tfrecord_gzip':
3544          table = self._make_table_from_tfrecord_gzip_file(key_dtype, name)
3545        else:
3546          table = lookup_ops.index_table_from_file(
3547              vocabulary_file=self.vocabulary_file,
3548              num_oov_buckets=self.num_oov_buckets,
3549              vocab_size=self.vocabulary_size,
3550              default_value=self.default_value,
3551              key_dtype=key_dtype,
3552              name=name)
3553      if state_manager is not None:
3554        state_manager.add_resource(self, name, table)
3555    else:
3556      # Reuse the table from the previous run.
3557      table = state_manager.get_resource(self, name)
3558    return table
3559
3560  def _transform_input_tensor(self, input_tensor, state_manager=None):
3561    """Creates a lookup table for the vocabulary."""
3562    if self.dtype.is_integer != input_tensor.dtype.is_integer:
3563      raise ValueError(
3564          'Column dtype and SparseTensors dtype must be compatible. '
3565          'key: {}, column dtype: {}, tensor dtype: {}'.format(
3566              self.key, self.dtype, input_tensor.dtype))
3567
3568    fc_utils.assert_string_or_int(
3569        input_tensor.dtype,
3570        prefix='column_name: {} input_tensor'.format(self.key))
3571
3572    key_dtype = self.dtype
3573    if input_tensor.dtype.is_integer:
3574      # `index_table_from_file` requires 64-bit integer keys.
3575      key_dtype = dtypes.int64
3576      input_tensor = math_ops.cast(input_tensor, dtypes.int64)
3577    return self._make_table(key_dtype, state_manager).lookup(input_tensor)
3578
3579  def transform_feature(self, transformation_cache, state_manager):
3580    """Creates a lookup table for the vocabulary."""
3581    input_tensor = _to_sparse_input_and_drop_ignore_values(
3582        transformation_cache.get(self.key, state_manager))
3583    return self._transform_input_tensor(input_tensor, state_manager)
3584
3585  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3586                          _FEATURE_COLUMN_DEPRECATION)
3587  def _transform_feature(self, inputs):
3588    input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
3589    return self._transform_input_tensor(input_tensor)
3590
3591  @property
3592  def num_buckets(self):
3593    """Returns number of buckets in this sparse feature."""
3594    return self.vocabulary_size + self.num_oov_buckets
3595
3596  @property
3597  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3598                          _FEATURE_COLUMN_DEPRECATION)
3599  def _num_buckets(self):
3600    return self.num_buckets
3601
3602  def get_sparse_tensors(self, transformation_cache, state_manager):
3603    """See `CategoricalColumn` base class."""
3604    return CategoricalColumn.IdWeightPair(
3605        transformation_cache.get(self, state_manager), None)
3606
3607  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3608                          _FEATURE_COLUMN_DEPRECATION)
3609  def _get_sparse_tensors(self, inputs, weight_collections=None,
3610                          trainable=None):
3611    del weight_collections
3612    del trainable
3613    return CategoricalColumn.IdWeightPair(inputs.get(self), None)
3614
3615  @property
3616  def parents(self):
3617    """See 'FeatureColumn` base class."""
3618    return [self.key]
3619
3620  def get_config(self):
3621    """See 'FeatureColumn` base class."""
3622    config = dict(zip(self._fields, self))
3623    config['dtype'] = self.dtype.name
3624    return config
3625
3626  @classmethod
3627  def from_config(cls, config, custom_objects=None, columns_by_name=None):
3628    """See 'FeatureColumn` base class."""
3629    _check_config_keys(config, cls._fields)
3630    kwargs = _standardize_and_copy_config(config)
3631    kwargs['dtype'] = dtypes.as_dtype(config['dtype'])
3632    return cls(**kwargs)
3633
3634
3635class VocabularyListCategoricalColumn(
3636    CategoricalColumn,
3637    fc_old._CategoricalColumn,  # pylint: disable=protected-access
3638    collections.namedtuple(
3639        'VocabularyListCategoricalColumn',
3640        ('key', 'vocabulary_list', 'dtype', 'default_value', 'num_oov_buckets'))
3641):
3642  """See `categorical_column_with_vocabulary_list`."""
3643
3644  @property
3645  def _is_v2_column(self):
3646    return True
3647
3648  @property
3649  def name(self):
3650    """See `FeatureColumn` base class."""
3651    return self.key
3652
3653  @property
3654  def parse_example_spec(self):
3655    """See `FeatureColumn` base class."""
3656    return {self.key: parsing_ops.VarLenFeature(self.dtype)}
3657
3658  @property
3659  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3660                          _FEATURE_COLUMN_DEPRECATION)
3661  def _parse_example_spec(self):
3662    return self.parse_example_spec
3663
3664  def _transform_input_tensor(self, input_tensor, state_manager=None):
3665    """Creates a lookup table for the vocabulary list."""
3666    if self.dtype.is_integer != input_tensor.dtype.is_integer:
3667      raise ValueError(
3668          'Column dtype and SparseTensors dtype must be compatible. '
3669          'key: {}, column dtype: {}, tensor dtype: {}'.format(
3670              self.key, self.dtype, input_tensor.dtype))
3671
3672    fc_utils.assert_string_or_int(
3673        input_tensor.dtype,
3674        prefix='column_name: {} input_tensor'.format(self.key))
3675
3676    key_dtype = self.dtype
3677    if input_tensor.dtype.is_integer:
3678      # `index_table_from_tensor` requires 64-bit integer keys.
3679      key_dtype = dtypes.int64
3680      input_tensor = math_ops.cast(input_tensor, dtypes.int64)
3681
3682    name = '{}_lookup'.format(self.key)
3683    if state_manager is None or not state_manager.has_resource(self, name):
3684      with ops.init_scope():
3685        table = lookup_ops.index_table_from_tensor(
3686            vocabulary_list=tuple(self.vocabulary_list),
3687            default_value=self.default_value,
3688            num_oov_buckets=self.num_oov_buckets,
3689            dtype=key_dtype,
3690            name=name)
3691      if state_manager is not None:
3692        state_manager.add_resource(self, name, table)
3693    else:
3694      # Reuse the table from the previous run.
3695      table = state_manager.get_resource(self, name)
3696    return table.lookup(input_tensor)
3697
3698  def transform_feature(self, transformation_cache, state_manager):
3699    """Creates a lookup table for the vocabulary list."""
3700    input_tensor = _to_sparse_input_and_drop_ignore_values(
3701        transformation_cache.get(self.key, state_manager))
3702    return self._transform_input_tensor(input_tensor, state_manager)
3703
3704  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3705                          _FEATURE_COLUMN_DEPRECATION)
3706  def _transform_feature(self, inputs):
3707    input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
3708    return self._transform_input_tensor(input_tensor)
3709
3710  @property
3711  def num_buckets(self):
3712    """Returns number of buckets in this sparse feature."""
3713    return len(self.vocabulary_list) + self.num_oov_buckets
3714
3715  @property
3716  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3717                          _FEATURE_COLUMN_DEPRECATION)
3718  def _num_buckets(self):
3719    return self.num_buckets
3720
3721  def get_sparse_tensors(self, transformation_cache, state_manager):
3722    """See `CategoricalColumn` base class."""
3723    return CategoricalColumn.IdWeightPair(
3724        transformation_cache.get(self, state_manager), None)
3725
3726  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3727                          _FEATURE_COLUMN_DEPRECATION)
3728  def _get_sparse_tensors(self, inputs, weight_collections=None,
3729                          trainable=None):
3730    del weight_collections
3731    del trainable
3732    return CategoricalColumn.IdWeightPair(inputs.get(self), None)
3733
3734  @property
3735  def parents(self):
3736    """See 'FeatureColumn` base class."""
3737    return [self.key]
3738
3739  def get_config(self):
3740    """See 'FeatureColumn` base class."""
3741    config = dict(zip(self._fields, self))
3742    config['dtype'] = self.dtype.name
3743    return config
3744
3745  @classmethod
3746  def from_config(cls, config, custom_objects=None, columns_by_name=None):
3747    """See 'FeatureColumn` base class."""
3748    _check_config_keys(config, cls._fields)
3749    kwargs = _standardize_and_copy_config(config)
3750    kwargs['dtype'] = dtypes.as_dtype(config['dtype'])
3751    return cls(**kwargs)
3752
3753
3754class IdentityCategoricalColumn(
3755    CategoricalColumn,
3756    fc_old._CategoricalColumn,  # pylint: disable=protected-access
3757    collections.namedtuple('IdentityCategoricalColumn',
3758                           ('key', 'number_buckets', 'default_value'))):
3759
3760  """See `categorical_column_with_identity`."""
3761
3762  @property
3763  def _is_v2_column(self):
3764    return True
3765
3766  @property
3767  def name(self):
3768    """See `FeatureColumn` base class."""
3769    return self.key
3770
3771  @property
3772  def parse_example_spec(self):
3773    """See `FeatureColumn` base class."""
3774    return {self.key: parsing_ops.VarLenFeature(dtypes.int64)}
3775
3776  @property
3777  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3778                          _FEATURE_COLUMN_DEPRECATION)
3779  def _parse_example_spec(self):
3780    return self.parse_example_spec
3781
3782  def _transform_input_tensor(self, input_tensor):
3783    """Returns a SparseTensor with identity values."""
3784    if not input_tensor.dtype.is_integer:
3785      raise ValueError(
3786          'Invalid input, not integer. key: {} dtype: {}'.format(
3787              self.key, input_tensor.dtype))
3788    values = input_tensor.values
3789    if input_tensor.values.dtype != dtypes.int64:
3790      values = math_ops.cast(values, dtypes.int64, name='values')
3791    if self.default_value is not None:
3792      values = math_ops.cast(input_tensor.values, dtypes.int64, name='values')
3793      num_buckets = math_ops.cast(
3794          self.num_buckets, dtypes.int64, name='num_buckets')
3795      zero = math_ops.cast(0, dtypes.int64, name='zero')
3796      # Assign default for out-of-range values.
3797      values = array_ops.where_v2(
3798          math_ops.logical_or(
3799              values < zero, values >= num_buckets, name='out_of_range'),
3800          array_ops.fill(
3801              dims=array_ops.shape(values),
3802              value=math_ops.cast(self.default_value, dtypes.int64),
3803              name='default_values'), values)
3804
3805    return sparse_tensor_lib.SparseTensor(
3806        indices=input_tensor.indices,
3807        values=values,
3808        dense_shape=input_tensor.dense_shape)
3809
3810  def transform_feature(self, transformation_cache, state_manager):
3811    """Returns a SparseTensor with identity values."""
3812    input_tensor = _to_sparse_input_and_drop_ignore_values(
3813        transformation_cache.get(self.key, state_manager))
3814    return self._transform_input_tensor(input_tensor)
3815
3816  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3817                          _FEATURE_COLUMN_DEPRECATION)
3818  def _transform_feature(self, inputs):
3819    input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key))
3820    return self._transform_input_tensor(input_tensor)
3821
3822  @property
3823  def num_buckets(self):
3824    """Returns number of buckets in this sparse feature."""
3825    return self.number_buckets
3826
3827  @property
3828  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3829                          _FEATURE_COLUMN_DEPRECATION)
3830  def _num_buckets(self):
3831    return self.num_buckets
3832
3833  def get_sparse_tensors(self, transformation_cache, state_manager):
3834    """See `CategoricalColumn` base class."""
3835    return CategoricalColumn.IdWeightPair(
3836        transformation_cache.get(self, state_manager), None)
3837
3838  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3839                          _FEATURE_COLUMN_DEPRECATION)
3840  def _get_sparse_tensors(self, inputs, weight_collections=None,
3841                          trainable=None):
3842    del weight_collections
3843    del trainable
3844    return CategoricalColumn.IdWeightPair(inputs.get(self), None)
3845
3846  @property
3847  def parents(self):
3848    """See 'FeatureColumn` base class."""
3849    return [self.key]
3850
3851  def get_config(self):
3852    """See 'FeatureColumn` base class."""
3853    return dict(zip(self._fields, self))
3854
3855  @classmethod
3856  def from_config(cls, config, custom_objects=None, columns_by_name=None):
3857    """See 'FeatureColumn` base class."""
3858    _check_config_keys(config, cls._fields)
3859    kwargs = _standardize_and_copy_config(config)
3860    return cls(**kwargs)
3861
3862
3863class WeightedCategoricalColumn(
3864    CategoricalColumn,
3865    fc_old._CategoricalColumn,  # pylint: disable=protected-access
3866    collections.namedtuple(
3867        'WeightedCategoricalColumn',
3868        ('categorical_column', 'weight_feature_key', 'dtype'))):
3869  """See `weighted_categorical_column`."""
3870
3871  @property
3872  def _is_v2_column(self):
3873    return (isinstance(self.categorical_column, FeatureColumn) and
3874            self.categorical_column._is_v2_column)  # pylint: disable=protected-access
3875
3876  @property
3877  def name(self):
3878    """See `FeatureColumn` base class."""
3879    return '{}_weighted_by_{}'.format(
3880        self.categorical_column.name, self.weight_feature_key)
3881
3882  @property
3883  def parse_example_spec(self):
3884    """See `FeatureColumn` base class."""
3885    config = self.categorical_column.parse_example_spec
3886    if self.weight_feature_key in config:
3887      raise ValueError('Parse config {} already exists for {}.'.format(
3888          config[self.weight_feature_key], self.weight_feature_key))
3889    config[self.weight_feature_key] = parsing_ops.VarLenFeature(self.dtype)
3890    return config
3891
3892  @property
3893  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3894                          _FEATURE_COLUMN_DEPRECATION)
3895  def _parse_example_spec(self):
3896    config = self.categorical_column._parse_example_spec  # pylint: disable=protected-access
3897    if self.weight_feature_key in config:
3898      raise ValueError('Parse config {} already exists for {}.'.format(
3899          config[self.weight_feature_key], self.weight_feature_key))
3900    config[self.weight_feature_key] = parsing_ops.VarLenFeature(self.dtype)
3901    return config
3902
3903  @property
3904  def num_buckets(self):
3905    """See `DenseColumn` base class."""
3906    return self.categorical_column.num_buckets
3907
3908  @property
3909  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3910                          _FEATURE_COLUMN_DEPRECATION)
3911  def _num_buckets(self):
3912    return self.categorical_column._num_buckets  # pylint: disable=protected-access
3913
3914  def _transform_weight_tensor(self, weight_tensor):
3915    if weight_tensor is None:
3916      raise ValueError('Missing weights {}.'.format(self.weight_feature_key))
3917    weight_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(
3918        weight_tensor)
3919    if self.dtype != weight_tensor.dtype.base_dtype:
3920      raise ValueError('Bad dtype, expected {}, but got {}.'.format(
3921          self.dtype, weight_tensor.dtype))
3922    if not isinstance(weight_tensor, sparse_tensor_lib.SparseTensor):
3923      # The weight tensor can be a regular Tensor. In this case, sparsify it.
3924      weight_tensor = _to_sparse_input_and_drop_ignore_values(
3925          weight_tensor, ignore_value=0.0)
3926    if not weight_tensor.dtype.is_floating:
3927      weight_tensor = math_ops.cast(weight_tensor, dtypes.float32)
3928    return weight_tensor
3929
3930  def transform_feature(self, transformation_cache, state_manager):
3931    """Applies weights to tensor generated from `categorical_column`'."""
3932    weight_tensor = transformation_cache.get(self.weight_feature_key,
3933                                             state_manager)
3934    sparse_weight_tensor = self._transform_weight_tensor(weight_tensor)
3935    sparse_categorical_tensor = _to_sparse_input_and_drop_ignore_values(
3936        transformation_cache.get(self.categorical_column, state_manager))
3937    return (sparse_categorical_tensor, sparse_weight_tensor)
3938
3939  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3940                          _FEATURE_COLUMN_DEPRECATION)
3941  def _transform_feature(self, inputs):
3942    """Applies weights to tensor generated from `categorical_column`'."""
3943    weight_tensor = inputs.get(self.weight_feature_key)
3944    weight_tensor = self._transform_weight_tensor(weight_tensor)
3945    return (inputs.get(self.categorical_column), weight_tensor)
3946
3947  def get_sparse_tensors(self, transformation_cache, state_manager):
3948    """See `CategoricalColumn` base class."""
3949    tensors = transformation_cache.get(self, state_manager)
3950    return CategoricalColumn.IdWeightPair(tensors[0], tensors[1])
3951
3952  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
3953                          _FEATURE_COLUMN_DEPRECATION)
3954  def _get_sparse_tensors(self, inputs, weight_collections=None,
3955                          trainable=None):
3956    del weight_collections
3957    del trainable
3958    tensors = inputs.get(self)
3959    return CategoricalColumn.IdWeightPair(tensors[0], tensors[1])
3960
3961  @property
3962  def parents(self):
3963    """See 'FeatureColumn` base class."""
3964    return [self.categorical_column, self.weight_feature_key]
3965
3966  def get_config(self):
3967    """See 'FeatureColumn` base class."""
3968    from tensorflow.python.feature_column.serialization import serialize_feature_column  # pylint: disable=g-import-not-at-top
3969    config = dict(zip(self._fields, self))
3970    config['categorical_column'] = serialize_feature_column(
3971        self.categorical_column)
3972    config['dtype'] = self.dtype.name
3973    return config
3974
3975  @classmethod
3976  def from_config(cls, config, custom_objects=None, columns_by_name=None):
3977    """See 'FeatureColumn` base class."""
3978    from tensorflow.python.feature_column.serialization import deserialize_feature_column  # pylint: disable=g-import-not-at-top
3979    _check_config_keys(config, cls._fields)
3980    kwargs = _standardize_and_copy_config(config)
3981    kwargs['categorical_column'] = deserialize_feature_column(
3982        config['categorical_column'], custom_objects, columns_by_name)
3983    kwargs['dtype'] = dtypes.as_dtype(config['dtype'])
3984    return cls(**kwargs)
3985
3986
3987class CrossedColumn(
3988    CategoricalColumn,
3989    fc_old._CategoricalColumn,  # pylint: disable=protected-access
3990    collections.namedtuple('CrossedColumn',
3991                           ('keys', 'hash_bucket_size', 'hash_key'))):
3992  """See `crossed_column`."""
3993
3994  @property
3995  def _is_v2_column(self):
3996    for key in _collect_leaf_level_keys(self):
3997      if isinstance(key, six.string_types):
3998        continue
3999      if not isinstance(key, FeatureColumn):
4000        return False
4001      if not key._is_v2_column:  # pylint: disable=protected-access
4002        return False
4003    return True
4004
4005  @property
4006  def name(self):
4007    """See `FeatureColumn` base class."""
4008    feature_names = []
4009    for key in _collect_leaf_level_keys(self):
4010      if isinstance(key, (FeatureColumn, fc_old._FeatureColumn)):  # pylint: disable=protected-access
4011        feature_names.append(key.name)
4012      else:  # key must be a string
4013        feature_names.append(key)
4014    return '_X_'.join(sorted(feature_names))
4015
4016  @property
4017  def parse_example_spec(self):
4018    """See `FeatureColumn` base class."""
4019    config = {}
4020    for key in self.keys:
4021      if isinstance(key, FeatureColumn):
4022        config.update(key.parse_example_spec)
4023      elif isinstance(key, fc_old._FeatureColumn):  # pylint: disable=protected-access
4024        config.update(key._parse_example_spec)  # pylint: disable=protected-access
4025      else:  # key must be a string
4026        config.update({key: parsing_ops.VarLenFeature(dtypes.string)})
4027    return config
4028
4029  @property
4030  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4031                          _FEATURE_COLUMN_DEPRECATION)
4032  def _parse_example_spec(self):
4033    return self.parse_example_spec
4034
4035  def transform_feature(self, transformation_cache, state_manager):
4036    """Generates a hashed sparse cross from the input tensors."""
4037    feature_tensors = []
4038    for key in _collect_leaf_level_keys(self):
4039      if isinstance(key, six.string_types):
4040        feature_tensors.append(transformation_cache.get(key, state_manager))
4041      elif isinstance(key, (fc_old._CategoricalColumn, CategoricalColumn)):  # pylint: disable=protected-access
4042        ids_and_weights = key.get_sparse_tensors(transformation_cache,
4043                                                 state_manager)
4044        if ids_and_weights.weight_tensor is not None:
4045          raise ValueError(
4046              'crossed_column does not support weight_tensor, but the given '
4047              'column populates weight_tensor. '
4048              'Given column: {}'.format(key.name))
4049        feature_tensors.append(ids_and_weights.id_tensor)
4050      else:
4051        raise ValueError('Unsupported column type. Given: {}'.format(key))
4052    return sparse_ops.sparse_cross_hashed(
4053        inputs=feature_tensors,
4054        num_buckets=self.hash_bucket_size,
4055        hash_key=self.hash_key)
4056
4057  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4058                          _FEATURE_COLUMN_DEPRECATION)
4059  def _transform_feature(self, inputs):
4060    """Generates a hashed sparse cross from the input tensors."""
4061    feature_tensors = []
4062    for key in _collect_leaf_level_keys(self):
4063      if isinstance(key, six.string_types):
4064        feature_tensors.append(inputs.get(key))
4065      elif isinstance(key, (CategoricalColumn, fc_old._CategoricalColumn)):  # pylint: disable=protected-access
4066        ids_and_weights = key._get_sparse_tensors(inputs)  # pylint: disable=protected-access
4067        if ids_and_weights.weight_tensor is not None:
4068          raise ValueError(
4069              'crossed_column does not support weight_tensor, but the given '
4070              'column populates weight_tensor. '
4071              'Given column: {}'.format(key.name))
4072        feature_tensors.append(ids_and_weights.id_tensor)
4073      else:
4074        raise ValueError('Unsupported column type. Given: {}'.format(key))
4075    return sparse_ops.sparse_cross_hashed(
4076        inputs=feature_tensors,
4077        num_buckets=self.hash_bucket_size,
4078        hash_key=self.hash_key)
4079
4080  @property
4081  def num_buckets(self):
4082    """Returns number of buckets in this sparse feature."""
4083    return self.hash_bucket_size
4084
4085  @property
4086  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4087                          _FEATURE_COLUMN_DEPRECATION)
4088  def _num_buckets(self):
4089    return self.num_buckets
4090
4091  def get_sparse_tensors(self, transformation_cache, state_manager):
4092    """See `CategoricalColumn` base class."""
4093    return CategoricalColumn.IdWeightPair(
4094        transformation_cache.get(self, state_manager), None)
4095
4096  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4097                          _FEATURE_COLUMN_DEPRECATION)
4098  def _get_sparse_tensors(self, inputs, weight_collections=None,
4099                          trainable=None):
4100    """See `CategoricalColumn` base class."""
4101    del weight_collections
4102    del trainable
4103    return CategoricalColumn.IdWeightPair(inputs.get(self), None)
4104
4105  @property
4106  def parents(self):
4107    """See 'FeatureColumn` base class."""
4108    return list(self.keys)
4109
4110  def get_config(self):
4111    """See 'FeatureColumn` base class."""
4112    from tensorflow.python.feature_column.serialization import serialize_feature_column  # pylint: disable=g-import-not-at-top
4113    config = dict(zip(self._fields, self))
4114    config['keys'] = tuple([serialize_feature_column(fc) for fc in self.keys])
4115    return config
4116
4117  @classmethod
4118  def from_config(cls, config, custom_objects=None, columns_by_name=None):
4119    """See 'FeatureColumn` base class."""
4120    from tensorflow.python.feature_column.serialization import deserialize_feature_column  # pylint: disable=g-import-not-at-top
4121    _check_config_keys(config, cls._fields)
4122    kwargs = _standardize_and_copy_config(config)
4123    kwargs['keys'] = tuple([
4124        deserialize_feature_column(c, custom_objects, columns_by_name)
4125        for c in config['keys']
4126    ])
4127    return cls(**kwargs)
4128
4129
4130def _collect_leaf_level_keys(cross):
4131  """Collects base keys by expanding all nested crosses.
4132
4133  Args:
4134    cross: A `CrossedColumn`.
4135
4136  Returns:
4137    A list of strings or `CategoricalColumn` instances.
4138  """
4139  leaf_level_keys = []
4140  for k in cross.keys:
4141    if isinstance(k, CrossedColumn):
4142      leaf_level_keys.extend(_collect_leaf_level_keys(k))
4143    else:
4144      leaf_level_keys.append(k)
4145  return leaf_level_keys
4146
4147
4148def _prune_invalid_ids(sparse_ids, sparse_weights):
4149  """Prune invalid IDs (< 0) from the input ids and weights."""
4150  is_id_valid = math_ops.greater_equal(sparse_ids.values, 0)
4151  if sparse_weights is not None:
4152    is_id_valid = math_ops.logical_and(
4153        is_id_valid,
4154        array_ops.ones_like(sparse_weights.values, dtype=dtypes.bool))
4155  sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid)
4156  if sparse_weights is not None:
4157    sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid)
4158  return sparse_ids, sparse_weights
4159
4160
4161def _prune_invalid_weights(sparse_ids, sparse_weights):
4162  """Prune invalid weights (< 0) from the input ids and weights."""
4163  if sparse_weights is not None:
4164    is_weights_valid = math_ops.greater(sparse_weights.values, 0)
4165    sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_weights_valid)
4166    sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_weights_valid)
4167  return sparse_ids, sparse_weights
4168
4169
4170class IndicatorColumn(
4171    DenseColumn,
4172    SequenceDenseColumn,
4173    fc_old._DenseColumn,  # pylint: disable=protected-access
4174    fc_old._SequenceDenseColumn,  # pylint: disable=protected-access
4175    collections.namedtuple('IndicatorColumn', ('categorical_column'))):
4176  """Represents a one-hot column for use in deep networks.
4177
4178  Args:
4179    categorical_column: A `CategoricalColumn` which is created by
4180      `categorical_column_with_*` function.
4181  """
4182
4183  @property
4184  def _is_v2_column(self):
4185    return (isinstance(self.categorical_column, FeatureColumn) and
4186            self.categorical_column._is_v2_column)  # pylint: disable=protected-access
4187
4188  @property
4189  def name(self):
4190    """See `FeatureColumn` base class."""
4191    return '{}_indicator'.format(self.categorical_column.name)
4192
4193  def _transform_id_weight_pair(self, id_weight_pair, size):
4194    id_tensor = id_weight_pair.id_tensor
4195    weight_tensor = id_weight_pair.weight_tensor
4196
4197    # If the underlying column is weighted, return the input as a dense tensor.
4198    if weight_tensor is not None:
4199      weighted_column = sparse_ops.sparse_merge(
4200          sp_ids=id_tensor, sp_values=weight_tensor, vocab_size=int(size))
4201      # Remove (?, -1) index.
4202      weighted_column = sparse_ops.sparse_slice(weighted_column, [0, 0],
4203                                                weighted_column.dense_shape)
4204      # Use scatter_nd to merge duplicated indices if existed,
4205      # instead of sparse_tensor_to_dense.
4206      return array_ops.scatter_nd(weighted_column.indices,
4207                                  weighted_column.values,
4208                                  weighted_column.dense_shape)
4209
4210    dense_id_tensor = sparse_ops.sparse_tensor_to_dense(
4211        id_tensor, default_value=-1)
4212
4213    # One hot must be float for tf.concat reasons since all other inputs to
4214    # input_layer are float32.
4215    one_hot_id_tensor = array_ops.one_hot(
4216        dense_id_tensor, depth=size, on_value=1.0, off_value=0.0)
4217
4218    # Reduce to get a multi-hot per example.
4219    return math_ops.reduce_sum(one_hot_id_tensor, axis=[-2])
4220
4221  def transform_feature(self, transformation_cache, state_manager):
4222    """Returns dense `Tensor` representing feature.
4223
4224    Args:
4225      transformation_cache: A `FeatureTransformationCache` object to access
4226        features.
4227      state_manager: A `StateManager` to create / access resources such as
4228        lookup tables.
4229
4230    Returns:
4231      Transformed feature `Tensor`.
4232
4233    Raises:
4234      ValueError: if input rank is not known at graph building time.
4235    """
4236    id_weight_pair = self.categorical_column.get_sparse_tensors(
4237        transformation_cache, state_manager)
4238    return self._transform_id_weight_pair(id_weight_pair,
4239                                          self.variable_shape[-1])
4240
4241  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4242                          _FEATURE_COLUMN_DEPRECATION)
4243  def _transform_feature(self, inputs):
4244    id_weight_pair = self.categorical_column._get_sparse_tensors(inputs)  # pylint: disable=protected-access
4245    return self._transform_id_weight_pair(id_weight_pair,
4246                                          self._variable_shape[-1])
4247
4248  @property
4249  def parse_example_spec(self):
4250    """See `FeatureColumn` base class."""
4251    return self.categorical_column.parse_example_spec
4252
4253  @property
4254  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4255                          _FEATURE_COLUMN_DEPRECATION)
4256  def _parse_example_spec(self):
4257    return self.categorical_column._parse_example_spec  # pylint: disable=protected-access
4258
4259  @property
4260  def variable_shape(self):
4261    """Returns a `TensorShape` representing the shape of the dense `Tensor`."""
4262    if isinstance(self.categorical_column, FeatureColumn):
4263      return tensor_shape.TensorShape([1, self.categorical_column.num_buckets])
4264    else:
4265      return tensor_shape.TensorShape([1, self.categorical_column._num_buckets])  # pylint: disable=protected-access
4266
4267  @property
4268  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4269                          _FEATURE_COLUMN_DEPRECATION)
4270  def _variable_shape(self):
4271    return tensor_shape.TensorShape([1, self.categorical_column._num_buckets])  # pylint: disable=protected-access
4272
4273  def get_dense_tensor(self, transformation_cache, state_manager):
4274    """Returns dense `Tensor` representing feature.
4275
4276    Args:
4277      transformation_cache: A `FeatureTransformationCache` object to access
4278        features.
4279      state_manager: A `StateManager` to create / access resources such as
4280        lookup tables.
4281
4282    Returns:
4283      Dense `Tensor` created within `transform_feature`.
4284
4285    Raises:
4286      ValueError: If `categorical_column` is a `SequenceCategoricalColumn`.
4287    """
4288    if isinstance(self.categorical_column, SequenceCategoricalColumn):
4289      raise ValueError(
4290          'In indicator_column: {}. '
4291          'categorical_column must not be of type SequenceCategoricalColumn. '
4292          'Suggested fix A: If you wish to use DenseFeatures, use a '
4293          'non-sequence categorical_column_with_*. '
4294          'Suggested fix B: If you wish to create sequence input, use '
4295          'SequenceFeatures instead of DenseFeatures. '
4296          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
4297                                       self.categorical_column))
4298    # Feature has been already transformed. Return the intermediate
4299    # representation created by transform_feature.
4300    return transformation_cache.get(self, state_manager)
4301
4302  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4303                          _FEATURE_COLUMN_DEPRECATION)
4304  def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
4305    del weight_collections
4306    del trainable
4307    if isinstance(
4308        self.categorical_column,
4309        (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)):  # pylint: disable=protected-access
4310      raise ValueError(
4311          'In indicator_column: {}. '
4312          'categorical_column must not be of type _SequenceCategoricalColumn. '
4313          'Suggested fix A: If you wish to use DenseFeatures, use a '
4314          'non-sequence categorical_column_with_*. '
4315          'Suggested fix B: If you wish to create sequence input, use '
4316          'SequenceFeatures instead of DenseFeatures. '
4317          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
4318                                       self.categorical_column))
4319    # Feature has been already transformed. Return the intermediate
4320    # representation created by transform_feature.
4321    return inputs.get(self)
4322
4323  def get_sequence_dense_tensor(self, transformation_cache, state_manager):
4324    """See `SequenceDenseColumn` base class."""
4325    if not isinstance(self.categorical_column, SequenceCategoricalColumn):
4326      raise ValueError(
4327          'In indicator_column: {}. '
4328          'categorical_column must be of type SequenceCategoricalColumn '
4329          'to use SequenceFeatures. '
4330          'Suggested fix: Use one of sequence_categorical_column_with_*. '
4331          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
4332                                       self.categorical_column))
4333    # Feature has been already transformed. Return the intermediate
4334    # representation created by transform_feature.
4335    dense_tensor = transformation_cache.get(self, state_manager)
4336    sparse_tensors = self.categorical_column.get_sparse_tensors(
4337        transformation_cache, state_manager)
4338    sequence_length = fc_utils.sequence_length_from_sparse_tensor(
4339        sparse_tensors.id_tensor)
4340    return SequenceDenseColumn.TensorSequenceLengthPair(
4341        dense_tensor=dense_tensor, sequence_length=sequence_length)
4342
4343  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4344                          _FEATURE_COLUMN_DEPRECATION)
4345  def _get_sequence_dense_tensor(self,
4346                                 inputs,
4347                                 weight_collections=None,
4348                                 trainable=None):
4349    # Do nothing with weight_collections and trainable since no variables are
4350    # created in this function.
4351    del weight_collections
4352    del trainable
4353    if not isinstance(
4354        self.categorical_column,
4355        (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)):  # pylint: disable=protected-access
4356      raise ValueError(
4357          'In indicator_column: {}. '
4358          'categorical_column must be of type _SequenceCategoricalColumn '
4359          'to use SequenceFeatures. '
4360          'Suggested fix: Use one of sequence_categorical_column_with_*. '
4361          'Given (type {}): {}'.format(self.name, type(self.categorical_column),
4362                                       self.categorical_column))
4363    # Feature has been already transformed. Return the intermediate
4364    # representation created by _transform_feature.
4365    dense_tensor = inputs.get(self)
4366    sparse_tensors = self.categorical_column._get_sparse_tensors(inputs)  # pylint: disable=protected-access
4367    sequence_length = fc_utils.sequence_length_from_sparse_tensor(
4368        sparse_tensors.id_tensor)
4369    return SequenceDenseColumn.TensorSequenceLengthPair(
4370        dense_tensor=dense_tensor, sequence_length=sequence_length)
4371
4372  @property
4373  def parents(self):
4374    """See 'FeatureColumn` base class."""
4375    return [self.categorical_column]
4376
4377  def get_config(self):
4378    """See 'FeatureColumn` base class."""
4379    from tensorflow.python.feature_column.serialization import serialize_feature_column  # pylint: disable=g-import-not-at-top
4380    config = dict(zip(self._fields, self))
4381    config['categorical_column'] = serialize_feature_column(
4382        self.categorical_column)
4383    return config
4384
4385  @classmethod
4386  def from_config(cls, config, custom_objects=None, columns_by_name=None):
4387    """See 'FeatureColumn` base class."""
4388    from tensorflow.python.feature_column.serialization import deserialize_feature_column  # pylint: disable=g-import-not-at-top
4389    _check_config_keys(config, cls._fields)
4390    kwargs = _standardize_and_copy_config(config)
4391    kwargs['categorical_column'] = deserialize_feature_column(
4392        config['categorical_column'], custom_objects, columns_by_name)
4393    return cls(**kwargs)
4394
4395
4396def _verify_static_batch_size_equality(tensors, columns):
4397  """Verify equality between static batch sizes.
4398
4399  Args:
4400    tensors: iterable of input tensors.
4401    columns: Corresponding feature columns.
4402
4403  Raises:
4404    ValueError: in case of mismatched batch sizes.
4405  """
4406  # bath_size is a Dimension object.
4407  expected_batch_size = None
4408  for i in range(0, len(tensors)):
4409    batch_size = tensor_shape.Dimension(tensor_shape.dimension_value(
4410        tensors[i].shape[0]))
4411    if batch_size.value is not None:
4412      if expected_batch_size is None:
4413        bath_size_column_index = i
4414        expected_batch_size = batch_size
4415      elif not expected_batch_size.is_compatible_with(batch_size):
4416        raise ValueError(
4417            'Batch size (first dimension) of each feature must be same. '
4418            'Batch size of columns ({}, {}): ({}, {})'.format(
4419                columns[bath_size_column_index].name, columns[i].name,
4420                expected_batch_size, batch_size))
4421
4422
4423class SequenceCategoricalColumn(
4424    CategoricalColumn,
4425    fc_old._SequenceCategoricalColumn,  # pylint: disable=protected-access
4426    collections.namedtuple('SequenceCategoricalColumn',
4427                           ('categorical_column'))):
4428  """Represents sequences of categorical data."""
4429
4430  @property
4431  def _is_v2_column(self):
4432    return (isinstance(self.categorical_column, FeatureColumn) and
4433            self.categorical_column._is_v2_column)  # pylint: disable=protected-access
4434
4435  @property
4436  def name(self):
4437    """See `FeatureColumn` base class."""
4438    return self.categorical_column.name
4439
4440  @property
4441  def parse_example_spec(self):
4442    """See `FeatureColumn` base class."""
4443    return self.categorical_column.parse_example_spec
4444
4445  @property
4446  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4447                          _FEATURE_COLUMN_DEPRECATION)
4448  def _parse_example_spec(self):
4449    return self.categorical_column._parse_example_spec  # pylint: disable=protected-access
4450
4451  def transform_feature(self, transformation_cache, state_manager):
4452    """See `FeatureColumn` base class."""
4453    return self.categorical_column.transform_feature(transformation_cache,
4454                                                     state_manager)
4455
4456  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4457                          _FEATURE_COLUMN_DEPRECATION)
4458  def _transform_feature(self, inputs):
4459    return self.categorical_column._transform_feature(inputs)  # pylint: disable=protected-access
4460
4461  @property
4462  def num_buckets(self):
4463    """Returns number of buckets in this sparse feature."""
4464    return self.categorical_column.num_buckets
4465
4466  @property
4467  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4468                          _FEATURE_COLUMN_DEPRECATION)
4469  def _num_buckets(self):
4470    return self.categorical_column._num_buckets  # pylint: disable=protected-access
4471
4472  def _get_sparse_tensors_helper(self, sparse_tensors):
4473    id_tensor = sparse_tensors.id_tensor
4474    weight_tensor = sparse_tensors.weight_tensor
4475    # Expands third dimension, if necessary so that embeddings are not
4476    # combined during embedding lookup. If the tensor is already 3D, leave
4477    # as-is.
4478    shape = array_ops.shape(id_tensor)
4479    # Compute the third dimension explicitly instead of setting it to -1, as
4480    # that doesn't work for dynamically shaped tensors with 0-length at runtime.
4481    # This happens for empty sequences.
4482    target_shape = [shape[0], shape[1], math_ops.reduce_prod(shape[2:])]
4483    id_tensor = sparse_ops.sparse_reshape(id_tensor, target_shape)
4484    if weight_tensor is not None:
4485      weight_tensor = sparse_ops.sparse_reshape(weight_tensor, target_shape)
4486    return CategoricalColumn.IdWeightPair(id_tensor, weight_tensor)
4487
4488  def get_sparse_tensors(self, transformation_cache, state_manager):
4489    """Returns an IdWeightPair.
4490
4491    `IdWeightPair` is a pair of `SparseTensor`s which represents ids and
4492    weights.
4493
4494    `IdWeightPair.id_tensor` is typically a `batch_size` x `num_buckets`
4495    `SparseTensor` of `int64`. `IdWeightPair.weight_tensor` is either a
4496    `SparseTensor` of `float` or `None` to indicate all weights should be
4497    taken to be 1. If specified, `weight_tensor` must have exactly the same
4498    shape and indices as `sp_ids`. Expected `SparseTensor` is same as parsing
4499    output of a `VarLenFeature` which is a ragged matrix.
4500
4501    Args:
4502      transformation_cache: A `FeatureTransformationCache` object to access
4503        features.
4504      state_manager: A `StateManager` to create / access resources such as
4505        lookup tables.
4506    """
4507    sparse_tensors = self.categorical_column.get_sparse_tensors(
4508        transformation_cache, state_manager)
4509    return self._get_sparse_tensors_helper(sparse_tensors)
4510
4511  @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
4512                          _FEATURE_COLUMN_DEPRECATION)
4513  def _get_sparse_tensors(self, inputs, weight_collections=None,
4514                          trainable=None):
4515    sparse_tensors = self.categorical_column._get_sparse_tensors(inputs)  # pylint: disable=protected-access
4516    return self._get_sparse_tensors_helper(sparse_tensors)
4517
4518  @property
4519  def parents(self):
4520    """See 'FeatureColumn` base class."""
4521    return [self.categorical_column]
4522
4523  def get_config(self):
4524    """See 'FeatureColumn` base class."""
4525    from tensorflow.python.feature_column.serialization import serialize_feature_column  # pylint: disable=g-import-not-at-top
4526    config = dict(zip(self._fields, self))
4527    config['categorical_column'] = serialize_feature_column(
4528        self.categorical_column)
4529    return config
4530
4531  @classmethod
4532  def from_config(cls, config, custom_objects=None, columns_by_name=None):
4533    """See 'FeatureColumn` base class."""
4534    from tensorflow.python.feature_column.serialization import deserialize_feature_column  # pylint: disable=g-import-not-at-top
4535    _check_config_keys(config, cls._fields)
4536    kwargs = _standardize_and_copy_config(config)
4537    kwargs['categorical_column'] = deserialize_feature_column(
4538        config['categorical_column'], custom_objects, columns_by_name)
4539    return cls(**kwargs)
4540
4541
4542def _check_config_keys(config, expected_keys):
4543  """Checks that a config has all expected_keys."""
4544  if set(config.keys()) != set(expected_keys):
4545    raise ValueError('Invalid config: {}, expected keys: {}'.format(
4546        config, expected_keys))
4547
4548
4549def _standardize_and_copy_config(config):
4550  """Returns a shallow copy of config with lists turned to tuples.
4551
4552  Keras serialization uses nest to listify everything.
4553  This causes problems with the NumericColumn shape, which becomes
4554  unhashable. We could try to solve this on the Keras side, but that
4555  would require lots of tracking to avoid changing existing behavior.
4556  Instead, we ensure here that we revive correctly.
4557
4558  Args:
4559    config: dict that will be used to revive a Feature Column
4560
4561  Returns:
4562    Shallow copy of config with lists turned to tuples.
4563  """
4564  kwargs = config.copy()
4565  for k, v in kwargs.items():
4566    if isinstance(v, list):
4567      kwargs[k] = tuple(v)
4568
4569  return kwargs
4570
4571
4572def _sanitize_column_name_for_variable_scope(name):
4573  """Sanitizes user-provided feature names for use as variable scopes."""
4574  invalid_char = re.compile('[^A-Za-z0-9_.\\-]')
4575  return invalid_char.sub('_', name)
4576